├── SAP
├── LICENSE
├── README.md
├── configs
│ ├── default.yaml
│ ├── learning_based
│ │ ├── demo_large_noise.yaml
│ │ ├── demo_outlier.yaml
│ │ ├── noise_large
│ │ │ ├── ours.yaml
│ │ │ └── ours_pretrained.yaml
│ │ ├── noise_small
│ │ │ ├── ours.yaml
│ │ │ └── ours_pretrained.yaml
│ │ └── outlier
│ │ │ ├── ours_1x.yaml
│ │ │ ├── ours_1x_pretrained.yaml
│ │ │ ├── ours_3plane.yaml
│ │ │ ├── ours_3x.yaml
│ │ │ ├── ours_3x_pretrained.yaml
│ │ │ ├── ours_5x.yaml
│ │ │ ├── ours_5x_pretrained.yaml
│ │ │ ├── ours_7x.yaml
│ │ │ └── ours_7x_pretrained.yaml
│ └── optim_based
│ │ ├── dfaust.yaml
│ │ ├── dgp.yaml
│ │ ├── teaser.yaml
│ │ ├── thingi.yaml
│ │ └── thingi_noisy.yaml
├── eval_meshes.py
├── generate.py
├── optim.py
├── optim_hierarchy.py
├── scripts
│ ├── __pycache__
│ │ └── easy_mesh_vtk.cpython-38.pyc
│ ├── download_demo_data.sh
│ ├── download_optim_data.sh
│ ├── download_shapenet.sh
│ ├── easy_mesh_vtk.py
│ ├── prcocess_tooth.py
│ ├── process_shapenet.py
│ └── process_tooth.py
└── src
│ ├── __init__.py
│ ├── __pycache__
│ ├── config.cpython-38.pyc
│ ├── dpsr.cpython-38.pyc
│ ├── generation.cpython-38.pyc
│ ├── model.cpython-38.pyc
│ ├── training.cpython-38.pyc
│ └── utils.cpython-38.pyc
│ ├── config.py
│ ├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── core.cpython-38.pyc
│ │ ├── fields.cpython-38.pyc
│ │ └── transforms.cpython-38.pyc
│ ├── core.py
│ ├── fields.py
│ └── transforms.py
│ ├── data_loader.py
│ ├── dpsr.py
│ ├── eval.py
│ ├── generation.py
│ ├── model.py
│ ├── model_rgb.py
│ ├── network
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── decoder.cpython-38.pyc
│ │ ├── encoder.cpython-38.pyc
│ │ ├── unet.cpython-38.pyc
│ │ ├── unet3d.cpython-38.pyc
│ │ └── utils.cpython-38.pyc
│ ├── decoder.py
│ ├── encoder.py
│ ├── net_rgb.py
│ ├── unet.py
│ ├── unet3d.py
│ └── utils.py
│ ├── optimization.py
│ ├── training.py
│ ├── utils.py
│ └── visualize.py
├── cfgs
├── Tooth_models
│ └── PoinTr.yaml
└── dataset_configs
│ └── Tooth.yaml
├── data
└── dental
│ └── crown
│ ├── test.txt
│ └── train.txt
├── datasets
├── __init__.py
├── __pycache__
│ ├── KITTIDataset.cpython-38.pyc
│ ├── PCNDataset.cpython-38.pyc
│ ├── ShapeNet55Dataset.cpython-38.pyc
│ ├── Wrapping_Python_vtk_util_numpy_support.cpython-38.pyc
│ ├── __init__.cpython-38.pyc
│ ├── build.cpython-38.pyc
│ ├── crowndataset.cpython-38.pyc
│ ├── data_transforms.cpython-38.pyc
│ ├── easy_mesh_vtk.cpython-38.pyc
│ └── io.cpython-38.pyc
├── build.py
├── crowndataset.py
└── easy_mesh_vtk.py
├── extensions
├── chamfer_dist
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-38.pyc
│ ├── build
│ │ ├── lib.linux-x86_64-3.8
│ │ │ └── chamfer.cpython-38-x86_64-linux-gnu.so
│ │ └── temp.linux-x86_64-3.8
│ │ │ ├── build.ninja
│ │ │ ├── chamfer.o
│ │ │ └── chamfer_cuda.o
│ ├── chamfer.cu
│ ├── chamfer.egg-info
│ │ ├── PKG-INFO
│ │ ├── SOURCES.txt
│ │ ├── dependency_links.txt
│ │ └── top_level.txt
│ ├── chamfer_cuda.cpp
│ ├── dist
│ │ └── chamfer-2.0.0-py3.8-linux-x86_64.egg
│ ├── setup.py
│ └── test.py
├── cubic_feature_sampling
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-38.pyc
│ ├── build
│ │ ├── lib.linux-x86_64-3.8
│ │ │ └── cubic_feature_sampling.cpython-38-x86_64-linux-gnu.so
│ │ └── temp.linux-x86_64-3.8
│ │ │ ├── build.ninja
│ │ │ ├── cubic_feature_sampling.o
│ │ │ └── cubic_feature_sampling_cuda.o
│ ├── cubic_feature_sampling.cu
│ ├── cubic_feature_sampling.egg-info
│ │ ├── PKG-INFO
│ │ ├── SOURCES.txt
│ │ ├── dependency_links.txt
│ │ └── top_level.txt
│ ├── cubic_feature_sampling_cuda.cpp
│ ├── dist
│ │ └── cubic_feature_sampling-1.1.0-py3.8-linux-x86_64.egg
│ ├── setup.py
│ └── test.py
├── gridding
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-38.pyc
│ ├── build
│ │ ├── lib.linux-x86_64-3.8
│ │ │ └── gridding.cpython-38-x86_64-linux-gnu.so
│ │ └── temp.linux-x86_64-3.8
│ │ │ ├── build.ninja
│ │ │ ├── gridding.o
│ │ │ ├── gridding_cuda.o
│ │ │ └── gridding_reverse.o
│ ├── dist
│ │ └── gridding-2.1.0-py3.8-linux-x86_64.egg
│ ├── gridding.cu
│ ├── gridding.egg-info
│ │ ├── PKG-INFO
│ │ ├── SOURCES.txt
│ │ ├── dependency_links.txt
│ │ └── top_level.txt
│ ├── gridding_cuda.cpp
│ ├── gridding_reverse.cu
│ ├── setup.py
│ └── test.py
└── gridding_loss
│ ├── __init__.py
│ ├── __pycache__
│ └── __init__.cpython-38.pyc
│ ├── build
│ ├── lib.linux-x86_64-3.8
│ │ └── gridding_distance.cpython-38-x86_64-linux-gnu.so
│ └── temp.linux-x86_64-3.8
│ │ ├── build.ninja
│ │ ├── gridding_distance.o
│ │ └── gridding_distance_cuda.o
│ ├── dist
│ └── gridding_distance-1.0.0-py3.8-linux-x86_64.egg
│ ├── gridding_distance.cu
│ ├── gridding_distance.egg-info
│ ├── PKG-INFO
│ ├── SOURCES.txt
│ ├── dependency_links.txt
│ └── top_level.txt
│ ├── gridding_distance_cuda.cpp
│ └── setup.py
├── install.sh
├── main.py
├── models
├── PoinTr.py
├── Transformer.py
├── __init__.py
├── __pycache__
│ ├── FoldingNet.cpython-38.pyc
│ ├── GRNet.cpython-38.pyc
│ ├── PCN.cpython-38.pyc
│ ├── PoinTr.cpython-38.pyc
│ ├── TopNet.cpython-38.pyc
│ ├── Transformer.cpython-38.pyc
│ ├── __init__.cpython-38.pyc
│ ├── build.cpython-38.pyc
│ ├── dgcnn_group.cpython-38.pyc
│ └── easy_mesh_vtk.cpython-38.pyc
├── build.py
├── dgcnn_group.py
└── easy_mesh_vtk.py
├── readme.txt
├── requirements.txt
├── tools
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── builder.cpython-38.pyc
│ └── runner.cpython-38.pyc
├── builder.py
└── runner.py
└── utils
├── AverageMeter.py
├── __pycache__
├── AverageMeter.cpython-38.pyc
├── config.cpython-38.pyc
├── dist_utils.cpython-38.pyc
├── logger.cpython-38.pyc
├── metrics.cpython-38.pyc
├── misc.cpython-38.pyc
├── parser.cpython-38.pyc
└── registry.cpython-38.pyc
├── config.py
├── dist_utils.py
├── logger.py
├── metrics.py
├── misc.py
├── parser.py
└── registry.py
/SAP/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 autonomousvision
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/SAP/README.md:
--------------------------------------------------------------------------------
1 | # Shape As Points (SAP)
2 |
3 | ### [**Paper**](https://arxiv.org/abs/2106.03452) | [**Project Page**](https://pengsongyou.github.io/sap) | [**Short Video (6 min)**](https://youtu.be/FL8LMk_qWb4) | [**Long Video (12 min)**](https://youtu.be/TgR0NvYty0A)
4 |
5 | 
6 |
7 | This repository contains the implementation of the paper:
8 |
9 | Shape As Points: A Differentiable Poisson Solver
10 | [Songyou Peng](https://pengsongyou.github.io/), [Chiyu "Max" Jiang](https://www.maxjiang.ml/), [Yiyi Liao](https://yiyiliao.github.io/), [Michael Niemeyer](https://m-niemeyer.github.io/), [Marc Pollefeys](https://www.inf.ethz.ch/personal/pomarc/) and [Andreas Geiger](http://www.cvlibs.net/)
11 | **NeurIPS 2021 (Oral)**
12 |
13 |
14 | If you find our code or paper useful, please consider citing
15 | ```bibtex
16 | @inproceedings{Peng2021SAP,
17 | author = {Peng, Songyou and Jiang, Chiyu "Max" and Liao, Yiyi and Niemeyer, Michael and Pollefeys, Marc and Geiger, Andreas},
18 | title = {Shape As Points: A Differentiable Poisson Solver},
19 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
20 | year = {2021}}
21 | ```
22 |
23 |
24 | ## Installation
25 | First you have to make sure that you have all dependencies in place.
26 | The simplest way to do so, is to use [anaconda](https://www.anaconda.com/).
27 |
28 | You can create an anaconda environment called `sap` using
29 | ```
30 | conda env create -f environment.yaml
31 | conda activate sap
32 | ```
33 |
34 | Next, you should install [PyTorch3D](https://pytorch3d.org/) (**>=0.5**) yourself from the [official instruction](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md#3-install-wheels-for-linux).
35 |
36 | And install [PyTorch Scatter](https://github.com/rusty1s/pytorch_scatter):
37 | ```sh
38 | conda install pytorch-scatter -c pyg
39 | ```
40 |
41 |
42 | ## Demo - Quick Start
43 |
44 | First, run the script to get the demo data:
45 |
46 | ```bash
47 | bash scripts/download_demo_data.sh
48 | ```
49 |
50 | ### Optimization-based 3D Surface Reconstruction
51 |
52 | You can now quickly test our code on the data shown in the teaser. To this end, simply run:
53 |
54 | ```python
55 | python optim_hierarchy.py configs/optim_based/teaser.yaml
56 | ```
57 | This script should create a folder `out/demo_optim` where the output meshes and the optimized oriented point clouds under different grid resolution are stored.
58 |
59 | To visualize the optimization process on the fly, you can set `o3d_show: Frue` in [`configs/optim_based/teaser.yaml`](https://github.com/autonomousvision/shape_as_points/tree/main/configs/optim_based/teaser.yaml).
60 |
61 | ### Learning-based 3D Surface Reconstruction
62 | You can also test SAP on another application where we can reconstruct from unoriented point clouds with either **large noises** or **outliers** with a learned network.
63 |
64 | 
65 |
66 | For the point clouds with large noise as shown above, you can run:
67 | ```python
68 | python generate.py configs/learning_based/demo_large_noise.yaml
69 | ```
70 | The results can been found at `out/demo_shapenet_large_noise/generation/vis`.
71 |
72 | 
73 | As for the point clouds with outliers, you can run:
74 | ```python
75 | python generate.py configs/learning_based/demo_outlier.yaml
76 | ```
77 | You can find the reconstrution on `out/demo_shapenet_outlier/generation/vis`.
78 |
79 |
80 | ## Dataset
81 |
82 | We have different dataset for our optimization-based and learning-based settings.
83 |
84 | ### Dataset for Optimization-based Reconstruction
85 | Here we consider the following dataset:
86 | - [Thingi10K](https://arxiv.org/abs/1605.04797) (synthetic)
87 | - [Surface Reconstruction Benchmark (SRB)](https://github.com/fwilliams/deep-geometric-prior) (real scans)
88 | - [MPI Dynamic FAUST](https://dfaust.is.tue.mpg.de/) (real scans)
89 |
90 | Please cite the corresponding papers if you use the data.
91 |
92 | You can download the processed dataset (~200 MB) by running:
93 | ```bash
94 | bash scripts/download_optim_data.sh
95 | ```
96 |
97 | ### Dataset for Learning-based Reconstruction
98 | We train and evaluate on [ShapeNet](https://shapenet.org/).
99 | You can download the processed dataset (~220 GB) by running:
100 | ```bash
101 | bash scripts/download_shapenet.sh
102 | ```
103 | After, you should have the dataset in `data/shapenet_psr` folder.
104 |
105 | Alternatively, you can also preprocess the dataset yourself. To this end, you can:
106 | * first download the preprocessed dataset (73.4 GB) by running [the script](https://github.com/autonomousvision/occupancy_networks#preprocessed-data) from Occupancy Networks.
107 | * check [`scripts/process_shapenet.py`](https://github.com/autonomousvision/shape_as_points/tree/main/scripts/process_shapenet.py), modify the base path and run the code
108 |
109 |
110 | ## Usage for Optimization-based 3D Reconstruction
111 |
112 | For our optimization-based setting, you can consider running with a coarse-to-fine strategy:
113 | ```python
114 | python optim_hierarchy.py configs/optim_based/CONFIG.yaml
115 | ```
116 | We start from a grid resolution of 32^3, and increase to 64^3, 128^3 and finally 256^3.
117 |
118 | Alternatively, you can also run on a single resolution with:
119 |
120 | ```python
121 | python optim.py configs/optim_based/CONFIG.yaml
122 | ```
123 | You might need to modify the `CONFIG.yaml` accordingly.
124 |
125 | ## Usage for Learning-based 3D Reconstruction
126 |
127 | ### Mesh Generation
128 | To generate meshes using a trained model, use
129 | ```python
130 | python generate.py configs/learning_based/CONFIG.yaml
131 | ```
132 | where you replace `CONFIG.yaml` with the correct config file.
133 |
134 | #### Use a pre-trained model
135 | The easiest way is to use a pre-trained model. You can do this by using one of the config files with postfix `_pretrained`.
136 |
137 | For example, for 3D reconstruction from point clouds with outliers using our model with 7x offsets, you can simply run:
138 | ```python
139 | python generate.py configs/learning_based/outlier/ours_7x_pretrained.yaml
140 | ```
141 |
142 | The script will automatically download the pretrained model and run the generation. You can find the outputs in the `out/.../generation_pretrained` folders.
143 |
144 | **Note** config files are only for generation, not for training new models: when these configs are used for training, the model will be trained from scratch, but during inference our code will still use the pretrained model.
145 |
146 | We provide the following pretrained models:
147 | ```
148 | noise_small/ours.pt
149 | noise_large/ours.pt
150 | outlier/ours_1x.pt
151 | outlier/ours_3x.pt
152 | outlier/ours_5x.pt
153 | outlier/ours_7x.pt
154 | outlier/ours_3plane.pt
155 | ```
156 |
157 |
158 | ### Evaluation
159 | To evaluate a trained model, we provide the script [`eval_meshes.py`](https://github.com/autonomousvision/shape_as_points/blob/main/eval_meshes.py). You can run it using:
160 | ```python
161 | python eval_meshes.py configs/learning_based/CONFIG.yaml
162 | ```
163 | The script takes the meshes generated in the previous step and evaluates them using a standardized protocol. The output will be written to `.pkl` and `.csv` files in the corresponding generation folder that can be processed using [pandas](https://pandas.pydata.org/).
164 |
165 | ### Training
166 |
167 | Finally, to train a new network from scratch, simply run:
168 | ```python
169 | python train.py configs/learning_based/CONFIG.yaml
170 | ```
171 | For available training options, please take a look at `configs/default.yaml`.
172 |
173 |
--------------------------------------------------------------------------------
/SAP/configs/default.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset: Shapes3D
3 | path: data/ShapeNet
4 | class: null
5 | data_type: img
6 | input_type: pointcloud
7 | dim: 3
8 | num_points: 1000
9 | num_gt_points: 1000
10 | num_offset: 1
11 | img_size: null
12 | n_views_input: 20
13 | n_views_per_iter: 2
14 | pointcloud_noise: 0
15 | pointcloud_file: pointcloud.npz
16 | pointcloud_outlier_ratio: 0
17 | fixed_scale: 0
18 | train_split: train
19 | val_split: val
20 | test_split: test
21 | points_file: null
22 | points_iou_file: points.npz
23 | points_unpackbits: true
24 | padding: 0.1
25 | multi_files: null
26 | gt_mesh: null
27 | zero_level: 0
28 | only_single: False
29 | sample_only_floor: False
30 | model:
31 | apply_sigmoid: True
32 | grid_res: 128 # poisson grid resolution
33 | psr_sigma: 0
34 | psr_tanh: False
35 | normal_normalize: False
36 | raster: {}
37 | renderer: {}
38 | encoder: null
39 | predict_normal: True
40 | predict_offset: True
41 | s_offset: 0.001
42 | local_coord: True
43 | encoder_kwargs: {}
44 | unet3d: False
45 | unet3d_kwargs: {}
46 | multi_gpu: false
47 | rotate_matrix: false
48 | c_dim: 512
49 | sphere_radius: 0.2
50 | train:
51 | lr: 1e-3
52 | lr_pcl: 2e-2
53 | input_mesh: ''
54 | out_dir: out/default
55 | subsample_vertex: False
56 | batch_size: 4
57 | n_grow_points: 0
58 | resample_every: 0
59 | l_weight: {}
60 | w_reg_point: 0
61 | w_psr: 0
62 | w_raw: 0 # train with raw point cloud
63 | gauss_weight: 0
64 | n_sup_point: 0
65 | w_normals: 0
66 | total_epochs: 10
67 | print_every: 1
68 | visualize_every: 1
69 | save_every: 1
70 | vis_vert_color: True
71 | o3d_show: False
72 | o3d_vis_pcl: True
73 | o3d_window_size: 540
74 | vis_rendering: False
75 | vis_psr: False
76 | save_video: False
77 | exp_mesh: True
78 | exp_pcl: True
79 | checkpoint_every: 1
80 | validate_every: 1
81 | backup_every: 1
82 | timestamp: False # add timestamp to out_dir name
83 | model_selection_metric: loss
84 | model_selection_mode: minimize
85 | n_workers: 0
86 | n_workers_val: 0
87 | test:
88 | threshold: 0.5
89 | eval_mesh: true
90 | eval_pointcloud: false
91 | model_file: model_best.pt
92 | generation:
93 | batch_size: 4
94 | exp_gt: False
95 | exp_oracle: false
96 | exp_input: False
97 | vis_n_outputs: 10
98 | generate_mesh: true
99 | generate_pointcloud: true
100 | generation_dir: generation
101 | copy_input: true
102 | use_sampling: false
103 | psr_resolution: 0
104 | psr_sigma: 0
105 |
--------------------------------------------------------------------------------
/SAP/configs/learning_based/demo_large_noise.yaml:
--------------------------------------------------------------------------------
1 |
2 | inherit_from: configs/learning_based/noise_large/ours.yaml
3 | data:
4 | class: ['']
5 | path: data/demo/shapenet_chair
6 | train:
7 | out_dir: out/demo_shapenet_large_noise
8 | test:
9 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_025.pt
--------------------------------------------------------------------------------
/SAP/configs/learning_based/demo_outlier.yaml:
--------------------------------------------------------------------------------
1 |
2 | inherit_from: configs/learning_based/outlier/ours_7x.yaml
3 | data:
4 | class: ['']
5 | path: data/demo/shapenet_lamp
6 | train:
7 | out_dir: out/demo_shapenet_outlier
8 | test:
9 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_7x.pt
--------------------------------------------------------------------------------
/SAP/configs/learning_based/noise_large/ours.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class: null
3 | data_type: psr_full
4 | input_type: pointcloud
5 | path: data/shapenet_psr
6 | num_gt_points: 10000
7 | num_offset: 7
8 | pointcloud_n: 3000
9 | pointcloud_noise: 0.025
10 | model:
11 | grid_res: 128 # poisson grid resolution
12 | psr_sigma: 2
13 | psr_tanh: True
14 | normal_normalize: False
15 | predict_normal: True
16 | predict_offset: True
17 | c_dim: 32
18 | s_offset: 0.001
19 | encoder: local_pool_pointnet
20 | encoder_kwargs:
21 | hidden_dim: 32
22 | plane_type: 'grid'
23 | grid_resolution: 32
24 | unet3d: True
25 | unet3d_kwargs:
26 | num_levels: 3
27 | f_maps: 32
28 | in_channels: 32
29 | out_channels: 32
30 | decoder: simple_local
31 | decoder_kwargs:
32 | sample_mode: bilinear # bilinear / nearest
33 | hidden_size: 32
34 | train:
35 | batch_size: 32
36 | lr: 5e-4
37 | out_dir: out/shapenet/noise_025_ours
38 | w_psr: 1
39 | model_selection_metric: psr_l2
40 | print_every: 100
41 | checkpoint_every: 200
42 | validate_every: 5000
43 | backup_every: 10000
44 | total_epochs: 400000
45 | visualize_every: 5000
46 | exp_pcl: True
47 | exp_mesh: True
48 | n_workers: 8
49 | n_workers_val: 0
50 | generation:
51 | exp_gt: False
52 | exp_input: True
53 | psr_resolution: 128
54 | psr_sigma: 2
55 |
--------------------------------------------------------------------------------
/SAP/configs/learning_based/noise_large/ours_pretrained.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/learning_based/noise_large/ours.yaml
2 | generation:
3 | generation_dir: generation_pretrained
4 | test:
5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_025.pt
--------------------------------------------------------------------------------
/SAP/configs/learning_based/noise_small/ours.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class: null
3 | data_type: psr_full
4 | input_type: pointcloud
5 | dim: 3
6 | path: data/shapenet_psr
7 | num_gt_points: 10000
8 | num_offset: 7
9 | pointcloud_n: 3000
10 | pointcloud_noise: 0
11 | padding: 0.1
12 | model:
13 | grid_res: 128 # poisson grid resolution
14 | psr_sigma: 2
15 | psr_tanh: True
16 | normal_normalize: False
17 | predict_normal: True
18 | predict_offset: True
19 | c_dim: 32
20 | s_offset: 0.001
21 | local_coord: True
22 | encoder: local_pool_pointnet
23 | encoder_kwargs:
24 | hidden_dim: 32
25 | plane_type: 'grid'
26 | grid_resolution: 32
27 | unet3d: True
28 | unet3d_kwargs:
29 | num_levels: 3
30 | f_maps: 32
31 | in_channels: 32
32 | out_channels: 32
33 | decoder: simple_local
34 | decoder_kwargs:
35 | sample_mode: bilinear # bilinear / nearest
36 | hidden_size: 32
37 | train:
38 | batch_size: 4
39 | lr: 5e-4
40 | out_dir: out/shapenet/noise_0_ours_pointr_100ep
41 | w_psr: 1
42 | model_selection_metric: psr_l2
43 | print_every: 1
44 | checkpoint_every: 1
45 | validate_every: 1
46 | backup_every: 1
47 | total_epochs: 100
48 | visualize_every: 1
49 | exp_pcl: True
50 | exp_mesh: True
51 | n_workers: 8
52 | n_workers_val: 0
53 | generation:
54 | exp_gt: False
55 | exp_input: True
56 | psr_resolution: 128
57 | psr_sigma: 2
58 |
--------------------------------------------------------------------------------
/SAP/configs/learning_based/noise_small/ours_pretrained.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/learning_based/noise_small/ours.yaml
2 | generation:
3 | generation_dir: generation_pretrained_poinTr
4 | test:
5 | model_file: C:/Users/Golriz/OneDrive - polymtl.ca/Desktop/SAP/shape_as_points-main/shape_as_points-main/models/model_best.pt
--------------------------------------------------------------------------------
/SAP/configs/learning_based/outlier/ours_1x.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class: null
3 | data_type: psr_full
4 | input_type: pointcloud
5 | path: data/shapenet_psr
6 | num_gt_points: 10000
7 | num_offset: 1
8 | pointcloud_n: 3000
9 | pointcloud_noise: 0.005
10 | pointcloud_outlier_ratio: 0.5
11 | model:
12 | grid_res: 128 # poisson grid resolution
13 | psr_sigma: 2
14 | psr_tanh: True
15 | normal_normalize: False
16 | predict_normal: True
17 | predict_offset: True
18 | c_dim: 32
19 | s_offset: 0.001
20 | encoder: local_pool_pointnet
21 | encoder_kwargs:
22 | hidden_dim: 32
23 | plane_type: 'grid'
24 | grid_resolution: 32
25 | unet3d: True
26 | unet3d_kwargs:
27 | num_levels: 3
28 | f_maps: 32
29 | in_channels: 32
30 | out_channels: 32
31 | decoder: simple_local
32 | decoder_kwargs:
33 | sample_mode: bilinear # bilinear / nearest
34 | hidden_size: 32
35 | train:
36 | batch_size: 32
37 | lr: 5e-4
38 | out_dir: out/shapenet/outlier_ours_1x
39 | w_psr: 1
40 | model_selection_metric: psr_l2
41 | print_every: 100
42 | checkpoint_every: 200
43 | validate_every: 5000
44 | backup_every: 10000
45 | total_epochs: 400000
46 | visualize_every: 5000
47 | exp_pcl: True
48 | exp_mesh: True
49 | n_workers: 8
50 | n_workers_val: 0
51 | generation:
52 | exp_gt: False
53 | exp_input: True
54 | psr_resolution: 128
55 | psr_sigma: 2
56 |
--------------------------------------------------------------------------------
/SAP/configs/learning_based/outlier/ours_1x_pretrained.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/learning_based/outlier/ours_3x/ours.yaml
2 | generation:
3 | generation_dir: generation_pretrained
4 | test:
5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_3x.pt
--------------------------------------------------------------------------------
/SAP/configs/learning_based/outlier/ours_3plane.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class: null
3 | data_type: psr_full
4 | input_type: pointcloud
5 | path: data/shapenet_psr
6 | num_gt_points: 10000
7 | num_offset: 5
8 | pointcloud_n: 3000
9 | pointcloud_noise: 0.005
10 | pointcloud_outlier_ratio: 0.5
11 | model:
12 | grid_res: 128 # poisson grid resolution
13 | psr_sigma: 2
14 | psr_tanh: True
15 | normal_normalize: False
16 | predict_normal: True
17 | predict_offset: True
18 | c_dim: 32
19 | s_offset: 0.001
20 | encoder: local_pool_pointnet
21 | encoder_kwargs:
22 | hidden_dim: 32
23 | plane_type: ['xz', 'xy', 'yz']
24 | plane_resolution: 64
25 | unet: True
26 | unet_kwargs:
27 | depth: 4
28 | merge_mode: concat
29 | start_filts: 32
30 | decoder: simple_local
31 | decoder_kwargs:
32 | sample_mode: bilinear # bilinear / nearest
33 | hidden_size: 32
34 | train:
35 | batch_size: 32
36 | lr: 5e-4
37 | out_dir: out/shapenet/outlier_ours_3plane
38 | w_psr: 1
39 | model_selection_metric: psr_l2
40 | print_every: 100
41 | checkpoint_every: 200
42 | validate_every: 5000
43 | backup_every: 10000
44 | total_epochs: 400000
45 | visualize_every: 5000
46 | exp_pcl: True
47 | exp_mesh: True
48 | n_workers: 8
49 | n_workers_val: 0
50 | generation:
51 | exp_gt: False
52 | exp_input: True
53 | psr_resolution: 128
54 | psr_sigma: 2
55 |
--------------------------------------------------------------------------------
/SAP/configs/learning_based/outlier/ours_3x.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class: null
3 | data_type: psr_full
4 | input_type: pointcloud
5 | path: data/shapenet_psr
6 | num_gt_points: 10000
7 | num_offset: 3
8 | pointcloud_n: 3000
9 | pointcloud_noise: 0.005
10 | pointcloud_outlier_ratio: 0.5
11 | model:
12 | grid_res: 128 # poisson grid resolution
13 | psr_sigma: 2
14 | psr_tanh: True
15 | normal_normalize: False
16 | predict_normal: True
17 | predict_offset: True
18 | c_dim: 32
19 | s_offset: 0.001
20 | encoder: local_pool_pointnet
21 | encoder_kwargs:
22 | hidden_dim: 32
23 | plane_type: 'grid'
24 | grid_resolution: 32
25 | unet3d: True
26 | unet3d_kwargs:
27 | num_levels: 3
28 | f_maps: 32
29 | in_channels: 32
30 | out_channels: 32
31 | decoder: simple_local
32 | decoder_kwargs:
33 | sample_mode: bilinear # bilinear / nearest
34 | hidden_size: 32
35 | train:
36 | batch_size: 32
37 | lr: 5e-4
38 | out_dir: out/shapenet/outlier_ours_3x
39 | w_psr: 1
40 | model_selection_metric: psr_l2
41 | print_every: 100
42 | checkpoint_every: 200
43 | validate_every: 5000
44 | backup_every: 10000
45 | total_epochs: 400000
46 | visualize_every: 5000
47 | exp_pcl: True
48 | exp_mesh: True
49 | n_workers: 8
50 | n_workers_val: 0
51 | generation:
52 | exp_gt: False
53 | exp_input: True
54 | psr_resolution: 128
55 | psr_sigma: 2
56 |
--------------------------------------------------------------------------------
/SAP/configs/learning_based/outlier/ours_3x_pretrained.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/learning_based/outlier/ours_1x/ours.yaml
2 | generation:
3 | generation_dir: generation_pretrained
4 | test:
5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_1x.pt
--------------------------------------------------------------------------------
/SAP/configs/learning_based/outlier/ours_5x.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class: null
3 | data_type: psr_full
4 | input_type: pointcloud
5 | path: data/shapenet_psr
6 | num_gt_points: 10000
7 | num_offset: 5
8 | pointcloud_n: 3000
9 | pointcloud_noise: 0.005
10 | pointcloud_outlier_ratio: 0.5
11 | model:
12 | grid_res: 128 # poisson grid resolution
13 | psr_sigma: 2
14 | psr_tanh: True
15 | normal_normalize: False
16 | predict_normal: True
17 | predict_offset: True
18 | c_dim: 32
19 | s_offset: 0.001
20 | encoder: local_pool_pointnet
21 | encoder_kwargs:
22 | hidden_dim: 32
23 | plane_type: 'grid'
24 | grid_resolution: 32
25 | unet3d: True
26 | unet3d_kwargs:
27 | num_levels: 3
28 | f_maps: 32
29 | in_channels: 32
30 | out_channels: 32
31 | decoder: simple_local
32 | decoder_kwargs:
33 | sample_mode: bilinear # bilinear / nearest
34 | hidden_size: 32
35 | train:
36 | batch_size: 32
37 | lr: 5e-4
38 | out_dir: out/shapenet/outlier_ours_5x
39 | w_psr: 1
40 | model_selection_metric: psr_l2
41 | print_every: 100
42 | checkpoint_every: 200
43 | validate_every: 5000
44 | backup_every: 10000
45 | total_epochs: 400000
46 | visualize_every: 5000
47 | exp_pcl: True
48 | exp_mesh: True
49 | n_workers: 8
50 | n_workers_val: 0
51 | generation:
52 | exp_gt: False
53 | exp_input: True
54 | psr_resolution: 128
55 | psr_sigma: 2
56 |
--------------------------------------------------------------------------------
/SAP/configs/learning_based/outlier/ours_5x_pretrained.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/learning_based/outlier/ours_5x/ours.yaml
2 | generation:
3 | generation_dir: generation_pretrained
4 | test:
5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_5x.pt
--------------------------------------------------------------------------------
/SAP/configs/learning_based/outlier/ours_7x.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class: null
3 | data_type: psr_full
4 | input_type: pointcloud
5 | path: data/shapenet_psr
6 | num_gt_points: 10000
7 | num_offset: 7
8 | pointcloud_n: 3000
9 | pointcloud_noise: 0.005
10 | pointcloud_outlier_ratio: 0.5
11 | model:
12 | grid_res: 128 # poisson grid resolution
13 | psr_sigma: 2
14 | psr_tanh: True
15 | normal_normalize: False
16 | predict_normal: True
17 | predict_offset: True
18 | c_dim: 32
19 | s_offset: 0.001
20 | encoder: local_pool_pointnet
21 | encoder_kwargs:
22 | hidden_dim: 32
23 | plane_type: 'grid'
24 | grid_resolution: 32
25 | unet3d: True
26 | unet3d_kwargs:
27 | num_levels: 3
28 | f_maps: 32
29 | in_channels: 32
30 | out_channels: 32
31 | decoder: simple_local
32 | decoder_kwargs:
33 | sample_mode: bilinear # bilinear / nearest
34 | hidden_size: 32
35 | train:
36 | batch_size: 32
37 | lr: 5e-4
38 | out_dir: out/shapenet/outlier_ours_7x
39 | w_psr: 1
40 | model_selection_metric: psr_l2
41 | print_every: 100
42 | checkpoint_every: 200
43 | validate_every: 5000
44 | backup_every: 10000
45 | total_epochs: 400000
46 | visualize_every: 5000
47 | exp_pcl: True
48 | exp_mesh: True
49 | n_workers: 8
50 | n_workers_val: 0
51 | generation:
52 | exp_gt: False
53 | exp_input: True
54 | psr_resolution: 128
55 | psr_sigma: 2
56 |
--------------------------------------------------------------------------------
/SAP/configs/learning_based/outlier/ours_7x_pretrained.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/learning_based/outlier/ours_7x.yaml
2 | generation:
3 | generation_dir: generation_pretrained
4 | test:
5 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_7x.pt
--------------------------------------------------------------------------------
/SAP/configs/optim_based/dfaust.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class: 'only_pcl'
3 | data_type: point
4 | data_path: 'data/dfaust/*.ply'
5 | object_id: 0
6 | num_points: 20000
7 | model:
8 | sphere_radius: 0.2
9 | grid_res: 256 # poisson grid resolution
10 | psr_sigma: 2
11 | train:
12 | schedule:
13 | pcl:
14 | initial: 1e-2
15 | interval: 700
16 | factor: 0.5
17 | final: 1e-3
18 | out_dir: out/dfaust
19 | w_chamfer: 1
20 | n_sup_point: 20000
21 | batch_size: 1
22 | n_grow_points: 2000
23 | resample_every: 200
24 | subsample_vertex: False
25 | total_epochs: 1600
26 | print_every: 10
27 | visualize_every: 2
28 | checkpoint_every: 500
29 | save_every: 100
30 | exp_pcl: True
31 | exp_mesh: True
32 | o3d_show: False
33 | o3d_vis_pcl: True
34 | o3d_window_size: 540
35 | vis_rendering: False
36 | vis_vert_color: False
37 | n_workers: 0
38 |
--------------------------------------------------------------------------------
/SAP/configs/optim_based/dgp.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class: 'only_pcl'
3 | data_type: point
4 | data_path: 'data/deep_geometric_prior_data/*.ply'
5 | object_id: 0
6 | num_points: 20000
7 | model:
8 | sphere_radius: 0.2
9 | grid_res: 256 # poisson grid resolution
10 | psr_sigma: 2
11 | train:
12 | schedule:
13 | pcl:
14 | initial: 1e-2
15 | interval: 700
16 | factor: 0.5
17 | final: 1e-3
18 | out_dir: out/dgp
19 | w_reg_point: 0
20 | w_chamfer: 1
21 | n_sup_point: 20000
22 | batch_size: 1
23 | n_grow_points: 2000
24 | resample_every: 200
25 | subsample_vertex: False
26 | total_epochs: 1600
27 | print_every: 10
28 | visualize_every: 2
29 | checkpoint_every: 500
30 | save_every: 100
31 | exp_pcl: True
32 | exp_mesh: True
33 | o3d_show: False
34 | o3d_vis_pcl: True
35 | o3d_window_size: 540
36 | vis_rendering: False
37 | vis_vert_color: False
38 | n_workers: 0
39 |
--------------------------------------------------------------------------------
/SAP/configs/optim_based/teaser.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class: 'only_pcl'
3 | data_type: point
4 | data_path: 'data/demo/wheel.ply'
5 | object_id: 0
6 | num_points: 20000
7 | model:
8 | sphere_radius: 0.2
9 | grid_res: 128 # poisson grid resolution
10 | psr_sigma: 2
11 | train:
12 | schedule:
13 | pcl:
14 | initial: 1e-2
15 | interval: 700
16 | factor: 0.5
17 | final: 1e-3
18 | out_dir: out/demo_optim
19 | w_chamfer: 1
20 | n_sup_point: 20000
21 | batch_size: 1
22 | n_grow_points: 2000
23 | resample_every: 200
24 | subsample_vertex: False
25 | total_epochs: 1600
26 | print_every: 10
27 | visualize_every: 2
28 | checkpoint_every: 500
29 | save_every: 100
30 | exp_pcl: True
31 | exp_mesh: True
32 | o3d_show: False
33 | o3d_vis_pcl: True
34 | o3d_window_size: 540
35 | vis_rendering: False
36 | vis_vert_color: False
37 | n_workers: 0
--------------------------------------------------------------------------------
/SAP/configs/optim_based/thingi.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class: 'only_pcl'
3 | data_type: point
4 | data_path: 'data/thingi/*.ply'
5 | object_id: 0
6 | num_points: 20000
7 | model:
8 | sphere_radius: 0.2
9 | grid_res: 128 # poisson grid resolution
10 | psr_sigma: 2
11 | train:
12 | # lr_pcl: 2e-2
13 | schedule:
14 | pcl:
15 | initial: 1e-2
16 | interval: 700
17 | factor: 0.5
18 | final: 1e-3
19 | out_dir: out/thingi
20 | w_reg_point: 0
21 | w_chamfer: 1
22 | n_sup_point: 20000
23 | batch_size: 1
24 | n_grow_points: 2000
25 | resample_every: 200
26 | subsample_vertex: False
27 | total_epochs: 1600
28 | print_every: 10
29 | visualize_every: 2
30 | checkpoint_every: 500
31 | save_every: 100
32 | exp_pcl: True
33 | exp_mesh: True
34 | o3d_show: False
35 | o3d_vis_pcl: True
36 | o3d_window_size: 540
37 | vis_rendering: False
38 | vis_vert_color: False
39 | n_workers: 0
40 |
--------------------------------------------------------------------------------
/SAP/configs/optim_based/thingi_noisy.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class: 'only_pcl'
3 | data_type: point
4 | data_path: 'data/thingi_noisy/*.ply'
5 | object_id: 0
6 | num_points: 20000
7 | model:
8 | sphere_radius: 0.2
9 | grid_res: 128 # poisson grid resolution
10 | psr_sigma: 2
11 | train:
12 | # lr_pcl: 2e-2
13 | schedule:
14 | pcl:
15 | initial: 1e-2
16 | interval: 700
17 | factor: 0.5
18 | final: 1e-3
19 | out_dir: out/thingi_noisy
20 | w_reg_point: 0
21 | w_chamfer: 1
22 | n_sup_point: 20000
23 | batch_size: 1
24 | n_grow_points: 2000
25 | resample_every: 200
26 | subsample_vertex: False
27 | total_epochs: 1600
28 | print_every: 10
29 | visualize_every: 2
30 | checkpoint_every: 500
31 | save_every: 100
32 | exp_pcl: True
33 | exp_mesh: True
34 | o3d_show: False
35 | o3d_vis_pcl: True
36 | o3d_window_size: 540
37 | vis_rendering: False
38 | vis_vert_color: False
39 | n_workers: 0
40 |
--------------------------------------------------------------------------------
/SAP/eval_meshes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import trimesh
3 | from torch.utils.data import Dataset, DataLoader
4 | import numpy as np; np.set_printoptions(precision=4)
5 | import shutil, argparse, time, os
6 | import pandas as pd
7 | from src.data import collate_remove_none, collate_stack_together, worker_init_fn
8 | from src.training import Trainer
9 | from src.model import Encode2Points
10 | from src.data import PointCloudField, IndexField, Shapes3dDataset
11 | from src.utils import load_config, load_pointcloud
12 | from src.eval import MeshEvaluator
13 | from tqdm import tqdm
14 | from pdb import set_trace as st
15 |
16 |
17 | def main():
18 | parser = argparse.ArgumentParser(description='MNIST toy experiment')
19 | parser.add_argument('config', type=str, help='Path to config file.')
20 | parser.add_argument('--no_cuda', action='store_true', default=False,
21 | help='disables CUDA training')
22 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
23 | parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
24 |
25 | args = parser.parse_args()
26 | cfg = load_config(args.config, 'configs/default.yaml')
27 | use_cuda = not args.no_cuda and torch.cuda.is_available()
28 | device = torch.device("cuda" if use_cuda else "cpu")
29 | data_type = cfg['data']['data_type']
30 | # Shorthands
31 | out_dir = cfg['train']['out_dir']
32 | generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
33 |
34 | if cfg['generation'].get('iter', 0)!=0:
35 | generation_dir += '_%04d'%cfg['generation']['iter']
36 | elif args.iter is not None:
37 | generation_dir += '_%04d'%args.iter
38 |
39 | print('Evaluate meshes under %s'%generation_dir)
40 |
41 | out_file = os.path.join(generation_dir, 'eval_meshes_full.pkl')
42 | out_file_class = os.path.join(generation_dir, 'eval_meshes.csv')
43 |
44 | # PYTORCH VERSION > 1.0.0
45 | assert(float(torch.__version__.split('.')[-3]) > 0)
46 |
47 | pointcloud_field = PointCloudField(cfg['data']['pointcloud_file'])
48 | fields = {
49 | 'pointcloud': pointcloud_field,
50 | 'idx': IndexField(),
51 | }
52 |
53 | print('Test split: ', cfg['data']['test_split'])
54 |
55 | dataset_folder = cfg['data']['path']
56 | dataset = Shapes3dDataset(
57 | dataset_folder, fields,
58 | cfg['data']['test_split'],
59 | categories=cfg['data']['class'], cfg=cfg)
60 |
61 | # Loader
62 | test_loader = torch.utils.data.DataLoader(
63 | dataset, batch_size=1, num_workers=0, shuffle=False)
64 |
65 | # Evaluator
66 | evaluator = MeshEvaluator(n_points=100000)
67 |
68 | eval_dicts = []
69 | print('Evaluating meshes...')
70 | for it, data in enumerate(tqdm(test_loader)):
71 |
72 | if data is None:
73 | print('Invalid data.')
74 | continue
75 |
76 | mesh_dir = os.path.join(generation_dir, 'meshes')
77 | pointcloud_dir = os.path.join(generation_dir, 'pointcloud')
78 |
79 |
80 | # Get index etc.
81 | idx = data['idx'].item()
82 | try:
83 | model_dict = dataset.get_model_dict(idx)
84 | except AttributeError:
85 | model_dict = {'model': str(idx), 'category': 'n/a'}
86 |
87 | modelname = model_dict['model']
88 | category_id = model_dict['category']
89 |
90 | try:
91 | category_name = dataset.metadata[category_id].get('name', 'n/a')
92 | except AttributeError:
93 | category_name = 'n/a'
94 |
95 | if category_id != 'n/a':
96 | mesh_dir = os.path.join(mesh_dir, category_id)
97 | pointcloud_dir = os.path.join(pointcloud_dir, category_id)
98 |
99 | # Evaluate
100 | pointcloud_tgt = data['pointcloud'].squeeze(0).numpy()
101 | normals_tgt = data['pointcloud.normals'].squeeze(0).numpy()
102 |
103 |
104 | eval_dict = {
105 | 'idx': idx,
106 | 'class id': category_id,
107 | 'class name': category_name,
108 | 'modelname':modelname,
109 | }
110 | eval_dicts.append(eval_dict)
111 |
112 | # Evaluate mesh
113 | if cfg['test']['eval_mesh']:
114 | mesh_file = os.path.join(mesh_dir, '%s.off' % modelname)
115 |
116 | if os.path.exists(mesh_file):
117 | mesh = trimesh.load(mesh_file, process=False)
118 | eval_dict_mesh = evaluator.eval_mesh(
119 | mesh, pointcloud_tgt, normals_tgt)
120 | for k, v in eval_dict_mesh.items():
121 | eval_dict[k + ' (mesh)'] = v
122 | else:
123 | print('Warning: mesh does not exist: %s' % mesh_file)
124 |
125 | # Evaluate point cloud
126 | if cfg['test']['eval_pointcloud']:
127 | pointcloud_file = os.path.join(
128 | pointcloud_dir, '%s.ply' % modelname)
129 |
130 | if os.path.exists(pointcloud_file):
131 | pointcloud = load_pointcloud(pointcloud_file).astype(np.float32)
132 | eval_dict_pcl = evaluator.eval_pointcloud(
133 | pointcloud, pointcloud_tgt)
134 | for k, v in eval_dict_pcl.items():
135 | eval_dict[k + ' (pcl)'] = v
136 | else:
137 | print('Warning: pointcloud does not exist: %s'
138 | % pointcloud_file)
139 |
140 |
141 | # Create pandas dataframe and save
142 | eval_df = pd.DataFrame(eval_dicts)
143 | eval_df.set_index(['idx'], inplace=True)
144 | eval_df.to_pickle(out_file)
145 |
146 | # Create CSV file with main statistics
147 | eval_df_class = eval_df.groupby(by=['class name']).mean()
148 | eval_df_class.loc['mean'] = eval_df_class.mean()
149 | eval_df_class.to_csv(out_file_class)
150 |
151 | # Print results
152 | print(eval_df_class)
153 |
154 | if __name__ == '__main__':
155 | main()
--------------------------------------------------------------------------------
/SAP/optim_hierarchy.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | import argparse
3 | from src.utils import load_config
4 | import subprocess
5 | os.environ['MKL_THREADING_LAYER'] = 'GNU'
6 |
7 | def main():
8 |
9 | parser = argparse.ArgumentParser(description='MNIST toy experiment')
10 | parser.add_argument('config', type=str, help='Path to config file.')
11 | parser.add_argument('--start_res', type=int, default=-1, help='Resolution to start with.')
12 | parser.add_argument('--object_id', type=int, default=-1, help='Object index.')
13 |
14 | args, unknown = parser.parse_known_args()
15 | cfg = load_config(args.config, 'configs/default.yaml')
16 |
17 | resolutions=[32, 64, 128, 256]
18 | iterations=[1000, 1000, 1000, 200]
19 | lrs=[2e-3, 2e-3*0.7, 2e-3*(0.7**2), 2e-3*(0.7**3)] # reduce lr
20 | for idx,(res, iteration, lr) in enumerate(zip(resolutions, iterations, lrs)):
21 |
22 | if rescfg['model']['grid_res']:
26 | continue
27 |
28 | psr_sigma= 2 if res<=128 else 3
29 |
30 | if res > 128:
31 | psr_sigma = 5 if 'thingi_noisy' in args.config else 3
32 |
33 | if args.object_id != -1:
34 | out_dir = os.path.join(cfg['train']['out_dir'], 'object_%02d'%args.object_id, 'res_%d'%res)
35 | else:
36 | out_dir = os.path.join(cfg['train']['out_dir'], 'res_%d'%res)
37 |
38 | # sample from mesh when resampling is enabled, otherwise reuse the pointcloud
39 | init_shape='mesh' if cfg['train']['resample_every']>0 else 'pointcloud'
40 |
41 |
42 | if args.object_id != -1:
43 | input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'],
44 | 'object_%02d'%args.object_id, 'res_%d' % (resolutions[idx-1]),
45 | 'vis', init_shape, '%04d.ply' % (iterations[idx-1]))
46 | else:
47 | input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'],
48 | 'res_%d' % (resolutions[idx-1]),
49 | 'vis', init_shape, '%04d.ply' % (iterations[idx-1]))
50 |
51 |
52 | cmd = 'export MKL_SERVICE_FORCE_INTEL=1 && '
53 | cmd += "python optim.py %s --model:grid_res %d --model:psr_sigma %d \
54 | --train:input_mesh %s --train:total_epochs %d \
55 | --train:out_dir %s --train:lr_pcl %f \
56 | --data:object_id %d" % (
57 | args.config,
58 | res,
59 | psr_sigma,
60 | input_mesh,
61 | iteration,
62 | out_dir,
63 | lr,
64 | args.object_id)
65 | print(cmd)
66 | os.system(cmd)
67 |
68 | if __name__=="__main__":
69 | main()
70 |
--------------------------------------------------------------------------------
/SAP/scripts/__pycache__/easy_mesh_vtk.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/scripts/__pycache__/easy_mesh_vtk.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/scripts/download_demo_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p data
3 | cd data
4 | echo "Start downloading ..."
5 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/demo.zip
6 | unzip demo.zip
7 | rm demo.zip
8 | echo "Done!"
--------------------------------------------------------------------------------
/SAP/scripts/download_optim_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p data
3 | cd data
4 | echo "Start downloading data for optimization-based setting (~200 MB)"
5 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/optim_data.zip
6 | unzip optim_data.zip
7 | rm optim_data.zip
8 | echo "Done!"
--------------------------------------------------------------------------------
/SAP/scripts/download_shapenet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p data
3 | cd data
4 | echo "Start downloading preprocessed ShapeNet data (~220G)"
5 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/shapenet_psr.zip
6 | unzip shapenet_psr.zip
7 | rm shapenet_psr.zip
8 | echo "Done!"
--------------------------------------------------------------------------------
/SAP/scripts/prcocess_tooth.py:
--------------------------------------------------------------------------------
1 | import os
2 | import open3d as o3d
3 | import torch
4 | import time
5 | import multiprocessing
6 | import numpy as np
7 | from tqdm import tqdm
8 | from src.dpsr import DPSR
9 | #from easy_mesh_vtk import Easy_Mesh
10 | import pyvista as pv
11 | import os
12 | from pymeshfix._meshfix import PyTMesh
13 | from pymeshfix import MeshFix
14 |
15 | #data_path = 'C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\1a9e1fb2a51ffd065b07a27512172330' # path for ShapeNet from ONet
16 | data_path_mesh='C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\data-watertightshell'
17 | data_pred_points="C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\test-sap-watertight-5epochs\\test-pointr\\A0_6014-36\\A0_6014-36.npy"
18 | base = 'data' # output base directory
19 | dataset_name = 'shapenet_psr'
20 | multiprocess = True
21 | njobs = 8
22 | save_pointcloud = True
23 | save_psr_field = True
24 | resolution = 128
25 | zero_level = 0.0
26 | num_points = 100000
27 | padding = 1.2
28 | out_path_cur_obj='C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\test-sap-watertight-5epochs\\test-pointr\\A0_6014-36'
29 |
30 |
31 | dpsr = DPSR(res=(resolution, resolution, resolution), sig=0)
32 | """""
33 | gt_path = os.path.join(data_path, 'pointcloud.npz')
34 | data = np.load(gt_path)
35 | points = data['points']
36 | normals = data['normals']
37 |
38 | # normalize the point to [0, 1)
39 | points = points / padding + 0.5
40 | # to scale back during inference, we should:
41 | # ! p = (p - 0.5) * padding
42 | pcd = o3d.geometry.PointCloud()
43 | pcd.points = o3d.utility.Vector3dVector(points)
44 | pcd.normals= o3d.utility.Vector3dVector(normals)
45 | pcd.paint_uniform_color([0, 0.45, 0])
46 |
47 | o3d.visualization.draw_geometries([pcd], point_show_normal=True)
48 |
49 | if save_pointcloud:
50 | outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz')
51 | # np.savez(outdir, points=points, normals=normals)
52 | np.savez(outdir, points=data['points'], normals=data['normals'])
53 | # return
54 |
55 | if save_psr_field:
56 | psr_gt = dpsr(torch.from_numpy(points.astype(np.float32))[None],
57 | torch.from_numpy(normals.astype(np.float32))[None]).squeeze().cpu().numpy().astype(np.float16)
58 |
59 | outdir = os.path.join(out_path_cur_obj, 'psr.npz')
60 | np.savez(outdir, psr=psr_gt)
61 |
62 | """""
63 | #read the meshgt
64 | #for entry in os.listdir(data_path):
65 |
66 | #Open_mesh = pv.read(os.path.join(data_path_mesh,'shell26_registered.ply'))
67 | #meshfix = MeshFix(Open_mesh)
68 | #holes = meshfix.extract_holes()
69 | #meshfix.repair(verbose=True)
70 | #meshfix.save(os.path.join(data_path_mesh,'shell26_registered_watertight.ply' ))
71 |
72 | """""
73 | for cases in os.listdir(data_path_mesh):
74 | for i in os.listdir(os.path.join(data_path_mesh,cases)):
75 | if 'shell' in i:
76 | print('cases:',cases)
77 | mesh = o3d.io.read_triangle_mesh(os.path.join(data_path_mesh, cases, i))
78 | mesh = mesh.subdivide_loop(number_of_iterations=3)
79 | # mesh = Easy_Mesh(os.path.join(data_path_mesh, '45shell_registered.ply'))
80 | points = mesh.vertices
81 | print('points', np.asarray(points))
82 | Npoints = np.asarray(points)
83 | normals = mesh.compute_vertex_normals()
84 | Nnormals = np.asarray(normals.vertex_normals)
85 | print("normals", np.asarray(normals.vertex_normals))
86 | # randomple sample 100000
87 | points_sample = 100000
88 | positive_mesh_idx = np.arange(len(Npoints))
89 | try:
90 | positive_selected_mesh_idx = np.random.choice(positive_mesh_idx, size=points_sample, replace=False)
91 | except ValueError:
92 | positive_selected_mesh_idx = np.random.choice(positive_mesh_idx, size=points_sample, replace=True)
93 | # mesh_with_newpoints = np.zeros([points_sample, Npoints.shape[1]], dtype='float32')
94 | Npoints = Npoints[positive_selected_mesh_idx, :]
95 | print('shape', Npoints.shape)
96 | Nnormals = Nnormals[positive_selected_mesh_idx, :]
97 |
98 | # normalize the point to [0, 1)
99 | Npoints = (Npoints - np.min(Npoints)) / (np.max(Npoints) + 1 - np.min(Npoints))
100 | pcd = o3d.geometry.PointCloud()
101 | pcd.points = o3d.utility.Vector3dVector(Npoints)
102 | pcd.normals= o3d.utility.Vector3dVector(Nnormals)
103 | pcd.paint_uniform_color([0, 0.45, 0])
104 |
105 | o3d.visualization.draw_geometries([pcd], point_show_normal=True)
106 | # mean/std
107 | # Npoints_mean=np.mean(Npoints)
108 | # Npoints_std=np.std(Npoints)
109 | # Npoints=(Npoints-Npoints_mean) / Npoints_std
110 | # Npoints = np.asarray(Npoints) / padding + 0.5
111 | # to scale back during inference, we should:
112 | # ! p = (p - 0.5) * padding
113 |
114 | if save_pointcloud:
115 | outdir = os.path.join(out_path_cur_obj, cases, 'pointcloud.npz')
116 | # np.savez(outdir, points=points, normals=normals)
117 | np.savez(outdir, points=Npoints, normals=np.asarray(Nnormals))
118 | # return
119 |
120 | if save_psr_field:
121 | psr_gt = dpsr(torch.from_numpy(Npoints.astype(np.float32))[None],
122 | torch.from_numpy(Nnormals[None].astype(np.float32))).squeeze().cpu().numpy().astype(
123 | np.float16)
124 |
125 | outdir = os.path.join(out_path_cur_obj, cases, 'psr.npz')
126 | np.savez(outdir, psr=psr_gt)
127 |
128 | """""
129 |
130 | points=np.load(data_pred_points)
131 | # normalize the point to [0, 1)
132 | Npoints = (points - np.min(points)) / (np.max(points) + 1 - np.min(points))
133 | pcd = o3d.geometry.PointCloud()
134 | pcd.points = o3d.utility.Vector3dVector(Npoints)
135 | pcd.estimate_normals()
136 | pcd_normals = np.asarray(pcd.normals)
137 | #pcd.normals = o3d.utility.Vector3dVector(Nnormals)
138 | pcd.paint_uniform_color([0, 0.45, 0])
139 |
140 | #o3d.visualization.draw_geometries([pcd], point_show_normal=True)
141 |
142 |
143 | if save_pointcloud:
144 | outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz')
145 | # np.savez(outdir, points=points, normals=normals)
146 | np.savez(outdir, points=Npoints, normals=np.asarray(pcd_normals))
147 | # return
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
--------------------------------------------------------------------------------
/SAP/scripts/process_shapenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import time
4 | import multiprocessing
5 | import numpy as np
6 | from tqdm import tqdm
7 | from src.dpsr import DPSR
8 |
9 | data_path = 'data/ShapeNet' # path for ShapeNet from ONet
10 | base = 'data' # output base directory
11 | dataset_name = 'shapenet_psr'
12 | multiprocess = True
13 | njobs = 8
14 | save_pointcloud = True
15 | save_psr_field = True
16 | resolution = 128
17 | zero_level = 0.0
18 | num_points = 100000
19 | padding = 1.2
20 |
21 | dpsr = DPSR(res=(resolution, resolution, resolution), sig=0)
22 |
23 | def process_one(obj):
24 |
25 | obj_name = obj.split('/')[-1]
26 | c = obj.split('/')[-2]
27 |
28 | # create new for the current object
29 | out_path_cur = os.path.join(base, dataset_name, c)
30 | out_path_cur_obj = os.path.join(out_path_cur, obj_name)
31 | os.makedirs(out_path_cur_obj, exist_ok=True)
32 |
33 | gt_path = os.path.join(data_path, c, obj_name, 'pointcloud.npz')
34 | data = np.load(gt_path)
35 | points = data['points']
36 | normals = data['normals']
37 |
38 | # normalize the point to [0, 1)
39 | points = points / padding + 0.5
40 | # to scale back during inference, we should:
41 | #! p = (p - 0.5) * padding
42 |
43 | if save_pointcloud:
44 | outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz')
45 | # np.savez(outdir, points=points, normals=normals)
46 | np.savez(outdir, points=data['points'], normals=data['normals'])
47 | # return
48 |
49 | if save_psr_field:
50 | psr_gt = dpsr(torch.from_numpy(points.astype(np.float32))[None],
51 | torch.from_numpy(normals.astype(np.float32))[None]).squeeze().cpu().numpy().astype(np.float16)
52 |
53 | outdir = os.path.join(out_path_cur_obj, 'psr.npz')
54 | np.savez(outdir, psr=psr_gt)
55 |
56 |
57 | def main(c):
58 |
59 | print('---------------------------------------')
60 | print('Processing {} {}'.format(c, split))
61 | print('---------------------------------------')
62 |
63 | for split in ['train', 'val', 'test']:
64 | fname = os.path.join(data_path, c, split+'.lst')
65 | with open(fname, 'r') as f:
66 | obj_list = f.read().splitlines()
67 |
68 | obj_list = [c+'/'+s for s in obj_list]
69 |
70 | if multiprocess:
71 | # multiprocessing.set_start_method('spawn', force=True)
72 | pool = multiprocessing.Pool(njobs)
73 | try:
74 | for _ in tqdm(pool.imap_unordered(process_one, obj_list), total=len(obj_list)):
75 | pass
76 | # pool.map_async(process_one, obj_list).get()
77 | except KeyboardInterrupt:
78 | # Allow ^C to interrupt from any thread.
79 | exit()
80 | pool.close()
81 | else:
82 | for obj in tqdm(obj_list):
83 | process_one(obj)
84 |
85 | print('Done Processing {} {}!'.format(c, split))
86 |
87 |
88 | if __name__ == "__main__":
89 |
90 | classes = ['02691156', '02828884', '02933112',
91 | '02958343', '03211117', '03001627',
92 | '03636649', '03691459', '04090263',
93 | '04256520', '04379243', '04401088', '04530566']
94 |
95 |
96 | t_start = time.time()
97 | for c in classes:
98 | main(c)
99 |
100 | t_end = time.time()
101 | print('Total processing time: ', t_end - t_start)
102 |
--------------------------------------------------------------------------------
/SAP/scripts/process_tooth.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import numpy as np
3 | from tqdm import tqdm
4 | from src.dpsr import DPSR
5 | #from easy_mesh_vtk import Easy_Mesh
6 | import pyvista as pv
7 | import os
8 | from pymeshfix._meshfix import PyTMesh
9 | from pymeshfix import MeshFix
10 | import open3d as o3d
11 | import torch
12 |
13 | data_path = 'C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\datasap' # path for ShapeNet from ONet
14 | data_path_mesh='C:/Users/Golriz/OneDrive - polymtl.ca/Desktop/mesh'
15 | base = 'data' # output base directory
16 | dataset_name = 'shapenet_psr'
17 | multiprocess = True
18 | njobs = 8
19 | save_pointcloud = True
20 | save_psr_field = True
21 | resolution = 128
22 | zero_level = 0.0
23 | num_points = 100000
24 | padding = 1.2
25 | out_path_cur_obj='C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\datasap'
26 | dpsr = DPSR(res=(resolution, resolution, resolution), sig=0)
27 |
28 |
29 |
30 |
31 |
32 | mesh=o3d.io.read_triangle_mesh(os.path.join(data_path_mesh, 'shell26_registered_watertight.ply'))
33 | mesh=mesh.subdivide_loop(number_of_iterations=3)
34 | #mesh = Easy_Mesh(os.path.join(data_path_mesh, '45shell_registered.ply'))
35 | points = mesh.vertices
36 | print('points',np.asarray(points))
37 | Npoints = np.asarray(points)
38 | normals = mesh.compute_vertex_normals()
39 | Nnormals = np.asarray(normals.vertex_normals)
40 | print("normals",np.asarray(normals.vertex_normals))
41 | #randomple sample 100000
42 | points_sample = 100000
43 | positive_mesh_idx = np.arange(len(Npoints))
44 | try:
45 | positive_selected_mesh_idx = np.random.choice(positive_mesh_idx, size=points_sample, replace=False)
46 | except ValueError:
47 | positive_selected_mesh_idx = np.random.choice(positive_mesh_idx, size=points_sample, replace=True)
48 | #mesh_with_newpoints = np.zeros([points_sample, Npoints.shape[1]], dtype='float32')
49 | Npoints = Npoints[positive_selected_mesh_idx, :]
50 | print('shape',Npoints.shape)
51 | Nnormals= Nnormals[positive_selected_mesh_idx, :]
52 |
53 |
54 | # normalize the point to [0, 1)
55 | Npoints=(Npoints-np.min(Npoints))/(np.max(Npoints)+1-np.min(Npoints))
56 |
57 | #mean/std
58 | #Npoints_mean=np.mean(Npoints)
59 | #Npoints_std=np.std(Npoints)
60 | #Npoints=(Npoints-Npoints_mean) / Npoints_std
61 | #Npoints = np.asarray(Npoints) / padding + 0.5
62 | # to scale back during inference, we should:
63 | # ! p = (p - 0.5) * padding
64 |
65 |
66 | if save_pointcloud:
67 | outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz')
68 | # np.savez(outdir, points=points, normals=normals)
69 | np.savez(outdir, points=Npoints, normals=np.asarray(Nnormals))
70 | # return
71 |
72 | if save_psr_field:
73 | psr_gt = dpsr(torch.from_numpy(Npoints.astype(np.float32))[None],
74 | torch.from_numpy(Nnormals[None].astype(np.float32))).squeeze().cpu().numpy().astype(np.float16)
75 |
76 | outdir = os.path.join(out_path_cur_obj, 'psr.npz')
77 | np.savez(outdir, psr=psr_gt)
--------------------------------------------------------------------------------
/SAP/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/__init__.py
--------------------------------------------------------------------------------
/SAP/src/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/__pycache__/dpsr.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/__pycache__/dpsr.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/__pycache__/generation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/__pycache__/generation.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/__pycache__/training.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/__pycache__/training.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from torchvision import transforms
3 | from ..src import data, generation
4 | from ..src.dpsr import DPSR
5 | from ipdb import set_trace as st
6 |
7 |
8 | # Generator for final mesh extraction
9 | def get_generator(model, cfg, device, **kwargs):
10 | ''' Returns the generator object.
11 |
12 | Args:
13 | model (nn.Module): Occupancy Network model
14 | cfg (dict): imported yaml config
15 | device (device): pytorch device
16 | '''
17 |
18 | if cfg['generation']['psr_resolution'] == 0:
19 | psr_res = cfg['model']['grid_res']
20 | psr_sigma = cfg['model']['psr_sigma']
21 | else:
22 | psr_res = cfg['generation']['psr_resolution']
23 | psr_sigma = cfg['generation']['psr_sigma']
24 |
25 | dpsr = DPSR(res=(psr_res, psr_res, psr_res),
26 | sig= psr_sigma).to(device)
27 |
28 |
29 | generator = generation.Generator3D(
30 | model,
31 | device=device,
32 | threshold=cfg['data']['zero_level'],
33 | sample=cfg['generation']['use_sampling'],
34 | input_type = cfg['data']['input_type'],
35 | padding=cfg['data']['padding'],
36 | dpsr=dpsr,
37 | psr_tanh=cfg['model']['psr_tanh']
38 | )
39 | return generator
40 |
41 | # Datasets
42 | def get_dataset(mode, cfg, return_idx=False):
43 | ''' Returns the dataset.
44 |
45 | Args:
46 | model (nn.Module): the model which is used
47 | cfg (dict): config dictionary
48 | return_idx (bool): whether to include an ID field
49 | '''
50 | dataset_type = cfg['data']['dataset']
51 | dataset_folder = cfg['data']['path']
52 | categories = cfg['data']['class']
53 |
54 | # Get split
55 | splits = {
56 | 'train': cfg['data']['train_split'],
57 | 'val': cfg['data']['val_split'],
58 | 'test': cfg['data']['test_split'],
59 | 'vis': cfg['data']['val_split'],
60 | }
61 |
62 | split = splits[mode]
63 |
64 | # Create dataset
65 | if dataset_type == 'Shapes3D':
66 | fields = get_data_fields(mode, cfg)
67 | # Input fields
68 | inputs_field = get_inputs_field(mode, cfg)
69 | if inputs_field is not None:
70 | fields['inputs'] = inputs_field
71 |
72 | if return_idx:
73 | fields['idx'] = data.IndexField()
74 |
75 | dataset = data.Shapes3dDataset(
76 | dataset_folder, fields,
77 | split=split,
78 | categories=categories,
79 | cfg = cfg
80 | )
81 | else:
82 | raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset'])
83 |
84 | return dataset
85 |
86 |
87 | def get_inputs_field(mode, cfg):
88 | ''' Returns the inputs fields.
89 |
90 | Args:
91 | mode (str): the mode which is used
92 | cfg (dict): config dictionary
93 | '''
94 | input_type = cfg['data']['input_type']
95 |
96 | if input_type is None:
97 | inputs_field = None
98 | elif input_type == 'pointcloud':
99 | noise_level = cfg['data']['pointcloud_noise']
100 | if cfg['data']['pointcloud_outlier_ratio']>0:
101 | transform = transforms.Compose([
102 | data.SubsamplePointcloud(cfg['data']['pointcloud_n']),
103 | data.PointcloudNoise(noise_level),
104 | data.PointcloudOutliers(cfg['data']['pointcloud_outlier_ratio'])
105 | ])
106 | else:
107 | transform = transforms.Compose([
108 | data.SubsamplePointcloud(cfg['data']['pointcloud_n']),
109 | data.PointcloudNoise(noise_level)
110 | ])
111 |
112 | data_type = cfg['data']['data_type']
113 | inputs_field = data.PointCloudField(
114 | cfg['data']['pointcloud_file'], data_type, transform,
115 | multi_files= cfg['data']['multi_files']
116 | )
117 | else:
118 | raise ValueError(
119 | 'Invalid input type (%s)' % input_type)
120 | return inputs_field
121 |
122 | def get_data_fields(mode, cfg):
123 | ''' Returns the data fields.
124 |
125 | Args:
126 | mode (str): the mode which is used
127 | cfg (dict): imported yaml config
128 | '''
129 | data_type = cfg['data']['data_type']
130 | fields = {}
131 |
132 | if (mode in ('val', 'test')):
133 | transform = data.SubsamplePointcloud(100000)
134 | else:
135 | transform = data.SubsamplePointcloud(cfg['data']['num_gt_points'])
136 |
137 | data_name = cfg['data']['pointcloud_file']
138 | fields['gt_points'] = data.PointCloudField(data_name,
139 | transform=transform, data_type=data_type, multi_files=cfg['data']['multi_files'])
140 | if data_type == 'psr_full':
141 | if mode != 'test':
142 | fields['gt_psr'] = data.FullPSRField(multi_files=cfg['data']['multi_files'])
143 | else:
144 | raise ValueError('Invalid data type (%s)' % data_type)
145 |
146 | return fields
--------------------------------------------------------------------------------
/SAP/src/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from ..data.core import (
3 | Shapes3dDataset, collate_remove_none, worker_init_fn, collate_stack_together
4 | )
5 | from ..data.fields import (
6 | IndexField, PointCloudField, FullPSRField
7 | )
8 | from ..data.transforms import (
9 | PointcloudNoise, SubsamplePointcloud,
10 | PointcloudOutliers,
11 | )
12 | __all__ = [
13 | # Core
14 | Shapes3dDataset,
15 | collate_remove_none,
16 | worker_init_fn,
17 | # Fields
18 | IndexField,
19 | PointCloudField,
20 | FullPSRField,
21 | # Transforms
22 | PointcloudNoise,
23 | SubsamplePointcloud,
24 | PointcloudOutliers,
25 | ]
26 |
--------------------------------------------------------------------------------
/SAP/src/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/data/__pycache__/core.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/data/__pycache__/core.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/data/__pycache__/fields.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/data/__pycache__/fields.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/data/__pycache__/transforms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/data/__pycache__/transforms.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/data/fields.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import time
4 | import random
5 | from PIL import Image
6 | import numpy as np
7 | import trimesh
8 | from ..data.core import Field
9 | from pdb import set_trace as st
10 |
11 |
12 | class IndexField(Field):
13 | ''' Basic index field.'''
14 | def load(self, model_path, idx, category):
15 | ''' Loads the index field.
16 |
17 | Args:
18 | model_path (str): path to model
19 | idx (int): ID of data point
20 | category (int): index of category
21 | '''
22 | return idx
23 |
24 | def check_complete(self, files):
25 | ''' Check if field is complete.
26 |
27 | Args:
28 | files: files
29 | '''
30 | return True
31 |
32 | class FullPSRField(Field):
33 | def __init__(self, transform=None, multi_files=None):
34 | self.transform = transform
35 | # self.unpackbits = unpackbits
36 | self.multi_files = multi_files
37 |
38 | def load(self, model_path, idx, category):
39 |
40 | # try:
41 | # t0 = time.time()
42 | if self.multi_files is not None:
43 | psr_path = os.path.join(model_path, 'psr', 'psr_{:02d}.npz'.format(idx))
44 | else:
45 | psr_path = os.path.join(model_path, 'psr.npz')
46 | psr_dict = np.load(psr_path)
47 | # t1 = time.time()
48 | psr = psr_dict['psr']
49 | psr = psr.astype(np.float32)
50 | # t2 = time.time()
51 | # print('load PSR: {:.4f}, change type: {:.4f}, total: {:.4f}'.format(t1 - t0, t2 - t1, t2-t0))
52 | data = {None: psr}
53 |
54 | if self.transform is not None:
55 | data = self.transform(data)
56 |
57 | return data
58 |
59 | class PointCloudField(Field):
60 | ''' Point cloud field.
61 |
62 | It provides the field used for point cloud data. These are the points
63 | randomly sampled on the mesh.
64 |
65 | Args:
66 | file_name (str): file name
67 | transform (list): list of transformations applied to data points
68 | multi_files (callable): number of files
69 | '''
70 | def __init__(self, file_name, data_type=None, transform=None, multi_files=None, padding=0.1, scale=1.2):
71 | self.file_name = file_name
72 | self.data_type = data_type # to make sure the range of input is correct
73 | self.transform = transform
74 | self.multi_files = multi_files
75 | self.padding = padding
76 | self.scale = scale
77 |
78 | def load(self, model_path, idx, category):
79 | ''' Loads the data point.
80 |
81 | Args:
82 | model_path (str): path to model
83 | idx (int): ID of data point
84 | category (int): index of category
85 | '''
86 | if self.multi_files is None:
87 | file_path = os.path.join(model_path, self.file_name)
88 | else:
89 | # num = np.random.randint(self.multi_files)
90 | # file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num))
91 | file_path = os.path.join(model_path, self.file_name, 'pointcloud_%02d.npz' % (idx))
92 |
93 | pointcloud_dict = np.load(file_path)
94 |
95 | points = pointcloud_dict['points'].astype(np.float32)
96 | normals = pointcloud_dict['normals'].astype(np.float32)
97 |
98 | data = {
99 | None: points,
100 | 'normals': normals,
101 | }
102 | if self.transform is not None:
103 | data = self.transform(data)
104 |
105 | if self.data_type == 'psr_full':
106 | print("psr_full")
107 | #scale the point cloud to the range of (0, 1)
108 | #data[None] = data[None] / self.scale + 0.5 # we already scale data
109 |
110 | return data
111 |
112 | def check_complete(self, files):
113 | ''' Check if field is complete.
114 |
115 | Args:
116 | files: files
117 | '''
118 | complete = (self.file_name in files)
119 | return complete
120 |
--------------------------------------------------------------------------------
/SAP/src/data/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | # Transforms
5 | class PointcloudNoise(object):
6 | ''' Point cloud noise transformation class.
7 |
8 | It adds noise to point cloud data.
9 |
10 | Args:
11 | stddev (int): standard deviation
12 | '''
13 |
14 | def __init__(self, stddev):
15 | self.stddev = stddev
16 |
17 | def __call__(self, data):
18 | ''' Calls the transformation.
19 |
20 | Args:
21 | data (dictionary): data dictionary
22 | '''
23 | data_out = data.copy()
24 | points = data[None]
25 | noise = self.stddev * np.random.randn(*points.shape)
26 | noise = noise.astype(np.float32)
27 | data_out[None] = points + noise
28 | return data_out
29 |
30 | class PointcloudOutliers(object):
31 | ''' Point cloud outlier transformation class.
32 |
33 | It adds outliers to point cloud data.
34 |
35 | Args:
36 | ratio (int): outlier percentage to the entire point cloud
37 | '''
38 |
39 | def __init__(self, ratio):
40 | self.ratio = ratio
41 |
42 | def __call__(self, data):
43 | ''' Calls the transformation.
44 |
45 | Args:
46 | data (dictionary): data dictionary
47 | '''
48 | data_out = data.copy()
49 | points = data[None]
50 | n_points = points.shape[0]
51 | n_outlier_points = int(n_points*self.ratio)
52 | ind = np.random.randint(0, n_points, n_outlier_points)
53 |
54 | outliers = np.random.uniform(-0.55, 0.55, (n_outlier_points, 3))
55 | outliers = outliers.astype(np.float32)
56 | points[ind] = outliers
57 | data_out[None] = points
58 | return data_out
59 |
60 | class SubsamplePointcloud(object):
61 | ''' Point cloud subsampling transformation class.
62 |
63 | It subsamples the point cloud data.
64 |
65 | Args:
66 | N (int): number of points to be subsampled
67 | '''
68 | def __init__(self, N):
69 | self.N = N
70 |
71 | def __call__(self, data):
72 | ''' Calls the transformation.
73 |
74 | Args:
75 | data (dict): data dictionary
76 | '''
77 | data_out = data.copy()
78 | points = data[None]
79 |
80 | indices = np.random.randint(points.shape[0], size=self.N)
81 | data_out[None] = points[indices, :]
82 | if 'normals' in data.keys():
83 | normals = data['normals']
84 | data_out['normals'] = normals[indices, :]
85 |
86 | return data_out
--------------------------------------------------------------------------------
/SAP/src/dpsr.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | from ..src.utils import spec_gaussian_filter, fftfreqs, img, grid_interp, point_rasterize
5 | import numpy as np
6 | import torch.fft
7 |
8 | class DPSR(nn.Module):
9 | def __init__(self, res, sig=10, scale=True, shift=True):
10 | """
11 | :param res: tuple of output field resolution. eg., (128,128)
12 | :param sig: degree of gaussian smoothing
13 | """
14 | super(DPSR, self).__init__()
15 | self.res = res
16 | self.sig = sig
17 | self.dim = len(res)
18 | self.denom = np.prod(res)
19 | G = spec_gaussian_filter(res=res, sig=sig).float()
20 | # self.G.requires_grad = False # True, if we also make sig a learnable parameter
21 | self.omega = fftfreqs(res, dtype=torch.float32)
22 | self.scale = scale
23 | self.shift = shift
24 | self.register_buffer("G", G)
25 |
26 | def forward(self, V, N):
27 | """
28 | :param V: (batch, nv, 2 or 3) tensor for point cloud coordinates
29 | :param N: (batch, nv, 2 or 3) tensor for point normals
30 | :return phi: (batch, res, res, ...) tensor of output indicator function field
31 | """
32 | assert(V.shape == N.shape) # [b, nv, ndims]
33 | ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2]
34 |
35 | ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4))
36 | ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1]))
37 | N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
38 |
39 | omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1]
40 | omega *= 2 * np.pi # normalize frequencies
41 | omega = omega.to(V.device)
42 |
43 | DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2)
44 |
45 | Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
46 | Phi = DivN / (Lap+1e-6) # [b, dim0, dim1, dim2/2+1, 2]
47 | Phi = Phi.permute(*tuple([list(range(1,self.dim+2)) + [0]])) # [dim0, dim1, dim2/2+1, 2, b]
48 | Phi[tuple([0] * self.dim)] = 0
49 | Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2]
50 |
51 | phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3))
52 |
53 | if self.shift or self.scale:
54 | # ensure values at points are zero
55 | fv = grid_interp(phi.unsqueeze(-1), V, batched=True).squeeze(-1) # [b, nv]
56 | if self.shift: # offset points to have mean of 0
57 | offset = torch.mean(fv, dim=-1) # [b,]
58 | phi -= offset.view(*tuple([-1] + [1] * self.dim))
59 |
60 | phi = phi.permute(*tuple([list(range(1,self.dim+1)) + [0]]))
61 | fv0 = phi[tuple([0] * self.dim)] # [b,]
62 | phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))]))
63 |
64 | if self.scale:
65 | phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5
66 | return phi
--------------------------------------------------------------------------------
/SAP/src/eval.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | import trimesh
4 | from pykdtree.kdtree import KDTree
5 |
6 | EMPTY_PCL_DICT = {
7 | 'completeness': np.sqrt(3),
8 | 'accuracy': np.sqrt(3),
9 | 'completeness2': 3,
10 | 'accuracy2': 3,
11 | 'chamfer': 6,
12 | }
13 |
14 | EMPTY_PCL_DICT_NORMALS = {
15 | 'normals completeness': -1.,
16 | 'normals accuracy': -1.,
17 | 'normals': -1.,
18 | }
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | class MeshEvaluator(object):
24 | ''' Mesh evaluation class.
25 | It handles the mesh evaluation process.
26 | Args:
27 | n_points (int): number of points to be used for evaluation
28 | '''
29 |
30 | def __init__(self, n_points=100000):
31 | self.n_points = n_points
32 |
33 | def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1./1000, 1, 1000)):
34 | ''' Evaluates a mesh.
35 | Args:
36 | mesh (trimesh): mesh which should be evaluated
37 | pointcloud_tgt (numpy array): target point cloud
38 | normals_tgt (numpy array): target normals
39 | thresholds (numpy arry): for F-Score
40 | '''
41 | if len(mesh.vertices) != 0 and len(mesh.faces) != 0:
42 | pointcloud, idx = mesh.sample(self.n_points, return_index=True)
43 |
44 | pointcloud = pointcloud.astype(np.float32)
45 | normals = mesh.face_normals[idx]
46 | else:
47 | pointcloud = np.empty((0, 3))
48 | normals = np.empty((0, 3))
49 |
50 | out_dict = self.eval_pointcloud(
51 | pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds)
52 |
53 | return out_dict
54 |
55 | def eval_pointcloud(self, pointcloud, pointcloud_tgt,
56 | normals=None, normals_tgt=None,
57 | thresholds=np.linspace(1./1000, 1, 1000)):
58 | ''' Evaluates a point cloud.
59 | Args:
60 | pointcloud (numpy array): predicted point cloud
61 | pointcloud_tgt (numpy array): target point cloud
62 | normals (numpy array): predicted normals
63 | normals_tgt (numpy array): target normals
64 | thresholds (numpy array): threshold values for the F-score calculation
65 | '''
66 | # Return maximum losses if pointcloud is empty
67 | if pointcloud.shape[0] == 0:
68 | logger.warn('Empty pointcloud / mesh detected!')
69 | out_dict = EMPTY_PCL_DICT.copy()
70 | if normals is not None and normals_tgt is not None:
71 | out_dict.update(EMPTY_PCL_DICT_NORMALS)
72 | return out_dict
73 |
74 | pointcloud = np.asarray(pointcloud)
75 | pointcloud_tgt = np.asarray(pointcloud_tgt)
76 |
77 |
78 | # Completeness: how far are the points of the target point cloud
79 | # from thre predicted point cloud
80 | completeness, completeness_normals = distance_p2p(
81 | pointcloud_tgt, normals_tgt, pointcloud, normals
82 | )
83 | recall = get_threshold_percentage(completeness, thresholds)
84 | completeness2 = completeness**2
85 |
86 | completeness = completeness.mean()
87 | completeness2 = completeness2.mean()
88 | completeness_normals = completeness_normals.mean()
89 |
90 | # Accuracy: how far are th points of the predicted pointcloud
91 | # from the target pointcloud
92 | accuracy, accuracy_normals = distance_p2p(
93 | pointcloud, normals, pointcloud_tgt, normals_tgt
94 | )
95 | precision = get_threshold_percentage(accuracy, thresholds)
96 | accuracy2 = accuracy**2
97 |
98 | accuracy = accuracy.mean()
99 | accuracy2 = accuracy2.mean()
100 | accuracy_normals = accuracy_normals.mean()
101 |
102 | # Chamfer distance
103 | chamferL2 = 0.5 * (completeness2 + accuracy2)
104 | normals_correctness = (
105 | 0.5 * completeness_normals + 0.5 * accuracy_normals
106 | )
107 | chamferL1 = 0.5 * (completeness + accuracy)
108 |
109 | # F-Score
110 | F = [
111 | 2 * precision[i] * recall[i] / (precision[i] + recall[i])
112 | for i in range(len(precision))
113 | ]
114 |
115 | out_dict = {
116 | 'completeness': completeness,
117 | 'accuracy': accuracy,
118 | 'normals completeness': completeness_normals,
119 | 'normals accuracy': accuracy_normals,
120 | 'normals': normals_correctness,
121 | 'completeness2': completeness2,
122 | 'accuracy2': accuracy2,
123 | 'chamfer-L2': chamferL2,
124 | 'chamfer-L1': chamferL1,
125 | 'f-score': F[9], # threshold = 1.0%
126 | 'f-score-15': F[14], # threshold = 1.5%
127 | 'f-score-20': F[19], # threshold = 2.0%
128 | }
129 |
130 | return out_dict
131 |
132 |
133 | def distance_p2p(points_src, normals_src, points_tgt, normals_tgt):
134 | ''' Computes minimal distances of each point in points_src to points_tgt.
135 | Args:
136 | points_src (numpy array): source points
137 | normals_src (numpy array): source normals
138 | points_tgt (numpy array): target points
139 | normals_tgt (numpy array): target normals
140 | '''
141 | kdtree = KDTree(points_tgt)
142 | dist, idx = kdtree.query(points_src)
143 |
144 | if normals_src is not None and normals_tgt is not None:
145 | normals_src = \
146 | normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True)
147 | normals_tgt = \
148 | normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True)
149 |
150 | normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1)
151 | # Handle normals that point into wrong direction gracefully
152 | # (mostly due to mehtod not caring about this in generation)
153 | normals_dot_product = np.abs(normals_dot_product)
154 | else:
155 | normals_dot_product = np.array(
156 | [np.nan] * points_src.shape[0], dtype=np.float32)
157 | return dist, normals_dot_product
158 |
159 | def get_threshold_percentage(dist, thresholds):
160 | ''' Evaluates a point cloud.
161 | Args:
162 | dist (numpy array): calculated distance
163 | thresholds (numpy array): threshold values for the F-score calculation
164 | '''
165 | in_threshold = [
166 | (dist <= t).mean() for t in thresholds
167 | ]
168 | return in_threshold
--------------------------------------------------------------------------------
/SAP/src/generation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import trimesh
4 | import numpy as np
5 | from ..src.utils import mc_from_psr
6 |
7 | class Generator3D(object):
8 | ''' Generator class for Occupancy Networks.
9 |
10 | It provides functions to generate the final mesh as well refining options.
11 |
12 | Args:
13 | model (nn.Module): trained Occupancy Network model
14 | points_batch_size (int): batch size for points evaluation
15 | threshold (float): threshold value
16 | device (device): pytorch device
17 | padding (float): how much padding should be used for MISE
18 | sample (bool): whether z should be sampled
19 | input_type (str): type of input
20 | '''
21 |
22 | def __init__(self, model, points_batch_size=100000,
23 | threshold=0.5, device=None, padding=0.1,
24 | sample=False, input_type = None, dpsr=None, psr_tanh=True):
25 | self.model = model.to(device)
26 | self.points_batch_size = points_batch_size
27 | self.threshold = threshold
28 | self.device = device
29 | self.input_type = input_type
30 | self.padding = padding
31 | self.sample = sample
32 | self.dpsr = dpsr
33 | self.psr_tanh = psr_tanh
34 |
35 | def generate_mesh(self, data, return_stats=True):
36 | ''' Generates the output mesh.
37 |
38 | Args:
39 | data (tensor): data tensor
40 | return_stats (bool): whether stats should be returned
41 | '''
42 | self.model.eval()
43 | device = self.device
44 | stats_dict = {}
45 |
46 | p = data.get('inputs', torch.empty(1, 0)).to(device)
47 |
48 | t0 = time.time()
49 | points, normals = self.model(p)
50 | t1 = time.time()
51 | psr_grid = self.dpsr(points, normals)
52 | t2 = time.time()
53 | v, f, _ = mc_from_psr(psr_grid,
54 | zero_level=self.threshold)
55 | stats_dict['pcl'] = t1 - t0
56 | stats_dict['dpsr'] = t2 - t1
57 | stats_dict['mc'] = time.time() - t2
58 | stats_dict['total'] = time.time() - t0
59 |
60 | if return_stats:
61 | return v, f, points, normals, stats_dict
62 | else:
63 | return v, f, points, normals
--------------------------------------------------------------------------------
/SAP/src/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import yaml
3 | import numpy as np
4 | import time
5 | from ..src.utils import point_rasterize, grid_interp, mc_from_psr, \
6 | calc_inters_points
7 | from ..src.dpsr import DPSR
8 | import torch.nn as nn
9 | from ..src.network import encoder_dict, decoder_dict
10 | from ..src.network.utils import map2local
11 | from .utils import load_config
12 |
13 | class PSR2Mesh(torch.autograd.Function):
14 | @staticmethod
15 | def forward(ctx, psr_grid):
16 | """
17 | In the forward pass we receive a Tensor containing the input and return
18 | a Tensor containing the output. ctx is a context object that can be used
19 | to stash information for backward computation. You can cache arbitrary
20 | objects for use in the backward pass using the ctx.save_for_backward method.
21 | """
22 | verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
23 | verts = verts.unsqueeze(0)
24 | faces = faces.unsqueeze(0)
25 | normals = normals.unsqueeze(0)
26 |
27 | res = torch.tensor(psr_grid.detach().shape[2])
28 | ctx.save_for_backward(verts, normals, res)
29 |
30 | return verts, faces, normals
31 |
32 | @staticmethod
33 | def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals):
34 | """
35 | In the backward pass we receive a Tensor containing the gradient of the loss
36 | with respect to the output, and we need to compute the gradient of the loss
37 | with respect to the input.
38 | """
39 | vert_pts, normals, res = ctx.saved_tensors
40 | res = (res.item(), res.item(), res.item())
41 | # matrix multiplication between dL/dV and dV/dPSR
42 | # dV/dPSR = - normals
43 | grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0))
44 | grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res
45 |
46 | return grad_grid
47 |
48 | class PSR2SurfacePoints(torch.autograd.Function):
49 | @staticmethod
50 | def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample):
51 | verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
52 | verts = verts * 2. - 1. # within the range of [-1, 1]
53 |
54 |
55 | p_all, n_all, mask_all = [], [], []
56 |
57 | for i in range(len(poses)):
58 | pose = poses[i]
59 | if mask_sample is not None:
60 | p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size, mask_gt=mask_sample[i])
61 | else:
62 | p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size)
63 |
64 | n_inters = grid_interp(psr_grad[None], (p_inters[None].detach() + 1) / 2).squeeze()
65 | p_all.append(p_inters)
66 | n_all.append(n_inters)
67 | mask_all.append(mask)
68 | p_inters_all = torch.cat(p_all, dim=0)
69 | n_inters_all = torch.cat(n_all, dim=0)
70 | mask_visible = torch.stack(mask_all, dim=0)
71 |
72 |
73 | res = torch.tensor(psr_grid.detach().shape[2])
74 | ctx.save_for_backward(p_inters_all, n_inters_all, res)
75 |
76 | return p_inters_all, mask_visible
77 |
78 | @staticmethod
79 | def backward(ctx, dL_dp, dL_dmask):
80 | pts, pts_n, res = ctx.saved_tensors
81 | res = (res.item(), res.item(), res.item())
82 |
83 | # grad from the p_inters via MLP renderer
84 | grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None])
85 | grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res
86 |
87 | return grad_grid_pts, None, None, None, None, None
88 |
89 | class Encode2Points(nn.Module):
90 | def __init__(self, cfg):
91 | super().__init__()
92 |
93 | cfg = load_config(cfg)
94 | self.cfg = cfg
95 |
96 | encoder = cfg['model']['encoder']
97 | decoder = cfg['model']['decoder']
98 | dim = cfg['data']['dim'] # input dim
99 | c_dim = cfg['model']['c_dim']
100 | encoder_kwargs = cfg['model']['encoder_kwargs']
101 | if encoder_kwargs == None:
102 | encoder_kwargs = {}
103 | decoder_kwargs = cfg['model']['decoder_kwargs']
104 | padding = cfg['data']['padding']
105 | self.predict_normal = cfg['model']['predict_normal']
106 | self.predict_offset = cfg['model']['predict_offset']
107 |
108 | out_dim = 3
109 | out_dim_offset = 3
110 | num_offset = cfg['data']['num_offset']
111 | # each point predict more than one offset to add output points
112 | if num_offset > 1:
113 | out_dim_offset = out_dim * num_offset
114 | self.num_offset = num_offset
115 |
116 | # local mapping
117 | self.map2local = None
118 | if cfg['model']['local_coord']:
119 | if 'unet' in encoder_kwargs.keys():
120 | unit_size = 1 / encoder_kwargs['plane_resolution']
121 | else:
122 | unit_size = 1 / encoder_kwargs['grid_resolution']
123 |
124 | local_mapping = map2local(unit_size)
125 |
126 | self.encoder = encoder_dict[encoder](
127 | dim=dim, c_dim=c_dim, map2local=local_mapping,
128 | **encoder_kwargs
129 | )
130 |
131 | if self.predict_normal:
132 | # decoder for normal prediction
133 | self.decoder_normal = decoder_dict[decoder](
134 | dim=dim, c_dim=c_dim, out_dim=out_dim,
135 | **decoder_kwargs)
136 | if self.predict_offset:
137 | # decoder for offset prediction
138 | self.decoder_offset = decoder_dict[decoder](
139 | dim=dim, c_dim=c_dim, out_dim=out_dim_offset,
140 | map2local=local_mapping,
141 | **decoder_kwargs)
142 |
143 | self.s_off = cfg['model']['s_offset']
144 |
145 | def forward(self, p):
146 | ''' Performs a forward pass through the network.
147 |
148 | Args:
149 | p (tensor): input unoriented points
150 | '''
151 |
152 | time_dict = {}
153 | mask = None
154 |
155 | batch_size = p.size(0)
156 | points = p.clone()
157 |
158 | # encode the input point cloud to a feature volume
159 | t0 = time.perf_counter()
160 | c = self.encoder(p)
161 | t1 = time.perf_counter()
162 | if self.predict_offset:
163 | offset = self.decoder_offset(p, c)
164 | # more than one offset is predicted per-point
165 | if self.num_offset > 1:
166 | points = points.repeat(1, 1, self.num_offset).reshape(batch_size, -1, 3)
167 | points = points + self.s_off * offset
168 | else:
169 | points = p
170 |
171 | if self.predict_normal:
172 | normals = self.decoder_normal(points, c)
173 | t2 = time.perf_counter()
174 |
175 | time_dict['encode'] = t1 - t0
176 | time_dict['predict'] = t2 - t1
177 |
178 | points = torch.clamp(points, 0.0, 0.99)
179 | if self.cfg['model']['normal_normalize']:
180 | normals = normals / (normals.norm(dim=-1, keepdim=True)+1e-8)
181 |
182 |
183 | return points, normals
184 |
185 |
--------------------------------------------------------------------------------
/SAP/src/model_rgb.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ..src.network.net_rgb import RenderingNetwork
3 | from ..src.utils import approx_psr_grad
4 | from pytorch3d.renderer import (
5 | RasterizationSettings,
6 | PerspectiveCameras,
7 | MeshRenderer,
8 | MeshRasterizer,
9 | SoftSilhouetteShader)
10 | from pytorch3d.structures import Meshes
11 |
12 |
13 |
14 | def approx_psr_grad(psr_grid, res, normalize=True):
15 | delta_x = delta_y = delta_z = 1/res
16 | psr_pad = torch.nn.ReplicationPad3d(1)(psr_grid).squeeze()
17 |
18 | grad_x = (psr_pad[2:, :, :] - psr_pad[:-2, :, :]) / 2 / delta_x
19 | grad_y = (psr_pad[:, 2:, :] - psr_pad[:, :-2, :]) / 2 / delta_y
20 | grad_z = (psr_pad[:, :, 2:] - psr_pad[:, :, :-2]) / 2 / delta_z
21 | grad_x = grad_x[:, 1:-1, 1:-1]
22 | grad_y = grad_y[1:-1, :, 1:-1]
23 | grad_z = grad_z[1:-1, 1:-1, :]
24 |
25 | psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=3) # [res_x, res_y, res_z, 3]
26 | if normalize:
27 | psr_grad = psr_grad / (psr_grad.norm(dim=3, keepdim=True) + 1e-12)
28 |
29 | return psr_grad
30 |
31 |
32 | class SAP2Image(nn.Module):
33 | def __init__(self, cfg, img_size):
34 | super().__init__()
35 |
36 | self.psr2sur = PSR2SurfacePoints.apply
37 | self.psr2mesh = PSR2Mesh.apply
38 | # initialize DPSR
39 | self.dpsr = DPSR(res=(cfg['model']['grid_res'],
40 | cfg['model']['grid_res'],
41 | cfg['model']['grid_res']),
42 | sig=cfg['model']['psr_sigma'])
43 | self.cfg = cfg
44 | if cfg['train']['l_weight']['rgb'] != 0.:
45 | self.rendering_network = RenderingNetwork(**cfg['model']['renderer'])
46 |
47 | if cfg['train']['l_weight']['mask'] != 0.:
48 | # initialize rasterizer
49 | sigma = 1e-4
50 | raster_settings_soft = RasterizationSettings(
51 | image_size=img_size,
52 | blur_radius=np.log(1. / 1e-4 - 1.)*sigma,
53 | faces_per_pixel=150,
54 | perspective_correct=False
55 | )
56 |
57 | # initialize silhouette renderer
58 | self.mesh_rasterizer = MeshRenderer(
59 | rasterizer=MeshRasterizer(
60 | raster_settings=raster_settings_soft
61 | ),
62 | shader=SoftSilhouetteShader()
63 | )
64 |
65 | self.cfg = cfg
66 | self.img_size = img_size
67 |
68 | def forward(self, inputs, data):
69 | points, normals = inputs[...,:3], inputs[...,3:]
70 | points = torch.sigmoid(points)
71 | normals = normals / normals.norm(dim=-1, keepdim=True)
72 |
73 | # DPSR to get grid
74 | psr_grid = self.dpsr(points, normals).unsqueeze(1)
75 | psr_grid = torch.tanh(psr_grid)
76 |
77 | return self.render_img(psr_grid, data)
78 |
79 | def render_img(self, psr_grid, data):
80 |
81 | n_views = len(data['masks'])
82 | n_views_per_iter = self.cfg['data']['n_views_per_iter']
83 |
84 | rgb_render_mode = self.cfg['model']['renderer']['mode']
85 | uv = data['uv']
86 |
87 | idx = np.random.randint(0, n_views, n_views_per_iter)
88 | pose = [data['poses'][i] for i in idx]
89 | rgb = data['rgbs'][idx]
90 | mask_gt = data['masks'][idx]
91 | ray = None
92 | pred_rgb = None
93 | pred_mask = None
94 |
95 | if self.cfg['train']['l_weight']['rgb'] != 0.:
96 | psr_grad = approx_psr_grad(psr_grid, self.cfg['model']['grid_res'])
97 | p_inters, visible_mask = self.psr2sur(psr_grid, pose, self.img_size, uv, psr_grad, None)
98 | n_inters = grid_interp(psr_grad[None], (p_inters.detach()[None] + 1) / 2)
99 | fea_interp = None
100 | if 'rays' in data.keys():
101 | ray = data['rays'].squeeze()[idx][visible_mask]
102 | pred_rgb = self.rendering_network(p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp)
103 |
104 | # silhouette loss
105 | if self.cfg['train']['l_weight']['mask'] != 0.:
106 | # build mesh
107 | v, f, _ = self.psr2mesh(psr_grid)
108 | v = v * 2. - 1 # within the range of [-1, 1]
109 | # ! Fast but more GPU usage
110 | mesh = Meshes(verts=[v.squeeze()], faces=[f.squeeze()])
111 | if True:
112 | #! PyTorch3D silhouette loss
113 | # build pose
114 | R = torch.cat([p.R for p in pose], dim=0)
115 | T = torch.cat([p.T for p in pose], dim=0)
116 | focal = torch.cat([p.focal_length for p in pose], dim=0)
117 | pp = torch.cat([p.principal_point for p in pose], dim=0)
118 | pose_cur = PerspectiveCameras(
119 | focal_length=focal,
120 | principal_point=pp,
121 | R=R, T=T,
122 | device=R.device)
123 | pred_mask = self.mesh_rasterizer(mesh.extend(n_views_per_iter), cameras=pose_cur)[..., 3]
124 | else:
125 | pred_mask = []
126 | # ! Slow but less GPU usage
127 | for i in range(n_views_per_iter):
128 | #! PyTorch3D silhouette loss
129 | pred_mask.append(self.mesh_rasterizer(mesh, cameras=pose[i])[..., 3])
130 | pred_mask = torch.cat(pred_mask, dim=0)
131 |
132 | output = {
133 | 'rgb': pred_rgb,
134 | 'rgb_gt': rgb,
135 | 'mask': pred_mask,
136 | 'mask_gt': mask_gt,
137 | 'vis_mask': visible_mask,
138 | }
139 |
140 | return output
--------------------------------------------------------------------------------
/SAP/src/network/__init__.py:
--------------------------------------------------------------------------------
1 | from ..network import encoder, decoder
2 |
3 | encoder_dict = {
4 | 'local_pool_pointnet': encoder.LocalPoolPointnet,
5 | }
6 | decoder_dict = {
7 | 'simple_local': decoder.LocalDecoder,
8 | }
--------------------------------------------------------------------------------
/SAP/src/network/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/network/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/network/__pycache__/decoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/network/__pycache__/decoder.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/network/__pycache__/encoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/network/__pycache__/encoder.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/network/__pycache__/unet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/network/__pycache__/unet.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/network/__pycache__/unet3d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/network/__pycache__/unet3d.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/network/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/SAP/src/network/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/SAP/src/network/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from ipdb import set_trace as st
6 | from ..network.utils import normalize_3d_coordinate, ResnetBlockFC, \
7 | normalize_coordinate
8 |
9 |
10 | class LocalDecoder(nn.Module):
11 | ''' Decoder.
12 | Instead of conditioning on global features, on plane/volume local features.
13 | Args:
14 | dim (int): input dimension
15 | c_dim (int): dimension of latent conditioned code c
16 | hidden_size (int): hidden size of Decoder network
17 | n_blocks (int): number of blocks ResNetBlockFC layers
18 | leaky (bool): whether to use leaky ReLUs
19 | sample_mode (str): sampling feature strategy, bilinear|nearest
20 | padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
21 | '''
22 |
23 | def __init__(self, dim=3, c_dim=128, out_dim=3,
24 | hidden_size=256, n_blocks=5, leaky=False, sample_mode='bilinear', padding=0.1, map2local=None):
25 | super().__init__()
26 | self.c_dim = c_dim
27 | self.n_blocks = n_blocks
28 |
29 | if c_dim != 0:
30 | self.fc_c = nn.ModuleList([
31 | nn.Linear(c_dim, hidden_size) for i in range(n_blocks)
32 | ])
33 |
34 |
35 | self.fc_p = nn.Linear(dim, hidden_size)
36 |
37 | self.blocks = nn.ModuleList([
38 | ResnetBlockFC(hidden_size) for i in range(n_blocks)
39 | ])
40 |
41 | self.fc_out = nn.Linear(hidden_size, out_dim)
42 |
43 | if not leaky:
44 | self.actvn = F.relu
45 | else:
46 | self.actvn = lambda x: F.leaky_relu(x, 0.2)
47 |
48 | self.sample_mode = sample_mode
49 | self.padding = padding
50 | self.map2local = map2local
51 | self.out_dim = out_dim
52 |
53 |
54 | def sample_plane_feature(self, p, c, plane='xz'):
55 | xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
56 | xy = xy[:, :, None].float()
57 | vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
58 | c = F.grid_sample(c, vgrid, padding_mode='border',
59 | align_corners=True,
60 | mode=self.sample_mode).squeeze(-1)
61 | return c
62 |
63 | def sample_grid_feature(self, p, c):
64 | p_nor = normalize_3d_coordinate(p.clone())
65 | p_nor = p_nor[:, :, None, None].float()
66 | vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1)
67 | # acutally trilinear interpolation if mode = 'bilinear'
68 | c = F.grid_sample(c, vgrid, padding_mode='border',
69 | align_corners=True,
70 | mode=self.sample_mode).squeeze(-1).squeeze(-1)
71 | return c
72 |
73 | def forward(self, p, c_plane, **kwargs):
74 | batch_size = p.shape[0]
75 | plane_type = list(c_plane.keys())
76 | c = 0
77 | if 'grid' in plane_type:
78 | c += self.sample_grid_feature(p, c_plane['grid'])
79 | if 'xz' in plane_type:
80 | c += self.sample_plane_feature(p, c_plane['xz'], plane='xz')
81 | if 'xy' in plane_type:
82 | c += self.sample_plane_feature(p, c_plane['xy'], plane='xy')
83 | if 'yz' in plane_type:
84 | c += self.sample_plane_feature(p, c_plane['yz'], plane='yz')
85 | c = c.transpose(1, 2)
86 |
87 | p = p.float()
88 |
89 | if self.map2local:
90 | p = self.map2local(p)
91 |
92 | net = self.fc_p(p)
93 |
94 | for i in range(self.n_blocks):
95 | if self.c_dim != 0:
96 | net = net + self.fc_c[i](c)
97 |
98 | net = self.blocks[i](net)
99 |
100 | out = self.fc_out(self.actvn(net))
101 |
102 |
103 | if self.out_dim > 3:
104 | out = out.reshape(batch_size, -1, 3)
105 |
106 | return out
--------------------------------------------------------------------------------
/SAP/src/network/utils.py:
--------------------------------------------------------------------------------
1 | """ Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | class Embedder:
7 | def __init__(self, **kwargs):
8 | self.kwargs = kwargs
9 | self.create_embedding_fn()
10 |
11 | def create_embedding_fn(self):
12 | embed_fns = []
13 | d = self.kwargs['input_dims']
14 | out_dim = 0
15 | if self.kwargs['include_input']:
16 | embed_fns.append(lambda x: x)
17 | out_dim += d
18 |
19 | max_freq = self.kwargs['max_freq_log2']
20 | N_freqs = self.kwargs['num_freqs']
21 |
22 | if self.kwargs['log_sampling']:
23 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
24 | else:
25 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
26 |
27 | for freq in freq_bands:
28 | for p_fn in self.kwargs['periodic_fns']:
29 | embed_fns.append(lambda x, p_fn=p_fn,
30 | freq=freq: p_fn(x * freq))
31 | out_dim += d
32 |
33 | self.embed_fns = embed_fns
34 | self.out_dim = out_dim
35 |
36 | def embed(self, inputs):
37 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
38 |
39 | def get_embedder(multires, d_in=3):
40 | embed_kwargs = {
41 | 'include_input': True,
42 | 'input_dims': d_in,
43 | 'max_freq_log2': multires-1,
44 | 'num_freqs': multires,
45 | 'log_sampling': True,
46 | 'periodic_fns': [torch.sin, torch.cos],
47 | }
48 |
49 | embedder_obj = Embedder(**embed_kwargs)
50 | def embed(x, eo=embedder_obj): return eo.embed(x)
51 | return embed, embedder_obj.out_dim
52 |
53 | def normalize_coordinate(p, plane='xz'):
54 | ''' Normalize coordinate to [0, 1] for unit cube experiments
55 |
56 | Args:
57 | p (tensor): point
58 | padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
59 | plane (str): plane feature type, ['xz', 'xy', 'yz']
60 | '''
61 | if plane == 'xz':
62 | xy = p[:, :, [0, 2]]
63 | elif plane =='xy':
64 | xy = p[:, :, [0, 1]]
65 | else:
66 | xy = p[:, :, [1, 2]]
67 |
68 | xy_new = xy
69 | # f there are outliers out of the range
70 | if xy_new.max() >= 1:
71 | xy_new[xy_new >= 1] = 1 - 10e-6
72 | if xy_new.min() < 0:
73 | xy_new[xy_new < 0] = 0.0
74 | return xy_new
75 |
76 |
77 | def normalize_3d_coordinate(p):
78 | ''' Normalize coordinate to [0, 1] for unit cube experiments.
79 | '''
80 | if p.max() >= 1:
81 | p[p >= 1] = 1 - 10e-6
82 | if p.min() < 0:
83 | p[p < 0] = 0.0
84 | return p
85 |
86 | def coordinate2index(x, reso, coord_type='2d'):
87 | ''' Normalize coordinate to [0, 1] for unit cube experiments.
88 | Corresponds to our 3D model
89 |
90 | Args:
91 | x (tensor): coordinate
92 | reso (int): defined resolution
93 | coord_type (str): coordinate type
94 | '''
95 | x = (x * reso).long()
96 | if coord_type == '2d': # plane
97 | index = x[:, :, 0] + reso * x[:, :, 1]
98 | elif coord_type == '3d': # grid
99 | index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2])
100 | index = index[:, None, :]
101 | return index
102 |
103 |
104 | class map2local(object):
105 | ''' Add new keys to the given input
106 |
107 | Args:
108 | s (float): the defined voxel size
109 | pos_encoding (str): method for the positional encoding, linear|sin_cos
110 | '''
111 | def __init__(self, s, pos_encoding='linear'):
112 | super().__init__()
113 | self.s = s
114 | # self.pe = positional_encoding(basis_function=pos_encoding, local=True)
115 |
116 | def __call__(self, p):
117 | # p = torch.remainder(p, self.s) / self.s # always possitive
118 | p = (p % self.s) / self.s
119 | p[p < 0] = 0.0
120 | # p = torch.fmod(p, self.s) / self.s # same sign as input p!
121 | # p = self.pe(p)
122 | return p
123 |
124 | # Resnet Blocks
125 | class ResnetBlockFC(nn.Module):
126 | ''' Fully connected ResNet Block class.
127 |
128 | Args:
129 | size_in (int): input dimension
130 | size_out (int): output dimension
131 | size_h (int): hidden dimension
132 | '''
133 |
134 | def __init__(self, size_in, size_out=None, size_h=None, siren=False):
135 | super().__init__()
136 | # Attributes
137 | if size_out is None:
138 | size_out = size_in
139 |
140 | if size_h is None:
141 | size_h = min(size_in, size_out)
142 |
143 | self.size_in = size_in
144 | self.size_h = size_h
145 | self.size_out = size_out
146 | # Submodules
147 | self.fc_0 = nn.Linear(size_in, size_h)
148 | self.fc_1 = nn.Linear(size_h, size_out)
149 | self.actvn = nn.ReLU()
150 |
151 | if size_in == size_out:
152 | self.shortcut = None
153 | else:
154 | self.shortcut = nn.Linear(size_in, size_out, bias=False)
155 | # Initialization
156 | nn.init.zeros_(self.fc_1.weight)
157 |
158 | def forward(self, x):
159 | net = self.fc_0(self.actvn(x))
160 | dx = self.fc_1(self.actvn(net))
161 |
162 | if self.shortcut is not None:
163 | x_s = self.shortcut(x)
164 | else:
165 | x_s = x
166 |
167 | return x_s + dx
--------------------------------------------------------------------------------
/cfgs/Tooth_models/PoinTr.yaml:
--------------------------------------------------------------------------------
1 | optimizer : {
2 | type: AdamW,
3 | kwargs: {
4 | lr : 0.0005,
5 | weight_decay : 0.0005
6 | }}
7 | scheduler: {
8 | type: LambdaLR,
9 | kwargs: {
10 | decay_step: 21,
11 | lr_decay: 0.76,
12 | lowest_decay: 0.02 # min lr = lowest_decay * lr
13 | }}
14 | bnmscheduler: {
15 | type: Lambda,
16 | kwargs: {
17 | decay_step: 21,
18 | bn_decay: 0.5,
19 | bn_momentum: 0.9,
20 | lowest_decay: 0.01
21 | }}
22 |
23 | dataset : {
24 | train : { _base_: cfgs/dataset_configs/Tooth.yaml,
25 | others: {subset: 'train'}},
26 | val : { _base_: cfgs/dataset_configs/Tooth.yaml,
27 | others: {subset: 'test'}},
28 | test : { _base_: cfgs/dataset_configs/Tooth.yaml,
29 | others: {subset: 'test'}}}
30 | model : {
31 | NAME: PoinTr, num_pred: 1568, num_query: 96, knn_layer: 1, trans_dim: 384}
32 | total_bs : 1
33 | step_per_update : 1
34 | max_epoch : 2
35 |
36 | consider_metric: CDL2
37 |
38 |
39 |
--------------------------------------------------------------------------------
/cfgs/dataset_configs/Tooth.yaml:
--------------------------------------------------------------------------------
1 | NAME: crown
2 | DATA_PATH: data/dental/crown
3 | N_POINTS: 9999
4 | PC_PATH: C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\e
5 | #C:\\Users\\Golriz\\OneDrive - polymtl.ca\\Desktop\\back-updata-sap\\data-pointr\\data-pointr
6 |
--------------------------------------------------------------------------------
/data/dental/crown/test.txt:
--------------------------------------------------------------------------------
1 | 6581-36
--------------------------------------------------------------------------------
/data/dental/crown/train.txt:
--------------------------------------------------------------------------------
1 | 6581-36
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_dataset_from_cfg
2 | import datasets.crowndataset
--------------------------------------------------------------------------------
/datasets/__pycache__/KITTIDataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/KITTIDataset.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/PCNDataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/PCNDataset.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/ShapeNet55Dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/ShapeNet55Dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/Wrapping_Python_vtk_util_numpy_support.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/Wrapping_Python_vtk_util_numpy_support.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/build.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/build.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/crowndataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/crowndataset.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/data_transforms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/data_transforms.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/easy_mesh_vtk.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/easy_mesh_vtk.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/io.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/datasets/__pycache__/io.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/build.py:
--------------------------------------------------------------------------------
1 | from utils import registry
2 |
3 |
4 | DATASETS = registry.Registry('dataset')
5 |
6 |
7 | def build_dataset_from_cfg(cfg, default_args = None):
8 | """
9 | Build a dataset, defined by `dataset_name`.
10 | Args:
11 | cfg (eDICT):
12 | Returns:
13 | Dataset: a constructed dataset specified by dataset_name.
14 | """
15 | return DATASETS.build(cfg, default_args = default_args)
16 |
17 |
18 |
--------------------------------------------------------------------------------
/extensions/chamfer_dist/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Thibault GROUEIX
3 | # @Date: 2019-08-07 20:54:24
4 | # @Last Modified by: Haozhe Xie
5 | # @Last Modified time: 2019-12-18 15:06:25
6 | # @Email: cshzxie@gmail.com
7 |
8 | import torch
9 |
10 | import chamfer
11 |
12 |
13 | class ChamferFunction(torch.autograd.Function):
14 | @staticmethod
15 | def forward(ctx, xyz1, xyz2):
16 | dist1, dist2, idx1, idx2 = chamfer.forward(xyz1, xyz2)
17 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
18 |
19 | return dist1, dist2
20 |
21 | @staticmethod
22 | def backward(ctx, grad_dist1, grad_dist2):
23 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
24 | grad_xyz1, grad_xyz2 = chamfer.backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2)
25 | return grad_xyz1, grad_xyz2
26 |
27 |
28 | class ChamferDistanceL2(torch.nn.Module):
29 | f''' Chamder Distance L2
30 | '''
31 | def __init__(self, ignore_zeros=False):
32 | super().__init__()
33 | self.ignore_zeros = ignore_zeros
34 |
35 | def forward(self, xyz1, xyz2):
36 | batch_size = xyz1.size(0)
37 | if batch_size == 1 and self.ignore_zeros:
38 | non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
39 | non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
40 | xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
41 | xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
42 |
43 | dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
44 | return torch.mean(dist1) + torch.mean(dist2)
45 |
46 | class ChamferDistanceL2_split(torch.nn.Module):
47 | f''' Chamder Distance L2
48 | '''
49 | def __init__(self, ignore_zeros=False):
50 | super().__init__()
51 | self.ignore_zeros = ignore_zeros
52 |
53 | def forward(self, xyz1, xyz2):
54 | batch_size = xyz1.size(0)
55 | if batch_size == 1 and self.ignore_zeros:
56 | non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
57 | non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
58 | xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
59 | xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
60 |
61 | dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
62 | return torch.mean(dist1), torch.mean(dist2)
63 |
64 | class ChamferDistanceL1(torch.nn.Module):
65 | f''' Chamder Distance L1
66 | '''
67 | def __init__(self, ignore_zeros=False):
68 | super().__init__()
69 | self.ignore_zeros = ignore_zeros
70 |
71 | def forward(self, xyz1, xyz2):
72 | batch_size = xyz1.size(0)
73 | if batch_size == 1 and self.ignore_zeros:
74 | non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
75 | non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
76 | xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
77 | xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
78 |
79 | dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
80 | # import pdb
81 | # pdb.set_trace()
82 | dist1 = torch.sqrt(dist1)
83 | dist2 = torch.sqrt(dist2)
84 | return (torch.mean(dist1) + torch.mean(dist2))/2
85 |
86 |
--------------------------------------------------------------------------------
/extensions/chamfer_dist/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/chamfer_dist/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/extensions/chamfer_dist/build/lib.linux-x86_64-3.8/chamfer.cpython-38-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/chamfer_dist/build/lib.linux-x86_64-3.8/chamfer.cpython-38-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/extensions/chamfer_dist/build/temp.linux-x86_64-3.8/build.ninja:
--------------------------------------------------------------------------------
1 | ninja_required_version = 1.3
2 | cxx = c++
3 | nvcc = /cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/bin/nvcc
4 |
5 | cflags = -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -fPIC -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c
6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=chamfer -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++14
7 | cuda_cflags = -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c
8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=chamfer -D_GLIBCXX_USE_CXX11_ABI=1 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 -std=c++14
9 | ldflags =
10 |
11 | rule compile
12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
13 | depfile = $out.d
14 | deps = gcc
15 |
16 | rule cuda_compile
17 | depfile = $out.d
18 | deps = gcc
19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
20 |
21 |
22 |
23 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/chamfer_dist/build/temp.linux-x86_64-3.8/chamfer.o: cuda_compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/chamfer_dist/chamfer.cu
24 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/chamfer_dist/build/temp.linux-x86_64-3.8/chamfer_cuda.o: compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/chamfer_dist/chamfer_cuda.cpp
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/extensions/chamfer_dist/build/temp.linux-x86_64-3.8/chamfer.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/chamfer_dist/build/temp.linux-x86_64-3.8/chamfer.o
--------------------------------------------------------------------------------
/extensions/chamfer_dist/build/temp.linux-x86_64-3.8/chamfer_cuda.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/chamfer_dist/build/temp.linux-x86_64-3.8/chamfer_cuda.o
--------------------------------------------------------------------------------
/extensions/chamfer_dist/chamfer.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: chamfer
3 | Version: 2.0.0
4 | Summary: UNKNOWN
5 | License: UNKNOWN
6 | Platform: UNKNOWN
7 |
8 | UNKNOWN
9 |
10 |
--------------------------------------------------------------------------------
/extensions/chamfer_dist/chamfer.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | chamfer.cu
2 | chamfer_cuda.cpp
3 | setup.py
4 | chamfer.egg-info/PKG-INFO
5 | chamfer.egg-info/SOURCES.txt
6 | chamfer.egg-info/dependency_links.txt
7 | chamfer.egg-info/top_level.txt
--------------------------------------------------------------------------------
/extensions/chamfer_dist/chamfer.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/extensions/chamfer_dist/chamfer.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | chamfer
2 |
--------------------------------------------------------------------------------
/extensions/chamfer_dist/chamfer_cuda.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * @Author: Haozhe Xie
3 | * @Date: 2019-08-07 20:54:24
4 | * @Last Modified by: Haozhe Xie
5 | * @Last Modified time: 2019-12-10 10:33:50
6 | * @Email: cshzxie@gmail.com
7 | */
8 |
9 | #include
10 | #include
11 |
12 | std::vector chamfer_cuda_forward(torch::Tensor xyz1,
13 | torch::Tensor xyz2);
14 |
15 | std::vector chamfer_cuda_backward(torch::Tensor xyz1,
16 | torch::Tensor xyz2,
17 | torch::Tensor idx1,
18 | torch::Tensor idx2,
19 | torch::Tensor grad_dist1,
20 | torch::Tensor grad_dist2);
21 |
22 | std::vector chamfer_forward(torch::Tensor xyz1,
23 | torch::Tensor xyz2) {
24 | return chamfer_cuda_forward(xyz1, xyz2);
25 | }
26 |
27 | std::vector chamfer_backward(torch::Tensor xyz1,
28 | torch::Tensor xyz2,
29 | torch::Tensor idx1,
30 | torch::Tensor idx2,
31 | torch::Tensor grad_dist1,
32 | torch::Tensor grad_dist2) {
33 | return chamfer_cuda_backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2);
34 | }
35 |
36 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
37 | m.def("forward", &chamfer_forward, "Chamfer forward (CUDA)");
38 | m.def("backward", &chamfer_backward, "Chamfer backward (CUDA)");
39 | }
40 |
--------------------------------------------------------------------------------
/extensions/chamfer_dist/dist/chamfer-2.0.0-py3.8-linux-x86_64.egg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/chamfer_dist/dist/chamfer-2.0.0-py3.8-linux-x86_64.egg
--------------------------------------------------------------------------------
/extensions/chamfer_dist/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Haozhe Xie
3 | # @Date: 2019-08-07 20:54:24
4 | # @Last Modified by: Haozhe Xie
5 | # @Last Modified time: 2019-12-10 10:04:25
6 | # @Email: cshzxie@gmail.com
7 |
8 | from setuptools import setup
9 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
10 |
11 | setup(name='chamfer',
12 | version='2.0.0',
13 | ext_modules=[
14 | CUDAExtension('chamfer', [
15 | 'chamfer_cuda.cpp',
16 | 'chamfer.cu',
17 | ]),
18 | ],
19 | cmdclass={'build_ext': BuildExtension})
20 |
--------------------------------------------------------------------------------
/extensions/chamfer_dist/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Haozhe Xie
3 | # @Date: 2019-12-10 10:38:01
4 | # @Last Modified by: Haozhe Xie
5 | # @Last Modified time: 2019-12-26 14:21:36
6 | # @Email: cshzxie@gmail.com
7 | #
8 | # Note:
9 | # - Replace float -> double, kFloat -> kDouble in chamfer.cu
10 |
11 | import os
12 | import sys
13 | import torch
14 | import unittest
15 |
16 |
17 | from torch.autograd import gradcheck
18 |
19 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)))
20 | from extensions.chamfer_dist import ChamferFunction
21 |
22 |
23 | class ChamferDistanceTestCase(unittest.TestCase):
24 | def test_chamfer_dist(self):
25 | x = torch.rand(4, 64, 3).double()
26 | y = torch.rand(4, 128, 3).double()
27 | x.requires_grad = True
28 | y.requires_grad = True
29 | print(gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()]))
30 |
31 |
32 |
33 | if __name__ == '__main__':
34 | # unittest.main()
35 | import pdb
36 | x = torch.rand(32,128,3)
37 | y = torch.rand(32,128,3)
38 | pdb.set_trace()
39 |
--------------------------------------------------------------------------------
/extensions/cubic_feature_sampling/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Haozhe Xie
3 | # @Date: 2019-12-19 16:55:15
4 | # @Last Modified by: Haozhe Xie
5 | # @Last Modified time: 2019-12-26 13:15:14
6 | # @Email: cshzxie@gmail.com
7 |
8 | import torch
9 |
10 | import cubic_feature_sampling
11 |
12 |
13 | class CubicFeatureSamplingFunction(torch.autograd.Function):
14 | @staticmethod
15 | def forward(ctx, ptcloud, cubic_features, neighborhood_size=1):
16 | scale = cubic_features.size(2)
17 | point_features, grid_pt_indexes = cubic_feature_sampling.forward(scale, neighborhood_size, ptcloud,
18 | cubic_features)
19 | ctx.save_for_backward(torch.Tensor([scale]), torch.Tensor([neighborhood_size]), grid_pt_indexes)
20 | return point_features
21 |
22 | @staticmethod
23 | def backward(ctx, grad_point_features):
24 | scale, neighborhood_size, grid_pt_indexes = ctx.saved_tensors
25 | scale = int(scale.item())
26 | neighborhood_size = int(neighborhood_size.item())
27 | grad_point_features = grad_point_features.contiguous()
28 | grad_ptcloud, grad_cubic_features = cubic_feature_sampling.backward(scale, neighborhood_size,
29 | grad_point_features, grid_pt_indexes)
30 | return grad_ptcloud, grad_cubic_features, None
31 |
32 |
33 | class CubicFeatureSampling(torch.nn.Module):
34 | def __init__(self):
35 | super(CubicFeatureSampling, self).__init__()
36 |
37 | def forward(self, ptcloud, cubic_features, neighborhood_size=1):
38 | h_scale = cubic_features.size(2) / 2
39 | ptcloud = ptcloud * h_scale + h_scale
40 | return CubicFeatureSamplingFunction.apply(ptcloud, cubic_features, neighborhood_size)
41 |
--------------------------------------------------------------------------------
/extensions/cubic_feature_sampling/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/cubic_feature_sampling/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/extensions/cubic_feature_sampling/build/lib.linux-x86_64-3.8/cubic_feature_sampling.cpython-38-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/cubic_feature_sampling/build/lib.linux-x86_64-3.8/cubic_feature_sampling.cpython-38-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/extensions/cubic_feature_sampling/build/temp.linux-x86_64-3.8/build.ninja:
--------------------------------------------------------------------------------
1 | ninja_required_version = 1.3
2 | cxx = c++
3 | nvcc = /cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/bin/nvcc
4 |
5 | cflags = -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -fPIC -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c
6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=cubic_feature_sampling -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++14
7 | cuda_cflags = -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c
8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=cubic_feature_sampling -D_GLIBCXX_USE_CXX11_ABI=1 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 -std=c++14
9 | ldflags =
10 |
11 | rule compile
12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
13 | depfile = $out.d
14 | deps = gcc
15 |
16 | rule cuda_compile
17 | depfile = $out.d
18 | deps = gcc
19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
20 |
21 |
22 |
23 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/cubic_feature_sampling/build/temp.linux-x86_64-3.8/cubic_feature_sampling.o: cuda_compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/cubic_feature_sampling/cubic_feature_sampling.cu
24 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/cubic_feature_sampling/build/temp.linux-x86_64-3.8/cubic_feature_sampling_cuda.o: compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/cubic_feature_sampling/cubic_feature_sampling_cuda.cpp
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/extensions/cubic_feature_sampling/build/temp.linux-x86_64-3.8/cubic_feature_sampling.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/cubic_feature_sampling/build/temp.linux-x86_64-3.8/cubic_feature_sampling.o
--------------------------------------------------------------------------------
/extensions/cubic_feature_sampling/build/temp.linux-x86_64-3.8/cubic_feature_sampling_cuda.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/cubic_feature_sampling/build/temp.linux-x86_64-3.8/cubic_feature_sampling_cuda.o
--------------------------------------------------------------------------------
/extensions/cubic_feature_sampling/cubic_feature_sampling.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: cubic-feature-sampling
3 | Version: 1.1.0
4 | Summary: UNKNOWN
5 | License: UNKNOWN
6 | Platform: UNKNOWN
7 |
8 | UNKNOWN
9 |
10 |
--------------------------------------------------------------------------------
/extensions/cubic_feature_sampling/cubic_feature_sampling.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | cubic_feature_sampling.cu
2 | cubic_feature_sampling_cuda.cpp
3 | setup.py
4 | cubic_feature_sampling.egg-info/PKG-INFO
5 | cubic_feature_sampling.egg-info/SOURCES.txt
6 | cubic_feature_sampling.egg-info/dependency_links.txt
7 | cubic_feature_sampling.egg-info/top_level.txt
--------------------------------------------------------------------------------
/extensions/cubic_feature_sampling/cubic_feature_sampling.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/extensions/cubic_feature_sampling/cubic_feature_sampling.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | cubic_feature_sampling
2 |
--------------------------------------------------------------------------------
/extensions/cubic_feature_sampling/cubic_feature_sampling_cuda.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * @Author: Haozhe Xie
3 | * @Date: 2019-12-19 17:04:38
4 | * @Last Modified by: Haozhe Xie
5 | * @Last Modified time: 2020-06-17 14:50:22
6 | * @Email: cshzxie@gmail.com
7 | */
8 |
9 | #include
10 | #include
11 |
12 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
13 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
14 | #define CHECK_CONTIGUOUS(x) \
15 | AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
16 | #define CHECK_INPUT(x) \
17 | CHECK_CUDA(x); \
18 | CHECK_CONTIGUOUS(x)
19 |
20 | std::vector cubic_feature_sampling_cuda_forward(
21 | int scale,
22 | int neighborhood_size,
23 | torch::Tensor ptcloud,
24 | torch::Tensor cubic_features,
25 | cudaStream_t stream);
26 |
27 | std::vector cubic_feature_sampling_cuda_backward(
28 | int scale,
29 | int neighborhood_size,
30 | torch::Tensor grad_point_features,
31 | torch::Tensor grid_pt_indexes,
32 | cudaStream_t stream);
33 |
34 | std::vector cubic_feature_sampling_forward(
35 | int scale,
36 | int neighborhood_size,
37 | torch::Tensor ptcloud,
38 | torch::Tensor cubic_features) {
39 | CHECK_INPUT(ptcloud);
40 | CHECK_INPUT(cubic_features);
41 |
42 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
43 | return cubic_feature_sampling_cuda_forward(scale, neighborhood_size, ptcloud,
44 | cubic_features, stream);
45 | }
46 |
47 | std::vector cubic_feature_sampling_backward(
48 | int scale,
49 | int neighborhood_size,
50 | torch::Tensor grad_point_features,
51 | torch::Tensor grid_pt_indexes) {
52 | CHECK_INPUT(grad_point_features);
53 | CHECK_INPUT(grid_pt_indexes);
54 |
55 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
56 | return cubic_feature_sampling_cuda_backward(
57 | scale, neighborhood_size, grad_point_features, grid_pt_indexes, stream);
58 | }
59 |
60 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
61 | m.def("forward", &cubic_feature_sampling_forward,
62 | "Cubic Feature Sampling forward (CUDA)");
63 | m.def("backward", &cubic_feature_sampling_backward,
64 | "Cubic Feature Sampling backward (CUDA)");
65 | }
--------------------------------------------------------------------------------
/extensions/cubic_feature_sampling/dist/cubic_feature_sampling-1.1.0-py3.8-linux-x86_64.egg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/cubic_feature_sampling/dist/cubic_feature_sampling-1.1.0-py3.8-linux-x86_64.egg
--------------------------------------------------------------------------------
/extensions/cubic_feature_sampling/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Haozhe Xie
3 | # @Date: 2019-12-19 17:03:06
4 | # @Last Modified by: Haozhe Xie
5 | # @Last Modified time: 2019-12-26 14:02:06
6 | # @Email: cshzxie@gmail.com
7 |
8 | from setuptools import setup
9 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
10 |
11 | setup(name='cubic_feature_sampling',
12 | version='1.1.0',
13 | ext_modules=[
14 | CUDAExtension('cubic_feature_sampling', ['cubic_feature_sampling_cuda.cpp', 'cubic_feature_sampling.cu']),
15 | ],
16 | cmdclass={'build_ext': BuildExtension})
17 |
--------------------------------------------------------------------------------
/extensions/cubic_feature_sampling/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Haozhe Xie
3 | # @Date: 2019-12-20 11:50:50
4 | # @Last Modified by: Haozhe Xie
5 | # @Last Modified time: 2019-12-26 13:52:33
6 | # @Email: cshzxie@gmail.com
7 | #
8 | # Note:
9 | # - Replace float -> double, kFloat -> kDouble in cubic_feature_sampling.cu
10 |
11 | import os
12 | import sys
13 | import torch
14 | import unittest
15 |
16 | from torch.autograd import gradcheck
17 |
18 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)))
19 | from extensions.cubic_feature_sampling import CubicFeatureSamplingFunction
20 |
21 |
22 | class CubicFeatureSamplingTestCase(unittest.TestCase):
23 | def test_neighborhood_size_1(self):
24 | ptcloud = torch.rand(2, 64, 3) * 2 - 1
25 | cubic_features = torch.rand(2, 4, 8, 8, 8)
26 | ptcloud.requires_grad = True
27 | cubic_features.requires_grad = True
28 | self.assertTrue(
29 | gradcheck(CubicFeatureSamplingFunction.apply,
30 | [ptcloud.double().cuda(), cubic_features.double().cuda()]))
31 |
32 | def test_neighborhood_size_2(self):
33 | ptcloud = torch.rand(2, 32, 3) * 2 - 1
34 | cubic_features = torch.rand(2, 2, 8, 8, 8)
35 | ptcloud.requires_grad = True
36 | cubic_features.requires_grad = True
37 | self.assertTrue(
38 | gradcheck(CubicFeatureSamplingFunction.apply,
39 | [ptcloud.double().cuda(), cubic_features.double().cuda(), 2]))
40 |
41 | def test_neighborhood_size_3(self):
42 | ptcloud = torch.rand(1, 32, 3) * 2 - 1
43 | cubic_features = torch.rand(1, 2, 16, 16, 16)
44 | ptcloud.requires_grad = True
45 | cubic_features.requires_grad = True
46 | self.assertTrue(
47 | gradcheck(CubicFeatureSamplingFunction.apply,
48 | [ptcloud.double().cuda(), cubic_features.double().cuda(), 3]))
49 |
50 |
51 | if __name__ == '__main__':
52 | unittest.main()
53 |
--------------------------------------------------------------------------------
/extensions/gridding/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Haozhe Xie
3 | # @Date: 2019-11-15 20:33:52
4 | # @Last Modified by: Haozhe Xie
5 | # @Last Modified time: 2019-12-30 09:55:53
6 | # @Email: cshzxie@gmail.com
7 |
8 | import torch
9 |
10 | import gridding
11 |
12 |
13 | class GriddingFunction(torch.autograd.Function):
14 | @staticmethod
15 | def forward(ctx, scale, ptcloud):
16 | grid, grid_pt_weights, grid_pt_indexes = gridding.forward(-scale, scale - 1, -scale, scale - 1, -scale,
17 | scale - 1, ptcloud)
18 | # print(grid.size()) # torch.Size(batch_size, n_grid_vertices)
19 | # print(grid_pt_weights.size()) # torch.Size(batch_size, n_pts, 8, 3)
20 | # print(grid_pt_indexes.size()) # torch.Size(batch_size, n_pts, 8)
21 | ctx.save_for_backward(grid_pt_weights, grid_pt_indexes)
22 |
23 | return grid
24 |
25 | @staticmethod
26 | def backward(ctx, grad_grid):
27 | grid_pt_weights, grid_pt_indexes = ctx.saved_tensors
28 | grad_ptcloud = gridding.backward(grid_pt_weights, grid_pt_indexes, grad_grid)
29 | # print(grad_ptcloud.size()) # torch.Size(batch_size, n_pts, 3)
30 |
31 | return None, grad_ptcloud
32 |
33 |
34 | class Gridding(torch.nn.Module):
35 | def __init__(self, scale=1):
36 | super(Gridding, self).__init__()
37 | self.scale = scale // 2
38 |
39 | def forward(self, ptcloud):
40 | ptcloud = ptcloud * self.scale
41 | _ptcloud = torch.split(ptcloud, 1, dim=0)
42 | grids = []
43 | for p in _ptcloud:
44 | non_zeros = torch.sum(p, dim=2).ne(0)
45 | p = p[non_zeros].unsqueeze(dim=0)
46 | grids.append(GriddingFunction.apply(self.scale, p.contiguous()))
47 |
48 | return torch.cat(grids, dim=0).contiguous()
49 |
50 |
51 | class GriddingReverseFunction(torch.autograd.Function):
52 | @staticmethod
53 | def forward(ctx, scale, grid):
54 | ptcloud = gridding.rev_forward(scale, grid)
55 | ctx.save_for_backward(torch.Tensor([scale]), grid, ptcloud)
56 | return ptcloud
57 |
58 | @staticmethod
59 | def backward(ctx, grad_ptcloud):
60 | scale, grid, ptcloud = ctx.saved_tensors
61 | scale = int(scale.item())
62 | grad_grid = gridding.rev_backward(ptcloud, grid, grad_ptcloud)
63 | grad_grid = grad_grid.view(-1, scale, scale, scale)
64 | return None, grad_grid
65 |
66 |
67 | class GriddingReverse(torch.nn.Module):
68 | def __init__(self, scale=1):
69 | super(GriddingReverse, self).__init__()
70 | self.scale = scale
71 |
72 | def forward(self, grid):
73 | ptcloud = GriddingReverseFunction.apply(self.scale, grid)
74 | return ptcloud / self.scale * 2
75 |
--------------------------------------------------------------------------------
/extensions/gridding/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/extensions/gridding/build/lib.linux-x86_64-3.8/gridding.cpython-38-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding/build/lib.linux-x86_64-3.8/gridding.cpython-38-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/extensions/gridding/build/temp.linux-x86_64-3.8/build.ninja:
--------------------------------------------------------------------------------
1 | ninja_required_version = 1.3
2 | cxx = c++
3 | nvcc = /cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/bin/nvcc
4 |
5 | cflags = -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -fPIC -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c
6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=gridding -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++14
7 | cuda_cflags = -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c
8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=gridding -D_GLIBCXX_USE_CXX11_ABI=1 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 -std=c++14
9 | ldflags =
10 |
11 | rule compile
12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
13 | depfile = $out.d
14 | deps = gcc
15 |
16 | rule cuda_compile
17 | depfile = $out.d
18 | deps = gcc
19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
20 |
21 |
22 |
23 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding/build/temp.linux-x86_64-3.8/gridding.o: cuda_compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding/gridding.cu
24 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding/build/temp.linux-x86_64-3.8/gridding_cuda.o: compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding/gridding_cuda.cpp
25 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding/build/temp.linux-x86_64-3.8/gridding_reverse.o: cuda_compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding/gridding_reverse.cu
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/extensions/gridding/build/temp.linux-x86_64-3.8/gridding.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding/build/temp.linux-x86_64-3.8/gridding.o
--------------------------------------------------------------------------------
/extensions/gridding/build/temp.linux-x86_64-3.8/gridding_cuda.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding/build/temp.linux-x86_64-3.8/gridding_cuda.o
--------------------------------------------------------------------------------
/extensions/gridding/build/temp.linux-x86_64-3.8/gridding_reverse.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding/build/temp.linux-x86_64-3.8/gridding_reverse.o
--------------------------------------------------------------------------------
/extensions/gridding/dist/gridding-2.1.0-py3.8-linux-x86_64.egg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding/dist/gridding-2.1.0-py3.8-linux-x86_64.egg
--------------------------------------------------------------------------------
/extensions/gridding/gridding.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: gridding
3 | Version: 2.1.0
4 | Summary: UNKNOWN
5 | License: UNKNOWN
6 | Platform: UNKNOWN
7 |
8 | UNKNOWN
9 |
10 |
--------------------------------------------------------------------------------
/extensions/gridding/gridding.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | gridding.cu
2 | gridding_cuda.cpp
3 | gridding_reverse.cu
4 | setup.py
5 | gridding.egg-info/PKG-INFO
6 | gridding.egg-info/SOURCES.txt
7 | gridding.egg-info/dependency_links.txt
8 | gridding.egg-info/top_level.txt
--------------------------------------------------------------------------------
/extensions/gridding/gridding.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/extensions/gridding/gridding.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | gridding
2 |
--------------------------------------------------------------------------------
/extensions/gridding/gridding_cuda.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * @Author: Haozhe Xie
3 | * @Date: 2019-11-13 10:52:53
4 | * @Last Modified by: Haozhe Xie
5 | * @Last Modified time: 2020-06-17 14:52:32
6 | * @Email: cshzxie@gmail.com
7 | */
8 |
9 | #include
10 | #include
11 |
12 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
13 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
14 | #define CHECK_CONTIGUOUS(x) \
15 | AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
16 | #define CHECK_INPUT(x) \
17 | CHECK_CUDA(x); \
18 | CHECK_CONTIGUOUS(x)
19 |
20 | std::vector gridding_cuda_forward(float min_x,
21 | float max_x,
22 | float min_y,
23 | float max_y,
24 | float min_z,
25 | float max_z,
26 | torch::Tensor ptcloud,
27 | cudaStream_t stream);
28 |
29 | torch::Tensor gridding_cuda_backward(torch::Tensor grid_pt_weights,
30 | torch::Tensor grid_pt_indexes,
31 | torch::Tensor grad_grid,
32 | cudaStream_t stream);
33 |
34 | torch::Tensor gridding_reverse_cuda_forward(int scale,
35 | torch::Tensor grid,
36 | cudaStream_t stream);
37 |
38 | torch::Tensor gridding_reverse_cuda_backward(torch::Tensor ptcloud,
39 | torch::Tensor grid,
40 | torch::Tensor grad_ptcloud,
41 | cudaStream_t stream);
42 |
43 | std::vector gridding_forward(float min_x,
44 | float max_x,
45 | float min_y,
46 | float max_y,
47 | float min_z,
48 | float max_z,
49 | torch::Tensor ptcloud) {
50 | CHECK_INPUT(ptcloud);
51 |
52 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
53 | return gridding_cuda_forward(min_x, max_x, min_y, max_y, min_z, max_z,
54 | ptcloud, stream);
55 | }
56 |
57 | torch::Tensor gridding_backward(torch::Tensor grid_pt_weights,
58 | torch::Tensor grid_pt_indexes,
59 | torch::Tensor grad_grid) {
60 | CHECK_INPUT(grid_pt_weights);
61 | CHECK_INPUT(grid_pt_indexes);
62 | CHECK_INPUT(grad_grid);
63 |
64 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
65 | return gridding_cuda_backward(grid_pt_weights, grid_pt_indexes, grad_grid,
66 | stream);
67 | }
68 |
69 | torch::Tensor gridding_reverse_forward(int scale, torch::Tensor grid) {
70 | CHECK_INPUT(grid);
71 |
72 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
73 | return gridding_reverse_cuda_forward(scale, grid, stream);
74 | }
75 |
76 | torch::Tensor gridding_reverse_backward(torch::Tensor ptcloud,
77 | torch::Tensor grid,
78 | torch::Tensor grad_ptcloud) {
79 | CHECK_INPUT(ptcloud);
80 | CHECK_INPUT(grid);
81 | CHECK_INPUT(grad_ptcloud);
82 |
83 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
84 | return gridding_reverse_cuda_backward(ptcloud, grid, grad_ptcloud, stream);
85 | }
86 |
87 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
88 | m.def("forward", &gridding_forward, "Gridding forward (CUDA)");
89 | m.def("backward", &gridding_backward, "Gridding backward (CUDA)");
90 | m.def("rev_forward", &gridding_reverse_forward,
91 | "Gridding Reverse forward (CUDA)");
92 | m.def("rev_backward", &gridding_reverse_backward,
93 | "Gridding Reverse backward (CUDA)");
94 | }
95 |
--------------------------------------------------------------------------------
/extensions/gridding/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Haozhe Xie
3 | # @Date: 2019-11-13 10:51:33
4 | # @Last Modified by: Haozhe Xie
5 | # @Last Modified time: 2019-12-02 17:02:16
6 | # @Email: cshzxie@gmail.com
7 |
8 | from setuptools import setup
9 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
10 |
11 | setup(name='gridding',
12 | version='2.1.0',
13 | ext_modules=[
14 | CUDAExtension('gridding', ['gridding_cuda.cpp', 'gridding.cu', 'gridding_reverse.cu']),
15 | ],
16 | cmdclass={'build_ext': BuildExtension})
17 |
--------------------------------------------------------------------------------
/extensions/gridding/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Haozhe Xie
3 | # @Date: 2019-12-10 10:48:55
4 | # @Last Modified by: Haozhe Xie
5 | # @Last Modified time: 2019-12-26 14:20:42
6 | # @Email: cshzxie@gmail.com
7 | #
8 | # Note:
9 | # - Replace float -> double, kFloat -> kDouble in gridding.cu and gridding_reverse.cu
10 |
11 | import os
12 | import sys
13 | import torch
14 | import unittest
15 |
16 | from torch.autograd import gradcheck
17 |
18 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)))
19 | from extensions.gridding import GriddingFunction, GriddingReverseFunction
20 |
21 |
22 | class GriddingTestCase(unittest.TestCase):
23 | def test_gridding_reverse_function_4(self):
24 | x = torch.rand(2, 4, 4, 4)
25 | x.requires_grad = True
26 | self.assertTrue(gradcheck(GriddingReverseFunction.apply, [4, x.double().cuda()]))
27 |
28 | def test_gridding_reverse_function_8(self):
29 | x = torch.rand(4, 8, 8, 8)
30 | x.requires_grad = True
31 | self.assertTrue(gradcheck(GriddingReverseFunction.apply, [8, x.double().cuda()]))
32 |
33 | def test_gridding_reverse_function_16(self):
34 | x = torch.rand(1, 16, 16, 16)
35 | x.requires_grad = True
36 | self.assertTrue(gradcheck(GriddingReverseFunction.apply, [16, x.double().cuda()]))
37 |
38 | def test_gridding_function_32pts(self):
39 | x = torch.rand(1, 32, 3)
40 | x.requires_grad = True
41 | self.assertTrue(gradcheck(GriddingFunction.apply, [x.double().cuda()]))
42 |
43 |
44 | if __name__ == '__main__':
45 | unittest.main()
46 |
--------------------------------------------------------------------------------
/extensions/gridding_loss/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Haozhe Xie
3 | # @Date: 2019-12-30 09:56:06
4 | # @Last Modified by: Haozhe Xie
5 | # @Last Modified time: 2020-02-22 19:19:43
6 | # @Email: cshzxie@gmail.com
7 |
8 | import torch
9 |
10 | import gridding_distance
11 |
12 |
13 | class GriddingDistanceFunction(torch.autograd.Function):
14 | @staticmethod
15 | def forward(ctx, min_x, max_x, min_y, max_y, min_z, max_z, pred_cloud, gt_cloud):
16 | pred_grid, pred_grid_pt_weights, pred_grid_pt_indexes = gridding_distance.forward(
17 | min_x, max_x, min_y, max_y, min_z, max_z, pred_cloud)
18 | # print(pred_grid.size()) # torch.Size(batch_size, n_grid_vertices, 8)
19 | # print(pred_grid_pt_weights.size()) # torch.Size(batch_size, n_pts, 8, 3)
20 | # print(pred_grid_pt_indexes.size()) # torch.Size(batch_size, n_pts, 8)
21 | gt_grid, gt_grid_pt_weights, gt_grid_pt_indexes = gridding_distance.forward(
22 | min_x, max_x, min_y, max_y, min_z, max_z, gt_cloud)
23 | # print(gt_grid.size()) # torch.Size(batch_size, n_grid_vertices, 8)
24 | # print(gt_grid_pt_weights.size()) # torch.Size(batch_size, n_pts, 8, 3)
25 | # print(gt_grid_pt_indexes.size()) # torch.Size(batch_size, n_pts, 8)
26 |
27 | ctx.save_for_backward(pred_grid_pt_weights, pred_grid_pt_indexes, gt_grid_pt_weights, gt_grid_pt_indexes)
28 | return pred_grid, gt_grid
29 |
30 | @staticmethod
31 | def backward(ctx, grad_pred_grid, grad_gt_grid):
32 | pred_grid_pt_weights, pred_grid_pt_indexes, gt_grid_pt_weights, gt_grid_pt_indexes = ctx.saved_tensors
33 |
34 | grad_pred_cloud = gridding_distance.backward(pred_grid_pt_weights, pred_grid_pt_indexes, grad_pred_grid)
35 | # print(grad_pred_cloud.size()) # torch.Size(batch_size, n_pts, 3)
36 | grad_gt_cloud = gridding_distance.backward(gt_grid_pt_weights, gt_grid_pt_indexes, grad_gt_grid)
37 | # print(grad_gt_cloud.size()) # torch.Size(batch_size, n_pts, 3)
38 |
39 | return None, None, None, None, None, None, grad_pred_cloud, grad_gt_cloud
40 |
41 |
42 | class GriddingDistance(torch.nn.Module):
43 | def __init__(self, scale=1):
44 | super(GriddingDistance, self).__init__()
45 | self.scale = scale
46 |
47 | def forward(self, pred_cloud, gt_cloud):
48 | '''
49 | pred_cloud(b, n_pts1, 3)
50 | gt_cloud(b, n_pts2, 3)
51 | '''
52 | pred_cloud = pred_cloud * self.scale / 2
53 | gt_cloud = gt_cloud * self.scale / 2
54 |
55 | min_pred_x = torch.min(pred_cloud[:, :, 0])
56 | max_pred_x = torch.max(pred_cloud[:, :, 0])
57 | min_pred_y = torch.min(pred_cloud[:, :, 1])
58 | max_pred_y = torch.max(pred_cloud[:, :, 1])
59 | min_pred_z = torch.min(pred_cloud[:, :, 2])
60 | max_pred_z = torch.max(pred_cloud[:, :, 2])
61 |
62 | min_gt_x = torch.min(gt_cloud[:, :, 0])
63 | max_gt_x = torch.max(gt_cloud[:, :, 0])
64 | min_gt_y = torch.min(gt_cloud[:, :, 1])
65 | max_gt_y = torch.max(gt_cloud[:, :, 1])
66 | min_gt_z = torch.min(gt_cloud[:, :, 2])
67 | max_gt_z = torch.max(gt_cloud[:, :, 2])
68 |
69 | min_x = torch.floor(torch.min(min_pred_x, min_gt_x)) - 1
70 | max_x = torch.ceil(torch.max(max_pred_x, max_gt_x)) + 1
71 | min_y = torch.floor(torch.min(min_pred_y, min_gt_y)) - 1
72 | max_y = torch.ceil(torch.max(max_pred_y, max_gt_y)) + 1
73 | min_z = torch.floor(torch.min(min_pred_z, min_gt_z)) - 1
74 | max_z = torch.ceil(torch.max(max_pred_z, max_gt_z)) + 1
75 |
76 | _pred_clouds = torch.split(pred_cloud, 1, dim=0)
77 | _gt_clouds = torch.split(gt_cloud, 1, dim=0)
78 | pred_grids = []
79 | gt_grids = []
80 | for pc, gc in zip(_pred_clouds, _gt_clouds):
81 | non_zeros = torch.sum(pc, dim=2).ne(0)
82 | pc = pc[non_zeros].unsqueeze(dim=0)
83 | non_zeros = torch.sum(gc, dim=2).ne(0)
84 | gc = gc[non_zeros].unsqueeze(dim=0)
85 | pred_grid, gt_grid = GriddingDistanceFunction.apply(min_x, max_x, min_y, max_y, min_z, max_z, pc, gc)
86 | pred_grids.append(pred_grid)
87 | gt_grids.append(gt_grid)
88 |
89 | return torch.cat(pred_grids, dim=0).contiguous(), torch.cat(gt_grids, dim=0).contiguous()
90 |
91 |
92 | class GriddingLoss(torch.nn.Module):
93 | def __init__(self, scales=[], alphas=[]):
94 | super(GriddingLoss, self).__init__()
95 | self.scales = scales
96 | self.alphas = alphas
97 | self.gridding_dists = [GriddingDistance(scale=s) for s in scales]
98 | self.l1_loss = torch.nn.L1Loss()
99 |
100 | def forward(self, pred_cloud, gt_cloud):
101 | gridding_loss = None
102 | n_dists = len(self.scales)
103 |
104 | for i in range(n_dists):
105 | alpha = self.alphas[i]
106 | gdist = self.gridding_dists[i]
107 | pred_grid, gt_grid = gdist(pred_cloud, gt_cloud)
108 |
109 | if gridding_loss is None:
110 | gridding_loss = alpha * self.l1_loss(pred_grid, gt_grid)
111 | else:
112 | gridding_loss += alpha * self.l1_loss(pred_grid, gt_grid)
113 |
114 | return gridding_loss
115 |
--------------------------------------------------------------------------------
/extensions/gridding_loss/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding_loss/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/extensions/gridding_loss/build/lib.linux-x86_64-3.8/gridding_distance.cpython-38-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding_loss/build/lib.linux-x86_64-3.8/gridding_distance.cpython-38-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/extensions/gridding_loss/build/temp.linux-x86_64-3.8/build.ninja:
--------------------------------------------------------------------------------
1 | ninja_required_version = 1.3
2 | cxx = c++
3 | nvcc = /cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/bin/nvcc
4 |
5 | cflags = -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -O2 -ftree-vectorize -march=core-avx2 -fno-math-errno -fPIC -fPIC -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c
6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=gridding_distance -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++14
7 | cuda_cflags = -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/TH -I/lustre06/project/6006041/golriz/env/lib/python3.8/site-packages/torch/include/THC -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.4.2/include -I/lustre06/project/6006041/golriz/env/include -I/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.8.10/include/python3.8 -c
8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=gridding_distance -D_GLIBCXX_USE_CXX11_ABI=1 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 -std=c++14
9 | ldflags =
10 |
11 | rule compile
12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
13 | depfile = $out.d
14 | deps = gcc
15 |
16 | rule cuda_compile
17 | depfile = $out.d
18 | deps = gcc
19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
20 |
21 |
22 |
23 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding_loss/build/temp.linux-x86_64-3.8/gridding_distance.o: cuda_compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding_loss/gridding_distance.cu
24 | build /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding_loss/build/temp.linux-x86_64-3.8/gridding_distance_cuda.o: compile /lustre06/project/6006041/golriz/PoinTr-master/extensions/gridding_loss/gridding_distance_cuda.cpp
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/extensions/gridding_loss/build/temp.linux-x86_64-3.8/gridding_distance.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding_loss/build/temp.linux-x86_64-3.8/gridding_distance.o
--------------------------------------------------------------------------------
/extensions/gridding_loss/build/temp.linux-x86_64-3.8/gridding_distance_cuda.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding_loss/build/temp.linux-x86_64-3.8/gridding_distance_cuda.o
--------------------------------------------------------------------------------
/extensions/gridding_loss/dist/gridding_distance-1.0.0-py3.8-linux-x86_64.egg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/extensions/gridding_loss/dist/gridding_distance-1.0.0-py3.8-linux-x86_64.egg
--------------------------------------------------------------------------------
/extensions/gridding_loss/gridding_distance.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: gridding-distance
3 | Version: 1.0.0
4 | Summary: UNKNOWN
5 | License: UNKNOWN
6 | Platform: UNKNOWN
7 |
8 | UNKNOWN
9 |
10 |
--------------------------------------------------------------------------------
/extensions/gridding_loss/gridding_distance.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | gridding_distance.cu
2 | gridding_distance_cuda.cpp
3 | setup.py
4 | gridding_distance.egg-info/PKG-INFO
5 | gridding_distance.egg-info/SOURCES.txt
6 | gridding_distance.egg-info/dependency_links.txt
7 | gridding_distance.egg-info/top_level.txt
--------------------------------------------------------------------------------
/extensions/gridding_loss/gridding_distance.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/extensions/gridding_loss/gridding_distance.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | gridding_distance
2 |
--------------------------------------------------------------------------------
/extensions/gridding_loss/gridding_distance_cuda.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * @Author: Haozhe Xie
3 | * @Date: 2019-12-30 10:59:31
4 | * @Last Modified by: Haozhe Xie
5 | * @Last Modified time: 2020-06-17 14:52:52
6 | * @Email: cshzxie@gmail.com
7 | */
8 |
9 | #include
10 | #include
11 |
12 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
13 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
14 | #define CHECK_CONTIGUOUS(x) \
15 | AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
16 | #define CHECK_INPUT(x) \
17 | CHECK_CUDA(x); \
18 | CHECK_CONTIGUOUS(x)
19 |
20 | std::vector gridding_distance_cuda_forward(float min_x,
21 | float max_x,
22 | float min_y,
23 | float max_y,
24 | float min_z,
25 | float max_z,
26 | torch::Tensor ptcloud,
27 | cudaStream_t stream);
28 |
29 | torch::Tensor gridding_distance_cuda_backward(torch::Tensor grid_pt_weights,
30 | torch::Tensor grid_pt_indexes,
31 | torch::Tensor grad_grid,
32 | cudaStream_t stream);
33 |
34 | std::vector gridding_distance_forward(float min_x,
35 | float max_x,
36 | float min_y,
37 | float max_y,
38 | float min_z,
39 | float max_z,
40 | torch::Tensor ptcloud) {
41 | CHECK_INPUT(ptcloud);
42 |
43 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
44 | return gridding_distance_cuda_forward(min_x, max_x, min_y, max_y, min_z,
45 | max_z, ptcloud, stream);
46 | }
47 |
48 | torch::Tensor gridding_distance_backward(torch::Tensor grid_pt_weights,
49 | torch::Tensor grid_pt_indexes,
50 | torch::Tensor grad_grid) {
51 | CHECK_INPUT(grid_pt_weights);
52 | CHECK_INPUT(grid_pt_indexes);
53 | CHECK_INPUT(grad_grid);
54 |
55 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
56 | return gridding_distance_cuda_backward(grid_pt_weights, grid_pt_indexes,
57 | grad_grid, stream);
58 | }
59 |
60 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
61 | m.def("forward", &gridding_distance_forward,
62 | "Gridding Distance Forward (CUDA)");
63 | m.def("backward", &gridding_distance_backward,
64 | "Gridding Distance Backward (CUDA)");
65 | }
66 |
--------------------------------------------------------------------------------
/extensions/gridding_loss/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Haozhe Xie
3 | # @Date: 2019-12-30 11:03:55
4 | # @Last Modified by: Haozhe Xie
5 | # @Last Modified time: 2019-12-30 11:13:39
6 | # @Email: cshzxie@gmail.com
7 |
8 | from setuptools import setup
9 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
10 |
11 | setup(name='gridding_distance',
12 | version='1.0.0',
13 | ext_modules=[
14 | CUDAExtension('gridding_distance', ['gridding_distance_cuda.cpp', 'gridding_distance.cu']),
15 | ],
16 | cmdclass={'build_ext': BuildExtension})
17 |
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --gres=gpu:1 # Request GPU "generic resources"
3 |
4 |
5 |
6 | #SBATCH --cpus-per-task=6 # Cores proportional to GPUs: 6 on Cedar, 16 on Graham.
7 |
8 |
9 |
10 | #SBATCH --mem=32000M # Memory proportional to GPUs: 32000 Cedar, 64000 Graham.
11 |
12 |
13 |
14 | #SBATCH --time=0-03:00
15 | source /home/golriz/projects/def-guibault/golriz/env/bin/activate
16 | HOME=`pwd`
17 |
18 | # Chamfer Distance
19 | cd $HOME/extensions/chamfer_dist
20 | python setup.py install
21 |
22 | # NOTE: For GRNet
23 |
24 | # Cubic Feature Sampling
25 | cd $HOME/extensions/cubic_feature_sampling
26 | python setup.py install
27 |
28 | # Gridding & Gridding Reverse
29 | cd $HOME/extensions/gridding
30 | python setup.py install
31 |
32 | # Gridding Loss
33 | cd $HOME/extensions/gridding_loss
34 | python setup.py install
35 |
36 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from tools import run_net
2 | from tools import test_net
3 | from utils import parser, dist_utils, misc
4 | from utils.logger import *
5 | from utils.config import *
6 | import time
7 | import os
8 | import torch
9 | from tensorboardX import SummaryWriter
10 |
11 |
12 | def main():
13 | # args
14 | print(os.getcwd())
15 | print('start main')
16 | args = parser.get_args()
17 | args.config = 'cfgs/Tooth_models/PoinTr.yaml'
18 | args.config_SAP = 'SAP/configs/learning_based/noise_small/ours.yaml'
19 | #args.test = 'true'
20 | #args.ckpts = "./pretrained/ckpt-last.pth"
21 | #args.mode ="easy"
22 | args.exp_name = "train-20-03-23"
23 |
24 | # CUDA
25 | args.use_gpu = torch.cuda.is_available()
26 | if args.use_gpu:
27 | torch.backends.cudnn.benchmark = True
28 | # init distributed env first, since logger depends on the dist info.
29 | if args.launcher == 'none':
30 | args.distributed = False
31 | else:
32 | args.distributed = True
33 | dist_utils.init_dist(args.launcher)
34 | # re-set gpu_ids with distributed training mode
35 | _, world_size = dist_utils.get_dist_info()
36 | args.world_size = world_size
37 | print('in the main function')
38 | # logger
39 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
40 | log_file = os.path.join(args.experiment_path, f'{timestamp}.log')
41 | logger = get_root_logger(log_file=log_file, name=args.log_name)
42 | # define the tensorboard writer
43 | if not args.test:
44 | if args.local_rank == 0:
45 | train_writer = SummaryWriter(os.path.join(args.tfboard_path, 'train'))
46 | val_writer = SummaryWriter(os.path.join(args.tfboard_path, 'test'))
47 | else:
48 | train_writer = None
49 | val_writer = None
50 | # config
51 | config = get_config(args, logger = logger)
52 | config_SAP = load_config(args.config_SAP)
53 | # batch size
54 | if args.distributed:
55 | assert config.total_bs % world_size == 0
56 | config.dataset.train.others.bs = config.total_bs // world_size
57 | else:
58 | config.dataset.train.others.bs = config.total_bs
59 | # log
60 | log_args_to_file(args, 'args', logger = logger)
61 | log_config_to_file(config, 'config', logger = logger)
62 | # exit()
63 | logger.info(f'Distributed training: {args.distributed}')
64 | # set random seeds
65 | if args.seed is not None:
66 | logger.info(f'Set random seed to {args.seed}, '
67 | f'deterministic: {args.deterministic}')
68 | misc.set_random_seed(args.seed + args.local_rank, deterministic=args.deterministic) # seed + rank, for augmentation
69 | if args.distributed:
70 | assert args.local_rank == torch.distributed.get_rank()
71 |
72 | # run
73 | if args.test:
74 | test_net(args,config)
75 | else:
76 | run_net(args, config,config_SAP, train_writer, val_writer)
77 |
78 |
79 | if __name__ == '__main__':
80 | main()
81 |
82 |
83 |
84 |
--------------------------------------------------------------------------------
/models/PoinTr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from pointnet2_ops import pointnet2_utils
5 | from extensions.chamfer_dist import ChamferDistanceL1
6 | from .Transformer import PCTransformer
7 | from .build import MODELS
8 | from SAP.src.model import PSR2Mesh
9 | from SAP.src.dpsr import DPSR
10 | import argparse
11 | from SAP.src import utils
12 | from SAP.src.model import Encode2Points
13 | from SAP.src.model import PSR2Mesh
14 | def fps(pc, num,device):
15 | pc = pc.to(device)
16 | fps_idx = pointnet2_utils.furthest_point_sample(pc, num)
17 | sub_pc = pointnet2_utils.gather_operation(pc.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
18 | return sub_pc
19 |
20 |
21 | class Fold(nn.Module):
22 | def __init__(self, in_channel , step , hidden_dim = 512):
23 | super().__init__()
24 |
25 | self.in_channel = in_channel
26 | self.step = step
27 | #self.psr2mesh = PSR2Mesh.apply
28 | a = torch.linspace(-1., 1., steps=step, dtype=torch.float).view(1, step).expand(step, step).reshape(1, -1)
29 | b = torch.linspace(-1., 1., steps=step, dtype=torch.float).view(step, 1).expand(step, step).reshape(1, -1)
30 | self.folding_seed = torch.cat([a, b], dim=0).cuda()
31 |
32 | self.folding1 = nn.Sequential(
33 | nn.Conv1d(in_channel + 2, hidden_dim, 1),
34 | nn.BatchNorm1d(hidden_dim),
35 | nn.ReLU(inplace=True),
36 | nn.Conv1d(hidden_dim, hidden_dim//2, 1),
37 | nn.BatchNorm1d(hidden_dim//2),
38 | nn.ReLU(inplace=True),
39 | nn.Conv1d(hidden_dim//2, 3, 1),
40 | )
41 |
42 | self.folding2 = nn.Sequential(
43 | nn.Conv1d(in_channel + 3, hidden_dim, 1),
44 | nn.BatchNorm1d(hidden_dim),
45 | nn.ReLU(inplace=True),
46 | nn.Conv1d(hidden_dim, hidden_dim//2, 1),
47 | nn.BatchNorm1d(hidden_dim//2),
48 | nn.ReLU(inplace=True),
49 | nn.Conv1d(hidden_dim//2, 3, 1),
50 | )
51 |
52 | def forward(self, x):
53 | num_sample = self.step * self.step
54 | bs = x.size(0)
55 | features = x.view(bs, self.in_channel, 1).expand(bs, self.in_channel, num_sample)
56 | seed = self.folding_seed.view(1, 2, num_sample).expand(bs, 2, num_sample).to(x.device)
57 |
58 | x = torch.cat([seed, features], dim=1)
59 | fd1 = self.folding1(x)
60 | x = torch.cat([fd1, features], dim=1)
61 | fd2 = self.folding2(x)
62 |
63 | return fd2
64 |
65 | @MODELS.register_module()
66 | class PoinTr(nn.Module):
67 | def __init__(self,config, **kwargs):
68 | super().__init__()
69 | self.trans_dim = config.trans_dim
70 | self.knn_layer = config.knn_layer
71 | self.num_pred = config.num_pred
72 | self.num_query = config.num_query
73 | self.model_sap = Encode2Points("SAP/configs/learning_based/noise_small/ours.yaml")
74 | self.psr2mesh = PSR2Mesh.apply
75 | self.fold_step = int(pow(self.num_pred//self.num_query, 0.5) + 0.5)
76 | self.base_model = PCTransformer(in_chans = 3, embed_dim = self.trans_dim, depth = [6, 8], drop_rate = 0., num_query = self.num_query, knn_layer = self.knn_layer)
77 | self.foldingnet = Fold(self.trans_dim, step = self.fold_step, hidden_dim = 256) # rebuild a cluster point
78 | self.dpsr = DPSR(res=(128,128, 128), sig = 2)
79 |
80 |
81 |
82 | self.increase_dim = nn.Sequential(
83 | nn.Conv1d(self.trans_dim, 1024, 1),
84 | nn.BatchNorm1d(1024),
85 | nn.LeakyReLU(negative_slope=0.2),
86 | nn.Conv1d(1024, 1024, 1)
87 | )
88 | self.reduce_map = nn.Linear(self.trans_dim + 1027, self.trans_dim)
89 | self.build_loss_func()
90 |
91 | def build_loss_func(self):
92 | self.loss_func = ChamferDistanceL1()
93 |
94 | def get_loss(self, ret, gt):
95 | loss_fine = self.loss_func(ret, gt)
96 | #loss_fine = self.loss_func(ret[1], gt)
97 | return loss_fine
98 |
99 |
100 | def forward(self,xyz,min_gt,max_gt,value_std_pc,value_centroid):
101 |
102 |
103 | q, coarse_point_cloud = self.base_model(xyz) # B M C and B M 3
104 |
105 | B, M ,C = q.shape
106 |
107 | global_feature = self.increase_dim(q.transpose(1,2)).transpose(1,2) # B M 1024
108 | global_feature = torch.max(global_feature, dim=1)[0] # B 1024
109 | #print(global_feature.shape)
110 | rebuild_feature = torch.cat([
111 | global_feature.unsqueeze(-2).expand(-1, M, -1),
112 | q,
113 | coarse_point_cloud], dim=-1) # B M 1027 + C
114 |
115 | rebuild_feature = self.reduce_map(rebuild_feature.reshape(B*M, -1)) # BM C
116 | # # NOTE: try to rebuild pc
117 | # coarse_point_cloud = self.refine_coarse(rebuild_feature).reshape(B, M, 3)
118 | #print(rebuild_feature.shape)
119 | # NOTE: foldingNet
120 | relative_xyz = self.foldingnet(rebuild_feature).reshape(B, M, 3, -1) # B M 3 S
121 | #print(relative_xyz.shape)
122 | rebuild_points = (relative_xyz + coarse_point_cloud.unsqueeze(-1)).transpose(2,3).reshape(B, -1, 3) # B N 3
123 | ##print(rebuild_points.shape)
124 |
125 |
126 | # cat the input
127 | #inp_sparse = fps(xyz, self.num_query)
128 |
129 | # denormalize the data based on mean and std
130 | value_std_points=value_std_pc.view((rebuild_points.shape[0],1,3))
131 | value_centroid_points=value_centroid.view((rebuild_points.shape[0],1,3))
132 | De_point = torch.multiply(rebuild_points, value_std_points) + value_centroid_points
133 |
134 | #Normalize data to min and max on gt
135 | min_depoint=min_gt.view(rebuild_points.shape[0],1,1)
136 | max_depoint = max_gt.view(rebuild_points.shape[0],1,1)
137 |
138 | Npoints = torch.div(torch.subtract(De_point,min_depoint),torch.subtract((max_depoint + 1),min_depoint))
139 |
140 | #SAP
141 | out = self.model_sap(Npoints)
142 | points, normals = out
143 | psr_grid= self.dpsr(points, normals)
144 |
145 | #return psr_grid,points,rebuild_points,min_depoint,max_depoint
146 | return psr_grid,rebuild_points
147 |
148 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_model_from_cfg
2 | import models.PoinTr
3 | import SAP.src.model
4 | import SAP.src.dpsr
--------------------------------------------------------------------------------
/models/__pycache__/FoldingNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/FoldingNet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/GRNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/GRNet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/PCN.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/PCN.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/PoinTr.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/PoinTr.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/TopNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/TopNet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/Transformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/Transformer.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/build.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/build.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/dgcnn_group.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/dgcnn_group.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/easy_mesh_vtk.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/models/__pycache__/easy_mesh_vtk.cpython-38.pyc
--------------------------------------------------------------------------------
/models/build.py:
--------------------------------------------------------------------------------
1 | from utils import registry
2 |
3 |
4 | MODELS = registry.Registry('models')
5 |
6 |
7 | def build_model_from_cfg(cfg, **kwargs):
8 | """
9 | Build a dataset, defined by `dataset_name`.
10 | Args:
11 | cfg (eDICT):
12 | Returns:
13 | Dataset: a constructed dataset specified by dataset_name.
14 | """
15 | return MODELS.build(cfg, **kwargs)
16 |
17 |
18 |
--------------------------------------------------------------------------------
/models/dgcnn_group.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from pointnet2_ops import pointnet2_utils
4 | from knn_cuda import KNN
5 | knn = KNN(k=16, transpose_mode=False)
6 |
7 |
8 | class DGCNN_Grouper(nn.Module):
9 | def __init__(self):
10 | super().__init__()
11 | '''
12 | K has to be 16
13 | '''
14 | self.input_trans = nn.Conv1d(3, 8, 1)
15 |
16 | self.layer1 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=1, bias=False),
17 | nn.GroupNorm(4, 32),
18 | nn.LeakyReLU(negative_slope=0.2)
19 | )
20 |
21 | self.layer2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
22 | nn.GroupNorm(4, 64),
23 | nn.LeakyReLU(negative_slope=0.2)
24 | )
25 |
26 | self.layer3 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=1, bias=False),
27 | nn.GroupNorm(4, 64),
28 | nn.LeakyReLU(negative_slope=0.2)
29 | )
30 |
31 | self.layer4 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=1, bias=False),
32 | nn.GroupNorm(4, 128),
33 | nn.LeakyReLU(negative_slope=0.2)
34 | )
35 |
36 |
37 | @staticmethod
38 | def fps_downsample(coor, x, num_group):
39 | xyz = coor.transpose(1, 2).contiguous() # b, n, 3
40 | fps_idx = pointnet2_utils.furthest_point_sample(xyz, num_group)
41 |
42 | combined_x = torch.cat([coor, x], dim=1)
43 |
44 | new_combined_x = (
45 | pointnet2_utils.gather_operation(
46 | combined_x, fps_idx
47 | )
48 | )
49 |
50 | new_coor = new_combined_x[:, :3]
51 | new_x = new_combined_x[:, 3:]
52 |
53 | return new_coor, new_x
54 |
55 | @staticmethod
56 | def get_graph_feature(coor_q, x_q, coor_k, x_k):
57 |
58 | # coor: bs, 3, np, x: bs, c, np
59 |
60 | k = 16
61 | batch_size = x_k.size(0)
62 | num_points_k = x_k.size(2)
63 | num_points_q = x_q.size(2)
64 |
65 | with torch.no_grad():
66 | _, idx = knn(coor_k, coor_q) # bs k np
67 | assert idx.shape[1] == k
68 | idx_base = torch.arange(0, batch_size, device=x_q.device).view(-1, 1, 1) * num_points_k
69 | idx = idx + idx_base
70 | idx = idx.view(-1)
71 | num_dims = x_k.size(1)
72 | x_k = x_k.transpose(2, 1).contiguous()
73 | feature = x_k.view(batch_size * num_points_k, -1)[idx, :]
74 | feature = feature.view(batch_size, k, num_points_q, num_dims).permute(0, 3, 2, 1).contiguous()
75 | x_q = x_q.view(batch_size, num_dims, num_points_q, 1).expand(-1, -1, -1, k)
76 | feature = torch.cat((feature - x_q, x_q), dim=1)
77 | return feature
78 |
79 | def forward(self, x):
80 |
81 | # x: bs, 3, np
82 |
83 | # bs 3 N(128) bs C(224)128 N(128)
84 | coor = x
85 | f = self.input_trans(x)
86 |
87 | f = self.get_graph_feature(coor, f, coor, f)
88 | f = self.layer1(f)
89 | f = f.max(dim=-1, keepdim=False)[0]
90 |
91 | f = self.get_graph_feature(coor, f, coor, f)
92 | f = self.layer2(f)
93 | f = f.max(dim=-1, keepdim=False)[0]
94 |
95 | f = self.get_graph_feature(coor, f, coor, f)
96 | f = self.layer3(f)
97 | f = f.max(dim=-1, keepdim=False)[0]
98 |
99 | f = self.get_graph_feature(coor, f, coor, f)
100 | f = self.layer4(f)
101 | f = f.max(dim=-1, keepdim=False)[0]
102 |
103 | return coor, f
--------------------------------------------------------------------------------
/readme.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | argparse
2 | easydict
3 | h5py
4 | matplotlib
5 | numpy
6 | open3d==0.9
7 | opencv-python
8 | pyyaml
9 | scipy
10 | tensorboardX
11 | timm==0.4.5
12 | tqdm
13 | transforms3d
14 | gcc,cuda
15 | igl
16 | trimesh tensorboard plyfile open3d scikit-image python-mnist opencv-python av pykdtree ipdb
17 | Pip install https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl
18 | pip install --no-index pytorch3d
19 | pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.0+${CUDA}.html
20 | Pip install "git+https://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib"
21 | Pip install –no-index torch torchvision torchaudio
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from .runner import run_net
2 | from .runner import test_net
--------------------------------------------------------------------------------
/tools/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/tools/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/__pycache__/builder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/tools/__pycache__/builder.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/__pycache__/runner.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/tools/__pycache__/runner.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/builder.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | # online package
3 | import torch
4 | # optimizer
5 | import torch.optim as optim
6 | # dataloader
7 | from datasets import build_dataset_from_cfg
8 | from models import build_model_from_cfg
9 | # utils
10 | from utils.logger import *
11 | from utils.misc import *
12 | from collections import OrderedDict
13 |
14 | def dataset_builder(args, config):
15 | dataset = build_dataset_from_cfg(config._base_, config.others)
16 | shuffle = config.others.subset == 'train'
17 | if args.distributed:
18 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle = shuffle)
19 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = config.others.bs if shuffle else 1,
20 | num_workers = int(args.num_workers),
21 | drop_last = config.others.subset == 'train',
22 | worker_init_fn = worker_init_fn,
23 | sampler = sampler)
24 | else:
25 | sampler = None
26 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.others.bs if shuffle else 1,
27 | shuffle = shuffle,
28 | drop_last = config.others.subset == 'train',
29 | num_workers = int(args.num_workers),
30 | worker_init_fn=worker_init_fn)
31 | return sampler, dataloader
32 |
33 | def model_builder(config):
34 | model = build_model_from_cfg(config)
35 | return model
36 |
37 | def build_opti_sche(base_model, config):
38 | opti_config = config.optimizer
39 | if opti_config.type == 'AdamW':
40 | optimizer = optim.AdamW(base_model.parameters(), **opti_config.kwargs)
41 | elif opti_config.type == 'Adam':
42 | optimizer = optim.Adam(base_model.parameters(), **opti_config.kwargs)
43 | elif opti_config.type == 'SGD':
44 | optimizer = optim.SGD(base_model.parameters(), nesterov=True, **opti_config.kwargs)
45 | else:
46 | raise NotImplementedError()
47 |
48 | sche_config = config.scheduler
49 | if sche_config.type == 'LambdaLR':
50 | scheduler = build_lambda_sche(optimizer, sche_config.kwargs) # misc.py
51 | elif sche_config.type == 'StepLR':
52 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **sche_config.kwargs)
53 | else:
54 | raise NotImplementedError()
55 |
56 | if config.get('bnmscheduler') is not None:
57 | bnsche_config = config.bnmscheduler
58 | if bnsche_config.type == 'Lambda':
59 | bnscheduler = build_lambda_bnsche(base_model, bnsche_config.kwargs) # misc.py
60 | scheduler = [scheduler, bnscheduler]
61 |
62 | return optimizer, scheduler
63 |
64 | def resume_model(base_model, args, logger = None):
65 | ckpt_path = os.path.join(args.experiment_path, 'ckpt-last.pth')
66 | if not os.path.exists(ckpt_path):
67 | print_log(f'[RESUME INFO] no checkpoint file from path {ckpt_path}...', logger = logger)
68 | return 0, 0
69 | print_log(f'[RESUME INFO] Loading model weights from {ckpt_path}...', logger = logger )
70 |
71 | # load state dict
72 | map_location = {'cuda:%d' % 0: 'cuda:%d' % args.local_rank}
73 | state_dict = torch.load(ckpt_path, map_location=map_location)
74 | # parameter resume of base model
75 | # if args.local_rank == 0:
76 | base_ckpt = {k.replace("module.", ""): v for k, v in state_dict['base_model'].items()}
77 | base_model.load_state_dict(base_ckpt)
78 |
79 | # parameter
80 | start_epoch = state_dict['epoch'] + 1
81 | best_metrics = state_dict['best_metrics']
82 | if not isinstance(best_metrics, dict):
83 | best_metrics = best_metrics.state_dict()
84 | # print(best_metrics)
85 |
86 | print_log(f'[RESUME INFO] resume ckpts @ {start_epoch - 1} epoch( best_metrics = {str(best_metrics):s})', logger = logger)
87 | return start_epoch, best_metrics
88 |
89 | def resume_optimizer(optimizer, args, logger = None):
90 | ckpt_path = os.path.join(args.experiment_path, 'ckpt-last.pth')
91 | if not os.path.exists(ckpt_path):
92 | print_log(f'[RESUME INFO] no checkpoint file from path {ckpt_path}...', logger = logger)
93 | return 0, 0, 0
94 | print_log(f'[RESUME INFO] Loading optimizer from {ckpt_path}...', logger = logger )
95 | # load state dict
96 | state_dict = torch.load(ckpt_path, map_location='cpu')
97 | # optimizer
98 | optimizer.load_state_dict(state_dict['optimizer'])
99 |
100 | def save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, prefix, args, logger = None):
101 | if args.local_rank == 0:
102 | torch.save({
103 | 'base_model' : base_model.module.state_dict() if args.distributed else base_model.state_dict(),
104 | 'optimizer' : optimizer.state_dict(),
105 | 'epoch' : epoch,
106 | 'metrics' : metrics if metrics is not None else dict(),
107 | 'best_metrics' : best_metrics if best_metrics is not None else dict(),
108 | }, os.path.join(args.experiment_path, prefix + '.pth'))
109 | print_log(f"Save checkpoint at {os.path.join(args.experiment_path, prefix + '.pth')}", logger = logger)
110 |
111 | def load_model(base_model, ckpt_path, logger = None):
112 | if not os.path.exists(ckpt_path):
113 | raise NotImplementedError('no checkpoint file from path %s...' % ckpt_path)
114 | print_log(f'Loading weights from {ckpt_path}...', logger = logger )
115 |
116 | # load state dict
117 | state_dict = torch.load(ckpt_path, map_location='cpu')
118 | load_model_manual(state_dict['base_model'], base_model)
119 | # parameter resume of base model
120 | if state_dict.get('model') is not None:
121 | base_ckpt = {k.replace("module.", ""): v for k, v in state_dict['model'].items()}
122 | elif state_dict.get('base_model') is not None:
123 | base_ckpt = {k.replace("module.", ""): v for k, v in state_dict['base_model'].items()}
124 | else:
125 | raise RuntimeError('mismatch of ckpt weight')
126 | #base_model.load_state_dict(base_ckpt)
127 |
128 | epoch = -1
129 | if state_dict.get('epoch') is not None:
130 | epoch = state_dict['epoch']
131 | if state_dict.get('metrics') is not None:
132 | metrics = state_dict['metrics']
133 | if not isinstance(metrics, dict):
134 | metrics = metrics
135 | # metrics = metrics.state_dict()
136 | else:
137 | metrics = 'No Metrics'
138 | print_log(f'ckpts @ {epoch} epoch( performance = {str(metrics):s})', logger = logger)
139 | return
140 |
141 | def load_model_manual(state_dict, model):
142 | new_state_dict = OrderedDict()
143 | is_model_parallel = isinstance(model, torch.nn.DataParallel)
144 | for k, v in state_dict.items():
145 | if k.startswith('module.') != is_model_parallel:
146 | if k.startswith('module.'):
147 | # remove module
148 | k = k[7:]
149 | else:
150 | # add module
151 | k = 'module.' + k
152 |
153 | new_state_dict[k]=v
154 |
155 | model.load_state_dict(new_state_dict,strict=False)
--------------------------------------------------------------------------------
/utils/AverageMeter.py:
--------------------------------------------------------------------------------
1 |
2 | class AverageMeter(object):
3 | def __init__(self, items=None):
4 | self.items = items
5 | self.n_items = 1 if items is None else len(items)
6 | self.reset()
7 |
8 | def reset(self):
9 | self._val = [0] * self.n_items
10 | self._sum = [0] * self.n_items
11 | self._count = [0] * self.n_items
12 |
13 | def update(self, values):
14 | if type(values).__name__ == 'list':
15 | for idx, v in enumerate(values):
16 | self._val[idx] = v
17 | self._sum[idx] += v
18 | self._count[idx] += 1
19 | else:
20 | self._val[0] = values
21 | self._sum[0] += values
22 | self._count[0] += 1
23 |
24 | def val(self, idx=None):
25 | if idx is None:
26 | return self._val[0] if self.items is None else [self._val[i] for i in range(self.n_items)]
27 | else:
28 | return self._val[idx]
29 |
30 | def count(self, idx=None):
31 | if idx is None:
32 | return self._count[0] if self.items is None else [self._count[i] for i in range(self.n_items)]
33 | else:
34 | return self._count[idx]
35 |
36 | def avg(self, idx=None):
37 | if idx is None:
38 | return self._sum[0] / self._count[0] if self.items is None else [
39 | self._sum[i] / self._count[i] for i in range(self.n_items)
40 | ]
41 | else:
42 | return self._sum[idx] / self._count[idx]
--------------------------------------------------------------------------------
/utils/__pycache__/AverageMeter.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/AverageMeter.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dist_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/dist_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/logger.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/misc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/misc.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/parser.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/parser.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/registry.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Golriz-code/DMC/6957e96eb84057408b89f507666d6f644f6d8c43/utils/__pycache__/registry.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from easydict import EasyDict
3 | import os
4 | from .logger import print_log
5 |
6 | def log_args_to_file(args, pre='args', logger=None):
7 | for key, val in args.__dict__.items():
8 | print_log(f'{pre}.{key} : {val}', logger = logger)
9 |
10 | def log_config_to_file(cfg, pre='cfg', logger=None):
11 | for key, val in cfg.items():
12 | if isinstance(cfg[key], EasyDict):
13 | print_log(f'{pre}.{key} = edict()', logger = logger)
14 | log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger)
15 | continue
16 | print_log(f'{pre}.{key} : {val}', logger = logger)
17 |
18 | def merge_new_config(config, new_config):
19 | for key, val in new_config.items():
20 | if not isinstance(val, dict):
21 | if key == '_base_':
22 | with open(new_config['_base_'], 'r') as f:
23 | try:
24 | val = yaml.load(f, Loader=yaml.FullLoader)
25 | except:
26 | val = yaml.load(f)
27 | config[key] = EasyDict()
28 | merge_new_config(config[key], val)
29 | else:
30 | config[key] = val
31 | continue
32 | if key not in config:
33 | config[key] = EasyDict()
34 | merge_new_config(config[key], val)
35 | return config
36 |
37 | def cfg_from_yaml_file(cfg_file):
38 | config = EasyDict()
39 | with open(cfg_file, 'r') as f:
40 | try:
41 | new_config = yaml.load(f, Loader=yaml.FullLoader)
42 | except:
43 | new_config = yaml.load(f)
44 | merge_new_config(config=config, new_config=new_config)
45 | return config
46 |
47 | def get_config(args, logger=None):
48 | if args.resume:
49 | cfg_path = os.path.join(args.experiment_path, 'config.yaml')
50 | if not os.path.exists(cfg_path):
51 | print_log("Failed to resume", logger = logger)
52 | raise FileNotFoundError()
53 | print_log(f'Resume yaml from {cfg_path}', logger = logger)
54 | args.config = cfg_path
55 | config = cfg_from_yaml_file(args.config)
56 | if not args.resume and args.local_rank == 0:
57 | save_experiment_config(args, config, logger)
58 | return config
59 |
60 | def load_config(path, default_path=None):
61 | ''' Loads config file.
62 |
63 | Args:
64 | path (str): path to config file
65 | default_path (bool): whether to use default path
66 | '''
67 | # Load configuration from file itself
68 | with open(path, 'r') as f:
69 | cfg_special = yaml.load(f, Loader=yaml.Loader)
70 |
71 | # Check if we should inherit from a config
72 | inherit_from = cfg_special.get('inherit_from')
73 |
74 | # If yes, load this config first as default
75 | # If no, use the default_path
76 | if inherit_from is not None:
77 | cfg = load_config(inherit_from, default_path)
78 | elif default_path is not None:
79 | with open(default_path, 'r') as f:
80 | cfg = yaml.load(f, Loader=yaml.Loader)
81 | else:
82 | cfg = dict()
83 |
84 | # Include main configuration
85 | update_recursive(cfg, cfg_special)
86 |
87 | return cfg
88 |
89 | def update_recursive(dict1, dict2):
90 | ''' Update two config dictionaries recursively.
91 |
92 | Args:
93 | dict1 (dict): first dictionary to be updated
94 | dict2 (dict): second dictionary which entries should be used
95 |
96 | '''
97 | for k, v in dict2.items():
98 | if k not in dict1:
99 | dict1[k] = dict()
100 | if isinstance(v, dict):
101 | update_recursive(dict1[k], v)
102 | else:
103 | dict1[k] = v
104 |
105 | def save_experiment_config(args, config, logger = None):
106 | config_path = os.path.join(args.experiment_path, 'config.yaml')
107 | os.system('cp %s %s' % (args.config, config_path))
108 | print_log(f'Copy the Config file from {args.config} to {config_path}',logger = logger )
--------------------------------------------------------------------------------
/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.multiprocessing as mp
5 | from torch import distributed as dist
6 |
7 |
8 |
9 | def init_dist(launcher, backend='nccl', **kwargs):
10 | if mp.get_start_method(allow_none=True) is None:
11 | mp.set_start_method('spawn')
12 | if launcher == 'pytorch':
13 | _init_dist_pytorch(backend, **kwargs)
14 | else:
15 | raise ValueError(f'Invalid launcher type: {launcher}')
16 |
17 |
18 | def _init_dist_pytorch(backend, **kwargs):
19 | # TODO: use local_rank instead of rank % num_gpus
20 | rank = int(os.environ['RANK'])
21 | num_gpus = torch.cuda.device_count()
22 | torch.cuda.set_device(rank % num_gpus)
23 | dist.init_process_group(backend=backend, **kwargs)
24 | print(f'init distributed in rank {torch.distributed.get_rank()}')
25 |
26 |
27 | def get_dist_info():
28 | if dist.is_available():
29 | initialized = dist.is_initialized()
30 | else:
31 | initialized = False
32 | if initialized:
33 | rank = dist.get_rank()
34 | world_size = dist.get_world_size()
35 | else:
36 | rank = 0
37 | world_size = 1
38 | return rank, world_size
39 |
40 |
41 | def reduce_tensor(tensor, args):
42 | '''
43 | for acc kind, get the mean in each gpu
44 | '''
45 | rt = tensor.clone()
46 | torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
47 | rt /= args.world_size
48 | return rt
49 |
50 | def gather_tensor(tensor, args):
51 | output_tensors = [tensor.clone() for _ in range(args.world_size)]
52 | torch.distributed.all_gather(output_tensors, tensor)
53 | concat = torch.cat(output_tensors, dim=0)
54 | return concat
55 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch.distributed as dist
3 |
4 | logger_initialized = {}
5 |
6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'):
7 | """Get root logger and add a keyword filter to it.
8 | The logger will be initialized if it has not been initialized. By default a
9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
10 | also be added. The name of the root logger is the top-level package name,
11 | e.g., "mmdet3d".
12 | Args:
13 | log_file (str, optional): File path of log. Defaults to None.
14 | log_level (int, optional): The level of logger.
15 | Defaults to logging.INFO.
16 | name (str, optional): The name of the root logger, also used as a
17 | filter keyword. Defaults to 'mmdet3d'.
18 | Returns:
19 | :obj:`logging.Logger`: The obtained logger
20 | """
21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level)
22 | # add a logging filter
23 | logging_filter = logging.Filter(name)
24 | logging_filter.filter = lambda record: record.find(name) != -1
25 |
26 | return logger
27 |
28 |
29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
30 | """Initialize and get a logger by name.
31 | If the logger has not been initialized, this method will initialize the
32 | logger by adding one or two handlers, otherwise the initialized logger will
33 | be directly returned. During initialization, a StreamHandler will always be
34 | added. If `log_file` is specified and the process rank is 0, a FileHandler
35 | will also be added.
36 | Args:
37 | name (str): Logger name.
38 | log_file (str | None): The log filename. If specified, a FileHandler
39 | will be added to the logger.
40 | log_level (int): The logger level. Note that only the process of
41 | rank 0 is affected, and other processes will set the level to
42 | "Error" thus be silent most of the time.
43 | file_mode (str): The file mode used in opening log file.
44 | Defaults to 'w'.
45 | Returns:
46 | logging.Logger: The expected logger.
47 | """
48 | logger = logging.getLogger(name)
49 | if name in logger_initialized:
50 | return logger
51 | # handle hierarchical names
52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the
53 | # initialization since it is a child of "a".
54 | for logger_name in logger_initialized:
55 | if name.startswith(logger_name):
56 | return logger
57 |
58 | # handle duplicate logs to the console
59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
60 | # to the root logger. As logger.propagate is True by default, this root
61 | # level handler causes logging messages from rank>0 processes to
62 | # unexpectedly show up on the console, creating much unwanted clutter.
63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log
64 | # at the ERROR level.
65 | for handler in logger.root.handlers:
66 | if type(handler) is logging.StreamHandler:
67 | handler.setLevel(logging.ERROR)
68 |
69 | stream_handler = logging.StreamHandler()
70 | handlers = [stream_handler]
71 |
72 | if dist.is_available() and dist.is_initialized():
73 | rank = dist.get_rank()
74 | else:
75 | rank = 0
76 |
77 | # only rank 0 will add a FileHandler
78 | if rank == 0 and log_file is not None:
79 | # Here, the default behaviour of the official logger is 'a'. Thus, we
80 | # provide an interface to change the file mode to the default
81 | # behaviour.
82 | file_handler = logging.FileHandler(log_file, file_mode)
83 | handlers.append(file_handler)
84 |
85 | formatter = logging.Formatter(
86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
87 | for handler in handlers:
88 | handler.setFormatter(formatter)
89 | handler.setLevel(log_level)
90 | logger.addHandler(handler)
91 |
92 | if rank == 0:
93 | logger.setLevel(log_level)
94 | else:
95 | logger.setLevel(logging.ERROR)
96 |
97 | logger_initialized[name] = True
98 |
99 |
100 | return logger
101 |
102 |
103 | def print_log(msg, logger=None, level=logging.INFO):
104 | """Print a log message.
105 | Args:
106 | msg (str): The message to be logged.
107 | logger (logging.Logger | str | None): The logger to be used.
108 | Some special loggers are:
109 | - "silent": no message will be printed.
110 | - other str: the logger obtained with `get_root_logger(logger)`.
111 | - None: The `print()` method will be used to print log messages.
112 | level (int): Logging level. Only available when `logger` is a Logger
113 | object or "root".
114 | """
115 | if logger is None:
116 | print(msg)
117 | elif isinstance(logger, logging.Logger):
118 | logger.log(level, msg)
119 | elif logger == 'silent':
120 | pass
121 | elif isinstance(logger, str):
122 | _logger = get_logger(logger)
123 | _logger.log(level, msg)
124 | else:
125 | raise TypeError(
126 | 'logger should be either a logging.Logger object, str, '
127 | f'"silent" or None, but got {type(logger)}')
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Haozhe Xie
3 | # @Date: 2019-08-08 14:31:30
4 | # @Last Modified by: Haozhe Xie
5 | # @Last Modified time: 2020-05-25 09:13:32
6 | # @Email: cshzxie@gmail.com
7 |
8 | import logging
9 | import open3d
10 |
11 | from extensions.chamfer_dist import ChamferDistanceL1, ChamferDistanceL2
12 |
13 |
14 | class Metrics(object):
15 | ITEMS = [{
16 | 'name': 'F-Score',
17 | 'enabled': True,
18 | 'eval_func': 'cls._get_f_score',
19 | 'is_greater_better': True,
20 | 'init_value': 0
21 | }, {
22 | 'name': 'CDL1',
23 | 'enabled': True,
24 | 'eval_func': 'cls._get_chamfer_distancel1',
25 | 'eval_object': ChamferDistanceL1(ignore_zeros=True),
26 | 'is_greater_better': False,
27 | 'init_value': 32767
28 | }, {
29 | 'name': 'CDL2',
30 | 'enabled': True,
31 | 'eval_func': 'cls._get_chamfer_distancel2',
32 | 'eval_object': ChamferDistanceL2(ignore_zeros=True),
33 | 'is_greater_better': False,
34 | 'init_value': 32767
35 | }]
36 |
37 | @classmethod
38 | def get(cls, pred, gt):
39 | _items = cls.items()
40 | _values = [0] * len(_items)
41 | for i, item in enumerate(_items):
42 | eval_func = eval(item['eval_func'])
43 | _values[i] = eval_func(pred, gt)
44 |
45 | return _values
46 |
47 | @classmethod
48 | def items(cls):
49 | return [i for i in cls.ITEMS if i['enabled']]
50 |
51 | @classmethod
52 | def names(cls):
53 | _items = cls.items()
54 | return [i['name'] for i in _items]
55 |
56 | @classmethod
57 | def _get_f_score(cls, pred, gt, th=0.01):
58 |
59 | """References: https://github.com/lmb-freiburg/what3d/blob/master/util.py"""
60 | b = pred.size(0)
61 | assert pred.size(0) == gt.size(0)
62 | if b != 1:
63 | f_score_list = []
64 | for idx in range(b):
65 | f_score_list.append(cls._get_f_score(pred[idx:idx+1], gt[idx:idx+1]))
66 | return sum(f_score_list)/len(f_score_list)
67 | else:
68 | pred = cls._get_open3d_ptcloud(pred)
69 | gt = cls._get_open3d_ptcloud(gt)
70 |
71 | dist1 = pred.compute_point_cloud_distance(gt)
72 | dist2 = gt.compute_point_cloud_distance(pred)
73 |
74 | recall = float(sum(d < th for d in dist2)) / float(len(dist2))
75 | precision = float(sum(d < th for d in dist1)) / float(len(dist1))
76 | return 2 * recall * precision / (recall + precision) if recall + precision else 0
77 |
78 | @classmethod
79 | def _get_open3d_ptcloud(cls, tensor):
80 | """pred and gt bs is 1"""
81 | tensor = tensor.squeeze().cpu().numpy()
82 | ptcloud = open3d.geometry.PointCloud()
83 | ptcloud.points = open3d.utility.Vector3dVector(tensor)
84 |
85 | return ptcloud
86 |
87 | @classmethod
88 | def _get_chamfer_distancel1(cls, pred, gt):
89 | chamfer_distance = cls.ITEMS[1]['eval_object']
90 | return chamfer_distance(pred, gt).item() * 1000
91 |
92 | @classmethod
93 | def _get_chamfer_distancel2(cls, pred, gt):
94 | chamfer_distance = cls.ITEMS[2]['eval_object']
95 | return chamfer_distance(pred, gt).item() * 1000
96 |
97 | def __init__(self, metric_name, values):
98 | self._items = Metrics.items()
99 | self._values = [item['init_value'] for item in self._items]
100 | self.metric_name = metric_name
101 |
102 | if type(values).__name__ == 'list':
103 | self._values = values
104 | elif type(values).__name__ == 'dict':
105 | metric_indexes = {}
106 | for idx, item in enumerate(self._items):
107 | item_name = item['name']
108 | metric_indexes[item_name] = idx
109 | for k, v in values.items():
110 | if k not in metric_indexes:
111 | logging.warn('Ignore Metric[Name=%s] due to disability.' % k)
112 | continue
113 | self._values[metric_indexes[k]] = v
114 | else:
115 | raise Exception('Unsupported value type: %s' % type(values))
116 |
117 | def state_dict(self):
118 | _dict = dict()
119 | for i in range(len(self._items)):
120 | item = self._items[i]['name']
121 | value = self._values[i]
122 | _dict[item] = value
123 |
124 | return _dict
125 |
126 | def __repr__(self):
127 | return str(self.state_dict())
128 |
129 | def better_than(self, other):
130 | if other is None:
131 | return True
132 |
133 | _index = -1
134 | for i, _item in enumerate(self._items):
135 | if _item['name'] == self.metric_name:
136 | _index = i
137 | break
138 | if _index == -1:
139 | raise Exception('Invalid metric name to compare.')
140 |
141 | _metric = self._items[i]
142 | _value = self._values[_index]
143 | other_value = other._values[_index]
144 | return _value > other_value if _metric['is_greater_better'] else _value < other_value
145 |
--------------------------------------------------------------------------------
/utils/parser.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from pathlib import Path
4 |
5 | def get_args():
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument(
8 | '--config',
9 | type = str,
10 | default='cfgs/Tooth_models/PoinTr.yaml',
11 | help = 'yaml config file')
12 | parser.add_argument(
13 | '--config_SAP',
14 | type=str,
15 | default='SAP/configs/learning_based/noise_small/ours.yaml',
16 | help='yaml config file')
17 | parser.add_argument(
18 | '--launcher',
19 | choices=['none', 'pytorch'],
20 | default='none',
21 | help='job launcher')
22 | parser.add_argument('--local_rank', type=int, default=0)
23 | parser.add_argument('--num_workers', type=int, default=4)
24 | # seed
25 | parser.add_argument('--seed', type=int, default=0, help='random seed')
26 | parser.add_argument(
27 | '--deterministic',
28 | action='store_true',
29 | help='whether to set deterministic options for CUDNN backend.')
30 | # bn
31 | parser.add_argument(
32 | '--sync_bn',
33 | action='store_true',
34 | default=False,
35 | help='whether to use sync bn')
36 | # some args
37 | parser.add_argument('--exp_name', type = str, default='default', help = 'experiment name')
38 | parser.add_argument('--start_ckpts', type = str, default=None, help = 'reload used ckpt path')
39 | parser.add_argument('--ckpts', type = str, default=None, help = 'test used ckpt path')
40 | parser.add_argument('--val_freq', type = int, default=1, help = 'test freq')
41 | parser.add_argument(
42 | '--resume',
43 | action='store_true',
44 | default=False,
45 | help = 'autoresume training (interrupted by accident)')
46 | parser.add_argument(
47 | '--test',
48 | action='store_true',
49 | default=False,
50 | help = 'test mode for certain ckpt')
51 | parser.add_argument(
52 | '--mode',
53 | choices=['easy', 'median', 'hard', None],
54 | default=None,
55 | help = 'difficulty mode for shapenet')
56 | args = parser.parse_args()
57 |
58 | if args.test and args.resume:
59 | raise ValueError(
60 | '--test and --resume cannot be both activate')
61 |
62 | if args.resume and args.start_ckpts is not None:
63 | raise ValueError(
64 | '--resume and --start_ckpts cannot be both activate')
65 |
66 | if args.test and args.ckpts is None:
67 | raise ValueError(
68 | 'ckpts shouldnt be None while test mode')
69 |
70 | if 'LOCAL_RANK' not in os.environ:
71 | os.environ['LOCAL_RANK'] = str(args.local_rank)
72 |
73 | if args.test:
74 | args.exp_name = 'test_' + args.exp_name
75 | if args.mode is not None:
76 | args.exp_name = args.exp_name + '_' +args.mode
77 | args.experiment_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem, args.exp_name)
78 | args.tfboard_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem,'TFBoard' ,args.exp_name)
79 | args.log_name = Path(args.config).stem
80 | create_experiment_dir(args)
81 | return args
82 |
83 | def create_experiment_dir(args):
84 | if not os.path.exists(args.experiment_path):
85 | os.makedirs(args.experiment_path)
86 | print('Create experiment path successfully at %s' % args.experiment_path)
87 | if not os.path.exists(args.tfboard_path):
88 | os.makedirs(args.tfboard_path)
89 | print('Create TFBoard path successfully at %s' % args.tfboard_path)
90 |
91 |
--------------------------------------------------------------------------------