├── .github ├── ISSUE_TEMPLATE └── PULL_REQUEST_TEMPLATE ├── .gitignore ├── CHANGELOG ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data └── data.py ├── data_preprocessing ├── preprocess_dataset.m └── utils │ ├── compute_dist_matrices.m │ ├── compute_dist_matrix.m │ ├── compute_triangle_areas.m │ ├── convert_dataset.m │ ├── create_remeshed_collection.m │ ├── create_remeshed_dataset.m │ ├── read_obj.m │ ├── read_off.m │ ├── read_ply.m │ └── reduce_folder_individual.m ├── figures └── splash.png ├── main_test.py ├── main_train.py ├── model ├── interpolation_net.py ├── layers.py └── layers_onet.py ├── param.py └── utils ├── arap_interpolation.py ├── arap_potential.py ├── base_tools.py ├── interpolation_base.py └── shape_utils.py /.github/ISSUE_TEMPLATE: -------------------------------------------------------------------------------- 1 | Use this to open questions or issues, and provide context here. 2 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE: -------------------------------------------------------------------------------- 1 | Fixes # . 2 | 3 | Description of the changes. 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.mex* 3 | data/checkpoint 4 | data/meshes 5 | data/out -------------------------------------------------------------------------------- /CHANGELOG: -------------------------------------------------------------------------------- 1 | * V1.0: Initial release. 2 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to NeuroMorph 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `main`. 10 | 2. Make sure your code runs correctly. 11 | 3. Make sure your code lints correctly. 12 | 4. Make sure your code is formatted using [black](https://black.readthedocs.io/). 13 | 5. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | ## License 26 | By contributing to NeuroMorph, you agree that your contributions will be licensed 27 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Unless otherwise specified, the following license applies: 2 | 3 | MIT License 4 | 5 | Copyright (c) Facebook, Inc. and its affiliates. 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | 25 | --------------------------------- 26 | 27 | The following files 28 | 29 | data_processing/utils/read_obj.m 30 | data_processing/utils/read_off.m 31 | data_processing/utils/read_ply.m 32 | 33 | are derived from the Toolbox Graph 34 | (https://uk.mathworks.com/matlabcentral/fileexchange/5355-toolbox-graph) 35 | under the following license: 36 | 37 | 2-Clause BSD License 38 | 39 | Copyright (c) 2009, Gabriel Peyre 40 | All rights reserved. 41 | 42 | Redistribution and use in source and binary forms, with or without 43 | modification, are permitted provided that the following conditions are 44 | met: 45 | 46 | - Redistributions of source code must retain the above copyright 47 | notice, this list of conditions and the following disclaimer. 48 | 49 | - Redistributions in binary form must reproduce the above copyright 50 | notice, this list of conditions and the following disclaimer in 51 | the documentation and/or other materials provided with the distribution 52 | 53 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 54 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 55 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 56 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 57 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 58 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 59 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 60 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 61 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 62 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 63 | POSSIBILITY OF SUCH DAMAGE. 64 | 65 | --------------------------------- 66 | 67 | The following files: 68 | 69 | utils/base_tools.py 70 | utils/interpolation_base.py 71 | utils/shape_utils.py 72 | 73 | are derived from the Hamiltonian Dynamics 74 | (https://github.com/marvin-eisenberger/hamiltonian-interpolation/) repository 75 | under the following license: 76 | 77 | MIT License 78 | 79 | Copyright (c) 2021 Marvin Eisenberger 80 | 81 | Permission is hereby granted, free of charge, to any person obtaining a copy 82 | of this software and associated documentation files (the "Software"), to deal 83 | in the Software without restriction, including without limitation the rights 84 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 85 | copies of the Software, and to permit persons to whom the Software is 86 | furnished to do so, subject to the following conditions: 87 | 88 | The above copyright notice and this permission notice shall be included in all 89 | copies or substantial portions of the Software. 90 | 91 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 92 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 93 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 94 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 95 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 96 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 97 | SOFTWARE. 98 | 99 | ------------------------------ 100 | 101 | The following file: 102 | 103 | model/layers.py 104 | 105 | is derived from the Occupancy Networks 106 | (https://github.com/autonomousvision/occupancy_networks/) repository 107 | under the following license: 108 | 109 | MIT License 110 | 111 | Copyright 2019 Lars Mescheder, Michael Oechsle, Michael Niemeyer, Andreas Geiger, Sebastian Nowozin 112 | 113 | Permission is hereby granted, free of charge, to any person obtaining a 114 | copy of this software and associated documentation files (the "Software"), 115 | to deal in the Software without restriction, including without limitation 116 | the rights to use, copy, modify, merge, publish, distribute, sublicense, 117 | and/or sell copies of the Software, and to permit persons to whom the 118 | Software is furnished to do so, subject to the following conditions: 119 | 120 | The above copyright notice and this permission notice shall be included in 121 | all copies or substantial portions of the Software. 122 | 123 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 124 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 125 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 126 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 127 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 128 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 129 | DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeuroMorph: Unsupervised Shape Interpolation and Correspondence in One Go 2 | 3 | ![](figures/splash.png) 4 | 5 | This repository provides our implementation of the CVPR 2021 [paper NeuroMorph](https://openaccess.thecvf.com/content/CVPR2021/html/Eisenberger_NeuroMorph_Unsupervised_Shape_Interpolation_and_Correspondence_in_One_Go_CVPR_2021_paper.html). Our algorithm produces in one go, i.e., in a single feed forward pass, a smooth interpolation and point-to-point correspondences between two input 3D shapes. It is learned in a self-supervised manner from an unlabelled collection of deformable and heterogeneous shapes. 6 | 7 | If you use our work, please cite: 8 | 9 | ``` 10 | @inproceedings{eisenberger2021neuromorph, 11 | title={NeuroMorph: Unsupervised Shape Interpolation and Correspondence in One Go}, 12 | author={Eisenberger, Marvin and Novotny, David and Kerchenbaum, Gael and Labatut, Patrick and Neverova, Natalia and Cremers, Daniel and Vedaldi, Andrea}, 13 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 14 | pages={7473--7483}, 15 | year={2021} 16 | } 17 | ``` 18 | 19 | ## Requirements 20 | 21 | The code was tested on Python 3.8.10 with the PyTorch version 1.9.1 and CUDA 10.2. 22 | The code also requires the pytorch-geometric library ([installation instructions](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html)) and [matplotlib](https://matplotlib.org). 23 | Finally, MATLAB with the Statistics and Machine Learning Toolbox is used to pre-process ceratin datasets (we tested MATLAB versions 2019b and 2021b). 24 | The code should run on Linux, macOS and Windows. 25 | 26 | ## Installing NeuroMorph 27 | 28 | Using Anaconda, you can install the required dependencies as follows: 29 | 30 | ```bash 31 | conda create -n neuromorph python=3.8 32 | conda activate neuromorph 33 | conda install pytorch cudatoolkit=10.2 -c pytorch 34 | conda install matplotlib 35 | conda install pyg -c pyg -c conda-forge 36 | ``` 37 | 38 | ## Running NeuroMorph 39 | 40 | In order to run NeuroMorph: 41 | 42 | * Specify the location of datasets on your device under `data_folder_` in `param.py`. 43 | * To use your own data, create a new dataset in `data/data.py`. 44 | * To train FAUST remeshed, run the main script `main_train.py`. Modify the script as needed to train on different data. 45 | 46 | For a more detailed tutorial, see the next section. 47 | 48 | ## Reproducing the experiments 49 | 50 | We show below how to reproduce the experiments on the FAUST remeshed data. 51 | 52 | ### Data download 53 | 54 | You can download experimental mesh data from [here](https://nuage.lix.polytechnique.fr/index.php/s/LJFXrsTG22wYCXx) from the authors of the [Deep Geometric Functional Maps](https://github.com/LIX-shape-analysis/GeomFmaps). 55 | Download the `FAUST_r.zip` file from this site, unzip it, and move the content of the directory to `/data/mesh/FAUST_r` . 56 | 57 | ### Data preprocessing 58 | 59 | Meshes must be subsampled and remeshed (for data augmentation during training) and geodesic distance matrices must be computed before the learning code runs. 60 | For this, we use the `data_preprocessing/preprocess_dataset.m` MATLAB scripts (we tested V2019b and V2021b). 61 | 62 | Start MATLAB and do the following: 63 | 64 | ```matlab 65 | cd /data_preprocessing 66 | preprocess_dataset("../data/meshes/FAUST_r/", ".off") 67 | ``` 68 | 69 | The result should be a list of MATLAB mesh files in a `mat` subfolder (e.g., `data/meshes/FAUST_r/mat` ), 70 | plus additional data. 71 | 72 | ### Model training 73 | 74 | If you stored the data in the directory given above, you can train the model by running: 75 | 76 | ```bash 77 | mkdir -p data/{checkpoint,out} 78 | python main_train.py 79 | ``` 80 | 81 | The trained models will be saved in a series of checkpoints at `/data/out/` . 82 | Otherwise, edit `param.py` to change the paths. 83 | 84 | ### Model testing 85 | 86 | Upon completion, evaluate the trained model with `main_test.py` . Specify the checkpoint folder name `` by running: 87 | 88 | ```bash 89 | python main_test.py 90 | ``` 91 | 92 | Here `` is any of the directories saved in `/data/out/` . 93 | This automatically saves correspondences and interpolations on the FAUST remeshed test set to `/data/out/` . 94 | For reference, on FAUST you should expect a validation error around 0.25 after 400 epochs. 95 | 96 | ## Contributing 97 | 98 | See the [CONTRIBUTING](CONTRIBUTING.md) file for how to help out. 99 | 100 | ## License 101 | 102 | NeuroMorph is MIT licensed, as described in the [LICENSE](LICENSE) file. 103 | NeuroMorph includes a few files from other open source projects, as further detailed in the same [LICENSE](LICENSE) file. 104 | -------------------------------------------------------------------------------- /data/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import random 7 | import torch.utils.data 8 | import scipy.io 9 | from utils.shape_utils import * 10 | 11 | 12 | def input_to_batch(mat_dict): 13 | dict_out = dict() 14 | 15 | for attr in ["vert", "triv"]: 16 | if mat_dict[attr][0].dtype.kind in np.typecodes["AllInteger"]: 17 | dict_out[attr] = np.asarray(mat_dict[attr][0], dtype=np.int32) 18 | else: 19 | dict_out[attr] = np.asarray(mat_dict[attr][0], dtype=np.float32) 20 | 21 | return dict_out 22 | 23 | 24 | def batch_to_shape(batch): 25 | shape = Shape(batch["vert"].squeeze().to(device), batch["triv"].squeeze().to(device, torch.long) - 1) 26 | 27 | if "D" in batch: 28 | shape.D = batch["D"].squeeze().to(device) 29 | 30 | if "sub" in batch: 31 | shape.sub = batch["sub"] 32 | for i_s in range(len(shape.sub)): 33 | for i_p in range(len(shape.sub[i_s])): 34 | shape.sub[i_s][i_p] = shape.sub[i_s][i_p].to(device) 35 | 36 | if "idx" in batch: 37 | shape.samples = batch["idx"].squeeze().to(device, torch.long) 38 | 39 | if "vert_full" in batch: 40 | shape.vert_full = batch["vert_full"].squeeze().to(device) 41 | 42 | return shape 43 | 44 | 45 | class ShapeDatasetBase(torch.utils.data.Dataset): 46 | def __init__(self, axis=1): 47 | self.axis = axis 48 | self.num_shapes = None 49 | 50 | def dataset_name_str(self): 51 | raise NotImplementedError() 52 | 53 | def _get_file_from_folder(self, i, folder_path): 54 | shape_files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))] 55 | shape_files.sort() 56 | return os.path.join(folder_path, shape_files[i]) 57 | 58 | 59 | class ShapeDatasetInMemory(ShapeDatasetBase): 60 | def __init__(self, folder_path, num_shapes, axis=1, load_dist_mat=False, load_sub=False): 61 | super().__init__(axis) 62 | self.folder_path = folder_path 63 | self.num_shapes = num_shapes 64 | self.axis = axis 65 | self.load_dist_mat = load_dist_mat 66 | self.load_sub = load_sub 67 | 68 | self.data = [] 69 | 70 | self._init_data() 71 | 72 | def _init_data(self): 73 | for i in range(self.num_shapes): 74 | file_name = self._get_file(self._get_index(i)) 75 | load_data = scipy.io.loadmat(file_name) 76 | 77 | data_curr = input_to_batch(load_data["X"][0]) 78 | 79 | print("Loaded file ", file_name, "") 80 | 81 | if self.load_dist_mat: 82 | file_name = self._get_file_from_folder(self._get_index(i), 83 | os.path.join(self.folder_path, "distance_matrix")) 84 | load_dist = scipy.io.loadmat(file_name) 85 | load_dist["D"][load_dist["D"] > 1e2] = 2 86 | data_curr["D"] = np.asarray(load_dist["D"], dtype=np.float32) 87 | print("Loaded file ", file_name, "") 88 | 89 | self.data.append(data_curr) 90 | 91 | def _get_file(self, i): 92 | return self._get_file_from_folder(i, self.folder_path) 93 | 94 | def _get_index(self, i): 95 | return i 96 | 97 | def __getitem__(self, index): 98 | raise NotImplementedError() 99 | 100 | def __len__(self): 101 | raise NotImplementedError() 102 | 103 | def dataset_name_str(self): 104 | raise NotImplementedError() 105 | 106 | 107 | class ShapeDatasetCombine(ShapeDatasetInMemory): 108 | def __init__(self, folder_path, num_shapes, axis=1, load_dist_mat=False, load_sub=False): 109 | super().__init__(folder_path, num_shapes, axis, load_dist_mat, load_sub) 110 | self.num_pairs = num_shapes ** 2 111 | print("loaded", self.dataset_name_str(), "with", self.num_pairs, "pairs") 112 | 113 | def __getitem__(self, index): 114 | i1 = int(index / self.num_shapes) 115 | i2 = int(index % self.num_shapes) 116 | data_curr = dict() 117 | data_curr["X"] = self.data[i1] 118 | data_curr["Y"] = self.data[i2] 119 | data_curr["axis"] = self.axis 120 | return data_curr 121 | 122 | def __len__(self): 123 | return self.num_pairs 124 | 125 | def dataset_name_str(self): 126 | raise NotImplementedError() 127 | 128 | 129 | class ShapeDatasetCombineRemesh(ShapeDatasetBase): 130 | def __init__(self, dataset: ShapeDatasetCombine, remeshing_folder="remeshing_idx"): 131 | super().__init__(dataset.axis) 132 | self.dataset = dataset 133 | self.data = self.dataset.data 134 | self.num_shapes = self.dataset.num_shapes 135 | self.idx_arr_arr = None 136 | self.triv_arr_arr = None 137 | self.remeshing_folder = remeshing_folder 138 | self._init_mesh_info() 139 | print("Using the precomputed remeshings of", self.dataset_name_str()) 140 | 141 | def _init_mesh_info(self): 142 | self.idx_arr_arr = [] 143 | self.triv_arr_arr = [] 144 | 145 | for i in range(self.dataset.num_shapes): 146 | remesh_file = self._get_file_from_folder(self.dataset._get_index(i), os.path.join(self.dataset.folder_path, self.remeshing_folder)) 147 | mesh_info = scipy.io.loadmat(remesh_file) 148 | idx_arr = mesh_info["idx_arr"] 149 | triv_arr = mesh_info["triv_arr"] 150 | 151 | print("Loaded file ", remesh_file, "") 152 | 153 | self.idx_arr_arr.append(idx_arr) 154 | self.triv_arr_arr.append(triv_arr) 155 | 156 | def __getitem__(self, index): 157 | data_curr = self.dataset[index] 158 | 159 | i1 = int(index / self.dataset.num_shapes) 160 | i2 = int(index % self.dataset.num_shapes) 161 | 162 | idx_arr_x = self.idx_arr_arr[i1] 163 | idx_arr_y = self.idx_arr_arr[i2] 164 | 165 | triv_arr_x = self.triv_arr_arr[i1] 166 | triv_arr_y = self.triv_arr_arr[i2] 167 | 168 | i_mesh_x = random.randint(0, idx_arr_x.shape[0] - 1) 169 | i_mesh_y = random.randint(0, idx_arr_y.shape[0] - 1) 170 | 171 | data_new = dict() 172 | data_new["X"] = dict() 173 | data_new["Y"] = dict() 174 | 175 | idx_x = idx_arr_x[i_mesh_x][0].astype(np.long) - 1 176 | idx_y = idx_arr_y[i_mesh_y][0].astype(np.long) - 1 177 | 178 | data_new["X"]["vert_full"] = data_curr["X"]["vert"] 179 | data_new["Y"]["vert_full"] = data_curr["Y"]["vert"] 180 | data_new["X"]["idx"] = idx_x 181 | data_new["Y"]["idx"] = idx_y 182 | 183 | data_new["X"]["vert"] = data_curr["X"]["vert"][idx_x, :] 184 | data_new["Y"]["vert"] = data_curr["Y"]["vert"][idx_y, :] 185 | data_new["X"]["triv"] = triv_arr_x[i_mesh_x][0].astype(np.long) 186 | data_new["Y"]["triv"] = triv_arr_y[i_mesh_y][0].astype(np.long) 187 | 188 | if "D" in data_curr["X"]: 189 | idx_x = idx_x.squeeze() 190 | idx_y = idx_y.squeeze() 191 | data_new["X"]["D"] = data_curr["X"]["D"][:, idx_x][idx_x, :] 192 | data_new["Y"]["D"] = data_curr["Y"]["D"][:, idx_y][idx_y, :] 193 | 194 | if "sub" in data_curr["X"]: 195 | data_new["X"]["sub"] = data_curr["X"]["sub"] 196 | data_new["Y"]["sub"] = data_curr["Y"]["sub"] 197 | data_new["X"]["idx"] = idx_x 198 | data_new["Y"]["idx"] = idx_y 199 | 200 | data_new["axis"] = self.axis 201 | 202 | return data_new 203 | 204 | def __len__(self): 205 | return len(self.dataset) 206 | 207 | def dataset_name_str(self): 208 | return self.dataset.dataset_name_str() 209 | 210 | 211 | def get_faust_remeshed_folder(resolution): 212 | if resolution is None: 213 | folder_path = os.path.join(data_folder_faust_remeshed, "full") 214 | else: 215 | folder_path = os.path.join(data_folder_faust_remeshed, "sub_" + str(resolution)) 216 | return folder_path 217 | 218 | 219 | def get_shrec20_folder(resolution): 220 | if resolution is None: 221 | folder_path = os.path.join(data_folder_shrec20, "full") 222 | else: 223 | folder_path = os.path.join(data_folder_shrec20, "sub_" + str(resolution)) 224 | return folder_path 225 | 226 | 227 | class Faust_remeshed_train(ShapeDatasetCombine): 228 | def __init__(self, resolution, num_shapes=80, load_dist_mat=False, load_sub=False): 229 | self.resolution = resolution 230 | super().__init__(get_faust_remeshed_folder(resolution), num_shapes, load_dist_mat=load_dist_mat, load_sub=load_sub) 231 | 232 | def dataset_name_str(self): 233 | return "FAUST_remeshed_" + str(self.resolution) + "_train" 234 | 235 | 236 | class Faust_remeshed_test(ShapeDatasetCombine): 237 | def __init__(self, resolution, num_shapes=20, load_dist_mat=False, load_sub=False): 238 | self.resolution = resolution 239 | super().__init__(get_faust_remeshed_folder(resolution), num_shapes, load_dist_mat=load_dist_mat, load_sub=load_sub) 240 | 241 | def _get_index(self, i): 242 | return i+80 243 | 244 | def dataset_name_str(self): 245 | return "FAUST_remeshed_" + str(self.resolution) + "_test" 246 | 247 | 248 | class Mano_train(ShapeDatasetCombine): 249 | def __init__(self, resolution=None, num_shapes=100, load_dist_mat=False, load_sub=False): 250 | super().__init__(data_folder_mano_right, num_shapes, axis=2, load_dist_mat=load_dist_mat, load_sub=load_sub) 251 | 252 | def dataset_name_str(self): 253 | return "Mano_train" 254 | 255 | 256 | class Mano_test(ShapeDatasetCombine): 257 | def __init__(self, resolution=None, num_shapes=20, load_dist_mat=False, load_sub=False): 258 | super().__init__(data_folder_mano_test, num_shapes, axis=2, load_dist_mat=load_dist_mat, load_sub=load_sub) 259 | 260 | def dataset_name_str(self): 261 | return "Mano_test" 262 | 263 | 264 | class Shrec20_full(ShapeDatasetCombine): 265 | def __init__(self, resolution, num_shapes=14, load_dist_mat=False, load_sub=False): 266 | self.resolution = resolution 267 | super().__init__(get_shrec20_folder(resolution), num_shapes, load_dist_mat=load_dist_mat, load_sub=load_sub) 268 | 269 | def dataset_name_str(self): 270 | return "Shrec20_" + str(self.resolution) + "_train" 271 | 272 | 273 | 274 | if __name__ == "__main__": 275 | print("main of data.py") 276 | -------------------------------------------------------------------------------- /data_preprocessing/preprocess_dataset.m: -------------------------------------------------------------------------------- 1 | function preprocess_dataset(dataset_path, shape_file_extension, resolution_sub) 2 | 3 | % Copyright (c) Facebook, Inc. and its affiliates. 4 | % 5 | % This source code is licensed under the MIT license found in the 6 | % LICENSE file in the root directory of this source tree. 7 | 8 | addpath(genpath(fileparts(mfilename('fullpath')))); 9 | 10 | if ~exist("shape_file_extension", "var") 11 | shape_file_extension = ".obj"; 12 | end 13 | 14 | if ~exist("resolution_sub", "var") 15 | resolution_sub = 2000; 16 | end 17 | 18 | disp("Converting the shape files to .mat files...") 19 | dataset_path_mat = convert_dataset(dataset_path, shape_file_extension); 20 | 21 | fprintf("Subsampling the shapes to a resolution of %d vertices...\n", resolution_sub) 22 | dataset_path_sub = reduce_folder_individual(dataset_path_mat, resolution_sub); 23 | 24 | disp("Creating the remeshed version of individual shapes...") 25 | create_remeshed_dataset(dataset_path_sub); 26 | 27 | disp("Calculating the geodesic distance matrices...") 28 | compute_dist_matrices(dataset_path_sub); 29 | end 30 | -------------------------------------------------------------------------------- /data_preprocessing/utils/compute_dist_matrices.m: -------------------------------------------------------------------------------- 1 | function compute_dist_matrices(shapes_dir) 2 | 3 | % Copyright (c) Facebook, Inc. and its affiliates. 4 | % 5 | % This source code is licensed under the MIT license found in the 6 | % LICENSE file in the root directory of this source tree. 7 | 8 | files = dir(fullfile(shapes_dir, "*.mat")); 9 | matrices_dir = fullfile(shapes_dir, "distance_matrix"); 10 | if ~isfolder(matrices_dir); mkdir(matrices_dir); end 11 | 12 | for i = 1:numel(files) 13 | fprintf(" Processing %d of %d\n", i, numel(files)) 14 | if exist(fullfile(matrices_dir, files(i).name)), continue; end 15 | 16 | S = load(fullfile(shapes_dir, files(i).name)); 17 | D = compute_dist_matrix(S); 18 | D = single(D); 19 | 20 | save(fullfile(matrices_dir, files(i).name), 'D'); 21 | end 22 | 23 | end 24 | -------------------------------------------------------------------------------- /data_preprocessing/utils/compute_dist_matrix.m: -------------------------------------------------------------------------------- 1 | function D = compute_dist_matrix(S, samples) 2 | 3 | % Copyright (c) Facebook, Inc. and its affiliates. 4 | % 5 | % This source code is licensed under the MIT license found in the 6 | % LICENSE file in the root directory of this source tree. 7 | 8 | M = S.X; 9 | M.n = size(M.vert, 1); 10 | M.m = size(M.triv, 1); 11 | 12 | if nargin == 1 || isempty(samples) 13 | samples = 1:M.n; 14 | end 15 | 16 | if ~exist('fastmarchmex') 17 | % Use precomputed binaries from https://github.com/abbasloo/dnnAuto/ 18 | base_url = "https://github.com/abbasloo/dnnAuto/raw/37ce4320bc90a75b07a7ec1d862484d6576cec4c/preprocessing/isc"; 19 | for ext = {'mexa64', 'mexmaci64', 'mexw32', 'mexw64'} 20 | urlwrite(... 21 | base_url + "/fastmarchmex." + ext, ... 22 | "fastmarchmex." + ext); 23 | end 24 | rehash 25 | end 26 | 27 | % Calls legacy fast marching code 28 | march = fastmarchmex('init', int32(M.triv - 1), double(M.vert(:, 1)), double(M.vert(:, 2)), double(M.vert(:, 3))); 29 | 30 | D = zeros(length(samples)); 31 | 32 | for i = 1:length(samples) 33 | source = inf(M.n, 1); 34 | source(samples(i)) = 0; 35 | d = fastmarchmex('march', march, double(source)); 36 | D(:, i) = d(samples); 37 | end 38 | 39 | fastmarchmex('deinit', march); 40 | 41 | % Ensures that the distance matrix is exactly symmetric 42 | D = 0.5 * (D + D'); 43 | end 44 | -------------------------------------------------------------------------------- /data_preprocessing/utils/compute_triangle_areas.m: -------------------------------------------------------------------------------- 1 | function area = compute_triangle_areas(X) 2 | 3 | % Copyright (c) Facebook, Inc. and its affiliates. 4 | % 5 | % This source code is licensed under the MIT license found in the 6 | % LICENSE file in the root directory of this source tree. 7 | 8 | edge = cell(3, 1); 9 | 10 | for j = 1:3 11 | edge{j} = X.vert(X.triv(:, j), :); 12 | end 13 | 14 | area = 0.5 .* sqrt(sum(cross(edge{1} - edge{2}, edge{1} - edge{3}).^2, 2)); 15 | end 16 | -------------------------------------------------------------------------------- /data_preprocessing/utils/convert_dataset.m: -------------------------------------------------------------------------------- 1 | function folder_out = convert_dataset(folder_in, shape_file_ext) 2 | 3 | % Copyright (c) Facebook, Inc. and its affiliates. 4 | % 5 | % This source code is licensed under the MIT license found in the 6 | % LICENSE file in the root directory of this source tree. 7 | 8 | files = dir(fullfile(folder_in, "*" + shape_file_ext)); 9 | folder_out = fullfile(folder_in, "mat/"); 10 | if ~isfolder(folder_out); mkdir(folder_out); end 11 | 12 | for i = 1:length(files) 13 | 14 | fprintf(" Processing %d of %d\n", i, length(files)); 15 | 16 | file_in = fullfile(folder_in, files(i).name); 17 | [~, ~, file_in_ext] = fileparts(file_in); 18 | 19 | file_out = fullfile(folder_out, "shape_" + string(num2str(i - 1, '%03d')) + ".mat"); 20 | if exist(file_out); continue; end 21 | 22 | switch file_in_ext 23 | case '.off' 24 | [X.vert, X.triv] = read_off(file_in); 25 | X.vert = X.vert'; 26 | X.triv = X.triv'; 27 | case '.obj' 28 | [X.vert, X.triv] = read_obj(file_in); 29 | case '.ply' 30 | [X.vert, X.triv] = read_ply(file_in); 31 | case '.mat' 32 | load(file_in, 'X'); 33 | end 34 | 35 | X.n = size(X.vert, 1); 36 | X.m = size(X.triv, 1); 37 | 38 | X.vert = X.vert - mean(X.vert, 1); 39 | 40 | refarea = 0.44; 41 | X.vert = X.vert ./ sqrt(sum(compute_triangle_areas(X))) .* sqrt(refarea); 42 | 43 | save(file_out, 'X'); 44 | 45 | end 46 | 47 | end 48 | -------------------------------------------------------------------------------- /data_preprocessing/utils/create_remeshed_collection.m: -------------------------------------------------------------------------------- 1 | function [idx_arr, triv_arr] = create_remeshed_collection(file_in, res_array) 2 | 3 | % Copyright (c) Facebook, Inc. and its affiliates. 4 | % 5 | % This source code is licensed under the MIT license found in the 6 | % LICENSE file in the root directory of this source tree. 7 | 8 | S = load(file_in); 9 | 10 | if ~exist("res_array", "var") 11 | res_array = 200:2000; 12 | end 13 | 14 | idx_arr = cell(length(res_array), 1); 15 | triv_arr = cell(length(res_array), 1); 16 | 17 | for i_res = 1:length(res_array) 18 | num_vert_reduced = res_array(i_res); 19 | 20 | X_p.vertices = S.X.vert; 21 | X_p.faces = S.X.triv; 22 | 23 | ratio = num_vert_reduced / size(S.X.vert, 1); 24 | 25 | if ratio < 1 26 | X_p = reducepatch(X_p, ratio); 27 | end 28 | 29 | idx_arr{i_res} = knnsearch(S.X.vert, X_p.vertices); 30 | triv_arr{i_res} = X_p.faces; 31 | 32 | X_rec.vert = S.X.vert(idx_arr{i_res}, :); 33 | X_rec.triv = triv_arr{i_res}; 34 | end 35 | 36 | end 37 | -------------------------------------------------------------------------------- /data_preprocessing/utils/create_remeshed_dataset.m: -------------------------------------------------------------------------------- 1 | function create_remeshed_dataset(shapes_dir) 2 | 3 | % Copyright (c) Facebook, Inc. and its affiliates. 4 | % 5 | % This source code is licensed under the MIT license found in the 6 | % LICENSE file in the root directory of this source tree. 7 | 8 | out_dir = fullfile(shapes_dir, "remeshing_idx"); 9 | if ~isfolder(out_dir); mkdir(out_dir); end 10 | 11 | files = dir(fullfile(shapes_dir, "*.mat")); 12 | 13 | for i = 1:length(files) 14 | fprintf(" Processing %d of %d\n", i, length(files)); 15 | file_in = fullfile(shapes_dir, files(i).name); 16 | file_out = fullfile(out_dir, files(i).name); 17 | if exist(file_out); continue; end 18 | 19 | [idx_arr, triv_arr] = create_remeshed_collection(file_in); 20 | 21 | save(file_out, "idx_arr", "triv_arr"); 22 | end 23 | 24 | end 25 | -------------------------------------------------------------------------------- /data_preprocessing/utils/read_obj.m: -------------------------------------------------------------------------------- 1 | function [vertex, faces, normal] = read_obj(filename) 2 | 3 | % read_obj - load a .obj file. 4 | % 5 | % [vertex,face,normal] = read_obj(filename); 6 | % 7 | % faces : list of facesangle elements 8 | % vertex : node vertexinatates 9 | % normal : normal vector list 10 | % 11 | % Copyright (c) 2003 Gabriel Peyré 12 | 13 | fid = fopen(filename); 14 | 15 | if fid < 0 16 | error(['Cannot open ' filename '.']); 17 | end 18 | 19 | frewind(fid); 20 | a = fscanf(fid, '%c', 1); 21 | 22 | if strcmp(a, 'P') 23 | % This is the montreal neurological institute (MNI) specific ASCII facesangular mesh data structure. 24 | % For FreeSurfer software, a slightly different data input coding is 25 | % needed. It will be provided upon request. 26 | fscanf(fid, '%f', 5); 27 | n_points = fscanf(fid, '%i', 1); 28 | vertex = fscanf(fid, '%f', [3, n_points]); 29 | normal = fscanf(fid, '%f', [3, n_points]); 30 | n_faces = fscanf(fid, '%i', 1); 31 | fscanf(fid, '%i', 5 + n_faces); 32 | faces = fscanf(fid, '%i', [3, n_faces])' + 1; 33 | fclose(fid); 34 | return; 35 | end 36 | 37 | frewind(fid); 38 | vertex = []; 39 | faces = []; 40 | 41 | while 1 42 | s = fgetl(fid); 43 | 44 | if ~ischar(s), 45 | break; 46 | end 47 | 48 | if ~isempty(s) && strcmp(s(1), 'f') 49 | % face 50 | faces(:, end + 1) = sscanf(s(3:end), '%d %d %d'); 51 | end 52 | 53 | if ~isempty(s) && strcmp(s(1), 'v') 54 | % vertex 55 | vertex(:, end + 1) = sscanf(s(3:end), '%f %f %f'); 56 | end 57 | 58 | end 59 | 60 | fclose(fid); 61 | 62 | end 63 | -------------------------------------------------------------------------------- /data_preprocessing/utils/read_off.m: -------------------------------------------------------------------------------- 1 | function [vertex, face] = read_off(filename) 2 | 3 | % read_off - read data from OFF file. 4 | % 5 | % [vertex,face] = read_off(filename); 6 | % 7 | % 'vertex' is a 'nb.vert x 3' array specifying the position of the vertices. 8 | % 'face' is a 'nb.face x 3' array specifying the connectivity of the mesh. 9 | % 10 | % Copyright (c) 2003 Gabriel Peyré 11 | 12 | fid = fopen(filename, 'r'); 13 | 14 | if fid == -1 15 | error('Can''t open the file.'); 16 | return; 17 | end 18 | 19 | str = fgets(fid); % -1 if eof 20 | 21 | if ~strcmp(str(1:3), 'OFF') 22 | error('The file is not a valid OFF one.'); 23 | end 24 | 25 | str = fgets(fid); 26 | [a, str] = strtok(str); nvert = str2num(a); 27 | [a, str] = strtok(str); nface = str2num(a); 28 | 29 | [A, cnt] = fscanf(fid, '%f %f %f', 3 * nvert); 30 | 31 | if cnt ~= 3 * nvert 32 | warning('Problem in reading vertices.'); 33 | end 34 | 35 | A = reshape(A, 3, cnt / 3); 36 | vertex = A; 37 | % read Face 1 1088 480 1022 38 | [A, cnt] = fscanf(fid, '%d %d %d %d\n', 4 * nface); 39 | 40 | if cnt ~= 4 * nface 41 | warning('Problem in reading faces.'); 42 | end 43 | 44 | A = reshape(A, 4, cnt / 4); 45 | face = A(2:4, :) + 1; 46 | 47 | fclose(fid); 48 | end 49 | -------------------------------------------------------------------------------- /data_preprocessing/utils/read_ply.m: -------------------------------------------------------------------------------- 1 | function [vertex, face] = read_ply(filename) 2 | 3 | % read_ply - read data from PLY file. 4 | % 5 | % [vertex,face] = read_ply(filename); 6 | % 7 | % 'vertex' is a 'nb.vert x 3' array specifying the position of the vertices. 8 | % 'face' is a 'nb.face x 3' array specifying the connectivity of the mesh. 9 | % 10 | % IMPORTANT: works only for triangular meshes. 11 | % 12 | % Copyright (c) 2003 Gabriel Peyré 13 | 14 | [d, c] = plyread(filename); 15 | 16 | vi = d.face.vertex_indices; 17 | nf = length(vi); 18 | face = zeros(nf, 3); 19 | 20 | for i = 1:nf 21 | face(i, :) = vi{i} + 1; 22 | end 23 | 24 | vertex = [d.vertex.x, d.vertex.y, d.vertex.z]; 25 | end 26 | 27 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 28 | function [Elements, varargout] = plyread(Path, Str) 29 | %PLYREAD Read a PLY 3D data file. 30 | % [DATA,COMMENTS] = PLYREAD(FILENAME) reads a version 1.0 PLY file 31 | % FILENAME and returns a structure DATA. The fields in this structure 32 | % are defined by the PLY header; each element type is a field and each 33 | % element property is a subfield. If the file contains any comments, 34 | % they are returned in a cell string array COMMENTS. 35 | % 36 | % [TRI,PTS] = PLYREAD(FILENAME,'tri') or 37 | % [TRI,PTS,DATA,COMMENTS] = PLYREAD(FILENAME,'tri') converts vertex 38 | % and face data into triangular connectivity and vertex arrays. The 39 | % mesh can then be displayed using the TRISURF command. 40 | % 41 | % Note: This function is slow for large mesh files (+50K faces), 42 | % especially when reading data with list type properties. 43 | % 44 | % Example: 45 | % [Tri,Pts] = PLYREAD('cow.ply','tri'); 46 | % trisurf(Tri,Pts(:,1),Pts(:,2),Pts(:,3)); 47 | % colormap(gray); axis equal; 48 | % 49 | % See also: PLYWRITE 50 | 51 | % Pascal Getreuer 2004 52 | 53 | [fid, Msg] = fopen(Path, 'rt'); % open file in read text mode 54 | 55 | if fid == -1, error(Msg); end 56 | 57 | Buf = fscanf(fid, '%s', 1); 58 | 59 | if ~strcmp(Buf, 'ply') 60 | fclose(fid); 61 | error('Not a PLY file.'); 62 | end 63 | 64 | %%% read header %%% 65 | 66 | Position = ftell(fid); 67 | Format = ''; 68 | NumComments = 0; 69 | Comments = {}; % for storing any file comments 70 | NumElements = 0; 71 | NumProperties = 0; 72 | Elements = []; % structure for holding the element data 73 | ElementCount = []; % number of each type of element in file 74 | PropertyTypes = []; % corresponding structure recording property types 75 | ElementNames = {}; % list of element names in the order they are stored in the file 76 | PropertyNames = []; % structure of lists of property names 77 | 78 | while 1 79 | Buf = fgetl(fid); % read one line from file 80 | BufRem = Buf; 81 | Token = {}; 82 | Count = 0; 83 | 84 | while ~isempty(BufRem) % split line into tokens 85 | [tmp, BufRem] = strtok(BufRem); 86 | 87 | if ~isempty(tmp) 88 | Count = Count + 1; % count tokens 89 | Token{Count} = tmp; 90 | end 91 | 92 | end 93 | 94 | if Count % parse line 95 | 96 | switch lower(Token{1}) 97 | case 'format' % read data format 98 | 99 | if Count >= 2 100 | Format = lower(Token{2}); 101 | 102 | if Count == 3 & ~strcmp(Token{3}, '1.0') 103 | fclose(fid); 104 | error('Only PLY format version 1.0 supported.'); 105 | end 106 | 107 | end 108 | 109 | case 'comment' % read file comment 110 | NumComments = NumComments + 1; 111 | Comments{NumComments} = ''; 112 | 113 | for i = 2:Count 114 | Comments{NumComments} = [Comments{NumComments}, Token{i}, ' ']; 115 | end 116 | 117 | case 'element' % element name 118 | 119 | if Count >= 3 120 | 121 | if isfield(Elements, Token{2}) 122 | fclose(fid); 123 | error(['Duplicate element name, ''', Token{2}, '''.']); 124 | end 125 | 126 | NumElements = NumElements + 1; 127 | NumProperties = 0; 128 | Elements = setfield(Elements, Token{2}, []); 129 | PropertyTypes = setfield(PropertyTypes, Token{2}, []); 130 | ElementNames{NumElements} = Token{2}; 131 | PropertyNames = setfield(PropertyNames, Token{2}, {}); 132 | CurElement = Token{2}; 133 | ElementCount(NumElements) = str2double(Token{3}); 134 | 135 | if isnan(ElementCount(NumElements)) 136 | fclose(fid); 137 | error(['Bad element definition: ', Buf]); 138 | end 139 | 140 | else 141 | error(['Bad element definition: ', Buf]); 142 | end 143 | 144 | case 'property' % element property 145 | 146 | if ~isempty(CurElement) & Count >= 3 147 | NumProperties = NumProperties + 1; 148 | eval(['tmp=isfield(Elements.', CurElement, ',Token{Count});'], ... 149 | 'fclose(fid);error([''Error reading property: '',Buf])'); 150 | 151 | if tmp 152 | error(['Duplicate property name, ''', CurElement, '.', Token{2}, '''.']); 153 | end 154 | 155 | % add property subfield to Elements 156 | eval(['Elements.', CurElement, '.', Token{Count}, '=[];'], ... 157 | 'fclose(fid);error([''Error reading property: '',Buf])'); 158 | % add property subfield to PropertyTypes and save type 159 | eval(['PropertyTypes.', CurElement, '.', Token{Count}, '={Token{2:Count-1}};'], ... 160 | 'fclose(fid);error([''Error reading property: '',Buf])'); 161 | % record property name order 162 | eval(['PropertyNames.', CurElement, '{NumProperties}=Token{Count};'], ... 163 | 'fclose(fid);error([''Error reading property: '',Buf])'); 164 | else 165 | fclose(fid); 166 | 167 | if isempty(CurElement) 168 | error(['Property definition without element definition: ', Buf]); 169 | else 170 | error(['Bad property definition: ', Buf]); 171 | end 172 | 173 | end 174 | 175 | case 'end_header' % end of header, break from while loop 176 | break; 177 | end 178 | 179 | end 180 | 181 | end 182 | 183 | %%% set reading for specified data format %%% 184 | 185 | if isempty(Format) 186 | warning('Data format unspecified, assuming ASCII.'); 187 | Format = 'ascii'; 188 | end 189 | 190 | switch Format 191 | case 'ascii' 192 | Format = 0; 193 | case 'binary_little_endian' 194 | Format = 1; 195 | case 'binary_big_endian' 196 | Format = 2; 197 | otherwise 198 | fclose(fid); 199 | error(['Data format ''', Format, ''' not supported.']); 200 | end 201 | 202 | if ~Format 203 | Buf = fscanf(fid, '%f'); % read the rest of the file as ASCII data 204 | BufOff = 1; 205 | else 206 | % reopen the file in read binary mode 207 | fclose(fid); 208 | 209 | if Format == 1 210 | fid = fopen(Path, 'r', 'ieee-le.l64'); % little endian 211 | else 212 | fid = fopen(Path, 'r', 'ieee-be.l64'); % big endian 213 | end 214 | 215 | % find the end of the header again (using ftell on the old handle doesn't give the correct position) 216 | BufSize = 8192; 217 | Buf = [blanks(10), char(fread(fid, BufSize, 'uchar')')]; 218 | i = []; 219 | tmp = -11; 220 | 221 | while isempty(i) 222 | i = findstr(Buf, ['end_header', 13, 10]); % look for end_header + CR/LF 223 | i = [i, findstr(Buf, ['end_header', 10])]; % look for end_header + LF 224 | 225 | if isempty(i) 226 | tmp = tmp + BufSize; 227 | Buf = [Buf(BufSize + 1:BufSize + 10), char(fread(fid, BufSize, 'uchar')')]; 228 | end 229 | 230 | end 231 | 232 | % seek to just after the line feed 233 | fseek(fid, i + tmp + 11 + (Buf(i + 10) == 13), -1); 234 | end 235 | 236 | %%% read element data %%% 237 | 238 | % PLY and MATLAB data types (for fread) 239 | PlyTypeNames = {'char', 'uchar', 'short', 'ushort', 'int', 'uint', 'float', 'double', ... 240 | 'char8', 'uchar8', 'short16', 'ushort16', 'int32', 'uint32', 'float32', 'double64'}; 241 | MatlabTypeNames = {'schar', 'uchar', 'int16', 'uint16', 'int32', 'uint32', 'single', 'double'}; 242 | SizeOf = [1, 1, 2, 2, 4, 4, 4, 8]; % size in bytes of each type 243 | 244 | for i = 1:NumElements 245 | % get current element property information 246 | eval(['CurPropertyNames=PropertyNames.', ElementNames{i}, ';']); 247 | eval(['CurPropertyTypes=PropertyTypes.', ElementNames{i}, ';']); 248 | NumProperties = size(CurPropertyNames, 2); 249 | 250 | % fprintf('Reading %s...\n',ElementNames{i}); 251 | 252 | if ~Format % % % read ASCII data % % % 253 | 254 | for j = 1:NumProperties 255 | Token = getfield(CurPropertyTypes, CurPropertyNames{j}); 256 | 257 | if strcmpi(Token{1}, 'list') 258 | Type(j) = 1; 259 | else 260 | Type(j) = 0; 261 | end 262 | 263 | end 264 | 265 | % parse buffer 266 | if ~any(Type) 267 | % no list types 268 | Data = reshape(Buf(BufOff:BufOff + ElementCount(i) * NumProperties - 1), NumProperties, ElementCount(i))'; 269 | BufOff = BufOff + ElementCount(i) * NumProperties; 270 | else 271 | ListData = cell(NumProperties, 1); 272 | 273 | for k = 1:NumProperties 274 | ListData{k} = cell(ElementCount(i), 1); 275 | end 276 | 277 | % list type 278 | for j = 1:ElementCount(i) 279 | 280 | for k = 1:NumProperties 281 | 282 | if ~Type(k) 283 | Data(j, k) = Buf(BufOff); 284 | BufOff = BufOff + 1; 285 | else 286 | tmp = Buf(BufOff); 287 | ListData{k}{j} = Buf(BufOff + (1:tmp))'; 288 | BufOff = BufOff + tmp + 1; 289 | end 290 | 291 | end 292 | 293 | end 294 | 295 | end 296 | 297 | else %%% read binary data %%% 298 | % translate PLY data type names to MATLAB data type names 299 | ListFlag = 0; % = 1 if there is a list type 300 | SameFlag = 1; % = 1 if all types are the same 301 | 302 | for j = 1:NumProperties 303 | Token = getfield(CurPropertyTypes, CurPropertyNames{j}); 304 | 305 | if ~strcmp(Token{1}, 'list') % non-list type 306 | tmp = rem(strmatch(Token{1}, PlyTypeNames, 'exact') - 1, 8) + 1; 307 | 308 | if ~isempty(tmp) 309 | TypeSize(j) = SizeOf(tmp); 310 | Type{j} = MatlabTypeNames{tmp}; 311 | TypeSize2(j) = 0; 312 | Type2{j} = ''; 313 | 314 | SameFlag = SameFlag & strcmp(Type{1}, Type{j}); 315 | else 316 | fclose(fid); 317 | error(['Unknown property data type, ''', Token{1}, ''', in ', ... 318 | ElementNames{i}, '.', CurPropertyNames{j}, '.']); 319 | end 320 | 321 | else % list type 322 | 323 | if length(Token) == 3 324 | ListFlag = 1; 325 | SameFlag = 0; 326 | tmp = rem(strmatch(Token{2}, PlyTypeNames, 'exact') - 1, 8) + 1; 327 | tmp2 = rem(strmatch(Token{3}, PlyTypeNames, 'exact') - 1, 8) + 1; 328 | 329 | if ~isempty(tmp) & ~isempty(tmp2) 330 | TypeSize(j) = SizeOf(tmp); 331 | Type{j} = MatlabTypeNames{tmp}; 332 | TypeSize2(j) = SizeOf(tmp2); 333 | Type2{j} = MatlabTypeNames{tmp2}; 334 | else 335 | fclose(fid); 336 | error(['Unknown property data type, ''list ', Token{2}, ' ', Token{3}, ''', in ', ... 337 | ElementNames{i}, '.', CurPropertyNames{j}, '.']); 338 | end 339 | 340 | else 341 | fclose(fid); 342 | error(['Invalid list syntax in ', ElementNames{i}, '.', CurPropertyNames{j}, '.']); 343 | end 344 | 345 | end 346 | 347 | end 348 | 349 | % read file 350 | if ~ListFlag 351 | 352 | if SameFlag 353 | % no list types, all the same type (fast) 354 | Data = fread(fid, [NumProperties, ElementCount(i)], Type{1})'; 355 | else 356 | % no list types, mixed type 357 | Data = zeros(ElementCount(i), NumProperties); 358 | 359 | for j = 1:ElementCount(i) 360 | 361 | for k = 1:NumProperties 362 | Data(j, k) = fread(fid, 1, Type{k}); 363 | end 364 | 365 | end 366 | 367 | end 368 | 369 | else 370 | ListData = cell(NumProperties, 1); 371 | 372 | for k = 1:NumProperties 373 | ListData{k} = cell(ElementCount(i), 1); 374 | end 375 | 376 | if NumProperties == 1 377 | BufSize = 512; 378 | SkipNum = 4; 379 | j = 0; 380 | 381 | % list type, one property (fast if lists are usually the same length) 382 | while j < ElementCount(i) 383 | Position = ftell(fid); 384 | % read in BufSize count values, assuming all counts = SkipNum 385 | [Buf, BufSize] = fread(fid, BufSize, Type{1}, SkipNum * TypeSize2(1)); 386 | Miss = find(Buf ~= SkipNum); % find first count that is not SkipNum 387 | fseek(fid, Position + TypeSize(1), -1); % seek back to after first count 388 | 389 | if isempty(Miss) % all counts are SkipNum 390 | Buf = fread(fid, [SkipNum, BufSize], [int2str(SkipNum), '*', Type2{1}], TypeSize(1))'; 391 | fseek(fid, -TypeSize(1), 0); % undo last skip 392 | 393 | for k = 1:BufSize 394 | ListData{1}{j + k} = Buf(k, :); 395 | end 396 | 397 | j = j + BufSize; 398 | BufSize = floor(1.5 * BufSize); 399 | else 400 | 401 | if Miss(1) > 1 % some counts are SkipNum 402 | Buf2 = fread(fid, [SkipNum, Miss(1) - 1], [int2str(SkipNum), '*', Type2{1}], TypeSize(1))'; 403 | 404 | for k = 1:Miss(1) - 1 405 | ListData{1}{j + k} = Buf2(k, :); 406 | end 407 | 408 | j = j + k; 409 | end 410 | 411 | % read in the list with the missed count 412 | SkipNum = Buf(Miss(1)); 413 | j = j + 1; 414 | ListData{1}{j} = fread(fid, [1, SkipNum], Type2{1}); 415 | BufSize = ceil(0.6 * BufSize); 416 | end 417 | 418 | end 419 | 420 | else 421 | % list type(s), multiple properties (slow) 422 | Data = zeros(ElementCount(i), NumProperties); 423 | 424 | for j = 1:ElementCount(i) 425 | 426 | for k = 1:NumProperties 427 | 428 | if isempty(Type2{k}) 429 | Data(j, k) = fread(fid, 1, Type{k}); 430 | else 431 | tmp = fread(fid, 1, Type{k}); 432 | ListData{k}{j} = fread(fid, [1, tmp], Type2{k}); 433 | end 434 | 435 | end 436 | 437 | end 438 | 439 | end 440 | 441 | end 442 | 443 | end 444 | 445 | % put data into Elements structure 446 | for k = 1:NumProperties 447 | 448 | if (~Format & ~Type(k)) | (Format & isempty(Type2{k})) 449 | eval(['Elements.', ElementNames{i}, '.', CurPropertyNames{k}, '=Data(:,k);']); 450 | else 451 | eval(['Elements.', ElementNames{i}, '.', CurPropertyNames{k}, '=ListData{k};']); 452 | end 453 | 454 | end 455 | 456 | end 457 | 458 | clear Data ListData; 459 | fclose(fid); 460 | 461 | if (nargin > 1 & strcmpi(Str, 'Tri')) | nargout > 2 462 | % find vertex element field 463 | Name = {'vertex', 'Vertex', 'point', 'Point', 'pts', 'Pts'}; 464 | Names = []; 465 | 466 | for i = 1:length(Name) 467 | 468 | if any(strcmp(ElementNames, Name{i})) 469 | Names = getfield(PropertyNames, Name{i}); 470 | Name = Name{i}; 471 | break; 472 | end 473 | 474 | end 475 | 476 | if any(strcmp(Names, 'x')) & any(strcmp(Names, 'y')) & any(strcmp(Names, 'z')) 477 | eval(['varargout{1}=[Elements.', Name, '.x,Elements.', Name, '.y,Elements.', Name, '.z];']); 478 | else 479 | varargout{1} = zeros(1, 3); 480 | end 481 | 482 | varargout{2} = Elements; 483 | varargout{3} = Comments; 484 | Elements = []; 485 | 486 | % find face element field 487 | Name = {'face', 'Face', 'poly', 'Poly', 'tri', 'Tri'}; 488 | Names = []; 489 | 490 | for i = 1:length(Name) 491 | 492 | if any(strcmp(ElementNames, Name{i})) 493 | Names = getfield(PropertyNames, Name{i}); 494 | Name = Name{i}; 495 | break; 496 | end 497 | 498 | end 499 | 500 | if ~isempty(Names) 501 | % find vertex indices property subfield 502 | PropertyName = {'vertex_indices', 'vertex_indexes', 'vertex_index', 'indices', 'indexes'}; 503 | 504 | for i = 1:length(PropertyName) 505 | 506 | if any(strcmp(Names, PropertyName{i})) 507 | PropertyName = PropertyName{i}; 508 | break; 509 | end 510 | 511 | end 512 | 513 | if ~iscell(PropertyName) 514 | % convert face index lists to triangular connectivity 515 | eval(['FaceIndices=varargout{2}.', Name, '.', PropertyName, ';']); 516 | N = length(FaceIndices); 517 | Elements = zeros(N * 2, 3); 518 | Extra = 0; 519 | 520 | for k = 1:N 521 | Elements(k, :) = FaceIndices{k}(1:3); 522 | 523 | for j = 4:length(FaceIndices{k}) 524 | Extra = Extra + 1; 525 | Elements(N + Extra, :) = [Elements(k, [1, j - 1]), FaceIndices{k}(j)]; 526 | end 527 | 528 | end 529 | 530 | Elements = Elements(1:N + Extra, :) + 1; 531 | end 532 | 533 | end 534 | 535 | else 536 | varargout{1} = Comments; 537 | end 538 | 539 | end 540 | -------------------------------------------------------------------------------- /data_preprocessing/utils/reduce_folder_individual.m: -------------------------------------------------------------------------------- 1 | function folder_out = reduce_folder_individual(shapes_dir, num_vert_reduced) 2 | 3 | % Copyright (c) Facebook, Inc. and its affiliates. 4 | % 5 | % This source code is licensed under the MIT license found in the 6 | % LICENSE file in the root directory of this source tree. 7 | 8 | files = dir(fullfile(shapes_dir, "*.mat")); 9 | folder_out = fullfile(shapes_dir, "sub_" + string(num_vert_reduced)); 10 | if ~exist(folder_out, 'file'); mkdir(folder_out); end 11 | 12 | for i = 1:length(files) 13 | fprintf(" Processing %d of %d\n", i, length(files)); 14 | 15 | file_curr = fullfile(shapes_dir, files(i).name); 16 | [~, name, ~] = fileparts(file_curr); 17 | 18 | red_file = fullfile(folder_out, string(name) + ".mat"); 19 | if exist(red_file, 'file'); continue; end 20 | 21 | S = load(file_curr); 22 | 23 | refarea = 0.44; 24 | S.X.vert = S.X.vert - mean(S.X.vert, 1); 25 | S.X.vert = S.X.vert ./ sqrt(sum(compute_triangle_areas(S.X))) .* sqrt(refarea); 26 | 27 | [samples, faces] = subsample_shape(S.X, num_vert_reduced); 28 | 29 | X = struct; 30 | X.triv = faces; 31 | X.vert = S.X.vert(samples, :); 32 | 33 | save(red_file, 'X') 34 | end 35 | 36 | end 37 | 38 | function [samples, faces] = subsample_shape(X, num_vert_reduced) 39 | X_p.vertices = X.vert; 40 | X_p.faces = X.triv; 41 | 42 | ratio = num_vert_reduced / size(X.vert, 1); 43 | 44 | if ratio < 1 45 | X_p = reducepatch(X_p, ratio); 46 | end 47 | 48 | samples = knnsearch(X.vert, X_p.vertices); 49 | faces = X_p.faces; 50 | end 51 | -------------------------------------------------------------------------------- /figures/splash.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/neuromorph/e6e6433537bc42d9150bcad931d6fbd4aea3f751/figures/splash.png -------------------------------------------------------------------------------- /main_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from main_train import * 7 | import sys 8 | import os 9 | 10 | 11 | def save_sequence( 12 | folder_name, file_name, vert_sequence, shape_x, shape_y, time_elapsed=0 13 | ): 14 | """Saves an interpolation sequence to a .mat file""" 15 | 16 | if not os.path.isdir(folder_name): 17 | os.makedirs(folder_name, exist_ok=True) 18 | 19 | vert_x = shape_x.vert.detach().cpu().numpy() 20 | vert_y = shape_y.vert.detach().cpu().numpy() 21 | triv_x = shape_x.triv.detach().cpu().numpy() + 1 22 | triv_y = shape_y.triv.detach().cpu().numpy() + 1 23 | 24 | if type(shape_x.samples) is list: 25 | samples = np.array(shape_x.samples, dtype=np.float32) 26 | else: 27 | samples = shape_x.samples.detach().cpu().numpy() 28 | 29 | vert_sequence = vert_sequence.detach().cpu().numpy() 30 | 31 | if shape_x.mahal_cov_mat is None: 32 | mat_dict = { 33 | "vert_x": vert_x, 34 | "vert_y": vert_y, 35 | "triv_x": triv_x, 36 | "triv_y": triv_y, 37 | "vert_sequence": vert_sequence, 38 | "time_elapsed": time_elapsed, 39 | "samples": samples, 40 | } 41 | else: 42 | shape_x.mahal_cov_mat = shape_x.mahal_cov_mat.detach().cpu().numpy() 43 | mat_dict = { 44 | "vert_x": vert_x, 45 | "vert_y": vert_y, 46 | "triv_x": triv_x, 47 | "triv_y": triv_y, 48 | "vert_sequence": vert_sequence, 49 | "time_elapsed": time_elapsed, 50 | "samples": samples, 51 | "mahal_cov_mat": shape_x.mahal_cov_mat, 52 | } 53 | 54 | scipy.io.savemat(os.path.join(folder_name, file_name), mat_dict) 55 | 56 | 57 | def plot_curr_shape(vert, triv_x): 58 | fig = plt.figure(1) 59 | ax = fig.add_subplot(111, projection="3d") 60 | ax.plot_trisurf( 61 | vert[:, 0], 62 | vert[:, 1], 63 | vert[:, 2], 64 | triangles=triv_x, 65 | cmap="viridis", 66 | linewidths=0.2, 67 | ) 68 | ax.set_xlim(-0.4, 0.4) 69 | ax.set_ylim(-0.4, 0.4) 70 | ax.set_zlim(-0.4, 0.4) 71 | 72 | 73 | def save_seq_collection_hard_correspondences( 74 | interp_module, shape_x_out, shape_y_out, points_out, res_name 75 | ): 76 | """Save test correspondences on a shape""" 77 | 78 | if not os.path.isdir(os.path.join(data_folder_out, res_name)): 79 | os.makedirs(os.path.join(data_folder_out, res_name), exist_ok=True) 80 | 81 | if not os.path.isdir(os.path.join(data_folder_out, res_name, "corrs")): 82 | os.makedirs(os.path.join(data_folder_out, res_name, "corrs"), exist_ok=True) 83 | 84 | print("Saving", len(points_out), "sequences in", os.path.join(data_folder_out, res_name), "...") 85 | for i in range(len(points_out)): 86 | vert_x = shape_x_out[i].vert.detach().cpu().numpy() 87 | vert_y = shape_y_out[i].vert.detach().cpu().numpy() 88 | triv_x = shape_x_out[i].triv.detach().cpu().numpy() 89 | triv_y = shape_y_out[i].triv.detach().cpu().numpy() 90 | 91 | plot_curr_shape(vert_x, triv_x) 92 | plt.savefig( 93 | os.path.join( 94 | data_folder_out, 95 | res_name, 96 | "seq_" + str(i).zfill(3) + "_" + str(0).zfill(3) + "_x.png", 97 | ) 98 | ) 99 | plt.clf() 100 | 101 | for j in range(points_out[i].shape[2]): 102 | vert = points_out[i][:, :, j].detach().cpu().numpy() 103 | plot_curr_shape(vert, triv_x) 104 | plt.savefig( 105 | os.path.join( 106 | data_folder_out, 107 | res_name, 108 | "seq_" + str(i).zfill(3) + "_" + str(j + 1).zfill(3) + ".png", 109 | ) 110 | ) 111 | plt.clf() 112 | 113 | plot_curr_shape(vert_y, triv_y) 114 | plt.savefig( 115 | os.path.join( 116 | data_folder_out, 117 | res_name, 118 | "seq_" 119 | + str(i).zfill(3) 120 | + "_" 121 | + str(points_out[i].shape[2] + 1).zfill(3) 122 | + "_y.png", 123 | ) 124 | ) 125 | plt.clf() 126 | 127 | file_name_mat = "seq_" + str(i).zfill(3) + ".mat" 128 | save_sequence( 129 | os.path.join(data_folder_out, res_name), 130 | file_name_mat, 131 | points_out[i], 132 | shape_x_out[i], 133 | shape_y_out[i], 134 | ) 135 | 136 | corr_out = interp_module.match(shape_x_out[i], shape_y_out[i]) 137 | assignment = corr_out.argmax(dim=1).detach().cpu().numpy() 138 | assignmentinv = corr_out.argmax(dim=0).detach().cpu().numpy() 139 | file_name_mat_corr = os.path.join( 140 | data_folder_out, res_name, "corrs", "corrs_" + str(i).zfill(3) + ".mat" 141 | ) 142 | scipy.io.savemat( 143 | file_name_mat_corr, 144 | { 145 | "assignment": assignment + 1, 146 | "assignmentinv": assignmentinv + 1, 147 | "X": {"vert": vert_x, "triv": triv_x + 1}, 148 | "Y": {"vert": vert_y, "triv": triv_y + 1}, 149 | }, 150 | ) 151 | 152 | 153 | def run_test(time_stamp_chkpt=None): 154 | time_stamp_arr = [time_stamp_chkpt] 155 | 156 | module_arr = None 157 | 158 | hyp_param = HypParam() 159 | 160 | dataset_val = Faust_remeshed_test(2000) 161 | 162 | hyp_param.rot_mod = 0 163 | 164 | for i_time, time_stamp in enumerate(time_stamp_arr): 165 | 166 | if module_arr is not None: 167 | hyp_param.in_mod = module_arr[i_time] 168 | 169 | print( 170 | "Evaluating time_stamp", 171 | time_stamp, 172 | "with the dataset", 173 | dataset_val.dataset_name_str(), 174 | ) 175 | 176 | interpol = create_interpol( 177 | dataset=dataset_val, 178 | dataset_val=dataset_val, 179 | time_stamp=time_stamp, 180 | hyp_param=hyp_param, 181 | ) 182 | 183 | interpol.load_self(save_path(folder_str=time_stamp)) 184 | 185 | interpol.interp_module.param.num_timesteps = 1 186 | shape_x_out, shape_y_out, points_out = interpol.test(dataset_val) 187 | interpol.interp_module = interpol.interp_module.to(device_cpu) 188 | save_seq_collection_hard_correspondences( 189 | interpol.interp_module, 190 | shape_x_out, 191 | shape_y_out, 192 | points_out, 193 | time_stamp 194 | + "__" 195 | + dataset_val.dataset_name_str() 196 | + "__epoch" 197 | + str(interpol.i_epoch + 1) 198 | + "_steps" 199 | + str(interpol.interp_module.param.num_timesteps), 200 | ) 201 | 202 | 203 | if __name__ == "__main__": 204 | run_test(sys.argv[1]) 205 | -------------------------------------------------------------------------------- /main_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from model.interpolation_net import * 7 | from utils.arap_interpolation import * 8 | from data.data import * 9 | 10 | 11 | class HypParam(ParamBase): 12 | def __init__(self): 13 | self.increase_thresh = 300 14 | 15 | self.method = "arap" 16 | self.in_mod = get_in_mod() 17 | 18 | self.load_dist_mat = True 19 | self.load_sub = True 20 | 21 | 22 | def get_in_mod(): 23 | in_mod = InterpolationModGeoEC 24 | 25 | return in_mod 26 | 27 | 28 | def create_interpol( 29 | dataset, 30 | dataset_val=None, 31 | folder_weights_load=None, 32 | time_stamp=None, 33 | param=None, 34 | hyp_param=None, 35 | ): 36 | 37 | if time_stamp is None: 38 | time_stamp = get_timestr() 39 | 40 | if param is None: 41 | param = NetParam() 42 | 43 | if hyp_param is None: 44 | hyp_param = HypParam() 45 | 46 | hyp_param.print_self() 47 | 48 | interpol_energy = ArapInterpolationEnergy() 49 | 50 | interpol_module = hyp_param.in_mod(interpol_energy, param).to(device) 51 | 52 | preproc_mods = [] 53 | 54 | settings_module = SettingsFaust(increase_thresh=hyp_param.increase_thresh) 55 | 56 | preproc_mods.append(PreprocessRotateSame(dataset.axis)) 57 | 58 | interpol = InterpolNet( 59 | interpol_module, 60 | dataset, 61 | dataset_val=dataset_val, 62 | time_stamp=time_stamp, 63 | preproc_mods=preproc_mods, 64 | settings_module=settings_module, 65 | ) 66 | 67 | if folder_weights_load is not None: 68 | interpol.load_self(save_path(folder_str=folder_weights_load)) 69 | 70 | interpol.i_epoch = 0 71 | 72 | return interpol 73 | 74 | 75 | def remesh_individual(dataset): 76 | return ShapeDatasetCombineRemesh(dataset) 77 | 78 | 79 | def create_dataset( 80 | dataset_cls, 81 | resolution, 82 | num_shapes=None, 83 | load_dist_mat=True, 84 | remeshing_fct=None, 85 | load_sub=False, 86 | ): 87 | if num_shapes is None: 88 | dataset = dataset_cls( 89 | resolution, load_dist_mat=load_dist_mat, load_sub=load_sub 90 | ) 91 | else: 92 | dataset = dataset_cls( 93 | resolution, num_shapes, load_dist_mat=load_dist_mat, load_sub=load_sub 94 | ) 95 | 96 | if remeshing_fct is not None: 97 | dataset = remeshing_fct(dataset) 98 | 99 | return dataset 100 | 101 | 102 | def start_train(dataset, dataset_val=None, folder_weights_load=None): 103 | interpol = create_interpol( 104 | dataset, dataset_val=dataset_val, folder_weights_load=folder_weights_load 105 | ) 106 | 107 | interpol.train() 108 | 109 | return interpol 110 | 111 | 112 | def train_main(): 113 | hyp_param = HypParam() 114 | 115 | # FAUST_remeshed: 116 | dataset = create_dataset( 117 | Faust_remeshed_train, 118 | 2000, 119 | None, 120 | hyp_param.load_dist_mat, 121 | remesh_individual, 122 | hyp_param.load_sub, 123 | ) 124 | dataset_val = create_dataset( 125 | Faust_remeshed_test, 126 | 2000, 127 | None, 128 | hyp_param.load_dist_mat, 129 | remesh_individual, 130 | hyp_param.load_sub, 131 | ) 132 | 133 | start_train(dataset, dataset_val) 134 | 135 | 136 | if __name__ == "__main__": 137 | train_main() 138 | -------------------------------------------------------------------------------- /model/interpolation_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn.functional as F 7 | from utils.arap_interpolation import * 8 | from data.data import * 9 | from model.layers import * 10 | 11 | 12 | class NetParam(ParamBase): 13 | """Base class for hyperparameters of interpolation methods""" 14 | 15 | def __init__(self): 16 | super().__init__() 17 | self.lr = 1e-4 18 | self.num_it = 600 19 | self.batch_size = 16 20 | self.num_timesteps = 0 21 | self.hidden_dim = 128 22 | self.lambd = 1 23 | self.lambd_geo = 50 24 | 25 | self.log_freq = 10 26 | self.val_freq = 10 27 | 28 | self.log = True 29 | 30 | 31 | class InterpolationModBase(torch.nn.Module): 32 | def __init__(self, interp_energy: InterpolationEnergy): 33 | super().__init__() 34 | self.interp_energy = interp_energy 35 | 36 | def get_pred(self, shape_x, shape_y): 37 | raise NotImplementedError() 38 | 39 | def compute_loss(self, shape_x, shape_y, point_pred_arr): 40 | raise NotImplementedError() 41 | 42 | def forward(self, shape_x, shape_y): 43 | point_pred_arr = self.get_pred(shape_x, shape_y) 44 | return self.compute_loss(shape_x, shape_y, point_pred_arr) 45 | 46 | 47 | class InterpolationModGeoEC(InterpolationModBase): 48 | def __init__(self, interp_energy: InterpolationEnergy, param=NetParam()): 49 | super().__init__(interp_energy) 50 | self.param = param 51 | param.print_self() 52 | self.rn_ec = ResnetECPos(c_dim=3, dim=7, hidden_dim=param.hidden_dim) 53 | self.feat_module = ResnetECPos( 54 | c_dim=param.hidden_dim, dim=6, hidden_dim=param.hidden_dim 55 | ) 56 | print("Uses module 'InterpolationModGeoEC'") 57 | self.Pi = None 58 | self.Pi_inv = None 59 | 60 | def get_pred(self, shape_x, shape_y, update_corr=True): 61 | if update_corr: 62 | self.match(shape_x, shape_y) 63 | 64 | step_size = 1 / (self.param.num_timesteps + 1) 65 | timesteps = step_size + torch.arange(0, 1, step_size, device=device).unsqueeze( 66 | 1 67 | ).unsqueeze( 68 | 2 69 | ) # [T, 1, 1] 70 | timesteps_up = timesteps * ( 71 | torch.as_tensor([0, 0, 0, 0, 0, 0, 1], device=device, dtype=torch.float) 72 | .unsqueeze(0) 73 | .unsqueeze(1) 74 | ) # [T, 1, 7] 75 | 76 | points_in = torch.cat( 77 | ( 78 | shape_x.vert, 79 | torch.mm(self.Pi, shape_y.vert) - shape_x.vert, 80 | my_zeros((shape_x.vert.shape[0], 1)), 81 | ), 82 | dim=1, 83 | ).unsqueeze( 84 | 0 85 | ) # [1, n, 7] 86 | points_in = points_in + timesteps_up 87 | 88 | edge_index = shape_x.get_edge_index() 89 | 90 | displacement = my_zeros([points_in.shape[0], points_in.shape[1], 3]) 91 | for i in range(points_in.shape[0]): 92 | displacement[i, :, :] = self.rn_ec(points_in[i, :, :], edge_index) 93 | # the previous three lines used to support batchwise processing in torch-geometric but are now deprecated: 94 | # displacement = self.rn_ec(points_in, edge_index) # [T, n, 3] 95 | 96 | point_pred_arr = shape_x.vert.unsqueeze(0) + displacement * timesteps 97 | point_pred_arr = point_pred_arr.permute([1, 2, 0]) 98 | return point_pred_arr 99 | 100 | def compute_loss(self, shape_x, shape_y, point_pred_arr, n_normalize=201.0): 101 | 102 | E_x_0 = self.interp_energy.forward_single( 103 | shape_x.vert, point_pred_arr[:, :, 0], shape_x 104 | ) + self.interp_energy.forward_single( 105 | point_pred_arr[:, :, 0], shape_x.vert, shape_x 106 | ) 107 | 108 | lambda_align = n_normalize / shape_x.vert.shape[0] 109 | E_align = ( 110 | lambda_align 111 | * self.param.lambd 112 | * ( 113 | (torch.mm(self.Pi, shape_y.vert) - point_pred_arr[:, :, -1]).norm() ** 2 114 | + ( 115 | shape_y.vert - torch.mm(self.Pi_inv, point_pred_arr[:, :, -1]) 116 | ).norm() 117 | ** 2 118 | ) 119 | ) 120 | 121 | if shape_x.D is None: 122 | E_geo = my_tensor(0) 123 | elif self.param.lambd_geo == 0: 124 | E_geo = my_tensor(0) 125 | else: 126 | E_geo = ( 127 | self.param.lambd_geo 128 | * ( 129 | ( 130 | torch.mm(torch.mm(self.Pi, shape_y.D), self.Pi.transpose(0, 1)) 131 | - shape_x.D 132 | ) 133 | ** 2 134 | ).mean() 135 | ) 136 | 137 | E = E_x_0 + E_align + E_geo 138 | 139 | for i in range(self.param.num_timesteps): 140 | E_x = self.interp_energy.forward_single( 141 | point_pred_arr[:, :, i], point_pred_arr[:, :, i + 1], shape_x 142 | ) 143 | E_y = self.interp_energy.forward_single( 144 | point_pred_arr[:, :, i + 1], point_pred_arr[:, :, i], shape_x 145 | ) 146 | 147 | E = E + E_x + E_y 148 | 149 | return E, [E - E_align - E_geo, E_align, E_geo] 150 | 151 | def match(self, shape_x, shape_y): 152 | feat_x = torch.cat((shape_x.vert, shape_x.get_normal()), dim=1) 153 | feat_y = torch.cat((shape_y.vert, shape_y.get_normal()), dim=1) 154 | 155 | feat_x = self.feat_module(feat_x, shape_x.get_edge_index()) 156 | feat_y = self.feat_module(feat_y, shape_y.get_edge_index()) 157 | 158 | feat_x = feat_x / feat_x.norm(dim=1, keepdim=True) 159 | feat_y = feat_y / feat_y.norm(dim=1, keepdim=True) 160 | 161 | D = torch.mm(feat_x, feat_y.transpose(0, 1)) 162 | 163 | sigma = 1e2 164 | self.Pi = F.softmax(D * sigma, dim=1) 165 | self.Pi_inv = F.softmax(D * sigma, dim=0).transpose(0, 1) 166 | 167 | return self.Pi 168 | 169 | 170 | ################################################################################################ 171 | 172 | 173 | class InterpolNet: 174 | def __init__( 175 | self, 176 | interp_module: InterpolationModBase, 177 | dataset, 178 | dataset_val=None, 179 | time_stamp=None, 180 | preproc_mods=[], 181 | settings_module=None, 182 | ): 183 | super().__init__() 184 | self.time_stamp = time_stamp 185 | self.interp_module = interp_module 186 | self.settings_module = settings_module 187 | self.preproc_mods = preproc_mods 188 | self.dataset = dataset 189 | if dataset is not None: 190 | self.train_loader = torch.utils.data.DataLoader( 191 | dataset, batch_size=1, shuffle=True 192 | ) 193 | self.dataset_val = dataset_val 194 | self.i_epoch = 0 195 | self.optimizer = torch.optim.Adam( 196 | self.interp_module.parameters(), lr=self.interp_module.param.lr 197 | ) 198 | 199 | def train(self): 200 | print("start training ...") 201 | 202 | self.interp_module.train() 203 | 204 | while self.i_epoch < self.interp_module.param.num_it: 205 | tot_loss = 0 206 | tot_loss_comp = None 207 | 208 | self.update_settings() 209 | 210 | for i, data in enumerate(self.train_loader): 211 | shape_x = batch_to_shape(data["X"]) 212 | shape_y = batch_to_shape(data["Y"]) 213 | 214 | shape_x, shape_y = self.preprocess(shape_x, shape_y) 215 | 216 | loss, loss_comp = self.interp_module(shape_x, shape_y) 217 | 218 | loss.backward() 219 | 220 | if (i + 1) % self.interp_module.param.batch_size == 0 and i < len( 221 | self.train_loader 222 | ) - 1: 223 | self.optimizer.step() 224 | self.optimizer.zero_grad() 225 | 226 | if tot_loss_comp is None: 227 | tot_loss_comp = [ 228 | loss_comp[i].detach() / self.dataset.__len__() 229 | for i in range(len(loss_comp)) 230 | ] 231 | else: 232 | tot_loss_comp = [ 233 | tot_loss_comp[i] 234 | + loss_comp[i].detach() / self.dataset.__len__() 235 | for i in range(len(loss_comp)) 236 | ] 237 | 238 | tot_loss += loss.detach() / self.dataset.__len__() 239 | 240 | self.optimizer.step() 241 | self.optimizer.zero_grad() 242 | 243 | print( 244 | "epoch {:04d}, loss = {:.5f} (arap: {:.5f}, reg: {:.5f}, geo: {:.5f}), reserved memory={}MB".format( 245 | self.i_epoch, 246 | tot_loss, 247 | tot_loss_comp[0], 248 | tot_loss_comp[1], 249 | tot_loss_comp[2], 250 | torch.cuda.memory_reserved(0) // (1024 ** 2), 251 | ) 252 | ) 253 | 254 | if self.time_stamp is not None: 255 | if (self.i_epoch + 1) % self.interp_module.param.log_freq == 0: 256 | self.save_self() 257 | if (self.i_epoch + 1) % self.interp_module.param.val_freq == 0: 258 | self.test(self.dataset_val) 259 | 260 | self.i_epoch += 1 261 | 262 | def test(self, dataset, compute_val_loss=True): 263 | test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) 264 | shape_x_out = [] 265 | shape_y_out = [] 266 | points_out = [] 267 | 268 | tot_loss_val = 0 269 | 270 | for i, data in enumerate(test_loader): 271 | shape_x = batch_to_shape(data["X"]) 272 | shape_y = batch_to_shape(data["Y"]) 273 | 274 | shape_x, shape_y = self.preprocess(shape_x, shape_y) 275 | 276 | point_pred = self.interp_module.get_pred(shape_x, shape_y) 277 | 278 | if compute_val_loss: 279 | loss, _ = self.interp_module.compute_loss(shape_x, shape_y, point_pred) 280 | tot_loss_val += loss.detach() / len(dataset) 281 | 282 | shape_x.detach_cpu() 283 | shape_y.detach_cpu() 284 | point_pred = point_pred.detach().cpu() 285 | 286 | points_out.append(point_pred) 287 | shape_x_out.append(shape_x) 288 | shape_y_out.append(shape_y) 289 | 290 | if compute_val_loss: 291 | print("Validation loss = ", tot_loss_val) 292 | 293 | return shape_x_out, shape_y_out, points_out 294 | 295 | def preprocess(self, shape_x, shape_y): 296 | for pre in self.preproc_mods: 297 | shape_x, shape_y = pre.preprocess(shape_x, shape_y) 298 | return shape_x, shape_y 299 | 300 | def update_settings(self): 301 | if self.settings_module is not None: 302 | self.settings_module.update(self.interp_module, self.i_epoch) 303 | 304 | def save_self(self): 305 | folder_path = save_path(self.time_stamp) 306 | 307 | if not os.path.isdir(folder_path): 308 | os.mkdir(folder_path) 309 | 310 | ckpt_last_name = "ckpt_last.pth" 311 | ckpt_last_path = os.path.join(folder_path, ckpt_last_name) 312 | 313 | ckpt_name = "ckpt_ep{}.pth".format(self.i_epoch) 314 | ckpt_path = os.path.join(folder_path, ckpt_name) 315 | 316 | self.save_chkpt(ckpt_path) 317 | self.save_chkpt(ckpt_last_path) 318 | 319 | def save_chkpt(self, ckpt_path): 320 | ckpt = { 321 | "i_epoch": self.i_epoch, 322 | "interp_module": self.interp_module.state_dict(), 323 | "optimizer_state_dict": self.optimizer.state_dict(), 324 | "par": self.interp_module.param.__dict__, 325 | } 326 | 327 | torch.save(ckpt, ckpt_path) 328 | 329 | def load_self(self, folder_path, num_epoch=None): 330 | if num_epoch is None: 331 | ckpt_name = "ckpt_last.pth" 332 | ckpt_path = os.path.join(folder_path, ckpt_name) 333 | else: 334 | ckpt_name = "ckpt_ep{}.pth".format(num_epoch) 335 | ckpt_path = os.path.join(folder_path, ckpt_name) 336 | 337 | self.load_chkpt(ckpt_path) 338 | 339 | if num_epoch is None: 340 | print("Loaded model from ", folder_path, " with the latest weights") 341 | else: 342 | print( 343 | "Loaded model from ", 344 | folder_path, 345 | " with the weights from epoch ", 346 | num_epoch, 347 | ) 348 | 349 | def load_chkpt(self, ckpt_path): 350 | ckpt = torch.load(ckpt_path, map_location=device) 351 | 352 | self.i_epoch = ckpt["i_epoch"] 353 | self.interp_module.load_state_dict(ckpt["interp_module"]) 354 | 355 | if "par" in ckpt: 356 | self.interp_module.param.from_dict(ckpt["par"]) 357 | self.interp_module.param.print_self() 358 | 359 | if "optimizer_state_dict" in ckpt: 360 | self.optimizer.load_state_dict(ckpt["optimizer_state_dict"]) 361 | 362 | self.interp_module.train() 363 | 364 | 365 | class SettingsBase: 366 | def update(self, interp_module, i_epoch): 367 | raise NotImplementedError() 368 | 369 | 370 | class SettingsFaust(SettingsBase): 371 | def __init__(self, increase_thresh): 372 | super().__init__() 373 | self.increase_thresh = increase_thresh 374 | print("Uses settings module 'SettingsFaust'") 375 | 376 | def update(self, interp_module, i_epoch): 377 | if i_epoch < self.increase_thresh: # 0 - 300 378 | return 379 | elif i_epoch < self.increase_thresh * 1.5: # 300 - 450 380 | num_t = 1 381 | elif i_epoch < self.increase_thresh * 1.75: # 450 - 525 382 | num_t = 3 383 | else: # > 525 384 | num_t = 7 385 | 386 | interp_module.param.num_timesteps = num_t 387 | print("Set the # of timesteps to ", num_t) 388 | 389 | interp_module.param.lambd_geo = 0 390 | print("Deactivated the geodesic loss") 391 | 392 | 393 | class PreprocessBase: 394 | def preprocess(self, shape_x, shape_y): 395 | raise NotImplementedError() 396 | 397 | 398 | class PreprocessRotateBase(PreprocessBase): 399 | def __init__(self, axis=1): 400 | super().__init__() 401 | self.axis = axis 402 | 403 | def _create_rot_matrix(self, alpha): 404 | return create_rotation_matrix(alpha, self.axis) 405 | 406 | def _rand_rot(self): 407 | alpha = torch.rand(1) * 360 408 | return self._create_rot_matrix(alpha) 409 | 410 | def rot_sub(self, shape, r): 411 | if shape.sub is not None: 412 | for i_p in range(len(shape.sub[0])): 413 | shape.sub[0][i_p][0, :, :] = torch.mm(shape.sub[0][i_p][0, :, :], r) 414 | 415 | if shape.vert_full is not None: 416 | shape.vert_full = torch.mm(shape.vert_full, r) 417 | 418 | return shape 419 | 420 | def preprocess(self, shape_x, shape_y): 421 | raise NotImplementedError() 422 | 423 | 424 | class PreprocessRotate(PreprocessRotateBase): 425 | def __init__(self, axis=1): 426 | super().__init__(axis) 427 | print("Uses preprocessing module 'PreprocessRotate'") 428 | 429 | def preprocess(self, shape_x, shape_y): 430 | r_x = self._rand_rot() 431 | r_y = self._rand_rot() 432 | shape_x.vert = torch.mm(shape_x.vert, r_x) 433 | shape_y.vert = torch.mm(shape_y.vert, r_y) 434 | shape_x = self.rot_sub(shape_x, r_x) 435 | shape_y = self.rot_sub(shape_y, r_y) 436 | return shape_x, shape_y 437 | 438 | 439 | class PreprocessRotateSame(PreprocessRotateBase): 440 | def __init__(self, axis=1): 441 | super().__init__(axis) 442 | print("Uses preprocessing module 'PreprocessRotateSame'") 443 | 444 | def preprocess(self, shape_x, shape_y): 445 | r = self._rand_rot() 446 | shape_x.vert = torch.mm(shape_x.vert, r) 447 | shape_y.vert = torch.mm(shape_y.vert, r) 448 | 449 | shape_x = self.rot_sub(shape_x, r) 450 | shape_y = self.rot_sub(shape_y, r) 451 | return shape_x, shape_y 452 | 453 | 454 | class PreprocessRotateAugment(PreprocessRotateBase): 455 | def __init__(self, axis=1, sigma=0.3): 456 | super().__init__(axis) 457 | self.sigma = sigma 458 | print( 459 | "Uses preprocessing module 'PreprocessRotateAugment' with sigma =", 460 | self.sigma, 461 | ) 462 | 463 | def preprocess(self, shape_x, shape_y): 464 | r_x = self._rand_rot_augment() 465 | r_y = self._rand_rot_augment() 466 | shape_x.vert = torch.mm(shape_x.vert, r_x) 467 | shape_y.vert = torch.mm(shape_y.vert, r_y) 468 | 469 | shape_x = self.rot_sub(shape_x, r_x) 470 | shape_y = self.rot_sub(shape_y, r_y) 471 | return shape_x, shape_y 472 | 473 | # computes a pair of approximately similar rotation matrices 474 | def _rand_rot_augment(self): 475 | rot = torch.randn( 476 | [3, 3], dtype=torch.float, device=device 477 | ) * self.sigma + my_eye(3) 478 | 479 | U, _, V = torch.svd(rot, compute_uv=True) 480 | 481 | rot = torch.mm(U, V.transpose(0, 1)) 482 | 483 | return rot 484 | 485 | 486 | if __name__ == "__main__": 487 | print("main of interpolation_net.py") 488 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | from torch_geometric.nn import EdgeConv 3 | from model.layers_onet import ResnetBlockFC 4 | from utils.base_tools import * 5 | 6 | 7 | def MLP(channels): 8 | return torch.nn.Sequential(*[ 9 | torch.nn.Sequential(torch.nn.Linear(channels[i - 1], channels[i]), 10 | torch.nn.ReLU(), torch.nn.BatchNorm1d(channels[i])) 11 | for i in range(1, len(channels)) 12 | ]) 13 | 14 | 15 | def maxpool(x, dim=-1, keepdim=False): 16 | out, _ = x.max(dim=dim, keepdim=keepdim) 17 | return out 18 | 19 | 20 | class ResnetECPos(torch.nn.Module): 21 | def __init__(self, c_dim=128, dim=3, hidden_dim=128): 22 | # def __init__(self, c_dim=128, dim=3, hidden_dim=256): 23 | super().__init__() 24 | self.c_dim = c_dim 25 | 26 | self.fc_pos = torch.nn.Linear(dim, 2*hidden_dim) 27 | self.block_0 = EdgeConv(ResnetBlockFC(4*hidden_dim, hidden_dim)) 28 | self.block_1 = EdgeConv(ResnetBlockFC(4*hidden_dim+2*dim, hidden_dim)) 29 | self.block_2 = EdgeConv(ResnetBlockFC(4*hidden_dim+2*dim, hidden_dim)) 30 | self.block_3 = EdgeConv(ResnetBlockFC(4*hidden_dim+2*dim, hidden_dim)) 31 | self.block_4 = EdgeConv(ResnetBlockFC(4*hidden_dim+2*dim, hidden_dim)) 32 | self.fc_c = torch.nn.Linear(hidden_dim, c_dim) 33 | 34 | self.actvn = torch.nn.ReLU() 35 | self.pool = maxpool 36 | 37 | def forward(self, p, edge_index): 38 | net = self.fc_pos(p) 39 | net = self.block_0(net, edge_index) 40 | 41 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 42 | net = torch.cat([net, pooled, p], dim=1) 43 | 44 | net = self.block_1(net, edge_index) 45 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 46 | net = torch.cat([net, pooled, p], dim=1) 47 | 48 | net = self.block_2(net, edge_index) 49 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 50 | net = torch.cat([net, pooled, p], dim=1) 51 | 52 | net = self.block_3(net, edge_index) 53 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 54 | net = torch.cat([net, pooled, p], dim=1) 55 | 56 | net = self.block_4(net, edge_index) 57 | 58 | c = self.fc_c(self.actvn(net)) 59 | 60 | return c 61 | 62 | 63 | if __name__ == "__main__": 64 | print("main of layers.py") 65 | -------------------------------------------------------------------------------- /model/layers_onet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | # Resnet Blocks 11 | class ResnetBlockFC(nn.Module): 12 | ''' Fully connected ResNet Block class. 13 | 14 | Args: 15 | size_in (int): input dimension 16 | size_out (int): output dimension 17 | size_h (int): hidden dimension 18 | ''' 19 | 20 | def __init__(self, size_in, size_out=None, size_h=None): 21 | super().__init__() 22 | # Attributes 23 | if size_out is None: 24 | size_out = size_in 25 | 26 | if size_h is None: 27 | size_h = min(size_in, size_out) 28 | 29 | self.size_in = size_in 30 | self.size_h = size_h 31 | self.size_out = size_out 32 | # Submodules 33 | self.fc_0 = nn.Linear(size_in, size_h) 34 | self.fc_1 = nn.Linear(size_h, size_out) 35 | self.actvn = nn.ReLU() 36 | 37 | if size_in == size_out: 38 | self.shortcut = None 39 | else: 40 | self.shortcut = nn.Linear(size_in, size_out, bias=False) 41 | # Initialization 42 | nn.init.zeros_(self.fc_1.weight) 43 | 44 | def forward(self, x): 45 | net = self.fc_0(self.actvn(x)) 46 | dx = self.fc_1(self.actvn(net)) 47 | 48 | if self.shortcut is not None: 49 | x_s = self.shortcut(x) 50 | else: 51 | x_s = x 52 | 53 | return x_s + dx 54 | 55 | 56 | if __name__ == "__main__": 57 | print("main of layers_onet.py") 58 | -------------------------------------------------------------------------------- /param.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import pathlib 8 | import os 9 | from datetime import datetime 10 | 11 | path_curr = str(pathlib.Path(__file__).parent.absolute()) 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | device_cpu = torch.device('cpu') 15 | 16 | data_folder_faust_remeshed = "data/meshes/FAUST_r/mat" 17 | data_folder_mano_right = "data/meshes/MANO_right/mat" 18 | data_folder_mano_test = "data/meshes/MANO_test/mat" 19 | data_folder_shrec20 = "data/meshes/SHREC_r/mat" 20 | 21 | chkpt_folder = "data/checkpoint" 22 | data_folder_out = "data/out" 23 | 24 | 25 | def get_timestr(): 26 | now = datetime.now() 27 | time_stamp = now.strftime("%Y_%m_%d__%H_%M_%S") 28 | print("Time stamp: ", time_stamp) 29 | return time_stamp 30 | 31 | 32 | def save_path(folder_str=None): 33 | if folder_str is None: 34 | folder_str = get_timestr() 35 | 36 | folder_path_models = os.path.join(chkpt_folder, folder_str) 37 | print("Checkpoint path: ", folder_path_models) 38 | return folder_path_models 39 | -------------------------------------------------------------------------------- /utils/arap_interpolation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from utils.arap_potential import * 7 | from utils.interpolation_base import * 8 | 9 | 10 | class ArapInterpolationEnergy(InterpolationEnergyHessian): 11 | """The interpolation method based on Sorkine et al., 2007""" 12 | 13 | def __init__(self): 14 | super().__init__() 15 | 16 | # override 17 | def forward_single(self, vert_new, vert_ref, shape_i): 18 | E_arap = arap_energy_exact(vert_new, vert_ref, shape_i.get_neigh()) 19 | return E_arap 20 | 21 | # override 22 | def get_hessian(self, shape_i): 23 | return shape_i.get_neigh_hessian() 24 | 25 | 26 | if __name__ == "__main__": 27 | print("main of arap_interpolation.py") 28 | -------------------------------------------------------------------------------- /utils/arap_potential.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn.functional 7 | from utils.base_tools import * 8 | from param import * 9 | 10 | 11 | def arap_exact(vert_diff_t, vert_diff_0, neigh, n_vert): 12 | S_neigh = torch.bmm(vert_diff_t.unsqueeze(2), vert_diff_0.unsqueeze(1)) 13 | 14 | S = my_zeros([n_vert, 3, 3]) 15 | 16 | S = torch.index_add(S, 0, neigh[:, 0], S_neigh) 17 | S = torch.index_add(S, 0, neigh[:, 1], S_neigh) 18 | 19 | U, _, V = torch.svd(S.cpu(), compute_uv=True) 20 | 21 | U = U.to(device) 22 | V = V.to(device) 23 | 24 | R = torch.bmm(U, V.transpose(1, 2)) 25 | 26 | Sigma = my_ones((R.shape[0], 1, 3)) 27 | Sigma[:, :, 2] = torch.det(R).unsqueeze(1) 28 | 29 | R = torch.bmm(U * Sigma, V.transpose(1, 2)) 30 | 31 | return R 32 | 33 | 34 | def arap_energy_exact(vert_t, vert_0, neigh, lambda_reg_len=1e-6): 35 | n_vert = vert_t.shape[0] 36 | 37 | vert_diff_t = vert_t[neigh[:, 0], :] - vert_t[neigh[:, 1], :] 38 | vert_diff_0 = vert_0[neigh[:, 0], :] - vert_0[neigh[:, 1], :] 39 | 40 | R_t = arap_exact(vert_diff_t, vert_diff_0, neigh, n_vert) 41 | 42 | R_neigh_t = 0.5 * ( 43 | torch.index_select(R_t, 0, neigh[:, 0]) 44 | + torch.index_select(R_t, 0, neigh[:, 1]) 45 | ) 46 | 47 | vert_diff_0_rot = torch.bmm(R_neigh_t, vert_diff_0.unsqueeze(2)).squeeze() 48 | acc_t_neigh = vert_diff_t - vert_diff_0_rot 49 | 50 | E_arap = acc_t_neigh.norm() ** 2 + lambda_reg_len * (vert_t - vert_0).norm() ** 2 51 | 52 | return E_arap 53 | 54 | 55 | if __name__ == "__main__": 56 | print("main of arap_potential.py") 57 | -------------------------------------------------------------------------------- /utils/base_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Marvin Eisenberger. 3 | 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import math 9 | from param import device 10 | 11 | 12 | triv_to_edge = torch.as_tensor( 13 | [[-1, 1, 0], [0, -1, 1], [1, 0, -1]], dtype=torch.float32, device=device 14 | ) 15 | edge_norm_to_proj = torch.as_tensor( 16 | [[-1, 1, 1], [1, -1, 1], [1, 1, -1]], dtype=torch.float32, device=device 17 | ) 18 | hat_matrix = torch.as_tensor( 19 | [ 20 | [[0, 0, 0], [0, 0, 1], [0, -1, 0]], 21 | [[0, 0, -1], [0, 0, 0], [1, 0, 0]], 22 | [[0, 1, 0], [-1, 0, 0], [0, 0, 0]], 23 | ], 24 | device=device, 25 | dtype=torch.float32, 26 | ) 27 | 28 | 29 | def my_speye(n, offset=1): 30 | V = my_ones([n]) 31 | I = torch.arange(n) 32 | I = torch.cat((I.unsqueeze(0), I.unsqueeze(0)), 0).to( 33 | dtype=torch.long, device=device 34 | ) 35 | V = V * offset 36 | M = torch.sparse.FloatTensor(I, V, (n, n)) 37 | return M 38 | 39 | 40 | def create_rotation_matrix(alpha, axis): 41 | alpha = alpha / 180 * math.pi 42 | c = torch.cos(alpha) 43 | s = torch.sin(alpha) 44 | rot_2d = torch.as_tensor([[c, -s], [s, c]], dtype=torch.float, device=device) 45 | rot_3d = my_eye(3) 46 | idx = [i for i in range(3) if i != axis] 47 | for i in range(len(idx)): 48 | for j in range(len(idx)): 49 | rot_3d[idx[i], idx[j]] = rot_2d[i, j] 50 | return rot_3d 51 | 52 | 53 | def mat_to_rot(m): 54 | u, _, v = torch.svd(m) 55 | rot = torch.mm(u, v.transpose(0, 1)) 56 | s = my_ones([1, 3]) 57 | s[0, -1] = rot.det() 58 | rot = torch.mm(u * s, v.transpose(0, 1)) 59 | return rot 60 | 61 | 62 | def my_ones(shape): 63 | return torch.ones(shape, device=device, dtype=torch.float32) 64 | 65 | 66 | def my_zeros(shape): 67 | return torch.zeros(shape, device=device, dtype=torch.float32) 68 | 69 | 70 | def my_eye(n): 71 | return torch.eye(n, device=device, dtype=torch.float32) 72 | 73 | 74 | def my_tensor(t): 75 | return torch.as_tensor(t, device=device, dtype=torch.float32) 76 | 77 | 78 | def hat_op(v): 79 | assert v.shape[1] == 3, "wrong input dimensions" 80 | 81 | w = my_zeros([3, 3, 3]) 82 | 83 | w[0, 1, 2] = -1 84 | w[0, 2, 1] = 1 85 | w[1, 0, 2] = 1 86 | w[1, 2, 0] = -1 87 | w[2, 0, 1] = -1 88 | w[2, 1, 0] = 1 89 | 90 | v = v.transpose(0, 1).unsqueeze(2).unsqueeze(3) 91 | w = w.unsqueeze(1) 92 | 93 | M = v * w 94 | M = M.sum(0) 95 | 96 | return M 97 | 98 | 99 | def cross_prod(u, v): 100 | if len(v.shape) == 2: 101 | v = v.unsqueeze(2) 102 | return torch.bmm(hat_op(u), v) 103 | 104 | 105 | def batch_trace(m): 106 | m = (m * my_eye(m.shape[1]).unsqueeze(0)).sum(dim=(1, 2)) 107 | return m.unsqueeze(1).unsqueeze(2) 108 | 109 | 110 | def soft_relu(m, eps=1e-7): 111 | return torch.relu(m) + eps 112 | 113 | 114 | def dist_mat(x, y, inplace=True): 115 | d = torch.mm(x, y.transpose(0, 1)) 116 | v_x = torch.sum(x ** 2, 1).unsqueeze(1) 117 | v_y = torch.sum(y ** 2, 1).unsqueeze(0) 118 | d *= -2 119 | if inplace: 120 | d += v_x 121 | d += v_y 122 | else: 123 | d = d + v_x 124 | d = d + v_y 125 | 126 | return d 127 | 128 | 129 | if __name__ == "__main__": 130 | print("main of base_tools.py") 131 | -------------------------------------------------------------------------------- /utils/interpolation_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Marvin Eisenberger. 3 | 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch.optim 8 | from torch.nn import Parameter 9 | from utils.shape_utils import * 10 | from param import device_cpu 11 | from scipy.sparse.linalg import spsolve 12 | from scipy.sparse import kron, spdiags, csr_matrix, lil_matrix, eye 13 | from utils.base_tools import * 14 | 15 | 16 | class ParamBase: 17 | """Base class for parameters""" 18 | 19 | def from_dict(self, d): 20 | for key in d: 21 | if hasattr(self, key): 22 | self.__setattr__(key, d[key]) 23 | 24 | def print_self(self): 25 | print("parameters: ") 26 | p_d = self.__dict__ 27 | for k in p_d: 28 | print(k, ": ", p_d[k], " ", end="") 29 | print("") 30 | 31 | 32 | class Param(ParamBase): 33 | """Base class for hyperparameters of interpolation methods""" 34 | 35 | def __init__(self): 36 | self.lr = 0.001 37 | self.num_it = 20 38 | self.scales = [200, 500, 1000, 2000, 12500] 39 | self.odd_multisteps = False 40 | self.num_timesteps = 3 41 | 42 | self.log = True 43 | 44 | 45 | # ----------------------------------------------------------------------------------- 46 | 47 | 48 | class InterpolationEnergy: 49 | """Base class for local distortion interpolation potentials""" 50 | 51 | def __init__(self): 52 | super().__init__() 53 | 54 | def forward_single(self, vert_new, vert_ref, shape_i): 55 | raise NotImplementedError() 56 | 57 | 58 | class InterpolationEnergyHessian(InterpolationEnergy): 59 | """Abstract for interpolation potentials that also allow for second order optimization""" 60 | 61 | def __init__(self): 62 | super().__init__() 63 | 64 | def get_hessian(self, shape_i): 65 | raise NotImplementedError() 66 | 67 | 68 | # ----------------------------------------------------------------------------------- 69 | 70 | 71 | class InterpolationModuleBase(torch.nn.Module): 72 | """Base class for an interpolation method""" 73 | 74 | def __init__(self): 75 | super().__init__() 76 | 77 | def forward(self): 78 | raise NotImplementedError() 79 | 80 | 81 | class InterpolationModuleMultiscaleBase(InterpolationModuleBase): 82 | """Base class for multi-scale interpolation methods """ 83 | 84 | def __init__(self): 85 | super().__init__() 86 | 87 | def step_multiscale(self, i_scale): 88 | raise NotImplementedError() 89 | 90 | 91 | class InterpolationModule(InterpolationModuleBase): 92 | """Base class for interpolation methods based on a local distortion potential""" 93 | 94 | def __init__(self, energy: InterpolationEnergy): 95 | super().__init__() 96 | self.energy = energy 97 | 98 | def forward(self): 99 | raise NotImplementedError() 100 | 101 | 102 | class InterpolationModuleMultiscale( 103 | InterpolationModule, InterpolationModuleMultiscaleBase 104 | ): 105 | """Base class for interpolation methods based on a local distortion potential""" 106 | 107 | def __init__(self, energy: InterpolationEnergy): 108 | super().__init__(energy) 109 | 110 | 111 | class InterpolationModuleHessian(InterpolationModule): 112 | """Base class for interpolation methods that allow for second order optimization""" 113 | 114 | def __init__(self, energy: InterpolationEnergyHessian): 115 | super().__init__(energy) 116 | 117 | def mul_with_inv_hessian(self): 118 | raise NotImplementedError() 119 | 120 | 121 | class InterpolationModuleSingle(InterpolationModule): 122 | """Class for interpolation methods that compute shape interpolation 123 | as a weighted combination X^*=sum_i lambda_i*X_i, where 124 | lambda is an n-dim vector prescribing a discrete probability distribution""" 125 | 126 | def __init__(self, energy: InterpolationEnergy, shape_array, interpolation_coeff): 127 | super().__init__(energy) 128 | 129 | assert ( 130 | all([coeff >= 0 for coeff in interpolation_coeff]) 131 | and sum(interpolation_coeff) == 1 132 | ), "interpolation coeffs need to prescribe a discrete probability distribution" 133 | 134 | self.shape_array = shape_array 135 | self.interpolation_coeff = interpolation_coeff 136 | 137 | self.vert_new = Parameter(self.average_vertex(), requires_grad=True) 138 | 139 | def average_vertex(self): 140 | num_shapes = len(self.shape_array) 141 | vert_avg = my_zeros(self.shape_array[0].get_vert().shape) 142 | 143 | for i in range(num_shapes): 144 | vert_avg = ( 145 | vert_avg + 1 / num_shapes * self.shape_array[i].get_vert().clone() 146 | ) 147 | 148 | return vert_avg 149 | 150 | def reset_vert_new(self): 151 | self.vert_new.data = self.average_vertex() 152 | 153 | def set_vert_new(self, vert): 154 | self.vert_new.data = vert 155 | 156 | def forward(self): 157 | E_total = 0 158 | 159 | for i in range(len(self.shape_array)): 160 | E_curr = self.interpolation_coeff[i] * self.energy.forward_single( 161 | self.vert_new, self.shape_array[i].get_vert(), self.shape_array[i] 162 | ) 163 | E_total = E_total + E_curr 164 | 165 | return E_total, [E_total] 166 | 167 | 168 | class InterpolationModuleSingleHessian( 169 | InterpolationModuleSingle, InterpolationModuleHessian 170 | ): 171 | """Same as InterpolationModuleSingle, but with second order optimization""" 172 | 173 | def __init__( 174 | self, energy: InterpolationEnergyHessian, shape_array, interpolation_coeff 175 | ): 176 | super().__init__(energy, shape_array, interpolation_coeff) 177 | 178 | def mul_with_inv_hessian(self): 179 | hess = self.energy.get_hessian(self.shape_array[0]) 180 | 181 | grad_hess = spsolve(hess, self.vert_new.grad.to(device_cpu).detach().cpu()) 182 | self.vert_new.grad = torch.tensor(grad_hess, dtype=torch.float32, device=device) 183 | 184 | 185 | class InterpolationModuleBvp(InterpolationModuleMultiscale): 186 | """Class that computes shape interpolations by optimizing for the geometry of 187 | an intermediate sequence of shapes. Two consecutive shapes are required to have 188 | a small local distortion between each other wrt. some energy 'InterpolationEnergy'. 189 | The scheme is initialized with a linear interpolation between shape_x and shape_y. 190 | Moreover, it allows for a hierarchical set of multi-scale shapes_x with an increasing 191 | resolution and it refines the temporal discretization over time.""" 192 | 193 | def __init__( 194 | self, 195 | energy: InterpolationEnergy, 196 | shape_x: Shape, 197 | shape_y: Shape, 198 | param=Param(), 199 | vertices=None, 200 | ): 201 | super().__init__(energy) 202 | 203 | self.param = param 204 | self.shape_x = shape_x 205 | self.shape_y = shape_y 206 | 207 | if vertices is None: 208 | vertices = self.compute_vertices() 209 | self.vert_sequence = Parameter(vertices, requires_grad=True) 210 | 211 | def forward(self): 212 | num_t = self.param.num_timesteps 213 | 214 | E_x = self.energy.forward_single( 215 | self.vert_sequence[0, ...].detach(), 216 | self.vert_sequence[1, ...], 217 | self.shape_x, 218 | ) 219 | E_y = self.energy.forward_single( 220 | self.vert_sequence[num_t - 2, ...], 221 | self.vert_sequence[num_t - 1, ...].detach(), 222 | self.shape_x, 223 | ) 224 | 225 | E_total = E_x + E_y 226 | 227 | for i in range(1, num_t - 2): 228 | E_curr = self.energy.forward_single( 229 | self.vert_sequence[i, ...], self.vert_sequence[i + 1, ...], self.shape_x 230 | ) 231 | E_total = E_total + E_curr 232 | 233 | E_total = E_total / (num_t - 1) 234 | 235 | return E_total, [E_total] 236 | 237 | def get_vert_sequence(self): 238 | return self.vert_sequence 239 | 240 | def compute_vertices(self): 241 | num_t = self.param.num_timesteps 242 | vertices = my_zeros([num_t, self.shape_x.get_vert_shape()[0], 3]) 243 | for i in range(0, num_t): 244 | lambd = i / (num_t - 1) 245 | vertices[i, ...] = ( 246 | 1 - lambd 247 | ) * self.shape_x.get_vert().clone() + lambd * self.shape_y.get_vert().clone() 248 | 249 | return vertices 250 | 251 | def copy_self(self, vertices=None): 252 | return InterpolationModuleBvp( 253 | self.energy, self.shape_x, self.shape_y, self.param, vertices 254 | ) 255 | 256 | def step_multiscale(self, i_scale): 257 | vertices = self.vert_sequence.data 258 | 259 | if i_scale % 2 == 0: 260 | vertices = self.insert_additional_vertices(vertices) 261 | else: 262 | vertices = self.upsample_resolution(vertices) 263 | 264 | return self.copy_self(vertices) 265 | 266 | def upsample_resolution(self, vertices): 267 | num_vert = self.shape_x.next_resolution()[0] 268 | num_t = vertices.shape[0] 269 | 270 | vertices_new = my_zeros([num_t, num_vert, 3]) 271 | 272 | for t in range(1, num_t - 1): 273 | l = (t - 1) / (num_t - 2) 274 | vertices_new[t, ...] = (1 - l) * self.shape_x.apply_upsampling( 275 | vertices[t, ...] 276 | ) + l * self.shape_y.apply_upsampling(vertices[t, ...]) 277 | 278 | self.shape_x.increase_scale_idx() 279 | self.shape_y.increase_scale_idx() 280 | 281 | vertices_new[0, ...] = self.shape_x.vert.clone() 282 | vertices_new[num_t - 1, ...] = self.shape_y.vert.clone() 283 | 284 | return vertices_new 285 | 286 | def insert_additional_vertices(self, vertices): 287 | num_t = self.param.num_timesteps 288 | num_vert = vertices.shape[1] 289 | self.param.num_timesteps = num_t * 2 - 1 290 | 291 | vertices = vertices.unsqueeze(1) 292 | 293 | vertices = vertices * torch.as_tensor( 294 | [1, 0], device=device, dtype=torch.float32 295 | ).unsqueeze(0).unsqueeze(2).unsqueeze(3) 296 | vertices = vertices.reshape([num_t * 2, num_vert, 3]) 297 | vertices = vertices[0 : num_t * 2 - 1, ...] 298 | 299 | for i in range(num_t - 1): 300 | vertices[i * 2 + 1, ...] = ( 301 | 0.5 * (vertices[i * 2, ...] + vertices[i * 2 + 2, ...]).clone() 302 | ) 303 | 304 | self.vert_sequence.data = vertices 305 | 306 | return vertices 307 | 308 | 309 | class InterpolationModuleHesstrans(InterpolationModuleMultiscale): 310 | def __init__( 311 | self, 312 | energy: InterpolationEnergy, 313 | shape_x: Shape, 314 | shape_y: Shape, 315 | param=Param(), 316 | vertices=None, 317 | ): 318 | super().__init__(energy) 319 | 320 | self.param = param 321 | self.shape_x = shape_x 322 | self.shape_y = shape_y 323 | 324 | self.hess = torch.as_tensor( 325 | shape_x.get_neigh_hessian().todense(), dtype=torch.float, device=device 326 | ) 327 | self.hess += 1e-3 * my_eye(self.hess.shape[0]) 328 | self.hess_inv = self.hess.inverse() 329 | 330 | if vertices is None: 331 | vertices = self.compute_vertices() 332 | self.vert_sequence = Parameter(vertices, requires_grad=True) 333 | 334 | def forward(self): 335 | num_t = self.param.num_timesteps 336 | 337 | v_s = self.get_vert_sequence() 338 | 339 | E_x = self.energy.forward_single( 340 | v_s[0, ...].detach(), v_s[1, ...], self.shape_x 341 | ) 342 | E_y = self.energy.forward_single( 343 | v_s[num_t - 2, ...], v_s[num_t - 1, ...].detach(), self.shape_x 344 | ) 345 | 346 | E_total = E_x + E_y 347 | 348 | for i in range(1, num_t - 2): 349 | E_curr = self.energy.forward_single( 350 | v_s[i, ...], v_s[i + 1, ...], self.shape_x 351 | ) 352 | E_total = E_total + E_curr 353 | 354 | E_total = E_total / (num_t - 1) 355 | 356 | return E_total, [E_total] 357 | 358 | def get_vert_sequence(self): 359 | v_s = my_zeros(self.vert_sequence.shape) 360 | for i in range(0, v_s.shape[0]): 361 | v_s[i] = torch.mm(self.hess_inv, self.vert_sequence[i]) 362 | return v_s 363 | 364 | def compute_vertices(self): 365 | num_t = self.param.num_timesteps 366 | vertices = my_zeros([num_t, self.shape_x.get_vert_shape()[0], 3]) 367 | for i in range(0, num_t): 368 | lambd = i / (num_t - 1) 369 | vertices[i, ...] = ( 370 | 1 - lambd 371 | ) * self.shape_x.get_vert().clone() + lambd * self.shape_y.get_vert().clone() 372 | vertices[i, ...] = torch.mm(self.hess, vertices[i, ...]) 373 | 374 | return vertices 375 | 376 | def copy_self(self, vertices=None): 377 | return InterpolationModuleBvp( 378 | self.energy, self.shape_x, self.shape_y, self.param, vertices 379 | ) 380 | 381 | def step_multiscale(self, i_scale): 382 | vertices = self.vert_sequence.data 383 | 384 | if i_scale % 2 == 0: 385 | vertices = self.insert_additional_vertices(vertices) 386 | else: 387 | vertices = self.upsample_resolution(vertices) 388 | 389 | return self.copy_self(vertices) 390 | 391 | def upsample_resolution(self, vertices): 392 | num_vert = self.shape_x.next_resolution()[0] 393 | num_t = vertices.shape[0] 394 | 395 | vertices_new = my_zeros([num_t, num_vert, 3]) 396 | 397 | for t in range(1, num_t - 1): 398 | l = (t - 1) / (num_t - 2) 399 | vertices_new[t, ...] = (1 - l) * self.shape_x.apply_upsampling( 400 | torch.mm(self.hess_inv, vertices[t, ...]) 401 | ) + l * self.shape_y.apply_upsampling( 402 | torch.mm(self.hess_inv, vertices[t, ...]) 403 | ) 404 | vertices_new[t, ...] = torch.mm(self.hess, vertices_new[t, ...]) 405 | 406 | self.shape_x.increase_scale_idx() 407 | self.shape_y.increase_scale_idx() 408 | 409 | vertices_new[0, ...] = torch.mm(self.hess, self.shape_x.vert.clone()) 410 | vertices_new[num_t - 1, ...] = torch.mm(self.hess, self.shape_y.vert.clone()) 411 | 412 | return vertices_new 413 | 414 | def insert_additional_vertices(self, vertices): 415 | num_t = self.param.num_timesteps 416 | num_vert = vertices.shape[1] 417 | self.param.num_timesteps = num_t * 2 - 1 418 | 419 | vertices = vertices.unsqueeze(1) 420 | 421 | vertices = vertices * torch.as_tensor( 422 | [1, 0], device=device, dtype=torch.float32 423 | ).unsqueeze(0).unsqueeze(2).unsqueeze(3) 424 | vertices = vertices.reshape([num_t * 2, num_vert, 3]) 425 | vertices = vertices[0 : num_t * 2 - 1, ...] 426 | 427 | for i in range(num_t - 1): 428 | vertices[i * 2 + 1, ...] = ( 429 | 0.5 * (vertices[i * 2, ...] + vertices[i * 2 + 2, ...]).clone() 430 | ) 431 | 432 | self.vert_sequence.data = vertices 433 | 434 | return vertices 435 | 436 | 437 | class InterpolationModuleBvpHessian(InterpolationModuleBvp, InterpolationModuleHessian): 438 | """Same as InterpolationModuleBvp, but with second order optimization""" 439 | 440 | def __init__( 441 | self, 442 | energy: InterpolationEnergy, 443 | shape_x: Shape, 444 | shape_y: Shape, 445 | param=Param(), 446 | vertices=None, 447 | ): 448 | super().__init__(energy, shape_x, shape_y, param, vertices=vertices) 449 | 450 | def copy_self(self, vertices=None): 451 | return InterpolationModuleBvpHessian( 452 | self.energy, self.shape_x, self.shape_y, self.param, vertices 453 | ) 454 | 455 | def mul_with_inv_hessian(self): 456 | num_t = self.param.num_timesteps 457 | n_vert = self.shape_x.get_vert_shape()[0] 458 | 459 | hess_1d = self.energy.get_hessian(self.shape_x) 460 | 461 | central_diff_diags = -np.ones([3, num_t]) 462 | central_diff_diags[1, :] = 2 463 | central_diff = lil_matrix( 464 | spdiags(central_diff_diags, np.array([-1, 0, 1]), num_t, num_t) 465 | ) 466 | central_diff[[0, num_t - 1], :] = 0 467 | 468 | boundary_cond = lil_matrix((num_t, num_t)) 469 | 470 | boundary_cond[0, 0] = 1 471 | boundary_cond[num_t - 1, num_t - 1] = 1 472 | 473 | hess = csr_matrix( 474 | kron(central_diff, hess_1d) + kron(boundary_cond, eye(n_vert)) 475 | ) 476 | 477 | grad_hess = spsolve( 478 | hess, self.vert_sequence.grad.view(-1, 3).to(device_cpu).detach().cpu() 479 | ) 480 | self.vert_sequence.grad = torch.tensor( 481 | grad_hess, dtype=torch.float32, device=device 482 | ).view_as(self.vert_sequence) 483 | self.vert_sequence.grad = self.vert_sequence.grad.clone() 484 | 485 | 486 | # ----------------------------------------------------------------------------------- 487 | 488 | 489 | class Interpolation: 490 | """Base class for interpolation optimizers for some 491 | interpolation method 'InterpolationModule'""" 492 | 493 | def __init__(self, interp_module: InterpolationModule, param=Param()): 494 | self.interp_module = interp_module 495 | self.param = param 496 | 497 | def interpolate(self, super_idx=-1): 498 | raise NotImplementedError() 499 | 500 | 501 | class InterpolationNewton(Interpolation): 502 | """Shape interpolation with Newton optimization""" 503 | 504 | def __init__(self, interp_module: InterpolationModuleHessian, param=Param()): 505 | super().__init__(interp_module, param) 506 | 507 | def interpolate(self, super_idx=-1): 508 | 509 | lr = self.param.lr 510 | optimizer = torch.optim.Adam(self.interp_module.parameters(), lr=lr) 511 | 512 | self.interp_module.train() 513 | 514 | E = 0 515 | 516 | for it in range(self.param.num_it): 517 | optimizer.zero_grad() 518 | E, Elist = self.interp_module() 519 | E.backward() 520 | 521 | self.interp_module.mul_with_inv_hessian() 522 | 523 | optimizer.step() 524 | 525 | if self.param.log: 526 | if self.param.log: 527 | if super_idx >= 0: 528 | print( 529 | "Super {:02d}, It {:03d}, E: {:.5f}".format( 530 | super_idx, it, Elist[0] 531 | ) 532 | ) 533 | else: 534 | print("It {:03d}, E: {:.5f}".format(it, Elist[0])) 535 | 536 | self.energy = self.interp_module.eval() 537 | 538 | return E.detach() 539 | 540 | 541 | class InterpolationLBFGS(Interpolation): 542 | """Shape interpolation with the quasi-Newton scheme LBFGS""" 543 | 544 | def __init__(self, interp_module: InterpolationModule, param=Param()): 545 | super().__init__(interp_module, param) 546 | 547 | def interpolate(self, super_idx=-1): 548 | 549 | lr = self.param.lr 550 | optimizer = torch.optim.LBFGS( 551 | self.interp_module.parameters(), lr=lr, line_search_fn="strong_wolfe" 552 | ) 553 | 554 | self.interp_module.train() 555 | 556 | for it in range(self.param.num_it): 557 | 558 | def closure(): 559 | if torch.is_grad_enabled(): 560 | optimizer.zero_grad() 561 | E, Elist = self.interp_module() 562 | if E.requires_grad: 563 | E.backward() 564 | if self.param.log: 565 | if super_idx >= 0: 566 | print( 567 | "Super {:02d}, It {:03d}, E: {:.5f}".format( 568 | super_idx, it, Elist[0] 569 | ) 570 | ) 571 | else: 572 | print("It {:03d}, E: {:.5f}".format(it, Elist[0])) 573 | return E 574 | 575 | optimizer.step(closure) 576 | 577 | self.energy = self.interp_module.eval() 578 | 579 | 580 | class InterpolationGD(Interpolation): 581 | """Shape interpolation with gradient descent 'Adam'""" 582 | 583 | def __init__(self, interp_module: InterpolationModuleBase, param=Param()): 584 | super().__init__(interp_module, param) 585 | 586 | def interpolate(self, super_idx=-1): 587 | 588 | lr = self.param.lr 589 | optimizer = torch.optim.Adam(self.interp_module.parameters(), lr=lr) 590 | 591 | self.interp_module.train() 592 | 593 | E = 0 594 | 595 | for it in range(self.param.num_it): 596 | optimizer.zero_grad() 597 | E, Elist = self.interp_module() 598 | E.backward() 599 | if self.param.log: 600 | if super_idx >= 0: 601 | print( 602 | "Super {:02d}, It {:03d}, E: {:.5f}".format( 603 | super_idx, it, Elist[0] 604 | ) 605 | ) 606 | else: 607 | print("It {:03d}, E: {:.5f}".format(it, Elist[0])) 608 | 609 | optimizer.step() 610 | 611 | self.energy = self.interp_module.eval() 612 | 613 | return E.detach() 614 | 615 | 616 | class InterpolationMultiscale: 617 | """Shape interpolation with an additional multi-scale strategy""" 618 | 619 | def __init__(self, interp: Interpolation, param=Param(), num_it_super=None): 620 | self.interp = interp 621 | self.param = param 622 | self.E = None 623 | 624 | if num_it_super is None: 625 | if self.param.odd_multisteps: 626 | self.num_it_super = len(self.param.scales) * 2 627 | else: 628 | self.num_it_super = len(self.param.scales) * 2 - 1 629 | else: 630 | self.num_it_super = num_it_super 631 | 632 | def interpolate(self): 633 | 634 | for i in range(self.num_it_super - 1): 635 | self.interp.interpolate(i) 636 | self.interp.interp_module = self.interp.interp_module.step_multiscale(i) 637 | self.E = self.interp.interpolate(self.num_it_super - 1) 638 | 639 | return self.interp.interp_module 640 | 641 | 642 | # ----------------------------------------------------------------------------------- 643 | 644 | 645 | if __name__ == "__main__": 646 | print("main of interpolation_base.py") 647 | -------------------------------------------------------------------------------- /utils/shape_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Marvin Eisenberger. 3 | 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from scipy import sparse 9 | import numpy as np 10 | from torch_geometric.nn import fps, knn_graph 11 | import matplotlib.pyplot as plt 12 | from param import * 13 | from utils.base_tools import * 14 | 15 | 16 | def plot_curr_shape(vert, triv_x): 17 | fig = plt.figure(1) 18 | ax = fig.add_subplot(111, projection="3d") 19 | ax.plot_trisurf( 20 | vert[:, 0], 21 | vert[:, 1], 22 | vert[:, 2], 23 | triangles=triv_x, 24 | cmap="viridis", 25 | linewidths=0.2, 26 | ) 27 | ax.set_xlim(-0.4, 0.4) 28 | ax.set_ylim(-0.4, 0.4) 29 | ax.set_zlim(-0.4, 0.4) 30 | 31 | 32 | class Shape: 33 | """Class for shapes. (Optional) attributes are: 34 | vert: Vertices in the format nx3 35 | triv: Triangles in the format mx3 36 | samples: Index list of active vertices 37 | neigh: List of 2-Tuples encoding the adjacency of vertices 38 | neigh_hessian: Hessian/Graph Laplacian of the shape based on 'neigh' 39 | mahal_cov_mat: The covariance matrix of our anisotropic arap energy""" 40 | 41 | def __init__(self, vert=None, triv=None): 42 | self.vert = vert 43 | self.triv = triv 44 | self.samples = list(range(vert.shape[0])) 45 | self.neigh = None 46 | self.neigh_hessian = None 47 | self.mahal_cov_mat = None 48 | self.normal = None 49 | self.D = None 50 | self.sub = None 51 | self.vert_full = None 52 | 53 | if not self.triv is None: 54 | self.triv = self.triv.to(dtype=torch.long) 55 | 56 | def subsample_fps(self, goal_vert): 57 | assert ( 58 | goal_vert <= self.vert.shape[0] 59 | ), "you cannot subsample to more vertices than n" 60 | 61 | ratio = goal_vert / self.vert.shape[0] 62 | self.samples = fps(self.vert.detach().to(device_cpu), ratio=ratio).to(device) 63 | self._neigh_knn() 64 | 65 | def reset_sampling(self): 66 | self.gt_sampling(self.vert.shape[0]) 67 | 68 | def gt_sampling(self, n): 69 | self.samples = list(range(n)) 70 | self.neigh = None 71 | 72 | def scale(self, factor, shift=True): 73 | self.vert = self.vert * factor 74 | 75 | if shift: 76 | self.vert = self.vert + (1 - factor) / 2 77 | 78 | def get_bounding_box(self): 79 | max_x, _ = self.vert.max(dim=0) 80 | min_x, _ = self.vert.min(dim=0) 81 | 82 | return min_x, max_x 83 | 84 | def to_box(self, shape_y): 85 | 86 | min_x, max_x = self.get_bounding_box() 87 | min_y, max_y = shape_y.get_bounding_box() 88 | 89 | extent_x = max_x - min_x 90 | extent_y = max_y - min_y 91 | 92 | self.translate(-min_x) 93 | shape_y.translate(-min_y) 94 | 95 | scale_fac = torch.max(torch.cat((extent_x, extent_y), 0)) 96 | scale_fac = 1.0 / scale_fac 97 | 98 | self.scale(scale_fac, shift=False) 99 | shape_y.scale(scale_fac, shift=False) 100 | 101 | extent_x = scale_fac * extent_x 102 | extent_y = scale_fac * extent_y 103 | 104 | self.translate(0.5 * (1 - extent_x)) 105 | shape_y.translate(0.5 * (1 - extent_y)) 106 | 107 | def translate(self, offset): 108 | self.vert = self.vert + offset.unsqueeze(0) 109 | 110 | def get_vert(self): 111 | return self.vert[self.samples, :] 112 | 113 | def get_vert_shape(self): 114 | return self.get_vert().shape 115 | 116 | def get_triv(self): 117 | return self.triv 118 | 119 | def get_triv_np(self): 120 | return self.triv.detach().cpu().numpy() 121 | 122 | def get_vert_np(self): 123 | return self.vert[self.samples, :].detach().cpu().numpy() 124 | 125 | def get_vert_full_np(self): 126 | return self.vert.detach().cpu().numpy() 127 | 128 | def get_neigh(self, num_knn=5): 129 | if self.neigh is None: 130 | self.compute_neigh(num_knn=num_knn) 131 | 132 | return self.neigh 133 | 134 | def compute_neigh(self, num_knn=5): 135 | if len(self.samples) == self.vert.shape[0]: 136 | self._triv_neigh() 137 | else: 138 | self._neigh_knn(num_knn=num_knn) 139 | 140 | def get_edge_index(self, num_knn=5): 141 | edge_index_one = self.get_neigh(num_knn).t() 142 | edge_index = torch.zeros( 143 | [2, edge_index_one.shape[1] * 2], dtype=torch.long, device=self.vert.device 144 | ) 145 | edge_index[:, : edge_index_one.shape[1]] = edge_index_one 146 | edge_index[0, edge_index_one.shape[1] :] = edge_index_one[1, :] 147 | edge_index[1, edge_index_one.shape[1] :] = edge_index_one[0, :] 148 | return edge_index 149 | 150 | def _triv_neigh(self): 151 | self.neigh = torch.cat( 152 | (self.triv[:, [0, 1]], self.triv[:, [0, 2]], self.triv[:, [1, 2]]), 0 153 | ) 154 | 155 | def _neigh_knn(self, num_knn=5): 156 | vert = self.get_vert().detach() 157 | print("Compute knn....") 158 | self.neigh = ( 159 | knn_graph(vert.to(device_cpu), num_knn, loop=False) 160 | .transpose(0, 1) 161 | .to(device) 162 | ) 163 | 164 | def get_neigh_hessian(self): 165 | if self.neigh_hessian is None: 166 | self.compute_neigh_hessian() 167 | 168 | return self.neigh_hessian 169 | 170 | def compute_neigh_hessian(self): 171 | 172 | neigh = self.get_neigh() 173 | 174 | n_vert = self.get_vert().shape[0] 175 | 176 | H = sparse.lil_matrix(1e-3 * sparse.identity(n_vert)) 177 | 178 | I = np.array(neigh[:, 0].detach().cpu()) 179 | J = np.array(neigh[:, 1].detach().cpu()) 180 | V = np.ones([neigh.shape[0]]) 181 | U = -V 182 | H = H + sparse.lil_matrix( 183 | sparse.coo_matrix((U, (I, J)), shape=(n_vert, n_vert)) 184 | ) 185 | H = H + sparse.lil_matrix( 186 | sparse.coo_matrix((U, (J, I)), shape=(n_vert, n_vert)) 187 | ) 188 | H = H + sparse.lil_matrix( 189 | sparse.coo_matrix((V, (I, I)), shape=(n_vert, n_vert)) 190 | ) 191 | H = H + sparse.lil_matrix( 192 | sparse.coo_matrix((V, (J, J)), shape=(n_vert, n_vert)) 193 | ) 194 | 195 | self.neigh_hessian = H 196 | 197 | def rotate(self, R): 198 | self.vert = torch.mm(self.vert, R.transpose(0, 1)) 199 | 200 | def to(self, device): 201 | self.vert = self.vert.to(device) 202 | self.triv = self.triv.to(device) 203 | 204 | def detach_cpu(self): 205 | self.vert = self.vert.detach().cpu() 206 | self.triv = self.triv.detach().cpu() 207 | if self.normal is not None: 208 | self.normal = self.normal.detach().cpu() 209 | if self.neigh is not None: 210 | self.neigh = self.neigh.detach().cpu() 211 | if self.D is not None: 212 | self.D = self.D.detach().cpu() 213 | if self.vert_full is not None: 214 | self.vert_full = self.vert_full.detach().cpu() 215 | if self.samples is not None and torch.is_tensor(self.samples): 216 | self.samples = self.samples.detach().cpu() 217 | if self.sub is not None: 218 | for i_s in range(len(self.sub)): 219 | for i_p in range(len(self.sub[i_s])): 220 | self.sub[i_s][i_p] = self.sub[i_s][i_p].detach().cpu() 221 | 222 | def compute_volume(self): 223 | return self.compute_volume_shifted(self.vert) 224 | 225 | def compute_volume_shifted(self, vert_t): 226 | vert_t = vert_t - vert_t.mean(dim=0, keepdim=True) 227 | vert_triv = vert_t[self.triv, :].to(device_cpu) 228 | 229 | vol_tetrahedra = (vert_triv.det() / 6).to(device) 230 | 231 | return vol_tetrahedra.sum() 232 | 233 | def get_normal(self): 234 | if self.normal is None: 235 | self._compute_outer_normal() 236 | return self.normal 237 | 238 | def _compute_outer_normal(self): 239 | edge_1 = torch.index_select(self.vert, 0, self.triv[:, 1]) - torch.index_select( 240 | self.vert, 0, self.triv[:, 0] 241 | ) 242 | edge_2 = torch.index_select(self.vert, 0, self.triv[:, 2]) - torch.index_select( 243 | self.vert, 0, self.triv[:, 0] 244 | ) 245 | 246 | face_norm = torch.cross(1e4 * edge_1, 1e4 * edge_2) 247 | 248 | normal = my_zeros(self.vert.shape) 249 | for d in range(3): 250 | normal = torch.index_add(normal, 0, self.triv[:, d], face_norm) 251 | self.normal = normal / (1e-5 + normal.norm(dim=1, keepdim=True)) 252 | 253 | 254 | if __name__ == "__main__": 255 | print("main of shape_utils.py") 256 | --------------------------------------------------------------------------------