├── README.md
├── benchmarking
└── benchmark_synthetic_nerf.sh
├── config
├── img
│ └── config.json
├── nerf
│ └── config.json
└── sdf
│ └── config.json
├── datasets
├── __init__.py
├── img
│ └── imager.py
├── nerf
│ ├── base.py
│ ├── colmap.py
│ ├── colmap_utils.py
│ ├── color_utils.py
│ ├── depth_utils.py
│ ├── nerf.py
│ ├── nerfpp.py
│ ├── nsvf.py
│ ├── ray_utils.py
│ └── rtmv.py
└── sdf
│ └── sampler.py
├── docs
└── figures
│ ├── 2d_fitting.png
│ ├── 3d_fitting.png
│ ├── nvs.png
│ └── teaser.png
├── models
├── __init__.py
├── csrc
│ ├── binding.cpp
│ ├── include
│ │ ├── helper_math.h
│ │ └── utils.h
│ ├── intersection.cu
│ ├── losses.cu
│ ├── raymarching.cu
│ ├── setup.py
│ └── volumerendering.cu
├── loss
│ └── nerf
│ │ ├── __init__.py
│ │ └── losses.py
└── networks
│ ├── FFB_encoder.py
│ ├── Sine.py
│ ├── __init__.py
│ ├── img
│ ├── NFFB_2d.py
│ └── __init__.py
│ ├── nerf
│ ├── NFFB_nerf.py
│ ├── __init__.py
│ ├── custom_functions.py
│ └── rendering.py
│ └── sdf
│ ├── NFFB_3d.py
│ └── __init__.py
├── requirements.txt
├── scripts
├── img
│ ├── common.py
│ ├── opt.py
│ └── utils.py
├── nvs
│ ├── opt.py
│ └── prepare_rtmv.py
└── sdf
│ ├── opt.py
│ └── utils.py
├── train_img.py
├── train_nerf.py
├── train_sdf.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Neural Fourier Filter Bank
2 |
3 |
4 | This repository contains the code (in [PyTorch Lightning](https://www.pytorchlightning.ai/index.html)) for the paper:
5 |
6 | [__Neural Fourier Filter Bank__](https://arxiv.org/abs/2212.01735)
7 |
8 | [Zhijie Wu](https://zhijiew94.github.io/), [Yuhe Jin](https://scholar.google.ca/citations?user=oAYi1YQAAAAJ&hl=en), [Kwang Moo Yi](https://www.cs.ubc.ca/~kmyi/)
9 |
10 | CVPR 2023
11 |
12 |
13 | ## Introduction
14 |
15 | In this project, we propose to learn a neural field by decomposing the signal both spatially and frequency-wise.
16 | We follow the grid-based paradigm for spatial decomposition, but unlike existing work, encourage specific frequencies to be stored in each grid via Fourier feature encodings.
17 | We then apply a multi-layer perceptron with sine activations, taking these Fourier encoded features in at appropriate layers so that higher-frequency components are accumulated on top of lower-frequency components sequentially, which we sum up to form the final output.
18 | We do the evaluations in the tasks of 2D image fitting, 3D shape reconstruction, and neural radiance fields.
19 | All results are tested upon an Nvidia RTX 3090.
20 |
21 | If you have any questions, please feel free to contact Zhijie Wu (wzj.micker@gmail.com).
22 |
23 | 
24 |
25 |
26 |
27 | ## Key Requirements
28 | - Python 3.8
29 | - CUDA 11.6
30 | - [PyTorch 1.12.0](https://www.tensorflow.org/)
31 | - [PyTorch Lightning](https://www.pytorchlightning.ai/index.html)
32 | - [torch-scatter](https://github.com/rusty1s/pytorch_scatter#installation)
33 | - [apex](https://github.com/NVIDIA/apex#linux)
34 | - [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn#pytorch-extension)
35 | - Install requirements by `pip install -r requirements.txt`
36 |
37 | > Note: Our current implementations are heavily based on the [ngp-pl](https://github.com/kwea123/ngp_pl) repo.
38 | > For further details, please also refer to their codebase.
39 |
40 |
41 |
42 |
43 | ## Novel View Synthesis
44 | 
45 |
46 | A quickstart:
47 | ```bash
48 | python train_nerf.py --root_dir --exp_name Lego --num_epochs 30 --lr 2e-2 --eval_lpips --no_save_test
49 | ```
50 | It will train the Lego scene for 30k steps. `--no_save_test` is to disable saving synthesized images.
51 |
52 | More options can be found in `opt.py` and `FFB_config.json` under the `config` folder.
53 |
54 | To compute the metrics for the eight Blender scenes, please run the script `benchmark_synthetic_nerf.sh` under the folder `benchmarking`.
55 |
56 |
57 | ## 2D Image Fitting
58 | 
59 |
60 | ```bash
61 | python train_img.py --config --input_path
62 | ```
63 |
64 | Currently, the model is trained for 50k iterations. But our experiences show that the model has already achieved comparable results near 20k iterations' training.
65 |
66 |
67 | ## 3D Shape Fitting
68 | 
69 |
70 | ```bash
71 | python train_sdf.py --config --input_path
72 | ```
73 |
74 | Similar to **2D Image Fitting**, the model is trained with 50k iterations to achieve improved geometric details. However, using `size=100` instead of `size=1000` in the train_dataset (`train_sdf.py`) would slightly reduce the output quality while significantly accelerating the training process.
75 |
76 | ## Citation and License
77 |
78 | ```
79 | @InProceedings{Wu_2023_CVPR,
80 | author = {Wu, Zhijie and Jin, Yuhe and Yi, Kwang Moo},
81 | title = {Neural Fourier Filter Bank},
82 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
83 | month = {June},
84 | year = {2023},
85 | pages = {14153-14163}
86 | }
87 | ```
88 |
89 | Our codebase is under the MIT License.
90 |
91 |
92 | ## TODO
93 |
94 | - [ ] Finish the CUDA version
95 |
96 |
--------------------------------------------------------------------------------
/benchmarking/benchmark_synthetic_nerf.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export ROOT_DIR=NeRF/Synthetic_NeRF
4 |
5 | python train_nerf.py \
6 | --root_dir $ROOT_DIR/Chair \
7 | --exp_name Chair --no_save_test \
8 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips
9 |
10 | python train_nerf.py \
11 | --root_dir $ROOT_DIR/Drums \
12 | --exp_name Drums --no_save_test \
13 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips
14 |
15 | python train_nerf.py \
16 | --root_dir $ROOT_DIR/Ficus \
17 | --exp_name Ficus --no_save_test \
18 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips
19 |
20 | python train_nerf.py \
21 | --root_dir $ROOT_DIR/Hotdog \
22 | --exp_name Hotdog --no_save_test \
23 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips
24 |
25 | python train_nerf.py \
26 | --root_dir $ROOT_DIR/Lego \
27 | --exp_name Lego --no_save_test \
28 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips
29 |
30 | python train_nerf.py \
31 | --root_dir $ROOT_DIR/Materials \
32 | --exp_name Materials --no_save_test \
33 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips
34 |
35 | python train_nerf.py \
36 | --root_dir $ROOT_DIR/Mic \
37 | --exp_name Mic --no_save_test \
38 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips
39 |
40 | python train_nerf.py \
41 | --root_dir $ROOT_DIR/Ship \
42 | --exp_name Ship --no_save_test \
43 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips
44 |
--------------------------------------------------------------------------------
/config/img/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "network": {
3 | "encoding": {
4 | "feat_dim": 2,
5 | "base_resolution": 96,
6 | "per_level_scale": 1.5,
7 | "base_sigma": 5.0,
8 | "exp_sigma": 2.0,
9 | "grid_embedding_std": 0.01
10 | },
11 | "SIREN": {
12 | "dims" : [128, 128, 128, 128, 128, 128, 128, 128],
13 | "w0": 100.0,
14 | "w1": 100.0,
15 | "size_factor": 1
16 | },
17 | "Backbone": {
18 | "dims": [64, 64]
19 | }
20 | },
21 | "training": {
22 | "LR_scheduler" : [
23 | {
24 | "type" : "Step",
25 | "initial" : 0.0001,
26 | "interval" : 5,
27 | "factor" : 0.5
28 | }]
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/config/nerf/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "network": {
3 | "encoding": {
4 | "feat_dim": 2,
5 | "base_resolution": 64,
6 | "per_level_scale": 2.0,
7 | "base_sigma": 8.0,
8 | "exp_sigma": 1.5,
9 | "grid_embedding_std": 0.001
10 | },
11 | "SIREN": {
12 | "dims" : [128, 128, 128, 128, 128],
13 | "w0": 15.0,
14 | "w1": 25.0,
15 | "size_factor": 2
16 | }
17 | },
18 | "training": {
19 | "LearningRateSchedule" : [
20 | {
21 | "type" : "Step",
22 | "initial" : 0.0001,
23 | "interval" : 5000,
24 | "factor" : 0.5
25 | },
26 | {
27 | "type" : "Step",
28 | "initial" : 0.0001,
29 | "interval" : 5000,
30 | "factor" : 0.5
31 | },
32 | {
33 | "type" : "Step",
34 | "initial" : 0.001,
35 | "interval" : 5000,
36 | "factor" : 0.5
37 | },
38 | {
39 | "type" : "Step",
40 | "initial" : 0.005,
41 | "interval" : 5000,
42 | "factor" : 0.5
43 | }],
44 | "lr_threshold": 1e-5
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/config/sdf/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "network": {
3 | "encoding": {
4 | "feat_dim": 2,
5 | "base_resolution": 8,
6 | "per_level_scale": 1.3,
7 | "base_sigma": 5.0,
8 | "exp_sigma": 1.2,
9 | "grid_embedding_std": 0.01
10 | },
11 | "SIREN": {
12 | "dims" : [256, 256, 256, 256, 256, 256],
13 | "w0": 45.0,
14 | "w1": 45.0,
15 | "size_factor": 1
16 | }
17 | },
18 | "training": {
19 | "LR_scheduler" : [
20 | {
21 | "type" : "Step",
22 | "initial" : 0.0001,
23 | "interval" : 5,
24 | "factor" : 0.5
25 | }]
26 | }
27 | }
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from datasets.nerf.nerf import NeRFDataset
2 | from datasets.nerf.nsvf import NSVFDataset
3 | from datasets.nerf.colmap import ColmapDataset
4 | from datasets.nerf.nerfpp import NeRFPPDataset
5 | from datasets.nerf.rtmv import RTMVDataset
6 |
7 |
8 | dataset_dict = {'nerf': NeRFDataset,
9 | 'nsvf': NSVFDataset,
10 | 'colmap': ColmapDataset,
11 | 'nerfpp': NeRFPPDataset,
12 | 'rtmv': RTMVDataset}
--------------------------------------------------------------------------------
/datasets/img/imager.py:
--------------------------------------------------------------------------------
1 | """
2 | These codes are adapted from tiny-cuda-nn (https://github.com/NVlabs/tiny-cuda-nn)
3 | """
4 |
5 | import torch
6 | from torch.utils.data import Dataset
7 |
8 | import math
9 |
10 |
11 | class ImageDataset(Dataset):
12 | def __init__(self, data, size=100, num_samples=2**18, split='train'):
13 | super().__init__()
14 |
15 | # assign image
16 | self.data = data
17 |
18 | self.img_wh = (self.data.shape[0], self.data.shape[1])
19 | self.img_shape = torch.tensor([self.img_wh[0], self.img_wh[1]]).float()
20 |
21 | print(f"[INFO] image: {self.data.shape}")
22 |
23 | self.num_samples = num_samples
24 |
25 | self.split = split
26 | self.size = size
27 |
28 | if self.split.startswith("test"):
29 | half_dx = 0.5 / self.img_wh[0]
30 | half_dy = 0.5 / self.img_wh[1]
31 | xs = torch.linspace(half_dx, 1-half_dx, self.img_wh[0])
32 | ys = torch.linspace(half_dy, 1-half_dy, self.img_wh[1])
33 | xv, yv = torch.meshgrid([xs, ys], indexing="ij")
34 |
35 | xy = torch.stack((xv.flatten(), yv.flatten())).t()
36 |
37 | xy_max_num = math.ceil(xy.shape[0] / 1024.0)
38 | padding_delta = xy_max_num * 1024 - xy.shape[0]
39 | zeros_padding = torch.zeros((padding_delta, 2))
40 | self.xs = torch.cat([xy, zeros_padding], dim=0)
41 |
42 |
43 | def __len__(self):
44 | return self.size
45 |
46 |
47 | def __getitem__(self, _):
48 | if self.split.startswith('train'):
49 | xs = torch.rand([self.num_samples, 2], dtype=torch.float32)
50 |
51 | assert torch.sum(xs < 0) == 0, "The coordinates for input image should be non-negative."
52 |
53 | with torch.no_grad():
54 | scaled_xs = xs * self.img_shape
55 | indices = scaled_xs.long()
56 | lerp_weights = scaled_xs - indices.float()
57 |
58 | x0 = indices[:, 0].clamp(min=0, max=self.img_wh[0]-1).long()
59 | y0 = indices[:, 1].clamp(min=0, max=self.img_wh[1]-1).long()
60 | x1 = (x0 + 1).clamp(min=0, max=self.img_wh[0]-1).long()
61 | y1 = (y0 + 1).clamp(min=0, max=self.img_wh[1]-1).long()
62 |
63 | rgbs = self.data[x0, y0] * (1.0 - lerp_weights[:, 0:1]) * (1.0 - lerp_weights[:, 1:2]) + \
64 | self.data[x0, y1] * (1.0 - lerp_weights[:, 0:1]) * lerp_weights[:, 1:2] + \
65 | self.data[x1, y0] * lerp_weights[:, 0:1] * (1.0 - lerp_weights[:, 1:2]) + \
66 | self.data[x1, y1] * lerp_weights[:, 0:1] * lerp_weights[:, 1:2]
67 | else:
68 | xs = self.xs
69 | rgbs = self.data
70 |
71 | results = {
72 | 'points': xs,
73 | 'rgbs': rgbs,
74 | }
75 |
76 | return results
--------------------------------------------------------------------------------
/datasets/nerf/base.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import numpy as np
3 |
4 |
5 | class BaseDataset(Dataset):
6 | """
7 | Define length and sampling method
8 | """
9 | def __init__(self, root_dir, split='train', downsample=1.0):
10 | self.root_dir = root_dir
11 | self.split = split
12 | self.downsample = downsample
13 |
14 | def read_intrinsics(self):
15 | raise NotImplementedError
16 |
17 | def __len__(self):
18 | if self.split.startswith('train'):
19 | return 1000
20 | return len(self.poses)
21 |
22 | def __getitem__(self, idx):
23 | if self.split.startswith('train'):
24 | # training pose is retrieved in train_nerf.py
25 | if self.ray_sampling_strategy == 'all_images': # randomly select images
26 | img_idxs = np.random.choice(len(self.poses), self.batch_size)
27 | elif self.ray_sampling_strategy == 'same_image': # randomly select ONE image
28 | img_idxs = np.random.choice(len(self.poses), 1)[0]
29 | # randomly select pixels
30 | pix_idxs = np.random.choice(self.img_wh[0]*self.img_wh[1], self.batch_size)
31 | rays = self.rays[img_idxs, pix_idxs]
32 | sample = {'img_idxs': img_idxs, 'pix_idxs': pix_idxs,
33 | 'rgb': rays[:, :3]}
34 | if self.rays.shape[-1] == 4: # HDR-NeRF data
35 | sample['exposure'] = rays[:, 3:]
36 | else:
37 | sample = {'pose': self.poses[idx], 'img_idxs': idx}
38 | if len(self.rays) > 0: # if ground truth available
39 | rays = self.rays[idx]
40 | sample['rgb'] = rays[:, :3]
41 | if rays.shape[1] == 4: # HDR-NeRF data
42 | sample['exposure'] = rays[0, 3] # same exposure for all rays
43 |
44 | return sample
--------------------------------------------------------------------------------
/datasets/nerf/colmap.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import glob
5 | from tqdm import tqdm
6 |
7 | from .ray_utils import *
8 | from .color_utils import read_image
9 | from .colmap_utils import \
10 | read_cameras_binary, read_images_binary, read_points3d_binary
11 |
12 | from .base import BaseDataset
13 |
14 |
15 | class ColmapDataset(BaseDataset):
16 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs):
17 | super().__init__(root_dir, split, downsample)
18 |
19 | self.read_intrinsics()
20 |
21 | if kwargs.get('read_meta', True):
22 | self.read_meta(split, **kwargs)
23 |
24 | def read_intrinsics(self):
25 | # Step 1: read and scale intrinsics (same for all images)
26 | camdata = read_cameras_binary(os.path.join(self.root_dir, 'sparse/0/cameras.bin'))
27 | h = int(camdata[1].height*self.downsample)
28 | w = int(camdata[1].width*self.downsample)
29 | self.img_wh = (w, h)
30 |
31 | if camdata[1].model == 'SIMPLE_RADIAL':
32 | fx = fy = camdata[1].params[0]*self.downsample
33 | cx = camdata[1].params[1]*self.downsample
34 | cy = camdata[1].params[2]*self.downsample
35 | elif camdata[1].model in ['PINHOLE', 'OPENCV']:
36 | fx = camdata[1].params[0]*self.downsample
37 | fy = camdata[1].params[1]*self.downsample
38 | cx = camdata[1].params[2]*self.downsample
39 | cy = camdata[1].params[3]*self.downsample
40 | else:
41 | raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!")
42 | self.K = torch.FloatTensor([[fx, 0, cx],
43 | [0, fy, cy],
44 | [0, 0, 1]])
45 | self.directions = get_ray_directions(h, w, self.K)
46 |
47 | def read_meta(self, split, **kwargs):
48 | # Step 2: correct poses
49 | # read extrinsics (of successfully reconstructed images)
50 | imdata = read_images_binary(os.path.join(self.root_dir, 'sparse/0/images.bin'))
51 | img_names = [imdata[k].name for k in imdata]
52 | perm = np.argsort(img_names)
53 | if '360_v2' in self.root_dir and self.downsample<1: # mipnerf360 data
54 | folder = f'images_{int(1/self.downsample)}'
55 | else:
56 | folder = 'images'
57 | # read successfully reconstructed images and ignore others
58 | img_paths = [os.path.join(self.root_dir, folder, name)
59 | for name in sorted(img_names)]
60 | w2c_mats = []
61 | bottom = np.array([[0, 0, 0, 1.]])
62 | for k in imdata:
63 | im = imdata[k]
64 | R = im.qvec2rotmat(); t = im.tvec.reshape(3, 1)
65 | w2c_mats += [np.concatenate([np.concatenate([R, t], 1), bottom], 0)]
66 | w2c_mats = np.stack(w2c_mats, 0)
67 | poses = np.linalg.inv(w2c_mats)[perm, :3] # (N_images, 3, 4) cam2world matrices
68 |
69 | pts3d = read_points3d_binary(os.path.join(self.root_dir, 'sparse/0/points3D.bin'))
70 | pts3d = np.array([pts3d[k].xyz for k in pts3d]) # (N, 3)
71 |
72 | self.poses, self.pts3d = center_poses(poses, pts3d)
73 |
74 | scale = np.linalg.norm(self.poses[..., 3], axis=-1).min()
75 | self.poses[..., 3] /= scale
76 | self.pts3d /= scale
77 |
78 | self.rays = []
79 | if split == 'test_traj': # use precomputed test poses
80 | self.poses = create_spheric_poses(1.2, self.poses[:, 1, 3].mean())
81 | self.poses = torch.FloatTensor(self.poses)
82 | return
83 |
84 | if 'HDR-NeRF' in self.root_dir: # HDR-NeRF data
85 | if 'syndata' in self.root_dir: # synthetic
86 | # first 17 are test, last 18 are train
87 | self.unit_exposure_rgb = 0.73
88 | if split=='train':
89 | img_paths = sorted(glob.glob(os.path.join(self.root_dir,
90 | f'train/*[024].png')))
91 | self.poses = np.repeat(self.poses[-18:], 3, 0)
92 | elif split=='test':
93 | img_paths = sorted(glob.glob(os.path.join(self.root_dir,
94 | f'test/*[13].png')))
95 | self.poses = np.repeat(self.poses[:17], 2, 0)
96 | else:
97 | raise ValueError(f"split {split} is invalid for HDR-NeRF!")
98 | else: # real
99 | self.unit_exposure_rgb = 0.5
100 | # even numbers are train, odd numbers are test
101 | if split=='train':
102 | img_paths = sorted(glob.glob(os.path.join(self.root_dir,
103 | f'input_images/*0.jpg')))[::2]
104 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir,
105 | f'input_images/*2.jpg')))[::2]
106 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir,
107 | f'input_images/*4.jpg')))[::2]
108 | self.poses = np.tile(self.poses[::2], (3, 1, 1))
109 | elif split=='test':
110 | img_paths = sorted(glob.glob(os.path.join(self.root_dir,
111 | f'input_images/*1.jpg')))[1::2]
112 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir,
113 | f'input_images/*3.jpg')))[1::2]
114 | self.poses = np.tile(self.poses[1::2], (2, 1, 1))
115 | else:
116 | raise ValueError(f"split {split} is invalid for HDR-NeRF!")
117 | else:
118 | # use every 8th image as test set
119 | if split=='train':
120 | img_paths = [x for i, x in enumerate(img_paths) if i%8!=0]
121 | self.poses = np.array([x for i, x in enumerate(self.poses) if i%8!=0])
122 | elif split=='test':
123 | img_paths = [x for i, x in enumerate(img_paths) if i%8==0]
124 | self.poses = np.array([x for i, x in enumerate(self.poses) if i%8==0])
125 |
126 | print(f'Loading {len(img_paths)} {split} images ...')
127 | for img_path in tqdm(img_paths):
128 | buf = [] # buffer for ray attributes: rgb, etc
129 |
130 | img = read_image(img_path, self.img_wh, blend_a=False)
131 | img = torch.FloatTensor(img)
132 | buf += [img]
133 |
134 | if 'HDR-NeRF' in self.root_dir: # get exposure
135 | folder = self.root_dir.split('/')
136 | scene = folder[-1] if folder[-1] != '' else folder[-2]
137 | if scene in ['bathroom', 'bear', 'chair', 'desk']:
138 | e_dict = {e: 1/8*4**e for e in range(5)}
139 | elif scene in ['diningroom', 'dog']:
140 | e_dict = {e: 1/16*4**e for e in range(5)}
141 | elif scene in ['sofa']:
142 | e_dict = {0:0.25, 1:1, 2:2, 3:4, 4:16}
143 | elif scene in ['sponza']:
144 | e_dict = {0:0.5, 1:2, 2:4, 3:8, 4:32}
145 | elif scene in ['box']:
146 | e_dict = {0:2/3, 1:1/3, 2:1/6, 3:0.1, 4:0.05}
147 | elif scene in ['computer']:
148 | e_dict = {0:1/3, 1:1/8, 2:1/15, 3:1/30, 4:1/60}
149 | elif scene in ['flower']:
150 | e_dict = {0:1/3, 1:1/6, 2:0.1, 3:0.05, 4:1/45}
151 | elif scene in ['luckycat']:
152 | e_dict = {0:2, 1:1, 2:0.5, 3:0.25, 4:0.125}
153 | e = int(img_path.split('.')[0][-1])
154 | buf += [e_dict[e]*torch.ones_like(img[:, :1])]
155 |
156 | self.rays += [torch.cat(buf, 1)]
157 |
158 | self.rays = torch.stack(self.rays) # (N_images, hw, ?)
159 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4)
--------------------------------------------------------------------------------
/datasets/nerf/colmap_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
2 | # All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # * Redistributions of source code must retain the above copyright
8 | # notice, this list of conditions and the following disclaimer.
9 | #
10 | # * Redistributions in binary form must reproduce the above copyright
11 | # notice, this list of conditions and the following disclaimer in the
12 | # documentation and/or other materials provided with the distribution.
13 | #
14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15 | # its contributors may be used to endorse or promote products derived
16 | # from this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28 | # POSSIBILITY OF SUCH DAMAGE.
29 | #
30 | # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch)
31 |
32 | import os
33 | import sys
34 | import collections
35 | import numpy as np
36 | import struct
37 |
38 |
39 | CameraModel = collections.namedtuple(
40 | "CameraModel", ["model_id", "model_name", "num_params"])
41 | Camera = collections.namedtuple(
42 | "Camera", ["id", "model", "width", "height", "params"])
43 | BaseImage = collections.namedtuple(
44 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
45 | Point3D = collections.namedtuple(
46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
47 |
48 | class Image(BaseImage):
49 | def qvec2rotmat(self):
50 | return qvec2rotmat(self.qvec)
51 |
52 |
53 | CAMERA_MODELS = {
54 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
55 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
56 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
57 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
58 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
59 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
60 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
61 | CameraModel(model_id=7, model_name="FOV", num_params=5),
62 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
63 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
64 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
65 | }
66 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \
67 | for camera_model in CAMERA_MODELS])
68 |
69 |
70 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
71 | """Read and unpack the next bytes from a binary file.
72 | :param fid:
73 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
74 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
75 | :param endian_character: Any of {@, =, <, >, !}
76 | :return: Tuple of read and unpacked values.
77 | """
78 | data = fid.read(num_bytes)
79 | return struct.unpack(endian_character + format_char_sequence, data)
80 |
81 |
82 | def read_cameras_text(path):
83 | """
84 | see: src/base/reconstruction.cc
85 | void Reconstruction::WriteCamerasText(const std::string& path)
86 | void Reconstruction::ReadCamerasText(const std::string& path)
87 | """
88 | cameras = {}
89 | with open(path, "r") as fid:
90 | while True:
91 | line = fid.readline()
92 | if not line:
93 | break
94 | line = line.strip()
95 | if len(line) > 0 and line[0] != "#":
96 | elems = line.split()
97 | camera_id = int(elems[0])
98 | model = elems[1]
99 | width = int(elems[2])
100 | height = int(elems[3])
101 | params = np.array(tuple(map(float, elems[4:])))
102 | cameras[camera_id] = Camera(id=camera_id, model=model,
103 | width=width, height=height,
104 | params=params)
105 | return cameras
106 |
107 |
108 | def read_cameras_binary(path_to_model_file):
109 | """
110 | see: src/base/reconstruction.cc
111 | void Reconstruction::WriteCamerasBinary(const std::string& path)
112 | void Reconstruction::ReadCamerasBinary(const std::string& path)
113 | """
114 | cameras = {}
115 | with open(path_to_model_file, "rb") as fid:
116 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
117 | for camera_line_index in range(num_cameras):
118 | camera_properties = read_next_bytes(
119 | fid, num_bytes=24, format_char_sequence="iiQQ")
120 | camera_id = camera_properties[0]
121 | model_id = camera_properties[1]
122 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
123 | width = camera_properties[2]
124 | height = camera_properties[3]
125 | num_params = CAMERA_MODEL_IDS[model_id].num_params
126 | params = read_next_bytes(fid, num_bytes=8*num_params,
127 | format_char_sequence="d"*num_params)
128 | cameras[camera_id] = Camera(id=camera_id,
129 | model=model_name,
130 | width=width,
131 | height=height,
132 | params=np.array(params))
133 | assert len(cameras) == num_cameras
134 | return cameras
135 |
136 |
137 | def read_images_text(path):
138 | """
139 | see: src/base/reconstruction.cc
140 | void Reconstruction::ReadImagesText(const std::string& path)
141 | void Reconstruction::WriteImagesText(const std::string& path)
142 | """
143 | images = {}
144 | with open(path, "r") as fid:
145 | while True:
146 | line = fid.readline()
147 | if not line:
148 | break
149 | line = line.strip()
150 | if len(line) > 0 and line[0] != "#":
151 | elems = line.split()
152 | image_id = int(elems[0])
153 | qvec = np.array(tuple(map(float, elems[1:5])))
154 | tvec = np.array(tuple(map(float, elems[5:8])))
155 | camera_id = int(elems[8])
156 | image_name = elems[9]
157 | elems = fid.readline().split()
158 | xys = np.column_stack([tuple(map(float, elems[0::3])),
159 | tuple(map(float, elems[1::3]))])
160 | point3D_ids = np.array(tuple(map(int, elems[2::3])))
161 | images[image_id] = Image(
162 | id=image_id, qvec=qvec, tvec=tvec,
163 | camera_id=camera_id, name=image_name,
164 | xys=xys, point3D_ids=point3D_ids)
165 | return images
166 |
167 |
168 | def read_images_binary(path_to_model_file):
169 | """
170 | see: src/base/reconstruction.cc
171 | void Reconstruction::ReadImagesBinary(const std::string& path)
172 | void Reconstruction::WriteImagesBinary(const std::string& path)
173 | """
174 | images = {}
175 | with open(path_to_model_file, "rb") as fid:
176 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
177 | for image_index in range(num_reg_images):
178 | binary_image_properties = read_next_bytes(
179 | fid, num_bytes=64, format_char_sequence="idddddddi")
180 | image_id = binary_image_properties[0]
181 | qvec = np.array(binary_image_properties[1:5])
182 | tvec = np.array(binary_image_properties[5:8])
183 | camera_id = binary_image_properties[8]
184 | image_name = ""
185 | current_char = read_next_bytes(fid, 1, "c")[0]
186 | while current_char != b"\x00": # look for the ASCII 0 entry
187 | image_name += current_char.decode("utf-8")
188 | current_char = read_next_bytes(fid, 1, "c")[0]
189 | num_points2D = read_next_bytes(fid, num_bytes=8,
190 | format_char_sequence="Q")[0]
191 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
192 | format_char_sequence="ddq"*num_points2D)
193 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
194 | tuple(map(float, x_y_id_s[1::3]))])
195 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
196 | images[image_id] = Image(
197 | id=image_id, qvec=qvec, tvec=tvec,
198 | camera_id=camera_id, name=image_name,
199 | xys=xys, point3D_ids=point3D_ids)
200 | return images
201 |
202 |
203 | def read_points3D_text(path):
204 | """
205 | see: src/base/reconstruction.cc
206 | void Reconstruction::ReadPoints3DText(const std::string& path)
207 | void Reconstruction::WritePoints3DText(const std::string& path)
208 | """
209 | points3D = {}
210 | with open(path, "r") as fid:
211 | while True:
212 | line = fid.readline()
213 | if not line:
214 | break
215 | line = line.strip()
216 | if len(line) > 0 and line[0] != "#":
217 | elems = line.split()
218 | point3D_id = int(elems[0])
219 | xyz = np.array(tuple(map(float, elems[1:4])))
220 | rgb = np.array(tuple(map(int, elems[4:7])))
221 | error = float(elems[7])
222 | image_ids = np.array(tuple(map(int, elems[8::2])))
223 | point2D_idxs = np.array(tuple(map(int, elems[9::2])))
224 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
225 | error=error, image_ids=image_ids,
226 | point2D_idxs=point2D_idxs)
227 | return points3D
228 |
229 |
230 | def read_points3d_binary(path_to_model_file):
231 | """
232 | see: src/base/reconstruction.cc
233 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
234 | void Reconstruction::WritePoints3DBinary(const std::string& path)
235 | """
236 | points3D = {}
237 | with open(path_to_model_file, "rb") as fid:
238 | num_points = read_next_bytes(fid, 8, "Q")[0]
239 | for point_line_index in range(num_points):
240 | binary_point_line_properties = read_next_bytes(
241 | fid, num_bytes=43, format_char_sequence="QdddBBBd")
242 | point3D_id = binary_point_line_properties[0]
243 | xyz = np.array(binary_point_line_properties[1:4])
244 | rgb = np.array(binary_point_line_properties[4:7])
245 | error = np.array(binary_point_line_properties[7])
246 | track_length = read_next_bytes(
247 | fid, num_bytes=8, format_char_sequence="Q")[0]
248 | track_elems = read_next_bytes(
249 | fid, num_bytes=8*track_length,
250 | format_char_sequence="ii"*track_length)
251 | image_ids = np.array(tuple(map(int, track_elems[0::2])))
252 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
253 | points3D[point3D_id] = Point3D(
254 | id=point3D_id, xyz=xyz, rgb=rgb,
255 | error=error, image_ids=image_ids,
256 | point2D_idxs=point2D_idxs)
257 | return points3D
258 |
259 |
260 | def read_model(path, ext):
261 | if ext == ".txt":
262 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
263 | images = read_images_text(os.path.join(path, "images" + ext))
264 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
265 | else:
266 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
267 | images = read_images_binary(os.path.join(path, "images" + ext))
268 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext)
269 | return cameras, images, points3D
270 |
271 |
272 | def qvec2rotmat(qvec):
273 | return np.array([
274 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
275 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
276 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
277 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
278 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
279 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
280 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
281 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
282 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
283 |
284 |
285 | def rotmat2qvec(R):
286 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
287 | K = np.array([
288 | [Rxx - Ryy - Rzz, 0, 0, 0],
289 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
290 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
291 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
292 | eigvals, eigvecs = np.linalg.eigh(K)
293 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
294 | if qvec[0] < 0:
295 | qvec *= -1
296 | return qvec
--------------------------------------------------------------------------------
/datasets/nerf/color_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from einops import rearrange
3 | import imageio
4 | import numpy as np
5 |
6 |
7 | def srgb_to_linear(img):
8 | limit = 0.04045
9 | return np.where(img>limit, ((img+0.055)/1.055)**2.4, img/12.92)
10 |
11 |
12 | def linear_to_srgb(img):
13 | limit = 0.0031308
14 | img = np.where(img>limit, 1.055*img**(1/2.4)-0.055, 12.92*img)
15 | img[img>1] = 1 # "clamp" tonemapper
16 | return img
17 |
18 |
19 | def read_image(img_path, img_wh, blend_a=True):
20 | img = imageio.imread(img_path).astype(np.float32)/255.0
21 | # img[..., :3] = srgb_to_linear(img[..., :3])
22 | if img.shape[2] == 4: # blend A to RGB
23 | if blend_a:
24 | img = img[..., :3]*img[..., -1:]+(1-img[..., -1:])
25 | else:
26 | img = img[..., :3]*img[..., -1:]
27 |
28 | img = cv2.resize(img, img_wh)
29 | img = rearrange(img, 'h w c -> (h w) c')
30 |
31 | return img
--------------------------------------------------------------------------------
/datasets/nerf/depth_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import re
3 |
4 |
5 | def read_pfm(path):
6 | """Read pfm file.
7 |
8 | Args:
9 | path (str): path to file
10 |
11 | Returns:
12 | tuple: (data, scale)
13 | """
14 | with open(path, "rb") as file:
15 |
16 | color = None
17 | width = None
18 | height = None
19 | scale = None
20 | endian = None
21 |
22 | header = file.readline().rstrip()
23 | if header.decode("ascii") == "PF":
24 | color = True
25 | elif header.decode("ascii") == "Pf":
26 | color = False
27 | else:
28 | raise Exception("Not a PFM file: " + path)
29 |
30 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
31 | if dim_match:
32 | width, height = list(map(int, dim_match.groups()))
33 | else:
34 | raise Exception("Malformed PFM header.")
35 |
36 | scale = float(file.readline().decode("ascii").rstrip())
37 | if scale < 0:
38 | # little-endian
39 | endian = "<"
40 | scale = -scale
41 | else:
42 | # big-endian
43 | endian = ">"
44 |
45 | data = np.fromfile(file, endian + "f")
46 | shape = (height, width, 3) if color else (height, width)
47 |
48 | data = np.reshape(data, shape)
49 | data = np.flipud(data)
50 |
51 | return data, scale
--------------------------------------------------------------------------------
/datasets/nerf/nerf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | import numpy as np
4 | import os
5 | from tqdm import tqdm
6 |
7 | from .ray_utils import get_ray_directions
8 | from .color_utils import read_image
9 |
10 | from .base import BaseDataset
11 |
12 |
13 | class NeRFDataset(BaseDataset):
14 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs):
15 | super().__init__(root_dir, split, downsample)
16 |
17 | self.read_intrinsics()
18 |
19 | if kwargs.get('read_meta', True):
20 | self.read_meta(split)
21 |
22 | def read_intrinsics(self):
23 | with open(os.path.join(self.root_dir, "transforms_train.json"), 'r') as f:
24 | meta = json.load(f)
25 |
26 | w = h = int(800*self.downsample)
27 | fx = fy = 0.5*800/np.tan(0.5*meta['camera_angle_x'])*self.downsample
28 |
29 | K = np.float32([[fx, 0, w/2],
30 | [0, fy, h/2],
31 | [0, 0, 1]])
32 |
33 | self.K = torch.FloatTensor(K)
34 | self.directions = get_ray_directions(h, w, self.K)
35 | self.img_wh = (w, h)
36 |
37 | def read_meta(self, split):
38 | self.rays = []
39 | self.poses = []
40 |
41 | if split == 'trainval':
42 | with open(os.path.join(self.root_dir, "transforms_train.json"), 'r') as f:
43 | frames = json.load(f)["frames"]
44 | with open(os.path.join(self.root_dir, "transforms_val.json"), 'r') as f:
45 | frames += json.load(f)["frames"]
46 | else:
47 | with open(os.path.join(self.root_dir, f"transforms_{split}.json"), 'r') as f:
48 | frames = json.load(f)["frames"]
49 |
50 | print(f'Loading {len(frames)} {split} images ...')
51 | for frame in tqdm(frames):
52 | c2w = np.array(frame['transform_matrix'])[:3, :4]
53 |
54 | # determine scale
55 | if 'Jrender_Dataset' in self.root_dir:
56 | c2w[:, :2] *= -1 # [left up front] to [right down front]
57 | folder = self.root_dir.split('/')
58 | scene = folder[-1] if folder[-1] != '' else folder[-2]
59 | if scene=='Easyship':
60 | pose_radius_scale = 1.2
61 | elif scene=='Scar':
62 | pose_radius_scale = 1.8
63 | elif scene=='Coffee':
64 | pose_radius_scale = 2.5
65 | elif scene=='Car':
66 | pose_radius_scale = 0.8
67 | else:
68 | pose_radius_scale = 1.5
69 | else:
70 | c2w[:, 1:3] *= -1 # [right up back] to [right down front]
71 | pose_radius_scale = 1.5
72 | c2w[:, 3] /= np.linalg.norm(c2w[:, 3])/pose_radius_scale
73 |
74 | # add shift
75 | if 'Jrender_Dataset' in self.root_dir:
76 | if scene=='Coffee':
77 | c2w[1, 3] -= 0.4465
78 | elif scene=='Car':
79 | c2w[0, 3] -= 0.7
80 | self.poses += [c2w]
81 |
82 | try:
83 | img_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
84 | img = read_image(img_path, self.img_wh)
85 | self.rays += [img]
86 | except: pass
87 |
88 | if len(self.rays)>0:
89 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?)
90 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4)
91 |
--------------------------------------------------------------------------------
/datasets/nerf/nerfpp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import glob
3 | import numpy as np
4 | import os
5 | from PIL import Image
6 | from tqdm import tqdm
7 |
8 | from .ray_utils import get_ray_directions
9 | from .color_utils import read_image
10 |
11 | from .base import BaseDataset
12 |
13 |
14 | class NeRFPPDataset(BaseDataset):
15 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs):
16 | super().__init__(root_dir, split, downsample)
17 |
18 | self.read_intrinsics()
19 |
20 | if kwargs.get('read_meta', True):
21 | self.read_meta(split)
22 |
23 | def read_intrinsics(self):
24 | K = np.loadtxt(glob.glob(os.path.join(self.root_dir, 'train/intrinsics/*.txt'))[0],
25 | dtype=np.float32).reshape(4, 4)[:3, :3]
26 | K[:2] *= self.downsample
27 | w, h = Image.open(glob.glob(os.path.join(self.root_dir, 'train/rgb/*'))[0]).size
28 | w, h = int(w*self.downsample), int(h*self.downsample)
29 | self.K = torch.FloatTensor(K)
30 | self.directions = get_ray_directions(h, w, self.K)
31 | self.img_wh = (w, h)
32 |
33 | def read_meta(self, split):
34 | self.rays = []
35 | self.poses = []
36 |
37 | if split == 'test_traj':
38 | poses_path = \
39 | sorted(glob.glob(os.path.join(self.root_dir, 'camera_path/pose/*.txt')))
40 | self.poses = [np.loadtxt(p).reshape(4, 4)[:3] for p in poses_path]
41 | else:
42 | if split=='trainval':
43 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 'train/rgb/*')))+\
44 | sorted(glob.glob(os.path.join(self.root_dir, 'val/rgb/*')))
45 | poses = sorted(glob.glob(os.path.join(self.root_dir, 'train/pose/*.txt')))+\
46 | sorted(glob.glob(os.path.join(self.root_dir, 'val/pose/*.txt')))
47 | else:
48 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, split, 'rgb/*')))
49 | poses = sorted(glob.glob(os.path.join(self.root_dir, split, 'pose/*.txt')))
50 |
51 | print(f'Loading {len(img_paths)} {split} images ...')
52 | for img_path, pose in tqdm(zip(img_paths, poses)):
53 | self.poses += [np.loadtxt(pose).reshape(4, 4)[:3]]
54 |
55 | img = read_image(img_path, self.img_wh)
56 | self.rays += [img]
57 |
58 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?)
59 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4)
60 |
--------------------------------------------------------------------------------
/datasets/nerf/nsvf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import glob
3 | import numpy as np
4 | import os
5 | from tqdm import tqdm
6 |
7 | from .ray_utils import get_ray_directions
8 | from .color_utils import read_image
9 |
10 | from .base import BaseDataset
11 |
12 |
13 | class NSVFDataset(BaseDataset):
14 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs):
15 | super().__init__(root_dir, split, downsample)
16 |
17 | self.read_intrinsics()
18 |
19 | if kwargs.get('read_meta', True):
20 | xyz_min, xyz_max = \
21 | np.loadtxt(os.path.join(root_dir, 'bbox.txt'))[:6].reshape(2, 3)
22 | self.shift = (xyz_max + xyz_min) / 2
23 | self.scale = (xyz_max - xyz_min).max() / 2 * 1.05 # enlarge a little
24 |
25 | # hard-code fix the bound error for some scenes...
26 | if 'Mic' in self.root_dir: self.scale *= 1.2
27 | elif 'Lego' in self.root_dir: self.scale *= 1.1
28 |
29 | self.read_meta(split)
30 |
31 | def read_intrinsics(self):
32 | if 'Synthetic' in self.root_dir or 'Ignatius' in self.root_dir:
33 | with open(os.path.join(self.root_dir, 'intrinsics.txt')) as f:
34 | fx = fy = float(f.readline().split()[0]) * self.downsample
35 | if 'Synthetic' in self.root_dir:
36 | w = h = int(800*self.downsample)
37 | else:
38 | w, h = int(1920*self.downsample), int(1080*self.downsample)
39 |
40 | K = np.float32([[fx, 0, w/2],
41 | [0, fy, h/2],
42 | [0, 0, 1]])
43 | else:
44 | K = np.loadtxt(os.path.join(self.root_dir, 'intrinsics.txt'),
45 | dtype=np.float32)[:3, :3]
46 | if 'BlendedMVS' in self.root_dir:
47 | w, h = int(768*self.downsample), int(576*self.downsample)
48 | elif 'Tanks' in self.root_dir:
49 | w, h = int(1920*self.downsample), int(1080*self.downsample)
50 | K[:2] *= self.downsample
51 |
52 | self.K = torch.FloatTensor(K)
53 | self.directions = get_ray_directions(h, w, self.K)
54 | self.img_wh = (w, h)
55 |
56 | def read_meta(self, split):
57 | self.rays = []
58 | self.poses = []
59 |
60 | if split == 'test_traj': # BlendedMVS and TanksAndTemple
61 | if 'Ignatius' in self.root_dir:
62 | poses_path = \
63 | sorted(glob.glob(os.path.join(self.root_dir, 'test_pose/*.txt')))
64 | poses = [np.loadtxt(p) for p in poses_path]
65 | else:
66 | poses = np.loadtxt(os.path.join(self.root_dir, 'test_traj.txt'))
67 | poses = poses.reshape(-1, 4, 4)
68 | for pose in poses:
69 | c2w = pose[:3]
70 | c2w[:, 0] *= -1 # [left down front] to [right down front]
71 | c2w[:, 3] -= self.shift
72 | c2w[:, 3] /= 2*self.scale # to bound the scene inside [-0.5, 0.5]
73 | self.poses += [c2w]
74 | else:
75 | if split == 'train': prefix = '0_'
76 | elif split == 'trainval': prefix = '[0-1]_'
77 | elif split == 'trainvaltest': prefix = '[0-2]_'
78 | elif split == 'val': prefix = '1_'
79 | elif 'Synthetic' in self.root_dir: prefix = '2_' # test set for synthetic scenes
80 | elif split == 'test': prefix = '1_' # test set for real scenes
81 | else: raise ValueError(f'{split} split not recognized!')
82 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 'rgb', prefix+'*.png')))
83 | poses = sorted(glob.glob(os.path.join(self.root_dir, 'pose', prefix+'*.txt')))
84 |
85 | print(f'Loading {len(img_paths)} {split} images ...')
86 | for img_path, pose in tqdm(zip(img_paths, poses)):
87 | c2w = np.loadtxt(pose)[:3]
88 | c2w[:, 3] -= self.shift
89 | c2w[:, 3] /= 2*self.scale # to bound the scene inside [-0.5, 0.5]
90 | self.poses += [c2w]
91 |
92 | img = read_image(img_path, self.img_wh)
93 | if 'Jade' in self.root_dir or 'Fountain' in self.root_dir:
94 | # these scenes have black background, changing to white
95 | img[torch.all(img<=0.1, dim=-1)] = 1.0
96 |
97 | self.rays += [img]
98 |
99 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?)
100 |
101 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4)
102 |
--------------------------------------------------------------------------------
/datasets/nerf/ray_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from kornia import create_meshgrid
4 | from einops import rearrange
5 |
6 |
7 | @torch.cuda.amp.autocast(dtype=torch.float32)
8 | def get_ray_directions(H, W, K, device='cpu', random=False, return_uv=False, flatten=True):
9 | """
10 | Get ray directions for all pixels in camera coordinate [right down front].
11 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
12 | ray-tracing-generating-camera-rays/standard-coordinate-systems
13 |
14 | Inputs:
15 | H, W: image height and width
16 | K: (3, 3) camera intrinsics
17 | random: whether the ray passes randomly inside the pixel
18 | return_uv: whether to return uv image coordinates
19 |
20 | Outputs: (shape depends on @flatten)
21 | directions: (H, W, 3) or (H*W, 3), the direction of the rays in camera coordinate
22 | uv: (H, W, 2) or (H*W, 2) image coordinates
23 | """
24 | grid = create_meshgrid(H, W, False, device=device)[0] # (H, W, 2)
25 | u, v = grid.unbind(-1)
26 |
27 | fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
28 | if random:
29 | directions = \
30 | torch.stack([(u-cx+torch.rand_like(u))/fx,
31 | (v-cy+torch.rand_like(v))/fy,
32 | torch.ones_like(u)], -1)
33 | else: # pass by the center
34 | directions = \
35 | torch.stack([(u-cx+0.5)/fx, (v-cy+0.5)/fy, torch.ones_like(u)], -1)
36 | if flatten:
37 | directions = directions.reshape(-1, 3)
38 | grid = grid.reshape(-1, 2)
39 |
40 | if return_uv:
41 | return directions, grid
42 | return directions
43 |
44 |
45 | @torch.cuda.amp.autocast(dtype=torch.float32)
46 | def get_rays(directions, c2w):
47 | """
48 | Get ray origin and directions in world coordinate for all pixels in one image.
49 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
50 | ray-tracing-generating-camera-rays/standard-coordinate-systems
51 |
52 | Inputs:
53 | directions: (N, 3) ray directions in camera coordinate
54 | c2w: (3, 4) or (N, 3, 4) transformation matrix from camera coordinate to world coordinate
55 |
56 | Outputs:
57 | rays_o: (N, 3), the origin of the rays in world coordinate
58 | rays_d: (N, 3), the direction of the rays in world coordinate
59 | """
60 | if c2w.ndim==2:
61 | # Rotate ray directions from camera coordinate to the world coordinate
62 | rays_d = directions @ c2w[:, :3].T
63 | else:
64 | rays_d = rearrange(directions, 'n c -> n 1 c') @ \
65 | rearrange(c2w[..., :3], 'n a b -> n b a')
66 | rays_d = rearrange(rays_d, 'n 1 c -> n c')
67 | # The origin of all rays is the camera origin in world coordinate
68 | rays_o = c2w[..., 3].expand_as(rays_d)
69 |
70 | return rays_o, rays_d
71 |
72 |
73 | @torch.cuda.amp.autocast(dtype=torch.float32)
74 | def axisangle_to_R(v):
75 | """
76 | Convert an axis-angle vector to rotation matrix
77 | from https://github.com/ActiveVisionLab/nerfmm/blob/main/utils/lie_group_helper.py#L47
78 |
79 | Inputs:
80 | v: (3) or (B, 3)
81 |
82 | Outputs:
83 | R: (3, 3) or (B, 3, 3)
84 | """
85 | v_ndim = v.ndim
86 | if v_ndim==1:
87 | v = rearrange(v, 'c -> 1 c')
88 | zero = torch.zeros_like(v[:, :1]) # (B, 1)
89 | skew_v0 = torch.cat([zero, -v[:, 2:3], v[:, 1:2]], 1) # (B, 3)
90 | skew_v1 = torch.cat([v[:, 2:3], zero, -v[:, 0:1]], 1)
91 | skew_v2 = torch.cat([-v[:, 1:2], v[:, 0:1], zero], 1)
92 | skew_v = torch.stack([skew_v0, skew_v1, skew_v2], dim=1) # (B, 3, 3)
93 |
94 | norm_v = rearrange(torch.norm(v, dim=1)+1e-7, 'b -> b 1 1')
95 | eye = torch.eye(3, device=v.device)
96 | R = eye + (torch.sin(norm_v)/norm_v)*skew_v + \
97 | ((1-torch.cos(norm_v))/norm_v**2)*(skew_v@skew_v)
98 |
99 | if v_ndim==1:
100 | R = rearrange(R, '1 c d -> c d')
101 |
102 | return R
103 |
104 |
105 | def normalize(v):
106 | """Normalize a vector."""
107 | return v/np.linalg.norm(v)
108 |
109 |
110 | def average_poses(poses, pts3d=None):
111 | """
112 | Calculate the average pose, which is then used to center all poses
113 | using @center_poses. Its computation is as follows:
114 | 1. Compute the center: the average of 3d point cloud (if None, center of cameras).
115 | 2. Compute the z axis: the normalized average z axis.
116 | 3. Compute axis y': the average y axis.
117 | 4. Compute x' = y' cross product z, then normalize it as the x axis.
118 | 5. Compute the y axis: z cross product x.
119 |
120 | Note that at step 3, we cannot directly use y' as y axis since it's
121 | not necessarily orthogonal to z axis. We need to pass from x to y.
122 | Inputs:
123 | poses: (N_images, 3, 4)
124 | pts3d: (N, 3)
125 |
126 | Outputs:
127 | pose_avg: (3, 4) the average pose
128 | """
129 | # 1. Compute the center
130 | if pts3d is not None:
131 | center = pts3d.mean(0)
132 | else:
133 | center = poses[..., 3].mean(0)
134 |
135 | # 2. Compute the z axis
136 | z = normalize(poses[..., 2].mean(0)) # (3)
137 |
138 | # 3. Compute axis y' (no need to normalize as it's not the final output)
139 | y_ = poses[..., 1].mean(0) # (3)
140 |
141 | # 4. Compute the x axis
142 | x = normalize(np.cross(y_, z)) # (3)
143 |
144 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
145 | y = np.cross(z, x) # (3)
146 |
147 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
148 |
149 | return pose_avg
150 |
151 |
152 | def center_poses(poses, pts3d=None):
153 | """
154 | See https://github.com/bmild/nerf/issues/34
155 | Inputs:
156 | poses: (N_images, 3, 4)
157 | pts3d: (N, 3) reconstructed point cloud
158 |
159 | Outputs:
160 | poses_centered: (N_images, 3, 4) the centered poses
161 | pts3d_centered: (N, 3) centered point cloud
162 | """
163 |
164 | pose_avg = average_poses(poses, pts3d) # (3, 4)
165 | pose_avg_homo = np.eye(4)
166 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation
167 | # by simply adding 0, 0, 0, 1 as the last row
168 | pose_avg_inv = np.linalg.inv(pose_avg_homo)
169 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4)
170 | poses_homo = \
171 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate
172 |
173 | poses_centered = pose_avg_inv @ poses_homo # (N_images, 4, 4)
174 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
175 |
176 | if pts3d is not None:
177 | pts3d_centered = pts3d @ pose_avg_inv[:, :3].T + pose_avg_inv[:, 3:].T
178 | return poses_centered, pts3d_centered
179 |
180 | return poses_centered
181 |
182 |
183 | def create_spheric_poses(radius, mean_h, n_poses=120):
184 | """
185 | Create circular poses around z axis.
186 | Inputs:
187 | radius: the (negative) height and the radius of the circle.
188 | mean_h: mean camera height
189 | Outputs:
190 | spheric_poses: (n_poses, 3, 4) the poses in the circular path
191 | """
192 | def spheric_pose(theta, phi, radius):
193 | trans_t = lambda t : np.array([
194 | [1,0,0,0],
195 | [0,1,0,2*mean_h],
196 | [0,0,1,-t]
197 | ])
198 |
199 | rot_phi = lambda phi : np.array([
200 | [1,0,0],
201 | [0,np.cos(phi),-np.sin(phi)],
202 | [0,np.sin(phi), np.cos(phi)]
203 | ])
204 |
205 | rot_theta = lambda th : np.array([
206 | [np.cos(th),0,-np.sin(th)],
207 | [0,1,0],
208 | [np.sin(th),0, np.cos(th)]
209 | ])
210 |
211 | c2w = rot_theta(theta) @ rot_phi(phi) @ trans_t(radius)
212 | c2w = np.array([[-1,0,0],[0,0,1],[0,1,0]]) @ c2w
213 | return c2w
214 |
215 | spheric_poses = []
216 | for th in np.linspace(0, 2*np.pi, n_poses+1)[:-1]:
217 | spheric_poses += [spheric_pose(th, -np.pi/12, radius)]
218 | return np.stack(spheric_poses, 0)
--------------------------------------------------------------------------------
/datasets/nerf/rtmv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import glob
3 | import json
4 | import numpy as np
5 | import os
6 | from tqdm import tqdm
7 |
8 | from .ray_utils import get_ray_directions
9 | from .color_utils import read_image
10 |
11 | from .base import BaseDataset
12 |
13 |
14 | class RTMVDataset(BaseDataset):
15 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs):
16 | super().__init__(root_dir, split, downsample)
17 |
18 | self.read_intrinsics()
19 |
20 | if kwargs.get('read_meta', True):
21 | self.read_meta(split)
22 |
23 | def read_intrinsics(self):
24 | with open(os.path.join(self.root_dir, '00000.json'), 'r') as f:
25 | meta = json.load(f)['camera_data']
26 |
27 | self.shift = np.array(meta['scene_center_3d_box'])
28 | self.scale = (np.array(meta['scene_max_3d_box'])-
29 | np.array(meta['scene_min_3d_box'])).max()/2 * 1.05 # enlarge a little
30 |
31 | fx = meta['intrinsics']['fx'] * self.downsample
32 | fy = meta['intrinsics']['fy'] * self.downsample
33 | cx = meta['intrinsics']['cx'] * self.downsample
34 | cy = meta['intrinsics']['cy'] * self.downsample
35 | w = int(meta['width']*self.downsample)
36 | h = int(meta['height']*self.downsample)
37 | K = np.float32([[fx, 0, cx],
38 | [0, fy, cy],
39 | [0, 0, 1]])
40 | self.K = torch.FloatTensor(K)
41 | self.directions = get_ray_directions(h, w, self.K)
42 | self.img_wh = (w, h)
43 |
44 | def read_meta(self, split):
45 | self.rays = []
46 | self.poses = []
47 |
48 | if split == 'train': start_idx, end_idx = 0, 100
49 | elif split == 'trainval': start_idx, end_idx = 0, 105
50 | elif split == 'test': start_idx, end_idx = 105, 150
51 | else: start_idx, end_idx = 0, 150
52 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images/*')))[start_idx:end_idx]
53 | poses = sorted(glob.glob(os.path.join(self.root_dir, '*.json')))[start_idx:end_idx]
54 |
55 | print(f'Loading {len(img_paths)} {split} images ...')
56 | for img_path, pose in tqdm(zip(img_paths, poses)):
57 | with open(pose, 'r') as f:
58 | p = json.load(f)['camera_data']
59 | c2w = np.array(p['cam2world']).T[:3]
60 | c2w[:, 1:3] *= -1
61 | if 'bricks' in self.root_dir:
62 | c2w[:, 3] -= self.shift
63 | c2w[:, 3] /= 2*self.scale # bound in [-0.5, 0.5]
64 | self.poses += [c2w]
65 |
66 | img = read_image(img_path, self.img_wh)
67 | self.rays += [img]
68 |
69 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?)
70 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4)
71 |
--------------------------------------------------------------------------------
/datasets/sdf/sampler.py:
--------------------------------------------------------------------------------
1 | """
2 | These codes are adapted from torch-ngp (https://github.com/ashawkey/torch-ngp/tree/main)
3 | """
4 |
5 | from torch.utils.data import Dataset
6 |
7 | import numpy as np
8 | import trimesh
9 | import pysdf
10 |
11 |
12 | class SDFDataset(Dataset):
13 | def __init__(self, path, size=100, num_samples=2**18, clip_sdf=None):
14 | super().__init__()
15 | self.path = path
16 |
17 | # load obj
18 | self.mesh = trimesh.load(path, force='mesh')
19 |
20 | # normalize to [-1, 1] (different from instant-sdf where is [0, 1])
21 | vs = self.mesh.vertices
22 | vmin = vs.min(0)
23 | vmax = vs.max(0)
24 | v_center = (vmin + vmax) / 2
25 | v_scale = 2 / np.sqrt(np.sum((vmax - vmin) ** 2)) * 0.95
26 | vs = (vs - v_center[None, :]) * v_scale
27 | self.mesh.vertices = vs
28 |
29 | print(f"[INFO] mesh: {self.mesh.vertices.shape} {self.mesh.faces.shape}")
30 |
31 | if not self.mesh.is_watertight:
32 | print(f"[WARN] mesh is not watertight! SDF maybe incorrect.")
33 |
34 | self.sdf_fn = pysdf.SDF(self.mesh.vertices, self.mesh.faces)
35 |
36 | self.num_samples = num_samples
37 | assert self.num_samples % 8 == 0, "num_samples must be divisible by 8."
38 | self.clip_sdf = clip_sdf
39 |
40 | self.size = size
41 |
42 | def __len__(self):
43 | return self.size
44 |
45 | def __getitem__(self, _):
46 | # online sampling
47 | sdfs = np.zeros((self.num_samples, 1))
48 | # surface
49 | points_surface = self.mesh.sample(self.num_samples * 2 // 3)
50 | # perturb surface
51 | points_surface[self.num_samples // 3:] += 0.01 * np.random.randn(self.num_samples // 3, 3)
52 | # random
53 | points_uniform = np.random.rand(self.num_samples // 3, 3) * 2 - 1
54 | points = np.concatenate([points_surface, points_uniform], axis=0).astype(np.float32)
55 |
56 | sdfs[self.num_samples // 3:] = -self.sdf_fn(points[self.num_samples // 3:])[:,None].astype(np.float32)
57 |
58 | # clip sdf
59 | if self.clip_sdf is not None:
60 | sdfs = sdfs.clip(-self.clip_sdf, self.clip_sdf)
61 |
62 | results = {
63 | 'sdfs': sdfs,
64 | 'points': points,
65 | }
66 |
67 | return results
--------------------------------------------------------------------------------
/docs/figures/2d_fitting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/docs/figures/2d_fitting.png
--------------------------------------------------------------------------------
/docs/figures/3d_fitting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/docs/figures/3d_fitting.png
--------------------------------------------------------------------------------
/docs/figures/nvs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/docs/figures/nvs.png
--------------------------------------------------------------------------------
/docs/figures/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/docs/figures/teaser.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/models/__init__.py
--------------------------------------------------------------------------------
/models/csrc/binding.cpp:
--------------------------------------------------------------------------------
1 | #include "utils.h"
2 |
3 |
4 | std::vector ray_aabb_intersect(
5 | const torch::Tensor rays_o,
6 | const torch::Tensor rays_d,
7 | const torch::Tensor centers,
8 | const torch::Tensor half_sizes,
9 | const int max_hits
10 | ){
11 | CHECK_INPUT(rays_o);
12 | CHECK_INPUT(rays_d);
13 | CHECK_INPUT(centers);
14 | CHECK_INPUT(half_sizes);
15 | return ray_aabb_intersect_cu(rays_o, rays_d, centers, half_sizes, max_hits);
16 | }
17 |
18 |
19 | std::vector ray_sphere_intersect(
20 | const torch::Tensor rays_o,
21 | const torch::Tensor rays_d,
22 | const torch::Tensor centers,
23 | const torch::Tensor radii,
24 | const int max_hits
25 | ){
26 | CHECK_INPUT(rays_o);
27 | CHECK_INPUT(rays_d);
28 | CHECK_INPUT(centers);
29 | CHECK_INPUT(radii);
30 | return ray_sphere_intersect_cu(rays_o, rays_d, centers, radii, max_hits);
31 | }
32 |
33 |
34 | void packbits(
35 | torch::Tensor density_grid,
36 | const float density_threshold,
37 | torch::Tensor density_bitfield
38 | ){
39 | CHECK_INPUT(density_grid);
40 | CHECK_INPUT(density_bitfield);
41 |
42 | return packbits_cu(density_grid, density_threshold, density_bitfield);
43 | }
44 |
45 |
46 | torch::Tensor morton3D(const torch::Tensor coords){
47 | CHECK_INPUT(coords);
48 |
49 | return morton3D_cu(coords);
50 | }
51 |
52 |
53 | torch::Tensor morton3D_invert(const torch::Tensor indices){
54 | CHECK_INPUT(indices);
55 |
56 | return morton3D_invert_cu(indices);
57 | }
58 |
59 |
60 | std::vector raymarching_train(
61 | const torch::Tensor rays_o,
62 | const torch::Tensor rays_d,
63 | const torch::Tensor hits_t,
64 | const torch::Tensor density_bitfield,
65 | const int cascades,
66 | const float scale,
67 | const float exp_step_factor,
68 | const torch::Tensor noise,
69 | const int grid_size,
70 | const int max_samples
71 | ){
72 | CHECK_INPUT(rays_o);
73 | CHECK_INPUT(rays_d);
74 | CHECK_INPUT(hits_t);
75 | CHECK_INPUT(density_bitfield);
76 | CHECK_INPUT(noise);
77 |
78 | return raymarching_train_cu(
79 | rays_o, rays_d, hits_t, density_bitfield, cascades,
80 | scale, exp_step_factor, noise, grid_size, max_samples);
81 | }
82 |
83 |
84 | std::vector raymarching_test(
85 | const torch::Tensor rays_o,
86 | const torch::Tensor rays_d,
87 | torch::Tensor hits_t,
88 | const torch::Tensor alive_indices,
89 | const torch::Tensor density_bitfield,
90 | const int cascades,
91 | const float scale,
92 | const float exp_step_factor,
93 | const int grid_size,
94 | const int max_samples,
95 | const int N_samples
96 | ){
97 | CHECK_INPUT(rays_o);
98 | CHECK_INPUT(rays_d);
99 | CHECK_INPUT(hits_t);
100 | CHECK_INPUT(alive_indices);
101 | CHECK_INPUT(density_bitfield);
102 |
103 | return raymarching_test_cu(
104 | rays_o, rays_d, hits_t, alive_indices, density_bitfield, cascades,
105 | scale, exp_step_factor, grid_size, max_samples, N_samples);
106 | }
107 |
108 |
109 | std::vector composite_train_fw(
110 | const torch::Tensor sigmas,
111 | const torch::Tensor rgbs,
112 | const torch::Tensor deltas,
113 | const torch::Tensor ts,
114 | const torch::Tensor rays_a,
115 | const float opacity_threshold
116 | ){
117 | CHECK_INPUT(sigmas);
118 | CHECK_INPUT(rgbs);
119 | CHECK_INPUT(deltas);
120 | CHECK_INPUT(ts);
121 | CHECK_INPUT(rays_a);
122 |
123 | return composite_train_fw_cu(
124 | sigmas, rgbs, deltas, ts,
125 | rays_a, opacity_threshold);
126 | }
127 |
128 |
129 | std::vector composite_train_bw(
130 | const torch::Tensor dL_dopacity,
131 | const torch::Tensor dL_ddepth,
132 | const torch::Tensor dL_drgb,
133 | const torch::Tensor dL_dws,
134 | const torch::Tensor sigmas,
135 | const torch::Tensor rgbs,
136 | const torch::Tensor ws,
137 | const torch::Tensor deltas,
138 | const torch::Tensor ts,
139 | const torch::Tensor rays_a,
140 | const torch::Tensor opacity,
141 | const torch::Tensor depth,
142 | const torch::Tensor rgb,
143 | const float opacity_threshold
144 | ){
145 | CHECK_INPUT(dL_dopacity);
146 | CHECK_INPUT(dL_ddepth);
147 | CHECK_INPUT(dL_drgb);
148 | CHECK_INPUT(dL_dws);
149 | CHECK_INPUT(sigmas);
150 | CHECK_INPUT(rgbs);
151 | CHECK_INPUT(ws);
152 | CHECK_INPUT(deltas);
153 | CHECK_INPUT(ts);
154 | CHECK_INPUT(rays_a);
155 | CHECK_INPUT(opacity);
156 | CHECK_INPUT(depth);
157 | CHECK_INPUT(rgb);
158 |
159 | return composite_train_bw_cu(
160 | dL_dopacity, dL_ddepth, dL_drgb, dL_dws,
161 | sigmas, rgbs, ws, deltas, ts, rays_a,
162 | opacity, depth, rgb, opacity_threshold);
163 | }
164 |
165 |
166 | void composite_test_fw(
167 | const torch::Tensor sigmas,
168 | const torch::Tensor rgbs,
169 | const torch::Tensor deltas,
170 | const torch::Tensor ts,
171 | const torch::Tensor hits_t,
172 | const torch::Tensor alive_indices,
173 | const float T_threshold,
174 | const torch::Tensor N_eff_samples,
175 | torch::Tensor opacity,
176 | torch::Tensor depth,
177 | torch::Tensor rgb
178 | ){
179 | CHECK_INPUT(sigmas);
180 | CHECK_INPUT(rgbs);
181 | CHECK_INPUT(deltas);
182 | CHECK_INPUT(ts);
183 | CHECK_INPUT(hits_t);
184 | CHECK_INPUT(alive_indices);
185 | CHECK_INPUT(N_eff_samples);
186 | CHECK_INPUT(opacity);
187 | CHECK_INPUT(depth);
188 | CHECK_INPUT(rgb);
189 |
190 | composite_test_fw_cu(
191 | sigmas, rgbs, deltas, ts, hits_t, alive_indices,
192 | T_threshold, N_eff_samples,
193 | opacity, depth, rgb);
194 | }
195 |
196 |
197 | std::vector distortion_loss_fw(
198 | const torch::Tensor ws,
199 | const torch::Tensor deltas,
200 | const torch::Tensor ts,
201 | const torch::Tensor rays_a
202 | ){
203 | CHECK_INPUT(ws);
204 | CHECK_INPUT(deltas);
205 | CHECK_INPUT(ts);
206 | CHECK_INPUT(rays_a);
207 |
208 | return distortion_loss_fw_cu(ws, deltas, ts, rays_a);
209 | }
210 |
211 |
212 | torch::Tensor distortion_loss_bw(
213 | const torch::Tensor dL_dloss,
214 | const torch::Tensor ws_inclusive_scan,
215 | const torch::Tensor wts_inclusive_scan,
216 | const torch::Tensor ws,
217 | const torch::Tensor deltas,
218 | const torch::Tensor ts,
219 | const torch::Tensor rays_a
220 | ){
221 | CHECK_INPUT(dL_dloss);
222 | CHECK_INPUT(ws_inclusive_scan);
223 | CHECK_INPUT(wts_inclusive_scan);
224 | CHECK_INPUT(ws);
225 | CHECK_INPUT(deltas);
226 | CHECK_INPUT(ts);
227 | CHECK_INPUT(rays_a);
228 |
229 | return distortion_loss_bw_cu(dL_dloss, ws_inclusive_scan, wts_inclusive_scan,
230 | ws, deltas, ts, rays_a);
231 | }
232 |
233 |
234 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
235 | m.def("ray_aabb_intersect", &ray_aabb_intersect);
236 | m.def("ray_sphere_intersect", &ray_sphere_intersect);
237 |
238 | m.def("morton3D", &morton3D);
239 | m.def("morton3D_invert", &morton3D_invert);
240 | m.def("packbits", &packbits);
241 |
242 | m.def("raymarching_train", &raymarching_train);
243 | m.def("raymarching_test", &raymarching_test);
244 | m.def("composite_train_fw", &composite_train_fw);
245 | m.def("composite_train_bw", &composite_train_bw);
246 | m.def("composite_test_fw", &composite_test_fw);
247 |
248 | m.def("distortion_loss_fw", &distortion_loss_fw);
249 | m.def("distortion_loss_bw", &distortion_loss_bw);
250 |
251 | }
--------------------------------------------------------------------------------
/models/csrc/include/helper_math.h:
--------------------------------------------------------------------------------
1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | *
3 | * Redistribution and use in source and binary forms, with or without
4 | * modification, are permitted provided that the following conditions
5 | * are met:
6 | * * Redistributions of source code must retain the above copyright
7 | * notice, this list of conditions and the following disclaimer.
8 | * * Redistributions in binary form must reproduce the above copyright
9 | * notice, this list of conditions and the following disclaimer in the
10 | * documentation and/or other materials provided with the distribution.
11 | * * Neither the name of NVIDIA CORPORATION nor the names of its
12 | * contributors may be used to endorse or promote products derived
13 | * from this software without specific prior written permission.
14 | *
15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16 | * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18 | * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19 | * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22 | * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23 | * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 | */
27 |
28 | /*
29 | * This file implements common mathematical operations on vector types
30 | * (float3, float4 etc.) since these are not provided as standard by CUDA.
31 | *
32 | * The syntax is modeled on the Cg standard library.
33 | *
34 | * This is part of the Helper library includes
35 | *
36 | * Thanks to Linh Hah for additions and fixes.
37 | */
38 |
39 | #ifndef HELPER_MATH_H
40 | #define HELPER_MATH_H
41 |
42 | #include "cuda_runtime.h"
43 |
44 | typedef unsigned int uint;
45 | typedef unsigned short ushort;
46 |
47 | #ifndef EXIT_WAIVED
48 | #define EXIT_WAIVED 2
49 | #endif
50 |
51 | #ifndef __CUDACC__
52 | #include
53 |
54 | ////////////////////////////////////////////////////////////////////////////////
55 | // host implementations of CUDA functions
56 | ////////////////////////////////////////////////////////////////////////////////
57 |
58 | inline float fminf(float a, float b)
59 | {
60 | return a < b ? a : b;
61 | }
62 |
63 | inline float fmaxf(float a, float b)
64 | {
65 | return a > b ? a : b;
66 | }
67 |
68 | inline int max(int a, int b)
69 | {
70 | return a > b ? a : b;
71 | }
72 |
73 | inline int min(int a, int b)
74 | {
75 | return a < b ? a : b;
76 | }
77 |
78 | inline float rsqrtf(float x)
79 | {
80 | return 1.0f / sqrtf(x);
81 | }
82 | #endif
83 |
84 | ////////////////////////////////////////////////////////////////////////////////
85 | // constructors
86 | ////////////////////////////////////////////////////////////////////////////////
87 |
88 | inline __host__ __device__ float2 make_float2(float s)
89 | {
90 | return make_float2(s, s);
91 | }
92 | inline __host__ __device__ float2 make_float2(float3 a)
93 | {
94 | return make_float2(a.x, a.y);
95 | }
96 | inline __host__ __device__ float3 make_float3(float s)
97 | {
98 | return make_float3(s, s, s);
99 | }
100 | inline __host__ __device__ float3 make_float3(float2 a)
101 | {
102 | return make_float3(a.x, a.y, 0.0f);
103 | }
104 | inline __host__ __device__ float3 make_float3(float2 a, float s)
105 | {
106 | return make_float3(a.x, a.y, s);
107 | }
108 |
109 | ////////////////////////////////////////////////////////////////////////////////
110 | // negate
111 | ////////////////////////////////////////////////////////////////////////////////
112 |
113 | inline __host__ __device__ float3 operator-(float3 &a)
114 | {
115 | return make_float3(-a.x, -a.y, -a.z);
116 | }
117 |
118 | ////////////////////////////////////////////////////////////////////////////////
119 | // addition
120 | ////////////////////////////////////////////////////////////////////////////////
121 |
122 | inline __host__ __device__ float3 operator+(float3 a, float3 b)
123 | {
124 | return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
125 | }
126 | inline __host__ __device__ void operator+=(float3 &a, float3 b)
127 | {
128 | a.x += b.x;
129 | a.y += b.y;
130 | a.z += b.z;
131 | }
132 | inline __host__ __device__ float3 operator+(float3 a, float b)
133 | {
134 | return make_float3(a.x + b, a.y + b, a.z + b);
135 | }
136 | inline __host__ __device__ void operator+=(float3 &a, float b)
137 | {
138 | a.x += b;
139 | a.y += b;
140 | a.z += b;
141 | }
142 | inline __host__ __device__ float3 operator+(float b, float3 a)
143 | {
144 | return make_float3(a.x + b, a.y + b, a.z + b);
145 | }
146 |
147 | ////////////////////////////////////////////////////////////////////////////////
148 | // subtract
149 | ////////////////////////////////////////////////////////////////////////////////
150 |
151 | inline __host__ __device__ float3 operator-(float3 a, float3 b)
152 | {
153 | return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
154 | }
155 | inline __host__ __device__ void operator-=(float3 &a, float3 b)
156 | {
157 | a.x -= b.x;
158 | a.y -= b.y;
159 | a.z -= b.z;
160 | }
161 | inline __host__ __device__ float3 operator-(float3 a, float b)
162 | {
163 | return make_float3(a.x - b, a.y - b, a.z - b);
164 | }
165 | inline __host__ __device__ float3 operator-(float b, float3 a)
166 | {
167 | return make_float3(b - a.x, b - a.y, b - a.z);
168 | }
169 | inline __host__ __device__ void operator-=(float3 &a, float b)
170 | {
171 | a.x -= b;
172 | a.y -= b;
173 | a.z -= b;
174 | }
175 |
176 | ////////////////////////////////////////////////////////////////////////////////
177 | // multiply
178 | ////////////////////////////////////////////////////////////////////////////////
179 |
180 | inline __host__ __device__ float3 operator*(float3 a, float3 b)
181 | {
182 | return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
183 | }
184 | inline __host__ __device__ void operator*=(float3 &a, float3 b)
185 | {
186 | a.x *= b.x;
187 | a.y *= b.y;
188 | a.z *= b.z;
189 | }
190 | inline __host__ __device__ float3 operator*(float3 a, float b)
191 | {
192 | return make_float3(a.x * b, a.y * b, a.z * b);
193 | }
194 | inline __host__ __device__ float3 operator*(float b, float3 a)
195 | {
196 | return make_float3(b * a.x, b * a.y, b * a.z);
197 | }
198 | inline __host__ __device__ void operator*=(float3 &a, float b)
199 | {
200 | a.x *= b;
201 | a.y *= b;
202 | a.z *= b;
203 | }
204 |
205 | ////////////////////////////////////////////////////////////////////////////////
206 | // divide
207 | ////////////////////////////////////////////////////////////////////////////////
208 |
209 | inline __host__ __device__ float2 operator/(float2 a, float2 b)
210 | {
211 | return make_float2(a.x / b.x, a.y / b.y);
212 | }
213 | inline __host__ __device__ void operator/=(float2 &a, float2 b)
214 | {
215 | a.x /= b.x;
216 | a.y /= b.y;
217 | }
218 | inline __host__ __device__ float2 operator/(float2 a, float b)
219 | {
220 | return make_float2(a.x / b, a.y / b);
221 | }
222 | inline __host__ __device__ void operator/=(float2 &a, float b)
223 | {
224 | a.x /= b;
225 | a.y /= b;
226 | }
227 | inline __host__ __device__ float2 operator/(float b, float2 a)
228 | {
229 | return make_float2(b / a.x, b / a.y);
230 | }
231 |
232 | inline __host__ __device__ float3 operator/(float3 a, float3 b)
233 | {
234 | return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
235 | }
236 | inline __host__ __device__ void operator/=(float3 &a, float3 b)
237 | {
238 | a.x /= b.x;
239 | a.y /= b.y;
240 | a.z /= b.z;
241 | }
242 | inline __host__ __device__ float3 operator/(float3 a, float b)
243 | {
244 | return make_float3(a.x / b, a.y / b, a.z / b);
245 | }
246 | inline __host__ __device__ void operator/=(float3 &a, float b)
247 | {
248 | a.x /= b;
249 | a.y /= b;
250 | a.z /= b;
251 | }
252 | inline __host__ __device__ float3 operator/(float b, float3 a)
253 | {
254 | return make_float3(b / a.x, b / a.y, b / a.z);
255 | }
256 |
257 | ////////////////////////////////////////////////////////////////////////////////
258 | // min
259 | ////////////////////////////////////////////////////////////////////////////////
260 |
261 | inline __host__ __device__ float3 fminf(float3 a, float3 b)
262 | {
263 | return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z));
264 | }
265 |
266 | ////////////////////////////////////////////////////////////////////////////////
267 | // max
268 | ////////////////////////////////////////////////////////////////////////////////
269 |
270 | inline __host__ __device__ float3 fmaxf(float3 a, float3 b)
271 | {
272 | return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z));
273 | }
274 |
275 | ////////////////////////////////////////////////////////////////////////////////
276 | // clamp
277 | // - clamp the value v to be in the range [a, b]
278 | ////////////////////////////////////////////////////////////////////////////////
279 |
280 | inline __device__ __host__ float clamp(float f, float a, float b)
281 | {
282 | return fmaxf(a, fminf(f, b));
283 | }
284 | inline __device__ __host__ int clamp(int f, int a, int b)
285 | {
286 | return max(a, min(f, b));
287 | }
288 | inline __device__ __host__ uint clamp(uint f, uint a, uint b)
289 | {
290 | return max(a, min(f, b));
291 | }
292 |
293 | inline __device__ __host__ float3 clamp(float3 v, float a, float b)
294 | {
295 | return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
296 | }
297 | inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)
298 | {
299 | return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
300 | }
301 |
302 | ////////////////////////////////////////////////////////////////////////////////
303 | // dot product
304 | ////////////////////////////////////////////////////////////////////////////////
305 |
306 | inline __host__ __device__ float dot(float3 a, float3 b)
307 | {
308 | return a.x * b.x + a.y * b.y + a.z * b.z;
309 | }
310 |
311 | ////////////////////////////////////////////////////////////////////////////////
312 | // length
313 | ////////////////////////////////////////////////////////////////////////////////
314 |
315 | inline __host__ __device__ float length(float3 v)
316 | {
317 | return sqrtf(dot(v, v));
318 | }
319 |
320 | ////////////////////////////////////////////////////////////////////////////////
321 | // normalize
322 | ////////////////////////////////////////////////////////////////////////////////
323 |
324 | inline __host__ __device__ float3 normalize(float3 v)
325 | {
326 | float invLen = rsqrtf(dot(v, v));
327 | return v * invLen;
328 | }
329 |
330 | ////////////////////////////////////////////////////////////////////////////////
331 | // reflect
332 | // - returns reflection of incident ray I around surface normal N
333 | // - N should be normalized, reflected vector's length is equal to length of I
334 | ////////////////////////////////////////////////////////////////////////////////
335 |
336 | inline __host__ __device__ float3 reflect(float3 i, float3 n)
337 | {
338 | return i - 2.0f * n * dot(n,i);
339 | }
340 |
341 | ////////////////////////////////////////////////////////////////////////////////
342 | // cross product
343 | ////////////////////////////////////////////////////////////////////////////////
344 |
345 | inline __host__ __device__ float3 cross(float3 a, float3 b)
346 | {
347 | return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);
348 | }
349 |
350 | ////////////////////////////////////////////////////////////////////////////////
351 | // smoothstep
352 | // - returns 0 if x < a
353 | // - returns 1 if x > b
354 | // - otherwise returns smooth interpolation between 0 and 1 based on x
355 | ////////////////////////////////////////////////////////////////////////////////
356 |
357 | inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x)
358 | {
359 | float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
360 | return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y)));
361 | }
362 |
363 | #endif
364 |
--------------------------------------------------------------------------------
/models/csrc/include/utils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
5 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
6 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
7 |
8 |
9 | std::vector ray_aabb_intersect_cu(
10 | const torch::Tensor rays_o,
11 | const torch::Tensor rays_d,
12 | const torch::Tensor centers,
13 | const torch::Tensor half_sizes,
14 | const int max_hits
15 | );
16 |
17 |
18 | std::vector ray_sphere_intersect_cu(
19 | const torch::Tensor rays_o,
20 | const torch::Tensor rays_d,
21 | const torch::Tensor centers,
22 | const torch::Tensor radii,
23 | const int max_hits
24 | );
25 |
26 |
27 | void packbits_cu(
28 | torch::Tensor density_grid,
29 | const float density_threshold,
30 | torch::Tensor density_bitfield
31 | );
32 |
33 |
34 | torch::Tensor morton3D_cu(const torch::Tensor coords);
35 | torch::Tensor morton3D_invert_cu(const torch::Tensor indices);
36 |
37 |
38 | std::vector raymarching_train_cu(
39 | const torch::Tensor rays_o,
40 | const torch::Tensor rays_d,
41 | const torch::Tensor hits_t,
42 | const torch::Tensor density_bitfield,
43 | const int cascades,
44 | const float scale,
45 | const float exp_step_factor,
46 | const torch::Tensor noise,
47 | const int grid_size,
48 | const int max_samples
49 | );
50 |
51 |
52 | std::vector raymarching_test_cu(
53 | const torch::Tensor rays_o,
54 | const torch::Tensor rays_d,
55 | torch::Tensor hits_t,
56 | const torch::Tensor alive_indices,
57 | const torch::Tensor density_bitfield,
58 | const int cascades,
59 | const float scale,
60 | const float exp_step_factor,
61 | const int grid_size,
62 | const int max_samples,
63 | const int N_samples
64 | );
65 |
66 |
67 | std::vector composite_train_fw_cu(
68 | const torch::Tensor sigmas,
69 | const torch::Tensor rgbs,
70 | const torch::Tensor deltas,
71 | const torch::Tensor ts,
72 | const torch::Tensor rays_a,
73 | const float T_threshold
74 | );
75 |
76 |
77 | std::vector composite_train_bw_cu(
78 | const torch::Tensor dL_dopacity,
79 | const torch::Tensor dL_ddepth,
80 | const torch::Tensor dL_drgb,
81 | const torch::Tensor dL_dws,
82 | const torch::Tensor sigmas,
83 | const torch::Tensor rgbs,
84 | const torch::Tensor ws,
85 | const torch::Tensor deltas,
86 | const torch::Tensor ts,
87 | const torch::Tensor rays_a,
88 | const torch::Tensor opacity,
89 | const torch::Tensor depth,
90 | const torch::Tensor rgb,
91 | const float T_threshold
92 | );
93 |
94 |
95 | void composite_test_fw_cu(
96 | const torch::Tensor sigmas,
97 | const torch::Tensor rgbs,
98 | const torch::Tensor deltas,
99 | const torch::Tensor ts,
100 | const torch::Tensor hits_t,
101 | const torch::Tensor alive_indices,
102 | const float T_threshold,
103 | const torch::Tensor N_eff_samples,
104 | torch::Tensor opacity,
105 | torch::Tensor depth,
106 | torch::Tensor rgb
107 | );
108 |
109 |
110 | std::vector distortion_loss_fw_cu(
111 | const torch::Tensor ws,
112 | const torch::Tensor deltas,
113 | const torch::Tensor ts,
114 | const torch::Tensor rays_a
115 | );
116 |
117 |
118 | torch::Tensor distortion_loss_bw_cu(
119 | const torch::Tensor dL_dloss,
120 | const torch::Tensor ws_inclusive_scan,
121 | const torch::Tensor wts_inclusive_scan,
122 | const torch::Tensor ws,
123 | const torch::Tensor deltas,
124 | const torch::Tensor ts,
125 | const torch::Tensor rays_a
126 | );
--------------------------------------------------------------------------------
/models/csrc/intersection.cu:
--------------------------------------------------------------------------------
1 | #include "helper_math.h"
2 | #include "utils.h"
3 |
4 |
5 | __device__ __forceinline__ float2 _ray_aabb_intersect(
6 | const float3 ray_o,
7 | const float3 inv_d,
8 | const float3 center,
9 | const float3 half_size
10 | ){
11 |
12 | const float3 t_min = (center-half_size-ray_o)*inv_d;
13 | const float3 t_max = (center+half_size-ray_o)*inv_d;
14 |
15 | const float3 _t1 = fminf(t_min, t_max);
16 | const float3 _t2 = fmaxf(t_min, t_max);
17 | const float t1 = fmaxf(fmaxf(_t1.x, _t1.y), _t1.z);
18 | const float t2 = fminf(fminf(_t2.x, _t2.y), _t2.z);
19 |
20 | if (t1 > t2) return make_float2(-1.0f); // no intersection
21 | return make_float2(t1, t2);
22 | }
23 |
24 |
25 | __global__ void ray_aabb_intersect_kernel(
26 | const torch::PackedTensorAccessor32 rays_o,
27 | const torch::PackedTensorAccessor32 rays_d,
28 | const torch::PackedTensorAccessor32 centers,
29 | const torch::PackedTensorAccessor32 half_sizes,
30 | const int max_hits,
31 | int* __restrict__ hit_cnt,
32 | torch::PackedTensorAccessor32 hits_t,
33 | torch::PackedTensorAccessor64 hits_voxel_idx
34 | ){
35 | const int r = blockIdx.x * blockDim.x + threadIdx.x;
36 | const int v = blockIdx.y * blockDim.y + threadIdx.y;
37 |
38 | if (v>=centers.size(0) || r>=rays_o.size(0)) return;
39 |
40 | const float3 ray_o = make_float3(rays_o[r][0], rays_o[r][1], rays_o[r][2]);
41 | const float3 ray_d = make_float3(rays_d[r][0], rays_d[r][1], rays_d[r][2]);
42 | const float3 inv_d = 1.0f/ray_d;
43 |
44 | const float3 center = make_float3(centers[v][0], centers[v][1], centers[v][2]);
45 | const float3 half_size = make_float3(half_sizes[v][0], half_sizes[v][1], half_sizes[v][2]);
46 | const float2 t1t2 = _ray_aabb_intersect(ray_o, inv_d, center, half_size);
47 |
48 | if (t1t2.y > 0){ // if ray hits the voxel
49 | const int cnt = atomicAdd(&hit_cnt[r], 1);
50 | if (cnt < max_hits){
51 | hits_t[r][cnt][0] = fmaxf(t1t2.x, 0.0f);
52 | hits_t[r][cnt][1] = t1t2.y;
53 | hits_voxel_idx[r][cnt] = v;
54 | }
55 | }
56 | }
57 |
58 |
59 | std::vector ray_aabb_intersect_cu(
60 | const torch::Tensor rays_o,
61 | const torch::Tensor rays_d,
62 | const torch::Tensor centers,
63 | const torch::Tensor half_sizes,
64 | const int max_hits
65 | ){
66 |
67 | const int N_rays = rays_o.size(0), N_voxels = centers.size(0);
68 | auto hits_t = torch::zeros({N_rays, max_hits, 2}, rays_o.options())-1;
69 | auto hits_voxel_idx =
70 | torch::zeros({N_rays, max_hits},
71 | torch::dtype(torch::kLong).device(rays_o.device()))-1;
72 | auto hit_cnt =
73 | torch::zeros({N_rays},
74 | torch::dtype(torch::kInt32).device(rays_o.device()));
75 |
76 | const dim3 threads(256, 1);
77 | const dim3 blocks((N_rays+threads.x-1)/threads.x,
78 | (N_voxels+threads.y-1)/threads.y);
79 |
80 | AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "ray_aabb_intersect_cu",
81 | ([&] {
82 | ray_aabb_intersect_kernel<<>>(
83 | rays_o.packed_accessor32(),
84 | rays_d.packed_accessor32(),
85 | centers.packed_accessor32(),
86 | half_sizes.packed_accessor32(),
87 | max_hits,
88 | hit_cnt.data_ptr(),
89 | hits_t.packed_accessor32(),
90 | hits_voxel_idx.packed_accessor64()
91 | );
92 | }));
93 |
94 | // sort intersections from near to far based on t1
95 | auto hits_order = std::get<1>(torch::sort(hits_t.index({"...", 0})));
96 | hits_voxel_idx = torch::gather(hits_voxel_idx, 1, hits_order);
97 | hits_t = torch::gather(hits_t, 1, hits_order.unsqueeze(-1).tile({1, 1, 2}));
98 |
99 | return {hit_cnt, hits_t, hits_voxel_idx};
100 | }
101 |
102 |
103 | __device__ __forceinline__ float2 _ray_sphere_intersect(
104 | const float3 ray_o,
105 | const float3 ray_d,
106 | const float3 center,
107 | const float radius
108 | ){
109 | const float3 co = ray_o-center;
110 |
111 | const float a = dot(ray_d, ray_d);
112 | const float half_b = dot(ray_d, co);
113 | const float c = dot(co, co)-radius*radius;
114 |
115 | const float discriminant = half_b*half_b-a*c;
116 |
117 | if (discriminant < 0) return make_float2(-1.0f); // no intersection
118 |
119 | const float disc_sqrt = sqrtf(discriminant);
120 | return make_float2(-half_b-disc_sqrt, -half_b+disc_sqrt)/a;
121 | }
122 |
123 |
124 | __global__ void ray_sphere_intersect_kernel(
125 | const torch::PackedTensorAccessor32 rays_o,
126 | const torch::PackedTensorAccessor32 rays_d,
127 | const torch::PackedTensorAccessor32 centers,
128 | const torch::PackedTensorAccessor32 radii,
129 | const int max_hits,
130 | int* __restrict__ hit_cnt,
131 | torch::PackedTensorAccessor32 hits_t,
132 | torch::PackedTensorAccessor64 hits_sphere_idx
133 | ){
134 | const int r = blockIdx.x * blockDim.x + threadIdx.x;
135 | const int s = blockIdx.y * blockDim.y + threadIdx.y;
136 |
137 | if (s>=centers.size(0) || r>=rays_o.size(0)) return;
138 |
139 | const float3 ray_o = make_float3(rays_o[r][0], rays_o[r][1], rays_o[r][2]);
140 | const float3 ray_d = make_float3(rays_d[r][0], rays_d[r][1], rays_d[r][2]);
141 | const float3 center = make_float3(centers[s][0], centers[s][1], centers[s][2]);
142 |
143 | const float2 t1t2 = _ray_sphere_intersect(ray_o, ray_d, center, radii[s]);
144 |
145 | if (t1t2.y > 0){ // if ray hits the sphere
146 | const int cnt = atomicAdd(&hit_cnt[r], 1);
147 | if (cnt < max_hits){
148 | hits_t[r][cnt][0] = fmaxf(t1t2.x, 0.0f);
149 | hits_t[r][cnt][1] = t1t2.y;
150 | hits_sphere_idx[r][cnt] = s;
151 | }
152 | }
153 | }
154 |
155 |
156 | std::vector ray_sphere_intersect_cu(
157 | const torch::Tensor rays_o,
158 | const torch::Tensor rays_d,
159 | const torch::Tensor centers,
160 | const torch::Tensor radii,
161 | const int max_hits
162 | ){
163 |
164 | const int N_rays = rays_o.size(0), N_spheres = centers.size(0);
165 | auto hits_t = torch::zeros({N_rays, max_hits, 2}, rays_o.options())-1;
166 | auto hits_sphere_idx =
167 | torch::zeros({N_rays, max_hits},
168 | torch::dtype(torch::kLong).device(rays_o.device()))-1;
169 | auto hit_cnt =
170 | torch::zeros({N_rays},
171 | torch::dtype(torch::kInt32).device(rays_o.device()));
172 |
173 | const dim3 threads(256, 1);
174 | const dim3 blocks((N_rays+threads.x-1)/threads.x,
175 | (N_spheres+threads.y-1)/threads.y);
176 |
177 | AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "ray_sphere_intersect_cu",
178 | ([&] {
179 | ray_sphere_intersect_kernel<<>>(
180 | rays_o.packed_accessor32(),
181 | rays_d.packed_accessor32(),
182 | centers.packed_accessor32(),
183 | radii.packed_accessor32(),
184 | max_hits,
185 | hit_cnt.data_ptr(),
186 | hits_t.packed_accessor32(),
187 | hits_sphere_idx.packed_accessor64()
188 | );
189 | }));
190 |
191 | // sort intersections from near to far based on t1
192 | auto hits_order = std::get<1>(torch::sort(hits_t.index({"...", 0})));
193 | hits_sphere_idx = torch::gather(hits_sphere_idx, 1, hits_order);
194 | hits_t = torch::gather(hits_t, 1, hits_order.unsqueeze(-1).tile({1, 1, 2}));
195 |
196 | return {hit_cnt, hits_t, hits_sphere_idx};
197 | }
--------------------------------------------------------------------------------
/models/csrc/losses.cu:
--------------------------------------------------------------------------------
1 | #include "utils.h"
2 | #include
3 | #include
4 | #include
5 |
6 |
7 | // for details of the formulae, please see https://arxiv.org/pdf/2206.05085.pdf
8 |
9 | template
10 | __global__ void prefix_sums_kernel(
11 | const scalar_t* __restrict__ ws,
12 | const scalar_t* __restrict__ wts,
13 | const torch::PackedTensorAccessor64 rays_a,
14 | scalar_t* __restrict__ ws_inclusive_scan,
15 | scalar_t* __restrict__ ws_exclusive_scan,
16 | scalar_t* __restrict__ wts_inclusive_scan,
17 | scalar_t* __restrict__ wts_exclusive_scan
18 | ){
19 | const int n = blockIdx.x * blockDim.x + threadIdx.x;
20 | if (n >= rays_a.size(0)) return;
21 |
22 | const int start_idx = rays_a[n][1], N_samples = rays_a[n][2];
23 |
24 | // compute prefix sum of ws and ws*ts
25 | // [a0, a1, a2, a3, ...] -> [a0, a0+a1, a0+a1+a2, a0+a1+a2+a3, ...]
26 | thrust::inclusive_scan(thrust::device,
27 | ws+start_idx,
28 | ws+start_idx+N_samples,
29 | ws_inclusive_scan+start_idx);
30 | thrust::inclusive_scan(thrust::device,
31 | wts+start_idx,
32 | wts+start_idx+N_samples,
33 | wts_inclusive_scan+start_idx);
34 | // [a0, a1, a2, a3, ...] -> [0, a0, a0+a1, a0+a1+a2, ...]
35 | thrust::exclusive_scan(thrust::device,
36 | ws+start_idx,
37 | ws+start_idx+N_samples,
38 | ws_exclusive_scan+start_idx);
39 | thrust::exclusive_scan(thrust::device,
40 | wts+start_idx,
41 | wts+start_idx+N_samples,
42 | wts_exclusive_scan+start_idx);
43 | }
44 |
45 |
46 | template
47 | __global__ void distortion_loss_fw_kernel(
48 | const scalar_t* __restrict__ _loss,
49 | const torch::PackedTensorAccessor64 rays_a,
50 | torch::PackedTensorAccessor loss
51 | ){
52 | const int n = blockIdx.x * blockDim.x + threadIdx.x;
53 | if (n >= rays_a.size(0)) return;
54 |
55 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2];
56 |
57 | loss[ray_idx] = thrust::reduce(thrust::device,
58 | _loss+start_idx,
59 | _loss+start_idx+N_samples,
60 | (scalar_t)0);
61 | }
62 |
63 |
64 | std::vector distortion_loss_fw_cu(
65 | const torch::Tensor ws,
66 | const torch::Tensor deltas,
67 | const torch::Tensor ts,
68 | const torch::Tensor rays_a
69 | ){
70 | const int N_rays = rays_a.size(0), N = ws.size(0);
71 |
72 | auto wts = ws * ts;
73 |
74 | auto ws_inclusive_scan = torch::zeros({N}, ws.options());
75 | auto ws_exclusive_scan = torch::zeros({N}, ws.options());
76 | auto wts_inclusive_scan = torch::zeros({N}, ws.options());
77 | auto wts_exclusive_scan = torch::zeros({N}, ws.options());
78 |
79 | const int threads = 256, blocks = (N_rays+threads-1)/threads;
80 |
81 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(ws.type(), "distortion_loss_fw_cu_prefix_sums",
82 | ([&] {
83 | prefix_sums_kernel<<>>(
84 | ws.data_ptr(),
85 | wts.data_ptr(),
86 | rays_a.packed_accessor64(),
87 | ws_inclusive_scan.data_ptr(),
88 | ws_exclusive_scan.data_ptr(),
89 | wts_inclusive_scan.data_ptr(),
90 | wts_exclusive_scan.data_ptr()
91 | );
92 | }));
93 |
94 | auto _loss = 2*(wts_inclusive_scan*ws_exclusive_scan-
95 | ws_inclusive_scan*wts_exclusive_scan) + 1.0f/3*ws*ws*deltas;
96 |
97 | auto loss = torch::zeros({N_rays}, ws.options());
98 |
99 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(ws.type(), "distortion_loss_fw_cu",
100 | ([&] {
101 | distortion_loss_fw_kernel<<>>(
102 | _loss.data_ptr(),
103 | rays_a.packed_accessor64(),
104 | loss.packed_accessor()
105 | );
106 | }));
107 |
108 | return {loss, ws_inclusive_scan, wts_inclusive_scan};
109 | }
110 |
111 |
112 | template
113 | __global__ void distortion_loss_bw_kernel(
114 | const torch::PackedTensorAccessor dL_dloss,
115 | const torch::PackedTensorAccessor ws_inclusive_scan,
116 | const torch::PackedTensorAccessor wts_inclusive_scan,
117 | const torch::PackedTensorAccessor ws,
118 | const torch::PackedTensorAccessor deltas,
119 | const torch::PackedTensorAccessor ts,
120 | const torch::PackedTensorAccessor64 rays_a,
121 | torch::PackedTensorAccessor dL_dws
122 | ){
123 | const int n = blockIdx.x * blockDim.x + threadIdx.x;
124 | if (n >= rays_a.size(0)) return;
125 |
126 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2];
127 | const int end_idx = start_idx+N_samples-1;
128 |
129 | const scalar_t ws_sum = ws_inclusive_scan[end_idx];
130 | const scalar_t wts_sum = wts_inclusive_scan[end_idx];
131 | // fill in dL_dws from start_idx to end_idx
132 | for (int s=start_idx; s<=end_idx; s++){
133 | dL_dws[s] = dL_dloss[ray_idx] * 2 * (
134 | (s==start_idx?
135 | (scalar_t)0:
136 | (ts[s]*ws_inclusive_scan[s-1]-wts_inclusive_scan[s-1])
137 | ) +
138 | (wts_sum-wts_inclusive_scan[s]-ts[s]*(ws_sum-ws_inclusive_scan[s]))
139 | );
140 | dL_dws[s] += dL_dloss[ray_idx] * (scalar_t)2/3*ws[s]*deltas[s];
141 | }
142 | }
143 |
144 |
145 | torch::Tensor distortion_loss_bw_cu(
146 | const torch::Tensor dL_dloss,
147 | const torch::Tensor ws_inclusive_scan,
148 | const torch::Tensor wts_inclusive_scan,
149 | const torch::Tensor ws,
150 | const torch::Tensor deltas,
151 | const torch::Tensor ts,
152 | const torch::Tensor rays_a
153 | ){
154 | const int N_rays = rays_a.size(0), N = ws.size(0);
155 |
156 | auto dL_dws = torch::zeros({N}, dL_dloss.options());
157 |
158 | const int threads = 256, blocks = (N_rays+threads-1)/threads;
159 |
160 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(ws.type(), "distortion_loss_bw_cu",
161 | ([&] {
162 | distortion_loss_bw_kernel<<>>(
163 | dL_dloss.packed_accessor(),
164 | ws_inclusive_scan.packed_accessor(),
165 | wts_inclusive_scan.packed_accessor(),
166 | ws.packed_accessor(),
167 | deltas.packed_accessor(),
168 | ts.packed_accessor(),
169 | rays_a.packed_accessor64(),
170 | dL_dws.packed_accessor()
171 | );
172 | }));
173 |
174 | return dL_dws;
175 | }
--------------------------------------------------------------------------------
/models/csrc/setup.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os.path as osp
3 | from setuptools import setup
4 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension
5 |
6 |
7 | ROOT_DIR = osp.dirname(osp.abspath(__file__))
8 | include_dirs = [osp.join(ROOT_DIR, "include")]
9 | # "helper_math.h" is copied from https://github.com/NVIDIA/cuda-samples/blob/master/Common/helper_math.h
10 |
11 | sources = glob.glob('*.cpp')+glob.glob('*.cu')
12 |
13 |
14 | setup(
15 | name='vren',
16 | version='2.0',
17 | author='kwea123',
18 | author_email='kwea123@gmail.com',
19 | description='cuda volume rendering library',
20 | long_description='cuda volume rendering library',
21 | ext_modules=[
22 | CUDAExtension(
23 | name='vren',
24 | sources=sources,
25 | include_dirs=include_dirs,
26 | extra_compile_args={'cxx': ['-O2'],
27 | 'nvcc': ['-O2']}
28 | )
29 | ],
30 | cmdclass={
31 | 'build_ext': BuildExtension
32 | }
33 | )
--------------------------------------------------------------------------------
/models/csrc/volumerendering.cu:
--------------------------------------------------------------------------------
1 | #include "utils.h"
2 | #include
3 | #include
4 |
5 |
6 | template
7 | __global__ void composite_train_fw_kernel(
8 | const torch::PackedTensorAccessor sigmas,
9 | const torch::PackedTensorAccessor rgbs,
10 | const torch::PackedTensorAccessor deltas,
11 | const torch::PackedTensorAccessor ts,
12 | const torch::PackedTensorAccessor64 rays_a,
13 | const scalar_t T_threshold,
14 | torch::PackedTensorAccessor64 total_samples,
15 | torch::PackedTensorAccessor opacity,
16 | torch::PackedTensorAccessor depth,
17 | torch::PackedTensorAccessor rgb,
18 | torch::PackedTensorAccessor ws
19 | ){
20 | const int n = blockIdx.x * blockDim.x + threadIdx.x;
21 | if (n >= opacity.size(0)) return;
22 |
23 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2];
24 |
25 | // front to back compositing
26 | int samples = 0; scalar_t T = 1.0f;
27 |
28 | while (samples < N_samples) {
29 | const int s = start_idx + samples;
30 | const scalar_t a = 1.0f - __expf(-sigmas[s]*deltas[s]);
31 | const scalar_t w = a * T; // weight of the sample point
32 |
33 | rgb[ray_idx][0] += w*rgbs[s][0];
34 | rgb[ray_idx][1] += w*rgbs[s][1];
35 | rgb[ray_idx][2] += w*rgbs[s][2];
36 | depth[ray_idx] += w*ts[s];
37 | opacity[ray_idx] += w;
38 | ws[s] = w;
39 | T *= 1.0f-a;
40 |
41 | if (T <= T_threshold) break; // ray has enough opacity
42 | samples++;
43 | }
44 | total_samples[ray_idx] = samples;
45 | }
46 |
47 |
48 | std::vector composite_train_fw_cu(
49 | const torch::Tensor sigmas,
50 | const torch::Tensor rgbs,
51 | const torch::Tensor deltas,
52 | const torch::Tensor ts,
53 | const torch::Tensor rays_a,
54 | const float T_threshold
55 | ){
56 | const int N_rays = rays_a.size(0), N = sigmas.size(0);
57 |
58 | auto opacity = torch::zeros({N_rays}, sigmas.options());
59 | auto depth = torch::zeros({N_rays}, sigmas.options());
60 | auto rgb = torch::zeros({N_rays, 3}, sigmas.options());
61 | auto ws = torch::zeros({N}, sigmas.options());
62 | auto total_samples = torch::zeros({N_rays}, torch::dtype(torch::kLong).device(sigmas.device()));
63 |
64 | const int threads = 256, blocks = (N_rays+threads-1)/threads;
65 |
66 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(sigmas.type(), "composite_train_fw_cu",
67 | ([&] {
68 | composite_train_fw_kernel<<>>(
69 | sigmas.packed_accessor(),
70 | rgbs.packed_accessor(),
71 | deltas.packed_accessor(),
72 | ts.packed_accessor(),
73 | rays_a.packed_accessor64(),
74 | T_threshold,
75 | total_samples.packed_accessor64(),
76 | opacity.packed_accessor(),
77 | depth.packed_accessor(),
78 | rgb.packed_accessor(),
79 | ws.packed_accessor()
80 | );
81 | }));
82 |
83 | return {total_samples, opacity, depth, rgb, ws};
84 | }
85 |
86 |
87 | template
88 | __global__ void composite_train_bw_kernel(
89 | const torch::PackedTensorAccessor dL_dopacity,
90 | const torch::PackedTensorAccessor dL_ddepth,
91 | const torch::PackedTensorAccessor dL_drgb,
92 | const torch::PackedTensorAccessor dL_dws,
93 | scalar_t* __restrict__ dL_dws_times_ws,
94 | const torch::PackedTensorAccessor sigmas,
95 | const torch::PackedTensorAccessor rgbs,
96 | const torch::PackedTensorAccessor deltas,
97 | const torch::PackedTensorAccessor ts,
98 | const torch::PackedTensorAccessor64 rays_a,
99 | const torch::PackedTensorAccessor opacity,
100 | const torch::PackedTensorAccessor depth,
101 | const torch::PackedTensorAccessor rgb,
102 | const scalar_t T_threshold,
103 | torch::PackedTensorAccessor dL_dsigmas,
104 | torch::PackedTensorAccessor dL_drgbs
105 | ){
106 | const int n = blockIdx.x * blockDim.x + threadIdx.x;
107 | if (n >= opacity.size(0)) return;
108 |
109 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2];
110 |
111 | // front to back compositing
112 | int samples = 0;
113 | scalar_t R = rgb[ray_idx][0], G = rgb[ray_idx][1], B = rgb[ray_idx][2];
114 | scalar_t O = opacity[ray_idx], D = depth[ray_idx];
115 | scalar_t T = 1.0f, r = 0.0f, g = 0.0f, b = 0.0f, d = 0.0f;
116 |
117 | // compute prefix sum of dL_dws * ws
118 | // [a0, a1, a2, a3, ...] -> [a0, a0+a1, a0+a1+a2, a0+a1+a2+a3, ...]
119 | thrust::inclusive_scan(thrust::device,
120 | dL_dws_times_ws+start_idx,
121 | dL_dws_times_ws+start_idx+N_samples,
122 | dL_dws_times_ws+start_idx);
123 | scalar_t dL_dws_times_ws_sum = dL_dws_times_ws[start_idx+N_samples-1];
124 |
125 | while (samples < N_samples) {
126 | const int s = start_idx + samples;
127 | const scalar_t a = 1.0f - __expf(-sigmas[s]*deltas[s]);
128 | const scalar_t w = a * T;
129 |
130 | r += w*rgbs[s][0]; g += w*rgbs[s][1]; b += w*rgbs[s][2];
131 | d += w*ts[s];
132 | T *= 1.0f-a;
133 |
134 | // compute gradients by math...
135 | dL_drgbs[s][0] = dL_drgb[ray_idx][0]*w;
136 | dL_drgbs[s][1] = dL_drgb[ray_idx][1]*w;
137 | dL_drgbs[s][2] = dL_drgb[ray_idx][2]*w;
138 |
139 | dL_dsigmas[s] = deltas[s] * (
140 | dL_drgb[ray_idx][0]*(rgbs[s][0]*T-(R-r)) +
141 | dL_drgb[ray_idx][1]*(rgbs[s][1]*T-(G-g)) +
142 | dL_drgb[ray_idx][2]*(rgbs[s][2]*T-(B-b)) + // gradients from rgb
143 | dL_dopacity[ray_idx]*(1-O) + // gradient from opacity
144 | dL_ddepth[ray_idx]*(ts[s]*T-(D-d)) + // gradient from depth
145 | T*dL_dws[s]-(dL_dws_times_ws_sum-dL_dws_times_ws[s]) // gradient from ws
146 | );
147 |
148 | if (T <= T_threshold) break; // ray has enough opacity
149 | samples++;
150 | }
151 | }
152 |
153 |
154 | std::vector composite_train_bw_cu(
155 | const torch::Tensor dL_dopacity,
156 | const torch::Tensor dL_ddepth,
157 | const torch::Tensor dL_drgb,
158 | const torch::Tensor dL_dws,
159 | const torch::Tensor sigmas,
160 | const torch::Tensor rgbs,
161 | const torch::Tensor ws,
162 | const torch::Tensor deltas,
163 | const torch::Tensor ts,
164 | const torch::Tensor rays_a,
165 | const torch::Tensor opacity,
166 | const torch::Tensor depth,
167 | const torch::Tensor rgb,
168 | const float T_threshold
169 | ){
170 | const int N = sigmas.size(0), N_rays = rays_a.size(0);
171 |
172 | auto dL_dsigmas = torch::zeros({N}, sigmas.options());
173 | auto dL_drgbs = torch::zeros({N, 3}, sigmas.options());
174 |
175 | auto dL_dws_times_ws = dL_dws * ws; // auxiliary input
176 |
177 | const int threads = 256, blocks = (N_rays+threads-1)/threads;
178 |
179 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(sigmas.type(), "composite_train_bw_cu",
180 | ([&] {
181 | composite_train_bw_kernel<<>>(
182 | dL_dopacity.packed_accessor(),
183 | dL_ddepth.packed_accessor(),
184 | dL_drgb.packed_accessor(),
185 | dL_dws.packed_accessor(),
186 | dL_dws_times_ws.data_ptr(),
187 | sigmas.packed_accessor(),
188 | rgbs.packed_accessor(),
189 | deltas.packed_accessor(),
190 | ts.packed_accessor(),
191 | rays_a.packed_accessor64(),
192 | opacity.packed_accessor(),
193 | depth.packed_accessor(),
194 | rgb.packed_accessor(),
195 | T_threshold,
196 | dL_dsigmas.packed_accessor(),
197 | dL_drgbs.packed_accessor()
198 | );
199 | }));
200 |
201 | return {dL_dsigmas, dL_drgbs};
202 | }
203 |
204 |
205 | template
206 | __global__ void composite_test_fw_kernel(
207 | const torch::PackedTensorAccessor sigmas,
208 | const torch::PackedTensorAccessor rgbs,
209 | const torch::PackedTensorAccessor deltas,
210 | const torch::PackedTensorAccessor ts,
211 | const torch::PackedTensorAccessor hits_t,
212 | torch::PackedTensorAccessor64 alive_indices,
213 | const scalar_t T_threshold,
214 | const torch::PackedTensorAccessor32 N_eff_samples,
215 | torch::PackedTensorAccessor opacity,
216 | torch::PackedTensorAccessor depth,
217 | torch::PackedTensorAccessor rgb
218 | ){
219 | const int n = blockIdx.x * blockDim.x + threadIdx.x;
220 | if (n >= alive_indices.size(0)) return;
221 |
222 | if (N_eff_samples[n]==0){ // no hit
223 | alive_indices[n] = -1;
224 | return;
225 | }
226 |
227 | const size_t r = alive_indices[n]; // ray index
228 |
229 | // front to back compositing
230 | int s = 0; scalar_t T = 1-opacity[r];
231 |
232 | while (s < N_eff_samples[n]) {
233 | const scalar_t a = 1.0f - __expf(-sigmas[n][s]*deltas[n][s]);
234 | const scalar_t w = a * T;
235 |
236 | rgb[r][0] += w*rgbs[n][s][0];
237 | rgb[r][1] += w*rgbs[n][s][1];
238 | rgb[r][2] += w*rgbs[n][s][2];
239 | depth[r] += w*ts[n][s];
240 | opacity[r] += w;
241 | T *= 1.0f-a;
242 |
243 | if (T <= T_threshold){ // ray has enough opacity
244 | alive_indices[n] = -1;
245 | break;
246 | }
247 | s++;
248 | }
249 | }
250 |
251 |
252 | void composite_test_fw_cu(
253 | const torch::Tensor sigmas,
254 | const torch::Tensor rgbs,
255 | const torch::Tensor deltas,
256 | const torch::Tensor ts,
257 | const torch::Tensor hits_t,
258 | torch::Tensor alive_indices,
259 | const float T_threshold,
260 | const torch::Tensor N_eff_samples,
261 | torch::Tensor opacity,
262 | torch::Tensor depth,
263 | torch::Tensor rgb
264 | ){
265 | const int N_rays = alive_indices.size(0);
266 |
267 | const int threads = 256, blocks = (N_rays+threads-1)/threads;
268 |
269 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(sigmas.type(), "composite_test_fw_cu",
270 | ([&] {
271 | composite_test_fw_kernel<<>>(
272 | sigmas.packed_accessor(),
273 | rgbs.packed_accessor(),
274 | deltas.packed_accessor(),
275 | ts.packed_accessor(),
276 | hits_t.packed_accessor(),
277 | alive_indices.packed_accessor64(),
278 | T_threshold,
279 | N_eff_samples.packed_accessor32(),
280 | opacity.packed_accessor(),
281 | depth.packed_accessor(),
282 | rgb.packed_accessor()
283 | );
284 | }));
285 | }
--------------------------------------------------------------------------------
/models/loss/nerf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/models/loss/nerf/__init__.py
--------------------------------------------------------------------------------
/models/loss/nerf/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import vren
4 |
5 |
6 | class DistortionLoss(torch.autograd.Function):
7 | """
8 | Distortion loss proposed in Mip-NeRF 360 (https://arxiv.org/pdf/2111.12077.pdf)
9 | Implementation is based on DVGO-v2 (https://arxiv.org/pdf/2206.05085.pdf)
10 |
11 | Inputs:
12 | ws: (N) sample point weights
13 | deltas: (N) considered as intervals
14 | ts: (N) considered as midpoints
15 | rays_a: (N_rays, 3) ray_idx, start_idx, N_samples
16 | meaning each entry corresponds to the @ray_idx th ray,
17 | whose samples are [start_idx:start_idx+N_samples]
18 |
19 | Outputs:
20 | loss: (N_rays)
21 | """
22 | @staticmethod
23 | def forward(ctx, ws, deltas, ts, rays_a):
24 | loss, ws_inclusive_scan, wts_inclusive_scan = \
25 | vren.distortion_loss_fw(ws, deltas, ts, rays_a)
26 | ctx.save_for_backward(ws_inclusive_scan, wts_inclusive_scan,
27 | ws, deltas, ts, rays_a)
28 | return loss
29 |
30 | @staticmethod
31 | def backward(ctx, dL_dloss):
32 | (ws_inclusive_scan, wts_inclusive_scan,
33 | ws, deltas, ts, rays_a) = ctx.saved_tensors
34 | dL_dws = vren.distortion_loss_bw(dL_dloss, ws_inclusive_scan,
35 | wts_inclusive_scan,
36 | ws, deltas, ts, rays_a)
37 | return dL_dws, None, None, None
38 |
39 |
40 | class NeRFLoss(nn.Module):
41 | def __init__(self, lambda_opacity=1e-3, lambda_distortion=1e-3):
42 | super().__init__()
43 |
44 | self.lambda_opacity = lambda_opacity
45 | self.lambda_distortion = lambda_distortion
46 |
47 | def forward(self, results, target, **kwargs):
48 | d = {}
49 | d['rgb'] = (results['rgb']-target['rgb'])**2
50 |
51 | o = results['opacity']+1e-10
52 | # encourage opacity to be either 0 or 1 to avoid floater
53 | d['opacity'] = self.lambda_opacity*(-o*torch.log(o))
54 |
55 | if self.lambda_distortion > 0:
56 | d['distortion'] = self.lambda_distortion * \
57 | DistortionLoss.apply(results['ws'], results['deltas'],
58 | results['ts'], results['rays_a'])
59 |
60 | return d
61 |
--------------------------------------------------------------------------------
/models/networks/FFB_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import tinycudann as tcnn
5 | import math
6 |
7 | from models.networks.Sine import Sine, sine_init, first_layer_sine_init
8 |
9 |
10 | class FFB_encoder(nn.Module):
11 | def __init__(self, encoding_config, network_config, n_input_dims, bound=1.0, has_out=True):
12 | super().__init__()
13 |
14 | self.bound = bound
15 |
16 | ### The encoder part
17 | sin_dims = network_config["dims"]
18 | sin_dims = [n_input_dims] + sin_dims
19 | self.num_sin_layers = len(sin_dims)
20 |
21 | feat_dim = encoding_config["feat_dim"]
22 | base_resolution = encoding_config["base_resolution"]
23 | per_level_scale = encoding_config["per_level_scale"]
24 |
25 | assert self.num_sin_layers > 3, "The layer number (SIREN branch) should be greater than 3."
26 | grid_level = int(self.num_sin_layers - 2)
27 | self.grid_encoder = tcnn.Encoding(
28 | n_input_dims=n_input_dims,
29 | encoding_config={
30 | "otype": "HashGrid",
31 | "n_levels": grid_level,
32 | "n_features_per_level": feat_dim,
33 | "log2_hashmap_size": 19,
34 | "base_resolution": base_resolution,
35 | "per_level_scale": per_level_scale,
36 | },
37 | )
38 | self.grid_level = grid_level
39 | print(f"Grid encoder levels: {grid_level}")
40 |
41 | self.feat_dim = feat_dim
42 |
43 | ### Create the ffn to map low-dim grid feats to map high-dim SIREN feats
44 | base_sigma = encoding_config["base_sigma"]
45 | exp_sigma = encoding_config["exp_sigma"]
46 |
47 | ffn_list = []
48 | for i in range(grid_level):
49 | ffn = torch.randn((feat_dim, sin_dims[2 + i]), requires_grad=True) * base_sigma * exp_sigma ** i
50 |
51 | ffn_list.append(ffn)
52 |
53 | self.ffn = nn.Parameter(torch.stack(ffn_list, dim=0))
54 |
55 |
56 | ### The low-frequency MLP part
57 | for layer in range(0, self.num_sin_layers - 1):
58 | setattr(self, "sin_lin" + str(layer), nn.Linear(sin_dims[layer], sin_dims[layer + 1]))
59 |
60 | self.sin_w0 = network_config["w0"]
61 | self.sin_activation = Sine(w0=self.sin_w0)
62 | self.init_siren()
63 |
64 | ### The output layers
65 | self.has_out = has_out
66 | if has_out:
67 | size_factor = network_config["size_factor"]
68 | self.out_dim = sin_dims[-1] * size_factor
69 |
70 | for layer in range(0, grid_level):
71 | setattr(self, "out_lin" + str(layer), nn.Linear(sin_dims[layer + 1], self.out_dim))
72 |
73 | self.sin_w0_high = network_config["w1"]
74 | self.init_siren_out()
75 | self.out_activation = Sine(w0=self.sin_w0_high)
76 | else:
77 | self.out_dim = sin_dims[-1] * grid_level
78 |
79 |
80 | ### Initialize the parameters of SIREN branch
81 | def init_siren(self):
82 | for layer in range(0, self.num_sin_layers - 1):
83 | lin = getattr(self, "sin_lin" + str(layer))
84 |
85 | if layer == 0:
86 | first_layer_sine_init(lin)
87 | else:
88 | sine_init(lin, w0=self.sin_w0)
89 |
90 |
91 | def init_siren_out(self):
92 | for layer in range(0, self.grid_level):
93 | lin = getattr(self, "out_lin" + str(layer))
94 |
95 | sine_init(lin, w0=self.sin_w0_high)
96 |
97 |
98 | def forward(self, in_pos):
99 | """
100 | in_pos: [N, 3], in [-bound, bound]
101 |
102 | in_pos (for grid features) should always be located in [0.0, 1.0]
103 | x (for SIREN branch) should always be located in [-1.0, 1.0]
104 | """
105 |
106 | x = in_pos / self.bound # to [-1, 1]
107 | in_pos = (in_pos + self.bound) / (2 * self.bound) # to [0, 1]
108 |
109 | grid_x = self.grid_encoder(in_pos)
110 | grid_x = grid_x.view(-1, self.grid_level, self.feat_dim)
111 | grid_x = grid_x.permute(1, 0, 2)
112 |
113 | embedding_list = []
114 | for i in range(self.grid_level):
115 | grid_output = torch.matmul(grid_x[i], self.ffn[i])
116 | grid_output = torch.sin(2 * math.pi * grid_output)
117 | embedding_list.append(grid_output)
118 |
119 | if self.has_out:
120 | x_out = torch.zeros(x.shape[0], self.out_dim, device=in_pos.device)
121 | else:
122 | feat_list = []
123 |
124 | ### Grid encoding
125 | for layer in range(0, self.num_sin_layers - 1):
126 | sin_lin = getattr(self, "sin_lin" + str(layer))
127 | x = sin_lin(x)
128 | x = self.sin_activation(x)
129 |
130 | if layer > 0:
131 | x = embedding_list[layer-1] + x
132 |
133 | if self.has_out:
134 | out_lin = getattr(self, "out_lin" + str(layer-1))
135 | x_high = out_lin(x)
136 | x_high = self.out_activation(x_high)
137 |
138 | x_out = x_out + x_high
139 | else:
140 | feat_list.append(x)
141 |
142 | if self.has_out:
143 | x = x_out
144 | else:
145 | x = feat_list
146 |
147 | return x
--------------------------------------------------------------------------------
/models/networks/Sine.py:
--------------------------------------------------------------------------------
1 | """
2 | These codes are adapted from SIREN (https://github.com/vsitzmann/siren)
3 | """
4 |
5 |
6 | import torch
7 | from torch import nn
8 | import numpy as np
9 |
10 |
11 | class Sine(nn.Module):
12 | def __init__(self, w0):
13 | super().__init__()
14 |
15 | self.w0 = w0
16 |
17 | def forward(self, input):
18 | return torch.sin(input * self.w0)
19 |
20 |
21 |
22 | def sine_init(m, w0, num_input=None):
23 | with torch.no_grad():
24 | if hasattr(m, 'weight'):
25 | if num_input is None:
26 | num_input = m.weight.size(-1)
27 | m.weight.uniform_(-np.sqrt(6 / num_input) / w0, np.sqrt(6 / num_input) / w0)
28 |
29 |
30 | def first_layer_sine_init(m):
31 | with torch.no_grad():
32 | if hasattr(m, 'weight'):
33 | num_input = m.weight.size(-1)
34 | m.weight.uniform_(-1.0 / num_input, 1.0 / num_input)
35 |
--------------------------------------------------------------------------------
/models/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/models/networks/__init__.py
--------------------------------------------------------------------------------
/models/networks/img/NFFB_2d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from models.networks.FFB_encoder import FFB_encoder
5 |
6 |
7 | class NFFB(nn.Module):
8 | def __init__(self, config, out_dims=3):
9 | super().__init__()
10 |
11 | self.xyz_encoder = FFB_encoder(n_input_dims=2, encoding_config=config["encoding"],
12 | network_config=config["SIREN"], has_out=False)
13 |
14 | ### Initializing backbone part, to merge multi-scale grid features
15 | backbone_dims = config["Backbone"]["dims"]
16 | grid_feat_len = self.xyz_encoder.out_dim
17 | backbone_dims = [grid_feat_len] + backbone_dims + [out_dims]
18 | self.num_backbone_layers = len(backbone_dims)
19 |
20 | for layer in range(0, self.num_backbone_layers - 1):
21 | out_dim = backbone_dims[layer + 1]
22 | setattr(self, "backbone_lin" + str(layer), nn.Linear(backbone_dims[layer], out_dim))
23 |
24 | self.relu_activation = nn.ReLU(inplace=True)
25 |
26 |
27 | @torch.no_grad()
28 | # optimizer utils
29 | def get_params(self, LR_schedulers):
30 | params = [
31 | {'params': self.parameters(), 'lr': LR_schedulers[0]["initial"]}
32 | ]
33 |
34 | return params
35 |
36 |
37 | def forward(self, in_pos):
38 | """
39 | Inputs:
40 | x: (N, 2) xy in [-scale, scale]
41 | Outputs:
42 | out: (N, 1 or 3), the RGB values
43 | """
44 | x = (in_pos - 0.5) * 2.0
45 |
46 | grid_x = self.xyz_encoder(x)
47 | out_feat = torch.cat(grid_x, dim=1)
48 |
49 |
50 | ### Backbone transformation
51 | for layer in range(0, self.num_backbone_layers - 1):
52 | backbone_lin = getattr(self, "backbone_lin" + str(layer))
53 | out_feat = backbone_lin(out_feat)
54 |
55 | if layer < self.num_backbone_layers - 2:
56 | out_feat = self.relu_activation(out_feat)
57 |
58 | out_feat = out_feat.clamp(-1.0, 1.0)
59 |
60 | return out_feat
--------------------------------------------------------------------------------
/models/networks/img/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/models/networks/img/__init__.py
--------------------------------------------------------------------------------
/models/networks/nerf/NFFB_nerf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import tinycudann as tcnn
4 | import vren
5 | from einops import rearrange
6 | from .custom_functions import TruncExp
7 | import numpy as np
8 |
9 | from .rendering import NEAR_DISTANCE
10 |
11 | from models.networks.FFB_encoder import FFB_encoder
12 | from models.networks.Sine import sine_init, Sine
13 |
14 |
15 | class NFFB(nn.Module):
16 | def __init__(self, config, scale, rgb_act='Sigmoid'):
17 | super().__init__()
18 |
19 | self.rgb_act = rgb_act
20 |
21 | # scene bounding box
22 | self.scale = scale
23 | self.register_buffer('center', torch.zeros(1, 3))
24 | self.register_buffer('xyz_min', -torch.ones(1, 3)*scale)
25 | self.register_buffer('xyz_max', torch.ones(1, 3)*scale)
26 | self.register_buffer('half_size', (self.xyz_max-self.xyz_min)/2)
27 |
28 | # each density grid covers [-2^(k-1), 2^(k-1)]^3 for k in [0, C-1]
29 | self.cascades = max(1+int(np.ceil(np.log2(2*scale))), 1)
30 | self.grid_size = 128 ### This property is used to speed up training process
31 | self.register_buffer('density_bitfield',
32 | torch.zeros(self.cascades*self.grid_size**3//8, dtype=torch.uint8))
33 |
34 |
35 | self.xyz_encoder = FFB_encoder(n_input_dims=3, encoding_config=config["encoding"],
36 | network_config=config["SIREN"], bound=self.scale)
37 |
38 | ## sigma network
39 | self.num_layers = num_layers = 1
40 | hidden_dim = 64
41 | geo_feat_dim = 15
42 |
43 | sigma_net = []
44 | for l in range(num_layers):
45 | if l == 0:
46 | in_dim = self.xyz_encoder.out_dim
47 | else:
48 | in_dim = hidden_dim
49 |
50 | if l == num_layers - 1:
51 | out_dim = 1 + geo_feat_dim # 1 sigma + 15 SH features for color
52 | else:
53 | out_dim = hidden_dim
54 |
55 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False))
56 | self.sigma_net = nn.ModuleList(sigma_net)
57 |
58 | self.sin_w0 = config["SIREN"]["w1"]
59 | self.sin_activation = Sine(w0=self.sin_w0)
60 | self.init_siren()
61 |
62 | # ### sigma network
63 | # self.sigma_net = \
64 | # tcnn.Network(
65 | # n_input_dims=self.xyz_encoder.out_dim, n_output_dims=16,
66 | # network_config={
67 | # "otype": "FullyFusedMLP",
68 | # "activation": "ReLU",
69 | # "output_activation": "None",
70 | # "n_neurons": 64,
71 | # "n_hidden_layers": 1,
72 | # }
73 | # )
74 |
75 | # self.dir_encoder = \
76 | # tcnn.Encoding(
77 | # n_input_dims=3,
78 | # encoding_config={
79 | # "otype": "SphericalHarmonics",
80 | # "degree": 4,
81 | # },
82 | # )
83 |
84 | self.dir_encoder = \
85 | tcnn.Encoding(
86 | n_input_dims=3,
87 | encoding_config={
88 | "otype": "Frequency",
89 | "n_frequencies": 5
90 | },
91 | )
92 |
93 | self.rgb_net = \
94 | tcnn.Network(
95 | n_input_dims=46, n_output_dims=3,
96 | network_config={
97 | "otype": "FullyFusedMLP",
98 | "activation": "ReLU",
99 | "output_activation": self.rgb_act,
100 | "n_neurons": 64,
101 | "n_hidden_layers": 2,
102 | }
103 | )
104 |
105 | if self.rgb_act == 'None': # rgb_net output is log-radiance
106 | for i in range(3): # independent tonemappers for r,g,b
107 | tonemapper_net = \
108 | tcnn.Network(
109 | n_input_dims=1, n_output_dims=1,
110 | network_config={
111 | "otype": "FullyFusedMLP",
112 | "activation": "ReLU",
113 | "output_activation": "Sigmoid",
114 | "n_neurons": 64,
115 | "n_hidden_layers": 1,
116 | }
117 | )
118 | setattr(self, f'tonemapper_net_{i}', tonemapper_net)
119 |
120 | ### Initialize the sine-activated parameters
121 | def init_siren(self):
122 | ### Initialize the sigma network
123 | for l in range(self.num_layers):
124 | lin = self.sigma_net[l]
125 | sine_init(lin, w0=self.sin_w0)
126 |
127 | ### TODO - Transform the input coordinates into right range when feeding it into xyz_encoder()
128 | def density(self, x, return_feat=False):
129 | """
130 | Inputs:
131 | x: (N, 3) xyz in [-scale, scale]
132 | return_feat: whether to return intermediate feature
133 |
134 | Outputs:
135 | sigmas: (N)
136 | """
137 | # x = (x-self.xyz_min)/(self.xyz_max-self.xyz_min)
138 | h = self.xyz_encoder(x)
139 | # h = self.sigma_net(h)
140 | #
141 | for l in range(self.num_layers):
142 | h = self.sigma_net[l](h)
143 | if l != self.num_layers - 1:
144 | # h = F.relu(h, inplace=True)
145 | h = self.sin_activation(h)
146 | # h = self.sin_activation(h * self.sin_w0)
147 |
148 | sigmas = TruncExp.apply(h[:, 0])
149 | if return_feat: return sigmas, h
150 | return sigmas
151 |
152 | def log_radiance_to_rgb(self, log_radiances, **kwargs):
153 | """
154 | Convert log-radiance to rgb as the setting in HDR-NeRF.
155 | Called only when self.rgb_act == 'None' (with exposure)
156 |
157 | Inputs:
158 | log_radiances: (N, 3)
159 |
160 | Outputs:
161 | rgbs: (N, 3)
162 | """
163 | if 'exposure' in kwargs:
164 | log_exposure = torch.log(kwargs['exposure'])
165 | else: # unit exposure by default
166 | log_exposure = 0
167 |
168 | out = []
169 | for i in range(3):
170 | inp = log_radiances[:, i:i+1]+log_exposure
171 | out += [getattr(self, f'tonemapper_net_{i}')(inp)]
172 | rgbs = torch.cat(out, 1)
173 | return rgbs
174 |
175 | def forward(self, x, d, **kwargs):
176 | """
177 | Inputs:
178 | x: (N, 3) xyz in [-scale, scale]
179 | d: (N, 3) directions
180 |
181 | Outputs:
182 | sigmas: (N)
183 | rgbs: (N, 3)
184 | """
185 | sigmas, h = self.density(x, return_feat=True)
186 | d = d/torch.norm(d, dim=1, keepdim=True)
187 | d = self.dir_encoder((d+1)/2)
188 | rgbs = self.rgb_net(torch.cat([d, h], 1))
189 |
190 | if self.rgb_act == 'None': # rgbs is log-radiance
191 | if kwargs.get('output_radiance', False): # output HDR map
192 | rgbs = TruncExp.apply(rgbs)
193 | else: # convert to LDR using tonemapper networks
194 | rgbs = self.log_radiance_to_rgb(rgbs, **kwargs)
195 |
196 | return sigmas, rgbs
197 |
198 | @torch.no_grad()
199 | def get_all_cells(self):
200 | """
201 | Get all cells from the density grid.
202 |
203 | Outputs:
204 | cells: list (of length self.cascades) of indices and coords
205 | selected at each cascade
206 | """
207 | indices = vren.morton3D(self.grid_coords).long()
208 | cells = [(indices, self.grid_coords)] * self.cascades
209 |
210 | return cells
211 |
212 | @torch.no_grad()
213 | def sample_uniform_and_occupied_cells(self, M, density_threshold):
214 | """
215 | Sample both M uniform and occupied cells (per cascade)
216 | occupied cells are sample from cells with density > @density_threshold
217 |
218 | Outputs:
219 | cells: list (of length self.cascades) of indices and coords
220 | selected at each cascade
221 | """
222 | cells = []
223 | for c in range(self.cascades):
224 | # uniform cells
225 | coords1 = torch.randint(self.grid_size, (M, 3), dtype=torch.int32,
226 | device=self.density_grid.device)
227 | indices1 = vren.morton3D(coords1).long()
228 | # occupied cells
229 | indices2 = torch.nonzero(self.density_grid[c]>density_threshold)[:, 0]
230 | if len(indices2) > 0:
231 | ### Randomly pick M occupied cells
232 | rand_idx = torch.randint(len(indices2), (M,), device=self.density_grid.device)
233 | indices2 = indices2[rand_idx]
234 | coords2 = vren.morton3D_invert(indices2.int())
235 | # concatenate
236 | cells += [(torch.cat([indices1, indices2]), torch.cat([coords1, coords2]))]
237 |
238 | return cells
239 |
240 | @torch.no_grad()
241 | def mark_invisible_cells(self, K, poses, img_wh, chunk=64**3):
242 | """
243 | mark the cells that aren't covered by the cameras with density -1
244 | only executed once before training starts
245 |
246 | Inputs:
247 | K: (3, 3) camera intrinsics
248 | poses: (N, 3, 4) camera to world poses
249 | img_wh: image width and height
250 | chunk: the chunk size to split the cells (to avoid OOM)
251 | """
252 | N_cams = poses.shape[0]
253 | self.count_grid = torch.zeros_like(self.density_grid)
254 | w2c_R = rearrange(poses[:, :3, :3], 'n a b -> n b a') # (N_cams, 3, 3)
255 | w2c_T = -w2c_R@poses[:, :3, 3:] # (N_cams, 3, 1)
256 | cells = self.get_all_cells()
257 | for c in range(self.cascades):
258 | indices, coords = cells[c]
259 | for i in range(0, len(indices), chunk):
260 | xyzs = coords[i:i+chunk]/(self.grid_size-1)*2-1 ### [-1, 1]
261 | s = min(2**(c-1), self.scale)
262 | half_grid_size = s/self.grid_size
263 | xyzs_w = (xyzs*(s-half_grid_size)).T # (3, chunk) ### The coordinates in world frame
264 | xyzs_c = w2c_R @ xyzs_w + w2c_T # (N_cams, 3, chunk) ### The coordinates in camera frame
265 | uvd = K @ xyzs_c # (N_cams, 3, chunk)
266 | uv = uvd[:, :2]/uvd[:, 2:] # (N_cams, 2, chunk) ### The coordinates in image frame
267 | in_image = (uvd[:, 2]>=0)& \
268 | (uv[:, 0]>=0)&(uv[:, 0]=0)&(uv[:, 1]=NEAR_DISTANCE)&in_image # (N_cams, chunk)
271 | # if the cell is visible by at least one camera
272 | self.count_grid[c, indices[i:i+chunk]] = \
273 | count = covered_by_cam.sum(0)/N_cams
274 |
275 | too_near_to_cam = (uvd[:, 2]0)&(~too_near_to_any_cam)
280 | self.density_grid[c, indices[i:i+chunk]] = \
281 | torch.where(valid_mask, 0., -1.)
282 |
283 | @torch.no_grad()
284 | def update_density_grid(self, density_threshold, warmup=False, decay=0.95, erode=False):
285 | density_grid_tmp = torch.zeros_like(self.density_grid)
286 | if warmup: # during the first steps
287 | cells = self.get_all_cells()
288 | else:
289 | cells = self.sample_uniform_and_occupied_cells(self.grid_size**3//4, density_threshold)
290 |
291 | # infer and then update sigmas, and store at the density_grid_tmp
292 | for c in range(self.cascades):
293 | indices, coords = cells[c]
294 | s = min(2**(c-1), self.scale)
295 | half_grid_size = s/self.grid_size
296 | xyzs_w = (coords/(self.grid_size-1)*2-1)*(s-half_grid_size)
297 | # pick random position in the cell by adding noise in [-hgs, hgs]
298 | xyzs_w += (torch.rand_like(xyzs_w)*2-1) * half_grid_size
299 | density_grid_tmp[c, indices] = self.density(xyzs_w)
300 |
301 | if erode:
302 | # My own logic. decay more the cells that are visible to few cameras
303 | decay = torch.clamp(decay**(1/self.count_grid), 0.1, 0.95)
304 | self.density_grid = \
305 | torch.where(self.density_grid < 0, self.density_grid,
306 | torch.maximum(self.density_grid*decay, density_grid_tmp))
307 |
308 | mean_density = self.density_grid[self.density_grid>0].mean().item()
309 |
310 | ### Seems that, this line of code turn the density grids into a 8-bit integer array to save space
311 | vren.packbits(self.density_grid, min(mean_density, density_threshold), self.density_bitfield)
312 |
313 | @torch.no_grad()
314 | # optimizer utils
315 | def get_params(self, LR_schedulers):
316 | params = [
317 | {'params': self.xyz_encoder.parameters(), 'lr': LR_schedulers[0]["initial"]},
318 | {'params': self.sigma_net.parameters(), 'lr': LR_schedulers[1]["initial"]},
319 | {'params': self.dir_encoder.parameters(), 'lr': LR_schedulers[2]["initial"]},
320 | {'params': self.rgb_net.parameters(), 'lr': LR_schedulers[3]["initial"]},
321 | ]
322 |
323 | return params
--------------------------------------------------------------------------------
/models/networks/nerf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/models/networks/nerf/__init__.py
--------------------------------------------------------------------------------
/models/networks/nerf/custom_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import vren
3 | from torch.cuda.amp import custom_fwd, custom_bwd
4 | from torch_scatter import segment_csr
5 | from einops import rearrange
6 |
7 |
8 | ### Compute the intersection information for rays and AABB
9 | class RayAABBIntersector(torch.autograd.Function):
10 | """
11 | Computes the intersections of rays and axis-aligned voxels.
12 |
13 | Inputs:
14 | rays_o: (N_rays, 3) ray origins
15 | rays_d: (N_rays, 3) ray directions
16 | centers: (N_voxels, 3) voxel centers
17 | half_sizes: (N_voxels, 3) voxel half sizes
18 | max_hits: maximum number of intersected voxels to keep for one ray
19 | (for a cubic scene, this is at most 3*N_voxels^(1/3)-2)
20 |
21 | Outputs:
22 | hits_cnt: (N_rays) number of hits for each ray
23 | (followings are from near to far)
24 | hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit)
25 | hits_voxel_idx: (N_rays, max_hits) hit voxel indices (-1 if no hit)
26 | """
27 | @staticmethod
28 | @custom_fwd(cast_inputs=torch.float32)
29 | def forward(ctx, rays_o, rays_d, center, half_size, max_hits):
30 | return vren.ray_aabb_intersect(rays_o, rays_d, center, half_size, max_hits)
31 |
32 |
33 | ### Compute the intersection information between rays and a set of spheres
34 | class RaySphereIntersector(torch.autograd.Function):
35 | """
36 | Computes the intersections of rays and spheres.
37 |
38 | Inputs:
39 | rays_o: (N_rays, 3) ray origins
40 | rays_d: (N_rays, 3) ray directions
41 | centers: (N_spheres, 3) sphere centers
42 | radii: (N_spheres, 3) radii
43 | max_hits: maximum number of intersected spheres to keep for one ray
44 |
45 | Outputs:
46 | hits_cnt: (N_rays) number of hits for each ray
47 | (followings are from near to far)
48 | hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit)
49 | hits_sphere_idx: (N_rays, max_hits) hit sphere indices (-1 if no hit)
50 | """
51 | @staticmethod
52 | @custom_fwd(cast_inputs=torch.float32)
53 | def forward(ctx, rays_o, rays_d, center, radii, max_hits):
54 | return vren.ray_sphere_intersect(rays_o, rays_d, center, radii, max_hits)
55 |
56 |
57 | class RayMarcher(torch.autograd.Function):
58 | """
59 | March the rays to get sample point positions and directions.
60 |
61 | Inputs:
62 | rays_o: (N_rays, 3) ray origins
63 | rays_d: (N_rays, 3) normalized ray directions
64 | hits_t: (N_rays, 2) near and far bounds from aabb intersection
65 | density_bitfield: (C*G**3//8)
66 | cascades: int
67 | scale: float
68 | exp_step_factor: the exponential factor to scale the steps
69 | grid_size: int
70 | max_samples: int
71 |
72 | Outputs:
73 | rays_a: (N_rays) ray_idx, start_idx, N_samples
74 | xyzs: (N, 3) sample positions
75 | dirs: (N, 3) sample view directions
76 | deltas: (N) dt for integration
77 | ts: (N) sample ts
78 | """
79 | @staticmethod
80 | @custom_fwd(cast_inputs=torch.float32)
81 | def forward(ctx, rays_o, rays_d, hits_t,
82 | density_bitfield, cascades, scale, exp_step_factor,
83 | grid_size, max_samples):
84 | # noise to perturb the first sample of each ray
85 | noise = torch.rand_like(rays_o[:, 0])
86 |
87 | rays_a, xyzs, dirs, deltas, ts, counter = \
88 | vren.raymarching_train(
89 | rays_o, rays_d, hits_t,
90 | density_bitfield, cascades, scale,
91 | exp_step_factor, noise, grid_size, max_samples)
92 |
93 | total_samples = counter[0] # total samples for all rays
94 | # remove redundant output
95 | xyzs = xyzs[:total_samples]
96 | dirs = dirs[:total_samples]
97 | deltas = deltas[:total_samples]
98 | ts = ts[:total_samples]
99 |
100 | ctx.save_for_backward(rays_a, ts)
101 |
102 | return rays_a, xyzs, dirs, deltas, ts, total_samples
103 |
104 | @staticmethod
105 | @custom_bwd
106 | def backward(ctx, dL_drays_a, dL_dxyzs, dL_ddirs,
107 | dL_ddeltas, dL_dts, dL_dtotal_samples):
108 | rays_a, ts = ctx.saved_tensors
109 | segments = torch.cat([rays_a[:, 1], rays_a[-1:, 1]+rays_a[-1:, 2]])
110 | dL_drays_o = segment_csr(dL_dxyzs, segments)
111 | dL_drays_d = \
112 | segment_csr(dL_dxyzs*rearrange(ts, 'n -> n 1')+dL_ddirs, segments)
113 |
114 | return dL_drays_o, dL_drays_d, None, None, None, None, None, None, None
115 |
116 |
117 | ### Compute the information for RGB, depth and opacity
118 | class VolumeRenderer(torch.autograd.Function):
119 | """
120 | Volume rendering with different number of samples per ray
121 | Used in training only
122 |
123 | Inputs:
124 | sigmas: (N)
125 | rgbs: (N, 3)
126 | deltas: (N)
127 | ts: (N)
128 | rays_a: (N_rays, 3) ray_idx, start_idx, N_samples
129 | meaning each entry corresponds to the @ray_idx th ray,
130 | whose samples are [start_idx:start_idx+N_samples]
131 | T_threshold: float, stop the ray if the transmittance is below it
132 |
133 | Outputs:
134 | total_samples: int, total effective samples
135 | opacity: (N_rays)
136 | depth: (N_rays)
137 | rgb: (N_rays, 3)
138 | ws: (N) sample point weights
139 | """
140 | @staticmethod
141 | @custom_fwd(cast_inputs=torch.float32)
142 | def forward(ctx, sigmas, rgbs, deltas, ts, rays_a, T_threshold):
143 | total_samples, opacity, depth, rgb, ws = \
144 | vren.composite_train_fw(sigmas, rgbs, deltas, ts,
145 | rays_a, T_threshold)
146 | ctx.save_for_backward(sigmas, rgbs, deltas, ts, rays_a,
147 | opacity, depth, rgb, ws)
148 | ctx.T_threshold = T_threshold
149 | return total_samples.sum(), opacity, depth, rgb, ws
150 |
151 | @staticmethod
152 | @custom_bwd
153 | def backward(ctx, dL_dtotal_samples, dL_dopacity, dL_ddepth, dL_drgb, dL_dws):
154 | sigmas, rgbs, deltas, ts, rays_a, \
155 | opacity, depth, rgb, ws = ctx.saved_tensors
156 | dL_dsigmas, dL_drgbs = \
157 | vren.composite_train_bw(dL_dopacity, dL_ddepth, dL_drgb, dL_dws,
158 | sigmas, rgbs, ws, deltas, ts,
159 | rays_a,
160 | opacity, depth, rgb,
161 | ctx.T_threshold)
162 | return dL_dsigmas, dL_drgbs, None, None, None, None
163 |
164 |
165 | class TruncExp(torch.autograd.Function):
166 | @staticmethod
167 | @custom_fwd(cast_inputs=torch.float32)
168 | def forward(ctx, x):
169 | ctx.save_for_backward(x)
170 | return torch.exp(x)
171 |
172 | @staticmethod
173 | @custom_bwd
174 | def backward(ctx, dL_dout):
175 | x = ctx.saved_tensors[0]
176 | return dL_dout * torch.exp(x.clamp(-15, 15))
177 |
--------------------------------------------------------------------------------
/models/networks/nerf/rendering.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .custom_functions import \
3 | RayAABBIntersector, RayMarcher, VolumeRenderer
4 | from einops import rearrange
5 | import vren
6 |
7 | MAX_SAMPLES = 1024
8 | NEAR_DISTANCE = 0.01
9 |
10 |
11 | @torch.cuda.amp.autocast()
12 | def render(model, rays_o, rays_d, **kwargs):
13 | """
14 | Render rays by
15 | 1. Compute the intersection of the rays with the scene bounding box
16 | 2. Follow the process in @render_func (different for train/test)
17 |
18 | Inputs:
19 | model: NGP
20 | rays_o: (N_rays, 3) ray origins
21 | rays_d: (N_rays, 3) ray directions
22 |
23 | Outputs:
24 | result: dictionary containing final rgb and depth
25 | """
26 | rays_o = rays_o.contiguous(); rays_d = rays_d.contiguous()
27 | _, hits_t, _ = \
28 | RayAABBIntersector.apply(rays_o, rays_d, model.center, model.half_size, 1)
29 | hits_t[(hits_t[:, 0, 0]>=0)&(hits_t[:, 0, 0] (n1 n2) c')
91 | dirs = rearrange(dirs, 'n1 n2 c -> (n1 n2) c')
92 | valid_mask = ~torch.all(dirs==0, dim=1)
93 | if valid_mask.sum()==0: break
94 |
95 | sigmas = torch.zeros(len(xyzs), device=device)
96 | rgbs = torch.zeros(len(xyzs), 3, device=device)
97 | sigmas[valid_mask], _rgbs = model(xyzs[valid_mask], dirs[valid_mask], **kwargs)
98 | rgbs[valid_mask] = _rgbs.float()
99 | sigmas = rearrange(sigmas, '(n1 n2) -> n1 n2', n2=N_samples)
100 | rgbs = rearrange(rgbs, '(n1 n2) c -> n1 n2 c', n2=N_samples)
101 |
102 | vren.composite_test_fw(
103 | sigmas, rgbs, deltas, ts,
104 | hits_t[:, 0], alive_indices, kwargs.get('T_threshold', 1e-4),
105 | N_eff_samples, opacity, depth, rgb)
106 | alive_indices = alive_indices[alive_indices>=0] # remove converged rays
107 |
108 | results['opacity'] = opacity
109 | results['depth'] = depth
110 | results['rgb'] = rgb
111 | results['total_samples'] = total_samples # total samples for all rays
112 |
113 | if exp_step_factor==0: # synthetic
114 | rgb_bg = torch.ones(3, device=device)
115 | else: # real
116 | rgb_bg = torch.zeros(3, device=device)
117 | results['rgb'] += rgb_bg*rearrange(1-opacity, 'n -> n 1')
118 |
119 | return results
120 |
121 |
122 | ### Given the ray information, render RGB images for training stages
123 | def __render_rays_train(model, rays_o, rays_d, hits_t, **kwargs):
124 | """
125 | Render rays by
126 | 1. March the rays along their directions, querying @density_bitfield
127 | to skip empty space, and get the effective sample points (where
128 | there is object)
129 | 2. Infer the NN at these positions and view directions to get properties
130 | (currently sigmas and rgbs)
131 | 3. Use volume rendering to combine the result (front to back compositing
132 | and early stop the ray if its transmittance is below a threshold)
133 | """
134 | exp_step_factor = kwargs.get('exp_step_factor', 0.)
135 | results = {}
136 |
137 | (rays_a, xyzs, dirs,
138 | results['deltas'], results['ts'], results['rm_samples']) = \
139 | RayMarcher.apply(
140 | rays_o, rays_d, hits_t[:, 0], model.density_bitfield,
141 | model.cascades, model.scale,
142 | exp_step_factor, model.grid_size, MAX_SAMPLES)
143 |
144 | for k, v in kwargs.items(): # supply additional inputs, repeated per ray
145 | if isinstance(v, torch.Tensor):
146 | kwargs[k] = torch.repeat_interleave(v[rays_a[:, 0]], rays_a[:, 2], 0)
147 | sigmas, rgbs = model(xyzs, dirs, **kwargs)
148 |
149 | (results['vr_samples'], results['opacity'],
150 | results['depth'], results['rgb'], results['ws']) = \
151 | VolumeRenderer.apply(sigmas, rgbs.contiguous(), results['deltas'], results['ts'],
152 | rays_a, kwargs.get('T_threshold', 1e-4))
153 | results['rays_a'] = rays_a
154 |
155 | if exp_step_factor==0: # synthetic
156 | rgb_bg = torch.ones(3, device=rays_o.device)
157 | else: # real
158 | if kwargs.get('random_bg', False):
159 | rgb_bg = torch.rand(3, device=rays_o.device)
160 | else:
161 | rgb_bg = torch.zeros(3, device=rays_o.device)
162 | results['rgb'] = results['rgb'] + \
163 | rgb_bg*rearrange(1-results['opacity'], 'n -> n 1')
164 |
165 | return results
166 |
--------------------------------------------------------------------------------
/models/networks/sdf/NFFB_3d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from models.networks.FFB_encoder import FFB_encoder
5 | from models.networks.Sine import sine_init
6 |
7 |
8 | class NFFB(nn.Module):
9 | def __init__(self, config):
10 | super().__init__()
11 |
12 | self.xyz_encoder = FFB_encoder(n_input_dims=3, encoding_config=config["encoding"],
13 | network_config=config["SIREN"], has_out=False)
14 | enc_out_dim = self.xyz_encoder.out_dim
15 |
16 | self.out_lin = nn.Linear(enc_out_dim, 1)
17 |
18 | self.init_output(config["SIREN"]["dims"][-1])
19 |
20 |
21 | def init_output(self, layer_size):
22 | sine_init(self.out_lin, self.xyz_encoder.sin_w0, layer_size)
23 |
24 |
25 | def forward(self, x):
26 | """
27 | Inputs:
28 | x: (N, 3) xyz in [-scale, scale]
29 | Outputs:
30 | out: (N), the final sdf value
31 | """
32 | out = self.xyz_encoder(x)
33 |
34 | out_feat = torch.cat(out, dim=1)
35 | out_feat = self.out_lin(out_feat)
36 | out = out_feat / self.xyz_encoder.grid_level
37 |
38 | return out
39 |
40 |
41 | @torch.no_grad()
42 | # optimizer utils
43 | def get_params(self, LR_schedulers):
44 | params = [
45 | {'params': self.parameters(), 'lr': LR_schedulers[0]["initial"]}
46 | ]
47 |
48 | return params
--------------------------------------------------------------------------------
/models/networks/sdf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/models/networks/sdf/__init__.py
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.4.1
2 | kornia==0.6.5
3 | pytorch-lightning==1.7.7
4 | matplotlib==3.5.2
5 | opencv-python==4.6.0.66
6 | lpips
7 | imageio
8 | imageio-ffmpeg
9 | jupyter
10 | scipy
11 | pymcubes
12 | trimesh
13 | dearpygui
--------------------------------------------------------------------------------
/scripts/img/common.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """
4 | These codes are adapted from tiny-cuda-nn (https://github.com/NVlabs/tiny-cuda-nn)
5 | """
6 |
7 | import imageio
8 | import numpy as np
9 | import os
10 | import struct
11 |
12 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13 |
14 | def mse2psnr(x):
15 | return -10.*np.log(x)/np.log(10.)
16 |
17 | def write_image_imageio(img_file, img, quality):
18 | img = (np.clip(img, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8)
19 | kwargs = {}
20 | if os.path.splitext(img_file)[1].lower() in [".jpg", ".jpeg"]:
21 | if img.ndim >= 3 and img.shape[2] > 3:
22 | img = img[:,:,:3]
23 | kwargs["quality"] = quality
24 | kwargs["subsampling"] = 0
25 | imageio.imwrite(img_file, img, **kwargs)
26 |
27 | def read_image_imageio(img_file):
28 | img = imageio.imread(img_file)
29 | img = np.asarray(img).astype(np.float32)
30 | if len(img.shape) == 2:
31 | img = img[:,:,np.newaxis]
32 | return img / 255.0
33 |
34 | ### Do the exp and division operations to expand the expressivity of valid rgb values
35 | def srgb_to_linear(img):
36 | limit = 0.04045
37 | return np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92)
38 |
39 | def linear_to_srgb(img):
40 | limit = 0.0031308
41 | return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img)
42 |
43 | def read_image(file):
44 | if os.path.splitext(file)[1] == ".bin":
45 | with open(file, "rb") as f:
46 | bytes = f.read()
47 | h, w = struct.unpack("ii", bytes[:8])
48 | img = np.frombuffer(bytes, dtype=np.float16, count=h*w*4, offset=8).astype(np.float32).reshape([h, w, 4])
49 | else:
50 | img = read_image_imageio(file)
51 | if img.shape[2] == 4:
52 | img[...,0:3] = srgb_to_linear(img[...,0:3])
53 | # Premultiply alpha
54 | img[...,0:3] *= img[...,3:4]
55 | else:
56 | img = srgb_to_linear(img)
57 | return img
58 |
59 | def write_image(file, img, quality=95):
60 | if os.path.splitext(file)[1] == ".bin":
61 | if img.shape[2] < 4:
62 | img = np.dstack((img, np.ones([img.shape[0], img.shape[1], 4 - img.shape[2]])))
63 | with open(file, "wb") as f:
64 | f.write(struct.pack("ii", img.shape[0], img.shape[1]))
65 | f.write(img.astype(np.float16).tobytes())
66 | else:
67 | if img.shape[2] == 4:
68 | img = np.copy(img)
69 | # Unmultiply alpha
70 | img[...,0:3] = np.divide(img[...,0:3], img[...,3:4], out=np.zeros_like(img[...,0:3]), where=img[...,3:4] != 0)
71 | img[...,0:3] = linear_to_srgb(img[...,0:3])
72 | else:
73 | img = linear_to_srgb(img)
74 | write_image_imageio(file, img, quality)
75 |
76 | def trim(error, skip=0.000001):
77 | error = np.sort(error.flatten())
78 | size = error.size
79 | skip = int(skip * size)
80 | return error[skip:size-skip].mean()
81 |
82 | def luminance(a):
83 | a = np.maximum(0, a)**0.4545454545
84 | return 0.2126 * a[:,:,0] + 0.7152 * a[:,:,1] + 0.0722 * a[:,:,2]
85 |
86 | def L1(img, ref):
87 | return np.abs(img - ref)
88 |
89 | def APE(img, ref):
90 | return L1(img, ref) / (1e-2 + ref)
91 |
92 | def SAPE(img, ref):
93 | return L1(img, ref) / (1e-2 + (ref + img) / 2.)
94 |
95 | def L2(img, ref):
96 | return (img - ref)**2
97 |
98 | def RSE(img, ref):
99 | return L2(img, ref) / (1e-2 + ref**2)
100 |
101 | def rgb_mean(img):
102 | return np.mean(img, axis=2)
103 |
104 | def compute_error_img(metric, img, ref):
105 | img[np.logical_not(np.isfinite(img))] = 0
106 | img = np.maximum(img, 0.)
107 | if metric == "MAE":
108 | return L1(img, ref)
109 | elif metric == "MAPE":
110 | return APE(img, ref)
111 | elif metric == "SMAPE":
112 | return SAPE(img, ref)
113 | elif metric == "MSE":
114 | return L2(img, ref)
115 | elif metric == "MScE":
116 | return L2(np.clip(img, 0.0, 1.0), np.clip(ref, 0.0, 1.0))
117 | elif metric == "MRSE":
118 | return RSE(img, ref)
119 | elif metric == "MtRSE":
120 | return trim(RSE(img, ref))
121 | elif metric == "MRScE":
122 | return RSE(np.clip(img, 0, 100), np.clip(ref, 0, 100))
123 |
124 | raise ValueError(f"Unknown metric: {metric}.")
125 |
126 |
127 | def compute_error(metric, img, ref):
128 | metric_map = compute_error_img(metric, img, ref)
129 | metric_map[np.logical_not(np.isfinite(metric_map))] = 0
130 | if len(metric_map.shape) == 3:
131 | metric_map = np.mean(metric_map, axis=2)
132 | mean = np.mean(metric_map)
133 | return mean
134 |
--------------------------------------------------------------------------------
/scripts/img/opt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_opts():
5 | parser = argparse.ArgumentParser(description="Parsing parameters for 3D occupancy.")
6 |
7 | # config file
8 | parser.add_argument("--config", type=str, required=True,
9 | default="config/img/config.json",
10 | help="network configuration")
11 |
12 | # data file
13 | parser.add_argument("--input_path", type=str, required=True)
14 | parser.add_argument("--output_dir", type=str, default="experiments",
15 | help="output directory")
16 |
17 | # training options
18 | parser.add_argument('--batch_size', type=int, default=2**18,
19 | help='number of points in a batch')
20 | parser.add_argument('--num_epochs', type=int, default=50,
21 | help='number of training epochs')
22 | parser.add_argument('--seed', type=int, default=42,
23 | help='random seed for training')
24 |
25 | # validation options
26 | parser.add_argument('--val_only', action='store_true', default=False,
27 | help='run only validation (need to provide ckpt_path)')
28 | parser.add_argument('--no_save_test', action='store_true', default=False,
29 | help='whether to perform marching cubes for input shapes')
30 |
31 | # misc
32 | parser.add_argument('--ckpt_path', type=str, default=None,
33 | help='pretrained checkpoint to load')
34 | parser.add_argument('--clamp_distance', type=float, default=1.0,
35 | help='the value range for sdfs')
36 |
37 |
38 | args = parser.parse_args()
39 | return args
--------------------------------------------------------------------------------
/scripts/img/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """
4 | These codes are adapted from tiny-cuda-nn (https://github.com/NVlabs/tiny-cuda-nn)
5 | """
6 |
7 | import imageio
8 | import numpy as np
9 | import os
10 | import struct
11 |
12 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13 |
14 |
15 | def write_image_imageio(img_file, img, quality):
16 | img = (np.clip(img, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8)
17 | kwargs = {}
18 | if os.path.splitext(img_file)[1].lower() in [".jpg", ".jpeg"]:
19 | if img.ndim >= 3 and img.shape[2] > 3:
20 | img = img[:,:,:3]
21 | kwargs["quality"] = quality
22 | kwargs["subsampling"] = 0
23 | imageio.imwrite(img_file, img, **kwargs)
24 |
25 | def read_image_imageio(img_file):
26 | img = imageio.imread(img_file)
27 | img = np.asarray(img).astype(np.float32)
28 | if len(img.shape) == 2:
29 | img = img[:,:,np.newaxis]
30 | return img / 255.0
31 |
32 | ### Do the exp and division operations to expand the expressivity of valid rgb values
33 | def srgb_to_linear(img):
34 | limit = 0.04045
35 | return np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92)
36 |
37 | def linear_to_srgb(img):
38 | limit = 0.0031308
39 | return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img)
40 |
41 | def read_image(file):
42 | if os.path.splitext(file)[1] == ".bin":
43 | with open(file, "rb") as f:
44 | bytes = f.read()
45 | h, w = struct.unpack("ii", bytes[:8])
46 | img = np.frombuffer(bytes, dtype=np.float16, count=h*w*4, offset=8).astype(np.float32).reshape([h, w, 4])
47 | else:
48 | img = read_image_imageio(file)
49 | if img.shape[2] == 4:
50 | img[...,0:3] = srgb_to_linear(img[...,0:3])
51 | # Premultiply alpha
52 | img[...,0:3] *= img[...,3:4]
53 | else:
54 | img = srgb_to_linear(img)
55 | return img
56 |
57 | def write_image(file, img, quality=95):
58 | if os.path.splitext(file)[1] == ".bin":
59 | if img.shape[2] < 4:
60 | img = np.dstack((img, np.ones([img.shape[0], img.shape[1], 4 - img.shape[2]])))
61 | with open(file, "wb") as f:
62 | f.write(struct.pack("ii", img.shape[0], img.shape[1]))
63 | f.write(img.astype(np.float16).tobytes())
64 | else:
65 | if img.shape[2] == 4:
66 | img = np.copy(img)
67 | # Unmultiply alpha
68 | img[...,0:3] = np.divide(img[...,0:3], img[...,3:4], out=np.zeros_like(img[...,0:3]), where=img[...,3:4] != 0)
69 | img[...,0:3] = linear_to_srgb(img[...,0:3])
70 | else:
71 | img = linear_to_srgb(img)
72 | write_image_imageio(file, img, quality)
--------------------------------------------------------------------------------
/scripts/nvs/opt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_opts():
4 | parser = argparse.ArgumentParser()
5 |
6 | # dataset parameters
7 | parser.add_argument('--root_dir', type=str, required=True,
8 | help='root directory of dataset')
9 | parser.add_argument('--dataset_name', type=str, default='nsvf',
10 | choices=['nerf', 'nsvf', 'colmap', 'nerfpp', 'rtmv'],
11 | help='which dataset to train/test')
12 | parser.add_argument('--split', type=str, default='train',
13 | choices=['train', 'trainval', 'trainvaltest'],
14 | help='use which split to train')
15 | parser.add_argument('--downsample', type=float, default=1.0,
16 | help='downsample factor (<=1.0) for the images')
17 |
18 | # model parameters
19 | parser.add_argument('--scale', type=float, default=0.5,
20 | help='scene scale (whole scene must lie in [-scale, scale]^3')
21 | parser.add_argument('--use_exposure', action='store_true', default=False,
22 | help='whether to train in HDR-NeRF setting')
23 | parser.add_argument('--config', nargs="?", type=str, default="config/nerf/config.json")
24 |
25 | # loss parameters
26 | parser.add_argument('--distortion_loss_w', type=float, default=0,
27 | help='''weight of distortion loss (see losses.py),
28 | 0 to disable (default), to enable,
29 | a good value is 1e-3 for real scene and 1e-2 for synthetic scene
30 | ''')
31 |
32 | # training options
33 | parser.add_argument('--batch_size', type=int, default=4096,
34 | help='number of rays in a batch')
35 | parser.add_argument('--ray_sampling_strategy', type=str, default='all_images',
36 | choices=['all_images', 'same_image'],
37 | help='''
38 | all_images: uniformly from all pixels of ALL images
39 | same_image: uniformly from all pixels of a SAME image
40 | ''')
41 | parser.add_argument('--num_epochs', type=int, default=30,
42 | help='number of training epochs')
43 | parser.add_argument('--num_gpus', type=int, default=1,
44 | help='number of gpus')
45 | parser.add_argument('--lr', type=float, default=1e-2,
46 | help='learning rate')
47 | parser.add_argument('--seed', type=int, default=42,
48 | help='random seed for training')
49 | # experimental training options
50 | parser.add_argument('--optimize_ext', action='store_true', default=False,
51 | help='whether to optimize extrinsics')
52 | parser.add_argument('--random_bg', action='store_true', default=False,
53 | help='''whether to train with random bg color (real scene only)
54 | to avoid objects with black color to be predicted as transparent
55 | ''')
56 |
57 | # validation options
58 | parser.add_argument('--eval_lpips', action='store_true', default=False,
59 | help='evaluate lpips metric (consumes more VRAM)')
60 | parser.add_argument('--val_only', action='store_true', default=False,
61 | help='run only validation (need to provide ckpt_path)')
62 | parser.add_argument('--no_save_test', action='store_true', default=False,
63 | help='whether to save test image and video')
64 |
65 | # misc
66 | parser.add_argument('--exp_name', type=str, default='exp',
67 | help='experiment name')
68 | parser.add_argument('--ckpt_path', type=str, default=None,
69 | help='pretrained checkpoint to load (including optimizers, etc)')
70 | parser.add_argument('--weight_path', type=str, default=None,
71 | help='pretrained checkpoint to load (excluding optimizers, etc)')
72 |
73 | return parser.parse_args()
74 |
--------------------------------------------------------------------------------
/scripts/nvs/prepare_rtmv.py:
--------------------------------------------------------------------------------
1 | import imageio
2 | import glob
3 | import sys
4 | from tqdm import tqdm
5 | import os
6 | import numpy as np
7 | sys.path.append('datasets')
8 | from color_utils import linear_to_srgb
9 |
10 | import warnings; warnings.filterwarnings("ignore")
11 |
12 |
13 | if __name__ == '__main__':
14 | # convert hdr images to ldr by applying linear_to_srgb and clamping tone-mapping
15 | # and save into images/ folder to accelerate reading
16 | root_dir = sys.argv[1]
17 | envs = sorted(os.listdir(root_dir))
18 | print('Generating ldr images from hdr images ...')
19 | for env in tqdm(envs):
20 | for scene in tqdm(sorted(os.listdir(os.path.join(root_dir, env)))):
21 | os.makedirs(os.path.join(root_dir, env, scene, 'images'), exist_ok=True)
22 | for i, img_p in enumerate(tqdm(sorted(glob.glob(os.path.join(root_dir, env, scene, '*[0-9].exr'))))):
23 | img = imageio.imread(img_p) # hdr
24 | img[..., :3] = linear_to_srgb(img[..., :3])
25 | img = (255*img).astype(np.uint8)
26 | imageio.imsave(os.path.join(root_dir, env, scene, f'images/{i:05d}.png'), img)
--------------------------------------------------------------------------------
/scripts/sdf/opt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_opts():
5 | parser = argparse.ArgumentParser(description="Parsing parameters for 3D occupancy.")
6 |
7 | # config file
8 | parser.add_argument("--config", type=str, required=True,
9 | default="config/sdf/config.json",
10 | help="network configuration")
11 |
12 | # data file
13 | parser.add_argument("--input_path", type=str, required=True)
14 | parser.add_argument("--output_dir", type=str, default="experiments",
15 | help="output directory")
16 |
17 | # training options
18 | parser.add_argument('--batch_size', type=int, default=49152,
19 | help='number of points in a batch')
20 | parser.add_argument('--num_epochs', type=int, default=50,
21 | help='number of training epochs')
22 | parser.add_argument('--seed', type=int, default=42,
23 | help='random seed for training')
24 |
25 | # validation options
26 | parser.add_argument('--val_only', action='store_true', default=False,
27 | help='run only validation (need to provide ckpt_path)')
28 | parser.add_argument('--no_save_test', action='store_true', default=False,
29 | help='whether to perform marching cubes for input shapes')
30 |
31 | # misc
32 | parser.add_argument('--ckpt_path', type=str, default=None,
33 | help='pretrained checkpoint to load')
34 | parser.add_argument('--clamp_distance', type=float, default=0.1,
35 | help='the value range for sdfs')
36 |
37 |
38 | args = parser.parse_args()
39 | return args
--------------------------------------------------------------------------------
/scripts/sdf/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import trimesh
3 | import mcubes
4 |
5 |
6 | def create_mesh(model, mesh_out_path, grid_res):
7 | # Prepare directory
8 | num_samples = grid_res ** 3
9 |
10 | sdf_values = torch.zeros(num_samples, 1)
11 |
12 | bound_min = torch.FloatTensor([-1.0, -1.0, -1.0])
13 | bound_max = torch.FloatTensor([1.0, 1.0, 1.0])
14 |
15 | X = torch.linspace(bound_min[0], bound_max[0], grid_res)
16 | Y = torch.linspace(bound_min[1], bound_max[1], grid_res)
17 | Z = torch.linspace(bound_min[2], bound_max[2], grid_res)
18 |
19 | xx, yy, zz = torch.meshgrid(X, Y, Z, indexing='ij')
20 | inputs = torch.concat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).cuda() # [N, 3]
21 |
22 | head = 0
23 | max_batch = int(2 ** 18)
24 |
25 | while head < num_samples:
26 | sample_subset = inputs[head : min(head + max_batch, num_samples), :]
27 |
28 | sdf_values[head : min(head + max_batch, num_samples), 0] = (
29 | model(sample_subset).squeeze(1).detach().cpu()
30 | )
31 | head += max_batch
32 |
33 | sdf_values = sdf_values.reshape(grid_res, grid_res, grid_res)
34 |
35 | numpy_3d_sdf_tensor = sdf_values.data.cpu().numpy()
36 |
37 | verts, faces = mcubes.marching_cubes(numpy_3d_sdf_tensor, 0.0)
38 |
39 | vertices = verts / (grid_res - 1.0) * 2.0 - 1.0
40 |
41 | print(f'\nSaving mesh to {mesh_out_path}...', end="")
42 |
43 | mesh = trimesh.Trimesh(vertices, faces, process=False) # important, process=True leads to seg fault...
44 | mesh.export(mesh_out_path)
45 |
46 | print(f"==> Finished saving mesh.")
47 |
--------------------------------------------------------------------------------
/train_img.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from scripts.img.opt import get_opts
4 |
5 | # data
6 | from torch.utils.data import DataLoader
7 | from datasets.img.imager import ImageDataset
8 | from scripts.img.common import read_image
9 |
10 | # models
11 | import commentjson as json
12 | from models.networks.img.NFFB_2d import NFFB
13 |
14 | # optimizer, losses
15 | from apex.optimizers import FusedAdam
16 | from torch.optim.lr_scheduler import StepLR
17 |
18 |
19 | # pytorch-lightning
20 | from pytorch_lightning import LightningModule, Trainer
21 | from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
22 | from pytorch_lightning.loggers import TensorBoardLogger
23 |
24 |
25 | from utils import load_ckpt, seed_everything, process_batch_in_chunks
26 |
27 | # output
28 | import time
29 | from scripts.img.utils import write_image
30 |
31 |
32 | import warnings; warnings.filterwarnings("ignore")
33 |
34 |
35 | class ImageSystem(LightningModule):
36 | def __init__(self, hparams):
37 | super().__init__()
38 | self.save_hyperparameters(hparams)
39 |
40 | self.time = str(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))
41 |
42 | exp_dir = os.path.join(self.hparams.output_dir, self.time)
43 | if not os.path.isdir(exp_dir):
44 | os.makedirs(exp_dir)
45 |
46 | ### Load the configuration file
47 | with open(self.hparams.config) as config_file:
48 | self.config = json.load(config_file)
49 |
50 | ### Save the configuration file
51 | path = f"{exp_dir}/config.json"
52 | with open(path, 'w') as f:
53 | json.dump(self.config, f, indent=4, separators=(", ", ": "), sort_keys=True)
54 |
55 | self.img_data = torch.from_numpy(read_image(self.hparams.input_path)).float()
56 |
57 |
58 | def setup(self, stage):
59 | self.model = NFFB(self.config["network"], out_dims=self.img_data.shape[2])
60 |
61 | ema_decay = 0.95
62 | ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged: \
63 | ema_decay * averaged_model_parameter + (1-ema_decay) * model_parameter
64 | self.ema_model = torch.optim.swa_utils.AveragedModel(self.model, avg_fn=ema_avg)
65 |
66 |
67 | self.train_dataset = ImageDataset(data=self.img_data,
68 | size=1000,
69 | num_samples=self.hparams.batch_size,
70 | split='train')
71 |
72 | self.test_dataset = ImageDataset(data=self.img_data,
73 | size=1,
74 | num_samples=self.hparams.batch_size,
75 | split='test')
76 |
77 |
78 | def forward(self, batch):
79 | b_pos = batch["points"]
80 |
81 | pred = self.model(b_pos)
82 |
83 | return pred
84 |
85 |
86 | def on_fit_start(self):
87 | seed_everything(self.hparams.seed)
88 |
89 |
90 | def configure_optimizers(self):
91 | load_ckpt(self.model, self.hparams.ckpt_path)
92 |
93 | opts = []
94 | net_params = self.model.get_params(self.config["training"]["LR_scheduler"])
95 | self.net_opt = FusedAdam(net_params, betas=(0.9, 0.99), eps=1e-15)
96 | opts += [self.net_opt]
97 |
98 | lr_interval = self.config["training"]["LR_scheduler"][0]["interval"]
99 | lr_factor = self.config["training"]["LR_scheduler"][0]["factor"]
100 |
101 | if self.config["training"]["LR_scheduler"][0]["type"] == "Step":
102 | net_sch = StepLR(self.net_opt, step_size=lr_interval, gamma=lr_factor)
103 | else:
104 | net_sch = None
105 |
106 | return opts, [net_sch]
107 |
108 |
109 | def train_dataloader(self):
110 | return DataLoader(self.train_dataset,
111 | num_workers=16,
112 | persistent_workers=True,
113 | batch_size=None,
114 | pin_memory=True)
115 |
116 |
117 | def val_dataloader(self):
118 | return DataLoader(self.test_dataset,
119 | num_workers=8,
120 | batch_size=None,
121 | pin_memory=True)
122 |
123 |
124 | def predict_dataloader(self):
125 | return DataLoader(self.test_dataset,
126 | num_workers=8,
127 | batch_size=None,
128 | pin_memory=True)
129 |
130 |
131 | def training_step(self, batch, batch_nb, *args):
132 | results = self(batch)
133 |
134 | b_occ = batch['rgbs'].to(results.dtype)
135 |
136 | batch_loss = (results - b_occ)**2 / (b_occ.detach()**2 + 1e-2)
137 | loss = batch_loss.mean()
138 |
139 | self.log('lr/network', self.net_opt.param_groups[0]['lr'], True)
140 | self.log('train/loss', loss)
141 |
142 | return loss
143 |
144 |
145 | def training_epoch_end(self, training_step_outputs):
146 | for name, cur_para in self.model.named_parameters():
147 | if len(cur_para) == 0:
148 | print(f"The len of parameter {name} is 0 at epoch {self.current_epoch}.")
149 | continue
150 |
151 | if cur_para is not None and cur_para.requires_grad and cur_para.grad is not None:
152 | para_norm = torch.norm(cur_para.grad.detach(), 2)
153 | self.log('Grad/%s_norm' % name, para_norm)
154 |
155 |
156 | def on_before_zero_grad(self, optimizer):
157 | if self.ema_model is not None:
158 | self.ema_model.update_parameters(self.model)
159 |
160 |
161 | def backward(self, loss, optimizer, optimizer_idx):
162 | # do a custom way of backward to retain graph
163 | loss.backward(retain_graph=True)
164 |
165 |
166 | def on_train_start(self):
167 | gt_img = self.img_data.reshape(self.img_data.shape).float().clamp(0.0, 1.0)
168 | gt_img = gt_img.cpu().numpy()
169 |
170 | img_path = f'{self.hparams.output_dir}/{self.time}/reference.jpg'
171 | write_image(img_path, gt_img)
172 | print(f"\nWriting '{img_path}'... ", end="")
173 |
174 |
175 | model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
176 | self.log("misc/model_size", model_size)
177 | print(f"\nThe model size: {model_size}")
178 |
179 |
180 | def on_train_end(self):
181 | # The final validation will use the ema model, as it replaces our normal model
182 | if self.ema_model is not None:
183 | print("Replacing the standard model with the EMA model for last validation run")
184 | self.model = self.ema_model
185 |
186 |
187 | def on_validation_start(self):
188 | torch.cuda.empty_cache()
189 |
190 | if not self.hparams.no_save_test:
191 | self.val_dir = f'{self.hparams.output_dir}/{self.time}/validation/'
192 | os.makedirs(self.val_dir, exist_ok=True)
193 |
194 |
195 | def validation_step(self, batch, batch_nb):
196 | img_size = self.img_data.shape[0] * self.img_data.shape[1]
197 |
198 | pred_img = process_batch_in_chunks(batch["points"], self.ema_model, max_chunk_size=2**18)
199 | pred_img = pred_img[:img_size, :].reshape(self.img_data.shape).float().clamp(0.0, 1.0)
200 |
201 | pred_img = pred_img.cpu().numpy()
202 |
203 | if not self.hparams.no_save_test:
204 | img_path = f"{self.val_dir}/{self.current_epoch}.jpg"
205 | write_image(img_path, pred_img)
206 |
207 |
208 | def predict_step(self, batch, batch_idx):
209 | img_size = self.img_data.shape[0] * self.img_data.shape[1]
210 |
211 | pred_img = process_batch_in_chunks(batch["points"], self.ema_model, max_chunk_size=2**18)
212 | pred_img = pred_img[:img_size, :].reshape(self.img_data.shape).float().clamp(0.0, 1.0)
213 | pred_img = pred_img.cpu().numpy()
214 |
215 | img_path = f"{self.val_dir}/result.jpg"
216 | write_image(img_path, pred_img)
217 |
218 |
219 | def get_progress_bar_dict(self):
220 | # don't show the version number
221 | items = super().get_progress_bar_dict()
222 | items.pop("v_num", None)
223 | return items
224 |
225 |
226 | if __name__ == '__main__':
227 | hparams = get_opts()
228 | if hparams.val_only and (not hparams.ckpt_path):
229 | raise ValueError('You need to provide a @ckpt_path for validation!')
230 | system = ImageSystem(hparams)
231 |
232 | ckpt_cb = ModelCheckpoint(dirpath=f'{hparams.output_dir}/{system.time}/ckpts/',
233 | filename='{epoch:d}',
234 | save_weights_only=True,
235 | every_n_epochs=hparams.num_epochs,
236 | save_on_train_epoch_end=True,
237 | save_top_k=-1)
238 |
239 | callbacks = [ckpt_cb, TQDMProgressBar(refresh_rate=1)]
240 |
241 | logger = TensorBoardLogger(save_dir=f"{hparams.output_dir}/{system.time}/logs/",
242 | name="",
243 | default_hp_metric=False)
244 |
245 | trainer = Trainer(max_epochs=hparams.num_epochs,
246 | check_val_every_n_epoch=5,
247 | callbacks=callbacks,
248 | logger=logger,
249 | enable_model_summary=False,
250 | accelerator='gpu',
251 | gradient_clip_val=1.0,
252 | strategy=None,
253 | num_sanity_val_steps=-1 if hparams.val_only else 0,
254 | precision=16)
255 |
256 | if hparams.val_only:
257 | trainer.predict(system, ckpt_path=hparams.ckpt_path)
258 | system.output_metrics(logger)
259 | else:
260 | trainer.fit(system, ckpt_path=hparams.ckpt_path)
261 | trainer.predict()
262 | system.output_metrics(logger)
--------------------------------------------------------------------------------
/train_nerf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from scripts.nvs.opt import get_opts
4 | import os
5 | import glob
6 | import imageio
7 | import numpy as np
8 | import cv2
9 | from einops import rearrange
10 |
11 | # data
12 | from torch.utils.data import DataLoader
13 | from datasets import dataset_dict
14 | from datasets.nerf.ray_utils import axisangle_to_R, get_rays
15 |
16 | # models
17 | import commentjson as json
18 | from kornia.utils.grid import create_meshgrid3d
19 | from models.networks.nerf.NFFB_nerf import NFFB
20 | from models.networks.nerf.rendering import render, MAX_SAMPLES
21 |
22 | # optimizer, losses
23 | from apex.optimizers import FusedAdam
24 | from torch.optim.lr_scheduler import CosineAnnealingLR
25 | from models.loss.nerf.losses import NeRFLoss
26 |
27 | # metrics
28 | from torchmetrics import (
29 | PeakSignalNoiseRatio,
30 | StructuralSimilarityIndexMeasure
31 | )
32 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
33 |
34 | # pytorch-lightning
35 | from pytorch_lightning.plugins import DDPPlugin
36 | from pytorch_lightning import LightningModule, Trainer
37 | from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
38 | from pytorch_lightning.loggers import TensorBoardLogger
39 | from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
40 |
41 | from utils import slim_ckpt, load_ckpt, seed_everything
42 |
43 | # output
44 | import time
45 |
46 | import warnings; warnings.filterwarnings("ignore")
47 |
48 |
49 | def depth2img(depth):
50 | depth = (depth-depth.min())/(depth.max()-depth.min())
51 | depth_img = cv2.applyColorMap((depth*255).astype(np.uint8), cv2.COLORMAP_TURBO)
52 |
53 | return depth_img
54 |
55 | class NeRFSystem(LightningModule):
56 | def __init__(self, hparams):
57 | super().__init__()
58 | self.save_hyperparameters(hparams)
59 |
60 | self.warmup_steps = 256
61 | self.update_interval = 16
62 |
63 | self.time = str(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))
64 |
65 | exp_dir = os.path.join(f"experiments/", self.time)
66 | if not os.path.isdir(exp_dir):
67 | os.makedirs(exp_dir)
68 |
69 | with open(self.hparams.config) as config_file:
70 | self.net_config = json.load(config_file)
71 |
72 | ### Save the configuration file
73 | path = f"{exp_dir}/config.json"
74 | with open(path, 'w') as f:
75 | json.dump(self.net_config, f, indent=4, separators=(", ", ": "), sort_keys=True)
76 |
77 | self.loss = NeRFLoss(lambda_distortion=self.hparams.distortion_loss_w)
78 | self.train_psnr = PeakSignalNoiseRatio(data_range=1)
79 | self.val_psnr = PeakSignalNoiseRatio(data_range=1)
80 | self.val_ssim = StructuralSimilarityIndexMeasure(data_range=1)
81 | if self.hparams.eval_lpips:
82 | self.val_lpips = LearnedPerceptualImagePatchSimilarity('vgg')
83 | ### Do not train the network parameters which are used to compute metrics
84 | for p in self.val_lpips.net.parameters():
85 | p.requires_grad = False
86 |
87 | rgb_act = 'None' if self.hparams.use_exposure else 'Sigmoid'
88 | self.model = NFFB(self.net_config["network"], scale=self.hparams.scale, rgb_act=rgb_act)
89 | G = self.model.grid_size
90 | self.model.register_buffer('density_grid', torch.zeros(self.model.cascades, G**3))
91 | self.model.register_buffer('grid_coords',
92 | create_meshgrid3d(G, G, G, False, dtype=torch.int32).reshape(-1, 3))
93 |
94 | def forward(self, batch, split):
95 | if split=='train':
96 | poses = self.poses[batch['img_idxs']]
97 | directions = self.directions[batch['pix_idxs']]
98 | else:
99 | poses = batch['pose']
100 | directions = self.directions
101 |
102 | if self.hparams.optimize_ext:
103 | dR = axisangle_to_R(self.dR[batch['img_idxs']])
104 | poses[..., :3] = dR @ poses[..., :3] # Do the rotation for poses
105 | poses[..., 3] += self.dT[batch['img_idxs']] # Do the translation for poses
106 |
107 | rays_o, rays_d = get_rays(directions, poses)
108 |
109 | kwargs = {'test_time': split!='train',
110 | 'random_bg': self.hparams.random_bg}
111 | if self.hparams.scale > 0.5:
112 | kwargs['exp_step_factor'] = 1 / 256
113 | if self.hparams.use_exposure:
114 | kwargs['exposure'] = batch['exposure']
115 |
116 | return render(self.model, rays_o, rays_d, **kwargs)
117 |
118 | ### Setup the dataset for training and testing
119 | def setup(self, stage):
120 | dataset = dataset_dict[self.hparams.dataset_name]
121 | kwargs = {'root_dir': self.hparams.root_dir,
122 | 'downsample': self.hparams.downsample}
123 | self.train_dataset = dataset(split=self.hparams.split, **kwargs)
124 | self.train_dataset.batch_size = self.hparams.batch_size
125 | self.train_dataset.ray_sampling_strategy = self.hparams.ray_sampling_strategy
126 |
127 | self.test_dataset = dataset(split='test', **kwargs)
128 |
129 | def on_fit_start(self):
130 | seed_everything(self.hparams.seed)
131 |
132 | def configure_optimizers(self):
133 | # define additional parameters
134 | self.register_buffer('directions', self.train_dataset.directions.to(self.device))
135 | self.register_buffer('poses', self.train_dataset.poses.to(self.device))
136 |
137 | if self.hparams.optimize_ext:
138 | N = len(self.train_dataset.poses)
139 | self.register_parameter('dR',
140 | nn.Parameter(torch.zeros(N, 3, device=self.device)))
141 | self.register_parameter('dT',
142 | nn.Parameter(torch.zeros(N, 3, device=self.device)))
143 |
144 | load_ckpt(self.model, self.hparams.weight_path)
145 |
146 | ### Exclude the parameters of camera extrinsics
147 | net_params = []
148 | for n, p in self.named_parameters():
149 | if n not in ['dR', 'dT']: net_params += [p]
150 |
151 | opts = []
152 | net_params = self.model.get_params(self.net_config["training"]["LearningRateSchedule"])
153 | self.net_opt = FusedAdam(net_params, betas=(0.9, 0.99), eps=1e-15)
154 | opts += [self.net_opt]
155 | if self.hparams.optimize_ext:
156 | opts += [FusedAdam([self.dR, self.dT], 1e-6)] # learning rate is hard-coded
157 | net_sch = CosineAnnealingLR(self.net_opt,
158 | self.hparams.num_epochs,
159 | self.net_config["training"]["lr_threshold"])
160 |
161 | return opts, [net_sch]
162 |
163 | def train_dataloader(self):
164 | return DataLoader(self.train_dataset,
165 | num_workers=16,
166 | persistent_workers=True,
167 | batch_size=None,
168 | pin_memory=True)
169 |
170 | def val_dataloader(self):
171 | return DataLoader(self.test_dataset,
172 | num_workers=8,
173 | batch_size=None,
174 | pin_memory=True)
175 |
176 | def on_train_start(self):
177 | self.model.mark_invisible_cells(self.train_dataset.K.to(self.device),
178 | self.poses,
179 | self.train_dataset.img_wh)
180 |
181 | model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
182 | self.log("misc/model_size", model_size)
183 | print(f"\nThe model size: {model_size}")
184 |
185 |
186 | def training_step(self, batch, batch_nb, *args):
187 | if self.global_step % self.update_interval == 0:
188 | self.model.update_density_grid(0.01*MAX_SAMPLES/3**0.5,
189 | warmup=self.global_step 1 c h w', h=h)
248 | rgb_gt = rearrange(rgb_gt, '(h w) c -> 1 c h w', h=h)
249 | self.val_ssim(rgb_pred, rgb_gt)
250 | logs['ssim'] = self.val_ssim.compute()
251 | self.val_ssim.reset()
252 | if self.hparams.eval_lpips:
253 | self.val_lpips(torch.clip(rgb_pred * 2 - 1, -1, 1),
254 | torch.clip(rgb_gt * 2 - 1, -1, 1))
255 | logs['lpips'] = self.val_lpips.compute()
256 | self.val_lpips.reset()
257 |
258 | if not self.hparams.no_save_test: # save test image to disk
259 | idx = batch['img_idxs']
260 | rgb_pred = rearrange(results['rgb'].cpu().numpy(), '(h w) c -> h w c', h=h)
261 | rgb_pred = (rgb_pred*255).astype(np.uint8)
262 | depth = depth2img(rearrange(results['depth'].cpu().numpy(), '(h w) -> h w', h=h))
263 | imageio.imsave(os.path.join(self.val_dir, f'{idx:03d}.png'), rgb_pred)
264 | imageio.imsave(os.path.join(self.val_dir, f'{idx:03d}_d.png'), depth)
265 |
266 | return logs
267 |
268 | def validation_epoch_end(self, outputs):
269 | psnrs = torch.stack([x['psnr'] for x in outputs])
270 | mean_psnr = all_gather_ddp_if_available(psnrs).mean()
271 | self.log('test/psnr', mean_psnr, True)
272 |
273 | ssims = torch.stack([x['ssim'] for x in outputs])
274 | mean_ssim = all_gather_ddp_if_available(ssims).mean()
275 | self.log('test/ssim', mean_ssim)
276 |
277 | if self.hparams.eval_lpips:
278 | lpipss = torch.stack([x['lpips'] for x in outputs])
279 | mean_lpips = all_gather_ddp_if_available(lpipss).mean()
280 | self.log('test/lpips_vgg', mean_lpips)
281 |
282 | def get_progress_bar_dict(self):
283 | # don't show the version number
284 | items = super().get_progress_bar_dict()
285 | items.pop("v_num", None)
286 | return items
287 |
288 |
289 | if __name__ == '__main__':
290 | hparams = get_opts()
291 | if hparams.val_only and (not hparams.ckpt_path):
292 | raise ValueError('You need to provide a @ckpt_path for validation!')
293 | system = NeRFSystem(hparams)
294 |
295 | ckpt_cb = ModelCheckpoint(dirpath=f'experiments/{system.time}/ckpts/',
296 | filename='{epoch:d}',
297 | save_weights_only=True,
298 | every_n_epochs=hparams.num_epochs,
299 | save_on_train_epoch_end=True,
300 | save_top_k=-1)
301 |
302 | callbacks = [ckpt_cb, TQDMProgressBar(refresh_rate=1)]
303 |
304 | logger = TensorBoardLogger(save_dir=f"experiments/{system.time}/logs/",
305 | name="",
306 | default_hp_metric=False)
307 |
308 | trainer = Trainer(max_epochs=hparams.num_epochs,
309 | check_val_every_n_epoch=hparams.num_epochs,
310 | callbacks=callbacks,
311 | logger=logger,
312 | enable_model_summary=False,
313 | accelerator='gpu',
314 | devices=hparams.num_gpus,
315 | strategy=DDPPlugin(find_unused_parameters=True)
316 | if hparams.num_gpus>1 else None,
317 | num_sanity_val_steps=-1 if hparams.val_only else 0,
318 | precision=16)
319 |
320 | trainer.fit(system, ckpt_path=hparams.ckpt_path)
321 |
322 | if not hparams.val_only: # save slimmed ckpt for the last epoch
323 | ckpt_ = \
324 | slim_ckpt(f'experiments/{system.time}/ckpts/epoch={hparams.num_epochs-1}.ckpt',
325 | save_poses=hparams.optimize_ext)
326 | torch.save(ckpt_, f'experiments/{system.time}/ckpts/epoch={hparams.num_epochs-1}_slim.ckpt')
327 |
328 | if (not hparams.no_save_test) and \
329 | hparams.dataset_name=='nsvf' and \
330 | 'Synthetic' in hparams.root_dir: # save video
331 | imgs = sorted(glob.glob(os.path.join(system.val_dir, '*.png')))
332 | imageio.mimsave(os.path.join(system.val_dir, 'rgb.mp4'),
333 | [imageio.imread(img) for img in imgs[::2]],
334 | fps=30, macro_block_size=1)
335 | imageio.mimsave(os.path.join(system.val_dir, 'depth.mp4'),
336 | [imageio.imread(img) for img in imgs[1::2]],
337 | fps=30, macro_block_size=1)
--------------------------------------------------------------------------------
/train_sdf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from scripts.sdf.opt import get_opts
4 |
5 | # data
6 | from torch.utils.data import DataLoader
7 | from datasets.sdf.sampler import SDFDataset
8 |
9 | # models
10 | import commentjson as json
11 | from models.networks.sdf.NFFB_3d import NFFB
12 |
13 | # optimizer, losses
14 | from apex.optimizers import FusedAdam
15 | from torch.optim.lr_scheduler import StepLR
16 |
17 | # pytorch-lightning
18 | from pytorch_lightning import LightningModule, Trainer
19 | from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
20 | from pytorch_lightning.loggers import TensorBoardLogger
21 |
22 | from utils import load_ckpt, seed_everything
23 |
24 | # output
25 | import time
26 | from scripts.sdf.utils import create_mesh
27 |
28 |
29 | import warnings; warnings.filterwarnings("ignore")
30 |
31 |
32 | class SDFSystem(LightningModule):
33 | def __init__(self, hparams):
34 | super().__init__()
35 | self.save_hyperparameters(hparams)
36 |
37 | self.time = str(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))
38 |
39 | exp_dir = os.path.join(self.hparams.output_dir, self.time)
40 | if not os.path.isdir(exp_dir):
41 | os.makedirs(exp_dir)
42 |
43 | ### Load the configuration file
44 | with open(self.hparams.config) as config_file:
45 | self.config = json.load(config_file)
46 |
47 | ### Save the configuration file
48 | path = f"{exp_dir}/config.json"
49 | with open(path, 'w') as f:
50 | json.dump(self.config, f, indent=4, separators=(", ", ": "), sort_keys=True)
51 |
52 | self.model = NFFB(self.config["network"])
53 |
54 | ema_decay = 0.95
55 | ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged: \
56 | ema_decay * averaged_model_parameter + (1-ema_decay) * model_parameter
57 | self.ema_model = torch.optim.swa_utils.AveragedModel(self.model, avg_fn=ema_avg)
58 |
59 |
60 | def setup(self, stage):
61 | self.train_dataset = SDFDataset(path=self.hparams.input_path,
62 | size=1000,
63 | num_samples=self.hparams.batch_size,
64 | clip_sdf=self.hparams.clamp_distance)
65 |
66 | self.test_dataset = SDFDataset(path=self.hparams.input_path,
67 | size=1,
68 | num_samples=self.hparams.batch_size,
69 | clip_sdf=self.hparams.clamp_distance)
70 |
71 |
72 | def forward(self, batch):
73 | b_pos = batch["points"]
74 |
75 | pred = self.model(b_pos)
76 |
77 | if self.hparams.clamp_distance > 0.0:
78 | pred = torch.clamp(pred, -self.hparams.clamp_distance, self.hparams.clamp_distance)
79 |
80 | return pred
81 |
82 |
83 | def on_fit_start(self):
84 | seed_everything(self.hparams.seed)
85 |
86 |
87 | def configure_optimizers(self):
88 | load_ckpt(self.model, self.hparams.ckpt_path)
89 |
90 | opts = []
91 | net_params = self.model.get_params(self.config["training"]["LR_scheduler"])
92 | self.net_opt = FusedAdam(net_params, betas=(0.9, 0.99), eps=1e-15)
93 | opts += [self.net_opt]
94 |
95 | lr_interval = self.config["training"]["LR_scheduler"][0]["interval"]
96 | lr_factor = self.config["training"]["LR_scheduler"][0]["factor"]
97 |
98 | if self.config["training"]["LR_scheduler"][0]["type"] == "Step":
99 | net_sch = StepLR(self.net_opt, step_size=lr_interval, gamma=lr_factor)
100 | else:
101 | net_sch = None
102 |
103 | return opts, [net_sch]
104 |
105 | def train_dataloader(self):
106 | return DataLoader(self.train_dataset,
107 | num_workers=16,
108 | persistent_workers=True,
109 | batch_size=None,
110 | pin_memory=True)
111 |
112 | def val_dataloader(self):
113 | return DataLoader(self.test_dataset,
114 | num_workers=8,
115 | batch_size=None,
116 | pin_memory=True)
117 |
118 | def predict_dataloader(self):
119 | return DataLoader(self.test_dataset,
120 | num_workers=8,
121 | batch_size=None,
122 | pin_memory=True)
123 |
124 | def training_step(self, batch, batch_nb, *args):
125 | results = self(batch)
126 |
127 | b_occ = batch['sdfs'].to(results.dtype)
128 | if self.hparams.clamp_distance > 0.0:
129 | b_occ = torch.clamp(b_occ, -self.hparams.clamp_distance, self.hparams.clamp_distance)
130 |
131 | batch_loss = (results - b_occ)**2 / (b_occ.detach()**2 + 1e-4)
132 | loss = batch_loss.mean()
133 |
134 | self.log('lr/network', self.net_opt.param_groups[0]['lr'], True)
135 | self.log('train/loss', loss)
136 |
137 | return loss
138 |
139 | def training_epoch_end(self, training_step_outputs):
140 | for name, cur_para in self.model.named_parameters():
141 | if len(cur_para) == 0:
142 | print(f"The len of parameter {name} is 0 at epoch {self.current_epoch}.")
143 | continue
144 |
145 | if cur_para is not None and cur_para.requires_grad:
146 | para_norm = torch.norm(cur_para.grad.detach(), 2)
147 | self.log('Grad/%s_norm' % name, para_norm)
148 |
149 | def on_before_zero_grad(self, optimizer):
150 | if self.ema_model is not None:
151 | self.ema_model.update_parameters(self.model)
152 |
153 | def backward(self, loss, optimizer, optimizer_idx):
154 | # to retain graph
155 | loss.backward(retain_graph=True)
156 |
157 | def on_train_start(self):
158 | model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
159 | self.log("misc/model_size", model_size)
160 | print(f"\nThe model size: {model_size}")
161 |
162 | def on_train_end(self):
163 | # The final validation will use the ema model, as it replaces our normal model
164 | if self.ema_model is not None:
165 | print("Replacing the standard model with the EMA model for last validation run")
166 | self.model = self.ema_model
167 |
168 | def on_validation_start(self):
169 | torch.cuda.empty_cache()
170 |
171 | if not self.hparams.no_save_test:
172 | self.val_dir = f'{self.hparams.output_dir}/{self.time}/validation/'
173 | os.makedirs(self.val_dir, exist_ok=True)
174 |
175 | def validation_step(self, batch, batch_nb):
176 | if not self.hparams.no_save_test:
177 | res = 256
178 | mesh_path = os.path.join(self.val_dir, f'val_{self.current_epoch}_{res}.ply')
179 |
180 | create_mesh(self.ema_model, mesh_path, res)
181 |
182 |
183 | def on_predict_start(self):
184 | torch.cuda.empty_cache()
185 |
186 | if not self.hparams.no_save_test:
187 | self.pred_dir = f'{self.hparams.output_dir}/{self.time}/results/'
188 | os.makedirs(self.pred_dir, exist_ok=True)
189 |
190 | def predict_step(self, batch, batch_nb):
191 | if not self.hparams.no_save_test:
192 | res = 1024
193 | mesh_path = os.path.join(self.pred_dir, f'output_{res}.ply')
194 |
195 | create_mesh(self.model, mesh_path, res)
196 |
197 |
198 | def get_progress_bar_dict(self):
199 | # don't show the version number
200 | items = super().get_progress_bar_dict()
201 | items.pop("v_num", None)
202 | return items
203 |
204 |
205 | if __name__ == '__main__':
206 | hparams = get_opts()
207 | if hparams.val_only and (not hparams.ckpt_path):
208 | raise ValueError('You need to provide a @ckpt_path for validation!')
209 | system = SDFSystem(hparams)
210 |
211 | ckpt_cb = ModelCheckpoint(dirpath=f'{hparams.output_dir}/{system.time}/ckpts/',
212 | filename='{epoch:d}',
213 | save_weights_only=True,
214 | every_n_epochs=hparams.num_epochs,
215 | save_on_train_epoch_end=True,
216 | save_top_k=-1)
217 |
218 | callbacks = [ckpt_cb, TQDMProgressBar(refresh_rate=1)]
219 |
220 | logger = TensorBoardLogger(save_dir=f"{hparams.output_dir}/{system.time}/logs/",
221 | name="",
222 | default_hp_metric=False)
223 |
224 | trainer = Trainer(max_epochs=hparams.num_epochs,
225 | check_val_every_n_epoch=5,
226 | callbacks=callbacks,
227 | logger=logger,
228 | enable_model_summary=False,
229 | accelerator='gpu',
230 | gradient_clip_val=1.0,
231 | strategy=None,
232 | num_sanity_val_steps=-1 if hparams.val_only else 0,
233 | precision=16)
234 |
235 | if hparams.val_only:
236 | trainer.predict(system, ckpt_path=hparams.ckpt_path)
237 | else:
238 | trainer.fit(system, ckpt_path=hparams.ckpt_path)
239 |
240 | if (not hparams.no_save_test): # save mesh
241 | trainer.predict(system)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | import random
5 |
6 | import torch
7 | import pytorch_lightning
8 |
9 |
10 | def extract_model_state_dict(ckpt_path, model_name='model', prefixes_to_ignore=[]):
11 | checkpoint = torch.load(ckpt_path, map_location='cpu')
12 | checkpoint_ = {}
13 | if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint
14 | checkpoint = checkpoint['state_dict']
15 | for k, v in checkpoint.items():
16 | if not k.startswith(model_name):
17 | continue
18 | k = k[len(model_name)+1:]
19 | for prefix in prefixes_to_ignore:
20 | if k.startswith(prefix):
21 | break
22 | else:
23 | checkpoint_[k] = v
24 | return checkpoint_
25 |
26 |
27 | def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]):
28 | if not ckpt_path: return
29 | model_dict = model.state_dict()
30 | checkpoint_ = extract_model_state_dict(ckpt_path, model_name, prefixes_to_ignore)
31 | model_dict.update(checkpoint_)
32 | model.load_state_dict(model_dict)
33 |
34 |
35 | def slim_ckpt(ckpt_path, save_poses=False):
36 | ckpt = torch.load(ckpt_path, map_location='cpu')
37 | # pop unused parameters
38 | keys_to_pop = ['directions', 'model.density_grid', 'model.grid_coords']
39 | if not save_poses: keys_to_pop += ['poses']
40 | for k in ckpt['state_dict']:
41 | if k.startswith('val_lpips'):
42 | keys_to_pop += [k]
43 | for k in keys_to_pop:
44 | ckpt['state_dict'].pop(k, None)
45 | return ckpt['state_dict']
46 |
47 |
48 |
49 | def seed_everything(seed):
50 | random.seed(seed)
51 | os.environ['PYTHONHASHSEED'] = str(seed)
52 | np.random.seed(seed)
53 | torch.manual_seed(seed)
54 | torch.cuda.manual_seed(seed)
55 | pytorch_lightning.seed_everything(seed, workers=True)
56 | #torch.backends.cudnn.deterministic = True
57 | #torch.backends.cudnn.benchmark = True
58 |
59 |
60 |
61 | def process_batch_in_chunks(in_ccords, model, max_chunk_size=1024):
62 | chunk_outs = []
63 |
64 | coord_chunks = torch.split(in_ccords, max_chunk_size)
65 | for chunk_batched_in in coord_chunks:
66 | tmp_img = model(chunk_batched_in)
67 | chunk_outs.append(tmp_img.detach())
68 |
69 | batched_out = torch.cat(chunk_outs, dim=0)
70 |
71 | return batched_out
--------------------------------------------------------------------------------