├── LICENSE
├── README.md
├── environment.yml
├── .gitignore
├── train_origami.py
├── hodgeautograd.py
├── train_classification.py
├── train_segmentation.py
├── hodgenet.py
├── square.obj
└── meshdata.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Dima Smirnov
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## HodgeNet | [Webpage](https://people.csail.mit.edu/smirnov/hodgenet/) | [Paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459797) | [Video](https://youtu.be/juP0PHxvnx8)
2 |
3 |
4 |
5 | **HodgeNet: Learning Spectral Geometry on Triangle Meshes**
6 | Dmitriy Smirnov, Justin Solomon
7 | [SIGGRAPH 2021](https://s2021.siggraph.org/)
8 |
9 | ### Set-up
10 | To install the necessary dependencies, run:
11 | ```
12 | conda env create -f environment.yml
13 | conda activate HodgeNet
14 | ```
15 |
16 | ### Training
17 | To train the segmentation model, first download the [Shape COSEG dataset](http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/ssd.htm). Then, run:
18 | ```
19 | python train_segmentation.py --out out_dir --mesh_path path_to_meshes --seg_path path_to_segs
20 | ```
21 |
22 | To train the classification model, first download the SHREC 2011 dataset:
23 | ```
24 | wget -O shrec.tar.gz https://www.dropbox.com/s/4z4v1x30jsy0uoh/shrec.tar.gz?dl=0
25 | tar -xvf shrec.tar.gz -C data
26 | ```
27 | Then, run:
28 | ```
29 | python train_classification.py --out out_dir
30 | ```
31 |
32 | To train the dihedral angle stress test model, run:
33 | ```
34 | python train_origami.py --out out_dir
35 | ```
36 |
37 | To monitor the training, launch a TensorBoard instance with `--logdir out_dir`
38 |
39 | To finetune a model, add the flag `--fine_tune` to the above training commands.
40 |
41 | ### BibTeX
42 | ```
43 | @article{smirnov2021hodgenet,
44 | title={{HodgeNet}: Learning Spectral Geometry on Triangle Meshes},
45 | author={Smirnov, Dmitriy and Solomon, Justin},
46 | year={2021},
47 | journal={SIGGRAPH}
48 | }
49 | ```
50 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: HodgeNet
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1
8 | - _openmp_mutex=4.5
9 | - blas=1.0
10 | - bzip2=1.0.8
11 | - ca-certificates=2021.5.30
12 | - certifi=2021.5.30
13 | - cudatoolkit=11.1.1
14 | - ffmpeg=4.3
15 | - freetype=2.10.4
16 | - gmp=6.2.1
17 | - gnutls=3.6.13
18 | - igl=2.2.1
19 | - intel-openmp=2021.3.0
20 | - jpeg=9b
21 | - lame=3.100
22 | - lcms2=2.12
23 | - ld_impl_linux-64=2.35.1
24 | - libblas=3.9.0
25 | - libcblas=3.9.0
26 | - libffi=3.3
27 | - libgcc-ng=9.3.0
28 | - libgfortran-ng=11.1.0
29 | - libgfortran5=11.1.0
30 | - libgomp=9.3.0
31 | - libiconv=1.16
32 | - liblapack=3.9.0
33 | - libpng=1.6.37
34 | - libstdcxx-ng=9.3.0
35 | - libtiff=4.2.0
36 | - libuv=1.42.0
37 | - libwebp-base=1.2.0
38 | - lz4-c=1.9.3
39 | - mkl=2021.3.0
40 | - mkl-service=2.4.0
41 | - mkl_fft=1.3.0
42 | - mkl_random=1.2.2
43 | - ncurses=6.2
44 | - nettle=3.6
45 | - ninja=1.10.2
46 | - numpy=1.20.3
47 | - numpy-base=1.20.3
48 | - olefile=0.46
49 | - openh264=2.1.1
50 | - openjpeg=2.4.0
51 | - openssl=1.1.1k
52 | - pillow=8.3.1
53 | - pip=21.1.3
54 | - python=3.9.6
55 | - python_abi=3.9
56 | - pytorch=1.9.0
57 | - readline=8.1
58 | - scipy=1.7.0
59 | - setuptools=52.0.0
60 | - six=1.16.0
61 | - sqlite=3.36.0
62 | - tk=8.6.10
63 | - torchvision=0.10.0
64 | - typing_extensions=3.10.0.0
65 | - tzdata=2021a
66 | - wheel=0.36.2
67 | - xz=5.2.5
68 | - zlib=1.2.11
69 | - zstd=1.4.9
70 | - pip:
71 | - absl-py==0.13.0
72 | - cachetools==4.2.2
73 | - charset-normalizer==2.0.4
74 | - google-auth==1.34.0
75 | - google-auth-oauthlib==0.4.5
76 | - grpcio==1.39.0
77 | - idna==3.2
78 | - markdown==3.3.4
79 | - oauthlib==3.1.1
80 | - protobuf==3.17.3
81 | - pyasn1==0.4.8
82 | - pyasn1-modules==0.2.8
83 | - requests==2.26.0
84 | - requests-oauthlib==1.3.0
85 | - rsa==4.7.2
86 | - tensorboard==2.5.0
87 | - tensorboard-data-server==0.6.1
88 | - tensorboard-plugin-wit==1.8.0
89 | - tqdm==4.62.0
90 | - trimesh==3.9.25
91 | - urllib3==1.26.6
92 | - werkzeug==2.0.1
93 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | out
132 | data
133 |
--------------------------------------------------------------------------------
/train_origami.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import os
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 | from torch.utils.data import DataLoader
11 | from torch.utils.tensorboard import SummaryWriter
12 | from tqdm import tqdm
13 |
14 | from hodgenet import HodgeNetModel
15 | from meshdata import OrigamiDataset
16 |
17 |
18 | def main(args):
19 | torch.set_default_dtype(torch.float64) # needed for eigenvalue problems
20 | torch.manual_seed(0)
21 | np.random.seed(0)
22 |
23 | dataset = OrigamiDataset(
24 | edge_features_from_vertex_features=['vertices'],
25 | triangle_features_from_vertex_features=['vertices'])
26 |
27 | def mycollate(b): return b
28 | dataloader = DataLoader(dataset, batch_size=args.bs,
29 | num_workers=0, collate_fn=mycollate)
30 |
31 | example = dataset[0]
32 | hodgenet_model = HodgeNetModel(
33 | example['int_edge_features'].shape[1],
34 | example['triangle_features'].shape[1],
35 | num_output_features=args.n_out_features, mesh_feature=True,
36 | num_eigenvectors=args.n_eig, num_extra_eigenvectors=args.n_extra_eig,
37 | resample_to_triangles=False,
38 | num_bdry_edge_features=example['bdry_edge_features'].shape[1],
39 | num_vector_dimensions=args.num_vector_dimensions)
40 |
41 | origami_model = nn.Sequential(
42 | hodgenet_model,
43 | nn.Linear(args.n_out_features*args.num_vector_dimensions *
44 | args.num_vector_dimensions, 32),
45 | nn.LayerNorm(32),
46 | nn.LeakyReLU(),
47 | nn.Linear(32, 16),
48 | nn.LayerNorm(16),
49 | nn.LeakyReLU(),
50 | nn.Linear(16, 2))
51 |
52 | optimizer = optim.AdamW(origami_model.parameters(), lr=args.lr)
53 |
54 | if not os.path.exists(args.out):
55 | os.makedirs(args.out)
56 |
57 | train_writer = SummaryWriter(os.path.join(
58 | args.out, datetime.datetime.now().strftime('train-%m%d%y-%H%M%S')),
59 | flush_secs=1)
60 |
61 | def epoch_loop(dataloader, epochname, epochnum, writer, optimize=True):
62 | epoch_loss, epoch_size = 0, 0
63 | pbar = tqdm(total=len(dataloader),
64 | desc='{} {}'.format(epochname, epochnum))
65 | for batchnum, batch in enumerate(dataloader):
66 | if optimize:
67 | optimizer.zero_grad()
68 |
69 | batch_loss = 0
70 |
71 | dirs = origami_model(batch)
72 | dirs = F.normalize(dirs, p=2, dim=-1)
73 | for mesh, dir_estimate in zip(batch, dirs):
74 | gt_dir = mesh['dir'].to(dir_estimate.device)
75 | batch_loss += 1 - (gt_dir * dir_estimate).sum(-1)
76 |
77 | batch_loss /= len(batch)
78 |
79 | pbar.set_postfix({
80 | 'loss': batch_loss.item(),
81 | })
82 | pbar.update(1)
83 |
84 | epoch_loss += batch_loss.item()
85 | epoch_size += 1
86 |
87 | if optimize:
88 | batch_loss.backward()
89 | nn.utils.clip_grad_norm_(origami_model.parameters(), 1)
90 | optimizer.step()
91 |
92 | writer.add_scalar('Loss', batch_loss.item(),
93 | epochnum*len(dataloader)+batchnum)
94 |
95 | pbar.close()
96 |
97 | for epoch in range(args.n_epochs):
98 | origami_model.train()
99 | epoch_loop(dataloader, 'Epoch', epoch, train_writer)
100 |
101 | torch.save({
102 | 'origami_model_state_dict': origami_model.state_dict(),
103 | 'opt_state_dict': optimizer.state_dict(),
104 | 'epoch': epoch
105 | }, os.path.join(args.out, f'{epoch}.pth'))
106 |
107 |
108 | if __name__ == '__main__':
109 | parser = argparse.ArgumentParser()
110 | parser.add_argument('--out', type=str, default='out/origami')
111 | parser.add_argument('--bs', type=int, default=16)
112 | parser.add_argument('--lr', type=float, default=1e-3)
113 | parser.add_argument('--n_epochs', type=int, default=10000)
114 | parser.add_argument('--n_eig', type=int, default=32)
115 | parser.add_argument('--n_extra_eig', type=int, default=32)
116 | parser.add_argument('--n_out_features', type=int, default=32)
117 | parser.add_argument('--num_vector_dimensions', type=int, default=4)
118 |
119 | args = parser.parse_args()
120 | main(args)
121 |
--------------------------------------------------------------------------------
/hodgeautograd.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 | import scipy.sparse.linalg
4 | import torch
5 | import torch.multiprocessing
6 |
7 |
8 | def repeat_d_matrix(d, n_repeats):
9 | """Create block diagonal d matrix for vectorial operator."""
10 | if n_repeats == 1:
11 | return d
12 |
13 | I, J, V = scipy.sparse.find(d)
14 |
15 | bigI = np.concatenate([I*n_repeats+idx for idx in range(n_repeats)])
16 | bigJ = np.concatenate([J*n_repeats+idx for idx in range(n_repeats)])
17 | bigV = np.tile(V, n_repeats)
18 |
19 | result = scipy.sparse.csr_matrix((bigV, (bigI, bigJ)),
20 | shape=(d.shape[0]*n_repeats,
21 | d.shape[1]*n_repeats))
22 |
23 | return result
24 |
25 |
26 | def single_forward(ctx_eigenvalues, ctx_eigenvectors, ctx_dx, star0, star1, d,
27 | n_eig, n_extra_eig, device):
28 | """Compute the eigenvectors and eigenvalues of the generalized eigensystem.
29 |
30 | L*x = lambda*A*x
31 | where
32 | L = d'*blockdiag(star1)*d
33 | A = blockdiag(star0)
34 | """
35 | ne, nv = d.shape
36 | nvec = star1.shape[1]
37 |
38 | # make L
39 | star1s = [star1[i].squeeze() for i in range(ne)]
40 | star1mtx = scipy.sparse.block_diag(star1s)
41 | drep = repeat_d_matrix(d, nvec)
42 | L = drep.T @ (star1mtx @ drep)
43 |
44 | # make A
45 | star0s = [star0[i].squeeze() for i in range(nv)]
46 | star0mtx = scipy.sparse.block_diag(star0s)
47 |
48 | # can compute extra eigenvectors beyond n_eig (k total)
49 | # extras will improve quality of derivatives
50 | k = n_eig + n_extra_eig + nvec # adding nvec because all zero eigenvalues
51 | shift = 1e-4 # for numerical stability
52 |
53 | eigenvalues, eigenvectors = scipy.sparse.linalg.eigsh(
54 | L + shift*scipy.sparse.eye(L.shape[0]),
55 | k=k, M=star0mtx, which='LM', sigma=0, tol=1e-3)
56 | eigenvalues -= shift
57 |
58 | # sort eigenvalues/corresponding eigenvectors and make sign consistent
59 | idx = eigenvalues.argsort()
60 | eigenvalues = eigenvalues[idx]
61 | eigenvectors = eigenvectors[:, idx]
62 | eigenvectors = eigenvectors * np.sign(eigenvectors[0])
63 | eigenvalues[:nvec] = 0 # first Laplacian eigenvalue is always zero
64 | eigenvectors[:, 0] = 1 # will normalize momentarily
65 |
66 | # normalize eigenvectors
67 | vec_norms = np.sqrt(
68 | np.sum(eigenvectors * (star0mtx @ eigenvectors), axis=0))
69 | eigenvectors = eigenvectors / vec_norms.clip(1e-4)
70 |
71 | reshaped_eigenvectors = np.swapaxes(
72 | np.reshape(eigenvectors, (nv, nvec, k)), 1, 2)
73 |
74 | # differentiate eigenvectors --- useful for the derivative during backprop
75 | d_eig = np.swapaxes(np.reshape(drep @ eigenvectors, (ne, nvec, k)), 1, 2)
76 |
77 | ctx_eigenvalues.copy_(torch.from_numpy(eigenvalues).to(device))
78 | ctx_eigenvectors.copy_(torch.from_numpy(reshaped_eigenvectors).to(device))
79 | ctx_dx.copy_(torch.from_numpy(d_eig).to(device))
80 |
81 |
82 | def single_backward(dx, eigenvalues, eigenvectors, n_eig,
83 | grad_output_eigenvalues, grad_output_eigenvectors, device):
84 | """Backward pass for the eigenproblem."""
85 | nvec = eigenvectors.shape[2]
86 |
87 | grad_output_eigenvalues = torch.cat(
88 | [torch.zeros(nvec).to(device), grad_output_eigenvalues])
89 | dstar1 = torch.einsum('i,eil,eim->elm',
90 | grad_output_eigenvalues, dx[:, :n_eig],
91 | dx[:, :n_eig])
92 | dstar0 = torch.einsum('i,i,wil,wim->wlm',
93 | -grad_output_eigenvalues, eigenvalues[:n_eig],
94 | eigenvectors[:, :n_eig], eigenvectors[:, :n_eig])
95 |
96 | grad_output_eigenvectors = torch.cat([
97 | torch.zeros(eigenvectors.shape[0], nvec, nvec).to(device),
98 | grad_output_eigenvectors], dim=1)
99 |
100 | total_eig = eigenvectors.shape[1] # includes the extra eigenvalues
101 |
102 | M = eigenvalues[:, None].repeat(1, total_eig)
103 | M = 1. / (M - M.t())
104 | M[np.diag_indices(total_eig)] = 0
105 | M[:nvec, :nvec] = 0
106 |
107 | dstar1 += torch.einsum('vjn,vin,ij,ejl,eim->elm', eigenvectors,
108 | grad_output_eigenvectors, M[:n_eig], dx,
109 | dx[:, :n_eig])
110 |
111 | N = eigenvalues[:, None].repeat(1, total_eig)
112 | N = N / (N.t() - N)
113 | N[:nvec, :nvec] = 0
114 | N[np.diag_indices(total_eig)] = -.5
115 |
116 | dstar0 += torch.einsum('vjn,vin,ij,wjl,wim->wlm', eigenvectors,
117 | grad_output_eigenvectors, N[:n_eig], eigenvectors,
118 | eigenvectors[:, :n_eig])
119 |
120 | return dstar0, dstar1
121 |
122 |
123 | class HodgeEigensystem(torch.autograd.Function):
124 | """Autograd class for solving batches of Hodge eigensystems.
125 |
126 | WARNING: Assumes that the Hodge star matrices are symmetric.
127 | """
128 |
129 | @staticmethod
130 | def forward(ctx, nb, n_eig, n_extra_eig, *inputs):
131 | ctx.device = inputs[1].device
132 | ctx.nb = nb
133 | ctx.n_eig = n_eig
134 |
135 | eigenvalues = [
136 | torch.empty(
137 | inputs[3*i+1].shape[1] + n_eig + n_extra_eig
138 | ).to(ctx.device).share_memory_()
139 | for i in range(nb)
140 | ]
141 | eigenvectors = [
142 | torch.empty(
143 | inputs[3*i+2].shape[1],
144 | inputs[3*i+1].shape[1] + n_eig + n_extra_eig,
145 | inputs[3*i+1].shape[1]
146 | ).to(ctx.device).share_memory_()
147 | for i in range(nb)
148 | ]
149 | dx = [
150 | torch.empty(
151 | inputs[3*i+2].shape[0],
152 | inputs[3*i+1].shape[1] + n_eig + n_extra_eig,
153 | inputs[3*i+1].shape[1]
154 | ).to(ctx.device).share_memory_()
155 | for i in range(nb)
156 | ]
157 |
158 | processes = []
159 | for i in range(nb):
160 | star0, star1, d = inputs[3*i:3*i+3]
161 | p = torch.multiprocessing.Process(
162 | target=single_forward,
163 | args=(eigenvalues[i], eigenvectors[i], dx[i],
164 | star0.detach().cpu().numpy(),
165 | star1.detach().cpu().numpy(),
166 | d, n_eig, n_extra_eig, ctx.device))
167 | p.start()
168 | processes.append(p)
169 |
170 | for p in processes:
171 | p.join()
172 |
173 | ctx.eigenvalues = eigenvalues
174 | ctx.eigenvectors = eigenvectors
175 | ctx.dx = dx
176 |
177 | ret = []
178 | for i in range(nb):
179 | nvec = inputs[3*i+1].shape[1]
180 | ret.extend([
181 | ctx.eigenvalues[i][nvec:n_eig],
182 | ctx.eigenvectors[i][:, nvec:n_eig]])
183 |
184 | return tuple(ret)
185 |
186 | @staticmethod
187 | def backward(ctx, *grad_output):
188 | ret = [None, None, None]
189 |
190 | for i in range(ctx.nb):
191 | # derivative wrt eigenvalues
192 | dstar0, dstar1 = single_backward(
193 | ctx.dx[i], ctx.eigenvalues[i], ctx.eigenvectors[i], ctx.n_eig,
194 | grad_output[2*i], grad_output[2*i+1], ctx.device)
195 | ret.extend([dstar0, dstar1, None])
196 |
197 | return tuple(ret)
198 |
--------------------------------------------------------------------------------
/train_classification.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import defaultdict
3 | import datetime
4 | import os
5 | import random
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim as optim
11 | from torch.utils.data import DataLoader
12 | from torch.utils.tensorboard import SummaryWriter
13 | from tqdm import tqdm
14 |
15 | from hodgenet import HodgeNetModel
16 | from meshdata import HodgenetMeshDataset
17 |
18 |
19 | def main(args):
20 | torch.set_default_dtype(torch.float64) # needed for eigenvalue problems
21 | torch.manual_seed(1) # for repeatability
22 | np.random.seed(1)
23 | random.seed(args.seed)
24 |
25 | mesh_files_train = []
26 | labels_train = []
27 | mesh_files_val = []
28 | labels_val = []
29 |
30 | labeled_files = defaultdict(list)
31 | with open(os.path.join(args.data, 'labels.txt'), 'r') as labels:
32 | for y in labels:
33 | f, i = y.strip().split()
34 | i = int(i)
35 | labeled_files[i].append(os.path.join(args.data, f))
36 |
37 | for label, files in labeled_files.items():
38 | random.shuffle(files)
39 |
40 | mesh_files_train.extend(files[:args.train_size])
41 | labels_train.extend([label]*args.train_size)
42 | mesh_files_val.extend(files[args.train_size:])
43 | labels_val.extend([label]*len(files[args.train_size:]))
44 |
45 | features = ['vertices', 'normals']
46 |
47 | dataset = HodgenetMeshDataset(
48 | mesh_files_train,
49 | decimate_range=None if args.fine_tune is not None else (450, 500),
50 | edge_features_from_vertex_features=features,
51 | triangle_features_from_vertex_features=features,
52 | max_stretch=0 if args.fine_tune is not None else 0.05,
53 | random_rotation=True, mesh_features={'category': labels_train},
54 | normalize_coords=True)
55 |
56 | validation = HodgenetMeshDataset(
57 | mesh_files_val, decimate_range=None,
58 | edge_features_from_vertex_features=features,
59 | triangle_features_from_vertex_features=features, max_stretch=0,
60 | random_rotation=False, mesh_features={'category': labels_val},
61 | normalize_coords=True)
62 |
63 | def mycollate(b): return b
64 | dataloader = DataLoader(dataset, batch_size=args.bs,
65 | num_workers=args.num_workers, shuffle=True,
66 | collate_fn=mycollate)
67 | validationloader = DataLoader(
68 | validation, batch_size=args.bs,
69 | num_workers=args.num_workers, collate_fn=mycollate)
70 |
71 | example = dataset[0]
72 | hodgenet_model = HodgeNetModel(
73 | example['int_edge_features'].shape[1],
74 | example['triangle_features'].shape[1],
75 | num_output_features=args.n_out_features, mesh_feature=True,
76 | num_eigenvectors=args.n_eig, num_extra_eigenvectors=args.n_extra_eig,
77 | num_vector_dimensions=args.num_vector_dimensions)
78 |
79 | model = nn.Sequential(hodgenet_model,
80 | nn.Linear(args.n_out_features *
81 | args.num_vector_dimensions**2, 64),
82 | nn.BatchNorm1d(64),
83 | nn.LeakyReLU(),
84 | nn.Linear(64, 64),
85 | nn.BatchNorm1d(64),
86 | nn.LeakyReLU(),
87 | nn.Linear(64, len(labeled_files)))
88 |
89 | # categorical variables
90 | loss = nn.CrossEntropyLoss()
91 |
92 | # optimization routine
93 | print(sum(x.numel() for x in model.parameters()), 'parameters')
94 | optimizer = optim.AdamW(model.parameters(), lr=args.lr)
95 |
96 | if args.fine_tune is not None:
97 | checkpoint = torch.load(args.fine_tune)
98 | model.load_state_dict(
99 | checkpoint['model_state_dict'])
100 | optimizer.load_state_dict(checkpoint['opt_state_dict'])
101 | starting_epoch = checkpoint['epoch'] + 1
102 | print(f'Fine tuning! Starting at epoch {starting_epoch}')
103 | else:
104 | starting_epoch = 0
105 |
106 | if not os.path.exists(args.out):
107 | os.makedirs(args.out)
108 |
109 | train_writer = SummaryWriter(os.path.join(
110 | args.out, datetime.datetime.now().strftime('train-%m%d%y-%H%M%S')),
111 | flush_secs=1)
112 | val_writer = SummaryWriter(os.path.join(
113 | args.out, datetime.datetime.now().strftime('val-orig-%m%d%y-%H%M%S')),
114 | flush_secs=1)
115 |
116 | def epoch_loop(dataloader, epochname, epochnum, writer, optimize=True):
117 | epoch_loss, epoch_acc, epoch_size = 0, 0, 0
118 | pbar = tqdm(total=len(dataloader), desc=f'{epochname} {epochnum}')
119 | for batch in dataloader:
120 | if optimize:
121 | optimizer.zero_grad()
122 |
123 | batch_loss, batch_acc = 0, 0
124 |
125 | class_estimate = model(batch)
126 | labels = torch.tensor([x['category'] for x in batch])
127 | batch_loss = loss(class_estimate, labels) * len(batch)
128 | batch_acc = (class_estimate.argmax(1) == labels).float().sum()
129 |
130 | epoch_loss += batch_loss.item()
131 | epoch_acc += batch_acc.item()
132 | epoch_size += len(batch)
133 |
134 | batch_loss /= len(batch)
135 | batch_acc /= len(batch)
136 |
137 | pbar.set_postfix({
138 | 'loss': batch_loss.item(),
139 | 'accuracy': batch_acc.item(),
140 | })
141 | pbar.update(1)
142 |
143 | if optimize:
144 | batch_loss.backward()
145 | nn.utils.clip_grad_norm_(model.parameters(), 1)
146 | optimizer.step()
147 |
148 | writer.add_scalar('loss', epoch_loss / epoch_size, epochnum)
149 | writer.add_scalar('accuracy', epoch_acc / epoch_size, epochnum)
150 |
151 | pbar.close()
152 |
153 | for epoch in range(starting_epoch, starting_epoch+args.n_epochs+1):
154 | model.train()
155 | epoch_loop(dataloader, 'epoch', epoch, train_writer)
156 |
157 | # compute validation score
158 | if epoch % 5 == 0:
159 | model.eval()
160 | with torch.no_grad():
161 | epoch_loop(validationloader, 'validation', epoch, val_writer,
162 | optimize=False)
163 |
164 | torch.save({
165 | 'model_state_dict': model.state_dict(),
166 | 'opt_state_dict': optimizer.state_dict(),
167 | 'epoch': epoch
168 | }, os.path.join(args.out,
169 | f'{epoch}_finetune.pth'
170 | if args.fine_tune is not None else f'{epoch}.pth'))
171 |
172 |
173 | if __name__ == '__main__':
174 | parser = argparse.ArgumentParser()
175 | parser.add_argument('--out', type=str, default='out/shrec16')
176 | parser.add_argument('--data', type=str, default='data/shrec')
177 | parser.add_argument('--bs', type=int, default=16)
178 | parser.add_argument('--lr', type=float, default=1e-4)
179 | parser.add_argument('--n_epochs', type=int, default=100)
180 | parser.add_argument('--n_eig', type=int, default=32)
181 | parser.add_argument('--n_extra_eig', type=int, default=32)
182 | parser.add_argument('--n_out_features', type=int, default=32)
183 | parser.add_argument('--fine_tune', type=str, default=None)
184 | parser.add_argument('--num_workers', type=int, default=0)
185 | parser.add_argument('--num_vector_dimensions', type=int, default=4)
186 | parser.add_argument('--train_size', type=int, default=16)
187 | parser.add_argument('--seed', type=int, default=123)
188 |
189 | args = parser.parse_args()
190 | main(args)
191 |
--------------------------------------------------------------------------------
/train_segmentation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import os
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torch.utils.data import DataLoader
10 | from torch.utils.tensorboard import SummaryWriter
11 | from tqdm import tqdm
12 |
13 | from hodgenet import HodgeNetModel
14 | from meshdata import HodgenetMeshDataset
15 |
16 |
17 | def main(args):
18 | torch.set_default_dtype(torch.float64) # needed for eigenvalue problems
19 | torch.manual_seed(args.seed)
20 | np.random.seed(args.seed)
21 |
22 | mesh_files_train = []
23 | seg_files_train = []
24 | mesh_files_val = []
25 | seg_files_val = []
26 |
27 | files = sorted([f.split('.')[0] for f in os.listdir(args.mesh_path)])
28 | cutoff = round(0.85 * len(files) + 0.49)
29 |
30 | for i in files[:cutoff]:
31 | mesh_files_train.append(os.path.join(args.mesh_path, f'{i}.off'))
32 | seg_files_train.append(os.path.join(args.seg_path, f'{i}.seg'))
33 | for i in files[cutoff:]:
34 | mesh_files_val.append(os.path.join(args.mesh_path, f'{i}.off'))
35 | seg_files_val.append(os.path.join(args.seg_path, f'{i}.seg'))
36 |
37 | features = ['vertices'] if args.no_normals else ['vertices', 'normals']
38 |
39 | dataset = HodgenetMeshDataset(
40 | mesh_files_train,
41 | decimate_range=None if args.fine_tune is not None else (1000, 99999),
42 | edge_features_from_vertex_features=features,
43 | triangle_features_from_vertex_features=features,
44 | max_stretch=0 if args.fine_tune is not None else 0.05,
45 | random_rotation=False, segmentation_files=seg_files_train,
46 | normalize_coords=True)
47 |
48 | validation = HodgenetMeshDataset(
49 | mesh_files_val, decimate_range=None,
50 | edge_features_from_vertex_features=features,
51 | triangle_features_from_vertex_features=features, max_stretch=0,
52 | random_rotation=False, segmentation_files=seg_files_val,
53 | normalize_coords=True)
54 |
55 | def mycollate(b): return b
56 | dataloader = DataLoader(dataset, batch_size=args.bs,
57 | num_workers=args.num_workers, shuffle=True,
58 | collate_fn=mycollate)
59 | validationloader = DataLoader(validation, batch_size=args.bs,
60 | num_workers=args.num_workers,
61 | collate_fn=mycollate)
62 |
63 | example = dataset[0]
64 | hodgenet = HodgeNetModel(
65 | example['int_edge_features'].shape[1],
66 | example['triangle_features'].shape[1],
67 | num_output_features=args.n_out_features, mesh_feature=False,
68 | num_eigenvectors=args.n_eig, num_extra_eigenvectors=args.n_extra_eig,
69 | resample_to_triangles=True,
70 | num_vector_dimensions=args.num_vector_dimensions)
71 |
72 | model = nn.Sequential(
73 | hodgenet,
74 | nn.Linear(args.n_out_features*args.num_vector_dimensions *
75 | args.num_vector_dimensions, 32),
76 | nn.BatchNorm1d(32),
77 | nn.LeakyReLU(),
78 | nn.Linear(32, 32),
79 | nn.BatchNorm1d(32),
80 | nn.LeakyReLU(),
81 | nn.Linear(32, dataset.n_seg_categories))
82 |
83 | # categorical variables
84 | loss = nn.CrossEntropyLoss()
85 |
86 | # optimization routine
87 | print(sum(x.numel() for x in model.parameters()), 'parameters')
88 | optimizer = optim.AdamW(model.parameters(), lr=args.lr)
89 |
90 | if not os.path.exists(args.out):
91 | os.makedirs(args.out)
92 |
93 | if args.fine_tune is not None:
94 | checkpoint = torch.load(os.path.join(args.fine_tune))
95 | model.load_state_dict(checkpoint['model_state_dict'])
96 | optimizer.load_state_dict(checkpoint['opt_state_dict'])
97 | starting_epoch = checkpoint['epoch'] + 1
98 | print(f'Fine tuning! Starting at epoch {starting_epoch}')
99 | else:
100 | starting_epoch = 0
101 |
102 | train_writer = SummaryWriter(os.path.join(
103 | args.out, datetime.datetime.now().strftime('train-%m%d%y-%H%M%S')),
104 | flush_secs=1)
105 | val_writer = SummaryWriter(os.path.join(
106 | args.out, datetime.datetime.now().strftime('val-%m%d%y-%H%M%S')),
107 | flush_secs=1)
108 |
109 | def epoch_loop(dataloader, epochname, epochnum, writer, optimize=True):
110 | epoch_loss, epoch_acc, epoch_acc_weighted, epoch_size = 0, 0, 0, 0
111 | pbar = tqdm(total=len(dataloader), desc=f'{epochname} {epochnum}')
112 | for batch in dataloader:
113 | if optimize:
114 | optimizer.zero_grad()
115 |
116 | batch_loss, batch_acc, batch_acc_weighted = 0, 0, 0
117 |
118 | seg_estimates = torch.split(model(batch), [m['triangles'].shape[0]
119 | for m in batch], dim=0)
120 | for mesh, seg_estimate in zip(batch, seg_estimates):
121 | gt_segs = mesh['segmentation'].squeeze(-1)
122 | areas = mesh['areas']
123 | batch_loss += loss(seg_estimate, gt_segs)
124 | batch_acc += (seg_estimate.argmax(1) == gt_segs).float().mean()
125 | batch_acc_weighted += ((seg_estimate.argmax(1) == gt_segs)
126 | * areas).sum() / areas.sum()
127 |
128 | epoch_loss += batch_loss.item()
129 | epoch_acc += batch_acc.item()
130 | epoch_acc_weighted += batch_acc_weighted.item()
131 | epoch_size += len(batch)
132 |
133 | batch_loss /= len(batch)
134 | batch_acc /= len(batch)
135 | batch_acc_weighted /= len(batch)
136 |
137 | pbar.set_postfix({
138 | 'loss': batch_loss.item(),
139 | 'accuracy': batch_acc.item(),
140 | 'accuracy_weighted': batch_acc_weighted.item(),
141 | })
142 | pbar.update(1)
143 |
144 | if optimize:
145 | batch_loss.backward()
146 | nn.utils.clip_grad_norm_(model.parameters(), 1)
147 | optimizer.step()
148 |
149 | writer.add_scalar('loss', epoch_loss / epoch_size, epochnum)
150 | writer.add_scalar('accuracy', epoch_acc / epoch_size, epochnum)
151 | writer.add_scalar('accuracy_weighted',
152 | epoch_acc_weighted / epoch_size, epochnum)
153 |
154 | pbar.close()
155 |
156 | for epoch in range(starting_epoch, starting_epoch+args.n_epochs+1):
157 | model.train()
158 | epoch_loop(dataloader, 'Epoch', epoch, train_writer)
159 |
160 | # compute validation score
161 | if epoch % 5 == 0:
162 | model.eval()
163 | with torch.no_grad():
164 | epoch_loop(validationloader, 'Validation',
165 | epoch, val_writer, optimize=False)
166 |
167 | torch.save({
168 | 'model_state_dict': model.state_dict(),
169 | 'opt_state_dict': optimizer.state_dict(),
170 | 'epoch': epoch
171 | }, os.path.join(args.out,
172 | f'{epoch}_finetune.pth'
173 | if args.fine_tune is not None else f'{epoch}.pth'))
174 |
175 |
176 | if __name__ == '__main__':
177 | parser = argparse.ArgumentParser()
178 | parser.add_argument('--out', type=str, default='out/vase')
179 | parser.add_argument('--mesh_path', type=str, default='data/coseg_vase')
180 | parser.add_argument('--seg_path', type=str, default='data/coseg_vase_gt')
181 | parser.add_argument('--bs', type=int, default=16)
182 | parser.add_argument('--lr', type=float, default=1e-4)
183 | parser.add_argument('--n_epochs', type=int, default=100)
184 | parser.add_argument('--n_eig', type=int, default=32)
185 | parser.add_argument('--n_extra_eig', type=int, default=32)
186 | parser.add_argument('--n_out_features', type=int, default=32)
187 | parser.add_argument('--fine_tune', type=str, default=None)
188 | parser.add_argument('--num_workers', type=int, default=0)
189 | parser.add_argument('--num_vector_dimensions', type=int, default=4)
190 | parser.add_argument('--seed', type=int, default=123)
191 | parser.add_argument('--no_normals', action='store_true', default=False)
192 |
193 | args = parser.parse_args()
194 | main(args)
195 |
--------------------------------------------------------------------------------
/hodgenet.py:
--------------------------------------------------------------------------------
1 | import scipy
2 | import scipy.sparse.linalg
3 | import torch
4 | import torch.nn as nn
5 |
6 | from hodgeautograd import HodgeEigensystem
7 |
8 |
9 | class HodgeNetModel(nn.Module):
10 | """Main HodgeNet model.
11 |
12 | The model inputs a batch of meshes and outputs features per vertex or
13 | pooled to faces or the entire mesh.
14 | """
15 | def __init__(self, num_edge_features, num_triangle_features,
16 | num_output_features=32, num_eigenvectors=64,
17 | num_extra_eigenvectors=16, mesh_feature=False, min_star=1e-2,
18 | resample_to_triangles=False, num_bdry_edge_features=None,
19 | num_vector_dimensions=1):
20 | super(HodgeNetModel, self).__init__()
21 |
22 | self.num_triangle_features = num_triangle_features
23 | self.hodgefunc = HodgeEigensystem.apply
24 | self.num_eigenvectors = num_eigenvectors
25 | self.num_extra_eigenvectors = num_extra_eigenvectors
26 | self.num_output_features = num_output_features
27 | self.min_star = min_star
28 | self.resample_to_triangles = resample_to_triangles
29 | self.mesh_feature = mesh_feature
30 | self.num_vector_dimensions = num_vector_dimensions
31 |
32 | self.to_star1 = nn.Sequential(
33 | nn.Linear(num_edge_features, 32),
34 | nn.BatchNorm1d(32),
35 | nn.LeakyReLU(),
36 | nn.Linear(32, 32),
37 | nn.BatchNorm1d(32),
38 | nn.LeakyReLU(),
39 | nn.Linear(32, 32),
40 | nn.BatchNorm1d(32),
41 | nn.LeakyReLU(),
42 | nn.Linear(32, 32),
43 | nn.BatchNorm1d(32),
44 | nn.LeakyReLU(),
45 | nn.Linear(32, self.num_vector_dimensions**2)
46 | )
47 |
48 | if num_bdry_edge_features is not None:
49 | self.to_star1_bdry = nn.Sequential(
50 | nn.Linear(num_bdry_edge_features, 32),
51 | nn.BatchNorm1d(32),
52 | nn.LeakyReLU(),
53 | nn.Linear(32, 32),
54 | nn.BatchNorm1d(32),
55 | nn.LeakyReLU(),
56 | nn.Linear(32, 32),
57 | nn.BatchNorm1d(32),
58 | nn.LeakyReLU(),
59 | nn.Linear(32, 32),
60 | nn.BatchNorm1d(32),
61 | nn.LeakyReLU(),
62 | nn.Linear(32, self.num_vector_dimensions**2)
63 | )
64 | else:
65 | self.to_star1_bdry = None
66 |
67 | self.to_star0_tri = nn.Sequential(
68 | nn.Linear(num_triangle_features, 32),
69 | nn.BatchNorm1d(32),
70 | nn.LeakyReLU(),
71 | nn.Linear(32, 32),
72 | nn.BatchNorm1d(32),
73 | nn.LeakyReLU(),
74 | nn.Linear(32, 32),
75 | nn.BatchNorm1d(32),
76 | nn.LeakyReLU(),
77 | nn.Linear(32, 32),
78 | nn.BatchNorm1d(32),
79 | nn.LeakyReLU(),
80 | nn.Linear(32, self.num_vector_dimensions *
81 | self.num_vector_dimensions)
82 | )
83 |
84 | self.eigenvalue_to_matrix = nn.Sequential(
85 | nn.Linear(1, num_output_features),
86 | nn.BatchNorm1d(num_output_features),
87 | nn.LeakyReLU(),
88 | nn.Linear(num_output_features, num_output_features),
89 | nn.BatchNorm1d(num_output_features),
90 | nn.LeakyReLU(),
91 | nn.Linear(num_output_features, num_output_features),
92 | nn.BatchNorm1d(num_output_features),
93 | nn.LeakyReLU(),
94 | nn.Linear(num_output_features, num_output_features),
95 | nn.BatchNorm1d(num_output_features),
96 | nn.LeakyReLU(),
97 | nn.Linear(num_output_features, num_output_features)
98 | )
99 |
100 | def gather_star0(self, mesh, star0_tri):
101 | """Compute star0 matrix per vertex by gathering from triangles."""
102 | star0 = torch.zeros(mesh['vertices'].shape[0],
103 | star0_tri.shape[1]).to(star0_tri)
104 | star0.index_add_(0, mesh['triangles'][:, 0], star0_tri)
105 | star0.index_add_(0, mesh['triangles'][:, 1], star0_tri)
106 | star0.index_add_(0, mesh['triangles'][:, 2], star0_tri)
107 |
108 | star0 = star0.view(-1, self.num_vector_dimensions,
109 | self.num_vector_dimensions)
110 |
111 | # square the tensor to be semidefinite
112 | star0 = torch.einsum('ijk,ilk->ijl', star0, star0)
113 |
114 | # add min star down the diagonal
115 | star0 += torch.eye(self.num_vector_dimensions)[None].to(star0) * \
116 | self.min_star
117 |
118 | return star0
119 |
120 | def compute_mesh_eigenfunctions(self, mesh, star0, star1, bdry=False):
121 | """Compute eigenvectors and eigenvalues of the learned operator."""
122 | nb = len(mesh)
123 |
124 | inputs = []
125 | for m, s0, s1 in zip(mesh, star0, star1):
126 | d = m['int_d01']
127 | if bdry:
128 | d = scipy.sparse.vstack([d, m['bdry_d01']])
129 | inputs.extend([s0, s1, d])
130 |
131 | eigenvalues, eigenvectors = [], []
132 | outputs = self.hodgefunc(nb, self.num_eigenvectors,
133 | self.num_extra_eigenvectors, *inputs)
134 | for i in range(nb):
135 | eigenvalues.append(outputs[2*i])
136 | eigenvectors.append(outputs[2*i+1])
137 |
138 | return eigenvalues, eigenvectors
139 |
140 | def forward(self, batch):
141 | nb = len(batch)
142 |
143 | all_star0_tri = self.to_star0_tri(
144 | torch.cat([mesh['triangle_features'] for mesh in batch], dim=0))
145 | star0_tri_split = torch.split(
146 | all_star0_tri, [mesh['triangles'].shape[0] for mesh in batch],
147 | dim=0)
148 | star0_split = [self.gather_star0(mesh, star0_tri)
149 | for mesh, star0_tri in zip(batch, star0_tri_split)]
150 |
151 | all_star1 = self.to_star1(torch.cat([mesh['int_edge_features']
152 | for mesh in batch], dim=0))
153 | all_star1 = all_star1.view(-1, self.num_vector_dimensions,
154 | self.num_vector_dimensions)
155 | all_star1 = torch.einsum('ijk,ilk->ijl', all_star1, all_star1)
156 | all_star1 += torch.eye(
157 | self.num_vector_dimensions)[None].to(all_star1) * \
158 | self.min_star
159 | star1_split = list(torch.split(all_star1, [mesh['int_d01'].shape[0]
160 | for mesh in batch], dim=0))
161 |
162 | if self.to_star1_bdry is not None:
163 | all_star1_bdry = self.to_star1_bdry(
164 | torch.cat([mesh['bdry_edge_features'] for mesh in batch],
165 | dim=0))
166 | all_star1_bdry = all_star1_bdry.view(
167 | -1, self.num_vector_dimensions, self.num_vector_dimensions)
168 | all_star1_bdry = torch.einsum(
169 | 'ijk,ilk->ijl', all_star1_bdry, all_star1_bdry)
170 | all_star1_bdry += torch.eye(
171 | self.num_vector_dimensions)[None].to(all_star1_bdry) * \
172 | self.min_star
173 | star1_bdry_split = torch.split(
174 | all_star1_bdry,
175 | [mesh['bdry_d01'].shape[0] for mesh in batch], dim=0)
176 |
177 | for i in range(nb):
178 | star1_split[i] = torch.cat(
179 | [star1_split[i], star1_bdry_split[i]], dim=0)
180 |
181 | eigenvalues, eigenvectors = self.compute_mesh_eigenfunctions(
182 | batch, star0_split, star1_split,
183 | bdry=self.to_star1_bdry is not None)
184 |
185 | # glue the eigenvalues back together and run through the nonlinearity
186 | all_processed_eigenvalues = self.eigenvalue_to_matrix(
187 | torch.stack(eigenvalues).view(-1, 1)).view(
188 | nb, -1, self.num_output_features)
189 |
190 | # post-multiply the set of eigenvectors by the learned matrix that's a
191 | # function of eigenvalues (similar to HKS, WKS)
192 | outer_products = [torch.einsum(
193 | 'ijk,ijl->ijkl', eigenvectors[i], eigenvectors[i])
194 | for i in range(nb)] # take outer product of vectors
195 |
196 | result = [torch.einsum(
197 | 'ijkp,jl->ilkp', outer_products[i], all_processed_eigenvalues[i])
198 | for i in range(nb)] # multiply by learned matrix
199 |
200 | result = [result[i].flatten(start_dim=1) for i in range(nb)]
201 |
202 | if self.resample_to_triangles:
203 | result = [result[i][batch[i]['triangles']].max(
204 | 1)[0] for i in range(nb)]
205 |
206 | if self.mesh_feature:
207 | result = [f.max(0, keepdim=True)[0] for f in result]
208 |
209 | return torch.cat(result, dim=0)
210 |
--------------------------------------------------------------------------------
/square.obj:
--------------------------------------------------------------------------------
1 | v -0.500000 -0.000000 0.500000
2 | v -0.400000 -0.000000 0.500000
3 | v -0.300000 -0.000000 0.500000
4 | v -0.200000 -0.000000 0.500000
5 | v -0.100000 -0.000000 0.500000
6 | v 0.000000 -0.000000 0.500000
7 | v 0.100000 -0.000000 0.500000
8 | v 0.200000 -0.000000 0.500000
9 | v 0.300000 -0.000000 0.500000
10 | v 0.400000 -0.000000 0.500000
11 | v 0.500000 -0.000000 0.500000
12 | v -0.500000 -0.000000 0.400000
13 | v -0.400000 -0.000000 0.400000
14 | v -0.300000 -0.000000 0.400000
15 | v -0.200000 -0.000000 0.400000
16 | v -0.100000 -0.000000 0.400000
17 | v 0.000000 -0.000000 0.400000
18 | v 0.100000 -0.000000 0.400000
19 | v 0.200000 -0.000000 0.400000
20 | v 0.300000 -0.000000 0.400000
21 | v 0.400000 -0.000000 0.400000
22 | v 0.500000 -0.000000 0.400000
23 | v -0.500000 -0.000000 0.300000
24 | v -0.400000 -0.000000 0.300000
25 | v -0.300000 -0.000000 0.300000
26 | v -0.200000 -0.000000 0.300000
27 | v -0.100000 -0.000000 0.300000
28 | v 0.000000 -0.000000 0.300000
29 | v 0.100000 -0.000000 0.300000
30 | v 0.200000 -0.000000 0.300000
31 | v 0.300000 -0.000000 0.300000
32 | v 0.400000 -0.000000 0.300000
33 | v 0.500000 -0.000000 0.300000
34 | v -0.500000 -0.000000 0.200000
35 | v -0.400000 -0.000000 0.200000
36 | v -0.300000 -0.000000 0.200000
37 | v -0.200000 -0.000000 0.200000
38 | v -0.100000 -0.000000 0.200000
39 | v 0.000000 -0.000000 0.200000
40 | v 0.100000 -0.000000 0.200000
41 | v 0.200000 -0.000000 0.200000
42 | v 0.300000 -0.000000 0.200000
43 | v 0.400000 -0.000000 0.200000
44 | v 0.500000 -0.000000 0.200000
45 | v -0.500000 -0.000000 0.100000
46 | v -0.400000 -0.000000 0.100000
47 | v -0.300000 -0.000000 0.100000
48 | v -0.200000 -0.000000 0.100000
49 | v -0.100000 -0.000000 0.100000
50 | v 0.000000 -0.000000 0.100000
51 | v 0.100000 -0.000000 0.100000
52 | v 0.200000 -0.000000 0.100000
53 | v 0.300000 -0.000000 0.100000
54 | v 0.400000 -0.000000 0.100000
55 | v 0.500000 -0.000000 0.100000
56 | v -0.500000 0.000000 0.000000
57 | v -0.400000 0.000000 0.000000
58 | v -0.300000 0.000000 0.000000
59 | v -0.200000 0.000000 0.000000
60 | v -0.100000 0.000000 0.000000
61 | v 0.000000 0.000000 0.000000
62 | v 0.100000 0.000000 0.000000
63 | v 0.200000 0.000000 0.000000
64 | v 0.300000 0.000000 0.000000
65 | v 0.400000 0.000000 0.000000
66 | v 0.500000 0.000000 0.000000
67 | v -0.500000 0.000000 -0.100000
68 | v -0.400000 0.000000 -0.100000
69 | v -0.300000 0.000000 -0.100000
70 | v -0.200000 0.000000 -0.100000
71 | v -0.100000 0.000000 -0.100000
72 | v 0.000000 0.000000 -0.100000
73 | v 0.100000 0.000000 -0.100000
74 | v 0.200000 0.000000 -0.100000
75 | v 0.300000 0.000000 -0.100000
76 | v 0.400000 0.000000 -0.100000
77 | v 0.500000 0.000000 -0.100000
78 | v -0.500000 0.000000 -0.200000
79 | v -0.400000 0.000000 -0.200000
80 | v -0.300000 0.000000 -0.200000
81 | v -0.200000 0.000000 -0.200000
82 | v -0.100000 0.000000 -0.200000
83 | v 0.000000 0.000000 -0.200000
84 | v 0.100000 0.000000 -0.200000
85 | v 0.200000 0.000000 -0.200000
86 | v 0.300000 0.000000 -0.200000
87 | v 0.400000 0.000000 -0.200000
88 | v 0.500000 0.000000 -0.200000
89 | v -0.500000 0.000000 -0.300000
90 | v -0.400000 0.000000 -0.300000
91 | v -0.300000 0.000000 -0.300000
92 | v -0.200000 0.000000 -0.300000
93 | v -0.100000 0.000000 -0.300000
94 | v 0.000000 0.000000 -0.300000
95 | v 0.100000 0.000000 -0.300000
96 | v 0.200000 0.000000 -0.300000
97 | v 0.300000 0.000000 -0.300000
98 | v 0.400000 0.000000 -0.300000
99 | v 0.500000 0.000000 -0.300000
100 | v -0.500000 0.000000 -0.400000
101 | v -0.400000 0.000000 -0.400000
102 | v -0.300000 0.000000 -0.400000
103 | v -0.200000 0.000000 -0.400000
104 | v -0.100000 0.000000 -0.400000
105 | v 0.000000 0.000000 -0.400000
106 | v 0.100000 0.000000 -0.400000
107 | v 0.200000 0.000000 -0.400000
108 | v 0.300000 0.000000 -0.400000
109 | v 0.400000 0.000000 -0.400000
110 | v 0.500000 0.000000 -0.400000
111 | v -0.500000 0.000000 -0.500000
112 | v -0.400000 0.000000 -0.500000
113 | v -0.300000 0.000000 -0.500000
114 | v -0.200000 0.000000 -0.500000
115 | v -0.100000 0.000000 -0.500000
116 | v 0.000000 0.000000 -0.500000
117 | v 0.100000 0.000000 -0.500000
118 | v 0.200000 0.000000 -0.500000
119 | v 0.300000 0.000000 -0.500000
120 | v 0.400000 0.000000 -0.500000
121 | v 0.500000 0.000000 -0.500000
122 | f 1/1 2/2 12/12
123 | f 12/12 2/2 13/13
124 | f 2/2 3/3 13/13
125 | f 13/13 3/3 14/14
126 | f 3/3 4/4 14/14
127 | f 14/14 4/4 15/15
128 | f 4/4 5/5 15/15
129 | f 15/15 5/5 16/16
130 | f 5/5 6/6 16/16
131 | f 16/16 6/6 17/17
132 | f 6/6 7/7 17/17
133 | f 17/17 7/7 18/18
134 | f 7/7 8/8 18/18
135 | f 18/18 8/8 19/19
136 | f 8/8 9/9 19/19
137 | f 19/19 9/9 20/20
138 | f 9/9 10/10 20/20
139 | f 20/20 10/10 21/21
140 | f 10/10 11/11 21/21
141 | f 21/21 11/11 22/22
142 | f 12/12 13/13 23/23
143 | f 23/23 13/13 24/24
144 | f 13/13 14/14 24/24
145 | f 24/24 14/14 25/25
146 | f 14/14 15/15 25/25
147 | f 25/25 15/15 26/26
148 | f 15/15 16/16 26/26
149 | f 26/26 16/16 27/27
150 | f 16/16 17/17 27/27
151 | f 27/27 17/17 28/28
152 | f 17/17 18/18 28/28
153 | f 28/28 18/18 29/29
154 | f 18/18 19/19 29/29
155 | f 29/29 19/19 30/30
156 | f 19/19 20/20 30/30
157 | f 30/30 20/20 31/31
158 | f 20/20 21/21 31/31
159 | f 31/31 21/21 32/32
160 | f 21/21 22/22 32/32
161 | f 32/32 22/22 33/33
162 | f 23/23 24/24 34/34
163 | f 34/34 24/24 35/35
164 | f 24/24 25/25 35/35
165 | f 35/35 25/25 36/36
166 | f 25/25 26/26 36/36
167 | f 36/36 26/26 37/37
168 | f 26/26 27/27 37/37
169 | f 37/37 27/27 38/38
170 | f 27/27 28/28 38/38
171 | f 38/38 28/28 39/39
172 | f 28/28 29/29 39/39
173 | f 39/39 29/29 40/40
174 | f 29/29 30/30 40/40
175 | f 40/40 30/30 41/41
176 | f 30/30 31/31 41/41
177 | f 41/41 31/31 42/42
178 | f 31/31 32/32 42/42
179 | f 42/42 32/32 43/43
180 | f 32/32 33/33 43/43
181 | f 43/43 33/33 44/44
182 | f 34/34 35/35 45/45
183 | f 45/45 35/35 46/46
184 | f 35/35 36/36 46/46
185 | f 46/46 36/36 47/47
186 | f 36/36 37/37 47/47
187 | f 47/47 37/37 48/48
188 | f 37/37 38/38 48/48
189 | f 48/48 38/38 49/49
190 | f 38/38 39/39 49/49
191 | f 49/49 39/39 50/50
192 | f 39/39 40/40 50/50
193 | f 50/50 40/40 51/51
194 | f 40/40 41/41 51/51
195 | f 51/51 41/41 52/52
196 | f 41/41 42/42 52/52
197 | f 52/52 42/42 53/53
198 | f 42/42 43/43 53/53
199 | f 53/53 43/43 54/54
200 | f 43/43 44/44 54/54
201 | f 54/54 44/44 55/55
202 | f 45/45 46/46 56/56
203 | f 56/56 46/46 57/57
204 | f 46/46 47/47 57/57
205 | f 57/57 47/47 58/58
206 | f 47/47 48/48 58/58
207 | f 58/58 48/48 59/59
208 | f 48/48 49/49 59/59
209 | f 59/59 49/49 60/60
210 | f 49/49 50/50 60/60
211 | f 60/60 50/50 61/61
212 | f 50/50 51/51 61/61
213 | f 61/61 51/51 62/62
214 | f 51/51 52/52 62/62
215 | f 62/62 52/52 63/63
216 | f 52/52 53/53 63/63
217 | f 63/63 53/53 64/64
218 | f 53/53 54/54 64/64
219 | f 64/64 54/54 65/65
220 | f 54/54 55/55 65/65
221 | f 65/65 55/55 66/66
222 | f 56/56 57/57 67/67
223 | f 67/67 57/57 68/68
224 | f 57/57 58/58 68/68
225 | f 68/68 58/58 69/69
226 | f 58/58 59/59 69/69
227 | f 69/69 59/59 70/70
228 | f 59/59 60/60 70/70
229 | f 70/70 60/60 71/71
230 | f 60/60 61/61 71/71
231 | f 71/71 61/61 72/72
232 | f 61/61 62/62 72/72
233 | f 72/72 62/62 73/73
234 | f 62/62 63/63 73/73
235 | f 73/73 63/63 74/74
236 | f 63/63 64/64 74/74
237 | f 74/74 64/64 75/75
238 | f 64/64 65/65 75/75
239 | f 75/75 65/65 76/76
240 | f 65/65 66/66 76/76
241 | f 76/76 66/66 77/77
242 | f 67/67 68/68 78/78
243 | f 78/78 68/68 79/79
244 | f 68/68 69/69 79/79
245 | f 79/79 69/69 80/80
246 | f 69/69 70/70 80/80
247 | f 80/80 70/70 81/81
248 | f 70/70 71/71 81/81
249 | f 81/81 71/71 82/82
250 | f 71/71 72/72 82/82
251 | f 82/82 72/72 83/83
252 | f 72/72 73/73 83/83
253 | f 83/83 73/73 84/84
254 | f 73/73 74/74 84/84
255 | f 84/84 74/74 85/85
256 | f 74/74 75/75 85/85
257 | f 85/85 75/75 86/86
258 | f 75/75 76/76 86/86
259 | f 86/86 76/76 87/87
260 | f 76/76 77/77 87/87
261 | f 87/87 77/77 88/88
262 | f 78/78 79/79 89/89
263 | f 89/89 79/79 90/90
264 | f 79/79 80/80 90/90
265 | f 90/90 80/80 91/91
266 | f 80/80 81/81 91/91
267 | f 91/91 81/81 92/92
268 | f 81/81 82/82 92/92
269 | f 92/92 82/82 93/93
270 | f 82/82 83/83 93/93
271 | f 93/93 83/83 94/94
272 | f 83/83 84/84 94/94
273 | f 94/94 84/84 95/95
274 | f 84/84 85/85 95/95
275 | f 95/95 85/85 96/96
276 | f 85/85 86/86 96/96
277 | f 96/96 86/86 97/97
278 | f 86/86 87/87 97/97
279 | f 97/97 87/87 98/98
280 | f 87/87 88/88 98/98
281 | f 98/98 88/88 99/99
282 | f 89/89 90/90 100/100
283 | f 100/100 90/90 101/101
284 | f 90/90 91/91 101/101
285 | f 101/101 91/91 102/102
286 | f 91/91 92/92 102/102
287 | f 102/102 92/92 103/103
288 | f 92/92 93/93 103/103
289 | f 103/103 93/93 104/104
290 | f 93/93 94/94 104/104
291 | f 104/104 94/94 105/105
292 | f 94/94 95/95 105/105
293 | f 105/105 95/95 106/106
294 | f 95/95 96/96 106/106
295 | f 106/106 96/96 107/107
296 | f 96/96 97/97 107/107
297 | f 107/107 97/97 108/108
298 | f 97/97 98/98 108/108
299 | f 108/108 98/98 109/109
300 | f 98/98 99/99 109/109
301 | f 109/109 99/99 110/110
302 | f 100/100 101/101 111/111
303 | f 111/111 101/101 112/112
304 | f 101/101 102/102 112/112
305 | f 112/112 102/102 113/113
306 | f 102/102 103/103 113/113
307 | f 113/113 103/103 114/114
308 | f 103/103 104/104 114/114
309 | f 114/114 104/104 115/115
310 | f 104/104 105/105 115/115
311 | f 115/115 105/105 116/116
312 | f 105/105 106/106 116/116
313 | f 116/116 106/106 117/117
314 | f 106/106 107/107 117/117
315 | f 117/117 107/107 118/118
316 | f 107/107 108/108 118/118
317 | f 118/118 108/108 119/119
318 | f 108/108 109/109 119/119
319 | f 119/119 109/109 120/120
320 | f 109/109 110/110 120/120
321 | f 120/120 110/110 121/121
322 |
--------------------------------------------------------------------------------
/meshdata.py:
--------------------------------------------------------------------------------
1 | import igl
2 | import numpy as np
3 | import scipy
4 | import scipy.io
5 | import torch
6 | import trimesh
7 |
8 |
9 | def random_rotation_matrix():
10 | """Generate a random 3D rotation matrix."""
11 | Q, _ = np.linalg.qr(np.random.normal(size=(3, 3)))
12 | return Q
13 |
14 |
15 | def random_scale_matrix(max_stretch):
16 | """Generate a random 3D anisotropic scaling matrix."""
17 | return np.diag(1 + (np.random.rand(3)*2 - 1) * max_stretch)
18 |
19 |
20 | def d01(v, e):
21 | """Compute d01 operator from 0-froms to 1-forms."""
22 | row = np.tile(np.arange(e.shape[0]), 2)
23 | col = e.T.flatten()
24 | data = np.concatenate([np.ones(e.shape[0]), -np.ones(e.shape[0])], axis=0)
25 | d = scipy.sparse.csr_matrix(
26 | (data, (row, col)), dtype=np.double, shape=(e.shape[0], v.shape[0]))
27 | return d
28 |
29 |
30 | def flip(f, ev, ef, bdry=False):
31 | """Compute vertices of flipped edges."""
32 | # glue together the triangle vertices adjacent to each edge
33 | duplicate = f[ef].reshape(ef.shape[0], -1)
34 | duplicate[(duplicate == ev[:, 0, None])
35 | | (duplicate == ev[:, 1, None])] = -1 # remove edge vertices
36 |
37 | # find the two remaining verts (not -1) in an orientation-preserving way
38 | idxs = (-duplicate).argsort(1)[:, :(1 if bdry else 2)]
39 | idxs.sort(1) # preserve orientation by doing them in index order
40 | result = np.take_along_axis(duplicate, idxs, axis=1)
41 |
42 | return result
43 |
44 |
45 | class HodgenetMeshDataset(torch.utils.data.Dataset):
46 | """Dataset of meshes with labels."""
47 |
48 | def __init__(self, mesh_files, mesh_features={}, decimate_range=None,
49 | random_rotation=True, max_stretch=0.1,
50 | edge_features_from_vertex_features=['vertices'],
51 | triangle_features_from_vertex_features=['vertices'],
52 | center_vertices=True, normalize_coords=True,
53 | segmentation_files=None):
54 | self.mesh_files = mesh_files
55 | self.mesh_features = mesh_features
56 | self.decimate_range = decimate_range
57 | self.random_rotation = random_rotation
58 | self.max_stretch = max_stretch
59 | self.edge_features_from_vertex_features = \
60 | edge_features_from_vertex_features
61 | self.triangle_features_from_vertex_features = \
62 | triangle_features_from_vertex_features
63 | self.center_vertices = center_vertices
64 | self.segmentation_files = segmentation_files
65 | self.normalize_coords = normalize_coords
66 |
67 | self.min_category = float('inf')
68 | self.n_seg_categories = 0
69 | if self.segmentation_files is not None:
70 | for f in segmentation_files:
71 | triangle_data = np.loadtxt(f, dtype=np.int64)
72 | self.n_seg_categories = max(
73 | self.n_seg_categories, triangle_data.max())
74 | self.min_category = min(self.min_category, triangle_data.min())
75 |
76 | self.n_seg_categories = self.n_seg_categories-self.min_category+1
77 |
78 | def __len__(self):
79 | return len(self.mesh_files)
80 |
81 | def __getitem__(self, idx):
82 | mesh = trimesh.load(self.mesh_files[idx], process=False)
83 | v_orig, f_orig = mesh.vertices, mesh.faces
84 |
85 | if self.segmentation_files is not None:
86 | face_segmentation = np.loadtxt(self.segmentation_files[idx],
87 | dtype=int)
88 | face_segmentation -= self.min_category
89 |
90 | # decimate mesh to desired number of faces in provided range
91 | if self.decimate_range is not None:
92 | while True:
93 | nfaces = np.random.randint(
94 | min(f_orig.shape[0], self.decimate_range[0]),
95 | min(f_orig.shape[0], self.decimate_range[1]) + 1)
96 |
97 | _, v, f, decimated_f_idxs, _ = igl.decimate(
98 | v_orig, f_orig, nfaces)
99 | if igl.is_edge_manifold(f):
100 | break
101 |
102 | if self.segmentation_files is not None:
103 | face_segmentation = face_segmentation[decimated_f_idxs]
104 | else:
105 | v = v_orig
106 | f = f_orig
107 |
108 | if self.center_vertices:
109 | v -= v.mean(0)
110 |
111 | if self.normalize_coords:
112 | v /= np.linalg.norm(v, axis=1).max()
113 |
114 | # random rotation/scaling
115 | if self.random_rotation:
116 | v = v @ random_rotation_matrix()
117 | if self.max_stretch != 0:
118 | v = v @ random_scale_matrix(self.max_stretch)
119 | if self.random_rotation:
120 | v = v @ random_rotation_matrix()
121 |
122 | areas = igl.doublearea(v, f)
123 |
124 | ev, fe, ef = igl.edge_topology(v, f)
125 |
126 | bdry_idxs = (ef == -1).any(1)
127 | if bdry_idxs.sum() > 0:
128 | bdry_ev = ev[bdry_idxs]
129 | bdry_ef = ef[bdry_idxs]
130 | bdry_ef.sort(1)
131 | bdry_ef = bdry_ef[:, 1, None]
132 | bdry_flipped = flip(f, bdry_ev, bdry_ef, bdry=True)
133 |
134 | int_ev = ev[~bdry_idxs]
135 | int_ef = ef[~bdry_idxs]
136 | int_flipped = flip(f, int_ev, int_ef)
137 |
138 | # normals
139 | n = igl.per_vertex_normals(v, f)
140 | n[np.isnan(n)] = 0
141 |
142 | result = {
143 | 'vertices': torch.from_numpy(v),
144 | 'faces': torch.from_numpy(f),
145 | 'areas': torch.from_numpy(areas),
146 | 'int_d01': d01(v, int_ev),
147 | 'triangles': torch.from_numpy(f.astype(np.int64)),
148 | 'normals': torch.from_numpy(n),
149 | 'mesh': self.mesh_files[idx],
150 | 'int_ev': torch.from_numpy(int_ev.astype(np.int64)),
151 | 'int_flipped': torch.from_numpy(int_flipped.astype(np.int64)),
152 | 'f': torch.from_numpy(f.astype(np.int64))
153 | }
154 |
155 | if self.segmentation_files is not None:
156 | result['segmentation'] = torch.from_numpy(face_segmentation)
157 | if bdry_idxs.sum() > 0:
158 | result['bdry_d01'] = d01(v, bdry_ev)
159 |
160 | # feature per mesh (e.g. label of the mesh)
161 | for key in self.mesh_features:
162 | result[key] = self.mesh_features[key][idx]
163 |
164 | # gather vertex features to edges from list of keys
165 | result['int_edge_features'] = torch.from_numpy(
166 | np.concatenate([
167 | np.concatenate([
168 | result[key][torch.from_numpy(int_ev.astype(np.int64))],
169 | result[key][torch.from_numpy(int_flipped.astype(np.int64))]
170 | ], axis=1).reshape(int_ev.shape[0], -1)
171 | for key in self.edge_features_from_vertex_features
172 | ], axis=1))
173 |
174 | if bdry_idxs.sum() > 0:
175 | result['bdry_edge_features'] = torch.from_numpy(
176 | np.concatenate([
177 | np.concatenate([
178 | result[key][torch.from_numpy(
179 | bdry_ev.astype(np.int64))],
180 | result[key][torch.from_numpy(
181 | bdry_flipped.astype(np.int64))]
182 | ], axis=1).reshape(bdry_ev.shape[0], -1)
183 | for key in self.edge_features_from_vertex_features
184 | ], axis=1))
185 |
186 | # gather vertex features to triangles from list of keys
187 | result['triangle_features'] = torch.from_numpy(
188 | np.concatenate(
189 | [result[key][f].reshape(f.shape[0], -1)
190 | for key in self.triangle_features_from_vertex_features],
191 | axis=1))
192 |
193 | return result
194 |
195 |
196 | def get_rot(theta):
197 | """Get 3D rotation matrix for a given angle."""
198 | return np.array([
199 | [1, 0, 0],
200 | [0, np.cos(theta), -np.sin(theta)],
201 | [0, np.sin(theta), np.cos(theta)]
202 | ])
203 |
204 |
205 | class OrigamiDataset(torch.utils.data.Dataset):
206 | """Dataset of square mesh with random crease down the center."""
207 |
208 | def __init__(self, edge_features_from_vertex_features=['vertices'],
209 | triangle_features_from_vertex_features=['vertices']):
210 | self.edge_features_from_vertex_features = \
211 | edge_features_from_vertex_features
212 | self.triangle_features_from_vertex_features = \
213 | triangle_features_from_vertex_features
214 |
215 | self.v, self.f = igl.read_triangle_mesh('square.obj')
216 |
217 | def __len__(self):
218 | return 5000
219 |
220 | def __getitem__(self, _):
221 | v, f = self.v, self.f
222 |
223 | v1 = v[:55]
224 | v2 = v[55:]
225 |
226 | theta = np.random.rand() * 2 * np.pi
227 | v2_ = v2 @ get_rot(np.pi - theta)
228 | v = np.concatenate([v1, v2_], axis=0)
229 |
230 | ev, _, ef = igl.edge_topology(v, f)
231 |
232 | bdry_idxs = (ef == -1).any(1)
233 | bdry_ev = ev[bdry_idxs]
234 | bdry_ef = ef[bdry_idxs]
235 | bdry_ef.sort(1)
236 | bdry_ef = bdry_ef[:, 1, None]
237 |
238 | bdry_flipped = flip(f, bdry_ev, bdry_ef, bdry=True)
239 |
240 | int_ev = ev[~bdry_idxs]
241 | int_ef = ef[~bdry_idxs]
242 |
243 | int_flipped = flip(f, int_ev, int_ef)
244 |
245 | # normals
246 | n = igl.per_vertex_normals(v, f)
247 |
248 | result = {
249 | 'vertices': torch.from_numpy(v),
250 | 'int_d01': d01(v, int_ev),
251 | 'bdry_d01': d01(v, bdry_ev),
252 | 'triangles': torch.from_numpy(f.astype(np.int64)),
253 | 'normals': torch.from_numpy(n),
254 | 'dir': torch.tensor([np.cos(theta), np.sin(theta)]),
255 | }
256 |
257 | # gather vertex features to edges from list of keys
258 | result['int_edge_features'] = torch.from_numpy(
259 | np.concatenate([
260 | np.concatenate([
261 | result[key][torch.from_numpy(int_ev)],
262 | result[key][torch.from_numpy(int_flipped)]
263 | ], axis=1).reshape(int_ev.shape[0], -1)
264 | for key in self.edge_features_from_vertex_features
265 | ], axis=1))
266 | result['bdry_edge_features'] = torch.from_numpy(
267 | np.concatenate([
268 | np.concatenate([
269 | result[key][torch.from_numpy(bdry_ev)],
270 | result[key][torch.from_numpy(bdry_flipped)]
271 | ], axis=1).reshape(bdry_ev.shape[0], -1)
272 | for key in self.edge_features_from_vertex_features
273 | ], axis=1))
274 |
275 | # gather vertex features to triangles from list of keys
276 | result['triangle_features'] = torch.from_numpy(np.concatenate(
277 | [result[key][f].reshape(f.shape[0], -1)
278 | for key in self.triangle_features_from_vertex_features], axis=1))
279 |
280 | return result
281 |
--------------------------------------------------------------------------------