├── SAP ├── LICENSE ├── README.md ├── configs │ ├── default.yaml │ ├── learning_based │ │ ├── demo_large_noise.yaml │ │ ├── demo_outlier.yaml │ │ ├── noise_large │ │ │ ├── ours.yaml │ │ │ └── ours_pretrained.yaml │ │ ├── noise_small │ │ │ ├── ours.yaml │ │ │ └── ours_pretrained.yaml │ │ └── outlier │ │ │ ├── ours_1x.yaml │ │ │ ├── ours_1x_pretrained.yaml │ │ │ ├── ours_3plane.yaml │ │ │ ├── ours_3x.yaml │ │ │ ├── ours_3x_pretrained.yaml │ │ │ ├── ours_5x.yaml │ │ │ ├── ours_5x_pretrained.yaml │ │ │ ├── ours_7x.yaml │ │ │ └── ours_7x_pretrained.yaml │ └── optim_based │ │ ├── dfaust.yaml │ │ ├── dgp.yaml │ │ ├── teaser.yaml │ │ ├── thingi.yaml │ │ └── thingi_noisy.yaml ├── eval_meshes.py ├── generate.py ├── optim.py ├── optim_hierarchy.py ├── scripts │ ├── __pycache__ │ │ └── easy_mesh_vtk.cpython-38.pyc │ ├── download_demo_data.sh │ ├── download_optim_data.sh │ ├── download_shapenet.sh │ ├── easy_mesh_vtk.py │ ├── prcocess_tooth.py │ ├── process_shapenet.py │ └── process_tooth.py └── src │ ├── __init__.py │ ├── __pycache__ │ ├── config.cpython-38.pyc │ ├── dpsr.cpython-38.pyc │ ├── generation.cpython-38.pyc │ ├── model.cpython-38.pyc │ ├── training.cpython-38.pyc │ └── utils.cpython-38.pyc │ ├── config.py │ ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── core.cpython-38.pyc │ │ ├── fields.cpython-38.pyc │ │ └── transforms.cpython-38.pyc │ ├── core.py │ ├── fields.py │ └── transforms.py │ ├── data_loader.py │ ├── dpsr.py │ ├── eval.py │ ├── generation.py │ ├── model.py │ ├── model_rgb.py │ ├── network │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── decoder.cpython-38.pyc │ │ ├── encoder.cpython-38.pyc │ │ ├── unet.cpython-38.pyc │ │ ├── unet3d.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ ├── decoder.py │ ├── encoder.py │ ├── net_rgb.py │ ├── unet.py │ ├── unet3d.py │ └── utils.py │ ├── optimization.py │ ├── training.py │ ├── utils.py │ └── visualize.py ├── cfgs ├── Tooth_models │ └── PoinTr.yaml └── dataset_configs │ └── Tooth.yaml ├── data └── dental │ └── crown │ ├── test.txt │ └── train.txt ├── datasets ├── __init__.py ├── __pycache__ │ ├── KITTIDataset.cpython-38.pyc │ ├── PCNDataset.cpython-38.pyc │ ├── ShapeNet55Dataset.cpython-38.pyc │ ├── Wrapping_Python_vtk_util_numpy_support.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── build.cpython-38.pyc │ ├── crowndataset.cpython-38.pyc │ ├── data_transforms.cpython-38.pyc │ ├── easy_mesh_vtk.cpython-38.pyc │ └── io.cpython-38.pyc ├── build.py ├── crowndataset.py └── easy_mesh_vtk.py ├── extensions ├── chamfer_dist │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-38.pyc │ ├── build │ │ ├── lib.linux-x86_64-3.8 │ │ │ └── chamfer.cpython-38-x86_64-linux-gnu.so │ │ └── temp.linux-x86_64-3.8 │ │ │ ├── build.ninja │ │ │ ├── chamfer.o │ │ │ └── chamfer_cuda.o │ ├── chamfer.cu │ ├── chamfer.egg-info │ │ ├── PKG-INFO │ │ ├── SOURCES.txt │ │ ├── dependency_links.txt │ │ └── top_level.txt │ ├── chamfer_cuda.cpp │ ├── dist │ │ └── chamfer-2.0.0-py3.8-linux-x86_64.egg │ ├── setup.py │ └── test.py ├── cubic_feature_sampling │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-38.pyc │ ├── build │ │ ├── lib.linux-x86_64-3.8 │ │ │ └── cubic_feature_sampling.cpython-38-x86_64-linux-gnu.so │ │ └── temp.linux-x86_64-3.8 │ │ │ ├── build.ninja │ │ │ ├── cubic_feature_sampling.o │ │ │ └── cubic_feature_sampling_cuda.o │ ├── cubic_feature_sampling.cu │ ├── cubic_feature_sampling.egg-info │ │ ├── PKG-INFO │ │ ├── SOURCES.txt │ │ ├── dependency_links.txt │ │ └── top_level.txt │ ├── cubic_feature_sampling_cuda.cpp │ ├── dist │ │ └── cubic_feature_sampling-1.1.0-py3.8-linux-x86_64.egg │ ├── setup.py │ └── test.py ├── gridding │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-38.pyc │ ├── build │ │ ├── lib.linux-x86_64-3.8 │ │ │ └── gridding.cpython-38-x86_64-linux-gnu.so │ │ └── temp.linux-x86_64-3.8 │ │ │ ├── build.ninja │ │ │ ├── gridding.o │ │ │ ├── gridding_cuda.o │ │ │ └── gridding_reverse.o │ ├── dist │ │ └── gridding-2.1.0-py3.8-linux-x86_64.egg │ ├── gridding.cu │ ├── gridding.egg-info │ │ ├── PKG-INFO │ │ ├── SOURCES.txt │ │ ├── dependency_links.txt │ │ └── top_level.txt │ ├── gridding_cuda.cpp │ ├── gridding_reverse.cu │ ├── setup.py │ └── test.py └── gridding_loss │ ├── __init__.py │ ├── __pycache__ │ └── __init__.cpython-38.pyc │ ├── build │ ├── lib.linux-x86_64-3.8 │ │ └── gridding_distance.cpython-38-x86_64-linux-gnu.so │ └── temp.linux-x86_64-3.8 │ │ ├── build.ninja │ │ ├── gridding_distance.o │ │ └── gridding_distance_cuda.o │ ├── dist │ └── gridding_distance-1.0.0-py3.8-linux-x86_64.egg │ ├── gridding_distance.cu │ ├── gridding_distance.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt │ ├── gridding_distance_cuda.cpp │ └── setup.py ├── install.sh ├── main.py ├── models ├── PoinTr.py ├── Transformer.py ├── __init__.py ├── __pycache__ │ ├── FoldingNet.cpython-38.pyc │ ├── GRNet.cpython-38.pyc │ ├── PCN.cpython-38.pyc │ ├── PoinTr.cpython-38.pyc │ ├── TopNet.cpython-38.pyc │ ├── Transformer.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── build.cpython-38.pyc │ ├── dgcnn_group.cpython-38.pyc │ └── easy_mesh_vtk.cpython-38.pyc ├── build.py ├── dgcnn_group.py └── easy_mesh_vtk.py ├── readme.txt ├── requirements.txt ├── tools ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── builder.cpython-38.pyc │ └── runner.cpython-38.pyc ├── builder.py └── runner.py └── utils ├── AverageMeter.py ├── __pycache__ ├── AverageMeter.cpython-38.pyc ├── config.cpython-38.pyc ├── dist_utils.cpython-38.pyc ├── logger.cpython-38.pyc ├── metrics.cpython-38.pyc ├── misc.cpython-38.pyc ├── parser.cpython-38.pyc └── registry.cpython-38.pyc ├── config.py ├── dist_utils.py ├── logger.py ├── metrics.py ├── misc.py ├── parser.py └── registry.py /SAP/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 autonomousvision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /SAP/README.md: -------------------------------------------------------------------------------- 1 | # Shape As Points (SAP) 2 | 3 | ### [**Paper**](https://arxiv.org/abs/2106.03452) | [**Project Page**](https://pengsongyou.github.io/sap) | [**Short Video (6 min)**](https://youtu.be/FL8LMk_qWb4) | [**Long Video (12 min)**](https://youtu.be/TgR0NvYty0A)
4 | 5 | ![](./media/teaser_wheel.gif) 6 | 7 | This repository contains the implementation of the paper: 8 | 9 | Shape As Points: A Differentiable Poisson Solver 10 | [Songyou Peng](https://pengsongyou.github.io/), [Chiyu "Max" Jiang](https://www.maxjiang.ml/), [Yiyi Liao](https://yiyiliao.github.io/), [Michael Niemeyer](https://m-niemeyer.github.io/), [Marc Pollefeys](https://www.inf.ethz.ch/personal/pomarc/) and [Andreas Geiger](http://www.cvlibs.net/) 11 | **NeurIPS 2021 (Oral)** 12 | 13 | 14 | If you find our code or paper useful, please consider citing 15 | ```bibtex 16 | @inproceedings{Peng2021SAP, 17 | author = {Peng, Songyou and Jiang, Chiyu "Max" and Liao, Yiyi and Niemeyer, Michael and Pollefeys, Marc and Geiger, Andreas}, 18 | title = {Shape As Points: A Differentiable Poisson Solver}, 19 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, 20 | year = {2021}} 21 | ``` 22 | 23 | 24 | ## Installation 25 | First you have to make sure that you have all dependencies in place. 26 | The simplest way to do so, is to use [anaconda](https://www.anaconda.com/). 27 | 28 | You can create an anaconda environment called `sap` using 29 | ``` 30 | conda env create -f environment.yaml 31 | conda activate sap 32 | ``` 33 | 34 | Next, you should install [PyTorch3D](https://pytorch3d.org/) (**>=0.5**) yourself from the [official instruction](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md#3-install-wheels-for-linux). 35 | 36 | And install [PyTorch Scatter](https://github.com/rusty1s/pytorch_scatter): 37 | ```sh 38 | conda install pytorch-scatter -c pyg 39 | ``` 40 | 41 | 42 | ## Demo - Quick Start 43 | 44 | First, run the script to get the demo data: 45 | 46 | ```bash 47 | bash scripts/download_demo_data.sh 48 | ``` 49 | 50 | ### Optimization-based 3D Surface Reconstruction 51 | 52 | You can now quickly test our code on the data shown in the teaser. To this end, simply run: 53 | 54 | ```python 55 | python optim_hierarchy.py configs/optim_based/teaser.yaml 56 | ``` 57 | This script should create a folder `out/demo_optim` where the output meshes and the optimized oriented point clouds under different grid resolution are stored. 58 | 59 | To visualize the optimization process on the fly, you can set `o3d_show: Frue` in [`configs/optim_based/teaser.yaml`](https://github.com/autonomousvision/shape_as_points/tree/main/configs/optim_based/teaser.yaml). 60 | 61 | ### Learning-based 3D Surface Reconstruction 62 | You can also test SAP on another application where we can reconstruct from unoriented point clouds with either **large noises** or **outliers** with a learned network. 63 | 64 | ![](./media/results_large_noise.gif) 65 | 66 | For the point clouds with large noise as shown above, you can run: 67 | ```python 68 | python generate.py configs/learning_based/demo_large_noise.yaml 69 | ``` 70 | The results can been found at `out/demo_shapenet_large_noise/generation/vis`. 71 | 72 | ![](./media/results_outliers.gif) 73 | As for the point clouds with outliers, you can run: 74 | ```python 75 | python generate.py configs/learning_based/demo_outlier.yaml 76 | ``` 77 | You can find the reconstrution on `out/demo_shapenet_outlier/generation/vis`. 78 | 79 | 80 | ## Dataset 81 | 82 | We have different dataset for our optimization-based and learning-based settings. 83 | 84 | ### Dataset for Optimization-based Reconstruction 85 | Here we consider the following dataset: 86 | - [Thingi10K](https://arxiv.org/abs/1605.04797) (synthetic) 87 | - [Surface Reconstruction Benchmark (SRB)](https://github.com/fwilliams/deep-geometric-prior) (real scans) 88 | - [MPI Dynamic FAUST](https://dfaust.is.tue.mpg.de/) (real scans) 89 | 90 | Please cite the corresponding papers if you use the data. 91 | 92 | You can download the processed dataset (~200 MB) by running: 93 | ```bash 94 | bash scripts/download_optim_data.sh 95 | ``` 96 | 97 | ### Dataset for Learning-based Reconstruction 98 | We train and evaluate on [ShapeNet](https://shapenet.org/). 99 | You can download the processed dataset (~220 GB) by running: 100 | ```bash 101 | bash scripts/download_shapenet.sh 102 | ``` 103 | After, you should have the dataset in `data/shapenet_psr` folder. 104 | 105 | Alternatively, you can also preprocess the dataset yourself. To this end, you can: 106 | * first download the preprocessed dataset (73.4 GB) by running [the script](https://github.com/autonomousvision/occupancy_networks#preprocessed-data) from Occupancy Networks. 107 | * check [`scripts/process_shapenet.py`](https://github.com/autonomousvision/shape_as_points/tree/main/scripts/process_shapenet.py), modify the base path and run the code 108 | 109 | 110 | ## Usage for Optimization-based 3D Reconstruction 111 | 112 | For our optimization-based setting, you can consider running with a coarse-to-fine strategy: 113 | ```python 114 | python optim_hierarchy.py configs/optim_based/CONFIG.yaml 115 | ``` 116 | We start from a grid resolution of 32^3, and increase to 64^3, 128^3 and finally 256^3. 117 | 118 | Alternatively, you can also run on a single resolution with: 119 | 120 | ```python 121 | python optim.py configs/optim_based/CONFIG.yaml 122 | ``` 123 | You might need to modify the `CONFIG.yaml` accordingly. 124 | 125 | ## Usage for Learning-based 3D Reconstruction 126 | 127 | ### Mesh Generation 128 | To generate meshes using a trained model, use 129 | ```python 130 | python generate.py configs/learning_based/CONFIG.yaml 131 | ``` 132 | where you replace `CONFIG.yaml` with the correct config file. 133 | 134 | #### Use a pre-trained model 135 | The easiest way is to use a pre-trained model. You can do this by using one of the config files with postfix `_pretrained`. 136 | 137 | For example, for 3D reconstruction from point clouds with outliers using our model with 7x offsets, you can simply run: 138 | ```python 139 | python generate.py configs/learning_based/outlier/ours_7x_pretrained.yaml 140 | ``` 141 | 142 | The script will automatically download the pretrained model and run the generation. You can find the outputs in the `out/.../generation_pretrained` folders. 143 | 144 | **Note** config files are only for generation, not for training new models: when these configs are used for training, the model will be trained from scratch, but during inference our code will still use the pretrained model. 145 | 146 | We provide the following pretrained models: 147 | ``` 148 | noise_small/ours.pt 149 | noise_large/ours.pt 150 | outlier/ours_1x.pt 151 | outlier/ours_3x.pt 152 | outlier/ours_5x.pt 153 | outlier/ours_7x.pt 154 | outlier/ours_3plane.pt 155 | ``` 156 | 157 | 158 | ### Evaluation 159 | To evaluate a trained model, we provide the script [`eval_meshes.py`](https://github.com/autonomousvision/shape_as_points/blob/main/eval_meshes.py). You can run it using: 160 | ```python 161 | python eval_meshes.py configs/learning_based/CONFIG.yaml 162 | ``` 163 | The script takes the meshes generated in the previous step and evaluates them using a standardized protocol. The output will be written to `.pkl` and `.csv` files in the corresponding generation folder that can be processed using [pandas](https://pandas.pydata.org/). 164 | 165 | ### Training 166 | 167 | Finally, to train a new network from scratch, simply run: 168 | ```python 169 | python train.py configs/learning_based/CONFIG.yaml 170 | ``` 171 | For available training options, please take a look at `configs/default.yaml`. 172 | 173 | -------------------------------------------------------------------------------- /SAP/configs/default.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: Shapes3D 3 | path: data/ShapeNet 4 | class: null 5 | data_type: img 6 | input_type: pointcloud 7 | dim: 3 8 | num_points: 1000 9 | num_gt_points: 1000 10 | num_offset: 1 11 | img_size: null 12 | n_views_input: 20 13 | n_views_per_iter: 2 14 | pointcloud_noise: 0 15 | pointcloud_file: pointcloud.npz 16 | pointcloud_outlier_ratio: 0 17 | fixed_scale: 0 18 | train_split: train 19 | val_split: val 20 | test_split: test 21 | points_file: null 22 | points_iou_file: points.npz 23 | points_unpackbits: true 24 | padding: 0.1 25 | multi_files: null 26 | gt_mesh: null 27 | zero_level: 0 28 | only_single: False 29 | sample_only_floor: False 30 | model: 31 | apply_sigmoid: True 32 | grid_res: 128 # poisson grid resolution 33 | psr_sigma: 0 34 | psr_tanh: False 35 | normal_normalize: False 36 | raster: {} 37 | renderer: {} 38 | encoder: null 39 | predict_normal: True 40 | predict_offset: True 41 | s_offset: 0.001 42 | local_coord: True 43 | encoder_kwargs: {} 44 | unet3d: False 45 | unet3d_kwargs: {} 46 | multi_gpu: false 47 | rotate_matrix: false 48 | c_dim: 512 49 | sphere_radius: 0.2 50 | train: 51 | lr: 1e-3 52 | lr_pcl: 2e-2 53 | input_mesh: '' 54 | out_dir: out/default 55 | subsample_vertex: False 56 | batch_size: 4 57 | n_grow_points: 0 58 | resample_every: 0 59 | l_weight: {} 60 | w_reg_point: 0 61 | w_psr: 0 62 | w_raw: 0 # train with raw point cloud 63 | gauss_weight: 0 64 | n_sup_point: 0 65 | w_normals: 0 66 | total_epochs: 10 67 | print_every: 1 68 | visualize_every: 1 69 | save_every: 1 70 | vis_vert_color: True 71 | o3d_show: False 72 | o3d_vis_pcl: True 73 | o3d_window_size: 540 74 | vis_rendering: False 75 | vis_psr: False 76 | save_video: False 77 | exp_mesh: True 78 | exp_pcl: True 79 | checkpoint_every: 1 80 | validate_every: 1 81 | backup_every: 1 82 | timestamp: False # add timestamp to out_dir name 83 | model_selection_metric: loss 84 | model_selection_mode: minimize 85 | n_workers: 0 86 | n_workers_val: 0 87 | test: 88 | threshold: 0.5 89 | eval_mesh: true 90 | eval_pointcloud: false 91 | model_file: model_best.pt 92 | generation: 93 | batch_size: 4 94 | exp_gt: False 95 | exp_oracle: false 96 | exp_input: False 97 | vis_n_outputs: 10 98 | generate_mesh: true 99 | generate_pointcloud: true 100 | generation_dir: generation 101 | copy_input: true 102 | use_sampling: false 103 | psr_resolution: 0 104 | psr_sigma: 0 105 | -------------------------------------------------------------------------------- /SAP/configs/learning_based/demo_large_noise.yaml: -------------------------------------------------------------------------------- 1 | 2 | inherit_from: configs/learning_based/noise_large/ours.yaml 3 | data: 4 | class: [''] 5 | path: data/demo/shapenet_chair 6 | train: 7 | out_dir: out/demo_shapenet_large_noise 8 | test: 9 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_025.pt -------------------------------------------------------------------------------- /SAP/configs/learning_based/demo_outlier.yaml: -------------------------------------------------------------------------------- 1 | 2 | inherit_from: configs/learning_based/outlier/ours_7x.yaml 3 | data: 4 | class: [''] 5 | path: data/demo/shapenet_lamp 6 | train: 7 | out_dir: out/demo_shapenet_outlier 8 | test: 9 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_7x.pt -------------------------------------------------------------------------------- /SAP/configs/learning_based/noise_large/ours.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class: null 3 | data_type: psr_full 4 | input_type: pointcloud 5 | path: data/shapenet_psr 6 | num_gt_points: 10000 7 | num_offset: 7 8 | pointcloud_n: 3000 9 | pointcloud_noise: 0.025 10 | model: 11 | grid_res: 128 # poisson grid resolution 12 | psr_sigma: 2 13 | psr_tanh: True 14 | normal_normalize: False 15 | predict_normal: True 16 | predict_offset: True 17 | c_dim: 32 18 | s_offset: 0.001 19 | encoder: local_pool_pointnet 20 | encoder_kwargs: 21 | hidden_dim: 32 22 | plane_type: 'grid' 23 | grid_resolution: 32 24 | unet3d: True 25 | unet3d_kwargs: 26 | num_levels: 3 27 | f_maps: 32 28 | in_channels: 32 29 | out_channels: 32 30 | decoder: simple_local 31 | decoder_kwargs: 32 | sample_mode: bilinear # bilinear / nearest 33 | hidden_size: 32 34 | train: 35 | batch_size: 32 36 | lr: 5e-4 37 | out_dir: out/shapenet/noise_025_ours 38 | w_psr: 1 39 | model_selection_metric: psr_l2 40 | print_every: 100 41 | checkpoint_every: 200 42 | validate_every: 5000 43 | backup_every: 10000 44 | total_epochs: 400000 45 | visualize_every: 5000 46 | exp_pcl: True 47 | exp_mesh: True 48 | n_workers: 8 49 | n_workers_val: 0 50 | generation: 51 | exp_gt: False 52 | exp_input: True 53 | psr_resolution: 128 54 | psr_sigma: 2 55 | -------------------------------------------------------------------------------- /SAP/configs/learning_based/noise_large/ours_pretrained.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/learning_based/noise_large/ours.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_025.pt -------------------------------------------------------------------------------- /SAP/configs/learning_based/noise_small/ours.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class: null 3 | data_type: psr_full 4 | input_type: pointcloud 5 | dim: 3 6 | path: data/shapenet_psr 7 | num_gt_points: 10000 8 | num_offset: 7 9 | pointcloud_n: 3000 10 | pointcloud_noise: 0 11 | padding: 0.1 12 | model: 13 | grid_res: 128 # poisson grid resolution 14 | psr_sigma: 2 15 | psr_tanh: True 16 | normal_normalize: False 17 | predict_normal: True 18 | predict_offset: True 19 | c_dim: 32 20 | s_offset: 0.001 21 | local_coord: True 22 | encoder: local_pool_pointnet 23 | encoder_kwargs: 24 | hidden_dim: 32 25 | plane_type: 'grid' 26 | grid_resolution: 32 27 | unet3d: True 28 | unet3d_kwargs: 29 | num_levels: 3 30 | f_maps: 32 31 | in_channels: 32 32 | out_channels: 32 33 | decoder: simple_local 34 | decoder_kwargs: 35 | sample_mode: bilinear # bilinear / nearest 36 | hidden_size: 32 37 | train: 38 | batch_size: 4 39 | lr: 5e-4 40 | out_dir: out/shapenet/noise_0_ours_pointr_100ep 41 | w_psr: 1 42 | model_selection_metric: psr_l2 43 | print_every: 1 44 | checkpoint_every: 1 45 | validate_every: 1 46 | backup_every: 1 47 | total_epochs: 100 48 | visualize_every: 1 49 | exp_pcl: True 50 | exp_mesh: True 51 | n_workers: 8 52 | n_workers_val: 0 53 | generation: 54 | exp_gt: False 55 | exp_input: True 56 | psr_resolution: 128 57 | psr_sigma: 2 58 | -------------------------------------------------------------------------------- /SAP/configs/learning_based/noise_small/ours_pretrained.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/learning_based/noise_small/ours.yaml 2 | generation: 3 | generation_dir: generation_pretrained_poinTr 4 | test: 5 | model_file: C:/Users/Golriz/OneDrive - polymtl.ca/Desktop/SAP/shape_as_points-main/shape_as_points-main/models/model_best.pt -------------------------------------------------------------------------------- /SAP/configs/learning_based/outlier/ours_1x.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class: null 3 | data_type: psr_full 4 | input_type: pointcloud 5 | path: data/shapenet_psr 6 | num_gt_points: 10000 7 | num_offset: 1 8 | pointcloud_n: 3000 9 | pointcloud_noise: 0.005 10 | pointcloud_outlier_ratio: 0.5 11 | model: 12 | grid_res: 128 # poisson grid resolution 13 | psr_sigma: 2 14 | psr_tanh: True 15 | normal_normalize: False 16 | predict_normal: True 17 | predict_offset: True 18 | c_dim: 32 19 | s_offset: 0.001 20 | encoder: local_pool_pointnet 21 | encoder_kwargs: 22 | hidden_dim: 32 23 | plane_type: 'grid' 24 | grid_resolution: 32 25 | unet3d: True 26 | unet3d_kwargs: 27 | num_levels: 3 28 | f_maps: 32 29 | in_channels: 32 30 | out_channels: 32 31 | decoder: simple_local 32 | decoder_kwargs: 33 | sample_mode: bilinear # bilinear / nearest 34 | hidden_size: 32 35 | train: 36 | batch_size: 32 37 | lr: 5e-4 38 | out_dir: out/shapenet/outlier_ours_1x 39 | w_psr: 1 40 | model_selection_metric: psr_l2 41 | print_every: 100 42 | checkpoint_every: 200 43 | validate_every: 5000 44 | backup_every: 10000 45 | total_epochs: 400000 46 | visualize_every: 5000 47 | exp_pcl: True 48 | exp_mesh: True 49 | n_workers: 8 50 | n_workers_val: 0 51 | generation: 52 | exp_gt: False 53 | exp_input: True 54 | psr_resolution: 128 55 | psr_sigma: 2 56 | -------------------------------------------------------------------------------- /SAP/configs/learning_based/outlier/ours_1x_pretrained.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/learning_based/outlier/ours_3x/ours.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_3x.pt -------------------------------------------------------------------------------- /SAP/configs/learning_based/outlier/ours_3plane.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class: null 3 | data_type: psr_full 4 | input_type: pointcloud 5 | path: data/shapenet_psr 6 | num_gt_points: 10000 7 | num_offset: 5 8 | pointcloud_n: 3000 9 | pointcloud_noise: 0.005 10 | pointcloud_outlier_ratio: 0.5 11 | model: 12 | grid_res: 128 # poisson grid resolution 13 | psr_sigma: 2 14 | psr_tanh: True 15 | normal_normalize: False 16 | predict_normal: True 17 | predict_offset: True 18 | c_dim: 32 19 | s_offset: 0.001 20 | encoder: local_pool_pointnet 21 | encoder_kwargs: 22 | hidden_dim: 32 23 | plane_type: ['xz', 'xy', 'yz'] 24 | plane_resolution: 64 25 | unet: True 26 | unet_kwargs: 27 | depth: 4 28 | merge_mode: concat 29 | start_filts: 32 30 | decoder: simple_local 31 | decoder_kwargs: 32 | sample_mode: bilinear # bilinear / nearest 33 | hidden_size: 32 34 | train: 35 | batch_size: 32 36 | lr: 5e-4 37 | out_dir: out/shapenet/outlier_ours_3plane 38 | w_psr: 1 39 | model_selection_metric: psr_l2 40 | print_every: 100 41 | checkpoint_every: 200 42 | validate_every: 5000 43 | backup_every: 10000 44 | total_epochs: 400000 45 | visualize_every: 5000 46 | exp_pcl: True 47 | exp_mesh: True 48 | n_workers: 8 49 | n_workers_val: 0 50 | generation: 51 | exp_gt: False 52 | exp_input: True 53 | psr_resolution: 128 54 | psr_sigma: 2 55 | -------------------------------------------------------------------------------- /SAP/configs/learning_based/outlier/ours_3x.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class: null 3 | data_type: psr_full 4 | input_type: pointcloud 5 | path: data/shapenet_psr 6 | num_gt_points: 10000 7 | num_offset: 3 8 | pointcloud_n: 3000 9 | pointcloud_noise: 0.005 10 | pointcloud_outlier_ratio: 0.5 11 | model: 12 | grid_res: 128 # poisson grid resolution 13 | psr_sigma: 2 14 | psr_tanh: True 15 | normal_normalize: False 16 | predict_normal: True 17 | predict_offset: True 18 | c_dim: 32 19 | s_offset: 0.001 20 | encoder: local_pool_pointnet 21 | encoder_kwargs: 22 | hidden_dim: 32 23 | plane_type: 'grid' 24 | grid_resolution: 32 25 | unet3d: True 26 | unet3d_kwargs: 27 | num_levels: 3 28 | f_maps: 32 29 | in_channels: 32 30 | out_channels: 32 31 | decoder: simple_local 32 | decoder_kwargs: 33 | sample_mode: bilinear # bilinear / nearest 34 | hidden_size: 32 35 | train: 36 | batch_size: 32 37 | lr: 5e-4 38 | out_dir: out/shapenet/outlier_ours_3x 39 | w_psr: 1 40 | model_selection_metric: psr_l2 41 | print_every: 100 42 | checkpoint_every: 200 43 | validate_every: 5000 44 | backup_every: 10000 45 | total_epochs: 400000 46 | visualize_every: 5000 47 | exp_pcl: True 48 | exp_mesh: True 49 | n_workers: 8 50 | n_workers_val: 0 51 | generation: 52 | exp_gt: False 53 | exp_input: True 54 | psr_resolution: 128 55 | psr_sigma: 2 56 | -------------------------------------------------------------------------------- /SAP/configs/learning_based/outlier/ours_3x_pretrained.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/learning_based/outlier/ours_1x/ours.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_1x.pt -------------------------------------------------------------------------------- /SAP/configs/learning_based/outlier/ours_5x.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class: null 3 | data_type: psr_full 4 | input_type: pointcloud 5 | path: data/shapenet_psr 6 | num_gt_points: 10000 7 | num_offset: 5 8 | pointcloud_n: 3000 9 | pointcloud_noise: 0.005 10 | pointcloud_outlier_ratio: 0.5 11 | model: 12 | grid_res: 128 # poisson grid resolution 13 | psr_sigma: 2 14 | psr_tanh: True 15 | normal_normalize: False 16 | predict_normal: True 17 | predict_offset: True 18 | c_dim: 32 19 | s_offset: 0.001 20 | encoder: local_pool_pointnet 21 | encoder_kwargs: 22 | hidden_dim: 32 23 | plane_type: 'grid' 24 | grid_resolution: 32 25 | unet3d: True 26 | unet3d_kwargs: 27 | num_levels: 3 28 | f_maps: 32 29 | in_channels: 32 30 | out_channels: 32 31 | decoder: simple_local 32 | decoder_kwargs: 33 | sample_mode: bilinear # bilinear / nearest 34 | hidden_size: 32 35 | train: 36 | batch_size: 32 37 | lr: 5e-4 38 | out_dir: out/shapenet/outlier_ours_5x 39 | w_psr: 1 40 | model_selection_metric: psr_l2 41 | print_every: 100 42 | checkpoint_every: 200 43 | validate_every: 5000 44 | backup_every: 10000 45 | total_epochs: 400000 46 | visualize_every: 5000 47 | exp_pcl: True 48 | exp_mesh: True 49 | n_workers: 8 50 | n_workers_val: 0 51 | generation: 52 | exp_gt: False 53 | exp_input: True 54 | psr_resolution: 128 55 | psr_sigma: 2 56 | -------------------------------------------------------------------------------- /SAP/configs/learning_based/outlier/ours_5x_pretrained.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/learning_based/outlier/ours_5x/ours.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_5x.pt -------------------------------------------------------------------------------- /SAP/configs/learning_based/outlier/ours_7x.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class: null 3 | data_type: psr_full 4 | input_type: pointcloud 5 | path: data/shapenet_psr 6 | num_gt_points: 10000 7 | num_offset: 7 8 | pointcloud_n: 3000 9 | pointcloud_noise: 0.005 10 | pointcloud_outlier_ratio: 0.5 11 | model: 12 | grid_res: 128 # poisson grid resolution 13 | psr_sigma: 2 14 | psr_tanh: True 15 | normal_normalize: False 16 | predict_normal: True 17 | predict_offset: True 18 | c_dim: 32 19 | s_offset: 0.001 20 | encoder: local_pool_pointnet 21 | encoder_kwargs: 22 | hidden_dim: 32 23 | plane_type: 'grid' 24 | grid_resolution: 32 25 | unet3d: True 26 | unet3d_kwargs: 27 | num_levels: 3 28 | f_maps: 32 29 | in_channels: 32 30 | out_channels: 32 31 | decoder: simple_local 32 | decoder_kwargs: 33 | sample_mode: bilinear # bilinear / nearest 34 | hidden_size: 32 35 | train: 36 | batch_size: 32 37 | lr: 5e-4 38 | out_dir: out/shapenet/outlier_ours_7x 39 | w_psr: 1 40 | model_selection_metric: psr_l2 41 | print_every: 100 42 | checkpoint_every: 200 43 | validate_every: 5000 44 | backup_every: 10000 45 | total_epochs: 400000 46 | visualize_every: 5000 47 | exp_pcl: True 48 | exp_mesh: True 49 | n_workers: 8 50 | n_workers_val: 0 51 | generation: 52 | exp_gt: False 53 | exp_input: True 54 | psr_resolution: 128 55 | psr_sigma: 2 56 | -------------------------------------------------------------------------------- /SAP/configs/learning_based/outlier/ours_7x_pretrained.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/learning_based/outlier/ours_7x.yaml 2 | generation: 3 | generation_dir: generation_pretrained 4 | test: 5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_7x.pt -------------------------------------------------------------------------------- /SAP/configs/optim_based/dfaust.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class: 'only_pcl' 3 | data_type: point 4 | data_path: 'data/dfaust/*.ply' 5 | object_id: 0 6 | num_points: 20000 7 | model: 8 | sphere_radius: 0.2 9 | grid_res: 256 # poisson grid resolution 10 | psr_sigma: 2 11 | train: 12 | schedule: 13 | pcl: 14 | initial: 1e-2 15 | interval: 700 16 | factor: 0.5 17 | final: 1e-3 18 | out_dir: out/dfaust 19 | w_chamfer: 1 20 | n_sup_point: 20000 21 | batch_size: 1 22 | n_grow_points: 2000 23 | resample_every: 200 24 | subsample_vertex: False 25 | total_epochs: 1600 26 | print_every: 10 27 | visualize_every: 2 28 | checkpoint_every: 500 29 | save_every: 100 30 | exp_pcl: True 31 | exp_mesh: True 32 | o3d_show: False 33 | o3d_vis_pcl: True 34 | o3d_window_size: 540 35 | vis_rendering: False 36 | vis_vert_color: False 37 | n_workers: 0 38 | -------------------------------------------------------------------------------- /SAP/configs/optim_based/dgp.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class: 'only_pcl' 3 | data_type: point 4 | data_path: 'data/deep_geometric_prior_data/*.ply' 5 | object_id: 0 6 | num_points: 20000 7 | model: 8 | sphere_radius: 0.2 9 | grid_res: 256 # poisson grid resolution 10 | psr_sigma: 2 11 | train: 12 | schedule: 13 | pcl: 14 | initial: 1e-2 15 | interval: 700 16 | factor: 0.5 17 | final: 1e-3 18 | out_dir: out/dgp 19 | w_reg_point: 0 20 | w_chamfer: 1 21 | n_sup_point: 20000 22 | batch_size: 1 23 | n_grow_points: 2000 24 | resample_every: 200 25 | subsample_vertex: False 26 | total_epochs: 1600 27 | print_every: 10 28 | visualize_every: 2 29 | checkpoint_every: 500 30 | save_every: 100 31 | exp_pcl: True 32 | exp_mesh: True 33 | o3d_show: False 34 | o3d_vis_pcl: True 35 | o3d_window_size: 540 36 | vis_rendering: False 37 | vis_vert_color: False 38 | n_workers: 0 39 | -------------------------------------------------------------------------------- /SAP/configs/optim_based/teaser.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class: 'only_pcl' 3 | data_type: point 4 | data_path: 'data/demo/wheel.ply' 5 | object_id: 0 6 | num_points: 20000 7 | model: 8 | sphere_radius: 0.2 9 | grid_res: 128 # poisson grid resolution 10 | psr_sigma: 2 11 | train: 12 | schedule: 13 | pcl: 14 | initial: 1e-2 15 | interval: 700 16 | factor: 0.5 17 | final: 1e-3 18 | out_dir: out/demo_optim 19 | w_chamfer: 1 20 | n_sup_point: 20000 21 | batch_size: 1 22 | n_grow_points: 2000 23 | resample_every: 200 24 | subsample_vertex: False 25 | total_epochs: 1600 26 | print_every: 10 27 | visualize_every: 2 28 | checkpoint_every: 500 29 | save_every: 100 30 | exp_pcl: True 31 | exp_mesh: True 32 | o3d_show: False 33 | o3d_vis_pcl: True 34 | o3d_window_size: 540 35 | vis_rendering: False 36 | vis_vert_color: False 37 | n_workers: 0 -------------------------------------------------------------------------------- /SAP/configs/optim_based/thingi.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class: 'only_pcl' 3 | data_type: point 4 | data_path: 'data/thingi/*.ply' 5 | object_id: 0 6 | num_points: 20000 7 | model: 8 | sphere_radius: 0.2 9 | grid_res: 128 # poisson grid resolution 10 | psr_sigma: 2 11 | train: 12 | # lr_pcl: 2e-2 13 | schedule: 14 | pcl: 15 | initial: 1e-2 16 | interval: 700 17 | factor: 0.5 18 | final: 1e-3 19 | out_dir: out/thingi 20 | w_reg_point: 0 21 | w_chamfer: 1 22 | n_sup_point: 20000 23 | batch_size: 1 24 | n_grow_points: 2000 25 | resample_every: 200 26 | subsample_vertex: False 27 | total_epochs: 1600 28 | print_every: 10 29 | visualize_every: 2 30 | checkpoint_every: 500 31 | save_every: 100 32 | exp_pcl: True 33 | exp_mesh: True 34 | o3d_show: False 35 | o3d_vis_pcl: True 36 | o3d_window_size: 540 37 | vis_rendering: False 38 | vis_vert_color: False 39 | n_workers: 0 40 | -------------------------------------------------------------------------------- /SAP/configs/optim_based/thingi_noisy.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class: 'only_pcl' 3 | data_type: point 4 | data_path: 'data/thingi_noisy/*.ply' 5 | object_id: 0 6 | num_points: 20000 7 | model: 8 | sphere_radius: 0.2 9 | grid_res: 128 # poisson grid resolution 10 | psr_sigma: 2 11 | train: 12 | # lr_pcl: 2e-2 13 | schedule: 14 | pcl: 15 | initial: 1e-2 16 | interval: 700 17 | factor: 0.5 18 | final: 1e-3 19 | out_dir: out/thingi_noisy 20 | w_reg_point: 0 21 | w_chamfer: 1 22 | n_sup_point: 20000 23 | batch_size: 1 24 | n_grow_points: 2000 25 | resample_every: 200 26 | subsample_vertex: False 27 | total_epochs: 1600 28 | print_every: 10 29 | visualize_every: 2 30 | checkpoint_every: 500 31 | save_every: 100 32 | exp_pcl: True 33 | exp_mesh: True 34 | o3d_show: False 35 | o3d_vis_pcl: True 36 | o3d_window_size: 540 37 | vis_rendering: False 38 | vis_vert_color: False 39 | n_workers: 0 40 | -------------------------------------------------------------------------------- /SAP/eval_meshes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import trimesh 3 | from torch.utils.data import Dataset, DataLoader 4 | import numpy as np; np.set_printoptions(precision=4) 5 | import shutil, argparse, time, os 6 | import pandas as pd 7 | from src.data import collate_remove_none, collate_stack_together, worker_init_fn 8 | from src.training import Trainer 9 | from src.model import Encode2Points 10 | from src.data import PointCloudField, IndexField, Shapes3dDataset 11 | from src.utils import load_config, load_pointcloud 12 | from src.eval import MeshEvaluator 13 | from tqdm import tqdm 14 | from pdb import set_trace as st 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser(description='MNIST toy experiment') 19 | parser.add_argument('config', type=str, help='Path to config file.') 20 | parser.add_argument('--no_cuda', action='store_true', default=False, 21 | help='disables CUDA training') 22 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') 23 | parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.') 24 | 25 | args = parser.parse_args() 26 | cfg = load_config(args.config, 'configs/default.yaml') 27 | use_cuda = not args.no_cuda and torch.cuda.is_available() 28 | device = torch.device("cuda" if use_cuda else "cpu") 29 | data_type = cfg['data']['data_type'] 30 | # Shorthands 31 | out_dir = cfg['train']['out_dir'] 32 | generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir']) 33 | 34 | if cfg['generation'].get('iter', 0)!=0: 35 | generation_dir += '_%04d'%cfg['generation']['iter'] 36 | elif args.iter is not None: 37 | generation_dir += '_%04d'%args.iter 38 | 39 | print('Evaluate meshes under %s'%generation_dir) 40 | 41 | out_file = os.path.join(generation_dir, 'eval_meshes_full.pkl') 42 | out_file_class = os.path.join(generation_dir, 'eval_meshes.csv') 43 | 44 | # PYTORCH VERSION > 1.0.0 45 | assert(float(torch.__version__.split('.')[-3]) > 0) 46 | 47 | pointcloud_field = PointCloudField(cfg['data']['pointcloud_file']) 48 | fields = { 49 | 'pointcloud': pointcloud_field, 50 | 'idx': IndexField(), 51 | } 52 | 53 | print('Test split: ', cfg['data']['test_split']) 54 | 55 | dataset_folder = cfg['data']['path'] 56 | dataset = Shapes3dDataset( 57 | dataset_folder, fields, 58 | cfg['data']['test_split'], 59 | categories=cfg['data']['class'], cfg=cfg) 60 | 61 | # Loader 62 | test_loader = torch.utils.data.DataLoader( 63 | dataset, batch_size=1, num_workers=0, shuffle=False) 64 | 65 | # Evaluator 66 | evaluator = MeshEvaluator(n_points=100000) 67 | 68 | eval_dicts = [] 69 | print('Evaluating meshes...') 70 | for it, data in enumerate(tqdm(test_loader)): 71 | 72 | if data is None: 73 | print('Invalid data.') 74 | continue 75 | 76 | mesh_dir = os.path.join(generation_dir, 'meshes') 77 | pointcloud_dir = os.path.join(generation_dir, 'pointcloud') 78 | 79 | 80 | # Get index etc. 81 | idx = data['idx'].item() 82 | try: 83 | model_dict = dataset.get_model_dict(idx) 84 | except AttributeError: 85 | model_dict = {'model': str(idx), 'category': 'n/a'} 86 | 87 | modelname = model_dict['model'] 88 | category_id = model_dict['category'] 89 | 90 | try: 91 | category_name = dataset.metadata[category_id].get('name', 'n/a') 92 | except AttributeError: 93 | category_name = 'n/a' 94 | 95 | if category_id != 'n/a': 96 | mesh_dir = os.path.join(mesh_dir, category_id) 97 | pointcloud_dir = os.path.join(pointcloud_dir, category_id) 98 | 99 | # Evaluate 100 | pointcloud_tgt = data['pointcloud'].squeeze(0).numpy() 101 | normals_tgt = data['pointcloud.normals'].squeeze(0).numpy() 102 | 103 | 104 | eval_dict = { 105 | 'idx': idx, 106 | 'class id': category_id, 107 | 'class name': category_name, 108 | 'modelname':modelname, 109 | } 110 | eval_dicts.append(eval_dict) 111 | 112 | # Evaluate mesh 113 | if cfg['test']['eval_mesh']: 114 | mesh_file = os.path.join(mesh_dir, '%s.off' % modelname) 115 | 116 | if os.path.exists(mesh_file): 117 | mesh = trimesh.load(mesh_file, process=False) 118 | eval_dict_mesh = evaluator.eval_mesh( 119 | mesh, pointcloud_tgt, normals_tgt) 120 | for k, v in eval_dict_mesh.items(): 121 | eval_dict[k + ' (mesh)'] = v 122 | else: 123 | print('Warning: mesh does not exist: %s' % mesh_file) 124 | 125 | # Evaluate point cloud 126 | if cfg['test']['eval_pointcloud']: 127 | pointcloud_file = os.path.join( 128 | pointcloud_dir, '%s.ply' % modelname) 129 | 130 | if os.path.exists(pointcloud_file): 131 | pointcloud = load_pointcloud(pointcloud_file).astype(np.float32) 132 | eval_dict_pcl = evaluator.eval_pointcloud( 133 | pointcloud, pointcloud_tgt) 134 | for k, v in eval_dict_pcl.items(): 135 | eval_dict[k + ' (pcl)'] = v 136 | else: 137 | print('Warning: pointcloud does not exist: %s' 138 | % pointcloud_file) 139 | 140 | 141 | # Create pandas dataframe and save 142 | eval_df = pd.DataFrame(eval_dicts) 143 | eval_df.set_index(['idx'], inplace=True) 144 | eval_df.to_pickle(out_file) 145 | 146 | # Create CSV file with main statistics 147 | eval_df_class = eval_df.groupby(by=['class name']).mean() 148 | eval_df_class.loc['mean'] = eval_df_class.mean() 149 | eval_df_class.to_csv(out_file_class) 150 | 151 | # Print results 152 | print(eval_df_class) 153 | 154 | if __name__ == '__main__': 155 | main() -------------------------------------------------------------------------------- /SAP/optim_hierarchy.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import argparse 3 | from src.utils import load_config 4 | import subprocess 5 | os.environ['MKL_THREADING_LAYER'] = 'GNU' 6 | 7 | def main(): 8 | 9 | parser = argparse.ArgumentParser(description='MNIST toy experiment') 10 | parser.add_argument('config', type=str, help='Path to config file.') 11 | parser.add_argument('--start_res', type=int, default=-1, help='Resolution to start with.') 12 | parser.add_argument('--object_id', type=int, default=-1, help='Object index.') 13 | 14 | args, unknown = parser.parse_known_args() 15 | cfg = load_config(args.config, 'configs/default.yaml') 16 | 17 | resolutions=[32, 64, 128, 256] 18 | iterations=[1000, 1000, 1000, 200] 19 | lrs=[2e-3, 2e-3*0.7, 2e-3*(0.7**2), 2e-3*(0.7**3)] # reduce lr 20 | for idx,(res, iteration, lr) in enumerate(zip(resolutions, iterations, lrs)): 21 | 22 | if rescfg['model']['grid_res']: 26 | continue 27 | 28 | psr_sigma= 2 if res<=128 else 3 29 | 30 | if res > 128: 31 | psr_sigma = 5 if 'thingi_noisy' in args.config else 3 32 | 33 | if args.object_id != -1: 34 | out_dir = os.path.join(cfg['train']['out_dir'], 'object_%02d'%args.object_id, 'res_%d'%res) 35 | else: 36 | out_dir = os.path.join(cfg['train']['out_dir'], 'res_%d'%res) 37 | 38 | # sample from mesh when resampling is enabled, otherwise reuse the pointcloud 39 | init_shape='mesh' if cfg['train']['resample_every']>0 else 'pointcloud' 40 | 41 | 42 | if args.object_id != -1: 43 | input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'], 44 | 'object_%02d'%args.object_id, 'res_%d' % (resolutions[idx-1]), 45 | 'vis', init_shape, '%04d.ply' % (iterations[idx-1])) 46 | else: 47 | input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'], 48 | 'res_%d' % (resolutions[idx-1]), 49 | 'vis', init_shape, '%04d.ply' % (iterations[idx-1])) 50 | 51 | 52 | cmd = 'export MKL_SERVICE_FORCE_INTEL=1 && ' 53 | cmd += "python optim.py %s --model:grid_res %d --model:psr_sigma %d \ 54 | --train:input_mesh %s --train:total_epochs %d \ 55 | --train:out_dir %s --train:lr_pcl %f \ 56 | --data:object_id %d" % ( 57 | args.config, 58 | res, 59 | psr_sigma, 60 | input_mesh, 61 | iteration, 62 | out_dir, 63 | lr, 64 | args.object_id) 65 | print(cmd) 66 | os.system(cmd) 67 | 68 | if __name__=="__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /SAP/scripts/__pycache__/easy_mesh_vtk.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/scripts/__pycache__/easy_mesh_vtk.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/scripts/download_demo_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p data 3 | cd data 4 | echo "Start downloading ..." 5 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/demo.zip 6 | unzip demo.zip 7 | rm demo.zip 8 | echo "Done!" -------------------------------------------------------------------------------- /SAP/scripts/download_optim_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p data 3 | cd data 4 | echo "Start downloading data for optimization-based setting (~200 MB)" 5 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/optim_data.zip 6 | unzip optim_data.zip 7 | rm optim_data.zip 8 | echo "Done!" -------------------------------------------------------------------------------- /SAP/scripts/download_shapenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p data 3 | cd data 4 | echo "Start downloading preprocessed ShapeNet data (~220G)" 5 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/shapenet_psr.zip 6 | unzip shapenet_psr.zip 7 | rm shapenet_psr.zip 8 | echo "Done!" -------------------------------------------------------------------------------- /SAP/scripts/prcocess_tooth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import open3d as o3d 3 | import torch 4 | import time 5 | import multiprocessing 6 | import numpy as np 7 | from tqdm import tqdm 8 | from src.dpsr import DPSR 9 | #from easy_mesh_vtk import Easy_Mesh 10 | import pyvista as pv 11 | import os 12 | from pymeshfix._meshfix import PyTMesh 13 | from pymeshfix import MeshFix 14 | 15 | #data_path = 'C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\1a9e1fb2a51ffd065b07a27512172330' # path for ShapeNet from ONet 16 | data_path_mesh='C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\data-watertightshell' 17 | data_pred_points="C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\test-sap-watertight-5epochs\\test-pointr\\A0_6014-36\\A0_6014-36.npy" 18 | base = 'data' # output base directory 19 | dataset_name = 'shapenet_psr' 20 | multiprocess = True 21 | njobs = 8 22 | save_pointcloud = True 23 | save_psr_field = True 24 | resolution = 128 25 | zero_level = 0.0 26 | num_points = 100000 27 | padding = 1.2 28 | out_path_cur_obj='C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\test-sap-watertight-5epochs\\test-pointr\\A0_6014-36' 29 | 30 | 31 | dpsr = DPSR(res=(resolution, resolution, resolution), sig=0) 32 | """"" 33 | gt_path = os.path.join(data_path, 'pointcloud.npz') 34 | data = np.load(gt_path) 35 | points = data['points'] 36 | normals = data['normals'] 37 | 38 | # normalize the point to [0, 1) 39 | points = points / padding + 0.5 40 | # to scale back during inference, we should: 41 | # ! p = (p - 0.5) * padding 42 | pcd = o3d.geometry.PointCloud() 43 | pcd.points = o3d.utility.Vector3dVector(points) 44 | pcd.normals= o3d.utility.Vector3dVector(normals) 45 | pcd.paint_uniform_color([0, 0.45, 0]) 46 | 47 | o3d.visualization.draw_geometries([pcd], point_show_normal=True) 48 | 49 | if save_pointcloud: 50 | outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz') 51 | # np.savez(outdir, points=points, normals=normals) 52 | np.savez(outdir, points=data['points'], normals=data['normals']) 53 | # return 54 | 55 | if save_psr_field: 56 | psr_gt = dpsr(torch.from_numpy(points.astype(np.float32))[None], 57 | torch.from_numpy(normals.astype(np.float32))[None]).squeeze().cpu().numpy().astype(np.float16) 58 | 59 | outdir = os.path.join(out_path_cur_obj, 'psr.npz') 60 | np.savez(outdir, psr=psr_gt) 61 | 62 | """"" 63 | #read the meshgt 64 | #for entry in os.listdir(data_path): 65 | 66 | #Open_mesh = pv.read(os.path.join(data_path_mesh,'shell26_registered.ply')) 67 | #meshfix = MeshFix(Open_mesh) 68 | #holes = meshfix.extract_holes() 69 | #meshfix.repair(verbose=True) 70 | #meshfix.save(os.path.join(data_path_mesh,'shell26_registered_watertight.ply' )) 71 | 72 | """"" 73 | for cases in os.listdir(data_path_mesh): 74 | for i in os.listdir(os.path.join(data_path_mesh,cases)): 75 | if 'shell' in i: 76 | print('cases:',cases) 77 | mesh = o3d.io.read_triangle_mesh(os.path.join(data_path_mesh, cases, i)) 78 | mesh = mesh.subdivide_loop(number_of_iterations=3) 79 | # mesh = Easy_Mesh(os.path.join(data_path_mesh, '45shell_registered.ply')) 80 | points = mesh.vertices 81 | print('points', np.asarray(points)) 82 | Npoints = np.asarray(points) 83 | normals = mesh.compute_vertex_normals() 84 | Nnormals = np.asarray(normals.vertex_normals) 85 | print("normals", np.asarray(normals.vertex_normals)) 86 | # randomple sample 100000 87 | points_sample = 100000 88 | positive_mesh_idx = np.arange(len(Npoints)) 89 | try: 90 | positive_selected_mesh_idx = np.random.choice(positive_mesh_idx, size=points_sample, replace=False) 91 | except ValueError: 92 | positive_selected_mesh_idx = np.random.choice(positive_mesh_idx, size=points_sample, replace=True) 93 | # mesh_with_newpoints = np.zeros([points_sample, Npoints.shape[1]], dtype='float32') 94 | Npoints = Npoints[positive_selected_mesh_idx, :] 95 | print('shape', Npoints.shape) 96 | Nnormals = Nnormals[positive_selected_mesh_idx, :] 97 | 98 | # normalize the point to [0, 1) 99 | Npoints = (Npoints - np.min(Npoints)) / (np.max(Npoints) + 1 - np.min(Npoints)) 100 | pcd = o3d.geometry.PointCloud() 101 | pcd.points = o3d.utility.Vector3dVector(Npoints) 102 | pcd.normals= o3d.utility.Vector3dVector(Nnormals) 103 | pcd.paint_uniform_color([0, 0.45, 0]) 104 | 105 | o3d.visualization.draw_geometries([pcd], point_show_normal=True) 106 | # mean/std 107 | # Npoints_mean=np.mean(Npoints) 108 | # Npoints_std=np.std(Npoints) 109 | # Npoints=(Npoints-Npoints_mean) / Npoints_std 110 | # Npoints = np.asarray(Npoints) / padding + 0.5 111 | # to scale back during inference, we should: 112 | # ! p = (p - 0.5) * padding 113 | 114 | if save_pointcloud: 115 | outdir = os.path.join(out_path_cur_obj, cases, 'pointcloud.npz') 116 | # np.savez(outdir, points=points, normals=normals) 117 | np.savez(outdir, points=Npoints, normals=np.asarray(Nnormals)) 118 | # return 119 | 120 | if save_psr_field: 121 | psr_gt = dpsr(torch.from_numpy(Npoints.astype(np.float32))[None], 122 | torch.from_numpy(Nnormals[None].astype(np.float32))).squeeze().cpu().numpy().astype( 123 | np.float16) 124 | 125 | outdir = os.path.join(out_path_cur_obj, cases, 'psr.npz') 126 | np.savez(outdir, psr=psr_gt) 127 | 128 | """"" 129 | 130 | points=np.load(data_pred_points) 131 | # normalize the point to [0, 1) 132 | Npoints = (points - np.min(points)) / (np.max(points) + 1 - np.min(points)) 133 | pcd = o3d.geometry.PointCloud() 134 | pcd.points = o3d.utility.Vector3dVector(Npoints) 135 | pcd.estimate_normals() 136 | pcd_normals = np.asarray(pcd.normals) 137 | #pcd.normals = o3d.utility.Vector3dVector(Nnormals) 138 | pcd.paint_uniform_color([0, 0.45, 0]) 139 | 140 | #o3d.visualization.draw_geometries([pcd], point_show_normal=True) 141 | 142 | 143 | if save_pointcloud: 144 | outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz') 145 | # np.savez(outdir, points=points, normals=normals) 146 | np.savez(outdir, points=Npoints, normals=np.asarray(pcd_normals)) 147 | # return 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /SAP/scripts/process_shapenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import time 4 | import multiprocessing 5 | import numpy as np 6 | from tqdm import tqdm 7 | from src.dpsr import DPSR 8 | 9 | data_path = 'data/ShapeNet' # path for ShapeNet from ONet 10 | base = 'data' # output base directory 11 | dataset_name = 'shapenet_psr' 12 | multiprocess = True 13 | njobs = 8 14 | save_pointcloud = True 15 | save_psr_field = True 16 | resolution = 128 17 | zero_level = 0.0 18 | num_points = 100000 19 | padding = 1.2 20 | 21 | dpsr = DPSR(res=(resolution, resolution, resolution), sig=0) 22 | 23 | def process_one(obj): 24 | 25 | obj_name = obj.split('/')[-1] 26 | c = obj.split('/')[-2] 27 | 28 | # create new for the current object 29 | out_path_cur = os.path.join(base, dataset_name, c) 30 | out_path_cur_obj = os.path.join(out_path_cur, obj_name) 31 | os.makedirs(out_path_cur_obj, exist_ok=True) 32 | 33 | gt_path = os.path.join(data_path, c, obj_name, 'pointcloud.npz') 34 | data = np.load(gt_path) 35 | points = data['points'] 36 | normals = data['normals'] 37 | 38 | # normalize the point to [0, 1) 39 | points = points / padding + 0.5 40 | # to scale back during inference, we should: 41 | #! p = (p - 0.5) * padding 42 | 43 | if save_pointcloud: 44 | outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz') 45 | # np.savez(outdir, points=points, normals=normals) 46 | np.savez(outdir, points=data['points'], normals=data['normals']) 47 | # return 48 | 49 | if save_psr_field: 50 | psr_gt = dpsr(torch.from_numpy(points.astype(np.float32))[None], 51 | torch.from_numpy(normals.astype(np.float32))[None]).squeeze().cpu().numpy().astype(np.float16) 52 | 53 | outdir = os.path.join(out_path_cur_obj, 'psr.npz') 54 | np.savez(outdir, psr=psr_gt) 55 | 56 | 57 | def main(c): 58 | 59 | print('---------------------------------------') 60 | print('Processing {} {}'.format(c, split)) 61 | print('---------------------------------------') 62 | 63 | for split in ['train', 'val', 'test']: 64 | fname = os.path.join(data_path, c, split+'.lst') 65 | with open(fname, 'r') as f: 66 | obj_list = f.read().splitlines() 67 | 68 | obj_list = [c+'/'+s for s in obj_list] 69 | 70 | if multiprocess: 71 | # multiprocessing.set_start_method('spawn', force=True) 72 | pool = multiprocessing.Pool(njobs) 73 | try: 74 | for _ in tqdm(pool.imap_unordered(process_one, obj_list), total=len(obj_list)): 75 | pass 76 | # pool.map_async(process_one, obj_list).get() 77 | except KeyboardInterrupt: 78 | # Allow ^C to interrupt from any thread. 79 | exit() 80 | pool.close() 81 | else: 82 | for obj in tqdm(obj_list): 83 | process_one(obj) 84 | 85 | print('Done Processing {} {}!'.format(c, split)) 86 | 87 | 88 | if __name__ == "__main__": 89 | 90 | classes = ['02691156', '02828884', '02933112', 91 | '02958343', '03211117', '03001627', 92 | '03636649', '03691459', '04090263', 93 | '04256520', '04379243', '04401088', '04530566'] 94 | 95 | 96 | t_start = time.time() 97 | for c in classes: 98 | main(c) 99 | 100 | t_end = time.time() 101 | print('Total processing time: ', t_end - t_start) 102 | -------------------------------------------------------------------------------- /SAP/scripts/process_tooth.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import numpy as np 3 | from tqdm import tqdm 4 | from src.dpsr import DPSR 5 | #from easy_mesh_vtk import Easy_Mesh 6 | import pyvista as pv 7 | import os 8 | from pymeshfix._meshfix import PyTMesh 9 | from pymeshfix import MeshFix 10 | import open3d as o3d 11 | import torch 12 | 13 | data_path = 'C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\datasap' # path for ShapeNet from ONet 14 | data_path_mesh='C:/Users/Golriz/OneDrive - polymtl.ca/Desktop/mesh' 15 | base = 'data' # output base directory 16 | dataset_name = 'shapenet_psr' 17 | multiprocess = True 18 | njobs = 8 19 | save_pointcloud = True 20 | save_psr_field = True 21 | resolution = 128 22 | zero_level = 0.0 23 | num_points = 100000 24 | padding = 1.2 25 | out_path_cur_obj='C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\datasap' 26 | dpsr = DPSR(res=(resolution, resolution, resolution), sig=0) 27 | 28 | 29 | 30 | 31 | 32 | mesh=o3d.io.read_triangle_mesh(os.path.join(data_path_mesh, 'shell26_registered_watertight.ply')) 33 | mesh=mesh.subdivide_loop(number_of_iterations=3) 34 | #mesh = Easy_Mesh(os.path.join(data_path_mesh, '45shell_registered.ply')) 35 | points = mesh.vertices 36 | print('points',np.asarray(points)) 37 | Npoints = np.asarray(points) 38 | normals = mesh.compute_vertex_normals() 39 | Nnormals = np.asarray(normals.vertex_normals) 40 | print("normals",np.asarray(normals.vertex_normals)) 41 | #randomple sample 100000 42 | points_sample = 100000 43 | positive_mesh_idx = np.arange(len(Npoints)) 44 | try: 45 | positive_selected_mesh_idx = np.random.choice(positive_mesh_idx, size=points_sample, replace=False) 46 | except ValueError: 47 | positive_selected_mesh_idx = np.random.choice(positive_mesh_idx, size=points_sample, replace=True) 48 | #mesh_with_newpoints = np.zeros([points_sample, Npoints.shape[1]], dtype='float32') 49 | Npoints = Npoints[positive_selected_mesh_idx, :] 50 | print('shape',Npoints.shape) 51 | Nnormals= Nnormals[positive_selected_mesh_idx, :] 52 | 53 | 54 | # normalize the point to [0, 1) 55 | Npoints=(Npoints-np.min(Npoints))/(np.max(Npoints)+1-np.min(Npoints)) 56 | 57 | #mean/std 58 | #Npoints_mean=np.mean(Npoints) 59 | #Npoints_std=np.std(Npoints) 60 | #Npoints=(Npoints-Npoints_mean) / Npoints_std 61 | #Npoints = np.asarray(Npoints) / padding + 0.5 62 | # to scale back during inference, we should: 63 | # ! p = (p - 0.5) * padding 64 | 65 | 66 | if save_pointcloud: 67 | outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz') 68 | # np.savez(outdir, points=points, normals=normals) 69 | np.savez(outdir, points=Npoints, normals=np.asarray(Nnormals)) 70 | # return 71 | 72 | if save_psr_field: 73 | psr_gt = dpsr(torch.from_numpy(Npoints.astype(np.float32))[None], 74 | torch.from_numpy(Nnormals[None].astype(np.float32))).squeeze().cpu().numpy().astype(np.float16) 75 | 76 | outdir = os.path.join(out_path_cur_obj, 'psr.npz') 77 | np.savez(outdir, psr=psr_gt) -------------------------------------------------------------------------------- /SAP/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/__init__.py -------------------------------------------------------------------------------- /SAP/src/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/__pycache__/dpsr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/__pycache__/dpsr.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/__pycache__/generation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/__pycache__/generation.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/__pycache__/training.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/__pycache__/training.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from torchvision import transforms 3 | from ..src import data, generation 4 | from ..src.dpsr import DPSR 5 | from ipdb import set_trace as st 6 | 7 | 8 | # Generator for final mesh extraction 9 | def get_generator(model, cfg, device, **kwargs): 10 | ''' Returns the generator object. 11 | 12 | Args: 13 | model (nn.Module): Occupancy Network model 14 | cfg (dict): imported yaml config 15 | device (device): pytorch device 16 | ''' 17 | 18 | if cfg['generation']['psr_resolution'] == 0: 19 | psr_res = cfg['model']['grid_res'] 20 | psr_sigma = cfg['model']['psr_sigma'] 21 | else: 22 | psr_res = cfg['generation']['psr_resolution'] 23 | psr_sigma = cfg['generation']['psr_sigma'] 24 | 25 | dpsr = DPSR(res=(psr_res, psr_res, psr_res), 26 | sig= psr_sigma).to(device) 27 | 28 | 29 | generator = generation.Generator3D( 30 | model, 31 | device=device, 32 | threshold=cfg['data']['zero_level'], 33 | sample=cfg['generation']['use_sampling'], 34 | input_type = cfg['data']['input_type'], 35 | padding=cfg['data']['padding'], 36 | dpsr=dpsr, 37 | psr_tanh=cfg['model']['psr_tanh'] 38 | ) 39 | return generator 40 | 41 | # Datasets 42 | def get_dataset(mode, cfg, return_idx=False): 43 | ''' Returns the dataset. 44 | 45 | Args: 46 | model (nn.Module): the model which is used 47 | cfg (dict): config dictionary 48 | return_idx (bool): whether to include an ID field 49 | ''' 50 | dataset_type = cfg['data']['dataset'] 51 | dataset_folder = cfg['data']['path'] 52 | categories = cfg['data']['class'] 53 | 54 | # Get split 55 | splits = { 56 | 'train': cfg['data']['train_split'], 57 | 'val': cfg['data']['val_split'], 58 | 'test': cfg['data']['test_split'], 59 | 'vis': cfg['data']['val_split'], 60 | } 61 | 62 | split = splits[mode] 63 | 64 | # Create dataset 65 | if dataset_type == 'Shapes3D': 66 | fields = get_data_fields(mode, cfg) 67 | # Input fields 68 | inputs_field = get_inputs_field(mode, cfg) 69 | if inputs_field is not None: 70 | fields['inputs'] = inputs_field 71 | 72 | if return_idx: 73 | fields['idx'] = data.IndexField() 74 | 75 | dataset = data.Shapes3dDataset( 76 | dataset_folder, fields, 77 | split=split, 78 | categories=categories, 79 | cfg = cfg 80 | ) 81 | else: 82 | raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset']) 83 | 84 | return dataset 85 | 86 | 87 | def get_inputs_field(mode, cfg): 88 | ''' Returns the inputs fields. 89 | 90 | Args: 91 | mode (str): the mode which is used 92 | cfg (dict): config dictionary 93 | ''' 94 | input_type = cfg['data']['input_type'] 95 | 96 | if input_type is None: 97 | inputs_field = None 98 | elif input_type == 'pointcloud': 99 | noise_level = cfg['data']['pointcloud_noise'] 100 | if cfg['data']['pointcloud_outlier_ratio']>0: 101 | transform = transforms.Compose([ 102 | data.SubsamplePointcloud(cfg['data']['pointcloud_n']), 103 | data.PointcloudNoise(noise_level), 104 | data.PointcloudOutliers(cfg['data']['pointcloud_outlier_ratio']) 105 | ]) 106 | else: 107 | transform = transforms.Compose([ 108 | data.SubsamplePointcloud(cfg['data']['pointcloud_n']), 109 | data.PointcloudNoise(noise_level) 110 | ]) 111 | 112 | data_type = cfg['data']['data_type'] 113 | inputs_field = data.PointCloudField( 114 | cfg['data']['pointcloud_file'], data_type, transform, 115 | multi_files= cfg['data']['multi_files'] 116 | ) 117 | else: 118 | raise ValueError( 119 | 'Invalid input type (%s)' % input_type) 120 | return inputs_field 121 | 122 | def get_data_fields(mode, cfg): 123 | ''' Returns the data fields. 124 | 125 | Args: 126 | mode (str): the mode which is used 127 | cfg (dict): imported yaml config 128 | ''' 129 | data_type = cfg['data']['data_type'] 130 | fields = {} 131 | 132 | if (mode in ('val', 'test')): 133 | transform = data.SubsamplePointcloud(100000) 134 | else: 135 | transform = data.SubsamplePointcloud(cfg['data']['num_gt_points']) 136 | 137 | data_name = cfg['data']['pointcloud_file'] 138 | fields['gt_points'] = data.PointCloudField(data_name, 139 | transform=transform, data_type=data_type, multi_files=cfg['data']['multi_files']) 140 | if data_type == 'psr_full': 141 | if mode != 'test': 142 | fields['gt_psr'] = data.FullPSRField(multi_files=cfg['data']['multi_files']) 143 | else: 144 | raise ValueError('Invalid data type (%s)' % data_type) 145 | 146 | return fields -------------------------------------------------------------------------------- /SAP/src/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from ..data.core import ( 3 | Shapes3dDataset, collate_remove_none, worker_init_fn, collate_stack_together 4 | ) 5 | from ..data.fields import ( 6 | IndexField, PointCloudField, FullPSRField 7 | ) 8 | from ..data.transforms import ( 9 | PointcloudNoise, SubsamplePointcloud, 10 | PointcloudOutliers, 11 | ) 12 | __all__ = [ 13 | # Core 14 | Shapes3dDataset, 15 | collate_remove_none, 16 | worker_init_fn, 17 | # Fields 18 | IndexField, 19 | PointCloudField, 20 | FullPSRField, 21 | # Transforms 22 | PointcloudNoise, 23 | SubsamplePointcloud, 24 | PointcloudOutliers, 25 | ] 26 | -------------------------------------------------------------------------------- /SAP/src/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/data/__pycache__/core.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/data/__pycache__/core.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/data/__pycache__/fields.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/data/__pycache__/fields.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/data/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/data/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/data/fields.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import time 4 | import random 5 | from PIL import Image 6 | import numpy as np 7 | import trimesh 8 | from ..data.core import Field 9 | from pdb import set_trace as st 10 | 11 | 12 | class IndexField(Field): 13 | ''' Basic index field.''' 14 | def load(self, model_path, idx, category): 15 | ''' Loads the index field. 16 | 17 | Args: 18 | model_path (str): path to model 19 | idx (int): ID of data point 20 | category (int): index of category 21 | ''' 22 | return idx 23 | 24 | def check_complete(self, files): 25 | ''' Check if field is complete. 26 | 27 | Args: 28 | files: files 29 | ''' 30 | return True 31 | 32 | class FullPSRField(Field): 33 | def __init__(self, transform=None, multi_files=None): 34 | self.transform = transform 35 | # self.unpackbits = unpackbits 36 | self.multi_files = multi_files 37 | 38 | def load(self, model_path, idx, category): 39 | 40 | # try: 41 | # t0 = time.time() 42 | if self.multi_files is not None: 43 | psr_path = os.path.join(model_path, 'psr', 'psr_{:02d}.npz'.format(idx)) 44 | else: 45 | psr_path = os.path.join(model_path, 'psr.npz') 46 | psr_dict = np.load(psr_path) 47 | # t1 = time.time() 48 | psr = psr_dict['psr'] 49 | psr = psr.astype(np.float32) 50 | # t2 = time.time() 51 | # print('load PSR: {:.4f}, change type: {:.4f}, total: {:.4f}'.format(t1 - t0, t2 - t1, t2-t0)) 52 | data = {None: psr} 53 | 54 | if self.transform is not None: 55 | data = self.transform(data) 56 | 57 | return data 58 | 59 | class PointCloudField(Field): 60 | ''' Point cloud field. 61 | 62 | It provides the field used for point cloud data. These are the points 63 | randomly sampled on the mesh. 64 | 65 | Args: 66 | file_name (str): file name 67 | transform (list): list of transformations applied to data points 68 | multi_files (callable): number of files 69 | ''' 70 | def __init__(self, file_name, data_type=None, transform=None, multi_files=None, padding=0.1, scale=1.2): 71 | self.file_name = file_name 72 | self.data_type = data_type # to make sure the range of input is correct 73 | self.transform = transform 74 | self.multi_files = multi_files 75 | self.padding = padding 76 | self.scale = scale 77 | 78 | def load(self, model_path, idx, category): 79 | ''' Loads the data point. 80 | 81 | Args: 82 | model_path (str): path to model 83 | idx (int): ID of data point 84 | category (int): index of category 85 | ''' 86 | if self.multi_files is None: 87 | file_path = os.path.join(model_path, self.file_name) 88 | else: 89 | # num = np.random.randint(self.multi_files) 90 | # file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num)) 91 | file_path = os.path.join(model_path, self.file_name, 'pointcloud_%02d.npz' % (idx)) 92 | 93 | pointcloud_dict = np.load(file_path) 94 | 95 | points = pointcloud_dict['points'].astype(np.float32) 96 | normals = pointcloud_dict['normals'].astype(np.float32) 97 | 98 | data = { 99 | None: points, 100 | 'normals': normals, 101 | } 102 | if self.transform is not None: 103 | data = self.transform(data) 104 | 105 | if self.data_type == 'psr_full': 106 | print("psr_full") 107 | #scale the point cloud to the range of (0, 1) 108 | #data[None] = data[None] / self.scale + 0.5 # we already scale data 109 | 110 | return data 111 | 112 | def check_complete(self, files): 113 | ''' Check if field is complete. 114 | 115 | Args: 116 | files: files 117 | ''' 118 | complete = (self.file_name in files) 119 | return complete 120 | -------------------------------------------------------------------------------- /SAP/src/data/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # Transforms 5 | class PointcloudNoise(object): 6 | ''' Point cloud noise transformation class. 7 | 8 | It adds noise to point cloud data. 9 | 10 | Args: 11 | stddev (int): standard deviation 12 | ''' 13 | 14 | def __init__(self, stddev): 15 | self.stddev = stddev 16 | 17 | def __call__(self, data): 18 | ''' Calls the transformation. 19 | 20 | Args: 21 | data (dictionary): data dictionary 22 | ''' 23 | data_out = data.copy() 24 | points = data[None] 25 | noise = self.stddev * np.random.randn(*points.shape) 26 | noise = noise.astype(np.float32) 27 | data_out[None] = points + noise 28 | return data_out 29 | 30 | class PointcloudOutliers(object): 31 | ''' Point cloud outlier transformation class. 32 | 33 | It adds outliers to point cloud data. 34 | 35 | Args: 36 | ratio (int): outlier percentage to the entire point cloud 37 | ''' 38 | 39 | def __init__(self, ratio): 40 | self.ratio = ratio 41 | 42 | def __call__(self, data): 43 | ''' Calls the transformation. 44 | 45 | Args: 46 | data (dictionary): data dictionary 47 | ''' 48 | data_out = data.copy() 49 | points = data[None] 50 | n_points = points.shape[0] 51 | n_outlier_points = int(n_points*self.ratio) 52 | ind = np.random.randint(0, n_points, n_outlier_points) 53 | 54 | outliers = np.random.uniform(-0.55, 0.55, (n_outlier_points, 3)) 55 | outliers = outliers.astype(np.float32) 56 | points[ind] = outliers 57 | data_out[None] = points 58 | return data_out 59 | 60 | class SubsamplePointcloud(object): 61 | ''' Point cloud subsampling transformation class. 62 | 63 | It subsamples the point cloud data. 64 | 65 | Args: 66 | N (int): number of points to be subsampled 67 | ''' 68 | def __init__(self, N): 69 | self.N = N 70 | 71 | def __call__(self, data): 72 | ''' Calls the transformation. 73 | 74 | Args: 75 | data (dict): data dictionary 76 | ''' 77 | data_out = data.copy() 78 | points = data[None] 79 | 80 | indices = np.random.randint(points.shape[0], size=self.N) 81 | data_out[None] = points[indices, :] 82 | if 'normals' in data.keys(): 83 | normals = data['normals'] 84 | data_out['normals'] = normals[indices, :] 85 | 86 | return data_out -------------------------------------------------------------------------------- /SAP/src/dpsr.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from ..src.utils import spec_gaussian_filter, fftfreqs, img, grid_interp, point_rasterize 5 | import numpy as np 6 | import torch.fft 7 | 8 | class DPSR(nn.Module): 9 | def __init__(self, res, sig=10, scale=True, shift=True): 10 | """ 11 | :param res: tuple of output field resolution. eg., (128,128) 12 | :param sig: degree of gaussian smoothing 13 | """ 14 | super(DPSR, self).__init__() 15 | self.res = res 16 | self.sig = sig 17 | self.dim = len(res) 18 | self.denom = np.prod(res) 19 | G = spec_gaussian_filter(res=res, sig=sig).float() 20 | # self.G.requires_grad = False # True, if we also make sig a learnable parameter 21 | self.omega = fftfreqs(res, dtype=torch.float32) 22 | self.scale = scale 23 | self.shift = shift 24 | self.register_buffer("G", G) 25 | 26 | def forward(self, V, N): 27 | """ 28 | :param V: (batch, nv, 2 or 3) tensor for point cloud coordinates 29 | :param N: (batch, nv, 2 or 3) tensor for point normals 30 | :return phi: (batch, res, res, ...) tensor of output indicator function field 31 | """ 32 | assert(V.shape == N.shape) # [b, nv, ndims] 33 | ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2] 34 | 35 | ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4)) 36 | ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1])) 37 | N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1] 38 | 39 | omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1] 40 | omega *= 2 * np.pi # normalize frequencies 41 | omega = omega.to(V.device) 42 | 43 | DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2) 44 | 45 | Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1] 46 | Phi = DivN / (Lap+1e-6) # [b, dim0, dim1, dim2/2+1, 2] 47 | Phi = Phi.permute(*tuple([list(range(1,self.dim+2)) + [0]])) # [dim0, dim1, dim2/2+1, 2, b] 48 | Phi[tuple([0] * self.dim)] = 0 49 | Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2] 50 | 51 | phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3)) 52 | 53 | if self.shift or self.scale: 54 | # ensure values at points are zero 55 | fv = grid_interp(phi.unsqueeze(-1), V, batched=True).squeeze(-1) # [b, nv] 56 | if self.shift: # offset points to have mean of 0 57 | offset = torch.mean(fv, dim=-1) # [b,] 58 | phi -= offset.view(*tuple([-1] + [1] * self.dim)) 59 | 60 | phi = phi.permute(*tuple([list(range(1,self.dim+1)) + [0]])) 61 | fv0 = phi[tuple([0] * self.dim)] # [b,] 62 | phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))])) 63 | 64 | if self.scale: 65 | phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5 66 | return phi -------------------------------------------------------------------------------- /SAP/src/eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import trimesh 4 | from pykdtree.kdtree import KDTree 5 | 6 | EMPTY_PCL_DICT = { 7 | 'completeness': np.sqrt(3), 8 | 'accuracy': np.sqrt(3), 9 | 'completeness2': 3, 10 | 'accuracy2': 3, 11 | 'chamfer': 6, 12 | } 13 | 14 | EMPTY_PCL_DICT_NORMALS = { 15 | 'normals completeness': -1., 16 | 'normals accuracy': -1., 17 | 'normals': -1., 18 | } 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class MeshEvaluator(object): 24 | ''' Mesh evaluation class. 25 | It handles the mesh evaluation process. 26 | Args: 27 | n_points (int): number of points to be used for evaluation 28 | ''' 29 | 30 | def __init__(self, n_points=100000): 31 | self.n_points = n_points 32 | 33 | def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1./1000, 1, 1000)): 34 | ''' Evaluates a mesh. 35 | Args: 36 | mesh (trimesh): mesh which should be evaluated 37 | pointcloud_tgt (numpy array): target point cloud 38 | normals_tgt (numpy array): target normals 39 | thresholds (numpy arry): for F-Score 40 | ''' 41 | if len(mesh.vertices) != 0 and len(mesh.faces) != 0: 42 | pointcloud, idx = mesh.sample(self.n_points, return_index=True) 43 | 44 | pointcloud = pointcloud.astype(np.float32) 45 | normals = mesh.face_normals[idx] 46 | else: 47 | pointcloud = np.empty((0, 3)) 48 | normals = np.empty((0, 3)) 49 | 50 | out_dict = self.eval_pointcloud( 51 | pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds) 52 | 53 | return out_dict 54 | 55 | def eval_pointcloud(self, pointcloud, pointcloud_tgt, 56 | normals=None, normals_tgt=None, 57 | thresholds=np.linspace(1./1000, 1, 1000)): 58 | ''' Evaluates a point cloud. 59 | Args: 60 | pointcloud (numpy array): predicted point cloud 61 | pointcloud_tgt (numpy array): target point cloud 62 | normals (numpy array): predicted normals 63 | normals_tgt (numpy array): target normals 64 | thresholds (numpy array): threshold values for the F-score calculation 65 | ''' 66 | # Return maximum losses if pointcloud is empty 67 | if pointcloud.shape[0] == 0: 68 | logger.warn('Empty pointcloud / mesh detected!') 69 | out_dict = EMPTY_PCL_DICT.copy() 70 | if normals is not None and normals_tgt is not None: 71 | out_dict.update(EMPTY_PCL_DICT_NORMALS) 72 | return out_dict 73 | 74 | pointcloud = np.asarray(pointcloud) 75 | pointcloud_tgt = np.asarray(pointcloud_tgt) 76 | 77 | 78 | # Completeness: how far are the points of the target point cloud 79 | # from thre predicted point cloud 80 | completeness, completeness_normals = distance_p2p( 81 | pointcloud_tgt, normals_tgt, pointcloud, normals 82 | ) 83 | recall = get_threshold_percentage(completeness, thresholds) 84 | completeness2 = completeness**2 85 | 86 | completeness = completeness.mean() 87 | completeness2 = completeness2.mean() 88 | completeness_normals = completeness_normals.mean() 89 | 90 | # Accuracy: how far are th points of the predicted pointcloud 91 | # from the target pointcloud 92 | accuracy, accuracy_normals = distance_p2p( 93 | pointcloud, normals, pointcloud_tgt, normals_tgt 94 | ) 95 | precision = get_threshold_percentage(accuracy, thresholds) 96 | accuracy2 = accuracy**2 97 | 98 | accuracy = accuracy.mean() 99 | accuracy2 = accuracy2.mean() 100 | accuracy_normals = accuracy_normals.mean() 101 | 102 | # Chamfer distance 103 | chamferL2 = 0.5 * (completeness2 + accuracy2) 104 | normals_correctness = ( 105 | 0.5 * completeness_normals + 0.5 * accuracy_normals 106 | ) 107 | chamferL1 = 0.5 * (completeness + accuracy) 108 | 109 | # F-Score 110 | F = [ 111 | 2 * precision[i] * recall[i] / (precision[i] + recall[i]) 112 | for i in range(len(precision)) 113 | ] 114 | 115 | out_dict = { 116 | 'completeness': completeness, 117 | 'accuracy': accuracy, 118 | 'normals completeness': completeness_normals, 119 | 'normals accuracy': accuracy_normals, 120 | 'normals': normals_correctness, 121 | 'completeness2': completeness2, 122 | 'accuracy2': accuracy2, 123 | 'chamfer-L2': chamferL2, 124 | 'chamfer-L1': chamferL1, 125 | 'f-score': F[9], # threshold = 1.0% 126 | 'f-score-15': F[14], # threshold = 1.5% 127 | 'f-score-20': F[19], # threshold = 2.0% 128 | } 129 | 130 | return out_dict 131 | 132 | 133 | def distance_p2p(points_src, normals_src, points_tgt, normals_tgt): 134 | ''' Computes minimal distances of each point in points_src to points_tgt. 135 | Args: 136 | points_src (numpy array): source points 137 | normals_src (numpy array): source normals 138 | points_tgt (numpy array): target points 139 | normals_tgt (numpy array): target normals 140 | ''' 141 | kdtree = KDTree(points_tgt) 142 | dist, idx = kdtree.query(points_src) 143 | 144 | if normals_src is not None and normals_tgt is not None: 145 | normals_src = \ 146 | normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True) 147 | normals_tgt = \ 148 | normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True) 149 | 150 | normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1) 151 | # Handle normals that point into wrong direction gracefully 152 | # (mostly due to mehtod not caring about this in generation) 153 | normals_dot_product = np.abs(normals_dot_product) 154 | else: 155 | normals_dot_product = np.array( 156 | [np.nan] * points_src.shape[0], dtype=np.float32) 157 | return dist, normals_dot_product 158 | 159 | def get_threshold_percentage(dist, thresholds): 160 | ''' Evaluates a point cloud. 161 | Args: 162 | dist (numpy array): calculated distance 163 | thresholds (numpy array): threshold values for the F-score calculation 164 | ''' 165 | in_threshold = [ 166 | (dist <= t).mean() for t in thresholds 167 | ] 168 | return in_threshold -------------------------------------------------------------------------------- /SAP/src/generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import trimesh 4 | import numpy as np 5 | from ..src.utils import mc_from_psr 6 | 7 | class Generator3D(object): 8 | ''' Generator class for Occupancy Networks. 9 | 10 | It provides functions to generate the final mesh as well refining options. 11 | 12 | Args: 13 | model (nn.Module): trained Occupancy Network model 14 | points_batch_size (int): batch size for points evaluation 15 | threshold (float): threshold value 16 | device (device): pytorch device 17 | padding (float): how much padding should be used for MISE 18 | sample (bool): whether z should be sampled 19 | input_type (str): type of input 20 | ''' 21 | 22 | def __init__(self, model, points_batch_size=100000, 23 | threshold=0.5, device=None, padding=0.1, 24 | sample=False, input_type = None, dpsr=None, psr_tanh=True): 25 | self.model = model.to(device) 26 | self.points_batch_size = points_batch_size 27 | self.threshold = threshold 28 | self.device = device 29 | self.input_type = input_type 30 | self.padding = padding 31 | self.sample = sample 32 | self.dpsr = dpsr 33 | self.psr_tanh = psr_tanh 34 | 35 | def generate_mesh(self, data, return_stats=True): 36 | ''' Generates the output mesh. 37 | 38 | Args: 39 | data (tensor): data tensor 40 | return_stats (bool): whether stats should be returned 41 | ''' 42 | self.model.eval() 43 | device = self.device 44 | stats_dict = {} 45 | 46 | p = data.get('inputs', torch.empty(1, 0)).to(device) 47 | 48 | t0 = time.time() 49 | points, normals = self.model(p) 50 | t1 = time.time() 51 | psr_grid = self.dpsr(points, normals) 52 | t2 = time.time() 53 | v, f, _ = mc_from_psr(psr_grid, 54 | zero_level=self.threshold) 55 | stats_dict['pcl'] = t1 - t0 56 | stats_dict['dpsr'] = t2 - t1 57 | stats_dict['mc'] = time.time() - t2 58 | stats_dict['total'] = time.time() - t0 59 | 60 | if return_stats: 61 | return v, f, points, normals, stats_dict 62 | else: 63 | return v, f, points, normals -------------------------------------------------------------------------------- /SAP/src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | import numpy as np 4 | import time 5 | from ..src.utils import point_rasterize, grid_interp, mc_from_psr, \ 6 | calc_inters_points 7 | from ..src.dpsr import DPSR 8 | import torch.nn as nn 9 | from ..src.network import encoder_dict, decoder_dict 10 | from ..src.network.utils import map2local 11 | from .utils import load_config 12 | 13 | class PSR2Mesh(torch.autograd.Function): 14 | @staticmethod 15 | def forward(ctx, psr_grid): 16 | """ 17 | In the forward pass we receive a Tensor containing the input and return 18 | a Tensor containing the output. ctx is a context object that can be used 19 | to stash information for backward computation. You can cache arbitrary 20 | objects for use in the backward pass using the ctx.save_for_backward method. 21 | """ 22 | verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True) 23 | verts = verts.unsqueeze(0) 24 | faces = faces.unsqueeze(0) 25 | normals = normals.unsqueeze(0) 26 | 27 | res = torch.tensor(psr_grid.detach().shape[2]) 28 | ctx.save_for_backward(verts, normals, res) 29 | 30 | return verts, faces, normals 31 | 32 | @staticmethod 33 | def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals): 34 | """ 35 | In the backward pass we receive a Tensor containing the gradient of the loss 36 | with respect to the output, and we need to compute the gradient of the loss 37 | with respect to the input. 38 | """ 39 | vert_pts, normals, res = ctx.saved_tensors 40 | res = (res.item(), res.item(), res.item()) 41 | # matrix multiplication between dL/dV and dV/dPSR 42 | # dV/dPSR = - normals 43 | grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0)) 44 | grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res 45 | 46 | return grad_grid 47 | 48 | class PSR2SurfacePoints(torch.autograd.Function): 49 | @staticmethod 50 | def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample): 51 | verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True) 52 | verts = verts * 2. - 1. # within the range of [-1, 1] 53 | 54 | 55 | p_all, n_all, mask_all = [], [], [] 56 | 57 | for i in range(len(poses)): 58 | pose = poses[i] 59 | if mask_sample is not None: 60 | p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size, mask_gt=mask_sample[i]) 61 | else: 62 | p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size) 63 | 64 | n_inters = grid_interp(psr_grad[None], (p_inters[None].detach() + 1) / 2).squeeze() 65 | p_all.append(p_inters) 66 | n_all.append(n_inters) 67 | mask_all.append(mask) 68 | p_inters_all = torch.cat(p_all, dim=0) 69 | n_inters_all = torch.cat(n_all, dim=0) 70 | mask_visible = torch.stack(mask_all, dim=0) 71 | 72 | 73 | res = torch.tensor(psr_grid.detach().shape[2]) 74 | ctx.save_for_backward(p_inters_all, n_inters_all, res) 75 | 76 | return p_inters_all, mask_visible 77 | 78 | @staticmethod 79 | def backward(ctx, dL_dp, dL_dmask): 80 | pts, pts_n, res = ctx.saved_tensors 81 | res = (res.item(), res.item(), res.item()) 82 | 83 | # grad from the p_inters via MLP renderer 84 | grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None]) 85 | grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res 86 | 87 | return grad_grid_pts, None, None, None, None, None 88 | 89 | class Encode2Points(nn.Module): 90 | def __init__(self, cfg): 91 | super().__init__() 92 | 93 | cfg = load_config(cfg) 94 | self.cfg = cfg 95 | 96 | encoder = cfg['model']['encoder'] 97 | decoder = cfg['model']['decoder'] 98 | dim = cfg['data']['dim'] # input dim 99 | c_dim = cfg['model']['c_dim'] 100 | encoder_kwargs = cfg['model']['encoder_kwargs'] 101 | if encoder_kwargs == None: 102 | encoder_kwargs = {} 103 | decoder_kwargs = cfg['model']['decoder_kwargs'] 104 | padding = cfg['data']['padding'] 105 | self.predict_normal = cfg['model']['predict_normal'] 106 | self.predict_offset = cfg['model']['predict_offset'] 107 | 108 | out_dim = 3 109 | out_dim_offset = 3 110 | num_offset = cfg['data']['num_offset'] 111 | # each point predict more than one offset to add output points 112 | if num_offset > 1: 113 | out_dim_offset = out_dim * num_offset 114 | self.num_offset = num_offset 115 | 116 | # local mapping 117 | self.map2local = None 118 | if cfg['model']['local_coord']: 119 | if 'unet' in encoder_kwargs.keys(): 120 | unit_size = 1 / encoder_kwargs['plane_resolution'] 121 | else: 122 | unit_size = 1 / encoder_kwargs['grid_resolution'] 123 | 124 | local_mapping = map2local(unit_size) 125 | 126 | self.encoder = encoder_dict[encoder]( 127 | dim=dim, c_dim=c_dim, map2local=local_mapping, 128 | **encoder_kwargs 129 | ) 130 | 131 | if self.predict_normal: 132 | # decoder for normal prediction 133 | self.decoder_normal = decoder_dict[decoder]( 134 | dim=dim, c_dim=c_dim, out_dim=out_dim, 135 | **decoder_kwargs) 136 | if self.predict_offset: 137 | # decoder for offset prediction 138 | self.decoder_offset = decoder_dict[decoder]( 139 | dim=dim, c_dim=c_dim, out_dim=out_dim_offset, 140 | map2local=local_mapping, 141 | **decoder_kwargs) 142 | 143 | self.s_off = cfg['model']['s_offset'] 144 | 145 | def forward(self, p): 146 | ''' Performs a forward pass through the network. 147 | 148 | Args: 149 | p (tensor): input unoriented points 150 | ''' 151 | 152 | time_dict = {} 153 | mask = None 154 | 155 | batch_size = p.size(0) 156 | points = p.clone() 157 | 158 | # encode the input point cloud to a feature volume 159 | t0 = time.perf_counter() 160 | c = self.encoder(p) 161 | t1 = time.perf_counter() 162 | if self.predict_offset: 163 | offset = self.decoder_offset(p, c) 164 | # more than one offset is predicted per-point 165 | if self.num_offset > 1: 166 | points = points.repeat(1, 1, self.num_offset).reshape(batch_size, -1, 3) 167 | points = points + self.s_off * offset 168 | else: 169 | points = p 170 | 171 | if self.predict_normal: 172 | normals = self.decoder_normal(points, c) 173 | t2 = time.perf_counter() 174 | 175 | time_dict['encode'] = t1 - t0 176 | time_dict['predict'] = t2 - t1 177 | 178 | points = torch.clamp(points, 0.0, 0.99) 179 | if self.cfg['model']['normal_normalize']: 180 | normals = normals / (normals.norm(dim=-1, keepdim=True)+1e-8) 181 | 182 | 183 | return points, normals 184 | 185 | -------------------------------------------------------------------------------- /SAP/src/model_rgb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..src.network.net_rgb import RenderingNetwork 3 | from ..src.utils import approx_psr_grad 4 | from pytorch3d.renderer import ( 5 | RasterizationSettings, 6 | PerspectiveCameras, 7 | MeshRenderer, 8 | MeshRasterizer, 9 | SoftSilhouetteShader) 10 | from pytorch3d.structures import Meshes 11 | 12 | 13 | 14 | def approx_psr_grad(psr_grid, res, normalize=True): 15 | delta_x = delta_y = delta_z = 1/res 16 | psr_pad = torch.nn.ReplicationPad3d(1)(psr_grid).squeeze() 17 | 18 | grad_x = (psr_pad[2:, :, :] - psr_pad[:-2, :, :]) / 2 / delta_x 19 | grad_y = (psr_pad[:, 2:, :] - psr_pad[:, :-2, :]) / 2 / delta_y 20 | grad_z = (psr_pad[:, :, 2:] - psr_pad[:, :, :-2]) / 2 / delta_z 21 | grad_x = grad_x[:, 1:-1, 1:-1] 22 | grad_y = grad_y[1:-1, :, 1:-1] 23 | grad_z = grad_z[1:-1, 1:-1, :] 24 | 25 | psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=3) # [res_x, res_y, res_z, 3] 26 | if normalize: 27 | psr_grad = psr_grad / (psr_grad.norm(dim=3, keepdim=True) + 1e-12) 28 | 29 | return psr_grad 30 | 31 | 32 | class SAP2Image(nn.Module): 33 | def __init__(self, cfg, img_size): 34 | super().__init__() 35 | 36 | self.psr2sur = PSR2SurfacePoints.apply 37 | self.psr2mesh = PSR2Mesh.apply 38 | # initialize DPSR 39 | self.dpsr = DPSR(res=(cfg['model']['grid_res'], 40 | cfg['model']['grid_res'], 41 | cfg['model']['grid_res']), 42 | sig=cfg['model']['psr_sigma']) 43 | self.cfg = cfg 44 | if cfg['train']['l_weight']['rgb'] != 0.: 45 | self.rendering_network = RenderingNetwork(**cfg['model']['renderer']) 46 | 47 | if cfg['train']['l_weight']['mask'] != 0.: 48 | # initialize rasterizer 49 | sigma = 1e-4 50 | raster_settings_soft = RasterizationSettings( 51 | image_size=img_size, 52 | blur_radius=np.log(1. / 1e-4 - 1.)*sigma, 53 | faces_per_pixel=150, 54 | perspective_correct=False 55 | ) 56 | 57 | # initialize silhouette renderer 58 | self.mesh_rasterizer = MeshRenderer( 59 | rasterizer=MeshRasterizer( 60 | raster_settings=raster_settings_soft 61 | ), 62 | shader=SoftSilhouetteShader() 63 | ) 64 | 65 | self.cfg = cfg 66 | self.img_size = img_size 67 | 68 | def forward(self, inputs, data): 69 | points, normals = inputs[...,:3], inputs[...,3:] 70 | points = torch.sigmoid(points) 71 | normals = normals / normals.norm(dim=-1, keepdim=True) 72 | 73 | # DPSR to get grid 74 | psr_grid = self.dpsr(points, normals).unsqueeze(1) 75 | psr_grid = torch.tanh(psr_grid) 76 | 77 | return self.render_img(psr_grid, data) 78 | 79 | def render_img(self, psr_grid, data): 80 | 81 | n_views = len(data['masks']) 82 | n_views_per_iter = self.cfg['data']['n_views_per_iter'] 83 | 84 | rgb_render_mode = self.cfg['model']['renderer']['mode'] 85 | uv = data['uv'] 86 | 87 | idx = np.random.randint(0, n_views, n_views_per_iter) 88 | pose = [data['poses'][i] for i in idx] 89 | rgb = data['rgbs'][idx] 90 | mask_gt = data['masks'][idx] 91 | ray = None 92 | pred_rgb = None 93 | pred_mask = None 94 | 95 | if self.cfg['train']['l_weight']['rgb'] != 0.: 96 | psr_grad = approx_psr_grad(psr_grid, self.cfg['model']['grid_res']) 97 | p_inters, visible_mask = self.psr2sur(psr_grid, pose, self.img_size, uv, psr_grad, None) 98 | n_inters = grid_interp(psr_grad[None], (p_inters.detach()[None] + 1) / 2) 99 | fea_interp = None 100 | if 'rays' in data.keys(): 101 | ray = data['rays'].squeeze()[idx][visible_mask] 102 | pred_rgb = self.rendering_network(p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp) 103 | 104 | # silhouette loss 105 | if self.cfg['train']['l_weight']['mask'] != 0.: 106 | # build mesh 107 | v, f, _ = self.psr2mesh(psr_grid) 108 | v = v * 2. - 1 # within the range of [-1, 1] 109 | # ! Fast but more GPU usage 110 | mesh = Meshes(verts=[v.squeeze()], faces=[f.squeeze()]) 111 | if True: 112 | #! PyTorch3D silhouette loss 113 | # build pose 114 | R = torch.cat([p.R for p in pose], dim=0) 115 | T = torch.cat([p.T for p in pose], dim=0) 116 | focal = torch.cat([p.focal_length for p in pose], dim=0) 117 | pp = torch.cat([p.principal_point for p in pose], dim=0) 118 | pose_cur = PerspectiveCameras( 119 | focal_length=focal, 120 | principal_point=pp, 121 | R=R, T=T, 122 | device=R.device) 123 | pred_mask = self.mesh_rasterizer(mesh.extend(n_views_per_iter), cameras=pose_cur)[..., 3] 124 | else: 125 | pred_mask = [] 126 | # ! Slow but less GPU usage 127 | for i in range(n_views_per_iter): 128 | #! PyTorch3D silhouette loss 129 | pred_mask.append(self.mesh_rasterizer(mesh, cameras=pose[i])[..., 3]) 130 | pred_mask = torch.cat(pred_mask, dim=0) 131 | 132 | output = { 133 | 'rgb': pred_rgb, 134 | 'rgb_gt': rgb, 135 | 'mask': pred_mask, 136 | 'mask_gt': mask_gt, 137 | 'vis_mask': visible_mask, 138 | } 139 | 140 | return output -------------------------------------------------------------------------------- /SAP/src/network/__init__.py: -------------------------------------------------------------------------------- 1 | from ..network import encoder, decoder 2 | 3 | encoder_dict = { 4 | 'local_pool_pointnet': encoder.LocalPoolPointnet, 5 | } 6 | decoder_dict = { 7 | 'simple_local': decoder.LocalDecoder, 8 | } -------------------------------------------------------------------------------- /SAP/src/network/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/network/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/network/__pycache__/decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/network/__pycache__/decoder.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/network/__pycache__/encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/network/__pycache__/encoder.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/network/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/network/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/network/__pycache__/unet3d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/network/__pycache__/unet3d.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/network/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/network/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /SAP/src/network/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from ipdb import set_trace as st 6 | from ..network.utils import normalize_3d_coordinate, ResnetBlockFC, \ 7 | normalize_coordinate 8 | 9 | 10 | class LocalDecoder(nn.Module): 11 | ''' Decoder. 12 | Instead of conditioning on global features, on plane/volume local features. 13 | Args: 14 | dim (int): input dimension 15 | c_dim (int): dimension of latent conditioned code c 16 | hidden_size (int): hidden size of Decoder network 17 | n_blocks (int): number of blocks ResNetBlockFC layers 18 | leaky (bool): whether to use leaky ReLUs 19 | sample_mode (str): sampling feature strategy, bilinear|nearest 20 | padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] 21 | ''' 22 | 23 | def __init__(self, dim=3, c_dim=128, out_dim=3, 24 | hidden_size=256, n_blocks=5, leaky=False, sample_mode='bilinear', padding=0.1, map2local=None): 25 | super().__init__() 26 | self.c_dim = c_dim 27 | self.n_blocks = n_blocks 28 | 29 | if c_dim != 0: 30 | self.fc_c = nn.ModuleList([ 31 | nn.Linear(c_dim, hidden_size) for i in range(n_blocks) 32 | ]) 33 | 34 | 35 | self.fc_p = nn.Linear(dim, hidden_size) 36 | 37 | self.blocks = nn.ModuleList([ 38 | ResnetBlockFC(hidden_size) for i in range(n_blocks) 39 | ]) 40 | 41 | self.fc_out = nn.Linear(hidden_size, out_dim) 42 | 43 | if not leaky: 44 | self.actvn = F.relu 45 | else: 46 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 47 | 48 | self.sample_mode = sample_mode 49 | self.padding = padding 50 | self.map2local = map2local 51 | self.out_dim = out_dim 52 | 53 | 54 | def sample_plane_feature(self, p, c, plane='xz'): 55 | xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1) 56 | xy = xy[:, :, None].float() 57 | vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1) 58 | c = F.grid_sample(c, vgrid, padding_mode='border', 59 | align_corners=True, 60 | mode=self.sample_mode).squeeze(-1) 61 | return c 62 | 63 | def sample_grid_feature(self, p, c): 64 | p_nor = normalize_3d_coordinate(p.clone()) 65 | p_nor = p_nor[:, :, None, None].float() 66 | vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1) 67 | # acutally trilinear interpolation if mode = 'bilinear' 68 | c = F.grid_sample(c, vgrid, padding_mode='border', 69 | align_corners=True, 70 | mode=self.sample_mode).squeeze(-1).squeeze(-1) 71 | return c 72 | 73 | def forward(self, p, c_plane, **kwargs): 74 | batch_size = p.shape[0] 75 | plane_type = list(c_plane.keys()) 76 | c = 0 77 | if 'grid' in plane_type: 78 | c += self.sample_grid_feature(p, c_plane['grid']) 79 | if 'xz' in plane_type: 80 | c += self.sample_plane_feature(p, c_plane['xz'], plane='xz') 81 | if 'xy' in plane_type: 82 | c += self.sample_plane_feature(p, c_plane['xy'], plane='xy') 83 | if 'yz' in plane_type: 84 | c += self.sample_plane_feature(p, c_plane['yz'], plane='yz') 85 | c = c.transpose(1, 2) 86 | 87 | p = p.float() 88 | 89 | if self.map2local: 90 | p = self.map2local(p) 91 | 92 | net = self.fc_p(p) 93 | 94 | for i in range(self.n_blocks): 95 | if self.c_dim != 0: 96 | net = net + self.fc_c[i](c) 97 | 98 | net = self.blocks[i](net) 99 | 100 | out = self.fc_out(self.actvn(net)) 101 | 102 | 103 | if self.out_dim > 3: 104 | out = out.reshape(batch_size, -1, 3) 105 | 106 | return out -------------------------------------------------------------------------------- /SAP/src/network/utils.py: -------------------------------------------------------------------------------- 1 | """ Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class Embedder: 7 | def __init__(self, **kwargs): 8 | self.kwargs = kwargs 9 | self.create_embedding_fn() 10 | 11 | def create_embedding_fn(self): 12 | embed_fns = [] 13 | d = self.kwargs['input_dims'] 14 | out_dim = 0 15 | if self.kwargs['include_input']: 16 | embed_fns.append(lambda x: x) 17 | out_dim += d 18 | 19 | max_freq = self.kwargs['max_freq_log2'] 20 | N_freqs = self.kwargs['num_freqs'] 21 | 22 | if self.kwargs['log_sampling']: 23 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) 24 | else: 25 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) 26 | 27 | for freq in freq_bands: 28 | for p_fn in self.kwargs['periodic_fns']: 29 | embed_fns.append(lambda x, p_fn=p_fn, 30 | freq=freq: p_fn(x * freq)) 31 | out_dim += d 32 | 33 | self.embed_fns = embed_fns 34 | self.out_dim = out_dim 35 | 36 | def embed(self, inputs): 37 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 38 | 39 | def get_embedder(multires, d_in=3): 40 | embed_kwargs = { 41 | 'include_input': True, 42 | 'input_dims': d_in, 43 | 'max_freq_log2': multires-1, 44 | 'num_freqs': multires, 45 | 'log_sampling': True, 46 | 'periodic_fns': [torch.sin, torch.cos], 47 | } 48 | 49 | embedder_obj = Embedder(**embed_kwargs) 50 | def embed(x, eo=embedder_obj): return eo.embed(x) 51 | return embed, embedder_obj.out_dim 52 | 53 | def normalize_coordinate(p, plane='xz'): 54 | ''' Normalize coordinate to [0, 1] for unit cube experiments 55 | 56 | Args: 57 | p (tensor): point 58 | padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] 59 | plane (str): plane feature type, ['xz', 'xy', 'yz'] 60 | ''' 61 | if plane == 'xz': 62 | xy = p[:, :, [0, 2]] 63 | elif plane =='xy': 64 | xy = p[:, :, [0, 1]] 65 | else: 66 | xy = p[:, :, [1, 2]] 67 | 68 | xy_new = xy 69 | # f there are outliers out of the range 70 | if xy_new.max() >= 1: 71 | xy_new[xy_new >= 1] = 1 - 10e-6 72 | if xy_new.min() < 0: 73 | xy_new[xy_new < 0] = 0.0 74 | return xy_new 75 | 76 | 77 | def normalize_3d_coordinate(p): 78 | ''' Normalize coordinate to [0, 1] for unit cube experiments. 79 | ''' 80 | if p.max() >= 1: 81 | p[p >= 1] = 1 - 10e-6 82 | if p.min() < 0: 83 | p[p < 0] = 0.0 84 | return p 85 | 86 | def coordinate2index(x, reso, coord_type='2d'): 87 | ''' Normalize coordinate to [0, 1] for unit cube experiments. 88 | Corresponds to our 3D model 89 | 90 | Args: 91 | x (tensor): coordinate 92 | reso (int): defined resolution 93 | coord_type (str): coordinate type 94 | ''' 95 | x = (x * reso).long() 96 | if coord_type == '2d': # plane 97 | index = x[:, :, 0] + reso * x[:, :, 1] 98 | elif coord_type == '3d': # grid 99 | index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2]) 100 | index = index[:, None, :] 101 | return index 102 | 103 | 104 | class map2local(object): 105 | ''' Add new keys to the given input 106 | 107 | Args: 108 | s (float): the defined voxel size 109 | pos_encoding (str): method for the positional encoding, linear|sin_cos 110 | ''' 111 | def __init__(self, s, pos_encoding='linear'): 112 | super().__init__() 113 | self.s = s 114 | # self.pe = positional_encoding(basis_function=pos_encoding, local=True) 115 | 116 | def __call__(self, p): 117 | # p = torch.remainder(p, self.s) / self.s # always possitive 118 | p = (p % self.s) / self.s 119 | p[p < 0] = 0.0 120 | # p = torch.fmod(p, self.s) / self.s # same sign as input p! 121 | # p = self.pe(p) 122 | return p 123 | 124 | # Resnet Blocks 125 | class ResnetBlockFC(nn.Module): 126 | ''' Fully connected ResNet Block class. 127 | 128 | Args: 129 | size_in (int): input dimension 130 | size_out (int): output dimension 131 | size_h (int): hidden dimension 132 | ''' 133 | 134 | def __init__(self, size_in, size_out=None, size_h=None, siren=False): 135 | super().__init__() 136 | # Attributes 137 | if size_out is None: 138 | size_out = size_in 139 | 140 | if size_h is None: 141 | size_h = min(size_in, size_out) 142 | 143 | self.size_in = size_in 144 | self.size_h = size_h 145 | self.size_out = size_out 146 | # Submodules 147 | self.fc_0 = nn.Linear(size_in, size_h) 148 | self.fc_1 = nn.Linear(size_h, size_out) 149 | self.actvn = nn.ReLU() 150 | 151 | if size_in == size_out: 152 | self.shortcut = None 153 | else: 154 | self.shortcut = nn.Linear(size_in, size_out, bias=False) 155 | # Initialization 156 | nn.init.zeros_(self.fc_1.weight) 157 | 158 | def forward(self, x): 159 | net = self.fc_0(self.actvn(x)) 160 | dx = self.fc_1(self.actvn(net)) 161 | 162 | if self.shortcut is not None: 163 | x_s = self.shortcut(x) 164 | else: 165 | x_s = x 166 | 167 | return x_s + dx -------------------------------------------------------------------------------- /cfgs/Tooth_models/PoinTr.yaml: -------------------------------------------------------------------------------- 1 | optimizer : { 2 | type: AdamW, 3 | kwargs: { 4 | lr : 0.0005, 5 | weight_decay : 0.0005 6 | }} 7 | scheduler: { 8 | type: LambdaLR, 9 | kwargs: { 10 | decay_step: 21, 11 | lr_decay: 0.76, 12 | lowest_decay: 0.02 # min lr = lowest_decay * lr 13 | }} 14 | bnmscheduler: { 15 | type: Lambda, 16 | kwargs: { 17 | decay_step: 21, 18 | bn_decay: 0.5, 19 | bn_momentum: 0.9, 20 | lowest_decay: 0.01 21 | }} 22 | 23 | dataset : { 24 | train : { _base_: cfgs/dataset_configs/Tooth.yaml, 25 | others: {subset: 'train'}}, 26 | val : { _base_: cfgs/dataset_configs/Tooth.yaml, 27 | others: {subset: 'test'}}, 28 | test : { _base_: cfgs/dataset_configs/Tooth.yaml, 29 | others: {subset: 'test'}}} 30 | model : { 31 | NAME: PoinTr, num_pred: 1568, num_query: 96, knn_layer: 1, trans_dim: 384} 32 | total_bs : 1 33 | step_per_update : 1 34 | max_epoch : 2 35 | 36 | consider_metric: CDL2 37 | 38 | 39 | -------------------------------------------------------------------------------- /cfgs/dataset_configs/Tooth.yaml: -------------------------------------------------------------------------------- 1 | NAME: crown 2 | DATA_PATH: data/dental/crown 3 | N_POINTS: 9999 4 | PC_PATH: C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\e 5 | #C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\back-updata-sap\\data-pointr\\data-pointr 6 | -------------------------------------------------------------------------------- /data/dental/crown/test.txt: -------------------------------------------------------------------------------- 1 | 6581-36 -------------------------------------------------------------------------------- /data/dental/crown/train.txt: -------------------------------------------------------------------------------- 1 | 6581-36 -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_dataset_from_cfg 2 | import datasets.crowndataset -------------------------------------------------------------------------------- /datasets/__pycache__/KITTIDataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/KITTIDataset.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/PCNDataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/PCNDataset.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/ShapeNet55Dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/ShapeNet55Dataset.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/Wrapping_Python_vtk_util_numpy_support.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/Wrapping_Python_vtk_util_numpy_support.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/build.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/build.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/crowndataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/crowndataset.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/data_transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/data_transforms.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/easy_mesh_vtk.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/easy_mesh_vtk.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/io.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/io.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/build.py: -------------------------------------------------------------------------------- 1 | from utils import registry 2 | 3 | 4 | DATASETS = registry.Registry('dataset') 5 | 6 | 7 | def build_dataset_from_cfg(cfg, default_args = None): 8 | """ 9 | Build a dataset, defined by `dataset_name`. 10 | Args: 11 | cfg (eDICT): 12 | Returns: 13 | Dataset: a constructed dataset specified by dataset_name. 14 | """ 15 | return DATASETS.build(cfg, default_args = default_args) 16 | 17 | 18 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Thibault GROUEIX 3 | # @Date: 2019-08-07 20:54:24 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2019-12-18 15:06:25 6 | # @Email: cshzxie@gmail.com 7 | 8 | import torch 9 | 10 | import chamfer 11 | 12 | 13 | class ChamferFunction(torch.autograd.Function): 14 | @staticmethod 15 | def forward(ctx, xyz1, xyz2): 16 | dist1, dist2, idx1, idx2 = chamfer.forward(xyz1, xyz2) 17 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 18 | 19 | return dist1, dist2 20 | 21 | @staticmethod 22 | def backward(ctx, grad_dist1, grad_dist2): 23 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 24 | grad_xyz1, grad_xyz2 = chamfer.backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2) 25 | return grad_xyz1, grad_xyz2 26 | 27 | 28 | class ChamferDistanceL2(torch.nn.Module): 29 | f''' Chamder Distance L2 30 | ''' 31 | def __init__(self, ignore_zeros=False): 32 | super().__init__() 33 | self.ignore_zeros = ignore_zeros 34 | 35 | def forward(self, xyz1, xyz2): 36 | batch_size = xyz1.size(0) 37 | if batch_size == 1 and self.ignore_zeros: 38 | non_zeros1 = torch.sum(xyz1, dim=2).ne(0) 39 | non_zeros2 = torch.sum(xyz2, dim=2).ne(0) 40 | xyz1 = xyz1[non_zeros1].unsqueeze(dim=0) 41 | xyz2 = xyz2[non_zeros2].unsqueeze(dim=0) 42 | 43 | dist1, dist2 = ChamferFunction.apply(xyz1, xyz2) 44 | return torch.mean(dist1) + torch.mean(dist2) 45 | 46 | class ChamferDistanceL2_split(torch.nn.Module): 47 | f''' Chamder Distance L2 48 | ''' 49 | def __init__(self, ignore_zeros=False): 50 | super().__init__() 51 | self.ignore_zeros = ignore_zeros 52 | 53 | def forward(self, xyz1, xyz2): 54 | batch_size = xyz1.size(0) 55 | if batch_size == 1 and self.ignore_zeros: 56 | non_zeros1 = torch.sum(xyz1, dim=2).ne(0) 57 | non_zeros2 = torch.sum(xyz2, dim=2).ne(0) 58 | xyz1 = xyz1[non_zeros1].unsqueeze(dim=0) 59 | xyz2 = xyz2[non_zeros2].unsqueeze(dim=0) 60 | 61 | dist1, dist2 = ChamferFunction.apply(xyz1, xyz2) 62 | return torch.mean(dist1), torch.mean(dist2) 63 | 64 | class ChamferDistanceL1(torch.nn.Module): 65 | f''' Chamder Distance L1 66 | ''' 67 | def __init__(self, ignore_zeros=False): 68 | super().__init__() 69 | self.ignore_zeros = ignore_zeros 70 | 71 | def forward(self, xyz1, xyz2): 72 | batch_size = xyz1.size(0) 73 | if batch_size == 1 and self.ignore_zeros: 74 | non_zeros1 = torch.sum(xyz1, dim=2).ne(0) 75 | non_zeros2 = torch.sum(xyz2, dim=2).ne(0) 76 | xyz1 = xyz1[non_zeros1].unsqueeze(dim=0) 77 | xyz2 = xyz2[non_zeros2].unsqueeze(dim=0) 78 | 79 | dist1, dist2 = ChamferFunction.apply(xyz1, xyz2) 80 | # import pdb 81 | # pdb.set_trace() 82 | dist1 = torch.sqrt(dist1) 83 | dist2 = torch.sqrt(dist2) 84 | return (torch.mean(dist1) + torch.mean(dist2))/2 85 | 86 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/chamfer_dist/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /extensions/chamfer_dist/build/lib.linux-x86_64-3.8/chamfer.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/chamfer_dist/build/lib.linux-x86_64-3.8/chamfer.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /extensions/chamfer_dist/build/temp.linux-x86_64-3.8/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/bin/nvcc 4 | 5 | cflags = -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -fPIC -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c 6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=chamfer -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++14 7 | cuda_cflags = -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=chamfer -D_GLIBCXX_USE_CXX11_ABI=1 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 -std=c++14 9 | ldflags = 10 | 11 | rule compile 12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 13 | depfile = $out.d 14 | deps = gcc 15 | 16 | rule cuda_compile 17 | depfile = $out.d 18 | deps = gcc 19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 20 | 21 | 22 | 23 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/chamfer_dist/build/temp.linux-x86_64-3.8/chamfer.o: cuda_compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/chamfer_dist/chamfer.cu 24 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/chamfer_dist/build/temp.linux-x86_64-3.8/chamfer_cuda.o: compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/chamfer_dist/chamfer_cuda.cpp 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/build/temp.linux-x86_64-3.8/chamfer.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/chamfer_dist/build/temp.linux-x86_64-3.8/chamfer.o -------------------------------------------------------------------------------- /extensions/chamfer_dist/build/temp.linux-x86_64-3.8/chamfer_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/chamfer_dist/build/temp.linux-x86_64-3.8/chamfer_cuda.o -------------------------------------------------------------------------------- /extensions/chamfer_dist/chamfer.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: chamfer 3 | Version: 2.0.0 4 | Summary: UNKNOWN 5 | License: UNKNOWN 6 | Platform: UNKNOWN 7 | 8 | UNKNOWN 9 | 10 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/chamfer.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | chamfer.cu 2 | chamfer_cuda.cpp 3 | setup.py 4 | chamfer.egg-info/PKG-INFO 5 | chamfer.egg-info/SOURCES.txt 6 | chamfer.egg-info/dependency_links.txt 7 | chamfer.egg-info/top_level.txt -------------------------------------------------------------------------------- /extensions/chamfer_dist/chamfer.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/chamfer.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | chamfer 2 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * @Author: Haozhe Xie 3 | * @Date: 2019-08-07 20:54:24 4 | * @Last Modified by: Haozhe Xie 5 | * @Last Modified time: 2019-12-10 10:33:50 6 | * @Email: cshzxie@gmail.com 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | std::vector chamfer_cuda_forward(torch::Tensor xyz1, 13 | torch::Tensor xyz2); 14 | 15 | std::vector chamfer_cuda_backward(torch::Tensor xyz1, 16 | torch::Tensor xyz2, 17 | torch::Tensor idx1, 18 | torch::Tensor idx2, 19 | torch::Tensor grad_dist1, 20 | torch::Tensor grad_dist2); 21 | 22 | std::vector chamfer_forward(torch::Tensor xyz1, 23 | torch::Tensor xyz2) { 24 | return chamfer_cuda_forward(xyz1, xyz2); 25 | } 26 | 27 | std::vector chamfer_backward(torch::Tensor xyz1, 28 | torch::Tensor xyz2, 29 | torch::Tensor idx1, 30 | torch::Tensor idx2, 31 | torch::Tensor grad_dist1, 32 | torch::Tensor grad_dist2) { 33 | return chamfer_cuda_backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2); 34 | } 35 | 36 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 37 | m.def("forward", &chamfer_forward, "Chamfer forward (CUDA)"); 38 | m.def("backward", &chamfer_backward, "Chamfer backward (CUDA)"); 39 | } 40 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/dist/chamfer-2.0.0-py3.8-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/chamfer_dist/dist/chamfer-2.0.0-py3.8-linux-x86_64.egg -------------------------------------------------------------------------------- /extensions/chamfer_dist/setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-08-07 20:54:24 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2019-12-10 10:04:25 6 | # @Email: cshzxie@gmail.com 7 | 8 | from setuptools import setup 9 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 10 | 11 | setup(name='chamfer', 12 | version='2.0.0', 13 | ext_modules=[ 14 | CUDAExtension('chamfer', [ 15 | 'chamfer_cuda.cpp', 16 | 'chamfer.cu', 17 | ]), 18 | ], 19 | cmdclass={'build_ext': BuildExtension}) 20 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-12-10 10:38:01 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2019-12-26 14:21:36 6 | # @Email: cshzxie@gmail.com 7 | # 8 | # Note: 9 | # - Replace float -> double, kFloat -> kDouble in chamfer.cu 10 | 11 | import os 12 | import sys 13 | import torch 14 | import unittest 15 | 16 | 17 | from torch.autograd import gradcheck 18 | 19 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))) 20 | from extensions.chamfer_dist import ChamferFunction 21 | 22 | 23 | class ChamferDistanceTestCase(unittest.TestCase): 24 | def test_chamfer_dist(self): 25 | x = torch.rand(4, 64, 3).double() 26 | y = torch.rand(4, 128, 3).double() 27 | x.requires_grad = True 28 | y.requires_grad = True 29 | print(gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()])) 30 | 31 | 32 | 33 | if __name__ == '__main__': 34 | # unittest.main() 35 | import pdb 36 | x = torch.rand(32,128,3) 37 | y = torch.rand(32,128,3) 38 | pdb.set_trace() 39 | -------------------------------------------------------------------------------- /extensions/cubic_feature_sampling/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-12-19 16:55:15 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2019-12-26 13:15:14 6 | # @Email: cshzxie@gmail.com 7 | 8 | import torch 9 | 10 | import cubic_feature_sampling 11 | 12 | 13 | class CubicFeatureSamplingFunction(torch.autograd.Function): 14 | @staticmethod 15 | def forward(ctx, ptcloud, cubic_features, neighborhood_size=1): 16 | scale = cubic_features.size(2) 17 | point_features, grid_pt_indexes = cubic_feature_sampling.forward(scale, neighborhood_size, ptcloud, 18 | cubic_features) 19 | ctx.save_for_backward(torch.Tensor([scale]), torch.Tensor([neighborhood_size]), grid_pt_indexes) 20 | return point_features 21 | 22 | @staticmethod 23 | def backward(ctx, grad_point_features): 24 | scale, neighborhood_size, grid_pt_indexes = ctx.saved_tensors 25 | scale = int(scale.item()) 26 | neighborhood_size = int(neighborhood_size.item()) 27 | grad_point_features = grad_point_features.contiguous() 28 | grad_ptcloud, grad_cubic_features = cubic_feature_sampling.backward(scale, neighborhood_size, 29 | grad_point_features, grid_pt_indexes) 30 | return grad_ptcloud, grad_cubic_features, None 31 | 32 | 33 | class CubicFeatureSampling(torch.nn.Module): 34 | def __init__(self): 35 | super(CubicFeatureSampling, self).__init__() 36 | 37 | def forward(self, ptcloud, cubic_features, neighborhood_size=1): 38 | h_scale = cubic_features.size(2) / 2 39 | ptcloud = ptcloud * h_scale + h_scale 40 | return CubicFeatureSamplingFunction.apply(ptcloud, cubic_features, neighborhood_size) 41 | -------------------------------------------------------------------------------- /extensions/cubic_feature_sampling/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/cubic_feature_sampling/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /extensions/cubic_feature_sampling/build/lib.linux-x86_64-3.8/cubic_feature_sampling.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/cubic_feature_sampling/build/lib.linux-x86_64-3.8/cubic_feature_sampling.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /extensions/cubic_feature_sampling/build/temp.linux-x86_64-3.8/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/bin/nvcc 4 | 5 | cflags = -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -fPIC -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c 6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=cubic_feature_sampling -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++14 7 | cuda_cflags = -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=cubic_feature_sampling -D_GLIBCXX_USE_CXX11_ABI=1 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 -std=c++14 9 | ldflags = 10 | 11 | rule compile 12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 13 | depfile = $out.d 14 | deps = gcc 15 | 16 | rule cuda_compile 17 | depfile = $out.d 18 | deps = gcc 19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 20 | 21 | 22 | 23 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/cubic_feature_sampling/build/temp.linux-x86_64-3.8/cubic_feature_sampling.o: cuda_compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/cubic_feature_sampling/cubic_feature_sampling.cu 24 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/cubic_feature_sampling/build/temp.linux-x86_64-3.8/cubic_feature_sampling_cuda.o: compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/cubic_feature_sampling/cubic_feature_sampling_cuda.cpp 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /extensions/cubic_feature_sampling/build/temp.linux-x86_64-3.8/cubic_feature_sampling.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/cubic_feature_sampling/build/temp.linux-x86_64-3.8/cubic_feature_sampling.o -------------------------------------------------------------------------------- /extensions/cubic_feature_sampling/build/temp.linux-x86_64-3.8/cubic_feature_sampling_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/cubic_feature_sampling/build/temp.linux-x86_64-3.8/cubic_feature_sampling_cuda.o -------------------------------------------------------------------------------- /extensions/cubic_feature_sampling/cubic_feature_sampling.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: cubic-feature-sampling 3 | Version: 1.1.0 4 | Summary: UNKNOWN 5 | License: UNKNOWN 6 | Platform: UNKNOWN 7 | 8 | UNKNOWN 9 | 10 | -------------------------------------------------------------------------------- /extensions/cubic_feature_sampling/cubic_feature_sampling.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | cubic_feature_sampling.cu 2 | cubic_feature_sampling_cuda.cpp 3 | setup.py 4 | cubic_feature_sampling.egg-info/PKG-INFO 5 | cubic_feature_sampling.egg-info/SOURCES.txt 6 | cubic_feature_sampling.egg-info/dependency_links.txt 7 | cubic_feature_sampling.egg-info/top_level.txt -------------------------------------------------------------------------------- /extensions/cubic_feature_sampling/cubic_feature_sampling.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /extensions/cubic_feature_sampling/cubic_feature_sampling.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | cubic_feature_sampling 2 | -------------------------------------------------------------------------------- /extensions/cubic_feature_sampling/cubic_feature_sampling_cuda.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * @Author: Haozhe Xie 3 | * @Date: 2019-12-19 17:04:38 4 | * @Last Modified by: Haozhe Xie 5 | * @Last Modified time: 2020-06-17 14:50:22 6 | * @Email: cshzxie@gmail.com 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 13 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") 14 | #define CHECK_CONTIGUOUS(x) \ 15 | AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 16 | #define CHECK_INPUT(x) \ 17 | CHECK_CUDA(x); \ 18 | CHECK_CONTIGUOUS(x) 19 | 20 | std::vector cubic_feature_sampling_cuda_forward( 21 | int scale, 22 | int neighborhood_size, 23 | torch::Tensor ptcloud, 24 | torch::Tensor cubic_features, 25 | cudaStream_t stream); 26 | 27 | std::vector cubic_feature_sampling_cuda_backward( 28 | int scale, 29 | int neighborhood_size, 30 | torch::Tensor grad_point_features, 31 | torch::Tensor grid_pt_indexes, 32 | cudaStream_t stream); 33 | 34 | std::vector cubic_feature_sampling_forward( 35 | int scale, 36 | int neighborhood_size, 37 | torch::Tensor ptcloud, 38 | torch::Tensor cubic_features) { 39 | CHECK_INPUT(ptcloud); 40 | CHECK_INPUT(cubic_features); 41 | 42 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 43 | return cubic_feature_sampling_cuda_forward(scale, neighborhood_size, ptcloud, 44 | cubic_features, stream); 45 | } 46 | 47 | std::vector cubic_feature_sampling_backward( 48 | int scale, 49 | int neighborhood_size, 50 | torch::Tensor grad_point_features, 51 | torch::Tensor grid_pt_indexes) { 52 | CHECK_INPUT(grad_point_features); 53 | CHECK_INPUT(grid_pt_indexes); 54 | 55 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 56 | return cubic_feature_sampling_cuda_backward( 57 | scale, neighborhood_size, grad_point_features, grid_pt_indexes, stream); 58 | } 59 | 60 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 61 | m.def("forward", &cubic_feature_sampling_forward, 62 | "Cubic Feature Sampling forward (CUDA)"); 63 | m.def("backward", &cubic_feature_sampling_backward, 64 | "Cubic Feature Sampling backward (CUDA)"); 65 | } -------------------------------------------------------------------------------- /extensions/cubic_feature_sampling/dist/cubic_feature_sampling-1.1.0-py3.8-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/cubic_feature_sampling/dist/cubic_feature_sampling-1.1.0-py3.8-linux-x86_64.egg -------------------------------------------------------------------------------- /extensions/cubic_feature_sampling/setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-12-19 17:03:06 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2019-12-26 14:02:06 6 | # @Email: cshzxie@gmail.com 7 | 8 | from setuptools import setup 9 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 10 | 11 | setup(name='cubic_feature_sampling', 12 | version='1.1.0', 13 | ext_modules=[ 14 | CUDAExtension('cubic_feature_sampling', ['cubic_feature_sampling_cuda.cpp', 'cubic_feature_sampling.cu']), 15 | ], 16 | cmdclass={'build_ext': BuildExtension}) 17 | -------------------------------------------------------------------------------- /extensions/cubic_feature_sampling/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-12-20 11:50:50 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2019-12-26 13:52:33 6 | # @Email: cshzxie@gmail.com 7 | # 8 | # Note: 9 | # - Replace float -> double, kFloat -> kDouble in cubic_feature_sampling.cu 10 | 11 | import os 12 | import sys 13 | import torch 14 | import unittest 15 | 16 | from torch.autograd import gradcheck 17 | 18 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))) 19 | from extensions.cubic_feature_sampling import CubicFeatureSamplingFunction 20 | 21 | 22 | class CubicFeatureSamplingTestCase(unittest.TestCase): 23 | def test_neighborhood_size_1(self): 24 | ptcloud = torch.rand(2, 64, 3) * 2 - 1 25 | cubic_features = torch.rand(2, 4, 8, 8, 8) 26 | ptcloud.requires_grad = True 27 | cubic_features.requires_grad = True 28 | self.assertTrue( 29 | gradcheck(CubicFeatureSamplingFunction.apply, 30 | [ptcloud.double().cuda(), cubic_features.double().cuda()])) 31 | 32 | def test_neighborhood_size_2(self): 33 | ptcloud = torch.rand(2, 32, 3) * 2 - 1 34 | cubic_features = torch.rand(2, 2, 8, 8, 8) 35 | ptcloud.requires_grad = True 36 | cubic_features.requires_grad = True 37 | self.assertTrue( 38 | gradcheck(CubicFeatureSamplingFunction.apply, 39 | [ptcloud.double().cuda(), cubic_features.double().cuda(), 2])) 40 | 41 | def test_neighborhood_size_3(self): 42 | ptcloud = torch.rand(1, 32, 3) * 2 - 1 43 | cubic_features = torch.rand(1, 2, 16, 16, 16) 44 | ptcloud.requires_grad = True 45 | cubic_features.requires_grad = True 46 | self.assertTrue( 47 | gradcheck(CubicFeatureSamplingFunction.apply, 48 | [ptcloud.double().cuda(), cubic_features.double().cuda(), 3])) 49 | 50 | 51 | if __name__ == '__main__': 52 | unittest.main() 53 | -------------------------------------------------------------------------------- /extensions/gridding/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-11-15 20:33:52 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2019-12-30 09:55:53 6 | # @Email: cshzxie@gmail.com 7 | 8 | import torch 9 | 10 | import gridding 11 | 12 | 13 | class GriddingFunction(torch.autograd.Function): 14 | @staticmethod 15 | def forward(ctx, scale, ptcloud): 16 | grid, grid_pt_weights, grid_pt_indexes = gridding.forward(-scale, scale - 1, -scale, scale - 1, -scale, 17 | scale - 1, ptcloud) 18 | # print(grid.size()) # torch.Size(batch_size, n_grid_vertices) 19 | # print(grid_pt_weights.size()) # torch.Size(batch_size, n_pts, 8, 3) 20 | # print(grid_pt_indexes.size()) # torch.Size(batch_size, n_pts, 8) 21 | ctx.save_for_backward(grid_pt_weights, grid_pt_indexes) 22 | 23 | return grid 24 | 25 | @staticmethod 26 | def backward(ctx, grad_grid): 27 | grid_pt_weights, grid_pt_indexes = ctx.saved_tensors 28 | grad_ptcloud = gridding.backward(grid_pt_weights, grid_pt_indexes, grad_grid) 29 | # print(grad_ptcloud.size()) # torch.Size(batch_size, n_pts, 3) 30 | 31 | return None, grad_ptcloud 32 | 33 | 34 | class Gridding(torch.nn.Module): 35 | def __init__(self, scale=1): 36 | super(Gridding, self).__init__() 37 | self.scale = scale // 2 38 | 39 | def forward(self, ptcloud): 40 | ptcloud = ptcloud * self.scale 41 | _ptcloud = torch.split(ptcloud, 1, dim=0) 42 | grids = [] 43 | for p in _ptcloud: 44 | non_zeros = torch.sum(p, dim=2).ne(0) 45 | p = p[non_zeros].unsqueeze(dim=0) 46 | grids.append(GriddingFunction.apply(self.scale, p.contiguous())) 47 | 48 | return torch.cat(grids, dim=0).contiguous() 49 | 50 | 51 | class GriddingReverseFunction(torch.autograd.Function): 52 | @staticmethod 53 | def forward(ctx, scale, grid): 54 | ptcloud = gridding.rev_forward(scale, grid) 55 | ctx.save_for_backward(torch.Tensor([scale]), grid, ptcloud) 56 | return ptcloud 57 | 58 | @staticmethod 59 | def backward(ctx, grad_ptcloud): 60 | scale, grid, ptcloud = ctx.saved_tensors 61 | scale = int(scale.item()) 62 | grad_grid = gridding.rev_backward(ptcloud, grid, grad_ptcloud) 63 | grad_grid = grad_grid.view(-1, scale, scale, scale) 64 | return None, grad_grid 65 | 66 | 67 | class GriddingReverse(torch.nn.Module): 68 | def __init__(self, scale=1): 69 | super(GriddingReverse, self).__init__() 70 | self.scale = scale 71 | 72 | def forward(self, grid): 73 | ptcloud = GriddingReverseFunction.apply(self.scale, grid) 74 | return ptcloud / self.scale * 2 75 | -------------------------------------------------------------------------------- /extensions/gridding/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /extensions/gridding/build/lib.linux-x86_64-3.8/gridding.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding/build/lib.linux-x86_64-3.8/gridding.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /extensions/gridding/build/temp.linux-x86_64-3.8/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/bin/nvcc 4 | 5 | cflags = -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -fPIC -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c 6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=gridding -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++14 7 | cuda_cflags = -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=gridding -D_GLIBCXX_USE_CXX11_ABI=1 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 -std=c++14 9 | ldflags = 10 | 11 | rule compile 12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 13 | depfile = $out.d 14 | deps = gcc 15 | 16 | rule cuda_compile 17 | depfile = $out.d 18 | deps = gcc 19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 20 | 21 | 22 | 23 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding/build/temp.linux-x86_64-3.8/gridding.o: cuda_compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding/gridding.cu 24 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding/build/temp.linux-x86_64-3.8/gridding_cuda.o: compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding/gridding_cuda.cpp 25 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding/build/temp.linux-x86_64-3.8/gridding_reverse.o: cuda_compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding/gridding_reverse.cu 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /extensions/gridding/build/temp.linux-x86_64-3.8/gridding.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding/build/temp.linux-x86_64-3.8/gridding.o -------------------------------------------------------------------------------- /extensions/gridding/build/temp.linux-x86_64-3.8/gridding_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding/build/temp.linux-x86_64-3.8/gridding_cuda.o -------------------------------------------------------------------------------- /extensions/gridding/build/temp.linux-x86_64-3.8/gridding_reverse.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding/build/temp.linux-x86_64-3.8/gridding_reverse.o -------------------------------------------------------------------------------- /extensions/gridding/dist/gridding-2.1.0-py3.8-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding/dist/gridding-2.1.0-py3.8-linux-x86_64.egg -------------------------------------------------------------------------------- /extensions/gridding/gridding.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: gridding 3 | Version: 2.1.0 4 | Summary: UNKNOWN 5 | License: UNKNOWN 6 | Platform: UNKNOWN 7 | 8 | UNKNOWN 9 | 10 | -------------------------------------------------------------------------------- /extensions/gridding/gridding.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | gridding.cu 2 | gridding_cuda.cpp 3 | gridding_reverse.cu 4 | setup.py 5 | gridding.egg-info/PKG-INFO 6 | gridding.egg-info/SOURCES.txt 7 | gridding.egg-info/dependency_links.txt 8 | gridding.egg-info/top_level.txt -------------------------------------------------------------------------------- /extensions/gridding/gridding.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /extensions/gridding/gridding.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | gridding 2 | -------------------------------------------------------------------------------- /extensions/gridding/gridding_cuda.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * @Author: Haozhe Xie 3 | * @Date: 2019-11-13 10:52:53 4 | * @Last Modified by: Haozhe Xie 5 | * @Last Modified time: 2020-06-17 14:52:32 6 | * @Email: cshzxie@gmail.com 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 13 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") 14 | #define CHECK_CONTIGUOUS(x) \ 15 | AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 16 | #define CHECK_INPUT(x) \ 17 | CHECK_CUDA(x); \ 18 | CHECK_CONTIGUOUS(x) 19 | 20 | std::vector gridding_cuda_forward(float min_x, 21 | float max_x, 22 | float min_y, 23 | float max_y, 24 | float min_z, 25 | float max_z, 26 | torch::Tensor ptcloud, 27 | cudaStream_t stream); 28 | 29 | torch::Tensor gridding_cuda_backward(torch::Tensor grid_pt_weights, 30 | torch::Tensor grid_pt_indexes, 31 | torch::Tensor grad_grid, 32 | cudaStream_t stream); 33 | 34 | torch::Tensor gridding_reverse_cuda_forward(int scale, 35 | torch::Tensor grid, 36 | cudaStream_t stream); 37 | 38 | torch::Tensor gridding_reverse_cuda_backward(torch::Tensor ptcloud, 39 | torch::Tensor grid, 40 | torch::Tensor grad_ptcloud, 41 | cudaStream_t stream); 42 | 43 | std::vector gridding_forward(float min_x, 44 | float max_x, 45 | float min_y, 46 | float max_y, 47 | float min_z, 48 | float max_z, 49 | torch::Tensor ptcloud) { 50 | CHECK_INPUT(ptcloud); 51 | 52 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 53 | return gridding_cuda_forward(min_x, max_x, min_y, max_y, min_z, max_z, 54 | ptcloud, stream); 55 | } 56 | 57 | torch::Tensor gridding_backward(torch::Tensor grid_pt_weights, 58 | torch::Tensor grid_pt_indexes, 59 | torch::Tensor grad_grid) { 60 | CHECK_INPUT(grid_pt_weights); 61 | CHECK_INPUT(grid_pt_indexes); 62 | CHECK_INPUT(grad_grid); 63 | 64 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 65 | return gridding_cuda_backward(grid_pt_weights, grid_pt_indexes, grad_grid, 66 | stream); 67 | } 68 | 69 | torch::Tensor gridding_reverse_forward(int scale, torch::Tensor grid) { 70 | CHECK_INPUT(grid); 71 | 72 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 73 | return gridding_reverse_cuda_forward(scale, grid, stream); 74 | } 75 | 76 | torch::Tensor gridding_reverse_backward(torch::Tensor ptcloud, 77 | torch::Tensor grid, 78 | torch::Tensor grad_ptcloud) { 79 | CHECK_INPUT(ptcloud); 80 | CHECK_INPUT(grid); 81 | CHECK_INPUT(grad_ptcloud); 82 | 83 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 84 | return gridding_reverse_cuda_backward(ptcloud, grid, grad_ptcloud, stream); 85 | } 86 | 87 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 88 | m.def("forward", &gridding_forward, "Gridding forward (CUDA)"); 89 | m.def("backward", &gridding_backward, "Gridding backward (CUDA)"); 90 | m.def("rev_forward", &gridding_reverse_forward, 91 | "Gridding Reverse forward (CUDA)"); 92 | m.def("rev_backward", &gridding_reverse_backward, 93 | "Gridding Reverse backward (CUDA)"); 94 | } 95 | -------------------------------------------------------------------------------- /extensions/gridding/setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-11-13 10:51:33 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2019-12-02 17:02:16 6 | # @Email: cshzxie@gmail.com 7 | 8 | from setuptools import setup 9 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 10 | 11 | setup(name='gridding', 12 | version='2.1.0', 13 | ext_modules=[ 14 | CUDAExtension('gridding', ['gridding_cuda.cpp', 'gridding.cu', 'gridding_reverse.cu']), 15 | ], 16 | cmdclass={'build_ext': BuildExtension}) 17 | -------------------------------------------------------------------------------- /extensions/gridding/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-12-10 10:48:55 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2019-12-26 14:20:42 6 | # @Email: cshzxie@gmail.com 7 | # 8 | # Note: 9 | # - Replace float -> double, kFloat -> kDouble in gridding.cu and gridding_reverse.cu 10 | 11 | import os 12 | import sys 13 | import torch 14 | import unittest 15 | 16 | from torch.autograd import gradcheck 17 | 18 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))) 19 | from extensions.gridding import GriddingFunction, GriddingReverseFunction 20 | 21 | 22 | class GriddingTestCase(unittest.TestCase): 23 | def test_gridding_reverse_function_4(self): 24 | x = torch.rand(2, 4, 4, 4) 25 | x.requires_grad = True 26 | self.assertTrue(gradcheck(GriddingReverseFunction.apply, [4, x.double().cuda()])) 27 | 28 | def test_gridding_reverse_function_8(self): 29 | x = torch.rand(4, 8, 8, 8) 30 | x.requires_grad = True 31 | self.assertTrue(gradcheck(GriddingReverseFunction.apply, [8, x.double().cuda()])) 32 | 33 | def test_gridding_reverse_function_16(self): 34 | x = torch.rand(1, 16, 16, 16) 35 | x.requires_grad = True 36 | self.assertTrue(gradcheck(GriddingReverseFunction.apply, [16, x.double().cuda()])) 37 | 38 | def test_gridding_function_32pts(self): 39 | x = torch.rand(1, 32, 3) 40 | x.requires_grad = True 41 | self.assertTrue(gradcheck(GriddingFunction.apply, [x.double().cuda()])) 42 | 43 | 44 | if __name__ == '__main__': 45 | unittest.main() 46 | -------------------------------------------------------------------------------- /extensions/gridding_loss/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-12-30 09:56:06 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2020-02-22 19:19:43 6 | # @Email: cshzxie@gmail.com 7 | 8 | import torch 9 | 10 | import gridding_distance 11 | 12 | 13 | class GriddingDistanceFunction(torch.autograd.Function): 14 | @staticmethod 15 | def forward(ctx, min_x, max_x, min_y, max_y, min_z, max_z, pred_cloud, gt_cloud): 16 | pred_grid, pred_grid_pt_weights, pred_grid_pt_indexes = gridding_distance.forward( 17 | min_x, max_x, min_y, max_y, min_z, max_z, pred_cloud) 18 | # print(pred_grid.size()) # torch.Size(batch_size, n_grid_vertices, 8) 19 | # print(pred_grid_pt_weights.size()) # torch.Size(batch_size, n_pts, 8, 3) 20 | # print(pred_grid_pt_indexes.size()) # torch.Size(batch_size, n_pts, 8) 21 | gt_grid, gt_grid_pt_weights, gt_grid_pt_indexes = gridding_distance.forward( 22 | min_x, max_x, min_y, max_y, min_z, max_z, gt_cloud) 23 | # print(gt_grid.size()) # torch.Size(batch_size, n_grid_vertices, 8) 24 | # print(gt_grid_pt_weights.size()) # torch.Size(batch_size, n_pts, 8, 3) 25 | # print(gt_grid_pt_indexes.size()) # torch.Size(batch_size, n_pts, 8) 26 | 27 | ctx.save_for_backward(pred_grid_pt_weights, pred_grid_pt_indexes, gt_grid_pt_weights, gt_grid_pt_indexes) 28 | return pred_grid, gt_grid 29 | 30 | @staticmethod 31 | def backward(ctx, grad_pred_grid, grad_gt_grid): 32 | pred_grid_pt_weights, pred_grid_pt_indexes, gt_grid_pt_weights, gt_grid_pt_indexes = ctx.saved_tensors 33 | 34 | grad_pred_cloud = gridding_distance.backward(pred_grid_pt_weights, pred_grid_pt_indexes, grad_pred_grid) 35 | # print(grad_pred_cloud.size()) # torch.Size(batch_size, n_pts, 3) 36 | grad_gt_cloud = gridding_distance.backward(gt_grid_pt_weights, gt_grid_pt_indexes, grad_gt_grid) 37 | # print(grad_gt_cloud.size()) # torch.Size(batch_size, n_pts, 3) 38 | 39 | return None, None, None, None, None, None, grad_pred_cloud, grad_gt_cloud 40 | 41 | 42 | class GriddingDistance(torch.nn.Module): 43 | def __init__(self, scale=1): 44 | super(GriddingDistance, self).__init__() 45 | self.scale = scale 46 | 47 | def forward(self, pred_cloud, gt_cloud): 48 | ''' 49 | pred_cloud(b, n_pts1, 3) 50 | gt_cloud(b, n_pts2, 3) 51 | ''' 52 | pred_cloud = pred_cloud * self.scale / 2 53 | gt_cloud = gt_cloud * self.scale / 2 54 | 55 | min_pred_x = torch.min(pred_cloud[:, :, 0]) 56 | max_pred_x = torch.max(pred_cloud[:, :, 0]) 57 | min_pred_y = torch.min(pred_cloud[:, :, 1]) 58 | max_pred_y = torch.max(pred_cloud[:, :, 1]) 59 | min_pred_z = torch.min(pred_cloud[:, :, 2]) 60 | max_pred_z = torch.max(pred_cloud[:, :, 2]) 61 | 62 | min_gt_x = torch.min(gt_cloud[:, :, 0]) 63 | max_gt_x = torch.max(gt_cloud[:, :, 0]) 64 | min_gt_y = torch.min(gt_cloud[:, :, 1]) 65 | max_gt_y = torch.max(gt_cloud[:, :, 1]) 66 | min_gt_z = torch.min(gt_cloud[:, :, 2]) 67 | max_gt_z = torch.max(gt_cloud[:, :, 2]) 68 | 69 | min_x = torch.floor(torch.min(min_pred_x, min_gt_x)) - 1 70 | max_x = torch.ceil(torch.max(max_pred_x, max_gt_x)) + 1 71 | min_y = torch.floor(torch.min(min_pred_y, min_gt_y)) - 1 72 | max_y = torch.ceil(torch.max(max_pred_y, max_gt_y)) + 1 73 | min_z = torch.floor(torch.min(min_pred_z, min_gt_z)) - 1 74 | max_z = torch.ceil(torch.max(max_pred_z, max_gt_z)) + 1 75 | 76 | _pred_clouds = torch.split(pred_cloud, 1, dim=0) 77 | _gt_clouds = torch.split(gt_cloud, 1, dim=0) 78 | pred_grids = [] 79 | gt_grids = [] 80 | for pc, gc in zip(_pred_clouds, _gt_clouds): 81 | non_zeros = torch.sum(pc, dim=2).ne(0) 82 | pc = pc[non_zeros].unsqueeze(dim=0) 83 | non_zeros = torch.sum(gc, dim=2).ne(0) 84 | gc = gc[non_zeros].unsqueeze(dim=0) 85 | pred_grid, gt_grid = GriddingDistanceFunction.apply(min_x, max_x, min_y, max_y, min_z, max_z, pc, gc) 86 | pred_grids.append(pred_grid) 87 | gt_grids.append(gt_grid) 88 | 89 | return torch.cat(pred_grids, dim=0).contiguous(), torch.cat(gt_grids, dim=0).contiguous() 90 | 91 | 92 | class GriddingLoss(torch.nn.Module): 93 | def __init__(self, scales=[], alphas=[]): 94 | super(GriddingLoss, self).__init__() 95 | self.scales = scales 96 | self.alphas = alphas 97 | self.gridding_dists = [GriddingDistance(scale=s) for s in scales] 98 | self.l1_loss = torch.nn.L1Loss() 99 | 100 | def forward(self, pred_cloud, gt_cloud): 101 | gridding_loss = None 102 | n_dists = len(self.scales) 103 | 104 | for i in range(n_dists): 105 | alpha = self.alphas[i] 106 | gdist = self.gridding_dists[i] 107 | pred_grid, gt_grid = gdist(pred_cloud, gt_cloud) 108 | 109 | if gridding_loss is None: 110 | gridding_loss = alpha * self.l1_loss(pred_grid, gt_grid) 111 | else: 112 | gridding_loss += alpha * self.l1_loss(pred_grid, gt_grid) 113 | 114 | return gridding_loss 115 | -------------------------------------------------------------------------------- /extensions/gridding_loss/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding_loss/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /extensions/gridding_loss/build/lib.linux-x86_64-3.8/gridding_distance.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding_loss/build/lib.linux-x86_64-3.8/gridding_distance.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /extensions/gridding_loss/build/temp.linux-x86_64-3.8/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/bin/nvcc 4 | 5 | cflags = -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -fPIC -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c 6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=gridding_distance -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++14 7 | cuda_cflags = -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=gridding_distance -D_GLIBCXX_USE_CXX11_ABI=1 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 -std=c++14 9 | ldflags = 10 | 11 | rule compile 12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 13 | depfile = $out.d 14 | deps = gcc 15 | 16 | rule cuda_compile 17 | depfile = $out.d 18 | deps = gcc 19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 20 | 21 | 22 | 23 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding_loss/build/temp.linux-x86_64-3.8/gridding_distance.o: cuda_compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding_loss/gridding_distance.cu 24 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding_loss/build/temp.linux-x86_64-3.8/gridding_distance_cuda.o: compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding_loss/gridding_distance_cuda.cpp 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /extensions/gridding_loss/build/temp.linux-x86_64-3.8/gridding_distance.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding_loss/build/temp.linux-x86_64-3.8/gridding_distance.o -------------------------------------------------------------------------------- /extensions/gridding_loss/build/temp.linux-x86_64-3.8/gridding_distance_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding_loss/build/temp.linux-x86_64-3.8/gridding_distance_cuda.o -------------------------------------------------------------------------------- /extensions/gridding_loss/dist/gridding_distance-1.0.0-py3.8-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding_loss/dist/gridding_distance-1.0.0-py3.8-linux-x86_64.egg -------------------------------------------------------------------------------- /extensions/gridding_loss/gridding_distance.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: gridding-distance 3 | Version: 1.0.0 4 | Summary: UNKNOWN 5 | License: UNKNOWN 6 | Platform: UNKNOWN 7 | 8 | UNKNOWN 9 | 10 | -------------------------------------------------------------------------------- /extensions/gridding_loss/gridding_distance.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | gridding_distance.cu 2 | gridding_distance_cuda.cpp 3 | setup.py 4 | gridding_distance.egg-info/PKG-INFO 5 | gridding_distance.egg-info/SOURCES.txt 6 | gridding_distance.egg-info/dependency_links.txt 7 | gridding_distance.egg-info/top_level.txt -------------------------------------------------------------------------------- /extensions/gridding_loss/gridding_distance.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /extensions/gridding_loss/gridding_distance.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | gridding_distance 2 | -------------------------------------------------------------------------------- /extensions/gridding_loss/gridding_distance_cuda.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * @Author: Haozhe Xie 3 | * @Date: 2019-12-30 10:59:31 4 | * @Last Modified by: Haozhe Xie 5 | * @Last Modified time: 2020-06-17 14:52:52 6 | * @Email: cshzxie@gmail.com 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 13 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") 14 | #define CHECK_CONTIGUOUS(x) \ 15 | AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 16 | #define CHECK_INPUT(x) \ 17 | CHECK_CUDA(x); \ 18 | CHECK_CONTIGUOUS(x) 19 | 20 | std::vector gridding_distance_cuda_forward(float min_x, 21 | float max_x, 22 | float min_y, 23 | float max_y, 24 | float min_z, 25 | float max_z, 26 | torch::Tensor ptcloud, 27 | cudaStream_t stream); 28 | 29 | torch::Tensor gridding_distance_cuda_backward(torch::Tensor grid_pt_weights, 30 | torch::Tensor grid_pt_indexes, 31 | torch::Tensor grad_grid, 32 | cudaStream_t stream); 33 | 34 | std::vector gridding_distance_forward(float min_x, 35 | float max_x, 36 | float min_y, 37 | float max_y, 38 | float min_z, 39 | float max_z, 40 | torch::Tensor ptcloud) { 41 | CHECK_INPUT(ptcloud); 42 | 43 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 44 | return gridding_distance_cuda_forward(min_x, max_x, min_y, max_y, min_z, 45 | max_z, ptcloud, stream); 46 | } 47 | 48 | torch::Tensor gridding_distance_backward(torch::Tensor grid_pt_weights, 49 | torch::Tensor grid_pt_indexes, 50 | torch::Tensor grad_grid) { 51 | CHECK_INPUT(grid_pt_weights); 52 | CHECK_INPUT(grid_pt_indexes); 53 | CHECK_INPUT(grad_grid); 54 | 55 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 56 | return gridding_distance_cuda_backward(grid_pt_weights, grid_pt_indexes, 57 | grad_grid, stream); 58 | } 59 | 60 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 61 | m.def("forward", &gridding_distance_forward, 62 | "Gridding Distance Forward (CUDA)"); 63 | m.def("backward", &gridding_distance_backward, 64 | "Gridding Distance Backward (CUDA)"); 65 | } 66 | -------------------------------------------------------------------------------- /extensions/gridding_loss/setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-12-30 11:03:55 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2019-12-30 11:13:39 6 | # @Email: cshzxie@gmail.com 7 | 8 | from setuptools import setup 9 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 10 | 11 | setup(name='gridding_distance', 12 | version='1.0.0', 13 | ext_modules=[ 14 | CUDAExtension('gridding_distance', ['gridding_distance_cuda.cpp', 'gridding_distance.cu']), 15 | ], 16 | cmdclass={'build_ext': BuildExtension}) 17 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --gres=gpu:1 # Request GPU "generic resources" 3 | 4 | 5 | 6 | #SBATCH --cpus-per-task=6 # Cores proportional to GPUs: 6 on Cedar, 16 on Graham. 7 | 8 | 9 | 10 | #SBATCH --mem=32000M # Memory proportional to GPUs: 32000 Cedar, 64000 Graham. 11 | 12 | 13 | 14 | #SBATCH --time=0-03:00 15 | source /home/golriz/projects/def-guibault/golriz/env/bin/activate 16 | HOME=`pwd` 17 | 18 | # Chamfer Distance 19 | cd $HOME/extensions/chamfer_dist 20 | python setup.py install 21 | 22 | # NOTE: For GRNet 23 | 24 | # Cubic Feature Sampling 25 | cd $HOME/extensions/cubic_feature_sampling 26 | python setup.py install 27 | 28 | # Gridding & Gridding Reverse 29 | cd $HOME/extensions/gridding 30 | python setup.py install 31 | 32 | # Gridding Loss 33 | cd $HOME/extensions/gridding_loss 34 | python setup.py install 35 | 36 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from tools import run_net 2 | from tools import test_net 3 | from utils import parser, dist_utils, misc 4 | from utils.logger import * 5 | from utils.config import * 6 | import time 7 | import os 8 | import torch 9 | from tensorboardX import SummaryWriter 10 | 11 | 12 | def main(): 13 | # args 14 | print(os.getcwd()) 15 | print('start main') 16 | args = parser.get_args() 17 | args.config = 'cfgs/Tooth_models/PoinTr.yaml' 18 | args.config_SAP = 'SAP/configs/learning_based/noise_small/ours.yaml' 19 | #args.test = 'true' 20 | #args.ckpts = "./pretrained/ckpt-last.pth" 21 | #args.mode ="easy" 22 | args.exp_name = "train-20-03-23" 23 | 24 | # CUDA 25 | args.use_gpu = torch.cuda.is_available() 26 | if args.use_gpu: 27 | torch.backends.cudnn.benchmark = True 28 | # init distributed env first, since logger depends on the dist info. 29 | if args.launcher == 'none': 30 | args.distributed = False 31 | else: 32 | args.distributed = True 33 | dist_utils.init_dist(args.launcher) 34 | # re-set gpu_ids with distributed training mode 35 | _, world_size = dist_utils.get_dist_info() 36 | args.world_size = world_size 37 | print('in the main function') 38 | # logger 39 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 40 | log_file = os.path.join(args.experiment_path, f'{timestamp}.log') 41 | logger = get_root_logger(log_file=log_file, name=args.log_name) 42 | # define the tensorboard writer 43 | if not args.test: 44 | if args.local_rank == 0: 45 | train_writer = SummaryWriter(os.path.join(args.tfboard_path, 'train')) 46 | val_writer = SummaryWriter(os.path.join(args.tfboard_path, 'test')) 47 | else: 48 | train_writer = None 49 | val_writer = None 50 | # config 51 | config = get_config(args, logger = logger) 52 | config_SAP = load_config(args.config_SAP) 53 | # batch size 54 | if args.distributed: 55 | assert config.total_bs % world_size == 0 56 | config.dataset.train.others.bs = config.total_bs // world_size 57 | else: 58 | config.dataset.train.others.bs = config.total_bs 59 | # log 60 | log_args_to_file(args, 'args', logger = logger) 61 | log_config_to_file(config, 'config', logger = logger) 62 | # exit() 63 | logger.info(f'Distributed training: {args.distributed}') 64 | # set random seeds 65 | if args.seed is not None: 66 | logger.info(f'Set random seed to {args.seed}, ' 67 | f'deterministic: {args.deterministic}') 68 | misc.set_random_seed(args.seed + args.local_rank, deterministic=args.deterministic) # seed + rank, for augmentation 69 | if args.distributed: 70 | assert args.local_rank == torch.distributed.get_rank() 71 | 72 | # run 73 | if args.test: 74 | test_net(args,config) 75 | else: 76 | run_net(args, config,config_SAP, train_writer, val_writer) 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /models/PoinTr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from pointnet2_ops import pointnet2_utils 5 | from extensions.chamfer_dist import ChamferDistanceL1 6 | from .Transformer import PCTransformer 7 | from .build import MODELS 8 | from SAP.src.model import PSR2Mesh 9 | from SAP.src.dpsr import DPSR 10 | import argparse 11 | from SAP.src import utils 12 | from SAP.src.model import Encode2Points 13 | from SAP.src.model import PSR2Mesh 14 | def fps(pc, num,device): 15 | pc = pc.to(device) 16 | fps_idx = pointnet2_utils.furthest_point_sample(pc, num) 17 | sub_pc = pointnet2_utils.gather_operation(pc.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous() 18 | return sub_pc 19 | 20 | 21 | class Fold(nn.Module): 22 | def __init__(self, in_channel , step , hidden_dim = 512): 23 | super().__init__() 24 | 25 | self.in_channel = in_channel 26 | self.step = step 27 | #self.psr2mesh = PSR2Mesh.apply 28 | a = torch.linspace(-1., 1., steps=step, dtype=torch.float).view(1, step).expand(step, step).reshape(1, -1) 29 | b = torch.linspace(-1., 1., steps=step, dtype=torch.float).view(step, 1).expand(step, step).reshape(1, -1) 30 | self.folding_seed = torch.cat([a, b], dim=0).cuda() 31 | 32 | self.folding1 = nn.Sequential( 33 | nn.Conv1d(in_channel + 2, hidden_dim, 1), 34 | nn.BatchNorm1d(hidden_dim), 35 | nn.ReLU(inplace=True), 36 | nn.Conv1d(hidden_dim, hidden_dim//2, 1), 37 | nn.BatchNorm1d(hidden_dim//2), 38 | nn.ReLU(inplace=True), 39 | nn.Conv1d(hidden_dim//2, 3, 1), 40 | ) 41 | 42 | self.folding2 = nn.Sequential( 43 | nn.Conv1d(in_channel + 3, hidden_dim, 1), 44 | nn.BatchNorm1d(hidden_dim), 45 | nn.ReLU(inplace=True), 46 | nn.Conv1d(hidden_dim, hidden_dim//2, 1), 47 | nn.BatchNorm1d(hidden_dim//2), 48 | nn.ReLU(inplace=True), 49 | nn.Conv1d(hidden_dim//2, 3, 1), 50 | ) 51 | 52 | def forward(self, x): 53 | num_sample = self.step * self.step 54 | bs = x.size(0) 55 | features = x.view(bs, self.in_channel, 1).expand(bs, self.in_channel, num_sample) 56 | seed = self.folding_seed.view(1, 2, num_sample).expand(bs, 2, num_sample).to(x.device) 57 | 58 | x = torch.cat([seed, features], dim=1) 59 | fd1 = self.folding1(x) 60 | x = torch.cat([fd1, features], dim=1) 61 | fd2 = self.folding2(x) 62 | 63 | return fd2 64 | 65 | @MODELS.register_module() 66 | class PoinTr(nn.Module): 67 | def __init__(self,config, **kwargs): 68 | super().__init__() 69 | self.trans_dim = config.trans_dim 70 | self.knn_layer = config.knn_layer 71 | self.num_pred = config.num_pred 72 | self.num_query = config.num_query 73 | self.model_sap = Encode2Points("SAP/configs/learning_based/noise_small/ours.yaml") 74 | self.psr2mesh = PSR2Mesh.apply 75 | self.fold_step = int(pow(self.num_pred//self.num_query, 0.5) + 0.5) 76 | self.base_model = PCTransformer(in_chans = 3, embed_dim = self.trans_dim, depth = [6, 8], drop_rate = 0., num_query = self.num_query, knn_layer = self.knn_layer) 77 | self.foldingnet = Fold(self.trans_dim, step = self.fold_step, hidden_dim = 256) # rebuild a cluster point 78 | self.dpsr = DPSR(res=(128,128, 128), sig = 2) 79 | 80 | 81 | 82 | self.increase_dim = nn.Sequential( 83 | nn.Conv1d(self.trans_dim, 1024, 1), 84 | nn.BatchNorm1d(1024), 85 | nn.LeakyReLU(negative_slope=0.2), 86 | nn.Conv1d(1024, 1024, 1) 87 | ) 88 | self.reduce_map = nn.Linear(self.trans_dim + 1027, self.trans_dim) 89 | self.build_loss_func() 90 | 91 | def build_loss_func(self): 92 | self.loss_func = ChamferDistanceL1() 93 | 94 | def get_loss(self, ret, gt): 95 | loss_fine = self.loss_func(ret, gt) 96 | #loss_fine = self.loss_func(ret[1], gt) 97 | return loss_fine 98 | 99 | 100 | def forward(self,xyz,min_gt,max_gt,value_std_pc,value_centroid): 101 | 102 | 103 | q, coarse_point_cloud = self.base_model(xyz) # B M C and B M 3 104 | 105 | B, M ,C = q.shape 106 | 107 | global_feature = self.increase_dim(q.transpose(1,2)).transpose(1,2) # B M 1024 108 | global_feature = torch.max(global_feature, dim=1)[0] # B 1024 109 | #print(global_feature.shape) 110 | rebuild_feature = torch.cat([ 111 | global_feature.unsqueeze(-2).expand(-1, M, -1), 112 | q, 113 | coarse_point_cloud], dim=-1) # B M 1027 + C 114 | 115 | rebuild_feature = self.reduce_map(rebuild_feature.reshape(B*M, -1)) # BM C 116 | # # NOTE: try to rebuild pc 117 | # coarse_point_cloud = self.refine_coarse(rebuild_feature).reshape(B, M, 3) 118 | #print(rebuild_feature.shape) 119 | # NOTE: foldingNet 120 | relative_xyz = self.foldingnet(rebuild_feature).reshape(B, M, 3, -1) # B M 3 S 121 | #print(relative_xyz.shape) 122 | rebuild_points = (relative_xyz + coarse_point_cloud.unsqueeze(-1)).transpose(2,3).reshape(B, -1, 3) # B N 3 123 | ##print(rebuild_points.shape) 124 | 125 | 126 | # cat the input 127 | #inp_sparse = fps(xyz, self.num_query) 128 | 129 | # denormalize the data based on mean and std 130 | value_std_points=value_std_pc.view((rebuild_points.shape[0],1,3)) 131 | value_centroid_points=value_centroid.view((rebuild_points.shape[0],1,3)) 132 | De_point = torch.multiply(rebuild_points, value_std_points) + value_centroid_points 133 | 134 | #Normalize data to min and max on gt 135 | min_depoint=min_gt.view(rebuild_points.shape[0],1,1) 136 | max_depoint = max_gt.view(rebuild_points.shape[0],1,1) 137 | 138 | Npoints = torch.div(torch.subtract(De_point,min_depoint),torch.subtract((max_depoint + 1),min_depoint)) 139 | 140 | #SAP 141 | out = self.model_sap(Npoints) 142 | points, normals = out 143 | psr_grid= self.dpsr(points, normals) 144 | 145 | #return psr_grid,points,rebuild_points,min_depoint,max_depoint 146 | return psr_grid,rebuild_points 147 | 148 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model_from_cfg 2 | import models.PoinTr 3 | import SAP.src.model 4 | import SAP.src.dpsr -------------------------------------------------------------------------------- /models/__pycache__/FoldingNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/FoldingNet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/GRNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/GRNet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/PCN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/PCN.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/PoinTr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/PoinTr.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/TopNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/TopNet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/Transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/Transformer.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/build.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/build.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/dgcnn_group.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/dgcnn_group.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/easy_mesh_vtk.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/easy_mesh_vtk.cpython-38.pyc -------------------------------------------------------------------------------- /models/build.py: -------------------------------------------------------------------------------- 1 | from utils import registry 2 | 3 | 4 | MODELS = registry.Registry('models') 5 | 6 | 7 | def build_model_from_cfg(cfg, **kwargs): 8 | """ 9 | Build a dataset, defined by `dataset_name`. 10 | Args: 11 | cfg (eDICT): 12 | Returns: 13 | Dataset: a constructed dataset specified by dataset_name. 14 | """ 15 | return MODELS.build(cfg, **kwargs) 16 | 17 | 18 | -------------------------------------------------------------------------------- /models/dgcnn_group.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from pointnet2_ops import pointnet2_utils 4 | from knn_cuda import KNN 5 | knn = KNN(k=16, transpose_mode=False) 6 | 7 | 8 | class DGCNN_Grouper(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | ''' 12 | K has to be 16 13 | ''' 14 | self.input_trans = nn.Conv1d(3, 8, 1) 15 | 16 | self.layer1 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=1, bias=False), 17 | nn.GroupNorm(4, 32), 18 | nn.LeakyReLU(negative_slope=0.2) 19 | ) 20 | 21 | self.layer2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), 22 | nn.GroupNorm(4, 64), 23 | nn.LeakyReLU(negative_slope=0.2) 24 | ) 25 | 26 | self.layer3 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=1, bias=False), 27 | nn.GroupNorm(4, 64), 28 | nn.LeakyReLU(negative_slope=0.2) 29 | ) 30 | 31 | self.layer4 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=1, bias=False), 32 | nn.GroupNorm(4, 128), 33 | nn.LeakyReLU(negative_slope=0.2) 34 | ) 35 | 36 | 37 | @staticmethod 38 | def fps_downsample(coor, x, num_group): 39 | xyz = coor.transpose(1, 2).contiguous() # b, n, 3 40 | fps_idx = pointnet2_utils.furthest_point_sample(xyz, num_group) 41 | 42 | combined_x = torch.cat([coor, x], dim=1) 43 | 44 | new_combined_x = ( 45 | pointnet2_utils.gather_operation( 46 | combined_x, fps_idx 47 | ) 48 | ) 49 | 50 | new_coor = new_combined_x[:, :3] 51 | new_x = new_combined_x[:, 3:] 52 | 53 | return new_coor, new_x 54 | 55 | @staticmethod 56 | def get_graph_feature(coor_q, x_q, coor_k, x_k): 57 | 58 | # coor: bs, 3, np, x: bs, c, np 59 | 60 | k = 16 61 | batch_size = x_k.size(0) 62 | num_points_k = x_k.size(2) 63 | num_points_q = x_q.size(2) 64 | 65 | with torch.no_grad(): 66 | _, idx = knn(coor_k, coor_q) # bs k np 67 | assert idx.shape[1] == k 68 | idx_base = torch.arange(0, batch_size, device=x_q.device).view(-1, 1, 1) * num_points_k 69 | idx = idx + idx_base 70 | idx = idx.view(-1) 71 | num_dims = x_k.size(1) 72 | x_k = x_k.transpose(2, 1).contiguous() 73 | feature = x_k.view(batch_size * num_points_k, -1)[idx, :] 74 | feature = feature.view(batch_size, k, num_points_q, num_dims).permute(0, 3, 2, 1).contiguous() 75 | x_q = x_q.view(batch_size, num_dims, num_points_q, 1).expand(-1, -1, -1, k) 76 | feature = torch.cat((feature - x_q, x_q), dim=1) 77 | return feature 78 | 79 | def forward(self, x): 80 | 81 | # x: bs, 3, np 82 | 83 | # bs 3 N(128) bs C(224)128 N(128) 84 | coor = x 85 | f = self.input_trans(x) 86 | 87 | f = self.get_graph_feature(coor, f, coor, f) 88 | f = self.layer1(f) 89 | f = f.max(dim=-1, keepdim=False)[0] 90 | 91 | f = self.get_graph_feature(coor, f, coor, f) 92 | f = self.layer2(f) 93 | f = f.max(dim=-1, keepdim=False)[0] 94 | 95 | f = self.get_graph_feature(coor, f, coor, f) 96 | f = self.layer3(f) 97 | f = f.max(dim=-1, keepdim=False)[0] 98 | 99 | f = self.get_graph_feature(coor, f, coor, f) 100 | f = self.layer4(f) 101 | f = f.max(dim=-1, keepdim=False)[0] 102 | 103 | return coor, f -------------------------------------------------------------------------------- /readme.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | easydict 3 | h5py 4 | matplotlib 5 | numpy 6 | open3d==0.9 7 | opencv-python 8 | pyyaml 9 | scipy 10 | tensorboardX 11 | timm==0.4.5 12 | tqdm 13 | transforms3d 14 | gcc,cuda 15 | igl 16 | trimesh tensorboard plyfile open3d scikit-image python-mnist opencv-python av pykdtree ipdb 17 | Pip install https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl 18 | pip install --no-index pytorch3d 19 | pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.0+${CUDA}.html 20 | Pip install "git+https://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib" 21 | Pip install –no-index torch torchvision torchaudio -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .runner import run_net 2 | from .runner import test_net -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/tools/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/builder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/tools/__pycache__/builder.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/runner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/tools/__pycache__/runner.cpython-38.pyc -------------------------------------------------------------------------------- /tools/builder.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | # online package 3 | import torch 4 | # optimizer 5 | import torch.optim as optim 6 | # dataloader 7 | from datasets import build_dataset_from_cfg 8 | from models import build_model_from_cfg 9 | # utils 10 | from utils.logger import * 11 | from utils.misc import * 12 | from collections import OrderedDict 13 | 14 | def dataset_builder(args, config): 15 | dataset = build_dataset_from_cfg(config._base_, config.others) 16 | shuffle = config.others.subset == 'train' 17 | if args.distributed: 18 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle = shuffle) 19 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = config.others.bs if shuffle else 1, 20 | num_workers = int(args.num_workers), 21 | drop_last = config.others.subset == 'train', 22 | worker_init_fn = worker_init_fn, 23 | sampler = sampler) 24 | else: 25 | sampler = None 26 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.others.bs if shuffle else 1, 27 | shuffle = shuffle, 28 | drop_last = config.others.subset == 'train', 29 | num_workers = int(args.num_workers), 30 | worker_init_fn=worker_init_fn) 31 | return sampler, dataloader 32 | 33 | def model_builder(config): 34 | model = build_model_from_cfg(config) 35 | return model 36 | 37 | def build_opti_sche(base_model, config): 38 | opti_config = config.optimizer 39 | if opti_config.type == 'AdamW': 40 | optimizer = optim.AdamW(base_model.parameters(), **opti_config.kwargs) 41 | elif opti_config.type == 'Adam': 42 | optimizer = optim.Adam(base_model.parameters(), **opti_config.kwargs) 43 | elif opti_config.type == 'SGD': 44 | optimizer = optim.SGD(base_model.parameters(), nesterov=True, **opti_config.kwargs) 45 | else: 46 | raise NotImplementedError() 47 | 48 | sche_config = config.scheduler 49 | if sche_config.type == 'LambdaLR': 50 | scheduler = build_lambda_sche(optimizer, sche_config.kwargs) # misc.py 51 | elif sche_config.type == 'StepLR': 52 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **sche_config.kwargs) 53 | else: 54 | raise NotImplementedError() 55 | 56 | if config.get('bnmscheduler') is not None: 57 | bnsche_config = config.bnmscheduler 58 | if bnsche_config.type == 'Lambda': 59 | bnscheduler = build_lambda_bnsche(base_model, bnsche_config.kwargs) # misc.py 60 | scheduler = [scheduler, bnscheduler] 61 | 62 | return optimizer, scheduler 63 | 64 | def resume_model(base_model, args, logger = None): 65 | ckpt_path = os.path.join(args.experiment_path, 'ckpt-last.pth') 66 | if not os.path.exists(ckpt_path): 67 | print_log(f'[RESUME INFO] no checkpoint file from path {ckpt_path}...', logger = logger) 68 | return 0, 0 69 | print_log(f'[RESUME INFO] Loading model weights from {ckpt_path}...', logger = logger ) 70 | 71 | # load state dict 72 | map_location = {'cuda:%d' % 0: 'cuda:%d' % args.local_rank} 73 | state_dict = torch.load(ckpt_path, map_location=map_location) 74 | # parameter resume of base model 75 | # if args.local_rank == 0: 76 | base_ckpt = {k.replace("module.", ""): v for k, v in state_dict['base_model'].items()} 77 | base_model.load_state_dict(base_ckpt) 78 | 79 | # parameter 80 | start_epoch = state_dict['epoch'] + 1 81 | best_metrics = state_dict['best_metrics'] 82 | if not isinstance(best_metrics, dict): 83 | best_metrics = best_metrics.state_dict() 84 | # print(best_metrics) 85 | 86 | print_log(f'[RESUME INFO] resume ckpts @ {start_epoch - 1} epoch( best_metrics = {str(best_metrics):s})', logger = logger) 87 | return start_epoch, best_metrics 88 | 89 | def resume_optimizer(optimizer, args, logger = None): 90 | ckpt_path = os.path.join(args.experiment_path, 'ckpt-last.pth') 91 | if not os.path.exists(ckpt_path): 92 | print_log(f'[RESUME INFO] no checkpoint file from path {ckpt_path}...', logger = logger) 93 | return 0, 0, 0 94 | print_log(f'[RESUME INFO] Loading optimizer from {ckpt_path}...', logger = logger ) 95 | # load state dict 96 | state_dict = torch.load(ckpt_path, map_location='cpu') 97 | # optimizer 98 | optimizer.load_state_dict(state_dict['optimizer']) 99 | 100 | def save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, prefix, args, logger = None): 101 | if args.local_rank == 0: 102 | torch.save({ 103 | 'base_model' : base_model.module.state_dict() if args.distributed else base_model.state_dict(), 104 | 'optimizer' : optimizer.state_dict(), 105 | 'epoch' : epoch, 106 | 'metrics' : metrics if metrics is not None else dict(), 107 | 'best_metrics' : best_metrics if best_metrics is not None else dict(), 108 | }, os.path.join(args.experiment_path, prefix + '.pth')) 109 | print_log(f"Save checkpoint at {os.path.join(args.experiment_path, prefix + '.pth')}", logger = logger) 110 | 111 | def load_model(base_model, ckpt_path, logger = None): 112 | if not os.path.exists(ckpt_path): 113 | raise NotImplementedError('no checkpoint file from path %s...' % ckpt_path) 114 | print_log(f'Loading weights from {ckpt_path}...', logger = logger ) 115 | 116 | # load state dict 117 | state_dict = torch.load(ckpt_path, map_location='cpu') 118 | load_model_manual(state_dict['base_model'], base_model) 119 | # parameter resume of base model 120 | if state_dict.get('model') is not None: 121 | base_ckpt = {k.replace("module.", ""): v for k, v in state_dict['model'].items()} 122 | elif state_dict.get('base_model') is not None: 123 | base_ckpt = {k.replace("module.", ""): v for k, v in state_dict['base_model'].items()} 124 | else: 125 | raise RuntimeError('mismatch of ckpt weight') 126 | #base_model.load_state_dict(base_ckpt) 127 | 128 | epoch = -1 129 | if state_dict.get('epoch') is not None: 130 | epoch = state_dict['epoch'] 131 | if state_dict.get('metrics') is not None: 132 | metrics = state_dict['metrics'] 133 | if not isinstance(metrics, dict): 134 | metrics = metrics 135 | # metrics = metrics.state_dict() 136 | else: 137 | metrics = 'No Metrics' 138 | print_log(f'ckpts @ {epoch} epoch( performance = {str(metrics):s})', logger = logger) 139 | return 140 | 141 | def load_model_manual(state_dict, model): 142 | new_state_dict = OrderedDict() 143 | is_model_parallel = isinstance(model, torch.nn.DataParallel) 144 | for k, v in state_dict.items(): 145 | if k.startswith('module.') != is_model_parallel: 146 | if k.startswith('module.'): 147 | # remove module 148 | k = k[7:] 149 | else: 150 | # add module 151 | k = 'module.' + k 152 | 153 | new_state_dict[k]=v 154 | 155 | model.load_state_dict(new_state_dict,strict=False) -------------------------------------------------------------------------------- /utils/AverageMeter.py: -------------------------------------------------------------------------------- 1 | 2 | class AverageMeter(object): 3 | def __init__(self, items=None): 4 | self.items = items 5 | self.n_items = 1 if items is None else len(items) 6 | self.reset() 7 | 8 | def reset(self): 9 | self._val = [0] * self.n_items 10 | self._sum = [0] * self.n_items 11 | self._count = [0] * self.n_items 12 | 13 | def update(self, values): 14 | if type(values).__name__ == 'list': 15 | for idx, v in enumerate(values): 16 | self._val[idx] = v 17 | self._sum[idx] += v 18 | self._count[idx] += 1 19 | else: 20 | self._val[0] = values 21 | self._sum[0] += values 22 | self._count[0] += 1 23 | 24 | def val(self, idx=None): 25 | if idx is None: 26 | return self._val[0] if self.items is None else [self._val[i] for i in range(self.n_items)] 27 | else: 28 | return self._val[idx] 29 | 30 | def count(self, idx=None): 31 | if idx is None: 32 | return self._count[0] if self.items is None else [self._count[i] for i in range(self.n_items)] 33 | else: 34 | return self._count[idx] 35 | 36 | def avg(self, idx=None): 37 | if idx is None: 38 | return self._sum[0] / self._count[0] if self.items is None else [ 39 | self._sum[i] / self._count[i] for i in range(self.n_items) 40 | ] 41 | else: 42 | return self._sum[idx] / self._count[idx] -------------------------------------------------------------------------------- /utils/__pycache__/AverageMeter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/AverageMeter.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dist_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/dist_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/parser.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/parser.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/registry.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/registry.cpython-38.pyc -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from easydict import EasyDict 3 | import os 4 | from .logger import print_log 5 | 6 | def log_args_to_file(args, pre='args', logger=None): 7 | for key, val in args.__dict__.items(): 8 | print_log(f'{pre}.{key} : {val}', logger = logger) 9 | 10 | def log_config_to_file(cfg, pre='cfg', logger=None): 11 | for key, val in cfg.items(): 12 | if isinstance(cfg[key], EasyDict): 13 | print_log(f'{pre}.{key} = edict()', logger = logger) 14 | log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger) 15 | continue 16 | print_log(f'{pre}.{key} : {val}', logger = logger) 17 | 18 | def merge_new_config(config, new_config): 19 | for key, val in new_config.items(): 20 | if not isinstance(val, dict): 21 | if key == '_base_': 22 | with open(new_config['_base_'], 'r') as f: 23 | try: 24 | val = yaml.load(f, Loader=yaml.FullLoader) 25 | except: 26 | val = yaml.load(f) 27 | config[key] = EasyDict() 28 | merge_new_config(config[key], val) 29 | else: 30 | config[key] = val 31 | continue 32 | if key not in config: 33 | config[key] = EasyDict() 34 | merge_new_config(config[key], val) 35 | return config 36 | 37 | def cfg_from_yaml_file(cfg_file): 38 | config = EasyDict() 39 | with open(cfg_file, 'r') as f: 40 | try: 41 | new_config = yaml.load(f, Loader=yaml.FullLoader) 42 | except: 43 | new_config = yaml.load(f) 44 | merge_new_config(config=config, new_config=new_config) 45 | return config 46 | 47 | def get_config(args, logger=None): 48 | if args.resume: 49 | cfg_path = os.path.join(args.experiment_path, 'config.yaml') 50 | if not os.path.exists(cfg_path): 51 | print_log("Failed to resume", logger = logger) 52 | raise FileNotFoundError() 53 | print_log(f'Resume yaml from {cfg_path}', logger = logger) 54 | args.config = cfg_path 55 | config = cfg_from_yaml_file(args.config) 56 | if not args.resume and args.local_rank == 0: 57 | save_experiment_config(args, config, logger) 58 | return config 59 | 60 | def load_config(path, default_path=None): 61 | ''' Loads config file. 62 | 63 | Args: 64 | path (str): path to config file 65 | default_path (bool): whether to use default path 66 | ''' 67 | # Load configuration from file itself 68 | with open(path, 'r') as f: 69 | cfg_special = yaml.load(f, Loader=yaml.Loader) 70 | 71 | # Check if we should inherit from a config 72 | inherit_from = cfg_special.get('inherit_from') 73 | 74 | # If yes, load this config first as default 75 | # If no, use the default_path 76 | if inherit_from is not None: 77 | cfg = load_config(inherit_from, default_path) 78 | elif default_path is not None: 79 | with open(default_path, 'r') as f: 80 | cfg = yaml.load(f, Loader=yaml.Loader) 81 | else: 82 | cfg = dict() 83 | 84 | # Include main configuration 85 | update_recursive(cfg, cfg_special) 86 | 87 | return cfg 88 | 89 | def update_recursive(dict1, dict2): 90 | ''' Update two config dictionaries recursively. 91 | 92 | Args: 93 | dict1 (dict): first dictionary to be updated 94 | dict2 (dict): second dictionary which entries should be used 95 | 96 | ''' 97 | for k, v in dict2.items(): 98 | if k not in dict1: 99 | dict1[k] = dict() 100 | if isinstance(v, dict): 101 | update_recursive(dict1[k], v) 102 | else: 103 | dict1[k] = v 104 | 105 | def save_experiment_config(args, config, logger = None): 106 | config_path = os.path.join(args.experiment_path, 'config.yaml') 107 | os.system('cp %s %s' % (args.config, config_path)) 108 | print_log(f'Copy the Config file from {args.config} to {config_path}',logger = logger ) -------------------------------------------------------------------------------- /utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.multiprocessing as mp 5 | from torch import distributed as dist 6 | 7 | 8 | 9 | def init_dist(launcher, backend='nccl', **kwargs): 10 | if mp.get_start_method(allow_none=True) is None: 11 | mp.set_start_method('spawn') 12 | if launcher == 'pytorch': 13 | _init_dist_pytorch(backend, **kwargs) 14 | else: 15 | raise ValueError(f'Invalid launcher type: {launcher}') 16 | 17 | 18 | def _init_dist_pytorch(backend, **kwargs): 19 | # TODO: use local_rank instead of rank % num_gpus 20 | rank = int(os.environ['RANK']) 21 | num_gpus = torch.cuda.device_count() 22 | torch.cuda.set_device(rank % num_gpus) 23 | dist.init_process_group(backend=backend, **kwargs) 24 | print(f'init distributed in rank {torch.distributed.get_rank()}') 25 | 26 | 27 | def get_dist_info(): 28 | if dist.is_available(): 29 | initialized = dist.is_initialized() 30 | else: 31 | initialized = False 32 | if initialized: 33 | rank = dist.get_rank() 34 | world_size = dist.get_world_size() 35 | else: 36 | rank = 0 37 | world_size = 1 38 | return rank, world_size 39 | 40 | 41 | def reduce_tensor(tensor, args): 42 | ''' 43 | for acc kind, get the mean in each gpu 44 | ''' 45 | rt = tensor.clone() 46 | torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM) 47 | rt /= args.world_size 48 | return rt 49 | 50 | def gather_tensor(tensor, args): 51 | output_tensors = [tensor.clone() for _ in range(args.world_size)] 52 | torch.distributed.all_gather(output_tensors, tensor) 53 | concat = torch.cat(output_tensors, dim=0) 54 | return concat 55 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.distributed as dist 3 | 4 | logger_initialized = {} 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'): 7 | """Get root logger and add a keyword filter to it. 8 | The logger will be initialized if it has not been initialized. By default a 9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 10 | also be added. The name of the root logger is the top-level package name, 11 | e.g., "mmdet3d". 12 | Args: 13 | log_file (str, optional): File path of log. Defaults to None. 14 | log_level (int, optional): The level of logger. 15 | Defaults to logging.INFO. 16 | name (str, optional): The name of the root logger, also used as a 17 | filter keyword. Defaults to 'mmdet3d'. 18 | Returns: 19 | :obj:`logging.Logger`: The obtained logger 20 | """ 21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level) 22 | # add a logging filter 23 | logging_filter = logging.Filter(name) 24 | logging_filter.filter = lambda record: record.find(name) != -1 25 | 26 | return logger 27 | 28 | 29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): 30 | """Initialize and get a logger by name. 31 | If the logger has not been initialized, this method will initialize the 32 | logger by adding one or two handlers, otherwise the initialized logger will 33 | be directly returned. During initialization, a StreamHandler will always be 34 | added. If `log_file` is specified and the process rank is 0, a FileHandler 35 | will also be added. 36 | Args: 37 | name (str): Logger name. 38 | log_file (str | None): The log filename. If specified, a FileHandler 39 | will be added to the logger. 40 | log_level (int): The logger level. Note that only the process of 41 | rank 0 is affected, and other processes will set the level to 42 | "Error" thus be silent most of the time. 43 | file_mode (str): The file mode used in opening log file. 44 | Defaults to 'w'. 45 | Returns: 46 | logging.Logger: The expected logger. 47 | """ 48 | logger = logging.getLogger(name) 49 | if name in logger_initialized: 50 | return logger 51 | # handle hierarchical names 52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 53 | # initialization since it is a child of "a". 54 | for logger_name in logger_initialized: 55 | if name.startswith(logger_name): 56 | return logger 57 | 58 | # handle duplicate logs to the console 59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) 60 | # to the root logger. As logger.propagate is True by default, this root 61 | # level handler causes logging messages from rank>0 processes to 62 | # unexpectedly show up on the console, creating much unwanted clutter. 63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log 64 | # at the ERROR level. 65 | for handler in logger.root.handlers: 66 | if type(handler) is logging.StreamHandler: 67 | handler.setLevel(logging.ERROR) 68 | 69 | stream_handler = logging.StreamHandler() 70 | handlers = [stream_handler] 71 | 72 | if dist.is_available() and dist.is_initialized(): 73 | rank = dist.get_rank() 74 | else: 75 | rank = 0 76 | 77 | # only rank 0 will add a FileHandler 78 | if rank == 0 and log_file is not None: 79 | # Here, the default behaviour of the official logger is 'a'. Thus, we 80 | # provide an interface to change the file mode to the default 81 | # behaviour. 82 | file_handler = logging.FileHandler(log_file, file_mode) 83 | handlers.append(file_handler) 84 | 85 | formatter = logging.Formatter( 86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 87 | for handler in handlers: 88 | handler.setFormatter(formatter) 89 | handler.setLevel(log_level) 90 | logger.addHandler(handler) 91 | 92 | if rank == 0: 93 | logger.setLevel(log_level) 94 | else: 95 | logger.setLevel(logging.ERROR) 96 | 97 | logger_initialized[name] = True 98 | 99 | 100 | return logger 101 | 102 | 103 | def print_log(msg, logger=None, level=logging.INFO): 104 | """Print a log message. 105 | Args: 106 | msg (str): The message to be logged. 107 | logger (logging.Logger | str | None): The logger to be used. 108 | Some special loggers are: 109 | - "silent": no message will be printed. 110 | - other str: the logger obtained with `get_root_logger(logger)`. 111 | - None: The `print()` method will be used to print log messages. 112 | level (int): Logging level. Only available when `logger` is a Logger 113 | object or "root". 114 | """ 115 | if logger is None: 116 | print(msg) 117 | elif isinstance(logger, logging.Logger): 118 | logger.log(level, msg) 119 | elif logger == 'silent': 120 | pass 121 | elif isinstance(logger, str): 122 | _logger = get_logger(logger) 123 | _logger.log(level, msg) 124 | else: 125 | raise TypeError( 126 | 'logger should be either a logging.Logger object, str, ' 127 | f'"silent" or None, but got {type(logger)}') -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-08-08 14:31:30 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2020-05-25 09:13:32 6 | # @Email: cshzxie@gmail.com 7 | 8 | import logging 9 | import open3d 10 | 11 | from extensions.chamfer_dist import ChamferDistanceL1, ChamferDistanceL2 12 | 13 | 14 | class Metrics(object): 15 | ITEMS = [{ 16 | 'name': 'F-Score', 17 | 'enabled': True, 18 | 'eval_func': 'cls._get_f_score', 19 | 'is_greater_better': True, 20 | 'init_value': 0 21 | }, { 22 | 'name': 'CDL1', 23 | 'enabled': True, 24 | 'eval_func': 'cls._get_chamfer_distancel1', 25 | 'eval_object': ChamferDistanceL1(ignore_zeros=True), 26 | 'is_greater_better': False, 27 | 'init_value': 32767 28 | }, { 29 | 'name': 'CDL2', 30 | 'enabled': True, 31 | 'eval_func': 'cls._get_chamfer_distancel2', 32 | 'eval_object': ChamferDistanceL2(ignore_zeros=True), 33 | 'is_greater_better': False, 34 | 'init_value': 32767 35 | }] 36 | 37 | @classmethod 38 | def get(cls, pred, gt): 39 | _items = cls.items() 40 | _values = [0] * len(_items) 41 | for i, item in enumerate(_items): 42 | eval_func = eval(item['eval_func']) 43 | _values[i] = eval_func(pred, gt) 44 | 45 | return _values 46 | 47 | @classmethod 48 | def items(cls): 49 | return [i for i in cls.ITEMS if i['enabled']] 50 | 51 | @classmethod 52 | def names(cls): 53 | _items = cls.items() 54 | return [i['name'] for i in _items] 55 | 56 | @classmethod 57 | def _get_f_score(cls, pred, gt, th=0.01): 58 | 59 | """References: https://github.com/lmb-freiburg/what3d/blob/master/util.py""" 60 | b = pred.size(0) 61 | assert pred.size(0) == gt.size(0) 62 | if b != 1: 63 | f_score_list = [] 64 | for idx in range(b): 65 | f_score_list.append(cls._get_f_score(pred[idx:idx+1], gt[idx:idx+1])) 66 | return sum(f_score_list)/len(f_score_list) 67 | else: 68 | pred = cls._get_open3d_ptcloud(pred) 69 | gt = cls._get_open3d_ptcloud(gt) 70 | 71 | dist1 = pred.compute_point_cloud_distance(gt) 72 | dist2 = gt.compute_point_cloud_distance(pred) 73 | 74 | recall = float(sum(d < th for d in dist2)) / float(len(dist2)) 75 | precision = float(sum(d < th for d in dist1)) / float(len(dist1)) 76 | return 2 * recall * precision / (recall + precision) if recall + precision else 0 77 | 78 | @classmethod 79 | def _get_open3d_ptcloud(cls, tensor): 80 | """pred and gt bs is 1""" 81 | tensor = tensor.squeeze().cpu().numpy() 82 | ptcloud = open3d.geometry.PointCloud() 83 | ptcloud.points = open3d.utility.Vector3dVector(tensor) 84 | 85 | return ptcloud 86 | 87 | @classmethod 88 | def _get_chamfer_distancel1(cls, pred, gt): 89 | chamfer_distance = cls.ITEMS[1]['eval_object'] 90 | return chamfer_distance(pred, gt).item() * 1000 91 | 92 | @classmethod 93 | def _get_chamfer_distancel2(cls, pred, gt): 94 | chamfer_distance = cls.ITEMS[2]['eval_object'] 95 | return chamfer_distance(pred, gt).item() * 1000 96 | 97 | def __init__(self, metric_name, values): 98 | self._items = Metrics.items() 99 | self._values = [item['init_value'] for item in self._items] 100 | self.metric_name = metric_name 101 | 102 | if type(values).__name__ == 'list': 103 | self._values = values 104 | elif type(values).__name__ == 'dict': 105 | metric_indexes = {} 106 | for idx, item in enumerate(self._items): 107 | item_name = item['name'] 108 | metric_indexes[item_name] = idx 109 | for k, v in values.items(): 110 | if k not in metric_indexes: 111 | logging.warn('Ignore Metric[Name=%s] due to disability.' % k) 112 | continue 113 | self._values[metric_indexes[k]] = v 114 | else: 115 | raise Exception('Unsupported value type: %s' % type(values)) 116 | 117 | def state_dict(self): 118 | _dict = dict() 119 | for i in range(len(self._items)): 120 | item = self._items[i]['name'] 121 | value = self._values[i] 122 | _dict[item] = value 123 | 124 | return _dict 125 | 126 | def __repr__(self): 127 | return str(self.state_dict()) 128 | 129 | def better_than(self, other): 130 | if other is None: 131 | return True 132 | 133 | _index = -1 134 | for i, _item in enumerate(self._items): 135 | if _item['name'] == self.metric_name: 136 | _index = i 137 | break 138 | if _index == -1: 139 | raise Exception('Invalid metric name to compare.') 140 | 141 | _metric = self._items[i] 142 | _value = self._values[_index] 143 | other_value = other._values[_index] 144 | return _value > other_value if _metric['is_greater_better'] else _value < other_value 145 | -------------------------------------------------------------------------------- /utils/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | 5 | def get_args(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | '--config', 9 | type = str, 10 | default='cfgs/Tooth_models/PoinTr.yaml', 11 | help = 'yaml config file') 12 | parser.add_argument( 13 | '--config_SAP', 14 | type=str, 15 | default='SAP/configs/learning_based/noise_small/ours.yaml', 16 | help='yaml config file') 17 | parser.add_argument( 18 | '--launcher', 19 | choices=['none', 'pytorch'], 20 | default='none', 21 | help='job launcher') 22 | parser.add_argument('--local_rank', type=int, default=0) 23 | parser.add_argument('--num_workers', type=int, default=4) 24 | # seed 25 | parser.add_argument('--seed', type=int, default=0, help='random seed') 26 | parser.add_argument( 27 | '--deterministic', 28 | action='store_true', 29 | help='whether to set deterministic options for CUDNN backend.') 30 | # bn 31 | parser.add_argument( 32 | '--sync_bn', 33 | action='store_true', 34 | default=False, 35 | help='whether to use sync bn') 36 | # some args 37 | parser.add_argument('--exp_name', type = str, default='default', help = 'experiment name') 38 | parser.add_argument('--start_ckpts', type = str, default=None, help = 'reload used ckpt path') 39 | parser.add_argument('--ckpts', type = str, default=None, help = 'test used ckpt path') 40 | parser.add_argument('--val_freq', type = int, default=1, help = 'test freq') 41 | parser.add_argument( 42 | '--resume', 43 | action='store_true', 44 | default=False, 45 | help = 'autoresume training (interrupted by accident)') 46 | parser.add_argument( 47 | '--test', 48 | action='store_true', 49 | default=False, 50 | help = 'test mode for certain ckpt') 51 | parser.add_argument( 52 | '--mode', 53 | choices=['easy', 'median', 'hard', None], 54 | default=None, 55 | help = 'difficulty mode for shapenet') 56 | args = parser.parse_args() 57 | 58 | if args.test and args.resume: 59 | raise ValueError( 60 | '--test and --resume cannot be both activate') 61 | 62 | if args.resume and args.start_ckpts is not None: 63 | raise ValueError( 64 | '--resume and --start_ckpts cannot be both activate') 65 | 66 | if args.test and args.ckpts is None: 67 | raise ValueError( 68 | 'ckpts shouldnt be None while test mode') 69 | 70 | if 'LOCAL_RANK' not in os.environ: 71 | os.environ['LOCAL_RANK'] = str(args.local_rank) 72 | 73 | if args.test: 74 | args.exp_name = 'test_' + args.exp_name 75 | if args.mode is not None: 76 | args.exp_name = args.exp_name + '_' +args.mode 77 | args.experiment_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem, args.exp_name) 78 | args.tfboard_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem,'TFBoard' ,args.exp_name) 79 | args.log_name = Path(args.config).stem 80 | create_experiment_dir(args) 81 | return args 82 | 83 | def create_experiment_dir(args): 84 | if not os.path.exists(args.experiment_path): 85 | os.makedirs(args.experiment_path) 86 | print('Create experiment path successfully at %s' % args.experiment_path) 87 | if not os.path.exists(args.tfboard_path): 88 | os.makedirs(args.tfboard_path) 89 | print('Create TFBoard path successfully at %s' % args.tfboard_path) 90 | 91 | --------------------------------------------------------------------------------