├── README.md
├── convert_blender_data.py
├── dataio.py
├── download_datasets.py
├── environment.yml
├── experiments
├── config
│ ├── 1d
│ │ ├── bacon_freq1.ini
│ │ ├── bacon_freq2.ini
│ │ ├── ff.ini
│ │ └── siren.ini
│ ├── img
│ │ ├── bacon.ini
│ │ ├── ff.ini
│ │ ├── mip.ini
│ │ └── siren.ini
│ ├── nerf
│ │ ├── bacon.ini
│ │ ├── bacon_lr.ini
│ │ └── bacon_semisupervise.ini
│ └── sdf
│ │ ├── bacon_armadillo.ini
│ │ ├── bacon_dragon.ini
│ │ ├── bacon_lucy.ini
│ │ ├── bacon_thai.ini
│ │ ├── ff_armadillo.ini
│ │ ├── ff_dragon.ini
│ │ ├── ff_lucy.ini
│ │ ├── ff_thai.ini
│ │ ├── siren_armadillo.ini
│ │ ├── siren_dragon.ini
│ │ ├── siren_lucy.ini
│ │ └── siren_thai.ini
├── figure_setup.py
├── plot_activation_distributions.py
├── render_nerf.py
├── render_sdf.py
├── train_1d.py
├── train_img.py
├── train_radiance_field.py
└── train_sdf.py
├── forward_models.py
├── img
└── teaser.png
├── loss_functions.py
├── modules.py
├── spectrum_visualization
├── README.md
├── environment.yml
├── get_shape_spectra.py
├── magicwand
├── model_final.pth
└── ref_armadillo.xyz
├── trained_models
├── armadillo.pth
├── dragon.pth
├── lego.pth
├── lego_semisupervise.pth
├── lucy.pth
└── thai.pth
├── training.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # BACON: Band-limited Coordinate Networks for Multiscale Scene Representation | CVPR 2022 (oral)
2 | ### [Project Page](http://www.computationalimaging.org/publications/bacon/) | [Video](https://www.youtube.com/watch?v=zIH3KUCgJEA) | [Paper](https://arxiv.org/abs/2112.04645)
3 | Official PyTorch implementation of BACON.
4 | [BACON: Band-limited Coordinate Networks for Multiscale Scene Representation](http://www.computationalimaging.org/publications/bacon/)
5 | [David B. Lindell](https://davidlindell.com)\*,
6 | [Dave Van Veen](https://davevanveen.com/),
7 | [Jeong Joon Park](https://jjparkcv.github.io/),
8 | [Gordon Wetzstein](https://computationalimaging.org)
9 | Stanford University
10 |
11 |
12 |
13 | ## Quickstart
14 |
15 | To setup a conda environment use these commands
16 | ```
17 | conda env create -f environment.yml
18 | conda activate bacon
19 |
20 | # download all datasets
21 | python download_datasets.py
22 | ```
23 | Now you can train networks to fit a 1D function, images, signed distance fields, or neural radiance fields with the following commands.
24 |
25 | ```
26 | cd experiments
27 | python train_1d.py --config ./config/1d/bacon_freq1.ini # train 1D function
28 | python train_img.py --config ./config/img/bacon.ini # train image
29 | python train_sdf.py --config ./config/sdf/bacon_armadillo.ini # train SDF
30 | python train_radiance_field.py --config ./config/nerf/bacon_lr.ini # train NeRF
31 | ```
32 |
33 | To visualize outputs in Tensorboard, run the following.
34 | ```
35 | tensorboard --logdir=../logs --port=6006
36 | ```
37 |
38 | ## Band-limited Coordinate Networks
39 |
40 | Band-limited coordinate networks have an analytical Fourier spectrum and interpretible behavior. We demonstrate using these networks for fitting simple 1D signals, images, 3D shapes via signed distance functions and neural radiance fields.
41 |
42 | ### Datasets
43 |
44 | Datasets can be downloaded using the `download_datasets.py` script. This script
45 | - downloads the synthetic Blender dataset from the [original NeRF paper](https://github.com/bmild/nerf),
46 | - generates a multiscale version of the Blender dataset,
47 | - downloads 3D models originating from the [Stanford 3D Scanning Repository](http://graphics.stanford.edu/data/3Dscanrep/), which we have adjusted to make watertight, and
48 | - downloads an example image from the [Kodak dataset](http://www.cs.albany.edu/~xypan/research/snr/Kodak.html).
49 |
50 | ### Training
51 |
52 | We provide scripts for training and configuration files to reproduce the results in the paper.
53 |
54 | #### 1D Examples
55 | To run the 1D examples, use the `experiments/train_1d.py` script with any of the config files in `experiments/config/1d`. These scripts allow training models with BACON, [Fourier Features](https://github.com/tancik/fourier-feature-networks), or [SIREN](https://github.com/vsitzmann/siren).
56 | For example, to train a BACON model you can run
57 |
58 | ```
59 | python train_1d.py --config ./config/1d/bacon_freq1.ini
60 | ```
61 |
62 | To change the bandwidth of BACON, adjust the maximum frequency with the `--max_freq` flag.
63 | This sets network-equivalent sampling rate used to represent the signal.
64 | For example, if the signal you wish to represent has a maximum frequency of 5 cycles per unit interval, this value should be set to at least the Nyquist rate of 2 samples per cycle or 10 samples per unit interval.
65 | By default, the frequencies represented by BACON are quantized to intervals of 2*pi; thus, the network is periodic over an interval from -0.5 to 0.5.
66 | That is, the output of the network will repeat for input coordinates that exceed an absolute value of 0.5.
67 |
68 | #### Image Fitting
69 |
70 | Image fitting can be performed using the config files in `experiments/config/img` and the `train_img.py` script. We support training BACON, Fourier Features, SIREN, and networks with the positional encoding from [Mip-NeRF](https://github.com/google/mipnerf).
71 |
72 | #### SDF Fitting
73 |
74 | Config files for SDF fitting are in `experiments/config/sdf` and can be used with the `train_sdf.py` script.
75 | Be sure to download the example datasets before running this script.
76 |
77 | We also provide a rendering script to extract meshes from the trained models.
78 | The `render_sdf.py` program extracts a mesh using marching cubes and, optionally, our proposed multiscale adaptive SDF evaluation procedure.
79 |
80 | #### NeRF Reconstruction
81 |
82 | Use the config files in `experiments/config/nerf` with the `train_radiance_field.py` script to train neural radiance fields.
83 | Note that training the full resolution model can takes a while (a few days) so it may be easier to train a low-resolution model to get started.
84 | We provide a low-resolution config file in `experiments/config/nerf/bacon_lr.ini`.
85 |
86 | To render output images from a trained model, use the `render_nerf.py` script.
87 | Note that the Blender synthetic datasets should be downloaded and the multiscale dataset generated before running this script.
88 |
89 | #### Initialization Scheme
90 |
91 | Finally, we also show a visualization of our initialization scheme in `experiments/plot_activation_distributions.py`. As shown in the paper, our initialization scheme prevents the distribution of activations from becoming vanishingly small, even for deep networks.
92 |
93 |
94 | #### Pretrained models
95 |
96 | For convenience, we include pretrained models for the SDF fitting and NeRF reconstruction tasks in the `pretrained_models` directory.
97 | The outputs of these models can be rendered directly using the `experiments/render_sdf.py` and `experiments/render_nerf.py` scripts.
98 |
99 | ## Citation
100 |
101 | ```
102 | @article{lindell2021bacon,
103 | author = {Lindell, David B. and Van Veen, Dave and Park, Jeong Joon and Wetzstein, Gordon},
104 | title = {BACON: Band-limited coordinate networks for multiscale scene representation},
105 | journal = {arXiv preprint arXiv:2112.04645},
106 | year={2021}
107 | }
108 | ```
109 | ## Acknowledgments
110 |
111 | This project was supported in part by a PECASE by the ARO and NSF award 1839974.
112 |
--------------------------------------------------------------------------------
/convert_blender_data.py:
--------------------------------------------------------------------------------
1 | # adapted from Jon Barron's mipnerf conversion script
2 | # https://github.com/google/mipnerf/blob/main/scripts/convert_blender_data.py
3 |
4 | import json
5 | import os
6 | from os import path
7 |
8 | from absl import app
9 | from absl import flags
10 | import numpy as np
11 | from PIL import Image
12 |
13 | import skimage.transform
14 |
15 | FLAGS = flags.FLAGS
16 |
17 | flags.DEFINE_string('blenderdir', None, 'Base directory for all Blender data.')
18 | flags.DEFINE_integer('n_down', 4, 'How many levels of downscaling to use.')
19 |
20 |
21 | def load_renderings(data_dir, split):
22 | """Load images and metadata from disk."""
23 | f = 'transforms_{}.json'.format(split)
24 | with open(path.join(data_dir, f), 'r') as fp:
25 | meta = json.load(fp)
26 | images = []
27 | cams = []
28 | print('Loading imgs')
29 | for frame in meta['frames']:
30 | fname = os.path.join(data_dir, frame['file_path'] + '.png')
31 | with open(fname, 'rb') as imgin:
32 | image = np.array(Image.open(imgin), dtype=np.float32) / 255.
33 | cams.append(frame['transform_matrix'])
34 | images.append(image)
35 | ret = {}
36 | ret['images'] = np.stack(images, axis=0)
37 | print('Loaded all images, shape is', ret['images'].shape)
38 | ret['camtoworlds'] = np.stack(cams, axis=0)
39 | w = ret['images'].shape[2]
40 | camera_angle_x = float(meta['camera_angle_x'])
41 | ret['focal'] = .5 * w / np.tan(.5 * camera_angle_x)
42 | return ret
43 |
44 |
45 | def down2(img):
46 | sh = img.shape
47 | return np.mean(np.reshape(img, [sh[0] // 2, 2, sh[1] // 2, 2, -1]), (1, 3))
48 |
49 |
50 | def convert_to_nerfdata(basedir, n_down):
51 | """Convert Blender data to multiscale."""
52 | splits = ['train', 'val', 'test']
53 | # Foreach split in the dataset
54 | for split in splits:
55 | print('Split', split)
56 | # Load everything
57 | data = load_renderings(basedir, split)
58 |
59 | # Save out all the images
60 | imgdir = '{}_multiscale'.format(split)
61 | os.makedirs(os.path.join(basedir, imgdir), exist_ok=True)
62 | print('Saving images')
63 | for i, img in enumerate(data['images']):
64 | for j in range(n_down):
65 | fname = '{}/r_{:d}_d{}.png'.format(imgdir, i, j)
66 | fname = os.path.join(basedir, fname)
67 | with open(fname, 'wb') as imgout:
68 | img = skimage.transform.resize(img, 2*(512//2**j,))
69 | img8 = Image.fromarray(np.uint8(img * 255))
70 | img8.save(imgout)
71 | # img = down2(img)
72 |
73 |
74 | def main(unused_argv):
75 |
76 | blenderdir = FLAGS.blenderdir
77 | n_down = FLAGS.n_down
78 |
79 | dirs = [os.path.join(blenderdir, f) for f in os.listdir(blenderdir)]
80 | dirs = [d for d in dirs if os.path.isdir(d)]
81 | print(dirs)
82 | for basedir in dirs:
83 | print()
84 | convert_to_nerfdata(basedir, n_down)
85 |
86 |
87 | if __name__ == '__main__':
88 | app.run(main)
89 |
--------------------------------------------------------------------------------
/dataio.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import math
4 | from PIL import Image
5 | import skimage
6 | from torchvision.transforms import Compose, ToTensor, Resize, Lambda
7 | import skimage.transform
8 | import json
9 | import os
10 | import re
11 | from tqdm import tqdm
12 | from torch.utils.data import Dataset
13 | from pykdtree.kdtree import KDTree
14 | import errno
15 | import urllib.request
16 |
17 |
18 | def get_mgrid(sidelen, dim=2, centered=True, include_end=False):
19 | '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.'''
20 | if isinstance(sidelen, int):
21 | sidelen = dim * (sidelen,)
22 |
23 | if include_end:
24 | denom = [s-1 for s in sidelen]
25 | else:
26 | denom = sidelen
27 |
28 | if dim == 2:
29 | pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32)
30 | pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / denom[0]
31 | pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / denom[1]
32 | elif dim == 3:
33 | pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32)
34 | pixel_coords[..., 0] = pixel_coords[..., 0] / denom[0]
35 | pixel_coords[..., 1] = pixel_coords[..., 1] / denom[1]
36 | pixel_coords[..., 2] = pixel_coords[..., 2] / denom[2]
37 | else:
38 | raise NotImplementedError('Not implemented for dim=%d' % dim)
39 |
40 | if centered:
41 | pixel_coords -= 0.5
42 |
43 | pixel_coords = torch.Tensor(pixel_coords).view(-1, dim)
44 | return pixel_coords
45 |
46 |
47 | def lin2img(tensor, image_resolution=None):
48 | batch_size, num_samples, channels = tensor.shape
49 | if image_resolution is None:
50 | width = np.sqrt(num_samples).astype(int)
51 | height = width
52 | else:
53 | height = image_resolution[0]
54 | width = image_resolution[1]
55 |
56 | return tensor.permute(0, 2, 1).view(batch_size, channels, height, width)
57 |
58 |
59 | class Func1DWrapper(torch.utils.data.Dataset):
60 | def __init__(self, range, fn, grad_fn=None,
61 | sampling_density=100, train_every=10):
62 |
63 | coords = self.get_samples(range, sampling_density)
64 | self.fn_vals = fn(coords)
65 | self.train_idx = torch.arange(0, coords.shape[0], train_every).float()
66 |
67 | self.grid = coords
68 | self.grid.requires_grad_(True)
69 | self.range = range
70 |
71 | def get_samples(self, range, sampling_density):
72 | num = int(range[1] - range[0])*sampling_density
73 | coords = np.linspace(start=range[0], stop=range[1], num=num)
74 | coords.astype(np.float32)
75 | coords = torch.Tensor(coords).view(-1, 1)
76 | return coords
77 |
78 | def get_num_samples(self):
79 | return self.grid.shape[0]
80 |
81 | def __len__(self):
82 | return 1
83 |
84 | def __getitem__(self, idx):
85 |
86 | return {'idx': self.train_idx, 'coords': self.grid}, \
87 | {'func': self.fn_vals, 'coords': self.grid}
88 |
89 |
90 | def rect(coords, width=1):
91 | return torch.where(abs(coords) < width/2, 1.0/width, 0.0)
92 |
93 |
94 | def gaussian(coords, sigma=1, center=0.5):
95 | return 1 / (sigma * math.sqrt(2*np.pi)) * torch.exp(-(coords-center)**2 / (2*sigma**2))
96 |
97 |
98 | def sines1(coords):
99 | return 0.3 * torch.sin(2*np.pi*8*coords + np.pi/3) + 0.65 * torch.sin(2*np.pi*2*coords + np.pi)
100 |
101 |
102 | def polynomial_1(coords):
103 | return .1*((coords+.2)*3)**5 - .2*((coords+.2)*3)**4 + .2*((coords+.2)*3)**3 - .4*((coords+.2)*3)**2 + .1*((coords+.2)*3)
104 |
105 |
106 | def sinc(coords):
107 | coords[coords == 0] += 1
108 | return torch.div(torch.sin(20*coords), 20*coords)
109 |
110 |
111 | def linear(coords):
112 | return 1.0 * coords
113 |
114 |
115 | def xcosx(coords):
116 | return coords * torch.cos(coords)
117 |
118 |
119 | class ImageWrapper(torch.utils.data.Dataset):
120 | def __init__(self, dataset, compute_diff='all', centered=True,
121 | include_end=False, multiscale=False, stages=3):
122 |
123 | self.compute_diff = compute_diff
124 | self.centered = centered
125 | self.include_end = include_end
126 | self.transform = Compose([
127 | ToTensor(),
128 | ])
129 |
130 | self.dataset = dataset
131 | self.mgrid = get_mgrid(self.dataset.resolution, centered=centered, include_end=include_end)
132 |
133 | # sample pixel centers
134 | self.mgrid = self.mgrid + 1 / (2 * self.dataset.resolution[0])
135 | self.radii = 1 / self.dataset.resolution[0] * 2/np.sqrt(12)
136 | self.radii = [(self.radii * 2**i).astype(np.float32) for i in range(3)]
137 | self.radii.reverse()
138 |
139 | img = self.transform(self.dataset[0])
140 | _, self.rows, self.cols = img.shape
141 |
142 | self.img_chw = img
143 | self.img = img.permute(1, 2, 0).view(-1, self.dataset.img_channels)
144 |
145 | self.imgs = []
146 | self.multiscale = multiscale
147 | img = img.permute(1, 2, 0).numpy()
148 | for i in range(stages):
149 | tmp = skimage.transform.resize(img, [s//2**i for s in (self.rows, self.cols)])
150 | tmp = skimage.transform.resize(tmp, (self.rows, self.cols))
151 | self.imgs.append(torch.from_numpy(tmp).view(-1, self.dataset.img_channels))
152 | self.imgs.reverse()
153 |
154 | def __len__(self):
155 | return len(self.dataset)
156 |
157 | def __getitem__(self, idx):
158 |
159 | coords = self.mgrid
160 | img = self.img
161 |
162 | in_dict = {'coords': coords, 'radii': self.radii}
163 | gt_dict = {'img': img}
164 |
165 | if self.multiscale:
166 | gt_dict['img'] = self.imgs
167 |
168 | return in_dict, gt_dict
169 |
170 |
171 | def save_img(img, filename):
172 | ''' given np array, convert to image and save '''
173 | img = Image.fromarray((255*img).astype(np.uint8))
174 | img.save(filename)
175 |
176 |
177 | def crop_center(pil_img, crop_width, crop_height):
178 | img_width, img_height = pil_img.size
179 | return pil_img.crop(((img_width - crop_width) // 2,
180 | (img_height - crop_height) // 2,
181 | (img_width + crop_width) // 2,
182 | (img_height + crop_height) // 2))
183 |
184 |
185 | def crop_max_square(pil_img):
186 | return crop_center(pil_img, min(pil_img.size), min(pil_img.size))
187 |
188 |
189 | class ImageFile(Dataset):
190 | def __init__(self, filename, grayscale=False, resolution=None,
191 | root_path=None, crop_square=True, url=None):
192 |
193 | super().__init__()
194 |
195 | if not os.path.exists(filename):
196 | if url is None:
197 | raise FileNotFoundError(
198 | errno.ENOENT, os.strerror(errno.ENOENT), filename)
199 | else:
200 | print('Downloading image file...')
201 | os.makedirs(os.path.dirname(filename), exist_ok=True)
202 | urllib.request.urlretrieve(url, filename)
203 |
204 | self.img = Image.open(filename)
205 | if grayscale:
206 | self.img = self.img.convert('L')
207 | else:
208 | self.img = self.img.convert('RGB')
209 |
210 | self.img_channels = len(self.img.mode)
211 | self.resolution = self.img.size
212 |
213 | if crop_square: # preserve aspect ratio
214 | self.img = crop_max_square(self.img)
215 |
216 | if resolution is not None:
217 | self.resolution = resolution
218 | self.img = self.img.resize(resolution, Image.ANTIALIAS)
219 |
220 | self.img = np.array(self.img)
221 | self.img = self.img.astype(np.float32)/255.
222 |
223 | def __len__(self):
224 | return 1
225 |
226 | def __getitem__(self, idx):
227 | return self.img
228 |
229 |
230 | def chunk_lists_from_batch_reduce_to_raysamples_fn(model_input, meta, gt, max_chunk_size):
231 |
232 | model_in_chunked = []
233 | for key in model_input:
234 | num_views, num_rays, num_samples_per_rays, num_dims = model_input[key].shape
235 | chunks = torch.split(model_input[key].view(-1, num_samples_per_rays, num_dims), max_chunk_size)
236 | model_in_chunked.append(chunks)
237 |
238 | list_chunked_model_input = \
239 | [{k: v for k, v in zip(model_input.keys(), curr_chunks)} for curr_chunks in zip(*model_in_chunked)]
240 |
241 | # meta_dict
242 | list_chunked_zs = torch.split(meta['zs'].view(-1, num_samples_per_rays, 1),
243 | max_chunk_size)
244 | list_chunked_meta = [{'zs': zs} for zs in list_chunked_zs]
245 |
246 | # gt_dict
247 | gt_chunked = []
248 | for key in gt:
249 | if isinstance(gt[key], list):
250 | # this handles lists of gt tensors (e.g., for multiscale)
251 | num_dims = gt[key][0].shape[-1]
252 |
253 | # this chunks the list elements so you have [num_tensors, num_chunks]
254 | chunks = [torch.split(x.view(-1, num_dims), max_chunk_size) for x in gt[key]]
255 |
256 | # this switches it to [num_chunks, num_tensors]
257 | chunks = [chunk for chunk in zip(*chunks)]
258 | gt_chunked.append(chunks)
259 | else:
260 | *_, num_dims = gt[key].shape
261 | chunks = torch.split(gt[key].view(-1, num_dims), max_chunk_size)
262 | gt_chunked.append(chunks)
263 |
264 | list_chunked_gt = \
265 | [{k: v for k, v in zip(gt.keys(), curr_chunks)} for curr_chunks in zip(*gt_chunked)]
266 |
267 | return list_chunked_model_input, list_chunked_meta, list_chunked_gt
268 |
269 |
270 | class NerfBlenderDataset(torch.utils.data.Dataset):
271 | def __init__(self, basedir, mode='train',
272 | splits=['train', 'val', 'test'],
273 | select_idx=None,
274 | testskip=1, resize_to=None, final_render=False,
275 | d_rot=0, bounds=((-2, 2), (-2, 2), (0, 2)),
276 | multiscale=False,
277 | black_background=False,
278 | override_scale=None):
279 |
280 | self.mode = mode
281 | self.basedir = basedir
282 | self.resize_to = resize_to
283 | self.final_render = final_render
284 | self.bounds = bounds
285 | self.multiscale = multiscale
286 | self.select_idx = select_idx
287 | self.d_rot = d_rot
288 |
289 | metas = {}
290 | for s in splits:
291 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
292 | metas[s] = json.load(fp)
293 |
294 | # Eventually transform the inputs
295 | transform_list = [ToTensor()]
296 | if resize_to is not None:
297 | transform_list.insert(0, Resize(resize_to,
298 | interpolation=Image.BILINEAR))
299 |
300 | def multiscale_resize(x):
301 | scale = 512 // x.size[0]
302 | return x.resize([r//scale for r in resize_to],
303 | resample=Image.BILINEAR)
304 |
305 | if multiscale and override_scale is None:
306 | # this will scale the image down appropriately
307 | # (e.g., to 1/2, 1/4, 1/8 of desired resolution)
308 | # Then the next transform will scale it back up so we use the same rays
309 | # to supervise
310 | transform_list.insert(0, Lambda(lambda x: multiscale_resize(x)))
311 | if black_background:
312 | transform_list.append(Lambda(lambda x: x[:3] * x[[-1]]))
313 | else:
314 | transform_list.append(Lambda(lambda x: x[:3] * x[[-1]] + (1 - x[[-1]])))
315 |
316 | self.transforms = Compose(transform_list)
317 |
318 | # Gather images and poses
319 | self.all_imgs = {}
320 | self.all_poses = {}
321 | for s in splits:
322 | meta = metas[s]
323 | imgs, poses = self.load_images(s, meta, testskip)
324 |
325 | self.all_imgs.update({s: imgs})
326 | self.all_poses.update({s: poses})
327 |
328 | if self.final_render:
329 | self.poses = [torch.from_numpy(self.pose_spherical(angle, -30.0, 4.0)).float()
330 | for angle in np.linspace(-180, 180, 40 + 1)[:-1]]
331 |
332 | if override_scale is not None:
333 | assert multiscale, 'only for multiscale'
334 | if override_scale > 3:
335 | override_scale = 3
336 | H, W = self.multiscale_imgs[0][override_scale].shape[:2]
337 | self.img_shape = self.multiscale_imgs[0][override_scale].shape
338 | else:
339 | H, W = imgs[0].shape[:2]
340 | self.img_shape = imgs[0].shape
341 |
342 | # projective camera
343 | camera_angle_x = float(meta['camera_angle_x'])
344 | focal = .5 * W / np.tan(.5 * camera_angle_x)
345 | self.camera_params = {'H': H, 'W': W,
346 | 'camera_angle_x': camera_angle_x,
347 | 'focal': focal,
348 | 'near': 2.0,
349 | 'far': 6.0}
350 |
351 | def load_images(self, s, meta, testskip):
352 | imgs = []
353 | poses = []
354 |
355 | if s == 'train' or testskip == 0:
356 | skip = 1
357 | else:
358 | skip = testskip
359 |
360 | for frame in tqdm(meta['frames'][::skip]):
361 | if self.select_idx is not None:
362 | if re.search('[0-9]+', frame['file_path']).group(0) != self.select_idx:
363 | continue
364 |
365 | def load_image(fname):
366 | img = Image.open(fname)
367 | pose = torch.from_numpy(np.array(frame['transform_matrix'], dtype=np.float32))
368 |
369 | img_t = self.transforms(img)
370 | imgs.append(img_t.permute(1, 2, 0))
371 | poses.append(pose)
372 |
373 | if self.multiscale:
374 | for i in range(4):
375 | fname = os.path.join(self.basedir, frame['file_path']).replace(s, s + '_multiscale') + f'_d{i}.png'
376 | load_image(fname)
377 |
378 | else:
379 | fname = os.path.join(self.basedir, frame['file_path'] + '.png')
380 | load_image(fname)
381 |
382 | if self.multiscale:
383 | poses = poses[::4]
384 | self.multiscale_imgs = [imgs[i:i+4][::-1] for i in range(0, len(imgs), 4)]
385 | imgs = imgs[::4]
386 |
387 | return imgs, poses
388 |
389 | # adapted from https://github.com/krrish94/nerf-pytorch
390 | # derived from original NeRF repo (MIT License)
391 | def translate_by_t_along_z(self, t):
392 | tform = np.eye(4).astype(np.float32)
393 | tform[2][3] = t
394 | return tform
395 |
396 | def rotate_by_phi_along_x(self, phi):
397 | tform = np.eye(4).astype(np.float32)
398 | tform[1, 1] = tform[2, 2] = np.cos(phi)
399 | tform[1, 2] = -np.sin(phi)
400 | tform[2, 1] = -tform[1, 2]
401 | return tform
402 |
403 | def rotate_by_theta_along_y(self, theta):
404 | tform = np.eye(4).astype(np.float32)
405 | tform[0, 0] = tform[2, 2] = np.cos(theta)
406 | tform[0, 2] = -np.sin(theta)
407 | tform[2, 0] = -tform[0, 2]
408 | return tform
409 |
410 | def pose_spherical(self, theta, phi, radius):
411 | c2w = self.translate_by_t_along_z(radius)
412 | c2w = self.rotate_by_phi_along_x(phi / 180.0 * np.pi) @ c2w
413 | c2w = self.rotate_by_theta_along_y(theta / 180 * np.pi) @ c2w
414 | c2w = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
415 | return c2w
416 |
417 | def set_mode(self, mode):
418 | self.mode = mode
419 |
420 | def get_img_shape(self):
421 | return self.img_shape
422 |
423 | def get_camera_params(self):
424 | return self.camera_params
425 |
426 | def __len__(self):
427 | if self.final_render:
428 | return len(self.poses)
429 | else:
430 | return len(self.all_imgs[self.mode])
431 |
432 | def __getitem__(self, item):
433 | # render out trajectory (no GT images)
434 | if self.final_render:
435 | return {'img': torch.zeros(4), # we have to pass something...
436 | 'pose': self.poses[item]}
437 |
438 | # otherwise, return GT images and pose
439 | else:
440 | return {'img': self.all_imgs[self.mode][item],
441 | 'pose': self.all_poses[self.mode][item]}
442 |
443 |
444 | class Implicit6DMultiviewDataWrapper(torch.utils.data.Dataset):
445 | def __init__(self, dataset, img_shape, camera_params,
446 | samples_per_ray=128,
447 | samples_per_view=32000,
448 | num_workers=4,
449 | multiscale=False,
450 | supervise_hr=False,
451 | scales=[1/8, 1/4, 1/2, 1]):
452 |
453 | self.dataset = dataset
454 | self.num_workers = num_workers
455 | self.multiscale = multiscale
456 | self.scales = scales
457 | self.supervise_hr = supervise_hr
458 |
459 | self.img_shape = img_shape
460 | self.camera_params = camera_params
461 |
462 | self.samples_per_view = samples_per_view
463 | self.default_samples_per_view = samples_per_view
464 | self.samples_per_ray = samples_per_ray
465 |
466 | self._generate_rays_normalized()
467 | self._precompute_rays()
468 |
469 | self.is_logging = False
470 |
471 | self.val_idx = 0
472 |
473 | self.num_rays = self.all_ray_orgs.view(-1, 3).shape[0]
474 |
475 | self.shuffle_rays()
476 |
477 | if multiscale:
478 | self.multiscale_imgs = dataset.multiscale_imgs
479 |
480 | # switch to size [num_scales, num_views, img_size[0], img_size[1], 3]
481 | self.multiscale_imgs = torch.stack([torch.stack(m, dim=0)
482 | for m in zip(*self.multiscale_imgs)], dim=0)
483 |
484 | def toggle_logging_sampling(self):
485 | if self.is_logging:
486 | self.samples_per_view = self.default_samples_per_view
487 | self.is_logging = False
488 | else:
489 | self.samples_per_view = self.img_shape[0] * self.img_shape[1]
490 | self.is_logging = True
491 |
492 | def _generate_rays_normalized(self):
493 |
494 | # projective camera
495 | rows = torch.arange(0, self.img_shape[0], dtype=torch.float32)
496 | cols = torch.arange(0, self.img_shape[1], dtype=torch.float32)
497 | g_rows, g_cols = torch.meshgrid(rows, cols)
498 |
499 | W = self.camera_params['W']
500 | H = self.camera_params['H']
501 | f = self.camera_params['focal']
502 |
503 | self.norm_rays = torch.stack([(g_cols-.5*W + 0.5)/f,
504 | -(g_rows-.5*H + 0.5)/f,
505 | -torch.ones_like(g_rows)],
506 | dim=2).view(-1, 3).permute(1, 0)
507 |
508 | self.num_rays_per_view = self.norm_rays.shape[1]
509 |
510 | def shuffle_rays(self):
511 | self.shuffle_idxs = torch.randperm(self.num_rays)
512 |
513 | def _precompute_rays(self):
514 | img_list = []
515 | pose_list = []
516 | ray_orgs_list = []
517 | ray_dirs_list = []
518 |
519 | print('Precomputing rays...')
520 | for img_pose in tqdm(self.dataset):
521 | img = img_pose['img']
522 | img_list.append(img)
523 |
524 | pose = img_pose['pose']
525 | pose_list.append(pose)
526 |
527 | ray_dirs = pose[:3, :3].matmul(self.norm_rays).permute(1, 0)
528 | ray_dirs_list.append(ray_dirs)
529 |
530 | ray_orgs = pose[:3, 3].repeat((self.num_rays_per_view, 1))
531 | ray_orgs_list.append(ray_orgs)
532 |
533 | self.all_imgs = torch.stack(img_list, dim=0)
534 | self.all_poses = torch.stack(pose_list, dim=0)
535 | self.all_ray_orgs = torch.stack(ray_orgs_list, dim=0)
536 | self.all_ray_dirs = torch.stack(ray_dirs_list, dim=0)
537 |
538 | self.hit = torch.zeros(self.all_ray_dirs.view(-1, 3).shape[0])
539 |
540 | def __len__(self):
541 | if self.is_logging:
542 | return self.all_imgs.shape[0]
543 | else:
544 | return self.num_rays // self.samples_per_view
545 |
546 | def get_val_rays(self):
547 | img = self.all_imgs[self.val_idx, ...]
548 | ray_dirs = self.all_ray_dirs[self.val_idx, ...]
549 | ray_orgs = self.all_ray_orgs[self.val_idx, ...]
550 | view_samples = img
551 |
552 | if self.multiscale:
553 | img = self.multiscale_imgs[:, self.val_idx, ...]
554 | if self.supervise_hr:
555 | img = [img[-1] for _ in img]
556 | view_samples = [im for im in img]
557 |
558 | self.val_idx += 1
559 | self.val_idx %= self.all_imgs.shape[0]
560 |
561 | return view_samples, ray_orgs, ray_dirs
562 |
563 | def get_rays(self, idx):
564 | idxs = self.shuffle_idxs[self.samples_per_view * idx:self.samples_per_view * (idx+1)]
565 | ray_dirs = self.all_ray_dirs.view(-1, 3)[idxs, ...]
566 | ray_orgs = self.all_ray_orgs.view(-1, 3)[idxs, ...]
567 |
568 | if self.multiscale:
569 | view_samples = [mimg.view(-1, 3)[idxs] for mimg in self.multiscale_imgs]
570 |
571 | if self.supervise_hr:
572 | view_samples = [view_samples[-1] for _ in view_samples]
573 | else:
574 | img = self.all_imgs.view(-1, 3)[idxs, ...]
575 | view_samples = img.reshape(-1, 3)
576 |
577 | self.hit[idxs] += 1
578 |
579 | return view_samples, ray_orgs, ray_dirs
580 |
581 | def __getitem__(self, idx):
582 |
583 | if self.is_logging:
584 | view_samples, ray_orgs, ray_dirs = self.get_val_rays()
585 | else:
586 | view_samples, ray_orgs, ray_dirs = self.get_rays(idx)
587 |
588 | # Transform coordinate systems
589 | camera_params = self.dataset.get_camera_params()
590 |
591 | ray_dirs = ray_dirs[:, None, :]
592 | ray_orgs = ray_orgs[:, None, :]
593 |
594 | t_vals = torch.linspace(0.0, 1.0, self.samples_per_ray)
595 | t_vals = camera_params['near'] * (1.0 - t_vals) + camera_params['far'] * t_vals
596 | t_vals = t_vals[None, :].repeat(self.samples_per_view, 1)
597 |
598 | mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1])
599 | upper = torch.cat((mids, t_vals[..., -1:]), dim=-1)
600 | lower = torch.cat((t_vals[..., :1], mids), dim=-1)
601 |
602 | # Stratified samples in those intervals.
603 | t_rand = torch.rand(t_vals.shape)
604 | t_vals = lower + (upper - lower) * t_rand
605 |
606 | ray_samples = ray_orgs + ray_dirs * t_vals[..., None]
607 |
608 | t_intervals = t_vals[..., 1:] - t_vals[..., :-1]
609 | t_intervals = torch.cat((t_intervals, 1e10*torch.ones_like(t_intervals[:, 0:1])), dim=-1)
610 | t_intervals = (t_intervals * ray_dirs.norm(p=2, dim=-1))[..., None]
611 |
612 | # Compute distance samples from orgs
613 | dist_samples_to_org = torch.sqrt(torch.sum((ray_samples-ray_orgs)**2, dim=-1, keepdim=True))
614 |
615 | # broadcast tensors
616 | view_dirs = ray_dirs / ray_dirs.norm(p=2, dim=-1, keepdim=True).repeat(1, self.samples_per_ray, 1)
617 |
618 | in_dict = {'ray_samples': ray_samples,
619 | 'ray_orientations': view_dirs,
620 | 'ray_origins': ray_orgs,
621 | 't_intervals': t_intervals,
622 | 't': t_vals[..., None],
623 | 'ray_directions': ray_dirs}
624 | meta_dict = {'zs': dist_samples_to_org}
625 |
626 | gt_dict = {'pixel_samples': view_samples}
627 |
628 | return in_dict, meta_dict, gt_dict
629 |
630 |
631 | class MeshSDF(Dataset):
632 | ''' convert point cloud to SDF '''
633 |
634 | def __init__(self, pointcloud_path, num_samples=30**3,
635 | coarse_scale=1e-1, fine_scale=1e-3):
636 | super().__init__()
637 | self.num_samples = num_samples
638 | self.pointcloud_path = pointcloud_path
639 | self.coarse_scale = coarse_scale
640 | self.fine_scale = fine_scale
641 |
642 | self.load_mesh(pointcloud_path)
643 |
644 | def __len__(self):
645 | return 10000 # arbitrary
646 |
647 | def load_mesh(self, pointcloud_path):
648 | pointcloud = np.genfromtxt(pointcloud_path)
649 | self.v = pointcloud[:, :3]
650 | self.n = pointcloud[:, 3:]
651 |
652 | n_norm = (np.linalg.norm(self.n, axis=-1)[:, None])
653 | n_norm[n_norm == 0] = 1.
654 | self.n = self.n / n_norm
655 | self.v = self.normalize(self.v)
656 | self.kd_tree = KDTree(self.v)
657 | print('loaded pc')
658 |
659 | def normalize(self, coords):
660 | coords -= np.mean(coords, axis=0, keepdims=True)
661 | coord_max = np.amax(coords)
662 | coord_min = np.amin(coords)
663 | coords = (coords - coord_min) / (coord_max - coord_min) * 0.9
664 | coords -= 0.45
665 | return coords
666 |
667 | def sample_surface(self):
668 | idx = np.random.randint(0, self.v.shape[0], self.num_samples)
669 | points = self.v[idx]
670 | points[::2] += np.random.laplace(scale=self.coarse_scale, size=(points.shape[0]//2, points.shape[-1]))
671 | points[1::2] += np.random.laplace(scale=self.fine_scale, size=(points.shape[0]//2, points.shape[-1]))
672 |
673 | # wrap around any points that are sampled out of bounds
674 | points[points > 0.5] -= 1
675 | points[points < -0.5] += 1
676 |
677 | # use KDTree to get distance to surface and estimate the normal
678 | sdf, idx = self.kd_tree.query(points, k=3)
679 | avg_normal = np.mean(self.n[idx], axis=1)
680 | sdf = np.sum((points - self.v[idx][:, 0]) * avg_normal, axis=-1)
681 | sdf = sdf[..., None]
682 |
683 | return points, sdf
684 |
685 | def __getitem__(self, idx):
686 | coords, sdf = self.sample_surface()
687 |
688 | return {'coords': torch.from_numpy(coords).float()}, \
689 | {'sdf': torch.from_numpy(sdf).float()}
690 |
--------------------------------------------------------------------------------
/download_datasets.py:
--------------------------------------------------------------------------------
1 | import gdown
2 | import zipfile
3 | import subprocess
4 | import urllib.request
5 | import os
6 |
7 | # Make folder to save data in.
8 | os.makedirs('./data', exist_ok=True)
9 |
10 | # nerf
11 | print('Downloading blender dataset')
12 | gdown.download("https://drive.google.com/uc?id=18JxhpWD-4ZmuFKLzKlAw-w5PpzZxXOcG", './data/nerf_synthetic.zip')
13 |
14 | print('Extracting ...')
15 | with zipfile.ZipFile('./data/nerf_synthetic.zip', 'r') as f:
16 | f.extractall('./data/')
17 |
18 | print('Generating multiscale nerf dataset')
19 | subprocess.run(['python', 'convert_blender_data.py', '--blenderdir', './data/nerf_synthetic'])
20 |
21 | print('Downloading SDF datasets')
22 | gdown.download("https://drive.google.com/uc?id=1xBo6OCGmyWi0qD74EZW4lc45Gs4HXjWw", './data/gt_armadillo.xyz')
23 | gdown.download("https://drive.google.com/uc?id=1Pm3WHUvJiMJEKUnnhMjB6mUAnR9qhnxm", './data/gt_dragon.xyz')
24 | gdown.download("https://drive.google.com/uc?id=1wE24AZtXS8jbIIc-amYeEUtlxN8dFYCo", './data/gt_lucy.xyz')
25 | gdown.download("https://drive.google.com/uc?id=1OVw0JNA-NZtDXVmkf57erqwqDjqmF5Mc", './data/gt_thai.xyz')
26 |
27 | print('Downloading image dataset')
28 | urllib.request.urlretrieve('http://www.cs.albany.edu/~xypan/research/img/Kodak/kodim19.png', './data/lighthouse.png')
29 |
30 | print('Done!')
31 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: bacon
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - numpy=1.21.2
7 | - python=3.8.12
8 | - pytorch=1.10.1
9 | - tensorboard=2.6.0
10 | - tensorboard-data-server=0.6.0
11 | - tensorboard-plugin-wit=1.6.0
12 | - torchvision=0.11.2
13 | - pykdtree=1.3.4
14 | - pip
15 | - pip:
16 | - argparse==1.4.0
17 | - configargparse==1.5.3
18 | - gdown==4.2.0
19 | - matplotlib==3.4.3
20 | - pymcubes==0.1.2
21 | - scikit-image==0.18.3
22 | - scipy==1.7.1
23 | - tqdm==4.62.3
24 | - trimesh==3.9.35
25 | - setuptools==59.5.0
26 |
--------------------------------------------------------------------------------
/experiments/config/1d/bacon_freq1.ini:
--------------------------------------------------------------------------------
1 | experiment_name = bacon1_1d
2 | lr = 0.01
3 | gpu = 0
4 | model = mfn
5 | max_freq = 16.5
6 | batch_size = 1
7 | hidden_features = 128
8 | hidden_layers = 4
9 | num_steps = 1001
10 | activation = sine
11 | w0 = 10
12 | pe_scale = 3
13 | steps_til_ckpt = 100
14 | steps_til_summary = 100
15 | logging_root = ../logs
16 |
--------------------------------------------------------------------------------
/experiments/config/1d/bacon_freq2.ini:
--------------------------------------------------------------------------------
1 | experiment_name = bacon2_1d
2 | lr = 0.01
3 | gpu = 0
4 | model = mfn
5 | max_freq = 8.0
6 | batch_size = 1
7 | hidden_features = 128
8 | hidden_layers = 4
9 | num_steps = 1001
10 | activation = sine
11 | w0 = 10
12 | pe_scale = 3
13 | steps_til_ckpt = 100
14 | steps_til_summary = 100
15 | logging_root = ../logs
16 |
--------------------------------------------------------------------------------
/experiments/config/1d/ff.ini:
--------------------------------------------------------------------------------
1 | hidden_layers = 3
2 | experiment_name = ff_1d
3 | lr = 0.001
4 | gpu = 0
5 | model = mlp
6 | activation = relu
7 | w0 = 30.0
8 | pe_scale = 4.0
9 | max_freq = 16.0
10 | batch_size = 1
11 | hidden_features = 128
12 | num_steps = 1001
13 | steps_til_ckpt = 100
14 | steps_til_summary = 100
15 | logging_root = ../logs
16 |
--------------------------------------------------------------------------------
/experiments/config/1d/siren.ini:
--------------------------------------------------------------------------------
1 | experiment_name = siren_1d
2 | lr = 0.0001
3 | gpu = 0
4 | model = mlp
5 | activation = sine
6 | w0 = 30.0
7 | pe_scale = 4.0
8 | max_freq = 16.0
9 | batch_size = 1
10 | hidden_features = 128
11 | hidden_layers = 4
12 | num_steps = 1001
13 | steps_til_ckpt = 100
14 | steps_til_summary = 100
15 | logging_root = ../logs
16 |
--------------------------------------------------------------------------------
/experiments/config/img/bacon.ini:
--------------------------------------------------------------------------------
1 | config = ./image_bacon_config.ini
2 | experiment_name = bacon_img
3 | gpu = 1
4 | hidden_features = 256
5 | hidden_layers = 4
6 | res = 256
7 | centered = true
8 | model = mfn
9 | lr = 5e-3
10 | multiscale = true
11 | steps_til_summary = 100
12 | batch_size = 1
13 | activation = sine
14 | w0 = 10
15 | pe_scale = 3
16 | num_steps = 5001
17 | steps_til_ckpt = 100
18 | logging_root = ../logs
19 |
--------------------------------------------------------------------------------
/experiments/config/img/ff.ini:
--------------------------------------------------------------------------------
1 | config = ./image_ff_config.ini
2 | experiment_name = ff_img
3 | gpu = 1
4 | hidden_features = 256
5 | hidden_layers = 3
6 | res = 256
7 | centered = true
8 | model = mlp
9 | lr = 0.001
10 | multiscale = false
11 | steps_til_summary = 100
12 | batch_size = 1
13 | activation = relu
14 | w0 = 30
15 | pe_scale = 6
16 | num_steps = 5001
17 | steps_til_ckpt = 100
18 | logging_root = ../logs
19 |
--------------------------------------------------------------------------------
/experiments/config/img/mip.ini:
--------------------------------------------------------------------------------
1 | config = ./image_mip_config.ini
2 | experiment_name = mip_img
3 | gpu = 1
4 | hidden_features = 256
5 | hidden_layers = 4
6 | res = 256
7 | centered = true
8 | model = mlp
9 | lr = 0.001
10 | multiscale = true
11 | use_resized = true
12 | steps_til_summary = 100
13 | batch_size = 1
14 | activation = relu
15 | ipe = true
16 | w0 = 30
17 | pe_scale = 10
18 | num_steps = 5001
19 | steps_til_ckpt = 100
20 | logging_root = ../logs
21 |
--------------------------------------------------------------------------------
/experiments/config/img/siren.ini:
--------------------------------------------------------------------------------
1 | config = ./image_siren_config.ini
2 | experiment_name = siren_img
3 | gpu = 1
4 | hidden_features = 256
5 | hidden_layers = 4
6 | res = 256
7 | centered = true
8 | model = mlp
9 | lr = 0.0001
10 | multiscale = false
11 | steps_til_summary = 100
12 | batch_size = 1
13 | activation = sine
14 | w0 = 30
15 | pe_scale = 6
16 | num_steps = 5001
17 | steps_til_ckpt = 100
18 | logging_root = ../logs
19 |
--------------------------------------------------------------------------------
/experiments/config/nerf/bacon.ini:
--------------------------------------------------------------------------------
1 | experiment_name = bacon_lego
2 | gpu = 0
3 | num_workers = 0
4 | chunk_size_train = 4096
5 | chunk_size_eval = 4096
6 | hidden_layers = 8
7 | img_size = 512
8 | dataset_path = ../data/nerf_synthetic/lego
9 | steps_til_summary = 2000
10 | lr = 0.001
11 | hidden_features = 256
12 | multiscale = true
13 | use_resized = true
14 | reuse_filters = true
15 | samples_per_ray = 128
16 | samples_per_view = 4096
17 | logging_root = ../logs
18 | num_steps = 1000000
19 | steps_til_ckpt = 50000
20 | batch_size = 1
21 |
--------------------------------------------------------------------------------
/experiments/config/nerf/bacon_lr.ini:
--------------------------------------------------------------------------------
1 | experiment_name = bacon_lego_lowres
2 | gpu = 0
3 | num_workers = 0
4 | chunk_size_train = 4096
5 | chunk_size_eval = 4096
6 | hidden_layers = 8
7 | img_size = 64
8 | dataset_path = ../data/nerf_synthetic/lego
9 | steps_til_summary = 1000
10 | lr = 0.001
11 | hidden_features = 256
12 | multiscale = true
13 | use_resized = true
14 | reuse_filters = true
15 | samples_per_ray = 128
16 | samples_per_view = 1024
17 | logging_root = ../logs
18 | num_steps = 1000000
19 | steps_til_ckpt = 50000
20 | batch_size = 1
21 |
--------------------------------------------------------------------------------
/experiments/config/nerf/bacon_semisupervise.ini:
--------------------------------------------------------------------------------
1 | experiment_name = bacon_lego_semisupervise
2 | gpu = 0
3 | supervise_hr = true
4 | chunk_size_train = 4096
5 | chunk_size_eval = 4096
6 | hidden_layers = 8
7 | img_size = 512
8 | dataset_path = ../data/nerf_synthetic/lego
9 | steps_til_summary = 2000
10 | num_workers = 0
11 | lr = 0.001
12 | hidden_features = 256
13 | multiscale = true
14 | use_resized = true
15 | reuse_filters = true
16 | samples_per_ray = 128
17 | samples_per_view = 4096
18 | logging_root = ../logs
19 | num_steps = 1000000
20 | steps_til_ckpt = 50000
21 | batch_size = 1
22 |
--------------------------------------------------------------------------------
/experiments/config/sdf/bacon_armadillo.ini:
--------------------------------------------------------------------------------
1 | point_cloud_path = ../data/gt_armadillo.xyz
2 | experiment_name = bacon_armadillo
3 | num_pts_on = 10000
4 | num_steps = 200000
5 | gpu = 0
6 | coarse_scale = 0.1
7 | fine_scale = 0.001
8 | max_freq = 384
9 | logging_root = ../logs
10 | steps_til_summary = 1000
11 | lr = 0.01
12 | model_type = mfn
13 | multiscale = true
14 | hidden_size = 256
15 | steps_til_ckpt = 50000
16 | ckpt_step = 0
17 | hidden_layers = 8
18 | w0 = 30
19 | num_workers = 0
20 | coarse_weight = 1e-2
21 |
--------------------------------------------------------------------------------
/experiments/config/sdf/bacon_dragon.ini:
--------------------------------------------------------------------------------
1 | point_cloud_path = ../data/gt_dragon.xyz
2 | experiment_name = bacon_dragon
3 | num_pts_on = 10000
4 | num_steps = 200000
5 | gpu = 0
6 | coarse_scale = 0.1
7 | fine_scale = 0.001
8 | max_freq = 384
9 | logging_root = ../logs
10 | steps_til_summary = 1000
11 | lr = 0.01
12 | model_type = mfn
13 | multiscale = true
14 | hidden_size = 256
15 | steps_til_ckpt = 50000
16 | ckpt_step = 0
17 | hidden_layers = 8
18 | w0 = 30
19 | num_workers = 0
20 | coarse_weight = 1e-2
21 |
--------------------------------------------------------------------------------
/experiments/config/sdf/bacon_lucy.ini:
--------------------------------------------------------------------------------
1 | point_cloud_path = ../data/gt_lucy.xyz
2 | experiment_name = bacon_lucy
3 | num_pts_on = 10000
4 | num_steps = 200000
5 | gpu = 0
6 | coarse_scale = 0.1
7 | fine_scale = 0.001
8 | max_freq = 512
9 | logging_root = ../logs
10 | steps_til_summary = 1000
11 | lr = 0.01
12 | model_type = mfn
13 | multiscale = true
14 | hidden_size = 256
15 | steps_til_ckpt = 50000
16 | ckpt_step = 0
17 | hidden_layers = 8
18 | w0 = 30
19 | num_workers = 0
20 | coarse_weight = 1e-2
21 |
--------------------------------------------------------------------------------
/experiments/config/sdf/bacon_thai.ini:
--------------------------------------------------------------------------------
1 | point_cloud_path = ../data/gt_thai.xyz
2 | experiment_name = bacon_thai
3 | num_pts_on = 10000
4 | num_steps = 200000
5 | gpu = 0
6 | coarse_scale = 0.1
7 | fine_scale = 0.001
8 | max_freq = 512
9 | logging_root = ../logs
10 | steps_til_summary = 1000
11 | lr = 0.01
12 | model_type = mfn
13 | multiscale = true
14 | hidden_size = 256
15 | steps_til_ckpt = 50000
16 | ckpt_step = 0
17 | hidden_layers = 8
18 | w0 = 30
19 | num_workers = 0
20 | coarse_weight = 1e-2
21 |
--------------------------------------------------------------------------------
/experiments/config/sdf/ff_armadillo.ini:
--------------------------------------------------------------------------------
1 | point_cloud_path = ../data/gt_armadillo.xyz
2 | experiment_name = ff_armadillo
3 | num_pts_on = 10000
4 | lr = 0.001
5 | num_steps = 200000
6 | gpu = 0
7 | model_type = ff
8 | coarse_scale = 0.1
9 | fine_scale = 0.001
10 | hidden_layers = 7
11 | max_freq = 384
12 | pe_scale = 8.0
13 | logging_root = ../logs
14 | steps_til_summary = 1000
15 | multiscale = false
16 | hidden_size = 256
17 | steps_til_ckpt = 50000
18 | ckpt_step = 0
19 | w0 = 30
20 | num_workers = 0
21 | coarse_weight = 1e-2
22 |
--------------------------------------------------------------------------------
/experiments/config/sdf/ff_dragon.ini:
--------------------------------------------------------------------------------
1 | point_cloud_path = ../data/gt_dragon.xyz
2 | experiment_name = ff_dragon
3 | num_pts_on = 10000
4 | lr = 0.001
5 | num_steps = 200000
6 | gpu = 0
7 | model_type = ff
8 | coarse_scale = 0.1
9 | fine_scale = 0.001
10 | hidden_layers = 7
11 | max_freq = 384
12 | pe_scale = 8.0
13 | logging_root = ../logs
14 | steps_til_summary = 1000
15 | multiscale = false
16 | hidden_size = 256
17 | steps_til_ckpt = 50000
18 | ckpt_step = 0
19 | w0 = 30
20 | num_workers = 0
21 | coarse_weight = 1e-2
22 |
--------------------------------------------------------------------------------
/experiments/config/sdf/ff_lucy.ini:
--------------------------------------------------------------------------------
1 | point_cloud_path = ../data/gt_lucy.xyz
2 | experiment_name = ff_lucy
3 | num_pts_on = 10000
4 | lr = 0.001
5 | num_steps = 200000
6 | gpu = 0
7 | model_type = ff
8 | coarse_scale = 0.1
9 | fine_scale = 0.001
10 | hidden_layers = 7
11 | max_freq = 384
12 | pe_scale = 8.0
13 | logging_root = ../logs
14 | steps_til_summary = 1000
15 | multiscale = false
16 | hidden_size = 256
17 | steps_til_ckpt = 50000
18 | ckpt_step = 0
19 | w0 = 30
20 | num_workers = 0
21 | coarse_weight = 1e-2
22 |
--------------------------------------------------------------------------------
/experiments/config/sdf/ff_thai.ini:
--------------------------------------------------------------------------------
1 | point_cloud_path = ../data/gt_thai.xyz
2 | experiment_name = ff_thai
3 | num_pts_on = 10000
4 | lr = 0.001
5 | num_steps = 200000
6 | gpu = 0
7 | model_type = ff
8 | coarse_scale = 0.1
9 | fine_scale = 0.001
10 | hidden_layers = 7
11 | max_freq = 384
12 | pe_scale = 8.0
13 | logging_root = ../logs
14 | steps_til_summary = 1000
15 | multiscale = false
16 | hidden_size = 256
17 | steps_til_ckpt = 50000
18 | ckpt_step = 0
19 | w0 = 30
20 | num_workers = 0
21 | coarse_weight = 1e-2
22 |
--------------------------------------------------------------------------------
/experiments/config/sdf/siren_armadillo.ini:
--------------------------------------------------------------------------------
1 | point_cloud_path = ../data/gt_armadillo.xyz
2 | experiment_name = siren_armadillo
3 | num_pts_on = 10000
4 | lr = 0.0001
5 | num_steps = 200000
6 | gpu = 0
7 | model_type = siren
8 | coarse_scale = 0.1
9 | fine_scale = 0.001
10 | max_freq = 384
11 | logging_root = ../logs
12 | steps_til_summary = 1000
13 | multiscale = false
14 | hidden_size = 256
15 | steps_til_ckpt = 50000
16 | ckpt_step = 0
17 | hidden_layers = 8
18 | w0 = 30
19 | num_workers = 0
20 | coarse_weight = 1e-2
21 | pe_scale = 5
22 |
--------------------------------------------------------------------------------
/experiments/config/sdf/siren_dragon.ini:
--------------------------------------------------------------------------------
1 | point_cloud_path = ../data/gt_dragon.xyz
2 | experiment_name = siren_dragon
3 | num_pts_on = 10000
4 | lr = 0.0001
5 | num_steps = 200000
6 | gpu = 0
7 | model_type = siren
8 | coarse_scale = 0.1
9 | fine_scale = 0.001
10 | max_freq = 384
11 | logging_root = ../logs
12 | steps_til_summary = 1000
13 | multiscale = false
14 | hidden_size = 256
15 | steps_til_ckpt = 50000
16 | ckpt_step = 0
17 | hidden_layers = 8
18 | w0 = 30
19 | num_workers = 0
20 | coarse_weight = 1e-2
21 | pe_scale = 5
22 |
--------------------------------------------------------------------------------
/experiments/config/sdf/siren_lucy.ini:
--------------------------------------------------------------------------------
1 | point_cloud_path = ../data/gt_lucy.xyz
2 | experiment_name = siren_lucy
3 | num_pts_on = 10000
4 | lr = 0.0001
5 | num_steps = 200000
6 | gpu = 0
7 | model_type = siren
8 | coarse_scale = 0.1
9 | fine_scale = 0.001
10 | max_freq = 384
11 | logging_root = ../logs
12 | steps_til_summary = 1000
13 | multiscale = false
14 | hidden_size = 256
15 | steps_til_ckpt = 50000
16 | ckpt_step = 0
17 | hidden_layers = 8
18 | w0 = 30
19 | num_workers = 0
20 | coarse_weight = 1e-2
21 | pe_scale = 5
22 |
--------------------------------------------------------------------------------
/experiments/config/sdf/siren_thai.ini:
--------------------------------------------------------------------------------
1 | point_cloud_path = ../data/gt_thai.xyz
2 | experiment_name = siren_thai
3 | num_pts_on = 10000
4 | lr = 0.0001
5 | num_steps = 200000
6 | gpu = 0
7 | model_type = siren
8 | coarse_scale = 0.1
9 | fine_scale = 0.001
10 | max_freq = 384
11 | logging_root = ../logs
12 | steps_til_summary = 1000
13 | multiscale = false
14 | hidden_size = 256
15 | steps_til_ckpt = 50000
16 | ckpt_step = 0
17 | hidden_layers = 8
18 | w0 = 30
19 | num_workers = 0
20 | coarse_weight = 1e-2
21 | pe_scale = 5
22 |
--------------------------------------------------------------------------------
/experiments/figure_setup.py:
--------------------------------------------------------------------------------
1 | import math
2 | import matplotlib.pyplot as plt
3 | import os
4 | import subprocess
5 | import tempfile
6 |
7 |
8 | def get_fig_size(fig_width_cm, fig_height_cm=None):
9 | """Convert dimensions in centimeters to inches.
10 | If no height is given, it is computed using the golden ratio.
11 | """
12 | if not fig_height_cm:
13 | golden_ratio = (1 + math.sqrt(5))/2
14 | fig_height_cm = fig_width_cm / golden_ratio
15 |
16 | size_cm = (fig_width_cm, fig_height_cm)
17 | return map(lambda x: x/2.54, size_cm)
18 |
19 |
20 | """
21 | The following functions can be used by scripts to get the sizes of
22 | the various elements of the figures.
23 | """
24 |
25 |
26 | def label_size():
27 | """Size of axis labels
28 | """
29 | return 8
30 |
31 |
32 | def font_size():
33 | """Size of all texts shown in plots
34 | """
35 | return 8
36 |
37 |
38 | def ticks_size():
39 | """Size of axes' ticks
40 | """
41 | return 6
42 |
43 |
44 | def axis_lw():
45 | """Line width of the axes
46 | """
47 | return 0.6
48 |
49 |
50 | def plot_lw():
51 | """Line width of the plotted curves
52 | """
53 | return 1.0
54 |
55 |
56 | def figure_setup():
57 | """Set all the sizes to the correct values and use
58 | tex fonts for all texts.
59 | """
60 |
61 | params = {'text.usetex': False,
62 | 'figure.dpi': 150,
63 | 'font.size': font_size(),
64 | 'font.sans-serif': ['helvetica', 'Arial'],
65 | 'font.serif': ['Times', 'NimbusRomNo9L-Reg', 'TeX Gyre Termes', 'Times New Roman'],
66 | 'font.monospace': [],
67 | 'lines.linewidth': plot_lw(),
68 | 'axes.labelsize': label_size(),
69 | 'axes.titlesize': font_size(),
70 | 'axes.linewidth': axis_lw(),
71 | 'legend.fontsize': font_size(),
72 | 'xtick.labelsize': ticks_size(),
73 | 'ytick.labelsize': ticks_size(),
74 | 'font.family': 'sans-serif',
75 | 'xtick.bottom': False,
76 | 'xtick.top': False,
77 | 'ytick.left': False,
78 | 'ytick.right': False,
79 | 'xtick.major.pad': -1,
80 | 'ytick.major.pad': -2,
81 | 'xtick.minor.visible': False,
82 | 'ytick.minor.visible': False,
83 | 'xtick.labelsize': 8,
84 | 'ytick.labelsize': 8,
85 | 'axes.labelpad': 0,
86 | 'axes.titlepad': 3,
87 | 'axes.unicode_minus': False,
88 | 'pdf.fonttype': 42,
89 | 'ps.fonttype': 42}
90 | #'figure.constrained_layout.use': True}
91 | plt.rcParams.update(params)
92 |
93 |
94 | def save_fig(fig, file_name, fmt=None, dpi=150, tight=False):
95 | """Save a Matplotlib figure as EPS/PNG/PDF to the given path and trim it.
96 | """
97 |
98 | if not fmt:
99 | fmt = file_name.strip().split('.')[-1]
100 |
101 | if fmt not in ['eps', 'png', 'pdf']:
102 | raise ValueError('unsupported format: %s' % (fmt,))
103 |
104 | extension = '.%s' % (fmt,)
105 | if not file_name.endswith(extension):
106 | file_name += extension
107 |
108 | file_name = os.path.abspath(file_name)
109 | with tempfile.NamedTemporaryFile() as tmp_file:
110 | tmp_name = tmp_file.name + extension
111 |
112 | # save figure
113 | if tight:
114 | #fig.savefig(tmp_name, dpi=dpi, bbox_inches='tight')
115 | fig.savefig(file_name, dpi=dpi, bbox_inches='tight', pad_inches=0)
116 | else:
117 | fig.savefig(file_name, dpi=dpi)
118 |
119 | # trim it
120 | if fmt == 'eps':
121 | subprocess.call('epstool --bbox --copy %s %s' %
122 | (tmp_name, file_name), shell=True)
123 | elif fmt == 'png':
124 | subprocess.call('convert %s -trim %s' %
125 | (tmp_name, file_name), shell=True)
126 | #elif fmt == 'pdf':
127 | # subprocess.call('pdfcrop %s %s' % (tmp_name, file_name), shell=True)
128 |
--------------------------------------------------------------------------------
/experiments/plot_activation_distributions.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5 |
6 | import torch
7 | import modules
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 | import torch.nn as nn
11 | import warnings
12 | import figure_setup
13 |
14 |
15 | def make_square_axes(ax, scale=1):
16 | """Make an axes square in screen units.
17 | Should be called after plotting.
18 | """
19 | ax.set_aspect(scale / ax.get_data_ratio())
20 |
21 |
22 | # original MFN implementation
23 | class MFNBase(nn.Module):
24 | """
25 | Multiplicative filter network base class.
26 | Expects the child class to define the 'filters' attribute, which should be
27 | a nn.ModuleList of n_layers+1 filters with output equal to hidden_size.
28 | """
29 |
30 | def __init__(
31 | self, hidden_size, out_size, n_layers, weight_scale, bias=True, output_act=False
32 | ):
33 | super().__init__()
34 |
35 | self.linear = nn.ModuleList(
36 | [nn.Linear(hidden_size, hidden_size, bias) for _ in range(n_layers)]
37 | )
38 | self.output_linear = nn.Linear(hidden_size, out_size)
39 | self.output_act = output_act
40 |
41 | for lin in self.linear:
42 | lin.weight.data.uniform_(
43 | -np.sqrt(weight_scale / hidden_size),
44 | np.sqrt(weight_scale / hidden_size),
45 | )
46 |
47 | return
48 |
49 | def forward(self, x):
50 | out = self.filters[0](x)
51 | for i in range(1, len(self.filters)):
52 | out = self.filters[i](x) * self.linear[i - 1](out)
53 | out = self.output_linear(out)
54 |
55 | if self.output_act:
56 | out = torch.sin(out)
57 |
58 | return out
59 |
60 |
61 | class FourierLayer(nn.Module):
62 | """
63 | Sine filter as used in FourierNet.
64 | """
65 |
66 | def __init__(self, in_features, out_features, weight_scale):
67 | super().__init__()
68 | self.linear = nn.Linear(in_features, out_features)
69 | self.linear.weight.data *= weight_scale # gamma
70 | self.linear.bias.data.uniform_(-np.pi, np.pi)
71 | return
72 |
73 | def forward(self, x):
74 | return torch.sin(self.linear(x))
75 |
76 |
77 | class FourierNet(MFNBase):
78 | def __init__(
79 | self,
80 | in_size,
81 | hidden_size,
82 | out_size,
83 | n_layers=3,
84 | input_scale=256.0,
85 | weight_scale=1.0,
86 | bias=True,
87 | output_act=False,
88 | ):
89 | super().__init__(
90 | hidden_size, out_size, n_layers, weight_scale, bias, output_act
91 | )
92 | self.filters = nn.ModuleList(
93 | [
94 | FourierLayer(in_size, hidden_size, input_scale / np.sqrt(n_layers + 1))
95 | for _ in range(n_layers + 1)
96 | ]
97 | )
98 |
99 |
100 | def plot_initialization(use_original=False, plot_fit=True):
101 |
102 | f = 256
103 | nl = 8
104 |
105 | if use_original:
106 | model = FourierNet(1, 1024, 1, nl)
107 | else:
108 | model = modules.BACON(1, 1024, 1, nl, frequency=(f, f))
109 |
110 | coords = torch.linspace(-0.5, 0.5, 1000)[:, None]
111 |
112 | activations = []
113 | activations.append(coords.mean(-1))
114 | out = model.filters[0].linear(coords)
115 | activations.append(out.flatten())
116 | out = model.filters[0](coords)
117 | activations.append(out.flatten())
118 | for i in range(1, len(model.filters)):
119 | activations.append(model.linear[i-1](out).flatten())
120 | out = model.filters[i](coords) * model.linear[i - 1](out)
121 | activations.append(out.flatten())
122 |
123 | figure_setup.figure_setup()
124 | fig = plt.figure(figsize=(6.9, 5))
125 | gs = fig.add_gridspec(5, 4, wspace=0.25, hspace=0.25)
126 | gs.update(left=0.05, right=1.0, top=0.95, bottom=0.05)
127 |
128 | for idx, act in enumerate(activations):
129 |
130 | if idx > 2:
131 | # plt.subplot(2, len(activations)//2 + 1, idx+2)
132 | fig.add_subplot(gs[(idx+1)//4, (idx+1) % 4])
133 | else:
134 | fig.add_subplot(gs[0, idx])
135 |
136 | # plt.subplot(2, len(activations)//2 + 1, idx+1)
137 | plt.hist(act.detach().cpu().numpy(), 50, density=True)
138 |
139 | if idx == 0:
140 | plt.title('Input')
141 | plt.xlim(-0.5, 0.5)
142 |
143 | if idx == 1:
144 | x = np.linspace(-120, 120, 1000)
145 |
146 | if plot_fit:
147 | plt.plot(x, 2/(2*np.pi*f/(nl+1)) * np.log(np.pi*f/(nl+1)/(np.minimum(abs(2*x), np.pi*f/(nl+1)))), 'r')
148 | plt.xlim(-120, 120)
149 | plt.title('Before Sine')
150 |
151 | if idx == 2:
152 | x = np.linspace(-1, 1, 1000)
153 |
154 | if plot_fit:
155 | with warnings.catch_warnings():
156 | warnings.simplefilter('ignore')
157 | plt.plot(x, 1/np.pi * 1 / (np.sqrt(1 - x**2)), 'r')
158 | plt.xlim(-1, 1)
159 | plt.title('After Sine')
160 |
161 | if idx % 2 == 1 and idx > 2:
162 | x = np.linspace(-4, 4, 1000)
163 | if plot_fit:
164 | plt.plot(x, 1 / np.sqrt(2*np.pi) * np.exp(-x**2/2), 'r')
165 | plt.xlim(-4, 4)
166 | plt.title('After Linear')
167 |
168 | if idx % 2 == 0 and idx > 2:
169 | # this is the product of a standard normal and
170 | # an arcsine distributed RV, which is an RV
171 | # that obeys the product rule
172 | # https://en.wikipedia.org/wiki/Distribution_of_the_product_of_two_random_variables
173 | # and its variance will be 1/2. Then, we can use the initialization scheme of siren (Thm 1.8)
174 |
175 | zs = np.linspace(-4, 4, 1000)
176 | x = np.linspace(-4, 4, 5000)
177 | out = np.zeros_like(zs)
178 | dx = x[1] - x[0]
179 | for idx, z in enumerate(zs):
180 | zdx2 = (z/x)**2
181 | with warnings.catch_warnings():
182 | warnings.simplefilter('ignore')
183 | tmp = 1 / np.sqrt(2*np.pi) * np.exp(-x**2/2) * 1/np.pi * 1/np.sqrt(1 - zdx2) * 1/abs(x) * dx
184 | tmp[np.isnan(tmp)] = 0
185 | out[idx] = tmp.sum()
186 |
187 | if plot_fit:
188 | plt.plot(zs, out, 'r')
189 | plt.title('After Product')
190 | plt.xlim(-4, 4)
191 |
192 | make_square_axes(plt.gca(), 0.5)
193 | plt.show()
194 |
195 |
196 | if __name__ == '__main__':
197 | plot_initialization(use_original=False, plot_fit=True)
198 | plot_initialization(use_original=True, plot_fit=False)
199 |
--------------------------------------------------------------------------------
/experiments/render_nerf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4 |
5 | import utils
6 | import dataio
7 | import modules
8 | import forward_models
9 | import training
10 | import torch
11 | import numpy as np
12 | import configargparse
13 | import dataclasses
14 | from dataclasses import dataclass
15 | from tqdm import tqdm
16 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity
17 | import skimage.io
18 | from functools import partial
19 |
20 |
21 | torch.backends.cudnn.benchmark = True
22 |
23 | ssim_fn = partial(structural_similarity, data_range=1,
24 | gaussian_weights=True, sigma=1.5,
25 | use_sample_covariance=False,
26 | multichannel=True,
27 | channel_axis=-1)
28 |
29 |
30 | @dataclass
31 | class Options:
32 | config: str
33 | experiment_name: str
34 | logging_root: str
35 | dataset_path: str
36 | num_epochs: int
37 | epochs_til_ckpt: int
38 | steps_til_summary: int
39 | gpu: int
40 | img_size: int
41 | chunk_size_train: int
42 | chunk_size_eval: int
43 | num_workers: int
44 | lr: float
45 | batch_size: int
46 | hidden_features: int
47 | hidden_layers: int
48 | model: str
49 | activation: str
50 | multiscale: bool
51 | single_network: bool
52 | use_resized: bool
53 | reuse_filters: bool
54 | samples_per_ray: int
55 | samples_per_view: int
56 | forward_mode: str
57 | supervise_hr: bool
58 | rank: int
59 |
60 | def __init__(self, **kwargs):
61 | names = set([f.name for f in dataclasses.fields(self)])
62 | for k, v in kwargs.items():
63 | if k in names:
64 | setattr(self, k, self.__annotations__[k](v))
65 |
66 | if 'supervise_hr' not in kwargs.keys():
67 | self.supervise_hr = False
68 |
69 | self.img_size = 512
70 |
71 |
72 | def load_dataset(opt, res, scale):
73 | dataset = dataio.NerfBlenderDataset(opt.dataset_path,
74 | splits=['test'],
75 | mode='test',
76 | resize_to=(int(res/2**(3-scale)), int(res/2**(3-scale))),
77 | multiscale=opt.multiscale,
78 | override_scale=scale,
79 | testskip=1)
80 |
81 | coords_dataset = dataio.Implicit6DMultiviewDataWrapper(dataset,
82 | (int(res/2**(3-scale)), int(res/2**(3-scale))),
83 | dataset.get_camera_params(),
84 | samples_per_ray=256, # opt.samples_per_ray,
85 | samples_per_view=opt.samples_per_view,
86 | num_workers=opt.num_workers,
87 | multiscale=opt.use_resized,
88 | supervise_hr=opt.supervise_hr,
89 | scales=[1/8, 1/4, 1/2, 1])
90 | coords_dataset.toggle_logging_sampling()
91 |
92 | return coords_dataset
93 |
94 |
95 | def load_model(opt, checkpoint):
96 | # since model goes between -4 and 4 instead of -0.5 to 0.5
97 | sample_frequency = 3*(opt.img_size/4,)
98 |
99 | if opt.multiscale:
100 | # scale the frequencies of each layer accordingly
101 | input_scales = [1/24, 1/24, 1/24, 1/16, 1/16, 1/8, 1/8, 1/4, 1/4]
102 | output_layers = [2, 4, 6, 8]
103 |
104 | with utils.HiddenPrint():
105 | model = modules.MultiscaleBACON(3, opt.hidden_features, 4,
106 | hidden_layers=opt.hidden_layers,
107 | bias=True,
108 | frequency=sample_frequency,
109 | quantization_interval=np.pi/4,
110 | input_scales=input_scales,
111 | output_layers=output_layers,
112 | reuse_filters=opt.reuse_filters)
113 | model.cuda()
114 |
115 | print('Loading checkpoints')
116 | state_dict = torch.load(checkpoint)
117 | model.load_state_dict(state_dict, strict=False)
118 |
119 | models = {'combined': model}
120 |
121 | return models
122 |
123 |
124 | def render_in_chunks(in_dict, model, chunk_size, return_all=False):
125 | batches, rays, samples, dims = in_dict['ray_samples'].shape
126 |
127 | in_dict['ray_samples'] = in_dict['ray_samples'].reshape(-1, 3)
128 |
129 | if return_all:
130 | out = [torch.zeros(batches, rays, samples, 4, device=in_dict['ray_samples'].device) for i in range(4)]
131 | out = [o.reshape(-1, 4) for o in out]
132 | else:
133 | out = torch.zeros(batches, rays, samples, 4, device=in_dict['ray_samples'].device)
134 | out = out.reshape(-1, 4)
135 |
136 | chunk_size *= 128
137 | num_chunks = int(np.ceil(rays*samples / (chunk_size)))
138 |
139 | for i in range(num_chunks):
140 | tmp = {'ray_samples': in_dict['ray_samples'][i*chunk_size:(i+1)*chunk_size, ...]}
141 |
142 | if return_all:
143 | for j in range(4):
144 | out[j][i*chunk_size:(i+1)*chunk_size, ...] = model(tmp)['model_out']['output'][j]
145 | else:
146 | out[i*chunk_size:(i+1)*chunk_size, ...] = model(tmp)['model_out']['output'][-1]
147 |
148 | if return_all:
149 | out = [o.reshape(batches, rays, samples, 4) for o in out]
150 | return {'model_out': {'output': out}, 'model_in': {'t_intervals': in_dict['t_intervals']}}
151 |
152 | else:
153 | out = out.reshape(batches, rays, samples, 4)
154 | return {'model_out': {'output': [out]}, 'model_in': {'t_intervals': in_dict['t_intervals']}}
155 |
156 |
157 | def render_all_in_chunks(in_dict, model, scale, chunk_size):
158 | batches, rays, samples, dims = in_dict['ray_samples'].shape
159 | out = torch.zeros(batches, rays, 3)
160 | num_chunks = int(np.ceil(rays / (chunk_size)))
161 |
162 | with torch.no_grad():
163 | for i in range(num_chunks):
164 | # transfer to cuda
165 | model_in = {k: v[:, i*chunk_size:(i+1)*chunk_size, ...].cuda() for k, v in in_dict.items()}
166 | model_in = training.dict2cuda(model_in)
167 |
168 | # run first forward pass for importance sampling
169 | model.stop_after = 0
170 | model_out = {'combined': model(model_in)}
171 |
172 | # resample rays
173 | model_in = training.sample_pdf(model_in, model_out, idx=0)
174 |
175 | # importance sampled pass
176 | model.stop_after = scale
177 | model_out = {'combined': model(model_in)}
178 |
179 | # render outputs
180 | sigma = model_out['combined']['model_out']['output'][-1][..., -1:]
181 | rgb = model_out['combined']['model_out']['output'][-1][..., :-1]
182 |
183 | t_interval = model_in['t_intervals']
184 |
185 | pred_weights = forward_models.compute_transmittance_weights(sigma, t_interval)
186 | pred_pixels = forward_models.compute_tomo_radiance(pred_weights, rgb)
187 |
188 | out[:, i*chunk_size:(i+1)*chunk_size, ...] = pred_pixels.cpu()
189 |
190 | pred_view = out.view(int(np.sqrt(rays)), int(np.sqrt(rays)), 3).detach().cpu()
191 | pred_view = torch.clamp(pred_view, 0, 1).numpy()
192 | return pred_view
193 |
194 |
195 | def render_image(opt, models, dataset, chunk_size,
196 | in_dict, meta_dict, gt_dict, scale,
197 | return_all=False):
198 |
199 | # add batch dimension
200 | for k, v in in_dict.items():
201 | in_dict[k].unsqueeze_(0)
202 |
203 | for i in range(len(gt_dict['pixel_samples'])):
204 | gt_dict['pixel_samples'][i].unsqueeze_(0)
205 |
206 | use_chunks = True
207 | if in_dict['ray_samples'].shape[1] < chunk_size:
208 | use_chunks = False
209 | use_chunks = True
210 |
211 | # render the whole thing in chunks
212 | if scale > 3:
213 | pred_view = render_all_in_chunks(in_dict, models['combined'], scale, chunk_size)
214 | return pred_view, 0.0, 0.0, 0.0
215 |
216 | in_dict = training.dict2cuda(in_dict)
217 |
218 | start = torch.cuda.Event(enable_timing=True)
219 | end = torch.cuda.Event(enable_timing=True)
220 | start.record()
221 |
222 | with torch.no_grad():
223 | models['combined'].stop_after = 0
224 | if use_chunks:
225 | out_dict = {key: render_in_chunks(in_dict, model, chunk_size)
226 | for key, model in models.items()}
227 | else:
228 | out_dict = {key: model(in_dict) for key, model in models.items()}
229 | models['combined'].stop_after = scale
230 |
231 | in_dict = training.sample_pdf(in_dict, out_dict, idx=0)
232 |
233 | if use_chunks:
234 | out_dict = {key: render_in_chunks(in_dict, model, chunk_size, return_all=return_all)
235 | for key, model in models.items()}
236 |
237 | else:
238 | out_dict = {key: model(in_dict) for key, model in models.items()}
239 |
240 | if return_all:
241 | sigma = [s[..., -1:] for s in out_dict['combined']['model_out']['output']]
242 | rgb = [c[..., :-1] for c in out_dict['combined']['model_out']['output']]
243 | t_interval = in_dict['t_intervals']
244 |
245 | if isinstance(gt_dict['pixel_samples'], list):
246 | gt_view = gt_dict['pixel_samples'][scale].squeeze(0).numpy()
247 | else:
248 | gt_view = gt_dict['pixel_samples'].detach().squeeze(0).numpy()
249 | view_shape = gt_view.shape
250 |
251 | pred_weights = [forward_models.compute_transmittance_weights(s, t_interval) for s in sigma]
252 | pred_pixels = [forward_models.compute_tomo_radiance(w, c) for w, c in zip(pred_weights, rgb)]
253 |
254 | pred_view = [p.view(view_shape).detach().cpu() for p in pred_pixels]
255 | pred_view = [torch.clamp(p, 0, 1).numpy() for p in pred_view]
256 |
257 | end.record()
258 | torch.cuda.synchronize()
259 | elapsed = start.elapsed_time(end)
260 |
261 | return pred_view, 0, 0, elapsed
262 |
263 | else:
264 | sigma = out_dict['combined']['model_out']['output'][-1][..., -1:]
265 | rgb = out_dict['combined']['model_out']['output'][-1][..., :-1]
266 | t_interval = in_dict['t_intervals']
267 |
268 | if isinstance(gt_dict['pixel_samples'], list):
269 | gt_view = gt_dict['pixel_samples'][scale].squeeze(0).numpy()
270 | else:
271 | gt_view = gt_dict['pixel_samples'].detach().squeeze(0).numpy()
272 | view_shape = gt_view.shape
273 |
274 | pred_weights = forward_models.compute_transmittance_weights(sigma, t_interval)
275 | pred_pixels = forward_models.compute_tomo_radiance(pred_weights, rgb)
276 |
277 | # log the images
278 | end.record()
279 | torch.cuda.synchronize()
280 | elapsed = start.elapsed_time(end)
281 |
282 | pred_view = pred_pixels.view(view_shape).detach().cpu()
283 | pred_view = torch.clamp(pred_view, 0, 1).numpy()
284 |
285 | psnr = peak_signal_noise_ratio(gt_view, pred_view)
286 | ssim = ssim_fn(gt_view, pred_view)
287 |
288 | return pred_view, psnr, ssim, elapsed
289 |
290 |
291 | def eval_nerf_bacon(scene, config, checkpoint, outdir, res, scale, chunk_size=10000, return_all=False, val_idx=None):
292 |
293 | os.makedirs(f'./outputs/nerf/{outdir}', exist_ok=True)
294 |
295 | p = configargparse.DefaultConfigFileParser()
296 | with open(config) as f:
297 | opt = p.parse(f)
298 |
299 | opt = Options(**opt)
300 | dataset = load_dataset(opt, res, scale)
301 | models = load_model(opt, checkpoint)
302 |
303 | for k in models.keys():
304 | models[k].stop_after = scale
305 |
306 | # render images
307 | psnrs = []
308 | ssims = []
309 | dataset_generator = iter(dataset)
310 | for idx in range(len(dataset)):
311 |
312 | if val_idx is not None:
313 | dataset.val_idx = val_idx
314 | idx = val_idx
315 |
316 | in_dict, meta_dict, gt_dict = next(dataset_generator)
317 |
318 | images, psnr, ssim, elapsed = render_image(opt, models, dataset, chunk_size,
319 | in_dict, meta_dict, gt_dict,
320 | scale, return_all=return_all)
321 |
322 | tqdm.write(f'Scale: {scale} | PSNR: {psnr:.02f} dB, SSIM: {ssim:.02f}, Elapsed: {elapsed:.02f} ms')
323 |
324 | if return_all:
325 | for s in range(4):
326 | skimage.io.imsave(f'./outputs/nerf/{outdir}/r_{idx}_d{3-s}.png', (images[s]*255).astype(np.uint8))
327 | else:
328 | np.save(f'./outputs/nerf/{outdir}/r_{idx}_d{3-scale}.npy', {'psnr': psnr, 'ssim': ssim})
329 | skimage.io.imsave(f'./outputs/nerf/{outdir}/r_{idx}_d{3-scale}.png', (images*255).astype(np.uint8))
330 |
331 | psnrs.append(psnr)
332 | ssims.append(ssim)
333 |
334 | if val_idx is not None:
335 | break
336 |
337 | if not return_all and val_idx is not None:
338 | np.save(f'./outputs/nerf/{outdir}/metrics_d{3-scale}.npy', {'psnr': psnrs, 'ssim': ssims,
339 | 'avg_psnr': np.mean(psnrs),
340 | 'avg_ssim': np.mean(ssims)})
341 |
342 | print(f'Avg. PSNR: {np.mean(psnrs):.02f}, Avg. SSIM: {np.mean(ssims):.02f}')
343 |
344 |
345 | if __name__ == '__main__':
346 | # before running this you need to download the nerf blender datasets for the lego model
347 | # and place in ../data/nerf_synthetic/lego
348 | # these can be downloaded here
349 | # https://drive.google.com/drive/folders/1lrDkQanWtTznf48FCaW5lX9ToRdNDF1a
350 |
351 | # render the model trained with explicit supervision at each scale
352 | config = './config/nerf/bacon.ini'
353 | checkpoint = '../trained_models/lego.pth'
354 | outdir = 'lego'
355 | res = 512
356 | for scale in range(4):
357 | eval_nerf_bacon('lego', config, checkpoint, outdir, res, scale)
358 |
359 | # render the semisupervised model
360 | config = './config/nerf/bacon_semisupervise.ini'
361 | checkpoint = '../trained_models/lego_semisupervise.pth'
362 | outdir = 'lego_semisupervise'
363 | res = 512
364 | for scale in range(4):
365 | eval_nerf_bacon('lego_semisupervise', config, checkpoint, outdir, res, scale)
366 |
--------------------------------------------------------------------------------
/experiments/render_sdf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4 |
5 | import utils
6 | import modules
7 | import torch
8 | import numpy as np
9 | from tqdm import tqdm
10 | import mcubes
11 | import trimesh
12 | import dataio
13 | import math
14 |
15 |
16 | def export_model(ckpt_path, model_name, N=512, model_type='bacon', hidden_layers=8,
17 | hidden_size=256, output_layers=[1, 2, 4, 8],
18 | return_sdf=False, adaptive=True):
19 |
20 | # the network has 4 output levels of detail
21 | num_outputs = len(output_layers)
22 | max_frequency = 3*(32,)
23 |
24 | # load model
25 | with utils.HiddenPrint():
26 | model = modules.MultiscaleBACON(3, hidden_size, 1,
27 | hidden_layers=hidden_layers,
28 | bias=True,
29 | frequency=max_frequency,
30 | quantization_interval=np.pi,
31 | is_sdf=True,
32 | output_layers=output_layers,
33 | reuse_filters=True)
34 |
35 | ckpt = torch.load(ckpt_path)
36 | model.load_state_dict(ckpt)
37 | model.cuda()
38 |
39 | if not adaptive:
40 | # extracts separate meshes for each scale
41 | generate_mesh(model, N, return_sdf, num_outputs, model_name)
42 |
43 | else:
44 | # extracts single-scale output
45 | generate_mesh_adaptive(model, model_name)
46 |
47 |
48 | def generate_mesh(model, N, return_sdf=False, num_outputs=4, model_name='model'):
49 |
50 | # write output
51 | x = torch.linspace(-0.5, 0.5, N)
52 | if return_sdf:
53 | x = torch.arange(-N//2, N//2) / N
54 | x = x.float()
55 | x, y, z = torch.meshgrid(x, x, x)
56 | render_coords = torch.stack((x.flatten(), y.flatten(), z.flatten()), dim=-1).cuda()
57 | sdf_values = [np.zeros((N**3, 1)) for i in range(num_outputs)]
58 |
59 | # render in a batched fashion to save memory
60 | bsize = int(128**2)
61 | for i in tqdm(range(int(N**3 / bsize))):
62 | coords = render_coords[i*bsize:(i+1)*bsize, :]
63 | out = model({'coords': coords})['model_out']
64 |
65 | if not isinstance(out, list):
66 | out = [out, ]
67 |
68 | for idx, sdf in enumerate(out):
69 | sdf_values[idx][i*bsize:(i+1)*bsize] = sdf.detach().cpu().numpy()
70 |
71 | if return_sdf:
72 | return [sdf.reshape(N, N, N) for sdf in sdf_values]
73 |
74 | for idx, sdf in enumerate(sdf_values):
75 | sdf = sdf.reshape(N, N, N)
76 | vertices, triangles = mcubes.marching_cubes(-sdf, 0)
77 | mesh = trimesh.Trimesh(vertices=vertices, faces=triangles)
78 | mesh.vertices = (mesh.vertices / N - 0.5) + 0.5/N
79 |
80 | os.makedirs('./outputs/meshes', exist_ok=True)
81 | mesh.export(f"./outputs/meshes/{model_name}_{idx+1}.obj")
82 |
83 |
84 | def prepare_multi_scale(res, num_scales):
85 | def coord2ind(xyz_coord, res):
86 | # xyz_coord: * x 3
87 | x, y, z = torch.split(xyz_coord, 1, dim=-1)
88 | flat_ind = x * res**2 + y * res + z
89 | return flat_ind.squeeze(-1) # *
90 |
91 | shifts = torch.from_numpy(np.stack(np.mgrid[:2, :2, :2], axis=-1)).view(-1, 3)
92 |
93 | def subdiv_index(xyz_prev, next_res): # should output (N^3)*8
94 | xyz_next = xyz_prev.unsqueeze(1) * 2 + shifts # (N^3)x8x3
95 | flat_ind_next = coord2ind(xyz_next, next_res) # (N^3)*8
96 | return flat_ind_next
97 |
98 | lowest_res = res / 2**(num_scales-1)
99 | subdiv_hash_list = []
100 |
101 | for i in range(num_scales-1):
102 | curr_res = int(lowest_res*2**i)
103 | xyz_ind = torch.from_numpy(np.stack(np.mgrid[:curr_res, :curr_res, :curr_res], axis=-1)).view(-1, 3) # (N^3)x3
104 | subdiv_hash = subdiv_index(xyz_ind, curr_res * 2)
105 | subdiv_hash_list.append(subdiv_hash.cuda().long())
106 | return subdiv_hash_list
107 |
108 |
109 | # multi-scale marching cubes
110 | def compute_one_scale(model, layer_ind, render_coords, sdf_values, hash_ind):
111 | assert(len(render_coords) == len(hash_ind))
112 | bsize = int(128 ** 2)
113 | for i in range(int(len(render_coords) / bsize)+1):
114 | coords = render_coords[i * bsize:(i + 1) * bsize, :]
115 | out = model({'coords': coords}, specified_layers=output_layers[layer_ind])['model_out']
116 | sdf_values[hash_ind[i * bsize:(i + 1) * bsize]] = out[0]
117 |
118 |
119 | def compute_one_scale_adaptive(model, layer_ind, render_coords, sdf_values, hash_ind, threshold=0.003):
120 | assert(len(render_coords) == len(hash_ind))
121 | bsize = int(128 ** 2)
122 | for i in range(int(len(render_coords) / bsize)+1):
123 | coords = render_coords[i * bsize:(i + 1) * bsize, :]
124 | out = model({'coords': coords}, specified_layers=2, get_feature=True)['model_out']
125 | sdf = out[0][0]
126 | if output_layers[layer_ind] > 2:
127 | feature = out[0][1]
128 | near_surf = (sdf.abs() < threshold).squeeze()
129 | coords_surf = coords[near_surf]
130 | feature_surf = feature[near_surf]
131 | out = model({'coords': coords_surf}, specified_layers=output_layers[layer_ind],
132 | continue_layer=2, continue_feature=feature_surf)['model_out']
133 | sdf_near = out[0]
134 | sdf[near_surf] = sdf_near
135 |
136 | sdf_values[hash_ind[i * bsize:(i + 1) * bsize]] = sdf
137 |
138 |
139 | def generate_mesh_adaptive(model, model_name):
140 | with torch.no_grad():
141 | lowest_res = N / 2 ** (len(output_layers) - 1)
142 | compute_one_scale(model, 0, coords_list[0], sdf_out_list[0], subdiv_hashes[0])
143 |
144 | for i in range(1, num_outputs):
145 | curr_res = int(lowest_res*2**(i-1))
146 | pixel_len = 1 / curr_res
147 | threshold = (math.sqrt(2)*pixel_len*0.5)*2
148 | sdf_prev = sdf_out_list[i-1]
149 | sdf_curr = sdf_out_list[i]
150 | hash_curr = subdiv_hashes[i]
151 | coords_curr = coords_list[i]
152 | near_surf_prev = (sdf_prev.abs() <= threshold).squeeze(-1)
153 |
154 | # empty space
155 | sdf_curr[hash_curr[~near_surf_prev]] = sdf_prev[~near_surf_prev].unsqueeze(-1)
156 |
157 | # non-empty space
158 | non_empty_ind = hash_curr[near_surf_prev].flatten()
159 |
160 | if i == num_outputs-1:
161 | compute_one_scale_adaptive(model, i, coords_curr[non_empty_ind], sdf_curr,
162 | non_empty_ind, threshold=pixel_len*0.5*2.)
163 | else:
164 | compute_one_scale(model, i, coords_curr[non_empty_ind], sdf_curr, non_empty_ind)
165 |
166 | # run marching cubes
167 | sdf = sdf_curr.reshape(N, N, N).detach().cpu().numpy()
168 | vertices, triangles = mcubes.marching_cubes(-sdf, 0)
169 | mesh = trimesh.Trimesh(vertices=vertices, faces=triangles)
170 | mesh.vertices = (mesh.vertices / N - 0.5) + 0.5/N
171 |
172 | os.makedirs('./outputs/meshes', exist_ok=True)
173 | mesh.export(f"./outputs/meshes/{model_name}.obj")
174 |
175 |
176 | def export_meshes(adaptive=True):
177 | bacon_ckpts = ['../trained_models/dragon.pth',
178 | '../trained_models/armadillo.pth',
179 | '../trained_models/lucy.pth',
180 | '../trained_models/thai.pth']
181 |
182 | bacon_names = ['bacon_dragon',
183 | 'bacon_armadillo',
184 | 'bacon_lucy',
185 | 'bacon_thai']
186 |
187 | print('Exporting BACON')
188 | for ckpt, name in tqdm(zip(bacon_ckpts, bacon_names), total=len(bacon_ckpts)):
189 | export_model(ckpt, name, model_type='bacon', output_layers=output_layers, adaptive=adaptive)
190 |
191 |
192 | def init_multiscale_mc():
193 | subdiv_hashes = prepare_multi_scale(N, len(output_layers)) # (N^3)*8
194 | subdiv_hashes = [torch.arange((N // 8) ** 3).cuda().long(), ] + subdiv_hashes
195 | lowest_res = N // 2**(len(output_layers)-1)
196 | coords_list = [dataio.get_mgrid(lowest_res*(2**i), dim=3).cuda() for i in range(len(output_layers))] # (N^3)*3
197 | sdf_out_list = [torch.zeros(((lowest_res*(2**i))**3), 1).cuda() for i in range(len(output_layers))] # (N^3)
198 |
199 | return subdiv_hashes, lowest_res, coords_list, sdf_out_list
200 |
201 |
202 | if __name__ == '__main__':
203 | global N, output_layers, subdiv_hashes, lowest_res, coords_list, sdf_out_list, num_outputs
204 | N = 512
205 | output_layers = [2, 4, 6, 8]
206 | num_outputs = len(output_layers)
207 |
208 | subdiv_hashes, lowest_res, coords_list, sdf_out_list = init_multiscale_mc()
209 |
210 | # export meshes, use adaptive SDF evaluation or not
211 | # setting adaptive=False will output meshes at all resolutions
212 | # while adaptive=True while extract only a high-resolution mesh
213 | export_meshes(adaptive=True)
214 |
--------------------------------------------------------------------------------
/experiments/train_1d.py:
--------------------------------------------------------------------------------
1 | # Enable import from parent package
2 | import sys
3 | import os
4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5 |
6 | from torch.utils.tensorboard import SummaryWriter
7 | import dataio
8 | import utils
9 | import training
10 | import loss_functions
11 | import modules
12 | from torch.utils.data import DataLoader
13 | import configargparse
14 | import torch
15 | from functools import partial
16 | import numpy as np
17 |
18 | torch.backends.cudnn.benchmark = True
19 | torch.set_num_threads(4)
20 |
21 | p = configargparse.ArgumentParser()
22 | p.add('-c', '--config', required=False, is_config_file=True, help='Path to config file.')
23 |
24 | # General training options
25 | p.add_argument('--batch_size', type=int, default=1)
26 | p.add_argument('--hidden_features', type=int, default=128)
27 | p.add_argument('--hidden_layers', type=int, default=4)
28 | p.add_argument('--experiment_name', type=str, default='train_1d_mfn',
29 | help='path to directory where checkpoints & tensorboard events will be saved.')
30 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=5e-5')
31 | p.add_argument('--num_steps', type=int, default=1001,
32 | help='Number of epochs to train for.')
33 | p.add_argument('--gpu', type=int, default=0,
34 | help='GPU ID to use')
35 |
36 | p.add_argument('--model', default='mfn', choices=['mfn', 'mlp'],
37 | help='use MFN or standard MLP')
38 | p.add_argument('--activation', type=str, default='sine',
39 | choices=['sine', 'relu', 'requ', 'gelu', 'selu', 'softplus', 'tanh', 'swish'],
40 | help='activation to use (for model mlp only)')
41 | p.add_argument('--w0', type=float, default=10)
42 | p.add_argument('--pe_scale', type=float, default=3, help='positional encoding scale')
43 | p.add_argument('--no_pe', action='store_true', default=False, help='override to have no positional encoding for relu mlp')
44 | p.add_argument('--max_freq', type=float, default=5, help='The network-equivalent sample rate used to represent the signal. Should be at least twice the Nyquist frequency.')
45 |
46 | # summary options
47 | p.add_argument('--steps_til_ckpt', type=int, default=100,
48 | help='Time interval in seconds until checkpoint is saved.')
49 | p.add_argument('--steps_til_summary', type=int, default=100,
50 | help='Time interval in seconds until tensorboard summary is saved.')
51 |
52 | # logging options
53 | p.add_argument('--logging_root', type=str, default='../logs', help='root for logging')
54 |
55 | opt = p.parse_args()
56 |
57 | if opt.experiment_name is None and opt.render_model is None:
58 | p.error('--experiment_name is required.')
59 |
60 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu)
61 |
62 |
63 | def main():
64 |
65 | print('--- Run Configuration ---')
66 | for k, v in vars(opt).items():
67 | print(k, v)
68 |
69 | train()
70 |
71 |
72 | def train():
73 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
74 | utils.cond_mkdir(root_path)
75 |
76 | fn = dataio.sines1
77 | train_dataset = dataio.Func1DWrapper(range=(-0.5, 0.5),
78 | fn=fn,
79 | sampling_density=1000,
80 | train_every=1000/18) # 18 samples is ~1.1 the nyquist rate assuming fmax=8
81 |
82 | train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0)
83 |
84 | if opt.model == 'mlp':
85 | model = modules.CoordinateNet(nl=opt.activation,
86 | in_features=1,
87 | out_features=1,
88 | hidden_features=opt.hidden_features,
89 | num_hidden_layers=opt.hidden_layers,
90 | w0=opt.w0,
91 | pe_scale=opt.pe_scale,
92 | use_sigmoid=False,
93 | no_pe=opt.no_pe)
94 |
95 | elif opt.model == 'mfn':
96 | model = modules.BACON(1, opt.hidden_features, 1,
97 | hidden_layers=opt.hidden_layers,
98 | bias=True,
99 | frequency=[opt.max_freq, ],
100 | quantization_interval=2*np.pi)
101 | else:
102 | raise ValueError('model must be mlp or mfn')
103 |
104 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
105 | params = sum([np.prod(p.size()) for p in model_parameters])
106 | print(f'Num. Parameters: {params}')
107 |
108 | model.cuda()
109 |
110 | # Define the loss
111 | loss_fn = loss_functions.function_mse
112 | summary_fn = partial(utils.write_simple_1D_function_summary, train_dataset)
113 |
114 | # Save command-line parameters log directory.
115 | p.write_config_file(opt, [os.path.join(root_path, 'config.ini')])
116 | with open(os.path.join(root_path, "params.txt"), "w") as out_file:
117 | out_file.write('\n'.join(["%s: %s" % (key, value) for key, value in vars(opt).items()]))
118 |
119 | # Save text summary of model into log directory.
120 | with open(os.path.join(root_path, "model.txt"), "w") as out_file:
121 | out_file.write(str(model))
122 |
123 | training.train(model=model, train_dataloader=train_dataloader, steps=opt.num_steps, lr=opt.lr,
124 | steps_til_summary=opt.steps_til_summary, steps_til_checkpoint=opt.steps_til_ckpt,
125 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn)
126 |
127 |
128 | if __name__ == '__main__':
129 | main()
130 |
--------------------------------------------------------------------------------
/experiments/train_img.py:
--------------------------------------------------------------------------------
1 | # Enable import from parent package
2 | import sys
3 | import os
4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5 |
6 | from torch.utils.tensorboard import SummaryWriter
7 | import numpy as np
8 | import dataio
9 | import utils
10 | import training
11 | import loss_functions
12 | import modules
13 | from torch.utils.data import DataLoader
14 | import configargparse
15 | import torch
16 | from functools import partial
17 |
18 | torch.backends.cudnn.benchmark = True
19 | torch.set_num_threads(4)
20 |
21 | p = configargparse.ArgumentParser()
22 |
23 | # config file, output directories
24 | p.add('-c', '--config', required=False, is_config_file=True,
25 | help='Path to config file.')
26 | p.add_argument('--logging_root', type=str, default='../logs', help='root for logging')
27 | p.add_argument('--experiment_name', type=str, default='train_img',
28 | help='path to directory where checkpoints & tensorboard events will be saved.')
29 |
30 | # general training options
31 | p.add_argument('--model', default='mfn', choices=['mfn', 'mlp'],
32 | help='use MFN or standard MLP')
33 | p.add_argument('--batch_size', type=int, default=1)
34 | p.add_argument('--hidden_features', type=int, default=32)
35 | p.add_argument('--hidden_layers', type=int, default=4)
36 | p.add_argument('--res', type=int, default=256,
37 | help='resolution of image to fit, also used to set the network-equivalent sample rate'
38 | + ' i.e., the maximum network bandwidth in cycles per unit interval is half this value')
39 | p.add_argument('--lr', type=float, default=5e-4, help='learning rate')
40 | p.add_argument('--num_steps', type=int, default=5001,
41 | help='number of training steps')
42 | p.add_argument('--gpu', type=int, default=0,
43 | help='gpu id to use for training')
44 |
45 | # mfn options
46 | p.add_argument('--multiscale', action='store_true', default=False,
47 | help='use multiscale')
48 | p.add_argument('--use_resized', action='store_true', default=False,
49 | help='use multiscale')
50 |
51 | # mlp options
52 | p.add_argument('--activation', type=str, default='sine',
53 | choices=['sine', 'relu', 'requ', 'gelu', 'selu', 'softplus', 'tanh', 'swish'],
54 | help='activation to use (for model mlp only)')
55 | p.add_argument('--ipe', action='store_true', default=False,
56 | help='use integrated positional encoding')
57 | p.add_argument('--w0', type=float, default=10)
58 | p.add_argument('--pe_scale', type=float, default=3, help='positional encoding scale')
59 | p.add_argument('--no_pe', action='store_true', default=False,
60 | help='override to have no positional encoding for relu mlp')
61 |
62 | # data processing and i/o
63 | p.add_argument('--centered', action='store_true', default=False,
64 | help='centere input coordinates as -1 to 1')
65 | p.add_argument('--img_fn', type=str, default='../data/lighthouse.png',
66 | help='path to specific png filename')
67 | p.add_argument('--grayscale', action='store_true', default=False,
68 | help='if grayscale image')
69 |
70 | # summary, logging options
71 | p.add_argument('--steps_til_ckpt', type=int, default=100,
72 | help='Time interval in seconds until checkpoint is saved.')
73 | p.add_argument('--steps_til_summary', type=int, default=100,
74 | help='Time interval in seconds until tensorboard summary is saved.')
75 |
76 | opt = p.parse_args()
77 |
78 | if opt.experiment_name is None and opt.render_model is None:
79 | p.error('--experiment_name is required.')
80 |
81 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu)
82 |
83 |
84 | def main():
85 |
86 | print('--- Run Configuration ---')
87 | for k, v in vars(opt).items():
88 | print(k, v)
89 |
90 | train()
91 |
92 |
93 | def train():
94 |
95 | # set up logging dir
96 | opt.root_path = os.path.join(opt.logging_root, opt.experiment_name)
97 | utils.cond_mkdir(opt.root_path)
98 |
99 | # get datasets
100 | trn_dataset, val_dataset, dataloader = init_dataloader(opt)
101 |
102 | # set up coordinate network
103 | model = init_model(opt)
104 |
105 | # loss and tensorboard logging functions
106 | loss_fn, summary_fn = init_loss(opt, trn_dataset, val_dataset)
107 |
108 | # back up config file
109 | save_params(opt, model)
110 |
111 | # start training
112 | training.train(model=model, train_dataloader=dataloader,
113 | steps=opt.num_steps, lr=opt.lr,
114 | steps_til_summary=opt.steps_til_summary,
115 | steps_til_checkpoint=opt.steps_til_ckpt,
116 | model_dir=opt.root_path, loss_fn=loss_fn, summary_fn=summary_fn)
117 |
118 |
119 | def init_dataloader(opt):
120 | ''' load image datasets, dataloader '''
121 |
122 | if opt.img_fn == '../data/lighthouse.png':
123 | url = 'http://www.cs.albany.edu/~xypan/research/img/Kodak/kodim19.png'
124 | else:
125 | url = None
126 |
127 | # init datasets
128 | trn_dataset = dataio.ImageFile(opt.img_fn, grayscale=opt.grayscale, resolution=(opt.res, opt.res), url=url)
129 |
130 | val_dataset = dataio.ImageFile(opt.img_fn, grayscale=opt.grayscale, resolution=(2*opt.res, 2*opt.res), url=url)
131 |
132 | trn_dataset = dataio.ImageWrapper(trn_dataset, centered=opt.centered,
133 | include_end=False,
134 | multiscale=opt.use_resized,
135 | stages=3)
136 |
137 | val_dataset = dataio.ImageWrapper(val_dataset, centered=opt.centered,
138 | include_end=False,
139 | multiscale=opt.use_resized,
140 | stages=3)
141 |
142 | dataloader = DataLoader(trn_dataset, shuffle=True, batch_size=opt.batch_size,
143 | pin_memory=True, num_workers=0)
144 |
145 | return trn_dataset, val_dataset, dataloader
146 |
147 |
148 | def init_model(opt):
149 |
150 | if opt.grayscale:
151 | out_features = 1
152 | else:
153 | out_features = 3
154 |
155 | if opt.model == 'mlp':
156 |
157 | if opt.multiscale:
158 | m = modules.MultiscaleCoordinateNet
159 | else:
160 | m = modules.CoordinateNet
161 |
162 | model = m(nl=opt.activation,
163 | in_features=2,
164 | out_features=out_features,
165 | hidden_features=opt.hidden_features,
166 | num_hidden_layers=opt.hidden_layers,
167 | w0=opt.w0,
168 | pe_scale=opt.pe_scale,
169 | no_pe=opt.no_pe,
170 | integrated_pe=opt.ipe)
171 |
172 | elif opt.model == 'mfn':
173 |
174 | if opt.multiscale:
175 | m = modules.MultiscaleBACON
176 | else:
177 | m = modules.BACON
178 |
179 | input_scales = [1/8, 1/8, 1/4, 1/4, 1/4]
180 | output_layers = [1, 2, 4]
181 |
182 | model = m(2, opt.hidden_features, out_size=out_features,
183 | hidden_layers=opt.hidden_layers,
184 | bias=True,
185 | frequency=(opt.res, opt.res),
186 | quantization_interval=2*np.pi,
187 | input_scales=input_scales,
188 | output_layers=output_layers,
189 | reuse_filters=False)
190 |
191 | else:
192 | raise ValueError('model must be mlp or mfn')
193 |
194 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
195 | params = sum([np.prod(p.size()) for p in model_parameters])
196 | print(f'Num. Parameters: {params}')
197 | model.cuda()
198 |
199 | return model
200 |
201 |
202 | def init_loss(opt, trn_dataset, val_dataset):
203 | ''' define loss, summary functions given expmt configs '''
204 |
205 | # initialize the loss
206 | if opt.multiscale:
207 | loss_fn = partial(loss_functions.multiscale_image_mse, use_resized=opt.use_resized)
208 | summary_fn = partial(utils.write_multiscale_image_summary, (opt.res, opt.res),
209 | trn_dataset, use_resized=opt.use_resized, val_dataset=val_dataset)
210 | else:
211 | loss_fn = loss_functions.image_mse
212 | summary_fn = partial(utils.write_image_summary, (opt.res, opt.res), trn_dataset,
213 | val_dataset=val_dataset)
214 |
215 | return loss_fn, summary_fn
216 |
217 |
218 | def save_params(opt, model):
219 |
220 | # Save command-line parameters log directory.
221 | p.write_config_file(opt, [os.path.join(opt.root_path, 'config.ini')])
222 | with open(os.path.join(opt.root_path, "params.txt"), "w") as out_file:
223 | out_file.write('\n'.join(["%s: %s" % (key, value) for key, value in vars(opt).items()]))
224 |
225 | # Save text summary of model into log directory.
226 | with open(os.path.join(opt.root_path, "model.txt"), "w") as out_file:
227 | out_file.write(str(model))
228 |
229 |
230 | if __name__ == '__main__':
231 | main()
232 |
--------------------------------------------------------------------------------
/experiments/train_radiance_field.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4 |
5 | from torch.utils.tensorboard import SummaryWriter
6 | import dataio
7 | import utils
8 | import training
9 | import loss_functions
10 | import modules
11 | from torch.utils.data import DataLoader
12 | import configargparse
13 | from functools import partial
14 | import torch
15 | import numpy as np
16 |
17 | torch.set_num_threads(8)
18 | torch.backends.cudnn.benchmark = True
19 |
20 | p = configargparse.ArgumentParser()
21 | p.add('-c', '--config', required=False, is_config_file=True, help='Path to config file.')
22 |
23 | # Experiment & I/O general properties
24 | p.add_argument('--experiment_name', type=str, default=None,
25 | help='path to directory where checkpoints & tensorboard events will be saved.')
26 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
27 | p.add_argument('--logging_root', type=str, default='../logs', help='root for logging')
28 | p.add_argument('--dataset_path', type=str, default='../data/nerf_synthetic/lego/',
29 | help='path to directory where dataset is stored')
30 | p.add_argument('--resume', nargs=2, type=str, default=None,
31 | help='resume training, specify path to directory where model is stored.')
32 | p.add_argument('--num_steps', type=int, default=1000000,
33 | help='Number of iterations to train for.')
34 | p.add_argument('--steps_til_ckpt', type=int, default=50000,
35 | help='Iterations until checkpoint is saved.')
36 | p.add_argument('--steps_til_summary', type=int, default=2000,
37 | help='Iterations until tensorboard summary is saved.')
38 |
39 | # GPU & other computing properties
40 | p.add_argument('--gpu', type=int, default=0,
41 | help='GPU ID to use')
42 | p.add_argument('--chunk_size_train', type=int, default=1024,
43 | help='max chunk size to process data during training')
44 | p.add_argument('--chunk_size_eval', type=int, default=512,
45 | help='max chunk size to process data during eval')
46 | p.add_argument('--num_workers', type=int, default=0, help='number of dataloader workers.')
47 |
48 | # Learning properties
49 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=5e-5')
50 | p.add_argument('--batch_size', type=int, default=1)
51 |
52 | # Network architecture properties
53 | p.add_argument('--hidden_features', type=int, default=128)
54 | p.add_argument('--hidden_layers', type=int, default=4)
55 |
56 | p.add_argument('--multiscale', action='store_true', help='use multiscale architecture')
57 | p.add_argument('--supervise_hr', action='store_true', help='supervise only with high resolution signal')
58 | p.add_argument('--use_resized', action='store_true', help='use explicit multiscale supervision')
59 | p.add_argument('--reuse_filters', action='store_true', help='reuse fourier filters for faster training/inference')
60 |
61 | # NeRF Properties
62 | p.add_argument('--img_size', type=int, default=64,
63 | help='image resolution to train on (assumed symmetric)')
64 | p.add_argument('--samples_per_ray', type=int, default=128,
65 | help='samples to evaluate along each ray')
66 | p.add_argument('--samples_per_view', type=int, default=1024,
67 | help='samples to evaluate along each view')
68 |
69 | opt = p.parse_args()
70 |
71 | if opt.experiment_name is None and opt.render_model is None:
72 | p.error('--experiment_name is required.')
73 |
74 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu)
75 |
76 |
77 | def main():
78 | print('--- Run Configuration ---')
79 | for k, v in vars(opt).items():
80 | print(k, v)
81 | train()
82 |
83 |
84 | def train(validation=True):
85 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
86 | utils.cond_mkdir(root_path)
87 |
88 | ''' Training dataset '''
89 | dataset = dataio.NerfBlenderDataset(opt.dataset_path,
90 | splits=['train'],
91 | mode='train',
92 | resize_to=2*(opt.img_size,),
93 | multiscale=opt.multiscale)
94 |
95 | coords_dataset = dataio.Implicit6DMultiviewDataWrapper(dataset,
96 | dataset.get_img_shape(),
97 | dataset.get_camera_params(),
98 | samples_per_view=opt.samples_per_view,
99 | num_workers=opt.num_workers,
100 | multiscale=opt.use_resized,
101 | supervise_hr=opt.supervise_hr,
102 | scales=[1/8, 1/4, 1/2, 1])
103 | ''' Validation dataset '''
104 | if validation:
105 | val_dataset = dataio.NerfBlenderDataset(opt.dataset_path,
106 | splits=['val'],
107 | mode='val',
108 | resize_to=2*(opt.img_size,),
109 | multiscale=opt.multiscale)
110 |
111 | val_coords_dataset = dataio.Implicit6DMultiviewDataWrapper(val_dataset,
112 | val_dataset.get_img_shape(),
113 | val_dataset.get_camera_params(),
114 | samples_per_view=opt.img_size**2,
115 | num_workers=opt.num_workers,
116 | multiscale=opt.use_resized,
117 | supervise_hr=opt.supervise_hr,
118 | scales=[1/8, 1/4, 1/2, 1])
119 |
120 | ''' Dataloaders'''
121 | dataloader = DataLoader(coords_dataset, shuffle=True, batch_size=opt.batch_size, # num of views in a batch
122 | pin_memory=True, num_workers=opt.num_workers)
123 |
124 | if validation:
125 | val_dataloader = DataLoader(val_coords_dataset, shuffle=True, batch_size=1,
126 | pin_memory=True, num_workers=opt.num_workers)
127 | else:
128 | val_dataloader = None
129 |
130 | # get model paths
131 | if opt.resume is not None:
132 | path, step = opt.resume
133 | step = int(step)
134 | assert(os.path.isdir(path))
135 | assert opt.config is not None, 'Specify config file'
136 |
137 | # since model goes between -4 and 4 instead of -0.5 to 0.5
138 | # we divide by a factor of 8. Then this is Nyquist sampled
139 | # assuming a maximum frequency of opt.img_size/8 cycles per unit interval
140 | # (where the blender dataset scenes typically span from -4 to 4 units)
141 | rgb_sample_freq = 3*(2*opt.img_size/8,)
142 |
143 | if opt.multiscale:
144 | # scale the frequencies of each layer
145 | # so that we have outputs at 1/8, 1/4, 1/2, and 1x
146 | # the maximum network bnadiwdth
147 | input_scales = [1/24, 1/24, 1/24, 1/16, 1/16, 1/8, 1/8, 1/4, 1/4]
148 | output_layers = [2, 4, 6, 8]
149 |
150 | model = modules.MultiscaleBACON(3, opt.hidden_features, 4,
151 | hidden_layers=opt.hidden_layers,
152 | bias=True,
153 | frequency=rgb_sample_freq,
154 | quantization_interval=np.pi/4,
155 | input_scales=input_scales,
156 | output_layers=output_layers,
157 | reuse_filters=opt.reuse_filters)
158 | model.cuda()
159 |
160 | else:
161 | input_scales = [1/24, 1/24, 1/24, 1/16, 1/16, 1/8, 1/8, 1/4, 1/4]
162 | input_scales = input_scales[:opt.hidden_layers+1]
163 |
164 | model = modules.BACON(3, opt.hidden_features, 4,
165 | hidden_layers=opt.hidden_layers,
166 | bias=True,
167 | frequency=rgb_sample_freq,
168 | quantization_interval=np.pi/4,
169 | reuse_filters=opt.reuse_filters,
170 | input_scales=input_scales)
171 | model.cuda()
172 |
173 | if opt.resume is not None:
174 | print('Loading checkpoints')
175 |
176 | state_dict = torch.load(path + '/checkpoints/' + f'model_combined_step_{step:04d}.pth')
177 | model.load_state_dict(state_dict, strict=False)
178 |
179 | # load optimizers
180 | try:
181 | resume_checkpoint = {}
182 | ckpt = torch.load(path + '/checkpoints/' + f'optim_combined_step_{step:04d}.pth')
183 | for g in ckpt['optimizer_state_dict']['param_groups']:
184 | g['lr'] = opt.lr
185 |
186 | resume_checkpoint['combined'] = {}
187 | resume_checkpoint['combined']['optim'] = ckpt['optimizer_state_dict']
188 | resume_checkpoint['combined']['scheduler'] = ckpt['scheduler_state_dict']
189 | resume_checkpoint['step'] = ckpt['step']
190 | except FileNotFoundError:
191 | print('Unable to load optimizer checkpoints')
192 | else:
193 | resume_checkpoint = {}
194 |
195 | models = {'combined': model}
196 |
197 | # Define the loss
198 | if opt.multiscale:
199 | loss_fn = partial(loss_functions.multiscale_radiance_loss, use_resized=opt.use_resized)
200 | summary_fn = partial(utils.write_multiscale_radiance_summary,
201 | chunk_size_eval=opt.chunk_size_eval,
202 | num_views_to_disp_at_training=1,
203 | hierarchical_sampling=True)
204 | else:
205 | loss_fn = partial(loss_functions.radiance_sigma_rgb_loss)
206 |
207 | summary_fn = partial(utils.write_radiance_summary,
208 | chunk_size_eval=opt.chunk_size_eval,
209 | num_views_to_disp_at_training=1,
210 | hierarchical_sampling=True)
211 |
212 | chunk_lists_from_batch_fn = dataio.chunk_lists_from_batch_reduce_to_raysamples_fn
213 |
214 | # Save command-line parameters log directory.
215 | p.write_config_file(opt, [os.path.join(root_path, 'config.ini')])
216 | with open(os.path.join(root_path, "params.txt"), "w") as out_file:
217 | out_file.write('\n'.join(["%s: %s" % (key, value) for key, value in vars(opt).items()]))
218 |
219 | # Save text summary of model into log directory.
220 | with open(os.path.join(root_path, "model.txt"), "w") as out_file:
221 | for model_name, model in models.items():
222 | out_file.write(model_name)
223 | out_file.write(str(model))
224 |
225 | training.train_wchunks(models, dataloader,
226 | num_steps=opt.num_steps, lr=opt.lr,
227 | steps_til_summary=opt.steps_til_summary, steps_til_checkpoint=opt.steps_til_ckpt,
228 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn,
229 | val_dataloader=val_dataloader,
230 | chunk_lists_from_batch_fn=chunk_lists_from_batch_fn,
231 | max_chunk_size=opt.chunk_size_train,
232 | resume_checkpoint=resume_checkpoint,
233 | chunked=True,
234 | hierarchical_sampling=True,
235 | stop_after=0)
236 |
237 |
238 | if __name__ == '__main__':
239 | main()
240 |
--------------------------------------------------------------------------------
/experiments/train_sdf.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4 |
5 | from torch.utils.tensorboard import SummaryWriter
6 | import numpy as np
7 | import torch
8 | from torch.utils.data import DataLoader
9 | import configargparse
10 | import dataio
11 | import utils
12 | import training
13 | import loss_functions
14 | import modules
15 | from functools import partial
16 |
17 | torch.set_num_threads(8)
18 |
19 | p = configargparse.ArgumentParser()
20 |
21 | # config file, output directories
22 | p.add('-c', '--config', required=False, is_config_file=True,
23 | help='Path to config file.')
24 | p.add_argument('--logging_root', type=str, default='../logs',
25 | help='root for logging')
26 | p.add_argument('--experiment_name', type=str, required=True,
27 | help='subdirectory in logging_root for checkpoints, summaries')
28 |
29 | # general training
30 | p.add_argument('--model_type', type=str, default='mfn',
31 | help='options: mfn, siren, ff')
32 | p.add_argument('--hidden_size', type=int, default=128,
33 | help='size of hidden layer')
34 | p.add_argument('--hidden_layers', type=int, default=8)
35 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate')
36 | p.add_argument('--num_steps', type=int, default=20000,
37 | help='number of training steps')
38 | p.add_argument('--ckpt_step', type=int, default=0,
39 | help='step at which to resume training')
40 | p.add_argument('--gpu', type=int, default=1, help='GPU ID to use')
41 | p.add_argument('--seed', default=None,
42 | help='random seed for experiment reproducibility')
43 |
44 | # mfn options
45 | p.add_argument('--multiscale', action='store_true', default=False,
46 | help='use multiscale')
47 | p.add_argument('--max_freq', type=int, default=512,
48 | help='The network-equivalent sample rate used to represent the signal.'
49 | + 'Should be at least twice the Nyquist frequency.')
50 | p.add_argument('--input_scales', nargs='*', type=float, default=None,
51 | help='fraction of resolution growth at each layer')
52 | p.add_argument('--output_layers', nargs='*', type=int, default=None,
53 | help='layer indices to output, beginning at 1')
54 |
55 | # mlp options
56 | p.add_argument('--w0', default=30, type=int,
57 | help='omega_0 parameter for siren')
58 | p.add_argument('--pe_scale', default=5, type=float,
59 | help='positional encoding scale')
60 |
61 | # sdf model and sampling
62 | p.add_argument('--num_pts_on', type=int, default=10000,
63 | help='number of on-surface points to sample')
64 | p.add_argument('--coarse_scale', type=float, default=1e-1,
65 | help='laplacian scale factor for coarse samples')
66 | p.add_argument('--fine_scale', type=float, default=1e-3,
67 | help='laplacian scale factor for fine samples')
68 | p.add_argument('--coarse_weight', type=float, default=1e-2,
69 | help='weight to apply to coarse loss samples')
70 |
71 | # data i/o
72 | p.add_argument('--shape', type=str, default='bunny',
73 | help='name of point cloud shape in xyz format')
74 | p.add_argument('--point_cloud_path', type=str,
75 | default='../data/armadillo.xyz',
76 | help='path for input point cloud')
77 | p.add_argument('--num_workers', default=0, type=int,
78 | help='number of workers')
79 |
80 | # tensorboard summary
81 | p.add_argument('--steps_til_ckpt', type=int, default=50000,
82 | help='epoch frequency to save a checkpoint')
83 | p.add_argument('--steps_til_summary', type=int, default=1000,
84 | help='epoch frequency to update tensorboard summary')
85 |
86 | opt = p.parse_args()
87 |
88 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu)
89 |
90 |
91 | def main():
92 |
93 | print('--- Run Configuration ---')
94 | for k, v in vars(opt).items():
95 | print(k, v)
96 |
97 | train()
98 |
99 |
100 | def train():
101 |
102 | opt.root_path = os.path.join(opt.logging_root, opt.experiment_name)
103 | utils.cond_mkdir(opt.root_path)
104 |
105 | if opt.seed:
106 | torch.manual_seed(int(opt.seed))
107 | np.random.seed(int(opt.seed))
108 |
109 | dataloader = init_dataloader(opt)
110 |
111 | model = init_model(opt)
112 |
113 | loss_fn, summary_fn = init_loss(opt)
114 |
115 | save_params(opt, model)
116 |
117 | training.train(model=model, train_dataloader=dataloader, steps=opt.num_steps,
118 | lr=opt.lr, steps_til_summary=opt.steps_til_summary,
119 | ckpt_step=opt.ckpt_step,
120 | steps_til_checkpoint=opt.steps_til_ckpt,
121 | model_dir=opt.root_path, loss_fn=loss_fn, summary_fn=summary_fn,
122 | double_precision=False, clip_grad=True,
123 | use_lr_scheduler=True)
124 |
125 |
126 | def init_dataloader(opt):
127 | ''' load sdf dataloader via eikonal equation or fitting sdf directly '''
128 |
129 | sdf_dataset = dataio.MeshSDF(opt.point_cloud_path,
130 | num_samples=opt.num_pts_on,
131 | coarse_scale=opt.coarse_scale,
132 | fine_scale=opt.fine_scale)
133 |
134 | dataloader = DataLoader(sdf_dataset, shuffle=True,
135 | batch_size=1, pin_memory=True,
136 | num_workers=opt.num_workers)
137 |
138 | return dataloader
139 |
140 |
141 | def init_model(opt):
142 | ''' return appropriate model given experiment configs '''
143 |
144 | if opt.model_type == 'mfn':
145 |
146 | opt.input_scales = [1/24, 1/24, 1/24, 1/16, 1/16, 1/8, 1/8, 1/4, 1/4]
147 | opt.output_layers = [2, 4, 6, 8]
148 |
149 | frequency = (opt.max_freq, opt.max_freq, opt.max_freq)
150 |
151 | if opt.multiscale:
152 | if opt.output_layers and len(opt.output_layers) == 1:
153 | raise ValueError('expects >1 layer extraction if multiscale')
154 | model_ = modules.MultiscaleBACON
155 | else:
156 | model_ = modules.BACON
157 |
158 | model = model_(in_size=3, hidden_size=opt.hidden_size, out_size=1,
159 | hidden_layers=opt.hidden_layers,
160 | bias=True,
161 | frequency=frequency,
162 | quantization_interval=2*np.pi, # data on range [-0.5, 0.5]
163 | input_scales=opt.input_scales,
164 | is_sdf=True,
165 | output_layers=opt.output_layers,
166 | reuse_filters=True)
167 |
168 | elif opt.model_type == 'siren':
169 |
170 | if opt.multiscale:
171 | model_ = modules.MultiscaleCoordinateNet
172 | else:
173 | model_ = modules.CoordinateNet
174 |
175 | model = model_(nl='sine',
176 | in_features=3,
177 | out_features=1,
178 | num_hidden_layers=opt.hidden_layers,
179 | hidden_features=opt.hidden_size,
180 | w0=opt.w0,
181 | is_sdf=True)
182 |
183 | elif opt.model_type == 'ff': # mlp w relu + positional encoding
184 | model = modules.CoordinateNet(nl='relu',
185 | in_features=3,
186 | out_features=1,
187 | num_hidden_layers=opt.hidden_layers,
188 | hidden_features=opt.hidden_size,
189 | is_sdf=True,
190 | pe_scale=opt.pe_scale,
191 | use_sigmoid=False)
192 | else:
193 | raise NotImplementedError
194 |
195 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
196 | params = sum([np.prod(p.size()) for p in model_parameters])
197 | print(f'Num. Parameters: {params}')
198 | model.cuda()
199 |
200 | # if resuming model training
201 | if opt.ckpt_step:
202 | opt.num_steps -= opt.ckpt_step # steps remaning to train
203 | if opt.num_steps < 1:
204 | raise ValueError('ckpt_epoch must be less than num_epochs')
205 | print(opt.num_steps)
206 |
207 | pth_file = '{}/checkpoints/model_step_{}.pth'.format(opt.root_path,
208 | str(opt.ckpt_step).zfill(4))
209 | model.load_state_dict(torch.load(pth_file))
210 |
211 | return model
212 |
213 |
214 | def init_loss(opt):
215 | ''' define loss, summary functions given expmt configs '''
216 |
217 | if opt.multiscale:
218 | summary_fn = utils.write_multiscale_sdf_summary
219 | loss_fn = partial(loss_functions.multiscale_overfit_sdf,
220 | coarse_loss_weight=opt.coarse_weight)
221 | else:
222 | summary_fn = utils.write_sdf_summary
223 | loss_fn = partial(loss_functions.overfit_sdf,
224 | coarse_loss_weight=opt.coarse_weight)
225 |
226 | return loss_fn, summary_fn
227 |
228 |
229 | def save_params(opt, model):
230 |
231 | # Save command-line parameters log directory.
232 | p.write_config_file(opt, [os.path.join(opt.root_path, 'config.ini')])
233 | with open(os.path.join(opt.root_path, "params.txt"), "w") as out_file:
234 | out_file.write('\n'.join(["%s: %s" % (key, value) for key, value in vars(opt).items()]))
235 |
236 | # Save text summary of model into log directory.
237 | with open(os.path.join(opt.root_path, "model.txt"), "w") as out_file:
238 | out_file.write(str(model))
239 |
240 |
241 | if __name__ == '__main__':
242 | main()
243 |
--------------------------------------------------------------------------------
/forward_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def cumprod_exclusive(tensor, dim=-2):
5 | cumprod = torch.cumprod(tensor, dim)
6 | cumprod = torch.roll(cumprod, 1, dim)
7 | cumprod[..., 0, :] = 1.0
8 | return cumprod
9 |
10 |
11 | def compute_transmittance_weights(pred_sigma, t_intervals):
12 | # pred_alpha = 1.-torch.exp(-torch.relu(pred_sigma)*t_intervals)
13 | tau = torch.nn.functional.softplus(pred_sigma - 1)
14 | pred_alpha = 1.-torch.exp(-tau*t_intervals)
15 | pred_weights = pred_alpha * cumprod_exclusive(1.-pred_alpha+1e-10, dim=-2)
16 | return pred_weights
17 |
18 |
19 | def compute_tomo_radiance(pred_weights, pred_rgb, black_background=False):
20 | eps = 0.001
21 | pred_rgb_pos = torch.sigmoid(pred_rgb)
22 | pred_rgb_pos = pred_rgb_pos * (1 + 2 * eps) - eps
23 | pred_pixel_samples = torch.sum(pred_rgb_pos*pred_weights, dim=-2) # line integral
24 |
25 | if not black_background:
26 | pred_pixel_samples += 1 - pred_weights.sum(-2)
27 | return pred_pixel_samples
28 |
29 |
30 | def compute_tomo_depth(pred_weights, zs):
31 | pred_depth = torch.sum(pred_weights*zs, dim=-2)
32 | return pred_depth
33 |
34 |
35 | def compute_disp_from_depth(pred_depth, pred_weights):
36 | pred_disp = 1. / torch.max(torch.tensor(1e-10).to(pred_depth.device),
37 | pred_depth / torch.sum(pred_weights, -2))
38 | return pred_disp
39 |
--------------------------------------------------------------------------------
/img/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/img/teaser.png
--------------------------------------------------------------------------------
/loss_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import forward_models
3 |
4 |
5 | def function_mse(model_output, gt):
6 | idx = model_output['model_in']['idx'].long().squeeze()
7 | loss = (model_output['model_out']['output'][:, idx] - gt['func'][:, idx]) ** 2
8 | return {'func_loss': loss.mean()}
9 |
10 |
11 | def image_mse(model_output, gt):
12 | if 'complex' in model_output['model_out']:
13 | c = model_output['model_out']['complex']
14 | loss = (c.real - gt['img']) ** 2
15 | imag_loss = (c.imag) ** 2
16 | return {'func_loss': loss.mean(), 'imag_loss': imag_loss.mean()}
17 |
18 | else:
19 | loss = (model_output['model_out']['output'] - gt['img']) ** 2
20 | return {'func_loss': loss.mean()}
21 |
22 |
23 | def multiscale_image_mse(model_output, gt, use_resized=False):
24 | if use_resized:
25 | loss = [(out - gt_img)**2 for out, gt_img in zip(model_output['model_out']['output'], gt['img'])]
26 | else:
27 | loss = [(out - gt['img'])**2 for out in model_output['model_out']['output']]
28 |
29 | loss = torch.stack(loss).mean()
30 |
31 | return {'func_loss': loss}
32 |
33 |
34 | def multiscale_radiance_loss(model_outputs, gt, use_resized=False, weight=1.0,
35 | regularize_sigma=False, reg_lambda=1e-5, reg_c=0.5):
36 | tomo_loss = None
37 | sigma_reg = None
38 |
39 | pred_sigmas = [pred[..., -1:] for pred in model_outputs['combined']['model_out']['output']]
40 | pred_rgbs = [pred[..., :-1] for pred in model_outputs['combined']['model_out']['output']]
41 | if isinstance(model_outputs['combined']['model_in']['t_intervals'], list):
42 | t_intervals = [t_interval for t_interval in model_outputs['combined']['model_in']['t_intervals']]
43 | else:
44 | t_intervals = model_outputs['combined']['model_in']['t_intervals']
45 |
46 | for idx, (pred_sigma, pred_rgb) in enumerate(zip(pred_sigmas, pred_rgbs)):
47 |
48 | if isinstance(t_intervals, list):
49 | t_interval = t_intervals[idx]
50 | else:
51 | t_interval = t_intervals
52 |
53 | # Pass through the forward models
54 | pred_weights = forward_models.compute_transmittance_weights(pred_sigma, t_interval)
55 | pred_pixel_samples = forward_models.compute_tomo_radiance(pred_weights, pred_rgb)
56 |
57 | # Target Ground truth
58 | if use_resized:
59 | target_pixel_samples = gt['pixel_samples'][idx]
60 | else:
61 | target_pixel_samples = gt['pixel_samples']
62 |
63 | # Loss
64 | if tomo_loss is None:
65 | tomo_loss = (pred_pixel_samples - target_pixel_samples)**2
66 | else:
67 | tomo_loss += (pred_pixel_samples - target_pixel_samples)**2
68 |
69 | if regularize_sigma:
70 | tau = torch.nn.functional.softplus(pred_sigma - 1)
71 | if sigma_reg is None:
72 | sigma_reg = (torch.log(1 + tau**2 / reg_c))
73 | else:
74 | sigma_reg += (torch.log(1 + tau**2 / reg_c))
75 |
76 | loss = {'tomo_rad_loss': weight * tomo_loss.mean()}
77 |
78 | if regularize_sigma:
79 | loss['sigma_reg'] = reg_lambda * sigma_reg.mean()
80 |
81 | return loss
82 |
83 |
84 | def radiance_sigma_rgb_loss(model_outputs, gt, regularize_sigma=False,
85 | reg_lambda=1e-5, reg_c=0.5):
86 | pred_sigma = model_outputs['combined']['model_out']['output'][..., -1:]
87 | pred_rgb = model_outputs['combined']['model_out']['output'][..., :-1]
88 | t_intervals = model_outputs['combined']['model_in']['t_intervals']
89 |
90 | # Pass through the forward models
91 | pred_weights = forward_models.compute_transmittance_weights(pred_sigma, t_intervals)
92 | pred_pixel_samples = forward_models.compute_tomo_radiance(pred_weights, pred_rgb)
93 |
94 | # Target Ground truth
95 | target_pixel_samples = gt['pixel_samples'][..., :3] # rgba -> rgb
96 |
97 | # Loss
98 | tomo_loss = (pred_pixel_samples - target_pixel_samples)**2
99 |
100 | if regularize_sigma:
101 | tau = torch.nn.functional.softplus(pred_sigma - 1)
102 | sigma_reg = (torch.log(1 + tau**2 / reg_c))
103 |
104 | loss = {'tomo_rad_loss': tomo_loss.mean()}
105 |
106 | if regularize_sigma:
107 | loss['sigma_reg'] = reg_lambda * sigma_reg.mean()
108 |
109 | return loss
110 |
111 |
112 | def overfit_sdf(model_output, gt, coarse_loss_weight=1e-2):
113 | return overfit_sdf_loss_total(model_output, gt, is_multiscale=False,
114 | coarse_loss_weight=coarse_loss_weight)
115 |
116 |
117 | def multiscale_overfit_sdf(model_output, gt, coarse_loss_weight=1e-2):
118 | return overfit_sdf_loss_total(model_output, gt, is_multiscale=True,
119 | coarse_loss_weight=coarse_loss_weight)
120 |
121 |
122 | def overfit_sdf_loss_total(model_output, gt, is_multiscale, lambda_grad=1e-3,
123 | coarse_loss_weight=1e-2):
124 | ''' fit sdf to sphere via mse loss '''
125 |
126 | gt_sdf = gt['sdf']
127 | pred_sdf = model_output['model_out']
128 |
129 | pred_sdf_ = pred_sdf[0] if is_multiscale else pred_sdf
130 |
131 | mse_ = (gt_sdf - pred_sdf_)**2
132 |
133 | if is_multiscale:
134 | for pred_sdf_ in pred_sdf[1:]:
135 | mse_ += (gt_sdf - pred_sdf_)**2
136 |
137 | mse_ = (mse_ / len(pred_sdf))
138 | mse_[:, ::2] *= coarse_loss_weight
139 |
140 | mse_fine = mse_[:, 1::2].sum()
141 | mse_coarse = mse_[:, ::2].sum()
142 |
143 | return {'sdf_fine': mse_fine, 'sdf_coarse': mse_coarse}
144 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4 |
5 | import torch
6 | import math
7 | import numpy as np
8 | from functools import partial
9 | from torch import nn
10 | import copy
11 |
12 |
13 | def init_weights_normal(m):
14 | if type(m) == nn.Linear:
15 | if hasattr(m, 'weight'):
16 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_out')
17 |
18 |
19 | def init_weights_selu(m):
20 | if type(m) == nn.Linear:
21 | if hasattr(m, 'weight'):
22 | num_input = m.weight.size(-1)
23 | nn.init.normal_(m.weight, std=1/math.sqrt(num_input))
24 |
25 |
26 | def init_weights_elu(m):
27 | if type(m) == nn.Linear:
28 | if hasattr(m, 'weight'):
29 | num_input = m.weight.size(-1)
30 | nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277)/math.sqrt(num_input))
31 |
32 |
33 | def init_weights_xavier(m):
34 | if type(m) == nn.Linear:
35 | if hasattr(m, 'weight'):
36 | nn.init.xavier_normal_(m.weight)
37 |
38 |
39 | def init_weights_uniform(m):
40 | if type(m) == nn.Linear:
41 | if hasattr(m, 'weight'):
42 | torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
43 |
44 |
45 | def sine_init(m, w0=30):
46 | with torch.no_grad():
47 | if hasattr(m, 'weight'):
48 | num_input = m.weight.size(-1)
49 | m.weight.uniform_(-np.sqrt(6/num_input)/w0, np.sqrt(6/num_input)/w0)
50 |
51 |
52 | def first_layer_sine_init(m):
53 | with torch.no_grad():
54 | if hasattr(m, 'weight'):
55 | num_input = m.weight.size(-1)
56 | m.weight.uniform_(-1/num_input, 1/num_input)
57 |
58 |
59 | class FirstSine(nn.Module):
60 | def __init__(self, w0=20):
61 | super().__init__()
62 | self.w0 = torch.tensor(w0)
63 |
64 | def forward(self, input):
65 | return torch.sin(self.w0*input)
66 |
67 |
68 | class Sine(nn.Module):
69 | def __init__(self, w0=20):
70 | super().__init__()
71 | self.w0 = torch.tensor(w0)
72 |
73 | def forward(self, input):
74 | return torch.sin(self.w0*input)
75 |
76 |
77 | class MSoftplus(nn.Module):
78 | def __init__(self):
79 | super().__init__()
80 | self.softplus = nn.Softplus()
81 | self.cst = torch.log(torch.tensor(2.))
82 |
83 | def forward(self, input):
84 | return self.softplus(input)-self.cst
85 |
86 |
87 | class Swish(nn.Module):
88 | def __init__(self):
89 | super().__init__()
90 |
91 | def forward(self, input):
92 | return input*torch.sigmoid(input)
93 |
94 |
95 | def mfn_weights_init(m):
96 | with torch.no_grad():
97 | if hasattr(m, 'weight'):
98 | num_input = m.weight.size(-1)
99 | m.weight.uniform_(-np.sqrt(6/num_input), np.sqrt(6/num_input))
100 |
101 |
102 | class MFNBase(nn.Module):
103 |
104 | def __init__(self, hidden_size, out_size, n_layers, weight_scale,
105 | bias=True, output_act=False):
106 | super().__init__()
107 |
108 | self.linear = nn.ModuleList(
109 | [nn.Linear(hidden_size, hidden_size, bias) for _ in range(n_layers)]
110 | )
111 |
112 | self.output_linear = nn.Linear(hidden_size, out_size)
113 |
114 | self.output_act = output_act
115 |
116 | self.linear.apply(mfn_weights_init)
117 | self.output_linear.apply(mfn_weights_init)
118 |
119 | def forward(self, model_input):
120 |
121 | input_dict = {key: input.clone().detach().requires_grad_(True)
122 | for key, input in model_input.items()}
123 | coords = input_dict['coords']
124 |
125 | out = self.filters[0](coords)
126 | for i in range(1, len(self.filters)):
127 | out = self.filters[i](coords) * self.linear[i - 1](out)
128 | out = self.output_linear(out)
129 |
130 | if self.output_act:
131 | out = torch.sin(out)
132 |
133 | return {'model_in': input_dict, 'model_out': {'output': out}}
134 |
135 |
136 | class FourierLayer(nn.Module):
137 |
138 | def __init__(self, in_features, out_features, weight_scale, quantization_interval=2*np.pi):
139 | super().__init__()
140 | self.linear = nn.Linear(in_features, out_features)
141 |
142 | r = 2*weight_scale[0] / quantization_interval
143 | assert math.isclose(r, round(r)), \
144 | 'weight_scale should be divisible by quantization interval'
145 |
146 | # sample discrete uniform distribution of frequencies
147 | for i in range(self.linear.weight.data.shape[1]):
148 | init = torch.randint_like(self.linear.weight.data[:, i],
149 | 0, int(2*weight_scale[i] / quantization_interval)+1)
150 | init = init * quantization_interval - weight_scale[i]
151 | self.linear.weight.data[:, i] = init
152 |
153 | self.linear.weight.requires_grad = False
154 | self.linear.bias.data.uniform_(-np.pi, np.pi)
155 | return
156 |
157 | def forward(self, x):
158 | return torch.sin(self.linear(x))
159 |
160 |
161 | class BACON(MFNBase):
162 | def __init__(self,
163 | in_size,
164 | hidden_size,
165 | out_size,
166 | hidden_layers=3,
167 | weight_scale=1.0,
168 | bias=True,
169 | output_act=False,
170 | frequency=(128, 128),
171 | quantization_interval=2*np.pi, # assumes data range [-.5,.5]
172 | centered=True,
173 | input_scales=None,
174 | output_layers=None,
175 | is_sdf=False,
176 | reuse_filters=False,
177 | **kwargs):
178 |
179 | super().__init__(hidden_size, out_size, hidden_layers,
180 | weight_scale, bias, output_act)
181 |
182 | self.quantization_interval = quantization_interval
183 | self.hidden_layers = hidden_layers
184 | self.hidden_size = hidden_size
185 | self.centered = centered
186 | self.frequency = frequency
187 | self.is_sdf = is_sdf
188 | self.reuse_filters = reuse_filters
189 | self.in_size = in_size
190 |
191 | # we need to multiply by this to be able to fit the signal
192 | input_scale = [round((np.pi * freq / (hidden_layers + 1))
193 | / quantization_interval) * quantization_interval for freq in frequency]
194 |
195 | self.filters = nn.ModuleList([
196 | FourierLayer(in_size, hidden_size, input_scale,
197 | quantization_interval=quantization_interval)
198 | for i in range(hidden_layers + 1)])
199 |
200 | print(self)
201 |
202 | def forward_mfn(self, input_dict):
203 | if 'coords' in input_dict:
204 | coords = input_dict['coords']
205 | elif 'ray_samples' in input_dict:
206 | if self.in_size > 3:
207 | coords = torch.cat((input_dict['ray_samples'], input_dict['ray_orientations']), dim=-1)
208 | else:
209 | coords = input_dict['ray_samples']
210 |
211 | if self.reuse_filters:
212 | filter_outputs = 3 * [self.filters[2](coords), ] + \
213 | 2 * [self.filters[4](coords), ] + \
214 | 2 * [self.filters[6](coords), ] + \
215 | 2 * [self.filters[8](coords), ]
216 |
217 | out = filter_outputs[0]
218 | for i in range(1, len(self.filters)):
219 | out = filter_outputs[i] * self.linear[i - 1](out)
220 | else:
221 | out = self.filters[0](coords)
222 | for i in range(1, len(self.filters)):
223 | out = self.filters[i](coords) * self.linear[i - 1](out)
224 |
225 | out = self.output_linear(out)
226 |
227 | if self.output_act:
228 | out = torch.sin(out)
229 |
230 | return out
231 |
232 | def forward(self, model_input, mode=None, integral_dim=None):
233 |
234 | out = {'output': self.forward_mfn(model_input)}
235 |
236 | if self.is_sdf:
237 | return {'model_in': model_input['coords'],
238 | 'model_out': out['output']}
239 |
240 | return {'model_in': model_input, 'model_out': out}
241 |
242 |
243 | class MultiscaleBACON(MFNBase):
244 | def __init__(self,
245 | in_size,
246 | hidden_size,
247 | out_size,
248 | hidden_layers=3,
249 | weight_scale=1.0,
250 | bias=True,
251 | output_act=False,
252 | frequency=(128, 128),
253 | quantization_interval=2*np.pi,
254 | centered=True,
255 | is_sdf=False,
256 | input_scales=None,
257 | output_layers=None,
258 | reuse_filters=False):
259 |
260 | super().__init__(hidden_size, out_size, hidden_layers,
261 | weight_scale, bias, output_act)
262 |
263 | self.quantization_interval = quantization_interval
264 | self.hidden_layers = hidden_layers
265 | self.centered = centered
266 | self.is_sdf = is_sdf
267 | self.frequency = frequency
268 | self.output_layers = output_layers
269 | self.reuse_filters = reuse_filters
270 | self.stop_after = None
271 |
272 | # we need to multiply by this to be able to fit the signal
273 | if input_scales is None:
274 | input_scale = [round((np.pi * freq / (hidden_layers + 1))
275 | / quantization_interval) * quantization_interval for freq in frequency]
276 |
277 | self.filters = nn.ModuleList([
278 | FourierLayer(in_size, hidden_size, input_scale,
279 | quantization_interval=quantization_interval)
280 | for i in range(hidden_layers + 1)])
281 | else:
282 | if len(input_scales) != hidden_layers+1:
283 | raise ValueError('require n+1 scales for n hidden_layers')
284 | input_scale = [[round((np.pi * freq * scale) / quantization_interval) * quantization_interval
285 | for freq in frequency] for scale in input_scales]
286 |
287 | self.filters = nn.ModuleList([
288 | FourierLayer(in_size, hidden_size, input_scale[i],
289 | quantization_interval=quantization_interval)
290 | for i in range(hidden_layers + 1)])
291 |
292 | # linear layers to extract intermediate outputs
293 | self.output_linear = nn.ModuleList([nn.Linear(hidden_size, out_size) for i in range(len(self.filters))])
294 | self.output_linear.apply(mfn_weights_init)
295 |
296 | # if outputs layers is None, output at every possible layer
297 | if self.output_layers is None:
298 | self.output_layers = np.arange(1, len(self.filters))
299 |
300 | print(self)
301 |
302 | def layer_forward(self, coords, filter_outputs, specified_layers,
303 | get_feature, continue_layer, continue_feature):
304 | """ for multiscale SDF extraction """
305 |
306 | # hardcode the 8 layer network that we use for all sdf experiments
307 | filter_ind_dict = [2, 2, 2, 4, 4, 6, 6, 8, 8]
308 | outputs = []
309 |
310 | if continue_feature is None:
311 | assert(continue_layer == 0)
312 | out = self.filters[filter_ind_dict[0]](coords)
313 | filter_output_dict = {filter_ind_dict[0]: out}
314 | else:
315 | out = continue_feature
316 | filter_output_dict = {}
317 |
318 | for i in range(continue_layer+1, len(self.filters)):
319 | if filter_ind_dict[i] not in filter_output_dict.keys():
320 | filter_output_dict[filter_ind_dict[i]] = self.filters[filter_ind_dict[i]](coords)
321 | out = filter_output_dict[filter_ind_dict[i]] * self.linear[i - 1](out)
322 |
323 | if i in self.output_layers and i == specified_layers:
324 | if get_feature:
325 | outputs.append([self.output_linear[i](out), out])
326 | else:
327 | outputs.append(self.output_linear[i](out))
328 | return outputs
329 |
330 | return outputs
331 |
332 | def forward(self, model_input, specified_layers=None, get_feature=False,
333 | continue_layer=0, continue_feature=None):
334 |
335 | if self.is_sdf:
336 | model_input = {key: input.clone().detach().requires_grad_(True)
337 | for key, input in model_input.items()}
338 |
339 | if 'coords' in model_input:
340 | coords = model_input['coords']
341 | elif 'ray_samples' in model_input:
342 | coords = model_input['ray_samples']
343 |
344 | outputs = []
345 | if self.reuse_filters:
346 |
347 | # which layers to reuse
348 | if len(self.filters) < 9:
349 | filter_outputs = 2 * [self.filters[0](coords), ] + \
350 | (len(self.filters)-2) * [self.filters[-1](coords), ]
351 | else:
352 | filter_outputs = 3 * [self.filters[2](coords), ] + \
353 | 2 * [self.filters[4](coords), ] + \
354 | 2 * [self.filters[6](coords), ] + \
355 | 2 * [self.filters[8](coords), ]
356 |
357 | # multiscale sdf extractions (evaluate only some layers)
358 | if specified_layers is not None:
359 | outputs = self.layer_forward(coords, filter_outputs, specified_layers,
360 | get_feature, continue_layer, continue_feature)
361 |
362 | # evaluate all layers
363 | else:
364 | out = filter_outputs[0]
365 | for i in range(1, len(self.filters)):
366 | out = filter_outputs[i] * self.linear[i - 1](out)
367 |
368 | if i in self.output_layers:
369 | outputs.append(self.output_linear[i](out))
370 | if self.stop_after is not None and len(outputs) > self.stop_after:
371 | break
372 |
373 | # no layer reuse
374 | else:
375 | out = self.filters[0](coords)
376 | for i in range(1, len(self.filters)):
377 | out = self.filters[i](coords) * self.linear[i - 1](out)
378 |
379 | if i in self.output_layers:
380 | outputs.append(self.output_linear[i](out))
381 | if self.stop_after is not None and len(outputs) > self.stop_after:
382 | break
383 |
384 | if self.is_sdf: # convert dtype
385 | return {'model_in': model_input['coords'],
386 | 'model_out': outputs} # outputs is a list of tensors
387 |
388 | return {'model_in': model_input, 'model_out': {'output': outputs}}
389 |
390 |
391 | class MultiscaleCoordinateNet(nn.Module):
392 | '''A canonical coordinate network'''
393 | def __init__(self, out_features=1, nl='sine', in_features=1,
394 | hidden_features=256, num_hidden_layers=3,
395 | w0=30, pe_scale=6, use_sigmoid=True, no_pe=False,
396 | integrated_pe=False):
397 |
398 | super().__init__()
399 |
400 | self.nl = nl
401 | dims = in_features
402 | self.use_sigmoid = use_sigmoid
403 | self.no_pe = no_pe
404 | self.integrated_pe = integrated_pe
405 |
406 | if integrated_pe:
407 | self.pe = partial(IntegratedPositionalEncoding, L=pe_scale)
408 | in_features = int(2 * in_features * pe_scale)
409 |
410 | elif self.nl != 'sine' and not self.no_pe:
411 | in_features = in_features * hidden_features
412 | self.pe = FFPositionalEncoding(hidden_features, pe_scale, dims=dims)
413 |
414 | self.net = FCBlock(in_features=in_features,
415 | out_features=out_features,
416 | num_hidden_layers=num_hidden_layers,
417 | hidden_features=hidden_features,
418 | outermost_linear=True,
419 | nonlinearity=nl,
420 | w0=w0).net
421 |
422 | if not integrated_pe:
423 | self.output_linear = nn.ModuleList([nn.Linear(hidden_features, out_features)
424 | for i in range(num_hidden_layers)])
425 | self.net = self.net[:-1]
426 |
427 | print(self)
428 |
429 | def net_forward(self, coords):
430 | outputs = []
431 |
432 | if self.use_sigmoid and self.nl != 'sine':
433 | def out_nl(x): return torch.sigmoid(x)
434 | else:
435 | def out_nl(x): return x
436 |
437 | # mipnerf baseline
438 | if self.integrated_pe:
439 | for c in coords:
440 | outputs.append(out_nl(self.net(c)))
441 | else:
442 | out = self.net[0](coords)
443 | outputs.append(self.output_linear[0](out))
444 | for i, n in enumerate(self.net[1:]):
445 | # run main branch
446 | out = n(out)
447 |
448 | # extract intermediate output
449 | outputs.append(out_nl(self.output_linear[i](out)))
450 |
451 | return outputs
452 |
453 | def forward(self, model_input):
454 |
455 | coords = model_input['coords']
456 |
457 | if self.integrated_pe:
458 | coords = [self.pe(coords, r) for r in model_input['radii']]
459 | elif self.nl != 'sine' and not self.no_pe:
460 | coords = self.pe(coords)
461 |
462 | output = self.net_forward(coords)
463 |
464 | return {'model_in': model_input, 'model_out': {'output': output}}
465 |
466 |
467 | class CoordinateNet(nn.Module):
468 | '''A canonical coordinate network'''
469 | def __init__(self, out_features=1, nl='sine', in_features=1,
470 | hidden_features=256, num_hidden_layers=3,
471 | w0=30, pe_scale=5, use_sigmoid=True, no_pe=False,
472 | is_sdf=False, **kwargs):
473 |
474 | super().__init__()
475 |
476 | self.nl = nl
477 | dims = in_features
478 | self.use_sigmoid = use_sigmoid
479 | self.no_pe = no_pe
480 | self.is_sdf = is_sdf
481 |
482 | if self.nl != 'sine' and not self.no_pe:
483 | in_features = hidden_features # in_features * hidden_features
484 |
485 | self.pe = FFPositionalEncoding(hidden_features, pe_scale, dims=dims)
486 |
487 | self.net = FCBlock(in_features=in_features,
488 | out_features=out_features,
489 | num_hidden_layers=num_hidden_layers,
490 | hidden_features=hidden_features,
491 | outermost_linear=True,
492 | nonlinearity=nl,
493 | w0=w0)
494 | print(self)
495 |
496 | def forward(self, model_input):
497 |
498 | coords = model_input['coords']
499 |
500 | if self.nl != 'sine' and not self.no_pe:
501 | coords_pe = self.pe(coords)
502 | output = self.net(coords_pe)
503 | if self.use_sigmoid:
504 | output = torch.sigmoid(output)
505 | else:
506 | output = self.net(coords)
507 |
508 | if self.is_sdf:
509 | return {'model_in': model_input, 'model_out': output}
510 |
511 | else:
512 | return {'model_in': model_input, 'model_out': {'output': output}}
513 |
514 |
515 | def IntegratedPositionalEncoding(coords, radius, L=8):
516 |
517 | # adapted from mipnerf https://github.com/google/mipnerf
518 | def expected_sin(x, x_var):
519 | """Estimates mean and variance of sin(z), z ~ N(x, var)."""
520 |
521 | # When the variance is wide, shrink sin towards zero.
522 | y = torch.exp(-0.5 * x_var) * torch.sin(x)
523 | y_var = torch.clip(0.5 * (1 - torch.exp(-2 * x_var) * torch.cos(2 * x)) - y**2, 0)
524 | return y, y_var
525 |
526 | def integrated_pos_enc(x_coord, min_deg, max_deg):
527 | """Encode `x` with sinusoids scaled by 2^[min_deg:max_deg-1]."""
528 |
529 | x, x_cov_diag = x_coord
530 | scales = torch.tensor([2**i for i in range(int(min_deg), int(max_deg))], device=x.device)
531 | shape = list(x.shape[:-1]) + [-1]
532 |
533 | y = torch.reshape(x[..., None, :] * scales[:, None], shape)
534 | y_var = torch.reshape(x_cov_diag[..., None, :] * scales[:, None]**2, shape)
535 |
536 | return expected_sin(
537 | torch.cat([y, y + 0.5 * np.pi], dim=-1),
538 | torch.cat([y_var] * 2, dim=-1))[0]
539 |
540 | means = coords
541 | covs = (radius**2 / 4) * torch.ones((1, 2), device=coords.device).repeat(coords.shape[-2], 1)
542 | return integrated_pos_enc((means, covs), 0, L)
543 |
544 |
545 | class FFPositionalEncoding(nn.Module):
546 | def __init__(self, embedding_size, scale, dims=2, gaussian=True):
547 | super().__init__()
548 | self.embedding_size = embedding_size
549 | self.scale = scale
550 |
551 | if gaussian:
552 | bvals = torch.randn(embedding_size // 2, dims) * scale
553 | else:
554 | bvals = 2.**torch.linspace(0, scale, embedding_size//2) - 1
555 |
556 | if dims == 1:
557 | bvals = bvals[:, None]
558 |
559 | elif dims == 2:
560 | bvals = torch.stack([bvals, torch.zeros_like(bvals)], dim=-1)
561 | bvals = torch.cat([bvals, torch.roll(bvals, 1, -1)], dim=0)
562 |
563 | else:
564 | tmp = (dims-1)*(torch.zeros_like(bvals),)
565 | bvals = torch.stack([bvals, *tmp], dim=-1)
566 |
567 | tmp = [torch.roll(bvals, i, -1) for i in range(1, dims)]
568 | bvals = torch.cat([bvals, *tmp], dim=0)
569 |
570 | avals = torch.ones((bvals.shape[0]))
571 | self.avals = nn.Parameter(avals, requires_grad=False)
572 | self.bvals = nn.Parameter(bvals, requires_grad=False)
573 |
574 | def forward(self, tensor) -> torch.Tensor:
575 | """
576 | Apply positional encoding to the input.
577 | """
578 |
579 | return torch.cat([self.avals * torch.sin((2.*np.pi*tensor) @ self.bvals.T),
580 | self.avals * torch.cos((2.*np.pi*tensor) @ self.bvals.T)], dim=-1)
581 |
582 |
583 | class PositionalEncoding(nn.Module):
584 | def __init__(self, num_encoding_functions=6, include_input=True, log_sampling=True, normalize=False,
585 | input_dim=3, gaussian_pe=False, gaussian_variance=38):
586 | super().__init__()
587 | self.num_encoding_functions = num_encoding_functions
588 | self.include_input = include_input
589 | self.log_sampling = log_sampling
590 | self.normalize = normalize
591 | self.gaussian_pe = gaussian_pe
592 | self.normalization = None
593 |
594 | if self.gaussian_pe:
595 | # this needs to be registered as a parameter so that it is saved in the model state dict
596 | # and so that it is converted using .cuda(). Doesn't need to be trained though
597 | self.gaussian_weights = nn.Parameter(gaussian_variance * torch.randn(num_encoding_functions, input_dim),
598 | requires_grad=False)
599 |
600 | else:
601 | self.frequency_bands = None
602 | if self.log_sampling:
603 | self.frequency_bands = 2.0 ** torch.linspace(
604 | 0.0,
605 | self.num_encoding_functions - 1,
606 | self.num_encoding_functions)
607 | else:
608 | self.frequency_bands = torch.linspace(
609 | 2.0 ** 0.0,
610 | 2.0 ** (self.num_encoding_functions - 1),
611 | self.num_encoding_functions)
612 |
613 | if normalize:
614 | self.normalization = torch.tensor(1/self.frequency_bands)
615 |
616 | def forward(self, tensor) -> torch.Tensor:
617 | r"""Apply positional encoding to the input.
618 |
619 | Args:
620 | tensor (torch.Tensor): Input tensor to be positionally encoded.
621 | encoding_size (optional, int): Number of encoding functions used to compute
622 | a positional encoding (default: 6).
623 | include_input (optional, bool): Whether or not to include the input in the
624 | positional encoding (default: True).
625 |
626 | Returns:
627 | (torch.Tensor): Positional encoding of the input tensor.
628 | """
629 |
630 | encoding = [tensor] if self.include_input else []
631 | if self.gaussian_pe:
632 | for func in [torch.sin, torch.cos]:
633 | encoding.append(func(torch.matmul(tensor, self.gaussian_weights.T)))
634 | else:
635 | for idx, freq in enumerate(self.frequency_bands):
636 | for func in [torch.sin, torch.cos]:
637 | if self.normalization is not None:
638 | encoding.append(self.normalization[idx]*func(tensor * freq))
639 | else:
640 | encoding.append(func(tensor * freq))
641 |
642 | # Special case, for no positional encoding
643 | if len(encoding) == 1:
644 | return encoding[0]
645 | else:
646 | return torch.cat(encoding, dim=-1)
647 |
648 |
649 | def layer_factory(layer_type, w0=30):
650 | layer_dict = \
651 | {
652 | 'relu': (nn.ReLU(inplace=True), init_weights_uniform),
653 | 'sigmoid': (nn.Sigmoid(), None),
654 | 'fsine': (Sine(), first_layer_sine_init),
655 | 'sine': (Sine(w0=w0), partial(sine_init, w0=w0)),
656 | 'tanh': (nn.Tanh(), init_weights_xavier),
657 | 'selu': (nn.SELU(inplace=True), init_weights_selu),
658 | 'gelu': (nn.GELU(), init_weights_selu),
659 | 'swish': (Swish(), init_weights_selu),
660 | 'softplus': (nn.Softplus(), init_weights_normal),
661 | 'msoftplus': (MSoftplus(), init_weights_normal),
662 | 'elu': (nn.ELU(), init_weights_elu)
663 | }
664 | return layer_dict[layer_type]
665 |
666 |
667 | class FCBlock(nn.Module):
668 | '''A fully connected neural network that also allows swapping out the weights when used with a hypernetwork.
669 | Can be used just as a normal neural network though, as well.
670 | '''
671 | def __init__(self, in_features, out_features,
672 | num_hidden_layers, hidden_features,
673 | outermost_linear=False, nonlinearity='relu',
674 | weight_init=None, w0=30, set_bias=None,
675 | dropout=0.0):
676 | super().__init__()
677 |
678 | self.first_layer_init = None
679 | self.dropout = dropout
680 |
681 | # Create hidden features list
682 | if not isinstance(hidden_features, list):
683 | num_hidden_features = hidden_features
684 | hidden_features = []
685 | for i in range(num_hidden_layers+1):
686 | hidden_features.append(num_hidden_features)
687 | else:
688 | num_hidden_layers = len(hidden_features)-1
689 | print(f"net_size={hidden_features}")
690 |
691 | # Create the net
692 | print(f"num_layers={len(hidden_features)}")
693 | if isinstance(nonlinearity, list):
694 | print(f"num_non_lin={len(nonlinearity)}")
695 | assert len(hidden_features) == len(nonlinearity), "Num hidden layers needs to " \
696 | "match the length of the list of non-linearities"
697 |
698 | self.net = []
699 | self.net.append(nn.Sequential(
700 | nn.Linear(in_features, hidden_features[0]),
701 | layer_factory(nonlinearity[0])[0]
702 | ))
703 | for i in range(num_hidden_layers):
704 | self.net.append(nn.Sequential(
705 | nn.Linear(hidden_features[i], hidden_features[i+1]),
706 | layer_factory(nonlinearity[i+1])[0]
707 | ))
708 |
709 | if outermost_linear:
710 | self.net.append(nn.Sequential(
711 | nn.Linear(hidden_features[-1], out_features),
712 | ))
713 | else:
714 | self.net.append(nn.Sequential(
715 | nn.Linear(hidden_features[-1], out_features),
716 | layer_factory(nonlinearity[-1])[0]
717 | ))
718 | elif isinstance(nonlinearity, str):
719 | nl, weight_init = layer_factory(nonlinearity, w0=w0)
720 | if(nonlinearity == 'sine'):
721 | first_nl = FirstSine(w0=w0)
722 | self.first_layer_init = first_layer_sine_init
723 | else:
724 | first_nl = nl
725 |
726 | if weight_init is not None:
727 | self.weight_init = weight_init
728 |
729 | self.net = []
730 | self.net.append(nn.Sequential(
731 | nn.Linear(in_features, hidden_features[0]),
732 | first_nl
733 | ))
734 |
735 | for i in range(num_hidden_layers):
736 | if(self.dropout > 0):
737 | self.net.append(nn.Dropout(self.dropout))
738 | self.net.append(nn.Sequential(
739 | nn.Linear(hidden_features[i], hidden_features[i+1]),
740 | copy.deepcopy(nl)
741 | ))
742 |
743 | if (self.dropout > 0):
744 | self.net.append(nn.Dropout(self.dropout))
745 | if outermost_linear:
746 | self.net.append(nn.Sequential(
747 | nn.Linear(hidden_features[-1], out_features),
748 | ))
749 | else:
750 | self.net.append(nn.Sequential(
751 | nn.Linear(hidden_features[-1], out_features),
752 | copy.deepcopy(nl)
753 | ))
754 |
755 | self.net = nn.Sequential(*self.net)
756 |
757 | if isinstance(nonlinearity, list):
758 | for layer_num, layer_name in enumerate(nonlinearity):
759 | self.net[layer_num].apply(layer_factory(layer_name, w0=w0)[1])
760 | elif isinstance(nonlinearity, str):
761 | if self.weight_init is not None:
762 | self.net.apply(self.weight_init)
763 |
764 | if self.first_layer_init is not None:
765 | self.net[0].apply(self.first_layer_init)
766 |
767 | if set_bias is not None:
768 | self.net[-1][0].bias.data = set_bias * torch.ones_like(self.net[-1][0].bias.data)
769 |
770 | def forward(self, coords):
771 | output = self.net(coords)
772 | return output
773 |
774 |
775 | class RadianceNet(nn.Module):
776 | def __init__(self, in_features=2, out_features=1,
777 | hidden_features=256, num_hidden_layers=3, w0=30,
778 | input_pe_params=[('ray_samples', 3, 10), ('ray_orientations', 6, 4)],
779 | nl='relu'):
780 |
781 | super().__init__()
782 | self.input_pe_params = input_pe_params
783 | self.nl = nl
784 |
785 | self.positional_encoding_fn = {}
786 | if nl != 'sine':
787 | for input_to_encode, input_dim, num_pe_fns in self.input_pe_params:
788 | self.positional_encoding_fn[input_to_encode] = PositionalEncoding(num_encoding_functions=num_pe_fns,
789 | input_dim=input_dim)
790 |
791 | self.net = FCBlock(in_features=in_features, out_features=out_features, num_hidden_layers=num_hidden_layers,
792 | hidden_features=hidden_features, outermost_linear=True, nonlinearity=nl, w0=w0)
793 |
794 | print(self)
795 |
796 | def forward(self, model_input):
797 |
798 | input_dict = {key: input.clone().detach().requires_grad_(True)
799 | for key, input in model_input.items()}
800 |
801 | if self.nl != 'sine':
802 | for input_to_encode, input_dim, num_pe_fns in self.input_pe_params:
803 | encoded_input = self.positional_encoding_fn[input_to_encode](input_dict[input_to_encode])
804 | input_dict.update({input_to_encode: encoded_input})
805 |
806 | input_list = []
807 | for input_name, _, _ in self.input_pe_params:
808 | input_list.append(input_dict[input_name])
809 |
810 | coords = torch.cat(input_list, dim=-1)
811 |
812 | if coords.ndim == 2:
813 | coords = coords[None, :, :]
814 |
815 | output = self.net(coords)
816 |
817 | output_dict = {'output': output}
818 | return {'model_in': input_dict, 'model_out': output_dict}
819 |
--------------------------------------------------------------------------------
/spectrum_visualization/README.md:
--------------------------------------------------------------------------------
1 | A few folks have asked how the shape spectra are visualized so here is some demo code that should allow you to reproduce those figures for the armadillo scene.
2 |
3 | Prerequisites:
4 | - create a conda environment from the environment.yml file
5 | - install imagemagick and convert utility https://imagemagick.org/script/convert.php
6 | - install chimerax from https://www.cgl.ucsf.edu/chimerax/ and make sure the executible is on the path as 'chimerax'
7 |
8 | Then, run `get_shape_spectra.py` which will run the model to extract the `armadillo` scene mesh and then visualize the fourier transform
9 |
--------------------------------------------------------------------------------
/spectrum_visualization/environment.yml:
--------------------------------------------------------------------------------
1 | name: bacon
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - numpy=
7 | - python=3.9
8 | - pytorch
9 | - tensorboard
10 | - torchvision
11 | - pip
12 | - pip:
13 | - argparse==1.4.0
14 | - configargparse==1.5.3
15 | - gdown==4.2.0
16 | - matplotlib==3.4.3
17 | - pymcubes==0.1.2
18 | - scikit-image
19 | - scipy
20 | - tqdm
21 | - trimesh
22 | - mrcfile
23 | - pykdtree
24 |
--------------------------------------------------------------------------------
/spectrum_visualization/get_shape_spectra.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5 |
6 | import modules
7 | import torch
8 | import numpy as np
9 | from tqdm import tqdm
10 | from pykdtree.kdtree import KDTree
11 | import mrcfile
12 |
13 | device = torch.device('mps')
14 |
15 |
16 | class HiddenPrints:
17 | def __enter__(self):
18 | self._original_stdout = sys.stdout
19 | sys.stdout = open(os.devnull, 'w')
20 |
21 | def __exit__(self, exc_type, exc_val, exc_tb):
22 | sys.stdout.close()
23 | sys.stdout = self._original_stdout
24 |
25 |
26 | def export_model(ckpt_path, model_name, N=512, model_type='bacon', hidden_layers=8,
27 | hidden_size=256, output_layers=[1, 2, 4, 8], w0=30, pe=8,
28 | filter_mesh=False, scaling=None, return_sdf=False):
29 |
30 | with HiddenPrints():
31 | # the network has 4 output levels of detail
32 | num_outputs = len(output_layers)
33 | max_frequency = 3*(32,)
34 |
35 | # load model
36 | if len(output_layers) > 1:
37 | model = modules.MultiscaleBACON(3, hidden_size, 1,
38 | hidden_layers=hidden_layers,
39 | bias=True,
40 | frequency=max_frequency,
41 | quantization_interval=np.pi,
42 | is_sdf=True,
43 | output_layers=output_layers,
44 | reuse_filters=True)
45 |
46 | print(model)
47 | ckpt = torch.load(ckpt_path, map_location=device)
48 | model.load_state_dict(ckpt)
49 | model = model.to(device)
50 |
51 | # write output
52 | x = torch.linspace(-0.5, 0.5, N)
53 | if return_sdf:
54 | x = torch.arange(-N//2, N//2) / N
55 | x = x.float()
56 | x, y, z = torch.meshgrid(x, x, x)
57 | render_coords = torch.stack((x.flatten(), y.flatten(), z.flatten()), dim=-1).to(device)
58 | sdf_values = [np.zeros((N**3, 1)) for i in range(num_outputs)]
59 |
60 | # render in a batched fashion to save memory
61 | bsize = int(128**2)
62 | for i in tqdm(range(int(N**3 / bsize))):
63 | coords = render_coords[i*bsize:(i+1)*bsize, :]
64 | out = model({'coords': coords})['model_out']
65 |
66 | if not isinstance(out, list):
67 | out = [out,]
68 |
69 | for idx, sdf in enumerate(out):
70 | sdf_values[idx][i*bsize:(i+1)*bsize] = sdf.detach().cpu().numpy()
71 |
72 | return [sdf.reshape(N, N, N) for sdf in sdf_values]
73 |
74 |
75 | def normalize(coords, scaling=0.9):
76 | coords = np.array(coords).copy()
77 | cmean = np.mean(coords, axis=0, keepdims=True)
78 | coords -= cmean
79 | coord_max = np.amax(coords)
80 | coord_min = np.amin(coords)
81 | coords = (coords - coord_min) / (coord_max - coord_min)
82 | coords -= 0.5
83 | coords *= scaling
84 |
85 | scale = scaling / (coord_max - coord_min)
86 | offset = -scaling * (cmean + coord_min) / (coord_max - coord_min) - 0.5*scaling
87 | return coords, scale, offset
88 |
89 |
90 | def get_ref_spectrum(xyz_file, N):
91 | pointcloud = np.genfromtxt(xyz_file)
92 | v = pointcloud[:, :3]
93 | n = pointcloud[:, 3:]
94 |
95 | n = n / (np.linalg.norm(n, axis=-1)[:, None])
96 | v, _, _ = normalize(v)
97 | print('loaded pc')
98 |
99 | # put pointcloud points into KDTree
100 | kd_tree = KDTree(v)
101 | print('made kd tree')
102 |
103 | # get sdf on grid and show
104 | x = (np.arange(-N//2, N//2) / N).astype(np.float32)
105 | coords = np.stack([arr.flatten() for arr in np.meshgrid(x, x, x)], axis=-1)
106 |
107 | sdf, idx = kd_tree.query(coords, k=3)
108 |
109 | # get average normal of hit point
110 | avg_normal = np.mean(n[idx], axis=1)
111 | sdf = np.sum((coords - v[idx][:, 0]) * avg_normal, axis=-1)
112 | sdf = sdf.reshape(N, N, N)
113 | return [sdf, ]
114 |
115 |
116 | def extract_spectrum():
117 |
118 | scenes = ['armadillo']
119 | Ns = [384, 384, 384, 512, 512]
120 | methods = ['bacon', 'ref']
121 |
122 | for method in methods:
123 | for scene, N in zip(scenes, Ns):
124 | if method == 'ref':
125 | sdfs = get_ref_spectrum(f'ref_{scene}.xyz', N)
126 | else:
127 |
128 | ckpt = 'model_final.pth'
129 |
130 | sdfs = export_model(ckpt, scene, model_type=method, output_layers=[2, 4, 6, 8],
131 | return_sdf=True, N=N, pe=8, w0=30)
132 |
133 | sdfs_ft = [np.abs(np.fft.fftshift(np.fft.fftn(sdf))) for sdf in sdfs]
134 | sdfs_ft = [sdf_ft / np.max(sdf_ft)*1000 for sdf_ft in sdfs_ft]
135 | sdfs_ft = [np.clip(sdf_ft, 0, 1)**(1/3) for sdf_ft in sdfs_ft]
136 |
137 | for idx, sdf_ft in enumerate(sdfs_ft):
138 | with mrcfile.new_mmap(f'/tmp/sdf_ft_{idx}.mrc', overwrite=True, shape=(N, N, N), mrc_mode=2) as mrc:
139 | mrc.data[:] = sdf_ft
140 |
141 | # render with chimera
142 | with open('/tmp/render.cxc', 'w') as f:
143 | for i in range(len(sdfs_ft)):
144 | f.write(f'open /tmp/sdf_ft_{i}.mrc\n')
145 | f.write('volume #1 style solid level 0,0 level 1,1\n')
146 | f.write('volume #1 maximumIntensityProjection true\n')
147 | f.write('volume #!1 showOutlineBox true\n')
148 | f.write('volume #1 outlineBoxRgb slate gray\n')
149 | f.write('volume #1 step 1\n')
150 | f.write('view matrix camera 0.91721,0.10246,-0.38499,-533.37,-0.010261,0.97212,0.23427,631.78,0.39826,-0.21092,0.89269,1870.6\n')
151 | f.write('view\n')
152 | f.write(f'save /tmp/{method}_{scene}_{i+1}.png width 512 height 512 transparentBackground false\n')
153 | f.write('close #1\n')
154 | f.write('exit\n')
155 | os.system('chimerax /tmp/render.cxc')
156 |
157 | for idx in range(len(sdfs_ft)):
158 | fname = f'{method}_{scene}_{idx+1}.png'
159 | os.system(f'./magicwand 1,1 -t 9 -r outside -m overlay -o 0 /tmp/{fname} {fname}')
160 | os.system(f'convert {fname} -trim +repage {fname}')
161 | os.remove(f'/tmp/{fname}')
162 |
163 | # clean up
164 | for idx in range(len(sdfs_ft)):
165 | os.remove(f'/tmp/sdf_ft_{idx}.mrc')
166 | os.remove('/tmp/render.cxc')
167 |
168 |
169 | if __name__ == '__main__':
170 | extract_spectrum()
171 |
--------------------------------------------------------------------------------
/spectrum_visualization/magicwand:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Developed by Fred Weinhaus 11/2/2007 .......... revised 4/30/2015
4 | #
5 | # ------------------------------------------------------------------------------
6 | #
7 | # Licensing:
8 | #
9 | # Copyright © Fred Weinhaus
10 | #
11 | # My scripts are available free of charge for non-commercial use, ONLY.
12 | #
13 | # For use of my scripts in commercial (for-profit) environments or
14 | # non-free applications, please contact me (Fred Weinhaus) for
15 | # licensing arrangements. My email address is fmw at alink dot net.
16 | #
17 | # If you: 1) redistribute, 2) incorporate any of these scripts into other
18 | # free applications or 3) reprogram them in another scripting language,
19 | # then you must contact me for permission, especially if the result might
20 | # be used in a commercial or for-profit environment.
21 | #
22 | # My scripts are also subject, in a subordinate manner, to the ImageMagick
23 | # license, which can be found at: http://www.imagemagick.org/script/license.php
24 | #
25 | # ------------------------------------------------------------------------------
26 | #
27 | ####
28 | #
29 | # USAGE: magicwand x,y [-t threshold] [-f format] [-r region] [-m mask] [-c color] [-o opacity] infile outfile
30 | # USAGE: magicwand [-h or -help]
31 | #
32 | # OPTIONS:
33 | #
34 | # x,y x,y location to get color and seed floodfill
35 | # -t threshold percent color disimilarity threshold (fuzz factor);
36 | # values from 0 to 100; A value of 0 is an exact
37 | # match and a value of 100 is any color; default=10
38 | # -f format output format; image or mask; default=image
39 | # -r region region to display; inside or outside; default=inside
40 | # -m mask mask type; transparent, binary, edge, overlay, layer;
41 | # default=transparent
42 | # -c color color for background, edge outline or translucent layer;
43 | # color="none" indicates use opacity;
44 | # color="trans" indicates make background transparent;
45 | # default="black"
46 | # -o opacity opacity for transparent, overlay or layer mask;
47 | # values from 0 to 100; default=0
48 | #
49 | ###
50 | #
51 | # NAME: MAGICWAND
52 | #
53 | # PURPOSE: To isolate a contiguous region of an image based upon a color determined
54 | # from a user specified image coordinate.
55 | #
56 | # DESCRIPTION: MAGICWAND determines a contiguous region of an image based
57 | # upon a color determined from a user specified image coordinate and a color
58 | # similarity threshold value (fuzz factor). The output can be either an
59 | # image or a mask. If the region is set to inside, then the image can be
60 | # made to mask out the background as transparent, mask out the background
61 | # using an opacity channel, fill the background with a color, display a
62 | # boundary (outline) edge for the region or apply a translucent color layer
63 | # over the background. If the region is set to outside, then the inside will
64 | # be masked and the outside area will show normally. Alternately, the output
65 | # can be a mask which is either binary (black and white), transparent
66 | # (transparent and white) or a boundary edge (white on black background).
67 | # The boundary edge can be made to match the interior or exterior depending
68 | # upon the region setting.
69 | #
70 | #
71 | # OPTIONS:
72 | #
73 | # x,y ... x,y are the coordinates in the image where the color is to be
74 | # extracted and the floodfill is to be seeded.
75 | #
76 | # -f format ... FORMAT specifies whether the output will be the modified
77 | # input image or a mask image. The choices are image or mask. The default
78 | # is image.
79 | #
80 | # -r region ... REGION specifies whether the inside or outside are of the
81 | # image will show and the other be masked. The choices are inside or
82 | # outside. The default is inside.
83 | #
84 | # -m mask ... MASK specifies the type of mask to use or create. The choices
85 | # are transparent, binary, edge, overlay or layer. Only transparent, binary or edge
86 | # masks will be allowed as output. With a transparent mask, the image will be
87 | # modified to mask out the complement of the regions specified according to the
88 | # color setting. Specify the color setting to 1) "trans" to make the background
89 | # transparent, 2) "none" to mask by multiplying by the opacity setting or
90 | # 3) a color value to use a fill color. With an overlay mask, which will only
91 | # be effective on PNG format output images, the masking is done via the opacity
92 | # channel using the opacity setting. With a layer mask, a translucent color will
93 | # be layered over the background according to the color and opacity values. The
94 | # larger the opacity, the lighter the color overlay and the more the image will
95 | # show.
96 | #
97 | # -c color ... COLOR is the color to be used for the background fill or
98 | # the boundary edge overlaid on the image. Any IM color specification is
99 | # valid or a value of none. Be sure to enclose them in double quotes.
100 | # The color value over-rides the opacity for mask=transparent, so use
101 | # color="none" to allow the opacity to work or use color="trans" to make
102 | # the background transparent. The default="black".
103 | #
104 | # -o opacity ... OPACITY controls the degree of transparency for the area
105 | # that is masked out using a mask setting of transparent, overlay or layer.
106 | # A value of zero is fully transparent and a value of 100 is fully opaque.
107 | #
108 | # CAVEAT: No guarantee that this script will work on all platforms,
109 | # nor that trapping of inconsistent parameters is complete and
110 | # foolproof. Use At Your Own Risk.
111 | #
112 | ######
113 | #
114 |
115 | # set default values
116 | threshold=10
117 | format="image" # image or mask
118 | region="inside" # inside or outside
119 | mask="trans" # trans or binary or edge or overlay (overlay only if png)
120 | bgcolor="black" # color or none when format=image
121 | opacity=0
122 |
123 | # set directory for temporary files
124 | dir="." # suggestions are dir="." or dir="/tmp"
125 |
126 |
127 | # set up functions to report Usage and Usage with Description
128 | PROGNAME=`type $0 | awk '{print $3}'` # search for executable on path
129 | PROGDIR=`dirname $PROGNAME` # extract directory of program
130 | PROGNAME=`basename $PROGNAME` # base name of program
131 | usage1()
132 | {
133 | echo >&2 ""
134 | echo >&2 "$PROGNAME:" "$@"
135 | sed >&2 -e '1,/^####/d; /^###/g; /^#/!q; s/^#//; s/^ //; 4,$p' "$PROGDIR/$PROGNAME"
136 | }
137 | usage2()
138 | {
139 | echo >&2 ""
140 | echo >&2 "$PROGNAME:" "$@"
141 | sed >&2 -e '1,/^####/d; /^######/g; /^#/!q; s/^#*//; s/^ //; 4,$p' "$PROGDIR/$PROGNAME"
142 | }
143 |
144 |
145 | # function to report error messages
146 | errMsg()
147 | {
148 | echo ""
149 | echo $1
150 | echo ""
151 | usage1
152 | exit 1
153 | }
154 |
155 |
156 | # function to test for minus at start of value of second part of option 1 or 2
157 | checkMinus()
158 | {
159 | test=`echo "$1" | grep -c '^-.*$'` # returns 1 if match; 0 otherwise
160 | [ $test -eq 1 ] && errMsg "$errorMsg"
161 | }
162 |
163 | # test for correct number of arguments and get values
164 | if [ $# -eq 0 ]
165 | then
166 | # help information
167 | echo ""
168 | usage2
169 | exit 0
170 | elif [ $# -gt 15 ]
171 | then
172 | errMsg "--- TOO MANY ARGUMENTS WERE PROVIDED ---"
173 | else
174 | while [ $# -gt 0 ]
175 | do
176 | # get parameter values
177 | case "$1" in
178 | -h|-help) # help information
179 | echo ""
180 | usage2
181 | exit 0
182 | ;;
183 | -f) # format
184 | shift # to get the next parameter - format
185 | # test if parameter starts with minus sign
186 | errorMsg="--- INVALID FORMAT SPECIFICATION ---"
187 | checkMinus "$1"
188 | # test region values
189 | format="$1"
190 | [ "$format" != "image" -a "$format" != "mask" ] && errMsg "--- FORMAT=$format IS NOT A VALID VALUE ---"
191 | ;;
192 | -r) # region
193 | shift # to get the next parameter - region
194 | # test if parameter starts with minus sign
195 | errorMsg="--- INVALID REGION SPECIFICATION ---"
196 | checkMinus "$1"
197 | # test region values
198 | region="$1"
199 | [ "$region" != "inside" -a "$region" != "outside" ] && errMsg "--- REGION=$region IS NOT A VALID VALUE ---"
200 | ;;
201 | -m) # mask
202 | shift # to get the next parameter - mask
203 | # test if parameter starts with minus sign
204 | errorMsg="--- INVALID MASK SPECIFICATION ---"
205 | checkMinus "$1"
206 | # test mask values
207 | mask="$1"
208 | [ "$mask" != "transparent" -a "$mask" != "binary" -a "$mask" != "edge" -a "$mask" != "overlay" -a "$mask" != "layer" ] && errMsg "--- MASK=$mask IS NOT A VALID VALUE ---"
209 | [ "$mask" = "transparent" ] && mask="trans"
210 | ;;
211 | -c) # get color
212 | shift # to get the next parameter - lineval
213 | # test if parameter starts with minus sign
214 | errorMsg="--- INVALID COLOR SPECIFICATION ---"
215 | checkMinus "$1"
216 | bgcolor="$1"
217 | ;;
218 | -t) # get threshold
219 | shift # to get the next parameter - threshold
220 | # test if parameter starts with minus sign
221 | errorMsg="--- INVALID THRESHOLD SPECIFICATION ---"
222 | checkMinus "$1"
223 | # test threshold values
224 | threshold=`expr "$1" : '\([.0-9]*\)'`
225 | [ "$threshold" = "" ] && errMsg "THRESHOLD=$threshold IS NOT A NON-NEGATIVE FLOATING POINT NUMBER"
226 | thresholdtestA=`echo "$threshold < 0" | bc`
227 | thresholdtestB=`echo "$threshold > 100" | bc`
228 | [ $thresholdtestA -eq 1 -o $thresholdtestB -eq 1 ] && errMsg "--- THRESHOLD=$threshold MUST BE GREATER THAN OR EQUAL 0 AND LESS THAN OR EQUAL 100 ---"
229 | ;;
230 | -o) # get opacity
231 | shift # to get the next parameter - opacity
232 | # test if parameter starts with minus sign
233 | errorMsg="--- INVALID OPACITY SPECIFICATION ---"
234 | checkMinus "$1"
235 | # test width values
236 | opacity=`expr "$1" : '\([.0-9]*\)'`
237 | [ "$opacity" = "" ] && errMsg "OPACITY=$opacity IS NOT A NON-NEGATIVE FLOATING POINT NUMBER"
238 | opacitytestA=`echo "$opacity < 0" | bc`
239 | opacitytestB=`echo "$opacity > 100" | bc`
240 | [ $opacitytestA -eq 1 -o $opacitytestB -eq 1 ] && errMsg "--- OPACITY=$opacity MUST BE GREATER THAN OR EQUAL 0 AND LESS THAN OR EQUAL 100 ---"
241 | ;;
242 | -) # STDIN, end of arguments
243 | break
244 | ;;
245 | -*) # any other - argument
246 | errMsg "--- UNKNOWN OPTION ---"
247 | ;;
248 | [0-9]*,[0-9]*) # Values supplied for coordinates
249 | coords="$1"
250 | ;;
251 | .*,.*) # Bogus Values supplied
252 | errMsg "--- COORDINATES ARE NOT VALID ---"
253 | ;;
254 | *) # end of arguments
255 | break
256 | ;;
257 | esac
258 | shift # next option
259 | done
260 | #
261 | # get infile and outfile
262 | infile="$1"
263 | outfile="$2"
264 | fi
265 |
266 |
267 | # test that infile provided
268 | [ "$infile" = "" ] && errMsg "NO INPUT FILE SPECIFIED"
269 |
270 | # test that outfile provided
271 | [ "$outfile" = "" ] && errMsg "NO OUTPUT FILE SPECIFIED"
272 |
273 | tmpA="$dir/magicwand_$$.mpc"
274 | tmpB="$dir/magicwand_$$.cache"
275 | tmp0="$dir/magicwand_0_$$.png"
276 | trap "rm -f $tmpA $tmpB $tmp0;" 0
277 | trap "rm -f $tmpA $tmpB $tmp0; exit 1" 1 2 3 15
278 | trap "rm -f $tmpA $tmpB $tmp0; exit 1" ERR
279 |
280 | if convert -quiet "$infile" +repage "$tmpA"
281 | then
282 | width=`identify -format %w $tmpA`
283 | height=`identify -format %h $tmpA`
284 | [ "$coords" = "" ] && errMsg "--- NO COORDINATES PROVIDED ---"
285 | else
286 | errMsg "--- FILE $infile DOES NOT EXIST OR IS NOT AN ORDINARY FILE, NOT READABLE OR HAS ZERO SIZE ---"
287 | fi
288 |
289 |
290 | # get im_version
291 | im_version=`convert -list configure | \
292 | sed '/^LIB_VERSION_NUMBER /!d; s//,/; s/,/,0/g; s/,0*\([0-9][0-9]\)/\1/g' | head -n 1`
293 |
294 | # set up floodfill
295 | if [ "$im_version" -ge "07000000" ]; then
296 | matte_alpha="alpha"
297 | else
298 | matte_alpha="alpha"
299 | fi
300 |
301 | # create transparent mask for region=inside
302 | # make interior transparent and outside black
303 | convert $tmpA -fuzz $threshold% -fill none -draw "$matte_alpha $coords floodfill" \
304 | -fill black +opaque none $tmp0
305 |
306 | # create negative mask - for region=outside
307 | if [ "$region" = "outside" ]
308 | then
309 | convert $tmp0 -channel rgba \
310 | -fill white -opaque none \
311 | -transparent black \
312 | -fill black -opaque white $tmp0
313 |
314 | fi
315 |
316 |
317 | # convert mask to binary if appropriate
318 | if [ "$mask" != "trans" -a "$mask" != "layer" ]
319 | then
320 | # make transparent go to white
321 | convert \( -size ${width}x${height} xc:white \) $tmp0 \
322 | -composite $tmp0
323 | fi
324 |
325 |
326 | #convert mask to edge if appropriate
327 | if [ "$mask" = "edge" ]
328 | then
329 | convert $tmp0 -convolve "-1,-1,-1,-1,8,-1,-1,-1,-1" -clamp $tmp0
330 | fi
331 |
332 |
333 | # process image if appropriate
334 | if [ "$format" = "image" -a "$mask" = "edge" ]
335 | then
336 | convert $tmpA \( $tmp0 -transparent black -fill $bgcolor -opaque white \) -composite $tmp0
337 |
338 | elif [ "$format" = "image" -a "$mask" = "trans" ]
339 | then
340 | # composite with input and set background
341 | if [ "$bgcolor" = "trans" ]
342 | then
343 | convert $tmpA $tmp0 -composite -transparent black $tmp0
344 | elif [ "$bgcolor" = "none" ]
345 | then
346 | convert $tmpA \( $tmp0 -fill "rgb($opacity%,$opacity%,$opacity%)" -opaque black \) -compose Multiply -composite $tmp0
347 | else
348 | convert $tmpA $tmp0 -composite -fill $bgcolor -opaque black $tmp0
349 | fi
350 |
351 | elif [ "$format" = "image" -a "$mask" = "overlay" ]
352 | then
353 | convert $tmpA \( $tmp0 -fill "rgb($opacity%,$opacity%,$opacity%)" -opaque black \) -compose Copy_Opacity -composite $tmp0
354 |
355 | elif [ "$format" = "image" -a "$mask" = "layer" ]
356 | then
357 | opacity=`expr 100 - $opacity`
358 | convert $tmp0 -fill $bgcolor -opaque black $tmp0
359 | if [ "$im_version" -lt "06050304" ]; then
360 | composite -dissolve $opacity% $tmp0 $tmpA $tmp0
361 | else
362 | convert $tmpA $tmp0 -define compose:args=$opacity% -compose dissolve -composite $tmp0
363 | fi
364 | fi
365 | convert $tmp0 "$outfile"
366 | exit 0
367 |
--------------------------------------------------------------------------------
/spectrum_visualization/model_final.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/spectrum_visualization/model_final.pth
--------------------------------------------------------------------------------
/trained_models/armadillo.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/trained_models/armadillo.pth
--------------------------------------------------------------------------------
/trained_models/dragon.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/trained_models/dragon.pth
--------------------------------------------------------------------------------
/trained_models/lego.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/trained_models/lego.pth
--------------------------------------------------------------------------------
/trained_models/lego_semisupervise.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/trained_models/lego_semisupervise.pth
--------------------------------------------------------------------------------
/trained_models/lucy.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/trained_models/lucy.pth
--------------------------------------------------------------------------------
/trained_models/thai.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/bacon/c691c69f1dd41a32329a03a32850bf5ced772343/trained_models/thai.pth
--------------------------------------------------------------------------------
/training.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 | import torch
3 | import utils
4 | from tqdm.autonotebook import tqdm
5 | import time
6 | import numpy as np
7 | import os
8 | import forward_models
9 | from functools import partial
10 | import shutil
11 |
12 |
13 | def train(model, train_dataloader, steps, lr, steps_til_summary,
14 | steps_til_checkpoint, model_dir, loss_fn, summary_fn,
15 | prefix_model_dir='', val_dataloader=None, double_precision=False,
16 | clip_grad=False, use_lbfgs=False, loss_schedules=None, params=None,
17 | ckpt_step=0, use_lr_scheduler=False):
18 |
19 | if params is None:
20 | optim = torch.optim.Adam(lr=lr, params=model.parameters(), amsgrad=True)
21 | else:
22 | optim = torch.optim.Adam(lr=lr, params=params, amsgrad=True)
23 |
24 | if use_lbfgs:
25 | optim = torch.optim.LBFGS(lr=lr, params=model.parameters(),
26 | max_iter=50000, max_eval=50000,
27 | history_size=50, line_search_fn='strong_wolfe')
28 |
29 | scheduler = None
30 | if use_lr_scheduler:
31 | def sampling_scheduler(step, start=0, lr0=1e-4, lrn=1e-4):
32 |
33 | if step > start:
34 | fine_scale = lr_log_schedule(step-start, num_steps=steps-start, nw=1, lr0=lr0, lrn=lrn)
35 | train_dataloader.dataset.fine_scale = fine_scale
36 | else:
37 | train_dataloader.dataset.fine_scale = lr0
38 |
39 | # lr scheduler
40 | optim.param_groups[0]['lr'] = 1
41 | log_scheduler = partial(lr_log_schedule, num_steps=steps, nw=1, lr0=lr, lrn=1e-4)
42 | scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=log_scheduler)
43 |
44 | if os.path.exists(model_dir):
45 | pass
46 | else:
47 | os.makedirs(model_dir)
48 |
49 | model_dir_postfixed = os.path.join(model_dir, prefix_model_dir)
50 |
51 | summaries_dir = os.path.join(model_dir_postfixed, 'summaries')
52 | utils.cond_mkdir(summaries_dir)
53 |
54 | checkpoints_dir = os.path.join(model_dir_postfixed, 'checkpoints')
55 | utils.cond_mkdir(checkpoints_dir)
56 |
57 | writer = SummaryWriter(summaries_dir)
58 |
59 | # e.g. epochs=1k, len(train_dataloader)=25
60 | train_generator = iter(train_dataloader)
61 |
62 | with tqdm(total=steps) as pbar:
63 | train_losses = []
64 | for step in range(steps):
65 |
66 | if not step % steps_til_checkpoint and step:
67 | torch.save(model.state_dict(),
68 | os.path.join(checkpoints_dir,
69 | 'model_step_%04d.pth' % (step + ckpt_step)))
70 | np.savetxt(os.path.join(checkpoints_dir,
71 | 'train_losses_step_%04d.txt' % (step + ckpt_step)),
72 | np.array(train_losses))
73 |
74 | try:
75 | # sampling_scheduler(step)
76 | model_input, gt = next(train_generator)
77 | except StopIteration:
78 | train_generator = iter(train_dataloader)
79 | model_input, gt = next(train_generator)
80 |
81 | start_time = time.time()
82 |
83 | model_input = dict2cuda(model_input)
84 | gt = dict2cuda(gt)
85 |
86 | if double_precision:
87 | model_input = {key: value.double()
88 | for key, value in model_input.items()}
89 | gt = {key: value.double() for key, value in gt.items()}
90 |
91 | if use_lbfgs:
92 | def closure():
93 | optim.zero_grad(set_to_none=True)
94 | model_output = model(model_input)
95 | losses = loss_fn(model_output, gt)
96 | train_loss = 0.
97 | for loss_name, loss in losses.items():
98 | train_loss += loss.mean()
99 | train_loss.backward()
100 | return train_loss
101 | optim.step(closure)
102 |
103 | model_output = model(model_input)
104 | losses = loss_fn(model_output, gt)
105 |
106 | train_loss = 0.
107 | for loss_name, loss in losses.items():
108 | single_loss = loss.mean()
109 |
110 | if loss_schedules is not None and \
111 | loss_name in loss_schedules:
112 | writer.add_scalar(loss_name + "_weight",
113 | loss_schedules[loss_name](step), step)
114 | single_loss *= loss_schedules[loss_name](step)
115 |
116 | writer.add_scalar(loss_name, single_loss, step)
117 | train_loss += single_loss
118 |
119 | train_losses.append(train_loss.item())
120 | writer.add_scalar("total_train_loss", train_loss, step)
121 | writer.add_scalar("lr", optim.param_groups[0]['lr'], step)
122 |
123 | if not step % steps_til_summary:
124 | torch.save(model.state_dict(),
125 | os.path.join(checkpoints_dir,
126 | 'model_current.pth'))
127 | summary_fn(model, model_input, gt, model_output, writer, step)
128 |
129 | if not use_lbfgs:
130 | optim.zero_grad(set_to_none=True)
131 | train_loss.backward()
132 |
133 | if clip_grad:
134 | if isinstance(clip_grad, bool):
135 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.)
136 | else:
137 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad)
138 |
139 | optim.step()
140 |
141 | if scheduler is not None:
142 | scheduler.step()
143 |
144 | pbar.update(1)
145 |
146 | if not step % steps_til_summary:
147 | tqdm.write("Step %d, Total loss %0.6f, iteration time %0.6f" % (step, train_loss, time.time() - start_time))
148 |
149 | if val_dataloader is not None:
150 | print("Running validation set...")
151 | model.eval()
152 | with torch.no_grad():
153 | val_losses = []
154 | for (model_input, gt) in val_dataloader:
155 | model_output = model(model_input)
156 | val_loss = loss_fn(model_output, gt)
157 | val_losses.append(val_loss)
158 |
159 | writer.add_scalar("val_loss", np.mean(val_losses), step)
160 | model.train()
161 |
162 | torch.save(model.state_dict(),
163 | os.path.join(checkpoints_dir, 'model_final.pth'))
164 | np.savetxt(os.path.join(checkpoints_dir, 'train_losses_final.txt'),
165 | np.array(train_losses))
166 |
167 |
168 | def dict2cuda(a_dict):
169 | tmp = {}
170 | for key, value in a_dict.items():
171 | if isinstance(value, torch.Tensor):
172 | tmp.update({key: value.cuda()})
173 | elif isinstance(value, dict):
174 | tmp.update({key: dict2cuda(value)})
175 | elif isinstance(value, list) or isinstance(value, tuple):
176 | if isinstance(value[0], torch.Tensor):
177 | tmp.update({key: [v.cuda() for v in value]})
178 | else:
179 | tmp.update({key: value})
180 | return tmp
181 |
182 |
183 | def dict2cpu(a_dict):
184 | tmp = {}
185 | for key, value in a_dict.items():
186 | if isinstance(value, torch.Tensor):
187 | tmp.update({key: value.cpu()})
188 | elif isinstance(value, dict):
189 | tmp.update({key: dict2cpu(value)})
190 | elif isinstance(value, list):
191 | if isinstance(value[0], torch.Tensor):
192 | tmp.update({key: [v.cpu() for v in value]})
193 | else:
194 | tmp.update({key: value})
195 | return tmp
196 |
197 |
198 | def reg_schedule(it, num_steps=1e6, lr0=1e-3, lrn=1e-4):
199 | return np.exp((1 - it/num_steps) * np.log(lr0) + (it/num_steps) * np.log(lrn))
200 |
201 |
202 | def lr_log_schedule(it, num_steps=1e6, nw=2500, lr0=1e-3, lrn=5e-6, lambdaw=0.01):
203 | return (lambdaw + (1 - lambdaw) * np.sin(np.pi/2 * np.clip(it/nw, 0, 1))) \
204 | * np.exp((1 - it/num_steps) * np.log(lr0) + (it/num_steps) * np.log(lrn))
205 |
206 |
207 | def train_wchunks(models, train_dataloader, num_steps, lr, steps_til_summary, steps_til_checkpoint, model_dir,
208 | loss_fn, summary_fn, chunk_lists_from_batch_fn,
209 | val_dataloader=None, double_precision=False, clip_grad=False, loss_schedules=None,
210 | num_cuts=128,
211 | max_chunk_size=4096,
212 | resume_checkpoint={},
213 | chunked=True,
214 | hierarchical_sampling=False,
215 | coarse_loss_weight=0.1,
216 | stop_after=None):
217 |
218 | optims = {key: torch.optim.Adam(lr=1, params=model.parameters())
219 | for key, model in models.items()}
220 | schedulers = {key: torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_log_schedule)
221 | for key, optim in optims.items()}
222 |
223 | # load optimizer if supplied
224 | for key in models.keys():
225 | if key in resume_checkpoint:
226 | optims[key].load_state_dict(resume_checkpoint[key]['optim'])
227 | schedulers[key].load_state_dict(resume_checkpoint[key]['scheduler'])
228 |
229 | if os.path.exists(os.path.join(model_dir, 'summaries')):
230 | val = input("The model directory %s exists. Overwrite? (y/n)" % model_dir)
231 | if val == 'y':
232 | if os.path.exists(os.path.join(model_dir, 'summaries')):
233 | shutil.rmtree(os.path.join(model_dir, 'summaries'))
234 | if os.path.exists(os.path.join(model_dir, 'checkpoints')):
235 | shutil.rmtree(os.path.join(model_dir, 'checkpoints'))
236 |
237 | os.makedirs(model_dir, exist_ok=True)
238 |
239 | summaries_dir = os.path.join(model_dir, 'summaries')
240 | utils.cond_mkdir(summaries_dir)
241 |
242 | checkpoints_dir = os.path.join(model_dir, 'checkpoints')
243 | utils.cond_mkdir(checkpoints_dir)
244 |
245 | writer = SummaryWriter(summaries_dir)
246 |
247 | start_step = 0
248 | if 'step' in resume_checkpoint:
249 | start_step = resume_checkpoint['step']
250 |
251 | train_generator = iter(train_dataloader)
252 |
253 | with tqdm(total=num_steps) as pbar:
254 | pbar.update(start_step)
255 | train_losses = []
256 | for step in range(start_step, num_steps):
257 | if not step % steps_til_checkpoint and step:
258 | for key, model in models.items():
259 | torch.save(model.state_dict(),
260 | os.path.join(checkpoints_dir, 'model_'+key+'_step_%04d.pth' % step))
261 | np.savetxt(os.path.join(checkpoints_dir, 'train_losses_step_%04d.txt' % step),
262 | np.array(train_losses))
263 | for key, optim in optims.items():
264 | torch.save({'step': step,
265 | 'optimizer_state_dict': optim.state_dict(),
266 | 'scheduler_state_dict': schedulers[key].state_dict()},
267 | os.path.join(checkpoints_dir, 'optim_'+key+'_step_%04d.pth' % step))
268 |
269 | try:
270 | model_input, meta, gt = next(train_generator)
271 | except StopIteration:
272 | train_dataloader.dataset.shuffle_rays()
273 | train_generator = iter(train_dataloader)
274 | model_input, meta, gt = next(train_generator)
275 |
276 | start_time = time.time()
277 |
278 | for optim in optims.values():
279 | optim.zero_grad(set_to_none=True)
280 |
281 | batch_avged_losses = {}
282 | if chunked:
283 | list_chunked_model_input, list_chunked_meta, list_chunked_gt = \
284 | chunk_lists_from_batch_fn(model_input, meta, gt, max_chunk_size)
285 |
286 | num_chunks = len(list_chunked_gt)
287 | batch_avged_tot_loss = 0.
288 | for chunk_idx, (chunked_model_input, chunked_meta, chunked_gt) \
289 | in enumerate(zip(list_chunked_model_input, list_chunked_meta, list_chunked_gt)):
290 | chunked_model_input = dict2cuda(chunked_model_input)
291 | chunked_meta = dict2cuda(chunked_meta)
292 | chunked_gt = dict2cuda(chunked_gt)
293 |
294 | # forward pass through model
295 | for k in models.keys():
296 | models[k].stop_after = stop_after
297 | chunk_model_outputs = {key: model(chunked_model_input) for key, model in models.items()}
298 | for k in models.keys():
299 | models[k].stop_after = None
300 |
301 | losses = {}
302 |
303 | if hierarchical_sampling:
304 | # set idx to use sigma from coarse level (idx=0)
305 | # for hierarchical sampling
306 | chunked_model_input_fine = sample_pdf(chunked_model_input,
307 | chunk_model_outputs, idx=0)
308 |
309 | chunk_model_importance_outputs = {key: model(chunked_model_input_fine)
310 | for key, model in models.items()}
311 |
312 | reg_lambda = reg_schedule(step)
313 | losses_importance = loss_fn(chunk_model_importance_outputs, chunked_gt,
314 | regularize_sigma=True, reg_lambda=reg_lambda)
315 |
316 | # loss from forward pass
317 | train_loss = 0.
318 | for loss_name, loss in losses.items():
319 |
320 | single_loss = loss.mean()
321 | train_loss += single_loss / num_chunks
322 |
323 | batch_avged_tot_loss += float(single_loss / num_chunks)
324 | if loss_name in batch_avged_losses:
325 | batch_avged_losses[loss_name] += single_loss / num_chunks
326 | else:
327 | batch_avged_losses.update({loss_name: single_loss/num_chunks})
328 |
329 | # Loss from eventual second pass
330 | if hierarchical_sampling:
331 | for loss_name, loss in losses_importance.items():
332 | single_loss = loss.mean()
333 | train_loss += single_loss / num_chunks
334 |
335 | batch_avged_tot_loss += float(train_loss)
336 | if loss_name + '_importance' in batch_avged_losses:
337 | batch_avged_losses[loss_name+'_importance'] += single_loss / num_chunks
338 | else:
339 | batch_avged_losses.update({loss_name + '_importance': single_loss / num_chunks})
340 |
341 | train_loss.backward()
342 | else:
343 | model_input = dict2cuda(model_input)
344 | meta = dict2cuda(meta)
345 | gt = dict2cuda(gt)
346 |
347 | model_outputs = {key: model(model_input) for key, model in models.items()}
348 | losses = loss_fn(model_outputs, gt)
349 |
350 | # loss from forward pass
351 | train_loss = 0.
352 | for loss_name, loss in losses.items():
353 |
354 | single_loss = loss.mean()
355 | train_loss += single_loss
356 |
357 | batch_avged_tot_loss = float(single_loss)
358 | if loss_name in batch_avged_losses:
359 | batch_avged_losses[loss_name] += single_loss
360 | else:
361 | batch_avged_losses.update({loss_name: single_loss})
362 |
363 | train_loss.backward()
364 |
365 | for loss_name, loss in batch_avged_losses.items():
366 | writer.add_scalar(loss_name, loss, step)
367 | train_losses.append(batch_avged_tot_loss)
368 | writer.add_scalar("total_train_loss", batch_avged_tot_loss, step)
369 | writer.add_scalar("reg_lambda", reg_lambda, step)
370 |
371 | for k in optims.keys():
372 | writer.add_scalar(f"{k}_lr", optims[k].param_groups[0]['lr'], step)
373 |
374 | if clip_grad:
375 | for model in models.values():
376 | if isinstance(clip_grad, bool):
377 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
378 | else:
379 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad)
380 |
381 | for optim in optims.values():
382 | optim.step()
383 |
384 | if not step % steps_til_summary:
385 | tqdm.write("Step %d, Total loss %0.6f, iteration time %0.6f" % (step, train_loss, time.time() - start_time))
386 | for key, model in models.items():
387 | torch.save(model.state_dict(),
388 | os.path.join(checkpoints_dir, 'model_'+key+'_current.pth'))
389 | for key, optim in optims.items():
390 | torch.save({'step': step,
391 | 'total_steps': step,
392 | 'optimizer_state_dict': optim.state_dict(),
393 | 'scheduler_state_dict': schedulers[key].state_dict()},
394 | os.path.join(checkpoints_dir, 'optim_'+key+'_current.pth'))
395 | summary_fn(models, train_dataloader, val_dataloader, loss_fn, optims, meta, gt,
396 | writer, step)
397 |
398 | pbar.update(1)
399 |
400 | for k in schedulers.keys():
401 | schedulers[k].step()
402 |
403 | for key, model in models.items():
404 | torch.save(model.state_dict(),
405 | os.path.join(checkpoints_dir, 'model_' + key + '_final.pth'))
406 | np.savetxt(os.path.join(checkpoints_dir, 'train_losses_final.txt'),
407 | np.array(train_losses))
408 |
409 |
410 | def sample_pdf(model_inputs, model_outputs, offset=5e-3,
411 | idx=-1):
412 | ''' hierarchical sampling code for neural radiance fields '''
413 |
414 | z_vals = model_inputs['t']
415 | bins = .5*(z_vals[..., 1:, :] + z_vals[..., :-1, :]).squeeze()
416 | bins = bins.clone().detach().requires_grad_(True)
417 |
418 | if 'combined' in model_outputs:
419 | if isinstance(model_outputs['combined']['model_out']['output'], list):
420 | pred_sigma = model_outputs['combined']['model_out']['output'][idx][..., -1:]
421 | t_intervals = model_outputs['combined']['model_in']['t_intervals']
422 | else:
423 | pred_sigma = model_outputs['combined']['model_out']['output'][..., -1:]
424 | t_intervals = model_outputs['combined']['model_in']['t_intervals']
425 | else:
426 | pred_sigma = model_outputs['sigma']['model_out']['output']
427 | t_intervals = model_outputs['sigma']['model_in']['t_intervals']
428 |
429 | if isinstance(pred_sigma, list):
430 | pred_sigma = pred_sigma[idx]
431 |
432 | pred_weights = forward_models.compute_transmittance_weights(pred_sigma, t_intervals)[..., :-1, 0]
433 |
434 | # blur weights
435 | pred_weights = torch.cat((pred_weights, pred_weights[..., -1:]), dim=-1)
436 | weights_max = torch.maximum(pred_weights[..., :-1], pred_weights[..., 1:])
437 | weights_blur = 0.5 * (weights_max[..., :-1] + weights_max[..., 1:])
438 | pred_weights = weights_blur + offset
439 |
440 | pdf = pred_weights / torch.sum(pred_weights, dim=-1, keepdim=True)
441 |
442 | cdf = torch.cumsum(pdf, dim=-1)
443 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1).squeeze() # batch_pixels, num_bins=samples_per_ray-1)
444 | cdf = cdf.detach()
445 | num_samples = pred_sigma.shape[-2]
446 | u = torch.rand(list(cdf.shape[:-1])+[num_samples], device=pred_weights.device)
447 |
448 | inds = torch.searchsorted(cdf, u, right=True)
449 | below = torch.max(torch.zeros_like(inds), inds-1)
450 | above = torch.min((cdf.shape[-1]-1)*torch.ones_like(inds), inds)
451 | inds_g = torch.stack((below, above), -1)
452 |
453 | matched_shape = (inds_g.shape[0], inds_g.shape[1], cdf.shape[-1])
454 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
455 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
456 |
457 | denom = (cdf_g[..., 1]-cdf_g[..., 0])
458 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
459 | t = (u - cdf_g[..., 0])/denom
460 | t_vals = (bins_g[..., 0] + t*(bins_g[..., 1]-bins_g[..., 0])).unsqueeze(-1)
461 | t_vals, _ = torch.sort(t_vals, dim=-2)
462 |
463 | ray_dirs = model_inputs['ray_directions']
464 | ray_orgs = model_inputs['ray_origins']
465 |
466 | t_vals = t_vals[..., 0]
467 | t_intervals = t_vals[..., 1:] - t_vals[..., :-1]
468 | t_intervals = torch.cat((t_intervals, 1e10*torch.ones_like(t_intervals[:, 0:1])), dim=-1)
469 | t_intervals = (t_intervals * ray_dirs.norm(p=2, dim=-1))[..., None]
470 | t_vals = t_vals[..., None]
471 |
472 | if ray_dirs.ndim == 4:
473 | t_vals = t_vals[None, ...]
474 |
475 | model_inputs.update({'t': t_vals})
476 | model_inputs.update({'ray_samples': ray_orgs + ray_dirs * t_vals})
477 | model_inputs.update({'t_intervals': t_intervals})
478 |
479 | return model_inputs
480 |
--------------------------------------------------------------------------------