├── LICENSE
├── README.md
├── animations
├── input_full.gif
├── output_full.gif
└── teaser.png
├── configs
├── default.yaml
├── demo.yaml
└── noflow
│ ├── lpdc_completion.yaml
│ ├── lpdc_completion_pretrained.yaml
│ ├── lpdc_even.yaml
│ ├── lpdc_even_pretrained.yaml
│ ├── lpdc_uneven.yaml
│ └── lpdc_uneven_pretrained.yaml
├── data
└── Humans
│ ├── D-FAUST
│ ├── overfit.lst
│ ├── test.lst
│ ├── test_new_individual.lst
│ ├── train.lst
│ └── val.lst
│ └── Demo
│ ├── 50002_light_hopping_loose
│ └── pcl_seq
│ │ ├── 00000080.npz
│ │ ├── 00000081.npz
│ │ ├── 00000082.npz
│ │ ├── 00000083.npz
│ │ ├── 00000084.npz
│ │ ├── 00000085.npz
│ │ ├── 00000086.npz
│ │ ├── 00000087.npz
│ │ ├── 00000088.npz
│ │ ├── 00000089.npz
│ │ ├── 00000090.npz
│ │ ├── 00000091.npz
│ │ ├── 00000092.npz
│ │ ├── 00000093.npz
│ │ ├── 00000094.npz
│ │ ├── 00000095.npz
│ │ └── 00000096.npz
│ └── 50004_punching
│ └── pcl_seq
│ ├── 00000166.npz
│ ├── 00000167.npz
│ ├── 00000168.npz
│ ├── 00000169.npz
│ ├── 00000170.npz
│ ├── 00000171.npz
│ ├── 00000172.npz
│ ├── 00000173.npz
│ ├── 00000174.npz
│ ├── 00000175.npz
│ ├── 00000176.npz
│ ├── 00000177.npz
│ ├── 00000178.npz
│ ├── 00000179.npz
│ ├── 00000180.npz
│ ├── 00000181.npz
│ └── 00000182.npz
├── docs
└── index.html
├── environment.yaml
├── eval.py
├── generate.py
├── im2mesh
├── __init__.py
├── __init__.pyc
├── checkpoints.py
├── checkpoints.pyc
├── common.py
├── config.py
├── config.pyc
├── data
│ ├── __init__.py
│ ├── __init__.pyc
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── core.cpython-36.pyc
│ │ ├── fields.cpython-36.pyc
│ │ ├── subseq_dataset.cpython-36.pyc
│ │ └── transforms.cpython-36.pyc
│ ├── core.py
│ ├── core.pyc
│ ├── fields.py
│ ├── fields.pyc
│ ├── subseq_dataset.py
│ ├── subseq_dataset.pyc
│ ├── transforms.py
│ └── transforms.pyc
├── encoder
│ ├── __init__.py
│ ├── __init__.pyc
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── conv.cpython-36.pyc
│ │ ├── pointnet.cpython-36.pyc
│ │ ├── pointnet_unet.cpython-36.pyc
│ │ ├── unet.cpython-36.pyc
│ │ └── unet3d.cpython-36.pyc
│ ├── conv.pyc
│ └── pointnet.py
├── eval.py
├── layers.py
├── lpdc
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── config.cpython-36.pyc
│ │ ├── generation.cpython-36.pyc
│ │ └── training.cpython-36.pyc
│ ├── config.py
│ ├── generation.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── decoder.cpython-36.pyc
│ │ │ ├── decoder_unet.cpython-36.pyc
│ │ │ └── displacement.cpython-36.pyc
│ │ ├── decoder.py
│ │ └── displacement.py
│ └── training.py
├── lpdc_uneven
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── config.cpython-36.pyc
│ │ ├── generation.cpython-36.pyc
│ │ └── training.cpython-36.pyc
│ ├── config.py
│ ├── generation.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── decoder.cpython-36.pyc
│ │ │ ├── decoder_unet.cpython-36.pyc
│ │ │ └── displacement.cpython-36.pyc
│ │ ├── decoder.py
│ │ └── displacement.py
│ └── training.py
├── training.py
└── utils
│ ├── __init__.py
│ ├── __init__.pyc
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── grad.cpython-36.pyc
│ ├── io.cpython-36.pyc
│ ├── onet_generator.cpython-36.pyc
│ └── visualize.cpython-36.pyc
│ ├── binvox_rw.py
│ ├── grad.py
│ ├── icp.py
│ ├── io.py
│ ├── libkdtree
│ ├── .gitignore
│ ├── LICENSE.txt
│ ├── MANIFEST.in
│ ├── README
│ ├── README.rst
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-36.pyc
│ ├── pykdtree
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ └── __init__.cpython-36.pyc
│ │ ├── _kdtree_core.c
│ │ ├── _kdtree_core.c.mako
│ │ ├── kdtree.c
│ │ ├── kdtree.cpython-36m-x86_64-linux-gnu.so
│ │ ├── kdtree.pyx
│ │ ├── render_template.py
│ │ └── test_tree.py
│ ├── setup.cfg
│ └── setup.py
│ ├── libmcubes
│ ├── .gitignore
│ ├── LICENSE
│ ├── README.rst
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── exporter.cpython-36.pyc
│ ├── exporter.py
│ ├── marchingcubes.cpp
│ ├── marchingcubes.h
│ ├── mcubes.cpp
│ ├── mcubes.cpython-36m-x86_64-linux-gnu.so
│ ├── mcubes.pyx
│ ├── pyarray_symbol.h
│ ├── pyarraymodule.h
│ ├── pywrapper.cpp
│ ├── pywrapper.h
│ └── setup.py
│ ├── libmesh
│ ├── .gitignore
│ ├── __init__.py
│ ├── __init__.pyc
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── inside_mesh.cpython-36.pyc
│ ├── inside_mesh.py
│ ├── inside_mesh.pyc
│ ├── setup.py
│ ├── triangle_hash.cpython-36m-x86_64-linux-gnu.so
│ └── triangle_hash.pyx
│ ├── libmise
│ ├── .gitignore
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-36.pyc
│ ├── mise.cpython-36m-x86_64-linux-gnu.so
│ ├── mise.pyx
│ ├── setup.py
│ └── test.py
│ ├── libsimplify
│ ├── Simplify.h
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-36.pyc
│ ├── setup.py
│ ├── simplify_mesh.cpp
│ ├── simplify_mesh.cpython-36m-x86_64-linux-gnu.so
│ ├── simplify_mesh.pyx
│ └── test.py
│ ├── libvoxelize
│ ├── .gitignore
│ ├── __init__.py
│ ├── setup.py
│ ├── tribox2.h
│ ├── voxelize.cpython-36m-x86_64-linux-gnu.so
│ └── voxelize.pyx
│ ├── mesh.py
│ ├── onet_generator.py
│ ├── torchdiffeq
│ ├── .gitignore
│ ├── LICENSE
│ ├── README.md
│ ├── assets
│ │ ├── ode_demo.gif
│ │ ├── odenet_0_viz.png
│ │ └── resnet_0_viz.png
│ ├── examples
│ │ ├── README.md
│ │ ├── latent_ode.py
│ │ ├── ode_demo.py
│ │ └── odenet_mnist.py
│ ├── setup.py
│ ├── tests
│ │ ├── api_tests.py
│ │ ├── gradient_tests.py
│ │ ├── odeint_tests.py
│ │ ├── problems.py
│ │ └── run_all.py
│ └── torchdiffeq
│ │ ├── __init__.py
│ │ └── _impl
│ │ ├── __init__.py
│ │ ├── adams.py
│ │ ├── adjoint.py
│ │ ├── dopri5.py
│ │ ├── fixed_adams.py
│ │ ├── fixed_grid.py
│ │ ├── interp.py
│ │ ├── misc.py
│ │ ├── odeint.py
│ │ ├── rk_common.py
│ │ ├── solvers.py
│ │ └── tsit5.py
│ ├── visualize.py
│ └── voxels.py
├── scripts
├── build_dataset.sh
├── build_dataset_incomplete.sh
├── compute_incomplete.py
├── config.sh
├── download_data.sh
├── install_dataset.sh
├── migrate_dfaust.sh
├── sample_mesh.py
└── split_files
│ ├── overfit.lst
│ ├── test.lst
│ ├── test_new_individual.lst
│ ├── train.lst
│ ├── train_generative.lst
│ └── val.lst
├── setup.py
└── train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Jiapeng Tang, Dan Xu, Kui Jia, Lei Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LPDC-Net
2 | [Homepage](https://tangjiapeng.github.io/LPDC-Net) | [Paper-Pdf](https://arxiv.org/pdf/2103.16341.pdf) | [Video](https://youtu.be/dhmuuzfRpNs)
3 |
4 | This repository contains the code for the project **LPDC-Net - Learning Parallel Dense Correspondence from Spatio-Temporal Descriptors
5 | for Efficient and Robust 4D Reconstruction**
6 |
7 | You can find detailed usage instructions for training your own models and using the [pretrained models](https://drive.google.com/drive/folders/1jPrkxd9GYKtSsQt_q4poIYMuxeXlYRZ9?usp=sharing).
8 |
9 | ## Installation
10 | First you have to make sure that you have all dependencies in place. You can create and activate an anaconda environment called `lpdc` using
11 |
12 | ```
13 | conda env create -f environment.yaml
14 | conda activate lpdc
15 | ```
16 | Next, compile the extension modules. You can do this via
17 | ```
18 | python setup.py build_ext --inplace
19 | ```
20 |
21 | ## Demo
22 |
23 | You can test our code on the provided input point cloud sequences in the `demo/` folder. To this end, simple run
24 | ```
25 | python generate.py configs/demo.yaml
26 | ```
27 | This script should create a folder `out/demo/` where the output is stored.
28 |
29 | ## Dataset
30 |
31 | ### Point-based Data
32 | To train a new model from scratch, you have to download the full dataset.
33 | You can download the pre-processed data (~42 GB) using
34 |
35 | ```
36 | bash scripts/download_data.sh
37 | ```
38 |
39 | The script will download the point-based point-based data for the [Dynamic FAUST (D-FAUST)](http://dfaust.is.tue.mpg.de/) dataset to the `data/` folder.
40 |
41 | ### Mesh Data
42 |
43 | Please follow the instructions on [D-FAUST homepage](http://dfaust.is.tue.mpg.de/) to download the "female and male registrations" as well as "scripts to load / parse the data".
44 | Next, follow their instructions in the `scripts/README.txt` file to extract the obj-files of the sequences. Once completed, you should have a folder with the following structure:
45 | ___
46 | your_dfaust_folder/
47 | | 50002_chicken_wings/
48 | | 00000.obj
49 | | 00001.obj
50 | | ...
51 | | 000215.obj
52 | | 50002_hips/
53 | | 00000.obj
54 | | ...
55 | | ...
56 | | 50027_shake_shoulders/
57 | | 00000.obj
58 | | ...
59 | ___
60 | You can now run
61 | ```
62 | bash scripts/migrate_dfaust.sh path/to/your_dfaust_folder
63 | ```
64 | to copy the mesh data to the dataset folder.
65 | The argument has to be the folder to which you have extracted the mesh data (the `your_dfaust_folder` from the directory tree above).
66 |
67 | ### Incomplete Point Cloud Sequence
68 |
69 | You can now run
70 | ```
71 | bash scripts/build_dataset_incomplete.sh
72 | ```
73 | to create incomplete point cloud sequences for the experiment of 4D Shape Completion.
74 |
75 | ## Usage
76 |
77 | When you have installed all dependencies and obtained the preprocessed data, you are ready to run our pre-trained models and train new models from scratch.
78 |
79 | ### Generation
80 |
81 |
82 |

83 |

84 |
85 |
86 | To start the normal mesh generation process using a trained model, use
87 |
88 | ```
89 | python generate.py configs/CONFIG.yaml
90 | ```
91 | where you replace `CONFIG.yaml` with the name of the configuration file you want to use.
92 |
93 | The easiest way is to use a pretrained model. You can do this by using one of the config files
94 |
95 | ```
96 | configs/noflow/lpdc_even_pretrained.yaml
97 | configs/noflow/lpdc_uneven_pretrained.yaml
98 | configs/noflow/lpdc_completion_pretrained.yaml
99 | ```
100 |
101 | Our script will automatically download the model checkpoints and run the generation.
102 | You can find the outputs in the `out/pointcloud` folder.
103 |
104 | Please note that the config files *_pretrained.yaml are only for generation, not for training new models: when these configs are used for training, the model
105 | will be trained from scratch, but during inference our code will still use the pretrained model.
106 |
107 | ### Evaluation
108 |
109 | You can evaluate the generated output of a model on the test set using
110 |
111 | ```
112 | python eval.py configs/CONFIG.yaml
113 | ```
114 | The evaluation results will be saved to pickle and csv files.
115 |
116 | ### Training
117 |
118 | Finally, to train a new network from scratch, run
119 | ```
120 | python train.py configs/CONFIG.yaml
121 | ```
122 | You can monitor the training process on http://localhost:6006 using tensorboard:
123 | ```
124 | cd OUTPUT_DIR
125 | tensorboard --logdir ./logs --port 6006
126 | ```
127 | where you replace `OUTPUT_DIR` with the respective output directory. For available training options, please have a look at `config/default.yaml`.
128 |
129 |
130 | ## Acknowledgements
131 |
132 | Most of the code is borrowed from [Occupancy Flow](https://github.com/autonomousvision/occupancy_flow). We thank Michael Niemeyer for his great works and repos.
133 |
134 | ## Citation
135 |
136 | If you find our code or paper useful, please consider citing
137 |
138 | @inproceedings{tang2021learning,
139 | title={Learning Parallel Dense Correspondence from Spatio-Temporal Descriptors for Efficient and Robust 4D Reconstruction},
140 | author={Tang, Jiapeng and Xu, Dan and Jia, Kui and Zhang, Lei},
141 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
142 | pages={6022--6031},
143 | year={2021}
144 | }
145 |
--------------------------------------------------------------------------------
/animations/input_full.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/animations/input_full.gif
--------------------------------------------------------------------------------
/animations/output_full.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/animations/output_full.gif
--------------------------------------------------------------------------------
/animations/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/animations/teaser.png
--------------------------------------------------------------------------------
/configs/default.yaml:
--------------------------------------------------------------------------------
1 | method: lpdc
2 | data:
3 | path: ./data/Humans
4 | dataset: Humans
5 | input_type: pcl_seq
6 | classes: ['D-FAUST']
7 | train_split: train
8 | val_split: val
9 | test_split: test
10 | dim: 3
11 | n_training_points: 512
12 | points_unpackbits: true
13 | n_training_pcl_points: 100
14 | input_pointcloud_n: 300
15 | input_pointcloud_noise: 0.001
16 | input_pointcloud_corresponding: true
17 | n_views: 24
18 | img_size: 224
19 | img_with_camera: false
20 | img_augment: false
21 | length_sequence: 17 #
22 | select_steps: null #create uneven input
23 | offset_sequence: 15
24 | n_files_per_sequence: -1
25 | n_intervals: 1
26 | points_file: points.npz
27 | mesh_seq_folder: mesh_seq
28 | points_iou_seq_folder: points_seq
29 | pointcloud_seq_folder: pcl_seq
30 | img_seq_folder: img
31 | completion: false #if point cloud completion
32 | pointcloud_seq_incomplete_folder: pcl_incomp_seq #incomplete pointcloud folder
33 | model:
34 | encoder: pointnet_resnet
35 | encoder_temporal: pointnet_resnet
36 | decoder: cbatchnorm
37 | velocity_field: concat
38 | encoder_latent: null
39 | encoder_latent_temporal: null
40 | decoder_kwargs: {}
41 | encoder_kwargs: {}
42 | encoder_latent_kwargs: {}
43 | encoder_temporal_kwargs: {}
44 | velocity_field_kwargs: {}
45 | encoder_latent_temporal_kwargs: {}
46 | learn_embedding: false
47 | ode_solver: dopri5
48 | ode_step_size: null
49 | use_adjoint: true
50 | rtol: 0.001
51 | atol: 0.00001
52 | vae_beta: 0.0001
53 | loss_corr: false
54 | loss_corr_bw: false
55 | loss_recon: true
56 | loss_transform_forward: false
57 | initialize_from: null
58 | initialization_file_name: model_best.pt
59 | c_dim: 512
60 | z_dim: 0
61 | use_camera: False
62 | training:
63 | out_dir: out/00
64 | model_selection_metric: iou
65 | model_selection_mode: maximize
66 | n_eval_points: 5000
67 | batch_size: 16
68 | batch_size_vis: 1
69 | batch_size_val: 1
70 | print_every: 5
71 | visualize_every: 999999999
72 | checkpoint_every: 200
73 | validate_every: 2000
74 | backup_every: 100000
75 | eval_sample: true
76 | learning_rate: 0.0001
77 | test:
78 | threshold: 0.3
79 | eval_mesh: true
80 | eval_pointcloud: false
81 | project_to_final_mesh: false
82 | eval_mesh_correspondences: true
83 | eval_mesh_iou: true
84 | eval_pointcloud_correspondences: true
85 | eval_only_end_time_steps: false
86 | model_file: model_best.pt
87 | generation:
88 | generate_pointcloud: false
89 | generate_mesh: true
90 | resolution_0: 32
91 | upsampling_steps: 2
92 | refinement_step: 0
93 | simplify_nfaces: null
94 | padding: 0.1
95 | vis_n_outputs: 20
96 | n_time_steps: 17
97 | mesh_color: true
98 | interpolate: false
99 | only_end_time_points: false
100 | fix_z: false
101 | fix_zt: False
102 | shuffle_generation: true
103 | rand_seed: 12345
104 | batch_size: 1000000
105 | generation_dir: generation
106 | use_sampling: false
107 | copy_input: false
108 |
--------------------------------------------------------------------------------
/configs/demo.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/noflow/lpdc_even_pretrained.yaml
2 | data:
3 | classes: ['Demo']
4 | test_split: null
5 | offset_sequence: 0
6 | training:
7 | out_dir: out/demo
8 | test:
9 | model_file: lpdc_even_pretrained.pt
--------------------------------------------------------------------------------
/configs/noflow/lpdc_completion.yaml:
--------------------------------------------------------------------------------
1 | method: lpdc_uneven
2 | data:
3 | test_split: test #test_new_individual
4 | n_intervals: 1
5 | select_steps: [0,10,19,29,44,49]
6 | length_sequence: 50
7 | pointcloud_seq_incomplete_folder: incompcl_seq
8 | completion: true
9 | model:
10 | encoder:
11 | encoder_kwargs:
12 | hidden_dim: 128
13 | encoder_temporal: pointnet_spatiotemporal2
14 | encoder_temporal_kwargs:
15 | hidden_dim: 128
16 | decoder: simple_local
17 | decoder_kwargs:
18 | hidden_size: 128
19 | velocity_field: concat
20 | velocity_field_kwargs:
21 | hidden_size: 128
22 | c_dim: 128
23 | loss_corr: true
24 | loss_transform_forward: true
25 | training:
26 | model_selection_metric: iou
27 | model_selection_mode: maximize
28 | batch_size: 16
29 | validate_every: 10000
30 | backup_every: 2000 ##
31 | learning_rate: 0.0001
32 | out_dir: out/pointcloud/lpdc_completion
--------------------------------------------------------------------------------
/configs/noflow/lpdc_completion_pretrained.yaml:
--------------------------------------------------------------------------------
1 | method: lpdc_uneven
2 | data:
3 | test_split: test #test_new_individual
4 | n_intervals: 1
5 | select_steps: [0,10,19,29,44,49]
6 | length_sequence: 50
7 | pointcloud_seq_incomplete_folder: incompcl_seq
8 | completion: true
9 | model:
10 | encoder:
11 | encoder_kwargs:
12 | hidden_dim: 128
13 | encoder_temporal: pointnet_spatiotemporal2
14 | encoder_temporal_kwargs:
15 | hidden_dim: 128
16 | decoder: simple_local
17 | decoder_kwargs:
18 | hidden_size: 128
19 | velocity_field: concat
20 | velocity_field_kwargs:
21 | hidden_size: 128
22 | c_dim: 128
23 | loss_corr: true
24 | loss_transform_forward: true
25 | training:
26 | out_dir: out/pointcloud/lpdc_completion_pretrained
27 | test:
28 | model_file: lpdc_completion_pretrained.pt
--------------------------------------------------------------------------------
/configs/noflow/lpdc_even.yaml:
--------------------------------------------------------------------------------
1 | method: lpdc
2 | data:
3 | test_split: test #test_new_individual
4 | n_intervals: 1
5 | model:
6 | encoder:
7 | encoder_kwargs:
8 | hidden_dim: 128
9 | encoder_temporal: pointnet_spatiotemporal
10 | encoder_temporal_kwargs:
11 | hidden_dim: 128
12 | decoder: simple_local
13 | decoder_kwargs:
14 | hidden_size: 128
15 | velocity_field: concat
16 | velocity_field_kwargs:
17 | hidden_size: 128
18 | c_dim: 128
19 | loss_corr: true
20 | loss_transform_forward: true
21 | training:
22 | model_selection_metric: iou
23 | model_selection_mode: maximize
24 | batch_size: 16
25 | validate_every: 10000
26 | backup_every: 2000 ##
27 | learning_rate: 0.0001
28 | out_dir: out/pointcloud/lpdc_even
29 |
--------------------------------------------------------------------------------
/configs/noflow/lpdc_even_pretrained.yaml:
--------------------------------------------------------------------------------
1 | method: lpdc
2 | data:
3 | test_split: test #test_new_individual
4 | n_intervals: 1
5 | model:
6 | encoder:
7 | encoder_kwargs:
8 | hidden_dim: 128
9 | encoder_temporal: pointnet_spatiotemporal
10 | encoder_temporal_kwargs:
11 | hidden_dim: 128
12 | decoder: simple_local
13 | decoder_kwargs:
14 | hidden_size: 128
15 | velocity_field: concat
16 | velocity_field_kwargs:
17 | hidden_size: 128
18 | c_dim: 128
19 | loss_corr: true
20 | loss_transform_forward: true
21 | training:
22 | out_dir: out/pointcloud/lpdc_even_pretrained
23 | test:
24 | model_file: lpdc_even_pretrained.pt
25 |
--------------------------------------------------------------------------------
/configs/noflow/lpdc_uneven.yaml:
--------------------------------------------------------------------------------
1 | method: lpdc_uneven
2 | data:
3 | test_split: test #test_new_individual
4 | n_intervals: 1
5 | select_steps: [0,10,19,29,44,49]
6 | length_sequence: 50
7 | model:
8 | encoder:
9 | encoder_kwargs:
10 | hidden_dim: 128
11 | encoder_temporal: pointnet_spatiotemporal
12 | encoder_temporal_kwargs:
13 | hidden_dim: 128
14 | decoder: simple_local
15 | decoder_kwargs:
16 | hidden_size: 128
17 | velocity_field: concat
18 | velocity_field_kwargs:
19 | hidden_size: 128
20 | c_dim: 128
21 | loss_corr: true
22 | loss_transform_forward: true
23 | training:
24 | model_selection_metric: iou
25 | model_selection_mode: maximize
26 | batch_size: 16
27 | validate_every: 10000
28 | backup_every: 2000 ##
29 | learning_rate: 0.0001
30 | out_dir: out/pointcloud/lpdc_uneven
--------------------------------------------------------------------------------
/configs/noflow/lpdc_uneven_pretrained.yaml:
--------------------------------------------------------------------------------
1 | method: lpdc_uneven
2 | data:
3 | test_split: test #test_new_individual
4 | n_intervals: 1
5 | select_steps: [0,10,19,29,44,49]
6 | length_sequence: 50
7 | model:
8 | encoder:
9 | encoder_kwargs:
10 | hidden_dim: 128
11 | encoder_temporal: pointnet_spatiotemporal
12 | encoder_temporal_kwargs:
13 | hidden_dim: 128
14 | decoder: simple_local
15 | decoder_kwargs:
16 | hidden_size: 128
17 | velocity_field: concat
18 | velocity_field_kwargs:
19 | hidden_size: 128
20 | c_dim: 128
21 | loss_corr: true
22 | loss_transform_forward: true
23 | training:
24 | out_dir: out/pointcloud/lpdc_uneven_pretrained
25 | test:
26 | model_file: lpdc_uneven_pretrained.pt
--------------------------------------------------------------------------------
/data/Humans/D-FAUST/overfit.lst:
--------------------------------------------------------------------------------
1 | 50026_one_leg_jump
2 |
--------------------------------------------------------------------------------
/data/Humans/D-FAUST/test.lst:
--------------------------------------------------------------------------------
1 | 50002_light_hopping_loose
2 | 50004_punching
3 | 50007_shake_shoulders
4 | 50009_chicken_wings
5 | 50020_chicken_wings
6 | 50022_light_hopping_loose
7 | 50025_light_hopping_loose
8 | 50026_shake_arms
9 | 50027_shake_shoulders
10 |
--------------------------------------------------------------------------------
/data/Humans/D-FAUST/test_new_individual.lst:
--------------------------------------------------------------------------------
1 | 50021_chicken_wings
2 | 50021_knees
3 | 50021_one_leg_jump
4 | 50021_punching
5 | 50021_shake_arms
6 | 50021_shake_shoulders
7 | 50021_hips
8 | 50021_light_hopping_stiff
9 | 50021_one_leg_loose
10 | 50021_running_on_spot
11 | 50021_shake_hips
12 |
--------------------------------------------------------------------------------
/data/Humans/D-FAUST/train.lst:
--------------------------------------------------------------------------------
1 | 50002_one_leg_loose
2 | 50025_shake_shoulders
3 | 50007_one_leg_jump
4 | 50002_knees
5 | 50002_light_hopping_stiff
6 | 50004_jiggle_on_toes
7 | 50004_shake_hips
8 | 50026_chicken_wings
9 | 50007_punching
10 | 50022_one_leg_jump
11 | 50009_light_hopping_stiff
12 | 50025_light_hopping_stiff
13 | 50007_shake_arms
14 | 50026_running_on_spot
15 | 50025_chicken_wings
16 | 50020_shake_hips
17 | 50026_shake_hips
18 | 50027_light_hopping_stiff
19 | 50009_shake_hips
20 | 50009_light_hopping_loose
21 | 50020_jiggle_on_toes
22 | 50025_one_leg_loose
23 | 50009_punching
24 | 50027_hips
25 | 50002_running_on_spot
26 | 50026_light_hopping_stiff
27 | 50026_jiggle_on_toes
28 | 50020_shake_shoulders
29 | 50007_light_hopping_stiff
30 | 50007_jiggle_on_toes
31 | 50027_punching
32 | 50009_running_on_spot
33 | 50002_shake_arms
34 | 50022_knees
35 | 50007_jumping_jacks
36 | 50027_running_on_spot
37 | 50022_running_on_spot
38 | 50004_knees
39 | 50027_one_leg_loose
40 | 50009_one_leg_jump
41 | 50026_jumping_jacks
42 | 50009_one_leg_loose
43 | 50027_jiggle_on_toes
44 | 50020_knees
45 | 50027_light_hopping_loose
46 | 50026_knees
47 | 50004_jumping_jacks
48 | 50026_one_leg_jump
49 | 50004_hips
50 | 50027_shake_arms
51 | 50026_hips
52 | 50020_punching
53 | 50025_jiggle_on_toes
54 | 50022_punching
55 | 50004_light_hopping_loose
56 | 50009_jumping_jacks
57 | 50002_one_leg_jump
58 | 50007_knees
59 | 50027_jumping_jacks
60 | 50022_shake_arms
61 | 50002_jiggle_on_toes
62 | 50002_hips
63 | 50009_jiggle_on_toes
64 | 50022_jiggle_on_toes
65 | 50007_shake_hips
66 | 50022_hips
67 | 50026_punching
68 | 50026_one_leg_loose
69 | 50025_running_on_spot
70 | 50025_knees
71 | 50025_hips
72 | 50020_light_hopping_stiff
73 | 50026_shake_shoulders
74 | 50002_jumping_jacks
75 | 50022_light_hopping_stiff
76 | 50027_one_leg_jump
77 | 50002_punching
78 | 50022_shake_shoulders
79 | 50004_running_on_spot
80 | 50020_light_hopping_loose
81 | 50022_shake_hips
82 | 50004_one_leg_loose
83 | 50022_jumping_jacks
84 | 50002_chicken_wings
85 | 50022_one_leg_loose
86 | 50027_knees
87 | 50004_shake_arms
88 | 50007_chicken_wings
89 | 50002_shake_hips
90 | 50007_one_leg_loose
91 | 50004_shake_shoulders
92 | 50009_hips
93 | 50007_running_on_spot
94 | 50025_shake_hips
95 | 50002_shake_shoulders
96 | 50020_shake_arms
97 | 50027_shake_hips
98 | 50026_light_hopping_loose
99 | 50025_one_leg_jump
100 | 50020_one_leg_loose
101 | 50004_one_leg_jump
102 | 50025_punching
103 | 50020_one_leg_jump
104 | 50004_light_hopping_stiff
105 |
--------------------------------------------------------------------------------
/data/Humans/D-FAUST/val.lst:
--------------------------------------------------------------------------------
1 | 50004_chicken_wings
2 | 50020_hips
3 | 50025_shake_arms
4 | 50020_running_on_spot
5 | 50007_light_hopping_loose
6 |
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000080.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000080.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000081.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000081.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000082.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000082.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000083.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000083.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000084.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000084.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000085.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000085.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000086.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000086.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000087.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000087.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000088.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000088.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000089.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000089.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000090.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000090.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000091.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000091.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000092.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000092.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000093.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000093.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000094.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000094.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000095.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000095.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000096.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50002_light_hopping_loose/pcl_seq/00000096.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000166.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000166.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000167.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000167.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000168.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000168.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000169.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000169.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000170.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000170.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000171.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000171.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000172.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000172.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000173.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000173.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000174.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000174.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000175.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000175.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000176.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000176.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000177.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000177.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000178.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000178.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000179.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000179.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000180.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000180.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000181.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000181.npz
--------------------------------------------------------------------------------
/data/Humans/Demo/50004_punching/pcl_seq/00000182.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/data/Humans/Demo/50004_punching/pcl_seq/00000182.npz
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: lpdc
2 | channels:
3 | - conda-forge
4 | - pytorch
5 | - defaults
6 | dependencies:
7 | - cython=0.29.2
8 | - imageio=2.4.1
9 | - numpy=1.15.4
10 | - numpy-base=1.15.4
11 | - matplotlib=3.0.3
12 | - matplotlib-base=3.0.3
13 | - pandas=0.23.4
14 | - pillow=5.3.0
15 | - pyembree=0.1.4
16 | - pytest=4.0.2
17 | - python=3.6.7
18 | - pytorch=1.0.0
19 | - pyyaml=3.13
20 | - scikit-image=0.14.1
21 | - scikit-learn=0.21.3
22 | - scipy=1.1.0
23 | - tensorboardx=1.4
24 | - torchvision=0.2.1
25 | - tqdm=4.28.1
26 | - trimesh=2.37.7
27 | - pip=19.1.1
28 | - pip:
29 | - h5py==2.9.0
30 | - plyfile==0.7
31 |
--------------------------------------------------------------------------------
/im2mesh/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/__init__.py
--------------------------------------------------------------------------------
/im2mesh/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/__init__.pyc
--------------------------------------------------------------------------------
/im2mesh/checkpoints.py:
--------------------------------------------------------------------------------
1 | import os
2 | import urllib
3 | import torch
4 | from torch.utils import model_zoo
5 |
6 |
7 | class CheckpointIO(object):
8 | ''' CheckpointIO class.
9 |
10 | It handles saving and loading checkpoints.
11 |
12 | Args:
13 | checkpoint_dir (str): path where checkpoints are saved
14 | '''
15 | def __init__(self, checkpoint_dir='./chkpts', initialize_from=None,
16 | initialization_file_name='model_best.pt', **kwargs):
17 | self.module_dict = kwargs
18 | self.checkpoint_dir = checkpoint_dir
19 | self.initialize_from = initialize_from
20 | self.initialization_file_name = initialization_file_name
21 | if not os.path.exists(checkpoint_dir):
22 | os.makedirs(checkpoint_dir)
23 |
24 | def register_modules(self, **kwargs):
25 | ''' Registers modules in current module dictionary.
26 | '''
27 | self.module_dict.update(kwargs)
28 |
29 | def save(self, filename, **kwargs):
30 | ''' Saves the current module dictionary.
31 |
32 | Args:
33 | filename (str): name of output file
34 | '''
35 | if not os.path.isabs(filename):
36 | filename = os.path.join(self.checkpoint_dir, filename)
37 |
38 | outdict = kwargs
39 | print(self.module_dict.keys())
40 | for k, v in self.module_dict.items():
41 | outdict[k] = v.state_dict()
42 | torch.save(outdict, filename)
43 |
44 | def load(self, filename):
45 | '''Loads a module dictionary from local file or url.
46 |
47 | Args:
48 | filename (str): name of saved module dictionary
49 | '''
50 | if is_url(filename):
51 | return self.load_url(filename)
52 | else:
53 | return self.load_file(filename)
54 |
55 | def load_file(self, filename):
56 | '''Loads a module dictionary from file.
57 |
58 | Args:
59 | filename (str): name of saved module dictionary
60 | '''
61 |
62 | if not os.path.isabs(filename):
63 | filename = os.path.join(self.checkpoint_dir, filename)
64 |
65 | if os.path.exists(filename):
66 | print(filename)
67 | print('=> Loading checkpoint from local file...')
68 | state_dict = torch.load(filename)
69 | scalars = self.parse_state_dict(state_dict)
70 | return scalars
71 | else:
72 | if self.initialize_from is not None:
73 | self.initialize_weights()
74 | raise FileExistsError
75 |
76 | def load_url(self, url):
77 | '''Load a module dictionary from url.
78 |
79 | Args:
80 | url (str): url to saved model
81 | '''
82 | print(url)
83 | print('=> Loading checkpoint from url...')
84 | state_dict = model_zoo.load_url(url, progress=True)
85 | scalars = self.parse_state_dict(state_dict)
86 | return scalars
87 |
88 | def parse_state_dict(self, state_dict):
89 | '''Parse state_dict of model and return scalars.
90 |
91 | Args:
92 | state_dict (dict): State dict of model
93 | '''
94 |
95 | for k, v in self.module_dict.items():
96 | if k in state_dict:
97 | v.load_state_dict(state_dict[k])
98 | else:
99 | print('Warning: Could not find %s in checkpoint!' % k)
100 | scalars = {k: v for k, v in state_dict.items()
101 | if k not in self.module_dict}
102 | return scalars
103 |
104 | def initialize_weights(self):
105 | ''' Initializes the model weights from another model file.
106 | '''
107 |
108 | print('Intializing weights from model %s' % self.initialize_from)
109 | filename_in = os.path.join(
110 | self.initialize_from, self.initialization_file_name)
111 |
112 | model_state_dict = self.module_dict.get('model').state_dict()
113 | model_dict = self.module_dict.get('model').state_dict()
114 | model_keys = set([k for (k, v) in model_dict.items()])
115 |
116 | init_model_dict = torch.load(filename_in)['model']
117 | init_model_k = set([k for (k, v) in init_model_dict.items()])
118 |
119 | for k in model_keys:
120 | if ((k in init_model_k) and (model_state_dict[k].shape ==
121 | init_model_dict[k].shape)):
122 | model_state_dict[k] = init_model_dict[k]
123 | self.module_dict.get('model').load_state_dict(model_state_dict)
124 |
125 |
126 | def is_url(url):
127 | ''' Checks if input is url.'''
128 | scheme = urllib.parse.urlparse(url).scheme
129 | return scheme in ('http', 'https')
130 |
--------------------------------------------------------------------------------
/im2mesh/checkpoints.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/checkpoints.pyc
--------------------------------------------------------------------------------
/im2mesh/config.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/config.pyc
--------------------------------------------------------------------------------
/im2mesh/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from im2mesh.data.core import (
3 | Shapes3dDataset, collate_remove_none, worker_init_fn
4 | )
5 | from im2mesh.data.subseq_dataset import (
6 | HumansDataset
7 | )
8 | from im2mesh.data.fields import (
9 | IndexField, CategoryField,
10 | PointsSubseqField, ImageSubseqField,
11 | PointCloudSubseqField, MeshSubseqField,
12 | )
13 |
14 | from im2mesh.data.transforms import (
15 | PointcloudNoise,
16 | SubsamplePointcloud,
17 | SubsamplePoints,
18 | # Temporal transforms
19 | SubsamplePointsSeq, SubsamplePointcloudSeq,
20 | )
21 |
22 |
23 | __all__ = [
24 | # Core
25 | Shapes3dDataset,
26 | collate_remove_none,
27 | worker_init_fn,
28 | # Humans Dataset
29 | HumansDataset,
30 | # Fields
31 | IndexField,
32 | CategoryField,
33 | PointsSubseqField,
34 | PointCloudSubseqField,
35 | ImageSubseqField,
36 | MeshSubseqField,
37 | # Transforms
38 | PointcloudNoise,
39 | SubsamplePointcloud,
40 | SubsamplePoints,
41 | # Temporal Transforms
42 | SubsamplePointsSeq,
43 | SubsamplePointcloudSeq,
44 | ]
45 |
--------------------------------------------------------------------------------
/im2mesh/data/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/__init__.pyc
--------------------------------------------------------------------------------
/im2mesh/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/data/__pycache__/core.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/__pycache__/core.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/data/__pycache__/fields.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/__pycache__/fields.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/data/__pycache__/subseq_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/__pycache__/subseq_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/data/__pycache__/transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/__pycache__/transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/data/core.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from torch.utils import data
4 | import numpy as np
5 | import yaml
6 |
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | # Fields
12 | class Field(object):
13 | ''' Data fields class.
14 | '''
15 |
16 | def load(self, data_path, idx, category):
17 | ''' Loads a data point.
18 |
19 | Args:
20 | data_path (str): path to data file
21 | idx (int): index of data point
22 | category (int): index of category
23 | '''
24 | raise NotImplementedError
25 |
26 | def check_complete(self, files):
27 | ''' Checks if set is complete.
28 |
29 | Args:
30 | files: files
31 | '''
32 | raise NotImplementedError
33 |
34 |
35 | class Shapes3dDataset(data.Dataset):
36 | ''' 3D Shapes dataset class.
37 | '''
38 |
39 | def __init__(self, dataset_folder, fields, split=None,
40 | categories=None, no_except=True, transform=None):
41 | ''' Initialization of the the 3D shape dataset.
42 |
43 | Args:
44 | dataset_folder (str): dataset folder
45 | fields (dict): dictionary of fields
46 | split (str): which split is used
47 | categories (list): list of categories to use
48 | no_except (bool): no exception
49 | transform (callable): transformation applied to data points
50 | '''
51 | # Attributes
52 | self.dataset_folder = dataset_folder
53 | self.fields = fields
54 | self.no_except = no_except
55 | self.transform = transform
56 |
57 | # If categories is None, use all subfolders
58 | if categories is None:
59 | categories = os.listdir(dataset_folder)
60 | categories = [c for c in categories
61 | if os.path.isdir(os.path.join(dataset_folder, c))]
62 |
63 | # Read metadata file
64 | metadata_file = os.path.join(dataset_folder, 'metadata.yaml')
65 |
66 | if os.path.exists(metadata_file):
67 | with open(metadata_file, 'r') as f:
68 | self.metadata = yaml.load(f)
69 | else:
70 | self.metadata = {
71 | c: {'id': c, 'name': 'n/a'} for c in categories
72 | }
73 |
74 | # Set index
75 | for c_idx, c in enumerate(categories):
76 | self.metadata[c]['idx'] = c_idx
77 |
78 | # Get all models
79 | self.models = []
80 | for c_idx, c in enumerate(categories):
81 | subpath = os.path.join(dataset_folder, c)
82 | if not os.path.isdir(subpath):
83 | logger.warning('Category %s does not exist in dataset.' % c)
84 |
85 | split_file = os.path.join(subpath, split + '.lst')
86 | with open(split_file, 'r') as f:
87 | models_c = f.read().split('\n')
88 |
89 | self.models += [
90 | {'category': c, 'model': m}
91 | for m in models_c
92 | ]
93 |
94 | def __len__(self):
95 | ''' Returns the length of the dataset.
96 | '''
97 | return len(self.models)
98 |
99 | def __getitem__(self, idx):
100 | ''' Returns an item of the dataset.
101 |
102 | Args:
103 | idx (int): ID of data point
104 | '''
105 | category = self.models[idx]['category']
106 | model = self.models[idx]['model']
107 | c_idx = self.metadata[category]['idx']
108 |
109 | model_path = os.path.join(self.dataset_folder, category, model)
110 | data = {}
111 |
112 | for field_name, field in self.fields.items():
113 | try:
114 | field_data = field.load(model_path, idx, c_idx)
115 | except Exception:
116 | if self.no_except:
117 | logger.warn(
118 | 'Error occured when loading field %s of model %s'
119 | % (field_name, model)
120 | )
121 | return None
122 | else:
123 | raise
124 |
125 | if isinstance(field_data, dict):
126 | for k, v in field_data.items():
127 | if k is None:
128 | data[field_name] = v
129 | else:
130 | data['%s.%s' % (field_name, k)] = v
131 | else:
132 | data[field_name] = field_data
133 |
134 | if self.transform is not None:
135 | data = self.transform(data)
136 |
137 | return data
138 |
139 | def get_model_dict(self, idx):
140 | return self.models[idx]
141 |
142 | def test_model_complete(self, category, model):
143 | ''' Tests if model is complete.
144 |
145 | Args:
146 | model (str): modelname
147 | '''
148 | model_path = os.path.join(self.dataset_folder, category, model)
149 | files = os.listdir(model_path)
150 | for field_name, field in self.fields.items():
151 | if not field.check_complete(files):
152 | logger.warn('Field "%s" is incomplete: %s'
153 | % (field_name, model_path))
154 | return False
155 |
156 | return True
157 |
158 |
159 | def collate_remove_none(batch):
160 | ''' Collater that puts each data field into a tensor with outer dimension
161 | batch size.
162 |
163 | Args:
164 | batch: batch
165 | '''
166 |
167 | batch = list(filter(lambda x: x is not None, batch))
168 | return data.dataloader.default_collate(batch)
169 |
170 |
171 | def worker_init_fn(worker_id):
172 | ''' Worker init function to ensure true randomness.
173 | '''
174 | random_data = os.urandom(4)
175 | base_seed = int.from_bytes(random_data, byteorder="big")
176 | np.random.seed(base_seed + worker_id)
177 |
--------------------------------------------------------------------------------
/im2mesh/data/core.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/core.pyc
--------------------------------------------------------------------------------
/im2mesh/data/fields.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/fields.pyc
--------------------------------------------------------------------------------
/im2mesh/data/subseq_dataset.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/subseq_dataset.pyc
--------------------------------------------------------------------------------
/im2mesh/data/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | # Transforms
5 | class PointcloudNoise(object):
6 | ''' Point cloud noise transformation class.
7 |
8 | It adds noise to point cloud data.
9 |
10 | Args:
11 | stddev (int): standard deviation
12 | '''
13 |
14 | def __init__(self, stddev):
15 | self.stddev = stddev
16 |
17 | def __call__(self, data):
18 | ''' Calls the transformation.
19 |
20 | Args:
21 | data (dictionary): data dictionary
22 | '''
23 | data_out = data.copy()
24 | points = data[None]
25 | noise = self.stddev * np.random.randn(*points.shape)
26 | noise = noise.astype(np.float32)
27 | data_out[None] = points + noise
28 | return data_out
29 |
30 |
31 | class SubsamplePointcloud(object):
32 | ''' Point cloud subsampling transformation class.
33 |
34 | It subsamples the point cloud data.
35 |
36 | Args:
37 | N (int): number of points to be subsampled
38 | '''
39 |
40 | def __init__(self, N):
41 | self.N = N
42 |
43 | def __call__(self, data):
44 | ''' Calls the transformation.
45 |
46 | Args:
47 | data (dict): data dictionary
48 | '''
49 | data_out = data.copy()
50 | points = data[None]
51 |
52 | indices = np.random.randint(points.shape[0], size=self.N)
53 | data_out[None] = points[indices, :]
54 |
55 | if 'normals' in data.keys():
56 | normals = data['normals']
57 | data_out['normals'] = normals[indices, :]
58 |
59 | return data_out
60 |
61 |
62 | class SubsamplePoints(object):
63 | ''' Points subsampling transformation class.
64 | It subsamples the points data.
65 | Args:
66 | N (int): number of points to be subsampled
67 | '''
68 |
69 | def __init__(self, N):
70 | self.N = N
71 |
72 | def __call__(self, data):
73 | ''' Calls the transformation.
74 | Args:
75 | data (dictionary): data dictionary
76 | '''
77 | points = data[None]
78 | occ = data['occ']
79 |
80 | data_out = data.copy()
81 | if isinstance(self.N, int):
82 | idx = np.random.randint(points.shape[0], size=self.N)
83 | data_out.update({
84 | None: points[idx, :],
85 | 'occ': occ[idx],
86 | })
87 | else:
88 | Nt_out, Nt_in = self.N
89 | occ_binary = (occ >= 0.5)
90 | points0 = points[~occ_binary]
91 | points1 = points[occ_binary]
92 |
93 | idx0 = np.random.randint(points0.shape[0], size=Nt_out)
94 | idx1 = np.random.randint(points1.shape[0], size=Nt_in)
95 |
96 | points0 = points0[idx0, :]
97 | points1 = points1[idx1, :]
98 | points = np.concatenate([points0, points1], axis=0)
99 |
100 | occ0 = np.zeros(Nt_out, dtype=np.float32)
101 | occ1 = np.ones(Nt_in, dtype=np.float32)
102 | occ = np.concatenate([occ0, occ1], axis=0)
103 |
104 | volume = occ_binary.sum() / len(occ_binary)
105 | volume = volume.astype(np.float32)
106 |
107 | data_out.update({
108 | None: points,
109 | 'occ': occ,
110 | 'volume': volume,
111 | })
112 | return data_out
113 |
114 |
115 | class SubsamplePointcloudSeq(object):
116 | ''' Point cloud sequence subsampling transformation class.
117 |
118 | It subsamples the point cloud sequence data.
119 |
120 | Args:
121 | N (int): number of points to be subsampled
122 | connected_samples (bool): whether to obtain connected samples
123 | random (bool): whether to sub-sample randomly
124 | '''
125 |
126 | def __init__(self, N, connected_samples=False, random=True):
127 | self.N = N
128 | self.connected_samples = connected_samples
129 | self.random = random
130 |
131 | def __call__(self, data):
132 | ''' Calls the transformation.
133 |
134 | Args:
135 | data (dictionary): data dictionary
136 | '''
137 | data_out = data.copy()
138 | points = data[None] # n_steps x T x 3
139 | n_steps, T, dim = points.shape
140 | N_max = min(self.N, T)
141 | if self.connected_samples or not self.random:
142 | indices = (np.random.randint(T, size=self.N) if self.random else
143 | np.arange(N_max))
144 | data_out[None] = points[:, indices, :]
145 | else:
146 | indices = np.random.randint(T, size=(n_steps, self.N))
147 | data_out[None] = \
148 | points[np.arange(n_steps).reshape(-1, 1), indices, :]
149 | return data_out
150 |
151 | class SubsamplePointsSeq(object):
152 | ''' Points sequence subsampling transformation class.
153 |
154 | It subsamples the points sequence data.
155 |
156 | Args:
157 | N (int): number of points to be subsampled
158 | connected_samples (bool): whether to obtain connected samples
159 | random (bool): whether to sub-sample randomly
160 | '''
161 |
162 | def __init__(self, N, connected_samples=False, random=True):
163 | self.N = N
164 | self.connected_samples = connected_samples
165 | self.random = random
166 |
167 | def __call__(self, data):
168 | ''' Calls the transformation.
169 |
170 | Args:
171 | data (dictionary): data dictionary
172 | '''
173 | points = data[None]
174 | occ = data['occ']
175 | data_out = data.copy()
176 | n_steps, T, dim = points.shape
177 |
178 | N_max = min(self.N, T)
179 |
180 | if self.connected_samples or not self.random:
181 | indices = (np.random.randint(T, size=self.N) if self.random
182 | else np.arange(N_max))
183 | data_out.update({
184 | None: points[:, indices],
185 | 'occ': occ[:, indices],
186 | })
187 | else:
188 | indices = np.random.randint(T, size=(n_steps, self.N))
189 | help_arr = np.arange(n_steps).reshape(-1, 1)
190 | data_out.update({
191 | None: points[help_arr, indices, :],
192 | 'occ': occ[help_arr, indices, :]
193 | })
194 | return data_out
195 |
--------------------------------------------------------------------------------
/im2mesh/data/transforms.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/data/transforms.pyc
--------------------------------------------------------------------------------
/im2mesh/encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from im2mesh.encoder import pointnet
2 |
3 |
4 | encoder_dict = {
5 | 'pointnet_simple': pointnet.SimplePointnet,
6 | }
7 |
8 | encoder_temporal_dict = {
9 | 'pointnet_spatiotemporal': pointnet.SpatioTemporalResnetPointnet,
10 | 'pointnet_spatiotemporal2': pointnet.SpatioTemporalResnetPointnet2,
11 | }
--------------------------------------------------------------------------------
/im2mesh/encoder/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/__init__.pyc
--------------------------------------------------------------------------------
/im2mesh/encoder/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/encoder/__pycache__/conv.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/__pycache__/conv.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/encoder/__pycache__/pointnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/__pycache__/pointnet.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/encoder/__pycache__/pointnet_unet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/__pycache__/pointnet_unet.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/encoder/__pycache__/unet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/__pycache__/unet.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/encoder/__pycache__/unet3d.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/__pycache__/unet3d.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/encoder/conv.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/encoder/conv.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc/__init__.py:
--------------------------------------------------------------------------------
1 | from im2mesh.lpdc import (
2 | config, generation, training, models
3 | )
4 |
5 | __all__ = [
6 | config, generation, training, models
7 | ]
--------------------------------------------------------------------------------
/im2mesh/lpdc/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc/__pycache__/config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/__pycache__/config.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc/__pycache__/generation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/__pycache__/generation.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc/__pycache__/training.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/__pycache__/training.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc/models/__pycache__/decoder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/models/__pycache__/decoder.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc/models/__pycache__/decoder_unet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/models/__pycache__/decoder_unet.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc/models/__pycache__/displacement.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc/models/__pycache__/displacement.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc/models/decoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from im2mesh.layers import (
4 | ResnetBlockFC, CResnetBlockConv1d,
5 | CBatchNorm1d, CBatchNorm1d_legacy
6 | )
7 |
8 |
9 | class Decoder(nn.Module):
10 | ''' Basic Decoder network for OFlow class.
11 |
12 | The decoder network maps points together with latent conditioned codes
13 | c and z to log probabilities of occupancy for the points. This basic
14 | decoder does not use batch normalization.
15 |
16 | Args:
17 | dim (int): dimension of input points
18 | z_dim (int): dimension of latent code z
19 | c_dim (int): dimension of latent conditioned code c
20 | hidden_size (int): dimension of hidden size
21 | leaky (bool): whether to use leaky ReLUs as activation
22 | '''
23 |
24 | def __init__(self, dim=3, z_dim=128, c_dim=128,
25 | hidden_size=128, leaky=False, **kwargs):
26 | super().__init__()
27 | self.z_dim = z_dim
28 | self.c_dim = c_dim
29 | self.dim = dim
30 |
31 | # Submodules
32 | self.fc_p = nn.Linear(dim, hidden_size)
33 |
34 | if not z_dim == 0:
35 | self.fc_z = nn.Linear(z_dim, hidden_size)
36 | if not c_dim == 0:
37 | self.fc_c = nn.Linear(c_dim, hidden_size)
38 |
39 | self.block0 = ResnetBlockFC(hidden_size)
40 | self.block1 = ResnetBlockFC(hidden_size)
41 | self.block2 = ResnetBlockFC(hidden_size)
42 | self.block3 = ResnetBlockFC(hidden_size)
43 | self.block4 = ResnetBlockFC(hidden_size)
44 |
45 | self.fc_out = nn.Linear(hidden_size, 1)
46 |
47 | if not leaky:
48 | self.actvn = F.relu
49 | else:
50 | self.actvn = lambda x: F.leaky_relu(x, 0.2)
51 |
52 | def forward(self, p, z=None, c=None, **kwargs):
53 | ''' Performs a forward pass through the network.
54 |
55 | Args:
56 | p (tensor): points tensor
57 | z (tensor): latent code z
58 | c (tensor): latent conditioned code c
59 | '''
60 | batch_size = p.shape[0]
61 | p = p.view(batch_size, -1, self.dim)
62 | net = self.fc_p(p)
63 |
64 | if self.z_dim != 0:
65 | net_z = self.fc_z(z).unsqueeze(1)
66 | net = net + net_z
67 |
68 | if self.c_dim != 0:
69 | net_c = self.fc_c(c).unsqueeze(1)
70 | net = net + net_c
71 |
72 | net = self.block0(net)
73 | net = self.block1(net)
74 | net = self.block2(net)
75 | net = self.block3(net)
76 | net = self.block4(net)
77 |
78 | out = self.fc_out(self.actvn(net))
79 | out = out.squeeze(-1)
80 |
81 | return out
82 |
83 |
84 | class DecoderCBatchNorm(nn.Module):
85 | ''' Conditioned Batch Norm Decoder network for OFlow class.
86 |
87 | The decoder network maps points together with latent conditioned codes
88 | c and z to log probabilities of occupancy for the points. This decoder
89 | uses conditioned batch normalization to inject the latent codes.
90 |
91 | Args:
92 | dim (int): dimension of input points
93 | z_dim (int): dimension of latent code z
94 | c_dim (int): dimension of latent conditioned code c
95 | hidden_size (int): dimension of hidden size
96 | leaky (bool): whether to use leaky ReLUs as activation
97 |
98 | '''
99 |
100 | def __init__(self, dim=3, z_dim=128, c_dim=128,
101 | hidden_size=256, leaky=False, legacy=False):
102 | super().__init__()
103 | self.z_dim = z_dim
104 | self.dim = dim
105 | if not z_dim == 0:
106 | self.fc_z = nn.Linear(z_dim, hidden_size)
107 |
108 | self.fc_p = nn.Conv1d(dim, hidden_size, 1)
109 | self.block0 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy)
110 | self.block1 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy)
111 | self.block2 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy)
112 | self.block3 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy)
113 | self.block4 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy)
114 |
115 | if not legacy:
116 | self.bn = CBatchNorm1d(c_dim, hidden_size)
117 | else:
118 | self.bn = CBatchNorm1d_legacy(c_dim, hidden_size)
119 |
120 | self.fc_out = nn.Conv1d(hidden_size, 1, 1)
121 |
122 | if not leaky:
123 | self.actvn = F.relu
124 | else:
125 | self.actvn = lambda x: F.leaky_relu(x, 0.2)
126 |
127 | def forward(self, p, z, c, **kwargs):
128 | ''' Performs a forward pass through the network.
129 |
130 | Args:
131 | p (tensor): points tensor
132 | z (tensor): latent code z
133 | c (tensor): latent conditioned code c
134 | '''
135 | p = p.transpose(1, 2)
136 | batch_size, D, T = p.size()
137 | net = self.fc_p(p)
138 |
139 | net = self.block0(net, c)
140 | net = self.block1(net, c)
141 | net = self.block2(net, c)
142 | net = self.block3(net, c)
143 | net = self.block4(net, c)
144 |
145 | out = self.fc_out(self.actvn(self.bn(net, c)))
146 | out = out.squeeze(1)
147 |
148 | return out
149 |
--------------------------------------------------------------------------------
/im2mesh/lpdc/models/displacement.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from im2mesh.layers import ResnetBlockFC
5 |
6 | class DisplacementDecoder(nn.Module):
7 | ''' DisplacementDecoder network class.
8 |
9 | It maps input points and time values together with (optional) conditioned
10 | codes c and latent codes z to the respective motion vectors.
11 |
12 | Args:
13 | in_dim (int): input dimension of points concatenated with the time axis
14 | out_dim (int): output dimension of motion vectors
15 | z_dim (int): dimension of latent code z
16 | c_dim (int): dimension of latent conditioned code c
17 | hidden_size (int): size of the hidden dimension
18 | leaky (bool): whether to use leaky ReLUs as activation
19 | n_blocks (int): number of ResNet-based blocks
20 | '''
21 |
22 | def __init__(self, in_dim=3, out_dim=3, c_dim=128,
23 | hidden_size=512, leaky=False, n_blocks=5, **kwargs):
24 | super().__init__()
25 | self.c_dim = c_dim
26 | self.in_dim = in_dim
27 | self.out_dim = out_dim
28 | self.n_blocks = n_blocks
29 | # Submodules
30 | self.fc_p = nn.Linear(in_dim, hidden_size)
31 |
32 | self.fc_in = nn.Linear(c_dim * 2, c_dim)
33 | if c_dim != 0:
34 | self.fc_c = nn.ModuleList([
35 | nn.Linear(c_dim, hidden_size) for i in range(n_blocks)])
36 |
37 | self.blocks = nn.ModuleList([
38 | ResnetBlockFC(hidden_size) for i in range(n_blocks)
39 | ])
40 |
41 | self.fc_out = nn.Linear(hidden_size, self.out_dim)
42 |
43 | if not leaky:
44 | self.actvn = F.relu
45 | else:
46 | self.actvn = lambda x: F.leaky_relu(x, 0.2)
47 |
48 |
49 | def forward(self, p, cur_t, fuc_t, c):
50 | batch_size, nsteps, c_dim = c.shape
51 | _, npoints, dim = p.shape
52 | cur_t = torch.clamp(cur_t[:, :, None] * nsteps, 0, nsteps-1).expand(batch_size, 1, c_dim).type(torch.LongTensor).to(p.device)
53 | fuc_t = torch.clamp(fuc_t[:, :, None] * nsteps, 0, nsteps-1).expand(batch_size, 1, c_dim).type(torch.LongTensor).to(p.device)
54 | cur_c = torch.gather(c, 1, cur_t)
55 | fuc_c = torch.gather(c, 1, fuc_t)
56 | # glo_c = torch.mean(c, dim=1)
57 | # concat_c = torch.cat([cur_c.squeeze(1), fuc_c.squeeze(1), glo_c], dim=1)
58 | concat_c = torch.cat([cur_c.squeeze(1), fuc_c.squeeze(1)], dim=1)
59 | concat_c = self.fc_in(concat_c)
60 | net = self.fc_p(p)
61 |
62 | # Layer loop
63 | for i in range(self.n_blocks):
64 | if self.c_dim != 0:
65 | net_c = self.fc_c[i](concat_c).unsqueeze(1)
66 | net = net + net_c
67 | net = self.blocks[i](net)
68 |
69 | out = self.fc_out(self.actvn(net))
70 | return out
--------------------------------------------------------------------------------
/im2mesh/lpdc_uneven/__init__.py:
--------------------------------------------------------------------------------
1 | from im2mesh.lpdc_uneven import (
2 | config, generation, training, models
3 | )
4 |
5 | __all__ = [
6 | config, generation, training, models
7 | ]
--------------------------------------------------------------------------------
/im2mesh/lpdc_uneven/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc_uneven/__pycache__/config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/__pycache__/config.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc_uneven/__pycache__/generation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/__pycache__/generation.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc_uneven/__pycache__/training.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/__pycache__/training.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc_uneven/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc_uneven/models/__pycache__/decoder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/models/__pycache__/decoder.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc_uneven/models/__pycache__/decoder_unet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/models/__pycache__/decoder_unet.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc_uneven/models/__pycache__/displacement.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/lpdc_uneven/models/__pycache__/displacement.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/lpdc_uneven/models/decoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from im2mesh.layers import (
4 | ResnetBlockFC, CResnetBlockConv1d,
5 | CBatchNorm1d, CBatchNorm1d_legacy
6 | )
7 |
8 |
9 | class Decoder(nn.Module):
10 | ''' Basic Decoder network for OFlow class.
11 |
12 | The decoder network maps points together with latent conditioned codes
13 | c and z to log probabilities of occupancy for the points. This basic
14 | decoder does not use batch normalization.
15 |
16 | Args:
17 | dim (int): dimension of input points
18 | z_dim (int): dimension of latent code z
19 | c_dim (int): dimension of latent conditioned code c
20 | hidden_size (int): dimension of hidden size
21 | leaky (bool): whether to use leaky ReLUs as activation
22 | '''
23 |
24 | def __init__(self, dim=3, z_dim=128, c_dim=128,
25 | hidden_size=128, leaky=False, **kwargs):
26 | super().__init__()
27 | self.z_dim = z_dim
28 | self.c_dim = c_dim
29 | self.dim = dim
30 |
31 | # Submodules
32 | self.fc_p = nn.Linear(dim, hidden_size)
33 |
34 | if not z_dim == 0:
35 | self.fc_z = nn.Linear(z_dim, hidden_size)
36 | if not c_dim == 0:
37 | self.fc_c = nn.Linear(c_dim, hidden_size)
38 |
39 | self.block0 = ResnetBlockFC(hidden_size)
40 | self.block1 = ResnetBlockFC(hidden_size)
41 | self.block2 = ResnetBlockFC(hidden_size)
42 | self.block3 = ResnetBlockFC(hidden_size)
43 | self.block4 = ResnetBlockFC(hidden_size)
44 |
45 | self.fc_out = nn.Linear(hidden_size, 1)
46 |
47 | if not leaky:
48 | self.actvn = F.relu
49 | else:
50 | self.actvn = lambda x: F.leaky_relu(x, 0.2)
51 |
52 | def forward(self, p, z=None, c=None, **kwargs):
53 | ''' Performs a forward pass through the network.
54 |
55 | Args:
56 | p (tensor): points tensor
57 | z (tensor): latent code z
58 | c (tensor): latent conditioned code c
59 | '''
60 | batch_size = p.shape[0]
61 | p = p.view(batch_size, -1, self.dim)
62 | net = self.fc_p(p)
63 |
64 | if self.z_dim != 0:
65 | net_z = self.fc_z(z).unsqueeze(1)
66 | net = net + net_z
67 |
68 | if self.c_dim != 0:
69 | net_c = self.fc_c(c).unsqueeze(1)
70 | net = net + net_c
71 |
72 | net = self.block0(net)
73 | net = self.block1(net)
74 | net = self.block2(net)
75 | net = self.block3(net)
76 | net = self.block4(net)
77 |
78 | out = self.fc_out(self.actvn(net))
79 | out = out.squeeze(-1)
80 |
81 | return out
82 |
83 |
84 | class DecoderCBatchNorm(nn.Module):
85 | ''' Conditioned Batch Norm Decoder network for OFlow class.
86 |
87 | The decoder network maps points together with latent conditioned codes
88 | c and z to log probabilities of occupancy for the points. This decoder
89 | uses conditioned batch normalization to inject the latent codes.
90 |
91 | Args:
92 | dim (int): dimension of input points
93 | z_dim (int): dimension of latent code z
94 | c_dim (int): dimension of latent conditioned code c
95 | hidden_size (int): dimension of hidden size
96 | leaky (bool): whether to use leaky ReLUs as activation
97 |
98 | '''
99 |
100 | def __init__(self, dim=3, z_dim=128, c_dim=128,
101 | hidden_size=256, leaky=False, legacy=False):
102 | super().__init__()
103 | self.z_dim = z_dim
104 | self.dim = dim
105 | if not z_dim == 0:
106 | self.fc_z = nn.Linear(z_dim, hidden_size)
107 |
108 | self.fc_p = nn.Conv1d(dim, hidden_size, 1)
109 | self.block0 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy)
110 | self.block1 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy)
111 | self.block2 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy)
112 | self.block3 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy)
113 | self.block4 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy)
114 |
115 | if not legacy:
116 | self.bn = CBatchNorm1d(c_dim, hidden_size)
117 | else:
118 | self.bn = CBatchNorm1d_legacy(c_dim, hidden_size)
119 |
120 | self.fc_out = nn.Conv1d(hidden_size, 1, 1)
121 |
122 | if not leaky:
123 | self.actvn = F.relu
124 | else:
125 | self.actvn = lambda x: F.leaky_relu(x, 0.2)
126 |
127 | def forward(self, p, z, c, **kwargs):
128 | ''' Performs a forward pass through the network.
129 |
130 | Args:
131 | p (tensor): points tensor
132 | z (tensor): latent code z
133 | c (tensor): latent conditioned code c
134 | '''
135 | p = p.transpose(1, 2)
136 | batch_size, D, T = p.size()
137 | net = self.fc_p(p)
138 |
139 | net = self.block0(net, c)
140 | net = self.block1(net, c)
141 | net = self.block2(net, c)
142 | net = self.block3(net, c)
143 | net = self.block4(net, c)
144 |
145 | out = self.fc_out(self.actvn(self.bn(net, c)))
146 | out = out.squeeze(1)
147 |
148 | return out
149 |
--------------------------------------------------------------------------------
/im2mesh/lpdc_uneven/models/displacement.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from im2mesh.layers import ResnetBlockFC
5 |
6 | class DisplacementDecoder(nn.Module):
7 | ''' DisplacementDecoder network class.
8 |
9 | It maps input points and time values together with (optional) conditioned
10 | codes c and latent codes z to the respective motion vectors.
11 |
12 | Args:
13 | in_dim (int): input dimension of points concatenated with the time axis
14 | out_dim (int): output dimension of motion vectors
15 | z_dim (int): dimension of latent code z
16 | c_dim (int): dimension of latent conditioned code c
17 | hidden_size (int): size of the hidden dimension
18 | leaky (bool): whether to use leaky ReLUs as activation
19 | n_blocks (int): number of ResNet-based blocks
20 | '''
21 |
22 | def __init__(self, in_dim=3, out_dim=3, c_dim=128,
23 | hidden_size=512, leaky=False, n_blocks=5, **kwargs):
24 | super().__init__()
25 | self.c_dim = c_dim
26 | self.in_dim = in_dim
27 | self.out_dim = out_dim
28 | self.n_blocks = n_blocks
29 | # Submodules
30 | self.fc_p = nn.Linear(in_dim, hidden_size)
31 |
32 | self.fc_in = nn.Linear(c_dim * 2, c_dim)
33 | if c_dim != 0:
34 | self.fc_c = nn.ModuleList([
35 | nn.Linear(c_dim, hidden_size) for i in range(n_blocks)])
36 |
37 | self.blocks = nn.ModuleList([
38 | ResnetBlockFC(hidden_size) for i in range(n_blocks)
39 | ])
40 |
41 | self.fc_out = nn.Linear(hidden_size, self.out_dim)
42 |
43 | if not leaky:
44 | self.actvn = F.relu
45 | else:
46 | self.actvn = lambda x: F.leaky_relu(x, 0.2)
47 |
48 |
49 | def forward(self, p, cur_t, fuc_t, c):
50 | batch_size, nsteps, c_dim = c.shape
51 | _, npoints, dim = p.shape
52 | cur_t = torch.clamp(cur_t[:, :, None] * nsteps, 0, nsteps-1).expand(batch_size, 1, c_dim).type(torch.LongTensor).to(p.device)
53 | fuc_t = torch.clamp(fuc_t[:, :, None] * nsteps, 0, nsteps-1).expand(batch_size, 1, c_dim).type(torch.LongTensor).to(p.device)
54 | cur_c = torch.gather(c, 1, cur_t)
55 | fuc_c = torch.gather(c, 1, fuc_t)
56 | #glo_c = torch.mean(c, dim=1)
57 | #concat_c = torch.cat([cur_c.squeeze(1), fuc_c.squeeze(1), glo_c], dim=1)
58 | concat_c = torch.cat([cur_c.squeeze(1), fuc_c.squeeze(1)], dim=1)
59 | concat_c = self.fc_in(concat_c)
60 | net = self.fc_p(p)
61 |
62 | # Layer loop
63 | for i in range(self.n_blocks):
64 | if self.c_dim != 0:
65 | net_c = self.fc_c[i](concat_c).unsqueeze(1)
66 | net = net + net_c
67 | net = self.blocks[i](net)
68 |
69 | out = self.fc_out(self.actvn(net))
70 | return out
71 |
--------------------------------------------------------------------------------
/im2mesh/training.py:
--------------------------------------------------------------------------------
1 | # from im2mesh import icp
2 | import numpy as np
3 | from collections import defaultdict
4 | from tqdm import tqdm
5 |
6 |
7 | class BaseTrainer(object):
8 | ''' Base trainer class.
9 | '''
10 |
11 | def evaluate(self, val_loader):
12 | ''' Performs an evaluation.
13 | Args:
14 | val_loader (dataloader): Pytorch dataloader
15 | '''
16 | eval_list = defaultdict(list)
17 |
18 | for data in tqdm(val_loader):
19 | eval_step_dict = self.eval_step(data)
20 |
21 | for k, v in eval_step_dict.items():
22 | eval_list[k].append(v)
23 |
24 | eval_dict = {k: np.mean(v) for k, v in eval_list.items()}
25 | return eval_dict
26 |
27 | def train_step(self, *args, **kwargs):
28 | ''' Performs a training step.
29 | '''
30 | raise NotImplementedError
31 |
32 | def eval_step(self, *args, **kwargs):
33 | ''' Performs an evaluation step.
34 | '''
35 | raise NotImplementedError
36 |
37 | def visualize(self, *args, **kwargs):
38 | ''' Performs visualization.
39 | '''
40 | raise NotImplementedError
41 |
--------------------------------------------------------------------------------
/im2mesh/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/__init__.py
--------------------------------------------------------------------------------
/im2mesh/utils/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/__init__.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/__pycache__/grad.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/__pycache__/grad.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/__pycache__/io.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/__pycache__/io.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/__pycache__/onet_generator.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/__pycache__/onet_generator.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/__pycache__/visualize.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/__pycache__/visualize.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/grad.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn as nn
3 | import torch
4 | from torch.autograd import grad
5 |
6 | def gradient(inputs, outputs):
7 | d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
8 | points_grad = grad(
9 | outputs=outputs,
10 | inputs=inputs,
11 | grad_outputs=d_points,
12 | create_graph=True,
13 | retain_graph=True,
14 | only_inputs=True)[0][:, -3:]
15 | return points_grad
16 |
17 |
18 | def gradient2(inputs, outputs):
19 | d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
20 | points_grad = grad(
21 | outputs=outputs,
22 | inputs=inputs,
23 | grad_outputs=d_points,
24 | create_graph=True,
25 | retain_graph=True,
26 | only_inputs=True)[0][:, : -3:]
27 | return points_grad
--------------------------------------------------------------------------------
/im2mesh/utils/icp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.neighbors import NearestNeighbors
3 |
4 |
5 | def best_fit_transform(A, B):
6 | '''
7 | Calculates the least-squares best-fit transform that maps corresponding
8 | points A to B in m spatial dimensions
9 | Input:
10 | A: Nxm numpy array of corresponding points
11 | B: Nxm numpy array of corresponding points
12 | Returns:
13 | T: (m+1)x(m+1) homogeneous transformation matrix that maps A on to B
14 | R: mxm rotation matrix
15 | t: mx1 translation vector
16 | '''
17 |
18 | assert A.shape == B.shape
19 |
20 | # get number of dimensions
21 | m = A.shape[1]
22 |
23 | # translate points to their centroids
24 | centroid_A = np.mean(A, axis=0)
25 | centroid_B = np.mean(B, axis=0)
26 | AA = A - centroid_A
27 | BB = B - centroid_B
28 |
29 | # rotation matrix
30 | H = np.dot(AA.T, BB)
31 | U, S, Vt = np.linalg.svd(H)
32 | R = np.dot(Vt.T, U.T)
33 |
34 | # special reflection case
35 | if np.linalg.det(R) < 0:
36 | Vt[m-1,:] *= -1
37 | R = np.dot(Vt.T, U.T)
38 |
39 | # translation
40 | t = centroid_B.T - np.dot(R,centroid_A.T)
41 |
42 | # homogeneous transformation
43 | T = np.identity(m+1)
44 | T[:m, :m] = R
45 | T[:m, m] = t
46 |
47 | return T, R, t
48 |
49 |
50 | def nearest_neighbor(src, dst):
51 | '''
52 | Find the nearest (Euclidean) neighbor in dst for each point in src
53 | Input:
54 | src: Nxm array of points
55 | dst: Nxm array of points
56 | Output:
57 | distances: Euclidean distances of the nearest neighbor
58 | indices: dst indices of the nearest neighbor
59 | '''
60 |
61 | assert src.shape == dst.shape
62 |
63 | neigh = NearestNeighbors(n_neighbors=1)
64 | neigh.fit(dst)
65 | distances, indices = neigh.kneighbors(src, return_distance=True)
66 | return distances.ravel(), indices.ravel()
67 |
68 |
69 | def icp(A, B, init_pose=None, max_iterations=20, tolerance=0.001):
70 | '''
71 | The Iterative Closest Point method: finds best-fit transform that maps
72 | points A on to points B
73 | Input:
74 | A: Nxm numpy array of source mD points
75 | B: Nxm numpy array of destination mD point
76 | init_pose: (m+1)x(m+1) homogeneous transformation
77 | max_iterations: exit algorithm after max_iterations
78 | tolerance: convergence criteria
79 | Output:
80 | T: final homogeneous transformation that maps A on to B
81 | distances: Euclidean distances (errors) of the nearest neighbor
82 | i: number of iterations to converge
83 | '''
84 |
85 | assert A.shape == B.shape
86 |
87 | # get number of dimensions
88 | m = A.shape[1]
89 |
90 | # make points homogeneous, copy them to maintain the originals
91 | src = np.ones((m+1,A.shape[0]))
92 | dst = np.ones((m+1,B.shape[0]))
93 | src[:m,:] = np.copy(A.T)
94 | dst[:m,:] = np.copy(B.T)
95 |
96 | # apply the initial pose estimation
97 | if init_pose is not None:
98 | src = np.dot(init_pose, src)
99 |
100 | prev_error = 0
101 |
102 | for i in range(max_iterations):
103 | # find the nearest neighbors between the current source and destination points
104 | distances, indices = nearest_neighbor(src[:m,:].T, dst[:m,:].T)
105 |
106 | # compute the transformation between the current source and nearest destination points
107 | T,_,_ = best_fit_transform(src[:m,:].T, dst[:m,indices].T)
108 |
109 | # update the current source
110 | src = np.dot(T, src)
111 |
112 | # check error
113 | mean_error = np.mean(distances)
114 | if np.abs(prev_error - mean_error) < tolerance:
115 | break
116 | prev_error = mean_error
117 |
118 | # calculate final transformation
119 | T,_,_ = best_fit_transform(A, src[:m,:].T)
120 |
121 | return T, distances, i
122 |
--------------------------------------------------------------------------------
/im2mesh/utils/io.py:
--------------------------------------------------------------------------------
1 | import os
2 | from plyfile import PlyElement, PlyData
3 | import numpy as np
4 | from trimesh.util import array_to_string
5 | import trimesh
6 |
7 | def export_pointcloud(vertices, out_file, as_text=True):
8 | assert(vertices.shape[1] == 3)
9 | vertices = vertices.astype(np.float32)
10 | vertices = np.ascontiguousarray(vertices)
11 | vector_dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')]
12 | vertices = vertices.view(dtype=vector_dtype).flatten()
13 | plyel = PlyElement.describe(vertices, 'vertex')
14 | plydata = PlyData([plyel], text=as_text)
15 | plydata.write(out_file)
16 |
17 |
18 | def load_pointcloud(in_file):
19 | plydata = PlyData.read(in_file)
20 | vertices = np.stack([
21 | plydata['vertex']['x'],
22 | plydata['vertex']['y'],
23 | plydata['vertex']['z']
24 | ], axis=1)
25 | return vertices
26 |
27 |
28 | def read_off(file):
29 | """
30 | Reads vertices and faces from an off file.
31 |
32 | :param file: path to file to read
33 | :type file: str
34 | :return: vertices and faces as lists of tuples
35 | :rtype: [(float)], [(int)]
36 | """
37 |
38 | assert os.path.exists(file), 'file %s not found' % file
39 |
40 | with open(file, 'r') as fp:
41 | lines = fp.readlines()
42 | lines = [line.strip() for line in lines]
43 |
44 | # Fix for ModelNet bug were 'OFF' and the number of vertices and faces
45 | # are all in the first line.
46 | if len(lines[0]) > 3:
47 | assert lines[0][:3] == 'OFF' or lines[0][:3] == 'off', \
48 | 'invalid OFF file %s' % file
49 |
50 | parts = lines[0][3:].split(' ')
51 | assert len(parts) == 3
52 |
53 | num_vertices = int(parts[0])
54 | assert num_vertices > 0
55 |
56 | num_faces = int(parts[1])
57 | assert num_faces > 0
58 |
59 | start_index = 1
60 | # This is the regular case!
61 | else:
62 | assert lines[0] == 'OFF' or lines[0] == 'off', \
63 | 'invalid OFF file %s' % file
64 |
65 | parts = lines[1].split(' ')
66 | assert len(parts) == 3
67 |
68 | num_vertices = int(parts[0])
69 | assert num_vertices > 0
70 |
71 | num_faces = int(parts[1])
72 | assert num_faces > 0
73 |
74 | start_index = 2
75 |
76 | vertices = []
77 | for i in range(num_vertices):
78 | vertex = lines[start_index + i].split(' ')
79 | vertex = [float(point.strip()) for point in vertex if point != '']
80 | assert len(vertex) == 3
81 |
82 | vertices.append(vertex)
83 |
84 | faces = []
85 | for i in range(num_faces):
86 | face = lines[start_index + num_vertices + i].split(' ')
87 | face = [index.strip() for index in face if index != '']
88 |
89 | # check to be sure
90 | for index in face:
91 | assert index != '', \
92 | 'found empty vertex index: %s (%s)' \
93 | % (lines[start_index + num_vertices + i], file)
94 |
95 | face = [int(index) for index in face]
96 |
97 | assert face[0] == len(face) - 1, \
98 | 'face should have %d vertices but as %d (%s)' \
99 | % (face[0], len(face) - 1, file)
100 | assert face[0] == 3, \
101 | 'only triangular meshes supported (%s)' % file
102 | for index in face:
103 | assert index >= 0 and index < num_vertices, \
104 | 'vertex %d (of %d vertices) does not exist (%s)' \
105 | % (index, num_vertices, file)
106 |
107 | assert len(face) > 1
108 |
109 | faces.append(face)
110 |
111 | return vertices, faces
112 |
113 | assert False, 'could not open %s' % file
114 |
115 |
116 | def save_mesh(mesh, out_file, digits=10, face_colors=None):
117 | digits = int(digits)
118 | # prepend a 3 (face count) to each face
119 | if face_colors is None:
120 | faces_stacked = np.column_stack((
121 | np.ones(len(mesh.faces)) * 3, mesh.faces)).astype(np.int64)
122 | else:
123 | mesh.visual.face_colors = face_colors
124 | assert(mesh.visual.face_colors.shape[0] == mesh.faces.shape[0])
125 | faces_stacked = np.column_stack((
126 | np.ones(len(mesh.faces)) * 3, mesh.faces,
127 | mesh.visual.face_colors[:, :3])).astype(np.int64)
128 | export = 'OFF\n'
129 | # the header is vertex count, face count, edge number
130 | export += str(len(mesh.vertices)) + ' ' + str(len(mesh.faces)) + ' 0\n'
131 | export += array_to_string(
132 | mesh.vertices, col_delim=' ', row_delim='\n', digits=digits) + '\n'
133 | export += array_to_string(faces_stacked, col_delim=' ', row_delim='\n')
134 |
135 | with open(out_file, 'w') as f:
136 | f.write(export)
137 |
138 | return mesh
139 |
140 |
141 | def load_mesh(mesh_file):
142 | with open(mesh_file, 'r') as f:
143 | str_file = f.read().split('\n')
144 | n_vertices, n_faces, _ = list(
145 | map(lambda x: int(x), str_file[1].split(' ')))
146 | str_file = str_file[2:] # Remove first 2 lines
147 |
148 | v = [l.split(' ') for l in str_file[:n_vertices]]
149 | f = [l.split(' ') for l in str_file[n_vertices:]]
150 |
151 | v = np.array(v).astype(np.float32)
152 | f = np.array(f).astype(np.uint64)[:, 1:4]
153 |
154 | mesh = trimesh.Trimesh(vertices=v, faces=f, process=False)
155 |
156 | return mesh
--------------------------------------------------------------------------------
/im2mesh/utils/libkdtree/.gitignore:
--------------------------------------------------------------------------------
1 | build
2 |
--------------------------------------------------------------------------------
/im2mesh/utils/libkdtree/MANIFEST.in:
--------------------------------------------------------------------------------
1 | exclude pykdtree/render_template.py
2 | include LICENSE.txt
3 |
--------------------------------------------------------------------------------
/im2mesh/utils/libkdtree/__init__.py:
--------------------------------------------------------------------------------
1 | from .pykdtree.kdtree import KDTree
2 |
3 |
4 | __all__ = [
5 | KDTree
6 | ]
7 |
--------------------------------------------------------------------------------
/im2mesh/utils/libkdtree/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libkdtree/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/libkdtree/pykdtree/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libkdtree/pykdtree/__init__.py
--------------------------------------------------------------------------------
/im2mesh/utils/libkdtree/pykdtree/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libkdtree/pykdtree/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/libkdtree/pykdtree/kdtree.cpython-36m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libkdtree/pykdtree/kdtree.cpython-36m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/im2mesh/utils/libkdtree/pykdtree/render_template.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from mako.template import Template
4 |
5 | mytemplate = Template(filename='_kdtree_core.c.mako')
6 | with open('_kdtree_core.c', 'w') as fp:
7 | fp.write(mytemplate.render())
8 |
--------------------------------------------------------------------------------
/im2mesh/utils/libkdtree/setup.cfg:
--------------------------------------------------------------------------------
1 | [bdist_rpm]
2 | requires=numpy
3 | release=1
4 |
5 |
6 |
--------------------------------------------------------------------------------
/im2mesh/utils/libkdtree/setup.py:
--------------------------------------------------------------------------------
1 | #pykdtree, Fast kd-tree implementation with OpenMP-enabled queries
2 | #
3 | #Copyright (C) 2013 - present Esben S. Nielsen
4 | #
5 | # This program is free software: you can redistribute it and/or modify it under
6 | # the terms of the GNU Lesser General Public License as published by the Free
7 | # Software Foundation, either version 3 of the License, or
8 | #(at your option) any later version.
9 | #
10 | # This program is distributed in the hope that it will be useful, but WITHOUT
11 | # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12 | # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
13 | # details.
14 | #
15 | # You should have received a copy of the GNU Lesser General Public License along
16 | # with this program. If not, see .
17 |
18 | import os
19 | from setuptools import setup, Extension
20 | from setuptools.command.build_ext import build_ext
21 |
22 | # Get OpenMP setting from environment
23 | try:
24 | use_omp = int(os.environ['USE_OMP'])
25 | except KeyError:
26 | use_omp = True
27 |
28 |
29 | def set_builtin(name, value):
30 | if isinstance(__builtins__, dict):
31 | __builtins__[name] = value
32 | else:
33 | setattr(__builtins__, name, value)
34 |
35 |
36 | # Custom builder to handler compiler flags. Edit if needed.
37 | class build_ext_subclass(build_ext):
38 | def build_extensions(self):
39 | comp = self.compiler.compiler_type
40 | if comp in ('unix', 'cygwin', 'mingw32'):
41 | # Check if build is with OpenMP
42 | if use_omp:
43 | extra_compile_args = ['-std=c99', '-O3', '-fopenmp']
44 | extra_link_args=['-lgomp']
45 | else:
46 | extra_compile_args = ['-std=c99', '-O3']
47 | extra_link_args = []
48 | elif comp == 'msvc':
49 | extra_compile_args = ['/Ox']
50 | extra_link_args = []
51 | if use_omp:
52 | extra_compile_args.append('/openmp')
53 | else:
54 | # Add support for more compilers here
55 | raise ValueError('Compiler flags undefined for %s. Please modify setup.py and add compiler flags'
56 | % comp)
57 | self.extensions[0].extra_compile_args = extra_compile_args
58 | self.extensions[0].extra_link_args = extra_link_args
59 | build_ext.build_extensions(self)
60 |
61 | def finalize_options(self):
62 | '''
63 | In order to avoid premature import of numpy before it gets installed as a dependency
64 | get numpy include directories during the extensions building process
65 | http://stackoverflow.com/questions/19919905/how-to-bootstrap-numpy-installation-in-setup-py
66 | '''
67 | build_ext.finalize_options(self)
68 | # Prevent numpy from thinking it is still in its setup process:
69 | set_builtin('__NUMPY_SETUP__', False)
70 | import numpy
71 | self.include_dirs.append(numpy.get_include())
72 |
73 |
74 | setup(
75 | name='pykdtree',
76 | version='1.3.1',
77 | description='Fast kd-tree implementation with OpenMP-enabled queries',
78 | author='Esben S. Nielsen',
79 | author_email='storpipfugl@gmail.com',
80 | packages = ['pykdtree'],
81 | python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*',
82 | install_requires=['numpy'],
83 | setup_requires=['numpy'],
84 | tests_require=['nose'],
85 | zip_safe=False,
86 | test_suite = 'nose.collector',
87 | ext_modules = [Extension('pykdtree.kdtree',
88 | ['pykdtree/kdtree.c', 'pykdtree/_kdtree_core.c'])],
89 | cmdclass = {'build_ext': build_ext_subclass },
90 | classifiers=[
91 | 'Development Status :: 5 - Production/Stable',
92 | ('License :: OSI Approved :: '
93 | 'GNU Lesser General Public License v3 (LGPLv3)'),
94 | 'Programming Language :: Python',
95 | 'Operating System :: OS Independent',
96 | 'Intended Audience :: Science/Research',
97 | 'Topic :: Scientific/Engineering'
98 | ]
99 | )
100 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmcubes/.gitignore:
--------------------------------------------------------------------------------
1 | PyMCubes.egg-info
2 | build
3 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmcubes/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2012-2015, P. M. Neila
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | * Neither the name of the copyright holder nor the names of its
15 | contributors may be used to endorse or promote products derived from
16 | this software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmcubes/README.rst:
--------------------------------------------------------------------------------
1 | ========
2 | PyMCubes
3 | ========
4 |
5 | PyMCubes is an implementation of the marching cubes algorithm to extract
6 | isosurfaces from volumetric data. The volumetric data can be given as a
7 | three-dimensional NumPy array or as a Python function ``f(x, y, z)``. The first
8 | option is much faster, but it requires more memory and becomes unfeasible for
9 | very large volumes.
10 |
11 | PyMCubes also provides a function to export the results of the marching cubes as
12 | COLLADA ``(.dae)`` files. This requires the
13 | `PyCollada `_ library.
14 |
15 | Installation
16 | ============
17 |
18 | Just as any standard Python package, clone or download the project
19 | and run::
20 |
21 | $ cd path/to/PyMCubes
22 | $ python setup.py build
23 | $ python setup.py install
24 |
25 | If you do not have write permission on the directory of Python packages,
26 | install with the ``--user`` option::
27 |
28 | $ python setup.py install --user
29 |
30 | Example
31 | =======
32 |
33 | The following example creates a data volume with spherical isosurfaces and
34 | extracts one of them (i.e., a sphere) with PyMCubes. The result is exported as
35 | ``sphere.dae``::
36 |
37 | >>> import numpy as np
38 | >>> import mcubes
39 |
40 | # Create a data volume (30 x 30 x 30)
41 | >>> X, Y, Z = np.mgrid[:30, :30, :30]
42 | >>> u = (X-15)**2 + (Y-15)**2 + (Z-15)**2 - 8**2
43 |
44 | # Extract the 0-isosurface
45 | >>> vertices, triangles = mcubes.marching_cubes(u, 0)
46 |
47 | # Export the result to sphere.dae
48 | >>> mcubes.export_mesh(vertices, triangles, "sphere.dae", "MySphere")
49 |
50 | The second example is very similar to the first one, but it uses a function
51 | to represent the volume instead of a NumPy array::
52 |
53 | >>> import numpy as np
54 | >>> import mcubes
55 |
56 | # Create the volume
57 | >>> f = lambda x, y, z: x**2 + y**2 + z**2
58 |
59 | # Extract the 16-isosurface
60 | >>> vertices, triangles = mcubes.marching_cubes_func((-10,-10,-10), (10,10,10),
61 | ... 100, 100, 100, f, 16)
62 |
63 | # Export the result to sphere2.dae
64 | >>> mcubes.export_mesh(vertices, triangles, "sphere2.dae", "MySphere")
65 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmcubes/__init__.py:
--------------------------------------------------------------------------------
1 | from im2mesh.utils.libmcubes.mcubes import (
2 | marching_cubes, marching_cubes_func
3 | )
4 | from im2mesh.utils.libmcubes.exporter import (
5 | export_mesh, export_obj, export_off
6 | )
7 |
8 |
9 | __all__ = [
10 | marching_cubes, marching_cubes_func,
11 | export_mesh, export_obj, export_off
12 | ]
13 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmcubes/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmcubes/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/libmcubes/__pycache__/exporter.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmcubes/__pycache__/exporter.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/libmcubes/exporter.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 |
4 |
5 | def export_obj(vertices, triangles, filename):
6 | """
7 | Exports a mesh in the (.obj) format.
8 | """
9 |
10 | with open(filename, 'w') as fh:
11 |
12 | for v in vertices:
13 | fh.write("v {} {} {}\n".format(*v))
14 |
15 | for f in triangles:
16 | fh.write("f {} {} {}\n".format(*(f + 1)))
17 |
18 |
19 | def export_off(vertices, triangles, filename):
20 | """
21 | Exports a mesh in the (.off) format.
22 | """
23 |
24 | with open(filename, 'w') as fh:
25 | fh.write('OFF\n')
26 | fh.write('{} {} 0\n'.format(len(vertices), len(triangles)))
27 |
28 | for v in vertices:
29 | fh.write("{} {} {}\n".format(*v))
30 |
31 | for f in triangles:
32 | fh.write("3 {} {} {}\n".format(*f))
33 |
34 |
35 | def export_mesh(vertices, triangles, filename, mesh_name="mcubes_mesh"):
36 | """
37 | Exports a mesh in the COLLADA (.dae) format.
38 |
39 | Needs PyCollada (https://github.com/pycollada/pycollada).
40 | """
41 |
42 | import collada
43 |
44 | mesh = collada.Collada()
45 |
46 | vert_src = collada.source.FloatSource("verts-array", vertices, ('X','Y','Z'))
47 | geom = collada.geometry.Geometry(mesh, "geometry0", mesh_name, [vert_src])
48 |
49 | input_list = collada.source.InputList()
50 | input_list.addInput(0, 'VERTEX', "#verts-array")
51 |
52 | triset = geom.createTriangleSet(np.copy(triangles), input_list, "")
53 | geom.primitives.append(triset)
54 | mesh.geometries.append(geom)
55 |
56 | geomnode = collada.scene.GeometryNode(geom, [])
57 | node = collada.scene.Node(mesh_name, children=[geomnode])
58 |
59 | myscene = collada.scene.Scene("mcubes_scene", [node])
60 | mesh.scenes.append(myscene)
61 | mesh.scene = myscene
62 |
63 | mesh.write(filename)
64 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmcubes/mcubes.cpython-36m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmcubes/mcubes.cpython-36m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/im2mesh/utils/libmcubes/mcubes.pyx:
--------------------------------------------------------------------------------
1 |
2 | # distutils: language = c++
3 | # cython: embedsignature = True
4 |
5 | # from libcpp.vector cimport vector
6 | import numpy as np
7 |
8 | # Define PY_ARRAY_UNIQUE_SYMBOL
9 | cdef extern from "pyarray_symbol.h":
10 | pass
11 |
12 | cimport numpy as np
13 |
14 | np.import_array()
15 |
16 | cdef extern from "pywrapper.h":
17 | cdef object c_marching_cubes "marching_cubes"(np.ndarray, double) except +
18 | cdef object c_marching_cubes2 "marching_cubes2"(np.ndarray, double) except +
19 | cdef object c_marching_cubes3 "marching_cubes3"(np.ndarray, double) except +
20 | cdef object c_marching_cubes_func "marching_cubes_func"(tuple, tuple, int, int, int, object, double) except +
21 |
22 | def marching_cubes(np.ndarray volume, float isovalue):
23 |
24 | verts, faces = c_marching_cubes(volume, isovalue)
25 | verts.shape = (-1, 3)
26 | faces.shape = (-1, 3)
27 | return verts, faces
28 |
29 | def marching_cubes2(np.ndarray volume, float isovalue):
30 |
31 | verts, faces = c_marching_cubes2(volume, isovalue)
32 | verts.shape = (-1, 3)
33 | faces.shape = (-1, 3)
34 | return verts, faces
35 |
36 | def marching_cubes3(np.ndarray volume, float isovalue):
37 |
38 | verts, faces = c_marching_cubes3(volume, isovalue)
39 | verts.shape = (-1, 3)
40 | faces.shape = (-1, 3)
41 | return verts, faces
42 |
43 | def marching_cubes_func(tuple lower, tuple upper, int numx, int numy, int numz, object f, double isovalue):
44 |
45 | verts, faces = c_marching_cubes_func(lower, upper, numx, numy, numz, f, isovalue)
46 | verts.shape = (-1, 3)
47 | faces.shape = (-1, 3)
48 | return verts, faces
49 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmcubes/pyarray_symbol.h:
--------------------------------------------------------------------------------
1 |
2 | #define PY_ARRAY_UNIQUE_SYMBOL mcubes_PyArray_API
3 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmcubes/pyarraymodule.h:
--------------------------------------------------------------------------------
1 |
2 | #ifndef _EXTMODULE_H
3 | #define _EXTMODULE_H
4 |
5 | #include
6 | #include
7 |
8 | // #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
9 | #define PY_ARRAY_UNIQUE_SYMBOL mcubes_PyArray_API
10 | #define NO_IMPORT_ARRAY
11 | #include "numpy/arrayobject.h"
12 |
13 | #include
14 |
15 | template
16 | struct numpy_typemap;
17 |
18 | #define define_numpy_type(ctype, dtype) \
19 | template<> \
20 | struct numpy_typemap \
21 | {static const int type = dtype;};
22 |
23 | define_numpy_type(bool, NPY_BOOL);
24 | define_numpy_type(char, NPY_BYTE);
25 | define_numpy_type(short, NPY_SHORT);
26 | define_numpy_type(int, NPY_INT);
27 | define_numpy_type(long, NPY_LONG);
28 | define_numpy_type(long long, NPY_LONGLONG);
29 | define_numpy_type(unsigned char, NPY_UBYTE);
30 | define_numpy_type(unsigned short, NPY_USHORT);
31 | define_numpy_type(unsigned int, NPY_UINT);
32 | define_numpy_type(unsigned long, NPY_ULONG);
33 | define_numpy_type(unsigned long long, NPY_ULONGLONG);
34 | define_numpy_type(float, NPY_FLOAT);
35 | define_numpy_type(double, NPY_DOUBLE);
36 | define_numpy_type(long double, NPY_LONGDOUBLE);
37 | define_numpy_type(std::complex, NPY_CFLOAT);
38 | define_numpy_type(std::complex, NPY_CDOUBLE);
39 | define_numpy_type(std::complex, NPY_CLONGDOUBLE);
40 |
41 | template
42 | T PyArray_SafeGet(const PyArrayObject* aobj, const npy_intp* indaux)
43 | {
44 | // HORROR.
45 | npy_intp* ind = const_cast(indaux);
46 | void* ptr = PyArray_GetPtr(const_cast(aobj), ind);
47 | switch(PyArray_TYPE(aobj))
48 | {
49 | case NPY_BOOL:
50 | return static_cast(*reinterpret_cast(ptr));
51 | case NPY_BYTE:
52 | return static_cast(*reinterpret_cast(ptr));
53 | case NPY_SHORT:
54 | return static_cast(*reinterpret_cast(ptr));
55 | case NPY_INT:
56 | return static_cast(*reinterpret_cast(ptr));
57 | case NPY_LONG:
58 | return static_cast(*reinterpret_cast(ptr));
59 | case NPY_LONGLONG:
60 | return static_cast(*reinterpret_cast(ptr));
61 | case NPY_UBYTE:
62 | return static_cast(*reinterpret_cast(ptr));
63 | case NPY_USHORT:
64 | return static_cast(*reinterpret_cast(ptr));
65 | case NPY_UINT:
66 | return static_cast(*reinterpret_cast(ptr));
67 | case NPY_ULONG:
68 | return static_cast(*reinterpret_cast(ptr));
69 | case NPY_ULONGLONG:
70 | return static_cast(*reinterpret_cast(ptr));
71 | case NPY_FLOAT:
72 | return static_cast(*reinterpret_cast(ptr));
73 | case NPY_DOUBLE:
74 | return static_cast(*reinterpret_cast(ptr));
75 | case NPY_LONGDOUBLE:
76 | return static_cast(*reinterpret_cast(ptr));
77 | default:
78 | throw std::runtime_error("data type not supported");
79 | }
80 | }
81 |
82 | template
83 | T PyArray_SafeSet(PyArrayObject* aobj, const npy_intp* indaux, const T& value)
84 | {
85 | // HORROR.
86 | npy_intp* ind = const_cast(indaux);
87 | void* ptr = PyArray_GetPtr(aobj, ind);
88 | switch(PyArray_TYPE(aobj))
89 | {
90 | case NPY_BOOL:
91 | *reinterpret_cast(ptr) = static_cast(value);
92 | break;
93 | case NPY_BYTE:
94 | *reinterpret_cast(ptr) = static_cast(value);
95 | break;
96 | case NPY_SHORT:
97 | *reinterpret_cast(ptr) = static_cast(value);
98 | break;
99 | case NPY_INT:
100 | *reinterpret_cast(ptr) = static_cast(value);
101 | break;
102 | case NPY_LONG:
103 | *reinterpret_cast(ptr) = static_cast(value);
104 | break;
105 | case NPY_LONGLONG:
106 | *reinterpret_cast(ptr) = static_cast(value);
107 | break;
108 | case NPY_UBYTE:
109 | *reinterpret_cast(ptr) = static_cast(value);
110 | break;
111 | case NPY_USHORT:
112 | *reinterpret_cast(ptr) = static_cast(value);
113 | break;
114 | case NPY_UINT:
115 | *reinterpret_cast(ptr) = static_cast(value);
116 | break;
117 | case NPY_ULONG:
118 | *reinterpret_cast(ptr) = static_cast(value);
119 | break;
120 | case NPY_ULONGLONG:
121 | *reinterpret_cast(ptr) = static_cast(value);
122 | break;
123 | case NPY_FLOAT:
124 | *reinterpret_cast(ptr) = static_cast(value);
125 | break;
126 | case NPY_DOUBLE:
127 | *reinterpret_cast(ptr) = static_cast(value);
128 | break;
129 | case NPY_LONGDOUBLE:
130 | *reinterpret_cast(ptr) = static_cast(value);
131 | break;
132 | default:
133 | throw std::runtime_error("data type not supported");
134 | }
135 | }
136 |
137 | #endif
138 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmcubes/pywrapper.h:
--------------------------------------------------------------------------------
1 |
2 | #ifndef _PYWRAPPER_H
3 | #define _PYWRAPPER_H
4 |
5 | #include
6 | #include "pyarraymodule.h"
7 |
8 | #include
9 |
10 | PyObject* marching_cubes(PyArrayObject* arr, double isovalue);
11 | PyObject* marching_cubes2(PyArrayObject* arr, double isovalue);
12 | PyObject* marching_cubes3(PyArrayObject* arr, double isovalue);
13 | PyObject* marching_cubes_func(PyObject* lower, PyObject* upper,
14 | int numx, int numy, int numz, PyObject* f, double isovalue);
15 |
16 | #endif // _PYWRAPPER_H
17 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmcubes/setup.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 |
3 | try:
4 | from setuptools import setup
5 | except ImportError:
6 | from distutils.core import setup
7 |
8 | from Cython.Build import cythonize
9 |
10 | import numpy
11 | from distutils.extension import Extension
12 |
13 | # Get the version number.
14 | numpy_include_dir = numpy.get_include()
15 |
16 | mcubes_module = Extension(
17 | "mcubes",
18 | [
19 | "mcubes.pyx",
20 | "pywrapper.cpp",
21 | "marchingcubes.cpp"
22 | ],
23 | language="c++",
24 | extra_compile_args=['-std=c++11'],
25 | include_dirs=[numpy_include_dir]
26 | )
27 |
28 | setup(name="PyMCubes",
29 | version="0.0.6",
30 | description="Marching cubes for Python",
31 | author="Pablo Márquez Neila",
32 | author_email="pablo.marquezneila@epfl.ch",
33 | url="https://github.com/pmneila/PyMCubes",
34 | license="BSD 3-clause",
35 | long_description="""
36 | Marching cubes for Python
37 | """,
38 | classifiers=[
39 | "Development Status :: 4 - Beta",
40 | "Environment :: Console",
41 | "Intended Audience :: Developers",
42 | "Intended Audience :: Science/Research",
43 | "License :: OSI Approved :: BSD License",
44 | "Natural Language :: English",
45 | "Operating System :: OS Independent",
46 | "Programming Language :: C++",
47 | "Programming Language :: Python",
48 | "Topic :: Multimedia :: Graphics :: 3D Modeling",
49 | "Topic :: Scientific/Engineering :: Image Recognition",
50 | ],
51 | packages=["mcubes"],
52 | ext_modules=cythonize(mcubes_module),
53 | requires=['numpy', 'Cython', 'PyCollada'],
54 | setup_requires=['numpy', 'Cython']
55 | )
56 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmesh/.gitignore:
--------------------------------------------------------------------------------
1 | triangle_hash.cpp
2 | build
3 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmesh/__init__.py:
--------------------------------------------------------------------------------
1 | from .inside_mesh import (
2 | check_mesh_contains, MeshIntersector, TriangleIntersector2d
3 | )
4 |
5 |
6 | __all__ = [
7 | check_mesh_contains, MeshIntersector, TriangleIntersector2d
8 | ]
9 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmesh/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmesh/__init__.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/libmesh/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmesh/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/libmesh/__pycache__/inside_mesh.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmesh/__pycache__/inside_mesh.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/libmesh/inside_mesh.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .triangle_hash import TriangleHash as _TriangleHash
3 |
4 |
5 | def check_mesh_contains(mesh, points, hash_resolution=512):
6 | intersector = MeshIntersector(mesh, hash_resolution)
7 | contains = intersector.query(points)
8 | return contains
9 |
10 |
11 | class MeshIntersector:
12 | def __init__(self, mesh, resolution=512):
13 | triangles = mesh.vertices[mesh.faces].astype(np.float64)
14 | n_tri = triangles.shape[0]
15 |
16 | self.resolution = resolution
17 | self.bbox_min = triangles.reshape(3 * n_tri, 3).min(axis=0)
18 | self.bbox_max = triangles.reshape(3 * n_tri, 3).max(axis=0)
19 | # Tranlate and scale it to [0.5, self.resolution - 0.5]^3
20 | self.scale = (resolution - 1) / (self.bbox_max - self.bbox_min)
21 | self.translate = 0.5 - self.scale * self.bbox_min
22 |
23 | self._triangles = triangles = self.rescale(triangles)
24 | # assert(np.allclose(triangles.reshape(-1, 3).min(0), 0.5))
25 | # assert(np.allclose(triangles.reshape(-1, 3).max(0), resolution - 0.5))
26 |
27 | triangles2d = triangles[:, :, :2]
28 | self._tri_intersector2d = TriangleIntersector2d(
29 | triangles2d, resolution)
30 |
31 | def query(self, points):
32 | # Rescale points
33 | points = self.rescale(points)
34 |
35 | # placeholder result with no hits we'll fill in later
36 | contains = np.zeros(len(points), dtype=np.bool)
37 |
38 | # cull points outside of the axis aligned bounding box
39 | # this avoids running ray tests unless points are close
40 | inside_aabb = np.all(
41 | (0 <= points) & (points <= self.resolution), axis=1)
42 | if not inside_aabb.any():
43 | return contains
44 |
45 | # Only consider points inside bounding box
46 | mask = inside_aabb
47 | points = points[mask]
48 |
49 | # Compute intersection depth and check order
50 | points_indices, tri_indices = self._tri_intersector2d.query(points[:, :2])
51 |
52 | triangles_intersect = self._triangles[tri_indices]
53 | points_intersect = points[points_indices]
54 |
55 | depth_intersect, abs_n_2 = self.compute_intersection_depth(
56 | points_intersect, triangles_intersect)
57 |
58 | # Count number of intersections in both directions
59 | smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2
60 | bigger_depth = depth_intersect < points_intersect[:, 2] * abs_n_2
61 | points_indices_0 = points_indices[smaller_depth]
62 | points_indices_1 = points_indices[bigger_depth]
63 |
64 | nintersect0 = np.bincount(points_indices_0, minlength=points.shape[0])
65 | nintersect1 = np.bincount(points_indices_1, minlength=points.shape[0])
66 |
67 | # Check if point contained in mesh
68 | contains1 = (np.mod(nintersect0, 2) == 1)
69 | contains2 = (np.mod(nintersect1, 2) == 1)
70 | if (contains1 != contains2).any():
71 | print('Warning: contains1 != contains2 for some points.')
72 | contains[mask] = (contains1 & contains2)
73 | return contains
74 |
75 | def compute_intersection_depth(self, points, triangles):
76 | t1 = triangles[:, 0, :]
77 | t2 = triangles[:, 1, :]
78 | t3 = triangles[:, 2, :]
79 |
80 | v1 = t3 - t1
81 | v2 = t2 - t1
82 | # v1 = v1 / np.linalg.norm(v1, axis=-1, keepdims=True)
83 | # v2 = v2 / np.linalg.norm(v2, axis=-1, keepdims=True)
84 |
85 | normals = np.cross(v1, v2)
86 | alpha = np.sum(normals[:, :2] * (t1[:, :2] - points[:, :2]), axis=1)
87 |
88 | n_2 = normals[:, 2]
89 | t1_2 = t1[:, 2]
90 | s_n_2 = np.sign(n_2)
91 | abs_n_2 = np.abs(n_2)
92 |
93 | mask = (abs_n_2 != 0)
94 |
95 | depth_intersect = np.full(points.shape[0], np.nan)
96 | depth_intersect[mask] = \
97 | t1_2[mask] * abs_n_2[mask] + alpha[mask] * s_n_2[mask]
98 |
99 | # Test the depth:
100 | # TODO: remove and put into tests
101 | # points_new = np.concatenate([points[:, :2], depth_intersect[:, None]], axis=1)
102 | # alpha = (normals * t1).sum(-1)
103 | # mask = (depth_intersect == depth_intersect)
104 | # assert(np.allclose((points_new[mask] * normals[mask]).sum(-1),
105 | # alpha[mask]))
106 | return depth_intersect, abs_n_2
107 |
108 | def rescale(self, array):
109 | array = self.scale * array + self.translate
110 | return array
111 |
112 |
113 | class TriangleIntersector2d:
114 | def __init__(self, triangles, resolution=128):
115 | self.triangles = triangles
116 | self.tri_hash = _TriangleHash(triangles, resolution)
117 |
118 | def query(self, points):
119 | point_indices, tri_indices = self.tri_hash.query(points)
120 | point_indices = np.array(point_indices, dtype=np.int64)
121 | tri_indices = np.array(tri_indices, dtype=np.int64)
122 | points = points[point_indices]
123 | triangles = self.triangles[tri_indices]
124 | mask = self.check_triangles(points, triangles)
125 | point_indices = point_indices[mask]
126 | tri_indices = tri_indices[mask]
127 | return point_indices, tri_indices
128 |
129 | def check_triangles(self, points, triangles):
130 | contains = np.zeros(points.shape[0], dtype=np.bool)
131 | A = triangles[:, :2] - triangles[:, 2:]
132 | A = A.transpose([0, 2, 1])
133 | y = points - triangles[:, 2]
134 |
135 | detA = A[:, 0, 0] * A[:, 1, 1] - A[:, 0, 1] * A[:, 1, 0]
136 |
137 | mask = (np.abs(detA) != 0.)
138 | A = A[mask]
139 | y = y[mask]
140 | detA = detA[mask]
141 |
142 | s_detA = np.sign(detA)
143 | abs_detA = np.abs(detA)
144 |
145 | u = (A[:, 1, 1] * y[:, 0] - A[:, 0, 1] * y[:, 1]) * s_detA
146 | v = (-A[:, 1, 0] * y[:, 0] + A[:, 0, 0] * y[:, 1]) * s_detA
147 |
148 | sum_uv = u + v
149 | contains[mask] = (
150 | (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA)
151 | & (0 < sum_uv) & (sum_uv < abs_detA)
152 | )
153 | return contains
154 |
155 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmesh/inside_mesh.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmesh/inside_mesh.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/libmesh/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from distutils.extension import Extension
3 | from Cython.Build import cythonize
4 |
5 |
6 | ext_modules = [
7 | Extension("triangle_hash",
8 | sources=["triangle_hash.pyx"],
9 | libraries=["m"] # Unix-like specific
10 | )
11 | ]
12 |
13 | setup(
14 | ext_modules=cythonize(ext_modules)
15 | )
16 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmesh/triangle_hash.cpython-36m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmesh/triangle_hash.cpython-36m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/im2mesh/utils/libmesh/triangle_hash.pyx:
--------------------------------------------------------------------------------
1 |
2 | # distutils: language=c++
3 | import numpy as np
4 | cimport numpy as np
5 | cimport cython
6 | from libcpp.vector cimport vector
7 | from libc.math cimport floor, ceil
8 |
9 | cdef class TriangleHash:
10 | cdef vector[vector[int]] spatial_hash
11 | cdef int resolution
12 |
13 | def __cinit__(self, double[:, :, :] triangles, int resolution):
14 | self.spatial_hash.resize(resolution * resolution)
15 | self.resolution = resolution
16 | self._build_hash(triangles)
17 |
18 | @cython.boundscheck(False) # Deactivate bounds checking
19 | @cython.wraparound(False) # Deactivate negative indexing.
20 | cdef int _build_hash(self, double[:, :, :] triangles):
21 | assert(triangles.shape[1] == 3)
22 | assert(triangles.shape[2] == 2)
23 |
24 | cdef int n_tri = triangles.shape[0]
25 | cdef int bbox_min[2]
26 | cdef int bbox_max[2]
27 |
28 | cdef int i_tri, j, x, y
29 | cdef int spatial_idx
30 |
31 | for i_tri in range(n_tri):
32 | # Compute bounding box
33 | for j in range(2):
34 | bbox_min[j] = min(
35 | triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j]
36 | )
37 | bbox_max[j] = max(
38 | triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j]
39 | )
40 | bbox_min[j] = min(max(bbox_min[j], 0), self.resolution - 1)
41 | bbox_max[j] = min(max(bbox_max[j], 0), self.resolution - 1)
42 |
43 | # Find all voxels where bounding box intersects
44 | for x in range(bbox_min[0], bbox_max[0] + 1):
45 | for y in range(bbox_min[1], bbox_max[1] + 1):
46 | spatial_idx = self.resolution * x + y
47 | self.spatial_hash[spatial_idx].push_back(i_tri)
48 |
49 | @cython.boundscheck(False) # Deactivate bounds checking
50 | @cython.wraparound(False) # Deactivate negative indexing.
51 | cpdef query(self, double[:, :] points):
52 | assert(points.shape[1] == 2)
53 | cdef int n_points = points.shape[0]
54 |
55 | cdef vector[int] points_indices
56 | cdef vector[int] tri_indices
57 | # cdef int[:] points_indices_np
58 | # cdef int[:] tri_indices_np
59 |
60 | cdef int i_point, k, x, y
61 | cdef int spatial_idx
62 |
63 | for i_point in range(n_points):
64 | x = int(points[i_point, 0])
65 | y = int(points[i_point, 1])
66 | if not (0 <= x < self.resolution and 0 <= y < self.resolution):
67 | continue
68 |
69 | spatial_idx = self.resolution * x + y
70 | for i_tri in self.spatial_hash[spatial_idx]:
71 | points_indices.push_back(i_point)
72 | tri_indices.push_back(i_tri)
73 |
74 | points_indices_np = np.zeros(points_indices.size(), dtype=np.int32)
75 | tri_indices_np = np.zeros(tri_indices.size(), dtype=np.int32)
76 |
77 | cdef int[:] points_indices_view = points_indices_np
78 | cdef int[:] tri_indices_view = tri_indices_np
79 |
80 | for k in range(points_indices.size()):
81 | points_indices_view[k] = points_indices[k]
82 |
83 | for k in range(tri_indices.size()):
84 | tri_indices_view[k] = tri_indices[k]
85 |
86 | return points_indices_np, tri_indices_np
87 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmise/.gitignore:
--------------------------------------------------------------------------------
1 | mise.c
2 | mise.cpp
3 | mise.html
4 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmise/__init__.py:
--------------------------------------------------------------------------------
1 | from .mise import MISE
2 |
3 |
4 | __all__ = [
5 | MISE
6 | ]
7 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmise/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmise/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/libmise/mise.cpython-36m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libmise/mise.cpython-36m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/im2mesh/utils/libmise/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from distutils.extension import Extension
3 | from Cython.Build import cythonize
4 |
5 |
6 | ext_modules = [
7 | Extension("mise",
8 | sources=["mise.pyx"],
9 | )
10 | ]
11 |
12 | setup(
13 | ext_modules=cythonize(ext_modules)
14 | )
15 |
--------------------------------------------------------------------------------
/im2mesh/utils/libmise/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mise import MISE
3 | import time
4 |
5 | t0 = time.time()
6 | extractor = MISE(1, 2, 0.)
7 |
8 | p = extractor.query()
9 | i = 0
10 |
11 | while p.shape[0] != 0:
12 | print(i)
13 | print(p)
14 | v = 2 * (p.sum(axis=-1) > 2).astype(np.float64) - 1
15 | extractor.update(p, v)
16 | p = extractor.query()
17 | i += 1
18 | if (i >= 8):
19 | break
20 |
21 | print(extractor.to_dense())
22 | # p, v = extractor.get_points()
23 | # print(p)
24 | # print(v)
25 | print('Total time: %f' % (time.time() - t0))
26 |
--------------------------------------------------------------------------------
/im2mesh/utils/libsimplify/__init__.py:
--------------------------------------------------------------------------------
1 | from .simplify_mesh import (
2 | mesh_simplify
3 | )
4 | import trimesh
5 |
6 |
7 | def simplify_mesh(mesh, f_target=10000, agressiveness=7.):
8 | vertices = mesh.vertices
9 | faces = mesh.faces
10 |
11 | vertices, faces = mesh_simplify(vertices, faces, f_target, agressiveness)
12 |
13 | mesh_simplified = trimesh.Trimesh(vertices, faces, process=False)
14 |
15 | return mesh_simplified
16 |
--------------------------------------------------------------------------------
/im2mesh/utils/libsimplify/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libsimplify/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/im2mesh/utils/libsimplify/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from distutils.extension import Extension
3 | from Cython.Build import cythonize
4 |
5 |
6 | ext_modules = [
7 | Extension("simplify_mesh",
8 | sources=["simplify_mesh.pyx"]
9 | )
10 | ]
11 |
12 | setup(
13 | ext_modules=cythonize(ext_modules)
14 | )
15 |
--------------------------------------------------------------------------------
/im2mesh/utils/libsimplify/simplify_mesh.cpython-36m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libsimplify/simplify_mesh.cpython-36m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/im2mesh/utils/libsimplify/simplify_mesh.pyx:
--------------------------------------------------------------------------------
1 | # distutils: language = c++
2 | from libcpp.vector cimport vector
3 | import numpy as np
4 | cimport numpy as np
5 |
6 |
7 | cdef extern from "Simplify.h":
8 | cdef struct vec3f:
9 | double x, y, z
10 |
11 | cdef cppclass SymetricMatrix:
12 | SymetricMatrix() except +
13 |
14 |
15 | cdef extern from "Simplify.h" namespace "Simplify":
16 | cdef struct Triangle:
17 | int v[3]
18 | double err[4]
19 | int deleted, dirty, attr
20 | vec3f uvs[3]
21 | int material
22 |
23 | cdef struct Vertex:
24 | vec3f p
25 | int tstart, tcount
26 | SymetricMatrix q
27 | int border
28 |
29 | cdef vector[Triangle] triangles
30 | cdef vector[Vertex] vertices
31 | cdef void simplify_mesh(int, double)
32 |
33 |
34 | cpdef mesh_simplify(double[:, ::1] vertices_in, long[:, ::1] triangles_in,
35 | int f_target, double agressiveness=7.) except +:
36 | vertices.clear()
37 | triangles.clear()
38 |
39 | # Read in vertices and triangles
40 | cdef Vertex v
41 | for iv in range(vertices_in.shape[0]):
42 | v = Vertex()
43 | v.p.x = vertices_in[iv, 0]
44 | v.p.y = vertices_in[iv, 1]
45 | v.p.z = vertices_in[iv, 2]
46 | vertices.push_back(v)
47 |
48 | cdef Triangle t
49 | for it in range(triangles_in.shape[0]):
50 | t = Triangle()
51 | t.v[0] = triangles_in[it, 0]
52 | t.v[1] = triangles_in[it, 1]
53 | t.v[2] = triangles_in[it, 2]
54 | triangles.push_back(t)
55 |
56 | # Simplify
57 | # print('Simplify...')
58 | simplify_mesh(f_target, agressiveness)
59 |
60 | # Only use triangles that are not deleted
61 | cdef vector[Triangle] triangles_notdel
62 | triangles_notdel.reserve(triangles.size())
63 |
64 | for t in triangles:
65 | if not t.deleted:
66 | triangles_notdel.push_back(t)
67 |
68 | # Read out triangles
69 | vertices_out = np.empty((vertices.size(), 3), dtype=np.float64)
70 | triangles_out = np.empty((triangles_notdel.size(), 3), dtype=np.int64)
71 |
72 | cdef double[:, :] vertices_out_view = vertices_out
73 | cdef long[:, :] triangles_out_view = triangles_out
74 |
75 | for iv in range(vertices.size()):
76 | vertices_out_view[iv, 0] = vertices[iv].p.x
77 | vertices_out_view[iv, 1] = vertices[iv].p.y
78 | vertices_out_view[iv, 2] = vertices[iv].p.z
79 |
80 | for it in range(triangles_notdel.size()):
81 | triangles_out_view[it, 0] = triangles_notdel[it].v[0]
82 | triangles_out_view[it, 1] = triangles_notdel[it].v[1]
83 | triangles_out_view[it, 2] = triangles_notdel[it].v[2]
84 |
85 | # Clear vertices and triangles
86 | vertices.clear()
87 | triangles.clear()
88 |
89 | return vertices_out, triangles_out
--------------------------------------------------------------------------------
/im2mesh/utils/libsimplify/test.py:
--------------------------------------------------------------------------------
1 | from simplify_mesh import mesh_simplify
2 | import numpy as np
3 |
4 | v = np.random.rand(100, 3)
5 | f = np.random.choice(range(100), (50, 3))
6 |
7 | mesh_simplify(v, f, 50)
--------------------------------------------------------------------------------
/im2mesh/utils/libvoxelize/.gitignore:
--------------------------------------------------------------------------------
1 | voxelize.c
2 | voxelize.html
3 | build
4 |
--------------------------------------------------------------------------------
/im2mesh/utils/libvoxelize/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libvoxelize/__init__.py
--------------------------------------------------------------------------------
/im2mesh/utils/libvoxelize/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from distutils.extension import Extension
3 | from Cython.Build import cythonize
4 |
5 | ext_modules = [
6 | Extension("voxelize",
7 | sources=["voxelize.pyx"],
8 | libraries=["m"] # Unix-like specific
9 | )
10 | ]
11 |
12 | setup(
13 | ext_modules=cythonize(ext_modules)
14 | )
15 |
--------------------------------------------------------------------------------
/im2mesh/utils/libvoxelize/voxelize.cpython-36m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/libvoxelize/voxelize.cpython-36m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/im2mesh/utils/libvoxelize/voxelize.pyx:
--------------------------------------------------------------------------------
1 | cimport cython
2 | from libc.math cimport floor, ceil
3 | from cython.view cimport array as cvarray
4 |
5 | cdef extern from "tribox2.h":
6 | int triBoxOverlap(float boxcenter[3], float boxhalfsize[3],
7 | float tri0[3], float tri1[3], float tri2[3])
8 |
9 |
10 | @cython.boundscheck(False) # Deactivate bounds checking
11 | @cython.wraparound(False) # Deactivate negative indexing.
12 | cpdef int voxelize_mesh_(bint[:, :, :] occ, float[:, :, ::1] faces):
13 | assert(faces.shape[1] == 3)
14 | assert(faces.shape[2] == 3)
15 |
16 | n_faces = faces.shape[0]
17 | cdef int i
18 | for i in range(n_faces):
19 | voxelize_triangle_(occ, faces[i])
20 |
21 |
22 | @cython.boundscheck(False) # Deactivate bounds checking
23 | @cython.wraparound(False) # Deactivate negative indexing.
24 | cpdef int voxelize_triangle_(bint[:, :, :] occupancies, float[:, ::1] triverts):
25 | cdef int bbox_min[3]
26 | cdef int bbox_max[3]
27 | cdef int i, j, k
28 | cdef float boxhalfsize[3]
29 | cdef float boxcenter[3]
30 | cdef bint intersection
31 |
32 | boxhalfsize[:] = (0.5, 0.5, 0.5)
33 |
34 | for i in range(3):
35 | bbox_min[i] = (
36 | min(triverts[0, i], triverts[1, i], triverts[2, i])
37 | )
38 | bbox_min[i] = min(max(bbox_min[i], 0), occupancies.shape[i] - 1)
39 |
40 | for i in range(3):
41 | bbox_max[i] = (
42 | max(triverts[0, i], triverts[1, i], triverts[2, i])
43 | )
44 | bbox_max[i] = min(max(bbox_max[i], 0), occupancies.shape[i] - 1)
45 |
46 | for i in range(bbox_min[0], bbox_max[0] + 1):
47 | for j in range(bbox_min[1], bbox_max[1] + 1):
48 | for k in range(bbox_min[2], bbox_max[2] + 1):
49 | boxcenter[:] = (i + 0.5, j + 0.5, k + 0.5)
50 | intersection = triBoxOverlap(&boxcenter[0], &boxhalfsize[0],
51 | &triverts[0, 0], &triverts[1, 0], &triverts[2, 0])
52 | occupancies[i, j, k] |= intersection
53 |
54 |
55 | @cython.boundscheck(False) # Deactivate bounds checking
56 | @cython.wraparound(False) # Deactivate negative indexing.
57 | cdef int test_triangle_aabb(float[::1] boxcenter, float[::1] boxhalfsize, float[:, ::1] triverts):
58 | assert(boxcenter.shape[0] == 3)
59 | assert(boxhalfsize.shape[0] == 3)
60 | assert(triverts.shape[0] == triverts.shape[1] == 3)
61 |
62 | # print(triverts)
63 | # Call functions
64 | cdef int result = triBoxOverlap(&boxcenter[0], &boxhalfsize[0],
65 | &triverts[0, 0], &triverts[1, 0], &triverts[2, 0])
66 | return result
67 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/.gitignore:
--------------------------------------------------------------------------------
1 | *.egg-info/
2 | .installed.cfg
3 | *.egg
4 | *__pycache__*
5 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Ricky Tian Qi Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch Implementation of Differentiable ODE Solvers
2 |
3 | This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpropagation through all solvers is supported using the adjoint method. For usage of ODE solvers in deep learning applications, see [1].
4 |
5 | As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU.
6 |
7 | ---
8 |
9 |
10 |
11 |
12 |
13 |
14 | ## Installation
15 | ```
16 | git clone https://github.com/rtqichen/torchdiffeq.git
17 | cd torchdiffeq
18 | pip install -e .
19 | ```
20 |
21 | ## Examples
22 | Examples are placed in the [`examples`](./examples) directory.
23 |
24 | We encourage those who are interested in using this library to take a look at [`examples/ode_demo.py`](./examples/ode_demo.py) for understanding how to use `torchdiffeq` to fit a simple spiral ODE.
25 |
26 |
27 |
28 |
29 |
30 | ## Basic usage
31 | This library provides one main interface `odeint` which contains general-purpose algorithms for solving initial value problems (IVP), with gradients implemented for all main arguments. An initial value problem consists of an ODE and an initial value,
32 | ```
33 | dy/dt = f(t, y) y(t_0) = y_0.
34 | ```
35 | The goal of an ODE solver is to find a continuous trajectory satisfying the ODE that passes through the initial condition.
36 |
37 | To solve an IVP using the default solver:
38 | ```
39 | from torchdiffeq import odeint
40 |
41 | odeint(func, y0, t)
42 | ```
43 | where `func` is any callable implementing the ordinary differential equation `f(t, x)`, `y0` is an _any_-D Tensor or a tuple of _any_-D Tensors representing the initial values, and `t` is a 1-D Tensor containing the evaluation points. The initial time is taken to be `t[0]`.
44 |
45 | Backpropagation through `odeint` goes through the internals of the solver, but this is not supported for all solvers. Instead, we encourage the use of the adjoint method explained in [1], which will allow solving with as many steps as necessary due to O(1) memory usage.
46 |
47 | To use the adjoint method:
48 | ```
49 | from torchdiffeq import odeint_adjoint as odeint
50 |
51 | odeint(func, y0, t)
52 | ```
53 | `odeint_adjoint` simply wraps around `odeint`, but will use only O(1) memory in exchange for solving an adjoint ODE in the backward call.
54 |
55 | The biggest **gotcha** is that `func` must be a `nn.Module` when using the adjoint method. This is used to collect parameters of the differential equation.
56 |
57 | ### Keyword Arguments
58 | - `rtol` Relative tolerance.
59 | - `atol` Absolute tolerance.
60 | - `method` One of the solvers listed below.
61 |
62 | #### List of ODE Solvers:
63 |
64 | Adaptive-step:
65 | - `dopri5` Runge-Kutta 4(5) [default].
66 | - `adams` Adaptive-order implicit Adams.
67 |
68 | Fixed-step:
69 | - `euler` Euler method.
70 | - `midpoint` Midpoint method.
71 | - `rk4` Fourth-order Runge-Kutta with 3/8 rule.
72 | - `explicit_adams` Explicit Adams.
73 | - `fixed_adams` Implicit Adams.
74 |
75 | ### References
76 | [1] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." *Advances in Neural Processing Information Systems.* 2018. [[arxiv]](https://arxiv.org/abs/1806.07366)
77 |
78 | ---
79 |
80 | If you found this library useful in your research, please consider citing
81 | ```
82 | @article{chen2018neural,
83 | title={Neural Ordinary Differential Equations},
84 | author={Chen, Ricky T. Q. and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David},
85 | journal={Advances in Neural Information Processing Systems},
86 | year={2018}
87 | }
88 | ```
89 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/assets/ode_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/torchdiffeq/assets/ode_demo.gif
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/assets/odenet_0_viz.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/torchdiffeq/assets/odenet_0_viz.png
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/assets/resnet_0_viz.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/LPDC-Net/ba69083cd5343348d377dfbb163d2c15dd6516ab/im2mesh/utils/torchdiffeq/assets/resnet_0_viz.png
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/examples/README.md:
--------------------------------------------------------------------------------
1 | # Overview of Examples
2 |
3 | This `examples` directory contains cleaned up code regarding the usage of adaptive ODE solvers in machine learning. The scripts in this directory assume that `torchdiffeq` is installed following instructions from the main directory.
4 |
5 | ## Demo
6 | The `ode_demo.py` file contains a short implementation of learning a dynamics model to mimic a spiral ODE.
7 |
8 | To visualize the training progress, run
9 | ```
10 | python ode_demo.py --viz
11 | ```
12 | The training should look similar to this:
13 |
14 |
15 |
16 |
17 |
18 | ## ODEnet for MNIST
19 | The `odenet_mnist.py` file contains a reproduction of the MNIST experiments in our Neural ODE paper. Notably not just the architecture but the ODE solver library and integration method are different from our original experiments, though the results are similar to those we report in the paper.
20 |
21 | We can use an adaptive ODE solver to approximate our continuous-depth network while still backpropagating through the network.
22 | ```
23 | python odenet_mnist.py --network odenet
24 | ```
25 | However, the memory requirements for this will blow up very fast, especially for more complex problems where the number of function evaluations can reach nearly a thousand.
26 |
27 | For applications that require solving complex trajectories, we recommend using the adjoint method.
28 | ```
29 | python odenet_mnist.py --network odenet --adjoint True
30 | ```
31 | The adjoint method can be slower when using an adaptive ODE solver as it involves another solve in the backward pass with a much larger system, so experimenting on small systems with direct backpropagation first is recommended.
32 |
33 | Thankfully, it is extremely easy to write code for both adjoint and non-adjoint backpropagation, as they use the same interface.
34 | ```
35 | if adjoint:
36 | from torchdiffeq import odeint_adjoint as odeint
37 | else:
38 | from torchdiffeq import odeint
39 | ```
40 | The main gotcha is that `odeint_adjoint` requires implementing the dynamics network as a `nn.Module` while `odeint` can work with any callable in Python.
41 |
42 | ## Continuous Normalizing Flows
43 | Code for continuous normalizing flows (CNF) have their own public repository. Tools for training, evaluating, and visualizing CNF for reversible generative modeling are provided along with FFJORD, a linear cost stochastic approximation of CNF.
44 |
45 | Find the code in https://github.com/rtqichen/ffjord. This code contains some advanced tricks for `torchdiffeq`.
46 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | setuptools.setup(
4 | name="torchdiffeq",
5 | version="0.0.1",
6 | author="Ricky Tian Qi Chen",
7 | author_email="rtqichen@cs.toronto.edu",
8 | description="ODE solvers and adjoint sensitivity analysis in PyTorch.",
9 | url="https://github.com/rtqichen/torchdiffeq",
10 | packages=['torchdiffeq', 'torchdiffeq._impl'],
11 | install_requires=['torch>=0.4.1'],
12 | classifiers=(
13 | "Programming Language :: Python :: 3"),)
14 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/tests/api_tests.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import torchdiffeq
4 |
5 | from problems import construct_problem
6 |
7 | eps = 1e-12
8 |
9 | torch.set_default_dtype(torch.float64)
10 | TEST_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
11 |
12 |
13 | def max_abs(tensor):
14 | return torch.max(torch.abs(tensor))
15 |
16 |
17 | class TestCollectionState(unittest.TestCase):
18 |
19 | def test_dopri5(self):
20 | f, y0, t_points, sol = construct_problem(TEST_DEVICE)
21 |
22 | tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1]))
23 | tuple_y0 = (y0, y0)
24 |
25 | tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method='dopri5')
26 | max_error0 = (sol - tuple_y[0]).max()
27 | max_error1 = (sol - tuple_y[1]).max()
28 | self.assertLess(max_error0, eps)
29 | self.assertLess(max_error1, eps)
30 |
31 | def test_dopri5_gradient(self):
32 | f, y0, t_points, sol = construct_problem(TEST_DEVICE)
33 |
34 | tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1]))
35 |
36 | for i in range(2):
37 | func = lambda y0, t_points: torchdiffeq.odeint(tuple_f, (y0, y0), t_points, method='dopri5')[i]
38 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
39 |
40 | def test_adams(self):
41 | f, y0, t_points, sol = construct_problem(TEST_DEVICE)
42 |
43 | tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1]))
44 | tuple_y0 = (y0, y0)
45 |
46 | tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method='adams')
47 | max_error0 = (sol - tuple_y[0]).max()
48 | max_error1 = (sol - tuple_y[1]).max()
49 | self.assertLess(max_error0, eps)
50 | self.assertLess(max_error1, eps)
51 |
52 | def test_adams_gradient(self):
53 | f, y0, t_points, sol = construct_problem(TEST_DEVICE)
54 |
55 | tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1]))
56 |
57 | for i in range(2):
58 | func = lambda y0, t_points: torchdiffeq.odeint(tuple_f, (y0, y0), t_points, method='adams')[i]
59 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
60 |
61 |
62 | if __name__ == '__main__':
63 | unittest.main()
64 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/tests/gradient_tests.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import torchdiffeq
4 |
5 | from problems import construct_problem
6 |
7 | eps = 1e-12
8 |
9 | torch.set_default_dtype(torch.float64)
10 | TEST_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
11 |
12 |
13 | def max_abs(tensor):
14 | return torch.max(torch.abs(tensor))
15 |
16 |
17 | class TestGradient(unittest.TestCase):
18 |
19 | def test_midpoint(self):
20 |
21 | f, y0, t_points, _ = construct_problem(TEST_DEVICE)
22 |
23 | func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='midpoint')
24 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
25 |
26 | def test_rk4(self):
27 |
28 | f, y0, t_points, _ = construct_problem(TEST_DEVICE)
29 |
30 | func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='rk4')
31 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
32 |
33 | def test_dopri5(self):
34 | f, y0, t_points, _ = construct_problem(TEST_DEVICE)
35 |
36 | func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='dopri5')
37 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
38 |
39 | def test_adams(self):
40 | f, y0, t_points, _ = construct_problem(TEST_DEVICE)
41 |
42 | func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='adams')
43 | self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
44 |
45 | def test_adjoint(self):
46 | """
47 | Test against dopri5
48 | """
49 | f, y0, t_points, _ = construct_problem(TEST_DEVICE)
50 |
51 | func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='dopri5')
52 | ys = func(y0, t_points)
53 | torch.manual_seed(0)
54 | gradys = torch.rand_like(ys)
55 | ys.backward(gradys)
56 |
57 | # reg_y0_grad = y0.grad
58 | reg_t_grad = t_points.grad
59 | reg_a_grad = f.a.grad
60 | reg_b_grad = f.b.grad
61 |
62 | f, y0, t_points, _ = construct_problem(TEST_DEVICE)
63 |
64 | func = lambda y0, t_points: torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
65 | ys = func(y0, t_points)
66 | ys.backward(gradys)
67 |
68 | # adj_y0_grad = y0.grad
69 | adj_t_grad = t_points.grad
70 | adj_a_grad = f.a.grad
71 | adj_b_grad = f.b.grad
72 |
73 | # self.assertLess(max_abs(reg_y0_grad - adj_y0_grad), eps)
74 | self.assertLess(max_abs(reg_t_grad - adj_t_grad), eps)
75 | self.assertLess(max_abs(reg_a_grad - adj_a_grad), eps)
76 | self.assertLess(max_abs(reg_b_grad - adj_b_grad), eps)
77 |
78 |
79 | class TestCompareAdjointGradient(unittest.TestCase):
80 |
81 | def problem(self):
82 |
83 | class Odefunc(torch.nn.Module):
84 |
85 | def __init__(self):
86 | super(Odefunc, self).__init__()
87 | self.A = torch.nn.Parameter(torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]))
88 | self.unused_module = torch.nn.Linear(2, 5)
89 |
90 | def forward(self, t, y):
91 | return torch.mm(y**3, self.A)
92 |
93 | y0 = torch.tensor([[2., 0.]]).to(TEST_DEVICE).requires_grad_(True)
94 | t_points = torch.linspace(0., 25., 10).to(TEST_DEVICE).requires_grad_(True)
95 | func = Odefunc().to(TEST_DEVICE)
96 | return func, y0, t_points
97 |
98 | def test_dopri5_adjoint_against_dopri5(self):
99 | func, y0, t_points = self.problem()
100 | ys = torchdiffeq.odeint_adjoint(func, y0, t_points, method='dopri5')
101 | gradys = torch.rand_like(ys) * 0.1
102 | ys.backward(gradys)
103 |
104 | adj_y0_grad = y0.grad
105 | adj_t_grad = t_points.grad
106 | adj_A_grad = func.A.grad
107 | self.assertEqual(max_abs(func.unused_module.weight.grad), 0)
108 | self.assertEqual(max_abs(func.unused_module.bias.grad), 0)
109 |
110 | func, y0, t_points = self.problem()
111 | ys = torchdiffeq.odeint(func, y0, t_points, method='dopri5')
112 | ys.backward(gradys)
113 |
114 | self.assertLess(max_abs(y0.grad - adj_y0_grad), 3e-4)
115 | self.assertLess(max_abs(t_points.grad - adj_t_grad), 1e-4)
116 | self.assertLess(max_abs(func.A.grad - adj_A_grad), 2e-3)
117 |
118 | def test_adams_adjoint_against_dopri5(self):
119 | func, y0, t_points = self.problem()
120 | ys_ = torchdiffeq.odeint_adjoint(func, y0, t_points, method='adams')
121 | gradys = torch.rand_like(ys_) * 0.1
122 | ys_.backward(gradys)
123 |
124 | adj_y0_grad = y0.grad
125 | adj_t_grad = t_points.grad
126 | adj_A_grad = func.A.grad
127 | self.assertEqual(max_abs(func.unused_module.weight.grad), 0)
128 | self.assertEqual(max_abs(func.unused_module.bias.grad), 0)
129 |
130 | func, y0, t_points = self.problem()
131 | ys = torchdiffeq.odeint(func, y0, t_points, method='dopri5')
132 | ys.backward(gradys)
133 |
134 | self.assertLess(max_abs(y0.grad - adj_y0_grad), 5e-2)
135 | self.assertLess(max_abs(t_points.grad - adj_t_grad), 5e-4)
136 | self.assertLess(max_abs(func.A.grad - adj_A_grad), 2e-2)
137 |
138 |
139 | if __name__ == '__main__':
140 | unittest.main()
141 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/tests/odeint_tests.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import torchdiffeq
4 |
5 | import problems
6 |
7 | error_tol = 1e-4
8 |
9 | torch.set_default_dtype(torch.float64)
10 | TEST_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
11 |
12 |
13 | def max_abs(tensor):
14 | return torch.max(torch.abs(tensor))
15 |
16 |
17 | def rel_error(true, estimate):
18 | return max_abs((true - estimate) / true)
19 |
20 |
21 | class TestSolverError(unittest.TestCase):
22 |
23 | def test_euler(self):
24 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE)
25 |
26 | y = torchdiffeq.odeint(f, y0, t_points, method='euler')
27 | self.assertLess(rel_error(sol, y), error_tol)
28 |
29 | def test_midpoint(self):
30 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE)
31 |
32 | y = torchdiffeq.odeint(f, y0, t_points, method='midpoint')
33 | self.assertLess(rel_error(sol, y), error_tol)
34 |
35 | def test_rk4(self):
36 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE)
37 |
38 | y = torchdiffeq.odeint(f, y0, t_points, method='rk4')
39 | self.assertLess(rel_error(sol, y), error_tol)
40 |
41 | def test_explicit_adams(self):
42 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE)
43 |
44 | y = torchdiffeq.odeint(f, y0, t_points, method='explicit_adams')
45 | self.assertLess(rel_error(sol, y), error_tol)
46 |
47 | def test_adams(self):
48 | for ode in problems.PROBLEMS.keys():
49 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, ode=ode)
50 | y = torchdiffeq.odeint(f, y0, t_points, method='adams')
51 | with self.subTest(ode=ode):
52 | self.assertLess(rel_error(sol, y), error_tol)
53 |
54 | def test_dopri5(self):
55 | for ode in problems.PROBLEMS.keys():
56 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, ode=ode)
57 | y = torchdiffeq.odeint(f, y0, t_points, method='dopri5')
58 | with self.subTest(ode=ode):
59 | self.assertLess(rel_error(sol, y), error_tol)
60 |
61 | def test_adjoint(self):
62 | for ode in problems.PROBLEMS.keys():
63 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
64 |
65 | y = torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
66 | with self.subTest(ode=ode):
67 | self.assertLess(rel_error(sol, y), error_tol)
68 |
69 |
70 | class TestSolverBackwardsInTimeError(unittest.TestCase):
71 |
72 | def test_euler(self):
73 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
74 |
75 | y = torchdiffeq.odeint(f, y0, t_points, method='euler')
76 | self.assertLess(rel_error(sol, y), error_tol)
77 |
78 | def test_midpoint(self):
79 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
80 |
81 | y = torchdiffeq.odeint(f, y0, t_points, method='midpoint')
82 | self.assertLess(rel_error(sol, y), error_tol)
83 |
84 | def test_rk4(self):
85 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
86 |
87 | y = torchdiffeq.odeint(f, y0, t_points, method='rk4')
88 | self.assertLess(rel_error(sol, y), error_tol)
89 |
90 | def test_explicit_adams(self):
91 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
92 |
93 | y = torchdiffeq.odeint(f, y0, t_points, method='explicit_adams')
94 | self.assertLess(rel_error(sol, y), error_tol)
95 |
96 | def test_adams(self):
97 | for ode in problems.PROBLEMS.keys():
98 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
99 |
100 | y = torchdiffeq.odeint(f, y0, t_points, method='adams')
101 | with self.subTest(ode=ode):
102 | self.assertLess(rel_error(sol, y), error_tol)
103 |
104 | def test_dopri5(self):
105 | for ode in problems.PROBLEMS.keys():
106 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
107 |
108 | y = torchdiffeq.odeint(f, y0, t_points, method='dopri5')
109 | with self.subTest(ode=ode):
110 | self.assertLess(rel_error(sol, y), error_tol)
111 |
112 | def test_adjoint(self):
113 | for ode in problems.PROBLEMS.keys():
114 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
115 |
116 | y = torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
117 | with self.subTest(ode=ode):
118 | self.assertLess(rel_error(sol, y), error_tol)
119 |
120 |
121 | class TestNoIntegration(unittest.TestCase):
122 |
123 | def test_midpoint(self):
124 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
125 |
126 | y = torchdiffeq.odeint(f, y0, t_points[0:1], method='midpoint')
127 | self.assertLess(max_abs(sol[0] - y), error_tol)
128 |
129 | def test_rk4(self):
130 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
131 |
132 | y = torchdiffeq.odeint(f, y0, t_points[0:1], method='rk4')
133 | self.assertLess(max_abs(sol[0] - y), error_tol)
134 |
135 | def test_explicit_adams(self):
136 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
137 |
138 | y = torchdiffeq.odeint(f, y0, t_points[0:1], method='explicit_adams')
139 | self.assertLess(max_abs(sol[0] - y), error_tol)
140 |
141 | def test_adams(self):
142 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
143 |
144 | y = torchdiffeq.odeint(f, y0, t_points[0:1], method='adams')
145 | self.assertLess(max_abs(sol[0] - y), error_tol)
146 |
147 | def test_dopri5(self):
148 | f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
149 |
150 | y = torchdiffeq.odeint(f, y0, t_points[0:1], method='dopri5')
151 | self.assertLess(max_abs(sol[0] - y), error_tol)
152 |
153 |
154 | if __name__ == '__main__':
155 | unittest.main()
156 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/tests/problems.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import scipy.linalg
4 | import torch
5 |
6 |
7 | class ConstantODE(torch.nn.Module):
8 |
9 | def __init__(self, device):
10 | super(ConstantODE, self).__init__()
11 | self.a = torch.nn.Parameter(torch.tensor(0.2).to(device))
12 | self.b = torch.nn.Parameter(torch.tensor(3.0).to(device))
13 |
14 | def forward(self, t, y):
15 | return self.a + (y - (self.a * t + self.b))**5
16 |
17 | def y_exact(self, t):
18 | return self.a * t + self.b
19 |
20 |
21 | class SineODE(torch.nn.Module):
22 |
23 | def __init__(self, device):
24 | super(SineODE, self).__init__()
25 |
26 | def forward(self, t, y):
27 | return 2 * y / t + t**4 * torch.sin(2 * t) - t**2 + 4 * t**3
28 |
29 | def y_exact(self, t):
30 | return -0.5 * t**4 * torch.cos(2 * t) + 0.5 * t**3 * torch.sin(2 * t) + 0.25 * t**2 * torch.cos(
31 | 2 * t
32 | ) - t**3 + 2 * t**4 + (math.pi - 0.25) * t**2
33 |
34 |
35 | class LinearODE(torch.nn.Module):
36 |
37 | def __init__(self, device, dim=10):
38 | super(LinearODE, self).__init__()
39 | self.dim = dim
40 | U = torch.randn(dim, dim).to(device) * 0.1
41 | A = 2 * U - (U + U.transpose(0, 1))
42 | self.A = torch.nn.Parameter(A)
43 | self.initial_val = np.ones((dim, 1))
44 |
45 | def forward(self, t, y):
46 | return torch.mm(self.A, y.reshape(self.dim, 1)).reshape(-1)
47 |
48 | def y_exact(self, t):
49 | t = t.detach().cpu().numpy()
50 | A_np = self.A.detach().cpu().numpy()
51 | ans = []
52 | for t_i in t:
53 | ans.append(np.matmul(scipy.linalg.expm(A_np * t_i), self.initial_val))
54 | return torch.stack([torch.tensor(ans_) for ans_ in ans]).reshape(len(t), self.dim)
55 |
56 |
57 | PROBLEMS = {'constant': ConstantODE, 'linear': LinearODE, 'sine': SineODE}
58 |
59 |
60 | def construct_problem(device, npts=10, ode='constant', reverse=False):
61 |
62 | f = PROBLEMS[ode](device)
63 |
64 | t_points = torch.linspace(1, 8, npts).to(device).requires_grad_(True)
65 | sol = f.y_exact(t_points)
66 |
67 | def _flip(x, dim):
68 | indices = [slice(None)] * x.dim()
69 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
70 | return x[tuple(indices)]
71 |
72 | if reverse:
73 | t_points = _flip(t_points, 0).clone().detach()
74 | sol = _flip(sol, 0).clone().detach()
75 |
76 | return f, sol[0].detach(), t_points, sol
77 |
78 |
79 | if __name__ == '__main__':
80 | f = SineODE('cpu')
81 | t_points = torch.linspace(1, 8, 100).to('cpu').requires_grad_(True)
82 | sol = f.y_exact(t_points)
83 |
84 | import matplotlib.pyplot as plt
85 | plt.plot(t_points.detach().cpu().numpy(), sol.detach().cpu().numpy())
86 | plt.show()
87 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/tests/run_all.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from api_tests import *
3 | from gradient_tests import *
4 | from odeint_tests import *
5 |
6 | if __name__ == '__main__':
7 | unittest.main()
8 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/torchdiffeq/__init__.py:
--------------------------------------------------------------------------------
1 | from ._impl import odeint
2 | from ._impl import odeint_adjoint
3 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/torchdiffeq/_impl/__init__.py:
--------------------------------------------------------------------------------
1 | from .odeint import odeint
2 | from .adjoint import odeint_adjoint
3 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/torchdiffeq/_impl/adjoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from . import odeint
4 | from .misc import _flatten, _flatten_convert_none_to_zeros
5 |
6 |
7 | class OdeintAdjointMethod(torch.autograd.Function):
8 |
9 | @staticmethod
10 | def forward(ctx, *args):
11 | assert len(args) >= 9, 'Internal error: all arguments required.'
12 | y0, func, t, flat_params, rtol, atol, method, options, f_options = \
13 | args[:-8], args[-8], args[-7], args[-6], args[-5], args[-4], args[-3], args[-2], args[-1]
14 |
15 | ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options, ctx.f_options = func, rtol, atol, method, options, f_options
16 |
17 | with torch.no_grad():
18 | ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, f_options=f_options)
19 | ctx.save_for_backward(t, flat_params, *ans)
20 | return ans
21 |
22 | @staticmethod
23 | def backward(ctx, *grad_output):
24 | t, flat_params, *ans = ctx.saved_tensors
25 | ans = tuple(ans)
26 | func, rtol, atol, method, options, f_options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options, ctx.f_options
27 | n_tensors = len(ans)
28 | f_params = tuple(func.parameters())
29 |
30 | # TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
31 | def augmented_dynamics(t, y_aug, **f_options):
32 | # Dynamics of the original system augmented with
33 | # the adjoint wrt y, and an integrator wrt t and args.
34 | y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] # Ignore adj_time and adj_params.
35 |
36 | with torch.set_grad_enabled(True):
37 | t = t.to(y[0].device).detach().requires_grad_(True)
38 | y = tuple(y_.detach().requires_grad_(True) for y_ in y)
39 | func_eval = func(t, y, **f_options)
40 | vjp_t, *vjp_y_and_params = torch.autograd.grad(
41 | func_eval, (t,) + y + f_params,
42 | tuple(-adj_y_ for adj_y_ in adj_y), allow_unused=True, retain_graph=True
43 | )
44 | vjp_y = vjp_y_and_params[:n_tensors]
45 | vjp_params = vjp_y_and_params[n_tensors:]
46 |
47 | # autograd.grad returns None if no gradient, set to zero.
48 | vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
49 | vjp_y = tuple(torch.zeros_like(y_) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y))
50 | vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)
51 |
52 | if len(f_params) == 0:
53 | vjp_params = torch.tensor(0.).to(vjp_y[0])
54 | return (*func_eval, *vjp_y, vjp_t, vjp_params)
55 |
56 | T = ans[0].shape[0]
57 | with torch.no_grad():
58 | adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)
59 | adj_params = torch.zeros_like(flat_params)
60 | adj_time = torch.tensor(0.).to(t)
61 | time_vjps = []
62 | for i in range(T - 1, 0, -1):
63 |
64 | ans_i = tuple(ans_[i] for ans_ in ans)
65 | grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output)
66 | func_i = func(t[i], ans_i, **f_options)
67 |
68 | # Compute the effect of moving the current time measurement point.
69 | dLd_cur_t = sum(
70 | torch.dot(func_i_.view(-1), grad_output_i_.view(-1)).view(1)
71 | for func_i_, grad_output_i_ in zip(func_i, grad_output_i)
72 | )
73 | adj_time = adj_time - dLd_cur_t
74 | time_vjps.append(dLd_cur_t)
75 |
76 | # Run the augmented system backwards in time.
77 | if len(adj_params) == 0:
78 | adj_params = torch.tensor(0.).to(adj_y[0])
79 | aug_y0 = (*ans_i, *adj_y, adj_time, adj_params)
80 | aug_ans = odeint(
81 | augmented_dynamics, aug_y0,
82 | torch.tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options, f_options=f_options
83 | )
84 |
85 | # Unpack aug_ans.
86 | adj_y = aug_ans[n_tensors:2 * n_tensors]
87 | adj_time = aug_ans[2 * n_tensors]
88 | adj_params = aug_ans[2 * n_tensors + 1]
89 |
90 | adj_y = tuple(adj_y_[1] if len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y)
91 | if len(adj_time) > 0: adj_time = adj_time[1]
92 | if len(adj_params) > 0: adj_params = adj_params[1]
93 |
94 | adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output))
95 |
96 | del aug_y0, aug_ans
97 |
98 | time_vjps.append(adj_time)
99 | time_vjps = torch.cat(time_vjps[::-1])
100 |
101 | return (*adj_y, None, time_vjps, adj_params, None, None, None, None, None)
102 |
103 |
104 | def odeint_adjoint(func, y0, t, rtol=1e-6, atol=1e-12, method=None, options=None, f_options=None):
105 |
106 | # We need this in order to access the variables inside this module,
107 | # since we have no other way of getting variables along the execution path.
108 | if not isinstance(func, nn.Module):
109 | raise ValueError('func is required to be an instance of nn.Module.')
110 |
111 | tensor_input = False
112 | if torch.is_tensor(y0):
113 |
114 | class TupleFunc(nn.Module):
115 |
116 | def __init__(self, base_func):
117 | super(TupleFunc, self).__init__()
118 | self.base_func = base_func
119 |
120 | def forward(self, t, y, **f_options):
121 | return (self.base_func(t, y[0], **f_options),)
122 |
123 | tensor_input = True
124 | y0 = (y0,)
125 | func = TupleFunc(func)
126 |
127 | flat_params = _flatten(func.parameters())
128 | ys = OdeintAdjointMethod.apply(*y0, func, t, flat_params, rtol, atol, method, options, f_options)
129 |
130 | if tensor_input:
131 | ys = ys[0]
132 | return ys
133 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/torchdiffeq/_impl/dopri5.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate
2 | import torch
3 | from .misc import (
4 | _scaled_dot_product, _convert_to_tensor, _is_finite, _select_initial_step, _handle_unused_kwargs, _is_iterable,
5 | _optimal_step_size, _compute_error_ratio
6 | )
7 | from .solvers import AdaptiveStepsizeODESolver
8 | from .interp import _interp_fit, _interp_evaluate
9 | from .rk_common import _RungeKuttaState, _ButcherTableau, _runge_kutta_step
10 |
11 | _DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau(
12 | alpha=[1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.],
13 | beta=[
14 | [1 / 5],
15 | [3 / 40, 9 / 40],
16 | [44 / 45, -56 / 15, 32 / 9],
17 | [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729],
18 | [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656],
19 | [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84],
20 | ],
21 | c_sol=[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0],
22 | c_error=[
23 | 35 / 384 - 1951 / 21600,
24 | 0,
25 | 500 / 1113 - 22642 / 50085,
26 | 125 / 192 - 451 / 720,
27 | -2187 / 6784 - -12231 / 42400,
28 | 11 / 84 - 649 / 6300,
29 | -1. / 60.,
30 | ],
31 | )
32 |
33 | DPS_C_MID = [
34 | 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2,
35 | 187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2
36 | ]
37 |
38 |
39 | def _interp_fit_dopri5(y0, y1, k, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU):
40 | """Fit an interpolating polynomial to the results of a Runge-Kutta step."""
41 | dt = dt.type_as(y0[0])
42 | y_mid = tuple(y0_ + _scaled_dot_product(dt, DPS_C_MID, k_) for y0_, k_ in zip(y0, k))
43 | f0 = tuple(k_[0] for k_ in k)
44 | f1 = tuple(k_[-1] for k_ in k)
45 | return _interp_fit(y0, y1, y_mid, f0, f1, dt)
46 |
47 |
48 | def _abs_square(x):
49 | return torch.mul(x, x)
50 |
51 |
52 | def _ta_append(list_of_tensors, value):
53 | """Append a value to the end of a list of PyTorch tensors."""
54 | list_of_tensors.append(value)
55 | return list_of_tensors
56 |
57 |
58 | class Dopri5Solver(AdaptiveStepsizeODESolver):
59 |
60 | def __init__(
61 | self, func, y0, f_options, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1,
62 | **unused_kwargs
63 | ):
64 | _handle_unused_kwargs(self, unused_kwargs)
65 | del unused_kwargs
66 |
67 | self.func = func
68 | self.y0 = y0
69 | self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
70 | self.atol = atol if _is_iterable(atol) else [atol] * len(y0)
71 | self.f_options = f_options
72 | self.first_step = first_step
73 | self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device)
74 | self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device)
75 | self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device)
76 | self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device)
77 |
78 | def before_integrate(self, t):
79 | f0 = self.func(t[0].type_as(self.y0[0]), self.y0, **self.f_options)
80 | if self.first_step is None:
81 | first_step = _select_initial_step(self.func, t[0], self.y0, self.f_options, 4, self.rtol[0], self.atol[0], f0=f0).to(t)
82 | else:
83 | first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device)
84 | self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5)
85 |
86 | def advance(self, next_t):
87 | """Interpolate through the next time point, integrating as necessary."""
88 | n_steps = 0
89 | while next_t > self.rk_state.t1:
90 | assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps)
91 | self.rk_state = self._adaptive_dopri5_step(self.rk_state)
92 | n_steps += 1
93 | return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t)
94 |
95 | def _adaptive_dopri5_step(self, rk_state):
96 | """Take an adaptive Runge-Kutta step to integrate the ODE."""
97 | y0, f0, _, t0, dt, interp_coeff = rk_state
98 | ########################################################
99 | # Assertions #
100 | ########################################################
101 | assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item())
102 | for y0_ in y0:
103 | assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_)
104 | y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU, f_options=self.f_options)
105 |
106 | ########################################################
107 | # Error Ratio #
108 | ########################################################
109 | mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1)
110 | accept_step = (torch.tensor(mean_sq_error_ratio) <= 1).all()
111 |
112 | ########################################################
113 | # Update RK State #
114 | ########################################################
115 | y_next = y1 if accept_step else y0
116 | f_next = f1 if accept_step else f0
117 | t_next = t0 + dt if accept_step else t0
118 | interp_coeff = _interp_fit_dopri5(y0, y1, k, dt) if accept_step else interp_coeff
119 | dt_next = _optimal_step_size(
120 | dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=5
121 | )
122 | rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff)
123 | return rk_state
124 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/torchdiffeq/_impl/fixed_grid.py:
--------------------------------------------------------------------------------
1 | from .solvers import FixedGridODESolver
2 | from . import rk_common
3 |
4 |
5 | class Euler(FixedGridODESolver):
6 |
7 | def step_func(self, func, t, dt, y):
8 | return tuple(dt * f_ for f_ in func(t, y))
9 |
10 | @property
11 | def order(self):
12 | return 1
13 |
14 |
15 | class Midpoint(FixedGridODESolver):
16 |
17 | def step_func(self, func, t, dt, y):
18 | y_mid = tuple(y_ + f_ * dt / 2 for y_, f_ in zip(y, func(t, y)))
19 | return tuple(dt * f_ for f_ in func(t + dt / 2, y_mid))
20 |
21 | @property
22 | def order(self):
23 | return 2
24 |
25 |
26 | class RK4(FixedGridODESolver):
27 |
28 | def step_func(self, func, t, dt, y, f_options):
29 | return rk_common.rk4_alt_step_func(func, t, dt, y, f_options)
30 |
31 | @property
32 | def order(self):
33 | return 4
34 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/torchdiffeq/_impl/interp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .misc import _convert_to_tensor, _dot_product
3 |
4 |
5 | def _interp_fit(y0, y1, y_mid, f0, f1, dt):
6 | """Fit coefficients for 4th order polynomial interpolation.
7 |
8 | Args:
9 | y0: function value at the start of the interval.
10 | y1: function value at the end of the interval.
11 | y_mid: function value at the mid-point of the interval.
12 | f0: derivative value at the start of the interval.
13 | f1: derivative value at the end of the interval.
14 | dt: width of the interval.
15 |
16 | Returns:
17 | List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial
18 | `p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x`
19 | between 0 (start of interval) and 1 (end of interval).
20 | """
21 | a = tuple(
22 | _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0_, f1_, y0_, y1_, y_mid_])
23 | for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid)
24 | )
25 | b = tuple(
26 | _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0_, f1_, y0_, y1_, y_mid_])
27 | for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid)
28 | )
29 | c = tuple(
30 | _dot_product([-4 * dt, dt, -11, -5, 16], [f0_, f1_, y0_, y1_, y_mid_])
31 | for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid)
32 | )
33 | d = tuple(dt * f0_ for f0_ in f0)
34 | e = y0
35 | return [a, b, c, d, e]
36 |
37 |
38 | def _interp_evaluate(coefficients, t0, t1, t):
39 | """Evaluate polynomial interpolation at the given time point.
40 |
41 | Args:
42 | coefficients: list of Tensor coefficients as created by `interp_fit`.
43 | t0: scalar float64 Tensor giving the start of the interval.
44 | t1: scalar float64 Tensor giving the end of the interval.
45 | t: scalar float64 Tensor giving the desired interpolation point.
46 |
47 | Returns:
48 | Polynomial interpolation of the coefficients at time `t`.
49 | """
50 |
51 | dtype = coefficients[0][0].dtype
52 | device = coefficients[0][0].device
53 |
54 | t0 = _convert_to_tensor(t0, dtype=dtype, device=device)
55 | t1 = _convert_to_tensor(t1, dtype=dtype, device=device)
56 | t = _convert_to_tensor(t, dtype=dtype, device=device)
57 |
58 | assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1)
59 | x = ((t - t0) / (t1 - t0)).type(dtype).to(device)
60 |
61 | xs = [torch.tensor(1).type(dtype).to(device), x]
62 | for _ in range(2, len(coefficients)):
63 | xs.append(xs[-1] * x)
64 |
65 | return tuple(_dot_product(coefficients_, reversed(xs)) for coefficients_ in zip(*coefficients))
66 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/torchdiffeq/_impl/odeint.py:
--------------------------------------------------------------------------------
1 | from .tsit5 import Tsit5Solver
2 | from .dopri5 import Dopri5Solver
3 | from .fixed_grid import Euler, Midpoint, RK4
4 | from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton
5 | from .adams import VariableCoefficientAdamsBashforth
6 | from .misc import _check_inputs
7 |
8 | SOLVERS = {
9 | 'explicit_adams': AdamsBashforth,
10 | 'fixed_adams': AdamsBashforthMoulton,
11 | 'adams': VariableCoefficientAdamsBashforth,
12 | 'tsit5': Tsit5Solver,
13 | 'dopri5': Dopri5Solver,
14 | 'euler': Euler,
15 | 'midpoint': Midpoint,
16 | 'rk4': RK4,
17 | }
18 |
19 |
20 | def odeint(func, y0, t, rtol=1e-7, atol=1e-9, method=None, options=None, f_options=None):
21 | """Integrate a system of ordinary differential equations.
22 |
23 | Solves the initial value problem for a non-stiff system of first order ODEs:
24 | ```
25 | dy/dt = func(t, y), y(t[0]) = y0
26 | ```
27 | where y is a Tensor of any shape.
28 |
29 | Output dtypes and numerical precision are based on the dtypes of the inputs `y0`.
30 |
31 | Args:
32 | func: Function that maps a Tensor holding the state `y` and a scalar Tensor
33 | `t` into a Tensor of state derivatives with respect to time.
34 | y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
35 | have any floating point or complex dtype.
36 | t: 1-D Tensor holding a sequence of time points for which to solve for
37 | `y`. The initial time point should be the first element of this sequence,
38 | and each time must be larger than the previous time. May have any floating
39 | point dtype. Converted to a Tensor with float64 dtype.
40 | rtol: optional float64 Tensor specifying an upper bound on relative error,
41 | per element of `y`.
42 | atol: optional float64 Tensor specifying an upper bound on absolute error,
43 | per element of `y`.
44 | method: optional string indicating the integration method to use.
45 | options: optional dict of configuring options for the indicated integration
46 | method. Can only be provided if a `method` is explicitly set.
47 | f_options: optional dict of additional arguments for func
48 |
49 | Returns:
50 | y: Tensor, where the first dimension corresponds to different
51 | time points. Contains the solved value of y for each desired time point in
52 | `t`, with the initial value `y0` being the first element along the first
53 | dimension.
54 |
55 | Raises:
56 | ValueError: if an invalid `method` is provided.
57 | TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
58 | an invalid dtype.
59 | """
60 | tensor_input, func, y0, t = _check_inputs(func, y0, t, f_options)
61 |
62 | if options is None:
63 | options = {}
64 | elif method is None:
65 | raise ValueError('cannot supply `options` without specifying `method`')
66 |
67 | if method is None:
68 | method = 'dopri5'
69 |
70 | if f_options is None:
71 | f_options = {}
72 | solver = SOLVERS[method](func, y0, f_options, rtol=rtol, atol=atol, **options)
73 | solution = solver.integrate(t)
74 |
75 | if tensor_input:
76 | solution = solution[0]
77 | return solution
78 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/torchdiffeq/_impl/rk_common.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate
2 | import collections
3 | from .misc import _scaled_dot_product, _convert_to_tensor
4 |
5 | _ButcherTableau = collections.namedtuple('_ButcherTableau', 'alpha beta c_sol c_error')
6 |
7 |
8 | class _RungeKuttaState(collections.namedtuple('_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff')):
9 | """Saved state of the Runge Kutta solver.
10 |
11 | Attributes:
12 | y1: Tensor giving the function value at the end of the last time step.
13 | f1: Tensor giving derivative at the end of the last time step.
14 | t0: scalar float64 Tensor giving start of the last time step.
15 | t1: scalar float64 Tensor giving end of the last time step.
16 | dt: scalar float64 Tensor giving the size for the next time step.
17 | interp_coef: list of Tensors giving coefficients for polynomial
18 | interpolation between `t0` and `t1`.
19 | """
20 |
21 |
22 | def _runge_kutta_step(func, y0, f0, t0, dt, tableau, f_options):
23 | """Take an arbitrary Runge-Kutta step and estimate error.
24 |
25 | Args:
26 | func: Function to evaluate like `func(t, y)` to compute the time derivative
27 | of `y`.
28 | y0: Tensor initial value for the state.
29 | f0: Tensor initial value for the derivative, computed from `func(t0, y0)`.
30 | t0: float64 scalar Tensor giving the initial time.
31 | dt: float64 scalar Tensor giving the size of the desired time step.
32 | tableau: optional _ButcherTableau describing how to take the Runge-Kutta
33 | step.
34 | f_options: additional keyworded arguments for func
35 | name: optional name for the operation.
36 |
37 | Returns:
38 | Tuple `(y1, f1, y1_error, k)` giving the estimated function value after
39 | the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`,
40 | estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for
41 | calculating these terms.
42 | """
43 | dtype = y0[0].dtype
44 | device = y0[0].device
45 |
46 | t0 = _convert_to_tensor(t0, dtype=dtype, device=device)
47 | dt = _convert_to_tensor(dt, dtype=dtype, device=device)
48 |
49 | k = tuple(map(lambda x: [x], f0))
50 | for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
51 | ti = t0 + alpha_i * dt
52 | yi = tuple(y0_ + _scaled_dot_product(dt, beta_i, k_) for y0_, k_ in zip(y0, k))
53 | tuple(k_.append(f_) for k_, f_ in zip(k, func(ti, yi, **f_options)))
54 |
55 | if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]):
56 | # This property (true for Dormand-Prince) lets us save a few FLOPs.
57 | yi = tuple(y0_ + _scaled_dot_product(dt, tableau.c_sol, k_) for y0_, k_ in zip(y0, k))
58 |
59 | y1 = yi
60 | f1 = tuple(k_[-1] for k_ in k)
61 | y1_error = tuple(_scaled_dot_product(dt, tableau.c_error, k_) for k_ in k)
62 | return (y1, f1, y1_error, k)
63 |
64 |
65 | def rk4_step_func(func, t, dt, y, f_options, k1=None):
66 | if k1 is None: k1 = func(t, y, **f_options)
67 | k2 = func(t + dt / 2, tuple(y_ + dt * k1_ / 2 for y_, k1_ in zip(y, k1)), **f_options)
68 | k3 = func(t + dt / 2, tuple(y_ + dt * k2_ / 2 for y_, k2_ in zip(y, k2)), **f_options)
69 | k4 = func(t + dt, tuple(y_ + dt * k3_ for y_, k3_ in zip(y, k3)), **f_options)
70 | return tuple((k1_ + 2 * k2_ + 2 * k3_ + k4_) * (dt / 6) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4))
71 |
72 |
73 | def rk4_alt_step_func(func, t, dt, y, f_options, k1=None):
74 | """Smaller error with slightly more compute."""
75 | if k1 is None: k1 = func(t, y, **f_options)
76 | k2 = func(t + dt / 3, tuple(y_ + dt * k1_ / 3 for y_, k1_ in zip(y, k1)), **f_options)
77 | k3 = func(t + dt * 2 / 3, tuple(y_ + dt * (k1_ / -3 + k2_) for y_, k1_, k2_ in zip(y, k1, k2)), **f_options)
78 | k4 = func(t + dt, tuple(y_ + dt * (k1_ - k2_ + k3_) for y_, k1_, k2_, k3_ in zip(y, k1, k2, k3)), **f_options)
79 | return tuple((k1_ + 3 * k2_ + 3 * k3_ + k4_) * (dt / 8) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4))
80 |
--------------------------------------------------------------------------------
/im2mesh/utils/torchdiffeq/torchdiffeq/_impl/solvers.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import torch
3 | from .misc import _assert_increasing, _handle_unused_kwargs
4 |
5 |
6 | class AdaptiveStepsizeODESolver(object):
7 | __metaclass__ = abc.ABCMeta
8 |
9 | def __init__(self, func, y0, f_options, atol=0.01, rtol=0.001, **unused_kwargs):
10 | _handle_unused_kwargs(self, unused_kwargs)
11 | del unused_kwargs
12 |
13 | self.func = func
14 | self.y0 = y0
15 | self.f_options = f_options
16 |
17 | self.atol = atol
18 | self.rtol = rtol
19 |
20 | def before_integrate(self, t):
21 | pass
22 |
23 | @abc.abstractmethod
24 | def advance(self, next_t):
25 | raise NotImplementedError
26 |
27 | def integrate(self, t):
28 | _assert_increasing(t)
29 | solution = [self.y0]
30 | t = t.to(self.y0[0].device, torch.float64)
31 | self.before_integrate(t)
32 | for i in range(1, len(t)):
33 | y = self.advance(t[i])
34 | solution.append(y)
35 | return tuple(map(torch.stack, tuple(zip(*solution))))
36 |
37 |
38 | class FixedGridODESolver(object):
39 | __metaclass__ = abc.ABCMeta
40 |
41 | def __init__(self, func, y0, f_options, step_size=None, grid_constructor=None, **unused_kwargs):
42 | unused_kwargs.pop('rtol', None)
43 | unused_kwargs.pop('atol', None)
44 | _handle_unused_kwargs(self, unused_kwargs)
45 | del unused_kwargs
46 |
47 | self.func = func
48 | self.y0 = y0
49 | self.f_options = f_options
50 | if step_size is not None and grid_constructor is None:
51 | self.grid_constructor = self._grid_constructor_from_step_size(step_size)
52 | elif grid_constructor is None:
53 | self.grid_constructor = lambda f, y0, t: t
54 | else:
55 | raise ValueError("step_size and grid_constructor are exclusive arguments.")
56 |
57 | def _grid_constructor_from_step_size(self, step_size):
58 |
59 | def _grid_constructor(func, y0, t):
60 | start_time = t[0]
61 | end_time = t[-1]
62 |
63 | niters = torch.ceil((end_time - start_time) / step_size + 1).item()
64 | t_infer = torch.arange(0, niters).to(t) * step_size + start_time
65 | '''
66 | if t_infer[-1] > t[-1]:
67 | t_infer[-1] = t[-1]
68 | '''
69 | if t_infer[-1] != t[-1]:
70 | t_infer[-1] = t[-1]
71 |
72 | return t_infer
73 |
74 | return _grid_constructor
75 |
76 | @property
77 | @abc.abstractmethod
78 | def order(self):
79 | pass
80 |
81 | @abc.abstractmethod
82 | def step_func(self, func, t, dt, y, f_options):
83 | pass
84 |
85 | def integrate(self, t):
86 | _assert_increasing(t)
87 | t = t.type_as(self.y0[0])
88 | time_grid = self.grid_constructor(self.func, self.y0, t)
89 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1]
90 | time_grid = time_grid.to(self.y0[0])
91 |
92 | solution = [self.y0]
93 |
94 | j = 1
95 | y0 = self.y0
96 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
97 | dy = self.step_func(self.func, t0, t1 - t0, y0, self.f_options)
98 | y1 = tuple(y0_ + dy_ for y0_, dy_ in zip(y0, dy))
99 | y0 = y1
100 |
101 | while j < len(t) and t1 >= t[j]:
102 | solution.append(self._linear_interp(t0, t1, y0, y1, t[j]))
103 | j += 1
104 |
105 | return tuple(map(torch.stack, tuple(zip(*solution))))
106 |
107 | def _linear_interp(self, t0, t1, y0, y1, t):
108 | if t == t0:
109 | return y0
110 | if t == t1:
111 | return y1
112 | t0, t1, t = t0.to(y0[0]), t1.to(y0[0]), t.to(y0[0])
113 | slope = tuple((y1_ - y0_) / (t1 - t0) for y0_, y1_, in zip(y0, y1))
114 | return tuple(y0_ + slope_ * (t - t0) for y0_, slope_ in zip(y0, slope))
115 |
--------------------------------------------------------------------------------
/im2mesh/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | matplotlib.use('agg')
4 | from matplotlib import pyplot as plt
5 | from mpl_toolkits.mplot3d import Axes3D
6 | from torchvision.utils import save_image
7 | import im2mesh.common as common
8 |
9 |
10 | def visualize_data(data, data_type, out_file):
11 | r''' Visualizes the data with regard to its type.
12 |
13 | Args:
14 | data (tensor): batch of data
15 | data_type (string): data type (img, voxels or pointcloud)
16 | out_file (string): output file
17 | '''
18 | if data_type == 'img':
19 | if data.dim() == 3:
20 | data = data.unsqueeze(0)
21 | save_image(data, out_file, nrow=4)
22 | elif data_type == 'voxels':
23 | visualize_voxels(data, out_file=out_file)
24 | elif data_type == 'pointcloud':
25 | visualize_pointcloud(data, out_file=out_file)
26 | elif data_type is None or data_type == 'idx':
27 | pass
28 | else:
29 | raise ValueError('Invalid data_type "%s"' % data_type)
30 |
31 |
32 | def visualize_voxels(voxels, out_file=None, show=False):
33 | r''' Visualizes voxel data.
34 |
35 | Args:
36 | voxels (tensor): voxel data
37 | out_file (string): output file
38 | show (bool): whether the plot should be shown
39 | '''
40 | # Use numpy
41 | voxels = np.asarray(voxels)
42 | # Create plot
43 | fig = plt.figure()
44 | ax = fig.gca(projection=Axes3D.name)
45 | voxels = voxels.transpose(2, 0, 1)
46 | ax.voxels(voxels, edgecolor='k')
47 | ax.set_xlabel('Z')
48 | ax.set_ylabel('X')
49 | ax.set_zlabel('Y')
50 | ax.view_init(elev=30, azim=45)
51 | if out_file is not None:
52 | plt.savefig(out_file)
53 | if show:
54 | plt.show()
55 | plt.close(fig)
56 |
57 |
58 | def visualize_pointcloud(points, normals=None,
59 | out_file=None, show=False):
60 | r''' Visualizes point cloud data.
61 |
62 | Args:
63 | points (tensor): point data
64 | normals (tensor): normal data (if existing)
65 | out_file (string): output file
66 | show (bool): whether the plot should be shown
67 | '''
68 | # Use numpy
69 | points = np.asarray(points)
70 | # Create plot
71 | fig = plt.figure()
72 | ax = fig.gca(projection=Axes3D.name)
73 | ax.scatter(points[:, 2], points[:, 0], points[:, 1])
74 | if normals is not None:
75 | ax.quiver(
76 | points[:, 2], points[:, 0], points[:, 1],
77 | normals[:, 2], normals[:, 0], normals[:, 1],
78 | length=0.1, color='k'
79 | )
80 | ax.set_xlabel('Z')
81 | ax.set_ylabel('X')
82 | ax.set_zlabel('Y')
83 | ax.set_xlim(-0.5, 0.5)
84 | ax.set_ylim(-0.5, 0.5)
85 | ax.set_zlim(-0.5, 0.5)
86 | ax.view_init(elev=30, azim=45)
87 | if out_file is not None:
88 | plt.savefig(out_file)
89 | if show:
90 | plt.show()
91 | plt.close(fig)
92 |
93 |
94 | def visualise_projection(
95 | self, points, world_mat, camera_mat, img, output_file='out.png'):
96 | r''' Visualizes the transformation and projection to image plane.
97 |
98 | The first points of the batch are transformed and projected to the
99 | respective image. After performing the relevant transformations, the
100 | visualization is saved in the provided output_file path.
101 |
102 | Arguments:
103 | points (tensor): batch of point cloud points
104 | world_mat (tensor): batch of matrices to rotate pc to camera-based
105 | coordinates
106 | camera_mat (tensor): batch of camera matrices to project to 2D image
107 | plane
108 | img (tensor): tensor of batch GT image files
109 | output_file (string): where the output should be saved
110 | '''
111 | points_transformed = common.transform_points(points, world_mat)
112 | points_img = common.project_to_camera(points_transformed, camera_mat)
113 | pimg2 = points_img[0].detach().cpu().numpy()
114 | image = img[0].cpu().numpy()
115 | plt.imshow(image.transpose(1, 2, 0))
116 | plt.plot(
117 | (pimg2[:, 0] + 1)*image.shape[1]/2,
118 | (pimg2[:, 1] + 1) * image.shape[2]/2, 'x')
119 | plt.savefig(output_file)
120 |
--------------------------------------------------------------------------------
/scripts/build_dataset.sh:
--------------------------------------------------------------------------------
1 | source config.sh
2 |
3 | # Make output directories
4 | mkdir -p $BUILD_PATH
5 |
6 | mkdir -p $BUILD_PATH/0_points \
7 | $BUILD_PATH/0_pointcloud \
8 | $BUILD_PATH/0_meshes #\
9 | # $BUILD_PATH/0_render
10 |
11 | echo " Building Human Dataset for Occupancy Flow Project."
12 | echo " Input Path: $INPUT_PATH"
13 | echo " Build Path: $BUILD_PATH"
14 |
15 | echo "Sample points ..."
16 | python sample_mesh.py $INPUT_PATH \
17 | --n_proc $NPROC --resize \
18 | --points_folder $BUILD_PATH/0_points \
19 | --overwrite --float16 --packbits
20 | echo "done!"
21 |
22 | echo "Sample pointcloud"
23 | python sample_mesh.py $INPUT_PATH \
24 | --n_proc $NPROC --resize \
25 | --pointcloud_folder $BUILD_PATH/0_pointcloud \
26 | --overwrite --float16
27 | echo "done"
28 |
29 | echo "Copy mesh data."
30 | inputs=$(lsfilter $INPUT_PATH)
31 | for m in ${inputs[@]}; do
32 | m_path="$INPUT_PATH/$m"
33 | mesh_files=$(lsfilter $m_path)
34 | out_path="$BUILD_PATH/0_meshes/$m"
35 | mkdir -p $out_path
36 | echo "Copy for model $m"
37 | for f in ${mesh_files[@]}; do
38 | mesh_file="$m_path/$f"
39 | out_file="$out_path/$f"
40 | cp $mesh_file $out_file
41 | done
42 | done
43 | echo "done"
44 |
45 | # echo "Render sequence (camera on circle)"
46 | # inputs=$(lsfilter $INPUT_PATH $BUILD_PATH/0_render /camera.npz) #
47 | # for f in ${inputs[@]}; do
48 | # # lsfilter $INPUT_PATH $BUILD_PATH/0_render /camera.npz | parallel -P $NPROC \
49 | # blender --background --python render_blender.py -- \
50 | # --output_folder $BUILD_PATH/0_render \
51 | # --views $N_VIEWS \
52 | # --camera circle \
53 | # $INPUT_PATH/$f
54 | # done
55 |
56 |
--------------------------------------------------------------------------------
/scripts/build_dataset_incomplete.sh:
--------------------------------------------------------------------------------
1 | source config.sh
2 |
3 | # Make output directories
4 | mkdir -p $BUILD_PATH
5 |
6 |
7 | echo " Building Human Dataset for Occupancy Flow Project."
8 | echo " Input Path: $INPUT_PATH"
9 | echo " Build Path: $BUILD_PATH"
10 |
11 |
12 | echo "Sample points sdf and pointcloud ..."
13 | python compute_incomplete.py $INPUT_PATH \
14 | --n_proc 8 --resize \
15 | --partial_pointcloud_size 30000 --start 0.0 --end 1.0 --float16 --radius 0.1
16 | echo "done!"
17 |
--------------------------------------------------------------------------------
/scripts/config.sh:
--------------------------------------------------------------------------------
1 | ROOT=/data2/lab-tang.jiapeng
2 | export HDF5_USE_FILE_LOCKING=FALSE # Workaround for NFS mounts
3 |
4 | INPUT_PATH=$ROOT/D-FAUST/scripts/registration_meshes
5 |
6 | BUILD_PATH=$ROOT/Humans.build
7 | OUTPUT_PATH=$ROOT/Humans/D-FAUST
8 |
9 | NPROC=6
10 | TIMEOUT=180
11 | #N_VIEWS=4
12 |
13 | # Utility functions
14 | lsfilter() {
15 | folder=$1
16 | other_folder=$2
17 | ext=$3
18 |
19 | for f in $folder/*; do
20 | filename=$(basename $f)
21 | if [ ! -f $other_folder/$filename$ext ] && [ ! -d $other_folder/$filename$ext ]; then
22 | echo $filename
23 | fi
24 | done
25 | }
26 |
--------------------------------------------------------------------------------
/scripts/download_data.sh:
--------------------------------------------------------------------------------
1 | mkdir data
2 | cd data
3 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/occupancy_flow/data/Humans.zip
4 | unzip Humans.zip
--------------------------------------------------------------------------------
/scripts/install_dataset.sh:
--------------------------------------------------------------------------------
1 | source config.sh
2 |
3 | # Function for processing a single model
4 | organize_model() {
5 | filename=$(basename -- $3)
6 | modelname="${filename%.*}"
7 | output_path="$2/$modelname"
8 | build_path=$1
9 |
10 | points_folder="$build_path/0_points/$modelname"
11 | points_out_folder="$output_path/points_seq"
12 |
13 | # img_folder="$build_path/0_render/$modelname"
14 | # img_out_folder="$output_path/img"
15 |
16 | pointcloud_folder="$build_path/0_pointcloud/$modelname"
17 | pointcloud_out_folder="$output_path/pcl_seq"
18 |
19 | mesh_folder="$build_path/0_meshes/$modelname"
20 | mesh_out_folder="$output_path/mesh_seq"
21 |
22 | # if [ -d $points_folder ] \
23 | # && [ -d $pointcloud_folder ] \
24 | # && [ -d $img_folder ] \
25 | # && [ -d $mesh_folder ] \
26 | if [ -d $points_folder ] \
27 | && [ -d $pointcloud_folder ] \
28 | && [ -d $mesh_folder ] \
29 | ; then
30 | echo "Copying model $output_path"
31 | mkdir -p "$output_path"
32 |
33 | cp -rT $points_folder $points_out_folder
34 | cp -rT $pointcloud_folder $pointcloud_out_folder
35 | #Qcp -rT $img_folder $img_out_folder
36 | cp -rT $mesh_folder $mesh_out_folder
37 | fi
38 | }
39 |
40 | echo "Installing Humans Dataset for Occupancy Flow"
41 | echo "Output Directory: $OUTPUT_PATH"
42 |
43 | export -f organize_model
44 |
45 | # Make output directories
46 | mkdir -p $OUTPUT_PATH
47 |
48 | # Run install
49 | ls $INPUT_PATH | parallel -P $NPROC \
50 | organize_model $BUILD_PATH $OUTPUT_PATH {}
51 |
52 | # Copy Split Files
53 | cp split_files/* $OUTPUT_PATH/
54 |
55 | echo "done!"
56 |
--------------------------------------------------------------------------------
/scripts/migrate_dfaust.sh:
--------------------------------------------------------------------------------
1 | # This script migrates D-FAUST mesh data and the provided point and point cloud data.
2 |
3 | IN_PATH=$1
4 | OUT_PATH=/data2/lab-tang.jiapeng/Humans/D-FAUST
5 | mesh_folder_name='mesh_seq'
6 |
7 | model_files=$(ls $IN_PATH)
8 |
9 | echo "Copying mesh data from $IN_PATH to dataset in $OUT_PATH ..."
10 | for m in ${model_files[@]}
11 | do
12 | echo "Processing model $m ..."
13 | model_folder_in=$IN_PATH/$m
14 | model_folder_out=$OUT_PATH/$m/$mesh_folder_name
15 | cp -R $model_folder_in $model_folder_out
16 | echo "done (model)!"
17 | done
18 | echo "done (dataset)!
19 |
--------------------------------------------------------------------------------
/scripts/split_files/overfit.lst:
--------------------------------------------------------------------------------
1 | 50026_one_leg_jump
2 |
--------------------------------------------------------------------------------
/scripts/split_files/test.lst:
--------------------------------------------------------------------------------
1 | 50002_light_hopping_loose
2 | 50004_punching
3 | 50007_shake_shoulders
4 | 50009_chicken_wings
5 | 50020_chicken_wings
6 | 50022_light_hopping_loose
7 | 50025_light_hopping_loose
8 | 50026_shake_arms
9 | 50027_shake_shoulders
10 |
--------------------------------------------------------------------------------
/scripts/split_files/test_new_individual.lst:
--------------------------------------------------------------------------------
1 | 50021_chicken_wings
2 | 50021_knees
3 | 50021_one_leg_jump
4 | 50021_punching
5 | 50021_shake_arms
6 | 50021_shake_shoulders
7 | 50021_hips
8 | 50021_light_hopping_stiff
9 | 50021_one_leg_loose
10 | 50021_running_on_spot
11 | 50021_shake_hips
12 |
--------------------------------------------------------------------------------
/scripts/split_files/train.lst:
--------------------------------------------------------------------------------
1 | 50002_one_leg_loose
2 | 50025_shake_shoulders
3 | 50007_one_leg_jump
4 | 50002_knees
5 | 50002_light_hopping_stiff
6 | 50004_jiggle_on_toes
7 | 50004_shake_hips
8 | 50026_chicken_wings
9 | 50007_punching
10 | 50022_one_leg_jump
11 | 50009_light_hopping_stiff
12 | 50025_light_hopping_stiff
13 | 50007_shake_arms
14 | 50026_running_on_spot
15 | 50025_chicken_wings
16 | 50020_shake_hips
17 | 50026_shake_hips
18 | 50027_light_hopping_stiff
19 | 50009_shake_hips
20 | 50009_light_hopping_loose
21 | 50020_jiggle_on_toes
22 | 50025_one_leg_loose
23 | 50009_punching
24 | 50027_hips
25 | 50002_running_on_spot
26 | 50026_light_hopping_stiff
27 | 50026_jiggle_on_toes
28 | 50020_shake_shoulders
29 | 50007_light_hopping_stiff
30 | 50007_jiggle_on_toes
31 | 50027_punching
32 | 50009_running_on_spot
33 | 50002_shake_arms
34 | 50022_knees
35 | 50007_jumping_jacks
36 | 50027_running_on_spot
37 | 50022_running_on_spot
38 | 50004_knees
39 | 50027_one_leg_loose
40 | 50009_one_leg_jump
41 | 50026_jumping_jacks
42 | 50009_one_leg_loose
43 | 50027_jiggle_on_toes
44 | 50020_knees
45 | 50027_light_hopping_loose
46 | 50026_knees
47 | 50004_jumping_jacks
48 | 50026_one_leg_jump
49 | 50004_hips
50 | 50027_shake_arms
51 | 50026_hips
52 | 50020_punching
53 | 50025_jiggle_on_toes
54 | 50022_punching
55 | 50004_light_hopping_loose
56 | 50009_jumping_jacks
57 | 50002_one_leg_jump
58 | 50007_knees
59 | 50027_jumping_jacks
60 | 50022_shake_arms
61 | 50002_jiggle_on_toes
62 | 50002_hips
63 | 50009_jiggle_on_toes
64 | 50022_jiggle_on_toes
65 | 50007_shake_hips
66 | 50022_hips
67 | 50026_punching
68 | 50026_one_leg_loose
69 | 50025_running_on_spot
70 | 50025_knees
71 | 50025_hips
72 | 50020_light_hopping_stiff
73 | 50026_shake_shoulders
74 | 50002_jumping_jacks
75 | 50022_light_hopping_stiff
76 | 50027_one_leg_jump
77 | 50002_punching
78 | 50022_shake_shoulders
79 | 50004_running_on_spot
80 | 50020_light_hopping_loose
81 | 50022_shake_hips
82 | 50004_one_leg_loose
83 | 50022_jumping_jacks
84 | 50002_chicken_wings
85 | 50022_one_leg_loose
86 | 50027_knees
87 | 50004_shake_arms
88 | 50007_chicken_wings
89 | 50002_shake_hips
90 | 50007_one_leg_loose
91 | 50004_shake_shoulders
92 | 50009_hips
93 | 50007_running_on_spot
94 | 50025_shake_hips
95 | 50002_shake_shoulders
96 | 50020_shake_arms
97 | 50027_shake_hips
98 | 50026_light_hopping_loose
99 | 50025_one_leg_jump
100 | 50020_one_leg_loose
101 | 50004_one_leg_jump
102 | 50025_punching
103 | 50020_one_leg_jump
104 | 50004_light_hopping_stiff
105 |
--------------------------------------------------------------------------------
/scripts/split_files/train_generative.lst:
--------------------------------------------------------------------------------
1 | 50002_one_leg_loose
2 | 50025_shake_shoulders
3 | 50007_one_leg_jump
4 | 50002_knees
5 | 50002_light_hopping_stiff
6 | 50004_jiggle_on_toes
7 | 50004_shake_hips
8 | 50026_chicken_wings
9 | 50007_punching
10 | 50022_one_leg_jump
11 | 50009_light_hopping_stiff
12 | 50025_light_hopping_stiff
13 | 50007_shake_arms
14 | 50026_running_on_spot
15 | 50025_chicken_wings
16 | 50020_shake_hips
17 | 50026_shake_hips
18 | 50027_light_hopping_stiff
19 | 50009_shake_hips
20 | 50009_light_hopping_loose
21 | 50020_jiggle_on_toes
22 | 50025_one_leg_loose
23 | 50009_punching
24 | 50027_hips
25 | 50002_running_on_spot
26 | 50026_light_hopping_stiff
27 | 50026_jiggle_on_toes
28 | 50020_shake_shoulders
29 | 50007_light_hopping_stiff
30 | 50007_jiggle_on_toes
31 | 50027_punching
32 | 50009_running_on_spot
33 | 50002_shake_arms
34 | 50022_knees
35 | 50007_jumping_jacks
36 | 50027_running_on_spot
37 | 50022_running_on_spot
38 | 50004_knees
39 | 50027_one_leg_loose
40 | 50009_one_leg_jump
41 | 50026_jumping_jacks
42 | 50009_one_leg_loose
43 | 50027_jiggle_on_toes
44 | 50020_knees
45 | 50027_light_hopping_loose
46 | 50026_knees
47 | 50004_jumping_jacks
48 | 50026_one_leg_jump
49 | 50004_hips
50 | 50027_shake_arms
51 | 50026_hips
52 | 50020_punching
53 | 50025_jiggle_on_toes
54 | 50022_punching
55 | 50004_light_hopping_loose
56 | 50009_jumping_jacks
57 | 50002_one_leg_jump
58 | 50007_knees
59 | 50027_jumping_jacks
60 | 50022_shake_arms
61 | 50002_jiggle_on_toes
62 | 50002_hips
63 | 50009_jiggle_on_toes
64 | 50022_jiggle_on_toes
65 | 50007_shake_hips
66 | 50022_hips
67 | 50026_punching
68 | 50026_one_leg_loose
69 | 50025_running_on_spot
70 | 50025_knees
71 | 50025_hips
72 | 50020_light_hopping_stiff
73 | 50026_shake_shoulders
74 | 50002_jumping_jacks
75 | 50022_light_hopping_stiff
76 | 50027_one_leg_jump
77 | 50002_punching
78 | 50022_shake_shoulders
79 | 50004_running_on_spot
80 | 50020_light_hopping_loose
81 | 50022_shake_hips
82 | 50004_one_leg_loose
83 | 50022_jumping_jacks
84 | 50002_chicken_wings
85 | 50022_one_leg_loose
86 | 50027_knees
87 | 50004_shake_arms
88 | 50007_chicken_wings
89 | 50002_shake_hips
90 | 50007_one_leg_loose
91 | 50004_shake_shoulders
92 | 50009_hips
93 | 50007_running_on_spot
94 | 50025_shake_hips
95 | 50002_shake_shoulders
96 | 50020_shake_arms
97 | 50027_shake_hips
98 | 50026_light_hopping_loose
99 | 50025_one_leg_jump
100 | 50020_one_leg_loose
101 | 50004_one_leg_jump
102 | 50025_punching
103 | 50020_one_leg_jump
104 | 50004_light_hopping_stiff
105 | 50002_light_hopping_loose
106 | 50004_punching
107 | 50007_shake_shoulders
108 | 50009_chicken_wings
109 | 50020_chicken_wings
110 | 50022_light_hopping_loose
111 | 50025_light_hopping_loose
112 | 50026_shake_arms
113 | 50027_shake_shoulders
114 | 50021_chicken_wings
115 | 50021_knees
116 | 50021_one_leg_jump
117 | 50021_punching
118 | 50021_shake_arms
119 | 50021_shake_shoulders
120 | 50021_hips
121 | 50021_light_hopping_stiff
122 | 50021_one_leg_loose
123 | 50021_running_on_spot
124 | 50021_shake_hips
125 |
--------------------------------------------------------------------------------
/scripts/split_files/val.lst:
--------------------------------------------------------------------------------
1 | 50004_chicken_wings
2 | 50020_hips
3 | 50025_shake_arms
4 | 50020_running_on_spot
5 | 50007_light_hopping_loose
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | try:
2 | from setuptools import setup
3 | except ImportError:
4 | from distutils.core import setup
5 | from distutils.extension import Extension
6 | from Cython.Build import cythonize
7 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
8 | import numpy
9 |
10 |
11 | # Get the numpy include directory.
12 | numpy_include_dir = numpy.get_include()
13 |
14 | # Extensions
15 | # pykdtree (kd tree)
16 | pykdtree = Extension(
17 | 'im2mesh.utils.libkdtree.pykdtree.kdtree',
18 | sources=[
19 | 'im2mesh/utils/libkdtree/pykdtree/kdtree.c',
20 | 'im2mesh/utils/libkdtree/pykdtree/_kdtree_core.c'
21 | ],
22 | language='c',
23 | extra_compile_args=['-std=c99', '-O3', '-fopenmp'],
24 | extra_link_args=['-lgomp'],
25 | include_dirs=[numpy_include_dir]
26 | )
27 |
28 | # mcubes (marching cubes algorithm)
29 | mcubes_module = Extension(
30 | 'im2mesh.utils.libmcubes.mcubes',
31 | sources=[
32 | 'im2mesh/utils/libmcubes/mcubes.pyx',
33 | 'im2mesh/utils/libmcubes/pywrapper.cpp',
34 | 'im2mesh/utils/libmcubes/marchingcubes.cpp'
35 | ],
36 | language='c++',
37 | extra_compile_args=['-std=c++11'],
38 | include_dirs=[numpy_include_dir]
39 | )
40 |
41 | # triangle hash (efficient mesh intersection)
42 | triangle_hash_module = Extension(
43 | 'im2mesh.utils.libmesh.triangle_hash',
44 | sources=[
45 | 'im2mesh/utils/libmesh/triangle_hash.pyx'
46 | ],
47 | libraries=['m'], # Unix-like specific
48 | include_dirs=[numpy_include_dir]
49 | )
50 |
51 | # mise (efficient mesh extraction)
52 | mise_module = Extension(
53 | 'im2mesh.utils.libmise.mise',
54 | sources=[
55 | 'im2mesh/utils/libmise/mise.pyx'
56 | ],
57 | )
58 |
59 | # simplify (efficient mesh simplification)
60 | simplify_mesh_module = Extension(
61 | 'im2mesh.utils.libsimplify.simplify_mesh',
62 | sources=[
63 | 'im2mesh/utils/libsimplify/simplify_mesh.pyx'
64 | ],
65 | include_dirs=[numpy_include_dir]
66 | )
67 |
68 | # voxelization (efficient mesh voxelization)
69 | voxelize_module = Extension(
70 | 'im2mesh.utils.libvoxelize.voxelize',
71 | sources=[
72 | 'im2mesh/utils/libvoxelize/voxelize.pyx'
73 | ],
74 | libraries=['m'], # Unix-like specific
75 | include_dirs=[numpy_include_dir]
76 | )
77 | '''
78 | # DMC extensions
79 | dmc_pred2mesh_module = CppExtension(
80 | 'im2mesh.dmc.ops.cpp_modules.pred2mesh',
81 | sources=[
82 | 'im2mesh/dmc/ops/cpp_modules/pred_to_mesh_.cpp',
83 | ],
84 | include_dirs=[numpy_include_dir]
85 | )
86 |
87 | dmc_cuda_module = CUDAExtension(
88 | 'im2mesh.dmc.ops._cuda_ext',
89 | sources=[
90 | 'im2mesh/dmc/ops/src/extension.cpp',
91 | 'im2mesh/dmc/ops/src/curvature_constraint_kernel.cu',
92 | 'im2mesh/dmc/ops/src/grid_pooling_kernel.cu',
93 | 'im2mesh/dmc/ops/src/occupancy_to_topology_kernel.cu',
94 | 'im2mesh/dmc/ops/src/occupancy_connectivity_kernel.cu',
95 | 'im2mesh/dmc/ops/src/point_triangle_distance_kernel.cu',
96 | ]
97 | )
98 | '''
99 |
100 | # Gather all extension modules
101 | ext_modules = [
102 | pykdtree,
103 | mcubes_module,
104 | triangle_hash_module,
105 | mise_module,
106 | simplify_mesh_module,
107 | voxelize_module,
108 | #dmc_pred2mesh_module,
109 | #dmc_cuda_module,
110 | ]
111 |
112 | setup(
113 | ext_modules=cythonize(ext_modules),
114 | cmdclass={
115 | 'build_ext': BuildExtension
116 | }
117 | )
118 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from im2mesh.checkpoints import CheckpointIO
2 | from im2mesh import config, data
3 | import torch
4 | import torch.optim as optim
5 | from tensorboardX import SummaryWriter
6 | import numpy as np
7 | import os
8 | import argparse
9 | import time
10 |
11 |
12 | # Arguments
13 | parser = argparse.ArgumentParser(
14 | description='Train a 4D model.'
15 | )
16 | parser.add_argument('config', type=str, help='Path to config file.')
17 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.')
18 | parser.add_argument('--exit-after', type=int, default=-1,
19 | help='Checkpoint and exit after specified number of '
20 | 'seconds with exit code 2.')
21 |
22 | args = parser.parse_args()
23 | cfg = config.load_config(args.config, 'configs/default.yaml')
24 | is_cuda = (torch.cuda.is_available() and not args.no_cuda)
25 | device = torch.device("cuda" if is_cuda else "cpu")
26 |
27 | # Set t0
28 | t0 = time.time()
29 |
30 | # Shorthands
31 | out_dir = cfg['training']['out_dir']
32 | batch_size = cfg['training']['batch_size']
33 | batch_size_vis = cfg['training']['batch_size_vis']
34 | batch_size_val = cfg['training']['batch_size_val']
35 | backup_every = cfg['training']['backup_every']
36 | exit_after = args.exit_after
37 | lr = cfg['training']['learning_rate']
38 |
39 | model_selection_metric = cfg['training']['model_selection_metric']
40 | if cfg['training']['model_selection_mode'] == 'maximize':
41 | model_selection_sign = 1
42 | elif cfg['training']['model_selection_mode'] == 'minimize':
43 | model_selection_sign = -1
44 | else:
45 | raise ValueError('model_selection_mode must be '
46 | 'either maximize or minimize.')
47 |
48 | # Output directory
49 | if not os.path.exists(out_dir):
50 | os.makedirs(out_dir)
51 |
52 | # Dataset
53 | train_dataset = config.get_dataset('train', cfg)
54 | val_dataset = config.get_dataset('val', cfg)
55 |
56 | # Dataloader
57 | train_loader = torch.utils.data.DataLoader(
58 | train_dataset, batch_size=batch_size, num_workers=4, shuffle=True,
59 | collate_fn=data.collate_remove_none,
60 | worker_init_fn=data.worker_init_fn)
61 | val_loader = torch.utils.data.DataLoader(
62 | val_dataset, batch_size=batch_size_val, num_workers=1, shuffle=False,
63 | collate_fn=data.collate_remove_none,
64 | worker_init_fn=data.worker_init_fn)
65 |
66 | # For visualizations
67 | vis_loader = torch.utils.data.DataLoader(
68 | val_dataset, batch_size=batch_size_vis, shuffle=True,
69 | collate_fn=data.collate_remove_none,
70 | worker_init_fn=data.worker_init_fn)
71 | data_vis = next(iter(vis_loader))
72 |
73 | # Model
74 | model = config.get_model(cfg, device=device, dataset=train_dataset)
75 |
76 | # Get optimizer and trainer
77 | optimizer = optim.Adam(model.parameters(), lr=lr)
78 | trainer = config.get_trainer(model, optimizer, cfg, device=device)
79 |
80 | # Load pre-trained model is existing
81 | kwargs = {
82 | 'model': model,
83 | 'optimizer': optimizer,
84 | }
85 | checkpoint_io = CheckpointIO(
86 | out_dir, initialize_from=cfg['model']['initialize_from'],
87 | initialization_file_name=cfg['model']['initialization_file_name'],
88 | **kwargs)
89 | try:
90 | load_dict = checkpoint_io.load('model.pt')
91 | except FileExistsError:
92 | load_dict = dict()
93 | epoch_it = load_dict.get('epoch_it', -1)
94 | it = load_dict.get('it', -1)
95 | metric_val_best = load_dict.get(
96 | 'loss_val_best', -model_selection_sign * np.inf)
97 |
98 | if metric_val_best == np.inf or metric_val_best == -np.inf:
99 | metric_val_best = -model_selection_sign * np.inf
100 |
101 | print('Current best validation metric (%s): %.8f'
102 | % (model_selection_metric, metric_val_best))
103 |
104 | logger = SummaryWriter(os.path.join(out_dir, 'logs'))
105 |
106 | # Shorthands
107 | print_every = cfg['training']['print_every']
108 | checkpoint_every = cfg['training']['checkpoint_every']
109 | validate_every = cfg['training']['validate_every']
110 | visualize_every = cfg['training']['visualize_every']
111 |
112 | # Print model
113 | nparameters = sum(p.numel() for p in model.parameters())
114 | print(model)
115 | print('Total number of parameters: %d' % nparameters)
116 |
117 | # Training loop
118 | while True:
119 | epoch_it += 1
120 |
121 | for batch in train_loader:
122 | it += 1
123 | loss = trainer.train_step(batch)
124 | logger.add_scalar('train/loss', loss, it)
125 |
126 | # Print output
127 | if print_every > 0 and (it % print_every) == 0:
128 | print('[Epoch %02d] it=%03d, loss=%.4f'
129 | % (epoch_it, it, loss))
130 |
131 | # Visualize output
132 | if visualize_every > 0 and (it % visualize_every) == 0:
133 | print('Visualizing')
134 | trainer.visualize(data_vis)
135 |
136 | # Save checkpoint
137 | if (checkpoint_every > 0 and (it % checkpoint_every) == 0):
138 | print('Saving checkpoint')
139 | checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it,
140 | loss_val_best=metric_val_best)
141 |
142 | # Backup if necessary
143 | if (backup_every > 0 and (it % backup_every) == 0):
144 | print('Backup checkpoint')
145 | checkpoint_io.save('model_%d.pt' % it, epoch_it=epoch_it, it=it,
146 | loss_val_best=metric_val_best)
147 | # Run validation
148 | if validate_every > 0 and (it % validate_every) == 0:
149 | eval_dict = trainer.evaluate(val_loader)
150 | metric_val = eval_dict[model_selection_metric]
151 | print('Validation metric (%s): %.4f'
152 | % (model_selection_metric, metric_val))
153 |
154 | for k, v in eval_dict.items():
155 | logger.add_scalar('val/%s' % k, v, it)
156 |
157 | if model_selection_sign * (metric_val - metric_val_best) > 0:
158 | metric_val_best = metric_val
159 | print('New best model (loss %.4f)' % metric_val_best)
160 | checkpoint_io.save('model_best.pt', epoch_it=epoch_it, it=it,
161 | loss_val_best=metric_val_best)
162 |
163 | # Exit if necessary
164 | if exit_after > 0 and (time.time() - t0) >= exit_after:
165 | print('Time limit reached. Exiting.')
166 | checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it,
167 | loss_val_best=metric_val_best)
168 | exit(3)
169 |
--------------------------------------------------------------------------------