├── LICENSE ├── README.md ├── animations ├── input_full.gif ├── output_full.gif └── teaser.png ├── configs ├── default.yaml ├── demo.yaml └── noflow │ ├── lpdc_completion.yaml │ ├── lpdc_completion_pretrained.yaml │ ├── lpdc_even.yaml │ ├── lpdc_even_pretrained.yaml │ ├── lpdc_uneven.yaml │ └── lpdc_uneven_pretrained.yaml ├── data └── Humans │ ├── D-FAUST │ ├── overfit.lst │ ├── test.lst │ ├── test_new_individual.lst │ ├── train.lst │ └── val.lst │ └── Demo │ ├── 50002_light_hopping_loose │ └── pcl_seq │ │ ├── 00000080.npz │ │ ├── 00000081.npz │ │ ├── 00000082.npz │ │ ├── 00000083.npz │ │ ├── 00000084.npz │ │ ├── 00000085.npz │ │ ├── 00000086.npz │ │ ├── 00000087.npz │ │ ├── 00000088.npz │ │ ├── 00000089.npz │ │ ├── 00000090.npz │ │ ├── 00000091.npz │ │ ├── 00000092.npz │ │ ├── 00000093.npz │ │ ├── 00000094.npz │ │ ├── 00000095.npz │ │ └── 00000096.npz │ └── 50004_punching │ └── pcl_seq │ ├── 00000166.npz │ ├── 00000167.npz │ ├── 00000168.npz │ ├── 00000169.npz │ ├── 00000170.npz │ ├── 00000171.npz │ ├── 00000172.npz │ ├── 00000173.npz │ ├── 00000174.npz │ ├── 00000175.npz │ ├── 00000176.npz │ ├── 00000177.npz │ ├── 00000178.npz │ ├── 00000179.npz │ ├── 00000180.npz │ ├── 00000181.npz │ └── 00000182.npz ├── docs └── index.html ├── environment.yaml ├── eval.py ├── generate.py ├── im2mesh ├── __init__.py ├── __init__.pyc ├── checkpoints.py ├── checkpoints.pyc ├── common.py ├── config.py ├── config.pyc ├── data │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── core.cpython-36.pyc │ │ ├── fields.cpython-36.pyc │ │ ├── subseq_dataset.cpython-36.pyc │ │ └── transforms.cpython-36.pyc │ ├── core.py │ ├── core.pyc │ ├── fields.py │ ├── fields.pyc │ ├── subseq_dataset.py │ ├── subseq_dataset.pyc │ ├── transforms.py │ └── transforms.pyc ├── encoder │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── conv.cpython-36.pyc │ │ ├── pointnet.cpython-36.pyc │ │ ├── pointnet_unet.cpython-36.pyc │ │ ├── unet.cpython-36.pyc │ │ └── unet3d.cpython-36.pyc │ ├── conv.pyc │ └── pointnet.py ├── eval.py ├── layers.py ├── lpdc │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── config.cpython-36.pyc │ │ ├── generation.cpython-36.pyc │ │ └── training.cpython-36.pyc │ ├── config.py │ ├── generation.py │ ├── models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── decoder.cpython-36.pyc │ │ │ ├── decoder_unet.cpython-36.pyc │ │ │ └── displacement.cpython-36.pyc │ │ ├── decoder.py │ │ └── displacement.py │ └── training.py ├── lpdc_uneven │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── config.cpython-36.pyc │ │ ├── generation.cpython-36.pyc │ │ └── training.cpython-36.pyc │ ├── config.py │ ├── generation.py │ ├── models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── decoder.cpython-36.pyc │ │ │ ├── decoder_unet.cpython-36.pyc │ │ │ └── displacement.cpython-36.pyc │ │ ├── decoder.py │ │ └── displacement.py │ └── training.py ├── training.py └── utils │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── grad.cpython-36.pyc │ ├── io.cpython-36.pyc │ ├── onet_generator.cpython-36.pyc │ └── visualize.cpython-36.pyc │ ├── binvox_rw.py │ ├── grad.py │ ├── icp.py │ ├── io.py │ ├── libkdtree │ ├── .gitignore │ ├── LICENSE.txt │ ├── MANIFEST.in │ ├── README │ ├── README.rst │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ ├── pykdtree │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ ├── _kdtree_core.c │ │ ├── _kdtree_core.c.mako │ │ ├── kdtree.c │ │ ├── kdtree.cpython-36m-x86_64-linux-gnu.so │ │ ├── kdtree.pyx │ │ ├── render_template.py │ │ └── test_tree.py │ ├── setup.cfg │ └── setup.py │ ├── libmcubes │ ├── .gitignore │ ├── LICENSE │ ├── README.rst │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── exporter.cpython-36.pyc │ ├── exporter.py │ ├── marchingcubes.cpp │ ├── marchingcubes.h │ ├── mcubes.cpp │ ├── mcubes.cpython-36m-x86_64-linux-gnu.so │ ├── mcubes.pyx │ ├── pyarray_symbol.h │ ├── pyarraymodule.h │ ├── pywrapper.cpp │ ├── pywrapper.h │ └── setup.py │ ├── libmesh │ ├── .gitignore │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── inside_mesh.cpython-36.pyc │ ├── inside_mesh.py │ ├── inside_mesh.pyc │ ├── setup.py │ ├── triangle_hash.cpython-36m-x86_64-linux-gnu.so │ └── triangle_hash.pyx │ ├── libmise │ ├── .gitignore │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ ├── mise.cpython-36m-x86_64-linux-gnu.so │ ├── mise.pyx │ ├── setup.py │ └── test.py │ ├── libsimplify │ ├── Simplify.h │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ ├── setup.py │ ├── simplify_mesh.cpp │ ├── simplify_mesh.cpython-36m-x86_64-linux-gnu.so │ ├── simplify_mesh.pyx │ └── test.py │ ├── libvoxelize │ ├── .gitignore │ ├── __init__.py │ ├── setup.py │ ├── tribox2.h │ ├── voxelize.cpython-36m-x86_64-linux-gnu.so │ └── voxelize.pyx │ ├── mesh.py │ ├── onet_generator.py │ ├── torchdiffeq │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── assets │ │ ├── ode_demo.gif │ │ ├── odenet_0_viz.png │ │ └── resnet_0_viz.png │ ├── examples │ │ ├── README.md │ │ ├── latent_ode.py │ │ ├── ode_demo.py │ │ └── odenet_mnist.py │ ├── setup.py │ ├── tests │ │ ├── api_tests.py │ │ ├── gradient_tests.py │ │ ├── odeint_tests.py │ │ ├── problems.py │ │ └── run_all.py │ └── torchdiffeq │ │ ├── __init__.py │ │ └── _impl │ │ ├── __init__.py │ │ ├── adams.py │ │ ├── adjoint.py │ │ ├── dopri5.py │ │ ├── fixed_adams.py │ │ ├── fixed_grid.py │ │ ├── interp.py │ │ ├── misc.py │ │ ├── odeint.py │ │ ├── rk_common.py │ │ ├── solvers.py │ │ └── tsit5.py │ ├── visualize.py │ └── voxels.py ├── scripts ├── build_dataset.sh ├── build_dataset_incomplete.sh ├── compute_incomplete.py ├── config.sh ├── download_data.sh ├── install_dataset.sh ├── migrate_dfaust.sh ├── sample_mesh.py └── split_files │ ├── overfit.lst │ ├── test.lst │ ├── test_new_individual.lst │ ├── train.lst │ ├── train_generative.lst │ └── val.lst ├── setup.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jiapeng Tang, Dan Xu, Kui Jia, Lei Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LPDC-Net 2 | [Homepage](https://tangjiapeng.github.io/LPDC-Net) | [Paper-Pdf](https://arxiv.org/pdf/2103.16341.pdf) | [Video](https://youtu.be/dhmuuzfRpNs) 3 | 4 | This repository contains the code for the project **LPDC-Net - Learning Parallel Dense Correspondence from Spatio-Temporal Descriptors 5 | for Efficient and Robust 4D Reconstruction** 6 | 7 | You can find detailed usage instructions for training your own models and using the [pretrained models](https://drive.google.com/drive/folders/1jPrkxd9GYKtSsQt_q4poIYMuxeXlYRZ9?usp=sharing). 8 | 9 | ## Installation 10 | First you have to make sure that you have all dependencies in place. You can create and activate an anaconda environment called `lpdc` using 11 | 12 | ``` 13 | conda env create -f environment.yaml 14 | conda activate lpdc 15 | ``` 16 | Next, compile the extension modules. You can do this via 17 | ``` 18 | python setup.py build_ext --inplace 19 | ``` 20 | 21 | ## Demo 22 | 23 | You can test our code on the provided input point cloud sequences in the `demo/` folder. To this end, simple run 24 | ``` 25 | python generate.py configs/demo.yaml 26 | ``` 27 | This script should create a folder `out/demo/` where the output is stored. 28 | 29 | ## Dataset 30 | 31 | ### Point-based Data 32 | To train a new model from scratch, you have to download the full dataset. 33 | You can download the pre-processed data (~42 GB) using 34 | 35 | ``` 36 | bash scripts/download_data.sh 37 | ``` 38 | 39 | The script will download the point-based point-based data for the [Dynamic FAUST (D-FAUST)](http://dfaust.is.tue.mpg.de/) dataset to the `data/` folder. 40 | 41 | ### Mesh Data 42 | 43 | Please follow the instructions on [D-FAUST homepage](http://dfaust.is.tue.mpg.de/) to download the "female and male registrations" as well as "scripts to load / parse the data". 44 | Next, follow their instructions in the `scripts/README.txt` file to extract the obj-files of the sequences. Once completed, you should have a folder with the following structure: 45 | ___ 46 | your_dfaust_folder/ 47 | | 50002_chicken_wings/ 48 |     | 00000.obj 49 |     | 00001.obj 50 |     | ... 51 |     | 000215.obj 52 | | 50002_hips/ 53 |     | 00000.obj 54 |     | ... 55 | | ... 56 | | 50027_shake_shoulders/ 57 |     | 00000.obj 58 |     | ... 59 | ___ 60 | You can now run 61 | ``` 62 | bash scripts/migrate_dfaust.sh path/to/your_dfaust_folder 63 | ``` 64 | to copy the mesh data to the dataset folder. 65 | The argument has to be the folder to which you have extracted the mesh data (the `your_dfaust_folder` from the directory tree above). 66 | 67 | ### Incomplete Point Cloud Sequence 68 | 69 | You can now run 70 | ``` 71 | bash scripts/build_dataset_incomplete.sh 72 | ``` 73 | to create incomplete point cloud sequences for the experiment of 4D Shape Completion. 74 | 75 | ## Usage 76 | 77 | When you have installed all dependencies and obtained the preprocessed data, you are ready to run our pre-trained models and train new models from scratch. 78 | 79 | ### Generation 80 | 81 |
82 | 83 | 84 |
85 | 86 | To start the normal mesh generation process using a trained model, use 87 | 88 | ``` 89 | python generate.py configs/CONFIG.yaml 90 | ``` 91 | where you replace `CONFIG.yaml` with the name of the configuration file you want to use. 92 | 93 | The easiest way is to use a pretrained model. You can do this by using one of the config files 94 | 95 | ``` 96 | configs/noflow/lpdc_even_pretrained.yaml 97 | configs/noflow/lpdc_uneven_pretrained.yaml 98 | configs/noflow/lpdc_completion_pretrained.yaml 99 | ``` 100 | 101 | Our script will automatically download the model checkpoints and run the generation. 102 | You can find the outputs in the `out/pointcloud` folder. 103 | 104 | Please note that the config files *_pretrained.yaml are only for generation, not for training new models: when these configs are used for training, the model 105 | will be trained from scratch, but during inference our code will still use the pretrained model. 106 | 107 | ### Evaluation 108 | 109 | You can evaluate the generated output of a model on the test set using 110 | 111 | ``` 112 | python eval.py configs/CONFIG.yaml 113 | ``` 114 | The evaluation results will be saved to pickle and csv files. 115 | 116 | ### Training 117 | 118 | Finally, to train a new network from scratch, run 119 | ``` 120 | python train.py configs/CONFIG.yaml 121 | ``` 122 | You can monitor the training process on http://localhost:6006 using tensorboard: 123 | ``` 124 | cd OUTPUT_DIR 125 | tensorboard --logdir ./logs --port 6006 126 | ``` 127 | where you replace `OUTPUT_DIR` with the respective output directory. For available training options, please have a look at `config/default.yaml`. 128 | 129 | 130 | ## Acknowledgements 131 | 132 | Most of the code is borrowed from [Occupancy Flow](https://github.com/autonomousvision/occupancy_flow). We thank Michael Niemeyer for his great works and repos. 133 | 134 | ## Citation 135 | 136 | If you find our code or paper useful, please consider citing 137 | 138 | @inproceedings{tang2021learning, 139 | title={Learning Parallel Dense Correspondence from Spatio-Temporal Descriptors for Efficient and Robust 4D Reconstruction}, 140 | author={Tang, Jiapeng and Xu, Dan and Jia, Kui and Zhang, Lei}, 141 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 142 | pages={6022--6031}, 143 | year={2021} 144 | } 145 | -------------------------------------------------------------------------------- /animations/input_full.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/animations/input_full.gif -------------------------------------------------------------------------------- /animations/output_full.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/animations/output_full.gif -------------------------------------------------------------------------------- /animations/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/animations/teaser.png -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | method: lpdc 2 | data: 3 | path: ./data/Humans 4 | dataset: Humans 5 | input_type: pcl_seq 6 | classes: ['D-FAUST'] 7 | train_split: train 8 | val_split: val 9 | test_split: test 10 | dim: 3 11 | n_training_points: 512 12 | points_unpackbits: true 13 | n_training_pcl_points: 100 14 | input_pointcloud_n: 300 15 | input_pointcloud_noise: 0.001 16 | input_pointcloud_corresponding: true 17 | n_views: 24 18 | img_size: 224 19 | img_with_camera: false 20 | img_augment: false 21 | length_sequence: 17 # 22 | select_steps: null #create uneven input 23 | offset_sequence: 15 24 | n_files_per_sequence: -1 25 | n_intervals: 1 26 | points_file: points.npz 27 | mesh_seq_folder: mesh_seq 28 | points_iou_seq_folder: points_seq 29 | pointcloud_seq_folder: pcl_seq 30 | img_seq_folder: img 31 | completion: false #if point cloud completion 32 | pointcloud_seq_incomplete_folder: pcl_incomp_seq #incomplete pointcloud folder 33 | model: 34 | encoder: pointnet_resnet 35 | encoder_temporal: pointnet_resnet 36 | decoder: cbatchnorm 37 | velocity_field: concat 38 | encoder_latent: null 39 | encoder_latent_temporal: null 40 | decoder_kwargs: {} 41 | encoder_kwargs: {} 42 | encoder_latent_kwargs: {} 43 | encoder_temporal_kwargs: {} 44 | velocity_field_kwargs: {} 45 | encoder_latent_temporal_kwargs: {} 46 | learn_embedding: false 47 | ode_solver: dopri5 48 | ode_step_size: null 49 | use_adjoint: true 50 | rtol: 0.001 51 | atol: 0.00001 52 | vae_beta: 0.0001 53 | loss_corr: false 54 | loss_corr_bw: false 55 | loss_recon: true 56 | loss_transform_forward: false 57 | initialize_from: null 58 | initialization_file_name: model_best.pt 59 | c_dim: 512 60 | z_dim: 0 61 | use_camera: False 62 | training: 63 | out_dir: out/00 64 | model_selection_metric: iou 65 | model_selection_mode: maximize 66 | n_eval_points: 5000 67 | batch_size: 16 68 | batch_size_vis: 1 69 | batch_size_val: 1 70 | print_every: 5 71 | visualize_every: 999999999 72 | checkpoint_every: 200 73 | validate_every: 2000 74 | backup_every: 100000 75 | eval_sample: true 76 | learning_rate: 0.0001 77 | test: 78 | threshold: 0.3 79 | eval_mesh: true 80 | eval_pointcloud: false 81 | project_to_final_mesh: false 82 | eval_mesh_correspondences: true 83 | eval_mesh_iou: true 84 | eval_pointcloud_correspondences: true 85 | eval_only_end_time_steps: false 86 | model_file: model_best.pt 87 | generation: 88 | generate_pointcloud: false 89 | generate_mesh: true 90 | resolution_0: 32 91 | upsampling_steps: 2 92 | refinement_step: 0 93 | simplify_nfaces: null 94 | padding: 0.1 95 | vis_n_outputs: 20 96 | n_time_steps: 17 97 | mesh_color: true 98 | interpolate: false 99 | only_end_time_points: false 100 | fix_z: false 101 | fix_zt: False 102 | shuffle_generation: true 103 | rand_seed: 12345 104 | batch_size: 1000000 105 | generation_dir: generation 106 | use_sampling: false 107 | copy_input: false 108 | -------------------------------------------------------------------------------- /configs/demo.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/noflow/lpdc_even_pretrained.yaml 2 | data: 3 | classes: ['Demo'] 4 | test_split: null 5 | offset_sequence: 0 6 | training: 7 | out_dir: out/demo 8 | test: 9 | model_file: lpdc_even_pretrained.pt -------------------------------------------------------------------------------- /configs/noflow/lpdc_completion.yaml: -------------------------------------------------------------------------------- 1 | method: lpdc_uneven 2 | data: 3 | test_split: test #test_new_individual 4 | n_intervals: 1 5 | select_steps: [0,10,19,29,44,49] 6 | length_sequence: 50 7 | pointcloud_seq_incomplete_folder: incompcl_seq 8 | completion: true 9 | model: 10 | encoder: 11 | encoder_kwargs: 12 | hidden_dim: 128 13 | encoder_temporal: pointnet_spatiotemporal2 14 | encoder_temporal_kwargs: 15 | hidden_dim: 128 16 | decoder: simple_local 17 | decoder_kwargs: 18 | hidden_size: 128 19 | velocity_field: concat 20 | velocity_field_kwargs: 21 | hidden_size: 128 22 | c_dim: 128 23 | loss_corr: true 24 | loss_transform_forward: true 25 | training: 26 | model_selection_metric: iou 27 | model_selection_mode: maximize 28 | batch_size: 16 29 | validate_every: 10000 30 | backup_every: 2000 ## 31 | learning_rate: 0.0001 32 | out_dir: out/pointcloud/lpdc_completion -------------------------------------------------------------------------------- /configs/noflow/lpdc_completion_pretrained.yaml: -------------------------------------------------------------------------------- 1 | method: lpdc_uneven 2 | data: 3 | test_split: test #test_new_individual 4 | n_intervals: 1 5 | select_steps: [0,10,19,29,44,49] 6 | length_sequence: 50 7 | pointcloud_seq_incomplete_folder: incompcl_seq 8 | completion: true 9 | model: 10 | encoder: 11 | encoder_kwargs: 12 | hidden_dim: 128 13 | encoder_temporal: pointnet_spatiotemporal2 14 | encoder_temporal_kwargs: 15 | hidden_dim: 128 16 | decoder: simple_local 17 | decoder_kwargs: 18 | hidden_size: 128 19 | velocity_field: concat 20 | velocity_field_kwargs: 21 | hidden_size: 128 22 | c_dim: 128 23 | loss_corr: true 24 | loss_transform_forward: true 25 | training: 26 | out_dir: out/pointcloud/lpdc_completion_pretrained 27 | test: 28 | model_file: lpdc_completion_pretrained.pt -------------------------------------------------------------------------------- /configs/noflow/lpdc_even.yaml: -------------------------------------------------------------------------------- 1 | method: lpdc 2 | data: 3 | test_split: test #test_new_individual 4 | n_intervals: 1 5 | model: 6 | encoder: 7 | encoder_kwargs: 8 | hidden_dim: 128 9 | encoder_temporal: pointnet_spatiotemporal 10 | encoder_temporal_kwargs: 11 | hidden_dim: 128 12 | decoder: simple_local 13 | decoder_kwargs: 14 | hidden_size: 128 15 | velocity_field: concat 16 | velocity_field_kwargs: 17 | hidden_size: 128 18 | c_dim: 128 19 | loss_corr: true 20 | loss_transform_forward: true 21 | training: 22 | model_selection_metric: iou 23 | model_selection_mode: maximize 24 | batch_size: 16 25 | validate_every: 10000 26 | backup_every: 2000 ## 27 | learning_rate: 0.0001 28 | out_dir: out/pointcloud/lpdc_even 29 | -------------------------------------------------------------------------------- /configs/noflow/lpdc_even_pretrained.yaml: -------------------------------------------------------------------------------- 1 | method: lpdc 2 | data: 3 | test_split: test #test_new_individual 4 | n_intervals: 1 5 | model: 6 | encoder: 7 | encoder_kwargs: 8 | hidden_dim: 128 9 | encoder_temporal: pointnet_spatiotemporal 10 | encoder_temporal_kwargs: 11 | hidden_dim: 128 12 | decoder: simple_local 13 | decoder_kwargs: 14 | hidden_size: 128 15 | velocity_field: concat 16 | velocity_field_kwargs: 17 | hidden_size: 128 18 | c_dim: 128 19 | loss_corr: true 20 | loss_transform_forward: true 21 | training: 22 | out_dir: out/pointcloud/lpdc_even_pretrained 23 | test: 24 | model_file: lpdc_even_pretrained.pt 25 | -------------------------------------------------------------------------------- /configs/noflow/lpdc_uneven.yaml: -------------------------------------------------------------------------------- 1 | method: lpdc_uneven 2 | data: 3 | test_split: test #test_new_individual 4 | n_intervals: 1 5 | select_steps: [0,10,19,29,44,49] 6 | length_sequence: 50 7 | model: 8 | encoder: 9 | encoder_kwargs: 10 | hidden_dim: 128 11 | encoder_temporal: pointnet_spatiotemporal 12 | encoder_temporal_kwargs: 13 | hidden_dim: 128 14 | decoder: simple_local 15 | decoder_kwargs: 16 | hidden_size: 128 17 | velocity_field: concat 18 | velocity_field_kwargs: 19 | hidden_size: 128 20 | c_dim: 128 21 | loss_corr: true 22 | loss_transform_forward: true 23 | training: 24 | model_selection_metric: iou 25 | model_selection_mode: maximize 26 | batch_size: 16 27 | validate_every: 10000 28 | backup_every: 2000 ## 29 | learning_rate: 0.0001 30 | out_dir: out/pointcloud/lpdc_uneven -------------------------------------------------------------------------------- /configs/noflow/lpdc_uneven_pretrained.yaml: -------------------------------------------------------------------------------- 1 | method: lpdc_uneven 2 | data: 3 | test_split: test #test_new_individual 4 | n_intervals: 1 5 | select_steps: [0,10,19,29,44,49] 6 | length_sequence: 50 7 | model: 8 | encoder: 9 | encoder_kwargs: 10 | hidden_dim: 128 11 | encoder_temporal: pointnet_spatiotemporal 12 | encoder_temporal_kwargs: 13 | hidden_dim: 128 14 | decoder: simple_local 15 | decoder_kwargs: 16 | hidden_size: 128 17 | velocity_field: concat 18 | velocity_field_kwargs: 19 | hidden_size: 128 20 | c_dim: 128 21 | loss_corr: true 22 | loss_transform_forward: true 23 | training: 24 | out_dir: out/pointcloud/lpdc_uneven_pretrained 25 | test: 26 | model_file: lpdc_uneven_pretrained.pt -------------------------------------------------------------------------------- /data/Humans/D-FAUST/overfit.lst: -------------------------------------------------------------------------------- 1 | 50026_one_leg_jump 2 | -------------------------------------------------------------------------------- /data/Humans/D-FAUST/test.lst: -------------------------------------------------------------------------------- 1 | 50002_light_hopping_loose 2 | 50004_punching 3 | 50007_shake_shoulders 4 | 50009_chicken_wings 5 | 50020_chicken_wings 6 | 50022_light_hopping_loose 7 | 50025_light_hopping_loose 8 | 50026_shake_arms 9 | 50027_shake_shoulders 10 | -------------------------------------------------------------------------------- /data/Humans/D-FAUST/test_new_individual.lst: -------------------------------------------------------------------------------- 1 | 50021_chicken_wings 2 | 50021_knees 3 | 50021_one_leg_jump 4 | 50021_punching 5 | 50021_shake_arms 6 | 50021_shake_shoulders 7 | 50021_hips 8 | 50021_light_hopping_stiff 9 | 50021_one_leg_loose 10 | 50021_running_on_spot 11 | 50021_shake_hips 12 | -------------------------------------------------------------------------------- /data/Humans/D-FAUST/train.lst: -------------------------------------------------------------------------------- 1 | 50002_one_leg_loose 2 | 50025_shake_shoulders 3 | 50007_one_leg_jump 4 | 50002_knees 5 | 50002_light_hopping_stiff 6 | 50004_jiggle_on_toes 7 | 50004_shake_hips 8 | 50026_chicken_wings 9 | 50007_punching 10 | 50022_one_leg_jump 11 | 50009_light_hopping_stiff 12 | 50025_light_hopping_stiff 13 | 50007_shake_arms 14 | 50026_running_on_spot 15 | 50025_chicken_wings 16 | 50020_shake_hips 17 | 50026_shake_hips 18 | 50027_light_hopping_stiff 19 | 50009_shake_hips 20 | 50009_light_hopping_loose 21 | 50020_jiggle_on_toes 22 | 50025_one_leg_loose 23 | 50009_punching 24 | 50027_hips 25 | 50002_running_on_spot 26 | 50026_light_hopping_stiff 27 | 50026_jiggle_on_toes 28 | 50020_shake_shoulders 29 | 50007_light_hopping_stiff 30 | 50007_jiggle_on_toes 31 | 50027_punching 32 | 50009_running_on_spot 33 | 50002_shake_arms 34 | 50022_knees 35 | 50007_jumping_jacks 36 | 50027_running_on_spot 37 | 50022_running_on_spot 38 | 50004_knees 39 | 50027_one_leg_loose 40 | 50009_one_leg_jump 41 | 50026_jumping_jacks 42 | 50009_one_leg_loose 43 | 50027_jiggle_on_toes 44 | 50020_knees 45 | 50027_light_hopping_loose 46 | 50026_knees 47 | 50004_jumping_jacks 48 | 50026_one_leg_jump 49 | 50004_hips 50 | 50027_shake_arms 51 | 50026_hips 52 | 50020_punching 53 | 50025_jiggle_on_toes 54 | 50022_punching 55 | 50004_light_hopping_loose 56 | 50009_jumping_jacks 57 | 50002_one_leg_jump 58 | 50007_knees 59 | 50027_jumping_jacks 60 | 50022_shake_arms 61 | 50002_jiggle_on_toes 62 | 50002_hips 63 | 50009_jiggle_on_toes 64 | 50022_jiggle_on_toes 65 | 50007_shake_hips 66 | 50022_hips 67 | 50026_punching 68 | 50026_one_leg_loose 69 | 50025_running_on_spot 70 | 50025_knees 71 | 50025_hips 72 | 50020_light_hopping_stiff 73 | 50026_shake_shoulders 74 | 50002_jumping_jacks 75 | 50022_light_hopping_stiff 76 | 50027_one_leg_jump 77 | 50002_punching 78 | 50022_shake_shoulders 79 | 50004_running_on_spot 80 | 50020_light_hopping_loose 81 | 50022_shake_hips 82 | 50004_one_leg_loose 83 | 50022_jumping_jacks 84 | 50002_chicken_wings 85 | 50022_one_leg_loose 86 | 50027_knees 87 | 50004_shake_arms 88 | 50007_chicken_wings 89 | 50002_shake_hips 90 | 50007_one_leg_loose 91 | 50004_shake_shoulders 92 | 50009_hips 93 | 50007_running_on_spot 94 | 50025_shake_hips 95 | 50002_shake_shoulders 96 | 50020_shake_arms 97 | 50027_shake_hips 98 | 50026_light_hopping_loose 99 | 50025_one_leg_jump 100 | 50020_one_leg_loose 101 | 50004_one_leg_jump 102 | 50025_punching 103 | 50020_one_leg_jump 104 | 50004_light_hopping_stiff 105 | -------------------------------------------------------------------------------- /data/Humans/D-FAUST/val.lst: -------------------------------------------------------------------------------- 1 | 50004_chicken_wings 2 | 50020_hips 3 | 50025_shake_arms 4 | 50020_running_on_spot 5 | 50007_light_hopping_loose 6 | -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000080.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000080.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000081.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000081.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000082.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000082.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000083.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000083.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000084.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000084.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000085.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000085.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000086.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000086.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000087.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000087.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000088.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000088.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000089.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000089.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000090.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000090.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000091.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000091.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000092.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000092.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000093.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000093.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000094.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000094.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000095.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000095.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000096.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000096.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000166.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000166.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000167.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000167.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000168.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000168.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000169.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000169.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000170.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000170.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000171.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000171.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000172.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000172.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000173.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000173.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000174.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000174.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000175.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000175.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000176.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000176.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000177.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000177.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000178.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000178.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000179.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000179.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000180.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000180.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000181.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000181.npz -------------------------------------------------------------------------------- /data/Humans/Demo/50004_punching/pcl_seq/00000182.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000182.npz -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: lpdc 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - cython=0.29.2 8 | - imageio=2.4.1 9 | - numpy=1.15.4 10 | - numpy-base=1.15.4 11 | - matplotlib=3.0.3 12 | - matplotlib-base=3.0.3 13 | - pandas=0.23.4 14 | - pillow=5.3.0 15 | - pyembree=0.1.4 16 | - pytest=4.0.2 17 | - python=3.6.7 18 | - pytorch=1.0.0 19 | - pyyaml=3.13 20 | - scikit-image=0.14.1 21 | - scikit-learn=0.21.3 22 | - scipy=1.1.0 23 | - tensorboardx=1.4 24 | - torchvision=0.2.1 25 | - tqdm=4.28.1 26 | - trimesh=2.37.7 27 | - pip=19.1.1 28 | - pip: 29 | - h5py==2.9.0 30 | - plyfile==0.7 31 | -------------------------------------------------------------------------------- /im2mesh/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/__init__.py -------------------------------------------------------------------------------- /im2mesh/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/__init__.pyc -------------------------------------------------------------------------------- /im2mesh/checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib 3 | import torch 4 | from torch.utils import model_zoo 5 | 6 | 7 | class CheckpointIO(object): 8 | ''' CheckpointIO class. 9 | 10 | It handles saving and loading checkpoints. 11 | 12 | Args: 13 | checkpoint_dir (str): path where checkpoints are saved 14 | ''' 15 | def __init__(self, checkpoint_dir='./chkpts', initialize_from=None, 16 | initialization_file_name='model_best.pt', **kwargs): 17 | self.module_dict = kwargs 18 | self.checkpoint_dir = checkpoint_dir 19 | self.initialize_from = initialize_from 20 | self.initialization_file_name = initialization_file_name 21 | if not os.path.exists(checkpoint_dir): 22 | os.makedirs(checkpoint_dir) 23 | 24 | def register_modules(self, **kwargs): 25 | ''' Registers modules in current module dictionary. 26 | ''' 27 | self.module_dict.update(kwargs) 28 | 29 | def save(self, filename, **kwargs): 30 | ''' Saves the current module dictionary. 31 | 32 | Args: 33 | filename (str): name of output file 34 | ''' 35 | if not os.path.isabs(filename): 36 | filename = os.path.join(self.checkpoint_dir, filename) 37 | 38 | outdict = kwargs 39 | print(self.module_dict.keys()) 40 | for k, v in self.module_dict.items(): 41 | outdict[k] = v.state_dict() 42 | torch.save(outdict, filename) 43 | 44 | def load(self, filename): 45 | '''Loads a module dictionary from local file or url. 46 | 47 | Args: 48 | filename (str): name of saved module dictionary 49 | ''' 50 | if is_url(filename): 51 | return self.load_url(filename) 52 | else: 53 | return self.load_file(filename) 54 | 55 | def load_file(self, filename): 56 | '''Loads a module dictionary from file. 57 | 58 | Args: 59 | filename (str): name of saved module dictionary 60 | ''' 61 | 62 | if not os.path.isabs(filename): 63 | filename = os.path.join(self.checkpoint_dir, filename) 64 | 65 | if os.path.exists(filename): 66 | print(filename) 67 | print('=> Loading checkpoint from local file...') 68 | state_dict = torch.load(filename) 69 | scalars = self.parse_state_dict(state_dict) 70 | return scalars 71 | else: 72 | if self.initialize_from is not None: 73 | self.initialize_weights() 74 | raise FileExistsError 75 | 76 | def load_url(self, url): 77 | '''Load a module dictionary from url. 78 | 79 | Args: 80 | url (str): url to saved model 81 | ''' 82 | print(url) 83 | print('=> Loading checkpoint from url...') 84 | state_dict = model_zoo.load_url(url, progress=True) 85 | scalars = self.parse_state_dict(state_dict) 86 | return scalars 87 | 88 | def parse_state_dict(self, state_dict): 89 | '''Parse state_dict of model and return scalars. 90 | 91 | Args: 92 | state_dict (dict): State dict of model 93 | ''' 94 | 95 | for k, v in self.module_dict.items(): 96 | if k in state_dict: 97 | v.load_state_dict(state_dict[k]) 98 | else: 99 | print('Warning: Could not find %s in checkpoint!' % k) 100 | scalars = {k: v for k, v in state_dict.items() 101 | if k not in self.module_dict} 102 | return scalars 103 | 104 | def initialize_weights(self): 105 | ''' Initializes the model weights from another model file. 106 | ''' 107 | 108 | print('Intializing weights from model %s' % self.initialize_from) 109 | filename_in = os.path.join( 110 | self.initialize_from, self.initialization_file_name) 111 | 112 | model_state_dict = self.module_dict.get('model').state_dict() 113 | model_dict = self.module_dict.get('model').state_dict() 114 | model_keys = set([k for (k, v) in model_dict.items()]) 115 | 116 | init_model_dict = torch.load(filename_in)['model'] 117 | init_model_k = set([k for (k, v) in init_model_dict.items()]) 118 | 119 | for k in model_keys: 120 | if ((k in init_model_k) and (model_state_dict[k].shape == 121 | init_model_dict[k].shape)): 122 | model_state_dict[k] = init_model_dict[k] 123 | self.module_dict.get('model').load_state_dict(model_state_dict) 124 | 125 | 126 | def is_url(url): 127 | ''' Checks if input is url.''' 128 | scheme = urllib.parse.urlparse(url).scheme 129 | return scheme in ('http', 'https') 130 | -------------------------------------------------------------------------------- /im2mesh/checkpoints.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/checkpoints.pyc -------------------------------------------------------------------------------- /im2mesh/config.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/config.pyc -------------------------------------------------------------------------------- /im2mesh/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from im2mesh.data.core import ( 3 | Shapes3dDataset, collate_remove_none, worker_init_fn 4 | ) 5 | from im2mesh.data.subseq_dataset import ( 6 | HumansDataset 7 | ) 8 | from im2mesh.data.fields import ( 9 | IndexField, CategoryField, 10 | PointsSubseqField, ImageSubseqField, 11 | PointCloudSubseqField, MeshSubseqField, 12 | ) 13 | 14 | from im2mesh.data.transforms import ( 15 | PointcloudNoise, 16 | SubsamplePointcloud, 17 | SubsamplePoints, 18 | # Temporal transforms 19 | SubsamplePointsSeq, SubsamplePointcloudSeq, 20 | ) 21 | 22 | 23 | __all__ = [ 24 | # Core 25 | Shapes3dDataset, 26 | collate_remove_none, 27 | worker_init_fn, 28 | # Humans Dataset 29 | HumansDataset, 30 | # Fields 31 | IndexField, 32 | CategoryField, 33 | PointsSubseqField, 34 | PointCloudSubseqField, 35 | ImageSubseqField, 36 | MeshSubseqField, 37 | # Transforms 38 | PointcloudNoise, 39 | SubsamplePointcloud, 40 | SubsamplePoints, 41 | # Temporal Transforms 42 | SubsamplePointsSeq, 43 | SubsamplePointcloudSeq, 44 | ] 45 | -------------------------------------------------------------------------------- /im2mesh/data/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/__init__.pyc -------------------------------------------------------------------------------- /im2mesh/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/data/__pycache__/core.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/__pycache__/core.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/data/__pycache__/fields.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/__pycache__/fields.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/data/__pycache__/subseq_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/__pycache__/subseq_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/data/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/data/core.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from torch.utils import data 4 | import numpy as np 5 | import yaml 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | # Fields 12 | class Field(object): 13 | ''' Data fields class. 14 | ''' 15 | 16 | def load(self, data_path, idx, category): 17 | ''' Loads a data point. 18 | 19 | Args: 20 | data_path (str): path to data file 21 | idx (int): index of data point 22 | category (int): index of category 23 | ''' 24 | raise NotImplementedError 25 | 26 | def check_complete(self, files): 27 | ''' Checks if set is complete. 28 | 29 | Args: 30 | files: files 31 | ''' 32 | raise NotImplementedError 33 | 34 | 35 | class Shapes3dDataset(data.Dataset): 36 | ''' 3D Shapes dataset class. 37 | ''' 38 | 39 | def __init__(self, dataset_folder, fields, split=None, 40 | categories=None, no_except=True, transform=None): 41 | ''' Initialization of the the 3D shape dataset. 42 | 43 | Args: 44 | dataset_folder (str): dataset folder 45 | fields (dict): dictionary of fields 46 | split (str): which split is used 47 | categories (list): list of categories to use 48 | no_except (bool): no exception 49 | transform (callable): transformation applied to data points 50 | ''' 51 | # Attributes 52 | self.dataset_folder = dataset_folder 53 | self.fields = fields 54 | self.no_except = no_except 55 | self.transform = transform 56 | 57 | # If categories is None, use all subfolders 58 | if categories is None: 59 | categories = os.listdir(dataset_folder) 60 | categories = [c for c in categories 61 | if os.path.isdir(os.path.join(dataset_folder, c))] 62 | 63 | # Read metadata file 64 | metadata_file = os.path.join(dataset_folder, 'metadata.yaml') 65 | 66 | if os.path.exists(metadata_file): 67 | with open(metadata_file, 'r') as f: 68 | self.metadata = yaml.load(f) 69 | else: 70 | self.metadata = { 71 | c: {'id': c, 'name': 'n/a'} for c in categories 72 | } 73 | 74 | # Set index 75 | for c_idx, c in enumerate(categories): 76 | self.metadata[c]['idx'] = c_idx 77 | 78 | # Get all models 79 | self.models = [] 80 | for c_idx, c in enumerate(categories): 81 | subpath = os.path.join(dataset_folder, c) 82 | if not os.path.isdir(subpath): 83 | logger.warning('Category %s does not exist in dataset.' % c) 84 | 85 | split_file = os.path.join(subpath, split + '.lst') 86 | with open(split_file, 'r') as f: 87 | models_c = f.read().split('\n') 88 | 89 | self.models += [ 90 | {'category': c, 'model': m} 91 | for m in models_c 92 | ] 93 | 94 | def __len__(self): 95 | ''' Returns the length of the dataset. 96 | ''' 97 | return len(self.models) 98 | 99 | def __getitem__(self, idx): 100 | ''' Returns an item of the dataset. 101 | 102 | Args: 103 | idx (int): ID of data point 104 | ''' 105 | category = self.models[idx]['category'] 106 | model = self.models[idx]['model'] 107 | c_idx = self.metadata[category]['idx'] 108 | 109 | model_path = os.path.join(self.dataset_folder, category, model) 110 | data = {} 111 | 112 | for field_name, field in self.fields.items(): 113 | try: 114 | field_data = field.load(model_path, idx, c_idx) 115 | except Exception: 116 | if self.no_except: 117 | logger.warn( 118 | 'Error occured when loading field %s of model %s' 119 | % (field_name, model) 120 | ) 121 | return None 122 | else: 123 | raise 124 | 125 | if isinstance(field_data, dict): 126 | for k, v in field_data.items(): 127 | if k is None: 128 | data[field_name] = v 129 | else: 130 | data['%s.%s' % (field_name, k)] = v 131 | else: 132 | data[field_name] = field_data 133 | 134 | if self.transform is not None: 135 | data = self.transform(data) 136 | 137 | return data 138 | 139 | def get_model_dict(self, idx): 140 | return self.models[idx] 141 | 142 | def test_model_complete(self, category, model): 143 | ''' Tests if model is complete. 144 | 145 | Args: 146 | model (str): modelname 147 | ''' 148 | model_path = os.path.join(self.dataset_folder, category, model) 149 | files = os.listdir(model_path) 150 | for field_name, field in self.fields.items(): 151 | if not field.check_complete(files): 152 | logger.warn('Field "%s" is incomplete: %s' 153 | % (field_name, model_path)) 154 | return False 155 | 156 | return True 157 | 158 | 159 | def collate_remove_none(batch): 160 | ''' Collater that puts each data field into a tensor with outer dimension 161 | batch size. 162 | 163 | Args: 164 | batch: batch 165 | ''' 166 | 167 | batch = list(filter(lambda x: x is not None, batch)) 168 | return data.dataloader.default_collate(batch) 169 | 170 | 171 | def worker_init_fn(worker_id): 172 | ''' Worker init function to ensure true randomness. 173 | ''' 174 | random_data = os.urandom(4) 175 | base_seed = int.from_bytes(random_data, byteorder="big") 176 | np.random.seed(base_seed + worker_id) 177 | -------------------------------------------------------------------------------- /im2mesh/data/core.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/core.pyc -------------------------------------------------------------------------------- /im2mesh/data/fields.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/fields.pyc -------------------------------------------------------------------------------- /im2mesh/data/subseq_dataset.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/subseq_dataset.pyc -------------------------------------------------------------------------------- /im2mesh/data/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # Transforms 5 | class PointcloudNoise(object): 6 | ''' Point cloud noise transformation class. 7 | 8 | It adds noise to point cloud data. 9 | 10 | Args: 11 | stddev (int): standard deviation 12 | ''' 13 | 14 | def __init__(self, stddev): 15 | self.stddev = stddev 16 | 17 | def __call__(self, data): 18 | ''' Calls the transformation. 19 | 20 | Args: 21 | data (dictionary): data dictionary 22 | ''' 23 | data_out = data.copy() 24 | points = data[None] 25 | noise = self.stddev * np.random.randn(*points.shape) 26 | noise = noise.astype(np.float32) 27 | data_out[None] = points + noise 28 | return data_out 29 | 30 | 31 | class SubsamplePointcloud(object): 32 | ''' Point cloud subsampling transformation class. 33 | 34 | It subsamples the point cloud data. 35 | 36 | Args: 37 | N (int): number of points to be subsampled 38 | ''' 39 | 40 | def __init__(self, N): 41 | self.N = N 42 | 43 | def __call__(self, data): 44 | ''' Calls the transformation. 45 | 46 | Args: 47 | data (dict): data dictionary 48 | ''' 49 | data_out = data.copy() 50 | points = data[None] 51 | 52 | indices = np.random.randint(points.shape[0], size=self.N) 53 | data_out[None] = points[indices, :] 54 | 55 | if 'normals' in data.keys(): 56 | normals = data['normals'] 57 | data_out['normals'] = normals[indices, :] 58 | 59 | return data_out 60 | 61 | 62 | class SubsamplePoints(object): 63 | ''' Points subsampling transformation class. 64 | It subsamples the points data. 65 | Args: 66 | N (int): number of points to be subsampled 67 | ''' 68 | 69 | def __init__(self, N): 70 | self.N = N 71 | 72 | def __call__(self, data): 73 | ''' Calls the transformation. 74 | Args: 75 | data (dictionary): data dictionary 76 | ''' 77 | points = data[None] 78 | occ = data['occ'] 79 | 80 | data_out = data.copy() 81 | if isinstance(self.N, int): 82 | idx = np.random.randint(points.shape[0], size=self.N) 83 | data_out.update({ 84 | None: points[idx, :], 85 | 'occ': occ[idx], 86 | }) 87 | else: 88 | Nt_out, Nt_in = self.N 89 | occ_binary = (occ >= 0.5) 90 | points0 = points[~occ_binary] 91 | points1 = points[occ_binary] 92 | 93 | idx0 = np.random.randint(points0.shape[0], size=Nt_out) 94 | idx1 = np.random.randint(points1.shape[0], size=Nt_in) 95 | 96 | points0 = points0[idx0, :] 97 | points1 = points1[idx1, :] 98 | points = np.concatenate([points0, points1], axis=0) 99 | 100 | occ0 = np.zeros(Nt_out, dtype=np.float32) 101 | occ1 = np.ones(Nt_in, dtype=np.float32) 102 | occ = np.concatenate([occ0, occ1], axis=0) 103 | 104 | volume = occ_binary.sum() / len(occ_binary) 105 | volume = volume.astype(np.float32) 106 | 107 | data_out.update({ 108 | None: points, 109 | 'occ': occ, 110 | 'volume': volume, 111 | }) 112 | return data_out 113 | 114 | 115 | class SubsamplePointcloudSeq(object): 116 | ''' Point cloud sequence subsampling transformation class. 117 | 118 | It subsamples the point cloud sequence data. 119 | 120 | Args: 121 | N (int): number of points to be subsampled 122 | connected_samples (bool): whether to obtain connected samples 123 | random (bool): whether to sub-sample randomly 124 | ''' 125 | 126 | def __init__(self, N, connected_samples=False, random=True): 127 | self.N = N 128 | self.connected_samples = connected_samples 129 | self.random = random 130 | 131 | def __call__(self, data): 132 | ''' Calls the transformation. 133 | 134 | Args: 135 | data (dictionary): data dictionary 136 | ''' 137 | data_out = data.copy() 138 | points = data[None] # n_steps x T x 3 139 | n_steps, T, dim = points.shape 140 | N_max = min(self.N, T) 141 | if self.connected_samples or not self.random: 142 | indices = (np.random.randint(T, size=self.N) if self.random else 143 | np.arange(N_max)) 144 | data_out[None] = points[:, indices, :] 145 | else: 146 | indices = np.random.randint(T, size=(n_steps, self.N)) 147 | data_out[None] = \ 148 | points[np.arange(n_steps).reshape(-1, 1), indices, :] 149 | return data_out 150 | 151 | class SubsamplePointsSeq(object): 152 | ''' Points sequence subsampling transformation class. 153 | 154 | It subsamples the points sequence data. 155 | 156 | Args: 157 | N (int): number of points to be subsampled 158 | connected_samples (bool): whether to obtain connected samples 159 | random (bool): whether to sub-sample randomly 160 | ''' 161 | 162 | def __init__(self, N, connected_samples=False, random=True): 163 | self.N = N 164 | self.connected_samples = connected_samples 165 | self.random = random 166 | 167 | def __call__(self, data): 168 | ''' Calls the transformation. 169 | 170 | Args: 171 | data (dictionary): data dictionary 172 | ''' 173 | points = data[None] 174 | occ = data['occ'] 175 | data_out = data.copy() 176 | n_steps, T, dim = points.shape 177 | 178 | N_max = min(self.N, T) 179 | 180 | if self.connected_samples or not self.random: 181 | indices = (np.random.randint(T, size=self.N) if self.random 182 | else np.arange(N_max)) 183 | data_out.update({ 184 | None: points[:, indices], 185 | 'occ': occ[:, indices], 186 | }) 187 | else: 188 | indices = np.random.randint(T, size=(n_steps, self.N)) 189 | help_arr = np.arange(n_steps).reshape(-1, 1) 190 | data_out.update({ 191 | None: points[help_arr, indices, :], 192 | 'occ': occ[help_arr, indices, :] 193 | }) 194 | return data_out 195 | -------------------------------------------------------------------------------- /im2mesh/data/transforms.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/transforms.pyc -------------------------------------------------------------------------------- /im2mesh/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from im2mesh.encoder import pointnet 2 | 3 | 4 | encoder_dict = { 5 | 'pointnet_simple': pointnet.SimplePointnet, 6 | } 7 | 8 | encoder_temporal_dict = { 9 | 'pointnet_spatiotemporal': pointnet.SpatioTemporalResnetPointnet, 10 | 'pointnet_spatiotemporal2': pointnet.SpatioTemporalResnetPointnet2, 11 | } -------------------------------------------------------------------------------- /im2mesh/encoder/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/__init__.pyc -------------------------------------------------------------------------------- /im2mesh/encoder/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/encoder/__pycache__/conv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/__pycache__/conv.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/encoder/__pycache__/pointnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/__pycache__/pointnet.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/encoder/__pycache__/pointnet_unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/__pycache__/pointnet_unet.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/encoder/__pycache__/unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/__pycache__/unet.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/encoder/__pycache__/unet3d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/__pycache__/unet3d.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/encoder/conv.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/conv.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc/__init__.py: -------------------------------------------------------------------------------- 1 | from im2mesh.lpdc import ( 2 | config, generation, training, models 3 | ) 4 | 5 | __all__ = [ 6 | config, generation, training, models 7 | ] -------------------------------------------------------------------------------- /im2mesh/lpdc/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc/__pycache__/generation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/__pycache__/generation.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc/__pycache__/training.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/__pycache__/training.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc/models/__pycache__/decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/models/__pycache__/decoder.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc/models/__pycache__/decoder_unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/models/__pycache__/decoder_unet.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc/models/__pycache__/displacement.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/models/__pycache__/displacement.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc/models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from im2mesh.layers import ( 4 | ResnetBlockFC, CResnetBlockConv1d, 5 | CBatchNorm1d, CBatchNorm1d_legacy 6 | ) 7 | 8 | 9 | class Decoder(nn.Module): 10 | ''' Basic Decoder network for OFlow class. 11 | 12 | The decoder network maps points together with latent conditioned codes 13 | c and z to log probabilities of occupancy for the points. This basic 14 | decoder does not use batch normalization. 15 | 16 | Args: 17 | dim (int): dimension of input points 18 | z_dim (int): dimension of latent code z 19 | c_dim (int): dimension of latent conditioned code c 20 | hidden_size (int): dimension of hidden size 21 | leaky (bool): whether to use leaky ReLUs as activation 22 | ''' 23 | 24 | def __init__(self, dim=3, z_dim=128, c_dim=128, 25 | hidden_size=128, leaky=False, **kwargs): 26 | super().__init__() 27 | self.z_dim = z_dim 28 | self.c_dim = c_dim 29 | self.dim = dim 30 | 31 | # Submodules 32 | self.fc_p = nn.Linear(dim, hidden_size) 33 | 34 | if not z_dim == 0: 35 | self.fc_z = nn.Linear(z_dim, hidden_size) 36 | if not c_dim == 0: 37 | self.fc_c = nn.Linear(c_dim, hidden_size) 38 | 39 | self.block0 = ResnetBlockFC(hidden_size) 40 | self.block1 = ResnetBlockFC(hidden_size) 41 | self.block2 = ResnetBlockFC(hidden_size) 42 | self.block3 = ResnetBlockFC(hidden_size) 43 | self.block4 = ResnetBlockFC(hidden_size) 44 | 45 | self.fc_out = nn.Linear(hidden_size, 1) 46 | 47 | if not leaky: 48 | self.actvn = F.relu 49 | else: 50 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 51 | 52 | def forward(self, p, z=None, c=None, **kwargs): 53 | ''' Performs a forward pass through the network. 54 | 55 | Args: 56 | p (tensor): points tensor 57 | z (tensor): latent code z 58 | c (tensor): latent conditioned code c 59 | ''' 60 | batch_size = p.shape[0] 61 | p = p.view(batch_size, -1, self.dim) 62 | net = self.fc_p(p) 63 | 64 | if self.z_dim != 0: 65 | net_z = self.fc_z(z).unsqueeze(1) 66 | net = net + net_z 67 | 68 | if self.c_dim != 0: 69 | net_c = self.fc_c(c).unsqueeze(1) 70 | net = net + net_c 71 | 72 | net = self.block0(net) 73 | net = self.block1(net) 74 | net = self.block2(net) 75 | net = self.block3(net) 76 | net = self.block4(net) 77 | 78 | out = self.fc_out(self.actvn(net)) 79 | out = out.squeeze(-1) 80 | 81 | return out 82 | 83 | 84 | class DecoderCBatchNorm(nn.Module): 85 | ''' Conditioned Batch Norm Decoder network for OFlow class. 86 | 87 | The decoder network maps points together with latent conditioned codes 88 | c and z to log probabilities of occupancy for the points. This decoder 89 | uses conditioned batch normalization to inject the latent codes. 90 | 91 | Args: 92 | dim (int): dimension of input points 93 | z_dim (int): dimension of latent code z 94 | c_dim (int): dimension of latent conditioned code c 95 | hidden_size (int): dimension of hidden size 96 | leaky (bool): whether to use leaky ReLUs as activation 97 | 98 | ''' 99 | 100 | def __init__(self, dim=3, z_dim=128, c_dim=128, 101 | hidden_size=256, leaky=False, legacy=False): 102 | super().__init__() 103 | self.z_dim = z_dim 104 | self.dim = dim 105 | if not z_dim == 0: 106 | self.fc_z = nn.Linear(z_dim, hidden_size) 107 | 108 | self.fc_p = nn.Conv1d(dim, hidden_size, 1) 109 | self.block0 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy) 110 | self.block1 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy) 111 | self.block2 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy) 112 | self.block3 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy) 113 | self.block4 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy) 114 | 115 | if not legacy: 116 | self.bn = CBatchNorm1d(c_dim, hidden_size) 117 | else: 118 | self.bn = CBatchNorm1d_legacy(c_dim, hidden_size) 119 | 120 | self.fc_out = nn.Conv1d(hidden_size, 1, 1) 121 | 122 | if not leaky: 123 | self.actvn = F.relu 124 | else: 125 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 126 | 127 | def forward(self, p, z, c, **kwargs): 128 | ''' Performs a forward pass through the network. 129 | 130 | Args: 131 | p (tensor): points tensor 132 | z (tensor): latent code z 133 | c (tensor): latent conditioned code c 134 | ''' 135 | p = p.transpose(1, 2) 136 | batch_size, D, T = p.size() 137 | net = self.fc_p(p) 138 | 139 | net = self.block0(net, c) 140 | net = self.block1(net, c) 141 | net = self.block2(net, c) 142 | net = self.block3(net, c) 143 | net = self.block4(net, c) 144 | 145 | out = self.fc_out(self.actvn(self.bn(net, c))) 146 | out = out.squeeze(1) 147 | 148 | return out 149 | -------------------------------------------------------------------------------- /im2mesh/lpdc/models/displacement.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from im2mesh.layers import ResnetBlockFC 5 | 6 | class DisplacementDecoder(nn.Module): 7 | ''' DisplacementDecoder network class. 8 | 9 | It maps input points and time values together with (optional) conditioned 10 | codes c and latent codes z to the respective motion vectors. 11 | 12 | Args: 13 | in_dim (int): input dimension of points concatenated with the time axis 14 | out_dim (int): output dimension of motion vectors 15 | z_dim (int): dimension of latent code z 16 | c_dim (int): dimension of latent conditioned code c 17 | hidden_size (int): size of the hidden dimension 18 | leaky (bool): whether to use leaky ReLUs as activation 19 | n_blocks (int): number of ResNet-based blocks 20 | ''' 21 | 22 | def __init__(self, in_dim=3, out_dim=3, c_dim=128, 23 | hidden_size=512, leaky=False, n_blocks=5, **kwargs): 24 | super().__init__() 25 | self.c_dim = c_dim 26 | self.in_dim = in_dim 27 | self.out_dim = out_dim 28 | self.n_blocks = n_blocks 29 | # Submodules 30 | self.fc_p = nn.Linear(in_dim, hidden_size) 31 | 32 | self.fc_in = nn.Linear(c_dim * 2, c_dim) 33 | if c_dim != 0: 34 | self.fc_c = nn.ModuleList([ 35 | nn.Linear(c_dim, hidden_size) for i in range(n_blocks)]) 36 | 37 | self.blocks = nn.ModuleList([ 38 | ResnetBlockFC(hidden_size) for i in range(n_blocks) 39 | ]) 40 | 41 | self.fc_out = nn.Linear(hidden_size, self.out_dim) 42 | 43 | if not leaky: 44 | self.actvn = F.relu 45 | else: 46 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 47 | 48 | 49 | def forward(self, p, cur_t, fuc_t, c): 50 | batch_size, nsteps, c_dim = c.shape 51 | _, npoints, dim = p.shape 52 | cur_t = torch.clamp(cur_t[:, :, None] * nsteps, 0, nsteps-1).expand(batch_size, 1, c_dim).type(torch.LongTensor).to(p.device) 53 | fuc_t = torch.clamp(fuc_t[:, :, None] * nsteps, 0, nsteps-1).expand(batch_size, 1, c_dim).type(torch.LongTensor).to(p.device) 54 | cur_c = torch.gather(c, 1, cur_t) 55 | fuc_c = torch.gather(c, 1, fuc_t) 56 | # glo_c = torch.mean(c, dim=1) 57 | # concat_c = torch.cat([cur_c.squeeze(1), fuc_c.squeeze(1), glo_c], dim=1) 58 | concat_c = torch.cat([cur_c.squeeze(1), fuc_c.squeeze(1)], dim=1) 59 | concat_c = self.fc_in(concat_c) 60 | net = self.fc_p(p) 61 | 62 | # Layer loop 63 | for i in range(self.n_blocks): 64 | if self.c_dim != 0: 65 | net_c = self.fc_c[i](concat_c).unsqueeze(1) 66 | net = net + net_c 67 | net = self.blocks[i](net) 68 | 69 | out = self.fc_out(self.actvn(net)) 70 | return out -------------------------------------------------------------------------------- /im2mesh/lpdc_uneven/__init__.py: -------------------------------------------------------------------------------- 1 | from im2mesh.lpdc_uneven import ( 2 | config, generation, training, models 3 | ) 4 | 5 | __all__ = [ 6 | config, generation, training, models 7 | ] -------------------------------------------------------------------------------- /im2mesh/lpdc_uneven/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc_uneven/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc_uneven/__pycache__/generation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/__pycache__/generation.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc_uneven/__pycache__/training.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/__pycache__/training.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc_uneven/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc_uneven/models/__pycache__/decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/models/__pycache__/decoder.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc_uneven/models/__pycache__/decoder_unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/models/__pycache__/decoder_unet.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc_uneven/models/__pycache__/displacement.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/models/__pycache__/displacement.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/lpdc_uneven/models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from im2mesh.layers import ( 4 | ResnetBlockFC, CResnetBlockConv1d, 5 | CBatchNorm1d, CBatchNorm1d_legacy 6 | ) 7 | 8 | 9 | class Decoder(nn.Module): 10 | ''' Basic Decoder network for OFlow class. 11 | 12 | The decoder network maps points together with latent conditioned codes 13 | c and z to log probabilities of occupancy for the points. This basic 14 | decoder does not use batch normalization. 15 | 16 | Args: 17 | dim (int): dimension of input points 18 | z_dim (int): dimension of latent code z 19 | c_dim (int): dimension of latent conditioned code c 20 | hidden_size (int): dimension of hidden size 21 | leaky (bool): whether to use leaky ReLUs as activation 22 | ''' 23 | 24 | def __init__(self, dim=3, z_dim=128, c_dim=128, 25 | hidden_size=128, leaky=False, **kwargs): 26 | super().__init__() 27 | self.z_dim = z_dim 28 | self.c_dim = c_dim 29 | self.dim = dim 30 | 31 | # Submodules 32 | self.fc_p = nn.Linear(dim, hidden_size) 33 | 34 | if not z_dim == 0: 35 | self.fc_z = nn.Linear(z_dim, hidden_size) 36 | if not c_dim == 0: 37 | self.fc_c = nn.Linear(c_dim, hidden_size) 38 | 39 | self.block0 = ResnetBlockFC(hidden_size) 40 | self.block1 = ResnetBlockFC(hidden_size) 41 | self.block2 = ResnetBlockFC(hidden_size) 42 | self.block3 = ResnetBlockFC(hidden_size) 43 | self.block4 = ResnetBlockFC(hidden_size) 44 | 45 | self.fc_out = nn.Linear(hidden_size, 1) 46 | 47 | if not leaky: 48 | self.actvn = F.relu 49 | else: 50 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 51 | 52 | def forward(self, p, z=None, c=None, **kwargs): 53 | ''' Performs a forward pass through the network. 54 | 55 | Args: 56 | p (tensor): points tensor 57 | z (tensor): latent code z 58 | c (tensor): latent conditioned code c 59 | ''' 60 | batch_size = p.shape[0] 61 | p = p.view(batch_size, -1, self.dim) 62 | net = self.fc_p(p) 63 | 64 | if self.z_dim != 0: 65 | net_z = self.fc_z(z).unsqueeze(1) 66 | net = net + net_z 67 | 68 | if self.c_dim != 0: 69 | net_c = self.fc_c(c).unsqueeze(1) 70 | net = net + net_c 71 | 72 | net = self.block0(net) 73 | net = self.block1(net) 74 | net = self.block2(net) 75 | net = self.block3(net) 76 | net = self.block4(net) 77 | 78 | out = self.fc_out(self.actvn(net)) 79 | out = out.squeeze(-1) 80 | 81 | return out 82 | 83 | 84 | class DecoderCBatchNorm(nn.Module): 85 | ''' Conditioned Batch Norm Decoder network for OFlow class. 86 | 87 | The decoder network maps points together with latent conditioned codes 88 | c and z to log probabilities of occupancy for the points. This decoder 89 | uses conditioned batch normalization to inject the latent codes. 90 | 91 | Args: 92 | dim (int): dimension of input points 93 | z_dim (int): dimension of latent code z 94 | c_dim (int): dimension of latent conditioned code c 95 | hidden_size (int): dimension of hidden size 96 | leaky (bool): whether to use leaky ReLUs as activation 97 | 98 | ''' 99 | 100 | def __init__(self, dim=3, z_dim=128, c_dim=128, 101 | hidden_size=256, leaky=False, legacy=False): 102 | super().__init__() 103 | self.z_dim = z_dim 104 | self.dim = dim 105 | if not z_dim == 0: 106 | self.fc_z = nn.Linear(z_dim, hidden_size) 107 | 108 | self.fc_p = nn.Conv1d(dim, hidden_size, 1) 109 | self.block0 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy) 110 | self.block1 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy) 111 | self.block2 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy) 112 | self.block3 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy) 113 | self.block4 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy) 114 | 115 | if not legacy: 116 | self.bn = CBatchNorm1d(c_dim, hidden_size) 117 | else: 118 | self.bn = CBatchNorm1d_legacy(c_dim, hidden_size) 119 | 120 | self.fc_out = nn.Conv1d(hidden_size, 1, 1) 121 | 122 | if not leaky: 123 | self.actvn = F.relu 124 | else: 125 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 126 | 127 | def forward(self, p, z, c, **kwargs): 128 | ''' Performs a forward pass through the network. 129 | 130 | Args: 131 | p (tensor): points tensor 132 | z (tensor): latent code z 133 | c (tensor): latent conditioned code c 134 | ''' 135 | p = p.transpose(1, 2) 136 | batch_size, D, T = p.size() 137 | net = self.fc_p(p) 138 | 139 | net = self.block0(net, c) 140 | net = self.block1(net, c) 141 | net = self.block2(net, c) 142 | net = self.block3(net, c) 143 | net = self.block4(net, c) 144 | 145 | out = self.fc_out(self.actvn(self.bn(net, c))) 146 | out = out.squeeze(1) 147 | 148 | return out 149 | -------------------------------------------------------------------------------- /im2mesh/lpdc_uneven/models/displacement.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from im2mesh.layers import ResnetBlockFC 5 | 6 | class DisplacementDecoder(nn.Module): 7 | ''' DisplacementDecoder network class. 8 | 9 | It maps input points and time values together with (optional) conditioned 10 | codes c and latent codes z to the respective motion vectors. 11 | 12 | Args: 13 | in_dim (int): input dimension of points concatenated with the time axis 14 | out_dim (int): output dimension of motion vectors 15 | z_dim (int): dimension of latent code z 16 | c_dim (int): dimension of latent conditioned code c 17 | hidden_size (int): size of the hidden dimension 18 | leaky (bool): whether to use leaky ReLUs as activation 19 | n_blocks (int): number of ResNet-based blocks 20 | ''' 21 | 22 | def __init__(self, in_dim=3, out_dim=3, c_dim=128, 23 | hidden_size=512, leaky=False, n_blocks=5, **kwargs): 24 | super().__init__() 25 | self.c_dim = c_dim 26 | self.in_dim = in_dim 27 | self.out_dim = out_dim 28 | self.n_blocks = n_blocks 29 | # Submodules 30 | self.fc_p = nn.Linear(in_dim, hidden_size) 31 | 32 | self.fc_in = nn.Linear(c_dim * 2, c_dim) 33 | if c_dim != 0: 34 | self.fc_c = nn.ModuleList([ 35 | nn.Linear(c_dim, hidden_size) for i in range(n_blocks)]) 36 | 37 | self.blocks = nn.ModuleList([ 38 | ResnetBlockFC(hidden_size) for i in range(n_blocks) 39 | ]) 40 | 41 | self.fc_out = nn.Linear(hidden_size, self.out_dim) 42 | 43 | if not leaky: 44 | self.actvn = F.relu 45 | else: 46 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 47 | 48 | 49 | def forward(self, p, cur_t, fuc_t, c): 50 | batch_size, nsteps, c_dim = c.shape 51 | _, npoints, dim = p.shape 52 | cur_t = torch.clamp(cur_t[:, :, None] * nsteps, 0, nsteps-1).expand(batch_size, 1, c_dim).type(torch.LongTensor).to(p.device) 53 | fuc_t = torch.clamp(fuc_t[:, :, None] * nsteps, 0, nsteps-1).expand(batch_size, 1, c_dim).type(torch.LongTensor).to(p.device) 54 | cur_c = torch.gather(c, 1, cur_t) 55 | fuc_c = torch.gather(c, 1, fuc_t) 56 | #glo_c = torch.mean(c, dim=1) 57 | #concat_c = torch.cat([cur_c.squeeze(1), fuc_c.squeeze(1), glo_c], dim=1) 58 | concat_c = torch.cat([cur_c.squeeze(1), fuc_c.squeeze(1)], dim=1) 59 | concat_c = self.fc_in(concat_c) 60 | net = self.fc_p(p) 61 | 62 | # Layer loop 63 | for i in range(self.n_blocks): 64 | if self.c_dim != 0: 65 | net_c = self.fc_c[i](concat_c).unsqueeze(1) 66 | net = net + net_c 67 | net = self.blocks[i](net) 68 | 69 | out = self.fc_out(self.actvn(net)) 70 | return out 71 | -------------------------------------------------------------------------------- /im2mesh/training.py: -------------------------------------------------------------------------------- 1 | # from im2mesh import icp 2 | import numpy as np 3 | from collections import defaultdict 4 | from tqdm import tqdm 5 | 6 | 7 | class BaseTrainer(object): 8 | ''' Base trainer class. 9 | ''' 10 | 11 | def evaluate(self, val_loader): 12 | ''' Performs an evaluation. 13 | Args: 14 | val_loader (dataloader): Pytorch dataloader 15 | ''' 16 | eval_list = defaultdict(list) 17 | 18 | for data in tqdm(val_loader): 19 | eval_step_dict = self.eval_step(data) 20 | 21 | for k, v in eval_step_dict.items(): 22 | eval_list[k].append(v) 23 | 24 | eval_dict = {k: np.mean(v) for k, v in eval_list.items()} 25 | return eval_dict 26 | 27 | def train_step(self, *args, **kwargs): 28 | ''' Performs a training step. 29 | ''' 30 | raise NotImplementedError 31 | 32 | def eval_step(self, *args, **kwargs): 33 | ''' Performs an evaluation step. 34 | ''' 35 | raise NotImplementedError 36 | 37 | def visualize(self, *args, **kwargs): 38 | ''' Performs visualization. 39 | ''' 40 | raise NotImplementedError 41 | -------------------------------------------------------------------------------- /im2mesh/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/__init__.py -------------------------------------------------------------------------------- /im2mesh/utils/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/__init__.pyc -------------------------------------------------------------------------------- /im2mesh/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/utils/__pycache__/grad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/__pycache__/grad.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/utils/__pycache__/io.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/__pycache__/io.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/utils/__pycache__/onet_generator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/__pycache__/onet_generator.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/utils/__pycache__/visualize.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/__pycache__/visualize.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/utils/grad.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | from torch.autograd import grad 5 | 6 | def gradient(inputs, outputs): 7 | d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device) 8 | points_grad = grad( 9 | outputs=outputs, 10 | inputs=inputs, 11 | grad_outputs=d_points, 12 | create_graph=True, 13 | retain_graph=True, 14 | only_inputs=True)[0][:, -3:] 15 | return points_grad 16 | 17 | 18 | def gradient2(inputs, outputs): 19 | d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device) 20 | points_grad = grad( 21 | outputs=outputs, 22 | inputs=inputs, 23 | grad_outputs=d_points, 24 | create_graph=True, 25 | retain_graph=True, 26 | only_inputs=True)[0][:, : -3:] 27 | return points_grad -------------------------------------------------------------------------------- /im2mesh/utils/icp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.neighbors import NearestNeighbors 3 | 4 | 5 | def best_fit_transform(A, B): 6 | ''' 7 | Calculates the least-squares best-fit transform that maps corresponding 8 | points A to B in m spatial dimensions 9 | Input: 10 | A: Nxm numpy array of corresponding points 11 | B: Nxm numpy array of corresponding points 12 | Returns: 13 | T: (m+1)x(m+1) homogeneous transformation matrix that maps A on to B 14 | R: mxm rotation matrix 15 | t: mx1 translation vector 16 | ''' 17 | 18 | assert A.shape == B.shape 19 | 20 | # get number of dimensions 21 | m = A.shape[1] 22 | 23 | # translate points to their centroids 24 | centroid_A = np.mean(A, axis=0) 25 | centroid_B = np.mean(B, axis=0) 26 | AA = A - centroid_A 27 | BB = B - centroid_B 28 | 29 | # rotation matrix 30 | H = np.dot(AA.T, BB) 31 | U, S, Vt = np.linalg.svd(H) 32 | R = np.dot(Vt.T, U.T) 33 | 34 | # special reflection case 35 | if np.linalg.det(R) < 0: 36 | Vt[m-1,:] *= -1 37 | R = np.dot(Vt.T, U.T) 38 | 39 | # translation 40 | t = centroid_B.T - np.dot(R,centroid_A.T) 41 | 42 | # homogeneous transformation 43 | T = np.identity(m+1) 44 | T[:m, :m] = R 45 | T[:m, m] = t 46 | 47 | return T, R, t 48 | 49 | 50 | def nearest_neighbor(src, dst): 51 | ''' 52 | Find the nearest (Euclidean) neighbor in dst for each point in src 53 | Input: 54 | src: Nxm array of points 55 | dst: Nxm array of points 56 | Output: 57 | distances: Euclidean distances of the nearest neighbor 58 | indices: dst indices of the nearest neighbor 59 | ''' 60 | 61 | assert src.shape == dst.shape 62 | 63 | neigh = NearestNeighbors(n_neighbors=1) 64 | neigh.fit(dst) 65 | distances, indices = neigh.kneighbors(src, return_distance=True) 66 | return distances.ravel(), indices.ravel() 67 | 68 | 69 | def icp(A, B, init_pose=None, max_iterations=20, tolerance=0.001): 70 | ''' 71 | The Iterative Closest Point method: finds best-fit transform that maps 72 | points A on to points B 73 | Input: 74 | A: Nxm numpy array of source mD points 75 | B: Nxm numpy array of destination mD point 76 | init_pose: (m+1)x(m+1) homogeneous transformation 77 | max_iterations: exit algorithm after max_iterations 78 | tolerance: convergence criteria 79 | Output: 80 | T: final homogeneous transformation that maps A on to B 81 | distances: Euclidean distances (errors) of the nearest neighbor 82 | i: number of iterations to converge 83 | ''' 84 | 85 | assert A.shape == B.shape 86 | 87 | # get number of dimensions 88 | m = A.shape[1] 89 | 90 | # make points homogeneous, copy them to maintain the originals 91 | src = np.ones((m+1,A.shape[0])) 92 | dst = np.ones((m+1,B.shape[0])) 93 | src[:m,:] = np.copy(A.T) 94 | dst[:m,:] = np.copy(B.T) 95 | 96 | # apply the initial pose estimation 97 | if init_pose is not None: 98 | src = np.dot(init_pose, src) 99 | 100 | prev_error = 0 101 | 102 | for i in range(max_iterations): 103 | # find the nearest neighbors between the current source and destination points 104 | distances, indices = nearest_neighbor(src[:m,:].T, dst[:m,:].T) 105 | 106 | # compute the transformation between the current source and nearest destination points 107 | T,_,_ = best_fit_transform(src[:m,:].T, dst[:m,indices].T) 108 | 109 | # update the current source 110 | src = np.dot(T, src) 111 | 112 | # check error 113 | mean_error = np.mean(distances) 114 | if np.abs(prev_error - mean_error) < tolerance: 115 | break 116 | prev_error = mean_error 117 | 118 | # calculate final transformation 119 | T,_,_ = best_fit_transform(A, src[:m,:].T) 120 | 121 | return T, distances, i 122 | -------------------------------------------------------------------------------- /im2mesh/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | from plyfile import PlyElement, PlyData 3 | import numpy as np 4 | from trimesh.util import array_to_string 5 | import trimesh 6 | 7 | def export_pointcloud(vertices, out_file, as_text=True): 8 | assert(vertices.shape[1] == 3) 9 | vertices = vertices.astype(np.float32) 10 | vertices = np.ascontiguousarray(vertices) 11 | vector_dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')] 12 | vertices = vertices.view(dtype=vector_dtype).flatten() 13 | plyel = PlyElement.describe(vertices, 'vertex') 14 | plydata = PlyData([plyel], text=as_text) 15 | plydata.write(out_file) 16 | 17 | 18 | def load_pointcloud(in_file): 19 | plydata = PlyData.read(in_file) 20 | vertices = np.stack([ 21 | plydata['vertex']['x'], 22 | plydata['vertex']['y'], 23 | plydata['vertex']['z'] 24 | ], axis=1) 25 | return vertices 26 | 27 | 28 | def read_off(file): 29 | """ 30 | Reads vertices and faces from an off file. 31 | 32 | :param file: path to file to read 33 | :type file: str 34 | :return: vertices and faces as lists of tuples 35 | :rtype: [(float)], [(int)] 36 | """ 37 | 38 | assert os.path.exists(file), 'file %s not found' % file 39 | 40 | with open(file, 'r') as fp: 41 | lines = fp.readlines() 42 | lines = [line.strip() for line in lines] 43 | 44 | # Fix for ModelNet bug were 'OFF' and the number of vertices and faces 45 | # are all in the first line. 46 | if len(lines[0]) > 3: 47 | assert lines[0][:3] == 'OFF' or lines[0][:3] == 'off', \ 48 | 'invalid OFF file %s' % file 49 | 50 | parts = lines[0][3:].split(' ') 51 | assert len(parts) == 3 52 | 53 | num_vertices = int(parts[0]) 54 | assert num_vertices > 0 55 | 56 | num_faces = int(parts[1]) 57 | assert num_faces > 0 58 | 59 | start_index = 1 60 | # This is the regular case! 61 | else: 62 | assert lines[0] == 'OFF' or lines[0] == 'off', \ 63 | 'invalid OFF file %s' % file 64 | 65 | parts = lines[1].split(' ') 66 | assert len(parts) == 3 67 | 68 | num_vertices = int(parts[0]) 69 | assert num_vertices > 0 70 | 71 | num_faces = int(parts[1]) 72 | assert num_faces > 0 73 | 74 | start_index = 2 75 | 76 | vertices = [] 77 | for i in range(num_vertices): 78 | vertex = lines[start_index + i].split(' ') 79 | vertex = [float(point.strip()) for point in vertex if point != ''] 80 | assert len(vertex) == 3 81 | 82 | vertices.append(vertex) 83 | 84 | faces = [] 85 | for i in range(num_faces): 86 | face = lines[start_index + num_vertices + i].split(' ') 87 | face = [index.strip() for index in face if index != ''] 88 | 89 | # check to be sure 90 | for index in face: 91 | assert index != '', \ 92 | 'found empty vertex index: %s (%s)' \ 93 | % (lines[start_index + num_vertices + i], file) 94 | 95 | face = [int(index) for index in face] 96 | 97 | assert face[0] == len(face) - 1, \ 98 | 'face should have %d vertices but as %d (%s)' \ 99 | % (face[0], len(face) - 1, file) 100 | assert face[0] == 3, \ 101 | 'only triangular meshes supported (%s)' % file 102 | for index in face: 103 | assert index >= 0 and index < num_vertices, \ 104 | 'vertex %d (of %d vertices) does not exist (%s)' \ 105 | % (index, num_vertices, file) 106 | 107 | assert len(face) > 1 108 | 109 | faces.append(face) 110 | 111 | return vertices, faces 112 | 113 | assert False, 'could not open %s' % file 114 | 115 | 116 | def save_mesh(mesh, out_file, digits=10, face_colors=None): 117 | digits = int(digits) 118 | # prepend a 3 (face count) to each face 119 | if face_colors is None: 120 | faces_stacked = np.column_stack(( 121 | np.ones(len(mesh.faces)) * 3, mesh.faces)).astype(np.int64) 122 | else: 123 | mesh.visual.face_colors = face_colors 124 | assert(mesh.visual.face_colors.shape[0] == mesh.faces.shape[0]) 125 | faces_stacked = np.column_stack(( 126 | np.ones(len(mesh.faces)) * 3, mesh.faces, 127 | mesh.visual.face_colors[:, :3])).astype(np.int64) 128 | export = 'OFF\n' 129 | # the header is vertex count, face count, edge number 130 | export += str(len(mesh.vertices)) + ' ' + str(len(mesh.faces)) + ' 0\n' 131 | export += array_to_string( 132 | mesh.vertices, col_delim=' ', row_delim='\n', digits=digits) + '\n' 133 | export += array_to_string(faces_stacked, col_delim=' ', row_delim='\n') 134 | 135 | with open(out_file, 'w') as f: 136 | f.write(export) 137 | 138 | return mesh 139 | 140 | 141 | def load_mesh(mesh_file): 142 | with open(mesh_file, 'r') as f: 143 | str_file = f.read().split('\n') 144 | n_vertices, n_faces, _ = list( 145 | map(lambda x: int(x), str_file[1].split(' '))) 146 | str_file = str_file[2:] # Remove first 2 lines 147 | 148 | v = [l.split(' ') for l in str_file[:n_vertices]] 149 | f = [l.split(' ') for l in str_file[n_vertices:]] 150 | 151 | v = np.array(v).astype(np.float32) 152 | f = np.array(f).astype(np.uint64)[:, 1:4] 153 | 154 | mesh = trimesh.Trimesh(vertices=v, faces=f, process=False) 155 | 156 | return mesh -------------------------------------------------------------------------------- /im2mesh/utils/libkdtree/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | -------------------------------------------------------------------------------- /im2mesh/utils/libkdtree/MANIFEST.in: -------------------------------------------------------------------------------- 1 | exclude pykdtree/render_template.py 2 | include LICENSE.txt 3 | -------------------------------------------------------------------------------- /im2mesh/utils/libkdtree/__init__.py: -------------------------------------------------------------------------------- 1 | from .pykdtree.kdtree import KDTree 2 | 3 | 4 | __all__ = [ 5 | KDTree 6 | ] 7 | -------------------------------------------------------------------------------- /im2mesh/utils/libkdtree/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libkdtree/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/utils/libkdtree/pykdtree/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libkdtree/pykdtree/__init__.py -------------------------------------------------------------------------------- /im2mesh/utils/libkdtree/pykdtree/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libkdtree/pykdtree/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/utils/libkdtree/pykdtree/kdtree.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libkdtree/pykdtree/kdtree.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /im2mesh/utils/libkdtree/pykdtree/render_template.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from mako.template import Template 4 | 5 | mytemplate = Template(filename='_kdtree_core.c.mako') 6 | with open('_kdtree_core.c', 'w') as fp: 7 | fp.write(mytemplate.render()) 8 | -------------------------------------------------------------------------------- /im2mesh/utils/libkdtree/setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_rpm] 2 | requires=numpy 3 | release=1 4 | 5 | 6 | -------------------------------------------------------------------------------- /im2mesh/utils/libkdtree/setup.py: -------------------------------------------------------------------------------- 1 | #pykdtree, Fast kd-tree implementation with OpenMP-enabled queries 2 | # 3 | #Copyright (C) 2013 - present Esben S. Nielsen 4 | # 5 | # This program is free software: you can redistribute it and/or modify it under 6 | # the terms of the GNU Lesser General Public License as published by the Free 7 | # Software Foundation, either version 3 of the License, or 8 | #(at your option) any later version. 9 | # 10 | # This program is distributed in the hope that it will be useful, but WITHOUT 11 | # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 12 | # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more 13 | # details. 14 | # 15 | # You should have received a copy of the GNU Lesser General Public License along 16 | # with this program. If not, see . 17 | 18 | import os 19 | from setuptools import setup, Extension 20 | from setuptools.command.build_ext import build_ext 21 | 22 | # Get OpenMP setting from environment 23 | try: 24 | use_omp = int(os.environ['USE_OMP']) 25 | except KeyError: 26 | use_omp = True 27 | 28 | 29 | def set_builtin(name, value): 30 | if isinstance(__builtins__, dict): 31 | __builtins__[name] = value 32 | else: 33 | setattr(__builtins__, name, value) 34 | 35 | 36 | # Custom builder to handler compiler flags. Edit if needed. 37 | class build_ext_subclass(build_ext): 38 | def build_extensions(self): 39 | comp = self.compiler.compiler_type 40 | if comp in ('unix', 'cygwin', 'mingw32'): 41 | # Check if build is with OpenMP 42 | if use_omp: 43 | extra_compile_args = ['-std=c99', '-O3', '-fopenmp'] 44 | extra_link_args=['-lgomp'] 45 | else: 46 | extra_compile_args = ['-std=c99', '-O3'] 47 | extra_link_args = [] 48 | elif comp == 'msvc': 49 | extra_compile_args = ['/Ox'] 50 | extra_link_args = [] 51 | if use_omp: 52 | extra_compile_args.append('/openmp') 53 | else: 54 | # Add support for more compilers here 55 | raise ValueError('Compiler flags undefined for %s. Please modify setup.py and add compiler flags' 56 | % comp) 57 | self.extensions[0].extra_compile_args = extra_compile_args 58 | self.extensions[0].extra_link_args = extra_link_args 59 | build_ext.build_extensions(self) 60 | 61 | def finalize_options(self): 62 | ''' 63 | In order to avoid premature import of numpy before it gets installed as a dependency 64 | get numpy include directories during the extensions building process 65 | http://stackoverflow.com/questions/19919905/how-to-bootstrap-numpy-installation-in-setup-py 66 | ''' 67 | build_ext.finalize_options(self) 68 | # Prevent numpy from thinking it is still in its setup process: 69 | set_builtin('__NUMPY_SETUP__', False) 70 | import numpy 71 | self.include_dirs.append(numpy.get_include()) 72 | 73 | 74 | setup( 75 | name='pykdtree', 76 | version='1.3.1', 77 | description='Fast kd-tree implementation with OpenMP-enabled queries', 78 | author='Esben S. Nielsen', 79 | author_email='storpipfugl@gmail.com', 80 | packages = ['pykdtree'], 81 | python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*', 82 | install_requires=['numpy'], 83 | setup_requires=['numpy'], 84 | tests_require=['nose'], 85 | zip_safe=False, 86 | test_suite = 'nose.collector', 87 | ext_modules = [Extension('pykdtree.kdtree', 88 | ['pykdtree/kdtree.c', 'pykdtree/_kdtree_core.c'])], 89 | cmdclass = {'build_ext': build_ext_subclass }, 90 | classifiers=[ 91 | 'Development Status :: 5 - Production/Stable', 92 | ('License :: OSI Approved :: ' 93 | 'GNU Lesser General Public License v3 (LGPLv3)'), 94 | 'Programming Language :: Python', 95 | 'Operating System :: OS Independent', 96 | 'Intended Audience :: Science/Research', 97 | 'Topic :: Scientific/Engineering' 98 | ] 99 | ) 100 | -------------------------------------------------------------------------------- /im2mesh/utils/libmcubes/.gitignore: -------------------------------------------------------------------------------- 1 | PyMCubes.egg-info 2 | build 3 | -------------------------------------------------------------------------------- /im2mesh/utils/libmcubes/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2015, P. M. Neila 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /im2mesh/utils/libmcubes/README.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | PyMCubes 3 | ======== 4 | 5 | PyMCubes is an implementation of the marching cubes algorithm to extract 6 | isosurfaces from volumetric data. The volumetric data can be given as a 7 | three-dimensional NumPy array or as a Python function ``f(x, y, z)``. The first 8 | option is much faster, but it requires more memory and becomes unfeasible for 9 | very large volumes. 10 | 11 | PyMCubes also provides a function to export the results of the marching cubes as 12 | COLLADA ``(.dae)`` files. This requires the 13 | `PyCollada `_ library. 14 | 15 | Installation 16 | ============ 17 | 18 | Just as any standard Python package, clone or download the project 19 | and run:: 20 | 21 | $ cd path/to/PyMCubes 22 | $ python setup.py build 23 | $ python setup.py install 24 | 25 | If you do not have write permission on the directory of Python packages, 26 | install with the ``--user`` option:: 27 | 28 | $ python setup.py install --user 29 | 30 | Example 31 | ======= 32 | 33 | The following example creates a data volume with spherical isosurfaces and 34 | extracts one of them (i.e., a sphere) with PyMCubes. The result is exported as 35 | ``sphere.dae``:: 36 | 37 | >>> import numpy as np 38 | >>> import mcubes 39 | 40 | # Create a data volume (30 x 30 x 30) 41 | >>> X, Y, Z = np.mgrid[:30, :30, :30] 42 | >>> u = (X-15)**2 + (Y-15)**2 + (Z-15)**2 - 8**2 43 | 44 | # Extract the 0-isosurface 45 | >>> vertices, triangles = mcubes.marching_cubes(u, 0) 46 | 47 | # Export the result to sphere.dae 48 | >>> mcubes.export_mesh(vertices, triangles, "sphere.dae", "MySphere") 49 | 50 | The second example is very similar to the first one, but it uses a function 51 | to represent the volume instead of a NumPy array:: 52 | 53 | >>> import numpy as np 54 | >>> import mcubes 55 | 56 | # Create the volume 57 | >>> f = lambda x, y, z: x**2 + y**2 + z**2 58 | 59 | # Extract the 16-isosurface 60 | >>> vertices, triangles = mcubes.marching_cubes_func((-10,-10,-10), (10,10,10), 61 | ... 100, 100, 100, f, 16) 62 | 63 | # Export the result to sphere2.dae 64 | >>> mcubes.export_mesh(vertices, triangles, "sphere2.dae", "MySphere") 65 | -------------------------------------------------------------------------------- /im2mesh/utils/libmcubes/__init__.py: -------------------------------------------------------------------------------- 1 | from im2mesh.utils.libmcubes.mcubes import ( 2 | marching_cubes, marching_cubes_func 3 | ) 4 | from im2mesh.utils.libmcubes.exporter import ( 5 | export_mesh, export_obj, export_off 6 | ) 7 | 8 | 9 | __all__ = [ 10 | marching_cubes, marching_cubes_func, 11 | export_mesh, export_obj, export_off 12 | ] 13 | -------------------------------------------------------------------------------- /im2mesh/utils/libmcubes/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmcubes/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/utils/libmcubes/__pycache__/exporter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmcubes/__pycache__/exporter.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/utils/libmcubes/exporter.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | 5 | def export_obj(vertices, triangles, filename): 6 | """ 7 | Exports a mesh in the (.obj) format. 8 | """ 9 | 10 | with open(filename, 'w') as fh: 11 | 12 | for v in vertices: 13 | fh.write("v {} {} {}\n".format(*v)) 14 | 15 | for f in triangles: 16 | fh.write("f {} {} {}\n".format(*(f + 1))) 17 | 18 | 19 | def export_off(vertices, triangles, filename): 20 | """ 21 | Exports a mesh in the (.off) format. 22 | """ 23 | 24 | with open(filename, 'w') as fh: 25 | fh.write('OFF\n') 26 | fh.write('{} {} 0\n'.format(len(vertices), len(triangles))) 27 | 28 | for v in vertices: 29 | fh.write("{} {} {}\n".format(*v)) 30 | 31 | for f in triangles: 32 | fh.write("3 {} {} {}\n".format(*f)) 33 | 34 | 35 | def export_mesh(vertices, triangles, filename, mesh_name="mcubes_mesh"): 36 | """ 37 | Exports a mesh in the COLLADA (.dae) format. 38 | 39 | Needs PyCollada (https://github.com/pycollada/pycollada). 40 | """ 41 | 42 | import collada 43 | 44 | mesh = collada.Collada() 45 | 46 | vert_src = collada.source.FloatSource("verts-array", vertices, ('X','Y','Z')) 47 | geom = collada.geometry.Geometry(mesh, "geometry0", mesh_name, [vert_src]) 48 | 49 | input_list = collada.source.InputList() 50 | input_list.addInput(0, 'VERTEX', "#verts-array") 51 | 52 | triset = geom.createTriangleSet(np.copy(triangles), input_list, "") 53 | geom.primitives.append(triset) 54 | mesh.geometries.append(geom) 55 | 56 | geomnode = collada.scene.GeometryNode(geom, []) 57 | node = collada.scene.Node(mesh_name, children=[geomnode]) 58 | 59 | myscene = collada.scene.Scene("mcubes_scene", [node]) 60 | mesh.scenes.append(myscene) 61 | mesh.scene = myscene 62 | 63 | mesh.write(filename) 64 | -------------------------------------------------------------------------------- /im2mesh/utils/libmcubes/mcubes.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmcubes/mcubes.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /im2mesh/utils/libmcubes/mcubes.pyx: -------------------------------------------------------------------------------- 1 | 2 | # distutils: language = c++ 3 | # cython: embedsignature = True 4 | 5 | # from libcpp.vector cimport vector 6 | import numpy as np 7 | 8 | # Define PY_ARRAY_UNIQUE_SYMBOL 9 | cdef extern from "pyarray_symbol.h": 10 | pass 11 | 12 | cimport numpy as np 13 | 14 | np.import_array() 15 | 16 | cdef extern from "pywrapper.h": 17 | cdef object c_marching_cubes "marching_cubes"(np.ndarray, double) except + 18 | cdef object c_marching_cubes2 "marching_cubes2"(np.ndarray, double) except + 19 | cdef object c_marching_cubes3 "marching_cubes3"(np.ndarray, double) except + 20 | cdef object c_marching_cubes_func "marching_cubes_func"(tuple, tuple, int, int, int, object, double) except + 21 | 22 | def marching_cubes(np.ndarray volume, float isovalue): 23 | 24 | verts, faces = c_marching_cubes(volume, isovalue) 25 | verts.shape = (-1, 3) 26 | faces.shape = (-1, 3) 27 | return verts, faces 28 | 29 | def marching_cubes2(np.ndarray volume, float isovalue): 30 | 31 | verts, faces = c_marching_cubes2(volume, isovalue) 32 | verts.shape = (-1, 3) 33 | faces.shape = (-1, 3) 34 | return verts, faces 35 | 36 | def marching_cubes3(np.ndarray volume, float isovalue): 37 | 38 | verts, faces = c_marching_cubes3(volume, isovalue) 39 | verts.shape = (-1, 3) 40 | faces.shape = (-1, 3) 41 | return verts, faces 42 | 43 | def marching_cubes_func(tuple lower, tuple upper, int numx, int numy, int numz, object f, double isovalue): 44 | 45 | verts, faces = c_marching_cubes_func(lower, upper, numx, numy, numz, f, isovalue) 46 | verts.shape = (-1, 3) 47 | faces.shape = (-1, 3) 48 | return verts, faces 49 | -------------------------------------------------------------------------------- /im2mesh/utils/libmcubes/pyarray_symbol.h: -------------------------------------------------------------------------------- 1 | 2 | #define PY_ARRAY_UNIQUE_SYMBOL mcubes_PyArray_API 3 | -------------------------------------------------------------------------------- /im2mesh/utils/libmcubes/pyarraymodule.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _EXTMODULE_H 3 | #define _EXTMODULE_H 4 | 5 | #include 6 | #include 7 | 8 | // #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION 9 | #define PY_ARRAY_UNIQUE_SYMBOL mcubes_PyArray_API 10 | #define NO_IMPORT_ARRAY 11 | #include "numpy/arrayobject.h" 12 | 13 | #include 14 | 15 | template 16 | struct numpy_typemap; 17 | 18 | #define define_numpy_type(ctype, dtype) \ 19 | template<> \ 20 | struct numpy_typemap \ 21 | {static const int type = dtype;}; 22 | 23 | define_numpy_type(bool, NPY_BOOL); 24 | define_numpy_type(char, NPY_BYTE); 25 | define_numpy_type(short, NPY_SHORT); 26 | define_numpy_type(int, NPY_INT); 27 | define_numpy_type(long, NPY_LONG); 28 | define_numpy_type(long long, NPY_LONGLONG); 29 | define_numpy_type(unsigned char, NPY_UBYTE); 30 | define_numpy_type(unsigned short, NPY_USHORT); 31 | define_numpy_type(unsigned int, NPY_UINT); 32 | define_numpy_type(unsigned long, NPY_ULONG); 33 | define_numpy_type(unsigned long long, NPY_ULONGLONG); 34 | define_numpy_type(float, NPY_FLOAT); 35 | define_numpy_type(double, NPY_DOUBLE); 36 | define_numpy_type(long double, NPY_LONGDOUBLE); 37 | define_numpy_type(std::complex, NPY_CFLOAT); 38 | define_numpy_type(std::complex, NPY_CDOUBLE); 39 | define_numpy_type(std::complex, NPY_CLONGDOUBLE); 40 | 41 | template 42 | T PyArray_SafeGet(const PyArrayObject* aobj, const npy_intp* indaux) 43 | { 44 | // HORROR. 45 | npy_intp* ind = const_cast(indaux); 46 | void* ptr = PyArray_GetPtr(const_cast(aobj), ind); 47 | switch(PyArray_TYPE(aobj)) 48 | { 49 | case NPY_BOOL: 50 | return static_cast(*reinterpret_cast(ptr)); 51 | case NPY_BYTE: 52 | return static_cast(*reinterpret_cast(ptr)); 53 | case NPY_SHORT: 54 | return static_cast(*reinterpret_cast(ptr)); 55 | case NPY_INT: 56 | return static_cast(*reinterpret_cast(ptr)); 57 | case NPY_LONG: 58 | return static_cast(*reinterpret_cast(ptr)); 59 | case NPY_LONGLONG: 60 | return static_cast(*reinterpret_cast(ptr)); 61 | case NPY_UBYTE: 62 | return static_cast(*reinterpret_cast(ptr)); 63 | case NPY_USHORT: 64 | return static_cast(*reinterpret_cast(ptr)); 65 | case NPY_UINT: 66 | return static_cast(*reinterpret_cast(ptr)); 67 | case NPY_ULONG: 68 | return static_cast(*reinterpret_cast(ptr)); 69 | case NPY_ULONGLONG: 70 | return static_cast(*reinterpret_cast(ptr)); 71 | case NPY_FLOAT: 72 | return static_cast(*reinterpret_cast(ptr)); 73 | case NPY_DOUBLE: 74 | return static_cast(*reinterpret_cast(ptr)); 75 | case NPY_LONGDOUBLE: 76 | return static_cast(*reinterpret_cast(ptr)); 77 | default: 78 | throw std::runtime_error("data type not supported"); 79 | } 80 | } 81 | 82 | template 83 | T PyArray_SafeSet(PyArrayObject* aobj, const npy_intp* indaux, const T& value) 84 | { 85 | // HORROR. 86 | npy_intp* ind = const_cast(indaux); 87 | void* ptr = PyArray_GetPtr(aobj, ind); 88 | switch(PyArray_TYPE(aobj)) 89 | { 90 | case NPY_BOOL: 91 | *reinterpret_cast(ptr) = static_cast(value); 92 | break; 93 | case NPY_BYTE: 94 | *reinterpret_cast(ptr) = static_cast(value); 95 | break; 96 | case NPY_SHORT: 97 | *reinterpret_cast(ptr) = static_cast(value); 98 | break; 99 | case NPY_INT: 100 | *reinterpret_cast(ptr) = static_cast(value); 101 | break; 102 | case NPY_LONG: 103 | *reinterpret_cast(ptr) = static_cast(value); 104 | break; 105 | case NPY_LONGLONG: 106 | *reinterpret_cast(ptr) = static_cast(value); 107 | break; 108 | case NPY_UBYTE: 109 | *reinterpret_cast(ptr) = static_cast(value); 110 | break; 111 | case NPY_USHORT: 112 | *reinterpret_cast(ptr) = static_cast(value); 113 | break; 114 | case NPY_UINT: 115 | *reinterpret_cast(ptr) = static_cast(value); 116 | break; 117 | case NPY_ULONG: 118 | *reinterpret_cast(ptr) = static_cast(value); 119 | break; 120 | case NPY_ULONGLONG: 121 | *reinterpret_cast(ptr) = static_cast(value); 122 | break; 123 | case NPY_FLOAT: 124 | *reinterpret_cast(ptr) = static_cast(value); 125 | break; 126 | case NPY_DOUBLE: 127 | *reinterpret_cast(ptr) = static_cast(value); 128 | break; 129 | case NPY_LONGDOUBLE: 130 | *reinterpret_cast(ptr) = static_cast(value); 131 | break; 132 | default: 133 | throw std::runtime_error("data type not supported"); 134 | } 135 | } 136 | 137 | #endif 138 | -------------------------------------------------------------------------------- /im2mesh/utils/libmcubes/pywrapper.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _PYWRAPPER_H 3 | #define _PYWRAPPER_H 4 | 5 | #include 6 | #include "pyarraymodule.h" 7 | 8 | #include 9 | 10 | PyObject* marching_cubes(PyArrayObject* arr, double isovalue); 11 | PyObject* marching_cubes2(PyArrayObject* arr, double isovalue); 12 | PyObject* marching_cubes3(PyArrayObject* arr, double isovalue); 13 | PyObject* marching_cubes_func(PyObject* lower, PyObject* upper, 14 | int numx, int numy, int numz, PyObject* f, double isovalue); 15 | 16 | #endif // _PYWRAPPER_H 17 | -------------------------------------------------------------------------------- /im2mesh/utils/libmcubes/setup.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | 3 | try: 4 | from setuptools import setup 5 | except ImportError: 6 | from distutils.core import setup 7 | 8 | from Cython.Build import cythonize 9 | 10 | import numpy 11 | from distutils.extension import Extension 12 | 13 | # Get the version number. 14 | numpy_include_dir = numpy.get_include() 15 | 16 | mcubes_module = Extension( 17 | "mcubes", 18 | [ 19 | "mcubes.pyx", 20 | "pywrapper.cpp", 21 | "marchingcubes.cpp" 22 | ], 23 | language="c++", 24 | extra_compile_args=['-std=c++11'], 25 | include_dirs=[numpy_include_dir] 26 | ) 27 | 28 | setup(name="PyMCubes", 29 | version="0.0.6", 30 | description="Marching cubes for Python", 31 | author="Pablo Márquez Neila", 32 | author_email="pablo.marquezneila@epfl.ch", 33 | url="https://github.com/pmneila/PyMCubes", 34 | license="BSD 3-clause", 35 | long_description=""" 36 | Marching cubes for Python 37 | """, 38 | classifiers=[ 39 | "Development Status :: 4 - Beta", 40 | "Environment :: Console", 41 | "Intended Audience :: Developers", 42 | "Intended Audience :: Science/Research", 43 | "License :: OSI Approved :: BSD License", 44 | "Natural Language :: English", 45 | "Operating System :: OS Independent", 46 | "Programming Language :: C++", 47 | "Programming Language :: Python", 48 | "Topic :: Multimedia :: Graphics :: 3D Modeling", 49 | "Topic :: Scientific/Engineering :: Image Recognition", 50 | ], 51 | packages=["mcubes"], 52 | ext_modules=cythonize(mcubes_module), 53 | requires=['numpy', 'Cython', 'PyCollada'], 54 | setup_requires=['numpy', 'Cython'] 55 | ) 56 | -------------------------------------------------------------------------------- /im2mesh/utils/libmesh/.gitignore: -------------------------------------------------------------------------------- 1 | triangle_hash.cpp 2 | build 3 | -------------------------------------------------------------------------------- /im2mesh/utils/libmesh/__init__.py: -------------------------------------------------------------------------------- 1 | from .inside_mesh import ( 2 | check_mesh_contains, MeshIntersector, TriangleIntersector2d 3 | ) 4 | 5 | 6 | __all__ = [ 7 | check_mesh_contains, MeshIntersector, TriangleIntersector2d 8 | ] 9 | -------------------------------------------------------------------------------- /im2mesh/utils/libmesh/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmesh/__init__.pyc -------------------------------------------------------------------------------- /im2mesh/utils/libmesh/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmesh/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/utils/libmesh/__pycache__/inside_mesh.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmesh/__pycache__/inside_mesh.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/utils/libmesh/inside_mesh.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .triangle_hash import TriangleHash as _TriangleHash 3 | 4 | 5 | def check_mesh_contains(mesh, points, hash_resolution=512): 6 | intersector = MeshIntersector(mesh, hash_resolution) 7 | contains = intersector.query(points) 8 | return contains 9 | 10 | 11 | class MeshIntersector: 12 | def __init__(self, mesh, resolution=512): 13 | triangles = mesh.vertices[mesh.faces].astype(np.float64) 14 | n_tri = triangles.shape[0] 15 | 16 | self.resolution = resolution 17 | self.bbox_min = triangles.reshape(3 * n_tri, 3).min(axis=0) 18 | self.bbox_max = triangles.reshape(3 * n_tri, 3).max(axis=0) 19 | # Tranlate and scale it to [0.5, self.resolution - 0.5]^3 20 | self.scale = (resolution - 1) / (self.bbox_max - self.bbox_min) 21 | self.translate = 0.5 - self.scale * self.bbox_min 22 | 23 | self._triangles = triangles = self.rescale(triangles) 24 | # assert(np.allclose(triangles.reshape(-1, 3).min(0), 0.5)) 25 | # assert(np.allclose(triangles.reshape(-1, 3).max(0), resolution - 0.5)) 26 | 27 | triangles2d = triangles[:, :, :2] 28 | self._tri_intersector2d = TriangleIntersector2d( 29 | triangles2d, resolution) 30 | 31 | def query(self, points): 32 | # Rescale points 33 | points = self.rescale(points) 34 | 35 | # placeholder result with no hits we'll fill in later 36 | contains = np.zeros(len(points), dtype=np.bool) 37 | 38 | # cull points outside of the axis aligned bounding box 39 | # this avoids running ray tests unless points are close 40 | inside_aabb = np.all( 41 | (0 <= points) & (points <= self.resolution), axis=1) 42 | if not inside_aabb.any(): 43 | return contains 44 | 45 | # Only consider points inside bounding box 46 | mask = inside_aabb 47 | points = points[mask] 48 | 49 | # Compute intersection depth and check order 50 | points_indices, tri_indices = self._tri_intersector2d.query(points[:, :2]) 51 | 52 | triangles_intersect = self._triangles[tri_indices] 53 | points_intersect = points[points_indices] 54 | 55 | depth_intersect, abs_n_2 = self.compute_intersection_depth( 56 | points_intersect, triangles_intersect) 57 | 58 | # Count number of intersections in both directions 59 | smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2 60 | bigger_depth = depth_intersect < points_intersect[:, 2] * abs_n_2 61 | points_indices_0 = points_indices[smaller_depth] 62 | points_indices_1 = points_indices[bigger_depth] 63 | 64 | nintersect0 = np.bincount(points_indices_0, minlength=points.shape[0]) 65 | nintersect1 = np.bincount(points_indices_1, minlength=points.shape[0]) 66 | 67 | # Check if point contained in mesh 68 | contains1 = (np.mod(nintersect0, 2) == 1) 69 | contains2 = (np.mod(nintersect1, 2) == 1) 70 | if (contains1 != contains2).any(): 71 | print('Warning: contains1 != contains2 for some points.') 72 | contains[mask] = (contains1 & contains2) 73 | return contains 74 | 75 | def compute_intersection_depth(self, points, triangles): 76 | t1 = triangles[:, 0, :] 77 | t2 = triangles[:, 1, :] 78 | t3 = triangles[:, 2, :] 79 | 80 | v1 = t3 - t1 81 | v2 = t2 - t1 82 | # v1 = v1 / np.linalg.norm(v1, axis=-1, keepdims=True) 83 | # v2 = v2 / np.linalg.norm(v2, axis=-1, keepdims=True) 84 | 85 | normals = np.cross(v1, v2) 86 | alpha = np.sum(normals[:, :2] * (t1[:, :2] - points[:, :2]), axis=1) 87 | 88 | n_2 = normals[:, 2] 89 | t1_2 = t1[:, 2] 90 | s_n_2 = np.sign(n_2) 91 | abs_n_2 = np.abs(n_2) 92 | 93 | mask = (abs_n_2 != 0) 94 | 95 | depth_intersect = np.full(points.shape[0], np.nan) 96 | depth_intersect[mask] = \ 97 | t1_2[mask] * abs_n_2[mask] + alpha[mask] * s_n_2[mask] 98 | 99 | # Test the depth: 100 | # TODO: remove and put into tests 101 | # points_new = np.concatenate([points[:, :2], depth_intersect[:, None]], axis=1) 102 | # alpha = (normals * t1).sum(-1) 103 | # mask = (depth_intersect == depth_intersect) 104 | # assert(np.allclose((points_new[mask] * normals[mask]).sum(-1), 105 | # alpha[mask])) 106 | return depth_intersect, abs_n_2 107 | 108 | def rescale(self, array): 109 | array = self.scale * array + self.translate 110 | return array 111 | 112 | 113 | class TriangleIntersector2d: 114 | def __init__(self, triangles, resolution=128): 115 | self.triangles = triangles 116 | self.tri_hash = _TriangleHash(triangles, resolution) 117 | 118 | def query(self, points): 119 | point_indices, tri_indices = self.tri_hash.query(points) 120 | point_indices = np.array(point_indices, dtype=np.int64) 121 | tri_indices = np.array(tri_indices, dtype=np.int64) 122 | points = points[point_indices] 123 | triangles = self.triangles[tri_indices] 124 | mask = self.check_triangles(points, triangles) 125 | point_indices = point_indices[mask] 126 | tri_indices = tri_indices[mask] 127 | return point_indices, tri_indices 128 | 129 | def check_triangles(self, points, triangles): 130 | contains = np.zeros(points.shape[0], dtype=np.bool) 131 | A = triangles[:, :2] - triangles[:, 2:] 132 | A = A.transpose([0, 2, 1]) 133 | y = points - triangles[:, 2] 134 | 135 | detA = A[:, 0, 0] * A[:, 1, 1] - A[:, 0, 1] * A[:, 1, 0] 136 | 137 | mask = (np.abs(detA) != 0.) 138 | A = A[mask] 139 | y = y[mask] 140 | detA = detA[mask] 141 | 142 | s_detA = np.sign(detA) 143 | abs_detA = np.abs(detA) 144 | 145 | u = (A[:, 1, 1] * y[:, 0] - A[:, 0, 1] * y[:, 1]) * s_detA 146 | v = (-A[:, 1, 0] * y[:, 0] + A[:, 0, 0] * y[:, 1]) * s_detA 147 | 148 | sum_uv = u + v 149 | contains[mask] = ( 150 | (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA) 151 | & (0 < sum_uv) & (sum_uv < abs_detA) 152 | ) 153 | return contains 154 | 155 | -------------------------------------------------------------------------------- /im2mesh/utils/libmesh/inside_mesh.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmesh/inside_mesh.pyc -------------------------------------------------------------------------------- /im2mesh/utils/libmesh/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from distutils.extension import Extension 3 | from Cython.Build import cythonize 4 | 5 | 6 | ext_modules = [ 7 | Extension("triangle_hash", 8 | sources=["triangle_hash.pyx"], 9 | libraries=["m"] # Unix-like specific 10 | ) 11 | ] 12 | 13 | setup( 14 | ext_modules=cythonize(ext_modules) 15 | ) 16 | -------------------------------------------------------------------------------- /im2mesh/utils/libmesh/triangle_hash.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmesh/triangle_hash.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /im2mesh/utils/libmesh/triangle_hash.pyx: -------------------------------------------------------------------------------- 1 | 2 | # distutils: language=c++ 3 | import numpy as np 4 | cimport numpy as np 5 | cimport cython 6 | from libcpp.vector cimport vector 7 | from libc.math cimport floor, ceil 8 | 9 | cdef class TriangleHash: 10 | cdef vector[vector[int]] spatial_hash 11 | cdef int resolution 12 | 13 | def __cinit__(self, double[:, :, :] triangles, int resolution): 14 | self.spatial_hash.resize(resolution * resolution) 15 | self.resolution = resolution 16 | self._build_hash(triangles) 17 | 18 | @cython.boundscheck(False) # Deactivate bounds checking 19 | @cython.wraparound(False) # Deactivate negative indexing. 20 | cdef int _build_hash(self, double[:, :, :] triangles): 21 | assert(triangles.shape[1] == 3) 22 | assert(triangles.shape[2] == 2) 23 | 24 | cdef int n_tri = triangles.shape[0] 25 | cdef int bbox_min[2] 26 | cdef int bbox_max[2] 27 | 28 | cdef int i_tri, j, x, y 29 | cdef int spatial_idx 30 | 31 | for i_tri in range(n_tri): 32 | # Compute bounding box 33 | for j in range(2): 34 | bbox_min[j] = min( 35 | triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j] 36 | ) 37 | bbox_max[j] = max( 38 | triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j] 39 | ) 40 | bbox_min[j] = min(max(bbox_min[j], 0), self.resolution - 1) 41 | bbox_max[j] = min(max(bbox_max[j], 0), self.resolution - 1) 42 | 43 | # Find all voxels where bounding box intersects 44 | for x in range(bbox_min[0], bbox_max[0] + 1): 45 | for y in range(bbox_min[1], bbox_max[1] + 1): 46 | spatial_idx = self.resolution * x + y 47 | self.spatial_hash[spatial_idx].push_back(i_tri) 48 | 49 | @cython.boundscheck(False) # Deactivate bounds checking 50 | @cython.wraparound(False) # Deactivate negative indexing. 51 | cpdef query(self, double[:, :] points): 52 | assert(points.shape[1] == 2) 53 | cdef int n_points = points.shape[0] 54 | 55 | cdef vector[int] points_indices 56 | cdef vector[int] tri_indices 57 | # cdef int[:] points_indices_np 58 | # cdef int[:] tri_indices_np 59 | 60 | cdef int i_point, k, x, y 61 | cdef int spatial_idx 62 | 63 | for i_point in range(n_points): 64 | x = int(points[i_point, 0]) 65 | y = int(points[i_point, 1]) 66 | if not (0 <= x < self.resolution and 0 <= y < self.resolution): 67 | continue 68 | 69 | spatial_idx = self.resolution * x + y 70 | for i_tri in self.spatial_hash[spatial_idx]: 71 | points_indices.push_back(i_point) 72 | tri_indices.push_back(i_tri) 73 | 74 | points_indices_np = np.zeros(points_indices.size(), dtype=np.int32) 75 | tri_indices_np = np.zeros(tri_indices.size(), dtype=np.int32) 76 | 77 | cdef int[:] points_indices_view = points_indices_np 78 | cdef int[:] tri_indices_view = tri_indices_np 79 | 80 | for k in range(points_indices.size()): 81 | points_indices_view[k] = points_indices[k] 82 | 83 | for k in range(tri_indices.size()): 84 | tri_indices_view[k] = tri_indices[k] 85 | 86 | return points_indices_np, tri_indices_np 87 | -------------------------------------------------------------------------------- /im2mesh/utils/libmise/.gitignore: -------------------------------------------------------------------------------- 1 | mise.c 2 | mise.cpp 3 | mise.html 4 | -------------------------------------------------------------------------------- /im2mesh/utils/libmise/__init__.py: -------------------------------------------------------------------------------- 1 | from .mise import MISE 2 | 3 | 4 | __all__ = [ 5 | MISE 6 | ] 7 | -------------------------------------------------------------------------------- /im2mesh/utils/libmise/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmise/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/utils/libmise/mise.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmise/mise.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /im2mesh/utils/libmise/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from distutils.extension import Extension 3 | from Cython.Build import cythonize 4 | 5 | 6 | ext_modules = [ 7 | Extension("mise", 8 | sources=["mise.pyx"], 9 | ) 10 | ] 11 | 12 | setup( 13 | ext_modules=cythonize(ext_modules) 14 | ) 15 | -------------------------------------------------------------------------------- /im2mesh/utils/libmise/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mise import MISE 3 | import time 4 | 5 | t0 = time.time() 6 | extractor = MISE(1, 2, 0.) 7 | 8 | p = extractor.query() 9 | i = 0 10 | 11 | while p.shape[0] != 0: 12 | print(i) 13 | print(p) 14 | v = 2 * (p.sum(axis=-1) > 2).astype(np.float64) - 1 15 | extractor.update(p, v) 16 | p = extractor.query() 17 | i += 1 18 | if (i >= 8): 19 | break 20 | 21 | print(extractor.to_dense()) 22 | # p, v = extractor.get_points() 23 | # print(p) 24 | # print(v) 25 | print('Total time: %f' % (time.time() - t0)) 26 | -------------------------------------------------------------------------------- /im2mesh/utils/libsimplify/__init__.py: -------------------------------------------------------------------------------- 1 | from .simplify_mesh import ( 2 | mesh_simplify 3 | ) 4 | import trimesh 5 | 6 | 7 | def simplify_mesh(mesh, f_target=10000, agressiveness=7.): 8 | vertices = mesh.vertices 9 | faces = mesh.faces 10 | 11 | vertices, faces = mesh_simplify(vertices, faces, f_target, agressiveness) 12 | 13 | mesh_simplified = trimesh.Trimesh(vertices, faces, process=False) 14 | 15 | return mesh_simplified 16 | -------------------------------------------------------------------------------- /im2mesh/utils/libsimplify/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libsimplify/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /im2mesh/utils/libsimplify/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from distutils.extension import Extension 3 | from Cython.Build import cythonize 4 | 5 | 6 | ext_modules = [ 7 | Extension("simplify_mesh", 8 | sources=["simplify_mesh.pyx"] 9 | ) 10 | ] 11 | 12 | setup( 13 | ext_modules=cythonize(ext_modules) 14 | ) 15 | -------------------------------------------------------------------------------- /im2mesh/utils/libsimplify/simplify_mesh.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libsimplify/simplify_mesh.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /im2mesh/utils/libsimplify/simplify_mesh.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | from libcpp.vector cimport vector 3 | import numpy as np 4 | cimport numpy as np 5 | 6 | 7 | cdef extern from "Simplify.h": 8 | cdef struct vec3f: 9 | double x, y, z 10 | 11 | cdef cppclass SymetricMatrix: 12 | SymetricMatrix() except + 13 | 14 | 15 | cdef extern from "Simplify.h" namespace "Simplify": 16 | cdef struct Triangle: 17 | int v[3] 18 | double err[4] 19 | int deleted, dirty, attr 20 | vec3f uvs[3] 21 | int material 22 | 23 | cdef struct Vertex: 24 | vec3f p 25 | int tstart, tcount 26 | SymetricMatrix q 27 | int border 28 | 29 | cdef vector[Triangle] triangles 30 | cdef vector[Vertex] vertices 31 | cdef void simplify_mesh(int, double) 32 | 33 | 34 | cpdef mesh_simplify(double[:, ::1] vertices_in, long[:, ::1] triangles_in, 35 | int f_target, double agressiveness=7.) except +: 36 | vertices.clear() 37 | triangles.clear() 38 | 39 | # Read in vertices and triangles 40 | cdef Vertex v 41 | for iv in range(vertices_in.shape[0]): 42 | v = Vertex() 43 | v.p.x = vertices_in[iv, 0] 44 | v.p.y = vertices_in[iv, 1] 45 | v.p.z = vertices_in[iv, 2] 46 | vertices.push_back(v) 47 | 48 | cdef Triangle t 49 | for it in range(triangles_in.shape[0]): 50 | t = Triangle() 51 | t.v[0] = triangles_in[it, 0] 52 | t.v[1] = triangles_in[it, 1] 53 | t.v[2] = triangles_in[it, 2] 54 | triangles.push_back(t) 55 | 56 | # Simplify 57 | # print('Simplify...') 58 | simplify_mesh(f_target, agressiveness) 59 | 60 | # Only use triangles that are not deleted 61 | cdef vector[Triangle] triangles_notdel 62 | triangles_notdel.reserve(triangles.size()) 63 | 64 | for t in triangles: 65 | if not t.deleted: 66 | triangles_notdel.push_back(t) 67 | 68 | # Read out triangles 69 | vertices_out = np.empty((vertices.size(), 3), dtype=np.float64) 70 | triangles_out = np.empty((triangles_notdel.size(), 3), dtype=np.int64) 71 | 72 | cdef double[:, :] vertices_out_view = vertices_out 73 | cdef long[:, :] triangles_out_view = triangles_out 74 | 75 | for iv in range(vertices.size()): 76 | vertices_out_view[iv, 0] = vertices[iv].p.x 77 | vertices_out_view[iv, 1] = vertices[iv].p.y 78 | vertices_out_view[iv, 2] = vertices[iv].p.z 79 | 80 | for it in range(triangles_notdel.size()): 81 | triangles_out_view[it, 0] = triangles_notdel[it].v[0] 82 | triangles_out_view[it, 1] = triangles_notdel[it].v[1] 83 | triangles_out_view[it, 2] = triangles_notdel[it].v[2] 84 | 85 | # Clear vertices and triangles 86 | vertices.clear() 87 | triangles.clear() 88 | 89 | return vertices_out, triangles_out -------------------------------------------------------------------------------- /im2mesh/utils/libsimplify/test.py: -------------------------------------------------------------------------------- 1 | from simplify_mesh import mesh_simplify 2 | import numpy as np 3 | 4 | v = np.random.rand(100, 3) 5 | f = np.random.choice(range(100), (50, 3)) 6 | 7 | mesh_simplify(v, f, 50) -------------------------------------------------------------------------------- /im2mesh/utils/libvoxelize/.gitignore: -------------------------------------------------------------------------------- 1 | voxelize.c 2 | voxelize.html 3 | build 4 | -------------------------------------------------------------------------------- /im2mesh/utils/libvoxelize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libvoxelize/__init__.py -------------------------------------------------------------------------------- /im2mesh/utils/libvoxelize/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from distutils.extension import Extension 3 | from Cython.Build import cythonize 4 | 5 | ext_modules = [ 6 | Extension("voxelize", 7 | sources=["voxelize.pyx"], 8 | libraries=["m"] # Unix-like specific 9 | ) 10 | ] 11 | 12 | setup( 13 | ext_modules=cythonize(ext_modules) 14 | ) 15 | -------------------------------------------------------------------------------- /im2mesh/utils/libvoxelize/voxelize.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libvoxelize/voxelize.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /im2mesh/utils/libvoxelize/voxelize.pyx: -------------------------------------------------------------------------------- 1 | cimport cython 2 | from libc.math cimport floor, ceil 3 | from cython.view cimport array as cvarray 4 | 5 | cdef extern from "tribox2.h": 6 | int triBoxOverlap(float boxcenter[3], float boxhalfsize[3], 7 | float tri0[3], float tri1[3], float tri2[3]) 8 | 9 | 10 | @cython.boundscheck(False) # Deactivate bounds checking 11 | @cython.wraparound(False) # Deactivate negative indexing. 12 | cpdef int voxelize_mesh_(bint[:, :, :] occ, float[:, :, ::1] faces): 13 | assert(faces.shape[1] == 3) 14 | assert(faces.shape[2] == 3) 15 | 16 | n_faces = faces.shape[0] 17 | cdef int i 18 | for i in range(n_faces): 19 | voxelize_triangle_(occ, faces[i]) 20 | 21 | 22 | @cython.boundscheck(False) # Deactivate bounds checking 23 | @cython.wraparound(False) # Deactivate negative indexing. 24 | cpdef int voxelize_triangle_(bint[:, :, :] occupancies, float[:, ::1] triverts): 25 | cdef int bbox_min[3] 26 | cdef int bbox_max[3] 27 | cdef int i, j, k 28 | cdef float boxhalfsize[3] 29 | cdef float boxcenter[3] 30 | cdef bint intersection 31 | 32 | boxhalfsize[:] = (0.5, 0.5, 0.5) 33 | 34 | for i in range(3): 35 | bbox_min[i] = ( 36 | min(triverts[0, i], triverts[1, i], triverts[2, i]) 37 | ) 38 | bbox_min[i] = min(max(bbox_min[i], 0), occupancies.shape[i] - 1) 39 | 40 | for i in range(3): 41 | bbox_max[i] = ( 42 | max(triverts[0, i], triverts[1, i], triverts[2, i]) 43 | ) 44 | bbox_max[i] = min(max(bbox_max[i], 0), occupancies.shape[i] - 1) 45 | 46 | for i in range(bbox_min[0], bbox_max[0] + 1): 47 | for j in range(bbox_min[1], bbox_max[1] + 1): 48 | for k in range(bbox_min[2], bbox_max[2] + 1): 49 | boxcenter[:] = (i + 0.5, j + 0.5, k + 0.5) 50 | intersection = triBoxOverlap(&boxcenter[0], &boxhalfsize[0], 51 | &triverts[0, 0], &triverts[1, 0], &triverts[2, 0]) 52 | occupancies[i, j, k] |= intersection 53 | 54 | 55 | @cython.boundscheck(False) # Deactivate bounds checking 56 | @cython.wraparound(False) # Deactivate negative indexing. 57 | cdef int test_triangle_aabb(float[::1] boxcenter, float[::1] boxhalfsize, float[:, ::1] triverts): 58 | assert(boxcenter.shape[0] == 3) 59 | assert(boxhalfsize.shape[0] == 3) 60 | assert(triverts.shape[0] == triverts.shape[1] == 3) 61 | 62 | # print(triverts) 63 | # Call functions 64 | cdef int result = triBoxOverlap(&boxcenter[0], &boxhalfsize[0], 65 | &triverts[0, 0], &triverts[1, 0], &triverts[2, 0]) 66 | return result 67 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info/ 2 | .installed.cfg 3 | *.egg 4 | *__pycache__* 5 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ricky Tian Qi Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Implementation of Differentiable ODE Solvers 2 | 3 | This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpropagation through all solvers is supported using the adjoint method. For usage of ODE solvers in deep learning applications, see [1]. 4 | 5 | As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU. 6 | 7 | --- 8 | 9 |

10 | Discrete-depth network 11 | Continuous-depth network 12 |

13 | 14 | ## Installation 15 | ``` 16 | git clone https://github.com/rtqichen/torchdiffeq.git 17 | cd torchdiffeq 18 | pip install -e . 19 | ``` 20 | 21 | ## Examples 22 | Examples are placed in the [`examples`](./examples) directory. 23 | 24 | We encourage those who are interested in using this library to take a look at [`examples/ode_demo.py`](./examples/ode_demo.py) for understanding how to use `torchdiffeq` to fit a simple spiral ODE. 25 | 26 |

27 | ODE Demo 28 |

29 | 30 | ## Basic usage 31 | This library provides one main interface `odeint` which contains general-purpose algorithms for solving initial value problems (IVP), with gradients implemented for all main arguments. An initial value problem consists of an ODE and an initial value, 32 | ``` 33 | dy/dt = f(t, y) y(t_0) = y_0. 34 | ``` 35 | The goal of an ODE solver is to find a continuous trajectory satisfying the ODE that passes through the initial condition. 36 | 37 | To solve an IVP using the default solver: 38 | ``` 39 | from torchdiffeq import odeint 40 | 41 | odeint(func, y0, t) 42 | ``` 43 | where `func` is any callable implementing the ordinary differential equation `f(t, x)`, `y0` is an _any_-D Tensor or a tuple of _any_-D Tensors representing the initial values, and `t` is a 1-D Tensor containing the evaluation points. The initial time is taken to be `t[0]`. 44 | 45 | Backpropagation through `odeint` goes through the internals of the solver, but this is not supported for all solvers. Instead, we encourage the use of the adjoint method explained in [1], which will allow solving with as many steps as necessary due to O(1) memory usage. 46 | 47 | To use the adjoint method: 48 | ``` 49 | from torchdiffeq import odeint_adjoint as odeint 50 | 51 | odeint(func, y0, t) 52 | ``` 53 | `odeint_adjoint` simply wraps around `odeint`, but will use only O(1) memory in exchange for solving an adjoint ODE in the backward call. 54 | 55 | The biggest **gotcha** is that `func` must be a `nn.Module` when using the adjoint method. This is used to collect parameters of the differential equation. 56 | 57 | ### Keyword Arguments 58 | - `rtol` Relative tolerance. 59 | - `atol` Absolute tolerance. 60 | - `method` One of the solvers listed below. 61 | 62 | #### List of ODE Solvers: 63 | 64 | Adaptive-step: 65 | - `dopri5` Runge-Kutta 4(5) [default]. 66 | - `adams` Adaptive-order implicit Adams. 67 | 68 | Fixed-step: 69 | - `euler` Euler method. 70 | - `midpoint` Midpoint method. 71 | - `rk4` Fourth-order Runge-Kutta with 3/8 rule. 72 | - `explicit_adams` Explicit Adams. 73 | - `fixed_adams` Implicit Adams. 74 | 75 | ### References 76 | [1] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." *Advances in Neural Processing Information Systems.* 2018. [[arxiv]](https://arxiv.org/abs/1806.07366) 77 | 78 | --- 79 | 80 | If you found this library useful in your research, please consider citing 81 | ``` 82 | @article{chen2018neural, 83 | title={Neural Ordinary Differential Equations}, 84 | author={Chen, Ricky T. Q. and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David}, 85 | journal={Advances in Neural Information Processing Systems}, 86 | year={2018} 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/assets/ode_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/torchdiffeq/assets/ode_demo.gif -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/assets/odenet_0_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/torchdiffeq/assets/odenet_0_viz.png -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/assets/resnet_0_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/torchdiffeq/assets/resnet_0_viz.png -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/examples/README.md: -------------------------------------------------------------------------------- 1 | # Overview of Examples 2 | 3 | This `examples` directory contains cleaned up code regarding the usage of adaptive ODE solvers in machine learning. The scripts in this directory assume that `torchdiffeq` is installed following instructions from the main directory. 4 | 5 | ## Demo 6 | The `ode_demo.py` file contains a short implementation of learning a dynamics model to mimic a spiral ODE. 7 | 8 | To visualize the training progress, run 9 | ``` 10 | python ode_demo.py --viz 11 | ``` 12 | The training should look similar to this: 13 | 14 |

15 | ODE Demo 16 |

17 | 18 | ## ODEnet for MNIST 19 | The `odenet_mnist.py` file contains a reproduction of the MNIST experiments in our Neural ODE paper. Notably not just the architecture but the ODE solver library and integration method are different from our original experiments, though the results are similar to those we report in the paper. 20 | 21 | We can use an adaptive ODE solver to approximate our continuous-depth network while still backpropagating through the network. 22 | ``` 23 | python odenet_mnist.py --network odenet 24 | ``` 25 | However, the memory requirements for this will blow up very fast, especially for more complex problems where the number of function evaluations can reach nearly a thousand. 26 | 27 | For applications that require solving complex trajectories, we recommend using the adjoint method. 28 | ``` 29 | python odenet_mnist.py --network odenet --adjoint True 30 | ``` 31 | The adjoint method can be slower when using an adaptive ODE solver as it involves another solve in the backward pass with a much larger system, so experimenting on small systems with direct backpropagation first is recommended. 32 | 33 | Thankfully, it is extremely easy to write code for both adjoint and non-adjoint backpropagation, as they use the same interface. 34 | ``` 35 | if adjoint: 36 | from torchdiffeq import odeint_adjoint as odeint 37 | else: 38 | from torchdiffeq import odeint 39 | ``` 40 | The main gotcha is that `odeint_adjoint` requires implementing the dynamics network as a `nn.Module` while `odeint` can work with any callable in Python. 41 | 42 | ## Continuous Normalizing Flows 43 | Code for continuous normalizing flows (CNF) have their own public repository. Tools for training, evaluating, and visualizing CNF for reversible generative modeling are provided along with FFJORD, a linear cost stochastic approximation of CNF. 44 | 45 | Find the code in https://github.com/rtqichen/ffjord. This code contains some advanced tricks for `torchdiffeq`. 46 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="torchdiffeq", 5 | version="0.0.1", 6 | author="Ricky Tian Qi Chen", 7 | author_email="rtqichen@cs.toronto.edu", 8 | description="ODE solvers and adjoint sensitivity analysis in PyTorch.", 9 | url="https://github.com/rtqichen/torchdiffeq", 10 | packages=['torchdiffeq', 'torchdiffeq._impl'], 11 | install_requires=['torch>=0.4.1'], 12 | classifiers=( 13 | "Programming Language :: Python :: 3"),) 14 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/tests/api_tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torchdiffeq 4 | 5 | from problems import construct_problem 6 | 7 | eps = 1e-12 8 | 9 | torch.set_default_dtype(torch.float64) 10 | TEST_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | def max_abs(tensor): 14 | return torch.max(torch.abs(tensor)) 15 | 16 | 17 | class TestCollectionState(unittest.TestCase): 18 | 19 | def test_dopri5(self): 20 | f, y0, t_points, sol = construct_problem(TEST_DEVICE) 21 | 22 | tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1])) 23 | tuple_y0 = (y0, y0) 24 | 25 | tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method='dopri5') 26 | max_error0 = (sol - tuple_y[0]).max() 27 | max_error1 = (sol - tuple_y[1]).max() 28 | self.assertLess(max_error0, eps) 29 | self.assertLess(max_error1, eps) 30 | 31 | def test_dopri5_gradient(self): 32 | f, y0, t_points, sol = construct_problem(TEST_DEVICE) 33 | 34 | tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1])) 35 | 36 | for i in range(2): 37 | func = lambda y0, t_points: torchdiffeq.odeint(tuple_f, (y0, y0), t_points, method='dopri5')[i] 38 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points))) 39 | 40 | def test_adams(self): 41 | f, y0, t_points, sol = construct_problem(TEST_DEVICE) 42 | 43 | tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1])) 44 | tuple_y0 = (y0, y0) 45 | 46 | tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method='adams') 47 | max_error0 = (sol - tuple_y[0]).max() 48 | max_error1 = (sol - tuple_y[1]).max() 49 | self.assertLess(max_error0, eps) 50 | self.assertLess(max_error1, eps) 51 | 52 | def test_adams_gradient(self): 53 | f, y0, t_points, sol = construct_problem(TEST_DEVICE) 54 | 55 | tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1])) 56 | 57 | for i in range(2): 58 | func = lambda y0, t_points: torchdiffeq.odeint(tuple_f, (y0, y0), t_points, method='adams')[i] 59 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points))) 60 | 61 | 62 | if __name__ == '__main__': 63 | unittest.main() 64 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/tests/gradient_tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torchdiffeq 4 | 5 | from problems import construct_problem 6 | 7 | eps = 1e-12 8 | 9 | torch.set_default_dtype(torch.float64) 10 | TEST_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | def max_abs(tensor): 14 | return torch.max(torch.abs(tensor)) 15 | 16 | 17 | class TestGradient(unittest.TestCase): 18 | 19 | def test_midpoint(self): 20 | 21 | f, y0, t_points, _ = construct_problem(TEST_DEVICE) 22 | 23 | func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='midpoint') 24 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points))) 25 | 26 | def test_rk4(self): 27 | 28 | f, y0, t_points, _ = construct_problem(TEST_DEVICE) 29 | 30 | func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='rk4') 31 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points))) 32 | 33 | def test_dopri5(self): 34 | f, y0, t_points, _ = construct_problem(TEST_DEVICE) 35 | 36 | func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='dopri5') 37 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points))) 38 | 39 | def test_adams(self): 40 | f, y0, t_points, _ = construct_problem(TEST_DEVICE) 41 | 42 | func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='adams') 43 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points))) 44 | 45 | def test_adjoint(self): 46 | """ 47 | Test against dopri5 48 | """ 49 | f, y0, t_points, _ = construct_problem(TEST_DEVICE) 50 | 51 | func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='dopri5') 52 | ys = func(y0, t_points) 53 | torch.manual_seed(0) 54 | gradys = torch.rand_like(ys) 55 | ys.backward(gradys) 56 | 57 | # reg_y0_grad = y0.grad 58 | reg_t_grad = t_points.grad 59 | reg_a_grad = f.a.grad 60 | reg_b_grad = f.b.grad 61 | 62 | f, y0, t_points, _ = construct_problem(TEST_DEVICE) 63 | 64 | func = lambda y0, t_points: torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5') 65 | ys = func(y0, t_points) 66 | ys.backward(gradys) 67 | 68 | # adj_y0_grad = y0.grad 69 | adj_t_grad = t_points.grad 70 | adj_a_grad = f.a.grad 71 | adj_b_grad = f.b.grad 72 | 73 | # self.assertLess(max_abs(reg_y0_grad - adj_y0_grad), eps) 74 | self.assertLess(max_abs(reg_t_grad - adj_t_grad), eps) 75 | self.assertLess(max_abs(reg_a_grad - adj_a_grad), eps) 76 | self.assertLess(max_abs(reg_b_grad - adj_b_grad), eps) 77 | 78 | 79 | class TestCompareAdjointGradient(unittest.TestCase): 80 | 81 | def problem(self): 82 | 83 | class Odefunc(torch.nn.Module): 84 | 85 | def __init__(self): 86 | super(Odefunc, self).__init__() 87 | self.A = torch.nn.Parameter(torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])) 88 | self.unused_module = torch.nn.Linear(2, 5) 89 | 90 | def forward(self, t, y): 91 | return torch.mm(y**3, self.A) 92 | 93 | y0 = torch.tensor([[2., 0.]]).to(TEST_DEVICE).requires_grad_(True) 94 | t_points = torch.linspace(0., 25., 10).to(TEST_DEVICE).requires_grad_(True) 95 | func = Odefunc().to(TEST_DEVICE) 96 | return func, y0, t_points 97 | 98 | def test_dopri5_adjoint_against_dopri5(self): 99 | func, y0, t_points = self.problem() 100 | ys = torchdiffeq.odeint_adjoint(func, y0, t_points, method='dopri5') 101 | gradys = torch.rand_like(ys) * 0.1 102 | ys.backward(gradys) 103 | 104 | adj_y0_grad = y0.grad 105 | adj_t_grad = t_points.grad 106 | adj_A_grad = func.A.grad 107 | self.assertEqual(max_abs(func.unused_module.weight.grad), 0) 108 | self.assertEqual(max_abs(func.unused_module.bias.grad), 0) 109 | 110 | func, y0, t_points = self.problem() 111 | ys = torchdiffeq.odeint(func, y0, t_points, method='dopri5') 112 | ys.backward(gradys) 113 | 114 | self.assertLess(max_abs(y0.grad - adj_y0_grad), 3e-4) 115 | self.assertLess(max_abs(t_points.grad - adj_t_grad), 1e-4) 116 | self.assertLess(max_abs(func.A.grad - adj_A_grad), 2e-3) 117 | 118 | def test_adams_adjoint_against_dopri5(self): 119 | func, y0, t_points = self.problem() 120 | ys_ = torchdiffeq.odeint_adjoint(func, y0, t_points, method='adams') 121 | gradys = torch.rand_like(ys_) * 0.1 122 | ys_.backward(gradys) 123 | 124 | adj_y0_grad = y0.grad 125 | adj_t_grad = t_points.grad 126 | adj_A_grad = func.A.grad 127 | self.assertEqual(max_abs(func.unused_module.weight.grad), 0) 128 | self.assertEqual(max_abs(func.unused_module.bias.grad), 0) 129 | 130 | func, y0, t_points = self.problem() 131 | ys = torchdiffeq.odeint(func, y0, t_points, method='dopri5') 132 | ys.backward(gradys) 133 | 134 | self.assertLess(max_abs(y0.grad - adj_y0_grad), 5e-2) 135 | self.assertLess(max_abs(t_points.grad - adj_t_grad), 5e-4) 136 | self.assertLess(max_abs(func.A.grad - adj_A_grad), 2e-2) 137 | 138 | 139 | if __name__ == '__main__': 140 | unittest.main() 141 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/tests/odeint_tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torchdiffeq 4 | 5 | import problems 6 | 7 | error_tol = 1e-4 8 | 9 | torch.set_default_dtype(torch.float64) 10 | TEST_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | def max_abs(tensor): 14 | return torch.max(torch.abs(tensor)) 15 | 16 | 17 | def rel_error(true, estimate): 18 | return max_abs((true - estimate) / true) 19 | 20 | 21 | class TestSolverError(unittest.TestCase): 22 | 23 | def test_euler(self): 24 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE) 25 | 26 | y = torchdiffeq.odeint(f, y0, t_points, method='euler') 27 | self.assertLess(rel_error(sol, y), error_tol) 28 | 29 | def test_midpoint(self): 30 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE) 31 | 32 | y = torchdiffeq.odeint(f, y0, t_points, method='midpoint') 33 | self.assertLess(rel_error(sol, y), error_tol) 34 | 35 | def test_rk4(self): 36 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE) 37 | 38 | y = torchdiffeq.odeint(f, y0, t_points, method='rk4') 39 | self.assertLess(rel_error(sol, y), error_tol) 40 | 41 | def test_explicit_adams(self): 42 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE) 43 | 44 | y = torchdiffeq.odeint(f, y0, t_points, method='explicit_adams') 45 | self.assertLess(rel_error(sol, y), error_tol) 46 | 47 | def test_adams(self): 48 | for ode in problems.PROBLEMS.keys(): 49 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, ode=ode) 50 | y = torchdiffeq.odeint(f, y0, t_points, method='adams') 51 | with self.subTest(ode=ode): 52 | self.assertLess(rel_error(sol, y), error_tol) 53 | 54 | def test_dopri5(self): 55 | for ode in problems.PROBLEMS.keys(): 56 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, ode=ode) 57 | y = torchdiffeq.odeint(f, y0, t_points, method='dopri5') 58 | with self.subTest(ode=ode): 59 | self.assertLess(rel_error(sol, y), error_tol) 60 | 61 | def test_adjoint(self): 62 | for ode in problems.PROBLEMS.keys(): 63 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) 64 | 65 | y = torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5') 66 | with self.subTest(ode=ode): 67 | self.assertLess(rel_error(sol, y), error_tol) 68 | 69 | 70 | class TestSolverBackwardsInTimeError(unittest.TestCase): 71 | 72 | def test_euler(self): 73 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) 74 | 75 | y = torchdiffeq.odeint(f, y0, t_points, method='euler') 76 | self.assertLess(rel_error(sol, y), error_tol) 77 | 78 | def test_midpoint(self): 79 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) 80 | 81 | y = torchdiffeq.odeint(f, y0, t_points, method='midpoint') 82 | self.assertLess(rel_error(sol, y), error_tol) 83 | 84 | def test_rk4(self): 85 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) 86 | 87 | y = torchdiffeq.odeint(f, y0, t_points, method='rk4') 88 | self.assertLess(rel_error(sol, y), error_tol) 89 | 90 | def test_explicit_adams(self): 91 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) 92 | 93 | y = torchdiffeq.odeint(f, y0, t_points, method='explicit_adams') 94 | self.assertLess(rel_error(sol, y), error_tol) 95 | 96 | def test_adams(self): 97 | for ode in problems.PROBLEMS.keys(): 98 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) 99 | 100 | y = torchdiffeq.odeint(f, y0, t_points, method='adams') 101 | with self.subTest(ode=ode): 102 | self.assertLess(rel_error(sol, y), error_tol) 103 | 104 | def test_dopri5(self): 105 | for ode in problems.PROBLEMS.keys(): 106 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) 107 | 108 | y = torchdiffeq.odeint(f, y0, t_points, method='dopri5') 109 | with self.subTest(ode=ode): 110 | self.assertLess(rel_error(sol, y), error_tol) 111 | 112 | def test_adjoint(self): 113 | for ode in problems.PROBLEMS.keys(): 114 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) 115 | 116 | y = torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5') 117 | with self.subTest(ode=ode): 118 | self.assertLess(rel_error(sol, y), error_tol) 119 | 120 | 121 | class TestNoIntegration(unittest.TestCase): 122 | 123 | def test_midpoint(self): 124 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) 125 | 126 | y = torchdiffeq.odeint(f, y0, t_points[0:1], method='midpoint') 127 | self.assertLess(max_abs(sol[0] - y), error_tol) 128 | 129 | def test_rk4(self): 130 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) 131 | 132 | y = torchdiffeq.odeint(f, y0, t_points[0:1], method='rk4') 133 | self.assertLess(max_abs(sol[0] - y), error_tol) 134 | 135 | def test_explicit_adams(self): 136 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) 137 | 138 | y = torchdiffeq.odeint(f, y0, t_points[0:1], method='explicit_adams') 139 | self.assertLess(max_abs(sol[0] - y), error_tol) 140 | 141 | def test_adams(self): 142 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) 143 | 144 | y = torchdiffeq.odeint(f, y0, t_points[0:1], method='adams') 145 | self.assertLess(max_abs(sol[0] - y), error_tol) 146 | 147 | def test_dopri5(self): 148 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) 149 | 150 | y = torchdiffeq.odeint(f, y0, t_points[0:1], method='dopri5') 151 | self.assertLess(max_abs(sol[0] - y), error_tol) 152 | 153 | 154 | if __name__ == '__main__': 155 | unittest.main() 156 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/tests/problems.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import scipy.linalg 4 | import torch 5 | 6 | 7 | class ConstantODE(torch.nn.Module): 8 | 9 | def __init__(self, device): 10 | super(ConstantODE, self).__init__() 11 | self.a = torch.nn.Parameter(torch.tensor(0.2).to(device)) 12 | self.b = torch.nn.Parameter(torch.tensor(3.0).to(device)) 13 | 14 | def forward(self, t, y): 15 | return self.a + (y - (self.a * t + self.b))**5 16 | 17 | def y_exact(self, t): 18 | return self.a * t + self.b 19 | 20 | 21 | class SineODE(torch.nn.Module): 22 | 23 | def __init__(self, device): 24 | super(SineODE, self).__init__() 25 | 26 | def forward(self, t, y): 27 | return 2 * y / t + t**4 * torch.sin(2 * t) - t**2 + 4 * t**3 28 | 29 | def y_exact(self, t): 30 | return -0.5 * t**4 * torch.cos(2 * t) + 0.5 * t**3 * torch.sin(2 * t) + 0.25 * t**2 * torch.cos( 31 | 2 * t 32 | ) - t**3 + 2 * t**4 + (math.pi - 0.25) * t**2 33 | 34 | 35 | class LinearODE(torch.nn.Module): 36 | 37 | def __init__(self, device, dim=10): 38 | super(LinearODE, self).__init__() 39 | self.dim = dim 40 | U = torch.randn(dim, dim).to(device) * 0.1 41 | A = 2 * U - (U + U.transpose(0, 1)) 42 | self.A = torch.nn.Parameter(A) 43 | self.initial_val = np.ones((dim, 1)) 44 | 45 | def forward(self, t, y): 46 | return torch.mm(self.A, y.reshape(self.dim, 1)).reshape(-1) 47 | 48 | def y_exact(self, t): 49 | t = t.detach().cpu().numpy() 50 | A_np = self.A.detach().cpu().numpy() 51 | ans = [] 52 | for t_i in t: 53 | ans.append(np.matmul(scipy.linalg.expm(A_np * t_i), self.initial_val)) 54 | return torch.stack([torch.tensor(ans_) for ans_ in ans]).reshape(len(t), self.dim) 55 | 56 | 57 | PROBLEMS = {'constant': ConstantODE, 'linear': LinearODE, 'sine': SineODE} 58 | 59 | 60 | def construct_problem(device, npts=10, ode='constant', reverse=False): 61 | 62 | f = PROBLEMS[ode](device) 63 | 64 | t_points = torch.linspace(1, 8, npts).to(device).requires_grad_(True) 65 | sol = f.y_exact(t_points) 66 | 67 | def _flip(x, dim): 68 | indices = [slice(None)] * x.dim() 69 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) 70 | return x[tuple(indices)] 71 | 72 | if reverse: 73 | t_points = _flip(t_points, 0).clone().detach() 74 | sol = _flip(sol, 0).clone().detach() 75 | 76 | return f, sol[0].detach(), t_points, sol 77 | 78 | 79 | if __name__ == '__main__': 80 | f = SineODE('cpu') 81 | t_points = torch.linspace(1, 8, 100).to('cpu').requires_grad_(True) 82 | sol = f.y_exact(t_points) 83 | 84 | import matplotlib.pyplot as plt 85 | plt.plot(t_points.detach().cpu().numpy(), sol.detach().cpu().numpy()) 86 | plt.show() 87 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/tests/run_all.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from api_tests import * 3 | from gradient_tests import * 4 | from odeint_tests import * 5 | 6 | if __name__ == '__main__': 7 | unittest.main() 8 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/torchdiffeq/__init__.py: -------------------------------------------------------------------------------- 1 | from ._impl import odeint 2 | from ._impl import odeint_adjoint 3 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/torchdiffeq/_impl/__init__.py: -------------------------------------------------------------------------------- 1 | from .odeint import odeint 2 | from .adjoint import odeint_adjoint 3 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/torchdiffeq/_impl/adjoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import odeint 4 | from .misc import _flatten, _flatten_convert_none_to_zeros 5 | 6 | 7 | class OdeintAdjointMethod(torch.autograd.Function): 8 | 9 | @staticmethod 10 | def forward(ctx, *args): 11 | assert len(args) >= 9, 'Internal error: all arguments required.' 12 | y0, func, t, flat_params, rtol, atol, method, options, f_options = \ 13 | args[:-8], args[-8], args[-7], args[-6], args[-5], args[-4], args[-3], args[-2], args[-1] 14 | 15 | ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options, ctx.f_options = func, rtol, atol, method, options, f_options 16 | 17 | with torch.no_grad(): 18 | ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, f_options=f_options) 19 | ctx.save_for_backward(t, flat_params, *ans) 20 | return ans 21 | 22 | @staticmethod 23 | def backward(ctx, *grad_output): 24 | t, flat_params, *ans = ctx.saved_tensors 25 | ans = tuple(ans) 26 | func, rtol, atol, method, options, f_options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options, ctx.f_options 27 | n_tensors = len(ans) 28 | f_params = tuple(func.parameters()) 29 | 30 | # TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives. 31 | def augmented_dynamics(t, y_aug, **f_options): 32 | # Dynamics of the original system augmented with 33 | # the adjoint wrt y, and an integrator wrt t and args. 34 | y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. 35 | 36 | with torch.set_grad_enabled(True): 37 | t = t.to(y[0].device).detach().requires_grad_(True) 38 | y = tuple(y_.detach().requires_grad_(True) for y_ in y) 39 | func_eval = func(t, y, **f_options) 40 | vjp_t, *vjp_y_and_params = torch.autograd.grad( 41 | func_eval, (t,) + y + f_params, 42 | tuple(-adj_y_ for adj_y_ in adj_y), allow_unused=True, retain_graph=True 43 | ) 44 | vjp_y = vjp_y_and_params[:n_tensors] 45 | vjp_params = vjp_y_and_params[n_tensors:] 46 | 47 | # autograd.grad returns None if no gradient, set to zero. 48 | vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t 49 | vjp_y = tuple(torch.zeros_like(y_) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y)) 50 | vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params) 51 | 52 | if len(f_params) == 0: 53 | vjp_params = torch.tensor(0.).to(vjp_y[0]) 54 | return (*func_eval, *vjp_y, vjp_t, vjp_params) 55 | 56 | T = ans[0].shape[0] 57 | with torch.no_grad(): 58 | adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output) 59 | adj_params = torch.zeros_like(flat_params) 60 | adj_time = torch.tensor(0.).to(t) 61 | time_vjps = [] 62 | for i in range(T - 1, 0, -1): 63 | 64 | ans_i = tuple(ans_[i] for ans_ in ans) 65 | grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output) 66 | func_i = func(t[i], ans_i, **f_options) 67 | 68 | # Compute the effect of moving the current time measurement point. 69 | dLd_cur_t = sum( 70 | torch.dot(func_i_.view(-1), grad_output_i_.view(-1)).view(1) 71 | for func_i_, grad_output_i_ in zip(func_i, grad_output_i) 72 | ) 73 | adj_time = adj_time - dLd_cur_t 74 | time_vjps.append(dLd_cur_t) 75 | 76 | # Run the augmented system backwards in time. 77 | if len(adj_params) == 0: 78 | adj_params = torch.tensor(0.).to(adj_y[0]) 79 | aug_y0 = (*ans_i, *adj_y, adj_time, adj_params) 80 | aug_ans = odeint( 81 | augmented_dynamics, aug_y0, 82 | torch.tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options, f_options=f_options 83 | ) 84 | 85 | # Unpack aug_ans. 86 | adj_y = aug_ans[n_tensors:2 * n_tensors] 87 | adj_time = aug_ans[2 * n_tensors] 88 | adj_params = aug_ans[2 * n_tensors + 1] 89 | 90 | adj_y = tuple(adj_y_[1] if len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y) 91 | if len(adj_time) > 0: adj_time = adj_time[1] 92 | if len(adj_params) > 0: adj_params = adj_params[1] 93 | 94 | adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output)) 95 | 96 | del aug_y0, aug_ans 97 | 98 | time_vjps.append(adj_time) 99 | time_vjps = torch.cat(time_vjps[::-1]) 100 | 101 | return (*adj_y, None, time_vjps, adj_params, None, None, None, None, None) 102 | 103 | 104 | def odeint_adjoint(func, y0, t, rtol=1e-6, atol=1e-12, method=None, options=None, f_options=None): 105 | 106 | # We need this in order to access the variables inside this module, 107 | # since we have no other way of getting variables along the execution path. 108 | if not isinstance(func, nn.Module): 109 | raise ValueError('func is required to be an instance of nn.Module.') 110 | 111 | tensor_input = False 112 | if torch.is_tensor(y0): 113 | 114 | class TupleFunc(nn.Module): 115 | 116 | def __init__(self, base_func): 117 | super(TupleFunc, self).__init__() 118 | self.base_func = base_func 119 | 120 | def forward(self, t, y, **f_options): 121 | return (self.base_func(t, y[0], **f_options),) 122 | 123 | tensor_input = True 124 | y0 = (y0,) 125 | func = TupleFunc(func) 126 | 127 | flat_params = _flatten(func.parameters()) 128 | ys = OdeintAdjointMethod.apply(*y0, func, t, flat_params, rtol, atol, method, options, f_options) 129 | 130 | if tensor_input: 131 | ys = ys[0] 132 | return ys 133 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/torchdiffeq/_impl/dopri5.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate 2 | import torch 3 | from .misc import ( 4 | _scaled_dot_product, _convert_to_tensor, _is_finite, _select_initial_step, _handle_unused_kwargs, _is_iterable, 5 | _optimal_step_size, _compute_error_ratio 6 | ) 7 | from .solvers import AdaptiveStepsizeODESolver 8 | from .interp import _interp_fit, _interp_evaluate 9 | from .rk_common import _RungeKuttaState, _ButcherTableau, _runge_kutta_step 10 | 11 | _DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau( 12 | alpha=[1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], 13 | beta=[ 14 | [1 / 5], 15 | [3 / 40, 9 / 40], 16 | [44 / 45, -56 / 15, 32 / 9], 17 | [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], 18 | [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], 19 | [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], 20 | ], 21 | c_sol=[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], 22 | c_error=[ 23 | 35 / 384 - 1951 / 21600, 24 | 0, 25 | 500 / 1113 - 22642 / 50085, 26 | 125 / 192 - 451 / 720, 27 | -2187 / 6784 - -12231 / 42400, 28 | 11 / 84 - 649 / 6300, 29 | -1. / 60., 30 | ], 31 | ) 32 | 33 | DPS_C_MID = [ 34 | 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2, 35 | 187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2 36 | ] 37 | 38 | 39 | def _interp_fit_dopri5(y0, y1, k, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU): 40 | """Fit an interpolating polynomial to the results of a Runge-Kutta step.""" 41 | dt = dt.type_as(y0[0]) 42 | y_mid = tuple(y0_ + _scaled_dot_product(dt, DPS_C_MID, k_) for y0_, k_ in zip(y0, k)) 43 | f0 = tuple(k_[0] for k_ in k) 44 | f1 = tuple(k_[-1] for k_ in k) 45 | return _interp_fit(y0, y1, y_mid, f0, f1, dt) 46 | 47 | 48 | def _abs_square(x): 49 | return torch.mul(x, x) 50 | 51 | 52 | def _ta_append(list_of_tensors, value): 53 | """Append a value to the end of a list of PyTorch tensors.""" 54 | list_of_tensors.append(value) 55 | return list_of_tensors 56 | 57 | 58 | class Dopri5Solver(AdaptiveStepsizeODESolver): 59 | 60 | def __init__( 61 | self, func, y0, f_options, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1, 62 | **unused_kwargs 63 | ): 64 | _handle_unused_kwargs(self, unused_kwargs) 65 | del unused_kwargs 66 | 67 | self.func = func 68 | self.y0 = y0 69 | self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) 70 | self.atol = atol if _is_iterable(atol) else [atol] * len(y0) 71 | self.f_options = f_options 72 | self.first_step = first_step 73 | self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) 74 | self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) 75 | self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) 76 | self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device) 77 | 78 | def before_integrate(self, t): 79 | f0 = self.func(t[0].type_as(self.y0[0]), self.y0, **self.f_options) 80 | if self.first_step is None: 81 | first_step = _select_initial_step(self.func, t[0], self.y0, self.f_options, 4, self.rtol[0], self.atol[0], f0=f0).to(t) 82 | else: 83 | first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device) 84 | self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5) 85 | 86 | def advance(self, next_t): 87 | """Interpolate through the next time point, integrating as necessary.""" 88 | n_steps = 0 89 | while next_t > self.rk_state.t1: 90 | assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps) 91 | self.rk_state = self._adaptive_dopri5_step(self.rk_state) 92 | n_steps += 1 93 | return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t) 94 | 95 | def _adaptive_dopri5_step(self, rk_state): 96 | """Take an adaptive Runge-Kutta step to integrate the ODE.""" 97 | y0, f0, _, t0, dt, interp_coeff = rk_state 98 | ######################################################## 99 | # Assertions # 100 | ######################################################## 101 | assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item()) 102 | for y0_ in y0: 103 | assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_) 104 | y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU, f_options=self.f_options) 105 | 106 | ######################################################## 107 | # Error Ratio # 108 | ######################################################## 109 | mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1) 110 | accept_step = (torch.tensor(mean_sq_error_ratio) <= 1).all() 111 | 112 | ######################################################## 113 | # Update RK State # 114 | ######################################################## 115 | y_next = y1 if accept_step else y0 116 | f_next = f1 if accept_step else f0 117 | t_next = t0 + dt if accept_step else t0 118 | interp_coeff = _interp_fit_dopri5(y0, y1, k, dt) if accept_step else interp_coeff 119 | dt_next = _optimal_step_size( 120 | dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=5 121 | ) 122 | rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff) 123 | return rk_state 124 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/torchdiffeq/_impl/fixed_grid.py: -------------------------------------------------------------------------------- 1 | from .solvers import FixedGridODESolver 2 | from . import rk_common 3 | 4 | 5 | class Euler(FixedGridODESolver): 6 | 7 | def step_func(self, func, t, dt, y): 8 | return tuple(dt * f_ for f_ in func(t, y)) 9 | 10 | @property 11 | def order(self): 12 | return 1 13 | 14 | 15 | class Midpoint(FixedGridODESolver): 16 | 17 | def step_func(self, func, t, dt, y): 18 | y_mid = tuple(y_ + f_ * dt / 2 for y_, f_ in zip(y, func(t, y))) 19 | return tuple(dt * f_ for f_ in func(t + dt / 2, y_mid)) 20 | 21 | @property 22 | def order(self): 23 | return 2 24 | 25 | 26 | class RK4(FixedGridODESolver): 27 | 28 | def step_func(self, func, t, dt, y, f_options): 29 | return rk_common.rk4_alt_step_func(func, t, dt, y, f_options) 30 | 31 | @property 32 | def order(self): 33 | return 4 34 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/torchdiffeq/_impl/interp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .misc import _convert_to_tensor, _dot_product 3 | 4 | 5 | def _interp_fit(y0, y1, y_mid, f0, f1, dt): 6 | """Fit coefficients for 4th order polynomial interpolation. 7 | 8 | Args: 9 | y0: function value at the start of the interval. 10 | y1: function value at the end of the interval. 11 | y_mid: function value at the mid-point of the interval. 12 | f0: derivative value at the start of the interval. 13 | f1: derivative value at the end of the interval. 14 | dt: width of the interval. 15 | 16 | Returns: 17 | List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial 18 | `p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x` 19 | between 0 (start of interval) and 1 (end of interval). 20 | """ 21 | a = tuple( 22 | _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0_, f1_, y0_, y1_, y_mid_]) 23 | for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid) 24 | ) 25 | b = tuple( 26 | _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0_, f1_, y0_, y1_, y_mid_]) 27 | for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid) 28 | ) 29 | c = tuple( 30 | _dot_product([-4 * dt, dt, -11, -5, 16], [f0_, f1_, y0_, y1_, y_mid_]) 31 | for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid) 32 | ) 33 | d = tuple(dt * f0_ for f0_ in f0) 34 | e = y0 35 | return [a, b, c, d, e] 36 | 37 | 38 | def _interp_evaluate(coefficients, t0, t1, t): 39 | """Evaluate polynomial interpolation at the given time point. 40 | 41 | Args: 42 | coefficients: list of Tensor coefficients as created by `interp_fit`. 43 | t0: scalar float64 Tensor giving the start of the interval. 44 | t1: scalar float64 Tensor giving the end of the interval. 45 | t: scalar float64 Tensor giving the desired interpolation point. 46 | 47 | Returns: 48 | Polynomial interpolation of the coefficients at time `t`. 49 | """ 50 | 51 | dtype = coefficients[0][0].dtype 52 | device = coefficients[0][0].device 53 | 54 | t0 = _convert_to_tensor(t0, dtype=dtype, device=device) 55 | t1 = _convert_to_tensor(t1, dtype=dtype, device=device) 56 | t = _convert_to_tensor(t, dtype=dtype, device=device) 57 | 58 | assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1) 59 | x = ((t - t0) / (t1 - t0)).type(dtype).to(device) 60 | 61 | xs = [torch.tensor(1).type(dtype).to(device), x] 62 | for _ in range(2, len(coefficients)): 63 | xs.append(xs[-1] * x) 64 | 65 | return tuple(_dot_product(coefficients_, reversed(xs)) for coefficients_ in zip(*coefficients)) 66 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/torchdiffeq/_impl/odeint.py: -------------------------------------------------------------------------------- 1 | from .tsit5 import Tsit5Solver 2 | from .dopri5 import Dopri5Solver 3 | from .fixed_grid import Euler, Midpoint, RK4 4 | from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton 5 | from .adams import VariableCoefficientAdamsBashforth 6 | from .misc import _check_inputs 7 | 8 | SOLVERS = { 9 | 'explicit_adams': AdamsBashforth, 10 | 'fixed_adams': AdamsBashforthMoulton, 11 | 'adams': VariableCoefficientAdamsBashforth, 12 | 'tsit5': Tsit5Solver, 13 | 'dopri5': Dopri5Solver, 14 | 'euler': Euler, 15 | 'midpoint': Midpoint, 16 | 'rk4': RK4, 17 | } 18 | 19 | 20 | def odeint(func, y0, t, rtol=1e-7, atol=1e-9, method=None, options=None, f_options=None): 21 | """Integrate a system of ordinary differential equations. 22 | 23 | Solves the initial value problem for a non-stiff system of first order ODEs: 24 | ``` 25 | dy/dt = func(t, y), y(t[0]) = y0 26 | ``` 27 | where y is a Tensor of any shape. 28 | 29 | Output dtypes and numerical precision are based on the dtypes of the inputs `y0`. 30 | 31 | Args: 32 | func: Function that maps a Tensor holding the state `y` and a scalar Tensor 33 | `t` into a Tensor of state derivatives with respect to time. 34 | y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May 35 | have any floating point or complex dtype. 36 | t: 1-D Tensor holding a sequence of time points for which to solve for 37 | `y`. The initial time point should be the first element of this sequence, 38 | and each time must be larger than the previous time. May have any floating 39 | point dtype. Converted to a Tensor with float64 dtype. 40 | rtol: optional float64 Tensor specifying an upper bound on relative error, 41 | per element of `y`. 42 | atol: optional float64 Tensor specifying an upper bound on absolute error, 43 | per element of `y`. 44 | method: optional string indicating the integration method to use. 45 | options: optional dict of configuring options for the indicated integration 46 | method. Can only be provided if a `method` is explicitly set. 47 | f_options: optional dict of additional arguments for func 48 | 49 | Returns: 50 | y: Tensor, where the first dimension corresponds to different 51 | time points. Contains the solved value of y for each desired time point in 52 | `t`, with the initial value `y0` being the first element along the first 53 | dimension. 54 | 55 | Raises: 56 | ValueError: if an invalid `method` is provided. 57 | TypeError: if `options` is supplied without `method`, or if `t` or `y0` has 58 | an invalid dtype. 59 | """ 60 | tensor_input, func, y0, t = _check_inputs(func, y0, t, f_options) 61 | 62 | if options is None: 63 | options = {} 64 | elif method is None: 65 | raise ValueError('cannot supply `options` without specifying `method`') 66 | 67 | if method is None: 68 | method = 'dopri5' 69 | 70 | if f_options is None: 71 | f_options = {} 72 | solver = SOLVERS[method](func, y0, f_options, rtol=rtol, atol=atol, **options) 73 | solution = solver.integrate(t) 74 | 75 | if tensor_input: 76 | solution = solution[0] 77 | return solution 78 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/torchdiffeq/_impl/rk_common.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate 2 | import collections 3 | from .misc import _scaled_dot_product, _convert_to_tensor 4 | 5 | _ButcherTableau = collections.namedtuple('_ButcherTableau', 'alpha beta c_sol c_error') 6 | 7 | 8 | class _RungeKuttaState(collections.namedtuple('_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff')): 9 | """Saved state of the Runge Kutta solver. 10 | 11 | Attributes: 12 | y1: Tensor giving the function value at the end of the last time step. 13 | f1: Tensor giving derivative at the end of the last time step. 14 | t0: scalar float64 Tensor giving start of the last time step. 15 | t1: scalar float64 Tensor giving end of the last time step. 16 | dt: scalar float64 Tensor giving the size for the next time step. 17 | interp_coef: list of Tensors giving coefficients for polynomial 18 | interpolation between `t0` and `t1`. 19 | """ 20 | 21 | 22 | def _runge_kutta_step(func, y0, f0, t0, dt, tableau, f_options): 23 | """Take an arbitrary Runge-Kutta step and estimate error. 24 | 25 | Args: 26 | func: Function to evaluate like `func(t, y)` to compute the time derivative 27 | of `y`. 28 | y0: Tensor initial value for the state. 29 | f0: Tensor initial value for the derivative, computed from `func(t0, y0)`. 30 | t0: float64 scalar Tensor giving the initial time. 31 | dt: float64 scalar Tensor giving the size of the desired time step. 32 | tableau: optional _ButcherTableau describing how to take the Runge-Kutta 33 | step. 34 | f_options: additional keyworded arguments for func 35 | name: optional name for the operation. 36 | 37 | Returns: 38 | Tuple `(y1, f1, y1_error, k)` giving the estimated function value after 39 | the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`, 40 | estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for 41 | calculating these terms. 42 | """ 43 | dtype = y0[0].dtype 44 | device = y0[0].device 45 | 46 | t0 = _convert_to_tensor(t0, dtype=dtype, device=device) 47 | dt = _convert_to_tensor(dt, dtype=dtype, device=device) 48 | 49 | k = tuple(map(lambda x: [x], f0)) 50 | for alpha_i, beta_i in zip(tableau.alpha, tableau.beta): 51 | ti = t0 + alpha_i * dt 52 | yi = tuple(y0_ + _scaled_dot_product(dt, beta_i, k_) for y0_, k_ in zip(y0, k)) 53 | tuple(k_.append(f_) for k_, f_ in zip(k, func(ti, yi, **f_options))) 54 | 55 | if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]): 56 | # This property (true for Dormand-Prince) lets us save a few FLOPs. 57 | yi = tuple(y0_ + _scaled_dot_product(dt, tableau.c_sol, k_) for y0_, k_ in zip(y0, k)) 58 | 59 | y1 = yi 60 | f1 = tuple(k_[-1] for k_ in k) 61 | y1_error = tuple(_scaled_dot_product(dt, tableau.c_error, k_) for k_ in k) 62 | return (y1, f1, y1_error, k) 63 | 64 | 65 | def rk4_step_func(func, t, dt, y, f_options, k1=None): 66 | if k1 is None: k1 = func(t, y, **f_options) 67 | k2 = func(t + dt / 2, tuple(y_ + dt * k1_ / 2 for y_, k1_ in zip(y, k1)), **f_options) 68 | k3 = func(t + dt / 2, tuple(y_ + dt * k2_ / 2 for y_, k2_ in zip(y, k2)), **f_options) 69 | k4 = func(t + dt, tuple(y_ + dt * k3_ for y_, k3_ in zip(y, k3)), **f_options) 70 | return tuple((k1_ + 2 * k2_ + 2 * k3_ + k4_) * (dt / 6) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4)) 71 | 72 | 73 | def rk4_alt_step_func(func, t, dt, y, f_options, k1=None): 74 | """Smaller error with slightly more compute.""" 75 | if k1 is None: k1 = func(t, y, **f_options) 76 | k2 = func(t + dt / 3, tuple(y_ + dt * k1_ / 3 for y_, k1_ in zip(y, k1)), **f_options) 77 | k3 = func(t + dt * 2 / 3, tuple(y_ + dt * (k1_ / -3 + k2_) for y_, k1_, k2_ in zip(y, k1, k2)), **f_options) 78 | k4 = func(t + dt, tuple(y_ + dt * (k1_ - k2_ + k3_) for y_, k1_, k2_, k3_ in zip(y, k1, k2, k3)), **f_options) 79 | return tuple((k1_ + 3 * k2_ + 3 * k3_ + k4_) * (dt / 8) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4)) 80 | -------------------------------------------------------------------------------- /im2mesh/utils/torchdiffeq/torchdiffeq/_impl/solvers.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | from .misc import _assert_increasing, _handle_unused_kwargs 4 | 5 | 6 | class AdaptiveStepsizeODESolver(object): 7 | __metaclass__ = abc.ABCMeta 8 | 9 | def __init__(self, func, y0, f_options, atol=0.01, rtol=0.001, **unused_kwargs): 10 | _handle_unused_kwargs(self, unused_kwargs) 11 | del unused_kwargs 12 | 13 | self.func = func 14 | self.y0 = y0 15 | self.f_options = f_options 16 | 17 | self.atol = atol 18 | self.rtol = rtol 19 | 20 | def before_integrate(self, t): 21 | pass 22 | 23 | @abc.abstractmethod 24 | def advance(self, next_t): 25 | raise NotImplementedError 26 | 27 | def integrate(self, t): 28 | _assert_increasing(t) 29 | solution = [self.y0] 30 | t = t.to(self.y0[0].device, torch.float64) 31 | self.before_integrate(t) 32 | for i in range(1, len(t)): 33 | y = self.advance(t[i]) 34 | solution.append(y) 35 | return tuple(map(torch.stack, tuple(zip(*solution)))) 36 | 37 | 38 | class FixedGridODESolver(object): 39 | __metaclass__ = abc.ABCMeta 40 | 41 | def __init__(self, func, y0, f_options, step_size=None, grid_constructor=None, **unused_kwargs): 42 | unused_kwargs.pop('rtol', None) 43 | unused_kwargs.pop('atol', None) 44 | _handle_unused_kwargs(self, unused_kwargs) 45 | del unused_kwargs 46 | 47 | self.func = func 48 | self.y0 = y0 49 | self.f_options = f_options 50 | if step_size is not None and grid_constructor is None: 51 | self.grid_constructor = self._grid_constructor_from_step_size(step_size) 52 | elif grid_constructor is None: 53 | self.grid_constructor = lambda f, y0, t: t 54 | else: 55 | raise ValueError("step_size and grid_constructor are exclusive arguments.") 56 | 57 | def _grid_constructor_from_step_size(self, step_size): 58 | 59 | def _grid_constructor(func, y0, t): 60 | start_time = t[0] 61 | end_time = t[-1] 62 | 63 | niters = torch.ceil((end_time - start_time) / step_size + 1).item() 64 | t_infer = torch.arange(0, niters).to(t) * step_size + start_time 65 | ''' 66 | if t_infer[-1] > t[-1]: 67 | t_infer[-1] = t[-1] 68 | ''' 69 | if t_infer[-1] != t[-1]: 70 | t_infer[-1] = t[-1] 71 | 72 | return t_infer 73 | 74 | return _grid_constructor 75 | 76 | @property 77 | @abc.abstractmethod 78 | def order(self): 79 | pass 80 | 81 | @abc.abstractmethod 82 | def step_func(self, func, t, dt, y, f_options): 83 | pass 84 | 85 | def integrate(self, t): 86 | _assert_increasing(t) 87 | t = t.type_as(self.y0[0]) 88 | time_grid = self.grid_constructor(self.func, self.y0, t) 89 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1] 90 | time_grid = time_grid.to(self.y0[0]) 91 | 92 | solution = [self.y0] 93 | 94 | j = 1 95 | y0 = self.y0 96 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]): 97 | dy = self.step_func(self.func, t0, t1 - t0, y0, self.f_options) 98 | y1 = tuple(y0_ + dy_ for y0_, dy_ in zip(y0, dy)) 99 | y0 = y1 100 | 101 | while j < len(t) and t1 >= t[j]: 102 | solution.append(self._linear_interp(t0, t1, y0, y1, t[j])) 103 | j += 1 104 | 105 | return tuple(map(torch.stack, tuple(zip(*solution)))) 106 | 107 | def _linear_interp(self, t0, t1, y0, y1, t): 108 | if t == t0: 109 | return y0 110 | if t == t1: 111 | return y1 112 | t0, t1, t = t0.to(y0[0]), t1.to(y0[0]), t.to(y0[0]) 113 | slope = tuple((y1_ - y0_) / (t1 - t0) for y0_, y1_, in zip(y0, y1)) 114 | return tuple(y0_ + slope_ * (t - t0) for y0_, slope_ in zip(y0, slope)) 115 | -------------------------------------------------------------------------------- /im2mesh/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use('agg') 4 | from matplotlib import pyplot as plt 5 | from mpl_toolkits.mplot3d import Axes3D 6 | from torchvision.utils import save_image 7 | import im2mesh.common as common 8 | 9 | 10 | def visualize_data(data, data_type, out_file): 11 | r''' Visualizes the data with regard to its type. 12 | 13 | Args: 14 | data (tensor): batch of data 15 | data_type (string): data type (img, voxels or pointcloud) 16 | out_file (string): output file 17 | ''' 18 | if data_type == 'img': 19 | if data.dim() == 3: 20 | data = data.unsqueeze(0) 21 | save_image(data, out_file, nrow=4) 22 | elif data_type == 'voxels': 23 | visualize_voxels(data, out_file=out_file) 24 | elif data_type == 'pointcloud': 25 | visualize_pointcloud(data, out_file=out_file) 26 | elif data_type is None or data_type == 'idx': 27 | pass 28 | else: 29 | raise ValueError('Invalid data_type "%s"' % data_type) 30 | 31 | 32 | def visualize_voxels(voxels, out_file=None, show=False): 33 | r''' Visualizes voxel data. 34 | 35 | Args: 36 | voxels (tensor): voxel data 37 | out_file (string): output file 38 | show (bool): whether the plot should be shown 39 | ''' 40 | # Use numpy 41 | voxels = np.asarray(voxels) 42 | # Create plot 43 | fig = plt.figure() 44 | ax = fig.gca(projection=Axes3D.name) 45 | voxels = voxels.transpose(2, 0, 1) 46 | ax.voxels(voxels, edgecolor='k') 47 | ax.set_xlabel('Z') 48 | ax.set_ylabel('X') 49 | ax.set_zlabel('Y') 50 | ax.view_init(elev=30, azim=45) 51 | if out_file is not None: 52 | plt.savefig(out_file) 53 | if show: 54 | plt.show() 55 | plt.close(fig) 56 | 57 | 58 | def visualize_pointcloud(points, normals=None, 59 | out_file=None, show=False): 60 | r''' Visualizes point cloud data. 61 | 62 | Args: 63 | points (tensor): point data 64 | normals (tensor): normal data (if existing) 65 | out_file (string): output file 66 | show (bool): whether the plot should be shown 67 | ''' 68 | # Use numpy 69 | points = np.asarray(points) 70 | # Create plot 71 | fig = plt.figure() 72 | ax = fig.gca(projection=Axes3D.name) 73 | ax.scatter(points[:, 2], points[:, 0], points[:, 1]) 74 | if normals is not None: 75 | ax.quiver( 76 | points[:, 2], points[:, 0], points[:, 1], 77 | normals[:, 2], normals[:, 0], normals[:, 1], 78 | length=0.1, color='k' 79 | ) 80 | ax.set_xlabel('Z') 81 | ax.set_ylabel('X') 82 | ax.set_zlabel('Y') 83 | ax.set_xlim(-0.5, 0.5) 84 | ax.set_ylim(-0.5, 0.5) 85 | ax.set_zlim(-0.5, 0.5) 86 | ax.view_init(elev=30, azim=45) 87 | if out_file is not None: 88 | plt.savefig(out_file) 89 | if show: 90 | plt.show() 91 | plt.close(fig) 92 | 93 | 94 | def visualise_projection( 95 | self, points, world_mat, camera_mat, img, output_file='out.png'): 96 | r''' Visualizes the transformation and projection to image plane. 97 | 98 | The first points of the batch are transformed and projected to the 99 | respective image. After performing the relevant transformations, the 100 | visualization is saved in the provided output_file path. 101 | 102 | Arguments: 103 | points (tensor): batch of point cloud points 104 | world_mat (tensor): batch of matrices to rotate pc to camera-based 105 | coordinates 106 | camera_mat (tensor): batch of camera matrices to project to 2D image 107 | plane 108 | img (tensor): tensor of batch GT image files 109 | output_file (string): where the output should be saved 110 | ''' 111 | points_transformed = common.transform_points(points, world_mat) 112 | points_img = common.project_to_camera(points_transformed, camera_mat) 113 | pimg2 = points_img[0].detach().cpu().numpy() 114 | image = img[0].cpu().numpy() 115 | plt.imshow(image.transpose(1, 2, 0)) 116 | plt.plot( 117 | (pimg2[:, 0] + 1)*image.shape[1]/2, 118 | (pimg2[:, 1] + 1) * image.shape[2]/2, 'x') 119 | plt.savefig(output_file) 120 | -------------------------------------------------------------------------------- /scripts/build_dataset.sh: -------------------------------------------------------------------------------- 1 | source config.sh 2 | 3 | # Make output directories 4 | mkdir -p $BUILD_PATH 5 | 6 | mkdir -p $BUILD_PATH/0_points \ 7 | $BUILD_PATH/0_pointcloud \ 8 | $BUILD_PATH/0_meshes #\ 9 | # $BUILD_PATH/0_render 10 | 11 | echo " Building Human Dataset for Occupancy Flow Project." 12 | echo " Input Path: $INPUT_PATH" 13 | echo " Build Path: $BUILD_PATH" 14 | 15 | echo "Sample points ..." 16 | python sample_mesh.py $INPUT_PATH \ 17 | --n_proc $NPROC --resize \ 18 | --points_folder $BUILD_PATH/0_points \ 19 | --overwrite --float16 --packbits 20 | echo "done!" 21 | 22 | echo "Sample pointcloud" 23 | python sample_mesh.py $INPUT_PATH \ 24 | --n_proc $NPROC --resize \ 25 | --pointcloud_folder $BUILD_PATH/0_pointcloud \ 26 | --overwrite --float16 27 | echo "done" 28 | 29 | echo "Copy mesh data." 30 | inputs=$(lsfilter $INPUT_PATH) 31 | for m in ${inputs[@]}; do 32 | m_path="$INPUT_PATH/$m" 33 | mesh_files=$(lsfilter $m_path) 34 | out_path="$BUILD_PATH/0_meshes/$m" 35 | mkdir -p $out_path 36 | echo "Copy for model $m" 37 | for f in ${mesh_files[@]}; do 38 | mesh_file="$m_path/$f" 39 | out_file="$out_path/$f" 40 | cp $mesh_file $out_file 41 | done 42 | done 43 | echo "done" 44 | 45 | # echo "Render sequence (camera on circle)" 46 | # inputs=$(lsfilter $INPUT_PATH $BUILD_PATH/0_render /camera.npz) # 47 | # for f in ${inputs[@]}; do 48 | # # lsfilter $INPUT_PATH $BUILD_PATH/0_render /camera.npz | parallel -P $NPROC \ 49 | # blender --background --python render_blender.py -- \ 50 | # --output_folder $BUILD_PATH/0_render \ 51 | # --views $N_VIEWS \ 52 | # --camera circle \ 53 | # $INPUT_PATH/$f 54 | # done 55 | 56 | -------------------------------------------------------------------------------- /scripts/build_dataset_incomplete.sh: -------------------------------------------------------------------------------- 1 | source config.sh 2 | 3 | # Make output directories 4 | mkdir -p $BUILD_PATH 5 | 6 | 7 | echo " Building Human Dataset for Occupancy Flow Project." 8 | echo " Input Path: $INPUT_PATH" 9 | echo " Build Path: $BUILD_PATH" 10 | 11 | 12 | echo "Sample points sdf and pointcloud ..." 13 | python compute_incomplete.py $INPUT_PATH \ 14 | --n_proc 8 --resize \ 15 | --partial_pointcloud_size 30000 --start 0.0 --end 1.0 --float16 --radius 0.1 16 | echo "done!" 17 | -------------------------------------------------------------------------------- /scripts/config.sh: -------------------------------------------------------------------------------- 1 | ROOT=/data2/lab-tang.jiapeng 2 | export HDF5_USE_FILE_LOCKING=FALSE # Workaround for NFS mounts 3 | 4 | INPUT_PATH=$ROOT/D-FAUST/scripts/registration_meshes 5 | 6 | BUILD_PATH=$ROOT/Humans.build 7 | OUTPUT_PATH=$ROOT/Humans/D-FAUST 8 | 9 | NPROC=6 10 | TIMEOUT=180 11 | #N_VIEWS=4 12 | 13 | # Utility functions 14 | lsfilter() { 15 | folder=$1 16 | other_folder=$2 17 | ext=$3 18 | 19 | for f in $folder/*; do 20 | filename=$(basename $f) 21 | if [ ! -f $other_folder/$filename$ext ] && [ ! -d $other_folder/$filename$ext ]; then 22 | echo $filename 23 | fi 24 | done 25 | } 26 | -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | mkdir data 2 | cd data 3 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/occupancy_flow/data/Humans.zip 4 | unzip Humans.zip -------------------------------------------------------------------------------- /scripts/install_dataset.sh: -------------------------------------------------------------------------------- 1 | source config.sh 2 | 3 | # Function for processing a single model 4 | organize_model() { 5 | filename=$(basename -- $3) 6 | modelname="${filename%.*}" 7 | output_path="$2/$modelname" 8 | build_path=$1 9 | 10 | points_folder="$build_path/0_points/$modelname" 11 | points_out_folder="$output_path/points_seq" 12 | 13 | # img_folder="$build_path/0_render/$modelname" 14 | # img_out_folder="$output_path/img" 15 | 16 | pointcloud_folder="$build_path/0_pointcloud/$modelname" 17 | pointcloud_out_folder="$output_path/pcl_seq" 18 | 19 | mesh_folder="$build_path/0_meshes/$modelname" 20 | mesh_out_folder="$output_path/mesh_seq" 21 | 22 | # if [ -d $points_folder ] \ 23 | # && [ -d $pointcloud_folder ] \ 24 | # && [ -d $img_folder ] \ 25 | # && [ -d $mesh_folder ] \ 26 | if [ -d $points_folder ] \ 27 | && [ -d $pointcloud_folder ] \ 28 | && [ -d $mesh_folder ] \ 29 | ; then 30 | echo "Copying model $output_path" 31 | mkdir -p "$output_path" 32 | 33 | cp -rT $points_folder $points_out_folder 34 | cp -rT $pointcloud_folder $pointcloud_out_folder 35 | #Qcp -rT $img_folder $img_out_folder 36 | cp -rT $mesh_folder $mesh_out_folder 37 | fi 38 | } 39 | 40 | echo "Installing Humans Dataset for Occupancy Flow" 41 | echo "Output Directory: $OUTPUT_PATH" 42 | 43 | export -f organize_model 44 | 45 | # Make output directories 46 | mkdir -p $OUTPUT_PATH 47 | 48 | # Run install 49 | ls $INPUT_PATH | parallel -P $NPROC \ 50 | organize_model $BUILD_PATH $OUTPUT_PATH {} 51 | 52 | # Copy Split Files 53 | cp split_files/* $OUTPUT_PATH/ 54 | 55 | echo "done!" 56 | -------------------------------------------------------------------------------- /scripts/migrate_dfaust.sh: -------------------------------------------------------------------------------- 1 | # This script migrates D-FAUST mesh data and the provided point and point cloud data. 2 | 3 | IN_PATH=$1 4 | OUT_PATH=/data2/lab-tang.jiapeng/Humans/D-FAUST 5 | mesh_folder_name='mesh_seq' 6 | 7 | model_files=$(ls $IN_PATH) 8 | 9 | echo "Copying mesh data from $IN_PATH to dataset in $OUT_PATH ..." 10 | for m in ${model_files[@]} 11 | do 12 | echo "Processing model $m ..." 13 | model_folder_in=$IN_PATH/$m 14 | model_folder_out=$OUT_PATH/$m/$mesh_folder_name 15 | cp -R $model_folder_in $model_folder_out 16 | echo "done (model)!" 17 | done 18 | echo "done (dataset)! 19 | -------------------------------------------------------------------------------- /scripts/split_files/overfit.lst: -------------------------------------------------------------------------------- 1 | 50026_one_leg_jump 2 | -------------------------------------------------------------------------------- /scripts/split_files/test.lst: -------------------------------------------------------------------------------- 1 | 50002_light_hopping_loose 2 | 50004_punching 3 | 50007_shake_shoulders 4 | 50009_chicken_wings 5 | 50020_chicken_wings 6 | 50022_light_hopping_loose 7 | 50025_light_hopping_loose 8 | 50026_shake_arms 9 | 50027_shake_shoulders 10 | -------------------------------------------------------------------------------- /scripts/split_files/test_new_individual.lst: -------------------------------------------------------------------------------- 1 | 50021_chicken_wings 2 | 50021_knees 3 | 50021_one_leg_jump 4 | 50021_punching 5 | 50021_shake_arms 6 | 50021_shake_shoulders 7 | 50021_hips 8 | 50021_light_hopping_stiff 9 | 50021_one_leg_loose 10 | 50021_running_on_spot 11 | 50021_shake_hips 12 | -------------------------------------------------------------------------------- /scripts/split_files/train.lst: -------------------------------------------------------------------------------- 1 | 50002_one_leg_loose 2 | 50025_shake_shoulders 3 | 50007_one_leg_jump 4 | 50002_knees 5 | 50002_light_hopping_stiff 6 | 50004_jiggle_on_toes 7 | 50004_shake_hips 8 | 50026_chicken_wings 9 | 50007_punching 10 | 50022_one_leg_jump 11 | 50009_light_hopping_stiff 12 | 50025_light_hopping_stiff 13 | 50007_shake_arms 14 | 50026_running_on_spot 15 | 50025_chicken_wings 16 | 50020_shake_hips 17 | 50026_shake_hips 18 | 50027_light_hopping_stiff 19 | 50009_shake_hips 20 | 50009_light_hopping_loose 21 | 50020_jiggle_on_toes 22 | 50025_one_leg_loose 23 | 50009_punching 24 | 50027_hips 25 | 50002_running_on_spot 26 | 50026_light_hopping_stiff 27 | 50026_jiggle_on_toes 28 | 50020_shake_shoulders 29 | 50007_light_hopping_stiff 30 | 50007_jiggle_on_toes 31 | 50027_punching 32 | 50009_running_on_spot 33 | 50002_shake_arms 34 | 50022_knees 35 | 50007_jumping_jacks 36 | 50027_running_on_spot 37 | 50022_running_on_spot 38 | 50004_knees 39 | 50027_one_leg_loose 40 | 50009_one_leg_jump 41 | 50026_jumping_jacks 42 | 50009_one_leg_loose 43 | 50027_jiggle_on_toes 44 | 50020_knees 45 | 50027_light_hopping_loose 46 | 50026_knees 47 | 50004_jumping_jacks 48 | 50026_one_leg_jump 49 | 50004_hips 50 | 50027_shake_arms 51 | 50026_hips 52 | 50020_punching 53 | 50025_jiggle_on_toes 54 | 50022_punching 55 | 50004_light_hopping_loose 56 | 50009_jumping_jacks 57 | 50002_one_leg_jump 58 | 50007_knees 59 | 50027_jumping_jacks 60 | 50022_shake_arms 61 | 50002_jiggle_on_toes 62 | 50002_hips 63 | 50009_jiggle_on_toes 64 | 50022_jiggle_on_toes 65 | 50007_shake_hips 66 | 50022_hips 67 | 50026_punching 68 | 50026_one_leg_loose 69 | 50025_running_on_spot 70 | 50025_knees 71 | 50025_hips 72 | 50020_light_hopping_stiff 73 | 50026_shake_shoulders 74 | 50002_jumping_jacks 75 | 50022_light_hopping_stiff 76 | 50027_one_leg_jump 77 | 50002_punching 78 | 50022_shake_shoulders 79 | 50004_running_on_spot 80 | 50020_light_hopping_loose 81 | 50022_shake_hips 82 | 50004_one_leg_loose 83 | 50022_jumping_jacks 84 | 50002_chicken_wings 85 | 50022_one_leg_loose 86 | 50027_knees 87 | 50004_shake_arms 88 | 50007_chicken_wings 89 | 50002_shake_hips 90 | 50007_one_leg_loose 91 | 50004_shake_shoulders 92 | 50009_hips 93 | 50007_running_on_spot 94 | 50025_shake_hips 95 | 50002_shake_shoulders 96 | 50020_shake_arms 97 | 50027_shake_hips 98 | 50026_light_hopping_loose 99 | 50025_one_leg_jump 100 | 50020_one_leg_loose 101 | 50004_one_leg_jump 102 | 50025_punching 103 | 50020_one_leg_jump 104 | 50004_light_hopping_stiff 105 | -------------------------------------------------------------------------------- /scripts/split_files/train_generative.lst: -------------------------------------------------------------------------------- 1 | 50002_one_leg_loose 2 | 50025_shake_shoulders 3 | 50007_one_leg_jump 4 | 50002_knees 5 | 50002_light_hopping_stiff 6 | 50004_jiggle_on_toes 7 | 50004_shake_hips 8 | 50026_chicken_wings 9 | 50007_punching 10 | 50022_one_leg_jump 11 | 50009_light_hopping_stiff 12 | 50025_light_hopping_stiff 13 | 50007_shake_arms 14 | 50026_running_on_spot 15 | 50025_chicken_wings 16 | 50020_shake_hips 17 | 50026_shake_hips 18 | 50027_light_hopping_stiff 19 | 50009_shake_hips 20 | 50009_light_hopping_loose 21 | 50020_jiggle_on_toes 22 | 50025_one_leg_loose 23 | 50009_punching 24 | 50027_hips 25 | 50002_running_on_spot 26 | 50026_light_hopping_stiff 27 | 50026_jiggle_on_toes 28 | 50020_shake_shoulders 29 | 50007_light_hopping_stiff 30 | 50007_jiggle_on_toes 31 | 50027_punching 32 | 50009_running_on_spot 33 | 50002_shake_arms 34 | 50022_knees 35 | 50007_jumping_jacks 36 | 50027_running_on_spot 37 | 50022_running_on_spot 38 | 50004_knees 39 | 50027_one_leg_loose 40 | 50009_one_leg_jump 41 | 50026_jumping_jacks 42 | 50009_one_leg_loose 43 | 50027_jiggle_on_toes 44 | 50020_knees 45 | 50027_light_hopping_loose 46 | 50026_knees 47 | 50004_jumping_jacks 48 | 50026_one_leg_jump 49 | 50004_hips 50 | 50027_shake_arms 51 | 50026_hips 52 | 50020_punching 53 | 50025_jiggle_on_toes 54 | 50022_punching 55 | 50004_light_hopping_loose 56 | 50009_jumping_jacks 57 | 50002_one_leg_jump 58 | 50007_knees 59 | 50027_jumping_jacks 60 | 50022_shake_arms 61 | 50002_jiggle_on_toes 62 | 50002_hips 63 | 50009_jiggle_on_toes 64 | 50022_jiggle_on_toes 65 | 50007_shake_hips 66 | 50022_hips 67 | 50026_punching 68 | 50026_one_leg_loose 69 | 50025_running_on_spot 70 | 50025_knees 71 | 50025_hips 72 | 50020_light_hopping_stiff 73 | 50026_shake_shoulders 74 | 50002_jumping_jacks 75 | 50022_light_hopping_stiff 76 | 50027_one_leg_jump 77 | 50002_punching 78 | 50022_shake_shoulders 79 | 50004_running_on_spot 80 | 50020_light_hopping_loose 81 | 50022_shake_hips 82 | 50004_one_leg_loose 83 | 50022_jumping_jacks 84 | 50002_chicken_wings 85 | 50022_one_leg_loose 86 | 50027_knees 87 | 50004_shake_arms 88 | 50007_chicken_wings 89 | 50002_shake_hips 90 | 50007_one_leg_loose 91 | 50004_shake_shoulders 92 | 50009_hips 93 | 50007_running_on_spot 94 | 50025_shake_hips 95 | 50002_shake_shoulders 96 | 50020_shake_arms 97 | 50027_shake_hips 98 | 50026_light_hopping_loose 99 | 50025_one_leg_jump 100 | 50020_one_leg_loose 101 | 50004_one_leg_jump 102 | 50025_punching 103 | 50020_one_leg_jump 104 | 50004_light_hopping_stiff 105 | 50002_light_hopping_loose 106 | 50004_punching 107 | 50007_shake_shoulders 108 | 50009_chicken_wings 109 | 50020_chicken_wings 110 | 50022_light_hopping_loose 111 | 50025_light_hopping_loose 112 | 50026_shake_arms 113 | 50027_shake_shoulders 114 | 50021_chicken_wings 115 | 50021_knees 116 | 50021_one_leg_jump 117 | 50021_punching 118 | 50021_shake_arms 119 | 50021_shake_shoulders 120 | 50021_hips 121 | 50021_light_hopping_stiff 122 | 50021_one_leg_loose 123 | 50021_running_on_spot 124 | 50021_shake_hips 125 | -------------------------------------------------------------------------------- /scripts/split_files/val.lst: -------------------------------------------------------------------------------- 1 | 50004_chicken_wings 2 | 50020_hips 3 | 50025_shake_arms 4 | 50020_running_on_spot 5 | 50007_light_hopping_loose 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | try: 2 | from setuptools import setup 3 | except ImportError: 4 | from distutils.core import setup 5 | from distutils.extension import Extension 6 | from Cython.Build import cythonize 7 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 8 | import numpy 9 | 10 | 11 | # Get the numpy include directory. 12 | numpy_include_dir = numpy.get_include() 13 | 14 | # Extensions 15 | # pykdtree (kd tree) 16 | pykdtree = Extension( 17 | 'im2mesh.utils.libkdtree.pykdtree.kdtree', 18 | sources=[ 19 | 'im2mesh/utils/libkdtree/pykdtree/kdtree.c', 20 | 'im2mesh/utils/libkdtree/pykdtree/_kdtree_core.c' 21 | ], 22 | language='c', 23 | extra_compile_args=['-std=c99', '-O3', '-fopenmp'], 24 | extra_link_args=['-lgomp'], 25 | include_dirs=[numpy_include_dir] 26 | ) 27 | 28 | # mcubes (marching cubes algorithm) 29 | mcubes_module = Extension( 30 | 'im2mesh.utils.libmcubes.mcubes', 31 | sources=[ 32 | 'im2mesh/utils/libmcubes/mcubes.pyx', 33 | 'im2mesh/utils/libmcubes/pywrapper.cpp', 34 | 'im2mesh/utils/libmcubes/marchingcubes.cpp' 35 | ], 36 | language='c++', 37 | extra_compile_args=['-std=c++11'], 38 | include_dirs=[numpy_include_dir] 39 | ) 40 | 41 | # triangle hash (efficient mesh intersection) 42 | triangle_hash_module = Extension( 43 | 'im2mesh.utils.libmesh.triangle_hash', 44 | sources=[ 45 | 'im2mesh/utils/libmesh/triangle_hash.pyx' 46 | ], 47 | libraries=['m'], # Unix-like specific 48 | include_dirs=[numpy_include_dir] 49 | ) 50 | 51 | # mise (efficient mesh extraction) 52 | mise_module = Extension( 53 | 'im2mesh.utils.libmise.mise', 54 | sources=[ 55 | 'im2mesh/utils/libmise/mise.pyx' 56 | ], 57 | ) 58 | 59 | # simplify (efficient mesh simplification) 60 | simplify_mesh_module = Extension( 61 | 'im2mesh.utils.libsimplify.simplify_mesh', 62 | sources=[ 63 | 'im2mesh/utils/libsimplify/simplify_mesh.pyx' 64 | ], 65 | include_dirs=[numpy_include_dir] 66 | ) 67 | 68 | # voxelization (efficient mesh voxelization) 69 | voxelize_module = Extension( 70 | 'im2mesh.utils.libvoxelize.voxelize', 71 | sources=[ 72 | 'im2mesh/utils/libvoxelize/voxelize.pyx' 73 | ], 74 | libraries=['m'], # Unix-like specific 75 | include_dirs=[numpy_include_dir] 76 | ) 77 | ''' 78 | # DMC extensions 79 | dmc_pred2mesh_module = CppExtension( 80 | 'im2mesh.dmc.ops.cpp_modules.pred2mesh', 81 | sources=[ 82 | 'im2mesh/dmc/ops/cpp_modules/pred_to_mesh_.cpp', 83 | ], 84 | include_dirs=[numpy_include_dir] 85 | ) 86 | 87 | dmc_cuda_module = CUDAExtension( 88 | 'im2mesh.dmc.ops._cuda_ext', 89 | sources=[ 90 | 'im2mesh/dmc/ops/src/extension.cpp', 91 | 'im2mesh/dmc/ops/src/curvature_constraint_kernel.cu', 92 | 'im2mesh/dmc/ops/src/grid_pooling_kernel.cu', 93 | 'im2mesh/dmc/ops/src/occupancy_to_topology_kernel.cu', 94 | 'im2mesh/dmc/ops/src/occupancy_connectivity_kernel.cu', 95 | 'im2mesh/dmc/ops/src/point_triangle_distance_kernel.cu', 96 | ] 97 | ) 98 | ''' 99 | 100 | # Gather all extension modules 101 | ext_modules = [ 102 | pykdtree, 103 | mcubes_module, 104 | triangle_hash_module, 105 | mise_module, 106 | simplify_mesh_module, 107 | voxelize_module, 108 | #dmc_pred2mesh_module, 109 | #dmc_cuda_module, 110 | ] 111 | 112 | setup( 113 | ext_modules=cythonize(ext_modules), 114 | cmdclass={ 115 | 'build_ext': BuildExtension 116 | } 117 | ) 118 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from im2mesh.checkpoints import CheckpointIO 2 | from im2mesh import config, data 3 | import torch 4 | import torch.optim as optim 5 | from tensorboardX import SummaryWriter 6 | import numpy as np 7 | import os 8 | import argparse 9 | import time 10 | 11 | 12 | # Arguments 13 | parser = argparse.ArgumentParser( 14 | description='Train a 4D model.' 15 | ) 16 | parser.add_argument('config', type=str, help='Path to config file.') 17 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 18 | parser.add_argument('--exit-after', type=int, default=-1, 19 | help='Checkpoint and exit after specified number of ' 20 | 'seconds with exit code 2.') 21 | 22 | args = parser.parse_args() 23 | cfg = config.load_config(args.config, 'configs/default.yaml') 24 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 25 | device = torch.device("cuda" if is_cuda else "cpu") 26 | 27 | # Set t0 28 | t0 = time.time() 29 | 30 | # Shorthands 31 | out_dir = cfg['training']['out_dir'] 32 | batch_size = cfg['training']['batch_size'] 33 | batch_size_vis = cfg['training']['batch_size_vis'] 34 | batch_size_val = cfg['training']['batch_size_val'] 35 | backup_every = cfg['training']['backup_every'] 36 | exit_after = args.exit_after 37 | lr = cfg['training']['learning_rate'] 38 | 39 | model_selection_metric = cfg['training']['model_selection_metric'] 40 | if cfg['training']['model_selection_mode'] == 'maximize': 41 | model_selection_sign = 1 42 | elif cfg['training']['model_selection_mode'] == 'minimize': 43 | model_selection_sign = -1 44 | else: 45 | raise ValueError('model_selection_mode must be ' 46 | 'either maximize or minimize.') 47 | 48 | # Output directory 49 | if not os.path.exists(out_dir): 50 | os.makedirs(out_dir) 51 | 52 | # Dataset 53 | train_dataset = config.get_dataset('train', cfg) 54 | val_dataset = config.get_dataset('val', cfg) 55 | 56 | # Dataloader 57 | train_loader = torch.utils.data.DataLoader( 58 | train_dataset, batch_size=batch_size, num_workers=4, shuffle=True, 59 | collate_fn=data.collate_remove_none, 60 | worker_init_fn=data.worker_init_fn) 61 | val_loader = torch.utils.data.DataLoader( 62 | val_dataset, batch_size=batch_size_val, num_workers=1, shuffle=False, 63 | collate_fn=data.collate_remove_none, 64 | worker_init_fn=data.worker_init_fn) 65 | 66 | # For visualizations 67 | vis_loader = torch.utils.data.DataLoader( 68 | val_dataset, batch_size=batch_size_vis, shuffle=True, 69 | collate_fn=data.collate_remove_none, 70 | worker_init_fn=data.worker_init_fn) 71 | data_vis = next(iter(vis_loader)) 72 | 73 | # Model 74 | model = config.get_model(cfg, device=device, dataset=train_dataset) 75 | 76 | # Get optimizer and trainer 77 | optimizer = optim.Adam(model.parameters(), lr=lr) 78 | trainer = config.get_trainer(model, optimizer, cfg, device=device) 79 | 80 | # Load pre-trained model is existing 81 | kwargs = { 82 | 'model': model, 83 | 'optimizer': optimizer, 84 | } 85 | checkpoint_io = CheckpointIO( 86 | out_dir, initialize_from=cfg['model']['initialize_from'], 87 | initialization_file_name=cfg['model']['initialization_file_name'], 88 | **kwargs) 89 | try: 90 | load_dict = checkpoint_io.load('model.pt') 91 | except FileExistsError: 92 | load_dict = dict() 93 | epoch_it = load_dict.get('epoch_it', -1) 94 | it = load_dict.get('it', -1) 95 | metric_val_best = load_dict.get( 96 | 'loss_val_best', -model_selection_sign * np.inf) 97 | 98 | if metric_val_best == np.inf or metric_val_best == -np.inf: 99 | metric_val_best = -model_selection_sign * np.inf 100 | 101 | print('Current best validation metric (%s): %.8f' 102 | % (model_selection_metric, metric_val_best)) 103 | 104 | logger = SummaryWriter(os.path.join(out_dir, 'logs')) 105 | 106 | # Shorthands 107 | print_every = cfg['training']['print_every'] 108 | checkpoint_every = cfg['training']['checkpoint_every'] 109 | validate_every = cfg['training']['validate_every'] 110 | visualize_every = cfg['training']['visualize_every'] 111 | 112 | # Print model 113 | nparameters = sum(p.numel() for p in model.parameters()) 114 | print(model) 115 | print('Total number of parameters: %d' % nparameters) 116 | 117 | # Training loop 118 | while True: 119 | epoch_it += 1 120 | 121 | for batch in train_loader: 122 | it += 1 123 | loss = trainer.train_step(batch) 124 | logger.add_scalar('train/loss', loss, it) 125 | 126 | # Print output 127 | if print_every > 0 and (it % print_every) == 0: 128 | print('[Epoch %02d] it=%03d, loss=%.4f' 129 | % (epoch_it, it, loss)) 130 | 131 | # Visualize output 132 | if visualize_every > 0 and (it % visualize_every) == 0: 133 | print('Visualizing') 134 | trainer.visualize(data_vis) 135 | 136 | # Save checkpoint 137 | if (checkpoint_every > 0 and (it % checkpoint_every) == 0): 138 | print('Saving checkpoint') 139 | checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it, 140 | loss_val_best=metric_val_best) 141 | 142 | # Backup if necessary 143 | if (backup_every > 0 and (it % backup_every) == 0): 144 | print('Backup checkpoint') 145 | checkpoint_io.save('model_%d.pt' % it, epoch_it=epoch_it, it=it, 146 | loss_val_best=metric_val_best) 147 | # Run validation 148 | if validate_every > 0 and (it % validate_every) == 0: 149 | eval_dict = trainer.evaluate(val_loader) 150 | metric_val = eval_dict[model_selection_metric] 151 | print('Validation metric (%s): %.4f' 152 | % (model_selection_metric, metric_val)) 153 | 154 | for k, v in eval_dict.items(): 155 | logger.add_scalar('val/%s' % k, v, it) 156 | 157 | if model_selection_sign * (metric_val - metric_val_best) > 0: 158 | metric_val_best = metric_val 159 | print('New best model (loss %.4f)' % metric_val_best) 160 | checkpoint_io.save('model_best.pt', epoch_it=epoch_it, it=it, 161 | loss_val_best=metric_val_best) 162 | 163 | # Exit if necessary 164 | if exit_after > 0 and (time.time() - t0) >= exit_after: 165 | print('Time limit reached. Exiting.') 166 | checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it, 167 | loss_val_best=metric_val_best) 168 | exit(3) 169 | --------------------------------------------------------------------------------