├── README.md ├── convert_blender_data.py ├── dataio.py ├── download_datasets.py ├── environment.yml ├── experiments ├── config │ ├── 1d │ │ ├── bacon_freq1.ini │ │ ├── bacon_freq2.ini │ │ ├── ff.ini │ │ └── siren.ini │ ├── img │ │ ├── bacon.ini │ │ ├── ff.ini │ │ ├── mip.ini │ │ └── siren.ini │ ├── nerf │ │ ├── bacon.ini │ │ ├── bacon_lr.ini │ │ └── bacon_semisupervise.ini │ └── sdf │ │ ├── bacon_armadillo.ini │ │ ├── bacon_dragon.ini │ │ ├── bacon_lucy.ini │ │ ├── bacon_thai.ini │ │ ├── ff_armadillo.ini │ │ ├── ff_dragon.ini │ │ ├── ff_lucy.ini │ │ ├── ff_thai.ini │ │ ├── siren_armadillo.ini │ │ ├── siren_dragon.ini │ │ ├── siren_lucy.ini │ │ └── siren_thai.ini ├── figure_setup.py ├── plot_activation_distributions.py ├── render_nerf.py ├── render_sdf.py ├── train_1d.py ├── train_img.py ├── train_radiance_field.py └── train_sdf.py ├── forward_models.py ├── img └── teaser.png ├── loss_functions.py ├── modules.py ├── spectrum_visualization ├── README.md ├── environment.yml ├── get_shape_spectra.py ├── magicwand ├── model_final.pth └── ref_armadillo.xyz ├── trained_models ├── armadillo.pth ├── dragon.pth ├── lego.pth ├── lego_semisupervise.pth ├── lucy.pth └── thai.pth ├── training.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # BACON: Band-limited Coordinate Networks for Multiscale Scene Representation | CVPR 2022 (oral) 2 | ### [Project Page](http://www.computationalimaging.org/publications/bacon/) | [Video](https://www.youtube.com/watch?v=zIH3KUCgJEA) | [Paper](https://arxiv.org/abs/2112.04645) 3 | Official PyTorch implementation of BACON.
4 | [BACON: Band-limited Coordinate Networks for Multiscale Scene Representation](http://www.computationalimaging.org/publications/bacon/)
5 | [David B. Lindell](https://davidlindell.com)\*, 6 | [Dave Van Veen](https://davevanveen.com/), 7 | [Jeong Joon Park](https://jjparkcv.github.io/), 8 | [Gordon Wetzstein](https://computationalimaging.org)
9 | Stanford University
10 | 11 | 12 | 13 | ## Quickstart 14 | 15 | To setup a conda environment use these commands 16 | ``` 17 | conda env create -f environment.yml 18 | conda activate bacon 19 | 20 | # download all datasets 21 | python download_datasets.py 22 | ``` 23 | Now you can train networks to fit a 1D function, images, signed distance fields, or neural radiance fields with the following commands. 24 | 25 | ``` 26 | cd experiments 27 | python train_1d.py --config ./config/1d/bacon_freq1.ini # train 1D function 28 | python train_img.py --config ./config/img/bacon.ini # train image 29 | python train_sdf.py --config ./config/sdf/bacon_armadillo.ini # train SDF 30 | python train_radiance_field.py --config ./config/nerf/bacon_lr.ini # train NeRF 31 | ``` 32 | 33 | To visualize outputs in Tensorboard, run the following. 34 | ``` 35 | tensorboard --logdir=../logs --port=6006 36 | ``` 37 | 38 | ## Band-limited Coordinate Networks 39 | 40 | Band-limited coordinate networks have an analytical Fourier spectrum and interpretible behavior. We demonstrate using these networks for fitting simple 1D signals, images, 3D shapes via signed distance functions and neural radiance fields. 41 | 42 | ### Datasets 43 | 44 | Datasets can be downloaded using the `download_datasets.py` script. This script 45 | - downloads the synthetic Blender dataset from the [original NeRF paper](https://github.com/bmild/nerf), 46 | - generates a multiscale version of the Blender dataset, 47 | - downloads 3D models originating from the [Stanford 3D Scanning Repository](http://graphics.stanford.edu/data/3Dscanrep/), which we have adjusted to make watertight, and 48 | - downloads an example image from the [Kodak dataset](http://www.cs.albany.edu/~xypan/research/snr/Kodak.html). 49 | 50 | ### Training 51 | 52 | We provide scripts for training and configuration files to reproduce the results in the paper. 53 | 54 | #### 1D Examples 55 | To run the 1D examples, use the `experiments/train_1d.py` script with any of the config files in `experiments/config/1d`. These scripts allow training models with BACON, [Fourier Features](https://github.com/tancik/fourier-feature-networks), or [SIREN](https://github.com/vsitzmann/siren). 56 | For example, to train a BACON model you can run 57 | 58 | ``` 59 | python train_1d.py --config ./config/1d/bacon_freq1.ini 60 | ``` 61 | 62 | To change the bandwidth of BACON, adjust the maximum frequency with the `--max_freq` flag. 63 | This sets network-equivalent sampling rate used to represent the signal. 64 | For example, if the signal you wish to represent has a maximum frequency of 5 cycles per unit interval, this value should be set to at least the Nyquist rate of 2 samples per cycle or 10 samples per unit interval. 65 | By default, the frequencies represented by BACON are quantized to intervals of 2*pi; thus, the network is periodic over an interval from -0.5 to 0.5. 66 | That is, the output of the network will repeat for input coordinates that exceed an absolute value of 0.5. 67 | 68 | #### Image Fitting 69 | 70 | Image fitting can be performed using the config files in `experiments/config/img` and the `train_img.py` script. We support training BACON, Fourier Features, SIREN, and networks with the positional encoding from [Mip-NeRF](https://github.com/google/mipnerf). 71 | 72 | #### SDF Fitting 73 | 74 | Config files for SDF fitting are in `experiments/config/sdf` and can be used with the `train_sdf.py` script. 75 | Be sure to download the example datasets before running this script. 76 | 77 | We also provide a rendering script to extract meshes from the trained models. 78 | The `render_sdf.py` program extracts a mesh using marching cubes and, optionally, our proposed multiscale adaptive SDF evaluation procedure. 79 | 80 | #### NeRF Reconstruction 81 | 82 | Use the config files in `experiments/config/nerf` with the `train_radiance_field.py` script to train neural radiance fields. 83 | Note that training the full resolution model can takes a while (a few days) so it may be easier to train a low-resolution model to get started. 84 | We provide a low-resolution config file in `experiments/config/nerf/bacon_lr.ini`. 85 | 86 | To render output images from a trained model, use the `render_nerf.py` script. 87 | Note that the Blender synthetic datasets should be downloaded and the multiscale dataset generated before running this script. 88 | 89 | #### Initialization Scheme 90 | 91 | Finally, we also show a visualization of our initialization scheme in `experiments/plot_activation_distributions.py`. As shown in the paper, our initialization scheme prevents the distribution of activations from becoming vanishingly small, even for deep networks. 92 | 93 | 94 | #### Pretrained models 95 | 96 | For convenience, we include pretrained models for the SDF fitting and NeRF reconstruction tasks in the `pretrained_models` directory. 97 | The outputs of these models can be rendered directly using the `experiments/render_sdf.py` and `experiments/render_nerf.py` scripts. 98 | 99 | ## Citation 100 | 101 | ``` 102 | @article{lindell2021bacon, 103 | author = {Lindell, David B. and Van Veen, Dave and Park, Jeong Joon and Wetzstein, Gordon}, 104 | title = {BACON: Band-limited coordinate networks for multiscale scene representation}, 105 | journal = {arXiv preprint arXiv:2112.04645}, 106 | year={2021} 107 | } 108 | ``` 109 | ## Acknowledgments 110 | 111 | This project was supported in part by a PECASE by the ARO and NSF award 1839974. 112 | -------------------------------------------------------------------------------- /convert_blender_data.py: -------------------------------------------------------------------------------- 1 | # adapted from Jon Barron's mipnerf conversion script 2 | # https://github.com/google/mipnerf/blob/main/scripts/convert_blender_data.py 3 | 4 | import json 5 | import os 6 | from os import path 7 | 8 | from absl import app 9 | from absl import flags 10 | import numpy as np 11 | from PIL import Image 12 | 13 | import skimage.transform 14 | 15 | FLAGS = flags.FLAGS 16 | 17 | flags.DEFINE_string('blenderdir', None, 'Base directory for all Blender data.') 18 | flags.DEFINE_integer('n_down', 4, 'How many levels of downscaling to use.') 19 | 20 | 21 | def load_renderings(data_dir, split): 22 | """Load images and metadata from disk.""" 23 | f = 'transforms_{}.json'.format(split) 24 | with open(path.join(data_dir, f), 'r') as fp: 25 | meta = json.load(fp) 26 | images = [] 27 | cams = [] 28 | print('Loading imgs') 29 | for frame in meta['frames']: 30 | fname = os.path.join(data_dir, frame['file_path'] + '.png') 31 | with open(fname, 'rb') as imgin: 32 | image = np.array(Image.open(imgin), dtype=np.float32) / 255. 33 | cams.append(frame['transform_matrix']) 34 | images.append(image) 35 | ret = {} 36 | ret['images'] = np.stack(images, axis=0) 37 | print('Loaded all images, shape is', ret['images'].shape) 38 | ret['camtoworlds'] = np.stack(cams, axis=0) 39 | w = ret['images'].shape[2] 40 | camera_angle_x = float(meta['camera_angle_x']) 41 | ret['focal'] = .5 * w / np.tan(.5 * camera_angle_x) 42 | return ret 43 | 44 | 45 | def down2(img): 46 | sh = img.shape 47 | return np.mean(np.reshape(img, [sh[0] // 2, 2, sh[1] // 2, 2, -1]), (1, 3)) 48 | 49 | 50 | def convert_to_nerfdata(basedir, n_down): 51 | """Convert Blender data to multiscale.""" 52 | splits = ['train', 'val', 'test'] 53 | # Foreach split in the dataset 54 | for split in splits: 55 | print('Split', split) 56 | # Load everything 57 | data = load_renderings(basedir, split) 58 | 59 | # Save out all the images 60 | imgdir = '{}_multiscale'.format(split) 61 | os.makedirs(os.path.join(basedir, imgdir), exist_ok=True) 62 | print('Saving images') 63 | for i, img in enumerate(data['images']): 64 | for j in range(n_down): 65 | fname = '{}/r_{:d}_d{}.png'.format(imgdir, i, j) 66 | fname = os.path.join(basedir, fname) 67 | with open(fname, 'wb') as imgout: 68 | img = skimage.transform.resize(img, 2*(512//2**j,)) 69 | img8 = Image.fromarray(np.uint8(img * 255)) 70 | img8.save(imgout) 71 | # img = down2(img) 72 | 73 | 74 | def main(unused_argv): 75 | 76 | blenderdir = FLAGS.blenderdir 77 | n_down = FLAGS.n_down 78 | 79 | dirs = [os.path.join(blenderdir, f) for f in os.listdir(blenderdir)] 80 | dirs = [d for d in dirs if os.path.isdir(d)] 81 | print(dirs) 82 | for basedir in dirs: 83 | print() 84 | convert_to_nerfdata(basedir, n_down) 85 | 86 | 87 | if __name__ == '__main__': 88 | app.run(main) 89 | -------------------------------------------------------------------------------- /dataio.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | from PIL import Image 5 | import skimage 6 | from torchvision.transforms import Compose, ToTensor, Resize, Lambda 7 | import skimage.transform 8 | import json 9 | import os 10 | import re 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | from pykdtree.kdtree import KDTree 14 | import errno 15 | import urllib.request 16 | 17 | 18 | def get_mgrid(sidelen, dim=2, centered=True, include_end=False): 19 | '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.''' 20 | if isinstance(sidelen, int): 21 | sidelen = dim * (sidelen,) 22 | 23 | if include_end: 24 | denom = [s-1 for s in sidelen] 25 | else: 26 | denom = sidelen 27 | 28 | if dim == 2: 29 | pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32) 30 | pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / denom[0] 31 | pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / denom[1] 32 | elif dim == 3: 33 | pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32) 34 | pixel_coords[..., 0] = pixel_coords[..., 0] / denom[0] 35 | pixel_coords[..., 1] = pixel_coords[..., 1] / denom[1] 36 | pixel_coords[..., 2] = pixel_coords[..., 2] / denom[2] 37 | else: 38 | raise NotImplementedError('Not implemented for dim=%d' % dim) 39 | 40 | if centered: 41 | pixel_coords -= 0.5 42 | 43 | pixel_coords = torch.Tensor(pixel_coords).view(-1, dim) 44 | return pixel_coords 45 | 46 | 47 | def lin2img(tensor, image_resolution=None): 48 | batch_size, num_samples, channels = tensor.shape 49 | if image_resolution is None: 50 | width = np.sqrt(num_samples).astype(int) 51 | height = width 52 | else: 53 | height = image_resolution[0] 54 | width = image_resolution[1] 55 | 56 | return tensor.permute(0, 2, 1).view(batch_size, channels, height, width) 57 | 58 | 59 | class Func1DWrapper(torch.utils.data.Dataset): 60 | def __init__(self, range, fn, grad_fn=None, 61 | sampling_density=100, train_every=10): 62 | 63 | coords = self.get_samples(range, sampling_density) 64 | self.fn_vals = fn(coords) 65 | self.train_idx = torch.arange(0, coords.shape[0], train_every).float() 66 | 67 | self.grid = coords 68 | self.grid.requires_grad_(True) 69 | self.range = range 70 | 71 | def get_samples(self, range, sampling_density): 72 | num = int(range[1] - range[0])*sampling_density 73 | coords = np.linspace(start=range[0], stop=range[1], num=num) 74 | coords.astype(np.float32) 75 | coords = torch.Tensor(coords).view(-1, 1) 76 | return coords 77 | 78 | def get_num_samples(self): 79 | return self.grid.shape[0] 80 | 81 | def __len__(self): 82 | return 1 83 | 84 | def __getitem__(self, idx): 85 | 86 | return {'idx': self.train_idx, 'coords': self.grid}, \ 87 | {'func': self.fn_vals, 'coords': self.grid} 88 | 89 | 90 | def rect(coords, width=1): 91 | return torch.where(abs(coords) < width/2, 1.0/width, 0.0) 92 | 93 | 94 | def gaussian(coords, sigma=1, center=0.5): 95 | return 1 / (sigma * math.sqrt(2*np.pi)) * torch.exp(-(coords-center)**2 / (2*sigma**2)) 96 | 97 | 98 | def sines1(coords): 99 | return 0.3 * torch.sin(2*np.pi*8*coords + np.pi/3) + 0.65 * torch.sin(2*np.pi*2*coords + np.pi) 100 | 101 | 102 | def polynomial_1(coords): 103 | return .1*((coords+.2)*3)**5 - .2*((coords+.2)*3)**4 + .2*((coords+.2)*3)**3 - .4*((coords+.2)*3)**2 + .1*((coords+.2)*3) 104 | 105 | 106 | def sinc(coords): 107 | coords[coords == 0] += 1 108 | return torch.div(torch.sin(20*coords), 20*coords) 109 | 110 | 111 | def linear(coords): 112 | return 1.0 * coords 113 | 114 | 115 | def xcosx(coords): 116 | return coords * torch.cos(coords) 117 | 118 | 119 | class ImageWrapper(torch.utils.data.Dataset): 120 | def __init__(self, dataset, compute_diff='all', centered=True, 121 | include_end=False, multiscale=False, stages=3): 122 | 123 | self.compute_diff = compute_diff 124 | self.centered = centered 125 | self.include_end = include_end 126 | self.transform = Compose([ 127 | ToTensor(), 128 | ]) 129 | 130 | self.dataset = dataset 131 | self.mgrid = get_mgrid(self.dataset.resolution, centered=centered, include_end=include_end) 132 | 133 | # sample pixel centers 134 | self.mgrid = self.mgrid + 1 / (2 * self.dataset.resolution[0]) 135 | self.radii = 1 / self.dataset.resolution[0] * 2/np.sqrt(12) 136 | self.radii = [(self.radii * 2**i).astype(np.float32) for i in range(3)] 137 | self.radii.reverse() 138 | 139 | img = self.transform(self.dataset[0]) 140 | _, self.rows, self.cols = img.shape 141 | 142 | self.img_chw = img 143 | self.img = img.permute(1, 2, 0).view(-1, self.dataset.img_channels) 144 | 145 | self.imgs = [] 146 | self.multiscale = multiscale 147 | img = img.permute(1, 2, 0).numpy() 148 | for i in range(stages): 149 | tmp = skimage.transform.resize(img, [s//2**i for s in (self.rows, self.cols)]) 150 | tmp = skimage.transform.resize(tmp, (self.rows, self.cols)) 151 | self.imgs.append(torch.from_numpy(tmp).view(-1, self.dataset.img_channels)) 152 | self.imgs.reverse() 153 | 154 | def __len__(self): 155 | return len(self.dataset) 156 | 157 | def __getitem__(self, idx): 158 | 159 | coords = self.mgrid 160 | img = self.img 161 | 162 | in_dict = {'coords': coords, 'radii': self.radii} 163 | gt_dict = {'img': img} 164 | 165 | if self.multiscale: 166 | gt_dict['img'] = self.imgs 167 | 168 | return in_dict, gt_dict 169 | 170 | 171 | def save_img(img, filename): 172 | ''' given np array, convert to image and save ''' 173 | img = Image.fromarray((255*img).astype(np.uint8)) 174 | img.save(filename) 175 | 176 | 177 | def crop_center(pil_img, crop_width, crop_height): 178 | img_width, img_height = pil_img.size 179 | return pil_img.crop(((img_width - crop_width) // 2, 180 | (img_height - crop_height) // 2, 181 | (img_width + crop_width) // 2, 182 | (img_height + crop_height) // 2)) 183 | 184 | 185 | def crop_max_square(pil_img): 186 | return crop_center(pil_img, min(pil_img.size), min(pil_img.size)) 187 | 188 | 189 | class ImageFile(Dataset): 190 | def __init__(self, filename, grayscale=False, resolution=None, 191 | root_path=None, crop_square=True, url=None): 192 | 193 | super().__init__() 194 | 195 | if not os.path.exists(filename): 196 | if url is None: 197 | raise FileNotFoundError( 198 | errno.ENOENT, os.strerror(errno.ENOENT), filename) 199 | else: 200 | print('Downloading image file...') 201 | os.makedirs(os.path.dirname(filename), exist_ok=True) 202 | urllib.request.urlretrieve(url, filename) 203 | 204 | self.img = Image.open(filename) 205 | if grayscale: 206 | self.img = self.img.convert('L') 207 | else: 208 | self.img = self.img.convert('RGB') 209 | 210 | self.img_channels = len(self.img.mode) 211 | self.resolution = self.img.size 212 | 213 | if crop_square: # preserve aspect ratio 214 | self.img = crop_max_square(self.img) 215 | 216 | if resolution is not None: 217 | self.resolution = resolution 218 | self.img = self.img.resize(resolution, Image.ANTIALIAS) 219 | 220 | self.img = np.array(self.img) 221 | self.img = self.img.astype(np.float32)/255. 222 | 223 | def __len__(self): 224 | return 1 225 | 226 | def __getitem__(self, idx): 227 | return self.img 228 | 229 | 230 | def chunk_lists_from_batch_reduce_to_raysamples_fn(model_input, meta, gt, max_chunk_size): 231 | 232 | model_in_chunked = [] 233 | for key in model_input: 234 | num_views, num_rays, num_samples_per_rays, num_dims = model_input[key].shape 235 | chunks = torch.split(model_input[key].view(-1, num_samples_per_rays, num_dims), max_chunk_size) 236 | model_in_chunked.append(chunks) 237 | 238 | list_chunked_model_input = \ 239 | [{k: v for k, v in zip(model_input.keys(), curr_chunks)} for curr_chunks in zip(*model_in_chunked)] 240 | 241 | # meta_dict 242 | list_chunked_zs = torch.split(meta['zs'].view(-1, num_samples_per_rays, 1), 243 | max_chunk_size) 244 | list_chunked_meta = [{'zs': zs} for zs in list_chunked_zs] 245 | 246 | # gt_dict 247 | gt_chunked = [] 248 | for key in gt: 249 | if isinstance(gt[key], list): 250 | # this handles lists of gt tensors (e.g., for multiscale) 251 | num_dims = gt[key][0].shape[-1] 252 | 253 | # this chunks the list elements so you have [num_tensors, num_chunks] 254 | chunks = [torch.split(x.view(-1, num_dims), max_chunk_size) for x in gt[key]] 255 | 256 | # this switches it to [num_chunks, num_tensors] 257 | chunks = [chunk for chunk in zip(*chunks)] 258 | gt_chunked.append(chunks) 259 | else: 260 | *_, num_dims = gt[key].shape 261 | chunks = torch.split(gt[key].view(-1, num_dims), max_chunk_size) 262 | gt_chunked.append(chunks) 263 | 264 | list_chunked_gt = \ 265 | [{k: v for k, v in zip(gt.keys(), curr_chunks)} for curr_chunks in zip(*gt_chunked)] 266 | 267 | return list_chunked_model_input, list_chunked_meta, list_chunked_gt 268 | 269 | 270 | class NerfBlenderDataset(torch.utils.data.Dataset): 271 | def __init__(self, basedir, mode='train', 272 | splits=['train', 'val', 'test'], 273 | select_idx=None, 274 | testskip=1, resize_to=None, final_render=False, 275 | d_rot=0, bounds=((-2, 2), (-2, 2), (0, 2)), 276 | multiscale=False, 277 | black_background=False, 278 | override_scale=None): 279 | 280 | self.mode = mode 281 | self.basedir = basedir 282 | self.resize_to = resize_to 283 | self.final_render = final_render 284 | self.bounds = bounds 285 | self.multiscale = multiscale 286 | self.select_idx = select_idx 287 | self.d_rot = d_rot 288 | 289 | metas = {} 290 | for s in splits: 291 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 292 | metas[s] = json.load(fp) 293 | 294 | # Eventually transform the inputs 295 | transform_list = [ToTensor()] 296 | if resize_to is not None: 297 | transform_list.insert(0, Resize(resize_to, 298 | interpolation=Image.BILINEAR)) 299 | 300 | def multiscale_resize(x): 301 | scale = 512 // x.size[0] 302 | return x.resize([r//scale for r in resize_to], 303 | resample=Image.BILINEAR) 304 | 305 | if multiscale and override_scale is None: 306 | # this will scale the image down appropriately 307 | # (e.g., to 1/2, 1/4, 1/8 of desired resolution) 308 | # Then the next transform will scale it back up so we use the same rays 309 | # to supervise 310 | transform_list.insert(0, Lambda(lambda x: multiscale_resize(x))) 311 | if black_background: 312 | transform_list.append(Lambda(lambda x: x[:3] * x[[-1]])) 313 | else: 314 | transform_list.append(Lambda(lambda x: x[:3] * x[[-1]] + (1 - x[[-1]]))) 315 | 316 | self.transforms = Compose(transform_list) 317 | 318 | # Gather images and poses 319 | self.all_imgs = {} 320 | self.all_poses = {} 321 | for s in splits: 322 | meta = metas[s] 323 | imgs, poses = self.load_images(s, meta, testskip) 324 | 325 | self.all_imgs.update({s: imgs}) 326 | self.all_poses.update({s: poses}) 327 | 328 | if self.final_render: 329 | self.poses = [torch.from_numpy(self.pose_spherical(angle, -30.0, 4.0)).float() 330 | for angle in np.linspace(-180, 180, 40 + 1)[:-1]] 331 | 332 | if override_scale is not None: 333 | assert multiscale, 'only for multiscale' 334 | if override_scale > 3: 335 | override_scale = 3 336 | H, W = self.multiscale_imgs[0][override_scale].shape[:2] 337 | self.img_shape = self.multiscale_imgs[0][override_scale].shape 338 | else: 339 | H, W = imgs[0].shape[:2] 340 | self.img_shape = imgs[0].shape 341 | 342 | # projective camera 343 | camera_angle_x = float(meta['camera_angle_x']) 344 | focal = .5 * W / np.tan(.5 * camera_angle_x) 345 | self.camera_params = {'H': H, 'W': W, 346 | 'camera_angle_x': camera_angle_x, 347 | 'focal': focal, 348 | 'near': 2.0, 349 | 'far': 6.0} 350 | 351 | def load_images(self, s, meta, testskip): 352 | imgs = [] 353 | poses = [] 354 | 355 | if s == 'train' or testskip == 0: 356 | skip = 1 357 | else: 358 | skip = testskip 359 | 360 | for frame in tqdm(meta['frames'][::skip]): 361 | if self.select_idx is not None: 362 | if re.search('[0-9]+', frame['file_path']).group(0) != self.select_idx: 363 | continue 364 | 365 | def load_image(fname): 366 | img = Image.open(fname) 367 | pose = torch.from_numpy(np.array(frame['transform_matrix'], dtype=np.float32)) 368 | 369 | img_t = self.transforms(img) 370 | imgs.append(img_t.permute(1, 2, 0)) 371 | poses.append(pose) 372 | 373 | if self.multiscale: 374 | for i in range(4): 375 | fname = os.path.join(self.basedir, frame['file_path']).replace(s, s + '_multiscale') + f'_d{i}.png' 376 | load_image(fname) 377 | 378 | else: 379 | fname = os.path.join(self.basedir, frame['file_path'] + '.png') 380 | load_image(fname) 381 | 382 | if self.multiscale: 383 | poses = poses[::4] 384 | self.multiscale_imgs = [imgs[i:i+4][::-1] for i in range(0, len(imgs), 4)] 385 | imgs = imgs[::4] 386 | 387 | return imgs, poses 388 | 389 | # adapted from https://github.com/krrish94/nerf-pytorch 390 | # derived from original NeRF repo (MIT License) 391 | def translate_by_t_along_z(self, t): 392 | tform = np.eye(4).astype(np.float32) 393 | tform[2][3] = t 394 | return tform 395 | 396 | def rotate_by_phi_along_x(self, phi): 397 | tform = np.eye(4).astype(np.float32) 398 | tform[1, 1] = tform[2, 2] = np.cos(phi) 399 | tform[1, 2] = -np.sin(phi) 400 | tform[2, 1] = -tform[1, 2] 401 | return tform 402 | 403 | def rotate_by_theta_along_y(self, theta): 404 | tform = np.eye(4).astype(np.float32) 405 | tform[0, 0] = tform[2, 2] = np.cos(theta) 406 | tform[0, 2] = -np.sin(theta) 407 | tform[2, 0] = -tform[0, 2] 408 | return tform 409 | 410 | def pose_spherical(self, theta, phi, radius): 411 | c2w = self.translate_by_t_along_z(radius) 412 | c2w = self.rotate_by_phi_along_x(phi / 180.0 * np.pi) @ c2w 413 | c2w = self.rotate_by_theta_along_y(theta / 180 * np.pi) @ c2w 414 | c2w = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w 415 | return c2w 416 | 417 | def set_mode(self, mode): 418 | self.mode = mode 419 | 420 | def get_img_shape(self): 421 | return self.img_shape 422 | 423 | def get_camera_params(self): 424 | return self.camera_params 425 | 426 | def __len__(self): 427 | if self.final_render: 428 | return len(self.poses) 429 | else: 430 | return len(self.all_imgs[self.mode]) 431 | 432 | def __getitem__(self, item): 433 | # render out trajectory (no GT images) 434 | if self.final_render: 435 | return {'img': torch.zeros(4), # we have to pass something... 436 | 'pose': self.poses[item]} 437 | 438 | # otherwise, return GT images and pose 439 | else: 440 | return {'img': self.all_imgs[self.mode][item], 441 | 'pose': self.all_poses[self.mode][item]} 442 | 443 | 444 | class Implicit6DMultiviewDataWrapper(torch.utils.data.Dataset): 445 | def __init__(self, dataset, img_shape, camera_params, 446 | samples_per_ray=128, 447 | samples_per_view=32000, 448 | num_workers=4, 449 | multiscale=False, 450 | supervise_hr=False, 451 | scales=[1/8, 1/4, 1/2, 1]): 452 | 453 | self.dataset = dataset 454 | self.num_workers = num_workers 455 | self.multiscale = multiscale 456 | self.scales = scales 457 | self.supervise_hr = supervise_hr 458 | 459 | self.img_shape = img_shape 460 | self.camera_params = camera_params 461 | 462 | self.samples_per_view = samples_per_view 463 | self.default_samples_per_view = samples_per_view 464 | self.samples_per_ray = samples_per_ray 465 | 466 | self._generate_rays_normalized() 467 | self._precompute_rays() 468 | 469 | self.is_logging = False 470 | 471 | self.val_idx = 0 472 | 473 | self.num_rays = self.all_ray_orgs.view(-1, 3).shape[0] 474 | 475 | self.shuffle_rays() 476 | 477 | if multiscale: 478 | self.multiscale_imgs = dataset.multiscale_imgs 479 | 480 | # switch to size [num_scales, num_views, img_size[0], img_size[1], 3] 481 | self.multiscale_imgs = torch.stack([torch.stack(m, dim=0) 482 | for m in zip(*self.multiscale_imgs)], dim=0) 483 | 484 | def toggle_logging_sampling(self): 485 | if self.is_logging: 486 | self.samples_per_view = self.default_samples_per_view 487 | self.is_logging = False 488 | else: 489 | self.samples_per_view = self.img_shape[0] * self.img_shape[1] 490 | self.is_logging = True 491 | 492 | def _generate_rays_normalized(self): 493 | 494 | # projective camera 495 | rows = torch.arange(0, self.img_shape[0], dtype=torch.float32) 496 | cols = torch.arange(0, self.img_shape[1], dtype=torch.float32) 497 | g_rows, g_cols = torch.meshgrid(rows, cols) 498 | 499 | W = self.camera_params['W'] 500 | H = self.camera_params['H'] 501 | f = self.camera_params['focal'] 502 | 503 | self.norm_rays = torch.stack([(g_cols-.5*W + 0.5)/f, 504 | -(g_rows-.5*H + 0.5)/f, 505 | -torch.ones_like(g_rows)], 506 | dim=2).view(-1, 3).permute(1, 0) 507 | 508 | self.num_rays_per_view = self.norm_rays.shape[1] 509 | 510 | def shuffle_rays(self): 511 | self.shuffle_idxs = torch.randperm(self.num_rays) 512 | 513 | def _precompute_rays(self): 514 | img_list = [] 515 | pose_list = [] 516 | ray_orgs_list = [] 517 | ray_dirs_list = [] 518 | 519 | print('Precomputing rays...') 520 | for img_pose in tqdm(self.dataset): 521 | img = img_pose['img'] 522 | img_list.append(img) 523 | 524 | pose = img_pose['pose'] 525 | pose_list.append(pose) 526 | 527 | ray_dirs = pose[:3, :3].matmul(self.norm_rays).permute(1, 0) 528 | ray_dirs_list.append(ray_dirs) 529 | 530 | ray_orgs = pose[:3, 3].repeat((self.num_rays_per_view, 1)) 531 | ray_orgs_list.append(ray_orgs) 532 | 533 | self.all_imgs = torch.stack(img_list, dim=0) 534 | self.all_poses = torch.stack(pose_list, dim=0) 535 | self.all_ray_orgs = torch.stack(ray_orgs_list, dim=0) 536 | self.all_ray_dirs = torch.stack(ray_dirs_list, dim=0) 537 | 538 | self.hit = torch.zeros(self.all_ray_dirs.view(-1, 3).shape[0]) 539 | 540 | def __len__(self): 541 | if self.is_logging: 542 | return self.all_imgs.shape[0] 543 | else: 544 | return self.num_rays // self.samples_per_view 545 | 546 | def get_val_rays(self): 547 | img = self.all_imgs[self.val_idx, ...] 548 | ray_dirs = self.all_ray_dirs[self.val_idx, ...] 549 | ray_orgs = self.all_ray_orgs[self.val_idx, ...] 550 | view_samples = img 551 | 552 | if self.multiscale: 553 | img = self.multiscale_imgs[:, self.val_idx, ...] 554 | if self.supervise_hr: 555 | img = [img[-1] for _ in img] 556 | view_samples = [im for im in img] 557 | 558 | self.val_idx += 1 559 | self.val_idx %= self.all_imgs.shape[0] 560 | 561 | return view_samples, ray_orgs, ray_dirs 562 | 563 | def get_rays(self, idx): 564 | idxs = self.shuffle_idxs[self.samples_per_view * idx:self.samples_per_view * (idx+1)] 565 | ray_dirs = self.all_ray_dirs.view(-1, 3)[idxs, ...] 566 | ray_orgs = self.all_ray_orgs.view(-1, 3)[idxs, ...] 567 | 568 | if self.multiscale: 569 | view_samples = [mimg.view(-1, 3)[idxs] for mimg in self.multiscale_imgs] 570 | 571 | if self.supervise_hr: 572 | view_samples = [view_samples[-1] for _ in view_samples] 573 | else: 574 | img = self.all_imgs.view(-1, 3)[idxs, ...] 575 | view_samples = img.reshape(-1, 3) 576 | 577 | self.hit[idxs] += 1 578 | 579 | return view_samples, ray_orgs, ray_dirs 580 | 581 | def __getitem__(self, idx): 582 | 583 | if self.is_logging: 584 | view_samples, ray_orgs, ray_dirs = self.get_val_rays() 585 | else: 586 | view_samples, ray_orgs, ray_dirs = self.get_rays(idx) 587 | 588 | # Transform coordinate systems 589 | camera_params = self.dataset.get_camera_params() 590 | 591 | ray_dirs = ray_dirs[:, None, :] 592 | ray_orgs = ray_orgs[:, None, :] 593 | 594 | t_vals = torch.linspace(0.0, 1.0, self.samples_per_ray) 595 | t_vals = camera_params['near'] * (1.0 - t_vals) + camera_params['far'] * t_vals 596 | t_vals = t_vals[None, :].repeat(self.samples_per_view, 1) 597 | 598 | mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1]) 599 | upper = torch.cat((mids, t_vals[..., -1:]), dim=-1) 600 | lower = torch.cat((t_vals[..., :1], mids), dim=-1) 601 | 602 | # Stratified samples in those intervals. 603 | t_rand = torch.rand(t_vals.shape) 604 | t_vals = lower + (upper - lower) * t_rand 605 | 606 | ray_samples = ray_orgs + ray_dirs * t_vals[..., None] 607 | 608 | t_intervals = t_vals[..., 1:] - t_vals[..., :-1] 609 | t_intervals = torch.cat((t_intervals, 1e10*torch.ones_like(t_intervals[:, 0:1])), dim=-1) 610 | t_intervals = (t_intervals * ray_dirs.norm(p=2, dim=-1))[..., None] 611 | 612 | # Compute distance samples from orgs 613 | dist_samples_to_org = torch.sqrt(torch.sum((ray_samples-ray_orgs)**2, dim=-1, keepdim=True)) 614 | 615 | # broadcast tensors 616 | view_dirs = ray_dirs / ray_dirs.norm(p=2, dim=-1, keepdim=True).repeat(1, self.samples_per_ray, 1) 617 | 618 | in_dict = {'ray_samples': ray_samples, 619 | 'ray_orientations': view_dirs, 620 | 'ray_origins': ray_orgs, 621 | 't_intervals': t_intervals, 622 | 't': t_vals[..., None], 623 | 'ray_directions': ray_dirs} 624 | meta_dict = {'zs': dist_samples_to_org} 625 | 626 | gt_dict = {'pixel_samples': view_samples} 627 | 628 | return in_dict, meta_dict, gt_dict 629 | 630 | 631 | class MeshSDF(Dataset): 632 | ''' convert point cloud to SDF ''' 633 | 634 | def __init__(self, pointcloud_path, num_samples=30**3, 635 | coarse_scale=1e-1, fine_scale=1e-3): 636 | super().__init__() 637 | self.num_samples = num_samples 638 | self.pointcloud_path = pointcloud_path 639 | self.coarse_scale = coarse_scale 640 | self.fine_scale = fine_scale 641 | 642 | self.load_mesh(pointcloud_path) 643 | 644 | def __len__(self): 645 | return 10000 # arbitrary 646 | 647 | def load_mesh(self, pointcloud_path): 648 | pointcloud = np.genfromtxt(pointcloud_path) 649 | self.v = pointcloud[:, :3] 650 | self.n = pointcloud[:, 3:] 651 | 652 | n_norm = (np.linalg.norm(self.n, axis=-1)[:, None]) 653 | n_norm[n_norm == 0] = 1. 654 | self.n = self.n / n_norm 655 | self.v = self.normalize(self.v) 656 | self.kd_tree = KDTree(self.v) 657 | print('loaded pc') 658 | 659 | def normalize(self, coords): 660 | coords -= np.mean(coords, axis=0, keepdims=True) 661 | coord_max = np.amax(coords) 662 | coord_min = np.amin(coords) 663 | coords = (coords - coord_min) / (coord_max - coord_min) * 0.9 664 | coords -= 0.45 665 | return coords 666 | 667 | def sample_surface(self): 668 | idx = np.random.randint(0, self.v.shape[0], self.num_samples) 669 | points = self.v[idx] 670 | points[::2] += np.random.laplace(scale=self.coarse_scale, size=(points.shape[0]//2, points.shape[-1])) 671 | points[1::2] += np.random.laplace(scale=self.fine_scale, size=(points.shape[0]//2, points.shape[-1])) 672 | 673 | # wrap around any points that are sampled out of bounds 674 | points[points > 0.5] -= 1 675 | points[points < -0.5] += 1 676 | 677 | # use KDTree to get distance to surface and estimate the normal 678 | sdf, idx = self.kd_tree.query(points, k=3) 679 | avg_normal = np.mean(self.n[idx], axis=1) 680 | sdf = np.sum((points - self.v[idx][:, 0]) * avg_normal, axis=-1) 681 | sdf = sdf[..., None] 682 | 683 | return points, sdf 684 | 685 | def __getitem__(self, idx): 686 | coords, sdf = self.sample_surface() 687 | 688 | return {'coords': torch.from_numpy(coords).float()}, \ 689 | {'sdf': torch.from_numpy(sdf).float()} 690 | -------------------------------------------------------------------------------- /download_datasets.py: -------------------------------------------------------------------------------- 1 | import gdown 2 | import zipfile 3 | import subprocess 4 | import urllib.request 5 | import os 6 | 7 | # Make folder to save data in. 8 | os.makedirs('./data', exist_ok=True) 9 | 10 | # nerf 11 | print('Downloading blender dataset') 12 | gdown.download("https://drive.google.com/uc?id=18JxhpWD-4ZmuFKLzKlAw-w5PpzZxXOcG", './data/nerf_synthetic.zip') 13 | 14 | print('Extracting ...') 15 | with zipfile.ZipFile('./data/nerf_synthetic.zip', 'r') as f: 16 | f.extractall('./data/') 17 | 18 | print('Generating multiscale nerf dataset') 19 | subprocess.run(['python', 'convert_blender_data.py', '--blenderdir', './data/nerf_synthetic']) 20 | 21 | print('Downloading SDF datasets') 22 | gdown.download("https://drive.google.com/uc?id=1xBo6OCGmyWi0qD74EZW4lc45Gs4HXjWw", './data/gt_armadillo.xyz') 23 | gdown.download("https://drive.google.com/uc?id=1Pm3WHUvJiMJEKUnnhMjB6mUAnR9qhnxm", './data/gt_dragon.xyz') 24 | gdown.download("https://drive.google.com/uc?id=1wE24AZtXS8jbIIc-amYeEUtlxN8dFYCo", './data/gt_lucy.xyz') 25 | gdown.download("https://drive.google.com/uc?id=1OVw0JNA-NZtDXVmkf57erqwqDjqmF5Mc", './data/gt_thai.xyz') 26 | 27 | print('Downloading image dataset') 28 | urllib.request.urlretrieve('http://www.cs.albany.edu/~xypan/research/img/Kodak/kodim19.png', './data/lighthouse.png') 29 | 30 | print('Done!') 31 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: bacon 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - numpy=1.21.2 7 | - python=3.8.12 8 | - pytorch=1.10.1 9 | - tensorboard=2.6.0 10 | - tensorboard-data-server=0.6.0 11 | - tensorboard-plugin-wit=1.6.0 12 | - torchvision=0.11.2 13 | - pykdtree=1.3.4 14 | - pip 15 | - pip: 16 | - argparse==1.4.0 17 | - configargparse==1.5.3 18 | - gdown==4.2.0 19 | - matplotlib==3.4.3 20 | - pymcubes==0.1.2 21 | - scikit-image==0.18.3 22 | - scipy==1.7.1 23 | - tqdm==4.62.3 24 | - trimesh==3.9.35 25 | - setuptools==59.5.0 26 | -------------------------------------------------------------------------------- /experiments/config/1d/bacon_freq1.ini: -------------------------------------------------------------------------------- 1 | experiment_name = bacon1_1d 2 | lr = 0.01 3 | gpu = 0 4 | model = mfn 5 | max_freq = 16.5 6 | batch_size = 1 7 | hidden_features = 128 8 | hidden_layers = 4 9 | num_steps = 1001 10 | activation = sine 11 | w0 = 10 12 | pe_scale = 3 13 | steps_til_ckpt = 100 14 | steps_til_summary = 100 15 | logging_root = ../logs 16 | -------------------------------------------------------------------------------- /experiments/config/1d/bacon_freq2.ini: -------------------------------------------------------------------------------- 1 | experiment_name = bacon2_1d 2 | lr = 0.01 3 | gpu = 0 4 | model = mfn 5 | max_freq = 8.0 6 | batch_size = 1 7 | hidden_features = 128 8 | hidden_layers = 4 9 | num_steps = 1001 10 | activation = sine 11 | w0 = 10 12 | pe_scale = 3 13 | steps_til_ckpt = 100 14 | steps_til_summary = 100 15 | logging_root = ../logs 16 | -------------------------------------------------------------------------------- /experiments/config/1d/ff.ini: -------------------------------------------------------------------------------- 1 | hidden_layers = 3 2 | experiment_name = ff_1d 3 | lr = 0.001 4 | gpu = 0 5 | model = mlp 6 | activation = relu 7 | w0 = 30.0 8 | pe_scale = 4.0 9 | max_freq = 16.0 10 | batch_size = 1 11 | hidden_features = 128 12 | num_steps = 1001 13 | steps_til_ckpt = 100 14 | steps_til_summary = 100 15 | logging_root = ../logs 16 | -------------------------------------------------------------------------------- /experiments/config/1d/siren.ini: -------------------------------------------------------------------------------- 1 | experiment_name = siren_1d 2 | lr = 0.0001 3 | gpu = 0 4 | model = mlp 5 | activation = sine 6 | w0 = 30.0 7 | pe_scale = 4.0 8 | max_freq = 16.0 9 | batch_size = 1 10 | hidden_features = 128 11 | hidden_layers = 4 12 | num_steps = 1001 13 | steps_til_ckpt = 100 14 | steps_til_summary = 100 15 | logging_root = ../logs 16 | -------------------------------------------------------------------------------- /experiments/config/img/bacon.ini: -------------------------------------------------------------------------------- 1 | config = ./image_bacon_config.ini 2 | experiment_name = bacon_img 3 | gpu = 1 4 | hidden_features = 256 5 | hidden_layers = 4 6 | res = 256 7 | centered = true 8 | model = mfn 9 | lr = 5e-3 10 | multiscale = true 11 | steps_til_summary = 100 12 | batch_size = 1 13 | activation = sine 14 | w0 = 10 15 | pe_scale = 3 16 | num_steps = 5001 17 | steps_til_ckpt = 100 18 | logging_root = ../logs 19 | -------------------------------------------------------------------------------- /experiments/config/img/ff.ini: -------------------------------------------------------------------------------- 1 | config = ./image_ff_config.ini 2 | experiment_name = ff_img 3 | gpu = 1 4 | hidden_features = 256 5 | hidden_layers = 3 6 | res = 256 7 | centered = true 8 | model = mlp 9 | lr = 0.001 10 | multiscale = false 11 | steps_til_summary = 100 12 | batch_size = 1 13 | activation = relu 14 | w0 = 30 15 | pe_scale = 6 16 | num_steps = 5001 17 | steps_til_ckpt = 100 18 | logging_root = ../logs 19 | -------------------------------------------------------------------------------- /experiments/config/img/mip.ini: -------------------------------------------------------------------------------- 1 | config = ./image_mip_config.ini 2 | experiment_name = mip_img 3 | gpu = 1 4 | hidden_features = 256 5 | hidden_layers = 4 6 | res = 256 7 | centered = true 8 | model = mlp 9 | lr = 0.001 10 | multiscale = true 11 | use_resized = true 12 | steps_til_summary = 100 13 | batch_size = 1 14 | activation = relu 15 | ipe = true 16 | w0 = 30 17 | pe_scale = 10 18 | num_steps = 5001 19 | steps_til_ckpt = 100 20 | logging_root = ../logs 21 | -------------------------------------------------------------------------------- /experiments/config/img/siren.ini: -------------------------------------------------------------------------------- 1 | config = ./image_siren_config.ini 2 | experiment_name = siren_img 3 | gpu = 1 4 | hidden_features = 256 5 | hidden_layers = 4 6 | res = 256 7 | centered = true 8 | model = mlp 9 | lr = 0.0001 10 | multiscale = false 11 | steps_til_summary = 100 12 | batch_size = 1 13 | activation = sine 14 | w0 = 30 15 | pe_scale = 6 16 | num_steps = 5001 17 | steps_til_ckpt = 100 18 | logging_root = ../logs 19 | -------------------------------------------------------------------------------- /experiments/config/nerf/bacon.ini: -------------------------------------------------------------------------------- 1 | experiment_name = bacon_lego 2 | gpu = 0 3 | num_workers = 0 4 | chunk_size_train = 4096 5 | chunk_size_eval = 4096 6 | hidden_layers = 8 7 | img_size = 512 8 | dataset_path = ../data/nerf_synthetic/lego 9 | steps_til_summary = 2000 10 | lr = 0.001 11 | hidden_features = 256 12 | multiscale = true 13 | use_resized = true 14 | reuse_filters = true 15 | samples_per_ray = 128 16 | samples_per_view = 4096 17 | logging_root = ../logs 18 | num_steps = 1000000 19 | steps_til_ckpt = 50000 20 | batch_size = 1 21 | -------------------------------------------------------------------------------- /experiments/config/nerf/bacon_lr.ini: -------------------------------------------------------------------------------- 1 | experiment_name = bacon_lego_lowres 2 | gpu = 0 3 | num_workers = 0 4 | chunk_size_train = 4096 5 | chunk_size_eval = 4096 6 | hidden_layers = 8 7 | img_size = 64 8 | dataset_path = ../data/nerf_synthetic/lego 9 | steps_til_summary = 1000 10 | lr = 0.001 11 | hidden_features = 256 12 | multiscale = true 13 | use_resized = true 14 | reuse_filters = true 15 | samples_per_ray = 128 16 | samples_per_view = 1024 17 | logging_root = ../logs 18 | num_steps = 1000000 19 | steps_til_ckpt = 50000 20 | batch_size = 1 21 | -------------------------------------------------------------------------------- /experiments/config/nerf/bacon_semisupervise.ini: -------------------------------------------------------------------------------- 1 | experiment_name = bacon_lego_semisupervise 2 | gpu = 0 3 | supervise_hr = true 4 | chunk_size_train = 4096 5 | chunk_size_eval = 4096 6 | hidden_layers = 8 7 | img_size = 512 8 | dataset_path = ../data/nerf_synthetic/lego 9 | steps_til_summary = 2000 10 | num_workers = 0 11 | lr = 0.001 12 | hidden_features = 256 13 | multiscale = true 14 | use_resized = true 15 | reuse_filters = true 16 | samples_per_ray = 128 17 | samples_per_view = 4096 18 | logging_root = ../logs 19 | num_steps = 1000000 20 | steps_til_ckpt = 50000 21 | batch_size = 1 22 | -------------------------------------------------------------------------------- /experiments/config/sdf/bacon_armadillo.ini: -------------------------------------------------------------------------------- 1 | point_cloud_path = ../data/gt_armadillo.xyz 2 | experiment_name = bacon_armadillo 3 | num_pts_on = 10000 4 | num_steps = 200000 5 | gpu = 0 6 | coarse_scale = 0.1 7 | fine_scale = 0.001 8 | max_freq = 384 9 | logging_root = ../logs 10 | steps_til_summary = 1000 11 | lr = 0.01 12 | model_type = mfn 13 | multiscale = true 14 | hidden_size = 256 15 | steps_til_ckpt = 50000 16 | ckpt_step = 0 17 | hidden_layers = 8 18 | w0 = 30 19 | num_workers = 0 20 | coarse_weight = 1e-2 21 | -------------------------------------------------------------------------------- /experiments/config/sdf/bacon_dragon.ini: -------------------------------------------------------------------------------- 1 | point_cloud_path = ../data/gt_dragon.xyz 2 | experiment_name = bacon_dragon 3 | num_pts_on = 10000 4 | num_steps = 200000 5 | gpu = 0 6 | coarse_scale = 0.1 7 | fine_scale = 0.001 8 | max_freq = 384 9 | logging_root = ../logs 10 | steps_til_summary = 1000 11 | lr = 0.01 12 | model_type = mfn 13 | multiscale = true 14 | hidden_size = 256 15 | steps_til_ckpt = 50000 16 | ckpt_step = 0 17 | hidden_layers = 8 18 | w0 = 30 19 | num_workers = 0 20 | coarse_weight = 1e-2 21 | -------------------------------------------------------------------------------- /experiments/config/sdf/bacon_lucy.ini: -------------------------------------------------------------------------------- 1 | point_cloud_path = ../data/gt_lucy.xyz 2 | experiment_name = bacon_lucy 3 | num_pts_on = 10000 4 | num_steps = 200000 5 | gpu = 0 6 | coarse_scale = 0.1 7 | fine_scale = 0.001 8 | max_freq = 512 9 | logging_root = ../logs 10 | steps_til_summary = 1000 11 | lr = 0.01 12 | model_type = mfn 13 | multiscale = true 14 | hidden_size = 256 15 | steps_til_ckpt = 50000 16 | ckpt_step = 0 17 | hidden_layers = 8 18 | w0 = 30 19 | num_workers = 0 20 | coarse_weight = 1e-2 21 | -------------------------------------------------------------------------------- /experiments/config/sdf/bacon_thai.ini: -------------------------------------------------------------------------------- 1 | point_cloud_path = ../data/gt_thai.xyz 2 | experiment_name = bacon_thai 3 | num_pts_on = 10000 4 | num_steps = 200000 5 | gpu = 0 6 | coarse_scale = 0.1 7 | fine_scale = 0.001 8 | max_freq = 512 9 | logging_root = ../logs 10 | steps_til_summary = 1000 11 | lr = 0.01 12 | model_type = mfn 13 | multiscale = true 14 | hidden_size = 256 15 | steps_til_ckpt = 50000 16 | ckpt_step = 0 17 | hidden_layers = 8 18 | w0 = 30 19 | num_workers = 0 20 | coarse_weight = 1e-2 21 | -------------------------------------------------------------------------------- /experiments/config/sdf/ff_armadillo.ini: -------------------------------------------------------------------------------- 1 | point_cloud_path = ../data/gt_armadillo.xyz 2 | experiment_name = ff_armadillo 3 | num_pts_on = 10000 4 | lr = 0.001 5 | num_steps = 200000 6 | gpu = 0 7 | model_type = ff 8 | coarse_scale = 0.1 9 | fine_scale = 0.001 10 | hidden_layers = 7 11 | max_freq = 384 12 | pe_scale = 8.0 13 | logging_root = ../logs 14 | steps_til_summary = 1000 15 | multiscale = false 16 | hidden_size = 256 17 | steps_til_ckpt = 50000 18 | ckpt_step = 0 19 | w0 = 30 20 | num_workers = 0 21 | coarse_weight = 1e-2 22 | -------------------------------------------------------------------------------- /experiments/config/sdf/ff_dragon.ini: -------------------------------------------------------------------------------- 1 | point_cloud_path = ../data/gt_dragon.xyz 2 | experiment_name = ff_dragon 3 | num_pts_on = 10000 4 | lr = 0.001 5 | num_steps = 200000 6 | gpu = 0 7 | model_type = ff 8 | coarse_scale = 0.1 9 | fine_scale = 0.001 10 | hidden_layers = 7 11 | max_freq = 384 12 | pe_scale = 8.0 13 | logging_root = ../logs 14 | steps_til_summary = 1000 15 | multiscale = false 16 | hidden_size = 256 17 | steps_til_ckpt = 50000 18 | ckpt_step = 0 19 | w0 = 30 20 | num_workers = 0 21 | coarse_weight = 1e-2 22 | -------------------------------------------------------------------------------- /experiments/config/sdf/ff_lucy.ini: -------------------------------------------------------------------------------- 1 | point_cloud_path = ../data/gt_lucy.xyz 2 | experiment_name = ff_lucy 3 | num_pts_on = 10000 4 | lr = 0.001 5 | num_steps = 200000 6 | gpu = 0 7 | model_type = ff 8 | coarse_scale = 0.1 9 | fine_scale = 0.001 10 | hidden_layers = 7 11 | max_freq = 384 12 | pe_scale = 8.0 13 | logging_root = ../logs 14 | steps_til_summary = 1000 15 | multiscale = false 16 | hidden_size = 256 17 | steps_til_ckpt = 50000 18 | ckpt_step = 0 19 | w0 = 30 20 | num_workers = 0 21 | coarse_weight = 1e-2 22 | -------------------------------------------------------------------------------- /experiments/config/sdf/ff_thai.ini: -------------------------------------------------------------------------------- 1 | point_cloud_path = ../data/gt_thai.xyz 2 | experiment_name = ff_thai 3 | num_pts_on = 10000 4 | lr = 0.001 5 | num_steps = 200000 6 | gpu = 0 7 | model_type = ff 8 | coarse_scale = 0.1 9 | fine_scale = 0.001 10 | hidden_layers = 7 11 | max_freq = 384 12 | pe_scale = 8.0 13 | logging_root = ../logs 14 | steps_til_summary = 1000 15 | multiscale = false 16 | hidden_size = 256 17 | steps_til_ckpt = 50000 18 | ckpt_step = 0 19 | w0 = 30 20 | num_workers = 0 21 | coarse_weight = 1e-2 22 | -------------------------------------------------------------------------------- /experiments/config/sdf/siren_armadillo.ini: -------------------------------------------------------------------------------- 1 | point_cloud_path = ../data/gt_armadillo.xyz 2 | experiment_name = siren_armadillo 3 | num_pts_on = 10000 4 | lr = 0.0001 5 | num_steps = 200000 6 | gpu = 0 7 | model_type = siren 8 | coarse_scale = 0.1 9 | fine_scale = 0.001 10 | max_freq = 384 11 | logging_root = ../logs 12 | steps_til_summary = 1000 13 | multiscale = false 14 | hidden_size = 256 15 | steps_til_ckpt = 50000 16 | ckpt_step = 0 17 | hidden_layers = 8 18 | w0 = 30 19 | num_workers = 0 20 | coarse_weight = 1e-2 21 | pe_scale = 5 22 | -------------------------------------------------------------------------------- /experiments/config/sdf/siren_dragon.ini: -------------------------------------------------------------------------------- 1 | point_cloud_path = ../data/gt_dragon.xyz 2 | experiment_name = siren_dragon 3 | num_pts_on = 10000 4 | lr = 0.0001 5 | num_steps = 200000 6 | gpu = 0 7 | model_type = siren 8 | coarse_scale = 0.1 9 | fine_scale = 0.001 10 | max_freq = 384 11 | logging_root = ../logs 12 | steps_til_summary = 1000 13 | multiscale = false 14 | hidden_size = 256 15 | steps_til_ckpt = 50000 16 | ckpt_step = 0 17 | hidden_layers = 8 18 | w0 = 30 19 | num_workers = 0 20 | coarse_weight = 1e-2 21 | pe_scale = 5 22 | -------------------------------------------------------------------------------- /experiments/config/sdf/siren_lucy.ini: -------------------------------------------------------------------------------- 1 | point_cloud_path = ../data/gt_lucy.xyz 2 | experiment_name = siren_lucy 3 | num_pts_on = 10000 4 | lr = 0.0001 5 | num_steps = 200000 6 | gpu = 0 7 | model_type = siren 8 | coarse_scale = 0.1 9 | fine_scale = 0.001 10 | max_freq = 384 11 | logging_root = ../logs 12 | steps_til_summary = 1000 13 | multiscale = false 14 | hidden_size = 256 15 | steps_til_ckpt = 50000 16 | ckpt_step = 0 17 | hidden_layers = 8 18 | w0 = 30 19 | num_workers = 0 20 | coarse_weight = 1e-2 21 | pe_scale = 5 22 | -------------------------------------------------------------------------------- /experiments/config/sdf/siren_thai.ini: -------------------------------------------------------------------------------- 1 | point_cloud_path = ../data/gt_thai.xyz 2 | experiment_name = siren_thai 3 | num_pts_on = 10000 4 | lr = 0.0001 5 | num_steps = 200000 6 | gpu = 0 7 | model_type = siren 8 | coarse_scale = 0.1 9 | fine_scale = 0.001 10 | max_freq = 384 11 | logging_root = ../logs 12 | steps_til_summary = 1000 13 | multiscale = false 14 | hidden_size = 256 15 | steps_til_ckpt = 50000 16 | ckpt_step = 0 17 | hidden_layers = 8 18 | w0 = 30 19 | num_workers = 0 20 | coarse_weight = 1e-2 21 | pe_scale = 5 22 | -------------------------------------------------------------------------------- /experiments/figure_setup.py: -------------------------------------------------------------------------------- 1 | import math 2 | import matplotlib.pyplot as plt 3 | import os 4 | import subprocess 5 | import tempfile 6 | 7 | 8 | def get_fig_size(fig_width_cm, fig_height_cm=None): 9 | """Convert dimensions in centimeters to inches. 10 | If no height is given, it is computed using the golden ratio. 11 | """ 12 | if not fig_height_cm: 13 | golden_ratio = (1 + math.sqrt(5))/2 14 | fig_height_cm = fig_width_cm / golden_ratio 15 | 16 | size_cm = (fig_width_cm, fig_height_cm) 17 | return map(lambda x: x/2.54, size_cm) 18 | 19 | 20 | """ 21 | The following functions can be used by scripts to get the sizes of 22 | the various elements of the figures. 23 | """ 24 | 25 | 26 | def label_size(): 27 | """Size of axis labels 28 | """ 29 | return 8 30 | 31 | 32 | def font_size(): 33 | """Size of all texts shown in plots 34 | """ 35 | return 8 36 | 37 | 38 | def ticks_size(): 39 | """Size of axes' ticks 40 | """ 41 | return 6 42 | 43 | 44 | def axis_lw(): 45 | """Line width of the axes 46 | """ 47 | return 0.6 48 | 49 | 50 | def plot_lw(): 51 | """Line width of the plotted curves 52 | """ 53 | return 1.0 54 | 55 | 56 | def figure_setup(): 57 | """Set all the sizes to the correct values and use 58 | tex fonts for all texts. 59 | """ 60 | 61 | params = {'text.usetex': False, 62 | 'figure.dpi': 150, 63 | 'font.size': font_size(), 64 | 'font.sans-serif': ['helvetica', 'Arial'], 65 | 'font.serif': ['Times', 'NimbusRomNo9L-Reg', 'TeX Gyre Termes', 'Times New Roman'], 66 | 'font.monospace': [], 67 | 'lines.linewidth': plot_lw(), 68 | 'axes.labelsize': label_size(), 69 | 'axes.titlesize': font_size(), 70 | 'axes.linewidth': axis_lw(), 71 | 'legend.fontsize': font_size(), 72 | 'xtick.labelsize': ticks_size(), 73 | 'ytick.labelsize': ticks_size(), 74 | 'font.family': 'sans-serif', 75 | 'xtick.bottom': False, 76 | 'xtick.top': False, 77 | 'ytick.left': False, 78 | 'ytick.right': False, 79 | 'xtick.major.pad': -1, 80 | 'ytick.major.pad': -2, 81 | 'xtick.minor.visible': False, 82 | 'ytick.minor.visible': False, 83 | 'xtick.labelsize': 8, 84 | 'ytick.labelsize': 8, 85 | 'axes.labelpad': 0, 86 | 'axes.titlepad': 3, 87 | 'axes.unicode_minus': False, 88 | 'pdf.fonttype': 42, 89 | 'ps.fonttype': 42} 90 | #'figure.constrained_layout.use': True} 91 | plt.rcParams.update(params) 92 | 93 | 94 | def save_fig(fig, file_name, fmt=None, dpi=150, tight=False): 95 | """Save a Matplotlib figure as EPS/PNG/PDF to the given path and trim it. 96 | """ 97 | 98 | if not fmt: 99 | fmt = file_name.strip().split('.')[-1] 100 | 101 | if fmt not in ['eps', 'png', 'pdf']: 102 | raise ValueError('unsupported format: %s' % (fmt,)) 103 | 104 | extension = '.%s' % (fmt,) 105 | if not file_name.endswith(extension): 106 | file_name += extension 107 | 108 | file_name = os.path.abspath(file_name) 109 | with tempfile.NamedTemporaryFile() as tmp_file: 110 | tmp_name = tmp_file.name + extension 111 | 112 | # save figure 113 | if tight: 114 | #fig.savefig(tmp_name, dpi=dpi, bbox_inches='tight') 115 | fig.savefig(file_name, dpi=dpi, bbox_inches='tight', pad_inches=0) 116 | else: 117 | fig.savefig(file_name, dpi=dpi) 118 | 119 | # trim it 120 | if fmt == 'eps': 121 | subprocess.call('epstool --bbox --copy %s %s' % 122 | (tmp_name, file_name), shell=True) 123 | elif fmt == 'png': 124 | subprocess.call('convert %s -trim %s' % 125 | (tmp_name, file_name), shell=True) 126 | #elif fmt == 'pdf': 127 | # subprocess.call('pdfcrop %s %s' % (tmp_name, file_name), shell=True) 128 | -------------------------------------------------------------------------------- /experiments/plot_activation_distributions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | 6 | import torch 7 | import modules 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import torch.nn as nn 11 | import warnings 12 | import figure_setup 13 | 14 | 15 | def make_square_axes(ax, scale=1): 16 | """Make an axes square in screen units. 17 | Should be called after plotting. 18 | """ 19 | ax.set_aspect(scale / ax.get_data_ratio()) 20 | 21 | 22 | # original MFN implementation 23 | class MFNBase(nn.Module): 24 | """ 25 | Multiplicative filter network base class. 26 | Expects the child class to define the 'filters' attribute, which should be 27 | a nn.ModuleList of n_layers+1 filters with output equal to hidden_size. 28 | """ 29 | 30 | def __init__( 31 | self, hidden_size, out_size, n_layers, weight_scale, bias=True, output_act=False 32 | ): 33 | super().__init__() 34 | 35 | self.linear = nn.ModuleList( 36 | [nn.Linear(hidden_size, hidden_size, bias) for _ in range(n_layers)] 37 | ) 38 | self.output_linear = nn.Linear(hidden_size, out_size) 39 | self.output_act = output_act 40 | 41 | for lin in self.linear: 42 | lin.weight.data.uniform_( 43 | -np.sqrt(weight_scale / hidden_size), 44 | np.sqrt(weight_scale / hidden_size), 45 | ) 46 | 47 | return 48 | 49 | def forward(self, x): 50 | out = self.filters[0](x) 51 | for i in range(1, len(self.filters)): 52 | out = self.filters[i](x) * self.linear[i - 1](out) 53 | out = self.output_linear(out) 54 | 55 | if self.output_act: 56 | out = torch.sin(out) 57 | 58 | return out 59 | 60 | 61 | class FourierLayer(nn.Module): 62 | """ 63 | Sine filter as used in FourierNet. 64 | """ 65 | 66 | def __init__(self, in_features, out_features, weight_scale): 67 | super().__init__() 68 | self.linear = nn.Linear(in_features, out_features) 69 | self.linear.weight.data *= weight_scale # gamma 70 | self.linear.bias.data.uniform_(-np.pi, np.pi) 71 | return 72 | 73 | def forward(self, x): 74 | return torch.sin(self.linear(x)) 75 | 76 | 77 | class FourierNet(MFNBase): 78 | def __init__( 79 | self, 80 | in_size, 81 | hidden_size, 82 | out_size, 83 | n_layers=3, 84 | input_scale=256.0, 85 | weight_scale=1.0, 86 | bias=True, 87 | output_act=False, 88 | ): 89 | super().__init__( 90 | hidden_size, out_size, n_layers, weight_scale, bias, output_act 91 | ) 92 | self.filters = nn.ModuleList( 93 | [ 94 | FourierLayer(in_size, hidden_size, input_scale / np.sqrt(n_layers + 1)) 95 | for _ in range(n_layers + 1) 96 | ] 97 | ) 98 | 99 | 100 | def plot_initialization(use_original=False, plot_fit=True): 101 | 102 | f = 256 103 | nl = 8 104 | 105 | if use_original: 106 | model = FourierNet(1, 1024, 1, nl) 107 | else: 108 | model = modules.BACON(1, 1024, 1, nl, frequency=(f, f)) 109 | 110 | coords = torch.linspace(-0.5, 0.5, 1000)[:, None] 111 | 112 | activations = [] 113 | activations.append(coords.mean(-1)) 114 | out = model.filters[0].linear(coords) 115 | activations.append(out.flatten()) 116 | out = model.filters[0](coords) 117 | activations.append(out.flatten()) 118 | for i in range(1, len(model.filters)): 119 | activations.append(model.linear[i-1](out).flatten()) 120 | out = model.filters[i](coords) * model.linear[i - 1](out) 121 | activations.append(out.flatten()) 122 | 123 | figure_setup.figure_setup() 124 | fig = plt.figure(figsize=(6.9, 5)) 125 | gs = fig.add_gridspec(5, 4, wspace=0.25, hspace=0.25) 126 | gs.update(left=0.05, right=1.0, top=0.95, bottom=0.05) 127 | 128 | for idx, act in enumerate(activations): 129 | 130 | if idx > 2: 131 | # plt.subplot(2, len(activations)//2 + 1, idx+2) 132 | fig.add_subplot(gs[(idx+1)//4, (idx+1) % 4]) 133 | else: 134 | fig.add_subplot(gs[0, idx]) 135 | 136 | # plt.subplot(2, len(activations)//2 + 1, idx+1) 137 | plt.hist(act.detach().cpu().numpy(), 50, density=True) 138 | 139 | if idx == 0: 140 | plt.title('Input') 141 | plt.xlim(-0.5, 0.5) 142 | 143 | if idx == 1: 144 | x = np.linspace(-120, 120, 1000) 145 | 146 | if plot_fit: 147 | plt.plot(x, 2/(2*np.pi*f/(nl+1)) * np.log(np.pi*f/(nl+1)/(np.minimum(abs(2*x), np.pi*f/(nl+1)))), 'r') 148 | plt.xlim(-120, 120) 149 | plt.title('Before Sine') 150 | 151 | if idx == 2: 152 | x = np.linspace(-1, 1, 1000) 153 | 154 | if plot_fit: 155 | with warnings.catch_warnings(): 156 | warnings.simplefilter('ignore') 157 | plt.plot(x, 1/np.pi * 1 / (np.sqrt(1 - x**2)), 'r') 158 | plt.xlim(-1, 1) 159 | plt.title('After Sine') 160 | 161 | if idx % 2 == 1 and idx > 2: 162 | x = np.linspace(-4, 4, 1000) 163 | if plot_fit: 164 | plt.plot(x, 1 / np.sqrt(2*np.pi) * np.exp(-x**2/2), 'r') 165 | plt.xlim(-4, 4) 166 | plt.title('After Linear') 167 | 168 | if idx % 2 == 0 and idx > 2: 169 | # this is the product of a standard normal and 170 | # an arcsine distributed RV, which is an RV 171 | # that obeys the product rule 172 | # https://en.wikipedia.org/wiki/Distribution_of_the_product_of_two_random_variables 173 | # and its variance will be 1/2. Then, we can use the initialization scheme of siren (Thm 1.8) 174 | 175 | zs = np.linspace(-4, 4, 1000) 176 | x = np.linspace(-4, 4, 5000) 177 | out = np.zeros_like(zs) 178 | dx = x[1] - x[0] 179 | for idx, z in enumerate(zs): 180 | zdx2 = (z/x)**2 181 | with warnings.catch_warnings(): 182 | warnings.simplefilter('ignore') 183 | tmp = 1 / np.sqrt(2*np.pi) * np.exp(-x**2/2) * 1/np.pi * 1/np.sqrt(1 - zdx2) * 1/abs(x) * dx 184 | tmp[np.isnan(tmp)] = 0 185 | out[idx] = tmp.sum() 186 | 187 | if plot_fit: 188 | plt.plot(zs, out, 'r') 189 | plt.title('After Product') 190 | plt.xlim(-4, 4) 191 | 192 | make_square_axes(plt.gca(), 0.5) 193 | plt.show() 194 | 195 | 196 | if __name__ == '__main__': 197 | plot_initialization(use_original=False, plot_fit=True) 198 | plot_initialization(use_original=True, plot_fit=False) 199 | -------------------------------------------------------------------------------- /experiments/render_nerf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import utils 6 | import dataio 7 | import modules 8 | import forward_models 9 | import training 10 | import torch 11 | import numpy as np 12 | import configargparse 13 | import dataclasses 14 | from dataclasses import dataclass 15 | from tqdm import tqdm 16 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity 17 | import skimage.io 18 | from functools import partial 19 | 20 | 21 | torch.backends.cudnn.benchmark = True 22 | 23 | ssim_fn = partial(structural_similarity, data_range=1, 24 | gaussian_weights=True, sigma=1.5, 25 | use_sample_covariance=False, 26 | multichannel=True, 27 | channel_axis=-1) 28 | 29 | 30 | @dataclass 31 | class Options: 32 | config: str 33 | experiment_name: str 34 | logging_root: str 35 | dataset_path: str 36 | num_epochs: int 37 | epochs_til_ckpt: int 38 | steps_til_summary: int 39 | gpu: int 40 | img_size: int 41 | chunk_size_train: int 42 | chunk_size_eval: int 43 | num_workers: int 44 | lr: float 45 | batch_size: int 46 | hidden_features: int 47 | hidden_layers: int 48 | model: str 49 | activation: str 50 | multiscale: bool 51 | single_network: bool 52 | use_resized: bool 53 | reuse_filters: bool 54 | samples_per_ray: int 55 | samples_per_view: int 56 | forward_mode: str 57 | supervise_hr: bool 58 | rank: int 59 | 60 | def __init__(self, **kwargs): 61 | names = set([f.name for f in dataclasses.fields(self)]) 62 | for k, v in kwargs.items(): 63 | if k in names: 64 | setattr(self, k, self.__annotations__[k](v)) 65 | 66 | if 'supervise_hr' not in kwargs.keys(): 67 | self.supervise_hr = False 68 | 69 | self.img_size = 512 70 | 71 | 72 | def load_dataset(opt, res, scale): 73 | dataset = dataio.NerfBlenderDataset(opt.dataset_path, 74 | splits=['test'], 75 | mode='test', 76 | resize_to=(int(res/2**(3-scale)), int(res/2**(3-scale))), 77 | multiscale=opt.multiscale, 78 | override_scale=scale, 79 | testskip=1) 80 | 81 | coords_dataset = dataio.Implicit6DMultiviewDataWrapper(dataset, 82 | (int(res/2**(3-scale)), int(res/2**(3-scale))), 83 | dataset.get_camera_params(), 84 | samples_per_ray=256, # opt.samples_per_ray, 85 | samples_per_view=opt.samples_per_view, 86 | num_workers=opt.num_workers, 87 | multiscale=opt.use_resized, 88 | supervise_hr=opt.supervise_hr, 89 | scales=[1/8, 1/4, 1/2, 1]) 90 | coords_dataset.toggle_logging_sampling() 91 | 92 | return coords_dataset 93 | 94 | 95 | def load_model(opt, checkpoint): 96 | # since model goes between -4 and 4 instead of -0.5 to 0.5 97 | sample_frequency = 3*(opt.img_size/4,) 98 | 99 | if opt.multiscale: 100 | # scale the frequencies of each layer accordingly 101 | input_scales = [1/24, 1/24, 1/24, 1/16, 1/16, 1/8, 1/8, 1/4, 1/4] 102 | output_layers = [2, 4, 6, 8] 103 | 104 | with utils.HiddenPrint(): 105 | model = modules.MultiscaleBACON(3, opt.hidden_features, 4, 106 | hidden_layers=opt.hidden_layers, 107 | bias=True, 108 | frequency=sample_frequency, 109 | quantization_interval=np.pi/4, 110 | input_scales=input_scales, 111 | output_layers=output_layers, 112 | reuse_filters=opt.reuse_filters) 113 | model.cuda() 114 | 115 | print('Loading checkpoints') 116 | state_dict = torch.load(checkpoint) 117 | model.load_state_dict(state_dict, strict=False) 118 | 119 | models = {'combined': model} 120 | 121 | return models 122 | 123 | 124 | def render_in_chunks(in_dict, model, chunk_size, return_all=False): 125 | batches, rays, samples, dims = in_dict['ray_samples'].shape 126 | 127 | in_dict['ray_samples'] = in_dict['ray_samples'].reshape(-1, 3) 128 | 129 | if return_all: 130 | out = [torch.zeros(batches, rays, samples, 4, device=in_dict['ray_samples'].device) for i in range(4)] 131 | out = [o.reshape(-1, 4) for o in out] 132 | else: 133 | out = torch.zeros(batches, rays, samples, 4, device=in_dict['ray_samples'].device) 134 | out = out.reshape(-1, 4) 135 | 136 | chunk_size *= 128 137 | num_chunks = int(np.ceil(rays*samples / (chunk_size))) 138 | 139 | for i in range(num_chunks): 140 | tmp = {'ray_samples': in_dict['ray_samples'][i*chunk_size:(i+1)*chunk_size, ...]} 141 | 142 | if return_all: 143 | for j in range(4): 144 | out[j][i*chunk_size:(i+1)*chunk_size, ...] = model(tmp)['model_out']['output'][j] 145 | else: 146 | out[i*chunk_size:(i+1)*chunk_size, ...] = model(tmp)['model_out']['output'][-1] 147 | 148 | if return_all: 149 | out = [o.reshape(batches, rays, samples, 4) for o in out] 150 | return {'model_out': {'output': out}, 'model_in': {'t_intervals': in_dict['t_intervals']}} 151 | 152 | else: 153 | out = out.reshape(batches, rays, samples, 4) 154 | return {'model_out': {'output': [out]}, 'model_in': {'t_intervals': in_dict['t_intervals']}} 155 | 156 | 157 | def render_all_in_chunks(in_dict, model, scale, chunk_size): 158 | batches, rays, samples, dims = in_dict['ray_samples'].shape 159 | out = torch.zeros(batches, rays, 3) 160 | num_chunks = int(np.ceil(rays / (chunk_size))) 161 | 162 | with torch.no_grad(): 163 | for i in range(num_chunks): 164 | # transfer to cuda 165 | model_in = {k: v[:, i*chunk_size:(i+1)*chunk_size, ...].cuda() for k, v in in_dict.items()} 166 | model_in = training.dict2cuda(model_in) 167 | 168 | # run first forward pass for importance sampling 169 | model.stop_after = 0 170 | model_out = {'combined': model(model_in)} 171 | 172 | # resample rays 173 | model_in = training.sample_pdf(model_in, model_out, idx=0) 174 | 175 | # importance sampled pass 176 | model.stop_after = scale 177 | model_out = {'combined': model(model_in)} 178 | 179 | # render outputs 180 | sigma = model_out['combined']['model_out']['output'][-1][..., -1:] 181 | rgb = model_out['combined']['model_out']['output'][-1][..., :-1] 182 | 183 | t_interval = model_in['t_intervals'] 184 | 185 | pred_weights = forward_models.compute_transmittance_weights(sigma, t_interval) 186 | pred_pixels = forward_models.compute_tomo_radiance(pred_weights, rgb) 187 | 188 | out[:, i*chunk_size:(i+1)*chunk_size, ...] = pred_pixels.cpu() 189 | 190 | pred_view = out.view(int(np.sqrt(rays)), int(np.sqrt(rays)), 3).detach().cpu() 191 | pred_view = torch.clamp(pred_view, 0, 1).numpy() 192 | return pred_view 193 | 194 | 195 | def render_image(opt, models, dataset, chunk_size, 196 | in_dict, meta_dict, gt_dict, scale, 197 | return_all=False): 198 | 199 | # add batch dimension 200 | for k, v in in_dict.items(): 201 | in_dict[k].unsqueeze_(0) 202 | 203 | for i in range(len(gt_dict['pixel_samples'])): 204 | gt_dict['pixel_samples'][i].unsqueeze_(0) 205 | 206 | use_chunks = True 207 | if in_dict['ray_samples'].shape[1] < chunk_size: 208 | use_chunks = False 209 | use_chunks = True 210 | 211 | # render the whole thing in chunks 212 | if scale > 3: 213 | pred_view = render_all_in_chunks(in_dict, models['combined'], scale, chunk_size) 214 | return pred_view, 0.0, 0.0, 0.0 215 | 216 | in_dict = training.dict2cuda(in_dict) 217 | 218 | start = torch.cuda.Event(enable_timing=True) 219 | end = torch.cuda.Event(enable_timing=True) 220 | start.record() 221 | 222 | with torch.no_grad(): 223 | models['combined'].stop_after = 0 224 | if use_chunks: 225 | out_dict = {key: render_in_chunks(in_dict, model, chunk_size) 226 | for key, model in models.items()} 227 | else: 228 | out_dict = {key: model(in_dict) for key, model in models.items()} 229 | models['combined'].stop_after = scale 230 | 231 | in_dict = training.sample_pdf(in_dict, out_dict, idx=0) 232 | 233 | if use_chunks: 234 | out_dict = {key: render_in_chunks(in_dict, model, chunk_size, return_all=return_all) 235 | for key, model in models.items()} 236 | 237 | else: 238 | out_dict = {key: model(in_dict) for key, model in models.items()} 239 | 240 | if return_all: 241 | sigma = [s[..., -1:] for s in out_dict['combined']['model_out']['output']] 242 | rgb = [c[..., :-1] for c in out_dict['combined']['model_out']['output']] 243 | t_interval = in_dict['t_intervals'] 244 | 245 | if isinstance(gt_dict['pixel_samples'], list): 246 | gt_view = gt_dict['pixel_samples'][scale].squeeze(0).numpy() 247 | else: 248 | gt_view = gt_dict['pixel_samples'].detach().squeeze(0).numpy() 249 | view_shape = gt_view.shape 250 | 251 | pred_weights = [forward_models.compute_transmittance_weights(s, t_interval) for s in sigma] 252 | pred_pixels = [forward_models.compute_tomo_radiance(w, c) for w, c in zip(pred_weights, rgb)] 253 | 254 | pred_view = [p.view(view_shape).detach().cpu() for p in pred_pixels] 255 | pred_view = [torch.clamp(p, 0, 1).numpy() for p in pred_view] 256 | 257 | end.record() 258 | torch.cuda.synchronize() 259 | elapsed = start.elapsed_time(end) 260 | 261 | return pred_view, 0, 0, elapsed 262 | 263 | else: 264 | sigma = out_dict['combined']['model_out']['output'][-1][..., -1:] 265 | rgb = out_dict['combined']['model_out']['output'][-1][..., :-1] 266 | t_interval = in_dict['t_intervals'] 267 | 268 | if isinstance(gt_dict['pixel_samples'], list): 269 | gt_view = gt_dict['pixel_samples'][scale].squeeze(0).numpy() 270 | else: 271 | gt_view = gt_dict['pixel_samples'].detach().squeeze(0).numpy() 272 | view_shape = gt_view.shape 273 | 274 | pred_weights = forward_models.compute_transmittance_weights(sigma, t_interval) 275 | pred_pixels = forward_models.compute_tomo_radiance(pred_weights, rgb) 276 | 277 | # log the images 278 | end.record() 279 | torch.cuda.synchronize() 280 | elapsed = start.elapsed_time(end) 281 | 282 | pred_view = pred_pixels.view(view_shape).detach().cpu() 283 | pred_view = torch.clamp(pred_view, 0, 1).numpy() 284 | 285 | psnr = peak_signal_noise_ratio(gt_view, pred_view) 286 | ssim = ssim_fn(gt_view, pred_view) 287 | 288 | return pred_view, psnr, ssim, elapsed 289 | 290 | 291 | def eval_nerf_bacon(scene, config, checkpoint, outdir, res, scale, chunk_size=10000, return_all=False, val_idx=None): 292 | 293 | os.makedirs(f'./outputs/nerf/{outdir}', exist_ok=True) 294 | 295 | p = configargparse.DefaultConfigFileParser() 296 | with open(config) as f: 297 | opt = p.parse(f) 298 | 299 | opt = Options(**opt) 300 | dataset = load_dataset(opt, res, scale) 301 | models = load_model(opt, checkpoint) 302 | 303 | for k in models.keys(): 304 | models[k].stop_after = scale 305 | 306 | # render images 307 | psnrs = [] 308 | ssims = [] 309 | dataset_generator = iter(dataset) 310 | for idx in range(len(dataset)): 311 | 312 | if val_idx is not None: 313 | dataset.val_idx = val_idx 314 | idx = val_idx 315 | 316 | in_dict, meta_dict, gt_dict = next(dataset_generator) 317 | 318 | images, psnr, ssim, elapsed = render_image(opt, models, dataset, chunk_size, 319 | in_dict, meta_dict, gt_dict, 320 | scale, return_all=return_all) 321 | 322 | tqdm.write(f'Scale: {scale} | PSNR: {psnr:.02f} dB, SSIM: {ssim:.02f}, Elapsed: {elapsed:.02f} ms') 323 | 324 | if return_all: 325 | for s in range(4): 326 | skimage.io.imsave(f'./outputs/nerf/{outdir}/r_{idx}_d{3-s}.png', (images[s]*255).astype(np.uint8)) 327 | else: 328 | np.save(f'./outputs/nerf/{outdir}/r_{idx}_d{3-scale}.npy', {'psnr': psnr, 'ssim': ssim}) 329 | skimage.io.imsave(f'./outputs/nerf/{outdir}/r_{idx}_d{3-scale}.png', (images*255).astype(np.uint8)) 330 | 331 | psnrs.append(psnr) 332 | ssims.append(ssim) 333 | 334 | if val_idx is not None: 335 | break 336 | 337 | if not return_all and val_idx is not None: 338 | np.save(f'./outputs/nerf/{outdir}/metrics_d{3-scale}.npy', {'psnr': psnrs, 'ssim': ssims, 339 | 'avg_psnr': np.mean(psnrs), 340 | 'avg_ssim': np.mean(ssims)}) 341 | 342 | print(f'Avg. PSNR: {np.mean(psnrs):.02f}, Avg. SSIM: {np.mean(ssims):.02f}') 343 | 344 | 345 | if __name__ == '__main__': 346 | # before running this you need to download the nerf blender datasets for the lego model 347 | # and place in ../data/nerf_synthetic/lego 348 | # these can be downloaded here 349 | # https://drive.google.com/drive/folders/1lrDkQanWtTznf48FCaW5lX9ToRdNDF1a 350 | 351 | # render the model trained with explicit supervision at each scale 352 | config = './config/nerf/bacon.ini' 353 | checkpoint = '../trained_models/lego.pth' 354 | outdir = 'lego' 355 | res = 512 356 | for scale in range(4): 357 | eval_nerf_bacon('lego', config, checkpoint, outdir, res, scale) 358 | 359 | # render the semisupervised model 360 | config = './config/nerf/bacon_semisupervise.ini' 361 | checkpoint = '../trained_models/lego_semisupervise.pth' 362 | outdir = 'lego_semisupervise' 363 | res = 512 364 | for scale in range(4): 365 | eval_nerf_bacon('lego_semisupervise', config, checkpoint, outdir, res, scale) 366 | -------------------------------------------------------------------------------- /experiments/render_sdf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import utils 6 | import modules 7 | import torch 8 | import numpy as np 9 | from tqdm import tqdm 10 | import mcubes 11 | import trimesh 12 | import dataio 13 | import math 14 | 15 | 16 | def export_model(ckpt_path, model_name, N=512, model_type='bacon', hidden_layers=8, 17 | hidden_size=256, output_layers=[1, 2, 4, 8], 18 | return_sdf=False, adaptive=True): 19 | 20 | # the network has 4 output levels of detail 21 | num_outputs = len(output_layers) 22 | max_frequency = 3*(32,) 23 | 24 | # load model 25 | with utils.HiddenPrint(): 26 | model = modules.MultiscaleBACON(3, hidden_size, 1, 27 | hidden_layers=hidden_layers, 28 | bias=True, 29 | frequency=max_frequency, 30 | quantization_interval=np.pi, 31 | is_sdf=True, 32 | output_layers=output_layers, 33 | reuse_filters=True) 34 | 35 | ckpt = torch.load(ckpt_path) 36 | model.load_state_dict(ckpt) 37 | model.cuda() 38 | 39 | if not adaptive: 40 | # extracts separate meshes for each scale 41 | generate_mesh(model, N, return_sdf, num_outputs, model_name) 42 | 43 | else: 44 | # extracts single-scale output 45 | generate_mesh_adaptive(model, model_name) 46 | 47 | 48 | def generate_mesh(model, N, return_sdf=False, num_outputs=4, model_name='model'): 49 | 50 | # write output 51 | x = torch.linspace(-0.5, 0.5, N) 52 | if return_sdf: 53 | x = torch.arange(-N//2, N//2) / N 54 | x = x.float() 55 | x, y, z = torch.meshgrid(x, x, x) 56 | render_coords = torch.stack((x.flatten(), y.flatten(), z.flatten()), dim=-1).cuda() 57 | sdf_values = [np.zeros((N**3, 1)) for i in range(num_outputs)] 58 | 59 | # render in a batched fashion to save memory 60 | bsize = int(128**2) 61 | for i in tqdm(range(int(N**3 / bsize))): 62 | coords = render_coords[i*bsize:(i+1)*bsize, :] 63 | out = model({'coords': coords})['model_out'] 64 | 65 | if not isinstance(out, list): 66 | out = [out, ] 67 | 68 | for idx, sdf in enumerate(out): 69 | sdf_values[idx][i*bsize:(i+1)*bsize] = sdf.detach().cpu().numpy() 70 | 71 | if return_sdf: 72 | return [sdf.reshape(N, N, N) for sdf in sdf_values] 73 | 74 | for idx, sdf in enumerate(sdf_values): 75 | sdf = sdf.reshape(N, N, N) 76 | vertices, triangles = mcubes.marching_cubes(-sdf, 0) 77 | mesh = trimesh.Trimesh(vertices=vertices, faces=triangles) 78 | mesh.vertices = (mesh.vertices / N - 0.5) + 0.5/N 79 | 80 | os.makedirs('./outputs/meshes', exist_ok=True) 81 | mesh.export(f"./outputs/meshes/{model_name}_{idx+1}.obj") 82 | 83 | 84 | def prepare_multi_scale(res, num_scales): 85 | def coord2ind(xyz_coord, res): 86 | # xyz_coord: * x 3 87 | x, y, z = torch.split(xyz_coord, 1, dim=-1) 88 | flat_ind = x * res**2 + y * res + z 89 | return flat_ind.squeeze(-1) # * 90 | 91 | shifts = torch.from_numpy(np.stack(np.mgrid[:2, :2, :2], axis=-1)).view(-1, 3) 92 | 93 | def subdiv_index(xyz_prev, next_res): # should output (N^3)*8 94 | xyz_next = xyz_prev.unsqueeze(1) * 2 + shifts # (N^3)x8x3 95 | flat_ind_next = coord2ind(xyz_next, next_res) # (N^3)*8 96 | return flat_ind_next 97 | 98 | lowest_res = res / 2**(num_scales-1) 99 | subdiv_hash_list = [] 100 | 101 | for i in range(num_scales-1): 102 | curr_res = int(lowest_res*2**i) 103 | xyz_ind = torch.from_numpy(np.stack(np.mgrid[:curr_res, :curr_res, :curr_res], axis=-1)).view(-1, 3) # (N^3)x3 104 | subdiv_hash = subdiv_index(xyz_ind, curr_res * 2) 105 | subdiv_hash_list.append(subdiv_hash.cuda().long()) 106 | return subdiv_hash_list 107 | 108 | 109 | # multi-scale marching cubes 110 | def compute_one_scale(model, layer_ind, render_coords, sdf_values, hash_ind): 111 | assert(len(render_coords) == len(hash_ind)) 112 | bsize = int(128 ** 2) 113 | for i in range(int(len(render_coords) / bsize)+1): 114 | coords = render_coords[i * bsize:(i + 1) * bsize, :] 115 | out = model({'coords': coords}, specified_layers=output_layers[layer_ind])['model_out'] 116 | sdf_values[hash_ind[i * bsize:(i + 1) * bsize]] = out[0] 117 | 118 | 119 | def compute_one_scale_adaptive(model, layer_ind, render_coords, sdf_values, hash_ind, threshold=0.003): 120 | assert(len(render_coords) == len(hash_ind)) 121 | bsize = int(128 ** 2) 122 | for i in range(int(len(render_coords) / bsize)+1): 123 | coords = render_coords[i * bsize:(i + 1) * bsize, :] 124 | out = model({'coords': coords}, specified_layers=2, get_feature=True)['model_out'] 125 | sdf = out[0][0] 126 | if output_layers[layer_ind] > 2: 127 | feature = out[0][1] 128 | near_surf = (sdf.abs() < threshold).squeeze() 129 | coords_surf = coords[near_surf] 130 | feature_surf = feature[near_surf] 131 | out = model({'coords': coords_surf}, specified_layers=output_layers[layer_ind], 132 | continue_layer=2, continue_feature=feature_surf)['model_out'] 133 | sdf_near = out[0] 134 | sdf[near_surf] = sdf_near 135 | 136 | sdf_values[hash_ind[i * bsize:(i + 1) * bsize]] = sdf 137 | 138 | 139 | def generate_mesh_adaptive(model, model_name): 140 | with torch.no_grad(): 141 | lowest_res = N / 2 ** (len(output_layers) - 1) 142 | compute_one_scale(model, 0, coords_list[0], sdf_out_list[0], subdiv_hashes[0]) 143 | 144 | for i in range(1, num_outputs): 145 | curr_res = int(lowest_res*2**(i-1)) 146 | pixel_len = 1 / curr_res 147 | threshold = (math.sqrt(2)*pixel_len*0.5)*2 148 | sdf_prev = sdf_out_list[i-1] 149 | sdf_curr = sdf_out_list[i] 150 | hash_curr = subdiv_hashes[i] 151 | coords_curr = coords_list[i] 152 | near_surf_prev = (sdf_prev.abs() <= threshold).squeeze(-1) 153 | 154 | # empty space 155 | sdf_curr[hash_curr[~near_surf_prev]] = sdf_prev[~near_surf_prev].unsqueeze(-1) 156 | 157 | # non-empty space 158 | non_empty_ind = hash_curr[near_surf_prev].flatten() 159 | 160 | if i == num_outputs-1: 161 | compute_one_scale_adaptive(model, i, coords_curr[non_empty_ind], sdf_curr, 162 | non_empty_ind, threshold=pixel_len*0.5*2.) 163 | else: 164 | compute_one_scale(model, i, coords_curr[non_empty_ind], sdf_curr, non_empty_ind) 165 | 166 | # run marching cubes 167 | sdf = sdf_curr.reshape(N, N, N).detach().cpu().numpy() 168 | vertices, triangles = mcubes.marching_cubes(-sdf, 0) 169 | mesh = trimesh.Trimesh(vertices=vertices, faces=triangles) 170 | mesh.vertices = (mesh.vertices / N - 0.5) + 0.5/N 171 | 172 | os.makedirs('./outputs/meshes', exist_ok=True) 173 | mesh.export(f"./outputs/meshes/{model_name}.obj") 174 | 175 | 176 | def export_meshes(adaptive=True): 177 | bacon_ckpts = ['../trained_models/dragon.pth', 178 | '../trained_models/armadillo.pth', 179 | '../trained_models/lucy.pth', 180 | '../trained_models/thai.pth'] 181 | 182 | bacon_names = ['bacon_dragon', 183 | 'bacon_armadillo', 184 | 'bacon_lucy', 185 | 'bacon_thai'] 186 | 187 | print('Exporting BACON') 188 | for ckpt, name in tqdm(zip(bacon_ckpts, bacon_names), total=len(bacon_ckpts)): 189 | export_model(ckpt, name, model_type='bacon', output_layers=output_layers, adaptive=adaptive) 190 | 191 | 192 | def init_multiscale_mc(): 193 | subdiv_hashes = prepare_multi_scale(N, len(output_layers)) # (N^3)*8 194 | subdiv_hashes = [torch.arange((N // 8) ** 3).cuda().long(), ] + subdiv_hashes 195 | lowest_res = N // 2**(len(output_layers)-1) 196 | coords_list = [dataio.get_mgrid(lowest_res*(2**i), dim=3).cuda() for i in range(len(output_layers))] # (N^3)*3 197 | sdf_out_list = [torch.zeros(((lowest_res*(2**i))**3), 1).cuda() for i in range(len(output_layers))] # (N^3) 198 | 199 | return subdiv_hashes, lowest_res, coords_list, sdf_out_list 200 | 201 | 202 | if __name__ == '__main__': 203 | global N, output_layers, subdiv_hashes, lowest_res, coords_list, sdf_out_list, num_outputs 204 | N = 512 205 | output_layers = [2, 4, 6, 8] 206 | num_outputs = len(output_layers) 207 | 208 | subdiv_hashes, lowest_res, coords_list, sdf_out_list = init_multiscale_mc() 209 | 210 | # export meshes, use adaptive SDF evaluation or not 211 | # setting adaptive=False will output meshes at all resolutions 212 | # while adaptive=True while extract only a high-resolution mesh 213 | export_meshes(adaptive=True) 214 | -------------------------------------------------------------------------------- /experiments/train_1d.py: -------------------------------------------------------------------------------- 1 | # Enable import from parent package 2 | import sys 3 | import os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | 6 | from torch.utils.tensorboard import SummaryWriter 7 | import dataio 8 | import utils 9 | import training 10 | import loss_functions 11 | import modules 12 | from torch.utils.data import DataLoader 13 | import configargparse 14 | import torch 15 | from functools import partial 16 | import numpy as np 17 | 18 | torch.backends.cudnn.benchmark = True 19 | torch.set_num_threads(4) 20 | 21 | p = configargparse.ArgumentParser() 22 | p.add('-c', '--config', required=False, is_config_file=True, help='Path to config file.') 23 | 24 | # General training options 25 | p.add_argument('--batch_size', type=int, default=1) 26 | p.add_argument('--hidden_features', type=int, default=128) 27 | p.add_argument('--hidden_layers', type=int, default=4) 28 | p.add_argument('--experiment_name', type=str, default='train_1d_mfn', 29 | help='path to directory where checkpoints & tensorboard events will be saved.') 30 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=5e-5') 31 | p.add_argument('--num_steps', type=int, default=1001, 32 | help='Number of epochs to train for.') 33 | p.add_argument('--gpu', type=int, default=0, 34 | help='GPU ID to use') 35 | 36 | p.add_argument('--model', default='mfn', choices=['mfn', 'mlp'], 37 | help='use MFN or standard MLP') 38 | p.add_argument('--activation', type=str, default='sine', 39 | choices=['sine', 'relu', 'requ', 'gelu', 'selu', 'softplus', 'tanh', 'swish'], 40 | help='activation to use (for model mlp only)') 41 | p.add_argument('--w0', type=float, default=10) 42 | p.add_argument('--pe_scale', type=float, default=3, help='positional encoding scale') 43 | p.add_argument('--no_pe', action='store_true', default=False, help='override to have no positional encoding for relu mlp') 44 | p.add_argument('--max_freq', type=float, default=5, help='The network-equivalent sample rate used to represent the signal. Should be at least twice the Nyquist frequency.') 45 | 46 | # summary options 47 | p.add_argument('--steps_til_ckpt', type=int, default=100, 48 | help='Time interval in seconds until checkpoint is saved.') 49 | p.add_argument('--steps_til_summary', type=int, default=100, 50 | help='Time interval in seconds until tensorboard summary is saved.') 51 | 52 | # logging options 53 | p.add_argument('--logging_root', type=str, default='../logs', help='root for logging') 54 | 55 | opt = p.parse_args() 56 | 57 | if opt.experiment_name is None and opt.render_model is None: 58 | p.error('--experiment_name is required.') 59 | 60 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu) 61 | 62 | 63 | def main(): 64 | 65 | print('--- Run Configuration ---') 66 | for k, v in vars(opt).items(): 67 | print(k, v) 68 | 69 | train() 70 | 71 | 72 | def train(): 73 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 74 | utils.cond_mkdir(root_path) 75 | 76 | fn = dataio.sines1 77 | train_dataset = dataio.Func1DWrapper(range=(-0.5, 0.5), 78 | fn=fn, 79 | sampling_density=1000, 80 | train_every=1000/18) # 18 samples is ~1.1 the nyquist rate assuming fmax=8 81 | 82 | train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0) 83 | 84 | if opt.model == 'mlp': 85 | model = modules.CoordinateNet(nl=opt.activation, 86 | in_features=1, 87 | out_features=1, 88 | hidden_features=opt.hidden_features, 89 | num_hidden_layers=opt.hidden_layers, 90 | w0=opt.w0, 91 | pe_scale=opt.pe_scale, 92 | use_sigmoid=False, 93 | no_pe=opt.no_pe) 94 | 95 | elif opt.model == 'mfn': 96 | model = modules.BACON(1, opt.hidden_features, 1, 97 | hidden_layers=opt.hidden_layers, 98 | bias=True, 99 | frequency=[opt.max_freq, ], 100 | quantization_interval=2*np.pi) 101 | else: 102 | raise ValueError('model must be mlp or mfn') 103 | 104 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 105 | params = sum([np.prod(p.size()) for p in model_parameters]) 106 | print(f'Num. Parameters: {params}') 107 | 108 | model.cuda() 109 | 110 | # Define the loss 111 | loss_fn = loss_functions.function_mse 112 | summary_fn = partial(utils.write_simple_1D_function_summary, train_dataset) 113 | 114 | # Save command-line parameters log directory. 115 | p.write_config_file(opt, [os.path.join(root_path, 'config.ini')]) 116 | with open(os.path.join(root_path, "params.txt"), "w") as out_file: 117 | out_file.write('\n'.join(["%s: %s" % (key, value) for key, value in vars(opt).items()])) 118 | 119 | # Save text summary of model into log directory. 120 | with open(os.path.join(root_path, "model.txt"), "w") as out_file: 121 | out_file.write(str(model)) 122 | 123 | training.train(model=model, train_dataloader=train_dataloader, steps=opt.num_steps, lr=opt.lr, 124 | steps_til_summary=opt.steps_til_summary, steps_til_checkpoint=opt.steps_til_ckpt, 125 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn) 126 | 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /experiments/train_img.py: -------------------------------------------------------------------------------- 1 | # Enable import from parent package 2 | import sys 3 | import os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | 6 | from torch.utils.tensorboard import SummaryWriter 7 | import numpy as np 8 | import dataio 9 | import utils 10 | import training 11 | import loss_functions 12 | import modules 13 | from torch.utils.data import DataLoader 14 | import configargparse 15 | import torch 16 | from functools import partial 17 | 18 | torch.backends.cudnn.benchmark = True 19 | torch.set_num_threads(4) 20 | 21 | p = configargparse.ArgumentParser() 22 | 23 | # config file, output directories 24 | p.add('-c', '--config', required=False, is_config_file=True, 25 | help='Path to config file.') 26 | p.add_argument('--logging_root', type=str, default='../logs', help='root for logging') 27 | p.add_argument('--experiment_name', type=str, default='train_img', 28 | help='path to directory where checkpoints & tensorboard events will be saved.') 29 | 30 | # general training options 31 | p.add_argument('--model', default='mfn', choices=['mfn', 'mlp'], 32 | help='use MFN or standard MLP') 33 | p.add_argument('--batch_size', type=int, default=1) 34 | p.add_argument('--hidden_features', type=int, default=32) 35 | p.add_argument('--hidden_layers', type=int, default=4) 36 | p.add_argument('--res', type=int, default=256, 37 | help='resolution of image to fit, also used to set the network-equivalent sample rate' 38 | + ' i.e., the maximum network bandwidth in cycles per unit interval is half this value') 39 | p.add_argument('--lr', type=float, default=5e-4, help='learning rate') 40 | p.add_argument('--num_steps', type=int, default=5001, 41 | help='number of training steps') 42 | p.add_argument('--gpu', type=int, default=0, 43 | help='gpu id to use for training') 44 | 45 | # mfn options 46 | p.add_argument('--multiscale', action='store_true', default=False, 47 | help='use multiscale') 48 | p.add_argument('--use_resized', action='store_true', default=False, 49 | help='use multiscale') 50 | 51 | # mlp options 52 | p.add_argument('--activation', type=str, default='sine', 53 | choices=['sine', 'relu', 'requ', 'gelu', 'selu', 'softplus', 'tanh', 'swish'], 54 | help='activation to use (for model mlp only)') 55 | p.add_argument('--ipe', action='store_true', default=False, 56 | help='use integrated positional encoding') 57 | p.add_argument('--w0', type=float, default=10) 58 | p.add_argument('--pe_scale', type=float, default=3, help='positional encoding scale') 59 | p.add_argument('--no_pe', action='store_true', default=False, 60 | help='override to have no positional encoding for relu mlp') 61 | 62 | # data processing and i/o 63 | p.add_argument('--centered', action='store_true', default=False, 64 | help='centere input coordinates as -1 to 1') 65 | p.add_argument('--img_fn', type=str, default='../data/lighthouse.png', 66 | help='path to specific png filename') 67 | p.add_argument('--grayscale', action='store_true', default=False, 68 | help='if grayscale image') 69 | 70 | # summary, logging options 71 | p.add_argument('--steps_til_ckpt', type=int, default=100, 72 | help='Time interval in seconds until checkpoint is saved.') 73 | p.add_argument('--steps_til_summary', type=int, default=100, 74 | help='Time interval in seconds until tensorboard summary is saved.') 75 | 76 | opt = p.parse_args() 77 | 78 | if opt.experiment_name is None and opt.render_model is None: 79 | p.error('--experiment_name is required.') 80 | 81 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu) 82 | 83 | 84 | def main(): 85 | 86 | print('--- Run Configuration ---') 87 | for k, v in vars(opt).items(): 88 | print(k, v) 89 | 90 | train() 91 | 92 | 93 | def train(): 94 | 95 | # set up logging dir 96 | opt.root_path = os.path.join(opt.logging_root, opt.experiment_name) 97 | utils.cond_mkdir(opt.root_path) 98 | 99 | # get datasets 100 | trn_dataset, val_dataset, dataloader = init_dataloader(opt) 101 | 102 | # set up coordinate network 103 | model = init_model(opt) 104 | 105 | # loss and tensorboard logging functions 106 | loss_fn, summary_fn = init_loss(opt, trn_dataset, val_dataset) 107 | 108 | # back up config file 109 | save_params(opt, model) 110 | 111 | # start training 112 | training.train(model=model, train_dataloader=dataloader, 113 | steps=opt.num_steps, lr=opt.lr, 114 | steps_til_summary=opt.steps_til_summary, 115 | steps_til_checkpoint=opt.steps_til_ckpt, 116 | model_dir=opt.root_path, loss_fn=loss_fn, summary_fn=summary_fn) 117 | 118 | 119 | def init_dataloader(opt): 120 | ''' load image datasets, dataloader ''' 121 | 122 | if opt.img_fn == '../data/lighthouse.png': 123 | url = 'http://www.cs.albany.edu/~xypan/research/img/Kodak/kodim19.png' 124 | else: 125 | url = None 126 | 127 | # init datasets 128 | trn_dataset = dataio.ImageFile(opt.img_fn, grayscale=opt.grayscale, resolution=(opt.res, opt.res), url=url) 129 | 130 | val_dataset = dataio.ImageFile(opt.img_fn, grayscale=opt.grayscale, resolution=(2*opt.res, 2*opt.res), url=url) 131 | 132 | trn_dataset = dataio.ImageWrapper(trn_dataset, centered=opt.centered, 133 | include_end=False, 134 | multiscale=opt.use_resized, 135 | stages=3) 136 | 137 | val_dataset = dataio.ImageWrapper(val_dataset, centered=opt.centered, 138 | include_end=False, 139 | multiscale=opt.use_resized, 140 | stages=3) 141 | 142 | dataloader = DataLoader(trn_dataset, shuffle=True, batch_size=opt.batch_size, 143 | pin_memory=True, num_workers=0) 144 | 145 | return trn_dataset, val_dataset, dataloader 146 | 147 | 148 | def init_model(opt): 149 | 150 | if opt.grayscale: 151 | out_features = 1 152 | else: 153 | out_features = 3 154 | 155 | if opt.model == 'mlp': 156 | 157 | if opt.multiscale: 158 | m = modules.MultiscaleCoordinateNet 159 | else: 160 | m = modules.CoordinateNet 161 | 162 | model = m(nl=opt.activation, 163 | in_features=2, 164 | out_features=out_features, 165 | hidden_features=opt.hidden_features, 166 | num_hidden_layers=opt.hidden_layers, 167 | w0=opt.w0, 168 | pe_scale=opt.pe_scale, 169 | no_pe=opt.no_pe, 170 | integrated_pe=opt.ipe) 171 | 172 | elif opt.model == 'mfn': 173 | 174 | if opt.multiscale: 175 | m = modules.MultiscaleBACON 176 | else: 177 | m = modules.BACON 178 | 179 | input_scales = [1/8, 1/8, 1/4, 1/4, 1/4] 180 | output_layers = [1, 2, 4] 181 | 182 | model = m(2, opt.hidden_features, out_size=out_features, 183 | hidden_layers=opt.hidden_layers, 184 | bias=True, 185 | frequency=(opt.res, opt.res), 186 | quantization_interval=2*np.pi, 187 | input_scales=input_scales, 188 | output_layers=output_layers, 189 | reuse_filters=False) 190 | 191 | else: 192 | raise ValueError('model must be mlp or mfn') 193 | 194 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 195 | params = sum([np.prod(p.size()) for p in model_parameters]) 196 | print(f'Num. Parameters: {params}') 197 | model.cuda() 198 | 199 | return model 200 | 201 | 202 | def init_loss(opt, trn_dataset, val_dataset): 203 | ''' define loss, summary functions given expmt configs ''' 204 | 205 | # initialize the loss 206 | if opt.multiscale: 207 | loss_fn = partial(loss_functions.multiscale_image_mse, use_resized=opt.use_resized) 208 | summary_fn = partial(utils.write_multiscale_image_summary, (opt.res, opt.res), 209 | trn_dataset, use_resized=opt.use_resized, val_dataset=val_dataset) 210 | else: 211 | loss_fn = loss_functions.image_mse 212 | summary_fn = partial(utils.write_image_summary, (opt.res, opt.res), trn_dataset, 213 | val_dataset=val_dataset) 214 | 215 | return loss_fn, summary_fn 216 | 217 | 218 | def save_params(opt, model): 219 | 220 | # Save command-line parameters log directory. 221 | p.write_config_file(opt, [os.path.join(opt.root_path, 'config.ini')]) 222 | with open(os.path.join(opt.root_path, "params.txt"), "w") as out_file: 223 | out_file.write('\n'.join(["%s: %s" % (key, value) for key, value in vars(opt).items()])) 224 | 225 | # Save text summary of model into log directory. 226 | with open(os.path.join(opt.root_path, "model.txt"), "w") as out_file: 227 | out_file.write(str(model)) 228 | 229 | 230 | if __name__ == '__main__': 231 | main() 232 | -------------------------------------------------------------------------------- /experiments/train_radiance_field.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | from torch.utils.tensorboard import SummaryWriter 6 | import dataio 7 | import utils 8 | import training 9 | import loss_functions 10 | import modules 11 | from torch.utils.data import DataLoader 12 | import configargparse 13 | from functools import partial 14 | import torch 15 | import numpy as np 16 | 17 | torch.set_num_threads(8) 18 | torch.backends.cudnn.benchmark = True 19 | 20 | p = configargparse.ArgumentParser() 21 | p.add('-c', '--config', required=False, is_config_file=True, help='Path to config file.') 22 | 23 | # Experiment & I/O general properties 24 | p.add_argument('--experiment_name', type=str, default=None, 25 | help='path to directory where checkpoints & tensorboard events will be saved.') 26 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 27 | p.add_argument('--logging_root', type=str, default='../logs', help='root for logging') 28 | p.add_argument('--dataset_path', type=str, default='../data/nerf_synthetic/lego/', 29 | help='path to directory where dataset is stored') 30 | p.add_argument('--resume', nargs=2, type=str, default=None, 31 | help='resume training, specify path to directory where model is stored.') 32 | p.add_argument('--num_steps', type=int, default=1000000, 33 | help='Number of iterations to train for.') 34 | p.add_argument('--steps_til_ckpt', type=int, default=50000, 35 | help='Iterations until checkpoint is saved.') 36 | p.add_argument('--steps_til_summary', type=int, default=2000, 37 | help='Iterations until tensorboard summary is saved.') 38 | 39 | # GPU & other computing properties 40 | p.add_argument('--gpu', type=int, default=0, 41 | help='GPU ID to use') 42 | p.add_argument('--chunk_size_train', type=int, default=1024, 43 | help='max chunk size to process data during training') 44 | p.add_argument('--chunk_size_eval', type=int, default=512, 45 | help='max chunk size to process data during eval') 46 | p.add_argument('--num_workers', type=int, default=0, help='number of dataloader workers.') 47 | 48 | # Learning properties 49 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=5e-5') 50 | p.add_argument('--batch_size', type=int, default=1) 51 | 52 | # Network architecture properties 53 | p.add_argument('--hidden_features', type=int, default=128) 54 | p.add_argument('--hidden_layers', type=int, default=4) 55 | 56 | p.add_argument('--multiscale', action='store_true', help='use multiscale architecture') 57 | p.add_argument('--supervise_hr', action='store_true', help='supervise only with high resolution signal') 58 | p.add_argument('--use_resized', action='store_true', help='use explicit multiscale supervision') 59 | p.add_argument('--reuse_filters', action='store_true', help='reuse fourier filters for faster training/inference') 60 | 61 | # NeRF Properties 62 | p.add_argument('--img_size', type=int, default=64, 63 | help='image resolution to train on (assumed symmetric)') 64 | p.add_argument('--samples_per_ray', type=int, default=128, 65 | help='samples to evaluate along each ray') 66 | p.add_argument('--samples_per_view', type=int, default=1024, 67 | help='samples to evaluate along each view') 68 | 69 | opt = p.parse_args() 70 | 71 | if opt.experiment_name is None and opt.render_model is None: 72 | p.error('--experiment_name is required.') 73 | 74 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu) 75 | 76 | 77 | def main(): 78 | print('--- Run Configuration ---') 79 | for k, v in vars(opt).items(): 80 | print(k, v) 81 | train() 82 | 83 | 84 | def train(validation=True): 85 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 86 | utils.cond_mkdir(root_path) 87 | 88 | ''' Training dataset ''' 89 | dataset = dataio.NerfBlenderDataset(opt.dataset_path, 90 | splits=['train'], 91 | mode='train', 92 | resize_to=2*(opt.img_size,), 93 | multiscale=opt.multiscale) 94 | 95 | coords_dataset = dataio.Implicit6DMultiviewDataWrapper(dataset, 96 | dataset.get_img_shape(), 97 | dataset.get_camera_params(), 98 | samples_per_view=opt.samples_per_view, 99 | num_workers=opt.num_workers, 100 | multiscale=opt.use_resized, 101 | supervise_hr=opt.supervise_hr, 102 | scales=[1/8, 1/4, 1/2, 1]) 103 | ''' Validation dataset ''' 104 | if validation: 105 | val_dataset = dataio.NerfBlenderDataset(opt.dataset_path, 106 | splits=['val'], 107 | mode='val', 108 | resize_to=2*(opt.img_size,), 109 | multiscale=opt.multiscale) 110 | 111 | val_coords_dataset = dataio.Implicit6DMultiviewDataWrapper(val_dataset, 112 | val_dataset.get_img_shape(), 113 | val_dataset.get_camera_params(), 114 | samples_per_view=opt.img_size**2, 115 | num_workers=opt.num_workers, 116 | multiscale=opt.use_resized, 117 | supervise_hr=opt.supervise_hr, 118 | scales=[1/8, 1/4, 1/2, 1]) 119 | 120 | ''' Dataloaders''' 121 | dataloader = DataLoader(coords_dataset, shuffle=True, batch_size=opt.batch_size, # num of views in a batch 122 | pin_memory=True, num_workers=opt.num_workers) 123 | 124 | if validation: 125 | val_dataloader = DataLoader(val_coords_dataset, shuffle=True, batch_size=1, 126 | pin_memory=True, num_workers=opt.num_workers) 127 | else: 128 | val_dataloader = None 129 | 130 | # get model paths 131 | if opt.resume is not None: 132 | path, step = opt.resume 133 | step = int(step) 134 | assert(os.path.isdir(path)) 135 | assert opt.config is not None, 'Specify config file' 136 | 137 | # since model goes between -4 and 4 instead of -0.5 to 0.5 138 | # we divide by a factor of 8. Then this is Nyquist sampled 139 | # assuming a maximum frequency of opt.img_size/8 cycles per unit interval 140 | # (where the blender dataset scenes typically span from -4 to 4 units) 141 | rgb_sample_freq = 3*(2*opt.img_size/8,) 142 | 143 | if opt.multiscale: 144 | # scale the frequencies of each layer 145 | # so that we have outputs at 1/8, 1/4, 1/2, and 1x 146 | # the maximum network bnadiwdth 147 | input_scales = [1/24, 1/24, 1/24, 1/16, 1/16, 1/8, 1/8, 1/4, 1/4] 148 | output_layers = [2, 4, 6, 8] 149 | 150 | model = modules.MultiscaleBACON(3, opt.hidden_features, 4, 151 | hidden_layers=opt.hidden_layers, 152 | bias=True, 153 | frequency=rgb_sample_freq, 154 | quantization_interval=np.pi/4, 155 | input_scales=input_scales, 156 | output_layers=output_layers, 157 | reuse_filters=opt.reuse_filters) 158 | model.cuda() 159 | 160 | else: 161 | input_scales = [1/24, 1/24, 1/24, 1/16, 1/16, 1/8, 1/8, 1/4, 1/4] 162 | input_scales = input_scales[:opt.hidden_layers+1] 163 | 164 | model = modules.BACON(3, opt.hidden_features, 4, 165 | hidden_layers=opt.hidden_layers, 166 | bias=True, 167 | frequency=rgb_sample_freq, 168 | quantization_interval=np.pi/4, 169 | reuse_filters=opt.reuse_filters, 170 | input_scales=input_scales) 171 | model.cuda() 172 | 173 | if opt.resume is not None: 174 | print('Loading checkpoints') 175 | 176 | state_dict = torch.load(path + '/checkpoints/' + f'model_combined_step_{step:04d}.pth') 177 | model.load_state_dict(state_dict, strict=False) 178 | 179 | # load optimizers 180 | try: 181 | resume_checkpoint = {} 182 | ckpt = torch.load(path + '/checkpoints/' + f'optim_combined_step_{step:04d}.pth') 183 | for g in ckpt['optimizer_state_dict']['param_groups']: 184 | g['lr'] = opt.lr 185 | 186 | resume_checkpoint['combined'] = {} 187 | resume_checkpoint['combined']['optim'] = ckpt['optimizer_state_dict'] 188 | resume_checkpoint['combined']['scheduler'] = ckpt['scheduler_state_dict'] 189 | resume_checkpoint['step'] = ckpt['step'] 190 | except FileNotFoundError: 191 | print('Unable to load optimizer checkpoints') 192 | else: 193 | resume_checkpoint = {} 194 | 195 | models = {'combined': model} 196 | 197 | # Define the loss 198 | if opt.multiscale: 199 | loss_fn = partial(loss_functions.multiscale_radiance_loss, use_resized=opt.use_resized) 200 | summary_fn = partial(utils.write_multiscale_radiance_summary, 201 | chunk_size_eval=opt.chunk_size_eval, 202 | num_views_to_disp_at_training=1, 203 | hierarchical_sampling=True) 204 | else: 205 | loss_fn = partial(loss_functions.radiance_sigma_rgb_loss) 206 | 207 | summary_fn = partial(utils.write_radiance_summary, 208 | chunk_size_eval=opt.chunk_size_eval, 209 | num_views_to_disp_at_training=1, 210 | hierarchical_sampling=True) 211 | 212 | chunk_lists_from_batch_fn = dataio.chunk_lists_from_batch_reduce_to_raysamples_fn 213 | 214 | # Save command-line parameters log directory. 215 | p.write_config_file(opt, [os.path.join(root_path, 'config.ini')]) 216 | with open(os.path.join(root_path, "params.txt"), "w") as out_file: 217 | out_file.write('\n'.join(["%s: %s" % (key, value) for key, value in vars(opt).items()])) 218 | 219 | # Save text summary of model into log directory. 220 | with open(os.path.join(root_path, "model.txt"), "w") as out_file: 221 | for model_name, model in models.items(): 222 | out_file.write(model_name) 223 | out_file.write(str(model)) 224 | 225 | training.train_wchunks(models, dataloader, 226 | num_steps=opt.num_steps, lr=opt.lr, 227 | steps_til_summary=opt.steps_til_summary, steps_til_checkpoint=opt.steps_til_ckpt, 228 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, 229 | val_dataloader=val_dataloader, 230 | chunk_lists_from_batch_fn=chunk_lists_from_batch_fn, 231 | max_chunk_size=opt.chunk_size_train, 232 | resume_checkpoint=resume_checkpoint, 233 | chunked=True, 234 | hierarchical_sampling=True, 235 | stop_after=0) 236 | 237 | 238 | if __name__ == '__main__': 239 | main() 240 | -------------------------------------------------------------------------------- /experiments/train_sdf.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | from torch.utils.tensorboard import SummaryWriter 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import configargparse 10 | import dataio 11 | import utils 12 | import training 13 | import loss_functions 14 | import modules 15 | from functools import partial 16 | 17 | torch.set_num_threads(8) 18 | 19 | p = configargparse.ArgumentParser() 20 | 21 | # config file, output directories 22 | p.add('-c', '--config', required=False, is_config_file=True, 23 | help='Path to config file.') 24 | p.add_argument('--logging_root', type=str, default='../logs', 25 | help='root for logging') 26 | p.add_argument('--experiment_name', type=str, required=True, 27 | help='subdirectory in logging_root for checkpoints, summaries') 28 | 29 | # general training 30 | p.add_argument('--model_type', type=str, default='mfn', 31 | help='options: mfn, siren, ff') 32 | p.add_argument('--hidden_size', type=int, default=128, 33 | help='size of hidden layer') 34 | p.add_argument('--hidden_layers', type=int, default=8) 35 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate') 36 | p.add_argument('--num_steps', type=int, default=20000, 37 | help='number of training steps') 38 | p.add_argument('--ckpt_step', type=int, default=0, 39 | help='step at which to resume training') 40 | p.add_argument('--gpu', type=int, default=1, help='GPU ID to use') 41 | p.add_argument('--seed', default=None, 42 | help='random seed for experiment reproducibility') 43 | 44 | # mfn options 45 | p.add_argument('--multiscale', action='store_true', default=False, 46 | help='use multiscale') 47 | p.add_argument('--max_freq', type=int, default=512, 48 | help='The network-equivalent sample rate used to represent the signal.' 49 | + 'Should be at least twice the Nyquist frequency.') 50 | p.add_argument('--input_scales', nargs='*', type=float, default=None, 51 | help='fraction of resolution growth at each layer') 52 | p.add_argument('--output_layers', nargs='*', type=int, default=None, 53 | help='layer indices to output, beginning at 1') 54 | 55 | # mlp options 56 | p.add_argument('--w0', default=30, type=int, 57 | help='omega_0 parameter for siren') 58 | p.add_argument('--pe_scale', default=5, type=float, 59 | help='positional encoding scale') 60 | 61 | # sdf model and sampling 62 | p.add_argument('--num_pts_on', type=int, default=10000, 63 | help='number of on-surface points to sample') 64 | p.add_argument('--coarse_scale', type=float, default=1e-1, 65 | help='laplacian scale factor for coarse samples') 66 | p.add_argument('--fine_scale', type=float, default=1e-3, 67 | help='laplacian scale factor for fine samples') 68 | p.add_argument('--coarse_weight', type=float, default=1e-2, 69 | help='weight to apply to coarse loss samples') 70 | 71 | # data i/o 72 | p.add_argument('--shape', type=str, default='bunny', 73 | help='name of point cloud shape in xyz format') 74 | p.add_argument('--point_cloud_path', type=str, 75 | default='../data/armadillo.xyz', 76 | help='path for input point cloud') 77 | p.add_argument('--num_workers', default=0, type=int, 78 | help='number of workers') 79 | 80 | # tensorboard summary 81 | p.add_argument('--steps_til_ckpt', type=int, default=50000, 82 | help='epoch frequency to save a checkpoint') 83 | p.add_argument('--steps_til_summary', type=int, default=1000, 84 | help='epoch frequency to update tensorboard summary') 85 | 86 | opt = p.parse_args() 87 | 88 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu) 89 | 90 | 91 | def main(): 92 | 93 | print('--- Run Configuration ---') 94 | for k, v in vars(opt).items(): 95 | print(k, v) 96 | 97 | train() 98 | 99 | 100 | def train(): 101 | 102 | opt.root_path = os.path.join(opt.logging_root, opt.experiment_name) 103 | utils.cond_mkdir(opt.root_path) 104 | 105 | if opt.seed: 106 | torch.manual_seed(int(opt.seed)) 107 | np.random.seed(int(opt.seed)) 108 | 109 | dataloader = init_dataloader(opt) 110 | 111 | model = init_model(opt) 112 | 113 | loss_fn, summary_fn = init_loss(opt) 114 | 115 | save_params(opt, model) 116 | 117 | training.train(model=model, train_dataloader=dataloader, steps=opt.num_steps, 118 | lr=opt.lr, steps_til_summary=opt.steps_til_summary, 119 | ckpt_step=opt.ckpt_step, 120 | steps_til_checkpoint=opt.steps_til_ckpt, 121 | model_dir=opt.root_path, loss_fn=loss_fn, summary_fn=summary_fn, 122 | double_precision=False, clip_grad=True, 123 | use_lr_scheduler=True) 124 | 125 | 126 | def init_dataloader(opt): 127 | ''' load sdf dataloader via eikonal equation or fitting sdf directly ''' 128 | 129 | sdf_dataset = dataio.MeshSDF(opt.point_cloud_path, 130 | num_samples=opt.num_pts_on, 131 | coarse_scale=opt.coarse_scale, 132 | fine_scale=opt.fine_scale) 133 | 134 | dataloader = DataLoader(sdf_dataset, shuffle=True, 135 | batch_size=1, pin_memory=True, 136 | num_workers=opt.num_workers) 137 | 138 | return dataloader 139 | 140 | 141 | def init_model(opt): 142 | ''' return appropriate model given experiment configs ''' 143 | 144 | if opt.model_type == 'mfn': 145 | 146 | opt.input_scales = [1/24, 1/24, 1/24, 1/16, 1/16, 1/8, 1/8, 1/4, 1/4] 147 | opt.output_layers = [2, 4, 6, 8] 148 | 149 | frequency = (opt.max_freq, opt.max_freq, opt.max_freq) 150 | 151 | if opt.multiscale: 152 | if opt.output_layers and len(opt.output_layers) == 1: 153 | raise ValueError('expects >1 layer extraction if multiscale') 154 | model_ = modules.MultiscaleBACON 155 | else: 156 | model_ = modules.BACON 157 | 158 | model = model_(in_size=3, hidden_size=opt.hidden_size, out_size=1, 159 | hidden_layers=opt.hidden_layers, 160 | bias=True, 161 | frequency=frequency, 162 | quantization_interval=2*np.pi, # data on range [-0.5, 0.5] 163 | input_scales=opt.input_scales, 164 | is_sdf=True, 165 | output_layers=opt.output_layers, 166 | reuse_filters=True) 167 | 168 | elif opt.model_type == 'siren': 169 | 170 | if opt.multiscale: 171 | model_ = modules.MultiscaleCoordinateNet 172 | else: 173 | model_ = modules.CoordinateNet 174 | 175 | model = model_(nl='sine', 176 | in_features=3, 177 | out_features=1, 178 | num_hidden_layers=opt.hidden_layers, 179 | hidden_features=opt.hidden_size, 180 | w0=opt.w0, 181 | is_sdf=True) 182 | 183 | elif opt.model_type == 'ff': # mlp w relu + positional encoding 184 | model = modules.CoordinateNet(nl='relu', 185 | in_features=3, 186 | out_features=1, 187 | num_hidden_layers=opt.hidden_layers, 188 | hidden_features=opt.hidden_size, 189 | is_sdf=True, 190 | pe_scale=opt.pe_scale, 191 | use_sigmoid=False) 192 | else: 193 | raise NotImplementedError 194 | 195 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 196 | params = sum([np.prod(p.size()) for p in model_parameters]) 197 | print(f'Num. Parameters: {params}') 198 | model.cuda() 199 | 200 | # if resuming model training 201 | if opt.ckpt_step: 202 | opt.num_steps -= opt.ckpt_step # steps remaning to train 203 | if opt.num_steps < 1: 204 | raise ValueError('ckpt_epoch must be less than num_epochs') 205 | print(opt.num_steps) 206 | 207 | pth_file = '{}/checkpoints/model_step_{}.pth'.format(opt.root_path, 208 | str(opt.ckpt_step).zfill(4)) 209 | model.load_state_dict(torch.load(pth_file)) 210 | 211 | return model 212 | 213 | 214 | def init_loss(opt): 215 | ''' define loss, summary functions given expmt configs ''' 216 | 217 | if opt.multiscale: 218 | summary_fn = utils.write_multiscale_sdf_summary 219 | loss_fn = partial(loss_functions.multiscale_overfit_sdf, 220 | coarse_loss_weight=opt.coarse_weight) 221 | else: 222 | summary_fn = utils.write_sdf_summary 223 | loss_fn = partial(loss_functions.overfit_sdf, 224 | coarse_loss_weight=opt.coarse_weight) 225 | 226 | return loss_fn, summary_fn 227 | 228 | 229 | def save_params(opt, model): 230 | 231 | # Save command-line parameters log directory. 232 | p.write_config_file(opt, [os.path.join(opt.root_path, 'config.ini')]) 233 | with open(os.path.join(opt.root_path, "params.txt"), "w") as out_file: 234 | out_file.write('\n'.join(["%s: %s" % (key, value) for key, value in vars(opt).items()])) 235 | 236 | # Save text summary of model into log directory. 237 | with open(os.path.join(opt.root_path, "model.txt"), "w") as out_file: 238 | out_file.write(str(model)) 239 | 240 | 241 | if __name__ == '__main__': 242 | main() 243 | -------------------------------------------------------------------------------- /forward_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def cumprod_exclusive(tensor, dim=-2): 5 | cumprod = torch.cumprod(tensor, dim) 6 | cumprod = torch.roll(cumprod, 1, dim) 7 | cumprod[..., 0, :] = 1.0 8 | return cumprod 9 | 10 | 11 | def compute_transmittance_weights(pred_sigma, t_intervals): 12 | # pred_alpha = 1.-torch.exp(-torch.relu(pred_sigma)*t_intervals) 13 | tau = torch.nn.functional.softplus(pred_sigma - 1) 14 | pred_alpha = 1.-torch.exp(-tau*t_intervals) 15 | pred_weights = pred_alpha * cumprod_exclusive(1.-pred_alpha+1e-10, dim=-2) 16 | return pred_weights 17 | 18 | 19 | def compute_tomo_radiance(pred_weights, pred_rgb, black_background=False): 20 | eps = 0.001 21 | pred_rgb_pos = torch.sigmoid(pred_rgb) 22 | pred_rgb_pos = pred_rgb_pos * (1 + 2 * eps) - eps 23 | pred_pixel_samples = torch.sum(pred_rgb_pos*pred_weights, dim=-2) # line integral 24 | 25 | if not black_background: 26 | pred_pixel_samples += 1 - pred_weights.sum(-2) 27 | return pred_pixel_samples 28 | 29 | 30 | def compute_tomo_depth(pred_weights, zs): 31 | pred_depth = torch.sum(pred_weights*zs, dim=-2) 32 | return pred_depth 33 | 34 | 35 | def compute_disp_from_depth(pred_depth, pred_weights): 36 | pred_disp = 1. / torch.max(torch.tensor(1e-10).to(pred_depth.device), 37 | pred_depth / torch.sum(pred_weights, -2)) 38 | return pred_disp 39 | -------------------------------------------------------------------------------- /img/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/img/teaser.png -------------------------------------------------------------------------------- /loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import forward_models 3 | 4 | 5 | def function_mse(model_output, gt): 6 | idx = model_output['model_in']['idx'].long().squeeze() 7 | loss = (model_output['model_out']['output'][:, idx] - gt['func'][:, idx]) ** 2 8 | return {'func_loss': loss.mean()} 9 | 10 | 11 | def image_mse(model_output, gt): 12 | if 'complex' in model_output['model_out']: 13 | c = model_output['model_out']['complex'] 14 | loss = (c.real - gt['img']) ** 2 15 | imag_loss = (c.imag) ** 2 16 | return {'func_loss': loss.mean(), 'imag_loss': imag_loss.mean()} 17 | 18 | else: 19 | loss = (model_output['model_out']['output'] - gt['img']) ** 2 20 | return {'func_loss': loss.mean()} 21 | 22 | 23 | def multiscale_image_mse(model_output, gt, use_resized=False): 24 | if use_resized: 25 | loss = [(out - gt_img)**2 for out, gt_img in zip(model_output['model_out']['output'], gt['img'])] 26 | else: 27 | loss = [(out - gt['img'])**2 for out in model_output['model_out']['output']] 28 | 29 | loss = torch.stack(loss).mean() 30 | 31 | return {'func_loss': loss} 32 | 33 | 34 | def multiscale_radiance_loss(model_outputs, gt, use_resized=False, weight=1.0, 35 | regularize_sigma=False, reg_lambda=1e-5, reg_c=0.5): 36 | tomo_loss = None 37 | sigma_reg = None 38 | 39 | pred_sigmas = [pred[..., -1:] for pred in model_outputs['combined']['model_out']['output']] 40 | pred_rgbs = [pred[..., :-1] for pred in model_outputs['combined']['model_out']['output']] 41 | if isinstance(model_outputs['combined']['model_in']['t_intervals'], list): 42 | t_intervals = [t_interval for t_interval in model_outputs['combined']['model_in']['t_intervals']] 43 | else: 44 | t_intervals = model_outputs['combined']['model_in']['t_intervals'] 45 | 46 | for idx, (pred_sigma, pred_rgb) in enumerate(zip(pred_sigmas, pred_rgbs)): 47 | 48 | if isinstance(t_intervals, list): 49 | t_interval = t_intervals[idx] 50 | else: 51 | t_interval = t_intervals 52 | 53 | # Pass through the forward models 54 | pred_weights = forward_models.compute_transmittance_weights(pred_sigma, t_interval) 55 | pred_pixel_samples = forward_models.compute_tomo_radiance(pred_weights, pred_rgb) 56 | 57 | # Target Ground truth 58 | if use_resized: 59 | target_pixel_samples = gt['pixel_samples'][idx] 60 | else: 61 | target_pixel_samples = gt['pixel_samples'] 62 | 63 | # Loss 64 | if tomo_loss is None: 65 | tomo_loss = (pred_pixel_samples - target_pixel_samples)**2 66 | else: 67 | tomo_loss += (pred_pixel_samples - target_pixel_samples)**2 68 | 69 | if regularize_sigma: 70 | tau = torch.nn.functional.softplus(pred_sigma - 1) 71 | if sigma_reg is None: 72 | sigma_reg = (torch.log(1 + tau**2 / reg_c)) 73 | else: 74 | sigma_reg += (torch.log(1 + tau**2 / reg_c)) 75 | 76 | loss = {'tomo_rad_loss': weight * tomo_loss.mean()} 77 | 78 | if regularize_sigma: 79 | loss['sigma_reg'] = reg_lambda * sigma_reg.mean() 80 | 81 | return loss 82 | 83 | 84 | def radiance_sigma_rgb_loss(model_outputs, gt, regularize_sigma=False, 85 | reg_lambda=1e-5, reg_c=0.5): 86 | pred_sigma = model_outputs['combined']['model_out']['output'][..., -1:] 87 | pred_rgb = model_outputs['combined']['model_out']['output'][..., :-1] 88 | t_intervals = model_outputs['combined']['model_in']['t_intervals'] 89 | 90 | # Pass through the forward models 91 | pred_weights = forward_models.compute_transmittance_weights(pred_sigma, t_intervals) 92 | pred_pixel_samples = forward_models.compute_tomo_radiance(pred_weights, pred_rgb) 93 | 94 | # Target Ground truth 95 | target_pixel_samples = gt['pixel_samples'][..., :3] # rgba -> rgb 96 | 97 | # Loss 98 | tomo_loss = (pred_pixel_samples - target_pixel_samples)**2 99 | 100 | if regularize_sigma: 101 | tau = torch.nn.functional.softplus(pred_sigma - 1) 102 | sigma_reg = (torch.log(1 + tau**2 / reg_c)) 103 | 104 | loss = {'tomo_rad_loss': tomo_loss.mean()} 105 | 106 | if regularize_sigma: 107 | loss['sigma_reg'] = reg_lambda * sigma_reg.mean() 108 | 109 | return loss 110 | 111 | 112 | def overfit_sdf(model_output, gt, coarse_loss_weight=1e-2): 113 | return overfit_sdf_loss_total(model_output, gt, is_multiscale=False, 114 | coarse_loss_weight=coarse_loss_weight) 115 | 116 | 117 | def multiscale_overfit_sdf(model_output, gt, coarse_loss_weight=1e-2): 118 | return overfit_sdf_loss_total(model_output, gt, is_multiscale=True, 119 | coarse_loss_weight=coarse_loss_weight) 120 | 121 | 122 | def overfit_sdf_loss_total(model_output, gt, is_multiscale, lambda_grad=1e-3, 123 | coarse_loss_weight=1e-2): 124 | ''' fit sdf to sphere via mse loss ''' 125 | 126 | gt_sdf = gt['sdf'] 127 | pred_sdf = model_output['model_out'] 128 | 129 | pred_sdf_ = pred_sdf[0] if is_multiscale else pred_sdf 130 | 131 | mse_ = (gt_sdf - pred_sdf_)**2 132 | 133 | if is_multiscale: 134 | for pred_sdf_ in pred_sdf[1:]: 135 | mse_ += (gt_sdf - pred_sdf_)**2 136 | 137 | mse_ = (mse_ / len(pred_sdf)) 138 | mse_[:, ::2] *= coarse_loss_weight 139 | 140 | mse_fine = mse_[:, 1::2].sum() 141 | mse_coarse = mse_[:, ::2].sum() 142 | 143 | return {'sdf_fine': mse_fine, 'sdf_coarse': mse_coarse} 144 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import torch 6 | import math 7 | import numpy as np 8 | from functools import partial 9 | from torch import nn 10 | import copy 11 | 12 | 13 | def init_weights_normal(m): 14 | if type(m) == nn.Linear: 15 | if hasattr(m, 'weight'): 16 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_out') 17 | 18 | 19 | def init_weights_selu(m): 20 | if type(m) == nn.Linear: 21 | if hasattr(m, 'weight'): 22 | num_input = m.weight.size(-1) 23 | nn.init.normal_(m.weight, std=1/math.sqrt(num_input)) 24 | 25 | 26 | def init_weights_elu(m): 27 | if type(m) == nn.Linear: 28 | if hasattr(m, 'weight'): 29 | num_input = m.weight.size(-1) 30 | nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277)/math.sqrt(num_input)) 31 | 32 | 33 | def init_weights_xavier(m): 34 | if type(m) == nn.Linear: 35 | if hasattr(m, 'weight'): 36 | nn.init.xavier_normal_(m.weight) 37 | 38 | 39 | def init_weights_uniform(m): 40 | if type(m) == nn.Linear: 41 | if hasattr(m, 'weight'): 42 | torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu') 43 | 44 | 45 | def sine_init(m, w0=30): 46 | with torch.no_grad(): 47 | if hasattr(m, 'weight'): 48 | num_input = m.weight.size(-1) 49 | m.weight.uniform_(-np.sqrt(6/num_input)/w0, np.sqrt(6/num_input)/w0) 50 | 51 | 52 | def first_layer_sine_init(m): 53 | with torch.no_grad(): 54 | if hasattr(m, 'weight'): 55 | num_input = m.weight.size(-1) 56 | m.weight.uniform_(-1/num_input, 1/num_input) 57 | 58 | 59 | class FirstSine(nn.Module): 60 | def __init__(self, w0=20): 61 | super().__init__() 62 | self.w0 = torch.tensor(w0) 63 | 64 | def forward(self, input): 65 | return torch.sin(self.w0*input) 66 | 67 | 68 | class Sine(nn.Module): 69 | def __init__(self, w0=20): 70 | super().__init__() 71 | self.w0 = torch.tensor(w0) 72 | 73 | def forward(self, input): 74 | return torch.sin(self.w0*input) 75 | 76 | 77 | class MSoftplus(nn.Module): 78 | def __init__(self): 79 | super().__init__() 80 | self.softplus = nn.Softplus() 81 | self.cst = torch.log(torch.tensor(2.)) 82 | 83 | def forward(self, input): 84 | return self.softplus(input)-self.cst 85 | 86 | 87 | class Swish(nn.Module): 88 | def __init__(self): 89 | super().__init__() 90 | 91 | def forward(self, input): 92 | return input*torch.sigmoid(input) 93 | 94 | 95 | def mfn_weights_init(m): 96 | with torch.no_grad(): 97 | if hasattr(m, 'weight'): 98 | num_input = m.weight.size(-1) 99 | m.weight.uniform_(-np.sqrt(6/num_input), np.sqrt(6/num_input)) 100 | 101 | 102 | class MFNBase(nn.Module): 103 | 104 | def __init__(self, hidden_size, out_size, n_layers, weight_scale, 105 | bias=True, output_act=False): 106 | super().__init__() 107 | 108 | self.linear = nn.ModuleList( 109 | [nn.Linear(hidden_size, hidden_size, bias) for _ in range(n_layers)] 110 | ) 111 | 112 | self.output_linear = nn.Linear(hidden_size, out_size) 113 | 114 | self.output_act = output_act 115 | 116 | self.linear.apply(mfn_weights_init) 117 | self.output_linear.apply(mfn_weights_init) 118 | 119 | def forward(self, model_input): 120 | 121 | input_dict = {key: input.clone().detach().requires_grad_(True) 122 | for key, input in model_input.items()} 123 | coords = input_dict['coords'] 124 | 125 | out = self.filters[0](coords) 126 | for i in range(1, len(self.filters)): 127 | out = self.filters[i](coords) * self.linear[i - 1](out) 128 | out = self.output_linear(out) 129 | 130 | if self.output_act: 131 | out = torch.sin(out) 132 | 133 | return {'model_in': input_dict, 'model_out': {'output': out}} 134 | 135 | 136 | class FourierLayer(nn.Module): 137 | 138 | def __init__(self, in_features, out_features, weight_scale, quantization_interval=2*np.pi): 139 | super().__init__() 140 | self.linear = nn.Linear(in_features, out_features) 141 | 142 | r = 2*weight_scale[0] / quantization_interval 143 | assert math.isclose(r, round(r)), \ 144 | 'weight_scale should be divisible by quantization interval' 145 | 146 | # sample discrete uniform distribution of frequencies 147 | for i in range(self.linear.weight.data.shape[1]): 148 | init = torch.randint_like(self.linear.weight.data[:, i], 149 | 0, int(2*weight_scale[i] / quantization_interval)+1) 150 | init = init * quantization_interval - weight_scale[i] 151 | self.linear.weight.data[:, i] = init 152 | 153 | self.linear.weight.requires_grad = False 154 | self.linear.bias.data.uniform_(-np.pi, np.pi) 155 | return 156 | 157 | def forward(self, x): 158 | return torch.sin(self.linear(x)) 159 | 160 | 161 | class BACON(MFNBase): 162 | def __init__(self, 163 | in_size, 164 | hidden_size, 165 | out_size, 166 | hidden_layers=3, 167 | weight_scale=1.0, 168 | bias=True, 169 | output_act=False, 170 | frequency=(128, 128), 171 | quantization_interval=2*np.pi, # assumes data range [-.5,.5] 172 | centered=True, 173 | input_scales=None, 174 | output_layers=None, 175 | is_sdf=False, 176 | reuse_filters=False, 177 | **kwargs): 178 | 179 | super().__init__(hidden_size, out_size, hidden_layers, 180 | weight_scale, bias, output_act) 181 | 182 | self.quantization_interval = quantization_interval 183 | self.hidden_layers = hidden_layers 184 | self.hidden_size = hidden_size 185 | self.centered = centered 186 | self.frequency = frequency 187 | self.is_sdf = is_sdf 188 | self.reuse_filters = reuse_filters 189 | self.in_size = in_size 190 | 191 | # we need to multiply by this to be able to fit the signal 192 | input_scale = [round((np.pi * freq / (hidden_layers + 1)) 193 | / quantization_interval) * quantization_interval for freq in frequency] 194 | 195 | self.filters = nn.ModuleList([ 196 | FourierLayer(in_size, hidden_size, input_scale, 197 | quantization_interval=quantization_interval) 198 | for i in range(hidden_layers + 1)]) 199 | 200 | print(self) 201 | 202 | def forward_mfn(self, input_dict): 203 | if 'coords' in input_dict: 204 | coords = input_dict['coords'] 205 | elif 'ray_samples' in input_dict: 206 | if self.in_size > 3: 207 | coords = torch.cat((input_dict['ray_samples'], input_dict['ray_orientations']), dim=-1) 208 | else: 209 | coords = input_dict['ray_samples'] 210 | 211 | if self.reuse_filters: 212 | filter_outputs = 3 * [self.filters[2](coords), ] + \ 213 | 2 * [self.filters[4](coords), ] + \ 214 | 2 * [self.filters[6](coords), ] + \ 215 | 2 * [self.filters[8](coords), ] 216 | 217 | out = filter_outputs[0] 218 | for i in range(1, len(self.filters)): 219 | out = filter_outputs[i] * self.linear[i - 1](out) 220 | else: 221 | out = self.filters[0](coords) 222 | for i in range(1, len(self.filters)): 223 | out = self.filters[i](coords) * self.linear[i - 1](out) 224 | 225 | out = self.output_linear(out) 226 | 227 | if self.output_act: 228 | out = torch.sin(out) 229 | 230 | return out 231 | 232 | def forward(self, model_input, mode=None, integral_dim=None): 233 | 234 | out = {'output': self.forward_mfn(model_input)} 235 | 236 | if self.is_sdf: 237 | return {'model_in': model_input['coords'], 238 | 'model_out': out['output']} 239 | 240 | return {'model_in': model_input, 'model_out': out} 241 | 242 | 243 | class MultiscaleBACON(MFNBase): 244 | def __init__(self, 245 | in_size, 246 | hidden_size, 247 | out_size, 248 | hidden_layers=3, 249 | weight_scale=1.0, 250 | bias=True, 251 | output_act=False, 252 | frequency=(128, 128), 253 | quantization_interval=2*np.pi, 254 | centered=True, 255 | is_sdf=False, 256 | input_scales=None, 257 | output_layers=None, 258 | reuse_filters=False): 259 | 260 | super().__init__(hidden_size, out_size, hidden_layers, 261 | weight_scale, bias, output_act) 262 | 263 | self.quantization_interval = quantization_interval 264 | self.hidden_layers = hidden_layers 265 | self.centered = centered 266 | self.is_sdf = is_sdf 267 | self.frequency = frequency 268 | self.output_layers = output_layers 269 | self.reuse_filters = reuse_filters 270 | self.stop_after = None 271 | 272 | # we need to multiply by this to be able to fit the signal 273 | if input_scales is None: 274 | input_scale = [round((np.pi * freq / (hidden_layers + 1)) 275 | / quantization_interval) * quantization_interval for freq in frequency] 276 | 277 | self.filters = nn.ModuleList([ 278 | FourierLayer(in_size, hidden_size, input_scale, 279 | quantization_interval=quantization_interval) 280 | for i in range(hidden_layers + 1)]) 281 | else: 282 | if len(input_scales) != hidden_layers+1: 283 | raise ValueError('require n+1 scales for n hidden_layers') 284 | input_scale = [[round((np.pi * freq * scale) / quantization_interval) * quantization_interval 285 | for freq in frequency] for scale in input_scales] 286 | 287 | self.filters = nn.ModuleList([ 288 | FourierLayer(in_size, hidden_size, input_scale[i], 289 | quantization_interval=quantization_interval) 290 | for i in range(hidden_layers + 1)]) 291 | 292 | # linear layers to extract intermediate outputs 293 | self.output_linear = nn.ModuleList([nn.Linear(hidden_size, out_size) for i in range(len(self.filters))]) 294 | self.output_linear.apply(mfn_weights_init) 295 | 296 | # if outputs layers is None, output at every possible layer 297 | if self.output_layers is None: 298 | self.output_layers = np.arange(1, len(self.filters)) 299 | 300 | print(self) 301 | 302 | def layer_forward(self, coords, filter_outputs, specified_layers, 303 | get_feature, continue_layer, continue_feature): 304 | """ for multiscale SDF extraction """ 305 | 306 | # hardcode the 8 layer network that we use for all sdf experiments 307 | filter_ind_dict = [2, 2, 2, 4, 4, 6, 6, 8, 8] 308 | outputs = [] 309 | 310 | if continue_feature is None: 311 | assert(continue_layer == 0) 312 | out = self.filters[filter_ind_dict[0]](coords) 313 | filter_output_dict = {filter_ind_dict[0]: out} 314 | else: 315 | out = continue_feature 316 | filter_output_dict = {} 317 | 318 | for i in range(continue_layer+1, len(self.filters)): 319 | if filter_ind_dict[i] not in filter_output_dict.keys(): 320 | filter_output_dict[filter_ind_dict[i]] = self.filters[filter_ind_dict[i]](coords) 321 | out = filter_output_dict[filter_ind_dict[i]] * self.linear[i - 1](out) 322 | 323 | if i in self.output_layers and i == specified_layers: 324 | if get_feature: 325 | outputs.append([self.output_linear[i](out), out]) 326 | else: 327 | outputs.append(self.output_linear[i](out)) 328 | return outputs 329 | 330 | return outputs 331 | 332 | def forward(self, model_input, specified_layers=None, get_feature=False, 333 | continue_layer=0, continue_feature=None): 334 | 335 | if self.is_sdf: 336 | model_input = {key: input.clone().detach().requires_grad_(True) 337 | for key, input in model_input.items()} 338 | 339 | if 'coords' in model_input: 340 | coords = model_input['coords'] 341 | elif 'ray_samples' in model_input: 342 | coords = model_input['ray_samples'] 343 | 344 | outputs = [] 345 | if self.reuse_filters: 346 | 347 | # which layers to reuse 348 | if len(self.filters) < 9: 349 | filter_outputs = 2 * [self.filters[0](coords), ] + \ 350 | (len(self.filters)-2) * [self.filters[-1](coords), ] 351 | else: 352 | filter_outputs = 3 * [self.filters[2](coords), ] + \ 353 | 2 * [self.filters[4](coords), ] + \ 354 | 2 * [self.filters[6](coords), ] + \ 355 | 2 * [self.filters[8](coords), ] 356 | 357 | # multiscale sdf extractions (evaluate only some layers) 358 | if specified_layers is not None: 359 | outputs = self.layer_forward(coords, filter_outputs, specified_layers, 360 | get_feature, continue_layer, continue_feature) 361 | 362 | # evaluate all layers 363 | else: 364 | out = filter_outputs[0] 365 | for i in range(1, len(self.filters)): 366 | out = filter_outputs[i] * self.linear[i - 1](out) 367 | 368 | if i in self.output_layers: 369 | outputs.append(self.output_linear[i](out)) 370 | if self.stop_after is not None and len(outputs) > self.stop_after: 371 | break 372 | 373 | # no layer reuse 374 | else: 375 | out = self.filters[0](coords) 376 | for i in range(1, len(self.filters)): 377 | out = self.filters[i](coords) * self.linear[i - 1](out) 378 | 379 | if i in self.output_layers: 380 | outputs.append(self.output_linear[i](out)) 381 | if self.stop_after is not None and len(outputs) > self.stop_after: 382 | break 383 | 384 | if self.is_sdf: # convert dtype 385 | return {'model_in': model_input['coords'], 386 | 'model_out': outputs} # outputs is a list of tensors 387 | 388 | return {'model_in': model_input, 'model_out': {'output': outputs}} 389 | 390 | 391 | class MultiscaleCoordinateNet(nn.Module): 392 | '''A canonical coordinate network''' 393 | def __init__(self, out_features=1, nl='sine', in_features=1, 394 | hidden_features=256, num_hidden_layers=3, 395 | w0=30, pe_scale=6, use_sigmoid=True, no_pe=False, 396 | integrated_pe=False): 397 | 398 | super().__init__() 399 | 400 | self.nl = nl 401 | dims = in_features 402 | self.use_sigmoid = use_sigmoid 403 | self.no_pe = no_pe 404 | self.integrated_pe = integrated_pe 405 | 406 | if integrated_pe: 407 | self.pe = partial(IntegratedPositionalEncoding, L=pe_scale) 408 | in_features = int(2 * in_features * pe_scale) 409 | 410 | elif self.nl != 'sine' and not self.no_pe: 411 | in_features = in_features * hidden_features 412 | self.pe = FFPositionalEncoding(hidden_features, pe_scale, dims=dims) 413 | 414 | self.net = FCBlock(in_features=in_features, 415 | out_features=out_features, 416 | num_hidden_layers=num_hidden_layers, 417 | hidden_features=hidden_features, 418 | outermost_linear=True, 419 | nonlinearity=nl, 420 | w0=w0).net 421 | 422 | if not integrated_pe: 423 | self.output_linear = nn.ModuleList([nn.Linear(hidden_features, out_features) 424 | for i in range(num_hidden_layers)]) 425 | self.net = self.net[:-1] 426 | 427 | print(self) 428 | 429 | def net_forward(self, coords): 430 | outputs = [] 431 | 432 | if self.use_sigmoid and self.nl != 'sine': 433 | def out_nl(x): return torch.sigmoid(x) 434 | else: 435 | def out_nl(x): return x 436 | 437 | # mipnerf baseline 438 | if self.integrated_pe: 439 | for c in coords: 440 | outputs.append(out_nl(self.net(c))) 441 | else: 442 | out = self.net[0](coords) 443 | outputs.append(self.output_linear[0](out)) 444 | for i, n in enumerate(self.net[1:]): 445 | # run main branch 446 | out = n(out) 447 | 448 | # extract intermediate output 449 | outputs.append(out_nl(self.output_linear[i](out))) 450 | 451 | return outputs 452 | 453 | def forward(self, model_input): 454 | 455 | coords = model_input['coords'] 456 | 457 | if self.integrated_pe: 458 | coords = [self.pe(coords, r) for r in model_input['radii']] 459 | elif self.nl != 'sine' and not self.no_pe: 460 | coords = self.pe(coords) 461 | 462 | output = self.net_forward(coords) 463 | 464 | return {'model_in': model_input, 'model_out': {'output': output}} 465 | 466 | 467 | class CoordinateNet(nn.Module): 468 | '''A canonical coordinate network''' 469 | def __init__(self, out_features=1, nl='sine', in_features=1, 470 | hidden_features=256, num_hidden_layers=3, 471 | w0=30, pe_scale=5, use_sigmoid=True, no_pe=False, 472 | is_sdf=False, **kwargs): 473 | 474 | super().__init__() 475 | 476 | self.nl = nl 477 | dims = in_features 478 | self.use_sigmoid = use_sigmoid 479 | self.no_pe = no_pe 480 | self.is_sdf = is_sdf 481 | 482 | if self.nl != 'sine' and not self.no_pe: 483 | in_features = hidden_features # in_features * hidden_features 484 | 485 | self.pe = FFPositionalEncoding(hidden_features, pe_scale, dims=dims) 486 | 487 | self.net = FCBlock(in_features=in_features, 488 | out_features=out_features, 489 | num_hidden_layers=num_hidden_layers, 490 | hidden_features=hidden_features, 491 | outermost_linear=True, 492 | nonlinearity=nl, 493 | w0=w0) 494 | print(self) 495 | 496 | def forward(self, model_input): 497 | 498 | coords = model_input['coords'] 499 | 500 | if self.nl != 'sine' and not self.no_pe: 501 | coords_pe = self.pe(coords) 502 | output = self.net(coords_pe) 503 | if self.use_sigmoid: 504 | output = torch.sigmoid(output) 505 | else: 506 | output = self.net(coords) 507 | 508 | if self.is_sdf: 509 | return {'model_in': model_input, 'model_out': output} 510 | 511 | else: 512 | return {'model_in': model_input, 'model_out': {'output': output}} 513 | 514 | 515 | def IntegratedPositionalEncoding(coords, radius, L=8): 516 | 517 | # adapted from mipnerf https://github.com/google/mipnerf 518 | def expected_sin(x, x_var): 519 | """Estimates mean and variance of sin(z), z ~ N(x, var).""" 520 | 521 | # When the variance is wide, shrink sin towards zero. 522 | y = torch.exp(-0.5 * x_var) * torch.sin(x) 523 | y_var = torch.clip(0.5 * (1 - torch.exp(-2 * x_var) * torch.cos(2 * x)) - y**2, 0) 524 | return y, y_var 525 | 526 | def integrated_pos_enc(x_coord, min_deg, max_deg): 527 | """Encode `x` with sinusoids scaled by 2^[min_deg:max_deg-1].""" 528 | 529 | x, x_cov_diag = x_coord 530 | scales = torch.tensor([2**i for i in range(int(min_deg), int(max_deg))], device=x.device) 531 | shape = list(x.shape[:-1]) + [-1] 532 | 533 | y = torch.reshape(x[..., None, :] * scales[:, None], shape) 534 | y_var = torch.reshape(x_cov_diag[..., None, :] * scales[:, None]**2, shape) 535 | 536 | return expected_sin( 537 | torch.cat([y, y + 0.5 * np.pi], dim=-1), 538 | torch.cat([y_var] * 2, dim=-1))[0] 539 | 540 | means = coords 541 | covs = (radius**2 / 4) * torch.ones((1, 2), device=coords.device).repeat(coords.shape[-2], 1) 542 | return integrated_pos_enc((means, covs), 0, L) 543 | 544 | 545 | class FFPositionalEncoding(nn.Module): 546 | def __init__(self, embedding_size, scale, dims=2, gaussian=True): 547 | super().__init__() 548 | self.embedding_size = embedding_size 549 | self.scale = scale 550 | 551 | if gaussian: 552 | bvals = torch.randn(embedding_size // 2, dims) * scale 553 | else: 554 | bvals = 2.**torch.linspace(0, scale, embedding_size//2) - 1 555 | 556 | if dims == 1: 557 | bvals = bvals[:, None] 558 | 559 | elif dims == 2: 560 | bvals = torch.stack([bvals, torch.zeros_like(bvals)], dim=-1) 561 | bvals = torch.cat([bvals, torch.roll(bvals, 1, -1)], dim=0) 562 | 563 | else: 564 | tmp = (dims-1)*(torch.zeros_like(bvals),) 565 | bvals = torch.stack([bvals, *tmp], dim=-1) 566 | 567 | tmp = [torch.roll(bvals, i, -1) for i in range(1, dims)] 568 | bvals = torch.cat([bvals, *tmp], dim=0) 569 | 570 | avals = torch.ones((bvals.shape[0])) 571 | self.avals = nn.Parameter(avals, requires_grad=False) 572 | self.bvals = nn.Parameter(bvals, requires_grad=False) 573 | 574 | def forward(self, tensor) -> torch.Tensor: 575 | """ 576 | Apply positional encoding to the input. 577 | """ 578 | 579 | return torch.cat([self.avals * torch.sin((2.*np.pi*tensor) @ self.bvals.T), 580 | self.avals * torch.cos((2.*np.pi*tensor) @ self.bvals.T)], dim=-1) 581 | 582 | 583 | class PositionalEncoding(nn.Module): 584 | def __init__(self, num_encoding_functions=6, include_input=True, log_sampling=True, normalize=False, 585 | input_dim=3, gaussian_pe=False, gaussian_variance=38): 586 | super().__init__() 587 | self.num_encoding_functions = num_encoding_functions 588 | self.include_input = include_input 589 | self.log_sampling = log_sampling 590 | self.normalize = normalize 591 | self.gaussian_pe = gaussian_pe 592 | self.normalization = None 593 | 594 | if self.gaussian_pe: 595 | # this needs to be registered as a parameter so that it is saved in the model state dict 596 | # and so that it is converted using .cuda(). Doesn't need to be trained though 597 | self.gaussian_weights = nn.Parameter(gaussian_variance * torch.randn(num_encoding_functions, input_dim), 598 | requires_grad=False) 599 | 600 | else: 601 | self.frequency_bands = None 602 | if self.log_sampling: 603 | self.frequency_bands = 2.0 ** torch.linspace( 604 | 0.0, 605 | self.num_encoding_functions - 1, 606 | self.num_encoding_functions) 607 | else: 608 | self.frequency_bands = torch.linspace( 609 | 2.0 ** 0.0, 610 | 2.0 ** (self.num_encoding_functions - 1), 611 | self.num_encoding_functions) 612 | 613 | if normalize: 614 | self.normalization = torch.tensor(1/self.frequency_bands) 615 | 616 | def forward(self, tensor) -> torch.Tensor: 617 | r"""Apply positional encoding to the input. 618 | 619 | Args: 620 | tensor (torch.Tensor): Input tensor to be positionally encoded. 621 | encoding_size (optional, int): Number of encoding functions used to compute 622 | a positional encoding (default: 6). 623 | include_input (optional, bool): Whether or not to include the input in the 624 | positional encoding (default: True). 625 | 626 | Returns: 627 | (torch.Tensor): Positional encoding of the input tensor. 628 | """ 629 | 630 | encoding = [tensor] if self.include_input else [] 631 | if self.gaussian_pe: 632 | for func in [torch.sin, torch.cos]: 633 | encoding.append(func(torch.matmul(tensor, self.gaussian_weights.T))) 634 | else: 635 | for idx, freq in enumerate(self.frequency_bands): 636 | for func in [torch.sin, torch.cos]: 637 | if self.normalization is not None: 638 | encoding.append(self.normalization[idx]*func(tensor * freq)) 639 | else: 640 | encoding.append(func(tensor * freq)) 641 | 642 | # Special case, for no positional encoding 643 | if len(encoding) == 1: 644 | return encoding[0] 645 | else: 646 | return torch.cat(encoding, dim=-1) 647 | 648 | 649 | def layer_factory(layer_type, w0=30): 650 | layer_dict = \ 651 | { 652 | 'relu': (nn.ReLU(inplace=True), init_weights_uniform), 653 | 'sigmoid': (nn.Sigmoid(), None), 654 | 'fsine': (Sine(), first_layer_sine_init), 655 | 'sine': (Sine(w0=w0), partial(sine_init, w0=w0)), 656 | 'tanh': (nn.Tanh(), init_weights_xavier), 657 | 'selu': (nn.SELU(inplace=True), init_weights_selu), 658 | 'gelu': (nn.GELU(), init_weights_selu), 659 | 'swish': (Swish(), init_weights_selu), 660 | 'softplus': (nn.Softplus(), init_weights_normal), 661 | 'msoftplus': (MSoftplus(), init_weights_normal), 662 | 'elu': (nn.ELU(), init_weights_elu) 663 | } 664 | return layer_dict[layer_type] 665 | 666 | 667 | class FCBlock(nn.Module): 668 | '''A fully connected neural network that also allows swapping out the weights when used with a hypernetwork. 669 | Can be used just as a normal neural network though, as well. 670 | ''' 671 | def __init__(self, in_features, out_features, 672 | num_hidden_layers, hidden_features, 673 | outermost_linear=False, nonlinearity='relu', 674 | weight_init=None, w0=30, set_bias=None, 675 | dropout=0.0): 676 | super().__init__() 677 | 678 | self.first_layer_init = None 679 | self.dropout = dropout 680 | 681 | # Create hidden features list 682 | if not isinstance(hidden_features, list): 683 | num_hidden_features = hidden_features 684 | hidden_features = [] 685 | for i in range(num_hidden_layers+1): 686 | hidden_features.append(num_hidden_features) 687 | else: 688 | num_hidden_layers = len(hidden_features)-1 689 | print(f"net_size={hidden_features}") 690 | 691 | # Create the net 692 | print(f"num_layers={len(hidden_features)}") 693 | if isinstance(nonlinearity, list): 694 | print(f"num_non_lin={len(nonlinearity)}") 695 | assert len(hidden_features) == len(nonlinearity), "Num hidden layers needs to " \ 696 | "match the length of the list of non-linearities" 697 | 698 | self.net = [] 699 | self.net.append(nn.Sequential( 700 | nn.Linear(in_features, hidden_features[0]), 701 | layer_factory(nonlinearity[0])[0] 702 | )) 703 | for i in range(num_hidden_layers): 704 | self.net.append(nn.Sequential( 705 | nn.Linear(hidden_features[i], hidden_features[i+1]), 706 | layer_factory(nonlinearity[i+1])[0] 707 | )) 708 | 709 | if outermost_linear: 710 | self.net.append(nn.Sequential( 711 | nn.Linear(hidden_features[-1], out_features), 712 | )) 713 | else: 714 | self.net.append(nn.Sequential( 715 | nn.Linear(hidden_features[-1], out_features), 716 | layer_factory(nonlinearity[-1])[0] 717 | )) 718 | elif isinstance(nonlinearity, str): 719 | nl, weight_init = layer_factory(nonlinearity, w0=w0) 720 | if(nonlinearity == 'sine'): 721 | first_nl = FirstSine(w0=w0) 722 | self.first_layer_init = first_layer_sine_init 723 | else: 724 | first_nl = nl 725 | 726 | if weight_init is not None: 727 | self.weight_init = weight_init 728 | 729 | self.net = [] 730 | self.net.append(nn.Sequential( 731 | nn.Linear(in_features, hidden_features[0]), 732 | first_nl 733 | )) 734 | 735 | for i in range(num_hidden_layers): 736 | if(self.dropout > 0): 737 | self.net.append(nn.Dropout(self.dropout)) 738 | self.net.append(nn.Sequential( 739 | nn.Linear(hidden_features[i], hidden_features[i+1]), 740 | copy.deepcopy(nl) 741 | )) 742 | 743 | if (self.dropout > 0): 744 | self.net.append(nn.Dropout(self.dropout)) 745 | if outermost_linear: 746 | self.net.append(nn.Sequential( 747 | nn.Linear(hidden_features[-1], out_features), 748 | )) 749 | else: 750 | self.net.append(nn.Sequential( 751 | nn.Linear(hidden_features[-1], out_features), 752 | copy.deepcopy(nl) 753 | )) 754 | 755 | self.net = nn.Sequential(*self.net) 756 | 757 | if isinstance(nonlinearity, list): 758 | for layer_num, layer_name in enumerate(nonlinearity): 759 | self.net[layer_num].apply(layer_factory(layer_name, w0=w0)[1]) 760 | elif isinstance(nonlinearity, str): 761 | if self.weight_init is not None: 762 | self.net.apply(self.weight_init) 763 | 764 | if self.first_layer_init is not None: 765 | self.net[0].apply(self.first_layer_init) 766 | 767 | if set_bias is not None: 768 | self.net[-1][0].bias.data = set_bias * torch.ones_like(self.net[-1][0].bias.data) 769 | 770 | def forward(self, coords): 771 | output = self.net(coords) 772 | return output 773 | 774 | 775 | class RadianceNet(nn.Module): 776 | def __init__(self, in_features=2, out_features=1, 777 | hidden_features=256, num_hidden_layers=3, w0=30, 778 | input_pe_params=[('ray_samples', 3, 10), ('ray_orientations', 6, 4)], 779 | nl='relu'): 780 | 781 | super().__init__() 782 | self.input_pe_params = input_pe_params 783 | self.nl = nl 784 | 785 | self.positional_encoding_fn = {} 786 | if nl != 'sine': 787 | for input_to_encode, input_dim, num_pe_fns in self.input_pe_params: 788 | self.positional_encoding_fn[input_to_encode] = PositionalEncoding(num_encoding_functions=num_pe_fns, 789 | input_dim=input_dim) 790 | 791 | self.net = FCBlock(in_features=in_features, out_features=out_features, num_hidden_layers=num_hidden_layers, 792 | hidden_features=hidden_features, outermost_linear=True, nonlinearity=nl, w0=w0) 793 | 794 | print(self) 795 | 796 | def forward(self, model_input): 797 | 798 | input_dict = {key: input.clone().detach().requires_grad_(True) 799 | for key, input in model_input.items()} 800 | 801 | if self.nl != 'sine': 802 | for input_to_encode, input_dim, num_pe_fns in self.input_pe_params: 803 | encoded_input = self.positional_encoding_fn[input_to_encode](input_dict[input_to_encode]) 804 | input_dict.update({input_to_encode: encoded_input}) 805 | 806 | input_list = [] 807 | for input_name, _, _ in self.input_pe_params: 808 | input_list.append(input_dict[input_name]) 809 | 810 | coords = torch.cat(input_list, dim=-1) 811 | 812 | if coords.ndim == 2: 813 | coords = coords[None, :, :] 814 | 815 | output = self.net(coords) 816 | 817 | output_dict = {'output': output} 818 | return {'model_in': input_dict, 'model_out': output_dict} 819 | -------------------------------------------------------------------------------- /spectrum_visualization/README.md: -------------------------------------------------------------------------------- 1 | A few folks have asked how the shape spectra are visualized so here is some demo code that should allow you to reproduce those figures for the armadillo scene. 2 | 3 | Prerequisites: 4 | - create a conda environment from the environment.yml file 5 | - install imagemagick and convert utility https://imagemagick.org/script/convert.php 6 | - install chimerax from https://www.cgl.ucsf.edu/chimerax/ and make sure the executible is on the path as 'chimerax' 7 | 8 | Then, run `get_shape_spectra.py` which will run the model to extract the `armadillo` scene mesh and then visualize the fourier transform 9 | -------------------------------------------------------------------------------- /spectrum_visualization/environment.yml: -------------------------------------------------------------------------------- 1 | name: bacon 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - numpy= 7 | - python=3.9 8 | - pytorch 9 | - tensorboard 10 | - torchvision 11 | - pip 12 | - pip: 13 | - argparse==1.4.0 14 | - configargparse==1.5.3 15 | - gdown==4.2.0 16 | - matplotlib==3.4.3 17 | - pymcubes==0.1.2 18 | - scikit-image 19 | - scipy 20 | - tqdm 21 | - trimesh 22 | - mrcfile 23 | - pykdtree 24 | -------------------------------------------------------------------------------- /spectrum_visualization/get_shape_spectra.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | 6 | import modules 7 | import torch 8 | import numpy as np 9 | from tqdm import tqdm 10 | from pykdtree.kdtree import KDTree 11 | import mrcfile 12 | 13 | device = torch.device('mps') 14 | 15 | 16 | class HiddenPrints: 17 | def __enter__(self): 18 | self._original_stdout = sys.stdout 19 | sys.stdout = open(os.devnull, 'w') 20 | 21 | def __exit__(self, exc_type, exc_val, exc_tb): 22 | sys.stdout.close() 23 | sys.stdout = self._original_stdout 24 | 25 | 26 | def export_model(ckpt_path, model_name, N=512, model_type='bacon', hidden_layers=8, 27 | hidden_size=256, output_layers=[1, 2, 4, 8], w0=30, pe=8, 28 | filter_mesh=False, scaling=None, return_sdf=False): 29 | 30 | with HiddenPrints(): 31 | # the network has 4 output levels of detail 32 | num_outputs = len(output_layers) 33 | max_frequency = 3*(32,) 34 | 35 | # load model 36 | if len(output_layers) > 1: 37 | model = modules.MultiscaleBACON(3, hidden_size, 1, 38 | hidden_layers=hidden_layers, 39 | bias=True, 40 | frequency=max_frequency, 41 | quantization_interval=np.pi, 42 | is_sdf=True, 43 | output_layers=output_layers, 44 | reuse_filters=True) 45 | 46 | print(model) 47 | ckpt = torch.load(ckpt_path, map_location=device) 48 | model.load_state_dict(ckpt) 49 | model = model.to(device) 50 | 51 | # write output 52 | x = torch.linspace(-0.5, 0.5, N) 53 | if return_sdf: 54 | x = torch.arange(-N//2, N//2) / N 55 | x = x.float() 56 | x, y, z = torch.meshgrid(x, x, x) 57 | render_coords = torch.stack((x.flatten(), y.flatten(), z.flatten()), dim=-1).to(device) 58 | sdf_values = [np.zeros((N**3, 1)) for i in range(num_outputs)] 59 | 60 | # render in a batched fashion to save memory 61 | bsize = int(128**2) 62 | for i in tqdm(range(int(N**3 / bsize))): 63 | coords = render_coords[i*bsize:(i+1)*bsize, :] 64 | out = model({'coords': coords})['model_out'] 65 | 66 | if not isinstance(out, list): 67 | out = [out,] 68 | 69 | for idx, sdf in enumerate(out): 70 | sdf_values[idx][i*bsize:(i+1)*bsize] = sdf.detach().cpu().numpy() 71 | 72 | return [sdf.reshape(N, N, N) for sdf in sdf_values] 73 | 74 | 75 | def normalize(coords, scaling=0.9): 76 | coords = np.array(coords).copy() 77 | cmean = np.mean(coords, axis=0, keepdims=True) 78 | coords -= cmean 79 | coord_max = np.amax(coords) 80 | coord_min = np.amin(coords) 81 | coords = (coords - coord_min) / (coord_max - coord_min) 82 | coords -= 0.5 83 | coords *= scaling 84 | 85 | scale = scaling / (coord_max - coord_min) 86 | offset = -scaling * (cmean + coord_min) / (coord_max - coord_min) - 0.5*scaling 87 | return coords, scale, offset 88 | 89 | 90 | def get_ref_spectrum(xyz_file, N): 91 | pointcloud = np.genfromtxt(xyz_file) 92 | v = pointcloud[:, :3] 93 | n = pointcloud[:, 3:] 94 | 95 | n = n / (np.linalg.norm(n, axis=-1)[:, None]) 96 | v, _, _ = normalize(v) 97 | print('loaded pc') 98 | 99 | # put pointcloud points into KDTree 100 | kd_tree = KDTree(v) 101 | print('made kd tree') 102 | 103 | # get sdf on grid and show 104 | x = (np.arange(-N//2, N//2) / N).astype(np.float32) 105 | coords = np.stack([arr.flatten() for arr in np.meshgrid(x, x, x)], axis=-1) 106 | 107 | sdf, idx = kd_tree.query(coords, k=3) 108 | 109 | # get average normal of hit point 110 | avg_normal = np.mean(n[idx], axis=1) 111 | sdf = np.sum((coords - v[idx][:, 0]) * avg_normal, axis=-1) 112 | sdf = sdf.reshape(N, N, N) 113 | return [sdf, ] 114 | 115 | 116 | def extract_spectrum(): 117 | 118 | scenes = ['armadillo'] 119 | Ns = [384, 384, 384, 512, 512] 120 | methods = ['bacon', 'ref'] 121 | 122 | for method in methods: 123 | for scene, N in zip(scenes, Ns): 124 | if method == 'ref': 125 | sdfs = get_ref_spectrum(f'ref_{scene}.xyz', N) 126 | else: 127 | 128 | ckpt = 'model_final.pth' 129 | 130 | sdfs = export_model(ckpt, scene, model_type=method, output_layers=[2, 4, 6, 8], 131 | return_sdf=True, N=N, pe=8, w0=30) 132 | 133 | sdfs_ft = [np.abs(np.fft.fftshift(np.fft.fftn(sdf))) for sdf in sdfs] 134 | sdfs_ft = [sdf_ft / np.max(sdf_ft)*1000 for sdf_ft in sdfs_ft] 135 | sdfs_ft = [np.clip(sdf_ft, 0, 1)**(1/3) for sdf_ft in sdfs_ft] 136 | 137 | for idx, sdf_ft in enumerate(sdfs_ft): 138 | with mrcfile.new_mmap(f'/tmp/sdf_ft_{idx}.mrc', overwrite=True, shape=(N, N, N), mrc_mode=2) as mrc: 139 | mrc.data[:] = sdf_ft 140 | 141 | # render with chimera 142 | with open('/tmp/render.cxc', 'w') as f: 143 | for i in range(len(sdfs_ft)): 144 | f.write(f'open /tmp/sdf_ft_{i}.mrc\n') 145 | f.write('volume #1 style solid level 0,0 level 1,1\n') 146 | f.write('volume #1 maximumIntensityProjection true\n') 147 | f.write('volume #!1 showOutlineBox true\n') 148 | f.write('volume #1 outlineBoxRgb slate gray\n') 149 | f.write('volume #1 step 1\n') 150 | f.write('view matrix camera 0.91721,0.10246,-0.38499,-533.37,-0.010261,0.97212,0.23427,631.78,0.39826,-0.21092,0.89269,1870.6\n') 151 | f.write('view\n') 152 | f.write(f'save /tmp/{method}_{scene}_{i+1}.png width 512 height 512 transparentBackground false\n') 153 | f.write('close #1\n') 154 | f.write('exit\n') 155 | os.system('chimerax /tmp/render.cxc') 156 | 157 | for idx in range(len(sdfs_ft)): 158 | fname = f'{method}_{scene}_{idx+1}.png' 159 | os.system(f'./magicwand 1,1 -t 9 -r outside -m overlay -o 0 /tmp/{fname} {fname}') 160 | os.system(f'convert {fname} -trim +repage {fname}') 161 | os.remove(f'/tmp/{fname}') 162 | 163 | # clean up 164 | for idx in range(len(sdfs_ft)): 165 | os.remove(f'/tmp/sdf_ft_{idx}.mrc') 166 | os.remove('/tmp/render.cxc') 167 | 168 | 169 | if __name__ == '__main__': 170 | extract_spectrum() 171 | -------------------------------------------------------------------------------- /spectrum_visualization/magicwand: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Developed by Fred Weinhaus 11/2/2007 .......... revised 4/30/2015 4 | # 5 | # ------------------------------------------------------------------------------ 6 | # 7 | # Licensing: 8 | # 9 | # Copyright © Fred Weinhaus 10 | # 11 | # My scripts are available free of charge for non-commercial use, ONLY. 12 | # 13 | # For use of my scripts in commercial (for-profit) environments or 14 | # non-free applications, please contact me (Fred Weinhaus) for 15 | # licensing arrangements. My email address is fmw at alink dot net. 16 | # 17 | # If you: 1) redistribute, 2) incorporate any of these scripts into other 18 | # free applications or 3) reprogram them in another scripting language, 19 | # then you must contact me for permission, especially if the result might 20 | # be used in a commercial or for-profit environment. 21 | # 22 | # My scripts are also subject, in a subordinate manner, to the ImageMagick 23 | # license, which can be found at: http://www.imagemagick.org/script/license.php 24 | # 25 | # ------------------------------------------------------------------------------ 26 | # 27 | #### 28 | # 29 | # USAGE: magicwand x,y [-t threshold] [-f format] [-r region] [-m mask] [-c color] [-o opacity] infile outfile 30 | # USAGE: magicwand [-h or -help] 31 | # 32 | # OPTIONS: 33 | # 34 | # x,y x,y location to get color and seed floodfill 35 | # -t threshold percent color disimilarity threshold (fuzz factor); 36 | # values from 0 to 100; A value of 0 is an exact 37 | # match and a value of 100 is any color; default=10 38 | # -f format output format; image or mask; default=image 39 | # -r region region to display; inside or outside; default=inside 40 | # -m mask mask type; transparent, binary, edge, overlay, layer; 41 | # default=transparent 42 | # -c color color for background, edge outline or translucent layer; 43 | # color="none" indicates use opacity; 44 | # color="trans" indicates make background transparent; 45 | # default="black" 46 | # -o opacity opacity for transparent, overlay or layer mask; 47 | # values from 0 to 100; default=0 48 | # 49 | ### 50 | # 51 | # NAME: MAGICWAND 52 | # 53 | # PURPOSE: To isolate a contiguous region of an image based upon a color determined 54 | # from a user specified image coordinate. 55 | # 56 | # DESCRIPTION: MAGICWAND determines a contiguous region of an image based 57 | # upon a color determined from a user specified image coordinate and a color 58 | # similarity threshold value (fuzz factor). The output can be either an 59 | # image or a mask. If the region is set to inside, then the image can be 60 | # made to mask out the background as transparent, mask out the background 61 | # using an opacity channel, fill the background with a color, display a 62 | # boundary (outline) edge for the region or apply a translucent color layer 63 | # over the background. If the region is set to outside, then the inside will 64 | # be masked and the outside area will show normally. Alternately, the output 65 | # can be a mask which is either binary (black and white), transparent 66 | # (transparent and white) or a boundary edge (white on black background). 67 | # The boundary edge can be made to match the interior or exterior depending 68 | # upon the region setting. 69 | # 70 | # 71 | # OPTIONS: 72 | # 73 | # x,y ... x,y are the coordinates in the image where the color is to be 74 | # extracted and the floodfill is to be seeded. 75 | # 76 | # -f format ... FORMAT specifies whether the output will be the modified 77 | # input image or a mask image. The choices are image or mask. The default 78 | # is image. 79 | # 80 | # -r region ... REGION specifies whether the inside or outside are of the 81 | # image will show and the other be masked. The choices are inside or 82 | # outside. The default is inside. 83 | # 84 | # -m mask ... MASK specifies the type of mask to use or create. The choices 85 | # are transparent, binary, edge, overlay or layer. Only transparent, binary or edge 86 | # masks will be allowed as output. With a transparent mask, the image will be 87 | # modified to mask out the complement of the regions specified according to the 88 | # color setting. Specify the color setting to 1) "trans" to make the background 89 | # transparent, 2) "none" to mask by multiplying by the opacity setting or 90 | # 3) a color value to use a fill color. With an overlay mask, which will only 91 | # be effective on PNG format output images, the masking is done via the opacity 92 | # channel using the opacity setting. With a layer mask, a translucent color will 93 | # be layered over the background according to the color and opacity values. The 94 | # larger the opacity, the lighter the color overlay and the more the image will 95 | # show. 96 | # 97 | # -c color ... COLOR is the color to be used for the background fill or 98 | # the boundary edge overlaid on the image. Any IM color specification is 99 | # valid or a value of none. Be sure to enclose them in double quotes. 100 | # The color value over-rides the opacity for mask=transparent, so use 101 | # color="none" to allow the opacity to work or use color="trans" to make 102 | # the background transparent. The default="black". 103 | # 104 | # -o opacity ... OPACITY controls the degree of transparency for the area 105 | # that is masked out using a mask setting of transparent, overlay or layer. 106 | # A value of zero is fully transparent and a value of 100 is fully opaque. 107 | # 108 | # CAVEAT: No guarantee that this script will work on all platforms, 109 | # nor that trapping of inconsistent parameters is complete and 110 | # foolproof. Use At Your Own Risk. 111 | # 112 | ###### 113 | # 114 | 115 | # set default values 116 | threshold=10 117 | format="image" # image or mask 118 | region="inside" # inside or outside 119 | mask="trans" # trans or binary or edge or overlay (overlay only if png) 120 | bgcolor="black" # color or none when format=image 121 | opacity=0 122 | 123 | # set directory for temporary files 124 | dir="." # suggestions are dir="." or dir="/tmp" 125 | 126 | 127 | # set up functions to report Usage and Usage with Description 128 | PROGNAME=`type $0 | awk '{print $3}'` # search for executable on path 129 | PROGDIR=`dirname $PROGNAME` # extract directory of program 130 | PROGNAME=`basename $PROGNAME` # base name of program 131 | usage1() 132 | { 133 | echo >&2 "" 134 | echo >&2 "$PROGNAME:" "$@" 135 | sed >&2 -e '1,/^####/d; /^###/g; /^#/!q; s/^#//; s/^ //; 4,$p' "$PROGDIR/$PROGNAME" 136 | } 137 | usage2() 138 | { 139 | echo >&2 "" 140 | echo >&2 "$PROGNAME:" "$@" 141 | sed >&2 -e '1,/^####/d; /^######/g; /^#/!q; s/^#*//; s/^ //; 4,$p' "$PROGDIR/$PROGNAME" 142 | } 143 | 144 | 145 | # function to report error messages 146 | errMsg() 147 | { 148 | echo "" 149 | echo $1 150 | echo "" 151 | usage1 152 | exit 1 153 | } 154 | 155 | 156 | # function to test for minus at start of value of second part of option 1 or 2 157 | checkMinus() 158 | { 159 | test=`echo "$1" | grep -c '^-.*$'` # returns 1 if match; 0 otherwise 160 | [ $test -eq 1 ] && errMsg "$errorMsg" 161 | } 162 | 163 | # test for correct number of arguments and get values 164 | if [ $# -eq 0 ] 165 | then 166 | # help information 167 | echo "" 168 | usage2 169 | exit 0 170 | elif [ $# -gt 15 ] 171 | then 172 | errMsg "--- TOO MANY ARGUMENTS WERE PROVIDED ---" 173 | else 174 | while [ $# -gt 0 ] 175 | do 176 | # get parameter values 177 | case "$1" in 178 | -h|-help) # help information 179 | echo "" 180 | usage2 181 | exit 0 182 | ;; 183 | -f) # format 184 | shift # to get the next parameter - format 185 | # test if parameter starts with minus sign 186 | errorMsg="--- INVALID FORMAT SPECIFICATION ---" 187 | checkMinus "$1" 188 | # test region values 189 | format="$1" 190 | [ "$format" != "image" -a "$format" != "mask" ] && errMsg "--- FORMAT=$format IS NOT A VALID VALUE ---" 191 | ;; 192 | -r) # region 193 | shift # to get the next parameter - region 194 | # test if parameter starts with minus sign 195 | errorMsg="--- INVALID REGION SPECIFICATION ---" 196 | checkMinus "$1" 197 | # test region values 198 | region="$1" 199 | [ "$region" != "inside" -a "$region" != "outside" ] && errMsg "--- REGION=$region IS NOT A VALID VALUE ---" 200 | ;; 201 | -m) # mask 202 | shift # to get the next parameter - mask 203 | # test if parameter starts with minus sign 204 | errorMsg="--- INVALID MASK SPECIFICATION ---" 205 | checkMinus "$1" 206 | # test mask values 207 | mask="$1" 208 | [ "$mask" != "transparent" -a "$mask" != "binary" -a "$mask" != "edge" -a "$mask" != "overlay" -a "$mask" != "layer" ] && errMsg "--- MASK=$mask IS NOT A VALID VALUE ---" 209 | [ "$mask" = "transparent" ] && mask="trans" 210 | ;; 211 | -c) # get color 212 | shift # to get the next parameter - lineval 213 | # test if parameter starts with minus sign 214 | errorMsg="--- INVALID COLOR SPECIFICATION ---" 215 | checkMinus "$1" 216 | bgcolor="$1" 217 | ;; 218 | -t) # get threshold 219 | shift # to get the next parameter - threshold 220 | # test if parameter starts with minus sign 221 | errorMsg="--- INVALID THRESHOLD SPECIFICATION ---" 222 | checkMinus "$1" 223 | # test threshold values 224 | threshold=`expr "$1" : '\([.0-9]*\)'` 225 | [ "$threshold" = "" ] && errMsg "THRESHOLD=$threshold IS NOT A NON-NEGATIVE FLOATING POINT NUMBER" 226 | thresholdtestA=`echo "$threshold < 0" | bc` 227 | thresholdtestB=`echo "$threshold > 100" | bc` 228 | [ $thresholdtestA -eq 1 -o $thresholdtestB -eq 1 ] && errMsg "--- THRESHOLD=$threshold MUST BE GREATER THAN OR EQUAL 0 AND LESS THAN OR EQUAL 100 ---" 229 | ;; 230 | -o) # get opacity 231 | shift # to get the next parameter - opacity 232 | # test if parameter starts with minus sign 233 | errorMsg="--- INVALID OPACITY SPECIFICATION ---" 234 | checkMinus "$1" 235 | # test width values 236 | opacity=`expr "$1" : '\([.0-9]*\)'` 237 | [ "$opacity" = "" ] && errMsg "OPACITY=$opacity IS NOT A NON-NEGATIVE FLOATING POINT NUMBER" 238 | opacitytestA=`echo "$opacity < 0" | bc` 239 | opacitytestB=`echo "$opacity > 100" | bc` 240 | [ $opacitytestA -eq 1 -o $opacitytestB -eq 1 ] && errMsg "--- OPACITY=$opacity MUST BE GREATER THAN OR EQUAL 0 AND LESS THAN OR EQUAL 100 ---" 241 | ;; 242 | -) # STDIN, end of arguments 243 | break 244 | ;; 245 | -*) # any other - argument 246 | errMsg "--- UNKNOWN OPTION ---" 247 | ;; 248 | [0-9]*,[0-9]*) # Values supplied for coordinates 249 | coords="$1" 250 | ;; 251 | .*,.*) # Bogus Values supplied 252 | errMsg "--- COORDINATES ARE NOT VALID ---" 253 | ;; 254 | *) # end of arguments 255 | break 256 | ;; 257 | esac 258 | shift # next option 259 | done 260 | # 261 | # get infile and outfile 262 | infile="$1" 263 | outfile="$2" 264 | fi 265 | 266 | 267 | # test that infile provided 268 | [ "$infile" = "" ] && errMsg "NO INPUT FILE SPECIFIED" 269 | 270 | # test that outfile provided 271 | [ "$outfile" = "" ] && errMsg "NO OUTPUT FILE SPECIFIED" 272 | 273 | tmpA="$dir/magicwand_$$.mpc" 274 | tmpB="$dir/magicwand_$$.cache" 275 | tmp0="$dir/magicwand_0_$$.png" 276 | trap "rm -f $tmpA $tmpB $tmp0;" 0 277 | trap "rm -f $tmpA $tmpB $tmp0; exit 1" 1 2 3 15 278 | trap "rm -f $tmpA $tmpB $tmp0; exit 1" ERR 279 | 280 | if convert -quiet "$infile" +repage "$tmpA" 281 | then 282 | width=`identify -format %w $tmpA` 283 | height=`identify -format %h $tmpA` 284 | [ "$coords" = "" ] && errMsg "--- NO COORDINATES PROVIDED ---" 285 | else 286 | errMsg "--- FILE $infile DOES NOT EXIST OR IS NOT AN ORDINARY FILE, NOT READABLE OR HAS ZERO SIZE ---" 287 | fi 288 | 289 | 290 | # get im_version 291 | im_version=`convert -list configure | \ 292 | sed '/^LIB_VERSION_NUMBER /!d; s//,/; s/,/,0/g; s/,0*\([0-9][0-9]\)/\1/g' | head -n 1` 293 | 294 | # set up floodfill 295 | if [ "$im_version" -ge "07000000" ]; then 296 | matte_alpha="alpha" 297 | else 298 | matte_alpha="alpha" 299 | fi 300 | 301 | # create transparent mask for region=inside 302 | # make interior transparent and outside black 303 | convert $tmpA -fuzz $threshold% -fill none -draw "$matte_alpha $coords floodfill" \ 304 | -fill black +opaque none $tmp0 305 | 306 | # create negative mask - for region=outside 307 | if [ "$region" = "outside" ] 308 | then 309 | convert $tmp0 -channel rgba \ 310 | -fill white -opaque none \ 311 | -transparent black \ 312 | -fill black -opaque white $tmp0 313 | 314 | fi 315 | 316 | 317 | # convert mask to binary if appropriate 318 | if [ "$mask" != "trans" -a "$mask" != "layer" ] 319 | then 320 | # make transparent go to white 321 | convert \( -size ${width}x${height} xc:white \) $tmp0 \ 322 | -composite $tmp0 323 | fi 324 | 325 | 326 | #convert mask to edge if appropriate 327 | if [ "$mask" = "edge" ] 328 | then 329 | convert $tmp0 -convolve "-1,-1,-1,-1,8,-1,-1,-1,-1" -clamp $tmp0 330 | fi 331 | 332 | 333 | # process image if appropriate 334 | if [ "$format" = "image" -a "$mask" = "edge" ] 335 | then 336 | convert $tmpA \( $tmp0 -transparent black -fill $bgcolor -opaque white \) -composite $tmp0 337 | 338 | elif [ "$format" = "image" -a "$mask" = "trans" ] 339 | then 340 | # composite with input and set background 341 | if [ "$bgcolor" = "trans" ] 342 | then 343 | convert $tmpA $tmp0 -composite -transparent black $tmp0 344 | elif [ "$bgcolor" = "none" ] 345 | then 346 | convert $tmpA \( $tmp0 -fill "rgb($opacity%,$opacity%,$opacity%)" -opaque black \) -compose Multiply -composite $tmp0 347 | else 348 | convert $tmpA $tmp0 -composite -fill $bgcolor -opaque black $tmp0 349 | fi 350 | 351 | elif [ "$format" = "image" -a "$mask" = "overlay" ] 352 | then 353 | convert $tmpA \( $tmp0 -fill "rgb($opacity%,$opacity%,$opacity%)" -opaque black \) -compose Copy_Opacity -composite $tmp0 354 | 355 | elif [ "$format" = "image" -a "$mask" = "layer" ] 356 | then 357 | opacity=`expr 100 - $opacity` 358 | convert $tmp0 -fill $bgcolor -opaque black $tmp0 359 | if [ "$im_version" -lt "06050304" ]; then 360 | composite -dissolve $opacity% $tmp0 $tmpA $tmp0 361 | else 362 | convert $tmpA $tmp0 -define compose:args=$opacity% -compose dissolve -composite $tmp0 363 | fi 364 | fi 365 | convert $tmp0 "$outfile" 366 | exit 0 367 | -------------------------------------------------------------------------------- /spectrum_visualization/model_final.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/spectrum_visualization/model_final.pth -------------------------------------------------------------------------------- /trained_models/armadillo.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/trained_models/armadillo.pth -------------------------------------------------------------------------------- /trained_models/dragon.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/trained_models/dragon.pth -------------------------------------------------------------------------------- /trained_models/lego.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/trained_models/lego.pth -------------------------------------------------------------------------------- /trained_models/lego_semisupervise.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/trained_models/lego_semisupervise.pth -------------------------------------------------------------------------------- /trained_models/lucy.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/trained_models/lucy.pth -------------------------------------------------------------------------------- /trained_models/thai.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/trained_models/thai.pth -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | import torch 3 | import utils 4 | from tqdm.autonotebook import tqdm 5 | import time 6 | import numpy as np 7 | import os 8 | import forward_models 9 | from functools import partial 10 | import shutil 11 | 12 | 13 | def train(model, train_dataloader, steps, lr, steps_til_summary, 14 | steps_til_checkpoint, model_dir, loss_fn, summary_fn, 15 | prefix_model_dir='', val_dataloader=None, double_precision=False, 16 | clip_grad=False, use_lbfgs=False, loss_schedules=None, params=None, 17 | ckpt_step=0, use_lr_scheduler=False): 18 | 19 | if params is None: 20 | optim = torch.optim.Adam(lr=lr, params=model.parameters(), amsgrad=True) 21 | else: 22 | optim = torch.optim.Adam(lr=lr, params=params, amsgrad=True) 23 | 24 | if use_lbfgs: 25 | optim = torch.optim.LBFGS(lr=lr, params=model.parameters(), 26 | max_iter=50000, max_eval=50000, 27 | history_size=50, line_search_fn='strong_wolfe') 28 | 29 | scheduler = None 30 | if use_lr_scheduler: 31 | def sampling_scheduler(step, start=0, lr0=1e-4, lrn=1e-4): 32 | 33 | if step > start: 34 | fine_scale = lr_log_schedule(step-start, num_steps=steps-start, nw=1, lr0=lr0, lrn=lrn) 35 | train_dataloader.dataset.fine_scale = fine_scale 36 | else: 37 | train_dataloader.dataset.fine_scale = lr0 38 | 39 | # lr scheduler 40 | optim.param_groups[0]['lr'] = 1 41 | log_scheduler = partial(lr_log_schedule, num_steps=steps, nw=1, lr0=lr, lrn=1e-4) 42 | scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=log_scheduler) 43 | 44 | if os.path.exists(model_dir): 45 | pass 46 | else: 47 | os.makedirs(model_dir) 48 | 49 | model_dir_postfixed = os.path.join(model_dir, prefix_model_dir) 50 | 51 | summaries_dir = os.path.join(model_dir_postfixed, 'summaries') 52 | utils.cond_mkdir(summaries_dir) 53 | 54 | checkpoints_dir = os.path.join(model_dir_postfixed, 'checkpoints') 55 | utils.cond_mkdir(checkpoints_dir) 56 | 57 | writer = SummaryWriter(summaries_dir) 58 | 59 | # e.g. epochs=1k, len(train_dataloader)=25 60 | train_generator = iter(train_dataloader) 61 | 62 | with tqdm(total=steps) as pbar: 63 | train_losses = [] 64 | for step in range(steps): 65 | 66 | if not step % steps_til_checkpoint and step: 67 | torch.save(model.state_dict(), 68 | os.path.join(checkpoints_dir, 69 | 'model_step_%04d.pth' % (step + ckpt_step))) 70 | np.savetxt(os.path.join(checkpoints_dir, 71 | 'train_losses_step_%04d.txt' % (step + ckpt_step)), 72 | np.array(train_losses)) 73 | 74 | try: 75 | # sampling_scheduler(step) 76 | model_input, gt = next(train_generator) 77 | except StopIteration: 78 | train_generator = iter(train_dataloader) 79 | model_input, gt = next(train_generator) 80 | 81 | start_time = time.time() 82 | 83 | model_input = dict2cuda(model_input) 84 | gt = dict2cuda(gt) 85 | 86 | if double_precision: 87 | model_input = {key: value.double() 88 | for key, value in model_input.items()} 89 | gt = {key: value.double() for key, value in gt.items()} 90 | 91 | if use_lbfgs: 92 | def closure(): 93 | optim.zero_grad(set_to_none=True) 94 | model_output = model(model_input) 95 | losses = loss_fn(model_output, gt) 96 | train_loss = 0. 97 | for loss_name, loss in losses.items(): 98 | train_loss += loss.mean() 99 | train_loss.backward() 100 | return train_loss 101 | optim.step(closure) 102 | 103 | model_output = model(model_input) 104 | losses = loss_fn(model_output, gt) 105 | 106 | train_loss = 0. 107 | for loss_name, loss in losses.items(): 108 | single_loss = loss.mean() 109 | 110 | if loss_schedules is not None and \ 111 | loss_name in loss_schedules: 112 | writer.add_scalar(loss_name + "_weight", 113 | loss_schedules[loss_name](step), step) 114 | single_loss *= loss_schedules[loss_name](step) 115 | 116 | writer.add_scalar(loss_name, single_loss, step) 117 | train_loss += single_loss 118 | 119 | train_losses.append(train_loss.item()) 120 | writer.add_scalar("total_train_loss", train_loss, step) 121 | writer.add_scalar("lr", optim.param_groups[0]['lr'], step) 122 | 123 | if not step % steps_til_summary: 124 | torch.save(model.state_dict(), 125 | os.path.join(checkpoints_dir, 126 | 'model_current.pth')) 127 | summary_fn(model, model_input, gt, model_output, writer, step) 128 | 129 | if not use_lbfgs: 130 | optim.zero_grad(set_to_none=True) 131 | train_loss.backward() 132 | 133 | if clip_grad: 134 | if isinstance(clip_grad, bool): 135 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.) 136 | else: 137 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad) 138 | 139 | optim.step() 140 | 141 | if scheduler is not None: 142 | scheduler.step() 143 | 144 | pbar.update(1) 145 | 146 | if not step % steps_til_summary: 147 | tqdm.write("Step %d, Total loss %0.6f, iteration time %0.6f" % (step, train_loss, time.time() - start_time)) 148 | 149 | if val_dataloader is not None: 150 | print("Running validation set...") 151 | model.eval() 152 | with torch.no_grad(): 153 | val_losses = [] 154 | for (model_input, gt) in val_dataloader: 155 | model_output = model(model_input) 156 | val_loss = loss_fn(model_output, gt) 157 | val_losses.append(val_loss) 158 | 159 | writer.add_scalar("val_loss", np.mean(val_losses), step) 160 | model.train() 161 | 162 | torch.save(model.state_dict(), 163 | os.path.join(checkpoints_dir, 'model_final.pth')) 164 | np.savetxt(os.path.join(checkpoints_dir, 'train_losses_final.txt'), 165 | np.array(train_losses)) 166 | 167 | 168 | def dict2cuda(a_dict): 169 | tmp = {} 170 | for key, value in a_dict.items(): 171 | if isinstance(value, torch.Tensor): 172 | tmp.update({key: value.cuda()}) 173 | elif isinstance(value, dict): 174 | tmp.update({key: dict2cuda(value)}) 175 | elif isinstance(value, list) or isinstance(value, tuple): 176 | if isinstance(value[0], torch.Tensor): 177 | tmp.update({key: [v.cuda() for v in value]}) 178 | else: 179 | tmp.update({key: value}) 180 | return tmp 181 | 182 | 183 | def dict2cpu(a_dict): 184 | tmp = {} 185 | for key, value in a_dict.items(): 186 | if isinstance(value, torch.Tensor): 187 | tmp.update({key: value.cpu()}) 188 | elif isinstance(value, dict): 189 | tmp.update({key: dict2cpu(value)}) 190 | elif isinstance(value, list): 191 | if isinstance(value[0], torch.Tensor): 192 | tmp.update({key: [v.cpu() for v in value]}) 193 | else: 194 | tmp.update({key: value}) 195 | return tmp 196 | 197 | 198 | def reg_schedule(it, num_steps=1e6, lr0=1e-3, lrn=1e-4): 199 | return np.exp((1 - it/num_steps) * np.log(lr0) + (it/num_steps) * np.log(lrn)) 200 | 201 | 202 | def lr_log_schedule(it, num_steps=1e6, nw=2500, lr0=1e-3, lrn=5e-6, lambdaw=0.01): 203 | return (lambdaw + (1 - lambdaw) * np.sin(np.pi/2 * np.clip(it/nw, 0, 1))) \ 204 | * np.exp((1 - it/num_steps) * np.log(lr0) + (it/num_steps) * np.log(lrn)) 205 | 206 | 207 | def train_wchunks(models, train_dataloader, num_steps, lr, steps_til_summary, steps_til_checkpoint, model_dir, 208 | loss_fn, summary_fn, chunk_lists_from_batch_fn, 209 | val_dataloader=None, double_precision=False, clip_grad=False, loss_schedules=None, 210 | num_cuts=128, 211 | max_chunk_size=4096, 212 | resume_checkpoint={}, 213 | chunked=True, 214 | hierarchical_sampling=False, 215 | coarse_loss_weight=0.1, 216 | stop_after=None): 217 | 218 | optims = {key: torch.optim.Adam(lr=1, params=model.parameters()) 219 | for key, model in models.items()} 220 | schedulers = {key: torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_log_schedule) 221 | for key, optim in optims.items()} 222 | 223 | # load optimizer if supplied 224 | for key in models.keys(): 225 | if key in resume_checkpoint: 226 | optims[key].load_state_dict(resume_checkpoint[key]['optim']) 227 | schedulers[key].load_state_dict(resume_checkpoint[key]['scheduler']) 228 | 229 | if os.path.exists(os.path.join(model_dir, 'summaries')): 230 | val = input("The model directory %s exists. Overwrite? (y/n)" % model_dir) 231 | if val == 'y': 232 | if os.path.exists(os.path.join(model_dir, 'summaries')): 233 | shutil.rmtree(os.path.join(model_dir, 'summaries')) 234 | if os.path.exists(os.path.join(model_dir, 'checkpoints')): 235 | shutil.rmtree(os.path.join(model_dir, 'checkpoints')) 236 | 237 | os.makedirs(model_dir, exist_ok=True) 238 | 239 | summaries_dir = os.path.join(model_dir, 'summaries') 240 | utils.cond_mkdir(summaries_dir) 241 | 242 | checkpoints_dir = os.path.join(model_dir, 'checkpoints') 243 | utils.cond_mkdir(checkpoints_dir) 244 | 245 | writer = SummaryWriter(summaries_dir) 246 | 247 | start_step = 0 248 | if 'step' in resume_checkpoint: 249 | start_step = resume_checkpoint['step'] 250 | 251 | train_generator = iter(train_dataloader) 252 | 253 | with tqdm(total=num_steps) as pbar: 254 | pbar.update(start_step) 255 | train_losses = [] 256 | for step in range(start_step, num_steps): 257 | if not step % steps_til_checkpoint and step: 258 | for key, model in models.items(): 259 | torch.save(model.state_dict(), 260 | os.path.join(checkpoints_dir, 'model_'+key+'_step_%04d.pth' % step)) 261 | np.savetxt(os.path.join(checkpoints_dir, 'train_losses_step_%04d.txt' % step), 262 | np.array(train_losses)) 263 | for key, optim in optims.items(): 264 | torch.save({'step': step, 265 | 'optimizer_state_dict': optim.state_dict(), 266 | 'scheduler_state_dict': schedulers[key].state_dict()}, 267 | os.path.join(checkpoints_dir, 'optim_'+key+'_step_%04d.pth' % step)) 268 | 269 | try: 270 | model_input, meta, gt = next(train_generator) 271 | except StopIteration: 272 | train_dataloader.dataset.shuffle_rays() 273 | train_generator = iter(train_dataloader) 274 | model_input, meta, gt = next(train_generator) 275 | 276 | start_time = time.time() 277 | 278 | for optim in optims.values(): 279 | optim.zero_grad(set_to_none=True) 280 | 281 | batch_avged_losses = {} 282 | if chunked: 283 | list_chunked_model_input, list_chunked_meta, list_chunked_gt = \ 284 | chunk_lists_from_batch_fn(model_input, meta, gt, max_chunk_size) 285 | 286 | num_chunks = len(list_chunked_gt) 287 | batch_avged_tot_loss = 0. 288 | for chunk_idx, (chunked_model_input, chunked_meta, chunked_gt) \ 289 | in enumerate(zip(list_chunked_model_input, list_chunked_meta, list_chunked_gt)): 290 | chunked_model_input = dict2cuda(chunked_model_input) 291 | chunked_meta = dict2cuda(chunked_meta) 292 | chunked_gt = dict2cuda(chunked_gt) 293 | 294 | # forward pass through model 295 | for k in models.keys(): 296 | models[k].stop_after = stop_after 297 | chunk_model_outputs = {key: model(chunked_model_input) for key, model in models.items()} 298 | for k in models.keys(): 299 | models[k].stop_after = None 300 | 301 | losses = {} 302 | 303 | if hierarchical_sampling: 304 | # set idx to use sigma from coarse level (idx=0) 305 | # for hierarchical sampling 306 | chunked_model_input_fine = sample_pdf(chunked_model_input, 307 | chunk_model_outputs, idx=0) 308 | 309 | chunk_model_importance_outputs = {key: model(chunked_model_input_fine) 310 | for key, model in models.items()} 311 | 312 | reg_lambda = reg_schedule(step) 313 | losses_importance = loss_fn(chunk_model_importance_outputs, chunked_gt, 314 | regularize_sigma=True, reg_lambda=reg_lambda) 315 | 316 | # loss from forward pass 317 | train_loss = 0. 318 | for loss_name, loss in losses.items(): 319 | 320 | single_loss = loss.mean() 321 | train_loss += single_loss / num_chunks 322 | 323 | batch_avged_tot_loss += float(single_loss / num_chunks) 324 | if loss_name in batch_avged_losses: 325 | batch_avged_losses[loss_name] += single_loss / num_chunks 326 | else: 327 | batch_avged_losses.update({loss_name: single_loss/num_chunks}) 328 | 329 | # Loss from eventual second pass 330 | if hierarchical_sampling: 331 | for loss_name, loss in losses_importance.items(): 332 | single_loss = loss.mean() 333 | train_loss += single_loss / num_chunks 334 | 335 | batch_avged_tot_loss += float(train_loss) 336 | if loss_name + '_importance' in batch_avged_losses: 337 | batch_avged_losses[loss_name+'_importance'] += single_loss / num_chunks 338 | else: 339 | batch_avged_losses.update({loss_name + '_importance': single_loss / num_chunks}) 340 | 341 | train_loss.backward() 342 | else: 343 | model_input = dict2cuda(model_input) 344 | meta = dict2cuda(meta) 345 | gt = dict2cuda(gt) 346 | 347 | model_outputs = {key: model(model_input) for key, model in models.items()} 348 | losses = loss_fn(model_outputs, gt) 349 | 350 | # loss from forward pass 351 | train_loss = 0. 352 | for loss_name, loss in losses.items(): 353 | 354 | single_loss = loss.mean() 355 | train_loss += single_loss 356 | 357 | batch_avged_tot_loss = float(single_loss) 358 | if loss_name in batch_avged_losses: 359 | batch_avged_losses[loss_name] += single_loss 360 | else: 361 | batch_avged_losses.update({loss_name: single_loss}) 362 | 363 | train_loss.backward() 364 | 365 | for loss_name, loss in batch_avged_losses.items(): 366 | writer.add_scalar(loss_name, loss, step) 367 | train_losses.append(batch_avged_tot_loss) 368 | writer.add_scalar("total_train_loss", batch_avged_tot_loss, step) 369 | writer.add_scalar("reg_lambda", reg_lambda, step) 370 | 371 | for k in optims.keys(): 372 | writer.add_scalar(f"{k}_lr", optims[k].param_groups[0]['lr'], step) 373 | 374 | if clip_grad: 375 | for model in models.values(): 376 | if isinstance(clip_grad, bool): 377 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) 378 | else: 379 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad) 380 | 381 | for optim in optims.values(): 382 | optim.step() 383 | 384 | if not step % steps_til_summary: 385 | tqdm.write("Step %d, Total loss %0.6f, iteration time %0.6f" % (step, train_loss, time.time() - start_time)) 386 | for key, model in models.items(): 387 | torch.save(model.state_dict(), 388 | os.path.join(checkpoints_dir, 'model_'+key+'_current.pth')) 389 | for key, optim in optims.items(): 390 | torch.save({'step': step, 391 | 'total_steps': step, 392 | 'optimizer_state_dict': optim.state_dict(), 393 | 'scheduler_state_dict': schedulers[key].state_dict()}, 394 | os.path.join(checkpoints_dir, 'optim_'+key+'_current.pth')) 395 | summary_fn(models, train_dataloader, val_dataloader, loss_fn, optims, meta, gt, 396 | writer, step) 397 | 398 | pbar.update(1) 399 | 400 | for k in schedulers.keys(): 401 | schedulers[k].step() 402 | 403 | for key, model in models.items(): 404 | torch.save(model.state_dict(), 405 | os.path.join(checkpoints_dir, 'model_' + key + '_final.pth')) 406 | np.savetxt(os.path.join(checkpoints_dir, 'train_losses_final.txt'), 407 | np.array(train_losses)) 408 | 409 | 410 | def sample_pdf(model_inputs, model_outputs, offset=5e-3, 411 | idx=-1): 412 | ''' hierarchical sampling code for neural radiance fields ''' 413 | 414 | z_vals = model_inputs['t'] 415 | bins = .5*(z_vals[..., 1:, :] + z_vals[..., :-1, :]).squeeze() 416 | bins = bins.clone().detach().requires_grad_(True) 417 | 418 | if 'combined' in model_outputs: 419 | if isinstance(model_outputs['combined']['model_out']['output'], list): 420 | pred_sigma = model_outputs['combined']['model_out']['output'][idx][..., -1:] 421 | t_intervals = model_outputs['combined']['model_in']['t_intervals'] 422 | else: 423 | pred_sigma = model_outputs['combined']['model_out']['output'][..., -1:] 424 | t_intervals = model_outputs['combined']['model_in']['t_intervals'] 425 | else: 426 | pred_sigma = model_outputs['sigma']['model_out']['output'] 427 | t_intervals = model_outputs['sigma']['model_in']['t_intervals'] 428 | 429 | if isinstance(pred_sigma, list): 430 | pred_sigma = pred_sigma[idx] 431 | 432 | pred_weights = forward_models.compute_transmittance_weights(pred_sigma, t_intervals)[..., :-1, 0] 433 | 434 | # blur weights 435 | pred_weights = torch.cat((pred_weights, pred_weights[..., -1:]), dim=-1) 436 | weights_max = torch.maximum(pred_weights[..., :-1], pred_weights[..., 1:]) 437 | weights_blur = 0.5 * (weights_max[..., :-1] + weights_max[..., 1:]) 438 | pred_weights = weights_blur + offset 439 | 440 | pdf = pred_weights / torch.sum(pred_weights, dim=-1, keepdim=True) 441 | 442 | cdf = torch.cumsum(pdf, dim=-1) 443 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1).squeeze() # batch_pixels, num_bins=samples_per_ray-1) 444 | cdf = cdf.detach() 445 | num_samples = pred_sigma.shape[-2] 446 | u = torch.rand(list(cdf.shape[:-1])+[num_samples], device=pred_weights.device) 447 | 448 | inds = torch.searchsorted(cdf, u, right=True) 449 | below = torch.max(torch.zeros_like(inds), inds-1) 450 | above = torch.min((cdf.shape[-1]-1)*torch.ones_like(inds), inds) 451 | inds_g = torch.stack((below, above), -1) 452 | 453 | matched_shape = (inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]) 454 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 455 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 456 | 457 | denom = (cdf_g[..., 1]-cdf_g[..., 0]) 458 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 459 | t = (u - cdf_g[..., 0])/denom 460 | t_vals = (bins_g[..., 0] + t*(bins_g[..., 1]-bins_g[..., 0])).unsqueeze(-1) 461 | t_vals, _ = torch.sort(t_vals, dim=-2) 462 | 463 | ray_dirs = model_inputs['ray_directions'] 464 | ray_orgs = model_inputs['ray_origins'] 465 | 466 | t_vals = t_vals[..., 0] 467 | t_intervals = t_vals[..., 1:] - t_vals[..., :-1] 468 | t_intervals = torch.cat((t_intervals, 1e10*torch.ones_like(t_intervals[:, 0:1])), dim=-1) 469 | t_intervals = (t_intervals * ray_dirs.norm(p=2, dim=-1))[..., None] 470 | t_vals = t_vals[..., None] 471 | 472 | if ray_dirs.ndim == 4: 473 | t_vals = t_vals[None, ...] 474 | 475 | model_inputs.update({'t': t_vals}) 476 | model_inputs.update({'ray_samples': ray_orgs + ray_dirs * t_vals}) 477 | model_inputs.update({'t_intervals': t_intervals}) 478 | 479 | return model_inputs 480 | --------------------------------------------------------------------------------