├── .gitignore ├── LICENSE ├── README.md ├── configs ├── default.yaml └── lig │ ├── lig_eval.yaml │ └── lig_pretrained.yaml ├── main └── run_lig.py ├── pipelines ├── __init__.py ├── config.py ├── lig │ ├── __init__.py │ ├── config.py │ ├── generation.py │ └── models │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ ├── layers.py │ │ ├── method.py │ │ └── optimizer.py └── utils │ ├── eval_utils.py │ ├── libchamfer │ ├── __init__.py │ ├── chamfer.cp37-win_amd64.pyd │ ├── chamfer.cu │ ├── chamfer_cuda.cpp │ └── dist_chamfer.py │ ├── lr_schedulers.py │ ├── point_utils.py │ └── postprocess_utils.py ├── pretrained_models └── lig │ └── model_best.pt ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | **/*.so 3 | **/__pycache__/git * 4 | **/__pycache__/* 5 | *.out 6 | out 7 | **/*.ply 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 wnbzhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A PyTorch reimplementation of Local Implicit Grid Representations for 3D Scenes 2 | 3 | This project is a PyTorch implementation of [LIG](http://maxjiang.ml/proj/lig). 4 | The codes is based on authors' Tensorflow implementation [here](https://github.com/tensorflow/graphics/tree/master/tensorflow_graphics/projects/local_implicit_grid), 5 | 6 | ## Prepare Environment 7 | ``` 8 | pip install -r requirements.txt 9 | python setup.py build_ext --inplace 10 | ``` 11 | 12 | ## Perform Reconstruction 13 | ``` 14 | python main/run_lig.py --input_ply demo_data/living_room_33_1000_per_m2.ply --output_ply test.ply 15 | ``` -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | method: onet 2 | data: 3 | dataset: Shapes3D 4 | path: /media/data3/occupancy_networks/data/ShapeNet/ 5 | classes: null 6 | input_type: img 7 | train_split: train 8 | val_split: val 9 | test_split: test 10 | dim: 3 11 | points_file: points.npz 12 | points_iou_file: points.npz 13 | points_subsample: 1024 14 | points_unpackbits: true 15 | model_file: model.off 16 | watertight_file: model_watertight.off 17 | img_folder: img_choy2016 18 | img_size: 224 19 | img_with_camera: false 20 | img_augment: false 21 | n_views: 24 22 | pointcloud_file: pointcloud.npz 23 | pointcloud_chamfer_file: pointcloud.npz 24 | pointcloud_n: 256 25 | pointcloud_target_n: 1024 26 | pointcloud_noise: 0.05 27 | voxels_file: 'model.binvox' 28 | with_transforms: false 29 | model: 30 | decoder: simple 31 | encoder: resnet18 32 | encoder_latent: null 33 | decoder_kwargs: {} 34 | encoder_kwargs: {} 35 | encoder_latent_kwargs: {} 36 | multi_gpu: false 37 | c_dim: 512 38 | z_dim: 64 39 | use_camera: false 40 | dmc_weight_prior: 10. 41 | training: 42 | out_dir: out/default 43 | batch_size: 64 44 | print_every: 10 45 | visualize_every: 2000 46 | checkpoint_every: 1000 47 | validate_every: 2000 48 | backup_every: 100000 49 | eval_sample: false 50 | model_selection_metric: loss 51 | model_selection_mode: minimize 52 | test: 53 | threshold: 0.5 54 | eval_mesh: true 55 | eval_pointcloud: true 56 | model_file: model_best.pt 57 | eval_npoints: 100000 58 | generation: 59 | batch_size: 100000 60 | refinement_step: 0 61 | vis_n_outputs: 30 62 | generate_mesh: true 63 | generate_pointcloud: true 64 | generation_dir: generation 65 | use_sampling: false 66 | resolution_0: 32 67 | upsampling_steps: 2 68 | simplify_nfaces: null 69 | copy_groundtruth: false 70 | copy_input: true 71 | latent_number: 4 72 | latent_H: 8 73 | latent_W: 8 74 | latent_ny: 2 75 | latent_nx: 2 76 | latent_repeat: true 77 | preprocessor: 78 | type: null 79 | config: "" 80 | model_file: null 81 | -------------------------------------------------------------------------------- /configs/lig/lig_eval.yaml: -------------------------------------------------------------------------------- 1 | method: lig 2 | exp_name: lig_object_level 3 | model: 4 | encoder: null 5 | encoder_kwargs: null 6 | decoder: imnet 7 | decoder_kwargs: 8 | dim: 3 9 | in_features: 32 10 | out_features: 1 11 | num_filters: 32 12 | overlap: true # if false, must set indep_pt_loss==false 13 | part_size: 0.24 14 | res_per_part: 0 15 | test: 16 | threshold: 0.2 17 | eval_mesh: true 18 | eval_pointcloud: false 19 | generation: 20 | out_dir: . 21 | points_batch: 20000 22 | conservative: true 23 | postprocess: true 24 | indep_pt_loss: true 25 | optimizer_kwargs: 26 | latent_size: 32 27 | alpha_lat: 0.01 28 | num_optim_samples: 10000 29 | init_std: 0.02 30 | learning_rate: 0.001 31 | optim_steps: 10000 32 | print_every_n_steps: 1000 33 | 34 | -------------------------------------------------------------------------------- /configs/lig/lig_pretrained.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/lig/lig_eval.yaml 2 | generation: 3 | generation_dir: local_implicit_grid -------------------------------------------------------------------------------- /main/run_lig.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | import numpy as np 5 | import sys 6 | sys.path.append('./') 7 | from pipelines import config 8 | from pipelines.utils.point_utils import read_point_ply 9 | 10 | parser = argparse.ArgumentParser(description='Extract meshes from occupancy process.') 11 | parser.add_argument('--config', default='configs/lig/lig_pretrained.yaml', type=str, help='Path to config file.') 12 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 13 | parser.add_argument('--input_ply', type=str, help='Input object file') 14 | parser.add_argument('--output_ply', type=str, help='Output object file') 15 | parser.add_argument('--gen', action='store_true', help='to generate mesh, no training') 16 | parser.add_argument('--continue_training', action='store_true', help='whether to continue training') 17 | parser.add_argument('--model', type=str, default='pretrained_models/lig/model_best.pt', help='pretrained model path') 18 | parser.add_argument('--debug', action='store_true', help='whether it is debug mode') 19 | parser.add_argument('--normalized', action='store_true', help='whether normalize the input') 20 | args = parser.parse_args() 21 | print(str(args)) 22 | 23 | cfg = config.load_config(args.config, 'configs/default.yaml') 24 | assert not np.logical_and(args.gen, args.continue_training), "Cannot be generation mode and training mode at the same time" 25 | if args.gen: 26 | assert args.model != '', "Pretrained model path shouldn't be empty in generation mode" 27 | cfg['generation']['optimizer_kwargs']['optim_steps'] = 0 28 | if args.continue_training: 29 | assert args.model != '', "Pretrained model path shouldn't be empty in continue training mode" 30 | cfg['generation']['optimizer_kwargs']['continue_training'] = args.continue_training 31 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 32 | device = torch.device("cuda" if is_cuda else "cpu") 33 | 34 | # Fix seed of numpy and torch to make results reproducable 35 | np.random.seed(0) 36 | torch.manual_seed(0) 37 | 38 | # Model 39 | model = config.get_model(cfg, device=device) 40 | 41 | print('!!! Model Loaded !!! ') 42 | out_dir = cfg['generation']['out_dir'] 43 | 44 | # Initialize generation directory 45 | file_path, file_name = os.path.split(args.input_ply) 46 | obj_name, ext = os.path.splitext(file_name) 47 | generation_dir = os.path.join(out_dir, obj_name + '_debug' if args.debug else obj_name) 48 | if not os.path.exists(generation_dir): 49 | os.makedirs(generation_dir) 50 | 51 | output_path = './' 52 | # Set pretrained path and output path of the model 53 | model_path = os.path.join(output_path, 'model') 54 | cfg['model']['model_path'] = model_path 55 | cfg['model']['pretrained_path'] = args.model 56 | 57 | # Load pretrained weight 58 | if args.model != '': 59 | pretrained_dict = torch.load(args.model) 60 | model_dict = model.state_dict() 61 | update_dict = {k : v for k, v in pretrained_dict.items() if k in model_dict.keys()} 62 | model_dict.update(update_dict) 63 | model.load_state_dict(model_dict) 64 | print(f"Total {len(model_dict)} parameters, updated {len(update_dict)} parameters") 65 | 66 | # Generator 67 | generator = config.get_generator(model, cfg, device=device, output_path=output_path) 68 | 69 | v, n = read_point_ply(args.input_ply) 70 | v = v.astype(np.float32) 71 | n = n.astype(np.float32) 72 | 73 | # Normalize to unit sphere 74 | if args.normalized: 75 | v = v - v.mean(axis=0) 76 | v = v / (np.linalg.norm(v, ord=2, axis=1).max() + 1e-12) 77 | 78 | mesh = generator.generate_single_obj_mesh(v, n) 79 | 80 | # Write output 81 | mesh.export(args.output_ply) -------------------------------------------------------------------------------- /pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wnbzhao/Local-Implicit-Grid-Pytorch/d45da37beda52653f0066f9ba0f0500c54402e13/pipelines/__init__.py -------------------------------------------------------------------------------- /pipelines/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from torchvision import transforms 3 | from pipelines import lig 4 | 5 | method_dict = { 6 | 'lig': lig, 7 | } 8 | 9 | 10 | # General config 11 | def load_config(path, default_path=None): 12 | ''' Loads config file. 13 | 14 | Args: 15 | path (str): path to config file 16 | default_path (bool): whether to use default path 17 | ''' 18 | # Load configuration from file itself 19 | with open(path, 'r') as f: 20 | cfg_special = yaml.load(f) 21 | 22 | # Check if we should inherit from a config 23 | inherit_from = cfg_special.get('inherit_from') 24 | 25 | # If yes, load this config first as default 26 | # If no, use the default_path 27 | if inherit_from is not None: 28 | cfg = load_config(inherit_from, default_path) 29 | elif default_path is not None: 30 | with open(default_path, 'r') as f: 31 | cfg = yaml.load(f) 32 | else: 33 | cfg = dict() 34 | 35 | # Include main configuration 36 | update_recursive(cfg, cfg_special) 37 | 38 | return cfg 39 | 40 | 41 | def update_recursive(dict1, dict2): 42 | ''' Update two config dictionaries recursively. 43 | 44 | Args: 45 | dict1 (dict): first dictionary to be updated 46 | dict2 (dict): second dictionary which entries should be used 47 | 48 | ''' 49 | for k, v in dict2.items(): 50 | if k not in dict1: 51 | dict1[k] = dict() 52 | if isinstance(v, dict): 53 | update_recursive(dict1[k], v) 54 | else: 55 | dict1[k] = v 56 | 57 | 58 | # Models 59 | def get_model(cfg, device=None, dataset=None): 60 | ''' Returns the model instance. 61 | 62 | Args: 63 | cfg (dict): config dictionary 64 | device (device): pytorch device 65 | dataset (dataset): dataset 66 | ''' 67 | method = cfg['method'] 68 | model = method_dict[method].config.get_model(cfg, device=device, dataset=dataset) 69 | return model 70 | 71 | 72 | # Generator for final mesh extraction 73 | def get_generator(model, cfg, device, **kwargs): 74 | ''' Returns a generator instance. 75 | Args: 76 | model (nn.Module): the model which is used 77 | cfg (dict): config dictionary 78 | device (device): pytorch device 79 | ''' 80 | method = cfg['method'] 81 | generator = method_dict[method].config.get_generator(model, cfg, device, **kwargs) 82 | return generator 83 | -------------------------------------------------------------------------------- /pipelines/lig/__init__.py: -------------------------------------------------------------------------------- 1 | from pipelines.lig import ( 2 | config, generation, models 3 | ) 4 | 5 | __all__ = [ 6 | config, generation, models 7 | ] 8 | -------------------------------------------------------------------------------- /pipelines/lig/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributions as dist 4 | import torch.optim as optim 5 | import torch.nn as nn 6 | from pipelines import config 7 | from pipelines.lig import models, generation 8 | from pipelines.lig.models import LocalImplicitGrid 9 | from pipelines.lig.models.layers import GridInterpolationLayer 10 | from pipelines.lig.models.optimizer import LIGOptimizer 11 | import ipdb 12 | 13 | 14 | def get_model(cfg, device=None, dataset=None, **kwargs): 15 | ''' Return the Local Implicit Grid Network model. 16 | 17 | Args: 18 | cfg (dict): imported yaml config 19 | device (device): pytorch device 20 | dataset (dataset): dataset 21 | ''' 22 | encoder = cfg['model']['encoder'] 23 | encoder_kwargs = cfg['model']['encoder_kwargs'] 24 | decoder = cfg['model']['decoder'] 25 | decoder_kwargs = cfg['model']['decoder_kwargs'] 26 | if encoder is not None: 27 | encoder = models.encoder_dict[encoder](**encoder_kwargs).to(device) 28 | decoder = models.decoder_dict[decoder](**decoder_kwargs).to(device) 29 | grid_interp_layer = GridInterpolationLayer() 30 | method = 'linear' if cfg['model']['overlap'] else 'nn' 31 | x_location_max = 1.0 if cfg['model']['overlap'] else 2.0 32 | interp = not cfg['generation']['indep_pt_loss'] 33 | 34 | model = LocalImplicitGrid(encoder, decoder, grid_interp_layer, 35 | method, x_location_max, interp, device) 36 | 37 | return model 38 | 39 | 40 | def get_generator(model, cfg, device, **kwargs): 41 | ''' Returns the generator object. 42 | 43 | Args: 44 | model (nn.Module): Occupancy Network model 45 | cfg (dict): imported yaml config 46 | device (device): pytorch device 47 | ''' 48 | # Optimizer parameters 49 | optimizer_kwargs = cfg['generation']['optimizer_kwargs'] 50 | latent_size = optimizer_kwargs['latent_size'] 51 | alpha_lat = optimizer_kwargs['alpha_lat'] 52 | num_optim_samples = optimizer_kwargs['num_optim_samples'] 53 | init_std = optimizer_kwargs['init_std'] 54 | learning_rate = optimizer_kwargs['learning_rate'] 55 | optim_steps = optimizer_kwargs['optim_steps'] 56 | print_every_n_steps = optimizer_kwargs['print_every_n_steps'] 57 | indep_pt_loss = cfg['generation']['indep_pt_loss'] 58 | 59 | optimizer = LIGOptimizer(model, latent_size=latent_size, alpha_lat=alpha_lat, 60 | num_optim_samples=num_optim_samples, init_std=init_std, 61 | learning_rate=learning_rate, optim_steps=optim_steps, 62 | print_every_n_steps=print_every_n_steps, 63 | indep_pt_loss=indep_pt_loss, device=device) 64 | 65 | # Generator parameters 66 | part_size = cfg['model']['part_size'] 67 | res_per_part = cfg['model']['res_per_part'] 68 | overlap = cfg['model']['overlap'] 69 | points_batch = cfg['generation']['points_batch'] 70 | conservative = cfg['generation']['conservative'] 71 | postprocess = cfg['generation']['postprocess'] 72 | 73 | generator = generation.Generator3D( 74 | model, optimizer, part_size=part_size, num_optim_samples=num_optim_samples, 75 | res_per_part=res_per_part, overlap=overlap, device=device, 76 | points_batch=points_batch, conservative=conservative, 77 | postprocess=postprocess 78 | ) 79 | return generator 80 | -------------------------------------------------------------------------------- /pipelines/lig/generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch import autograd 4 | import numpy as np 5 | from tqdm import trange 6 | import trimesh 7 | from skimage import measure 8 | import warnings 9 | import time 10 | from pipelines.utils.point_utils import sample_points_from_ray, np_get_occupied_idx, occupancy_sparse_to_dense 11 | from pipelines.utils.postprocess_utils import remove_backface 12 | 13 | 14 | class Generator3D(object): 15 | ''' Generator class for Local implicit grid Networks. 16 | 17 | It provides functions to generate the final mesh as well refining options. 18 | 19 | Args: 20 | model (nn.Module): trained Local implicit grid model 21 | optimizer (object): optimization utility class for optimizing latent grid 22 | part_size (float): size of a part 23 | num_optim_samples (int): number of points to sample at each optimization step 24 | res_per_part (int): how many parts we split a grid into 25 | overlap (bool): whether we use overlapping grids 26 | device (device): pytorch device 27 | points_batch (int): number of points we evaluate sdf values each time 28 | conservative (bool): whether we evaluate a grid when all of its 8 neighbors contain points 29 | postprocess (bool): whether to use post process to remove back faces 30 | ''' 31 | def __init__(self, 32 | model, 33 | optimizer, 34 | part_size=0.25, 35 | num_optim_samples=2048, 36 | res_per_part=0, 37 | overlap=True, 38 | device=None, 39 | points_batch=20000, 40 | conservative=False, 41 | postprocess=True): 42 | self.model = model.to(device) 43 | self.optimizer = optimizer 44 | self.part_size = part_size 45 | self.num_optim_samples = num_optim_samples 46 | if res_per_part == 0: 47 | self.res_per_part = int(64 * self.part_size) 48 | else: 49 | self.res_per_part = res_per_part 50 | self.overlap = overlap 51 | self.device = device 52 | self.points_batch = points_batch 53 | self.conservative = conservative 54 | self.postprocess = postprocess 55 | 56 | def generate_mesh(self, data, return_stats=True): 57 | ''' Generates the output mesh from inputs loaded from dataset. 58 | 59 | Args: 60 | data (tensor): data tensor 61 | return_stats (bool): whether stats should be returned 62 | ''' 63 | stats_dict = {} 64 | 65 | v = data.get('inputs', torch.empty(1, 0)).squeeze(0).cpu().numpy() 66 | n = data.get('inputs.normals', torch.empty(1, 0)).squeeze(0).cpu().numpy() 67 | mesh = self.generate_single_obj_mesh(v, n) 68 | return mesh 69 | 70 | def generate_single_obj_mesh(self, v, n): 71 | ''' Generates the output mesh of user specified single object. 72 | 73 | Args: 74 | v (numpy array): [#v, 3], input point cloud. 75 | n (numpy array): [#v, 3], normals of the input point cloud. 76 | Returns: 77 | mesh (trimesh.Trimesh obj): output mesh object. 78 | ''' 79 | device = self.device 80 | 81 | surface_points = np.concatenate([v, n], axis=1) 82 | 83 | xmin = np.min(v, axis=0) 84 | xmax = np.max(v, axis=0) 85 | 86 | # check if part size is too large 87 | min_bb = np.min(xmax - xmin) 88 | if self.part_size > 0.25 * min_bb: 89 | warnings.warn( 90 | 'WARNING: part_size seems too large. Recommend using a part_size < ' 91 | '{:.2f} for this shape.'.format(0.25 * min_bb), UserWarning) 92 | 93 | # add some extra slack to xmin and xmax 94 | xmin -= self.part_size 95 | xmax += self.part_size 96 | 97 | ######################################################################### 98 | # generate sdf samples from pc 99 | point_samples, sdf_values = sample_points_from_ray(v, n, sample_factor=10, std=0.01) 100 | 101 | # shuffle 102 | shuffle_index = np.random.permutation(point_samples.shape[0]) 103 | point_samples = point_samples[shuffle_index] 104 | sdf_values = sdf_values[shuffle_index] 105 | 106 | ######################################################################### 107 | ################### only evaluated at sparse grid location ############## 108 | ######################################################################### 109 | # get valid girds (we only evaluate on sparse locations) 110 | # _.shape==(total_ncrops, ntarget, v.shape[1]) points within voxel 111 | # occ_idx.shape==(total_ncrops, 3) index of each voxel 112 | # grid_shape == (rr[0], rr[1], rr[2]) 113 | _, occ_idx, grid_shape = np_get_occupied_idx( 114 | point_samples[:100000, :3], 115 | # point_samples[:, :3], 116 | xmin=xmin - 0.5 * self.part_size, 117 | xmax=xmax + 0.5 * self.part_size, 118 | crop_size=self.part_size, 119 | ntarget=1, # we do not require `point_crops` (i.e. `_` in returns), so we set it to 1 120 | overlap=self.overlap, 121 | normalize_crops=False, 122 | return_shape=True) 123 | 124 | print('LIG shape: {}'.format(grid_shape)) 125 | 126 | ######################################################################### 127 | # treat as one batch 128 | point_samples = torch.from_numpy(point_samples).to(device) 129 | sdf_values = torch.from_numpy(sdf_values).to(device) 130 | occ_idx_tensor = torch.from_numpy(occ_idx).to(device) 131 | point_samples = point_samples.unsqueeze(0) # shape==(1, npoints, 3) 132 | sdf_values = sdf_values.unsqueeze(0) # shape==(1, npoints, 1) 133 | occ_idx_tensor = occ_idx_tensor.unsqueeze(0) # shape==(1, total_ncrops, 3) 134 | 135 | # set range for computation 136 | true_shape = ((np.array(grid_shape) - 1) / (2.0 if self.overlap else 1.0)).astype(np.int32) 137 | self.model.set_xrange(xmin=xmin, xmax=xmin + true_shape * self.part_size) 138 | 139 | # Clip the point position 140 | xmin_ = self.model.grid_interp_layer.xmin 141 | xmax_ = self.model.grid_interp_layer.xmax 142 | x = point_samples[:, :, 0].clamp(xmin_[0], xmax_[0]) 143 | y = point_samples[:, :, 1].clamp(xmin_[1], xmax_[1]) 144 | z = point_samples[:, :, 2].clamp(xmin_[2], xmax_[2]) 145 | point_samples = torch.stack([x, y, z], dim=2) 146 | 147 | # get label (inside==-1, outside==+1) 148 | point_values = torch.sign(sdf_values) 149 | 150 | ######################################################################### 151 | ###################### Build/Optimize latent grid ####################### 152 | ######################################################################### 153 | # optimize latent grids, shape==(1, *grid_shape, code_len) 154 | print('Optimizing latent codes in LIG...') 155 | latent_grid = self.optimizer.optimize_latent_code(point_samples, point_values, occ_idx_tensor, grid_shape) 156 | 157 | ######################################################################### 158 | ##################### Evaluation (Marching Cubes) ####################### 159 | ######################################################################### 160 | # sparse occ index to dense occ grids 161 | # (total_ncrops, 3) --> (*grid_shape, ) bool 162 | occ_mask = occupancy_sparse_to_dense(occ_idx, grid_shape) 163 | 164 | # points shape to be evaluated 165 | output_grid_shape = list(self.res_per_part * true_shape) 166 | # output_grid is ones, shape==(?, ) 167 | # xyz is points to be evaluated (dense, shape==(?, 3)) 168 | output_grid, xyz = self.get_eval_grid(xmin=xmin, 169 | xmax=xmin + true_shape * self.part_size, 170 | output_grid_shape=output_grid_shape) 171 | 172 | # we only evaluate eval_points 173 | # out_mask is for xyz, i.e. eval_points = xyz[occ_mask] 174 | eval_points, out_mask = self.get_eval_inputs(xyz, xmin, occ_mask) 175 | eval_points = torch.from_numpy(eval_points).to(device) 176 | 177 | # evaluate dense grid for marching cubes (on sparse grids) 178 | output_grid = self.generate_occ_grid(latent_grid, eval_points, output_grid, out_mask) 179 | output_grid = output_grid.reshape(*output_grid_shape) 180 | 181 | v, f, _, _ = measure.marching_cubes_lewiner(output_grid, 0) # logits==0 182 | v *= (self.part_size / float(self.res_per_part) * (np.array(output_grid.shape, dtype=np.float32) / 183 | (np.array(output_grid.shape, dtype=np.float32) - 1))) 184 | v += xmin 185 | 186 | # Create mesh 187 | mesh = trimesh.Trimesh(v, f) 188 | 189 | # Post-process the generated mesh to prevent artifacts 190 | if self.postprocess: 191 | print('Postprocessing generated mesh...') 192 | mesh = remove_backface(mesh, surface_points) 193 | 194 | return mesh 195 | 196 | def get_eval_grid(self, xmin, xmax, output_grid_shape): 197 | """Initialize the eval output grid and its corresponding grid points. 198 | 199 | Args: 200 | xmin (numpy array): [3], minimum xyz values of the entire space. 201 | xmax (numpy array): [3], maximum xyz values of the entire space. 202 | output_grid_shape (list): [3], latent grid shape. 203 | Returns: 204 | output_grid (numpy array): [d*h*w] output grid sdf values. 205 | xyz (numpy array): [d*h*w, 3] grid point xyz coordinates. 206 | """ 207 | # setup grid 208 | eps = 1e-6 209 | l = [np.linspace(xmin[i] + eps, xmax[i] - eps, output_grid_shape[i]) for i in range(3)] 210 | xyz = np.stack(np.meshgrid(l[0], l[1], l[2], indexing='ij'), axis=-1).astype(np.float32) 211 | 212 | output_grid = np.ones(output_grid_shape, dtype=np.float32) 213 | xyz = xyz.reshape(-1, 3) 214 | output_grid = output_grid.reshape(-1) 215 | 216 | return output_grid, xyz 217 | 218 | def get_eval_inputs(self, xyz, xmin, occ_mask): 219 | """Gathers the points within the grids that any/all of its 8 neighbors 220 | contains points. 221 | 222 | If self.conservative is True, gathers the points within the grids that any of its 8 neighbors 223 | contains points. 224 | If self.conservative is False, gathers the points within the grids that all of its 8 neighbors 225 | contains points. 226 | Returns the points need to be evaluate and the mask of the points and the output grid. 227 | 228 | Args: 229 | xyz (numpy array): [h*w*d, 3] 230 | xmin (numpy array): [3] minimum value of the entire space. 231 | occ_mask (numpy array): latent grid occupancy mask. 232 | Returns: 233 | eval_points (numpy array): [neval, 3], points to be evaluated. 234 | out_mask (numpy array): [h*w*d], 0 1 value eval mask of the final sdf grid. 235 | """ 236 | mask = occ_mask.astype(np.bool) 237 | if self.overlap: 238 | mask = np.stack([ 239 | mask[:-1, :-1, :-1], mask[:-1, :-1, 1:], mask[:-1, 1:, :-1], mask[:-1, 1:, 1:], mask[1:, :-1, :-1], 240 | mask[1:, :-1, 1:], mask[1:, 1:, :-1], mask[1:, 1:, 1:] 241 | ], 242 | axis=-1) 243 | if self.conservative: 244 | mask = np.any(mask, axis=-1) 245 | else: 246 | mask = np.all(mask, axis=-1) 247 | 248 | g = np.stack(np.meshgrid(np.arange(mask.shape[0]), 249 | np.arange(mask.shape[1]), 250 | np.arange(mask.shape[2]), 251 | indexing='ij'), 252 | axis=-1).reshape(-1, 3) 253 | g = g[:, 0] * (mask.shape[1] * mask.shape[2]) + g[:, 1] * mask.shape[2] + g[:, 2] 254 | g_valid = g[mask.ravel()] # valid grid index 255 | 256 | if self.overlap: 257 | ijk = np.floor((xyz - xmin) / self.part_size * 2).astype(np.int32) 258 | else: 259 | ijk = np.floor((xyz - xmin + 0.5 * self.part_size) / self.part_size).astype(np.int32) 260 | ijk_idx = (ijk[:, 0] * (mask.shape[1] * mask.shape[2]) + ijk[:, 1] * mask.shape[2] + ijk[:, 2]) 261 | out_mask = np.isin(ijk_idx, g_valid) 262 | eval_points = xyz[out_mask] 263 | return eval_points, out_mask 264 | 265 | def generate_occ_grid(self, latent_grid, eval_points, output_grid, out_mask): 266 | """Gets the final output occ grid. 267 | 268 | Args: 269 | latent_grid (tensor): [1, *grid_shape, latent_size], optimized latent grid. 270 | eval_points (tensor): [neval, 3], points to be evaluated. 271 | output_grid (numpy array): [d*h*w], final output occ grid. 272 | out_mask (numpy array): [d*h*w], mask indicating the grids evaluated. 273 | Returns: 274 | output_grid (numpy array): [d*h*w], final output occ grid flattened. 275 | """ 276 | interp_old = self.model.interp 277 | self.model.interp = True 278 | 279 | split = int(np.ceil(eval_points.shape[0] / self.points_batch)) 280 | occ_val_list = [] 281 | self.model.eval() 282 | with torch.no_grad(): 283 | for s in range(split): 284 | sid = s * self.points_batch 285 | eid = min((s + 1) * self.points_batch, eval_points.shape[0]) 286 | eval_points_slice = eval_points[sid:eid, :] 287 | occ_vals = self.model.decode(latent_grid, eval_points_slice.unsqueeze(0)) 288 | occ_vals = occ_vals.squeeze(0).squeeze(1).cpu().numpy() 289 | occ_val_list.append(occ_vals) 290 | occ_vals = np.concatenate(occ_val_list, axis=0) 291 | output_grid[out_mask] = occ_vals 292 | 293 | self.model.interp = interp_old 294 | return output_grid 295 | -------------------------------------------------------------------------------- /pipelines/lig/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pipelines.lig.models import encoder, decoder 5 | 6 | encoder_dict = { 7 | 'unet3d': encoder.UNet3D, 8 | } 9 | 10 | decoder_dict = { 11 | 'imnet': decoder.IMNet, 12 | } 13 | 14 | 15 | class LocalImplicitGrid(nn.Module): 16 | def __init__( 17 | self, 18 | encoder, # 19 | decoder, # e.g. imnet 20 | grid_interp_layer, 21 | method, 22 | x_location_max, 23 | interp, 24 | device, 25 | ): 26 | super(LocalImplicitGrid, self).__init__() 27 | if encoder is not None: 28 | self.encoder = encoder.to(device) 29 | self.decoder = decoder.to(device) 30 | self.grid_interp_layer = grid_interp_layer 31 | self.method = method 32 | self.x_location_max = x_location_max 33 | self.interp = interp 34 | self.device = device 35 | # Print warning if x_location_max and method do not match 36 | if not ((x_location_max == 1 and method == "linear") or (x_location_max == 2 and method == "nn")): 37 | raise ValueError("Bad combination of x_location_max and method.") 38 | 39 | def forward(self, inputs, pts): 40 | grid = self.encoder(inputs) 41 | values = self.decode(grid, pts) 42 | return values 43 | 44 | def set_xrange(self, xmin, xmax): 45 | """Sets the xyz range during inference. 46 | 47 | Args: 48 | xmin (numpy array): minimum xyz values of input points. 49 | xmax (numpy array): maximum xyz values of input points. 50 | """ 51 | setattr(self.grid_interp_layer, 'xmin', torch.from_numpy(xmin).type(torch.cuda.FloatTensor)) 52 | setattr(self.grid_interp_layer, 'xmax', torch.from_numpy(xmax).type(torch.cuda.FloatTensor)) 53 | 54 | def decode(self, grid, pts): 55 | lat, weights, xloc = self.grid_interp_layer(grid, pts) 56 | xloc = xloc * self.x_location_max 57 | if self.method == "linear": 58 | input_features = torch.cat([xloc, lat], dim=3) # bs*npoints*nneighbors*c 59 | values = self.decoder(input_features) 60 | if self.interp: 61 | values = (values * weights.unsqueeze(3)).sum(dim=2) # bs*npoints*1 62 | else: 63 | values = (values, weights) 64 | else: 65 | # nearest neighbor 66 | bs, npoints, nneighbors, c = lat.size() 67 | nearest_neighbor_idxs = weights.max(dim=2, keepdim=True)[1] 68 | lat = torch.gather(lat, dim=2, index=nearest_neighbor_idxs.unsqueeze(3).expand(bs, npoints, 1, 69 | c)) # bs*npoints*1*c 70 | lat = lat.squeeze(2) # bs*npoints*c 71 | xloc = torch.gather(xloc, dim=2, index=nearest_neighbor_idxs.unsqueeze(3).expand(bs, npoints, 1, 3)) 72 | xloc = xloc.squeeze(2) # bs*npoints*3 73 | input_features = torch.cat([xloc, lat], dim=2) 74 | values = self.decoder(input_features) 75 | 76 | return values 77 | -------------------------------------------------------------------------------- /pipelines/lig/models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class IMNet(nn.Module): 6 | """ImNet layer py-torch implementation.""" 7 | 8 | def __init__(self, dim=3, in_features=128, out_features=1, num_filters=128, activation=nn.LeakyReLU(0.2)): 9 | """Initialization. 10 | 11 | Args: 12 | dim: int, dimension of input points. 13 | in_features: int, length of input features (i.e., latent code). 14 | out_features: number of output features. 15 | num_filters: int, width of the second to last layer. 16 | activation: activation function. 17 | """ 18 | super(IMNet, self).__init__() 19 | self.dim = dim 20 | self.in_features = in_features 21 | self.dimz = dim + in_features 22 | self.out_features = out_features 23 | self.num_filters = num_filters 24 | self.activ = activation 25 | self.fc0 = nn.Linear(self.dimz, num_filters * 16) 26 | self.fc1 = nn.Linear(self.dimz + num_filters * 16, num_filters * 8) 27 | self.fc2 = nn.Linear(self.dimz + num_filters * 8, num_filters * 4) 28 | self.fc3 = nn.Linear(self.dimz + num_filters * 4, num_filters * 2) 29 | self.fc4 = nn.Linear(self.dimz + num_filters * 2, num_filters * 1) 30 | self.fc5 = nn.Linear(num_filters * 1, out_features) 31 | self.fc = [self.fc0, self.fc1, self.fc2, self.fc3, self.fc4, self.fc5] 32 | 33 | def forward(self, x): 34 | """Forward method. 35 | 36 | Args: 37 | x: `[batch_size, dim+in_features]` tensor, inputs to decode. 38 | Returns: 39 | x_: output through this layer. 40 | """ 41 | x_ = x 42 | for dense in self.fc[:4]: 43 | x_ = self.activ(dense(x_)) 44 | x_ = torch.cat([x_, x], dim=-1) 45 | x_ = self.activ(self.fc4(x_)) 46 | x_ = self.fc5(x_) 47 | return x_ 48 | -------------------------------------------------------------------------------- /pipelines/lig/models/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from pipelines.lig.models.layers import ResBlock3D 5 | 6 | 7 | class UNet3D(nn.Module): 8 | """UNet that inputs even dimension grid and outputs even dimension grid.""" 9 | 10 | def __init__(self, 11 | dim=3, 12 | in_grid_res=32, 13 | out_grid_res=16, 14 | num_filters=16, 15 | max_filters=512, 16 | out_features=32): 17 | """Initialization. 18 | 19 | Args: 20 | in_grid_res: int, input grid resolution, must be powers of 2. 21 | out_grid_res: int, output grid resolution, must be powers of 2. 22 | num_filters: int, number of feature layers at smallest grid resolution. 23 | max_filters: int, max number of feature layers at any resolution. 24 | out_features: int, number of output feature channels. 25 | 26 | Raises: 27 | ValueError: if in_grid_res or out_grid_res is not powers of 2. 28 | """ 29 | super(UNet3D, self).__init__() 30 | self.in_grid_res = in_grid_res 31 | self.out_grid_res = out_grid_res 32 | self.num_filters = num_filters 33 | self.max_filters = max_filters 34 | self.out_features = out_features 35 | 36 | # assert dimensions acceptable 37 | if math.log(out_grid_res, 2) % 1 != 0 or math.log(in_grid_res, 2) % 1 != 0: 38 | raise ValueError('in_grid_res and out_grid_res must be 2**n.') 39 | 40 | self.num_in_level = math.log(self.in_grid_res, 2) 41 | self.num_out_level = math.log(self.out_grid_res, 2) 42 | self.num_in_level = int(self.num_in_level) # number of input levels 43 | self.num_out_level = int(self.num_out_level) # number of output levels 44 | 45 | self._create_layers() 46 | 47 | def _create_layers(self): 48 | num_filter_down = [ 49 | self.num_filters * (2 ** (i + 1)) for i in range(self.num_in_level) 50 | ] 51 | # num. features in downward path 52 | num_filter_down = [ 53 | n if n <= self.max_filters else self.max_filters 54 | for n in num_filter_down 55 | ] 56 | num_filter_up = num_filter_down[::-1][:self.num_out_level] 57 | self.num_filter_down = num_filter_down 58 | self.num_filter_up = num_filter_up 59 | self.conv_in = ResBlock3D(self.num_filters, self.num_filters) 60 | self.conv_out = ResBlock3D( 61 | self.out_features, self.out_features, final_relu=False) 62 | self.down_modules = [ResBlock3D(int(n / 2), n) for n in num_filter_down] 63 | self.up_modules = [ResBlock3D(n, n) for n in num_filter_up] 64 | self.dnpool = nn.MaxPool3d(2, stride=2) 65 | self.upsamp = nn.Upsample(scale_factor=2, mode='trilinear') 66 | self.up_final = nn.Upsample(scale_factor=2, mode='trilinear') 67 | 68 | def forward(self, x): 69 | """Forward method. 70 | 71 | Args: 72 | x: `[batch, in_grid_res, in_grid_res, in_grid_res, in_features]` tensor, 73 | input voxel grid. 74 | training: bool, flag indicating whether model is in training mode. 75 | 76 | Returns: 77 | `[batch, out_grid_res, out_grid_res, out_grid_res, out_features]` tensor, 78 | output voxel grid. 79 | """ 80 | x = self.conv_in(x) 81 | x_dns = [x] 82 | for mod in self.down_modules: 83 | x_ = self.dnpool(mod(x_dns[-1])) 84 | x_dns.append(x_) 85 | 86 | x_ups = [x_dns.pop(-1)] 87 | for mod in self.up_modules: 88 | x_ = torch.cat([self.upsamp(x_ups[-1]), x_dns.pop(-1)], dim=-1) 89 | x_ = mod(x_) 90 | x_ups.append(x_) 91 | 92 | x = self.conv_out(x_ups[-1]) 93 | return x 94 | -------------------------------------------------------------------------------- /pipelines/lig/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import ipdb 5 | 6 | 7 | class ResBlock3D(nn.Module): 8 | """3D convolutional Residual Block Layer. 9 | 10 | Maintains same resolution. 11 | """ 12 | def __init__(self, dim, neck_channels, out_channels, final_relu=True): 13 | """Initialization. 14 | 15 | Args: 16 | dim (int): input feature dim 17 | neck_channels (int): number of channels in bottleneck layer. 18 | out_channels (int): number of output channels. 19 | final_relu (bool): whether to add relu to the last layer. 20 | """ 21 | super(ResBlock3D, self).__init__() 22 | self.neck_channels = neck_channels 23 | self.out_channels = out_channels 24 | self.conv1 = nn.Conv3D(dim, neck_channels, kernel_size=1, stride=1) 25 | self.conv2 = nn.Conv3D(neck_channels, neck_channels, kernel_size=3, stride=1, padding=1) 26 | self.conv3 = nn.Conv3D(neck_channels, out_channels, kernel_size=1, stride=1) 27 | self.bn1 = nn.BatchNorm3d(neck_channels) 28 | self.bn2 = nn.BatchNorm3d(neck_channels) 29 | self.bn3 = nn.BatchNorm3d(out_channels) 30 | 31 | self.shortcut = nn.Conv3D(dim, out_channels, kernel_size=1, stride=1) 32 | self.final_relu = final_relu 33 | 34 | def forward(self, x): 35 | # x.shape == (N, C, D, W, H) 36 | 37 | identity = x 38 | 39 | x = self.conv1(x) 40 | x = self.bn1(x) 41 | x = F.relu(x) 42 | 43 | x = self.conv2(x) 44 | x = self.bn2(x) 45 | x = F.relu(x) 46 | 47 | x = self.conv3(x) 48 | x = self.bn3(x) 49 | x += self.shortcut(identity) 50 | if self.final_relu: 51 | x = F.relu(x) 52 | 53 | return x 54 | 55 | 56 | class GridInterpolationLayer(nn.Module): 57 | def __init__(self, xmin=(0, 0, 0), xmax=(1, 1, 1)): 58 | """ 59 | Args: 60 | xmin (tuple): the min vertex of bbox of the scene 61 | xmax (tuple): the max vertex of bbox of the scene 62 | """ 63 | super(GridInterpolationLayer, self).__init__() 64 | self.xmin = torch.cuda.FloatTensor(xmin) 65 | self.xmax = torch.cuda.FloatTensor(xmax) 66 | 67 | def forward(self, grid, pts): 68 | """ Forward pass of grid interpolation layer. 69 | Returning trilinear interpolation neighbor latent codes, weights, and relative coordinates 70 | 71 | Args: 72 | grid (tensor): latent grid | shape==(bs, d, h, w, code_len) | `bs` is scenes batch num (not grids batch num) 73 | pts (tensor): query point, should be xmin<=pts<=xmax | shape==(bs, npoints, 3) 74 | Returns: 75 | lat (tensor): neighbors' latent codes | shape==(bs, npoints, 8, code_len) 76 | weight (tensor): trilinear interpolation weight | shape==(bs, npoints, 8) 77 | xloc (tensor): relative coordinate in local grid, it is normalized into (-1, 1) | shape==(bs, npoints, 8, 3) 78 | """ 79 | # get dimensions 80 | bs, npoints, _ = pts.shape 81 | xmin = self.xmin.reshape([1, 1, -1]) 82 | xmax = self.xmax.reshape([1, 1, -1]) 83 | size = torch.cuda.FloatTensor(list(grid.shape[1:-1])) 84 | cube_size = 1 / (size - 1) 85 | 86 | # normalize coords for interpolation 87 | pts = (pts - xmin) / (xmax - xmin) # normalize to 0 ~ 1 88 | pts = pts.clamp(min=1e-6, max=1 - 1e-6) 89 | ind0 = (pts / cube_size.reshape([1, 1, -1])).floor() # grid index (bs, npoints, 3) 90 | 91 | # get 8 neighbors 92 | offset = torch.Tensor([0, 1]) 93 | grid_x, grid_y, grid_z = torch.meshgrid(*tuple([offset] * 3)) 94 | neighbor_offsets = torch.stack([grid_x, grid_y, grid_z], dim=-1) 95 | neighbor_offsets = neighbor_offsets.reshape(-1, 3) # 8*3 96 | nneighbors = neighbor_offsets.shape[0] 97 | neighbor_offsets = neighbor_offsets.type(torch.cuda.FloatTensor) # shape==(8, 3) 98 | 99 | # get neighbor 8 latent codes 100 | neighbor_indices = ind0.unsqueeze(2) + neighbor_offsets[None, None, :, :] # (bs, npoints, 8, 3) 101 | neighbor_indices = neighbor_indices.type(torch.cuda.LongTensor) 102 | neighbor_indices = neighbor_indices.reshape(bs, -1, 3) # (bs, npoints*8, 3) 103 | d, h, w = neighbor_indices[:, :, 0], neighbor_indices[:, :, 1], neighbor_indices[:, :, 2] # (bs, npoints*8) 104 | batch_idxs = torch.arange(bs).type(torch.cuda.LongTensor) 105 | batch_idxs = batch_idxs.unsqueeze(1).expand(bs, npoints * nneighbors) # bs, 8*npoints 106 | lat = grid[batch_idxs, d, h, w, :] # bs, (npoints*8), c 107 | lat = lat.reshape(bs, npoints, nneighbors, -1) 108 | 109 | # get the tri-linear interpolation weights for each point 110 | xyz0 = ind0 * cube_size.reshape([1, 1, -1]) # (bs, npoints, 3) 111 | xyz0_expand = xyz0.unsqueeze(2).expand(bs, npoints, nneighbors, 3) # (bs, npoints, nneighbors, 3) 112 | xyz_neighbors = xyz0_expand + neighbor_offsets[None, None, :, :] * cube_size 113 | 114 | neighbor_offsets_oppo = 1 - neighbor_offsets 115 | xyz_neighbors_oppo = xyz0.unsqueeze(2) + neighbor_offsets_oppo[None, 116 | None, :, :] * cube_size # bs, npoints, 8, 3 117 | dxyz = (pts.unsqueeze(2) - xyz_neighbors_oppo).abs() / cube_size 118 | weight = dxyz[:, :, :, 0] * dxyz[:, :, :, 1] * dxyz[:, :, :, 2] 119 | 120 | # relative coordinates inside the grid (-1 ~ 1, e.g. [0~1,0~1,0~1] for min vertex, [-1~0,-1~0,-1~0] for max vertex) 121 | xloc = (pts.unsqueeze(2) - xyz_neighbors) / cube_size[None, None, None, :] 122 | 123 | return lat, weight, xloc 124 | 125 | 126 | if __name__ == '__main__': 127 | grid_interp_layer = GridInterpolationLayer() 128 | grid = torch.randn(1, 5, 5, 5, 3).type(torch.cuda.FloatTensor) 129 | # cube_size 0.25 130 | pts = torch.cuda.FloatTensor([ 131 | [0.125, 0.125, 0.125], # [0.5, 0.5, 0.5] 132 | [0.1875, 0.125, 0.125], # [0.75, 0.5, 0.5] 133 | [0.05, 0.2, 0.05], # [0.2, 0.8, 0.2] 134 | [0.3, 0.575, 0.925], # [1.2, 2.3, 3.7] 135 | [0.2, 0.8, 0.5], # [0.8, 3.2, 2.0] 136 | [0.975, 0.275, 0.65] # [3.9, 1.1, 2.6] 137 | ]) 138 | pts = pts.unsqueeze(0) 139 | ipdb.set_trace() 140 | lat, weight, loc = grid_interp_layer(grid, pts) 141 | -------------------------------------------------------------------------------- /pipelines/lig/models/method.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class NearestNeighbor: 5 | """Nearest neighbor method to get the final sdf prediction. 6 | 7 | Attributes: 8 | decoder (nn.Module): Decoder Module which inputs xyz coordinate and latent code and outputs sdf value. 9 | """ 10 | def __init__(self, decoder, interp=True): 11 | super(NearestNeighbor, self).__init__() 12 | self.decoder = decoder 13 | 14 | def forward(self, lat, weights, xloc): 15 | """Forward pass process of Nearest Neighbor Module. 16 | 17 | Args: 18 | lat (tensor): neighbors' latent codes | shape==(bs, npoints, 8, code_len) 19 | weights (tensor): trilinear interpolation weight | shape==(bs, npoints, 8) 20 | xloc (tensor): relative coordinate in local grid, it is normalized into (-1, 1) | shape==(bs, npoints, 8, 3) 21 | Returns: 22 | values (tensor): interpolated value | shape==(bs, npoints, 1) 23 | """ 24 | bs, npoints, nneighbors, c = lat.size() 25 | nearest_neighbor_idxs = weights.max(dim=2, keepdim=True)[1] 26 | lat = torch.gather(lat, dim=2, index=nearest_neighbor_idxs.unsqueeze(3).expand(bs, npoints, 1, 27 | c)) # bs*npoints*1*c 28 | lat = lat.squeeze(2) # bs*npoints*c 29 | xloc = torch.gather(xloc, dim=2, index=nearest_neighbor_idxs.unsqueeze(3).expand(bs, npoints, 1, 3)) 30 | xloc = xloc.squeeze(2) # bs*npoints*3 31 | input_features = torch.cat([xloc, lat], dim=2) 32 | values = self.decoder(input_features) 33 | return values 34 | 35 | 36 | class Linear(nn.Module): 37 | """Linear weighted sum method to get the final sdf prediction. 38 | 39 | Attributes: 40 | decoder (nn.Module): Decoder Module which inputs xyz coordinate and latent code and outputs sdf value. 41 | """ 42 | def __init__(self, decoder, interp=True): 43 | super(Linear, self).__init__() 44 | self.decoder = decoder 45 | self.interp = interp 46 | 47 | def forward(self, lat, weights, xloc): 48 | """Forward pass process of Nearest Neighbor Module. 49 | 50 | Args: 51 | lat (tensor): neighbors' latent codes | shape==(bs, npoints, 8, code_len) 52 | weights (tensor): trilinear interpolation weight | shape==(bs, npoints, 8) 53 | xloc (tensor): relative coordinate in local grid, it is normalized into (-1, 1) | shape==(bs, npoints, 8, 3) 54 | Returns: 55 | values (tensor): interpolated value | shape==(bs, npoints, 1) 56 | """ 57 | input_features = torch.cat([xloc, lat], dim=3) # shape==(bs, npoints, 8, 3+code_len) 58 | values = self.decoder(input_features) 59 | if self.interp: 60 | values = (values * weights.unsqueeze(3)).sum(dim=2) # bs*npoints*1 61 | return values 62 | else: 63 | return (values, weights) 64 | -------------------------------------------------------------------------------- /pipelines/lig/models/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | 6 | 7 | class LIGOptimizer(object): 8 | """Utility class for optimizing the input latent code at inference phase. 9 | 10 | Attributes: 11 | model (nn.Module): `GridInterpolationLayer` module which inputs xyz and latent code 12 | and outputs sdf values. 13 | latent_size (int): latent code length. 14 | grid_shape (tuple or torch.Size): grid shape of latent code grid. 15 | alpha_lat (float): loss weight of latent code norm loss during optimization process. 16 | num_optim_samples (int): number of points sampled at each step of optimization. 17 | init_std (float): standard deviation for initializing random latent codes. 18 | learning_rate (float): learning rate of the optimizer. 19 | optim_steps (int): total steps for optimizing the latent codes. 20 | print_every_n_steps (int): frequency of printing the loss information. 21 | """ 22 | def __init__(self, 23 | model, 24 | latent_size=32, 25 | alpha_lat=1e-2, 26 | num_optim_samples=2048, 27 | init_std=1e-2, 28 | learning_rate=1e-3, 29 | optim_steps=10000, 30 | print_every_n_steps=1000, 31 | indep_pt_loss=True, 32 | device=None): 33 | super(LIGOptimizer, self).__init__() 34 | self.model = model.to(device) 35 | self.latent_size = latent_size 36 | self.alpha_lat = alpha_lat 37 | self.num_optim_samples = num_optim_samples 38 | self.init_std = init_std 39 | self.lr = learning_rate 40 | self.optim_steps = optim_steps 41 | self.print_every_n_steps = print_every_n_steps 42 | self.indep_pt_loss = indep_pt_loss 43 | self.device = device 44 | self.loss_fn = nn.BCEWithLogitsLoss(reduction='none') 45 | 46 | def optimize_latent_code(self, points, point_values, occ_idxs, grid_shape): 47 | """Optimizes the latent code for each part of the grid. 48 | 49 | Args: 50 | points (tensor): bs*npoints*3, point samples near the mesh surface. 51 | point_values (tensor): bs*npoints*1, occupancy labels of the corresponding points. # (-1 / +1) 52 | occ_idxs (tensor): bs*noccupied*3, indices of the occupied grid, 53 | i.e. indices of grids to be optimized. 54 | grid_shape (list or tuple): [3], latent grid shape. 55 | Returns: 56 | latent_grid (tensor): bs*d*h*w*c, optimized latent grid. 57 | """ 58 | device = self.device 59 | 60 | # Get latent code grid for optimization process 61 | bs, npoints, _ = points.shape 62 | noccupied = occ_idxs.shape[1] 63 | si, sj, sk = grid_shape 64 | occ_idxs_flatten = occ_idxs[:, :, 0] * (sj * sk) + occ_idxs[:, :, 1] * sk + occ_idxs[:, :, 2] # bs*npoints 65 | random_latents = torch.randn(bs, noccupied, self.latent_size).type(torch.cuda.FloatTensor) * self.init_std 66 | latent_grid = torch.zeros(bs, (si * sj * sk), self.latent_size).type(torch.cuda.FloatTensor) 67 | occ_idxs_flatten_expanded = occ_idxs_flatten.unsqueeze(2).expand(bs, noccupied, self.latent_size).type(torch.cuda.LongTensor) 68 | latent_grid.scatter_(dim=1, index=occ_idxs_flatten_expanded, src=random_latents) 69 | latent_grid = latent_grid.reshape(bs, si, sj, sk, self.latent_size) 70 | 71 | latent_grid.requires_grad = True 72 | optimizer = optim.Adam([latent_grid], lr=self.lr) 73 | 74 | # Randomly shuffle the points before optimizing 75 | shuffled_idxs = np.random.permutation(npoints) 76 | points = points[:, shuffled_idxs, :] 77 | point_values = point_values[:, shuffled_idxs, :] 78 | 79 | self.model.train() # ???? 80 | for s in range(self.optim_steps): 81 | loss, acc = self.optimize_step(optimizer, latent_grid, points, point_values) 82 | if s % self.print_every_n_steps == 0: 83 | print('Step [{:6d}] Acc: {:5.4f} Loss: {:5.4f}'.format(s, acc.item(), loss.item())) 84 | 85 | return latent_grid 86 | 87 | def optimize_step(self, optimizer, latent_grid, points, point_values): 88 | """Performs an optimize step. 89 | 90 | In-place version of optimizing the input latent grid. 91 | 92 | Args: 93 | optimizer (torch.optim): py-torch optimizer 94 | latent_grid (tensor): 1*h*w*d*c, input random latent grid. 95 | points (tensor): 1*npoints*3, input query points. 96 | point_values (tensor): 1*npoints*1, sign of point sdf values 97 | Returns: 98 | loss (tensor): [1] loss in this optimization step. 99 | acc (tensor): [1] predicted sdf values' sign accuracy. 100 | """ 101 | optimizer.zero_grad() 102 | point_samples, point_val_samples = self.random_point_sample(points, point_values) 103 | if self.indep_pt_loss: 104 | # 1*npoints*nneighbors*1 1*npoints*nneighbors 105 | pred, weights = self.model.decode(latent_grid, point_samples) 106 | pred_interp = (pred * weights.unsqueeze(3)).sum(dim=2, keepdim=True) 107 | pred = torch.cat([pred, pred_interp], dim=2) # 1*npoints*9*1 108 | point_val_samples = point_val_samples.unsqueeze(2).expand(*pred.size()) # 1*npoints*9*1 109 | else: 110 | pred = self.model.decode(latent_grid, point_samples) 111 | 112 | binary_labels = (point_val_samples + 1) / 2 # 0 / 1 113 | pred_flatten = pred.reshape(-1, 1) 114 | binary_labels = binary_labels.reshape(-1, 1) 115 | loss = self.loss_fn(pred_flatten, binary_labels).mean() 116 | all_norm = torch.norm(latent_grid, dim=4).reshape(-1) 117 | loss_lat = all_norm[torch.abs(all_norm) > 1e-7].mean() * self.alpha_lat 118 | loss = loss + loss_lat 119 | 120 | if self.indep_pt_loss: 121 | both_pos = (pred[:, :, -1, :].sign() > 0) & (point_val_samples[:, :, -1, :].sign() > 0) 122 | both_neg = (pred[:, :, -1, :].sign() < 0) & (point_val_samples[:, :, -1, :].sign() < 0) 123 | else: 124 | both_pos = (pred.sign() > 0) & (point_val_samples.sign() > 0) 125 | both_neg = (pred.sign() < 0) & (point_val_samples.sign() < 0) 126 | correct = (both_pos | both_neg).sum().float() 127 | 128 | bs, nsamples = point_val_samples.shape[0], point_val_samples.shape[1] 129 | acc = correct / (bs * nsamples) 130 | 131 | loss.backward() 132 | optimizer.step() 133 | return loss, acc 134 | 135 | def random_point_sample(self, points, point_vals): 136 | """Samples point-occ pairs randomly. 137 | 138 | Args: 139 | points (tensor): bs*npoints*3 140 | point_vals (tensor): bs*npoints*1 141 | Returns: 142 | points_samples (tensor): bs*self.num_optim_samples*3 143 | point_val_samples (tensor): bs*self.num_optim_samples*1 144 | """ 145 | self.num_optim_samples = min(self.num_optim_samples, points.shape[1]) 146 | start_idx = np.random.choice(points.shape[1] - self.num_optim_samples + 1) 147 | end_idx = start_idx + self.num_optim_samples 148 | point_samples = points[:, start_idx:end_idx, :] 149 | point_val_samples = point_vals[:, start_idx:end_idx, :] 150 | return point_samples, point_val_samples 151 | -------------------------------------------------------------------------------- /pipelines/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import numpy as np 3 | from scipy.spatial import cKDTree 4 | from pipelines.utils.libmesh import check_mesh_contains 5 | 6 | class AverageValueMeter(object): 7 | """Computes and stores the average and current value""" 8 | 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0.0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | def compute_iou(occ1, occ2): 25 | ''' Computes the Intersection over Union (IoU) value for two sets of 26 | occupancy values. 27 | 28 | Args: 29 | occ1 (tensor): first set of occupancy values 30 | occ2 (tensor): second set of occupancy values 31 | ''' 32 | occ1 = np.asarray(occ1) 33 | occ2 = np.asarray(occ2) 34 | 35 | # Put all data in second dimension 36 | # Also works for 1-dimensional data 37 | if occ1.ndim >= 2: 38 | occ1 = occ1.reshape(occ1.shape[0], -1) 39 | if occ2.ndim >= 2: 40 | occ2 = occ2.reshape(occ2.shape[0], -1) 41 | 42 | # Convert to boolean values 43 | occ1 = (occ1 >= 0.5) 44 | occ2 = (occ2 >= 0.5) 45 | 46 | # Compute IOU 47 | area_union = (occ1 | occ2).astype(np.float32).sum(axis=-1) 48 | area_intersect = (occ1 & occ2).astype(np.float32).sum(axis=-1) 49 | 50 | iou = (area_intersect / area_union) 51 | 52 | return iou 53 | 54 | 55 | def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, 56 | points_iou, occ_tgt): 57 | ''' Evaluates a mesh. 58 | 59 | Args: 60 | mesh (trimesh): mesh which should be evaluated 61 | pointcloud_tgt (numpy array): target point cloud 62 | normals_tgt (numpy array): target normals 63 | points_iou (numpy_array): points tensor for IoU evaluation 64 | occ_tgt (numpy_array): GT occupancy values for IoU points 65 | ''' 66 | if len(mesh.vertices) != 0 and len(mesh.faces) != 0: 67 | pointcloud, idx = mesh.sample(self.n_points, return_index=True) 68 | pointcloud = pointcloud.astype(np.float32) 69 | normals = mesh.face_normals[idx] 70 | else: 71 | pointcloud = np.empty((0, 3)) 72 | normals = np.empty((0, 3)) 73 | 74 | out_dict = self.eval_pointcloud( 75 | pointcloud, pointcloud_tgt, normals, normals_tgt) 76 | 77 | if len(mesh.vertices) != 0 and len(mesh.faces) != 0: 78 | occ = check_mesh_contains(mesh, points_iou) 79 | out_dict['iou'] = compute_iou(occ, occ_tgt) 80 | else: 81 | out_dict['iou'] = 0. 82 | 83 | return out_dict 84 | 85 | def distance_p2p(points_src, normals_src, points_tgt, normals_tgt): 86 | ''' Computes minimal distances of each point in points_src to points_tgt. 87 | 88 | Args: 89 | points_src (numpy array): source points 90 | normals_src (numpy array): source normals 91 | points_tgt (numpy array): target points 92 | normals_tgt (numpy array): target normals 93 | ''' 94 | kdtree = cKDTree(points_tgt) 95 | dist, idx = kdtree.query(points_src, n_jobs=32) 96 | 97 | if normals_src is not None and normals_tgt is not None: 98 | normals_src = \ 99 | normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True) 100 | normals_tgt = \ 101 | normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True) 102 | 103 | normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1) 104 | # Handle normals that point into wrong direction gracefully 105 | # (mostly due to mehtod not caring about this in generation) 106 | normals_dot_product = np.abs(normals_dot_product) 107 | else: 108 | normals_dot_product = np.array( 109 | [np.nan] * points_src.shape[0], dtype=np.float32) 110 | return dist, normals_dot_product 111 | 112 | 113 | def eval_pointcloud(pointcloud, pointcloud_tgt, normals=None, normals_tgt=None): 114 | ''' Evaluates a point cloud. 115 | 116 | Args: 117 | pointcloud (numpy array): predicted point cloud 118 | pointcloud_tgt (numpy array): target point cloud 119 | normals (numpy array): predicted normals 120 | normals_tgt (numpy array): target normals 121 | ''' 122 | pointcloud = np.asarray(pointcloud) 123 | pointcloud_tgt = np.asarray(pointcloud_tgt) 124 | 125 | # Completeness: how far are the points of the target point cloud 126 | # from thre predicted point cloud 127 | dist_backward, completeness_normals = distance_p2p( 128 | pointcloud_tgt, normals_tgt, pointcloud, normals 129 | ) 130 | dist_backward2 = dist_backward**2 131 | 132 | completeness = dist_backward.mean() 133 | completeness2 = dist_backward2.mean() 134 | completeness_normals = completeness_normals.mean() 135 | 136 | # Accuracy: how far are th points of the predicted pointcloud 137 | # from the target pointcloud 138 | dist_forward, accuracy_normals = distance_p2p( 139 | pointcloud, normals, pointcloud_tgt, normals_tgt 140 | ) 141 | dist_forward2 = dist_forward**2 142 | 143 | accuracy = dist_forward.mean() 144 | accuracy2 = dist_forward2.mean() 145 | accuracy_normals = accuracy_normals.mean() 146 | 147 | # Chamfer distance 148 | chamferL2 = 0.5 * (completeness2 + accuracy2) 149 | normals_correctness = ( 150 | 0.5 * completeness_normals + 0.5 * accuracy_normals 151 | ) 152 | chamferL1 = 0.5 * (completeness + accuracy) 153 | 154 | out_dict = { 155 | 'completeness': completeness, 156 | 'accuracy': accuracy, 157 | 'normals completeness': completeness_normals, 158 | 'normals accuracy': accuracy_normals, 159 | 'normals': normals_correctness, 160 | 'completeness2': completeness2, 161 | 'accuracy2': accuracy2, 162 | 'chamfer-L2': chamferL2, 163 | 'chamfer-L1': chamferL1, 164 | } 165 | 166 | # F-score 167 | # percentage_list_1 = np.arange(0.0001, 0.001, 0.0001).astype(np.float32) 168 | # percentage_list_2 = np.arange(0.001, 0.01, 0.001).astype(np.float32) 169 | # percentage_list_3 = np.arange(0.01, 0.11, 0.01).astype(np.float32) 170 | # thres_percentage_list = np.concatenate([percentage_list_1, percentage_list_2, percentage_list_3], axis=0) 171 | # thres_percentage_list = np.sort(thres_percentage_list) 172 | # xmax = pointcloud_tgt.max(axis=0) 173 | # xmin = pointcloud_tgt.min(axis=0) 174 | # bbox_length = np.linalg.norm(xmax - xmin) 175 | # threshold_list = bbox_length * thres_percentage_list 176 | threshold_list = np.array([0.005]).astype(np.float32) 177 | for i in range(threshold_list.shape[0]): 178 | threshold = threshold_list[i] 179 | 180 | pre_sum_val = np.sum(np.less(dist_forward, threshold)) 181 | rec_sum_val = np.sum(np.less(dist_backward, threshold)) 182 | fprecision = pre_sum_val / dist_forward.shape[0] 183 | frecall = rec_sum_val / dist_backward.shape[0] 184 | fscore = 2 * (fprecision * frecall) / (fprecision + frecall + 1e-6) 185 | out_dict['f_score_{:.4}'.format(threshold)] = fscore 186 | out_dict['precision_{:.4}'.format(threshold)] = fprecision 187 | out_dict['rescall_{:.4}'.format(threshold)] = frecall 188 | 189 | return out_dict -------------------------------------------------------------------------------- /pipelines/utils/libchamfer/__init__.py: -------------------------------------------------------------------------------- 1 | from pipelines.utils.libchamfer.dist_chamfer import chamferDist -------------------------------------------------------------------------------- /pipelines/utils/libchamfer/chamfer.cp37-win_amd64.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wnbzhao/Local-Implicit-Grid-Pytorch/d45da37beda52653f0066f9ba0f0500c54402e13/pipelines/utils/libchamfer/chamfer.cp37-win_amd64.pyd -------------------------------------------------------------------------------- /pipelines/utils/libchamfer/chamfer.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | 11 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 12 | const int batch=512; 13 | __shared__ float buf[batch*3]; 14 | for (int i=blockIdx.x;ibest){ 126 | result[(i*n+j)]=best; 127 | result_i[(i*n+j)]=best_i; 128 | } 129 | } 130 | __syncthreads(); 131 | } 132 | } 133 | } 134 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 135 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ 136 | 137 | const auto batch_size = xyz1.size(0); 138 | const auto n = xyz1.size(1); //num_points point cloud A 139 | const auto m = xyz2.size(1); //num_points point cloud B 140 | 141 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); 142 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); 143 | 144 | cudaError_t err = cudaGetLastError(); 145 | if (err != cudaSuccess) { 146 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 147 | //THError("aborting"); 148 | return 0; 149 | } 150 | return 1; 151 | 152 | 153 | } 154 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 155 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); 184 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); 185 | 186 | cudaError_t err = cudaGetLastError(); 187 | if (err != cudaSuccess) { 188 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 189 | //THError("aborting"); 190 | return 0; 191 | } 192 | return 1; 193 | 194 | } 195 | 196 | -------------------------------------------------------------------------------- /pipelines/utils/libchamfer/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | ///TMP 5 | //#include "common.h" 6 | /// NOT TMP 7 | 8 | 9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 10 | 11 | 12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 13 | 14 | 15 | 16 | 17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 19 | } 20 | 21 | 22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 24 | 25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); 32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); 33 | } -------------------------------------------------------------------------------- /pipelines/utils/libchamfer/dist_chamfer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | from torch.autograd import Function 4 | import torch 5 | import sys 6 | from numbers import Number 7 | from collections import Set, Mapping, deque 8 | from . import chamfer 9 | 10 | 11 | # Chamfer's distance module @thibaultgroueix 12 | # GPU tensors only 13 | class chamferFunction(Function): 14 | @staticmethod 15 | def forward(ctx, xyz1, xyz2): 16 | batchsize, n, _ = xyz1.size() 17 | _, m, _ = xyz2.size() 18 | 19 | dist1 = torch.zeros(batchsize, n) 20 | dist2 = torch.zeros(batchsize, m) 21 | 22 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 23 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 24 | 25 | dist1 = dist1.cuda() 26 | dist2 = dist2.cuda() 27 | idx1 = idx1.cuda() 28 | idx2 = idx2.cuda() 29 | 30 | chamfer.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 31 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 32 | return dist1, dist2, idx1, idx2 33 | 34 | @staticmethod 35 | def backward(ctx, graddist1, graddist2,gradidx1,gradidx2): 36 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 37 | graddist1 = graddist1.contiguous() 38 | graddist2 = graddist2.contiguous() 39 | 40 | gradxyz1 = torch.zeros(xyz1.size()) 41 | gradxyz2 = torch.zeros(xyz2.size()) 42 | 43 | gradxyz1 = gradxyz1.cuda() 44 | gradxyz2 = gradxyz2.cuda() 45 | chamfer.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 46 | return gradxyz1, gradxyz2 47 | 48 | class chamferDist(nn.Module): 49 | def __init__(self): 50 | super(chamferDist, self).__init__() 51 | 52 | def forward(self, input1, input2): 53 | return chamferFunction.apply(input1, input2) 54 | 55 | -------------------------------------------------------------------------------- /pipelines/utils/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | """Code adapted from https://github.com/facebookresearch/DeepSDF/blob/master/train_deep_sdf.py""" 2 | 3 | class LearningRateSchedule: 4 | def get_learning_rate(self, epoch): 5 | pass 6 | 7 | 8 | class ConstantLearningRateSchedule(LearningRateSchedule): 9 | def __init__(self, value): 10 | self.value = value 11 | 12 | def get_learning_rate(self, epoch): 13 | return self.value 14 | 15 | 16 | class StepLearningRateSchedule(LearningRateSchedule): 17 | def __init__(self, initial, interval, factor): 18 | self.initial = initial 19 | self.interval = interval 20 | self.factor = factor 21 | 22 | def get_learning_rate(self, epoch): 23 | 24 | return self.initial * (self.factor ** (epoch // self.interval)) 25 | 26 | 27 | class WarmupLearningRateSchedule(LearningRateSchedule): 28 | def __init__(self, initial, warmed_up, length): 29 | self.initial = initial 30 | self.warmed_up = warmed_up 31 | self.length = length 32 | 33 | def get_learning_rate(self, epoch): 34 | if epoch > self.length: 35 | return self.warmed_up 36 | return self.initial + (self.warmed_up - self.initial) * epoch / self.length 37 | 38 | 39 | def get_learning_rate_schedules(specs): 40 | 41 | schedule_specs = specs["LearningRateSchedule"] 42 | 43 | schedules = [] 44 | 45 | for schedule_specs in schedule_specs: 46 | 47 | if schedule_specs["Type"] == "Step": 48 | schedules.append( 49 | StepLearningRateSchedule( 50 | schedule_specs["Initial"], 51 | schedule_specs["Interval"], 52 | schedule_specs["Factor"], 53 | ) 54 | ) 55 | elif schedule_specs["Type"] == "Warmup": 56 | schedules.append( 57 | WarmupLearningRateSchedule( 58 | schedule_specs["Initial"], 59 | schedule_specs["Final"], 60 | schedule_specs["Length"], 61 | ) 62 | ) 63 | elif schedule_specs["Type"] == "Constant": 64 | schedules.append(ConstantLearningRateSchedule(schedule_specs["Value"])) 65 | 66 | else: 67 | raise Exception( 68 | 'no known learning rate schedule of type "{}"'.format( 69 | schedule_specs["Type"] 70 | ) 71 | ) 72 | return schedules 73 | 74 | 75 | def adjust_learning_rate(lr_schedules, optimizer, epoch): 76 | for i, param_group in enumerate(optimizer.param_groups): 77 | param_group["lr"] = lr_schedules[i].get_learning_rate(epoch) -------------------------------------------------------------------------------- /pipelines/utils/point_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from plyfile import PlyData 4 | from plyfile import PlyElement 5 | from scipy.spatial import cKDTree 6 | import ipdb 7 | 8 | 9 | def read_point_ply(filename): 10 | """Load point cloud from ply file. 11 | 12 | Args: 13 | filename: str, filename for ply file to load. 14 | Returns: 15 | v: np.array of shape [#v, 3], vertex coordinates 16 | n: np.array of shape [#v, 3], vertex normals 17 | """ 18 | pd = PlyData.read(filename)['vertex'] 19 | try: 20 | v = np.array(np.stack([pd[i] for i in ['x', 'y', 'z']], axis=-1)) 21 | n = np.array(np.stack([pd[i] for i in ['nx', 'ny', 'nz']], axis=-1)) 22 | except: 23 | v = np.array(np.stack([pd[i] for i in ['x', 'y', 'z']], axis=-1)) 24 | n = np.zeros_like(v).astype(np.float32) 25 | return v, n 26 | 27 | 28 | def write_point_ply(filename, v, n): 29 | """Write point cloud to ply file. 30 | 31 | Args: 32 | filename: str, filename for ply file to load. 33 | v: np.array of shape [#v, 3], vertex coordinates 34 | n: np.array of shape [#v, 3], vertex normals 35 | """ 36 | vn = np.concatenate([v, n], axis=1) 37 | vn = [tuple(vn[i]) for i in range(vn.shape[0])] 38 | vn = np.array(vn, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4')]) 39 | el = PlyElement.describe(vn, 'vertex') 40 | PlyData([el]).write(filename) 41 | 42 | 43 | def sample_points_from_ray(points, normals, sample_factor=10, std=0.01): 44 | """Get sample points from points from ray. 45 | 46 | Args: 47 | points (numpy array): [npts, 3], xyz coordinate of points on the mesh surface. 48 | normals (numpy array): [npts, 3], normals of points on the mesh surface. 49 | sample_factor (int): number of samples to pick per surface point. 50 | std (float): std of samples to generate. 51 | Returns: 52 | points (numpy array): [npts*sample_factor, 3], where last dimension is 53 | distance to surface point. 54 | sdf_values (numpy array): [npts*sample_factor, 1], sdf values of the sampled points 55 | near the mesh surface. 56 | """ 57 | normals = normals / (np.linalg.norm(normals, axis=1, keepdims=True) + 1e-8) 58 | npoints = points.shape[0] 59 | offsets = np.random.randn(npoints, sample_factor, 1) * std 60 | point_samples = points[:, np.newaxis, :] + offsets * normals[:, np.newaxis, :] 61 | point_samples = point_samples.reshape(-1, points.shape[1]) 62 | sdf_values = offsets.reshape(-1, 1) 63 | point_samples = point_samples.astype(np.float32) 64 | sdf_values = sdf_values.astype(np.float32) 65 | return point_samples, sdf_values 66 | 67 | 68 | def np_pad_points(points, ntarget): 69 | """Pad point cloud to required size. 70 | 71 | If number of points is larger than ntarget, take ntarget random samples. 72 | If number of points is smaller than ntarget, pad by repeating last point. 73 | Args: 74 | points: `[npoints, nchannel]` np array, where first 3 channels are xyz. 75 | ntarget: int, number of target channels. 76 | Returns: 77 | result: `[ntarget, nchannel]` np array, padded points to ntarget numbers. 78 | """ 79 | if points.shape[0] < ntarget: 80 | mult = np.ceil(float(ntarget) / float(points.shape[0])) - 1 81 | rand_pool = np.tile(points, [int(mult), 1]) 82 | nextra = ntarget - points.shape[0] 83 | extra_idx = np.random.choice(rand_pool.shape[0], nextra, replace=False) 84 | extra_pts = rand_pool[extra_idx] 85 | points_out = np.concatenate([points, extra_pts], axis=0) 86 | else: 87 | idx_choice = np.random.choice(points.shape[0], size=ntarget, replace=False) 88 | points_out = points[idx_choice] 89 | 90 | return points_out 91 | 92 | 93 | def np_gather_ijk_index(arr, index): 94 | """Gather the features of given index from the feature grid. 95 | 96 | Args: 97 | arr (numpy array): h*w*d*c, feature grid. 98 | index (numpy array): nx*3, index of the feature grid 99 | Returns: 100 | nx*c, features at given index of the feature grid. 101 | """ 102 | arr_flat = arr.reshape(-1, arr.shape[-1]) 103 | _, j, k, _ = arr.shape 104 | index_transform = index[:, 0] * j * k + index[:, 1] * k + index[:, 2] 105 | return arr_flat[index_transform] 106 | 107 | 108 | def np_shifted_crop(v, idx_grid, shift, crop_size, ntarget): 109 | """Crop the """ 110 | nchannels = v.shape[1] 111 | vxyz = v[:, :3] - shift * crop_size * 0.5 112 | vall = v.copy() 113 | point_idxs = np.arange(v.shape[0]) 114 | point_grid_idx = np.floor(vxyz / crop_size).astype(np.int32) 115 | valid_mask = np.ones(point_grid_idx.shape[0]).astype(np.bool) 116 | for i in range(3): 117 | valid_mask = np.logical_and(valid_mask, point_grid_idx[:, i] >= 0) 118 | valid_mask = np.logical_and(valid_mask, point_grid_idx[:, i] < idx_grid.shape[i]) 119 | point_grid_idx = point_grid_idx[valid_mask] 120 | # translate to global grid index 121 | point_grid_idx = np_gather_ijk_index(idx_grid, point_grid_idx) 122 | 123 | vall = vall[valid_mask] 124 | point_idxs = point_idxs[valid_mask] 125 | crop_indices, revidx = np.unique(point_grid_idx, axis=0, return_inverse=True) 126 | ncrops = crop_indices.shape[0] 127 | sortarr = np.argsort(revidx) 128 | revidx_sorted = revidx[sortarr] 129 | vall_sorted = vall[sortarr] 130 | point_idxs_sorted = point_idxs[sortarr] 131 | bins = np.searchsorted(revidx_sorted, np.arange(ncrops)) 132 | bins = list(bins) + [v.shape[0]] 133 | sid = bins[0:-1] 134 | eid = bins[1:] 135 | # initialize outputs 136 | point_crops = np.zeros([ncrops, ntarget, nchannels]).astype(np.float32) 137 | crop_point_idxs = [] 138 | # extract crops and pad 139 | for i, (s, e) in enumerate(zip(sid, eid)): 140 | cropped_points = vall_sorted[s:e] 141 | crop_point_idx = point_idxs_sorted[s:e] 142 | crop_point_idxs.append(crop_point_idx) 143 | padded_points = np_pad_points(cropped_points, ntarget=ntarget) 144 | point_crops[i] = padded_points 145 | return point_crops, crop_indices, crop_point_idxs 146 | 147 | 148 | def np_get_occupied_idx(v, 149 | xmin=(0., 0., 0.), 150 | xmax=(1., 1., 1.), 151 | crop_size=.125, 152 | ntarget=2048, 153 | overlap=True, 154 | normalize_crops=False, 155 | return_shape=False, 156 | return_crop_point_idxs=False): 157 | """Get crop indices for point clouds.""" 158 | # v = v.copy() - xmin 159 | v = v.copy() 160 | v[:, :3] = v[:, :3] - xmin 161 | xmin = np.array(xmin) 162 | xmax = np.array(xmax) 163 | r = (xmax - xmin) / crop_size 164 | r = np.ceil(r) 165 | rr = r.astype(np.int32) if not overlap else (2 * r - 1).astype(np.int32) 166 | # create index grid 167 | idx_grid = np.stack(np.meshgrid(np.arange(rr[0]), np.arange(rr[1]), np.arange(rr[2]), indexing='ij'), axis=-1) 168 | # [rr[0], rr[1], rr[2], 3] 169 | 170 | shift_idxs = np.stack(np.meshgrid(np.arange(int(overlap) + 1), 171 | np.arange(int(overlap) + 1), 172 | np.arange(int(overlap) + 1), 173 | indexing='ij'), 174 | axis=-1) 175 | shift_idxs = np.reshape(shift_idxs, [-1, 3]) 176 | point_crops = [] 177 | crop_indices = [] 178 | crop_point_idxs = [] 179 | for i in range(shift_idxs.shape[0]): 180 | sft = shift_idxs[i] 181 | skp = int(overlap) + 1 182 | idg = idx_grid[sft[0]::skp, sft[1]::skp, sft[2]::skp] 183 | pc, ci, cpidx = np_shifted_crop(v, idg, sft, crop_size=crop_size, ntarget=ntarget) 184 | point_crops.append(pc) 185 | crop_indices.append(ci) 186 | crop_point_idxs += cpidx 187 | point_crops = np.concatenate(point_crops, axis=0) # [ncrops, nsurface, 6] 188 | crop_indices = np.concatenate(crop_indices, axis=0) # [ncrops, 3] 189 | 190 | if normalize_crops: 191 | # normalize each crop 192 | if overlap: 193 | crop_corners = crop_indices * 0.5 * crop_size 194 | crop_centers = crop_corners + 0.5 * crop_size # [ncrops, 3] 195 | else: 196 | # add new branch here to fix bug.. 197 | crop_corners = crop_indices * crop_size 198 | crop_centers = crop_corners + 0.5 * crop_size # [ncrops, 3] 199 | 200 | crop_centers = crop_centers[:, np.newaxis, :] # [ncrops, 1, 3] 201 | point_crops[..., :3] = point_crops[..., :3] - crop_centers 202 | point_crops[..., :3] = point_crops[..., :3] / crop_size * 2 203 | 204 | outputs = [point_crops, crop_indices] 205 | if return_shape: outputs += [idx_grid.shape[:3]] 206 | if return_crop_point_idxs: 207 | outputs += [crop_point_idxs] 208 | return tuple(outputs) 209 | 210 | 211 | def sample_uniform_from_occupied_grids(v, crop_indices, xmin, crop_size, samples_per_grid=2048, 212 | overlap=True, dist_threshold=0.03): 213 | ncrops = crop_indices.shape[0] 214 | if overlap: 215 | crop_corners = xmin + crop_indices * 0.5 * crop_size # [ncrops, 3] 216 | else: 217 | crop_corners = xmin + crop_indices * crop_size 218 | 219 | # Sample points uniformly from a grid 220 | uniform_samples = crop_corners[:, np.newaxis, :] + np.random.uniform(size=(ncrops, samples_per_grid, 3)) * crop_size 221 | uniform_samples = uniform_samples.reshape(-1, 3) 222 | tree = cKDTree(v, balanced_tree=False) 223 | dists, idxs = tree.query(uniform_samples, k=1, n_jobs=32) 224 | all_idxs = np.arange(uniform_samples.shape[0]) 225 | target_idxs = all_idxs[dists > dist_threshold] 226 | target_samples = uniform_samples[target_idxs, :] 227 | target_samples = target_samples.astype(np.float32) 228 | 229 | return target_samples # [?, 3] 230 | 231 | 232 | def pca_normal_estimation(points, k=10): 233 | tree = cKDTree(points, balanced_tree=False) 234 | dists, idxs = tree.query(points, k=k + 1, n_jobs=32) 235 | idxs = idxs[:, 1:] 236 | npoints, k = idxs.shape 237 | neighbors = points[idxs.reshape(-1), :] 238 | neighbors = neighbors.reshape(npoints, k, 3) 239 | vectors = neighbors - points[:, np.newaxis, :] # npoints*k*3 240 | vectors_trans = np.transpose(vectors, (0, 2, 1)) # npoints*3*k 241 | cov_matrix = np.matmul(vectors_trans, vectors) # npoints*3*3 242 | u, s, v = np.linalg.svd(cov_matrix) 243 | est_normals = v[:, -1, :] # npoints*3 244 | est_normals = est_normals / (np.linalg.norm(est_normals, axis=1, keepdims=True) + 1e-12) 245 | return est_normals 246 | 247 | 248 | def occupancy_sparse_to_dense(occ_idx, grid_shape): 249 | dense = np.zeros(grid_shape, dtype=np.bool).ravel() 250 | occ_idx_f = (occ_idx[:, 0] * grid_shape[1] * grid_shape[2] + occ_idx[:, 1] * grid_shape[2] + occ_idx[:, 2]) 251 | dense[occ_idx_f] = True 252 | dense = np.reshape(dense, grid_shape) 253 | return dense 254 | 255 | 256 | def fit_sphere_through_points(points): 257 | points_mean = np.mean(points, axis=0, keepdims=True) 258 | N = points.shape[0] 259 | points_exp = np.tile(points[:, :, np.newaxis], [1, 1, 3]) 260 | points_exp = points_exp.reshape(N, 9) 261 | delta = points - points_mean 262 | delta_exp = np.tile(delta, [1, 3]) 263 | A = 2 * (points_exp * delta_exp).mean(axis=0).reshape(3, 3) 264 | 265 | points_squared = (points ** 2).sum(axis=1, keepdims=True) 266 | B = (points_squared * delta).mean(axis=0, keepdims=True).T 267 | 268 | AT_A = np.dot(A.T, A) 269 | AT_B = np.dot(A.T, B) 270 | try: 271 | center_T = np.dot(np.linalg.inv(AT_A), AT_B) 272 | except: 273 | return None, None 274 | center = center_T.T 275 | 276 | radius = np.sqrt(((points - center) ** 2).sum(axis=1).mean()) 277 | 278 | return center, radius 279 | 280 | def get_occupied_grid_fitting_sphere(v, xmin, xmax, crop_size, ntarget=2048, overlap=True, 281 | normalize_crops=False, return_shape=False, return_crop_point_idxs=False): 282 | point_crops, crop_indices, grid_shape, crop_point_idxs = np_get_occupied_idx( 283 | v=v, 284 | xmin=xmin, 285 | xmax=xmax, 286 | crop_size=crop_size, 287 | ntarget=ntarget, 288 | overlap=overlap, 289 | normalize_crops=normalize_crops, 290 | return_shape=True, 291 | return_crop_point_idxs=True 292 | ) 293 | 294 | center_list = [] 295 | radius_list = [] 296 | valid_crop_idxs = [] 297 | ncrops = len(crop_point_idxs) 298 | for n in range(ncrops): 299 | grid_points = v[crop_point_idxs[n]] 300 | 301 | if normalize_crops: 302 | # Normalize the grid points 303 | if overlap: 304 | grid_corner = xmin + crop_indices[n] * 0.5 * crop_size # [3] 305 | else: 306 | grid_corner = xmin + crop_indices[n] * crop_size # [3] 307 | grid_center = grid_corner + 0.5 * crop_size 308 | grid_points = (grid_points - grid_center) / (0.5 * crop_size) 309 | else: 310 | grid_points = grid_points - xmin 311 | 312 | center, radius = fit_sphere_through_points(grid_points) 313 | if center is not None: 314 | center_list.append(center) 315 | radius_list.append(radius) 316 | valid_crop_idxs.append(n) 317 | point_crops = point_crops[valid_crop_idxs, :, :] 318 | crop_indices = crop_indices[valid_crop_idxs, :] 319 | crop_point_idxs = [crop_point_idxs[idx] for idx in valid_crop_idxs] 320 | 321 | centers = np.concatenate(center_list).astype(np.float32) 322 | radius = np.array(radius_list).astype(np.float32)[:, np.newaxis] 323 | outputs = [point_crops, crop_indices, centers, radius] 324 | if return_shape: outputs += [grid_shape] 325 | if return_crop_point_idxs: 326 | outputs += [crop_point_idxs] 327 | return tuple(outputs) 328 | 329 | 330 | def clip_radius_np(point_crops, centers, radius, min_radius=10.): 331 | clip_idxs = np.where(radius < min_radius)[0] 332 | clip_points = point_crops[clip_idxs, :, :] 333 | clip_points_centers = np.mean(clip_points, axis=1) # gravity center 334 | clip_centers = centers[clip_idxs, :] # x0 335 | vector = (clip_centers - clip_points_centers) 336 | vector_normed = vector / (np.linalg.norm(vector, axis=1, keepdims=True) + 1e-12) 337 | clip_points_new_centers = clip_points_centers + vector_normed * min_radius # xc + r * (x0 - xc) 338 | centers[clip_idxs, :] = clip_points_new_centers 339 | radius[clip_idxs, :] = min_radius 340 | return centers, radius 341 | 342 | 343 | def clip_radius_torch(point_crops, centers, radius, min_radius=10.): 344 | clip_idxs = torch.nonzero(radius < min_radius)[:, 0] 345 | clip_points = point_crops[clip_idxs, :, :] 346 | clip_points_centers = clip_points.mean(dim=1) # gravity center 347 | clip_centers = centers[clip_idxs, :] # x0 348 | vector = (clip_centers - clip_points_centers) 349 | vector_normed = vector / (torch.norm(vector, dim=1, keepdim=True) + 1e-12) 350 | clip_points_new_centers = clip_points_centers + vector_normed * min_radius # xc + r * (x0 - xc) 351 | centers[clip_idxs, :] = clip_points_new_centers 352 | radius[clip_idxs, :] = min_radius 353 | return centers, radius 354 | 355 | 356 | def sample_from_overlapping_area(crop_indices, xmin, crop_size, grid_shape, samples_per_grid=2048, overlap=True): 357 | ncrops = crop_indices.shape[0] 358 | if overlap: 359 | crop_corners = xmin + crop_indices * 0.5 * crop_size # [ncrops, 3] 360 | else: 361 | crop_corners = xmin + crop_indices * crop_size 362 | 363 | # Sample points uniformly from a grid 364 | uniform_samples = crop_corners[:, np.newaxis, :] + np.random.uniform(size=(ncrops, samples_per_grid, 3)) * crop_size 365 | 366 | # Get neighbor indices of an occupied grid and mask out the neighboring grids without points or out of bound 367 | offset_x = np.array([0, 1])[:, np.newaxis, np.newaxis] 368 | offset_x = np.tile(offset_x, [1, 2, 2]) 369 | offset_y = np.array([0, 1])[np.newaxis, :, np.newaxis] 370 | offset_y = np.tile(offset_y, [2, 1, 2]) 371 | offset_z = np.array([0, 1])[np.newaxis, np.newaxis, :] 372 | offset_z = np.tile(offset_z, [2, 2, 1]) 373 | neighbor_idx_offset = np.stack([offset_x, offset_y, offset_z], axis=3) 374 | neighbor_idx_offset = neighbor_idx_offset.reshape(-1, 3) 375 | neighbor_idxs = crop_indices[:, np.newaxis, :] + neighbor_idx_offset[np.newaxis, :, :] # [ncrops, nneighbors, 3] 376 | _, nneighbors, _ = neighbor_idxs.shape 377 | neighbor_idxs = neighbor_idxs.reshape(-1, 3) 378 | 379 | d, h, w = grid_shape 380 | neighbor_idxs_flatten = neighbor_idxs[:, 0] * (h * w) + neighbor_idxs[:, 1] * w + neighbor_idxs[:, 2] 381 | crop_indices_flatten = crop_indices[:, 0] * (h * w) + crop_indices[:, 1] * w + crop_indices[:, 2] 382 | mask = np.isin(neighbor_idxs_flatten, crop_indices_flatten) # [ncrops * nneighbors,] 383 | mask = mask.reshape(ncrops, nneighbors) 384 | 385 | # Mask out the grids with no neighbors containing points 386 | mask_with_neighbors = (mask.sum(axis=1) > 1) 387 | uniform_samples = uniform_samples[mask_with_neighbors, :, :] 388 | mask = mask[mask_with_neighbors, :].astype(np.float32) 389 | uniform_samples = uniform_samples.astype(np.float32) 390 | 391 | return uniform_samples, mask # [?, samples_per_grid, 3] [?, 8] 392 | 393 | 394 | def sample_from_grids_with_neighbors(crop_indices, xmin, crop_size, grid_shape, samples_per_grid=2048, overlap=True): 395 | d, h, w = grid_shape 396 | d_idxs = np.arange(d) 397 | h_idxs = np.arange(h) 398 | w_idxs = np.arange(w) 399 | dd, hh, ww = np.meshgrid(d_idxs, h_idxs, w_idxs) 400 | all_grid_idxs = np.stack([dd, hh, ww], axis=3) 401 | all_grid_idxs = all_grid_idxs.reshape(-1, 3) 402 | 403 | xx, yy, zz = np.meshgrid(*list(([0, 1],) * 3)) 404 | neighbor_idx_offsets = np.stack([xx, yy, zz], axis=3) 405 | neighbor_idx_offsets = neighbor_idx_offsets.reshape(-1, 3) 406 | all_grid_neighbor_idxs = all_grid_idxs[:, np.newaxis, :] + neighbor_idx_offsets[np.newaxis, :, :] # [ngrids, 8, 3], ngrids==d*h*w 407 | 408 | # create neighbor mask for every crop grid idx 409 | ngrids, nneighbors, _ = all_grid_neighbor_idxs.shape 410 | all_grid_neighbor_idxs_flatten = all_grid_neighbor_idxs.reshape(-1, 3) 411 | all_grid_neighbor_idxs_flatten = all_grid_neighbor_idxs_flatten[:, 0] * (h * w) + all_grid_neighbor_idxs_flatten[:, 1] * w + all_grid_neighbor_idxs_flatten[:, 2] 412 | crop_indices_flatten = crop_indices[:, 0] * (h * w) + crop_indices[:, 1] * w + crop_indices[:, 2] 413 | mask_flatten = np.isin(all_grid_neighbor_idxs_flatten, crop_indices_flatten) # [ngrids * 8,] 414 | mask = mask_flatten.reshape(ngrids, nneighbors) 415 | mask_with_neighbor = np.any(mask, axis=1) 416 | 417 | grid_with_neighbor_idxs = all_grid_idxs[mask_with_neighbor, :] # [?, 3] 418 | grid_with_neighbor_corners = xmin + grid_with_neighbor_idxs * 0.5 * crop_size 419 | 420 | grid_with_neighbor_samples = grid_with_neighbor_corners[:, np.newaxis, :] + np.random.uniform(size=(grid_with_neighbor_corners.shape[0], samples_per_grid, 3)) * crop_size # [?, 3] 421 | grid_with_neighbor_samples = grid_with_neighbor_samples.reshape(-1, 3) 422 | grid_with_neighbor_samples = grid_with_neighbor_samples.astype(np.float32) 423 | 424 | return grid_with_neighbor_samples 425 | 426 | 427 | def farthest_point_sample(xyz, nsamples): # bs*npoints*3 428 | centroids = np.zeros((xyz.shape[0], nsamples)).astype(np.long) # bs*nsamples 429 | distance = np.ones((xyz.shape[0], xyz.shape[1])).astype(np.float32) * 1e10 # bs*npoints 430 | farthest = np.random.randint(0, xyz.shape[1], size=(xyz.shape[0],)) # bs 431 | batch_idxs = np.arange(xyz.shape[0]) # bs 432 | for i in range(nsamples): 433 | centroids[:, i] = farthest 434 | centroid = np.expand_dims(xyz[batch_idxs, farthest, :], axis=1) # bs*1*3 435 | dist = ((xyz - centroid) ** 2).sum(axis=2) # bs*npoints 436 | mask = (dist < distance) 437 | distance[mask] = dist[mask] 438 | farthest = np.argmax(distance, axis=1) 439 | batch_idxs_exp = np.tile(batch_idxs[:, np.newaxis], [1, nsamples]) 440 | query_points = xyz[batch_idxs_exp, centroids] # bs*nsamples*3 441 | return query_points # bs*nsamples*3 442 | 443 | 444 | def random_point_sampling(xyz, nsamples): # bs*npoints*3 445 | bs, npoints, _ = xyz.shape 446 | rand_idxs = np.random.choice(npoints, size=(bs, nsamples), replace=False) 447 | batch_idxs = np.tile(np.arange(bs)[:, np.newaxis], [1, nsamples]) 448 | rand_samples = xyz[batch_idxs, rand_idxs, :] # bs*nsamples*3 449 | return rand_samples # bs*nsamples*3 450 | 451 | 452 | if __name__ == "__main__": 453 | input_ply = 'shapenet_chair.ply' 454 | v, n = read_point_ply(input_ply) 455 | 456 | query_points = farthest_point_sample(v[np.newaxis, :, :], 10000) 457 | rand_samples = random_point_sampling(v[np.newaxis, :, :], 10000) 458 | 459 | np.savetxt("fps_samples.txt", query_points[0], fmt='%f', delimiter=';') 460 | np.savetxt("rand_samples.txt", rand_samples[0], fmt='%f', delimiter=';') 461 | -------------------------------------------------------------------------------- /pipelines/utils/postprocess_utils.py: -------------------------------------------------------------------------------- 1 | """Post-process to remove interior backface from reconstruction artifact.""" 2 | import numpy as np 3 | from scipy import sparse 4 | from scipy import spatial 5 | import trimesh 6 | 7 | 8 | def merge_meshes(mesh_list): 9 | """Merge a list of individual meshes into a single mesh.""" 10 | verts = [] 11 | faces = [] 12 | nv = 0 13 | for m in mesh_list: 14 | verts.append(m.vertices) 15 | faces.append(m.faces + nv) 16 | nv += m.vertices.shape[0] 17 | v = np.concatenate(verts, axis=0) 18 | f = np.concatenate(faces, axis=0) 19 | merged_mesh = trimesh.Trimesh(v, f) 20 | return merged_mesh 21 | 22 | 23 | def average_onto_vertex(mesh, per_face_attrib): 24 | """Average per-face attribute onto vertices.""" 25 | assert per_face_attrib.shape[0] == mesh.faces.shape[0] 26 | assert len(per_face_attrib.shape) == 1 27 | c = np.concatenate([[0], per_face_attrib], axis=0) 28 | v2f_orig = mesh.vertex_faces.copy() 29 | v2f = v2f_orig.copy() 30 | v2f += 1 31 | per_vert_sum = np.sum(c[v2f], axis=1) 32 | per_vert_count = np.sum(np.logical_not(v2f == 0), axis=1) 33 | per_vert_attrib = per_vert_sum / per_vert_count 34 | return per_vert_attrib 35 | 36 | 37 | def average_onto_face(mesh, per_vert_attrib): 38 | """Average per-vert attribute onto faces.""" 39 | assert per_vert_attrib.shape[0] == mesh.vertices.shape[0] 40 | assert len(per_vert_attrib.shape) == 1 41 | per_face_attrib = per_vert_attrib[mesh.faces] 42 | per_face_attrib = np.mean(per_face_attrib, axis=1) 43 | return per_face_attrib 44 | 45 | 46 | def remove_backface(mesh, pc, k=3, lap_iter=50, lap_val=0.50, 47 | area_threshold=1, verbose=False): 48 | """Remove the interior backface resulting from reconstruction artifacts. 49 | 50 | Args: 51 | mesh: trimesh instance. mesh recon. from lig that may contain backface. 52 | pc: np.array of shape [n, 6], original input point cloud. 53 | k: int, number of nearest neighbor for pooling sign. 54 | lap_iter: int, number of laplacian smoothing iterations. 55 | lap_val: float, lambda value for laplacian smoothing of cosine distance. 56 | area_threshold: float, minimum area connected components to preserve. 57 | verbose: bool, verbose print. 58 | Returns: 59 | mesh_new: trimesh instance. new mesh with backface removed. 60 | """ 61 | mesh.remove_degenerate_faces() 62 | 63 | v, n = pc[:, :3], pc[:, 3:] 64 | 65 | # build cKDTree to accelerate nearest point search 66 | if verbose: print("Building KDTree...") 67 | tree_pc = spatial.cKDTree(data=v) 68 | 69 | # for each vertex, find nearest point in input point cloud 70 | if verbose: print("{}-nearest neighbor search...".format(k)) 71 | _, idx = tree_pc.query(mesh.vertices, k=k, n_jobs=-1) 72 | 73 | # slice out the nn points 74 | n_nn = n[idx] # shape: [#v_query, k, dim] 75 | 76 | # find dot products. 77 | if verbose: print("Computing norm alignment...") 78 | n_v = mesh.vertex_normals[:, None, :] # shape: [#v_query, 1, dim] 79 | per_vert_norm_alignment = np.sum(n_nn * n_v, axis=-1) # [#v_query, k] 80 | per_vert_norm_alignment = np.mean(per_vert_norm_alignment, axis=-1) 81 | 82 | # laplacian smoothing of per vertex normal alignment 83 | if verbose: print("Computing laplacian smoothing...") 84 | lap = trimesh.smoothing.laplacian_calculation(mesh) 85 | dlap = lap.shape[0] 86 | op = sparse.eye(dlap) + lap_val * (lap - sparse.eye(dlap)) 87 | for _ in range(lap_iter): 88 | per_vert_norm_alignment = op.dot(per_vert_norm_alignment) 89 | 90 | # average onto face 91 | per_face_norm_alignment = average_onto_face(mesh, per_vert_norm_alignment) 92 | 93 | # remove faces with per_face_norm_alignment < 0 94 | if verbose: print("Removing backfaces...") 95 | ff = mesh.faces[per_face_norm_alignment > -0.75] 96 | mesh_new = trimesh.Trimesh(mesh.vertices, ff) 97 | mesh_new.remove_unreferenced_vertices() 98 | 99 | if verbose: print("Cleaning up...") 100 | mesh_list = mesh_new.split(only_watertight=False) 101 | # filter out small floating junk from backface 102 | areas = [m.area for m in mesh_list] 103 | threshold = min(np.max(areas) / 5, area_threshold) 104 | mesh_list = [m for m in mesh_list if m.area > threshold] 105 | mesh_new = merge_meshes(mesh_list) 106 | 107 | # fill small holes 108 | mesh_new.fill_holes() 109 | 110 | return mesh_new 111 | -------------------------------------------------------------------------------- /pretrained_models/lig/model_best.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wnbzhao/Local-Implicit-Grid-Pytorch/d45da37beda52653f0066f9ba0f0500c54402e13/pretrained_models/lig/model_best.pt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | torch>=1.6 3 | torchvision>=0.8.1 4 | tensorboardX 5 | tqdm 6 | argparse 7 | ipdb 8 | yaml -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | try: 2 | from setuptools import setup 3 | except ImportError: 4 | from distutils.core import setup 5 | from Cython.Build import cythonize 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | setup( 9 | ext_modules=[ 10 | CUDAExtension('pipelines.utils.libchamfer.chamfer', [ 11 | 'pipelines/utils/libchamfer/chamfer_cuda.cpp', 12 | 'pipelines/utils/libchamfer/chamfer.cu', 13 | ]), 14 | ], 15 | cmdclass={ 16 | 'build_ext': BuildExtension 17 | }) --------------------------------------------------------------------------------