├── .gitignore ├── LICENSE ├── README.md ├── environment.yml ├── saved_models └── model_state_dict.pth ├── src ├── data_utils.py ├── generate_local_points_dataset.py ├── knn.py ├── losses.py ├── main_generate_mesh.py ├── main_train_model.py ├── mesh_utils.py ├── mini_mlp.py ├── point_tri_net.py ├── train_utils.py ├── utils.py └── world.py └── teaser.gif /.gitignore: -------------------------------------------------------------------------------- 1 | # Our things 2 | .polyscope.ini 3 | imgui.ini 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Nicholas Sharp 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Source code & pretrained model for "[PointTriNet: Learned Triangulation of 3D Point Sets](https://nmwsharp.com/research/learned-triangulation/)", by [Nicholas Sharp](https://nmwsharp.com/) and [Maks Ovsjanikov](http://www.lix.polytechnique.fr/~maks/) at ECCV 2020. 2 | 3 | - PDF: [link](https://nmwsharp.com/media/papers/learned-triangulation/learned_triangulation.pdf) 4 | - Project: [link](https://nmwsharp.com/research/learned-triangulation/) 5 | - Talk: [link](https://www.youtube.com/watch?v=PoNT0u_wz4Y) 6 | 7 | 8 | ![demo gif](https://github.com/nmwsharp/learned-triangulation/blob/master/teaser.gif) 9 | 10 | ## Example: Generate a mesh 11 | 12 | 13 | The script `main_generate_mesh.py` applies a trained model to triangulate a point set. A set of pretrained weights are included in `saved_models/` 14 | 15 | ```sh 16 | python src/main_generate_mesh.py saved_models/model_state_dict.pth path/to/points.ply 17 | ``` 18 | 19 | Check out the `--help` flag on the script for arguments. In particular, the script can either take a point cloud directly as input, or take a mesh as input and uniformly sample points with `--sample_cloud`. 20 | 21 | Note that by default, the script opens up a GUI (using [Polyscope](http://polyscope.run/)) to show results. To skip the GUI and just write out the resulting mesh, use: 22 | 23 | ```sh 24 | python src/main_generate_mesh.py path_to_your_cloud_or_mesh.ply --output result 25 | ``` 26 | 27 | ## Example: Integrating with code 28 | 29 | If you want to integrate PointTriNet in to your own codebase, the `PointTriNet_Mesher` from `point_tri_net.py` encapsulates all the functionality of the method. It's a `torch.nn.Module`, so you can make it a member of other modules, load weights, etc. 30 | 31 | To create the model, load weights, and triangulate a point set, just call: 32 | 33 | ```python 34 | 35 | model = PointTriNet_Mesher() 36 | model.load_state_dict(torch.load(some_path)) 37 | model.eval() 38 | 39 | samples = # your (B,V,3) torch tensor of point positions 40 | 41 | with torch.no_grad(): 42 | candidate_triangles, candidate_probs = model.predict_mesh(samples) 43 | # candidate_triangles is a (B, F, 3) index tensor, predicted triangles 44 | # candidate_probs is a (B, F) float tensor of [0,1] probabilities for each triangle 45 | 46 | # You are probably interested in only the high-probability triangles. For example, 47 | # get the high-probability triangles from the 0th batch entry like 48 | b = 0 49 | prob_thresh = 0.9 50 | high_prob_faces = candidate_triangles[b, candidate_probs[b,:] > prob_thresh, :] 51 | 52 | 53 | ``` 54 | 55 | ## Example: Generate data & train the model 56 | 57 | **Prerequisite**: a collection of shapes to train on; we use the training set (all classes) of ShapeNet v2, which you can download on your own. Note that we _do not_ train PointTriNet to match the triangulation of existing meshes, we're just using meshes as a convenient data source from which to sample point cloud patches. 58 | 59 | **Step 1** Sample point cloud patches as training (and validation) data 60 | 61 | ```shell 62 | python src/generate_local_points_dataset.py --input_dir=/path/to/train_meshes/ --output_dir=data/train/ --n_samples=20000 63 | 64 | python src/generate_local_points_dataset.py --input_dir=/path/to/val_meshes/ --output_dir=data/val/ --n_samples=5000 65 | ``` 66 | 67 | **Step 2** Train the model 68 | 69 | ```sh 70 | python src/main_train_model.py 71 | ``` 72 | 73 | With default parameters, this will train for 3 epochs on the dataset above, using < 8GB gpu memory and taking ~6hrs on an RTX 2070 GPU. Checkpoints will be saved in `./training_runs`, along with tensorboard logging. 74 | 75 | Note that this script has paths at the top relative to the expected directory layout of this repo. If you want to use a different directory layout, you can update the paths. 76 | 77 | ## Dependencies 78 | 79 | Depends on `pytorch`, `torch-scatter`, `libigl`, and `polyscope`, along with some other typical numerical components. The code is pretty standard, and there shouldn't be any particularly strict version requirements on these dependencies; any recent version should work fine. 80 | 81 | For completeness, an `environment.yml` file is included (which is a superset of the required packages). 82 | 83 | ## Citation 84 | 85 | If this code contributes to academic work, please cite: 86 | 87 | ```bib 88 | @inproceedings{sharp2020ptn, 89 | title={"PointTriNet: Learned Triangulation of 3D Point Sets"}, 90 | author={Sharp, Nicholas and Ovsjanikov, Maks}, 91 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 92 | pages={}, 93 | year={2020} 94 | } 95 | ``` 96 | 97 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: learned_tri_env 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1 7 | - blas=1.0 8 | - ca-certificates=2020.6.20 9 | - certifi=2020.6.20 10 | - igl=0.4.1 11 | - intel-openmp=2020.2 12 | - ld_impl_linux-64=2.33.1 13 | - libblas=3.9.0 14 | - libcblas=3.9.0 15 | - libedit=3.1.20191231 16 | - libffi=3.3 17 | - libgcc-ng=9.1.0 18 | - libgfortran-ng=7.5.0 19 | - libgfortran4=7.5.0 20 | - liblapack=3.9.0 21 | - libstdcxx-ng=9.1.0 22 | - mkl=2020.2 23 | - mkl-service=2.3.0 24 | - mkl_fft=1.2.0 25 | - mkl_random=1.1.1 26 | - ncurses=6.2 27 | - numpy=1.19.2 28 | - numpy-base=1.19.2 29 | - openssl=1.1.1h 30 | - pip=20.2.4 31 | - python=3.6.12 32 | - python_abi=3.6 33 | - readline=8.0 34 | - scipy=1.5.2 35 | - setuptools=50.3.0 36 | - six=1.15.0 37 | - sqlite=3.33.0 38 | - tk=8.6.10 39 | - wheel=0.35.1 40 | - xz=5.2.5 41 | - zlib=1.2.11 42 | - pip: 43 | - absl-py==0.10.0 44 | - cachetools==4.1.1 45 | - chardet==3.0.4 46 | - google-auth==1.22.1 47 | - google-auth-oauthlib==0.4.1 48 | - grpcio==1.33.1 49 | - idna==2.10 50 | - importlib-metadata==2.0.0 51 | - joblib==0.17.0 52 | - markdown==3.3.3 53 | - meshio==4.3.1 54 | - oauthlib==3.1.0 55 | - pillow==8.0.1 56 | - plyfile==0.7.2 57 | - polyscope==0.1.3 58 | - protobuf==3.13.0 59 | - pyasn1==0.4.8 60 | - pyasn1-modules==0.2.8 61 | - requests==2.24.0 62 | - requests-oauthlib==1.3.0 63 | - rsa==4.6 64 | - scikit-learn==0.23.2 65 | - sklearn==0.0 66 | - tensorboard==2.1.0 67 | - threadpoolctl==2.1.0 68 | - torch==1.4.0 69 | - torch-cluster==1.4.5 70 | - torch-scatter==2.0.4 71 | - torchvision==0.5.0 72 | - urllib3==1.25.11 73 | - werkzeug==1.0.1 74 | - zipp==3.4.0 75 | -------------------------------------------------------------------------------- /saved_models/model_state_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmwsharp/learned-triangulation/12d970d9ce87a973b8aeb0d4ea47562ada578a45/saved_models/model_state_dict.pth -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import torch 5 | import numpy as np 6 | 7 | # import polyscope 8 | 9 | import world 10 | import utils 11 | from utils import * 12 | import mesh_utils 13 | 14 | 15 | class PointSurfaceDataset(torch.utils.data.Dataset): 16 | 17 | def __init__(self, dir_with_meshes=None, transforms=[]): 18 | super(PointSurfaceDataset, self).__init__() 19 | 20 | # Members 21 | self.mesh_paths = None 22 | self.transforms = None 23 | 24 | # Constructor 25 | if dir_with_meshes is not None: 26 | 27 | # Wrap the string if we just got a single directory 28 | if isinstance(dir_with_meshes, str): 29 | dir_with_meshes = [dir_with_meshes] 30 | 31 | # Parse out all of the paths 32 | self.mesh_paths = [] 33 | for d in dir_with_meshes: 34 | # Just load from a single directory 35 | for f in os.listdir(d): 36 | _, ext = os.path.splitext(f) 37 | fullpath = os.path.join(d, f) 38 | self.mesh_paths.append(fullpath) 39 | 40 | 41 | # Validate that all of the paths are valid, so we fail fast if there's a mistake 42 | for p in self.mesh_paths: 43 | if not os.path.isfile(p): 44 | raise ValueError("Dataset load error: could not find file " + str(p)) 45 | 46 | # Save other options 47 | self.transforms = transforms 48 | 49 | print("\n== PointSurfaceDataset: loaded dataset with {} surfaces .\n".format(len(self.mesh_paths))) 50 | 51 | def __len__(self): 52 | return len(self.mesh_paths) 53 | 54 | def __getitem__(self, idx): 55 | if torch.is_tensor(idx): 56 | idx = idx.tolist() 57 | 58 | # Read the mesh 59 | # (always loads on CPU) 60 | fullpath = self.mesh_paths[idx] 61 | record = np.load(fullpath, allow_pickle=True) 62 | 63 | vert_pos = torch.tensor(record['vert_pos'], dtype=world.dtype, device='cpu') 64 | surf_pos = torch.tensor(record['surf_pos'], dtype=world.dtype, device='cpu') 65 | 66 | if record['vert_normal'] is None: 67 | vert_normal = torch.zeros((0,3), dtype=world.dtype, device='cpu') 68 | else: 69 | vert_normal = torch.tensor(record['vert_normal'], dtype=world.dtype, device='cpu') 70 | 71 | if record['surf_normal'] is None: 72 | surf_normal = torch.zeros((0,3), dtype=world.dtype, device='cpu') 73 | else: 74 | surf_normal = torch.tensor(record['surf_normal'], dtype=world.dtype, device='cpu') 75 | 76 | # Apply transformations 77 | for transform in self.transforms: 78 | vert_pos, _, _ = transform(verts=vert_pos) 79 | 80 | return {'vert_pos': vert_pos, 'vert_normal': vert_normal, 'surf_pos' : surf_pos, 'surf_normal' : surf_normal, 'path': fullpath} 81 | 82 | -------------------------------------------------------------------------------- /src/generate_local_points_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | import argparse 4 | import numpy as np 5 | import sys 6 | import os 7 | import gc 8 | 9 | import utils 10 | 11 | from scipy.io import loadmat 12 | from scipy import spatial 13 | import meshio 14 | from plyfile import PlyData 15 | 16 | 17 | """ 18 | Generate training data in the form of points for meshes in local neighborhoods. 19 | """ 20 | 21 | sys.setrecursionlimit(10000) 22 | 23 | def ensure_dir_exists(d): 24 | if not os.path.exists(d): 25 | os.makedirs(d) 26 | 27 | 28 | def generate_sample_counts(entries, total_count): 29 | 30 | counts = np.zeros(len(entries), dtype=int) 31 | for i in range(total_count): 32 | ind = np.random.randint(len(entries)) 33 | counts[ind] += 1 34 | 35 | return counts 36 | 37 | def area_normals(verts, faces): 38 | coords = verts[faces] 39 | vec_A = coords[:, 1, :] - coords[:, 0, :] 40 | vec_B = coords[:, 2, :] - coords[:, 0, :] 41 | raw_normal = np.cross(vec_A, vec_B) 42 | return raw_normal 43 | 44 | 45 | def uniform_sample_surface(verts, faces, n_pts): 46 | 47 | areaN = area_normals(verts, faces) 48 | face_areas = 0.5 * np.linalg.norm(areaN, axis=-1) 49 | 50 | # chose which faces 51 | face_inds = np.random.choice(faces.shape[0], size=(n_pts,), replace=True, p=face_areas/np.sum(face_areas)) 52 | 53 | # Get barycoords for each sample 54 | r1_sqrt = np.sqrt(np.random.rand(n_pts)) 55 | r2 = np.random.rand(n_pts) 56 | bary_vals = np.zeros((n_pts, 3)) 57 | bary_vals[:, 0] = 1. - r1_sqrt 58 | bary_vals[:, 1] = r1_sqrt * (1. - r2) 59 | bary_vals[:, 2] = r1_sqrt * r2 60 | 61 | return face_inds, bary_vals 62 | 63 | def get_samples(verts, faces, n_pts): 64 | 65 | face_inds, bary_vals = uniform_sample_surface(verts, faces, n_pts) 66 | # face_normals = igl.per_face_normals(verts, faces, np.array((0., 0., 0.,))) 67 | areaN = area_normals(verts, faces) 68 | face_normals = areaN / np.linalg.norm(areaN, axis=-1)[:,np.newaxis] 69 | 70 | positions = np.sum(bary_vals[:,:,np.newaxis] * verts[faces[face_inds, :]], axis=1) 71 | normals = face_normals[face_inds] 72 | 73 | return positions, normals 74 | 75 | def main(): 76 | 77 | parser = argparse.ArgumentParser() 78 | 79 | # Build arguments 80 | parser.add_argument('--input_dir', type=str, required=True, help='path to the files') 81 | parser.add_argument('--output_dir', type=str, required=True, help='where to put results') 82 | parser.add_argument('--n_samples', type=int, required=True, help='number of neighborhoods to sample') 83 | 84 | parser.add_argument('--neigh_size', type=int, default=256, help='number of vertices to sample in each region') 85 | parser.add_argument('--surface_size', type=int, default=1024, help='number of points to use to represent the surface') 86 | parser.add_argument('--model_frac', type=float, default=0.25, help='what fraction of the shape each neighborhood should be') 87 | 88 | parser.add_argument('--n_add', type=float, default=0.0, help='fraction of noise points to add') 89 | parser.add_argument('--on_surface_dev', type=float, default=0.02, help='') 90 | 91 | parser.add_argument('--polyscope', action='store_true', help='viz') 92 | 93 | # Parse arguments 94 | args = parser.parse_args() 95 | 96 | ensure_dir_exists(args.output_dir) 97 | 98 | # Load the list of meshes 99 | meshes = [] 100 | for f in os.listdir(args.input_dir): 101 | meshes.append(os.path.join(args.input_dir, f)) 102 | 103 | print("Found {} mesh files".format(len(meshes))) 104 | random.shuffle(meshes) 105 | counts = generate_sample_counts(meshes, args.n_samples) 106 | i_sample = 0 107 | 108 | 109 | def process_file(i_mesh, f): 110 | nonlocal i_sample 111 | 112 | # Read the mesh 113 | 114 | # libigl loader seems to leak memory in loop? 115 | # verts, faces = utils.read_mesh(f) 116 | 117 | plydata = PlyData.read(f) 118 | verts = np.vstack(( 119 | plydata['vertex']['x'], 120 | plydata['vertex']['y'], 121 | plydata['vertex']['z'] 122 | )).T 123 | tri_data = plydata['face'].data['vertex_indices'] 124 | faces = np.vstack(tri_data) 125 | 126 | 127 | # Compute total sample counts 128 | n_vert_sample_tot = int(args.neigh_size / args.model_frac * (1. - args.n_add)) 129 | n_surf_sample_tot = int(args.surface_size / (args.model_frac)) 130 | 131 | # sample points 132 | vert_sample_pos, vert_sample_normal = get_samples(verts, faces, n_vert_sample_tot) 133 | 134 | if(args.n_add > 0): 135 | n_vert_sample_noise = int(args.neigh_size / args.model_frac * (args.n_add)) 136 | vert_sample_noise_pos, vert_sample_noise_normal = get_samples(verts, faces, n_vert_sample_noise) 137 | vert_sample_noise_pos += np.random.randn(n_vert_sample_noise, 3) * args.on_surface_dev 138 | 139 | vert_sample_pos = np.concatenate((vert_sample_pos, vert_sample_noise_pos), axis=0) 140 | vert_sample_normal = np.concatenate((vert_sample_normal, vert_sample_noise_normal), axis=0) 141 | 142 | 143 | surf_sample_pos, surf_sample_normal = get_samples(verts, faces, n_surf_sample_tot) 144 | 145 | # Build nearest-neighbor structure 146 | kd_tree_vert = spatial.KDTree(vert_sample_pos) 147 | kd_tree_surf = spatial.KDTree(surf_sample_pos) 148 | 149 | # Randomly sample vertices 150 | last_sample = i_sample + counts[i_mesh] 151 | while i_sample < last_sample: 152 | 153 | print("generating sample {} / {} on mesh {}".format(i_sample, args.n_samples, f)) 154 | 155 | # Random vertex 156 | ind = np.random.randint(vert_sample_pos.shape[0]) 157 | center = surf_sample_pos[ind, :] 158 | 159 | _, neigh_vert = kd_tree_vert.query(center, k=args.neigh_size) 160 | _, neigh_surf = kd_tree_surf.query(center, k=args.surface_size) 161 | 162 | result_vert_pos = vert_sample_pos[neigh_vert, :] 163 | result_vert_normal = vert_sample_normal[neigh_vert, :] 164 | result_surf_pos = surf_sample_pos[neigh_surf, :] 165 | result_surf_normal = surf_sample_normal[neigh_surf, :] 166 | 167 | # Write out the result 168 | out_filename = os.path.join(args.output_dir, "neighborhood_points_{:06d}.npz".format(i_sample)) 169 | np.savez(out_filename, vert_pos=result_vert_pos, vert_normal=result_vert_normal, surf_pos= result_surf_pos, surf_normal=result_surf_normal) 170 | 171 | i_sample = i_sample + 1 172 | 173 | 174 | for i_mesh, f in enumerate(meshes): 175 | process_file(i_mesh, f) 176 | 177 | 178 | 179 | if __name__ == "__main__": 180 | main() 181 | -------------------------------------------------------------------------------- /src/knn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | #import igl 3 | import numpy as np 4 | import sklearn.neighbors 5 | 6 | import world 7 | import utils 8 | from utils import * 9 | 10 | import torch_cluster 11 | 12 | 13 | # Finds the k nearest neighbors of source on target. 14 | # Return is two tensors (distances, indices). Returned points will be sorted in increasing order of distance. 15 | def find_knn(points_source, points_target, k, largest=False, omit_diagonal=False, method='brute', prebuilt_tree=None): 16 | 17 | if omit_diagonal and points_source.shape[0] != points_target.shape[0]: 18 | raise ValueError("omit_diagonal can only be used when source and target are same shape") 19 | 20 | if method != 'cpu_kd' and points_source.shape[0] * points_target.shape[0] > 1e8: 21 | method = 'cpu_kd' 22 | print("switching to cpu_kd knn") 23 | 24 | if method == 'brute': 25 | 26 | # Expand so both are NxMx3 tensor 27 | points_source_expand = points_source.unsqueeze(1) 28 | points_source_expand = points_source_expand.expand(-1, points_target.shape[0], -1) 29 | points_target_expand = points_target.unsqueeze(0) 30 | points_target_expand = points_target_expand.expand(points_source.shape[0], -1, -1) 31 | 32 | diff_mat = points_source_expand - points_target_expand 33 | dist_mat = norm(diff_mat) 34 | 35 | if omit_diagonal: 36 | torch.diagonal(dist_mat)[:] = float('inf') 37 | 38 | result = torch.topk(dist_mat, k=k, largest=largest, sorted=True) 39 | return result 40 | 41 | elif method == 'cpu_kd': 42 | 43 | if largest: 44 | raise ValueError("can't do largest with cpu_kd") 45 | 46 | points_source_np = toNP(points_source) 47 | 48 | # Build the tree 49 | if prebuilt_tree is not None: 50 | kd_tree = prebuilt_tree 51 | else: 52 | points_target_np = toNP(points_target) 53 | kd_tree = sklearn.neighbors.KDTree(points_target_np) 54 | 55 | k_search = k+1 if omit_diagonal else k 56 | _, neighbors = kd_tree.query(points_source_np, k=k_search) 57 | 58 | if omit_diagonal: 59 | # Mask out self element 60 | mask = neighbors != np.arange(neighbors.shape[0])[:, np.newaxis] 61 | 62 | # make sure we mask out exactly one element in each row, in rare case of many duplicate points 63 | mask[np.sum(mask, axis=1) == mask.shape[1], -1] = False 64 | 65 | neighbors = neighbors[mask].reshape((neighbors.shape[0], neighbors.shape[1]-1)) 66 | 67 | inds = torch.tensor(neighbors, device=points_source.device, dtype=torch.int64) 68 | dists = norm(points_source.unsqueeze(1).expand(-1, k, -1) - points_target[inds]) 69 | 70 | return dists, inds 71 | 72 | else: 73 | raise ValueError("unrecognized method") 74 | 75 | # For each point in `source`, compute the distance to the nearest point in `target`. Returns the mean of these distances. 76 | def point_cloud_nearest_dist(points_source, points_target): 77 | 78 | # dummy batch IDs 79 | # source_ids = torch.zeros(points_source.shape[0], dtype=torch.int64, device=points_source.device) 80 | # target_ids = torch.zeros(points_target.shape[0], dtype=torch.int64, device=points_source.device) 81 | 82 | # get the nearest point in target 83 | # nearest_ind = torch_cluster.nearest(points_source, points_tarege, source_ids, target_ids) 84 | nearest_ind = torch_cluster.nearest(points_source, points_target) 85 | 86 | # compute the distances themselves 87 | nearest_pos = points_target[nearest_ind, :] 88 | dists = utils.norm(points_source - nearest_pos) 89 | return dists 90 | 91 | 92 | # For face in faces, find the k nearest neighbors in target_points. 93 | # Implemented by first finding the nearest neighbors of the barycenter of the triangle, then discarding the triangles vertices 94 | def face_neighbors(face_centers, faces, target_points, k=10, alternate_centers=None): 95 | 96 | # Get nearby points for each face by nearest neighbor from barycenter, discard neighbors which are the vertices of the triangle 97 | # (note: the way this is written if the triangle vertices did not end up on the neighbors list, some other neighbors will be 98 | # discarded so the output tensor has fixed size) 99 | _, neighbors = find_knn(face_centers, target_points, k+3) # +3 because we discard below 100 | 101 | # If using alternate centers, includes them here 102 | if alternate_centers is not None: 103 | _, alt_neighbors = find_knn(alternate_centers, target_points, k) 104 | neighbors = torch.cat([neighbors, alt_neighbors], dim=-1) 105 | 106 | # Discard the faces's vertices 107 | # Build a mask of the indices we want to keep, by masking out neighbors equal to one of the faces vertices 108 | mask = torch.ones_like(neighbors, dtype=torch.bool) 109 | for i in range(3): 110 | ith_vert = faces[:, i] 111 | ith_mask = (neighbors - ith_vert.unsqueeze(1)) == 0 112 | # mask = mask & ~ith_mask # old version 113 | mask = mask.type(torch.uint8) & (~ith_mask).type(torch.uint8) 114 | 115 | # Make sure there are exactly 3 False entries in each mask, by adding Falses to end (which drops farthest points) 116 | 117 | # Note: this uses byte tensors, since where() isn't implemented for bool. But support was added days ago: https://github.com/pytorch/pytorch/pull/26430 118 | # once that pull makes it in to release, we can just use bool tensors 119 | mask = mask.to(torch.uint8) 120 | 121 | n_neigh = k + 3 122 | if alternate_centers is not None: 123 | n_neigh += k 124 | 125 | n_rounds = 3 if alternate_centers is None else 6 126 | for i in range(n_rounds): 127 | false_counts = n_neigh - torch.sum(mask, dim=-1) 128 | 129 | # Alternate version with element from back set to False 130 | false_back = mask.clone() 131 | false_back[:, n_neigh-i-1] = False 132 | 133 | replace_row = (false_counts != 3).unsqueeze(-1) 134 | mask = torch.where(replace_row, false_back, mask) 135 | 136 | # see note above 137 | mask = mask.to(torch.bool) 138 | 139 | # Must be 0 or reshape below will fail. Luckily, 3 iterations of the loop above will always be enough to make this 0 140 | # false_counts = n_neigh - torch.sum(mask, dim=-1) 141 | # print("n bad = " + str(torch.sum(false_counts != 3))) 142 | # print_info(neighbors[mask], "neigh mask") 143 | 144 | # Take only the masked elements 145 | result_neighbors = neighbors[mask].reshape((faces.shape[0], -1)) 146 | 147 | return result_neighbors 148 | 149 | 150 | # Generate all neighbors for a triangle, excluding its three vertices 151 | def all_face_neighbors(faces, n_target): 152 | n_face = faces.shape[0] 153 | 154 | if world.debug_checks: 155 | utils.check_faces_for_duplicates(faces) 156 | 157 | neighbors = torch.arange(n_target, device=world.device).unsqueeze(0).expand((n_face, -1)) 158 | 159 | # Remove the three triangle vertices 160 | # Build a mask of the indices we want to keep, by masking out neighbors equal to one of the faces vertices 161 | mask = torch.ones_like(neighbors, dtype=torch.bool) 162 | for i in range(3): 163 | ith_vert = faces[:, i] 164 | ith_mask = (neighbors - ith_vert.unsqueeze(1)) == 0 165 | # mask = mask & ~ith_mask # old version 166 | mask = mask.type(torch.uint8) & (~ith_mask).type(torch.uint8) 167 | 168 | # see note above 169 | mask = mask.to(torch.bool) 170 | 171 | # Take only the masked elements 172 | result_neighbors = neighbors[mask].reshape((faces.shape[0], -1)) 173 | 174 | return result_neighbors 175 | 176 | 177 | # Inputs: 178 | # interp_points: (N,D) locations at which to same values 179 | # source_points: (M,D) locations at which values are defined 180 | # source_values: (M,V) value at each location in source_points 181 | # weight_fn: strategy to weight interpolant 182 | # 183 | # Outputs: 184 | # (N,V) interpolated values at interp_points 185 | def interpolate_nearby(interp_points, source_points, source_values, weight_fn='inv_dist', eps=1e-6): 186 | N = interp_points.shape[0] 187 | M = source_points.shape[0] 188 | 189 | # TODO could do this with knn for better scaling 190 | 191 | # Expand so both point sets are NxMxD tensor, and value set is NxMxV 192 | interp_points_expand = interp_points.unsqueeze(1) 193 | interp_points_expand = interp_points_expand.expand(-1, M, -1) 194 | 195 | source_points_expand = source_points.unsqueeze(0) 196 | source_points_expand = source_points_expand.expand(N, -1, -1) 197 | 198 | source_values_expand = source_values.unsqueeze(0) 199 | source_values_expand = source_values_expand.expand(N, -1, -1) 200 | 201 | # Evaluate weight function 202 | # after, `weights' will be a (NxM) tensor of weights 203 | if weight_fn == 'inv_dist': 204 | 205 | diff_mat = interp_points_expand - source_points_expand 206 | dist_mat = norm(diff_mat, highdim=True) 207 | weights = 1.0 / (torch.pow(dist_mat, 3) + eps) 208 | 209 | else: 210 | raise ValueError("unrecognized weight function: {}".format(weight_fn)) 211 | 212 | # Divide by weight sum to get interpolation coefficients 213 | weights = weights / torch.sum(weights, dim=-1, keepdim=True) 214 | 215 | # Interpolate 216 | interp_vals = torch.sum(source_values_expand * weights.unsqueeze(-1), dim=1) 217 | 218 | return interp_vals 219 | 220 | 221 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import world 5 | import utils 6 | import knn 7 | from utils import * 8 | import mesh_utils 9 | 10 | import torch_scatter 11 | 12 | 13 | # Distance from a known surface to a generated triangulation with probabilities 14 | def dist_surface_to_triangle_probs(gen_verts, gen_faces, gen_face_probs, n_sample_pts=None, mesh=None, surf_samples=None): 15 | if gen_faces.shape[0] == 0: 16 | return torch.tensor(0., device=surf_verts.device) 17 | 18 | if surf_samples is not None: 19 | if mesh is not None or n_sample_pts is not None: 20 | raise ValueError("bad args!") 21 | 22 | if mesh is not None: 23 | surf_verts, surf_faces = mesh 24 | if surf_samples is not None: 25 | raise ValueError("bad args!") 26 | 27 | # Sample points on the known surfacse 28 | surf_samples = mesh_utils.sample_points_on_surface(surf_verts, surf_faces, n_sample_pts) 29 | 30 | 31 | 32 | # get a characteristic length 33 | char_len = utils.norm(surf_samples - torch.mean(surf_samples , dim=0, keepdim=True)).mean() 34 | 35 | # Find the distance to all triangles in the generated surface 36 | tri_dists = mesh_utils.point_triangle_distance(surf_samples, gen_verts, gen_faces) 37 | 38 | # Sort distances 39 | k_val = min(32, tri_dists.shape[-1]) 40 | tri_dists_sorted, sorted_inds = torch.topk(tri_dists, largest=False, k=k_val, dim=-1) 41 | 42 | # Compute the likelihoods that each triangle is the nearest for that sample 43 | sorted_probs = gen_face_probs[sorted_inds] 44 | 45 | prob_none_closer = torch.cat(( # shift to the right, put 1 in first col 46 | torch.ones_like(sorted_probs)[:,:1], 47 | torch.cumprod(1. - sorted_probs, dim=-1)[:, :-1] 48 | ), dim=-1) 49 | 50 | prob_is_closest = prob_none_closer * sorted_probs 51 | 52 | # Append a last distance very far away, so you get high loss values if nothing is close 53 | last_prob = 1.0 - torch.sum(prob_is_closest, dim=-1) 54 | last_dist = char_len * torch.ones(tri_dists.shape[0], dtype=tri_dists.dtype, device=tri_dists.device) 55 | prob_is_closest = torch.cat((prob_is_closest, last_prob.unsqueeze(-1)), dim=-1) 56 | prob_is_closest = torch.clamp(prob_is_closest, 0., 1.) # for floating point reasons 57 | tri_dists_sorted = torch.cat((tri_dists_sorted, last_dist.unsqueeze(-1)), dim=-1) 58 | 59 | 60 | 61 | # Use these likelihoods to get expected distance 62 | expected_dist = torch.sum(prob_is_closest * tri_dists_sorted, dim=-1) 63 | 64 | result = torch.mean(expected_dist / char_len) 65 | return result 66 | 67 | 68 | 69 | # Distance from generated triangulation with probabilities to a known surface 70 | def dist_triangle_probs_to_sampled_surface(surf_pos, surf_normals, gen_verts, gen_faces, gen_face_probs, n_sample_pts=5000): 71 | if gen_faces.shape[0] == 0: 72 | return torch.tensor(0., device=surf_verts.device) 73 | 74 | # get a characteristic length 75 | char_len = utils.norm(surf_pos - torch.mean(surf_pos, dim=0, keepdim=True)).mean() 76 | 77 | # Sample points on the generated triangulation 78 | samples, face_inds, _ = mesh_utils.sample_points_on_surface( 79 | gen_verts, gen_faces, n_sample_pts, return_inds_and_bary=True) 80 | 81 | # Likelihoods associated with each point 82 | point_probs = gen_face_probs[face_inds] 83 | 84 | # Measure the distance to the surface 85 | knn_dist, neigh = knn.find_knn(samples, surf_pos, k=1) 86 | neigh_pos = surf_pos[neigh.squeeze(1), :] 87 | 88 | if len(surf_normals) == 0 : 89 | dists = knn_dist 90 | else: 91 | neigh_normal = surf_normals[neigh.squeeze(1), :] 92 | vecs = neigh_pos - samples 93 | dists = torch.abs(utils.dot(vecs, neigh_normal)) 94 | 95 | # Expected distance integral 96 | exp_dist = torch.mean(dists * point_probs) 97 | 98 | return exp_dist / char_len 99 | 100 | 101 | # Penalize overcomplete triangulations which overlap themselves by evaluating a spatial kernel at sampled surface 102 | def overlap_kernel(gen_verts, gen_faces, gen_face_probs, n_sample_pts=5000): 103 | if gen_faces.shape[0] == 0: 104 | return torch.tensor(0., device=gen_verts.device) 105 | 106 | # Sample points on the generated triangulation 107 | samples, face_inds, _ = mesh_utils.sample_points_on_surface( 108 | gen_verts, gen_faces, n_sample_pts, face_probs=(gen_face_probs), return_inds_and_bary=True) 109 | 110 | # Evaluate kernel 111 | sample_tri_kvals = mesh_utils.triangle_kernel(samples, gen_verts, gen_faces, kernel_height=0.5) 112 | 113 | # Incorporate weights and sum 114 | sample_tri_kvals_weight = sample_tri_kvals * gen_face_probs.unsqueeze(0) 115 | 116 | # Ideally, all samples should all have one entry with value 1 and 0 for all other entries, so 117 | # we ask that there be no kernel contribution from any other triangles. 118 | kernel_sums = torch.sum(sample_tri_kvals_weight, dim=-1) 119 | kernel_max = torch.max(sample_tri_kvals_weight, dim=-1).values 120 | scores = (kernel_sums - 1.)**2 + (kernel_max - 1.)**2 121 | 122 | # note that this corresponds to a normalization by the expected area of the surface 123 | return torch.mean(scores) 124 | 125 | 126 | def expected_watertight(gen_verts, gen_faces, gen_face_probs): 127 | if gen_faces.shape[0] == 0: 128 | return torch.tensor(0., device=gen_verts.device) 129 | 130 | V = gen_verts.shape[0] 131 | 132 | # NOTE V^2 for now 133 | 134 | # Build a list of all 3V halfedges, and the probabilities associated with them 135 | ind_keys = [] 136 | key_probs = [] 137 | for i in range(3): 138 | 139 | indA = gen_faces[:,i] 140 | indB = gen_faces[:,(i+1)%3] 141 | 142 | ind_min = torch.min(indA, indB) 143 | ind_max = torch.max(indA, indB) 144 | 145 | ind_key = ind_min * V + ind_max 146 | 147 | ind_keys.append(ind_key) 148 | key_probs.append(gen_face_probs) 149 | 150 | ind_keys_vv = torch.cat(ind_keys, dim=0) 151 | key_probs = torch.cat(key_probs, dim=0) 152 | 153 | # compute unique dense edge keys 154 | _, ind_keys = torch.unique(ind_keys_vv, return_inverse=True) 155 | 156 | # warning: this has all kinds of numerical stability pitfalls 157 | EPS = 1e-3 158 | pi = (1. - EPS) * key_probs + EPS * .5 # pull slightly towards .5 to mitigate 159 | qi = 1. - pi 160 | 161 | # compute probability that there are exactly two incident triangles 162 | # prod_qi_edge = torch_scatter.scatter_mul(qi, ind_keys) 163 | prod_qi_edge = torch.exp(torch_scatter.scatter_add(torch.log(qi), ind_keys)) 164 | prob1_this = pi * prod_qi_edge[ind_keys] / qi # probability that this halfedge is the only one incident on tri 165 | prob1_edge = torch_scatter.scatter_add(prob1_this, ind_keys) # probability that each edge has exactly one tri incident 166 | prob1_other = (prob1_edge[ind_keys] - prob1_this) / qi # probabilty that there is exactly one tri other than this one 167 | 168 | # expected halfedges without unique twin 169 | loss = torch.sum(pi * (1. - prob1_other)) / (torch.sum(pi) + 1e-4) 170 | 171 | return loss 172 | 173 | 174 | 175 | def match_predictions(candA, predA, candB, predB): 176 | 177 | candA, predA = mesh_utils.uniqueify_triangle_prob_batch(candA.unsqueeze(0), predA.unsqueeze(0)) 178 | candB, predB = mesh_utils.uniqueify_triangle_prob_batch(candB.unsqueeze(0), predB.unsqueeze(0)) 179 | candA = candA.squeeze(0) 180 | predA = predA.squeeze(0) 181 | candB = candB.squeeze(0) 182 | predB = predB.squeeze(0) 183 | 184 | # form combined list 185 | cands = torch.cat((candA, candB), dim=0) 186 | preds = torch.cat((predA, predB), dim=0) 187 | 188 | # unique an compute average 189 | u_cands, u_inds = torch.unique(cands, dim=0, return_inverse=True) 190 | u_mean = torch_scatter.scatter_mean(preds, u_inds) 191 | n_shared = candA.shape[0] + candB.shape[0] - u_cands.shape[0] 192 | 193 | # difference from mean 194 | diffs = preds - u_mean[u_inds] 195 | 196 | return torch.sum(diffs**2) / (n_shared + 1e-3) 197 | 198 | 199 | def build_loss(args, vert_pos, candidates, candidate_probs, surf_pos=None, surf_normal=None, gt_tris=None, gt_probs=None, n_sample=None, dist_surf_subsample_factor=10): 200 | 201 | loss_terms = {} 202 | 203 | if hasattr(args, "w_dist_surf_tri") and args.w_dist_surf_tri > 0.: 204 | loss_terms['dist_surf_tri'] = args.w_dist_surf_tri * \ 205 | dist_surface_to_triangle_probs(vert_pos, candidates, candidate_probs, surf_samples=surf_pos[...,::dist_surf_subsample_factor,:]) 206 | 207 | if hasattr(args, "w_dist_tri_surf") and args.w_dist_tri_surf > 0.: 208 | loss_terms['dist_tri_surf'] = args.w_dist_tri_surf * \ 209 | dist_triangle_probs_to_sampled_surface(surf_pos, surf_normal, vert_pos, candidates, candidate_probs, n_sample_pts=n_sample) 210 | 211 | if hasattr(args, "w_overlap_kernel") and args.w_overlap_kernel> 0.: 212 | loss_terms['overlap_kernel'] = args.w_overlap_kernel* \ 213 | overlap_kernel(vert_pos, candidates, candidate_probs, n_sample_pts=n_sample) 214 | 215 | if hasattr(args, "w_watertight") and args.w_watertight > 0.: 216 | loss_terms['watertight'] = args.w_watertight * \ 217 | expected_watertight(vert_pos, candidates, candidate_probs) 218 | 219 | return loss_terms 220 | -------------------------------------------------------------------------------- /src/main_generate_mesh.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import sys 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | 9 | import igl 10 | import plyfile 11 | import polyscope 12 | 13 | import utils 14 | import mesh_utils 15 | from utils import * 16 | from point_tri_net import PointTriNet_Mesher 17 | 18 | def write_ply_points(filename, points): 19 | vertex = np.core.records.fromarrays(points.transpose(), names='x, y, z', formats = 'f8, f8, f8') 20 | el = plyfile.PlyElement.describe(vertex, 'vertex') 21 | plyfile.PlyData([el]).write(filename) 22 | 23 | def main(): 24 | 25 | parser = argparse.ArgumentParser() 26 | 27 | 28 | parser.add_argument('model_weights_path', type=str, help='path to the model checkpoint') 29 | parser.add_argument('input_path', type=str, help='path to the input') 30 | 31 | parser.add_argument('--disable_cuda', action='store_true', help='disable cuda') 32 | 33 | parser.add_argument('--sample_cloud', type=int, help='run on sampled points') 34 | 35 | parser.add_argument('--n_rounds', type=int, default=5, help='number of rounds') 36 | parser.add_argument('--prob_thresh', type=float, default=.9, help='threshold for final surface') 37 | 38 | parser.add_argument('--output', type=str, help='path to save the resulting high prob mesh to. also disables viz') 39 | parser.add_argument('--output_trim_unused', action='store_true', help='trim unused vertices when outputting') 40 | 41 | # Parse arguments 42 | args = parser.parse_args() 43 | set_args_defaults(args) 44 | 45 | viz = not args.output 46 | args.polyscope = False 47 | 48 | # Initialize polyscope 49 | if viz: 50 | polyscope.init() 51 | 52 | # === Load the input 53 | 54 | if args.input_path.endswith(".npz"): 55 | record = np.load(args.input_path) 56 | verts = torch.tensor(record['vert_pos'], dtype=args.dtype, device=args.device) 57 | surf_samples = torch.tensor(record['surf_pos'], dtype=args.dtype, device=args.device) 58 | 59 | samples = verts.clone() 60 | faces = torch.zeros((0,3), dtype=torch.int64, device=args.device) 61 | 62 | polyscope.register_point_cloud("surf samples", toNP(surf_samples)) 63 | 64 | if args.input_path.endswith(".xyz"): 65 | raw_pts = np.loadtxt(args.input_path) 66 | verts = torch.tensor(raw_pts, dtype=args.dtype, device=args.device) 67 | 68 | samples = verts.clone() 69 | faces = torch.zeros((0,3), dtype=torch.int64, device=args.device) 70 | 71 | polyscope.register_point_cloud("surf samples", toNP(verts)) 72 | 73 | else: 74 | print("reading mesh") 75 | verts, faces = utils.read_mesh(args.input_path) 76 | print(" {} verts {} faces".format(verts.shape[0], faces.shape[0])) 77 | verts = torch.tensor(verts, dtype=args.dtype, device=args.device) 78 | faces = torch.tensor(faces, dtype=torch.int64, device=args.device) 79 | 80 | # verts = verts[::10,:] 81 | 82 | if args.sample_cloud: 83 | samples = mesh_utils.sample_points_on_surface(verts, faces, args.sample_cloud) 84 | else: 85 | samples = verts.clone() 86 | 87 | 88 | # For very large inputs, leave the data on the CPU and only use the device for NN evaluation 89 | if samples.shape[0] > 50000: 90 | print("Large input: leaving data on CPU") 91 | samples = samples.cpu() 92 | 93 | # === Load the model 94 | 95 | print("loading model weights") 96 | model = PointTriNet_Mesher() 97 | model.load_state_dict(torch.load(args.model_weights_path)) 98 | 99 | model.eval() 100 | 101 | with torch.no_grad(): 102 | 103 | # Sample lots of faces from the vertices 104 | print("predicting") 105 | candidate_triangles, candidate_probs = model.predict_mesh(samples.unsqueeze(0), n_rounds=args.n_rounds) 106 | candidate_triangles = candidate_triangles.squeeze(0) 107 | candidate_probs = candidate_probs.squeeze(0) 108 | print("done predicting") 109 | 110 | # Visualize 111 | high_prob = args.prob_thresh 112 | high_faces = candidate_triangles[candidate_probs > high_prob] 113 | closed_faces = mesh_utils.fill_holes_greedy(high_faces) 114 | 115 | if viz: 116 | polyscope.register_point_cloud("input points", toNP(samples)) 117 | 118 | spmesh = polyscope.register_surface_mesh("all faces", toNP(samples), toNP(candidate_triangles), enabled=False) 119 | spmesh.add_scalar_quantity("probs", toNP(candidate_probs), defined_on='faces') 120 | 121 | spmesh = polyscope.register_surface_mesh("high prob mesh " + str(high_prob), toNP(samples), toNP(high_faces)) 122 | spmesh.add_scalar_quantity("probs", toNP(candidate_probs[candidate_probs > high_prob]), defined_on='faces') 123 | 124 | spmesh = polyscope.register_surface_mesh("hole-closed mesh " + str(high_prob), toNP(samples), toNP(closed_faces), enabled=False) 125 | 126 | polyscope.show() 127 | 128 | 129 | # Save output 130 | if args.output: 131 | 132 | high_prob = args.prob_thresh 133 | out_verts = toNP(samples) 134 | out_faces = toNP(high_faces) 135 | out_faces_closed = toNP(closed_faces) 136 | 137 | if args.output_trim_unused: 138 | out_verts, out_faces, _, _ = igl.remove_unreferenced(out_verts, out_faces) 139 | 140 | igl.write_triangle_mesh(args.output + "_mesh.ply", out_verts, out_faces) 141 | write_ply_points(args.output + "_samples.ply", toNP(samples)) 142 | 143 | igl.write_triangle_mesh(args.output + "_pred_mesh.ply", out_verts, out_faces) 144 | igl.write_triangle_mesh(args.output + "_pred_mesh_closed.ply", out_verts, out_faces_closed) 145 | write_ply_points(args.output + "_samples.ply", toNP(samples)) 146 | 147 | 148 | 149 | 150 | if __name__ == "__main__": 151 | main() 152 | -------------------------------------------------------------------------------- /src/main_train_model.py: -------------------------------------------------------------------------------- 1 | import sys, os, datetime 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | # add the path to the files from this project 8 | sys.path.append(os.path.dirname(__file__)) 9 | 10 | import world 11 | import utils 12 | import losses 13 | import mesh_utils 14 | from utils import * 15 | import data_utils 16 | import train_utils 17 | from point_tri_net import PointTriNet_Mesher 18 | 19 | ### Experiment options 20 | 21 | args = world.ArgsObject() 22 | world.args = args 23 | 24 | # System parameters 25 | args.experiment_name = "fit_model" 26 | args.run_name = datetime.datetime.now().strftime("%Y_%m_%d-%H_%M_%S") 27 | experiments_dir = os.path.join(os.path.dirname(__file__), "..") 28 | args.run_dir = os.path.join(os.path.dirname(__file__), "../training_runs", args.run_name) 29 | args.log_dir = os.path.join(args.run_dir, "logs") 30 | args.dataset_dir = os.path.join(experiments_dir, "..", "data") 31 | args.debug_checks = False 32 | args.disable_cuda = False 33 | 34 | # set some defaults 35 | set_args_defaults(args) 36 | world.device = args.device 37 | world.dtype = args.dtype 38 | 39 | # Experiment parameters 40 | args.load_weights = None 41 | 42 | args.w_watertight = 1.0 43 | args.w_dist_surf_tri = 1.0 44 | args.w_dist_tri_surf = 1.0 45 | args.w_overlap_kernel = 0.01 46 | 47 | # Algorithm parameters 48 | args.n_mesh_rounds = 5 49 | 50 | # Training parameters 51 | args.epochs = 3 # Number of epochs to train for 52 | args.lr = 1e-4 # "Learning rate" 53 | args.lr_decay = .5 # How much to decrease the learning rate by, applied every decay_step samples (lr = lr * lr_decay) 54 | args.decay_step = 10e99 # Decay lr after processing this many samples (_not_ batches, samples) 55 | args.batch_size = 1 # Batch size (note accum parameter below) 56 | args.batch_accum = 8 # Accumulate over this many batches before stepping gradients 57 | args.eval_every = 2048 # Evaluate on the validation set after this many samples 58 | args.eval_size = 512 # Use this much of the validation set to evaluate (if less than full validation set size) 59 | 60 | print("Beginning run " + str(args.run_name)) 61 | 62 | # Ensure the run directory exists 63 | ensure_dir_exists(args.run_dir) 64 | 65 | ## Initialize the tensorboard writer 66 | world.tb_writer = SummaryWriter(log_dir=args.log_dir) 67 | 68 | # Load the dataset 69 | T = [] 70 | train_dataset = data_utils.PointSurfaceDataset(dir_with_meshes="data/train/", transforms=T) 71 | val_dataset = data_utils.PointSurfaceDataset(dir_with_meshes="data/val/", transforms=T) 72 | 73 | train_loader = torch.utils.data.DataLoader( 74 | train_dataset, 75 | batch_size=world.args.batch_size, 76 | shuffle=True, 77 | num_workers=4, 78 | pin_memory=True, 79 | drop_last=True, 80 | # collate_fn=train_dataset.collate_fn 81 | ) 82 | val_loader = torch.utils.data.DataLoader( 83 | val_dataset, 84 | batch_size=world.args.batch_size, 85 | shuffle=True, 86 | num_workers=4, 87 | pin_memory=True, 88 | drop_last=True, 89 | # collate_fn=val_dataset.collate_fn 90 | ) 91 | 92 | # Construct a model 93 | model = PointTriNet_Mesher() 94 | 95 | if args.load_weights: 96 | weights_path = os.path.join(world.args.load_weights,"model_state_dict.pth") 97 | model.load_state_dict(torch.load(weights_path)) 98 | 99 | # Calls the model on a batch 100 | def call_model_fn(model, batch, trainer=None): 101 | points = batch['vert_pos'].to(world.device) 102 | candidate_triangles, candidate_probs, proposal_triangles, proposal_probs = model.predict_mesh(points, n_rounds=args.n_mesh_rounds, sample_last=True) 103 | return { 104 | "candidates" : candidate_triangles, "probs": candidate_probs, 105 | "proposals" : proposal_triangles, "proposal_probs" : proposal_probs 106 | } 107 | 108 | # Construct a loss 109 | def loss_fn(batch, model_outputs, viz_extra=False, trainer=None): 110 | B = batch['vert_pos'].shape[0] 111 | 112 | # Data from the sample 113 | vert_pos_batch = batch['vert_pos'].to(world.device) 114 | surf_pos_batch = batch['surf_pos'].to(world.device) 115 | surf_normal_batch = batch['surf_normal'].to(world.device) 116 | 117 | # Outputs from model 118 | all_candidates = model_outputs['candidates'] 119 | all_candidate_probs = model_outputs['probs'] 120 | all_proposals = model_outputs['proposals'] 121 | all_proposal_probs = model_outputs['proposal_probs'] 122 | 123 | # Accumulate loss 124 | need_grad = all_candidate_probs.requires_grad 125 | total_loss = torch.tensor(0.0, dtype=vert_pos_batch.dtype, device=vert_pos_batch.device, requires_grad=need_grad) 126 | 127 | # Evaluate loss one batch entry at a time 128 | for b in range(B): 129 | 130 | vert_pos = vert_pos_batch[b, :] 131 | candidates = all_candidates[b,:,:] 132 | candidate_probs = all_candidate_probs[b, :] 133 | proposals = all_proposals[b,:,:] 134 | proposal_probs = all_proposal_probs[b, :] 135 | 136 | surf_pos = surf_pos_batch[b, :] 137 | surf_normal = surf_normal_batch[b, :] 138 | 139 | # Add all the terms 140 | loss_terms = losses.build_loss(args, vert_pos, candidates, candidate_probs, surf_pos=surf_pos, surf_normal=surf_normal, n_sample=1000) 141 | 142 | loss_terms["proposal_match"] = losses.match_predictions( 143 | candidates, candidate_probs.detach(), 144 | proposals, proposal_probs) 145 | 146 | this_loss = torch.tensor(0.0, dtype=vert_pos_batch.dtype, device=vert_pos_batch.device, requires_grad=need_grad) 147 | for t in loss_terms: 148 | this_loss = this_loss + loss_terms[t] 149 | 150 | 151 | # Log some stats 152 | if trainer is not None: 153 | if trainer.training: 154 | prefix = "train_" 155 | it = trainer.curr_iter + b 156 | else: 157 | prefix = "val_" 158 | it = trainer.eval_iter + b 159 | 160 | # log less 161 | if it % 10 == 0: 162 | 163 | for t in loss_terms: 164 | world.tb_writer.add_scalar(prefix + t, loss_terms[t].item(), it) 165 | 166 | world.tb_writer.add_scalar(prefix + "sample loss", this_loss.item(), it) 167 | 168 | if it % 1000 == 0: 169 | world.tb_writer.add_histogram(prefix + 'triangle_probs', candidate_probs.detach(), it) 170 | world.tb_writer.add_histogram(prefix + 'triangle_proposal_probs', proposal_probs.detach(), it) 171 | 172 | world.tb_writer.add_scalar(prefix + "prob mean", torch.mean(candidate_probs).item(), it) 173 | world.tb_writer.add_scalar(prefix + "prob stddev", torch.std(candidate_probs).item(), it) 174 | 175 | if not trainer.training: 176 | trainer.add_eval_stat_entry("prob mean", torch.mean(candidate_probs).item()) 177 | trainer.add_eval_stat_entry("prob std", torch.std(candidate_probs).item()) 178 | for t in loss_terms: 179 | trainer.add_eval_stat_entry(t, loss_terms[t].item()) 180 | 181 | total_loss = total_loss + this_loss 182 | 183 | return total_loss / B 184 | 185 | # Train 186 | with torch.autograd.set_detect_anomaly(world.debug_checks): 187 | 188 | trainer = train_utils.MyTrainer( 189 | args=world.args, 190 | model=model, 191 | call_model_fn=call_model_fn, 192 | loss_fn=loss_fn, 193 | train_loader=train_loader, 194 | val_loader=val_loader, 195 | ) 196 | 197 | trainer.train() 198 | -------------------------------------------------------------------------------- /src/mesh_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.distributions.categorical import Categorical 5 | # import igl 6 | import numpy as np 7 | # import polyscope 8 | 9 | import torch_scatter 10 | 11 | import world 12 | import utils 13 | import knn 14 | from utils import * 15 | 16 | 17 | # Cyclically permute the indices of each triangle such that the smallest index comes first 18 | def roll_faces_to_canonical(faces, in_place=False): 19 | 20 | # Don't modify input 21 | if not in_place: 22 | faces = faces.clone() 23 | 24 | min_index = torch.argmin(faces, dim=-1) 25 | for i in range(3): 26 | mask = min_index == i 27 | faces[mask] = faces[mask].roll(-i, dims=-1) 28 | 29 | return faces 30 | 31 | 32 | # Sort the indices of each triangle 33 | def sort_faces_to_canonical(faces, in_place=False): 34 | 35 | # Don't modify input 36 | if not in_place: 37 | faces = faces.clone() 38 | 39 | roll_faces_to_canonical(faces, in_place=True) 40 | 41 | # Swap last two indices so largest index comes last 42 | max_index = torch.argmax(faces, dim=-1) 43 | mask = max_index == 1 44 | faces_opp_orient = torch.index_select(faces, 1, torch.tensor([0, 2, 1], device=world.device)) 45 | faces[mask, :] = faces_opp_orient[mask, :] 46 | 47 | return faces 48 | 49 | 50 | # Return only unique triangles. 51 | # Note, the indices within each face may have been rearranged in the result 52 | # Also, modifies input. 53 | # 54 | # If oriented = True, assumes that the triagnles have been canonically oriented, so 55 | # - will not modify the orientation 56 | # - treats differently oriented triangles as different, aka [1,3,5] != [1,5,3] 57 | # if oriented = False, just treats trangles as sets of three indices, and may modify orientation 58 | def uniqueify_faces(faces, oriented=False): 59 | 60 | if oriented: 61 | roll_faces_to_canonical(faces, in_place=True) 62 | else: 63 | sort_faces_to_canonical(faces, in_place=True) 64 | 65 | # Now, simply take smallest 66 | return torch.unique(faces, sorted=False, dim=0) 67 | 68 | 69 | def sample_points_on_surface(verts, faces, n_pts, return_inds_and_bary=False, face_probs=None): 70 | 71 | # Choose faces 72 | face_areas = utils.face_area(verts, faces) 73 | if face_probs is None: 74 | # if no probs, just weight directly by areas to uniformly sample surface 75 | sample_probs = face_areas 76 | sample_probs = torch.clamp(sample_probs, 1e-30, float('inf')) # avoid -eps area 77 | face_distrib = Categorical(sample_probs) 78 | else: 79 | # if we have face probs, weight by those so we are more likely to sample more probable faces 80 | sample_probs = face_areas * face_probs 81 | sample_probs = torch.clamp(sample_probs, 1e-30, float('inf')) # avoid -eps area 82 | face_distrib = Categorical(sample_probs) 83 | 84 | face_inds = face_distrib.sample(sample_shape=(n_pts,)) 85 | 86 | # Get barycoords for each sample 87 | r1_sqrt = torch.sqrt(torch.rand(n_pts, device=verts.device)) 88 | r2 = torch.rand(n_pts, device=verts.device) 89 | bary_vals = torch.zeros((n_pts, 3), device=verts.device) 90 | bary_vals[:, 0] = 1. - r1_sqrt 91 | bary_vals[:, 1] = r1_sqrt * (1. - r2) 92 | bary_vals[:, 2] = r1_sqrt * r2 93 | 94 | # Get position in face for each sample 95 | coords = utils.face_coords(verts, faces) 96 | sample_coords = coords[face_inds, :, :] 97 | sample_pos = torch.sum(bary_vals.unsqueeze(-1) * sample_coords, dim=1) 98 | 99 | if return_inds_and_bary: 100 | return sample_pos, face_inds, bary_vals 101 | else: 102 | return sample_pos 103 | 104 | # For each point in pointsA, return distances to each point in pointsB 105 | # pointsA: (A, 3) coords 106 | # pointsB: (B, 3) coords 107 | # return: (A, B) dists 108 | def point_point_distances(pointsA, pointsB): 109 | 110 | # Expand so both are NxMx3 tensor 111 | pointsA_expand = pointsA.unsqueeze(1) 112 | pointsA_expand = pointsA_expand.expand(-1, points_target.shape[0], -1) 113 | pointsB_expand = pointsB.unsqueeze(0) 114 | pointsB_expand = pointsB_expand.expand(pointsA.shape[0], -1, -1) 115 | 116 | diff_mat = pointsA_expand - pointsB_expand 117 | dist_mat = utils.norm(diff_mat) 118 | 119 | return dist_mat 120 | 121 | 122 | # For each point in points, returns the distance^2 to each line segment 123 | # points: (N, 3) coords 124 | # linesA: (L, 3) coords 125 | # linesB: (L, 3) coords 126 | # return: (N, L) dists 127 | def point_line_segment_distances2(points, linesA, linesB): 128 | n_p = points.shape[0] 129 | n_l = linesA.shape[0] 130 | 131 | dir_line = utils.normalize(linesB - linesA).unsqueeze(0).expand(n_p, -1, -1) 132 | vecA = points.unsqueeze(1).expand(-1, n_l, -1) - linesA.unsqueeze(0).expand(n_p, -1, -1) 133 | vecB = points.unsqueeze(1).expand(-1, n_l, -1) - linesB.unsqueeze(0).expand(n_p, -1, -1) 134 | 135 | # Distances to first endpoint 136 | dists2 = utils.norm2(vecA) 137 | 138 | # Distances to second endpoint 139 | dists2 = torch.min(dists2, utils.norm2(vecB)) 140 | 141 | # Points within segment 142 | in_line = (utils.dot(dir_line, vecA) > 0) & (utils.dot(dir_line, vecB) < 0) 143 | 144 | # Distances to line 145 | line_dists2 = utils.norm2(utils.project_to_tangent(vecA[in_line], dir_line[in_line])) 146 | dists2[in_line] = line_dists2 147 | 148 | return dists2 149 | 150 | 151 | # For each point in points, returns the distance to each face in faces 152 | # points: (N, 3) coords 153 | # verts: (V, 3) coords 154 | # faces: (F, 3) inds 155 | # return: (N, F) 156 | def point_triangle_distance(points, verts, faces): 157 | n_p = points.shape[0] 158 | n_f = faces.shape[0] 159 | 160 | # make sure everything is contiguous 161 | points = points.contiguous() 162 | verts = verts.contiguous() 163 | faces = faces.contiguous() 164 | 165 | points_expand = points.unsqueeze(1).expand(-1, n_f, -1) 166 | face_normals = utils.face_normals(verts, faces) 167 | 168 | # Accumulate distances 169 | dists2 = float('inf') * torch.ones((n_p, n_f), dtype=points.dtype, device=points.device) 170 | 171 | # True if point projects inside of face in plane 172 | inside_face = torch.ones((n_p, n_f), dtype=torch.bool, device=points.device) 173 | 174 | for i in range(3): 175 | 176 | # Distance to each of the three edges 177 | lineA = verts[faces[:, i]] 178 | lineB = verts[faces[:, (i+1) % 3]] 179 | dists2 = torch.min(dists2, point_line_segment_distances2(points, lineA, lineB)) 180 | 181 | # Edge perp vec (not normalized) 182 | e_perp = utils.cross(face_normals, lineB - lineA) 183 | inside_edge = utils.dot(e_perp.unsqueeze(0).expand(n_p, -1, -1), points_expand - lineA.unsqueeze(0).expand(n_p, -1, -1)) > 0 184 | inside_face = inside_face & inside_edge 185 | 186 | 187 | 188 | dists = torch.sqrt(dists2) 189 | 190 | # For points inside, distance is just normal distance 191 | point_in_face = verts[faces[:, 0]].unsqueeze(0).expand(n_p, -1, -1) 192 | inside_face_dist = torch.abs(utils.dot( 193 | face_normals.unsqueeze(0).expand(n_p, -1, -1)[inside_face], 194 | points_expand[inside_face] - point_in_face[inside_face] 195 | )) 196 | 197 | # dists[inside_face] = inside_face_dist 198 | inside_face_dist_full = torch.zeros_like(dists) 199 | inside_face_dist_full[inside_face] = inside_face_dist 200 | dists = torch.where(inside_face, inside_face_dist_full, dists) 201 | 202 | if False: 203 | polyscope.remove_all_structures() 204 | 205 | samp = polyscope.register_point_cloud("points", toNP(points)) 206 | samp.add_scalar_quantity("dist", toNP(dists[:,0])) 207 | samp.add_scalar_quantity("inside face", toNP(inside_face[:,0].float())) 208 | 209 | tri = polyscope.register_surface_mesh("tri", toNP(verts), toNP(faces[0,:].unsqueeze(0))) 210 | tri.add_face_vector_quantity("e perp", toNP(e_perp[0,:].unsqueeze(0))) 211 | tri.add_face_vector_quantity("e vec", toNP((lineB - lineA)[0,:].unsqueeze(0))) 212 | tri.add_face_vector_quantity("N", toNP((face_normals)[0,:].unsqueeze(0))) 213 | 214 | polyscope.show() 215 | 216 | return dists 217 | 218 | # For each point, returns the triangle zone kernel evaluated over all faces at that point 219 | # points: (N, 3) coords 220 | # verts: (V, 3) coords 221 | # faces: (F, 3) inds 222 | # kernel_height: height of the kernel, as a fraction of the triangle's longest edge 223 | # return: (N, F) 224 | def triangle_kernel(points, verts, faces, kernel_height=1.0): 225 | n_p = points.shape[0] 226 | n_f = faces.shape[0] 227 | 228 | points_expand = points.unsqueeze(1).expand(-1, n_f, -1) 229 | face_normals = utils.face_normals(verts, faces) 230 | 231 | # Longest edge in each triangle 232 | # longest_edge = torch.zeros(n_f, dtype=verts.dtype, device=verts.device) 233 | 234 | # True if point projects inside of face in plane 235 | min_edge_dist = torch.ones((n_p, n_f), dtype=verts.dtype, device=points.device) * float('inf') 236 | 237 | for i in range(3): 238 | 239 | lineA = verts[faces[:, i]] 240 | lineB = verts[faces[:, (i+1) % 3]] 241 | 242 | # Update longest edge 243 | # longest_edge = torch.max(longest_edge, torch.norm(lineA - lineB)) 244 | 245 | # Edge perp vec 246 | e_perp = utils.normalize(utils.cross(face_normals, lineB - lineA)) 247 | edge_inside_dist = utils.dot(e_perp.unsqueeze(0).expand(n_p, -1, -1), points_expand - lineA.unsqueeze(0).expand(n_p, -1, -1)) 248 | # edge_inside_dist = torch.max(dge_inside_dist, torch.tensor(0., device=verts.device)) 249 | min_edge_dist = torch.min(min_edge_dist, edge_inside_dist) 250 | 251 | # normal distance 252 | point_in_face = verts[faces[:, 0]].unsqueeze(0).expand(n_p, -1, -1) 253 | normal_face_dist = torch.abs(utils.dot( 254 | face_normals.unsqueeze(0).expand(n_p, -1, -1), 255 | points_expand - point_in_face 256 | )) 257 | 258 | EPS = 1e-8 259 | k_val = torch.max(torch.tensor(0., device=verts.device), 1. - (normal_face_dist / (min_edge_dist * kernel_height + EPS))) 260 | k_val = torch.where(min_edge_dist < 0., torch.zeros_like(k_val), k_val) 261 | 262 | return k_val 263 | 264 | # Different batches will have different numbers of unique triangles, so will also cull extra triangles from some batches, preferring to cull lowest-prob 265 | # Input: 266 | # candidate_triangles (B, C, 3) 267 | # candidate_probs (B, C) 268 | def uniqueify_triangle_prob_batch(candidate_triangles, candidate_probs): 269 | B = candidate_triangles.shape[0] 270 | C = candidate_triangles.shape[1] 271 | 272 | # TODO some forum posts have indicated that sort becomes faster on the CPU pretty early, might be worth benchmarking transfer to CPU 273 | 274 | # NOTE These loop over batch dimension, as many routines (esp. unique) don't broadcast like we need 275 | 276 | if world.debug_checks: 277 | for b in range(B): 278 | utils.check_faces_for_duplicates(candidate_triangles[b,:,:], check_rows=False) 279 | 280 | # Sort indices within each triangle 281 | candidate_triangles = torch.sort(candidate_triangles, dim=-1).values 282 | 283 | candidate_triangles_list = [] 284 | candidate_probs_list = [] 285 | for b in range(B): 286 | 287 | # Identify unique triangles 288 | if candidate_triangles.is_cuda: 289 | _, inverse_inds = torch.unique(candidate_triangles[b], dim=0, return_inverse=True) 290 | else: 291 | # the CPU implementation of torch.unique(inverse_inds=True) has some scaling problems; 292 | _, inverse_inds = np.unique(toNP(candidate_triangles[b]).astype(np.int32), axis=0, return_inverse=True) 293 | inverse_inds = torch.tensor(inverse_inds, device=candidate_triangles.device, dtype=torch.long) 294 | 295 | # find the largest prob and first entry for each repeat group 296 | max_probs, max_entry = torch_scatter.scatter_max(candidate_probs[b, :], inverse_inds) 297 | max_entry = max_entry.detach() # needed due to a torch_scatter bug? 298 | ind_of_max = max_entry[inverse_inds] 299 | is_max = torch.arange(inverse_inds.shape[0], device=inverse_inds.device) == ind_of_max 300 | del inverse_inds 301 | del max_probs 302 | del max_entry 303 | 304 | # set repeat probs to -1, keeping only the largest prob for repeated triangles 305 | zeroed_candidate_probs = torch.where(is_max, candidate_probs[b,:], -torch.ones_like(candidate_probs[b,:])) 306 | 307 | # Sort by probabilities, so repeats are now at bottom 308 | sorted_zeroed_candidate_probs, sort_inds = torch.sort(zeroed_candidate_probs, descending=True) 309 | candidate_triangles_list.append(candidate_triangles[b, sort_inds, :]) 310 | candidate_probs_list.append(sorted_zeroed_candidate_probs) 311 | 312 | # Identify the first index of a repeat in any colum 313 | repeat_count = max([torch.sum(candidate_probs_list[b] == -1., dim=-1).item() for b in range(B)]) 314 | 315 | # Clip out the last repeat_count entries 316 | u = C-repeat_count 317 | for b in range(B): 318 | candidate_triangles_list[b] = candidate_triangles_list[b][:u, :] 319 | candidate_probs_list[b] = candidate_probs_list[b][:u] 320 | 321 | candidate_triangles = torch.stack(candidate_triangles_list) 322 | candidate_probs = torch.stack(candidate_probs_list) 323 | 324 | 325 | if world.debug_checks: 326 | for b in range(B): 327 | utils.check_faces_for_duplicates(candidate_triangles[b,:,:], check_rows=True) 328 | 329 | 330 | return candidate_triangles, candidate_probs 331 | 332 | # Different batches will have different numbers of unique triangles, so will also cull extra triangles from some batches 333 | # Input: 334 | # candidate_triangles (B, C, 3) 335 | def uniqueify_triangle_batch(candidate_triangles): 336 | B = candidate_triangles.shape[0] 337 | C = candidate_triangles.shape[1] 338 | 339 | # TODO some forum posts have indicated that sort becomes faster on the CPU pretty early, might be worth benchmarking transfer to CPU 340 | 341 | # NOTE These loop over batch dimension, as many routines (esp. unique) don't broadcast like we need 342 | 343 | if world.debug_checks: 344 | for b in range(B): 345 | utils.check_faces_for_duplicates(candidate_triangles[b,:,:], check_rows=False) 346 | 347 | # Sort indices within each triangle 348 | candidate_triangles = torch.sort(candidate_triangles, dim=-1).values 349 | 350 | candidate_triangles_list = [] 351 | min_length = float('inf') 352 | for b in range(B): 353 | unique_triangles = uniqueify_faces(candidate_triangles[b], oriented=False) 354 | candidate_triangles_list.append(unique_triangles) 355 | min_length = min((min_length), unique_triangles.shape[0]) 356 | 357 | # Truncate all in batch to the length of the shortest unique list 358 | for b in range(B): 359 | candidate_triangles_list[b] = candidate_triangles[b][:min_length,:] 360 | 361 | candidate_triangles = torch.stack(candidate_triangles_list) 362 | 363 | if world.debug_checks: 364 | for b in range(B): 365 | utils.check_faces_for_duplicates(candidate_triangles[b,:,:], check_rows=True) 366 | 367 | return candidate_triangles 368 | 369 | 370 | 371 | # Input: 372 | # candidate_triangles (B, C, 3) 373 | # candidate_probs (B, C) 374 | def filter_low_prob_triangles(candidate_triangles, candidate_probs, n_keep): 375 | B = candidate_triangles.shape[0] 376 | C = candidate_triangles.shape[1] 377 | u = min(n_keep, C) 378 | 379 | # TODO some forum posts have indicated that sort becomes faster on the CPU pretty early, might be worth benchmarking transfer to CPU 380 | 381 | # For each group of unique triangles 382 | candidate_triangles_list = [] 383 | candidate_probs_list = [] 384 | for b in range(B): 385 | 386 | # Sort by probabilities 387 | sort_inds = torch.argsort(candidate_probs[b,:], descending=True) 388 | 389 | # Clip out all repeated triangles (and possibly some extras from other arrays) 390 | candidate_triangles_list.append(candidate_triangles[b, sort_inds[:u], :]) 391 | candidate_probs_list.append(candidate_probs[b, sort_inds[:u]]) 392 | 393 | candidate_triangles = torch.stack(candidate_triangles_list) 394 | candidate_probs = torch.stack(candidate_probs_list) 395 | 396 | return candidate_triangles, candidate_probs 397 | 398 | 399 | 400 | # Generate V triangles via nearest neighbors 401 | # Note that this will contain lots of duplicates, but that's fine 402 | # verts (B, V, 3) 403 | def generate_seed_triangles(verts): 404 | B = verts.shape[0] 405 | V = verts.shape[1] 406 | 407 | gen_tris = torch.zeros((B,V,3), device=verts.device, dtype=torch.long) 408 | 409 | for b in range(B): 410 | _, inds = knn.find_knn(verts[b,:], verts[b,:], 2, omit_diagonal=True) 411 | gen_tris[b, ...] = torch.cat((torch.arange(V, device=verts.device).unsqueeze(-1), inds), dim = -1) 412 | 413 | # gen_probs = torch.rand((B,V), device=verts.device, dtype=verts.dtype) 414 | gen_probs = 0.5 * torch.ones((B,V), device=verts.device, dtype=verts.dtype) 415 | 416 | return gen_tris, gen_probs 417 | 418 | 419 | def fill_holes_greedy(in_faces): 420 | faces = toNP(in_faces).tolist() 421 | 422 | def edge_key(a,b): 423 | return (min(a,b), max(a,b)) 424 | def face_key(f): 425 | return tuple(sorted(f)) 426 | 427 | edge_count = {} 428 | neighbors = {} 429 | all_faces = set() 430 | def add_edge(a,b): 431 | if a not in neighbors: 432 | neighbors[a] = set() 433 | if b not in neighbors: 434 | neighbors[b] = set() 435 | key = edge_key(a,b) 436 | if key not in edge_count: 437 | edge_count[key] = 0 438 | 439 | neighbors[a].add(b) 440 | neighbors[b].add(a) 441 | edge_count[key] += 1 442 | 443 | def add_face(f): 444 | for i in range(3): 445 | a = f[i] 446 | b = f[(i+1)%3] 447 | add_edge(a,b) 448 | 449 | all_faces.add(face_key(f)) 450 | 451 | def face_exists(f): 452 | return face_key(f) in all_faces 453 | 454 | for f in faces: 455 | add_face(f) 456 | 457 | # repeated passes (inefficient) 458 | any_changed = True 459 | while(any_changed): 460 | any_changed = False 461 | new_faces = [] 462 | 463 | start_edges = [e for e in edge_count] 464 | 465 | for e in start_edges: 466 | if edge_count[e] == 1: 467 | a,b = e 468 | found = False 469 | 470 | # Look single triangle holes 471 | for s in [a,b]: # one of the verts in this edge 472 | if found: break # quit once found 473 | o = b if s == a else a # the other vert in this edge 474 | for n in neighbors[s]: # a candidate third vertex 475 | if found: break # quit once found 476 | if n == o: continue # must not be same as edge 477 | if face_exists([a,b,n]): continue # face must not exist 478 | if (edge_count[edge_key(s,n)] == 1) and (edge_key(o,n) in edge_count) and (edge_count[edge_key(o,n)] == 1): # must be single hole 479 | 480 | # accept the new face 481 | found = True 482 | new_f = [a,b,n] 483 | new_faces.append(new_f) 484 | add_face(new_f) 485 | 486 | if any_changed: 487 | # if we found a single hole, look for more 488 | continue 489 | 490 | for e in start_edges: 491 | if edge_count[e] == 1: 492 | a,b = e 493 | found = False 494 | 495 | # Look for matching edge 496 | for s in [a,b]: # one of the verts in this edge 497 | if found: break # quit once found 498 | o = b if s == a else a # the other vert in this edge 499 | for n in neighbors[s]: # a candidate third vertex 500 | if found: break # quit once found 501 | if n == o: continue # must not be same as edge 502 | if face_exists([a,b,n]): continue # face must not exist 503 | if edge_count[edge_key(s,n)] == 1: # must be boundary edge 504 | 505 | # accept the new face 506 | found = True 507 | new_f = [a,b,n] 508 | new_faces.append(new_f) 509 | add_face(new_f) 510 | 511 | 512 | 513 | faces.extend(new_faces) 514 | 515 | return torch.tensor(faces, dtype=in_faces.dtype, device=in_faces.device) 516 | 517 | 518 | 519 | 520 | -------------------------------------------------------------------------------- /src/mini_mlp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | import utils 8 | import world 9 | 10 | 11 | class BatchNormLastDim(nn.Module): 12 | def __init__(self, s): 13 | super(BatchNormLastDim, self).__init__() 14 | self.s = s 15 | self.bn = nn.BatchNorm1d(s) 16 | 17 | def forward(self, x): 18 | init_dim = x.shape 19 | if init_dim[-1] != self.s: 20 | raise ValueError("batch norm last dim does not have right shape. should be {}, but is {}".format(self.s, init_dim[-1])) 21 | 22 | x_flat = x.view((-1, self.s)) 23 | bn_flat = self.bn(x_flat) 24 | return bn_flat.view(*init_dim) 25 | 26 | class MiniMLP(nn.Sequential): 27 | 28 | def __init__( 29 | self, 30 | layer_sizes, 31 | name='miniMLP', 32 | activation=nn.ReLU, 33 | batch_norm=True, 34 | skip_last_norm=False, 35 | layer_norm=False, 36 | dropout=False, 37 | skip_first_dropout=False, 38 | ): 39 | super(MiniMLP, self).__init__() 40 | 41 | for i in range(len(layer_sizes) - 1): 42 | 43 | is_last = (i+2 == len(layer_sizes)) 44 | 45 | if dropout: 46 | 47 | if i > 0 or not skip_first_dropout: 48 | 49 | self.add_module( 50 | name + "_mlp_layer_dropout_{:03d}".format(i), 51 | nn.Dropout() 52 | ) 53 | 54 | 55 | # Affine map 56 | self.add_module( 57 | name + "_mlp_layer_{:03d}".format(i), 58 | nn.Linear( 59 | layer_sizes[i], 60 | layer_sizes[i + 1], 61 | ), 62 | ) 63 | 64 | # Maybe batch_norm 65 | # (but maybe not on the last layer) 66 | if batch_norm: 67 | if (not skip_last_norm) or (not is_last): 68 | self.add_module( 69 | name + "_mlp_batch_norm_{:03d}".format(i), 70 | BatchNormLastDim(layer_sizes[i+1]) 71 | ) 72 | 73 | # Maybe layer norm 74 | # (but maybe not on the last layer) 75 | if layer_norm: 76 | if (not skip_last_norm) or (not is_last): 77 | self.add_module( 78 | name + "_mlp_layer_norm_{:03d}".format(i), 79 | nn.LayerNorm(layer_sizes[i+1]) 80 | ) 81 | 82 | # Nonlinearity 83 | # (but not on the last layer) 84 | if activation is not None and not is_last: 85 | self.add_module( 86 | name + "_mlp_act_{:03d}".format(i), 87 | activation() 88 | ) 89 | 90 | -------------------------------------------------------------------------------- /src/point_tri_net.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | import sklearn.neighbors 8 | 9 | import mesh_utils 10 | import utils 11 | import knn 12 | 13 | from mini_mlp import MiniMLP 14 | 15 | 16 | # Encode points with respect to a triangle. For each input point generates (x,y,z,u,v,w) with respect to triangle. 17 | # Inputs: 18 | # - points_pos (B, Q, K, 3) positions 19 | # - query_triangles_pos (B, Q, 3, 3) corner positions 20 | # Outputs: 21 | # (B, Q, N, 6) 22 | def generate_coords(points_pos, query_triangles_pos): 23 | 24 | EPS = 1e-6 25 | 26 | # First, compute and remove the normal component 27 | area_normals = 0.5 * torch.cross( 28 | query_triangles_pos[:, :, 1, :] - query_triangles_pos[:, :, 0, :], 29 | query_triangles_pos[:, :, 2, :] - query_triangles_pos[:, :, 0, :], dim=-1) 30 | 31 | areas = utils.norm(area_normals) + EPS # (B, Q) 32 | normals = area_normals / areas.unsqueeze(-1) # (B, Q, 3) 33 | barycenters = torch.mean(query_triangles_pos, dim=2) # (B, Q, 3) 34 | centered_neighborhood = points_pos - barycenters.unsqueeze(2) 35 | normal_comp = utils.dot(normals.unsqueeze(2), centered_neighborhood) 36 | neighborhood_planar = points_pos - normals.unsqueeze(2) * normal_comp.unsqueeze(-1) 37 | 38 | # Compute barycentric coordinates in plane 39 | def coords_i(i): 40 | point_area = 0.5 * utils.dot( 41 | normals.unsqueeze(2), 42 | torch.cross( 43 | query_triangles_pos[:, :, (i+1) % 3, :].unsqueeze(2) - neighborhood_planar, 44 | query_triangles_pos[:, :, (i+2) % 3, :].unsqueeze(2) - neighborhood_planar, 45 | dim=-1) 46 | ) 47 | 48 | area_frac = (point_area + EPS / 3.) / areas.unsqueeze(-1) 49 | return area_frac 50 | 51 | BARY_MAX = 5. 52 | u = torch.clamp(coords_i(0), -BARY_MAX, BARY_MAX) 53 | v = torch.clamp(coords_i(1), -BARY_MAX, BARY_MAX) 54 | w = torch.clamp(coords_i(2), -BARY_MAX, BARY_MAX) 55 | 56 | # Compute cartesian coordinates with the x-axis along the i --> j edge 57 | basisX = utils.normalize(query_triangles_pos[:, :, 1, :] - query_triangles_pos[:, :, 0, :]) 58 | basisY = utils.normalize(torch.cross(normals, basisX)) 59 | x_comp = utils.dot(basisX.unsqueeze(2), centered_neighborhood) 60 | y_comp = utils.dot(basisY.unsqueeze(2), centered_neighborhood) 61 | 62 | coords = torch.stack((x_comp, y_comp, normal_comp, u, v, w), dim=-1) 63 | 64 | return coords 65 | 66 | 67 | # Inputs: 68 | # - query_triangles_pos (B, Q, 3, 3) 69 | # - nearby_points_pos (B, Q, K, 3) 70 | # - nearby_triangle_pos (B, Q, K_T, 3, 3) 71 | # - nearby_triangle_probs (B, Q, K_T) 72 | def encode_points_and_triangles(query_triangles_pos, nearby_points_pos, 73 | nearby_triangles_pos=None, nearby_triangle_probs=None): 74 | 75 | B = query_triangles_pos.shape[0] 76 | Q = query_triangles_pos.shape[1] 77 | K = nearby_points_pos.shape[2] 78 | 79 | have_triangles = (nearby_triangles_pos is not None) 80 | if have_triangles: 81 | K_T = nearby_triangles_pos.shape[2] 82 | 83 | # Normalize neighborhood (translation won't matter, but unit scale is nice) 84 | # note that we normalize vs. the triangle, not vs. the points 85 | neigh_centers = torch.mean(query_triangles_pos, dim=2) # (B, Q, 3) 86 | neigh_scales = torch.mean(utils.norm(query_triangles_pos - neigh_centers.unsqueeze(2)), dim=-1) + 1e-5 # (B, Q) 87 | nearby_points_pos = nearby_points_pos.clone() / neigh_scales.unsqueeze(-1).unsqueeze(-1) 88 | query_triangles_pos = query_triangles_pos.clone() / neigh_scales.unsqueeze(-1).unsqueeze(-1) 89 | if have_triangles: 90 | nearby_triangles_pos = nearby_triangles_pos.clone() / neigh_scales.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 91 | 92 | # Encode the nearby points 93 | point_coords = generate_coords(nearby_points_pos, query_triangles_pos) 94 | 95 | # Encode the nearby triangles 96 | if have_triangles: 97 | tri_coords = generate_coords(nearby_triangles_pos.view(B, Q, K_T*3, 3), query_triangles_pos).view(B, Q, K_T, 3, 6) 98 | max_vals = torch.max(tri_coords, dim=3).values # (B, Q, K_T, 6) 99 | min_vals = torch.min(tri_coords, dim=3).values # (B, Q, K_T, 6) 100 | triangle_coords = torch.cat((min_vals, max_vals, nearby_triangle_probs.unsqueeze(-1)), dim=-1) 101 | 102 | if have_triangles: 103 | return point_coords, triangle_coords 104 | else: 105 | return point_coords 106 | 107 | # Given probabilities for points 108 | # Input: 109 | # - `query_triangles` (B, Q, 3) indices for the three vertices of triangle 110 | # - `neighborhoods` (B, Q, K) indices of the neighbors for each query triangle 111 | # - `point_probs` (B, Q, K) in [0,1] probs for each point connecting to the side 112 | # Output: 113 | # - (B, Q, O, 3) triangle inds in to points list 114 | # - (B, Q, O) probs for each triangle 115 | def sample_neighbor_tris_from_point(query_triangles, neighborhoods, point_probs, n_output_per_side, random_rate=0.): 116 | B = point_probs.shape[0] 117 | Q = point_probs.shape[1] 118 | K = point_probs.shape[2] 119 | O = n_output_per_side 120 | 121 | # Zero probs for points which appear in the triangles 122 | zeros = torch.zeros_like(point_probs) 123 | sample_probs = point_probs + .0001 # avoid zeros in multinomial 124 | sample_probs = random_rate * torch.mean(sample_probs, dim=-1, keepdim=True) + (1. - random_rate) * sample_probs # support random_rate option 125 | for i in range(3): 126 | remove_pts = query_triangles[:, :, i] 127 | mask = (neighborhoods == remove_pts.unsqueeze(-1)) 128 | sample_probs = torch.where(mask, zeros, sample_probs) 129 | 130 | # indexed local to this neighborhood 131 | new_neigh_inds = torch.zeros((B, Q, O), dtype=neighborhoods.dtype, device=neighborhoods.device) 132 | 133 | # loop, multinomial doesn't broadcast (could use view?) 134 | for b in range(B): 135 | new_neigh_inds[b,:,:] = torch.multinomial(sample_probs[b,:,:], num_samples=O, replacement=False) 136 | 137 | # Global-index the sampled neighbors and gather data 138 | new_inds = torch.gather(neighborhoods, 2, new_neigh_inds) # (B, Q, O) global indexed 139 | 140 | tri_verts = torch.stack(( 141 | query_triangles[:, :, 0].unsqueeze(2).expand(-1, -1, O), 142 | query_triangles[:, :, 1].unsqueeze(2).expand(-1, -1, O), 143 | new_inds 144 | ), dim=3) 145 | 146 | # Likelihood for the points that were sampled (use actual original likelihood, rather than sample likelihood) 147 | new_probs = torch.gather(point_probs[:,:,:], 2, new_neigh_inds) 148 | 149 | 150 | # pull sliiiightly away from 0/1 for numerical stability downstream 151 | EPS = 1e-4 152 | new_probs = (1. - EPS) * new_probs + EPS * 0.5 153 | 154 | return tri_verts, new_probs 155 | 156 | 157 | 158 | class PointTriNet(torch.nn.Module): 159 | def __init__(self, input_dim=3): 160 | super(PointTriNet, self).__init__() 161 | 162 | use_batch_norm = False 163 | use_layer_norm = False 164 | activation = nn.ReLU 165 | 166 | # == The MLPs for the pointnet 167 | 168 | self.input_dim = input_dim 169 | 170 | # classification net 171 | self.neigh_point_feat_class_net = MiniMLP([6 + (self.input_dim-3), 64, 128, 1024], activation=activation, batch_norm=False, layer_norm=False) 172 | g_dim = 1024 173 | self.neigh_tri_feat_class_net = MiniMLP([13, 64, 128, 1024], activation=activation, batch_norm=False, layer_norm=False) 174 | g_dim += 1024 175 | 176 | self.global_feat_class_net = MiniMLP([g_dim, 512, 256, 1], activation=activation, batch_norm=use_batch_norm, layer_norm=use_layer_norm, skip_last_norm=True, dropout=True) 177 | 178 | # suggestion net 179 | self.neigh_point_feat_sugg_net = MiniMLP([6 + (self.input_dim-3), 64, 64, 128], activation=activation, batch_norm=False, layer_norm=False) 180 | self.point_sugg_net = MiniMLP([6 + (self.input_dim-3) + 128, 128, 64, 64, 1], activation=activation, batch_norm=use_batch_norm, layer_norm=use_layer_norm, skip_last_norm=True) 181 | 182 | # Input: 183 | # - `verts` (B, V, 3) ALL vertices for each shape 184 | # - `all_triangle_pos` (B, F, 3, 3) positions for the three vertices of ALL triangles 185 | # - `all_triangle_prob` (B, F) current probs for ALL triangles 186 | # - `query_triangle_ind` (B, Q, 3) indices for the three vertices of triangle, ordered arbitrarily 187 | # - `query_triangle_prob` (B, Q) current probs for the triangles used as queries 188 | # - `point_neighbor_ind` (B, Q, K) indices of the neighboring points for each query triangle 189 | # - `face_neighbor_ind` (B, Q, K_T) indices of the neighboring triangles for each query triangle 190 | # 191 | # Output: 192 | def forward(self, verts, all_triangle_pos, all_triangle_prob, query_triangle_pos, query_triangle_ind, query_triangle_prob, point_neighbor_ind, face_neighbor_ind, preds_per_side): 193 | 194 | if not hasattr(self, 'input_dim'): 195 | self.input_dim = 3 196 | 197 | D = query_triangle_ind.device 198 | DT = verts.dtype 199 | B = point_neighbor_ind.shape[0] 200 | Q = point_neighbor_ind.shape[1] 201 | K = point_neighbor_ind.shape[2] 202 | K_T = face_neighbor_ind.shape[2] 203 | 204 | ## Gather data about neighborhood 205 | 206 | point_neighbor_pos = torch.gather( 207 | verts.unsqueeze(-2).expand(-1, -1, K, -1), 1, 208 | point_neighbor_ind.unsqueeze(-1).expand(-1, -1, -1, self.input_dim) 209 | ) # (B, Q, K, 3) 210 | 211 | face_neighbor_probs = torch.gather( 212 | all_triangle_prob.unsqueeze(-1).expand(-1, -1, K_T), 1, 213 | face_neighbor_ind 214 | ) # (B, Q, K_T) 215 | 216 | face_neighbor_pos = torch.gather( 217 | all_triangle_pos.unsqueeze(2).expand(-1, -1, K_T, -1, -1), 1, 218 | face_neighbor_ind.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, 3, 3) 219 | ) # (B, Q, K_T, 3, 3) 220 | 221 | # ========================== 222 | # === Classification 223 | # ========================== 224 | 225 | # Generate coordinates 226 | point_neighbor_coords, face_neighbor_coords = \ 227 | encode_points_and_triangles(query_triangle_pos, point_neighbor_pos[...,:3], face_neighbor_pos, face_neighbor_probs) 228 | 229 | point_neighbor_coords = torch.cat((point_neighbor_coords, point_neighbor_pos[...,3:]), dim=-1) # restore optional latent data 230 | 231 | 232 | # Evaluate the pointnet for the point neighbors of each query 233 | point_features = self.neigh_point_feat_class_net(point_neighbor_coords) # (B, Q, K, feat) 234 | point_features_max = torch.max(point_features, dim=2).values # (B, Q, feat) 235 | 236 | # Evaluate the pointnet for the face neighbors of each query 237 | tri_features = self.neigh_tri_feat_class_net(face_neighbor_coords) # (B, Q, K_T, feat) 238 | tri_features_max = torch.max(tri_features, dim=2).values # (B, Q, feat) 239 | 240 | # Combine and take the max 241 | max_features = torch.cat((point_features_max, tri_features_max), dim=-1) #(B, Q, 2*feat) 242 | 243 | # get global features for each output (from both heads) 244 | global_features_class = self.global_feat_class_net(max_features) # (B, Q, 1) 245 | 246 | # probabilities from classification head 247 | output_probs = torch.sigmoid(global_features_class.squeeze(-1)) 248 | 249 | # pull sliiiightly away from 0/1 for numerical stability downstream 250 | EPS = 1e-4 251 | output_probs = (1. - EPS) * output_probs + EPS * 0.5 252 | 253 | # happens very rarely due to geometric degeneracies 254 | output_probs = torch.where(torch.isnan(output_probs), 255 | torch.mean(output_probs[~torch.isnan(output_probs)]), output_probs) 256 | 257 | # ========================== 258 | # === Classification 259 | # ========================== 260 | 261 | # repeat for each of the 3 orientaitons 262 | gen_tris_list = [] 263 | gen_probs_list = [] 264 | for i in range(3): 265 | 266 | # permute the encoded values, if needed 267 | # (for the first iteraiton we can just reuse the values from classification prediction) 268 | if i != 0: 269 | query_triangle_pos = query_triangle_pos.roll(1, dims=2) 270 | query_triangle_ind = query_triangle_ind.roll(1, dims=2) 271 | 272 | # Generate coordinates 273 | point_neighbor_coords = encode_points_and_triangles(query_triangle_pos, point_neighbor_pos[...,:3]) 274 | point_neighbor_coords = torch.cat((point_neighbor_coords, point_neighbor_pos[...,3:]), dim=-1) # restore optional latent data 275 | 276 | 277 | # Evaluate the pointnet for the point neighbors of each query 278 | point_features = self.neigh_point_feat_sugg_net(point_neighbor_coords) # (B, Q, K, feat) 279 | point_features_max = torch.max(point_features, dim=2).values # (B, Q, feat) 280 | 281 | # use these global features to make point predictions 282 | point_select_inputs = torch.cat(( 283 | point_neighbor_coords, 284 | point_features_max.unsqueeze(2).expand(-1, -1, K, -1) 285 | ), dim=-1) # (B, Q, K, 6 + g_feat) 286 | 287 | # generate selection scores at each point 288 | point_probs = torch.sigmoid(self.point_sugg_net(point_select_inputs).squeeze(-1)) # (B, Q, K) 289 | 290 | 291 | # Sample new triangles from the vertex likelihoods 292 | random_rate = 0.25 if self.training else 0. # during training, chose random with small prob to get more divesity 293 | gen_tris, gen_probs = sample_neighbor_tris_from_point( 294 | query_triangle_ind, point_neighbor_ind, point_probs, preds_per_side, random_rate) 295 | 296 | 297 | # modulate point probs by the triangle which generated them 298 | gen_probs *= output_probs.unsqueeze(-1) 299 | 300 | gen_tris_list.append(gen_tris) 301 | gen_probs_list.append(gen_probs) 302 | 303 | gen_tris = torch.cat(gen_tris_list, dim=2) 304 | gen_probs = torch.cat(gen_probs_list, dim=2) 305 | 306 | # Collapse all of the new candidates from all of the query triangles 307 | gen_tris = gen_tris.view(B, -1, 3) 308 | gen_probs = gen_probs.view(B, -1) 309 | 310 | return output_probs, gen_tris, gen_probs 311 | 312 | 313 | # Input: 314 | # - query_triangle_ind: (B, Q, 3) indices 315 | # - verts: (B, V, 3) positions 316 | # - query_triangle_ind: (B, Q) in [0,1] 317 | 318 | # - as an optimization, neighbors can be passed in if we already have them 319 | # 320 | # Output: 321 | # - (T) \in [0,1] new likelihood for each face 322 | def apply_to_candidates(self, query_triangle_ind, verts, query_probs, new_verts_per_edge, k_neigh=64, neighbors_method='generate', return_list=False, split_size=1024*4): 323 | B = query_triangle_ind.shape[0] 324 | Q = query_triangle_ind.shape[1] 325 | V_D = verts.shape[-1] 326 | D = verts[0].device 327 | K = k_neigh 328 | K_T = min(k_neigh, Q-1) 329 | 330 | query_triangles_pos = torch.gather( 331 | verts[...,:3].unsqueeze(-2).expand(-1, -1, 3, -1), 1, 332 | query_triangle_ind.unsqueeze(-1).expand(-1, -1, -1, 3) 333 | ) # (B, Q, 3, 3) 334 | 335 | barycenters = torch.mean(query_triangles_pos, dim=2) 336 | 337 | 338 | # Manage devices in the case where we are leaving data CPU-side 339 | input_device = verts.device 340 | model_device = next(self.parameters()).device 341 | query_triangles_pos_d = query_triangles_pos.to(model_device) 342 | query_probs_d = query_probs.to(model_device) 343 | method = 'brute' if (verts[0].is_cuda and Q < 4096) else 'cpu_kd' 344 | if method == 'cpu_kd': 345 | # pre-build a tree just once for CPU lookups 346 | kd_tree_verts = [sklearn.neighbors.KDTree(utils.toNP(verts[b,...,:3])) for b in range(B)] 347 | kd_tree_bary = [sklearn.neighbors.KDTree(utils.toNP(barycenters[b,...])) for b in range(B)] 348 | else: 349 | kd_tree_verts = [None for b in range(B)] 350 | kd_tree_bary = [None for b in range(B)] 351 | 352 | # (during training, this should hopefully leave a single chunk, so we get batch statistics 353 | query_triangle_ind_chunks = torch.split(query_triangle_ind, split_size, dim=1) 354 | query_triangle_pos_chunks = torch.split(query_triangles_pos, split_size, dim=1) 355 | query_triangle_prob_chunks = torch.split(query_probs, split_size, dim=1) 356 | 357 | # Apply the model 358 | pred_chunks = [] 359 | gen_tri_chunks = [] 360 | gen_pred_chunks= [] 361 | for i_chunk in range(len(query_triangle_ind_chunks)): 362 | if(len(query_triangle_ind_chunks) > 1): 363 | print("chunk {}/{}".format(i_chunk, len(query_triangle_ind_chunks))) 364 | 365 | query_triangle_ind_chunk = query_triangle_ind_chunks[i_chunk] 366 | query_triangle_pos_chunk = query_triangle_pos_chunks[i_chunk] 367 | query_triangle_prob_chunk = query_triangle_prob_chunks[i_chunk] 368 | 369 | Q_C = query_triangle_ind_chunk.shape[1] 370 | barycenters_chunk = torch.mean(query_triangle_pos_chunk, dim=2) 371 | 372 | # Gather neighborhoods of each candidate face 373 | 374 | # Build out neighbors 375 | point_neighbor_inds = torch.zeros((B, Q_C, K), device=D, dtype=query_triangle_ind.dtype) 376 | face_neighbor_inds = torch.zeros((B, Q_C, K_T), device=D, dtype=query_triangle_ind.dtype) 377 | 378 | for b in range(B): 379 | 380 | _, point_neighbor_inds_this = knn.find_knn(barycenters_chunk[b,...], verts[b,...,:3], k=K, method=method, prebuilt_tree=kd_tree_verts[b]) 381 | point_neighbor_inds[b,...] = point_neighbor_inds_this 382 | 383 | _, face_neighbor_inds_this = knn.find_knn(barycenters_chunk[b,...], barycenters[b,...], k=K_T+1, method=method, omit_diagonal=False, prebuilt_tree=kd_tree_bary[b]) 384 | face_neighbor_inds_this = face_neighbor_inds_this[...,1:] # remove self overlap 385 | face_neighbor_inds[b,...] = face_neighbor_inds_this 386 | 387 | 388 | # Invoke the model 389 | 390 | output_preds_chunk, gen_tri_chunk, gen_pred_chunk = \ 391 | self(verts.to(model_device), 392 | query_triangles_pos_d, 393 | query_probs_d, 394 | query_triangle_pos_chunk.to(model_device), 395 | query_triangle_ind_chunk.to(model_device), 396 | query_triangle_prob_chunk.to(model_device), 397 | point_neighbor_inds.to(model_device), 398 | face_neighbor_inds.to(model_device), 399 | new_verts_per_edge) 400 | 401 | 402 | output_preds_chunk = output_preds_chunk.to(input_device) 403 | gen_tri_chunk = gen_tri_chunk.to(input_device) 404 | gen_pred_chunk = gen_pred_chunk.to(input_device) 405 | 406 | pred_chunks.append(output_preds_chunk) 407 | gen_tri_chunks.append(gen_tri_chunk) 408 | gen_pred_chunks.append(gen_pred_chunk) 409 | 410 | preds = torch.cat(pred_chunks, dim=1) 411 | gen_tris = torch.cat(gen_tri_chunks, dim=1) 412 | gen_preds = torch.cat(gen_pred_chunks, dim=1) 413 | 414 | return preds, gen_tris, gen_preds 415 | 416 | 417 | 418 | 419 | # Iteratively applies the PointTriNet to generate predictions 420 | class PointTriNet_Mesher(torch.nn.Module): 421 | def __init__(self, input_dim=3): 422 | super(PointTriNet_Mesher, self).__init__() 423 | 424 | self.net = PointTriNet(input_dim=input_dim) 425 | 426 | 427 | def predict_mesh(self, verts, verts_latent = None, n_rounds=5, keep_faces_per_vert=12, new_verts_per_edge=4, sample_last=False, return_all=False): 428 | 429 | B = verts.shape[0] 430 | n_keep = keep_faces_per_vert * verts.shape[1] 431 | 432 | # Seed with some random initial triangles 433 | candidate_triangles, candidate_probs = mesh_utils.generate_seed_triangles(verts[...,:3]) 434 | candidate_triangles, candidate_probs = mesh_utils.uniqueify_triangle_prob_batch(candidate_triangles, candidate_probs) 435 | 436 | if return_all: 437 | working_tris = [] 438 | working_probs = [] 439 | proposal_tris = [] 440 | proposal_probs = [] 441 | 442 | for iter in range(n_rounds): 443 | is_last = iter == n_rounds-1 444 | 445 | # only take gradients on last iter 446 | with (utils.fake_context() if is_last else torch.autograd.no_grad()): 447 | 448 | # Classify triangles & generate new ones 449 | new_candidate_probs, gen_tris, gen_probs = self.net.apply_to_candidates(candidate_triangles, verts, candidate_probs, new_verts_per_edge) 450 | candidate_probs = new_candidate_probs 451 | 452 | if return_all: 453 | working_tris.append(candidate_triangles) 454 | working_probs.append(candidate_probs) 455 | proposal_tris.append(gen_tris) 456 | proposal_probs.append(gen_probs) 457 | 458 | if (not is_last): 459 | 460 | # Union new candidates 461 | candidate_triangles = torch.cat((candidate_triangles, gen_tris), dim=1) 462 | candidate_probs = torch.cat((candidate_probs, gen_probs), dim=1) 463 | 464 | # Prune out repeats 465 | candidate_triangles, candidate_probs = mesh_utils.uniqueify_triangle_prob_batch(candidate_triangles, candidate_probs) 466 | 467 | # Cull low-probability triangles 468 | candidate_triangles, candidate_probs = mesh_utils.filter_low_prob_triangles( 469 | candidate_triangles, candidate_probs, n_keep) 470 | 471 | if return_all: 472 | return working_tris, working_probs, proposal_tris, proposal_probs 473 | 474 | if sample_last: 475 | # Prune out repeats amongst last samples 476 | gen_tris, gen_probs = mesh_utils.uniqueify_triangle_prob_batch(gen_tris, gen_probs) 477 | 478 | return candidate_triangles, candidate_probs, gen_tris, gen_probs 479 | else: 480 | return candidate_triangles, candidate_probs 481 | -------------------------------------------------------------------------------- /src/train_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import gc 4 | 5 | import numpy as np 6 | import torch 7 | 8 | import world 9 | import utils 10 | import data_utils 11 | from utils import toNP 12 | 13 | class MyTrainer(object): 14 | 15 | def __init__( 16 | self, args, model, call_model_fn, loss_fn, train_loader, val_loader, collate_fn=None 17 | ): 18 | 19 | # Copy parameters 20 | self.args = args 21 | self.model = model 22 | self.loss_fn = loss_fn 23 | self.call_model_fn = call_model_fn 24 | self.train_loader = train_loader 25 | self.val_loader = val_loader 26 | self.collate_fn = collate_fn 27 | 28 | # Some extra training state 29 | self.best_loss = float("inf") 30 | self.curr_epoch = 0 31 | self.curr_iter = 0 32 | self.training = True # false == eval 33 | self.eval_iter = 0 # just used as a number for tensorboard logs 34 | 35 | # Stats 36 | self.running_train_loss = 0. 37 | self.running_train_loss_count = 0 38 | self.eval_stats = {} # a dictionary holding lists of stats to track for the eval loop 39 | 40 | # === Utilities 41 | def run_dir(self): 42 | return os.path.join(self.args.run_dir) 43 | 44 | def save_dir(self): 45 | return os.path.join(self.run_dir(), "saved") 46 | 47 | # === Save subroutine 48 | def save_training_state(self, opt=None, suffix=""): 49 | 50 | # Paths 51 | this_save_dir = os.path.join(self.save_dir(), suffix) 52 | print(" --> saving model to {}".format(this_save_dir)) 53 | utils.ensure_dir_exists(this_save_dir) 54 | 55 | # Serialize all the things 56 | torch.save(self.model, os.path.join(this_save_dir, "model.pth")) 57 | torch.save( 58 | self.model.state_dict(), os.path.join(this_save_dir, "model_state_dict.pth") 59 | ) 60 | if opt is not None: 61 | torch.save( 62 | opt.state_dict(), os.path.join(this_save_dir, "opt_state_dict.pth") 63 | ) 64 | torch.save(self.args, os.path.join(this_save_dir, "args.pth")) 65 | torch.save( 66 | { 67 | 'curr_iter': self.curr_iter, 68 | 'curr_epoch': self.curr_epoch, 69 | 'best_loss': self.best_loss, 70 | }, 71 | os.path.join(this_save_dir, "train_state.pth")) 72 | 73 | with open(os.path.join(this_save_dir, "args.txt"), "w") as text_file: 74 | text_file.write(world.args_to_str(world.args)) 75 | 76 | def train(self): 77 | 78 | # Make sure the model is where it belongs 79 | if world.device == torch.device("cpu"): 80 | self.model.cpu() 81 | else: 82 | self.model.cuda(world.device) 83 | 84 | # === Basic default optimization parameters 85 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr) 86 | 87 | # Learning rate schedule 88 | def lr_lbmd(it): 89 | lr_clip = 1e-5 90 | return max( 91 | self.args.lr_decay ** (int(it / self.args.decay_step)), 92 | lr_clip / self.args.lr, 93 | ) 94 | 95 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 96 | optimizer, lr_lambda=lr_lbmd, last_epoch=-1 97 | ) 98 | 99 | # === Epoch loop 100 | while self.curr_epoch < self.args.epochs: 101 | total_train_loss = 0. 102 | total_train_loss_count = 0 103 | 104 | print("\n\n=== Epoch {} / {}".format(self.curr_epoch, self.args.epochs)) 105 | 106 | ib_count = 0 107 | for batch in self.train_loader: 108 | 109 | self.model.train() 110 | self.training = True 111 | 112 | # Zero gradients 113 | if ib_count >= self.args.batch_accum: 114 | optimizer.zero_grad() 115 | ib_count = 0 116 | ib_count += 1 117 | 118 | # Invoke the model 119 | model_outputs = self.call_model_fn(self.model, batch, trainer=self) 120 | 121 | # Evaluate loss 122 | loss = self.eval_loss(model_outputs, batch) 123 | 124 | # Get gradients 125 | loss.backward() 126 | 127 | # Step the optimizer 128 | if ib_count >= self.args.batch_accum: 129 | optimizer.step() 130 | 131 | # Step schedules 132 | if lr_scheduler is not None: 133 | lr_scheduler.step(self.curr_iter) 134 | 135 | # Evaluate 136 | if self.curr_iter % self.args.eval_every == 0 and self.curr_iter > 0: 137 | self.evaluate_on_val() 138 | 139 | this_train_loss = loss.item() 140 | total_train_loss += this_train_loss 141 | total_train_loss_count += 1 142 | 143 | self.curr_iter += self.args.batch_size 144 | 145 | # Update states 146 | self.running_train_loss += this_train_loss 147 | self.running_train_loss_count += 1 148 | 149 | 150 | # Always evaluate at end of epoch 151 | epoch_loss = self.evaluate_on_val(viz=True) 152 | mean_train_loss = total_train_loss / total_train_loss_count 153 | print("\n") 154 | print(" epoch {} [it: {}]: eval loss = {} train loss = {}".format(self.curr_epoch, self.curr_iter, epoch_loss.item(), mean_train_loss)) 155 | print(" parameters: lr = {0:.10f}".format(lr_lbmd(self.curr_iter) * self.args.lr)) 156 | 157 | self.save_training_state(opt=optimizer, suffix="epoch{:03d}".format(self.curr_epoch)) 158 | 159 | self.curr_epoch += 1 160 | 161 | # Evaluate the loss for a batch, distributing over batches independently 162 | def eval_loss(self, model_outputs, batch, viz_extra=False): 163 | 164 | # Evaluate loss 165 | batch_loss = torch.tensor(0.0, device=world.device) 166 | batch_size = self.args.batch_size 167 | 168 | # Iterate through the batch, invoking the loss 169 | mean_batch_loss = self.loss_fn(batch, model_outputs, viz_extra=viz_extra, trainer=self) 170 | 171 | return mean_batch_loss 172 | 173 | # Add an entry to the statistic sets tracked during evaluation 174 | # value should be a plain float 175 | def add_eval_stat_entry(self, name, value): 176 | if name not in self.eval_stats: 177 | self.eval_stats[name] = [] 178 | 179 | self.eval_stats[name].append(value) 180 | 181 | # Evaluate loss over (a subset of) the validation dataset and report results 182 | def evaluate_on_val(self, viz=False): 183 | self.model.eval() 184 | self.training = False 185 | self.eval_iter = self.curr_iter 186 | self.eval_stats = {} # clear it 187 | 188 | with torch.no_grad(): 189 | 190 | total_loss = torch.tensor(0.0, device=world.device) 191 | 192 | eval_count = 0 193 | for batch_ind, batch in enumerate(self.val_loader): 194 | 195 | # Only evaluate on the first eval_size entries 196 | if batch_ind * self.args.batch_size >= self.args.eval_size: 197 | break 198 | 199 | # Invoke the model 200 | model_outputs = self.call_model_fn(self.model, batch, trainer=self) 201 | 202 | # Evaluate loss 203 | total_loss += self.eval_loss(model_outputs, batch, viz_extra=viz) 204 | eval_count += 1 205 | self.eval_iter += self.args.batch_size 206 | 207 | total_loss /= eval_count 208 | 209 | world.tb_writer.add_scalar("evaluate loss", total_loss, self.curr_iter) 210 | print(" evaluation [it: {}]: loss = {} train loss since last = {}".format( 211 | self.curr_iter, total_loss.item(), self.running_train_loss / (self.running_train_loss_count+1e-4))) 212 | 213 | # Print the mean of any tracked statistics 214 | for name in self.eval_stats: 215 | val = np.mean(self.eval_stats[name]) 216 | print(" {} : {}".format(name, val)) 217 | 218 | 219 | self.running_train_loss = 0. 220 | self.running_train_loss_count = 0 221 | 222 | if total_loss < self.best_loss: 223 | self.best_loss = total_loss 224 | self.save_training_state(suffix="best") 225 | 226 | self.save_training_state(suffix="last") 227 | 228 | return total_loss 229 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import torch 6 | import numpy as np 7 | import meshio 8 | import igl 9 | 10 | # === Argument management helpers 11 | 12 | def set_args_defaults(args): 13 | 14 | # Manage cuda config 15 | if (not args.disable_cuda and not torch.cuda.is_available()): 16 | print("!!! WARNING: CUDA requested but not available!") 17 | 18 | if (not args.disable_cuda and torch.cuda.is_available()): 19 | args.device = torch.device('cuda:0') 20 | args.dtype = torch.float32 21 | torch.set_default_dtype(args.dtype) 22 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 23 | print("CUDA enabled :)") 24 | else: 25 | args.device = torch.device('cpu') 26 | args.dtype = torch.float32 27 | torch.set_default_dtype(args.dtype) 28 | torch.set_default_tensor_type(torch.FloatTensor) 29 | print("CUDA disabled :(") 30 | 31 | 32 | # === Misc value conversion 33 | 34 | # Really, definitely convert a torch tensor to a numpy array 35 | def toNP(x): 36 | return x.detach().to(torch.device('cpu')).numpy() 37 | 38 | class fake_context(): 39 | def __enter__(self): 40 | return None 41 | def __exit__(self, _1, _2, _3): 42 | return False 43 | 44 | 45 | # === File helpers 46 | def ensure_dir_exists(d): 47 | if not os.path.exists(d): 48 | os.makedirs(d) 49 | 50 | def read_mesh(f): 51 | return igl.read_triangle_mesh(f) 52 | 53 | # === Geometric helpers in pytorch 54 | 55 | 56 | # Computes norm of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension 57 | def norm(x, highdim=False): 58 | 59 | if(len(x.shape) == 1): 60 | raise ValueError("called norm() on single vector of dim " + str(x.shape) + " are you sure?") 61 | if(not highdim and x.shape[-1] > 4): 62 | raise ValueError("called norm() with large last dimension " + str(x.shape) + " are you sure?") 63 | 64 | return torch.norm(x, dim=len(x.shape)-1) 65 | 66 | 67 | def norm2(x, highdim=False): 68 | 69 | if(len(x.shape) == 1): 70 | raise ValueError("called norm() on single vector of dim " + str(x.shape) + " are you sure?") 71 | if(not highdim and x.shape[-1] > 4): 72 | raise ValueError("called norm() with large last dimension " + str(x.shape) + " are you sure?") 73 | 74 | return dot(x, x) 75 | 76 | # Computes normalizes array of vectors along last dimension 77 | def normalize(x, divide_eps=1e-6, highdim=False): 78 | if(len(x.shape) == 1): 79 | raise ValueError("called normalize() on single vector of dim " + str(x.shape) + " are you sure?") 80 | if(not highdim and x.shape[-1] > 4): 81 | raise ValueError("called normalize() with large last dimension " + str(x.shape) + " are you sure?") 82 | 83 | return x / (norm(x, highdim=highdim)+divide_eps).unsqueeze(-1) 84 | 85 | def face_coords(verts, faces): 86 | coords = verts[faces] 87 | return coords 88 | 89 | def face_barycenters(verts, faces): 90 | coords = face_coords(verts, faces) 91 | bary = torch.mean(coords, dim=-2) 92 | return bary 93 | 94 | def cross(vec_A, vec_B): 95 | return torch.cross(vec_A, vec_B, dim=-1) 96 | 97 | def dot(vec_A, vec_B): 98 | return torch.sum(vec_A*vec_B, dim=-1) 99 | 100 | # Given (..., 3) vectors and normals, projects out any components of vecs which lies in the direction of normals. Normals are assumed to be unit. 101 | def project_to_tangent(vecs, unit_normals): 102 | dots = dot(vecs, unit_normals) 103 | return vecs - unit_normals * dots.unsqueeze(-1) 104 | 105 | def face_normals(verts, faces, normalized=True): 106 | coords = face_coords(verts, faces) 107 | vec_A = coords[:, 1, :] - coords[:, 0, :] 108 | vec_B = coords[:, 2, :] - coords[:, 0, :] 109 | 110 | raw_normal = cross(vec_A, vec_B) 111 | 112 | if normalized: 113 | return normalize(raw_normal) 114 | 115 | return raw_normal 116 | 117 | def face_area(verts, faces): 118 | coords = face_coords(verts, faces) 119 | vec_A = coords[:, 1, :] - coords[:, 0, :] 120 | vec_B = coords[:, 2, :] - coords[:, 0, :] 121 | 122 | raw_normal = cross(vec_A, vec_B) 123 | return 0.5 * norm(raw_normal) 124 | -------------------------------------------------------------------------------- /src/world.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | 4 | # Globals and other world state management 5 | 6 | 7 | # Global argument cache 8 | global args 9 | args = [] 10 | 11 | # Global torch default device 12 | global device 13 | device = None 14 | 15 | global dtype 16 | dtype = None 17 | 18 | # Debug checks 19 | global debug_checks 20 | debug_checks = False 21 | 22 | # Tensorboard logger 23 | global tb_writer 24 | tb_writer = None 25 | global tb_tick 26 | tb_tick = 0 27 | 28 | global train_state 29 | train_state = None 30 | 31 | 32 | class ArgsObject(object): 33 | pass 34 | 35 | def args_to_str(args): 36 | 37 | s = [] 38 | 39 | for attr, value in args.__dict__.items(): 40 | if(attr != ""): 41 | s.append(attr + ": " + str(value)) 42 | 43 | return "\n".join(s) 44 | -------------------------------------------------------------------------------- /teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmwsharp/learned-triangulation/12d970d9ce87a973b8aeb0d4ea47562ada578a45/teaser.gif --------------------------------------------------------------------------------