├── .gitignore ├── LICENSE ├── README.md ├── deform_3_obj.py ├── doc ├── tb_losses.png ├── tb_meshes.png └── teaser.png ├── download_data.sh ├── download_pretrained.sh ├── exp1_visualize_results.ipynb ├── exp2_animate_smpl.ipynb ├── exp3_parametric_cad.ipynb ├── paper_figures ├── anim │ ├── linear.npz │ ├── linear.png │ ├── linear │ │ ├── mesh_0.ply │ │ ├── mesh_1.ply │ │ ├── mesh_10.ply │ │ ├── mesh_11.ply │ │ ├── mesh_12.ply │ │ ├── mesh_13.ply │ │ ├── mesh_14.ply │ │ ├── mesh_15.ply │ │ ├── mesh_16.ply │ │ ├── mesh_17.ply │ │ ├── mesh_18.ply │ │ ├── mesh_19.ply │ │ ├── mesh_2.ply │ │ ├── mesh_20.ply │ │ ├── mesh_3.ply │ │ ├── mesh_4.ply │ │ ├── mesh_5.ply │ │ ├── mesh_6.ply │ │ ├── mesh_7.ply │ │ ├── mesh_8.ply │ │ └── mesh_9.ply │ ├── shapeflow_divfree_edge0.npz │ ├── shapeflow_divfree_edge0.png │ ├── shapeflow_divfree_edge0 │ │ ├── mesh_0.ply │ │ ├── mesh_1.ply │ │ ├── mesh_10.ply │ │ ├── mesh_11.ply │ │ ├── mesh_12.ply │ │ ├── mesh_13.ply │ │ ├── mesh_14.ply │ │ ├── mesh_15.ply │ │ ├── mesh_16.ply │ │ ├── mesh_17.ply │ │ ├── mesh_18.ply │ │ ├── mesh_19.ply │ │ ├── mesh_2.ply │ │ ├── mesh_20.ply │ │ ├── mesh_3.ply │ │ ├── mesh_4.ply │ │ ├── mesh_5.ply │ │ ├── mesh_6.ply │ │ ├── mesh_7.ply │ │ ├── mesh_8.ply │ │ └── mesh_9.ply │ ├── shapeflow_divfree_edge2.npz │ ├── shapeflow_divfree_edge2.png │ ├── shapeflow_divfree_edge2 │ │ ├── mesh_0.ply │ │ ├── mesh_1.ply │ │ ├── mesh_10.ply │ │ ├── mesh_11.ply │ │ ├── mesh_12.ply │ │ ├── mesh_13.ply │ │ ├── mesh_14.ply │ │ ├── mesh_15.ply │ │ ├── mesh_16.ply │ │ ├── mesh_17.ply │ │ ├── mesh_18.ply │ │ ├── mesh_19.ply │ │ ├── mesh_2.ply │ │ ├── mesh_20.ply │ │ ├── mesh_3.ply │ │ ├── mesh_4.ply │ │ ├── mesh_5.ply │ │ ├── mesh_6.ply │ │ ├── mesh_7.ply │ │ ├── mesh_8.ply │ │ └── mesh_9.ply │ ├── shapeflow_edge0.npz │ ├── shapeflow_edge0.png │ ├── shapeflow_edge0 │ │ ├── mesh_0.ply │ │ ├── mesh_1.ply │ │ ├── mesh_10.ply │ │ ├── mesh_11.ply │ │ ├── mesh_12.ply │ │ ├── mesh_13.ply │ │ ├── mesh_14.ply │ │ ├── mesh_15.ply │ │ ├── mesh_16.ply │ │ ├── mesh_17.ply │ │ ├── mesh_18.ply │ │ ├── mesh_19.ply │ │ ├── mesh_2.ply │ │ ├── mesh_20.ply │ │ ├── mesh_3.ply │ │ ├── mesh_4.ply │ │ ├── mesh_5.ply │ │ ├── mesh_6.ply │ │ ├── mesh_7.ply │ │ ├── mesh_8.ply │ │ └── mesh_9.ply │ └── volume_change.pdf ├── anim_debug │ ├── linear.npz │ ├── linear.png │ ├── shapeflow_divfree_edge0.npz │ ├── shapeflow_divfree_edge0.png │ ├── shapeflow_divfree_edge2.npz │ ├── shapeflow_divfree_edge2.png │ ├── shapeflow_edge0.npz │ ├── shapeflow_edge0.png │ └── volume_change.pdf ├── parametric_cad │ ├── deformed.png │ └── groundtruth.png ├── shapegen │ ├── 0_dmc.jpeg │ ├── 0_gt.jpeg │ ├── 0_input.jpeg │ ├── 0_onet.jpeg │ ├── 0_psgn.jpeg │ ├── 0_r2n2.jpeg │ ├── 0_retrieved.jpeg │ ├── 0_shapeflow.jpeg │ ├── 1_dmc.jpeg │ ├── 1_gt.jpeg │ ├── 1_input.jpeg │ ├── 1_onet.jpeg │ ├── 1_psgn.jpeg │ ├── 1_r2n2.jpeg │ ├── 1_retrieved.jpeg │ ├── 1_shapeflow.jpeg │ ├── 2_dmc.jpeg │ ├── 2_gt.jpeg │ ├── 2_input.jpeg │ ├── 2_onet.jpeg │ ├── 2_psgn.jpeg │ ├── 2_r2n2.jpeg │ ├── 2_retrieved.jpeg │ ├── 2_shapeflow.jpeg │ ├── old │ │ ├── 0_dmc.jpeg │ │ ├── 0_gt.jpeg │ │ ├── 0_input.jpeg │ │ ├── 0_onet.jpeg │ │ ├── 0_psgn.jpeg │ │ ├── 0_r2n2.jpeg │ │ ├── 0_shapeflow.jpeg │ │ ├── 1_dmc.jpeg │ │ ├── 1_gt.jpeg │ │ ├── 1_input.jpeg │ │ ├── 1_onet.jpeg │ │ ├── 1_psgn.jpeg │ │ ├── 1_r2n2.jpeg │ │ ├── 1_shapeflow.jpeg │ │ ├── 2_dmc.jpeg │ │ ├── 2_gt.jpeg │ │ ├── 2_input.jpeg │ │ ├── 2_onet.jpeg │ │ ├── 2_psgn.jpeg │ │ ├── 2_r2n2.jpeg │ │ └── 2_shapeflow.jpeg │ └── shapegen.zip └── teaser │ └── flow.pdf ├── render_example.ipynb ├── requirements.txt ├── scripts ├── cache_shapenet_pointcloud.py ├── render_thumbnail.py └── simplify.py ├── shapeflow ├── __init__.py ├── layers │ ├── __init__.py │ ├── chamfer_layer.py │ ├── deformation_layer.py │ ├── pde_layer.py │ ├── pointnet_layer.py │ └── shared_definition.py └── utils │ ├── __init__.py │ └── train_utils.py ├── shapenet_dataloader.py ├── shapenet_embedding.py ├── shapenet_generation.sh ├── shapenet_reconstruct.ipynb ├── shapenet_reconstruct.py ├── shapenet_train.py ├── shapenet_train.sh ├── utils └── render.py └── visualize_deformer.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Meshes 132 | *.ply 133 | 134 | # VSCode 135 | .vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Chiyu Max Jiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ShapeFlow 2 | [NeurIPS 20'] ShapeFlow: Learnable Deformations Among 3D Shapes. 3 | 4 | 5 | By: [Chiyu "Max" Jiang*](http://maxjiang.ml/), [Jingwei Huang*](http://stanford.edu/~jingweih/), [Andrea Tagliasacchi](http://gfx.uvic.ca/people/ataiya/), [Leonidas Guibas](https://geometry.stanford.edu/member/guibas/) 6 | 7 | \[[Project Website]()\] \[[Paper](https://arxiv.org/abs/2006.07982)\] 8 | 9 | deepdeform_teaser 10 | 11 | 12 | ## 1. Introduction 13 | We present ShapeFlow, a flow-based model for learning a deformation space for entire classes of 3D shapes with large intra-class variations. ShapeFlow allows learning a multi-template deformation space that is agnostic to shape topology, yet preserves fine geometric details. Different from a generative space where a latent vector is directly decoded into a shape, a deformation space decodes a vector into a continuous flow that can advect a source shape towards a target. Such a space naturally allows the disentanglement of geometric style (coming from the source) and structural pose (conforming to the target). We parametrize the deformation between geometries as a learned continuous flow field via a neural network and show that such deformations can be guaranteed to have desirable properties, such as bijectivity, freedom from self-intersections, or volume preservation. We illustrate the effectiveness of this learned deformation space for various downstream applications, including shape generation via deformation, geometric style transfer, unsupervised learning of a consistent parameterization for entire classes of shapes, and shape interpolation. 14 | 15 | ## 2. Getting Started 16 | ### Installing dependencies 17 | We recommend using pip to install all required dependencies with ease. 18 | ``` 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ### Optional dependency (rendering) 23 | We strongly suggest installing the optional dependencies for rendering meshes, so that you can visualize the results using interactive notebooks. 24 | `pyrender` can be installed via pip: 25 | ``` 26 | pip install pyrender 27 | ``` 28 | 29 | Additionally to run the notebook renderings on a headless server, follow the [instructions](https://pyrender.readthedocs.io/en/latest/install/#python-installation) for installing `OSMesa`. 30 | 31 | ### Download and unpack data 32 | To download and unpack the data used in the experiment, please use the utility script privided. 33 | ``` 34 | bash download_data.sh 35 | ``` 36 | 37 | ## 3. Reproduce results 38 | ### Run experiments 39 | Please use our provided launch script to start training the shape deformation model. 40 | ``` 41 | bash shapenet_train.sh 42 | ``` 43 | 44 | The training will launch on all available GPUs. Mask GPUs accordingly if you want to use only a subset of all GPUs. The initial tests are done on NVIDIA Volta V100 GPUs, therefore the `batch_size_per_gpu=16` might need to be adjusted accordingly for GPUs with smaller or larger memory limits if the out of memory error is triggered. 45 | 46 | ### Load and visualize pretrained checkpoint 47 | First download the pretrained checkpoint. 48 | ``` 49 | wget island.me.berkeley.edu/files/pretrained_ckpt.zip 50 | mkdir -p runs 51 | mv pretrained_ckpt.zip runs 52 | cd runs; unzip pretrained_ckpt.zip; rm pretrained_ckpt.zip; cd .. 53 | ``` 54 | 55 | Next, run through the cells in `visualize_deformer.ipynb`. 56 | 57 | ### Monitor training 58 | After launching the training script, a `runs` directory will be created, with different runs each as a separate subfolder within. To monitor the training process based on text logs, use 59 | ``` 60 | tail -f runs//log.txt 61 | ``` 62 | 63 | To monitor the training process using tensorboard, do: 64 | ``` 65 | # if you are running this on a remote server via ssh 66 | ssh my_favorite_machine -L 6006:localhost:6006 67 | 68 | # go to the directory containing the tensorboard log 69 | cd path/to/ShapeFlow/runs//tensorboard 70 | 71 | # launch tensorboard 72 | tensorboard --logdir . --port 6006 73 | ``` 74 | Tensorboard allows tracking of deformation losses, as well as visualizing the source / target / deformed meshes. The deformed meshes are colored by the distance per vertex with respect to target shape. 75 | 76 | tensorboard_losses 77 | tensorboard_meshes 78 | 79 | ### Citation 80 | If you find our code useful for your work, please consider citing our paper: 81 | ``` 82 | @inproceedings{jiang2020shapeflow, 83 | title={ShapeFlow: Learnable Deformations Among 3D Shapes}, 84 | author={Jiang, Chiyu and Huang, Jingwei and Tagliasacchi, Andrea and Guibas, Leonidas}, 85 | booktitle={Advances in Neural Information Processing Systems}, 86 | year={2020} 87 | } 88 | ``` 89 | 90 | ### Contact 91 | Please contact [Max Jiang](mailto:maxjiang93@gmail.com) if you have further questions! 92 | -------------------------------------------------------------------------------- /deform_3_obj.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | from shapeflow.layers.chamfer_layer import ChamferDistKDTree 4 | from shapeflow.layers.deformation_layer import NeuralFlowDeformer 5 | from shapeflow.layers.pointnet_layer import PointNetEncoder 6 | 7 | import torch 8 | import numpy as np 9 | from time import time 10 | import trimesh 11 | from glob import glob 12 | 13 | 14 | files = sorted(glob("data/shapenet_watertight/val/03001627/*/*.ply")) 15 | m1 = trimesh.load(files[1]) 16 | m2 = trimesh.load(files[6]) 17 | m3 = trimesh.load(files[7]) 18 | device = torch.device("cuda:0") 19 | 20 | chamfer_dist = ChamferDistKDTree(reduction="mean", njobs=1).to(device) 21 | criterion = torch.nn.MSELoss() 22 | 23 | latent_size = 3 24 | 25 | deformer = NeuralFlowDeformer( 26 | latent_size=latent_size, 27 | f_nlayers=6, 28 | f_width=100, 29 | s_nlayers=2, 30 | s_width=5, 31 | method="dopri5", 32 | nonlinearity="elu", 33 | arch="imnet", 34 | adjoint=True, 35 | atol=1e-4, 36 | rtol=1e-4, 37 | ).to(device) 38 | encoder = PointNetEncoder( 39 | nf=16, out_features=latent_size, dropout_prob=0.0 40 | ).to(device) 41 | 42 | # this is an awkward workaround to get gradients for encoder via adjoint solver 43 | deformer.add_encoder(encoder) 44 | deformer.to(device) 45 | encoder = deformer.net.encoder 46 | 47 | optimizer = optim.Adam(list(deformer.parameters()), lr=1e-3) 48 | 49 | niter = 1000 50 | npts = 5000 51 | 52 | V1 = torch.tensor(m1.vertices.astype(np.float32)).to(device) # .unsqueeze(0) 53 | V2 = torch.tensor(m2.vertices.astype(np.float32)).to(device) # .unsqueeze(0) 54 | V3 = torch.tensor(m3.vertices.astype(np.float32)).to(device) # .unsqueeze(0) 55 | 56 | loss_min = 1e30 57 | tic = time() 58 | encoder.train() 59 | 60 | for it in range(0, niter): 61 | optimizer.zero_grad() 62 | 63 | seq1 = torch.randperm(V1.shape[0], device=device)[:npts] 64 | seq2 = torch.randperm(V2.shape[0], device=device)[:npts] 65 | seq3 = torch.randperm(V3.shape[0], device=device)[:npts] 66 | V1_samp = V1[seq1] 67 | V2_samp = V2[seq2] 68 | V3_samp = V3[seq3] 69 | 70 | V_src = torch.stack( 71 | [V1_samp, V1_samp, V2_samp], dim=0 72 | ) # [batch, npoints, 3] 73 | V_tar = torch.stack( 74 | [V2_samp, V3_samp, V3_samp], dim=0 75 | ) # [batch, npoints, 3] 76 | 77 | V_src_tar = torch.cat([V_src, V_tar], dim=0) 78 | V_tar_src = torch.cat([V_tar, V_src], dim=0) 79 | 80 | batch_latent_src_tar = encoder(V_src_tar) 81 | batch_latent_tar_src = torch.cat( 82 | [batch_latent_src_tar[3:], batch_latent_src_tar[:3]] 83 | ) 84 | 85 | V_deform = deformer(V_src_tar, batch_latent_src_tar, batch_latent_tar_src) 86 | 87 | _, _, dist = chamfer_dist(V_deform, V_tar_src) 88 | 89 | loss = criterion(dist, torch.zeros_like(dist)) 90 | 91 | loss.backward() 92 | optimizer.step() 93 | 94 | if it % 100 == 0 or True: 95 | print(f"iter={it}, loss={np.sqrt(loss.item())}") 96 | 97 | toc = time() 98 | print("Time for {} iters: {:.4f} s".format(niter, toc - tic)) 99 | 100 | # save deformed mesh 101 | encoder.eval() 102 | with torch.no_grad(): 103 | V1_latent = encoder(V1.unsqueeze(0)) 104 | V2_latent = encoder(V2.unsqueeze(0)) 105 | V3_latent = encoder(V3.unsqueeze(0)) 106 | 107 | V1_2 = ( 108 | deformer(V1.unsqueeze(0), V1_latent, V2_latent) 109 | .detach() 110 | .cpu() 111 | .numpy()[0] 112 | ) 113 | V2_1 = ( 114 | deformer(V2.unsqueeze(0), V2_latent, V1_latent) 115 | .detach() 116 | .cpu() 117 | .numpy()[0] 118 | ) 119 | V1_3 = ( 120 | deformer(V1.unsqueeze(0), V1_latent, V3_latent) 121 | .detach() 122 | .cpu() 123 | .numpy()[0] 124 | ) 125 | V3_1 = ( 126 | deformer(V3.unsqueeze(0), V3_latent, V1_latent) 127 | .detach() 128 | .cpu() 129 | .numpy()[0] 130 | ) 131 | V2_3 = ( 132 | deformer(V2.unsqueeze(0), V2_latent, V3_latent) 133 | .detach() 134 | .cpu() 135 | .numpy()[0] 136 | ) 137 | V3_2 = ( 138 | deformer(V3.unsqueeze(0), V3_latent, V2_latent) 139 | .detach() 140 | .cpu() 141 | .numpy()[0] 142 | ) 143 | trimesh.Trimesh(V1_2, m1.faces).export("demo/output_1_2.obj") 144 | trimesh.Trimesh(V2_1, m2.faces).export("demo/output_2_1.obj") 145 | trimesh.Trimesh(V1_3, m1.faces).export("demo/output_1_3.obj") 146 | trimesh.Trimesh(V3_1, m3.faces).export("demo/output_3_1.obj") 147 | trimesh.Trimesh(V2_3, m2.faces).export("demo/output_2_3.obj") 148 | trimesh.Trimesh(V3_2, m3.faces).export("demo/output_3_2.obj") 149 | 150 | m1.export("demo/output_1.obj") 151 | m2.export("demo/output_2.obj") 152 | m3.export("demo/output_3.obj") 153 | -------------------------------------------------------------------------------- /doc/tb_losses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/doc/tb_losses.png -------------------------------------------------------------------------------- /doc/tb_meshes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/doc/tb_meshes.png -------------------------------------------------------------------------------- /doc/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/doc/teaser.png -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p data 4 | 5 | echo "Downloading and unzipping data. This will take a while, take a break and get a coffee..." 6 | 7 | echo "Downloading Data..." 8 | 9 | cd data 10 | wget island.me.berkeley.edu/shape_flow/shapenet_simplified.zip 11 | wget island.me.berkeley.edu/shape_flow/shapenet_thumbnails.zip 12 | wget island.me.berkeley.edu/shape_flow/shapenet_pointcloud.zip 13 | wget island.me.berkeley.edu/shape_flow/smpl_meshes.zip 14 | wget island.me.berkeley.edu/shape_flow/sparse_inputs.zip 15 | wget island.me.berkeley.edu/shape_flow/parametric_cad.zip 16 | 17 | echo "Unzipping Data..." 18 | unzip shapenet_simplified.zip && rm shapenet_simplified.zip 19 | unzip shapenet_thumbnails.zip && rm shapenet_thumbnails.zip 20 | unzip shapenet_pointcloud.zip && rm shapenet_pointcloud.zip 21 | unzip smpl_meshes.zip && rm smpl_meshes.zip 22 | unzip sparse_inputs.zip && rm sparse_inputs.zip 23 | unzip parametric_cad.zip && rm parametric_cad.zip 24 | 25 | cd .. 26 | -------------------------------------------------------------------------------- /download_pretrained.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Downloading pretrained checkpoints and inference inputs. This will take a while, take a break and get a coffee..." 4 | 5 | mkdir -p data 6 | cd data 7 | wget island.me.berkeley.edu/shape_flow/sparse_inputs.zip 8 | unzip sparse_inputs.zip && rm sparse_inputs.zip 9 | cd .. 10 | 11 | mkdir -p runs 12 | cd runs 13 | wget island.me.berkeley.edu/shape_flow/pretrained_chair_symm128.zip 14 | unzip pretrained_chair_symm128.zip && rm pretrained_chair_symm128.zip 15 | cd .. 16 | -------------------------------------------------------------------------------- /paper_figures/anim/linear.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear.npz -------------------------------------------------------------------------------- /paper_figures/anim/linear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear.png -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_0.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_0.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_1.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_1.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_10.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_10.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_11.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_11.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_12.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_12.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_13.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_13.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_14.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_14.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_15.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_15.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_16.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_16.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_17.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_17.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_18.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_18.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_19.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_19.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_2.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_2.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_20.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_20.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_3.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_3.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_4.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_4.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_5.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_5.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_6.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_6.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_7.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_7.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_8.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_8.ply -------------------------------------------------------------------------------- /paper_figures/anim/linear/mesh_9.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/linear/mesh_9.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0.npz -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0.png -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_0.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_0.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_1.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_1.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_10.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_10.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_11.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_11.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_12.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_12.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_13.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_13.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_14.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_14.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_15.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_15.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_16.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_16.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_17.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_17.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_18.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_18.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_19.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_19.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_2.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_2.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_20.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_20.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_3.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_3.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_4.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_4.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_5.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_5.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_6.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_6.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_7.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_7.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_8.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_8.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge0/mesh_9.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge0/mesh_9.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2.npz -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2.png -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_0.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_0.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_1.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_1.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_10.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_10.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_11.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_11.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_12.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_12.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_13.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_13.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_14.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_14.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_15.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_15.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_16.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_16.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_17.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_17.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_18.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_18.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_19.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_19.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_2.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_2.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_20.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_20.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_3.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_3.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_4.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_4.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_5.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_5.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_6.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_6.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_7.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_7.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_8.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_8.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_divfree_edge2/mesh_9.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_divfree_edge2/mesh_9.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0.npz -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0.png -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_0.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_0.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_1.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_1.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_10.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_10.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_11.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_11.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_12.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_12.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_13.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_13.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_14.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_14.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_15.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_15.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_16.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_16.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_17.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_17.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_18.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_18.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_19.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_19.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_2.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_2.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_20.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_20.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_3.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_3.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_4.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_4.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_5.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_5.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_6.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_6.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_7.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_7.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_8.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_8.ply -------------------------------------------------------------------------------- /paper_figures/anim/shapeflow_edge0/mesh_9.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/shapeflow_edge0/mesh_9.ply -------------------------------------------------------------------------------- /paper_figures/anim/volume_change.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim/volume_change.pdf -------------------------------------------------------------------------------- /paper_figures/anim_debug/linear.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/linear.npz -------------------------------------------------------------------------------- /paper_figures/anim_debug/linear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/linear.png -------------------------------------------------------------------------------- /paper_figures/anim_debug/shapeflow_divfree_edge0.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/shapeflow_divfree_edge0.npz -------------------------------------------------------------------------------- /paper_figures/anim_debug/shapeflow_divfree_edge0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/shapeflow_divfree_edge0.png -------------------------------------------------------------------------------- /paper_figures/anim_debug/shapeflow_divfree_edge2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/shapeflow_divfree_edge2.npz -------------------------------------------------------------------------------- /paper_figures/anim_debug/shapeflow_divfree_edge2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/shapeflow_divfree_edge2.png -------------------------------------------------------------------------------- /paper_figures/anim_debug/shapeflow_edge0.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/shapeflow_edge0.npz -------------------------------------------------------------------------------- /paper_figures/anim_debug/shapeflow_edge0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/shapeflow_edge0.png -------------------------------------------------------------------------------- /paper_figures/anim_debug/volume_change.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/anim_debug/volume_change.pdf -------------------------------------------------------------------------------- /paper_figures/parametric_cad/deformed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/parametric_cad/deformed.png -------------------------------------------------------------------------------- /paper_figures/parametric_cad/groundtruth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/parametric_cad/groundtruth.png -------------------------------------------------------------------------------- /paper_figures/shapegen/0_dmc.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_dmc.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/0_gt.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_gt.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/0_input.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_input.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/0_onet.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_onet.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/0_psgn.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_psgn.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/0_r2n2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_r2n2.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/0_retrieved.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_retrieved.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/0_shapeflow.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/0_shapeflow.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/1_dmc.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_dmc.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/1_gt.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_gt.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/1_input.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_input.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/1_onet.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_onet.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/1_psgn.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_psgn.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/1_r2n2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_r2n2.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/1_retrieved.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_retrieved.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/1_shapeflow.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/1_shapeflow.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/2_dmc.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_dmc.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/2_gt.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_gt.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/2_input.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_input.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/2_onet.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_onet.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/2_psgn.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_psgn.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/2_r2n2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_r2n2.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/2_retrieved.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_retrieved.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/2_shapeflow.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/2_shapeflow.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/0_dmc.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/0_dmc.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/0_gt.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/0_gt.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/0_input.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/0_input.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/0_onet.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/0_onet.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/0_psgn.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/0_psgn.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/0_r2n2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/0_r2n2.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/0_shapeflow.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/0_shapeflow.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/1_dmc.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/1_dmc.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/1_gt.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/1_gt.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/1_input.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/1_input.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/1_onet.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/1_onet.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/1_psgn.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/1_psgn.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/1_r2n2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/1_r2n2.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/1_shapeflow.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/1_shapeflow.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/2_dmc.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/2_dmc.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/2_gt.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/2_gt.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/2_input.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/2_input.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/2_onet.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/2_onet.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/2_psgn.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/2_psgn.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/2_r2n2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/2_r2n2.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/old/2_shapeflow.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/old/2_shapeflow.jpeg -------------------------------------------------------------------------------- /paper_figures/shapegen/shapegen.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/shapegen/shapegen.zip -------------------------------------------------------------------------------- /paper_figures/teaser/flow.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/paper_figures/teaser/flow.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.1.1 2 | numpy>=1.18.1 3 | scipy>=1.4.1 4 | torch>=1.4.0 5 | # torchdiffeq>=0.0.1 # torch 1.5.0 broke backwards compat. fixed in master. 6 | git+https://github.com/rtqichen/torchdiffeq 7 | torchvision>=0.5.0 8 | trimesh>=3.5.20 9 | sympy>=1.6.2 10 | imageio>=2.9.0 11 | tensorboard>=2.3.0 12 | -------------------------------------------------------------------------------- /scripts/cache_shapenet_pointcloud.py: -------------------------------------------------------------------------------- 1 | """Precompute and cache shapenet pointclouds. 2 | """ 3 | 4 | import os 5 | import sys 6 | 7 | WORKING_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") 8 | sys.path.append(WORKING_DIR) 9 | import tqdm # noqa: E402 10 | import argparse # noqa: E402 11 | import glob # noqa: E402 12 | import numpy as np # noqa: E402 13 | import trimesh # noqa: E402 14 | from multiprocessing import Pool # noqa: E402 15 | import pickle # noqa: E402 16 | from collections import OrderedDict # noqa: E402 17 | 18 | 19 | def sample_vertex(filename): 20 | mesh = trimesh.load(filename) 21 | v = np.array(mesh.vertices, dtype=np.float32) 22 | np.random.shuffle(v) 23 | return v[:n_points] 24 | 25 | 26 | def wrapper(arg): 27 | return arg, sample_vertex(arg) 28 | 29 | 30 | def update(args): 31 | filename, point_samp = args 32 | fname = "/".join(filename.split("/")[-4:-1]) 33 | # note: input comes from async `wrapper` 34 | sampled_points[ 35 | fname 36 | ] = point_samp # put answer into correct index of result list 37 | pbar.update() 38 | 39 | 40 | def get_args(): 41 | """Parse command line arguments.""" 42 | parser = argparse.ArgumentParser( 43 | description="Precompute and cache shapenet pointclouds." 44 | ) 45 | 46 | parser.add_argument( 47 | "--file_pattern", 48 | type=str, 49 | default="**/*.ply", 50 | help="filename pattern for files to be rendered.", 51 | ) 52 | parser.add_argument( 53 | "--input_root", 54 | type=str, 55 | default="data/shapenet_simplified", 56 | help="path to input mesh root.", 57 | ) 58 | parser.add_argument( 59 | "--output_pkl", 60 | type=str, 61 | default="data/shapenet_points.pkl", 62 | help="path to output image root.", 63 | ) 64 | parser.add_argument( 65 | "--n_points", 66 | type=int, 67 | default=4096, 68 | help="Number of points to sample per shape", 69 | ) 70 | parser.add_argument( 71 | "--n_jobs", 72 | type=int, 73 | default=-1, 74 | help="Number of processes to use. Use all if set to -1.", 75 | ) 76 | 77 | args = parser.parse_args() 78 | return args 79 | 80 | 81 | def main(): 82 | args = get_args() 83 | patt = os.path.join(WORKING_DIR, args.input_root, args.file_pattern) 84 | in_files = glob.glob(patt, recursive=True) 85 | global sampled_points 86 | global n_points 87 | global pbar 88 | sampled_points = OrderedDict() 89 | n_points = args.n_points 90 | pbar = tqdm.tqdm(total=len(in_files)) 91 | pool = Pool(processes=None if args.n_jobs == -1 else args.n_jobs) 92 | for fname in in_files: 93 | pool.apply_async(wrapper, args=(fname,), callback=update) 94 | pool.close() 95 | pool.join() 96 | pbar.close() 97 | with open(args.output_pkl, "wb") as fh: 98 | pickle.dump(sampled_points, fh) 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /scripts/render_thumbnail.py: -------------------------------------------------------------------------------- 1 | """Render thumbnails for meshes. 2 | """ 3 | 4 | # flake8: noqa E402 5 | 6 | import os 7 | import sys 8 | 9 | WORKING_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") 10 | sys.path.append(WORKING_DIR) 11 | from utils.render import render_trimesh 12 | from multiprocessing import Pool 13 | import trimesh 14 | import tqdm 15 | import argparse 16 | import glob 17 | import imageio 18 | 19 | 20 | def render_file(in_file): 21 | out_file = in_file.replace(input_root, output_root) 22 | out_dir = os.path.dirname(out_file) 23 | os.makedirs(out_dir, exist_ok=True) 24 | eye = [0.8, 0.4, -0.5] 25 | center = [0, 0, 0] 26 | up = [0, 1, 0] 27 | mesh = trimesh.load(in_file) 28 | image, _, _, _ = render_trimesh( 29 | mesh, eye, center, up, res=(112, 112), light_intensity=6 30 | ) 31 | imageio.imwrite(os.path.join(out_dir, "thumbnail.jpg"), image) 32 | 33 | 34 | def wrapper(arg): 35 | return arg, render_file(arg) 36 | 37 | 38 | def update(args): 39 | pbar.update() 40 | 41 | 42 | def get_args(): 43 | """Parse command line arguments.""" 44 | parser = argparse.ArgumentParser(description="Render thumbnails.") 45 | 46 | parser.add_argument( 47 | "--file_pattern", 48 | type=str, 49 | default="**/*.ply", 50 | help="filename pattern for files to be rendered.", 51 | ) 52 | parser.add_argument( 53 | "--input_root", 54 | type=str, 55 | default="data/shapenet_simplified", 56 | help="path to input mesh root.", 57 | ) 58 | parser.add_argument( 59 | "--output_root", 60 | type=str, 61 | default="data/shapenet_thumbnails", 62 | help="path to output image root.", 63 | ) 64 | parser.add_argument( 65 | "--n_jobs", 66 | type=int, 67 | default=-1, 68 | help="Number of processes to use. Use all if set to -1.", 69 | ) 70 | 71 | args = parser.parse_args() 72 | return args 73 | 74 | 75 | def main(): 76 | args = get_args() 77 | in_files = glob.glob( 78 | os.path.join(WORKING_DIR, args.input_root, args.file_pattern), 79 | recursive=True, 80 | ) 81 | print(os.path.join(WORKING_DIR, args.input_root, args.file_pattern)) 82 | global input_root 83 | global output_root 84 | global pbar 85 | input_root = args.input_root 86 | output_root = args.output_root 87 | pbar = tqdm.tqdm(total=len(in_files)) 88 | pool = Pool(processes=None if args.n_jobs == -1 else args.n_jobs) 89 | for fname in in_files: 90 | pool.apply_async(wrapper, args=(fname,), callback=update) 91 | pool.close() 92 | pool.join() 93 | pbar.close() 94 | 95 | 96 | main() 97 | -------------------------------------------------------------------------------- /scripts/simplify.py: -------------------------------------------------------------------------------- 1 | """Script for performing mesh simplification. 2 | """ 3 | # flake8: noqa E402 4 | import sys 5 | 6 | sys.path.append("..") 7 | 8 | import meshutils 9 | import trimesh 10 | import glob 11 | import os 12 | import tqdm 13 | 14 | 15 | files = glob.glob("../data/shapenet_watertight/test/*/*/*.ply") 16 | outroot = "shapenet_simplified" 17 | 18 | for f_in in tqdm.tqdm(files): 19 | f_out = f_in.replace("shapenet_watertight", outroot) 20 | dirname = os.path.dirname(f_out) 21 | os.makedirs(dirname, exist_ok=True) 22 | mesh = trimesh.load(f_in) 23 | v, f = meshutils.fast_simplify(mesh.vertices, mesh.faces, ratio=0.1) 24 | trimesh.Trimesh(v, f).export(f_out) 25 | -------------------------------------------------------------------------------- /shapeflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/shapeflow/__init__.py -------------------------------------------------------------------------------- /shapeflow/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/shapeflow/layers/__init__.py -------------------------------------------------------------------------------- /shapeflow/layers/chamfer_layer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from scipy.spatial import cKDTree 5 | import multiprocessing 6 | from .shared_definition import REDUCTIONS 7 | 8 | 9 | class ChamferDist(nn.Module): 10 | """Compute chamfer distance on GPU using O(n^2) dense distance matrix.""" 11 | 12 | def __init__(self, reduction="mean"): 13 | super(ChamferDist, self).__init__() 14 | if not (reduction in list(REDUCTIONS.keys())): 15 | raise ValueError( 16 | f"reduction method ({reduction}) not in list of " 17 | f"accepted values: {list(REDUCTIONS.keys())}" 18 | ) 19 | self.reduce = lambda x: REDUCTIONS[reduction](x, axis=-1) 20 | 21 | def forward(self, tar, src): 22 | """ 23 | Args: 24 | tar: [b, n, 3] points for target 25 | src: [b, m, 3] points for source 26 | Returns: 27 | accuracy, complete, chamfer 28 | """ 29 | tar = tar.unsqueeze(2) 30 | src = src.unsqueeze(1) 31 | diff = tar - src # [b, n, m, 3] 32 | dist = torch.norm(diff, dim=-1) # [b, n, m] 33 | complete = torch.mean(dist.min(2)[0], dim=1) # [b] 34 | accuracy = torch.mean(dist.min(1)[0], dim=1) # [b] 35 | 36 | complete = self.reduce(complete) 37 | accuracy = self.reduce(accuracy) 38 | chamfer = 0.5 * (complete + accuracy) 39 | return accuracy, complete, chamfer 40 | 41 | 42 | def find_nn_id(args): 43 | """Eval distance between point sets. 44 | Args: 45 | src: [m, 3] np array points for source 46 | tar: [n, 3] np array points for target 47 | Returns: 48 | nn_idx: [m,] np array, index of nearest point in target 49 | """ 50 | src, tar = args 51 | tree = cKDTree(tar) 52 | _, nn_idx = tree.query(src, k=1, n_jobs=1) 53 | 54 | return nn_idx 55 | 56 | 57 | def find_nn_id_parallel(args): 58 | """Eval distance between point sets. 59 | Args: 60 | src: [m, 3] np array points for source 61 | tar: [n, 3] np array points for target 62 | idx: int, batch index 63 | Returns: 64 | nn_idx: [m,] np array, index of nearest point in target 65 | idx 66 | """ 67 | src, tar, idx = args 68 | tree = cKDTree(tar) 69 | _, nn_idx = tree.query(src, k=1, n_jobs=1) 70 | 71 | return idx, nn_idx 72 | 73 | 74 | class ChamferDistKDTree(nn.Module): 75 | """Compute chamfer distances on CPU using KDTree.""" 76 | 77 | def __init__(self, reduction="mean", njobs=1): 78 | """Initialize loss module. 79 | 80 | Args: 81 | reduction: str, reduction method. choice of mean/sum/max/min. 82 | njobs: int, number of parallel workers to use during eval. 83 | """ 84 | super(ChamferDistKDTree, self).__init__() 85 | self.njobs = njobs 86 | 87 | self.set_reduction_method(reduction) 88 | if self.njobs != 1: 89 | self.p = multiprocessing.Pool(njobs) 90 | 91 | def find_batch_nn_id(self, src, tar, njobs): 92 | """Batched eval of distance between point sets. 93 | Args: 94 | src: [batch, m, 3] np array points for source 95 | tar: [batch, n, 3] np array points for target 96 | Returns: 97 | batch_nn_idx: [batch, m], np array, index of nearest point in target 98 | """ 99 | b = src.shape[0] 100 | if njobs != 1: 101 | src_tar_pairs = tuple(zip(src, tar, range(b))) 102 | result = self.p.map(find_nn_id_parallel, src_tar_pairs) 103 | seq_arr = np.array([r[0] for r in result]) 104 | batch_nn_idx = np.stack([r[1] for r in result], axis=0) 105 | batch_nn_idx = batch_nn_idx[np.argsort(seq_arr)] 106 | else: 107 | batch_nn_idx = np.stack( 108 | [find_nn_id((src[i], tar[i])) for i in range(b)], axis=0 109 | ) 110 | 111 | return batch_nn_idx 112 | 113 | def set_reduction_method(self, reduction): 114 | """Set reduction method. 115 | 116 | Args: 117 | reduction: str, reduction method. choice of mean/sum/max/min. 118 | """ 119 | if not (reduction in list(REDUCTIONS.keys())): 120 | raise ValueError( 121 | f"reduction method ({reduction}) not in list of " 122 | f"accepted values: {list(REDUCTIONS.keys())}" 123 | ) 124 | self.reduce = REDUCTIONS[reduction] 125 | 126 | def forward(self, src, tar): 127 | """ 128 | Args: 129 | src: [batch, m, 3] points for source 130 | tar: [batch, n, 3] points for target 131 | Returns: 132 | accuracy: [batch, m], accuracy measure for each point in source 133 | complete: [batch, n], complete measure for each point in target 134 | chamfer: [batch,], chamfer distance between source and target 135 | """ 136 | bs = src.shape[0] 137 | device = src.device 138 | src_np = src.data.cpu().numpy() 139 | tar_np = tar.data.cpu().numpy() 140 | batch_tar_idx = ( 141 | torch.from_numpy( 142 | self.find_batch_nn_id(src_np, tar_np, njobs=self.njobs) 143 | ) 144 | .type(torch.LongTensor) 145 | .to(device) 146 | ) # [b, m] 147 | batch_src_idx = ( 148 | torch.from_numpy( 149 | self.find_batch_nn_id(tar_np, src_np, njobs=self.njobs) 150 | ) 151 | .type(torch.LongTensor) 152 | .to(device) 153 | ) # [b, n] 154 | batch_tar_idx_b = ( 155 | torch.arange(bs).view(-1, 1).expand(-1, src.shape[1]) 156 | ) # [b, m, 3] 157 | batch_src_idx_b = ( 158 | torch.arange(bs).view(-1, 1).expand(-1, tar.shape[1]) 159 | ) # [b, n, 3] 160 | 161 | src_to_tar_diff = ( 162 | tar[batch_tar_idx_b, batch_tar_idx] - src 163 | ) # [b, m, 3] 164 | tar_to_src_diff = ( 165 | src[batch_src_idx_b, batch_src_idx] - tar 166 | ) # [b, n, 3] 167 | accuracy = torch.norm(src_to_tar_diff, dim=-1, keepdim=False) # [b, m] 168 | complete = torch.norm(tar_to_src_diff, dim=-1, keepdim=False) # [b, n] 169 | 170 | chamfer = 0.5 * (self.reduce(accuracy) + self.reduce(complete)) 171 | return accuracy, complete, chamfer 172 | -------------------------------------------------------------------------------- /shapeflow/layers/deformation_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchdiffeq import odeint_adjoint 4 | from torchdiffeq import odeint as odeint_regular 5 | from .pde_layer import PDELayer 6 | from .shared_definition import NONLINEARITIES 7 | import numpy as np 8 | 9 | 10 | class ImNet(nn.Module): 11 | """ImNet layer pytorch implementation.""" 12 | 13 | def __init__( 14 | self, 15 | dim=3, 16 | in_features=32, 17 | out_features=4, 18 | nf=32, 19 | nonlinearity="leakyrelu", 20 | ): 21 | """Initialization. 22 | 23 | Args: 24 | dim: int, dimension of input points. 25 | in_features: int, length of input features (i.e., latent code). 26 | out_features: number of output features. 27 | nf: int, width of the second to last layer. 28 | activation: tf activation op. 29 | name: str, name of the layer. 30 | """ 31 | super(ImNet, self).__init__() 32 | self.dim = dim 33 | self.in_features = in_features 34 | self.dimz = dim + in_features 35 | self.out_features = out_features 36 | self.nf = nf 37 | self.activ = NONLINEARITIES[nonlinearity] 38 | self.fc0 = nn.Linear(self.dimz, nf * 16) 39 | self.fc1 = nn.Linear(nf * 16 + self.dimz, nf * 8) 40 | self.fc2 = nn.Linear(nf * 8 + self.dimz, nf * 4) 41 | self.fc3 = nn.Linear(nf * 4 + self.dimz, nf * 2) 42 | self.fc4 = nn.Linear(nf * 2 + self.dimz, nf * 1) 43 | self.fc5 = nn.Linear(nf * 1, out_features) 44 | self.fc = [self.fc0, self.fc1, self.fc2, self.fc3, self.fc4, self.fc5] 45 | self.fc = nn.ModuleList(self.fc) 46 | 47 | def forward(self, x): 48 | """Forward method. 49 | 50 | Args: 51 | x: `[batch_size, dim+in_features]` tensor, inputs to decode. 52 | Returns: 53 | output through this layer of shape [batch_size, out_features]. 54 | """ 55 | x_tmp = x 56 | for dense in self.fc[:4]: 57 | x_tmp = self.activ(dense(x_tmp)) 58 | x_tmp = torch.cat([x_tmp, x], dim=-1) 59 | x_tmp = self.activ(self.fc4(x_tmp)) 60 | x_tmp = self.fc5(x_tmp) 61 | return x_tmp 62 | 63 | 64 | class VanillaNet(nn.Module): 65 | """Vanilla mpl pytorch implementation.""" 66 | 67 | def __init__( 68 | self, 69 | dim=3, 70 | in_features=32, 71 | out_features=3, 72 | nf=50, 73 | nlayers=4, 74 | nonlinearity="leakyrelu", 75 | ): 76 | """Initialization. 77 | 78 | Args: 79 | dim: int, dimension of input points. 80 | in_features: int, length of input features (i.e., latent code). 81 | out_features: number of output features. 82 | nf: int, width of the second to last layer. 83 | nlayers: int, number of layers in mlp (inc. input/output layers). 84 | activation: tf activation op. 85 | name: str, name of the layer. 86 | """ 87 | super(VanillaNet, self).__init__() 88 | self.dim = dim 89 | self.in_features = in_features 90 | self.dimz = dim + in_features 91 | self.out_features = out_features 92 | self.nf = nf 93 | self.nlayers = nlayers 94 | self.activ = NONLINEARITIES[nonlinearity] 95 | modules = [nn.Linear(dim + in_features, nf), self.activ] 96 | assert nlayers >= 2 97 | for i in range(nlayers - 2): 98 | modules += [nn.Linear(nf, nf), self.activ] 99 | modules += [nn.Linear(nf, out_features)] 100 | self.net = nn.Sequential(*modules) 101 | 102 | def forward(self, x): 103 | """Forward method. 104 | 105 | Args: 106 | x: `[batch_size, dim+in_features]` tensor, inputs to decode. 107 | Returns: 108 | output through this layer of shape [batch_size, out_features]. 109 | """ 110 | return self.net(x) 111 | 112 | 113 | def symmetrize(net, latent_vector, points, symm_dim): 114 | """Make network output symmetric.""" 115 | # query both sides of the symmetric dimension 116 | points_pos = points 117 | points_neg = points.clone() 118 | points_neg[..., symm_dim] = -points_neg[..., symm_dim] 119 | y_pos = net(latent_vector, points_pos) 120 | y_neg = net(latent_vector, points_neg) 121 | y_sym = (y_pos + y_neg) / 2 122 | y_sym[..., symm_dim] = (y_pos[..., symm_dim] - y_neg[..., symm_dim]) / 2 123 | return y_sym 124 | 125 | 126 | class DeformationFlowNetwork(nn.Module): 127 | def __init__( 128 | self, 129 | dim=3, 130 | latent_size=1, 131 | nlayers=4, 132 | width=50, 133 | nonlinearity="leakyrelu", 134 | arch="imnet", 135 | divfree=False, 136 | ): 137 | """Intialize deformation flow network. 138 | 139 | Args: 140 | dim: int, physical dimensions. Either 2 for 2d or 3 for 3d. 141 | latent_size: int, size of latent space. >= 1. 142 | nlayers: int, number of neural network layers. >= 2. 143 | width: int, number of neurons per hidden layer. >= 1. 144 | divfree: bool, paramaterize a divergence free flow. 145 | """ 146 | super(DeformationFlowNetwork, self).__init__() 147 | self.dim = dim 148 | self.latent_size = latent_size 149 | self.nlayers = nlayers 150 | self.width = width 151 | self.nonlinearity = nonlinearity 152 | 153 | self.arch = arch 154 | self.divfree = divfree 155 | 156 | assert arch in ["imnet", "vanilla"] 157 | if arch == "imnet": 158 | self.net = ImNet( 159 | dim=dim, 160 | in_features=latent_size, 161 | out_features=dim, 162 | nf=width, 163 | nonlinearity=nonlinearity, 164 | ) 165 | else: # vanilla 166 | self.net = VanillaNet( 167 | dim=dim, 168 | in_features=latent_size, 169 | out_features=dim, 170 | nf=width, 171 | nlayers=nlayers, 172 | nonlinearity=nonlinearity, 173 | ) 174 | 175 | for m in self.net.modules(): 176 | if isinstance(m, nn.Linear): 177 | nn.init.normal_(m.weight, mean=0, std=1e-1) 178 | nn.init.constant_(m.bias, val=0) 179 | if divfree: 180 | self.curl = self._get_curl_layer() 181 | 182 | def _get_curl_layer(self): 183 | in_vars = "x, y, z" 184 | out_vars = "u, v, w" 185 | eqn_strs = [ 186 | "dif(w, y) - dif(v, z)", 187 | "dif(u, z) - dif(w, x)", 188 | "dif(v, x) - dif(u, y)", 189 | ] 190 | eqn_names = [ 191 | "curl_x", 192 | "curl_y", 193 | "curl_z", 194 | ] # a name/identifier for the equations 195 | curl = PDELayer( 196 | in_vars=in_vars, out_vars=out_vars 197 | ) # initialize the pde layer 198 | for eqn_str, eqn_name in zip(eqn_strs, eqn_names): # add equations 199 | curl.add_equation(eqn_str, eqn_name) 200 | return curl 201 | 202 | def forward(self, latent_vector, points): 203 | """ 204 | Args: 205 | latent_vector: tensor of shape [batch, latent_size], latent code for 206 | each shape 207 | points: tensor of shape [batch, num_points, dim], points representing 208 | each shape 209 | 210 | Returns: 211 | velocities: tensor of shape [batch, num_points, dim], velocity at 212 | each point 213 | """ 214 | latent_vector = latent_vector.unsqueeze(1).expand( 215 | -1, points.shape[1], -1 216 | ) # [batch, num_points, latent_size] 217 | # wrapper for pde layer 218 | 219 | def fwd_fn(points): 220 | """Forward function. 221 | 222 | Where inpt[..., 0], inpt[..., 1], inpt[..., 2] correspond to x, y, 223 | z and 224 | out[..., 0], out[..., 1], out[..., 2] correspond to u, v, w 225 | """ 226 | points_latents = torch.cat( 227 | (points, latent_vector), axis=-1 228 | ) # [batch, num_points, dim + latent_size] 229 | b, n, d = points_latents.shape 230 | res = self.net(points_latents.reshape([-1, d])) 231 | res = res.reshape([b, n, self.dim]) 232 | return res 233 | 234 | if self.divfree: 235 | # return the curl of the velocity field instead 236 | self.curl.update_forward_method(fwd_fn) 237 | _, res_dict = self.curl(points) # res are the equation residues 238 | res = torch.cat( 239 | [res_dict["curl_x"], res_dict["curl_y"], res_dict["curl_z"]], 240 | dim=-1, 241 | ) # [batch, num_points, dim] 242 | return res 243 | else: 244 | return fwd_fn(points) 245 | 246 | 247 | class ConformalDeformationFlowNetwork(nn.Module): 248 | def __init__( 249 | self, 250 | dim=3, 251 | latent_size=1, 252 | nlayers=4, 253 | width=50, 254 | nonlinearity="softplus", 255 | output_scalar=False, 256 | arch="imnet", 257 | ): 258 | """Intialize conformal deformation flow network w/ irrotational flow. 259 | 260 | The network produces a scalar field Phi(x,y,z,t), and the velocity 261 | field is represented as the gradient of Phi. 262 | v = \nabla\Phi # noqa: W605 263 | The gradients can be efficiently computed as the Jacobian through 264 | backprop. 265 | 266 | Args: 267 | dim: int, physical dimensions. Either 2 for 2d or 3 for 3d. 268 | latent_size: int, size of latent space. >= 1. 269 | nlayers: int, number of neural network layers. >= 2. 270 | width: int, number of neurons per hidden layer. >= 1. 271 | """ 272 | super(ConformalDeformationFlowNetwork, self).__init__() 273 | self.dim = dim 274 | self.latent_size = latent_size 275 | self.nlayers = nlayers 276 | self.width = width 277 | self.nonlinearity = nonlinearity 278 | self.output_scalar = output_scalar 279 | 280 | self.scale = nn.Parameter(torch.ones(1) * 1e-1) 281 | 282 | nlin = NONLINEARITIES[nonlinearity] 283 | 284 | modules = [nn.Linear(dim + latent_size, width), nlin] 285 | for i in range(nlayers - 2): 286 | modules += [nn.Linear(width, width), nlin] 287 | modules += [nn.Linear(width, 1)] 288 | self.net = nn.Sequential(*modules) 289 | 290 | for m in self.net.modules(): 291 | if isinstance(m, nn.Linear): 292 | nn.init.normal_(m.weight, mean=0, std=1e-1) 293 | nn.init.constant_(m.bias, val=0) 294 | 295 | def forward(self, latent_vector, points): 296 | """ 297 | Args: 298 | latent_vector: tensor of shape [batch, latent_size], latent code for 299 | each shape 300 | points: tensor of shape [batch, num_points, dim], points representing 301 | each shape 302 | Returns: 303 | velocities: tensor of shape [batch, num_points, dim], velocity at 304 | each point 305 | """ 306 | latent_vector = latent_vector.unsqueeze(1).expand( 307 | -1, points.shape[1], -1 308 | ) # [batch, num_points, latent_size] 309 | b, num_points, lat_size = latent_vector.shape 310 | latent_flat = latent_vector.reshape([-1, lat_size]) 311 | points_flat = points.reshape([-1, self.dim]) 312 | points_flat_ = torch.autograd.Variable(points_flat) 313 | points_flat_.requires_grad = True 314 | points_latents = torch.cat( 315 | (points_flat_, latent_flat), axis=-1 316 | ) # [batch*num_points, dim + latent_size] 317 | phi = self.net(points_latents) * self.scale 318 | vel_flat = torch.autograd.grad( 319 | phi, 320 | points_flat_, 321 | grad_outputs=torch.ones_like(phi), 322 | create_graph=True, 323 | )[0] 324 | vel = vel_flat.reshape(points.shape) 325 | if self.output_scalar: 326 | return vel, phi 327 | else: 328 | return vel 329 | 330 | 331 | class DeformationSignNetwork(nn.Module): 332 | def __init__( 333 | self, latent_size=1, nlayers=3, width=20, nonlinearity="tanh" 334 | ): 335 | """Initialize deformation sign network. 336 | Args: 337 | latent_size: int, size of latent space. >= 1. 338 | nlayers: int, number of neural network layers. >= 2. 339 | width: int, number of neurons per hidden layer. >= 1. 340 | """ 341 | super(DeformationSignNetwork, self).__init__() 342 | self.latent_size = latent_size 343 | self.nlayers = nlayers 344 | self.width = width 345 | 346 | nlin = NONLINEARITIES[nonlinearity] 347 | modules = [nn.Linear(latent_size, width, bias=False), nlin] 348 | for i in range(nlayers - 2): 349 | modules += [nn.Linear(width, width, bias=False), nlin] 350 | modules += [nn.Linear(width, 1, bias=False), nlin] 351 | self.net = nn.Sequential(*modules) 352 | 353 | for m in self.net.modules(): 354 | if isinstance(m, nn.Linear): 355 | nn.init.normal_(m.weight, mean=0, std=1e-1) 356 | 357 | def forward(self, dir_vector): 358 | """ 359 | Args: 360 | dir_vector: tensor of shape [batch, latent_size], latent direction. 361 | Returns: 362 | signs: tensor of shape [batch, 1, 1] 363 | """ 364 | dir_vector = dir_vector / ( 365 | torch.norm(dir_vector, dim=-1, keepdim=True) + 1e-6 366 | ) # normalize 367 | signs = self.net(dir_vector).unsqueeze(-1) 368 | return signs 369 | 370 | 371 | class NeuralFlowModel(nn.Module): 372 | def __init__( 373 | self, 374 | dim=3, 375 | latent_size=1, 376 | f_nlayers=4, 377 | f_width=50, 378 | s_nlayers=3, 379 | s_width=20, 380 | nonlinearity="relu", 381 | conformal=False, 382 | arch="imnet", 383 | no_sign_net=False, 384 | divfree=False, 385 | symm_dim=None, 386 | ): 387 | super(NeuralFlowModel, self).__init__() 388 | if conformal: 389 | model = ConformalDeformationFlowNetwork 390 | 391 | else: 392 | model = DeformationFlowNetwork 393 | self.no_sign_net = no_sign_net 394 | self.flow_net = model( 395 | dim=dim, 396 | latent_size=latent_size, 397 | nlayers=f_nlayers, 398 | width=f_width, 399 | nonlinearity=nonlinearity, 400 | arch=arch, 401 | divfree=divfree, 402 | ) 403 | if not no_sign_net: 404 | self.sign_net = DeformationSignNetwork( 405 | latent_size=latent_size, nlayers=s_nlayers, width=s_width 406 | ) 407 | self.symm_dim = symm_dim 408 | self.latent_source = None 409 | self.latent_target = None 410 | self.latent_updated = False 411 | self.conformal = conformal 412 | self.arch = arch 413 | self.encoder = None 414 | self.lat_params = None 415 | self.scale = nn.Parameter(torch.ones(1) * 1e-3) 416 | 417 | def add_encoder(self, encoder): 418 | self.encoder = encoder 419 | 420 | def add_lat_params(self, lat_params): 421 | self.lat_params = lat_params 422 | 423 | def get_lat_params(self, idx): 424 | assert self.lat_params is not None 425 | return self.lat_params[idx] 426 | 427 | def update_latents(self, latent_sequence): 428 | """ 429 | Args: 430 | latent_sequence: long or float tensor of shape 431 | [batch, nsteps, latent_size]. 432 | sequence of latents along deformation path. 433 | if long, index into self.lat_params to retrieve latents. 434 | Returns: 435 | latent_waypoint: float tensor of shape [batch, nsteps], interp 436 | coefficient betwene [0, 1] corresponding to each latent code. 437 | """ 438 | bs, ns, d = latent_sequence.shape 439 | dev = latent_sequence.device 440 | self.latent_sequence = latent_sequence 441 | self.latent_seq_len = torch.norm( 442 | self.latent_sequence[:, 1:] - self.latent_sequence[:, :-1], dim=-1 443 | ) # [batch, nsteps-1] 444 | self.latent_seq_len_sum = torch.sum( 445 | self.latent_seq_len, dim=1 446 | ) # [batch] 447 | self.latent_seq_weight = ( 448 | self.latent_seq_len / self.latent_seq_len_sum[:, None] 449 | ) # [batch, nsteps-1] 450 | self.latent_seq_bins = torch.cumsum( 451 | self.latent_seq_weight, dim=1 452 | ) # [batch, nsteps-1] 453 | self.latent_seq_bins = torch.cat( 454 | [torch.zeros([bs, 1], device=dev), self.latent_seq_bins], dim=1 455 | ) # [batch, nsteps] 456 | self.latent_updated = True 457 | 458 | return self.latent_seq_bins 459 | 460 | def latent_at_t(self, t, return_sign=False): 461 | """Helper fn to compute latent at t.""" 462 | t = t.to(self.latent_seq_bins.device) 463 | # find out which bin this t falls into 464 | bin_mask = (t > self.latent_seq_bins[:, :-1]) * ( 465 | t < self.latent_seq_bins[:, 1:] 466 | ) 467 | # logical and 468 | 469 | bin_mask = bin_mask.float() 470 | bin_idx = torch.argmax(bin_mask, dim=1) # [batch,] 471 | batch_idx = torch.arange(bin_idx.shape[0]).to(bin_idx.device) 472 | 473 | # Find the interpolation coefficient between the latents at the two 474 | # ends of the bin 475 | t0 = self.latent_seq_bins[batch_idx, bin_idx] 476 | t1 = self.latent_seq_bins[batch_idx, bin_idx + 1] # [batch] 477 | alpha = (t - t0) / (t1 - t0) # [batch] 478 | latent_t0 = self.latent_sequence[ 479 | batch_idx, bin_idx 480 | ] # [batch, latent_size] 481 | latent_t1 = self.latent_sequence[ 482 | batch_idx, bin_idx + 1 483 | ] # [batch, latent_size] 484 | latent_val = latent_t0 + alpha[:, None] * (latent_t1 - latent_t0) 485 | latent_dir = (latent_t1 - latent_t0) / torch.norm( 486 | latent_t1 - latent_t0, dim=1, keepdim=True 487 | ) 488 | zeros = torch.zeros_like(latent_t0) 489 | outward = torch.norm(latent_t0 - zeros, dim=1) < 1e-6 # [batch] 490 | sign = (outward.float() - 0.5) * 2 491 | 492 | return latent_val, latent_dir, sign 493 | 494 | def forward(self, t, points): 495 | """ 496 | Args: 497 | t: float, deformation parameter between 0 and 1. 498 | points: [batch, num_points, dim] 499 | Returns: 500 | vel: [batch, num_points, dim] 501 | """ 502 | # Reparametrize eval along latent path as a function of a single 503 | # scalar t 504 | if not self.latent_updated: 505 | raise RuntimeError( 506 | "Latent not updated. " 507 | "Use .update_latents() to update the source and target latents" 508 | ) 509 | 510 | latent_val, latent_dir, sign = self.latent_at_t(t) 511 | sign = sign[:, None, None] * self.scale 512 | if self.symm_dim is None: 513 | flow = self.flow_net(latent_val, points) # [batch, num_pints, dim] 514 | else: 515 | flow = symmetrize(self.flow_net, latent_val, points, self.symm_dim) 516 | # Normalize velocity based on time space proportional to latent 517 | # difference. 518 | flow *= self.latent_seq_len_sum[:, None, None] 519 | if not self.no_sign_net: 520 | sign = self.sign_net(latent_dir) 521 | return flow * sign 522 | 523 | 524 | class NeuralFlowDeformer(nn.Module): 525 | def __init__( 526 | self, 527 | dim=3, 528 | latent_size=1, 529 | f_nlayers=4, 530 | f_width=50, 531 | s_nlayers=3, 532 | s_width=20, 533 | method="dopri5", 534 | nonlinearity="leakyrelu", 535 | arch="imnet", 536 | conformal=False, 537 | adjoint=True, 538 | atol=1e-5, 539 | rtol=1e-5, 540 | via_hub=False, 541 | no_sign_net=False, 542 | return_waypoints=False, 543 | use_latent_waypoints=False, 544 | divfree=False, 545 | symm_dim=None, 546 | ): 547 | """Initialize. The parameters are the parameters for the Deformation 548 | Flow network. 549 | 550 | Args: 551 | dim: int, physical dimensions. Either 2 for 2d or 3 for 3d. 552 | latent_size: int, size of latent space. >= 1. 553 | f_nlayers: int, number of neural network layers for flow network. 554 | (>= 2). 555 | f_width: int, number of neurons per hidden layer for flow network 556 | (>= 1). 557 | s_nlayers: int, number of neural network layers for sign network 558 | (>= 2). 559 | s_width: int, number of neurons per hidden layer for sign network. 560 | (>= 1). 561 | arch: str, architecture, choice of 'imnet' / 'vanilla' 562 | adjoint: bool, whether to use adjoint solver to backprop gadient 563 | thru odeint. 564 | rtol, atol: float, relative / absolute error tolerence in ode solver. 565 | via_hub: bool, will perform transformation via hub-and-spokes 566 | configuration. Only useful if latent_sequence is torch.long 567 | return_waypoints: bool, return intermediate waypoints along timing. 568 | use_latent_waypoints: bool, use latent waypoints. 569 | symm_dim: int, list of int, or None. Symmetry axis/axes, or None. 570 | """ 571 | super(NeuralFlowDeformer, self).__init__() 572 | self.method = method 573 | self.conformal = conformal 574 | self.arch = arch 575 | self.adjoint = adjoint 576 | self.odeint = odeint_adjoint if adjoint else odeint_regular 577 | self.__timing = torch.from_numpy( 578 | np.array([0.0, 1.0]).astype("float32") 579 | ) 580 | self.return_waypoints = return_waypoints 581 | self.use_latent_waypoints = use_latent_waypoints 582 | self.rtol = rtol 583 | self.atol = atol 584 | self.via_hub = via_hub 585 | self.symm_dim = symm_dim 586 | 587 | self.net = NeuralFlowModel( 588 | dim=dim, 589 | latent_size=latent_size, 590 | f_nlayers=f_nlayers, 591 | f_width=f_width, 592 | s_nlayers=s_nlayers, 593 | s_width=s_width, 594 | arch=arch, 595 | conformal=conformal, 596 | nonlinearity=nonlinearity, 597 | no_sign_net=no_sign_net, 598 | divfree=divfree, 599 | symm_dim=symm_dim, 600 | ) 601 | if symm_dim is not None: 602 | if not (isinstance(symm_dim, int) or isinstance(symm_dim, list)): 603 | raise ValueError( 604 | "symm_dim must be int or list of ints, indicating axes of" 605 | "symmetry." 606 | ) 607 | 608 | @property 609 | def adjoint(self): 610 | return self.__adjoint 611 | 612 | @adjoint.setter 613 | def adjoint(self, isadjoint): 614 | assert isinstance(isadjoint, bool) 615 | self.__adjoint = isadjoint 616 | self.odeint = odeint_adjoint if isadjoint else odeint_regular 617 | 618 | @property 619 | def timing(self): 620 | return self.__timing 621 | 622 | @timing.setter 623 | def timing(self, timing): 624 | assert isinstance(timing, torch.Tensor) 625 | assert timing.ndim == 1 626 | self.__timing = timing 627 | 628 | def add_encoder(self, encoder): 629 | self.net.add_encoder(encoder) 630 | 631 | def add_lat_params(self, lat_params): 632 | self.net.add_lat_params(lat_params) 633 | 634 | def get_lat_params(self, idx): 635 | return self.net.get_lat_params(idx) 636 | 637 | def forward(self, points, latent_sequence): 638 | """Forward transformation (source -> latent_path -> target). 639 | 640 | To perform backward transformation, simply switch the order of the lat 641 | codes. 642 | 643 | Args: 644 | points: [batch, num_points, dim] 645 | latent_sequence: float tensor of shape [batch, nsteps, latent_size], 646 | ------- or ------- 647 | long tensor of shape [batch, nsteps] 648 | sequence of latents along deformation path. 649 | if long, index into self.lat_params to retrieve latents. 650 | Returns: 651 | points_transformed: 652 | tensor of shape [batch, num_points, dim] if not 653 | self.return_waypoint. 654 | tensor of shape [nsteps, batch, num_points, dim] if 655 | self.return_waypoint. 656 | """ 657 | if latent_sequence.dtype == torch.long: 658 | latent_sequence = self.get_lat_params( 659 | latent_sequence 660 | ) # [nsteps, batch, lat_dim] 661 | if self.via_hub: 662 | assert latent_sequence.shape[1] == 2 663 | zeros = torch.zeros_like(latent_sequence[:, :1]) 664 | lat0 = latent_sequence[:, 0:1] 665 | lat1 = latent_sequence[:, 1:2] 666 | latent_sequence = torch.cat( 667 | [lat0, zeros, lat1], dim=1 668 | ) # [batch, nsteps=3, lat_dim] 669 | waypoints = self.net.update_latents(latent_sequence) 670 | if self.use_latent_waypoints: 671 | timing = waypoints[0] 672 | else: 673 | timing = self.timing 674 | points_transformed = self.odeint( 675 | self.net, 676 | points, 677 | timing, 678 | method=self.method, 679 | rtol=self.rtol, 680 | atol=self.atol, 681 | ) 682 | if self.return_waypoints: 683 | return points_transformed 684 | else: 685 | return points_transformed[-1] 686 | -------------------------------------------------------------------------------- /shapeflow/layers/pde_layer.py: -------------------------------------------------------------------------------- 1 | import sympy 2 | import torch 3 | from torch.autograd import grad 4 | from sympy.parsing.sympy_parser import parse_expr 5 | 6 | 7 | # utility functions for parsing equations 8 | torch_diff = lambda y, x: grad( # noqa: E731 9 | y, x, grad_outputs=torch.ones_like(y), create_graph=True, allow_unused=True 10 | )[0] 11 | 12 | 13 | class PDELayer(object): 14 | """PDE Layer for querying values and computing PDE residues.""" 15 | 16 | def __init__(self, in_vars, out_vars): 17 | """Initialize physics layer. 18 | 19 | Args: 20 | in_vars: str, a string of input variable names separated by space. 21 | E.g., 'x y t' for the three variables x, y and t. 22 | out_vars: str, a string of output variable names separated by space. 23 | E.g., 'u v p' for the three variables u, v and p. 24 | """ 25 | self.in_vars = sympy.symbols(in_vars) 26 | self.out_vars = sympy.symbols(out_vars) 27 | if not isinstance(self.in_vars, tuple): 28 | self.in_vars = (self.in_vars,) 29 | if not isinstance(self.out_vars, tuple): 30 | self.out_vars = (self.out_vars,) 31 | self.n_in = len(self.in_vars) 32 | self.n_out = len(self.out_vars) 33 | self.all_vars = list(self.in_vars) + list(self.out_vars) 34 | self.eqns_raw = {} # raw string equations 35 | self.eqns_fn = {} # lambda function for the equations 36 | self.forward_method = None 37 | 38 | def add_equation(self, eqn_str, eqn_name="", subs_dict=None): 39 | """Add an equation to the physics layer. 40 | 41 | The equation string should represent the expression for computing the 42 | residue of a given equation, rather than representing the equation 43 | itself. Use dif(y,x) for computing the derivate of y with respect to x. 44 | Sign of the expression does not matter. The variable names 45 | **MUST** be the same as the variables in self.in_vars and 46 | self.out_vars. 47 | 48 | E.g., 49 | For the equation 50 | partial(u, x) + partial(v, y) = 3*partial(u, y)*partial(v, x), 51 | write as: 52 | eqn_str = 'dif(u,x)+dif(v,y)-3*dif(u,y)*dif(v,x)' 53 | - or - 54 | eqn_str = '3*dif(u,y)*dif(v,x)-(dif(u,x)+dif(v,y))' 55 | 56 | Args: 57 | eqn_str: str, a string that can be parsed as an experession for 58 | computing the residue of an equation. 59 | eqn_name: str, a name or identifier for this equation entry. E.g., 60 | 'div_free'. If none or empty, use default of eqn_i where i is an 61 | index. 62 | subs_dict: dict, a dictionary where the key (str) is the variable 63 | to subsitute and val (str) is the expression to substitite the = 64 | variable with. useful for scenarios such as normalizations and/or 65 | non-dimensionalizing expressions. 66 | 67 | Raises: 68 | ValueError: when the variables in the eqn_str do not match that of 69 | in_vars and out_vars. 70 | 71 | """ 72 | if not eqn_name: 73 | eqn_name = "eqn_{i}".format(i=len(self.eqns_raw.keys())) 74 | 75 | # Assert that the equation contains the same vars as in_vars and 76 | # out_vars. 77 | expr = parse_expr(eqn_str) 78 | 79 | # substitute variables in the equation. 80 | if subs_dict: 81 | for key, val in subs_dict.items(): 82 | expr = expr.subs(key, val) 83 | 84 | valid_var = expr.free_symbols <= ( 85 | set(self.in_vars) | set(self.out_vars) 86 | ) 87 | if not valid_var: 88 | raise ValueError( 89 | "Variables in the eqn_str ({}) does not match that of " 90 | "in_vars ({}) and out_vars ({})".format( 91 | expr.free_symbols, set(self.in_vars), set(self.out_vars) 92 | ) 93 | ) 94 | 95 | # convert into lambda functions 96 | fn = sympy.lambdify(self.all_vars, expr, {"dif": torch_diff}) 97 | 98 | # update equations 99 | self.eqns_raw.update({eqn_name: eqn_str}) 100 | self.eqns_fn.update({eqn_name: fn}) 101 | 102 | def update_forward_method(self, forward_method): 103 | """Update forward method. 104 | 105 | Args: 106 | forward_method: a function, such that y = forward_method(x). x is a 107 | tensor of shape (..., n_in) and y is a tensor of shape (..., n_out). 108 | """ 109 | self.forward_method = forward_method 110 | 111 | def eval(self, x): 112 | """Evaluate the output values using forward_method. 113 | 114 | Args: 115 | x: a tensor of shape (..., n_in) 116 | Returns: 117 | a tensor of shape (..., n_out) 118 | """ 119 | if not self.forward_method: 120 | raise RuntimeError( 121 | "forward_method has not been defined." 122 | "Run update_forward_method first." 123 | ) 124 | y = self.forward_method(x) 125 | if not ((x.shape[-1] == self.n_in) and (y.shape[-1] == self.n_out)): 126 | raise ValueError( 127 | "Input/output dimensions ({}/{}) not equal to the dimensions " 128 | "of defined variables ({}/{}).".format( 129 | x.shape[-1], y.shape[-1], self.n_in, self.n_out 130 | ) 131 | ) 132 | return y 133 | 134 | def __call__(self, x, return_residue=True): 135 | """Compute the forward eval and possibly compute residues from the 136 | previously defined pdes. 137 | 138 | Args: 139 | x: input tensor of shape (..., n_in) 140 | return_residue: bool, whether to return the residue of the pde for 141 | each equation. 142 | Returns: 143 | y: output tensor of shape (..., n_out) 144 | residues (optional): a dictionary containing residue evaluation for 145 | each pde. 146 | """ 147 | 148 | if not return_residue: 149 | y = self.eval(x) 150 | return y 151 | else: 152 | with torch.enable_grad(): 153 | # split into individual channels and set each to require grad. 154 | inputs = [x[..., i: i + 1] for i in range(x.shape[-1])] 155 | for xx in inputs: 156 | if not xx.requires_grad: 157 | xx.requires_grad = True 158 | x_ = torch.cat(inputs, axis=-1) 159 | y = self.eval(x_) 160 | outputs = [y[..., i: i + 1] for i in range(y.shape[-1])] 161 | inputs_outputs = inputs + outputs 162 | residues = {} 163 | for key, fn in self.eqns_fn.items(): 164 | residue = fn(*inputs_outputs) 165 | residues.update({key: residue}) 166 | return y, residues 167 | 168 | @property 169 | def eqn_num(self): 170 | return len(self.eqns_raw) 171 | 172 | @property 173 | def eqn_names(self): 174 | return list(self.eqns_raw.keys()) 175 | -------------------------------------------------------------------------------- /shapeflow/layers/pointnet_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Credits: 3 | Originally implemented by Fei Xia. 4 | https://github.com/fxia22/pointnet.pytorch 5 | 6 | with small modifications. 7 | """ 8 | from __future__ import print_function 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.utils.data 13 | from torch.autograd import Variable 14 | import numpy as np 15 | 16 | from .shared_definition import NORMTYPE, NONLINEARITIES 17 | 18 | 19 | class STN3d(nn.Module): 20 | def __init__(self, norm_type, nonlinearity="relu"): 21 | super(STN3d, self).__init__() 22 | assert norm_type in NORMTYPE.keys() 23 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 24 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 25 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 26 | self.fc1 = torch.nn.Conv1d(1024, 512, 1) 27 | self.fc2 = torch.nn.Conv1d(512, 256, 1) 28 | self.fc3 = torch.nn.Conv1d(256, 9, 1) 29 | self.nl = NONLINEARITIES[nonlinearity] 30 | 31 | self.bn1 = NORMTYPE[norm_type](64) 32 | self.bn2 = NORMTYPE[norm_type](128) 33 | self.bn3 = NORMTYPE[norm_type](1024) 34 | self.bn4 = NORMTYPE[norm_type](512) 35 | self.bn5 = NORMTYPE[norm_type](256) 36 | 37 | def forward(self, x): 38 | batchsize = x.size()[0] 39 | x = self.nl(self.bn1(self.conv1(x))) 40 | x = self.nl(self.bn2(self.conv2(x))) 41 | x = self.nl(self.bn3(self.conv3(x))) 42 | x = torch.max(x, 2, keepdim=True)[0] # [b, c, 1] 43 | 44 | x = self.nl(self.bn4(self.fc1(x))) # [b, c, 1] 45 | x = self.nl(self.bn5(self.fc2(x))) # [b, c, 1] 46 | x = self.fc3(x).squeeze(-1) 47 | 48 | iden = ( 49 | Variable( 50 | torch.from_numpy( 51 | np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32) 52 | ) 53 | ) 54 | .view(1, 9) 55 | .repeat(batchsize, 1) 56 | ) 57 | if x.is_cuda: 58 | iden = iden.cuda() 59 | x = x + iden 60 | x = x.view(-1, 3, 3) 61 | return x 62 | 63 | 64 | class PointNetfeat(nn.Module): 65 | def __init__( 66 | self, 67 | in_features=3, 68 | nf=64, 69 | global_feat=True, 70 | feature_transform=False, 71 | norm_type="batchnorm", 72 | nonlinearity="relu", 73 | ): 74 | super(PointNetfeat, self).__init__() 75 | assert norm_type in NORMTYPE.keys() 76 | self.stn = STN3d(norm_type=norm_type, nonlinearity=nonlinearity) 77 | self.conv1 = torch.nn.Conv1d(in_features, nf, 1) 78 | self.conv2 = torch.nn.Conv1d(nf, nf * 2, 1) 79 | self.conv3 = torch.nn.Conv1d(nf * 2, nf * 16, 1) 80 | self.nm1 = NORMTYPE[norm_type](nf) 81 | self.nm2 = NORMTYPE[norm_type](nf * 2) 82 | self.nm3 = NORMTYPE[norm_type](nf * 16) 83 | self.global_feat = global_feat 84 | self.feature_transform = feature_transform 85 | self.nf = nf 86 | self.nl = NONLINEARITIES[nonlinearity] 87 | if self.feature_transform: 88 | self.fstn = STN3d(k=nf) 89 | 90 | def forward(self, x): 91 | n_pts = x.size()[2] 92 | trans = self.stn(x) 93 | x = x.transpose(2, 1) 94 | x = torch.bmm(x, trans) 95 | x = x.transpose(2, 1) 96 | x = self.nl(self.nm1(self.conv1(x))) 97 | 98 | if self.feature_transform: 99 | trans_feat = self.fstn(x) 100 | x = x.transpose(2, 1) 101 | x = torch.bmm(x, trans_feat) 102 | x = x.transpose(2, 1) 103 | else: 104 | trans_feat = None 105 | 106 | pointfeat = x 107 | x = self.nl(self.nm2(self.conv2(x))) 108 | x = self.nm3(self.conv3(x)) 109 | x = torch.max(x, 2, keepdim=True)[0] 110 | x = x.view(-1, self.nf * 16) 111 | if self.global_feat: 112 | return x, trans, trans_feat 113 | else: 114 | x = x.view(-1, self.nf * 16, 1).repeat(1, 1, n_pts) 115 | return torch.cat([x, pointfeat], 1), trans, trans_feat 116 | 117 | 118 | class PointNetEncoder(nn.Module): 119 | def __init__( 120 | self, 121 | nf=64, 122 | in_features=3, 123 | out_features=8, 124 | feature_transform=False, 125 | dropout_prob=0.3, 126 | norm_type="batchnorm", 127 | nonlinearity="relu", 128 | ): 129 | super(PointNetEncoder, self).__init__() 130 | assert norm_type in NORMTYPE.keys() 131 | self.feature_transform = feature_transform 132 | self.dropout_prob = dropout_prob 133 | self.feat = PointNetfeat( 134 | global_feat=True, 135 | in_features=in_features, 136 | nf=nf, 137 | feature_transform=feature_transform, 138 | norm_type=norm_type, 139 | nonlinearity=nonlinearity, 140 | ) 141 | self.fc1 = nn.Conv1d(nf * 16, nf * 8, 1) 142 | self.fc2 = nn.Conv1d(nf * 8, nf * 4, 1) 143 | self.fc3 = nn.Conv1d(nf * 4, out_features, 1) 144 | self.dropout = nn.Dropout(p=dropout_prob) 145 | self.nm1 = NORMTYPE[norm_type](nf * 8) 146 | self.nm2 = NORMTYPE[norm_type](nf * 4) 147 | self.nl = NONLINEARITIES[nonlinearity] 148 | self.in_features = in_features 149 | self.out_features = out_features 150 | self.nf = nf 151 | 152 | def forward(self, x): 153 | """ 154 | Args: 155 | x: tensor of shape [batch, npoints, in_features] 156 | Returns: 157 | output: tensor of shape [batch, out_features] 158 | """ 159 | x = x.permute(0, 2, 1) 160 | x, _, _ = self.feat(x) 161 | x = x.unsqueeze(-1) 162 | x = self.nl(self.nm1(self.fc1(x))) 163 | x = self.nl( 164 | self.nm2(self.dropout(self.fc2(x).squeeze(-1)).unsqueeze(-1)) 165 | ) 166 | x = self.fc3(x).squeeze(-1) 167 | return x 168 | 169 | 170 | def feature_transform_regularizer(trans): 171 | d = trans.size()[1] 172 | I = torch.eye(d)[None, :, :] # noqa: E741 173 | if trans.is_cuda: 174 | I = I.cuda() # noqa: E741 175 | loss = torch.mean( 176 | torch.norm(torch.bmm(trans, trans.transpose(2, 1)) - I, dim=(1, 2)) 177 | ) 178 | return loss 179 | 180 | 181 | if __name__ == "__main__": 182 | # example for using encoder 183 | encoder = PointNetEncoder(output_features=8) 184 | points = torch.rand(16, 100, 3) 185 | latents = encoder(points) 186 | print(latents.shape) 187 | -------------------------------------------------------------------------------- /shapeflow/layers/shared_definition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class NoNorm(nn.Module): 6 | def __init__(self, layers): 7 | super(NoNorm, self).__init__() 8 | 9 | def forward(self, x): 10 | return x 11 | 12 | 13 | class Swish(nn.Module): 14 | def __init__(self): 15 | super(Swish, self).__init__() 16 | self.beta = nn.Parameter(torch.tensor(1.0)) 17 | 18 | def forward(self, x): 19 | return x * torch.sigmoid(self.beta * x) 20 | 21 | 22 | class Lambda(nn.Module): 23 | def __init__(self, f): 24 | super(Lambda, self).__init__() 25 | self.f = f 26 | 27 | def forward(self, x): 28 | return self.f(x) 29 | 30 | 31 | OPTIMIZERS = { 32 | "sgd": torch.optim.SGD, 33 | "adam": torch.optim.Adam, 34 | "adadelta": torch.optim.Adadelta, 35 | "adagrad": torch.optim.Adagrad, 36 | "rmsprop": torch.optim.RMSprop, 37 | } 38 | 39 | LOSSES = { 40 | "l1": torch.nn.L1Loss(), 41 | "l2": torch.nn.MSELoss(), 42 | "huber": torch.nn.SmoothL1Loss(), 43 | } 44 | 45 | REDUCTIONS = { 46 | "mean": lambda x: torch.mean(x, axis=-1), 47 | "max": lambda x: torch.max(x, axis=-1)[0], 48 | "min": lambda x: torch.min(x, axis=-1)[0], 49 | "sum": lambda x: torch.sum(x, axis=-1), 50 | } 51 | 52 | 53 | NORMTYPE = { 54 | "batchnorm": nn.BatchNorm1d, 55 | "instancenorm": nn.InstanceNorm1d, 56 | "none": NoNorm, 57 | } 58 | 59 | NONLINEARITIES = { 60 | "tanh": nn.Tanh(), 61 | "relu": nn.ReLU(), 62 | "softplus": nn.Softplus(), 63 | "elu": nn.ELU(), 64 | "swish": Swish(), 65 | "square": Lambda(lambda x: x ** 2), 66 | "identity": Lambda(lambda x: x), 67 | "leakyrelu": nn.LeakyReLU(), 68 | "tanh10x": Lambda(lambda x: torch.tanh(10 * x)), 69 | } 70 | -------------------------------------------------------------------------------- /shapeflow/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/ShapeFlow/351ee2fe2508eb052b658dddf136e5db0e49c7cb/shapeflow/utils/__init__.py -------------------------------------------------------------------------------- /shapeflow/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | """Utility tools for training the model. 2 | """ 3 | import logging 4 | import os 5 | import shutil 6 | import torch 7 | import numpy as np 8 | from matplotlib import cm, colors 9 | 10 | # pylint: disable=too-many-arguments 11 | 12 | 13 | def save_checkpoint(state, is_best, epoch, output_folder, filename, logger): 14 | """Save checkpoint. 15 | Args: 16 | state: dict, containing state of the model to save. 17 | is_best: bool, indicate whether this is the best model so far. 18 | epoch: int, epoch number. 19 | output_folder: str, path to output folder. 20 | filename: str, the name to save the model as. 21 | logger: logger object to log progress. 22 | """ 23 | if epoch > 1: 24 | prev_ckpt = ( 25 | output_folder + filename + "_%03d" % (epoch - 1) + ".pth.tar" 26 | ) 27 | if os.path.exists(prev_ckpt): 28 | os.remove(prev_ckpt) 29 | torch.save(state, output_folder + filename + "_%03d" % epoch + ".pth.tar") 30 | print(output_folder + filename + "_%03d" % epoch + ".pth.tar") 31 | if is_best: 32 | if logger is not None: 33 | logger.info("Saving new best model") 34 | 35 | shutil.copyfile( 36 | output_folder + filename + "_%03d" % epoch + ".pth.tar", 37 | output_folder + filename + "_best.pth.tar", 38 | ) 39 | 40 | 41 | def snapshot_files(list_of_filenames, log_dir): 42 | """Snapshot list of files in current run state to the log directory. 43 | Args: 44 | list_of_filenames: list of str. 45 | log_dir: str, log directory to save code snapshots. 46 | """ 47 | snap_dir = os.path.join(log_dir, "snapshots") 48 | os.makedirs(snap_dir, exist_ok=True) 49 | for filename in list_of_filenames: 50 | if filename == os.path.basename(filename): 51 | shutil.copy2(filename, os.path.join(snap_dir, filename)) 52 | else: 53 | subdir = os.path.dirname(filename) 54 | os.makedirs(subdir, exist_ok=True) 55 | shutil.copy2(filename, os.path.join(snap_dir, filename)) 56 | 57 | 58 | def get_logger( 59 | log_dir, name="train", level=logging.DEBUG, log_file_name="log.txt" 60 | ): 61 | """Get a logger that writes a log file in log_dir. 62 | Args: 63 | log_dir: str, log directory to save logs. 64 | name: str, name of the logger instance. 65 | level: logging level. 66 | log_file_name: str, name of the log file to output. 67 | Returns: 68 | a logger instance 69 | """ 70 | logger = logging.getLogger(name) 71 | logger.setLevel(level) 72 | logger.handlers = [] 73 | stream_handler = logging.StreamHandler() 74 | logger.addHandler(stream_handler) 75 | file_handler = logging.FileHandler( 76 | os.path.join(log_dir, os.path.basename(log_file_name)) 77 | ) 78 | logger.addHandler(file_handler) 79 | return logger 80 | 81 | 82 | def colorize_scalar_tensors( 83 | x, vmin=None, vmax=None, cmap="viridis", out_channel="rgb" 84 | ): 85 | """Colorize scalar field tensors. 86 | Args: 87 | x: torch tensor of shape [H, W]. 88 | vmin: float, min value to normalize the colors to. 89 | vmax: float, max value to normalize the colors to. 90 | cmap: str or Colormap instance, the colormap used to map normalized 91 | data values to RGBA colors. 92 | out_channel: str, 'rgb' or 'rgba'. 93 | Returns: 94 | y: torch tensor of shape [H, W, 3(or 4 if out_channel=='rgbd')], 95 | mapped colors. 96 | """ 97 | if vmin or vmax: 98 | normalizer = colors.Normalize(vmin, vmax) 99 | else: 100 | normalizer = None 101 | assert out_channel in ["rgb", "rgba"] 102 | 103 | mapper = cm.ScalarMappable(norm=normalizer, cmap=cmap) 104 | x_ = x.detach().cpu().numpy() 105 | 106 | y_ = mapper.to_rgba(x_)[..., : len(out_channel)].astype(x_.dtype) 107 | y = torch.tensor(y_, device=x.device) 108 | 109 | return y 110 | 111 | 112 | def batch_colorize_scalar_tensors( 113 | x, vmin=None, vmax=None, cmap="viridis", out_channel="rgb" 114 | ): 115 | """Colorize scalar field tensors. 116 | Args: 117 | x: torch tensor of shape [N, H, W]. 118 | vmin: float, or array of length N. min value to normalize the colors 119 | to. 120 | vmax: float, or array of length N. max value to normalize the colors 121 | to. 122 | cmap: str or Colormap instance, the colormap used to map normalized 123 | data values to RGBA 124 | colors. 125 | out_channel: str, 'rgb' or 'rgba'. 126 | Returns: 127 | y: torch tensor of shape [N, H, W, 3(or 4 if out_channel=='rgbd')] 128 | """ 129 | 130 | def broadcast_limits(v): 131 | if v: 132 | if not isinstance(v, np.array): 133 | v = np.array(v) 134 | v = np.broadcast_to(v, x.shape[0]) 135 | return v 136 | 137 | vmin = broadcast_limits(vmin) 138 | vmax = broadcast_limits(vmax) 139 | y = torch.zeros(list(x.shape) + [len(out_channel)], device=x.device) 140 | for idx in range(x.shape[0]): 141 | y[idx] = colorize_scalar_tensors(x[idx]) 142 | 143 | return y 144 | 145 | 146 | def symmetric_duplication(points, symm_dim=2): 147 | """Symmetric duplication of points. 148 | 149 | Args: 150 | points: tensor of shape [batch, npoints, 3] 151 | symm_dim: int, direction of symmetry. 152 | Returns: 153 | duplicated points, tensor of shape [batch, 2*npoints, 3] 154 | """ 155 | points_dup = points.clone() 156 | points_dup[..., symm_dim] = -points_dup[..., symm_dim] 157 | points_new = torch.cat([points, points_dup], dim=1) 158 | 159 | return points_new 160 | -------------------------------------------------------------------------------- /shapenet_dataloader.py: -------------------------------------------------------------------------------- 1 | """ShapeNet deformation dataloader""" 2 | import os 3 | import torch 4 | from torch.utils.data import Dataset, Sampler 5 | import numpy as np 6 | import trimesh 7 | import glob 8 | import imageio 9 | import pickle 10 | from collections import OrderedDict 11 | from scipy.spatial import cKDTree 12 | 13 | 14 | synset_to_cat = { 15 | "02691156": "airplane", 16 | "02933112": "cabinet", 17 | "03001627": "chair", 18 | "03636649": "lamp", 19 | "04090263": "rifle", 20 | "04379243": "table", 21 | "04530566": "watercraft", 22 | "02828884": "bench", 23 | "02958343": "car", 24 | "03211117": "display", 25 | "03691459": "speaker", 26 | "04256520": "sofa", 27 | "04401088": "telephone", 28 | } 29 | 30 | cat_to_synset = {value: key for key, value in synset_to_cat.items()} 31 | 32 | SPLITS = ["train", "test", "val", "*"] 33 | 34 | 35 | def strip_name(filename): 36 | if len(filename.split("/")) > 3: 37 | return "/".join(filename.split("/")[-4:-1]) 38 | else: 39 | return filename 40 | 41 | 42 | class ShapeNetBase(Dataset): 43 | """Pytorch Dataset base for loading ShapeNet shape pairs.""" 44 | 45 | def __init__(self, data_root, split, category="chair"): 46 | """ 47 | Initialize DataSet 48 | Args: 49 | data_root: str, path to data root that contains the ShapeNet dataset. 50 | split: str, one of 'train'/'val'/'test'/'*'. '*' for all splits. 51 | catetory: 52 | str, name of the category to train on. 'all' for all 13 classes. 53 | Otherwise can be a comma separated string containing multiple names 54 | """ 55 | self.data_root = data_root 56 | self.split = split 57 | 58 | if not (split in SPLITS): 59 | raise ValueError(f"{split} must be one of {SPLITS}") 60 | self.categories = [c.strip() for c in category.split(",")] 61 | cats = list(cat_to_synset.keys()) 62 | if "all" in self.categories: 63 | self.categories = cats 64 | for c in self.categories: 65 | if c not in cats: 66 | raise ValueError( 67 | f"{c} is not in the list of the 13 categories: {cats}" 68 | ) 69 | self.files = self._get_filenames( 70 | self.data_root, self.split, self.categories 71 | ) 72 | self._file_splits = None 73 | 74 | self.thumbnails_dir = None 75 | self.thumbnails = False 76 | self._fname_to_idx_dict = None 77 | 78 | @property 79 | def file_splits(self): 80 | if self._file_splits is None: 81 | self._file_splits = {"train": [], "test": [], "val": []} 82 | for f in self.files: 83 | if "train/" in f: 84 | self._file_splits["train"].append(f) 85 | elif "test/" in f: 86 | self._file_splits["test"].append(f) 87 | else: # val/ 88 | self._file_splits["val"].append(f) 89 | 90 | return self._file_splits 91 | 92 | @staticmethod 93 | def _get_filenames(data_root, split, categories): 94 | files = [] 95 | for c in categories: 96 | synset_id = cat_to_synset[c] 97 | if split != "*": 98 | cat_folder = os.path.join(data_root, split, synset_id) 99 | if not os.path.exists(cat_folder): 100 | raise RuntimeError( 101 | f"Datafolder for {synset_id} ({c}) " 102 | f"does not exist at {cat_folder}." 103 | ) 104 | files += glob.glob( 105 | os.path.join(cat_folder, "*/*.ply"), recursive=True 106 | ) 107 | else: 108 | for split in SPLITS[:3]: 109 | cat_folder = os.path.join(data_root, split, synset_id) 110 | if not os.path.exists(cat_folder): 111 | raise RuntimeError( 112 | f"Datafolder for {synset_id} ({c}) does not exist " 113 | f"at {cat_folder}." 114 | ) 115 | files += glob.glob( 116 | os.path.join(cat_folder, "*/*.ply"), recursive=True 117 | ) 118 | return sorted(files) 119 | 120 | def __len__(self): 121 | return self.n_shapes ** 2 122 | 123 | @property 124 | def n_shapes(self): 125 | return len(self.files) 126 | 127 | def restrict_subset(self, indices): 128 | """Restrict data to the subset of data as indicated by the indices. 129 | 130 | Mostly helpful for debugging only. 131 | 132 | Args: 133 | indices: list or array of ints, to index the original self.files 134 | """ 135 | self.files = [self.files[i] for i in indices] 136 | 137 | @property 138 | def fname_to_idx_dict(self): 139 | """A dict mapping unique mesh names to indicies.""" 140 | if self._fname_to_idx_dict is None: 141 | fnames = ["/".join(f.split("/")[-4:-1]) for f in self.files] 142 | self._fname_to_idx_dict = dict( 143 | zip(fnames, list(range(len(fnames)))) 144 | ) 145 | return self._fname_to_idx_dict 146 | 147 | def idx_to_combinations(self, idx): 148 | """Convert s linear index to a pair of indices.""" 149 | i = np.floor(idx / self.n_shapes) 150 | j = idx - i * self.n_shapes 151 | if hasattr(idx, "__len__"): 152 | i = np.array(i, dtype=int) 153 | j = np.array(j, dtype=int) 154 | else: 155 | i = int(i) 156 | j = int(j) 157 | return i, j 158 | 159 | def combinations_to_idx(self, i, j): 160 | """Convert a pair of indices to a linear index.""" 161 | idx = i * self.n_shapes + j 162 | if hasattr(idx, "__len__"): 163 | idx = np.array(idx, dtype=int) 164 | else: 165 | idx = int(idx) 166 | return idx 167 | 168 | 169 | class ShapeNetVertex(ShapeNetBase): 170 | """Pytorch Dataset for sampling vertices from meshes.""" 171 | 172 | def __init__( 173 | self, data_root, split, category="chair", nsamples=5000, normals=True 174 | ): 175 | """ 176 | Initialize DataSet 177 | Args: 178 | data_root: str, path to data root that contains the ShapeNet dataset. 179 | split: str, one of 'train'/'val'/'test'. 180 | catetory: str, name of the category to train on. 'all' for all 13 181 | classes. Otherwise can be a comma separated string containing 182 | multiple names. 183 | nsamples: int, number of points to sample from each mesh. 184 | normals: bool, whether to add normals to the point features. 185 | """ 186 | super(ShapeNetVertex, self).__init__( 187 | data_root=data_root, split=split, category=category 188 | ) 189 | self.nsamples = nsamples 190 | self.normals = normals 191 | 192 | @staticmethod 193 | def sample_mesh(mesh_path, nsamples, normals=True): 194 | """Load the mesh from mesh_path and sample nsampels points from its vertices. 195 | 196 | If nsamples < number of vertices on mesh, randomly repeat some 197 | vertices as padding. 198 | 199 | Args: 200 | mesh_path: str, path to load the mesh from. 201 | nsamples: int, number of vertices to sample. 202 | normals: bool, whether to add normals to the point features. 203 | Returns: 204 | v_sample: np array of shape [nsamples, 3 or 6] for sampled points. 205 | """ 206 | mesh = trimesh.load(mesh_path) 207 | v = np.array(mesh.vertices) 208 | nv = v.shape[0] 209 | seq = np.random.permutation(nv)[:nsamples] 210 | if len(seq) < nsamples: 211 | seq_repeat = np.random.choice( 212 | nv, nsamples - len(seq), replace=True 213 | ) 214 | seq = np.concatenate([seq, seq_repeat], axis=0) 215 | v_sample = v[seq] 216 | if normals: 217 | n_sample = np.array(mesh.vertex_normals[seq]) 218 | v_sample = np.concatenate([v_sample, n_sample], axis=-1) 219 | 220 | return v_sample 221 | 222 | def add_thumbnails(self, thumbnails_root): 223 | self.thumbnails = True 224 | self.thumbnails_dir = thumbnails_root 225 | 226 | def _get_one_mesh(self, idx): 227 | verts = self.sample_mesh(self.files[idx], self.nsamples, self.normals) 228 | verts = verts.astype(np.float32) 229 | if self.thumbnails: 230 | thumb_dir = self.files[idx].replace( 231 | self.data_root, self.thumbnails_dir 232 | ) 233 | thumb_dir = os.path.dirname(thumb_dir) 234 | thumb_file = os.path.join(thumb_dir, "thumbnail.jpg") 235 | thumb = np.array(imageio.imread(thumb_file)) 236 | return verts, thumb 237 | else: 238 | return verts 239 | 240 | def __getitem__(self, idx): 241 | """Get a random pair of shapes corresponding to idx. 242 | Args: 243 | idx: int, index of the shape pair to return. must be smaller than 244 | len(self). 245 | Returns: 246 | verts_i: [npoints, 3 or 6] float tensor for point samples from the 247 | first mesh. 248 | verts_j: [npoints, 3 or 6] float tensor for point samples from the 249 | second mesh. 250 | thumb_i: (optional) [H, W, 3] int8 tensor for thumbnail image for 251 | the first mesh. 252 | thumb_j: (optional) [H, W, 3] int8 tensor for thumbnail image for 253 | the second mesh. 254 | """ 255 | i, j = self.idx_to_combinations(idx) 256 | if self.thumbnails: 257 | verts_i, thumb_i = self._get_one_mesh(i) 258 | verts_j, thumb_j = self._get_one_mesh(j) 259 | return i, j, verts_i, verts_j, thumb_i, thumb_j 260 | else: 261 | verts_i = self._get_one_mesh(i) 262 | verts_j = self._get_one_mesh(j) 263 | return i, j, verts_i, verts_j 264 | 265 | 266 | class ShapeNetMesh(ShapeNetBase): 267 | """Pytorch Dataset for sampling entire meshes.""" 268 | 269 | def __init__(self, data_root, split, category="chair", normals=True): 270 | """ 271 | Initialize DataSet 272 | Args: 273 | data_root: str, path to data root that contains the ShapeNet dataset. 274 | split: str, one of 'train'/'val'/'test'. 275 | catetory: 276 | str, name of the category to train on. 'all' for all 13 classes. 277 | Otherwise can be a comma separated string containing multiple 278 | names. 279 | """ 280 | super(ShapeNetMesh, self).__init__( 281 | data_root=data_root, split=split, category=category 282 | ) 283 | self.normals = normals 284 | 285 | def get_pairs(self, i, j): 286 | verts_i, faces_i = self.get_single(i) 287 | verts_j, faces_j = self.get_single(j) 288 | 289 | return i, j, verts_i, faces_i, verts_j, faces_j 290 | 291 | def get_single(self, i): 292 | mesh_i = trimesh.load(self.files[i]) 293 | 294 | verts_i = mesh_i.vertices.astype(np.float32) 295 | faces_i = mesh_i.faces.astype(np.int32) 296 | 297 | if self.normals: 298 | norms_i = mesh_i.vertex_normals.astype(np.float32) 299 | verts_i = np.concatenate([verts_i, norms_i], axis=-1) 300 | 301 | verts_i = torch.from_numpy(verts_i) 302 | faces_i = torch.from_numpy(faces_i) 303 | 304 | return verts_i, faces_i 305 | 306 | def __getitem__(self, idx): 307 | """Get a random pair of meshes. 308 | Args: 309 | idx: int, index of the shape pair to return. must be smaller than 310 | len(self). 311 | Returns: 312 | verts_i: [#vi, 3 or 6] float tensor for vertices from the first mesh. 313 | faces_i: [#fi, 3 or 6] int32 tensor for faces from the first mesh. 314 | verts_j: [#vj, 3 or 6] float tensor for vertices from the 2nd mesh. 315 | faces_j: [#fj, 3 or 6] int32 tensor for faces from the 2nd mesh. 316 | """ 317 | i, j = self.idx_to_combinations(idx) 318 | return self.get_pairs(i, j) 319 | 320 | 321 | class FixedPointsCachedDataset(Dataset): 322 | """Dataset for loading fixed points dataset from cached pickle file.""" 323 | 324 | def __init__(self, pkl_file, npts=1024): 325 | with open(pkl_file, "rb") as fh: 326 | self.data_dict = pickle.load(fh) 327 | self.data_dict = OrderedDict(sorted(self.data_dict.items())) 328 | self.key_list = list(self.data_dict.keys()) 329 | assert npts <= 4096 and npts > 0 330 | self.npts = npts 331 | 332 | def __getitem__(self, idx): 333 | filename = self.key_list[idx] 334 | points = self.data_dict[filename] 335 | rand_seq = np.random.choice(points.shape[0], self.npts, replace=False) 336 | points_ = points[rand_seq] 337 | return filename, idx, points_ 338 | 339 | def __len__(self): 340 | return len(self.data_dict) 341 | 342 | 343 | class PairSamplerBase(Sampler): 344 | """Data sampler base for sampling pairs.""" 345 | 346 | def __init__( 347 | self, dataset, src_split, tar_split, n_samples, replace=False 348 | ): 349 | assert src_split in SPLITS[:3] 350 | assert tar_split in SPLITS[:3] 351 | self.replace = replace 352 | self.n_samples = n_samples 353 | self.src_split = src_split 354 | self.tar_split = tar_split 355 | self.dataset = dataset 356 | self.src_files = self.dataset.file_splits[src_split] 357 | self.tar_files = self.dataset.file_splits[tar_split] 358 | self.src_files = [strip_name(f) for f in self.src_files] 359 | self.tar_files = [strip_name(f) for f in self.tar_files] 360 | self.n_src = len(self.src_files) 361 | self.n_tar = len(self.tar_files) 362 | if not replace: 363 | if not self.n_samples <= self.n_src: 364 | raise RuntimeError( 365 | f"Numer of samples ({len(self.n_samples)}) must be " 366 | f"less than number source shapes ({len(self.n_src)})" 367 | ) 368 | 369 | def __iter__(self): 370 | raise NotImplementedError 371 | 372 | def __len__(self): 373 | raise NotImplementedError 374 | 375 | 376 | class RandomPairSampler(PairSamplerBase): 377 | """Data sampler for sampling random pairs.""" 378 | 379 | def __init__( 380 | self, dataset, src_split, tar_split, n_samples, replace=False 381 | ): 382 | super(RandomPairSampler, self).__init__( 383 | dataset, src_split, tar_split, n_samples, replace 384 | ) 385 | 386 | def __iter__(self): 387 | d = self.dataset 388 | if self.replace: 389 | src_names = np.random.choice( 390 | self.src_files, self.n_samples, replace=True 391 | ) 392 | tar_names = np.random.choice( 393 | self.tar_files, self.n_samples, replace=True 394 | ) 395 | else: 396 | src_names = np.random.permutation(self.src_files)[ 397 | : int(self.n_samples) 398 | ] 399 | tar_names = np.random.permutation(self.tar_files)[ 400 | : int(self.n_samples) 401 | ] 402 | src_idxs = np.array( 403 | [d.fname_to_idx_dict[strip_name(f)] for f in src_names] 404 | ) 405 | tar_idxs = np.array( 406 | [d.fname_to_idx_dict[strip_name(f)] for f in tar_names] 407 | ) 408 | combo_ids = self.dataset.combinations_to_idx(src_idxs, tar_idxs) 409 | 410 | return iter(combo_ids) 411 | 412 | def __len__(self): 413 | return self.n_samples 414 | 415 | 416 | class LatentNearestNeighborSampler(PairSamplerBase): 417 | """Data sampler for sampling pairs from top-k nearest latent neighbors.""" 418 | 419 | def __init__( 420 | self, dataset, src_split, tar_split, n_samples, k, replace=False 421 | ): 422 | """Initialize. 423 | 424 | Args: 425 | k: int, top-k neighbors to sample from. 426 | replace: bool, sample with replacement. 427 | if no replace, then must ensure n_samples <= n_shapes 428 | """ 429 | super(LatentNearestNeighborSampler, self).__init__( 430 | dataset, src_split, tar_split, n_samples, replace 431 | ) 432 | self.k = k 433 | self.graph_set = False 434 | 435 | def update_nn_graph(self, src_latent_dict, tar_latent_dict, k=None): 436 | """Update nearest neighbor graph. 437 | 438 | Args: 439 | src_latent_dict: a dict that maps filenames to latent codes for 440 | source set. 441 | tar_latent_dict: a dict that maps filenames to latent codes for 442 | target set. 443 | """ 444 | if k is not None: 445 | self.k = k 446 | tar_names = list(tar_latent_dict.keys()) 447 | tar_latents = list(tar_latent_dict.values()) 448 | tar_latents = np.stack(tar_latents, axis=0) # [n, lat_dim] 449 | # build kd-tree to accelerate nearest neighbor computation 450 | k = self.k + 1 if self.src_split == self.tar_split else self.k 451 | self._kdtree = cKDTree(tar_latents) 452 | 453 | src_names = list(src_latent_dict.keys()) 454 | src_latents = list(src_latent_dict.values()) 455 | src_latents = np.stack(src_latents, axis=0) # [m, lat_dim] 456 | _, nn_idx = self._kdtree.query(src_latents, k=k) # [m, k] 457 | if nn_idx.ndim == 1: 458 | nn_idx = nn_idx[:, None] 459 | nn_idx = nn_idx[:, -self.k:] 460 | 461 | nn_names = [] 462 | for i in range(nn_idx.shape[0]): 463 | nn_names.append([tar_names[j] for j in nn_idx[i]]) 464 | self._nn_map = dict(zip(src_names, nn_names)) 465 | self.graph_set = True 466 | 467 | @property 468 | def kdtree(self): 469 | return self._kdtree 470 | 471 | @property 472 | def nn_map(self): 473 | return self._nn_map 474 | 475 | def __iter__(self): 476 | if not self.graph_set: 477 | raise RuntimeError( 478 | "Nearest neighbor graph not yet set." 479 | " Run '.update_nn_graph()' to update first." 480 | ) 481 | d = self.dataset 482 | # return generator 483 | if self.replace: 484 | src_names = np.random.choice( 485 | self.src_files, self.n_samples, replace=True 486 | ) 487 | else: 488 | src_names = np.random.permutation(self.src_files)[ 489 | : int(self.n_samples) 490 | ] 491 | 492 | for src_name in src_names: 493 | tar_name = np.random.choice(self.nn_map[src_name], 1)[0] 494 | src_idx = d.fname_to_idx_dict[strip_name(src_name)] 495 | tar_idx = d.fname_to_idx_dict[strip_name(tar_name)] 496 | combo_id = d.combinations_to_idx(src_idx, tar_idx) 497 | yield combo_id 498 | 499 | def __len__(self): 500 | return self.n_samples 501 | -------------------------------------------------------------------------------- /shapenet_embedding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial import cKDTree 3 | import time 4 | import torch 5 | from torch.utils.data import SubsetRandomSampler, DataLoader 6 | 7 | from shapeflow.layers.chamfer_layer import ChamferDistKDTree 8 | from shapeflow.layers.shared_definition import LOSSES, OPTIMIZERS 9 | import shapeflow.utils.train_utils as utils 10 | 11 | 12 | class LatentEmbedder(object): 13 | """Helper class for embedding new observation in deformation latent space. 14 | """ 15 | 16 | def __init__(self, point_dataset, mesh_dataset, deformer, topk=5): 17 | """Initialize embedder. 18 | 19 | Args: 20 | point_dataset: instance of FixedPointsCachedDataset 21 | mesh_dataset: instance of ShapeNetMesh 22 | deformer: pretrined deformer instance 23 | """ 24 | self.point_dataset = point_dataset 25 | self.mesh_dataset = mesh_dataset 26 | self.deformer = deformer 27 | self.topk = topk 28 | self.tree = cKDTree(self.lat_params.clone().detach().cpu().numpy()) 29 | 30 | @property 31 | def lat_dims(self): 32 | return self.lat_params.shape[1] 33 | 34 | @property 35 | def lat_params(self): 36 | return self.deformer.net.lat_params 37 | 38 | @property 39 | def symm(self): 40 | return self.deformer.symm_dim is not None 41 | 42 | @property 43 | def device(self): 44 | return self.lat_params.device 45 | 46 | def _padded_verts_from_meshes(self, meshes): 47 | verts = [vf[0] for vf in meshes] 48 | faces = [vf[1] for vf in meshes] 49 | nv = [v.shape[0] for v in verts] 50 | max_nv = np.max(nv) 51 | verts_pad = [ 52 | np.pad(verts[i], ((0, max_nv - nv[i]), (0, 0))) 53 | for i in range(len(nv)) 54 | ] 55 | verts_pad = np.stack(verts_pad, 0) # [nmesh, max_nv, 3] 56 | return verts_pad, faces, nv 57 | 58 | def _meshes_from_padded_verts(self, verts_pad, faces, nv): 59 | verts_pad = [v for v in verts_pad] 60 | verts = [v[:n] for v, n in zip(verts_pad, nv)] 61 | meshes = list(zip(verts, faces)) 62 | return meshes 63 | 64 | def embed( 65 | self, 66 | input_points, 67 | optimizer="adam", 68 | lr=1e-3, 69 | seed=0, 70 | embedding_niter=30, 71 | finetune_niter=30, 72 | bs=32, 73 | verbose=False, 74 | matching="two_way", 75 | loss_type="l1", 76 | ): 77 | """Embed inputs points observations into deformation latent space. 78 | 79 | Args: 80 | input_points: tensor of shape [bs_tar, npoints, 3] 81 | optimizer: str, optimizer choice. one of sgd, adam, adadelta, 82 | adagrad, rmsprop. 83 | lr: float, learning rate. 84 | seed: int, random seed. 85 | embedding_niter: int, number of embedding optimization iterations. 86 | finetune_niter: int, number of finetuning optimization iterations. 87 | bs: int, batch size. 88 | verbose: bool, turn on verbose. 89 | matching: str, matching function. choice of one_way or two_way. 90 | loss_type: str, loss type. choice of l1, l2, huber. 91 | 92 | Returns: 93 | embedded_latents: tensor of shape [batch, lat_dims] 94 | """ 95 | if input_points.shape[0] != 1: 96 | raise NotImplementedError("Code is not ready for batch size > 1.") 97 | torch.manual_seed(seed) 98 | 99 | # Check input validity. 100 | if matching not in ["one_way", "two_way"]: 101 | raise ValueError( 102 | f"matching method must be one of one_way / two_way. Instead " 103 | f"entered {matching}" 104 | ) 105 | if loss_type not in LOSSES.keys(): 106 | raise ValueError( 107 | f"loss_type must be one of {LOSSES.keys()}. " 108 | f"Instead entered {loss_type}" 109 | ) 110 | 111 | criterion = LOSSES[loss_type] 112 | 113 | bs_tar, npts_tar, _ = input_points.shape 114 | # Assign random latent code close to zero. 115 | embedded_latents = torch.nn.Parameter( 116 | torch.randn(bs_tar, self.lat_dims, device=self.device) * 1e-4, 117 | requires_grad=True, 118 | ) 119 | self.deformer.net.tar_latents = embedded_latents 120 | embedded_latents = self.deformer.net.tar_latents 121 | # [bs_tar, lat_dims] 122 | 123 | # Init optimizer. 124 | if optimizer not in OPTIMIZERS.keys(): 125 | raise ValueError(f"optimizer must be one of {OPTIMIZERS.keys()}") 126 | optim = OPTIMIZERS[optimizer]([embedded_latents], lr=lr) 127 | 128 | # Init dataloader. 129 | sampler = SubsetRandomSampler( 130 | np.arange(len(self.point_dataset)).tolist() 131 | ) 132 | point_loader = DataLoader( 133 | self.point_dataset, 134 | batch_size=bs, 135 | sampler=sampler, 136 | shuffle=False, 137 | drop_last=True, 138 | ) 139 | 140 | # Chamfer distance calc. 141 | chamfer_dist = ChamferDistKDTree(reduction="mean", njobs=1) 142 | chamfer_dist.to(self.device) 143 | 144 | def optimize_latent(point_loader, optim, niter): 145 | # Optimize for latents. 146 | self.deformer.train() 147 | toc = time.time() 148 | 149 | bs_src = point_loader.batch_size 150 | embedded_latents_ = embedded_latents[None].expand( 151 | bs_src, bs_tar, self.lat_dims 152 | ) 153 | # [bs_src, bs_tar, lat_dims] 154 | 155 | # Broadcast and reshape input points. 156 | target_points_ = ( 157 | input_points[None] 158 | .expand(bs_src, bs_tar, npts_tar, 3) 159 | .view(-1, npts_tar, 3) 160 | ) 161 | it = 0 162 | 163 | for batch_idx, (fnames, idxs, source_points) in enumerate( 164 | point_loader 165 | ): 166 | tic = time.time() 167 | # Send tensors to device. 168 | source_points = source_points.to( 169 | self.device 170 | ) # [bs_src, npts_src, 3] 171 | idxs = idxs.to(self.device) 172 | 173 | optim.zero_grad() 174 | 175 | # Deform chosen points to input_points. 176 | # Broadcast src lats to src x tar. 177 | source_latents = self.lat_params[idxs] # [bs_src, lat_dims] 178 | source_latents_ = source_latents[:, None].expand( 179 | bs_src, bs_tar, self.lat_dims 180 | ) 181 | source_latents_ = source_latents_.view(-1, self.lat_dims) 182 | target_latents_ = embedded_latents_.view(-1, self.lat_dims) 183 | zeros = torch.zeros_like(source_latents_) 184 | source_target_latents = torch.stack( 185 | [source_latents_, zeros, target_latents_], dim=1 186 | ) 187 | 188 | deformed_pts = self.deformer( 189 | source_points, 190 | source_target_latents, # [bs_sr*bs_tar, npts_src, 3] 191 | ) # [bs_sr*bs_tar, npts_src, 3] 192 | 193 | # Symmetric pair of matching losses. 194 | if self.symm: 195 | accu, comp, cham = chamfer_dist( 196 | utils.symmetric_duplication(deformed_pts, symm_dim=2), 197 | utils.symmetric_duplication( 198 | target_points_, symm_dim=2 199 | ), 200 | ) 201 | else: 202 | accu, comp, cham = chamfer_dist( 203 | deformed_pts, target_points_ 204 | ) 205 | 206 | if matching == "one_way": 207 | comp = torch.mean(comp, dim=1) 208 | loss = criterion(comp, torch.zeros_like(comp)) 209 | else: 210 | loss = criterion(cham, torch.zeros_like(cham)) 211 | 212 | # Check amount of deformation. 213 | deform_abs = torch.mean( 214 | torch.norm(deformed_pts - source_points, dim=-1) 215 | ) 216 | 217 | loss.backward() 218 | 219 | # Gradient clipping. 220 | torch.nn.utils.clip_grad_value_(embedded_latents, 1.0) 221 | 222 | optim.step() 223 | 224 | toc = time.time() 225 | if verbose: 226 | if loss_type == "l1": 227 | dist = loss.item() 228 | else: 229 | dist = np.sqrt(loss.item()) 230 | print( 231 | f"Iter: {it}, Loss: {loss.item():.4f}, " 232 | f"Dist: {dist:.4f}, " 233 | f"Deformation Magnitude: {deform_abs.item():.4f}, " 234 | f"Time per iter (s): {toc-tic:.4f}" 235 | ) 236 | it += 1 237 | if batch_idx >= niter: 238 | break 239 | 240 | # Optimize to range. 241 | optimize_latent(point_loader, optim, embedding_niter) 242 | latents_pre_tune = embedded_latents.detach().cpu().numpy() 243 | 244 | # Finetune topk. 245 | dist, idxs = self.tree.query( 246 | embedded_latents.detach().cpu().numpy(), k=self.topk 247 | ) # [batch, k] 248 | bs, k = idxs.shape 249 | idxs_ = idxs.reshape(-1) 250 | 251 | # Change lr. 252 | for param_group in optim.param_groups: 253 | param_group["lr"] = 1e-3 254 | 255 | sampler = SubsetRandomSampler(idxs_.tolist() * finetune_niter) 256 | point_loader = DataLoader( 257 | self.point_dataset, 258 | batch_size=self.topk, 259 | sampler=sampler, 260 | shuffle=False, 261 | drop_last=True, 262 | ) 263 | 264 | print(f"Finetuning for {finetune_niter} iters...") 265 | optim = OPTIMIZERS[optimizer]( 266 | [embedded_latents] + list(self.deformer.parameters()), lr=1e-3 267 | ) 268 | optimize_latent(point_loader, optim, finetune_niter) 269 | latents_post_tune = embedded_latents.detach().cpu().numpy() 270 | 271 | return latents_pre_tune, latents_post_tune 272 | 273 | def retrieve(self, lat_codes, tar_pts, matching="one_way"): 274 | """Retrieve top 10 nearest neighbors, deform and pick the best one. 275 | 276 | Args: 277 | lat_codes: tensor of shape [batch, lat_dims], latent code targets. 278 | 279 | Returns: 280 | List of len batch of (V, F) tuples. 281 | """ 282 | if lat_codes.shape[0] != 1: 283 | raise NotImplementedError("Code is not ready for batch size > 1.") 284 | dist, idxs = self.tree.query(lat_codes, k=self.topk) # [batch, k] 285 | bs, k = idxs.shape 286 | idxs_ = idxs.reshape(-1) 287 | 288 | if not isinstance(lat_codes, torch.Tensor): 289 | lat_codes = torch.tensor(lat_codes).float().to(self.device) 290 | 291 | src_latent = self.lat_params[idxs_] # [batch*k, lat_dims] 292 | tar_latent = ( 293 | lat_codes[:, None] 294 | .expand(bs, k, self.lat_dims) 295 | .reshape(-1, self.lat_dims) 296 | ) # [batch*k, lat_dims] 297 | zeros = torch.zeros_like(src_latent) 298 | src_tar_latent = torch.stack([src_latent, zeros, tar_latent], dim=1) 299 | 300 | # Retrieve meshes. 301 | orig_meshes = [ 302 | self.mesh_dataset.get_single(i) for i in idxs_ 303 | ] # [(v1,f1), ..., (vn,fn)] 304 | src_verts, faces, nv = self._padded_verts_from_meshes(orig_meshes) 305 | src_verts = torch.tensor(src_verts).to(self.device) 306 | with torch.no_grad(): 307 | deformed_verts = self.deformer(src_verts, src_tar_latent) 308 | deformed_meshes = self._meshes_from_padded_verts( 309 | deformed_verts, faces, nv 310 | ) 311 | 312 | # Chamfer distance calc. 313 | chamfer_dist = ChamferDistKDTree(reduction="mean", njobs=1) 314 | chamfer_dist.to(self.device) 315 | dist = [] 316 | for i in range(len(deformed_meshes)): 317 | accu, comp, cham = chamfer_dist( 318 | deformed_meshes[i][0][None].to(self.device), 319 | torch.tensor(tar_pts)[None].to(self.device), 320 | ) 321 | if matching == "one_way": 322 | dist.append(torch.mean(comp, dim=1).item()) 323 | else: 324 | dist.append(cham.item()) 325 | 326 | # Reshape the list of (v, f) tuples. 327 | deformed_meshes = [ 328 | (vf[0].detach().cpu().numpy(), vf[1].detach().cpu().numpy()) 329 | for vf in deformed_meshes 330 | ] 331 | 332 | return deformed_meshes, orig_meshes, dist 333 | -------------------------------------------------------------------------------- /shapenet_generation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CONFIG 4 | category=03001627 5 | ckpt=runs/pretrained_chair_symm128/checkpoint_latest.pth.tar_shapeflow_100.pth.tar 6 | 7 | gpuid=$1 # gpu id 8 | sid=$2 # start id 9 | eid=$3 # end id 10 | 11 | export CUDA_VISIBLE_DEVICES=$gpuid 12 | 13 | allfiles=($(ls data/sparse_inputs/$category/*.ply)) 14 | 15 | for i in $(seq $sid $((eid-1))); do 16 | infile=${allfiles[$i]} 17 | python shapenet_reconstruct.py --input_path=$infile --output_dir=out --checkpoint=$ckpt --device=cuda 18 | done 19 | -------------------------------------------------------------------------------- /shapenet_reconstruct.py: -------------------------------------------------------------------------------- 1 | """Reconstruct shape from point cloud using learned deformation space. 2 | """ 3 | import os 4 | import sys 5 | import argparse 6 | import json 7 | import trimesh 8 | import torch 9 | import numpy as np 10 | import time 11 | from types import SimpleNamespace 12 | 13 | from shapenet_dataloader import ShapeNetMesh, FixedPointsCachedDataset 14 | from shapeflow.layers.deformation_layer import NeuralFlowDeformer 15 | from shapenet_embedding import LatentEmbedder 16 | 17 | 18 | synset_to_cat = { 19 | "02691156": "airplane", 20 | "02933112": "cabinet", 21 | "03001627": "chair", 22 | "03636649": "lamp", 23 | "04090263": "rifle", 24 | "04379243": "table", 25 | "04530566": "watercraft", 26 | "02828884": "bench", 27 | "02958343": "car", 28 | "03211117": "display", 29 | "03691459": "speaker", 30 | "04256520": "sofa", 31 | "04401088": "telephone", 32 | } 33 | 34 | cat_to_synset = {value: key for key, value in synset_to_cat.items()} 35 | 36 | 37 | def get_args(): 38 | """Parse command line arguments.""" 39 | parser = argparse.ArgumentParser( 40 | description="Generate reconstructions via retrieve and deform." 41 | ) 42 | 43 | parser.add_argument( 44 | "--input_path", 45 | type=str, 46 | required=True, 47 | help="path to input points (.ply file).", 48 | ) 49 | parser.add_argument( 50 | "--output_dir", type=str, required=True, help="path to output meshes." 51 | ) 52 | parser.add_argument( 53 | "--topk", 54 | type=int, 55 | default=4, 56 | help="top k nearest neighbor to retrieve.", 57 | ) 58 | parser.add_argument( 59 | "-ne", 60 | "--embedding_niter", 61 | type=int, 62 | default=30, 63 | help="number of embedding iterations.", 64 | ) 65 | parser.add_argument( 66 | "-nf", 67 | "--finetune_niter", 68 | type=int, 69 | default=30, 70 | help="number of finetuning iterations.", 71 | ) 72 | parser.add_argument( 73 | "--checkpoint", 74 | type=str, 75 | required=True, 76 | help="path to pretrained checkpoint " 77 | "(params.json must be in the same directory).", 78 | ) 79 | parser.add_argument( 80 | "--device", 81 | type=str, 82 | default="cuda:0", 83 | help="device to run inference on.", 84 | ) 85 | args = parser.parse_args() 86 | return args 87 | 88 | 89 | def main(): 90 | t0 = time.time() 91 | args_eval = get_args() 92 | 93 | device = torch.device(args_eval.device) 94 | 95 | # load training args 96 | run_dir = os.path.dirname(args_eval.checkpoint) 97 | args = SimpleNamespace( 98 | **json.load(open(os.path.join(run_dir, "params.json"), "r")) 99 | ) 100 | 101 | # assert category is correct 102 | syn_id = args_eval.input_path.split("/")[-2] 103 | mesh_name = args_eval.input_path.split("/")[-1] 104 | assert syn_id == cat_to_synset[args.category] 105 | 106 | # output directories 107 | mesh_out_dir = os.path.join(args_eval.output_dir, "meshes", syn_id) 108 | mesh_out_file = os.path.join( 109 | mesh_out_dir, mesh_name.replace(".ply", ".off") 110 | ) 111 | meta_out_dir = os.path.join( 112 | args_eval.output_dir, "meta", syn_id, mesh_name.replace(".ply", "") 113 | ) 114 | orig_dir = os.path.join(meta_out_dir, "original_retrieved") 115 | deformed_dir = os.path.join(meta_out_dir, "deformed") 116 | os.makedirs(mesh_out_dir, exist_ok=True) 117 | os.makedirs(meta_out_dir, exist_ok=True) 118 | os.makedirs(orig_dir, exist_ok=True) 119 | os.makedirs(deformed_dir, exist_ok=True) 120 | 121 | # redirect logging 122 | sys.stdout = open(os.path.join(meta_out_dir, "log.txt"), "w") 123 | 124 | # initialize deformer 125 | # input points 126 | points = np.array(trimesh.load(args_eval.input_path).vertices) 127 | 128 | # dataloader 129 | data_root = args.data_root 130 | mesh_dataset = ShapeNetMesh( 131 | data_root=data_root, 132 | split="train", 133 | category=args.category, 134 | normals=False, 135 | ) 136 | point_dataset = FixedPointsCachedDataset( 137 | f"data/shapenet_pointcloud/train/{cat_to_synset[args.category]}.pkl", 138 | npts=300, 139 | ) 140 | 141 | # setup model 142 | deformer = NeuralFlowDeformer( 143 | latent_size=args.lat_dims, 144 | f_width=args.deformer_nf, 145 | s_nlayers=2, 146 | s_width=5, 147 | method=args.solver, 148 | nonlinearity=args.nonlin, 149 | arch="imnet", 150 | adjoint=args.adjoint, 151 | rtol=args.rtol, 152 | atol=args.atol, 153 | via_hub=True, 154 | no_sign_net=(not args.sign_net), 155 | symm_dim=(2 if args.symm else None), 156 | ) 157 | 158 | lat_params = torch.nn.Parameter( 159 | torch.randn(mesh_dataset.n_shapes, args.lat_dims) * 1e-1, 160 | requires_grad=True, 161 | ) 162 | deformer.add_lat_params(lat_params) 163 | deformer.to(device) 164 | 165 | # load checkpoint 166 | resume_dict = torch.load(args_eval.checkpoint) 167 | deformer.load_state_dict(resume_dict["deformer_state_dict"]) 168 | 169 | # embed 170 | embedder = LatentEmbedder(point_dataset, mesh_dataset, deformer, topk=5) 171 | input_pts = torch.tensor(points)[None].to(device) 172 | lat_codes_pre, lat_codes_post = embedder.embed( 173 | input_pts, 174 | matching="two_way", 175 | verbose=True, 176 | lr=1e-2, 177 | embedding_niter=args_eval.embedding_niter, 178 | finetune_niter=args_eval.finetune_niter, 179 | bs=4, 180 | seed=1, 181 | ) 182 | 183 | # retrieve deformed models 184 | deformed_meshes, orig_meshes, dist = embedder.retrieve( 185 | lat_codes_post, tar_pts=points, matching="two_way" 186 | ) 187 | asort = np.argsort(dist) 188 | dist = [dist[i] for i in asort] 189 | deformed_meshes = [deformed_meshes[i] for i in asort] 190 | orig_meshes = [orig_meshes[i] for i in asort] 191 | 192 | # output best mehs 193 | vb, fb = deformed_meshes[0] 194 | trimesh.Trimesh(vb, fb).export(mesh_out_file) 195 | 196 | # meta directory 197 | for i in range(len(deformed_meshes)): 198 | vo, fo = orig_meshes[i] 199 | vd, fd = deformed_meshes[i] 200 | trimesh.Trimesh(vo, fo).export(os.path.join(orig_dir, f"{i}.ply")) 201 | trimesh.Trimesh(vd, fd).export(os.path.join(deformed_dir, f"{i}.ply")) 202 | np.save(os.path.join(meta_out_dir, "latent.npy"), lat_codes_pre) 203 | t1 = time.time() 204 | print(f"Total Timelapse: {t1-t0:.4f}") 205 | 206 | 207 | if __name__ == "__main__": 208 | main() 209 | -------------------------------------------------------------------------------- /shapenet_train.py: -------------------------------------------------------------------------------- 1 | """Training script shapenet deformation space experiment. 2 | """ 3 | import argparse 4 | import json 5 | import os 6 | import glob 7 | import numpy as np 8 | import time 9 | import trimesh 10 | 11 | import shapeflow.utils.train_utils as utils 12 | from shapeflow.layers.chamfer_layer import ChamferDistKDTree 13 | from shapeflow.layers.deformation_layer import NeuralFlowDeformer 14 | import shapenet_dataloader as dl 15 | 16 | import torch 17 | import torch.optim as optim 18 | import torch.nn as nn 19 | from torch.utils.data import DataLoader 20 | from torch.utils.tensorboard import SummaryWriter 21 | 22 | np.set_printoptions(precision=4) 23 | 24 | 25 | # Various choices for losses and optimizers. 26 | LOSSES = { 27 | "l1": torch.nn.L1Loss(), 28 | "l2": torch.nn.MSELoss(), 29 | "huber": torch.nn.SmoothL1Loss(), 30 | } 31 | 32 | OPTIMIZERS = { 33 | "sgd": optim.SGD, 34 | "adam": optim.Adam, 35 | "adadelta": optim.Adadelta, 36 | "adagrad": optim.Adagrad, 37 | "rmsprop": optim.RMSprop, 38 | } 39 | 40 | SOLVERS = [ 41 | "dopri5", 42 | "adams", 43 | "euler", 44 | "midpoint", 45 | "rk4", 46 | "explicit_adams", 47 | "fixed_adams", 48 | "bosh3", 49 | "adaptive_heun", 50 | "tsit5", 51 | ] 52 | 53 | 54 | def compute_latent_dict(deformer, dataset): 55 | """ 56 | Args: 57 | deformer: 58 | dataset: 59 | Returns: 60 | a dict that maps filenames to latent codes. 61 | """ 62 | # Encode all shapes from dataloader into latents. 63 | all_filenames = dataset.file_splits["train"] 64 | all_filenames = [dl.strip_name(f) for f in all_filenames] 65 | if isinstance(deformer, nn.DataParallel): 66 | all_latents = deformer.module.net.lat_params.detach().cpu().numpy() 67 | else: 68 | all_latents = deformer.net.lat_params.detach().cpu().numpy() 69 | 70 | return dict(zip(all_filenames, all_latents)) 71 | 72 | 73 | def get_k(epoch): 74 | if epoch < 10: 75 | return 4000 76 | elif epoch < 50: 77 | return 800 78 | elif epoch < 80: 79 | return 100 80 | else: 81 | return 10 82 | 83 | 84 | def train_or_eval( 85 | mode, 86 | args, 87 | deformer, 88 | chamfer_dist, 89 | dataloader, 90 | epoch, 91 | global_step, 92 | device, 93 | logger, 94 | writer, 95 | optimizer, 96 | vis_loader=None, 97 | ): 98 | """Training / Eval function.""" 99 | modes = ["train", "eval"] 100 | if mode not in modes: 101 | raise ValueError(f"mode ({mode}) must be one of {modes}.") 102 | if mode == "train": 103 | deformer.train() 104 | else: 105 | deformer.eval() 106 | tot_loss = 0 107 | count = 0 108 | criterion = LOSSES[args.loss_type] 109 | epoch_images = [] 110 | epoch_latents = [] 111 | 112 | with torch.set_grad_enabled(mode == "train"): 113 | toc = time.time() 114 | 115 | for batch_idx, data_tensors in enumerate(dataloader): 116 | tic = time.time() 117 | # Send tensors to device. 118 | data_tensors = [t.to(device) for t in data_tensors] 119 | ( 120 | ii, 121 | jj, 122 | source_pts, 123 | target_pts, 124 | source_img, 125 | target_img, 126 | ) = data_tensors 127 | 128 | bs = len(source_pts) 129 | optimizer.zero_grad() 130 | 131 | # Batch together source and target to create two-way loss training. 132 | # Cannot call deformer twice (once for each way) because that 133 | # breaks odeint_ajoint's gradient computation. not sure why. 134 | source_target_points = torch.cat([source_pts, target_pts], dim=0) 135 | target_source_points = torch.cat([target_pts, source_pts], dim=0) 136 | 137 | source_target_latents = torch.cat([ii, jj], dim=0) 138 | target_source_latents = torch.cat([jj, ii], dim=0) 139 | latent_seq = torch.stack( 140 | [source_target_latents, target_source_latents], dim=1 141 | ) 142 | deformed_pts = deformer( 143 | source_target_points[..., :3], latent_seq 144 | ) # Already set to via_hub. 145 | 146 | if mode == "eval": 147 | # Add thumbnail images for visualizing latent embedding. 148 | epoch_images += [torch.cat([source_img, target_img], dim=0)] 149 | source_target_latents = deformer.module.get_lat_params( 150 | source_target_latents 151 | ) 152 | epoch_latents += [source_target_latents] 153 | 154 | # Symmetric pair of matching losses. 155 | if args.symm: 156 | _, _, dist = chamfer_dist( 157 | utils.symmetric_duplication(deformed_pts, symm_dim=2), 158 | utils.symmetric_duplication( 159 | target_source_points[..., :3], symm_dim=2 160 | ), 161 | ) 162 | else: 163 | _, _, dist = chamfer_dist( 164 | deformed_pts, target_source_points[..., :3] 165 | ) 166 | 167 | loss = criterion(dist, torch.zeros_like(dist)) 168 | 169 | # Check amount of deformation. 170 | deform_abs = torch.mean( 171 | torch.norm(deformed_pts - source_target_points, dim=-1) 172 | ) 173 | 174 | if mode == "train": 175 | loss.backward() 176 | 177 | # Gradient clipping. 178 | torch.nn.utils.clip_grad_value_( 179 | deformer.module.parameters(), args.clip_grad 180 | ) 181 | 182 | optimizer.step() 183 | 184 | tot_loss += loss.item() 185 | count += bs 186 | 187 | if batch_idx % args.log_interval == 0: 188 | # Logger log. 189 | logger.info( 190 | "{} Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t" 191 | "Dist Mean: {:.6f}\t" 192 | "Deform Mean: {:.6f}\t" 193 | "DataTime: {:.4f}\tComputeTime: {:.4f}".format( 194 | mode, 195 | epoch, 196 | batch_idx * bs, 197 | len(dataloader) * bs, 198 | 100.0 * batch_idx / len(dataloader), 199 | loss.item(), 200 | np.sqrt(loss.item()), 201 | deform_abs.item(), 202 | tic - toc, 203 | time.time() - tic, 204 | ) 205 | ) 206 | # Tensorboard log. 207 | writer.add_scalar( 208 | f"{mode}/loss_sum", 209 | loss.item(), 210 | global_step=int(global_step), 211 | ) 212 | writer.add_scalar( 213 | f"{mode}/dist_avg", 214 | np.sqrt(loss.item()), 215 | global_step=int(global_step), 216 | ) 217 | writer.add_scalar( 218 | f"{mode}/def_mean", 219 | deform_abs.item(), 220 | global_step=int(global_step), 221 | ) 222 | 223 | if mode == "train": 224 | global_step += 1 225 | toc = time.time() 226 | tot_loss /= count 227 | 228 | # # visualize embeddings 229 | if mode == "eval": 230 | epoch_images = torch.cat(epoch_images, dim=0).permute(0, 3, 1, 2) # \ 231 | # [N,C,H,W] 232 | epoch_images = epoch_images.float() / 255.0 233 | epoch_latents = torch.cat(epoch_latents, dim=0) 234 | writer.add_embedding( 235 | mat=epoch_latents, label_img=epoch_images, global_step=epoch 236 | ) 237 | 238 | # # visualize a few deformation examples in tensorboard 239 | if args.vis_mesh and (vis_loader is not None) and (mode == "eval"): 240 | # add deformation demo 241 | with torch.set_grad_enabled(False): 242 | for ind, data_tensors in enumerate(vis_loader): # batch size = 1 243 | ii = torch.tensor([data_tensors[0]], dtype=torch.long) 244 | jj = torch.tensor([data_tensors[1]], dtype=torch.long) 245 | 246 | source_latents = deformer.module.get_lat_params(ii) 247 | target_latents = deformer.module.get_lat_params(jj) 248 | hub_latents = torch.zeros_like(source_latents) 249 | 250 | data_tensors = [ 251 | t.unsqueeze(0).to(device) for t in data_tensors[2:] 252 | ] 253 | vi, fi, vj, fj = data_tensors 254 | vi = vi[0] 255 | fi = fi[0] 256 | vj = vj[0] 257 | fj = fj[0] 258 | vi_j = deformer( 259 | vi[..., :3], 260 | torch.stack( 261 | [source_latents, hub_latents, target_latents], 262 | dim=1, 263 | ), 264 | ) 265 | vj_i = deformer( 266 | vj[..., :3], 267 | torch.stack( 268 | [target_latents, hub_latents, source_latents], 269 | dim=1, 270 | ), 271 | ) 272 | 273 | accu_i, _, _ = chamfer_dist(vi_j, vj) # [1, m] 274 | accu_j, _, _ = chamfer_dist(vj_i, vi) # [1, n] 275 | 276 | # Find the max dist between pairs of original shapes for 277 | # normalizing colors 278 | chamfer_dist.set_reduction_method("max") 279 | _, _, max_dist = chamfer_dist(vi, vj) # [1,] 280 | chamfer_dist.set_reduction_method("mean") 281 | 282 | # Normalize the accuracies wrt. the distance between src 283 | # and tgt meshes. 284 | ci = utils.colorize_scalar_tensors( 285 | accu_i / max_dist, vmin=0.0, vmax=1.0, cmap="coolwarm" 286 | ) 287 | cj = utils.colorize_scalar_tensors( 288 | accu_j / max_dist, vmin=0.0, vmax=1.0, cmap="coolwarm" 289 | ) 290 | ci = (ci * 255.0).int() 291 | cj = (cj * 255.0).int() 292 | 293 | # Save mesh. 294 | samp_dir = os.path.join(args.log_dir, "deformation_samples") 295 | os.makedirs(samp_dir, exist_ok=True) 296 | trimesh.Trimesh( 297 | vi.detach().cpu().numpy()[0], fi.detach().cpu().numpy()[0] 298 | ).export(os.path.join(samp_dir, f"samp{ind}_src.obj")) 299 | trimesh.Trimesh( 300 | vj.detach().cpu().numpy()[0], fj.detach().cpu().numpy()[0] 301 | ).export(os.path.join(samp_dir, f"samp{ind}_tar.obj")) 302 | trimesh.Trimesh( 303 | vi_j.detach().cpu().numpy()[0], 304 | fi.detach().cpu().numpy()[0], 305 | ).export(os.path.join(samp_dir, f"samp{ind}_src_to_tar.obj")) 306 | trimesh.Trimesh( 307 | vj_i.detach().cpu().numpy()[0], 308 | fj.detach().cpu().numpy()[0], 309 | ).export(os.path.join(samp_dir, f"samp{ind}_tar_to_src.obj")) 310 | 311 | # Add colorized mesh to tensorboard. 312 | writer.add_mesh( 313 | f"samp{ind}/src", 314 | vertices=vi, 315 | faces=fi, 316 | global_step=int(epoch), 317 | ) 318 | writer.add_mesh( 319 | f"samp{ind}/tar", 320 | vertices=vj, 321 | faces=fj, 322 | global_step=int(epoch), 323 | ) 324 | writer.add_mesh( 325 | f"samp{ind}/src_to_tar", 326 | vertices=vi_j, 327 | faces=fi, 328 | colors=ci, 329 | global_step=int(epoch), 330 | ) 331 | writer.add_mesh( 332 | f"samp{ind}/tar_to_src", 333 | vertices=vj_i, 334 | faces=fj, 335 | colors=cj, 336 | global_step=int(epoch), 337 | ) 338 | 339 | return tot_loss 340 | 341 | 342 | def get_args(): 343 | """Parse command line arguments.""" 344 | parser = argparse.ArgumentParser(description="ShapeNet Deformation Space") 345 | 346 | parser.add_argument( 347 | "--batch_size_per_gpu", 348 | type=int, 349 | default=16, 350 | metavar="N", 351 | help="input batch size for training (default: 10)", 352 | ) 353 | parser.add_argument( 354 | "--epochs", 355 | type=int, 356 | default=100, 357 | metavar="N", 358 | help="number of epochs to train (default: 100)", 359 | ) 360 | parser.add_argument( 361 | "--pseudo_train_epoch_size", 362 | type=int, 363 | default=2048, 364 | metavar="N", 365 | help="number of samples in an pseudo-epoch. (default: 2048)", 366 | ) 367 | parser.add_argument( 368 | "--pseudo_eval_epoch_size", 369 | type=int, 370 | default=128, 371 | metavar="N", 372 | help="number of samples in an pseudo-epoch. (default: 128)", 373 | ) 374 | parser.add_argument( 375 | "--lr", 376 | type=float, 377 | default=1e-3, 378 | metavar="R", 379 | help="learning rate (default: 0.001)", 380 | ) 381 | parser.add_argument( 382 | "--no_cuda", 383 | action="store_true", 384 | default=False, 385 | help="disables CUDA training", 386 | ) 387 | parser.add_argument( 388 | "--seed", 389 | type=int, 390 | default=1, 391 | metavar="S", 392 | help="random seed (default: 1)", 393 | ) 394 | parser.add_argument( 395 | "--data_root", 396 | type=str, 397 | default="data/shapenet_simplified", 398 | help="path to mesh folder root (default: data/shapenet_simplified)", 399 | ) 400 | parser.add_argument( 401 | "--category", type=str, default="chair", help="the shape category." 402 | ) 403 | parser.add_argument( 404 | "--thumbnails_root", 405 | type=str, 406 | default="data/shapenet_thumbnails", 407 | help="path to thumbnails folder root " 408 | "(default: data/shapenet_thumbnails)", 409 | ) 410 | parser.add_argument( 411 | "--deformer_arch", 412 | type=str, 413 | choices=["imnet", "vanilla"], 414 | default="imnet", 415 | help="deformer architecture. (default: imnet)", 416 | ) 417 | parser.add_argument( 418 | "--solver", 419 | type=str, 420 | choices=SOLVERS, 421 | default="dopri5", 422 | help="ode solver. (default: dopri5)", 423 | ) 424 | parser.add_argument( 425 | "--atol", 426 | type=float, 427 | default=1e-5, 428 | help="absolute error tolerence in ode solver. (default: 1e-5)", 429 | ) 430 | parser.add_argument( 431 | "--rtol", 432 | type=float, 433 | default=1e-5, 434 | help="relative error tolerence in ode solver. (default: 1e-5)", 435 | ) 436 | parser.add_argument( 437 | "--log_interval", 438 | type=int, 439 | default=10, 440 | metavar="N", 441 | help="how many batches to wait before logging training status", 442 | ) 443 | parser.add_argument( 444 | "--log_dir", type=str, required=True, help="log directory for run" 445 | ) 446 | parser.add_argument( 447 | "--nonlin", type=str, default="elu", help="type of nonlinearity to use" 448 | ) 449 | parser.add_argument( 450 | "--optim", type=str, default="adam", choices=list(OPTIMIZERS.keys()) 451 | ) 452 | parser.add_argument( 453 | "--loss_type", type=str, default="l2", choices=list(LOSSES.keys()) 454 | ) 455 | parser.add_argument( 456 | "--resume", 457 | type=str, 458 | default=None, 459 | help="path to checkpoint if resume is needed", 460 | ) 461 | parser.add_argument( 462 | "-n", 463 | "--nsamples", 464 | default=2048, 465 | type=int, 466 | help="number of sample points to draw per shape.", 467 | ) 468 | parser.add_argument( 469 | "--lat_dims", default=32, type=int, help="number of latent dimensions." 470 | ) 471 | parser.add_argument( 472 | "--datasubset", 473 | default=0, 474 | type=int, 475 | help="0 to not subset. else subset this many examples.", 476 | ) 477 | parser.add_argument( 478 | "--deformer_nf", 479 | default=100, 480 | type=int, 481 | help="number of base number of feature layers in deformer (imnet).", 482 | ) 483 | parser.add_argument( 484 | "--lr_scheduler", dest="lr_scheduler", action="store_true" 485 | ) 486 | parser.add_argument( 487 | "--no_lr_scheduler", dest="lr_scheduler", action="store_false" 488 | ) 489 | parser.set_defaults(lr_scheduler=True) 490 | parser.set_defaults(normals=True) 491 | parser.add_argument( 492 | "--visualize_mesh", 493 | dest="vis_mesh", 494 | action="store_true", 495 | help="visualize deformation for meshes of sample validation data " 496 | "in tensorboard.", 497 | ) 498 | parser.add_argument( 499 | "--no_visualize_mesh", 500 | dest="vis_mesh", 501 | action="store_false", 502 | help="no visualize deformation for meshes of sample validation data " 503 | "in tensorboard.", 504 | ) 505 | parser.set_defaults(vis_mesh=True) 506 | parser.add_argument( 507 | "--adjoint", 508 | dest="adjoint", 509 | action="store_true", 510 | help="use adjoint solver to propagate gradients thru odeint.", 511 | ) 512 | parser.add_argument( 513 | "--no_adjoint", 514 | dest="adjoint", 515 | action="store_false", 516 | help="not use adjoint solver to propagate gradients thru odeint.", 517 | ) 518 | parser.set_defaults(adjoint=True) 519 | parser.add_argument( 520 | "--sign_net", 521 | dest="sign_net", 522 | action="store_true", 523 | help="use sign net.", 524 | ) 525 | parser.add_argument( 526 | "--no_sign_net", 527 | dest="sign_net", 528 | action="store_false", 529 | help="not use sign net.", 530 | ) 531 | parser.set_defaults(sign_net=False) 532 | parser.add_argument( 533 | "--clip_grad", 534 | default=1.0, 535 | type=float, 536 | help="clip gradient to this value. large value basically " 537 | "deactivates it.", 538 | ) 539 | parser.add_argument( 540 | "--sampling_method", 541 | type=str, 542 | choices=[ 543 | "nn_replace", 544 | "nn_no_replace", 545 | "all_replace", 546 | "all_no_replace", 547 | ], 548 | default="nn_no_replace", 549 | help="method for sampling pairs of shape to deform.", 550 | ) 551 | parser.add_argument( 552 | "--symm", dest="symm", action="store_true", help="use symmetric flow." 553 | ) 554 | parser.add_argument( 555 | "--no_symm", 556 | dest="symm", 557 | action="store_false", 558 | help="not use symmetric flow.", 559 | ) 560 | parser.set_defaults(symm=False) 561 | args = parser.parse_args() 562 | return args 563 | 564 | 565 | def main(): 566 | args = get_args() 567 | 568 | # Adjust batch size based on the number of gpus available. 569 | args.batch_size = int(torch.cuda.device_count()) * args.batch_size_per_gpu 570 | use_cuda = (not args.no_cuda) and torch.cuda.is_available() 571 | kwargs = ( 572 | {"num_workers": min(12, args.batch_size), "pin_memory": True} 573 | if use_cuda 574 | else {} 575 | ) 576 | device = torch.device("cuda" if use_cuda else "cpu") 577 | 578 | # Log and create snapshots. 579 | filenames_to_snapshot = ( 580 | glob.glob("*.py") + glob.glob("*.sh") + glob.glob("layers/*.py") 581 | ) 582 | utils.snapshot_files(filenames_to_snapshot, args.log_dir) 583 | logger = utils.get_logger(log_dir=args.log_dir) 584 | with open(os.path.join(args.log_dir, "params.json"), "w") as fh: 585 | json.dump(args.__dict__, fh, indent=2) 586 | logger.info("%s", repr(args)) 587 | 588 | args.n_vis = 2 # Number of deformation examples to visualize. 589 | 590 | # Tensorboard writer. 591 | writer = SummaryWriter(log_dir=os.path.join(args.log_dir, "tensorboard")) 592 | 593 | # Random seed for reproducability. 594 | torch.manual_seed(args.seed) 595 | np.random.seed(args.seed) 596 | 597 | # Create dataloaders. 598 | fullset = dl.ShapeNetVertex( 599 | data_root=args.data_root, 600 | split="train", 601 | category=args.category, 602 | nsamples=args.nsamples, 603 | normals=False, 604 | ) 605 | if args.datasubset > 0: 606 | fullset.restrict_subset(args.datasubset) 607 | 608 | # Return thumbnails (to visualize embedding during eval). 609 | fullset.add_thumbnails(args.thumbnails_root) 610 | 611 | if "nn_" in args.sampling_method: 612 | args.nn_samp = True 613 | replace = True if args.sampling_method == "nn_replace" else False 614 | train_sampler = dl.LatentNearestNeighborSampler( 615 | dataset=fullset, 616 | src_split="train", 617 | tar_split="train", 618 | n_samples=args.pseudo_train_epoch_size, 619 | k=1000, 620 | replace=replace, 621 | ) 622 | eval_sampler = dl.LatentNearestNeighborSampler( 623 | dataset=fullset, 624 | src_split="train", 625 | tar_split="train", 626 | n_samples=args.pseudo_eval_epoch_size, 627 | k=1, 628 | replace=replace, 629 | ) # Pick the closest. 630 | vis_sampler = dl.LatentNearestNeighborSampler( 631 | dataset=fullset, 632 | src_split="train", 633 | tar_split="train", 634 | n_samples=args.n_vis, 635 | k=1, 636 | replace=replace, 637 | ) # Pick the closest. 638 | else: 639 | args.nn_samp = False 640 | replace = True if args.sampling_method == "all_replace" else False 641 | train_sampler = dl.RandomPairSampler( 642 | dataset=fullset, 643 | src_split="train", 644 | tar_split="train", 645 | n_samples=args.pseudo_train_epoch_size, 646 | replace=replace, 647 | ) 648 | eval_sampler = dl.RandomPairSampler( 649 | dataset=fullset, 650 | src_split="train", 651 | tar_split="train", 652 | n_samples=args.pseudo_eval_epoch_size, 653 | replace=replace, 654 | ) 655 | vis_sampler = dl.RandomPairSampler( 656 | dataset=fullset, 657 | src_split="train", 658 | tar_split="train", 659 | n_samples=args.n_vis, 660 | replace=replace, 661 | ) 662 | 663 | # Make sure we are turning off shuffle since we are using samplers! 664 | train_loader = DataLoader( 665 | fullset, 666 | batch_size=args.batch_size, 667 | shuffle=False, 668 | drop_last=True, 669 | sampler=train_sampler, 670 | **kwargs, 671 | ) 672 | eval_loader = DataLoader( 673 | fullset, 674 | batch_size=args.batch_size, 675 | shuffle=False, 676 | drop_last=False, 677 | sampler=eval_sampler, 678 | **kwargs, 679 | ) 680 | 681 | if args.vis_mesh: 682 | # For loading full meshes for visualization. 683 | simp_data_root = args.data_root 684 | simpset = dl.ShapeNetMesh( 685 | data_root=simp_data_root, 686 | split="train", 687 | category=args.category, 688 | normals=False, 689 | ) 690 | if args.datasubset > 0: 691 | simpset.restrict_subset(args.datasubset) 692 | if not ( 693 | vis_sampler.dataset.fname_to_idx_dict == simpset.fname_to_idx_dict 694 | ): 695 | raise RuntimeError( 696 | f"vis_sampler ({len(vis_sampler.dataset.fname_to_idx_dict)}) " 697 | f"does not match sample set ({len(simpset.fname_to_idx_dict)})" 698 | ) 699 | vis_loader = DataLoader( 700 | simpset, 701 | batch_size=1, 702 | shuffle=False, 703 | drop_last=False, 704 | sampler=vis_sampler, 705 | **kwargs, 706 | ) 707 | else: 708 | vis_loader = None 709 | 710 | # Setup model. 711 | deformer = NeuralFlowDeformer( 712 | latent_size=args.lat_dims, 713 | f_width=args.deformer_nf, 714 | s_nlayers=2, 715 | s_width=5, 716 | method=args.solver, 717 | nonlinearity=args.nonlin, 718 | arch="imnet", 719 | adjoint=args.adjoint, 720 | rtol=args.rtol, 721 | atol=args.atol, 722 | via_hub=True, 723 | no_sign_net=(not args.sign_net), 724 | symm_dim=(2 if args.symm else None), 725 | ) 726 | 727 | # Awkward workaround to get gradients from odeint_adjoint to lat_params. 728 | lat_params = torch.nn.Parameter( 729 | torch.randn(fullset.n_shapes, args.lat_dims) * 1e-1, requires_grad=True 730 | ) 731 | deformer.add_lat_params(lat_params) 732 | deformer.to(device) 733 | 734 | all_model_params = list(deformer.parameters()) 735 | 736 | optimizer = OPTIMIZERS[args.optim](all_model_params, lr=args.lr) 737 | 738 | start_ep = 0 739 | global_step = np.zeros(1, dtype=np.uint32) 740 | tracked_stats = np.inf 741 | 742 | if args.resume: 743 | logger.info( 744 | "Loading checkpoint {} ================>".format(args.resume) 745 | ) 746 | resume_dict = torch.load(args.resume) 747 | start_ep = resume_dict["epoch"] 748 | global_step = resume_dict["global_step"] 749 | tracked_stats = resume_dict["tracked_stats"] 750 | deformer.load_state_dict(resume_dict["deformer_state_dict"]) 751 | optimizer.load_state_dict(resume_dict["optim_state_dict"]) 752 | for state in optimizer.state.values(): 753 | for k, v in state.items(): 754 | if isinstance(v, torch.Tensor): 755 | state[k] = v.to(device) 756 | logger.info("[!] Successfully loaded checkpoint.") 757 | 758 | # More threads don't seem to help. 759 | chamfer_dist = ChamferDistKDTree(reduction="mean", njobs=1) 760 | chamfer_dist.to(device) 761 | deformer = nn.DataParallel(deformer) 762 | deformer.to(device) 763 | 764 | model_param_count = lambda model: sum( # noqa: E731 765 | x.numel() for x in model.parameters() 766 | ) 767 | logger.info( 768 | f"{model_param_count(deformer)}(deformer) paramerters in total" 769 | ) 770 | 771 | checkpoint_path = os.path.join(args.log_dir, "checkpoint_latest.pth.tar") 772 | 773 | if args.lr_scheduler: 774 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") 775 | 776 | train_latent_dict = compute_latent_dict(deformer, fullset) 777 | # Training loop. 778 | for epoch in range(start_ep + 1, args.epochs + 1): 779 | # Set sampler nn graph before train or eval. 780 | if args.nn_samp: 781 | train_loader.sampler.update_nn_graph( 782 | train_latent_dict, train_latent_dict, k=get_k(epoch) 783 | ) 784 | eval_loader.sampler.update_nn_graph( 785 | train_latent_dict, train_latent_dict 786 | ) 787 | vis_loader.sampler.update_nn_graph( 788 | train_latent_dict, train_latent_dict 789 | ) 790 | 791 | _ = train_or_eval( 792 | "train", 793 | args, 794 | deformer, 795 | chamfer_dist, 796 | train_loader, 797 | epoch, 798 | global_step, 799 | device, 800 | logger, 801 | writer, 802 | optimizer, 803 | None, 804 | ) 805 | loss_eval = train_or_eval( 806 | "eval", 807 | args, 808 | deformer, 809 | chamfer_dist, 810 | eval_loader, 811 | epoch, 812 | global_step, 813 | device, 814 | logger, 815 | writer, 816 | optimizer, 817 | vis_loader, 818 | ) 819 | 820 | if args.lr_scheduler: 821 | scheduler.step(loss_eval) 822 | if loss_eval < tracked_stats: 823 | tracked_stats = loss_eval 824 | is_best = True 825 | else: 826 | is_best = False 827 | 828 | utils.save_checkpoint( 829 | { 830 | "epoch": epoch, 831 | "deformer_state_dict": deformer.module.state_dict(), 832 | "lat_params": lat_params, 833 | "optim_state_dict": optimizer.state_dict(), 834 | "tracked_stats": tracked_stats, 835 | "global_step": global_step, 836 | }, 837 | is_best, 838 | epoch, 839 | checkpoint_path, 840 | "_shapeflow", 841 | logger, 842 | ) 843 | 844 | 845 | if __name__ == "__main__": 846 | main() 847 | -------------------------------------------------------------------------------- /shapenet_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ##################################### 4 | # Configure Experiment # 5 | ##################################### 6 | run_name=demo 7 | log_dir=runs/$run_name 8 | data_root=data/shapenet_simplified 9 | 10 | # Create run directory if it doesn't exist. 11 | mkdir -p runs 12 | 13 | # Launch training. 14 | python shapenet_train.py \ 15 | --atol=1e-4 \ 16 | --rtol=1e-4 \ 17 | --data_root=$data_root \ 18 | --pseudo_train_epoch_size=2048 \ 19 | --pseudo_eval_epoch_size=128 \ 20 | --lr=1e-3 \ 21 | --log_dir=$log_dir \ 22 | --lr_scheduler \ 23 | --visualize_mesh \ 24 | --batch_size_per_gpu=32 \ 25 | --log_interval=2 \ 26 | --epochs=100 \ 27 | --no_sign_net \ 28 | --adjoint \ 29 | --solver='dopri5' \ 30 | --deformer_nf=128 \ 31 | --nsamples=512 \ 32 | --lat_dims=128 \ 33 | --nonlin='leakyrelu' \ 34 | --symm \ 35 | --category=${category} \ 36 | --sampling_method='all_no_replace' \ 37 | -------------------------------------------------------------------------------- /utils/render.py: -------------------------------------------------------------------------------- 1 | """Rendering utility functions. 2 | """ 3 | import os 4 | 5 | os.environ["PYOPENGL_PLATFORM"] = "osmesa" 6 | import numpy as np # noqa: E402 7 | import trimesh # noqa: E402 8 | import pyrender # noqa: E402 9 | 10 | 11 | def render_trimesh( 12 | trimesh_mesh, 13 | eye, 14 | center, 15 | world_up, 16 | res=(640, 640), 17 | light_intensity=3.0, 18 | ambient_intensity=None, 19 | **kwargs, 20 | ): 21 | """Render a shapenet mesh using default settings. 22 | 23 | Args: 24 | trimesh_mesh: trimesh mesh instance, or a list of trimesh meshes 25 | (or point clouds). 26 | eye: array with shape [3,] containing the XYZ world 27 | space position of the camera. 28 | center: array with shape [3,] containing a position 29 | along the center of the camera's gaze. 30 | world_up: np.float32 array with shape [3,] specifying the 31 | world's up direction; the output camera will have no tilt with respect 32 | to this direction. 33 | res: 2-tuple of int, [width, height], resolution (in pixels) of output 34 | images. 35 | light_intensity: float, light intensity. 36 | ambient_intensity: float, ambient light intensity. 37 | kwargs: additional flags to pass to pyrender renderer. 38 | Returns: 39 | color_img: [*res, 3] color image. 40 | depth_img: [*res, 1] depth image. 41 | world_to_cam: [4, 4] camera to world matrix. 42 | projection_matrix: [4, 4] projection matrix, aka cam_to_img matrix. 43 | """ 44 | if not isinstance(trimesh_mesh, list): 45 | trimesh_mesh = [trimesh_mesh] 46 | eye = list2npy(eye).astype(np.float32) 47 | center = list2npy(center).astype(np.float32) 48 | world_up = list2npy(world_up).astype(np.float32) 49 | 50 | # setup camera pose matrix 51 | scene = pyrender.Scene( 52 | ambient_light=ambient_intensity*np.ones([3], dtype=float) 53 | ) 54 | for tmesh in trimesh_mesh: 55 | if not ( 56 | isinstance(tmesh, trimesh.Trimesh) 57 | or isinstance(tmesh, trimesh.PointCloud) 58 | ): 59 | raise NotImplementedError( 60 | "All instances in trimesh_mesh must be either trimesh.Trimesh " 61 | f"or trimesh.PointCloud. Instead it is {type(tmesh)}." 62 | ) 63 | if isinstance(tmesh, trimesh.Trimesh): 64 | mesh = pyrender.Mesh.from_trimesh(tmesh) 65 | elif isinstance(tmesh, trimesh.PointCloud): 66 | if tmesh.colors is not None: 67 | colors = np.array(tmesh.colors) 68 | else: 69 | colors = np.ones_like(tmesh.vertices) 70 | mesh = pyrender.Mesh.from_points( 71 | np.array(tmesh.vertices), colors=colors 72 | ) 73 | scene.add(mesh) 74 | 75 | # Set up the camera -- z-axis away from the scene, x-axis right, y-axis up 76 | camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0) 77 | 78 | world_to_cam = look_at(eye[None], center[None], world_up[None]) 79 | world_to_cam = world_to_cam[0] 80 | cam_pose = np.linalg.inv(world_to_cam) 81 | scene.add(camera, pose=cam_pose) 82 | 83 | # Set up the light -- a single spot light in the same spot as the camera 84 | light = pyrender.SpotLight( 85 | color=np.ones(3, dtype=np.float32), 86 | intensity=light_intensity, 87 | innerConeAngle=np.pi / 16.0, 88 | ) 89 | scene.add(light, pose=cam_pose) 90 | 91 | # Render the scene 92 | r = pyrender.OffscreenRenderer(*res, **kwargs) 93 | color_img, depth_img = r.render(scene) 94 | return ( 95 | color_img, 96 | depth_img, 97 | world_to_cam, 98 | camera.get_projection_matrix(*res), 99 | ) 100 | 101 | 102 | def _unproject_points(points, projection_matrix, world_to_cam): 103 | """Unproject points from image space to world space.""" 104 | # pad 105 | depth = points[:, 2] 106 | xy_scale = (depth - projection_matrix[2, 3]) / (-projection_matrix[2, 2]) 107 | points[:, :2] = points[:, :2] * xy_scale[:, None] 108 | points = np.concatenate([points, np.ones_like(points[:, :1])], axis=1) 109 | points[:, 3] = xy_scale 110 | # camera space coordinates 111 | point_cam = ( 112 | np.linalg.inv(projection_matrix) @ (points.T) 113 | ).T # [npoints, 4] 114 | # world space coordinates 115 | cam_to_world = np.linalg.inv(world_to_cam) 116 | point_world = (cam_to_world @ (point_cam.T)).T 117 | point_world = point_world[:, :3] / point_world[:, 3:] 118 | return point_world 119 | 120 | 121 | def _points_from_depth(depth_img, zoffset=0): 122 | """Get image space points from depth image.""" 123 | point_mask = depth_img != 0.0 124 | point_mask_flat = point_mask.reshape(-1) 125 | w, h = depth_img.shape 126 | x, y = np.meshgrid( 127 | np.linspace(-1.0, 1.0, w), np.linspace(-1.0, 1.0, h), indexing="ij" 128 | ) 129 | xy_img = np.stack([y, -x], axis=-1) # [w, h, 2] 130 | xy_flat = xy_img.reshape(-1, 2) # [w*h, 2] 131 | point_img = xy_flat[point_mask_flat] # [npoints, 2] 132 | depth = depth_img.reshape(-1)[point_mask_flat] + zoffset 133 | point_img = np.concatenate([point_img, depth[..., None]], axis=-1) 134 | return point_img 135 | 136 | 137 | def line_meshes(verts, edges, colors=None, poses=None): 138 | """Create pyrender Mesh instance for lines. 139 | 140 | Args: 141 | verts: np.array floats of shape [#v, 3] 142 | edges: np.array ints of shape [#e, 3] 143 | colors: np.array floats of shape [#v, 3] 144 | poses: poses : (x,4,4) 145 | Array of 4x4 transformation matrices for instancing this object. 146 | """ 147 | prim = pyrender.primitive.Primitive( 148 | positions=verts, 149 | indices=edges, 150 | color_0=colors, 151 | mode=pyrender.constants.GLTF.LINES, 152 | poses=poses) 153 | return pyrender.mesh.Mesh(primitives=[prim], is_visible=True) 154 | 155 | 156 | def unproject_depth_img(depth_img, projection_matrix, world_to_cam): 157 | """Unproject depth image to point cloud in world coordinates. 158 | 159 | Args: 160 | depth_img: array of [width, height] depth image. 161 | projection_matrix: array of [4, 4], projection matrix, aka cam_to_img 162 | matrix. 163 | world_to_cam: array of [4, 4], world to cam matrix, inverse of camera 164 | pose. 165 | 166 | Returns: 167 | point_world: array of [npoints, 3] depth scan point cloud in world 168 | coordinates. 169 | """ 170 | point_img = _points_from_depth(depth_img, zoffset=projection_matrix[2, 3]) 171 | point_world = _unproject_points(point_img, projection_matrix, world_to_cam) 172 | 173 | return point_world 174 | 175 | 176 | def list2npy(array): 177 | return array if isinstance(array, np.ndarray) else np.array(array) 178 | 179 | 180 | def r4pad(array): 181 | """pad [..., 3] array to [..., 4] with ones in last channel.""" 182 | zeros = np.ones_like(array[..., -1:]) 183 | return np.concatenate([array, zeros], axis=-1) 184 | 185 | 186 | def look_at(eye, center, world_up): 187 | """Computes camera viewing matrices (numpy implementation). 188 | 189 | Args: 190 | eye: np.float32 array with shape [batch_size, 3] containing the XYZ world 191 | space position of the camera. 192 | center: np.float32 array with shape [batch_size, 3] containing a position 193 | along the center of the camera's gaze. 194 | world_up: np.float32 array with shape [batch_size, 3] specifying the 195 | world's up direction; the output camera will have no tilt with respect to 196 | this direction. 197 | 198 | Returns: 199 | A [batch_size, 4, 4] np.float32 array containing a right-handed camera 200 | extrinsics matrix that maps points from world space to points in eye space. 201 | """ 202 | batch_size = center.shape[0] 203 | vector_degeneracy_cutoff = 1e-6 204 | forward = center - eye 205 | forward_norm = np.linalg.norm(forward, axis=1, keepdims=True) 206 | assert np.all(forward_norm > vector_degeneracy_cutoff) 207 | forward /= forward_norm 208 | 209 | to_side = np.cross(forward, world_up) 210 | to_side_norm = np.linalg.norm(to_side, axis=1, keepdims=True) 211 | assert np.all(to_side_norm > vector_degeneracy_cutoff) 212 | to_side /= to_side_norm 213 | cam_up = np.cross(to_side, forward) 214 | 215 | w_column = np.array( 216 | batch_size * [[0.0, 0.0, 0.0, 1.0]], dtype=np.float32 217 | ) # [batch_size, 4] 218 | w_column = w_column.reshape([batch_size, 4, 1]) 219 | view_rotation = np.stack( 220 | [to_side, cam_up, -forward, np.zeros_like(to_side, dtype=np.float32)], 221 | axis=1, 222 | ) # [batch_size, 4, 3] matrix 223 | view_rotation = np.concatenate( 224 | [view_rotation, w_column], axis=2 225 | ) # [batch_size, 4, 4] 226 | 227 | identity_batch = np.tile(np.expand_dims(np.eye(3), 0), [batch_size, 1, 1]) 228 | view_translation = np.concatenate( 229 | [identity_batch, np.expand_dims(-eye, 2)], 2 230 | ) 231 | view_translation = np.concatenate( 232 | [view_translation, w_column.reshape([batch_size, 1, 4])], 1 233 | ) 234 | camera_matrices = np.matmul(view_rotation, view_translation) 235 | return camera_matrices 236 | --------------------------------------------------------------------------------