├── .gitignore
├── LICENSE
├── README.md
├── deform_3_obj.py
├── doc
├── tb_losses.png
├── tb_meshes.png
└── teaser.png
├── download_data.sh
├── download_pretrained.sh
├── exp1_visualize_results.ipynb
├── exp2_animate_smpl.ipynb
├── exp3_parametric_cad.ipynb
├── paper_figures
├── anim
│ ├── linear.npz
│ ├── linear.png
│ ├── linear
│ │ ├── mesh_0.ply
│ │ ├── mesh_1.ply
│ │ ├── mesh_10.ply
│ │ ├── mesh_11.ply
│ │ ├── mesh_12.ply
│ │ ├── mesh_13.ply
│ │ ├── mesh_14.ply
│ │ ├── mesh_15.ply
│ │ ├── mesh_16.ply
│ │ ├── mesh_17.ply
│ │ ├── mesh_18.ply
│ │ ├── mesh_19.ply
│ │ ├── mesh_2.ply
│ │ ├── mesh_20.ply
│ │ ├── mesh_3.ply
│ │ ├── mesh_4.ply
│ │ ├── mesh_5.ply
│ │ ├── mesh_6.ply
│ │ ├── mesh_7.ply
│ │ ├── mesh_8.ply
│ │ └── mesh_9.ply
│ ├── shapeflow_divfree_edge0.npz
│ ├── shapeflow_divfree_edge0.png
│ ├── shapeflow_divfree_edge0
│ │ ├── mesh_0.ply
│ │ ├── mesh_1.ply
│ │ ├── mesh_10.ply
│ │ ├── mesh_11.ply
│ │ ├── mesh_12.ply
│ │ ├── mesh_13.ply
│ │ ├── mesh_14.ply
│ │ ├── mesh_15.ply
│ │ ├── mesh_16.ply
│ │ ├── mesh_17.ply
│ │ ├── mesh_18.ply
│ │ ├── mesh_19.ply
│ │ ├── mesh_2.ply
│ │ ├── mesh_20.ply
│ │ ├── mesh_3.ply
│ │ ├── mesh_4.ply
│ │ ├── mesh_5.ply
│ │ ├── mesh_6.ply
│ │ ├── mesh_7.ply
│ │ ├── mesh_8.ply
│ │ └── mesh_9.ply
│ ├── shapeflow_divfree_edge2.npz
│ ├── shapeflow_divfree_edge2.png
│ ├── shapeflow_divfree_edge2
│ │ ├── mesh_0.ply
│ │ ├── mesh_1.ply
│ │ ├── mesh_10.ply
│ │ ├── mesh_11.ply
│ │ ├── mesh_12.ply
│ │ ├── mesh_13.ply
│ │ ├── mesh_14.ply
│ │ ├── mesh_15.ply
│ │ ├── mesh_16.ply
│ │ ├── mesh_17.ply
│ │ ├── mesh_18.ply
│ │ ├── mesh_19.ply
│ │ ├── mesh_2.ply
│ │ ├── mesh_20.ply
│ │ ├── mesh_3.ply
│ │ ├── mesh_4.ply
│ │ ├── mesh_5.ply
│ │ ├── mesh_6.ply
│ │ ├── mesh_7.ply
│ │ ├── mesh_8.ply
│ │ └── mesh_9.ply
│ ├── shapeflow_edge0.npz
│ ├── shapeflow_edge0.png
│ ├── shapeflow_edge0
│ │ ├── mesh_0.ply
│ │ ├── mesh_1.ply
│ │ ├── mesh_10.ply
│ │ ├── mesh_11.ply
│ │ ├── mesh_12.ply
│ │ ├── mesh_13.ply
│ │ ├── mesh_14.ply
│ │ ├── mesh_15.ply
│ │ ├── mesh_16.ply
│ │ ├── mesh_17.ply
│ │ ├── mesh_18.ply
│ │ ├── mesh_19.ply
│ │ ├── mesh_2.ply
│ │ ├── mesh_20.ply
│ │ ├── mesh_3.ply
│ │ ├── mesh_4.ply
│ │ ├── mesh_5.ply
│ │ ├── mesh_6.ply
│ │ ├── mesh_7.ply
│ │ ├── mesh_8.ply
│ │ └── mesh_9.ply
│ └── volume_change.pdf
├── anim_debug
│ ├── linear.npz
│ ├── linear.png
│ ├── shapeflow_divfree_edge0.npz
│ ├── shapeflow_divfree_edge0.png
│ ├── shapeflow_divfree_edge2.npz
│ ├── shapeflow_divfree_edge2.png
│ ├── shapeflow_edge0.npz
│ ├── shapeflow_edge0.png
│ └── volume_change.pdf
├── parametric_cad
│ ├── deformed.png
│ └── groundtruth.png
├── shapegen
│ ├── 0_dmc.jpeg
│ ├── 0_gt.jpeg
│ ├── 0_input.jpeg
│ ├── 0_onet.jpeg
│ ├── 0_psgn.jpeg
│ ├── 0_r2n2.jpeg
│ ├── 0_retrieved.jpeg
│ ├── 0_shapeflow.jpeg
│ ├── 1_dmc.jpeg
│ ├── 1_gt.jpeg
│ ├── 1_input.jpeg
│ ├── 1_onet.jpeg
│ ├── 1_psgn.jpeg
│ ├── 1_r2n2.jpeg
│ ├── 1_retrieved.jpeg
│ ├── 1_shapeflow.jpeg
│ ├── 2_dmc.jpeg
│ ├── 2_gt.jpeg
│ ├── 2_input.jpeg
│ ├── 2_onet.jpeg
│ ├── 2_psgn.jpeg
│ ├── 2_r2n2.jpeg
│ ├── 2_retrieved.jpeg
│ ├── 2_shapeflow.jpeg
│ ├── old
│ │ ├── 0_dmc.jpeg
│ │ ├── 0_gt.jpeg
│ │ ├── 0_input.jpeg
│ │ ├── 0_onet.jpeg
│ │ ├── 0_psgn.jpeg
│ │ ├── 0_r2n2.jpeg
│ │ ├── 0_shapeflow.jpeg
│ │ ├── 1_dmc.jpeg
│ │ ├── 1_gt.jpeg
│ │ ├── 1_input.jpeg
│ │ ├── 1_onet.jpeg
│ │ ├── 1_psgn.jpeg
│ │ ├── 1_r2n2.jpeg
│ │ ├── 1_shapeflow.jpeg
│ │ ├── 2_dmc.jpeg
│ │ ├── 2_gt.jpeg
│ │ ├── 2_input.jpeg
│ │ ├── 2_onet.jpeg
│ │ ├── 2_psgn.jpeg
│ │ ├── 2_r2n2.jpeg
│ │ └── 2_shapeflow.jpeg
│ └── shapegen.zip
└── teaser
│ └── flow.pdf
├── render_example.ipynb
├── requirements.txt
├── scripts
├── cache_shapenet_pointcloud.py
├── render_thumbnail.py
└── simplify.py
├── shapeflow
├── __init__.py
├── layers
│ ├── __init__.py
│ ├── chamfer_layer.py
│ ├── deformation_layer.py
│ ├── pde_layer.py
│ ├── pointnet_layer.py
│ └── shared_definition.py
└── utils
│ ├── __init__.py
│ └── train_utils.py
├── shapenet_dataloader.py
├── shapenet_embedding.py
├── shapenet_generation.sh
├── shapenet_reconstruct.ipynb
├── shapenet_reconstruct.py
├── shapenet_train.py
├── shapenet_train.sh
├── utils
└── render.py
└── visualize_deformer.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Meshes
132 | *.ply
133 |
134 | # VSCode
135 | .vscode
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Chiyu Max Jiang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ShapeFlow
2 | [NeurIPS 20'] ShapeFlow: Learnable Deformations Among 3D Shapes.
3 |
4 |
5 | By: [Chiyu "Max" Jiang*](http://maxjiang.ml/), [Jingwei Huang*](http://stanford.edu/~jingweih/), [Andrea Tagliasacchi](http://gfx.uvic.ca/people/ataiya/), [Leonidas Guibas](https://geometry.stanford.edu/member/guibas/)
6 |
7 | \[[Project Website]()\] \[[Paper](https://arxiv.org/abs/2006.07982)\]
8 |
9 |
10 |
11 |
12 | ## 1. Introduction
13 | We present ShapeFlow, a flow-based model for learning a deformation space for entire classes of 3D shapes with large intra-class variations. ShapeFlow allows learning a multi-template deformation space that is agnostic to shape topology, yet preserves fine geometric details. Different from a generative space where a latent vector is directly decoded into a shape, a deformation space decodes a vector into a continuous flow that can advect a source shape towards a target. Such a space naturally allows the disentanglement of geometric style (coming from the source) and structural pose (conforming to the target). We parametrize the deformation between geometries as a learned continuous flow field via a neural network and show that such deformations can be guaranteed to have desirable properties, such as bijectivity, freedom from self-intersections, or volume preservation. We illustrate the effectiveness of this learned deformation space for various downstream applications, including shape generation via deformation, geometric style transfer, unsupervised learning of a consistent parameterization for entire classes of shapes, and shape interpolation.
14 |
15 | ## 2. Getting Started
16 | ### Installing dependencies
17 | We recommend using pip to install all required dependencies with ease.
18 | ```
19 | pip install -r requirements.txt
20 | ```
21 |
22 | ### Optional dependency (rendering)
23 | We strongly suggest installing the optional dependencies for rendering meshes, so that you can visualize the results using interactive notebooks.
24 | `pyrender` can be installed via pip:
25 | ```
26 | pip install pyrender
27 | ```
28 |
29 | Additionally to run the notebook renderings on a headless server, follow the [instructions](https://pyrender.readthedocs.io/en/latest/install/#python-installation) for installing `OSMesa`.
30 |
31 | ### Download and unpack data
32 | To download and unpack the data used in the experiment, please use the utility script privided.
33 | ```
34 | bash download_data.sh
35 | ```
36 |
37 | ## 3. Reproduce results
38 | ### Run experiments
39 | Please use our provided launch script to start training the shape deformation model.
40 | ```
41 | bash shapenet_train.sh
42 | ```
43 |
44 | The training will launch on all available GPUs. Mask GPUs accordingly if you want to use only a subset of all GPUs. The initial tests are done on NVIDIA Volta V100 GPUs, therefore the `batch_size_per_gpu=16` might need to be adjusted accordingly for GPUs with smaller or larger memory limits if the out of memory error is triggered.
45 |
46 | ### Load and visualize pretrained checkpoint
47 | First download the pretrained checkpoint.
48 | ```
49 | wget island.me.berkeley.edu/files/pretrained_ckpt.zip
50 | mkdir -p runs
51 | mv pretrained_ckpt.zip runs
52 | cd runs; unzip pretrained_ckpt.zip; rm pretrained_ckpt.zip; cd ..
53 | ```
54 |
55 | Next, run through the cells in `visualize_deformer.ipynb`.
56 |
57 | ### Monitor training
58 | After launching the training script, a `runs` directory will be created, with different runs each as a separate subfolder within. To monitor the training process based on text logs, use
59 | ```
60 | tail -f runs//log.txt
61 | ```
62 |
63 | To monitor the training process using tensorboard, do:
64 | ```
65 | # if you are running this on a remote server via ssh
66 | ssh my_favorite_machine -L 6006:localhost:6006
67 |
68 | # go to the directory containing the tensorboard log
69 | cd path/to/ShapeFlow/runs//tensorboard
70 |
71 | # launch tensorboard
72 | tensorboard --logdir . --port 6006
73 | ```
74 | Tensorboard allows tracking of deformation losses, as well as visualizing the source / target / deformed meshes. The deformed meshes are colored by the distance per vertex with respect to target shape.
75 |
76 |
77 |
78 |
79 | ### Citation
80 | If you find our code useful for your work, please consider citing our paper:
81 | ```
82 | @inproceedings{jiang2020shapeflow,
83 | title={ShapeFlow: Learnable Deformations Among 3D Shapes},
84 | author={Jiang, Chiyu and Huang, Jingwei and Tagliasacchi, Andrea and Guibas, Leonidas},
85 | booktitle={Advances in Neural Information Processing Systems},
86 | year={2020}
87 | }
88 | ```
89 |
90 | ### Contact
91 | Please contact [Max Jiang](mailto:maxjiang93@gmail.com) if you have further questions!
92 |
--------------------------------------------------------------------------------
/deform_3_obj.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 |
3 | from shapeflow.layers.chamfer_layer import ChamferDistKDTree
4 | from shapeflow.layers.deformation_layer import NeuralFlowDeformer
5 | from shapeflow.layers.pointnet_layer import PointNetEncoder
6 |
7 | import torch
8 | import numpy as np
9 | from time import time
10 | import trimesh
11 | from glob import glob
12 |
13 |
14 | files = sorted(glob("data/shapenet_watertight/val/03001627/*/*.ply"))
15 | m1 = trimesh.load(files[1])
16 | m2 = trimesh.load(files[6])
17 | m3 = trimesh.load(files[7])
18 | device = torch.device("cuda:0")
19 |
20 | chamfer_dist = ChamferDistKDTree(reduction="mean", njobs=1).to(device)
21 | criterion = torch.nn.MSELoss()
22 |
23 | latent_size = 3
24 |
25 | deformer = NeuralFlowDeformer(
26 | latent_size=latent_size,
27 | f_nlayers=6,
28 | f_width=100,
29 | s_nlayers=2,
30 | s_width=5,
31 | method="dopri5",
32 | nonlinearity="elu",
33 | arch="imnet",
34 | adjoint=True,
35 | atol=1e-4,
36 | rtol=1e-4,
37 | ).to(device)
38 | encoder = PointNetEncoder(
39 | nf=16, out_features=latent_size, dropout_prob=0.0
40 | ).to(device)
41 |
42 | # this is an awkward workaround to get gradients for encoder via adjoint solver
43 | deformer.add_encoder(encoder)
44 | deformer.to(device)
45 | encoder = deformer.net.encoder
46 |
47 | optimizer = optim.Adam(list(deformer.parameters()), lr=1e-3)
48 |
49 | niter = 1000
50 | npts = 5000
51 |
52 | V1 = torch.tensor(m1.vertices.astype(np.float32)).to(device) # .unsqueeze(0)
53 | V2 = torch.tensor(m2.vertices.astype(np.float32)).to(device) # .unsqueeze(0)
54 | V3 = torch.tensor(m3.vertices.astype(np.float32)).to(device) # .unsqueeze(0)
55 |
56 | loss_min = 1e30
57 | tic = time()
58 | encoder.train()
59 |
60 | for it in range(0, niter):
61 | optimizer.zero_grad()
62 |
63 | seq1 = torch.randperm(V1.shape[0], device=device)[:npts]
64 | seq2 = torch.randperm(V2.shape[0], device=device)[:npts]
65 | seq3 = torch.randperm(V3.shape[0], device=device)[:npts]
66 | V1_samp = V1[seq1]
67 | V2_samp = V2[seq2]
68 | V3_samp = V3[seq3]
69 |
70 | V_src = torch.stack(
71 | [V1_samp, V1_samp, V2_samp], dim=0
72 | ) # [batch, npoints, 3]
73 | V_tar = torch.stack(
74 | [V2_samp, V3_samp, V3_samp], dim=0
75 | ) # [batch, npoints, 3]
76 |
77 | V_src_tar = torch.cat([V_src, V_tar], dim=0)
78 | V_tar_src = torch.cat([V_tar, V_src], dim=0)
79 |
80 | batch_latent_src_tar = encoder(V_src_tar)
81 | batch_latent_tar_src = torch.cat(
82 | [batch_latent_src_tar[3:], batch_latent_src_tar[:3]]
83 | )
84 |
85 | V_deform = deformer(V_src_tar, batch_latent_src_tar, batch_latent_tar_src)
86 |
87 | _, _, dist = chamfer_dist(V_deform, V_tar_src)
88 |
89 | loss = criterion(dist, torch.zeros_like(dist))
90 |
91 | loss.backward()
92 | optimizer.step()
93 |
94 | if it % 100 == 0 or True:
95 | print(f"iter={it}, loss={np.sqrt(loss.item())}")
96 |
97 | toc = time()
98 | print("Time for {} iters: {:.4f} s".format(niter, toc - tic))
99 |
100 | # save deformed mesh
101 | encoder.eval()
102 | with torch.no_grad():
103 | V1_latent = encoder(V1.unsqueeze(0))
104 | V2_latent = encoder(V2.unsqueeze(0))
105 | V3_latent = encoder(V3.unsqueeze(0))
106 |
107 | V1_2 = (
108 | deformer(V1.unsqueeze(0), V1_latent, V2_latent)
109 | .detach()
110 | .cpu()
111 | .numpy()[0]
112 | )
113 | V2_1 = (
114 | deformer(V2.unsqueeze(0), V2_latent, V1_latent)
115 | .detach()
116 | .cpu()
117 | .numpy()[0]
118 | )
119 | V1_3 = (
120 | deformer(V1.unsqueeze(0), V1_latent, V3_latent)
121 | .detach()
122 | .cpu()
123 | .numpy()[0]
124 | )
125 | V3_1 = (
126 | deformer(V3.unsqueeze(0), V3_latent, V1_latent)
127 | .detach()
128 | .cpu()
129 | .numpy()[0]
130 | )
131 | V2_3 = (
132 | deformer(V2.unsqueeze(0), V2_latent, V3_latent)
133 | .detach()
134 | .cpu()
135 | .numpy()[0]
136 | )
137 | V3_2 = (
138 | deformer(V3.unsqueeze(0), V3_latent, V2_latent)
139 | .detach()
140 | .cpu()
141 | .numpy()[0]
142 | )
143 | trimesh.Trimesh(V1_2, m1.faces).export("demo/output_1_2.obj")
144 | trimesh.Trimesh(V2_1, m2.faces).export("demo/output_2_1.obj")
145 | trimesh.Trimesh(V1_3, m1.faces).export("demo/output_1_3.obj")
146 | trimesh.Trimesh(V3_1, m3.faces).export("demo/output_3_1.obj")
147 | trimesh.Trimesh(V2_3, m2.faces).export("demo/output_2_3.obj")
148 | trimesh.Trimesh(V3_2, m3.faces).export("demo/output_3_2.obj")
149 |
150 | m1.export("demo/output_1.obj")
151 | m2.export("demo/output_2.obj")
152 | m3.export("demo/output_3.obj")
153 |
--------------------------------------------------------------------------------
/doc/tb_losses.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/doc/tb_losses.png
--------------------------------------------------------------------------------
/doc/tb_meshes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/doc/tb_meshes.png
--------------------------------------------------------------------------------
/doc/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/doc/teaser.png
--------------------------------------------------------------------------------
/download_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mkdir -p data
4 |
5 | echo "Downloading and unzipping data. This will take a while, take a break and get a coffee..."
6 |
7 | echo "Downloading Data..."
8 |
9 | cd data
10 | wget island.me.berkeley.edu/shape_flow/shapenet_simplified.zip
11 | wget island.me.berkeley.edu/shape_flow/shapenet_thumbnails.zip
12 | wget island.me.berkeley.edu/shape_flow/shapenet_pointcloud.zip
13 | wget island.me.berkeley.edu/shape_flow/smpl_meshes.zip
14 | wget island.me.berkeley.edu/shape_flow/sparse_inputs.zip
15 | wget island.me.berkeley.edu/shape_flow/parametric_cad.zip
16 |
17 | echo "Unzipping Data..."
18 | unzip shapenet_simplified.zip && rm shapenet_simplified.zip
19 | unzip shapenet_thumbnails.zip && rm shapenet_thumbnails.zip
20 | unzip shapenet_pointcloud.zip && rm shapenet_pointcloud.zip
21 | unzip smpl_meshes.zip && rm smpl_meshes.zip
22 | unzip sparse_inputs.zip && rm sparse_inputs.zip
23 | unzip parametric_cad.zip && rm parametric_cad.zip
24 |
25 | cd ..
26 |
--------------------------------------------------------------------------------
/download_pretrained.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "Downloading pretrained checkpoints and inference inputs. This will take a while, take a break and get a coffee..."
4 |
5 | mkdir -p data
6 | cd data
7 | wget island.me.berkeley.edu/shape_flow/sparse_inputs.zip
8 | unzip sparse_inputs.zip && rm sparse_inputs.zip
9 | cd ..
10 |
11 | mkdir -p runs
12 | cd runs
13 | wget island.me.berkeley.edu/shape_flow/pretrained_chair_symm128.zip
14 | unzip pretrained_chair_symm128.zip && rm pretrained_chair_symm128.zip
15 | cd ..
16 |
--------------------------------------------------------------------------------
/paper_figures/anim/linear.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear.npz
--------------------------------------------------------------------------------
/paper_figures/anim/linear.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear.png
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_0.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_0.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_1.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_1.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_10.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_10.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_11.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_11.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_12.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_12.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_13.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_13.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_14.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_14.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_15.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_15.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_16.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_16.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_17.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_17.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_18.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_18.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_19.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_19.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_2.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_2.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_20.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_20.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_3.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_3.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_4.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_4.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_5.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_5.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_6.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_6.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_7.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_7.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_8.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_8.ply
--------------------------------------------------------------------------------
/paper_figures/anim/linear/mesh_9.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_9.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0.npz
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0.png
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_0.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_0.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_1.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_1.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_10.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_10.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_11.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_11.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_12.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_12.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_13.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_13.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_14.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_14.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_15.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_15.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_16.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_16.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_17.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_17.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_18.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_18.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_19.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_19.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_2.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_2.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_20.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_20.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_3.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_3.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_4.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_4.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_5.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_5.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_6.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_6.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_7.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_7.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_8.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_8.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge0/mesh_9.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_9.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2.npz
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2.png
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_0.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_0.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_1.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_1.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_10.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_10.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_11.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_11.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_12.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_12.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_13.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_13.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_14.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_14.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_15.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_15.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_16.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_16.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_17.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_17.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_18.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_18.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_19.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_19.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_2.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_2.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_20.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_20.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_3.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_3.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_4.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_4.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_5.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_5.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_6.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_6.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_7.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_7.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_8.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_8.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_divfree_edge2/mesh_9.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_9.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0.npz
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0.png
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_0.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_0.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_1.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_1.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_10.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_10.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_11.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_11.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_12.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_12.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_13.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_13.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_14.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_14.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_15.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_15.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_16.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_16.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_17.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_17.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_18.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_18.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_19.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_19.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_2.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_2.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_20.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_20.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_3.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_3.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_4.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_4.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_5.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_5.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_6.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_6.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_7.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_7.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_8.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_8.ply
--------------------------------------------------------------------------------
/paper_figures/anim/shapeflow_edge0/mesh_9.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_9.ply
--------------------------------------------------------------------------------
/paper_figures/anim/volume_change.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/volume_change.pdf
--------------------------------------------------------------------------------
/paper_figures/anim_debug/linear.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/linear.npz
--------------------------------------------------------------------------------
/paper_figures/anim_debug/linear.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/linear.png
--------------------------------------------------------------------------------
/paper_figures/anim_debug/shapeflow_divfree_edge0.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/shapeflow_divfree_edge0.npz
--------------------------------------------------------------------------------
/paper_figures/anim_debug/shapeflow_divfree_edge0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/shapeflow_divfree_edge0.png
--------------------------------------------------------------------------------
/paper_figures/anim_debug/shapeflow_divfree_edge2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/shapeflow_divfree_edge2.npz
--------------------------------------------------------------------------------
/paper_figures/anim_debug/shapeflow_divfree_edge2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/shapeflow_divfree_edge2.png
--------------------------------------------------------------------------------
/paper_figures/anim_debug/shapeflow_edge0.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/shapeflow_edge0.npz
--------------------------------------------------------------------------------
/paper_figures/anim_debug/shapeflow_edge0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/shapeflow_edge0.png
--------------------------------------------------------------------------------
/paper_figures/anim_debug/volume_change.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/volume_change.pdf
--------------------------------------------------------------------------------
/paper_figures/parametric_cad/deformed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/parametric_cad/deformed.png
--------------------------------------------------------------------------------
/paper_figures/parametric_cad/groundtruth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/parametric_cad/groundtruth.png
--------------------------------------------------------------------------------
/paper_figures/shapegen/0_dmc.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_dmc.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/0_gt.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_gt.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/0_input.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_input.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/0_onet.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_onet.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/0_psgn.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_psgn.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/0_r2n2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_r2n2.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/0_retrieved.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_retrieved.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/0_shapeflow.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_shapeflow.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/1_dmc.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_dmc.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/1_gt.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_gt.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/1_input.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_input.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/1_onet.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_onet.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/1_psgn.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_psgn.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/1_r2n2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_r2n2.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/1_retrieved.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_retrieved.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/1_shapeflow.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_shapeflow.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/2_dmc.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_dmc.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/2_gt.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_gt.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/2_input.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_input.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/2_onet.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_onet.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/2_psgn.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_psgn.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/2_r2n2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_r2n2.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/2_retrieved.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_retrieved.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/2_shapeflow.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_shapeflow.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/0_dmc.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/0_dmc.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/0_gt.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/0_gt.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/0_input.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/0_input.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/0_onet.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/0_onet.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/0_psgn.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/0_psgn.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/0_r2n2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/0_r2n2.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/0_shapeflow.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/0_shapeflow.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/1_dmc.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/1_dmc.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/1_gt.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/1_gt.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/1_input.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/1_input.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/1_onet.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/1_onet.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/1_psgn.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/1_psgn.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/1_r2n2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/1_r2n2.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/1_shapeflow.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/1_shapeflow.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/2_dmc.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/2_dmc.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/2_gt.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/2_gt.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/2_input.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/2_input.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/2_onet.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/2_onet.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/2_psgn.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/2_psgn.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/2_r2n2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/2_r2n2.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/old/2_shapeflow.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/2_shapeflow.jpeg
--------------------------------------------------------------------------------
/paper_figures/shapegen/shapegen.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/shapegen.zip
--------------------------------------------------------------------------------
/paper_figures/teaser/flow.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/teaser/flow.pdf
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib>=3.1.1
2 | numpy>=1.18.1
3 | scipy>=1.4.1
4 | torch>=1.4.0
5 | # torchdiffeq>=0.0.1 # torch 1.5.0 broke backwards compat. fixed in master.
6 | git+https://github.com/rtqichen/torchdiffeq
7 | torchvision>=0.5.0
8 | trimesh>=3.5.20
9 | sympy>=1.6.2
10 | imageio>=2.9.0
11 | tensorboard>=2.3.0
12 |
--------------------------------------------------------------------------------
/scripts/cache_shapenet_pointcloud.py:
--------------------------------------------------------------------------------
1 | """Precompute and cache shapenet pointclouds.
2 | """
3 |
4 | import os
5 | import sys
6 |
7 | WORKING_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
8 | sys.path.append(WORKING_DIR)
9 | import tqdm # noqa: E402
10 | import argparse # noqa: E402
11 | import glob # noqa: E402
12 | import numpy as np # noqa: E402
13 | import trimesh # noqa: E402
14 | from multiprocessing import Pool # noqa: E402
15 | import pickle # noqa: E402
16 | from collections import OrderedDict # noqa: E402
17 |
18 |
19 | def sample_vertex(filename):
20 | mesh = trimesh.load(filename)
21 | v = np.array(mesh.vertices, dtype=np.float32)
22 | np.random.shuffle(v)
23 | return v[:n_points]
24 |
25 |
26 | def wrapper(arg):
27 | return arg, sample_vertex(arg)
28 |
29 |
30 | def update(args):
31 | filename, point_samp = args
32 | fname = "/".join(filename.split("/")[-4:-1])
33 | # note: input comes from async `wrapper`
34 | sampled_points[
35 | fname
36 | ] = point_samp # put answer into correct index of result list
37 | pbar.update()
38 |
39 |
40 | def get_args():
41 | """Parse command line arguments."""
42 | parser = argparse.ArgumentParser(
43 | description="Precompute and cache shapenet pointclouds."
44 | )
45 |
46 | parser.add_argument(
47 | "--file_pattern",
48 | type=str,
49 | default="**/*.ply",
50 | help="filename pattern for files to be rendered.",
51 | )
52 | parser.add_argument(
53 | "--input_root",
54 | type=str,
55 | default="data/shapenet_simplified",
56 | help="path to input mesh root.",
57 | )
58 | parser.add_argument(
59 | "--output_pkl",
60 | type=str,
61 | default="data/shapenet_points.pkl",
62 | help="path to output image root.",
63 | )
64 | parser.add_argument(
65 | "--n_points",
66 | type=int,
67 | default=4096,
68 | help="Number of points to sample per shape",
69 | )
70 | parser.add_argument(
71 | "--n_jobs",
72 | type=int,
73 | default=-1,
74 | help="Number of processes to use. Use all if set to -1.",
75 | )
76 |
77 | args = parser.parse_args()
78 | return args
79 |
80 |
81 | def main():
82 | args = get_args()
83 | patt = os.path.join(WORKING_DIR, args.input_root, args.file_pattern)
84 | in_files = glob.glob(patt, recursive=True)
85 | global sampled_points
86 | global n_points
87 | global pbar
88 | sampled_points = OrderedDict()
89 | n_points = args.n_points
90 | pbar = tqdm.tqdm(total=len(in_files))
91 | pool = Pool(processes=None if args.n_jobs == -1 else args.n_jobs)
92 | for fname in in_files:
93 | pool.apply_async(wrapper, args=(fname,), callback=update)
94 | pool.close()
95 | pool.join()
96 | pbar.close()
97 | with open(args.output_pkl, "wb") as fh:
98 | pickle.dump(sampled_points, fh)
99 |
100 |
101 | if __name__ == "__main__":
102 | main()
103 |
--------------------------------------------------------------------------------
/scripts/render_thumbnail.py:
--------------------------------------------------------------------------------
1 | """Render thumbnails for meshes.
2 | """
3 |
4 | # flake8: noqa E402
5 |
6 | import os
7 | import sys
8 |
9 | WORKING_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
10 | sys.path.append(WORKING_DIR)
11 | from utils.render import render_trimesh
12 | from multiprocessing import Pool
13 | import trimesh
14 | import tqdm
15 | import argparse
16 | import glob
17 | import imageio
18 |
19 |
20 | def render_file(in_file):
21 | out_file = in_file.replace(input_root, output_root)
22 | out_dir = os.path.dirname(out_file)
23 | os.makedirs(out_dir, exist_ok=True)
24 | eye = [0.8, 0.4, -0.5]
25 | center = [0, 0, 0]
26 | up = [0, 1, 0]
27 | mesh = trimesh.load(in_file)
28 | image, _, _, _ = render_trimesh(
29 | mesh, eye, center, up, res=(112, 112), light_intensity=6
30 | )
31 | imageio.imwrite(os.path.join(out_dir, "thumbnail.jpg"), image)
32 |
33 |
34 | def wrapper(arg):
35 | return arg, render_file(arg)
36 |
37 |
38 | def update(args):
39 | pbar.update()
40 |
41 |
42 | def get_args():
43 | """Parse command line arguments."""
44 | parser = argparse.ArgumentParser(description="Render thumbnails.")
45 |
46 | parser.add_argument(
47 | "--file_pattern",
48 | type=str,
49 | default="**/*.ply",
50 | help="filename pattern for files to be rendered.",
51 | )
52 | parser.add_argument(
53 | "--input_root",
54 | type=str,
55 | default="data/shapenet_simplified",
56 | help="path to input mesh root.",
57 | )
58 | parser.add_argument(
59 | "--output_root",
60 | type=str,
61 | default="data/shapenet_thumbnails",
62 | help="path to output image root.",
63 | )
64 | parser.add_argument(
65 | "--n_jobs",
66 | type=int,
67 | default=-1,
68 | help="Number of processes to use. Use all if set to -1.",
69 | )
70 |
71 | args = parser.parse_args()
72 | return args
73 |
74 |
75 | def main():
76 | args = get_args()
77 | in_files = glob.glob(
78 | os.path.join(WORKING_DIR, args.input_root, args.file_pattern),
79 | recursive=True,
80 | )
81 | print(os.path.join(WORKING_DIR, args.input_root, args.file_pattern))
82 | global input_root
83 | global output_root
84 | global pbar
85 | input_root = args.input_root
86 | output_root = args.output_root
87 | pbar = tqdm.tqdm(total=len(in_files))
88 | pool = Pool(processes=None if args.n_jobs == -1 else args.n_jobs)
89 | for fname in in_files:
90 | pool.apply_async(wrapper, args=(fname,), callback=update)
91 | pool.close()
92 | pool.join()
93 | pbar.close()
94 |
95 |
96 | main()
97 |
--------------------------------------------------------------------------------
/scripts/simplify.py:
--------------------------------------------------------------------------------
1 | """Script for performing mesh simplification.
2 | """
3 | # flake8: noqa E402
4 | import sys
5 |
6 | sys.path.append("..")
7 |
8 | import meshutils
9 | import trimesh
10 | import glob
11 | import os
12 | import tqdm
13 |
14 |
15 | files = glob.glob("../data/shapenet_watertight/test/*/*/*.ply")
16 | outroot = "shapenet_simplified"
17 |
18 | for f_in in tqdm.tqdm(files):
19 | f_out = f_in.replace("shapenet_watertight", outroot)
20 | dirname = os.path.dirname(f_out)
21 | os.makedirs(dirname, exist_ok=True)
22 | mesh = trimesh.load(f_in)
23 | v, f = meshutils.fast_simplify(mesh.vertices, mesh.faces, ratio=0.1)
24 | trimesh.Trimesh(v, f).export(f_out)
25 |
--------------------------------------------------------------------------------
/shapeflow/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/shapeflow/__init__.py
--------------------------------------------------------------------------------
/shapeflow/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/shapeflow/layers/__init__.py
--------------------------------------------------------------------------------
/shapeflow/layers/chamfer_layer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from scipy.spatial import cKDTree
5 | import multiprocessing
6 | from .shared_definition import REDUCTIONS
7 |
8 |
9 | class ChamferDist(nn.Module):
10 | """Compute chamfer distance on GPU using O(n^2) dense distance matrix."""
11 |
12 | def __init__(self, reduction="mean"):
13 | super(ChamferDist, self).__init__()
14 | if not (reduction in list(REDUCTIONS.keys())):
15 | raise ValueError(
16 | f"reduction method ({reduction}) not in list of "
17 | f"accepted values: {list(REDUCTIONS.keys())}"
18 | )
19 | self.reduce = lambda x: REDUCTIONS[reduction](x, axis=-1)
20 |
21 | def forward(self, tar, src):
22 | """
23 | Args:
24 | tar: [b, n, 3] points for target
25 | src: [b, m, 3] points for source
26 | Returns:
27 | accuracy, complete, chamfer
28 | """
29 | tar = tar.unsqueeze(2)
30 | src = src.unsqueeze(1)
31 | diff = tar - src # [b, n, m, 3]
32 | dist = torch.norm(diff, dim=-1) # [b, n, m]
33 | complete = torch.mean(dist.min(2)[0], dim=1) # [b]
34 | accuracy = torch.mean(dist.min(1)[0], dim=1) # [b]
35 |
36 | complete = self.reduce(complete)
37 | accuracy = self.reduce(accuracy)
38 | chamfer = 0.5 * (complete + accuracy)
39 | return accuracy, complete, chamfer
40 |
41 |
42 | def find_nn_id(args):
43 | """Eval distance between point sets.
44 | Args:
45 | src: [m, 3] np array points for source
46 | tar: [n, 3] np array points for target
47 | Returns:
48 | nn_idx: [m,] np array, index of nearest point in target
49 | """
50 | src, tar = args
51 | tree = cKDTree(tar)
52 | _, nn_idx = tree.query(src, k=1, n_jobs=1)
53 |
54 | return nn_idx
55 |
56 |
57 | def find_nn_id_parallel(args):
58 | """Eval distance between point sets.
59 | Args:
60 | src: [m, 3] np array points for source
61 | tar: [n, 3] np array points for target
62 | idx: int, batch index
63 | Returns:
64 | nn_idx: [m,] np array, index of nearest point in target
65 | idx
66 | """
67 | src, tar, idx = args
68 | tree = cKDTree(tar)
69 | _, nn_idx = tree.query(src, k=1, n_jobs=1)
70 |
71 | return idx, nn_idx
72 |
73 |
74 | class ChamferDistKDTree(nn.Module):
75 | """Compute chamfer distances on CPU using KDTree."""
76 |
77 | def __init__(self, reduction="mean", njobs=1):
78 | """Initialize loss module.
79 |
80 | Args:
81 | reduction: str, reduction method. choice of mean/sum/max/min.
82 | njobs: int, number of parallel workers to use during eval.
83 | """
84 | super(ChamferDistKDTree, self).__init__()
85 | self.njobs = njobs
86 |
87 | self.set_reduction_method(reduction)
88 | if self.njobs != 1:
89 | self.p = multiprocessing.Pool(njobs)
90 |
91 | def find_batch_nn_id(self, src, tar, njobs):
92 | """Batched eval of distance between point sets.
93 | Args:
94 | src: [batch, m, 3] np array points for source
95 | tar: [batch, n, 3] np array points for target
96 | Returns:
97 | batch_nn_idx: [batch, m], np array, index of nearest point in target
98 | """
99 | b = src.shape[0]
100 | if njobs != 1:
101 | src_tar_pairs = tuple(zip(src, tar, range(b)))
102 | result = self.p.map(find_nn_id_parallel, src_tar_pairs)
103 | seq_arr = np.array([r[0] for r in result])
104 | batch_nn_idx = np.stack([r[1] for r in result], axis=0)
105 | batch_nn_idx = batch_nn_idx[np.argsort(seq_arr)]
106 | else:
107 | batch_nn_idx = np.stack(
108 | [find_nn_id((src[i], tar[i])) for i in range(b)], axis=0
109 | )
110 |
111 | return batch_nn_idx
112 |
113 | def set_reduction_method(self, reduction):
114 | """Set reduction method.
115 |
116 | Args:
117 | reduction: str, reduction method. choice of mean/sum/max/min.
118 | """
119 | if not (reduction in list(REDUCTIONS.keys())):
120 | raise ValueError(
121 | f"reduction method ({reduction}) not in list of "
122 | f"accepted values: {list(REDUCTIONS.keys())}"
123 | )
124 | self.reduce = REDUCTIONS[reduction]
125 |
126 | def forward(self, src, tar):
127 | """
128 | Args:
129 | src: [batch, m, 3] points for source
130 | tar: [batch, n, 3] points for target
131 | Returns:
132 | accuracy: [batch, m], accuracy measure for each point in source
133 | complete: [batch, n], complete measure for each point in target
134 | chamfer: [batch,], chamfer distance between source and target
135 | """
136 | bs = src.shape[0]
137 | device = src.device
138 | src_np = src.data.cpu().numpy()
139 | tar_np = tar.data.cpu().numpy()
140 | batch_tar_idx = (
141 | torch.from_numpy(
142 | self.find_batch_nn_id(src_np, tar_np, njobs=self.njobs)
143 | )
144 | .type(torch.LongTensor)
145 | .to(device)
146 | ) # [b, m]
147 | batch_src_idx = (
148 | torch.from_numpy(
149 | self.find_batch_nn_id(tar_np, src_np, njobs=self.njobs)
150 | )
151 | .type(torch.LongTensor)
152 | .to(device)
153 | ) # [b, n]
154 | batch_tar_idx_b = (
155 | torch.arange(bs).view(-1, 1).expand(-1, src.shape[1])
156 | ) # [b, m, 3]
157 | batch_src_idx_b = (
158 | torch.arange(bs).view(-1, 1).expand(-1, tar.shape[1])
159 | ) # [b, n, 3]
160 |
161 | src_to_tar_diff = (
162 | tar[batch_tar_idx_b, batch_tar_idx] - src
163 | ) # [b, m, 3]
164 | tar_to_src_diff = (
165 | src[batch_src_idx_b, batch_src_idx] - tar
166 | ) # [b, n, 3]
167 | accuracy = torch.norm(src_to_tar_diff, dim=-1, keepdim=False) # [b, m]
168 | complete = torch.norm(tar_to_src_diff, dim=-1, keepdim=False) # [b, n]
169 |
170 | chamfer = 0.5 * (self.reduce(accuracy) + self.reduce(complete))
171 | return accuracy, complete, chamfer
172 |
--------------------------------------------------------------------------------
/shapeflow/layers/deformation_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torchdiffeq import odeint_adjoint
4 | from torchdiffeq import odeint as odeint_regular
5 | from .pde_layer import PDELayer
6 | from .shared_definition import NONLINEARITIES
7 | import numpy as np
8 |
9 |
10 | class ImNet(nn.Module):
11 | """ImNet layer pytorch implementation."""
12 |
13 | def __init__(
14 | self,
15 | dim=3,
16 | in_features=32,
17 | out_features=4,
18 | nf=32,
19 | nonlinearity="leakyrelu",
20 | ):
21 | """Initialization.
22 |
23 | Args:
24 | dim: int, dimension of input points.
25 | in_features: int, length of input features (i.e., latent code).
26 | out_features: number of output features.
27 | nf: int, width of the second to last layer.
28 | activation: tf activation op.
29 | name: str, name of the layer.
30 | """
31 | super(ImNet, self).__init__()
32 | self.dim = dim
33 | self.in_features = in_features
34 | self.dimz = dim + in_features
35 | self.out_features = out_features
36 | self.nf = nf
37 | self.activ = NONLINEARITIES[nonlinearity]
38 | self.fc0 = nn.Linear(self.dimz, nf * 16)
39 | self.fc1 = nn.Linear(nf * 16 + self.dimz, nf * 8)
40 | self.fc2 = nn.Linear(nf * 8 + self.dimz, nf * 4)
41 | self.fc3 = nn.Linear(nf * 4 + self.dimz, nf * 2)
42 | self.fc4 = nn.Linear(nf * 2 + self.dimz, nf * 1)
43 | self.fc5 = nn.Linear(nf * 1, out_features)
44 | self.fc = [self.fc0, self.fc1, self.fc2, self.fc3, self.fc4, self.fc5]
45 | self.fc = nn.ModuleList(self.fc)
46 |
47 | def forward(self, x):
48 | """Forward method.
49 |
50 | Args:
51 | x: `[batch_size, dim+in_features]` tensor, inputs to decode.
52 | Returns:
53 | output through this layer of shape [batch_size, out_features].
54 | """
55 | x_tmp = x
56 | for dense in self.fc[:4]:
57 | x_tmp = self.activ(dense(x_tmp))
58 | x_tmp = torch.cat([x_tmp, x], dim=-1)
59 | x_tmp = self.activ(self.fc4(x_tmp))
60 | x_tmp = self.fc5(x_tmp)
61 | return x_tmp
62 |
63 |
64 | class VanillaNet(nn.Module):
65 | """Vanilla mpl pytorch implementation."""
66 |
67 | def __init__(
68 | self,
69 | dim=3,
70 | in_features=32,
71 | out_features=3,
72 | nf=50,
73 | nlayers=4,
74 | nonlinearity="leakyrelu",
75 | ):
76 | """Initialization.
77 |
78 | Args:
79 | dim: int, dimension of input points.
80 | in_features: int, length of input features (i.e., latent code).
81 | out_features: number of output features.
82 | nf: int, width of the second to last layer.
83 | nlayers: int, number of layers in mlp (inc. input/output layers).
84 | activation: tf activation op.
85 | name: str, name of the layer.
86 | """
87 | super(VanillaNet, self).__init__()
88 | self.dim = dim
89 | self.in_features = in_features
90 | self.dimz = dim + in_features
91 | self.out_features = out_features
92 | self.nf = nf
93 | self.nlayers = nlayers
94 | self.activ = NONLINEARITIES[nonlinearity]
95 | modules = [nn.Linear(dim + in_features, nf), self.activ]
96 | assert nlayers >= 2
97 | for i in range(nlayers - 2):
98 | modules += [nn.Linear(nf, nf), self.activ]
99 | modules += [nn.Linear(nf, out_features)]
100 | self.net = nn.Sequential(*modules)
101 |
102 | def forward(self, x):
103 | """Forward method.
104 |
105 | Args:
106 | x: `[batch_size, dim+in_features]` tensor, inputs to decode.
107 | Returns:
108 | output through this layer of shape [batch_size, out_features].
109 | """
110 | return self.net(x)
111 |
112 |
113 | def symmetrize(net, latent_vector, points, symm_dim):
114 | """Make network output symmetric."""
115 | # query both sides of the symmetric dimension
116 | points_pos = points
117 | points_neg = points.clone()
118 | points_neg[..., symm_dim] = -points_neg[..., symm_dim]
119 | y_pos = net(latent_vector, points_pos)
120 | y_neg = net(latent_vector, points_neg)
121 | y_sym = (y_pos + y_neg) / 2
122 | y_sym[..., symm_dim] = (y_pos[..., symm_dim] - y_neg[..., symm_dim]) / 2
123 | return y_sym
124 |
125 |
126 | class DeformationFlowNetwork(nn.Module):
127 | def __init__(
128 | self,
129 | dim=3,
130 | latent_size=1,
131 | nlayers=4,
132 | width=50,
133 | nonlinearity="leakyrelu",
134 | arch="imnet",
135 | divfree=False,
136 | ):
137 | """Intialize deformation flow network.
138 |
139 | Args:
140 | dim: int, physical dimensions. Either 2 for 2d or 3 for 3d.
141 | latent_size: int, size of latent space. >= 1.
142 | nlayers: int, number of neural network layers. >= 2.
143 | width: int, number of neurons per hidden layer. >= 1.
144 | divfree: bool, paramaterize a divergence free flow.
145 | """
146 | super(DeformationFlowNetwork, self).__init__()
147 | self.dim = dim
148 | self.latent_size = latent_size
149 | self.nlayers = nlayers
150 | self.width = width
151 | self.nonlinearity = nonlinearity
152 |
153 | self.arch = arch
154 | self.divfree = divfree
155 |
156 | assert arch in ["imnet", "vanilla"]
157 | if arch == "imnet":
158 | self.net = ImNet(
159 | dim=dim,
160 | in_features=latent_size,
161 | out_features=dim,
162 | nf=width,
163 | nonlinearity=nonlinearity,
164 | )
165 | else: # vanilla
166 | self.net = VanillaNet(
167 | dim=dim,
168 | in_features=latent_size,
169 | out_features=dim,
170 | nf=width,
171 | nlayers=nlayers,
172 | nonlinearity=nonlinearity,
173 | )
174 |
175 | for m in self.net.modules():
176 | if isinstance(m, nn.Linear):
177 | nn.init.normal_(m.weight, mean=0, std=1e-1)
178 | nn.init.constant_(m.bias, val=0)
179 | if divfree:
180 | self.curl = self._get_curl_layer()
181 |
182 | def _get_curl_layer(self):
183 | in_vars = "x, y, z"
184 | out_vars = "u, v, w"
185 | eqn_strs = [
186 | "dif(w, y) - dif(v, z)",
187 | "dif(u, z) - dif(w, x)",
188 | "dif(v, x) - dif(u, y)",
189 | ]
190 | eqn_names = [
191 | "curl_x",
192 | "curl_y",
193 | "curl_z",
194 | ] # a name/identifier for the equations
195 | curl = PDELayer(
196 | in_vars=in_vars, out_vars=out_vars
197 | ) # initialize the pde layer
198 | for eqn_str, eqn_name in zip(eqn_strs, eqn_names): # add equations
199 | curl.add_equation(eqn_str, eqn_name)
200 | return curl
201 |
202 | def forward(self, latent_vector, points):
203 | """
204 | Args:
205 | latent_vector: tensor of shape [batch, latent_size], latent code for
206 | each shape
207 | points: tensor of shape [batch, num_points, dim], points representing
208 | each shape
209 |
210 | Returns:
211 | velocities: tensor of shape [batch, num_points, dim], velocity at
212 | each point
213 | """
214 | latent_vector = latent_vector.unsqueeze(1).expand(
215 | -1, points.shape[1], -1
216 | ) # [batch, num_points, latent_size]
217 | # wrapper for pde layer
218 |
219 | def fwd_fn(points):
220 | """Forward function.
221 |
222 | Where inpt[..., 0], inpt[..., 1], inpt[..., 2] correspond to x, y,
223 | z and
224 | out[..., 0], out[..., 1], out[..., 2] correspond to u, v, w
225 | """
226 | points_latents = torch.cat(
227 | (points, latent_vector), axis=-1
228 | ) # [batch, num_points, dim + latent_size]
229 | b, n, d = points_latents.shape
230 | res = self.net(points_latents.reshape([-1, d]))
231 | res = res.reshape([b, n, self.dim])
232 | return res
233 |
234 | if self.divfree:
235 | # return the curl of the velocity field instead
236 | self.curl.update_forward_method(fwd_fn)
237 | _, res_dict = self.curl(points) # res are the equation residues
238 | res = torch.cat(
239 | [res_dict["curl_x"], res_dict["curl_y"], res_dict["curl_z"]],
240 | dim=-1,
241 | ) # [batch, num_points, dim]
242 | return res
243 | else:
244 | return fwd_fn(points)
245 |
246 |
247 | class ConformalDeformationFlowNetwork(nn.Module):
248 | def __init__(
249 | self,
250 | dim=3,
251 | latent_size=1,
252 | nlayers=4,
253 | width=50,
254 | nonlinearity="softplus",
255 | output_scalar=False,
256 | arch="imnet",
257 | ):
258 | """Intialize conformal deformation flow network w/ irrotational flow.
259 |
260 | The network produces a scalar field Phi(x,y,z,t), and the velocity
261 | field is represented as the gradient of Phi.
262 | v = \nabla\Phi # noqa: W605
263 | The gradients can be efficiently computed as the Jacobian through
264 | backprop.
265 |
266 | Args:
267 | dim: int, physical dimensions. Either 2 for 2d or 3 for 3d.
268 | latent_size: int, size of latent space. >= 1.
269 | nlayers: int, number of neural network layers. >= 2.
270 | width: int, number of neurons per hidden layer. >= 1.
271 | """
272 | super(ConformalDeformationFlowNetwork, self).__init__()
273 | self.dim = dim
274 | self.latent_size = latent_size
275 | self.nlayers = nlayers
276 | self.width = width
277 | self.nonlinearity = nonlinearity
278 | self.output_scalar = output_scalar
279 |
280 | self.scale = nn.Parameter(torch.ones(1) * 1e-1)
281 |
282 | nlin = NONLINEARITIES[nonlinearity]
283 |
284 | modules = [nn.Linear(dim + latent_size, width), nlin]
285 | for i in range(nlayers - 2):
286 | modules += [nn.Linear(width, width), nlin]
287 | modules += [nn.Linear(width, 1)]
288 | self.net = nn.Sequential(*modules)
289 |
290 | for m in self.net.modules():
291 | if isinstance(m, nn.Linear):
292 | nn.init.normal_(m.weight, mean=0, std=1e-1)
293 | nn.init.constant_(m.bias, val=0)
294 |
295 | def forward(self, latent_vector, points):
296 | """
297 | Args:
298 | latent_vector: tensor of shape [batch, latent_size], latent code for
299 | each shape
300 | points: tensor of shape [batch, num_points, dim], points representing
301 | each shape
302 | Returns:
303 | velocities: tensor of shape [batch, num_points, dim], velocity at
304 | each point
305 | """
306 | latent_vector = latent_vector.unsqueeze(1).expand(
307 | -1, points.shape[1], -1
308 | ) # [batch, num_points, latent_size]
309 | b, num_points, lat_size = latent_vector.shape
310 | latent_flat = latent_vector.reshape([-1, lat_size])
311 | points_flat = points.reshape([-1, self.dim])
312 | points_flat_ = torch.autograd.Variable(points_flat)
313 | points_flat_.requires_grad = True
314 | points_latents = torch.cat(
315 | (points_flat_, latent_flat), axis=-1
316 | ) # [batch*num_points, dim + latent_size]
317 | phi = self.net(points_latents) * self.scale
318 | vel_flat = torch.autograd.grad(
319 | phi,
320 | points_flat_,
321 | grad_outputs=torch.ones_like(phi),
322 | create_graph=True,
323 | )[0]
324 | vel = vel_flat.reshape(points.shape)
325 | if self.output_scalar:
326 | return vel, phi
327 | else:
328 | return vel
329 |
330 |
331 | class DeformationSignNetwork(nn.Module):
332 | def __init__(
333 | self, latent_size=1, nlayers=3, width=20, nonlinearity="tanh"
334 | ):
335 | """Initialize deformation sign network.
336 | Args:
337 | latent_size: int, size of latent space. >= 1.
338 | nlayers: int, number of neural network layers. >= 2.
339 | width: int, number of neurons per hidden layer. >= 1.
340 | """
341 | super(DeformationSignNetwork, self).__init__()
342 | self.latent_size = latent_size
343 | self.nlayers = nlayers
344 | self.width = width
345 |
346 | nlin = NONLINEARITIES[nonlinearity]
347 | modules = [nn.Linear(latent_size, width, bias=False), nlin]
348 | for i in range(nlayers - 2):
349 | modules += [nn.Linear(width, width, bias=False), nlin]
350 | modules += [nn.Linear(width, 1, bias=False), nlin]
351 | self.net = nn.Sequential(*modules)
352 |
353 | for m in self.net.modules():
354 | if isinstance(m, nn.Linear):
355 | nn.init.normal_(m.weight, mean=0, std=1e-1)
356 |
357 | def forward(self, dir_vector):
358 | """
359 | Args:
360 | dir_vector: tensor of shape [batch, latent_size], latent direction.
361 | Returns:
362 | signs: tensor of shape [batch, 1, 1]
363 | """
364 | dir_vector = dir_vector / (
365 | torch.norm(dir_vector, dim=-1, keepdim=True) + 1e-6
366 | ) # normalize
367 | signs = self.net(dir_vector).unsqueeze(-1)
368 | return signs
369 |
370 |
371 | class NeuralFlowModel(nn.Module):
372 | def __init__(
373 | self,
374 | dim=3,
375 | latent_size=1,
376 | f_nlayers=4,
377 | f_width=50,
378 | s_nlayers=3,
379 | s_width=20,
380 | nonlinearity="relu",
381 | conformal=False,
382 | arch="imnet",
383 | no_sign_net=False,
384 | divfree=False,
385 | symm_dim=None,
386 | ):
387 | super(NeuralFlowModel, self).__init__()
388 | if conformal:
389 | model = ConformalDeformationFlowNetwork
390 |
391 | else:
392 | model = DeformationFlowNetwork
393 | self.no_sign_net = no_sign_net
394 | self.flow_net = model(
395 | dim=dim,
396 | latent_size=latent_size,
397 | nlayers=f_nlayers,
398 | width=f_width,
399 | nonlinearity=nonlinearity,
400 | arch=arch,
401 | divfree=divfree,
402 | )
403 | if not no_sign_net:
404 | self.sign_net = DeformationSignNetwork(
405 | latent_size=latent_size, nlayers=s_nlayers, width=s_width
406 | )
407 | self.symm_dim = symm_dim
408 | self.latent_source = None
409 | self.latent_target = None
410 | self.latent_updated = False
411 | self.conformal = conformal
412 | self.arch = arch
413 | self.encoder = None
414 | self.lat_params = None
415 | self.scale = nn.Parameter(torch.ones(1) * 1e-3)
416 |
417 | def add_encoder(self, encoder):
418 | self.encoder = encoder
419 |
420 | def add_lat_params(self, lat_params):
421 | self.lat_params = lat_params
422 |
423 | def get_lat_params(self, idx):
424 | assert self.lat_params is not None
425 | return self.lat_params[idx]
426 |
427 | def update_latents(self, latent_sequence):
428 | """
429 | Args:
430 | latent_sequence: long or float tensor of shape
431 | [batch, nsteps, latent_size].
432 | sequence of latents along deformation path.
433 | if long, index into self.lat_params to retrieve latents.
434 | Returns:
435 | latent_waypoint: float tensor of shape [batch, nsteps], interp
436 | coefficient betwene [0, 1] corresponding to each latent code.
437 | """
438 | bs, ns, d = latent_sequence.shape
439 | dev = latent_sequence.device
440 | self.latent_sequence = latent_sequence
441 | self.latent_seq_len = torch.norm(
442 | self.latent_sequence[:, 1:] - self.latent_sequence[:, :-1], dim=-1
443 | ) # [batch, nsteps-1]
444 | self.latent_seq_len_sum = torch.sum(
445 | self.latent_seq_len, dim=1
446 | ) # [batch]
447 | self.latent_seq_weight = (
448 | self.latent_seq_len / self.latent_seq_len_sum[:, None]
449 | ) # [batch, nsteps-1]
450 | self.latent_seq_bins = torch.cumsum(
451 | self.latent_seq_weight, dim=1
452 | ) # [batch, nsteps-1]
453 | self.latent_seq_bins = torch.cat(
454 | [torch.zeros([bs, 1], device=dev), self.latent_seq_bins], dim=1
455 | ) # [batch, nsteps]
456 | self.latent_updated = True
457 |
458 | return self.latent_seq_bins
459 |
460 | def latent_at_t(self, t, return_sign=False):
461 | """Helper fn to compute latent at t."""
462 | t = t.to(self.latent_seq_bins.device)
463 | # find out which bin this t falls into
464 | bin_mask = (t > self.latent_seq_bins[:, :-1]) * (
465 | t < self.latent_seq_bins[:, 1:]
466 | )
467 | # logical and
468 |
469 | bin_mask = bin_mask.float()
470 | bin_idx = torch.argmax(bin_mask, dim=1) # [batch,]
471 | batch_idx = torch.arange(bin_idx.shape[0]).to(bin_idx.device)
472 |
473 | # Find the interpolation coefficient between the latents at the two
474 | # ends of the bin
475 | t0 = self.latent_seq_bins[batch_idx, bin_idx]
476 | t1 = self.latent_seq_bins[batch_idx, bin_idx + 1] # [batch]
477 | alpha = (t - t0) / (t1 - t0) # [batch]
478 | latent_t0 = self.latent_sequence[
479 | batch_idx, bin_idx
480 | ] # [batch, latent_size]
481 | latent_t1 = self.latent_sequence[
482 | batch_idx, bin_idx + 1
483 | ] # [batch, latent_size]
484 | latent_val = latent_t0 + alpha[:, None] * (latent_t1 - latent_t0)
485 | latent_dir = (latent_t1 - latent_t0) / torch.norm(
486 | latent_t1 - latent_t0, dim=1, keepdim=True
487 | )
488 | zeros = torch.zeros_like(latent_t0)
489 | outward = torch.norm(latent_t0 - zeros, dim=1) < 1e-6 # [batch]
490 | sign = (outward.float() - 0.5) * 2
491 |
492 | return latent_val, latent_dir, sign
493 |
494 | def forward(self, t, points):
495 | """
496 | Args:
497 | t: float, deformation parameter between 0 and 1.
498 | points: [batch, num_points, dim]
499 | Returns:
500 | vel: [batch, num_points, dim]
501 | """
502 | # Reparametrize eval along latent path as a function of a single
503 | # scalar t
504 | if not self.latent_updated:
505 | raise RuntimeError(
506 | "Latent not updated. "
507 | "Use .update_latents() to update the source and target latents"
508 | )
509 |
510 | latent_val, latent_dir, sign = self.latent_at_t(t)
511 | sign = sign[:, None, None] * self.scale
512 | if self.symm_dim is None:
513 | flow = self.flow_net(latent_val, points) # [batch, num_pints, dim]
514 | else:
515 | flow = symmetrize(self.flow_net, latent_val, points, self.symm_dim)
516 | # Normalize velocity based on time space proportional to latent
517 | # difference.
518 | flow *= self.latent_seq_len_sum[:, None, None]
519 | if not self.no_sign_net:
520 | sign = self.sign_net(latent_dir)
521 | return flow * sign
522 |
523 |
524 | class NeuralFlowDeformer(nn.Module):
525 | def __init__(
526 | self,
527 | dim=3,
528 | latent_size=1,
529 | f_nlayers=4,
530 | f_width=50,
531 | s_nlayers=3,
532 | s_width=20,
533 | method="dopri5",
534 | nonlinearity="leakyrelu",
535 | arch="imnet",
536 | conformal=False,
537 | adjoint=True,
538 | atol=1e-5,
539 | rtol=1e-5,
540 | via_hub=False,
541 | no_sign_net=False,
542 | return_waypoints=False,
543 | use_latent_waypoints=False,
544 | divfree=False,
545 | symm_dim=None,
546 | ):
547 | """Initialize. The parameters are the parameters for the Deformation
548 | Flow network.
549 |
550 | Args:
551 | dim: int, physical dimensions. Either 2 for 2d or 3 for 3d.
552 | latent_size: int, size of latent space. >= 1.
553 | f_nlayers: int, number of neural network layers for flow network.
554 | (>= 2).
555 | f_width: int, number of neurons per hidden layer for flow network
556 | (>= 1).
557 | s_nlayers: int, number of neural network layers for sign network
558 | (>= 2).
559 | s_width: int, number of neurons per hidden layer for sign network.
560 | (>= 1).
561 | arch: str, architecture, choice of 'imnet' / 'vanilla'
562 | adjoint: bool, whether to use adjoint solver to backprop gadient
563 | thru odeint.
564 | rtol, atol: float, relative / absolute error tolerence in ode solver.
565 | via_hub: bool, will perform transformation via hub-and-spokes
566 | configuration. Only useful if latent_sequence is torch.long
567 | return_waypoints: bool, return intermediate waypoints along timing.
568 | use_latent_waypoints: bool, use latent waypoints.
569 | symm_dim: int, list of int, or None. Symmetry axis/axes, or None.
570 | """
571 | super(NeuralFlowDeformer, self).__init__()
572 | self.method = method
573 | self.conformal = conformal
574 | self.arch = arch
575 | self.adjoint = adjoint
576 | self.odeint = odeint_adjoint if adjoint else odeint_regular
577 | self.__timing = torch.from_numpy(
578 | np.array([0.0, 1.0]).astype("float32")
579 | )
580 | self.return_waypoints = return_waypoints
581 | self.use_latent_waypoints = use_latent_waypoints
582 | self.rtol = rtol
583 | self.atol = atol
584 | self.via_hub = via_hub
585 | self.symm_dim = symm_dim
586 |
587 | self.net = NeuralFlowModel(
588 | dim=dim,
589 | latent_size=latent_size,
590 | f_nlayers=f_nlayers,
591 | f_width=f_width,
592 | s_nlayers=s_nlayers,
593 | s_width=s_width,
594 | arch=arch,
595 | conformal=conformal,
596 | nonlinearity=nonlinearity,
597 | no_sign_net=no_sign_net,
598 | divfree=divfree,
599 | symm_dim=symm_dim,
600 | )
601 | if symm_dim is not None:
602 | if not (isinstance(symm_dim, int) or isinstance(symm_dim, list)):
603 | raise ValueError(
604 | "symm_dim must be int or list of ints, indicating axes of"
605 | "symmetry."
606 | )
607 |
608 | @property
609 | def adjoint(self):
610 | return self.__adjoint
611 |
612 | @adjoint.setter
613 | def adjoint(self, isadjoint):
614 | assert isinstance(isadjoint, bool)
615 | self.__adjoint = isadjoint
616 | self.odeint = odeint_adjoint if isadjoint else odeint_regular
617 |
618 | @property
619 | def timing(self):
620 | return self.__timing
621 |
622 | @timing.setter
623 | def timing(self, timing):
624 | assert isinstance(timing, torch.Tensor)
625 | assert timing.ndim == 1
626 | self.__timing = timing
627 |
628 | def add_encoder(self, encoder):
629 | self.net.add_encoder(encoder)
630 |
631 | def add_lat_params(self, lat_params):
632 | self.net.add_lat_params(lat_params)
633 |
634 | def get_lat_params(self, idx):
635 | return self.net.get_lat_params(idx)
636 |
637 | def forward(self, points, latent_sequence):
638 | """Forward transformation (source -> latent_path -> target).
639 |
640 | To perform backward transformation, simply switch the order of the lat
641 | codes.
642 |
643 | Args:
644 | points: [batch, num_points, dim]
645 | latent_sequence: float tensor of shape [batch, nsteps, latent_size],
646 | ------- or -------
647 | long tensor of shape [batch, nsteps]
648 | sequence of latents along deformation path.
649 | if long, index into self.lat_params to retrieve latents.
650 | Returns:
651 | points_transformed:
652 | tensor of shape [batch, num_points, dim] if not
653 | self.return_waypoint.
654 | tensor of shape [nsteps, batch, num_points, dim] if
655 | self.return_waypoint.
656 | """
657 | if latent_sequence.dtype == torch.long:
658 | latent_sequence = self.get_lat_params(
659 | latent_sequence
660 | ) # [nsteps, batch, lat_dim]
661 | if self.via_hub:
662 | assert latent_sequence.shape[1] == 2
663 | zeros = torch.zeros_like(latent_sequence[:, :1])
664 | lat0 = latent_sequence[:, 0:1]
665 | lat1 = latent_sequence[:, 1:2]
666 | latent_sequence = torch.cat(
667 | [lat0, zeros, lat1], dim=1
668 | ) # [batch, nsteps=3, lat_dim]
669 | waypoints = self.net.update_latents(latent_sequence)
670 | if self.use_latent_waypoints:
671 | timing = waypoints[0]
672 | else:
673 | timing = self.timing
674 | points_transformed = self.odeint(
675 | self.net,
676 | points,
677 | timing,
678 | method=self.method,
679 | rtol=self.rtol,
680 | atol=self.atol,
681 | )
682 | if self.return_waypoints:
683 | return points_transformed
684 | else:
685 | return points_transformed[-1]
686 |
--------------------------------------------------------------------------------
/shapeflow/layers/pde_layer.py:
--------------------------------------------------------------------------------
1 | import sympy
2 | import torch
3 | from torch.autograd import grad
4 | from sympy.parsing.sympy_parser import parse_expr
5 |
6 |
7 | # utility functions for parsing equations
8 | torch_diff = lambda y, x: grad( # noqa: E731
9 | y, x, grad_outputs=torch.ones_like(y), create_graph=True, allow_unused=True
10 | )[0]
11 |
12 |
13 | class PDELayer(object):
14 | """PDE Layer for querying values and computing PDE residues."""
15 |
16 | def __init__(self, in_vars, out_vars):
17 | """Initialize physics layer.
18 |
19 | Args:
20 | in_vars: str, a string of input variable names separated by space.
21 | E.g., 'x y t' for the three variables x, y and t.
22 | out_vars: str, a string of output variable names separated by space.
23 | E.g., 'u v p' for the three variables u, v and p.
24 | """
25 | self.in_vars = sympy.symbols(in_vars)
26 | self.out_vars = sympy.symbols(out_vars)
27 | if not isinstance(self.in_vars, tuple):
28 | self.in_vars = (self.in_vars,)
29 | if not isinstance(self.out_vars, tuple):
30 | self.out_vars = (self.out_vars,)
31 | self.n_in = len(self.in_vars)
32 | self.n_out = len(self.out_vars)
33 | self.all_vars = list(self.in_vars) + list(self.out_vars)
34 | self.eqns_raw = {} # raw string equations
35 | self.eqns_fn = {} # lambda function for the equations
36 | self.forward_method = None
37 |
38 | def add_equation(self, eqn_str, eqn_name="", subs_dict=None):
39 | """Add an equation to the physics layer.
40 |
41 | The equation string should represent the expression for computing the
42 | residue of a given equation, rather than representing the equation
43 | itself. Use dif(y,x) for computing the derivate of y with respect to x.
44 | Sign of the expression does not matter. The variable names
45 | **MUST** be the same as the variables in self.in_vars and
46 | self.out_vars.
47 |
48 | E.g.,
49 | For the equation
50 | partial(u, x) + partial(v, y) = 3*partial(u, y)*partial(v, x),
51 | write as:
52 | eqn_str = 'dif(u,x)+dif(v,y)-3*dif(u,y)*dif(v,x)'
53 | - or -
54 | eqn_str = '3*dif(u,y)*dif(v,x)-(dif(u,x)+dif(v,y))'
55 |
56 | Args:
57 | eqn_str: str, a string that can be parsed as an experession for
58 | computing the residue of an equation.
59 | eqn_name: str, a name or identifier for this equation entry. E.g.,
60 | 'div_free'. If none or empty, use default of eqn_i where i is an
61 | index.
62 | subs_dict: dict, a dictionary where the key (str) is the variable
63 | to subsitute and val (str) is the expression to substitite the =
64 | variable with. useful for scenarios such as normalizations and/or
65 | non-dimensionalizing expressions.
66 |
67 | Raises:
68 | ValueError: when the variables in the eqn_str do not match that of
69 | in_vars and out_vars.
70 |
71 | """
72 | if not eqn_name:
73 | eqn_name = "eqn_{i}".format(i=len(self.eqns_raw.keys()))
74 |
75 | # Assert that the equation contains the same vars as in_vars and
76 | # out_vars.
77 | expr = parse_expr(eqn_str)
78 |
79 | # substitute variables in the equation.
80 | if subs_dict:
81 | for key, val in subs_dict.items():
82 | expr = expr.subs(key, val)
83 |
84 | valid_var = expr.free_symbols <= (
85 | set(self.in_vars) | set(self.out_vars)
86 | )
87 | if not valid_var:
88 | raise ValueError(
89 | "Variables in the eqn_str ({}) does not match that of "
90 | "in_vars ({}) and out_vars ({})".format(
91 | expr.free_symbols, set(self.in_vars), set(self.out_vars)
92 | )
93 | )
94 |
95 | # convert into lambda functions
96 | fn = sympy.lambdify(self.all_vars, expr, {"dif": torch_diff})
97 |
98 | # update equations
99 | self.eqns_raw.update({eqn_name: eqn_str})
100 | self.eqns_fn.update({eqn_name: fn})
101 |
102 | def update_forward_method(self, forward_method):
103 | """Update forward method.
104 |
105 | Args:
106 | forward_method: a function, such that y = forward_method(x). x is a
107 | tensor of shape (..., n_in) and y is a tensor of shape (..., n_out).
108 | """
109 | self.forward_method = forward_method
110 |
111 | def eval(self, x):
112 | """Evaluate the output values using forward_method.
113 |
114 | Args:
115 | x: a tensor of shape (..., n_in)
116 | Returns:
117 | a tensor of shape (..., n_out)
118 | """
119 | if not self.forward_method:
120 | raise RuntimeError(
121 | "forward_method has not been defined."
122 | "Run update_forward_method first."
123 | )
124 | y = self.forward_method(x)
125 | if not ((x.shape[-1] == self.n_in) and (y.shape[-1] == self.n_out)):
126 | raise ValueError(
127 | "Input/output dimensions ({}/{}) not equal to the dimensions "
128 | "of defined variables ({}/{}).".format(
129 | x.shape[-1], y.shape[-1], self.n_in, self.n_out
130 | )
131 | )
132 | return y
133 |
134 | def __call__(self, x, return_residue=True):
135 | """Compute the forward eval and possibly compute residues from the
136 | previously defined pdes.
137 |
138 | Args:
139 | x: input tensor of shape (..., n_in)
140 | return_residue: bool, whether to return the residue of the pde for
141 | each equation.
142 | Returns:
143 | y: output tensor of shape (..., n_out)
144 | residues (optional): a dictionary containing residue evaluation for
145 | each pde.
146 | """
147 |
148 | if not return_residue:
149 | y = self.eval(x)
150 | return y
151 | else:
152 | with torch.enable_grad():
153 | # split into individual channels and set each to require grad.
154 | inputs = [x[..., i: i + 1] for i in range(x.shape[-1])]
155 | for xx in inputs:
156 | if not xx.requires_grad:
157 | xx.requires_grad = True
158 | x_ = torch.cat(inputs, axis=-1)
159 | y = self.eval(x_)
160 | outputs = [y[..., i: i + 1] for i in range(y.shape[-1])]
161 | inputs_outputs = inputs + outputs
162 | residues = {}
163 | for key, fn in self.eqns_fn.items():
164 | residue = fn(*inputs_outputs)
165 | residues.update({key: residue})
166 | return y, residues
167 |
168 | @property
169 | def eqn_num(self):
170 | return len(self.eqns_raw)
171 |
172 | @property
173 | def eqn_names(self):
174 | return list(self.eqns_raw.keys())
175 |
--------------------------------------------------------------------------------
/shapeflow/layers/pointnet_layer.py:
--------------------------------------------------------------------------------
1 | """
2 | Credits:
3 | Originally implemented by Fei Xia.
4 | https://github.com/fxia22/pointnet.pytorch
5 |
6 | with small modifications.
7 | """
8 | from __future__ import print_function
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.parallel
12 | import torch.utils.data
13 | from torch.autograd import Variable
14 | import numpy as np
15 |
16 | from .shared_definition import NORMTYPE, NONLINEARITIES
17 |
18 |
19 | class STN3d(nn.Module):
20 | def __init__(self, norm_type, nonlinearity="relu"):
21 | super(STN3d, self).__init__()
22 | assert norm_type in NORMTYPE.keys()
23 | self.conv1 = torch.nn.Conv1d(3, 64, 1)
24 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
25 | self.conv3 = torch.nn.Conv1d(128, 1024, 1)
26 | self.fc1 = torch.nn.Conv1d(1024, 512, 1)
27 | self.fc2 = torch.nn.Conv1d(512, 256, 1)
28 | self.fc3 = torch.nn.Conv1d(256, 9, 1)
29 | self.nl = NONLINEARITIES[nonlinearity]
30 |
31 | self.bn1 = NORMTYPE[norm_type](64)
32 | self.bn2 = NORMTYPE[norm_type](128)
33 | self.bn3 = NORMTYPE[norm_type](1024)
34 | self.bn4 = NORMTYPE[norm_type](512)
35 | self.bn5 = NORMTYPE[norm_type](256)
36 |
37 | def forward(self, x):
38 | batchsize = x.size()[0]
39 | x = self.nl(self.bn1(self.conv1(x)))
40 | x = self.nl(self.bn2(self.conv2(x)))
41 | x = self.nl(self.bn3(self.conv3(x)))
42 | x = torch.max(x, 2, keepdim=True)[0] # [b, c, 1]
43 |
44 | x = self.nl(self.bn4(self.fc1(x))) # [b, c, 1]
45 | x = self.nl(self.bn5(self.fc2(x))) # [b, c, 1]
46 | x = self.fc3(x).squeeze(-1)
47 |
48 | iden = (
49 | Variable(
50 | torch.from_numpy(
51 | np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32)
52 | )
53 | )
54 | .view(1, 9)
55 | .repeat(batchsize, 1)
56 | )
57 | if x.is_cuda:
58 | iden = iden.cuda()
59 | x = x + iden
60 | x = x.view(-1, 3, 3)
61 | return x
62 |
63 |
64 | class PointNetfeat(nn.Module):
65 | def __init__(
66 | self,
67 | in_features=3,
68 | nf=64,
69 | global_feat=True,
70 | feature_transform=False,
71 | norm_type="batchnorm",
72 | nonlinearity="relu",
73 | ):
74 | super(PointNetfeat, self).__init__()
75 | assert norm_type in NORMTYPE.keys()
76 | self.stn = STN3d(norm_type=norm_type, nonlinearity=nonlinearity)
77 | self.conv1 = torch.nn.Conv1d(in_features, nf, 1)
78 | self.conv2 = torch.nn.Conv1d(nf, nf * 2, 1)
79 | self.conv3 = torch.nn.Conv1d(nf * 2, nf * 16, 1)
80 | self.nm1 = NORMTYPE[norm_type](nf)
81 | self.nm2 = NORMTYPE[norm_type](nf * 2)
82 | self.nm3 = NORMTYPE[norm_type](nf * 16)
83 | self.global_feat = global_feat
84 | self.feature_transform = feature_transform
85 | self.nf = nf
86 | self.nl = NONLINEARITIES[nonlinearity]
87 | if self.feature_transform:
88 | self.fstn = STN3d(k=nf)
89 |
90 | def forward(self, x):
91 | n_pts = x.size()[2]
92 | trans = self.stn(x)
93 | x = x.transpose(2, 1)
94 | x = torch.bmm(x, trans)
95 | x = x.transpose(2, 1)
96 | x = self.nl(self.nm1(self.conv1(x)))
97 |
98 | if self.feature_transform:
99 | trans_feat = self.fstn(x)
100 | x = x.transpose(2, 1)
101 | x = torch.bmm(x, trans_feat)
102 | x = x.transpose(2, 1)
103 | else:
104 | trans_feat = None
105 |
106 | pointfeat = x
107 | x = self.nl(self.nm2(self.conv2(x)))
108 | x = self.nm3(self.conv3(x))
109 | x = torch.max(x, 2, keepdim=True)[0]
110 | x = x.view(-1, self.nf * 16)
111 | if self.global_feat:
112 | return x, trans, trans_feat
113 | else:
114 | x = x.view(-1, self.nf * 16, 1).repeat(1, 1, n_pts)
115 | return torch.cat([x, pointfeat], 1), trans, trans_feat
116 |
117 |
118 | class PointNetEncoder(nn.Module):
119 | def __init__(
120 | self,
121 | nf=64,
122 | in_features=3,
123 | out_features=8,
124 | feature_transform=False,
125 | dropout_prob=0.3,
126 | norm_type="batchnorm",
127 | nonlinearity="relu",
128 | ):
129 | super(PointNetEncoder, self).__init__()
130 | assert norm_type in NORMTYPE.keys()
131 | self.feature_transform = feature_transform
132 | self.dropout_prob = dropout_prob
133 | self.feat = PointNetfeat(
134 | global_feat=True,
135 | in_features=in_features,
136 | nf=nf,
137 | feature_transform=feature_transform,
138 | norm_type=norm_type,
139 | nonlinearity=nonlinearity,
140 | )
141 | self.fc1 = nn.Conv1d(nf * 16, nf * 8, 1)
142 | self.fc2 = nn.Conv1d(nf * 8, nf * 4, 1)
143 | self.fc3 = nn.Conv1d(nf * 4, out_features, 1)
144 | self.dropout = nn.Dropout(p=dropout_prob)
145 | self.nm1 = NORMTYPE[norm_type](nf * 8)
146 | self.nm2 = NORMTYPE[norm_type](nf * 4)
147 | self.nl = NONLINEARITIES[nonlinearity]
148 | self.in_features = in_features
149 | self.out_features = out_features
150 | self.nf = nf
151 |
152 | def forward(self, x):
153 | """
154 | Args:
155 | x: tensor of shape [batch, npoints, in_features]
156 | Returns:
157 | output: tensor of shape [batch, out_features]
158 | """
159 | x = x.permute(0, 2, 1)
160 | x, _, _ = self.feat(x)
161 | x = x.unsqueeze(-1)
162 | x = self.nl(self.nm1(self.fc1(x)))
163 | x = self.nl(
164 | self.nm2(self.dropout(self.fc2(x).squeeze(-1)).unsqueeze(-1))
165 | )
166 | x = self.fc3(x).squeeze(-1)
167 | return x
168 |
169 |
170 | def feature_transform_regularizer(trans):
171 | d = trans.size()[1]
172 | I = torch.eye(d)[None, :, :] # noqa: E741
173 | if trans.is_cuda:
174 | I = I.cuda() # noqa: E741
175 | loss = torch.mean(
176 | torch.norm(torch.bmm(trans, trans.transpose(2, 1)) - I, dim=(1, 2))
177 | )
178 | return loss
179 |
180 |
181 | if __name__ == "__main__":
182 | # example for using encoder
183 | encoder = PointNetEncoder(output_features=8)
184 | points = torch.rand(16, 100, 3)
185 | latents = encoder(points)
186 | print(latents.shape)
187 |
--------------------------------------------------------------------------------
/shapeflow/layers/shared_definition.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class NoNorm(nn.Module):
6 | def __init__(self, layers):
7 | super(NoNorm, self).__init__()
8 |
9 | def forward(self, x):
10 | return x
11 |
12 |
13 | class Swish(nn.Module):
14 | def __init__(self):
15 | super(Swish, self).__init__()
16 | self.beta = nn.Parameter(torch.tensor(1.0))
17 |
18 | def forward(self, x):
19 | return x * torch.sigmoid(self.beta * x)
20 |
21 |
22 | class Lambda(nn.Module):
23 | def __init__(self, f):
24 | super(Lambda, self).__init__()
25 | self.f = f
26 |
27 | def forward(self, x):
28 | return self.f(x)
29 |
30 |
31 | OPTIMIZERS = {
32 | "sgd": torch.optim.SGD,
33 | "adam": torch.optim.Adam,
34 | "adadelta": torch.optim.Adadelta,
35 | "adagrad": torch.optim.Adagrad,
36 | "rmsprop": torch.optim.RMSprop,
37 | }
38 |
39 | LOSSES = {
40 | "l1": torch.nn.L1Loss(),
41 | "l2": torch.nn.MSELoss(),
42 | "huber": torch.nn.SmoothL1Loss(),
43 | }
44 |
45 | REDUCTIONS = {
46 | "mean": lambda x: torch.mean(x, axis=-1),
47 | "max": lambda x: torch.max(x, axis=-1)[0],
48 | "min": lambda x: torch.min(x, axis=-1)[0],
49 | "sum": lambda x: torch.sum(x, axis=-1),
50 | }
51 |
52 |
53 | NORMTYPE = {
54 | "batchnorm": nn.BatchNorm1d,
55 | "instancenorm": nn.InstanceNorm1d,
56 | "none": NoNorm,
57 | }
58 |
59 | NONLINEARITIES = {
60 | "tanh": nn.Tanh(),
61 | "relu": nn.ReLU(),
62 | "softplus": nn.Softplus(),
63 | "elu": nn.ELU(),
64 | "swish": Swish(),
65 | "square": Lambda(lambda x: x ** 2),
66 | "identity": Lambda(lambda x: x),
67 | "leakyrelu": nn.LeakyReLU(),
68 | "tanh10x": Lambda(lambda x: torch.tanh(10 * x)),
69 | }
70 |
--------------------------------------------------------------------------------
/shapeflow/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/shapeflow/utils/__init__.py
--------------------------------------------------------------------------------
/shapeflow/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | """Utility tools for training the model.
2 | """
3 | import logging
4 | import os
5 | import shutil
6 | import torch
7 | import numpy as np
8 | from matplotlib import cm, colors
9 |
10 | # pylint: disable=too-many-arguments
11 |
12 |
13 | def save_checkpoint(state, is_best, epoch, output_folder, filename, logger):
14 | """Save checkpoint.
15 | Args:
16 | state: dict, containing state of the model to save.
17 | is_best: bool, indicate whether this is the best model so far.
18 | epoch: int, epoch number.
19 | output_folder: str, path to output folder.
20 | filename: str, the name to save the model as.
21 | logger: logger object to log progress.
22 | """
23 | if epoch > 1:
24 | prev_ckpt = (
25 | output_folder + filename + "_%03d" % (epoch - 1) + ".pth.tar"
26 | )
27 | if os.path.exists(prev_ckpt):
28 | os.remove(prev_ckpt)
29 | torch.save(state, output_folder + filename + "_%03d" % epoch + ".pth.tar")
30 | print(output_folder + filename + "_%03d" % epoch + ".pth.tar")
31 | if is_best:
32 | if logger is not None:
33 | logger.info("Saving new best model")
34 |
35 | shutil.copyfile(
36 | output_folder + filename + "_%03d" % epoch + ".pth.tar",
37 | output_folder + filename + "_best.pth.tar",
38 | )
39 |
40 |
41 | def snapshot_files(list_of_filenames, log_dir):
42 | """Snapshot list of files in current run state to the log directory.
43 | Args:
44 | list_of_filenames: list of str.
45 | log_dir: str, log directory to save code snapshots.
46 | """
47 | snap_dir = os.path.join(log_dir, "snapshots")
48 | os.makedirs(snap_dir, exist_ok=True)
49 | for filename in list_of_filenames:
50 | if filename == os.path.basename(filename):
51 | shutil.copy2(filename, os.path.join(snap_dir, filename))
52 | else:
53 | subdir = os.path.dirname(filename)
54 | os.makedirs(subdir, exist_ok=True)
55 | shutil.copy2(filename, os.path.join(snap_dir, filename))
56 |
57 |
58 | def get_logger(
59 | log_dir, name="train", level=logging.DEBUG, log_file_name="log.txt"
60 | ):
61 | """Get a logger that writes a log file in log_dir.
62 | Args:
63 | log_dir: str, log directory to save logs.
64 | name: str, name of the logger instance.
65 | level: logging level.
66 | log_file_name: str, name of the log file to output.
67 | Returns:
68 | a logger instance
69 | """
70 | logger = logging.getLogger(name)
71 | logger.setLevel(level)
72 | logger.handlers = []
73 | stream_handler = logging.StreamHandler()
74 | logger.addHandler(stream_handler)
75 | file_handler = logging.FileHandler(
76 | os.path.join(log_dir, os.path.basename(log_file_name))
77 | )
78 | logger.addHandler(file_handler)
79 | return logger
80 |
81 |
82 | def colorize_scalar_tensors(
83 | x, vmin=None, vmax=None, cmap="viridis", out_channel="rgb"
84 | ):
85 | """Colorize scalar field tensors.
86 | Args:
87 | x: torch tensor of shape [H, W].
88 | vmin: float, min value to normalize the colors to.
89 | vmax: float, max value to normalize the colors to.
90 | cmap: str or Colormap instance, the colormap used to map normalized
91 | data values to RGBA colors.
92 | out_channel: str, 'rgb' or 'rgba'.
93 | Returns:
94 | y: torch tensor of shape [H, W, 3(or 4 if out_channel=='rgbd')],
95 | mapped colors.
96 | """
97 | if vmin or vmax:
98 | normalizer = colors.Normalize(vmin, vmax)
99 | else:
100 | normalizer = None
101 | assert out_channel in ["rgb", "rgba"]
102 |
103 | mapper = cm.ScalarMappable(norm=normalizer, cmap=cmap)
104 | x_ = x.detach().cpu().numpy()
105 |
106 | y_ = mapper.to_rgba(x_)[..., : len(out_channel)].astype(x_.dtype)
107 | y = torch.tensor(y_, device=x.device)
108 |
109 | return y
110 |
111 |
112 | def batch_colorize_scalar_tensors(
113 | x, vmin=None, vmax=None, cmap="viridis", out_channel="rgb"
114 | ):
115 | """Colorize scalar field tensors.
116 | Args:
117 | x: torch tensor of shape [N, H, W].
118 | vmin: float, or array of length N. min value to normalize the colors
119 | to.
120 | vmax: float, or array of length N. max value to normalize the colors
121 | to.
122 | cmap: str or Colormap instance, the colormap used to map normalized
123 | data values to RGBA
124 | colors.
125 | out_channel: str, 'rgb' or 'rgba'.
126 | Returns:
127 | y: torch tensor of shape [N, H, W, 3(or 4 if out_channel=='rgbd')]
128 | """
129 |
130 | def broadcast_limits(v):
131 | if v:
132 | if not isinstance(v, np.array):
133 | v = np.array(v)
134 | v = np.broadcast_to(v, x.shape[0])
135 | return v
136 |
137 | vmin = broadcast_limits(vmin)
138 | vmax = broadcast_limits(vmax)
139 | y = torch.zeros(list(x.shape) + [len(out_channel)], device=x.device)
140 | for idx in range(x.shape[0]):
141 | y[idx] = colorize_scalar_tensors(x[idx])
142 |
143 | return y
144 |
145 |
146 | def symmetric_duplication(points, symm_dim=2):
147 | """Symmetric duplication of points.
148 |
149 | Args:
150 | points: tensor of shape [batch, npoints, 3]
151 | symm_dim: int, direction of symmetry.
152 | Returns:
153 | duplicated points, tensor of shape [batch, 2*npoints, 3]
154 | """
155 | points_dup = points.clone()
156 | points_dup[..., symm_dim] = -points_dup[..., symm_dim]
157 | points_new = torch.cat([points, points_dup], dim=1)
158 |
159 | return points_new
160 |
--------------------------------------------------------------------------------
/shapenet_dataloader.py:
--------------------------------------------------------------------------------
1 | """ShapeNet deformation dataloader"""
2 | import os
3 | import torch
4 | from torch.utils.data import Dataset, Sampler
5 | import numpy as np
6 | import trimesh
7 | import glob
8 | import imageio
9 | import pickle
10 | from collections import OrderedDict
11 | from scipy.spatial import cKDTree
12 |
13 |
14 | synset_to_cat = {
15 | "02691156": "airplane",
16 | "02933112": "cabinet",
17 | "03001627": "chair",
18 | "03636649": "lamp",
19 | "04090263": "rifle",
20 | "04379243": "table",
21 | "04530566": "watercraft",
22 | "02828884": "bench",
23 | "02958343": "car",
24 | "03211117": "display",
25 | "03691459": "speaker",
26 | "04256520": "sofa",
27 | "04401088": "telephone",
28 | }
29 |
30 | cat_to_synset = {value: key for key, value in synset_to_cat.items()}
31 |
32 | SPLITS = ["train", "test", "val", "*"]
33 |
34 |
35 | def strip_name(filename):
36 | if len(filename.split("/")) > 3:
37 | return "/".join(filename.split("/")[-4:-1])
38 | else:
39 | return filename
40 |
41 |
42 | class ShapeNetBase(Dataset):
43 | """Pytorch Dataset base for loading ShapeNet shape pairs."""
44 |
45 | def __init__(self, data_root, split, category="chair"):
46 | """
47 | Initialize DataSet
48 | Args:
49 | data_root: str, path to data root that contains the ShapeNet dataset.
50 | split: str, one of 'train'/'val'/'test'/'*'. '*' for all splits.
51 | catetory:
52 | str, name of the category to train on. 'all' for all 13 classes.
53 | Otherwise can be a comma separated string containing multiple names
54 | """
55 | self.data_root = data_root
56 | self.split = split
57 |
58 | if not (split in SPLITS):
59 | raise ValueError(f"{split} must be one of {SPLITS}")
60 | self.categories = [c.strip() for c in category.split(",")]
61 | cats = list(cat_to_synset.keys())
62 | if "all" in self.categories:
63 | self.categories = cats
64 | for c in self.categories:
65 | if c not in cats:
66 | raise ValueError(
67 | f"{c} is not in the list of the 13 categories: {cats}"
68 | )
69 | self.files = self._get_filenames(
70 | self.data_root, self.split, self.categories
71 | )
72 | self._file_splits = None
73 |
74 | self.thumbnails_dir = None
75 | self.thumbnails = False
76 | self._fname_to_idx_dict = None
77 |
78 | @property
79 | def file_splits(self):
80 | if self._file_splits is None:
81 | self._file_splits = {"train": [], "test": [], "val": []}
82 | for f in self.files:
83 | if "train/" in f:
84 | self._file_splits["train"].append(f)
85 | elif "test/" in f:
86 | self._file_splits["test"].append(f)
87 | else: # val/
88 | self._file_splits["val"].append(f)
89 |
90 | return self._file_splits
91 |
92 | @staticmethod
93 | def _get_filenames(data_root, split, categories):
94 | files = []
95 | for c in categories:
96 | synset_id = cat_to_synset[c]
97 | if split != "*":
98 | cat_folder = os.path.join(data_root, split, synset_id)
99 | if not os.path.exists(cat_folder):
100 | raise RuntimeError(
101 | f"Datafolder for {synset_id} ({c}) "
102 | f"does not exist at {cat_folder}."
103 | )
104 | files += glob.glob(
105 | os.path.join(cat_folder, "*/*.ply"), recursive=True
106 | )
107 | else:
108 | for split in SPLITS[:3]:
109 | cat_folder = os.path.join(data_root, split, synset_id)
110 | if not os.path.exists(cat_folder):
111 | raise RuntimeError(
112 | f"Datafolder for {synset_id} ({c}) does not exist "
113 | f"at {cat_folder}."
114 | )
115 | files += glob.glob(
116 | os.path.join(cat_folder, "*/*.ply"), recursive=True
117 | )
118 | return sorted(files)
119 |
120 | def __len__(self):
121 | return self.n_shapes ** 2
122 |
123 | @property
124 | def n_shapes(self):
125 | return len(self.files)
126 |
127 | def restrict_subset(self, indices):
128 | """Restrict data to the subset of data as indicated by the indices.
129 |
130 | Mostly helpful for debugging only.
131 |
132 | Args:
133 | indices: list or array of ints, to index the original self.files
134 | """
135 | self.files = [self.files[i] for i in indices]
136 |
137 | @property
138 | def fname_to_idx_dict(self):
139 | """A dict mapping unique mesh names to indicies."""
140 | if self._fname_to_idx_dict is None:
141 | fnames = ["/".join(f.split("/")[-4:-1]) for f in self.files]
142 | self._fname_to_idx_dict = dict(
143 | zip(fnames, list(range(len(fnames))))
144 | )
145 | return self._fname_to_idx_dict
146 |
147 | def idx_to_combinations(self, idx):
148 | """Convert s linear index to a pair of indices."""
149 | i = np.floor(idx / self.n_shapes)
150 | j = idx - i * self.n_shapes
151 | if hasattr(idx, "__len__"):
152 | i = np.array(i, dtype=int)
153 | j = np.array(j, dtype=int)
154 | else:
155 | i = int(i)
156 | j = int(j)
157 | return i, j
158 |
159 | def combinations_to_idx(self, i, j):
160 | """Convert a pair of indices to a linear index."""
161 | idx = i * self.n_shapes + j
162 | if hasattr(idx, "__len__"):
163 | idx = np.array(idx, dtype=int)
164 | else:
165 | idx = int(idx)
166 | return idx
167 |
168 |
169 | class ShapeNetVertex(ShapeNetBase):
170 | """Pytorch Dataset for sampling vertices from meshes."""
171 |
172 | def __init__(
173 | self, data_root, split, category="chair", nsamples=5000, normals=True
174 | ):
175 | """
176 | Initialize DataSet
177 | Args:
178 | data_root: str, path to data root that contains the ShapeNet dataset.
179 | split: str, one of 'train'/'val'/'test'.
180 | catetory: str, name of the category to train on. 'all' for all 13
181 | classes. Otherwise can be a comma separated string containing
182 | multiple names.
183 | nsamples: int, number of points to sample from each mesh.
184 | normals: bool, whether to add normals to the point features.
185 | """
186 | super(ShapeNetVertex, self).__init__(
187 | data_root=data_root, split=split, category=category
188 | )
189 | self.nsamples = nsamples
190 | self.normals = normals
191 |
192 | @staticmethod
193 | def sample_mesh(mesh_path, nsamples, normals=True):
194 | """Load the mesh from mesh_path and sample nsampels points from its vertices.
195 |
196 | If nsamples < number of vertices on mesh, randomly repeat some
197 | vertices as padding.
198 |
199 | Args:
200 | mesh_path: str, path to load the mesh from.
201 | nsamples: int, number of vertices to sample.
202 | normals: bool, whether to add normals to the point features.
203 | Returns:
204 | v_sample: np array of shape [nsamples, 3 or 6] for sampled points.
205 | """
206 | mesh = trimesh.load(mesh_path)
207 | v = np.array(mesh.vertices)
208 | nv = v.shape[0]
209 | seq = np.random.permutation(nv)[:nsamples]
210 | if len(seq) < nsamples:
211 | seq_repeat = np.random.choice(
212 | nv, nsamples - len(seq), replace=True
213 | )
214 | seq = np.concatenate([seq, seq_repeat], axis=0)
215 | v_sample = v[seq]
216 | if normals:
217 | n_sample = np.array(mesh.vertex_normals[seq])
218 | v_sample = np.concatenate([v_sample, n_sample], axis=-1)
219 |
220 | return v_sample
221 |
222 | def add_thumbnails(self, thumbnails_root):
223 | self.thumbnails = True
224 | self.thumbnails_dir = thumbnails_root
225 |
226 | def _get_one_mesh(self, idx):
227 | verts = self.sample_mesh(self.files[idx], self.nsamples, self.normals)
228 | verts = verts.astype(np.float32)
229 | if self.thumbnails:
230 | thumb_dir = self.files[idx].replace(
231 | self.data_root, self.thumbnails_dir
232 | )
233 | thumb_dir = os.path.dirname(thumb_dir)
234 | thumb_file = os.path.join(thumb_dir, "thumbnail.jpg")
235 | thumb = np.array(imageio.imread(thumb_file))
236 | return verts, thumb
237 | else:
238 | return verts
239 |
240 | def __getitem__(self, idx):
241 | """Get a random pair of shapes corresponding to idx.
242 | Args:
243 | idx: int, index of the shape pair to return. must be smaller than
244 | len(self).
245 | Returns:
246 | verts_i: [npoints, 3 or 6] float tensor for point samples from the
247 | first mesh.
248 | verts_j: [npoints, 3 or 6] float tensor for point samples from the
249 | second mesh.
250 | thumb_i: (optional) [H, W, 3] int8 tensor for thumbnail image for
251 | the first mesh.
252 | thumb_j: (optional) [H, W, 3] int8 tensor for thumbnail image for
253 | the second mesh.
254 | """
255 | i, j = self.idx_to_combinations(idx)
256 | if self.thumbnails:
257 | verts_i, thumb_i = self._get_one_mesh(i)
258 | verts_j, thumb_j = self._get_one_mesh(j)
259 | return i, j, verts_i, verts_j, thumb_i, thumb_j
260 | else:
261 | verts_i = self._get_one_mesh(i)
262 | verts_j = self._get_one_mesh(j)
263 | return i, j, verts_i, verts_j
264 |
265 |
266 | class ShapeNetMesh(ShapeNetBase):
267 | """Pytorch Dataset for sampling entire meshes."""
268 |
269 | def __init__(self, data_root, split, category="chair", normals=True):
270 | """
271 | Initialize DataSet
272 | Args:
273 | data_root: str, path to data root that contains the ShapeNet dataset.
274 | split: str, one of 'train'/'val'/'test'.
275 | catetory:
276 | str, name of the category to train on. 'all' for all 13 classes.
277 | Otherwise can be a comma separated string containing multiple
278 | names.
279 | """
280 | super(ShapeNetMesh, self).__init__(
281 | data_root=data_root, split=split, category=category
282 | )
283 | self.normals = normals
284 |
285 | def get_pairs(self, i, j):
286 | verts_i, faces_i = self.get_single(i)
287 | verts_j, faces_j = self.get_single(j)
288 |
289 | return i, j, verts_i, faces_i, verts_j, faces_j
290 |
291 | def get_single(self, i):
292 | mesh_i = trimesh.load(self.files[i])
293 |
294 | verts_i = mesh_i.vertices.astype(np.float32)
295 | faces_i = mesh_i.faces.astype(np.int32)
296 |
297 | if self.normals:
298 | norms_i = mesh_i.vertex_normals.astype(np.float32)
299 | verts_i = np.concatenate([verts_i, norms_i], axis=-1)
300 |
301 | verts_i = torch.from_numpy(verts_i)
302 | faces_i = torch.from_numpy(faces_i)
303 |
304 | return verts_i, faces_i
305 |
306 | def __getitem__(self, idx):
307 | """Get a random pair of meshes.
308 | Args:
309 | idx: int, index of the shape pair to return. must be smaller than
310 | len(self).
311 | Returns:
312 | verts_i: [#vi, 3 or 6] float tensor for vertices from the first mesh.
313 | faces_i: [#fi, 3 or 6] int32 tensor for faces from the first mesh.
314 | verts_j: [#vj, 3 or 6] float tensor for vertices from the 2nd mesh.
315 | faces_j: [#fj, 3 or 6] int32 tensor for faces from the 2nd mesh.
316 | """
317 | i, j = self.idx_to_combinations(idx)
318 | return self.get_pairs(i, j)
319 |
320 |
321 | class FixedPointsCachedDataset(Dataset):
322 | """Dataset for loading fixed points dataset from cached pickle file."""
323 |
324 | def __init__(self, pkl_file, npts=1024):
325 | with open(pkl_file, "rb") as fh:
326 | self.data_dict = pickle.load(fh)
327 | self.data_dict = OrderedDict(sorted(self.data_dict.items()))
328 | self.key_list = list(self.data_dict.keys())
329 | assert npts <= 4096 and npts > 0
330 | self.npts = npts
331 |
332 | def __getitem__(self, idx):
333 | filename = self.key_list[idx]
334 | points = self.data_dict[filename]
335 | rand_seq = np.random.choice(points.shape[0], self.npts, replace=False)
336 | points_ = points[rand_seq]
337 | return filename, idx, points_
338 |
339 | def __len__(self):
340 | return len(self.data_dict)
341 |
342 |
343 | class PairSamplerBase(Sampler):
344 | """Data sampler base for sampling pairs."""
345 |
346 | def __init__(
347 | self, dataset, src_split, tar_split, n_samples, replace=False
348 | ):
349 | assert src_split in SPLITS[:3]
350 | assert tar_split in SPLITS[:3]
351 | self.replace = replace
352 | self.n_samples = n_samples
353 | self.src_split = src_split
354 | self.tar_split = tar_split
355 | self.dataset = dataset
356 | self.src_files = self.dataset.file_splits[src_split]
357 | self.tar_files = self.dataset.file_splits[tar_split]
358 | self.src_files = [strip_name(f) for f in self.src_files]
359 | self.tar_files = [strip_name(f) for f in self.tar_files]
360 | self.n_src = len(self.src_files)
361 | self.n_tar = len(self.tar_files)
362 | if not replace:
363 | if not self.n_samples <= self.n_src:
364 | raise RuntimeError(
365 | f"Numer of samples ({len(self.n_samples)}) must be "
366 | f"less than number source shapes ({len(self.n_src)})"
367 | )
368 |
369 | def __iter__(self):
370 | raise NotImplementedError
371 |
372 | def __len__(self):
373 | raise NotImplementedError
374 |
375 |
376 | class RandomPairSampler(PairSamplerBase):
377 | """Data sampler for sampling random pairs."""
378 |
379 | def __init__(
380 | self, dataset, src_split, tar_split, n_samples, replace=False
381 | ):
382 | super(RandomPairSampler, self).__init__(
383 | dataset, src_split, tar_split, n_samples, replace
384 | )
385 |
386 | def __iter__(self):
387 | d = self.dataset
388 | if self.replace:
389 | src_names = np.random.choice(
390 | self.src_files, self.n_samples, replace=True
391 | )
392 | tar_names = np.random.choice(
393 | self.tar_files, self.n_samples, replace=True
394 | )
395 | else:
396 | src_names = np.random.permutation(self.src_files)[
397 | : int(self.n_samples)
398 | ]
399 | tar_names = np.random.permutation(self.tar_files)[
400 | : int(self.n_samples)
401 | ]
402 | src_idxs = np.array(
403 | [d.fname_to_idx_dict[strip_name(f)] for f in src_names]
404 | )
405 | tar_idxs = np.array(
406 | [d.fname_to_idx_dict[strip_name(f)] for f in tar_names]
407 | )
408 | combo_ids = self.dataset.combinations_to_idx(src_idxs, tar_idxs)
409 |
410 | return iter(combo_ids)
411 |
412 | def __len__(self):
413 | return self.n_samples
414 |
415 |
416 | class LatentNearestNeighborSampler(PairSamplerBase):
417 | """Data sampler for sampling pairs from top-k nearest latent neighbors."""
418 |
419 | def __init__(
420 | self, dataset, src_split, tar_split, n_samples, k, replace=False
421 | ):
422 | """Initialize.
423 |
424 | Args:
425 | k: int, top-k neighbors to sample from.
426 | replace: bool, sample with replacement.
427 | if no replace, then must ensure n_samples <= n_shapes
428 | """
429 | super(LatentNearestNeighborSampler, self).__init__(
430 | dataset, src_split, tar_split, n_samples, replace
431 | )
432 | self.k = k
433 | self.graph_set = False
434 |
435 | def update_nn_graph(self, src_latent_dict, tar_latent_dict, k=None):
436 | """Update nearest neighbor graph.
437 |
438 | Args:
439 | src_latent_dict: a dict that maps filenames to latent codes for
440 | source set.
441 | tar_latent_dict: a dict that maps filenames to latent codes for
442 | target set.
443 | """
444 | if k is not None:
445 | self.k = k
446 | tar_names = list(tar_latent_dict.keys())
447 | tar_latents = list(tar_latent_dict.values())
448 | tar_latents = np.stack(tar_latents, axis=0) # [n, lat_dim]
449 | # build kd-tree to accelerate nearest neighbor computation
450 | k = self.k + 1 if self.src_split == self.tar_split else self.k
451 | self._kdtree = cKDTree(tar_latents)
452 |
453 | src_names = list(src_latent_dict.keys())
454 | src_latents = list(src_latent_dict.values())
455 | src_latents = np.stack(src_latents, axis=0) # [m, lat_dim]
456 | _, nn_idx = self._kdtree.query(src_latents, k=k) # [m, k]
457 | if nn_idx.ndim == 1:
458 | nn_idx = nn_idx[:, None]
459 | nn_idx = nn_idx[:, -self.k:]
460 |
461 | nn_names = []
462 | for i in range(nn_idx.shape[0]):
463 | nn_names.append([tar_names[j] for j in nn_idx[i]])
464 | self._nn_map = dict(zip(src_names, nn_names))
465 | self.graph_set = True
466 |
467 | @property
468 | def kdtree(self):
469 | return self._kdtree
470 |
471 | @property
472 | def nn_map(self):
473 | return self._nn_map
474 |
475 | def __iter__(self):
476 | if not self.graph_set:
477 | raise RuntimeError(
478 | "Nearest neighbor graph not yet set."
479 | " Run '.update_nn_graph()' to update first."
480 | )
481 | d = self.dataset
482 | # return generator
483 | if self.replace:
484 | src_names = np.random.choice(
485 | self.src_files, self.n_samples, replace=True
486 | )
487 | else:
488 | src_names = np.random.permutation(self.src_files)[
489 | : int(self.n_samples)
490 | ]
491 |
492 | for src_name in src_names:
493 | tar_name = np.random.choice(self.nn_map[src_name], 1)[0]
494 | src_idx = d.fname_to_idx_dict[strip_name(src_name)]
495 | tar_idx = d.fname_to_idx_dict[strip_name(tar_name)]
496 | combo_id = d.combinations_to_idx(src_idx, tar_idx)
497 | yield combo_id
498 |
499 | def __len__(self):
500 | return self.n_samples
501 |
--------------------------------------------------------------------------------
/shapenet_embedding.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial import cKDTree
3 | import time
4 | import torch
5 | from torch.utils.data import SubsetRandomSampler, DataLoader
6 |
7 | from shapeflow.layers.chamfer_layer import ChamferDistKDTree
8 | from shapeflow.layers.shared_definition import LOSSES, OPTIMIZERS
9 | import shapeflow.utils.train_utils as utils
10 |
11 |
12 | class LatentEmbedder(object):
13 | """Helper class for embedding new observation in deformation latent space.
14 | """
15 |
16 | def __init__(self, point_dataset, mesh_dataset, deformer, topk=5):
17 | """Initialize embedder.
18 |
19 | Args:
20 | point_dataset: instance of FixedPointsCachedDataset
21 | mesh_dataset: instance of ShapeNetMesh
22 | deformer: pretrined deformer instance
23 | """
24 | self.point_dataset = point_dataset
25 | self.mesh_dataset = mesh_dataset
26 | self.deformer = deformer
27 | self.topk = topk
28 | self.tree = cKDTree(self.lat_params.clone().detach().cpu().numpy())
29 |
30 | @property
31 | def lat_dims(self):
32 | return self.lat_params.shape[1]
33 |
34 | @property
35 | def lat_params(self):
36 | return self.deformer.net.lat_params
37 |
38 | @property
39 | def symm(self):
40 | return self.deformer.symm_dim is not None
41 |
42 | @property
43 | def device(self):
44 | return self.lat_params.device
45 |
46 | def _padded_verts_from_meshes(self, meshes):
47 | verts = [vf[0] for vf in meshes]
48 | faces = [vf[1] for vf in meshes]
49 | nv = [v.shape[0] for v in verts]
50 | max_nv = np.max(nv)
51 | verts_pad = [
52 | np.pad(verts[i], ((0, max_nv - nv[i]), (0, 0)))
53 | for i in range(len(nv))
54 | ]
55 | verts_pad = np.stack(verts_pad, 0) # [nmesh, max_nv, 3]
56 | return verts_pad, faces, nv
57 |
58 | def _meshes_from_padded_verts(self, verts_pad, faces, nv):
59 | verts_pad = [v for v in verts_pad]
60 | verts = [v[:n] for v, n in zip(verts_pad, nv)]
61 | meshes = list(zip(verts, faces))
62 | return meshes
63 |
64 | def embed(
65 | self,
66 | input_points,
67 | optimizer="adam",
68 | lr=1e-3,
69 | seed=0,
70 | embedding_niter=30,
71 | finetune_niter=30,
72 | bs=32,
73 | verbose=False,
74 | matching="two_way",
75 | loss_type="l1",
76 | ):
77 | """Embed inputs points observations into deformation latent space.
78 |
79 | Args:
80 | input_points: tensor of shape [bs_tar, npoints, 3]
81 | optimizer: str, optimizer choice. one of sgd, adam, adadelta,
82 | adagrad, rmsprop.
83 | lr: float, learning rate.
84 | seed: int, random seed.
85 | embedding_niter: int, number of embedding optimization iterations.
86 | finetune_niter: int, number of finetuning optimization iterations.
87 | bs: int, batch size.
88 | verbose: bool, turn on verbose.
89 | matching: str, matching function. choice of one_way or two_way.
90 | loss_type: str, loss type. choice of l1, l2, huber.
91 |
92 | Returns:
93 | embedded_latents: tensor of shape [batch, lat_dims]
94 | """
95 | if input_points.shape[0] != 1:
96 | raise NotImplementedError("Code is not ready for batch size > 1.")
97 | torch.manual_seed(seed)
98 |
99 | # Check input validity.
100 | if matching not in ["one_way", "two_way"]:
101 | raise ValueError(
102 | f"matching method must be one of one_way / two_way. Instead "
103 | f"entered {matching}"
104 | )
105 | if loss_type not in LOSSES.keys():
106 | raise ValueError(
107 | f"loss_type must be one of {LOSSES.keys()}. "
108 | f"Instead entered {loss_type}"
109 | )
110 |
111 | criterion = LOSSES[loss_type]
112 |
113 | bs_tar, npts_tar, _ = input_points.shape
114 | # Assign random latent code close to zero.
115 | embedded_latents = torch.nn.Parameter(
116 | torch.randn(bs_tar, self.lat_dims, device=self.device) * 1e-4,
117 | requires_grad=True,
118 | )
119 | self.deformer.net.tar_latents = embedded_latents
120 | embedded_latents = self.deformer.net.tar_latents
121 | # [bs_tar, lat_dims]
122 |
123 | # Init optimizer.
124 | if optimizer not in OPTIMIZERS.keys():
125 | raise ValueError(f"optimizer must be one of {OPTIMIZERS.keys()}")
126 | optim = OPTIMIZERS[optimizer]([embedded_latents], lr=lr)
127 |
128 | # Init dataloader.
129 | sampler = SubsetRandomSampler(
130 | np.arange(len(self.point_dataset)).tolist()
131 | )
132 | point_loader = DataLoader(
133 | self.point_dataset,
134 | batch_size=bs,
135 | sampler=sampler,
136 | shuffle=False,
137 | drop_last=True,
138 | )
139 |
140 | # Chamfer distance calc.
141 | chamfer_dist = ChamferDistKDTree(reduction="mean", njobs=1)
142 | chamfer_dist.to(self.device)
143 |
144 | def optimize_latent(point_loader, optim, niter):
145 | # Optimize for latents.
146 | self.deformer.train()
147 | toc = time.time()
148 |
149 | bs_src = point_loader.batch_size
150 | embedded_latents_ = embedded_latents[None].expand(
151 | bs_src, bs_tar, self.lat_dims
152 | )
153 | # [bs_src, bs_tar, lat_dims]
154 |
155 | # Broadcast and reshape input points.
156 | target_points_ = (
157 | input_points[None]
158 | .expand(bs_src, bs_tar, npts_tar, 3)
159 | .view(-1, npts_tar, 3)
160 | )
161 | it = 0
162 |
163 | for batch_idx, (fnames, idxs, source_points) in enumerate(
164 | point_loader
165 | ):
166 | tic = time.time()
167 | # Send tensors to device.
168 | source_points = source_points.to(
169 | self.device
170 | ) # [bs_src, npts_src, 3]
171 | idxs = idxs.to(self.device)
172 |
173 | optim.zero_grad()
174 |
175 | # Deform chosen points to input_points.
176 | # Broadcast src lats to src x tar.
177 | source_latents = self.lat_params[idxs] # [bs_src, lat_dims]
178 | source_latents_ = source_latents[:, None].expand(
179 | bs_src, bs_tar, self.lat_dims
180 | )
181 | source_latents_ = source_latents_.view(-1, self.lat_dims)
182 | target_latents_ = embedded_latents_.view(-1, self.lat_dims)
183 | zeros = torch.zeros_like(source_latents_)
184 | source_target_latents = torch.stack(
185 | [source_latents_, zeros, target_latents_], dim=1
186 | )
187 |
188 | deformed_pts = self.deformer(
189 | source_points,
190 | source_target_latents, # [bs_sr*bs_tar, npts_src, 3]
191 | ) # [bs_sr*bs_tar, npts_src, 3]
192 |
193 | # Symmetric pair of matching losses.
194 | if self.symm:
195 | accu, comp, cham = chamfer_dist(
196 | utils.symmetric_duplication(deformed_pts, symm_dim=2),
197 | utils.symmetric_duplication(
198 | target_points_, symm_dim=2
199 | ),
200 | )
201 | else:
202 | accu, comp, cham = chamfer_dist(
203 | deformed_pts, target_points_
204 | )
205 |
206 | if matching == "one_way":
207 | comp = torch.mean(comp, dim=1)
208 | loss = criterion(comp, torch.zeros_like(comp))
209 | else:
210 | loss = criterion(cham, torch.zeros_like(cham))
211 |
212 | # Check amount of deformation.
213 | deform_abs = torch.mean(
214 | torch.norm(deformed_pts - source_points, dim=-1)
215 | )
216 |
217 | loss.backward()
218 |
219 | # Gradient clipping.
220 | torch.nn.utils.clip_grad_value_(embedded_latents, 1.0)
221 |
222 | optim.step()
223 |
224 | toc = time.time()
225 | if verbose:
226 | if loss_type == "l1":
227 | dist = loss.item()
228 | else:
229 | dist = np.sqrt(loss.item())
230 | print(
231 | f"Iter: {it}, Loss: {loss.item():.4f}, "
232 | f"Dist: {dist:.4f}, "
233 | f"Deformation Magnitude: {deform_abs.item():.4f}, "
234 | f"Time per iter (s): {toc-tic:.4f}"
235 | )
236 | it += 1
237 | if batch_idx >= niter:
238 | break
239 |
240 | # Optimize to range.
241 | optimize_latent(point_loader, optim, embedding_niter)
242 | latents_pre_tune = embedded_latents.detach().cpu().numpy()
243 |
244 | # Finetune topk.
245 | dist, idxs = self.tree.query(
246 | embedded_latents.detach().cpu().numpy(), k=self.topk
247 | ) # [batch, k]
248 | bs, k = idxs.shape
249 | idxs_ = idxs.reshape(-1)
250 |
251 | # Change lr.
252 | for param_group in optim.param_groups:
253 | param_group["lr"] = 1e-3
254 |
255 | sampler = SubsetRandomSampler(idxs_.tolist() * finetune_niter)
256 | point_loader = DataLoader(
257 | self.point_dataset,
258 | batch_size=self.topk,
259 | sampler=sampler,
260 | shuffle=False,
261 | drop_last=True,
262 | )
263 |
264 | print(f"Finetuning for {finetune_niter} iters...")
265 | optim = OPTIMIZERS[optimizer](
266 | [embedded_latents] + list(self.deformer.parameters()), lr=1e-3
267 | )
268 | optimize_latent(point_loader, optim, finetune_niter)
269 | latents_post_tune = embedded_latents.detach().cpu().numpy()
270 |
271 | return latents_pre_tune, latents_post_tune
272 |
273 | def retrieve(self, lat_codes, tar_pts, matching="one_way"):
274 | """Retrieve top 10 nearest neighbors, deform and pick the best one.
275 |
276 | Args:
277 | lat_codes: tensor of shape [batch, lat_dims], latent code targets.
278 |
279 | Returns:
280 | List of len batch of (V, F) tuples.
281 | """
282 | if lat_codes.shape[0] != 1:
283 | raise NotImplementedError("Code is not ready for batch size > 1.")
284 | dist, idxs = self.tree.query(lat_codes, k=self.topk) # [batch, k]
285 | bs, k = idxs.shape
286 | idxs_ = idxs.reshape(-1)
287 |
288 | if not isinstance(lat_codes, torch.Tensor):
289 | lat_codes = torch.tensor(lat_codes).float().to(self.device)
290 |
291 | src_latent = self.lat_params[idxs_] # [batch*k, lat_dims]
292 | tar_latent = (
293 | lat_codes[:, None]
294 | .expand(bs, k, self.lat_dims)
295 | .reshape(-1, self.lat_dims)
296 | ) # [batch*k, lat_dims]
297 | zeros = torch.zeros_like(src_latent)
298 | src_tar_latent = torch.stack([src_latent, zeros, tar_latent], dim=1)
299 |
300 | # Retrieve meshes.
301 | orig_meshes = [
302 | self.mesh_dataset.get_single(i) for i in idxs_
303 | ] # [(v1,f1), ..., (vn,fn)]
304 | src_verts, faces, nv = self._padded_verts_from_meshes(orig_meshes)
305 | src_verts = torch.tensor(src_verts).to(self.device)
306 | with torch.no_grad():
307 | deformed_verts = self.deformer(src_verts, src_tar_latent)
308 | deformed_meshes = self._meshes_from_padded_verts(
309 | deformed_verts, faces, nv
310 | )
311 |
312 | # Chamfer distance calc.
313 | chamfer_dist = ChamferDistKDTree(reduction="mean", njobs=1)
314 | chamfer_dist.to(self.device)
315 | dist = []
316 | for i in range(len(deformed_meshes)):
317 | accu, comp, cham = chamfer_dist(
318 | deformed_meshes[i][0][None].to(self.device),
319 | torch.tensor(tar_pts)[None].to(self.device),
320 | )
321 | if matching == "one_way":
322 | dist.append(torch.mean(comp, dim=1).item())
323 | else:
324 | dist.append(cham.item())
325 |
326 | # Reshape the list of (v, f) tuples.
327 | deformed_meshes = [
328 | (vf[0].detach().cpu().numpy(), vf[1].detach().cpu().numpy())
329 | for vf in deformed_meshes
330 | ]
331 |
332 | return deformed_meshes, orig_meshes, dist
333 |
--------------------------------------------------------------------------------
/shapenet_generation.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # CONFIG
4 | category=03001627
5 | ckpt=runs/pretrained_chair_symm128/checkpoint_latest.pth.tar_shapeflow_100.pth.tar
6 |
7 | gpuid=$1 # gpu id
8 | sid=$2 # start id
9 | eid=$3 # end id
10 |
11 | export CUDA_VISIBLE_DEVICES=$gpuid
12 |
13 | allfiles=($(ls data/sparse_inputs/$category/*.ply))
14 |
15 | for i in $(seq $sid $((eid-1))); do
16 | infile=${allfiles[$i]}
17 | python shapenet_reconstruct.py --input_path=$infile --output_dir=out --checkpoint=$ckpt --device=cuda
18 | done
19 |
--------------------------------------------------------------------------------
/shapenet_reconstruct.py:
--------------------------------------------------------------------------------
1 | """Reconstruct shape from point cloud using learned deformation space.
2 | """
3 | import os
4 | import sys
5 | import argparse
6 | import json
7 | import trimesh
8 | import torch
9 | import numpy as np
10 | import time
11 | from types import SimpleNamespace
12 |
13 | from shapenet_dataloader import ShapeNetMesh, FixedPointsCachedDataset
14 | from shapeflow.layers.deformation_layer import NeuralFlowDeformer
15 | from shapenet_embedding import LatentEmbedder
16 |
17 |
18 | synset_to_cat = {
19 | "02691156": "airplane",
20 | "02933112": "cabinet",
21 | "03001627": "chair",
22 | "03636649": "lamp",
23 | "04090263": "rifle",
24 | "04379243": "table",
25 | "04530566": "watercraft",
26 | "02828884": "bench",
27 | "02958343": "car",
28 | "03211117": "display",
29 | "03691459": "speaker",
30 | "04256520": "sofa",
31 | "04401088": "telephone",
32 | }
33 |
34 | cat_to_synset = {value: key for key, value in synset_to_cat.items()}
35 |
36 |
37 | def get_args():
38 | """Parse command line arguments."""
39 | parser = argparse.ArgumentParser(
40 | description="Generate reconstructions via retrieve and deform."
41 | )
42 |
43 | parser.add_argument(
44 | "--input_path",
45 | type=str,
46 | required=True,
47 | help="path to input points (.ply file).",
48 | )
49 | parser.add_argument(
50 | "--output_dir", type=str, required=True, help="path to output meshes."
51 | )
52 | parser.add_argument(
53 | "--topk",
54 | type=int,
55 | default=4,
56 | help="top k nearest neighbor to retrieve.",
57 | )
58 | parser.add_argument(
59 | "-ne",
60 | "--embedding_niter",
61 | type=int,
62 | default=30,
63 | help="number of embedding iterations.",
64 | )
65 | parser.add_argument(
66 | "-nf",
67 | "--finetune_niter",
68 | type=int,
69 | default=30,
70 | help="number of finetuning iterations.",
71 | )
72 | parser.add_argument(
73 | "--checkpoint",
74 | type=str,
75 | required=True,
76 | help="path to pretrained checkpoint "
77 | "(params.json must be in the same directory).",
78 | )
79 | parser.add_argument(
80 | "--device",
81 | type=str,
82 | default="cuda:0",
83 | help="device to run inference on.",
84 | )
85 | args = parser.parse_args()
86 | return args
87 |
88 |
89 | def main():
90 | t0 = time.time()
91 | args_eval = get_args()
92 |
93 | device = torch.device(args_eval.device)
94 |
95 | # load training args
96 | run_dir = os.path.dirname(args_eval.checkpoint)
97 | args = SimpleNamespace(
98 | **json.load(open(os.path.join(run_dir, "params.json"), "r"))
99 | )
100 |
101 | # assert category is correct
102 | syn_id = args_eval.input_path.split("/")[-2]
103 | mesh_name = args_eval.input_path.split("/")[-1]
104 | assert syn_id == cat_to_synset[args.category]
105 |
106 | # output directories
107 | mesh_out_dir = os.path.join(args_eval.output_dir, "meshes", syn_id)
108 | mesh_out_file = os.path.join(
109 | mesh_out_dir, mesh_name.replace(".ply", ".off")
110 | )
111 | meta_out_dir = os.path.join(
112 | args_eval.output_dir, "meta", syn_id, mesh_name.replace(".ply", "")
113 | )
114 | orig_dir = os.path.join(meta_out_dir, "original_retrieved")
115 | deformed_dir = os.path.join(meta_out_dir, "deformed")
116 | os.makedirs(mesh_out_dir, exist_ok=True)
117 | os.makedirs(meta_out_dir, exist_ok=True)
118 | os.makedirs(orig_dir, exist_ok=True)
119 | os.makedirs(deformed_dir, exist_ok=True)
120 |
121 | # redirect logging
122 | sys.stdout = open(os.path.join(meta_out_dir, "log.txt"), "w")
123 |
124 | # initialize deformer
125 | # input points
126 | points = np.array(trimesh.load(args_eval.input_path).vertices)
127 |
128 | # dataloader
129 | data_root = args.data_root
130 | mesh_dataset = ShapeNetMesh(
131 | data_root=data_root,
132 | split="train",
133 | category=args.category,
134 | normals=False,
135 | )
136 | point_dataset = FixedPointsCachedDataset(
137 | f"data/shapenet_pointcloud/train/{cat_to_synset[args.category]}.pkl",
138 | npts=300,
139 | )
140 |
141 | # setup model
142 | deformer = NeuralFlowDeformer(
143 | latent_size=args.lat_dims,
144 | f_width=args.deformer_nf,
145 | s_nlayers=2,
146 | s_width=5,
147 | method=args.solver,
148 | nonlinearity=args.nonlin,
149 | arch="imnet",
150 | adjoint=args.adjoint,
151 | rtol=args.rtol,
152 | atol=args.atol,
153 | via_hub=True,
154 | no_sign_net=(not args.sign_net),
155 | symm_dim=(2 if args.symm else None),
156 | )
157 |
158 | lat_params = torch.nn.Parameter(
159 | torch.randn(mesh_dataset.n_shapes, args.lat_dims) * 1e-1,
160 | requires_grad=True,
161 | )
162 | deformer.add_lat_params(lat_params)
163 | deformer.to(device)
164 |
165 | # load checkpoint
166 | resume_dict = torch.load(args_eval.checkpoint)
167 | deformer.load_state_dict(resume_dict["deformer_state_dict"])
168 |
169 | # embed
170 | embedder = LatentEmbedder(point_dataset, mesh_dataset, deformer, topk=5)
171 | input_pts = torch.tensor(points)[None].to(device)
172 | lat_codes_pre, lat_codes_post = embedder.embed(
173 | input_pts,
174 | matching="two_way",
175 | verbose=True,
176 | lr=1e-2,
177 | embedding_niter=args_eval.embedding_niter,
178 | finetune_niter=args_eval.finetune_niter,
179 | bs=4,
180 | seed=1,
181 | )
182 |
183 | # retrieve deformed models
184 | deformed_meshes, orig_meshes, dist = embedder.retrieve(
185 | lat_codes_post, tar_pts=points, matching="two_way"
186 | )
187 | asort = np.argsort(dist)
188 | dist = [dist[i] for i in asort]
189 | deformed_meshes = [deformed_meshes[i] for i in asort]
190 | orig_meshes = [orig_meshes[i] for i in asort]
191 |
192 | # output best mehs
193 | vb, fb = deformed_meshes[0]
194 | trimesh.Trimesh(vb, fb).export(mesh_out_file)
195 |
196 | # meta directory
197 | for i in range(len(deformed_meshes)):
198 | vo, fo = orig_meshes[i]
199 | vd, fd = deformed_meshes[i]
200 | trimesh.Trimesh(vo, fo).export(os.path.join(orig_dir, f"{i}.ply"))
201 | trimesh.Trimesh(vd, fd).export(os.path.join(deformed_dir, f"{i}.ply"))
202 | np.save(os.path.join(meta_out_dir, "latent.npy"), lat_codes_pre)
203 | t1 = time.time()
204 | print(f"Total Timelapse: {t1-t0:.4f}")
205 |
206 |
207 | if __name__ == "__main__":
208 | main()
209 |
--------------------------------------------------------------------------------
/shapenet_train.py:
--------------------------------------------------------------------------------
1 | """Training script shapenet deformation space experiment.
2 | """
3 | import argparse
4 | import json
5 | import os
6 | import glob
7 | import numpy as np
8 | import time
9 | import trimesh
10 |
11 | import shapeflow.utils.train_utils as utils
12 | from shapeflow.layers.chamfer_layer import ChamferDistKDTree
13 | from shapeflow.layers.deformation_layer import NeuralFlowDeformer
14 | import shapenet_dataloader as dl
15 |
16 | import torch
17 | import torch.optim as optim
18 | import torch.nn as nn
19 | from torch.utils.data import DataLoader
20 | from torch.utils.tensorboard import SummaryWriter
21 |
22 | np.set_printoptions(precision=4)
23 |
24 |
25 | # Various choices for losses and optimizers.
26 | LOSSES = {
27 | "l1": torch.nn.L1Loss(),
28 | "l2": torch.nn.MSELoss(),
29 | "huber": torch.nn.SmoothL1Loss(),
30 | }
31 |
32 | OPTIMIZERS = {
33 | "sgd": optim.SGD,
34 | "adam": optim.Adam,
35 | "adadelta": optim.Adadelta,
36 | "adagrad": optim.Adagrad,
37 | "rmsprop": optim.RMSprop,
38 | }
39 |
40 | SOLVERS = [
41 | "dopri5",
42 | "adams",
43 | "euler",
44 | "midpoint",
45 | "rk4",
46 | "explicit_adams",
47 | "fixed_adams",
48 | "bosh3",
49 | "adaptive_heun",
50 | "tsit5",
51 | ]
52 |
53 |
54 | def compute_latent_dict(deformer, dataset):
55 | """
56 | Args:
57 | deformer:
58 | dataset:
59 | Returns:
60 | a dict that maps filenames to latent codes.
61 | """
62 | # Encode all shapes from dataloader into latents.
63 | all_filenames = dataset.file_splits["train"]
64 | all_filenames = [dl.strip_name(f) for f in all_filenames]
65 | if isinstance(deformer, nn.DataParallel):
66 | all_latents = deformer.module.net.lat_params.detach().cpu().numpy()
67 | else:
68 | all_latents = deformer.net.lat_params.detach().cpu().numpy()
69 |
70 | return dict(zip(all_filenames, all_latents))
71 |
72 |
73 | def get_k(epoch):
74 | if epoch < 10:
75 | return 4000
76 | elif epoch < 50:
77 | return 800
78 | elif epoch < 80:
79 | return 100
80 | else:
81 | return 10
82 |
83 |
84 | def train_or_eval(
85 | mode,
86 | args,
87 | deformer,
88 | chamfer_dist,
89 | dataloader,
90 | epoch,
91 | global_step,
92 | device,
93 | logger,
94 | writer,
95 | optimizer,
96 | vis_loader=None,
97 | ):
98 | """Training / Eval function."""
99 | modes = ["train", "eval"]
100 | if mode not in modes:
101 | raise ValueError(f"mode ({mode}) must be one of {modes}.")
102 | if mode == "train":
103 | deformer.train()
104 | else:
105 | deformer.eval()
106 | tot_loss = 0
107 | count = 0
108 | criterion = LOSSES[args.loss_type]
109 | epoch_images = []
110 | epoch_latents = []
111 |
112 | with torch.set_grad_enabled(mode == "train"):
113 | toc = time.time()
114 |
115 | for batch_idx, data_tensors in enumerate(dataloader):
116 | tic = time.time()
117 | # Send tensors to device.
118 | data_tensors = [t.to(device) for t in data_tensors]
119 | (
120 | ii,
121 | jj,
122 | source_pts,
123 | target_pts,
124 | source_img,
125 | target_img,
126 | ) = data_tensors
127 |
128 | bs = len(source_pts)
129 | optimizer.zero_grad()
130 |
131 | # Batch together source and target to create two-way loss training.
132 | # Cannot call deformer twice (once for each way) because that
133 | # breaks odeint_ajoint's gradient computation. not sure why.
134 | source_target_points = torch.cat([source_pts, target_pts], dim=0)
135 | target_source_points = torch.cat([target_pts, source_pts], dim=0)
136 |
137 | source_target_latents = torch.cat([ii, jj], dim=0)
138 | target_source_latents = torch.cat([jj, ii], dim=0)
139 | latent_seq = torch.stack(
140 | [source_target_latents, target_source_latents], dim=1
141 | )
142 | deformed_pts = deformer(
143 | source_target_points[..., :3], latent_seq
144 | ) # Already set to via_hub.
145 |
146 | if mode == "eval":
147 | # Add thumbnail images for visualizing latent embedding.
148 | epoch_images += [torch.cat([source_img, target_img], dim=0)]
149 | source_target_latents = deformer.module.get_lat_params(
150 | source_target_latents
151 | )
152 | epoch_latents += [source_target_latents]
153 |
154 | # Symmetric pair of matching losses.
155 | if args.symm:
156 | _, _, dist = chamfer_dist(
157 | utils.symmetric_duplication(deformed_pts, symm_dim=2),
158 | utils.symmetric_duplication(
159 | target_source_points[..., :3], symm_dim=2
160 | ),
161 | )
162 | else:
163 | _, _, dist = chamfer_dist(
164 | deformed_pts, target_source_points[..., :3]
165 | )
166 |
167 | loss = criterion(dist, torch.zeros_like(dist))
168 |
169 | # Check amount of deformation.
170 | deform_abs = torch.mean(
171 | torch.norm(deformed_pts - source_target_points, dim=-1)
172 | )
173 |
174 | if mode == "train":
175 | loss.backward()
176 |
177 | # Gradient clipping.
178 | torch.nn.utils.clip_grad_value_(
179 | deformer.module.parameters(), args.clip_grad
180 | )
181 |
182 | optimizer.step()
183 |
184 | tot_loss += loss.item()
185 | count += bs
186 |
187 | if batch_idx % args.log_interval == 0:
188 | # Logger log.
189 | logger.info(
190 | "{} Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t"
191 | "Dist Mean: {:.6f}\t"
192 | "Deform Mean: {:.6f}\t"
193 | "DataTime: {:.4f}\tComputeTime: {:.4f}".format(
194 | mode,
195 | epoch,
196 | batch_idx * bs,
197 | len(dataloader) * bs,
198 | 100.0 * batch_idx / len(dataloader),
199 | loss.item(),
200 | np.sqrt(loss.item()),
201 | deform_abs.item(),
202 | tic - toc,
203 | time.time() - tic,
204 | )
205 | )
206 | # Tensorboard log.
207 | writer.add_scalar(
208 | f"{mode}/loss_sum",
209 | loss.item(),
210 | global_step=int(global_step),
211 | )
212 | writer.add_scalar(
213 | f"{mode}/dist_avg",
214 | np.sqrt(loss.item()),
215 | global_step=int(global_step),
216 | )
217 | writer.add_scalar(
218 | f"{mode}/def_mean",
219 | deform_abs.item(),
220 | global_step=int(global_step),
221 | )
222 |
223 | if mode == "train":
224 | global_step += 1
225 | toc = time.time()
226 | tot_loss /= count
227 |
228 | # # visualize embeddings
229 | if mode == "eval":
230 | epoch_images = torch.cat(epoch_images, dim=0).permute(0, 3, 1, 2) # \
231 | # [N,C,H,W]
232 | epoch_images = epoch_images.float() / 255.0
233 | epoch_latents = torch.cat(epoch_latents, dim=0)
234 | writer.add_embedding(
235 | mat=epoch_latents, label_img=epoch_images, global_step=epoch
236 | )
237 |
238 | # # visualize a few deformation examples in tensorboard
239 | if args.vis_mesh and (vis_loader is not None) and (mode == "eval"):
240 | # add deformation demo
241 | with torch.set_grad_enabled(False):
242 | for ind, data_tensors in enumerate(vis_loader): # batch size = 1
243 | ii = torch.tensor([data_tensors[0]], dtype=torch.long)
244 | jj = torch.tensor([data_tensors[1]], dtype=torch.long)
245 |
246 | source_latents = deformer.module.get_lat_params(ii)
247 | target_latents = deformer.module.get_lat_params(jj)
248 | hub_latents = torch.zeros_like(source_latents)
249 |
250 | data_tensors = [
251 | t.unsqueeze(0).to(device) for t in data_tensors[2:]
252 | ]
253 | vi, fi, vj, fj = data_tensors
254 | vi = vi[0]
255 | fi = fi[0]
256 | vj = vj[0]
257 | fj = fj[0]
258 | vi_j = deformer(
259 | vi[..., :3],
260 | torch.stack(
261 | [source_latents, hub_latents, target_latents],
262 | dim=1,
263 | ),
264 | )
265 | vj_i = deformer(
266 | vj[..., :3],
267 | torch.stack(
268 | [target_latents, hub_latents, source_latents],
269 | dim=1,
270 | ),
271 | )
272 |
273 | accu_i, _, _ = chamfer_dist(vi_j, vj) # [1, m]
274 | accu_j, _, _ = chamfer_dist(vj_i, vi) # [1, n]
275 |
276 | # Find the max dist between pairs of original shapes for
277 | # normalizing colors
278 | chamfer_dist.set_reduction_method("max")
279 | _, _, max_dist = chamfer_dist(vi, vj) # [1,]
280 | chamfer_dist.set_reduction_method("mean")
281 |
282 | # Normalize the accuracies wrt. the distance between src
283 | # and tgt meshes.
284 | ci = utils.colorize_scalar_tensors(
285 | accu_i / max_dist, vmin=0.0, vmax=1.0, cmap="coolwarm"
286 | )
287 | cj = utils.colorize_scalar_tensors(
288 | accu_j / max_dist, vmin=0.0, vmax=1.0, cmap="coolwarm"
289 | )
290 | ci = (ci * 255.0).int()
291 | cj = (cj * 255.0).int()
292 |
293 | # Save mesh.
294 | samp_dir = os.path.join(args.log_dir, "deformation_samples")
295 | os.makedirs(samp_dir, exist_ok=True)
296 | trimesh.Trimesh(
297 | vi.detach().cpu().numpy()[0], fi.detach().cpu().numpy()[0]
298 | ).export(os.path.join(samp_dir, f"samp{ind}_src.obj"))
299 | trimesh.Trimesh(
300 | vj.detach().cpu().numpy()[0], fj.detach().cpu().numpy()[0]
301 | ).export(os.path.join(samp_dir, f"samp{ind}_tar.obj"))
302 | trimesh.Trimesh(
303 | vi_j.detach().cpu().numpy()[0],
304 | fi.detach().cpu().numpy()[0],
305 | ).export(os.path.join(samp_dir, f"samp{ind}_src_to_tar.obj"))
306 | trimesh.Trimesh(
307 | vj_i.detach().cpu().numpy()[0],
308 | fj.detach().cpu().numpy()[0],
309 | ).export(os.path.join(samp_dir, f"samp{ind}_tar_to_src.obj"))
310 |
311 | # Add colorized mesh to tensorboard.
312 | writer.add_mesh(
313 | f"samp{ind}/src",
314 | vertices=vi,
315 | faces=fi,
316 | global_step=int(epoch),
317 | )
318 | writer.add_mesh(
319 | f"samp{ind}/tar",
320 | vertices=vj,
321 | faces=fj,
322 | global_step=int(epoch),
323 | )
324 | writer.add_mesh(
325 | f"samp{ind}/src_to_tar",
326 | vertices=vi_j,
327 | faces=fi,
328 | colors=ci,
329 | global_step=int(epoch),
330 | )
331 | writer.add_mesh(
332 | f"samp{ind}/tar_to_src",
333 | vertices=vj_i,
334 | faces=fj,
335 | colors=cj,
336 | global_step=int(epoch),
337 | )
338 |
339 | return tot_loss
340 |
341 |
342 | def get_args():
343 | """Parse command line arguments."""
344 | parser = argparse.ArgumentParser(description="ShapeNet Deformation Space")
345 |
346 | parser.add_argument(
347 | "--batch_size_per_gpu",
348 | type=int,
349 | default=16,
350 | metavar="N",
351 | help="input batch size for training (default: 10)",
352 | )
353 | parser.add_argument(
354 | "--epochs",
355 | type=int,
356 | default=100,
357 | metavar="N",
358 | help="number of epochs to train (default: 100)",
359 | )
360 | parser.add_argument(
361 | "--pseudo_train_epoch_size",
362 | type=int,
363 | default=2048,
364 | metavar="N",
365 | help="number of samples in an pseudo-epoch. (default: 2048)",
366 | )
367 | parser.add_argument(
368 | "--pseudo_eval_epoch_size",
369 | type=int,
370 | default=128,
371 | metavar="N",
372 | help="number of samples in an pseudo-epoch. (default: 128)",
373 | )
374 | parser.add_argument(
375 | "--lr",
376 | type=float,
377 | default=1e-3,
378 | metavar="R",
379 | help="learning rate (default: 0.001)",
380 | )
381 | parser.add_argument(
382 | "--no_cuda",
383 | action="store_true",
384 | default=False,
385 | help="disables CUDA training",
386 | )
387 | parser.add_argument(
388 | "--seed",
389 | type=int,
390 | default=1,
391 | metavar="S",
392 | help="random seed (default: 1)",
393 | )
394 | parser.add_argument(
395 | "--data_root",
396 | type=str,
397 | default="data/shapenet_simplified",
398 | help="path to mesh folder root (default: data/shapenet_simplified)",
399 | )
400 | parser.add_argument(
401 | "--category", type=str, default="chair", help="the shape category."
402 | )
403 | parser.add_argument(
404 | "--thumbnails_root",
405 | type=str,
406 | default="data/shapenet_thumbnails",
407 | help="path to thumbnails folder root "
408 | "(default: data/shapenet_thumbnails)",
409 | )
410 | parser.add_argument(
411 | "--deformer_arch",
412 | type=str,
413 | choices=["imnet", "vanilla"],
414 | default="imnet",
415 | help="deformer architecture. (default: imnet)",
416 | )
417 | parser.add_argument(
418 | "--solver",
419 | type=str,
420 | choices=SOLVERS,
421 | default="dopri5",
422 | help="ode solver. (default: dopri5)",
423 | )
424 | parser.add_argument(
425 | "--atol",
426 | type=float,
427 | default=1e-5,
428 | help="absolute error tolerence in ode solver. (default: 1e-5)",
429 | )
430 | parser.add_argument(
431 | "--rtol",
432 | type=float,
433 | default=1e-5,
434 | help="relative error tolerence in ode solver. (default: 1e-5)",
435 | )
436 | parser.add_argument(
437 | "--log_interval",
438 | type=int,
439 | default=10,
440 | metavar="N",
441 | help="how many batches to wait before logging training status",
442 | )
443 | parser.add_argument(
444 | "--log_dir", type=str, required=True, help="log directory for run"
445 | )
446 | parser.add_argument(
447 | "--nonlin", type=str, default="elu", help="type of nonlinearity to use"
448 | )
449 | parser.add_argument(
450 | "--optim", type=str, default="adam", choices=list(OPTIMIZERS.keys())
451 | )
452 | parser.add_argument(
453 | "--loss_type", type=str, default="l2", choices=list(LOSSES.keys())
454 | )
455 | parser.add_argument(
456 | "--resume",
457 | type=str,
458 | default=None,
459 | help="path to checkpoint if resume is needed",
460 | )
461 | parser.add_argument(
462 | "-n",
463 | "--nsamples",
464 | default=2048,
465 | type=int,
466 | help="number of sample points to draw per shape.",
467 | )
468 | parser.add_argument(
469 | "--lat_dims", default=32, type=int, help="number of latent dimensions."
470 | )
471 | parser.add_argument(
472 | "--datasubset",
473 | default=0,
474 | type=int,
475 | help="0 to not subset. else subset this many examples.",
476 | )
477 | parser.add_argument(
478 | "--deformer_nf",
479 | default=100,
480 | type=int,
481 | help="number of base number of feature layers in deformer (imnet).",
482 | )
483 | parser.add_argument(
484 | "--lr_scheduler", dest="lr_scheduler", action="store_true"
485 | )
486 | parser.add_argument(
487 | "--no_lr_scheduler", dest="lr_scheduler", action="store_false"
488 | )
489 | parser.set_defaults(lr_scheduler=True)
490 | parser.set_defaults(normals=True)
491 | parser.add_argument(
492 | "--visualize_mesh",
493 | dest="vis_mesh",
494 | action="store_true",
495 | help="visualize deformation for meshes of sample validation data "
496 | "in tensorboard.",
497 | )
498 | parser.add_argument(
499 | "--no_visualize_mesh",
500 | dest="vis_mesh",
501 | action="store_false",
502 | help="no visualize deformation for meshes of sample validation data "
503 | "in tensorboard.",
504 | )
505 | parser.set_defaults(vis_mesh=True)
506 | parser.add_argument(
507 | "--adjoint",
508 | dest="adjoint",
509 | action="store_true",
510 | help="use adjoint solver to propagate gradients thru odeint.",
511 | )
512 | parser.add_argument(
513 | "--no_adjoint",
514 | dest="adjoint",
515 | action="store_false",
516 | help="not use adjoint solver to propagate gradients thru odeint.",
517 | )
518 | parser.set_defaults(adjoint=True)
519 | parser.add_argument(
520 | "--sign_net",
521 | dest="sign_net",
522 | action="store_true",
523 | help="use sign net.",
524 | )
525 | parser.add_argument(
526 | "--no_sign_net",
527 | dest="sign_net",
528 | action="store_false",
529 | help="not use sign net.",
530 | )
531 | parser.set_defaults(sign_net=False)
532 | parser.add_argument(
533 | "--clip_grad",
534 | default=1.0,
535 | type=float,
536 | help="clip gradient to this value. large value basically "
537 | "deactivates it.",
538 | )
539 | parser.add_argument(
540 | "--sampling_method",
541 | type=str,
542 | choices=[
543 | "nn_replace",
544 | "nn_no_replace",
545 | "all_replace",
546 | "all_no_replace",
547 | ],
548 | default="nn_no_replace",
549 | help="method for sampling pairs of shape to deform.",
550 | )
551 | parser.add_argument(
552 | "--symm", dest="symm", action="store_true", help="use symmetric flow."
553 | )
554 | parser.add_argument(
555 | "--no_symm",
556 | dest="symm",
557 | action="store_false",
558 | help="not use symmetric flow.",
559 | )
560 | parser.set_defaults(symm=False)
561 | args = parser.parse_args()
562 | return args
563 |
564 |
565 | def main():
566 | args = get_args()
567 |
568 | # Adjust batch size based on the number of gpus available.
569 | args.batch_size = int(torch.cuda.device_count()) * args.batch_size_per_gpu
570 | use_cuda = (not args.no_cuda) and torch.cuda.is_available()
571 | kwargs = (
572 | {"num_workers": min(12, args.batch_size), "pin_memory": True}
573 | if use_cuda
574 | else {}
575 | )
576 | device = torch.device("cuda" if use_cuda else "cpu")
577 |
578 | # Log and create snapshots.
579 | filenames_to_snapshot = (
580 | glob.glob("*.py") + glob.glob("*.sh") + glob.glob("layers/*.py")
581 | )
582 | utils.snapshot_files(filenames_to_snapshot, args.log_dir)
583 | logger = utils.get_logger(log_dir=args.log_dir)
584 | with open(os.path.join(args.log_dir, "params.json"), "w") as fh:
585 | json.dump(args.__dict__, fh, indent=2)
586 | logger.info("%s", repr(args))
587 |
588 | args.n_vis = 2 # Number of deformation examples to visualize.
589 |
590 | # Tensorboard writer.
591 | writer = SummaryWriter(log_dir=os.path.join(args.log_dir, "tensorboard"))
592 |
593 | # Random seed for reproducability.
594 | torch.manual_seed(args.seed)
595 | np.random.seed(args.seed)
596 |
597 | # Create dataloaders.
598 | fullset = dl.ShapeNetVertex(
599 | data_root=args.data_root,
600 | split="train",
601 | category=args.category,
602 | nsamples=args.nsamples,
603 | normals=False,
604 | )
605 | if args.datasubset > 0:
606 | fullset.restrict_subset(args.datasubset)
607 |
608 | # Return thumbnails (to visualize embedding during eval).
609 | fullset.add_thumbnails(args.thumbnails_root)
610 |
611 | if "nn_" in args.sampling_method:
612 | args.nn_samp = True
613 | replace = True if args.sampling_method == "nn_replace" else False
614 | train_sampler = dl.LatentNearestNeighborSampler(
615 | dataset=fullset,
616 | src_split="train",
617 | tar_split="train",
618 | n_samples=args.pseudo_train_epoch_size,
619 | k=1000,
620 | replace=replace,
621 | )
622 | eval_sampler = dl.LatentNearestNeighborSampler(
623 | dataset=fullset,
624 | src_split="train",
625 | tar_split="train",
626 | n_samples=args.pseudo_eval_epoch_size,
627 | k=1,
628 | replace=replace,
629 | ) # Pick the closest.
630 | vis_sampler = dl.LatentNearestNeighborSampler(
631 | dataset=fullset,
632 | src_split="train",
633 | tar_split="train",
634 | n_samples=args.n_vis,
635 | k=1,
636 | replace=replace,
637 | ) # Pick the closest.
638 | else:
639 | args.nn_samp = False
640 | replace = True if args.sampling_method == "all_replace" else False
641 | train_sampler = dl.RandomPairSampler(
642 | dataset=fullset,
643 | src_split="train",
644 | tar_split="train",
645 | n_samples=args.pseudo_train_epoch_size,
646 | replace=replace,
647 | )
648 | eval_sampler = dl.RandomPairSampler(
649 | dataset=fullset,
650 | src_split="train",
651 | tar_split="train",
652 | n_samples=args.pseudo_eval_epoch_size,
653 | replace=replace,
654 | )
655 | vis_sampler = dl.RandomPairSampler(
656 | dataset=fullset,
657 | src_split="train",
658 | tar_split="train",
659 | n_samples=args.n_vis,
660 | replace=replace,
661 | )
662 |
663 | # Make sure we are turning off shuffle since we are using samplers!
664 | train_loader = DataLoader(
665 | fullset,
666 | batch_size=args.batch_size,
667 | shuffle=False,
668 | drop_last=True,
669 | sampler=train_sampler,
670 | **kwargs,
671 | )
672 | eval_loader = DataLoader(
673 | fullset,
674 | batch_size=args.batch_size,
675 | shuffle=False,
676 | drop_last=False,
677 | sampler=eval_sampler,
678 | **kwargs,
679 | )
680 |
681 | if args.vis_mesh:
682 | # For loading full meshes for visualization.
683 | simp_data_root = args.data_root
684 | simpset = dl.ShapeNetMesh(
685 | data_root=simp_data_root,
686 | split="train",
687 | category=args.category,
688 | normals=False,
689 | )
690 | if args.datasubset > 0:
691 | simpset.restrict_subset(args.datasubset)
692 | if not (
693 | vis_sampler.dataset.fname_to_idx_dict == simpset.fname_to_idx_dict
694 | ):
695 | raise RuntimeError(
696 | f"vis_sampler ({len(vis_sampler.dataset.fname_to_idx_dict)}) "
697 | f"does not match sample set ({len(simpset.fname_to_idx_dict)})"
698 | )
699 | vis_loader = DataLoader(
700 | simpset,
701 | batch_size=1,
702 | shuffle=False,
703 | drop_last=False,
704 | sampler=vis_sampler,
705 | **kwargs,
706 | )
707 | else:
708 | vis_loader = None
709 |
710 | # Setup model.
711 | deformer = NeuralFlowDeformer(
712 | latent_size=args.lat_dims,
713 | f_width=args.deformer_nf,
714 | s_nlayers=2,
715 | s_width=5,
716 | method=args.solver,
717 | nonlinearity=args.nonlin,
718 | arch="imnet",
719 | adjoint=args.adjoint,
720 | rtol=args.rtol,
721 | atol=args.atol,
722 | via_hub=True,
723 | no_sign_net=(not args.sign_net),
724 | symm_dim=(2 if args.symm else None),
725 | )
726 |
727 | # Awkward workaround to get gradients from odeint_adjoint to lat_params.
728 | lat_params = torch.nn.Parameter(
729 | torch.randn(fullset.n_shapes, args.lat_dims) * 1e-1, requires_grad=True
730 | )
731 | deformer.add_lat_params(lat_params)
732 | deformer.to(device)
733 |
734 | all_model_params = list(deformer.parameters())
735 |
736 | optimizer = OPTIMIZERS[args.optim](all_model_params, lr=args.lr)
737 |
738 | start_ep = 0
739 | global_step = np.zeros(1, dtype=np.uint32)
740 | tracked_stats = np.inf
741 |
742 | if args.resume:
743 | logger.info(
744 | "Loading checkpoint {} ================>".format(args.resume)
745 | )
746 | resume_dict = torch.load(args.resume)
747 | start_ep = resume_dict["epoch"]
748 | global_step = resume_dict["global_step"]
749 | tracked_stats = resume_dict["tracked_stats"]
750 | deformer.load_state_dict(resume_dict["deformer_state_dict"])
751 | optimizer.load_state_dict(resume_dict["optim_state_dict"])
752 | for state in optimizer.state.values():
753 | for k, v in state.items():
754 | if isinstance(v, torch.Tensor):
755 | state[k] = v.to(device)
756 | logger.info("[!] Successfully loaded checkpoint.")
757 |
758 | # More threads don't seem to help.
759 | chamfer_dist = ChamferDistKDTree(reduction="mean", njobs=1)
760 | chamfer_dist.to(device)
761 | deformer = nn.DataParallel(deformer)
762 | deformer.to(device)
763 |
764 | model_param_count = lambda model: sum( # noqa: E731
765 | x.numel() for x in model.parameters()
766 | )
767 | logger.info(
768 | f"{model_param_count(deformer)}(deformer) paramerters in total"
769 | )
770 |
771 | checkpoint_path = os.path.join(args.log_dir, "checkpoint_latest.pth.tar")
772 |
773 | if args.lr_scheduler:
774 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
775 |
776 | train_latent_dict = compute_latent_dict(deformer, fullset)
777 | # Training loop.
778 | for epoch in range(start_ep + 1, args.epochs + 1):
779 | # Set sampler nn graph before train or eval.
780 | if args.nn_samp:
781 | train_loader.sampler.update_nn_graph(
782 | train_latent_dict, train_latent_dict, k=get_k(epoch)
783 | )
784 | eval_loader.sampler.update_nn_graph(
785 | train_latent_dict, train_latent_dict
786 | )
787 | vis_loader.sampler.update_nn_graph(
788 | train_latent_dict, train_latent_dict
789 | )
790 |
791 | _ = train_or_eval(
792 | "train",
793 | args,
794 | deformer,
795 | chamfer_dist,
796 | train_loader,
797 | epoch,
798 | global_step,
799 | device,
800 | logger,
801 | writer,
802 | optimizer,
803 | None,
804 | )
805 | loss_eval = train_or_eval(
806 | "eval",
807 | args,
808 | deformer,
809 | chamfer_dist,
810 | eval_loader,
811 | epoch,
812 | global_step,
813 | device,
814 | logger,
815 | writer,
816 | optimizer,
817 | vis_loader,
818 | )
819 |
820 | if args.lr_scheduler:
821 | scheduler.step(loss_eval)
822 | if loss_eval < tracked_stats:
823 | tracked_stats = loss_eval
824 | is_best = True
825 | else:
826 | is_best = False
827 |
828 | utils.save_checkpoint(
829 | {
830 | "epoch": epoch,
831 | "deformer_state_dict": deformer.module.state_dict(),
832 | "lat_params": lat_params,
833 | "optim_state_dict": optimizer.state_dict(),
834 | "tracked_stats": tracked_stats,
835 | "global_step": global_step,
836 | },
837 | is_best,
838 | epoch,
839 | checkpoint_path,
840 | "_shapeflow",
841 | logger,
842 | )
843 |
844 |
845 | if __name__ == "__main__":
846 | main()
847 |
--------------------------------------------------------------------------------
/shapenet_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #####################################
4 | # Configure Experiment #
5 | #####################################
6 | run_name=demo
7 | log_dir=runs/$run_name
8 | data_root=data/shapenet_simplified
9 |
10 | # Create run directory if it doesn't exist.
11 | mkdir -p runs
12 |
13 | # Launch training.
14 | python shapenet_train.py \
15 | --atol=1e-4 \
16 | --rtol=1e-4 \
17 | --data_root=$data_root \
18 | --pseudo_train_epoch_size=2048 \
19 | --pseudo_eval_epoch_size=128 \
20 | --lr=1e-3 \
21 | --log_dir=$log_dir \
22 | --lr_scheduler \
23 | --visualize_mesh \
24 | --batch_size_per_gpu=32 \
25 | --log_interval=2 \
26 | --epochs=100 \
27 | --no_sign_net \
28 | --adjoint \
29 | --solver='dopri5' \
30 | --deformer_nf=128 \
31 | --nsamples=512 \
32 | --lat_dims=128 \
33 | --nonlin='leakyrelu' \
34 | --symm \
35 | --category=${category} \
36 | --sampling_method='all_no_replace' \
37 |
--------------------------------------------------------------------------------
/utils/render.py:
--------------------------------------------------------------------------------
1 | """Rendering utility functions.
2 | """
3 | import os
4 |
5 | os.environ["PYOPENGL_PLATFORM"] = "osmesa"
6 | import numpy as np # noqa: E402
7 | import trimesh # noqa: E402
8 | import pyrender # noqa: E402
9 |
10 |
11 | def render_trimesh(
12 | trimesh_mesh,
13 | eye,
14 | center,
15 | world_up,
16 | res=(640, 640),
17 | light_intensity=3.0,
18 | ambient_intensity=None,
19 | **kwargs,
20 | ):
21 | """Render a shapenet mesh using default settings.
22 |
23 | Args:
24 | trimesh_mesh: trimesh mesh instance, or a list of trimesh meshes
25 | (or point clouds).
26 | eye: array with shape [3,] containing the XYZ world
27 | space position of the camera.
28 | center: array with shape [3,] containing a position
29 | along the center of the camera's gaze.
30 | world_up: np.float32 array with shape [3,] specifying the
31 | world's up direction; the output camera will have no tilt with respect
32 | to this direction.
33 | res: 2-tuple of int, [width, height], resolution (in pixels) of output
34 | images.
35 | light_intensity: float, light intensity.
36 | ambient_intensity: float, ambient light intensity.
37 | kwargs: additional flags to pass to pyrender renderer.
38 | Returns:
39 | color_img: [*res, 3] color image.
40 | depth_img: [*res, 1] depth image.
41 | world_to_cam: [4, 4] camera to world matrix.
42 | projection_matrix: [4, 4] projection matrix, aka cam_to_img matrix.
43 | """
44 | if not isinstance(trimesh_mesh, list):
45 | trimesh_mesh = [trimesh_mesh]
46 | eye = list2npy(eye).astype(np.float32)
47 | center = list2npy(center).astype(np.float32)
48 | world_up = list2npy(world_up).astype(np.float32)
49 |
50 | # setup camera pose matrix
51 | scene = pyrender.Scene(
52 | ambient_light=ambient_intensity*np.ones([3], dtype=float)
53 | )
54 | for tmesh in trimesh_mesh:
55 | if not (
56 | isinstance(tmesh, trimesh.Trimesh)
57 | or isinstance(tmesh, trimesh.PointCloud)
58 | ):
59 | raise NotImplementedError(
60 | "All instances in trimesh_mesh must be either trimesh.Trimesh "
61 | f"or trimesh.PointCloud. Instead it is {type(tmesh)}."
62 | )
63 | if isinstance(tmesh, trimesh.Trimesh):
64 | mesh = pyrender.Mesh.from_trimesh(tmesh)
65 | elif isinstance(tmesh, trimesh.PointCloud):
66 | if tmesh.colors is not None:
67 | colors = np.array(tmesh.colors)
68 | else:
69 | colors = np.ones_like(tmesh.vertices)
70 | mesh = pyrender.Mesh.from_points(
71 | np.array(tmesh.vertices), colors=colors
72 | )
73 | scene.add(mesh)
74 |
75 | # Set up the camera -- z-axis away from the scene, x-axis right, y-axis up
76 | camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0)
77 |
78 | world_to_cam = look_at(eye[None], center[None], world_up[None])
79 | world_to_cam = world_to_cam[0]
80 | cam_pose = np.linalg.inv(world_to_cam)
81 | scene.add(camera, pose=cam_pose)
82 |
83 | # Set up the light -- a single spot light in the same spot as the camera
84 | light = pyrender.SpotLight(
85 | color=np.ones(3, dtype=np.float32),
86 | intensity=light_intensity,
87 | innerConeAngle=np.pi / 16.0,
88 | )
89 | scene.add(light, pose=cam_pose)
90 |
91 | # Render the scene
92 | r = pyrender.OffscreenRenderer(*res, **kwargs)
93 | color_img, depth_img = r.render(scene)
94 | return (
95 | color_img,
96 | depth_img,
97 | world_to_cam,
98 | camera.get_projection_matrix(*res),
99 | )
100 |
101 |
102 | def _unproject_points(points, projection_matrix, world_to_cam):
103 | """Unproject points from image space to world space."""
104 | # pad
105 | depth = points[:, 2]
106 | xy_scale = (depth - projection_matrix[2, 3]) / (-projection_matrix[2, 2])
107 | points[:, :2] = points[:, :2] * xy_scale[:, None]
108 | points = np.concatenate([points, np.ones_like(points[:, :1])], axis=1)
109 | points[:, 3] = xy_scale
110 | # camera space coordinates
111 | point_cam = (
112 | np.linalg.inv(projection_matrix) @ (points.T)
113 | ).T # [npoints, 4]
114 | # world space coordinates
115 | cam_to_world = np.linalg.inv(world_to_cam)
116 | point_world = (cam_to_world @ (point_cam.T)).T
117 | point_world = point_world[:, :3] / point_world[:, 3:]
118 | return point_world
119 |
120 |
121 | def _points_from_depth(depth_img, zoffset=0):
122 | """Get image space points from depth image."""
123 | point_mask = depth_img != 0.0
124 | point_mask_flat = point_mask.reshape(-1)
125 | w, h = depth_img.shape
126 | x, y = np.meshgrid(
127 | np.linspace(-1.0, 1.0, w), np.linspace(-1.0, 1.0, h), indexing="ij"
128 | )
129 | xy_img = np.stack([y, -x], axis=-1) # [w, h, 2]
130 | xy_flat = xy_img.reshape(-1, 2) # [w*h, 2]
131 | point_img = xy_flat[point_mask_flat] # [npoints, 2]
132 | depth = depth_img.reshape(-1)[point_mask_flat] + zoffset
133 | point_img = np.concatenate([point_img, depth[..., None]], axis=-1)
134 | return point_img
135 |
136 |
137 | def line_meshes(verts, edges, colors=None, poses=None):
138 | """Create pyrender Mesh instance for lines.
139 |
140 | Args:
141 | verts: np.array floats of shape [#v, 3]
142 | edges: np.array ints of shape [#e, 3]
143 | colors: np.array floats of shape [#v, 3]
144 | poses: poses : (x,4,4)
145 | Array of 4x4 transformation matrices for instancing this object.
146 | """
147 | prim = pyrender.primitive.Primitive(
148 | positions=verts,
149 | indices=edges,
150 | color_0=colors,
151 | mode=pyrender.constants.GLTF.LINES,
152 | poses=poses)
153 | return pyrender.mesh.Mesh(primitives=[prim], is_visible=True)
154 |
155 |
156 | def unproject_depth_img(depth_img, projection_matrix, world_to_cam):
157 | """Unproject depth image to point cloud in world coordinates.
158 |
159 | Args:
160 | depth_img: array of [width, height] depth image.
161 | projection_matrix: array of [4, 4], projection matrix, aka cam_to_img
162 | matrix.
163 | world_to_cam: array of [4, 4], world to cam matrix, inverse of camera
164 | pose.
165 |
166 | Returns:
167 | point_world: array of [npoints, 3] depth scan point cloud in world
168 | coordinates.
169 | """
170 | point_img = _points_from_depth(depth_img, zoffset=projection_matrix[2, 3])
171 | point_world = _unproject_points(point_img, projection_matrix, world_to_cam)
172 |
173 | return point_world
174 |
175 |
176 | def list2npy(array):
177 | return array if isinstance(array, np.ndarray) else np.array(array)
178 |
179 |
180 | def r4pad(array):
181 | """pad [..., 3] array to [..., 4] with ones in last channel."""
182 | zeros = np.ones_like(array[..., -1:])
183 | return np.concatenate([array, zeros], axis=-1)
184 |
185 |
186 | def look_at(eye, center, world_up):
187 | """Computes camera viewing matrices (numpy implementation).
188 |
189 | Args:
190 | eye: np.float32 array with shape [batch_size, 3] containing the XYZ world
191 | space position of the camera.
192 | center: np.float32 array with shape [batch_size, 3] containing a position
193 | along the center of the camera's gaze.
194 | world_up: np.float32 array with shape [batch_size, 3] specifying the
195 | world's up direction; the output camera will have no tilt with respect to
196 | this direction.
197 |
198 | Returns:
199 | A [batch_size, 4, 4] np.float32 array containing a right-handed camera
200 | extrinsics matrix that maps points from world space to points in eye space.
201 | """
202 | batch_size = center.shape[0]
203 | vector_degeneracy_cutoff = 1e-6
204 | forward = center - eye
205 | forward_norm = np.linalg.norm(forward, axis=1, keepdims=True)
206 | assert np.all(forward_norm > vector_degeneracy_cutoff)
207 | forward /= forward_norm
208 |
209 | to_side = np.cross(forward, world_up)
210 | to_side_norm = np.linalg.norm(to_side, axis=1, keepdims=True)
211 | assert np.all(to_side_norm > vector_degeneracy_cutoff)
212 | to_side /= to_side_norm
213 | cam_up = np.cross(to_side, forward)
214 |
215 | w_column = np.array(
216 | batch_size * [[0.0, 0.0, 0.0, 1.0]], dtype=np.float32
217 | ) # [batch_size, 4]
218 | w_column = w_column.reshape([batch_size, 4, 1])
219 | view_rotation = np.stack(
220 | [to_side, cam_up, -forward, np.zeros_like(to_side, dtype=np.float32)],
221 | axis=1,
222 | ) # [batch_size, 4, 3] matrix
223 | view_rotation = np.concatenate(
224 | [view_rotation, w_column], axis=2
225 | ) # [batch_size, 4, 4]
226 |
227 | identity_batch = np.tile(np.expand_dims(np.eye(3), 0), [batch_size, 1, 1])
228 | view_translation = np.concatenate(
229 | [identity_batch, np.expand_dims(-eye, 2)], 2
230 | )
231 | view_translation = np.concatenate(
232 | [view_translation, w_column.reshape([batch_size, 1, 4])], 1
233 | )
234 | camera_matrices = np.matmul(view_rotation, view_translation)
235 | return camera_matrices
236 |
--------------------------------------------------------------------------------