├── .gitignore
├── LICENSE
├── README.md
├── data
└── pluto.jpg
├── data_structs.py
├── dataio.py
├── environment.yml
├── experiment_scripts
├── config_img
│ ├── config_mars.ini
│ ├── config_pluto_acorn_1k.ini
│ ├── config_pluto_acorn_4k.ini
│ ├── config_pluto_acorn_8k.ini
│ ├── config_pluto_pe_8k.ini
│ ├── config_pluto_siren_8k.ini
│ └── config_tokyo.ini
├── config_occupancy
│ ├── config_dragon_acorn.ini
│ ├── config_engine_acorn.ini
│ ├── config_lucy_acorn.ini
│ ├── config_lucy_small_acorn.ini
│ └── config_thai_acorn.ini
├── train_img.py
└── train_occupancy.py
├── img
└── teaser.png
├── inside_mesh
├── inside_mesh.py
├── setup.py
└── triangle_hash.pyx
├── loss_functions.py
├── metrics.py
├── modules.py
├── pruning_functions.py
├── training.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | inside_mesh/
3 | logs/
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Stanford Computational Imaging Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ACORN: Adaptive Coordinate Networks for Neural Scene Representation
SIGGRAPH 2021
2 | ### [Project Page](http://www.computationalimaging.org/publications/acorn/) | [Video](https://www.youtube.com/watch?v=P192X3J6cg4) | [Paper](https://arxiv.org/abs/2105.02788)
3 | PyTorch implementation of ACORN.
4 | [ACORN: Adaptive Coordinate Networks for Neural Scene Representation](http://www.computationalimaging.org/publications/acorn/)
5 | [Julien N. P. Martel](http://web.stanford.edu/~jnmartel/)\*,
6 | [David B. Lindell](https://davidlindell.com)\*,
7 | [Connor Z. Lin](https://connorzlin.com/),
8 | [Eric R. Chan](https://ericryanchan.github.io/about.html),
9 | [Marco Monteiro](https://twitter.com/monteiroamarco),
10 | [Gordon Wetzstein](https://computationalimaging.org)
11 | Stanford University
12 | \*denotes equal contribution
13 | in SIGGRAPH 2021
14 |
15 |
16 |
17 | ## Quickstart
18 |
19 | To setup a conda environment, download example training data, begin the training process, and launch Tensorboard, follow the below commands. As part of this you will also need to [register for and install an academic license](https://www.gurobi.com/downloads/free-academic-license/) for the Gurobi optimizer (this is free for academic use).
20 | ```
21 | conda env create -f environment.yml
22 | # before proceeding, install Gurobi optimizer license (see above web link)
23 | conda activate acorn
24 | cd inside_mesh
25 | python setup.py build_ext --inplace
26 | cd ../experiment_scripts
27 | python train_img.py --config ./config_img/config_pluto_acorn_1k.ini
28 | tensorboard --logdir=../logs --port=6006
29 | ```
30 |
31 | This example will fit 1 MP image of Pluto. You can monitor the training in your browser at `localhost:6006`.
32 |
33 | ### Adaptive Coordinate Networks
34 |
35 | An adaptive coordinate network learns an adaptive decomposition of the signal domain, allowing the network to fit signals faster and more accurately. We demonstrate using ACORN to fit large-scale images and detailed 3D occupancy fields.
36 |
37 | #### Datasets
38 |
39 | Image and 3D model datasets should be downloaded and placed in the `data` directory. The datasets used in the paper can be accessed as follows.
40 |
41 | - Public domain image of Pluto is included in the repository *(NASA/Johns Hopkins University Applied Physics Laboratory/Southwest Research Institute/Alex Parker)*
42 | - [Gigapixel image of Tokyo](https://drive.google.com/file/d/1ITWSv8KcZ_HPNrCXbbbkwzXDSDMr7ACg/view?usp=sharing) *(Trevor Dobson [CC BY-NC-ND 2.0](https://creativecommons.org/licenses/by-nc-nd/2.0/) image resized from [original](https://www.flickr.com/photos/trevor_dobson_inefekt69/29314390837))*
43 | - Public domain [Gigapixel image of Mars](https://drive.google.com/file/d/1Ro1lWxRsl97Jbzm9EA2k9nUEyyVUwxEu/view?usp=sharing) *(NASA/JPL-Caltech/MSSS)*
44 | - Blender engine model [.obj](https://drive.google.com/file/d/1NU2I1Vly6X7YZWD1z_JiBx67XSJ_iR8d/view?usp=sharing), [.blend](https://www.blendswap.com/blend/17636) *(ChrisKuhn [CC-BY](https://creativecommons.org/licenses/by/2.0/))*
45 | - Lucy dataset [.ply](http://graphics.stanford.edu/data/3Dscanrep/) (Stanford 3D Scanning Repository)
46 | - Thai Statue dataset [.ply](http://graphics.stanford.edu/data/3Dscanrep/) (Stanford 3D Scanning Repository)
47 | - Dragon dataset ([TurboSquid](https://www.turbosquid.com/3d-models/chinese-printing-3d-model-1548953))
48 |
49 | #### Training
50 |
51 | To use ACORN, first set up the conda environment and build the Cython extension with
52 | ```
53 | conda env create -f environment.yml
54 | conda activate acorn
55 | cd inside_mesh
56 | python setup.py build_ext --inplace
57 | ```
58 |
59 | Then, download the datasets to the `data` folder.
60 |
61 | We use Gurobi to perform solve the integer linear program used in the optimization. A free academic license can be installed from [this link](https://www.gurobi.com/downloads/free-academic-license/).
62 |
63 | To train image representations, use the config files in the `experiment_scripts/config_img` folder. For example, to train on the Pluto image, run the following
64 | ```
65 | python train_img.py --config ./config_img/config_pluto_1k.ini
66 | tensorboard --logdir=../logs/ --port=6006
67 | ```
68 |
69 | After the image representation has been trained, the decomposition and images can be exported using the following command.
70 |
71 | ```
72 | python train_img.py --config ../logs//config.ini --resume ../logs/ --eval
73 | ```
74 |
75 | Exported images will appear in the `../logs//eval` folder, where `` is the subdirectory in the `log` folder corresponding to the particular training run.
76 |
77 | To train 3D models, download the datasets, and then use the corresponding config file in `experiment_scripts/config_occupancy`. For example, a small model representing the Lucy statue can be trained with
78 |
79 | ```
80 | python train_occupancy.py --config ./config_occupancy/config_lucy_small_acorn.ini
81 | ```
82 |
83 | Then a mesh of the final model can be exported with
84 | ```
85 | python train_occupancy.py --config ../logs//config.ini --load ../logs/ --export
86 | ```
87 |
88 | This will create a `.dae` mesh file in the `../logs/` folder.
89 |
90 | ## Citation
91 |
92 | ```
93 | @article{martel2021acorn,
94 | title={ACORN: {Adaptive} coordinate networks for neural scene representation},
95 | author={Julien N. P. Martel and David B. Lindell and Connor Z. Lin and Eric R. Chan and Marco Monteiro and Gordon Wetzstein},
96 | journal={ACM Trans. Graph. (SIGGRAPH)},
97 | volume={40},
98 | number={4},
99 | year={2021},
100 | }
101 | ```
102 | ## Acknowledgments
103 |
104 | We include the MIT licensed `inside_mesh` code in this repo from Lars Mescheder, Michael Oechsle, Michael Niemeyer, Andreas Geiger, and Sebastian Nowozin, which is originally included in their [Occupancy Networks repository](https://github.com/autonomousvision/occupancy_networks/tree/ddb2908f96de9c0c5a30c093f2a701878ffc1f4a/im2mesh/utils/libmesh
105 | ).
106 |
107 | J.N.P. Martel was supported by a Swiss National Foundation (SNF) Fellowship (P2EZP2 181817). C.Z. Lin was supported by a David Cheriton Stanford Graduate Fellowship. G.W. was supported by an Okawa Research Grant, a Sloan Fellowship, and a PECASE by the ARO. Other funding for the project was provided by NSF (award numbers 1553333 and 1839974).
108 |
109 | ## Errata
110 | - The 3D shape fitting metrics were reported in the paper as calculated using the Chamfer-L1 distance. The metric should have been labeled Chamfer-L2, which is consistent with the implementation in this repository.
111 |
--------------------------------------------------------------------------------
/data/pluto.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/ACORN/67817daaf20d38840f043e81704e7a6ea5da584e/data/pluto.jpg
--------------------------------------------------------------------------------
/data_structs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import gurobipy as gp
4 | from gurobipy import GRB
5 | import copy
6 | import matplotlib.pyplot as plt
7 | import matplotlib.colors as colors
8 | import matplotlib.patches as patches
9 |
10 |
11 | class QuadTree():
12 | def __init__(self, sidelength, patch_size):
13 | self.sidelength = sidelength
14 |
15 | self.patch_size = patch_size
16 | self.min_patch_size = np.min(patch_size)
17 | self.max_patch_size = np.min(sidelength)
18 | self.aspect_ratio = np.array(patch_size) / np.min(patch_size)
19 |
20 | # how many levels of quadtree are there
21 | self.min_quadtree_level = int(np.log2(np.min(self.sidelength) // self.max_patch_size))
22 | self.max_quadtree_level = int(np.log2(np.min(sidelength) // self.min_patch_size))
23 | self.num_scales = self.max_quadtree_level - self.min_quadtree_level + 1
24 |
25 | # optimization model
26 | self.optim_model = gp.Model()
27 | self.optim_model.setParam('OutputFlag', 0)
28 | self.c_max_patches = None
29 |
30 | # initialize tree
31 | self.root = self.init_root(self.max_quadtree_level)
32 |
33 | # populate tree nodes with coordinate values, metadata
34 | self.populate_tree()
35 |
36 | def __deepcopy__(self, memo):
37 | deep_copied_obj = QuadTree(self.sidelength, self.patch_size)
38 |
39 | for k, v in self.__dict__.items():
40 | if k in ['optim_model', 'c_max_patches']:
41 | # setattr(deep_copied_obj, k, v)
42 | del(deep_copied_obj.__dict__[k])
43 | else:
44 | setattr(deep_copied_obj, k, copy.deepcopy(v, memo))
45 | return deep_copied_obj
46 |
47 | def __getstate__(self):
48 | state = self.__dict__.copy()
49 | for k, v in self.__dict__.items():
50 | if k in ['optim_model', 'c_max_patches']:
51 | del(state[k])
52 | return state
53 |
54 | def __load__(self, obj):
55 | for k, v in obj.__dict__.items():
56 | if k == 'root':
57 | continue
58 | setattr(self, k, v)
59 | self.root = self.init_root(obj.max_quadtree_level)
60 | self.populate_tree()
61 |
62 | def _load_helper(curr_patch, curr_obj_patch):
63 | curr_patch.__load__(curr_obj_patch)
64 | for child, obj_child in zip(curr_patch.children, curr_obj_patch.children):
65 | _load_helper(child, obj_child)
66 | return
67 | _load_helper(self.root, obj.root)
68 |
69 | def __str__(self, level=0):
70 |
71 | def _str_helper(curr_patch, level):
72 | ret = "\t"*level+repr(curr_patch.active)+"\n"
73 | for child in curr_patch.children:
74 | ret += _str_helper(child, level+1)
75 | return ret
76 | return _str_helper(self.root, 0)
77 |
78 | def populate_tree(self):
79 |
80 | # get block coords for patches at each scale
81 | patch_sizes = []
82 | curr_size = self.max_patch_size
83 | while True:
84 | patch_sizes.append(curr_size)
85 | curr_size //= 2
86 | if curr_size == self.min_patch_size:
87 | patch_sizes.append(curr_size)
88 | break
89 | elif curr_size < self.min_patch_size:
90 | raise ValueError('Patch sizes and resolution are incompatible')
91 |
92 | block_coords = [self.get_block_coords(patch_size=patch_size, include_ends=True) for patch_size in patch_sizes]
93 | block_sizes = [block[1, 1, :] - block[0, 0, :] for block in block_coords]
94 | block_coords = [block[:-1, :-1, :] for block in block_coords]
95 |
96 | # create sampling grids for training
97 | num_samples = self.min_patch_size * self.aspect_ratio
98 | row_posts = torch.linspace(-1, 1, int(self.min_patch_size*self.aspect_ratio[0])+1)[:-1]
99 | col_posts = torch.linspace(-1, 1, int(self.min_patch_size*self.aspect_ratio[1])+1)[:-1]
100 | row_coords, col_coords = torch.meshgrid(row_posts, col_posts)
101 | row_coords = row_coords.flatten()
102 | col_coords = col_coords.flatten()
103 |
104 | # create sampling grids for evaluation
105 | # here we need to sample every pixel within each block
106 | row_posts = [torch.linspace(-1, 1, int(pixel_size*self.aspect_ratio[0])+1)[:-1] for pixel_size in patch_sizes]
107 | col_posts = [torch.linspace(-1, 1, int(pixel_size*self.aspect_ratio[1])+1)[:-1] for pixel_size in patch_sizes]
108 | eval_coords = [torch.meshgrid(row_post, col_post) for row_post, col_post in zip(row_posts, col_posts)]
109 | eval_row_coords = [eval_coord[0].flatten() for eval_coord in eval_coords]
110 | eval_col_coords = [eval_coord[1].flatten() for eval_coord in eval_coords]
111 |
112 | def _populate_tree_helper(patch, idx):
113 | # get block scale idx
114 | scale_idx = len(idx) - (self.min_quadtree_level)
115 |
116 | # do we have patches at this level?
117 | if scale_idx >= 0:
118 |
119 | # set patch parameters
120 | coords = block_coords[scale_idx]
121 | coord_idx = _index_block_coord(idx, coords.shape[0], coord=[0, 0])
122 | patch.block_coord = coords[coord_idx[0], coord_idx[1]]
123 | patch.block_size = block_sizes[scale_idx]
124 | patch.scale = scale_idx
125 | patch.pixel_size = patch_sizes[scale_idx]
126 | patch.num_samples = num_samples
127 | patch.row_coords = row_coords
128 | patch.col_coords = col_coords
129 | patch.eval_row_coords = eval_row_coords[scale_idx]
130 | patch.eval_col_coords = eval_col_coords[scale_idx]
131 |
132 | if not patch.children:
133 | return
134 |
135 | # recurse
136 | for i in range(4):
137 | child = patch.children[i]
138 | _populate_tree_helper(child, [*idx, i])
139 |
140 | return
141 |
142 | # given list of tree idxs in {0,1,2,3}^N, retrieve the block coordinate
143 | def _index_block_coord(tree_idx, length, coord=[0, 0]):
144 | if length == 1:
145 | return coord
146 |
147 | if tree_idx[0] == 0:
148 | pass
149 | elif tree_idx[0] == 1:
150 | coord[1] += length//2
151 | elif tree_idx[0] == 2:
152 | coord[0] += length//2
153 | elif tree_idx[0] == 3:
154 | coord[0] += length//2
155 | coord[1] += length//2
156 | else:
157 | raise ValueError("Unexpected child value, should be 0, 1, 2, or 3")
158 |
159 | return _index_block_coord(tree_idx[1:], length//2, coord)
160 |
161 | # done with setup, now actually populate the tree
162 | _populate_tree_helper(self.root, [])
163 |
164 | def init_root(self, max_level):
165 |
166 | def _init_root_helper(curr_patch, curr_level, max_level, optim_model):
167 | if curr_level == max_level:
168 | return
169 | curr_patch.children = [Patch(optim_model) for _ in range(4)]
170 |
171 | for patch in curr_patch.children:
172 | patch.parent = curr_patch
173 | _init_root_helper(patch, curr_level+1, max_level, optim_model)
174 | return
175 |
176 | # create root node
177 | root = Patch(self.optim_model)
178 | _init_root_helper(root, 0, max_level, self.optim_model)
179 | return root
180 |
181 | def get_block_coords(self, flatten=False, include_ends=False, patch_size=None):
182 |
183 | patch_size = patch_size * self.aspect_ratio
184 |
185 | # get size of each block
186 | block_size = (2 / (self.sidelength[0]-1) * patch_size[0], 2 / (self.sidelength[1]-1) * patch_size[1])
187 |
188 | # get block begin/end coordinates
189 | if include_ends:
190 | block_coords_y = torch.arange(-1, 1+block_size[0], block_size[0])
191 | block_coords_x = torch.arange(-1, 1+block_size[1], block_size[1])
192 | else:
193 | block_coords_y = torch.arange(-1, 1, block_size[0])
194 | block_coords_x = torch.arange(-1, 1, block_size[1])
195 |
196 | # repeat for every single block
197 | block_coords = torch.meshgrid(block_coords_y, block_coords_x)
198 | block_coords = torch.stack((block_coords[0], block_coords[1]), dim=-1)
199 | if flatten:
200 | block_coords = block_coords.reshape(-1, 2)
201 |
202 | return block_coords
203 |
204 | def get_patches_at_level(self, level):
205 | # level is the image scale: 0-> coarsest, N->finest
206 | if level == -1:
207 | level = self.max_quadtree_level
208 |
209 | # what quadtree level do our patches start at?
210 | # check input, too
211 | target_level = level + self.min_quadtree_level
212 | assert level <= (self.max_quadtree_level - self.min_quadtree_level), \
213 | "invalid 'level' input to get_blocks_at_level"
214 |
215 | def _get_patches_at_level_helper(curr_patch, curr_level, patches):
216 | if curr_level > target_level:
217 | return
218 |
219 | for patch in curr_patch.children:
220 | _get_patches_at_level_helper(patch, curr_level+1, patches)
221 |
222 | if curr_level == target_level:
223 | patches.append(curr_patch)
224 | return patches
225 |
226 | return _get_patches_at_level_helper(self.root, 0, [])
227 |
228 | def get_frozen_patches(self):
229 |
230 | def _get_frozen_patches_helper(curr_patch, patches):
231 | if curr_patch.frozen and curr_patch.active:
232 | patches.append(curr_patch)
233 | for patch in curr_patch.children:
234 | _get_frozen_patches_helper(patch, patches)
235 | return patches
236 | return _get_frozen_patches_helper(self.root, [])
237 |
238 | def get_active_patches(self, include_frozen_patches=False):
239 |
240 | def _get_active_patches_helper(curr_patch, patches):
241 | if curr_patch.active and \
242 | (include_frozen_patches or
243 | (not include_frozen_patches and not curr_patch.frozen)):
244 | patches.append(curr_patch)
245 | for patch in curr_patch.children:
246 | _get_active_patches_helper(patch, patches)
247 | return patches
248 | return _get_active_patches_helper(self.root, [])
249 |
250 | def activate_random(self):
251 | def _activate_random_helper(curr_patch):
252 | if not curr_patch.children:
253 | curr_patch.activate()
254 | return
255 | elif (curr_patch.scale is not None) and (torch.rand(1).item() < 0.2):
256 | curr_patch.activate()
257 | return
258 |
259 | for patch in curr_patch.children:
260 | _activate_random_helper(patch)
261 | return
262 |
263 | _activate_random_helper(self.root)
264 |
265 | def synchronize(self, master):
266 | # set active/inactive nodes to be the same as master
267 | # for now just toggle the flags without worrying about the gurobi variables
268 | def _synchronize_helper(curr_patch, curr_patch_master):
269 | curr_patch.active = curr_patch_master.active
270 | if not curr_patch.children:
271 | return
272 |
273 | for patch, patch_master in zip(curr_patch.children, curr_patch_master.children):
274 | _synchronize_helper(patch, patch_master)
275 | return
276 |
277 | _synchronize_helper(self.root, master.root)
278 |
279 | def get_frozen_samples(self):
280 | patches = self.get_frozen_patches()
281 | if not patches:
282 | return None, None, None, None
283 |
284 | rel_coords, abs_coords, vals = [], [], []
285 | patch_idx = []
286 |
287 | for idx, p in enumerate(patches):
288 | rel_samp, abs_samp = p.get_stratified_samples(jitter=False, eval=True)
289 |
290 | rel_samp = rel_samp.reshape(-1, int(np.prod(self.min_patch_size * self.aspect_ratio)), 2)
291 | abs_samp = abs_samp.reshape(-1, int(np.prod(self.min_patch_size * self.aspect_ratio)), 2)
292 | patch_idx.extend(rel_samp.shape[0] * [idx, ])
293 |
294 | rel_coords.append(rel_samp)
295 | abs_coords.append(abs_samp)
296 | # values have the same size as rel_samp but last dim is a scalar
297 | vals.append(p.value*torch.ones(abs_samp.shape[:-1] + (1,)))
298 |
299 | return torch.cat(rel_coords, dim=0), torch.cat(abs_coords, dim=0), \
300 | torch.cat(vals, dim=0), patch_idx
301 |
302 | def get_stratified_samples(self, jitter=True, eval=False):
303 | patches = self.get_active_patches()
304 |
305 | rel_coords, abs_coords = [], []
306 | patch_idx = []
307 |
308 | for idx, p in enumerate(patches):
309 | rel_samp, abs_samp = p.get_stratified_samples(jitter=jitter, eval=eval)
310 |
311 | # always batch the coordinates in groups of a specific patch size
312 | # so we can process them in parallel
313 | rel_samp = rel_samp.reshape(-1, int(np.prod(self.min_patch_size * self.aspect_ratio)), 2)
314 | abs_samp = abs_samp.reshape(-1, int(np.prod(self.min_patch_size * self.aspect_ratio)), 2)
315 |
316 | # since patch samples could be split across batches,
317 | # keep track of which batch idx maps to which patch idx
318 | patch_idx.extend(rel_samp.shape[0] * [idx, ])
319 |
320 | rel_coords.append(rel_samp)
321 | abs_coords.append(abs_samp)
322 | return torch.cat(rel_coords, dim=0), torch.cat(abs_coords, dim=0), patch_idx
323 |
324 | def solve_optim(self, max_num_patches=1024):
325 | patches = self.get_active_patches()
326 |
327 | assert (len(patches) <= max_num_patches), \
328 | "You are trying to solve a model which is infeasible: " \
329 | "Number of active patches > Max number of patches"
330 |
331 | if self.c_max_patches is not None:
332 | self.optim_model.remove(self.c_max_patches)
333 |
334 | # global "knapsack" constraint
335 | expr_sum_patches = [p.update_merge() for p in patches]
336 | self.c_max_patches = self.optim_model.addConstr(gp.quicksum(expr_sum_patches) <= max_num_patches)
337 |
338 | # objective
339 | self.optim_model.setObjective(gp.quicksum([p.get_cost() for p in patches]), GRB.MINIMIZE)
340 | self.optim_model.optimize()
341 | obj_val = self.optim_model.objVal
342 |
343 | if self.optim_model.Status == GRB.INFEASIBLE:
344 | print("----------- Model is infeasible")
345 | self.optim_model.computeIIS()
346 | self.optim_model.write("model.ilp")
347 |
348 | # split and merge
349 | merged = 0
350 | split = 0
351 | none = 0
352 | for p in patches:
353 | # print(p)
354 | if p.has_split() and p.scale < self.max_quadtree_level:
355 | p.deactivate()
356 | for child in p.get_children():
357 | child.activate()
358 | split += 1
359 | elif p.has_merged() and p.scale >= self.min_quadtree_level and p.scale > 0:
360 | # we first check if it is active,
361 | # since we could have already been activated by a neighbor
362 | if p.active:
363 | for neighbor in p.get_neighbors():
364 | neighbor.deactivate()
365 | p.parent.activate()
366 | merged += 1
367 |
368 | else:
369 | p.update()
370 | none += 1
371 | stats_dict = {'merged': merged,
372 | 'splits': split,
373 | 'none': none,
374 | 'obj': obj_val}
375 | print(f"============================= Total patches:{len(patches)}, split/merge:{split}/{merged}")
376 | return stats_dict
377 |
378 | def draw(self):
379 | fig, ax = plt.subplots(1, figsize=(5, 5))
380 | depth = 1 + self.max_quadtree_level - self.min_quadtree_level
381 | sidelen = 4**(depth-1) // 2**(depth-1)
382 |
383 | # calculate scale
384 | patch_list = self.get_active_patches()
385 | patches_err = [p.err for p in patch_list]
386 |
387 | max_err = np.max(patches_err)
388 | min_err = np.min(patches_err)
389 |
390 | cmap = plt.cm.get_cmap('viridis')
391 |
392 | def _draw_level(patch, curr_level, ax, sidelen, offset, scale):
393 | if curr_level > self.max_quadtree_level:
394 | return ax
395 |
396 | scale = scale/2.
397 |
398 | for i, child in enumerate(patch.children):
399 | if i == 0:
400 | new_offset = (offset[0], offset[1])
401 | elif i == 1:
402 | new_offset = (offset[0] + scale * sidelen, offset[1])
403 | elif i == 2:
404 | new_offset = (offset[0], offset[1] + scale * sidelen)
405 | else:
406 | new_offset = (offset[0] + scale * sidelen, offset[1] + scale * sidelen)
407 |
408 | if child.active:
409 | norm_err = (child.err-min_err)/(max_err-min_err)
410 | if child.frozen:
411 | facecolor = 'white'
412 | edgecolor = 'red'
413 | else:
414 | facecolor = cmap(norm_err)
415 | edgecolor = 'black'
416 | rect = patches.Rectangle(new_offset, scale * sidelen, scale * sidelen, linewidth=1,
417 | edgecolor=edgecolor,
418 | facecolor=facecolor, fill=True)
419 | ax.add_patch(rect)
420 | else:
421 | ax = _draw_level(child, curr_level+1, ax, sidelen, new_offset, scale)
422 |
423 | return ax
424 |
425 | ax = _draw_level(self.root, self.min_quadtree_level, ax, sidelen, (0., 0.), 1.)
426 | ax.set_aspect('equal')
427 | plt.xlim(-1, sidelen + 1)
428 | plt.ylim(-1, sidelen + 1)
429 | plt.gca().invert_yaxis() # we want 0,0 to be on top-left
430 |
431 | norm = colors.Normalize(vmin=min_err, vmax=max_err)
432 | sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
433 | sm.set_array([])
434 | plt.colorbar(sm)
435 | return fig
436 |
437 |
438 | # patch class
439 | class Patch():
440 | def __init__(self, optim_model=None, block_coord=None, scale=None):
441 | self.active = False
442 |
443 | self.parent = None
444 | self.children = []
445 |
446 | # absolute block coordinate
447 | self.block_coord = block_coord
448 |
449 | # size of block in absolute coord frame
450 | self.block_size = None
451 |
452 | # scale level of block
453 | self.scale = scale
454 |
455 | # num samples to be generated for this block
456 | self.num_samples = None
457 |
458 | # num pixels in this patch
459 | self.pixel_size = None
460 |
461 | # optimization model
462 | self.optim = optim_model
463 |
464 | # row/column coords for sampling at test time
465 | # initialized by set_samples() function
466 | self.row_coords = None
467 | self.col_coords = None
468 | self.eval_row_coords = None
469 | self.eval_col_coords = None
470 |
471 | # error for doing nothing, merging, splitting
472 | self.err = 0.
473 | self.last_updated = 0.
474 |
475 | self._nocopy = ['optim', 'I_grp', 'I_split', 'I_none',
476 | 'I_merge', 'c_joinable', 'c_merge_split']
477 | self._pickle_vars = ['parent', 'children', 'active', 'err', 'last_updated']
478 | self.spec_cstrs = []
479 |
480 | # options for pruning
481 | self.frozen = False
482 | self.value = 0.0
483 |
484 | def __str__(self):
485 | str = f"Patch id={id(self)}\n" \
486 | f" . active={self.active}\n" \
487 | f" . level={self.scale}\n" \
488 | f" . model={self.optim}"
489 |
490 | if self.active:
491 | str += f"\n . g={self.I_grp.x}, s={self.I_split.x}, n={self.I_none.x}"
492 |
493 | return str
494 |
495 | # override deep copy to copy undeepcopyable objects by reference
496 | def __deepcopy__(self, memo):
497 | deep_copied_obj = Patch()
498 | for k, v in self.__dict__.items():
499 | if k in self._nocopy:
500 | # setattr(deep_copied_obj, k, None)
501 | if k in deep_copied_obj.__dict__.keys():
502 | del(deep_copied_obj.__dict__[k])
503 | else:
504 | setattr(deep_copied_obj, k, copy.deepcopy(v, memo))
505 |
506 | return deep_copied_obj
507 |
508 | def __getstate__(self):
509 | state = self.__dict__.copy()
510 | for k, v in self.__dict__.items():
511 | if k in self._nocopy:
512 | # if k not in self._pickle_vars:
513 | del(state[k])
514 | return state
515 |
516 | def __load__(self, obj):
517 | for k, v in obj.__dict__.items():
518 | if k in ['children', 'parent']:
519 | continue
520 | setattr(self, k, v)
521 | if self.active:
522 | self.activate()
523 |
524 | def update(self):
525 | self.deactivate()
526 | self.activate()
527 |
528 | def activate(self):
529 | self.active = True
530 |
531 | # indicator variables
532 | self.I_grp = self.optim.addVar(vtype=GRB.BINARY)
533 | self.I_split = self.optim.addVar(vtype=GRB.BINARY)
534 | self.I_none = self.optim.addVar(vtype=GRB.BINARY)
535 |
536 | self.I_merge = gp.LinExpr(0.0)
537 |
538 | # local constraint "merge/none/split"
539 | self.c_joinable = self.optim.addConstr(self.I_grp + self.I_none + self.I_split == 1)
540 |
541 | # local constraint "merge-split"
542 | self.c_merge_split = None
543 |
544 | def deactivate(self):
545 | self.active = False
546 |
547 | self.optim.remove(self.I_grp)
548 | self.optim.remove(self.I_split)
549 | self.optim.remove(self.I_none)
550 |
551 | self.I_merge = gp.LinExpr(0.0)
552 |
553 | self.optim.remove(self.c_joinable)
554 |
555 | if self.c_merge_split is not None:
556 | self.optim.remove(self.c_merge_split)
557 |
558 | for cstr in self.spec_cstrs:
559 | self.optim.remove(cstr)
560 | self.spec_cstrs = []
561 |
562 | def is_mergeable(self):
563 | siblings = self.parent.children
564 | return np.all(np.all([sib.active for sib in siblings]))
565 |
566 | def set_sample_params(self, num_samples):
567 | self.num_samples = num_samples
568 | posts = torch.linspace(-1, 1, self.num_samples+1)[:-1]
569 | row_coords, col_coords = torch.meshgrid(posts, posts)
570 | self.row_coords = row_coords.flatten()
571 | self.col_coords = col_coords.flatten()
572 |
573 | def must_split(self):
574 | self.spec_cstrs.append(
575 | self.optim.addConstr(self.I_split == 1)
576 | )
577 |
578 | def must_merge(self):
579 | self.spec_cstrs.append(
580 | self.optim.addConstr(self.I_grp == 1)
581 | )
582 |
583 | def has_split(self):
584 | return self.I_split.x == 1
585 |
586 | def has_merged(self):
587 | return self.I_grp.x == 1
588 | # return self.I_none.x==0 and self.I_split.x==0
589 |
590 | def has_done_nothing(self):
591 | return self.I_none.x == 1
592 |
593 | def get_cost(self):
594 | area = self.block_size[0]**2
595 | alpha = 0.2 # how much worse we expect the error to be when merging
596 | beta = -0.02 # how much better we expect the error to be when splitting
597 |
598 | # == Merge
599 | if self.scale > 0: # it should never be root, but still..
600 | err_merge = (4+alpha) * area * self.err
601 |
602 | if self.parent.last_updated:
603 | parent_area = self.parent.block_size[0]**2
604 | err_merge = parent_area * self.parent.err # can multiply by 1/4 as in paper to make merging more aggressive
605 | else:
606 | err_merge = self.err
607 |
608 | # == Split
609 | if self.children:
610 | err_split = (0.25+beta) * area * self.err
611 |
612 | if self.children[0].last_updated:
613 | err_children = np.sum([child.err for child in self.children])
614 | err_split = area * err_children
615 | else:
616 | err_split = 1. # in case you don't have children, high to avoid splitting
617 |
618 | # == None
619 | err_none = area * self.err
620 |
621 | return err_none * self.I_none \
622 | + err_split * self.I_split \
623 | + err_merge * self.I_grp
624 |
625 | def update_merge(self):
626 | if self.parent is None: # if root
627 | return gp.LinExpr(0)
628 |
629 | siblings = self.parent.children
630 | if np.all([sib.active for sib in siblings]):
631 | I_grp_neighs = [s.I_grp for s in siblings]
632 | self.I_merge = gp.quicksum(I_grp_neighs)
633 |
634 | # local constraint "joinable"
635 | self.c_merge_split = self.optim.addConstr(self.I_none + self.I_split + .25*self.I_merge == 1)
636 | expr_max_patches = 4 * self.I_split + 1 * self.I_none + .25 * self.I_grp
637 |
638 | return expr_max_patches
639 |
640 | def get_neighbors(self):
641 | return self.parent.children
642 |
643 | def get_children(self):
644 | return self.children
645 |
646 | def get_parent(self):
647 | return self.parent
648 |
649 | def is_joinable(self):
650 | # test if siblings are all leaf nodes
651 | siblings = self.parent.children
652 | return np.all([sib.active for sib in siblings])
653 |
654 | def get_block_coord(self):
655 | return self.block_coord
656 |
657 | def get_scale(self):
658 | return self.scale
659 |
660 | def update_error(self, error, iter):
661 | self.err = error
662 | self.last_updated = iter
663 |
664 | def get_stratified_samples(self, jitter=True, eval=False):
665 | # Block coords are always aligned to the pixel grid,
666 | # e.g., they align with pixels 0, 8, 16, 24, etc. for
667 | # patch size 8
668 | #
669 | # To normalize the coordinates between (-1, 1), consider
670 | # we have an image of 64x64 and patch size 8x8.
671 | # The block coordinate (-1, -1) aligns with pixel (0, 0)
672 | # and coordinate (1, 1) aligns with pixel (63, 63)
673 | #
674 | # Absolute coordinates within a block should stretch all the way
675 | # from the absolute position of one block coordinate to another.
676 | # Say each block contains 8x8 pixels and we use a feature grid
677 | # of 8x8 features to interpolate values within a block.
678 | # This means is that the feature positions are not actually
679 | # aligned to the pixel positions. The features are positioned
680 | # on a grid stretching from one block coord to another whereas
681 | # the pixel grid ends just short of the next block coordinate
682 | #
683 | # Example patch (x = pixel position, B = block coordinate position)
684 | # and relative coordinate positions.
685 | #
686 | # -1 ^ B x x x x x x x B
687 | # | x x x x x x x x x
688 | # | x x x x x x x x x
689 | # | x x x x x x x x x
690 | # | x x x x x x x x x
691 | # | x x x x x x x x x
692 | # | x x x x x x x x x
693 | # | x x x x x x x x x
694 | # 1 v B x x x x x x x B
695 | # <--------------->
696 | # -1 1
697 | #
698 | # When we generate samples for a patch, we sample an
699 | # 8x8 grid that extends between block coords, i.e.
700 | # between the arrows above
701 | #
702 | if eval:
703 | row_coords = self.eval_row_coords.flatten()
704 | col_coords = self.eval_col_coords.flatten()
705 | else:
706 | row_coords = self.row_coords
707 | col_coords = self.col_coords
708 |
709 | if jitter:
710 | row_coords = self.row_coords + torch.rand_like(self.row_coords) * 2./self.num_samples[0]
711 | col_coords = self.col_coords + torch.rand_like(self.col_coords) * 2./self.num_samples[1]
712 |
713 | rel_samples = torch.stack((row_coords, col_coords), dim=-1)
714 | abs_samples = self.block_coord[None, :] + self.block_size[None, :] * (rel_samples+1)/2
715 | return rel_samples, abs_samples
716 |
717 |
718 | class OctTree():
719 | def __init__(self, sidelength, min_octant_size, bounds=((-1, 1), (-1, 1), (-1, 1)), mesh_kd_tree=None):
720 | self.sidelength = sidelength
721 | self.min_octant_size = min_octant_size
722 | self.max_octant_size = sidelength[0]
723 |
724 | # how many levels of quadtree are there
725 | self.min_octtree_level = int(np.log2(self.sidelength[0] // self.max_octant_size))
726 | self.max_octtree_level = int(np.log2(sidelength[0] // min_octant_size))
727 | self.num_scales = self.max_octtree_level - self.min_octtree_level + 1
728 |
729 | # optimization model
730 | self.optim_model = gp.Model()
731 | self.optim_model.setParam('OutputFlag', 0)
732 | self.c_max_octants = None
733 |
734 | # set bounds
735 | self.z_min, self.z_max = bounds[0]
736 | self.y_min, self.y_max = bounds[1]
737 | self.x_min, self.x_max = bounds[2]
738 |
739 | # KD tree that stores points on mesh surface
740 | self.surface_tree = mesh_kd_tree
741 |
742 | # initialize tree
743 | self.root = self.init_root(self.max_octtree_level)
744 |
745 | # populate tree nodes with coordinate values, metadata
746 | self.populate_tree()
747 |
748 | def __deepcopy__(self, memo):
749 | deep_copied_obj = OctTree(self.sidelength, self.min_octant_size)
750 |
751 | for k, v in self.__dict__.items():
752 | if k in ['optim_model', 'c_max_octants']:
753 | setattr(deep_copied_obj, k, v)
754 | else:
755 | setattr(deep_copied_obj, k, copy.deepcopy(v, memo))
756 | return deep_copied_obj
757 |
758 | def __getstate__(self):
759 | state = self.__dict__.copy()
760 | for k, v in self.__dict__.items():
761 | if k in ['optim_model', 'c_max_octants']:
762 | del(state[k])
763 | return state
764 |
765 | def __load__(self, obj):
766 | for k, v in obj.__dict__.items():
767 | if k == 'root':
768 | continue
769 | setattr(self, k, v)
770 | self.root = self.init_root(obj.max_octtree_level)
771 | self.populate_tree()
772 |
773 | def _load_helper(curr_patch, curr_obj_patch):
774 | curr_patch.__load__(curr_obj_patch)
775 | for child, obj_child in zip(curr_patch.children, curr_obj_patch.children):
776 | _load_helper(child, obj_child)
777 | return
778 | _load_helper(self.root, obj.root)
779 |
780 | def __str__(self, level=0):
781 |
782 | def _str_helper(curr_octant, level):
783 | ret = "\t"*level+repr(curr_octant.active)+"\n"
784 | for child in curr_octant.children:
785 | ret += _str_helper(child, level+1)
786 | return ret
787 | return _str_helper(self.root, 0)
788 |
789 | def populate_tree(self):
790 |
791 | # maximum octant scale
792 | max_octant_scale = int(np.log2(self.max_octant_size))
793 | min_octant_scale = int(np.log2(self.min_octant_size))
794 |
795 | # get block coords for octants at each scale
796 | octant_sizes = [2**s for s in range(min_octant_scale, max_octant_scale+1)]
797 | octant_sizes.reverse()
798 | block_coords = [self.get_block_coords(octant_size=octant_size, include_ends=True) for octant_size in octant_sizes]
799 | block_sizes = [block[1, 1, 1, :] - block[0, 0, 0, :] for block in block_coords]
800 | block_coords = [block[:-1, :-1, :-1, :] for block in block_coords]
801 |
802 | # create sampling grids for training
803 | num_samples = self.min_octant_size
804 | posts = torch.linspace(-1, 1, self.min_octant_size+1)[:-1]
805 | row_coords, col_coords, dep_coords = torch.meshgrid(posts, posts, posts)
806 | row_coords = row_coords.flatten()
807 | col_coords = col_coords.flatten()
808 | dep_coords = dep_coords.flatten()
809 |
810 | # create sampling grids for evaluation
811 | # here we need to sample every voxel within each block
812 | posts = [torch.linspace(-1, 1, voxel_size+1)[:-1] + (1/voxel_size)/2 for voxel_size in octant_sizes]
813 | eval_coords = [torch.meshgrid(post, post, post) for post in posts]
814 | eval_row_coords = [eval_coord[0].flatten() for eval_coord in eval_coords]
815 | eval_col_coords = [eval_coord[1].flatten() for eval_coord in eval_coords]
816 | eval_dep_coords = [eval_coord[2].flatten() for eval_coord in eval_coords]
817 |
818 | def _populate_tree_helper(octant, idx):
819 | # get block scale idx
820 | scale_idx = len(idx) - (self.min_octtree_level)
821 |
822 | # do we have octants at this level?
823 | if scale_idx >= 0:
824 |
825 | # set patch parameters
826 | coords = block_coords[scale_idx]
827 | coord_idx = _index_block_coord(idx, coords.shape[0], coord=[0, 0, 0])
828 | octant.block_coord = coords[coord_idx[0], coord_idx[1], coord_idx[2]]
829 | octant.block_size = block_sizes[scale_idx]
830 | octant.scale = scale_idx
831 | octant.voxel_size = octant_sizes[scale_idx]
832 | octant.num_samples = num_samples
833 | octant.row_coords = row_coords
834 | octant.col_coords = col_coords
835 | octant.dep_coords = dep_coords
836 | octant.eval_row_coords = eval_row_coords[scale_idx]
837 | octant.eval_col_coords = eval_col_coords[scale_idx]
838 | octant.eval_dep_coords = eval_dep_coords[scale_idx]
839 | octant.surface_tree = self.surface_tree
840 |
841 | if not octant.children:
842 | return
843 |
844 | # recurse
845 | for i in range(8):
846 | child = octant.children[i]
847 | _populate_tree_helper(child, [*idx, i])
848 |
849 | return
850 |
851 | # given list of tree idxs in {0,1,2,3}^N, retrieve the block coordinate
852 | def _index_block_coord(tree_idx, length, coord=[0, 0, 0]):
853 | if length == 1:
854 | return coord
855 |
856 | # depth 0
857 | if tree_idx[0] == 0:
858 | pass
859 | elif tree_idx[0] == 1:
860 | coord[1] += length//2
861 | elif tree_idx[0] == 2:
862 | coord[0] += length//2
863 | elif tree_idx[0] == 3:
864 | coord[0] += length//2
865 | coord[1] += length//2
866 | # depth 1
867 | elif tree_idx[0] == 4:
868 | coord[2] += length//2
869 | elif tree_idx[0] == 5:
870 | coord[1] += length//2
871 | coord[2] += length//2
872 | elif tree_idx[0] == 6:
873 | coord[0] += length//2
874 | coord[2] += length//2
875 | elif tree_idx[0] == 7:
876 | coord[0] += length//2
877 | coord[1] += length//2
878 | coord[2] += length//2
879 | else:
880 | raise ValueError("Unexpected child value, should be in{0,...7}")
881 |
882 | return _index_block_coord(tree_idx[1:], length//2, coord)
883 |
884 | # done with setup, now actually populate the tree
885 | _populate_tree_helper(self.root, [])
886 |
887 | def init_root(self, max_level):
888 |
889 | def _init_root_helper(curr_octant, curr_level, max_level, optim_model):
890 | if curr_level == max_level:
891 | return
892 | curr_octant.children = [Octant(optim_model) for _ in range(8)]
893 |
894 | for octant in curr_octant.children:
895 | octant.parent = curr_octant
896 | _init_root_helper(octant, curr_level+1, max_level, optim_model)
897 | return
898 |
899 | # create root node
900 | root = Octant(self.optim_model)
901 | _init_root_helper(root, 0, max_level, self.optim_model)
902 | return root
903 |
904 | def get_block_coords(self, flatten=False, include_ends=False, octant_size=None):
905 |
906 | # use finest scale patch by default
907 | if octant_size is None:
908 | octant_size = self.min_octant_size # TODO: ?? verify
909 |
910 | # get size of each block
911 | z_len = self.z_max - self.z_min
912 | y_len = self.y_max - self.y_min
913 | x_len = self.x_max - self.x_min
914 |
915 | block_size = (z_len / (self.sidelength[0]) * octant_size,
916 | y_len / (self.sidelength[1]) * octant_size,
917 | x_len / (self.sidelength[2]) * octant_size)
918 |
919 | # get block begin/end coordinates
920 | if include_ends:
921 | block_coords_z = torch.arange(self.z_min, self.z_max + block_size[0], block_size[0])
922 | block_coords_y = torch.arange(self.y_min, self.y_max + block_size[1], block_size[1])
923 | block_coords_x = torch.arange(self.x_min, self.x_max + block_size[2], block_size[2])
924 | else:
925 | block_coords_z = torch.arange(self.z_min, self.z_max, block_size[0])
926 | block_coords_y = torch.arange(self.y_min, self.y_max, block_size[1])
927 | block_coords_x = torch.arange(self.x_min, self.x_max, block_size[2])
928 |
929 | # repeat for every single block
930 | block_coords = torch.meshgrid(block_coords_z, block_coords_y, block_coords_x)
931 | block_coords = torch.stack((block_coords[0], block_coords[1], block_coords[2]), dim=-1)
932 | if flatten:
933 | block_coords = block_coords.reshape(-1, 3)
934 |
935 | return block_coords
936 |
937 | def get_octants_at_level(self, level):
938 | # level is the image scale: 0-> coarsest, N->finest
939 |
940 | # what quadtree level do our octants start at?
941 | # check input, too
942 | target_level = level + self.min_octtree_level
943 | assert level <= (self.max_octtree_level - self.min_octtree_level), \
944 | "invalid 'level' input to get_blocks_at_level"
945 |
946 | def _get_octants_at_level_helper(curr_octant, curr_level, octants):
947 | if curr_level > target_level:
948 | return
949 |
950 | for octant in curr_octant.children:
951 | _get_octants_at_level_helper(octant, curr_level+1, octants)
952 |
953 | if curr_level == target_level:
954 | octants.append(curr_octant)
955 | return octants
956 |
957 | return _get_octants_at_level_helper(self.root, 0, [])
958 |
959 | def get_frozen_octants(self):
960 |
961 | def _get_frozen_octants_helper(curr_octant, octants):
962 | if curr_octant.frozen and curr_octant.active:
963 | octants.append(curr_octant)
964 | for octant in curr_octant.children:
965 | _get_frozen_octants_helper(octant, octants)
966 | return octants
967 | return _get_frozen_octants_helper(self.root, [])
968 |
969 | def get_active_octants(self, include_frozen_octants=False):
970 |
971 | def _get_active_octants_helper(curr_octant, octants):
972 | if curr_octant.active and \
973 | (include_frozen_octants or
974 | (not include_frozen_octants and not curr_octant.frozen)):
975 | octants.append(curr_octant)
976 | return octants
977 | for octant in curr_octant.children:
978 | _get_active_octants_helper(octant, octants)
979 | return octants
980 | return _get_active_octants_helper(self.root, [])
981 |
982 | def activate_random(self):
983 | def _activate_random_helper(curr_octant):
984 | if not curr_octant.children:
985 | curr_octant.activate()
986 | return
987 | elif (curr_octant.scale is not None) and (torch.rand(1).item() < 0.2):
988 | curr_octant.activate()
989 | return
990 |
991 | for patch in curr_octant.children:
992 | _activate_random_helper(patch)
993 | return
994 |
995 | _activate_random_helper(self.root)
996 |
997 | def synchronize(self, master):
998 | # set active/inactive nodes to be the same as master
999 | # for now just toggle the flags without worrying about the gurobi variables
1000 | def _synchronize_helper(curr_octant, curr_octant_master):
1001 | curr_octant.active = curr_octant_master.active
1002 | curr_octant.frozen = curr_octant_master.frozen
1003 | curr_octant.value = curr_octant_master.value
1004 | if not curr_octant.children:
1005 | return
1006 |
1007 | for octant, octant_master in zip(curr_octant.children, curr_octant_master.children):
1008 | _synchronize_helper(octant, octant_master)
1009 | return
1010 |
1011 | _synchronize_helper(self.root, master.root)
1012 |
1013 | def get_frozen_samples(self, oversample):
1014 | octants = self.get_frozen_octants()
1015 |
1016 | if not octants:
1017 | return None, None, None, None
1018 |
1019 | rel_coords, abs_coords, vals = [], [], []
1020 | octant_idx = []
1021 |
1022 | for idx, p in enumerate(octants):
1023 | rel_samp, abs_samp, _ = p.get_stratified_samples(jitter=False, eval=True, oversample=oversample)
1024 |
1025 | rel_samp = rel_samp.reshape(-1, self.min_octant_size**3, 3)
1026 | abs_samp = abs_samp.reshape(-1, self.min_octant_size**3, 3)
1027 | octant_idx.extend(rel_samp.shape[0] * [idx, ])
1028 |
1029 | rel_coords.append(rel_samp)
1030 | abs_coords.append(abs_samp)
1031 | # values have the same size as rel_samp but last dim is a scalar
1032 | vals.append(p.value*torch.ones(abs_samp.shape[:-1] + (1,)))
1033 |
1034 | return torch.cat(rel_coords, dim=0), torch.cat(abs_coords, dim=0), \
1035 | torch.cat(vals, dim=0), octant_idx
1036 |
1037 | def get_stratified_samples(self, jitter=True, eval=False, oversample=1):
1038 | octants = self.get_active_octants()
1039 |
1040 | rel_coords, abs_coords = [], []
1041 | all_global_indices = []
1042 | octant_idx = []
1043 |
1044 | for idx, p in enumerate(octants):
1045 |
1046 | rel_samp, abs_samp, global_indices = p.get_stratified_samples(jitter=jitter, eval=eval, oversample=oversample)
1047 |
1048 | # always batch the coordinates in groups of a specific patch size
1049 | # so we can process them in parallel
1050 | rel_samp = rel_samp.reshape(-1, int(self.min_octant_size * oversample)**3, 3)
1051 | abs_samp = abs_samp.reshape(-1, int(self.min_octant_size * oversample)**3, 3)
1052 | if global_indices is not None:
1053 | global_indices = global_indices.reshape(-1, int(self.min_octant_size * oversample)**3, 1)
1054 |
1055 | # since patch samples could be split across batches,
1056 | # keep track of which batch idx maps to which patch idx
1057 | octant_idx.extend(rel_samp.shape[0] * [idx, ])
1058 |
1059 | rel_coords.append(rel_samp)
1060 | abs_coords.append(abs_samp)
1061 | if global_indices is not None:
1062 | all_global_indices.append(global_indices)
1063 |
1064 | return torch.cat(rel_coords, dim=0), torch.cat(abs_coords, dim=0), octant_idx, None
1065 |
1066 | def solve_optim(self, Max_Num_Octants=150):
1067 | octants = self.get_active_octants()
1068 |
1069 | assert (len(octants) <= Max_Num_Octants), \
1070 | "You are trying to solve a model which is infeasible: " \
1071 | "Number of active octants > Max number of octants"
1072 |
1073 | if self.c_max_octants is not None:
1074 | self.optim_model.remove(self.c_max_octants)
1075 |
1076 | # global "knapsack" constraint
1077 | expr_sum_octants = [p.update_merge() for p in octants]
1078 | self.c_max_octants = self.optim_model.addConstr(gp.quicksum(expr_sum_octants) <= Max_Num_Octants)
1079 |
1080 | # objective
1081 | self.optim_model.setObjective(gp.quicksum([p.get_cost() for p in octants]), GRB.MINIMIZE)
1082 | self.optim_model.optimize()
1083 | obj_val = self.optim_model.objVal
1084 |
1085 | if self.optim_model.Status == GRB.INFEASIBLE:
1086 | print("----------- Model is infeasible")
1087 | self.optim_model.computeIIS()
1088 | self.optim_model.write("model.ilp")
1089 |
1090 | # split and merge
1091 | merged = 0
1092 | split = 0
1093 | none = 0
1094 | for p in octants:
1095 | # print(p)
1096 | if p.has_split() and p.scale < self.max_octtree_level:
1097 | p.deactivate()
1098 | for child in p.get_children():
1099 | child.activate()
1100 | split += 1
1101 | elif p.has_merged() and p.scale >= self.min_octtree_level and p.scale > 0:
1102 | # we first check if it is active,
1103 | # since we could have already been activated by a neighbor
1104 | if p.active:
1105 | for neighbor in p.get_neighbors():
1106 | neighbor.deactivate()
1107 | p.parent.activate()
1108 | merged += 1
1109 |
1110 | else:
1111 | p.update()
1112 | none += 1
1113 |
1114 | stats_dict = {'merged': merged,
1115 | 'splits': split,
1116 | 'none': none,
1117 | 'obj': obj_val}
1118 | print(f"============================= Total octants:{len(octants)}, split/merge:{split}/{merged}")
1119 | print(f"Vars={len(self.optim_model.getVars())}, Cstrs={len(self.optim_model.getConstrs())}")
1120 | return stats_dict
1121 |
1122 | def draw(self, color_by_scale=False, save_fig=True):
1123 | fig = plt.figure(figsize=(5, 5))
1124 | ax = fig.gca(projection='3d')
1125 |
1126 | depth = 1 + self.max_octtree_level - self.min_octtree_level
1127 | sidelen = 4**(depth-1) // 2**(depth-1)
1128 |
1129 | # calculate scale
1130 | octant_list = self.get_active_octants()
1131 | octants_err = [p.err for p in octant_list]
1132 | max_err = np.max(octants_err)
1133 | min_err = np.min(octants_err)
1134 |
1135 | cmap = plt.cm.get_cmap('viridis')
1136 |
1137 | def cuboid_data(pos, size=(1, 1, 1)):
1138 | eps = 0.2
1139 | o = pos + (eps, eps, eps)
1140 | # get the length, width, and height
1141 | l, w, h = size
1142 | l -= eps
1143 | w -= eps
1144 | h -= eps
1145 | x = [[o[0], o[0] + l, o[0] + l, o[0], o[0]],
1146 | [o[0], o[0] + l, o[0] + l, o[0], o[0]],
1147 | [o[0], o[0] + l, o[0] + l, o[0], o[0]],
1148 | [o[0], o[0] + l, o[0] + l, o[0], o[0]]]
1149 | y = [[o[1], o[1], o[1] + w, o[1] + w, o[1]],
1150 | [o[1], o[1], o[1] + w, o[1] + w, o[1]],
1151 | [o[1], o[1], o[1], o[1], o[1]],
1152 | [o[1] + w, o[1] + w, o[1] + w, o[1] + w, o[1] + w]]
1153 | z = [[o[2], o[2], o[2], o[2], o[2]],
1154 | [o[2] + h, o[2] + h, o[2] + h, o[2] + h, o[2] + h],
1155 | [o[2], o[2], o[2] + h, o[2] + h, o[2]],
1156 | [o[2], o[2], o[2] + h, o[2] + h, o[2]]]
1157 | return np.array(x), np.array(y), np.array(z)
1158 |
1159 | def draw_cube_at(pos=(0, 0, 0), size=(1, 1, 1),
1160 | color='b', edgecolor='b', alpha=1., ax=None):
1161 |
1162 | # Plotting a cube element at position pos
1163 | if ax is not None:
1164 | Z, Y, X = cuboid_data(pos, size)
1165 | ax.plot_surface(X, Y, Z, color=color, rstride=1, cstride=1, alpha=alpha,
1166 | edgecolors=edgecolor, linewidth=0.1)
1167 |
1168 | def _draw_level_2(octant, curr_level,
1169 | ax, sidelen, offset, scale):
1170 | if curr_level > self.max_octtree_level:
1171 | return ax
1172 |
1173 | scale = scale/2.
1174 |
1175 | for i, child in enumerate(octant.children):
1176 | # depth 0
1177 | if i == 0:
1178 | new_offset = (offset[0],
1179 | offset[1],
1180 | offset[2])
1181 | elif i == 1:
1182 | new_offset = (offset[0] + scale * sidelen,
1183 | offset[1],
1184 | offset[2])
1185 | elif i == 2:
1186 | new_offset = (offset[0],
1187 | offset[1] + scale * sidelen,
1188 | offset[2])
1189 | elif i == 3:
1190 | new_offset = (offset[0] + scale * sidelen,
1191 | offset[1] + scale * sidelen,
1192 | offset[2])
1193 | # depth 1
1194 | elif i == 4:
1195 | new_offset = (offset[0],
1196 | offset[1],
1197 | offset[2] + scale * sidelen)
1198 | elif i == 5:
1199 | new_offset = (offset[0] + scale * sidelen,
1200 | offset[1],
1201 | offset[2] + scale * sidelen)
1202 | elif i == 6:
1203 | new_offset = (offset[0],
1204 | offset[1] + scale * sidelen,
1205 | offset[2] + scale * sidelen)
1206 | else:
1207 | new_offset = (offset[0] + scale * sidelen,
1208 | offset[1] + scale * sidelen,
1209 | offset[2] + scale * sidelen)
1210 |
1211 | if child.active:
1212 | norm_err = (child.err-min_err)/(max_err-min_err)
1213 | sz = scale*sidelen
1214 |
1215 | if child.frozen:
1216 | color = [0.5, 0.5, 0.5]
1217 | edgecolor = [1.0, 0.0, 0.0, 0.1]
1218 | alpha = 0.
1219 | else:
1220 | color = cmap(norm_err)[0:3]
1221 | edgecolor = 'none'
1222 | alpha = 0.1
1223 |
1224 | draw_cube_at(pos=new_offset, size=(sz, sz, sz),
1225 | color=color, edgecolor=edgecolor,
1226 | alpha=alpha, ax=ax)
1227 |
1228 | else:
1229 | ax = _draw_level_2(child, curr_level+1,
1230 | ax, sidelen, new_offset, scale)
1231 |
1232 | return ax
1233 |
1234 | ax = _draw_level_2(self.root, self.min_octtree_level,
1235 | ax, sidelen,
1236 | (0., 0., 0.), 1.)
1237 |
1238 | ax = fig.gca(projection='3d')
1239 | ax.grid(False)
1240 | ax.set_xlim(-1, sidelen+1)
1241 | ax.set_ylim(-1, sidelen+1)
1242 | ax.set_zlim(-1, sidelen+1)
1243 |
1244 | ax.set_xticks([-1, sidelen+1])
1245 | ax.set_xticklabels([-1, 1])
1246 | ax.set_yticks([-1, sidelen+1])
1247 | ax.set_yticklabels([-1, 1])
1248 | ax.set_zticks([-1, sidelen+1])
1249 | ax.set_zticklabels([-1, 1])
1250 |
1251 | return fig
1252 |
1253 |
1254 | class Octant():
1255 | def __init__(self, optim_model=None, block_coord=None, scale=None, gamma=0.95):
1256 | self.active = False
1257 |
1258 | self.parent = None
1259 | self.children = []
1260 |
1261 | # absolute block coordinate
1262 | self.block_coord = block_coord
1263 |
1264 | # size of block in absolute coord frame
1265 | self.block_size = None
1266 |
1267 | self.old_block_coord = None
1268 | self.old_block_size = None
1269 |
1270 | # scale level of block
1271 | self.scale = scale
1272 |
1273 | # num samples to be generated for this block
1274 | self.num_samples = None
1275 |
1276 | # num pixels in this patch
1277 | self.voxel_size = None
1278 |
1279 | # optimization model
1280 | self.optim = optim_model
1281 |
1282 | # row/column coords for sampling at test time
1283 | # initialized by set_samples() function
1284 | self.row_coords = None
1285 | self.col_coords = None
1286 | self.dep_coords = None
1287 |
1288 | self.near_mesh_abs_samples = None
1289 | self.near_mesh_rel_samples = None
1290 |
1291 | # error for doing nothing, merging, splitting
1292 | self.err = 0.
1293 | self.last_updated = 0.
1294 |
1295 | self.gamma = gamma
1296 |
1297 | self._nocopy = ['optim', 'I_grp', 'I_split', 'I_none',
1298 | 'I_merge', 'c_joinable', 'c_merge_split',
1299 | 'children', 'parent', 'loss', 'loss_iter',
1300 | 'err']
1301 | self.spec_cstrs = []
1302 |
1303 | self._pickle_vars = ['parent', 'children', 'active', 'err', 'last_updated', 'frozen', 'value']
1304 |
1305 | # options for pruning
1306 | self.frozen = False
1307 | self.value = 0
1308 |
1309 | def __str__(self):
1310 | str = f"Octant id={id(self)}\n" \
1311 | f" . active={self.active}\n" \
1312 | f" . level={self.scale}\n" \
1313 | f" . model={self.optim}"
1314 |
1315 | if self.active:
1316 | str += f"\n . g={self.I_grp.x}, s={self.I_split.x}, n={self.I_none.x}"
1317 |
1318 | return str
1319 |
1320 | # override deep copy to copy undeepcopyable objects by reference
1321 | def __deepcopy__(self, memo):
1322 | deep_copied_obj = Octant()
1323 | for k, v in self.__dict__.items():
1324 | if k in self._nocopy:
1325 | setattr(deep_copied_obj, k, v)
1326 | else:
1327 | setattr(deep_copied_obj, k, copy.deepcopy(v, memo))
1328 |
1329 | return deep_copied_obj
1330 |
1331 | def __getstate__(self):
1332 | state = self.__dict__.copy()
1333 | for k, v in self.__dict__.items():
1334 | if k not in self._pickle_vars:
1335 | del(state[k])
1336 | return state
1337 |
1338 | def __load__(self, obj):
1339 | for k, v in obj.__dict__.items():
1340 | if k in ['children', 'parent']:
1341 | continue
1342 | setattr(self, k, v)
1343 | if self.active:
1344 | self.activate()
1345 |
1346 | def update(self):
1347 | self.deactivate()
1348 | self.activate()
1349 |
1350 | def activate(self):
1351 | self.active = True
1352 |
1353 | # indicator variables
1354 | self.I_grp = self.optim.addVar(vtype=GRB.BINARY)
1355 | self.I_split = self.optim.addVar(vtype=GRB.BINARY)
1356 | self.I_none = self.optim.addVar(vtype=GRB.BINARY)
1357 |
1358 | self.I_merge = gp.LinExpr(0.0)
1359 |
1360 | # local constraint "merge/none/split"
1361 | self.c_joinable = self.optim.addConstr(self.I_grp + self.I_none + self.I_split == 1)
1362 |
1363 | # local constraint "merge-split"
1364 | self.c_merge_split = None
1365 |
1366 | def deactivate(self):
1367 | self.active = False
1368 |
1369 | self.optim.remove(self.I_grp)
1370 | self.optim.remove(self.I_split)
1371 | self.optim.remove(self.I_none)
1372 |
1373 | self.I_merge = gp.LinExpr(0.0)
1374 |
1375 | self.optim.remove(self.c_joinable)
1376 |
1377 | if self.c_merge_split is not None:
1378 | self.optim.remove(self.c_merge_split)
1379 |
1380 | for cstr in self.spec_cstrs:
1381 | self.optim.remove(cstr)
1382 | self.spec_cstrs = []
1383 |
1384 | def is_mergeable(self):
1385 | siblings = self.parent.children
1386 | return np.all(np.all([sib.active for sib in siblings]))
1387 |
1388 | def set_sample_params(self, num_samples):
1389 | self.num_samples = num_samples
1390 | posts = torch.linspace(-1, 1, self.num_samples+1)[:-1]
1391 | row_coords, col_coords, dep_coords = torch.meshgrid(posts, posts, posts)
1392 | self.row_coords = row_coords.flatten()
1393 | self.col_coords = col_coords.flatten()
1394 | self.dep_coords = dep_coords.flatten()
1395 |
1396 | def must_split(self):
1397 | self.spec_cstrs.append(
1398 | self.optim.addConstr(self.I_split == 1)
1399 | )
1400 |
1401 | def must_merge(self):
1402 | self.spec_cstrs.append(
1403 | self.optim.addConstr(self.I_grp == 1)
1404 | )
1405 |
1406 | def has_split(self):
1407 | return self.I_split.x == 1
1408 |
1409 | def has_merged(self):
1410 | return self.I_grp.x == 1
1411 | # return self.I_none.x==0 and self.I_split.x==0
1412 |
1413 | def has_done_nothing(self):
1414 | return self.I_none.x == 1
1415 |
1416 | def get_cost(self):
1417 | area = self.block_size[0]**2
1418 | alpha = 0.2 # how much worse we expect the error to be when merging
1419 | beta = -0.02 # how much better we expect the error to be when splitting
1420 |
1421 | # == Merge
1422 | if self.scale > 0: # it should never be root, but still..
1423 | err_merge = (8+alpha) * area * self.err
1424 |
1425 | if self.parent.last_updated:
1426 | parent_area = self.parent.block_size[0]**2
1427 | err_merge = parent_area * self.parent.err # can multiply by 1/8 as in paper to make merging more aggressive
1428 | else:
1429 | err_merge = self.err
1430 |
1431 | # == Split
1432 | if self.children:
1433 | err_split = (0.125+beta) * area * self.err
1434 |
1435 | if self.children[0].last_updated:
1436 | err_children = np.sum([child.err for child in self.children])
1437 | err_split = area * err_children
1438 | else:
1439 | err_split = 1. # in case you don't have children, high to avoid splitting
1440 |
1441 | err_none = area * self.err
1442 |
1443 | return err_none * self.I_none \
1444 | + err_split * self.I_split \
1445 | + err_merge * self.I_grp
1446 |
1447 | def update_merge(self):
1448 | if self.parent is None: # if root
1449 | return gp.LinExpr(0)
1450 |
1451 | siblings = self.parent.children
1452 | if np.all([sib.active for sib in siblings]):
1453 | I_grp_neighs = [s.I_grp for s in siblings]
1454 | self.I_merge = gp.quicksum(I_grp_neighs)
1455 |
1456 | # local constraint "joinable"
1457 | self.c_merge_split = self.optim.addConstr(self.I_none + self.I_split + .125*self.I_merge == 1)
1458 | expr_max_patches = 8 * self.I_split + 1 * self.I_none + .125 * self.I_grp
1459 |
1460 | return expr_max_patches
1461 |
1462 | def get_neighbors(self):
1463 | return self.parent.children
1464 |
1465 | def get_children(self):
1466 | return self.children
1467 |
1468 | def get_parent(self):
1469 | return self.parent
1470 |
1471 | def is_joinable(self):
1472 | # test if siblings are all leaf nodes
1473 | siblings = self.parent.children
1474 | return np.all([sib.active for sib in siblings])
1475 |
1476 | def get_block_coord(self):
1477 | return self.block_coord
1478 |
1479 | def get_scale(self):
1480 | return self.scale
1481 |
1482 | def update_error(self, error, iter):
1483 | self.err = self.gamma*self.err + (1-self.gamma)*error
1484 | self.last_updated = iter
1485 |
1486 | def get_block_coords(self, flatten=False, include_ends=False, octant_size=None):
1487 | # get size of each block
1488 | z_len = 2
1489 | y_len = 2
1490 | x_len = 2
1491 |
1492 | sidelength = 256
1493 |
1494 | block_size = (z_len / (sidelength) * octant_size,
1495 | y_len / (sidelength) * octant_size,
1496 | x_len / (sidelength) * octant_size)
1497 |
1498 | # get block begin/end coordinates
1499 | if include_ends:
1500 | block_coords_z = torch.arange(-1, -1 + block_size[0], block_size[0])
1501 | block_coords_y = torch.arange(-1, 1 + block_size[1], block_size[1])
1502 | block_coords_x = torch.arange(-1, 1 + block_size[2], block_size[2])
1503 | else:
1504 | block_coords_z = torch.arange(-1, 1, block_size[0])
1505 | block_coords_y = torch.arange(-1, 1, block_size[1])
1506 | block_coords_x = torch.arange(-1, 1, block_size[2])
1507 |
1508 | # repeat for every single block
1509 | block_coords = torch.meshgrid(block_coords_z, block_coords_y, block_coords_x)
1510 | block_coords = torch.stack((block_coords[0], block_coords[1], block_coords[2]), dim=-1)
1511 | if flatten:
1512 | block_coords = block_coords.reshape(-1, 3)
1513 |
1514 | return block_coords
1515 |
1516 | def get_stratified_samples(self, jitter=True, eval=False, oversample=1., kd_tree=None):
1517 | # Block coords are always aligned to the pixel grid,
1518 | # e.g., they align with pixels 0, 8, 16, 24, etc. for
1519 | # patch size 8
1520 | #
1521 | # To normalize the coordinates between (-1, 1), consider
1522 | # we have an image of 64x64 and patch size 8x8.
1523 | # The block coordinate (-1, -1) aligns with pixel (0, 0)
1524 | # and coordinate (1, 1) aligns with pixel (63, 63)
1525 | #
1526 | # Absolute coordinates within a block should stretch all the way
1527 | # from the absolute position of one block coordinate to another.
1528 | # Say each block contains 8x8 pixels and we use a feature grid
1529 | # of 8x8 features to interpolate values within a block.
1530 | # This means is that the feature positions are not actually
1531 | # aligned to the pixel positions. The features are positioned
1532 | # on a grid stretching from one block coord to another whereas
1533 | # the pixel grid ends just short of the next block coordinate
1534 | #
1535 | # Example patch (x = pixel position, B = block coordinate position)
1536 | # and relative coordinate positions.
1537 | #
1538 | # -1 ^ B x x x x x x x B
1539 | # | x x x x x x x x x
1540 | # | x x x x x x x x x
1541 | # | x x x x x x x x x
1542 | # | x x x x x x x x x
1543 | # | x x x x x x x x x
1544 | # | x x x x x x x x x
1545 | # | x x x x x x x x x
1546 | # 1 v B x x x x x x x B
1547 | # <--------------->
1548 | # -1 1
1549 | #
1550 | # When we generate samples for a patch, we sample an
1551 | # 8x8 grid that extends between block coords, i.e.
1552 | # between the arrows above
1553 | #
1554 | if eval:
1555 | if True or oversample != 1.:
1556 | post = torch.linspace(-1, 1, int(self.voxel_size * oversample + 1))[:-1]
1557 | post += (post[1] - post[0])/2
1558 |
1559 | eval_coords = torch.meshgrid(post, post, post)
1560 | row_coords = eval_coords[0].flatten()
1561 | col_coords = eval_coords[1].flatten()
1562 | dep_coords = eval_coords[2].flatten()
1563 | else:
1564 | row_coords = self.eval_row_coords.flatten()
1565 | col_coords = self.eval_col_coords.flatten()
1566 | dep_coords = self.eval_dep_coords.flatten()
1567 |
1568 | rel_samples = torch.stack((row_coords, col_coords, dep_coords), dim=-1)
1569 | abs_samples = self.block_coord[None, :] + (self.block_size[None, :]) * (rel_samples+1)/2
1570 |
1571 | return rel_samples, abs_samples, None
1572 |
1573 | else:
1574 | if (self.old_block_coord is None or self.old_block_size is None or
1575 | (self.old_block_coord != self.block_coord).any()
1576 | or (self.old_block_size != self.block_size).any()):
1577 |
1578 | self.old_block_coord = self.block_coord
1579 | self.old_block_size = self.block_size
1580 | self.update_surface_coords()
1581 | return self.select_near_mesh(self.num_samples**3)
1582 |
1583 | def update_surface_coords(self):
1584 | center = (self.block_coord + 0.5 * self.block_size).cpu().numpy()
1585 | side_length = float(self.block_size[0])
1586 |
1587 | search_radius = (side_length/2)*(3**0.5)
1588 | indices = np.array(self.surface_tree.query_ball_point(center, search_radius))
1589 |
1590 | if indices.shape[0] == 0:
1591 | self.near_mesh_abs_samples = np.zeros((0, 3))
1592 | self.near_mesh_rel_samples = np.zeros((0, 3))
1593 | else:
1594 | coordinates = self.surface_tree.data[indices]
1595 | in_cube = np.linalg.norm(coordinates - center, ord=np.inf, axis=1) < (side_length/2)
1596 | self.near_mesh_abs_samples = torch.FloatTensor(coordinates[in_cube])
1597 |
1598 | self.near_mesh_rel_samples = 2 * (self.near_mesh_abs_samples - self.block_coord[None, :])/self.block_size[None, :] - 1
1599 |
1600 | def select_near_mesh(self, num_samples, jitter=True):
1601 | if np.random.rand() > 0.9 and self.near_mesh_abs_samples.shape[0] > 0: # Near surface
1602 | selection_indices = torch.randint(self.near_mesh_abs_samples.shape[0], (num_samples,))
1603 | rel_samples_mesh = self.near_mesh_rel_samples[selection_indices]
1604 | abs_samples_mesh = self.near_mesh_abs_samples[selection_indices]
1605 |
1606 | rel_samples_mesh += torch.randn_like(rel_samples_mesh) * 0.05 * self.block_size
1607 | abs_samples_mesh = self.block_coord[None, :] + (self.block_size[None, :]) * (rel_samples_mesh+1)/2
1608 |
1609 | return rel_samples_mesh, abs_samples_mesh, None
1610 | else: # Uniform
1611 | row_coords = self.row_coords + torch.rand_like(self.row_coords) * 2./self.num_samples
1612 | col_coords = self.col_coords + torch.rand_like(self.col_coords) * 2./self.num_samples
1613 | dep_coords = self.dep_coords + torch.rand_like(self.dep_coords) * 2./self.num_samples
1614 |
1615 | rel_samples = torch.stack((row_coords, col_coords, dep_coords), dim=-1)
1616 | abs_samples = self.block_coord[None, :] + (self.block_size[None, :]) * (rel_samples+1)/2
1617 |
1618 | return rel_samples, abs_samples, None
1619 |
--------------------------------------------------------------------------------
/dataio.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import errno
4 | import matplotlib.colors as colors
5 | import skimage
6 | import skimage.filters
7 | import torch
8 | from PIL import Image
9 | from torch.utils.data import Dataset
10 | from torchvision.transforms import Resize, Compose, ToTensor, Normalize
11 | import urllib.request
12 | from tqdm import tqdm
13 | import numpy as np
14 | import copy
15 | import trimesh
16 | from inside_mesh import inside_mesh
17 |
18 | from scipy.spatial import cKDTree as spKDTree
19 | from data_structs import QuadTree, OctTree
20 |
21 |
22 | def get_mgrid(sidelen, dim=2):
23 | '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.'''
24 | if isinstance(sidelen, int):
25 | sidelen = dim * (sidelen,)
26 |
27 | if dim == 2:
28 | pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32)
29 | pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1)
30 | pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1)
31 | elif dim == 3:
32 | pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32)
33 | pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1)
34 | pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1)
35 | pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1)
36 | else:
37 | raise NotImplementedError('Not implemented for dim=%d' % dim)
38 |
39 | pixel_coords -= 0.5
40 | pixel_coords *= 2.
41 | pixel_coords = torch.Tensor(pixel_coords).view(-1, dim)
42 | return pixel_coords
43 |
44 |
45 | def lin2img(tensor, image_resolution=None):
46 | batch_size, num_samples, channels = tensor.shape
47 | if image_resolution is None:
48 | width = np.sqrt(num_samples).astype(int)
49 | height = width
50 | else:
51 | height = image_resolution[0]
52 | width = image_resolution[1]
53 |
54 | return tensor.permute(0, 2, 1).view(batch_size, channels, height, width)
55 |
56 |
57 | def grads2img(gradients):
58 | mG = gradients.detach().squeeze(0).permute(-2, -1, -3).cpu()
59 |
60 | # assumes mG is [row,cols,2]
61 | nRows = mG.shape[0]
62 | nCols = mG.shape[1]
63 | mGr = mG[:, :, 0]
64 | mGc = mG[:, :, 1]
65 | mGa = np.arctan2(mGc, mGr)
66 | mGm = np.hypot(mGc, mGr)
67 | mGhsv = np.zeros((nRows, nCols, 3), dtype=np.float32)
68 | mGhsv[:, :, 0] = (mGa + math.pi) / (2. * math.pi)
69 | mGhsv[:, :, 1] = 1.
70 |
71 | nPerMin = np.percentile(mGm, 5)
72 | nPerMax = np.percentile(mGm, 95)
73 | mGm = (mGm - nPerMin) / (nPerMax - nPerMin)
74 | mGm = np.clip(mGm, 0, 1)
75 |
76 | mGhsv[:, :, 2] = mGm
77 | mGrgb = colors.hsv_to_rgb(mGhsv)
78 | return torch.from_numpy(mGrgb).permute(2, 0, 1)
79 |
80 |
81 | def rescale_img(x, mode='scale', perc=None, tmax=1.0, tmin=0.0):
82 | if (mode == 'scale'):
83 | if perc is None:
84 | xmax = torch.max(x)
85 | xmin = torch.min(x)
86 | else:
87 | xmin = np.percentile(x.detach().cpu().numpy(), perc)
88 | xmax = np.percentile(x.detach().cpu().numpy(), 100 - perc)
89 | x = torch.clamp(x, xmin, xmax)
90 | if xmin == xmax:
91 | return 0.5 * torch.ones_like(x) * (tmax - tmin) + tmin
92 | x = ((x - xmin) / (xmax - xmin)) * (tmax - tmin) + tmin
93 | elif (mode == 'clamp'):
94 | x = torch.clamp(x, 0, 1)
95 | return x
96 |
97 |
98 | def to_uint8(x):
99 | return (255. * x).astype(np.uint8)
100 |
101 |
102 | def to_numpy(x):
103 | return x.detach().cpu().numpy()
104 |
105 |
106 | class PointCloud(Dataset):
107 | def __init__(self, pointcloud_path, on_surface_points, keep_aspect_ratio=True):
108 | super().__init__()
109 |
110 | print("Dataset: loading point cloud")
111 | point_cloud = np.genfromtxt(pointcloud_path)
112 | print("Dataset: finished loading point cloud")
113 |
114 | coords = point_cloud[:, :3]
115 | self.normals = point_cloud[:, 3:]
116 |
117 | # Reshape point cloud such that it lies in bounding box of (-1, 1) (distorts geometry, but makes for high
118 | # sample efficiency)
119 | coords -= np.mean(coords, axis=0, keepdims=True)
120 | if keep_aspect_ratio:
121 | coord_max = np.amax(coords)
122 | coord_min = np.amin(coords)
123 | else:
124 | coord_max = np.amax(coords, axis=0, keepdims=True)
125 | coord_min = np.amin(coords, axis=0, keepdims=True)
126 |
127 | self.coords = (coords - coord_min) / (coord_max - coord_min)
128 | self.coords -= 0.5
129 | self.coords *= 2.
130 |
131 | self.on_surface_points = on_surface_points
132 |
133 | def __len__(self):
134 | return self.coords.shape[0] // self.on_surface_points
135 |
136 | def __getitem__(self, idx):
137 | point_cloud_size = self.coords.shape[0]
138 |
139 | off_surface_samples = self.on_surface_points # **2
140 | total_samples = self.on_surface_points + off_surface_samples
141 |
142 | # Random coords
143 | rand_idcs = np.random.choice(point_cloud_size, size=self.on_surface_points)
144 |
145 | on_surface_coords = self.coords[rand_idcs, :]
146 | on_surface_normals = self.normals[rand_idcs, :]
147 |
148 | off_surface_coords = np.random.uniform(-1, 1, size=(off_surface_samples, 3))
149 | off_surface_normals = np.ones((off_surface_samples, 3)) * -1
150 |
151 | sdf = np.zeros((total_samples, 1)) # on-surface = 0
152 | sdf[self.on_surface_points:, :] = -1 # off-surface = -1
153 |
154 | coords = np.concatenate((on_surface_coords, off_surface_coords), axis=0)
155 | normals = np.concatenate((on_surface_normals, off_surface_normals), axis=0)
156 |
157 | return {'coords': torch.from_numpy(coords).float()}, {'sdf': torch.from_numpy(sdf).float(),
158 | 'normals': torch.from_numpy(normals).float()}
159 |
160 |
161 | class OccupancyDataset():
162 | def __init__(self, pc_or_mesh_filename):
163 | self.intersector = None
164 | self.kd_tree = None
165 | self.kd_tree_sp = None
166 | self.mode = None
167 |
168 | if not pc_or_mesh_filename:
169 | return
170 |
171 | print("Dataset: loading mesh")
172 | self.mesh = trimesh.load(pc_or_mesh_filename, process=False, force='mesh', skip_materials=True)
173 |
174 | def normalize_mesh(mesh):
175 | print("Dataset: scaling parameters: ", mesh.bounding_box.extents)
176 | mesh.vertices -= mesh.bounding_box.centroid
177 | mesh.vertices /= np.max(mesh.bounding_box.extents / 2)
178 |
179 | normalize_mesh(self.mesh)
180 |
181 | self.intersector = inside_mesh.MeshIntersector(self.mesh, 2048)
182 | self.mode = 'volume'
183 |
184 | print('Dataset: sampling points on mesh')
185 | samples = trimesh.sample.sample_surface(self.mesh, 20000000)[0]
186 |
187 | self.kd_tree_sp = spKDTree(samples)
188 |
189 | def __len__(self):
190 | return 1
191 |
192 | def evaluate_occupancy(self, pts):
193 | return self.intersector.query(pts).astype(int).reshape(-1, 1)
194 |
195 |
196 | class Camera(Dataset):
197 | def __init__(self, downsample_factor=1):
198 | super().__init__()
199 | self.downsample_factor = downsample_factor
200 | self.img = Image.fromarray(skimage.data.camera())
201 | self.img_channels = 1
202 |
203 | if downsample_factor > 1:
204 | size = (int(512 / downsample_factor),) * 2
205 | self.img_downsampled = self.img.resize(size, Image.ANTIALIAS)
206 |
207 | def __len__(self):
208 | return 1
209 |
210 | def __getitem__(self, idx):
211 | if self.downsample_factor > 1:
212 | return self.img_downsampled
213 | else:
214 | return self.img
215 |
216 |
217 | class ImageFile(Dataset):
218 | def __init__(self, filename, url=None, grayscale=True):
219 | super().__init__()
220 | Image.MAX_IMAGE_PIXELS = 1000000000
221 | file_exists = os.path.isfile(filename)
222 |
223 | if not file_exists:
224 | if url is None:
225 | raise FileNotFoundError(
226 | errno.ENOENT, os.strerror(errno.ENOENT), filename)
227 | else:
228 | print('Downloading image file...')
229 | urllib.request.urlretrieve(url, filename)
230 |
231 | self.img = Image.open(filename)
232 | if grayscale:
233 | self.img = self.img.convert('L')
234 |
235 | self.img_channels = len(self.img.mode)
236 |
237 | def __len__(self):
238 | return 1
239 |
240 | def __getitem__(self, idx):
241 | return self.img
242 |
243 |
244 | class Patch2DWrapperMultiscaleAdaptive(torch.utils.data.Dataset):
245 | def __init__(self, dataset, patch_size=(16, 16), sidelength=None, random_coords=False,
246 | jitter=True, num_workers=0, length=1000, scale_init=3, max_patches=1024):
247 |
248 | self.length = length
249 | if len(sidelength) == 1:
250 | sidelength = 2*sidelength
251 | self.sidelength = sidelength
252 |
253 | for i in range(2):
254 | assert float(sidelength[i]) / float(patch_size[i]) % 1 == 0, 'Resolution not divisible by patch size'
255 | assert float(sidelength[0]) / float(patch_size[0]) == float(sidelength[1]) / float(patch_size[1]), \
256 | 'number of patches must be same along each dim; check values of resolution and patch size'
257 |
258 | self.transform = Compose([
259 | Resize(sidelength),
260 | ToTensor(),
261 | Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))
262 | ])
263 |
264 | # initialize quad tree
265 | self.quadtree = QuadTree(sidelength, patch_size)
266 | self.num_scales = self.quadtree.max_quadtree_level - self.quadtree.min_quadtree_level + 1
267 | self.max_patches = max_patches
268 |
269 | # set patches at coarsest level to be active
270 | patches = self.quadtree.get_patches_at_level(scale_init)
271 | for p in patches:
272 | p.activate()
273 |
274 | # handle parallelization
275 | self.num_workers = num_workers
276 |
277 | # make a copy of the tree for each worker
278 | self.quadtrees = []
279 | print('Dataset: preparing dataloaders...')
280 | for idx in tqdm(range(num_workers)):
281 | self.quadtrees.append(copy.deepcopy(self.quadtree))
282 | self.last_active_patches = self.quadtree.get_active_patches()
283 |
284 | # set random patches to be active
285 | # self.quadtree.activate_random()
286 |
287 | self.patch_size = patch_size
288 | self.dataset = dataset
289 | self.img = self.transform(self.dataset[0])
290 | self.jitter = jitter
291 | self.eval = False
292 |
293 | def toggle_eval(self):
294 | if not self.eval:
295 | self.jitter_bak = self.jitter
296 | self.jitter = False
297 | self.eval = True
298 | else:
299 | self.jitter = self.jitter_bak
300 | self.eval = False
301 |
302 | def interpolate_bilinear(self, img, fine_abs_coords, psize):
303 | n_blocks = fine_abs_coords.shape[0]
304 | n_channels = img.shape[0]
305 | fine_abs_coords = fine_abs_coords.reshape(n_blocks, psize[0], psize[1], 2)
306 | x = fine_abs_coords[..., :1]
307 | y = fine_abs_coords[..., 1:]
308 | coords = torch.cat([y, x], dim=-1)
309 |
310 | out = []
311 | for block in coords:
312 | tmp = torch.nn.functional.grid_sample(img[None, ...], block[None, ...],
313 | mode='bilinear',
314 | padding_mode='reflection',
315 | align_corners=False)
316 | out.append(tmp)
317 | out = torch.cat(out, dim=0)
318 | out = out.permute(0, 2, 3, 1)
319 | return out.reshape(n_blocks, np.prod(psize), n_channels)
320 |
321 | def synchronize(self):
322 | self.last_active_patches = self.quadtree.get_active_patches()
323 |
324 | if self.num_workers == 0:
325 | return
326 | else:
327 | for idx in range(self.num_workers):
328 | self.quadtrees[idx].synchronize(self.quadtree)
329 |
330 | def __len__(self):
331 | # return len(self.dataset)
332 | return self.length
333 |
334 | def get_frozen_patches(self):
335 | quadtree = self.quadtree
336 |
337 | # get fine coords, get frozen patches is only called at eval
338 | fine_rel_coords, fine_abs_coords, vals,\
339 | coord_patch_idx = quadtree.get_frozen_samples()
340 |
341 | return fine_abs_coords, vals
342 |
343 | def __getitem__(self, idx):
344 |
345 | quadtree = self.quadtree
346 | if not self.eval and self.num_workers > 0:
347 | worker_idx = torch.utils.data.get_worker_info().id
348 | quadtree = self.quadtrees[worker_idx]
349 |
350 | # get fine coords
351 | fine_rel_coords, fine_abs_coords, coord_patch_idx = quadtree.get_stratified_samples(self.jitter, eval=self.eval)
352 |
353 | # get block coords
354 | patches = quadtree.get_active_patches()
355 | coords = torch.stack([p.block_coord for p in patches], dim=0)
356 | scales = torch.stack([torch.tensor(p.scale) for p in patches], dim=0)[:, None]
357 | scales = 2*scales / (self.num_scales-1) - 1
358 | coords = torch.cat((coords, scales), dim=-1)
359 |
360 | if self.eval:
361 | coords = coords[coord_patch_idx]
362 |
363 | fine_abs_coords = fine_abs_coords
364 | img = self.interpolate_bilinear(self.img, fine_abs_coords, self.patch_size)
365 |
366 | in_dict = {'coords': coords,
367 | 'fine_abs_coords': fine_abs_coords,
368 | 'fine_rel_coords': fine_rel_coords}
369 | gt_dict = {'img': img}
370 |
371 | return in_dict, gt_dict
372 |
373 | def update_patch_err(self, err_per_patch, step):
374 | assert err_per_patch.shape[0] == len(self.last_active_patches), \
375 | f"Trying to update the error in active patches but list of patches and error tensor" \
376 | f" sizes are mismatched: {err_per_patch.shape[0]} vs {len(self.last_active_patches)}"
377 |
378 | for i, p in enumerate(self.last_active_patches):
379 | # Log the history of error
380 | p.update_error(err_per_patch[i], step)
381 |
382 | def update_tiling(self):
383 | return self.quadtree.solve_optim(self.max_patches)
384 |
385 |
386 | class Block3DWrapperMultiscaleAdaptive(torch.utils.data.Dataset):
387 | def __init__(self, dataset, octant_size=16, sidelength=None, random_coords=False,
388 | max_octants=600, jitter=True, num_workers=0, length=1000, scale_init=3):
389 |
390 | self.length = length
391 | if isinstance(sidelength, int):
392 | sidelength = (sidelength, sidelength, sidelength)
393 | self.sidelength = sidelength
394 |
395 | # initialize quad tree
396 | self.octtree = OctTree(sidelength, octant_size, mesh_kd_tree=dataset.kd_tree_sp)
397 | self.num_scales = self.octtree.max_octtree_level - self.octtree.min_octtree_level + 1
398 |
399 | # set patches at coarsest level to be active
400 | octants = self.octtree.get_octants_at_level(scale_init)
401 | for p in octants:
402 | p.activate()
403 |
404 | # handle parallelization
405 | self.num_workers = num_workers
406 |
407 | # make a copy of the tree for each worker
408 | self.octtrees = []
409 | print('Dataset: preparing dataloaders...')
410 | for idx in tqdm(range(num_workers)):
411 | self.octtrees.append(copy.deepcopy(self.octtree))
412 | self.last_active_octants = self.octtree.get_active_octants()
413 |
414 | self.octant_size = octant_size
415 | self.dataset = dataset
416 | self.pointcloud = None
417 | self.jitter = jitter
418 | self.eval = False
419 |
420 | self.max_octants = max_octants
421 |
422 | self.iter = 0
423 |
424 | def toggle_eval(self):
425 | if not self.eval:
426 | self.jitter_bak = self.jitter
427 | self.jitter = False
428 | self.eval = True
429 | else:
430 | self.jitter = self.jitter_bak
431 | self.eval = False
432 |
433 | def synchronize(self):
434 | self.last_active_octants = self.octtree.get_active_octants()
435 | if self.num_workers == 0:
436 | return
437 | else:
438 | for idx in range(self.num_workers):
439 | self.octtrees[idx].synchronize(self.octtree)
440 |
441 | def __len__(self):
442 | return self.length
443 |
444 | def get_frozen_octants(self, oversample):
445 | octtree = self.octtree
446 |
447 | # get fine coords, get frozen patches is only called at eval
448 | fine_rel_coords, fine_abs_coords, vals,\
449 | coord_patch_idx = octtree.get_frozen_samples(oversample)
450 |
451 | return fine_abs_coords, vals
452 |
453 | def get_eval_samples(self, oversample):
454 | octtree = self.octtree
455 |
456 | # get fine coords
457 | fine_rel_coords, fine_abs_coords, coord_octant_idx, _ = octtree.get_stratified_samples(self.jitter, eval=True, oversample=oversample)
458 |
459 | # get block coords
460 | octants = octtree.get_active_octants()
461 | coords = torch.stack([p.block_coord for p in octants], dim=0)
462 | scales = torch.stack([torch.tensor(p.scale) for p in octants], dim=0)[:, None]
463 | scales = 2*scales / (self.num_scales-1) - 1
464 | coords = torch.cat((coords, scales), dim=-1)
465 |
466 | coords = coords[coord_octant_idx]
467 |
468 | # query for occupancy
469 | sz_b, sz_p, _ = fine_abs_coords.shape
470 |
471 | in_dict = {'coords': coords,
472 | 'fine_abs_coords': fine_abs_coords,
473 | 'fine_rel_coords': fine_rel_coords,
474 | 'coord_octant_idx': torch.tensor(coord_octant_idx, dtype=torch.int)}
475 |
476 | return in_dict
477 |
478 | def __getitem__(self, idx):
479 | assert(not self.eval)
480 |
481 | octtree = self.octtree
482 | if not self.eval and self.num_workers > 0:
483 | worker_idx = torch.utils.data.get_worker_info().id
484 | octtree = self.octtrees[worker_idx]
485 |
486 | # get fine coords
487 | fine_rel_coords, fine_abs_coords, coord_octant_idx, coord_global_idx = octtree.get_stratified_samples(self.jitter, eval=self.eval)
488 |
489 | # get block coords
490 | octants = octtree.get_active_octants()
491 | coords = torch.stack([p.block_coord for p in octants], dim=0)
492 | scales = torch.stack([torch.tensor(p.scale) for p in octants], dim=0)[:, None]
493 | scales = 2*scales / (self.num_scales-1) - 1
494 | coords = torch.cat((coords, scales), dim=-1)
495 |
496 | if self.eval:
497 | coords = coords[coord_octant_idx]
498 |
499 | # query for occupancy
500 | sz_b, sz_p, _ = fine_abs_coords.shape
501 | fine_abs_coords_query = fine_abs_coords.reshape(-1, 3).detach().cpu().numpy()
502 |
503 | if self.eval:
504 | occupancy = np.zeros(fine_abs_coords_query.shape[0])
505 | else:
506 | occupancy = self.dataset.evaluate_occupancy(fine_abs_coords_query) # start-end/num iters
507 | occupancy = torch.from_numpy(occupancy).reshape(sz_b, sz_p, 1)
508 |
509 | self.iter += 1
510 |
511 | in_dict = {'coords': coords,
512 | 'fine_abs_coords': fine_abs_coords,
513 | 'fine_rel_coords': fine_rel_coords}
514 |
515 | if self.eval:
516 | in_dict.update({'coord_octant_idx': torch.tensor(coord_octant_idx, dtype=torch.int)})
517 |
518 | gt_dict = {'occupancy': occupancy}
519 |
520 | return in_dict, gt_dict
521 |
522 | def update_octant_err(self, err_per_octant, step):
523 | assert err_per_octant.shape[0] == len(self.last_active_octants), \
524 | f"Trying to update the error in active patches but list of patches and error tensor" \
525 | f" sizes are mismatched: {err_per_octant.shape[0]} vs {len(self.last_active_octants)}" \
526 | f"step: {step}"
527 |
528 | for i, p in enumerate(self.last_active_octants):
529 | # Log the history of error
530 | p.update_error(err_per_octant[i], step)
531 |
532 | self.per_octant_error = err_per_octant
533 |
534 | def update_tiling(self):
535 | return self.octtree.solve_optim(self.max_octants)
536 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: acorn
2 | channels:
3 | - gurobi
4 | - pytorch
5 | - defaults
6 | dependencies:
7 | - cudatoolkit=10.2.89
8 | - cython=0.29.24
9 | - gurobi=9.1.2=py37_0
10 | - matplotlib=3.3.4
11 | - numpy=1.20.3
12 | - pillow=8.3.1
13 | - pip=21.1.3
14 | - python=3.7.10
15 | - pytorch=1.8.1
16 | - scikit-image=0.18.1
17 | - scipy=1.6.2
18 | - tensorboard=2.5.0=py_0
19 | - torchvision=0.9.1=py37_cu102
20 | - tqdm=4.61.2=pyhd3eb1b0_1
21 | - pip:
22 | - configargparse==1.5.1
23 | - mrcfile==1.3.0
24 | - opencv-python==4.5.3.56
25 | - pycollada==0.7.1
26 | - pymcubes==0.1.2
27 | - trimesh==3.9.25
28 |
--------------------------------------------------------------------------------
/experiment_scripts/config_img/config_mars.ini:
--------------------------------------------------------------------------------
1 | experiment_name = mars_acorn_1G
2 | num_workers = 4
3 | res = [19456, 51200]
4 | dataset = tokyo
5 | max_patches = 16384
6 | skip_logging = true
7 | lr = 0.001
8 | num_iters = 100300
9 | epochs_til_ckpt = 2
10 | steps_til_summary = 500
11 | patch_size = [64, 19, 50]
12 | w0 = 5
13 | steps_til_tiling = 500
14 | model_type = multiscale
15 | scale_init = 3
16 | logging_root = ../logs
17 | hidden_features = 2048
18 | hidden_layers = 8
19 |
--------------------------------------------------------------------------------
/experiment_scripts/config_img/config_pluto_acorn_1k.ini:
--------------------------------------------------------------------------------
1 | experiment_name = pluto_acorn_1K
2 | num_workers = 12
3 | res = [1024]
4 | dataset = pluto
5 | steps_til_tiling = 500
6 | max_patches = 512
7 | scale_init = 3
8 | skip_logging = False
9 | lr = 0.001
10 | num_iters = 100300
11 | epochs_til_ckpt = 2
12 | steps_til_summary = 500
13 | patch_size = [16, 8, 8]
14 | w0 = 5
15 | model_type = multiscale
16 | logging_root = ../logs
17 | hidden_features = 256
18 | hidden_layers = 4
19 |
--------------------------------------------------------------------------------
/experiment_scripts/config_img/config_pluto_acorn_4k.ini:
--------------------------------------------------------------------------------
1 | experiment_name = pluto_acorn_4K
2 | epochs_til_ckpt = 1
3 | num_workers = 8
4 | res = [4096]
5 | dataset = pluto
6 | steps_til_tiling = 500
7 | max_patches = 256
8 | skip_logging = false
9 | lr = 0.001
10 | num_iters = 100300
11 | steps_til_summary = 500
12 | patch_size = [16, 32, 32]
13 | w0 = 50
14 | model_type = multiscale
15 | scale_init = 3
16 | logging_root = ../logs
17 | hidden_features = 512
18 | hidden_layers = 4
19 |
--------------------------------------------------------------------------------
/experiment_scripts/config_img/config_pluto_acorn_8k.ini:
--------------------------------------------------------------------------------
1 | experiment_name = pluto_acorn_8K
2 | num_workers = 12
3 | res = 8192
4 | dataset = pluto
5 | max_patches = 1024
6 | skip_logging = false
7 | lr = 0.001
8 | num_iters = 100300
9 | epochs_til_ckpt = 2
10 | steps_til_summary = 500
11 | patch_size = [16, 32, 32]
12 | w0 = 5
13 | steps_til_tiling = 500
14 | model_type = multiscale
15 | scale_init = 3
16 | logging_root = ../logs
17 | hidden_features = 512
18 | hidden_layers = 4
19 |
--------------------------------------------------------------------------------
/experiment_scripts/config_img/config_pluto_pe_8k.ini:
--------------------------------------------------------------------------------
1 | experiment_name = pluto_pe_8K
2 | num_workers = 8
3 | res = 8192
4 | dataset = pluto
5 | model_type = pe
6 | scale_init = 4
7 | skip_logging = true
8 | lr = 0.001
9 | num_iters = 100300
10 | epochs_til_ckpt = 2
11 | steps_til_summary = 500
12 | patch_size = [0, 32, 32]
13 | w0 = 5
14 | steps_til_tiling = 500
15 | max_patches = 1024
16 | logging_root = ../logs
17 | hidden_features = 1536
18 | hidden_layers = 4
19 |
--------------------------------------------------------------------------------
/experiment_scripts/config_img/config_pluto_siren_8k.ini:
--------------------------------------------------------------------------------
1 | experiment_name = pluto_siren_8K
2 | lr = 1e-05
3 | num_workers = 8
4 | w0 = 50
5 | res = 8192
6 | dataset = pluto
7 | model_type = siren
8 | scale_init = 4
9 | skip_logging = true
10 | num_iters = 100300
11 | epochs_til_ckpt = 2
12 | steps_til_summary = 500
13 | patch_size = [0, 32, 32]
14 | steps_til_tiling = 500
15 | max_patches = 1024
16 | logging_root = ../logs
17 | hidden_features = 1536
18 | hidden_layers = 4
19 |
--------------------------------------------------------------------------------
/experiment_scripts/config_img/config_tokyo.ini:
--------------------------------------------------------------------------------
1 | experiment_name = tokyo_acorn_1G
2 | num_workers = 4
3 | res = [19456, 51200]
4 | dataset = tokyo
5 | max_patches = 16384
6 | skip_logging = true
7 | lr = 0.001
8 | num_iters = 100300
9 | epochs_til_ckpt = 2
10 | steps_til_summary = 500
11 | patch_size = [64, 19, 50]
12 | w0 = 5
13 | steps_til_tiling = 500
14 | model_type = multiscale
15 | scale_init = 3
16 | logging_root = ../logs
17 | hidden_features = 2048
18 | hidden_layers = 8
19 |
--------------------------------------------------------------------------------
/experiment_scripts/config_occupancy/config_dragon_acorn.ini:
--------------------------------------------------------------------------------
1 | experiment_name = dragon
2 | num_workers = 8
3 | lr = 0.001
4 | num_epochs = 10000
5 | epochs_til_pruning = 4
6 | epochs_til_ckpt = 4
7 | steps_til_summary = 4000
8 | octant_size = 4
9 | res = 512
10 | steps_til_tiling = 1000
11 | max_octants = 1024
12 | logging_root = ../logs
13 | feature_grid_size = [18, 12, 12, 12]
14 | pc_filepath = ../data/dragon.obj
15 | pruning_threshold = -10
16 | hidden_features = 512
17 | hidden_layers = 4
18 |
--------------------------------------------------------------------------------
/experiment_scripts/config_occupancy/config_engine_acorn.ini:
--------------------------------------------------------------------------------
1 | experiment_name = engine
2 | lr = 0.001
3 | num_workers = 8
4 | octant_size = 4
5 | steps_til_tiling = 1000
6 | pruning_threshold = -10.0
7 | num_epochs = 10000
8 | epochs_til_ckpt = 4
9 | steps_til_summary = 4000
10 | res = 512
11 | max_octants = 1024
12 | logging_root = ../logs
13 | feature_grid_size = [18, 12, 12, 12]
14 | pc_filepath = ../data/engine.obj
15 | epochs_til_pruning = 4
16 | hidden_features = 512
17 | hidden_layers = 4
18 |
--------------------------------------------------------------------------------
/experiment_scripts/config_occupancy/config_lucy_acorn.ini:
--------------------------------------------------------------------------------
1 | experiment_name = lucy
2 | num_workers = 6
3 | lr = 0.001
4 | num_epochs = 10000
5 | epochs_til_pruning = 4
6 | epochs_til_ckpt = 4
7 | steps_til_summary = 4000
8 | octant_size = 4
9 | res = 512
10 | steps_til_tiling = 1000
11 | max_octants = 1024
12 | logging_root = ../logs
13 | feature_grid_size = [18, 12, 12, 12]
14 | pc_filepath = ../data/lucy.ply
15 | pruning_threshold = -10
16 | hidden_features = 512
17 | hidden_layers = 4
18 |
--------------------------------------------------------------------------------
/experiment_scripts/config_occupancy/config_lucy_small_acorn.ini:
--------------------------------------------------------------------------------
1 | experiment_name = lucy_small
2 | num_workers = 12
3 | lr = 0.001
4 | num_epochs = 10000
5 | epochs_til_pruning = 4
6 | epochs_til_ckpt = 4
7 | steps_til_summary = 4000
8 | octant_size = 4
9 | res = 256
10 | steps_til_tiling = 1000
11 | max_octants = 256
12 | logging_root = ../logs
13 | feature_grid_size = [18, 12, 12, 12]
14 | pc_filepath = ../data/lucy.ply
15 | pruning_threshold = -10
16 | hidden_features = 256
17 | hidden_layers = 4
18 | scale_init = 2
19 |
--------------------------------------------------------------------------------
/experiment_scripts/config_occupancy/config_thai_acorn.ini:
--------------------------------------------------------------------------------
1 | experiment_name = thai_statue
2 | num_workers = 8
3 | lr = 0.001
4 | num_epochs = 10000
5 | epochs_til_ckpt = 4
6 | steps_til_summary = 4000
7 | octant_size = 4
8 | res = 512
9 | steps_til_tiling = 1000
10 | max_octants = 1024
11 | logging_root = ../logs
12 | feature_grid_size = [18, 12, 12, 12]
13 | pc_filepath = ../data/thai_statue.ply
14 | pruning_threshold = -10
15 | epochs_til_pruning = 4
16 | hidden_features = 512
17 | hidden_layers = 4
18 |
--------------------------------------------------------------------------------
/experiment_scripts/train_img.py:
--------------------------------------------------------------------------------
1 | # Enable import from parent package
2 | import sys
3 | import os
4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5 | import dataio
6 | import utils
7 | import training
8 | import loss_functions
9 | import pruning_functions
10 | import modules
11 | from torch.utils.data import DataLoader
12 | import configargparse
13 | from functools import partial
14 | import numpy as np
15 | import skimage
16 | import cv2
17 | from tqdm import tqdm
18 | import matplotlib.pyplot as plt
19 | import re
20 | import torch
21 | from time import time
22 |
23 |
24 | p = configargparse.ArgumentParser()
25 | p.add('-c', '--config', required=False, is_config_file=True, help='Path to config file.')
26 |
27 | # General training options
28 | p.add_argument('--lr', type=float, default=1e-3, help='learning rate. default=1e-3')
29 | p.add_argument('--num_iters', type=int, default=100300,
30 | help='Number of iterations to train for.')
31 | p.add_argument('--num_workers', type=int, default=0,
32 | help='number of dataloader workers.')
33 | p.add_argument('--skip_logging', action='store_true', default=False,
34 | help="don't use summary function, only save loss and models")
35 | p.add_argument('--eval', action='store_true', default=False,
36 | help='run evaluation')
37 | p.add_argument('--resume', nargs=2, type=str, default=None,
38 | help='resume training, specify path to directory where model is stored and the iteration of ckpt.')
39 | p.add_argument('--gpu', type=int, default=0,
40 | help='GPU ID to use')
41 |
42 | # logging options
43 | p.add_argument('--experiment_name', type=str, required=True,
44 | help='path to directory where checkpoints & tensorboard events will be saved.')
45 | p.add_argument('--epochs_til_ckpt', type=int, default=2,
46 | help='Epochs until checkpoint is saved.')
47 | p.add_argument('--steps_til_summary', type=int, default=500,
48 | help='Number of iterations until tensorboard summary is saved.')
49 |
50 | # dataset options
51 | p.add_argument('--res', nargs='+', type=int, default=[512],
52 | help='image resolution.')
53 | p.add_argument('--dataset', type=str, default='camera', choices=['camera', 'pluto', 'tokyo', 'mars'],
54 | help='which dataset to use')
55 | p.add_argument('--grayscale', action='store_true', default=False,
56 | help='whether to use grayscale')
57 |
58 | # model options
59 | p.add_argument('--patch_size', nargs='+', type=int, default=[32],
60 | help='patch size.')
61 | p.add_argument('--hidden_features', type=int, default=512,
62 | help='hidden features in network')
63 | p.add_argument('--hidden_layers', type=int, default=4,
64 | help='hidden layers in network')
65 | p.add_argument('--w0', type=int, default=5,
66 | help='w0 for the siren model.')
67 | p.add_argument('--steps_til_tiling', type=int, default=500,
68 | help='How often to recompute the tiling, also defines number of steps per epoch.')
69 | p.add_argument('--max_patches', type=int, default=1024,
70 | help='maximum number of patches in the optimization')
71 | p.add_argument('--model_type', type=str, default='multiscale', required=False, choices=['multiscale', 'siren', 'pe'],
72 | help='Type of model to evaluate, default is multiscale.')
73 | p.add_argument('--scale_init', type=int, default=3,
74 | help='which scale to initialize active patches in the quadtree')
75 |
76 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
77 | p.add_argument('--logging_root', type=str, default='../logs', help='root for logging')
78 | opt = p.parse_args()
79 |
80 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
81 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu)
82 |
83 | for k, v in opt.__dict__.items():
84 | print(k, v)
85 |
86 |
87 | def main():
88 | if opt.dataset == 'camera':
89 | img_dataset = dataio.Camera()
90 | elif opt.dataset == 'pluto':
91 | pluto_url = "https://upload.wikimedia.org/wikipedia/commons/e/ef/Pluto_in_True_Color_-_High-Res.jpg"
92 | img_dataset = dataio.ImageFile('../data/pluto.jpg', url=pluto_url, grayscale=opt.grayscale)
93 | elif opt.dataset == 'tokyo':
94 | img_dataset = dataio.ImageFile('../data/tokyo.tif', grayscale=opt.grayscale)
95 | elif opt.dataset == 'mars':
96 | img_dataset = dataio.ImageFile('../data/mars.tif', grayscale=opt.grayscale)
97 |
98 | if len(opt.patch_size) == 1:
99 | opt.patch_size = 3*opt.patch_size
100 |
101 | # set up dataset
102 | coord_dataset = dataio.Patch2DWrapperMultiscaleAdaptive(img_dataset,
103 | sidelength=opt.res,
104 | patch_size=opt.patch_size[1:], jitter=True,
105 | num_workers=opt.num_workers, length=opt.steps_til_tiling,
106 | scale_init=opt.scale_init, max_patches=opt.max_patches)
107 |
108 | opt.num_epochs = opt.num_iters // coord_dataset.__len__()
109 |
110 | image_resolution = (opt.res, opt.res)
111 |
112 | dataloader = DataLoader(coord_dataset, shuffle=False, batch_size=1, pin_memory=True,
113 | num_workers=opt.num_workers)
114 |
115 | if opt.resume is not None:
116 | path, iter = opt.resume
117 | iter = int(iter)
118 | assert(os.path.isdir(path))
119 | assert opt.config is not None, 'Specify config file'
120 |
121 | # Define the model.
122 | if opt.grayscale:
123 | out_features = 1
124 | else:
125 | out_features = 3
126 |
127 | if opt.model_type == 'multiscale':
128 | model = modules.ImplicitAdaptivePatchNet(in_features=3, out_features=out_features,
129 | num_hidden_layers=opt.hidden_layers,
130 | hidden_features=opt.hidden_features,
131 | feature_grid_size=(opt.patch_size[0], opt.patch_size[1], opt.patch_size[2]),
132 | sidelength=opt.res,
133 | num_encoding_functions=10,
134 | patch_size=opt.patch_size[1:])
135 |
136 | elif opt.model_type == 'siren':
137 | model = modules.ImplicitNet(opt.res, in_features=2,
138 | out_features=out_features,
139 | num_hidden_layers=4,
140 | hidden_features=1536,
141 | mode='siren', w0=opt.w0)
142 | elif opt.model_type == 'pe':
143 | model = modules.ImplicitNet(opt.res, in_features=2,
144 | out_features=out_features,
145 | num_hidden_layers=4,
146 | hidden_features=1536,
147 | mode='pe')
148 | else:
149 | raise NotImplementedError('Only model types multiscale, siren, and pe are implemented')
150 |
151 | model.cuda()
152 |
153 | # print number of model parameters
154 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
155 | params = sum([np.prod(p.size()) for p in model_parameters])
156 | print(f'Num. Parameters: {params}')
157 |
158 | # Define the loss
159 | loss_fn = partial(loss_functions.image_mse,
160 | tiling_every=opt.steps_til_tiling,
161 | dataset=coord_dataset,
162 | model_type=opt.model_type)
163 | summary_fn = partial(utils.write_image_patch_multiscale_summary, image_resolution, opt.patch_size[1:], coord_dataset, model_type=opt.model_type, skip=opt.skip_logging)
164 |
165 | # Define the pruning function
166 | pruning_fn = partial(pruning_functions.no_pruning,
167 | pruning_every=1)
168 |
169 | # if we are resuming from a saved checkpoint
170 | if opt.resume is not None:
171 | print('Loading checkpoints')
172 | model_dict = torch.load(path + '/checkpoints/' + f'model_{iter:06d}.pth')
173 | model.load_state_dict(model_dict)
174 |
175 | # load optimizers
176 | try:
177 | resume_checkpoint = {}
178 | optim_dict = torch.load(path + '/checkpoints/' + f'optim_{iter:06d}.pth')
179 | for g in optim_dict['optimizer_state_dict']['param_groups']:
180 | g['lr'] = opt.lr
181 | resume_checkpoint['optimizer_state_dict'] = optim_dict['optimizer_state_dict']
182 | resume_checkpoint['total_steps'] = optim_dict['total_steps']
183 | resume_checkpoint['epoch'] = optim_dict['epoch']
184 |
185 | # initialize model state_dict
186 | print('Initializing models')
187 | coord_dataset.quadtree.__load__(optim_dict['quadtree'])
188 | coord_dataset.synchronize()
189 |
190 | except FileNotFoundError:
191 | print('Unable to load optimizer checkpoints')
192 | else:
193 | resume_checkpoint = {}
194 |
195 | if opt.eval:
196 | run_eval(model, coord_dataset)
197 | else:
198 | # Save command-line parameters log directory.
199 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
200 | utils.cond_mkdir(root_path)
201 | p.write_config_file(opt, [os.path.join(root_path, 'config.ini')])
202 |
203 | # Save text summary of model into log directory.
204 | with open(os.path.join(root_path, "model.txt"), "w") as out_file:
205 | out_file.write(str(model))
206 |
207 | objs_to_save = {'quadtree': coord_dataset.quadtree}
208 |
209 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr,
210 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt,
211 | model_dir=root_path, loss_fn=loss_fn, pruning_fn=pruning_fn, summary_fn=summary_fn, objs_to_save=objs_to_save,
212 | resume_checkpoint=resume_checkpoint)
213 |
214 |
215 | # evaluate PSNR at saved checkpoints and save model outputs
216 | def run_eval(model, coord_dataset):
217 | # get checkpoint directory
218 | checkpoint_dir = os.path.join(os.path.dirname(opt.config), 'checkpoints')
219 |
220 | # make eval directory
221 | eval_dir = os.path.join(os.path.dirname(opt.config), 'eval')
222 | utils.cond_mkdir(eval_dir)
223 |
224 | # get model & optim files
225 | model_files = sorted([f for f in os.listdir(checkpoint_dir) if re.search(r'model_[0-9]+.pth', f)], reverse=True)
226 | optim_files = sorted([f for f in os.listdir(checkpoint_dir) if re.search(r'optim_[0-9]+.pth', f)], reverse=True)
227 |
228 | # extract iterations
229 | iters = [int(re.search(r'[0-9]+', f)[0]) for f in model_files]
230 |
231 | # append beginning of path
232 | model_files = [os.path.join(checkpoint_dir, f) for f in model_files]
233 | optim_files = [os.path.join(checkpoint_dir, f) for f in optim_files]
234 |
235 | # iterate through model and optim files
236 | metrics = {}
237 | saved_gt = False
238 | for curr_iter, model_path, optim_path in zip(tqdm(iters), model_files, optim_files):
239 |
240 | # load model and optimizer files
241 | print('Loading models')
242 | model_dict = torch.load(model_path)
243 | optim_dict = torch.load(optim_path)
244 |
245 | # initialize model state_dict
246 | print('Initializing models')
247 | model.load_state_dict(model_dict)
248 | coord_dataset.quadtree.__load__(optim_dict['quadtree'])
249 | coord_dataset.synchronize()
250 |
251 | # save image and calculate psnr
252 | coord_dataset.toggle_eval()
253 | model_input, gt = coord_dataset[0]
254 | coord_dataset.toggle_eval()
255 |
256 | # convert to cuda and add batch dimension
257 | tmp = {}
258 | for key, value in model_input.items():
259 | if isinstance(value, torch.Tensor):
260 | tmp.update({key: value[None, ...].cpu()})
261 | else:
262 | tmp.update({key: value})
263 | model_input = tmp
264 |
265 | tmp = {}
266 | for key, value in gt.items():
267 | if isinstance(value, torch.Tensor):
268 | tmp.update({key: value[None, ...].cpu()})
269 | else:
270 | tmp.update({key: value})
271 | gt = tmp
272 |
273 | # run the model on uniform samples
274 | print('Running forward pass')
275 | n_channels = gt['img'].shape[-1]
276 | start = time()
277 | with torch.no_grad():
278 | pred_img = utils.process_batch_in_chunks(model_input, model, max_chunk_size=512)['model_out']['output']
279 | torch.cuda.synchronize()
280 | print(f'Model: {time() - start:.02f}')
281 |
282 | # get pixel idx for each coordinate
283 | start = time()
284 | coords = model_input['fine_abs_coords'].detach().cpu().numpy()
285 | pixel_idx = np.zeros_like(coords).astype(np.int32)
286 | pixel_idx[..., 0] = np.round((coords[..., 0] + 1.)/2. * (coord_dataset.sidelength[0]-1)).astype(np.int32)
287 | pixel_idx[..., 1] = np.round((coords[..., 1] + 1.)/2. * (coord_dataset.sidelength[1]-1)).astype(np.int32)
288 | pixel_idx = pixel_idx.reshape(-1, 2)
289 |
290 | # assign predicted image values into a new array
291 | # need to use numpy since it supports index assignment
292 | pred_img = pred_img.detach().cpu().numpy().reshape(-1, n_channels)
293 | display_pred = np.zeros((*coord_dataset.sidelength, n_channels))
294 | display_pred[[pixel_idx[:, 0]], [pixel_idx[:, 1]]] = pred_img
295 | display_pred = torch.tensor(display_pred)[None, ...]
296 | display_pred = display_pred.permute(0, 3, 1, 2)
297 |
298 | if not saved_gt:
299 | gt_img = gt['img'].detach().cpu().numpy().reshape(-1, n_channels)
300 | display_gt = np.zeros((*coord_dataset.sidelength, n_channels))
301 | display_gt[[pixel_idx[:, 0]], [pixel_idx[:, 1]]] = gt_img
302 | display_gt = torch.tensor(display_gt)[None, ...]
303 | display_gt = display_gt.permute(0, 3, 1, 2)
304 | print(f'Reshape: {time() - start:.02f}')
305 |
306 | # record metrics
307 | start = time()
308 | psnr, ssim = get_metrics(display_pred, display_gt)
309 | metrics.update({curr_iter: {'psnr': psnr, 'ssim': ssim}})
310 | print(f'Metrics: {time() - start:.02f}')
311 | print(f'Iter: {curr_iter}, PSNR: {psnr:.02f}')
312 |
313 | # save images
314 | pred_out = np.clip((display_pred.squeeze().numpy()/2.) + 0.5, a_min=0., a_max=1.).transpose(1, 2, 0)*255
315 | pred_out = pred_out.astype(np.uint8)
316 | pred_fname = os.path.join(eval_dir, f'pred_{curr_iter:06d}.png')
317 | print('Saving image')
318 | cv2.imwrite(pred_fname, cv2.cvtColor(pred_out, cv2.COLOR_RGB2BGR))
319 |
320 | if not saved_gt:
321 | print('Saving gt')
322 | gt_out = np.clip((display_gt.squeeze().numpy()/2.) + 0.5, a_min=0., a_max=1.).transpose(1, 2, 0)*255
323 | gt_out = gt_out.astype(np.uint8)
324 | gt_fname = os.path.join(eval_dir, 'gt.png')
325 | cv2.imwrite(gt_fname, cv2.cvtColor(gt_out, cv2.COLOR_RGB2BGR))
326 | saved_gt = True
327 |
328 | # save tiling
329 | tiling_fname = os.path.join(eval_dir, f'tiling_{curr_iter:06d}.pdf')
330 | coord_dataset.quadtree.draw()
331 | plt.savefig(tiling_fname)
332 |
333 | # save metrics
334 | metric_fname = os.path.join(eval_dir, f'metrics_{curr_iter:06d}.npy')
335 | np.save(metric_fname, metrics)
336 |
337 |
338 | def get_metrics(pred_img, gt_img):
339 | pred_img = pred_img.detach().cpu().numpy().squeeze()
340 | gt_img = gt_img.detach().cpu().numpy().squeeze()
341 |
342 | p = pred_img.transpose(1, 2, 0)
343 | trgt = gt_img.transpose(1, 2, 0)
344 |
345 | p = (p / 2.) + 0.5
346 | p = np.clip(p, a_min=0., a_max=1.)
347 |
348 | trgt = (trgt / 2.) + 0.5
349 |
350 | ssim = skimage.metrics.structural_similarity(p, trgt, multichannel=True, data_range=1)
351 | psnr = skimage.metrics.peak_signal_noise_ratio(p, trgt, data_range=1)
352 |
353 | return psnr, ssim
354 |
355 |
356 | if __name__ == '__main__':
357 | main()
358 |
--------------------------------------------------------------------------------
/experiment_scripts/train_occupancy.py:
--------------------------------------------------------------------------------
1 | # Enable import from parent package
2 | import sys
3 | import os
4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5 | import dataio
6 | import utils
7 | import training
8 | import loss_functions
9 | import pruning_functions
10 | import modules
11 | from torch.utils.data import DataLoader
12 | import configargparse
13 | from functools import partial
14 | import torch
15 | import re
16 | import mcubes
17 |
18 |
19 | p = configargparse.ArgumentParser()
20 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
21 |
22 | # General training options
23 | p.add_argument('--lr', type=float, default=1e-3, help='learning rate. default=5e-5')
24 | p.add_argument('--num_epochs', type=int, default=10000,
25 | help='Number of epochs to train for.')
26 | p.add_argument('--num_workers', type=int, default=0,
27 | help='number of dataloader workers.')
28 | p.add_argument('--pc_filepath', type=str, default='', help='pc_or_mesh')
29 | p.add_argument('--load', type=str, default=None, help='logging directory to resume from')
30 | p.add_argument('--gpu', type=int, default=0,
31 | help='GPU ID to use')
32 |
33 | # logging options
34 | p.add_argument('--experiment_name', type=str, required=True,
35 | help='path to directory where checkpoints & tensorboard events will be saved.')
36 | p.add_argument('--skip_logging', action='store_true', default=False,
37 | help="don't use summary function, only save loss and models")
38 | p.add_argument('--epochs_til_ckpt', type=int, default=4,
39 | help='Time interval in seconds until checkpoint is saved.')
40 | p.add_argument('--steps_til_summary', type=int, default=500,
41 | help='Time interval in seconds until tensorboard summary is saved.')
42 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
43 | p.add_argument('--logging_root', type=str, default='../logs', help='root for logging')
44 |
45 | # dataset options
46 | p.add_argument('--res', type=int, default=512,
47 | help='image resolution.')
48 | p.add_argument('--octant_size', type=int, default=4,
49 | help='patch size.')
50 | p.add_argument('--max_octants', type=int, default=1024)
51 | p.add_argument('--scale_init', type=int, default=3,
52 | help='which scale to initialize active octants in the octree')
53 |
54 | # model options
55 | p.add_argument('--steps_til_tiling', type=int, default=1000,
56 | help='how often to recompute the tiling')
57 | p.add_argument('--hidden_features', type=int, default=512,
58 | help='hidden features in network')
59 | p.add_argument('--hidden_layers', type=int, default=4,
60 | help='hidden layers in network')
61 | p.add_argument('--feature_grid_size', nargs='+', type=int, default=(18, 12, 12, 12))
62 | p.add_argument('--pruning_threshold', type=float, default=-10)
63 | p.add_argument('--epochs_til_pruning', type=int, default=4)
64 |
65 | # export options
66 | p.add_argument('--export', action='store_true', default=False,
67 | help='export mesh from checkpoint (requires load flag)')
68 | p.add_argument('--upsample', type=int, default=2,
69 | help='how much to upsamples the occupancies used to generate the output mesh')
70 | p.add_argument('--mc_threshold', type=float, default=0.005, help='threshold for marching cubes')
71 |
72 |
73 | opt = p.parse_args()
74 |
75 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
76 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu)
77 |
78 | feature_grid_size = tuple(opt.feature_grid_size)
79 |
80 | assert opt.pc_filepath is not None, "Must specify dataset input"
81 |
82 | for k, v in opt.__dict__.items():
83 | print(k, v)
84 | print()
85 |
86 |
87 | def main():
88 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
89 | utils.cond_mkdir(root_path)
90 |
91 | point_cloud_dataset = dataio.OccupancyDataset(opt.pc_filepath)
92 |
93 | coord_dataset = dataio.Block3DWrapperMultiscaleAdaptive(point_cloud_dataset,
94 | sidelength=opt.res,
95 | octant_size=opt.octant_size,
96 | jitter=True, max_octants=opt.max_octants,
97 | num_workers=opt.num_workers,
98 | length=opt.steps_til_tiling,
99 | scale_init=opt.scale_init)
100 |
101 | model = modules.ImplicitAdaptiveOctantNet(in_features=3+1, out_features=1,
102 | num_hidden_layers=opt.hidden_layers,
103 | hidden_features=opt.hidden_features,
104 | feature_grid_size=feature_grid_size,
105 | octant_size=opt.octant_size)
106 | model.cuda()
107 |
108 | resume_checkpoint = {}
109 | if opt.load is not None:
110 | resume_checkpoint = load_from_checkpoint(opt.load, model, coord_dataset)
111 |
112 | if opt.export:
113 | assert opt.load is not None, 'Need to specify which model to export with --load'
114 |
115 | export_mesh(model, coord_dataset, opt.upsample, opt.mc_threshold)
116 | return
117 |
118 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
119 | print(f"\n\nTrainable Parameters: {num_params}\n\n")
120 |
121 | dataloader = DataLoader(coord_dataset, shuffle=False, batch_size=1, pin_memory=True,
122 | num_workers=opt.num_workers)
123 |
124 | # Define the loss
125 | loss_fn = partial(loss_functions.occupancy_bce,
126 | tiling_every=opt.steps_til_tiling,
127 | dataset=coord_dataset)
128 |
129 | summary_fn = partial(utils.write_occupancy_multiscale_summary, (opt.res, opt.res, opt.res),
130 | coord_dataset, output_mrc=f'{opt.experiment_name}.mrc',
131 | skip=opt.skip_logging)
132 |
133 | # Define the pruning
134 | pruning_fn = partial(pruning_functions.pruning_occupancy,
135 | threshold=opt.pruning_threshold)
136 |
137 | # Save command-line parameters log directory.
138 | p.write_config_file(opt, [os.path.join(root_path, 'config.ini')])
139 |
140 | # Save text summary of model into log directory.
141 | with open(os.path.join(root_path, "model.txt"), "w") as out_file:
142 | out_file.write(str(model))
143 |
144 | objs_to_save = {'octtree': coord_dataset.octtree}
145 |
146 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr,
147 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt,
148 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, objs_to_save=objs_to_save,
149 | pruning_fn=pruning_fn, epochs_til_pruning=opt.epochs_til_pruning,
150 | resume_checkpoint=resume_checkpoint)
151 |
152 |
153 | def load_from_checkpoint(experiment_dir, model, coord_dataset):
154 | checkpoint_dir = os.path.join(experiment_dir, 'checkpoints')
155 | model_files = sorted([f for f in os.listdir(checkpoint_dir) if re.search(r'model_[0-9]+.pth', f)], reverse=False)
156 | optim_files = sorted([f for f in os.listdir(checkpoint_dir) if re.search(r'optim_[0-9]+.pth', f)], reverse=False)
157 |
158 | # append beginning of path
159 | model_files = [os.path.join(checkpoint_dir, f) for f in model_files]
160 | optim_files = [os.path.join(checkpoint_dir, f) for f in optim_files]
161 | model_path = model_files[-1]
162 | optim_path = optim_files[-1]
163 |
164 | print("MODEL PATH: ", model_path)
165 |
166 | # load model and octree
167 | print('Loading models')
168 | model_dict = torch.load(model_path)
169 | optim_dict = torch.load(optim_path)
170 |
171 | # initialize model state_dict
172 | print('Initializing models')
173 | model.load_state_dict(model_dict)
174 | coord_dataset.octtree.__load__(optim_dict['octtree'])
175 | coord_dataset.synchronize()
176 |
177 | resume_checkpoint = {}
178 | resume_checkpoint['optimizer_state_dict'] = optim_dict['optimizer_state_dict']
179 | resume_checkpoint['total_steps'] = optim_dict['total_steps']
180 | resume_checkpoint['epoch'] = optim_dict['epoch']
181 |
182 | return resume_checkpoint
183 |
184 |
185 | def export_mesh(model, dataset, upsample, mcubes_threshold=0.005):
186 | res = 3*(upsample*opt.res,)
187 | model.octant_size = model.octant_size * upsample
188 |
189 | print('Export: calculating occupancy...')
190 | mrc_fname = os.path.join(opt.logging_root, opt.experiment_name, f"{opt.experiment_name}.mrc")
191 | occupancy = utils.write_occupancy_multiscale_summary(res, dataset, model,
192 | None, None, None, None, None,
193 | output_mrc=mrc_fname,
194 | oversample=upsample,
195 | mode='hq')
196 |
197 | print('Export: running marching cubes...')
198 | vertices, faces = mcubes.marching_cubes(occupancy, mcubes_threshold)
199 |
200 | print('Export: exporting mesh...')
201 | out_fname = os.path.join(opt.logging_root, opt.experiment_name, f"{opt.experiment_name}.dae")
202 | mcubes.export_mesh(vertices, faces, out_fname)
203 |
204 |
205 | if __name__ == '__main__':
206 | main()
207 |
--------------------------------------------------------------------------------
/img/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/ACORN/67817daaf20d38840f043e81704e7a6ea5da584e/img/teaser.png
--------------------------------------------------------------------------------
/inside_mesh/inside_mesh.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright 2019 Lars Mescheder, Michael Oechsle, Michael Niemeyer, Andreas Geiger, Sebastian Nowozin
3 |
4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
5 |
6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
7 |
8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
9 |
10 | https://github.com/autonomousvision/occupancy_networks/tree/ddb2908f96de9c0c5a30c093f2a701878ffc1f4a/im2mesh/utils/libmesh
11 | '''
12 |
13 |
14 | import numpy as np
15 | from .triangle_hash import TriangleHash as _TriangleHash
16 |
17 |
18 | def check_mesh_contains(mesh, points, hash_resolution=512):
19 | intersector = MeshIntersector(mesh, hash_resolution)
20 | contains = intersector.query(points)
21 | return contains
22 |
23 |
24 | class MeshIntersector:
25 | def __init__(self, mesh, resolution=512):
26 | triangles = mesh.vertices[mesh.faces].astype(np.float64)
27 | n_tri = triangles.shape[0]
28 |
29 | self.resolution = resolution
30 | self.bbox_min = triangles.reshape(3 * n_tri, 3).min(axis=0)
31 | self.bbox_max = triangles.reshape(3 * n_tri, 3).max(axis=0)
32 | # Tranlate and scale it to [0.5, self.resolution - 0.5]^3
33 | self.scale = (resolution - 1) / (self.bbox_max - self.bbox_min)
34 | self.translate = 0.5 - self.scale * self.bbox_min
35 |
36 | self._triangles = triangles = self.rescale(triangles)
37 | # assert(np.allclose(triangles.reshape(-1, 3).min(0), 0.5))
38 | # assert(np.allclose(triangles.reshape(-1, 3).max(0), resolution - 0.5))
39 |
40 | triangles2d = triangles[:, :, :2]
41 | self._tri_intersector2d = TriangleIntersector2d(
42 | triangles2d, resolution)
43 |
44 | def query(self, points):
45 | # Rescale points
46 | points = self.rescale(points)
47 |
48 | # placeholder result with no hits we'll fill in later
49 | contains = np.zeros(len(points), dtype=np.bool)
50 |
51 | # cull points outside of the axis aligned bounding box
52 | # this avoids running ray tests unless points are close
53 | inside_aabb = np.all(
54 | (0 <= points) & (points <= self.resolution), axis=1)
55 | if not inside_aabb.any():
56 | return contains
57 |
58 | # Only consider points inside bounding box
59 | mask = inside_aabb
60 | points = points[mask]
61 |
62 | # Compute intersection depth and check order
63 | points_indices, tri_indices = self._tri_intersector2d.query(points[:, :2])
64 |
65 | triangles_intersect = self._triangles[tri_indices]
66 | points_intersect = points[points_indices]
67 |
68 | depth_intersect, abs_n_2 = self.compute_intersection_depth(
69 | points_intersect, triangles_intersect)
70 |
71 | # Count number of intersections in both directions
72 | smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2
73 | bigger_depth = depth_intersect < points_intersect[:, 2] * abs_n_2
74 | points_indices_0 = points_indices[smaller_depth]
75 | points_indices_1 = points_indices[bigger_depth]
76 |
77 | nintersect0 = np.bincount(points_indices_0, minlength=points.shape[0])
78 | nintersect1 = np.bincount(points_indices_1, minlength=points.shape[0])
79 |
80 | # Check if point contained in mesh
81 | contains1 = (np.mod(nintersect0, 2) == 1)
82 | contains2 = (np.mod(nintersect1, 2) == 1)
83 | # if (contains1 != contains2).any():
84 | # print('Warning: contains1 != contains2 for some points.')
85 | contains[mask] = (contains1 & contains2)
86 | return contains
87 |
88 | def compute_intersection_depth(self, points, triangles):
89 | t1 = triangles[:, 0, :]
90 | t2 = triangles[:, 1, :]
91 | t3 = triangles[:, 2, :]
92 |
93 | v1 = t3 - t1
94 | v2 = t2 - t1
95 | # v1 = v1 / np.linalg.norm(v1, axis=-1, keepdims=True)
96 | # v2 = v2 / np.linalg.norm(v2, axis=-1, keepdims=True)
97 |
98 | normals = np.cross(v1, v2)
99 | alpha = np.sum(normals[:, :2] * (t1[:, :2] - points[:, :2]), axis=1)
100 |
101 | n_2 = normals[:, 2]
102 | t1_2 = t1[:, 2]
103 | s_n_2 = np.sign(n_2)
104 | abs_n_2 = np.abs(n_2)
105 |
106 | mask = (abs_n_2 != 0)
107 |
108 | depth_intersect = np.full(points.shape[0], np.nan)
109 | depth_intersect[mask] = \
110 | t1_2[mask] * abs_n_2[mask] + alpha[mask] * s_n_2[mask]
111 |
112 | # Test the depth:
113 | # TODO: remove and put into tests
114 | # points_new = np.concatenate([points[:, :2], depth_intersect[:, None]], axis=1)
115 | # alpha = (normals * t1).sum(-1)
116 | # mask = (depth_intersect == depth_intersect)
117 | # assert(np.allclose((points_new[mask] * normals[mask]).sum(-1),
118 | # alpha[mask]))
119 | return depth_intersect, abs_n_2
120 |
121 | def rescale(self, array):
122 | array = self.scale * array + self.translate
123 | return array
124 |
125 |
126 | class TriangleIntersector2d:
127 | def __init__(self, triangles, resolution=128):
128 | self.triangles = triangles
129 | self.tri_hash = _TriangleHash(triangles, resolution)
130 |
131 | def query(self, points):
132 | point_indices, tri_indices = self.tri_hash.query(points)
133 | point_indices = np.array(point_indices, dtype=np.int64)
134 | tri_indices = np.array(tri_indices, dtype=np.int64)
135 | points = points[point_indices]
136 | triangles = self.triangles[tri_indices]
137 | mask = self.check_triangles(points, triangles)
138 | point_indices = point_indices[mask]
139 | tri_indices = tri_indices[mask]
140 | return point_indices, tri_indices
141 |
142 | def check_triangles(self, points, triangles):
143 | contains = np.zeros(points.shape[0], dtype=np.bool)
144 | A = triangles[:, :2] - triangles[:, 2:]
145 | A = A.transpose([0, 2, 1])
146 | y = points - triangles[:, 2]
147 |
148 | detA = A[:, 0, 0] * A[:, 1, 1] - A[:, 0, 1] * A[:, 1, 0]
149 |
150 | mask = (np.abs(detA) != 0.)
151 | A = A[mask]
152 | y = y[mask]
153 | detA = detA[mask]
154 |
155 | s_detA = np.sign(detA)
156 | abs_detA = np.abs(detA)
157 |
158 | u = (A[:, 1, 1] * y[:, 0] - A[:, 0, 1] * y[:, 1]) * s_detA
159 | v = (-A[:, 1, 0] * y[:, 0] + A[:, 0, 0] * y[:, 1]) * s_detA
160 |
161 | sum_uv = u + v
162 | contains[mask] = (
163 | (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA)
164 | & (0 < sum_uv) & (sum_uv < abs_detA)
165 | )
166 | return contains
167 |
--------------------------------------------------------------------------------
/inside_mesh/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from Cython.Build import cythonize
3 | import numpy
4 |
5 | setup(
6 | ext_modules = cythonize("triangle_hash.pyx"),
7 | include_dirs = [numpy.get_include()]
8 | )
--------------------------------------------------------------------------------
/inside_mesh/triangle_hash.pyx:
--------------------------------------------------------------------------------
1 | # distutils: language=c++
2 | '''
3 | Copyright 2019 Lars Mescheder, Michael Oechsle, Michael Niemeyer, Andreas Geiger, Sebastian Nowozin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | https://github.com/autonomousvision/occupancy_networks/tree/ddb2908f96de9c0c5a30c093f2a701878ffc1f4a/im2mesh/utils/libmesh
12 | '''
13 |
14 |
15 | import numpy as np
16 | cimport numpy as np
17 | cimport cython
18 | from libcpp.vector cimport vector
19 | from libc.math cimport floor, ceil
20 |
21 | cdef class TriangleHash:
22 | cdef vector[vector[int]] spatial_hash
23 | cdef int resolution
24 |
25 | def __cinit__(self, double[:, :, :] triangles, int resolution):
26 | self.spatial_hash.resize(resolution * resolution)
27 | self.resolution = resolution
28 | self._build_hash(triangles)
29 |
30 | @cython.boundscheck(False) # Deactivate bounds checking
31 | @cython.wraparound(False) # Deactivate negative indexing.
32 | cdef int _build_hash(self, double[:, :, :] triangles):
33 | assert(triangles.shape[1] == 3)
34 | assert(triangles.shape[2] == 2)
35 |
36 | cdef int n_tri = triangles.shape[0]
37 | cdef int bbox_min[2]
38 | cdef int bbox_max[2]
39 |
40 | cdef int i_tri, j, x, y
41 | cdef int spatial_idx
42 |
43 | for i_tri in range(n_tri):
44 | # Compute bounding box
45 | for j in range(2):
46 | bbox_min[j] = min(
47 | triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j]
48 | )
49 | bbox_max[j] = max(
50 | triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j]
51 | )
52 | bbox_min[j] = min(max(bbox_min[j], 0), self.resolution - 1)
53 | bbox_max[j] = min(max(bbox_max[j], 0), self.resolution - 1)
54 |
55 | # Find all voxels where bounding box intersects
56 | for x in range(bbox_min[0], bbox_max[0] + 1):
57 | for y in range(bbox_min[1], bbox_max[1] + 1):
58 | spatial_idx = self.resolution * x + y
59 | self.spatial_hash[spatial_idx].push_back(i_tri)
60 |
61 | @cython.boundscheck(False) # Deactivate bounds checking
62 | @cython.wraparound(False) # Deactivate negative indexing.
63 | cpdef query(self, double[:, :] points):
64 | assert(points.shape[1] == 2)
65 | cdef int n_points = points.shape[0]
66 |
67 | cdef vector[int] points_indices
68 | cdef vector[int] tri_indices
69 | # cdef int[:] points_indices_np
70 | # cdef int[:] tri_indices_np
71 |
72 | cdef int i_point, k, x, y
73 | cdef int spatial_idx
74 |
75 | for i_point in range(n_points):
76 | x = int(points[i_point, 0])
77 | y = int(points[i_point, 1])
78 | if not (0 <= x < self.resolution and 0 <= y < self.resolution):
79 | continue
80 |
81 | spatial_idx = self.resolution * x + y
82 | for i_tri in self.spatial_hash[spatial_idx]:
83 | points_indices.push_back(i_point)
84 | tri_indices.push_back(i_tri)
85 |
86 | points_indices_np = np.zeros(points_indices.size(), dtype=np.int32)
87 | tri_indices_np = np.zeros(tri_indices.size(), dtype=np.int32)
88 |
89 | cdef int[:] points_indices_view = points_indices_np
90 | cdef int[:] tri_indices_view = tri_indices_np
91 |
92 | for k in range(points_indices.size()):
93 | points_indices_view[k] = points_indices[k]
94 |
95 | for k in range(tri_indices.size()):
96 | tri_indices_view[k] = tri_indices[k]
97 |
98 | return points_indices_np, tri_indices_np
99 |
--------------------------------------------------------------------------------
/loss_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def image_mse(model_output, gt, step, tiling_every=100, dataset=None, model_type='multiscale', retile=True):
5 | img_loss = (model_output['model_out']['output'] - gt['img'])**2
6 |
7 | if model_type == 'multiscale':
8 | per_patch_loss = torch.mean(img_loss, dim=(-1, -2)).squeeze(0).detach().cpu().numpy()
9 |
10 | dataset.update_patch_err(per_patch_loss, step)
11 | if step % tiling_every == tiling_every-1 and retile:
12 | tiling_stats = dataset.update_tiling()
13 | if tiling_stats['merged'] != 0 or tiling_stats['splits'] != 0:
14 | dataset.synchronize()
15 |
16 | return {'img_loss': img_loss.mean()}
17 |
18 |
19 | def occupancy_bce(model_output, gt, step, tiling_every=100, dataset=None,
20 | model_type='multiscale', pruning_fn=None, retile=True):
21 | occupancy_loss = torch.nn.BCEWithLogitsLoss(reduction='none')(model_output['model_out']['output'], gt['occupancy'].float())
22 |
23 | if model_type == 'multiscale':
24 | per_octant_loss = torch.mean(occupancy_loss, dim=(-1, -2)).squeeze(0).detach().cpu().numpy()
25 |
26 | dataset.update_octant_err(per_octant_loss, step)
27 | if step % tiling_every == tiling_every-1 and retile:
28 | tiling_stats = dataset.update_tiling()
29 | if tiling_stats['merged'] != 0 or tiling_stats['splits'] != 0:
30 | dataset.synchronize()
31 |
32 | return {'occupancy_loss': occupancy_loss.mean()}
33 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import trimesh
4 | from scipy.spatial import cKDTree as KDTree
5 | from inside_mesh.triangle_hash import TriangleHash as _TriangleHash
6 |
7 | '''
8 |
9 | Some code included from 'inside_mesh' library of Occupancy Networks
10 | https://github.com/autonomousvision/occupancy_networks
11 |
12 | '''
13 |
14 |
15 | def define_grid_3d(N, voxel_origin=[-1, -1, -1], voxel_size=None):
16 | ''' define NxNxN coordinate grid across [-1, 1]
17 | voxel_origin is the (bottom, left, down) corner, not the middle '''
18 |
19 | if not voxel_size:
20 | voxel_size = 2.0 / (N - 1)
21 |
22 | # initialize empty tensors
23 | overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor())
24 | grid = torch.zeros(N ** 3, 3)
25 |
26 | # transform first 3 columns to be x, y, z voxel index
27 | # every possible comb'n of [0..N,0..N,0..N]
28 | grid[:, 2] = overall_index % N # [0,1,2,...,N-1,N,0,1,2,...,N]
29 | grid[:, 1] = (overall_index.long() // N) % N # [N [N 0's, ..., N N's]]
30 | grid[:, 0] = ((overall_index.long() // N) // N) % N # [N*N 0's,...,N*N N's]
31 |
32 | # transform first 3 columns: voxel indices --> voxel coordinates
33 | grid[:, 0] = (grid[:, 0] * voxel_size) + voxel_origin[2]
34 | grid[:, 1] = (grid[:, 1] * voxel_size) + voxel_origin[1]
35 | grid[:, 2] = (grid[:, 2] * voxel_size) + voxel_origin[0]
36 |
37 | return grid
38 |
39 |
40 | def compute_iou(path_gt, path_pr, N=128, sphere=False, sphere_radius=0.25):
41 | ''' compute iou score
42 | parameters
43 | path_gt: path to ground-truth mesh (.ply or .obj)
44 | path_pr: path to predicted mesh (.ply or .obj)
45 | N: NxNxN grid resolution at which to compute iou '''
46 |
47 | # define NxNxN coordinate grid across [-1,1]
48 | grid = np.array(define_grid_3d(N))
49 |
50 | # load mesh
51 | occ_pr = MeshDataset(path_pr)
52 |
53 | # compute occupancy at specified grid points
54 | if sphere:
55 | occ_gt = torch.from_numpy(np.linalg.norm(grid, axis=-1) <= sphere_radius)
56 | else:
57 | occ_gt = MeshDataset(path_gt)
58 | occ_gt = torch.tensor(check_mesh_contains(occ_gt.mesh, grid))
59 |
60 | occ_pr = torch.tensor(check_mesh_contains(occ_pr.mesh, grid))
61 |
62 | # compute iou
63 | area_union = torch.sum((occ_gt | occ_pr).float())
64 | area_intersect = torch.sum((occ_gt & occ_pr).float())
65 | iou = area_intersect / area_union
66 |
67 | return iou.item()
68 |
69 |
70 | def compute_trimesh_chamfer(mesh1, mesh2, num_mesh_samples=300000):
71 | """
72 | This function computes a symmetric chamfer distance, i.e. the sum of both chamfers.
73 | gt_points: trimesh.points.PointCloud of just poins, sampled from the surface (see
74 | compute_metrics.ply for more documentation)
75 | gen_mesh: trimesh.base.Trimesh of output mesh from whichever autoencoding reconstruction
76 | method (see compute_metrics.py for more)
77 | """
78 |
79 | gen_points_sampled = trimesh.sample.sample_surface(mesh1, num_mesh_samples)[0]
80 | gt_points_np = trimesh.sample.sample_surface(mesh2, num_mesh_samples)[0]
81 |
82 | # one direction
83 | gen_points_kd_tree = KDTree(gen_points_sampled)
84 | one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points_np)
85 | gt_to_gen_chamfer = np.mean(np.square(one_distances))
86 |
87 | # other direction
88 | gt_points_kd_tree = KDTree(gt_points_np)
89 | two_distances, two_vertex_ids = gt_points_kd_tree.query(gen_points_sampled)
90 | gen_to_gt_chamfer = np.mean(np.square(two_distances))
91 |
92 | chamfer_dist = gt_to_gen_chamfer + gen_to_gt_chamfer
93 | return chamfer_dist
94 |
95 |
96 | class MeshDataset():
97 | def __init__(self, path_mesh, sample=False, num_pts=0):
98 |
99 | if not path_mesh:
100 | return
101 |
102 | self.mesh = trimesh.load(path_mesh, process=False,
103 | force='mesh', skip_materials=True)
104 |
105 |
106 | def check_mesh_contains(mesh, points, hash_resolution=512):
107 | intersector = MeshIntersector(mesh, hash_resolution)
108 | contains = intersector.query(points)
109 | return contains
110 |
111 |
112 | class MeshIntersector:
113 | def __init__(self, mesh, resolution=512):
114 | triangles = mesh.vertices[mesh.faces].astype(np.float64)
115 | n_tri = triangles.shape[0]
116 |
117 | self.resolution = resolution
118 | self.bbox_min = triangles.reshape(3 * n_tri, 3).min(axis=0)
119 | self.bbox_max = triangles.reshape(3 * n_tri, 3).max(axis=0)
120 |
121 | # Tranlate and scale it to [0.5, self.resolution - 0.5]^3
122 | self.scale = (resolution - 1) / (self.bbox_max - self.bbox_min)
123 | self.translate = 0.5 - self.scale * self.bbox_min
124 | self._triangles = triangles = self.rescale(triangles)
125 |
126 | triangles2d = triangles[:, :, :2]
127 | self._tri_intersector2d = TriangleIntersector2d(
128 | triangles2d, resolution)
129 |
130 | def query(self, points):
131 | # Rescale points
132 | points = self.rescale(points)
133 |
134 | # placeholder result with no hits we'll fill in later
135 | contains = np.zeros(len(points), dtype=np.bool)
136 |
137 | # cull points outside of the axis aligned bounding box
138 | # this avoids running ray tests unless points are close
139 | inside_aabb = np.all(
140 | (0 <= points) & (points <= self.resolution), axis=1)
141 | if not inside_aabb.any():
142 | return contains
143 |
144 | # Only consider points inside bounding box
145 | mask = inside_aabb
146 | points = points[mask]
147 |
148 | # Compute intersection depth and check order
149 | points_indices, tri_indices = self._tri_intersector2d.query(points[:, :2])
150 |
151 | triangles_intersect = self._triangles[tri_indices]
152 | points_intersect = points[points_indices]
153 |
154 | depth_intersect, abs_n_2 = self.compute_intersection_depth(
155 | points_intersect, triangles_intersect)
156 |
157 | # Count number of intersections in both directions
158 | smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2
159 | bigger_depth = depth_intersect < points_intersect[:, 2] * abs_n_2
160 | points_indices_0 = points_indices[smaller_depth]
161 | points_indices_1 = points_indices[bigger_depth]
162 |
163 | nintersect0 = np.bincount(points_indices_0, minlength=points.shape[0])
164 | nintersect1 = np.bincount(points_indices_1, minlength=points.shape[0])
165 |
166 | # Check if point contained in mesh
167 | contains1 = (np.mod(nintersect0, 2) == 1)
168 | contains2 = (np.mod(nintersect1, 2) == 1)
169 | contains[mask] = (contains1 & contains2)
170 | return contains
171 |
172 | def compute_intersection_depth(self, points, triangles):
173 | t1 = triangles[:, 0, :]
174 | t2 = triangles[:, 1, :]
175 | t3 = triangles[:, 2, :]
176 |
177 | v1 = t3 - t1
178 | v2 = t2 - t1
179 |
180 | normals = np.cross(v1, v2)
181 | alpha = np.sum(normals[:, :2] * (t1[:, :2] - points[:, :2]), axis=1)
182 |
183 | n_2 = normals[:, 2]
184 | t1_2 = t1[:, 2]
185 | s_n_2 = np.sign(n_2)
186 | abs_n_2 = np.abs(n_2)
187 |
188 | mask = (abs_n_2 != 0)
189 |
190 | depth_intersect = np.full(points.shape[0], np.nan)
191 | depth_intersect[mask] = \
192 | t1_2[mask] * abs_n_2[mask] + alpha[mask] * s_n_2[mask]
193 |
194 | return depth_intersect, abs_n_2
195 |
196 | def rescale(self, array):
197 | array = self.scale * array + self.translate
198 | return array
199 |
200 |
201 | class TriangleIntersector2d:
202 | def __init__(self, triangles, resolution=128):
203 | self.triangles = triangles
204 | self.tri_hash = _TriangleHash(triangles, resolution)
205 |
206 | def query(self, points):
207 | point_indices, tri_indices = self.tri_hash.query(points)
208 | point_indices = np.array(point_indices, dtype=np.int64)
209 | tri_indices = np.array(tri_indices, dtype=np.int64)
210 | points = points[point_indices]
211 | triangles = self.triangles[tri_indices]
212 | mask = self.check_triangles(points, triangles)
213 | point_indices = point_indices[mask]
214 | tri_indices = tri_indices[mask]
215 | return point_indices, tri_indices
216 |
217 | def check_triangles(self, points, triangles):
218 | contains = np.zeros(points.shape[0], dtype=np.bool)
219 | A = triangles[:, :2] - triangles[:, 2:]
220 | A = A.transpose([0, 2, 1])
221 | y = points - triangles[:, 2]
222 |
223 | detA = A[:, 0, 0] * A[:, 1, 1] - A[:, 0, 1] * A[:, 1, 0]
224 |
225 | mask = (np.abs(detA) != 0.)
226 | A = A[mask]
227 | y = y[mask]
228 | detA = detA[mask]
229 |
230 | s_detA = np.sign(detA)
231 | abs_detA = np.abs(detA)
232 |
233 | u = (A[:, 1, 1] * y[:, 0] - A[:, 0, 1] * y[:, 1]) * s_detA
234 | v = (-A[:, 1, 0] * y[:, 0] + A[:, 0, 0] * y[:, 1]) * s_detA
235 |
236 | sum_uv = u + v
237 | contains[mask] = (
238 | (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA)
239 | & (0 < sum_uv) & (sum_uv < abs_detA)
240 | )
241 | return contains
242 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 | import math
5 | from functools import partial
6 |
7 |
8 | class Sine(nn.Module):
9 | def __init__(self, w0=30):
10 | super().__init__()
11 | self.w0 = w0
12 |
13 | def forward(self, input):
14 | # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
15 | return torch.sin(self.w0 * input)
16 |
17 |
18 | class FCBlock(nn.Module):
19 | '''A fully connected neural network that also allows swapping out the weights when used with a hypernetwork.
20 | Can be used just as a normal neural network though, as well.
21 | '''
22 |
23 | def __init__(self, in_features, out_features, num_hidden_layers, hidden_features,
24 | outermost_linear=False, nonlinearity='relu', weight_init=None, w0=30):
25 | super().__init__()
26 |
27 | self.first_layer_init = None
28 |
29 | # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
30 | # special first-layer initialization scheme
31 | nls_and_inits = {'sine': (Sine(w0=w0), partial(sine_init, w0=w0), first_layer_sine_init),
32 | 'relu': (nn.ReLU(inplace=True), init_weights_normal, None)}
33 |
34 | nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]
35 |
36 | if weight_init is not None: # Overwrite weight init if passed
37 | self.weight_init = weight_init
38 | else:
39 | self.weight_init = nl_weight_init
40 |
41 | self.net = []
42 | self.net.append(nn.Sequential(
43 | nn.Linear(in_features, hidden_features), nl
44 | ))
45 |
46 | for i in range(num_hidden_layers):
47 | self.net.append(nn.Sequential(
48 | nn.Linear(hidden_features, hidden_features), nl
49 | ))
50 |
51 | if outermost_linear:
52 | self.net.append(nn.Sequential(nn.Linear(hidden_features, out_features)))
53 | else:
54 | self.net.append(nn.Sequential(
55 | nn.Linear(hidden_features, out_features), nl
56 | ))
57 |
58 | self.net = nn.Sequential(*self.net)
59 | if self.weight_init is not None:
60 | self.net.apply(self.weight_init)
61 |
62 | if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
63 | self.net[0].apply(first_layer_init)
64 |
65 | def forward(self, coords):
66 | output = self.net(coords)
67 | return output
68 |
69 |
70 | class PositionalEncoding(nn.Module):
71 | def __init__(self, num_encoding_functions=6, include_input=True, log_sampling=True, normalize=False,
72 | input_dim=3, gaussian_pe=False, gaussian_variance=38):
73 | super().__init__()
74 | self.num_encoding_functions = num_encoding_functions
75 | self.include_input = include_input
76 | self.log_sampling = log_sampling
77 | self.normalize = normalize
78 | self.gaussian_pe = gaussian_pe
79 | self.normalization = None
80 |
81 | if self.gaussian_pe:
82 | # this needs to be registered as a parameter so that it is saved in the model state dict
83 | # and so that it is converted using .cuda(). Doesn't need to be trained though
84 | self.gaussian_weights = nn.Parameter(gaussian_variance * torch.randn(num_encoding_functions, input_dim),
85 | requires_grad=False)
86 | else:
87 | self.frequency_bands = None
88 | if self.log_sampling:
89 | self.frequency_bands = 2.0 ** torch.linspace(
90 | 0.0,
91 | self.num_encoding_functions - 1,
92 | self.num_encoding_functions)
93 | else:
94 | self.frequency_bands = torch.linspace(
95 | 2.0 ** 0.0,
96 | 2.0 ** (self.num_encoding_functions - 1),
97 | self.num_encoding_functions)
98 |
99 | if normalize:
100 | self.normalization = torch.tensor(1/self.frequency_bands)
101 |
102 | def forward(self, tensor) -> torch.Tensor:
103 | r"""Apply positional encoding to the input.
104 |
105 | Args:
106 | tensor (torch.Tensor): Input tensor to be positionally encoded.
107 | encoding_size (optional, int): Number of encoding functions used to compute
108 | a positional encoding (default: 6).
109 | include_input (optional, bool): Whether or not to include the input in the
110 | positional encoding (default: True).
111 |
112 | Returns:
113 | (torch.Tensor): Positional encoding of the input tensor.
114 | """
115 |
116 | encoding = [tensor] if self.include_input else []
117 | if self.gaussian_pe:
118 | for func in [torch.sin, torch.cos]:
119 | encoding.append(func(torch.matmul(tensor, self.gaussian_weights.T)))
120 | else:
121 | for idx, freq in enumerate(self.frequency_bands):
122 | for func in [torch.sin, torch.cos]:
123 | if self.normalization is not None:
124 | encoding.append(self.normalization[idx]*func(tensor * freq))
125 | else:
126 | encoding.append(func(tensor * freq))
127 |
128 | # Special case, for no positional encoding
129 | if len(encoding) == 1:
130 | return encoding[0]
131 | else:
132 | return torch.cat(encoding, dim=-1)
133 |
134 |
135 | def init_weights_normal(m):
136 | if type(m) == nn.Linear:
137 | if hasattr(m, 'weight'):
138 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
139 |
140 |
141 | def init_weights_xavier(m):
142 | if type(m) == nn.Linear:
143 | if hasattr(m, 'weight'):
144 | nn.init.xavier_normal_(m.weight)
145 |
146 |
147 | def sine_init(m, w0=30):
148 | with torch.no_grad():
149 | if hasattr(m, 'weight'):
150 | num_input = m.weight.size(-1)
151 | # See supplement Sec. 1.5 for discussion of factor w0
152 | m.weight.uniform_(-np.sqrt(6 / num_input) / w0, np.sqrt(6 / num_input) / w0)
153 |
154 |
155 | def first_layer_sine_init(m):
156 | with torch.no_grad():
157 | if hasattr(m, 'weight'):
158 | num_input = m.weight.size(-1)
159 | # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
160 | m.weight.uniform_(-1 / num_input, 1 / num_input)
161 |
162 |
163 | class ImplicitAdaptivePatchNet(nn.Module):
164 | def __init__(self, in_features=3, out_features=1, feature_grid_size=(8, 8, 8),
165 | hidden_features=256, num_hidden_layers=3, patch_size=8,
166 | code_dim=8, use_pe=True, num_encoding_functions=6, **kwargs):
167 | super().__init__()
168 | self.in_features = in_features
169 | self.out_features = out_features
170 | self.feature_grid_size = feature_grid_size
171 | self.patch_size = patch_size
172 | self.use_pe = use_pe
173 |
174 | if self.use_pe:
175 | self.positional_encoding = PositionalEncoding(num_encoding_functions=num_encoding_functions)
176 | in_features = 2*in_features*num_encoding_functions + in_features
177 |
178 | self.coord2features_net = FCBlock(in_features=in_features, out_features=np.prod(feature_grid_size),
179 | num_hidden_layers=num_hidden_layers, hidden_features=hidden_features,
180 | outermost_linear=True, nonlinearity='relu')
181 |
182 | self.features2sample_net = FCBlock(in_features=self.feature_grid_size[0], out_features=out_features,
183 | num_hidden_layers=1, hidden_features=64,
184 | outermost_linear=True, nonlinearity='relu')
185 | print(self)
186 |
187 | def forward(self, model_input):
188 |
189 | # Enables us to compute gradients w.r.t. coordinates
190 | coords = model_input['coords'].clone().detach().requires_grad_(True)
191 | fine_coords = model_input['fine_rel_coords'].clone().detach().requires_grad_(True)
192 |
193 | if self.use_pe:
194 | coords = self.positional_encoding(coords)
195 |
196 | features = self.coord2features_net(coords)
197 |
198 | # features is size (Batch Size, Blocks, prod(feature_grid_size))
199 | # but currently interpolate bilinear only supports one batch dimension,
200 | # therefore, for now assume that Batch Size == 1
201 | assert features.shape[0] == 1, 'Code currently only supports Batch Size == 1'
202 |
203 | n_channels, dx, dy = self.feature_grid_size
204 | features = features.squeeze(0)
205 | b_size = features.shape[0]
206 |
207 | features_in = features.squeeze().reshape(b_size, n_channels, dx, dy)
208 | sample_coords_out = fine_coords[0, ...].reshape(1, -1, 2)
209 | sample_coords = sample_coords_out.reshape(b_size, self.patch_size[0], self.patch_size[1], 2)
210 |
211 | y = sample_coords[..., :1]
212 | x = sample_coords[..., 1:]
213 | sample_coords = torch.cat([y, x], dim=-1)
214 |
215 | features_out = torch.nn.functional.grid_sample(features_in, sample_coords,
216 | mode='bilinear',
217 | padding_mode='border',
218 | align_corners=True).reshape(b_size, n_channels, np.prod(self.patch_size))
219 |
220 | # permute from (Blocks, feature_grid_size[0], patch_size**2)->(Blocks, patch_size**2, feature_grid_size[0])
221 | # so the network maps features to function output
222 | features_out = features_out.permute(0, 2, 1)
223 |
224 | # for all spatial feature vectors, extract function value
225 | patch_out = self.features2sample_net(features_out)
226 |
227 | # squeeze out last dimension and restore batch dimension
228 | patch_out = patch_out.unsqueeze(0)
229 |
230 | return {'model_in': {'sample_coords_out': sample_coords_out, 'model_in_coarse': coords},
231 | 'model_out': {'output': patch_out, 'codes': None}}
232 |
233 |
234 | class ImplicitAdaptiveOctantNet(nn.Module):
235 | def __init__(self, in_features=4, out_features=1, feature_grid_size=(4, 16, 16, 16),
236 | hidden_features=256, num_hidden_layers=3, octant_size=8,
237 | code_dim=8, use_pe=True, num_encoding_functions=6):
238 | super().__init__()
239 | self.in_features = in_features
240 | self.out_features = out_features
241 | self.feature_grid_size = feature_grid_size
242 | self.octant_size = octant_size
243 | self.use_pe = use_pe
244 |
245 | if self.use_pe:
246 | self.positional_encoding = PositionalEncoding(num_encoding_functions=num_encoding_functions)
247 | in_features = 2*in_features*num_encoding_functions + in_features
248 |
249 | self.coord2features_net = FCBlock(in_features=in_features, out_features=np.prod(feature_grid_size),
250 | num_hidden_layers=num_hidden_layers, hidden_features=hidden_features,
251 | outermost_linear=True, nonlinearity='relu')
252 |
253 | self.features2sample_net = FCBlock(in_features=feature_grid_size[0], out_features=out_features,
254 | num_hidden_layers=1, hidden_features=64,
255 | outermost_linear=True, nonlinearity='relu')
256 |
257 | def forward(self, model_input, oversample=1.0):
258 |
259 | # Enables us to compute gradients w.r.t. coordinates
260 | coords = model_input['coords'].clone().detach().requires_grad_(True)
261 | fine_coords = model_input['fine_rel_coords'].clone().detach().requires_grad_(True)
262 |
263 | if self.use_pe:
264 | coords = self.positional_encoding(coords)
265 |
266 | features = self.coord2features_net(coords)
267 |
268 | # features is size (Batch Size, Blocks, prod(feature_grid_size))
269 | # but currently interpolate bilinear only supports one batch dimension,
270 | # therefore, for now assume that Batch Size == 1
271 | assert features.shape[0] == 1, 'Code currently only supports Batch Size == 1'
272 |
273 | n_channels, dx, dy, dz = self.feature_grid_size
274 | features = features.squeeze(0)
275 | b_size = features.shape[0]
276 |
277 | features_in = features.squeeze().reshape(b_size, n_channels, dx, dy, dz)
278 | sample_coords_out = fine_coords[0, ...].reshape(1, -1, 3)
279 | sample_coords = sample_coords_out.reshape(b_size, self.octant_size, self.octant_size, self.octant_size, 3)
280 | features_out = torch.nn.functional.grid_sample(features_in, sample_coords,
281 | mode='bilinear',
282 | padding_mode='border',
283 | align_corners=True).reshape(b_size, n_channels, self.octant_size**3)
284 |
285 | # permute from (Blocks, feature_grid_size[0], patch_size**2)->(Blocks, patch_size**2, feature_grid_size[0])
286 | # so the network maps features to function output
287 | features_out = features_out.permute(0, 2, 1)
288 |
289 | # for all spatial feature vectors, extract function value
290 | patch_out = self.features2sample_net(features_out)
291 |
292 | # squeeze out last dimension and restore batch dimension
293 | patch_out = patch_out.unsqueeze(0)
294 |
295 | return {'model_in': {'sample_coords_out': sample_coords_out, 'model_in_coarse': coords},
296 | 'model_out': {'output': patch_out, 'codes': None}}
297 |
298 |
299 | class ImplicitNet(nn.Module):
300 | '''A canonical representation network for a BVP.'''
301 |
302 | def __init__(self, sidelength, out_features=1, in_features=2,
303 | mode='pe', hidden_features=256, num_hidden_layers=3, w0=30, **kwargs):
304 |
305 | super().__init__()
306 | self.mode = mode
307 |
308 | if self.mode == 'pe':
309 | nyquist_rate = 1 / (2 * (2 * 1/np.max(sidelength)))
310 | num_encoding_functions = int(math.floor(math.log(nyquist_rate, 2)))
311 |
312 | nonlinearity = 'relu'
313 | self.positional_encoding = PositionalEncoding(num_encoding_functions=num_encoding_functions)
314 | in_features = 2*in_features*num_encoding_functions + in_features
315 |
316 | elif self.mode == 'siren':
317 | nonlinearity = 'sine'
318 | else:
319 | raise NotImplementedError(f'mode=={self.mode} not implemented')
320 |
321 | self.net = FCBlock(in_features=in_features, out_features=out_features, num_hidden_layers=num_hidden_layers,
322 | hidden_features=hidden_features, outermost_linear=True, nonlinearity=nonlinearity, w0=w0)
323 | print(self)
324 |
325 | def forward(self, model_input):
326 |
327 | coords = model_input['fine_abs_coords'][..., :2]
328 |
329 | if self.mode == 'pe':
330 | coords = self.positional_encoding(coords)
331 |
332 | output = self.net(coords)
333 | return {'model_in': {'coords': coords}, 'model_out': {'output': output}}
334 |
--------------------------------------------------------------------------------
/pruning_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import utils
3 |
4 |
5 | def no_pruning(model, dataset, pruning_every=100):
6 | return
7 |
8 |
9 | def pruning_occupancy(model, dataset, threshold=-10):
10 | model_input = dataset.get_eval_samples(1)
11 |
12 | print("Pruning: loading data to cuda...")
13 | tmp = {}
14 | for key, value in model_input.items():
15 | if isinstance(value, torch.Tensor):
16 | tmp.update({key: value[None, ...].cuda()})
17 | else:
18 | tmp.update({key: value})
19 | model_input = tmp
20 |
21 | print("Pruning: evaluating occupancy...")
22 | pred_occupancy = utils.process_batch_in_chunks(model_input, model)['model_out']['output']
23 | pred_occupancy = torch.max(pred_occupancy, dim=-2).values.squeeze()
24 | pred_occupancy_idx = model_input['coord_octant_idx'].squeeze()
25 |
26 | print("Pruning: computing mean and freezing empty octants")
27 | active_octants = dataset.octtree.get_active_octants()
28 |
29 | frozen_octants = 0
30 | for idx, octant in enumerate(active_octants):
31 | max_prediction = torch.max(pred_occupancy[pred_occupancy_idx == idx])
32 | if max_prediction < threshold and octant.err < 1e-3: # Prune if model is confident that everything is empty
33 | octant.frozen = True
34 | frozen_octants += 1
35 | print(f"Pruning: Froze {frozen_octants} octants.")
36 | dataset.synchronize()
37 |
--------------------------------------------------------------------------------
/training.py:
--------------------------------------------------------------------------------
1 | '''Implements a generic training loop.
2 | '''
3 |
4 | import torch
5 | import utils
6 | from torch.utils.tensorboard import SummaryWriter
7 | from tqdm.autonotebook import tqdm
8 | import time
9 | import numpy as np
10 | import os
11 | import shutil
12 |
13 |
14 | def train(model, train_dataloader, epochs, lr, steps_til_summary, epochs_til_checkpoint, model_dir,
15 | loss_fn, pruning_fn, summary_fn, double_precision=False, clip_grad=False,
16 | loss_schedules=None, resume_checkpoint={}, objs_to_save={}, epochs_til_pruning=4):
17 | optim = torch.optim.Adam(lr=lr, params=model.parameters())
18 |
19 | # load optimizer if supplied
20 | if 'optimizer_state_dict' in resume_checkpoint:
21 | optim.load_state_dict(resume_checkpoint['optimizer_state_dict'])
22 |
23 | for g in optim.param_groups:
24 | g['lr'] = lr
25 |
26 | if os.path.exists(os.path.join(model_dir, 'summaries')):
27 | val = input("The model directory %s exists. Overwrite? (y/n)" % model_dir)
28 | if val == 'y':
29 | if os.path.exists(os.path.join(model_dir, 'summaries')):
30 | shutil.rmtree(os.path.join(model_dir, 'summaries'))
31 | if os.path.exists(os.path.join(model_dir, 'checkpoints')):
32 | shutil.rmtree(os.path.join(model_dir, 'checkpoints'))
33 |
34 | os.makedirs(model_dir, exist_ok=True)
35 |
36 | summaries_dir = os.path.join(model_dir, 'summaries')
37 | utils.cond_mkdir(summaries_dir)
38 |
39 | checkpoints_dir = os.path.join(model_dir, 'checkpoints')
40 | utils.cond_mkdir(checkpoints_dir)
41 |
42 | writer = SummaryWriter(summaries_dir)
43 | total_steps = 0
44 | if 'total_steps' in resume_checkpoint:
45 | total_steps = resume_checkpoint['total_steps']
46 |
47 | start_epoch = 0
48 | if 'epoch' in resume_checkpoint:
49 | start_epoch = resume_checkpoint['epoch']
50 |
51 | with tqdm(total=len(train_dataloader) * epochs) as pbar:
52 | pbar.update(total_steps)
53 | train_losses = []
54 | for epoch in range(start_epoch, epochs):
55 | if not epoch % epochs_til_checkpoint and epoch:
56 | torch.save(model.state_dict(),
57 | os.path.join(checkpoints_dir, 'model_%06d.pth' % total_steps))
58 | np.savetxt(os.path.join(checkpoints_dir, 'train_losses_%06d.txt' % total_steps),
59 | np.array(train_losses))
60 | save_dict = {'epoch': epoch,
61 | 'total_steps': total_steps,
62 | 'optimizer_state_dict': optim.state_dict()}
63 | save_dict.update(objs_to_save)
64 | torch.save(save_dict, os.path.join(checkpoints_dir, 'optim_%06d.pth' % total_steps))
65 |
66 | # prune
67 | if not epoch % epochs_til_pruning and epoch:
68 | pruning_fn(model, train_dataloader.dataset)
69 |
70 | if not (epoch + 1) % epochs_til_pruning:
71 | retile = False
72 | else:
73 | retile = True
74 |
75 | for step, (model_input, gt) in enumerate(train_dataloader):
76 | start_time = time.time()
77 |
78 | tmp = {}
79 | for key, value in model_input.items():
80 | if isinstance(value, torch.Tensor):
81 | tmp.update({key: value.cuda()})
82 | else:
83 | tmp.update({key: value})
84 | model_input = tmp
85 |
86 | tmp = {}
87 | for key, value in gt.items():
88 | if isinstance(value, torch.Tensor):
89 | tmp.update({key: value.cuda()})
90 | else:
91 | tmp.update({key: value})
92 | gt = tmp
93 |
94 | if double_precision:
95 | model_input = {key: value.double() for key, value in model_input.items()}
96 | gt = {key: value.double() for key, value in gt.items()}
97 |
98 | model_output = model(model_input)
99 | losses = loss_fn(model_output, gt, total_steps, retile=retile)
100 |
101 | train_loss = 0.
102 | for loss_name, loss in losses.items():
103 | single_loss = loss.mean()
104 |
105 | if loss_schedules is not None and loss_name in loss_schedules:
106 | writer.add_scalar(loss_name + "_weight", loss_schedules[loss_name](total_steps), total_steps)
107 | single_loss *= loss_schedules[loss_name](total_steps)
108 |
109 | writer.add_scalar(loss_name, single_loss, total_steps)
110 | train_loss += single_loss
111 |
112 | train_losses.append(train_loss.item())
113 | writer.add_scalar("total_train_loss", train_loss, total_steps)
114 |
115 | optim.zero_grad()
116 | train_loss.backward()
117 |
118 | if clip_grad:
119 | if isinstance(clip_grad, bool):
120 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.)
121 | else:
122 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad)
123 |
124 | optim.step()
125 |
126 | pbar.update(1)
127 |
128 | if not total_steps % steps_til_summary:
129 | tqdm.write("Epoch %d, Total loss %0.6f, iteration time %0.6f" % (epoch, train_loss, time.time() - start_time))
130 | summary_fn(model, model_input, gt, model_output, writer, total_steps)
131 |
132 | total_steps += 1
133 |
134 | # after epoch
135 | tqdm.write("Epoch %d, Total loss %0.6f, iteration time %0.6f" % (epoch, train_loss, time.time() - start_time))
136 |
137 | # save model at end of epoch
138 | torch.save(model.state_dict(),
139 | os.path.join(checkpoints_dir, 'model_final_%06d.pth' % total_steps))
140 | np.savetxt(os.path.join(checkpoints_dir, 'train_losses_final_%06d.txt' % total_steps),
141 | np.array(train_losses))
142 | save_dict = {'epoch': epoch,
143 | 'total_steps': total_steps,
144 | 'optimizer_state_dict': optim.state_dict()}
145 | save_dict.update(objs_to_save)
146 | torch.save(save_dict, os.path.join(checkpoints_dir, 'optim_final_%06d.pth' % total_steps))
147 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import os
4 | from torchvision.utils import make_grid
5 | import skimage.measure
6 | from tqdm import tqdm
7 | import mrcfile
8 |
9 |
10 | def cond_mkdir(path):
11 | if not os.path.exists(path):
12 | os.makedirs(path)
13 |
14 |
15 | def write_psnr(pred_img, gt_img, writer, iter, prefix):
16 | batch_size = pred_img.shape[0]
17 |
18 | pred_img = pred_img.detach().cpu().numpy()
19 | gt_img = gt_img.detach().cpu().numpy()
20 |
21 | psnrs = list()
22 | for i in range(batch_size):
23 | p = pred_img[i].transpose(1, 2, 0)
24 | trgt = gt_img[i].transpose(1, 2, 0)
25 |
26 | p = (p / 2.) + 0.5
27 | p = np.clip(p, a_min=0., a_max=1.)
28 | trgt = (trgt / 2.) + 0.5
29 |
30 | psnr = skimage.metrics.peak_signal_noise_ratio(p, trgt, data_range=1)
31 | psnrs.append(psnr)
32 |
33 | writer.add_scalar(prefix + "psnr", np.mean(psnrs), iter)
34 |
35 |
36 | def write_occupancy_multiscale_summary(image_resolution, dataset, model, model_input, gt,
37 | model_output, writer, total_steps, prefix='train_',
38 | output_mrc='test.mrc', skip=False,
39 | oversample=1.0, max_chunk_size=1024, mode='binary'):
40 | if skip:
41 | return
42 |
43 | model_input = dataset.get_eval_samples(oversample)
44 |
45 | print("Summary: Write occupancy multiscale summary...")
46 |
47 | # convert to cuda and add batch dimension
48 | tmp = {}
49 | for key, value in model_input.items():
50 | if isinstance(value, torch.Tensor):
51 | tmp.update({key: value[None, ...]})
52 | else:
53 | tmp.update({key: value})
54 | model_input = tmp
55 |
56 | print("Summary: processing...")
57 | pred_occupancy = process_batch_in_chunks(model_input, model, max_chunk_size=max_chunk_size)['model_out']['output']
58 |
59 | # get voxel idx for each coordinate
60 | coords = model_input['fine_abs_coords'].detach().cpu().numpy()
61 | voxel_idx = np.floor((coords + 1.) / 2. * (dataset.sidelength[0] * oversample)).astype(np.int32)
62 | voxel_idx = voxel_idx.reshape(-1, 3)
63 |
64 | # init a new occupancy volume
65 | display_occupancy = -1 * np.ones(image_resolution, dtype=np.float32)
66 |
67 | # assign predicted voxel occupancy values into the array
68 | pred_occupancy = pred_occupancy.reshape(-1, 1).detach().cpu().numpy()
69 | display_occupancy[voxel_idx[:, 0], voxel_idx[:, 1], voxel_idx[:, 2]] = pred_occupancy[..., 0]
70 |
71 | print(f"Summary: write MRC file {image_resolution}")
72 | if mode == 'hq':
73 | print("\tWriting float")
74 | with mrcfile.new_mmap(output_mrc, overwrite=True, shape=image_resolution, mrc_mode=2) as mrc:
75 | mrc.data[voxel_idx[:, 0], voxel_idx[:, 1], voxel_idx[:, 2]] = pred_occupancy[..., 0]
76 | elif mode == 'binary':
77 | print("\tWriting binary")
78 | with mrcfile.new_mmap(output_mrc, overwrite=True, shape=image_resolution) as mrc:
79 | mrc.data[voxel_idx[:, 0], voxel_idx[:, 1], voxel_idx[:, 2]] = pred_occupancy[..., 0] > 0
80 |
81 | if writer is not None:
82 | print("Summary: Draw octtree")
83 | fig = dataset.octtree.draw()
84 | writer.add_figure(prefix + 'tiling', fig, global_step=total_steps)
85 |
86 | return display_occupancy
87 |
88 |
89 | def write_image_patch_multiscale_summary(image_resolution, patch_size, dataset, model, model_input, gt,
90 | model_output, writer, total_steps, prefix='train_',
91 | model_type='multiscale', skip=False):
92 | if skip:
93 | return
94 |
95 | # uniformly sample the image
96 | dataset.toggle_eval()
97 | model_input, gt = dataset[0]
98 | dataset.toggle_eval()
99 |
100 | # convert to cuda and add batch dimension
101 | tmp = {}
102 | for key, value in model_input.items():
103 | if isinstance(value, torch.Tensor):
104 | tmp.update({key: value[None, ...].cpu()})
105 | else:
106 | tmp.update({key: value})
107 | model_input = tmp
108 |
109 | tmp = {}
110 | for key, value in gt.items():
111 | if isinstance(value, torch.Tensor):
112 | tmp.update({key: value[None, ...].cpu()})
113 | else:
114 | tmp.update({key: value})
115 | gt = tmp
116 |
117 | # run the model on uniform samples
118 | n_channels = gt['img'].shape[-1]
119 | pred_img = process_batch_in_chunks(model_input, model)['model_out']['output']
120 |
121 | # get pixel idx for each coordinate
122 | coords = model_input['fine_abs_coords'].detach().cpu().numpy()
123 | pixel_idx = np.zeros_like(coords).astype(np.int32)
124 | pixel_idx[..., 0] = np.round((coords[..., 0] + 1.)/2. * (dataset.sidelength[0]-1)).astype(np.int32)
125 | pixel_idx[..., 1] = np.round((coords[..., 1] + 1.)/2. * (dataset.sidelength[1]-1)).astype(np.int32)
126 | pixel_idx = pixel_idx.reshape(-1, 2)
127 |
128 | # get pixel idx for each coordinate in frozen patches
129 | frozen_coords, frozen_values = dataset.get_frozen_patches()
130 | if frozen_coords is not None:
131 | frozen_coords = frozen_coords.detach().cpu().numpy()
132 | frozen_pixel_idx = np.zeros_like(frozen_coords).astype(np.int32)
133 | frozen_pixel_idx[..., 0] = np.round((frozen_coords[..., 0] + 1.) / 2. * (dataset.sidelength[0] - 1)).astype(np.int32)
134 | frozen_pixel_idx[..., 1] = np.round((frozen_coords[..., 1] + 1.) / 2. * (dataset.sidelength[1] - 1)).astype(np.int32)
135 | frozen_pixel_idx = frozen_pixel_idx.reshape(-1, 2)
136 |
137 | # init a new reconstructed image
138 | display_pred = np.zeros((*dataset.sidelength, n_channels))
139 |
140 | # assign predicted image values into a new array
141 | # need to use numpy since it supports index assignment
142 | pred_img = pred_img.reshape(-1, n_channels).detach().cpu().numpy()
143 | display_pred[[pixel_idx[:, 0]], [pixel_idx[:, 1]]] = pred_img
144 |
145 | # assign frozen image values into the array too
146 | if frozen_coords is not None:
147 | frozen_values = frozen_values.reshape(-1, n_channels).detach().cpu().numpy()
148 | display_pred[[frozen_pixel_idx[:, 0]], [frozen_pixel_idx[:, 1]]] = frozen_values
149 |
150 | # show reconstructed img
151 | display_pred = torch.tensor(display_pred)[None, ...]
152 | display_pred = display_pred.permute(0, 3, 1, 2)
153 |
154 | gt_img = gt['img'].reshape(-1, n_channels).detach().cpu().numpy()
155 | display_gt = np.zeros((*dataset.sidelength, n_channels))
156 | display_gt[[pixel_idx[:, 0]], [pixel_idx[:, 1]]] = gt_img
157 | display_gt = torch.tensor(display_gt)[None, ...]
158 | display_gt = display_gt.permute(0, 3, 1, 2)
159 |
160 | fig = dataset.quadtree.draw()
161 | writer.add_figure(prefix + 'tiling', fig, global_step=total_steps)
162 |
163 | if 'img' in gt:
164 | output_vs_gt = torch.cat((display_gt, display_pred), dim=0)
165 | writer.add_image(prefix + 'gt_vs_pred', make_grid(output_vs_gt, scale_each=False, normalize=True),
166 | global_step=total_steps)
167 | write_psnr(display_pred, display_gt, writer, total_steps, prefix+'img_')
168 |
169 |
170 | def dict2cuda(a_dict):
171 | tmp = {}
172 | for key, value in a_dict.items():
173 | if isinstance(value, torch.Tensor):
174 | tmp.update({key: value.cuda()})
175 | else:
176 | tmp.update({key: value})
177 | return tmp
178 |
179 |
180 | def dict2cpu(a_dict):
181 | tmp = {}
182 | for key, value in a_dict.items():
183 | if isinstance(value, torch.Tensor):
184 | tmp.update({key: value.cpu()})
185 | elif isinstance(value, dict):
186 | tmp.update({key: dict2cpu(value)})
187 | else:
188 | tmp.update({key: value})
189 | return tmp
190 |
191 |
192 | def process_batch_in_chunks(in_dict, model, max_chunk_size=1024, progress=None):
193 |
194 | in_chunked = []
195 | for key in in_dict:
196 | chunks = torch.split(in_dict[key], max_chunk_size, dim=1)
197 | in_chunked.append(chunks)
198 |
199 | list_chunked_batched_in = \
200 | [{k: v for k, v in zip(in_dict.keys(), curr_chunks)} for curr_chunks in zip(*in_chunked)]
201 | del in_chunked
202 |
203 | list_chunked_batched_out_out = {}
204 | list_chunked_batched_out_in = {}
205 | for chunk_batched_in in tqdm(list_chunked_batched_in):
206 | chunk_batched_in = {k: v.cuda() for k, v in chunk_batched_in.items()}
207 | tmp = model(chunk_batched_in)
208 | tmp = dict2cpu(tmp)
209 |
210 | for key in tmp['model_out']:
211 | if tmp['model_out'][key] is None:
212 | continue
213 |
214 | out_ = tmp['model_out'][key].detach().clone().requires_grad_(False)
215 | list_chunked_batched_out_out.setdefault(key, []).append(out_)
216 |
217 | for key in tmp['model_in']:
218 | if tmp['model_in'][key] is None:
219 | continue
220 |
221 | in_ = tmp['model_in'][key].detach().clone().requires_grad_(False)
222 | list_chunked_batched_out_in.setdefault(key, []).append(in_)
223 |
224 | del tmp, chunk_batched_in
225 |
226 | # Reassemble the output chunks in a batch
227 | batched_out = {}
228 | for key in list_chunked_batched_out_out:
229 | batched_out_lin = torch.cat(list_chunked_batched_out_out[key], dim=1)
230 | batched_out[key] = batched_out_lin
231 |
232 | batched_in = {}
233 | for key in list_chunked_batched_out_in:
234 | batched_in_lin = torch.cat(list_chunked_batched_out_in[key], dim=1)
235 | batched_in[key] = batched_in_lin
236 |
237 | return {'model_in': batched_in, 'model_out': batched_out}
238 |
239 |
240 | def subsample_dict(in_dict, num_views, multiscale=False):
241 | if multiscale:
242 | out = {}
243 | for k, v in in_dict.items():
244 | if v.shape[0] == in_dict['octant_coords'].shape[0]:
245 | # this is arranged by blocks
246 | out.update({k: v[0:num_views[0]]})
247 | else:
248 | # arranged by rays
249 | out.update({k: v[0:num_views[1]]})
250 | else:
251 | out = {key: value[0:num_views, ...] for key, value in in_dict.items()}
252 |
253 | return out
254 |
--------------------------------------------------------------------------------