├── LICENSE ├── README.md ├── environment.yml ├── .gitignore ├── train_origami.py ├── hodgeautograd.py ├── train_classification.py ├── train_segmentation.py ├── hodgenet.py ├── square.obj └── meshdata.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Dima Smirnov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## HodgeNet | [Webpage](https://people.csail.mit.edu/smirnov/hodgenet/) | [Paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459797) | [Video](https://youtu.be/juP0PHxvnx8) 2 | 3 | HodgeNet 4 | 5 | **HodgeNet: Learning Spectral Geometry on Triangle Meshes**
6 | Dmitriy Smirnov, Justin Solomon
7 | [SIGGRAPH 2021](https://s2021.siggraph.org/) 8 | 9 | ### Set-up 10 | To install the necessary dependencies, run: 11 | ``` 12 | conda env create -f environment.yml 13 | conda activate HodgeNet 14 | ``` 15 | 16 | ### Training 17 | To train the segmentation model, first download the [Shape COSEG dataset](http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/ssd.htm). Then, run: 18 | ``` 19 | python train_segmentation.py --out out_dir --mesh_path path_to_meshes --seg_path path_to_segs 20 | ``` 21 | 22 | To train the classification model, first download the SHREC 2011 dataset: 23 | ``` 24 | wget -O shrec.tar.gz https://www.dropbox.com/s/4z4v1x30jsy0uoh/shrec.tar.gz?dl=0 25 | tar -xvf shrec.tar.gz -C data 26 | ``` 27 | Then, run: 28 | ``` 29 | python train_classification.py --out out_dir 30 | ``` 31 | 32 | To train the dihedral angle stress test model, run: 33 | ``` 34 | python train_origami.py --out out_dir 35 | ``` 36 | 37 | To monitor the training, launch a TensorBoard instance with `--logdir out_dir` 38 | 39 | To finetune a model, add the flag `--fine_tune` to the above training commands. 40 | 41 | ### BibTeX 42 | ``` 43 | @article{smirnov2021hodgenet, 44 | title={{HodgeNet}: Learning Spectral Geometry on Triangle Meshes}, 45 | author={Smirnov, Dmitriy and Solomon, Justin}, 46 | year={2021}, 47 | journal={SIGGRAPH} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: HodgeNet 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1 8 | - _openmp_mutex=4.5 9 | - blas=1.0 10 | - bzip2=1.0.8 11 | - ca-certificates=2021.5.30 12 | - certifi=2021.5.30 13 | - cudatoolkit=11.1.1 14 | - ffmpeg=4.3 15 | - freetype=2.10.4 16 | - gmp=6.2.1 17 | - gnutls=3.6.13 18 | - igl=2.2.1 19 | - intel-openmp=2021.3.0 20 | - jpeg=9b 21 | - lame=3.100 22 | - lcms2=2.12 23 | - ld_impl_linux-64=2.35.1 24 | - libblas=3.9.0 25 | - libcblas=3.9.0 26 | - libffi=3.3 27 | - libgcc-ng=9.3.0 28 | - libgfortran-ng=11.1.0 29 | - libgfortran5=11.1.0 30 | - libgomp=9.3.0 31 | - libiconv=1.16 32 | - liblapack=3.9.0 33 | - libpng=1.6.37 34 | - libstdcxx-ng=9.3.0 35 | - libtiff=4.2.0 36 | - libuv=1.42.0 37 | - libwebp-base=1.2.0 38 | - lz4-c=1.9.3 39 | - mkl=2021.3.0 40 | - mkl-service=2.4.0 41 | - mkl_fft=1.3.0 42 | - mkl_random=1.2.2 43 | - ncurses=6.2 44 | - nettle=3.6 45 | - ninja=1.10.2 46 | - numpy=1.20.3 47 | - numpy-base=1.20.3 48 | - olefile=0.46 49 | - openh264=2.1.1 50 | - openjpeg=2.4.0 51 | - openssl=1.1.1k 52 | - pillow=8.3.1 53 | - pip=21.1.3 54 | - python=3.9.6 55 | - python_abi=3.9 56 | - pytorch=1.9.0 57 | - readline=8.1 58 | - scipy=1.7.0 59 | - setuptools=52.0.0 60 | - six=1.16.0 61 | - sqlite=3.36.0 62 | - tk=8.6.10 63 | - torchvision=0.10.0 64 | - typing_extensions=3.10.0.0 65 | - tzdata=2021a 66 | - wheel=0.36.2 67 | - xz=5.2.5 68 | - zlib=1.2.11 69 | - zstd=1.4.9 70 | - pip: 71 | - absl-py==0.13.0 72 | - cachetools==4.2.2 73 | - charset-normalizer==2.0.4 74 | - google-auth==1.34.0 75 | - google-auth-oauthlib==0.4.5 76 | - grpcio==1.39.0 77 | - idna==3.2 78 | - markdown==3.3.4 79 | - oauthlib==3.1.1 80 | - protobuf==3.17.3 81 | - pyasn1==0.4.8 82 | - pyasn1-modules==0.2.8 83 | - requests==2.26.0 84 | - requests-oauthlib==1.3.0 85 | - rsa==4.7.2 86 | - tensorboard==2.5.0 87 | - tensorboard-data-server==0.6.1 88 | - tensorboard-plugin-wit==1.8.0 89 | - tqdm==4.62.0 90 | - trimesh==3.9.25 91 | - urllib3==1.26.6 92 | - werkzeug==2.0.1 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | out 132 | data 133 | -------------------------------------------------------------------------------- /train_origami.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torch.utils.data import DataLoader 11 | from torch.utils.tensorboard import SummaryWriter 12 | from tqdm import tqdm 13 | 14 | from hodgenet import HodgeNetModel 15 | from meshdata import OrigamiDataset 16 | 17 | 18 | def main(args): 19 | torch.set_default_dtype(torch.float64) # needed for eigenvalue problems 20 | torch.manual_seed(0) 21 | np.random.seed(0) 22 | 23 | dataset = OrigamiDataset( 24 | edge_features_from_vertex_features=['vertices'], 25 | triangle_features_from_vertex_features=['vertices']) 26 | 27 | def mycollate(b): return b 28 | dataloader = DataLoader(dataset, batch_size=args.bs, 29 | num_workers=0, collate_fn=mycollate) 30 | 31 | example = dataset[0] 32 | hodgenet_model = HodgeNetModel( 33 | example['int_edge_features'].shape[1], 34 | example['triangle_features'].shape[1], 35 | num_output_features=args.n_out_features, mesh_feature=True, 36 | num_eigenvectors=args.n_eig, num_extra_eigenvectors=args.n_extra_eig, 37 | resample_to_triangles=False, 38 | num_bdry_edge_features=example['bdry_edge_features'].shape[1], 39 | num_vector_dimensions=args.num_vector_dimensions) 40 | 41 | origami_model = nn.Sequential( 42 | hodgenet_model, 43 | nn.Linear(args.n_out_features*args.num_vector_dimensions * 44 | args.num_vector_dimensions, 32), 45 | nn.LayerNorm(32), 46 | nn.LeakyReLU(), 47 | nn.Linear(32, 16), 48 | nn.LayerNorm(16), 49 | nn.LeakyReLU(), 50 | nn.Linear(16, 2)) 51 | 52 | optimizer = optim.AdamW(origami_model.parameters(), lr=args.lr) 53 | 54 | if not os.path.exists(args.out): 55 | os.makedirs(args.out) 56 | 57 | train_writer = SummaryWriter(os.path.join( 58 | args.out, datetime.datetime.now().strftime('train-%m%d%y-%H%M%S')), 59 | flush_secs=1) 60 | 61 | def epoch_loop(dataloader, epochname, epochnum, writer, optimize=True): 62 | epoch_loss, epoch_size = 0, 0 63 | pbar = tqdm(total=len(dataloader), 64 | desc='{} {}'.format(epochname, epochnum)) 65 | for batchnum, batch in enumerate(dataloader): 66 | if optimize: 67 | optimizer.zero_grad() 68 | 69 | batch_loss = 0 70 | 71 | dirs = origami_model(batch) 72 | dirs = F.normalize(dirs, p=2, dim=-1) 73 | for mesh, dir_estimate in zip(batch, dirs): 74 | gt_dir = mesh['dir'].to(dir_estimate.device) 75 | batch_loss += 1 - (gt_dir * dir_estimate).sum(-1) 76 | 77 | batch_loss /= len(batch) 78 | 79 | pbar.set_postfix({ 80 | 'loss': batch_loss.item(), 81 | }) 82 | pbar.update(1) 83 | 84 | epoch_loss += batch_loss.item() 85 | epoch_size += 1 86 | 87 | if optimize: 88 | batch_loss.backward() 89 | nn.utils.clip_grad_norm_(origami_model.parameters(), 1) 90 | optimizer.step() 91 | 92 | writer.add_scalar('Loss', batch_loss.item(), 93 | epochnum*len(dataloader)+batchnum) 94 | 95 | pbar.close() 96 | 97 | for epoch in range(args.n_epochs): 98 | origami_model.train() 99 | epoch_loop(dataloader, 'Epoch', epoch, train_writer) 100 | 101 | torch.save({ 102 | 'origami_model_state_dict': origami_model.state_dict(), 103 | 'opt_state_dict': optimizer.state_dict(), 104 | 'epoch': epoch 105 | }, os.path.join(args.out, f'{epoch}.pth')) 106 | 107 | 108 | if __name__ == '__main__': 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument('--out', type=str, default='out/origami') 111 | parser.add_argument('--bs', type=int, default=16) 112 | parser.add_argument('--lr', type=float, default=1e-3) 113 | parser.add_argument('--n_epochs', type=int, default=10000) 114 | parser.add_argument('--n_eig', type=int, default=32) 115 | parser.add_argument('--n_extra_eig', type=int, default=32) 116 | parser.add_argument('--n_out_features', type=int, default=32) 117 | parser.add_argument('--num_vector_dimensions', type=int, default=4) 118 | 119 | args = parser.parse_args() 120 | main(args) 121 | -------------------------------------------------------------------------------- /hodgeautograd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import scipy.sparse.linalg 4 | import torch 5 | import torch.multiprocessing 6 | 7 | 8 | def repeat_d_matrix(d, n_repeats): 9 | """Create block diagonal d matrix for vectorial operator.""" 10 | if n_repeats == 1: 11 | return d 12 | 13 | I, J, V = scipy.sparse.find(d) 14 | 15 | bigI = np.concatenate([I*n_repeats+idx for idx in range(n_repeats)]) 16 | bigJ = np.concatenate([J*n_repeats+idx for idx in range(n_repeats)]) 17 | bigV = np.tile(V, n_repeats) 18 | 19 | result = scipy.sparse.csr_matrix((bigV, (bigI, bigJ)), 20 | shape=(d.shape[0]*n_repeats, 21 | d.shape[1]*n_repeats)) 22 | 23 | return result 24 | 25 | 26 | def single_forward(ctx_eigenvalues, ctx_eigenvectors, ctx_dx, star0, star1, d, 27 | n_eig, n_extra_eig, device): 28 | """Compute the eigenvectors and eigenvalues of the generalized eigensystem. 29 | 30 | L*x = lambda*A*x 31 | where 32 | L = d'*blockdiag(star1)*d 33 | A = blockdiag(star0) 34 | """ 35 | ne, nv = d.shape 36 | nvec = star1.shape[1] 37 | 38 | # make L 39 | star1s = [star1[i].squeeze() for i in range(ne)] 40 | star1mtx = scipy.sparse.block_diag(star1s) 41 | drep = repeat_d_matrix(d, nvec) 42 | L = drep.T @ (star1mtx @ drep) 43 | 44 | # make A 45 | star0s = [star0[i].squeeze() for i in range(nv)] 46 | star0mtx = scipy.sparse.block_diag(star0s) 47 | 48 | # can compute extra eigenvectors beyond n_eig (k total) 49 | # extras will improve quality of derivatives 50 | k = n_eig + n_extra_eig + nvec # adding nvec because all zero eigenvalues 51 | shift = 1e-4 # for numerical stability 52 | 53 | eigenvalues, eigenvectors = scipy.sparse.linalg.eigsh( 54 | L + shift*scipy.sparse.eye(L.shape[0]), 55 | k=k, M=star0mtx, which='LM', sigma=0, tol=1e-3) 56 | eigenvalues -= shift 57 | 58 | # sort eigenvalues/corresponding eigenvectors and make sign consistent 59 | idx = eigenvalues.argsort() 60 | eigenvalues = eigenvalues[idx] 61 | eigenvectors = eigenvectors[:, idx] 62 | eigenvectors = eigenvectors * np.sign(eigenvectors[0]) 63 | eigenvalues[:nvec] = 0 # first Laplacian eigenvalue is always zero 64 | eigenvectors[:, 0] = 1 # will normalize momentarily 65 | 66 | # normalize eigenvectors 67 | vec_norms = np.sqrt( 68 | np.sum(eigenvectors * (star0mtx @ eigenvectors), axis=0)) 69 | eigenvectors = eigenvectors / vec_norms.clip(1e-4) 70 | 71 | reshaped_eigenvectors = np.swapaxes( 72 | np.reshape(eigenvectors, (nv, nvec, k)), 1, 2) 73 | 74 | # differentiate eigenvectors --- useful for the derivative during backprop 75 | d_eig = np.swapaxes(np.reshape(drep @ eigenvectors, (ne, nvec, k)), 1, 2) 76 | 77 | ctx_eigenvalues.copy_(torch.from_numpy(eigenvalues).to(device)) 78 | ctx_eigenvectors.copy_(torch.from_numpy(reshaped_eigenvectors).to(device)) 79 | ctx_dx.copy_(torch.from_numpy(d_eig).to(device)) 80 | 81 | 82 | def single_backward(dx, eigenvalues, eigenvectors, n_eig, 83 | grad_output_eigenvalues, grad_output_eigenvectors, device): 84 | """Backward pass for the eigenproblem.""" 85 | nvec = eigenvectors.shape[2] 86 | 87 | grad_output_eigenvalues = torch.cat( 88 | [torch.zeros(nvec).to(device), grad_output_eigenvalues]) 89 | dstar1 = torch.einsum('i,eil,eim->elm', 90 | grad_output_eigenvalues, dx[:, :n_eig], 91 | dx[:, :n_eig]) 92 | dstar0 = torch.einsum('i,i,wil,wim->wlm', 93 | -grad_output_eigenvalues, eigenvalues[:n_eig], 94 | eigenvectors[:, :n_eig], eigenvectors[:, :n_eig]) 95 | 96 | grad_output_eigenvectors = torch.cat([ 97 | torch.zeros(eigenvectors.shape[0], nvec, nvec).to(device), 98 | grad_output_eigenvectors], dim=1) 99 | 100 | total_eig = eigenvectors.shape[1] # includes the extra eigenvalues 101 | 102 | M = eigenvalues[:, None].repeat(1, total_eig) 103 | M = 1. / (M - M.t()) 104 | M[np.diag_indices(total_eig)] = 0 105 | M[:nvec, :nvec] = 0 106 | 107 | dstar1 += torch.einsum('vjn,vin,ij,ejl,eim->elm', eigenvectors, 108 | grad_output_eigenvectors, M[:n_eig], dx, 109 | dx[:, :n_eig]) 110 | 111 | N = eigenvalues[:, None].repeat(1, total_eig) 112 | N = N / (N.t() - N) 113 | N[:nvec, :nvec] = 0 114 | N[np.diag_indices(total_eig)] = -.5 115 | 116 | dstar0 += torch.einsum('vjn,vin,ij,wjl,wim->wlm', eigenvectors, 117 | grad_output_eigenvectors, N[:n_eig], eigenvectors, 118 | eigenvectors[:, :n_eig]) 119 | 120 | return dstar0, dstar1 121 | 122 | 123 | class HodgeEigensystem(torch.autograd.Function): 124 | """Autograd class for solving batches of Hodge eigensystems. 125 | 126 | WARNING: Assumes that the Hodge star matrices are symmetric. 127 | """ 128 | 129 | @staticmethod 130 | def forward(ctx, nb, n_eig, n_extra_eig, *inputs): 131 | ctx.device = inputs[1].device 132 | ctx.nb = nb 133 | ctx.n_eig = n_eig 134 | 135 | eigenvalues = [ 136 | torch.empty( 137 | inputs[3*i+1].shape[1] + n_eig + n_extra_eig 138 | ).to(ctx.device).share_memory_() 139 | for i in range(nb) 140 | ] 141 | eigenvectors = [ 142 | torch.empty( 143 | inputs[3*i+2].shape[1], 144 | inputs[3*i+1].shape[1] + n_eig + n_extra_eig, 145 | inputs[3*i+1].shape[1] 146 | ).to(ctx.device).share_memory_() 147 | for i in range(nb) 148 | ] 149 | dx = [ 150 | torch.empty( 151 | inputs[3*i+2].shape[0], 152 | inputs[3*i+1].shape[1] + n_eig + n_extra_eig, 153 | inputs[3*i+1].shape[1] 154 | ).to(ctx.device).share_memory_() 155 | for i in range(nb) 156 | ] 157 | 158 | processes = [] 159 | for i in range(nb): 160 | star0, star1, d = inputs[3*i:3*i+3] 161 | p = torch.multiprocessing.Process( 162 | target=single_forward, 163 | args=(eigenvalues[i], eigenvectors[i], dx[i], 164 | star0.detach().cpu().numpy(), 165 | star1.detach().cpu().numpy(), 166 | d, n_eig, n_extra_eig, ctx.device)) 167 | p.start() 168 | processes.append(p) 169 | 170 | for p in processes: 171 | p.join() 172 | 173 | ctx.eigenvalues = eigenvalues 174 | ctx.eigenvectors = eigenvectors 175 | ctx.dx = dx 176 | 177 | ret = [] 178 | for i in range(nb): 179 | nvec = inputs[3*i+1].shape[1] 180 | ret.extend([ 181 | ctx.eigenvalues[i][nvec:n_eig], 182 | ctx.eigenvectors[i][:, nvec:n_eig]]) 183 | 184 | return tuple(ret) 185 | 186 | @staticmethod 187 | def backward(ctx, *grad_output): 188 | ret = [None, None, None] 189 | 190 | for i in range(ctx.nb): 191 | # derivative wrt eigenvalues 192 | dstar0, dstar1 = single_backward( 193 | ctx.dx[i], ctx.eigenvalues[i], ctx.eigenvectors[i], ctx.n_eig, 194 | grad_output[2*i], grad_output[2*i+1], ctx.device) 195 | ret.extend([dstar0, dstar1, None]) 196 | 197 | return tuple(ret) 198 | -------------------------------------------------------------------------------- /train_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | import datetime 4 | import os 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.utils.data import DataLoader 12 | from torch.utils.tensorboard import SummaryWriter 13 | from tqdm import tqdm 14 | 15 | from hodgenet import HodgeNetModel 16 | from meshdata import HodgenetMeshDataset 17 | 18 | 19 | def main(args): 20 | torch.set_default_dtype(torch.float64) # needed for eigenvalue problems 21 | torch.manual_seed(1) # for repeatability 22 | np.random.seed(1) 23 | random.seed(args.seed) 24 | 25 | mesh_files_train = [] 26 | labels_train = [] 27 | mesh_files_val = [] 28 | labels_val = [] 29 | 30 | labeled_files = defaultdict(list) 31 | with open(os.path.join(args.data, 'labels.txt'), 'r') as labels: 32 | for y in labels: 33 | f, i = y.strip().split() 34 | i = int(i) 35 | labeled_files[i].append(os.path.join(args.data, f)) 36 | 37 | for label, files in labeled_files.items(): 38 | random.shuffle(files) 39 | 40 | mesh_files_train.extend(files[:args.train_size]) 41 | labels_train.extend([label]*args.train_size) 42 | mesh_files_val.extend(files[args.train_size:]) 43 | labels_val.extend([label]*len(files[args.train_size:])) 44 | 45 | features = ['vertices', 'normals'] 46 | 47 | dataset = HodgenetMeshDataset( 48 | mesh_files_train, 49 | decimate_range=None if args.fine_tune is not None else (450, 500), 50 | edge_features_from_vertex_features=features, 51 | triangle_features_from_vertex_features=features, 52 | max_stretch=0 if args.fine_tune is not None else 0.05, 53 | random_rotation=True, mesh_features={'category': labels_train}, 54 | normalize_coords=True) 55 | 56 | validation = HodgenetMeshDataset( 57 | mesh_files_val, decimate_range=None, 58 | edge_features_from_vertex_features=features, 59 | triangle_features_from_vertex_features=features, max_stretch=0, 60 | random_rotation=False, mesh_features={'category': labels_val}, 61 | normalize_coords=True) 62 | 63 | def mycollate(b): return b 64 | dataloader = DataLoader(dataset, batch_size=args.bs, 65 | num_workers=args.num_workers, shuffle=True, 66 | collate_fn=mycollate) 67 | validationloader = DataLoader( 68 | validation, batch_size=args.bs, 69 | num_workers=args.num_workers, collate_fn=mycollate) 70 | 71 | example = dataset[0] 72 | hodgenet_model = HodgeNetModel( 73 | example['int_edge_features'].shape[1], 74 | example['triangle_features'].shape[1], 75 | num_output_features=args.n_out_features, mesh_feature=True, 76 | num_eigenvectors=args.n_eig, num_extra_eigenvectors=args.n_extra_eig, 77 | num_vector_dimensions=args.num_vector_dimensions) 78 | 79 | model = nn.Sequential(hodgenet_model, 80 | nn.Linear(args.n_out_features * 81 | args.num_vector_dimensions**2, 64), 82 | nn.BatchNorm1d(64), 83 | nn.LeakyReLU(), 84 | nn.Linear(64, 64), 85 | nn.BatchNorm1d(64), 86 | nn.LeakyReLU(), 87 | nn.Linear(64, len(labeled_files))) 88 | 89 | # categorical variables 90 | loss = nn.CrossEntropyLoss() 91 | 92 | # optimization routine 93 | print(sum(x.numel() for x in model.parameters()), 'parameters') 94 | optimizer = optim.AdamW(model.parameters(), lr=args.lr) 95 | 96 | if args.fine_tune is not None: 97 | checkpoint = torch.load(args.fine_tune) 98 | model.load_state_dict( 99 | checkpoint['model_state_dict']) 100 | optimizer.load_state_dict(checkpoint['opt_state_dict']) 101 | starting_epoch = checkpoint['epoch'] + 1 102 | print(f'Fine tuning! Starting at epoch {starting_epoch}') 103 | else: 104 | starting_epoch = 0 105 | 106 | if not os.path.exists(args.out): 107 | os.makedirs(args.out) 108 | 109 | train_writer = SummaryWriter(os.path.join( 110 | args.out, datetime.datetime.now().strftime('train-%m%d%y-%H%M%S')), 111 | flush_secs=1) 112 | val_writer = SummaryWriter(os.path.join( 113 | args.out, datetime.datetime.now().strftime('val-orig-%m%d%y-%H%M%S')), 114 | flush_secs=1) 115 | 116 | def epoch_loop(dataloader, epochname, epochnum, writer, optimize=True): 117 | epoch_loss, epoch_acc, epoch_size = 0, 0, 0 118 | pbar = tqdm(total=len(dataloader), desc=f'{epochname} {epochnum}') 119 | for batch in dataloader: 120 | if optimize: 121 | optimizer.zero_grad() 122 | 123 | batch_loss, batch_acc = 0, 0 124 | 125 | class_estimate = model(batch) 126 | labels = torch.tensor([x['category'] for x in batch]) 127 | batch_loss = loss(class_estimate, labels) * len(batch) 128 | batch_acc = (class_estimate.argmax(1) == labels).float().sum() 129 | 130 | epoch_loss += batch_loss.item() 131 | epoch_acc += batch_acc.item() 132 | epoch_size += len(batch) 133 | 134 | batch_loss /= len(batch) 135 | batch_acc /= len(batch) 136 | 137 | pbar.set_postfix({ 138 | 'loss': batch_loss.item(), 139 | 'accuracy': batch_acc.item(), 140 | }) 141 | pbar.update(1) 142 | 143 | if optimize: 144 | batch_loss.backward() 145 | nn.utils.clip_grad_norm_(model.parameters(), 1) 146 | optimizer.step() 147 | 148 | writer.add_scalar('loss', epoch_loss / epoch_size, epochnum) 149 | writer.add_scalar('accuracy', epoch_acc / epoch_size, epochnum) 150 | 151 | pbar.close() 152 | 153 | for epoch in range(starting_epoch, starting_epoch+args.n_epochs+1): 154 | model.train() 155 | epoch_loop(dataloader, 'epoch', epoch, train_writer) 156 | 157 | # compute validation score 158 | if epoch % 5 == 0: 159 | model.eval() 160 | with torch.no_grad(): 161 | epoch_loop(validationloader, 'validation', epoch, val_writer, 162 | optimize=False) 163 | 164 | torch.save({ 165 | 'model_state_dict': model.state_dict(), 166 | 'opt_state_dict': optimizer.state_dict(), 167 | 'epoch': epoch 168 | }, os.path.join(args.out, 169 | f'{epoch}_finetune.pth' 170 | if args.fine_tune is not None else f'{epoch}.pth')) 171 | 172 | 173 | if __name__ == '__main__': 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument('--out', type=str, default='out/shrec16') 176 | parser.add_argument('--data', type=str, default='data/shrec') 177 | parser.add_argument('--bs', type=int, default=16) 178 | parser.add_argument('--lr', type=float, default=1e-4) 179 | parser.add_argument('--n_epochs', type=int, default=100) 180 | parser.add_argument('--n_eig', type=int, default=32) 181 | parser.add_argument('--n_extra_eig', type=int, default=32) 182 | parser.add_argument('--n_out_features', type=int, default=32) 183 | parser.add_argument('--fine_tune', type=str, default=None) 184 | parser.add_argument('--num_workers', type=int, default=0) 185 | parser.add_argument('--num_vector_dimensions', type=int, default=4) 186 | parser.add_argument('--train_size', type=int, default=16) 187 | parser.add_argument('--seed', type=int, default=123) 188 | 189 | args = parser.parse_args() 190 | main(args) 191 | -------------------------------------------------------------------------------- /train_segmentation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | from torch.utils.tensorboard import SummaryWriter 11 | from tqdm import tqdm 12 | 13 | from hodgenet import HodgeNetModel 14 | from meshdata import HodgenetMeshDataset 15 | 16 | 17 | def main(args): 18 | torch.set_default_dtype(torch.float64) # needed for eigenvalue problems 19 | torch.manual_seed(args.seed) 20 | np.random.seed(args.seed) 21 | 22 | mesh_files_train = [] 23 | seg_files_train = [] 24 | mesh_files_val = [] 25 | seg_files_val = [] 26 | 27 | files = sorted([f.split('.')[0] for f in os.listdir(args.mesh_path)]) 28 | cutoff = round(0.85 * len(files) + 0.49) 29 | 30 | for i in files[:cutoff]: 31 | mesh_files_train.append(os.path.join(args.mesh_path, f'{i}.off')) 32 | seg_files_train.append(os.path.join(args.seg_path, f'{i}.seg')) 33 | for i in files[cutoff:]: 34 | mesh_files_val.append(os.path.join(args.mesh_path, f'{i}.off')) 35 | seg_files_val.append(os.path.join(args.seg_path, f'{i}.seg')) 36 | 37 | features = ['vertices'] if args.no_normals else ['vertices', 'normals'] 38 | 39 | dataset = HodgenetMeshDataset( 40 | mesh_files_train, 41 | decimate_range=None if args.fine_tune is not None else (1000, 99999), 42 | edge_features_from_vertex_features=features, 43 | triangle_features_from_vertex_features=features, 44 | max_stretch=0 if args.fine_tune is not None else 0.05, 45 | random_rotation=False, segmentation_files=seg_files_train, 46 | normalize_coords=True) 47 | 48 | validation = HodgenetMeshDataset( 49 | mesh_files_val, decimate_range=None, 50 | edge_features_from_vertex_features=features, 51 | triangle_features_from_vertex_features=features, max_stretch=0, 52 | random_rotation=False, segmentation_files=seg_files_val, 53 | normalize_coords=True) 54 | 55 | def mycollate(b): return b 56 | dataloader = DataLoader(dataset, batch_size=args.bs, 57 | num_workers=args.num_workers, shuffle=True, 58 | collate_fn=mycollate) 59 | validationloader = DataLoader(validation, batch_size=args.bs, 60 | num_workers=args.num_workers, 61 | collate_fn=mycollate) 62 | 63 | example = dataset[0] 64 | hodgenet = HodgeNetModel( 65 | example['int_edge_features'].shape[1], 66 | example['triangle_features'].shape[1], 67 | num_output_features=args.n_out_features, mesh_feature=False, 68 | num_eigenvectors=args.n_eig, num_extra_eigenvectors=args.n_extra_eig, 69 | resample_to_triangles=True, 70 | num_vector_dimensions=args.num_vector_dimensions) 71 | 72 | model = nn.Sequential( 73 | hodgenet, 74 | nn.Linear(args.n_out_features*args.num_vector_dimensions * 75 | args.num_vector_dimensions, 32), 76 | nn.BatchNorm1d(32), 77 | nn.LeakyReLU(), 78 | nn.Linear(32, 32), 79 | nn.BatchNorm1d(32), 80 | nn.LeakyReLU(), 81 | nn.Linear(32, dataset.n_seg_categories)) 82 | 83 | # categorical variables 84 | loss = nn.CrossEntropyLoss() 85 | 86 | # optimization routine 87 | print(sum(x.numel() for x in model.parameters()), 'parameters') 88 | optimizer = optim.AdamW(model.parameters(), lr=args.lr) 89 | 90 | if not os.path.exists(args.out): 91 | os.makedirs(args.out) 92 | 93 | if args.fine_tune is not None: 94 | checkpoint = torch.load(os.path.join(args.fine_tune)) 95 | model.load_state_dict(checkpoint['model_state_dict']) 96 | optimizer.load_state_dict(checkpoint['opt_state_dict']) 97 | starting_epoch = checkpoint['epoch'] + 1 98 | print(f'Fine tuning! Starting at epoch {starting_epoch}') 99 | else: 100 | starting_epoch = 0 101 | 102 | train_writer = SummaryWriter(os.path.join( 103 | args.out, datetime.datetime.now().strftime('train-%m%d%y-%H%M%S')), 104 | flush_secs=1) 105 | val_writer = SummaryWriter(os.path.join( 106 | args.out, datetime.datetime.now().strftime('val-%m%d%y-%H%M%S')), 107 | flush_secs=1) 108 | 109 | def epoch_loop(dataloader, epochname, epochnum, writer, optimize=True): 110 | epoch_loss, epoch_acc, epoch_acc_weighted, epoch_size = 0, 0, 0, 0 111 | pbar = tqdm(total=len(dataloader), desc=f'{epochname} {epochnum}') 112 | for batch in dataloader: 113 | if optimize: 114 | optimizer.zero_grad() 115 | 116 | batch_loss, batch_acc, batch_acc_weighted = 0, 0, 0 117 | 118 | seg_estimates = torch.split(model(batch), [m['triangles'].shape[0] 119 | for m in batch], dim=0) 120 | for mesh, seg_estimate in zip(batch, seg_estimates): 121 | gt_segs = mesh['segmentation'].squeeze(-1) 122 | areas = mesh['areas'] 123 | batch_loss += loss(seg_estimate, gt_segs) 124 | batch_acc += (seg_estimate.argmax(1) == gt_segs).float().mean() 125 | batch_acc_weighted += ((seg_estimate.argmax(1) == gt_segs) 126 | * areas).sum() / areas.sum() 127 | 128 | epoch_loss += batch_loss.item() 129 | epoch_acc += batch_acc.item() 130 | epoch_acc_weighted += batch_acc_weighted.item() 131 | epoch_size += len(batch) 132 | 133 | batch_loss /= len(batch) 134 | batch_acc /= len(batch) 135 | batch_acc_weighted /= len(batch) 136 | 137 | pbar.set_postfix({ 138 | 'loss': batch_loss.item(), 139 | 'accuracy': batch_acc.item(), 140 | 'accuracy_weighted': batch_acc_weighted.item(), 141 | }) 142 | pbar.update(1) 143 | 144 | if optimize: 145 | batch_loss.backward() 146 | nn.utils.clip_grad_norm_(model.parameters(), 1) 147 | optimizer.step() 148 | 149 | writer.add_scalar('loss', epoch_loss / epoch_size, epochnum) 150 | writer.add_scalar('accuracy', epoch_acc / epoch_size, epochnum) 151 | writer.add_scalar('accuracy_weighted', 152 | epoch_acc_weighted / epoch_size, epochnum) 153 | 154 | pbar.close() 155 | 156 | for epoch in range(starting_epoch, starting_epoch+args.n_epochs+1): 157 | model.train() 158 | epoch_loop(dataloader, 'Epoch', epoch, train_writer) 159 | 160 | # compute validation score 161 | if epoch % 5 == 0: 162 | model.eval() 163 | with torch.no_grad(): 164 | epoch_loop(validationloader, 'Validation', 165 | epoch, val_writer, optimize=False) 166 | 167 | torch.save({ 168 | 'model_state_dict': model.state_dict(), 169 | 'opt_state_dict': optimizer.state_dict(), 170 | 'epoch': epoch 171 | }, os.path.join(args.out, 172 | f'{epoch}_finetune.pth' 173 | if args.fine_tune is not None else f'{epoch}.pth')) 174 | 175 | 176 | if __name__ == '__main__': 177 | parser = argparse.ArgumentParser() 178 | parser.add_argument('--out', type=str, default='out/vase') 179 | parser.add_argument('--mesh_path', type=str, default='data/coseg_vase') 180 | parser.add_argument('--seg_path', type=str, default='data/coseg_vase_gt') 181 | parser.add_argument('--bs', type=int, default=16) 182 | parser.add_argument('--lr', type=float, default=1e-4) 183 | parser.add_argument('--n_epochs', type=int, default=100) 184 | parser.add_argument('--n_eig', type=int, default=32) 185 | parser.add_argument('--n_extra_eig', type=int, default=32) 186 | parser.add_argument('--n_out_features', type=int, default=32) 187 | parser.add_argument('--fine_tune', type=str, default=None) 188 | parser.add_argument('--num_workers', type=int, default=0) 189 | parser.add_argument('--num_vector_dimensions', type=int, default=4) 190 | parser.add_argument('--seed', type=int, default=123) 191 | parser.add_argument('--no_normals', action='store_true', default=False) 192 | 193 | args = parser.parse_args() 194 | main(args) 195 | -------------------------------------------------------------------------------- /hodgenet.py: -------------------------------------------------------------------------------- 1 | import scipy 2 | import scipy.sparse.linalg 3 | import torch 4 | import torch.nn as nn 5 | 6 | from hodgeautograd import HodgeEigensystem 7 | 8 | 9 | class HodgeNetModel(nn.Module): 10 | """Main HodgeNet model. 11 | 12 | The model inputs a batch of meshes and outputs features per vertex or 13 | pooled to faces or the entire mesh. 14 | """ 15 | def __init__(self, num_edge_features, num_triangle_features, 16 | num_output_features=32, num_eigenvectors=64, 17 | num_extra_eigenvectors=16, mesh_feature=False, min_star=1e-2, 18 | resample_to_triangles=False, num_bdry_edge_features=None, 19 | num_vector_dimensions=1): 20 | super(HodgeNetModel, self).__init__() 21 | 22 | self.num_triangle_features = num_triangle_features 23 | self.hodgefunc = HodgeEigensystem.apply 24 | self.num_eigenvectors = num_eigenvectors 25 | self.num_extra_eigenvectors = num_extra_eigenvectors 26 | self.num_output_features = num_output_features 27 | self.min_star = min_star 28 | self.resample_to_triangles = resample_to_triangles 29 | self.mesh_feature = mesh_feature 30 | self.num_vector_dimensions = num_vector_dimensions 31 | 32 | self.to_star1 = nn.Sequential( 33 | nn.Linear(num_edge_features, 32), 34 | nn.BatchNorm1d(32), 35 | nn.LeakyReLU(), 36 | nn.Linear(32, 32), 37 | nn.BatchNorm1d(32), 38 | nn.LeakyReLU(), 39 | nn.Linear(32, 32), 40 | nn.BatchNorm1d(32), 41 | nn.LeakyReLU(), 42 | nn.Linear(32, 32), 43 | nn.BatchNorm1d(32), 44 | nn.LeakyReLU(), 45 | nn.Linear(32, self.num_vector_dimensions**2) 46 | ) 47 | 48 | if num_bdry_edge_features is not None: 49 | self.to_star1_bdry = nn.Sequential( 50 | nn.Linear(num_bdry_edge_features, 32), 51 | nn.BatchNorm1d(32), 52 | nn.LeakyReLU(), 53 | nn.Linear(32, 32), 54 | nn.BatchNorm1d(32), 55 | nn.LeakyReLU(), 56 | nn.Linear(32, 32), 57 | nn.BatchNorm1d(32), 58 | nn.LeakyReLU(), 59 | nn.Linear(32, 32), 60 | nn.BatchNorm1d(32), 61 | nn.LeakyReLU(), 62 | nn.Linear(32, self.num_vector_dimensions**2) 63 | ) 64 | else: 65 | self.to_star1_bdry = None 66 | 67 | self.to_star0_tri = nn.Sequential( 68 | nn.Linear(num_triangle_features, 32), 69 | nn.BatchNorm1d(32), 70 | nn.LeakyReLU(), 71 | nn.Linear(32, 32), 72 | nn.BatchNorm1d(32), 73 | nn.LeakyReLU(), 74 | nn.Linear(32, 32), 75 | nn.BatchNorm1d(32), 76 | nn.LeakyReLU(), 77 | nn.Linear(32, 32), 78 | nn.BatchNorm1d(32), 79 | nn.LeakyReLU(), 80 | nn.Linear(32, self.num_vector_dimensions * 81 | self.num_vector_dimensions) 82 | ) 83 | 84 | self.eigenvalue_to_matrix = nn.Sequential( 85 | nn.Linear(1, num_output_features), 86 | nn.BatchNorm1d(num_output_features), 87 | nn.LeakyReLU(), 88 | nn.Linear(num_output_features, num_output_features), 89 | nn.BatchNorm1d(num_output_features), 90 | nn.LeakyReLU(), 91 | nn.Linear(num_output_features, num_output_features), 92 | nn.BatchNorm1d(num_output_features), 93 | nn.LeakyReLU(), 94 | nn.Linear(num_output_features, num_output_features), 95 | nn.BatchNorm1d(num_output_features), 96 | nn.LeakyReLU(), 97 | nn.Linear(num_output_features, num_output_features) 98 | ) 99 | 100 | def gather_star0(self, mesh, star0_tri): 101 | """Compute star0 matrix per vertex by gathering from triangles.""" 102 | star0 = torch.zeros(mesh['vertices'].shape[0], 103 | star0_tri.shape[1]).to(star0_tri) 104 | star0.index_add_(0, mesh['triangles'][:, 0], star0_tri) 105 | star0.index_add_(0, mesh['triangles'][:, 1], star0_tri) 106 | star0.index_add_(0, mesh['triangles'][:, 2], star0_tri) 107 | 108 | star0 = star0.view(-1, self.num_vector_dimensions, 109 | self.num_vector_dimensions) 110 | 111 | # square the tensor to be semidefinite 112 | star0 = torch.einsum('ijk,ilk->ijl', star0, star0) 113 | 114 | # add min star down the diagonal 115 | star0 += torch.eye(self.num_vector_dimensions)[None].to(star0) * \ 116 | self.min_star 117 | 118 | return star0 119 | 120 | def compute_mesh_eigenfunctions(self, mesh, star0, star1, bdry=False): 121 | """Compute eigenvectors and eigenvalues of the learned operator.""" 122 | nb = len(mesh) 123 | 124 | inputs = [] 125 | for m, s0, s1 in zip(mesh, star0, star1): 126 | d = m['int_d01'] 127 | if bdry: 128 | d = scipy.sparse.vstack([d, m['bdry_d01']]) 129 | inputs.extend([s0, s1, d]) 130 | 131 | eigenvalues, eigenvectors = [], [] 132 | outputs = self.hodgefunc(nb, self.num_eigenvectors, 133 | self.num_extra_eigenvectors, *inputs) 134 | for i in range(nb): 135 | eigenvalues.append(outputs[2*i]) 136 | eigenvectors.append(outputs[2*i+1]) 137 | 138 | return eigenvalues, eigenvectors 139 | 140 | def forward(self, batch): 141 | nb = len(batch) 142 | 143 | all_star0_tri = self.to_star0_tri( 144 | torch.cat([mesh['triangle_features'] for mesh in batch], dim=0)) 145 | star0_tri_split = torch.split( 146 | all_star0_tri, [mesh['triangles'].shape[0] for mesh in batch], 147 | dim=0) 148 | star0_split = [self.gather_star0(mesh, star0_tri) 149 | for mesh, star0_tri in zip(batch, star0_tri_split)] 150 | 151 | all_star1 = self.to_star1(torch.cat([mesh['int_edge_features'] 152 | for mesh in batch], dim=0)) 153 | all_star1 = all_star1.view(-1, self.num_vector_dimensions, 154 | self.num_vector_dimensions) 155 | all_star1 = torch.einsum('ijk,ilk->ijl', all_star1, all_star1) 156 | all_star1 += torch.eye( 157 | self.num_vector_dimensions)[None].to(all_star1) * \ 158 | self.min_star 159 | star1_split = list(torch.split(all_star1, [mesh['int_d01'].shape[0] 160 | for mesh in batch], dim=0)) 161 | 162 | if self.to_star1_bdry is not None: 163 | all_star1_bdry = self.to_star1_bdry( 164 | torch.cat([mesh['bdry_edge_features'] for mesh in batch], 165 | dim=0)) 166 | all_star1_bdry = all_star1_bdry.view( 167 | -1, self.num_vector_dimensions, self.num_vector_dimensions) 168 | all_star1_bdry = torch.einsum( 169 | 'ijk,ilk->ijl', all_star1_bdry, all_star1_bdry) 170 | all_star1_bdry += torch.eye( 171 | self.num_vector_dimensions)[None].to(all_star1_bdry) * \ 172 | self.min_star 173 | star1_bdry_split = torch.split( 174 | all_star1_bdry, 175 | [mesh['bdry_d01'].shape[0] for mesh in batch], dim=0) 176 | 177 | for i in range(nb): 178 | star1_split[i] = torch.cat( 179 | [star1_split[i], star1_bdry_split[i]], dim=0) 180 | 181 | eigenvalues, eigenvectors = self.compute_mesh_eigenfunctions( 182 | batch, star0_split, star1_split, 183 | bdry=self.to_star1_bdry is not None) 184 | 185 | # glue the eigenvalues back together and run through the nonlinearity 186 | all_processed_eigenvalues = self.eigenvalue_to_matrix( 187 | torch.stack(eigenvalues).view(-1, 1)).view( 188 | nb, -1, self.num_output_features) 189 | 190 | # post-multiply the set of eigenvectors by the learned matrix that's a 191 | # function of eigenvalues (similar to HKS, WKS) 192 | outer_products = [torch.einsum( 193 | 'ijk,ijl->ijkl', eigenvectors[i], eigenvectors[i]) 194 | for i in range(nb)] # take outer product of vectors 195 | 196 | result = [torch.einsum( 197 | 'ijkp,jl->ilkp', outer_products[i], all_processed_eigenvalues[i]) 198 | for i in range(nb)] # multiply by learned matrix 199 | 200 | result = [result[i].flatten(start_dim=1) for i in range(nb)] 201 | 202 | if self.resample_to_triangles: 203 | result = [result[i][batch[i]['triangles']].max( 204 | 1)[0] for i in range(nb)] 205 | 206 | if self.mesh_feature: 207 | result = [f.max(0, keepdim=True)[0] for f in result] 208 | 209 | return torch.cat(result, dim=0) 210 | -------------------------------------------------------------------------------- /square.obj: -------------------------------------------------------------------------------- 1 | v -0.500000 -0.000000 0.500000 2 | v -0.400000 -0.000000 0.500000 3 | v -0.300000 -0.000000 0.500000 4 | v -0.200000 -0.000000 0.500000 5 | v -0.100000 -0.000000 0.500000 6 | v 0.000000 -0.000000 0.500000 7 | v 0.100000 -0.000000 0.500000 8 | v 0.200000 -0.000000 0.500000 9 | v 0.300000 -0.000000 0.500000 10 | v 0.400000 -0.000000 0.500000 11 | v 0.500000 -0.000000 0.500000 12 | v -0.500000 -0.000000 0.400000 13 | v -0.400000 -0.000000 0.400000 14 | v -0.300000 -0.000000 0.400000 15 | v -0.200000 -0.000000 0.400000 16 | v -0.100000 -0.000000 0.400000 17 | v 0.000000 -0.000000 0.400000 18 | v 0.100000 -0.000000 0.400000 19 | v 0.200000 -0.000000 0.400000 20 | v 0.300000 -0.000000 0.400000 21 | v 0.400000 -0.000000 0.400000 22 | v 0.500000 -0.000000 0.400000 23 | v -0.500000 -0.000000 0.300000 24 | v -0.400000 -0.000000 0.300000 25 | v -0.300000 -0.000000 0.300000 26 | v -0.200000 -0.000000 0.300000 27 | v -0.100000 -0.000000 0.300000 28 | v 0.000000 -0.000000 0.300000 29 | v 0.100000 -0.000000 0.300000 30 | v 0.200000 -0.000000 0.300000 31 | v 0.300000 -0.000000 0.300000 32 | v 0.400000 -0.000000 0.300000 33 | v 0.500000 -0.000000 0.300000 34 | v -0.500000 -0.000000 0.200000 35 | v -0.400000 -0.000000 0.200000 36 | v -0.300000 -0.000000 0.200000 37 | v -0.200000 -0.000000 0.200000 38 | v -0.100000 -0.000000 0.200000 39 | v 0.000000 -0.000000 0.200000 40 | v 0.100000 -0.000000 0.200000 41 | v 0.200000 -0.000000 0.200000 42 | v 0.300000 -0.000000 0.200000 43 | v 0.400000 -0.000000 0.200000 44 | v 0.500000 -0.000000 0.200000 45 | v -0.500000 -0.000000 0.100000 46 | v -0.400000 -0.000000 0.100000 47 | v -0.300000 -0.000000 0.100000 48 | v -0.200000 -0.000000 0.100000 49 | v -0.100000 -0.000000 0.100000 50 | v 0.000000 -0.000000 0.100000 51 | v 0.100000 -0.000000 0.100000 52 | v 0.200000 -0.000000 0.100000 53 | v 0.300000 -0.000000 0.100000 54 | v 0.400000 -0.000000 0.100000 55 | v 0.500000 -0.000000 0.100000 56 | v -0.500000 0.000000 0.000000 57 | v -0.400000 0.000000 0.000000 58 | v -0.300000 0.000000 0.000000 59 | v -0.200000 0.000000 0.000000 60 | v -0.100000 0.000000 0.000000 61 | v 0.000000 0.000000 0.000000 62 | v 0.100000 0.000000 0.000000 63 | v 0.200000 0.000000 0.000000 64 | v 0.300000 0.000000 0.000000 65 | v 0.400000 0.000000 0.000000 66 | v 0.500000 0.000000 0.000000 67 | v -0.500000 0.000000 -0.100000 68 | v -0.400000 0.000000 -0.100000 69 | v -0.300000 0.000000 -0.100000 70 | v -0.200000 0.000000 -0.100000 71 | v -0.100000 0.000000 -0.100000 72 | v 0.000000 0.000000 -0.100000 73 | v 0.100000 0.000000 -0.100000 74 | v 0.200000 0.000000 -0.100000 75 | v 0.300000 0.000000 -0.100000 76 | v 0.400000 0.000000 -0.100000 77 | v 0.500000 0.000000 -0.100000 78 | v -0.500000 0.000000 -0.200000 79 | v -0.400000 0.000000 -0.200000 80 | v -0.300000 0.000000 -0.200000 81 | v -0.200000 0.000000 -0.200000 82 | v -0.100000 0.000000 -0.200000 83 | v 0.000000 0.000000 -0.200000 84 | v 0.100000 0.000000 -0.200000 85 | v 0.200000 0.000000 -0.200000 86 | v 0.300000 0.000000 -0.200000 87 | v 0.400000 0.000000 -0.200000 88 | v 0.500000 0.000000 -0.200000 89 | v -0.500000 0.000000 -0.300000 90 | v -0.400000 0.000000 -0.300000 91 | v -0.300000 0.000000 -0.300000 92 | v -0.200000 0.000000 -0.300000 93 | v -0.100000 0.000000 -0.300000 94 | v 0.000000 0.000000 -0.300000 95 | v 0.100000 0.000000 -0.300000 96 | v 0.200000 0.000000 -0.300000 97 | v 0.300000 0.000000 -0.300000 98 | v 0.400000 0.000000 -0.300000 99 | v 0.500000 0.000000 -0.300000 100 | v -0.500000 0.000000 -0.400000 101 | v -0.400000 0.000000 -0.400000 102 | v -0.300000 0.000000 -0.400000 103 | v -0.200000 0.000000 -0.400000 104 | v -0.100000 0.000000 -0.400000 105 | v 0.000000 0.000000 -0.400000 106 | v 0.100000 0.000000 -0.400000 107 | v 0.200000 0.000000 -0.400000 108 | v 0.300000 0.000000 -0.400000 109 | v 0.400000 0.000000 -0.400000 110 | v 0.500000 0.000000 -0.400000 111 | v -0.500000 0.000000 -0.500000 112 | v -0.400000 0.000000 -0.500000 113 | v -0.300000 0.000000 -0.500000 114 | v -0.200000 0.000000 -0.500000 115 | v -0.100000 0.000000 -0.500000 116 | v 0.000000 0.000000 -0.500000 117 | v 0.100000 0.000000 -0.500000 118 | v 0.200000 0.000000 -0.500000 119 | v 0.300000 0.000000 -0.500000 120 | v 0.400000 0.000000 -0.500000 121 | v 0.500000 0.000000 -0.500000 122 | f 1/1 2/2 12/12 123 | f 12/12 2/2 13/13 124 | f 2/2 3/3 13/13 125 | f 13/13 3/3 14/14 126 | f 3/3 4/4 14/14 127 | f 14/14 4/4 15/15 128 | f 4/4 5/5 15/15 129 | f 15/15 5/5 16/16 130 | f 5/5 6/6 16/16 131 | f 16/16 6/6 17/17 132 | f 6/6 7/7 17/17 133 | f 17/17 7/7 18/18 134 | f 7/7 8/8 18/18 135 | f 18/18 8/8 19/19 136 | f 8/8 9/9 19/19 137 | f 19/19 9/9 20/20 138 | f 9/9 10/10 20/20 139 | f 20/20 10/10 21/21 140 | f 10/10 11/11 21/21 141 | f 21/21 11/11 22/22 142 | f 12/12 13/13 23/23 143 | f 23/23 13/13 24/24 144 | f 13/13 14/14 24/24 145 | f 24/24 14/14 25/25 146 | f 14/14 15/15 25/25 147 | f 25/25 15/15 26/26 148 | f 15/15 16/16 26/26 149 | f 26/26 16/16 27/27 150 | f 16/16 17/17 27/27 151 | f 27/27 17/17 28/28 152 | f 17/17 18/18 28/28 153 | f 28/28 18/18 29/29 154 | f 18/18 19/19 29/29 155 | f 29/29 19/19 30/30 156 | f 19/19 20/20 30/30 157 | f 30/30 20/20 31/31 158 | f 20/20 21/21 31/31 159 | f 31/31 21/21 32/32 160 | f 21/21 22/22 32/32 161 | f 32/32 22/22 33/33 162 | f 23/23 24/24 34/34 163 | f 34/34 24/24 35/35 164 | f 24/24 25/25 35/35 165 | f 35/35 25/25 36/36 166 | f 25/25 26/26 36/36 167 | f 36/36 26/26 37/37 168 | f 26/26 27/27 37/37 169 | f 37/37 27/27 38/38 170 | f 27/27 28/28 38/38 171 | f 38/38 28/28 39/39 172 | f 28/28 29/29 39/39 173 | f 39/39 29/29 40/40 174 | f 29/29 30/30 40/40 175 | f 40/40 30/30 41/41 176 | f 30/30 31/31 41/41 177 | f 41/41 31/31 42/42 178 | f 31/31 32/32 42/42 179 | f 42/42 32/32 43/43 180 | f 32/32 33/33 43/43 181 | f 43/43 33/33 44/44 182 | f 34/34 35/35 45/45 183 | f 45/45 35/35 46/46 184 | f 35/35 36/36 46/46 185 | f 46/46 36/36 47/47 186 | f 36/36 37/37 47/47 187 | f 47/47 37/37 48/48 188 | f 37/37 38/38 48/48 189 | f 48/48 38/38 49/49 190 | f 38/38 39/39 49/49 191 | f 49/49 39/39 50/50 192 | f 39/39 40/40 50/50 193 | f 50/50 40/40 51/51 194 | f 40/40 41/41 51/51 195 | f 51/51 41/41 52/52 196 | f 41/41 42/42 52/52 197 | f 52/52 42/42 53/53 198 | f 42/42 43/43 53/53 199 | f 53/53 43/43 54/54 200 | f 43/43 44/44 54/54 201 | f 54/54 44/44 55/55 202 | f 45/45 46/46 56/56 203 | f 56/56 46/46 57/57 204 | f 46/46 47/47 57/57 205 | f 57/57 47/47 58/58 206 | f 47/47 48/48 58/58 207 | f 58/58 48/48 59/59 208 | f 48/48 49/49 59/59 209 | f 59/59 49/49 60/60 210 | f 49/49 50/50 60/60 211 | f 60/60 50/50 61/61 212 | f 50/50 51/51 61/61 213 | f 61/61 51/51 62/62 214 | f 51/51 52/52 62/62 215 | f 62/62 52/52 63/63 216 | f 52/52 53/53 63/63 217 | f 63/63 53/53 64/64 218 | f 53/53 54/54 64/64 219 | f 64/64 54/54 65/65 220 | f 54/54 55/55 65/65 221 | f 65/65 55/55 66/66 222 | f 56/56 57/57 67/67 223 | f 67/67 57/57 68/68 224 | f 57/57 58/58 68/68 225 | f 68/68 58/58 69/69 226 | f 58/58 59/59 69/69 227 | f 69/69 59/59 70/70 228 | f 59/59 60/60 70/70 229 | f 70/70 60/60 71/71 230 | f 60/60 61/61 71/71 231 | f 71/71 61/61 72/72 232 | f 61/61 62/62 72/72 233 | f 72/72 62/62 73/73 234 | f 62/62 63/63 73/73 235 | f 73/73 63/63 74/74 236 | f 63/63 64/64 74/74 237 | f 74/74 64/64 75/75 238 | f 64/64 65/65 75/75 239 | f 75/75 65/65 76/76 240 | f 65/65 66/66 76/76 241 | f 76/76 66/66 77/77 242 | f 67/67 68/68 78/78 243 | f 78/78 68/68 79/79 244 | f 68/68 69/69 79/79 245 | f 79/79 69/69 80/80 246 | f 69/69 70/70 80/80 247 | f 80/80 70/70 81/81 248 | f 70/70 71/71 81/81 249 | f 81/81 71/71 82/82 250 | f 71/71 72/72 82/82 251 | f 82/82 72/72 83/83 252 | f 72/72 73/73 83/83 253 | f 83/83 73/73 84/84 254 | f 73/73 74/74 84/84 255 | f 84/84 74/74 85/85 256 | f 74/74 75/75 85/85 257 | f 85/85 75/75 86/86 258 | f 75/75 76/76 86/86 259 | f 86/86 76/76 87/87 260 | f 76/76 77/77 87/87 261 | f 87/87 77/77 88/88 262 | f 78/78 79/79 89/89 263 | f 89/89 79/79 90/90 264 | f 79/79 80/80 90/90 265 | f 90/90 80/80 91/91 266 | f 80/80 81/81 91/91 267 | f 91/91 81/81 92/92 268 | f 81/81 82/82 92/92 269 | f 92/92 82/82 93/93 270 | f 82/82 83/83 93/93 271 | f 93/93 83/83 94/94 272 | f 83/83 84/84 94/94 273 | f 94/94 84/84 95/95 274 | f 84/84 85/85 95/95 275 | f 95/95 85/85 96/96 276 | f 85/85 86/86 96/96 277 | f 96/96 86/86 97/97 278 | f 86/86 87/87 97/97 279 | f 97/97 87/87 98/98 280 | f 87/87 88/88 98/98 281 | f 98/98 88/88 99/99 282 | f 89/89 90/90 100/100 283 | f 100/100 90/90 101/101 284 | f 90/90 91/91 101/101 285 | f 101/101 91/91 102/102 286 | f 91/91 92/92 102/102 287 | f 102/102 92/92 103/103 288 | f 92/92 93/93 103/103 289 | f 103/103 93/93 104/104 290 | f 93/93 94/94 104/104 291 | f 104/104 94/94 105/105 292 | f 94/94 95/95 105/105 293 | f 105/105 95/95 106/106 294 | f 95/95 96/96 106/106 295 | f 106/106 96/96 107/107 296 | f 96/96 97/97 107/107 297 | f 107/107 97/97 108/108 298 | f 97/97 98/98 108/108 299 | f 108/108 98/98 109/109 300 | f 98/98 99/99 109/109 301 | f 109/109 99/99 110/110 302 | f 100/100 101/101 111/111 303 | f 111/111 101/101 112/112 304 | f 101/101 102/102 112/112 305 | f 112/112 102/102 113/113 306 | f 102/102 103/103 113/113 307 | f 113/113 103/103 114/114 308 | f 103/103 104/104 114/114 309 | f 114/114 104/104 115/115 310 | f 104/104 105/105 115/115 311 | f 115/115 105/105 116/116 312 | f 105/105 106/106 116/116 313 | f 116/116 106/106 117/117 314 | f 106/106 107/107 117/117 315 | f 117/117 107/107 118/118 316 | f 107/107 108/108 118/118 317 | f 118/118 108/108 119/119 318 | f 108/108 109/109 119/119 319 | f 119/119 109/109 120/120 320 | f 109/109 110/110 120/120 321 | f 120/120 110/110 121/121 322 | -------------------------------------------------------------------------------- /meshdata.py: -------------------------------------------------------------------------------- 1 | import igl 2 | import numpy as np 3 | import scipy 4 | import scipy.io 5 | import torch 6 | import trimesh 7 | 8 | 9 | def random_rotation_matrix(): 10 | """Generate a random 3D rotation matrix.""" 11 | Q, _ = np.linalg.qr(np.random.normal(size=(3, 3))) 12 | return Q 13 | 14 | 15 | def random_scale_matrix(max_stretch): 16 | """Generate a random 3D anisotropic scaling matrix.""" 17 | return np.diag(1 + (np.random.rand(3)*2 - 1) * max_stretch) 18 | 19 | 20 | def d01(v, e): 21 | """Compute d01 operator from 0-froms to 1-forms.""" 22 | row = np.tile(np.arange(e.shape[0]), 2) 23 | col = e.T.flatten() 24 | data = np.concatenate([np.ones(e.shape[0]), -np.ones(e.shape[0])], axis=0) 25 | d = scipy.sparse.csr_matrix( 26 | (data, (row, col)), dtype=np.double, shape=(e.shape[0], v.shape[0])) 27 | return d 28 | 29 | 30 | def flip(f, ev, ef, bdry=False): 31 | """Compute vertices of flipped edges.""" 32 | # glue together the triangle vertices adjacent to each edge 33 | duplicate = f[ef].reshape(ef.shape[0], -1) 34 | duplicate[(duplicate == ev[:, 0, None]) 35 | | (duplicate == ev[:, 1, None])] = -1 # remove edge vertices 36 | 37 | # find the two remaining verts (not -1) in an orientation-preserving way 38 | idxs = (-duplicate).argsort(1)[:, :(1 if bdry else 2)] 39 | idxs.sort(1) # preserve orientation by doing them in index order 40 | result = np.take_along_axis(duplicate, idxs, axis=1) 41 | 42 | return result 43 | 44 | 45 | class HodgenetMeshDataset(torch.utils.data.Dataset): 46 | """Dataset of meshes with labels.""" 47 | 48 | def __init__(self, mesh_files, mesh_features={}, decimate_range=None, 49 | random_rotation=True, max_stretch=0.1, 50 | edge_features_from_vertex_features=['vertices'], 51 | triangle_features_from_vertex_features=['vertices'], 52 | center_vertices=True, normalize_coords=True, 53 | segmentation_files=None): 54 | self.mesh_files = mesh_files 55 | self.mesh_features = mesh_features 56 | self.decimate_range = decimate_range 57 | self.random_rotation = random_rotation 58 | self.max_stretch = max_stretch 59 | self.edge_features_from_vertex_features = \ 60 | edge_features_from_vertex_features 61 | self.triangle_features_from_vertex_features = \ 62 | triangle_features_from_vertex_features 63 | self.center_vertices = center_vertices 64 | self.segmentation_files = segmentation_files 65 | self.normalize_coords = normalize_coords 66 | 67 | self.min_category = float('inf') 68 | self.n_seg_categories = 0 69 | if self.segmentation_files is not None: 70 | for f in segmentation_files: 71 | triangle_data = np.loadtxt(f, dtype=np.int64) 72 | self.n_seg_categories = max( 73 | self.n_seg_categories, triangle_data.max()) 74 | self.min_category = min(self.min_category, triangle_data.min()) 75 | 76 | self.n_seg_categories = self.n_seg_categories-self.min_category+1 77 | 78 | def __len__(self): 79 | return len(self.mesh_files) 80 | 81 | def __getitem__(self, idx): 82 | mesh = trimesh.load(self.mesh_files[idx], process=False) 83 | v_orig, f_orig = mesh.vertices, mesh.faces 84 | 85 | if self.segmentation_files is not None: 86 | face_segmentation = np.loadtxt(self.segmentation_files[idx], 87 | dtype=int) 88 | face_segmentation -= self.min_category 89 | 90 | # decimate mesh to desired number of faces in provided range 91 | if self.decimate_range is not None: 92 | while True: 93 | nfaces = np.random.randint( 94 | min(f_orig.shape[0], self.decimate_range[0]), 95 | min(f_orig.shape[0], self.decimate_range[1]) + 1) 96 | 97 | _, v, f, decimated_f_idxs, _ = igl.decimate( 98 | v_orig, f_orig, nfaces) 99 | if igl.is_edge_manifold(f): 100 | break 101 | 102 | if self.segmentation_files is not None: 103 | face_segmentation = face_segmentation[decimated_f_idxs] 104 | else: 105 | v = v_orig 106 | f = f_orig 107 | 108 | if self.center_vertices: 109 | v -= v.mean(0) 110 | 111 | if self.normalize_coords: 112 | v /= np.linalg.norm(v, axis=1).max() 113 | 114 | # random rotation/scaling 115 | if self.random_rotation: 116 | v = v @ random_rotation_matrix() 117 | if self.max_stretch != 0: 118 | v = v @ random_scale_matrix(self.max_stretch) 119 | if self.random_rotation: 120 | v = v @ random_rotation_matrix() 121 | 122 | areas = igl.doublearea(v, f) 123 | 124 | ev, fe, ef = igl.edge_topology(v, f) 125 | 126 | bdry_idxs = (ef == -1).any(1) 127 | if bdry_idxs.sum() > 0: 128 | bdry_ev = ev[bdry_idxs] 129 | bdry_ef = ef[bdry_idxs] 130 | bdry_ef.sort(1) 131 | bdry_ef = bdry_ef[:, 1, None] 132 | bdry_flipped = flip(f, bdry_ev, bdry_ef, bdry=True) 133 | 134 | int_ev = ev[~bdry_idxs] 135 | int_ef = ef[~bdry_idxs] 136 | int_flipped = flip(f, int_ev, int_ef) 137 | 138 | # normals 139 | n = igl.per_vertex_normals(v, f) 140 | n[np.isnan(n)] = 0 141 | 142 | result = { 143 | 'vertices': torch.from_numpy(v), 144 | 'faces': torch.from_numpy(f), 145 | 'areas': torch.from_numpy(areas), 146 | 'int_d01': d01(v, int_ev), 147 | 'triangles': torch.from_numpy(f.astype(np.int64)), 148 | 'normals': torch.from_numpy(n), 149 | 'mesh': self.mesh_files[idx], 150 | 'int_ev': torch.from_numpy(int_ev.astype(np.int64)), 151 | 'int_flipped': torch.from_numpy(int_flipped.astype(np.int64)), 152 | 'f': torch.from_numpy(f.astype(np.int64)) 153 | } 154 | 155 | if self.segmentation_files is not None: 156 | result['segmentation'] = torch.from_numpy(face_segmentation) 157 | if bdry_idxs.sum() > 0: 158 | result['bdry_d01'] = d01(v, bdry_ev) 159 | 160 | # feature per mesh (e.g. label of the mesh) 161 | for key in self.mesh_features: 162 | result[key] = self.mesh_features[key][idx] 163 | 164 | # gather vertex features to edges from list of keys 165 | result['int_edge_features'] = torch.from_numpy( 166 | np.concatenate([ 167 | np.concatenate([ 168 | result[key][torch.from_numpy(int_ev.astype(np.int64))], 169 | result[key][torch.from_numpy(int_flipped.astype(np.int64))] 170 | ], axis=1).reshape(int_ev.shape[0], -1) 171 | for key in self.edge_features_from_vertex_features 172 | ], axis=1)) 173 | 174 | if bdry_idxs.sum() > 0: 175 | result['bdry_edge_features'] = torch.from_numpy( 176 | np.concatenate([ 177 | np.concatenate([ 178 | result[key][torch.from_numpy( 179 | bdry_ev.astype(np.int64))], 180 | result[key][torch.from_numpy( 181 | bdry_flipped.astype(np.int64))] 182 | ], axis=1).reshape(bdry_ev.shape[0], -1) 183 | for key in self.edge_features_from_vertex_features 184 | ], axis=1)) 185 | 186 | # gather vertex features to triangles from list of keys 187 | result['triangle_features'] = torch.from_numpy( 188 | np.concatenate( 189 | [result[key][f].reshape(f.shape[0], -1) 190 | for key in self.triangle_features_from_vertex_features], 191 | axis=1)) 192 | 193 | return result 194 | 195 | 196 | def get_rot(theta): 197 | """Get 3D rotation matrix for a given angle.""" 198 | return np.array([ 199 | [1, 0, 0], 200 | [0, np.cos(theta), -np.sin(theta)], 201 | [0, np.sin(theta), np.cos(theta)] 202 | ]) 203 | 204 | 205 | class OrigamiDataset(torch.utils.data.Dataset): 206 | """Dataset of square mesh with random crease down the center.""" 207 | 208 | def __init__(self, edge_features_from_vertex_features=['vertices'], 209 | triangle_features_from_vertex_features=['vertices']): 210 | self.edge_features_from_vertex_features = \ 211 | edge_features_from_vertex_features 212 | self.triangle_features_from_vertex_features = \ 213 | triangle_features_from_vertex_features 214 | 215 | self.v, self.f = igl.read_triangle_mesh('square.obj') 216 | 217 | def __len__(self): 218 | return 5000 219 | 220 | def __getitem__(self, _): 221 | v, f = self.v, self.f 222 | 223 | v1 = v[:55] 224 | v2 = v[55:] 225 | 226 | theta = np.random.rand() * 2 * np.pi 227 | v2_ = v2 @ get_rot(np.pi - theta) 228 | v = np.concatenate([v1, v2_], axis=0) 229 | 230 | ev, _, ef = igl.edge_topology(v, f) 231 | 232 | bdry_idxs = (ef == -1).any(1) 233 | bdry_ev = ev[bdry_idxs] 234 | bdry_ef = ef[bdry_idxs] 235 | bdry_ef.sort(1) 236 | bdry_ef = bdry_ef[:, 1, None] 237 | 238 | bdry_flipped = flip(f, bdry_ev, bdry_ef, bdry=True) 239 | 240 | int_ev = ev[~bdry_idxs] 241 | int_ef = ef[~bdry_idxs] 242 | 243 | int_flipped = flip(f, int_ev, int_ef) 244 | 245 | # normals 246 | n = igl.per_vertex_normals(v, f) 247 | 248 | result = { 249 | 'vertices': torch.from_numpy(v), 250 | 'int_d01': d01(v, int_ev), 251 | 'bdry_d01': d01(v, bdry_ev), 252 | 'triangles': torch.from_numpy(f.astype(np.int64)), 253 | 'normals': torch.from_numpy(n), 254 | 'dir': torch.tensor([np.cos(theta), np.sin(theta)]), 255 | } 256 | 257 | # gather vertex features to edges from list of keys 258 | result['int_edge_features'] = torch.from_numpy( 259 | np.concatenate([ 260 | np.concatenate([ 261 | result[key][torch.from_numpy(int_ev)], 262 | result[key][torch.from_numpy(int_flipped)] 263 | ], axis=1).reshape(int_ev.shape[0], -1) 264 | for key in self.edge_features_from_vertex_features 265 | ], axis=1)) 266 | result['bdry_edge_features'] = torch.from_numpy( 267 | np.concatenate([ 268 | np.concatenate([ 269 | result[key][torch.from_numpy(bdry_ev)], 270 | result[key][torch.from_numpy(bdry_flipped)] 271 | ], axis=1).reshape(bdry_ev.shape[0], -1) 272 | for key in self.edge_features_from_vertex_features 273 | ], axis=1)) 274 | 275 | # gather vertex features to triangles from list of keys 276 | result['triangle_features'] = torch.from_numpy(np.concatenate( 277 | [result[key][f].reshape(f.shape[0], -1) 278 | for key in self.triangle_features_from_vertex_features], axis=1)) 279 | 280 | return result 281 | --------------------------------------------------------------------------------