├── .gitignore ├── LICENSE ├── README.md ├── configs ├── default.yaml ├── pointcloud │ ├── demo_syn_room.yaml │ ├── pretrained │ │ ├── room_1plane.yaml │ │ ├── room_3plane.yaml │ │ ├── room_combine.yaml │ │ ├── room_grid32.yaml │ │ ├── room_grid64.yaml │ │ ├── room_pointconv.yaml │ │ ├── shapenet_1plane.yaml │ │ ├── shapenet_3plane.yaml │ │ ├── shapenet_3plane_partial.yaml │ │ ├── shapenet_grid32.yaml │ │ └── shapenet_pointconv.yaml │ ├── room_1plane.yaml │ ├── room_3plane.yaml │ ├── room_combine.yaml │ ├── room_grid32.yaml │ ├── room_grid64.yaml │ ├── room_pointconv.yaml │ ├── shapenet_1plane.yaml │ ├── shapenet_3plane.yaml │ ├── shapenet_3plane_partial.yaml │ ├── shapenet_grid32.yaml │ └── shapenet_pointconv.yaml ├── pointcloud_crop │ ├── demo_matterport.yaml │ ├── pretrained │ │ └── room_grid64.yaml │ └── room_grid64.yaml └── voxel │ ├── pretrained │ ├── shapenet_1plane.yaml │ ├── shapenet_3plane.yaml │ └── shapenet_grid32.yaml │ ├── shapenet_1plane.yaml │ ├── shapenet_3plane.yaml │ └── shapenet_grid32.yaml ├── environment.yaml ├── eval_meshes.py ├── generate.py ├── media ├── demo_syn_room.gif └── teaser_matterport.gif ├── scripts ├── dataset_matterport │ └── build_dataset.py ├── dataset_scannet │ ├── SensorData.py │ └── build_dataset.py ├── dataset_synthetic_room │ └── build_dataset.py ├── download_data.sh └── download_demo_data.sh ├── setup.py ├── src ├── __init__.py ├── checkpoints.py ├── common.py ├── config.py ├── conv_onet │ ├── __init__.py │ ├── config.py │ ├── generation.py │ ├── models │ │ ├── __init__.py │ │ └── decoder.py │ └── training.py ├── data │ ├── __init__.py │ ├── core.py │ ├── fields.py │ └── transforms.py ├── encoder │ ├── __init__.py │ ├── pointnet.py │ ├── pointnetpp.py │ ├── unet.py │ ├── unet3d.py │ └── voxels.py ├── eval.py ├── layers.py ├── training.py └── utils │ ├── __init__.py │ ├── binvox_rw.py │ ├── icp.py │ ├── io.py │ ├── libkdtree │ ├── .gitignore │ ├── MANIFEST.in │ ├── README │ ├── README.rst │ ├── __init__.py │ ├── pykdtree │ │ ├── __init__.py │ │ ├── _kdtree_core.c │ │ ├── _kdtree_core.c.mako │ │ ├── kdtree.c │ │ ├── kdtree.pyx │ │ ├── render_template.py │ │ └── test_tree.py │ └── setup.cfg │ ├── libmcubes │ ├── .gitignore │ ├── LICENSE │ ├── README.rst │ ├── __init__.py │ ├── exporter.py │ ├── marchingcubes.cpp │ ├── marchingcubes.h │ ├── mcubes.pyx │ ├── pyarray_symbol.h │ ├── pyarraymodule.h │ ├── pywrapper.cpp │ └── pywrapper.h │ ├── libmesh │ ├── .gitignore │ ├── __init__.py │ ├── inside_mesh.py │ └── triangle_hash.pyx │ ├── libmise │ ├── .gitignore │ ├── __init__.py │ ├── mise.pyx │ └── test.py │ ├── libsimplify │ ├── Simplify.h │ ├── __init__.py │ ├── simplify_mesh.pyx │ └── test.py │ ├── libvoxelize │ ├── .gitignore │ ├── __init__.py │ ├── tribox2.h │ └── voxelize.pyx │ ├── mesh.py │ ├── visualize.py │ └── voxels.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | /out 2 | /data 3 | build 4 | .vscode 5 | .pytest_cache 6 | .cache 7 | *.pyc 8 | *.pyd 9 | *.pt 10 | *.so 11 | *.o 12 | *.prof 13 | *.swp 14 | *.lib 15 | *.obj 16 | *.exp 17 | .nfs* 18 | *.jpg 19 | *.png 20 | *.ply 21 | *.off 22 | *.npz 23 | *.txt 24 | /src/utils/libmcubes/mcubes.cpp 25 | /src/utils/libsimplify/simplify_mesh.cpp 26 | /src/utils/libsimplify/build 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Songyou Peng, Michael Niemeyer, Lars Mescheder, Marc Pollefeys, Andreas Geiger 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convolutional Occupancy Networks 2 | [**Paper**](https://arxiv.org/pdf/2003.04618.pdf) | [**Supplementary**](http://www.cvlibs.net/publications/Peng2020ECCV_supplementary.pdf) | [**Video**](https://www.youtube.com/watch?v=EmauovgrDSM) | [**Teaser Video**](https://youtu.be/k0monzIcjUo) | [**Project Page**](https://pengsongyou.github.io/conv_onet) | [**Blog Post**](https://autonomousvision.github.io/convolutional-occupancy-networks/)
3 | 4 |
5 | 6 |
7 | 8 | This repository contains the implementation of the paper: 9 | 10 | Convolutional Occupancy Networks 11 | [Songyou Peng](https://pengsongyou.github.io/), [Michael Niemeyer](https://m-niemeyer.github.io/), [Lars Mescheder](https://is.tuebingen.mpg.de/person/lmescheder), [Marc Pollefeys](https://www.inf.ethz.ch/personal/pomarc/) and [Andreas Geiger](http://www.cvlibs.net/) 12 | **ECCV 2020 (spotlight)** 13 | 14 | If you find our code or paper useful, please consider citing 15 | ```bibtex 16 | @inproceedings{Peng2020ECCV, 17 | author = {Peng, Songyou and Niemeyer, Michael and Mescheder, Lars and Pollefeys, Marc and Geiger, Andreas}, 18 | title = {Convolutional Occupancy Networks}, 19 | booktitle = {European Conference on Computer Vision (ECCV)}, 20 | year = {2020}} 21 | ``` 22 | Contact [Songyou Peng](mailto:songyou.pp@gmail.com) for questions, comments and reporting bugs. 23 | 24 | 25 | ## Installation 26 | First you have to make sure that you have all dependencies in place. 27 | The simplest way to do so, is to use [anaconda](https://www.anaconda.com/). 28 | 29 | You can create an anaconda environment called `conv_onet` using 30 | ``` 31 | conda env create -f environment.yaml 32 | conda activate conv_onet 33 | ``` 34 | **Note**: you might need to install **torch-scatter** mannually following [the official instruction](https://github.com/rusty1s/pytorch_scatter#pytorch-140): 35 | ``` 36 | pip install torch-scatter==2.0.4 -f https://pytorch-geometric.com/whl/torch-1.4.0+cu101.html 37 | ``` 38 | 39 | Next, compile the extension modules. 40 | You can do this via 41 | ``` 42 | python setup.py build_ext --inplace 43 | ``` 44 | 45 | ## Demo 46 | First, run the script to get the demo data: 47 | ``` 48 | bash scripts/download_demo_data.sh 49 | ``` 50 | ### Reconstruct Large-Scale Matterport3D Scene 51 | You can now quickly test our code on the real-world scene shown in the teaser. To this end, simply run: 52 | ``` 53 | python generate.py configs/pointcloud_crop/demo_matterport.yaml 54 | ``` 55 | This script should create a folder `out/demo_matterport/generation` where the output meshes and input point cloud are stored. 56 | 57 | **Note**: This experiment corresponds to our **fully convolutional model**, which we train only on the small crops from our synthetic room dataset. This model can be directly applied to large-scale real-world scenes with real units and generate meshes in a sliding-window manner, as shown in the [teaser](media/teaser_matterport.gif). More details can be found in section 6 of our [supplementary material](http://www.cvlibs.net/publications/Peng2020ECCV_supplementary.pdf). For training, you can use the script `pointcloud_crop/room_grid64.yaml`. 58 | 59 | 60 | ### Reconstruct Synthetic Indoor Scene 61 |
62 | 63 |
64 | 65 | You can also test on our synthetic room dataset by running: 66 | ``` 67 | python generate.py configs/pointcloud/demo_syn_room.yaml 68 | ``` 69 | ## Dataset 70 | 71 | To evaluate a pretrained model or train a new model from scratch, you have to obtain the respective dataset. 72 | In this paper, we consider 4 different datasets: 73 | 74 | ### ShapeNet 75 | You can download the dataset (73.4 GB) by running the [script](https://github.com/autonomousvision/occupancy_networks#preprocessed-data) from Occupancy Networks. After, you should have the dataset in `data/ShapeNet` folder. 76 | 77 | ### Synthetic Indoor Scene Dataset 78 | For scene-level reconstruction, we create a synthetic dataset of 5000 79 | scenes with multiple objects from ShapeNet (chair, sofa, lamp, cabinet, table). There are also ground planes and randomly sampled walls. 80 | 81 | You can download our preprocessed data (144 GB) using 82 | 83 | ``` 84 | bash scripts/download_data.sh 85 | ``` 86 | 87 | This script should download and unpack the data automatically into the `data/synthetic_room_dataset` folder. 88 | **Note**: We also provide **point-wise semantic labels** in the dataset, which might be useful. 89 | 90 | 91 | Alternatively, you can also preprocess the dataset yourself. 92 | To this end, you can: 93 | * download the ShapeNet dataset as described above. 94 | * check `scripts/dataset_synthetic_room/build_dataset.py`, modify the path and run the code. 95 | 96 | ### Matterport3D 97 | Download Matterport3D dataset from [the official website](https://niessner.github.io/Matterport/). And then, use `scripts/dataset_matterport/build_dataset.py` to preprocess one of your favorite scenes. Put the processed data into `data/Matterport3D_processed` folder. 98 | 99 | ### ScanNet 100 | Download ScanNet v2 data from the [official ScanNet website](https://github.com/ScanNet/ScanNet). 101 | Then, you can preprocess data with: 102 | `scripts/dataset_scannet/build_dataset.py` and put into `data/ScanNet` folder. 103 | **Note**: Currently, the preprocess script normalizes ScanNet data to a unit cube for the comparison shown in the paper, but you can easily adapt the code to produce data with real-world metric. You can then use our fully convolutional model to run evaluation in a sliding-window manner. 104 | 105 | ## Usage 106 | When you have installed all binary dependencies and obtained the preprocessed data, you are ready to run our pre-trained models and train new models from scratch. 107 | 108 | ### Mesh Generation 109 | To generate meshes using a trained model, use 110 | ``` 111 | python generate.py CONFIG.yaml 112 | ``` 113 | where you replace `CONFIG.yaml` with the correct config file. 114 | 115 | **Use a pre-trained model** 116 | The easiest way is to use a pre-trained model. You can do this by using one of the config files under the `pretrained` folders. 117 | 118 | For example, for 3D reconstruction from noisy point cloud with our 3-plane model on the synthetic room dataset, you can simply run: 119 | ``` 120 | python generate.py configs/pointcloud/pretrained/room_3plane.yaml 121 | ``` 122 | The script will automatically download the pretrained model and run the generation. You can find the outputs in the `out/.../generation_pretrained` folders 123 | 124 | Note that the config files are only for generation, not for training new models: when these configs are used for training, the model will be trained from scratch, but during inference our code will still use the pretrained model. 125 | 126 | 127 | We provide the following pretrained models: 128 | ``` 129 | pointcloud/shapenet_1plane.pt 130 | pointcloud/shapenet_3plane.pt 131 | pointcloud/shapenet_grid32.pt 132 | pointcloud/shapenet_3plane_partial.pt 133 | pointcloud/shapenet_pointconv.pt 134 | pointcloud/room_1plane.pt 135 | pointcloud/room_3plane.pt 136 | pointcloud/room_grid32.pt 137 | pointcloud/room_grid64.pt 138 | pointcloud/room_combine.pt 139 | pointcloud/room_pointconv.pt 140 | pointcloud_crop/room_grid64.pt 141 | voxel/voxel_shapenet_1plane.pt 142 | voxel/voxel_shapenet_3plane.pt 143 | voxel/voxel_shapenet_grid32.pt 144 | ``` 145 | 146 | ### Evaluation 147 | For evaluation of the models, we provide the script `eval_meshes.py`. You can run it using: 148 | ``` 149 | python eval_meshes.py CONFIG.yaml 150 | ``` 151 | The script takes the meshes generated in the previous step and evaluates them using a standardized protocol. The output will be written to `.pkl/.csv` files in the corresponding generation folder which can be processed using [pandas](https://pandas.pydata.org/). 152 | 153 | **Note:** We follow previous works to use "use 1/10 times the maximal edge length of the current object’s bounding box as unit 1" (see [Section 4 - Metrics](http://www.cvlibs.net/publications/Mescheder2019CVPR.pdf)). In practice, this means that we multiply the Chamfer-L1 by a factor of 10 for reporting the numbers in the paper. 154 | 155 | ### Training 156 | Finally, to train a new network from scratch, run: 157 | ``` 158 | python train.py CONFIG.yaml 159 | ``` 160 | For available training options, please take a look at `configs/default.yaml`. 161 | 162 | ## Further Information 163 | Please also check out the following concurrent works that either tackle similar problems or share similar ideas: 164 | - [[CVPR 2020] Jiang et al. - Local Implicit Grid Representations for 3D Scenes](https://arxiv.org/abs/2003.08981) 165 | - [[CVPR 2020] Chibane et al. Implicit Functions in Feature Space for 3D Shape Reconstruction and Completion](https://arxiv.org/abs/2003.01456) 166 | - [[ECCV 2020] Chabra et al. - Deep Local Shapes: Learning Local SDF Priors for Detailed 3D Reconstruction](https://arxiv.org/abs/2003.10983) 167 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | method: conv_onet 2 | data: 3 | dataset: Shapes3D 4 | path: data/ShapeNet 5 | watertight_path: data/watertight 6 | classes: null 7 | input_type: img 8 | train_split: train 9 | val_split: val 10 | test_split: test 11 | dim: 3 12 | points_file: points.npz 13 | points_iou_file: points.npz 14 | multi_files: null 15 | points_subsample: 1024 16 | points_unpackbits: true 17 | model_file: model.off 18 | watertight_file: model_watertight.off 19 | img_folder: img 20 | img_size: 224 21 | img_with_camera: false 22 | img_augment: false 23 | n_views: 24 24 | pointcloud_file: pointcloud.npz 25 | pointcloud_chamfer_file: pointcloud.npz 26 | pointcloud_n: 256 27 | pointcloud_target_n: 1024 28 | pointcloud_noise: 0.05 29 | voxels_file: 'model.binvox' 30 | padding: 0.1 31 | model: 32 | decoder: simple 33 | encoder: resnet18 34 | decoder_kwargs: {} 35 | encoder_kwargs: {} 36 | multi_gpu: false 37 | c_dim: 512 38 | training: 39 | out_dir: out/default 40 | batch_size: 64 41 | print_every: 200 42 | visualize_every: 1000 43 | checkpoint_every: 1000 44 | validate_every: 2000 45 | backup_every: 100000 46 | eval_sample: false 47 | model_selection_metric: loss 48 | model_selection_mode: minimize 49 | n_workers: 4 50 | n_workers_val: 4 51 | test: 52 | threshold: 0.5 53 | eval_mesh: true 54 | eval_pointcloud: true 55 | remove_wall: false 56 | model_file: model_best.pt 57 | generation: 58 | batch_size: 100000 59 | refinement_step: 0 60 | vis_n_outputs: 30 61 | generate_mesh: true 62 | generate_pointcloud: true 63 | generation_dir: generation 64 | use_sampling: false 65 | resolution_0: 32 66 | upsampling_steps: 2 67 | simplify_nfaces: null 68 | copy_groundtruth: false 69 | copy_input: true 70 | latent_number: 4 71 | latent_H: 8 72 | latent_W: 8 73 | latent_ny: 2 74 | latent_nx: 2 75 | latent_repeat: true 76 | sliding_window: False # added for crop generation -------------------------------------------------------------------------------- /configs/pointcloud/demo_syn_room.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/pointcloud/room_combine.yaml 2 | data: 3 | classes: [''] 4 | path: data/demo/synthetic_room_dataset 5 | pointcloud_n: 10000 6 | pointcloud_file: pointcloud 7 | voxels_file: null 8 | points_file: null 9 | points_iou_file: null 10 | training: 11 | out_dir: out/demo_syn_room 12 | test: 13 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/pointcloud/room_combine.pt 14 | generation: 15 | generation_dir: generation 16 | -------------------------------------------------------------------------------- /configs/pointcloud/pretrained/room_1plane.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/pointcloud/room_1plane.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/pointcloud/room_1plane.pt 6 | -------------------------------------------------------------------------------- /configs/pointcloud/pretrained/room_3plane.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/pointcloud/room_3plane.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/pointcloud/room_3plane.pt 6 | -------------------------------------------------------------------------------- /configs/pointcloud/pretrained/room_combine.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/pointcloud/room_combine.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/pointcloud/room_combine.pt 6 | -------------------------------------------------------------------------------- /configs/pointcloud/pretrained/room_grid32.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/pointcloud/room_grid32.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/pointcloud/room_grid32.pt 6 | -------------------------------------------------------------------------------- /configs/pointcloud/pretrained/room_grid64.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/pointcloud/room_grid64.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/pointcloud/room_grid64.pt 6 | -------------------------------------------------------------------------------- /configs/pointcloud/pretrained/room_pointconv.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/pointcloud/room_pointconv.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/pointcloud/room_pointconv.pt 6 | -------------------------------------------------------------------------------- /configs/pointcloud/pretrained/shapenet_1plane.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/pointcloud/shapenet_1plane.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/pointcloud/shapenet_1plane.pt 6 | -------------------------------------------------------------------------------- /configs/pointcloud/pretrained/shapenet_3plane.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/pointcloud/shapenet_3plane.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/pointcloud/shapenet_3plane.pt 6 | -------------------------------------------------------------------------------- /configs/pointcloud/pretrained/shapenet_3plane_partial.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/pointcloud/shapenet_3plane_partial.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/pointcloud/shapenet_3plane_partial.pt 6 | -------------------------------------------------------------------------------- /configs/pointcloud/pretrained/shapenet_grid32.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/pointcloud/shapenet_grid32.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/pointcloud/shapenet_grid32.pt 6 | -------------------------------------------------------------------------------- /configs/pointcloud/pretrained/shapenet_pointconv.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/pointcloud/shapenet_pointconv.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/pointcloud/shapenet_pointconv.pt 6 | -------------------------------------------------------------------------------- /configs/pointcloud/room_1plane.yaml: -------------------------------------------------------------------------------- 1 | method: conv_onet 2 | data: 3 | input_type: pointcloud 4 | classes: ['rooms_04', 'rooms_05', 'rooms_06', 'rooms_07', 'rooms_08'] 5 | path: data/synthetic_room_dataset 6 | pointcloud_n: 10000 7 | pointcloud_noise: 0.005 8 | points_subsample: 2048 9 | points_file: points_iou 10 | points_iou_file: points_iou 11 | pointcloud_file: pointcloud 12 | pointcloud_chamfer_file: pointcloud 13 | multi_files: 10 14 | voxels_file: null 15 | model: 16 | encoder: pointnet_local_pool 17 | encoder_kwargs: 18 | hidden_dim: 32 19 | plane_type: ['xz'] 20 | plane_resolution: 128 21 | unet: True 22 | unet_kwargs: 23 | depth: 5 24 | merge_mode: concat 25 | start_filts: 32 26 | decoder: simple_local 27 | decoder_kwargs: 28 | sample_mode: bilinear # bilinear / nearest 29 | hidden_size: 32 30 | c_dim: 32 31 | training: 32 | out_dir: out/pointcloud/room_1plane 33 | batch_size: 32 34 | model_selection_metric: iou 35 | model_selection_mode: maximize 36 | print_every: 100 37 | visualize_every: 10000 38 | validate_every: 10000 39 | checkpoint_every: 2000 40 | backup_every: 10000 41 | n_workers: 8 42 | n_workers_val: 4 43 | test: 44 | threshold: 0.2 45 | eval_mesh: true 46 | eval_pointcloud: false 47 | remove_wall: true 48 | model_file: model_best.pt 49 | generation: 50 | vis_n_outputs: 2 51 | refine: false 52 | n_x: 128 53 | n_z: 1 54 | -------------------------------------------------------------------------------- /configs/pointcloud/room_3plane.yaml: -------------------------------------------------------------------------------- 1 | method: conv_onet 2 | data: 3 | input_type: pointcloud 4 | classes: ['rooms_04', 'rooms_05', 'rooms_06', 'rooms_07', 'rooms_08'] 5 | path: data/synthetic_room_dataset 6 | pointcloud_n: 10000 7 | pointcloud_noise: 0.005 8 | points_subsample: 2048 9 | points_file: points_iou 10 | points_iou_file: points_iou 11 | pointcloud_file: pointcloud 12 | pointcloud_chamfer_file: pointcloud 13 | multi_files: 10 14 | voxels_file: null 15 | model: 16 | encoder: pointnet_local_pool 17 | encoder_kwargs: 18 | hidden_dim: 32 19 | plane_type: ['xz', 'xy', 'yz'] 20 | plane_resolution: 128 21 | unet: True 22 | unet_kwargs: 23 | depth: 5 24 | merge_mode: concat 25 | start_filts: 32 26 | decoder: simple_local 27 | decoder_kwargs: 28 | sample_mode: bilinear # bilinear / nearest 29 | hidden_size: 32 30 | c_dim: 32 31 | training: 32 | out_dir: out/pointcloud/room_3plane 33 | batch_size: 32 34 | model_selection_metric: iou 35 | model_selection_mode: maximize 36 | print_every: 100 37 | visualize_every: 10000 38 | validate_every: 10000 39 | checkpoint_every: 2000 40 | backup_every: 10000 41 | n_workers: 8 42 | n_workers_val: 4 43 | test: 44 | threshold: 0.2 45 | eval_mesh: true 46 | eval_pointcloud: false 47 | remove_wall: true 48 | model_file: model_best.pt 49 | generation: 50 | vis_n_outputs: 2 51 | refine: false 52 | n_x: 128 53 | n_z: 1 54 | -------------------------------------------------------------------------------- /configs/pointcloud/room_combine.yaml: -------------------------------------------------------------------------------- 1 | method: conv_onet 2 | data: 3 | input_type: pointcloud 4 | classes: ['rooms_04', 'rooms_05', 'rooms_06', 'rooms_07', 'rooms_08'] 5 | path: data/synthetic_room_dataset 6 | pointcloud_n: 10000 7 | pointcloud_noise: 0.005 8 | points_subsample: 2048 9 | points_file: points_iou 10 | points_iou_file: points_iou 11 | pointcloud_file: pointcloud 12 | pointcloud_chamfer_file: pointcloud 13 | multi_files: 10 14 | voxels_file: null 15 | model: 16 | encoder: pointnet_local_pool 17 | encoder_kwargs: 18 | hidden_dim: 32 19 | plane_type: ['xz', 'xy', 'yz', 'grid'] 20 | grid_resolution: 32 21 | unet3d: True 22 | unet3d_kwargs: 23 | num_levels: 3 24 | f_maps: 32 25 | in_channels: 32 26 | out_channels: 32 27 | plane_resolution: 128 28 | unet: True 29 | unet_kwargs: 30 | depth: 5 31 | merge_mode: concat 32 | start_filts: 32 33 | decoder: simple_local 34 | decoder_kwargs: 35 | sample_mode: bilinear # bilinear / nearest 36 | hidden_size: 32 37 | c_dim: 32 38 | training: 39 | out_dir: out/pointcloud/room_combine 40 | batch_size: 24 41 | model_selection_metric: iou 42 | model_selection_mode: maximize 43 | print_every: 100 44 | visualize_every: 10000 45 | validate_every: 10000 46 | checkpoint_every: 2000 47 | backup_every: 10000 48 | n_workers: 8 49 | n_workers_val: 4 50 | test: 51 | threshold: 0.2 52 | eval_mesh: true 53 | eval_pointcloud: false 54 | remove_wall: true 55 | model_file: model_best.pt 56 | generation: 57 | vis_n_outputs: 2 58 | refine: false 59 | n_x: 128 60 | n_z: 1 61 | -------------------------------------------------------------------------------- /configs/pointcloud/room_grid32.yaml: -------------------------------------------------------------------------------- 1 | method: conv_onet 2 | data: 3 | input_type: pointcloud 4 | classes: ['rooms_04', 'rooms_05', 'rooms_06', 'rooms_07', 'rooms_08'] 5 | path: data/synthetic_room_dataset 6 | pointcloud_n: 10000 7 | pointcloud_noise: 0.005 8 | points_subsample: 2048 9 | points_file: points_iou 10 | points_iou_file: points_iou 11 | pointcloud_file: pointcloud 12 | pointcloud_chamfer_file: pointcloud 13 | multi_files: 10 14 | voxels_file: null 15 | model: 16 | encoder: pointnet_local_pool 17 | encoder_kwargs: 18 | hidden_dim: 32 19 | plane_type: 'grid' 20 | grid_resolution: 32 21 | unet3d: True 22 | unet3d_kwargs: 23 | num_levels: 3 24 | f_maps: 32 25 | in_channels: 32 26 | out_channels: 32 27 | decoder: simple_local 28 | decoder_kwargs: 29 | sample_mode: bilinear # bilinear / nearest 30 | hidden_size: 32 31 | c_dim: 32 32 | training: 33 | out_dir: out/pointcloud/room_grid32 34 | batch_size: 32 35 | model_selection_metric: iou 36 | model_selection_mode: maximize 37 | print_every: 100 38 | visualize_every: 10000 39 | validate_every: 10000 40 | checkpoint_every: 2000 41 | backup_every: 10000 42 | n_workers: 8 43 | n_workers_val: 4 44 | test: 45 | threshold: 0.2 46 | eval_mesh: true 47 | eval_pointcloud: false 48 | remove_wall: true 49 | model_file: model_best.pt 50 | generation: 51 | vis_n_outputs: 2 52 | refine: false 53 | n_x: 128 54 | n_z: 1 55 | -------------------------------------------------------------------------------- /configs/pointcloud/room_grid64.yaml: -------------------------------------------------------------------------------- 1 | method: conv_onet 2 | data: 3 | input_type: pointcloud 4 | classes: ['rooms_04', 'rooms_05', 'rooms_06', 'rooms_07', 'rooms_08'] 5 | path: data/synthetic_room_dataset 6 | pointcloud_n: 10000 7 | pointcloud_noise: 0.005 8 | points_subsample: 2048 9 | points_file: points_iou 10 | points_iou_file: points_iou 11 | pointcloud_file: pointcloud 12 | pointcloud_chamfer_file: pointcloud 13 | multi_files: 10 14 | voxels_file: null 15 | model: 16 | encoder: pointnet_local_pool 17 | encoder_kwargs: 18 | hidden_dim: 32 19 | plane_type: 'grid' 20 | grid_resolution: 64 21 | unet3d: True 22 | unet3d_kwargs: 23 | num_levels: 4 24 | f_maps: 32 25 | in_channels: 32 26 | out_channels: 32 27 | decoder: simple_local 28 | decoder_kwargs: 29 | sample_mode: bilinear # bilinear / nearest 30 | hidden_size: 32 31 | c_dim: 32 32 | training: 33 | out_dir: out/pointcloud/room_grid64 34 | batch_size: 6 35 | model_selection_metric: iou 36 | model_selection_mode: maximize 37 | print_every: 100 38 | visualize_every: 10000 39 | validate_every: 10000 40 | checkpoint_every: 2000 41 | backup_every: 10000 42 | n_workers: 8 43 | n_workers_val: 4 44 | test: 45 | threshold: 0.2 46 | eval_mesh: true 47 | eval_pointcloud: false 48 | remove_wall: true 49 | model_file: model_best.pt 50 | generation: 51 | generation_dir: generation_new 52 | vis_n_outputs: 2 53 | refine: false 54 | n_x: 128 55 | n_z: 1 56 | -------------------------------------------------------------------------------- /configs/pointcloud/room_pointconv.yaml: -------------------------------------------------------------------------------- 1 | method: conv_onet 2 | data: 3 | input_type: pointcloud 4 | classes: ['rooms_04', 'rooms_05', 'rooms_06', 'rooms_07', 'rooms_08'] 5 | path: data/synthetic_room_dataset 6 | pointcloud_n: 10000 7 | pointcloud_noise: 0.005 8 | points_subsample: 2048 9 | points_file: points_iou 10 | points_iou_file: points_iou 11 | pointcloud_file: pointcloud 12 | pointcloud_chamfer_file: pointcloud 13 | multi_files: 10 14 | voxels_file: null 15 | model: 16 | encoder: pointnet_plus_plus 17 | decoder: simple_local_point 18 | decoder_kwargs: 19 | sample_mode: gaussian 20 | gaussian_val: 0.2 21 | c_dim: 32 22 | training: 23 | out_dir: out/pointcloud/room_pointconv 24 | batch_size: 20 25 | model_selection_metric: iou 26 | model_selection_mode: maximize 27 | print_every: 10 28 | visualize_every: 10000 29 | validate_every: 10000 30 | checkpoint_every: 2000 31 | backup_every: 10000 32 | n_workers: 8 33 | n_workers_val: 4 34 | test: 35 | threshold: 0.2 36 | eval_mesh: true 37 | eval_pointcloud: false 38 | model_file: model_best.pt 39 | generation: 40 | vis_n_outputs: 2 41 | refine: false 42 | n_x: 128 43 | n_z: 1 44 | -------------------------------------------------------------------------------- /configs/pointcloud/shapenet_1plane.yaml: -------------------------------------------------------------------------------- 1 | method: conv_onet 2 | data: 3 | input_type: pointcloud 4 | classes: null 5 | path: data/ShapeNet 6 | pointcloud_n: 3000 7 | pointcloud_noise: 0.005 8 | points_subsample: 2048 9 | points_file: points.npz 10 | points_iou_file: points.npz 11 | voxels_file: null 12 | model: 13 | encoder: pointnet_local_pool 14 | encoder_kwargs: 15 | hidden_dim: 32 16 | plane_type: ['xz'] 17 | plane_resolution: 64 18 | unet: True 19 | unet_kwargs: 20 | depth: 4 21 | merge_mode: concat 22 | start_filts: 32 23 | decoder: simple_local 24 | decoder_kwargs: 25 | sample_mode: bilinear # bilinear / nearest 26 | hidden_size: 32 27 | c_dim: 32 28 | training: 29 | out_dir: out/pointcloud/shapenet_1plane 30 | batch_size: 32 31 | model_selection_metric: iou 32 | model_selection_mode: maximize 33 | print_every: 1000 34 | visualize_every: 10000 35 | validate_every: 10000 36 | checkpoint_every: 2000 37 | backup_every: 10000 38 | n_workers: 8 39 | n_workers_val: 4 40 | test: 41 | threshold: 0.2 42 | eval_mesh: true 43 | eval_pointcloud: false 44 | model_file: model_best.pt 45 | generation: 46 | vis_n_outputs: 2 47 | refine: false 48 | n_x: 128 49 | n_z: 1 50 | -------------------------------------------------------------------------------- /configs/pointcloud/shapenet_3plane.yaml: -------------------------------------------------------------------------------- 1 | method: conv_onet 2 | data: 3 | input_type: pointcloud 4 | classes: null 5 | path: data/ShapeNet 6 | pointcloud_n: 3000 7 | pointcloud_noise: 0.005 8 | points_subsample: 2048 9 | points_file: points.npz 10 | points_iou_file: points.npz 11 | voxels_file: null 12 | model: 13 | encoder: pointnet_local_pool 14 | encoder_kwargs: 15 | hidden_dim: 32 16 | plane_type: ['xz', 'xy', 'yz'] 17 | plane_resolution: 64 18 | unet: True 19 | unet_kwargs: 20 | depth: 4 21 | merge_mode: concat 22 | start_filts: 32 23 | decoder: simple_local 24 | decoder_kwargs: 25 | sample_mode: bilinear # bilinear / nearest 26 | hidden_size: 32 27 | c_dim: 32 28 | training: 29 | out_dir: out/pointcloud/shapenet_3plane 30 | batch_size: 32 31 | model_selection_metric: iou 32 | model_selection_mode: maximize 33 | print_every: 1000 34 | visualize_every: 10000 35 | validate_every: 10000 36 | checkpoint_every: 2000 37 | backup_every: 10000 38 | n_workers: 8 39 | n_workers_val: 4 40 | test: 41 | threshold: 0.2 42 | eval_mesh: true 43 | eval_pointcloud: false 44 | model_file: model_best.pt 45 | generation: 46 | vis_n_outputs: 2 47 | refine: false 48 | n_x: 128 49 | n_z: 1 50 | -------------------------------------------------------------------------------- /configs/pointcloud/shapenet_3plane_partial.yaml: -------------------------------------------------------------------------------- 1 | method: conv_onet 2 | data: 3 | input_type: partial_pointcloud 4 | classes: null 5 | path: data/ShapeNet 6 | pointcloud_n: 3000 7 | pointcloud_noise: 0.005 8 | points_subsample: 2048 9 | points_file: points.npz 10 | points_iou_file: points.npz 11 | voxels_file: null 12 | model: 13 | encoder: pointnet_local_pool 14 | encoder_kwargs: 15 | hidden_dim: 32 16 | plane_type: ['xz', 'xy', 'yz'] 17 | plane_resolution: 64 18 | unet: True 19 | unet_kwargs: 20 | depth: 4 21 | merge_mode: concat 22 | start_filts: 32 23 | decoder: simple_local 24 | decoder_kwargs: 25 | sample_mode: bilinear # bilinear / nearest 26 | hidden_size: 32 27 | c_dim: 32 28 | training: 29 | out_dir: out/pointcloud/partial_shapenet 30 | batch_size: 32 31 | model_selection_metric: iou 32 | model_selection_mode: maximize 33 | print_every: 100 34 | visualize_every: 10000 35 | validate_every: 10000 36 | checkpoint_every: 2000 37 | backup_every: 10000 38 | n_workers: 12 39 | n_workers_val: 12 40 | test: 41 | threshold: 0.2 42 | eval_mesh: true 43 | eval_pointcloud: false 44 | model_file: model_best.pt 45 | generation: 46 | vis_n_outputs: 2 47 | refine: false 48 | n_x: 128 49 | n_z: 1 50 | -------------------------------------------------------------------------------- /configs/pointcloud/shapenet_grid32.yaml: -------------------------------------------------------------------------------- 1 | method: conv_onet 2 | data: 3 | input_type: pointcloud 4 | classes: null 5 | path: data/ShapeNet 6 | pointcloud_n: 3000 7 | pointcloud_noise: 0.005 8 | points_subsample: 2048 9 | points_file: points.npz 10 | points_iou_file: points.npz 11 | # points_unpackbits: false 12 | voxels_file: null 13 | model: 14 | encoder: pointnet_local_pool 15 | encoder_kwargs: 16 | hidden_dim: 32 17 | plane_type: 'grid' 18 | grid_resolution: 32 19 | unet3d: True 20 | unet3d_kwargs: 21 | num_levels: 3 22 | f_maps: 32 23 | in_channels: 32 24 | out_channels: 32 25 | decoder: simple_local 26 | decoder_kwargs: 27 | sample_mode: bilinear # bilinear / nearest 28 | hidden_size: 32 29 | c_dim: 32 30 | training: 31 | out_dir: out/pointcloud/shapenet_grid32 32 | batch_size: 32 33 | model_selection_metric: iou 34 | model_selection_mode: maximize 35 | print_every: 100 36 | visualize_every: 10000 37 | validate_every: 10000 38 | checkpoint_every: 2000 39 | backup_every: 10000 40 | n_workers: 8 41 | n_workers_val: 4 42 | test: 43 | threshold: 0.2 44 | eval_mesh: true 45 | eval_pointcloud: false 46 | model_file: model_best.pt 47 | generation: 48 | vis_n_outputs: 2 49 | refine: false 50 | n_x: 128 51 | n_z: 1 52 | -------------------------------------------------------------------------------- /configs/pointcloud/shapenet_pointconv.yaml: -------------------------------------------------------------------------------- 1 | method: conv_onet 2 | data: 3 | input_type: pointcloud 4 | classes: null 5 | path: data/ShapeNet 6 | pointcloud_n: 3000 7 | pointcloud_noise: 0.005 8 | points_subsample: 2048 9 | points_file: points.npz 10 | points_iou_file: points.npz 11 | voxels_file: null 12 | model: 13 | encoder: pointnet_plus_plus 14 | decoder: simple_local_point 15 | decoder_kwargs: 16 | hidden_size: 32 17 | sample_mode: gaussian 18 | gaussian_val: 0.2 19 | c_dim: 32 20 | training: 21 | out_dir: out/pointcloud/shapenet_pointconv 22 | batch_size: 24 23 | model_selection_metric: iou 24 | model_selection_mode: maximize 25 | print_every: 100 26 | visualize_every: 10000 27 | validate_every: 10000 28 | checkpoint_every: 2000 29 | backup_every: 10000 30 | n_workers: 8 31 | n_workers_val: 4 32 | test: 33 | threshold: 0.2 34 | eval_mesh: true 35 | eval_pointcloud: false 36 | model_file: model_best.pt 37 | generation: 38 | vis_n_outputs: 2 39 | refine: false 40 | n_x: 128 41 | n_z: 1 42 | -------------------------------------------------------------------------------- /configs/pointcloud_crop/demo_matterport.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/pointcloud_crop/room_grid64.yaml 2 | data: 3 | input_type: pointcloud_crop 4 | classes: [''] 5 | path: data/demo/Matterport3D_processed 6 | pointcloud_n: 200000 7 | pointcloud_noise: 0.0 8 | pointcloud_file: pointcloud.npz 9 | voxels_file: null 10 | points_file: null 11 | points_iou_file: null 12 | multi_files: null 13 | unit_size: 0.02 # define the size of a voxel, in meter 14 | query_vol_size: 90 # query crop in voxel 15 | training: 16 | out_dir: out/demo_matterport 17 | test: 18 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/pointcloud_crop/room_grid64.pt 19 | generation: 20 | generation_dir: generation 21 | sliding_window: True # generate mesh in the sliding-window manner 22 | resolution_0: 128 # resolution for each crop 23 | upsampling_steps: 0 24 | -------------------------------------------------------------------------------- /configs/pointcloud_crop/pretrained/room_grid64.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/pointcloud_crop/room_grid64.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/pointcloud_crop/room_grid64.pt 6 | -------------------------------------------------------------------------------- /configs/pointcloud_crop/room_grid64.yaml: -------------------------------------------------------------------------------- 1 | method: conv_onet 2 | data: 3 | input_type: pointcloud_crop 4 | classes: ['rooms_04', 'rooms_05', 'rooms_06', 'rooms_07', 'rooms_08'] 5 | path: data/synthetic_room_dataset 6 | pointcloud_n: 10000 7 | pointcloud_noise: 0.005 8 | points_subsample: 1024 9 | points_file: points_iou 10 | points_iou_file: points_iou 11 | pointcloud_file: pointcloud 12 | pointcloud_chamfer_file: pointcloud 13 | voxels_file: null 14 | multi_files: 10 15 | unit_size: 0.005 # size of a voxel 16 | query_vol_size: 25 17 | model: 18 | local_coord: True 19 | encoder: pointnet_crop_local_pool 20 | encoder_kwargs: 21 | hidden_dim: 32 22 | plane_type: ['grid'] 23 | unet3d: True 24 | unet3d_kwargs: 25 | num_levels: 4 # define the receptive field, 3 -> 32, 4 -> 64 26 | f_maps: 32 27 | in_channels: 32 28 | out_channels: 32 29 | decoder: simple_local_crop 30 | decoder_kwargs: 31 | sample_mode: bilinear # bilinear / nearest 32 | hidden_size: 32 33 | c_dim: 32 34 | training: 35 | out_dir: out/pointcloud_crop_training 36 | batch_size: 2 37 | model_selection_metric: iou 38 | model_selection_mode: maximize 39 | print_every: 100 40 | visualize_every: 10000 41 | validate_every: 1000000000 # TODO: validation for crop training 42 | checkpoint_every: 1000 43 | backup_every: 10000 44 | n_workers: 8 45 | n_workers_val: 4 46 | test: 47 | threshold: 0.2 48 | eval_mesh: true 49 | eval_pointcloud: false 50 | model_file: model_best.pt 51 | generation: 52 | generation_dir: generation 53 | vis_n_outputs: 2 54 | sliding_window: True 55 | resolution_0: 32 56 | upsampling_steps: 0 -------------------------------------------------------------------------------- /configs/voxel/pretrained/shapenet_1plane.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/voxel/shapenet_1plane.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/voxel/voxel_shapenet_1plane.pt 6 | -------------------------------------------------------------------------------- /configs/voxel/pretrained/shapenet_3plane.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/voxel/shapenet_3plane.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/voxel/voxel_shapenet_3plane.pt 6 | -------------------------------------------------------------------------------- /configs/voxel/pretrained/shapenet_grid32.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/voxel/shapenet_grid32.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/models/voxel/voxel_shapenet_grid32.pt 6 | -------------------------------------------------------------------------------- /configs/voxel/shapenet_1plane.yaml: -------------------------------------------------------------------------------- 1 | method: onet 2 | data: 3 | classes: null 4 | input_type: voxels 5 | path: data/ShapeNet 6 | dim: 3 7 | points_subsample: 1024 8 | model: 9 | encoder: voxel_simple_local 10 | encoder_kwargs: 11 | plane_resolution: 64 12 | plane_type: ['xz'] 13 | unet: True 14 | unet_kwargs: 15 | depth: 4 16 | merge_mode: concat 17 | start_filts: 32 18 | decoder: simple_local 19 | decoder_kwargs: 20 | sample_mode: bilinear # bilinear / nearest 21 | hidden_size: 32 22 | c_dim: 32 23 | training: 24 | out_dir: out/voxels/shapenet_1plane 25 | batch_size: 64 26 | model_selection_metric: iou 27 | model_selection_mode: maximize 28 | print_every: 100 29 | visualize_every: 10000 30 | validate_every: 10000 31 | checkpoint_every: 2000 32 | backup_every: 10000 33 | n_workers: 8 34 | n_workers_val: 4 35 | test: 36 | threshold: 0.2 37 | eval_mesh: true 38 | eval_pointcloud: false 39 | generation: 40 | vis_n_outputs: 2 41 | refine: false 42 | n_x: 128 43 | n_z: 1 44 | -------------------------------------------------------------------------------- /configs/voxel/shapenet_3plane.yaml: -------------------------------------------------------------------------------- 1 | method: onet 2 | data: 3 | classes: null 4 | input_type: voxels 5 | path: data/ShapeNet 6 | dim: 3 7 | points_subsample: 1024 8 | model: 9 | encoder: voxel_simple_local 10 | encoder_kwargs: 11 | plane_resolution: 64 12 | plane_type: ['xz', 'xy', 'yz'] 13 | unet: True 14 | unet_kwargs: 15 | depth: 4 16 | merge_mode: concat 17 | start_filts: 32 18 | decoder: simple_local 19 | decoder_kwargs: 20 | sample_mode: bilinear # bilinear / nearest 21 | hidden_size: 32 22 | c_dim: 32 23 | training: 24 | out_dir: out/voxels/shapenet_3plane 25 | batch_size: 64 26 | model_selection_metric: iou 27 | model_selection_mode: maximize 28 | print_every: 100 29 | visualize_every: 10000 30 | validate_every: 10000 31 | checkpoint_every: 2000 32 | backup_every: 10000 33 | n_workers: 8 34 | n_workers_val: 4 35 | test: 36 | threshold: 0.2 37 | eval_mesh: true 38 | eval_pointcloud: false 39 | generation: 40 | vis_n_outputs: 2 41 | refine: false 42 | n_x: 128 43 | n_z: 1 44 | -------------------------------------------------------------------------------- /configs/voxel/shapenet_grid32.yaml: -------------------------------------------------------------------------------- 1 | method: onet 2 | data: 3 | classes: null 4 | input_type: voxels 5 | path: data/ShapeNet 6 | dim: 3 7 | points_subsample: 1024 8 | model: 9 | encoder: voxel_simple_local 10 | encoder_kwargs: 11 | grid_resolution: 32 12 | plane_type: 'grid' 13 | unet3d: True 14 | unet3d_kwargs: 15 | num_levels: 3 16 | f_maps: 32 17 | in_channels: 32 18 | out_channels: 32 19 | decoder: simple_local 20 | decoder_kwargs: 21 | sample_mode: bilinear # bilinear / nearest 22 | hidden_size: 32 23 | c_dim: 32 24 | training: 25 | out_dir: out/voxels/shapenet_grid32 26 | batch_size: 64 27 | model_selection_metric: iou 28 | model_selection_mode: maximize 29 | print_every: 100 30 | visualize_every: 10000 31 | validate_every: 10000 32 | checkpoint_every: 2000 33 | backup_every: 10000 34 | n_workers: 8 35 | n_workers_val: 4 36 | test: 37 | threshold: 0.2 38 | eval_mesh: true 39 | eval_pointcloud: false 40 | generation: 41 | vis_n_outputs: 2 42 | refine: false 43 | n_x: 128 44 | n_z: 1 45 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: conv_onet 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - cython=0.29.2 8 | - imageio=2.4.1 9 | - numpy=1.15.4 10 | - numpy-base=1.15.4 11 | - matplotlib=3.0.3 12 | - matplotlib-base=3.0.3 13 | - pandas=0.23.4 14 | - pillow=5.3.0 15 | - pyembree=0.1.4 16 | - pytest=4.0.2 17 | - python=3.6.7 18 | - pytorch=1.4.0 19 | - pyyaml=3.13 20 | - scikit-image=0.14.1 21 | - scipy=1.1.0 22 | - tensorboardx=1.4 23 | - torchvision=0.2.1 24 | - tqdm=4.28.1 25 | - trimesh=2.37.7 26 | - pip: 27 | - h5py==2.9.0 28 | - plyfile==0.7 29 | - torch_scatter==2.0.4 30 | 31 | -------------------------------------------------------------------------------- /eval_meshes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from tqdm import tqdm 4 | import pandas as pd 5 | import trimesh 6 | import torch 7 | from src import config, data 8 | from src.eval import MeshEvaluator 9 | from src.utils.io import load_pointcloud 10 | 11 | 12 | parser = argparse.ArgumentParser( 13 | description='Evaluate mesh algorithms.' 14 | ) 15 | parser.add_argument('config', type=str, help='Path to config file.') 16 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 17 | parser.add_argument('--eval_input', action='store_true', 18 | help='Evaluate inputs instead.') 19 | 20 | args = parser.parse_args() 21 | cfg = config.load_config(args.config, 'configs/default.yaml') 22 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 23 | device = torch.device("cuda" if is_cuda else "cpu") 24 | 25 | # Shorthands 26 | out_dir = cfg['training']['out_dir'] 27 | generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir']) 28 | if not args.eval_input: 29 | out_file = os.path.join(generation_dir, 'eval_meshes_full.pkl') 30 | out_file_class = os.path.join(generation_dir, 'eval_meshes.csv') 31 | else: 32 | out_file = os.path.join(generation_dir, 'eval_input_full.pkl') 33 | out_file_class = os.path.join(generation_dir, 'eval_input.csv') 34 | 35 | # Dataset 36 | points_field = data.PointsField( 37 | cfg['data']['points_iou_file'], 38 | unpackbits=cfg['data']['points_unpackbits'], 39 | multi_files=cfg['data']['multi_files'] 40 | ) 41 | pointcloud_field = data.PointCloudField( 42 | cfg['data']['pointcloud_chamfer_file'], 43 | multi_files=cfg['data']['multi_files'] 44 | ) 45 | fields = { 46 | 'points_iou': points_field, 47 | 'pointcloud_chamfer': pointcloud_field, 48 | 'idx': data.IndexField(), 49 | } 50 | 51 | print('Test split: ', cfg['data']['test_split']) 52 | 53 | dataset_folder = cfg['data']['path'] 54 | dataset = data.Shapes3dDataset( 55 | dataset_folder, fields, 56 | cfg['data']['test_split'], 57 | categories=cfg['data']['classes'], 58 | cfg=cfg 59 | ) 60 | 61 | # Evaluator 62 | evaluator = MeshEvaluator(n_points=100000) 63 | 64 | # Loader 65 | test_loader = torch.utils.data.DataLoader( 66 | dataset, batch_size=1, num_workers=0, shuffle=False) 67 | 68 | # Evaluate all classes 69 | eval_dicts = [] 70 | print('Evaluating meshes...') 71 | for it, data in enumerate(tqdm(test_loader)): 72 | if data is None: 73 | print('Invalid data.') 74 | continue 75 | 76 | # Output folders 77 | if not args.eval_input: 78 | mesh_dir = os.path.join(generation_dir, 'meshes') 79 | pointcloud_dir = os.path.join(generation_dir, 'pointcloud') 80 | else: 81 | mesh_dir = os.path.join(generation_dir, 'input') 82 | pointcloud_dir = os.path.join(generation_dir, 'input') 83 | 84 | # Get index etc. 85 | idx = data['idx'].item() 86 | 87 | try: 88 | model_dict = dataset.get_model_dict(idx) 89 | except AttributeError: 90 | model_dict = {'model': str(idx), 'category': 'n/a'} 91 | 92 | modelname = model_dict['model'] 93 | category_id = model_dict['category'] 94 | 95 | try: 96 | category_name = dataset.metadata[category_id].get('name', 'n/a') 97 | # for room dataset 98 | if category_name == 'n/a': 99 | category_name = category_id 100 | except AttributeError: 101 | category_name = 'n/a' 102 | 103 | if category_id != 'n/a': 104 | mesh_dir = os.path.join(mesh_dir, category_id) 105 | pointcloud_dir = os.path.join(pointcloud_dir, category_id) 106 | 107 | # Evaluate 108 | pointcloud_tgt = data['pointcloud_chamfer'].squeeze(0).numpy() 109 | normals_tgt = data['pointcloud_chamfer.normals'].squeeze(0).numpy() 110 | points_tgt = data['points_iou'].squeeze(0).numpy() 111 | occ_tgt = data['points_iou.occ'].squeeze(0).numpy() 112 | 113 | # Evaluating mesh and pointcloud 114 | # Start row and put basic informatin inside 115 | eval_dict = { 116 | 'idx': idx, 117 | 'class id': category_id, 118 | 'class name': category_name, 119 | 'modelname': modelname, 120 | } 121 | eval_dicts.append(eval_dict) 122 | 123 | # Evaluate mesh 124 | if cfg['test']['eval_mesh']: 125 | mesh_file = os.path.join(mesh_dir, '%s.off' % modelname) 126 | 127 | if os.path.exists(mesh_file): 128 | try: 129 | mesh = trimesh.load(mesh_file, process=False) 130 | eval_dict_mesh = evaluator.eval_mesh( 131 | mesh, pointcloud_tgt, normals_tgt, points_tgt, occ_tgt, remove_wall=cfg['test']['remove_wall']) 132 | for k, v in eval_dict_mesh.items(): 133 | eval_dict[k + ' (mesh)'] = v 134 | except Exception as e: 135 | print("Error: Could not evaluate mesh: %s" % mesh_file) 136 | else: 137 | print('Warning: mesh does not exist: %s' % mesh_file) 138 | 139 | # Evaluate point cloud 140 | if cfg['test']['eval_pointcloud']: 141 | pointcloud_file = os.path.join( 142 | pointcloud_dir, '%s.ply' % modelname) 143 | 144 | if os.path.exists(pointcloud_file): 145 | pointcloud = load_pointcloud(pointcloud_file) 146 | eval_dict_pcl = evaluator.eval_pointcloud( 147 | pointcloud, pointcloud_tgt) 148 | for k, v in eval_dict_pcl.items(): 149 | eval_dict[k + ' (pcl)'] = v 150 | else: 151 | print('Warning: pointcloud does not exist: %s' 152 | % pointcloud_file) 153 | 154 | # Create pandas dataframe and save 155 | eval_df = pd.DataFrame(eval_dicts) 156 | eval_df.set_index(['idx'], inplace=True) 157 | eval_df.to_pickle(out_file) 158 | 159 | # Create CSV file with main statistics 160 | eval_df_class = eval_df.groupby(by=['class name']).mean() 161 | eval_df_class.to_csv(out_file_class) 162 | 163 | # Print results 164 | eval_df_class.loc['mean'] = eval_df_class.mean() 165 | print(eval_df_class) 166 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import shutil 4 | import argparse 5 | from tqdm import tqdm 6 | import time 7 | from collections import defaultdict 8 | import pandas as pd 9 | from src import config 10 | from src.checkpoints import CheckpointIO 11 | from src.utils.io import export_pointcloud 12 | from src.utils.visualize import visualize_data 13 | from src.utils.voxels import VoxelGrid 14 | 15 | 16 | parser = argparse.ArgumentParser( 17 | description='Extract meshes from occupancy process.' 18 | ) 19 | parser.add_argument('config', type=str, help='Path to config file.') 20 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 21 | 22 | args = parser.parse_args() 23 | cfg = config.load_config(args.config, 'configs/default.yaml') 24 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 25 | device = torch.device("cuda" if is_cuda else "cpu") 26 | 27 | out_dir = cfg['training']['out_dir'] 28 | generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir']) 29 | out_time_file = os.path.join(generation_dir, 'time_generation_full.pkl') 30 | out_time_file_class = os.path.join(generation_dir, 'time_generation.pkl') 31 | 32 | input_type = cfg['data']['input_type'] 33 | vis_n_outputs = cfg['generation']['vis_n_outputs'] 34 | if vis_n_outputs is None: 35 | vis_n_outputs = -1 36 | 37 | # Dataset 38 | dataset = config.get_dataset('test', cfg, return_idx=True) 39 | 40 | # Model 41 | model = config.get_model(cfg, device=device, dataset=dataset) 42 | 43 | checkpoint_io = CheckpointIO(out_dir, model=model) 44 | checkpoint_io.load(cfg['test']['model_file']) 45 | 46 | # Generator 47 | generator = config.get_generator(model, cfg, device=device) 48 | 49 | # Determine what to generate 50 | generate_mesh = cfg['generation']['generate_mesh'] 51 | generate_pointcloud = cfg['generation']['generate_pointcloud'] 52 | 53 | if generate_mesh and not hasattr(generator, 'generate_mesh'): 54 | generate_mesh = False 55 | print('Warning: generator does not support mesh generation.') 56 | 57 | if generate_pointcloud and not hasattr(generator, 'generate_pointcloud'): 58 | generate_pointcloud = False 59 | print('Warning: generator does not support pointcloud generation.') 60 | 61 | 62 | # Loader 63 | test_loader = torch.utils.data.DataLoader( 64 | dataset, batch_size=1, num_workers=0, shuffle=False) 65 | 66 | # Statistics 67 | time_dicts = [] 68 | 69 | # Generate 70 | model.eval() 71 | 72 | # Count how many models already created 73 | model_counter = defaultdict(int) 74 | 75 | for it, data in enumerate(tqdm(test_loader)): 76 | # Output folders 77 | mesh_dir = os.path.join(generation_dir, 'meshes') 78 | pointcloud_dir = os.path.join(generation_dir, 'pointcloud') 79 | in_dir = os.path.join(generation_dir, 'input') 80 | generation_vis_dir = os.path.join(generation_dir, 'vis') 81 | 82 | # Get index etc. 83 | idx = data['idx'].item() 84 | 85 | try: 86 | model_dict = dataset.get_model_dict(idx) 87 | except AttributeError: 88 | model_dict = {'model': str(idx), 'category': 'n/a'} 89 | 90 | modelname = model_dict['model'] 91 | category_id = model_dict.get('category', 'n/a') 92 | 93 | try: 94 | category_name = dataset.metadata[category_id].get('name', 'n/a') 95 | except AttributeError: 96 | category_name = 'n/a' 97 | 98 | if category_id != 'n/a': 99 | mesh_dir = os.path.join(mesh_dir, str(category_id)) 100 | pointcloud_dir = os.path.join(pointcloud_dir, str(category_id)) 101 | in_dir = os.path.join(in_dir, str(category_id)) 102 | 103 | folder_name = str(category_id) 104 | if category_name != 'n/a': 105 | folder_name = str(folder_name) + '_' + category_name.split(',')[0] 106 | 107 | generation_vis_dir = os.path.join(generation_vis_dir, folder_name) 108 | 109 | # Create directories if necessary 110 | if vis_n_outputs >= 0 and not os.path.exists(generation_vis_dir): 111 | os.makedirs(generation_vis_dir) 112 | 113 | if generate_mesh and not os.path.exists(mesh_dir): 114 | os.makedirs(mesh_dir) 115 | 116 | if generate_pointcloud and not os.path.exists(pointcloud_dir): 117 | os.makedirs(pointcloud_dir) 118 | 119 | if not os.path.exists(in_dir): 120 | os.makedirs(in_dir) 121 | 122 | # Timing dict 123 | time_dict = { 124 | 'idx': idx, 125 | 'class id': category_id, 126 | 'class name': category_name, 127 | 'modelname': modelname, 128 | } 129 | time_dicts.append(time_dict) 130 | 131 | # Generate outputs 132 | out_file_dict = {} 133 | 134 | # Also copy ground truth 135 | if cfg['generation']['copy_groundtruth']: 136 | modelpath = os.path.join( 137 | dataset.dataset_folder, category_id, modelname, 138 | cfg['data']['watertight_file']) 139 | out_file_dict['gt'] = modelpath 140 | 141 | if generate_mesh: 142 | t0 = time.time() 143 | if cfg['generation']['sliding_window']: 144 | if it == 0: 145 | print('Process scenes in a sliding-window manner') 146 | out = generator.generate_mesh_sliding(data) 147 | else: 148 | out = generator.generate_mesh(data) 149 | time_dict['mesh'] = time.time() - t0 150 | 151 | # Get statistics 152 | try: 153 | mesh, stats_dict = out 154 | except TypeError: 155 | mesh, stats_dict = out, {} 156 | time_dict.update(stats_dict) 157 | 158 | # Write output 159 | mesh_out_file = os.path.join(mesh_dir, '%s.off' % modelname) 160 | mesh.export(mesh_out_file) 161 | out_file_dict['mesh'] = mesh_out_file 162 | 163 | if generate_pointcloud: 164 | t0 = time.time() 165 | pointcloud = generator.generate_pointcloud(data) 166 | time_dict['pcl'] = time.time() - t0 167 | pointcloud_out_file = os.path.join( 168 | pointcloud_dir, '%s.ply' % modelname) 169 | export_pointcloud(pointcloud, pointcloud_out_file) 170 | out_file_dict['pointcloud'] = pointcloud_out_file 171 | 172 | if cfg['generation']['copy_input']: 173 | # Save inputs 174 | if input_type == 'voxels': 175 | inputs_path = os.path.join(in_dir, '%s.off' % modelname) 176 | inputs = data['inputs'].squeeze(0).cpu() 177 | voxel_mesh = VoxelGrid(inputs).to_mesh() 178 | voxel_mesh.export(inputs_path) 179 | out_file_dict['in'] = inputs_path 180 | elif input_type == 'pointcloud_crop': 181 | inputs_path = os.path.join(in_dir, '%s.ply' % modelname) 182 | inputs = data['inputs'].squeeze(0).cpu().numpy() 183 | export_pointcloud(inputs, inputs_path, False) 184 | out_file_dict['in'] = inputs_path 185 | elif input_type == 'pointcloud' or 'partial_pointcloud': 186 | inputs_path = os.path.join(in_dir, '%s.ply' % modelname) 187 | inputs = data['inputs'].squeeze(0).cpu().numpy() 188 | export_pointcloud(inputs, inputs_path, False) 189 | out_file_dict['in'] = inputs_path 190 | 191 | # Copy to visualization directory for first vis_n_output samples 192 | c_it = model_counter[category_id] 193 | if c_it < vis_n_outputs: 194 | # Save output files 195 | img_name = '%02d.off' % c_it 196 | for k, filepath in out_file_dict.items(): 197 | ext = os.path.splitext(filepath)[1] 198 | out_file = os.path.join(generation_vis_dir, '%02d_%s%s' 199 | % (c_it, k, ext)) 200 | shutil.copyfile(filepath, out_file) 201 | 202 | model_counter[category_id] += 1 203 | 204 | # Create pandas dataframe and save 205 | time_df = pd.DataFrame(time_dicts) 206 | time_df.set_index(['idx'], inplace=True) 207 | time_df.to_pickle(out_time_file) 208 | 209 | # Create pickle files with main statistics 210 | time_df_class = time_df.groupby(by=['class name']).mean() 211 | time_df_class.to_pickle(out_time_file_class) 212 | 213 | # Print results 214 | time_df_class.loc['mean'] = time_df_class.mean() 215 | print('Timings [s]:') 216 | print(time_df_class) 217 | -------------------------------------------------------------------------------- /media/demo_syn_room.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/convolutional_occupancy_networks/838bea5b2f1314f2edbb68d05ebb0db49f1f3bd2/media/demo_syn_room.gif -------------------------------------------------------------------------------- /media/teaser_matterport.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/convolutional_occupancy_networks/838bea5b2f1314f2edbb68d05ebb0db49f1f3bd2/media/teaser_matterport.gif -------------------------------------------------------------------------------- /scripts/dataset_matterport/build_dataset.py: -------------------------------------------------------------------------------- 1 | from os import listdir, makedirs, getcwd 2 | from os.path import join, exists, isdir, exists 3 | import json 4 | import trimesh 5 | import numpy as np 6 | from copy import deepcopy 7 | import shutil 8 | import zipfile 9 | from tqdm import tqdm 10 | from src.utils.io import export_pointcloud 11 | 12 | def create_dir(dir_in): 13 | if not exists(dir_in): 14 | makedirs(dir_in) 15 | 16 | base_path = 'data/Matterport3D/v1/scans' 17 | scene_name = 'JmbYfDe2QKZ' 18 | out_path = 'data/Matterport3D_processed' 19 | scene_path = join(base_path, scene_name, 'region_segmentations') 20 | regions = [join(scene_path, 'region'+str(m)+'.ply')) 21 | for m in range(100) if exists(join(scene_path, 'region'+str(m)+'.ply'))] 22 | outfile = join(out_path, scene_name) 23 | create_dir(outfile) 24 | 25 | n_pointcloud_points = 500000 26 | dtype = np.float16 27 | cut_mesh =True 28 | save_part_mesh = False 29 | 30 | mat_permute = np.array([ 31 | [1, 0, 0, 0], 32 | [0, 0, 1, 0], 33 | [0, 1, 0, 0], 34 | [0, 0, 0, 1]]) 35 | for idx, r_path in tqdm(enumerate(regions)): 36 | mesh = trimesh.load(r_path) 37 | z_max = max(mesh.vertices[:, 2]) 38 | z_range = max(mesh.vertices[:, 2]) - min(mesh.vertices[:, 2]) 39 | x_min = min(mesh.vertices[:, 0]) 40 | y_min = min(mesh.vertices[:, 1]) 41 | # For better visualization, cut the ceilings and parts of walls 42 | if cut_mesh: 43 | mesh = trimesh.intersections.slice_mesh_plane(mesh, np.array([0, 0, -1]), np.array([0, 0, z_max - 0.5*z_range])) 44 | # mesh = trimesh.intersections.slice_mesh_plane(mesh, np.array([0, 1, 0]), np.array([0, y_min + 0.5, 0])) 45 | mesh = trimesh.intersections.slice_mesh_plane(mesh, np.array([1, 0, 0]), np.array([x_min + 0.2, 0, 0])) 46 | mesh = deepcopy(mesh) 47 | mesh.apply_transform(mat_permute) 48 | if save_part_mesh == True: 49 | out_file = join(outfile, 'mesh_fused%d.ply'%idx) 50 | mesh.export(out_file) 51 | 52 | if idx == 0: 53 | faces = mesh.faces 54 | vertices = mesh.vertices 55 | else: 56 | faces = np.concatenate([faces, mesh.faces + vertices.shape[0]]) 57 | vertices = np.concatenate([vertices, mesh.vertices]) 58 | 59 | mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) 60 | out_file = join(outfile, 'mesh_fused.ply') 61 | mesh.export(out_file) 62 | 63 | # Sample surface points 64 | pcl, face_idx = mesh.sample(n_pointcloud_points, return_index=True) 65 | normals = mesh.face_normals[face_idx] 66 | 67 | # save surface points 68 | out_file = join(outfile, 'pointcloud.npz') 69 | np.savez(out_file, points=pcl.astype(dtype), normals=normals.astype(dtype)) 70 | export_pointcloud(pcl, join(outfile, 'pointcloud.ply')) 71 | 72 | # create test.lst file 73 | with open(join(out_path, 'test.lst'), "w") as file: 74 | file.write(scene_name) -------------------------------------------------------------------------------- /scripts/dataset_scannet/SensorData.py: -------------------------------------------------------------------------------- 1 | # Original code from ScanNet data exporter: https://github.com/ScanNet/ScanNet/blob/master/SensReader/python/SensorData.py 2 | # Adapted for Python 3: https://github.com/daveredrum/ScanNet/edit/master/SensReader/python/SensorData.py 3 | import os, struct 4 | import numpy as np 5 | import zlib 6 | import imageio 7 | import cv2 8 | from tqdm import tqdm 9 | import torch 10 | import logging 11 | logging.basicConfig(filename='scannet_generation.log',level=logging.DEBUG) 12 | 13 | 14 | COMPRESSION_TYPE_COLOR = {-1:'unknown', 0:'raw', 1:'png', 2:'jpeg'} 15 | COMPRESSION_TYPE_DEPTH = {-1:'unknown', 0:'raw_ushort', 1:'zlib_ushort', 2:'occi_ushort'} 16 | 17 | class RGBDFrame(): 18 | 19 | def load(self, file_handle): 20 | self.camera_to_world = np.asarray(struct.unpack('f'*16, file_handle.read(16*4)), dtype=np.float32).reshape(4, 4) 21 | self.timestamp_color = struct.unpack('Q', file_handle.read(8))[0] 22 | self.timestamp_depth = struct.unpack('Q', file_handle.read(8))[0] 23 | self.color_size_bytes = struct.unpack('Q', file_handle.read(8))[0] 24 | self.depth_size_bytes = struct.unpack('Q', file_handle.read(8))[0] 25 | self.color_data = b''.join(struct.unpack('c'*self.color_size_bytes, file_handle.read(self.color_size_bytes))) 26 | self.depth_data = b''.join(struct.unpack('c'*self.depth_size_bytes, file_handle.read(self.depth_size_bytes))) 27 | 28 | 29 | def decompress_depth(self, compression_type): 30 | if compression_type == 'zlib_ushort': 31 | return self.decompress_depth_zlib() 32 | else: 33 | raise ValueError("invalid type") 34 | 35 | 36 | def decompress_depth_zlib(self): 37 | return zlib.decompress(self.depth_data) 38 | 39 | 40 | def decompress_color(self, compression_type): 41 | if compression_type == 'jpeg': 42 | return self.decompress_color_jpeg() 43 | else: 44 | raise ValueError("invalid type") 45 | 46 | 47 | def decompress_color_jpeg(self): 48 | return imageio.imread(self.color_data) 49 | 50 | 51 | class SensorData: 52 | 53 | def __init__(self, filename): 54 | self.version = 4 55 | self.load(filename) 56 | 57 | 58 | def load(self, filename): 59 | with open(filename, 'rb') as f: 60 | version = struct.unpack('I', f.read(4))[0] 61 | assert self.version == version 62 | strlen = struct.unpack('Q', f.read(8))[0] 63 | self.sensor_name = b''.join(struct.unpack('c'*strlen, f.read(strlen))) 64 | self.intrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 65 | self.extrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 66 | self.intrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 67 | self.extrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 68 | self.color_compression_type = COMPRESSION_TYPE_COLOR[struct.unpack('i', f.read(4))[0]] 69 | self.depth_compression_type = COMPRESSION_TYPE_DEPTH[struct.unpack('i', f.read(4))[0]] 70 | self.color_width = struct.unpack('I', f.read(4))[0] 71 | self.color_height = struct.unpack('I', f.read(4))[0] 72 | self.depth_width = struct.unpack('I', f.read(4))[0] 73 | self.depth_height = struct.unpack('I', f.read(4))[0] 74 | self.depth_shift = struct.unpack('f', f.read(4))[0] 75 | num_frames = struct.unpack('Q', f.read(8))[0] 76 | print('Number of frames: %d' % num_frames) 77 | self.frames = [] 78 | print('loading', filename) 79 | for i in tqdm(range(num_frames)): 80 | frame = RGBDFrame() 81 | frame.load(f) 82 | self.frames.append(frame) 83 | 84 | def extract_depth_images(self, image_size=None, frame_skip=1): 85 | depth_list, cam_pose_list = [], [] 86 | cam_intr = self.intrinsic_depth[:3, :3] 87 | print('extracting', len(self.frames)//frame_skip, ' depth maps') 88 | for f in tqdm(range(0, len(self.frames), frame_skip)): 89 | depth_data = self.frames[f].decompress_depth(self.depth_compression_type) 90 | 91 | depth = np.fromstring(depth_data, dtype=np.uint16).reshape(self.depth_height, self.depth_width) 92 | depth = depth.astype(float)/1000. 93 | depth_list.append(depth) 94 | 95 | cam_pose_list.append(self.frames[f].camera_to_world) 96 | return depth_list, cam_pose_list, cam_intr 97 | 98 | def export_depth_images(self, output_path, image_size=None, frame_skip=1): 99 | if not os.path.exists(output_path): 100 | os.makedirs(output_path) 101 | print('exporting', len(self.frames)//frame_skip, ' depth frames to', output_path) 102 | for f in tqdm(range(0, len(self.frames), frame_skip)): 103 | depth_data = self.frames[f].decompress_depth(self.depth_compression_type) 104 | depth = np.fromstring(depth_data, dtype=np.uint16).reshape(self.depth_height, self.depth_width) 105 | if image_size is not None: 106 | depth = cv2.resize(depth, (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST) 107 | # imageio.imwrite(os.path.join(output_path, str(f) + '.png'), depth) 108 | imageio.imwrite(os.path.join(output_path, '%06d.png' % f), depth) 109 | 110 | def export_color_images(self, output_path, image_size=None, frame_skip=1): 111 | if not os.path.exists(output_path): 112 | os.makedirs(output_path) 113 | print('exporting', len(self.frames)//frame_skip, 'color frames to', output_path) 114 | for f in range(0, len(self.frames), frame_skip): 115 | color = self.frames[f].decompress_color(self.color_compression_type) 116 | if image_size is not None: 117 | color = cv2.resize(color, (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST) 118 | imageio.imwrite(os.path.join(output_path, str(f) + '.jpg'), color) 119 | 120 | 121 | def save_mat_to_file(self, matrix, filename): 122 | with open(filename, 'w') as f: 123 | for line in matrix: 124 | np.savetxt(f, line[np.newaxis], fmt='%f') 125 | 126 | 127 | def export_poses(self, output_path, frame_skip=1): 128 | if not os.path.exists(output_path): 129 | os.makedirs(output_path) 130 | print('exporting', len(self.frames)//frame_skip, 'camera poses to', output_path) 131 | for f in range(0, len(self.frames), frame_skip): 132 | self.save_mat_to_file(self.frames[f].camera_to_world, os.path.join(output_path, str(f) + '.txt')) 133 | 134 | 135 | def export_intrinsics(self, output_path): 136 | if not os.path.exists(output_path): 137 | os.makedirs(output_path) 138 | print('exporting camera intrinsics to', output_path) 139 | self.save_mat_to_file(self.intrinsic_color, os.path.join(output_path, 'intrinsic_color.txt')) 140 | self.save_mat_to_file(self.extrinsic_color, os.path.join(output_path, 'extrinsic_color.txt')) 141 | self.save_mat_to_file(self.intrinsic_depth, os.path.join(output_path, 'intrinsic_depth.txt')) 142 | self.save_mat_to_file(self.extrinsic_depth, os.path.join(output_path, 'extrinsic_depth.txt')) 143 | 144 | 145 | def get_scale_mat(self, mesh, padding=0.0): 146 | bbox = mesh.bounds 147 | loc = (bbox[0] + bbox[1]) / 2 148 | scale = (1 - 2 * padding) / (bbox[1] - bbox[0]).max() 149 | S = np.eye(4) * scale 150 | S[:-1, -1] = -scale*loc 151 | S[-1, -1] = 1 152 | S_inv = np.linalg.inv(S) 153 | return S, S_inv 154 | 155 | 156 | def process_camera_dict(self, output_path, scale_matrix, resolution=(480, 640)): 157 | h, w = resolution 158 | 159 | #_, scale_matrix = self.get_scale_mat(mesh) 160 | 161 | out_dict = {} 162 | instrinsic_mat = self.intrinsic_depth 163 | # scale pixels to [-1, 1] 164 | scale_mat = np.array([ 165 | [2. / (w-1), 0, -1, 0], 166 | [0, 2./(h-1), -1, 0], 167 | [0, 0, 1, 0], 168 | [0, 0, 0, 1], 169 | ]) 170 | camera_mat = scale_mat @ instrinsic_mat 171 | mask_camera = [] 172 | for f in range(0, len(self.frames), 1): 173 | out_dict['camera_mat_%d' % f] = camera_mat.astype(np.float32) 174 | 175 | world_mat_inv = self.frames[f].camera_to_world 176 | if np.any(np.isnan(world_mat_inv)) or np.any(np.isinf(world_mat_inv)): 177 | logging.warning('inf world mat for %s: %d' % (output_path, f)) 178 | print('invalid world matrix!') 179 | mask_camera.append(f) 180 | 181 | try: 182 | world_mat = np.linalg.inv(world_mat_inv) 183 | except e: 184 | world_mat = np.linalg.pinv(world_mat_inv) 185 | 186 | out_dict['world_mat_%d' % f] = world_mat.astype(np.float32) 187 | out_dict['scale_mat_%d' % f] = scale_matrix.astype(np.float32) 188 | out_dict['camera_mask'] = mask_camera 189 | out_file = os.path.join(output_path, 'cameras.npz') 190 | np.savez(out_file, **out_dict) 191 | -------------------------------------------------------------------------------- /scripts/dataset_scannet/build_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import argparse 5 | import trimesh 6 | from SensorData import SensorData 7 | import tqdm 8 | from os.path import join 9 | from os import listdir 10 | import numpy as np 11 | import multiprocessing 12 | 13 | path_in = '/dir/to/scannet_v2' 14 | path_out = '/dir/to/scannet_out' 15 | 16 | if not os.path.exists(path_out): 17 | os.makedirs(path_out) 18 | 19 | path_out = join(path_out, 'scenes') 20 | if not os.path.exists(path_out): 21 | os.makedirs(path_out) 22 | 23 | 24 | def align_axis(file_name, mesh): 25 | rotation_matrix = np.array([ 26 | [1., 0., 0., 0.], 27 | [0., 0., 1., 0.], 28 | [0., 1., 0., 0.], 29 | [0., 0., 0., 1.], 30 | ]) 31 | lines = open(file_name).readlines() 32 | for line in lines: 33 | if 'axisAlignment' in line: 34 | axis_align_matrix = [float(x) for x in line.rstrip().strip('axisAlignment = ').split(' ')] 35 | break 36 | axis_align_matrix = np.array(axis_align_matrix).reshape((4,4)) 37 | axis_align_matrix = rotation_matrix @ axis_align_matrix 38 | mesh.apply_transform(axis_align_matrix) 39 | return mesh, axis_align_matrix 40 | 41 | def sample_points(mesh, n_points=100000, p_type=np.float16): 42 | pcl, idx = mesh.sample(n_points, return_index=True) 43 | normals = mesh.face_normals[idx] 44 | out_dict = { 45 | 'points': pcl.astype(p_type), 46 | 'normals': normals.astype(p_type), 47 | 48 | } 49 | return out_dict 50 | 51 | def scale_to_unit_cube(mesh, y_level=-0.5): 52 | bbox = mesh.bounds 53 | loc = (bbox[0] + bbox[1]) / 2 54 | scale = 1. / (bbox[1] - bbox[0]).max() 55 | vertices_t = (mesh.vertices - loc.reshape(-1, 3)) * scale 56 | y_min = min(vertices_t[:, 1]) 57 | 58 | # create_transform_matrix 59 | S_loc = np.eye(4) 60 | S_loc[:-1, -1] = -loc 61 | # create scale mat 62 | S_scale = np.eye(4) * scale 63 | S_scale[-1, -1] = 1 64 | # create last translate matrix 65 | S_loc2 = np.eye(4) 66 | S_loc2[1, -1] = -y_min + y_level 67 | 68 | S = S_loc2 @ S_scale @ S_loc 69 | mesh.apply_transform(S) 70 | 71 | return mesh, S 72 | 73 | 74 | def process(scene_name): 75 | out_path_cur = os.path.join(path_out, scene_name) 76 | if not os.path.exists(out_path_cur): 77 | os.makedirs(out_path_cur) 78 | 79 | # load mesh 80 | mesh = trimesh.load(os.path.join(path_in, scene_name, scene_name+'_vh_clean.ply'), process=False) 81 | txt_file = os.path.join(path_in, scene_name, '%s.txt' % scene_name) 82 | mesh, align_mat = align_axis(txt_file, mesh) 83 | mesh, scale_mat = scale_to_unit_cube(mesh) 84 | scale_matrix = np.linalg.inv(scale_mat @ align_mat) 85 | 86 | file_cur = os.path.join(path_in, scene_name, scene_name+'.sens') 87 | sd = SensorData(file_cur) 88 | sd.export_depth_images(os.path.join(path_out, scene_name, 'depth'), frame_skip=1) 89 | sd.process_camera_dict(join(path_out, scene_name), scale_matrix) 90 | pcl = sample_points(mesh) 91 | out_file = join(path_out, scene_name, 'pointcloud.npz') 92 | np.savez(out_file, **pcl) 93 | 94 | file_list = listdir(path_in) 95 | file_list.sort() 96 | pbar = tqdm.tqdm() 97 | pool = multiprocessing.Pool(processes=8) 98 | for f in file_list: 99 | pool.apply_async(process, args=(f,), callback=lambda _: pbar.update()) 100 | pool.close() 101 | pool.join() 102 | -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p data 3 | cd data 4 | echo "Start downloading ..." 5 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/data/synthetic_room_dataset.zip 6 | unzip synthetic_room_dataset.zip 7 | echo "Done!" -------------------------------------------------------------------------------- /scripts/download_demo_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p data 3 | cd data 4 | echo "Downloading demo data..." 5 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/data/demo_data.zip 6 | unzip demo_data.zip 7 | echo "Done!" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | try: 2 | from setuptools import setup 3 | except ImportError: 4 | from distutils.core import setup 5 | from distutils.extension import Extension 6 | from Cython.Build import cythonize 7 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 8 | import numpy 9 | 10 | 11 | # Get the numpy include directory. 12 | numpy_include_dir = numpy.get_include() 13 | 14 | # Extensions 15 | # pykdtree (kd tree) 16 | pykdtree = Extension( 17 | 'src.utils.libkdtree.pykdtree.kdtree', 18 | sources=[ 19 | 'src/utils/libkdtree/pykdtree/kdtree.c', 20 | 'src/utils/libkdtree/pykdtree/_kdtree_core.c' 21 | ], 22 | language='c', 23 | extra_compile_args=['-std=c99', '-O3', '-fopenmp'], 24 | extra_link_args=['-lgomp'], 25 | include_dirs=[numpy_include_dir] 26 | ) 27 | 28 | # mcubes (marching cubes algorithm) 29 | mcubes_module = Extension( 30 | 'src.utils.libmcubes.mcubes', 31 | sources=[ 32 | 'src/utils/libmcubes/mcubes.pyx', 33 | 'src/utils/libmcubes/pywrapper.cpp', 34 | 'src/utils/libmcubes/marchingcubes.cpp' 35 | ], 36 | language='c++', 37 | extra_compile_args=['-std=c++11'], 38 | include_dirs=[numpy_include_dir] 39 | ) 40 | 41 | # triangle hash (efficient mesh intersection) 42 | triangle_hash_module = Extension( 43 | 'src.utils.libmesh.triangle_hash', 44 | sources=[ 45 | 'src/utils/libmesh/triangle_hash.pyx' 46 | ], 47 | libraries=['m'], # Unix-like specific 48 | include_dirs=[numpy_include_dir] 49 | ) 50 | 51 | # mise (efficient mesh extraction) 52 | mise_module = Extension( 53 | 'src.utils.libmise.mise', 54 | sources=[ 55 | 'src/utils/libmise/mise.pyx' 56 | ], 57 | ) 58 | 59 | # simplify (efficient mesh simplification) 60 | simplify_mesh_module = Extension( 61 | 'src.utils.libsimplify.simplify_mesh', 62 | sources=[ 63 | 'src/utils/libsimplify/simplify_mesh.pyx' 64 | ], 65 | include_dirs=[numpy_include_dir] 66 | ) 67 | 68 | # voxelization (efficient mesh voxelization) 69 | voxelize_module = Extension( 70 | 'src.utils.libvoxelize.voxelize', 71 | sources=[ 72 | 'src/utils/libvoxelize/voxelize.pyx' 73 | ], 74 | libraries=['m'] # Unix-like specific 75 | ) 76 | 77 | # Gather all extension modules 78 | ext_modules = [ 79 | pykdtree, 80 | mcubes_module, 81 | triangle_hash_module, 82 | mise_module, 83 | simplify_mesh_module, 84 | voxelize_module, 85 | ] 86 | 87 | setup( 88 | ext_modules=cythonize(ext_modules), 89 | cmdclass={ 90 | 'build_ext': BuildExtension 91 | } 92 | ) 93 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/convolutional_occupancy_networks/838bea5b2f1314f2edbb68d05ebb0db49f1f3bd2/src/__init__.py -------------------------------------------------------------------------------- /src/checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib 3 | import torch 4 | from torch.utils import model_zoo 5 | 6 | 7 | class CheckpointIO(object): 8 | ''' CheckpointIO class. 9 | 10 | It handles saving and loading checkpoints. 11 | 12 | Args: 13 | checkpoint_dir (str): path where checkpoints are saved 14 | ''' 15 | def __init__(self, checkpoint_dir='./chkpts', **kwargs): 16 | self.module_dict = kwargs 17 | self.checkpoint_dir = checkpoint_dir 18 | if not os.path.exists(checkpoint_dir): 19 | os.makedirs(checkpoint_dir) 20 | 21 | def register_modules(self, **kwargs): 22 | ''' Registers modules in current module dictionary. 23 | ''' 24 | self.module_dict.update(kwargs) 25 | 26 | def save(self, filename, **kwargs): 27 | ''' Saves the current module dictionary. 28 | 29 | Args: 30 | filename (str): name of output file 31 | ''' 32 | if not os.path.isabs(filename): 33 | filename = os.path.join(self.checkpoint_dir, filename) 34 | 35 | outdict = kwargs 36 | for k, v in self.module_dict.items(): 37 | outdict[k] = v.state_dict() 38 | torch.save(outdict, filename) 39 | 40 | def load(self, filename): 41 | '''Loads a module dictionary from local file or url. 42 | 43 | Args: 44 | filename (str): name of saved module dictionary 45 | ''' 46 | if is_url(filename): 47 | return self.load_url(filename) 48 | else: 49 | return self.load_file(filename) 50 | 51 | def load_file(self, filename): 52 | '''Loads a module dictionary from file. 53 | 54 | Args: 55 | filename (str): name of saved module dictionary 56 | ''' 57 | 58 | if not os.path.isabs(filename): 59 | filename = os.path.join(self.checkpoint_dir, filename) 60 | 61 | if os.path.exists(filename): 62 | print(filename) 63 | print('=> Loading checkpoint from local file...') 64 | state_dict = torch.load(filename) 65 | scalars = self.parse_state_dict(state_dict) 66 | return scalars 67 | else: 68 | raise FileExistsError 69 | 70 | def load_url(self, url): 71 | '''Load a module dictionary from url. 72 | 73 | Args: 74 | url (str): url to saved model 75 | ''' 76 | print(url) 77 | print('=> Loading checkpoint from url...') 78 | state_dict = model_zoo.load_url(url, progress=True) 79 | scalars = self.parse_state_dict(state_dict) 80 | return scalars 81 | 82 | def parse_state_dict(self, state_dict): 83 | '''Parse state_dict of model and return scalars. 84 | 85 | Args: 86 | state_dict (dict): State dict of model 87 | ''' 88 | 89 | for k, v in self.module_dict.items(): 90 | if k in state_dict: 91 | v.load_state_dict(state_dict[k]) 92 | else: 93 | print('Warning: Could not find %s in checkpoint!' % k) 94 | scalars = {k: v for k, v in state_dict.items() 95 | if k not in self.module_dict} 96 | return scalars 97 | 98 | def is_url(url): 99 | scheme = urllib.parse.urlparse(url).scheme 100 | return scheme in ('http', 'https') -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from torchvision import transforms 3 | from src import data 4 | from src import conv_onet 5 | 6 | 7 | method_dict = { 8 | 'conv_onet': conv_onet 9 | } 10 | 11 | 12 | # General config 13 | def load_config(path, default_path=None): 14 | ''' Loads config file. 15 | 16 | Args: 17 | path (str): path to config file 18 | default_path (bool): whether to use default path 19 | ''' 20 | # Load configuration from file itself 21 | with open(path, 'r') as f: 22 | cfg_special = yaml.load(f) 23 | 24 | # Check if we should inherit from a config 25 | inherit_from = cfg_special.get('inherit_from') 26 | 27 | # If yes, load this config first as default 28 | # If no, use the default_path 29 | if inherit_from is not None: 30 | cfg = load_config(inherit_from, default_path) 31 | elif default_path is not None: 32 | with open(default_path, 'r') as f: 33 | cfg = yaml.load(f) 34 | else: 35 | cfg = dict() 36 | 37 | # Include main configuration 38 | update_recursive(cfg, cfg_special) 39 | 40 | return cfg 41 | 42 | 43 | def update_recursive(dict1, dict2): 44 | ''' Update two config dictionaries recursively. 45 | 46 | Args: 47 | dict1 (dict): first dictionary to be updated 48 | dict2 (dict): second dictionary which entries should be used 49 | 50 | ''' 51 | for k, v in dict2.items(): 52 | if k not in dict1: 53 | dict1[k] = dict() 54 | if isinstance(v, dict): 55 | update_recursive(dict1[k], v) 56 | else: 57 | dict1[k] = v 58 | 59 | 60 | # Models 61 | def get_model(cfg, device=None, dataset=None): 62 | ''' Returns the model instance. 63 | 64 | Args: 65 | cfg (dict): config dictionary 66 | device (device): pytorch device 67 | dataset (dataset): dataset 68 | ''' 69 | method = cfg['method'] 70 | model = method_dict[method].config.get_model( 71 | cfg, device=device, dataset=dataset) 72 | return model 73 | 74 | 75 | # Trainer 76 | def get_trainer(model, optimizer, cfg, device): 77 | ''' Returns a trainer instance. 78 | 79 | Args: 80 | model (nn.Module): the model which is used 81 | optimizer (optimizer): pytorch optimizer 82 | cfg (dict): config dictionary 83 | device (device): pytorch device 84 | ''' 85 | method = cfg['method'] 86 | trainer = method_dict[method].config.get_trainer( 87 | model, optimizer, cfg, device) 88 | return trainer 89 | 90 | 91 | # Generator for final mesh extraction 92 | def get_generator(model, cfg, device): 93 | ''' Returns a generator instance. 94 | 95 | Args: 96 | model (nn.Module): the model which is used 97 | cfg (dict): config dictionary 98 | device (device): pytorch device 99 | ''' 100 | method = cfg['method'] 101 | generator = method_dict[method].config.get_generator(model, cfg, device) 102 | return generator 103 | 104 | 105 | # Datasets 106 | def get_dataset(mode, cfg, return_idx=False): 107 | ''' Returns the dataset. 108 | 109 | Args: 110 | model (nn.Module): the model which is used 111 | cfg (dict): config dictionary 112 | return_idx (bool): whether to include an ID field 113 | ''' 114 | method = cfg['method'] 115 | dataset_type = cfg['data']['dataset'] 116 | dataset_folder = cfg['data']['path'] 117 | categories = cfg['data']['classes'] 118 | 119 | # Get split 120 | splits = { 121 | 'train': cfg['data']['train_split'], 122 | 'val': cfg['data']['val_split'], 123 | 'test': cfg['data']['test_split'], 124 | } 125 | 126 | split = splits[mode] 127 | 128 | # Create dataset 129 | if dataset_type == 'Shapes3D': 130 | # Dataset fields 131 | # Method specific fields (usually correspond to output) 132 | fields = method_dict[method].config.get_data_fields(mode, cfg) 133 | # Input fields 134 | inputs_field = get_inputs_field(mode, cfg) 135 | if inputs_field is not None: 136 | fields['inputs'] = inputs_field 137 | 138 | if return_idx: 139 | fields['idx'] = data.IndexField() 140 | 141 | dataset = data.Shapes3dDataset( 142 | dataset_folder, fields, 143 | split=split, 144 | categories=categories, 145 | cfg = cfg 146 | ) 147 | else: 148 | raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset']) 149 | 150 | return dataset 151 | 152 | 153 | def get_inputs_field(mode, cfg): 154 | ''' Returns the inputs fields. 155 | 156 | Args: 157 | mode (str): the mode which is used 158 | cfg (dict): config dictionary 159 | ''' 160 | input_type = cfg['data']['input_type'] 161 | 162 | if input_type is None: 163 | inputs_field = None 164 | elif input_type == 'pointcloud': 165 | transform = transforms.Compose([ 166 | data.SubsamplePointcloud(cfg['data']['pointcloud_n']), 167 | data.PointcloudNoise(cfg['data']['pointcloud_noise']) 168 | ]) 169 | inputs_field = data.PointCloudField( 170 | cfg['data']['pointcloud_file'], transform, 171 | multi_files= cfg['data']['multi_files'] 172 | ) 173 | elif input_type == 'partial_pointcloud': 174 | transform = transforms.Compose([ 175 | data.SubsamplePointcloud(cfg['data']['pointcloud_n']), 176 | data.PointcloudNoise(cfg['data']['pointcloud_noise']) 177 | ]) 178 | inputs_field = data.PartialPointCloudField( 179 | cfg['data']['pointcloud_file'], transform, 180 | multi_files= cfg['data']['multi_files'] 181 | ) 182 | elif input_type == 'pointcloud_crop': 183 | transform = transforms.Compose([ 184 | data.SubsamplePointcloud(cfg['data']['pointcloud_n']), 185 | data.PointcloudNoise(cfg['data']['pointcloud_noise']) 186 | ]) 187 | 188 | inputs_field = data.PatchPointCloudField( 189 | cfg['data']['pointcloud_file'], 190 | transform, 191 | multi_files= cfg['data']['multi_files'], 192 | ) 193 | 194 | elif input_type == 'voxels': 195 | inputs_field = data.VoxelsField( 196 | cfg['data']['voxels_file'] 197 | ) 198 | elif input_type == 'idx': 199 | inputs_field = data.IndexField() 200 | else: 201 | raise ValueError( 202 | 'Invalid input type (%s)' % input_type) 203 | return inputs_field -------------------------------------------------------------------------------- /src/conv_onet/__init__.py: -------------------------------------------------------------------------------- 1 | from src.conv_onet import ( 2 | config, generation, training, models 3 | ) 4 | 5 | __all__ = [ 6 | config, generation, training, models 7 | ] 8 | -------------------------------------------------------------------------------- /src/conv_onet/config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions as dist 3 | from torch import nn 4 | import os 5 | from src.encoder import encoder_dict 6 | from src.conv_onet import models, training 7 | from src.conv_onet import generation 8 | from src import data 9 | from src import config 10 | from src.common import decide_total_volume_range, update_reso 11 | from torchvision import transforms 12 | import numpy as np 13 | 14 | 15 | def get_model(cfg, device=None, dataset=None, **kwargs): 16 | ''' Return the Occupancy Network model. 17 | 18 | Args: 19 | cfg (dict): imported yaml config 20 | device (device): pytorch device 21 | dataset (dataset): dataset 22 | ''' 23 | decoder = cfg['model']['decoder'] 24 | encoder = cfg['model']['encoder'] 25 | dim = cfg['data']['dim'] 26 | c_dim = cfg['model']['c_dim'] 27 | decoder_kwargs = cfg['model']['decoder_kwargs'] 28 | encoder_kwargs = cfg['model']['encoder_kwargs'] 29 | padding = cfg['data']['padding'] 30 | 31 | # for pointcloud_crop 32 | try: 33 | encoder_kwargs['unit_size'] = cfg['data']['unit_size'] 34 | decoder_kwargs['unit_size'] = cfg['data']['unit_size'] 35 | except: 36 | pass 37 | # local positional encoding 38 | if 'local_coord' in cfg['model'].keys(): 39 | encoder_kwargs['local_coord'] = cfg['model']['local_coord'] 40 | decoder_kwargs['local_coord'] = cfg['model']['local_coord'] 41 | if 'pos_encoding' in cfg['model']: 42 | encoder_kwargs['pos_encoding'] = cfg['model']['pos_encoding'] 43 | decoder_kwargs['pos_encoding'] = cfg['model']['pos_encoding'] 44 | 45 | # update the feature volume/plane resolution 46 | if cfg['data']['input_type'] == 'pointcloud_crop': 47 | fea_type = cfg['model']['encoder_kwargs']['plane_type'] 48 | if (dataset.split == 'train') or (cfg['generation']['sliding_window']): 49 | recep_field = 2**(cfg['model']['encoder_kwargs']['unet3d_kwargs']['num_levels'] + 2) 50 | reso = cfg['data']['query_vol_size'] + recep_field - 1 51 | if 'grid' in fea_type: 52 | encoder_kwargs['grid_resolution'] = update_reso(reso, dataset.depth) 53 | if bool(set(fea_type) & set(['xz', 'xy', 'yz'])): 54 | encoder_kwargs['plane_resolution'] = update_reso(reso, dataset.depth) 55 | # if dataset.split == 'val': #TODO run validation in room level during training 56 | else: 57 | if 'grid' in fea_type: 58 | encoder_kwargs['grid_resolution'] = dataset.total_reso 59 | if bool(set(fea_type) & set(['xz', 'xy', 'yz'])): 60 | encoder_kwargs['plane_resolution'] = dataset.total_reso 61 | 62 | 63 | decoder = models.decoder_dict[decoder]( 64 | dim=dim, c_dim=c_dim, padding=padding, 65 | **decoder_kwargs 66 | ) 67 | 68 | if encoder == 'idx': 69 | encoder = nn.Embedding(len(dataset), c_dim) 70 | elif encoder is not None: 71 | encoder = encoder_dict[encoder]( 72 | dim=dim, c_dim=c_dim, padding=padding, 73 | **encoder_kwargs 74 | ) 75 | else: 76 | encoder = None 77 | 78 | model = models.ConvolutionalOccupancyNetwork( 79 | decoder, encoder, device=device 80 | ) 81 | 82 | return model 83 | 84 | 85 | def get_trainer(model, optimizer, cfg, device, **kwargs): 86 | ''' Returns the trainer object. 87 | 88 | Args: 89 | model (nn.Module): the Occupancy Network model 90 | optimizer (optimizer): pytorch optimizer object 91 | cfg (dict): imported yaml config 92 | device (device): pytorch device 93 | ''' 94 | threshold = cfg['test']['threshold'] 95 | out_dir = cfg['training']['out_dir'] 96 | vis_dir = os.path.join(out_dir, 'vis') 97 | input_type = cfg['data']['input_type'] 98 | 99 | trainer = training.Trainer( 100 | model, optimizer, 101 | device=device, input_type=input_type, 102 | vis_dir=vis_dir, threshold=threshold, 103 | eval_sample=cfg['training']['eval_sample'], 104 | ) 105 | 106 | return trainer 107 | 108 | 109 | def get_generator(model, cfg, device, **kwargs): 110 | ''' Returns the generator object. 111 | 112 | Args: 113 | model (nn.Module): Occupancy Network model 114 | cfg (dict): imported yaml config 115 | device (device): pytorch device 116 | ''' 117 | 118 | if cfg['data']['input_type'] == 'pointcloud_crop': 119 | # calculate the volume boundary 120 | query_vol_metric = cfg['data']['padding'] + 1 121 | unit_size = cfg['data']['unit_size'] 122 | recep_field = 2**(cfg['model']['encoder_kwargs']['unet3d_kwargs']['num_levels'] + 2) 123 | if 'unet' in cfg['model']['encoder_kwargs']: 124 | depth = cfg['model']['encoder_kwargs']['unet_kwargs']['depth'] 125 | elif 'unet3d' in cfg['model']['encoder_kwargs']: 126 | depth = cfg['model']['encoder_kwargs']['unet3d_kwargs']['num_levels'] 127 | 128 | vol_info = decide_total_volume_range(query_vol_metric, recep_field, unit_size, depth) 129 | 130 | grid_reso = cfg['data']['query_vol_size'] + recep_field - 1 131 | grid_reso = update_reso(grid_reso, depth) 132 | query_vol_size = cfg['data']['query_vol_size'] * unit_size 133 | input_vol_size = grid_reso * unit_size 134 | # only for the sliding window case 135 | vol_bound = None 136 | if cfg['generation']['sliding_window']: 137 | vol_bound = {'query_crop_size': query_vol_size, 138 | 'input_crop_size': input_vol_size, 139 | 'fea_type': cfg['model']['encoder_kwargs']['plane_type'], 140 | 'reso': grid_reso} 141 | 142 | else: 143 | vol_bound = None 144 | vol_info = None 145 | 146 | generator = generation.Generator3D( 147 | model, 148 | device=device, 149 | threshold=cfg['test']['threshold'], 150 | resolution0=cfg['generation']['resolution_0'], 151 | upsampling_steps=cfg['generation']['upsampling_steps'], 152 | sample=cfg['generation']['use_sampling'], 153 | refinement_step=cfg['generation']['refinement_step'], 154 | simplify_nfaces=cfg['generation']['simplify_nfaces'], 155 | input_type = cfg['data']['input_type'], 156 | padding=cfg['data']['padding'], 157 | vol_info = vol_info, 158 | vol_bound = vol_bound, 159 | ) 160 | return generator 161 | 162 | 163 | def get_data_fields(mode, cfg): 164 | ''' Returns the data fields. 165 | 166 | Args: 167 | mode (str): the mode which is used 168 | cfg (dict): imported yaml config 169 | ''' 170 | points_transform = data.SubsamplePoints(cfg['data']['points_subsample']) 171 | 172 | input_type = cfg['data']['input_type'] 173 | fields = {} 174 | if cfg['data']['points_file'] is not None: 175 | if input_type != 'pointcloud_crop': 176 | fields['points'] = data.PointsField( 177 | cfg['data']['points_file'], points_transform, 178 | unpackbits=cfg['data']['points_unpackbits'], 179 | multi_files=cfg['data']['multi_files'] 180 | ) 181 | else: 182 | fields['points'] = data.PatchPointsField( 183 | cfg['data']['points_file'], 184 | transform=points_transform, 185 | unpackbits=cfg['data']['points_unpackbits'], 186 | multi_files=cfg['data']['multi_files'] 187 | ) 188 | 189 | 190 | if mode in ('val', 'test'): 191 | points_iou_file = cfg['data']['points_iou_file'] 192 | voxels_file = cfg['data']['voxels_file'] 193 | if points_iou_file is not None: 194 | if input_type == 'pointcloud_crop': 195 | fields['points_iou'] = data.PatchPointsField( 196 | points_iou_file, 197 | unpackbits=cfg['data']['points_unpackbits'], 198 | multi_files=cfg['data']['multi_files'] 199 | ) 200 | else: 201 | fields['points_iou'] = data.PointsField( 202 | points_iou_file, 203 | unpackbits=cfg['data']['points_unpackbits'], 204 | multi_files=cfg['data']['multi_files'] 205 | ) 206 | if voxels_file is not None: 207 | fields['voxels'] = data.VoxelsField(voxels_file) 208 | 209 | return fields 210 | -------------------------------------------------------------------------------- /src/conv_onet/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import distributions as dist 4 | from src.conv_onet.models import decoder 5 | 6 | # Decoder dictionary 7 | decoder_dict = { 8 | 'simple_local': decoder.LocalDecoder, 9 | 'simple_local_crop': decoder.PatchLocalDecoder, 10 | 'simple_local_point': decoder.LocalPointDecoder 11 | } 12 | 13 | 14 | class ConvolutionalOccupancyNetwork(nn.Module): 15 | ''' Occupancy Network class. 16 | 17 | Args: 18 | decoder (nn.Module): decoder network 19 | encoder (nn.Module): encoder network 20 | device (device): torch device 21 | ''' 22 | 23 | def __init__(self, decoder, encoder=None, device=None): 24 | super().__init__() 25 | 26 | self.decoder = decoder.to(device) 27 | 28 | if encoder is not None: 29 | self.encoder = encoder.to(device) 30 | else: 31 | self.encoder = None 32 | 33 | self._device = device 34 | 35 | def forward(self, p, inputs, sample=True, **kwargs): 36 | ''' Performs a forward pass through the network. 37 | 38 | Args: 39 | p (tensor): sampled points 40 | inputs (tensor): conditioning input 41 | sample (bool): whether to sample for z 42 | ''' 43 | ############# 44 | if isinstance(p, dict): 45 | batch_size = p['p'].size(0) 46 | else: 47 | batch_size = p.size(0) 48 | c = self.encode_inputs(inputs) 49 | p_r = self.decode(p, c, **kwargs) 50 | return p_r 51 | 52 | def encode_inputs(self, inputs): 53 | ''' Encodes the input. 54 | 55 | Args: 56 | input (tensor): the input 57 | ''' 58 | 59 | if self.encoder is not None: 60 | c = self.encoder(inputs) 61 | else: 62 | # Return inputs? 63 | c = torch.empty(inputs.size(0), 0) 64 | 65 | return c 66 | 67 | def decode(self, p, c, **kwargs): 68 | ''' Returns occupancy probabilities for the sampled points. 69 | 70 | Args: 71 | p (tensor): points 72 | c (tensor): latent conditioned code c 73 | ''' 74 | 75 | logits = self.decoder(p, c, **kwargs) 76 | p_r = dist.Bernoulli(logits=logits) 77 | return p_r 78 | 79 | def to(self, device): 80 | ''' Puts the model to the device. 81 | 82 | Args: 83 | device (device): pytorch device 84 | ''' 85 | model = super().to(device) 86 | model._device = device 87 | return model 88 | -------------------------------------------------------------------------------- /src/conv_onet/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import trange 3 | import torch 4 | from torch.nn import functional as F 5 | from torch import distributions as dist 6 | from src.common import ( 7 | compute_iou, make_3d_grid, add_key, 8 | ) 9 | from src.utils import visualize as vis 10 | from src.training import BaseTrainer 11 | 12 | class Trainer(BaseTrainer): 13 | ''' Trainer object for the Occupancy Network. 14 | 15 | Args: 16 | model (nn.Module): Occupancy Network model 17 | optimizer (optimizer): pytorch optimizer object 18 | device (device): pytorch device 19 | input_type (str): input type 20 | vis_dir (str): visualization directory 21 | threshold (float): threshold value 22 | eval_sample (bool): whether to evaluate samples 23 | 24 | ''' 25 | 26 | def __init__(self, model, optimizer, device=None, input_type='pointcloud', 27 | vis_dir=None, threshold=0.5, eval_sample=False): 28 | self.model = model 29 | self.optimizer = optimizer 30 | self.device = device 31 | self.input_type = input_type 32 | self.vis_dir = vis_dir 33 | self.threshold = threshold 34 | self.eval_sample = eval_sample 35 | 36 | if vis_dir is not None and not os.path.exists(vis_dir): 37 | os.makedirs(vis_dir) 38 | 39 | def train_step(self, data): 40 | ''' Performs a training step. 41 | 42 | Args: 43 | data (dict): data dictionary 44 | ''' 45 | self.model.train() 46 | self.optimizer.zero_grad() 47 | loss = self.compute_loss(data) 48 | loss.backward() 49 | self.optimizer.step() 50 | 51 | return loss.item() 52 | 53 | def eval_step(self, data): 54 | ''' Performs an evaluation step. 55 | 56 | Args: 57 | data (dict): data dictionary 58 | ''' 59 | self.model.eval() 60 | 61 | device = self.device 62 | threshold = self.threshold 63 | eval_dict = {} 64 | 65 | points = data.get('points').to(device) 66 | occ = data.get('points.occ').to(device) 67 | 68 | inputs = data.get('inputs', torch.empty(points.size(0), 0)).to(device) 69 | voxels_occ = data.get('voxels') 70 | 71 | points_iou = data.get('points_iou').to(device) 72 | occ_iou = data.get('points_iou.occ').to(device) 73 | 74 | batch_size = points.size(0) 75 | 76 | kwargs = {} 77 | 78 | # add pre-computed index 79 | inputs = add_key(inputs, data.get('inputs.ind'), 'points', 'index', device=device) 80 | # add pre-computed normalized coordinates 81 | points = add_key(points, data.get('points.normalized'), 'p', 'p_n', device=device) 82 | points_iou = add_key(points_iou, data.get('points_iou.normalized'), 'p', 'p_n', device=device) 83 | 84 | # Compute iou 85 | with torch.no_grad(): 86 | p_out = self.model(points_iou, inputs, 87 | sample=self.eval_sample, **kwargs) 88 | 89 | occ_iou_np = (occ_iou >= 0.5).cpu().numpy() 90 | occ_iou_hat_np = (p_out.probs >= threshold).cpu().numpy() 91 | 92 | iou = compute_iou(occ_iou_np, occ_iou_hat_np).mean() 93 | eval_dict['iou'] = iou 94 | 95 | # Estimate voxel iou 96 | if voxels_occ is not None: 97 | voxels_occ = voxels_occ.to(device) 98 | points_voxels = make_3d_grid( 99 | (-0.5 + 1/64,) * 3, (0.5 - 1/64,) * 3, voxels_occ.shape[1:]) 100 | points_voxels = points_voxels.expand( 101 | batch_size, *points_voxels.size()) 102 | points_voxels = points_voxels.to(device) 103 | with torch.no_grad(): 104 | p_out = self.model(points_voxels, inputs, 105 | sample=self.eval_sample, **kwargs) 106 | 107 | voxels_occ_np = (voxels_occ >= 0.5).cpu().numpy() 108 | occ_hat_np = (p_out.probs >= threshold).cpu().numpy() 109 | iou_voxels = compute_iou(voxels_occ_np, occ_hat_np).mean() 110 | 111 | eval_dict['iou_voxels'] = iou_voxels 112 | 113 | return eval_dict 114 | 115 | def compute_loss(self, data): 116 | ''' Computes the loss. 117 | 118 | Args: 119 | data (dict): data dictionary 120 | ''' 121 | device = self.device 122 | p = data.get('points').to(device) 123 | occ = data.get('points.occ').to(device) 124 | inputs = data.get('inputs', torch.empty(p.size(0), 0)).to(device) 125 | 126 | if 'pointcloud_crop' in data.keys(): 127 | # add pre-computed index 128 | inputs = add_key(inputs, data.get('inputs.ind'), 'points', 'index', device=device) 129 | inputs['mask'] = data.get('inputs.mask').to(device) 130 | # add pre-computed normalized coordinates 131 | p = add_key(p, data.get('points.normalized'), 'p', 'p_n', device=device) 132 | 133 | c = self.model.encode_inputs(inputs) 134 | 135 | kwargs = {} 136 | # General points 137 | logits = self.model.decode(p, c, **kwargs).logits 138 | loss_i = F.binary_cross_entropy_with_logits( 139 | logits, occ, reduction='none') 140 | loss = loss_i.sum(-1).mean() 141 | 142 | return loss 143 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from src.data.core import ( 3 | Shapes3dDataset, collate_remove_none, worker_init_fn 4 | ) 5 | from src.data.fields import ( 6 | IndexField, PointsField, 7 | VoxelsField, PatchPointsField, PointCloudField, PatchPointCloudField, PartialPointCloudField, 8 | ) 9 | from src.data.transforms import ( 10 | PointcloudNoise, SubsamplePointcloud, 11 | SubsamplePoints, 12 | ) 13 | __all__ = [ 14 | # Core 15 | Shapes3dDataset, 16 | collate_remove_none, 17 | worker_init_fn, 18 | # Fields 19 | IndexField, 20 | PointsField, 21 | VoxelsField, 22 | PointCloudField, 23 | PartialPointCloudField, 24 | PatchPointCloudField, 25 | PatchPointsField, 26 | # Transforms 27 | PointcloudNoise, 28 | SubsamplePointcloud, 29 | SubsamplePoints, 30 | ] 31 | -------------------------------------------------------------------------------- /src/data/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # Transforms 5 | class PointcloudNoise(object): 6 | ''' Point cloud noise transformation class. 7 | 8 | It adds noise to point cloud data. 9 | 10 | Args: 11 | stddev (int): standard deviation 12 | ''' 13 | 14 | def __init__(self, stddev): 15 | self.stddev = stddev 16 | 17 | def __call__(self, data): 18 | ''' Calls the transformation. 19 | 20 | Args: 21 | data (dictionary): data dictionary 22 | ''' 23 | data_out = data.copy() 24 | points = data[None] 25 | noise = self.stddev * np.random.randn(*points.shape) 26 | noise = noise.astype(np.float32) 27 | data_out[None] = points + noise 28 | return data_out 29 | 30 | class SubsamplePointcloud(object): 31 | ''' Point cloud subsampling transformation class. 32 | 33 | It subsamples the point cloud data. 34 | 35 | Args: 36 | N (int): number of points to be subsampled 37 | ''' 38 | def __init__(self, N): 39 | self.N = N 40 | 41 | def __call__(self, data): 42 | ''' Calls the transformation. 43 | 44 | Args: 45 | data (dict): data dictionary 46 | ''' 47 | data_out = data.copy() 48 | points = data[None] 49 | normals = data['normals'] 50 | 51 | indices = np.random.randint(points.shape[0], size=self.N) 52 | data_out[None] = points[indices, :] 53 | data_out['normals'] = normals[indices, :] 54 | 55 | return data_out 56 | 57 | 58 | class SubsamplePoints(object): 59 | ''' Points subsampling transformation class. 60 | 61 | It subsamples the points data. 62 | 63 | Args: 64 | N (int): number of points to be subsampled 65 | ''' 66 | def __init__(self, N): 67 | self.N = N 68 | 69 | def __call__(self, data): 70 | ''' Calls the transformation. 71 | 72 | Args: 73 | data (dictionary): data dictionary 74 | ''' 75 | points = data[None] 76 | occ = data['occ'] 77 | 78 | data_out = data.copy() 79 | if isinstance(self.N, int): 80 | idx = np.random.randint(points.shape[0], size=self.N) 81 | data_out.update({ 82 | None: points[idx, :], 83 | 'occ': occ[idx], 84 | }) 85 | else: 86 | Nt_out, Nt_in = self.N 87 | occ_binary = (occ >= 0.5) 88 | points0 = points[~occ_binary] 89 | points1 = points[occ_binary] 90 | 91 | idx0 = np.random.randint(points0.shape[0], size=Nt_out) 92 | idx1 = np.random.randint(points1.shape[0], size=Nt_in) 93 | 94 | points0 = points0[idx0, :] 95 | points1 = points1[idx1, :] 96 | points = np.concatenate([points0, points1], axis=0) 97 | 98 | occ0 = np.zeros(Nt_out, dtype=np.float32) 99 | occ1 = np.ones(Nt_in, dtype=np.float32) 100 | occ = np.concatenate([occ0, occ1], axis=0) 101 | 102 | volume = occ_binary.sum() / len(occ_binary) 103 | volume = volume.astype(np.float32) 104 | 105 | data_out.update({ 106 | None: points, 107 | 'occ': occ, 108 | 'volume': volume, 109 | }) 110 | return data_out 111 | -------------------------------------------------------------------------------- /src/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from src.encoder import ( 2 | pointnet, voxels, pointnetpp 3 | ) 4 | 5 | 6 | encoder_dict = { 7 | 'pointnet_local_pool': pointnet.LocalPoolPointnet, 8 | 'pointnet_crop_local_pool': pointnet.PatchLocalPoolPointnet, 9 | 'pointnet_plus_plus': pointnetpp.PointNetPlusPlus, 10 | 'voxel_simple_local': voxels.LocalVoxelEncoder, 11 | } 12 | -------------------------------------------------------------------------------- /src/encoder/unet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Codes are from: 3 | https://github.com/jaxony/unet-pytorch/blob/master/model.py 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | from collections import OrderedDict 11 | from torch.nn import init 12 | import numpy as np 13 | 14 | def conv3x3(in_channels, out_channels, stride=1, 15 | padding=1, bias=True, groups=1): 16 | return nn.Conv2d( 17 | in_channels, 18 | out_channels, 19 | kernel_size=3, 20 | stride=stride, 21 | padding=padding, 22 | bias=bias, 23 | groups=groups) 24 | 25 | def upconv2x2(in_channels, out_channels, mode='transpose'): 26 | if mode == 'transpose': 27 | return nn.ConvTranspose2d( 28 | in_channels, 29 | out_channels, 30 | kernel_size=2, 31 | stride=2) 32 | else: 33 | # out_channels is always going to be the same 34 | # as in_channels 35 | return nn.Sequential( 36 | nn.Upsample(mode='bilinear', scale_factor=2), 37 | conv1x1(in_channels, out_channels)) 38 | 39 | def conv1x1(in_channels, out_channels, groups=1): 40 | return nn.Conv2d( 41 | in_channels, 42 | out_channels, 43 | kernel_size=1, 44 | groups=groups, 45 | stride=1) 46 | 47 | 48 | class DownConv(nn.Module): 49 | """ 50 | A helper Module that performs 2 convolutions and 1 MaxPool. 51 | A ReLU activation follows each convolution. 52 | """ 53 | def __init__(self, in_channels, out_channels, pooling=True): 54 | super(DownConv, self).__init__() 55 | 56 | self.in_channels = in_channels 57 | self.out_channels = out_channels 58 | self.pooling = pooling 59 | 60 | self.conv1 = conv3x3(self.in_channels, self.out_channels) 61 | self.conv2 = conv3x3(self.out_channels, self.out_channels) 62 | 63 | if self.pooling: 64 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 65 | 66 | def forward(self, x): 67 | x = F.relu(self.conv1(x)) 68 | x = F.relu(self.conv2(x)) 69 | before_pool = x 70 | if self.pooling: 71 | x = self.pool(x) 72 | return x, before_pool 73 | 74 | 75 | class UpConv(nn.Module): 76 | """ 77 | A helper Module that performs 2 convolutions and 1 UpConvolution. 78 | A ReLU activation follows each convolution. 79 | """ 80 | def __init__(self, in_channels, out_channels, 81 | merge_mode='concat', up_mode='transpose'): 82 | super(UpConv, self).__init__() 83 | 84 | self.in_channels = in_channels 85 | self.out_channels = out_channels 86 | self.merge_mode = merge_mode 87 | self.up_mode = up_mode 88 | 89 | self.upconv = upconv2x2(self.in_channels, self.out_channels, 90 | mode=self.up_mode) 91 | 92 | if self.merge_mode == 'concat': 93 | self.conv1 = conv3x3( 94 | 2*self.out_channels, self.out_channels) 95 | else: 96 | # num of input channels to conv2 is same 97 | self.conv1 = conv3x3(self.out_channels, self.out_channels) 98 | self.conv2 = conv3x3(self.out_channels, self.out_channels) 99 | 100 | 101 | def forward(self, from_down, from_up): 102 | """ Forward pass 103 | Arguments: 104 | from_down: tensor from the encoder pathway 105 | from_up: upconv'd tensor from the decoder pathway 106 | """ 107 | from_up = self.upconv(from_up) 108 | if self.merge_mode == 'concat': 109 | x = torch.cat((from_up, from_down), 1) 110 | else: 111 | x = from_up + from_down 112 | x = F.relu(self.conv1(x)) 113 | x = F.relu(self.conv2(x)) 114 | return x 115 | 116 | 117 | class UNet(nn.Module): 118 | """ `UNet` class is based on https://arxiv.org/abs/1505.04597 119 | 120 | The U-Net is a convolutional encoder-decoder neural network. 121 | Contextual spatial information (from the decoding, 122 | expansive pathway) about an input tensor is merged with 123 | information representing the localization of details 124 | (from the encoding, compressive pathway). 125 | 126 | Modifications to the original paper: 127 | (1) padding is used in 3x3 convolutions to prevent loss 128 | of border pixels 129 | (2) merging outputs does not require cropping due to (1) 130 | (3) residual connections can be used by specifying 131 | UNet(merge_mode='add') 132 | (4) if non-parametric upsampling is used in the decoder 133 | pathway (specified by upmode='upsample'), then an 134 | additional 1x1 2d convolution occurs after upsampling 135 | to reduce channel dimensionality by a factor of 2. 136 | This channel halving happens with the convolution in 137 | the tranpose convolution (specified by upmode='transpose') 138 | """ 139 | 140 | def __init__(self, num_classes, in_channels=3, depth=5, 141 | start_filts=64, up_mode='transpose', 142 | merge_mode='concat', **kwargs): 143 | """ 144 | Arguments: 145 | in_channels: int, number of channels in the input tensor. 146 | Default is 3 for RGB images. 147 | depth: int, number of MaxPools in the U-Net. 148 | start_filts: int, number of convolutional filters for the 149 | first conv. 150 | up_mode: string, type of upconvolution. Choices: 'transpose' 151 | for transpose convolution or 'upsample' for nearest neighbour 152 | upsampling. 153 | """ 154 | super(UNet, self).__init__() 155 | 156 | if up_mode in ('transpose', 'upsample'): 157 | self.up_mode = up_mode 158 | else: 159 | raise ValueError("\"{}\" is not a valid mode for " 160 | "upsampling. Only \"transpose\" and " 161 | "\"upsample\" are allowed.".format(up_mode)) 162 | 163 | if merge_mode in ('concat', 'add'): 164 | self.merge_mode = merge_mode 165 | else: 166 | raise ValueError("\"{}\" is not a valid mode for" 167 | "merging up and down paths. " 168 | "Only \"concat\" and " 169 | "\"add\" are allowed.".format(up_mode)) 170 | 171 | # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add' 172 | if self.up_mode == 'upsample' and self.merge_mode == 'add': 173 | raise ValueError("up_mode \"upsample\" is incompatible " 174 | "with merge_mode \"add\" at the moment " 175 | "because it doesn't make sense to use " 176 | "nearest neighbour to reduce " 177 | "depth channels (by half).") 178 | 179 | self.num_classes = num_classes 180 | self.in_channels = in_channels 181 | self.start_filts = start_filts 182 | self.depth = depth 183 | 184 | self.down_convs = [] 185 | self.up_convs = [] 186 | 187 | # create the encoder pathway and add to a list 188 | for i in range(depth): 189 | ins = self.in_channels if i == 0 else outs 190 | outs = self.start_filts*(2**i) 191 | pooling = True if i < depth-1 else False 192 | 193 | down_conv = DownConv(ins, outs, pooling=pooling) 194 | self.down_convs.append(down_conv) 195 | 196 | # create the decoder pathway and add to a list 197 | # - careful! decoding only requires depth-1 blocks 198 | for i in range(depth-1): 199 | ins = outs 200 | outs = ins // 2 201 | up_conv = UpConv(ins, outs, up_mode=up_mode, 202 | merge_mode=merge_mode) 203 | self.up_convs.append(up_conv) 204 | 205 | # add the list of modules to current module 206 | self.down_convs = nn.ModuleList(self.down_convs) 207 | self.up_convs = nn.ModuleList(self.up_convs) 208 | 209 | self.conv_final = conv1x1(outs, self.num_classes) 210 | 211 | self.reset_params() 212 | 213 | @staticmethod 214 | def weight_init(m): 215 | if isinstance(m, nn.Conv2d): 216 | init.xavier_normal_(m.weight) 217 | init.constant_(m.bias, 0) 218 | 219 | 220 | def reset_params(self): 221 | for i, m in enumerate(self.modules()): 222 | self.weight_init(m) 223 | 224 | 225 | def forward(self, x): 226 | encoder_outs = [] 227 | # encoder pathway, save outputs for merging 228 | for i, module in enumerate(self.down_convs): 229 | x, before_pool = module(x) 230 | encoder_outs.append(before_pool) 231 | for i, module in enumerate(self.up_convs): 232 | before_pool = encoder_outs[-(i+2)] 233 | x = module(before_pool, x) 234 | 235 | # No softmax is used. This means you need to use 236 | # nn.CrossEntropyLoss is your training script, 237 | # as this module includes a softmax already. 238 | x = self.conv_final(x) 239 | return x 240 | 241 | if __name__ == "__main__": 242 | """ 243 | testing 244 | """ 245 | model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32) 246 | print(model) 247 | print(sum(p.numel() for p in model.parameters())) 248 | 249 | reso = 176 250 | x = np.zeros((1, 1, reso, reso)) 251 | x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan 252 | x = torch.FloatTensor(x) 253 | 254 | out = model(x) 255 | print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso))) 256 | 257 | # loss = torch.sum(out) 258 | # loss.backward() 259 | -------------------------------------------------------------------------------- /src/encoder/voxels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_scatter import scatter_mean 5 | from src.encoder.unet import UNet 6 | from src.encoder.unet3d import UNet3D 7 | from src.common import coordinate2index, normalize_coordinate, normalize_3d_coordinate 8 | 9 | 10 | class LocalVoxelEncoder(nn.Module): 11 | ''' 3D-convolutional encoder network for voxel input. 12 | 13 | Args: 14 | dim (int): input dimension 15 | c_dim (int): dimension of latent code c 16 | hidden_dim (int): hidden dimension of the network 17 | unet (bool): weather to use U-Net 18 | unet_kwargs (str): U-Net parameters 19 | unet3d (bool): weather to use 3D U-Net 20 | unet3d_kwargs (str): 3D U-Net parameters 21 | plane_resolution (int): defined resolution for plane feature 22 | grid_resolution (int): defined resolution for grid feature 23 | plane_type (str): 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume 24 | kernel_size (int): kernel size for the first layer of CNN 25 | padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] 26 | 27 | ''' 28 | 29 | def __init__(self, dim=3, c_dim=128, unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None, 30 | plane_resolution=512, grid_resolution=None, plane_type='xz', kernel_size=3, padding=0.1): 31 | super().__init__() 32 | self.actvn = F.relu 33 | if kernel_size == 1: 34 | self.conv_in = nn.Conv3d(1, c_dim, 1) 35 | else: 36 | self.conv_in = nn.Conv3d(1, c_dim, kernel_size, padding=1) 37 | 38 | if unet: 39 | self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs) 40 | else: 41 | self.unet = None 42 | 43 | if unet3d: 44 | self.unet3d = UNet3D(**unet3d_kwargs) 45 | else: 46 | self.unet3d = None 47 | 48 | self.c_dim = c_dim 49 | 50 | self.reso_plane = plane_resolution 51 | self.reso_grid = grid_resolution 52 | 53 | self.plane_type = plane_type 54 | self.padding = padding 55 | 56 | def generate_plane_features(self, p, c, plane='xz'): 57 | # acquire indices of features in plane 58 | xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) 59 | index = coordinate2index(xy, self.reso_plane) 60 | 61 | # scatter plane features from points 62 | fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2) 63 | c = c.permute(0, 2, 1) 64 | fea_plane = scatter_mean(c, index, out=fea_plane) 65 | fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) 66 | 67 | # process the plane features with UNet 68 | if self.unet is not None: 69 | fea_plane = self.unet(fea_plane) 70 | 71 | return fea_plane 72 | 73 | def generate_grid_features(self, p, c): 74 | p_nor = normalize_3d_coordinate(p.clone(), padding=self.padding) 75 | index = coordinate2index(p_nor, self.reso_grid, coord_type='3d') 76 | # scatter grid features from points 77 | fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3) 78 | c = c.permute(0, 2, 1) 79 | fea_grid = scatter_mean(c, index, out=fea_grid) 80 | fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) 81 | 82 | if self.unet3d is not None: 83 | fea_grid = self.unet3d(fea_grid) 84 | 85 | return fea_grid 86 | 87 | 88 | def forward(self, x): 89 | batch_size = x.size(0) 90 | device = x.device 91 | n_voxel = x.size(1) * x.size(2) * x.size(3) 92 | 93 | # voxel 3D coordintates 94 | coord1 = torch.linspace(-0.5, 0.5, x.size(1)).to(device) 95 | coord2 = torch.linspace(-0.5, 0.5, x.size(2)).to(device) 96 | coord3 = torch.linspace(-0.5, 0.5, x.size(3)).to(device) 97 | 98 | coord1 = coord1.view(1, -1, 1, 1).expand_as(x) 99 | coord2 = coord2.view(1, 1, -1, 1).expand_as(x) 100 | coord3 = coord3.view(1, 1, 1, -1).expand_as(x) 101 | p = torch.stack([coord1, coord2, coord3], dim=4) 102 | p = p.view(batch_size, n_voxel, -1) 103 | 104 | # Acquire voxel-wise feature 105 | x = x.unsqueeze(1) 106 | c = self.actvn(self.conv_in(x)).view(batch_size, self.c_dim, -1) 107 | c = c.permute(0, 2, 1) 108 | 109 | fea = {} 110 | if 'grid' in self.plane_type: 111 | fea['grid'] = self.generate_grid_features(p, c) 112 | else: 113 | if 'xz' in self.plane_type: 114 | fea['xz'] = self.generate_plane_features(p, c, plane='xz') 115 | if 'xy' in self.plane_type: 116 | fea['xy'] = self.generate_plane_features(p, c, plane='xy') 117 | if 'yz' in self.plane_type: 118 | fea['yz'] = self.generate_plane_features(p, c, plane='yz') 119 | return fea 120 | 121 | class VoxelEncoder(nn.Module): 122 | ''' 3D-convolutional encoder network for voxel input. 123 | 124 | Args: 125 | dim (int): input dimension 126 | c_dim (int): output dimension 127 | ''' 128 | 129 | def __init__(self, dim=3, c_dim=128): 130 | super().__init__() 131 | self.actvn = F.relu 132 | 133 | self.conv_in = nn.Conv3d(1, 32, 3, padding=1) 134 | 135 | self.conv_0 = nn.Conv3d(32, 64, 3, padding=1, stride=2) 136 | self.conv_1 = nn.Conv3d(64, 128, 3, padding=1, stride=2) 137 | self.conv_2 = nn.Conv3d(128, 256, 3, padding=1, stride=2) 138 | self.conv_3 = nn.Conv3d(256, 512, 3, padding=1, stride=2) 139 | self.fc = nn.Linear(512 * 2 * 2 * 2, c_dim) 140 | 141 | def forward(self, x): 142 | batch_size = x.size(0) 143 | 144 | x = x.unsqueeze(1) 145 | net = self.conv_in(x) 146 | net = self.conv_0(self.actvn(net)) 147 | net = self.conv_1(self.actvn(net)) 148 | net = self.conv_2(self.actvn(net)) 149 | net = self.conv_3(self.actvn(net)) 150 | 151 | hidden = net.view(batch_size, 512 * 2 * 2 * 2) 152 | c = self.fc(self.actvn(hidden)) 153 | 154 | return c -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import trimesh 4 | # from scipy.spatial import cKDTree 5 | from src.utils.libkdtree import KDTree 6 | from src.utils.libmesh import check_mesh_contains 7 | from src.common import compute_iou 8 | 9 | # Maximum values for bounding box [-0.5, 0.5]^3 10 | EMPTY_PCL_DICT = { 11 | 'completeness': np.sqrt(3), 12 | 'accuracy': np.sqrt(3), 13 | 'completeness2': 3, 14 | 'accuracy2': 3, 15 | 'chamfer': 6, 16 | } 17 | 18 | EMPTY_PCL_DICT_NORMALS = { 19 | 'normals completeness': -1., 20 | 'normals accuracy': -1., 21 | 'normals': -1., 22 | } 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class MeshEvaluator(object): 28 | ''' Mesh evaluation class. 29 | 30 | It handles the mesh evaluation process. 31 | 32 | Args: 33 | n_points (int): number of points to be used for evaluation 34 | ''' 35 | 36 | def __init__(self, n_points=100000): 37 | self.n_points = n_points 38 | 39 | def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, 40 | points_iou, occ_tgt, remove_wall=False): 41 | ''' Evaluates a mesh. 42 | 43 | Args: 44 | mesh (trimesh): mesh which should be evaluated 45 | pointcloud_tgt (numpy array): target point cloud 46 | normals_tgt (numpy array): target normals 47 | points_iou (numpy_array): points tensor for IoU evaluation 48 | occ_tgt (numpy_array): GT occupancy values for IoU points 49 | ''' 50 | if len(mesh.vertices) != 0 and len(mesh.faces) != 0: 51 | if remove_wall: #! Remove walls and floors 52 | pointcloud, idx = mesh.sample(2*self.n_points, return_index=True) 53 | eps = 0.007 54 | x_max, x_min = pointcloud_tgt[:, 0].max(), pointcloud_tgt[:, 0].min() 55 | y_max, y_min = pointcloud_tgt[:, 1].max(), pointcloud_tgt[:, 1].min() 56 | z_max, z_min = pointcloud_tgt[:, 2].max(), pointcloud_tgt[:, 2].min() 57 | 58 | # add small offsets 59 | x_max, x_min = x_max + eps, x_min - eps 60 | y_max, y_min = y_max + eps, y_min - eps 61 | z_max, z_min = z_max + eps, z_min - eps 62 | 63 | mask_x = (pointcloud[:, 0] <= x_max) & (pointcloud[:, 0] >= x_min) 64 | mask_y = (pointcloud[:, 1] >= y_min) # floor 65 | mask_z = (pointcloud[:, 2] <= z_max) & (pointcloud[:, 2] >= z_min) 66 | 67 | mask = mask_x & mask_y & mask_z 68 | pointcloud_new = pointcloud[mask] 69 | # Subsample 70 | idx_new = np.random.randint(pointcloud_new.shape[0], size=self.n_points) 71 | pointcloud = pointcloud_new[idx_new] 72 | idx = idx[mask][idx_new] 73 | else: 74 | pointcloud, idx = mesh.sample(self.n_points, return_index=True) 75 | 76 | pointcloud = pointcloud.astype(np.float32) 77 | normals = mesh.face_normals[idx] 78 | else: 79 | pointcloud = np.empty((0, 3)) 80 | normals = np.empty((0, 3)) 81 | 82 | out_dict = self.eval_pointcloud( 83 | pointcloud, pointcloud_tgt, normals, normals_tgt) 84 | 85 | if len(mesh.vertices) != 0 and len(mesh.faces) != 0: 86 | occ = check_mesh_contains(mesh, points_iou) 87 | out_dict['iou'] = compute_iou(occ, occ_tgt) 88 | else: 89 | out_dict['iou'] = 0. 90 | 91 | return out_dict 92 | 93 | def eval_pointcloud(self, pointcloud, pointcloud_tgt, 94 | normals=None, normals_tgt=None, 95 | thresholds=np.linspace(1./1000, 1, 1000)): 96 | ''' Evaluates a point cloud. 97 | 98 | Args: 99 | pointcloud (numpy array): predicted point cloud 100 | pointcloud_tgt (numpy array): target point cloud 101 | normals (numpy array): predicted normals 102 | normals_tgt (numpy array): target normals 103 | thresholds (numpy array): threshold values for the F-score calculation 104 | ''' 105 | # Return maximum losses if pointcloud is empty 106 | if pointcloud.shape[0] == 0: 107 | logger.warn('Empty pointcloud / mesh detected!') 108 | out_dict = EMPTY_PCL_DICT.copy() 109 | if normals is not None and normals_tgt is not None: 110 | out_dict.update(EMPTY_PCL_DICT_NORMALS) 111 | return out_dict 112 | 113 | pointcloud = np.asarray(pointcloud) 114 | pointcloud_tgt = np.asarray(pointcloud_tgt) 115 | 116 | # Completeness: how far are the points of the target point cloud 117 | # from thre predicted point cloud 118 | completeness, completeness_normals = distance_p2p( 119 | pointcloud_tgt, normals_tgt, pointcloud, normals 120 | ) 121 | recall = get_threshold_percentage(completeness, thresholds) 122 | completeness2 = completeness**2 123 | 124 | completeness = completeness.mean() 125 | completeness2 = completeness2.mean() 126 | completeness_normals = completeness_normals.mean() 127 | 128 | # Accuracy: how far are th points of the predicted pointcloud 129 | # from the target pointcloud 130 | accuracy, accuracy_normals = distance_p2p( 131 | pointcloud, normals, pointcloud_tgt, normals_tgt 132 | ) 133 | precision = get_threshold_percentage(accuracy, thresholds) 134 | accuracy2 = accuracy**2 135 | 136 | accuracy = accuracy.mean() 137 | accuracy2 = accuracy2.mean() 138 | accuracy_normals = accuracy_normals.mean() 139 | 140 | # Chamfer distance 141 | chamferL2 = 0.5 * (completeness2 + accuracy2) 142 | normals_correctness = ( 143 | 0.5 * completeness_normals + 0.5 * accuracy_normals 144 | ) 145 | chamferL1 = 0.5 * (completeness + accuracy) 146 | 147 | # F-Score 148 | F = [ 149 | 2 * precision[i] * recall[i] / (precision[i] + recall[i]) 150 | for i in range(len(precision)) 151 | ] 152 | 153 | out_dict = { 154 | 'completeness': completeness, 155 | 'accuracy': accuracy, 156 | 'normals completeness': completeness_normals, 157 | 'normals accuracy': accuracy_normals, 158 | 'normals': normals_correctness, 159 | 'completeness2': completeness2, 160 | 'accuracy2': accuracy2, 161 | 'chamfer-L2': chamferL2, 162 | 'chamfer-L1': chamferL1, 163 | 'f-score': F[9], # threshold = 1.0% 164 | 'f-score-15': F[14], # threshold = 1.5% 165 | 'f-score-20': F[19], # threshold = 2.0% 166 | } 167 | 168 | return out_dict 169 | 170 | 171 | def distance_p2p(points_src, normals_src, points_tgt, normals_tgt): 172 | ''' Computes minimal distances of each point in points_src to points_tgt. 173 | 174 | Args: 175 | points_src (numpy array): source points 176 | normals_src (numpy array): source normals 177 | points_tgt (numpy array): target points 178 | normals_tgt (numpy array): target normals 179 | ''' 180 | kdtree = KDTree(points_tgt) 181 | dist, idx = kdtree.query(points_src) 182 | 183 | if normals_src is not None and normals_tgt is not None: 184 | normals_src = \ 185 | normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True) 186 | normals_tgt = \ 187 | normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True) 188 | 189 | normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1) 190 | # Handle normals that point into wrong direction gracefully 191 | # (mostly due to mehtod not caring about this in generation) 192 | normals_dot_product = np.abs(normals_dot_product) 193 | else: 194 | normals_dot_product = np.array( 195 | [np.nan] * points_src.shape[0], dtype=np.float32) 196 | return dist, normals_dot_product 197 | 198 | 199 | def distance_p2m(points, mesh): 200 | ''' Compute minimal distances of each point in points to mesh. 201 | 202 | Args: 203 | points (numpy array): points array 204 | mesh (trimesh): mesh 205 | 206 | ''' 207 | _, dist, _ = trimesh.proximity.closest_point(mesh, points) 208 | return dist 209 | 210 | def get_threshold_percentage(dist, thresholds): 211 | ''' Evaluates a point cloud. 212 | 213 | Args: 214 | dist (numpy array): calculated distance 215 | thresholds (numpy array): threshold values for the F-score calculation 216 | ''' 217 | in_threshold = [ 218 | (dist <= t).mean() for t in thresholds 219 | ] 220 | return in_threshold 221 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Resnet Blocks 6 | class ResnetBlockFC(nn.Module): 7 | ''' Fully connected ResNet Block class. 8 | 9 | Args: 10 | size_in (int): input dimension 11 | size_out (int): output dimension 12 | size_h (int): hidden dimension 13 | ''' 14 | 15 | def __init__(self, size_in, size_out=None, size_h=None): 16 | super().__init__() 17 | # Attributes 18 | if size_out is None: 19 | size_out = size_in 20 | 21 | if size_h is None: 22 | size_h = min(size_in, size_out) 23 | 24 | self.size_in = size_in 25 | self.size_h = size_h 26 | self.size_out = size_out 27 | # Submodules 28 | self.fc_0 = nn.Linear(size_in, size_h) 29 | self.fc_1 = nn.Linear(size_h, size_out) 30 | self.actvn = nn.ReLU() 31 | 32 | if size_in == size_out: 33 | self.shortcut = None 34 | else: 35 | self.shortcut = nn.Linear(size_in, size_out, bias=False) 36 | # Initialization 37 | nn.init.zeros_(self.fc_1.weight) 38 | 39 | def forward(self, x): 40 | net = self.fc_0(self.actvn(x)) 41 | dx = self.fc_1(self.actvn(net)) 42 | 43 | if self.shortcut is not None: 44 | x_s = self.shortcut(x) 45 | else: 46 | x_s = x 47 | 48 | return x_s + dx -------------------------------------------------------------------------------- /src/training.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | from tqdm import tqdm 4 | 5 | 6 | class BaseTrainer(object): 7 | ''' Base trainer class. 8 | ''' 9 | 10 | def evaluate(self, val_loader): 11 | ''' Performs an evaluation. 12 | Args: 13 | val_loader (dataloader): pytorch dataloader 14 | ''' 15 | eval_list = defaultdict(list) 16 | 17 | for data in tqdm(val_loader): 18 | eval_step_dict = self.eval_step(data) 19 | 20 | for k, v in eval_step_dict.items(): 21 | eval_list[k].append(v) 22 | 23 | eval_dict = {k: np.mean(v) for k, v in eval_list.items()} 24 | return eval_dict 25 | 26 | def train_step(self, *args, **kwargs): 27 | ''' Performs a training step. 28 | ''' 29 | raise NotImplementedError 30 | 31 | def eval_step(self, *args, **kwargs): 32 | ''' Performs an evaluation step. 33 | ''' 34 | raise NotImplementedError 35 | 36 | def visualize(self, *args, **kwargs): 37 | ''' Performs visualization. 38 | ''' 39 | raise NotImplementedError 40 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/convolutional_occupancy_networks/838bea5b2f1314f2edbb68d05ebb0db49f1f3bd2/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/icp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.neighbors import NearestNeighbors 3 | 4 | 5 | def best_fit_transform(A, B): 6 | ''' 7 | Calculates the least-squares best-fit transform that maps corresponding 8 | points A to B in m spatial dimensions 9 | Input: 10 | A: Nxm numpy array of corresponding points 11 | B: Nxm numpy array of corresponding points 12 | Returns: 13 | T: (m+1)x(m+1) homogeneous transformation matrix that maps A on to B 14 | R: mxm rotation matrix 15 | t: mx1 translation vector 16 | ''' 17 | 18 | assert A.shape == B.shape 19 | 20 | # get number of dimensions 21 | m = A.shape[1] 22 | 23 | # translate points to their centroids 24 | centroid_A = np.mean(A, axis=0) 25 | centroid_B = np.mean(B, axis=0) 26 | AA = A - centroid_A 27 | BB = B - centroid_B 28 | 29 | # rotation matrix 30 | H = np.dot(AA.T, BB) 31 | U, S, Vt = np.linalg.svd(H) 32 | R = np.dot(Vt.T, U.T) 33 | 34 | # special reflection case 35 | if np.linalg.det(R) < 0: 36 | Vt[m-1,:] *= -1 37 | R = np.dot(Vt.T, U.T) 38 | 39 | # translation 40 | t = centroid_B.T - np.dot(R,centroid_A.T) 41 | 42 | # homogeneous transformation 43 | T = np.identity(m+1) 44 | T[:m, :m] = R 45 | T[:m, m] = t 46 | 47 | return T, R, t 48 | 49 | 50 | def nearest_neighbor(src, dst): 51 | ''' 52 | Find the nearest (Euclidean) neighbor in dst for each point in src 53 | Input: 54 | src: Nxm array of points 55 | dst: Nxm array of points 56 | Output: 57 | distances: Euclidean distances of the nearest neighbor 58 | indices: dst indices of the nearest neighbor 59 | ''' 60 | 61 | assert src.shape == dst.shape 62 | 63 | neigh = NearestNeighbors(n_neighbors=1) 64 | neigh.fit(dst) 65 | distances, indices = neigh.kneighbors(src, return_distance=True) 66 | return distances.ravel(), indices.ravel() 67 | 68 | 69 | def icp(A, B, init_pose=None, max_iterations=20, tolerance=0.001): 70 | ''' 71 | The Iterative Closest Point method: finds best-fit transform that maps 72 | points A on to points B 73 | Input: 74 | A: Nxm numpy array of source mD points 75 | B: Nxm numpy array of destination mD point 76 | init_pose: (m+1)x(m+1) homogeneous transformation 77 | max_iterations: exit algorithm after max_iterations 78 | tolerance: convergence criteria 79 | Output: 80 | T: final homogeneous transformation that maps A on to B 81 | distances: Euclidean distances (errors) of the nearest neighbor 82 | i: number of iterations to converge 83 | ''' 84 | 85 | assert A.shape == B.shape 86 | 87 | # get number of dimensions 88 | m = A.shape[1] 89 | 90 | # make points homogeneous, copy them to maintain the originals 91 | src = np.ones((m+1,A.shape[0])) 92 | dst = np.ones((m+1,B.shape[0])) 93 | src[:m,:] = np.copy(A.T) 94 | dst[:m,:] = np.copy(B.T) 95 | 96 | # apply the initial pose estimation 97 | if init_pose is not None: 98 | src = np.dot(init_pose, src) 99 | 100 | prev_error = 0 101 | 102 | for i in range(max_iterations): 103 | # find the nearest neighbors between the current source and destination points 104 | distances, indices = nearest_neighbor(src[:m,:].T, dst[:m,:].T) 105 | 106 | # compute the transformation between the current source and nearest destination points 107 | T,_,_ = best_fit_transform(src[:m,:].T, dst[:m,indices].T) 108 | 109 | # update the current source 110 | src = np.dot(T, src) 111 | 112 | # check error 113 | mean_error = np.mean(distances) 114 | if np.abs(prev_error - mean_error) < tolerance: 115 | break 116 | prev_error = mean_error 117 | 118 | # calculate final transformation 119 | T,_,_ = best_fit_transform(A, src[:m,:].T) 120 | 121 | return T, distances, i 122 | -------------------------------------------------------------------------------- /src/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | from plyfile import PlyElement, PlyData 3 | import numpy as np 4 | 5 | 6 | def export_pointcloud(vertices, out_file, as_text=True): 7 | assert(vertices.shape[1] == 3) 8 | vertices = vertices.astype(np.float32) 9 | vertices = np.ascontiguousarray(vertices) 10 | vector_dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')] 11 | vertices = vertices.view(dtype=vector_dtype).flatten() 12 | plyel = PlyElement.describe(vertices, 'vertex') 13 | plydata = PlyData([plyel], text=as_text) 14 | plydata.write(out_file) 15 | 16 | 17 | def load_pointcloud(in_file): 18 | plydata = PlyData.read(in_file) 19 | vertices = np.stack([ 20 | plydata['vertex']['x'], 21 | plydata['vertex']['y'], 22 | plydata['vertex']['z'] 23 | ], axis=1) 24 | return vertices 25 | 26 | 27 | def read_off(file): 28 | """ 29 | Reads vertices and faces from an off file. 30 | 31 | :param file: path to file to read 32 | :type file: str 33 | :return: vertices and faces as lists of tuples 34 | :rtype: [(float)], [(int)] 35 | """ 36 | 37 | assert os.path.exists(file), 'file %s not found' % file 38 | 39 | with open(file, 'r') as fp: 40 | lines = fp.readlines() 41 | lines = [line.strip() for line in lines] 42 | 43 | # Fix for ModelNet bug were 'OFF' and the number of vertices and faces 44 | # are all in the first line. 45 | if len(lines[0]) > 3: 46 | assert lines[0][:3] == 'OFF' or lines[0][:3] == 'off', \ 47 | 'invalid OFF file %s' % file 48 | 49 | parts = lines[0][3:].split(' ') 50 | assert len(parts) == 3 51 | 52 | num_vertices = int(parts[0]) 53 | assert num_vertices > 0 54 | 55 | num_faces = int(parts[1]) 56 | assert num_faces > 0 57 | 58 | start_index = 1 59 | # This is the regular case! 60 | else: 61 | assert lines[0] == 'OFF' or lines[0] == 'off', \ 62 | 'invalid OFF file %s' % file 63 | 64 | parts = lines[1].split(' ') 65 | assert len(parts) == 3 66 | 67 | num_vertices = int(parts[0]) 68 | assert num_vertices > 0 69 | 70 | num_faces = int(parts[1]) 71 | assert num_faces > 0 72 | 73 | start_index = 2 74 | 75 | vertices = [] 76 | for i in range(num_vertices): 77 | vertex = lines[start_index + i].split(' ') 78 | vertex = [float(point.strip()) for point in vertex if point != ''] 79 | assert len(vertex) == 3 80 | 81 | vertices.append(vertex) 82 | 83 | faces = [] 84 | for i in range(num_faces): 85 | face = lines[start_index + num_vertices + i].split(' ') 86 | face = [index.strip() for index in face if index != ''] 87 | 88 | # check to be sure 89 | for index in face: 90 | assert index != '', \ 91 | 'found empty vertex index: %s (%s)' \ 92 | % (lines[start_index + num_vertices + i], file) 93 | 94 | face = [int(index) for index in face] 95 | 96 | assert face[0] == len(face) - 1, \ 97 | 'face should have %d vertices but as %d (%s)' \ 98 | % (face[0], len(face) - 1, file) 99 | assert face[0] == 3, \ 100 | 'only triangular meshes supported (%s)' % file 101 | for index in face: 102 | assert index >= 0 and index < num_vertices, \ 103 | 'vertex %d (of %d vertices) does not exist (%s)' \ 104 | % (index, num_vertices, file) 105 | 106 | assert len(face) > 1 107 | 108 | faces.append(face) 109 | 110 | return vertices, faces 111 | 112 | assert False, 'could not open %s' % file 113 | -------------------------------------------------------------------------------- /src/utils/libkdtree/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | -------------------------------------------------------------------------------- /src/utils/libkdtree/MANIFEST.in: -------------------------------------------------------------------------------- 1 | exclude pykdtree/render_template.py 2 | include LICENSE.txt 3 | -------------------------------------------------------------------------------- /src/utils/libkdtree/README: -------------------------------------------------------------------------------- 1 | README.rst -------------------------------------------------------------------------------- /src/utils/libkdtree/README.rst: -------------------------------------------------------------------------------- 1 | .. image:: https://travis-ci.org/storpipfugl/pykdtree.svg?branch=master 2 | :target: https://travis-ci.org/storpipfugl/pykdtree 3 | .. image:: https://ci.appveyor.com/api/projects/status/ubo92368ktt2d25g/branch/master 4 | :target: https://ci.appveyor.com/project/storpipfugl/pykdtree 5 | 6 | ======== 7 | pykdtree 8 | ======== 9 | 10 | Objective 11 | --------- 12 | pykdtree is a kd-tree implementation for fast nearest neighbour search in Python. 13 | The aim is to be the fastest implementation around for common use cases (low dimensions and low number of neighbours) for both tree construction and queries. 14 | 15 | The implementation is based on scipy.spatial.cKDTree and libANN by combining the best features from both and focus on implementation efficiency. 16 | 17 | The interface is similar to that of scipy.spatial.cKDTree except only Euclidean distance measure is supported. 18 | 19 | Queries are optionally multithreaded using OpenMP. 20 | 21 | Installation 22 | ------------ 23 | Default build of pykdtree with OpenMP enabled queries using libgomp 24 | 25 | .. code-block:: bash 26 | 27 | $ cd 28 | $ python setup.py install 29 | 30 | If it fails with undefined compiler flags or you want to use another OpenMP implementation please modify setup.py at the indicated point to match your system. 31 | 32 | Building without OpenMP support is controlled by the USE_OMP environment variable 33 | 34 | .. code-block:: bash 35 | 36 | $ cd 37 | $ export USE_OMP=0 38 | $ python setup.py install 39 | 40 | Note evironment variables are by default not exported when using sudo so in this case do 41 | 42 | .. code-block:: bash 43 | 44 | $ USE_OMP=0 sudo -E python setup.py install 45 | 46 | Usage 47 | ----- 48 | The usage of pykdtree is similar to scipy.spatial.cKDTree so for now refer to its documentation 49 | 50 | >>> from pykdtree.kdtree import KDTree 51 | >>> kd_tree = KDTree(data_pts) 52 | >>> dist, idx = kd_tree.query(query_pts, k=8) 53 | 54 | The number of threads to be used in OpenMP enabled queries can be controlled with the standard OpenMP environment variable OMP_NUM_THREADS. 55 | 56 | The **leafsize** argument (number of data points per leaf) for the tree creation can be used to control the memory overhead of the kd-tree. pykdtree uses a default **leafsize=16**. 57 | Increasing **leafsize** will reduce the memory overhead and construction time but increase query time. 58 | 59 | pykdtree accepts data in double precision (numpy.float64) or single precision (numpy.float32) floating point. If data of another type is used an internal copy in double precision is made resulting in a memory overhead. If the kd-tree is constructed on single precision data the query points must be single precision as well. 60 | 61 | Benchmarks 62 | ---------- 63 | Comparison with scipy.spatial.cKDTree and libANN. This benchmark is on geospatial 3D data with 10053632 data points and 4276224 query points. The results are indexed relative to the construction time of scipy.spatial.cKDTree. A leafsize of 10 (scipy.spatial.cKDTree default) is used. 64 | 65 | Note: libANN is *not* thread safe. In this benchmark libANN is compiled with "-O3 -funroll-loops -ffast-math -fprefetch-loop-arrays" in order to achieve optimum performance. 66 | 67 | ================== ===================== ====== ======== ================== 68 | Operation scipy.spatial.cKDTree libANN pykdtree pykdtree 4 threads 69 | ------------------ --------------------- ------ -------- ------------------ 70 | 71 | Construction 100 304 96 96 72 | 73 | query 1 neighbour 1267 294 223 70 74 | 75 | Total 1 neighbour 1367 598 319 166 76 | 77 | query 8 neighbours 2193 625 449 143 78 | 79 | Total 8 neighbours 2293 929 545 293 80 | ================== ===================== ====== ======== ================== 81 | 82 | Looking at the combined construction and query this gives the following performance improvement relative to scipy.spatial.cKDTree 83 | 84 | ========== ====== ======== ================== 85 | Neighbours libANN pykdtree pykdtree 4 threads 86 | ---------- ------ -------- ------------------ 87 | 1 129% 329% 723% 88 | 89 | 8 147% 320% 682% 90 | ========== ====== ======== ================== 91 | 92 | Note: mileage will vary with the dataset at hand and computer architecture. 93 | 94 | Test 95 | ---- 96 | Run the unit tests using nosetest 97 | 98 | .. code-block:: bash 99 | 100 | $ cd 101 | $ python setup.py nosetests 102 | 103 | Installing on AppVeyor 104 | ---------------------- 105 | 106 | Pykdtree requires the "stdint.h" header file which is not available on certain 107 | versions of Windows or certain Windows compilers including those on the 108 | continuous integration platform AppVeyor. To get around this the header file(s) 109 | can be downloaded and placed in the correct "include" directory. This can 110 | be done by adding the `anaconda/missing-headers.ps1` script to your repository 111 | and running it the install step of `appveyor.yml`: 112 | 113 | # install missing headers that aren't included with MSVC 2008 114 | # https://github.com/omnia-md/conda-recipes/pull/524 115 | - "powershell ./appveyor/missing-headers.ps1" 116 | 117 | In addition to this, AppVeyor does not support OpenMP so this feature must be 118 | turned off by adding the following to `appveyor.yml` in the 119 | `environment` section: 120 | 121 | environment: 122 | global: 123 | # Don't build with openmp because it isn't supported in appveyor's compilers 124 | USE_OMP: "0" 125 | 126 | Changelog 127 | --------- 128 | v1.3.1 : Fix masking in the "query" method introduced in 1.3.0 129 | 130 | v1.3.0 : Keyword argument "mask" added to "query" method. OpenMP compilation now works for MS Visual Studio compiler 131 | 132 | v1.2.2 : Build process fixes 133 | 134 | v1.2.1 : Fixed OpenMP thread safety issue introduced in v1.2.0 135 | 136 | v1.2.0 : 64 and 32 bit MSVC Windows support added 137 | 138 | v1.1.1 : Same as v1.1 release due to incorrect pypi release 139 | 140 | v1.1 : Build process improvements. Add data attribute to kdtree class for scipy interface compatibility 141 | 142 | v1.0 : Switched license from GPLv3 to LGPLv3 143 | 144 | v0.3 : Avoid zipping of installed egg 145 | 146 | v0.2 : Reduced memory footprint. Can now handle single precision data internally avoiding copy conversion to double precision. Default leafsize changed from 10 to 16 as this reduces the memory footprint and makes it a cache line multiplum (negligible if any query performance observed in benchmarks). Reduced memory allocation for leaf nodes. Applied patch for building on OS X. 147 | 148 | v0.1 : Initial version. 149 | -------------------------------------------------------------------------------- /src/utils/libkdtree/__init__.py: -------------------------------------------------------------------------------- 1 | from .pykdtree.kdtree import KDTree 2 | 3 | 4 | __all__ = [ 5 | KDTree 6 | ] 7 | -------------------------------------------------------------------------------- /src/utils/libkdtree/pykdtree/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/convolutional_occupancy_networks/838bea5b2f1314f2edbb68d05ebb0db49f1f3bd2/src/utils/libkdtree/pykdtree/__init__.py -------------------------------------------------------------------------------- /src/utils/libkdtree/pykdtree/render_template.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from mako.template import Template 4 | 5 | mytemplate = Template(filename='_kdtree_core.c.mako') 6 | with open('_kdtree_core.c', 'w') as fp: 7 | fp.write(mytemplate.render()) 8 | -------------------------------------------------------------------------------- /src/utils/libkdtree/setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_rpm] 2 | requires=numpy 3 | release=1 4 | 5 | 6 | -------------------------------------------------------------------------------- /src/utils/libmcubes/.gitignore: -------------------------------------------------------------------------------- 1 | PyMCubes.egg-info 2 | build 3 | -------------------------------------------------------------------------------- /src/utils/libmcubes/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2015, P. M. Neila 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /src/utils/libmcubes/README.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | PyMCubes 3 | ======== 4 | 5 | PyMCubes is an implementation of the marching cubes algorithm to extract 6 | isosurfaces from volumetric data. The volumetric data can be given as a 7 | three-dimensional NumPy array or as a Python function ``f(x, y, z)``. The first 8 | option is much faster, but it requires more memory and becomes unfeasible for 9 | very large volumes. 10 | 11 | PyMCubes also provides a function to export the results of the marching cubes as 12 | COLLADA ``(.dae)`` files. This requires the 13 | `PyCollada `_ library. 14 | 15 | Installation 16 | ============ 17 | 18 | Just as any standard Python package, clone or download the project 19 | and run:: 20 | 21 | $ cd path/to/PyMCubes 22 | $ python setup.py build 23 | $ python setup.py install 24 | 25 | If you do not have write permission on the directory of Python packages, 26 | install with the ``--user`` option:: 27 | 28 | $ python setup.py install --user 29 | 30 | Example 31 | ======= 32 | 33 | The following example creates a data volume with spherical isosurfaces and 34 | extracts one of them (i.e., a sphere) with PyMCubes. The result is exported as 35 | ``sphere.dae``:: 36 | 37 | >>> import numpy as np 38 | >>> import mcubes 39 | 40 | # Create a data volume (30 x 30 x 30) 41 | >>> X, Y, Z = np.mgrid[:30, :30, :30] 42 | >>> u = (X-15)**2 + (Y-15)**2 + (Z-15)**2 - 8**2 43 | 44 | # Extract the 0-isosurface 45 | >>> vertices, triangles = mcubes.marching_cubes(u, 0) 46 | 47 | # Export the result to sphere.dae 48 | >>> mcubes.export_mesh(vertices, triangles, "sphere.dae", "MySphere") 49 | 50 | The second example is very similar to the first one, but it uses a function 51 | to represent the volume instead of a NumPy array:: 52 | 53 | >>> import numpy as np 54 | >>> import mcubes 55 | 56 | # Create the volume 57 | >>> f = lambda x, y, z: x**2 + y**2 + z**2 58 | 59 | # Extract the 16-isosurface 60 | >>> vertices, triangles = mcubes.marching_cubes_func((-10,-10,-10), (10,10,10), 61 | ... 100, 100, 100, f, 16) 62 | 63 | # Export the result to sphere2.dae 64 | >>> mcubes.export_mesh(vertices, triangles, "sphere2.dae", "MySphere") 65 | -------------------------------------------------------------------------------- /src/utils/libmcubes/__init__.py: -------------------------------------------------------------------------------- 1 | from src.utils.libmcubes.mcubes import ( 2 | marching_cubes, marching_cubes_func 3 | ) 4 | from src.utils.libmcubes.exporter import ( 5 | export_mesh, export_obj, export_off 6 | ) 7 | 8 | 9 | __all__ = [ 10 | marching_cubes, marching_cubes_func, 11 | export_mesh, export_obj, export_off 12 | ] 13 | -------------------------------------------------------------------------------- /src/utils/libmcubes/exporter.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | 5 | def export_obj(vertices, triangles, filename): 6 | """ 7 | Exports a mesh in the (.obj) format. 8 | """ 9 | 10 | with open(filename, 'w') as fh: 11 | 12 | for v in vertices: 13 | fh.write("v {} {} {}\n".format(*v)) 14 | 15 | for f in triangles: 16 | fh.write("f {} {} {}\n".format(*(f + 1))) 17 | 18 | 19 | def export_off(vertices, triangles, filename): 20 | """ 21 | Exports a mesh in the (.off) format. 22 | """ 23 | 24 | with open(filename, 'w') as fh: 25 | fh.write('OFF\n') 26 | fh.write('{} {} 0\n'.format(len(vertices), len(triangles))) 27 | 28 | for v in vertices: 29 | fh.write("{} {} {}\n".format(*v)) 30 | 31 | for f in triangles: 32 | fh.write("3 {} {} {}\n".format(*f)) 33 | 34 | 35 | def export_mesh(vertices, triangles, filename, mesh_name="mcubes_mesh"): 36 | """ 37 | Exports a mesh in the COLLADA (.dae) format. 38 | 39 | Needs PyCollada (https://github.com/pycollada/pycollada). 40 | """ 41 | 42 | import collada 43 | 44 | mesh = collada.Collada() 45 | 46 | vert_src = collada.source.FloatSource("verts-array", vertices, ('X','Y','Z')) 47 | geom = collada.geometry.Geometry(mesh, "geometry0", mesh_name, [vert_src]) 48 | 49 | input_list = collada.source.InputList() 50 | input_list.addInput(0, 'VERTEX', "#verts-array") 51 | 52 | triset = geom.createTriangleSet(np.copy(triangles), input_list, "") 53 | geom.primitives.append(triset) 54 | mesh.geometries.append(geom) 55 | 56 | geomnode = collada.scene.GeometryNode(geom, []) 57 | node = collada.scene.Node(mesh_name, children=[geomnode]) 58 | 59 | myscene = collada.scene.Scene("mcubes_scene", [node]) 60 | mesh.scenes.append(myscene) 61 | mesh.scene = myscene 62 | 63 | mesh.write(filename) 64 | -------------------------------------------------------------------------------- /src/utils/libmcubes/mcubes.pyx: -------------------------------------------------------------------------------- 1 | 2 | # distutils: language = c++ 3 | # cython: embedsignature = True 4 | 5 | # from libcpp.vector cimport vector 6 | import numpy as np 7 | 8 | # Define PY_ARRAY_UNIQUE_SYMBOL 9 | cdef extern from "pyarray_symbol.h": 10 | pass 11 | 12 | cimport numpy as np 13 | 14 | np.import_array() 15 | 16 | cdef extern from "pywrapper.h": 17 | cdef object c_marching_cubes "marching_cubes"(np.ndarray, double) except + 18 | cdef object c_marching_cubes2 "marching_cubes2"(np.ndarray, double) except + 19 | cdef object c_marching_cubes3 "marching_cubes3"(np.ndarray, double) except + 20 | cdef object c_marching_cubes_func "marching_cubes_func"(tuple, tuple, int, int, int, object, double) except + 21 | 22 | def marching_cubes(np.ndarray volume, float isovalue): 23 | 24 | verts, faces = c_marching_cubes(volume, isovalue) 25 | verts.shape = (-1, 3) 26 | faces.shape = (-1, 3) 27 | return verts, faces 28 | 29 | def marching_cubes2(np.ndarray volume, float isovalue): 30 | 31 | verts, faces = c_marching_cubes2(volume, isovalue) 32 | verts.shape = (-1, 3) 33 | faces.shape = (-1, 3) 34 | return verts, faces 35 | 36 | def marching_cubes3(np.ndarray volume, float isovalue): 37 | 38 | verts, faces = c_marching_cubes3(volume, isovalue) 39 | verts.shape = (-1, 3) 40 | faces.shape = (-1, 3) 41 | return verts, faces 42 | 43 | def marching_cubes_func(tuple lower, tuple upper, int numx, int numy, int numz, object f, double isovalue): 44 | 45 | verts, faces = c_marching_cubes_func(lower, upper, numx, numy, numz, f, isovalue) 46 | verts.shape = (-1, 3) 47 | faces.shape = (-1, 3) 48 | return verts, faces 49 | -------------------------------------------------------------------------------- /src/utils/libmcubes/pyarray_symbol.h: -------------------------------------------------------------------------------- 1 | 2 | #define PY_ARRAY_UNIQUE_SYMBOL mcubes_PyArray_API 3 | -------------------------------------------------------------------------------- /src/utils/libmcubes/pyarraymodule.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _EXTMODULE_H 3 | #define _EXTMODULE_H 4 | 5 | #include 6 | #include 7 | 8 | // #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION 9 | #define PY_ARRAY_UNIQUE_SYMBOL mcubes_PyArray_API 10 | #define NO_IMPORT_ARRAY 11 | #include "numpy/arrayobject.h" 12 | 13 | #include 14 | 15 | template 16 | struct numpy_typemap; 17 | 18 | #define define_numpy_type(ctype, dtype) \ 19 | template<> \ 20 | struct numpy_typemap \ 21 | {static const int type = dtype;}; 22 | 23 | define_numpy_type(bool, NPY_BOOL); 24 | define_numpy_type(char, NPY_BYTE); 25 | define_numpy_type(short, NPY_SHORT); 26 | define_numpy_type(int, NPY_INT); 27 | define_numpy_type(long, NPY_LONG); 28 | define_numpy_type(long long, NPY_LONGLONG); 29 | define_numpy_type(unsigned char, NPY_UBYTE); 30 | define_numpy_type(unsigned short, NPY_USHORT); 31 | define_numpy_type(unsigned int, NPY_UINT); 32 | define_numpy_type(unsigned long, NPY_ULONG); 33 | define_numpy_type(unsigned long long, NPY_ULONGLONG); 34 | define_numpy_type(float, NPY_FLOAT); 35 | define_numpy_type(double, NPY_DOUBLE); 36 | define_numpy_type(long double, NPY_LONGDOUBLE); 37 | define_numpy_type(std::complex, NPY_CFLOAT); 38 | define_numpy_type(std::complex, NPY_CDOUBLE); 39 | define_numpy_type(std::complex, NPY_CLONGDOUBLE); 40 | 41 | template 42 | T PyArray_SafeGet(const PyArrayObject* aobj, const npy_intp* indaux) 43 | { 44 | // HORROR. 45 | npy_intp* ind = const_cast(indaux); 46 | void* ptr = PyArray_GetPtr(const_cast(aobj), ind); 47 | switch(PyArray_TYPE(aobj)) 48 | { 49 | case NPY_BOOL: 50 | return static_cast(*reinterpret_cast(ptr)); 51 | case NPY_BYTE: 52 | return static_cast(*reinterpret_cast(ptr)); 53 | case NPY_SHORT: 54 | return static_cast(*reinterpret_cast(ptr)); 55 | case NPY_INT: 56 | return static_cast(*reinterpret_cast(ptr)); 57 | case NPY_LONG: 58 | return static_cast(*reinterpret_cast(ptr)); 59 | case NPY_LONGLONG: 60 | return static_cast(*reinterpret_cast(ptr)); 61 | case NPY_UBYTE: 62 | return static_cast(*reinterpret_cast(ptr)); 63 | case NPY_USHORT: 64 | return static_cast(*reinterpret_cast(ptr)); 65 | case NPY_UINT: 66 | return static_cast(*reinterpret_cast(ptr)); 67 | case NPY_ULONG: 68 | return static_cast(*reinterpret_cast(ptr)); 69 | case NPY_ULONGLONG: 70 | return static_cast(*reinterpret_cast(ptr)); 71 | case NPY_FLOAT: 72 | return static_cast(*reinterpret_cast(ptr)); 73 | case NPY_DOUBLE: 74 | return static_cast(*reinterpret_cast(ptr)); 75 | case NPY_LONGDOUBLE: 76 | return static_cast(*reinterpret_cast(ptr)); 77 | default: 78 | throw std::runtime_error("data type not supported"); 79 | } 80 | } 81 | 82 | template 83 | T PyArray_SafeSet(PyArrayObject* aobj, const npy_intp* indaux, const T& value) 84 | { 85 | // HORROR. 86 | npy_intp* ind = const_cast(indaux); 87 | void* ptr = PyArray_GetPtr(aobj, ind); 88 | switch(PyArray_TYPE(aobj)) 89 | { 90 | case NPY_BOOL: 91 | *reinterpret_cast(ptr) = static_cast(value); 92 | break; 93 | case NPY_BYTE: 94 | *reinterpret_cast(ptr) = static_cast(value); 95 | break; 96 | case NPY_SHORT: 97 | *reinterpret_cast(ptr) = static_cast(value); 98 | break; 99 | case NPY_INT: 100 | *reinterpret_cast(ptr) = static_cast(value); 101 | break; 102 | case NPY_LONG: 103 | *reinterpret_cast(ptr) = static_cast(value); 104 | break; 105 | case NPY_LONGLONG: 106 | *reinterpret_cast(ptr) = static_cast(value); 107 | break; 108 | case NPY_UBYTE: 109 | *reinterpret_cast(ptr) = static_cast(value); 110 | break; 111 | case NPY_USHORT: 112 | *reinterpret_cast(ptr) = static_cast(value); 113 | break; 114 | case NPY_UINT: 115 | *reinterpret_cast(ptr) = static_cast(value); 116 | break; 117 | case NPY_ULONG: 118 | *reinterpret_cast(ptr) = static_cast(value); 119 | break; 120 | case NPY_ULONGLONG: 121 | *reinterpret_cast(ptr) = static_cast(value); 122 | break; 123 | case NPY_FLOAT: 124 | *reinterpret_cast(ptr) = static_cast(value); 125 | break; 126 | case NPY_DOUBLE: 127 | *reinterpret_cast(ptr) = static_cast(value); 128 | break; 129 | case NPY_LONGDOUBLE: 130 | *reinterpret_cast(ptr) = static_cast(value); 131 | break; 132 | default: 133 | throw std::runtime_error("data type not supported"); 134 | } 135 | } 136 | 137 | #endif 138 | -------------------------------------------------------------------------------- /src/utils/libmcubes/pywrapper.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "pywrapper.h" 3 | 4 | #include "marchingcubes.h" 5 | 6 | #include 7 | 8 | struct PythonToCFunc 9 | { 10 | PyObject* func; 11 | PythonToCFunc(PyObject* func) {this->func = func;} 12 | double operator()(double x, double y, double z) 13 | { 14 | PyObject* res = PyObject_CallFunction(func, "(d,d,d)", x, y, z); // py::extract(func(x,y,z)); 15 | if(res == NULL) 16 | return 0.0; 17 | 18 | double result = PyFloat_AsDouble(res); 19 | Py_DECREF(res); 20 | return result; 21 | } 22 | }; 23 | 24 | PyObject* marching_cubes_func(PyObject* lower, PyObject* upper, 25 | int numx, int numy, int numz, PyObject* f, double isovalue) 26 | { 27 | std::vector vertices; 28 | std::vector polygons; 29 | 30 | // Copy the lower and upper coordinates to a C array. 31 | double lower_[3]; 32 | double upper_[3]; 33 | for(int i=0; i<3; ++i) 34 | { 35 | PyObject* l = PySequence_GetItem(lower, i); 36 | if(l == NULL) 37 | throw std::runtime_error("error"); 38 | PyObject* u = PySequence_GetItem(upper, i); 39 | if(u == NULL) 40 | { 41 | Py_DECREF(l); 42 | throw std::runtime_error("error"); 43 | } 44 | 45 | lower_[i] = PyFloat_AsDouble(l); 46 | upper_[i] = PyFloat_AsDouble(u); 47 | 48 | Py_DECREF(l); 49 | Py_DECREF(u); 50 | if(lower_[i]==-1.0 || upper_[i]==-1.0) 51 | { 52 | if(PyErr_Occurred()) 53 | throw std::runtime_error("error"); 54 | } 55 | } 56 | 57 | // Marching cubes. 58 | mc::marching_cubes(lower_, upper_, numx, numy, numz, PythonToCFunc(f), isovalue, vertices, polygons); 59 | 60 | // Copy the result to two Python ndarrays. 61 | npy_intp size_vertices = vertices.size(); 62 | npy_intp size_polygons = polygons.size(); 63 | PyArrayObject* verticesarr = reinterpret_cast(PyArray_SimpleNew(1, &size_vertices, PyArray_DOUBLE)); 64 | PyArrayObject* polygonsarr = reinterpret_cast(PyArray_SimpleNew(1, &size_polygons, PyArray_ULONG)); 65 | 66 | std::vector::const_iterator it = vertices.begin(); 67 | for(int i=0; it!=vertices.end(); ++i, ++it) 68 | *reinterpret_cast(PyArray_GETPTR1(verticesarr, i)) = *it; 69 | std::vector::const_iterator it2 = polygons.begin(); 70 | for(int i=0; it2!=polygons.end(); ++i, ++it2) 71 | *reinterpret_cast(PyArray_GETPTR1(polygonsarr, i)) = *it2; 72 | 73 | PyObject* res = Py_BuildValue("(O,O)", verticesarr, polygonsarr); 74 | Py_XDECREF(verticesarr); 75 | Py_XDECREF(polygonsarr); 76 | return res; 77 | } 78 | 79 | struct PyArrayToCFunc 80 | { 81 | PyArrayObject* arr; 82 | PyArrayToCFunc(PyArrayObject* arr) {this->arr = arr;} 83 | double operator()(int x, int y, int z) 84 | { 85 | npy_intp c[3] = {x,y,z}; 86 | return PyArray_SafeGet(arr, c); 87 | } 88 | }; 89 | 90 | PyObject* marching_cubes(PyArrayObject* arr, double isovalue) 91 | { 92 | if(PyArray_NDIM(arr) != 3) 93 | throw std::runtime_error("Only three-dimensional arrays are supported."); 94 | 95 | // Prepare data. 96 | npy_intp* shape = PyArray_DIMS(arr); 97 | double lower[3] = {0,0,0}; 98 | double upper[3] = {shape[0]-1, shape[1]-1, shape[2]-1}; 99 | long numx = upper[0] - lower[0] + 1; 100 | long numy = upper[1] - lower[1] + 1; 101 | long numz = upper[2] - lower[2] + 1; 102 | std::vector vertices; 103 | std::vector polygons; 104 | 105 | // Marching cubes. 106 | mc::marching_cubes(lower, upper, numx, numy, numz, PyArrayToCFunc(arr), isovalue, 107 | vertices, polygons); 108 | 109 | // Copy the result to two Python ndarrays. 110 | npy_intp size_vertices = vertices.size(); 111 | npy_intp size_polygons = polygons.size(); 112 | PyArrayObject* verticesarr = reinterpret_cast(PyArray_SimpleNew(1, &size_vertices, PyArray_DOUBLE)); 113 | PyArrayObject* polygonsarr = reinterpret_cast(PyArray_SimpleNew(1, &size_polygons, PyArray_ULONG)); 114 | 115 | std::vector::const_iterator it = vertices.begin(); 116 | for(int i=0; it!=vertices.end(); ++i, ++it) 117 | *reinterpret_cast(PyArray_GETPTR1(verticesarr, i)) = *it; 118 | std::vector::const_iterator it2 = polygons.begin(); 119 | for(int i=0; it2!=polygons.end(); ++i, ++it2) 120 | *reinterpret_cast(PyArray_GETPTR1(polygonsarr, i)) = *it2; 121 | 122 | PyObject* res = Py_BuildValue("(O,O)", verticesarr, polygonsarr); 123 | Py_XDECREF(verticesarr); 124 | Py_XDECREF(polygonsarr); 125 | 126 | return res; 127 | } 128 | 129 | PyObject* marching_cubes2(PyArrayObject* arr, double isovalue) 130 | { 131 | if(PyArray_NDIM(arr) != 3) 132 | throw std::runtime_error("Only three-dimensional arrays are supported."); 133 | 134 | // Prepare data. 135 | npy_intp* shape = PyArray_DIMS(arr); 136 | double lower[3] = {0,0,0}; 137 | double upper[3] = {shape[0]-1, shape[1]-1, shape[2]-1}; 138 | long numx = upper[0] - lower[0] + 1; 139 | long numy = upper[1] - lower[1] + 1; 140 | long numz = upper[2] - lower[2] + 1; 141 | std::vector vertices; 142 | std::vector polygons; 143 | 144 | // Marching cubes. 145 | mc::marching_cubes2(lower, upper, numx, numy, numz, PyArrayToCFunc(arr), isovalue, 146 | vertices, polygons); 147 | 148 | // Copy the result to two Python ndarrays. 149 | npy_intp size_vertices = vertices.size(); 150 | npy_intp size_polygons = polygons.size(); 151 | PyArrayObject* verticesarr = reinterpret_cast(PyArray_SimpleNew(1, &size_vertices, PyArray_DOUBLE)); 152 | PyArrayObject* polygonsarr = reinterpret_cast(PyArray_SimpleNew(1, &size_polygons, PyArray_ULONG)); 153 | 154 | std::vector::const_iterator it = vertices.begin(); 155 | for(int i=0; it!=vertices.end(); ++i, ++it) 156 | *reinterpret_cast(PyArray_GETPTR1(verticesarr, i)) = *it; 157 | std::vector::const_iterator it2 = polygons.begin(); 158 | for(int i=0; it2!=polygons.end(); ++i, ++it2) 159 | *reinterpret_cast(PyArray_GETPTR1(polygonsarr, i)) = *it2; 160 | 161 | PyObject* res = Py_BuildValue("(O,O)", verticesarr, polygonsarr); 162 | Py_XDECREF(verticesarr); 163 | Py_XDECREF(polygonsarr); 164 | 165 | return res; 166 | } 167 | 168 | PyObject* marching_cubes3(PyArrayObject* arr, double isovalue) 169 | { 170 | if(PyArray_NDIM(arr) != 3) 171 | throw std::runtime_error("Only three-dimensional arrays are supported."); 172 | 173 | // Prepare data. 174 | npy_intp* shape = PyArray_DIMS(arr); 175 | double lower[3] = {0,0,0}; 176 | double upper[3] = {shape[0]-1, shape[1]-1, shape[2]-1}; 177 | long numx = upper[0] - lower[0] + 1; 178 | long numy = upper[1] - lower[1] + 1; 179 | long numz = upper[2] - lower[2] + 1; 180 | std::vector vertices; 181 | std::vector polygons; 182 | 183 | // Marching cubes. 184 | mc::marching_cubes3(lower, upper, numx, numy, numz, PyArrayToCFunc(arr), isovalue, 185 | vertices, polygons); 186 | 187 | // Copy the result to two Python ndarrays. 188 | npy_intp size_vertices = vertices.size(); 189 | npy_intp size_polygons = polygons.size(); 190 | PyArrayObject* verticesarr = reinterpret_cast(PyArray_SimpleNew(1, &size_vertices, PyArray_DOUBLE)); 191 | PyArrayObject* polygonsarr = reinterpret_cast(PyArray_SimpleNew(1, &size_polygons, PyArray_ULONG)); 192 | 193 | std::vector::const_iterator it = vertices.begin(); 194 | for(int i=0; it!=vertices.end(); ++i, ++it) 195 | *reinterpret_cast(PyArray_GETPTR1(verticesarr, i)) = *it; 196 | std::vector::const_iterator it2 = polygons.begin(); 197 | for(int i=0; it2!=polygons.end(); ++i, ++it2) 198 | *reinterpret_cast(PyArray_GETPTR1(polygonsarr, i)) = *it2; 199 | 200 | PyObject* res = Py_BuildValue("(O,O)", verticesarr, polygonsarr); 201 | Py_XDECREF(verticesarr); 202 | Py_XDECREF(polygonsarr); 203 | 204 | return res; 205 | } -------------------------------------------------------------------------------- /src/utils/libmcubes/pywrapper.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _PYWRAPPER_H 3 | #define _PYWRAPPER_H 4 | 5 | #include 6 | #include "pyarraymodule.h" 7 | 8 | #include 9 | 10 | PyObject* marching_cubes(PyArrayObject* arr, double isovalue); 11 | PyObject* marching_cubes2(PyArrayObject* arr, double isovalue); 12 | PyObject* marching_cubes3(PyArrayObject* arr, double isovalue); 13 | PyObject* marching_cubes_func(PyObject* lower, PyObject* upper, 14 | int numx, int numy, int numz, PyObject* f, double isovalue); 15 | 16 | #endif // _PYWRAPPER_H 17 | -------------------------------------------------------------------------------- /src/utils/libmesh/.gitignore: -------------------------------------------------------------------------------- 1 | triangle_hash.cpp 2 | build 3 | -------------------------------------------------------------------------------- /src/utils/libmesh/__init__.py: -------------------------------------------------------------------------------- 1 | from .inside_mesh import ( 2 | check_mesh_contains, MeshIntersector, TriangleIntersector2d 3 | ) 4 | 5 | 6 | __all__ = [ 7 | check_mesh_contains, MeshIntersector, TriangleIntersector2d 8 | ] 9 | -------------------------------------------------------------------------------- /src/utils/libmesh/inside_mesh.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .triangle_hash import TriangleHash as _TriangleHash 3 | 4 | 5 | def check_mesh_contains(mesh, points, hash_resolution=512): 6 | intersector = MeshIntersector(mesh, hash_resolution) 7 | contains = intersector.query(points) 8 | return contains 9 | 10 | 11 | class MeshIntersector: 12 | def __init__(self, mesh, resolution=512): 13 | triangles = mesh.vertices[mesh.faces].astype(np.float64) 14 | n_tri = triangles.shape[0] 15 | 16 | self.resolution = resolution 17 | self.bbox_min = triangles.reshape(3 * n_tri, 3).min(axis=0) 18 | self.bbox_max = triangles.reshape(3 * n_tri, 3).max(axis=0) 19 | # Tranlate and scale it to [0.5, self.resolution - 0.5]^3 20 | self.scale = (resolution - 1) / (self.bbox_max - self.bbox_min) 21 | self.translate = 0.5 - self.scale * self.bbox_min 22 | 23 | self._triangles = triangles = self.rescale(triangles) 24 | # assert(np.allclose(triangles.reshape(-1, 3).min(0), 0.5)) 25 | # assert(np.allclose(triangles.reshape(-1, 3).max(0), resolution - 0.5)) 26 | 27 | triangles2d = triangles[:, :, :2] 28 | self._tri_intersector2d = TriangleIntersector2d( 29 | triangles2d, resolution) 30 | 31 | def query(self, points): 32 | # Rescale points 33 | points = self.rescale(points) 34 | 35 | # placeholder result with no hits we'll fill in later 36 | contains = np.zeros(len(points), dtype=np.bool) 37 | 38 | # cull points outside of the axis aligned bounding box 39 | # this avoids running ray tests unless points are close 40 | inside_aabb = np.all( 41 | (0 <= points) & (points <= self.resolution), axis=1) 42 | if not inside_aabb.any(): 43 | return contains 44 | 45 | # Only consider points inside bounding box 46 | mask = inside_aabb 47 | points = points[mask] 48 | 49 | # Compute intersection depth and check order 50 | points_indices, tri_indices = self._tri_intersector2d.query(points[:, :2]) 51 | 52 | triangles_intersect = self._triangles[tri_indices] 53 | points_intersect = points[points_indices] 54 | 55 | depth_intersect, abs_n_2 = self.compute_intersection_depth( 56 | points_intersect, triangles_intersect) 57 | 58 | # Count number of intersections in both directions 59 | smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2 60 | bigger_depth = depth_intersect < points_intersect[:, 2] * abs_n_2 61 | points_indices_0 = points_indices[smaller_depth] 62 | points_indices_1 = points_indices[bigger_depth] 63 | 64 | nintersect0 = np.bincount(points_indices_0, minlength=points.shape[0]) 65 | nintersect1 = np.bincount(points_indices_1, minlength=points.shape[0]) 66 | 67 | # Check if point contained in mesh 68 | contains1 = (np.mod(nintersect0, 2) == 1) 69 | contains2 = (np.mod(nintersect1, 2) == 1) 70 | if (contains1 != contains2).any(): 71 | print('Warning: contains1 != contains2 for some points.') 72 | contains[mask] = (contains1 & contains2) 73 | return contains 74 | 75 | def compute_intersection_depth(self, points, triangles): 76 | t1 = triangles[:, 0, :] 77 | t2 = triangles[:, 1, :] 78 | t3 = triangles[:, 2, :] 79 | 80 | v1 = t3 - t1 81 | v2 = t2 - t1 82 | # v1 = v1 / np.linalg.norm(v1, axis=-1, keepdims=True) 83 | # v2 = v2 / np.linalg.norm(v2, axis=-1, keepdims=True) 84 | 85 | normals = np.cross(v1, v2) 86 | alpha = np.sum(normals[:, :2] * (t1[:, :2] - points[:, :2]), axis=1) 87 | 88 | n_2 = normals[:, 2] 89 | t1_2 = t1[:, 2] 90 | s_n_2 = np.sign(n_2) 91 | abs_n_2 = np.abs(n_2) 92 | 93 | mask = (abs_n_2 != 0) 94 | 95 | depth_intersect = np.full(points.shape[0], np.nan) 96 | depth_intersect[mask] = \ 97 | t1_2[mask] * abs_n_2[mask] + alpha[mask] * s_n_2[mask] 98 | 99 | # Test the depth: 100 | # TODO: remove and put into tests 101 | # points_new = np.concatenate([points[:, :2], depth_intersect[:, None]], axis=1) 102 | # alpha = (normals * t1).sum(-1) 103 | # mask = (depth_intersect == depth_intersect) 104 | # assert(np.allclose((points_new[mask] * normals[mask]).sum(-1), 105 | # alpha[mask])) 106 | return depth_intersect, abs_n_2 107 | 108 | def rescale(self, array): 109 | array = self.scale * array + self.translate 110 | return array 111 | 112 | 113 | class TriangleIntersector2d: 114 | def __init__(self, triangles, resolution=128): 115 | self.triangles = triangles 116 | self.tri_hash = _TriangleHash(triangles, resolution) 117 | 118 | def query(self, points): 119 | point_indices, tri_indices = self.tri_hash.query(points) 120 | point_indices = np.array(point_indices, dtype=np.int64) 121 | tri_indices = np.array(tri_indices, dtype=np.int64) 122 | points = points[point_indices] 123 | triangles = self.triangles[tri_indices] 124 | mask = self.check_triangles(points, triangles) 125 | point_indices = point_indices[mask] 126 | tri_indices = tri_indices[mask] 127 | return point_indices, tri_indices 128 | 129 | def check_triangles(self, points, triangles): 130 | contains = np.zeros(points.shape[0], dtype=np.bool) 131 | A = triangles[:, :2] - triangles[:, 2:] 132 | A = A.transpose([0, 2, 1]) 133 | y = points - triangles[:, 2] 134 | 135 | detA = A[:, 0, 0] * A[:, 1, 1] - A[:, 0, 1] * A[:, 1, 0] 136 | 137 | mask = (np.abs(detA) != 0.) 138 | A = A[mask] 139 | y = y[mask] 140 | detA = detA[mask] 141 | 142 | s_detA = np.sign(detA) 143 | abs_detA = np.abs(detA) 144 | 145 | u = (A[:, 1, 1] * y[:, 0] - A[:, 0, 1] * y[:, 1]) * s_detA 146 | v = (-A[:, 1, 0] * y[:, 0] + A[:, 0, 0] * y[:, 1]) * s_detA 147 | 148 | sum_uv = u + v 149 | contains[mask] = ( 150 | (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA) 151 | & (0 < sum_uv) & (sum_uv < abs_detA) 152 | ) 153 | return contains 154 | 155 | -------------------------------------------------------------------------------- /src/utils/libmesh/triangle_hash.pyx: -------------------------------------------------------------------------------- 1 | 2 | # distutils: language=c++ 3 | import numpy as np 4 | cimport numpy as np 5 | cimport cython 6 | from libcpp.vector cimport vector 7 | from libc.math cimport floor, ceil 8 | 9 | cdef class TriangleHash: 10 | cdef vector[vector[int]] spatial_hash 11 | cdef int resolution 12 | 13 | def __cinit__(self, double[:, :, :] triangles, int resolution): 14 | self.spatial_hash.resize(resolution * resolution) 15 | self.resolution = resolution 16 | self._build_hash(triangles) 17 | 18 | @cython.boundscheck(False) # Deactivate bounds checking 19 | @cython.wraparound(False) # Deactivate negative indexing. 20 | cdef int _build_hash(self, double[:, :, :] triangles): 21 | assert(triangles.shape[1] == 3) 22 | assert(triangles.shape[2] == 2) 23 | 24 | cdef int n_tri = triangles.shape[0] 25 | cdef int bbox_min[2] 26 | cdef int bbox_max[2] 27 | 28 | cdef int i_tri, j, x, y 29 | cdef int spatial_idx 30 | 31 | for i_tri in range(n_tri): 32 | # Compute bounding box 33 | for j in range(2): 34 | bbox_min[j] = min( 35 | triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j] 36 | ) 37 | bbox_max[j] = max( 38 | triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j] 39 | ) 40 | bbox_min[j] = min(max(bbox_min[j], 0), self.resolution - 1) 41 | bbox_max[j] = min(max(bbox_max[j], 0), self.resolution - 1) 42 | 43 | # Find all voxels where bounding box intersects 44 | for x in range(bbox_min[0], bbox_max[0] + 1): 45 | for y in range(bbox_min[1], bbox_max[1] + 1): 46 | spatial_idx = self.resolution * x + y 47 | self.spatial_hash[spatial_idx].push_back(i_tri) 48 | 49 | @cython.boundscheck(False) # Deactivate bounds checking 50 | @cython.wraparound(False) # Deactivate negative indexing. 51 | cpdef query(self, double[:, :] points): 52 | assert(points.shape[1] == 2) 53 | cdef int n_points = points.shape[0] 54 | 55 | cdef vector[int] points_indices 56 | cdef vector[int] tri_indices 57 | # cdef int[:] points_indices_np 58 | # cdef int[:] tri_indices_np 59 | 60 | cdef int i_point, k, x, y 61 | cdef int spatial_idx 62 | 63 | for i_point in range(n_points): 64 | x = int(points[i_point, 0]) 65 | y = int(points[i_point, 1]) 66 | if not (0 <= x < self.resolution and 0 <= y < self.resolution): 67 | continue 68 | 69 | spatial_idx = self.resolution * x + y 70 | for i_tri in self.spatial_hash[spatial_idx]: 71 | points_indices.push_back(i_point) 72 | tri_indices.push_back(i_tri) 73 | 74 | points_indices_np = np.zeros(points_indices.size(), dtype=np.int32) 75 | tri_indices_np = np.zeros(tri_indices.size(), dtype=np.int32) 76 | 77 | cdef int[:] points_indices_view = points_indices_np 78 | cdef int[:] tri_indices_view = tri_indices_np 79 | 80 | for k in range(points_indices.size()): 81 | points_indices_view[k] = points_indices[k] 82 | 83 | for k in range(tri_indices.size()): 84 | tri_indices_view[k] = tri_indices[k] 85 | 86 | return points_indices_np, tri_indices_np 87 | -------------------------------------------------------------------------------- /src/utils/libmise/.gitignore: -------------------------------------------------------------------------------- 1 | mise.c 2 | mise.cpp 3 | mise.html 4 | -------------------------------------------------------------------------------- /src/utils/libmise/__init__.py: -------------------------------------------------------------------------------- 1 | from .mise import MISE 2 | 3 | __all__ = [ 4 | MISE 5 | ] 6 | -------------------------------------------------------------------------------- /src/utils/libmise/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mise import MISE 3 | import time 4 | 5 | t0 = time.time() 6 | extractor = MISE(1, 2, 0.) 7 | 8 | p = extractor.query() 9 | i = 0 10 | 11 | while p.shape[0] != 0: 12 | print(i) 13 | print(p) 14 | v = 2 * (p.sum(axis=-1) > 2).astype(np.float64) - 1 15 | extractor.update(p, v) 16 | p = extractor.query() 17 | i += 1 18 | if (i >= 8): 19 | break 20 | 21 | print(extractor.to_dense()) 22 | # p, v = extractor.get_points() 23 | # print(p) 24 | # print(v) 25 | print('Total time: %f' % (time.time() - t0)) 26 | -------------------------------------------------------------------------------- /src/utils/libsimplify/__init__.py: -------------------------------------------------------------------------------- 1 | from .simplify_mesh import ( 2 | mesh_simplify 3 | ) 4 | import trimesh 5 | 6 | 7 | def simplify_mesh(mesh, f_target=10000, agressiveness=7.): 8 | vertices = mesh.vertices 9 | faces = mesh.faces 10 | 11 | vertices, faces = mesh_simplify(vertices, faces, f_target, agressiveness) 12 | 13 | mesh_simplified = trimesh.Trimesh(vertices, faces, process=False) 14 | 15 | return mesh_simplified 16 | -------------------------------------------------------------------------------- /src/utils/libsimplify/simplify_mesh.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | from libcpp.vector cimport vector 3 | import numpy as np 4 | cimport numpy as np 5 | 6 | 7 | cdef extern from "Simplify.h": 8 | cdef struct vec3f: 9 | double x, y, z 10 | 11 | cdef cppclass SymetricMatrix: 12 | SymetricMatrix() except + 13 | 14 | 15 | cdef extern from "Simplify.h" namespace "Simplify": 16 | cdef struct Triangle: 17 | int v[3] 18 | double err[4] 19 | int deleted, dirty, attr 20 | vec3f uvs[3] 21 | int material 22 | 23 | cdef struct Vertex: 24 | vec3f p 25 | int tstart, tcount 26 | SymetricMatrix q 27 | int border 28 | 29 | cdef vector[Triangle] triangles 30 | cdef vector[Vertex] vertices 31 | cdef void simplify_mesh(int, double) 32 | 33 | 34 | cpdef mesh_simplify(double[:, ::1] vertices_in, long[:, ::1] triangles_in, 35 | int f_target, double agressiveness=7.) except +: 36 | vertices.clear() 37 | triangles.clear() 38 | 39 | # Read in vertices and triangles 40 | cdef Vertex v 41 | for iv in range(vertices_in.shape[0]): 42 | v = Vertex() 43 | v.p.x = vertices_in[iv, 0] 44 | v.p.y = vertices_in[iv, 1] 45 | v.p.z = vertices_in[iv, 2] 46 | vertices.push_back(v) 47 | 48 | cdef Triangle t 49 | for it in range(triangles_in.shape[0]): 50 | t = Triangle() 51 | t.v[0] = triangles_in[it, 0] 52 | t.v[1] = triangles_in[it, 1] 53 | t.v[2] = triangles_in[it, 2] 54 | triangles.push_back(t) 55 | 56 | # Simplify 57 | # print('Simplify...') 58 | simplify_mesh(f_target, agressiveness) 59 | 60 | # Only use triangles that are not deleted 61 | cdef vector[Triangle] triangles_notdel 62 | triangles_notdel.reserve(triangles.size()) 63 | 64 | for t in triangles: 65 | if not t.deleted: 66 | triangles_notdel.push_back(t) 67 | 68 | # Read out triangles 69 | vertices_out = np.empty((vertices.size(), 3), dtype=np.float64) 70 | triangles_out = np.empty((triangles_notdel.size(), 3), dtype=np.int64) 71 | 72 | cdef double[:, :] vertices_out_view = vertices_out 73 | cdef long[:, :] triangles_out_view = triangles_out 74 | 75 | for iv in range(vertices.size()): 76 | vertices_out_view[iv, 0] = vertices[iv].p.x 77 | vertices_out_view[iv, 1] = vertices[iv].p.y 78 | vertices_out_view[iv, 2] = vertices[iv].p.z 79 | 80 | for it in range(triangles_notdel.size()): 81 | triangles_out_view[it, 0] = triangles_notdel[it].v[0] 82 | triangles_out_view[it, 1] = triangles_notdel[it].v[1] 83 | triangles_out_view[it, 2] = triangles_notdel[it].v[2] 84 | 85 | # Clear vertices and triangles 86 | vertices.clear() 87 | triangles.clear() 88 | 89 | return vertices_out, triangles_out -------------------------------------------------------------------------------- /src/utils/libsimplify/test.py: -------------------------------------------------------------------------------- 1 | from simplify_mesh import mesh_simplify 2 | import numpy as np 3 | 4 | v = np.random.rand(100, 3) 5 | f = np.random.choice(range(100), (50, 3)) 6 | 7 | mesh_simplify(v, f, 50) -------------------------------------------------------------------------------- /src/utils/libvoxelize/.gitignore: -------------------------------------------------------------------------------- 1 | voxelize.c 2 | voxelize.html 3 | build 4 | -------------------------------------------------------------------------------- /src/utils/libvoxelize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/convolutional_occupancy_networks/838bea5b2f1314f2edbb68d05ebb0db49f1f3bd2/src/utils/libvoxelize/__init__.py -------------------------------------------------------------------------------- /src/utils/libvoxelize/tribox2.h: -------------------------------------------------------------------------------- 1 | /********************************************************/ 2 | /* AABB-triangle overlap test code */ 3 | /* by Tomas Akenine-M�ller */ 4 | /* Function: int triBoxOverlap(float boxcenter[3], */ 5 | /* float boxhalfsize[3],float triverts[3][3]); */ 6 | /* History: */ 7 | /* 2001-03-05: released the code in its first version */ 8 | /* 2001-06-18: changed the order of the tests, faster */ 9 | /* */ 10 | /* Acknowledgement: Many thanks to Pierre Terdiman for */ 11 | /* suggestions and discussions on how to optimize code. */ 12 | /* Thanks to David Hunt for finding a ">="-bug! */ 13 | /********************************************************/ 14 | #include 15 | #include 16 | 17 | #define X 0 18 | #define Y 1 19 | #define Z 2 20 | 21 | #define CROSS(dest,v1,v2) \ 22 | dest[0]=v1[1]*v2[2]-v1[2]*v2[1]; \ 23 | dest[1]=v1[2]*v2[0]-v1[0]*v2[2]; \ 24 | dest[2]=v1[0]*v2[1]-v1[1]*v2[0]; 25 | 26 | #define DOT(v1,v2) (v1[0]*v2[0]+v1[1]*v2[1]+v1[2]*v2[2]) 27 | 28 | #define SUB(dest,v1,v2) \ 29 | dest[0]=v1[0]-v2[0]; \ 30 | dest[1]=v1[1]-v2[1]; \ 31 | dest[2]=v1[2]-v2[2]; 32 | 33 | #define FINDMINMAX(x0,x1,x2,min,max) \ 34 | min = max = x0; \ 35 | if(x1max) max=x1;\ 37 | if(x2max) max=x2; 39 | 40 | int planeBoxOverlap(float normal[3],float d, float maxbox[3]) 41 | { 42 | int q; 43 | float vmin[3],vmax[3]; 44 | for(q=X;q<=Z;q++) 45 | { 46 | if(normal[q]>0.0f) 47 | { 48 | vmin[q]=-maxbox[q]; 49 | vmax[q]=maxbox[q]; 50 | } 51 | else 52 | { 53 | vmin[q]=maxbox[q]; 54 | vmax[q]=-maxbox[q]; 55 | } 56 | } 57 | if(DOT(normal,vmin)+d>0.0f) return 0; 58 | if(DOT(normal,vmax)+d>=0.0f) return 1; 59 | 60 | return 0; 61 | } 62 | 63 | 64 | /*======================== X-tests ========================*/ 65 | #define AXISTEST_X01(a, b, fa, fb) \ 66 | p0 = a*v0[Y] - b*v0[Z]; \ 67 | p2 = a*v2[Y] - b*v2[Z]; \ 68 | if(p0rad || max<-rad) return 0; 71 | 72 | #define AXISTEST_X2(a, b, fa, fb) \ 73 | p0 = a*v0[Y] - b*v0[Z]; \ 74 | p1 = a*v1[Y] - b*v1[Z]; \ 75 | if(p0rad || max<-rad) return 0; 78 | 79 | /*======================== Y-tests ========================*/ 80 | #define AXISTEST_Y02(a, b, fa, fb) \ 81 | p0 = -a*v0[X] + b*v0[Z]; \ 82 | p2 = -a*v2[X] + b*v2[Z]; \ 83 | if(p0rad || max<-rad) return 0; 86 | 87 | #define AXISTEST_Y1(a, b, fa, fb) \ 88 | p0 = -a*v0[X] + b*v0[Z]; \ 89 | p1 = -a*v1[X] + b*v1[Z]; \ 90 | if(p0rad || max<-rad) return 0; 93 | 94 | /*======================== Z-tests ========================*/ 95 | 96 | #define AXISTEST_Z12(a, b, fa, fb) \ 97 | p1 = a*v1[X] - b*v1[Y]; \ 98 | p2 = a*v2[X] - b*v2[Y]; \ 99 | if(p2rad || max<-rad) return 0; 102 | 103 | #define AXISTEST_Z0(a, b, fa, fb) \ 104 | p0 = a*v0[X] - b*v0[Y]; \ 105 | p1 = a*v1[X] - b*v1[Y]; \ 106 | if(p0rad || max<-rad) return 0; 109 | 110 | int triBoxOverlap(float boxcenter[3],float boxhalfsize[3],float tri0[3], float tri1[3], float tri2[3]) 111 | { 112 | 113 | /* use separating axis theorem to test overlap between triangle and box */ 114 | /* need to test for overlap in these directions: */ 115 | /* 1) the {x,y,z}-directions (actually, since we use the AABB of the triangle */ 116 | /* we do not even need to test these) */ 117 | /* 2) normal of the triangle */ 118 | /* 3) crossproduct(edge from tri, {x,y,z}-directin) */ 119 | /* this gives 3x3=9 more tests */ 120 | float v0[3],v1[3],v2[3]; 121 | float min,max,d,p0,p1,p2,rad,fex,fey,fez; 122 | float normal[3],e0[3],e1[3],e2[3]; 123 | 124 | /* This is the fastest branch on Sun */ 125 | /* move everything so that the boxcenter is in (0,0,0) */ 126 | SUB(v0, tri0, boxcenter); 127 | SUB(v1, tri1, boxcenter); 128 | SUB(v2, tri2, boxcenter); 129 | 130 | /* compute triangle edges */ 131 | SUB(e0,v1,v0); /* tri edge 0 */ 132 | SUB(e1,v2,v1); /* tri edge 1 */ 133 | SUB(e2,v0,v2); /* tri edge 2 */ 134 | 135 | /* Bullet 3: */ 136 | /* test the 9 tests first (this was faster) */ 137 | fex = fabs(e0[X]); 138 | fey = fabs(e0[Y]); 139 | fez = fabs(e0[Z]); 140 | AXISTEST_X01(e0[Z], e0[Y], fez, fey); 141 | AXISTEST_Y02(e0[Z], e0[X], fez, fex); 142 | AXISTEST_Z12(e0[Y], e0[X], fey, fex); 143 | 144 | fex = fabs(e1[X]); 145 | fey = fabs(e1[Y]); 146 | fez = fabs(e1[Z]); 147 | AXISTEST_X01(e1[Z], e1[Y], fez, fey); 148 | AXISTEST_Y02(e1[Z], e1[X], fez, fex); 149 | AXISTEST_Z0(e1[Y], e1[X], fey, fex); 150 | 151 | fex = fabs(e2[X]); 152 | fey = fabs(e2[Y]); 153 | fez = fabs(e2[Z]); 154 | AXISTEST_X2(e2[Z], e2[Y], fez, fey); 155 | AXISTEST_Y1(e2[Z], e2[X], fez, fex); 156 | AXISTEST_Z12(e2[Y], e2[X], fey, fex); 157 | 158 | /* Bullet 1: */ 159 | /* first test overlap in the {x,y,z}-directions */ 160 | /* find min, max of the triangle each direction, and test for overlap in */ 161 | /* that direction -- this is equivalent to testing a minimal AABB around */ 162 | /* the triangle against the AABB */ 163 | 164 | /* test in X-direction */ 165 | FINDMINMAX(v0[X],v1[X],v2[X],min,max); 166 | if(min>boxhalfsize[X] || max<-boxhalfsize[X]) return 0; 167 | 168 | /* test in Y-direction */ 169 | FINDMINMAX(v0[Y],v1[Y],v2[Y],min,max); 170 | if(min>boxhalfsize[Y] || max<-boxhalfsize[Y]) return 0; 171 | 172 | /* test in Z-direction */ 173 | FINDMINMAX(v0[Z],v1[Z],v2[Z],min,max); 174 | if(min>boxhalfsize[Z] || max<-boxhalfsize[Z]) return 0; 175 | 176 | /* Bullet 2: */ 177 | /* test if the box intersects the plane of the triangle */ 178 | /* compute plane equation of triangle: normal*x+d=0 */ 179 | CROSS(normal,e0,e1); 180 | d=-DOT(normal,v0); /* plane eq: normal.x+d=0 */ 181 | if(!planeBoxOverlap(normal,d,boxhalfsize)) return 0; 182 | 183 | return 1; /* box and triangle overlaps */ 184 | } 185 | -------------------------------------------------------------------------------- /src/utils/libvoxelize/voxelize.pyx: -------------------------------------------------------------------------------- 1 | cimport cython 2 | from libc.math cimport floor, ceil 3 | from cython.view cimport array as cvarray 4 | 5 | cdef extern from "tribox2.h": 6 | int triBoxOverlap(float boxcenter[3], float boxhalfsize[3], 7 | float tri0[3], float tri1[3], float tri2[3]) 8 | 9 | 10 | @cython.boundscheck(False) # Deactivate bounds checking 11 | @cython.wraparound(False) # Deactivate negative indexing. 12 | cpdef int voxelize_mesh_(bint[:, :, :] occ, float[:, :, ::1] faces): 13 | assert(faces.shape[1] == 3) 14 | assert(faces.shape[2] == 3) 15 | 16 | n_faces = faces.shape[0] 17 | cdef int i 18 | for i in range(n_faces): 19 | voxelize_triangle_(occ, faces[i]) 20 | 21 | 22 | @cython.boundscheck(False) # Deactivate bounds checking 23 | @cython.wraparound(False) # Deactivate negative indexing. 24 | cpdef int voxelize_triangle_(bint[:, :, :] occupancies, float[:, ::1] triverts): 25 | cdef int bbox_min[3] 26 | cdef int bbox_max[3] 27 | cdef int i, j, k 28 | cdef float boxhalfsize[3] 29 | cdef float boxcenter[3] 30 | cdef bint intersection 31 | 32 | boxhalfsize[:] = (0.5, 0.5, 0.5) 33 | 34 | for i in range(3): 35 | bbox_min[i] = ( 36 | min(triverts[0, i], triverts[1, i], triverts[2, i]) 37 | ) 38 | bbox_min[i] = min(max(bbox_min[i], 0), occupancies.shape[i] - 1) 39 | 40 | for i in range(3): 41 | bbox_max[i] = ( 42 | max(triverts[0, i], triverts[1, i], triverts[2, i]) 43 | ) 44 | bbox_max[i] = min(max(bbox_max[i], 0), occupancies.shape[i] - 1) 45 | 46 | for i in range(bbox_min[0], bbox_max[0] + 1): 47 | for j in range(bbox_min[1], bbox_max[1] + 1): 48 | for k in range(bbox_min[2], bbox_max[2] + 1): 49 | boxcenter[:] = (i + 0.5, j + 0.5, k + 0.5) 50 | intersection = triBoxOverlap(&boxcenter[0], &boxhalfsize[0], 51 | &triverts[0, 0], &triverts[1, 0], &triverts[2, 0]) 52 | occupancies[i, j, k] |= intersection 53 | 54 | 55 | @cython.boundscheck(False) # Deactivate bounds checking 56 | @cython.wraparound(False) # Deactivate negative indexing. 57 | cdef int test_triangle_aabb(float[::1] boxcenter, float[::1] boxhalfsize, float[:, ::1] triverts): 58 | assert(boxcenter.shape[0] == 3) 59 | assert(boxhalfsize.shape[0] == 3) 60 | assert(triverts.shape[0] == triverts.shape[1] == 3) 61 | 62 | # print(triverts) 63 | # Call functions 64 | cdef int result = triBoxOverlap(&boxcenter[0], &boxhalfsize[0], 65 | &triverts[0, 0], &triverts[1, 0], &triverts[2, 0]) 66 | return result 67 | -------------------------------------------------------------------------------- /src/utils/mesh.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial import Delaunay 2 | from itertools import combinations 3 | import numpy as np 4 | from src.utils import voxels 5 | 6 | 7 | class MultiGridExtractor(object): 8 | def __init__(self, resolution0, threshold): 9 | # Attributes 10 | self.resolution = resolution0 11 | self.threshold = threshold 12 | 13 | # Voxels are active or inactive, 14 | # values live on the space between voxels and are either 15 | # known exactly or guessed by interpolation (unknown) 16 | shape_voxels = (resolution0,) * 3 17 | shape_values = (resolution0 + 1,) * 3 18 | self.values = np.empty(shape_values) 19 | self.value_known = np.full(shape_values, False) 20 | self.voxel_active = np.full(shape_voxels, True) 21 | 22 | def query(self): 23 | # Query locations in grid that are active but unkown 24 | idx1, idx2, idx3 = np.where( 25 | ~self.value_known & self.value_active 26 | ) 27 | points = np.stack([idx1, idx2, idx3], axis=-1) 28 | return points 29 | 30 | def update(self, points, values): 31 | # Update locations and set known status to true 32 | idx0, idx1, idx2 = points.transpose() 33 | self.values[idx0, idx1, idx2] = values 34 | self.value_known[idx0, idx1, idx2] = True 35 | 36 | # Update activity status of voxels accordings to new values 37 | self.voxel_active = ~self.voxel_empty 38 | # ( 39 | # # self.voxel_active & 40 | # self.voxel_known & ~self.voxel_empty 41 | # ) 42 | 43 | def increase_resolution(self): 44 | self.resolution = 2 * self.resolution 45 | shape_values = (self.resolution + 1,) * 3 46 | 47 | value_known = np.full(shape_values, False) 48 | value_known[::2, ::2, ::2] = self.value_known 49 | values = upsample3d_nn(self.values) 50 | values = values[:-1, :-1, :-1] 51 | 52 | self.values = values 53 | self.value_known = value_known 54 | self.voxel_active = upsample3d_nn(self.voxel_active) 55 | 56 | @property 57 | def occupancies(self): 58 | return (self.values < self.threshold) 59 | 60 | @property 61 | def value_active(self): 62 | value_active = np.full(self.values.shape, False) 63 | # Active if adjacent to active voxel 64 | value_active[:-1, :-1, :-1] |= self.voxel_active 65 | value_active[:-1, :-1, 1:] |= self.voxel_active 66 | value_active[:-1, 1:, :-1] |= self.voxel_active 67 | value_active[:-1, 1:, 1:] |= self.voxel_active 68 | value_active[1:, :-1, :-1] |= self.voxel_active 69 | value_active[1:, :-1, 1:] |= self.voxel_active 70 | value_active[1:, 1:, :-1] |= self.voxel_active 71 | value_active[1:, 1:, 1:] |= self.voxel_active 72 | 73 | return value_active 74 | 75 | @property 76 | def voxel_known(self): 77 | value_known = self.value_known 78 | voxel_known = voxels.check_voxel_occupied(value_known) 79 | return voxel_known 80 | 81 | @property 82 | def voxel_empty(self): 83 | occ = self.occupancies 84 | return ~voxels.check_voxel_boundary(occ) 85 | 86 | 87 | def upsample3d_nn(x): 88 | xshape = x.shape 89 | yshape = (2*xshape[0], 2*xshape[1], 2*xshape[2]) 90 | 91 | y = np.zeros(yshape, dtype=x.dtype) 92 | y[::2, ::2, ::2] = x 93 | y[::2, ::2, 1::2] = x 94 | y[::2, 1::2, ::2] = x 95 | y[::2, 1::2, 1::2] = x 96 | y[1::2, ::2, ::2] = x 97 | y[1::2, ::2, 1::2] = x 98 | y[1::2, 1::2, ::2] = x 99 | y[1::2, 1::2, 1::2] = x 100 | 101 | return y 102 | 103 | 104 | class DelauneyMeshExtractor(object): 105 | """Algorithm for extacting meshes from implicit function using 106 | delauney triangulation and random sampling.""" 107 | def __init__(self, points, values, threshold=0.): 108 | self.points = points 109 | self.values = values 110 | self.delaunay = Delaunay(self.points) 111 | self.threshold = threshold 112 | 113 | def update(self, points, values, reduce_to_active=True): 114 | # Find all active points 115 | if reduce_to_active: 116 | active_simplices = self.active_simplices() 117 | active_point_idx = np.unique(active_simplices.flatten()) 118 | self.points = self.points[active_point_idx] 119 | self.values = self.values[active_point_idx] 120 | 121 | self.points = np.concatenate([self.points, points], axis=0) 122 | self.values = np.concatenate([self.values, values], axis=0) 123 | self.delaunay = Delaunay(self.points) 124 | 125 | def extract_mesh(self): 126 | threshold = self.threshold 127 | vertices = [] 128 | triangles = [] 129 | vertex_dict = dict() 130 | 131 | active_simplices = self.active_simplices() 132 | active_simplices.sort(axis=1) 133 | for simplex in active_simplices: 134 | new_vertices = [] 135 | for i1, i2 in combinations(simplex, 2): 136 | assert(i1 < i2) 137 | v1 = self.values[i1] 138 | v2 = self.values[i2] 139 | if (v1 < threshold) ^ (v2 < threshold): 140 | # Subdivide edge 141 | vertex_idx = vertex_dict.get((i1, i2), len(vertices)) 142 | vertex_idx = len(vertices) 143 | if vertex_idx == len(vertices): 144 | tau = (threshold - v1) / (v2 - v1) 145 | assert(0 <= tau <= 1) 146 | p = (1 - tau) * self.points[i1] + tau * self.points[i2] 147 | vertices.append(p) 148 | vertex_dict[i1, i2] = vertex_idx 149 | new_vertices.append(vertex_idx) 150 | 151 | assert(len(new_vertices) in (3, 4)) 152 | p0 = self.points[simplex[0]] 153 | v0 = self.values[simplex[0]] 154 | if len(new_vertices) == 3: 155 | i1, i2, i3 = new_vertices 156 | p1, p2, p3 = vertices[i1], vertices[i2], vertices[i3] 157 | vol = get_tetrahedon_volume(np.asarray([p0, p1, p2, p3])) 158 | if vol * (v0 - threshold) <= 0: 159 | triangles.append((i1, i2, i3)) 160 | else: 161 | triangles.append((i1, i3, i2)) 162 | elif len(new_vertices) == 4: 163 | i1, i2, i3, i4 = new_vertices 164 | p1, p2, p3, p4 = \ 165 | vertices[i1], vertices[i2], vertices[i3], vertices[i4] 166 | vol = get_tetrahedon_volume(np.asarray([p0, p1, p2, p3])) 167 | if vol * (v0 - threshold) <= 0: 168 | triangles.append((i1, i2, i3)) 169 | else: 170 | triangles.append((i1, i3, i2)) 171 | 172 | vol = get_tetrahedon_volume(np.asarray([p0, p2, p3, p4])) 173 | if vol * (v0 - threshold) <= 0: 174 | triangles.append((i2, i3, i4)) 175 | else: 176 | triangles.append((i2, i4, i3)) 177 | 178 | vertices = np.asarray(vertices, dtype=np.float32) 179 | triangles = np.asarray(triangles, dtype=np.int32) 180 | 181 | return vertices, triangles 182 | 183 | def query(self, size): 184 | active_simplices = self.active_simplices() 185 | active_simplices_points = self.points[active_simplices] 186 | new_points = sample_tetraheda(active_simplices_points, size=size) 187 | return new_points 188 | 189 | def active_simplices(self): 190 | occ = (self.values >= self.threshold) 191 | simplices = self.delaunay.simplices 192 | simplices_occ = occ[simplices] 193 | 194 | active = ( 195 | np.any(simplices_occ, axis=1) & np.any(~simplices_occ, axis=1) 196 | ) 197 | 198 | simplices = self.delaunay.simplices[active] 199 | return simplices 200 | 201 | 202 | def sample_tetraheda(tetraheda_points, size): 203 | N_tetraheda = tetraheda_points.shape[0] 204 | volume = np.abs(get_tetrahedon_volume(tetraheda_points)) 205 | probs = volume / volume.sum() 206 | 207 | tetraheda_rnd = np.random.choice(range(N_tetraheda), p=probs, size=size) 208 | tetraheda_rnd_points = tetraheda_points[tetraheda_rnd] 209 | weights_rnd = np.random.dirichlet([1, 1, 1, 1], size=size) 210 | weights_rnd = weights_rnd.reshape(size, 4, 1) 211 | points_rnd = (weights_rnd * tetraheda_rnd_points).sum(axis=1) 212 | # points_rnd = tetraheda_rnd_points.mean(1) 213 | 214 | return points_rnd 215 | 216 | 217 | def get_tetrahedon_volume(points): 218 | vectors = points[..., :3, :] - points[..., 3:, :] 219 | volume = 1/6 * np.linalg.det(vectors) 220 | return volume 221 | -------------------------------------------------------------------------------- /src/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | import src.common as common 5 | 6 | 7 | def visualize_data(data, data_type, out_file): 8 | r''' Visualizes the data with regard to its type. 9 | 10 | Args: 11 | data (tensor): batch of data 12 | data_type (string): data type (img, voxels or pointcloud) 13 | out_file (string): output file 14 | ''' 15 | if data_type == 'voxels': 16 | visualize_voxels(data, out_file=out_file) 17 | elif data_type == 'pointcloud': 18 | visualize_pointcloud(data, out_file=out_file) 19 | elif data_type is None or data_type == 'idx': 20 | pass 21 | else: 22 | raise ValueError('Invalid data_type "%s"' % data_type) 23 | 24 | 25 | def visualize_voxels(voxels, out_file=None, show=False): 26 | r''' Visualizes voxel data. 27 | 28 | Args: 29 | voxels (tensor): voxel data 30 | out_file (string): output file 31 | show (bool): whether the plot should be shown 32 | ''' 33 | # Use numpy 34 | voxels = np.asarray(voxels) 35 | # Create plot 36 | fig = plt.figure() 37 | ax = fig.gca(projection=Axes3D.name) 38 | voxels = voxels.transpose(2, 0, 1) 39 | ax.voxels(voxels, edgecolor='k') 40 | ax.set_xlabel('Z') 41 | ax.set_ylabel('X') 42 | ax.set_zlabel('Y') 43 | ax.view_init(elev=30, azim=45) 44 | if out_file is not None: 45 | plt.savefig(out_file) 46 | if show: 47 | plt.show() 48 | plt.close(fig) 49 | 50 | 51 | def visualize_pointcloud(points, normals=None, 52 | out_file=None, show=False): 53 | r''' Visualizes point cloud data. 54 | 55 | Args: 56 | points (tensor): point data 57 | normals (tensor): normal data (if existing) 58 | out_file (string): output file 59 | show (bool): whether the plot should be shown 60 | ''' 61 | # Use numpy 62 | points = np.asarray(points) 63 | # Create plot 64 | fig = plt.figure() 65 | ax = fig.gca(projection=Axes3D.name) 66 | ax.scatter(points[:, 2], points[:, 0], points[:, 1]) 67 | if normals is not None: 68 | ax.quiver( 69 | points[:, 2], points[:, 0], points[:, 1], 70 | normals[:, 2], normals[:, 0], normals[:, 1], 71 | length=0.1, color='k' 72 | ) 73 | ax.set_xlabel('Z') 74 | ax.set_ylabel('X') 75 | ax.set_zlabel('Y') 76 | ax.set_xlim(-0.5, 0.5) 77 | ax.set_ylim(-0.5, 0.5) 78 | ax.set_zlim(-0.5, 0.5) 79 | ax.view_init(elev=30, azim=45) 80 | if out_file is not None: 81 | plt.savefig(out_file) 82 | if show: 83 | plt.show() 84 | plt.close(fig) 85 | 86 | -------------------------------------------------------------------------------- /src/utils/voxels.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import trimesh 4 | from scipy import ndimage 5 | from skimage.measure import block_reduce 6 | from src.utils.libvoxelize.voxelize import voxelize_mesh_ 7 | from src.utils.libmesh import check_mesh_contains 8 | from src.common import make_3d_grid 9 | 10 | 11 | class VoxelGrid: 12 | def __init__(self, data, loc=(0., 0., 0.), scale=1): 13 | assert(data.shape[0] == data.shape[1] == data.shape[2]) 14 | data = np.asarray(data, dtype=np.bool) 15 | loc = np.asarray(loc) 16 | self.data = data 17 | self.loc = loc 18 | self.scale = scale 19 | 20 | @classmethod 21 | def from_mesh(cls, mesh, resolution, loc=None, scale=None, method='ray'): 22 | bounds = mesh.bounds 23 | # Default location is center 24 | if loc is None: 25 | loc = (bounds[0] + bounds[1]) / 2 26 | 27 | # Default scale, scales the mesh to [-0.45, 0.45]^3 28 | if scale is None: 29 | scale = (bounds[1] - bounds[0]).max()/0.9 30 | 31 | loc = np.asarray(loc) 32 | scale = float(scale) 33 | 34 | # Transform mesh 35 | mesh = mesh.copy() 36 | mesh.apply_translation(-loc) 37 | mesh.apply_scale(1/scale) 38 | 39 | # Apply method 40 | if method == 'ray': 41 | voxel_data = voxelize_ray(mesh, resolution) 42 | elif method == 'fill': 43 | voxel_data = voxelize_fill(mesh, resolution) 44 | 45 | voxels = cls(voxel_data, loc, scale) 46 | return voxels 47 | 48 | def down_sample(self, factor=2): 49 | if not (self.resolution % factor) == 0: 50 | raise ValueError('Resolution must be divisible by factor.') 51 | new_data = block_reduce(self.data, (factor,) * 3, np.max) 52 | return VoxelGrid(new_data, self.loc, self.scale) 53 | 54 | def to_mesh(self): 55 | # Shorthand 56 | occ = self.data 57 | 58 | # Shape of voxel grid 59 | nx, ny, nz = occ.shape 60 | # Shape of corresponding occupancy grid 61 | grid_shape = (nx + 1, ny + 1, nz + 1) 62 | 63 | # Convert values to occupancies 64 | occ = np.pad(occ, 1, 'constant') 65 | 66 | # Determine if face present 67 | f1_r = (occ[:-1, 1:-1, 1:-1] & ~occ[1:, 1:-1, 1:-1]) 68 | f2_r = (occ[1:-1, :-1, 1:-1] & ~occ[1:-1, 1:, 1:-1]) 69 | f3_r = (occ[1:-1, 1:-1, :-1] & ~occ[1:-1, 1:-1, 1:]) 70 | 71 | f1_l = (~occ[:-1, 1:-1, 1:-1] & occ[1:, 1:-1, 1:-1]) 72 | f2_l = (~occ[1:-1, :-1, 1:-1] & occ[1:-1, 1:, 1:-1]) 73 | f3_l = (~occ[1:-1, 1:-1, :-1] & occ[1:-1, 1:-1, 1:]) 74 | 75 | f1 = f1_r | f1_l 76 | f2 = f2_r | f2_l 77 | f3 = f3_r | f3_l 78 | 79 | assert(f1.shape == (nx + 1, ny, nz)) 80 | assert(f2.shape == (nx, ny + 1, nz)) 81 | assert(f3.shape == (nx, ny, nz + 1)) 82 | 83 | # Determine if vertex present 84 | v = np.full(grid_shape, False) 85 | 86 | v[:, :-1, :-1] |= f1 87 | v[:, :-1, 1:] |= f1 88 | v[:, 1:, :-1] |= f1 89 | v[:, 1:, 1:] |= f1 90 | 91 | v[:-1, :, :-1] |= f2 92 | v[:-1, :, 1:] |= f2 93 | v[1:, :, :-1] |= f2 94 | v[1:, :, 1:] |= f2 95 | 96 | v[:-1, :-1, :] |= f3 97 | v[:-1, 1:, :] |= f3 98 | v[1:, :-1, :] |= f3 99 | v[1:, 1:, :] |= f3 100 | 101 | # Calculate indices for vertices 102 | n_vertices = v.sum() 103 | v_idx = np.full(grid_shape, -1) 104 | v_idx[v] = np.arange(n_vertices) 105 | 106 | # Vertices 107 | v_x, v_y, v_z = np.where(v) 108 | v_x = v_x / nx - 0.5 109 | v_y = v_y / ny - 0.5 110 | v_z = v_z / nz - 0.5 111 | vertices = np.stack([v_x, v_y, v_z], axis=1) 112 | 113 | # Face indices 114 | f1_l_x, f1_l_y, f1_l_z = np.where(f1_l) 115 | f2_l_x, f2_l_y, f2_l_z = np.where(f2_l) 116 | f3_l_x, f3_l_y, f3_l_z = np.where(f3_l) 117 | 118 | f1_r_x, f1_r_y, f1_r_z = np.where(f1_r) 119 | f2_r_x, f2_r_y, f2_r_z = np.where(f2_r) 120 | f3_r_x, f3_r_y, f3_r_z = np.where(f3_r) 121 | 122 | faces_1_l = np.stack([ 123 | v_idx[f1_l_x, f1_l_y, f1_l_z], 124 | v_idx[f1_l_x, f1_l_y, f1_l_z + 1], 125 | v_idx[f1_l_x, f1_l_y + 1, f1_l_z + 1], 126 | v_idx[f1_l_x, f1_l_y + 1, f1_l_z], 127 | ], axis=1) 128 | 129 | faces_1_r = np.stack([ 130 | v_idx[f1_r_x, f1_r_y, f1_r_z], 131 | v_idx[f1_r_x, f1_r_y + 1, f1_r_z], 132 | v_idx[f1_r_x, f1_r_y + 1, f1_r_z + 1], 133 | v_idx[f1_r_x, f1_r_y, f1_r_z + 1], 134 | ], axis=1) 135 | 136 | faces_2_l = np.stack([ 137 | v_idx[f2_l_x, f2_l_y, f2_l_z], 138 | v_idx[f2_l_x + 1, f2_l_y, f2_l_z], 139 | v_idx[f2_l_x + 1, f2_l_y, f2_l_z + 1], 140 | v_idx[f2_l_x, f2_l_y, f2_l_z + 1], 141 | ], axis=1) 142 | 143 | faces_2_r = np.stack([ 144 | v_idx[f2_r_x, f2_r_y, f2_r_z], 145 | v_idx[f2_r_x, f2_r_y, f2_r_z + 1], 146 | v_idx[f2_r_x + 1, f2_r_y, f2_r_z + 1], 147 | v_idx[f2_r_x + 1, f2_r_y, f2_r_z], 148 | ], axis=1) 149 | 150 | faces_3_l = np.stack([ 151 | v_idx[f3_l_x, f3_l_y, f3_l_z], 152 | v_idx[f3_l_x, f3_l_y + 1, f3_l_z], 153 | v_idx[f3_l_x + 1, f3_l_y + 1, f3_l_z], 154 | v_idx[f3_l_x + 1, f3_l_y, f3_l_z], 155 | ], axis=1) 156 | 157 | faces_3_r = np.stack([ 158 | v_idx[f3_r_x, f3_r_y, f3_r_z], 159 | v_idx[f3_r_x + 1, f3_r_y, f3_r_z], 160 | v_idx[f3_r_x + 1, f3_r_y + 1, f3_r_z], 161 | v_idx[f3_r_x, f3_r_y + 1, f3_r_z], 162 | ], axis=1) 163 | 164 | faces = np.concatenate([ 165 | faces_1_l, faces_1_r, 166 | faces_2_l, faces_2_r, 167 | faces_3_l, faces_3_r, 168 | ], axis=0) 169 | 170 | vertices = self.loc + self.scale * vertices 171 | mesh = trimesh.Trimesh(vertices, faces, process=False) 172 | return mesh 173 | 174 | @property 175 | def resolution(self): 176 | assert(self.data.shape[0] == self.data.shape[1] == self.data.shape[2]) 177 | return self.data.shape[0] 178 | 179 | def contains(self, points): 180 | nx = self.resolution 181 | 182 | # Rescale bounding box to [-0.5, 0.5]^3 183 | points = (points - self.loc) / self.scale 184 | # Discretize points to [0, nx-1]^3 185 | points_i = ((points + 0.5) * nx).astype(np.int32) 186 | # i1, i2, i3 have sizes (batch_size, T) 187 | i1, i2, i3 = points_i[..., 0], points_i[..., 1], points_i[..., 2] 188 | # Only use indices inside bounding box 189 | mask = ( 190 | (i1 >= 0) & (i2 >= 0) & (i3 >= 0) 191 | & (nx > i1) & (nx > i2) & (nx > i3) 192 | ) 193 | # Prevent out of bounds error 194 | i1 = i1[mask] 195 | i2 = i2[mask] 196 | i3 = i3[mask] 197 | 198 | # Compute values, default value outside box is 0 199 | occ = np.zeros(points.shape[:-1], dtype=np.bool) 200 | occ[mask] = self.data[i1, i2, i3] 201 | 202 | return occ 203 | 204 | 205 | def voxelize_ray(mesh, resolution): 206 | occ_surface = voxelize_surface(mesh, resolution) 207 | # TODO: use surface voxels here? 208 | occ_interior = voxelize_interior(mesh, resolution) 209 | occ = (occ_interior | occ_surface) 210 | return occ 211 | 212 | 213 | def voxelize_fill(mesh, resolution): 214 | bounds = mesh.bounds 215 | if (np.abs(bounds) >= 0.5).any(): 216 | raise ValueError('voxelize fill is only supported if mesh is inside [-0.5, 0.5]^3/') 217 | 218 | occ = voxelize_surface(mesh, resolution) 219 | occ = ndimage.morphology.binary_fill_holes(occ) 220 | return occ 221 | 222 | 223 | def voxelize_surface(mesh, resolution): 224 | vertices = mesh.vertices 225 | faces = mesh.faces 226 | 227 | vertices = (vertices + 0.5) * resolution 228 | 229 | face_loc = vertices[faces] 230 | occ = np.full((resolution,) * 3, 0, dtype=np.int32) 231 | face_loc = face_loc.astype(np.float32) 232 | 233 | voxelize_mesh_(occ, face_loc) 234 | occ = (occ != 0) 235 | 236 | return occ 237 | 238 | 239 | def voxelize_interior(mesh, resolution): 240 | shape = (resolution,) * 3 241 | bb_min = (0.5,) * 3 242 | bb_max = (resolution - 0.5,) * 3 243 | # Create points. Add noise to break symmetry 244 | points = make_3d_grid(bb_min, bb_max, shape=shape).numpy() 245 | points = points + 0.1 * (np.random.rand(*points.shape) - 0.5) 246 | points = (points / resolution - 0.5) 247 | occ = check_mesh_contains(mesh, points) 248 | occ = occ.reshape(shape) 249 | return occ 250 | 251 | 252 | def check_voxel_occupied(occupancy_grid): 253 | occ = occupancy_grid 254 | 255 | occupied = ( 256 | occ[..., :-1, :-1, :-1] 257 | & occ[..., :-1, :-1, 1:] 258 | & occ[..., :-1, 1:, :-1] 259 | & occ[..., :-1, 1:, 1:] 260 | & occ[..., 1:, :-1, :-1] 261 | & occ[..., 1:, :-1, 1:] 262 | & occ[..., 1:, 1:, :-1] 263 | & occ[..., 1:, 1:, 1:] 264 | ) 265 | return occupied 266 | 267 | 268 | def check_voxel_unoccupied(occupancy_grid): 269 | occ = occupancy_grid 270 | 271 | unoccupied = ~( 272 | occ[..., :-1, :-1, :-1] 273 | | occ[..., :-1, :-1, 1:] 274 | | occ[..., :-1, 1:, :-1] 275 | | occ[..., :-1, 1:, 1:] 276 | | occ[..., 1:, :-1, :-1] 277 | | occ[..., 1:, :-1, 1:] 278 | | occ[..., 1:, 1:, :-1] 279 | | occ[..., 1:, 1:, 1:] 280 | ) 281 | return unoccupied 282 | 283 | 284 | def check_voxel_boundary(occupancy_grid): 285 | occupied = check_voxel_occupied(occupancy_grid) 286 | unoccupied = check_voxel_unoccupied(occupancy_grid) 287 | return ~occupied & ~unoccupied 288 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from tensorboardX import SummaryWriter 4 | import numpy as np 5 | import os 6 | import argparse 7 | import time, datetime 8 | import matplotlib; matplotlib.use('Agg') 9 | from src import config, data 10 | from src.checkpoints import CheckpointIO 11 | from collections import defaultdict 12 | import shutil 13 | 14 | 15 | # Arguments 16 | parser = argparse.ArgumentParser( 17 | description='Train a 3D reconstruction model.' 18 | ) 19 | parser.add_argument('config', type=str, help='Path to config file.') 20 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 21 | parser.add_argument('--exit-after', type=int, default=-1, 22 | help='Checkpoint and exit after specified number of seconds' 23 | 'with exit code 2.') 24 | 25 | args = parser.parse_args() 26 | cfg = config.load_config(args.config, 'configs/default.yaml') 27 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 28 | device = torch.device("cuda" if is_cuda else "cpu") 29 | # Set t0 30 | t0 = time.time() 31 | 32 | # Shorthands 33 | out_dir = cfg['training']['out_dir'] 34 | batch_size = cfg['training']['batch_size'] 35 | backup_every = cfg['training']['backup_every'] 36 | vis_n_outputs = cfg['generation']['vis_n_outputs'] 37 | exit_after = args.exit_after 38 | 39 | model_selection_metric = cfg['training']['model_selection_metric'] 40 | if cfg['training']['model_selection_mode'] == 'maximize': 41 | model_selection_sign = 1 42 | elif cfg['training']['model_selection_mode'] == 'minimize': 43 | model_selection_sign = -1 44 | else: 45 | raise ValueError('model_selection_mode must be ' 46 | 'either maximize or minimize.') 47 | 48 | # Output directory 49 | if not os.path.exists(out_dir): 50 | os.makedirs(out_dir) 51 | 52 | shutil.copyfile(args.config, os.path.join(out_dir, 'config.yaml')) 53 | 54 | # Dataset 55 | train_dataset = config.get_dataset('train', cfg) 56 | val_dataset = config.get_dataset('val', cfg, return_idx=True) 57 | 58 | train_loader = torch.utils.data.DataLoader( 59 | train_dataset, batch_size=batch_size, num_workers=cfg['training']['n_workers'], shuffle=True, 60 | collate_fn=data.collate_remove_none, 61 | worker_init_fn=data.worker_init_fn) 62 | 63 | val_loader = torch.utils.data.DataLoader( 64 | val_dataset, batch_size=1, num_workers=cfg['training']['n_workers_val'], shuffle=False, 65 | collate_fn=data.collate_remove_none, 66 | worker_init_fn=data.worker_init_fn) 67 | 68 | # For visualizations 69 | vis_loader = torch.utils.data.DataLoader( 70 | val_dataset, batch_size=1, shuffle=False, 71 | collate_fn=data.collate_remove_none, 72 | worker_init_fn=data.worker_init_fn) 73 | model_counter = defaultdict(int) 74 | data_vis_list = [] 75 | 76 | # Build a data dictionary for visualization 77 | iterator = iter(vis_loader) 78 | for i in range(len(vis_loader)): 79 | data_vis = next(iterator) 80 | idx = data_vis['idx'].item() 81 | model_dict = val_dataset.get_model_dict(idx) 82 | category_id = model_dict.get('category', 'n/a') 83 | category_name = val_dataset.metadata[category_id].get('name', 'n/a') 84 | category_name = category_name.split(',')[0] 85 | if category_name == 'n/a': 86 | category_name = category_id 87 | 88 | c_it = model_counter[category_id] 89 | if c_it < vis_n_outputs: 90 | data_vis_list.append({'category': category_name, 'it': c_it, 'data': data_vis}) 91 | 92 | model_counter[category_id] += 1 93 | 94 | # Model 95 | model = config.get_model(cfg, device=device, dataset=train_dataset) 96 | 97 | # Generator 98 | generator = config.get_generator(model, cfg, device=device) 99 | 100 | # Intialize training 101 | optimizer = optim.Adam(model.parameters(), lr=1e-4) 102 | # optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) 103 | trainer = config.get_trainer(model, optimizer, cfg, device=device) 104 | 105 | checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer) 106 | try: 107 | load_dict = checkpoint_io.load('model.pt') 108 | except FileExistsError: 109 | load_dict = dict() 110 | epoch_it = load_dict.get('epoch_it', 0) 111 | it = load_dict.get('it', 0) 112 | metric_val_best = load_dict.get( 113 | 'loss_val_best', -model_selection_sign * np.inf) 114 | 115 | if metric_val_best == np.inf or metric_val_best == -np.inf: 116 | metric_val_best = -model_selection_sign * np.inf 117 | print('Current best validation metric (%s): %.8f' 118 | % (model_selection_metric, metric_val_best)) 119 | logger = SummaryWriter(os.path.join(out_dir, 'logs')) 120 | 121 | # Shorthands 122 | print_every = cfg['training']['print_every'] 123 | checkpoint_every = cfg['training']['checkpoint_every'] 124 | validate_every = cfg['training']['validate_every'] 125 | visualize_every = cfg['training']['visualize_every'] 126 | 127 | # Print model 128 | nparameters = sum(p.numel() for p in model.parameters()) 129 | print('Total number of parameters: %d' % nparameters) 130 | 131 | print('output path: ', cfg['training']['out_dir']) 132 | 133 | while True: 134 | epoch_it += 1 135 | 136 | for batch in train_loader: 137 | it += 1 138 | loss = trainer.train_step(batch) 139 | logger.add_scalar('train/loss', loss, it) 140 | 141 | # Print output 142 | if print_every > 0 and (it % print_every) == 0: 143 | t = datetime.datetime.now() 144 | print('[Epoch %02d] it=%03d, loss=%.4f, time: %.2fs, %02d:%02d' 145 | % (epoch_it, it, loss, time.time() - t0, t.hour, t.minute)) 146 | 147 | # Visualize output 148 | if visualize_every > 0 and (it % visualize_every) == 0: 149 | print('Visualizing') 150 | for data_vis in data_vis_list: 151 | if cfg['generation']['sliding_window']: 152 | out = generator.generate_mesh_sliding(data_vis['data']) 153 | else: 154 | out = generator.generate_mesh(data_vis['data']) 155 | # Get statistics 156 | try: 157 | mesh, stats_dict = out 158 | except TypeError: 159 | mesh, stats_dict = out, {} 160 | 161 | mesh.export(os.path.join(out_dir, 'vis', '{}_{}_{}.off'.format(it, data_vis['category'], data_vis['it']))) 162 | 163 | 164 | # Save checkpoint 165 | if (checkpoint_every > 0 and (it % checkpoint_every) == 0): 166 | print('Saving checkpoint') 167 | checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it, 168 | loss_val_best=metric_val_best) 169 | 170 | # Backup if necessary 171 | if (backup_every > 0 and (it % backup_every) == 0): 172 | print('Backup checkpoint') 173 | checkpoint_io.save('model_%d.pt' % it, epoch_it=epoch_it, it=it, 174 | loss_val_best=metric_val_best) 175 | # Run validation 176 | if validate_every > 0 and (it % validate_every) == 0: 177 | eval_dict = trainer.evaluate(val_loader) 178 | metric_val = eval_dict[model_selection_metric] 179 | print('Validation metric (%s): %.4f' 180 | % (model_selection_metric, metric_val)) 181 | 182 | for k, v in eval_dict.items(): 183 | logger.add_scalar('val/%s' % k, v, it) 184 | 185 | if model_selection_sign * (metric_val - metric_val_best) > 0: 186 | metric_val_best = metric_val 187 | print('New best model (loss %.4f)' % metric_val_best) 188 | checkpoint_io.save('model_best.pt', epoch_it=epoch_it, it=it, 189 | loss_val_best=metric_val_best) 190 | 191 | # Exit if necessary 192 | if exit_after > 0 and (time.time() - t0) >= exit_after: 193 | print('Time limit reached. Exiting.') 194 | checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it, 195 | loss_val_best=metric_val_best) 196 | exit(3) 197 | --------------------------------------------------------------------------------