├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── S3DIS.md
├── S3DIS_fix.diff
├── Semantic3D.md
├── learning
├── __init__.py
├── custom_dataset.py
├── ecc
│ ├── GraphConvInfo.py
│ ├── GraphConvModule.py
│ ├── GraphPoolInfo.py
│ ├── GraphPoolModule.py
│ ├── __init__.py
│ ├── cuda_kernels.py
│ ├── test_GraphConvModule.py
│ ├── test_GraphPoolModule.py
│ └── utils.py
├── evaluate.py
├── graphnet.py
├── main.py
├── metrics.py
├── modules.py
├── pointnet.py
├── s3dis_dataset.py
├── sema3d_dataset.py
├── spg.py
└── vkitti_dataset.py
├── partition
├── __init__.py
├── graphs.py
├── partition.py
├── ply_c
│ ├── CMakeLists.txt
│ ├── FindNumPy.cmake
│ ├── __init__.py
│ ├── connected_components.cpp
│ ├── ply_c.cpp
│ └── random_subgraph.cpp
├── provider.py
├── visualize.py
└── write_Semantic3d.py
├── supervized_partition
├── __init__.py
├── evaluate_partition.py
├── folderhierarchy.py
├── generate_partition.py
├── graph_processing.py
├── losses.py
└── supervized_partition.py
└── vKITTI3D.md
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *DS_Store*
3 | *.vscode
4 | *.so
5 | partition/ply_c/build
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "partition/cut-pursuit"]
2 | path = partition/cut-pursuit
3 | url = ../cut-pursuit.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Loic Landrieu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs
3 |
4 | ## ⚠️ This repo is no longer maintained! Please check out our brand new [*SuperPoint Transformer*](https://github.com/drprojects/superpoint_transformer), which does everything better! ⚠️
5 |
6 |
7 |
8 | This is the official PyTorch implementation of the papers:
9 |
10 | [*Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs*](http://arxiv.org/abs/1711.09869)
11 |
12 | by Loic Landrieu and Martin Simonovski (CVPR2018),
13 |
14 | and
15 |
16 | [*Point Cloud Oversegmentation with Graph-Structured Deep Metric Learning*](https://arxiv.org/pdf/1904.02113).
17 |
18 | by Loic Landrieu and Mohamed Boussaha (CVPR2019),
19 |
20 |
21 |
22 |
23 |
24 | ## Code structure
25 | * `./partition/*` - Partition code (geometric partitioning and superpoint graph construction using handcrafted features)
26 | * `./supervized_partition/*` - Supervized partition code (partitioning with learned features)
27 | * `./learning/*` - Learning code (superpoint embedding and contextual segmentation).
28 |
29 | To switch to the stable branch with only SPG, switch to [release](https://github.com/loicland/superpoint_graph/tree/release).
30 |
31 | ## Disclaimer
32 | Our partition method is inherently stochastic. Hence, even if we provide the trained weights, it is possible that the results that you obtain differ slightly from the ones presented in the paper.
33 |
34 | ## Requirements
35 | *0.* Download current version of the repository. We recommend using the `--recurse-submodules` option to make sure the [cut pursuit](https://github.com/loicland/cut-pursuit) module used in `/partition` is downloaded in the process. Wether you did not used the following command, please, refer to point 4:
36 | ```
37 | git clone --recurse-submodules https://github.com/loicland/superpoint_graph
38 | ```
39 |
40 | *1.* Install [PyTorch](https://pytorch.org) and [torchnet](https://github.com/pytorch/tnt).
41 | ```
42 | pip install git+https://github.com/pytorch/tnt.git@master
43 | ```
44 |
45 | *2.* Install additional Python packages:
46 | ```
47 | pip install future igraph tqdm transforms3d pynvrtc fastrlock cupy h5py sklearn plyfile scipy pandas
48 | ```
49 |
50 | *3.* Install Boost (1.63.0 or newer) and Eigen3, in Conda:
51 | ```
52 | conda install -c anaconda boost; conda install -c omnia eigen3; conda install eigen; conda install -c r libiconv
53 | ```
54 |
55 | *4.* Make sure that cut pursuit was downloaded. Otherwise, clone [this repository](https://github.com/loicland/cut-pursuit) or add it as a submodule in `/partition`:
56 | ```
57 | cd partition
58 | git submodule init
59 | git submodule update --remote cut-pursuit
60 | ```
61 |
62 | *5.* Compile the ```libply_c``` and ```libcp``` libraries:
63 | ```
64 | CONDAENV=YOUR_CONDA_ENVIRONMENT_LOCATION
65 | cd partition/ply_c
66 | cmake . -DPYTHON_LIBRARY=$CONDAENV/lib/libpython3.6m.so -DPYTHON_INCLUDE_DIR=$CONDAENV/include/python3.6m -DBOOST_INCLUDEDIR=$CONDAENV/include -DEIGEN3_INCLUDE_DIR=$CONDAENV/include/eigen3
67 | make
68 | cd ..
69 | cd cut-pursuit
70 | mkdir build
71 | cd build
72 | cmake .. -DPYTHON_LIBRARY=$CONDAENV/lib/libpython3.6m.so -DPYTHON_INCLUDE_DIR=$CONDAENV/include/python3.6m -DBOOST_INCLUDEDIR=$CONDAENV/include -DEIGEN3_INCLUDE_DIR=$CONDAENV/include/eigen3
73 | make
74 | ```
75 | *6.* (optional) Install [Pytorch Geometric](https://github.com/rusty1s/pytorch_geometric)
76 |
77 | The code was tested on Ubuntu 14 and 16 with Python 3.5 to 3.8 and PyTorch 0.2 to 1.3.
78 |
79 | ### Troubleshooting
80 |
81 | Common sources of errors and how to fix them:
82 | - $CONDAENV is not well defined : define it or replace $CONDAENV by the absolute path of your conda environment (find it with ```locate anaconda```)
83 | - anaconda uses a different version of python than 3.6m : adapt it in the command. Find which version of python conda is using with ```locate anaconda3/lib/libpython```
84 | - you are using boost 1.62 or older: update it
85 | - cut pursuit did not download: manually clone it in the ```partition``` folder or add it as a submodule as proposed in the requirements, point 4.
86 | - error in make: `'numpy/ndarrayobject.h' file not found`: set symbolic link to python site-package with `sudo ln -s $CONDAENV/lib/python3.7/site-packages/numpy/core/include/numpy $CONDAENV/include/numpy`
87 |
88 |
89 | ## Running the code
90 |
91 | To run our code or retrain from scratch on different datasets, see the corresponding readme files.
92 | Currently supported dataset are as follow:
93 |
94 | | Dataset | handcrafted partition | learned partition |
95 | | ---------- | --------------------- | ------------------|
96 | | S3DIS | yes | yes |
97 | | Semantic3D | yes | to come soon |
98 | | vKITTI3D | no | yes |
99 | | ScanNet | to come soon | to come soon |
100 |
101 | To use pytorch-geometric graph convolutions instead of our own, use the option `--use_pyg 1` in `./learning/main.py`. Their code is more stable and just as fast. Otherwise, use `--use_pyg 0`
102 |
103 | #### Evaluation
104 |
105 | To evaluate quantitatively a trained model, use (for S3DIS and vKITTI3D only):
106 | ```
107 | python learning/evaluate.py --dataset s3dis --odir results/s3dis/best --cvfold 123456
108 | ```
109 |
110 | To visualize the results and all intermediary steps, use the visualize function in partition (for S3DIS, vKITTI3D,a nd Semantic3D). For example:
111 | ```
112 | python partition/visualize.py --dataset s3dis --ROOT_PATH $S3DIR_DIR --res_file results/s3dis/pretrained/cv1/predictions_test --file_path Area_1/conferenceRoom_1 --output_type igfpres
113 | ```
114 |
115 | ```output_type``` defined as such:
116 | - ```'i'``` = input rgb point cloud
117 | - ```'g'``` = ground truth (if available), with the predefined class to color mapping
118 | - ```'f'``` = geometric feature with color code: red = linearity, green = planarity, blue = verticality
119 | - ```'p'``` = partition, with a random color for each superpoint
120 | - ```'r'``` = result cloud, with the predefined class to color mapping
121 | - ```'e'``` = error cloud, with green/red hue for correct/faulty prediction
122 | - ```'s'``` = superedge structure of the superpoint (toggle wireframe on meshlab to view it)
123 |
124 | Add option ```--upsample 1``` if you want the prediction file to be on the original, unpruned data (long).
125 |
126 | # Other data sets
127 |
128 | You can apply SPG on your own data set with minimal changes:
129 | - adapt references to ```custom_dataset``` in ```/partition/partition.py```
130 | - you will need to create the function ```read_custom_format``` in ```/partition/provider.py``` which outputs xyz and rgb values, as well as semantic labels if available (already implemented for ply and las files)
131 | - adapt the template function ```/learning/custom_dataset.py``` to your achitecture and design choices
132 | - adapt references to ```custom_dataset``` in ```/learning/main.py```
133 | - add your data set colormap to ```get_color_from_label``` in ```/partition/provider.py```
134 | - adapt line 212 of `learning/spg.py` to reflect the missing or extra point features
135 | - change ```--model_config``` to ```gru_10,f_K``` with ```K``` as the number of classes in your dataset, or ```gru_10_0,f_K``` to use matrix edge filters instead of vectors (only use matrices when your data set is quite large, and with many different point clouds, like S3DIS).
136 |
137 | # Datasets without RGB
138 | If your data does not have RGB values you can easily use SPG. You will need to follow the instructions in ```partition/partition.ply``` regarding the pruning.
139 | You will need to adapt the ```/learning/custom_dataset.py``` file so that it does not refer ro RGB values.
140 | You should absolutely not use a model pretrained on values with RGB. instead, retrain a model from scratch using the ```--pc_attribs xyzelpsv``` option to remove RGB from the shape embedding input.
141 |
142 | # Citation
143 | If you use the semantic segmentation module (code in `/learning`), please cite:
144 | *Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs*, Loic Landrieu and Martin Simonovski, CVPR, 2018.
145 |
146 | If you use the learned partition module (code in `/supervized_partition`), please cite:
147 | *Point Cloud Oversegmentation with Graph-Structured Deep Metric Learning*, Loic Landrieu and Mohamed Boussaha CVPR, 2019.
148 |
149 | To refer to the handcrafted partition (code in `/partition`) step specifically, refer to:
150 | *Weakly Supervised Segmentation-Aided Classification of Urban Scenes from 3D LiDAR Point Clouds*, Stéphane Guinard and Loic Landrieu. ISPRS Workshop, 2017.
151 |
152 | To refer to the L0-cut pursuit algorithm (code in `github.com/loicland/cut-pursuit`) specifically, refer to:
153 | *Cut Pursuit: Fast Algorithms to Learn Piecewise Constant Functions on General Weighted Graphs*, Loic Landrieu and Guillaume Obozinski, SIAM Journal on Imaging Sciences, 2017
154 |
155 | To refer to pytorch geometric implementation, see their bibtex in [their repo](https://github.com/rusty1s/pytorch_geometric).
156 |
157 |
158 |
--------------------------------------------------------------------------------
/S3DIS.md:
--------------------------------------------------------------------------------
1 | # S3DIS
2 |
3 | Download [S3DIS Dataset](http://buildingparser.stanford.edu/dataset.html) and extract `Stanford3dDataset_v1.2_Aligned_Version.zip` to `$S3DIS_DIR/data`, where `$S3DIS_DIR` is set to dataset directory.
4 |
5 | To fix some issues with the dataset as reported in issue [#29](https://github.com/loicland/superpoint_graph/issues/29), apply path `S3DIS_fix.diff` with:
6 | ```
7 | cp S3DIS_fix.diff $S3DIS_DIR/data; cd $S3DIS_DIR/data; git apply S3DIS_fix.diff; rm S3DIS_fix.diff; cd -
8 | ```
9 | Define $S3DIS_DIR as the location of the folder containing `/data`
10 |
11 | ## SPG with Handcrafted Partition
12 |
13 | To compute the partition with handcrafted features run:
14 | ```
15 | python partition/partition.py --dataset s3dis --ROOT_PATH $S3DIS_DIR --voxel_width 0.03 --reg_strength 0.03
16 | ```
17 |
18 | Then, reorganize point clouds into superpoints by:
19 | ```
20 | python learning/s3dis_dataset.py --S3DIS_PATH $S3DIS_DIR
21 | ```
22 |
23 | To train from scratch on the all 6 folds on the handcrafted partition, run:
24 | ```
25 | for FOLD in 1 2 3 4 5 6; do \
26 | CUDA_VISIBLE_DEVICES=0 python learning/main.py --dataset s3dis --S3DIS_PATH $S3DIS_DIR --cvfold $FOLD --epochs 350 \
27 | --lr_steps '[275,320]' --test_nth_epoch 50 --model_config 'gru_10_0,f_13' --ptn_nfeat_stn 14 --nworkers 2 \
28 | --pc_attribs xyzrgbelpsvXYZ --odir "results/s3dis/best/cv${FOLD}" --nworkers 4; \
29 | done
30 | ```
31 |
32 | Our trained networks can be downloaded [here](http://imagine.enpc.fr/~simonovm/largescale/models_s3dis.zip). Unzip the folder (but not the model.pth.tar themselves) and place them in the code folder `results/s3dis/pretrained/`.
33 |
34 | To test these networks on the full test set, run:
35 | ```
36 | for FOLD in 1 2 3 4 5 6; do \
37 | CUDA_VISIBLE_DEVICES=0 python learning/main.py --dataset s3dis --S3DIS_PATH $S3DIS_DIR --cvfold $FOLD --epochs -1 --lr_steps '[275,320]' \
38 | --test_nth_epoch 50 --model_config 'gru_10_0,f_13' --ptn_nfeat_stn 14 --nworkers 2 --pc_attribs xyzrgbelpsvXYZ --odir "results/s3dis/pretrained/cv${FOLD}" --resume RESUME; \
39 | done
40 | ```
41 |
42 | ## SSP+SPG: SPG with learned partition
43 |
44 | To learn the partition from scratch run:
45 | ```
46 | python supervized_partition/graph_processing.py --ROOT_PATH $S3DIS_DIR --dataset s3dis --voxel_width 0.03; \
47 |
48 | for FOLD in 1 2 3 4 5 6; do \
49 | python ./supervized_partition/supervized_partition.py --ROOT_PATH $S3DIS_DIR --cvfold $FOLD \
50 | --odir results_partition/s3dis/best --epochs 50 --reg_strength 0.1 --spatial_emb 0.2 \
51 | --global_feat eXYrgb --CP_cutoff 25; \
52 | done
53 | ```
54 | Or download our trained weights [here](http://recherche.ign.fr/llandrieu/SPG/S3DIS/pretrained.zip) in the folder `results_partition/s3dis/pretrained`, unzipped and run the following code:
55 |
56 | ```
57 | for FOLD in 1 2 3 4 5 6; do \
58 | python ./supervized_partition/supervized_partition.py --ROOT_PATH $S3DIS_DIR --cvfold $FOLD --epochs -1 \
59 | --odir results_partition/s3dis/pretrained --reg_strength 0.1 --spatial_emb 0.2 --global_feat eXYrgb \
60 | --CP_cutoff 25 --resume RESUME; \
61 | done
62 | ```
63 |
64 | To evaluate the quality of the partition, run:
65 | ```
66 | python supervized_partition/evaluate_partition.py --dataset s3dis --folder pretrained --cvfold 123456
67 | ```
68 |
69 | Then, reorganize point clouds into superpoints with:
70 | ```
71 | python learning/s3dis_dataset.py --S3DIS_PATH $S3DIS_DIR --supervized_partition 1 -plane_model_elevation 1
72 | ```
73 |
74 | Then to learn the SPG models from scratch, run:
75 | ```
76 | for FOLD in 1 2 3 4 5 6; do \
77 | CUDA_VISIBLE_DEVICES=0 python ./learning/main.py --dataset s3dis --S3DIS_PATH $S3DIS_DIR --batch_size 5 \
78 | --cvfold $FOLD --epochs 250 --lr_steps '[150,200]' --model_config "gru_10_0,f_13" --ptn_nfeat_stn 10 \
79 | --nworkers 2 --spg_augm_order 5 --pc_attribs xyzXYZrgbe --spg_augm_hardcutoff 768 --ptn_minpts 50 \
80 | --use_val_set 1 --odir results/s3dis/best/cv$FOLD; \
81 | done;
82 | ```
83 |
84 | Or use our [trained weights](http://recherche.ign.fr/llandrieu/SPG/S3DIS/pretrained_SSP.zip) with `--epochs -1` and `--resume RESUME`:
85 | ```
86 | for FOLD in 1 2 3 4 5 6; do \
87 | CUDA_VISIBLE_DEVICES=0 python ./learning/main.py --dataset s3dis --S3DIS_PATH $S3DIS_DIR --batch_size 5 \
88 | --cvfold $FOLD --epochs -1 --lr_steps '[150,200]' --model_config "gru_10_0,f_13" --ptn_nfeat_stn 10 \
89 | --nworkers 2 --spg_augm_order 5 --pc_attribs xyzXYZrgbe --spg_augm_hardcutoff 768 --ptn_minpts 50 \
90 | --use_val_set 1 --odir results/s3dis/pretrained_SSP/cv$FOLD --resume RESUME; \
91 | done;
92 | ```
93 | Note that these weights are specifically adapted to the pretrained model for the learned partition. Any change to the partition might decrease their performance.
94 |
--------------------------------------------------------------------------------
/S3DIS_fix.diff:
--------------------------------------------------------------------------------
1 | diff --git a/Area_3/hallway_2/hallway_2.txt b/Area_3/hallway_2/hallway_2.txt
2 | index 02f32b8..870566e 100644
3 | --- a/Area_3/hallway_2/hallway_2.txt
4 | +++ b/Area_3/hallway_2/hallway_2.txt
5 | @@ -926334,7 +926334,7 @@
6 | 19.237 -9.161 1.561 141 131 96
7 | 19.248 -9.160 1.768 136 129 103
8 | 19.276 -9.160 1.684 139 130 99
9 | -19.302 -9.10 1.785 146 137 106
10 | +19.302 -9.1 0 1.785 146 137 106
11 | 19.242 -9.160 1.790 146 134 108
12 | 19.271 -9.160 1.679 140 129 99
13 | 19.278 -9.160 1.761 133 123 98
14 | diff --git a/Area_5/hallway_6/Annotations/ceiling_1.txt b/Area_5/hallway_6/Annotations/ceiling_1.txt
15 | index 62e563d..3a9087b 100644
16 | --- a/Area_5/hallway_6/Annotations/ceiling_1.txt
17 | +++ b/Area_5/hallway_6/Annotations/ceiling_1.txt
18 | @@ -180386,7 +180386,7 @@
19 | 22.383 6.858 3.050 155 155 165
20 | 22.275 6.643 3.048 192 194 191
21 | 22.359 6.835 3.050 152 152 162
22 | -22.350 6.692 3.048 185 187 182
23 | +22.350 6.692 3.048 185 187 182
24 | 22.314 6.638 3.048 170 171 175
25 | 22.481 6.818 3.049 149 149 159
26 | 22.328 6.673 3.048 190 195 191
27 | diff --git a/Area_6/copyRoom_1/copy_Room_1.txt b/Area_6/copyRoom_1/copyRoom_1.txt
28 | similarity index 100%
29 | rename from Area_6/copyRoom_1/copy_Room_1.txt
30 | rename to Area_6/copyRoom_1/copyRoom_1.txt
31 |
--------------------------------------------------------------------------------
/Semantic3D.md:
--------------------------------------------------------------------------------
1 | # Semantic3D
2 |
3 | Download all point clouds and labels from [Semantic3D Dataset](http://www.semantic3d.net/) and place extracted training files to `$SEMA3D_DIR/data/train`, reduced test files into `$SEMA3D_DIR/data/test_reduced`, and full test files into `$SEMA3D_DIR/data/test_full`, where `$SEMA3D_DIR` is set to dataset directory. The label files of the training files must be put in the same directory than the .txt files.
4 |
5 | ## Handcrafted Partition
6 |
7 | To compute the partition with handcrafted features run:
8 | ```
9 | python partition/partition.py --dataset sema3d --ROOT_PATH $SEMA3D_DIR --voxel_width 0.05 --reg_strength 0.8 --ver_batch 5000000
10 | ```
11 | It is recommended that you have at least 24GB of RAM to run this code. Otherwise, increase the ```voxel_width``` parameter to increase pruning.
12 |
13 | Then, reorganize point clouds into superpoints by:
14 | ```
15 | python learning/sema3d_dataset.py --SEMA3D_PATH $SEMA3D_DIR
16 | ```
17 |
18 | To train on the whole publicly available data and test on the reduced test set, run:
19 | ```
20 | CUDA_VISIBLE_DEVICES=0 python learning/main.py --dataset sema3d --SEMA3D_PATH $SEMA3D_DIR --db_test_name testred --db_train_name trainval \
21 | --epochs 500 --lr_steps '[350, 400, 450]' --test_nth_epoch 100 --model_config 'gru_10,f_8' --ptn_nfeat_stn 11 \
22 | --nworkers 2 --pc_attrib xyzrgbelpsv --odir "results/sema3d/trainval_best"
23 | ```
24 | The trained network can be downloaded [here](http://imagine.enpc.fr/~simonovm/largescale/model_sema3d_trainval.pth.tar) and loaded with `--resume` argument. Rename the file ```model.pth.tar``` (do not try to unzip it!) and place it in the directory ```results/sema3d/trainval_best```.
25 |
26 | To test this network on the full test set, run:
27 | ```
28 | CUDA_VISIBLE_DEVICES=0 python learning/main.py --dataset sema3d --SEMA3D_PATH $SEMA3D_DIR --db_test_name testfull --db_train_name trainval \
29 | --epochs -1 --lr_steps '[350, 400, 450]' --test_nth_epoch 100 --model_config 'gru_10,f_8' --ptn_nfeat_stn 11 \
30 | --nworkers 2 --pc_attrib xyzrgbelpsv --odir "results/sema3d/trainval_best" --resume RESUME
31 | ```
32 | We validated our configuration on a custom split of 11 and 4 clouds. The network is trained as such:
33 | ```
34 | CUDA_VISIBLE_DEVICES=0 python learning/main.py --dataset sema3d --SEMA3D_PATH $SEMA3D_DIR --epochs 450 --lr_steps '[350, 400]' --test_nth_epoch 100 \
35 | --model_config 'gru_10,f_8' --pc_attrib xyzrgbelpsv --ptn_nfeat_stn 11 --nworkers 2 --odir "results/sema3d/best"
36 | ```
37 |
38 | #### Learned Partition
39 |
40 | Not yet available.
41 |
42 | #### Visualization
43 |
44 | To upsample the prediction to the unpruned data and write the .labels files for the reduced test set, run (quite slow):
45 | ```
46 | python partition/write_Semantic3d.py --SEMA3D_PATH $SEMA3D_DIR --odir "results/sema3d/trainval_best" --db_test_name testred
47 | ```
48 |
49 | To visualize the results and intermediary steps (on the subsampled graph), use the visualize function in partition. For example:
50 | ```
51 | python partition/visualize.py --dataset sema3d --ROOT_PATH $SEMA3D_DIR --res_file 'results/sema3d/trainval_best/prediction_testred' --file_path 'test_reduced/MarketplaceFeldkirch_Station4' --output_type ifprs
52 | ```
53 | avoid ```--upsample 1``` as it can can take a very long time on the largest clouds.
54 |
--------------------------------------------------------------------------------
/learning/__init__.py:
--------------------------------------------------------------------------------
1 | import os,sys
2 |
3 | DIR_PATH = os.path.dirname(os.path.realpath(__file__))
4 | sys.path.insert(0, DIR_PATH)
--------------------------------------------------------------------------------
/learning/custom_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Tue Mar 20 16:16:14 2018
5 |
6 | @author: landrieuloic
7 | """"""
8 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs
9 | http://arxiv.org/abs/1711.09869
10 | 2017 Loic Landrieu, Martin Simonovsky
11 | Template file for processing custome datasets
12 | """
13 | from __future__ import division
14 | from __future__ import print_function
15 | from builtins import range
16 |
17 | import random
18 | import numpy as np
19 | import os
20 | import functools
21 | import torch
22 | import torchnet as tnt
23 | import h5py
24 | import spg
25 |
26 |
27 | def get_datasets(args, test_seed_offset=0):
28 | """build training and testing set"""
29 |
30 | #for a simple train/test organization
31 | trainset = ['train/' + f for f in os.listdir(args.CUSTOM_SET_PATH + '/superpoint_graphs/train')]
32 | testset = ['test/' + f for f in os.listdir(args.CUSTOM_SET_PATH + '/superpoint_graphs/train')]
33 |
34 | # Load superpoints graphs
35 | testlist, trainlist = [], []
36 | for n in trainset:
37 | trainlist.append(spg.spg_reader(args, args.CUSTOM_SET_PATH + '/superpoint_graphs/' + n + '.h5', True))
38 | for n in testset:
39 | testlist.append(spg.spg_reader(args, args.CUSTOM_SET_PATH + '/superpoint_graphs/' + n + '.h5', True))
40 |
41 | # Normalize edge features
42 | if args.spg_attribs01:
43 | trainlist, testlist, validlist, scaler = spg.scaler01(trainlist, testlist)
44 |
45 | return tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in trainlist],
46 | functools.partial(spg.loader, train=True, args=args, db_path=args.CUSTOM_SET_PATH)), \
47 | tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in testlist],
48 | functools.partial(spg.loader, train=False, args=args, db_path=args.CUSTOM_SET_PATH, test_seed_offset=test_seed_offset)) ,\
49 | scaler
50 |
51 | def get_info(args):
52 | edge_feats = 0
53 | for attrib in args.edge_attribs.split(','):
54 | a = attrib.split('/')[0]
55 | if a in ['delta_avg', 'delta_std', 'xyz']:
56 | edge_feats += 3
57 | else:
58 | edge_feats += 1
59 |
60 | return {
61 | 'node_feats': 11 if args.pc_attribs=='' else len(args.pc_attribs),
62 | 'edge_feats': edge_feats,
63 | 'classes': 10, #CHANGE TO YOUR NUMBER OF CLASS
64 | 'inv_class_map': {0:'class_A', 1:'class_B'}, #etc...
65 | }
66 |
67 | def preprocess_pointclouds(SEMA3D_PATH):
68 | """ Preprocesses data by splitting them by components and normalizing."""
69 |
70 | for n in ['train', 'test_reduced', 'test_full']:
71 | pathP = '{}/parsed/{}/'.format(SEMA3D_PATH, n)
72 | pathD = '{}/features/{}/'.format(SEMA3D_PATH, n)
73 | pathC = '{}/superpoint_graphs/{}/'.format(SEMA3D_PATH, n)
74 | if not os.path.exists(pathP):
75 | os.makedirs(pathP)
76 | random.seed(0)
77 |
78 | for file in os.listdir(pathC):
79 | print(file)
80 | if file.endswith(".h5"):
81 | f = h5py.File(pathD + file, 'r')
82 | xyz = f['xyz'][:]
83 | rgb = f['rgb'][:].astype(np.float)
84 | elpsv = np.stack([ f['xyz'][:,2][:], f['linearity'][:], f['planarity'][:], f['scattering'][:], f['verticality'][:] ], axis=1)
85 |
86 | # rescale to [-0.5,0.5]; keep xyz
87 | #warning - to use the trained model, make sure the elevation is comparable
88 | #to the set they were trained on
89 | #i.e. ~0 for roads and ~0.2-0.3 for builings for sema3d
90 | # and -0.5 for floor and 0.5 for ceiling for s3dis
91 | elpsv[:,0] /= 100 # (rough guess) #adapt
92 | elpsv[:,1:] -= 0.5
93 | rgb = rgb/255.0 - 0.5
94 |
95 | P = np.concatenate([xyz, rgb, elpsv], axis=1)
96 |
97 | f = h5py.File(pathC + file, 'r')
98 | numc = len(f['components'].keys())
99 |
100 | with h5py.File(pathP + file, 'w') as hf:
101 | for c in range(numc):
102 | idx = f['components/{:d}'.format(c)][:].flatten()
103 | if idx.size > 10000: # trim extra large segments, just for speed-up of loading time
104 | ii = random.sample(range(idx.size), k=10000)
105 | idx = idx[ii]
106 |
107 | hf.create_dataset(name='{:d}'.format(c), data=P[idx,...])
108 |
109 | if __name__ == "__main__":
110 | import argparse
111 | parser = argparse.ArgumentParser(description='Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs')
112 | parser.add_argument('--CUSTOM_SET_PATH', default='datasets/custom_set')
113 | args = parser.parse_args()
114 | preprocess_pointclouds(args.CUSTOM_SET_PATH)
115 |
116 |
117 |
--------------------------------------------------------------------------------
/learning/ecc/GraphConvInfo.py:
--------------------------------------------------------------------------------
1 | """
2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs
3 | https://github.com/mys007/ecc
4 | https://arxiv.org/abs/1704.02901
5 | 2017 Martin Simonovsky
6 | """
7 | from __future__ import division
8 | from __future__ import print_function
9 | from builtins import range
10 |
11 | import igraph
12 | import torch
13 | from collections import defaultdict
14 | import numpy as np
15 |
16 | class GraphConvInfo(object):
17 | """ Holds information about the structure of graph(s) in a vectorized form useful to `GraphConvModule`.
18 |
19 | We assume that the node feature tensor (given to `GraphConvModule` as input) is ordered by igraph vertex id, e.g. the fifth row corresponds to vertex with id=4. Batch processing is realized by concatenating all graphs into a large graph of disconnected components (and all node feature tensors into a large tensor).
20 |
21 | The class requires problem-specific `edge_feat_func` function, which receives dict of edge attributes and returns Tensor of edge features and LongTensor of inverse indices if edge compaction was performed (less unique edge features than edges so some may be reused).
22 | """
23 |
24 | def __init__(self, *args, **kwargs):
25 | self._idxn = None #indices into input tensor of convolution (node features)
26 | self._idxe = None #indices into edge features tensor (or None if it would be linear, i.e. no compaction)
27 | self._degrees = None #in-degrees of output nodes (slices _idxn and _idxe)
28 | self._degrees_gpu = None
29 | self._edgefeats = None #edge features tensor (to be processed by feature-generating network)
30 | if len(args)>0 or len(kwargs)>0:
31 | self.set_batch(*args, **kwargs)
32 |
33 | def set_batch(self, graphs, edge_feat_func):
34 | """ Creates a representation of a given batch of graphs.
35 |
36 | Parameters:
37 | graphs: single graph or a list/tuple of graphs.
38 | edge_feat_func: see class description.
39 | """
40 |
41 | graphs = graphs if isinstance(graphs,(list,tuple)) else [graphs]
42 | p = 0
43 | idxn = []
44 | degrees = []
45 | edge_indexes = []
46 | edgeattrs = defaultdict(list)
47 |
48 | for G in graphs:
49 | E = np.array(G.get_edgelist())
50 | idx = E[:,1].argsort() # sort by target
51 |
52 | idxn.append(p + E[idx,0])
53 | edgeseq = G.es[idx.tolist()]
54 | for a in G.es.attributes():
55 | edgeattrs[a] += edgeseq.get_attribute_values(a)
56 | degrees += G.indegree(G.vs, loops=True)
57 | edge_indexes.append(np.asarray(p + E[idx]))
58 | p += G.vcount()
59 |
60 | self._edgefeats, self._idxe = edge_feat_func(edgeattrs)
61 |
62 | self._idxn = torch.LongTensor(np.concatenate(idxn))
63 | if self._idxe is not None:
64 | assert self._idxe.numel() == self._idxn.numel()
65 |
66 | self._degrees = torch.LongTensor(degrees)
67 | self._degrees_gpu = None
68 |
69 | self._edge_indexes = torch.LongTensor(np.concatenate(edge_indexes).T)
70 |
71 | def cuda(self):
72 | self._idxn = self._idxn.cuda()
73 | if self._idxe is not None: self._idxe = self._idxe.cuda()
74 | self._degrees_gpu = self._degrees.cuda()
75 | self._edgefeats = self._edgefeats.cuda()
76 | self._edge_indexes = self._edge_indexes.cuda()
77 |
78 | def get_buffers(self):
79 | """ Provides data to `GraphConvModule`.
80 | """
81 | return self._idxn, self._idxe, self._degrees, self._degrees_gpu, self._edgefeats
82 |
83 | def get_pyg_buffers(self):
84 | """ Provides data to `GraphConvModule`.
85 | """
86 | return self._edge_indexes
--------------------------------------------------------------------------------
/learning/ecc/GraphConvModule.py:
--------------------------------------------------------------------------------
1 | """
2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs
3 | https://github.com/mys007/ecc
4 | https://arxiv.org/abs/1704.02901
5 | 2017 Martin Simonovsky
6 | """
7 | from __future__ import division
8 | from __future__ import print_function
9 | from builtins import range
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable, Function
14 | from .GraphConvInfo import GraphConvInfo
15 | from . import cuda_kernels
16 | from . import utils
17 |
18 |
19 | class GraphConvFunction(Function):
20 | """Computes operations for each edge and averages the results over respective nodes.
21 | The operation is either matrix-vector multiplication (for 3D weight tensors) or element-wise
22 | vector-vector multiplication (for 2D weight tensors). The evaluation is computed in blocks of
23 | size `edge_mem_limit` to reduce peak memory load. See `GraphConvInfo` for info on `idxn, idxe, degs`.
24 | """
25 | def init(self, in_channels, out_channels, idxn, idxe, degs, degs_gpu, edge_mem_limit=1e20):
26 | self._in_channels = in_channels
27 | self._out_channels = out_channels
28 | self._idxn = idxn
29 | self._idxe = idxe
30 | self._degs = degs
31 | self._degs_gpu = degs_gpu
32 | self._shards = utils.get_edge_shards(degs, edge_mem_limit)
33 |
34 | def _multiply(ctx, a, b, out, f_a=None, f_b=None):
35 | """Performs operation on edge weights and node signal"""
36 | if ctx._full_weight_mat:
37 | # weights are full in_channels x out_channels matrices -> mm
38 | torch.bmm(f_a(a) if f_a else a, f_b(b) if f_b else b, out=out)
39 | else:
40 | # weights represent diagonal matrices -> mul
41 | torch.mul(a, b.expand_as(a), out=out)
42 |
43 | @staticmethod
44 | def forward(ctx, input, weights, in_channels, out_channels, idxn, idxe, degs, degs_gpu, edge_mem_limit=1e20):
45 |
46 | ctx.save_for_backward(input, weights)
47 | ctx._in_channels = in_channels
48 | ctx._out_channels = out_channels
49 | ctx._idxn = idxn
50 | ctx._idxe = idxe
51 | ctx._degs = degs
52 | ctx._degs_gpu = degs_gpu
53 | ctx._shards = utils.get_edge_shards(degs, edge_mem_limit)
54 |
55 | ctx._full_weight_mat = weights.dim() == 3
56 | assert ctx._full_weight_mat or (
57 | in_channels == out_channels and weights.size(1) == in_channels)
58 |
59 | output = input.new(degs.numel(), out_channels)
60 |
61 | # loop over blocks of output nodes
62 | startd, starte = 0, 0
63 | for numd, nume in ctx._shards:
64 |
65 | # select sequence of matching pairs of node and edge weights
66 | sel_input = torch.index_select(input, 0, idxn.narrow(0, starte, nume))
67 |
68 | if ctx._idxe is not None:
69 | sel_weights = torch.index_select(weights, 0, idxe.narrow(0, starte, nume))
70 | else:
71 | sel_weights = weights.narrow(0, starte, nume)
72 |
73 | # compute matrix-vector products
74 | products = input.new()
75 | GraphConvFunction._multiply(ctx, sel_input, sel_weights, products, lambda a: a.unsqueeze(1))
76 |
77 | # average over nodes
78 | if ctx._idxn.is_cuda:
79 | cuda_kernels.conv_aggregate_fw(output.narrow(0, startd, numd), products.view(-1, ctx._out_channels),
80 | ctx._degs_gpu.narrow(0, startd, numd))
81 | else:
82 | k = 0
83 | for i in range(startd, startd + numd):
84 | if ctx._degs[i] > 0:
85 | torch.mean(products.narrow(0, k, ctx._degs[i]), 0, out=output[i])
86 | else:
87 | output[i].fill_(0)
88 | k = k + ctx._degs[i]
89 |
90 | startd += numd
91 | starte += nume
92 | del sel_input, sel_weights, products
93 |
94 | return output
95 |
96 | @staticmethod
97 | def backward(ctx, grad_output):
98 | input, weights = ctx.saved_tensors
99 |
100 | grad_input = input.new(input.size()).fill_(0)
101 | grad_weights = weights.new(weights.size())
102 | if ctx._idxe is not None: grad_weights.fill_(0)
103 |
104 | # loop over blocks of output nodes
105 | startd, starte = 0, 0
106 | for numd, nume in ctx._shards:
107 |
108 | grad_products, tmp = input.new(nume, ctx._out_channels), input.new()
109 |
110 | if ctx._idxn.is_cuda:
111 | cuda_kernels.conv_aggregate_bw(grad_products, grad_output.narrow(0, startd, numd),
112 | ctx._degs_gpu.narrow(0, startd, numd))
113 | else:
114 | k = 0
115 | for i in range(startd, startd + numd):
116 | if ctx._degs[i] > 0:
117 | torch.div(grad_output[i], ctx._degs[i], out=grad_products[k])
118 | if ctx._degs[i] > 1:
119 | grad_products.narrow(0, k + 1, ctx._degs[i] - 1).copy_(
120 | grad_products[k].expand(ctx._degs[i] - 1, 1, ctx._out_channels).squeeze(1))
121 | k = k + ctx._degs[i]
122 |
123 | # grad wrt weights
124 | sel_input = torch.index_select(input, 0, ctx._idxn.narrow(0, starte, nume))
125 |
126 | if ctx._idxe is not None:
127 | GraphConvFunction._multiply(ctx, sel_input, grad_products, tmp,
128 | lambda a: a.unsqueeze(1).transpose_(2, 1),
129 | lambda b: b.unsqueeze(1))
130 | grad_weights.index_add_(0, ctx._idxe.narrow(0, starte, nume), tmp)
131 | else:
132 | GraphConvFunction._multiply(ctx, sel_input, grad_products, grad_weights.narrow(0, starte, nume),
133 | lambda a: a.unsqueeze(1).transpose_(2, 1), lambda b: b.unsqueeze(1))
134 |
135 | # grad wrt input
136 | if ctx._idxe is not None:
137 | torch.index_select(weights, 0, ctx._idxe.narrow(0, starte, nume), out=tmp)
138 | GraphConvFunction._multiply(ctx, grad_products, tmp, sel_input, lambda a: a.unsqueeze(1),
139 | lambda b: b.transpose_(2, 1))
140 | del tmp
141 | else:
142 | GraphConvFunction._multiply(ctx, grad_products, weights.narrow(0, starte, nume), sel_input,
143 | lambda a: a.unsqueeze(1),
144 | lambda b: b.transpose_(2, 1))
145 |
146 | grad_input.index_add_(0, ctx._idxn.narrow(0, starte, nume), sel_input)
147 |
148 | startd += numd
149 | starte += nume
150 | del grad_products, sel_input
151 |
152 | return grad_input, grad_weights, None, None, None, None, None, None, None
153 |
154 |
155 |
156 | class GraphConvModule(nn.Module):
157 | """ Computes graph convolution using filter weights obtained from a filter generating network (`filter_net`).
158 | The input should be a 2D tensor of size (# nodes, `in_channels`). Multiple graphs can be concatenated in the same tensor (minibatch).
159 |
160 | Parameters:
161 | in_channels: number of input channels
162 | out_channels: number of output channels
163 | filter_net: filter-generating network transforming a 2D tensor (# edges, # edge features) to (# edges, in_channels*out_channels) or (# edges, in_channels)
164 | gc_info: GraphConvInfo object containing graph(s) structure information, can be also set with `set_info()` method.
165 | edge_mem_limit: block size (number of evaluated edges in parallel) for convolution evaluation, a low value reduces peak memory.
166 | """
167 |
168 | def __init__(self, in_channels, out_channels, filter_net, gc_info=None, edge_mem_limit=1e20):
169 | super(GraphConvModule, self).__init__()
170 |
171 | self._in_channels = in_channels
172 | self._out_channels = out_channels
173 | self._fnet = filter_net
174 | self._edge_mem_limit = edge_mem_limit
175 |
176 | self.set_info(gc_info)
177 |
178 | def set_info(self, gc_info):
179 | self._gci = gc_info
180 |
181 | def forward(self, input):
182 | # get graph structure information tensors
183 | idxn, idxe, degs, degs_gpu, edgefeats = self._gci.get_buffers()
184 | edgefeats = Variable(edgefeats, requires_grad=False)
185 |
186 | # evalute and reshape filter weights
187 | weights = self._fnet(edgefeats)
188 | assert input.dim()==2 and weights.dim()==2 and (weights.size(1) == self._in_channels*self._out_channels or
189 | (self._in_channels == self._out_channels and weights.size(1) == self._in_channels))
190 | if weights.size(1) == self._in_channels*self._out_channels:
191 | weights = weights.view(-1, self._in_channels, self._out_channels)
192 |
193 | return GraphConvFunction(self._in_channels, self._out_channels, idxn, idxe, degs, degs_gpu, self._edge_mem_limit)(input, weights)
194 |
195 |
196 |
197 |
198 |
199 |
200 | class GraphConvModulePureAutograd(nn.Module):
201 | """
202 | Autograd-only equivalent of `GraphConvModule` + `GraphConvFunction`. Unfortunately, autograd needs to store intermediate products, which makes the module work only for very small graphs. The module is kept for didactic purposes only.
203 | """
204 |
205 | def __init__(self, in_channels, out_channels, filter_net, gc_info=None):
206 | super(GraphConvModulePureAutograd, self).__init__()
207 |
208 | self._in_channels = in_channels
209 | self._out_channels = out_channels
210 | self._fnet = filter_net
211 |
212 | self.set_info(gc_info)
213 |
214 | def set_info(self, gc_info):
215 | self._gci = gc_info
216 |
217 | def forward(self, input):
218 | # get graph structure information tensors
219 | idxn, idxe, degs, edgefeats = self._gci.get_buffers()
220 | idxn = Variable(idxn, requires_grad=False)
221 | edgefeats = Variable(edgefeats, requires_grad=False)
222 |
223 | # evalute and reshape filter weights
224 | weights = self._fnet(edgefeats)
225 | assert input.dim()==2 and weights.dim()==2 and weights.size(1) == self._in_channels*self._out_channels
226 | weights = weights.view(-1, self._in_channels, self._out_channels)
227 |
228 | # select sequence of matching pairs of node and edge weights
229 | if idxe is not None:
230 | idxe = Variable(idxe, requires_grad=False)
231 | weights = torch.index_select(weights, 0, idxe)
232 |
233 | sel_input = torch.index_select(input, 0, idxn)
234 |
235 | # compute matrix-vector products
236 | products = torch.bmm(sel_input.view(-1,1,self._in_channels), weights)
237 |
238 | output = Variable(input.data.new(len(degs), self._out_channels))
239 |
240 | # average over nodes
241 | k = 0
242 | for i in range(len(degs)):
243 | if degs[i]>0:
244 | output.index_copy_(0, Variable(torch.Tensor([i]).type_as(idxn.data)), torch.mean(products.narrow(0,k,degs[i]), 0).view(1,-1))
245 | else:
246 | output.index_fill_(0, Variable(torch.Tensor([i]).type_as(idxn.data)), 0)
247 | k = k + degs[i]
248 |
249 | return output
250 |
251 |
--------------------------------------------------------------------------------
/learning/ecc/GraphPoolInfo.py:
--------------------------------------------------------------------------------
1 | """
2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs
3 | https://github.com/mys007/ecc
4 | https://arxiv.org/abs/1704.02901
5 | 2017 Martin Simonovsky
6 | """
7 | from __future__ import division
8 | from __future__ import print_function
9 | from builtins import range
10 |
11 | import torch
12 |
13 |
14 | class GraphPoolInfo(object):
15 | """ Holds information about pooling in a vectorized form useful to `GraphPoolModule`.
16 |
17 | We assume that the node feature tensor (given to `GraphPoolModule` as input) is ordered by igraph vertex id, e.g. the fifth row corresponds to vertex with id=4. Batch processing is realized by concatenating all graphs into a large graph of disconnected components (and all node feature tensors into a large tensor).
18 | """
19 |
20 | def __init__(self, *args, **kwargs):
21 | self._idxn = None #indices into input tensor of convolution (node features)
22 | self._degrees = None #in-degrees of output nodes (slices _idxn)
23 | self._degrees_gpu = None
24 | if len(args)>0 or len(kwargs)>0:
25 | self.set_batch(*args, **kwargs)
26 |
27 | def set_batch(self, poolmaps, graphs_from, graphs_to):
28 | """ Creates a representation of a given batch of graph poolings.
29 |
30 | Parameters:
31 | poolmaps: dict(s) mapping vertex id in coarsened graph to a list of vertex ids in input graph (defines pooling)
32 | graphs_from: input graph(s)
33 | graphs_to: coarsened graph(s)
34 | """
35 |
36 | poolmaps = poolmaps if isinstance(poolmaps,(list,tuple)) else [poolmaps]
37 | graphs_from = graphs_from if isinstance(graphs_from,(list,tuple)) else [graphs_from]
38 | graphs_to = graphs_to if isinstance(graphs_to,(list,tuple)) else [graphs_to]
39 |
40 | idxn = []
41 | degrees = []
42 | p = 0
43 |
44 | for map, G_from, G_to in zip(poolmaps, graphs_from, graphs_to):
45 | for v in range(G_to.vcount()):
46 | nlist = map.get(v, [])
47 | idxn.extend([n+p for n in nlist])
48 | degrees.append(len(nlist))
49 | p += G_from.vcount()
50 |
51 | self._idxn = torch.LongTensor(idxn)
52 | self._degrees = torch.LongTensor(degrees)
53 | self._degrees_gpu = None
54 |
55 | def cuda(self):
56 | self._idxn = self._idxn.cuda()
57 | self._degrees_gpu = self._degrees.cuda()
58 |
59 | def get_buffers(self):
60 | """ Provides data to `GraphPoolModule`.
61 | """
62 | return self._idxn, self._degrees, self._degrees_gpu
63 |
--------------------------------------------------------------------------------
/learning/ecc/GraphPoolModule.py:
--------------------------------------------------------------------------------
1 | """
2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs
3 | https://github.com/mys007/ecc
4 | https://arxiv.org/abs/1704.02901
5 | 2017 Martin Simonovsky
6 | """
7 | from __future__ import division
8 | from __future__ import print_function
9 | from builtins import range
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable, Function
14 | from .GraphPoolInfo import GraphPoolInfo
15 | from . import cuda_kernels
16 | from . import utils
17 |
18 | class GraphPoolFunction(Function):
19 | """ Computes node feature aggregation for each node of the coarsened graph. The evaluation is computed in blocks of size `edge_mem_limit` to reduce peak memory load. See `GraphPoolInfo` for info on `idxn, degs`.
20 | """
21 |
22 | AGGR_MEAN = 0
23 | AGGR_MAX = 1
24 |
25 | def __init__(self, idxn, degs, degs_gpu, aggr, edge_mem_limit=1e20):
26 | super(GraphPoolFunction, self).__init__()
27 | self._idxn = idxn
28 | self._degs = degs
29 | self._degs_gpu = degs_gpu
30 | self._aggr = aggr
31 | self._shards = utils.get_edge_shards(degs, edge_mem_limit)
32 |
33 | def forward(self, input):
34 | output = input.new(self._degs.numel(), input.size(1))
35 | if self._aggr==GraphPoolFunction.AGGR_MAX:
36 | self._max_indices = self._idxn.new(self._degs.numel(), input.size(1)).fill_(-1)
37 |
38 | self._input_size = input.size()
39 |
40 | # loop over blocks of output nodes
41 | startd, starte = 0, 0
42 | for numd, nume in self._shards:
43 |
44 | sel_input = torch.index_select(input, 0, self._idxn.narrow(0,starte,nume))
45 |
46 | # aggregate over nodes
47 | if self._idxn.is_cuda:
48 | if self._aggr==GraphPoolFunction.AGGR_MEAN:
49 | cuda_kernels.avgpool_fw(output.narrow(0,startd,numd), sel_input, self._degs_gpu.narrow(0,startd,numd))
50 | elif self._aggr==GraphPoolFunction.AGGR_MAX:
51 | cuda_kernels.maxpool_fw(output.narrow(0,startd,numd), self._max_indices.narrow(0,startd,numd), sel_input, self._degs_gpu.narrow(0,startd,numd))
52 | else:
53 | k = 0
54 | for i in range(startd, startd+numd):
55 | if self._degs[i]>0:
56 | if self._aggr==GraphPoolFunction.AGGR_MEAN:
57 | torch.mean(sel_input.narrow(0,k,self._degs[i]), 0, out=output[i])
58 | elif self._aggr==GraphPoolFunction.AGGR_MAX:
59 | torch.max(sel_input.narrow(0,k,self._degs[i]), 0, out=(output[i], self._max_indices[i]))
60 | else:
61 | output[i].fill_(0)
62 | k = k + self._degs[i]
63 |
64 | startd += numd
65 | starte += nume
66 | del sel_input
67 |
68 | return output
69 |
70 |
71 | def backward(self, grad_output):
72 | grad_input = grad_output.new(self._input_size).fill_(0)
73 |
74 | # loop over blocks of output nodes
75 | startd, starte = 0, 0
76 | for numd, nume in self._shards:
77 |
78 | grad_sel_input = grad_output.new(nume, grad_output.size(1))
79 |
80 | # grad wrt input
81 | if self._idxn.is_cuda:
82 | if self._aggr==GraphPoolFunction.AGGR_MEAN:
83 | cuda_kernels.avgpool_bw(grad_input, self._idxn.narrow(0,starte,nume), grad_output.narrow(0,startd,numd), self._degs_gpu.narrow(0,startd,numd))
84 | elif self._aggr==GraphPoolFunction.AGGR_MAX:
85 | cuda_kernels.maxpool_bw(grad_input, self._idxn.narrow(0,starte,nume), self._max_indices.narrow(0,startd,numd), grad_output.narrow(0,startd,numd), self._degs_gpu.narrow(0,startd,numd))
86 | else:
87 | k = 0
88 | for i in range(startd, startd+numd):
89 | if self._degs[i]>0:
90 | if self._aggr==GraphPoolFunction.AGGR_MEAN:
91 | torch.div(grad_output[i], self._degs[i], out=grad_sel_input[k])
92 | if self._degs[i]>1:
93 | grad_sel_input.narrow(0, k+1, self._degs[i]-1).copy_( grad_sel_input[k].expand(self._degs[i]-1,1,grad_output.size(1)) )
94 | elif self._aggr==GraphPoolFunction.AGGR_MAX:
95 | grad_sel_input.narrow(0, k, self._degs[i]).fill_(0).scatter_(0, self._max_indices[i].view(1,-1), grad_output[i].view(1,-1))
96 | k = k + self._degs[i]
97 |
98 | grad_input.index_add_(0, self._idxn.narrow(0,starte,nume), grad_sel_input)
99 |
100 | startd += numd
101 | starte += nume
102 | del grad_sel_input
103 |
104 | return grad_input
105 |
106 |
107 |
108 | class GraphPoolModule(nn.Module):
109 | """ Performs graph pooling.
110 | The input should be a 2D tensor of size (# nodes, `in_channels`). Multiple graphs can be concatenated in the same tensor (minibatch).
111 |
112 | Parameters:
113 | aggr: aggregation type (GraphPoolFunction.AGGR_MEAN, GraphPoolFunction.AGGR_MAX)
114 | gp_info: GraphPoolInfo object containing node mapping information, can be also set with `set_info()` method.
115 | edge_mem_limit: block size (number of evaluated edges in parallel), a low value reduces peak memory.
116 | """
117 |
118 | def __init__(self, aggr, gp_info=None, edge_mem_limit=1e20):
119 | super(GraphPoolModule, self).__init__()
120 |
121 | self._aggr = aggr
122 | self._edge_mem_limit = edge_mem_limit
123 | self.set_info(gp_info)
124 |
125 | def set_info(self, gp_info):
126 | self._gpi = gp_info
127 |
128 | def forward(self, input):
129 | idxn, degs, degs_gpu = self._gpi.get_buffers()
130 | return GraphPoolFunction(idxn, degs, degs_gpu, self._aggr, self._edge_mem_limit)(input)
131 |
132 |
133 | class GraphAvgPoolModule(GraphPoolModule):
134 | def __init__(self, gp_info=None, edge_mem_limit=1e20):
135 | super(GraphAvgPoolModule, self).__init__(GraphPoolFunction.AGGR_MEAN, gp_info, edge_mem_limit)
136 |
137 | class GraphMaxPoolModule(GraphPoolModule):
138 | def __init__(self, gp_info=None, edge_mem_limit=1e20):
139 | super(GraphMaxPoolModule, self).__init__(GraphPoolFunction.AGGR_MAX, gp_info, edge_mem_limit)
--------------------------------------------------------------------------------
/learning/ecc/__init__.py:
--------------------------------------------------------------------------------
1 | from .GraphConvInfo import GraphConvInfo
2 | from .GraphConvModule import GraphConvModule, GraphConvFunction
3 |
4 | from .GraphPoolInfo import GraphPoolInfo
5 | from .GraphPoolModule import GraphAvgPoolModule, GraphMaxPoolModule
6 |
7 | from .utils import *
8 |
--------------------------------------------------------------------------------
/learning/ecc/cuda_kernels.py:
--------------------------------------------------------------------------------
1 | """
2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs
3 | https://github.com/mys007/ecc
4 | https://arxiv.org/abs/1704.02901
5 | 2017 Martin Simonovsky
6 | """
7 | from __future__ import division
8 | from __future__ import print_function
9 | from builtins import range
10 |
11 | import torch
12 | try:
13 | import cupy.cuda
14 | from pynvrtc.compiler import Program
15 | except:
16 | pass
17 | from collections import namedtuple
18 | import numpy as np
19 |
20 | CUDA_NUM_THREADS = 1024
21 |
22 | def GET_BLOCKS(N):
23 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS;
24 |
25 | modules = {}
26 |
27 | def get_dtype(t):
28 | if isinstance(t, torch.cuda.FloatTensor):
29 | return 'float'
30 | elif isinstance(t, torch.cuda.DoubleTensor):
31 | return 'double'
32 |
33 | def get_kernel_func(kname, ksrc, dtype):
34 | if kname+dtype not in modules:
35 | ksrc = ksrc.replace('DTYPE', dtype)
36 | #prog = Program(ksrc.encode('utf-8'), (kname+dtype+'.cu').encode('utf-8'))
37 | #uncomment the line above and comment the line below if it causes the following error: AttributeError: 'Program' object has no attribute '_program'
38 | prog = Program(ksrc, kname+dtype+'.cu')
39 | ptx = prog.compile()
40 | log = prog._interface.nvrtcGetProgramLog(prog._program)
41 | if len(log.strip()) > 0: print(log)
42 | module = cupy.cuda.function.Module()
43 | module.load(bytes(ptx.encode()))
44 | modules[kname+dtype] = module
45 | else:
46 | module = modules[kname+dtype]
47 |
48 | Stream = namedtuple('Stream', ['ptr'])
49 | s = Stream(ptr=torch.cuda.current_stream().cuda_stream)
50 |
51 | return module.get_function(kname), s
52 |
53 | ####
54 |
55 | def conv_aggregate_fw_kernel_v2(**kwargs):
56 | kernel = r'''
57 | extern "C"
58 | __global__ void conv_aggregate_fw_kernel_v2(DTYPE* dest, const DTYPE* src, const long long* lengths, const long long* cslengths, int width, int N, int dest_stridex, int src_stridex, int blockDimy) {
59 |
60 | int x = blockIdx.x * blockDim.x + threadIdx.x; //one thread per feature channel, runs over all nodes
61 | if (x >= width) return;
62 |
63 | int i = blockIdx.y * blockDimy;
64 | int imax = min(N, i + blockDimy);
65 | dest += dest_stridex * i + x;
66 | src += src_stridex * (cslengths[i] - lengths[i]) + x;
67 |
68 | for (; i 0) {
71 | DTYPE sum = 0;
72 | for (int j=0; j= width) return;
95 |
96 | int i = blockIdx.y * blockDimy;
97 | int imax = min(N, i + blockDimy);
98 | dest += dest_stridex * (cslengths[i] - lengths[i]) + x;
99 | src += src_stridex * i + x;
100 |
101 | for (; i 0) {
104 | DTYPE val = *src / len;
105 | for (int j=0; j= width) return;
150 |
151 | for (int i=0; i 0) {
153 | long long src_step = lengths[i] * src_stridex;
154 | long long bestjj = -1;
155 | DTYPE best = -1e10;
156 |
157 | for (long long j = x, jj=0; j < src_step; j += src_stridex, ++jj) {
158 | if (src[j] > best) {
159 | best = src[j];
160 | bestjj = jj;
161 | }
162 | }
163 |
164 | dest[x] = best;
165 | indices[x] = bestjj;
166 |
167 | src += src_step;
168 | }
169 | else {
170 | dest[x] = 0;
171 | indices[x] = -1;
172 | }
173 |
174 | dest += dest_stridex;
175 | indices += dest_stridex;
176 | }
177 | }
178 | '''
179 | return kernel
180 |
181 | def maxpool_bw_kernel(**kwargs):
182 | kernel = r'''
183 | //also directly scatters results by dest_indices (saves one sparse intermediate buffer)
184 | extern "C"
185 | __global__ void maxpool_bw_kernel(DTYPE* dest, const long long* dest_indices, const long long* max_indices, const DTYPE* src, const long long* lengths, int width, int N, int dest_stridex, int src_stridex) {
186 |
187 | int x = blockIdx.x * blockDim.x + threadIdx.x; //one thread per feature channel, runs over all points
188 | if (x >= width) return;
189 |
190 | for (int i=0; i 0) {
192 |
193 | long long destidx = dest_indices[max_indices[x]];
194 | dest[x + destidx * dest_stridex] += src[x]; //no need for atomicadd, only one threads cares about each feat
195 |
196 | dest_indices += lengths[i];
197 | }
198 |
199 | src += src_stridex;
200 | max_indices += src_stridex;
201 | }
202 | }
203 | '''
204 | return kernel
205 |
206 |
207 | def maxpool_fw(dest, indices, src, degs):
208 | n = degs.numel()
209 | w = src.size(1)
210 | assert n == dest.size(0) and w == dest.size(1)
211 | assert type(src)==type(dest) and isinstance(degs, torch.cuda.LongTensor) and isinstance(indices, torch.cuda.LongTensor)
212 |
213 | function, stream = get_kernel_func('maxpool_fw_kernel', maxpool_fw_kernel(), get_dtype(src))
214 | function(args=[dest.data_ptr(), indices.data_ptr(), src.data_ptr(), degs.data_ptr(), np.int32(w), np.int32(n), np.int32(dest.stride(0)), np.int32(src.stride(0))],
215 | block=(CUDA_NUM_THREADS,1,1), grid=(GET_BLOCKS(w),1,1), stream=stream)
216 |
217 | def maxpool_bw(dest, idxn, indices, src, degs):
218 | n = degs.numel()
219 | w = src.size(1)
220 | assert n == src.size(0) and w == dest.size(1)
221 | assert type(src)==type(dest) and isinstance(degs, torch.cuda.LongTensor) and isinstance(indices, torch.cuda.LongTensor) and isinstance(idxn, torch.cuda.LongTensor)
222 |
223 | function, stream = get_kernel_func('maxpool_bw_kernel', maxpool_bw_kernel(), get_dtype(src))
224 | function(args=[dest.data_ptr(), idxn.data_ptr(), indices.data_ptr(), src.data_ptr(), degs.data_ptr(), np.int32(w), np.int32(n), np.int32(dest.stride(0)), np.int32(src.stride(0))],
225 | block=(CUDA_NUM_THREADS,1,1), grid=(GET_BLOCKS(w),1,1), stream=stream)
226 |
227 |
228 |
229 | def avgpool_bw_kernel(**kwargs):
230 | kernel = r'''
231 | //also directly scatters results by dest_indices (saves one intermediate buffer)
232 | extern "C"
233 | __global__ void avgpool_bw_kernel(DTYPE* dest, const long long* dest_indices, const DTYPE* src, const long long* lengths, int width, int N, int dest_stridex, int src_stridex) {
234 |
235 | int x = blockIdx.x * blockDim.x + threadIdx.x; //one thread per feature channel, runs over all points
236 | if (x >= width) return;
237 |
238 | for (int i=0; i 0) {
240 |
241 | DTYPE val = src[x] / lengths[i];
242 |
243 | for (int j = 0; j < lengths[i]; ++j) {
244 | long long destidx = dest_indices[j];
245 | dest[x + destidx * dest_stridex] += val; //no need for atomicadd, only one threads cares about each feat
246 | }
247 |
248 | dest_indices += lengths[i];
249 | }
250 |
251 | src += src_stridex;
252 | }
253 | }
254 | '''
255 | return kernel
256 |
257 |
258 | def avgpool_fw(dest, src, degs):
259 | conv_aggregate_fw(dest, src, degs)
260 |
261 | def avgpool_bw(dest, idxn, src, degs):
262 | n = degs.numel()
263 | w = src.size(1)
264 | assert n == src.size(0) and w == dest.size(1)
265 | assert type(src)==type(dest) and isinstance(degs, torch.cuda.LongTensor) and isinstance(idxn, torch.cuda.LongTensor)
266 |
267 | function, stream = get_kernel_func('avgpool_bw_kernel', avgpool_bw_kernel(), get_dtype(src))
268 | function(args=[dest.data_ptr(), idxn.data_ptr(), src.data_ptr(), degs.data_ptr(), np.int32(w), np.int32(n), np.int32(dest.stride(0)), np.int32(src.stride(0))],
269 | block=(CUDA_NUM_THREADS,1,1), grid=(GET_BLOCKS(w),1,1), stream=stream)
270 |
--------------------------------------------------------------------------------
/learning/ecc/test_GraphConvModule.py:
--------------------------------------------------------------------------------
1 | """
2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs
3 | https://github.com/mys007/ecc
4 | https://arxiv.org/abs/1704.02901
5 | 2017 Martin Simonovsky
6 | """
7 | from __future__ import division
8 | from __future__ import print_function
9 | from builtins import range
10 |
11 | import unittest
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | from torch.autograd import Variable, gradcheck
16 |
17 | from .GraphConvModule import *
18 | from .GraphConvInfo import GraphConvInfo
19 |
20 |
21 | class TestGraphConvModule(unittest.TestCase):
22 |
23 | def test_gradcheck(self):
24 |
25 | torch.set_default_tensor_type('torch.DoubleTensor') #necessary for proper numerical gradient
26 |
27 | for cuda in range(0,2):
28 | # without idxe
29 | n,e,in_channels, out_channels = 20,50,10, 15
30 | input = torch.randn(n,in_channels)
31 | weights = torch.randn(e,in_channels,out_channels)
32 | idxn = torch.from_numpy(np.random.randint(n,size=e))
33 | idxe = None
34 | degs = torch.LongTensor([5, 0, 15, 20, 10]) #strided conv
35 | degs_gpu = degs
36 | edge_mem_limit = 30 # some nodes will be combined, some not
37 | if cuda:
38 | input = input.cuda(); weights = weights.cuda(); idxn = idxn.cuda(); degs_gpu = degs_gpu.cuda()
39 |
40 | func = GraphConvFunction(in_channels, out_channels, idxn, idxe, degs, degs_gpu, edge_mem_limit=edge_mem_limit)
41 | data = (Variable(input, requires_grad=True), Variable(weights, requires_grad=True))
42 |
43 | ok = gradcheck(func, data)
44 | self.assertTrue(ok)
45 |
46 | # with idxe
47 | weights = torch.randn(30,in_channels,out_channels)
48 | idxe = torch.from_numpy(np.random.randint(30,size=e))
49 | if cuda:
50 | weights = weights.cuda(); idxe = idxe.cuda()
51 |
52 | func = GraphConvFunction(in_channels, out_channels, idxn, idxe, degs, degs_gpu, edge_mem_limit=edge_mem_limit)
53 |
54 | ok = gradcheck(func, data)
55 | self.assertTrue(ok)
56 |
57 | torch.set_default_tensor_type('torch.FloatTensor')
58 |
59 | def test_batch_splitting(self):
60 |
61 | n,e,in_channels, out_channels = 20,50,10, 15
62 | input = torch.randn(n,in_channels)
63 | weights = torch.randn(e,in_channels,out_channels)
64 | idxn = torch.from_numpy(np.random.randint(n,size=e))
65 | idxe = None
66 | degs = torch.LongTensor([5, 0, 15, 20, 10]) #strided conv
67 |
68 | func = GraphConvFunction(in_channels, out_channels, idxn, idxe, degs, degs, edge_mem_limit=1e10)
69 | data = (Variable(input, requires_grad=True), Variable(weights, requires_grad=True))
70 | output1 = func(*data)
71 |
72 | func = GraphConvFunction(in_channels, out_channels, idxn, idxe, degs, degs, edge_mem_limit=1)
73 | output2 = func(*data)
74 |
75 | self.assertLess((output1-output2).norm().data[0], 1e-6)
76 |
77 |
78 |
79 | if __name__ == '__main__':
80 | unittest.main()
--------------------------------------------------------------------------------
/learning/ecc/test_GraphPoolModule.py:
--------------------------------------------------------------------------------
1 | """
2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs
3 | https://github.com/mys007/ecc
4 | https://arxiv.org/abs/1704.02901
5 | 2017 Martin Simonovsky
6 | """
7 | from __future__ import division
8 | from __future__ import print_function
9 | from builtins import range
10 |
11 | import unittest
12 | import numpy as np
13 | import torch
14 | from torch.autograd import Variable, gradcheck
15 |
16 | from .GraphPoolModule import *
17 | from .GraphPoolInfo import GraphPoolInfo
18 |
19 |
20 | class TestGraphConvModule(unittest.TestCase):
21 |
22 | def test_gradcheck(self):
23 |
24 | torch.set_default_tensor_type('torch.DoubleTensor') #necessary for proper numerical gradient
25 |
26 | for cuda in range(0,2):
27 | for aggr in range(0,2):
28 | n,in_channels = 20,10
29 | input = torch.randn(n,in_channels)
30 | idxn = torch.from_numpy(np.random.permutation(n))
31 | degs = torch.LongTensor([2, 0, 3, 10, 5])
32 | degs_gpu = degs
33 | edge_mem_limit = 30 # some nodes will be combined, some not
34 | if cuda:
35 | input = input.cuda(); idxn = idxn.cuda(); degs_gpu = degs_gpu.cuda()
36 |
37 | func = GraphPoolFunction(idxn, degs, degs_gpu, aggr=aggr, edge_mem_limit=edge_mem_limit)
38 | data = (Variable(input, requires_grad=True),)
39 |
40 | ok = gradcheck(func, data)
41 | self.assertTrue(ok)
42 |
43 | torch.set_default_tensor_type('torch.FloatTensor')
44 |
45 | def test_batch_splitting(self):
46 | n,in_channels = 20,10
47 | input = torch.randn(n,in_channels)
48 | idxn = torch.from_numpy(np.random.permutation(n))
49 | degs = torch.LongTensor([2, 0, 3, 10, 5])
50 |
51 | func = GraphPoolFunction(idxn, degs, degs, aggr=GraphPoolFunction.AGGR_MAX, edge_mem_limit=1e10)
52 | data = (Variable(input, requires_grad=True),)
53 | output1 = func(*data)
54 |
55 | func = GraphPoolFunction(idxn, degs, degs, aggr=GraphPoolFunction.AGGR_MAX, edge_mem_limit=1)
56 | output2 = func(*data)
57 |
58 | self.assertLess((output1-output2).norm(), 1e-6)
59 |
60 | if __name__ == '__main__':
61 | unittest.main()
--------------------------------------------------------------------------------
/learning/ecc/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs
3 | https://github.com/mys007/ecc
4 | https://arxiv.org/abs/1704.02901
5 | 2017 Martin Simonovsky
6 | """
7 | from __future__ import division
8 | from __future__ import print_function
9 | from builtins import range
10 |
11 | import random
12 | import numpy as np
13 | import torch
14 |
15 | import ecc
16 |
17 | def graph_info_collate_classification(batch, edge_func):
18 | """ Collates a list of dataset samples into a single batch. We assume that all samples have the same number of resolutions.
19 |
20 | Each sample is a tuple of following elements:
21 | features: 2D Tensor of node features
22 | classes: LongTensor of class ids
23 | graphs: list of graphs, each for one resolution
24 | pooldata: list of triplets, each for one resolution: (pooling map, finer graph, coarser graph)
25 | """
26 | features, classes, graphs, pooldata = list(zip(*batch))
27 | graphs_by_layer = list(zip(*graphs))
28 | pooldata_by_layer = list(zip(*pooldata))
29 |
30 | features = torch.cat([torch.from_numpy(f) for f in features])
31 | if features.dim()==1: features = features.view(-1,1)
32 |
33 | classes = torch.LongTensor(classes)
34 |
35 | GIs, PIs = [], []
36 | for graphs in graphs_by_layer:
37 | GIs.append( ecc.GraphConvInfo(graphs, edge_func) )
38 | for pooldata in pooldata_by_layer:
39 | PIs.append( ecc.GraphPoolInfo(*zip(*pooldata)) )
40 |
41 | return features, classes, GIs, PIs
42 |
43 |
44 | def unique_rows(data):
45 | """ Filters unique rows from a 2D np array and also returns inverse indices. Used for edge feature compaction. """
46 | # https://stackoverflow.com/questions/16970982/find-unique-rows-in-numpy-array
47 | uniq, indices = np.unique(data.view(data.dtype.descr * data.shape[1]), return_inverse=True)
48 | return uniq.view(data.dtype).reshape(-1, data.shape[1]), indices
49 |
50 | def one_hot_discretization(feat, clip_min, clip_max, upweight):
51 | indices = np.clip(np.round(feat), clip_min, clip_max).astype(int).reshape((-1,))
52 | onehot = np.zeros((feat.shape[0], clip_max - clip_min + 1))
53 | onehot[np.arange(onehot.shape[0]), indices] = onehot.shape[1] if upweight else 1
54 | return onehot
55 |
56 | def get_edge_shards(degs, edge_mem_limit):
57 | """ Splits iteration over nodes into shards, approximately limited by `edge_mem_limit` edges per shard.
58 | Returns a list of pairs indicating how many output nodes and edges to process in each shard."""
59 | d = degs if isinstance(degs, np.ndarray) else degs.numpy()
60 | cs = np.cumsum(d)
61 | cse = cs // edge_mem_limit
62 | _, cse_i, cse_c = np.unique(cse, return_index=True, return_counts=True)
63 |
64 | shards = []
65 | for b in range(len(cse_i)):
66 | numd = cse_c[b]
67 | nume = (cs[-1] if b==len(cse_i)-1 else cs[cse_i[b+1]-1]) - cs[cse_i[b]] + d[cse_i[b]]
68 | shards.append( (int(numd), int(nume)) )
69 | return shards
70 |
--------------------------------------------------------------------------------
/learning/evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Thu Jul 4 09:22:55 2019
5 |
6 | @author: landrieuloic
7 | """
8 |
9 | """
10 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs
11 | http://arxiv.org/abs/1711.09869
12 | 2017 Loic Landrieu, Martin Simonovsky
13 | """
14 | import argparse
15 | import numpy as np
16 | import sys
17 | sys.path.append("./learning")
18 | from metrics import *
19 |
20 | parser = argparse.ArgumentParser(description='Evaluation function for S3DIS')
21 |
22 | parser.add_argument('--odir', default='./results/s3dis/best', help='Directory to store results')
23 | parser.add_argument('--dataset', default='s3dis', help='Directory to store results')
24 | parser.add_argument('--cvfold', default='123456', help='which fold to consider')
25 |
26 | args = parser.parse_args()
27 |
28 |
29 |
30 | if args.dataset == 's3dis':
31 | n_labels = 13
32 | inv_class_map = {0:'ceiling', 1:'floor', 2:'wall', 3:'column', 4:'beam', 5:'window', 6:'door', 7:'table', 8:'chair', 9:'bookcase', 10:'sofa', 11:'board', 12:'clutter'}
33 | base_name = args.odir+'/cv'
34 | elif args.dataset == 'vkitti':
35 | n_labels = 13
36 | inv_class_map = {0:'Terrain', 1:'Tree', 2:'Vegetation', 3:'Building', 4:'Road', 5:'GuardRail', 6:'TrafficSign', 7:'TrafficLight', 8:'Pole', 9:'Misc', 10:'Truck', 11:'Car', 12:'Van'}
37 | base_name = args.odir+'/cv'
38 |
39 | C = ConfusionMatrix(n_labels)
40 | C.confusion_matrix=np.zeros((n_labels, n_labels))
41 |
42 |
43 | for i_fold in range(len(args.cvfold)):
44 | fold = int(args.cvfold[i_fold])
45 | cm = ConfusionMatrix(n_labels)
46 | cm.confusion_matrix=np.load(base_name+str(fold) +'/pointwise_cm.npy')
47 | print("Fold %d : \t OA = %3.2f \t mA = %3.2f \t mIoU = %3.2f" % (fold, \
48 | 100 * ConfusionMatrix.get_overall_accuracy(cm) \
49 | , 100 * ConfusionMatrix.get_mean_class_accuracy(cm) \
50 | , 100 * ConfusionMatrix.get_average_intersection_union(cm)
51 | ))
52 | C.confusion_matrix += cm.confusion_matrix
53 |
54 | print("\nOverall accuracy : %3.2f %%" % (100 * (ConfusionMatrix.get_overall_accuracy(C))))
55 | print("Mean accuracy : %3.2f %%" % (100 * (ConfusionMatrix.get_mean_class_accuracy(C))))
56 | print("Mean IoU : %3.2f %%\n" % (100 * (ConfusionMatrix.get_average_intersection_union(C))))
57 | print(" Classe : IoU")
58 | for c in range(0,n_labels):
59 | print (" %12s : %6.2f %% \t %.1e points" %(inv_class_map[c],100*ConfusionMatrix.get_intersection_union_per_class(C)[c], ConfusionMatrix.count_gt(C,c)))
60 |
--------------------------------------------------------------------------------
/learning/graphnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs
3 | http://arxiv.org/abs/1711.09869
4 | 2017 Loic Landrieu, Martin Simonovsky
5 | """
6 | from __future__ import division
7 | from __future__ import print_function
8 | from builtins import range
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.init as init
13 | from learning import ecc
14 | from learning.modules import RNNGraphConvModule, ECC_CRFModule, GRUCellEx, LSTMCellEx
15 |
16 |
17 | def create_fnet(widths, orthoinit, llbias, bnidx=-1):
18 | """ Creates feature-generating network, a multi-layer perceptron.
19 | Parameters:
20 | widths: list of widths of layers (including input and output widths)
21 | orthoinit: whether to use orthogonal weight initialization
22 | llbias: whether to use bias in the last layer
23 | bnidx: index of batch normalization (-1 if not used)
24 | """
25 | fnet_modules = []
26 | for k in range(len(widths)-2):
27 | fnet_modules.append(nn.Linear(widths[k], widths[k+1]))
28 | if orthoinit: init.orthogonal_(fnet_modules[-1].weight, gain=init.calculate_gain('relu'))
29 | if bnidx==k: fnet_modules.append(nn.BatchNorm1d(widths[k+1]))
30 | fnet_modules.append(nn.ReLU(True))
31 | fnet_modules.append(nn.Linear(widths[-2], widths[-1], bias=llbias))
32 | if orthoinit: init.orthogonal_(fnet_modules[-1].weight)
33 | if bnidx==len(widths)-1: fnet_modules.append(nn.BatchNorm1d(fnet_modules[-1].weight.size(0)))
34 | return nn.Sequential(*fnet_modules)
35 |
36 |
37 | class GraphNetwork(nn.Module):
38 | """ It is constructed in a flexible way based on `config` string, which contains sequence of comma-delimited layer definiton tokens layer_arg1_arg2_... See README.md for examples.
39 | """
40 | def __init__(self, config, nfeat, fnet_widths, fnet_orthoinit=True, fnet_llbias=True, fnet_bnidx=-1, edge_mem_limit=1e20, use_pyg = True, cuda = True):
41 | super(GraphNetwork, self).__init__()
42 | self.gconvs = []
43 |
44 | for d, conf in enumerate(config.split(',')):
45 | conf = conf.strip().split('_')
46 |
47 | if conf[0]=='f': #Fully connected layer; args: output_feats
48 | self.add_module(str(d), nn.Linear(nfeat, int(conf[1])))
49 | nfeat = int(conf[1])
50 | elif conf[0]=='b': #Batch norm; args: not_affine
51 | self.add_module(str(d), nn.BatchNorm1d(nfeat, eps=1e-5, affine=len(conf)==1))
52 | elif conf[0]=='r': #ReLU;
53 | self.add_module(str(d), nn.ReLU(True))
54 | elif conf[0]=='d': #Dropout; args: dropout_prob
55 | self.add_module(str(d), nn.Dropout(p=float(conf[1]), inplace=False))
56 |
57 | elif conf[0]=='crf': #ECC-CRF; args: repeats
58 | nrepeats = int(conf[1])
59 |
60 | fnet = create_fnet(fnet_widths + [nfeat*nfeat], fnet_orthoinit, fnet_llbias, fnet_bnidx)
61 | gconv = ecc.GraphConvModule(nfeat, nfeat, fnet, edge_mem_limit=edge_mem_limit)
62 | crf = ECC_CRFModule(gconv, nrepeats)
63 | self.add_module(str(d), crf)
64 | self.gconvs.append(gconv)
65 |
66 | elif conf[0]=='gru' or conf[0]=='lstm': #RNN-ECC args: repeats, mv=False, layernorm=True, ingate=True, cat_all=True
67 | nrepeats = int(conf[1])
68 | vv = bool(int(conf[2])) if len(conf)>2 else True # whether ECC does matrix-value mult or element-wise mult
69 | layernorm = bool(int(conf[3])) if len(conf)>3 else True
70 | ingate = bool(int(conf[4])) if len(conf)>4 else True
71 | cat_all = bool(int(conf[5])) if len(conf)>5 else True
72 |
73 | fnet = create_fnet(fnet_widths + [nfeat**2 if not vv else nfeat], fnet_orthoinit, fnet_llbias, fnet_bnidx)
74 | if conf[0]=='gru':
75 | cell = GRUCellEx(nfeat, nfeat, bias=True, layernorm=layernorm, ingate=ingate)
76 | else:
77 | cell = LSTMCellEx(nfeat, nfeat, bias=True, layernorm=layernorm, ingate=ingate)
78 | gconv = RNNGraphConvModule(cell, fnet, nfeat, vv = vv, nrepeats=nrepeats, cat_all=cat_all, edge_mem_limit=edge_mem_limit, use_pyg = use_pyg, cuda = cuda)
79 | self.add_module(str(d), gconv)
80 | self.gconvs.append(gconv)
81 | if cat_all: nfeat *= nrepeats + 1
82 |
83 | elif len(conf[0])>0:
84 | raise NotImplementedError('Unknown module: ' + conf[0])
85 |
86 |
87 | def set_info(self, gc_infos, cuda):
88 | """ Provides convolution modules with graph structure information for the current batch.
89 | """
90 | gc_infos = gc_infos if isinstance(gc_infos,(list,tuple)) else [gc_infos]
91 | for i,gc in enumerate(self.gconvs):
92 | if cuda: gc_infos[i].cuda()
93 | gc.set_info(gc_infos[i])
94 |
95 | def forward(self, input):
96 | for module in self._modules.values():
97 | input = module(input)
98 | return input
99 |
100 |
--------------------------------------------------------------------------------
/learning/metrics.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from builtins import range
4 |
5 | import numpy as np
6 |
7 | # extended official code from http://www.semantic3d.net/scripts/metric.py
8 | class ConfusionMatrix:
9 | """Streaming interface to allow for any source of predictions. Initialize it, count predictions one by one, then print confusion matrix and intersection-union score"""
10 | def __init__(self, number_of_labels = 2):
11 | self.number_of_labels = number_of_labels
12 | self.confusion_matrix = np.zeros(shape=(self.number_of_labels,self.number_of_labels))
13 | def count_predicted(self, ground_truth, predicted, number_of_added_elements=1):
14 | self.confusion_matrix[ground_truth][predicted] += number_of_added_elements
15 |
16 | def count_predicted_batch(self, ground_truth_vec, predicted): # added
17 | for i in range(ground_truth_vec.shape[0]):
18 | self.confusion_matrix[:,predicted[i]] += ground_truth_vec[i,:]
19 |
20 | def count_predicted_batch_hard(self, ground_truth_vec, predicted): # added
21 | for i in range(ground_truth_vec.shape[0]):
22 | self.confusion_matrix[ground_truth_vec[i],predicted[i]] += 1
23 |
24 | """labels are integers from 0 to number_of_labels-1"""
25 | def get_count(self, ground_truth, predicted):
26 | return self.confusion_matrix[ground_truth][predicted]
27 | """returns list of lists of integers; use it as result[ground_truth][predicted]
28 | to know how many samples of class ground_truth were reported as class predicted"""
29 | def get_confusion_matrix(self):
30 | return self.confusion_matrix
31 | """returns list of 64-bit floats"""
32 | def get_intersection_union_per_class(self):
33 | matrix_diagonal = [self.confusion_matrix[i][i] for i in range(self.number_of_labels)]
34 | errors_summed_by_row = [0] * self.number_of_labels
35 | for row in range(self.number_of_labels):
36 | for column in range(self.number_of_labels):
37 | if row != column:
38 | errors_summed_by_row[row] += self.confusion_matrix[row][column]
39 | errors_summed_by_column = [0] * self.number_of_labels
40 | for column in range(self.number_of_labels):
41 | for row in range(self.number_of_labels):
42 | if row != column:
43 | errors_summed_by_column[column] += self.confusion_matrix[row][column]
44 |
45 | divisor = [0] * self.number_of_labels
46 | for i in range(self.number_of_labels):
47 | divisor[i] = matrix_diagonal[i] + errors_summed_by_row[i] + errors_summed_by_column[i]
48 | if matrix_diagonal[i] == 0:
49 | divisor[i] = 1
50 |
51 | return [float(matrix_diagonal[i]) / divisor[i] for i in range(self.number_of_labels)]
52 | """returns 64-bit float"""
53 |
54 | def get_overall_accuracy(self):
55 | matrix_diagonal = 0
56 | all_values = 0
57 | for row in range(self.number_of_labels):
58 | for column in range(self.number_of_labels):
59 | all_values += self.confusion_matrix[row][column]
60 | if row == column:
61 | matrix_diagonal += self.confusion_matrix[row][column]
62 | if all_values == 0:
63 | all_values = 1
64 | return float(matrix_diagonal) / all_values
65 |
66 |
67 | def get_average_intersection_union(self):
68 | values = self.get_intersection_union_per_class()
69 | class_seen = ((self.confusion_matrix.sum(1)+self.confusion_matrix.sum(0))!=0).sum()
70 | return sum(values) / class_seen
71 |
72 | def get_mean_class_accuracy(self): # added
73 | re = 0
74 | for i in range(self.number_of_labels):
75 | re = re + self.confusion_matrix[i][i] / max(1,np.sum(self.confusion_matrix[i,:]))
76 | return re/self.number_of_labels
77 |
78 | def count_gt(self, ground_truth):
79 | return self.confusion_matrix[ground_truth,:].sum()
80 |
81 | def compute_predicted_transitions(in_component, edg_source, edg_target):
82 |
83 | pred_transitions = in_component[edg_source] != in_component[edg_target]
84 | return pred_transitions
85 |
86 | #-----------------------------------------------------------
87 | def compute_boundary_recall(is_transition, pred_transitions):
88 | return 100*((is_transition==pred_transitions)*is_transition).sum()/is_transition.sum()
89 |
90 | #-----------------------------------------------------------
91 | def compute_boundary_precision(is_transition, pred_transitions):
92 | return 100*((is_transition==pred_transitions)*pred_transitions).sum()/pred_transitions.sum()
93 | #--------------------------------------------
94 |
95 | def mode(array, only_freq=False):
96 | value, counts = np.unique(array, return_counts=True)
97 | if only_freq:
98 | return np.amax(counts)
99 | else:
100 | return value[np.argmax(counts)], np.amax(counts)
101 | #------------------------------------------------
102 | def compute_OOA(components, labels):
103 | hard_labels = labels.argmax(1)
104 | correct_labels=0
105 | for comp in components:
106 | dump, freq = mode(hard_labels[comp])
107 | correct_labels+=freq
108 | return 100*correct_labels/len(labels)
109 |
--------------------------------------------------------------------------------
/learning/modules.py:
--------------------------------------------------------------------------------
1 | """
2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs
3 | http://arxiv.org/abs/1711.09869
4 | 2017 Loic Landrieu, Martin Simonovsky
5 | """
6 | from __future__ import division
7 | from __future__ import print_function
8 | from builtins import range
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as nnf
13 | from torch.autograd import Variable
14 | from learning import ecc
15 |
16 | HAS_PYG = False
17 | try:
18 | from torch_geometric.nn.conv import MessagePassing
19 | from torch_geometric.nn.inits import uniform
20 | HAS_PYG = True
21 | except:
22 | pass
23 |
24 | if HAS_PYG:
25 | class NNConv(MessagePassing):
26 | r"""The continuous kernel-based convolutional operator from the
27 | `"Neural Message Passing for Quantum Chemistry"
28 | `_ paper.
29 | This convolution is also known as the edge-conditioned convolution from the
30 | `"Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on
31 | Graphs" `_ paper (see
32 | :class:`torch_geometric.nn.conv.ECConv` for an alias):
33 |
34 | .. math::
35 | \mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i +
36 | \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot
37 | h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),
38 |
39 | where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.*
40 | a MLP.
41 |
42 | Args:
43 | in_channels (int): Size of each input sample.
44 | out_channels (int): Size of each output sample.
45 | nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
46 | maps edge features :obj:`edge_attr` of shape :obj:`[-1,
47 | num_edge_features]` to shape
48 | :obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by
49 | :class:`torch.nn.Sequential`.
50 | aggr (string, optional): The aggregation scheme to use
51 | (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
52 | (default: :obj:`"add"`)
53 | root_weight (bool, optional): If set to :obj:`False`, the layer will
54 | not add the transformed root node features to the output.
55 | (default: :obj:`True`)
56 | bias (bool, optional): If set to :obj:`False`, the layer will not learn
57 | an additive bias. (default: :obj:`True`)
58 | **kwargs (optional): Additional arguments of
59 | :class:`torch_geometric.nn.conv.MessagePassing`.
60 | """
61 | def __init__(self,
62 | in_channels,
63 | out_channels,
64 | aggr='mean',
65 | root_weight=False,
66 | bias=False,
67 | vv=True,
68 | flow="target_to_source",
69 | negative_slope=0.2,
70 | softmax=False,
71 | **kwargs):
72 | super(NNConv, self).__init__(aggr=aggr, **kwargs)
73 |
74 | self.in_channels = in_channels
75 | self.out_channels = out_channels
76 | self.aggr = aggr
77 | self.vv = vv
78 | self.negative_slope = negative_slope
79 | self.softmax = softmax
80 |
81 | if root_weight:
82 | self.root = Parameter(torch.Tensor(in_channels, out_channels))
83 | else:
84 | self.register_parameter('root', None)
85 |
86 | if bias:
87 | self.bias = Parameter(torch.Tensor(out_channels))
88 | else:
89 | self.register_parameter('bias', None)
90 |
91 | self.reset_parameters()
92 |
93 | def reset_parameters(self):
94 | uniform(self.in_channels, self.root)
95 | uniform(self.in_channels, self.bias)
96 |
97 | def forward(self, x, edge_index, weights):
98 | """"""
99 | x = x.unsqueeze(-1) if x.dim() == 1 else x
100 | return self.propagate(edge_index, x=x, weights=weights)
101 |
102 | def message(self, edge_index_i, x_j, size_i, weights):
103 | if not self.vv:
104 | weight = weights.view(-1, self.in_channels, self.out_channels)
105 | if self.softmax: # APPLY A TWO DIMENSIONAL NON-DEPENDENT SPARSE SOFTMAX
106 | weight = F.leaky_relu(weight, self.negative_slope)
107 | weight = torch.cat([softmax(weight[:, k, :], edge_index_i, size_i).unsqueeze(1) for k in range(self.out_channels)], dim=1)
108 | return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1)
109 | else:
110 | weight = weights.view(-1, self.in_channels)
111 | if self.softmax:
112 | weight = F.leaky_relu(weight, self.negative_slope)
113 | weight = torch.cat([softmax(w.unsqueeze(-1), edge_index_i, size_i).t() for w in weight.t()], dim=0).t()
114 | return x_j * weight
115 |
116 | def update(self, aggr_out, x):
117 | if self.root is not None:
118 | aggr_out = aggr_out + torch.mm(x, self.root)
119 | if self.bias is not None:
120 | aggr_out = aggr_out + self.bias
121 | return aggr_out
122 |
123 | def __repr__(self):
124 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
125 | self.out_channels)
126 |
127 |
128 | class RNNGraphConvModule(nn.Module):
129 | """
130 | Computes recurrent graph convolution using filter weights obtained from a Filter generating network (`filter_net`).
131 | Its result is passed to RNN `cell` and the process is repeated over `nrepeats` iterations.
132 | Weight sharing over iterations is done both in RNN cell and in Filter generating network.
133 | """
134 | def __init__(self, cell, filter_net, nfeat, vv = True, gc_info=None, nrepeats=1, cat_all=False, edge_mem_limit=1e20, use_pyg = True, cuda = True):
135 | super(RNNGraphConvModule, self).__init__()
136 | self._cell = cell
137 | self._isLSTM = 'LSTM' in type(cell).__name__
138 | self._fnet = filter_net
139 | self._nrepeats = nrepeats
140 | self._cat_all = cat_all
141 | self._edge_mem_limit = edge_mem_limit
142 | self.set_info(gc_info)
143 | self.use_pyg = use_pyg
144 | if use_pyg:
145 | self.nn = NNConv(nfeat, nfeat, vv = vv)
146 | if cuda:
147 | self.nn = self.nn.cuda()
148 |
149 | def set_info(self, gc_info):
150 | self._gci = gc_info
151 |
152 | def forward(self, hx):
153 | # get graph structure information tensors
154 | idxn, idxe, degs, degs_gpu, edgefeats = self._gci.get_buffers()
155 |
156 | edge_indexes = self._gci.get_pyg_buffers()
157 | ###edgefeats = Variable(edgefeats, requires_grad=False)
158 |
159 | # evalute and reshape filter weights (shared among RNN iterations)
160 | weights = self._fnet(edgefeats)
161 | nc = hx.size(1)
162 | assert hx.dim()==2 and weights.dim()==2 and weights.size(1) in [nc, nc*nc]
163 | if weights.size(1) != nc:
164 | weights = weights.view(-1, nc, nc)
165 |
166 | # repeatedly evaluate RNN cell
167 | hxs = [hx]
168 | if self._isLSTM:
169 | cx = Variable(hx.data.new(hx.size()).fill_(0))
170 |
171 | for r in range(self._nrepeats):
172 | if self.use_pyg:
173 | input = self.nn(hx, edge_indexes, weights)
174 | else:
175 | input = ecc.GraphConvFunction.apply(hx, weights, nc, nc, idxn, idxe, degs, degs_gpu,
176 | self._edge_mem_limit)
177 | if self._isLSTM:
178 | hx, cx = self._cell(input, (hx, cx))
179 | else:
180 | hx = self._cell(input, hx)
181 | hxs.append(hx)
182 |
183 | return torch.cat(hxs,1) if self._cat_all else hx
184 |
185 | class ECC_CRFModule(nn.Module):
186 | """
187 | Adapted "Conditional Random Fields as Recurrent Neural Networks" (https://arxiv.org/abs/1502.03240)
188 | `propagation` should be ECC with Filter generating network producing 2D matrix.
189 | """
190 | def __init__(self, propagation, nrepeats=1):
191 | super(ECC_CRFModule, self).__init__()
192 | self._propagation = propagation
193 | self._nrepeats = nrepeats
194 |
195 | def forward(self, input):
196 | Q = nnf.softmax(input)
197 | for i in range(self._nrepeats):
198 | Q = self._propagation(Q) # todo: speedup possible by sharing computation of fnet
199 | Q = input - Q
200 | if i < self._nrepeats-1:
201 | Q = nnf.softmax(Q) # last softmax will be part of cross-entropy loss
202 | return Q
203 |
204 |
205 | class GRUCellEx(nn.GRUCell):
206 | """ Usual GRU cell extended with layer normalization and input gate.
207 | """
208 | def __init__(self, input_size, hidden_size, bias=True, layernorm=True, ingate=True):
209 | super(GRUCellEx, self).__init__(input_size, hidden_size, bias)
210 | self._layernorm = layernorm
211 | self._ingate = ingate
212 | if layernorm:
213 | self.add_module('ini', nn.InstanceNorm1d(1, eps=1e-5, affine=False, track_running_stats=False))
214 | self.add_module('inh', nn.InstanceNorm1d(1, eps=1e-5, affine=False, track_running_stats=False))
215 | if ingate:
216 | self.add_module('ig', nn.Linear(hidden_size, input_size, bias=True))
217 |
218 | def _normalize(self, gi, gh):
219 | if self._layernorm: # layernorm on input&hidden, as in https://arxiv.org/abs/1607.06450 (Layer Normalization)
220 | gi = self._modules['ini'](gi.unsqueeze(1)).squeeze(1)
221 | gh = self._modules['inh'](gh.unsqueeze(1)).squeeze(1)
222 | return gi, gh
223 |
224 | def forward(self, input, hidden):
225 | if self._ingate:
226 | input = torch.sigmoid(self._modules['ig'](hidden)) * input
227 |
228 | # GRUCell in https://github.com/pytorch/pytorch/blob/master/torch/nn/_functions/rnn.py extended with layer normalization
229 | if input.is_cuda and torch.__version__.split('.')[0]=='0':
230 | gi = nnf.linear(input, self.weight_ih)
231 | gh = nnf.linear(hidden, self.weight_hh)
232 | gi, gh = self._normalize(gi, gh)
233 | state = torch.nn._functions.thnn.rnnFusedPointwise.GRUFused
234 | try: #pytorch >=0.3
235 | return state.apply(gi, gh, hidden) if self.bias_ih is None else state.apply(gi, gh, hidden, self.bias_ih, self.bias_hh)
236 | except: #pytorch <=0.2
237 | return state()(gi, gh, hidden) if self.bias_ih is None else state()(gi, gh, hidden, self.bias_ih, self.bias_hh)
238 |
239 | gi = nnf.linear(input, self.weight_ih)
240 | gh = nnf.linear(hidden, self.weight_hh)
241 | gi, gh = self._normalize(gi, gh)
242 | i_r, i_i, i_n = gi.chunk(3, 1)
243 | h_r, h_i, h_n = gh.chunk(3, 1)
244 | bih_r, bih_i, bih_n = self.bias_ih.chunk(3)
245 | bhh_r, bhh_i, bhh_n = self.bias_hh.chunk(3)
246 |
247 | resetgate = torch.sigmoid(i_r + bih_r + h_r + bhh_r)
248 | inputgate = torch.sigmoid(i_i + bih_i + h_i + bhh_i)
249 | newgate = torch.tanh(i_n + bih_n + resetgate * (h_n + bhh_n))
250 | hy = newgate + inputgate * (hidden - newgate)
251 | return hy
252 |
253 | def __repr__(self):
254 | s = super(GRUCellEx, self).__repr__() + '('
255 | if self._ingate:
256 | s += 'ingate'
257 | if self._layernorm:
258 | s += ' layernorm'
259 | return s + ')'
260 |
261 |
262 | class LSTMCellEx(nn.LSTMCell):
263 | """ Usual LSTM cell extended with layer normalization and input gate.
264 | """
265 | def __init__(self, input_size, hidden_size, bias=True, layernorm=True, ingate=True):
266 | super(LSTMCellEx, self).__init__(input_size, hidden_size, bias)
267 | self._layernorm = layernorm
268 | self._ingate = ingate
269 | if layernorm:
270 | self.add_module('ini', nn.InstanceNorm1d(1, eps=1e-5, affine=False, track_running_stats=False))
271 | self.add_module('inh', nn.InstanceNorm1d(1, eps=1e-5, affine=False, track_running_stats=False))
272 | if ingate:
273 | self.add_module('ig', nn.Linear(hidden_size, input_size, bias=True))
274 |
275 | def _normalize(self, gi, gh):
276 | if self._layernorm: # layernorm on input&hidden, as in https://arxiv.org/abs/1607.06450 (Layer Normalization)
277 | gi = self._modules['ini'](gi.unsqueeze(1)).squeeze(1)
278 | gh = self._modules['inh'](gh.unsqueeze(1)).squeeze(1)
279 | return gi, gh
280 |
281 | def forward(self, input, hidden):
282 | if self._ingate:
283 | input = torch.sigmoid(self._modules['ig'](hidden[0])) * input
284 |
285 | # GRUCell in https://github.com/pytorch/pytorch/blob/master/torch/nn/_functions/rnn.py extended with layer normalization
286 | if input.is_cuda and torch.__version__.split('.')[0]=='0':
287 | gi = nnf.linear(input, self.weight_ih)
288 | gh = nnf.linear(hidden[0], self.weight_hh)
289 | gi, gh = self._normalize(gi, gh)
290 | state = torch.nn._functions.thnn.rnnFusedPointwise.LSTMFused
291 | try: #pytorch >=0.3
292 | return state.apply(gi, gh, hidden[1]) if self.bias_ih is None else state.apply(gi, gh, hidden[1], self.bias_ih, self.bias_hh)
293 | except: #pytorch <=0.2
294 | return state()(gi, gh, hidden[1]) if self.bias_ih is None else state()(gi, gh, hidden[1], self.bias_ih, self.bias_hh)
295 |
296 | gi = nnf.linear(input, self.weight_ih, self.bias_ih)
297 | gh = nnf.linear(hidden[0], self.weight_hh, self.bias_hh)
298 | gi, gh = self._normalize(gi, gh)
299 |
300 | ingate, forgetgate, cellgate, outgate = (gi+gh).chunk(4, 1)
301 | ingate = torch.sigmoid(ingate)
302 | forgetgate = torch.sigmoid(forgetgate)
303 | cellgate = torch.tanh(cellgate)
304 | outgate = torch.sigmoid(outgate)
305 |
306 | cy = (forgetgate * hidden[1]) + (ingate * cellgate)
307 | hy = outgate * torch.tanh(cy)
308 | return hy, cy
309 |
310 | def __repr__(self):
311 | s = super(LSTMCellEx, self).__repr__() + '('
312 | if self._ingate:
313 | s += 'ingate'
314 | if self._layernorm:
315 | s += ' layernorm'
316 | return s + ')'
317 |
--------------------------------------------------------------------------------
/learning/pointnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs
3 | http://arxiv.org/abs/1711.09869
4 | 2017 Loic Landrieu, Martin Simonovsky
5 | """
6 | from __future__ import division
7 | from __future__ import print_function
8 | from builtins import range
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as nnf
14 | from torch.autograd import Variable
15 |
16 | class STNkD(nn.Module):
17 | """
18 | Spatial Transformer Net for PointNet, producing a KxK transformation matrix.
19 | Parameters:
20 | nfeat: number of input features
21 | nf_conv: list of layer widths of point embeddings (before maxpool)
22 | nf_fc: list of layer widths of joint embeddings (after maxpool)
23 | """
24 | def __init__(self, nfeat, nf_conv, nf_fc, K=2, norm = 'batch', affine = True, n_group = 1):
25 | super(STNkD, self).__init__()
26 |
27 | modules = []
28 | for i in range(len(nf_conv)):
29 | modules.append(nn.Conv1d(nf_conv[i-1] if i>0 else nfeat, nf_conv[i], 1))
30 | if norm == 'batch':
31 | modules.append(nn.BatchNorm1d(nf_conv[i]))
32 | elif norm == 'layer':
33 | modules.append(nn.GroupNorm(1,nf_conv[i]))
34 | elif norm == 'group':
35 | modules.append(nn.GroupNorm(n_group,nf_conv[i]))
36 | modules.append(nn.ReLU(True))
37 | self.convs = nn.Sequential(*modules)
38 |
39 | modules = []
40 | for i in range(len(nf_fc)):
41 | modules.append(nn.Linear(nf_fc[i-1] if i>0 else nf_conv[-1], nf_fc[i]))
42 | if norm == 'batch':
43 | modules.append(nn.BatchNorm1d(nf_fc[i]))
44 | elif norm == 'layer':
45 | modules.append(nn.GroupNorm(1,nf_fc[i]))
46 | elif norm == 'group':
47 | modules.append(nn.GroupNorm(n_group,nf_fc[i]))
48 | modules.append(nn.ReLU(True))
49 | self.fcs = nn.Sequential(*modules)
50 |
51 | self.proj = nn.Linear(nf_fc[-1], K*K)
52 | nn.init.constant_(self.proj.weight, 0); nn.init.constant_(self.proj.bias, 0)
53 | self.eye = torch.eye(K).unsqueeze(0)
54 |
55 | def forward(self, input):
56 | self.eye = self.eye.cuda() if input.is_cuda else self.eye
57 | input = self.convs(input)
58 | input = nnf.max_pool1d(input, input.size(2)).squeeze(2)
59 | input = self.fcs(input)
60 | input = self.proj(input)
61 | return input.view(-1,self.eye.size(1),self.eye.size(2)) + Variable(self.eye)
62 |
63 | class PointNet(nn.Module):
64 | """
65 | PointNet with only one spatial transformer and additional "global" input concatenated after maxpool.
66 | Parameters:
67 | nf_conv: list of layer widths of point embeddings (before maxpool)
68 | nf_fc: list of layer widths of joint embeddings (after maxpool)
69 | nfeat: number of input features
70 | nf_conv_stn, nf_fc_stn, nfeat_stn: as above but for Spatial transformer
71 | nfeat_global: number of features concatenated after maxpooling
72 | prelast_do: dropout after the pre-last parameteric layer
73 | last_ac: whether to use batch norm and relu after the last parameteric layer
74 | """
75 | def __init__(self, nf_conv, nf_fc, nf_conv_stn, nf_fc_stn, nfeat, nfeat_stn=2, nfeat_global=1, prelast_do=0.5, last_ac=False, is_res=False, norm = 'batch', affine = True, n_group = 1, last_bn = False):
76 |
77 | super(PointNet, self).__init__()
78 | torch.manual_seed(0)
79 | if nfeat_stn > 0:
80 | self.stn = STNkD(nfeat_stn, nf_conv_stn, nf_fc_stn, norm=norm, n_group = n_group)
81 | self.nfeat_stn = nfeat_stn
82 |
83 | modules = []
84 | for i in range(len(nf_conv)):
85 | modules.append(nn.Conv1d(nf_conv[i-1] if i>0 else nfeat, nf_conv[i], 1))
86 | if norm == 'batch':
87 | modules.append(nn.BatchNorm1d(nf_conv[i]))
88 | elif norm == 'layer':
89 | modules.append(nn.GroupNorm(1, nf_conv[i]))
90 | elif norm == 'group':
91 | modules.append(nn.GroupNorm(n_group, nf_conv[i]))
92 | modules.append(nn.ReLU(True))
93 |
94 | # Initialization of BN parameters.
95 |
96 | self.convs = nn.Sequential(*modules)
97 |
98 | modules = []
99 | for i in range(len(nf_fc)):
100 | modules.append(nn.Linear(nf_fc[i-1] if i>0 else nf_conv[-1]+nfeat_global, nf_fc[i]))
101 | if i0:
110 | modules.append(nn.Dropout(prelast_do))
111 | if is_res: #init with small number so that at first the residual pointnet is close to zero
112 | nn.init.normal_(modules[-1].weight, mean=0, std = 1e-2)
113 | nn.init.normal_(modules[-1].bias, mean=0, std = 1e-2)
114 |
115 | #if last_bn:
116 | #modules.append(nn.BatchNorm1d(nf_fc[-1]))
117 |
118 | self.fcs = nn.Sequential(*modules)
119 |
120 | def forward(self, input, input_global):
121 | if self.nfeat_stn > 0:
122 | T = self.stn(input[:,:self.nfeat_stn,:])
123 | xy_transf = torch.bmm(input[:,:2,:].transpose(1,2), T).transpose(1,2)
124 | input = torch.cat([xy_transf, input[:,2:,:]], 1)
125 |
126 | input = self.convs(input)
127 | input = nnf.max_pool1d(input, input.size(2)).squeeze(2)
128 | if input_global is not None:
129 | if len(input_global.shape)== 1 or input_global.shape[1]==1:
130 | input = torch.cat([input, input_global.view(-1,1)], 1)
131 | else:
132 | input = torch.cat([input, input_global], 1)
133 | return self.fcs(input)
134 |
135 |
136 |
137 |
138 | class CloudEmbedder():
139 | """ Evaluates PointNet on superpoints. Too small superpoints are assigned zero embeddings. Can optionally apply memory mongering
140 | (https://arxiv.org/pdf/1604.06174.pdf) to decrease memory usage.
141 | """
142 | def __init__(self, args):
143 | self.args = args
144 | self.bw_hook = lambda: None # could be more elegant in the upcoming pytorch release: http://bit.ly/2A8PI7p
145 | self.run = self.run_full_monger if args.ptn_mem_monger else self.run_full
146 |
147 | def run_full(self, model, clouds_meta, clouds_flag, clouds, clouds_global):
148 | """ Simply evaluates all clouds in a differentiable way, assumes that all pointnet's feature maps fit into mem."""
149 | idx_valid = torch.nonzero(clouds_flag.eq(0)).squeeze()
150 | if self.args.cuda:
151 | clouds, clouds_global, idx_valid = clouds.cuda(), clouds_global.cuda(), idx_valid.cuda()
152 | clouds, clouds_global = Variable(clouds, volatile=not model.training), Variable(clouds_global, volatile=not model.training)
153 | #print('Ptn with', clouds.size(0), 'clouds')
154 |
155 | out = model.ptn(clouds, clouds_global)
156 | descriptors = Variable(out.data.new(clouds_flag.size(0), out.size(1)).fill_(0))
157 | descriptors.index_copy_(0, Variable(idx_valid), out)
158 | return descriptors
159 |
160 | def run_full_monger(self, model, clouds_meta, clouds_flag, clouds, clouds_global):
161 | """ Evaluates all clouds in forward pass, but uses memory mongering to compute backward pass."""
162 | idx_valid = torch.nonzero(clouds_flag.eq(0)).squeeze()
163 | if self.args.cuda:
164 | clouds, clouds_global, idx_valid = clouds.cuda(), clouds_global.cuda(), idx_valid.cuda()
165 | #print('Ptn with', clouds.size(0), 'clouds')
166 | with torch.no_grad():
167 | out = model.ptn(Variable(clouds), (clouds_global))
168 | if not model.training:
169 | out = Variable(out.data, requires_grad=model.training) # cut autograd
170 | if model.training:
171 | out = Variable(out.data, requires_grad=model.training)
172 | def bw_hook():
173 | out_v2 = model.ptn(Variable(clouds), Variable(clouds_global)) # re-run fw pass
174 | out_v2.backward(out.grad)
175 |
176 | self.bw_hook = bw_hook
177 |
178 | descriptors = Variable(out.data.new(clouds_flag.size(0), out.size(1)).fill_(0))
179 | descriptors.index_copy_(0, Variable(idx_valid), out)
180 | return descriptors
181 |
182 | class LocalCloudEmbedder():
183 | """ Local PointNet
184 | """
185 | def __init__(self, args):
186 | self.nfeat_stn = args.ptn_nfeat_stn
187 | self.stn_as_global = args.stn_as_global
188 |
189 | def run_batch(self, model, clouds, clouds_global, *excess):
190 | """ Evaluates all clouds in a differentiable way, use a batch approach.
191 | Use when embedding many small point clouds with small PointNets at once"""
192 | #cudnn cannot handle arrays larger than 2**16 in one go, uses batch
193 | batch_size = 2**16-1
194 | n_batches = int((clouds.shape[0]-1)/batch_size)
195 | if self.nfeat_stn > 0:
196 | T = model.stn(clouds[:batch_size,:self.nfeat_stn,:])
197 | for i in range(1,n_batches+1):
198 | T = torch.cat((T,model.stn(clouds[i * batch_size:(i+1) * batch_size,:self.nfeat_stn,:])))
199 | xy_transf = torch.bmm(clouds[:,:2,:].transpose(1,2), T).transpose(1,2)
200 | clouds = torch.cat([xy_transf, clouds[:,2:,:]], 1)
201 | if self.stn_as_global:
202 | clouds_global = torch.cat([clouds_global, T.view(-1,4)], 1)
203 |
204 | out = model.ptn(clouds[:batch_size,:,:], clouds_global[:batch_size,:])
205 | for i in range(1,n_batches+1):
206 | out = torch.cat((out,model.ptn(clouds[i * batch_size:(i+1) * batch_size,:,:], clouds_global[i * batch_size:(i+1) * batch_size,:])))
207 | return nnf.normalize(out)
208 |
209 | def run_batch_cpu(self, model, clouds, clouds_global, *excess):
210 | """ Evaluates the cloud on CPU, but put the values in the CPU as soon as they are computed"""
211 | #cudnn cannot handle arrays larger than 2**16 in one go, uses batch
212 | batch_size = 2**10-1
213 | n_batches = int(clouds.shape[0]/batch_size)
214 | emb_total = self.run_batch(model, clouds[:batch_size,:,:], clouds_global[:batch_size,:]).cpu()
215 | for i in range(1,n_batches+1):
216 | emb = self.run_batch(model, clouds[i * batch_size:(i+1) * batch_size,:,:], clouds_global[i * batch_size:(i+1) * batch_size,:])
217 | emb_total = torch.cat((emb_total,emb.cpu()))
218 | return emb_total
219 |
220 |
--------------------------------------------------------------------------------
/learning/s3dis_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs
3 | http://arxiv.org/abs/1711.09869
4 | 2017 Loic Landrieu, Martin Simonovsky
5 | """
6 | from __future__ import division
7 | from __future__ import print_function
8 | from builtins import range
9 |
10 | import sys
11 | sys.path.append("./learning")
12 |
13 | import random
14 | import numpy as np
15 | import os
16 | import functools
17 | import torch
18 | import torchnet as tnt
19 | import h5py
20 | import spg
21 | from sklearn.linear_model import RANSACRegressor
22 |
23 | def get_datasets(args, test_seed_offset=0):
24 | """ Gets training and test datasets. """
25 |
26 | # Load superpoints graphs
27 | testlist, trainlist, validlist = [], [], []
28 | valid_names = ['hallway_1.h5', 'hallway_6.h5', 'hallway_11.h5', 'office_1.h5' \
29 | , 'office_6.h5', 'office_11.h5', 'office_16.h5', 'office_21.h5', 'office_26.h5' \
30 | , 'office_31.h5', 'office_36.h5'\
31 | ,'WC_2.h5', 'storage_1.h5', 'storage_5.h5', 'conferenceRoom_2.h5', 'auditorium_1.h5']
32 |
33 | #if args.db_test_name == 'test' then the test set is the evaluation set
34 | #otherwise it serves as valdiation set to select the best epoch
35 |
36 | for n in range(1,7):
37 | if n != args.cvfold:
38 | path = '{}/superpoint_graphs/Area_{:d}/'.format(args.S3DIS_PATH, n)
39 | for fname in sorted(os.listdir(path)):
40 | if fname.endswith(".h5") and not (args.use_val_set and fname in valid_names):
41 | #training set
42 | trainlist.append(spg.spg_reader(args, path + fname, True))
43 | if fname.endswith(".h5") and (args.use_val_set and fname in valid_names):
44 | #validation set
45 | validlist.append(spg.spg_reader(args, path + fname, True))
46 | path = '{}/superpoint_graphs/Area_{:d}/'.format(args.S3DIS_PATH, args.cvfold)
47 |
48 | #evaluation set
49 | for fname in sorted(os.listdir(path)):
50 | if fname.endswith(".h5"):
51 | testlist.append(spg.spg_reader(args, path + fname, True))
52 |
53 | # Normalize edge features
54 | if args.spg_attribs01:
55 | trainlist, testlist, validlist, scaler = spg.scaler01(trainlist, testlist, validlist=validlist)
56 |
57 | return tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in trainlist],
58 | functools.partial(spg.loader, train=True, args=args, db_path=args.S3DIS_PATH)), \
59 | tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in testlist],
60 | functools.partial(spg.loader, train=False, args=args, db_path=args.S3DIS_PATH, test_seed_offset=test_seed_offset)), \
61 | tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in validlist],
62 | functools.partial(spg.loader, train=False, args=args, db_path=args.S3DIS_PATH, test_seed_offset=test_seed_offset)), \
63 | scaler
64 |
65 |
66 | def get_info(args):
67 | edge_feats = 0
68 | for attrib in args.edge_attribs.split(','):
69 | a = attrib.split('/')[0]
70 | if a in ['delta_avg', 'delta_std', 'xyz']:
71 | edge_feats += 3
72 | else:
73 | edge_feats += 1
74 | if args.loss_weights == 'none':
75 | weights = np.ones((13,),dtype='f4')
76 | else:
77 | weights = h5py.File(args.S3DIS_PATH + "/parsed/class_count.h5")["class_count"][:].astype('f4')
78 | weights = weights[:,[i for i in range(6) if i != args.cvfold-1]].sum(1)
79 | weights = weights.mean()/weights
80 | if args.loss_weights == 'sqrt':
81 | weights = np.sqrt(weights)
82 | weights = torch.from_numpy(weights).cuda() if args.cuda else torch.from_numpy(weights)
83 | return {
84 | 'node_feats': 14 if args.pc_attribs=='' else len(args.pc_attribs),
85 | 'edge_feats': edge_feats,
86 | 'class_weights': weights,
87 | 'classes': 13,
88 | 'inv_class_map': {0:'ceiling', 1:'floor', 2:'wall', 3:'column', 4:'beam', 5:'window', 6:'door', 7:'table', 8:'chair', 9:'bookcase', 10:'sofa', 11:'board', 12:'clutter'},
89 | }
90 |
91 |
92 |
93 | def preprocess_pointclouds(args):
94 | """ Preprocesses data by splitting them by components and normalizing."""
95 | S3DIS_PATH = args.S3DIS_PATH
96 | class_count = np.zeros((13,6),dtype='int')
97 | for n in range(1,7):
98 | pathP = '{}/parsed/Area_{:d}/'.format(S3DIS_PATH, n)
99 | if args.supervized_partition:
100 | pathD = '{}/features_supervision/Area_{:d}/'.format(S3DIS_PATH, n)
101 | else:
102 | pathD = '{}/features/Area_{:d}/'.format(S3DIS_PATH, n)
103 | pathC = '{}/superpoint_graphs/Area_{:d}/'.format(S3DIS_PATH, n)
104 | if not os.path.exists(pathP):
105 | os.makedirs(pathP)
106 | random.seed(n)
107 |
108 | for file in os.listdir(pathC):
109 | print(file)
110 | if file.endswith(".h5"):
111 | f = h5py.File(pathD + file, 'r')
112 | xyz = f['xyz'][:]
113 | rgb = f['rgb'][:].astype(np.float)
114 |
115 | labels = f['labels'][:]
116 | hard_labels = np.argmax(labels[:,1:],1)
117 | label_count = np.bincount(hard_labels, minlength=13)
118 | class_count[:,n-1] = class_count[:,n-1] + label_count
119 |
120 | if not args.supervized_partition:
121 | lpsv = f['geof'][:]
122 | lpsv -= 0.5 #normalize
123 | else:
124 | lpsv = np.stack([f["geof"][:] ]).squeeze()
125 | # rescale to [-0.5,0.5]; keep xyz
126 |
127 | if args.plane_model_elevation:
128 | if args.supervized_partition: #already computed
129 | e = f['elevation'][:]
130 | else: #simple plane model
131 | low_points = ((xyz[:,2]-xyz[:,2].min() < 0.5)).nonzero()[0]
132 | reg = RANSACRegressor(random_state=0).fit(xyz[low_points,:2], xyz[low_points,2])
133 | e = xyz[:,2]-reg.predict(xyz[:,:2])
134 | else: #compute elevation from zmin
135 | e = xyz[:,2] / 4 - 0.5 # (4m rough guess)
136 |
137 | rgb = rgb/255.0 - 0.5
138 |
139 | room_center = xyz[:,[0,1]].mean(0) #compute distance to room center, useful to detect walls and doors
140 | distance_to_center = np.sqrt(((xyz[:,[0,1]]-room_center)**2).sum(1))
141 | distance_to_center = (distance_to_center - distance_to_center.mean())/distance_to_center.std()
142 |
143 | ma, mi = np.max(xyz,axis=0,keepdims=True), np.min(xyz,axis=0,keepdims=True)
144 | xyzn = (xyz - mi) / (ma - mi + 1e-8) # as in PointNet ("normalized location as to the room (from 0 to 1)")
145 |
146 | P = np.concatenate([xyz, rgb, e[:,np.newaxis], lpsv, xyzn, distance_to_center[:,None]], axis=1)
147 |
148 | f = h5py.File(pathC + file, 'r')
149 | numc = len(f['components'].keys())
150 |
151 | with h5py.File(pathP + file, 'w') as hf:
152 | hf.create_dataset(name='centroid',data=xyz.mean(0))
153 | for c in range(numc):
154 | idx = f['components/{:d}'.format(c)][:].flatten()
155 | if idx.size > 10000: # trim extra large segments, just for speed-up of loading time
156 | ii = random.sample(range(idx.size), k=10000)
157 | idx = idx[ii]
158 | hf.create_dataset(name='{:d}'.format(c), data=P[idx,...])
159 |
160 | path = '{}/parsed/'.format(S3DIS_PATH)
161 | data_file = h5py.File(path+'class_count.h5', 'w')
162 | data_file.create_dataset('class_count', data=class_count, dtype='int')
163 |
164 | if __name__ == "__main__":
165 | import argparse
166 | parser = argparse.ArgumentParser(description='Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs')
167 | parser.add_argument('--S3DIS_PATH', default='datasets/s3dis')
168 | parser.add_argument('--supervized_partition', type=int, default=0)
169 | parser.add_argument('--plane_model_elevation', type=int, default=0, help='compute elevation with a simple RANSAC based plane model')
170 | args = parser.parse_args()
171 | preprocess_pointclouds(args)
172 |
--------------------------------------------------------------------------------
/learning/sema3d_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs
3 | http://arxiv.org/abs/1711.09869
4 | 2017 Loic Landrieu, Martin Simonovsky
5 | """
6 | from __future__ import division
7 | from __future__ import print_function
8 | from builtins import range
9 |
10 | import random
11 | import numpy as np
12 | import os
13 | import functools
14 | import torch
15 | import torchnet as tnt
16 | import h5py
17 | import spg
18 |
19 |
20 | def get_datasets(args, test_seed_offset=0):
21 |
22 | train_names = ['bildstein_station1', 'bildstein_station5', 'domfountain_station1', 'domfountain_station3', 'neugasse_station1', 'sg27_station1', 'sg27_station2', 'sg27_station5', 'sg27_station9', 'sg28_station4', 'untermaederbrunnen_station1']
23 | valid_names = ['bildstein_station3', 'domfountain_station2', 'sg27_station4', 'untermaederbrunnen_station3']
24 |
25 | if args.db_train_name == 'train':
26 | trainset = ['train/' + f for f in train_names]
27 | elif args.db_train_name == 'trainval':
28 | trainset = ['train/' + f for f in train_names + valid_names]
29 |
30 | validset = []
31 | testset = []
32 | if args.use_val_set:
33 | validset = ['train/' + f for f in valid_names]
34 | if args.db_test_name == 'testred':
35 | testset = ['test_reduced/' + os.path.splitext(f)[0] for f in os.listdir(args.SEMA3D_PATH + '/superpoint_graphs/test_reduced')]
36 | elif args.db_test_name == 'testfull':
37 | testset = ['test_full/' + os.path.splitext(f)[0] for f in os.listdir(args.SEMA3D_PATH + '/superpoint_graphs/test_full')]
38 |
39 | # Load superpoints graphs
40 | testlist, trainlist, validlist = [], [], []
41 | for n in trainset:
42 | trainlist.append(spg.spg_reader(args, args.SEMA3D_PATH + '/superpoint_graphs/' + n + '.h5', True))
43 | for n in validset:
44 | validlist.append(spg.spg_reader(args, args.SEMA3D_PATH + '/superpoint_graphs/' + n + '.h5', True))
45 | for n in testset:
46 | testlist.append(spg.spg_reader(args, args.SEMA3D_PATH + '/superpoint_graphs/' + n + '.h5', True))
47 |
48 | # Normalize edge features
49 | if args.spg_attribs01:
50 | trainlist, testlist, validlist, scaler = spg.scaler01(trainlist, testlist, validlist=validlist)
51 |
52 | return tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in trainlist],
53 | functools.partial(spg.loader, train=True, args=args, db_path=args.SEMA3D_PATH)), \
54 | tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in testlist],
55 | functools.partial(spg.loader, train=False, args=args, db_path=args.SEMA3D_PATH, test_seed_offset=test_seed_offset)), \
56 | tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in validlist],
57 | functools.partial(spg.loader, train=False, args=args, db_path=args.SEMA3D_PATH, test_seed_offset=test_seed_offset)),\
58 | scaler
59 |
60 |
61 | def get_info(args):
62 | edge_feats = 0
63 | for attrib in args.edge_attribs.split(','):
64 | a = attrib.split('/')[0]
65 | if a in ['delta_avg', 'delta_std', 'xyz']:
66 | edge_feats += 3
67 | else:
68 | edge_feats += 1
69 | if args.loss_weights == 'none':
70 | weights = np.ones((8,),dtype='f4')
71 | else:
72 | weights = h5py.File(args.SEMA3D_PATH + "/parsed/class_count.h5")["class_count"][:].astype('f4')
73 | weights = weights.mean()/weights
74 | if args.loss_weights == 'sqrt':
75 | weights = np.sqrt(weights)
76 | weights = torch.from_numpy(weights).cuda() if args.cuda else torch.from_numpy(weights)
77 | return {
78 | 'node_feats': 14 if args.pc_attribs=='' else len(args.pc_attribs),
79 | 'edge_feats': edge_feats,
80 | 'class_weights': weights,
81 | 'classes': 8,
82 | 'inv_class_map': {0:'terrain_man', 1:'terrain_nature', 2:'veget_hi', 3:'veget_low', 4:'building', 5:'scape', 6:'artefact', 7:'cars'},
83 | }
84 |
85 | def preprocess_pointclouds(SEMA3D_PATH):
86 | """ Preprocesses data by splitting them by components and normalizing."""
87 | class_count = np.zeros((8,),dtype='int')
88 | for n in ['train', 'test_reduced', 'test_full']:
89 | pathP = '{}/parsed/{}/'.format(SEMA3D_PATH, n)
90 | if args.supervised_partition :
91 | pathD = '{}/features_supervision/{}/'.format(SEMA3D_PATH, n)
92 | else:
93 | pathD = '{}/features/{}/'.format(SEMA3D_PATH, n)
94 | pathC = '{}/superpoint_graphs/{}/'.format(SEMA3D_PATH, n)
95 | if not os.path.exists(pathP):
96 | os.makedirs(pathP)
97 | random.seed(0)
98 |
99 | for file in os.listdir(pathC):
100 | print(file)
101 | if file.endswith(".h5"):
102 | f = h5py.File(pathD + file, 'r')
103 |
104 | if n == 'train':
105 | labels = f['labels'][:]
106 | hard_labels = np.argmax(labels[:,1:],1)
107 | label_count = np.bincount(hard_labels, minlength=8)
108 | class_count = class_count + label_count
109 |
110 | xyz = f['xyz'][:]
111 | rgb = f['rgb'][:].astype(np.float)
112 | elpsv = np.concatenate((f['xyz'][:,2][:,None], f['geof'][:]), axis=1)
113 |
114 | # rescale to [-0.5,0.5]; keep xyz
115 | elpsv[:,0] /= 100 # (rough guess)
116 | elpsv[:,1:] -= 0.5
117 | rgb = rgb/255.0 - 0.5
118 |
119 | P = np.concatenate([xyz, rgb, elpsv], axis=1)
120 |
121 | f = h5py.File(pathC + file, 'r')
122 | numc = len(f['components'].keys())
123 |
124 | with h5py.File(pathP + file, 'w') as hf:
125 | hf.create_dataset(name='centroid',data=xyz.mean(0))
126 | for c in range(numc):
127 | idx = f['components/{:d}'.format(c)][:].flatten()
128 | if idx.size > 10000: # trim extra large segments, just for speed-up of loading time
129 | ii = random.sample(range(idx.size), k=10000)
130 | idx = idx[ii]
131 |
132 | hf.create_dataset(name='{:d}'.format(c), data=P[idx,...])
133 | path = '{}/parsed/'.format(SEMA3D_PATH)
134 | data_file = h5py.File(path+'class_count.h5', 'w')
135 | data_file.create_dataset('class_count', data=class_count, dtype='int')
136 |
137 | if __name__ == "__main__":
138 | import argparse
139 | parser = argparse.ArgumentParser(description='Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs')
140 | parser.add_argument('--SEMA3D_PATH', default='datasets/semantic3d')
141 | parser.add_argument('--supervised_partition', default=0, type=int, help = 'wether to use supervized partition features')
142 | args = parser.parse_args()
143 | preprocess_pointclouds(args.SEMA3D_PATH)
144 |
--------------------------------------------------------------------------------
/learning/spg.py:
--------------------------------------------------------------------------------
1 | """
2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs
3 | http://arxiv.org/abs/1711.09869
4 | 2017 Loic Landrieu, Martin Simonovsky
5 | """
6 | from __future__ import division
7 | from __future__ import print_function
8 | from builtins import range
9 |
10 | import random
11 | import numpy as np
12 | import os
13 | import math
14 | import transforms3d
15 | import torch
16 | import ecc
17 | import h5py
18 | from sklearn import preprocessing
19 | import igraph
20 |
21 |
22 |
23 | def spg_edge_features(edges, node_att, edge_att, args):
24 | """ Assembles edge features from edge attributes and differences of node attributes. """
25 | columns = []
26 | for attrib in args.edge_attribs.split(','):
27 | attrib = attrib.split('/')
28 | a, opt = attrib[0], attrib[1].lower() if len(attrib)==2 else ''
29 |
30 | if a in ['delta_avg', 'delta_std']:
31 | columns.append(edge_att[a])
32 | elif a=='constant': # for isotropic baseline
33 | columns.append(np.ones((edges.shape[0],1), dtype=np.float32))
34 | elif a in ['nlength','surface','volume', 'size', 'xyz']:
35 | attr = node_att[a]
36 | if opt=='d': # difference
37 | attr = attr[edges[:,0],:] - attr[edges[:,1],:]
38 | elif opt=='ld': # log ratio
39 | attr = np.log(attr + 1e-10)
40 | attr = attr[edges[:,0],:] - attr[edges[:,1],:]
41 | elif opt=='r': # ratio
42 | attr = attr[edges[:,0],:] / (attr[edges[:,1],:] + 1e-10)
43 | else:
44 | raise NotImplementedError
45 | columns.append(attr)
46 | else:
47 | raise NotImplementedError
48 |
49 | return np.concatenate(columns, axis=1).astype(np.float32)
50 |
51 | def scaler01(trainlist, testlist, transform_train=True, validlist = []):
52 | """ Scale edge features to 0 mean 1 stddev """
53 | edge_feats = np.concatenate([ trainlist[i][3] for i in range(len(trainlist)) ], 0)
54 | scaler = preprocessing.StandardScaler().fit(edge_feats)
55 |
56 | if transform_train:
57 | for i in range(len(trainlist)):
58 | scaler.transform(trainlist[i][3], copy=False)
59 | for i in range(len(testlist)):
60 | scaler.transform(testlist[i][3], copy=False)
61 | if len(validlist)>0:
62 | for i in range(len(validlist)):
63 | scaler.transform(validlist[i][3], copy=False)
64 | return trainlist, testlist, validlist, scaler
65 |
66 | def spg_reader(args, fname, incl_dir_in_name=False):
67 | """ Loads a supergraph from H5 file. """
68 | f = h5py.File(fname,'r')
69 |
70 | if f['sp_labels'].size > 0:
71 | node_gt_size = f['sp_labels'][:].astype(np.int64) # column 0: no of unlabeled points, column 1+: no of labeled points per class
72 | node_gt = np.argmax(node_gt_size[:,1:], 1)[:,None]
73 | node_gt[node_gt_size[:,1:].sum(1)==0,:] = -100 # superpoints without labels are to be ignored in loss computation
74 | else:
75 | N = f['sp_point_count'].shape[0]
76 | node_gt_size = np.concatenate([f['sp_point_count'][:].astype(np.int64), np.zeros((N,8), dtype=np.int64)], 1)
77 | node_gt = np.zeros((N,1), dtype=np.int64)
78 |
79 | node_att = {}
80 | node_att['xyz'] = f['sp_centroids'][:]
81 | node_att['nlength'] = np.maximum(0, f['sp_length'][:])
82 | node_att['volume'] = np.maximum(0, f['sp_volume'][:] ** 2)
83 | node_att['surface'] = np.maximum(0, f['sp_surface'][:] ** 2)
84 | node_att['size'] = f['sp_point_count'][:]
85 |
86 | edges = np.concatenate([ f['source'][:], f['target'][:] ], axis=1).astype(np.int64)
87 |
88 | edge_att = {}
89 | edge_att['delta_avg'] = f['se_delta_mean'][:]
90 | edge_att['delta_std'] = f['se_delta_std'][:]
91 |
92 | if args.spg_superedge_cutoff > 0:
93 | filtered = np.linalg.norm(edge_att['delta_avg'],axis=1) < args.spg_superedge_cutoff
94 | edges = edges[filtered,:]
95 | edge_att['delta_avg'] = edge_att['delta_avg'][filtered,:]
96 | edge_att['delta_std'] = edge_att['delta_std'][filtered,:]
97 |
98 | edge_feats = spg_edge_features(edges, node_att, edge_att, args)
99 |
100 | name = os.path.basename(fname)[:-len('.h5')]
101 | if incl_dir_in_name: name = os.path.basename(os.path.dirname(fname)) + '/' + name
102 |
103 | return node_gt, node_gt_size, edges, edge_feats, name
104 |
105 |
106 | def spg_to_igraph(node_gt, node_gt_size, edges, edge_feats, fname):
107 | """ Builds representation of superpoint graph as igraph. """
108 | targets = np.concatenate([node_gt, node_gt_size], axis=1)
109 | G = igraph.Graph(n=node_gt.shape[0], edges=edges.tolist(), directed=True,
110 | edge_attrs={'f':edge_feats},
111 | vertex_attrs={'v':list(range(node_gt.shape[0])), 't':targets, 's':node_gt_size.sum(1)})
112 | return G, fname
113 |
114 | def random_neighborhoods(G, num, order):
115 | """ Samples `num` random neighborhoods of size `order`.
116 | Graph nodes are then treated as set, i.e. after hardcutoff, neighborhoods may be broken (sort of data augmentation). """
117 | centers = random.sample(range(G.vcount()), k=num)
118 | neighb = G.neighborhood(centers, order)
119 | subset = [item for sublist in neighb for item in sublist]
120 | subset = sorted(set(subset))
121 | return G.subgraph(subset)
122 |
123 | def k_big_enough(G, minpts, k):
124 | """ Returns a induced graph on maximum k superpoints of size >= minpts (smaller ones are not counted) """
125 | valid = np.array(G.vs['s']) >= minpts
126 | n = np.argwhere(np.cumsum(valid)<=k)[-1][0]+1
127 | return G.subgraph(range(n))
128 |
129 |
130 | def loader(entry, train, args, db_path, test_seed_offset=0):
131 | """ Prepares a superpoint graph (potentially subsampled in training) and associated superpoints. """
132 | G, fname = entry
133 | # 1) subset (neighborhood) selection of (permuted) superpoint graph
134 | if train:
135 | if 0 < args.spg_augm_hardcutoff < G.vcount():
136 | perm = list(range(G.vcount())); random.shuffle(perm)
137 | G = G.permute_vertices(perm)
138 |
139 | if 0 < args.spg_augm_nneigh < G.vcount():
140 | G = random_neighborhoods(G, args.spg_augm_nneigh, args.spg_augm_order)
141 |
142 | if 0 < args.spg_augm_hardcutoff < G.vcount():
143 | G = k_big_enough(G, args.ptn_minpts, args.spg_augm_hardcutoff)
144 |
145 | # Only stores graph with edges
146 | if len(G.get_edgelist()) != 0:
147 | # 2) loading clouds for chosen superpoint graph nodes
148 | clouds_meta, clouds_flag = [], [] # meta: textual id of the superpoint; flag: 0/-1 if no cloud because too small
149 | clouds, clouds_global = [], [] # clouds: point cloud arrays; clouds_global: diameters before scaling
150 |
151 | for s in range(G.vcount()):
152 | cloud, diam = load_superpoint(args, db_path + '/parsed/' + fname + '.h5', G.vs[s]['v'], train, test_seed_offset)
153 | if cloud is not None:
154 | clouds_meta.append('{}.{:d}'.format(fname,G.vs[s]['v'])); clouds_flag.append(0)
155 | clouds.append(cloud.T)
156 | clouds_global.append(diam)
157 | else:
158 | clouds_meta.append('{}.{:d}'.format(fname,G.vs[s]['v'])); clouds_flag.append(-1)
159 |
160 | clouds_flag = np.array(clouds_flag)
161 | if len(clouds) != 0:
162 | clouds = np.stack(clouds)
163 | if len(clouds_global) != 0:
164 | clouds_global = np.concatenate(clouds_global)
165 |
166 | return np.array(G.vs['t']), G, clouds_meta, clouds_flag, clouds, clouds_global
167 |
168 | # Don't use the graph if it doesn't have edges.
169 | else:
170 | target, G, clouds_meta, clouds_flag, clouds, clouds_global = None, None, None, None, None, None
171 | return target, G, clouds_meta, clouds_flag, clouds, clouds_global
172 |
173 |
174 | def cloud_edge_feats(edgeattrs):
175 | edgefeats = np.asarray(edgeattrs['f'])
176 | return torch.from_numpy(edgefeats), None
177 |
178 | def eccpc_collate(batch):
179 | """ Collates a list of dataset samples into a single batch (adapted in ecc.graph_info_collate_classification())
180 | """
181 | targets, graphs, clouds_meta, clouds_flag, clouds, clouds_global = list(zip(*batch))
182 |
183 | targets = torch.cat([torch.from_numpy(t) for t in targets if t is not None], 0).long()
184 | graphs = [graph for graph in graphs if graph is not None]
185 | GIs = [ecc.GraphConvInfo(graphs, cloud_edge_feats)]
186 |
187 | if len(clouds_meta[0]) > 0:
188 | clouds = torch.cat([torch.from_numpy(f) for f in clouds if f is not None], 0)
189 | clouds_global = torch.cat([torch.from_numpy(f) for f in clouds_global if f is not None], 0)
190 | clouds_flag = torch.cat([torch.from_numpy(f) for f in clouds_flag if f is not None], 0)
191 | clouds_meta = [item for sublist in clouds_meta if sublist is not None for item in sublist]
192 |
193 | return targets, GIs, (clouds_meta, clouds_flag, clouds, clouds_global)
194 |
195 |
196 | ############### POINT CLOUD PROCESSING ##########
197 |
198 | def load_superpoint(args, fname, id, train, test_seed_offset):
199 | """ """
200 | hf = h5py.File(fname,'r')
201 | P = hf['{:d}'.format(id)]
202 | N = P.shape[0]
203 | if N < args.ptn_minpts: # skip if too few pts (this must be consistent at train and test time)
204 | return None, N
205 | P = P[:].astype(np.float32)
206 |
207 | rs = np.random.random.__self__ if train else np.random.RandomState(seed=id+test_seed_offset) # fix seed for test
208 |
209 | if N > args.ptn_npts: # need to subsample
210 | ii = rs.choice(N, args.ptn_npts)
211 | P = P[ii, ...]
212 | elif N < args.ptn_npts: # need to pad by duplication
213 | ii = rs.choice(N, args.ptn_npts - N)
214 | P = np.concatenate([P, P[ii,...]], 0)
215 |
216 | if args.pc_xyznormalize:
217 | # normalize xyz into unit ball, i.e. in [-0.5,0.5]
218 | diameter = np.max(np.max(P[:,:3],axis=0) - np.min(P[:,:3],axis=0))
219 | P[:,:3] = (P[:,:3] - np.mean(P[:,:3], axis=0, keepdims=True)) / (diameter + 1e-10)
220 | else:
221 | diameter = 0.0
222 | P[:,:3] = (P[:,:3] - np.mean(P[:,:3], axis=0, keepdims=True))
223 |
224 | if args.pc_attribs != '':
225 | columns = []
226 | if 'xyz' in args.pc_attribs: columns.append(P[:,:3])
227 | if 'rgb' in args.pc_attribs: columns.append(P[:,3:6])
228 | if 'e' in args.pc_attribs: columns.append(P[:,6,None])
229 | if 'lpsv' in args.pc_attribs: columns.append(P[:,7:11])
230 | if 'XYZ' in args.pc_attribs: columns.append(P[:,11:14])
231 | if 'd' in args.pc_attribs: columns.append(P[:,14])
232 | P = np.concatenate(columns, axis=1)
233 |
234 | if train:
235 | P = augment_cloud(P, args)
236 | return P, np.array([diameter], dtype=np.float32)
237 |
238 |
239 | def augment_cloud(P, args):
240 | """" Augmentation on XYZ and jittering of everything """
241 | M = transforms3d.zooms.zfdir2mat(1)
242 | if args.pc_augm_scale > 1:
243 | s = random.uniform(1/args.pc_augm_scale, args.pc_augm_scale)
244 | M = np.dot(transforms3d.zooms.zfdir2mat(s), M)
245 | if args.pc_augm_rot==1:
246 | angle = random.uniform(0, 2*math.pi)
247 | M = np.dot(transforms3d.axangles.axangle2mat([0,0,1], angle), M) # z=upright assumption
248 | if args.pc_augm_mirror_prob > 0: # mirroring x&y, not z
249 | if random.random() < args.pc_augm_mirror_prob/2:
250 | M = np.dot(transforms3d.zooms.zfdir2mat(-1, [1,0,0]), M)
251 | if random.random() < args.pc_augm_mirror_prob/2:
252 | M = np.dot(transforms3d.zooms.zfdir2mat(-1, [0,1,0]), M)
253 | P[:,:3] = np.dot(P[:,:3], M.T)
254 |
255 | if args.pc_augm_jitter:
256 | sigma, clip= 0.01, 0.05 # https://github.com/charlesq34/pointnet/blob/master/provider.py#L74
257 | P = P + np.clip(sigma * np.random.randn(*P.shape), -1*clip, clip).astype(np.float32)
258 | return P
259 |
260 | def global_rotation(P, args):
261 | print("e")
262 |
--------------------------------------------------------------------------------
/learning/vkitti_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Tue Nov 6 16:45:16 2018
5 | @author: landrieuloic
6 | """
7 | from __future__ import division
8 | from __future__ import print_function
9 | from builtins import range
10 |
11 | import random
12 | import numpy as np
13 | import os
14 | import functools
15 | import torch
16 | import torchnet as tnt
17 | import h5py
18 | import spg
19 |
20 | def get_datasets(args, test_seed_offset=0):
21 | """ Gets training and test datasets. """
22 |
23 | # Load superpoints graphs
24 | testlist, trainlist, validlist = [], [], []
25 | valid_names = ['0001_00000.h5','0001_00085.h5', '0001_00170.h5','0001_00230.h5','0001_00325.h5','0001_00420.h5', \
26 | '0002_00000.h5','0002_00111.h5','0002_00223.h5','0018_00030.h5','0018_00184.h5','0018_00338.h5',\
27 | '0020_00080.h5','0020_00262.h5','0020_00444.h5','0020_00542.h5','0020_00692.h5', '0020_00800.h5']
28 |
29 | for n in range(1,7):
30 | if n != args.cvfold:
31 | path = '{}/superpoint_graphs/0{:d}/'.format(args.VKITTI_PATH, n)
32 | for fname in sorted(os.listdir(path)):
33 | if fname.endswith(".h5") and not (args.use_val_set and fname in valid_names):
34 | #training set
35 | trainlist.append(spg.spg_reader(args, path + fname, True))
36 | if fname.endswith(".h5") and (args.use_val_set and fname in valid_names):
37 | #validation set
38 | validlist.append(spg.spg_reader(args, path + fname, True))
39 | path = '{}/superpoint_graphs/0{:d}/'.format(args.VKITTI_PATH, args.cvfold)
40 | #evaluation set
41 | for fname in sorted(os.listdir(path)):
42 | if fname.endswith(".h5"):
43 | testlist.append(spg.spg_reader(args, path + fname, True))
44 |
45 | # Normalize edge features
46 | if args.spg_attribs01:
47 | trainlist, testlist, validlist, scaler = spg.scaler01(trainlist, testlist, validlist=validlist)
48 |
49 | return tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in trainlist],
50 | functools.partial(spg.loader, train=True, args=args, db_path=args.VKITTI_PATH)), \
51 | tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in testlist],
52 | functools.partial(spg.loader, train=False, args=args, db_path=args.VKITTI_PATH, test_seed_offset=test_seed_offset)), \
53 | tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in validlist],
54 | functools.partial(spg.loader, train=False, args=args, db_path=args.VKITTI_PATH, test_seed_offset=test_seed_offset)), \
55 | scaler
56 |
57 |
58 | def get_info(args):
59 | edge_feats = 0
60 | for attrib in args.edge_attribs.split(','):
61 | a = attrib.split('/')[0]
62 | if a in ['delta_avg', 'delta_std', 'xyz']:
63 | edge_feats += 3
64 | else:
65 | edge_feats += 1
66 | if args.loss_weights == 'none':
67 | weights = np.ones((13,),dtype='f4')
68 | else:
69 | weights = h5py.File(args.VKITTI_PATH + "/parsed/class_count.h5")["class_count"][:].astype('f4')
70 | weights = weights[:,[i for i in range(6) if i != args.cvfold-1]].sum(1)
71 | weights = (weights+1).mean()/(weights+1)
72 | if args.loss_weights == 'sqrt':
73 | weights = np.sqrt(weights)
74 | weights = torch.from_numpy(weights).cuda() if args.cuda else torch.from_numpy(weights)
75 | return {
76 | 'node_feats': 9 if args.pc_attribs=='' else len(args.pc_attribs),
77 | 'edge_feats': edge_feats,
78 | 'classes': 13,
79 | 'class_weights': weights,
80 | 'inv_class_map': {0:'Terrain', 1:'Tree', 2:'Vegetation', 3:'Building', 4:'Road', 5:'GuardRail', 6:'TrafficSign', 7:'TrafficLight', 8:'Pole', 9:'Misc', 10:'Truck', 11:'Car', 12:'Van'},
81 | }
82 |
83 | def preprocess_pointclouds(VKITTI_PATH):
84 | """ Preprocesses data by splitting them by components and normalizing."""
85 | class_count = np.zeros((13,6),dtype='int')
86 | for n in range(1,7):
87 | pathP = '{}/parsed/0{:d}/'.format(VKITTI_PATH, n)
88 | pathD = '{}/features_supervision/0{:d}/'.format(VKITTI_PATH, n)
89 | pathC = '{}/superpoint_graphs/0{:d}/'.format(VKITTI_PATH, n)
90 | if not os.path.exists(pathP):
91 | os.makedirs(pathP)
92 | random.seed(n)
93 |
94 | for file in os.listdir(pathC):
95 | print(file)
96 | if file.endswith(".h5"):
97 | f = h5py.File(pathD + file, 'r')
98 | xyz = f['xyz'][:]
99 | rgb = f['rgb'][:].astype(np.float)
100 |
101 | labels = f['labels'][:]
102 | hard_labels = np.argmax(labels[:,1:],1)
103 | label_count = np.bincount(hard_labels, minlength=13)
104 | class_count[:,n-1] = class_count[:,n-1] + label_count
105 |
106 | e = (f['xyz'][:,2][:] - np.min(f['xyz'][:,2]))/ (np.max(f['xyz'][:,2]) - np.min(f['xyz'][:,2]))-0.5
107 |
108 | rgb = rgb/255.0 - 0.5
109 |
110 | xyzn = (xyz - np.array([30,0,0])) / np.array([30,5,3])
111 |
112 | lpsv = np.zeros((e.shape[0],4))
113 |
114 | P = np.concatenate([xyz, rgb, e[:,np.newaxis], lpsv, xyzn], axis=1)
115 |
116 | f = h5py.File(pathC + file, 'r')
117 | numc = len(f['components'].keys())
118 |
119 | with h5py.File(pathP + file, 'w') as hf:
120 | hf.create_dataset(name='centroid',data=xyz.mean(0))
121 | for c in range(numc):
122 | idx = f['components/{:d}'.format(c)][:].flatten()
123 | if idx.size > 10000: # trim extra large segments, just for speed-up of loading time
124 | ii = random.sample(range(idx.size), k=10000)
125 | idx = idx[ii]
126 |
127 | hf.create_dataset(name='{:d}'.format(c), data=P[idx,...])
128 | path = '{}/parsed/'.format(VKITTI_PATH)
129 | data_file = h5py.File(path+'class_count.h5', 'w')
130 | data_file.create_dataset('class_count', data=class_count, dtype='int')
131 |
132 | if __name__ == "__main__":
133 | import argparse
134 | parser = argparse.ArgumentParser(description='Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs')
135 | parser.add_argument('--VKITTI_PATH', default='datasets/s3dis')
136 | args = parser.parse_args()
137 | preprocess_pointclouds(args.VKITTI_PATH)
--------------------------------------------------------------------------------
/partition/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/loicland/superpoint_graph/0209777339327c9b327b6947af6c89b20bb45981/partition/__init__.py
--------------------------------------------------------------------------------
/partition/graphs.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | #--------- Graph methods for SuperPoint Graph ------------------------------
3 | #--------- Loic Landrieu, Dec. 2017 -----------------------------------
4 | #------------------------------------------------------------------------------
5 | import numpy as np
6 | from sklearn.neighbors import NearestNeighbors
7 | from scipy.spatial import Delaunay
8 | from numpy import linalg as LA
9 | import numpy.matlib
10 | #------------------------------------------------------------------------------
11 | def compute_graph_nn(xyz, k_nn):
12 | """compute the knn graph"""
13 | num_ver = xyz.shape[0]
14 | graph = dict([("is_nn", True)])
15 | nn = NearestNeighbors(n_neighbors=k_nn+1, algorithm='kd_tree').fit(xyz)
16 | distances, neighbors = nn.kneighbors(xyz)
17 | neighbors = neighbors[:, 1:]
18 | distances = distances[:, 1:]
19 | source = np.matlib.repmat(range(0, num_ver), k_nn, 1).flatten(order='F')
20 | #save the graph
21 | graph["source"] = source.flatten().astype('uint32')
22 | graph["target"] = neighbors.flatten().astype('uint32')
23 | graph["distances"] = distances.flatten().astype('float32')
24 | return graph
25 | #------------------------------------------------------------------------------
26 | def compute_graph_nn_2(xyz, k_nn1, k_nn2, voronoi = 0.0):
27 | """compute simulteneoulsy 2 knn structures
28 | only saves target for knn2
29 | assumption : knn1 <= knn2"""
30 | assert k_nn1 <= k_nn2, "knn1 must be smaller than knn2"
31 | n_ver = xyz.shape[0]
32 | #compute nearest neighbors
33 | graph = dict([("is_nn", True)])
34 | nn = NearestNeighbors(n_neighbors=k_nn2+1, algorithm='kd_tree').fit(xyz)
35 | distances, neighbors = nn.kneighbors(xyz)
36 | del nn
37 | neighbors = neighbors[:, 1:]
38 | distances = distances[:, 1:]
39 | #---knn2---
40 | target2 = (neighbors.flatten()).astype('uint32')
41 | #---knn1-----
42 | if voronoi>0:
43 | tri = Delaunay(xyz)
44 | graph["source"] = np.hstack((tri.vertices[:,0],tri.vertices[:,0], \
45 | tri.vertices[:,0], tri.vertices[:,1], tri.vertices[:,1], tri.vertices[:,2])).astype('uint64')
46 | graph["target"]= np.hstack((tri.vertices[:,1],tri.vertices[:,2], \
47 | tri.vertices[:,3], tri.vertices[:,2], tri.vertices[:,3], tri.vertices[:,3])).astype('uint64')
48 | graph["distances"] = ((xyz[graph["source"],:] - xyz[graph["target"],:])**2).sum(1)
49 | keep_edges = graph["distances"] 1
80 | label_hist = has_labels and len(labels.shape) > 1 and labels.shape[1] > 1
81 | #---compute delaunay triangulation---
82 | tri = Delaunay(xyz)
83 | #interface select the edges between different components
84 | #edgx and edgxr converts from tetrahedrons to edges
85 | #done separatly for each edge of the tetrahedrons to limit memory impact
86 | interface = in_component[tri.vertices[:, 0]] != in_component[tri.vertices[:, 1]]
87 | edg1 = np.vstack((tri.vertices[interface, 0], tri.vertices[interface, 1]))
88 | edg1r = np.vstack((tri.vertices[interface, 1], tri.vertices[interface, 0]))
89 | interface = in_component[tri.vertices[:, 0]] != in_component[tri.vertices[:, 2]]
90 | edg2 = np.vstack((tri.vertices[interface, 0], tri.vertices[interface, 2]))
91 | edg2r = np.vstack((tri.vertices[interface, 2], tri.vertices[interface, 0]))
92 | interface = in_component[tri.vertices[:, 0]] != in_component[tri.vertices[:, 3]]
93 | edg3 = np.vstack((tri.vertices[interface, 0], tri.vertices[interface, 3]))
94 | edg3r = np.vstack((tri.vertices[interface, 3], tri.vertices[interface, 0]))
95 | interface = in_component[tri.vertices[:, 1]] != in_component[tri.vertices[:, 2]]
96 | edg4 = np.vstack((tri.vertices[interface, 1], tri.vertices[interface, 2]))
97 | edg4r = np.vstack((tri.vertices[interface, 2], tri.vertices[interface, 1]))
98 | interface = in_component[tri.vertices[:, 1]] != in_component[tri.vertices[:, 3]]
99 | edg5 = np.vstack((tri.vertices[interface, 1], tri.vertices[interface, 3]))
100 | edg5r = np.vstack((tri.vertices[interface, 3], tri.vertices[interface, 1]))
101 | interface = in_component[tri.vertices[:, 2]] != in_component[tri.vertices[:, 3]]
102 | edg6 = np.vstack((tri.vertices[interface, 2], tri.vertices[interface, 3]))
103 | edg6r = np.vstack((tri.vertices[interface, 3], tri.vertices[interface, 2]))
104 | del tri, interface
105 | edges = np.hstack((edg1, edg2, edg3, edg4 ,edg5, edg6, edg1r, edg2r,
106 | edg3r, edg4r ,edg5r, edg6r))
107 | del edg1, edg2, edg3, edg4 ,edg5, edg6, edg1r, edg2r, edg3r, edg4r, edg5r, edg6r
108 | edges = np.unique(edges, axis=1)
109 |
110 | if d_max > 0:
111 | dist = np.sqrt(((xyz[edges[0,:]]-xyz[edges[1,:]])**2).sum(1))
112 | edges = edges[:,dist 1:
203 | graph["se_delta_mean"][i_sedg] = np.mean(delta, axis=0)
204 | graph["se_delta_std"][i_sedg] = np.std(delta, axis=0)
205 | graph["se_delta_norm"][i_sedg] = np.mean(np.sqrt(np.sum(delta ** 2, axis=1)))
206 | else:
207 | graph["se_delta_mean"][i_sedg, :] = delta
208 | graph["se_delta_std"][i_sedg, :] = [0, 0, 0]
209 | graph["se_delta_norm"][i_sedg] = np.sqrt(np.sum(delta ** 2))
210 | return graph
211 |
--------------------------------------------------------------------------------
/partition/partition.py:
--------------------------------------------------------------------------------
1 | """
2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs
3 | http://arxiv.org/abs/1711.09869
4 | 2017 Loic Landrieu, Martin Simonovsky
5 | Script for partioning into simples shapes
6 | """
7 | import os.path
8 | import sys
9 | import numpy as np
10 | import argparse
11 | from timeit import default_timer as timer
12 | sys.path.append("./partition/cut-pursuit/build/src")
13 | sys.path.append("./partition/ply_c")
14 | sys.path.append("./partition")
15 | import libcp
16 | import libply_c
17 | from graphs import *
18 | from provider import *
19 |
20 | parser = argparse.ArgumentParser(description='Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs')
21 | parser.add_argument('--ROOT_PATH', default='datasets/s3dis')
22 | parser.add_argument('--dataset', default='s3dis', help='s3dis/sema3d/your_dataset')
23 | parser.add_argument('--k_nn_geof', default=45, type=int, help='number of neighbors for the geometric features')
24 | parser.add_argument('--k_nn_adj', default=10, type=int, help='adjacency structure for the minimal partition')
25 | parser.add_argument('--lambda_edge_weight', default=1., type=float, help='parameter determine the edge weight for minimal part.')
26 | parser.add_argument('--reg_strength', default=0.1, type=float, help='regularization strength for the minimal partition')
27 | parser.add_argument('--d_se_max', default=0, type=float, help='max length of super edges')
28 | parser.add_argument('--voxel_width', default=0.03, type=float, help='voxel size when subsampling (in m)')
29 | parser.add_argument('--ver_batch', default=0, type=int, help='Batch size for reading large files, 0 do disable batch loading')
30 | parser.add_argument('--overwrite', default=0, type=int, help='Wether to read existing files or overwrite them')
31 | args = parser.parse_args()
32 |
33 | #path to data
34 | root = args.ROOT_PATH+'/'
35 | #list of subfolders to be processed
36 | if args.dataset == 's3dis':
37 | folders = ["Area_1/", "Area_2/", "Area_3/", "Area_4/", "Area_5/", "Area_6/"]
38 | n_labels = 13
39 | elif args.dataset == 'sema3d':
40 | folders = ["test_reduced/", "test_full/", "train/"]
41 | n_labels = 8
42 | elif args.dataset == 'custom_dataset':
43 | folders = ["train/", "test/"]
44 | n_labels = 10 #number of classes
45 | else:
46 | raise ValueError('%s is an unknown data set' % dataset)
47 |
48 | times = [0,0,0] #time for computing: features / partition / spg
49 |
50 | if not os.path.isdir(root + "clouds"):
51 | os.mkdir(root + "clouds")
52 | if not os.path.isdir(root + "features"):
53 | os.mkdir(root + "features")
54 | if not os.path.isdir(root + "superpoint_graphs"):
55 | os.mkdir(root + "superpoint_graphs")
56 |
57 | for folder in folders:
58 | print("=================\n "+folder+"\n=================")
59 |
60 | data_folder = root + "data/" + folder
61 | cloud_folder = root + "clouds/" + folder
62 | fea_folder = root + "features/" + folder
63 | spg_folder = root + "superpoint_graphs/" + folder
64 | if not os.path.isdir(data_folder):
65 | raise ValueError("%s does not exist" % data_folder)
66 |
67 | if not os.path.isdir(cloud_folder):
68 | os.mkdir(cloud_folder)
69 | if not os.path.isdir(fea_folder):
70 | os.mkdir(fea_folder)
71 | if not os.path.isdir(spg_folder):
72 | os.mkdir(spg_folder)
73 |
74 | if args.dataset=='s3dis':
75 | files = [os.path.join(data_folder, o) for o in os.listdir(data_folder)
76 | if os.path.isdir(os.path.join(data_folder,o))]
77 | elif args.dataset=='sema3d':
78 | files = glob.glob(data_folder+"*.txt")
79 | elif args.dataset=='custom_dataset':
80 | #list all ply files in the folder
81 | files = glob.glob(data_folder+"*.ply")
82 | #list all las files in the folder
83 | files = glob.glob(data_folder+"*.las")
84 |
85 | if (len(files) == 0):
86 | raise ValueError('%s is empty' % data_folder)
87 |
88 | n_files = len(files)
89 | i_file = 0
90 | for file in files:
91 | file_name = os.path.splitext(os.path.basename(file))[0]
92 |
93 | if args.dataset=='s3dis':
94 | data_file = data_folder + file_name + '/' + file_name + ".txt"
95 | cloud_file = cloud_folder + file_name
96 | fea_file = fea_folder + file_name + '.h5'
97 | spg_file = spg_folder + file_name + '.h5'
98 | elif args.dataset=='sema3d':
99 | file_name_short = '_'.join(file_name.split('_')[:2])
100 | data_file = data_folder + file_name + ".txt"
101 | label_file = data_folder + file_name_short + ".labels"
102 | cloud_file = cloud_folder+ file_name_short
103 | fea_file = fea_folder + file_name_short + '.h5'
104 | spg_file = spg_folder + file_name_short + '.h5'
105 | elif args.dataset=='custom_dataset':
106 | #adapt to your hierarchy. The following 4 files must be defined
107 | data_file = data_folder + file_name + '.ply' #or .las
108 | cloud_file = cloud_folder + file_name
109 | fea_file = fea_folder + file_name + '.h5'
110 | spg_file = spg_folder + file_name + '.h5'
111 |
112 | i_file = i_file + 1
113 | print(str(i_file) + " / " + str(n_files) + "---> "+file_name)
114 | #--- build the geometric feature file h5 file ---
115 | if os.path.isfile(fea_file) and not args.overwrite:
116 | print(" reading the existing feature file...")
117 | geof, xyz, rgb, graph_nn, labels = read_features(fea_file)
118 | else :
119 | print(" creating the feature file...")
120 | #--- read the data files and compute the labels---
121 | if args.dataset=='s3dis':
122 | xyz, rgb, labels, objects = read_s3dis_format(data_file)
123 | if args.voxel_width > 0:
124 | xyz, rgb, labels, dump = libply_c.prune(xyz.astype('f4'), args.voxel_width, rgb.astype('uint8'), labels.astype('uint8'), np.zeros(1, dtype='uint8'), n_labels, 0)
125 | elif args.dataset=='sema3d':
126 | label_file = data_folder + file_name + ".labels"
127 | has_labels = (os.path.isfile(label_file))
128 | if (has_labels):
129 | xyz, rgb, labels = read_semantic3d_format(data_file, n_labels, label_file, args.voxel_width, args.ver_batch)
130 | else:
131 | xyz, rgb = read_semantic3d_format(data_file, 0, '', args.voxel_width, args.ver_batch)
132 | labels = []
133 | elif args.dataset=='custom_dataset':
134 | #implement in provider.py your own read_custom_format outputing xyz, rgb, labels
135 | #example for ply files
136 | xyz, rgb, labels = read_ply(data_file)
137 | #another one for las files without rgb
138 | xyz = read_las(data_file)
139 | if args.voxel_width > 0:
140 | #an example of pruning without labels
141 | xyz, rgb, labels = libply_c.prune(xyz, args.voxel_width, rgb, np.array(1,dtype='u1'), 0)
142 | #another one without rgb information nor labels
143 | xyz = libply_c.prune(xyz, args.voxel_width, np.zeros(xyz.shape,dtype='u1'), np.array(1,dtype='u1'), 0)[0]
144 | #if no labels available simply set here labels = []
145 | #if no rgb available simply set here rgb = [] and make sure to not use it later on
146 | start = timer()
147 | #---compute 10 nn graph-------
148 | graph_nn, target_fea = compute_graph_nn_2(xyz, args.k_nn_adj, args.k_nn_geof)
149 | #---compute geometric features-------
150 | geof = libply_c.compute_geof(xyz, target_fea, args.k_nn_geof).astype('float32')
151 | end = timer()
152 | times[0] = times[0] + end - start
153 | del target_fea
154 | write_features(fea_file, geof, xyz, rgb, graph_nn, labels)
155 | #--compute the partition------
156 | sys.stdout.flush()
157 | if os.path.isfile(spg_file) and not args.overwrite:
158 | print(" reading the existing superpoint graph file...")
159 | graph_sp, components, in_component = read_spg(spg_file)
160 | else:
161 | print(" computing the superpoint graph...")
162 | #--- build the spg h5 file --
163 | start = timer()
164 | if args.dataset=='s3dis':
165 | features = np.hstack((geof, rgb/255.)).astype('float32')#add rgb as a feature for partitioning
166 | features[:,3] = 2. * features[:,3] #increase importance of verticality (heuristic)
167 | elif args.dataset=='sema3d':
168 | features = geof
169 | geof[:,3] = 2. * geof[:, 3]
170 | elif args.dataset=='custom_dataset':
171 | #choose here which features to use for the partition
172 | features = geof
173 | geof[:,3] = 2. * geof[:, 3]
174 |
175 | graph_nn["edge_weight"] = np.array(1. / ( args.lambda_edge_weight + graph_nn["distances"] / np.mean(graph_nn["distances"])), dtype = 'float32')
176 | print(" minimal partition...")
177 | components, in_component = libcp.cutpursuit(features, graph_nn["source"], graph_nn["target"]
178 | , graph_nn["edge_weight"], args.reg_strength)
179 | components = np.array(components, dtype = 'object')
180 | end = timer()
181 | times[1] = times[1] + end - start
182 | print(" computation of the SPG...")
183 | start = timer()
184 | graph_sp = compute_sp_graph(xyz, args.d_se_max, in_component, components, labels, n_labels)
185 | end = timer()
186 | times[2] = times[2] + end - start
187 | write_spg(spg_file, graph_sp, components, in_component)
188 |
189 | print("Timer : %5.1f / %5.1f / %5.1f " % (times[0], times[1], times[2]))
190 |
--------------------------------------------------------------------------------
/partition/ply_c/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Graph for Cut Pursuit
2 | # author: Loic Landrieu
3 | # date: 2017
4 |
5 | CMAKE_MINIMUM_REQUIRED(VERSION 3.5)
6 |
7 | PROJECT(LIBGEOF)
8 |
9 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -std=c++11 -fopenmp -O3")
10 |
11 | ##############################
12 | ### Find required packages ###
13 | ##############################
14 | set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_MODULE_PATH})
15 |
16 | find_package(PythonLibs)
17 | find_package(PythonInterp)
18 | find_package(NumPy 1.5 REQUIRED)
19 |
20 | find_package(Boost 1.65.0 COMPONENTS graph REQUIRED) #system filesystem thread serialization
21 | if (${Boost_MINOR_VERSION} LESS 67 )
22 | find_package(Boost 1.65.0 COMPONENTS numpy${PYTHON_VERSION_MAJOR} REQUIRED) #system filesystem thread serialization
23 | else()
24 | set(PYTHONVERSION ${PYTHON_VERSION_MAJOR}${PYTHON_VERSION_MINOR})
25 | find_package(Boost 1.67.0 COMPONENTS numpy${PYTHONVERSION} REQUIRED)
26 | endif()
27 |
28 | include_directories(${Boost_INCLUDE_DIRS})
29 | link_directories(${Boost_LIBRARY_DIRS})
30 |
31 | message("Boost includes ARE " ${Boost_INCLUDE_DIRS})
32 | message("Boost LIBRARIES ARE " ${Boost_LIBRARY_DIRS})
33 |
34 | find_package(Eigen3 REQUIRED NO_MODULE)
35 | INCLUDE_DIRECTORIES(${EIGEN3_INCLUDE_DIR})
36 | #LINK_DIRECTORIES(${EIGEN3_LIBRARY_DIRS})
37 |
38 | #SET(PYTHON_LIBRARIES /usr/lib/x86_64-linux-gnu/libpython2.7.so)
39 | #SET(PYTHON_INCLUDE_DIRS /usr/include/python2.7)
40 |
41 | message("PYTHON LIBRARIES ARE " ${PYTHON_LIBRARIES})
42 | INCLUDE_DIRECTORIES(${PYTHON_INCLUDE_DIRS} ${PYTHON_NUMPY_INCLUDE_DIR})
43 | LINK_DIRECTORIES(${PYTHON_LIBRARY_DIRS})
44 | ##############################
45 | ### Build target library ###
46 | ##############################
47 |
48 | set(CMAKE_LD_FLAG "${CMAKE_LD_FLAGS} -shared -Wl -fPIC --export-dynamic -fopenmp -O3 -Wall")
49 |
50 | add_library(ply_c SHARED ply_c.cpp)
51 | target_link_libraries(ply_c
52 | ${Boost_LIBRARIES}
53 | ${PYTHON_LIBRARIES}
54 | )
55 |
--------------------------------------------------------------------------------
/partition/ply_c/FindNumPy.cmake:
--------------------------------------------------------------------------------
1 |
2 | # - Try to find the Python module NumPy
3 | #
4 | # This module defines:
5 | # NUMPY_INCLUDE_DIR: include path for arrayobject.h
6 |
7 | # Copyright (c) 2009-2012 Arnaud Barré
8 | # Redistribution and use is allowed according to the terms of the BSD license.
9 | # For details see the accompanying COPYING-CMAKE-SCRIPTS file.
10 |
11 | if (PYTHON_NUMPY_INCLUDE_DIR)
12 | set(PYTHON_NUMPY_FIND_QUIETLY TRUE)
13 | endif()
14 |
15 | if (NOT PYTHON_EXECUTABLE)
16 | message(FATAL_ERROR "\"PYTHON_EXECUTABLE\" varabile not set before FindNumPy.cmake was run.")
17 | endif()
18 |
19 | # Look for the include path
20 | # WARNING: The variable PYTHON_EXECUTABLE is defined by the script FindPythonInterp.cmake
21 | execute_process(COMMAND "${PYTHON_EXECUTABLE}" -c "import numpy; print (numpy.get_include()); print (numpy.version.version)"
22 | OUTPUT_VARIABLE NUMPY_OUTPUT
23 | ERROR_VARIABLE NUMPY_ERROR)
24 | if (NOT NUMPY_ERROR)
25 | STRING(REPLACE "\n" ";" NUMPY_OUTPUT ${NUMPY_OUTPUT})
26 | LIST(GET NUMPY_OUTPUT 0 PYTHON_NUMPY_INCLUDE_DIRS)
27 | LIST(GET NUMPY_OUTPUT 1 PYTHON_NUMPY_VERSION)
28 | endif(NOT NUMPY_ERROR)
29 |
30 | include(FindPackageHandleStandardArgs)
31 | find_package_handle_standard_args(NumPy DEFAULT_MSG PYTHON_NUMPY_VERSION PYTHON_NUMPY_INCLUDE_DIRS)
32 |
33 | set(PYTHON_NUMPY_INCLUDE_DIR ${PYTHON_NUMPY_INCLUDE_DIRS}
34 | CACHE PATH "Location of NumPy include files.")
35 | mark_as_advanced(PYTHON_NUMPY_INCLUDE_DIR)
--------------------------------------------------------------------------------
/partition/ply_c/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/loicland/superpoint_graph/0209777339327c9b327b6947af6c89b20bb45981/partition/ply_c/__init__.py
--------------------------------------------------------------------------------
/partition/ply_c/connected_components.cpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 |
11 | using namespace std;
12 | using namespace boost;
13 |
14 | typedef adjacency_list Graph;
15 | typedef typename graph_traits< Graph >::adjacency_iterator adjacency_iterator;
16 |
17 | void connected_components(const uint32_t n_ver, const uint32_t n_edg
18 | , const uint32_t * Eu, const uint32_t * Ev, const char * active_edge
19 | , std::vector & in_component, std::vector< std::vector > & components, const uint32_t cutoff)
20 | { //C-style interface
21 |
22 | Graph G(n_ver);
23 | for (uint32_t i_edg = 0; i_edg < n_edg; i_edg++)
24 | {
25 | if (active_edge[i_edg] > 0)
26 | {
27 | add_edge(Eu[i_edg], Ev[i_edg], G);
28 | }
29 | }
30 |
31 | int n_com = connected_components(G, &in_component[0]);
32 |
33 | //cout << "Total number of components: " << n_com << endl;
34 |
35 | std::vector< std::vector > components_tmp(n_com);
36 | for (uint32_t i_ver = 0; i_ver < n_ver; i_ver++)
37 | {
38 | components_tmp[in_component[i_ver]].push_back(i_ver);
39 | }
40 |
41 | //fuse components to preserve cutoff
42 |
43 | G = Graph(n_ver);
44 | for (uint32_t i_edg = 0; i_edg < n_edg; i_edg++)
45 | {
46 | if (active_edge[i_edg] == 0)
47 | {
48 | add_edge(Eu[i_edg], Ev[i_edg], G);
49 | }
50 | }
51 |
52 | typename graph_traits < Graph >::adjacency_iterator nei_ini, nei_end;
53 | boost::property_map::type vertex_index_map = get(boost::vertex_index, G);
54 | std::vector is_fused(n_ver, 0);
55 |
56 | int n_com_final = n_com;
57 | for (int i_com = 0; i_com < n_com; i_com++)
58 | {
59 | if (components_tmp[i_com].size() < cutoff)
60 | {//components is too small
61 | //std::cout << i_com << " of size " << components_tmp[i_com].size() << " / " << cutoff << std::endl;
62 | int largest_neigh_comp_value = 0;
63 | int largest_neigh_comp_index = -1;
64 | for (int i_ver_com = 0; i_ver_com < components_tmp[i_com].size(); i_ver_com++)
65 | { //std::cout << " considering node" << components_tmp[i_com][i_ver_com] << std::endl;
66 | boost::tie(nei_ini, nei_end) = adjacent_vertices(vertex(components_tmp[i_com][i_ver_com], G), G);
67 | for (graph_traits < Graph >::adjacency_iterator nei_ite = nei_ini; nei_ite != nei_end; nei_ite++)
68 | {
69 | int candidate_comp = in_component[vertex_index_map(*nei_ite)];
70 | if ((candidate_comp == i_com) || (is_fused[candidate_comp]))
71 | {
72 | continue;
73 | }
74 | //std::cout << " neighbors " << vertex_index_map(*nei_ite) << " in comp " << candidate_comp << "of size " << components_tmp[candidate_comp].size() << std::endl;
75 | if (components_tmp[candidate_comp].size() > largest_neigh_comp_value)
76 | {
77 | largest_neigh_comp_value = components_tmp[candidate_comp].size() ;
78 | largest_neigh_comp_index = candidate_comp;
79 | }
80 | }
81 | }
82 | if (largest_neigh_comp_index>0)
83 | {
84 | //std::cout << "best comp = " << largest_neigh_comp_index << " of size " << largest_neigh_comp_value << std::endl;
85 | //we now fuse the two connected components
86 | components_tmp[largest_neigh_comp_index].insert(components_tmp[largest_neigh_comp_index].end(), components_tmp[i_com].begin(), components_tmp[i_com].end());
87 | is_fused[i_com] = 1;
88 | n_com_final--;
89 | }
90 | }
91 | }
92 |
93 | components.resize(n_com_final);
94 | int i_com_index = 0;
95 | for (int i_com = 0; i_com < n_com; i_com++)
96 | {
97 | if (!is_fused[i_com])
98 | {
99 | components[i_com_index] = components_tmp[i_com];
100 | for (uint32_t i_ver_com = 0; i_ver_com < components_tmp[i_com].size(); i_ver_com++)
101 | {
102 | in_component[components_tmp[i_com][i_ver_com]] = i_com_index;
103 | }
104 | i_com_index++;
105 | }
106 | }
107 |
108 |
109 | return;
110 | }
111 |
--------------------------------------------------------------------------------
/partition/ply_c/random_subgraph.cpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 | using namespace std;
9 | using namespace boost;
10 |
11 | namespace subgraph {
12 |
13 | typedef adjacency_list Graph;
14 |
15 | typedef typename boost::graph_traits< Graph >::vertex_descriptor VertexDescriptor;
16 |
17 | typedef typename boost::property_map< Graph, boost::vertex_index_t>::type VertexIndexMap;
18 |
19 | typedef typename graph_traits < Graph >::adjacency_iterator Adjacency_iterator;
20 |
21 |
22 | void random_subgraph(const int n_ver, const int n_edg, const uint32_t * Eu, const uint32_t * Ev, int subgraph_size
23 | , uint8_t * selected_edges, uint8_t * selected_vertices)
24 |
25 | { //C-style interface
26 |
27 | if (n_ver < subgraph_size)
28 | {
29 | for (uint32_t i_edg = 0; i_edg < n_edg; i_edg++)
30 | {
31 | selected_edges[i_edg] = 1;
32 | }
33 | for (uint32_t i_ver = 0; i_ver < n_ver; i_ver++)
34 | {
35 | selected_vertices[n_ver] = 1;
36 | }
37 | return;
38 | }
39 |
40 | Graph G(n_ver);
41 |
42 | VertexIndexMap vertex_index_map = get(boost::vertex_index, G);
43 | VertexDescriptor ver_current;
44 | Adjacency_iterator ite_ver_adj,ite_ver_adj_end;
45 | int node_seen = 0, seed_index;
46 | queue ver_queue;
47 |
48 | for (uint32_t i_edg = 0; i_edg < n_edg; i_edg++)
49 | { //building graph
50 | add_edge(vertex(Eu[i_edg],G), vertex(Ev[i_edg],G), G);
51 | }
52 |
53 | while(node_seen < subgraph_size)
54 | {
55 | //add seed vertex
56 | seed_index = rand() % n_ver;
57 | if (selected_vertices[seed_index])
58 | {
59 | continue;
60 | }
61 | ver_queue.push(vertex(seed_index,G));
62 | selected_vertices[vertex_index_map(ver_queue.front())] = 1;
63 | node_seen = node_seen + 1;
64 |
65 | while(!ver_queue.empty())
66 | {
67 | //pop the top of the queue and mark it as seen
68 | ver_current = ver_queue.front();
69 | ver_queue.pop();
70 |
71 | //add the neighbors of that vertex
72 | for (tie(ite_ver_adj, ite_ver_adj_end) = adjacent_vertices(ver_current, G); ite_ver_adj != ite_ver_adj_end; ite_ver_adj++)
73 | {
74 | int i_ver_adj = vertex_index_map(*ite_ver_adj);
75 | if ((selected_vertices[i_ver_adj]==0) && (node_seen <= subgraph_size))
76 | {//vertex not already seen
77 | node_seen = node_seen + 1;
78 | selected_vertices[i_ver_adj] = 1;
79 | ver_queue.push(*ite_ver_adj);
80 | }
81 |
82 | if (node_seen >= subgraph_size)
83 | {//enough vertices
84 | break;
85 | }
86 |
87 | }
88 | }
89 | }
90 |
91 | for (int i_edg = 0; i_edg < n_edg; i_edg++)
92 | { //add edges between selected vertices
93 | selected_edges[i_edg] = selected_vertices[vertex_index_map(vertex(Eu[i_edg],G))]
94 | * selected_vertices[vertex_index_map(vertex(Ev[i_edg],G))];
95 | }
96 | return;
97 | }
98 | }
99 |
--------------------------------------------------------------------------------
/partition/visualize.py:
--------------------------------------------------------------------------------
1 | """
2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs
3 | http://arxiv.org/abs/1711.09869
4 | 2017 Loic Landrieu, Martin Simonovsky
5 |
6 | this functions outputs varied ply file to visualize the different steps
7 | """
8 | import os.path
9 | import numpy as np
10 | import argparse
11 | import sys
12 | sys.path.append("./partition/")
13 | from plyfile import PlyData, PlyElement
14 | from provider import *
15 | parser = argparse.ArgumentParser(description='Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs')
16 | parser.add_argument('--dataset', default='s3dis', help='dataset name: sema3d|s3dis')
17 | parser.add_argument('--ROOT_PATH', default='/mnt/bigdrive/loic/S3DIS', help='folder containing the ./data folder')
18 | parser.add_argument('--res_file', default='../models/cv1/predictions_val', help='folder containing the results')
19 | parser.add_argument('--supervized_partition',type=int, default=0)
20 | parser.add_argument('--file_path', default='Area_1/conferenceRoom_1', help='file to output (must include the area / set in its path)')
21 | parser.add_argument('--upsample', default=0, type=int, help='if 1, upsample the prediction to the original cloud (if the files is huge it can take a very long and use a lot of memory - avoid on sema3d)')
22 | parser.add_argument('--ver_batch', default=0, type=int, help='Batch size for reading large files')
23 | parser.add_argument('--output_type', default='igfpres', help='which cloud to output: i = input rgb pointcloud \
24 | , g = ground truth, f = geometric features, p = partition, r = prediction result \
25 | , e = error, s = SPG')
26 | args = parser.parse_args()
27 | #---path to data---------------------------------------------------------------
28 | #root of the data directory
29 | root = args.ROOT_PATH+'/'
30 | rgb_out = 'i' in args.output_type
31 | gt_out = 'g' in args.output_type
32 | fea_out = 'f' in args.output_type
33 | par_out = 'p' in args.output_type
34 | res_out = 'r' in args.output_type
35 | err_out = 'e' in args.output_type
36 | spg_out = 's' in args.output_type
37 | folder = os.path.split(args.file_path)[0] + '/'
38 | file_name = os.path.split(args.file_path)[1]
39 |
40 | if args.dataset == 's3dis':
41 | n_labels = 13
42 | if args.dataset == 'sema3d':
43 | n_labels = 8
44 | if args.dataset == 'vkitti':
45 | n_labels = 13
46 | if args.dataset == 'custom_dataset':
47 | n_labels = 10
48 | #---load the values------------------------------------------------------------
49 | fea_file = root + "features/" + folder + file_name + '.h5'
50 | if not os.path.isfile(fea_file) or args.supervized_partition:
51 | fea_file = root + "features_supervision/" + folder + file_name + '.h5'
52 | spg_file = root + "superpoint_graphs/" + folder + file_name + '.h5'
53 | ply_folder = root + "clouds/" + folder
54 | ply_file = ply_folder + file_name
55 | res_file = args.res_file + '.h5'
56 |
57 | if not os.path.isdir(root + "clouds/"):
58 | os.mkdir(root + "clouds/" )
59 | if not os.path.isdir(ply_folder ):
60 | os.mkdir(ply_folder)
61 | if (not os.path.isfile(fea_file)) :
62 | raise ValueError("%s does not exist and is needed" % fea_file)
63 |
64 | geof, xyz, rgb, graph_nn, labels = read_features(fea_file)
65 |
66 | if (par_out or res_out) and (not os.path.isfile(spg_file)):
67 | raise ValueError("%s does not exist and is needed to output the partition or result ply" % spg_file)
68 | else:
69 | graph_spg, components, in_component = read_spg(spg_file)
70 | if res_out or err_out:
71 | if not os.path.isfile(res_file):
72 | raise ValueError("%s does not exist and is needed to output the result ply" % res_file)
73 | try:
74 | pred_red = np.array(h5py.File(res_file, 'r').get(folder + file_name))
75 | if (len(pred_red) != len(components)):
76 | raise ValueError("It looks like the spg is not adapted to the result file")
77 | pred_full = reduced_labels2full(pred_red, components, len(xyz))
78 | except OSError:
79 | raise ValueError("%s does not exist in %s" % (folder + file_name, res_file))
80 | #---write the output clouds----------------------------------------------------
81 | if rgb_out:
82 | print("writing the RGB file...")
83 | write_ply(ply_file + "_rgb.ply", xyz, rgb)
84 |
85 | if gt_out:
86 | print("writing the GT file...")
87 | prediction2ply(ply_file + "_GT.ply", xyz, labels, n_labels, args.dataset)
88 |
89 | if fea_out:
90 | print("writing the features file...")
91 | geof2ply(ply_file + "_geof.ply", xyz, geof)
92 |
93 | if par_out:
94 | print("writing the partition file...")
95 | partition2ply(ply_file + "_partition.ply", xyz, components)
96 |
97 | if res_out and not bool(args.upsample):
98 | print("writing the prediction file...")
99 | prediction2ply(ply_file + "_pred.ply", xyz, pred_full+1, n_labels, args.dataset)
100 |
101 | if err_out:
102 | print("writing the error file...")
103 | error2ply(ply_file + "_err.ply", xyz, rgb, labels, pred_full+1)
104 |
105 | if spg_out:
106 | print("writing the SPG file...")
107 | spg2ply(ply_file + "_spg.ply", graph_spg)
108 |
109 | if res_out and bool(args.upsample):
110 | if args.dataset=='s3dis':
111 | data_file = root + 'data/' + folder + file_name + '/' + file_name + ".txt"
112 | xyz_up, rgb_up = read_s3dis_format(data_file, False)
113 | elif args.dataset=='sema3d':#really not recommended unless you are very confident in your hardware
114 | data_file = data_folder + file_name + ".txt"
115 | xyz_up, rgb_up = read_semantic3d_format(data_file, 0, '', 0, args.ver_batch)
116 | elif args.dataset=='custom_dataset':
117 | data_file = data_folder + file_name + ".ply"
118 | xyz_up, rgb_up = read_ply(data_file)
119 | del rgb_up
120 | pred_up = interpolate_labels(xyz_up, xyz, pred_full, args.ver_batch)
121 | print("writing the upsampled prediction file...")
122 | prediction2ply(ply_file + "_pred_up.ply", xyz_up, pred_up+1, n_labels, args.dataset)
123 |
124 |
--------------------------------------------------------------------------------
/partition/write_Semantic3d.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs
5 | http://arxiv.org/abs/1711.09869
6 | 2017 Loic Landrieu, Martin Simonovsky
7 |
8 | call this function once the partition and inference was made to upsample
9 | the prediction to the original point clouds
10 | """
11 | import os.path
12 | import glob
13 | import numpy as np
14 | import argparse
15 | from provider import *
16 | parser = argparse.ArgumentParser(description='Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs')
17 | parser.add_argument('--SEMA3D_PATH', default='datasets/semantic3D')
18 | parser.add_argument('--odir', default='./results/semantic3d', help='Directory to store results')
19 | parser.add_argument('--ver_batch', default=5000000, type=int, help='Batch size for reading large files')
20 | parser.add_argument('--db_test_name', default='testred')
21 | args = parser.parse_args()
22 | #---path to data---------------------------------------------------------------
23 | #root of the data directory
24 | root = args.SEMA3D_PATH+'/'
25 | #list of subfolders to be processed
26 | if args.db_test_name == 'testred':
27 | area = 'test_reduced/'
28 | elif args.db_test_name == 'testfull':
29 | area = 'test_full/'
30 | #------------------------------------------------------------------------------
31 | print("=================\n " + area + "\n=================")
32 | data_folder = root + "data/" + area
33 | fea_folder = root + "features/" + area
34 | spg_folder = root + "superpoint_graphs/" + area
35 | res_folder = './' + args.odir + '/'
36 | labels_folder = root + "labels/" + area
37 | if not os.path.isdir(data_folder):
38 | raise ValueError("%s do not exists" % data_folder)
39 | if not os.path.isdir(fea_folder):
40 | raise ValueError("%s do not exists" % fea_folder)
41 | if not os.path.isdir(res_folder):
42 | raise ValueError("%s do not exists" % res_folder)
43 | if not os.path.isdir(root + "labels/"):
44 | os.mkdir(root + "labels/")
45 | if not os.path.isdir(labels_folder):
46 | os.mkdir(labels_folder)
47 | try:
48 | res_file = h5py.File(res_folder + 'predictions_' + args.db_test_name + '.h5', 'r')
49 | except OSError:
50 | raise ValueError("%s do not exists" % res_file)
51 |
52 | files = glob.glob(data_folder+"*.txt")
53 | if (len(files) == 0):
54 | raise ValueError('%s is empty' % data_folder)
55 | n_files = len(files)
56 | i_file = 0
57 | for file in files:
58 | file_name = os.path.splitext(os.path.basename(file))[0]
59 | file_name_short = '_'.join(file_name.split('_')[:2])
60 | data_file = data_folder + file_name + ".txt"
61 | fea_file = fea_folder + file_name_short + '.h5'
62 | spg_file = spg_folder + file_name_short + '.h5'
63 | label_file = labels_folder + file_name_short + ".labels"
64 | i_file = i_file + 1
65 | print(str(i_file) + " / " + str(n_files) + "---> "+file_name_short)
66 | print(" reading the subsampled file...")
67 | geof, xyz, rgb, graph_nn, l = read_features(fea_file)
68 | graph_sp, components, in_component = read_spg(spg_file)
69 | n_ver = xyz.shape[0]
70 | del geof, rgb, graph_nn, l, graph_sp
71 | labels_red = np.array(res_file.get(area + file_name_short))
72 | print(" upsampling...")
73 | labels_full = reduced_labels2full(labels_red, components, n_ver)
74 | labels_ups = interpolate_labels_batch(data_file, xyz, labels_full, args.ver_batch)
75 | np.savetxt(label_file, labels_ups+1, delimiter=' ', fmt='%d') # X is an array
76 |
--------------------------------------------------------------------------------
/supervized_partition/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/loicland/superpoint_graph/0209777339327c9b327b6947af6c89b20bb45981/supervized_partition/__init__.py
--------------------------------------------------------------------------------
/supervized_partition/evaluate_partition.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Thu Oct 11 10:12:49 2018
5 |
6 | @author: landrieuloic
7 | """
8 |
9 | """
10 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs
11 | http://arxiv.org/abs/1711.09869
12 | 2017 Loic Landrieu, Martin Simonovsky
13 | """
14 | import glob, os
15 | import argparse
16 | import numpy as np
17 | import sys
18 | import ast
19 | import csv
20 | import h5py
21 | sys.path.append("./learning")
22 | from metrics import *
23 |
24 | parser = argparse.ArgumentParser(description='Evaluation function for S3DIS')
25 |
26 | parser.add_argument('--odir', default='./results_partition/', help='Directory to store results')
27 | parser.add_argument('--folder', default='', help='Directory to store results')
28 | parser.add_argument('--dataset', default='s3dis', help='Directory to store results')
29 | parser.add_argument('--cvfold', default='123456', help='which fold to consider')
30 |
31 | args = parser.parse_args()
32 | args.odir = args.odir + args.dataset + '/'
33 |
34 |
35 | root = args.odir + args.folder + '/'
36 |
37 | if args.dataset == 's3dis':
38 | fold_size = [44,40,23,49,68,48]
39 | files = glob.glob(root + 'cv{}'.format(args.cvfold[0]) + '/res*.h5')
40 | n_classes= 13
41 | elif args.dataset == 'vkitti':
42 | fold_size = [15,15,15,15,15,15]
43 | files = glob.glob(root + '0{}'.format(args.cvfold[0]) + '/res*.h5')
44 | n_classes = 13
45 |
46 | file_result_txt = open(args.odir + args.folder + '/results' + '.txt',"w")
47 | file_result_txt.write(" N \t ASA \t BR \t BP\n")
48 |
49 |
50 | C_classes = np.zeros((n_classes,n_classes))
51 | C_BR = np.zeros((2,2))
52 | C_BP = np.zeros((2,2))
53 | N_sp = 0
54 | N_pc = 0
55 |
56 | for i_fold in range(len(args.cvfold)):
57 | fold = int(args.cvfold[i_fold])
58 | if args.dataset == 's3dis':
59 | base_name = root + 'cv{}'.format(fold)
60 | elif args.dataset == 'vkitti':
61 | base_name = root + '0{}'.format(fold)
62 |
63 | try:
64 | file_name = base_name + '/res.h5'
65 | res_file = h5py.File(file_name, 'r')
66 | except OSError:
67 | raise NameError('Cant find pretrained model %s' % file_name)
68 |
69 | c_classes = np.array(res_file["confusion_matrix_classes"])
70 | c_BP = np.array(res_file["confusion_matrix_BP"])
71 | c_BR = np.array(res_file["confusion_matrix_BR"])
72 | n_sp = np.array(res_file["n_clusters"])
73 | print("Fold %d : \t n_sp = %5.1f \t ASA = %3.2f %% \t BR = %3.2f %% \t BP = %3.2f %%" % \
74 | (fold, n_sp, 100 * c_classes.trace() / c_classes.sum(), 100 * c_BR[1,1] / (c_BR[1,1] + c_BR[1,0]),100 * c_BP[1,1] / (c_BP[1,1] + c_BP[0,1]) ))
75 | C_classes += c_classes
76 | C_BR += c_BR
77 | C_BP += c_BP
78 | N_sp += n_sp * fold_size[i_fold]
79 | N_pc += fold_size[i_fold]
80 |
81 | if N_sp>0:
82 | print("\nOverall : \t n_sp = %5.1f \t ASA = %3.2f %% \t BR = %3.2f %% \t BP = %3.2f %%\n" % \
83 | (N_sp/N_pc, 100 * C_classes.trace() / C_classes.sum(), 100 * C_BR[1,1] / (C_BR[1,1] + C_BR[1,0]),100 * C_BP[1,1] / (C_BP[1,1] + C_BP[0,1]) ))
84 | file_result_txt.write("%4.1f \t %3.2f \t %3.2f \t %3.2f \n" % (N_sp/N_pc, 100 * C_classes.trace() / C_classes.sum(), 100 * C_BR[1,1] / (C_BR[1,1] + C_BR[1,0]),100 * C_BP[1,1] / (C_BP[1,1] + C_BP[0,1]) ))
85 |
86 | file_result_txt.close()
--------------------------------------------------------------------------------
/supervized_partition/folderhierarchy.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | DIR_PATH = os.path.dirname(os.path.realpath(__file__))
5 | sys.path.insert(0, os.path.join(DIR_PATH, '..'))
6 |
7 | class FolderHierachy:
8 | SPG_FOLDER = "superpoint_graphs"
9 | EMBEDDINGS_FOLDER = "embeddings"
10 | SCALAR_FOLDER = "scalars"
11 | MODEL_FILE = "model.pth.tar"
12 |
13 | def __init__(self,outputdir,dataset_name,root_dir,cv_fold):
14 | self._root = root_dir
15 | if dataset_name=='s3dis':
16 | self._outputdir = os.path.join(outputdir,'cv' + str(cv_fold))
17 | self._folders = ["Area_1/", "Area_2/", "Area_3/", "Area_4/", "Area_5/", "Area_6/"]
18 | elif dataset_name=='sema3d':
19 | self._outputdir = os.path.join(outputdir,'best')
20 | self._folders = ["train/", "test_reduced/", "test_full/"]
21 | elif dataset_name=='vkitti':
22 | self._outputdir = os.path.join(outputdir, 'cv' + str(cv_fold))
23 | self._folders = ["01/", "02/", "03/", "04/", "05/", "06/"]
24 |
25 | if not os.path.exists(self._outputdir):
26 | os.makedirs(self._outputdir)
27 |
28 | self._spg_folder = self._create_folder(self.SPG_FOLDER)
29 | self._emb_folder = self._create_folder(self.EMBEDDINGS_FOLDER)
30 | self._scalars = self._create_folder(self.SCALAR_FOLDER)
31 |
32 | @property
33 | def outputdir(self): return self._outputdir
34 |
35 | @property
36 | def emb_folder(self): return self._emb_folder
37 |
38 | @property
39 | def spg_folder(self): return self._spg_folder
40 |
41 | @property
42 | def scalars(self): return self._scalars
43 |
44 | @property
45 | def model_path(self): return os.path.join(self._outputdir, self.MODEL_FILE)
46 |
47 | def _create_folder(self,property_name):
48 | folder = os.path.join(self._root , property_name )
49 | if not os.path.isdir(folder):
50 | os.mkdir(folder)
51 | return folder
--------------------------------------------------------------------------------
/supervized_partition/generate_partition.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import argparse
3 | import sys
4 | import os
5 | import torch
6 | import glob
7 | import torchnet as tnt
8 | import functools
9 | import tqdm
10 | from multiprocessing import Pool
11 |
12 | DIR_PATH = os.path.dirname(os.path.realpath(__file__))
13 | sys.path.insert(0, os.path.join(DIR_PATH, '..'))
14 |
15 | from supervized_partition import supervized_partition
16 | from supervized_partition import graph_processing
17 | from supervized_partition import losses
18 | from partition import provider
19 | from partition import graphs
20 | from learning import pointnet
21 |
22 |
23 | def parse_args():
24 | parser = argparse.ArgumentParser(description='Partition large scale point clouds using cut-pursuit')
25 |
26 | parser.add_argument('--modeldir', help='Folder where the saved model lies', required=True)
27 | parser.add_argument('--cuda', default=0, type=int, help='Bool, use cuda')
28 | parser.add_argument('--input_folder', type=str,
29 | help='Folder containing preprocessed point clouds ready for segmentation', required=True)
30 | parser.add_argument('--output_folder', default="", type=str, help='Folder that will contain the output')
31 | parser.add_argument('--overwrite', default=1, type=int, help='Overwrite existing partition')
32 | parser.add_argument('--nworkers', default=5, type=int,
33 | help='Num subprocesses to use for generating the SPGs')
34 |
35 | args = parser.parse_args()
36 | return args
37 |
38 |
39 | def load_model(model_dir, cuda):
40 | checkpoint = torch.load(os.path.join(model_dir, supervized_partition.FolderHierachy.MODEL_FILE))
41 | training_args = checkpoint['args']
42 | training_args.cuda = cuda # override cuda
43 | model = supervized_partition.create_model(training_args)
44 | model.load_state_dict(checkpoint['state_dict'])
45 | model.eval()
46 | return model, training_args
47 |
48 |
49 | def get_dataloader(input_folder, training_args):
50 | file_list = glob.glob(os.path.join(input_folder, '*.h5'))
51 | if not file_list:
52 | raise ValueError("Empty input folder: %s" % input_folder)
53 | dataset = tnt.dataset.ListDataset(file_list,
54 | functools.partial(graph_processing.graph_loader, train=False, args=training_args, db_path=""))
55 | loader = torch.utils.data.DataLoader(dataset, batch_size=1,
56 | collate_fn=graph_processing.graph_collate)
57 | return loader
58 |
59 |
60 | def get_embedder(args):
61 | if args.learned_embeddings and args.ptn_embedding == 'ptn':
62 | ptnCloudEmbedder = pointnet.LocalCloudEmbedder(args)
63 | elif 'geof' in args.ver_value:
64 | ptnCloudEmbedder = graph_processing.spatialEmbedder(args)
65 | else:
66 | raise NameError('Do not know model ' + args.learned_embeddings)
67 | return ptnCloudEmbedder
68 |
69 |
70 | def get_num_classes(args):
71 | # Decide on the dataset
72 | if args.dataset == 's3dis':
73 | dbinfo = graph_processing.get_s3dis_info(args)
74 | elif args.dataset == 'sema3d':
75 | dbinfo = graph_processing.get_sema3d_info(args)
76 | elif args.dataset == 'vkitti':
77 | dbinfo = graph_processing.get_vkitti_info(args)
78 | else:
79 | raise NotImplementedError('Unknown dataset ' + args.dataset)
80 | return dbinfo["classes"]
81 |
82 |
83 | def process(data_tuple, model, output_folder, training_args, overwrite):
84 | fname, edg_source, edg_target, is_transition, labels, objects, clouds_data, xyz = data_tuple
85 | spg_file = os.path.join(output_folder, fname[0])
86 | logging.info("\nGenerating SPG file %s...", spg_file)
87 | if os.path.exists(os.path.dirname(spg_file)) and not overwrite:
88 | logging.info("Already exists, skipping")
89 | return
90 | elif not os.path.exists(os.path.dirname(spg_file)):
91 | os.makedirs(os.path.dirname(spg_file))
92 |
93 | if training_args.cuda:
94 | is_transition = is_transition.to('cuda', non_blocking=True)
95 | objects = objects.to('cuda', non_blocking=True)
96 | clouds, clouds_global, nei = clouds_data
97 | clouds_data = (clouds.to('cuda', non_blocking=True), clouds_global.to('cuda', non_blocking=True), nei)
98 |
99 | ptnCloudEmbedder = get_embedder(training_args)
100 | num_classes = get_num_classes(training_args)
101 |
102 | embeddings = ptnCloudEmbedder.run_batch(model, *clouds_data, xyz)
103 |
104 | diff = losses.compute_dist(embeddings, edg_source, edg_target, training_args.dist_type)
105 |
106 | pred_components, pred_in_component = losses.compute_partition(
107 | training_args, embeddings, edg_source, edg_target, diff, xyz)
108 |
109 | graph_sp = graphs.compute_sp_graph(xyz, 100, pred_in_component, pred_components, labels, num_classes)
110 |
111 | provider.write_spg(spg_file, graph_sp, pred_components, pred_in_component)
112 |
113 |
114 | def main():
115 | logging.getLogger().setLevel(logging.INFO) # set to logging.DEBUG to allow for more prints
116 | args = parse_args()
117 | model, training_args = load_model(args.modeldir, args.cuda)
118 | dataloader = get_dataloader(args.input_folder, training_args)
119 | workers = max(args.nworkers, 1)
120 |
121 | output_folder = args.output_folder
122 | if not output_folder:
123 | # By default assumes that it follows the S3DIS folder structure
124 | output_folder = os.path.join(args.input_folder, '../..', supervized_partition.FolderHierachy.SPG_FOLDER)
125 | if not os.path.exists(output_folder):
126 | os.makedirs(output_folder)
127 |
128 | if logging.getLogger().getEffectiveLevel() > logging.DEBUG:
129 | dataloader = tqdm.tqdm(dataloader, ncols=100)
130 | with torch.no_grad():
131 | processing_function = functools.partial(
132 | process, model=model, output_folder=output_folder, training_args=training_args, overwrite=args.overwrite)
133 | with Pool(workers) as p:
134 | p.map(processing_function, dataloader)
135 |
136 | logging.info("DONE for %s" % args.input_folder)
137 |
138 |
139 | if __name__ == "__main__":
140 | main()
141 |
--------------------------------------------------------------------------------
/supervized_partition/losses.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Wed Sep 26 13:56:33 2018
5 |
6 | @author: landrieuloic
7 |
8 | """
9 | import os
10 | import sys
11 | import math
12 | import numpy as np
13 | import torch
14 |
15 | DIR_PATH = os.path.dirname(os.path.realpath(__file__))
16 | sys.path.insert(0, os.path.join(DIR_PATH, '..'))
17 | sys.path.append(os.path.join(DIR_PATH,"../partition/cut-pursuit/src"))
18 |
19 | from partition.provider import *
20 | from partition.ply_c import libply_c
21 |
22 | import libcp
23 |
24 | def zhang(x, lam, dist_type):
25 | if dist_type == 'euclidian' or dist_type == 'scalar':
26 | beta = 1
27 | elif dist_type == 'intrinsic':
28 | beta = 1.0471975512
29 | return torch.clamp(-lam * x + lam * beta, min = 0)
30 |
31 | def compute_dist(embeddings, edg_source, edg_target, dist_type):
32 | if dist_type == 'euclidian':
33 | dist = ((embeddings[edg_source,:] - embeddings[edg_target,:])**2).sum(1)
34 | elif dist_type == 'intrinsic':
35 | smoothness = 0.999
36 | dist = (torch.acos((embeddings[edg_source,:] * embeddings[edg_target,:]).sum(1) * smoothness)-np.arccos(smoothness)) \
37 | / (np.arccos(-smoothness)-np.arccos(smoothness)) * 3.141592
38 | elif dist_type == 'scalar':
39 | dist = (embeddings[edg_source,:] * embeddings[edg_target,:]).sum(1)-1
40 | else:
41 | raise ValueError(" %s is an unknown argument of parameter --dist_type" % (dist_type))
42 | return dist
43 |
44 | def compute_loss(args, diff, is_transition, weights_loss):
45 | intra_edg = is_transition==0
46 | if 'tv' in args.loss:
47 | loss1 = (weights_loss[intra_edg] * (torch.sqrt(diff[intra_edg]+1e-10))).sum()
48 | elif 'laplacian' in args.loss:
49 | loss1 = (weights_loss[intra_edg] * (diff[intra_edg])).sum()
50 | elif 'TVH' in args.loss:
51 | delta = 0.2
52 | loss1 = delta * (weights_loss[intra_edg] * (torch.sqrt(1+diff[intra_edg]/delta**2)-1)).sum()
53 | else:
54 | raise ValueError(" %s is an unknown argument of parameter --loss" % (args.loss))
55 |
56 | inter_edg = is_transition==1
57 |
58 | if 'zhang' in args.loss:
59 | loss2 = (zhang(torch.sqrt(diff[inter_edg]+1e-10), weights_loss[inter_edg], args.dist_type)).sum()
60 | elif 'TVminus' in args.loss:
61 | loss2 = (torch.sqrt(diff[inter_edg]+1e-10) * weights_loss[inter_edg]).sum()
62 |
63 | #return loss1/ weights_loss.sum(), loss2/ weights_loss.sum()
64 | return loss1, loss2
65 |
66 |
67 | def compute_partition(args, embeddings, edg_source, edg_target, diff, xyz=0):
68 | edge_weight = np.ones_like(edg_source).astype('f4')
69 | if args.edge_weight_threshold>0:
70 | edge_weight[diff>1]=args.edge_weight_threshold
71 | if args.edge_weight_threshold<0:
72 | edge_weight = torch.exp(diff * args.edge_weight_threshold).detach().cpu().numpy()/np.exp(args.edge_weight_threshold)
73 |
74 | ver_value = np.zeros((embeddings.shape[0],0), dtype='f4')
75 | use_spatial = 0
76 | ver_value = np.hstack((ver_value,embeddings.detach().cpu().numpy()))
77 | if args.spatial_emb>0:
78 | ver_value = np.hstack((ver_value, args.spatial_emb * xyz))# * math.sqrt(args.reg_strength)))
79 | #ver_value = xyz * args.spatial_emb
80 | use_spatial = 1#!!!
81 |
82 | pred_components, pred_in_component = libcp.cutpursuit(ver_value, \
83 | edg_source.astype('uint32'), edg_target.astype('uint32'), edge_weight, \
84 | args.reg_strength / (4 * args.k_nn_adj), cutoff=args.CP_cutoff, spatial = use_spatial, weight_decay = 0.7)
85 | #emb2 = libcp.cutpursuit2(ver_value, edg_source.astype('uint32'), edg_target.astype('uint32'), edge_weight, args.reg_strength, cutoff=0, spatial =0)
86 | #emb2 = emb2.reshape(ver_value.shape)
87 | #((ver_value-emb2)**2).sum(0)
88 | #cut = pred_in_component[edg_source]!=pred_in_component[edg_target]
89 | return pred_components, pred_in_component
90 |
91 | def compute_weight_loss(args, embeddings, objects, edg_source, edg_target, is_transition, diff, return_partition, xyz=0):
92 |
93 | if args.loss_weight == 'seal' or args.loss_weight == 'crosspartition' or return_partition:
94 | pred_components, pred_in_component = compute_partition(args, embeddings, edg_source, edg_target, diff, xyz)
95 |
96 | if args.loss_weight=='none':
97 | weights_loss = np.ones_like(edg_target).astype('f4')
98 | elif args.loss_weight=='proportional':
99 | weights_loss = np.ones_like(edg_target).astype('f4') * float(len(is_transition)) / (1-is_transition).sum().float()
100 | weights_loss[is_transition.nonzero()] = float(len(is_transition)) / float(is_transition.sum()) * args.transition_factor
101 | weights_loss = weights_loss.cpu().numpy()
102 | elif args.loss_weight=='seal':
103 | weights_loss = compute_weights_SEAL(pred_components, pred_in_component, objects, edg_source, edg_target, is_transition, args.transition_factor)
104 | elif args.loss_weight=='crosspartition':
105 | weights_loss = compute_weights_XPART(pred_components, pred_in_component, objects.cpu().numpy(), edg_source, edg_target, is_transition.cpu().numpy(), args.transition_factor * 2 * args.k_nn_adj, xyz)
106 | else:
107 | raise ValueError(" %s is an unknown argument of parameter --loss" % (args.loss_weight))
108 |
109 | if args.cuda:
110 | weights_loss = torch.from_numpy(weights_loss).cuda()
111 | else:
112 | weights_loss = torch.from_numpy(weights_loss)
113 |
114 | if return_partition:
115 | return weights_loss, pred_components, pred_in_component
116 | else:
117 | return weights_loss
118 |
119 | def compute_weights_SEAL(pred_components, pred_in_component, objects, edg_source, edg_target, is_transition, transition_factor):
120 |
121 | SEAL_weights = np.ones((len(edg_source),), dtype='float32')
122 | w_per_component = np.empty((len(pred_components),), dtype='uint32')
123 | for i_com in range(len(pred_components)):
124 | w_per_component[i_com] = len(pred_components[i_com]) - mode(objects[pred_components[i_com]], only_frequency=True)
125 | SEAL_weights[is_transition.nonzero()] += np.stack(\
126 | (w_per_component[pred_in_component[edg_source[is_transition.nonzero()]]]
127 | , w_per_component[pred_in_component[edg_target[is_transition.nonzero()]]])).max(0) * transition_factor# 1 if not transition 1+w otherwise
128 | return SEAL_weights
129 |
130 | def compute_weights_XPART(pred_components, pred_in_component, objects, edg_source, edg_target, is_transition, transition_factor, xyz):
131 |
132 | SEAGL_weights = np.ones((len(edg_source),), dtype='float32')
133 | pred_transition = pred_in_component[edg_source]!=pred_in_component[edg_target]
134 | components_x, in_component_x = libply_c.connected_comp(pred_in_component.shape[0] \
135 | , edg_source.astype('uint32'), edg_target.astype('uint32') \
136 | , (is_transition+pred_transition==0).astype('uint8'), 0)
137 |
138 | edg_transition = is_transition.nonzero()[0]
139 | edg_source_trans = edg_source[edg_transition]
140 | edg_target_trans = edg_target[edg_transition]
141 |
142 | comp_x_weight = [len(c) for c in components_x]
143 | n_compx = len(components_x)
144 |
145 | edg_id = np.min((in_component_x[edg_source_trans],in_component_x[edg_target_trans]),0) * n_compx \
146 | + np.max((in_component_x[edg_source_trans],in_component_x[edg_target_trans]),0)
147 |
148 | edg_id_unique , in_edge_id, sedg_weight = np.unique(edg_id, return_index=True, return_counts=True)
149 |
150 | for i_edg in range(len(in_edge_id)):
151 | i_com_1 = in_component_x[edg_source_trans[in_edge_id[i_edg]]]
152 | i_com_2 = in_component_x[edg_target_trans[in_edge_id[i_edg]]]
153 | weight = min(comp_x_weight[i_com_1], comp_x_weight[i_com_2]) \
154 | / sedg_weight[i_edg] * transition_factor
155 | corresponding_trans_edg = edg_transition[\
156 | ((in_component_x[edg_source_trans]==i_com_1) * (in_component_x[edg_target_trans]==i_com_2) \
157 | + (in_component_x[edg_target_trans]==i_com_1) * (in_component_x[edg_source_trans]==i_com_2))]
158 | SEAGL_weights[corresponding_trans_edg] = SEAGL_weights[corresponding_trans_edg] + weight
159 |
160 | #missed_transition = ((is_transition==1)*(pred_transition==False)+(is_transition==0)*(pred_transition==True)).nonzero()[0]
161 | #missed_transition = ((is_transition==1)*(pred_transition==False)).nonzero()[0]
162 | #SEAGL_weights[missed_transition] = SEAGL_weights[missed_transition] * boosting_factor
163 | #scalar2ply('full_par.ply', xyz,pred_in_component)
164 | #scalar2ply('full_parX.ply', xyz,in_component_x)
165 | #edge_weight2ply2('w.ply', SEAGL_weights, xyz, edg_source, edg_target)
166 | return SEAGL_weights
167 |
168 | def mode(array, only_frequency=False):
169 | """compute the mode and the corresponding frequency of a given distribution"""
170 | u, counts = np.unique(array, return_counts=True)
171 | if only_frequency: return np.amax(counts)
172 | else:
173 | return u[np.argmax(counts)], np.amax(counts)
174 |
175 | def relax_edge_binary(edg_binary, edg_source, edg_target, n_ver, tolerance):
176 | if torch.is_tensor(edg_binary):
177 | relaxed_binary = edg_binary.cpu().numpy().copy()
178 | else:
179 | relaxed_binary = edg_binary.copy()
180 | transition_vertex = np.full((n_ver,), 0, dtype = 'uint8')
181 | for i_tolerance in range(tolerance):
182 | transition_vertex[edg_source[relaxed_binary.nonzero()]] = True
183 | transition_vertex[edg_target[relaxed_binary.nonzero()]] = True
184 | relaxed_binary[transition_vertex[edg_source]] = True
185 | relaxed_binary[transition_vertex[edg_target]>0] = True
186 | return relaxed_binary
187 |
188 |
--------------------------------------------------------------------------------
/vKITTI3D.md:
--------------------------------------------------------------------------------
1 | # vKITTI3D
2 |
3 | Download all point clouds and labels from [the vKITTI3D Dataset](https://github.com/VisualComputingInstitute/vkitti3D-dataset) and place extracted training files to `$VKITTI3D_DIR/data/`.
4 |
5 | ## Handcrafted Partition
6 |
7 | Not available for this dataset, the sparsity of the acquisition renders the handcrafted geometric features useless.
8 |
9 | ## Learned Partition
10 |
11 | For the learned partition, run:
12 | ```
13 | python supervized_partition/graph_processing.py --ROOT_PATH $VKITTI3D_DIR --dataset vkitti --voxel_width 0.05 --use_voronoi 1
14 |
15 | for FOLD in 1 2 3 4 5 6; do
16 | python ./supervized_partition/supervized_partition.py --ROOT_PATH $VKITTI3D_DIR --dataset vkitti \
17 | --epochs 50 --test_nth_epoch 10 --cvfold $FOLD --reg_strength 0.5 --spatial_emb 0.02 --batch_size 15 \
18 | --global_feat exyrgb --CP_cutoff 10 --odir results_part/vkitti/best; \
19 | done;
20 | ```
21 | or use our [trained weights](http://recherche.ign.fr/llandrieu/SPG/vkitti/results_part/pretrained.zip) and the `--resume RESUME` argument:
22 |
23 | ```
24 | python supervized_partition/graph_processing.py --ROOT_PATH $VKITTI3D_DIR --dataset vkitti --voxel_width 0.05 --use_voronoi 1
25 |
26 | for FOLD in 1 2 3 4 5 6; do
27 | python ./supervized_partition/supervized_partition.py --ROOT_PATH $VKITTI3D_DIR --dataset vkitti \
28 | --epochs -1 --test_nth_epoch 10 --cvfold $FOLD --reg_strength 0.5 --spatial_emb 0.02 --batch_size 15\
29 | --global_feat exyrgb --CP_cutoff 10 --odir results_partition/vkitti/pretrained --resume RESUME; \
30 | done;
31 | ```
32 |
33 | To evaluate the quality of the partition, run:
34 | ```
35 | python supervized_partition/evaluate_partition.py --dataset vkitti --cvfold 123456 --folder best
36 | ```
37 | ### Training
38 |
39 | Then, reorganize point clouds into superpoints by:
40 | ```
41 | python learning/vkitti_dataset.py --VKITTI_PATH $VKITTI3D_DIR
42 | ```
43 |
44 | To train from scratch, run:
45 | ```
46 | for FOLD in 1 2 3 4 5 6; do \
47 | CUDA_VISIBLE_DEVICES=0 python ./learning/main.py --dataset vkitti --VKITTI_PATH $VKITTI3D_DIR --cvfold $FOLD --epochs 100 \
48 | --lr_steps "[40, 50, 60, 70, 80]" --test_nth_epoch 10 --model_config gru_10_1_1_1_0,f_13 --pc_attribs xyzXYZrgb \
49 | --ptn_nfeat_stn 9 --batch_size 4 --ptn_minpts 15 --spg_augm_order 3 --spg_augm_hardcutoff 256 \
50 | --ptn_widths "[[64,64,128], [64,32,32]]" --ptn_widths_stn "[[32,64], [32,16]]" --loss_weights sqrt \
51 | --use_val_set 1 --odir results/vkitti/best/cv$FOLD; \
52 | done;\
53 | ```
54 |
55 | or use our [trained weights](http://recherche.ign.fr/llandrieu/SPG/vkitti/results/pretrained.zip) with the `--resume RESUME` argument:
56 | ```
57 | for FOLD in 1 2 3 4 5 6; do \
58 | CUDA_VISIBLE_DEVICES=0 python ./learning/main.py --dataset vkitti --VKITTI_PATH $VKITTI3D_DIR --cvfold $FOLD --epochs 100 \
59 | --lr_steps "[40, 50, 60, 70, 80]" --test_nth_epoch 10 --model_config gru_10_1_1_1_0,f_13 --pc_attribs xyzXYZrgb \
60 | --ptn_nfeat_stn 9 --batch_size 4 --ptn_minpts 15 --spg_augm_order 3 --spg_augm_hardcutoff 256 \
61 | --ptn_widths "[[64,64,128], [64,32,32]]" --ptn_widths_stn "[[32,64], [32,16]]" --loss_weights sqrt \
62 | --use_val_set 1 --odir results/vkitti/pretrained/cv$FOLD --resume RESUME; \
63 | done;\
64 | ```
65 |
66 | Estimate the quality of the semantic segmentation with:
67 | ```
68 | python learning/evaluate.py --dataset vkitti --odir results/vkitti/best --cvfold 123456
69 | ```
70 | #### Visualization
71 |
72 | To visualize the results and intermediary steps (on the subsampled graph), use the visualize function in partition. For example:
73 | ```
74 | python partition/visualize.py --dataset vkitti --ROOT_PATH $VKITTI3D_DIR --res_file 'results/vkitti/cv1/predictions_test' --file_path '01/0001_00000' --output_type ifprs
75 | ```
76 |
--------------------------------------------------------------------------------