├── .gitignore ├── LICENSE ├── README.md ├── data └── pluto.jpg ├── data_structs.py ├── dataio.py ├── environment.yml ├── experiment_scripts ├── config_img │ ├── config_mars.ini │ ├── config_pluto_acorn_1k.ini │ ├── config_pluto_acorn_4k.ini │ ├── config_pluto_acorn_8k.ini │ ├── config_pluto_pe_8k.ini │ ├── config_pluto_siren_8k.ini │ └── config_tokyo.ini ├── config_occupancy │ ├── config_dragon_acorn.ini │ ├── config_engine_acorn.ini │ ├── config_lucy_acorn.ini │ ├── config_lucy_small_acorn.ini │ └── config_thai_acorn.ini ├── train_img.py └── train_occupancy.py ├── img └── teaser.png ├── inside_mesh ├── inside_mesh.py ├── setup.py └── triangle_hash.pyx ├── loss_functions.py ├── metrics.py ├── modules.py ├── pruning_functions.py ├── training.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | inside_mesh/ 3 | logs/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Stanford Computational Imaging Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ACORN: Adaptive Coordinate Networks for Neural Scene Representation
SIGGRAPH 2021 2 | ### [Project Page](http://www.computationalimaging.org/publications/acorn/) | [Video](https://www.youtube.com/watch?v=P192X3J6cg4) | [Paper](https://arxiv.org/abs/2105.02788) 3 | PyTorch implementation of ACORN.
4 | [ACORN: Adaptive Coordinate Networks for Neural Scene Representation](http://www.computationalimaging.org/publications/acorn/)
5 | [Julien N. P. Martel](http://web.stanford.edu/~jnmartel/)\*, 6 | [David B. Lindell](https://davidlindell.com)\*, 7 | [Connor Z. Lin](https://connorzlin.com/), 8 | [Eric R. Chan](https://ericryanchan.github.io/about.html), 9 | [Marco Monteiro](https://twitter.com/monteiroamarco), 10 | [Gordon Wetzstein](https://computationalimaging.org)
11 | Stanford University
12 | \*denotes equal contribution 13 | in SIGGRAPH 2021 14 | 15 | 16 | 17 | ## Quickstart 18 | 19 | To setup a conda environment, download example training data, begin the training process, and launch Tensorboard, follow the below commands. As part of this you will also need to [register for and install an academic license](https://www.gurobi.com/downloads/free-academic-license/) for the Gurobi optimizer (this is free for academic use). 20 | ``` 21 | conda env create -f environment.yml 22 | # before proceeding, install Gurobi optimizer license (see above web link) 23 | conda activate acorn 24 | cd inside_mesh 25 | python setup.py build_ext --inplace 26 | cd ../experiment_scripts 27 | python train_img.py --config ./config_img/config_pluto_acorn_1k.ini 28 | tensorboard --logdir=../logs --port=6006 29 | ``` 30 | 31 | This example will fit 1 MP image of Pluto. You can monitor the training in your browser at `localhost:6006`. 32 | 33 | ### Adaptive Coordinate Networks 34 | 35 | An adaptive coordinate network learns an adaptive decomposition of the signal domain, allowing the network to fit signals faster and more accurately. We demonstrate using ACORN to fit large-scale images and detailed 3D occupancy fields. 36 | 37 | #### Datasets 38 | 39 | Image and 3D model datasets should be downloaded and placed in the `data` directory. The datasets used in the paper can be accessed as follows. 40 | 41 | - Public domain image of Pluto is included in the repository *(NASA/Johns Hopkins University Applied Physics Laboratory/Southwest Research Institute/Alex Parker)* 42 | - [Gigapixel image of Tokyo](https://drive.google.com/file/d/1ITWSv8KcZ_HPNrCXbbbkwzXDSDMr7ACg/view?usp=sharing) *(Trevor Dobson [CC BY-NC-ND 2.0](https://creativecommons.org/licenses/by-nc-nd/2.0/) image resized from [original](https://www.flickr.com/photos/trevor_dobson_inefekt69/29314390837))* 43 | - Public domain [Gigapixel image of Mars](https://drive.google.com/file/d/1Ro1lWxRsl97Jbzm9EA2k9nUEyyVUwxEu/view?usp=sharing) *(NASA/JPL-Caltech/MSSS)* 44 | - Blender engine model [.obj](https://drive.google.com/file/d/1NU2I1Vly6X7YZWD1z_JiBx67XSJ_iR8d/view?usp=sharing), [.blend](https://www.blendswap.com/blend/17636) *(ChrisKuhn [CC-BY](https://creativecommons.org/licenses/by/2.0/))* 45 | - Lucy dataset [.ply](http://graphics.stanford.edu/data/3Dscanrep/) (Stanford 3D Scanning Repository) 46 | - Thai Statue dataset [.ply](http://graphics.stanford.edu/data/3Dscanrep/) (Stanford 3D Scanning Repository) 47 | - Dragon dataset ([TurboSquid](https://www.turbosquid.com/3d-models/chinese-printing-3d-model-1548953)) 48 | 49 | #### Training 50 | 51 | To use ACORN, first set up the conda environment and build the Cython extension with 52 | ``` 53 | conda env create -f environment.yml 54 | conda activate acorn 55 | cd inside_mesh 56 | python setup.py build_ext --inplace 57 | ``` 58 | 59 | Then, download the datasets to the `data` folder. 60 | 61 | We use Gurobi to perform solve the integer linear program used in the optimization. A free academic license can be installed from [this link](https://www.gurobi.com/downloads/free-academic-license/). 62 | 63 | To train image representations, use the config files in the `experiment_scripts/config_img` folder. For example, to train on the Pluto image, run the following 64 | ``` 65 | python train_img.py --config ./config_img/config_pluto_1k.ini 66 | tensorboard --logdir=../logs/ --port=6006 67 | ``` 68 | 69 | After the image representation has been trained, the decomposition and images can be exported using the following command. 70 | 71 | ``` 72 | python train_img.py --config ../logs//config.ini --resume ../logs/ --eval 73 | ``` 74 | 75 | Exported images will appear in the `../logs//eval` folder, where `` is the subdirectory in the `log` folder corresponding to the particular training run. 76 | 77 | To train 3D models, download the datasets, and then use the corresponding config file in `experiment_scripts/config_occupancy`. For example, a small model representing the Lucy statue can be trained with 78 | 79 | ``` 80 | python train_occupancy.py --config ./config_occupancy/config_lucy_small_acorn.ini 81 | ``` 82 | 83 | Then a mesh of the final model can be exported with 84 | ``` 85 | python train_occupancy.py --config ../logs//config.ini --load ../logs/ --export 86 | ``` 87 | 88 | This will create a `.dae` mesh file in the `../logs/` folder. 89 | 90 | ## Citation 91 | 92 | ``` 93 | @article{martel2021acorn, 94 | title={ACORN: {Adaptive} coordinate networks for neural scene representation}, 95 | author={Julien N. P. Martel and David B. Lindell and Connor Z. Lin and Eric R. Chan and Marco Monteiro and Gordon Wetzstein}, 96 | journal={ACM Trans. Graph. (SIGGRAPH)}, 97 | volume={40}, 98 | number={4}, 99 | year={2021}, 100 | } 101 | ``` 102 | ## Acknowledgments 103 | 104 | We include the MIT licensed `inside_mesh` code in this repo from Lars Mescheder, Michael Oechsle, Michael Niemeyer, Andreas Geiger, and Sebastian Nowozin, which is originally included in their [Occupancy Networks repository](https://github.com/autonomousvision/occupancy_networks/tree/ddb2908f96de9c0c5a30c093f2a701878ffc1f4a/im2mesh/utils/libmesh 105 | ). 106 | 107 | J.N.P. Martel was supported by a Swiss National Foundation (SNF) Fellowship (P2EZP2 181817). C.Z. Lin was supported by a David Cheriton Stanford Graduate Fellowship. G.W. was supported by an Okawa Research Grant, a Sloan Fellowship, and a PECASE by the ARO. Other funding for the project was provided by NSF (award numbers 1553333 and 1839974). 108 | 109 | ## Errata 110 | - The 3D shape fitting metrics were reported in the paper as calculated using the Chamfer-L1 distance. The metric should have been labeled Chamfer-L2, which is consistent with the implementation in this repository. 111 | -------------------------------------------------------------------------------- /data/pluto.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/ACORN/67817daaf20d38840f043e81704e7a6ea5da584e/data/pluto.jpg -------------------------------------------------------------------------------- /data_structs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import gurobipy as gp 4 | from gurobipy import GRB 5 | import copy 6 | import matplotlib.pyplot as plt 7 | import matplotlib.colors as colors 8 | import matplotlib.patches as patches 9 | 10 | 11 | class QuadTree(): 12 | def __init__(self, sidelength, patch_size): 13 | self.sidelength = sidelength 14 | 15 | self.patch_size = patch_size 16 | self.min_patch_size = np.min(patch_size) 17 | self.max_patch_size = np.min(sidelength) 18 | self.aspect_ratio = np.array(patch_size) / np.min(patch_size) 19 | 20 | # how many levels of quadtree are there 21 | self.min_quadtree_level = int(np.log2(np.min(self.sidelength) // self.max_patch_size)) 22 | self.max_quadtree_level = int(np.log2(np.min(sidelength) // self.min_patch_size)) 23 | self.num_scales = self.max_quadtree_level - self.min_quadtree_level + 1 24 | 25 | # optimization model 26 | self.optim_model = gp.Model() 27 | self.optim_model.setParam('OutputFlag', 0) 28 | self.c_max_patches = None 29 | 30 | # initialize tree 31 | self.root = self.init_root(self.max_quadtree_level) 32 | 33 | # populate tree nodes with coordinate values, metadata 34 | self.populate_tree() 35 | 36 | def __deepcopy__(self, memo): 37 | deep_copied_obj = QuadTree(self.sidelength, self.patch_size) 38 | 39 | for k, v in self.__dict__.items(): 40 | if k in ['optim_model', 'c_max_patches']: 41 | # setattr(deep_copied_obj, k, v) 42 | del(deep_copied_obj.__dict__[k]) 43 | else: 44 | setattr(deep_copied_obj, k, copy.deepcopy(v, memo)) 45 | return deep_copied_obj 46 | 47 | def __getstate__(self): 48 | state = self.__dict__.copy() 49 | for k, v in self.__dict__.items(): 50 | if k in ['optim_model', 'c_max_patches']: 51 | del(state[k]) 52 | return state 53 | 54 | def __load__(self, obj): 55 | for k, v in obj.__dict__.items(): 56 | if k == 'root': 57 | continue 58 | setattr(self, k, v) 59 | self.root = self.init_root(obj.max_quadtree_level) 60 | self.populate_tree() 61 | 62 | def _load_helper(curr_patch, curr_obj_patch): 63 | curr_patch.__load__(curr_obj_patch) 64 | for child, obj_child in zip(curr_patch.children, curr_obj_patch.children): 65 | _load_helper(child, obj_child) 66 | return 67 | _load_helper(self.root, obj.root) 68 | 69 | def __str__(self, level=0): 70 | 71 | def _str_helper(curr_patch, level): 72 | ret = "\t"*level+repr(curr_patch.active)+"\n" 73 | for child in curr_patch.children: 74 | ret += _str_helper(child, level+1) 75 | return ret 76 | return _str_helper(self.root, 0) 77 | 78 | def populate_tree(self): 79 | 80 | # get block coords for patches at each scale 81 | patch_sizes = [] 82 | curr_size = self.max_patch_size 83 | while True: 84 | patch_sizes.append(curr_size) 85 | curr_size //= 2 86 | if curr_size == self.min_patch_size: 87 | patch_sizes.append(curr_size) 88 | break 89 | elif curr_size < self.min_patch_size: 90 | raise ValueError('Patch sizes and resolution are incompatible') 91 | 92 | block_coords = [self.get_block_coords(patch_size=patch_size, include_ends=True) for patch_size in patch_sizes] 93 | block_sizes = [block[1, 1, :] - block[0, 0, :] for block in block_coords] 94 | block_coords = [block[:-1, :-1, :] for block in block_coords] 95 | 96 | # create sampling grids for training 97 | num_samples = self.min_patch_size * self.aspect_ratio 98 | row_posts = torch.linspace(-1, 1, int(self.min_patch_size*self.aspect_ratio[0])+1)[:-1] 99 | col_posts = torch.linspace(-1, 1, int(self.min_patch_size*self.aspect_ratio[1])+1)[:-1] 100 | row_coords, col_coords = torch.meshgrid(row_posts, col_posts) 101 | row_coords = row_coords.flatten() 102 | col_coords = col_coords.flatten() 103 | 104 | # create sampling grids for evaluation 105 | # here we need to sample every pixel within each block 106 | row_posts = [torch.linspace(-1, 1, int(pixel_size*self.aspect_ratio[0])+1)[:-1] for pixel_size in patch_sizes] 107 | col_posts = [torch.linspace(-1, 1, int(pixel_size*self.aspect_ratio[1])+1)[:-1] for pixel_size in patch_sizes] 108 | eval_coords = [torch.meshgrid(row_post, col_post) for row_post, col_post in zip(row_posts, col_posts)] 109 | eval_row_coords = [eval_coord[0].flatten() for eval_coord in eval_coords] 110 | eval_col_coords = [eval_coord[1].flatten() for eval_coord in eval_coords] 111 | 112 | def _populate_tree_helper(patch, idx): 113 | # get block scale idx 114 | scale_idx = len(idx) - (self.min_quadtree_level) 115 | 116 | # do we have patches at this level? 117 | if scale_idx >= 0: 118 | 119 | # set patch parameters 120 | coords = block_coords[scale_idx] 121 | coord_idx = _index_block_coord(idx, coords.shape[0], coord=[0, 0]) 122 | patch.block_coord = coords[coord_idx[0], coord_idx[1]] 123 | patch.block_size = block_sizes[scale_idx] 124 | patch.scale = scale_idx 125 | patch.pixel_size = patch_sizes[scale_idx] 126 | patch.num_samples = num_samples 127 | patch.row_coords = row_coords 128 | patch.col_coords = col_coords 129 | patch.eval_row_coords = eval_row_coords[scale_idx] 130 | patch.eval_col_coords = eval_col_coords[scale_idx] 131 | 132 | if not patch.children: 133 | return 134 | 135 | # recurse 136 | for i in range(4): 137 | child = patch.children[i] 138 | _populate_tree_helper(child, [*idx, i]) 139 | 140 | return 141 | 142 | # given list of tree idxs in {0,1,2,3}^N, retrieve the block coordinate 143 | def _index_block_coord(tree_idx, length, coord=[0, 0]): 144 | if length == 1: 145 | return coord 146 | 147 | if tree_idx[0] == 0: 148 | pass 149 | elif tree_idx[0] == 1: 150 | coord[1] += length//2 151 | elif tree_idx[0] == 2: 152 | coord[0] += length//2 153 | elif tree_idx[0] == 3: 154 | coord[0] += length//2 155 | coord[1] += length//2 156 | else: 157 | raise ValueError("Unexpected child value, should be 0, 1, 2, or 3") 158 | 159 | return _index_block_coord(tree_idx[1:], length//2, coord) 160 | 161 | # done with setup, now actually populate the tree 162 | _populate_tree_helper(self.root, []) 163 | 164 | def init_root(self, max_level): 165 | 166 | def _init_root_helper(curr_patch, curr_level, max_level, optim_model): 167 | if curr_level == max_level: 168 | return 169 | curr_patch.children = [Patch(optim_model) for _ in range(4)] 170 | 171 | for patch in curr_patch.children: 172 | patch.parent = curr_patch 173 | _init_root_helper(patch, curr_level+1, max_level, optim_model) 174 | return 175 | 176 | # create root node 177 | root = Patch(self.optim_model) 178 | _init_root_helper(root, 0, max_level, self.optim_model) 179 | return root 180 | 181 | def get_block_coords(self, flatten=False, include_ends=False, patch_size=None): 182 | 183 | patch_size = patch_size * self.aspect_ratio 184 | 185 | # get size of each block 186 | block_size = (2 / (self.sidelength[0]-1) * patch_size[0], 2 / (self.sidelength[1]-1) * patch_size[1]) 187 | 188 | # get block begin/end coordinates 189 | if include_ends: 190 | block_coords_y = torch.arange(-1, 1+block_size[0], block_size[0]) 191 | block_coords_x = torch.arange(-1, 1+block_size[1], block_size[1]) 192 | else: 193 | block_coords_y = torch.arange(-1, 1, block_size[0]) 194 | block_coords_x = torch.arange(-1, 1, block_size[1]) 195 | 196 | # repeat for every single block 197 | block_coords = torch.meshgrid(block_coords_y, block_coords_x) 198 | block_coords = torch.stack((block_coords[0], block_coords[1]), dim=-1) 199 | if flatten: 200 | block_coords = block_coords.reshape(-1, 2) 201 | 202 | return block_coords 203 | 204 | def get_patches_at_level(self, level): 205 | # level is the image scale: 0-> coarsest, N->finest 206 | if level == -1: 207 | level = self.max_quadtree_level 208 | 209 | # what quadtree level do our patches start at? 210 | # check input, too 211 | target_level = level + self.min_quadtree_level 212 | assert level <= (self.max_quadtree_level - self.min_quadtree_level), \ 213 | "invalid 'level' input to get_blocks_at_level" 214 | 215 | def _get_patches_at_level_helper(curr_patch, curr_level, patches): 216 | if curr_level > target_level: 217 | return 218 | 219 | for patch in curr_patch.children: 220 | _get_patches_at_level_helper(patch, curr_level+1, patches) 221 | 222 | if curr_level == target_level: 223 | patches.append(curr_patch) 224 | return patches 225 | 226 | return _get_patches_at_level_helper(self.root, 0, []) 227 | 228 | def get_frozen_patches(self): 229 | 230 | def _get_frozen_patches_helper(curr_patch, patches): 231 | if curr_patch.frozen and curr_patch.active: 232 | patches.append(curr_patch) 233 | for patch in curr_patch.children: 234 | _get_frozen_patches_helper(patch, patches) 235 | return patches 236 | return _get_frozen_patches_helper(self.root, []) 237 | 238 | def get_active_patches(self, include_frozen_patches=False): 239 | 240 | def _get_active_patches_helper(curr_patch, patches): 241 | if curr_patch.active and \ 242 | (include_frozen_patches or 243 | (not include_frozen_patches and not curr_patch.frozen)): 244 | patches.append(curr_patch) 245 | for patch in curr_patch.children: 246 | _get_active_patches_helper(patch, patches) 247 | return patches 248 | return _get_active_patches_helper(self.root, []) 249 | 250 | def activate_random(self): 251 | def _activate_random_helper(curr_patch): 252 | if not curr_patch.children: 253 | curr_patch.activate() 254 | return 255 | elif (curr_patch.scale is not None) and (torch.rand(1).item() < 0.2): 256 | curr_patch.activate() 257 | return 258 | 259 | for patch in curr_patch.children: 260 | _activate_random_helper(patch) 261 | return 262 | 263 | _activate_random_helper(self.root) 264 | 265 | def synchronize(self, master): 266 | # set active/inactive nodes to be the same as master 267 | # for now just toggle the flags without worrying about the gurobi variables 268 | def _synchronize_helper(curr_patch, curr_patch_master): 269 | curr_patch.active = curr_patch_master.active 270 | if not curr_patch.children: 271 | return 272 | 273 | for patch, patch_master in zip(curr_patch.children, curr_patch_master.children): 274 | _synchronize_helper(patch, patch_master) 275 | return 276 | 277 | _synchronize_helper(self.root, master.root) 278 | 279 | def get_frozen_samples(self): 280 | patches = self.get_frozen_patches() 281 | if not patches: 282 | return None, None, None, None 283 | 284 | rel_coords, abs_coords, vals = [], [], [] 285 | patch_idx = [] 286 | 287 | for idx, p in enumerate(patches): 288 | rel_samp, abs_samp = p.get_stratified_samples(jitter=False, eval=True) 289 | 290 | rel_samp = rel_samp.reshape(-1, int(np.prod(self.min_patch_size * self.aspect_ratio)), 2) 291 | abs_samp = abs_samp.reshape(-1, int(np.prod(self.min_patch_size * self.aspect_ratio)), 2) 292 | patch_idx.extend(rel_samp.shape[0] * [idx, ]) 293 | 294 | rel_coords.append(rel_samp) 295 | abs_coords.append(abs_samp) 296 | # values have the same size as rel_samp but last dim is a scalar 297 | vals.append(p.value*torch.ones(abs_samp.shape[:-1] + (1,))) 298 | 299 | return torch.cat(rel_coords, dim=0), torch.cat(abs_coords, dim=0), \ 300 | torch.cat(vals, dim=0), patch_idx 301 | 302 | def get_stratified_samples(self, jitter=True, eval=False): 303 | patches = self.get_active_patches() 304 | 305 | rel_coords, abs_coords = [], [] 306 | patch_idx = [] 307 | 308 | for idx, p in enumerate(patches): 309 | rel_samp, abs_samp = p.get_stratified_samples(jitter=jitter, eval=eval) 310 | 311 | # always batch the coordinates in groups of a specific patch size 312 | # so we can process them in parallel 313 | rel_samp = rel_samp.reshape(-1, int(np.prod(self.min_patch_size * self.aspect_ratio)), 2) 314 | abs_samp = abs_samp.reshape(-1, int(np.prod(self.min_patch_size * self.aspect_ratio)), 2) 315 | 316 | # since patch samples could be split across batches, 317 | # keep track of which batch idx maps to which patch idx 318 | patch_idx.extend(rel_samp.shape[0] * [idx, ]) 319 | 320 | rel_coords.append(rel_samp) 321 | abs_coords.append(abs_samp) 322 | return torch.cat(rel_coords, dim=0), torch.cat(abs_coords, dim=0), patch_idx 323 | 324 | def solve_optim(self, max_num_patches=1024): 325 | patches = self.get_active_patches() 326 | 327 | assert (len(patches) <= max_num_patches), \ 328 | "You are trying to solve a model which is infeasible: " \ 329 | "Number of active patches > Max number of patches" 330 | 331 | if self.c_max_patches is not None: 332 | self.optim_model.remove(self.c_max_patches) 333 | 334 | # global "knapsack" constraint 335 | expr_sum_patches = [p.update_merge() for p in patches] 336 | self.c_max_patches = self.optim_model.addConstr(gp.quicksum(expr_sum_patches) <= max_num_patches) 337 | 338 | # objective 339 | self.optim_model.setObjective(gp.quicksum([p.get_cost() for p in patches]), GRB.MINIMIZE) 340 | self.optim_model.optimize() 341 | obj_val = self.optim_model.objVal 342 | 343 | if self.optim_model.Status == GRB.INFEASIBLE: 344 | print("----------- Model is infeasible") 345 | self.optim_model.computeIIS() 346 | self.optim_model.write("model.ilp") 347 | 348 | # split and merge 349 | merged = 0 350 | split = 0 351 | none = 0 352 | for p in patches: 353 | # print(p) 354 | if p.has_split() and p.scale < self.max_quadtree_level: 355 | p.deactivate() 356 | for child in p.get_children(): 357 | child.activate() 358 | split += 1 359 | elif p.has_merged() and p.scale >= self.min_quadtree_level and p.scale > 0: 360 | # we first check if it is active, 361 | # since we could have already been activated by a neighbor 362 | if p.active: 363 | for neighbor in p.get_neighbors(): 364 | neighbor.deactivate() 365 | p.parent.activate() 366 | merged += 1 367 | 368 | else: 369 | p.update() 370 | none += 1 371 | stats_dict = {'merged': merged, 372 | 'splits': split, 373 | 'none': none, 374 | 'obj': obj_val} 375 | print(f"============================= Total patches:{len(patches)}, split/merge:{split}/{merged}") 376 | return stats_dict 377 | 378 | def draw(self): 379 | fig, ax = plt.subplots(1, figsize=(5, 5)) 380 | depth = 1 + self.max_quadtree_level - self.min_quadtree_level 381 | sidelen = 4**(depth-1) // 2**(depth-1) 382 | 383 | # calculate scale 384 | patch_list = self.get_active_patches() 385 | patches_err = [p.err for p in patch_list] 386 | 387 | max_err = np.max(patches_err) 388 | min_err = np.min(patches_err) 389 | 390 | cmap = plt.cm.get_cmap('viridis') 391 | 392 | def _draw_level(patch, curr_level, ax, sidelen, offset, scale): 393 | if curr_level > self.max_quadtree_level: 394 | return ax 395 | 396 | scale = scale/2. 397 | 398 | for i, child in enumerate(patch.children): 399 | if i == 0: 400 | new_offset = (offset[0], offset[1]) 401 | elif i == 1: 402 | new_offset = (offset[0] + scale * sidelen, offset[1]) 403 | elif i == 2: 404 | new_offset = (offset[0], offset[1] + scale * sidelen) 405 | else: 406 | new_offset = (offset[0] + scale * sidelen, offset[1] + scale * sidelen) 407 | 408 | if child.active: 409 | norm_err = (child.err-min_err)/(max_err-min_err) 410 | if child.frozen: 411 | facecolor = 'white' 412 | edgecolor = 'red' 413 | else: 414 | facecolor = cmap(norm_err) 415 | edgecolor = 'black' 416 | rect = patches.Rectangle(new_offset, scale * sidelen, scale * sidelen, linewidth=1, 417 | edgecolor=edgecolor, 418 | facecolor=facecolor, fill=True) 419 | ax.add_patch(rect) 420 | else: 421 | ax = _draw_level(child, curr_level+1, ax, sidelen, new_offset, scale) 422 | 423 | return ax 424 | 425 | ax = _draw_level(self.root, self.min_quadtree_level, ax, sidelen, (0., 0.), 1.) 426 | ax.set_aspect('equal') 427 | plt.xlim(-1, sidelen + 1) 428 | plt.ylim(-1, sidelen + 1) 429 | plt.gca().invert_yaxis() # we want 0,0 to be on top-left 430 | 431 | norm = colors.Normalize(vmin=min_err, vmax=max_err) 432 | sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) 433 | sm.set_array([]) 434 | plt.colorbar(sm) 435 | return fig 436 | 437 | 438 | # patch class 439 | class Patch(): 440 | def __init__(self, optim_model=None, block_coord=None, scale=None): 441 | self.active = False 442 | 443 | self.parent = None 444 | self.children = [] 445 | 446 | # absolute block coordinate 447 | self.block_coord = block_coord 448 | 449 | # size of block in absolute coord frame 450 | self.block_size = None 451 | 452 | # scale level of block 453 | self.scale = scale 454 | 455 | # num samples to be generated for this block 456 | self.num_samples = None 457 | 458 | # num pixels in this patch 459 | self.pixel_size = None 460 | 461 | # optimization model 462 | self.optim = optim_model 463 | 464 | # row/column coords for sampling at test time 465 | # initialized by set_samples() function 466 | self.row_coords = None 467 | self.col_coords = None 468 | self.eval_row_coords = None 469 | self.eval_col_coords = None 470 | 471 | # error for doing nothing, merging, splitting 472 | self.err = 0. 473 | self.last_updated = 0. 474 | 475 | self._nocopy = ['optim', 'I_grp', 'I_split', 'I_none', 476 | 'I_merge', 'c_joinable', 'c_merge_split'] 477 | self._pickle_vars = ['parent', 'children', 'active', 'err', 'last_updated'] 478 | self.spec_cstrs = [] 479 | 480 | # options for pruning 481 | self.frozen = False 482 | self.value = 0.0 483 | 484 | def __str__(self): 485 | str = f"Patch id={id(self)}\n" \ 486 | f" . active={self.active}\n" \ 487 | f" . level={self.scale}\n" \ 488 | f" . model={self.optim}" 489 | 490 | if self.active: 491 | str += f"\n . g={self.I_grp.x}, s={self.I_split.x}, n={self.I_none.x}" 492 | 493 | return str 494 | 495 | # override deep copy to copy undeepcopyable objects by reference 496 | def __deepcopy__(self, memo): 497 | deep_copied_obj = Patch() 498 | for k, v in self.__dict__.items(): 499 | if k in self._nocopy: 500 | # setattr(deep_copied_obj, k, None) 501 | if k in deep_copied_obj.__dict__.keys(): 502 | del(deep_copied_obj.__dict__[k]) 503 | else: 504 | setattr(deep_copied_obj, k, copy.deepcopy(v, memo)) 505 | 506 | return deep_copied_obj 507 | 508 | def __getstate__(self): 509 | state = self.__dict__.copy() 510 | for k, v in self.__dict__.items(): 511 | if k in self._nocopy: 512 | # if k not in self._pickle_vars: 513 | del(state[k]) 514 | return state 515 | 516 | def __load__(self, obj): 517 | for k, v in obj.__dict__.items(): 518 | if k in ['children', 'parent']: 519 | continue 520 | setattr(self, k, v) 521 | if self.active: 522 | self.activate() 523 | 524 | def update(self): 525 | self.deactivate() 526 | self.activate() 527 | 528 | def activate(self): 529 | self.active = True 530 | 531 | # indicator variables 532 | self.I_grp = self.optim.addVar(vtype=GRB.BINARY) 533 | self.I_split = self.optim.addVar(vtype=GRB.BINARY) 534 | self.I_none = self.optim.addVar(vtype=GRB.BINARY) 535 | 536 | self.I_merge = gp.LinExpr(0.0) 537 | 538 | # local constraint "merge/none/split" 539 | self.c_joinable = self.optim.addConstr(self.I_grp + self.I_none + self.I_split == 1) 540 | 541 | # local constraint "merge-split" 542 | self.c_merge_split = None 543 | 544 | def deactivate(self): 545 | self.active = False 546 | 547 | self.optim.remove(self.I_grp) 548 | self.optim.remove(self.I_split) 549 | self.optim.remove(self.I_none) 550 | 551 | self.I_merge = gp.LinExpr(0.0) 552 | 553 | self.optim.remove(self.c_joinable) 554 | 555 | if self.c_merge_split is not None: 556 | self.optim.remove(self.c_merge_split) 557 | 558 | for cstr in self.spec_cstrs: 559 | self.optim.remove(cstr) 560 | self.spec_cstrs = [] 561 | 562 | def is_mergeable(self): 563 | siblings = self.parent.children 564 | return np.all(np.all([sib.active for sib in siblings])) 565 | 566 | def set_sample_params(self, num_samples): 567 | self.num_samples = num_samples 568 | posts = torch.linspace(-1, 1, self.num_samples+1)[:-1] 569 | row_coords, col_coords = torch.meshgrid(posts, posts) 570 | self.row_coords = row_coords.flatten() 571 | self.col_coords = col_coords.flatten() 572 | 573 | def must_split(self): 574 | self.spec_cstrs.append( 575 | self.optim.addConstr(self.I_split == 1) 576 | ) 577 | 578 | def must_merge(self): 579 | self.spec_cstrs.append( 580 | self.optim.addConstr(self.I_grp == 1) 581 | ) 582 | 583 | def has_split(self): 584 | return self.I_split.x == 1 585 | 586 | def has_merged(self): 587 | return self.I_grp.x == 1 588 | # return self.I_none.x==0 and self.I_split.x==0 589 | 590 | def has_done_nothing(self): 591 | return self.I_none.x == 1 592 | 593 | def get_cost(self): 594 | area = self.block_size[0]**2 595 | alpha = 0.2 # how much worse we expect the error to be when merging 596 | beta = -0.02 # how much better we expect the error to be when splitting 597 | 598 | # == Merge 599 | if self.scale > 0: # it should never be root, but still.. 600 | err_merge = (4+alpha) * area * self.err 601 | 602 | if self.parent.last_updated: 603 | parent_area = self.parent.block_size[0]**2 604 | err_merge = parent_area * self.parent.err # can multiply by 1/4 as in paper to make merging more aggressive 605 | else: 606 | err_merge = self.err 607 | 608 | # == Split 609 | if self.children: 610 | err_split = (0.25+beta) * area * self.err 611 | 612 | if self.children[0].last_updated: 613 | err_children = np.sum([child.err for child in self.children]) 614 | err_split = area * err_children 615 | else: 616 | err_split = 1. # in case you don't have children, high to avoid splitting 617 | 618 | # == None 619 | err_none = area * self.err 620 | 621 | return err_none * self.I_none \ 622 | + err_split * self.I_split \ 623 | + err_merge * self.I_grp 624 | 625 | def update_merge(self): 626 | if self.parent is None: # if root 627 | return gp.LinExpr(0) 628 | 629 | siblings = self.parent.children 630 | if np.all([sib.active for sib in siblings]): 631 | I_grp_neighs = [s.I_grp for s in siblings] 632 | self.I_merge = gp.quicksum(I_grp_neighs) 633 | 634 | # local constraint "joinable" 635 | self.c_merge_split = self.optim.addConstr(self.I_none + self.I_split + .25*self.I_merge == 1) 636 | expr_max_patches = 4 * self.I_split + 1 * self.I_none + .25 * self.I_grp 637 | 638 | return expr_max_patches 639 | 640 | def get_neighbors(self): 641 | return self.parent.children 642 | 643 | def get_children(self): 644 | return self.children 645 | 646 | def get_parent(self): 647 | return self.parent 648 | 649 | def is_joinable(self): 650 | # test if siblings are all leaf nodes 651 | siblings = self.parent.children 652 | return np.all([sib.active for sib in siblings]) 653 | 654 | def get_block_coord(self): 655 | return self.block_coord 656 | 657 | def get_scale(self): 658 | return self.scale 659 | 660 | def update_error(self, error, iter): 661 | self.err = error 662 | self.last_updated = iter 663 | 664 | def get_stratified_samples(self, jitter=True, eval=False): 665 | # Block coords are always aligned to the pixel grid, 666 | # e.g., they align with pixels 0, 8, 16, 24, etc. for 667 | # patch size 8 668 | # 669 | # To normalize the coordinates between (-1, 1), consider 670 | # we have an image of 64x64 and patch size 8x8. 671 | # The block coordinate (-1, -1) aligns with pixel (0, 0) 672 | # and coordinate (1, 1) aligns with pixel (63, 63) 673 | # 674 | # Absolute coordinates within a block should stretch all the way 675 | # from the absolute position of one block coordinate to another. 676 | # Say each block contains 8x8 pixels and we use a feature grid 677 | # of 8x8 features to interpolate values within a block. 678 | # This means is that the feature positions are not actually 679 | # aligned to the pixel positions. The features are positioned 680 | # on a grid stretching from one block coord to another whereas 681 | # the pixel grid ends just short of the next block coordinate 682 | # 683 | # Example patch (x = pixel position, B = block coordinate position) 684 | # and relative coordinate positions. 685 | # 686 | # -1 ^ B x x x x x x x B 687 | # | x x x x x x x x x 688 | # | x x x x x x x x x 689 | # | x x x x x x x x x 690 | # | x x x x x x x x x 691 | # | x x x x x x x x x 692 | # | x x x x x x x x x 693 | # | x x x x x x x x x 694 | # 1 v B x x x x x x x B 695 | # <---------------> 696 | # -1 1 697 | # 698 | # When we generate samples for a patch, we sample an 699 | # 8x8 grid that extends between block coords, i.e. 700 | # between the arrows above 701 | # 702 | if eval: 703 | row_coords = self.eval_row_coords.flatten() 704 | col_coords = self.eval_col_coords.flatten() 705 | else: 706 | row_coords = self.row_coords 707 | col_coords = self.col_coords 708 | 709 | if jitter: 710 | row_coords = self.row_coords + torch.rand_like(self.row_coords) * 2./self.num_samples[0] 711 | col_coords = self.col_coords + torch.rand_like(self.col_coords) * 2./self.num_samples[1] 712 | 713 | rel_samples = torch.stack((row_coords, col_coords), dim=-1) 714 | abs_samples = self.block_coord[None, :] + self.block_size[None, :] * (rel_samples+1)/2 715 | return rel_samples, abs_samples 716 | 717 | 718 | class OctTree(): 719 | def __init__(self, sidelength, min_octant_size, bounds=((-1, 1), (-1, 1), (-1, 1)), mesh_kd_tree=None): 720 | self.sidelength = sidelength 721 | self.min_octant_size = min_octant_size 722 | self.max_octant_size = sidelength[0] 723 | 724 | # how many levels of quadtree are there 725 | self.min_octtree_level = int(np.log2(self.sidelength[0] // self.max_octant_size)) 726 | self.max_octtree_level = int(np.log2(sidelength[0] // min_octant_size)) 727 | self.num_scales = self.max_octtree_level - self.min_octtree_level + 1 728 | 729 | # optimization model 730 | self.optim_model = gp.Model() 731 | self.optim_model.setParam('OutputFlag', 0) 732 | self.c_max_octants = None 733 | 734 | # set bounds 735 | self.z_min, self.z_max = bounds[0] 736 | self.y_min, self.y_max = bounds[1] 737 | self.x_min, self.x_max = bounds[2] 738 | 739 | # KD tree that stores points on mesh surface 740 | self.surface_tree = mesh_kd_tree 741 | 742 | # initialize tree 743 | self.root = self.init_root(self.max_octtree_level) 744 | 745 | # populate tree nodes with coordinate values, metadata 746 | self.populate_tree() 747 | 748 | def __deepcopy__(self, memo): 749 | deep_copied_obj = OctTree(self.sidelength, self.min_octant_size) 750 | 751 | for k, v in self.__dict__.items(): 752 | if k in ['optim_model', 'c_max_octants']: 753 | setattr(deep_copied_obj, k, v) 754 | else: 755 | setattr(deep_copied_obj, k, copy.deepcopy(v, memo)) 756 | return deep_copied_obj 757 | 758 | def __getstate__(self): 759 | state = self.__dict__.copy() 760 | for k, v in self.__dict__.items(): 761 | if k in ['optim_model', 'c_max_octants']: 762 | del(state[k]) 763 | return state 764 | 765 | def __load__(self, obj): 766 | for k, v in obj.__dict__.items(): 767 | if k == 'root': 768 | continue 769 | setattr(self, k, v) 770 | self.root = self.init_root(obj.max_octtree_level) 771 | self.populate_tree() 772 | 773 | def _load_helper(curr_patch, curr_obj_patch): 774 | curr_patch.__load__(curr_obj_patch) 775 | for child, obj_child in zip(curr_patch.children, curr_obj_patch.children): 776 | _load_helper(child, obj_child) 777 | return 778 | _load_helper(self.root, obj.root) 779 | 780 | def __str__(self, level=0): 781 | 782 | def _str_helper(curr_octant, level): 783 | ret = "\t"*level+repr(curr_octant.active)+"\n" 784 | for child in curr_octant.children: 785 | ret += _str_helper(child, level+1) 786 | return ret 787 | return _str_helper(self.root, 0) 788 | 789 | def populate_tree(self): 790 | 791 | # maximum octant scale 792 | max_octant_scale = int(np.log2(self.max_octant_size)) 793 | min_octant_scale = int(np.log2(self.min_octant_size)) 794 | 795 | # get block coords for octants at each scale 796 | octant_sizes = [2**s for s in range(min_octant_scale, max_octant_scale+1)] 797 | octant_sizes.reverse() 798 | block_coords = [self.get_block_coords(octant_size=octant_size, include_ends=True) for octant_size in octant_sizes] 799 | block_sizes = [block[1, 1, 1, :] - block[0, 0, 0, :] for block in block_coords] 800 | block_coords = [block[:-1, :-1, :-1, :] for block in block_coords] 801 | 802 | # create sampling grids for training 803 | num_samples = self.min_octant_size 804 | posts = torch.linspace(-1, 1, self.min_octant_size+1)[:-1] 805 | row_coords, col_coords, dep_coords = torch.meshgrid(posts, posts, posts) 806 | row_coords = row_coords.flatten() 807 | col_coords = col_coords.flatten() 808 | dep_coords = dep_coords.flatten() 809 | 810 | # create sampling grids for evaluation 811 | # here we need to sample every voxel within each block 812 | posts = [torch.linspace(-1, 1, voxel_size+1)[:-1] + (1/voxel_size)/2 for voxel_size in octant_sizes] 813 | eval_coords = [torch.meshgrid(post, post, post) for post in posts] 814 | eval_row_coords = [eval_coord[0].flatten() for eval_coord in eval_coords] 815 | eval_col_coords = [eval_coord[1].flatten() for eval_coord in eval_coords] 816 | eval_dep_coords = [eval_coord[2].flatten() for eval_coord in eval_coords] 817 | 818 | def _populate_tree_helper(octant, idx): 819 | # get block scale idx 820 | scale_idx = len(idx) - (self.min_octtree_level) 821 | 822 | # do we have octants at this level? 823 | if scale_idx >= 0: 824 | 825 | # set patch parameters 826 | coords = block_coords[scale_idx] 827 | coord_idx = _index_block_coord(idx, coords.shape[0], coord=[0, 0, 0]) 828 | octant.block_coord = coords[coord_idx[0], coord_idx[1], coord_idx[2]] 829 | octant.block_size = block_sizes[scale_idx] 830 | octant.scale = scale_idx 831 | octant.voxel_size = octant_sizes[scale_idx] 832 | octant.num_samples = num_samples 833 | octant.row_coords = row_coords 834 | octant.col_coords = col_coords 835 | octant.dep_coords = dep_coords 836 | octant.eval_row_coords = eval_row_coords[scale_idx] 837 | octant.eval_col_coords = eval_col_coords[scale_idx] 838 | octant.eval_dep_coords = eval_dep_coords[scale_idx] 839 | octant.surface_tree = self.surface_tree 840 | 841 | if not octant.children: 842 | return 843 | 844 | # recurse 845 | for i in range(8): 846 | child = octant.children[i] 847 | _populate_tree_helper(child, [*idx, i]) 848 | 849 | return 850 | 851 | # given list of tree idxs in {0,1,2,3}^N, retrieve the block coordinate 852 | def _index_block_coord(tree_idx, length, coord=[0, 0, 0]): 853 | if length == 1: 854 | return coord 855 | 856 | # depth 0 857 | if tree_idx[0] == 0: 858 | pass 859 | elif tree_idx[0] == 1: 860 | coord[1] += length//2 861 | elif tree_idx[0] == 2: 862 | coord[0] += length//2 863 | elif tree_idx[0] == 3: 864 | coord[0] += length//2 865 | coord[1] += length//2 866 | # depth 1 867 | elif tree_idx[0] == 4: 868 | coord[2] += length//2 869 | elif tree_idx[0] == 5: 870 | coord[1] += length//2 871 | coord[2] += length//2 872 | elif tree_idx[0] == 6: 873 | coord[0] += length//2 874 | coord[2] += length//2 875 | elif tree_idx[0] == 7: 876 | coord[0] += length//2 877 | coord[1] += length//2 878 | coord[2] += length//2 879 | else: 880 | raise ValueError("Unexpected child value, should be in{0,...7}") 881 | 882 | return _index_block_coord(tree_idx[1:], length//2, coord) 883 | 884 | # done with setup, now actually populate the tree 885 | _populate_tree_helper(self.root, []) 886 | 887 | def init_root(self, max_level): 888 | 889 | def _init_root_helper(curr_octant, curr_level, max_level, optim_model): 890 | if curr_level == max_level: 891 | return 892 | curr_octant.children = [Octant(optim_model) for _ in range(8)] 893 | 894 | for octant in curr_octant.children: 895 | octant.parent = curr_octant 896 | _init_root_helper(octant, curr_level+1, max_level, optim_model) 897 | return 898 | 899 | # create root node 900 | root = Octant(self.optim_model) 901 | _init_root_helper(root, 0, max_level, self.optim_model) 902 | return root 903 | 904 | def get_block_coords(self, flatten=False, include_ends=False, octant_size=None): 905 | 906 | # use finest scale patch by default 907 | if octant_size is None: 908 | octant_size = self.min_octant_size # TODO: ?? verify 909 | 910 | # get size of each block 911 | z_len = self.z_max - self.z_min 912 | y_len = self.y_max - self.y_min 913 | x_len = self.x_max - self.x_min 914 | 915 | block_size = (z_len / (self.sidelength[0]) * octant_size, 916 | y_len / (self.sidelength[1]) * octant_size, 917 | x_len / (self.sidelength[2]) * octant_size) 918 | 919 | # get block begin/end coordinates 920 | if include_ends: 921 | block_coords_z = torch.arange(self.z_min, self.z_max + block_size[0], block_size[0]) 922 | block_coords_y = torch.arange(self.y_min, self.y_max + block_size[1], block_size[1]) 923 | block_coords_x = torch.arange(self.x_min, self.x_max + block_size[2], block_size[2]) 924 | else: 925 | block_coords_z = torch.arange(self.z_min, self.z_max, block_size[0]) 926 | block_coords_y = torch.arange(self.y_min, self.y_max, block_size[1]) 927 | block_coords_x = torch.arange(self.x_min, self.x_max, block_size[2]) 928 | 929 | # repeat for every single block 930 | block_coords = torch.meshgrid(block_coords_z, block_coords_y, block_coords_x) 931 | block_coords = torch.stack((block_coords[0], block_coords[1], block_coords[2]), dim=-1) 932 | if flatten: 933 | block_coords = block_coords.reshape(-1, 3) 934 | 935 | return block_coords 936 | 937 | def get_octants_at_level(self, level): 938 | # level is the image scale: 0-> coarsest, N->finest 939 | 940 | # what quadtree level do our octants start at? 941 | # check input, too 942 | target_level = level + self.min_octtree_level 943 | assert level <= (self.max_octtree_level - self.min_octtree_level), \ 944 | "invalid 'level' input to get_blocks_at_level" 945 | 946 | def _get_octants_at_level_helper(curr_octant, curr_level, octants): 947 | if curr_level > target_level: 948 | return 949 | 950 | for octant in curr_octant.children: 951 | _get_octants_at_level_helper(octant, curr_level+1, octants) 952 | 953 | if curr_level == target_level: 954 | octants.append(curr_octant) 955 | return octants 956 | 957 | return _get_octants_at_level_helper(self.root, 0, []) 958 | 959 | def get_frozen_octants(self): 960 | 961 | def _get_frozen_octants_helper(curr_octant, octants): 962 | if curr_octant.frozen and curr_octant.active: 963 | octants.append(curr_octant) 964 | for octant in curr_octant.children: 965 | _get_frozen_octants_helper(octant, octants) 966 | return octants 967 | return _get_frozen_octants_helper(self.root, []) 968 | 969 | def get_active_octants(self, include_frozen_octants=False): 970 | 971 | def _get_active_octants_helper(curr_octant, octants): 972 | if curr_octant.active and \ 973 | (include_frozen_octants or 974 | (not include_frozen_octants and not curr_octant.frozen)): 975 | octants.append(curr_octant) 976 | return octants 977 | for octant in curr_octant.children: 978 | _get_active_octants_helper(octant, octants) 979 | return octants 980 | return _get_active_octants_helper(self.root, []) 981 | 982 | def activate_random(self): 983 | def _activate_random_helper(curr_octant): 984 | if not curr_octant.children: 985 | curr_octant.activate() 986 | return 987 | elif (curr_octant.scale is not None) and (torch.rand(1).item() < 0.2): 988 | curr_octant.activate() 989 | return 990 | 991 | for patch in curr_octant.children: 992 | _activate_random_helper(patch) 993 | return 994 | 995 | _activate_random_helper(self.root) 996 | 997 | def synchronize(self, master): 998 | # set active/inactive nodes to be the same as master 999 | # for now just toggle the flags without worrying about the gurobi variables 1000 | def _synchronize_helper(curr_octant, curr_octant_master): 1001 | curr_octant.active = curr_octant_master.active 1002 | curr_octant.frozen = curr_octant_master.frozen 1003 | curr_octant.value = curr_octant_master.value 1004 | if not curr_octant.children: 1005 | return 1006 | 1007 | for octant, octant_master in zip(curr_octant.children, curr_octant_master.children): 1008 | _synchronize_helper(octant, octant_master) 1009 | return 1010 | 1011 | _synchronize_helper(self.root, master.root) 1012 | 1013 | def get_frozen_samples(self, oversample): 1014 | octants = self.get_frozen_octants() 1015 | 1016 | if not octants: 1017 | return None, None, None, None 1018 | 1019 | rel_coords, abs_coords, vals = [], [], [] 1020 | octant_idx = [] 1021 | 1022 | for idx, p in enumerate(octants): 1023 | rel_samp, abs_samp, _ = p.get_stratified_samples(jitter=False, eval=True, oversample=oversample) 1024 | 1025 | rel_samp = rel_samp.reshape(-1, self.min_octant_size**3, 3) 1026 | abs_samp = abs_samp.reshape(-1, self.min_octant_size**3, 3) 1027 | octant_idx.extend(rel_samp.shape[0] * [idx, ]) 1028 | 1029 | rel_coords.append(rel_samp) 1030 | abs_coords.append(abs_samp) 1031 | # values have the same size as rel_samp but last dim is a scalar 1032 | vals.append(p.value*torch.ones(abs_samp.shape[:-1] + (1,))) 1033 | 1034 | return torch.cat(rel_coords, dim=0), torch.cat(abs_coords, dim=0), \ 1035 | torch.cat(vals, dim=0), octant_idx 1036 | 1037 | def get_stratified_samples(self, jitter=True, eval=False, oversample=1): 1038 | octants = self.get_active_octants() 1039 | 1040 | rel_coords, abs_coords = [], [] 1041 | all_global_indices = [] 1042 | octant_idx = [] 1043 | 1044 | for idx, p in enumerate(octants): 1045 | 1046 | rel_samp, abs_samp, global_indices = p.get_stratified_samples(jitter=jitter, eval=eval, oversample=oversample) 1047 | 1048 | # always batch the coordinates in groups of a specific patch size 1049 | # so we can process them in parallel 1050 | rel_samp = rel_samp.reshape(-1, int(self.min_octant_size * oversample)**3, 3) 1051 | abs_samp = abs_samp.reshape(-1, int(self.min_octant_size * oversample)**3, 3) 1052 | if global_indices is not None: 1053 | global_indices = global_indices.reshape(-1, int(self.min_octant_size * oversample)**3, 1) 1054 | 1055 | # since patch samples could be split across batches, 1056 | # keep track of which batch idx maps to which patch idx 1057 | octant_idx.extend(rel_samp.shape[0] * [idx, ]) 1058 | 1059 | rel_coords.append(rel_samp) 1060 | abs_coords.append(abs_samp) 1061 | if global_indices is not None: 1062 | all_global_indices.append(global_indices) 1063 | 1064 | return torch.cat(rel_coords, dim=0), torch.cat(abs_coords, dim=0), octant_idx, None 1065 | 1066 | def solve_optim(self, Max_Num_Octants=150): 1067 | octants = self.get_active_octants() 1068 | 1069 | assert (len(octants) <= Max_Num_Octants), \ 1070 | "You are trying to solve a model which is infeasible: " \ 1071 | "Number of active octants > Max number of octants" 1072 | 1073 | if self.c_max_octants is not None: 1074 | self.optim_model.remove(self.c_max_octants) 1075 | 1076 | # global "knapsack" constraint 1077 | expr_sum_octants = [p.update_merge() for p in octants] 1078 | self.c_max_octants = self.optim_model.addConstr(gp.quicksum(expr_sum_octants) <= Max_Num_Octants) 1079 | 1080 | # objective 1081 | self.optim_model.setObjective(gp.quicksum([p.get_cost() for p in octants]), GRB.MINIMIZE) 1082 | self.optim_model.optimize() 1083 | obj_val = self.optim_model.objVal 1084 | 1085 | if self.optim_model.Status == GRB.INFEASIBLE: 1086 | print("----------- Model is infeasible") 1087 | self.optim_model.computeIIS() 1088 | self.optim_model.write("model.ilp") 1089 | 1090 | # split and merge 1091 | merged = 0 1092 | split = 0 1093 | none = 0 1094 | for p in octants: 1095 | # print(p) 1096 | if p.has_split() and p.scale < self.max_octtree_level: 1097 | p.deactivate() 1098 | for child in p.get_children(): 1099 | child.activate() 1100 | split += 1 1101 | elif p.has_merged() and p.scale >= self.min_octtree_level and p.scale > 0: 1102 | # we first check if it is active, 1103 | # since we could have already been activated by a neighbor 1104 | if p.active: 1105 | for neighbor in p.get_neighbors(): 1106 | neighbor.deactivate() 1107 | p.parent.activate() 1108 | merged += 1 1109 | 1110 | else: 1111 | p.update() 1112 | none += 1 1113 | 1114 | stats_dict = {'merged': merged, 1115 | 'splits': split, 1116 | 'none': none, 1117 | 'obj': obj_val} 1118 | print(f"============================= Total octants:{len(octants)}, split/merge:{split}/{merged}") 1119 | print(f"Vars={len(self.optim_model.getVars())}, Cstrs={len(self.optim_model.getConstrs())}") 1120 | return stats_dict 1121 | 1122 | def draw(self, color_by_scale=False, save_fig=True): 1123 | fig = plt.figure(figsize=(5, 5)) 1124 | ax = fig.gca(projection='3d') 1125 | 1126 | depth = 1 + self.max_octtree_level - self.min_octtree_level 1127 | sidelen = 4**(depth-1) // 2**(depth-1) 1128 | 1129 | # calculate scale 1130 | octant_list = self.get_active_octants() 1131 | octants_err = [p.err for p in octant_list] 1132 | max_err = np.max(octants_err) 1133 | min_err = np.min(octants_err) 1134 | 1135 | cmap = plt.cm.get_cmap('viridis') 1136 | 1137 | def cuboid_data(pos, size=(1, 1, 1)): 1138 | eps = 0.2 1139 | o = pos + (eps, eps, eps) 1140 | # get the length, width, and height 1141 | l, w, h = size 1142 | l -= eps 1143 | w -= eps 1144 | h -= eps 1145 | x = [[o[0], o[0] + l, o[0] + l, o[0], o[0]], 1146 | [o[0], o[0] + l, o[0] + l, o[0], o[0]], 1147 | [o[0], o[0] + l, o[0] + l, o[0], o[0]], 1148 | [o[0], o[0] + l, o[0] + l, o[0], o[0]]] 1149 | y = [[o[1], o[1], o[1] + w, o[1] + w, o[1]], 1150 | [o[1], o[1], o[1] + w, o[1] + w, o[1]], 1151 | [o[1], o[1], o[1], o[1], o[1]], 1152 | [o[1] + w, o[1] + w, o[1] + w, o[1] + w, o[1] + w]] 1153 | z = [[o[2], o[2], o[2], o[2], o[2]], 1154 | [o[2] + h, o[2] + h, o[2] + h, o[2] + h, o[2] + h], 1155 | [o[2], o[2], o[2] + h, o[2] + h, o[2]], 1156 | [o[2], o[2], o[2] + h, o[2] + h, o[2]]] 1157 | return np.array(x), np.array(y), np.array(z) 1158 | 1159 | def draw_cube_at(pos=(0, 0, 0), size=(1, 1, 1), 1160 | color='b', edgecolor='b', alpha=1., ax=None): 1161 | 1162 | # Plotting a cube element at position pos 1163 | if ax is not None: 1164 | Z, Y, X = cuboid_data(pos, size) 1165 | ax.plot_surface(X, Y, Z, color=color, rstride=1, cstride=1, alpha=alpha, 1166 | edgecolors=edgecolor, linewidth=0.1) 1167 | 1168 | def _draw_level_2(octant, curr_level, 1169 | ax, sidelen, offset, scale): 1170 | if curr_level > self.max_octtree_level: 1171 | return ax 1172 | 1173 | scale = scale/2. 1174 | 1175 | for i, child in enumerate(octant.children): 1176 | # depth 0 1177 | if i == 0: 1178 | new_offset = (offset[0], 1179 | offset[1], 1180 | offset[2]) 1181 | elif i == 1: 1182 | new_offset = (offset[0] + scale * sidelen, 1183 | offset[1], 1184 | offset[2]) 1185 | elif i == 2: 1186 | new_offset = (offset[0], 1187 | offset[1] + scale * sidelen, 1188 | offset[2]) 1189 | elif i == 3: 1190 | new_offset = (offset[0] + scale * sidelen, 1191 | offset[1] + scale * sidelen, 1192 | offset[2]) 1193 | # depth 1 1194 | elif i == 4: 1195 | new_offset = (offset[0], 1196 | offset[1], 1197 | offset[2] + scale * sidelen) 1198 | elif i == 5: 1199 | new_offset = (offset[0] + scale * sidelen, 1200 | offset[1], 1201 | offset[2] + scale * sidelen) 1202 | elif i == 6: 1203 | new_offset = (offset[0], 1204 | offset[1] + scale * sidelen, 1205 | offset[2] + scale * sidelen) 1206 | else: 1207 | new_offset = (offset[0] + scale * sidelen, 1208 | offset[1] + scale * sidelen, 1209 | offset[2] + scale * sidelen) 1210 | 1211 | if child.active: 1212 | norm_err = (child.err-min_err)/(max_err-min_err) 1213 | sz = scale*sidelen 1214 | 1215 | if child.frozen: 1216 | color = [0.5, 0.5, 0.5] 1217 | edgecolor = [1.0, 0.0, 0.0, 0.1] 1218 | alpha = 0. 1219 | else: 1220 | color = cmap(norm_err)[0:3] 1221 | edgecolor = 'none' 1222 | alpha = 0.1 1223 | 1224 | draw_cube_at(pos=new_offset, size=(sz, sz, sz), 1225 | color=color, edgecolor=edgecolor, 1226 | alpha=alpha, ax=ax) 1227 | 1228 | else: 1229 | ax = _draw_level_2(child, curr_level+1, 1230 | ax, sidelen, new_offset, scale) 1231 | 1232 | return ax 1233 | 1234 | ax = _draw_level_2(self.root, self.min_octtree_level, 1235 | ax, sidelen, 1236 | (0., 0., 0.), 1.) 1237 | 1238 | ax = fig.gca(projection='3d') 1239 | ax.grid(False) 1240 | ax.set_xlim(-1, sidelen+1) 1241 | ax.set_ylim(-1, sidelen+1) 1242 | ax.set_zlim(-1, sidelen+1) 1243 | 1244 | ax.set_xticks([-1, sidelen+1]) 1245 | ax.set_xticklabels([-1, 1]) 1246 | ax.set_yticks([-1, sidelen+1]) 1247 | ax.set_yticklabels([-1, 1]) 1248 | ax.set_zticks([-1, sidelen+1]) 1249 | ax.set_zticklabels([-1, 1]) 1250 | 1251 | return fig 1252 | 1253 | 1254 | class Octant(): 1255 | def __init__(self, optim_model=None, block_coord=None, scale=None, gamma=0.95): 1256 | self.active = False 1257 | 1258 | self.parent = None 1259 | self.children = [] 1260 | 1261 | # absolute block coordinate 1262 | self.block_coord = block_coord 1263 | 1264 | # size of block in absolute coord frame 1265 | self.block_size = None 1266 | 1267 | self.old_block_coord = None 1268 | self.old_block_size = None 1269 | 1270 | # scale level of block 1271 | self.scale = scale 1272 | 1273 | # num samples to be generated for this block 1274 | self.num_samples = None 1275 | 1276 | # num pixels in this patch 1277 | self.voxel_size = None 1278 | 1279 | # optimization model 1280 | self.optim = optim_model 1281 | 1282 | # row/column coords for sampling at test time 1283 | # initialized by set_samples() function 1284 | self.row_coords = None 1285 | self.col_coords = None 1286 | self.dep_coords = None 1287 | 1288 | self.near_mesh_abs_samples = None 1289 | self.near_mesh_rel_samples = None 1290 | 1291 | # error for doing nothing, merging, splitting 1292 | self.err = 0. 1293 | self.last_updated = 0. 1294 | 1295 | self.gamma = gamma 1296 | 1297 | self._nocopy = ['optim', 'I_grp', 'I_split', 'I_none', 1298 | 'I_merge', 'c_joinable', 'c_merge_split', 1299 | 'children', 'parent', 'loss', 'loss_iter', 1300 | 'err'] 1301 | self.spec_cstrs = [] 1302 | 1303 | self._pickle_vars = ['parent', 'children', 'active', 'err', 'last_updated', 'frozen', 'value'] 1304 | 1305 | # options for pruning 1306 | self.frozen = False 1307 | self.value = 0 1308 | 1309 | def __str__(self): 1310 | str = f"Octant id={id(self)}\n" \ 1311 | f" . active={self.active}\n" \ 1312 | f" . level={self.scale}\n" \ 1313 | f" . model={self.optim}" 1314 | 1315 | if self.active: 1316 | str += f"\n . g={self.I_grp.x}, s={self.I_split.x}, n={self.I_none.x}" 1317 | 1318 | return str 1319 | 1320 | # override deep copy to copy undeepcopyable objects by reference 1321 | def __deepcopy__(self, memo): 1322 | deep_copied_obj = Octant() 1323 | for k, v in self.__dict__.items(): 1324 | if k in self._nocopy: 1325 | setattr(deep_copied_obj, k, v) 1326 | else: 1327 | setattr(deep_copied_obj, k, copy.deepcopy(v, memo)) 1328 | 1329 | return deep_copied_obj 1330 | 1331 | def __getstate__(self): 1332 | state = self.__dict__.copy() 1333 | for k, v in self.__dict__.items(): 1334 | if k not in self._pickle_vars: 1335 | del(state[k]) 1336 | return state 1337 | 1338 | def __load__(self, obj): 1339 | for k, v in obj.__dict__.items(): 1340 | if k in ['children', 'parent']: 1341 | continue 1342 | setattr(self, k, v) 1343 | if self.active: 1344 | self.activate() 1345 | 1346 | def update(self): 1347 | self.deactivate() 1348 | self.activate() 1349 | 1350 | def activate(self): 1351 | self.active = True 1352 | 1353 | # indicator variables 1354 | self.I_grp = self.optim.addVar(vtype=GRB.BINARY) 1355 | self.I_split = self.optim.addVar(vtype=GRB.BINARY) 1356 | self.I_none = self.optim.addVar(vtype=GRB.BINARY) 1357 | 1358 | self.I_merge = gp.LinExpr(0.0) 1359 | 1360 | # local constraint "merge/none/split" 1361 | self.c_joinable = self.optim.addConstr(self.I_grp + self.I_none + self.I_split == 1) 1362 | 1363 | # local constraint "merge-split" 1364 | self.c_merge_split = None 1365 | 1366 | def deactivate(self): 1367 | self.active = False 1368 | 1369 | self.optim.remove(self.I_grp) 1370 | self.optim.remove(self.I_split) 1371 | self.optim.remove(self.I_none) 1372 | 1373 | self.I_merge = gp.LinExpr(0.0) 1374 | 1375 | self.optim.remove(self.c_joinable) 1376 | 1377 | if self.c_merge_split is not None: 1378 | self.optim.remove(self.c_merge_split) 1379 | 1380 | for cstr in self.spec_cstrs: 1381 | self.optim.remove(cstr) 1382 | self.spec_cstrs = [] 1383 | 1384 | def is_mergeable(self): 1385 | siblings = self.parent.children 1386 | return np.all(np.all([sib.active for sib in siblings])) 1387 | 1388 | def set_sample_params(self, num_samples): 1389 | self.num_samples = num_samples 1390 | posts = torch.linspace(-1, 1, self.num_samples+1)[:-1] 1391 | row_coords, col_coords, dep_coords = torch.meshgrid(posts, posts, posts) 1392 | self.row_coords = row_coords.flatten() 1393 | self.col_coords = col_coords.flatten() 1394 | self.dep_coords = dep_coords.flatten() 1395 | 1396 | def must_split(self): 1397 | self.spec_cstrs.append( 1398 | self.optim.addConstr(self.I_split == 1) 1399 | ) 1400 | 1401 | def must_merge(self): 1402 | self.spec_cstrs.append( 1403 | self.optim.addConstr(self.I_grp == 1) 1404 | ) 1405 | 1406 | def has_split(self): 1407 | return self.I_split.x == 1 1408 | 1409 | def has_merged(self): 1410 | return self.I_grp.x == 1 1411 | # return self.I_none.x==0 and self.I_split.x==0 1412 | 1413 | def has_done_nothing(self): 1414 | return self.I_none.x == 1 1415 | 1416 | def get_cost(self): 1417 | area = self.block_size[0]**2 1418 | alpha = 0.2 # how much worse we expect the error to be when merging 1419 | beta = -0.02 # how much better we expect the error to be when splitting 1420 | 1421 | # == Merge 1422 | if self.scale > 0: # it should never be root, but still.. 1423 | err_merge = (8+alpha) * area * self.err 1424 | 1425 | if self.parent.last_updated: 1426 | parent_area = self.parent.block_size[0]**2 1427 | err_merge = parent_area * self.parent.err # can multiply by 1/8 as in paper to make merging more aggressive 1428 | else: 1429 | err_merge = self.err 1430 | 1431 | # == Split 1432 | if self.children: 1433 | err_split = (0.125+beta) * area * self.err 1434 | 1435 | if self.children[0].last_updated: 1436 | err_children = np.sum([child.err for child in self.children]) 1437 | err_split = area * err_children 1438 | else: 1439 | err_split = 1. # in case you don't have children, high to avoid splitting 1440 | 1441 | err_none = area * self.err 1442 | 1443 | return err_none * self.I_none \ 1444 | + err_split * self.I_split \ 1445 | + err_merge * self.I_grp 1446 | 1447 | def update_merge(self): 1448 | if self.parent is None: # if root 1449 | return gp.LinExpr(0) 1450 | 1451 | siblings = self.parent.children 1452 | if np.all([sib.active for sib in siblings]): 1453 | I_grp_neighs = [s.I_grp for s in siblings] 1454 | self.I_merge = gp.quicksum(I_grp_neighs) 1455 | 1456 | # local constraint "joinable" 1457 | self.c_merge_split = self.optim.addConstr(self.I_none + self.I_split + .125*self.I_merge == 1) 1458 | expr_max_patches = 8 * self.I_split + 1 * self.I_none + .125 * self.I_grp 1459 | 1460 | return expr_max_patches 1461 | 1462 | def get_neighbors(self): 1463 | return self.parent.children 1464 | 1465 | def get_children(self): 1466 | return self.children 1467 | 1468 | def get_parent(self): 1469 | return self.parent 1470 | 1471 | def is_joinable(self): 1472 | # test if siblings are all leaf nodes 1473 | siblings = self.parent.children 1474 | return np.all([sib.active for sib in siblings]) 1475 | 1476 | def get_block_coord(self): 1477 | return self.block_coord 1478 | 1479 | def get_scale(self): 1480 | return self.scale 1481 | 1482 | def update_error(self, error, iter): 1483 | self.err = self.gamma*self.err + (1-self.gamma)*error 1484 | self.last_updated = iter 1485 | 1486 | def get_block_coords(self, flatten=False, include_ends=False, octant_size=None): 1487 | # get size of each block 1488 | z_len = 2 1489 | y_len = 2 1490 | x_len = 2 1491 | 1492 | sidelength = 256 1493 | 1494 | block_size = (z_len / (sidelength) * octant_size, 1495 | y_len / (sidelength) * octant_size, 1496 | x_len / (sidelength) * octant_size) 1497 | 1498 | # get block begin/end coordinates 1499 | if include_ends: 1500 | block_coords_z = torch.arange(-1, -1 + block_size[0], block_size[0]) 1501 | block_coords_y = torch.arange(-1, 1 + block_size[1], block_size[1]) 1502 | block_coords_x = torch.arange(-1, 1 + block_size[2], block_size[2]) 1503 | else: 1504 | block_coords_z = torch.arange(-1, 1, block_size[0]) 1505 | block_coords_y = torch.arange(-1, 1, block_size[1]) 1506 | block_coords_x = torch.arange(-1, 1, block_size[2]) 1507 | 1508 | # repeat for every single block 1509 | block_coords = torch.meshgrid(block_coords_z, block_coords_y, block_coords_x) 1510 | block_coords = torch.stack((block_coords[0], block_coords[1], block_coords[2]), dim=-1) 1511 | if flatten: 1512 | block_coords = block_coords.reshape(-1, 3) 1513 | 1514 | return block_coords 1515 | 1516 | def get_stratified_samples(self, jitter=True, eval=False, oversample=1., kd_tree=None): 1517 | # Block coords are always aligned to the pixel grid, 1518 | # e.g., they align with pixels 0, 8, 16, 24, etc. for 1519 | # patch size 8 1520 | # 1521 | # To normalize the coordinates between (-1, 1), consider 1522 | # we have an image of 64x64 and patch size 8x8. 1523 | # The block coordinate (-1, -1) aligns with pixel (0, 0) 1524 | # and coordinate (1, 1) aligns with pixel (63, 63) 1525 | # 1526 | # Absolute coordinates within a block should stretch all the way 1527 | # from the absolute position of one block coordinate to another. 1528 | # Say each block contains 8x8 pixels and we use a feature grid 1529 | # of 8x8 features to interpolate values within a block. 1530 | # This means is that the feature positions are not actually 1531 | # aligned to the pixel positions. The features are positioned 1532 | # on a grid stretching from one block coord to another whereas 1533 | # the pixel grid ends just short of the next block coordinate 1534 | # 1535 | # Example patch (x = pixel position, B = block coordinate position) 1536 | # and relative coordinate positions. 1537 | # 1538 | # -1 ^ B x x x x x x x B 1539 | # | x x x x x x x x x 1540 | # | x x x x x x x x x 1541 | # | x x x x x x x x x 1542 | # | x x x x x x x x x 1543 | # | x x x x x x x x x 1544 | # | x x x x x x x x x 1545 | # | x x x x x x x x x 1546 | # 1 v B x x x x x x x B 1547 | # <---------------> 1548 | # -1 1 1549 | # 1550 | # When we generate samples for a patch, we sample an 1551 | # 8x8 grid that extends between block coords, i.e. 1552 | # between the arrows above 1553 | # 1554 | if eval: 1555 | if True or oversample != 1.: 1556 | post = torch.linspace(-1, 1, int(self.voxel_size * oversample + 1))[:-1] 1557 | post += (post[1] - post[0])/2 1558 | 1559 | eval_coords = torch.meshgrid(post, post, post) 1560 | row_coords = eval_coords[0].flatten() 1561 | col_coords = eval_coords[1].flatten() 1562 | dep_coords = eval_coords[2].flatten() 1563 | else: 1564 | row_coords = self.eval_row_coords.flatten() 1565 | col_coords = self.eval_col_coords.flatten() 1566 | dep_coords = self.eval_dep_coords.flatten() 1567 | 1568 | rel_samples = torch.stack((row_coords, col_coords, dep_coords), dim=-1) 1569 | abs_samples = self.block_coord[None, :] + (self.block_size[None, :]) * (rel_samples+1)/2 1570 | 1571 | return rel_samples, abs_samples, None 1572 | 1573 | else: 1574 | if (self.old_block_coord is None or self.old_block_size is None or 1575 | (self.old_block_coord != self.block_coord).any() 1576 | or (self.old_block_size != self.block_size).any()): 1577 | 1578 | self.old_block_coord = self.block_coord 1579 | self.old_block_size = self.block_size 1580 | self.update_surface_coords() 1581 | return self.select_near_mesh(self.num_samples**3) 1582 | 1583 | def update_surface_coords(self): 1584 | center = (self.block_coord + 0.5 * self.block_size).cpu().numpy() 1585 | side_length = float(self.block_size[0]) 1586 | 1587 | search_radius = (side_length/2)*(3**0.5) 1588 | indices = np.array(self.surface_tree.query_ball_point(center, search_radius)) 1589 | 1590 | if indices.shape[0] == 0: 1591 | self.near_mesh_abs_samples = np.zeros((0, 3)) 1592 | self.near_mesh_rel_samples = np.zeros((0, 3)) 1593 | else: 1594 | coordinates = self.surface_tree.data[indices] 1595 | in_cube = np.linalg.norm(coordinates - center, ord=np.inf, axis=1) < (side_length/2) 1596 | self.near_mesh_abs_samples = torch.FloatTensor(coordinates[in_cube]) 1597 | 1598 | self.near_mesh_rel_samples = 2 * (self.near_mesh_abs_samples - self.block_coord[None, :])/self.block_size[None, :] - 1 1599 | 1600 | def select_near_mesh(self, num_samples, jitter=True): 1601 | if np.random.rand() > 0.9 and self.near_mesh_abs_samples.shape[0] > 0: # Near surface 1602 | selection_indices = torch.randint(self.near_mesh_abs_samples.shape[0], (num_samples,)) 1603 | rel_samples_mesh = self.near_mesh_rel_samples[selection_indices] 1604 | abs_samples_mesh = self.near_mesh_abs_samples[selection_indices] 1605 | 1606 | rel_samples_mesh += torch.randn_like(rel_samples_mesh) * 0.05 * self.block_size 1607 | abs_samples_mesh = self.block_coord[None, :] + (self.block_size[None, :]) * (rel_samples_mesh+1)/2 1608 | 1609 | return rel_samples_mesh, abs_samples_mesh, None 1610 | else: # Uniform 1611 | row_coords = self.row_coords + torch.rand_like(self.row_coords) * 2./self.num_samples 1612 | col_coords = self.col_coords + torch.rand_like(self.col_coords) * 2./self.num_samples 1613 | dep_coords = self.dep_coords + torch.rand_like(self.dep_coords) * 2./self.num_samples 1614 | 1615 | rel_samples = torch.stack((row_coords, col_coords, dep_coords), dim=-1) 1616 | abs_samples = self.block_coord[None, :] + (self.block_size[None, :]) * (rel_samples+1)/2 1617 | 1618 | return rel_samples, abs_samples, None 1619 | -------------------------------------------------------------------------------- /dataio.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import errno 4 | import matplotlib.colors as colors 5 | import skimage 6 | import skimage.filters 7 | import torch 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | from torchvision.transforms import Resize, Compose, ToTensor, Normalize 11 | import urllib.request 12 | from tqdm import tqdm 13 | import numpy as np 14 | import copy 15 | import trimesh 16 | from inside_mesh import inside_mesh 17 | 18 | from scipy.spatial import cKDTree as spKDTree 19 | from data_structs import QuadTree, OctTree 20 | 21 | 22 | def get_mgrid(sidelen, dim=2): 23 | '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.''' 24 | if isinstance(sidelen, int): 25 | sidelen = dim * (sidelen,) 26 | 27 | if dim == 2: 28 | pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32) 29 | pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1) 30 | pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1) 31 | elif dim == 3: 32 | pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32) 33 | pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1) 34 | pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1) 35 | pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1) 36 | else: 37 | raise NotImplementedError('Not implemented for dim=%d' % dim) 38 | 39 | pixel_coords -= 0.5 40 | pixel_coords *= 2. 41 | pixel_coords = torch.Tensor(pixel_coords).view(-1, dim) 42 | return pixel_coords 43 | 44 | 45 | def lin2img(tensor, image_resolution=None): 46 | batch_size, num_samples, channels = tensor.shape 47 | if image_resolution is None: 48 | width = np.sqrt(num_samples).astype(int) 49 | height = width 50 | else: 51 | height = image_resolution[0] 52 | width = image_resolution[1] 53 | 54 | return tensor.permute(0, 2, 1).view(batch_size, channels, height, width) 55 | 56 | 57 | def grads2img(gradients): 58 | mG = gradients.detach().squeeze(0).permute(-2, -1, -3).cpu() 59 | 60 | # assumes mG is [row,cols,2] 61 | nRows = mG.shape[0] 62 | nCols = mG.shape[1] 63 | mGr = mG[:, :, 0] 64 | mGc = mG[:, :, 1] 65 | mGa = np.arctan2(mGc, mGr) 66 | mGm = np.hypot(mGc, mGr) 67 | mGhsv = np.zeros((nRows, nCols, 3), dtype=np.float32) 68 | mGhsv[:, :, 0] = (mGa + math.pi) / (2. * math.pi) 69 | mGhsv[:, :, 1] = 1. 70 | 71 | nPerMin = np.percentile(mGm, 5) 72 | nPerMax = np.percentile(mGm, 95) 73 | mGm = (mGm - nPerMin) / (nPerMax - nPerMin) 74 | mGm = np.clip(mGm, 0, 1) 75 | 76 | mGhsv[:, :, 2] = mGm 77 | mGrgb = colors.hsv_to_rgb(mGhsv) 78 | return torch.from_numpy(mGrgb).permute(2, 0, 1) 79 | 80 | 81 | def rescale_img(x, mode='scale', perc=None, tmax=1.0, tmin=0.0): 82 | if (mode == 'scale'): 83 | if perc is None: 84 | xmax = torch.max(x) 85 | xmin = torch.min(x) 86 | else: 87 | xmin = np.percentile(x.detach().cpu().numpy(), perc) 88 | xmax = np.percentile(x.detach().cpu().numpy(), 100 - perc) 89 | x = torch.clamp(x, xmin, xmax) 90 | if xmin == xmax: 91 | return 0.5 * torch.ones_like(x) * (tmax - tmin) + tmin 92 | x = ((x - xmin) / (xmax - xmin)) * (tmax - tmin) + tmin 93 | elif (mode == 'clamp'): 94 | x = torch.clamp(x, 0, 1) 95 | return x 96 | 97 | 98 | def to_uint8(x): 99 | return (255. * x).astype(np.uint8) 100 | 101 | 102 | def to_numpy(x): 103 | return x.detach().cpu().numpy() 104 | 105 | 106 | class PointCloud(Dataset): 107 | def __init__(self, pointcloud_path, on_surface_points, keep_aspect_ratio=True): 108 | super().__init__() 109 | 110 | print("Dataset: loading point cloud") 111 | point_cloud = np.genfromtxt(pointcloud_path) 112 | print("Dataset: finished loading point cloud") 113 | 114 | coords = point_cloud[:, :3] 115 | self.normals = point_cloud[:, 3:] 116 | 117 | # Reshape point cloud such that it lies in bounding box of (-1, 1) (distorts geometry, but makes for high 118 | # sample efficiency) 119 | coords -= np.mean(coords, axis=0, keepdims=True) 120 | if keep_aspect_ratio: 121 | coord_max = np.amax(coords) 122 | coord_min = np.amin(coords) 123 | else: 124 | coord_max = np.amax(coords, axis=0, keepdims=True) 125 | coord_min = np.amin(coords, axis=0, keepdims=True) 126 | 127 | self.coords = (coords - coord_min) / (coord_max - coord_min) 128 | self.coords -= 0.5 129 | self.coords *= 2. 130 | 131 | self.on_surface_points = on_surface_points 132 | 133 | def __len__(self): 134 | return self.coords.shape[0] // self.on_surface_points 135 | 136 | def __getitem__(self, idx): 137 | point_cloud_size = self.coords.shape[0] 138 | 139 | off_surface_samples = self.on_surface_points # **2 140 | total_samples = self.on_surface_points + off_surface_samples 141 | 142 | # Random coords 143 | rand_idcs = np.random.choice(point_cloud_size, size=self.on_surface_points) 144 | 145 | on_surface_coords = self.coords[rand_idcs, :] 146 | on_surface_normals = self.normals[rand_idcs, :] 147 | 148 | off_surface_coords = np.random.uniform(-1, 1, size=(off_surface_samples, 3)) 149 | off_surface_normals = np.ones((off_surface_samples, 3)) * -1 150 | 151 | sdf = np.zeros((total_samples, 1)) # on-surface = 0 152 | sdf[self.on_surface_points:, :] = -1 # off-surface = -1 153 | 154 | coords = np.concatenate((on_surface_coords, off_surface_coords), axis=0) 155 | normals = np.concatenate((on_surface_normals, off_surface_normals), axis=0) 156 | 157 | return {'coords': torch.from_numpy(coords).float()}, {'sdf': torch.from_numpy(sdf).float(), 158 | 'normals': torch.from_numpy(normals).float()} 159 | 160 | 161 | class OccupancyDataset(): 162 | def __init__(self, pc_or_mesh_filename): 163 | self.intersector = None 164 | self.kd_tree = None 165 | self.kd_tree_sp = None 166 | self.mode = None 167 | 168 | if not pc_or_mesh_filename: 169 | return 170 | 171 | print("Dataset: loading mesh") 172 | self.mesh = trimesh.load(pc_or_mesh_filename, process=False, force='mesh', skip_materials=True) 173 | 174 | def normalize_mesh(mesh): 175 | print("Dataset: scaling parameters: ", mesh.bounding_box.extents) 176 | mesh.vertices -= mesh.bounding_box.centroid 177 | mesh.vertices /= np.max(mesh.bounding_box.extents / 2) 178 | 179 | normalize_mesh(self.mesh) 180 | 181 | self.intersector = inside_mesh.MeshIntersector(self.mesh, 2048) 182 | self.mode = 'volume' 183 | 184 | print('Dataset: sampling points on mesh') 185 | samples = trimesh.sample.sample_surface(self.mesh, 20000000)[0] 186 | 187 | self.kd_tree_sp = spKDTree(samples) 188 | 189 | def __len__(self): 190 | return 1 191 | 192 | def evaluate_occupancy(self, pts): 193 | return self.intersector.query(pts).astype(int).reshape(-1, 1) 194 | 195 | 196 | class Camera(Dataset): 197 | def __init__(self, downsample_factor=1): 198 | super().__init__() 199 | self.downsample_factor = downsample_factor 200 | self.img = Image.fromarray(skimage.data.camera()) 201 | self.img_channels = 1 202 | 203 | if downsample_factor > 1: 204 | size = (int(512 / downsample_factor),) * 2 205 | self.img_downsampled = self.img.resize(size, Image.ANTIALIAS) 206 | 207 | def __len__(self): 208 | return 1 209 | 210 | def __getitem__(self, idx): 211 | if self.downsample_factor > 1: 212 | return self.img_downsampled 213 | else: 214 | return self.img 215 | 216 | 217 | class ImageFile(Dataset): 218 | def __init__(self, filename, url=None, grayscale=True): 219 | super().__init__() 220 | Image.MAX_IMAGE_PIXELS = 1000000000 221 | file_exists = os.path.isfile(filename) 222 | 223 | if not file_exists: 224 | if url is None: 225 | raise FileNotFoundError( 226 | errno.ENOENT, os.strerror(errno.ENOENT), filename) 227 | else: 228 | print('Downloading image file...') 229 | urllib.request.urlretrieve(url, filename) 230 | 231 | self.img = Image.open(filename) 232 | if grayscale: 233 | self.img = self.img.convert('L') 234 | 235 | self.img_channels = len(self.img.mode) 236 | 237 | def __len__(self): 238 | return 1 239 | 240 | def __getitem__(self, idx): 241 | return self.img 242 | 243 | 244 | class Patch2DWrapperMultiscaleAdaptive(torch.utils.data.Dataset): 245 | def __init__(self, dataset, patch_size=(16, 16), sidelength=None, random_coords=False, 246 | jitter=True, num_workers=0, length=1000, scale_init=3, max_patches=1024): 247 | 248 | self.length = length 249 | if len(sidelength) == 1: 250 | sidelength = 2*sidelength 251 | self.sidelength = sidelength 252 | 253 | for i in range(2): 254 | assert float(sidelength[i]) / float(patch_size[i]) % 1 == 0, 'Resolution not divisible by patch size' 255 | assert float(sidelength[0]) / float(patch_size[0]) == float(sidelength[1]) / float(patch_size[1]), \ 256 | 'number of patches must be same along each dim; check values of resolution and patch size' 257 | 258 | self.transform = Compose([ 259 | Resize(sidelength), 260 | ToTensor(), 261 | Normalize(torch.Tensor([0.5]), torch.Tensor([0.5])) 262 | ]) 263 | 264 | # initialize quad tree 265 | self.quadtree = QuadTree(sidelength, patch_size) 266 | self.num_scales = self.quadtree.max_quadtree_level - self.quadtree.min_quadtree_level + 1 267 | self.max_patches = max_patches 268 | 269 | # set patches at coarsest level to be active 270 | patches = self.quadtree.get_patches_at_level(scale_init) 271 | for p in patches: 272 | p.activate() 273 | 274 | # handle parallelization 275 | self.num_workers = num_workers 276 | 277 | # make a copy of the tree for each worker 278 | self.quadtrees = [] 279 | print('Dataset: preparing dataloaders...') 280 | for idx in tqdm(range(num_workers)): 281 | self.quadtrees.append(copy.deepcopy(self.quadtree)) 282 | self.last_active_patches = self.quadtree.get_active_patches() 283 | 284 | # set random patches to be active 285 | # self.quadtree.activate_random() 286 | 287 | self.patch_size = patch_size 288 | self.dataset = dataset 289 | self.img = self.transform(self.dataset[0]) 290 | self.jitter = jitter 291 | self.eval = False 292 | 293 | def toggle_eval(self): 294 | if not self.eval: 295 | self.jitter_bak = self.jitter 296 | self.jitter = False 297 | self.eval = True 298 | else: 299 | self.jitter = self.jitter_bak 300 | self.eval = False 301 | 302 | def interpolate_bilinear(self, img, fine_abs_coords, psize): 303 | n_blocks = fine_abs_coords.shape[0] 304 | n_channels = img.shape[0] 305 | fine_abs_coords = fine_abs_coords.reshape(n_blocks, psize[0], psize[1], 2) 306 | x = fine_abs_coords[..., :1] 307 | y = fine_abs_coords[..., 1:] 308 | coords = torch.cat([y, x], dim=-1) 309 | 310 | out = [] 311 | for block in coords: 312 | tmp = torch.nn.functional.grid_sample(img[None, ...], block[None, ...], 313 | mode='bilinear', 314 | padding_mode='reflection', 315 | align_corners=False) 316 | out.append(tmp) 317 | out = torch.cat(out, dim=0) 318 | out = out.permute(0, 2, 3, 1) 319 | return out.reshape(n_blocks, np.prod(psize), n_channels) 320 | 321 | def synchronize(self): 322 | self.last_active_patches = self.quadtree.get_active_patches() 323 | 324 | if self.num_workers == 0: 325 | return 326 | else: 327 | for idx in range(self.num_workers): 328 | self.quadtrees[idx].synchronize(self.quadtree) 329 | 330 | def __len__(self): 331 | # return len(self.dataset) 332 | return self.length 333 | 334 | def get_frozen_patches(self): 335 | quadtree = self.quadtree 336 | 337 | # get fine coords, get frozen patches is only called at eval 338 | fine_rel_coords, fine_abs_coords, vals,\ 339 | coord_patch_idx = quadtree.get_frozen_samples() 340 | 341 | return fine_abs_coords, vals 342 | 343 | def __getitem__(self, idx): 344 | 345 | quadtree = self.quadtree 346 | if not self.eval and self.num_workers > 0: 347 | worker_idx = torch.utils.data.get_worker_info().id 348 | quadtree = self.quadtrees[worker_idx] 349 | 350 | # get fine coords 351 | fine_rel_coords, fine_abs_coords, coord_patch_idx = quadtree.get_stratified_samples(self.jitter, eval=self.eval) 352 | 353 | # get block coords 354 | patches = quadtree.get_active_patches() 355 | coords = torch.stack([p.block_coord for p in patches], dim=0) 356 | scales = torch.stack([torch.tensor(p.scale) for p in patches], dim=0)[:, None] 357 | scales = 2*scales / (self.num_scales-1) - 1 358 | coords = torch.cat((coords, scales), dim=-1) 359 | 360 | if self.eval: 361 | coords = coords[coord_patch_idx] 362 | 363 | fine_abs_coords = fine_abs_coords 364 | img = self.interpolate_bilinear(self.img, fine_abs_coords, self.patch_size) 365 | 366 | in_dict = {'coords': coords, 367 | 'fine_abs_coords': fine_abs_coords, 368 | 'fine_rel_coords': fine_rel_coords} 369 | gt_dict = {'img': img} 370 | 371 | return in_dict, gt_dict 372 | 373 | def update_patch_err(self, err_per_patch, step): 374 | assert err_per_patch.shape[0] == len(self.last_active_patches), \ 375 | f"Trying to update the error in active patches but list of patches and error tensor" \ 376 | f" sizes are mismatched: {err_per_patch.shape[0]} vs {len(self.last_active_patches)}" 377 | 378 | for i, p in enumerate(self.last_active_patches): 379 | # Log the history of error 380 | p.update_error(err_per_patch[i], step) 381 | 382 | def update_tiling(self): 383 | return self.quadtree.solve_optim(self.max_patches) 384 | 385 | 386 | class Block3DWrapperMultiscaleAdaptive(torch.utils.data.Dataset): 387 | def __init__(self, dataset, octant_size=16, sidelength=None, random_coords=False, 388 | max_octants=600, jitter=True, num_workers=0, length=1000, scale_init=3): 389 | 390 | self.length = length 391 | if isinstance(sidelength, int): 392 | sidelength = (sidelength, sidelength, sidelength) 393 | self.sidelength = sidelength 394 | 395 | # initialize quad tree 396 | self.octtree = OctTree(sidelength, octant_size, mesh_kd_tree=dataset.kd_tree_sp) 397 | self.num_scales = self.octtree.max_octtree_level - self.octtree.min_octtree_level + 1 398 | 399 | # set patches at coarsest level to be active 400 | octants = self.octtree.get_octants_at_level(scale_init) 401 | for p in octants: 402 | p.activate() 403 | 404 | # handle parallelization 405 | self.num_workers = num_workers 406 | 407 | # make a copy of the tree for each worker 408 | self.octtrees = [] 409 | print('Dataset: preparing dataloaders...') 410 | for idx in tqdm(range(num_workers)): 411 | self.octtrees.append(copy.deepcopy(self.octtree)) 412 | self.last_active_octants = self.octtree.get_active_octants() 413 | 414 | self.octant_size = octant_size 415 | self.dataset = dataset 416 | self.pointcloud = None 417 | self.jitter = jitter 418 | self.eval = False 419 | 420 | self.max_octants = max_octants 421 | 422 | self.iter = 0 423 | 424 | def toggle_eval(self): 425 | if not self.eval: 426 | self.jitter_bak = self.jitter 427 | self.jitter = False 428 | self.eval = True 429 | else: 430 | self.jitter = self.jitter_bak 431 | self.eval = False 432 | 433 | def synchronize(self): 434 | self.last_active_octants = self.octtree.get_active_octants() 435 | if self.num_workers == 0: 436 | return 437 | else: 438 | for idx in range(self.num_workers): 439 | self.octtrees[idx].synchronize(self.octtree) 440 | 441 | def __len__(self): 442 | return self.length 443 | 444 | def get_frozen_octants(self, oversample): 445 | octtree = self.octtree 446 | 447 | # get fine coords, get frozen patches is only called at eval 448 | fine_rel_coords, fine_abs_coords, vals,\ 449 | coord_patch_idx = octtree.get_frozen_samples(oversample) 450 | 451 | return fine_abs_coords, vals 452 | 453 | def get_eval_samples(self, oversample): 454 | octtree = self.octtree 455 | 456 | # get fine coords 457 | fine_rel_coords, fine_abs_coords, coord_octant_idx, _ = octtree.get_stratified_samples(self.jitter, eval=True, oversample=oversample) 458 | 459 | # get block coords 460 | octants = octtree.get_active_octants() 461 | coords = torch.stack([p.block_coord for p in octants], dim=0) 462 | scales = torch.stack([torch.tensor(p.scale) for p in octants], dim=0)[:, None] 463 | scales = 2*scales / (self.num_scales-1) - 1 464 | coords = torch.cat((coords, scales), dim=-1) 465 | 466 | coords = coords[coord_octant_idx] 467 | 468 | # query for occupancy 469 | sz_b, sz_p, _ = fine_abs_coords.shape 470 | 471 | in_dict = {'coords': coords, 472 | 'fine_abs_coords': fine_abs_coords, 473 | 'fine_rel_coords': fine_rel_coords, 474 | 'coord_octant_idx': torch.tensor(coord_octant_idx, dtype=torch.int)} 475 | 476 | return in_dict 477 | 478 | def __getitem__(self, idx): 479 | assert(not self.eval) 480 | 481 | octtree = self.octtree 482 | if not self.eval and self.num_workers > 0: 483 | worker_idx = torch.utils.data.get_worker_info().id 484 | octtree = self.octtrees[worker_idx] 485 | 486 | # get fine coords 487 | fine_rel_coords, fine_abs_coords, coord_octant_idx, coord_global_idx = octtree.get_stratified_samples(self.jitter, eval=self.eval) 488 | 489 | # get block coords 490 | octants = octtree.get_active_octants() 491 | coords = torch.stack([p.block_coord for p in octants], dim=0) 492 | scales = torch.stack([torch.tensor(p.scale) for p in octants], dim=0)[:, None] 493 | scales = 2*scales / (self.num_scales-1) - 1 494 | coords = torch.cat((coords, scales), dim=-1) 495 | 496 | if self.eval: 497 | coords = coords[coord_octant_idx] 498 | 499 | # query for occupancy 500 | sz_b, sz_p, _ = fine_abs_coords.shape 501 | fine_abs_coords_query = fine_abs_coords.reshape(-1, 3).detach().cpu().numpy() 502 | 503 | if self.eval: 504 | occupancy = np.zeros(fine_abs_coords_query.shape[0]) 505 | else: 506 | occupancy = self.dataset.evaluate_occupancy(fine_abs_coords_query) # start-end/num iters 507 | occupancy = torch.from_numpy(occupancy).reshape(sz_b, sz_p, 1) 508 | 509 | self.iter += 1 510 | 511 | in_dict = {'coords': coords, 512 | 'fine_abs_coords': fine_abs_coords, 513 | 'fine_rel_coords': fine_rel_coords} 514 | 515 | if self.eval: 516 | in_dict.update({'coord_octant_idx': torch.tensor(coord_octant_idx, dtype=torch.int)}) 517 | 518 | gt_dict = {'occupancy': occupancy} 519 | 520 | return in_dict, gt_dict 521 | 522 | def update_octant_err(self, err_per_octant, step): 523 | assert err_per_octant.shape[0] == len(self.last_active_octants), \ 524 | f"Trying to update the error in active patches but list of patches and error tensor" \ 525 | f" sizes are mismatched: {err_per_octant.shape[0]} vs {len(self.last_active_octants)}" \ 526 | f"step: {step}" 527 | 528 | for i, p in enumerate(self.last_active_octants): 529 | # Log the history of error 530 | p.update_error(err_per_octant[i], step) 531 | 532 | self.per_octant_error = err_per_octant 533 | 534 | def update_tiling(self): 535 | return self.octtree.solve_optim(self.max_octants) 536 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: acorn 2 | channels: 3 | - gurobi 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=10.2.89 8 | - cython=0.29.24 9 | - gurobi=9.1.2=py37_0 10 | - matplotlib=3.3.4 11 | - numpy=1.20.3 12 | - pillow=8.3.1 13 | - pip=21.1.3 14 | - python=3.7.10 15 | - pytorch=1.8.1 16 | - scikit-image=0.18.1 17 | - scipy=1.6.2 18 | - tensorboard=2.5.0=py_0 19 | - torchvision=0.9.1=py37_cu102 20 | - tqdm=4.61.2=pyhd3eb1b0_1 21 | - pip: 22 | - configargparse==1.5.1 23 | - mrcfile==1.3.0 24 | - opencv-python==4.5.3.56 25 | - pycollada==0.7.1 26 | - pymcubes==0.1.2 27 | - trimesh==3.9.25 28 | -------------------------------------------------------------------------------- /experiment_scripts/config_img/config_mars.ini: -------------------------------------------------------------------------------- 1 | experiment_name = mars_acorn_1G 2 | num_workers = 4 3 | res = [19456, 51200] 4 | dataset = tokyo 5 | max_patches = 16384 6 | skip_logging = true 7 | lr = 0.001 8 | num_iters = 100300 9 | epochs_til_ckpt = 2 10 | steps_til_summary = 500 11 | patch_size = [64, 19, 50] 12 | w0 = 5 13 | steps_til_tiling = 500 14 | model_type = multiscale 15 | scale_init = 3 16 | logging_root = ../logs 17 | hidden_features = 2048 18 | hidden_layers = 8 19 | -------------------------------------------------------------------------------- /experiment_scripts/config_img/config_pluto_acorn_1k.ini: -------------------------------------------------------------------------------- 1 | experiment_name = pluto_acorn_1K 2 | num_workers = 12 3 | res = [1024] 4 | dataset = pluto 5 | steps_til_tiling = 500 6 | max_patches = 512 7 | scale_init = 3 8 | skip_logging = False 9 | lr = 0.001 10 | num_iters = 100300 11 | epochs_til_ckpt = 2 12 | steps_til_summary = 500 13 | patch_size = [16, 8, 8] 14 | w0 = 5 15 | model_type = multiscale 16 | logging_root = ../logs 17 | hidden_features = 256 18 | hidden_layers = 4 19 | -------------------------------------------------------------------------------- /experiment_scripts/config_img/config_pluto_acorn_4k.ini: -------------------------------------------------------------------------------- 1 | experiment_name = pluto_acorn_4K 2 | epochs_til_ckpt = 1 3 | num_workers = 8 4 | res = [4096] 5 | dataset = pluto 6 | steps_til_tiling = 500 7 | max_patches = 256 8 | skip_logging = false 9 | lr = 0.001 10 | num_iters = 100300 11 | steps_til_summary = 500 12 | patch_size = [16, 32, 32] 13 | w0 = 50 14 | model_type = multiscale 15 | scale_init = 3 16 | logging_root = ../logs 17 | hidden_features = 512 18 | hidden_layers = 4 19 | -------------------------------------------------------------------------------- /experiment_scripts/config_img/config_pluto_acorn_8k.ini: -------------------------------------------------------------------------------- 1 | experiment_name = pluto_acorn_8K 2 | num_workers = 12 3 | res = 8192 4 | dataset = pluto 5 | max_patches = 1024 6 | skip_logging = false 7 | lr = 0.001 8 | num_iters = 100300 9 | epochs_til_ckpt = 2 10 | steps_til_summary = 500 11 | patch_size = [16, 32, 32] 12 | w0 = 5 13 | steps_til_tiling = 500 14 | model_type = multiscale 15 | scale_init = 3 16 | logging_root = ../logs 17 | hidden_features = 512 18 | hidden_layers = 4 19 | -------------------------------------------------------------------------------- /experiment_scripts/config_img/config_pluto_pe_8k.ini: -------------------------------------------------------------------------------- 1 | experiment_name = pluto_pe_8K 2 | num_workers = 8 3 | res = 8192 4 | dataset = pluto 5 | model_type = pe 6 | scale_init = 4 7 | skip_logging = true 8 | lr = 0.001 9 | num_iters = 100300 10 | epochs_til_ckpt = 2 11 | steps_til_summary = 500 12 | patch_size = [0, 32, 32] 13 | w0 = 5 14 | steps_til_tiling = 500 15 | max_patches = 1024 16 | logging_root = ../logs 17 | hidden_features = 1536 18 | hidden_layers = 4 19 | -------------------------------------------------------------------------------- /experiment_scripts/config_img/config_pluto_siren_8k.ini: -------------------------------------------------------------------------------- 1 | experiment_name = pluto_siren_8K 2 | lr = 1e-05 3 | num_workers = 8 4 | w0 = 50 5 | res = 8192 6 | dataset = pluto 7 | model_type = siren 8 | scale_init = 4 9 | skip_logging = true 10 | num_iters = 100300 11 | epochs_til_ckpt = 2 12 | steps_til_summary = 500 13 | patch_size = [0, 32, 32] 14 | steps_til_tiling = 500 15 | max_patches = 1024 16 | logging_root = ../logs 17 | hidden_features = 1536 18 | hidden_layers = 4 19 | -------------------------------------------------------------------------------- /experiment_scripts/config_img/config_tokyo.ini: -------------------------------------------------------------------------------- 1 | experiment_name = tokyo_acorn_1G 2 | num_workers = 4 3 | res = [19456, 51200] 4 | dataset = tokyo 5 | max_patches = 16384 6 | skip_logging = true 7 | lr = 0.001 8 | num_iters = 100300 9 | epochs_til_ckpt = 2 10 | steps_til_summary = 500 11 | patch_size = [64, 19, 50] 12 | w0 = 5 13 | steps_til_tiling = 500 14 | model_type = multiscale 15 | scale_init = 3 16 | logging_root = ../logs 17 | hidden_features = 2048 18 | hidden_layers = 8 19 | -------------------------------------------------------------------------------- /experiment_scripts/config_occupancy/config_dragon_acorn.ini: -------------------------------------------------------------------------------- 1 | experiment_name = dragon 2 | num_workers = 8 3 | lr = 0.001 4 | num_epochs = 10000 5 | epochs_til_pruning = 4 6 | epochs_til_ckpt = 4 7 | steps_til_summary = 4000 8 | octant_size = 4 9 | res = 512 10 | steps_til_tiling = 1000 11 | max_octants = 1024 12 | logging_root = ../logs 13 | feature_grid_size = [18, 12, 12, 12] 14 | pc_filepath = ../data/dragon.obj 15 | pruning_threshold = -10 16 | hidden_features = 512 17 | hidden_layers = 4 18 | -------------------------------------------------------------------------------- /experiment_scripts/config_occupancy/config_engine_acorn.ini: -------------------------------------------------------------------------------- 1 | experiment_name = engine 2 | lr = 0.001 3 | num_workers = 8 4 | octant_size = 4 5 | steps_til_tiling = 1000 6 | pruning_threshold = -10.0 7 | num_epochs = 10000 8 | epochs_til_ckpt = 4 9 | steps_til_summary = 4000 10 | res = 512 11 | max_octants = 1024 12 | logging_root = ../logs 13 | feature_grid_size = [18, 12, 12, 12] 14 | pc_filepath = ../data/engine.obj 15 | epochs_til_pruning = 4 16 | hidden_features = 512 17 | hidden_layers = 4 18 | -------------------------------------------------------------------------------- /experiment_scripts/config_occupancy/config_lucy_acorn.ini: -------------------------------------------------------------------------------- 1 | experiment_name = lucy 2 | num_workers = 6 3 | lr = 0.001 4 | num_epochs = 10000 5 | epochs_til_pruning = 4 6 | epochs_til_ckpt = 4 7 | steps_til_summary = 4000 8 | octant_size = 4 9 | res = 512 10 | steps_til_tiling = 1000 11 | max_octants = 1024 12 | logging_root = ../logs 13 | feature_grid_size = [18, 12, 12, 12] 14 | pc_filepath = ../data/lucy.ply 15 | pruning_threshold = -10 16 | hidden_features = 512 17 | hidden_layers = 4 18 | -------------------------------------------------------------------------------- /experiment_scripts/config_occupancy/config_lucy_small_acorn.ini: -------------------------------------------------------------------------------- 1 | experiment_name = lucy_small 2 | num_workers = 12 3 | lr = 0.001 4 | num_epochs = 10000 5 | epochs_til_pruning = 4 6 | epochs_til_ckpt = 4 7 | steps_til_summary = 4000 8 | octant_size = 4 9 | res = 256 10 | steps_til_tiling = 1000 11 | max_octants = 256 12 | logging_root = ../logs 13 | feature_grid_size = [18, 12, 12, 12] 14 | pc_filepath = ../data/lucy.ply 15 | pruning_threshold = -10 16 | hidden_features = 256 17 | hidden_layers = 4 18 | scale_init = 2 19 | -------------------------------------------------------------------------------- /experiment_scripts/config_occupancy/config_thai_acorn.ini: -------------------------------------------------------------------------------- 1 | experiment_name = thai_statue 2 | num_workers = 8 3 | lr = 0.001 4 | num_epochs = 10000 5 | epochs_til_ckpt = 4 6 | steps_til_summary = 4000 7 | octant_size = 4 8 | res = 512 9 | steps_til_tiling = 1000 10 | max_octants = 1024 11 | logging_root = ../logs 12 | feature_grid_size = [18, 12, 12, 12] 13 | pc_filepath = ../data/thai_statue.ply 14 | pruning_threshold = -10 15 | epochs_til_pruning = 4 16 | hidden_features = 512 17 | hidden_layers = 4 18 | -------------------------------------------------------------------------------- /experiment_scripts/train_img.py: -------------------------------------------------------------------------------- 1 | # Enable import from parent package 2 | import sys 3 | import os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | import dataio 6 | import utils 7 | import training 8 | import loss_functions 9 | import pruning_functions 10 | import modules 11 | from torch.utils.data import DataLoader 12 | import configargparse 13 | from functools import partial 14 | import numpy as np 15 | import skimage 16 | import cv2 17 | from tqdm import tqdm 18 | import matplotlib.pyplot as plt 19 | import re 20 | import torch 21 | from time import time 22 | 23 | 24 | p = configargparse.ArgumentParser() 25 | p.add('-c', '--config', required=False, is_config_file=True, help='Path to config file.') 26 | 27 | # General training options 28 | p.add_argument('--lr', type=float, default=1e-3, help='learning rate. default=1e-3') 29 | p.add_argument('--num_iters', type=int, default=100300, 30 | help='Number of iterations to train for.') 31 | p.add_argument('--num_workers', type=int, default=0, 32 | help='number of dataloader workers.') 33 | p.add_argument('--skip_logging', action='store_true', default=False, 34 | help="don't use summary function, only save loss and models") 35 | p.add_argument('--eval', action='store_true', default=False, 36 | help='run evaluation') 37 | p.add_argument('--resume', nargs=2, type=str, default=None, 38 | help='resume training, specify path to directory where model is stored and the iteration of ckpt.') 39 | p.add_argument('--gpu', type=int, default=0, 40 | help='GPU ID to use') 41 | 42 | # logging options 43 | p.add_argument('--experiment_name', type=str, required=True, 44 | help='path to directory where checkpoints & tensorboard events will be saved.') 45 | p.add_argument('--epochs_til_ckpt', type=int, default=2, 46 | help='Epochs until checkpoint is saved.') 47 | p.add_argument('--steps_til_summary', type=int, default=500, 48 | help='Number of iterations until tensorboard summary is saved.') 49 | 50 | # dataset options 51 | p.add_argument('--res', nargs='+', type=int, default=[512], 52 | help='image resolution.') 53 | p.add_argument('--dataset', type=str, default='camera', choices=['camera', 'pluto', 'tokyo', 'mars'], 54 | help='which dataset to use') 55 | p.add_argument('--grayscale', action='store_true', default=False, 56 | help='whether to use grayscale') 57 | 58 | # model options 59 | p.add_argument('--patch_size', nargs='+', type=int, default=[32], 60 | help='patch size.') 61 | p.add_argument('--hidden_features', type=int, default=512, 62 | help='hidden features in network') 63 | p.add_argument('--hidden_layers', type=int, default=4, 64 | help='hidden layers in network') 65 | p.add_argument('--w0', type=int, default=5, 66 | help='w0 for the siren model.') 67 | p.add_argument('--steps_til_tiling', type=int, default=500, 68 | help='How often to recompute the tiling, also defines number of steps per epoch.') 69 | p.add_argument('--max_patches', type=int, default=1024, 70 | help='maximum number of patches in the optimization') 71 | p.add_argument('--model_type', type=str, default='multiscale', required=False, choices=['multiscale', 'siren', 'pe'], 72 | help='Type of model to evaluate, default is multiscale.') 73 | p.add_argument('--scale_init', type=int, default=3, 74 | help='which scale to initialize active patches in the quadtree') 75 | 76 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 77 | p.add_argument('--logging_root', type=str, default='../logs', help='root for logging') 78 | opt = p.parse_args() 79 | 80 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 81 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu) 82 | 83 | for k, v in opt.__dict__.items(): 84 | print(k, v) 85 | 86 | 87 | def main(): 88 | if opt.dataset == 'camera': 89 | img_dataset = dataio.Camera() 90 | elif opt.dataset == 'pluto': 91 | pluto_url = "https://upload.wikimedia.org/wikipedia/commons/e/ef/Pluto_in_True_Color_-_High-Res.jpg" 92 | img_dataset = dataio.ImageFile('../data/pluto.jpg', url=pluto_url, grayscale=opt.grayscale) 93 | elif opt.dataset == 'tokyo': 94 | img_dataset = dataio.ImageFile('../data/tokyo.tif', grayscale=opt.grayscale) 95 | elif opt.dataset == 'mars': 96 | img_dataset = dataio.ImageFile('../data/mars.tif', grayscale=opt.grayscale) 97 | 98 | if len(opt.patch_size) == 1: 99 | opt.patch_size = 3*opt.patch_size 100 | 101 | # set up dataset 102 | coord_dataset = dataio.Patch2DWrapperMultiscaleAdaptive(img_dataset, 103 | sidelength=opt.res, 104 | patch_size=opt.patch_size[1:], jitter=True, 105 | num_workers=opt.num_workers, length=opt.steps_til_tiling, 106 | scale_init=opt.scale_init, max_patches=opt.max_patches) 107 | 108 | opt.num_epochs = opt.num_iters // coord_dataset.__len__() 109 | 110 | image_resolution = (opt.res, opt.res) 111 | 112 | dataloader = DataLoader(coord_dataset, shuffle=False, batch_size=1, pin_memory=True, 113 | num_workers=opt.num_workers) 114 | 115 | if opt.resume is not None: 116 | path, iter = opt.resume 117 | iter = int(iter) 118 | assert(os.path.isdir(path)) 119 | assert opt.config is not None, 'Specify config file' 120 | 121 | # Define the model. 122 | if opt.grayscale: 123 | out_features = 1 124 | else: 125 | out_features = 3 126 | 127 | if opt.model_type == 'multiscale': 128 | model = modules.ImplicitAdaptivePatchNet(in_features=3, out_features=out_features, 129 | num_hidden_layers=opt.hidden_layers, 130 | hidden_features=opt.hidden_features, 131 | feature_grid_size=(opt.patch_size[0], opt.patch_size[1], opt.patch_size[2]), 132 | sidelength=opt.res, 133 | num_encoding_functions=10, 134 | patch_size=opt.patch_size[1:]) 135 | 136 | elif opt.model_type == 'siren': 137 | model = modules.ImplicitNet(opt.res, in_features=2, 138 | out_features=out_features, 139 | num_hidden_layers=4, 140 | hidden_features=1536, 141 | mode='siren', w0=opt.w0) 142 | elif opt.model_type == 'pe': 143 | model = modules.ImplicitNet(opt.res, in_features=2, 144 | out_features=out_features, 145 | num_hidden_layers=4, 146 | hidden_features=1536, 147 | mode='pe') 148 | else: 149 | raise NotImplementedError('Only model types multiscale, siren, and pe are implemented') 150 | 151 | model.cuda() 152 | 153 | # print number of model parameters 154 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 155 | params = sum([np.prod(p.size()) for p in model_parameters]) 156 | print(f'Num. Parameters: {params}') 157 | 158 | # Define the loss 159 | loss_fn = partial(loss_functions.image_mse, 160 | tiling_every=opt.steps_til_tiling, 161 | dataset=coord_dataset, 162 | model_type=opt.model_type) 163 | summary_fn = partial(utils.write_image_patch_multiscale_summary, image_resolution, opt.patch_size[1:], coord_dataset, model_type=opt.model_type, skip=opt.skip_logging) 164 | 165 | # Define the pruning function 166 | pruning_fn = partial(pruning_functions.no_pruning, 167 | pruning_every=1) 168 | 169 | # if we are resuming from a saved checkpoint 170 | if opt.resume is not None: 171 | print('Loading checkpoints') 172 | model_dict = torch.load(path + '/checkpoints/' + f'model_{iter:06d}.pth') 173 | model.load_state_dict(model_dict) 174 | 175 | # load optimizers 176 | try: 177 | resume_checkpoint = {} 178 | optim_dict = torch.load(path + '/checkpoints/' + f'optim_{iter:06d}.pth') 179 | for g in optim_dict['optimizer_state_dict']['param_groups']: 180 | g['lr'] = opt.lr 181 | resume_checkpoint['optimizer_state_dict'] = optim_dict['optimizer_state_dict'] 182 | resume_checkpoint['total_steps'] = optim_dict['total_steps'] 183 | resume_checkpoint['epoch'] = optim_dict['epoch'] 184 | 185 | # initialize model state_dict 186 | print('Initializing models') 187 | coord_dataset.quadtree.__load__(optim_dict['quadtree']) 188 | coord_dataset.synchronize() 189 | 190 | except FileNotFoundError: 191 | print('Unable to load optimizer checkpoints') 192 | else: 193 | resume_checkpoint = {} 194 | 195 | if opt.eval: 196 | run_eval(model, coord_dataset) 197 | else: 198 | # Save command-line parameters log directory. 199 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 200 | utils.cond_mkdir(root_path) 201 | p.write_config_file(opt, [os.path.join(root_path, 'config.ini')]) 202 | 203 | # Save text summary of model into log directory. 204 | with open(os.path.join(root_path, "model.txt"), "w") as out_file: 205 | out_file.write(str(model)) 206 | 207 | objs_to_save = {'quadtree': coord_dataset.quadtree} 208 | 209 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr, 210 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt, 211 | model_dir=root_path, loss_fn=loss_fn, pruning_fn=pruning_fn, summary_fn=summary_fn, objs_to_save=objs_to_save, 212 | resume_checkpoint=resume_checkpoint) 213 | 214 | 215 | # evaluate PSNR at saved checkpoints and save model outputs 216 | def run_eval(model, coord_dataset): 217 | # get checkpoint directory 218 | checkpoint_dir = os.path.join(os.path.dirname(opt.config), 'checkpoints') 219 | 220 | # make eval directory 221 | eval_dir = os.path.join(os.path.dirname(opt.config), 'eval') 222 | utils.cond_mkdir(eval_dir) 223 | 224 | # get model & optim files 225 | model_files = sorted([f for f in os.listdir(checkpoint_dir) if re.search(r'model_[0-9]+.pth', f)], reverse=True) 226 | optim_files = sorted([f for f in os.listdir(checkpoint_dir) if re.search(r'optim_[0-9]+.pth', f)], reverse=True) 227 | 228 | # extract iterations 229 | iters = [int(re.search(r'[0-9]+', f)[0]) for f in model_files] 230 | 231 | # append beginning of path 232 | model_files = [os.path.join(checkpoint_dir, f) for f in model_files] 233 | optim_files = [os.path.join(checkpoint_dir, f) for f in optim_files] 234 | 235 | # iterate through model and optim files 236 | metrics = {} 237 | saved_gt = False 238 | for curr_iter, model_path, optim_path in zip(tqdm(iters), model_files, optim_files): 239 | 240 | # load model and optimizer files 241 | print('Loading models') 242 | model_dict = torch.load(model_path) 243 | optim_dict = torch.load(optim_path) 244 | 245 | # initialize model state_dict 246 | print('Initializing models') 247 | model.load_state_dict(model_dict) 248 | coord_dataset.quadtree.__load__(optim_dict['quadtree']) 249 | coord_dataset.synchronize() 250 | 251 | # save image and calculate psnr 252 | coord_dataset.toggle_eval() 253 | model_input, gt = coord_dataset[0] 254 | coord_dataset.toggle_eval() 255 | 256 | # convert to cuda and add batch dimension 257 | tmp = {} 258 | for key, value in model_input.items(): 259 | if isinstance(value, torch.Tensor): 260 | tmp.update({key: value[None, ...].cpu()}) 261 | else: 262 | tmp.update({key: value}) 263 | model_input = tmp 264 | 265 | tmp = {} 266 | for key, value in gt.items(): 267 | if isinstance(value, torch.Tensor): 268 | tmp.update({key: value[None, ...].cpu()}) 269 | else: 270 | tmp.update({key: value}) 271 | gt = tmp 272 | 273 | # run the model on uniform samples 274 | print('Running forward pass') 275 | n_channels = gt['img'].shape[-1] 276 | start = time() 277 | with torch.no_grad(): 278 | pred_img = utils.process_batch_in_chunks(model_input, model, max_chunk_size=512)['model_out']['output'] 279 | torch.cuda.synchronize() 280 | print(f'Model: {time() - start:.02f}') 281 | 282 | # get pixel idx for each coordinate 283 | start = time() 284 | coords = model_input['fine_abs_coords'].detach().cpu().numpy() 285 | pixel_idx = np.zeros_like(coords).astype(np.int32) 286 | pixel_idx[..., 0] = np.round((coords[..., 0] + 1.)/2. * (coord_dataset.sidelength[0]-1)).astype(np.int32) 287 | pixel_idx[..., 1] = np.round((coords[..., 1] + 1.)/2. * (coord_dataset.sidelength[1]-1)).astype(np.int32) 288 | pixel_idx = pixel_idx.reshape(-1, 2) 289 | 290 | # assign predicted image values into a new array 291 | # need to use numpy since it supports index assignment 292 | pred_img = pred_img.detach().cpu().numpy().reshape(-1, n_channels) 293 | display_pred = np.zeros((*coord_dataset.sidelength, n_channels)) 294 | display_pred[[pixel_idx[:, 0]], [pixel_idx[:, 1]]] = pred_img 295 | display_pred = torch.tensor(display_pred)[None, ...] 296 | display_pred = display_pred.permute(0, 3, 1, 2) 297 | 298 | if not saved_gt: 299 | gt_img = gt['img'].detach().cpu().numpy().reshape(-1, n_channels) 300 | display_gt = np.zeros((*coord_dataset.sidelength, n_channels)) 301 | display_gt[[pixel_idx[:, 0]], [pixel_idx[:, 1]]] = gt_img 302 | display_gt = torch.tensor(display_gt)[None, ...] 303 | display_gt = display_gt.permute(0, 3, 1, 2) 304 | print(f'Reshape: {time() - start:.02f}') 305 | 306 | # record metrics 307 | start = time() 308 | psnr, ssim = get_metrics(display_pred, display_gt) 309 | metrics.update({curr_iter: {'psnr': psnr, 'ssim': ssim}}) 310 | print(f'Metrics: {time() - start:.02f}') 311 | print(f'Iter: {curr_iter}, PSNR: {psnr:.02f}') 312 | 313 | # save images 314 | pred_out = np.clip((display_pred.squeeze().numpy()/2.) + 0.5, a_min=0., a_max=1.).transpose(1, 2, 0)*255 315 | pred_out = pred_out.astype(np.uint8) 316 | pred_fname = os.path.join(eval_dir, f'pred_{curr_iter:06d}.png') 317 | print('Saving image') 318 | cv2.imwrite(pred_fname, cv2.cvtColor(pred_out, cv2.COLOR_RGB2BGR)) 319 | 320 | if not saved_gt: 321 | print('Saving gt') 322 | gt_out = np.clip((display_gt.squeeze().numpy()/2.) + 0.5, a_min=0., a_max=1.).transpose(1, 2, 0)*255 323 | gt_out = gt_out.astype(np.uint8) 324 | gt_fname = os.path.join(eval_dir, 'gt.png') 325 | cv2.imwrite(gt_fname, cv2.cvtColor(gt_out, cv2.COLOR_RGB2BGR)) 326 | saved_gt = True 327 | 328 | # save tiling 329 | tiling_fname = os.path.join(eval_dir, f'tiling_{curr_iter:06d}.pdf') 330 | coord_dataset.quadtree.draw() 331 | plt.savefig(tiling_fname) 332 | 333 | # save metrics 334 | metric_fname = os.path.join(eval_dir, f'metrics_{curr_iter:06d}.npy') 335 | np.save(metric_fname, metrics) 336 | 337 | 338 | def get_metrics(pred_img, gt_img): 339 | pred_img = pred_img.detach().cpu().numpy().squeeze() 340 | gt_img = gt_img.detach().cpu().numpy().squeeze() 341 | 342 | p = pred_img.transpose(1, 2, 0) 343 | trgt = gt_img.transpose(1, 2, 0) 344 | 345 | p = (p / 2.) + 0.5 346 | p = np.clip(p, a_min=0., a_max=1.) 347 | 348 | trgt = (trgt / 2.) + 0.5 349 | 350 | ssim = skimage.metrics.structural_similarity(p, trgt, multichannel=True, data_range=1) 351 | psnr = skimage.metrics.peak_signal_noise_ratio(p, trgt, data_range=1) 352 | 353 | return psnr, ssim 354 | 355 | 356 | if __name__ == '__main__': 357 | main() 358 | -------------------------------------------------------------------------------- /experiment_scripts/train_occupancy.py: -------------------------------------------------------------------------------- 1 | # Enable import from parent package 2 | import sys 3 | import os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | import dataio 6 | import utils 7 | import training 8 | import loss_functions 9 | import pruning_functions 10 | import modules 11 | from torch.utils.data import DataLoader 12 | import configargparse 13 | from functools import partial 14 | import torch 15 | import re 16 | import mcubes 17 | 18 | 19 | p = configargparse.ArgumentParser() 20 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 21 | 22 | # General training options 23 | p.add_argument('--lr', type=float, default=1e-3, help='learning rate. default=5e-5') 24 | p.add_argument('--num_epochs', type=int, default=10000, 25 | help='Number of epochs to train for.') 26 | p.add_argument('--num_workers', type=int, default=0, 27 | help='number of dataloader workers.') 28 | p.add_argument('--pc_filepath', type=str, default='', help='pc_or_mesh') 29 | p.add_argument('--load', type=str, default=None, help='logging directory to resume from') 30 | p.add_argument('--gpu', type=int, default=0, 31 | help='GPU ID to use') 32 | 33 | # logging options 34 | p.add_argument('--experiment_name', type=str, required=True, 35 | help='path to directory where checkpoints & tensorboard events will be saved.') 36 | p.add_argument('--skip_logging', action='store_true', default=False, 37 | help="don't use summary function, only save loss and models") 38 | p.add_argument('--epochs_til_ckpt', type=int, default=4, 39 | help='Time interval in seconds until checkpoint is saved.') 40 | p.add_argument('--steps_til_summary', type=int, default=500, 41 | help='Time interval in seconds until tensorboard summary is saved.') 42 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 43 | p.add_argument('--logging_root', type=str, default='../logs', help='root for logging') 44 | 45 | # dataset options 46 | p.add_argument('--res', type=int, default=512, 47 | help='image resolution.') 48 | p.add_argument('--octant_size', type=int, default=4, 49 | help='patch size.') 50 | p.add_argument('--max_octants', type=int, default=1024) 51 | p.add_argument('--scale_init', type=int, default=3, 52 | help='which scale to initialize active octants in the octree') 53 | 54 | # model options 55 | p.add_argument('--steps_til_tiling', type=int, default=1000, 56 | help='how often to recompute the tiling') 57 | p.add_argument('--hidden_features', type=int, default=512, 58 | help='hidden features in network') 59 | p.add_argument('--hidden_layers', type=int, default=4, 60 | help='hidden layers in network') 61 | p.add_argument('--feature_grid_size', nargs='+', type=int, default=(18, 12, 12, 12)) 62 | p.add_argument('--pruning_threshold', type=float, default=-10) 63 | p.add_argument('--epochs_til_pruning', type=int, default=4) 64 | 65 | # export options 66 | p.add_argument('--export', action='store_true', default=False, 67 | help='export mesh from checkpoint (requires load flag)') 68 | p.add_argument('--upsample', type=int, default=2, 69 | help='how much to upsamples the occupancies used to generate the output mesh') 70 | p.add_argument('--mc_threshold', type=float, default=0.005, help='threshold for marching cubes') 71 | 72 | 73 | opt = p.parse_args() 74 | 75 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 76 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu) 77 | 78 | feature_grid_size = tuple(opt.feature_grid_size) 79 | 80 | assert opt.pc_filepath is not None, "Must specify dataset input" 81 | 82 | for k, v in opt.__dict__.items(): 83 | print(k, v) 84 | print() 85 | 86 | 87 | def main(): 88 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 89 | utils.cond_mkdir(root_path) 90 | 91 | point_cloud_dataset = dataio.OccupancyDataset(opt.pc_filepath) 92 | 93 | coord_dataset = dataio.Block3DWrapperMultiscaleAdaptive(point_cloud_dataset, 94 | sidelength=opt.res, 95 | octant_size=opt.octant_size, 96 | jitter=True, max_octants=opt.max_octants, 97 | num_workers=opt.num_workers, 98 | length=opt.steps_til_tiling, 99 | scale_init=opt.scale_init) 100 | 101 | model = modules.ImplicitAdaptiveOctantNet(in_features=3+1, out_features=1, 102 | num_hidden_layers=opt.hidden_layers, 103 | hidden_features=opt.hidden_features, 104 | feature_grid_size=feature_grid_size, 105 | octant_size=opt.octant_size) 106 | model.cuda() 107 | 108 | resume_checkpoint = {} 109 | if opt.load is not None: 110 | resume_checkpoint = load_from_checkpoint(opt.load, model, coord_dataset) 111 | 112 | if opt.export: 113 | assert opt.load is not None, 'Need to specify which model to export with --load' 114 | 115 | export_mesh(model, coord_dataset, opt.upsample, opt.mc_threshold) 116 | return 117 | 118 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 119 | print(f"\n\nTrainable Parameters: {num_params}\n\n") 120 | 121 | dataloader = DataLoader(coord_dataset, shuffle=False, batch_size=1, pin_memory=True, 122 | num_workers=opt.num_workers) 123 | 124 | # Define the loss 125 | loss_fn = partial(loss_functions.occupancy_bce, 126 | tiling_every=opt.steps_til_tiling, 127 | dataset=coord_dataset) 128 | 129 | summary_fn = partial(utils.write_occupancy_multiscale_summary, (opt.res, opt.res, opt.res), 130 | coord_dataset, output_mrc=f'{opt.experiment_name}.mrc', 131 | skip=opt.skip_logging) 132 | 133 | # Define the pruning 134 | pruning_fn = partial(pruning_functions.pruning_occupancy, 135 | threshold=opt.pruning_threshold) 136 | 137 | # Save command-line parameters log directory. 138 | p.write_config_file(opt, [os.path.join(root_path, 'config.ini')]) 139 | 140 | # Save text summary of model into log directory. 141 | with open(os.path.join(root_path, "model.txt"), "w") as out_file: 142 | out_file.write(str(model)) 143 | 144 | objs_to_save = {'octtree': coord_dataset.octtree} 145 | 146 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr, 147 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt, 148 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, objs_to_save=objs_to_save, 149 | pruning_fn=pruning_fn, epochs_til_pruning=opt.epochs_til_pruning, 150 | resume_checkpoint=resume_checkpoint) 151 | 152 | 153 | def load_from_checkpoint(experiment_dir, model, coord_dataset): 154 | checkpoint_dir = os.path.join(experiment_dir, 'checkpoints') 155 | model_files = sorted([f for f in os.listdir(checkpoint_dir) if re.search(r'model_[0-9]+.pth', f)], reverse=False) 156 | optim_files = sorted([f for f in os.listdir(checkpoint_dir) if re.search(r'optim_[0-9]+.pth', f)], reverse=False) 157 | 158 | # append beginning of path 159 | model_files = [os.path.join(checkpoint_dir, f) for f in model_files] 160 | optim_files = [os.path.join(checkpoint_dir, f) for f in optim_files] 161 | model_path = model_files[-1] 162 | optim_path = optim_files[-1] 163 | 164 | print("MODEL PATH: ", model_path) 165 | 166 | # load model and octree 167 | print('Loading models') 168 | model_dict = torch.load(model_path) 169 | optim_dict = torch.load(optim_path) 170 | 171 | # initialize model state_dict 172 | print('Initializing models') 173 | model.load_state_dict(model_dict) 174 | coord_dataset.octtree.__load__(optim_dict['octtree']) 175 | coord_dataset.synchronize() 176 | 177 | resume_checkpoint = {} 178 | resume_checkpoint['optimizer_state_dict'] = optim_dict['optimizer_state_dict'] 179 | resume_checkpoint['total_steps'] = optim_dict['total_steps'] 180 | resume_checkpoint['epoch'] = optim_dict['epoch'] 181 | 182 | return resume_checkpoint 183 | 184 | 185 | def export_mesh(model, dataset, upsample, mcubes_threshold=0.005): 186 | res = 3*(upsample*opt.res,) 187 | model.octant_size = model.octant_size * upsample 188 | 189 | print('Export: calculating occupancy...') 190 | mrc_fname = os.path.join(opt.logging_root, opt.experiment_name, f"{opt.experiment_name}.mrc") 191 | occupancy = utils.write_occupancy_multiscale_summary(res, dataset, model, 192 | None, None, None, None, None, 193 | output_mrc=mrc_fname, 194 | oversample=upsample, 195 | mode='hq') 196 | 197 | print('Export: running marching cubes...') 198 | vertices, faces = mcubes.marching_cubes(occupancy, mcubes_threshold) 199 | 200 | print('Export: exporting mesh...') 201 | out_fname = os.path.join(opt.logging_root, opt.experiment_name, f"{opt.experiment_name}.dae") 202 | mcubes.export_mesh(vertices, faces, out_fname) 203 | 204 | 205 | if __name__ == '__main__': 206 | main() 207 | -------------------------------------------------------------------------------- /img/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/ACORN/67817daaf20d38840f043e81704e7a6ea5da584e/img/teaser.png -------------------------------------------------------------------------------- /inside_mesh/inside_mesh.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright 2019 Lars Mescheder, Michael Oechsle, Michael Niemeyer, Andreas Geiger, Sebastian Nowozin 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | 10 | https://github.com/autonomousvision/occupancy_networks/tree/ddb2908f96de9c0c5a30c093f2a701878ffc1f4a/im2mesh/utils/libmesh 11 | ''' 12 | 13 | 14 | import numpy as np 15 | from .triangle_hash import TriangleHash as _TriangleHash 16 | 17 | 18 | def check_mesh_contains(mesh, points, hash_resolution=512): 19 | intersector = MeshIntersector(mesh, hash_resolution) 20 | contains = intersector.query(points) 21 | return contains 22 | 23 | 24 | class MeshIntersector: 25 | def __init__(self, mesh, resolution=512): 26 | triangles = mesh.vertices[mesh.faces].astype(np.float64) 27 | n_tri = triangles.shape[0] 28 | 29 | self.resolution = resolution 30 | self.bbox_min = triangles.reshape(3 * n_tri, 3).min(axis=0) 31 | self.bbox_max = triangles.reshape(3 * n_tri, 3).max(axis=0) 32 | # Tranlate and scale it to [0.5, self.resolution - 0.5]^3 33 | self.scale = (resolution - 1) / (self.bbox_max - self.bbox_min) 34 | self.translate = 0.5 - self.scale * self.bbox_min 35 | 36 | self._triangles = triangles = self.rescale(triangles) 37 | # assert(np.allclose(triangles.reshape(-1, 3).min(0), 0.5)) 38 | # assert(np.allclose(triangles.reshape(-1, 3).max(0), resolution - 0.5)) 39 | 40 | triangles2d = triangles[:, :, :2] 41 | self._tri_intersector2d = TriangleIntersector2d( 42 | triangles2d, resolution) 43 | 44 | def query(self, points): 45 | # Rescale points 46 | points = self.rescale(points) 47 | 48 | # placeholder result with no hits we'll fill in later 49 | contains = np.zeros(len(points), dtype=np.bool) 50 | 51 | # cull points outside of the axis aligned bounding box 52 | # this avoids running ray tests unless points are close 53 | inside_aabb = np.all( 54 | (0 <= points) & (points <= self.resolution), axis=1) 55 | if not inside_aabb.any(): 56 | return contains 57 | 58 | # Only consider points inside bounding box 59 | mask = inside_aabb 60 | points = points[mask] 61 | 62 | # Compute intersection depth and check order 63 | points_indices, tri_indices = self._tri_intersector2d.query(points[:, :2]) 64 | 65 | triangles_intersect = self._triangles[tri_indices] 66 | points_intersect = points[points_indices] 67 | 68 | depth_intersect, abs_n_2 = self.compute_intersection_depth( 69 | points_intersect, triangles_intersect) 70 | 71 | # Count number of intersections in both directions 72 | smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2 73 | bigger_depth = depth_intersect < points_intersect[:, 2] * abs_n_2 74 | points_indices_0 = points_indices[smaller_depth] 75 | points_indices_1 = points_indices[bigger_depth] 76 | 77 | nintersect0 = np.bincount(points_indices_0, minlength=points.shape[0]) 78 | nintersect1 = np.bincount(points_indices_1, minlength=points.shape[0]) 79 | 80 | # Check if point contained in mesh 81 | contains1 = (np.mod(nintersect0, 2) == 1) 82 | contains2 = (np.mod(nintersect1, 2) == 1) 83 | # if (contains1 != contains2).any(): 84 | # print('Warning: contains1 != contains2 for some points.') 85 | contains[mask] = (contains1 & contains2) 86 | return contains 87 | 88 | def compute_intersection_depth(self, points, triangles): 89 | t1 = triangles[:, 0, :] 90 | t2 = triangles[:, 1, :] 91 | t3 = triangles[:, 2, :] 92 | 93 | v1 = t3 - t1 94 | v2 = t2 - t1 95 | # v1 = v1 / np.linalg.norm(v1, axis=-1, keepdims=True) 96 | # v2 = v2 / np.linalg.norm(v2, axis=-1, keepdims=True) 97 | 98 | normals = np.cross(v1, v2) 99 | alpha = np.sum(normals[:, :2] * (t1[:, :2] - points[:, :2]), axis=1) 100 | 101 | n_2 = normals[:, 2] 102 | t1_2 = t1[:, 2] 103 | s_n_2 = np.sign(n_2) 104 | abs_n_2 = np.abs(n_2) 105 | 106 | mask = (abs_n_2 != 0) 107 | 108 | depth_intersect = np.full(points.shape[0], np.nan) 109 | depth_intersect[mask] = \ 110 | t1_2[mask] * abs_n_2[mask] + alpha[mask] * s_n_2[mask] 111 | 112 | # Test the depth: 113 | # TODO: remove and put into tests 114 | # points_new = np.concatenate([points[:, :2], depth_intersect[:, None]], axis=1) 115 | # alpha = (normals * t1).sum(-1) 116 | # mask = (depth_intersect == depth_intersect) 117 | # assert(np.allclose((points_new[mask] * normals[mask]).sum(-1), 118 | # alpha[mask])) 119 | return depth_intersect, abs_n_2 120 | 121 | def rescale(self, array): 122 | array = self.scale * array + self.translate 123 | return array 124 | 125 | 126 | class TriangleIntersector2d: 127 | def __init__(self, triangles, resolution=128): 128 | self.triangles = triangles 129 | self.tri_hash = _TriangleHash(triangles, resolution) 130 | 131 | def query(self, points): 132 | point_indices, tri_indices = self.tri_hash.query(points) 133 | point_indices = np.array(point_indices, dtype=np.int64) 134 | tri_indices = np.array(tri_indices, dtype=np.int64) 135 | points = points[point_indices] 136 | triangles = self.triangles[tri_indices] 137 | mask = self.check_triangles(points, triangles) 138 | point_indices = point_indices[mask] 139 | tri_indices = tri_indices[mask] 140 | return point_indices, tri_indices 141 | 142 | def check_triangles(self, points, triangles): 143 | contains = np.zeros(points.shape[0], dtype=np.bool) 144 | A = triangles[:, :2] - triangles[:, 2:] 145 | A = A.transpose([0, 2, 1]) 146 | y = points - triangles[:, 2] 147 | 148 | detA = A[:, 0, 0] * A[:, 1, 1] - A[:, 0, 1] * A[:, 1, 0] 149 | 150 | mask = (np.abs(detA) != 0.) 151 | A = A[mask] 152 | y = y[mask] 153 | detA = detA[mask] 154 | 155 | s_detA = np.sign(detA) 156 | abs_detA = np.abs(detA) 157 | 158 | u = (A[:, 1, 1] * y[:, 0] - A[:, 0, 1] * y[:, 1]) * s_detA 159 | v = (-A[:, 1, 0] * y[:, 0] + A[:, 0, 0] * y[:, 1]) * s_detA 160 | 161 | sum_uv = u + v 162 | contains[mask] = ( 163 | (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA) 164 | & (0 < sum_uv) & (sum_uv < abs_detA) 165 | ) 166 | return contains 167 | -------------------------------------------------------------------------------- /inside_mesh/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup( 6 | ext_modules = cythonize("triangle_hash.pyx"), 7 | include_dirs = [numpy.get_include()] 8 | ) -------------------------------------------------------------------------------- /inside_mesh/triangle_hash.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language=c++ 2 | ''' 3 | Copyright 2019 Lars Mescheder, Michael Oechsle, Michael Niemeyer, Andreas Geiger, Sebastian Nowozin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | https://github.com/autonomousvision/occupancy_networks/tree/ddb2908f96de9c0c5a30c093f2a701878ffc1f4a/im2mesh/utils/libmesh 12 | ''' 13 | 14 | 15 | import numpy as np 16 | cimport numpy as np 17 | cimport cython 18 | from libcpp.vector cimport vector 19 | from libc.math cimport floor, ceil 20 | 21 | cdef class TriangleHash: 22 | cdef vector[vector[int]] spatial_hash 23 | cdef int resolution 24 | 25 | def __cinit__(self, double[:, :, :] triangles, int resolution): 26 | self.spatial_hash.resize(resolution * resolution) 27 | self.resolution = resolution 28 | self._build_hash(triangles) 29 | 30 | @cython.boundscheck(False) # Deactivate bounds checking 31 | @cython.wraparound(False) # Deactivate negative indexing. 32 | cdef int _build_hash(self, double[:, :, :] triangles): 33 | assert(triangles.shape[1] == 3) 34 | assert(triangles.shape[2] == 2) 35 | 36 | cdef int n_tri = triangles.shape[0] 37 | cdef int bbox_min[2] 38 | cdef int bbox_max[2] 39 | 40 | cdef int i_tri, j, x, y 41 | cdef int spatial_idx 42 | 43 | for i_tri in range(n_tri): 44 | # Compute bounding box 45 | for j in range(2): 46 | bbox_min[j] = min( 47 | triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j] 48 | ) 49 | bbox_max[j] = max( 50 | triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j] 51 | ) 52 | bbox_min[j] = min(max(bbox_min[j], 0), self.resolution - 1) 53 | bbox_max[j] = min(max(bbox_max[j], 0), self.resolution - 1) 54 | 55 | # Find all voxels where bounding box intersects 56 | for x in range(bbox_min[0], bbox_max[0] + 1): 57 | for y in range(bbox_min[1], bbox_max[1] + 1): 58 | spatial_idx = self.resolution * x + y 59 | self.spatial_hash[spatial_idx].push_back(i_tri) 60 | 61 | @cython.boundscheck(False) # Deactivate bounds checking 62 | @cython.wraparound(False) # Deactivate negative indexing. 63 | cpdef query(self, double[:, :] points): 64 | assert(points.shape[1] == 2) 65 | cdef int n_points = points.shape[0] 66 | 67 | cdef vector[int] points_indices 68 | cdef vector[int] tri_indices 69 | # cdef int[:] points_indices_np 70 | # cdef int[:] tri_indices_np 71 | 72 | cdef int i_point, k, x, y 73 | cdef int spatial_idx 74 | 75 | for i_point in range(n_points): 76 | x = int(points[i_point, 0]) 77 | y = int(points[i_point, 1]) 78 | if not (0 <= x < self.resolution and 0 <= y < self.resolution): 79 | continue 80 | 81 | spatial_idx = self.resolution * x + y 82 | for i_tri in self.spatial_hash[spatial_idx]: 83 | points_indices.push_back(i_point) 84 | tri_indices.push_back(i_tri) 85 | 86 | points_indices_np = np.zeros(points_indices.size(), dtype=np.int32) 87 | tri_indices_np = np.zeros(tri_indices.size(), dtype=np.int32) 88 | 89 | cdef int[:] points_indices_view = points_indices_np 90 | cdef int[:] tri_indices_view = tri_indices_np 91 | 92 | for k in range(points_indices.size()): 93 | points_indices_view[k] = points_indices[k] 94 | 95 | for k in range(tri_indices.size()): 96 | tri_indices_view[k] = tri_indices[k] 97 | 98 | return points_indices_np, tri_indices_np 99 | -------------------------------------------------------------------------------- /loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def image_mse(model_output, gt, step, tiling_every=100, dataset=None, model_type='multiscale', retile=True): 5 | img_loss = (model_output['model_out']['output'] - gt['img'])**2 6 | 7 | if model_type == 'multiscale': 8 | per_patch_loss = torch.mean(img_loss, dim=(-1, -2)).squeeze(0).detach().cpu().numpy() 9 | 10 | dataset.update_patch_err(per_patch_loss, step) 11 | if step % tiling_every == tiling_every-1 and retile: 12 | tiling_stats = dataset.update_tiling() 13 | if tiling_stats['merged'] != 0 or tiling_stats['splits'] != 0: 14 | dataset.synchronize() 15 | 16 | return {'img_loss': img_loss.mean()} 17 | 18 | 19 | def occupancy_bce(model_output, gt, step, tiling_every=100, dataset=None, 20 | model_type='multiscale', pruning_fn=None, retile=True): 21 | occupancy_loss = torch.nn.BCEWithLogitsLoss(reduction='none')(model_output['model_out']['output'], gt['occupancy'].float()) 22 | 23 | if model_type == 'multiscale': 24 | per_octant_loss = torch.mean(occupancy_loss, dim=(-1, -2)).squeeze(0).detach().cpu().numpy() 25 | 26 | dataset.update_octant_err(per_octant_loss, step) 27 | if step % tiling_every == tiling_every-1 and retile: 28 | tiling_stats = dataset.update_tiling() 29 | if tiling_stats['merged'] != 0 or tiling_stats['splits'] != 0: 30 | dataset.synchronize() 31 | 32 | return {'occupancy_loss': occupancy_loss.mean()} 33 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import trimesh 4 | from scipy.spatial import cKDTree as KDTree 5 | from inside_mesh.triangle_hash import TriangleHash as _TriangleHash 6 | 7 | ''' 8 | 9 | Some code included from 'inside_mesh' library of Occupancy Networks 10 | https://github.com/autonomousvision/occupancy_networks 11 | 12 | ''' 13 | 14 | 15 | def define_grid_3d(N, voxel_origin=[-1, -1, -1], voxel_size=None): 16 | ''' define NxNxN coordinate grid across [-1, 1] 17 | voxel_origin is the (bottom, left, down) corner, not the middle ''' 18 | 19 | if not voxel_size: 20 | voxel_size = 2.0 / (N - 1) 21 | 22 | # initialize empty tensors 23 | overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor()) 24 | grid = torch.zeros(N ** 3, 3) 25 | 26 | # transform first 3 columns to be x, y, z voxel index 27 | # every possible comb'n of [0..N,0..N,0..N] 28 | grid[:, 2] = overall_index % N # [0,1,2,...,N-1,N,0,1,2,...,N] 29 | grid[:, 1] = (overall_index.long() // N) % N # [N [N 0's, ..., N N's]] 30 | grid[:, 0] = ((overall_index.long() // N) // N) % N # [N*N 0's,...,N*N N's] 31 | 32 | # transform first 3 columns: voxel indices --> voxel coordinates 33 | grid[:, 0] = (grid[:, 0] * voxel_size) + voxel_origin[2] 34 | grid[:, 1] = (grid[:, 1] * voxel_size) + voxel_origin[1] 35 | grid[:, 2] = (grid[:, 2] * voxel_size) + voxel_origin[0] 36 | 37 | return grid 38 | 39 | 40 | def compute_iou(path_gt, path_pr, N=128, sphere=False, sphere_radius=0.25): 41 | ''' compute iou score 42 | parameters 43 | path_gt: path to ground-truth mesh (.ply or .obj) 44 | path_pr: path to predicted mesh (.ply or .obj) 45 | N: NxNxN grid resolution at which to compute iou ''' 46 | 47 | # define NxNxN coordinate grid across [-1,1] 48 | grid = np.array(define_grid_3d(N)) 49 | 50 | # load mesh 51 | occ_pr = MeshDataset(path_pr) 52 | 53 | # compute occupancy at specified grid points 54 | if sphere: 55 | occ_gt = torch.from_numpy(np.linalg.norm(grid, axis=-1) <= sphere_radius) 56 | else: 57 | occ_gt = MeshDataset(path_gt) 58 | occ_gt = torch.tensor(check_mesh_contains(occ_gt.mesh, grid)) 59 | 60 | occ_pr = torch.tensor(check_mesh_contains(occ_pr.mesh, grid)) 61 | 62 | # compute iou 63 | area_union = torch.sum((occ_gt | occ_pr).float()) 64 | area_intersect = torch.sum((occ_gt & occ_pr).float()) 65 | iou = area_intersect / area_union 66 | 67 | return iou.item() 68 | 69 | 70 | def compute_trimesh_chamfer(mesh1, mesh2, num_mesh_samples=300000): 71 | """ 72 | This function computes a symmetric chamfer distance, i.e. the sum of both chamfers. 73 | gt_points: trimesh.points.PointCloud of just poins, sampled from the surface (see 74 | compute_metrics.ply for more documentation) 75 | gen_mesh: trimesh.base.Trimesh of output mesh from whichever autoencoding reconstruction 76 | method (see compute_metrics.py for more) 77 | """ 78 | 79 | gen_points_sampled = trimesh.sample.sample_surface(mesh1, num_mesh_samples)[0] 80 | gt_points_np = trimesh.sample.sample_surface(mesh2, num_mesh_samples)[0] 81 | 82 | # one direction 83 | gen_points_kd_tree = KDTree(gen_points_sampled) 84 | one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points_np) 85 | gt_to_gen_chamfer = np.mean(np.square(one_distances)) 86 | 87 | # other direction 88 | gt_points_kd_tree = KDTree(gt_points_np) 89 | two_distances, two_vertex_ids = gt_points_kd_tree.query(gen_points_sampled) 90 | gen_to_gt_chamfer = np.mean(np.square(two_distances)) 91 | 92 | chamfer_dist = gt_to_gen_chamfer + gen_to_gt_chamfer 93 | return chamfer_dist 94 | 95 | 96 | class MeshDataset(): 97 | def __init__(self, path_mesh, sample=False, num_pts=0): 98 | 99 | if not path_mesh: 100 | return 101 | 102 | self.mesh = trimesh.load(path_mesh, process=False, 103 | force='mesh', skip_materials=True) 104 | 105 | 106 | def check_mesh_contains(mesh, points, hash_resolution=512): 107 | intersector = MeshIntersector(mesh, hash_resolution) 108 | contains = intersector.query(points) 109 | return contains 110 | 111 | 112 | class MeshIntersector: 113 | def __init__(self, mesh, resolution=512): 114 | triangles = mesh.vertices[mesh.faces].astype(np.float64) 115 | n_tri = triangles.shape[0] 116 | 117 | self.resolution = resolution 118 | self.bbox_min = triangles.reshape(3 * n_tri, 3).min(axis=0) 119 | self.bbox_max = triangles.reshape(3 * n_tri, 3).max(axis=0) 120 | 121 | # Tranlate and scale it to [0.5, self.resolution - 0.5]^3 122 | self.scale = (resolution - 1) / (self.bbox_max - self.bbox_min) 123 | self.translate = 0.5 - self.scale * self.bbox_min 124 | self._triangles = triangles = self.rescale(triangles) 125 | 126 | triangles2d = triangles[:, :, :2] 127 | self._tri_intersector2d = TriangleIntersector2d( 128 | triangles2d, resolution) 129 | 130 | def query(self, points): 131 | # Rescale points 132 | points = self.rescale(points) 133 | 134 | # placeholder result with no hits we'll fill in later 135 | contains = np.zeros(len(points), dtype=np.bool) 136 | 137 | # cull points outside of the axis aligned bounding box 138 | # this avoids running ray tests unless points are close 139 | inside_aabb = np.all( 140 | (0 <= points) & (points <= self.resolution), axis=1) 141 | if not inside_aabb.any(): 142 | return contains 143 | 144 | # Only consider points inside bounding box 145 | mask = inside_aabb 146 | points = points[mask] 147 | 148 | # Compute intersection depth and check order 149 | points_indices, tri_indices = self._tri_intersector2d.query(points[:, :2]) 150 | 151 | triangles_intersect = self._triangles[tri_indices] 152 | points_intersect = points[points_indices] 153 | 154 | depth_intersect, abs_n_2 = self.compute_intersection_depth( 155 | points_intersect, triangles_intersect) 156 | 157 | # Count number of intersections in both directions 158 | smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2 159 | bigger_depth = depth_intersect < points_intersect[:, 2] * abs_n_2 160 | points_indices_0 = points_indices[smaller_depth] 161 | points_indices_1 = points_indices[bigger_depth] 162 | 163 | nintersect0 = np.bincount(points_indices_0, minlength=points.shape[0]) 164 | nintersect1 = np.bincount(points_indices_1, minlength=points.shape[0]) 165 | 166 | # Check if point contained in mesh 167 | contains1 = (np.mod(nintersect0, 2) == 1) 168 | contains2 = (np.mod(nintersect1, 2) == 1) 169 | contains[mask] = (contains1 & contains2) 170 | return contains 171 | 172 | def compute_intersection_depth(self, points, triangles): 173 | t1 = triangles[:, 0, :] 174 | t2 = triangles[:, 1, :] 175 | t3 = triangles[:, 2, :] 176 | 177 | v1 = t3 - t1 178 | v2 = t2 - t1 179 | 180 | normals = np.cross(v1, v2) 181 | alpha = np.sum(normals[:, :2] * (t1[:, :2] - points[:, :2]), axis=1) 182 | 183 | n_2 = normals[:, 2] 184 | t1_2 = t1[:, 2] 185 | s_n_2 = np.sign(n_2) 186 | abs_n_2 = np.abs(n_2) 187 | 188 | mask = (abs_n_2 != 0) 189 | 190 | depth_intersect = np.full(points.shape[0], np.nan) 191 | depth_intersect[mask] = \ 192 | t1_2[mask] * abs_n_2[mask] + alpha[mask] * s_n_2[mask] 193 | 194 | return depth_intersect, abs_n_2 195 | 196 | def rescale(self, array): 197 | array = self.scale * array + self.translate 198 | return array 199 | 200 | 201 | class TriangleIntersector2d: 202 | def __init__(self, triangles, resolution=128): 203 | self.triangles = triangles 204 | self.tri_hash = _TriangleHash(triangles, resolution) 205 | 206 | def query(self, points): 207 | point_indices, tri_indices = self.tri_hash.query(points) 208 | point_indices = np.array(point_indices, dtype=np.int64) 209 | tri_indices = np.array(tri_indices, dtype=np.int64) 210 | points = points[point_indices] 211 | triangles = self.triangles[tri_indices] 212 | mask = self.check_triangles(points, triangles) 213 | point_indices = point_indices[mask] 214 | tri_indices = tri_indices[mask] 215 | return point_indices, tri_indices 216 | 217 | def check_triangles(self, points, triangles): 218 | contains = np.zeros(points.shape[0], dtype=np.bool) 219 | A = triangles[:, :2] - triangles[:, 2:] 220 | A = A.transpose([0, 2, 1]) 221 | y = points - triangles[:, 2] 222 | 223 | detA = A[:, 0, 0] * A[:, 1, 1] - A[:, 0, 1] * A[:, 1, 0] 224 | 225 | mask = (np.abs(detA) != 0.) 226 | A = A[mask] 227 | y = y[mask] 228 | detA = detA[mask] 229 | 230 | s_detA = np.sign(detA) 231 | abs_detA = np.abs(detA) 232 | 233 | u = (A[:, 1, 1] * y[:, 0] - A[:, 0, 1] * y[:, 1]) * s_detA 234 | v = (-A[:, 1, 0] * y[:, 0] + A[:, 0, 0] * y[:, 1]) * s_detA 235 | 236 | sum_uv = u + v 237 | contains[mask] = ( 238 | (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA) 239 | & (0 < sum_uv) & (sum_uv < abs_detA) 240 | ) 241 | return contains 242 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import math 5 | from functools import partial 6 | 7 | 8 | class Sine(nn.Module): 9 | def __init__(self, w0=30): 10 | super().__init__() 11 | self.w0 = w0 12 | 13 | def forward(self, input): 14 | # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30 15 | return torch.sin(self.w0 * input) 16 | 17 | 18 | class FCBlock(nn.Module): 19 | '''A fully connected neural network that also allows swapping out the weights when used with a hypernetwork. 20 | Can be used just as a normal neural network though, as well. 21 | ''' 22 | 23 | def __init__(self, in_features, out_features, num_hidden_layers, hidden_features, 24 | outermost_linear=False, nonlinearity='relu', weight_init=None, w0=30): 25 | super().__init__() 26 | 27 | self.first_layer_init = None 28 | 29 | # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable, 30 | # special first-layer initialization scheme 31 | nls_and_inits = {'sine': (Sine(w0=w0), partial(sine_init, w0=w0), first_layer_sine_init), 32 | 'relu': (nn.ReLU(inplace=True), init_weights_normal, None)} 33 | 34 | nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity] 35 | 36 | if weight_init is not None: # Overwrite weight init if passed 37 | self.weight_init = weight_init 38 | else: 39 | self.weight_init = nl_weight_init 40 | 41 | self.net = [] 42 | self.net.append(nn.Sequential( 43 | nn.Linear(in_features, hidden_features), nl 44 | )) 45 | 46 | for i in range(num_hidden_layers): 47 | self.net.append(nn.Sequential( 48 | nn.Linear(hidden_features, hidden_features), nl 49 | )) 50 | 51 | if outermost_linear: 52 | self.net.append(nn.Sequential(nn.Linear(hidden_features, out_features))) 53 | else: 54 | self.net.append(nn.Sequential( 55 | nn.Linear(hidden_features, out_features), nl 56 | )) 57 | 58 | self.net = nn.Sequential(*self.net) 59 | if self.weight_init is not None: 60 | self.net.apply(self.weight_init) 61 | 62 | if first_layer_init is not None: # Apply special initialization to first layer, if applicable. 63 | self.net[0].apply(first_layer_init) 64 | 65 | def forward(self, coords): 66 | output = self.net(coords) 67 | return output 68 | 69 | 70 | class PositionalEncoding(nn.Module): 71 | def __init__(self, num_encoding_functions=6, include_input=True, log_sampling=True, normalize=False, 72 | input_dim=3, gaussian_pe=False, gaussian_variance=38): 73 | super().__init__() 74 | self.num_encoding_functions = num_encoding_functions 75 | self.include_input = include_input 76 | self.log_sampling = log_sampling 77 | self.normalize = normalize 78 | self.gaussian_pe = gaussian_pe 79 | self.normalization = None 80 | 81 | if self.gaussian_pe: 82 | # this needs to be registered as a parameter so that it is saved in the model state dict 83 | # and so that it is converted using .cuda(). Doesn't need to be trained though 84 | self.gaussian_weights = nn.Parameter(gaussian_variance * torch.randn(num_encoding_functions, input_dim), 85 | requires_grad=False) 86 | else: 87 | self.frequency_bands = None 88 | if self.log_sampling: 89 | self.frequency_bands = 2.0 ** torch.linspace( 90 | 0.0, 91 | self.num_encoding_functions - 1, 92 | self.num_encoding_functions) 93 | else: 94 | self.frequency_bands = torch.linspace( 95 | 2.0 ** 0.0, 96 | 2.0 ** (self.num_encoding_functions - 1), 97 | self.num_encoding_functions) 98 | 99 | if normalize: 100 | self.normalization = torch.tensor(1/self.frequency_bands) 101 | 102 | def forward(self, tensor) -> torch.Tensor: 103 | r"""Apply positional encoding to the input. 104 | 105 | Args: 106 | tensor (torch.Tensor): Input tensor to be positionally encoded. 107 | encoding_size (optional, int): Number of encoding functions used to compute 108 | a positional encoding (default: 6). 109 | include_input (optional, bool): Whether or not to include the input in the 110 | positional encoding (default: True). 111 | 112 | Returns: 113 | (torch.Tensor): Positional encoding of the input tensor. 114 | """ 115 | 116 | encoding = [tensor] if self.include_input else [] 117 | if self.gaussian_pe: 118 | for func in [torch.sin, torch.cos]: 119 | encoding.append(func(torch.matmul(tensor, self.gaussian_weights.T))) 120 | else: 121 | for idx, freq in enumerate(self.frequency_bands): 122 | for func in [torch.sin, torch.cos]: 123 | if self.normalization is not None: 124 | encoding.append(self.normalization[idx]*func(tensor * freq)) 125 | else: 126 | encoding.append(func(tensor * freq)) 127 | 128 | # Special case, for no positional encoding 129 | if len(encoding) == 1: 130 | return encoding[0] 131 | else: 132 | return torch.cat(encoding, dim=-1) 133 | 134 | 135 | def init_weights_normal(m): 136 | if type(m) == nn.Linear: 137 | if hasattr(m, 'weight'): 138 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') 139 | 140 | 141 | def init_weights_xavier(m): 142 | if type(m) == nn.Linear: 143 | if hasattr(m, 'weight'): 144 | nn.init.xavier_normal_(m.weight) 145 | 146 | 147 | def sine_init(m, w0=30): 148 | with torch.no_grad(): 149 | if hasattr(m, 'weight'): 150 | num_input = m.weight.size(-1) 151 | # See supplement Sec. 1.5 for discussion of factor w0 152 | m.weight.uniform_(-np.sqrt(6 / num_input) / w0, np.sqrt(6 / num_input) / w0) 153 | 154 | 155 | def first_layer_sine_init(m): 156 | with torch.no_grad(): 157 | if hasattr(m, 'weight'): 158 | num_input = m.weight.size(-1) 159 | # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30 160 | m.weight.uniform_(-1 / num_input, 1 / num_input) 161 | 162 | 163 | class ImplicitAdaptivePatchNet(nn.Module): 164 | def __init__(self, in_features=3, out_features=1, feature_grid_size=(8, 8, 8), 165 | hidden_features=256, num_hidden_layers=3, patch_size=8, 166 | code_dim=8, use_pe=True, num_encoding_functions=6, **kwargs): 167 | super().__init__() 168 | self.in_features = in_features 169 | self.out_features = out_features 170 | self.feature_grid_size = feature_grid_size 171 | self.patch_size = patch_size 172 | self.use_pe = use_pe 173 | 174 | if self.use_pe: 175 | self.positional_encoding = PositionalEncoding(num_encoding_functions=num_encoding_functions) 176 | in_features = 2*in_features*num_encoding_functions + in_features 177 | 178 | self.coord2features_net = FCBlock(in_features=in_features, out_features=np.prod(feature_grid_size), 179 | num_hidden_layers=num_hidden_layers, hidden_features=hidden_features, 180 | outermost_linear=True, nonlinearity='relu') 181 | 182 | self.features2sample_net = FCBlock(in_features=self.feature_grid_size[0], out_features=out_features, 183 | num_hidden_layers=1, hidden_features=64, 184 | outermost_linear=True, nonlinearity='relu') 185 | print(self) 186 | 187 | def forward(self, model_input): 188 | 189 | # Enables us to compute gradients w.r.t. coordinates 190 | coords = model_input['coords'].clone().detach().requires_grad_(True) 191 | fine_coords = model_input['fine_rel_coords'].clone().detach().requires_grad_(True) 192 | 193 | if self.use_pe: 194 | coords = self.positional_encoding(coords) 195 | 196 | features = self.coord2features_net(coords) 197 | 198 | # features is size (Batch Size, Blocks, prod(feature_grid_size)) 199 | # but currently interpolate bilinear only supports one batch dimension, 200 | # therefore, for now assume that Batch Size == 1 201 | assert features.shape[0] == 1, 'Code currently only supports Batch Size == 1' 202 | 203 | n_channels, dx, dy = self.feature_grid_size 204 | features = features.squeeze(0) 205 | b_size = features.shape[0] 206 | 207 | features_in = features.squeeze().reshape(b_size, n_channels, dx, dy) 208 | sample_coords_out = fine_coords[0, ...].reshape(1, -1, 2) 209 | sample_coords = sample_coords_out.reshape(b_size, self.patch_size[0], self.patch_size[1], 2) 210 | 211 | y = sample_coords[..., :1] 212 | x = sample_coords[..., 1:] 213 | sample_coords = torch.cat([y, x], dim=-1) 214 | 215 | features_out = torch.nn.functional.grid_sample(features_in, sample_coords, 216 | mode='bilinear', 217 | padding_mode='border', 218 | align_corners=True).reshape(b_size, n_channels, np.prod(self.patch_size)) 219 | 220 | # permute from (Blocks, feature_grid_size[0], patch_size**2)->(Blocks, patch_size**2, feature_grid_size[0]) 221 | # so the network maps features to function output 222 | features_out = features_out.permute(0, 2, 1) 223 | 224 | # for all spatial feature vectors, extract function value 225 | patch_out = self.features2sample_net(features_out) 226 | 227 | # squeeze out last dimension and restore batch dimension 228 | patch_out = patch_out.unsqueeze(0) 229 | 230 | return {'model_in': {'sample_coords_out': sample_coords_out, 'model_in_coarse': coords}, 231 | 'model_out': {'output': patch_out, 'codes': None}} 232 | 233 | 234 | class ImplicitAdaptiveOctantNet(nn.Module): 235 | def __init__(self, in_features=4, out_features=1, feature_grid_size=(4, 16, 16, 16), 236 | hidden_features=256, num_hidden_layers=3, octant_size=8, 237 | code_dim=8, use_pe=True, num_encoding_functions=6): 238 | super().__init__() 239 | self.in_features = in_features 240 | self.out_features = out_features 241 | self.feature_grid_size = feature_grid_size 242 | self.octant_size = octant_size 243 | self.use_pe = use_pe 244 | 245 | if self.use_pe: 246 | self.positional_encoding = PositionalEncoding(num_encoding_functions=num_encoding_functions) 247 | in_features = 2*in_features*num_encoding_functions + in_features 248 | 249 | self.coord2features_net = FCBlock(in_features=in_features, out_features=np.prod(feature_grid_size), 250 | num_hidden_layers=num_hidden_layers, hidden_features=hidden_features, 251 | outermost_linear=True, nonlinearity='relu') 252 | 253 | self.features2sample_net = FCBlock(in_features=feature_grid_size[0], out_features=out_features, 254 | num_hidden_layers=1, hidden_features=64, 255 | outermost_linear=True, nonlinearity='relu') 256 | 257 | def forward(self, model_input, oversample=1.0): 258 | 259 | # Enables us to compute gradients w.r.t. coordinates 260 | coords = model_input['coords'].clone().detach().requires_grad_(True) 261 | fine_coords = model_input['fine_rel_coords'].clone().detach().requires_grad_(True) 262 | 263 | if self.use_pe: 264 | coords = self.positional_encoding(coords) 265 | 266 | features = self.coord2features_net(coords) 267 | 268 | # features is size (Batch Size, Blocks, prod(feature_grid_size)) 269 | # but currently interpolate bilinear only supports one batch dimension, 270 | # therefore, for now assume that Batch Size == 1 271 | assert features.shape[0] == 1, 'Code currently only supports Batch Size == 1' 272 | 273 | n_channels, dx, dy, dz = self.feature_grid_size 274 | features = features.squeeze(0) 275 | b_size = features.shape[0] 276 | 277 | features_in = features.squeeze().reshape(b_size, n_channels, dx, dy, dz) 278 | sample_coords_out = fine_coords[0, ...].reshape(1, -1, 3) 279 | sample_coords = sample_coords_out.reshape(b_size, self.octant_size, self.octant_size, self.octant_size, 3) 280 | features_out = torch.nn.functional.grid_sample(features_in, sample_coords, 281 | mode='bilinear', 282 | padding_mode='border', 283 | align_corners=True).reshape(b_size, n_channels, self.octant_size**3) 284 | 285 | # permute from (Blocks, feature_grid_size[0], patch_size**2)->(Blocks, patch_size**2, feature_grid_size[0]) 286 | # so the network maps features to function output 287 | features_out = features_out.permute(0, 2, 1) 288 | 289 | # for all spatial feature vectors, extract function value 290 | patch_out = self.features2sample_net(features_out) 291 | 292 | # squeeze out last dimension and restore batch dimension 293 | patch_out = patch_out.unsqueeze(0) 294 | 295 | return {'model_in': {'sample_coords_out': sample_coords_out, 'model_in_coarse': coords}, 296 | 'model_out': {'output': patch_out, 'codes': None}} 297 | 298 | 299 | class ImplicitNet(nn.Module): 300 | '''A canonical representation network for a BVP.''' 301 | 302 | def __init__(self, sidelength, out_features=1, in_features=2, 303 | mode='pe', hidden_features=256, num_hidden_layers=3, w0=30, **kwargs): 304 | 305 | super().__init__() 306 | self.mode = mode 307 | 308 | if self.mode == 'pe': 309 | nyquist_rate = 1 / (2 * (2 * 1/np.max(sidelength))) 310 | num_encoding_functions = int(math.floor(math.log(nyquist_rate, 2))) 311 | 312 | nonlinearity = 'relu' 313 | self.positional_encoding = PositionalEncoding(num_encoding_functions=num_encoding_functions) 314 | in_features = 2*in_features*num_encoding_functions + in_features 315 | 316 | elif self.mode == 'siren': 317 | nonlinearity = 'sine' 318 | else: 319 | raise NotImplementedError(f'mode=={self.mode} not implemented') 320 | 321 | self.net = FCBlock(in_features=in_features, out_features=out_features, num_hidden_layers=num_hidden_layers, 322 | hidden_features=hidden_features, outermost_linear=True, nonlinearity=nonlinearity, w0=w0) 323 | print(self) 324 | 325 | def forward(self, model_input): 326 | 327 | coords = model_input['fine_abs_coords'][..., :2] 328 | 329 | if self.mode == 'pe': 330 | coords = self.positional_encoding(coords) 331 | 332 | output = self.net(coords) 333 | return {'model_in': {'coords': coords}, 'model_out': {'output': output}} 334 | -------------------------------------------------------------------------------- /pruning_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils 3 | 4 | 5 | def no_pruning(model, dataset, pruning_every=100): 6 | return 7 | 8 | 9 | def pruning_occupancy(model, dataset, threshold=-10): 10 | model_input = dataset.get_eval_samples(1) 11 | 12 | print("Pruning: loading data to cuda...") 13 | tmp = {} 14 | for key, value in model_input.items(): 15 | if isinstance(value, torch.Tensor): 16 | tmp.update({key: value[None, ...].cuda()}) 17 | else: 18 | tmp.update({key: value}) 19 | model_input = tmp 20 | 21 | print("Pruning: evaluating occupancy...") 22 | pred_occupancy = utils.process_batch_in_chunks(model_input, model)['model_out']['output'] 23 | pred_occupancy = torch.max(pred_occupancy, dim=-2).values.squeeze() 24 | pred_occupancy_idx = model_input['coord_octant_idx'].squeeze() 25 | 26 | print("Pruning: computing mean and freezing empty octants") 27 | active_octants = dataset.octtree.get_active_octants() 28 | 29 | frozen_octants = 0 30 | for idx, octant in enumerate(active_octants): 31 | max_prediction = torch.max(pred_occupancy[pred_occupancy_idx == idx]) 32 | if max_prediction < threshold and octant.err < 1e-3: # Prune if model is confident that everything is empty 33 | octant.frozen = True 34 | frozen_octants += 1 35 | print(f"Pruning: Froze {frozen_octants} octants.") 36 | dataset.synchronize() 37 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | '''Implements a generic training loop. 2 | ''' 3 | 4 | import torch 5 | import utils 6 | from torch.utils.tensorboard import SummaryWriter 7 | from tqdm.autonotebook import tqdm 8 | import time 9 | import numpy as np 10 | import os 11 | import shutil 12 | 13 | 14 | def train(model, train_dataloader, epochs, lr, steps_til_summary, epochs_til_checkpoint, model_dir, 15 | loss_fn, pruning_fn, summary_fn, double_precision=False, clip_grad=False, 16 | loss_schedules=None, resume_checkpoint={}, objs_to_save={}, epochs_til_pruning=4): 17 | optim = torch.optim.Adam(lr=lr, params=model.parameters()) 18 | 19 | # load optimizer if supplied 20 | if 'optimizer_state_dict' in resume_checkpoint: 21 | optim.load_state_dict(resume_checkpoint['optimizer_state_dict']) 22 | 23 | for g in optim.param_groups: 24 | g['lr'] = lr 25 | 26 | if os.path.exists(os.path.join(model_dir, 'summaries')): 27 | val = input("The model directory %s exists. Overwrite? (y/n)" % model_dir) 28 | if val == 'y': 29 | if os.path.exists(os.path.join(model_dir, 'summaries')): 30 | shutil.rmtree(os.path.join(model_dir, 'summaries')) 31 | if os.path.exists(os.path.join(model_dir, 'checkpoints')): 32 | shutil.rmtree(os.path.join(model_dir, 'checkpoints')) 33 | 34 | os.makedirs(model_dir, exist_ok=True) 35 | 36 | summaries_dir = os.path.join(model_dir, 'summaries') 37 | utils.cond_mkdir(summaries_dir) 38 | 39 | checkpoints_dir = os.path.join(model_dir, 'checkpoints') 40 | utils.cond_mkdir(checkpoints_dir) 41 | 42 | writer = SummaryWriter(summaries_dir) 43 | total_steps = 0 44 | if 'total_steps' in resume_checkpoint: 45 | total_steps = resume_checkpoint['total_steps'] 46 | 47 | start_epoch = 0 48 | if 'epoch' in resume_checkpoint: 49 | start_epoch = resume_checkpoint['epoch'] 50 | 51 | with tqdm(total=len(train_dataloader) * epochs) as pbar: 52 | pbar.update(total_steps) 53 | train_losses = [] 54 | for epoch in range(start_epoch, epochs): 55 | if not epoch % epochs_til_checkpoint and epoch: 56 | torch.save(model.state_dict(), 57 | os.path.join(checkpoints_dir, 'model_%06d.pth' % total_steps)) 58 | np.savetxt(os.path.join(checkpoints_dir, 'train_losses_%06d.txt' % total_steps), 59 | np.array(train_losses)) 60 | save_dict = {'epoch': epoch, 61 | 'total_steps': total_steps, 62 | 'optimizer_state_dict': optim.state_dict()} 63 | save_dict.update(objs_to_save) 64 | torch.save(save_dict, os.path.join(checkpoints_dir, 'optim_%06d.pth' % total_steps)) 65 | 66 | # prune 67 | if not epoch % epochs_til_pruning and epoch: 68 | pruning_fn(model, train_dataloader.dataset) 69 | 70 | if not (epoch + 1) % epochs_til_pruning: 71 | retile = False 72 | else: 73 | retile = True 74 | 75 | for step, (model_input, gt) in enumerate(train_dataloader): 76 | start_time = time.time() 77 | 78 | tmp = {} 79 | for key, value in model_input.items(): 80 | if isinstance(value, torch.Tensor): 81 | tmp.update({key: value.cuda()}) 82 | else: 83 | tmp.update({key: value}) 84 | model_input = tmp 85 | 86 | tmp = {} 87 | for key, value in gt.items(): 88 | if isinstance(value, torch.Tensor): 89 | tmp.update({key: value.cuda()}) 90 | else: 91 | tmp.update({key: value}) 92 | gt = tmp 93 | 94 | if double_precision: 95 | model_input = {key: value.double() for key, value in model_input.items()} 96 | gt = {key: value.double() for key, value in gt.items()} 97 | 98 | model_output = model(model_input) 99 | losses = loss_fn(model_output, gt, total_steps, retile=retile) 100 | 101 | train_loss = 0. 102 | for loss_name, loss in losses.items(): 103 | single_loss = loss.mean() 104 | 105 | if loss_schedules is not None and loss_name in loss_schedules: 106 | writer.add_scalar(loss_name + "_weight", loss_schedules[loss_name](total_steps), total_steps) 107 | single_loss *= loss_schedules[loss_name](total_steps) 108 | 109 | writer.add_scalar(loss_name, single_loss, total_steps) 110 | train_loss += single_loss 111 | 112 | train_losses.append(train_loss.item()) 113 | writer.add_scalar("total_train_loss", train_loss, total_steps) 114 | 115 | optim.zero_grad() 116 | train_loss.backward() 117 | 118 | if clip_grad: 119 | if isinstance(clip_grad, bool): 120 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.) 121 | else: 122 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad) 123 | 124 | optim.step() 125 | 126 | pbar.update(1) 127 | 128 | if not total_steps % steps_til_summary: 129 | tqdm.write("Epoch %d, Total loss %0.6f, iteration time %0.6f" % (epoch, train_loss, time.time() - start_time)) 130 | summary_fn(model, model_input, gt, model_output, writer, total_steps) 131 | 132 | total_steps += 1 133 | 134 | # after epoch 135 | tqdm.write("Epoch %d, Total loss %0.6f, iteration time %0.6f" % (epoch, train_loss, time.time() - start_time)) 136 | 137 | # save model at end of epoch 138 | torch.save(model.state_dict(), 139 | os.path.join(checkpoints_dir, 'model_final_%06d.pth' % total_steps)) 140 | np.savetxt(os.path.join(checkpoints_dir, 'train_losses_final_%06d.txt' % total_steps), 141 | np.array(train_losses)) 142 | save_dict = {'epoch': epoch, 143 | 'total_steps': total_steps, 144 | 'optimizer_state_dict': optim.state_dict()} 145 | save_dict.update(objs_to_save) 146 | torch.save(save_dict, os.path.join(checkpoints_dir, 'optim_final_%06d.pth' % total_steps)) 147 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from torchvision.utils import make_grid 5 | import skimage.measure 6 | from tqdm import tqdm 7 | import mrcfile 8 | 9 | 10 | def cond_mkdir(path): 11 | if not os.path.exists(path): 12 | os.makedirs(path) 13 | 14 | 15 | def write_psnr(pred_img, gt_img, writer, iter, prefix): 16 | batch_size = pred_img.shape[0] 17 | 18 | pred_img = pred_img.detach().cpu().numpy() 19 | gt_img = gt_img.detach().cpu().numpy() 20 | 21 | psnrs = list() 22 | for i in range(batch_size): 23 | p = pred_img[i].transpose(1, 2, 0) 24 | trgt = gt_img[i].transpose(1, 2, 0) 25 | 26 | p = (p / 2.) + 0.5 27 | p = np.clip(p, a_min=0., a_max=1.) 28 | trgt = (trgt / 2.) + 0.5 29 | 30 | psnr = skimage.metrics.peak_signal_noise_ratio(p, trgt, data_range=1) 31 | psnrs.append(psnr) 32 | 33 | writer.add_scalar(prefix + "psnr", np.mean(psnrs), iter) 34 | 35 | 36 | def write_occupancy_multiscale_summary(image_resolution, dataset, model, model_input, gt, 37 | model_output, writer, total_steps, prefix='train_', 38 | output_mrc='test.mrc', skip=False, 39 | oversample=1.0, max_chunk_size=1024, mode='binary'): 40 | if skip: 41 | return 42 | 43 | model_input = dataset.get_eval_samples(oversample) 44 | 45 | print("Summary: Write occupancy multiscale summary...") 46 | 47 | # convert to cuda and add batch dimension 48 | tmp = {} 49 | for key, value in model_input.items(): 50 | if isinstance(value, torch.Tensor): 51 | tmp.update({key: value[None, ...]}) 52 | else: 53 | tmp.update({key: value}) 54 | model_input = tmp 55 | 56 | print("Summary: processing...") 57 | pred_occupancy = process_batch_in_chunks(model_input, model, max_chunk_size=max_chunk_size)['model_out']['output'] 58 | 59 | # get voxel idx for each coordinate 60 | coords = model_input['fine_abs_coords'].detach().cpu().numpy() 61 | voxel_idx = np.floor((coords + 1.) / 2. * (dataset.sidelength[0] * oversample)).astype(np.int32) 62 | voxel_idx = voxel_idx.reshape(-1, 3) 63 | 64 | # init a new occupancy volume 65 | display_occupancy = -1 * np.ones(image_resolution, dtype=np.float32) 66 | 67 | # assign predicted voxel occupancy values into the array 68 | pred_occupancy = pred_occupancy.reshape(-1, 1).detach().cpu().numpy() 69 | display_occupancy[voxel_idx[:, 0], voxel_idx[:, 1], voxel_idx[:, 2]] = pred_occupancy[..., 0] 70 | 71 | print(f"Summary: write MRC file {image_resolution}") 72 | if mode == 'hq': 73 | print("\tWriting float") 74 | with mrcfile.new_mmap(output_mrc, overwrite=True, shape=image_resolution, mrc_mode=2) as mrc: 75 | mrc.data[voxel_idx[:, 0], voxel_idx[:, 1], voxel_idx[:, 2]] = pred_occupancy[..., 0] 76 | elif mode == 'binary': 77 | print("\tWriting binary") 78 | with mrcfile.new_mmap(output_mrc, overwrite=True, shape=image_resolution) as mrc: 79 | mrc.data[voxel_idx[:, 0], voxel_idx[:, 1], voxel_idx[:, 2]] = pred_occupancy[..., 0] > 0 80 | 81 | if writer is not None: 82 | print("Summary: Draw octtree") 83 | fig = dataset.octtree.draw() 84 | writer.add_figure(prefix + 'tiling', fig, global_step=total_steps) 85 | 86 | return display_occupancy 87 | 88 | 89 | def write_image_patch_multiscale_summary(image_resolution, patch_size, dataset, model, model_input, gt, 90 | model_output, writer, total_steps, prefix='train_', 91 | model_type='multiscale', skip=False): 92 | if skip: 93 | return 94 | 95 | # uniformly sample the image 96 | dataset.toggle_eval() 97 | model_input, gt = dataset[0] 98 | dataset.toggle_eval() 99 | 100 | # convert to cuda and add batch dimension 101 | tmp = {} 102 | for key, value in model_input.items(): 103 | if isinstance(value, torch.Tensor): 104 | tmp.update({key: value[None, ...].cpu()}) 105 | else: 106 | tmp.update({key: value}) 107 | model_input = tmp 108 | 109 | tmp = {} 110 | for key, value in gt.items(): 111 | if isinstance(value, torch.Tensor): 112 | tmp.update({key: value[None, ...].cpu()}) 113 | else: 114 | tmp.update({key: value}) 115 | gt = tmp 116 | 117 | # run the model on uniform samples 118 | n_channels = gt['img'].shape[-1] 119 | pred_img = process_batch_in_chunks(model_input, model)['model_out']['output'] 120 | 121 | # get pixel idx for each coordinate 122 | coords = model_input['fine_abs_coords'].detach().cpu().numpy() 123 | pixel_idx = np.zeros_like(coords).astype(np.int32) 124 | pixel_idx[..., 0] = np.round((coords[..., 0] + 1.)/2. * (dataset.sidelength[0]-1)).astype(np.int32) 125 | pixel_idx[..., 1] = np.round((coords[..., 1] + 1.)/2. * (dataset.sidelength[1]-1)).astype(np.int32) 126 | pixel_idx = pixel_idx.reshape(-1, 2) 127 | 128 | # get pixel idx for each coordinate in frozen patches 129 | frozen_coords, frozen_values = dataset.get_frozen_patches() 130 | if frozen_coords is not None: 131 | frozen_coords = frozen_coords.detach().cpu().numpy() 132 | frozen_pixel_idx = np.zeros_like(frozen_coords).astype(np.int32) 133 | frozen_pixel_idx[..., 0] = np.round((frozen_coords[..., 0] + 1.) / 2. * (dataset.sidelength[0] - 1)).astype(np.int32) 134 | frozen_pixel_idx[..., 1] = np.round((frozen_coords[..., 1] + 1.) / 2. * (dataset.sidelength[1] - 1)).astype(np.int32) 135 | frozen_pixel_idx = frozen_pixel_idx.reshape(-1, 2) 136 | 137 | # init a new reconstructed image 138 | display_pred = np.zeros((*dataset.sidelength, n_channels)) 139 | 140 | # assign predicted image values into a new array 141 | # need to use numpy since it supports index assignment 142 | pred_img = pred_img.reshape(-1, n_channels).detach().cpu().numpy() 143 | display_pred[[pixel_idx[:, 0]], [pixel_idx[:, 1]]] = pred_img 144 | 145 | # assign frozen image values into the array too 146 | if frozen_coords is not None: 147 | frozen_values = frozen_values.reshape(-1, n_channels).detach().cpu().numpy() 148 | display_pred[[frozen_pixel_idx[:, 0]], [frozen_pixel_idx[:, 1]]] = frozen_values 149 | 150 | # show reconstructed img 151 | display_pred = torch.tensor(display_pred)[None, ...] 152 | display_pred = display_pred.permute(0, 3, 1, 2) 153 | 154 | gt_img = gt['img'].reshape(-1, n_channels).detach().cpu().numpy() 155 | display_gt = np.zeros((*dataset.sidelength, n_channels)) 156 | display_gt[[pixel_idx[:, 0]], [pixel_idx[:, 1]]] = gt_img 157 | display_gt = torch.tensor(display_gt)[None, ...] 158 | display_gt = display_gt.permute(0, 3, 1, 2) 159 | 160 | fig = dataset.quadtree.draw() 161 | writer.add_figure(prefix + 'tiling', fig, global_step=total_steps) 162 | 163 | if 'img' in gt: 164 | output_vs_gt = torch.cat((display_gt, display_pred), dim=0) 165 | writer.add_image(prefix + 'gt_vs_pred', make_grid(output_vs_gt, scale_each=False, normalize=True), 166 | global_step=total_steps) 167 | write_psnr(display_pred, display_gt, writer, total_steps, prefix+'img_') 168 | 169 | 170 | def dict2cuda(a_dict): 171 | tmp = {} 172 | for key, value in a_dict.items(): 173 | if isinstance(value, torch.Tensor): 174 | tmp.update({key: value.cuda()}) 175 | else: 176 | tmp.update({key: value}) 177 | return tmp 178 | 179 | 180 | def dict2cpu(a_dict): 181 | tmp = {} 182 | for key, value in a_dict.items(): 183 | if isinstance(value, torch.Tensor): 184 | tmp.update({key: value.cpu()}) 185 | elif isinstance(value, dict): 186 | tmp.update({key: dict2cpu(value)}) 187 | else: 188 | tmp.update({key: value}) 189 | return tmp 190 | 191 | 192 | def process_batch_in_chunks(in_dict, model, max_chunk_size=1024, progress=None): 193 | 194 | in_chunked = [] 195 | for key in in_dict: 196 | chunks = torch.split(in_dict[key], max_chunk_size, dim=1) 197 | in_chunked.append(chunks) 198 | 199 | list_chunked_batched_in = \ 200 | [{k: v for k, v in zip(in_dict.keys(), curr_chunks)} for curr_chunks in zip(*in_chunked)] 201 | del in_chunked 202 | 203 | list_chunked_batched_out_out = {} 204 | list_chunked_batched_out_in = {} 205 | for chunk_batched_in in tqdm(list_chunked_batched_in): 206 | chunk_batched_in = {k: v.cuda() for k, v in chunk_batched_in.items()} 207 | tmp = model(chunk_batched_in) 208 | tmp = dict2cpu(tmp) 209 | 210 | for key in tmp['model_out']: 211 | if tmp['model_out'][key] is None: 212 | continue 213 | 214 | out_ = tmp['model_out'][key].detach().clone().requires_grad_(False) 215 | list_chunked_batched_out_out.setdefault(key, []).append(out_) 216 | 217 | for key in tmp['model_in']: 218 | if tmp['model_in'][key] is None: 219 | continue 220 | 221 | in_ = tmp['model_in'][key].detach().clone().requires_grad_(False) 222 | list_chunked_batched_out_in.setdefault(key, []).append(in_) 223 | 224 | del tmp, chunk_batched_in 225 | 226 | # Reassemble the output chunks in a batch 227 | batched_out = {} 228 | for key in list_chunked_batched_out_out: 229 | batched_out_lin = torch.cat(list_chunked_batched_out_out[key], dim=1) 230 | batched_out[key] = batched_out_lin 231 | 232 | batched_in = {} 233 | for key in list_chunked_batched_out_in: 234 | batched_in_lin = torch.cat(list_chunked_batched_out_in[key], dim=1) 235 | batched_in[key] = batched_in_lin 236 | 237 | return {'model_in': batched_in, 'model_out': batched_out} 238 | 239 | 240 | def subsample_dict(in_dict, num_views, multiscale=False): 241 | if multiscale: 242 | out = {} 243 | for k, v in in_dict.items(): 244 | if v.shape[0] == in_dict['octant_coords'].shape[0]: 245 | # this is arranged by blocks 246 | out.update({k: v[0:num_views[0]]}) 247 | else: 248 | # arranged by rays 249 | out.update({k: v[0:num_views[1]]}) 250 | else: 251 | out = {key: value[0:num_views, ...] for key, value in in_dict.items()} 252 | 253 | return out 254 | --------------------------------------------------------------------------------