├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── S3DIS.md ├── S3DIS_fix.diff ├── Semantic3D.md ├── learning ├── __init__.py ├── custom_dataset.py ├── ecc │ ├── GraphConvInfo.py │ ├── GraphConvModule.py │ ├── GraphPoolInfo.py │ ├── GraphPoolModule.py │ ├── __init__.py │ ├── cuda_kernels.py │ ├── test_GraphConvModule.py │ ├── test_GraphPoolModule.py │ └── utils.py ├── evaluate.py ├── graphnet.py ├── main.py ├── metrics.py ├── modules.py ├── pointnet.py ├── s3dis_dataset.py ├── sema3d_dataset.py ├── spg.py └── vkitti_dataset.py ├── partition ├── __init__.py ├── graphs.py ├── partition.py ├── ply_c │ ├── CMakeLists.txt │ ├── FindNumPy.cmake │ ├── __init__.py │ ├── connected_components.cpp │ ├── ply_c.cpp │ └── random_subgraph.cpp ├── provider.py ├── visualize.py └── write_Semantic3d.py ├── supervized_partition ├── __init__.py ├── evaluate_partition.py ├── folderhierarchy.py ├── generate_partition.py ├── graph_processing.py ├── losses.py └── supervized_partition.py └── vKITTI3D.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *DS_Store* 3 | *.vscode 4 | *.so 5 | partition/ply_c/build -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "partition/cut-pursuit"] 2 | path = partition/cut-pursuit 3 | url = ../cut-pursuit.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Loic Landrieu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs 3 | 4 | ## ⚠️ This repo is no longer maintained! Please check out our brand new [*SuperPoint Transformer*](https://github.com/drprojects/superpoint_transformer), which does everything better! ⚠️ 5 | 6 | 7 | 8 | This is the official PyTorch implementation of the papers: 9 | 10 | [*Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs*](http://arxiv.org/abs/1711.09869) 11 | 12 | by Loic Landrieu and Martin Simonovski (CVPR2018), 13 | 14 | and 15 | 16 | [*Point Cloud Oversegmentation with Graph-Structured Deep Metric Learning*](https://arxiv.org/pdf/1904.02113). 17 | 18 | by Loic Landrieu and Mohamed Boussaha (CVPR2019), 19 | 20 | 21 | 22 | 23 | 24 | ## Code structure 25 | * `./partition/*` - Partition code (geometric partitioning and superpoint graph construction using handcrafted features) 26 | * `./supervized_partition/*` - Supervized partition code (partitioning with learned features) 27 | * `./learning/*` - Learning code (superpoint embedding and contextual segmentation). 28 | 29 | To switch to the stable branch with only SPG, switch to [release](https://github.com/loicland/superpoint_graph/tree/release). 30 | 31 | ## Disclaimer 32 | Our partition method is inherently stochastic. Hence, even if we provide the trained weights, it is possible that the results that you obtain differ slightly from the ones presented in the paper. 33 | 34 | ## Requirements 35 | *0.* Download current version of the repository. We recommend using the `--recurse-submodules` option to make sure the [cut pursuit](https://github.com/loicland/cut-pursuit) module used in `/partition` is downloaded in the process. Wether you did not used the following command, please, refer to point 4:
36 | ``` 37 | git clone --recurse-submodules https://github.com/loicland/superpoint_graph 38 | ``` 39 | 40 | *1.* Install [PyTorch](https://pytorch.org) and [torchnet](https://github.com/pytorch/tnt). 41 | ``` 42 | pip install git+https://github.com/pytorch/tnt.git@master 43 | ``` 44 | 45 | *2.* Install additional Python packages: 46 | ``` 47 | pip install future igraph tqdm transforms3d pynvrtc fastrlock cupy h5py sklearn plyfile scipy pandas 48 | ``` 49 | 50 | *3.* Install Boost (1.63.0 or newer) and Eigen3, in Conda:
51 | ``` 52 | conda install -c anaconda boost; conda install -c omnia eigen3; conda install eigen; conda install -c r libiconv 53 | ``` 54 | 55 | *4.* Make sure that cut pursuit was downloaded. Otherwise, clone [this repository](https://github.com/loicland/cut-pursuit) or add it as a submodule in `/partition`:
56 | ``` 57 | cd partition 58 | git submodule init 59 | git submodule update --remote cut-pursuit 60 | ``` 61 | 62 | *5.* Compile the ```libply_c``` and ```libcp``` libraries: 63 | ``` 64 | CONDAENV=YOUR_CONDA_ENVIRONMENT_LOCATION 65 | cd partition/ply_c 66 | cmake . -DPYTHON_LIBRARY=$CONDAENV/lib/libpython3.6m.so -DPYTHON_INCLUDE_DIR=$CONDAENV/include/python3.6m -DBOOST_INCLUDEDIR=$CONDAENV/include -DEIGEN3_INCLUDE_DIR=$CONDAENV/include/eigen3 67 | make 68 | cd .. 69 | cd cut-pursuit 70 | mkdir build 71 | cd build 72 | cmake .. -DPYTHON_LIBRARY=$CONDAENV/lib/libpython3.6m.so -DPYTHON_INCLUDE_DIR=$CONDAENV/include/python3.6m -DBOOST_INCLUDEDIR=$CONDAENV/include -DEIGEN3_INCLUDE_DIR=$CONDAENV/include/eigen3 73 | make 74 | ``` 75 | *6.* (optional) Install [Pytorch Geometric](https://github.com/rusty1s/pytorch_geometric) 76 | 77 | The code was tested on Ubuntu 14 and 16 with Python 3.5 to 3.8 and PyTorch 0.2 to 1.3. 78 | 79 | ### Troubleshooting 80 | 81 | Common sources of errors and how to fix them: 82 | - $CONDAENV is not well defined : define it or replace $CONDAENV by the absolute path of your conda environment (find it with ```locate anaconda```) 83 | - anaconda uses a different version of python than 3.6m : adapt it in the command. Find which version of python conda is using with ```locate anaconda3/lib/libpython``` 84 | - you are using boost 1.62 or older: update it 85 | - cut pursuit did not download: manually clone it in the ```partition``` folder or add it as a submodule as proposed in the requirements, point 4. 86 | - error in make: `'numpy/ndarrayobject.h' file not found`: set symbolic link to python site-package with `sudo ln -s $CONDAENV/lib/python3.7/site-packages/numpy/core/include/numpy $CONDAENV/include/numpy` 87 | 88 | 89 | ## Running the code 90 | 91 | To run our code or retrain from scratch on different datasets, see the corresponding readme files. 92 | Currently supported dataset are as follow: 93 | 94 | | Dataset | handcrafted partition | learned partition | 95 | | ---------- | --------------------- | ------------------| 96 | | S3DIS | yes | yes | 97 | | Semantic3D | yes | to come soon | 98 | | vKITTI3D | no | yes | 99 | | ScanNet | to come soon | to come soon | 100 | 101 | To use pytorch-geometric graph convolutions instead of our own, use the option `--use_pyg 1` in `./learning/main.py`. Their code is more stable and just as fast. Otherwise, use `--use_pyg 0` 102 | 103 | #### Evaluation 104 | 105 | To evaluate quantitatively a trained model, use (for S3DIS and vKITTI3D only): 106 | ``` 107 | python learning/evaluate.py --dataset s3dis --odir results/s3dis/best --cvfold 123456 108 | ``` 109 | 110 | To visualize the results and all intermediary steps, use the visualize function in partition (for S3DIS, vKITTI3D,a nd Semantic3D). For example: 111 | ``` 112 | python partition/visualize.py --dataset s3dis --ROOT_PATH $S3DIR_DIR --res_file results/s3dis/pretrained/cv1/predictions_test --file_path Area_1/conferenceRoom_1 --output_type igfpres 113 | ``` 114 | 115 | ```output_type``` defined as such: 116 | - ```'i'``` = input rgb point cloud 117 | - ```'g'``` = ground truth (if available), with the predefined class to color mapping 118 | - ```'f'``` = geometric feature with color code: red = linearity, green = planarity, blue = verticality 119 | - ```'p'``` = partition, with a random color for each superpoint 120 | - ```'r'``` = result cloud, with the predefined class to color mapping 121 | - ```'e'``` = error cloud, with green/red hue for correct/faulty prediction 122 | - ```'s'``` = superedge structure of the superpoint (toggle wireframe on meshlab to view it) 123 | 124 | Add option ```--upsample 1``` if you want the prediction file to be on the original, unpruned data (long). 125 | 126 | # Other data sets 127 | 128 | You can apply SPG on your own data set with minimal changes: 129 | - adapt references to ```custom_dataset``` in ```/partition/partition.py``` 130 | - you will need to create the function ```read_custom_format``` in ```/partition/provider.py``` which outputs xyz and rgb values, as well as semantic labels if available (already implemented for ply and las files) 131 | - adapt the template function ```/learning/custom_dataset.py``` to your achitecture and design choices 132 | - adapt references to ```custom_dataset``` in ```/learning/main.py``` 133 | - add your data set colormap to ```get_color_from_label``` in ```/partition/provider.py``` 134 | - adapt line 212 of `learning/spg.py` to reflect the missing or extra point features 135 | - change ```--model_config``` to ```gru_10,f_K``` with ```K``` as the number of classes in your dataset, or ```gru_10_0,f_K``` to use matrix edge filters instead of vectors (only use matrices when your data set is quite large, and with many different point clouds, like S3DIS). 136 | 137 | # Datasets without RGB 138 | If your data does not have RGB values you can easily use SPG. You will need to follow the instructions in ```partition/partition.ply``` regarding the pruning. 139 | You will need to adapt the ```/learning/custom_dataset.py``` file so that it does not refer ro RGB values. 140 | You should absolutely not use a model pretrained on values with RGB. instead, retrain a model from scratch using the ```--pc_attribs xyzelpsv``` option to remove RGB from the shape embedding input. 141 | 142 | # Citation 143 | If you use the semantic segmentation module (code in `/learning`), please cite:
144 | *Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs*, Loic Landrieu and Martin Simonovski, CVPR, 2018. 145 | 146 | If you use the learned partition module (code in `/supervized_partition`), please cite:
147 | *Point Cloud Oversegmentation with Graph-Structured Deep Metric Learning*, Loic Landrieu and Mohamed Boussaha CVPR, 2019. 148 | 149 | To refer to the handcrafted partition (code in `/partition`) step specifically, refer to:
150 | *Weakly Supervised Segmentation-Aided Classification of Urban Scenes from 3D LiDAR Point Clouds*, Stéphane Guinard and Loic Landrieu. ISPRS Workshop, 2017. 151 | 152 | To refer to the L0-cut pursuit algorithm (code in `github.com/loicland/cut-pursuit`) specifically, refer to:
153 | *Cut Pursuit: Fast Algorithms to Learn Piecewise Constant Functions on General Weighted Graphs*, Loic Landrieu and Guillaume Obozinski, SIAM Journal on Imaging Sciences, 2017 154 | 155 | To refer to pytorch geometric implementation, see their bibtex in [their repo](https://github.com/rusty1s/pytorch_geometric). 156 | 157 | 158 | -------------------------------------------------------------------------------- /S3DIS.md: -------------------------------------------------------------------------------- 1 | # S3DIS 2 | 3 | Download [S3DIS Dataset](http://buildingparser.stanford.edu/dataset.html) and extract `Stanford3dDataset_v1.2_Aligned_Version.zip` to `$S3DIS_DIR/data`, where `$S3DIS_DIR` is set to dataset directory. 4 | 5 | To fix some issues with the dataset as reported in issue [#29](https://github.com/loicland/superpoint_graph/issues/29), apply path `S3DIS_fix.diff` with: 6 | ``` 7 | cp S3DIS_fix.diff $S3DIS_DIR/data; cd $S3DIS_DIR/data; git apply S3DIS_fix.diff; rm S3DIS_fix.diff; cd - 8 | ``` 9 | Define $S3DIS_DIR as the location of the folder containing `/data` 10 | 11 | ## SPG with Handcrafted Partition 12 | 13 | To compute the partition with handcrafted features run: 14 | ``` 15 | python partition/partition.py --dataset s3dis --ROOT_PATH $S3DIS_DIR --voxel_width 0.03 --reg_strength 0.03 16 | ``` 17 | 18 | Then, reorganize point clouds into superpoints by: 19 | ``` 20 | python learning/s3dis_dataset.py --S3DIS_PATH $S3DIS_DIR 21 | ``` 22 | 23 | To train from scratch on the all 6 folds on the handcrafted partition, run: 24 | ``` 25 | for FOLD in 1 2 3 4 5 6; do \ 26 | CUDA_VISIBLE_DEVICES=0 python learning/main.py --dataset s3dis --S3DIS_PATH $S3DIS_DIR --cvfold $FOLD --epochs 350 \ 27 | --lr_steps '[275,320]' --test_nth_epoch 50 --model_config 'gru_10_0,f_13' --ptn_nfeat_stn 14 --nworkers 2 \ 28 | --pc_attribs xyzrgbelpsvXYZ --odir "results/s3dis/best/cv${FOLD}" --nworkers 4; \ 29 | done 30 | ``` 31 | 32 | Our trained networks can be downloaded [here](http://imagine.enpc.fr/~simonovm/largescale/models_s3dis.zip). Unzip the folder (but not the model.pth.tar themselves) and place them in the code folder `results/s3dis/pretrained/`. 33 | 34 | To test these networks on the full test set, run: 35 | ``` 36 | for FOLD in 1 2 3 4 5 6; do \ 37 | CUDA_VISIBLE_DEVICES=0 python learning/main.py --dataset s3dis --S3DIS_PATH $S3DIS_DIR --cvfold $FOLD --epochs -1 --lr_steps '[275,320]' \ 38 | --test_nth_epoch 50 --model_config 'gru_10_0,f_13' --ptn_nfeat_stn 14 --nworkers 2 --pc_attribs xyzrgbelpsvXYZ --odir "results/s3dis/pretrained/cv${FOLD}" --resume RESUME; \ 39 | done 40 | ``` 41 | 42 | ## SSP+SPG: SPG with learned partition 43 | 44 | To learn the partition from scratch run: 45 | ``` 46 | python supervized_partition/graph_processing.py --ROOT_PATH $S3DIS_DIR --dataset s3dis --voxel_width 0.03; \ 47 | 48 | for FOLD in 1 2 3 4 5 6; do \ 49 | python ./supervized_partition/supervized_partition.py --ROOT_PATH $S3DIS_DIR --cvfold $FOLD \ 50 | --odir results_partition/s3dis/best --epochs 50 --reg_strength 0.1 --spatial_emb 0.2 \ 51 | --global_feat eXYrgb --CP_cutoff 25; \ 52 | done 53 | ``` 54 | Or download our trained weights [here](http://recherche.ign.fr/llandrieu/SPG/S3DIS/pretrained.zip) in the folder `results_partition/s3dis/pretrained`, unzipped and run the following code: 55 | 56 | ``` 57 | for FOLD in 1 2 3 4 5 6; do \ 58 | python ./supervized_partition/supervized_partition.py --ROOT_PATH $S3DIS_DIR --cvfold $FOLD --epochs -1 \ 59 | --odir results_partition/s3dis/pretrained --reg_strength 0.1 --spatial_emb 0.2 --global_feat eXYrgb \ 60 | --CP_cutoff 25 --resume RESUME; \ 61 | done 62 | ``` 63 | 64 | To evaluate the quality of the partition, run: 65 | ``` 66 | python supervized_partition/evaluate_partition.py --dataset s3dis --folder pretrained --cvfold 123456 67 | ``` 68 | 69 | Then, reorganize point clouds into superpoints with: 70 | ``` 71 | python learning/s3dis_dataset.py --S3DIS_PATH $S3DIS_DIR --supervized_partition 1 -plane_model_elevation 1 72 | ``` 73 | 74 | Then to learn the SPG models from scratch, run: 75 | ``` 76 | for FOLD in 1 2 3 4 5 6; do \ 77 | CUDA_VISIBLE_DEVICES=0 python ./learning/main.py --dataset s3dis --S3DIS_PATH $S3DIS_DIR --batch_size 5 \ 78 | --cvfold $FOLD --epochs 250 --lr_steps '[150,200]' --model_config "gru_10_0,f_13" --ptn_nfeat_stn 10 \ 79 | --nworkers 2 --spg_augm_order 5 --pc_attribs xyzXYZrgbe --spg_augm_hardcutoff 768 --ptn_minpts 50 \ 80 | --use_val_set 1 --odir results/s3dis/best/cv$FOLD; \ 81 | done; 82 | ``` 83 | 84 | Or use our [trained weights](http://recherche.ign.fr/llandrieu/SPG/S3DIS/pretrained_SSP.zip) with `--epochs -1` and `--resume RESUME`: 85 | ``` 86 | for FOLD in 1 2 3 4 5 6; do \ 87 | CUDA_VISIBLE_DEVICES=0 python ./learning/main.py --dataset s3dis --S3DIS_PATH $S3DIS_DIR --batch_size 5 \ 88 | --cvfold $FOLD --epochs -1 --lr_steps '[150,200]' --model_config "gru_10_0,f_13" --ptn_nfeat_stn 10 \ 89 | --nworkers 2 --spg_augm_order 5 --pc_attribs xyzXYZrgbe --spg_augm_hardcutoff 768 --ptn_minpts 50 \ 90 | --use_val_set 1 --odir results/s3dis/pretrained_SSP/cv$FOLD --resume RESUME; \ 91 | done; 92 | ``` 93 | Note that these weights are specifically adapted to the pretrained model for the learned partition. Any change to the partition might decrease their performance. 94 | -------------------------------------------------------------------------------- /S3DIS_fix.diff: -------------------------------------------------------------------------------- 1 | diff --git a/Area_3/hallway_2/hallway_2.txt b/Area_3/hallway_2/hallway_2.txt 2 | index 02f32b8..870566e 100644 3 | --- a/Area_3/hallway_2/hallway_2.txt 4 | +++ b/Area_3/hallway_2/hallway_2.txt 5 | @@ -926334,7 +926334,7 @@ 6 | 19.237 -9.161 1.561 141 131 96 7 | 19.248 -9.160 1.768 136 129 103 8 | 19.276 -9.160 1.684 139 130 99 9 | -19.302 -9.10 1.785 146 137 106 10 | +19.302 -9.1 0 1.785 146 137 106 11 | 19.242 -9.160 1.790 146 134 108 12 | 19.271 -9.160 1.679 140 129 99 13 | 19.278 -9.160 1.761 133 123 98 14 | diff --git a/Area_5/hallway_6/Annotations/ceiling_1.txt b/Area_5/hallway_6/Annotations/ceiling_1.txt 15 | index 62e563d..3a9087b 100644 16 | --- a/Area_5/hallway_6/Annotations/ceiling_1.txt 17 | +++ b/Area_5/hallway_6/Annotations/ceiling_1.txt 18 | @@ -180386,7 +180386,7 @@ 19 | 22.383 6.858 3.050 155 155 165 20 | 22.275 6.643 3.048 192 194 191 21 | 22.359 6.835 3.050 152 152 162 22 | -22.350 6.692 3.048 185187 182 23 | +22.350 6.692 3.048 185 187 182 24 | 22.314 6.638 3.048 170 171 175 25 | 22.481 6.818 3.049 149 149 159 26 | 22.328 6.673 3.048 190 195 191 27 | diff --git a/Area_6/copyRoom_1/copy_Room_1.txt b/Area_6/copyRoom_1/copyRoom_1.txt 28 | similarity index 100% 29 | rename from Area_6/copyRoom_1/copy_Room_1.txt 30 | rename to Area_6/copyRoom_1/copyRoom_1.txt 31 | -------------------------------------------------------------------------------- /Semantic3D.md: -------------------------------------------------------------------------------- 1 | # Semantic3D 2 | 3 | Download all point clouds and labels from [Semantic3D Dataset](http://www.semantic3d.net/) and place extracted training files to `$SEMA3D_DIR/data/train`, reduced test files into `$SEMA3D_DIR/data/test_reduced`, and full test files into `$SEMA3D_DIR/data/test_full`, where `$SEMA3D_DIR` is set to dataset directory. The label files of the training files must be put in the same directory than the .txt files. 4 | 5 | ## Handcrafted Partition 6 | 7 | To compute the partition with handcrafted features run: 8 | ``` 9 | python partition/partition.py --dataset sema3d --ROOT_PATH $SEMA3D_DIR --voxel_width 0.05 --reg_strength 0.8 --ver_batch 5000000 10 | ``` 11 | It is recommended that you have at least 24GB of RAM to run this code. Otherwise, increase the ```voxel_width``` parameter to increase pruning. 12 | 13 | Then, reorganize point clouds into superpoints by: 14 | ``` 15 | python learning/sema3d_dataset.py --SEMA3D_PATH $SEMA3D_DIR 16 | ``` 17 | 18 | To train on the whole publicly available data and test on the reduced test set, run: 19 | ``` 20 | CUDA_VISIBLE_DEVICES=0 python learning/main.py --dataset sema3d --SEMA3D_PATH $SEMA3D_DIR --db_test_name testred --db_train_name trainval \ 21 | --epochs 500 --lr_steps '[350, 400, 450]' --test_nth_epoch 100 --model_config 'gru_10,f_8' --ptn_nfeat_stn 11 \ 22 | --nworkers 2 --pc_attrib xyzrgbelpsv --odir "results/sema3d/trainval_best" 23 | ``` 24 | The trained network can be downloaded [here](http://imagine.enpc.fr/~simonovm/largescale/model_sema3d_trainval.pth.tar) and loaded with `--resume` argument. Rename the file ```model.pth.tar``` (do not try to unzip it!) and place it in the directory ```results/sema3d/trainval_best```. 25 | 26 | To test this network on the full test set, run: 27 | ``` 28 | CUDA_VISIBLE_DEVICES=0 python learning/main.py --dataset sema3d --SEMA3D_PATH $SEMA3D_DIR --db_test_name testfull --db_train_name trainval \ 29 | --epochs -1 --lr_steps '[350, 400, 450]' --test_nth_epoch 100 --model_config 'gru_10,f_8' --ptn_nfeat_stn 11 \ 30 | --nworkers 2 --pc_attrib xyzrgbelpsv --odir "results/sema3d/trainval_best" --resume RESUME 31 | ``` 32 | We validated our configuration on a custom split of 11 and 4 clouds. The network is trained as such: 33 | ``` 34 | CUDA_VISIBLE_DEVICES=0 python learning/main.py --dataset sema3d --SEMA3D_PATH $SEMA3D_DIR --epochs 450 --lr_steps '[350, 400]' --test_nth_epoch 100 \ 35 | --model_config 'gru_10,f_8' --pc_attrib xyzrgbelpsv --ptn_nfeat_stn 11 --nworkers 2 --odir "results/sema3d/best" 36 | ``` 37 | 38 | #### Learned Partition 39 | 40 | Not yet available. 41 | 42 | #### Visualization 43 | 44 | To upsample the prediction to the unpruned data and write the .labels files for the reduced test set, run (quite slow): 45 | ``` 46 | python partition/write_Semantic3d.py --SEMA3D_PATH $SEMA3D_DIR --odir "results/sema3d/trainval_best" --db_test_name testred 47 | ``` 48 | 49 | To visualize the results and intermediary steps (on the subsampled graph), use the visualize function in partition. For example: 50 | ``` 51 | python partition/visualize.py --dataset sema3d --ROOT_PATH $SEMA3D_DIR --res_file 'results/sema3d/trainval_best/prediction_testred' --file_path 'test_reduced/MarketplaceFeldkirch_Station4' --output_type ifprs 52 | ``` 53 | avoid ```--upsample 1``` as it can can take a very long time on the largest clouds. 54 | -------------------------------------------------------------------------------- /learning/__init__.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | 3 | DIR_PATH = os.path.dirname(os.path.realpath(__file__)) 4 | sys.path.insert(0, DIR_PATH) -------------------------------------------------------------------------------- /learning/custom_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Mar 20 16:16:14 2018 5 | 6 | @author: landrieuloic 7 | """""" 8 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs 9 | http://arxiv.org/abs/1711.09869 10 | 2017 Loic Landrieu, Martin Simonovsky 11 | Template file for processing custome datasets 12 | """ 13 | from __future__ import division 14 | from __future__ import print_function 15 | from builtins import range 16 | 17 | import random 18 | import numpy as np 19 | import os 20 | import functools 21 | import torch 22 | import torchnet as tnt 23 | import h5py 24 | import spg 25 | 26 | 27 | def get_datasets(args, test_seed_offset=0): 28 | """build training and testing set""" 29 | 30 | #for a simple train/test organization 31 | trainset = ['train/' + f for f in os.listdir(args.CUSTOM_SET_PATH + '/superpoint_graphs/train')] 32 | testset = ['test/' + f for f in os.listdir(args.CUSTOM_SET_PATH + '/superpoint_graphs/train')] 33 | 34 | # Load superpoints graphs 35 | testlist, trainlist = [], [] 36 | for n in trainset: 37 | trainlist.append(spg.spg_reader(args, args.CUSTOM_SET_PATH + '/superpoint_graphs/' + n + '.h5', True)) 38 | for n in testset: 39 | testlist.append(spg.spg_reader(args, args.CUSTOM_SET_PATH + '/superpoint_graphs/' + n + '.h5', True)) 40 | 41 | # Normalize edge features 42 | if args.spg_attribs01: 43 | trainlist, testlist, validlist, scaler = spg.scaler01(trainlist, testlist) 44 | 45 | return tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in trainlist], 46 | functools.partial(spg.loader, train=True, args=args, db_path=args.CUSTOM_SET_PATH)), \ 47 | tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in testlist], 48 | functools.partial(spg.loader, train=False, args=args, db_path=args.CUSTOM_SET_PATH, test_seed_offset=test_seed_offset)) ,\ 49 | scaler 50 | 51 | def get_info(args): 52 | edge_feats = 0 53 | for attrib in args.edge_attribs.split(','): 54 | a = attrib.split('/')[0] 55 | if a in ['delta_avg', 'delta_std', 'xyz']: 56 | edge_feats += 3 57 | else: 58 | edge_feats += 1 59 | 60 | return { 61 | 'node_feats': 11 if args.pc_attribs=='' else len(args.pc_attribs), 62 | 'edge_feats': edge_feats, 63 | 'classes': 10, #CHANGE TO YOUR NUMBER OF CLASS 64 | 'inv_class_map': {0:'class_A', 1:'class_B'}, #etc... 65 | } 66 | 67 | def preprocess_pointclouds(SEMA3D_PATH): 68 | """ Preprocesses data by splitting them by components and normalizing.""" 69 | 70 | for n in ['train', 'test_reduced', 'test_full']: 71 | pathP = '{}/parsed/{}/'.format(SEMA3D_PATH, n) 72 | pathD = '{}/features/{}/'.format(SEMA3D_PATH, n) 73 | pathC = '{}/superpoint_graphs/{}/'.format(SEMA3D_PATH, n) 74 | if not os.path.exists(pathP): 75 | os.makedirs(pathP) 76 | random.seed(0) 77 | 78 | for file in os.listdir(pathC): 79 | print(file) 80 | if file.endswith(".h5"): 81 | f = h5py.File(pathD + file, 'r') 82 | xyz = f['xyz'][:] 83 | rgb = f['rgb'][:].astype(np.float) 84 | elpsv = np.stack([ f['xyz'][:,2][:], f['linearity'][:], f['planarity'][:], f['scattering'][:], f['verticality'][:] ], axis=1) 85 | 86 | # rescale to [-0.5,0.5]; keep xyz 87 | #warning - to use the trained model, make sure the elevation is comparable 88 | #to the set they were trained on 89 | #i.e. ~0 for roads and ~0.2-0.3 for builings for sema3d 90 | # and -0.5 for floor and 0.5 for ceiling for s3dis 91 | elpsv[:,0] /= 100 # (rough guess) #adapt 92 | elpsv[:,1:] -= 0.5 93 | rgb = rgb/255.0 - 0.5 94 | 95 | P = np.concatenate([xyz, rgb, elpsv], axis=1) 96 | 97 | f = h5py.File(pathC + file, 'r') 98 | numc = len(f['components'].keys()) 99 | 100 | with h5py.File(pathP + file, 'w') as hf: 101 | for c in range(numc): 102 | idx = f['components/{:d}'.format(c)][:].flatten() 103 | if idx.size > 10000: # trim extra large segments, just for speed-up of loading time 104 | ii = random.sample(range(idx.size), k=10000) 105 | idx = idx[ii] 106 | 107 | hf.create_dataset(name='{:d}'.format(c), data=P[idx,...]) 108 | 109 | if __name__ == "__main__": 110 | import argparse 111 | parser = argparse.ArgumentParser(description='Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs') 112 | parser.add_argument('--CUSTOM_SET_PATH', default='datasets/custom_set') 113 | args = parser.parse_args() 114 | preprocess_pointclouds(args.CUSTOM_SET_PATH) 115 | 116 | 117 | -------------------------------------------------------------------------------- /learning/ecc/GraphConvInfo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs 3 | https://github.com/mys007/ecc 4 | https://arxiv.org/abs/1704.02901 5 | 2017 Martin Simonovsky 6 | """ 7 | from __future__ import division 8 | from __future__ import print_function 9 | from builtins import range 10 | 11 | import igraph 12 | import torch 13 | from collections import defaultdict 14 | import numpy as np 15 | 16 | class GraphConvInfo(object): 17 | """ Holds information about the structure of graph(s) in a vectorized form useful to `GraphConvModule`. 18 | 19 | We assume that the node feature tensor (given to `GraphConvModule` as input) is ordered by igraph vertex id, e.g. the fifth row corresponds to vertex with id=4. Batch processing is realized by concatenating all graphs into a large graph of disconnected components (and all node feature tensors into a large tensor). 20 | 21 | The class requires problem-specific `edge_feat_func` function, which receives dict of edge attributes and returns Tensor of edge features and LongTensor of inverse indices if edge compaction was performed (less unique edge features than edges so some may be reused). 22 | """ 23 | 24 | def __init__(self, *args, **kwargs): 25 | self._idxn = None #indices into input tensor of convolution (node features) 26 | self._idxe = None #indices into edge features tensor (or None if it would be linear, i.e. no compaction) 27 | self._degrees = None #in-degrees of output nodes (slices _idxn and _idxe) 28 | self._degrees_gpu = None 29 | self._edgefeats = None #edge features tensor (to be processed by feature-generating network) 30 | if len(args)>0 or len(kwargs)>0: 31 | self.set_batch(*args, **kwargs) 32 | 33 | def set_batch(self, graphs, edge_feat_func): 34 | """ Creates a representation of a given batch of graphs. 35 | 36 | Parameters: 37 | graphs: single graph or a list/tuple of graphs. 38 | edge_feat_func: see class description. 39 | """ 40 | 41 | graphs = graphs if isinstance(graphs,(list,tuple)) else [graphs] 42 | p = 0 43 | idxn = [] 44 | degrees = [] 45 | edge_indexes = [] 46 | edgeattrs = defaultdict(list) 47 | 48 | for G in graphs: 49 | E = np.array(G.get_edgelist()) 50 | idx = E[:,1].argsort() # sort by target 51 | 52 | idxn.append(p + E[idx,0]) 53 | edgeseq = G.es[idx.tolist()] 54 | for a in G.es.attributes(): 55 | edgeattrs[a] += edgeseq.get_attribute_values(a) 56 | degrees += G.indegree(G.vs, loops=True) 57 | edge_indexes.append(np.asarray(p + E[idx])) 58 | p += G.vcount() 59 | 60 | self._edgefeats, self._idxe = edge_feat_func(edgeattrs) 61 | 62 | self._idxn = torch.LongTensor(np.concatenate(idxn)) 63 | if self._idxe is not None: 64 | assert self._idxe.numel() == self._idxn.numel() 65 | 66 | self._degrees = torch.LongTensor(degrees) 67 | self._degrees_gpu = None 68 | 69 | self._edge_indexes = torch.LongTensor(np.concatenate(edge_indexes).T) 70 | 71 | def cuda(self): 72 | self._idxn = self._idxn.cuda() 73 | if self._idxe is not None: self._idxe = self._idxe.cuda() 74 | self._degrees_gpu = self._degrees.cuda() 75 | self._edgefeats = self._edgefeats.cuda() 76 | self._edge_indexes = self._edge_indexes.cuda() 77 | 78 | def get_buffers(self): 79 | """ Provides data to `GraphConvModule`. 80 | """ 81 | return self._idxn, self._idxe, self._degrees, self._degrees_gpu, self._edgefeats 82 | 83 | def get_pyg_buffers(self): 84 | """ Provides data to `GraphConvModule`. 85 | """ 86 | return self._edge_indexes -------------------------------------------------------------------------------- /learning/ecc/GraphConvModule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs 3 | https://github.com/mys007/ecc 4 | https://arxiv.org/abs/1704.02901 5 | 2017 Martin Simonovsky 6 | """ 7 | from __future__ import division 8 | from __future__ import print_function 9 | from builtins import range 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable, Function 14 | from .GraphConvInfo import GraphConvInfo 15 | from . import cuda_kernels 16 | from . import utils 17 | 18 | 19 | class GraphConvFunction(Function): 20 | """Computes operations for each edge and averages the results over respective nodes. 21 | The operation is either matrix-vector multiplication (for 3D weight tensors) or element-wise 22 | vector-vector multiplication (for 2D weight tensors). The evaluation is computed in blocks of 23 | size `edge_mem_limit` to reduce peak memory load. See `GraphConvInfo` for info on `idxn, idxe, degs`. 24 | """ 25 | def init(self, in_channels, out_channels, idxn, idxe, degs, degs_gpu, edge_mem_limit=1e20): 26 | self._in_channels = in_channels 27 | self._out_channels = out_channels 28 | self._idxn = idxn 29 | self._idxe = idxe 30 | self._degs = degs 31 | self._degs_gpu = degs_gpu 32 | self._shards = utils.get_edge_shards(degs, edge_mem_limit) 33 | 34 | def _multiply(ctx, a, b, out, f_a=None, f_b=None): 35 | """Performs operation on edge weights and node signal""" 36 | if ctx._full_weight_mat: 37 | # weights are full in_channels x out_channels matrices -> mm 38 | torch.bmm(f_a(a) if f_a else a, f_b(b) if f_b else b, out=out) 39 | else: 40 | # weights represent diagonal matrices -> mul 41 | torch.mul(a, b.expand_as(a), out=out) 42 | 43 | @staticmethod 44 | def forward(ctx, input, weights, in_channels, out_channels, idxn, idxe, degs, degs_gpu, edge_mem_limit=1e20): 45 | 46 | ctx.save_for_backward(input, weights) 47 | ctx._in_channels = in_channels 48 | ctx._out_channels = out_channels 49 | ctx._idxn = idxn 50 | ctx._idxe = idxe 51 | ctx._degs = degs 52 | ctx._degs_gpu = degs_gpu 53 | ctx._shards = utils.get_edge_shards(degs, edge_mem_limit) 54 | 55 | ctx._full_weight_mat = weights.dim() == 3 56 | assert ctx._full_weight_mat or ( 57 | in_channels == out_channels and weights.size(1) == in_channels) 58 | 59 | output = input.new(degs.numel(), out_channels) 60 | 61 | # loop over blocks of output nodes 62 | startd, starte = 0, 0 63 | for numd, nume in ctx._shards: 64 | 65 | # select sequence of matching pairs of node and edge weights 66 | sel_input = torch.index_select(input, 0, idxn.narrow(0, starte, nume)) 67 | 68 | if ctx._idxe is not None: 69 | sel_weights = torch.index_select(weights, 0, idxe.narrow(0, starte, nume)) 70 | else: 71 | sel_weights = weights.narrow(0, starte, nume) 72 | 73 | # compute matrix-vector products 74 | products = input.new() 75 | GraphConvFunction._multiply(ctx, sel_input, sel_weights, products, lambda a: a.unsqueeze(1)) 76 | 77 | # average over nodes 78 | if ctx._idxn.is_cuda: 79 | cuda_kernels.conv_aggregate_fw(output.narrow(0, startd, numd), products.view(-1, ctx._out_channels), 80 | ctx._degs_gpu.narrow(0, startd, numd)) 81 | else: 82 | k = 0 83 | for i in range(startd, startd + numd): 84 | if ctx._degs[i] > 0: 85 | torch.mean(products.narrow(0, k, ctx._degs[i]), 0, out=output[i]) 86 | else: 87 | output[i].fill_(0) 88 | k = k + ctx._degs[i] 89 | 90 | startd += numd 91 | starte += nume 92 | del sel_input, sel_weights, products 93 | 94 | return output 95 | 96 | @staticmethod 97 | def backward(ctx, grad_output): 98 | input, weights = ctx.saved_tensors 99 | 100 | grad_input = input.new(input.size()).fill_(0) 101 | grad_weights = weights.new(weights.size()) 102 | if ctx._idxe is not None: grad_weights.fill_(0) 103 | 104 | # loop over blocks of output nodes 105 | startd, starte = 0, 0 106 | for numd, nume in ctx._shards: 107 | 108 | grad_products, tmp = input.new(nume, ctx._out_channels), input.new() 109 | 110 | if ctx._idxn.is_cuda: 111 | cuda_kernels.conv_aggregate_bw(grad_products, grad_output.narrow(0, startd, numd), 112 | ctx._degs_gpu.narrow(0, startd, numd)) 113 | else: 114 | k = 0 115 | for i in range(startd, startd + numd): 116 | if ctx._degs[i] > 0: 117 | torch.div(grad_output[i], ctx._degs[i], out=grad_products[k]) 118 | if ctx._degs[i] > 1: 119 | grad_products.narrow(0, k + 1, ctx._degs[i] - 1).copy_( 120 | grad_products[k].expand(ctx._degs[i] - 1, 1, ctx._out_channels).squeeze(1)) 121 | k = k + ctx._degs[i] 122 | 123 | # grad wrt weights 124 | sel_input = torch.index_select(input, 0, ctx._idxn.narrow(0, starte, nume)) 125 | 126 | if ctx._idxe is not None: 127 | GraphConvFunction._multiply(ctx, sel_input, grad_products, tmp, 128 | lambda a: a.unsqueeze(1).transpose_(2, 1), 129 | lambda b: b.unsqueeze(1)) 130 | grad_weights.index_add_(0, ctx._idxe.narrow(0, starte, nume), tmp) 131 | else: 132 | GraphConvFunction._multiply(ctx, sel_input, grad_products, grad_weights.narrow(0, starte, nume), 133 | lambda a: a.unsqueeze(1).transpose_(2, 1), lambda b: b.unsqueeze(1)) 134 | 135 | # grad wrt input 136 | if ctx._idxe is not None: 137 | torch.index_select(weights, 0, ctx._idxe.narrow(0, starte, nume), out=tmp) 138 | GraphConvFunction._multiply(ctx, grad_products, tmp, sel_input, lambda a: a.unsqueeze(1), 139 | lambda b: b.transpose_(2, 1)) 140 | del tmp 141 | else: 142 | GraphConvFunction._multiply(ctx, grad_products, weights.narrow(0, starte, nume), sel_input, 143 | lambda a: a.unsqueeze(1), 144 | lambda b: b.transpose_(2, 1)) 145 | 146 | grad_input.index_add_(0, ctx._idxn.narrow(0, starte, nume), sel_input) 147 | 148 | startd += numd 149 | starte += nume 150 | del grad_products, sel_input 151 | 152 | return grad_input, grad_weights, None, None, None, None, None, None, None 153 | 154 | 155 | 156 | class GraphConvModule(nn.Module): 157 | """ Computes graph convolution using filter weights obtained from a filter generating network (`filter_net`). 158 | The input should be a 2D tensor of size (# nodes, `in_channels`). Multiple graphs can be concatenated in the same tensor (minibatch). 159 | 160 | Parameters: 161 | in_channels: number of input channels 162 | out_channels: number of output channels 163 | filter_net: filter-generating network transforming a 2D tensor (# edges, # edge features) to (# edges, in_channels*out_channels) or (# edges, in_channels) 164 | gc_info: GraphConvInfo object containing graph(s) structure information, can be also set with `set_info()` method. 165 | edge_mem_limit: block size (number of evaluated edges in parallel) for convolution evaluation, a low value reduces peak memory. 166 | """ 167 | 168 | def __init__(self, in_channels, out_channels, filter_net, gc_info=None, edge_mem_limit=1e20): 169 | super(GraphConvModule, self).__init__() 170 | 171 | self._in_channels = in_channels 172 | self._out_channels = out_channels 173 | self._fnet = filter_net 174 | self._edge_mem_limit = edge_mem_limit 175 | 176 | self.set_info(gc_info) 177 | 178 | def set_info(self, gc_info): 179 | self._gci = gc_info 180 | 181 | def forward(self, input): 182 | # get graph structure information tensors 183 | idxn, idxe, degs, degs_gpu, edgefeats = self._gci.get_buffers() 184 | edgefeats = Variable(edgefeats, requires_grad=False) 185 | 186 | # evalute and reshape filter weights 187 | weights = self._fnet(edgefeats) 188 | assert input.dim()==2 and weights.dim()==2 and (weights.size(1) == self._in_channels*self._out_channels or 189 | (self._in_channels == self._out_channels and weights.size(1) == self._in_channels)) 190 | if weights.size(1) == self._in_channels*self._out_channels: 191 | weights = weights.view(-1, self._in_channels, self._out_channels) 192 | 193 | return GraphConvFunction(self._in_channels, self._out_channels, idxn, idxe, degs, degs_gpu, self._edge_mem_limit)(input, weights) 194 | 195 | 196 | 197 | 198 | 199 | 200 | class GraphConvModulePureAutograd(nn.Module): 201 | """ 202 | Autograd-only equivalent of `GraphConvModule` + `GraphConvFunction`. Unfortunately, autograd needs to store intermediate products, which makes the module work only for very small graphs. The module is kept for didactic purposes only. 203 | """ 204 | 205 | def __init__(self, in_channels, out_channels, filter_net, gc_info=None): 206 | super(GraphConvModulePureAutograd, self).__init__() 207 | 208 | self._in_channels = in_channels 209 | self._out_channels = out_channels 210 | self._fnet = filter_net 211 | 212 | self.set_info(gc_info) 213 | 214 | def set_info(self, gc_info): 215 | self._gci = gc_info 216 | 217 | def forward(self, input): 218 | # get graph structure information tensors 219 | idxn, idxe, degs, edgefeats = self._gci.get_buffers() 220 | idxn = Variable(idxn, requires_grad=False) 221 | edgefeats = Variable(edgefeats, requires_grad=False) 222 | 223 | # evalute and reshape filter weights 224 | weights = self._fnet(edgefeats) 225 | assert input.dim()==2 and weights.dim()==2 and weights.size(1) == self._in_channels*self._out_channels 226 | weights = weights.view(-1, self._in_channels, self._out_channels) 227 | 228 | # select sequence of matching pairs of node and edge weights 229 | if idxe is not None: 230 | idxe = Variable(idxe, requires_grad=False) 231 | weights = torch.index_select(weights, 0, idxe) 232 | 233 | sel_input = torch.index_select(input, 0, idxn) 234 | 235 | # compute matrix-vector products 236 | products = torch.bmm(sel_input.view(-1,1,self._in_channels), weights) 237 | 238 | output = Variable(input.data.new(len(degs), self._out_channels)) 239 | 240 | # average over nodes 241 | k = 0 242 | for i in range(len(degs)): 243 | if degs[i]>0: 244 | output.index_copy_(0, Variable(torch.Tensor([i]).type_as(idxn.data)), torch.mean(products.narrow(0,k,degs[i]), 0).view(1,-1)) 245 | else: 246 | output.index_fill_(0, Variable(torch.Tensor([i]).type_as(idxn.data)), 0) 247 | k = k + degs[i] 248 | 249 | return output 250 | 251 | -------------------------------------------------------------------------------- /learning/ecc/GraphPoolInfo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs 3 | https://github.com/mys007/ecc 4 | https://arxiv.org/abs/1704.02901 5 | 2017 Martin Simonovsky 6 | """ 7 | from __future__ import division 8 | from __future__ import print_function 9 | from builtins import range 10 | 11 | import torch 12 | 13 | 14 | class GraphPoolInfo(object): 15 | """ Holds information about pooling in a vectorized form useful to `GraphPoolModule`. 16 | 17 | We assume that the node feature tensor (given to `GraphPoolModule` as input) is ordered by igraph vertex id, e.g. the fifth row corresponds to vertex with id=4. Batch processing is realized by concatenating all graphs into a large graph of disconnected components (and all node feature tensors into a large tensor). 18 | """ 19 | 20 | def __init__(self, *args, **kwargs): 21 | self._idxn = None #indices into input tensor of convolution (node features) 22 | self._degrees = None #in-degrees of output nodes (slices _idxn) 23 | self._degrees_gpu = None 24 | if len(args)>0 or len(kwargs)>0: 25 | self.set_batch(*args, **kwargs) 26 | 27 | def set_batch(self, poolmaps, graphs_from, graphs_to): 28 | """ Creates a representation of a given batch of graph poolings. 29 | 30 | Parameters: 31 | poolmaps: dict(s) mapping vertex id in coarsened graph to a list of vertex ids in input graph (defines pooling) 32 | graphs_from: input graph(s) 33 | graphs_to: coarsened graph(s) 34 | """ 35 | 36 | poolmaps = poolmaps if isinstance(poolmaps,(list,tuple)) else [poolmaps] 37 | graphs_from = graphs_from if isinstance(graphs_from,(list,tuple)) else [graphs_from] 38 | graphs_to = graphs_to if isinstance(graphs_to,(list,tuple)) else [graphs_to] 39 | 40 | idxn = [] 41 | degrees = [] 42 | p = 0 43 | 44 | for map, G_from, G_to in zip(poolmaps, graphs_from, graphs_to): 45 | for v in range(G_to.vcount()): 46 | nlist = map.get(v, []) 47 | idxn.extend([n+p for n in nlist]) 48 | degrees.append(len(nlist)) 49 | p += G_from.vcount() 50 | 51 | self._idxn = torch.LongTensor(idxn) 52 | self._degrees = torch.LongTensor(degrees) 53 | self._degrees_gpu = None 54 | 55 | def cuda(self): 56 | self._idxn = self._idxn.cuda() 57 | self._degrees_gpu = self._degrees.cuda() 58 | 59 | def get_buffers(self): 60 | """ Provides data to `GraphPoolModule`. 61 | """ 62 | return self._idxn, self._degrees, self._degrees_gpu 63 | -------------------------------------------------------------------------------- /learning/ecc/GraphPoolModule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs 3 | https://github.com/mys007/ecc 4 | https://arxiv.org/abs/1704.02901 5 | 2017 Martin Simonovsky 6 | """ 7 | from __future__ import division 8 | from __future__ import print_function 9 | from builtins import range 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable, Function 14 | from .GraphPoolInfo import GraphPoolInfo 15 | from . import cuda_kernels 16 | from . import utils 17 | 18 | class GraphPoolFunction(Function): 19 | """ Computes node feature aggregation for each node of the coarsened graph. The evaluation is computed in blocks of size `edge_mem_limit` to reduce peak memory load. See `GraphPoolInfo` for info on `idxn, degs`. 20 | """ 21 | 22 | AGGR_MEAN = 0 23 | AGGR_MAX = 1 24 | 25 | def __init__(self, idxn, degs, degs_gpu, aggr, edge_mem_limit=1e20): 26 | super(GraphPoolFunction, self).__init__() 27 | self._idxn = idxn 28 | self._degs = degs 29 | self._degs_gpu = degs_gpu 30 | self._aggr = aggr 31 | self._shards = utils.get_edge_shards(degs, edge_mem_limit) 32 | 33 | def forward(self, input): 34 | output = input.new(self._degs.numel(), input.size(1)) 35 | if self._aggr==GraphPoolFunction.AGGR_MAX: 36 | self._max_indices = self._idxn.new(self._degs.numel(), input.size(1)).fill_(-1) 37 | 38 | self._input_size = input.size() 39 | 40 | # loop over blocks of output nodes 41 | startd, starte = 0, 0 42 | for numd, nume in self._shards: 43 | 44 | sel_input = torch.index_select(input, 0, self._idxn.narrow(0,starte,nume)) 45 | 46 | # aggregate over nodes 47 | if self._idxn.is_cuda: 48 | if self._aggr==GraphPoolFunction.AGGR_MEAN: 49 | cuda_kernels.avgpool_fw(output.narrow(0,startd,numd), sel_input, self._degs_gpu.narrow(0,startd,numd)) 50 | elif self._aggr==GraphPoolFunction.AGGR_MAX: 51 | cuda_kernels.maxpool_fw(output.narrow(0,startd,numd), self._max_indices.narrow(0,startd,numd), sel_input, self._degs_gpu.narrow(0,startd,numd)) 52 | else: 53 | k = 0 54 | for i in range(startd, startd+numd): 55 | if self._degs[i]>0: 56 | if self._aggr==GraphPoolFunction.AGGR_MEAN: 57 | torch.mean(sel_input.narrow(0,k,self._degs[i]), 0, out=output[i]) 58 | elif self._aggr==GraphPoolFunction.AGGR_MAX: 59 | torch.max(sel_input.narrow(0,k,self._degs[i]), 0, out=(output[i], self._max_indices[i])) 60 | else: 61 | output[i].fill_(0) 62 | k = k + self._degs[i] 63 | 64 | startd += numd 65 | starte += nume 66 | del sel_input 67 | 68 | return output 69 | 70 | 71 | def backward(self, grad_output): 72 | grad_input = grad_output.new(self._input_size).fill_(0) 73 | 74 | # loop over blocks of output nodes 75 | startd, starte = 0, 0 76 | for numd, nume in self._shards: 77 | 78 | grad_sel_input = grad_output.new(nume, grad_output.size(1)) 79 | 80 | # grad wrt input 81 | if self._idxn.is_cuda: 82 | if self._aggr==GraphPoolFunction.AGGR_MEAN: 83 | cuda_kernels.avgpool_bw(grad_input, self._idxn.narrow(0,starte,nume), grad_output.narrow(0,startd,numd), self._degs_gpu.narrow(0,startd,numd)) 84 | elif self._aggr==GraphPoolFunction.AGGR_MAX: 85 | cuda_kernels.maxpool_bw(grad_input, self._idxn.narrow(0,starte,nume), self._max_indices.narrow(0,startd,numd), grad_output.narrow(0,startd,numd), self._degs_gpu.narrow(0,startd,numd)) 86 | else: 87 | k = 0 88 | for i in range(startd, startd+numd): 89 | if self._degs[i]>0: 90 | if self._aggr==GraphPoolFunction.AGGR_MEAN: 91 | torch.div(grad_output[i], self._degs[i], out=grad_sel_input[k]) 92 | if self._degs[i]>1: 93 | grad_sel_input.narrow(0, k+1, self._degs[i]-1).copy_( grad_sel_input[k].expand(self._degs[i]-1,1,grad_output.size(1)) ) 94 | elif self._aggr==GraphPoolFunction.AGGR_MAX: 95 | grad_sel_input.narrow(0, k, self._degs[i]).fill_(0).scatter_(0, self._max_indices[i].view(1,-1), grad_output[i].view(1,-1)) 96 | k = k + self._degs[i] 97 | 98 | grad_input.index_add_(0, self._idxn.narrow(0,starte,nume), grad_sel_input) 99 | 100 | startd += numd 101 | starte += nume 102 | del grad_sel_input 103 | 104 | return grad_input 105 | 106 | 107 | 108 | class GraphPoolModule(nn.Module): 109 | """ Performs graph pooling. 110 | The input should be a 2D tensor of size (# nodes, `in_channels`). Multiple graphs can be concatenated in the same tensor (minibatch). 111 | 112 | Parameters: 113 | aggr: aggregation type (GraphPoolFunction.AGGR_MEAN, GraphPoolFunction.AGGR_MAX) 114 | gp_info: GraphPoolInfo object containing node mapping information, can be also set with `set_info()` method. 115 | edge_mem_limit: block size (number of evaluated edges in parallel), a low value reduces peak memory. 116 | """ 117 | 118 | def __init__(self, aggr, gp_info=None, edge_mem_limit=1e20): 119 | super(GraphPoolModule, self).__init__() 120 | 121 | self._aggr = aggr 122 | self._edge_mem_limit = edge_mem_limit 123 | self.set_info(gp_info) 124 | 125 | def set_info(self, gp_info): 126 | self._gpi = gp_info 127 | 128 | def forward(self, input): 129 | idxn, degs, degs_gpu = self._gpi.get_buffers() 130 | return GraphPoolFunction(idxn, degs, degs_gpu, self._aggr, self._edge_mem_limit)(input) 131 | 132 | 133 | class GraphAvgPoolModule(GraphPoolModule): 134 | def __init__(self, gp_info=None, edge_mem_limit=1e20): 135 | super(GraphAvgPoolModule, self).__init__(GraphPoolFunction.AGGR_MEAN, gp_info, edge_mem_limit) 136 | 137 | class GraphMaxPoolModule(GraphPoolModule): 138 | def __init__(self, gp_info=None, edge_mem_limit=1e20): 139 | super(GraphMaxPoolModule, self).__init__(GraphPoolFunction.AGGR_MAX, gp_info, edge_mem_limit) -------------------------------------------------------------------------------- /learning/ecc/__init__.py: -------------------------------------------------------------------------------- 1 | from .GraphConvInfo import GraphConvInfo 2 | from .GraphConvModule import GraphConvModule, GraphConvFunction 3 | 4 | from .GraphPoolInfo import GraphPoolInfo 5 | from .GraphPoolModule import GraphAvgPoolModule, GraphMaxPoolModule 6 | 7 | from .utils import * 8 | -------------------------------------------------------------------------------- /learning/ecc/cuda_kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs 3 | https://github.com/mys007/ecc 4 | https://arxiv.org/abs/1704.02901 5 | 2017 Martin Simonovsky 6 | """ 7 | from __future__ import division 8 | from __future__ import print_function 9 | from builtins import range 10 | 11 | import torch 12 | try: 13 | import cupy.cuda 14 | from pynvrtc.compiler import Program 15 | except: 16 | pass 17 | from collections import namedtuple 18 | import numpy as np 19 | 20 | CUDA_NUM_THREADS = 1024 21 | 22 | def GET_BLOCKS(N): 23 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS; 24 | 25 | modules = {} 26 | 27 | def get_dtype(t): 28 | if isinstance(t, torch.cuda.FloatTensor): 29 | return 'float' 30 | elif isinstance(t, torch.cuda.DoubleTensor): 31 | return 'double' 32 | 33 | def get_kernel_func(kname, ksrc, dtype): 34 | if kname+dtype not in modules: 35 | ksrc = ksrc.replace('DTYPE', dtype) 36 | #prog = Program(ksrc.encode('utf-8'), (kname+dtype+'.cu').encode('utf-8')) 37 | #uncomment the line above and comment the line below if it causes the following error: AttributeError: 'Program' object has no attribute '_program' 38 | prog = Program(ksrc, kname+dtype+'.cu') 39 | ptx = prog.compile() 40 | log = prog._interface.nvrtcGetProgramLog(prog._program) 41 | if len(log.strip()) > 0: print(log) 42 | module = cupy.cuda.function.Module() 43 | module.load(bytes(ptx.encode())) 44 | modules[kname+dtype] = module 45 | else: 46 | module = modules[kname+dtype] 47 | 48 | Stream = namedtuple('Stream', ['ptr']) 49 | s = Stream(ptr=torch.cuda.current_stream().cuda_stream) 50 | 51 | return module.get_function(kname), s 52 | 53 | #### 54 | 55 | def conv_aggregate_fw_kernel_v2(**kwargs): 56 | kernel = r''' 57 | extern "C" 58 | __global__ void conv_aggregate_fw_kernel_v2(DTYPE* dest, const DTYPE* src, const long long* lengths, const long long* cslengths, int width, int N, int dest_stridex, int src_stridex, int blockDimy) { 59 | 60 | int x = blockIdx.x * blockDim.x + threadIdx.x; //one thread per feature channel, runs over all nodes 61 | if (x >= width) return; 62 | 63 | int i = blockIdx.y * blockDimy; 64 | int imax = min(N, i + blockDimy); 65 | dest += dest_stridex * i + x; 66 | src += src_stridex * (cslengths[i] - lengths[i]) + x; 67 | 68 | for (; i 0) { 71 | DTYPE sum = 0; 72 | for (int j=0; j= width) return; 95 | 96 | int i = blockIdx.y * blockDimy; 97 | int imax = min(N, i + blockDimy); 98 | dest += dest_stridex * (cslengths[i] - lengths[i]) + x; 99 | src += src_stridex * i + x; 100 | 101 | for (; i 0) { 104 | DTYPE val = *src / len; 105 | for (int j=0; j= width) return; 150 | 151 | for (int i=0; i 0) { 153 | long long src_step = lengths[i] * src_stridex; 154 | long long bestjj = -1; 155 | DTYPE best = -1e10; 156 | 157 | for (long long j = x, jj=0; j < src_step; j += src_stridex, ++jj) { 158 | if (src[j] > best) { 159 | best = src[j]; 160 | bestjj = jj; 161 | } 162 | } 163 | 164 | dest[x] = best; 165 | indices[x] = bestjj; 166 | 167 | src += src_step; 168 | } 169 | else { 170 | dest[x] = 0; 171 | indices[x] = -1; 172 | } 173 | 174 | dest += dest_stridex; 175 | indices += dest_stridex; 176 | } 177 | } 178 | ''' 179 | return kernel 180 | 181 | def maxpool_bw_kernel(**kwargs): 182 | kernel = r''' 183 | //also directly scatters results by dest_indices (saves one sparse intermediate buffer) 184 | extern "C" 185 | __global__ void maxpool_bw_kernel(DTYPE* dest, const long long* dest_indices, const long long* max_indices, const DTYPE* src, const long long* lengths, int width, int N, int dest_stridex, int src_stridex) { 186 | 187 | int x = blockIdx.x * blockDim.x + threadIdx.x; //one thread per feature channel, runs over all points 188 | if (x >= width) return; 189 | 190 | for (int i=0; i 0) { 192 | 193 | long long destidx = dest_indices[max_indices[x]]; 194 | dest[x + destidx * dest_stridex] += src[x]; //no need for atomicadd, only one threads cares about each feat 195 | 196 | dest_indices += lengths[i]; 197 | } 198 | 199 | src += src_stridex; 200 | max_indices += src_stridex; 201 | } 202 | } 203 | ''' 204 | return kernel 205 | 206 | 207 | def maxpool_fw(dest, indices, src, degs): 208 | n = degs.numel() 209 | w = src.size(1) 210 | assert n == dest.size(0) and w == dest.size(1) 211 | assert type(src)==type(dest) and isinstance(degs, torch.cuda.LongTensor) and isinstance(indices, torch.cuda.LongTensor) 212 | 213 | function, stream = get_kernel_func('maxpool_fw_kernel', maxpool_fw_kernel(), get_dtype(src)) 214 | function(args=[dest.data_ptr(), indices.data_ptr(), src.data_ptr(), degs.data_ptr(), np.int32(w), np.int32(n), np.int32(dest.stride(0)), np.int32(src.stride(0))], 215 | block=(CUDA_NUM_THREADS,1,1), grid=(GET_BLOCKS(w),1,1), stream=stream) 216 | 217 | def maxpool_bw(dest, idxn, indices, src, degs): 218 | n = degs.numel() 219 | w = src.size(1) 220 | assert n == src.size(0) and w == dest.size(1) 221 | assert type(src)==type(dest) and isinstance(degs, torch.cuda.LongTensor) and isinstance(indices, torch.cuda.LongTensor) and isinstance(idxn, torch.cuda.LongTensor) 222 | 223 | function, stream = get_kernel_func('maxpool_bw_kernel', maxpool_bw_kernel(), get_dtype(src)) 224 | function(args=[dest.data_ptr(), idxn.data_ptr(), indices.data_ptr(), src.data_ptr(), degs.data_ptr(), np.int32(w), np.int32(n), np.int32(dest.stride(0)), np.int32(src.stride(0))], 225 | block=(CUDA_NUM_THREADS,1,1), grid=(GET_BLOCKS(w),1,1), stream=stream) 226 | 227 | 228 | 229 | def avgpool_bw_kernel(**kwargs): 230 | kernel = r''' 231 | //also directly scatters results by dest_indices (saves one intermediate buffer) 232 | extern "C" 233 | __global__ void avgpool_bw_kernel(DTYPE* dest, const long long* dest_indices, const DTYPE* src, const long long* lengths, int width, int N, int dest_stridex, int src_stridex) { 234 | 235 | int x = blockIdx.x * blockDim.x + threadIdx.x; //one thread per feature channel, runs over all points 236 | if (x >= width) return; 237 | 238 | for (int i=0; i 0) { 240 | 241 | DTYPE val = src[x] / lengths[i]; 242 | 243 | for (int j = 0; j < lengths[i]; ++j) { 244 | long long destidx = dest_indices[j]; 245 | dest[x + destidx * dest_stridex] += val; //no need for atomicadd, only one threads cares about each feat 246 | } 247 | 248 | dest_indices += lengths[i]; 249 | } 250 | 251 | src += src_stridex; 252 | } 253 | } 254 | ''' 255 | return kernel 256 | 257 | 258 | def avgpool_fw(dest, src, degs): 259 | conv_aggregate_fw(dest, src, degs) 260 | 261 | def avgpool_bw(dest, idxn, src, degs): 262 | n = degs.numel() 263 | w = src.size(1) 264 | assert n == src.size(0) and w == dest.size(1) 265 | assert type(src)==type(dest) and isinstance(degs, torch.cuda.LongTensor) and isinstance(idxn, torch.cuda.LongTensor) 266 | 267 | function, stream = get_kernel_func('avgpool_bw_kernel', avgpool_bw_kernel(), get_dtype(src)) 268 | function(args=[dest.data_ptr(), idxn.data_ptr(), src.data_ptr(), degs.data_ptr(), np.int32(w), np.int32(n), np.int32(dest.stride(0)), np.int32(src.stride(0))], 269 | block=(CUDA_NUM_THREADS,1,1), grid=(GET_BLOCKS(w),1,1), stream=stream) 270 | -------------------------------------------------------------------------------- /learning/ecc/test_GraphConvModule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs 3 | https://github.com/mys007/ecc 4 | https://arxiv.org/abs/1704.02901 5 | 2017 Martin Simonovsky 6 | """ 7 | from __future__ import division 8 | from __future__ import print_function 9 | from builtins import range 10 | 11 | import unittest 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | from torch.autograd import Variable, gradcheck 16 | 17 | from .GraphConvModule import * 18 | from .GraphConvInfo import GraphConvInfo 19 | 20 | 21 | class TestGraphConvModule(unittest.TestCase): 22 | 23 | def test_gradcheck(self): 24 | 25 | torch.set_default_tensor_type('torch.DoubleTensor') #necessary for proper numerical gradient 26 | 27 | for cuda in range(0,2): 28 | # without idxe 29 | n,e,in_channels, out_channels = 20,50,10, 15 30 | input = torch.randn(n,in_channels) 31 | weights = torch.randn(e,in_channels,out_channels) 32 | idxn = torch.from_numpy(np.random.randint(n,size=e)) 33 | idxe = None 34 | degs = torch.LongTensor([5, 0, 15, 20, 10]) #strided conv 35 | degs_gpu = degs 36 | edge_mem_limit = 30 # some nodes will be combined, some not 37 | if cuda: 38 | input = input.cuda(); weights = weights.cuda(); idxn = idxn.cuda(); degs_gpu = degs_gpu.cuda() 39 | 40 | func = GraphConvFunction(in_channels, out_channels, idxn, idxe, degs, degs_gpu, edge_mem_limit=edge_mem_limit) 41 | data = (Variable(input, requires_grad=True), Variable(weights, requires_grad=True)) 42 | 43 | ok = gradcheck(func, data) 44 | self.assertTrue(ok) 45 | 46 | # with idxe 47 | weights = torch.randn(30,in_channels,out_channels) 48 | idxe = torch.from_numpy(np.random.randint(30,size=e)) 49 | if cuda: 50 | weights = weights.cuda(); idxe = idxe.cuda() 51 | 52 | func = GraphConvFunction(in_channels, out_channels, idxn, idxe, degs, degs_gpu, edge_mem_limit=edge_mem_limit) 53 | 54 | ok = gradcheck(func, data) 55 | self.assertTrue(ok) 56 | 57 | torch.set_default_tensor_type('torch.FloatTensor') 58 | 59 | def test_batch_splitting(self): 60 | 61 | n,e,in_channels, out_channels = 20,50,10, 15 62 | input = torch.randn(n,in_channels) 63 | weights = torch.randn(e,in_channels,out_channels) 64 | idxn = torch.from_numpy(np.random.randint(n,size=e)) 65 | idxe = None 66 | degs = torch.LongTensor([5, 0, 15, 20, 10]) #strided conv 67 | 68 | func = GraphConvFunction(in_channels, out_channels, idxn, idxe, degs, degs, edge_mem_limit=1e10) 69 | data = (Variable(input, requires_grad=True), Variable(weights, requires_grad=True)) 70 | output1 = func(*data) 71 | 72 | func = GraphConvFunction(in_channels, out_channels, idxn, idxe, degs, degs, edge_mem_limit=1) 73 | output2 = func(*data) 74 | 75 | self.assertLess((output1-output2).norm().data[0], 1e-6) 76 | 77 | 78 | 79 | if __name__ == '__main__': 80 | unittest.main() -------------------------------------------------------------------------------- /learning/ecc/test_GraphPoolModule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs 3 | https://github.com/mys007/ecc 4 | https://arxiv.org/abs/1704.02901 5 | 2017 Martin Simonovsky 6 | """ 7 | from __future__ import division 8 | from __future__ import print_function 9 | from builtins import range 10 | 11 | import unittest 12 | import numpy as np 13 | import torch 14 | from torch.autograd import Variable, gradcheck 15 | 16 | from .GraphPoolModule import * 17 | from .GraphPoolInfo import GraphPoolInfo 18 | 19 | 20 | class TestGraphConvModule(unittest.TestCase): 21 | 22 | def test_gradcheck(self): 23 | 24 | torch.set_default_tensor_type('torch.DoubleTensor') #necessary for proper numerical gradient 25 | 26 | for cuda in range(0,2): 27 | for aggr in range(0,2): 28 | n,in_channels = 20,10 29 | input = torch.randn(n,in_channels) 30 | idxn = torch.from_numpy(np.random.permutation(n)) 31 | degs = torch.LongTensor([2, 0, 3, 10, 5]) 32 | degs_gpu = degs 33 | edge_mem_limit = 30 # some nodes will be combined, some not 34 | if cuda: 35 | input = input.cuda(); idxn = idxn.cuda(); degs_gpu = degs_gpu.cuda() 36 | 37 | func = GraphPoolFunction(idxn, degs, degs_gpu, aggr=aggr, edge_mem_limit=edge_mem_limit) 38 | data = (Variable(input, requires_grad=True),) 39 | 40 | ok = gradcheck(func, data) 41 | self.assertTrue(ok) 42 | 43 | torch.set_default_tensor_type('torch.FloatTensor') 44 | 45 | def test_batch_splitting(self): 46 | n,in_channels = 20,10 47 | input = torch.randn(n,in_channels) 48 | idxn = torch.from_numpy(np.random.permutation(n)) 49 | degs = torch.LongTensor([2, 0, 3, 10, 5]) 50 | 51 | func = GraphPoolFunction(idxn, degs, degs, aggr=GraphPoolFunction.AGGR_MAX, edge_mem_limit=1e10) 52 | data = (Variable(input, requires_grad=True),) 53 | output1 = func(*data) 54 | 55 | func = GraphPoolFunction(idxn, degs, degs, aggr=GraphPoolFunction.AGGR_MAX, edge_mem_limit=1) 56 | output2 = func(*data) 57 | 58 | self.assertLess((output1-output2).norm(), 1e-6) 59 | 60 | if __name__ == '__main__': 61 | unittest.main() -------------------------------------------------------------------------------- /learning/ecc/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs 3 | https://github.com/mys007/ecc 4 | https://arxiv.org/abs/1704.02901 5 | 2017 Martin Simonovsky 6 | """ 7 | from __future__ import division 8 | from __future__ import print_function 9 | from builtins import range 10 | 11 | import random 12 | import numpy as np 13 | import torch 14 | 15 | import ecc 16 | 17 | def graph_info_collate_classification(batch, edge_func): 18 | """ Collates a list of dataset samples into a single batch. We assume that all samples have the same number of resolutions. 19 | 20 | Each sample is a tuple of following elements: 21 | features: 2D Tensor of node features 22 | classes: LongTensor of class ids 23 | graphs: list of graphs, each for one resolution 24 | pooldata: list of triplets, each for one resolution: (pooling map, finer graph, coarser graph) 25 | """ 26 | features, classes, graphs, pooldata = list(zip(*batch)) 27 | graphs_by_layer = list(zip(*graphs)) 28 | pooldata_by_layer = list(zip(*pooldata)) 29 | 30 | features = torch.cat([torch.from_numpy(f) for f in features]) 31 | if features.dim()==1: features = features.view(-1,1) 32 | 33 | classes = torch.LongTensor(classes) 34 | 35 | GIs, PIs = [], [] 36 | for graphs in graphs_by_layer: 37 | GIs.append( ecc.GraphConvInfo(graphs, edge_func) ) 38 | for pooldata in pooldata_by_layer: 39 | PIs.append( ecc.GraphPoolInfo(*zip(*pooldata)) ) 40 | 41 | return features, classes, GIs, PIs 42 | 43 | 44 | def unique_rows(data): 45 | """ Filters unique rows from a 2D np array and also returns inverse indices. Used for edge feature compaction. """ 46 | # https://stackoverflow.com/questions/16970982/find-unique-rows-in-numpy-array 47 | uniq, indices = np.unique(data.view(data.dtype.descr * data.shape[1]), return_inverse=True) 48 | return uniq.view(data.dtype).reshape(-1, data.shape[1]), indices 49 | 50 | def one_hot_discretization(feat, clip_min, clip_max, upweight): 51 | indices = np.clip(np.round(feat), clip_min, clip_max).astype(int).reshape((-1,)) 52 | onehot = np.zeros((feat.shape[0], clip_max - clip_min + 1)) 53 | onehot[np.arange(onehot.shape[0]), indices] = onehot.shape[1] if upweight else 1 54 | return onehot 55 | 56 | def get_edge_shards(degs, edge_mem_limit): 57 | """ Splits iteration over nodes into shards, approximately limited by `edge_mem_limit` edges per shard. 58 | Returns a list of pairs indicating how many output nodes and edges to process in each shard.""" 59 | d = degs if isinstance(degs, np.ndarray) else degs.numpy() 60 | cs = np.cumsum(d) 61 | cse = cs // edge_mem_limit 62 | _, cse_i, cse_c = np.unique(cse, return_index=True, return_counts=True) 63 | 64 | shards = [] 65 | for b in range(len(cse_i)): 66 | numd = cse_c[b] 67 | nume = (cs[-1] if b==len(cse_i)-1 else cs[cse_i[b+1]-1]) - cs[cse_i[b]] + d[cse_i[b]] 68 | shards.append( (int(numd), int(nume)) ) 69 | return shards 70 | -------------------------------------------------------------------------------- /learning/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Jul 4 09:22:55 2019 5 | 6 | @author: landrieuloic 7 | """ 8 | 9 | """ 10 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs 11 | http://arxiv.org/abs/1711.09869 12 | 2017 Loic Landrieu, Martin Simonovsky 13 | """ 14 | import argparse 15 | import numpy as np 16 | import sys 17 | sys.path.append("./learning") 18 | from metrics import * 19 | 20 | parser = argparse.ArgumentParser(description='Evaluation function for S3DIS') 21 | 22 | parser.add_argument('--odir', default='./results/s3dis/best', help='Directory to store results') 23 | parser.add_argument('--dataset', default='s3dis', help='Directory to store results') 24 | parser.add_argument('--cvfold', default='123456', help='which fold to consider') 25 | 26 | args = parser.parse_args() 27 | 28 | 29 | 30 | if args.dataset == 's3dis': 31 | n_labels = 13 32 | inv_class_map = {0:'ceiling', 1:'floor', 2:'wall', 3:'column', 4:'beam', 5:'window', 6:'door', 7:'table', 8:'chair', 9:'bookcase', 10:'sofa', 11:'board', 12:'clutter'} 33 | base_name = args.odir+'/cv' 34 | elif args.dataset == 'vkitti': 35 | n_labels = 13 36 | inv_class_map = {0:'Terrain', 1:'Tree', 2:'Vegetation', 3:'Building', 4:'Road', 5:'GuardRail', 6:'TrafficSign', 7:'TrafficLight', 8:'Pole', 9:'Misc', 10:'Truck', 11:'Car', 12:'Van'} 37 | base_name = args.odir+'/cv' 38 | 39 | C = ConfusionMatrix(n_labels) 40 | C.confusion_matrix=np.zeros((n_labels, n_labels)) 41 | 42 | 43 | for i_fold in range(len(args.cvfold)): 44 | fold = int(args.cvfold[i_fold]) 45 | cm = ConfusionMatrix(n_labels) 46 | cm.confusion_matrix=np.load(base_name+str(fold) +'/pointwise_cm.npy') 47 | print("Fold %d : \t OA = %3.2f \t mA = %3.2f \t mIoU = %3.2f" % (fold, \ 48 | 100 * ConfusionMatrix.get_overall_accuracy(cm) \ 49 | , 100 * ConfusionMatrix.get_mean_class_accuracy(cm) \ 50 | , 100 * ConfusionMatrix.get_average_intersection_union(cm) 51 | )) 52 | C.confusion_matrix += cm.confusion_matrix 53 | 54 | print("\nOverall accuracy : %3.2f %%" % (100 * (ConfusionMatrix.get_overall_accuracy(C)))) 55 | print("Mean accuracy : %3.2f %%" % (100 * (ConfusionMatrix.get_mean_class_accuracy(C)))) 56 | print("Mean IoU : %3.2f %%\n" % (100 * (ConfusionMatrix.get_average_intersection_union(C)))) 57 | print(" Classe : IoU") 58 | for c in range(0,n_labels): 59 | print (" %12s : %6.2f %% \t %.1e points" %(inv_class_map[c],100*ConfusionMatrix.get_intersection_union_per_class(C)[c], ConfusionMatrix.count_gt(C,c))) 60 | -------------------------------------------------------------------------------- /learning/graphnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs 3 | http://arxiv.org/abs/1711.09869 4 | 2017 Loic Landrieu, Martin Simonovsky 5 | """ 6 | from __future__ import division 7 | from __future__ import print_function 8 | from builtins import range 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | from learning import ecc 14 | from learning.modules import RNNGraphConvModule, ECC_CRFModule, GRUCellEx, LSTMCellEx 15 | 16 | 17 | def create_fnet(widths, orthoinit, llbias, bnidx=-1): 18 | """ Creates feature-generating network, a multi-layer perceptron. 19 | Parameters: 20 | widths: list of widths of layers (including input and output widths) 21 | orthoinit: whether to use orthogonal weight initialization 22 | llbias: whether to use bias in the last layer 23 | bnidx: index of batch normalization (-1 if not used) 24 | """ 25 | fnet_modules = [] 26 | for k in range(len(widths)-2): 27 | fnet_modules.append(nn.Linear(widths[k], widths[k+1])) 28 | if orthoinit: init.orthogonal_(fnet_modules[-1].weight, gain=init.calculate_gain('relu')) 29 | if bnidx==k: fnet_modules.append(nn.BatchNorm1d(widths[k+1])) 30 | fnet_modules.append(nn.ReLU(True)) 31 | fnet_modules.append(nn.Linear(widths[-2], widths[-1], bias=llbias)) 32 | if orthoinit: init.orthogonal_(fnet_modules[-1].weight) 33 | if bnidx==len(widths)-1: fnet_modules.append(nn.BatchNorm1d(fnet_modules[-1].weight.size(0))) 34 | return nn.Sequential(*fnet_modules) 35 | 36 | 37 | class GraphNetwork(nn.Module): 38 | """ It is constructed in a flexible way based on `config` string, which contains sequence of comma-delimited layer definiton tokens layer_arg1_arg2_... See README.md for examples. 39 | """ 40 | def __init__(self, config, nfeat, fnet_widths, fnet_orthoinit=True, fnet_llbias=True, fnet_bnidx=-1, edge_mem_limit=1e20, use_pyg = True, cuda = True): 41 | super(GraphNetwork, self).__init__() 42 | self.gconvs = [] 43 | 44 | for d, conf in enumerate(config.split(',')): 45 | conf = conf.strip().split('_') 46 | 47 | if conf[0]=='f': #Fully connected layer; args: output_feats 48 | self.add_module(str(d), nn.Linear(nfeat, int(conf[1]))) 49 | nfeat = int(conf[1]) 50 | elif conf[0]=='b': #Batch norm; args: not_affine 51 | self.add_module(str(d), nn.BatchNorm1d(nfeat, eps=1e-5, affine=len(conf)==1)) 52 | elif conf[0]=='r': #ReLU; 53 | self.add_module(str(d), nn.ReLU(True)) 54 | elif conf[0]=='d': #Dropout; args: dropout_prob 55 | self.add_module(str(d), nn.Dropout(p=float(conf[1]), inplace=False)) 56 | 57 | elif conf[0]=='crf': #ECC-CRF; args: repeats 58 | nrepeats = int(conf[1]) 59 | 60 | fnet = create_fnet(fnet_widths + [nfeat*nfeat], fnet_orthoinit, fnet_llbias, fnet_bnidx) 61 | gconv = ecc.GraphConvModule(nfeat, nfeat, fnet, edge_mem_limit=edge_mem_limit) 62 | crf = ECC_CRFModule(gconv, nrepeats) 63 | self.add_module(str(d), crf) 64 | self.gconvs.append(gconv) 65 | 66 | elif conf[0]=='gru' or conf[0]=='lstm': #RNN-ECC args: repeats, mv=False, layernorm=True, ingate=True, cat_all=True 67 | nrepeats = int(conf[1]) 68 | vv = bool(int(conf[2])) if len(conf)>2 else True # whether ECC does matrix-value mult or element-wise mult 69 | layernorm = bool(int(conf[3])) if len(conf)>3 else True 70 | ingate = bool(int(conf[4])) if len(conf)>4 else True 71 | cat_all = bool(int(conf[5])) if len(conf)>5 else True 72 | 73 | fnet = create_fnet(fnet_widths + [nfeat**2 if not vv else nfeat], fnet_orthoinit, fnet_llbias, fnet_bnidx) 74 | if conf[0]=='gru': 75 | cell = GRUCellEx(nfeat, nfeat, bias=True, layernorm=layernorm, ingate=ingate) 76 | else: 77 | cell = LSTMCellEx(nfeat, nfeat, bias=True, layernorm=layernorm, ingate=ingate) 78 | gconv = RNNGraphConvModule(cell, fnet, nfeat, vv = vv, nrepeats=nrepeats, cat_all=cat_all, edge_mem_limit=edge_mem_limit, use_pyg = use_pyg, cuda = cuda) 79 | self.add_module(str(d), gconv) 80 | self.gconvs.append(gconv) 81 | if cat_all: nfeat *= nrepeats + 1 82 | 83 | elif len(conf[0])>0: 84 | raise NotImplementedError('Unknown module: ' + conf[0]) 85 | 86 | 87 | def set_info(self, gc_infos, cuda): 88 | """ Provides convolution modules with graph structure information for the current batch. 89 | """ 90 | gc_infos = gc_infos if isinstance(gc_infos,(list,tuple)) else [gc_infos] 91 | for i,gc in enumerate(self.gconvs): 92 | if cuda: gc_infos[i].cuda() 93 | gc.set_info(gc_infos[i]) 94 | 95 | def forward(self, input): 96 | for module in self._modules.values(): 97 | input = module(input) 98 | return input 99 | 100 | -------------------------------------------------------------------------------- /learning/metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from builtins import range 4 | 5 | import numpy as np 6 | 7 | # extended official code from http://www.semantic3d.net/scripts/metric.py 8 | class ConfusionMatrix: 9 | """Streaming interface to allow for any source of predictions. Initialize it, count predictions one by one, then print confusion matrix and intersection-union score""" 10 | def __init__(self, number_of_labels = 2): 11 | self.number_of_labels = number_of_labels 12 | self.confusion_matrix = np.zeros(shape=(self.number_of_labels,self.number_of_labels)) 13 | def count_predicted(self, ground_truth, predicted, number_of_added_elements=1): 14 | self.confusion_matrix[ground_truth][predicted] += number_of_added_elements 15 | 16 | def count_predicted_batch(self, ground_truth_vec, predicted): # added 17 | for i in range(ground_truth_vec.shape[0]): 18 | self.confusion_matrix[:,predicted[i]] += ground_truth_vec[i,:] 19 | 20 | def count_predicted_batch_hard(self, ground_truth_vec, predicted): # added 21 | for i in range(ground_truth_vec.shape[0]): 22 | self.confusion_matrix[ground_truth_vec[i],predicted[i]] += 1 23 | 24 | """labels are integers from 0 to number_of_labels-1""" 25 | def get_count(self, ground_truth, predicted): 26 | return self.confusion_matrix[ground_truth][predicted] 27 | """returns list of lists of integers; use it as result[ground_truth][predicted] 28 | to know how many samples of class ground_truth were reported as class predicted""" 29 | def get_confusion_matrix(self): 30 | return self.confusion_matrix 31 | """returns list of 64-bit floats""" 32 | def get_intersection_union_per_class(self): 33 | matrix_diagonal = [self.confusion_matrix[i][i] for i in range(self.number_of_labels)] 34 | errors_summed_by_row = [0] * self.number_of_labels 35 | for row in range(self.number_of_labels): 36 | for column in range(self.number_of_labels): 37 | if row != column: 38 | errors_summed_by_row[row] += self.confusion_matrix[row][column] 39 | errors_summed_by_column = [0] * self.number_of_labels 40 | for column in range(self.number_of_labels): 41 | for row in range(self.number_of_labels): 42 | if row != column: 43 | errors_summed_by_column[column] += self.confusion_matrix[row][column] 44 | 45 | divisor = [0] * self.number_of_labels 46 | for i in range(self.number_of_labels): 47 | divisor[i] = matrix_diagonal[i] + errors_summed_by_row[i] + errors_summed_by_column[i] 48 | if matrix_diagonal[i] == 0: 49 | divisor[i] = 1 50 | 51 | return [float(matrix_diagonal[i]) / divisor[i] for i in range(self.number_of_labels)] 52 | """returns 64-bit float""" 53 | 54 | def get_overall_accuracy(self): 55 | matrix_diagonal = 0 56 | all_values = 0 57 | for row in range(self.number_of_labels): 58 | for column in range(self.number_of_labels): 59 | all_values += self.confusion_matrix[row][column] 60 | if row == column: 61 | matrix_diagonal += self.confusion_matrix[row][column] 62 | if all_values == 0: 63 | all_values = 1 64 | return float(matrix_diagonal) / all_values 65 | 66 | 67 | def get_average_intersection_union(self): 68 | values = self.get_intersection_union_per_class() 69 | class_seen = ((self.confusion_matrix.sum(1)+self.confusion_matrix.sum(0))!=0).sum() 70 | return sum(values) / class_seen 71 | 72 | def get_mean_class_accuracy(self): # added 73 | re = 0 74 | for i in range(self.number_of_labels): 75 | re = re + self.confusion_matrix[i][i] / max(1,np.sum(self.confusion_matrix[i,:])) 76 | return re/self.number_of_labels 77 | 78 | def count_gt(self, ground_truth): 79 | return self.confusion_matrix[ground_truth,:].sum() 80 | 81 | def compute_predicted_transitions(in_component, edg_source, edg_target): 82 | 83 | pred_transitions = in_component[edg_source] != in_component[edg_target] 84 | return pred_transitions 85 | 86 | #----------------------------------------------------------- 87 | def compute_boundary_recall(is_transition, pred_transitions): 88 | return 100*((is_transition==pred_transitions)*is_transition).sum()/is_transition.sum() 89 | 90 | #----------------------------------------------------------- 91 | def compute_boundary_precision(is_transition, pred_transitions): 92 | return 100*((is_transition==pred_transitions)*pred_transitions).sum()/pred_transitions.sum() 93 | #-------------------------------------------- 94 | 95 | def mode(array, only_freq=False): 96 | value, counts = np.unique(array, return_counts=True) 97 | if only_freq: 98 | return np.amax(counts) 99 | else: 100 | return value[np.argmax(counts)], np.amax(counts) 101 | #------------------------------------------------ 102 | def compute_OOA(components, labels): 103 | hard_labels = labels.argmax(1) 104 | correct_labels=0 105 | for comp in components: 106 | dump, freq = mode(hard_labels[comp]) 107 | correct_labels+=freq 108 | return 100*correct_labels/len(labels) 109 | -------------------------------------------------------------------------------- /learning/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs 3 | http://arxiv.org/abs/1711.09869 4 | 2017 Loic Landrieu, Martin Simonovsky 5 | """ 6 | from __future__ import division 7 | from __future__ import print_function 8 | from builtins import range 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as nnf 13 | from torch.autograd import Variable 14 | from learning import ecc 15 | 16 | HAS_PYG = False 17 | try: 18 | from torch_geometric.nn.conv import MessagePassing 19 | from torch_geometric.nn.inits import uniform 20 | HAS_PYG = True 21 | except: 22 | pass 23 | 24 | if HAS_PYG: 25 | class NNConv(MessagePassing): 26 | r"""The continuous kernel-based convolutional operator from the 27 | `"Neural Message Passing for Quantum Chemistry" 28 | `_ paper. 29 | This convolution is also known as the edge-conditioned convolution from the 30 | `"Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on 31 | Graphs" `_ paper (see 32 | :class:`torch_geometric.nn.conv.ECConv` for an alias): 33 | 34 | .. math:: 35 | \mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i + 36 | \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot 37 | h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}), 38 | 39 | where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* 40 | a MLP. 41 | 42 | Args: 43 | in_channels (int): Size of each input sample. 44 | out_channels (int): Size of each output sample. 45 | nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that 46 | maps edge features :obj:`edge_attr` of shape :obj:`[-1, 47 | num_edge_features]` to shape 48 | :obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by 49 | :class:`torch.nn.Sequential`. 50 | aggr (string, optional): The aggregation scheme to use 51 | (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). 52 | (default: :obj:`"add"`) 53 | root_weight (bool, optional): If set to :obj:`False`, the layer will 54 | not add the transformed root node features to the output. 55 | (default: :obj:`True`) 56 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 57 | an additive bias. (default: :obj:`True`) 58 | **kwargs (optional): Additional arguments of 59 | :class:`torch_geometric.nn.conv.MessagePassing`. 60 | """ 61 | def __init__(self, 62 | in_channels, 63 | out_channels, 64 | aggr='mean', 65 | root_weight=False, 66 | bias=False, 67 | vv=True, 68 | flow="target_to_source", 69 | negative_slope=0.2, 70 | softmax=False, 71 | **kwargs): 72 | super(NNConv, self).__init__(aggr=aggr, **kwargs) 73 | 74 | self.in_channels = in_channels 75 | self.out_channels = out_channels 76 | self.aggr = aggr 77 | self.vv = vv 78 | self.negative_slope = negative_slope 79 | self.softmax = softmax 80 | 81 | if root_weight: 82 | self.root = Parameter(torch.Tensor(in_channels, out_channels)) 83 | else: 84 | self.register_parameter('root', None) 85 | 86 | if bias: 87 | self.bias = Parameter(torch.Tensor(out_channels)) 88 | else: 89 | self.register_parameter('bias', None) 90 | 91 | self.reset_parameters() 92 | 93 | def reset_parameters(self): 94 | uniform(self.in_channels, self.root) 95 | uniform(self.in_channels, self.bias) 96 | 97 | def forward(self, x, edge_index, weights): 98 | """""" 99 | x = x.unsqueeze(-1) if x.dim() == 1 else x 100 | return self.propagate(edge_index, x=x, weights=weights) 101 | 102 | def message(self, edge_index_i, x_j, size_i, weights): 103 | if not self.vv: 104 | weight = weights.view(-1, self.in_channels, self.out_channels) 105 | if self.softmax: # APPLY A TWO DIMENSIONAL NON-DEPENDENT SPARSE SOFTMAX 106 | weight = F.leaky_relu(weight, self.negative_slope) 107 | weight = torch.cat([softmax(weight[:, k, :], edge_index_i, size_i).unsqueeze(1) for k in range(self.out_channels)], dim=1) 108 | return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1) 109 | else: 110 | weight = weights.view(-1, self.in_channels) 111 | if self.softmax: 112 | weight = F.leaky_relu(weight, self.negative_slope) 113 | weight = torch.cat([softmax(w.unsqueeze(-1), edge_index_i, size_i).t() for w in weight.t()], dim=0).t() 114 | return x_j * weight 115 | 116 | def update(self, aggr_out, x): 117 | if self.root is not None: 118 | aggr_out = aggr_out + torch.mm(x, self.root) 119 | if self.bias is not None: 120 | aggr_out = aggr_out + self.bias 121 | return aggr_out 122 | 123 | def __repr__(self): 124 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 125 | self.out_channels) 126 | 127 | 128 | class RNNGraphConvModule(nn.Module): 129 | """ 130 | Computes recurrent graph convolution using filter weights obtained from a Filter generating network (`filter_net`). 131 | Its result is passed to RNN `cell` and the process is repeated over `nrepeats` iterations. 132 | Weight sharing over iterations is done both in RNN cell and in Filter generating network. 133 | """ 134 | def __init__(self, cell, filter_net, nfeat, vv = True, gc_info=None, nrepeats=1, cat_all=False, edge_mem_limit=1e20, use_pyg = True, cuda = True): 135 | super(RNNGraphConvModule, self).__init__() 136 | self._cell = cell 137 | self._isLSTM = 'LSTM' in type(cell).__name__ 138 | self._fnet = filter_net 139 | self._nrepeats = nrepeats 140 | self._cat_all = cat_all 141 | self._edge_mem_limit = edge_mem_limit 142 | self.set_info(gc_info) 143 | self.use_pyg = use_pyg 144 | if use_pyg: 145 | self.nn = NNConv(nfeat, nfeat, vv = vv) 146 | if cuda: 147 | self.nn = self.nn.cuda() 148 | 149 | def set_info(self, gc_info): 150 | self._gci = gc_info 151 | 152 | def forward(self, hx): 153 | # get graph structure information tensors 154 | idxn, idxe, degs, degs_gpu, edgefeats = self._gci.get_buffers() 155 | 156 | edge_indexes = self._gci.get_pyg_buffers() 157 | ###edgefeats = Variable(edgefeats, requires_grad=False) 158 | 159 | # evalute and reshape filter weights (shared among RNN iterations) 160 | weights = self._fnet(edgefeats) 161 | nc = hx.size(1) 162 | assert hx.dim()==2 and weights.dim()==2 and weights.size(1) in [nc, nc*nc] 163 | if weights.size(1) != nc: 164 | weights = weights.view(-1, nc, nc) 165 | 166 | # repeatedly evaluate RNN cell 167 | hxs = [hx] 168 | if self._isLSTM: 169 | cx = Variable(hx.data.new(hx.size()).fill_(0)) 170 | 171 | for r in range(self._nrepeats): 172 | if self.use_pyg: 173 | input = self.nn(hx, edge_indexes, weights) 174 | else: 175 | input = ecc.GraphConvFunction.apply(hx, weights, nc, nc, idxn, idxe, degs, degs_gpu, 176 | self._edge_mem_limit) 177 | if self._isLSTM: 178 | hx, cx = self._cell(input, (hx, cx)) 179 | else: 180 | hx = self._cell(input, hx) 181 | hxs.append(hx) 182 | 183 | return torch.cat(hxs,1) if self._cat_all else hx 184 | 185 | class ECC_CRFModule(nn.Module): 186 | """ 187 | Adapted "Conditional Random Fields as Recurrent Neural Networks" (https://arxiv.org/abs/1502.03240) 188 | `propagation` should be ECC with Filter generating network producing 2D matrix. 189 | """ 190 | def __init__(self, propagation, nrepeats=1): 191 | super(ECC_CRFModule, self).__init__() 192 | self._propagation = propagation 193 | self._nrepeats = nrepeats 194 | 195 | def forward(self, input): 196 | Q = nnf.softmax(input) 197 | for i in range(self._nrepeats): 198 | Q = self._propagation(Q) # todo: speedup possible by sharing computation of fnet 199 | Q = input - Q 200 | if i < self._nrepeats-1: 201 | Q = nnf.softmax(Q) # last softmax will be part of cross-entropy loss 202 | return Q 203 | 204 | 205 | class GRUCellEx(nn.GRUCell): 206 | """ Usual GRU cell extended with layer normalization and input gate. 207 | """ 208 | def __init__(self, input_size, hidden_size, bias=True, layernorm=True, ingate=True): 209 | super(GRUCellEx, self).__init__(input_size, hidden_size, bias) 210 | self._layernorm = layernorm 211 | self._ingate = ingate 212 | if layernorm: 213 | self.add_module('ini', nn.InstanceNorm1d(1, eps=1e-5, affine=False, track_running_stats=False)) 214 | self.add_module('inh', nn.InstanceNorm1d(1, eps=1e-5, affine=False, track_running_stats=False)) 215 | if ingate: 216 | self.add_module('ig', nn.Linear(hidden_size, input_size, bias=True)) 217 | 218 | def _normalize(self, gi, gh): 219 | if self._layernorm: # layernorm on input&hidden, as in https://arxiv.org/abs/1607.06450 (Layer Normalization) 220 | gi = self._modules['ini'](gi.unsqueeze(1)).squeeze(1) 221 | gh = self._modules['inh'](gh.unsqueeze(1)).squeeze(1) 222 | return gi, gh 223 | 224 | def forward(self, input, hidden): 225 | if self._ingate: 226 | input = torch.sigmoid(self._modules['ig'](hidden)) * input 227 | 228 | # GRUCell in https://github.com/pytorch/pytorch/blob/master/torch/nn/_functions/rnn.py extended with layer normalization 229 | if input.is_cuda and torch.__version__.split('.')[0]=='0': 230 | gi = nnf.linear(input, self.weight_ih) 231 | gh = nnf.linear(hidden, self.weight_hh) 232 | gi, gh = self._normalize(gi, gh) 233 | state = torch.nn._functions.thnn.rnnFusedPointwise.GRUFused 234 | try: #pytorch >=0.3 235 | return state.apply(gi, gh, hidden) if self.bias_ih is None else state.apply(gi, gh, hidden, self.bias_ih, self.bias_hh) 236 | except: #pytorch <=0.2 237 | return state()(gi, gh, hidden) if self.bias_ih is None else state()(gi, gh, hidden, self.bias_ih, self.bias_hh) 238 | 239 | gi = nnf.linear(input, self.weight_ih) 240 | gh = nnf.linear(hidden, self.weight_hh) 241 | gi, gh = self._normalize(gi, gh) 242 | i_r, i_i, i_n = gi.chunk(3, 1) 243 | h_r, h_i, h_n = gh.chunk(3, 1) 244 | bih_r, bih_i, bih_n = self.bias_ih.chunk(3) 245 | bhh_r, bhh_i, bhh_n = self.bias_hh.chunk(3) 246 | 247 | resetgate = torch.sigmoid(i_r + bih_r + h_r + bhh_r) 248 | inputgate = torch.sigmoid(i_i + bih_i + h_i + bhh_i) 249 | newgate = torch.tanh(i_n + bih_n + resetgate * (h_n + bhh_n)) 250 | hy = newgate + inputgate * (hidden - newgate) 251 | return hy 252 | 253 | def __repr__(self): 254 | s = super(GRUCellEx, self).__repr__() + '(' 255 | if self._ingate: 256 | s += 'ingate' 257 | if self._layernorm: 258 | s += ' layernorm' 259 | return s + ')' 260 | 261 | 262 | class LSTMCellEx(nn.LSTMCell): 263 | """ Usual LSTM cell extended with layer normalization and input gate. 264 | """ 265 | def __init__(self, input_size, hidden_size, bias=True, layernorm=True, ingate=True): 266 | super(LSTMCellEx, self).__init__(input_size, hidden_size, bias) 267 | self._layernorm = layernorm 268 | self._ingate = ingate 269 | if layernorm: 270 | self.add_module('ini', nn.InstanceNorm1d(1, eps=1e-5, affine=False, track_running_stats=False)) 271 | self.add_module('inh', nn.InstanceNorm1d(1, eps=1e-5, affine=False, track_running_stats=False)) 272 | if ingate: 273 | self.add_module('ig', nn.Linear(hidden_size, input_size, bias=True)) 274 | 275 | def _normalize(self, gi, gh): 276 | if self._layernorm: # layernorm on input&hidden, as in https://arxiv.org/abs/1607.06450 (Layer Normalization) 277 | gi = self._modules['ini'](gi.unsqueeze(1)).squeeze(1) 278 | gh = self._modules['inh'](gh.unsqueeze(1)).squeeze(1) 279 | return gi, gh 280 | 281 | def forward(self, input, hidden): 282 | if self._ingate: 283 | input = torch.sigmoid(self._modules['ig'](hidden[0])) * input 284 | 285 | # GRUCell in https://github.com/pytorch/pytorch/blob/master/torch/nn/_functions/rnn.py extended with layer normalization 286 | if input.is_cuda and torch.__version__.split('.')[0]=='0': 287 | gi = nnf.linear(input, self.weight_ih) 288 | gh = nnf.linear(hidden[0], self.weight_hh) 289 | gi, gh = self._normalize(gi, gh) 290 | state = torch.nn._functions.thnn.rnnFusedPointwise.LSTMFused 291 | try: #pytorch >=0.3 292 | return state.apply(gi, gh, hidden[1]) if self.bias_ih is None else state.apply(gi, gh, hidden[1], self.bias_ih, self.bias_hh) 293 | except: #pytorch <=0.2 294 | return state()(gi, gh, hidden[1]) if self.bias_ih is None else state()(gi, gh, hidden[1], self.bias_ih, self.bias_hh) 295 | 296 | gi = nnf.linear(input, self.weight_ih, self.bias_ih) 297 | gh = nnf.linear(hidden[0], self.weight_hh, self.bias_hh) 298 | gi, gh = self._normalize(gi, gh) 299 | 300 | ingate, forgetgate, cellgate, outgate = (gi+gh).chunk(4, 1) 301 | ingate = torch.sigmoid(ingate) 302 | forgetgate = torch.sigmoid(forgetgate) 303 | cellgate = torch.tanh(cellgate) 304 | outgate = torch.sigmoid(outgate) 305 | 306 | cy = (forgetgate * hidden[1]) + (ingate * cellgate) 307 | hy = outgate * torch.tanh(cy) 308 | return hy, cy 309 | 310 | def __repr__(self): 311 | s = super(LSTMCellEx, self).__repr__() + '(' 312 | if self._ingate: 313 | s += 'ingate' 314 | if self._layernorm: 315 | s += ' layernorm' 316 | return s + ')' 317 | -------------------------------------------------------------------------------- /learning/pointnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs 3 | http://arxiv.org/abs/1711.09869 4 | 2017 Loic Landrieu, Martin Simonovsky 5 | """ 6 | from __future__ import division 7 | from __future__ import print_function 8 | from builtins import range 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as nnf 14 | from torch.autograd import Variable 15 | 16 | class STNkD(nn.Module): 17 | """ 18 | Spatial Transformer Net for PointNet, producing a KxK transformation matrix. 19 | Parameters: 20 | nfeat: number of input features 21 | nf_conv: list of layer widths of point embeddings (before maxpool) 22 | nf_fc: list of layer widths of joint embeddings (after maxpool) 23 | """ 24 | def __init__(self, nfeat, nf_conv, nf_fc, K=2, norm = 'batch', affine = True, n_group = 1): 25 | super(STNkD, self).__init__() 26 | 27 | modules = [] 28 | for i in range(len(nf_conv)): 29 | modules.append(nn.Conv1d(nf_conv[i-1] if i>0 else nfeat, nf_conv[i], 1)) 30 | if norm == 'batch': 31 | modules.append(nn.BatchNorm1d(nf_conv[i])) 32 | elif norm == 'layer': 33 | modules.append(nn.GroupNorm(1,nf_conv[i])) 34 | elif norm == 'group': 35 | modules.append(nn.GroupNorm(n_group,nf_conv[i])) 36 | modules.append(nn.ReLU(True)) 37 | self.convs = nn.Sequential(*modules) 38 | 39 | modules = [] 40 | for i in range(len(nf_fc)): 41 | modules.append(nn.Linear(nf_fc[i-1] if i>0 else nf_conv[-1], nf_fc[i])) 42 | if norm == 'batch': 43 | modules.append(nn.BatchNorm1d(nf_fc[i])) 44 | elif norm == 'layer': 45 | modules.append(nn.GroupNorm(1,nf_fc[i])) 46 | elif norm == 'group': 47 | modules.append(nn.GroupNorm(n_group,nf_fc[i])) 48 | modules.append(nn.ReLU(True)) 49 | self.fcs = nn.Sequential(*modules) 50 | 51 | self.proj = nn.Linear(nf_fc[-1], K*K) 52 | nn.init.constant_(self.proj.weight, 0); nn.init.constant_(self.proj.bias, 0) 53 | self.eye = torch.eye(K).unsqueeze(0) 54 | 55 | def forward(self, input): 56 | self.eye = self.eye.cuda() if input.is_cuda else self.eye 57 | input = self.convs(input) 58 | input = nnf.max_pool1d(input, input.size(2)).squeeze(2) 59 | input = self.fcs(input) 60 | input = self.proj(input) 61 | return input.view(-1,self.eye.size(1),self.eye.size(2)) + Variable(self.eye) 62 | 63 | class PointNet(nn.Module): 64 | """ 65 | PointNet with only one spatial transformer and additional "global" input concatenated after maxpool. 66 | Parameters: 67 | nf_conv: list of layer widths of point embeddings (before maxpool) 68 | nf_fc: list of layer widths of joint embeddings (after maxpool) 69 | nfeat: number of input features 70 | nf_conv_stn, nf_fc_stn, nfeat_stn: as above but for Spatial transformer 71 | nfeat_global: number of features concatenated after maxpooling 72 | prelast_do: dropout after the pre-last parameteric layer 73 | last_ac: whether to use batch norm and relu after the last parameteric layer 74 | """ 75 | def __init__(self, nf_conv, nf_fc, nf_conv_stn, nf_fc_stn, nfeat, nfeat_stn=2, nfeat_global=1, prelast_do=0.5, last_ac=False, is_res=False, norm = 'batch', affine = True, n_group = 1, last_bn = False): 76 | 77 | super(PointNet, self).__init__() 78 | torch.manual_seed(0) 79 | if nfeat_stn > 0: 80 | self.stn = STNkD(nfeat_stn, nf_conv_stn, nf_fc_stn, norm=norm, n_group = n_group) 81 | self.nfeat_stn = nfeat_stn 82 | 83 | modules = [] 84 | for i in range(len(nf_conv)): 85 | modules.append(nn.Conv1d(nf_conv[i-1] if i>0 else nfeat, nf_conv[i], 1)) 86 | if norm == 'batch': 87 | modules.append(nn.BatchNorm1d(nf_conv[i])) 88 | elif norm == 'layer': 89 | modules.append(nn.GroupNorm(1, nf_conv[i])) 90 | elif norm == 'group': 91 | modules.append(nn.GroupNorm(n_group, nf_conv[i])) 92 | modules.append(nn.ReLU(True)) 93 | 94 | # Initialization of BN parameters. 95 | 96 | self.convs = nn.Sequential(*modules) 97 | 98 | modules = [] 99 | for i in range(len(nf_fc)): 100 | modules.append(nn.Linear(nf_fc[i-1] if i>0 else nf_conv[-1]+nfeat_global, nf_fc[i])) 101 | if i0: 110 | modules.append(nn.Dropout(prelast_do)) 111 | if is_res: #init with small number so that at first the residual pointnet is close to zero 112 | nn.init.normal_(modules[-1].weight, mean=0, std = 1e-2) 113 | nn.init.normal_(modules[-1].bias, mean=0, std = 1e-2) 114 | 115 | #if last_bn: 116 | #modules.append(nn.BatchNorm1d(nf_fc[-1])) 117 | 118 | self.fcs = nn.Sequential(*modules) 119 | 120 | def forward(self, input, input_global): 121 | if self.nfeat_stn > 0: 122 | T = self.stn(input[:,:self.nfeat_stn,:]) 123 | xy_transf = torch.bmm(input[:,:2,:].transpose(1,2), T).transpose(1,2) 124 | input = torch.cat([xy_transf, input[:,2:,:]], 1) 125 | 126 | input = self.convs(input) 127 | input = nnf.max_pool1d(input, input.size(2)).squeeze(2) 128 | if input_global is not None: 129 | if len(input_global.shape)== 1 or input_global.shape[1]==1: 130 | input = torch.cat([input, input_global.view(-1,1)], 1) 131 | else: 132 | input = torch.cat([input, input_global], 1) 133 | return self.fcs(input) 134 | 135 | 136 | 137 | 138 | class CloudEmbedder(): 139 | """ Evaluates PointNet on superpoints. Too small superpoints are assigned zero embeddings. Can optionally apply memory mongering 140 | (https://arxiv.org/pdf/1604.06174.pdf) to decrease memory usage. 141 | """ 142 | def __init__(self, args): 143 | self.args = args 144 | self.bw_hook = lambda: None # could be more elegant in the upcoming pytorch release: http://bit.ly/2A8PI7p 145 | self.run = self.run_full_monger if args.ptn_mem_monger else self.run_full 146 | 147 | def run_full(self, model, clouds_meta, clouds_flag, clouds, clouds_global): 148 | """ Simply evaluates all clouds in a differentiable way, assumes that all pointnet's feature maps fit into mem.""" 149 | idx_valid = torch.nonzero(clouds_flag.eq(0)).squeeze() 150 | if self.args.cuda: 151 | clouds, clouds_global, idx_valid = clouds.cuda(), clouds_global.cuda(), idx_valid.cuda() 152 | clouds, clouds_global = Variable(clouds, volatile=not model.training), Variable(clouds_global, volatile=not model.training) 153 | #print('Ptn with', clouds.size(0), 'clouds') 154 | 155 | out = model.ptn(clouds, clouds_global) 156 | descriptors = Variable(out.data.new(clouds_flag.size(0), out.size(1)).fill_(0)) 157 | descriptors.index_copy_(0, Variable(idx_valid), out) 158 | return descriptors 159 | 160 | def run_full_monger(self, model, clouds_meta, clouds_flag, clouds, clouds_global): 161 | """ Evaluates all clouds in forward pass, but uses memory mongering to compute backward pass.""" 162 | idx_valid = torch.nonzero(clouds_flag.eq(0)).squeeze() 163 | if self.args.cuda: 164 | clouds, clouds_global, idx_valid = clouds.cuda(), clouds_global.cuda(), idx_valid.cuda() 165 | #print('Ptn with', clouds.size(0), 'clouds') 166 | with torch.no_grad(): 167 | out = model.ptn(Variable(clouds), (clouds_global)) 168 | if not model.training: 169 | out = Variable(out.data, requires_grad=model.training) # cut autograd 170 | if model.training: 171 | out = Variable(out.data, requires_grad=model.training) 172 | def bw_hook(): 173 | out_v2 = model.ptn(Variable(clouds), Variable(clouds_global)) # re-run fw pass 174 | out_v2.backward(out.grad) 175 | 176 | self.bw_hook = bw_hook 177 | 178 | descriptors = Variable(out.data.new(clouds_flag.size(0), out.size(1)).fill_(0)) 179 | descriptors.index_copy_(0, Variable(idx_valid), out) 180 | return descriptors 181 | 182 | class LocalCloudEmbedder(): 183 | """ Local PointNet 184 | """ 185 | def __init__(self, args): 186 | self.nfeat_stn = args.ptn_nfeat_stn 187 | self.stn_as_global = args.stn_as_global 188 | 189 | def run_batch(self, model, clouds, clouds_global, *excess): 190 | """ Evaluates all clouds in a differentiable way, use a batch approach. 191 | Use when embedding many small point clouds with small PointNets at once""" 192 | #cudnn cannot handle arrays larger than 2**16 in one go, uses batch 193 | batch_size = 2**16-1 194 | n_batches = int((clouds.shape[0]-1)/batch_size) 195 | if self.nfeat_stn > 0: 196 | T = model.stn(clouds[:batch_size,:self.nfeat_stn,:]) 197 | for i in range(1,n_batches+1): 198 | T = torch.cat((T,model.stn(clouds[i * batch_size:(i+1) * batch_size,:self.nfeat_stn,:]))) 199 | xy_transf = torch.bmm(clouds[:,:2,:].transpose(1,2), T).transpose(1,2) 200 | clouds = torch.cat([xy_transf, clouds[:,2:,:]], 1) 201 | if self.stn_as_global: 202 | clouds_global = torch.cat([clouds_global, T.view(-1,4)], 1) 203 | 204 | out = model.ptn(clouds[:batch_size,:,:], clouds_global[:batch_size,:]) 205 | for i in range(1,n_batches+1): 206 | out = torch.cat((out,model.ptn(clouds[i * batch_size:(i+1) * batch_size,:,:], clouds_global[i * batch_size:(i+1) * batch_size,:]))) 207 | return nnf.normalize(out) 208 | 209 | def run_batch_cpu(self, model, clouds, clouds_global, *excess): 210 | """ Evaluates the cloud on CPU, but put the values in the CPU as soon as they are computed""" 211 | #cudnn cannot handle arrays larger than 2**16 in one go, uses batch 212 | batch_size = 2**10-1 213 | n_batches = int(clouds.shape[0]/batch_size) 214 | emb_total = self.run_batch(model, clouds[:batch_size,:,:], clouds_global[:batch_size,:]).cpu() 215 | for i in range(1,n_batches+1): 216 | emb = self.run_batch(model, clouds[i * batch_size:(i+1) * batch_size,:,:], clouds_global[i * batch_size:(i+1) * batch_size,:]) 217 | emb_total = torch.cat((emb_total,emb.cpu())) 218 | return emb_total 219 | 220 | -------------------------------------------------------------------------------- /learning/s3dis_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs 3 | http://arxiv.org/abs/1711.09869 4 | 2017 Loic Landrieu, Martin Simonovsky 5 | """ 6 | from __future__ import division 7 | from __future__ import print_function 8 | from builtins import range 9 | 10 | import sys 11 | sys.path.append("./learning") 12 | 13 | import random 14 | import numpy as np 15 | import os 16 | import functools 17 | import torch 18 | import torchnet as tnt 19 | import h5py 20 | import spg 21 | from sklearn.linear_model import RANSACRegressor 22 | 23 | def get_datasets(args, test_seed_offset=0): 24 | """ Gets training and test datasets. """ 25 | 26 | # Load superpoints graphs 27 | testlist, trainlist, validlist = [], [], [] 28 | valid_names = ['hallway_1.h5', 'hallway_6.h5', 'hallway_11.h5', 'office_1.h5' \ 29 | , 'office_6.h5', 'office_11.h5', 'office_16.h5', 'office_21.h5', 'office_26.h5' \ 30 | , 'office_31.h5', 'office_36.h5'\ 31 | ,'WC_2.h5', 'storage_1.h5', 'storage_5.h5', 'conferenceRoom_2.h5', 'auditorium_1.h5'] 32 | 33 | #if args.db_test_name == 'test' then the test set is the evaluation set 34 | #otherwise it serves as valdiation set to select the best epoch 35 | 36 | for n in range(1,7): 37 | if n != args.cvfold: 38 | path = '{}/superpoint_graphs/Area_{:d}/'.format(args.S3DIS_PATH, n) 39 | for fname in sorted(os.listdir(path)): 40 | if fname.endswith(".h5") and not (args.use_val_set and fname in valid_names): 41 | #training set 42 | trainlist.append(spg.spg_reader(args, path + fname, True)) 43 | if fname.endswith(".h5") and (args.use_val_set and fname in valid_names): 44 | #validation set 45 | validlist.append(spg.spg_reader(args, path + fname, True)) 46 | path = '{}/superpoint_graphs/Area_{:d}/'.format(args.S3DIS_PATH, args.cvfold) 47 | 48 | #evaluation set 49 | for fname in sorted(os.listdir(path)): 50 | if fname.endswith(".h5"): 51 | testlist.append(spg.spg_reader(args, path + fname, True)) 52 | 53 | # Normalize edge features 54 | if args.spg_attribs01: 55 | trainlist, testlist, validlist, scaler = spg.scaler01(trainlist, testlist, validlist=validlist) 56 | 57 | return tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in trainlist], 58 | functools.partial(spg.loader, train=True, args=args, db_path=args.S3DIS_PATH)), \ 59 | tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in testlist], 60 | functools.partial(spg.loader, train=False, args=args, db_path=args.S3DIS_PATH, test_seed_offset=test_seed_offset)), \ 61 | tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in validlist], 62 | functools.partial(spg.loader, train=False, args=args, db_path=args.S3DIS_PATH, test_seed_offset=test_seed_offset)), \ 63 | scaler 64 | 65 | 66 | def get_info(args): 67 | edge_feats = 0 68 | for attrib in args.edge_attribs.split(','): 69 | a = attrib.split('/')[0] 70 | if a in ['delta_avg', 'delta_std', 'xyz']: 71 | edge_feats += 3 72 | else: 73 | edge_feats += 1 74 | if args.loss_weights == 'none': 75 | weights = np.ones((13,),dtype='f4') 76 | else: 77 | weights = h5py.File(args.S3DIS_PATH + "/parsed/class_count.h5")["class_count"][:].astype('f4') 78 | weights = weights[:,[i for i in range(6) if i != args.cvfold-1]].sum(1) 79 | weights = weights.mean()/weights 80 | if args.loss_weights == 'sqrt': 81 | weights = np.sqrt(weights) 82 | weights = torch.from_numpy(weights).cuda() if args.cuda else torch.from_numpy(weights) 83 | return { 84 | 'node_feats': 14 if args.pc_attribs=='' else len(args.pc_attribs), 85 | 'edge_feats': edge_feats, 86 | 'class_weights': weights, 87 | 'classes': 13, 88 | 'inv_class_map': {0:'ceiling', 1:'floor', 2:'wall', 3:'column', 4:'beam', 5:'window', 6:'door', 7:'table', 8:'chair', 9:'bookcase', 10:'sofa', 11:'board', 12:'clutter'}, 89 | } 90 | 91 | 92 | 93 | def preprocess_pointclouds(args): 94 | """ Preprocesses data by splitting them by components and normalizing.""" 95 | S3DIS_PATH = args.S3DIS_PATH 96 | class_count = np.zeros((13,6),dtype='int') 97 | for n in range(1,7): 98 | pathP = '{}/parsed/Area_{:d}/'.format(S3DIS_PATH, n) 99 | if args.supervized_partition: 100 | pathD = '{}/features_supervision/Area_{:d}/'.format(S3DIS_PATH, n) 101 | else: 102 | pathD = '{}/features/Area_{:d}/'.format(S3DIS_PATH, n) 103 | pathC = '{}/superpoint_graphs/Area_{:d}/'.format(S3DIS_PATH, n) 104 | if not os.path.exists(pathP): 105 | os.makedirs(pathP) 106 | random.seed(n) 107 | 108 | for file in os.listdir(pathC): 109 | print(file) 110 | if file.endswith(".h5"): 111 | f = h5py.File(pathD + file, 'r') 112 | xyz = f['xyz'][:] 113 | rgb = f['rgb'][:].astype(np.float) 114 | 115 | labels = f['labels'][:] 116 | hard_labels = np.argmax(labels[:,1:],1) 117 | label_count = np.bincount(hard_labels, minlength=13) 118 | class_count[:,n-1] = class_count[:,n-1] + label_count 119 | 120 | if not args.supervized_partition: 121 | lpsv = f['geof'][:] 122 | lpsv -= 0.5 #normalize 123 | else: 124 | lpsv = np.stack([f["geof"][:] ]).squeeze() 125 | # rescale to [-0.5,0.5]; keep xyz 126 | 127 | if args.plane_model_elevation: 128 | if args.supervized_partition: #already computed 129 | e = f['elevation'][:] 130 | else: #simple plane model 131 | low_points = ((xyz[:,2]-xyz[:,2].min() < 0.5)).nonzero()[0] 132 | reg = RANSACRegressor(random_state=0).fit(xyz[low_points,:2], xyz[low_points,2]) 133 | e = xyz[:,2]-reg.predict(xyz[:,:2]) 134 | else: #compute elevation from zmin 135 | e = xyz[:,2] / 4 - 0.5 # (4m rough guess) 136 | 137 | rgb = rgb/255.0 - 0.5 138 | 139 | room_center = xyz[:,[0,1]].mean(0) #compute distance to room center, useful to detect walls and doors 140 | distance_to_center = np.sqrt(((xyz[:,[0,1]]-room_center)**2).sum(1)) 141 | distance_to_center = (distance_to_center - distance_to_center.mean())/distance_to_center.std() 142 | 143 | ma, mi = np.max(xyz,axis=0,keepdims=True), np.min(xyz,axis=0,keepdims=True) 144 | xyzn = (xyz - mi) / (ma - mi + 1e-8) # as in PointNet ("normalized location as to the room (from 0 to 1)") 145 | 146 | P = np.concatenate([xyz, rgb, e[:,np.newaxis], lpsv, xyzn, distance_to_center[:,None]], axis=1) 147 | 148 | f = h5py.File(pathC + file, 'r') 149 | numc = len(f['components'].keys()) 150 | 151 | with h5py.File(pathP + file, 'w') as hf: 152 | hf.create_dataset(name='centroid',data=xyz.mean(0)) 153 | for c in range(numc): 154 | idx = f['components/{:d}'.format(c)][:].flatten() 155 | if idx.size > 10000: # trim extra large segments, just for speed-up of loading time 156 | ii = random.sample(range(idx.size), k=10000) 157 | idx = idx[ii] 158 | hf.create_dataset(name='{:d}'.format(c), data=P[idx,...]) 159 | 160 | path = '{}/parsed/'.format(S3DIS_PATH) 161 | data_file = h5py.File(path+'class_count.h5', 'w') 162 | data_file.create_dataset('class_count', data=class_count, dtype='int') 163 | 164 | if __name__ == "__main__": 165 | import argparse 166 | parser = argparse.ArgumentParser(description='Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs') 167 | parser.add_argument('--S3DIS_PATH', default='datasets/s3dis') 168 | parser.add_argument('--supervized_partition', type=int, default=0) 169 | parser.add_argument('--plane_model_elevation', type=int, default=0, help='compute elevation with a simple RANSAC based plane model') 170 | args = parser.parse_args() 171 | preprocess_pointclouds(args) 172 | -------------------------------------------------------------------------------- /learning/sema3d_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs 3 | http://arxiv.org/abs/1711.09869 4 | 2017 Loic Landrieu, Martin Simonovsky 5 | """ 6 | from __future__ import division 7 | from __future__ import print_function 8 | from builtins import range 9 | 10 | import random 11 | import numpy as np 12 | import os 13 | import functools 14 | import torch 15 | import torchnet as tnt 16 | import h5py 17 | import spg 18 | 19 | 20 | def get_datasets(args, test_seed_offset=0): 21 | 22 | train_names = ['bildstein_station1', 'bildstein_station5', 'domfountain_station1', 'domfountain_station3', 'neugasse_station1', 'sg27_station1', 'sg27_station2', 'sg27_station5', 'sg27_station9', 'sg28_station4', 'untermaederbrunnen_station1'] 23 | valid_names = ['bildstein_station3', 'domfountain_station2', 'sg27_station4', 'untermaederbrunnen_station3'] 24 | 25 | if args.db_train_name == 'train': 26 | trainset = ['train/' + f for f in train_names] 27 | elif args.db_train_name == 'trainval': 28 | trainset = ['train/' + f for f in train_names + valid_names] 29 | 30 | validset = [] 31 | testset = [] 32 | if args.use_val_set: 33 | validset = ['train/' + f for f in valid_names] 34 | if args.db_test_name == 'testred': 35 | testset = ['test_reduced/' + os.path.splitext(f)[0] for f in os.listdir(args.SEMA3D_PATH + '/superpoint_graphs/test_reduced')] 36 | elif args.db_test_name == 'testfull': 37 | testset = ['test_full/' + os.path.splitext(f)[0] for f in os.listdir(args.SEMA3D_PATH + '/superpoint_graphs/test_full')] 38 | 39 | # Load superpoints graphs 40 | testlist, trainlist, validlist = [], [], [] 41 | for n in trainset: 42 | trainlist.append(spg.spg_reader(args, args.SEMA3D_PATH + '/superpoint_graphs/' + n + '.h5', True)) 43 | for n in validset: 44 | validlist.append(spg.spg_reader(args, args.SEMA3D_PATH + '/superpoint_graphs/' + n + '.h5', True)) 45 | for n in testset: 46 | testlist.append(spg.spg_reader(args, args.SEMA3D_PATH + '/superpoint_graphs/' + n + '.h5', True)) 47 | 48 | # Normalize edge features 49 | if args.spg_attribs01: 50 | trainlist, testlist, validlist, scaler = spg.scaler01(trainlist, testlist, validlist=validlist) 51 | 52 | return tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in trainlist], 53 | functools.partial(spg.loader, train=True, args=args, db_path=args.SEMA3D_PATH)), \ 54 | tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in testlist], 55 | functools.partial(spg.loader, train=False, args=args, db_path=args.SEMA3D_PATH, test_seed_offset=test_seed_offset)), \ 56 | tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in validlist], 57 | functools.partial(spg.loader, train=False, args=args, db_path=args.SEMA3D_PATH, test_seed_offset=test_seed_offset)),\ 58 | scaler 59 | 60 | 61 | def get_info(args): 62 | edge_feats = 0 63 | for attrib in args.edge_attribs.split(','): 64 | a = attrib.split('/')[0] 65 | if a in ['delta_avg', 'delta_std', 'xyz']: 66 | edge_feats += 3 67 | else: 68 | edge_feats += 1 69 | if args.loss_weights == 'none': 70 | weights = np.ones((8,),dtype='f4') 71 | else: 72 | weights = h5py.File(args.SEMA3D_PATH + "/parsed/class_count.h5")["class_count"][:].astype('f4') 73 | weights = weights.mean()/weights 74 | if args.loss_weights == 'sqrt': 75 | weights = np.sqrt(weights) 76 | weights = torch.from_numpy(weights).cuda() if args.cuda else torch.from_numpy(weights) 77 | return { 78 | 'node_feats': 14 if args.pc_attribs=='' else len(args.pc_attribs), 79 | 'edge_feats': edge_feats, 80 | 'class_weights': weights, 81 | 'classes': 8, 82 | 'inv_class_map': {0:'terrain_man', 1:'terrain_nature', 2:'veget_hi', 3:'veget_low', 4:'building', 5:'scape', 6:'artefact', 7:'cars'}, 83 | } 84 | 85 | def preprocess_pointclouds(SEMA3D_PATH): 86 | """ Preprocesses data by splitting them by components and normalizing.""" 87 | class_count = np.zeros((8,),dtype='int') 88 | for n in ['train', 'test_reduced', 'test_full']: 89 | pathP = '{}/parsed/{}/'.format(SEMA3D_PATH, n) 90 | if args.supervised_partition : 91 | pathD = '{}/features_supervision/{}/'.format(SEMA3D_PATH, n) 92 | else: 93 | pathD = '{}/features/{}/'.format(SEMA3D_PATH, n) 94 | pathC = '{}/superpoint_graphs/{}/'.format(SEMA3D_PATH, n) 95 | if not os.path.exists(pathP): 96 | os.makedirs(pathP) 97 | random.seed(0) 98 | 99 | for file in os.listdir(pathC): 100 | print(file) 101 | if file.endswith(".h5"): 102 | f = h5py.File(pathD + file, 'r') 103 | 104 | if n == 'train': 105 | labels = f['labels'][:] 106 | hard_labels = np.argmax(labels[:,1:],1) 107 | label_count = np.bincount(hard_labels, minlength=8) 108 | class_count = class_count + label_count 109 | 110 | xyz = f['xyz'][:] 111 | rgb = f['rgb'][:].astype(np.float) 112 | elpsv = np.concatenate((f['xyz'][:,2][:,None], f['geof'][:]), axis=1) 113 | 114 | # rescale to [-0.5,0.5]; keep xyz 115 | elpsv[:,0] /= 100 # (rough guess) 116 | elpsv[:,1:] -= 0.5 117 | rgb = rgb/255.0 - 0.5 118 | 119 | P = np.concatenate([xyz, rgb, elpsv], axis=1) 120 | 121 | f = h5py.File(pathC + file, 'r') 122 | numc = len(f['components'].keys()) 123 | 124 | with h5py.File(pathP + file, 'w') as hf: 125 | hf.create_dataset(name='centroid',data=xyz.mean(0)) 126 | for c in range(numc): 127 | idx = f['components/{:d}'.format(c)][:].flatten() 128 | if idx.size > 10000: # trim extra large segments, just for speed-up of loading time 129 | ii = random.sample(range(idx.size), k=10000) 130 | idx = idx[ii] 131 | 132 | hf.create_dataset(name='{:d}'.format(c), data=P[idx,...]) 133 | path = '{}/parsed/'.format(SEMA3D_PATH) 134 | data_file = h5py.File(path+'class_count.h5', 'w') 135 | data_file.create_dataset('class_count', data=class_count, dtype='int') 136 | 137 | if __name__ == "__main__": 138 | import argparse 139 | parser = argparse.ArgumentParser(description='Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs') 140 | parser.add_argument('--SEMA3D_PATH', default='datasets/semantic3d') 141 | parser.add_argument('--supervised_partition', default=0, type=int, help = 'wether to use supervized partition features') 142 | args = parser.parse_args() 143 | preprocess_pointclouds(args.SEMA3D_PATH) 144 | -------------------------------------------------------------------------------- /learning/spg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs 3 | http://arxiv.org/abs/1711.09869 4 | 2017 Loic Landrieu, Martin Simonovsky 5 | """ 6 | from __future__ import division 7 | from __future__ import print_function 8 | from builtins import range 9 | 10 | import random 11 | import numpy as np 12 | import os 13 | import math 14 | import transforms3d 15 | import torch 16 | import ecc 17 | import h5py 18 | from sklearn import preprocessing 19 | import igraph 20 | 21 | 22 | 23 | def spg_edge_features(edges, node_att, edge_att, args): 24 | """ Assembles edge features from edge attributes and differences of node attributes. """ 25 | columns = [] 26 | for attrib in args.edge_attribs.split(','): 27 | attrib = attrib.split('/') 28 | a, opt = attrib[0], attrib[1].lower() if len(attrib)==2 else '' 29 | 30 | if a in ['delta_avg', 'delta_std']: 31 | columns.append(edge_att[a]) 32 | elif a=='constant': # for isotropic baseline 33 | columns.append(np.ones((edges.shape[0],1), dtype=np.float32)) 34 | elif a in ['nlength','surface','volume', 'size', 'xyz']: 35 | attr = node_att[a] 36 | if opt=='d': # difference 37 | attr = attr[edges[:,0],:] - attr[edges[:,1],:] 38 | elif opt=='ld': # log ratio 39 | attr = np.log(attr + 1e-10) 40 | attr = attr[edges[:,0],:] - attr[edges[:,1],:] 41 | elif opt=='r': # ratio 42 | attr = attr[edges[:,0],:] / (attr[edges[:,1],:] + 1e-10) 43 | else: 44 | raise NotImplementedError 45 | columns.append(attr) 46 | else: 47 | raise NotImplementedError 48 | 49 | return np.concatenate(columns, axis=1).astype(np.float32) 50 | 51 | def scaler01(trainlist, testlist, transform_train=True, validlist = []): 52 | """ Scale edge features to 0 mean 1 stddev """ 53 | edge_feats = np.concatenate([ trainlist[i][3] for i in range(len(trainlist)) ], 0) 54 | scaler = preprocessing.StandardScaler().fit(edge_feats) 55 | 56 | if transform_train: 57 | for i in range(len(trainlist)): 58 | scaler.transform(trainlist[i][3], copy=False) 59 | for i in range(len(testlist)): 60 | scaler.transform(testlist[i][3], copy=False) 61 | if len(validlist)>0: 62 | for i in range(len(validlist)): 63 | scaler.transform(validlist[i][3], copy=False) 64 | return trainlist, testlist, validlist, scaler 65 | 66 | def spg_reader(args, fname, incl_dir_in_name=False): 67 | """ Loads a supergraph from H5 file. """ 68 | f = h5py.File(fname,'r') 69 | 70 | if f['sp_labels'].size > 0: 71 | node_gt_size = f['sp_labels'][:].astype(np.int64) # column 0: no of unlabeled points, column 1+: no of labeled points per class 72 | node_gt = np.argmax(node_gt_size[:,1:], 1)[:,None] 73 | node_gt[node_gt_size[:,1:].sum(1)==0,:] = -100 # superpoints without labels are to be ignored in loss computation 74 | else: 75 | N = f['sp_point_count'].shape[0] 76 | node_gt_size = np.concatenate([f['sp_point_count'][:].astype(np.int64), np.zeros((N,8), dtype=np.int64)], 1) 77 | node_gt = np.zeros((N,1), dtype=np.int64) 78 | 79 | node_att = {} 80 | node_att['xyz'] = f['sp_centroids'][:] 81 | node_att['nlength'] = np.maximum(0, f['sp_length'][:]) 82 | node_att['volume'] = np.maximum(0, f['sp_volume'][:] ** 2) 83 | node_att['surface'] = np.maximum(0, f['sp_surface'][:] ** 2) 84 | node_att['size'] = f['sp_point_count'][:] 85 | 86 | edges = np.concatenate([ f['source'][:], f['target'][:] ], axis=1).astype(np.int64) 87 | 88 | edge_att = {} 89 | edge_att['delta_avg'] = f['se_delta_mean'][:] 90 | edge_att['delta_std'] = f['se_delta_std'][:] 91 | 92 | if args.spg_superedge_cutoff > 0: 93 | filtered = np.linalg.norm(edge_att['delta_avg'],axis=1) < args.spg_superedge_cutoff 94 | edges = edges[filtered,:] 95 | edge_att['delta_avg'] = edge_att['delta_avg'][filtered,:] 96 | edge_att['delta_std'] = edge_att['delta_std'][filtered,:] 97 | 98 | edge_feats = spg_edge_features(edges, node_att, edge_att, args) 99 | 100 | name = os.path.basename(fname)[:-len('.h5')] 101 | if incl_dir_in_name: name = os.path.basename(os.path.dirname(fname)) + '/' + name 102 | 103 | return node_gt, node_gt_size, edges, edge_feats, name 104 | 105 | 106 | def spg_to_igraph(node_gt, node_gt_size, edges, edge_feats, fname): 107 | """ Builds representation of superpoint graph as igraph. """ 108 | targets = np.concatenate([node_gt, node_gt_size], axis=1) 109 | G = igraph.Graph(n=node_gt.shape[0], edges=edges.tolist(), directed=True, 110 | edge_attrs={'f':edge_feats}, 111 | vertex_attrs={'v':list(range(node_gt.shape[0])), 't':targets, 's':node_gt_size.sum(1)}) 112 | return G, fname 113 | 114 | def random_neighborhoods(G, num, order): 115 | """ Samples `num` random neighborhoods of size `order`. 116 | Graph nodes are then treated as set, i.e. after hardcutoff, neighborhoods may be broken (sort of data augmentation). """ 117 | centers = random.sample(range(G.vcount()), k=num) 118 | neighb = G.neighborhood(centers, order) 119 | subset = [item for sublist in neighb for item in sublist] 120 | subset = sorted(set(subset)) 121 | return G.subgraph(subset) 122 | 123 | def k_big_enough(G, minpts, k): 124 | """ Returns a induced graph on maximum k superpoints of size >= minpts (smaller ones are not counted) """ 125 | valid = np.array(G.vs['s']) >= minpts 126 | n = np.argwhere(np.cumsum(valid)<=k)[-1][0]+1 127 | return G.subgraph(range(n)) 128 | 129 | 130 | def loader(entry, train, args, db_path, test_seed_offset=0): 131 | """ Prepares a superpoint graph (potentially subsampled in training) and associated superpoints. """ 132 | G, fname = entry 133 | # 1) subset (neighborhood) selection of (permuted) superpoint graph 134 | if train: 135 | if 0 < args.spg_augm_hardcutoff < G.vcount(): 136 | perm = list(range(G.vcount())); random.shuffle(perm) 137 | G = G.permute_vertices(perm) 138 | 139 | if 0 < args.spg_augm_nneigh < G.vcount(): 140 | G = random_neighborhoods(G, args.spg_augm_nneigh, args.spg_augm_order) 141 | 142 | if 0 < args.spg_augm_hardcutoff < G.vcount(): 143 | G = k_big_enough(G, args.ptn_minpts, args.spg_augm_hardcutoff) 144 | 145 | # Only stores graph with edges 146 | if len(G.get_edgelist()) != 0: 147 | # 2) loading clouds for chosen superpoint graph nodes 148 | clouds_meta, clouds_flag = [], [] # meta: textual id of the superpoint; flag: 0/-1 if no cloud because too small 149 | clouds, clouds_global = [], [] # clouds: point cloud arrays; clouds_global: diameters before scaling 150 | 151 | for s in range(G.vcount()): 152 | cloud, diam = load_superpoint(args, db_path + '/parsed/' + fname + '.h5', G.vs[s]['v'], train, test_seed_offset) 153 | if cloud is not None: 154 | clouds_meta.append('{}.{:d}'.format(fname,G.vs[s]['v'])); clouds_flag.append(0) 155 | clouds.append(cloud.T) 156 | clouds_global.append(diam) 157 | else: 158 | clouds_meta.append('{}.{:d}'.format(fname,G.vs[s]['v'])); clouds_flag.append(-1) 159 | 160 | clouds_flag = np.array(clouds_flag) 161 | if len(clouds) != 0: 162 | clouds = np.stack(clouds) 163 | if len(clouds_global) != 0: 164 | clouds_global = np.concatenate(clouds_global) 165 | 166 | return np.array(G.vs['t']), G, clouds_meta, clouds_flag, clouds, clouds_global 167 | 168 | # Don't use the graph if it doesn't have edges. 169 | else: 170 | target, G, clouds_meta, clouds_flag, clouds, clouds_global = None, None, None, None, None, None 171 | return target, G, clouds_meta, clouds_flag, clouds, clouds_global 172 | 173 | 174 | def cloud_edge_feats(edgeattrs): 175 | edgefeats = np.asarray(edgeattrs['f']) 176 | return torch.from_numpy(edgefeats), None 177 | 178 | def eccpc_collate(batch): 179 | """ Collates a list of dataset samples into a single batch (adapted in ecc.graph_info_collate_classification()) 180 | """ 181 | targets, graphs, clouds_meta, clouds_flag, clouds, clouds_global = list(zip(*batch)) 182 | 183 | targets = torch.cat([torch.from_numpy(t) for t in targets if t is not None], 0).long() 184 | graphs = [graph for graph in graphs if graph is not None] 185 | GIs = [ecc.GraphConvInfo(graphs, cloud_edge_feats)] 186 | 187 | if len(clouds_meta[0]) > 0: 188 | clouds = torch.cat([torch.from_numpy(f) for f in clouds if f is not None], 0) 189 | clouds_global = torch.cat([torch.from_numpy(f) for f in clouds_global if f is not None], 0) 190 | clouds_flag = torch.cat([torch.from_numpy(f) for f in clouds_flag if f is not None], 0) 191 | clouds_meta = [item for sublist in clouds_meta if sublist is not None for item in sublist] 192 | 193 | return targets, GIs, (clouds_meta, clouds_flag, clouds, clouds_global) 194 | 195 | 196 | ############### POINT CLOUD PROCESSING ########## 197 | 198 | def load_superpoint(args, fname, id, train, test_seed_offset): 199 | """ """ 200 | hf = h5py.File(fname,'r') 201 | P = hf['{:d}'.format(id)] 202 | N = P.shape[0] 203 | if N < args.ptn_minpts: # skip if too few pts (this must be consistent at train and test time) 204 | return None, N 205 | P = P[:].astype(np.float32) 206 | 207 | rs = np.random.random.__self__ if train else np.random.RandomState(seed=id+test_seed_offset) # fix seed for test 208 | 209 | if N > args.ptn_npts: # need to subsample 210 | ii = rs.choice(N, args.ptn_npts) 211 | P = P[ii, ...] 212 | elif N < args.ptn_npts: # need to pad by duplication 213 | ii = rs.choice(N, args.ptn_npts - N) 214 | P = np.concatenate([P, P[ii,...]], 0) 215 | 216 | if args.pc_xyznormalize: 217 | # normalize xyz into unit ball, i.e. in [-0.5,0.5] 218 | diameter = np.max(np.max(P[:,:3],axis=0) - np.min(P[:,:3],axis=0)) 219 | P[:,:3] = (P[:,:3] - np.mean(P[:,:3], axis=0, keepdims=True)) / (diameter + 1e-10) 220 | else: 221 | diameter = 0.0 222 | P[:,:3] = (P[:,:3] - np.mean(P[:,:3], axis=0, keepdims=True)) 223 | 224 | if args.pc_attribs != '': 225 | columns = [] 226 | if 'xyz' in args.pc_attribs: columns.append(P[:,:3]) 227 | if 'rgb' in args.pc_attribs: columns.append(P[:,3:6]) 228 | if 'e' in args.pc_attribs: columns.append(P[:,6,None]) 229 | if 'lpsv' in args.pc_attribs: columns.append(P[:,7:11]) 230 | if 'XYZ' in args.pc_attribs: columns.append(P[:,11:14]) 231 | if 'd' in args.pc_attribs: columns.append(P[:,14]) 232 | P = np.concatenate(columns, axis=1) 233 | 234 | if train: 235 | P = augment_cloud(P, args) 236 | return P, np.array([diameter], dtype=np.float32) 237 | 238 | 239 | def augment_cloud(P, args): 240 | """" Augmentation on XYZ and jittering of everything """ 241 | M = transforms3d.zooms.zfdir2mat(1) 242 | if args.pc_augm_scale > 1: 243 | s = random.uniform(1/args.pc_augm_scale, args.pc_augm_scale) 244 | M = np.dot(transforms3d.zooms.zfdir2mat(s), M) 245 | if args.pc_augm_rot==1: 246 | angle = random.uniform(0, 2*math.pi) 247 | M = np.dot(transforms3d.axangles.axangle2mat([0,0,1], angle), M) # z=upright assumption 248 | if args.pc_augm_mirror_prob > 0: # mirroring x&y, not z 249 | if random.random() < args.pc_augm_mirror_prob/2: 250 | M = np.dot(transforms3d.zooms.zfdir2mat(-1, [1,0,0]), M) 251 | if random.random() < args.pc_augm_mirror_prob/2: 252 | M = np.dot(transforms3d.zooms.zfdir2mat(-1, [0,1,0]), M) 253 | P[:,:3] = np.dot(P[:,:3], M.T) 254 | 255 | if args.pc_augm_jitter: 256 | sigma, clip= 0.01, 0.05 # https://github.com/charlesq34/pointnet/blob/master/provider.py#L74 257 | P = P + np.clip(sigma * np.random.randn(*P.shape), -1*clip, clip).astype(np.float32) 258 | return P 259 | 260 | def global_rotation(P, args): 261 | print("e") 262 | -------------------------------------------------------------------------------- /learning/vkitti_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Nov 6 16:45:16 2018 5 | @author: landrieuloic 6 | """ 7 | from __future__ import division 8 | from __future__ import print_function 9 | from builtins import range 10 | 11 | import random 12 | import numpy as np 13 | import os 14 | import functools 15 | import torch 16 | import torchnet as tnt 17 | import h5py 18 | import spg 19 | 20 | def get_datasets(args, test_seed_offset=0): 21 | """ Gets training and test datasets. """ 22 | 23 | # Load superpoints graphs 24 | testlist, trainlist, validlist = [], [], [] 25 | valid_names = ['0001_00000.h5','0001_00085.h5', '0001_00170.h5','0001_00230.h5','0001_00325.h5','0001_00420.h5', \ 26 | '0002_00000.h5','0002_00111.h5','0002_00223.h5','0018_00030.h5','0018_00184.h5','0018_00338.h5',\ 27 | '0020_00080.h5','0020_00262.h5','0020_00444.h5','0020_00542.h5','0020_00692.h5', '0020_00800.h5'] 28 | 29 | for n in range(1,7): 30 | if n != args.cvfold: 31 | path = '{}/superpoint_graphs/0{:d}/'.format(args.VKITTI_PATH, n) 32 | for fname in sorted(os.listdir(path)): 33 | if fname.endswith(".h5") and not (args.use_val_set and fname in valid_names): 34 | #training set 35 | trainlist.append(spg.spg_reader(args, path + fname, True)) 36 | if fname.endswith(".h5") and (args.use_val_set and fname in valid_names): 37 | #validation set 38 | validlist.append(spg.spg_reader(args, path + fname, True)) 39 | path = '{}/superpoint_graphs/0{:d}/'.format(args.VKITTI_PATH, args.cvfold) 40 | #evaluation set 41 | for fname in sorted(os.listdir(path)): 42 | if fname.endswith(".h5"): 43 | testlist.append(spg.spg_reader(args, path + fname, True)) 44 | 45 | # Normalize edge features 46 | if args.spg_attribs01: 47 | trainlist, testlist, validlist, scaler = spg.scaler01(trainlist, testlist, validlist=validlist) 48 | 49 | return tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in trainlist], 50 | functools.partial(spg.loader, train=True, args=args, db_path=args.VKITTI_PATH)), \ 51 | tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in testlist], 52 | functools.partial(spg.loader, train=False, args=args, db_path=args.VKITTI_PATH, test_seed_offset=test_seed_offset)), \ 53 | tnt.dataset.ListDataset([spg.spg_to_igraph(*tlist) for tlist in validlist], 54 | functools.partial(spg.loader, train=False, args=args, db_path=args.VKITTI_PATH, test_seed_offset=test_seed_offset)), \ 55 | scaler 56 | 57 | 58 | def get_info(args): 59 | edge_feats = 0 60 | for attrib in args.edge_attribs.split(','): 61 | a = attrib.split('/')[0] 62 | if a in ['delta_avg', 'delta_std', 'xyz']: 63 | edge_feats += 3 64 | else: 65 | edge_feats += 1 66 | if args.loss_weights == 'none': 67 | weights = np.ones((13,),dtype='f4') 68 | else: 69 | weights = h5py.File(args.VKITTI_PATH + "/parsed/class_count.h5")["class_count"][:].astype('f4') 70 | weights = weights[:,[i for i in range(6) if i != args.cvfold-1]].sum(1) 71 | weights = (weights+1).mean()/(weights+1) 72 | if args.loss_weights == 'sqrt': 73 | weights = np.sqrt(weights) 74 | weights = torch.from_numpy(weights).cuda() if args.cuda else torch.from_numpy(weights) 75 | return { 76 | 'node_feats': 9 if args.pc_attribs=='' else len(args.pc_attribs), 77 | 'edge_feats': edge_feats, 78 | 'classes': 13, 79 | 'class_weights': weights, 80 | 'inv_class_map': {0:'Terrain', 1:'Tree', 2:'Vegetation', 3:'Building', 4:'Road', 5:'GuardRail', 6:'TrafficSign', 7:'TrafficLight', 8:'Pole', 9:'Misc', 10:'Truck', 11:'Car', 12:'Van'}, 81 | } 82 | 83 | def preprocess_pointclouds(VKITTI_PATH): 84 | """ Preprocesses data by splitting them by components and normalizing.""" 85 | class_count = np.zeros((13,6),dtype='int') 86 | for n in range(1,7): 87 | pathP = '{}/parsed/0{:d}/'.format(VKITTI_PATH, n) 88 | pathD = '{}/features_supervision/0{:d}/'.format(VKITTI_PATH, n) 89 | pathC = '{}/superpoint_graphs/0{:d}/'.format(VKITTI_PATH, n) 90 | if not os.path.exists(pathP): 91 | os.makedirs(pathP) 92 | random.seed(n) 93 | 94 | for file in os.listdir(pathC): 95 | print(file) 96 | if file.endswith(".h5"): 97 | f = h5py.File(pathD + file, 'r') 98 | xyz = f['xyz'][:] 99 | rgb = f['rgb'][:].astype(np.float) 100 | 101 | labels = f['labels'][:] 102 | hard_labels = np.argmax(labels[:,1:],1) 103 | label_count = np.bincount(hard_labels, minlength=13) 104 | class_count[:,n-1] = class_count[:,n-1] + label_count 105 | 106 | e = (f['xyz'][:,2][:] - np.min(f['xyz'][:,2]))/ (np.max(f['xyz'][:,2]) - np.min(f['xyz'][:,2]))-0.5 107 | 108 | rgb = rgb/255.0 - 0.5 109 | 110 | xyzn = (xyz - np.array([30,0,0])) / np.array([30,5,3]) 111 | 112 | lpsv = np.zeros((e.shape[0],4)) 113 | 114 | P = np.concatenate([xyz, rgb, e[:,np.newaxis], lpsv, xyzn], axis=1) 115 | 116 | f = h5py.File(pathC + file, 'r') 117 | numc = len(f['components'].keys()) 118 | 119 | with h5py.File(pathP + file, 'w') as hf: 120 | hf.create_dataset(name='centroid',data=xyz.mean(0)) 121 | for c in range(numc): 122 | idx = f['components/{:d}'.format(c)][:].flatten() 123 | if idx.size > 10000: # trim extra large segments, just for speed-up of loading time 124 | ii = random.sample(range(idx.size), k=10000) 125 | idx = idx[ii] 126 | 127 | hf.create_dataset(name='{:d}'.format(c), data=P[idx,...]) 128 | path = '{}/parsed/'.format(VKITTI_PATH) 129 | data_file = h5py.File(path+'class_count.h5', 'w') 130 | data_file.create_dataset('class_count', data=class_count, dtype='int') 131 | 132 | if __name__ == "__main__": 133 | import argparse 134 | parser = argparse.ArgumentParser(description='Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs') 135 | parser.add_argument('--VKITTI_PATH', default='datasets/s3dis') 136 | args = parser.parse_args() 137 | preprocess_pointclouds(args.VKITTI_PATH) -------------------------------------------------------------------------------- /partition/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loicland/superpoint_graph/0209777339327c9b327b6947af6c89b20bb45981/partition/__init__.py -------------------------------------------------------------------------------- /partition/graphs.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | #--------- Graph methods for SuperPoint Graph ------------------------------ 3 | #--------- Loic Landrieu, Dec. 2017 ----------------------------------- 4 | #------------------------------------------------------------------------------ 5 | import numpy as np 6 | from sklearn.neighbors import NearestNeighbors 7 | from scipy.spatial import Delaunay 8 | from numpy import linalg as LA 9 | import numpy.matlib 10 | #------------------------------------------------------------------------------ 11 | def compute_graph_nn(xyz, k_nn): 12 | """compute the knn graph""" 13 | num_ver = xyz.shape[0] 14 | graph = dict([("is_nn", True)]) 15 | nn = NearestNeighbors(n_neighbors=k_nn+1, algorithm='kd_tree').fit(xyz) 16 | distances, neighbors = nn.kneighbors(xyz) 17 | neighbors = neighbors[:, 1:] 18 | distances = distances[:, 1:] 19 | source = np.matlib.repmat(range(0, num_ver), k_nn, 1).flatten(order='F') 20 | #save the graph 21 | graph["source"] = source.flatten().astype('uint32') 22 | graph["target"] = neighbors.flatten().astype('uint32') 23 | graph["distances"] = distances.flatten().astype('float32') 24 | return graph 25 | #------------------------------------------------------------------------------ 26 | def compute_graph_nn_2(xyz, k_nn1, k_nn2, voronoi = 0.0): 27 | """compute simulteneoulsy 2 knn structures 28 | only saves target for knn2 29 | assumption : knn1 <= knn2""" 30 | assert k_nn1 <= k_nn2, "knn1 must be smaller than knn2" 31 | n_ver = xyz.shape[0] 32 | #compute nearest neighbors 33 | graph = dict([("is_nn", True)]) 34 | nn = NearestNeighbors(n_neighbors=k_nn2+1, algorithm='kd_tree').fit(xyz) 35 | distances, neighbors = nn.kneighbors(xyz) 36 | del nn 37 | neighbors = neighbors[:, 1:] 38 | distances = distances[:, 1:] 39 | #---knn2--- 40 | target2 = (neighbors.flatten()).astype('uint32') 41 | #---knn1----- 42 | if voronoi>0: 43 | tri = Delaunay(xyz) 44 | graph["source"] = np.hstack((tri.vertices[:,0],tri.vertices[:,0], \ 45 | tri.vertices[:,0], tri.vertices[:,1], tri.vertices[:,1], tri.vertices[:,2])).astype('uint64') 46 | graph["target"]= np.hstack((tri.vertices[:,1],tri.vertices[:,2], \ 47 | tri.vertices[:,3], tri.vertices[:,2], tri.vertices[:,3], tri.vertices[:,3])).astype('uint64') 48 | graph["distances"] = ((xyz[graph["source"],:] - xyz[graph["target"],:])**2).sum(1) 49 | keep_edges = graph["distances"] 1 80 | label_hist = has_labels and len(labels.shape) > 1 and labels.shape[1] > 1 81 | #---compute delaunay triangulation--- 82 | tri = Delaunay(xyz) 83 | #interface select the edges between different components 84 | #edgx and edgxr converts from tetrahedrons to edges 85 | #done separatly for each edge of the tetrahedrons to limit memory impact 86 | interface = in_component[tri.vertices[:, 0]] != in_component[tri.vertices[:, 1]] 87 | edg1 = np.vstack((tri.vertices[interface, 0], tri.vertices[interface, 1])) 88 | edg1r = np.vstack((tri.vertices[interface, 1], tri.vertices[interface, 0])) 89 | interface = in_component[tri.vertices[:, 0]] != in_component[tri.vertices[:, 2]] 90 | edg2 = np.vstack((tri.vertices[interface, 0], tri.vertices[interface, 2])) 91 | edg2r = np.vstack((tri.vertices[interface, 2], tri.vertices[interface, 0])) 92 | interface = in_component[tri.vertices[:, 0]] != in_component[tri.vertices[:, 3]] 93 | edg3 = np.vstack((tri.vertices[interface, 0], tri.vertices[interface, 3])) 94 | edg3r = np.vstack((tri.vertices[interface, 3], tri.vertices[interface, 0])) 95 | interface = in_component[tri.vertices[:, 1]] != in_component[tri.vertices[:, 2]] 96 | edg4 = np.vstack((tri.vertices[interface, 1], tri.vertices[interface, 2])) 97 | edg4r = np.vstack((tri.vertices[interface, 2], tri.vertices[interface, 1])) 98 | interface = in_component[tri.vertices[:, 1]] != in_component[tri.vertices[:, 3]] 99 | edg5 = np.vstack((tri.vertices[interface, 1], tri.vertices[interface, 3])) 100 | edg5r = np.vstack((tri.vertices[interface, 3], tri.vertices[interface, 1])) 101 | interface = in_component[tri.vertices[:, 2]] != in_component[tri.vertices[:, 3]] 102 | edg6 = np.vstack((tri.vertices[interface, 2], tri.vertices[interface, 3])) 103 | edg6r = np.vstack((tri.vertices[interface, 3], tri.vertices[interface, 2])) 104 | del tri, interface 105 | edges = np.hstack((edg1, edg2, edg3, edg4 ,edg5, edg6, edg1r, edg2r, 106 | edg3r, edg4r ,edg5r, edg6r)) 107 | del edg1, edg2, edg3, edg4 ,edg5, edg6, edg1r, edg2r, edg3r, edg4r, edg5r, edg6r 108 | edges = np.unique(edges, axis=1) 109 | 110 | if d_max > 0: 111 | dist = np.sqrt(((xyz[edges[0,:]]-xyz[edges[1,:]])**2).sum(1)) 112 | edges = edges[:,dist 1: 203 | graph["se_delta_mean"][i_sedg] = np.mean(delta, axis=0) 204 | graph["se_delta_std"][i_sedg] = np.std(delta, axis=0) 205 | graph["se_delta_norm"][i_sedg] = np.mean(np.sqrt(np.sum(delta ** 2, axis=1))) 206 | else: 207 | graph["se_delta_mean"][i_sedg, :] = delta 208 | graph["se_delta_std"][i_sedg, :] = [0, 0, 0] 209 | graph["se_delta_norm"][i_sedg] = np.sqrt(np.sum(delta ** 2)) 210 | return graph 211 | -------------------------------------------------------------------------------- /partition/partition.py: -------------------------------------------------------------------------------- 1 | """ 2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs 3 | http://arxiv.org/abs/1711.09869 4 | 2017 Loic Landrieu, Martin Simonovsky 5 | Script for partioning into simples shapes 6 | """ 7 | import os.path 8 | import sys 9 | import numpy as np 10 | import argparse 11 | from timeit import default_timer as timer 12 | sys.path.append("./partition/cut-pursuit/build/src") 13 | sys.path.append("./partition/ply_c") 14 | sys.path.append("./partition") 15 | import libcp 16 | import libply_c 17 | from graphs import * 18 | from provider import * 19 | 20 | parser = argparse.ArgumentParser(description='Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs') 21 | parser.add_argument('--ROOT_PATH', default='datasets/s3dis') 22 | parser.add_argument('--dataset', default='s3dis', help='s3dis/sema3d/your_dataset') 23 | parser.add_argument('--k_nn_geof', default=45, type=int, help='number of neighbors for the geometric features') 24 | parser.add_argument('--k_nn_adj', default=10, type=int, help='adjacency structure for the minimal partition') 25 | parser.add_argument('--lambda_edge_weight', default=1., type=float, help='parameter determine the edge weight for minimal part.') 26 | parser.add_argument('--reg_strength', default=0.1, type=float, help='regularization strength for the minimal partition') 27 | parser.add_argument('--d_se_max', default=0, type=float, help='max length of super edges') 28 | parser.add_argument('--voxel_width', default=0.03, type=float, help='voxel size when subsampling (in m)') 29 | parser.add_argument('--ver_batch', default=0, type=int, help='Batch size for reading large files, 0 do disable batch loading') 30 | parser.add_argument('--overwrite', default=0, type=int, help='Wether to read existing files or overwrite them') 31 | args = parser.parse_args() 32 | 33 | #path to data 34 | root = args.ROOT_PATH+'/' 35 | #list of subfolders to be processed 36 | if args.dataset == 's3dis': 37 | folders = ["Area_1/", "Area_2/", "Area_3/", "Area_4/", "Area_5/", "Area_6/"] 38 | n_labels = 13 39 | elif args.dataset == 'sema3d': 40 | folders = ["test_reduced/", "test_full/", "train/"] 41 | n_labels = 8 42 | elif args.dataset == 'custom_dataset': 43 | folders = ["train/", "test/"] 44 | n_labels = 10 #number of classes 45 | else: 46 | raise ValueError('%s is an unknown data set' % dataset) 47 | 48 | times = [0,0,0] #time for computing: features / partition / spg 49 | 50 | if not os.path.isdir(root + "clouds"): 51 | os.mkdir(root + "clouds") 52 | if not os.path.isdir(root + "features"): 53 | os.mkdir(root + "features") 54 | if not os.path.isdir(root + "superpoint_graphs"): 55 | os.mkdir(root + "superpoint_graphs") 56 | 57 | for folder in folders: 58 | print("=================\n "+folder+"\n=================") 59 | 60 | data_folder = root + "data/" + folder 61 | cloud_folder = root + "clouds/" + folder 62 | fea_folder = root + "features/" + folder 63 | spg_folder = root + "superpoint_graphs/" + folder 64 | if not os.path.isdir(data_folder): 65 | raise ValueError("%s does not exist" % data_folder) 66 | 67 | if not os.path.isdir(cloud_folder): 68 | os.mkdir(cloud_folder) 69 | if not os.path.isdir(fea_folder): 70 | os.mkdir(fea_folder) 71 | if not os.path.isdir(spg_folder): 72 | os.mkdir(spg_folder) 73 | 74 | if args.dataset=='s3dis': 75 | files = [os.path.join(data_folder, o) for o in os.listdir(data_folder) 76 | if os.path.isdir(os.path.join(data_folder,o))] 77 | elif args.dataset=='sema3d': 78 | files = glob.glob(data_folder+"*.txt") 79 | elif args.dataset=='custom_dataset': 80 | #list all ply files in the folder 81 | files = glob.glob(data_folder+"*.ply") 82 | #list all las files in the folder 83 | files = glob.glob(data_folder+"*.las") 84 | 85 | if (len(files) == 0): 86 | raise ValueError('%s is empty' % data_folder) 87 | 88 | n_files = len(files) 89 | i_file = 0 90 | for file in files: 91 | file_name = os.path.splitext(os.path.basename(file))[0] 92 | 93 | if args.dataset=='s3dis': 94 | data_file = data_folder + file_name + '/' + file_name + ".txt" 95 | cloud_file = cloud_folder + file_name 96 | fea_file = fea_folder + file_name + '.h5' 97 | spg_file = spg_folder + file_name + '.h5' 98 | elif args.dataset=='sema3d': 99 | file_name_short = '_'.join(file_name.split('_')[:2]) 100 | data_file = data_folder + file_name + ".txt" 101 | label_file = data_folder + file_name_short + ".labels" 102 | cloud_file = cloud_folder+ file_name_short 103 | fea_file = fea_folder + file_name_short + '.h5' 104 | spg_file = spg_folder + file_name_short + '.h5' 105 | elif args.dataset=='custom_dataset': 106 | #adapt to your hierarchy. The following 4 files must be defined 107 | data_file = data_folder + file_name + '.ply' #or .las 108 | cloud_file = cloud_folder + file_name 109 | fea_file = fea_folder + file_name + '.h5' 110 | spg_file = spg_folder + file_name + '.h5' 111 | 112 | i_file = i_file + 1 113 | print(str(i_file) + " / " + str(n_files) + "---> "+file_name) 114 | #--- build the geometric feature file h5 file --- 115 | if os.path.isfile(fea_file) and not args.overwrite: 116 | print(" reading the existing feature file...") 117 | geof, xyz, rgb, graph_nn, labels = read_features(fea_file) 118 | else : 119 | print(" creating the feature file...") 120 | #--- read the data files and compute the labels--- 121 | if args.dataset=='s3dis': 122 | xyz, rgb, labels, objects = read_s3dis_format(data_file) 123 | if args.voxel_width > 0: 124 | xyz, rgb, labels, dump = libply_c.prune(xyz.astype('f4'), args.voxel_width, rgb.astype('uint8'), labels.astype('uint8'), np.zeros(1, dtype='uint8'), n_labels, 0) 125 | elif args.dataset=='sema3d': 126 | label_file = data_folder + file_name + ".labels" 127 | has_labels = (os.path.isfile(label_file)) 128 | if (has_labels): 129 | xyz, rgb, labels = read_semantic3d_format(data_file, n_labels, label_file, args.voxel_width, args.ver_batch) 130 | else: 131 | xyz, rgb = read_semantic3d_format(data_file, 0, '', args.voxel_width, args.ver_batch) 132 | labels = [] 133 | elif args.dataset=='custom_dataset': 134 | #implement in provider.py your own read_custom_format outputing xyz, rgb, labels 135 | #example for ply files 136 | xyz, rgb, labels = read_ply(data_file) 137 | #another one for las files without rgb 138 | xyz = read_las(data_file) 139 | if args.voxel_width > 0: 140 | #an example of pruning without labels 141 | xyz, rgb, labels = libply_c.prune(xyz, args.voxel_width, rgb, np.array(1,dtype='u1'), 0) 142 | #another one without rgb information nor labels 143 | xyz = libply_c.prune(xyz, args.voxel_width, np.zeros(xyz.shape,dtype='u1'), np.array(1,dtype='u1'), 0)[0] 144 | #if no labels available simply set here labels = [] 145 | #if no rgb available simply set here rgb = [] and make sure to not use it later on 146 | start = timer() 147 | #---compute 10 nn graph------- 148 | graph_nn, target_fea = compute_graph_nn_2(xyz, args.k_nn_adj, args.k_nn_geof) 149 | #---compute geometric features------- 150 | geof = libply_c.compute_geof(xyz, target_fea, args.k_nn_geof).astype('float32') 151 | end = timer() 152 | times[0] = times[0] + end - start 153 | del target_fea 154 | write_features(fea_file, geof, xyz, rgb, graph_nn, labels) 155 | #--compute the partition------ 156 | sys.stdout.flush() 157 | if os.path.isfile(spg_file) and not args.overwrite: 158 | print(" reading the existing superpoint graph file...") 159 | graph_sp, components, in_component = read_spg(spg_file) 160 | else: 161 | print(" computing the superpoint graph...") 162 | #--- build the spg h5 file -- 163 | start = timer() 164 | if args.dataset=='s3dis': 165 | features = np.hstack((geof, rgb/255.)).astype('float32')#add rgb as a feature for partitioning 166 | features[:,3] = 2. * features[:,3] #increase importance of verticality (heuristic) 167 | elif args.dataset=='sema3d': 168 | features = geof 169 | geof[:,3] = 2. * geof[:, 3] 170 | elif args.dataset=='custom_dataset': 171 | #choose here which features to use for the partition 172 | features = geof 173 | geof[:,3] = 2. * geof[:, 3] 174 | 175 | graph_nn["edge_weight"] = np.array(1. / ( args.lambda_edge_weight + graph_nn["distances"] / np.mean(graph_nn["distances"])), dtype = 'float32') 176 | print(" minimal partition...") 177 | components, in_component = libcp.cutpursuit(features, graph_nn["source"], graph_nn["target"] 178 | , graph_nn["edge_weight"], args.reg_strength) 179 | components = np.array(components, dtype = 'object') 180 | end = timer() 181 | times[1] = times[1] + end - start 182 | print(" computation of the SPG...") 183 | start = timer() 184 | graph_sp = compute_sp_graph(xyz, args.d_se_max, in_component, components, labels, n_labels) 185 | end = timer() 186 | times[2] = times[2] + end - start 187 | write_spg(spg_file, graph_sp, components, in_component) 188 | 189 | print("Timer : %5.1f / %5.1f / %5.1f " % (times[0], times[1], times[2])) 190 | -------------------------------------------------------------------------------- /partition/ply_c/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Graph for Cut Pursuit 2 | # author: Loic Landrieu 3 | # date: 2017 4 | 5 | CMAKE_MINIMUM_REQUIRED(VERSION 3.5) 6 | 7 | PROJECT(LIBGEOF) 8 | 9 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -std=c++11 -fopenmp -O3") 10 | 11 | ############################## 12 | ### Find required packages ### 13 | ############################## 14 | set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_MODULE_PATH}) 15 | 16 | find_package(PythonLibs) 17 | find_package(PythonInterp) 18 | find_package(NumPy 1.5 REQUIRED) 19 | 20 | find_package(Boost 1.65.0 COMPONENTS graph REQUIRED) #system filesystem thread serialization 21 | if (${Boost_MINOR_VERSION} LESS 67 ) 22 | find_package(Boost 1.65.0 COMPONENTS numpy${PYTHON_VERSION_MAJOR} REQUIRED) #system filesystem thread serialization 23 | else() 24 | set(PYTHONVERSION ${PYTHON_VERSION_MAJOR}${PYTHON_VERSION_MINOR}) 25 | find_package(Boost 1.67.0 COMPONENTS numpy${PYTHONVERSION} REQUIRED) 26 | endif() 27 | 28 | include_directories(${Boost_INCLUDE_DIRS}) 29 | link_directories(${Boost_LIBRARY_DIRS}) 30 | 31 | message("Boost includes ARE " ${Boost_INCLUDE_DIRS}) 32 | message("Boost LIBRARIES ARE " ${Boost_LIBRARY_DIRS}) 33 | 34 | find_package(Eigen3 REQUIRED NO_MODULE) 35 | INCLUDE_DIRECTORIES(${EIGEN3_INCLUDE_DIR}) 36 | #LINK_DIRECTORIES(${EIGEN3_LIBRARY_DIRS}) 37 | 38 | #SET(PYTHON_LIBRARIES /usr/lib/x86_64-linux-gnu/libpython2.7.so) 39 | #SET(PYTHON_INCLUDE_DIRS /usr/include/python2.7) 40 | 41 | message("PYTHON LIBRARIES ARE " ${PYTHON_LIBRARIES}) 42 | INCLUDE_DIRECTORIES(${PYTHON_INCLUDE_DIRS} ${PYTHON_NUMPY_INCLUDE_DIR}) 43 | LINK_DIRECTORIES(${PYTHON_LIBRARY_DIRS}) 44 | ############################## 45 | ### Build target library ### 46 | ############################## 47 | 48 | set(CMAKE_LD_FLAG "${CMAKE_LD_FLAGS} -shared -Wl -fPIC --export-dynamic -fopenmp -O3 -Wall") 49 | 50 | add_library(ply_c SHARED ply_c.cpp) 51 | target_link_libraries(ply_c 52 | ${Boost_LIBRARIES} 53 | ${PYTHON_LIBRARIES} 54 | ) 55 | -------------------------------------------------------------------------------- /partition/ply_c/FindNumPy.cmake: -------------------------------------------------------------------------------- 1 | 2 | # - Try to find the Python module NumPy 3 | # 4 | # This module defines: 5 | # NUMPY_INCLUDE_DIR: include path for arrayobject.h 6 | 7 | # Copyright (c) 2009-2012 Arnaud Barré 8 | # Redistribution and use is allowed according to the terms of the BSD license. 9 | # For details see the accompanying COPYING-CMAKE-SCRIPTS file. 10 | 11 | if (PYTHON_NUMPY_INCLUDE_DIR) 12 | set(PYTHON_NUMPY_FIND_QUIETLY TRUE) 13 | endif() 14 | 15 | if (NOT PYTHON_EXECUTABLE) 16 | message(FATAL_ERROR "\"PYTHON_EXECUTABLE\" varabile not set before FindNumPy.cmake was run.") 17 | endif() 18 | 19 | # Look for the include path 20 | # WARNING: The variable PYTHON_EXECUTABLE is defined by the script FindPythonInterp.cmake 21 | execute_process(COMMAND "${PYTHON_EXECUTABLE}" -c "import numpy; print (numpy.get_include()); print (numpy.version.version)" 22 | OUTPUT_VARIABLE NUMPY_OUTPUT 23 | ERROR_VARIABLE NUMPY_ERROR) 24 | if (NOT NUMPY_ERROR) 25 | STRING(REPLACE "\n" ";" NUMPY_OUTPUT ${NUMPY_OUTPUT}) 26 | LIST(GET NUMPY_OUTPUT 0 PYTHON_NUMPY_INCLUDE_DIRS) 27 | LIST(GET NUMPY_OUTPUT 1 PYTHON_NUMPY_VERSION) 28 | endif(NOT NUMPY_ERROR) 29 | 30 | include(FindPackageHandleStandardArgs) 31 | find_package_handle_standard_args(NumPy DEFAULT_MSG PYTHON_NUMPY_VERSION PYTHON_NUMPY_INCLUDE_DIRS) 32 | 33 | set(PYTHON_NUMPY_INCLUDE_DIR ${PYTHON_NUMPY_INCLUDE_DIRS} 34 | CACHE PATH "Location of NumPy include files.") 35 | mark_as_advanced(PYTHON_NUMPY_INCLUDE_DIR) -------------------------------------------------------------------------------- /partition/ply_c/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loicland/superpoint_graph/0209777339327c9b327b6947af6c89b20bb45981/partition/ply_c/__init__.py -------------------------------------------------------------------------------- /partition/ply_c/connected_components.cpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | 11 | using namespace std; 12 | using namespace boost; 13 | 14 | typedef adjacency_list Graph; 15 | typedef typename graph_traits< Graph >::adjacency_iterator adjacency_iterator; 16 | 17 | void connected_components(const uint32_t n_ver, const uint32_t n_edg 18 | , const uint32_t * Eu, const uint32_t * Ev, const char * active_edge 19 | , std::vector & in_component, std::vector< std::vector > & components, const uint32_t cutoff) 20 | { //C-style interface 21 | 22 | Graph G(n_ver); 23 | for (uint32_t i_edg = 0; i_edg < n_edg; i_edg++) 24 | { 25 | if (active_edge[i_edg] > 0) 26 | { 27 | add_edge(Eu[i_edg], Ev[i_edg], G); 28 | } 29 | } 30 | 31 | int n_com = connected_components(G, &in_component[0]); 32 | 33 | //cout << "Total number of components: " << n_com << endl; 34 | 35 | std::vector< std::vector > components_tmp(n_com); 36 | for (uint32_t i_ver = 0; i_ver < n_ver; i_ver++) 37 | { 38 | components_tmp[in_component[i_ver]].push_back(i_ver); 39 | } 40 | 41 | //fuse components to preserve cutoff 42 | 43 | G = Graph(n_ver); 44 | for (uint32_t i_edg = 0; i_edg < n_edg; i_edg++) 45 | { 46 | if (active_edge[i_edg] == 0) 47 | { 48 | add_edge(Eu[i_edg], Ev[i_edg], G); 49 | } 50 | } 51 | 52 | typename graph_traits < Graph >::adjacency_iterator nei_ini, nei_end; 53 | boost::property_map::type vertex_index_map = get(boost::vertex_index, G); 54 | std::vector is_fused(n_ver, 0); 55 | 56 | int n_com_final = n_com; 57 | for (int i_com = 0; i_com < n_com; i_com++) 58 | { 59 | if (components_tmp[i_com].size() < cutoff) 60 | {//components is too small 61 | //std::cout << i_com << " of size " << components_tmp[i_com].size() << " / " << cutoff << std::endl; 62 | int largest_neigh_comp_value = 0; 63 | int largest_neigh_comp_index = -1; 64 | for (int i_ver_com = 0; i_ver_com < components_tmp[i_com].size(); i_ver_com++) 65 | { //std::cout << " considering node" << components_tmp[i_com][i_ver_com] << std::endl; 66 | boost::tie(nei_ini, nei_end) = adjacent_vertices(vertex(components_tmp[i_com][i_ver_com], G), G); 67 | for (graph_traits < Graph >::adjacency_iterator nei_ite = nei_ini; nei_ite != nei_end; nei_ite++) 68 | { 69 | int candidate_comp = in_component[vertex_index_map(*nei_ite)]; 70 | if ((candidate_comp == i_com) || (is_fused[candidate_comp])) 71 | { 72 | continue; 73 | } 74 | //std::cout << " neighbors " << vertex_index_map(*nei_ite) << " in comp " << candidate_comp << "of size " << components_tmp[candidate_comp].size() << std::endl; 75 | if (components_tmp[candidate_comp].size() > largest_neigh_comp_value) 76 | { 77 | largest_neigh_comp_value = components_tmp[candidate_comp].size() ; 78 | largest_neigh_comp_index = candidate_comp; 79 | } 80 | } 81 | } 82 | if (largest_neigh_comp_index>0) 83 | { 84 | //std::cout << "best comp = " << largest_neigh_comp_index << " of size " << largest_neigh_comp_value << std::endl; 85 | //we now fuse the two connected components 86 | components_tmp[largest_neigh_comp_index].insert(components_tmp[largest_neigh_comp_index].end(), components_tmp[i_com].begin(), components_tmp[i_com].end()); 87 | is_fused[i_com] = 1; 88 | n_com_final--; 89 | } 90 | } 91 | } 92 | 93 | components.resize(n_com_final); 94 | int i_com_index = 0; 95 | for (int i_com = 0; i_com < n_com; i_com++) 96 | { 97 | if (!is_fused[i_com]) 98 | { 99 | components[i_com_index] = components_tmp[i_com]; 100 | for (uint32_t i_ver_com = 0; i_ver_com < components_tmp[i_com].size(); i_ver_com++) 101 | { 102 | in_component[components_tmp[i_com][i_ver_com]] = i_com_index; 103 | } 104 | i_com_index++; 105 | } 106 | } 107 | 108 | 109 | return; 110 | } 111 | -------------------------------------------------------------------------------- /partition/ply_c/random_subgraph.cpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | using namespace std; 9 | using namespace boost; 10 | 11 | namespace subgraph { 12 | 13 | typedef adjacency_list Graph; 14 | 15 | typedef typename boost::graph_traits< Graph >::vertex_descriptor VertexDescriptor; 16 | 17 | typedef typename boost::property_map< Graph, boost::vertex_index_t>::type VertexIndexMap; 18 | 19 | typedef typename graph_traits < Graph >::adjacency_iterator Adjacency_iterator; 20 | 21 | 22 | void random_subgraph(const int n_ver, const int n_edg, const uint32_t * Eu, const uint32_t * Ev, int subgraph_size 23 | , uint8_t * selected_edges, uint8_t * selected_vertices) 24 | 25 | { //C-style interface 26 | 27 | if (n_ver < subgraph_size) 28 | { 29 | for (uint32_t i_edg = 0; i_edg < n_edg; i_edg++) 30 | { 31 | selected_edges[i_edg] = 1; 32 | } 33 | for (uint32_t i_ver = 0; i_ver < n_ver; i_ver++) 34 | { 35 | selected_vertices[n_ver] = 1; 36 | } 37 | return; 38 | } 39 | 40 | Graph G(n_ver); 41 | 42 | VertexIndexMap vertex_index_map = get(boost::vertex_index, G); 43 | VertexDescriptor ver_current; 44 | Adjacency_iterator ite_ver_adj,ite_ver_adj_end; 45 | int node_seen = 0, seed_index; 46 | queue ver_queue; 47 | 48 | for (uint32_t i_edg = 0; i_edg < n_edg; i_edg++) 49 | { //building graph 50 | add_edge(vertex(Eu[i_edg],G), vertex(Ev[i_edg],G), G); 51 | } 52 | 53 | while(node_seen < subgraph_size) 54 | { 55 | //add seed vertex 56 | seed_index = rand() % n_ver; 57 | if (selected_vertices[seed_index]) 58 | { 59 | continue; 60 | } 61 | ver_queue.push(vertex(seed_index,G)); 62 | selected_vertices[vertex_index_map(ver_queue.front())] = 1; 63 | node_seen = node_seen + 1; 64 | 65 | while(!ver_queue.empty()) 66 | { 67 | //pop the top of the queue and mark it as seen 68 | ver_current = ver_queue.front(); 69 | ver_queue.pop(); 70 | 71 | //add the neighbors of that vertex 72 | for (tie(ite_ver_adj, ite_ver_adj_end) = adjacent_vertices(ver_current, G); ite_ver_adj != ite_ver_adj_end; ite_ver_adj++) 73 | { 74 | int i_ver_adj = vertex_index_map(*ite_ver_adj); 75 | if ((selected_vertices[i_ver_adj]==0) && (node_seen <= subgraph_size)) 76 | {//vertex not already seen 77 | node_seen = node_seen + 1; 78 | selected_vertices[i_ver_adj] = 1; 79 | ver_queue.push(*ite_ver_adj); 80 | } 81 | 82 | if (node_seen >= subgraph_size) 83 | {//enough vertices 84 | break; 85 | } 86 | 87 | } 88 | } 89 | } 90 | 91 | for (int i_edg = 0; i_edg < n_edg; i_edg++) 92 | { //add edges between selected vertices 93 | selected_edges[i_edg] = selected_vertices[vertex_index_map(vertex(Eu[i_edg],G))] 94 | * selected_vertices[vertex_index_map(vertex(Ev[i_edg],G))]; 95 | } 96 | return; 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /partition/visualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs 3 | http://arxiv.org/abs/1711.09869 4 | 2017 Loic Landrieu, Martin Simonovsky 5 | 6 | this functions outputs varied ply file to visualize the different steps 7 | """ 8 | import os.path 9 | import numpy as np 10 | import argparse 11 | import sys 12 | sys.path.append("./partition/") 13 | from plyfile import PlyData, PlyElement 14 | from provider import * 15 | parser = argparse.ArgumentParser(description='Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs') 16 | parser.add_argument('--dataset', default='s3dis', help='dataset name: sema3d|s3dis') 17 | parser.add_argument('--ROOT_PATH', default='/mnt/bigdrive/loic/S3DIS', help='folder containing the ./data folder') 18 | parser.add_argument('--res_file', default='../models/cv1/predictions_val', help='folder containing the results') 19 | parser.add_argument('--supervized_partition',type=int, default=0) 20 | parser.add_argument('--file_path', default='Area_1/conferenceRoom_1', help='file to output (must include the area / set in its path)') 21 | parser.add_argument('--upsample', default=0, type=int, help='if 1, upsample the prediction to the original cloud (if the files is huge it can take a very long and use a lot of memory - avoid on sema3d)') 22 | parser.add_argument('--ver_batch', default=0, type=int, help='Batch size for reading large files') 23 | parser.add_argument('--output_type', default='igfpres', help='which cloud to output: i = input rgb pointcloud \ 24 | , g = ground truth, f = geometric features, p = partition, r = prediction result \ 25 | , e = error, s = SPG') 26 | args = parser.parse_args() 27 | #---path to data--------------------------------------------------------------- 28 | #root of the data directory 29 | root = args.ROOT_PATH+'/' 30 | rgb_out = 'i' in args.output_type 31 | gt_out = 'g' in args.output_type 32 | fea_out = 'f' in args.output_type 33 | par_out = 'p' in args.output_type 34 | res_out = 'r' in args.output_type 35 | err_out = 'e' in args.output_type 36 | spg_out = 's' in args.output_type 37 | folder = os.path.split(args.file_path)[0] + '/' 38 | file_name = os.path.split(args.file_path)[1] 39 | 40 | if args.dataset == 's3dis': 41 | n_labels = 13 42 | if args.dataset == 'sema3d': 43 | n_labels = 8 44 | if args.dataset == 'vkitti': 45 | n_labels = 13 46 | if args.dataset == 'custom_dataset': 47 | n_labels = 10 48 | #---load the values------------------------------------------------------------ 49 | fea_file = root + "features/" + folder + file_name + '.h5' 50 | if not os.path.isfile(fea_file) or args.supervized_partition: 51 | fea_file = root + "features_supervision/" + folder + file_name + '.h5' 52 | spg_file = root + "superpoint_graphs/" + folder + file_name + '.h5' 53 | ply_folder = root + "clouds/" + folder 54 | ply_file = ply_folder + file_name 55 | res_file = args.res_file + '.h5' 56 | 57 | if not os.path.isdir(root + "clouds/"): 58 | os.mkdir(root + "clouds/" ) 59 | if not os.path.isdir(ply_folder ): 60 | os.mkdir(ply_folder) 61 | if (not os.path.isfile(fea_file)) : 62 | raise ValueError("%s does not exist and is needed" % fea_file) 63 | 64 | geof, xyz, rgb, graph_nn, labels = read_features(fea_file) 65 | 66 | if (par_out or res_out) and (not os.path.isfile(spg_file)): 67 | raise ValueError("%s does not exist and is needed to output the partition or result ply" % spg_file) 68 | else: 69 | graph_spg, components, in_component = read_spg(spg_file) 70 | if res_out or err_out: 71 | if not os.path.isfile(res_file): 72 | raise ValueError("%s does not exist and is needed to output the result ply" % res_file) 73 | try: 74 | pred_red = np.array(h5py.File(res_file, 'r').get(folder + file_name)) 75 | if (len(pred_red) != len(components)): 76 | raise ValueError("It looks like the spg is not adapted to the result file") 77 | pred_full = reduced_labels2full(pred_red, components, len(xyz)) 78 | except OSError: 79 | raise ValueError("%s does not exist in %s" % (folder + file_name, res_file)) 80 | #---write the output clouds---------------------------------------------------- 81 | if rgb_out: 82 | print("writing the RGB file...") 83 | write_ply(ply_file + "_rgb.ply", xyz, rgb) 84 | 85 | if gt_out: 86 | print("writing the GT file...") 87 | prediction2ply(ply_file + "_GT.ply", xyz, labels, n_labels, args.dataset) 88 | 89 | if fea_out: 90 | print("writing the features file...") 91 | geof2ply(ply_file + "_geof.ply", xyz, geof) 92 | 93 | if par_out: 94 | print("writing the partition file...") 95 | partition2ply(ply_file + "_partition.ply", xyz, components) 96 | 97 | if res_out and not bool(args.upsample): 98 | print("writing the prediction file...") 99 | prediction2ply(ply_file + "_pred.ply", xyz, pred_full+1, n_labels, args.dataset) 100 | 101 | if err_out: 102 | print("writing the error file...") 103 | error2ply(ply_file + "_err.ply", xyz, rgb, labels, pred_full+1) 104 | 105 | if spg_out: 106 | print("writing the SPG file...") 107 | spg2ply(ply_file + "_spg.ply", graph_spg) 108 | 109 | if res_out and bool(args.upsample): 110 | if args.dataset=='s3dis': 111 | data_file = root + 'data/' + folder + file_name + '/' + file_name + ".txt" 112 | xyz_up, rgb_up = read_s3dis_format(data_file, False) 113 | elif args.dataset=='sema3d':#really not recommended unless you are very confident in your hardware 114 | data_file = data_folder + file_name + ".txt" 115 | xyz_up, rgb_up = read_semantic3d_format(data_file, 0, '', 0, args.ver_batch) 116 | elif args.dataset=='custom_dataset': 117 | data_file = data_folder + file_name + ".ply" 118 | xyz_up, rgb_up = read_ply(data_file) 119 | del rgb_up 120 | pred_up = interpolate_labels(xyz_up, xyz, pred_full, args.ver_batch) 121 | print("writing the upsampled prediction file...") 122 | prediction2ply(ply_file + "_pred_up.ply", xyz_up, pred_up+1, n_labels, args.dataset) 123 | 124 | -------------------------------------------------------------------------------- /partition/write_Semantic3d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs 5 | http://arxiv.org/abs/1711.09869 6 | 2017 Loic Landrieu, Martin Simonovsky 7 | 8 | call this function once the partition and inference was made to upsample 9 | the prediction to the original point clouds 10 | """ 11 | import os.path 12 | import glob 13 | import numpy as np 14 | import argparse 15 | from provider import * 16 | parser = argparse.ArgumentParser(description='Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs') 17 | parser.add_argument('--SEMA3D_PATH', default='datasets/semantic3D') 18 | parser.add_argument('--odir', default='./results/semantic3d', help='Directory to store results') 19 | parser.add_argument('--ver_batch', default=5000000, type=int, help='Batch size for reading large files') 20 | parser.add_argument('--db_test_name', default='testred') 21 | args = parser.parse_args() 22 | #---path to data--------------------------------------------------------------- 23 | #root of the data directory 24 | root = args.SEMA3D_PATH+'/' 25 | #list of subfolders to be processed 26 | if args.db_test_name == 'testred': 27 | area = 'test_reduced/' 28 | elif args.db_test_name == 'testfull': 29 | area = 'test_full/' 30 | #------------------------------------------------------------------------------ 31 | print("=================\n " + area + "\n=================") 32 | data_folder = root + "data/" + area 33 | fea_folder = root + "features/" + area 34 | spg_folder = root + "superpoint_graphs/" + area 35 | res_folder = './' + args.odir + '/' 36 | labels_folder = root + "labels/" + area 37 | if not os.path.isdir(data_folder): 38 | raise ValueError("%s do not exists" % data_folder) 39 | if not os.path.isdir(fea_folder): 40 | raise ValueError("%s do not exists" % fea_folder) 41 | if not os.path.isdir(res_folder): 42 | raise ValueError("%s do not exists" % res_folder) 43 | if not os.path.isdir(root + "labels/"): 44 | os.mkdir(root + "labels/") 45 | if not os.path.isdir(labels_folder): 46 | os.mkdir(labels_folder) 47 | try: 48 | res_file = h5py.File(res_folder + 'predictions_' + args.db_test_name + '.h5', 'r') 49 | except OSError: 50 | raise ValueError("%s do not exists" % res_file) 51 | 52 | files = glob.glob(data_folder+"*.txt") 53 | if (len(files) == 0): 54 | raise ValueError('%s is empty' % data_folder) 55 | n_files = len(files) 56 | i_file = 0 57 | for file in files: 58 | file_name = os.path.splitext(os.path.basename(file))[0] 59 | file_name_short = '_'.join(file_name.split('_')[:2]) 60 | data_file = data_folder + file_name + ".txt" 61 | fea_file = fea_folder + file_name_short + '.h5' 62 | spg_file = spg_folder + file_name_short + '.h5' 63 | label_file = labels_folder + file_name_short + ".labels" 64 | i_file = i_file + 1 65 | print(str(i_file) + " / " + str(n_files) + "---> "+file_name_short) 66 | print(" reading the subsampled file...") 67 | geof, xyz, rgb, graph_nn, l = read_features(fea_file) 68 | graph_sp, components, in_component = read_spg(spg_file) 69 | n_ver = xyz.shape[0] 70 | del geof, rgb, graph_nn, l, graph_sp 71 | labels_red = np.array(res_file.get(area + file_name_short)) 72 | print(" upsampling...") 73 | labels_full = reduced_labels2full(labels_red, components, n_ver) 74 | labels_ups = interpolate_labels_batch(data_file, xyz, labels_full, args.ver_batch) 75 | np.savetxt(label_file, labels_ups+1, delimiter=' ', fmt='%d') # X is an array 76 | -------------------------------------------------------------------------------- /supervized_partition/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loicland/superpoint_graph/0209777339327c9b327b6947af6c89b20bb45981/supervized_partition/__init__.py -------------------------------------------------------------------------------- /supervized_partition/evaluate_partition.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Oct 11 10:12:49 2018 5 | 6 | @author: landrieuloic 7 | """ 8 | 9 | """ 10 | Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs 11 | http://arxiv.org/abs/1711.09869 12 | 2017 Loic Landrieu, Martin Simonovsky 13 | """ 14 | import glob, os 15 | import argparse 16 | import numpy as np 17 | import sys 18 | import ast 19 | import csv 20 | import h5py 21 | sys.path.append("./learning") 22 | from metrics import * 23 | 24 | parser = argparse.ArgumentParser(description='Evaluation function for S3DIS') 25 | 26 | parser.add_argument('--odir', default='./results_partition/', help='Directory to store results') 27 | parser.add_argument('--folder', default='', help='Directory to store results') 28 | parser.add_argument('--dataset', default='s3dis', help='Directory to store results') 29 | parser.add_argument('--cvfold', default='123456', help='which fold to consider') 30 | 31 | args = parser.parse_args() 32 | args.odir = args.odir + args.dataset + '/' 33 | 34 | 35 | root = args.odir + args.folder + '/' 36 | 37 | if args.dataset == 's3dis': 38 | fold_size = [44,40,23,49,68,48] 39 | files = glob.glob(root + 'cv{}'.format(args.cvfold[0]) + '/res*.h5') 40 | n_classes= 13 41 | elif args.dataset == 'vkitti': 42 | fold_size = [15,15,15,15,15,15] 43 | files = glob.glob(root + '0{}'.format(args.cvfold[0]) + '/res*.h5') 44 | n_classes = 13 45 | 46 | file_result_txt = open(args.odir + args.folder + '/results' + '.txt',"w") 47 | file_result_txt.write(" N \t ASA \t BR \t BP\n") 48 | 49 | 50 | C_classes = np.zeros((n_classes,n_classes)) 51 | C_BR = np.zeros((2,2)) 52 | C_BP = np.zeros((2,2)) 53 | N_sp = 0 54 | N_pc = 0 55 | 56 | for i_fold in range(len(args.cvfold)): 57 | fold = int(args.cvfold[i_fold]) 58 | if args.dataset == 's3dis': 59 | base_name = root + 'cv{}'.format(fold) 60 | elif args.dataset == 'vkitti': 61 | base_name = root + '0{}'.format(fold) 62 | 63 | try: 64 | file_name = base_name + '/res.h5' 65 | res_file = h5py.File(file_name, 'r') 66 | except OSError: 67 | raise NameError('Cant find pretrained model %s' % file_name) 68 | 69 | c_classes = np.array(res_file["confusion_matrix_classes"]) 70 | c_BP = np.array(res_file["confusion_matrix_BP"]) 71 | c_BR = np.array(res_file["confusion_matrix_BR"]) 72 | n_sp = np.array(res_file["n_clusters"]) 73 | print("Fold %d : \t n_sp = %5.1f \t ASA = %3.2f %% \t BR = %3.2f %% \t BP = %3.2f %%" % \ 74 | (fold, n_sp, 100 * c_classes.trace() / c_classes.sum(), 100 * c_BR[1,1] / (c_BR[1,1] + c_BR[1,0]),100 * c_BP[1,1] / (c_BP[1,1] + c_BP[0,1]) )) 75 | C_classes += c_classes 76 | C_BR += c_BR 77 | C_BP += c_BP 78 | N_sp += n_sp * fold_size[i_fold] 79 | N_pc += fold_size[i_fold] 80 | 81 | if N_sp>0: 82 | print("\nOverall : \t n_sp = %5.1f \t ASA = %3.2f %% \t BR = %3.2f %% \t BP = %3.2f %%\n" % \ 83 | (N_sp/N_pc, 100 * C_classes.trace() / C_classes.sum(), 100 * C_BR[1,1] / (C_BR[1,1] + C_BR[1,0]),100 * C_BP[1,1] / (C_BP[1,1] + C_BP[0,1]) )) 84 | file_result_txt.write("%4.1f \t %3.2f \t %3.2f \t %3.2f \n" % (N_sp/N_pc, 100 * C_classes.trace() / C_classes.sum(), 100 * C_BR[1,1] / (C_BR[1,1] + C_BR[1,0]),100 * C_BP[1,1] / (C_BP[1,1] + C_BP[0,1]) )) 85 | 86 | file_result_txt.close() -------------------------------------------------------------------------------- /supervized_partition/folderhierarchy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | DIR_PATH = os.path.dirname(os.path.realpath(__file__)) 5 | sys.path.insert(0, os.path.join(DIR_PATH, '..')) 6 | 7 | class FolderHierachy: 8 | SPG_FOLDER = "superpoint_graphs" 9 | EMBEDDINGS_FOLDER = "embeddings" 10 | SCALAR_FOLDER = "scalars" 11 | MODEL_FILE = "model.pth.tar" 12 | 13 | def __init__(self,outputdir,dataset_name,root_dir,cv_fold): 14 | self._root = root_dir 15 | if dataset_name=='s3dis': 16 | self._outputdir = os.path.join(outputdir,'cv' + str(cv_fold)) 17 | self._folders = ["Area_1/", "Area_2/", "Area_3/", "Area_4/", "Area_5/", "Area_6/"] 18 | elif dataset_name=='sema3d': 19 | self._outputdir = os.path.join(outputdir,'best') 20 | self._folders = ["train/", "test_reduced/", "test_full/"] 21 | elif dataset_name=='vkitti': 22 | self._outputdir = os.path.join(outputdir, 'cv' + str(cv_fold)) 23 | self._folders = ["01/", "02/", "03/", "04/", "05/", "06/"] 24 | 25 | if not os.path.exists(self._outputdir): 26 | os.makedirs(self._outputdir) 27 | 28 | self._spg_folder = self._create_folder(self.SPG_FOLDER) 29 | self._emb_folder = self._create_folder(self.EMBEDDINGS_FOLDER) 30 | self._scalars = self._create_folder(self.SCALAR_FOLDER) 31 | 32 | @property 33 | def outputdir(self): return self._outputdir 34 | 35 | @property 36 | def emb_folder(self): return self._emb_folder 37 | 38 | @property 39 | def spg_folder(self): return self._spg_folder 40 | 41 | @property 42 | def scalars(self): return self._scalars 43 | 44 | @property 45 | def model_path(self): return os.path.join(self._outputdir, self.MODEL_FILE) 46 | 47 | def _create_folder(self,property_name): 48 | folder = os.path.join(self._root , property_name ) 49 | if not os.path.isdir(folder): 50 | os.mkdir(folder) 51 | return folder -------------------------------------------------------------------------------- /supervized_partition/generate_partition.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import sys 4 | import os 5 | import torch 6 | import glob 7 | import torchnet as tnt 8 | import functools 9 | import tqdm 10 | from multiprocessing import Pool 11 | 12 | DIR_PATH = os.path.dirname(os.path.realpath(__file__)) 13 | sys.path.insert(0, os.path.join(DIR_PATH, '..')) 14 | 15 | from supervized_partition import supervized_partition 16 | from supervized_partition import graph_processing 17 | from supervized_partition import losses 18 | from partition import provider 19 | from partition import graphs 20 | from learning import pointnet 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description='Partition large scale point clouds using cut-pursuit') 25 | 26 | parser.add_argument('--modeldir', help='Folder where the saved model lies', required=True) 27 | parser.add_argument('--cuda', default=0, type=int, help='Bool, use cuda') 28 | parser.add_argument('--input_folder', type=str, 29 | help='Folder containing preprocessed point clouds ready for segmentation', required=True) 30 | parser.add_argument('--output_folder', default="", type=str, help='Folder that will contain the output') 31 | parser.add_argument('--overwrite', default=1, type=int, help='Overwrite existing partition') 32 | parser.add_argument('--nworkers', default=5, type=int, 33 | help='Num subprocesses to use for generating the SPGs') 34 | 35 | args = parser.parse_args() 36 | return args 37 | 38 | 39 | def load_model(model_dir, cuda): 40 | checkpoint = torch.load(os.path.join(model_dir, supervized_partition.FolderHierachy.MODEL_FILE)) 41 | training_args = checkpoint['args'] 42 | training_args.cuda = cuda # override cuda 43 | model = supervized_partition.create_model(training_args) 44 | model.load_state_dict(checkpoint['state_dict']) 45 | model.eval() 46 | return model, training_args 47 | 48 | 49 | def get_dataloader(input_folder, training_args): 50 | file_list = glob.glob(os.path.join(input_folder, '*.h5')) 51 | if not file_list: 52 | raise ValueError("Empty input folder: %s" % input_folder) 53 | dataset = tnt.dataset.ListDataset(file_list, 54 | functools.partial(graph_processing.graph_loader, train=False, args=training_args, db_path="")) 55 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, 56 | collate_fn=graph_processing.graph_collate) 57 | return loader 58 | 59 | 60 | def get_embedder(args): 61 | if args.learned_embeddings and args.ptn_embedding == 'ptn': 62 | ptnCloudEmbedder = pointnet.LocalCloudEmbedder(args) 63 | elif 'geof' in args.ver_value: 64 | ptnCloudEmbedder = graph_processing.spatialEmbedder(args) 65 | else: 66 | raise NameError('Do not know model ' + args.learned_embeddings) 67 | return ptnCloudEmbedder 68 | 69 | 70 | def get_num_classes(args): 71 | # Decide on the dataset 72 | if args.dataset == 's3dis': 73 | dbinfo = graph_processing.get_s3dis_info(args) 74 | elif args.dataset == 'sema3d': 75 | dbinfo = graph_processing.get_sema3d_info(args) 76 | elif args.dataset == 'vkitti': 77 | dbinfo = graph_processing.get_vkitti_info(args) 78 | else: 79 | raise NotImplementedError('Unknown dataset ' + args.dataset) 80 | return dbinfo["classes"] 81 | 82 | 83 | def process(data_tuple, model, output_folder, training_args, overwrite): 84 | fname, edg_source, edg_target, is_transition, labels, objects, clouds_data, xyz = data_tuple 85 | spg_file = os.path.join(output_folder, fname[0]) 86 | logging.info("\nGenerating SPG file %s...", spg_file) 87 | if os.path.exists(os.path.dirname(spg_file)) and not overwrite: 88 | logging.info("Already exists, skipping") 89 | return 90 | elif not os.path.exists(os.path.dirname(spg_file)): 91 | os.makedirs(os.path.dirname(spg_file)) 92 | 93 | if training_args.cuda: 94 | is_transition = is_transition.to('cuda', non_blocking=True) 95 | objects = objects.to('cuda', non_blocking=True) 96 | clouds, clouds_global, nei = clouds_data 97 | clouds_data = (clouds.to('cuda', non_blocking=True), clouds_global.to('cuda', non_blocking=True), nei) 98 | 99 | ptnCloudEmbedder = get_embedder(training_args) 100 | num_classes = get_num_classes(training_args) 101 | 102 | embeddings = ptnCloudEmbedder.run_batch(model, *clouds_data, xyz) 103 | 104 | diff = losses.compute_dist(embeddings, edg_source, edg_target, training_args.dist_type) 105 | 106 | pred_components, pred_in_component = losses.compute_partition( 107 | training_args, embeddings, edg_source, edg_target, diff, xyz) 108 | 109 | graph_sp = graphs.compute_sp_graph(xyz, 100, pred_in_component, pred_components, labels, num_classes) 110 | 111 | provider.write_spg(spg_file, graph_sp, pred_components, pred_in_component) 112 | 113 | 114 | def main(): 115 | logging.getLogger().setLevel(logging.INFO) # set to logging.DEBUG to allow for more prints 116 | args = parse_args() 117 | model, training_args = load_model(args.modeldir, args.cuda) 118 | dataloader = get_dataloader(args.input_folder, training_args) 119 | workers = max(args.nworkers, 1) 120 | 121 | output_folder = args.output_folder 122 | if not output_folder: 123 | # By default assumes that it follows the S3DIS folder structure 124 | output_folder = os.path.join(args.input_folder, '../..', supervized_partition.FolderHierachy.SPG_FOLDER) 125 | if not os.path.exists(output_folder): 126 | os.makedirs(output_folder) 127 | 128 | if logging.getLogger().getEffectiveLevel() > logging.DEBUG: 129 | dataloader = tqdm.tqdm(dataloader, ncols=100) 130 | with torch.no_grad(): 131 | processing_function = functools.partial( 132 | process, model=model, output_folder=output_folder, training_args=training_args, overwrite=args.overwrite) 133 | with Pool(workers) as p: 134 | p.map(processing_function, dataloader) 135 | 136 | logging.info("DONE for %s" % args.input_folder) 137 | 138 | 139 | if __name__ == "__main__": 140 | main() 141 | -------------------------------------------------------------------------------- /supervized_partition/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Sep 26 13:56:33 2018 5 | 6 | @author: landrieuloic 7 | 8 | """ 9 | import os 10 | import sys 11 | import math 12 | import numpy as np 13 | import torch 14 | 15 | DIR_PATH = os.path.dirname(os.path.realpath(__file__)) 16 | sys.path.insert(0, os.path.join(DIR_PATH, '..')) 17 | sys.path.append(os.path.join(DIR_PATH,"../partition/cut-pursuit/src")) 18 | 19 | from partition.provider import * 20 | from partition.ply_c import libply_c 21 | 22 | import libcp 23 | 24 | def zhang(x, lam, dist_type): 25 | if dist_type == 'euclidian' or dist_type == 'scalar': 26 | beta = 1 27 | elif dist_type == 'intrinsic': 28 | beta = 1.0471975512 29 | return torch.clamp(-lam * x + lam * beta, min = 0) 30 | 31 | def compute_dist(embeddings, edg_source, edg_target, dist_type): 32 | if dist_type == 'euclidian': 33 | dist = ((embeddings[edg_source,:] - embeddings[edg_target,:])**2).sum(1) 34 | elif dist_type == 'intrinsic': 35 | smoothness = 0.999 36 | dist = (torch.acos((embeddings[edg_source,:] * embeddings[edg_target,:]).sum(1) * smoothness)-np.arccos(smoothness)) \ 37 | / (np.arccos(-smoothness)-np.arccos(smoothness)) * 3.141592 38 | elif dist_type == 'scalar': 39 | dist = (embeddings[edg_source,:] * embeddings[edg_target,:]).sum(1)-1 40 | else: 41 | raise ValueError(" %s is an unknown argument of parameter --dist_type" % (dist_type)) 42 | return dist 43 | 44 | def compute_loss(args, diff, is_transition, weights_loss): 45 | intra_edg = is_transition==0 46 | if 'tv' in args.loss: 47 | loss1 = (weights_loss[intra_edg] * (torch.sqrt(diff[intra_edg]+1e-10))).sum() 48 | elif 'laplacian' in args.loss: 49 | loss1 = (weights_loss[intra_edg] * (diff[intra_edg])).sum() 50 | elif 'TVH' in args.loss: 51 | delta = 0.2 52 | loss1 = delta * (weights_loss[intra_edg] * (torch.sqrt(1+diff[intra_edg]/delta**2)-1)).sum() 53 | else: 54 | raise ValueError(" %s is an unknown argument of parameter --loss" % (args.loss)) 55 | 56 | inter_edg = is_transition==1 57 | 58 | if 'zhang' in args.loss: 59 | loss2 = (zhang(torch.sqrt(diff[inter_edg]+1e-10), weights_loss[inter_edg], args.dist_type)).sum() 60 | elif 'TVminus' in args.loss: 61 | loss2 = (torch.sqrt(diff[inter_edg]+1e-10) * weights_loss[inter_edg]).sum() 62 | 63 | #return loss1/ weights_loss.sum(), loss2/ weights_loss.sum() 64 | return loss1, loss2 65 | 66 | 67 | def compute_partition(args, embeddings, edg_source, edg_target, diff, xyz=0): 68 | edge_weight = np.ones_like(edg_source).astype('f4') 69 | if args.edge_weight_threshold>0: 70 | edge_weight[diff>1]=args.edge_weight_threshold 71 | if args.edge_weight_threshold<0: 72 | edge_weight = torch.exp(diff * args.edge_weight_threshold).detach().cpu().numpy()/np.exp(args.edge_weight_threshold) 73 | 74 | ver_value = np.zeros((embeddings.shape[0],0), dtype='f4') 75 | use_spatial = 0 76 | ver_value = np.hstack((ver_value,embeddings.detach().cpu().numpy())) 77 | if args.spatial_emb>0: 78 | ver_value = np.hstack((ver_value, args.spatial_emb * xyz))# * math.sqrt(args.reg_strength))) 79 | #ver_value = xyz * args.spatial_emb 80 | use_spatial = 1#!!! 81 | 82 | pred_components, pred_in_component = libcp.cutpursuit(ver_value, \ 83 | edg_source.astype('uint32'), edg_target.astype('uint32'), edge_weight, \ 84 | args.reg_strength / (4 * args.k_nn_adj), cutoff=args.CP_cutoff, spatial = use_spatial, weight_decay = 0.7) 85 | #emb2 = libcp.cutpursuit2(ver_value, edg_source.astype('uint32'), edg_target.astype('uint32'), edge_weight, args.reg_strength, cutoff=0, spatial =0) 86 | #emb2 = emb2.reshape(ver_value.shape) 87 | #((ver_value-emb2)**2).sum(0) 88 | #cut = pred_in_component[edg_source]!=pred_in_component[edg_target] 89 | return pred_components, pred_in_component 90 | 91 | def compute_weight_loss(args, embeddings, objects, edg_source, edg_target, is_transition, diff, return_partition, xyz=0): 92 | 93 | if args.loss_weight == 'seal' or args.loss_weight == 'crosspartition' or return_partition: 94 | pred_components, pred_in_component = compute_partition(args, embeddings, edg_source, edg_target, diff, xyz) 95 | 96 | if args.loss_weight=='none': 97 | weights_loss = np.ones_like(edg_target).astype('f4') 98 | elif args.loss_weight=='proportional': 99 | weights_loss = np.ones_like(edg_target).astype('f4') * float(len(is_transition)) / (1-is_transition).sum().float() 100 | weights_loss[is_transition.nonzero()] = float(len(is_transition)) / float(is_transition.sum()) * args.transition_factor 101 | weights_loss = weights_loss.cpu().numpy() 102 | elif args.loss_weight=='seal': 103 | weights_loss = compute_weights_SEAL(pred_components, pred_in_component, objects, edg_source, edg_target, is_transition, args.transition_factor) 104 | elif args.loss_weight=='crosspartition': 105 | weights_loss = compute_weights_XPART(pred_components, pred_in_component, objects.cpu().numpy(), edg_source, edg_target, is_transition.cpu().numpy(), args.transition_factor * 2 * args.k_nn_adj, xyz) 106 | else: 107 | raise ValueError(" %s is an unknown argument of parameter --loss" % (args.loss_weight)) 108 | 109 | if args.cuda: 110 | weights_loss = torch.from_numpy(weights_loss).cuda() 111 | else: 112 | weights_loss = torch.from_numpy(weights_loss) 113 | 114 | if return_partition: 115 | return weights_loss, pred_components, pred_in_component 116 | else: 117 | return weights_loss 118 | 119 | def compute_weights_SEAL(pred_components, pred_in_component, objects, edg_source, edg_target, is_transition, transition_factor): 120 | 121 | SEAL_weights = np.ones((len(edg_source),), dtype='float32') 122 | w_per_component = np.empty((len(pred_components),), dtype='uint32') 123 | for i_com in range(len(pred_components)): 124 | w_per_component[i_com] = len(pred_components[i_com]) - mode(objects[pred_components[i_com]], only_frequency=True) 125 | SEAL_weights[is_transition.nonzero()] += np.stack(\ 126 | (w_per_component[pred_in_component[edg_source[is_transition.nonzero()]]] 127 | , w_per_component[pred_in_component[edg_target[is_transition.nonzero()]]])).max(0) * transition_factor# 1 if not transition 1+w otherwise 128 | return SEAL_weights 129 | 130 | def compute_weights_XPART(pred_components, pred_in_component, objects, edg_source, edg_target, is_transition, transition_factor, xyz): 131 | 132 | SEAGL_weights = np.ones((len(edg_source),), dtype='float32') 133 | pred_transition = pred_in_component[edg_source]!=pred_in_component[edg_target] 134 | components_x, in_component_x = libply_c.connected_comp(pred_in_component.shape[0] \ 135 | , edg_source.astype('uint32'), edg_target.astype('uint32') \ 136 | , (is_transition+pred_transition==0).astype('uint8'), 0) 137 | 138 | edg_transition = is_transition.nonzero()[0] 139 | edg_source_trans = edg_source[edg_transition] 140 | edg_target_trans = edg_target[edg_transition] 141 | 142 | comp_x_weight = [len(c) for c in components_x] 143 | n_compx = len(components_x) 144 | 145 | edg_id = np.min((in_component_x[edg_source_trans],in_component_x[edg_target_trans]),0) * n_compx \ 146 | + np.max((in_component_x[edg_source_trans],in_component_x[edg_target_trans]),0) 147 | 148 | edg_id_unique , in_edge_id, sedg_weight = np.unique(edg_id, return_index=True, return_counts=True) 149 | 150 | for i_edg in range(len(in_edge_id)): 151 | i_com_1 = in_component_x[edg_source_trans[in_edge_id[i_edg]]] 152 | i_com_2 = in_component_x[edg_target_trans[in_edge_id[i_edg]]] 153 | weight = min(comp_x_weight[i_com_1], comp_x_weight[i_com_2]) \ 154 | / sedg_weight[i_edg] * transition_factor 155 | corresponding_trans_edg = edg_transition[\ 156 | ((in_component_x[edg_source_trans]==i_com_1) * (in_component_x[edg_target_trans]==i_com_2) \ 157 | + (in_component_x[edg_target_trans]==i_com_1) * (in_component_x[edg_source_trans]==i_com_2))] 158 | SEAGL_weights[corresponding_trans_edg] = SEAGL_weights[corresponding_trans_edg] + weight 159 | 160 | #missed_transition = ((is_transition==1)*(pred_transition==False)+(is_transition==0)*(pred_transition==True)).nonzero()[0] 161 | #missed_transition = ((is_transition==1)*(pred_transition==False)).nonzero()[0] 162 | #SEAGL_weights[missed_transition] = SEAGL_weights[missed_transition] * boosting_factor 163 | #scalar2ply('full_par.ply', xyz,pred_in_component) 164 | #scalar2ply('full_parX.ply', xyz,in_component_x) 165 | #edge_weight2ply2('w.ply', SEAGL_weights, xyz, edg_source, edg_target) 166 | return SEAGL_weights 167 | 168 | def mode(array, only_frequency=False): 169 | """compute the mode and the corresponding frequency of a given distribution""" 170 | u, counts = np.unique(array, return_counts=True) 171 | if only_frequency: return np.amax(counts) 172 | else: 173 | return u[np.argmax(counts)], np.amax(counts) 174 | 175 | def relax_edge_binary(edg_binary, edg_source, edg_target, n_ver, tolerance): 176 | if torch.is_tensor(edg_binary): 177 | relaxed_binary = edg_binary.cpu().numpy().copy() 178 | else: 179 | relaxed_binary = edg_binary.copy() 180 | transition_vertex = np.full((n_ver,), 0, dtype = 'uint8') 181 | for i_tolerance in range(tolerance): 182 | transition_vertex[edg_source[relaxed_binary.nonzero()]] = True 183 | transition_vertex[edg_target[relaxed_binary.nonzero()]] = True 184 | relaxed_binary[transition_vertex[edg_source]] = True 185 | relaxed_binary[transition_vertex[edg_target]>0] = True 186 | return relaxed_binary 187 | 188 | -------------------------------------------------------------------------------- /vKITTI3D.md: -------------------------------------------------------------------------------- 1 | # vKITTI3D 2 | 3 | Download all point clouds and labels from [the vKITTI3D Dataset](https://github.com/VisualComputingInstitute/vkitti3D-dataset) and place extracted training files to `$VKITTI3D_DIR/data/`. 4 | 5 | ## Handcrafted Partition 6 | 7 | Not available for this dataset, the sparsity of the acquisition renders the handcrafted geometric features useless. 8 | 9 | ## Learned Partition 10 | 11 | For the learned partition, run: 12 | ``` 13 | python supervized_partition/graph_processing.py --ROOT_PATH $VKITTI3D_DIR --dataset vkitti --voxel_width 0.05 --use_voronoi 1 14 | 15 | for FOLD in 1 2 3 4 5 6; do 16 | python ./supervized_partition/supervized_partition.py --ROOT_PATH $VKITTI3D_DIR --dataset vkitti \ 17 | --epochs 50 --test_nth_epoch 10 --cvfold $FOLD --reg_strength 0.5 --spatial_emb 0.02 --batch_size 15 \ 18 | --global_feat exyrgb --CP_cutoff 10 --odir results_part/vkitti/best; \ 19 | done; 20 | ``` 21 | or use our [trained weights](http://recherche.ign.fr/llandrieu/SPG/vkitti/results_part/pretrained.zip) and the `--resume RESUME` argument: 22 | 23 | ``` 24 | python supervized_partition/graph_processing.py --ROOT_PATH $VKITTI3D_DIR --dataset vkitti --voxel_width 0.05 --use_voronoi 1 25 | 26 | for FOLD in 1 2 3 4 5 6; do 27 | python ./supervized_partition/supervized_partition.py --ROOT_PATH $VKITTI3D_DIR --dataset vkitti \ 28 | --epochs -1 --test_nth_epoch 10 --cvfold $FOLD --reg_strength 0.5 --spatial_emb 0.02 --batch_size 15\ 29 | --global_feat exyrgb --CP_cutoff 10 --odir results_partition/vkitti/pretrained --resume RESUME; \ 30 | done; 31 | ``` 32 | 33 | To evaluate the quality of the partition, run: 34 | ``` 35 | python supervized_partition/evaluate_partition.py --dataset vkitti --cvfold 123456 --folder best 36 | ``` 37 | ### Training 38 | 39 | Then, reorganize point clouds into superpoints by: 40 | ``` 41 | python learning/vkitti_dataset.py --VKITTI_PATH $VKITTI3D_DIR 42 | ``` 43 | 44 | To train from scratch, run: 45 | ``` 46 | for FOLD in 1 2 3 4 5 6; do \ 47 | CUDA_VISIBLE_DEVICES=0 python ./learning/main.py --dataset vkitti --VKITTI_PATH $VKITTI3D_DIR --cvfold $FOLD --epochs 100 \ 48 | --lr_steps "[40, 50, 60, 70, 80]" --test_nth_epoch 10 --model_config gru_10_1_1_1_0,f_13 --pc_attribs xyzXYZrgb \ 49 | --ptn_nfeat_stn 9 --batch_size 4 --ptn_minpts 15 --spg_augm_order 3 --spg_augm_hardcutoff 256 \ 50 | --ptn_widths "[[64,64,128], [64,32,32]]" --ptn_widths_stn "[[32,64], [32,16]]" --loss_weights sqrt \ 51 | --use_val_set 1 --odir results/vkitti/best/cv$FOLD; \ 52 | done;\ 53 | ``` 54 | 55 | or use our [trained weights](http://recherche.ign.fr/llandrieu/SPG/vkitti/results/pretrained.zip) with the `--resume RESUME` argument: 56 | ``` 57 | for FOLD in 1 2 3 4 5 6; do \ 58 | CUDA_VISIBLE_DEVICES=0 python ./learning/main.py --dataset vkitti --VKITTI_PATH $VKITTI3D_DIR --cvfold $FOLD --epochs 100 \ 59 | --lr_steps "[40, 50, 60, 70, 80]" --test_nth_epoch 10 --model_config gru_10_1_1_1_0,f_13 --pc_attribs xyzXYZrgb \ 60 | --ptn_nfeat_stn 9 --batch_size 4 --ptn_minpts 15 --spg_augm_order 3 --spg_augm_hardcutoff 256 \ 61 | --ptn_widths "[[64,64,128], [64,32,32]]" --ptn_widths_stn "[[32,64], [32,16]]" --loss_weights sqrt \ 62 | --use_val_set 1 --odir results/vkitti/pretrained/cv$FOLD --resume RESUME; \ 63 | done;\ 64 | ``` 65 | 66 | Estimate the quality of the semantic segmentation with: 67 | ``` 68 | python learning/evaluate.py --dataset vkitti --odir results/vkitti/best --cvfold 123456 69 | ``` 70 | #### Visualization 71 | 72 | To visualize the results and intermediary steps (on the subsampled graph), use the visualize function in partition. For example: 73 | ``` 74 | python partition/visualize.py --dataset vkitti --ROOT_PATH $VKITTI3D_DIR --res_file 'results/vkitti/cv1/predictions_test' --file_path '01/0001_00000' --output_type ifprs 75 | ``` 76 | --------------------------------------------------------------------------------