├── LICENSE ├── README.md ├── config ├── parameters(PCPNet-Semantic).yaml └── parameters.yaml ├── figs ├── motivation.png ├── overall_architecture.png └── predictions.gif ├── pcpnet ├── __init__.py ├── __pycache__ │ └── __init__.cpython-38.pyc ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── datasets.cpython-38.pyc │ └── datasets.py ├── models │ ├── PCPNet.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── PCPNet.cpython-38.pyc │ │ ├── PPT.cpython-38.pyc │ │ ├── TCNet.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── base.cpython-38.pyc │ │ ├── layers.cpython-38.pyc │ │ ├── loss.cpython-38.pyc │ │ └── transformer.cpython-38.pyc │ ├── base.py │ ├── layers.py │ └── loss.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── logger.cpython-38.pyc │ ├── preprocess_data.cpython-38.pyc │ ├── projection.cpython-38.pyc │ ├── utils.cpython-38.pyc │ └── visualization.cpython-38.pyc │ ├── logger.py │ ├── preprocess_data.py │ ├── projection.py │ ├── utils.py │ └── visualization.py ├── poetry.lock ├── pyTorchChamferDistance ├── LICENSE.md ├── README.md └── chamfer_distance │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── chamfer_distance.cpython-38.pyc │ ├── chamfer_distance.cpp │ ├── chamfer_distance.cu │ └── chamfer_distance.py ├── pyproject.toml ├── semantic_net └── rangenet │ ├── README.md │ ├── __init__.py │ ├── __pycache__ │ └── __init__.cpython-38.pyc │ ├── backbones │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── squeezeseg.cpython-38.pyc │ │ └── squeezesegV2.cpython-38.pyc │ ├── darknet.py │ ├── squeezeseg.py │ └── squeezesegV2.py │ ├── common │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── laserscan.cpython-38.pyc │ └── laserscan.py │ ├── requirements.txt │ └── tasks │ ├── __init__.py │ ├── __pycache__ │ └── __init__.cpython-38.pyc │ └── semantic │ ├── README.md │ ├── __init__.py │ ├── __pycache__ │ └── __init__.cpython-38.pyc │ ├── config │ └── labels │ │ ├── semantic-kitti-all.yaml │ │ └── semantic-kitti.yaml │ ├── decoders │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── squeezeseg.cpython-38.pyc │ │ └── squeezesegV2.cpython-38.pyc │ ├── darknet.py │ ├── squeezeseg.py │ └── squeezesegV2.py │ ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── segmentator.cpython-38.pyc │ └── segmentator.py │ ├── postproc │ ├── CRF.py │ ├── KNN.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── CRF.cpython-38.pyc │ │ └── __init__.cpython-38.pyc │ └── borderMask.py │ └── readme.md ├── test.py ├── train.py └── visualize.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Zhen Luo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PCPNet: An Efficient and Semantic-Enhanced Transformer Network for Point Cloud Prediction 2 | 3 | Accepted by IEEE RA-L. 4 | 5 | Developed by [Zhen Luo](https://github.com/Blurryface0814) and [Junyi Ma](https://github.com/BIT-MJY). 6 | 7 | 8 | 9 | *PCPNet predicts future range images based on past range image sequences. The semantic information of sequential range images is extracted for auxiliary training, making the outputs of PCPNet closer to the ground truth in semantics.* 10 | 11 | 12 | 13 | *Ground truth and future range images predicted by PCPNet.* 14 | 15 | ## Contents 16 | 1. [Publication](#Publication) 17 | 2. [Dataset](#Dataset) 18 | 3. [Installation](#Installation) 19 | 4. [Training](#Training) 20 | 5. [Semantic-Auxiliary-Training](#Semantic-Auxiliary-Training) 21 | 6. [Testing](#Testing) 22 | 7. [Visualization](#Visualization) 23 | 8. [Download](#Dwnload) 24 | 9. [Acknowledgment](#Acknowledgment) 25 | 10. [License](#License) 26 | 27 | ![](figs/overall_architecture.png) 28 | *Overall architecture of our proposed PCPNet. The input range images are first downsampled and compressed along the height and width dimensions respectively to generate the sentence-like features for the following Transformer blocks. The enhanced features are then combined and upsampled to the predicted range images and mask images. Semantic auxiliary training is used to improve the practical value of predicted point clouds.* 29 | 30 | ## Publication 31 | If you use our code in your academic work, please cite the corresponding [paper](https://ieeexplore.ieee.org/abstract/document/10141631?casa_token=VCXSYRIkT88AAAAA:-LLz-ZSNJVSLCYSjXjzpV_DrwtBggRvOKW_1dWxUDNa1IE4VzREdHovp-PyD1zA9rVlRZblXpQu1qfQ): 32 | 33 | ```latex 34 | @ARTICLE{10141631, 35 | author={Luo, Zhen and Ma, Junyi and Zhou, Zijie and Xiong, Guangming}, 36 | journal={IEEE Robotics and Automation Letters}, 37 | title={PCPNet: An Efficient and Semantic-Enhanced Transformer Network for Point Cloud Prediction}, 38 | year={2023}, 39 | volume={8}, 40 | number={7}, 41 | pages={4267-4274}, 42 | doi={10.1109/LRA.2023.3281937}} 43 | ``` 44 | 45 | ## Dataset 46 | We use the KITTI Odometry dataset to train and evaluate PCPNet in this repo, which you can download [here](http://www.cvlibs.net/datasets/kitti/eval_odometry.php). 47 | 48 | 49 | Besides, we use SemanticKITTI for semantic auxiliary training, which you can download [here](http://semantic-kitti.org/dataset.html#download). 50 | 51 | ## Installation 52 | 53 | ### Source Code 54 | Clone this repository and run 55 | ```bash 56 | cd PCPNet 57 | git submodule update --init 58 | ``` 59 | to install the Chamfer distance submodule. The Chamfer distance submodule is originally taken from [here](https://github.com/chrdiller/pyTorchChamferDistance) with some modifications to use it as a submodule. 60 | 61 | All parameters are stored in ```config/parameters.yaml```. 62 | 63 | ### Dependencies 64 | In this project, we use CUDA 11.4, pytorch 1.8.0 and pytorch-lightning 1.5.0. All other dependencies are managed with Python Poetry and can be found in the ```poetry.lock``` file. If you want to use Python Poetry, install it with: 65 | ```bash 66 | curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/install-poetry.py | python - 67 | ``` 68 | 69 | Install Python dependencies with Python Poetry 70 | ```bash 71 | poetry install 72 | ``` 73 | 74 | and activate the virtual environment in the shell with 75 | ```bash 76 | poetry shell 77 | ``` 78 | 79 | ## Training 80 | We process the data in advance to speed up training. The preprocessing is automatically done if ```GENERATE_FILES``` is set to True in ```config/parameters.yaml```. 81 | 82 | If you have not pre-processed the data yet, please set ```GENERATE_FILES: True``` in ```config/parameters.yaml```, and run the training script by 83 | ```bash 84 | python train.py --rawdata /PATH/TO/RAW/KITTI/dataset/sequences --dataset /PATH/TO/PROCESSED/dataset/ 85 | ``` 86 | in which ```--rawdata``` points to the directory containing the train/val/test sequences specified in the config file and ```--dataset``` points to the directory containing the processed train/val/test sequences 87 | 88 | If you have already pre-processed the data, please set ```GENERATE_FILES: False``` to skip this step, and run the training script by 89 | ```bash 90 | python train.py --dataset /PATH/TO/PROCESSED/dataset/ 91 | ``` 92 | 93 | To resume from a checkpoint, please run the training script by 94 | ```bash 95 | python train.py --dataset /PATH/TO/PROCESSED/dataset/ --resume /PATH/TO/YOUR/MODEL/ 96 | ``` 97 | You can also use the flag```--weights``` to initialize the weights from a pre-trained model. Pass the flag ```--help``` if you want to see more options. 98 | 99 | A directory will be created in ```runs``` which saves everything like the model files, used config, logs, and checkpoint. 100 | 101 | ## Semantic-Auxiliary-Training 102 | If you want to use our proposed semantic auxiliary training strategy, you need to first pre-train a semantic segmentation model. We provide codes for semantic auxiliary training using RangeNet++ in ```semantic_net/rangenet```. To use these codes, please first clone the [official codes](https://github.com/PRBonn/lidar-bonnetal) of RangeNet++ and train a semantic segmentation model. 103 | 104 | *Note that we recommend using squeezesegV2 backbone without CRF and only use ```range``` in the ```input_depth``` option while training RangeNet++, according to the codes we are currently providing.* If you want to use other backbones, please make corresponding modifications to ```class loss_semantic``` in ```pcpnet/models/loss.py```. 105 | 106 | Once you have completed the pre-training, you need to copy the folder containing the pre-trained model to ```semantic_net/rangenet/model/``` and modify ```SEMANTIC_PRETRAINED_MODEL``` in ```config/parameters.yaml``` to the folder name. 107 | 108 | After completing the above steps, you can start to use semantic auxiliary training by running the training script by 109 | ```bash 110 | python train.py --dataset /PATH/TO/PROCESSED/dataset/ 111 | ``` 112 | *Note* that you need to set ```LOSS_WEIGHT_SEMANTIC``` in ```config/parameters.yaml``` to the weight you want (we recommend 1.0) instead of 0.0 before running the training script. 113 | 114 | ## Testing 115 | You can evaluate the performance of the model by running 116 | ```bash 117 | python test.py --dataset /PATH/TO/PROCESSED/dataset/ --model /PATH/TO/YOUR/MODEL/ 118 | ``` 119 | *Note*: Please use the flag ```-s``` if you want to save the predicted point clouds for visualization, and ```-l``` if you want to test the model on a smaller amount of data. By using the flag ```-o```, you can only save the predicted point clouds without computing loss to accelerate the speed of saving. 120 | 121 | ## Visualization 122 | After passing the ```-s``` flag or the ```-o```flag to the testing script, the predicted range images will be saved as .png files in ```runs/MODEL_NAME/test_TIME/range_view_predictions```. The predicted point clouds are saved to ```runs/MODEL_NAME/test_TIME/point_clouds```. You can visualize the predicted point clouds by running 123 | ```bash 124 | python visualize.py --path runs/MODEL_NAME/test_TIME/point_clouds 125 | ``` 126 | Please download the car model from [here](https://drive.google.com/drive/folders/1bmBMdfJaN2ptJVh7gHv1Gy8L1aLMOslr?usp=share_link) and put it into ```./car_ model/``` to display the car during the visualization process. 127 | 128 | 129 | ## Download 130 | You can download our pre-trained model from this [link](https://drive.google.com/drive/folders/1p9q_SoXOsigi8vB_bXxWNZbU-ypSm19J?usp=share_link). Just extract the zip file into ```runs```. 131 | 132 | ## Acknowledgment 133 | We would like to thank Benedikt Mersch, Andres Milioto and Christian Diller et al. for their significant contributions in the field of point cloud processing. Some of the code in this repo is borrowed from [TCNet](https://github.com/PRBonn/point-cloud-prediction), [RangeNet++](https://github.com/PRBonn/lidar-bonnetal), and [pyTorchChamferDistance](https://github.com/chrdiller/pyTorchChamferDistance). Thank all the authors for their awesome projects. 134 | 135 | ## License 136 | This project is free software made available under the MIT License. For details see the LICENSE file. 137 | -------------------------------------------------------------------------------- /config/parameters(PCPNet-Semantic).yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | ID: PCPNet-Semantic # Give your experiment a unique ID which is used in the log 3 | 4 | DATA_CONFIG: 5 | DATASET_NAME: KITTIOdometry 6 | GENERATE_FILES: False # If true, the data will be pre-processed 7 | COMPUTE_MEAN_AND_STD: False # If true, the mean and std of the training data will be computed to use it in MEAN and STD. 8 | RANDOM_SEED: 1 # Set random seed for torch, numpy and python 9 | DATALOADER: 10 | NUM_WORKER: 4 11 | SHUFFLE: True 12 | SPLIT: # Specify the sequences here. We follow the KITTI format. 13 | TRAIN: 14 | - 0 15 | - 1 16 | - 2 17 | - 3 18 | - 4 19 | - 5 20 | VAL: 21 | - 6 22 | - 7 23 | TEST: 24 | - 8 25 | - 9 26 | - 10 27 | 28 | HEIGHT: 64 # Height of range images 29 | WIDTH: 2048 # Width of range images 30 | FOV_UP: 3.0 # Depends on the used LiDAR sensor 31 | FOV_DOWN: -25.0 # Depends on the used LiDAR sensor 32 | MAX_RANGE: 85.0 # Average max value in training set 33 | MIN_RANGE: 1.0 # Average min value in training set 34 | MEAN: 35 | - 10.839 # Range 36 | - 0.005 # X 37 | - 0.494 # Y 38 | - -1.13 # Z 39 | - 0.287 # Intensity 40 | 41 | STD: 42 | - 9.314 # Range 43 | - 11.521 # X 44 | - 8.262 # Y 45 | - 0.828 # Z 46 | - 0.14 # Intensity 47 | 48 | MODEL: 49 | N_PAST_STEPS: 5 # Number of input range images 50 | N_FUTURE_STEPS: 5 # Number of predicted future range images 51 | MASK_THRESHOLD: 0.5 # Threshold for valid point mask classification 52 | USE: 53 | XYZ: False # If true: x,y, and z coordinates will be used as additional input channels 54 | INTENSITY: False # If true: intensity will be used as additional input channel 55 | CHANNELS: # Number of channels in pre-encoder and post-decoder, respectively. 56 | - 16 57 | - 32 58 | - 64 59 | SKIP_IF_CHANNEL_SIZE: # Adds a skip connection between pre-encoder and post-decoder at desired channels 60 | - 32 61 | - 64 62 | TRANSFORMER_H_LAYERS: 4 # Number of layers in transfomer_H 63 | TRANSFORMER_W_LAYERS: 4 # Number of layers in transfomer_W 64 | CIRCULAR_PADDING: True 65 | NORM: batch # batch, group, none, instance 66 | N_CHANNELS_PER_GROUP: 16 67 | SEMANTIC_NET: rangenet 68 | SEMANTIC_DATA_CONFIG: semantic-kitti.yaml 69 | SEMANTIC_PRETRAINED_MODEL: squeezesegV2 70 | 71 | TRAIN: 72 | LR: 0.001 73 | LR_EPOCH: 1 74 | LR_DECAY: 0.99 75 | MAX_EPOCH: 50 76 | BATCH_SIZE: 2 77 | BATCH_ACC: 4 78 | N_GPUS: 1 79 | LOG_EVERY_N_STEPS: 10 80 | LOSS_WEIGHT_CHAMFER_DISTANCE: 0.0 81 | LOSS_WEIGHT_RANGE_VIEW: 1.0 82 | LOSS_WEIGHT_MASK: 1.0 83 | LOSS_WEIGHT_SEMANTIC: 1.0 84 | CHAMFER_DISTANCE_EVERY_N_VAL_EPOCH: 5 # Log chamfer distance every N val epoch. 85 | 86 | VALIDATION: 87 | SELECTED_SEQUENCE_AND_FRAME: # Only log point clouds for selected validation sequence and frame. 88 | 6: 89 | - 10 90 | 91 | TEST: 92 | N_BATCHES_TO_SAVE: -1 # If set to -1 and SAVE_POINT_CLOUDS is true, all batches of the test set will be saved. 93 | SAVE_POINT_CLOUDS: False 94 | ONLY_SAVE_POINT_CLOUDS: False # Only save point clouds, not compute loss 95 | SEMANTIC_SIMILARITY: True # compute semantic similarity 96 | N_DOWNSAMPLED_POINTS_CD: -1 # Can evaluate the CD on downsampled point clouds. Set -1 to evaluate on full point clouds. 97 | SELECTED_SEQUENCE_AND_FRAME: # Only log point clouds for selected test sequence and frame. 98 | 8: 99 | - 120 100 | -------------------------------------------------------------------------------- /config/parameters.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | ID: PCPNet # Give your experiment a unique ID which is used in the log 3 | 4 | DATA_CONFIG: 5 | DATASET_NAME: KITTIOdometry 6 | GENERATE_FILES: False # If true, the data will be pre-processed 7 | COMPUTE_MEAN_AND_STD: False # If true, the mean and std of the training data will be computed to use it in MEAN and STD. 8 | RANDOM_SEED: 1 # Set random seed for torch, numpy and python 9 | DATALOADER: 10 | NUM_WORKER: 4 11 | SHUFFLE: True 12 | SPLIT: # Specify the sequences here. We follow the KITTI format. 13 | TRAIN: 14 | - 0 15 | - 1 16 | - 2 17 | - 3 18 | - 4 19 | - 5 20 | VAL: 21 | - 6 22 | - 7 23 | TEST: 24 | - 8 25 | - 9 26 | - 10 27 | 28 | HEIGHT: 64 # Height of range images 29 | WIDTH: 2048 # Width of range images 30 | FOV_UP: 3.0 # Depends on the used LiDAR sensor 31 | FOV_DOWN: -25.0 # Depends on the used LiDAR sensor 32 | MAX_RANGE: 85.0 # Average max value in training set 33 | MIN_RANGE: 1.0 # Average min value in training set 34 | MEAN: 35 | - 10.839 # Range 36 | - 0.005 # X 37 | - 0.494 # Y 38 | - -1.13 # Z 39 | - 0.287 # Intensity 40 | 41 | STD: 42 | - 9.314 # Range 43 | - 11.521 # X 44 | - 8.262 # Y 45 | - 0.828 # Z 46 | - 0.14 # Intensity 47 | 48 | MODEL: 49 | N_PAST_STEPS: 5 # Number of input range images 50 | N_FUTURE_STEPS: 5 # Number of predicted future range images 51 | MASK_THRESHOLD: 0.5 # Threshold for valid point mask classification 52 | USE: 53 | XYZ: False # If true: x,y, and z coordinates will be used as additional input channels 54 | INTENSITY: False # If true: intensity will be used as additional input channel 55 | CHANNELS: # Number of channels in pre-encoder and post-decoder, respectively. 56 | - 16 57 | - 32 58 | - 64 59 | SKIP_IF_CHANNEL_SIZE: # Adds a skip connection between pre-encoder and post-decoder at desired channels 60 | - 32 61 | - 64 62 | TRANSFORMER_H_LAYERS: 4 # Number of layers in transfomer_H 63 | TRANSFORMER_W_LAYERS: 4 # Number of layers in transfomer_W 64 | CIRCULAR_PADDING: True 65 | NORM: batch # batch, group, none, instance 66 | N_CHANNELS_PER_GROUP: 16 67 | SEMANTIC_NET: rangenet 68 | SEMANTIC_DATA_CONFIG: semantic-kitti.yaml 69 | SEMANTIC_PRETRAINED_MODEL: squeezesegV2 70 | 71 | TRAIN: 72 | LR: 0.001 73 | LR_EPOCH: 1 74 | LR_DECAY: 0.99 75 | MAX_EPOCH: 50 76 | BATCH_SIZE: 2 77 | BATCH_ACC: 4 78 | N_GPUS: 1 79 | LOG_EVERY_N_STEPS: 10 80 | LOSS_WEIGHT_CHAMFER_DISTANCE: 0.0 81 | LOSS_WEIGHT_RANGE_VIEW: 1.0 82 | LOSS_WEIGHT_MASK: 1.0 83 | LOSS_WEIGHT_SEMANTIC: 0.0 84 | CHAMFER_DISTANCE_EVERY_N_VAL_EPOCH: 5 # Log chamfer distance every N val epoch. 85 | 86 | VALIDATION: 87 | SELECTED_SEQUENCE_AND_FRAME: # Only log point clouds for selected validation sequence and frame. 88 | 6: 89 | - 10 90 | 91 | TEST: 92 | N_BATCHES_TO_SAVE: -1 # If set to -1 and SAVE_POINT_CLOUDS is true, all batches of the test set will be saved. 93 | SAVE_POINT_CLOUDS: False 94 | ONLY_SAVE_POINT_CLOUDS: False # Only save point clouds, not compute loss 95 | SEMANTIC_SIMILARITY: False # compute semantic similarity 96 | N_DOWNSAMPLED_POINTS_CD: -1 # Can evaluate the CD on downsampled point clouds. Set -1 to evaluate on full point clouds. 97 | SELECTED_SEQUENCE_AND_FRAME: # Only log point clouds for selected test sequence and frame. 98 | 8: 99 | - 120 100 | -------------------------------------------------------------------------------- /figs/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/figs/motivation.png -------------------------------------------------------------------------------- /figs/overall_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/figs/overall_architecture.png -------------------------------------------------------------------------------- /figs/predictions.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/figs/predictions.gif -------------------------------------------------------------------------------- /pcpnet/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | -------------------------------------------------------------------------------- /pcpnet/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/datasets/__init__.py -------------------------------------------------------------------------------- /pcpnet/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/datasets/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/datasets/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou 3 | # This file is covered by the LICENSE file in the root of the project PCPNet: 4 | # https://github.com/Blurryface0814/PCPNet 5 | # Brief: Dataset Implementation for KITTI Odometry 6 | import os 7 | import yaml 8 | import torch 9 | import numpy as np 10 | from torch.utils.data import Dataset, DataLoader 11 | from pytorch_lightning import LightningDataModule 12 | 13 | from pcpnet.utils.projection import projection 14 | from pcpnet.utils.preprocess_data import prepare_data, compute_mean_and_std 15 | from pcpnet.utils.utils import load_files 16 | 17 | 18 | class KittiOdometryModule(LightningDataModule): 19 | """A Pytorch Lightning module for KITTI Odometry""" 20 | 21 | def __init__(self, cfg, dataset_path, rawdata_path=None): 22 | """Method to initizalize the Kitti Odometry dataset class 23 | 24 | Args: 25 | cfg: config dict 26 | 27 | Returns: 28 | None 29 | """ 30 | super(KittiOdometryModule, self).__init__() 31 | self.cfg = cfg 32 | self.dataset_path = dataset_path 33 | self.rawdata_path = rawdata_path 34 | 35 | def prepare_data(self): 36 | """Call prepare_data method to generate npy range images from raw LiDAR data""" 37 | if self.cfg["DATA_CONFIG"]["GENERATE_FILES"]: 38 | prepare_data(self.cfg, self.dataset_path, self.rawdata_path) 39 | 40 | def setup(self, stage=None): 41 | """Dataloader and iterators for training, validation and test data""" 42 | ########## Point dataset splits 43 | train_set = KittiOdometryRaw(self.cfg, split="train", dataset_path=self.dataset_path) 44 | 45 | val_set = KittiOdometryRaw(self.cfg, split="val", dataset_path=self.dataset_path) 46 | 47 | test_set = KittiOdometryRaw(self.cfg, split="test", dataset_path=self.dataset_path) 48 | 49 | ########## Generate dataloaders and iterables 50 | 51 | self.train_loader = DataLoader( 52 | dataset=train_set, 53 | batch_size=self.cfg["TRAIN"]["BATCH_SIZE"], 54 | shuffle=self.cfg["DATA_CONFIG"]["DATALOADER"]["SHUFFLE"], 55 | num_workers=self.cfg["DATA_CONFIG"]["DATALOADER"]["NUM_WORKER"], 56 | pin_memory=True, 57 | drop_last=False, 58 | timeout=0, 59 | ) 60 | self.train_iter = iter(self.train_loader) 61 | 62 | self.valid_loader = DataLoader( 63 | dataset=val_set, 64 | batch_size=self.cfg["TRAIN"]["BATCH_SIZE"], 65 | shuffle=False, 66 | num_workers=self.cfg["DATA_CONFIG"]["DATALOADER"]["NUM_WORKER"], 67 | pin_memory=True, 68 | drop_last=False, 69 | timeout=0, 70 | ) 71 | self.valid_iter = iter(self.valid_loader) 72 | 73 | self.test_loader = DataLoader( 74 | dataset=test_set, 75 | batch_size=self.cfg["TRAIN"]["BATCH_SIZE"], 76 | shuffle=False, 77 | num_workers=self.cfg["DATA_CONFIG"]["DATALOADER"]["NUM_WORKER"], 78 | pin_memory=True, 79 | drop_last=False, 80 | timeout=0, 81 | ) 82 | self.test_iter = iter(self.test_loader) 83 | 84 | # Optionally compute statistics of training data 85 | if self.cfg["DATA_CONFIG"]["COMPUTE_MEAN_AND_STD"]: 86 | compute_mean_and_std(self.cfg, self.train_loader) 87 | 88 | print( 89 | "Loaded {:d} training, {:d} validation and {:d} test samples.".format( 90 | len(train_set), len(val_set), (len(test_set)) 91 | ) 92 | ) 93 | 94 | def train_dataloader(self): 95 | return self.train_loader 96 | 97 | def val_dataloader(self): 98 | return self.valid_loader 99 | 100 | def test_dataloader(self): 101 | return self.test_loader 102 | 103 | 104 | class KittiOdometryRaw(Dataset): 105 | """Dataset class for range image-based point cloud prediction""" 106 | 107 | def __init__(self, cfg, split, dataset_path=None): 108 | """Read parameters and scan data 109 | 110 | Args: 111 | cfg (dict): Config parameters 112 | split (str): Data split 113 | 114 | Raises: 115 | Exception: [description] 116 | """ 117 | self.cfg = cfg 118 | self.root_dir = dataset_path 119 | self.height = self.cfg["DATA_CONFIG"]["HEIGHT"] 120 | self.width = self.cfg["DATA_CONFIG"]["WIDTH"] 121 | self.n_channels = 5 + 1 122 | 123 | self.n_past_steps = self.cfg["MODEL"]["N_PAST_STEPS"] 124 | self.n_future_steps = self.cfg["MODEL"]["N_FUTURE_STEPS"] 125 | 126 | # Projection class for mapping from range image to 3D point cloud 127 | self.projection = projection(self.cfg) 128 | 129 | if split == "train": 130 | self.sequences = self.cfg["DATA_CONFIG"]["SPLIT"]["TRAIN"] 131 | elif split == "val": 132 | self.sequences = self.cfg["DATA_CONFIG"]["SPLIT"]["VAL"] 133 | elif split == "test": 134 | self.sequences = self.cfg["DATA_CONFIG"]["SPLIT"]["TEST"] 135 | else: 136 | raise Exception("Split must be train/val/test") 137 | 138 | # Create a dict filenames that maps from a sequence number to a list of files in the dataset 139 | self.filenames_range = {} 140 | self.filenames_xyz = {} 141 | self.filenames_intensity = {} 142 | self.filenames_label = {} 143 | 144 | # Create a dict idx_mapper that maps from a dataset idx to a sequence number and the index of the current scan 145 | self.dataset_size = 0 146 | self.idx_mapper = {} 147 | idx = 0 148 | for seq in self.sequences: 149 | seqstr = "{0:02d}".format(int(seq)) 150 | 151 | scan_path_range = os.path.join(self.root_dir, seqstr, "processed", "range") 152 | self.filenames_range[seq] = load_files(scan_path_range) 153 | 154 | scan_path_xyz = os.path.join(self.root_dir, seqstr, "processed", "xyz") 155 | self.filenames_xyz[seq] = load_files(scan_path_xyz) 156 | assert len(self.filenames_range[seq]) == len(self.filenames_xyz[seq]) 157 | 158 | scan_path_intensity = os.path.join(self.root_dir, seqstr, "processed", "intensity") 159 | self.filenames_intensity[seq] = load_files(scan_path_intensity) 160 | assert len(self.filenames_range[seq]) == len(self.filenames_intensity[seq]) 161 | 162 | scan_path_label = os.path.join(self.root_dir, seqstr, "processed", "labels") 163 | self.filenames_label[seq] = load_files(scan_path_label) 164 | assert len(self.filenames_range[seq]) == len(self.filenames_label[seq]) 165 | 166 | # Get number of sequences based on number of past and future steps 167 | n_samples_sequence = max( 168 | 0, 169 | len(self.filenames_range[seq]) 170 | - self.n_past_steps 171 | - self.n_future_steps 172 | + 1, 173 | ) 174 | 175 | # Add to idx mapping 176 | for sample_idx in range(n_samples_sequence): 177 | scan_idx = self.n_past_steps + sample_idx - 1 178 | self.idx_mapper[idx] = (seq, scan_idx) 179 | idx += 1 180 | self.dataset_size += n_samples_sequence 181 | 182 | def __len__(self): 183 | return self.dataset_size 184 | 185 | def __getitem__(self, idx): 186 | """Load and concatenate range image channels 187 | 188 | Args: 189 | idx (int): Sample index 190 | 191 | Returns: 192 | item: Dataset dictionary item 193 | """ 194 | seq, scan_idx = self.idx_mapper[idx] 195 | 196 | # Load past data 197 | past_data = torch.empty( 198 | [self.n_channels, self.n_past_steps, self.height, self.width] 199 | ) 200 | 201 | from_idx = scan_idx - self.n_past_steps + 1 202 | to_idx = scan_idx 203 | past_filenames_range = self.filenames_range[seq][from_idx : to_idx + 1] 204 | past_filenames_xyz = self.filenames_xyz[seq][from_idx : to_idx + 1] 205 | past_filenames_intensity = self.filenames_intensity[seq][from_idx : to_idx + 1] 206 | past_filenames_label = self.filenames_label[seq][from_idx: to_idx + 1] 207 | 208 | for t in range(self.n_past_steps): 209 | past_data[0, t, :, :] = self.load_range(past_filenames_range[t]) 210 | past_data[1:4, t, :, :] = self.load_xyz(past_filenames_xyz[t]) 211 | past_data[4, t, :, :] = self.load_intensity(past_filenames_intensity[t]) 212 | past_data[5, t, :, :] = self.load_label(past_filenames_label[t]) 213 | 214 | # Load future data 215 | fut_data = torch.empty( 216 | [self.n_channels, self.n_future_steps, self.height, self.width] 217 | ) 218 | 219 | from_idx = scan_idx + 1 220 | to_idx = scan_idx + self.n_future_steps 221 | fut_filenames_range = self.filenames_range[seq][from_idx : to_idx + 1] 222 | fut_filenames_xyz = self.filenames_xyz[seq][from_idx : to_idx + 1] 223 | fut_filenames_intensity = self.filenames_intensity[seq][from_idx : to_idx + 1] 224 | fut_filenames_label = self.filenames_label[seq][from_idx: to_idx + 1] 225 | 226 | for t in range(self.n_future_steps): 227 | fut_data[0, t, :, :] = self.load_range(fut_filenames_range[t]) 228 | fut_data[1:4, t, :, :] = self.load_xyz(fut_filenames_xyz[t]) 229 | fut_data[4, t, :, :] = self.load_intensity(fut_filenames_intensity[t]) 230 | fut_data[5, t, :, :] = self.load_label(fut_filenames_label[t]) 231 | 232 | item = {"past_data": past_data, "fut_data": fut_data, "meta": (seq, scan_idx)} 233 | return item 234 | 235 | 236 | def load_range(self, filename): 237 | """Load .npy range image as (1,height,width) tensor""" 238 | rv = torch.Tensor(np.load(filename)).float() 239 | return rv 240 | 241 | def load_xyz(self, filename): 242 | """Load .npy xyz values as (3,height,width) tensor""" 243 | xyz = torch.Tensor(np.load(filename)).float()[:, :, :3] 244 | xyz = xyz.permute(2, 0, 1) 245 | return xyz 246 | 247 | def load_intensity(self, filename): 248 | """Load .npy intensity values as (1,height,width) tensor""" 249 | intensity = torch.Tensor(np.load(filename)).float() 250 | return intensity 251 | 252 | def load_label(self, filename): 253 | """Load .npy label values as (1,height,width) tensor""" 254 | label = torch.Tensor(np.load(filename)).int() 255 | return label 256 | 257 | 258 | if __name__ == "__main__": 259 | config_filename = "./config/parameters.yaml" 260 | cfg = yaml.safe_load(open(config_filename)) 261 | dataset_path = "" 262 | rawdata_path = "" 263 | data = KittiOdometryModule(cfg, dataset_path, rawdata_path) 264 | data.prepare_data() 265 | data.setup() 266 | 267 | item = data.valid_loader.dataset.__getitem__(1) 268 | 269 | def normalize(image): 270 | min = np.min(image) 271 | max = np.max(image) 272 | normalized_image = (image - min) / (max - min) 273 | return normalized_image 274 | 275 | import matplotlib.pyplot as plt 276 | 277 | fig, axs = plt.subplots(6, 1, sharex=True, figsize=(30, 30 * 6 * 64 / 2048)) 278 | 279 | axs[0].imshow(normalize(item["fut_data"][0, 0, :, :].numpy())) 280 | axs[0].set_title("Range") 281 | axs[1].imshow(normalize(item["fut_data"][1, 0, :, :].numpy())) 282 | axs[1].set_title("X") 283 | axs[2].imshow(normalize(item["fut_data"][2, 0, :, :].numpy())) 284 | axs[2].set_title("Y") 285 | axs[3].imshow(normalize(item["fut_data"][3, 0, :, :].numpy())) 286 | axs[3].set_title("Z") 287 | axs[4].imshow(normalize(item["fut_data"][4, 0, :, :].numpy())) 288 | axs[4].set_title("Intensity") 289 | axs[5].imshow(normalize(item["fut_data"][5, 0, :, :].numpy())) 290 | axs[5].set_title("Label") 291 | 292 | plt.show() 293 | -------------------------------------------------------------------------------- /pcpnet/models/PCPNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou 3 | # This file is covered by the LICENSE file in the root of the project PCPNet: 4 | # https://github.com/Blurryface0814/PCPNet 5 | # Brief: Main Network Architecture of PCPNet 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import yaml 10 | import math 11 | 12 | from pcpnet.models.base import BasePredictionModel 13 | from pcpnet.models.layers import CustomConv3d, DownBlock, UpBlock, Transformer_W, Transformer_H, eca_block 14 | 15 | 16 | class PCPNet(BasePredictionModel): 17 | def __init__(self, cfg): 18 | """Init all layers needed for range image-based point cloud prediction""" 19 | super().__init__(cfg) 20 | self.channels = self.cfg["MODEL"]["CHANNELS"] 21 | self.skip_if_channel_size = self.cfg["MODEL"]["SKIP_IF_CHANNEL_SIZE"] 22 | self.circular_padding = self.cfg["MODEL"]["CIRCULAR_PADDING"] 23 | self.transformer_h_layers = self.cfg["MODEL"]["TRANSFORMER_H_LAYERS"] 24 | self.transformer_w_layers = self.cfg["MODEL"]["TRANSFORMER_W_LAYERS"] 25 | 26 | self.input_layer = CustomConv3d( 27 | self.n_inputs, 28 | self.channels[0], 29 | kernel_size=(1, 1, 1), 30 | stride=(1, 1, 1), 31 | bias=True, 32 | circular_padding=self.circular_padding, 33 | ) 34 | 35 | self.DownLayers = nn.ModuleList() 36 | for i in range(len(self.channels) - 1): 37 | if self.channels[i + 1] in self.skip_if_channel_size: 38 | self.DownLayers.append( 39 | DownBlock( 40 | self.cfg, 41 | self.channels[i], 42 | self.channels[i + 1], 43 | down_stride_H=2, 44 | down_stride_W=4, 45 | skip=True, 46 | ) 47 | ) 48 | else: 49 | self.DownLayers.append( 50 | DownBlock( 51 | self.cfg, 52 | self.channels[i], 53 | self.channels[i + 1], 54 | down_stride_H=2, 55 | down_stride_W=4, 56 | skip=False, 57 | ) 58 | ) 59 | 60 | self.Transformer_H = Transformer_H(self.cfg, self.channels[-1], 61 | layers=self.transformer_h_layers, skip=True) 62 | self.Transformer_W = Transformer_W(self.cfg, self.channels[-1], 63 | layers=self.transformer_w_layers, skip=True) 64 | 65 | self.channel_attention = eca_block(2 * self.channels[-1], b=1, gamma=2) 66 | 67 | self.mid_layer = CustomConv3d( 68 | 2 * self.channels[-1], 69 | self.channels[-1], 70 | kernel_size=(1, 1, 1), 71 | stride=(1, 1, 1), 72 | bias=True, 73 | circular_padding=self.circular_padding, 74 | ) 75 | 76 | self.UpLayers = nn.ModuleList() 77 | for i in reversed(range(len(self.channels) - 1)): 78 | if self.channels[i + 1] in self.skip_if_channel_size: 79 | self.UpLayers.append( 80 | UpBlock( 81 | self.cfg, 82 | self.channels[i + 1], 83 | self.channels[i], 84 | up_stride_H=2, 85 | up_stride_W=4, 86 | skip=True, 87 | ) 88 | ) 89 | else: 90 | self.UpLayers.append( 91 | UpBlock( 92 | self.cfg, 93 | self.channels[i + 1], 94 | self.channels[i], 95 | up_stride_H=2, 96 | up_stride_W=4, 97 | skip=False, 98 | ) 99 | ) 100 | 101 | self.n_outputs = 2 102 | self.output_layer = CustomConv3d( 103 | self.channels[0], 104 | self.n_outputs, 105 | kernel_size=(1, 1, 1), 106 | stride=(1, 1, 1), 107 | bias=True, 108 | circular_padding=self.circular_padding, 109 | ) 110 | 111 | def forward(self, x): 112 | """Forward range image-based point cloud prediction 113 | 114 | Args: 115 | x (torch.tensor): Input tensor of concatenated, unnormalize range images 116 | 117 | Returns: 118 | dict: Containing the predicted range tensor and mask logits 119 | """ 120 | # Only select inputs specified in base model 121 | x = x[:, self.inputs, :, :, :] 122 | batch_size, n_inputs, n_past_steps, H, W = x.size() 123 | assert n_inputs == self.n_inputs 124 | 125 | # Get mask of valid points 126 | past_mask = x != -1.0 127 | 128 | # Standardization and set invalid points to zero 129 | mean = self.mean[None, self.inputs, None, None, None] 130 | std = self.std[None, self.inputs, None, None, None] 131 | x = torch.true_divide(x - mean, std) 132 | x = x * past_mask 133 | 134 | skip_list = [] 135 | x = x.view(batch_size, n_inputs, n_past_steps, H, W) # [B, C, T, H, W] 136 | 137 | x = self.input_layer(x) 138 | for layer in self.DownLayers: 139 | x = layer(x) 140 | if layer.skip: 141 | skip_list.append(x.clone()) 142 | 143 | # [B, C, T, H=16, W=128] 144 | x_h = self.Transformer_H(x) 145 | x_w = self.Transformer_W(x) 146 | x = torch.cat((x_h, x_w), dim=1) 147 | x = self.channel_attention(x) 148 | x = self.mid_layer(x) 149 | 150 | for layer in self.UpLayers: 151 | if layer.skip: 152 | x = layer(x, skip_list.pop()) 153 | else: 154 | x = layer(x) 155 | 156 | x = self.output_layer(x) 157 | 158 | output = {} 159 | output["rv"] = self.min_range + nn.Sigmoid()(x[:, 0, :, :, :]) * ( 160 | self.max_range - self.min_range 161 | ) 162 | output["mask_logits"] = x[:, 1, :, :, :] 163 | 164 | return output 165 | -------------------------------------------------------------------------------- /pcpnet/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__init__.py -------------------------------------------------------------------------------- /pcpnet/models/__pycache__/PCPNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/PCPNet.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/models/__pycache__/PPT.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/PPT.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/models/__pycache__/TCNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/TCNet.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/models/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/models/__pycache__/layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/layers.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/models/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/models/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/models/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou 3 | # This file is covered by the LICENSE file in the root of the project PCPNet: 4 | # https://github.com/Blurryface0814/PCPNet 5 | # Brief: Base Lightning Module of PCPNet 6 | import os 7 | import time 8 | import yaml 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from pytorch_lightning.core.lightning import LightningModule 15 | from pcpnet.models.loss import Loss 16 | from pcpnet.utils.projection import projection 17 | from pcpnet.utils.logger import log_point_clouds, save_range_mask_and_semantic, save_point_clouds 18 | 19 | 20 | class BasePredictionModel(LightningModule): 21 | """Pytorch Lightning base model for point cloud prediction""" 22 | 23 | def __init__(self, cfg): 24 | """Init base model 25 | 26 | Args: 27 | cfg (dict): Config parameters 28 | """ 29 | super(BasePredictionModel, self).__init__() 30 | self.cfg = cfg 31 | self.save_hyperparameters(self.cfg) 32 | 33 | self.height = self.cfg["DATA_CONFIG"]["HEIGHT"] 34 | self.width = self.cfg["DATA_CONFIG"]["WIDTH"] 35 | self.min_range = self.cfg["DATA_CONFIG"]["MIN_RANGE"] 36 | self.max_range = self.cfg["DATA_CONFIG"]["MAX_RANGE"] 37 | self.register_buffer("mean", torch.Tensor(self.cfg["DATA_CONFIG"]["MEAN"])) 38 | self.register_buffer("std", torch.Tensor(self.cfg["DATA_CONFIG"]["STD"])) 39 | 40 | self.n_past_steps = self.cfg["MODEL"]["N_PAST_STEPS"] 41 | self.n_future_steps = self.cfg["MODEL"]["N_FUTURE_STEPS"] 42 | self.use_xyz = self.cfg["MODEL"]["USE"]["XYZ"] 43 | self.use_intensity = self.cfg["MODEL"]["USE"]["INTENSITY"] 44 | 45 | # Create list of index used in input 46 | self.inputs = [0] 47 | if self.use_xyz: 48 | self.inputs.append(1) 49 | self.inputs.append(2) 50 | self.inputs.append(3) 51 | if self.use_intensity: 52 | self.inputs.append(4) 53 | self.n_inputs = len(self.inputs) 54 | 55 | # Init loss 56 | self.loss = Loss(self.cfg) 57 | 58 | # Init projection class for re-projcecting from range images to 3D point clouds 59 | self.projection = projection(self.cfg) 60 | 61 | self.chamfer_distances_tensor = torch.zeros(self.n_future_steps, 1) 62 | 63 | def forward(self, x): 64 | pass 65 | 66 | def configure_optimizers(self): 67 | """Optimizers""" 68 | # optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg["TRAIN"]["LR"]) 69 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), 70 | lr=self.cfg["TRAIN"]["LR"]) 71 | scheduler = torch.optim.lr_scheduler.StepLR( 72 | optimizer, 73 | step_size=self.cfg["TRAIN"]["LR_EPOCH"], 74 | gamma=self.cfg["TRAIN"]["LR_DECAY"], 75 | ) 76 | return [optimizer], [scheduler] 77 | 78 | def on_load_checkpoint(self, checkpoint): 79 | # load the semantic weights here 80 | for param_name in self.loss.state_dict(): 81 | checkpoint['state_dict']['loss.' + param_name] = self.loss.state_dict()[param_name] 82 | 83 | def on_save_checkpoint(self, checkpoint): 84 | # pop the semantic weights here 85 | for param_name in self.loss.state_dict(): 86 | del checkpoint['state_dict']['loss.' + param_name] 87 | 88 | def training_step(self, batch, batch_idx): 89 | """Pytorch Lightning training step including logging 90 | 91 | Args: 92 | batch (dict): A dict with a batch of training samples 93 | batch_idx (int): Index of batch in dataset 94 | 95 | Returns: 96 | loss (dict): Multiple loss components 97 | """ 98 | past = batch["past_data"] 99 | future = batch["fut_data"] 100 | output = self.forward(past) 101 | loss = self.loss(output, future, "train", self.current_epoch) 102 | 103 | self.log("train/loss", loss["loss"]) 104 | self.log("train/mean_chamfer_distance", loss["mean_chamfer_distance"]) 105 | self.log("train/final_chamfer_distance", loss["final_chamfer_distance"]) 106 | self.log("train/loss_range_view", loss["loss_range_view"]) 107 | self.log("train/loss_mask", loss["loss_mask"]) 108 | self.log("train/loss_semantic", loss["loss_semantic"]) 109 | 110 | return loss 111 | 112 | def validation_step(self, batch, batch_idx): 113 | """Pytorch Lightning validation step including logging 114 | 115 | Args: 116 | batch (dict): A dict with a batch of validation samples 117 | batch_idx (int): Index of batch in dataset 118 | 119 | Returns: 120 | None 121 | """ 122 | past = batch["past_data"] 123 | future = batch["fut_data"] 124 | output = self.forward(past) 125 | 126 | loss, learning_map_inv, color_map, output_argmax, target_argmax, label, \ 127 | rand_t, similarity_output, similarity_target = self.loss( 128 | output, future, "val", self.current_epoch 129 | ) 130 | 131 | self.log("val/loss", loss["loss"], on_epoch=True) 132 | self.log( 133 | "val/mean_chamfer_distance", loss["mean_chamfer_distance"], on_epoch=True 134 | ) 135 | self.log( 136 | "val/final_chamfer_distance", loss["final_chamfer_distance"], on_epoch=True 137 | ) 138 | self.log("val/loss_range_view", loss["loss_range_view"], on_epoch=True) 139 | self.log("val/loss_mask", loss["loss_mask"], on_epoch=True) 140 | self.log("val/loss_semantic", loss["loss_semantic"], on_epoch=True) 141 | 142 | selected_sequence_and_frame = self.cfg["VALIDATION"]["SELECTED_SEQUENCE_AND_FRAME"] 143 | sequence_batch, frame_batch = batch["meta"] 144 | for sample_idx in range(frame_batch.shape[0]): 145 | sequence = sequence_batch[sample_idx].item() 146 | frame = frame_batch[sample_idx].item() 147 | if sequence in selected_sequence_and_frame.keys(): 148 | if frame in selected_sequence_and_frame[sequence]: 149 | log_point_clouds( 150 | self.logger.experiment, 151 | self.projection, 152 | self.current_epoch, 153 | batch, 154 | output, 155 | sample_idx, 156 | sequence, 157 | frame, 158 | ) 159 | save_range_mask_and_semantic( 160 | self.cfg, 161 | self.projection, 162 | batch, 163 | output, 164 | sample_idx, 165 | sequence, 166 | frame, 167 | learning_map_inv, 168 | color_map, 169 | output_argmax, 170 | target_argmax, 171 | label, 172 | rand_t 173 | ) 174 | 175 | def test_step(self, batch, batch_idx): 176 | """Pytorch Lightning test step including logging 177 | 178 | Args: 179 | batch (dict): A dict with a batch of test samples 180 | batch_idx (int): Index of batch in dataset 181 | 182 | Returns: 183 | loss (dict): Multiple loss components 184 | """ 185 | past = batch["past_data"] 186 | future = batch["fut_data"] 187 | 188 | batch_size, n_inputs, n_future_steps, H, W = past.shape 189 | 190 | start = time.time() 191 | output = self.forward(past) 192 | inference_time = (time.time() - start) / batch_size 193 | self.log("test/inference_time", inference_time, on_epoch=True) 194 | 195 | if not self.cfg["TEST"]["ONLY_SAVE_POINT_CLOUDS"]: 196 | loss, learning_map_inv, color_map, output_argmax, target_argmax, label, \ 197 | rand_t, similarity_output, similarity_target = self.loss( 198 | output, future, "test", self.current_epoch 199 | ) 200 | 201 | self.log("test/loss_range_view", loss["loss_range_view"], on_epoch=True) 202 | self.log("test/loss_mask", loss["loss_mask"], on_epoch=True) 203 | self.log("test/loss_semantic", loss["loss_semantic"], on_epoch=True) 204 | self.log("test/similarity_output", similarity_output.detach(), on_epoch=True) 205 | self.log("test/similarity_target", similarity_target.detach(), on_epoch=True) 206 | 207 | for step, value in loss["chamfer_distance"].items(): 208 | self.log("test/chamfer_distance_{:d}".format(step), value, on_epoch=True) 209 | 210 | self.log( 211 | "test/mean_chamfer_distance", loss["mean_chamfer_distance"], on_epoch=True 212 | ) 213 | self.log( 214 | "test/final_chamfer_distance", loss["final_chamfer_distance"], on_epoch=True 215 | ) 216 | 217 | self.chamfer_distances_tensor = torch.cat( 218 | (self.chamfer_distances_tensor, loss["chamfer_distances_tensor"]), dim=1 219 | ) 220 | 221 | selected_sequence_and_frame = self.cfg["TEST"]["SELECTED_SEQUENCE_AND_FRAME"] 222 | sequence_batch, frame_batch = batch["meta"] 223 | for sample_idx in range(frame_batch.shape[0]): 224 | sequence = sequence_batch[sample_idx].item() 225 | frame = frame_batch[sample_idx].item() 226 | if sequence in selected_sequence_and_frame.keys(): 227 | if frame in selected_sequence_and_frame[sequence]: 228 | if self.logger: 229 | log_point_clouds( 230 | self.logger.experiment, 231 | self.projection, 232 | self.current_epoch, 233 | batch, 234 | output, 235 | sample_idx, 236 | sequence, 237 | frame, 238 | ) 239 | save_range_mask_and_semantic( 240 | self.cfg, 241 | self.projection, 242 | batch, 243 | output, 244 | sample_idx, 245 | sequence, 246 | frame, 247 | learning_map_inv, 248 | color_map, 249 | output_argmax, 250 | target_argmax, 251 | label, 252 | rand_t, 253 | test=True 254 | ) 255 | else: 256 | loss = None 257 | 258 | if self.cfg["TEST"]["SAVE_POINT_CLOUDS"] or self.cfg["TEST"]["ONLY_SAVE_POINT_CLOUDS"]: 259 | save_point_clouds(self.cfg, self.projection, batch, output) 260 | 261 | return loss 262 | 263 | def test_epoch_end(self, outputs): 264 | # Remove first row since it was initialized with zero 265 | self.chamfer_distances_tensor = self.chamfer_distances_tensor[:, 1:] 266 | n_steps, _ = self.chamfer_distances_tensor.shape 267 | mean = torch.mean(self.chamfer_distances_tensor, dim=1) 268 | std = torch.std(self.chamfer_distances_tensor, dim=1) 269 | q = torch.tensor([0.25, 0.5, 0.75]) 270 | quantile = torch.quantile(self.chamfer_distances_tensor, q, dim=1) 271 | 272 | chamfer_distances = [] 273 | for s in range(n_steps): 274 | chamfer_distances.append(self.chamfer_distances_tensor[s, :].tolist()) 275 | print("Final size of CD: ", self.chamfer_distances_tensor.shape) 276 | print("Mean :", mean) 277 | print("Std :", std) 278 | print("Quantile :", quantile) 279 | 280 | testdir = os.path.join(self.cfg["LOG_DIR"], self.cfg["TEST"]["DIR_NAME"]) 281 | if not os.path.exists(testdir): 282 | os.makedirs(testdir) 283 | 284 | filename = os.path.join( 285 | testdir, "stats" + ".yaml" 286 | ) 287 | 288 | log_to_save = { 289 | "mean": mean.tolist(), 290 | "std": std.tolist(), 291 | "quantile": quantile.tolist(), 292 | "chamfer_distances": chamfer_distances, 293 | } 294 | with open(filename, "w") as yaml_file: 295 | yaml.dump(log_to_save, yaml_file, default_flow_style=False) 296 | -------------------------------------------------------------------------------- /pcpnet/models/loss.py: -------------------------------------------------------------------------------- 1 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou 2 | # This file is covered by the LICENSE file in the root of the project PCPNet: 3 | # https://github.com/Blurryface0814/PCPNet 4 | # Brief: Implementation of the Loss Modules 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import os 9 | import yaml 10 | import math 11 | import random 12 | 13 | from pyTorchChamferDistance.chamfer_distance import ChamferDistance 14 | from pcpnet.utils.projection import projection 15 | from pcpnet.utils.logger import map 16 | import importlib 17 | import numpy as np 18 | 19 | 20 | class Loss(nn.Module): 21 | """Combined loss for point cloud prediction""" 22 | 23 | def __init__(self, cfg): 24 | """Init""" 25 | super().__init__() 26 | self.cfg = cfg 27 | self.n_future_steps = self.cfg["MODEL"]["N_FUTURE_STEPS"] 28 | self.loss_weight_cd = self.cfg["TRAIN"]["LOSS_WEIGHT_CHAMFER_DISTANCE"] 29 | self.loss_weight_rv = self.cfg["TRAIN"]["LOSS_WEIGHT_RANGE_VIEW"] 30 | self.loss_weight_mask = self.cfg["TRAIN"]["LOSS_WEIGHT_MASK"] 31 | self.loss_weight_semantic = self.cfg["TRAIN"]["LOSS_WEIGHT_SEMANTIC"] 32 | self.cd_every_n_val_epoch = self.cfg["TRAIN"]["CHAMFER_DISTANCE_EVERY_N_VAL_EPOCH"] 33 | 34 | self.loss_range = loss_range(self.cfg) 35 | self.chamfer_distance = chamfer_distance(self.cfg) 36 | self.loss_mask = loss_mask(self.cfg) 37 | self.loss_semantic = loss_semantic(self.cfg) 38 | 39 | def forward(self, output, target, mode, current_epoch=None): 40 | """Forward pass with multiple loss components 41 | 42 | Args: 43 | output (dict): Predicted mask logits and ranges 44 | target (torch.tensor): Target range image 45 | mode (str): Mode (train,val,test) 46 | 47 | Returns: 48 | dict: Dict with loss components 49 | """ 50 | cd_flag = ((current_epoch + 1) % self.cd_every_n_val_epoch == 0) 51 | 52 | target_range_image = target[:, 0, :, :, :] 53 | target_label = target[:, 5, :, :, :] 54 | 55 | # Range view 56 | loss_range_view = self.loss_range(output, target_range_image) 57 | 58 | # Mask 59 | loss_mask = self.loss_mask(output, target_range_image) 60 | 61 | # Semantic 62 | # if self.loss_weight_semantic > 0.0 or mode == "val" or mode == "test": 63 | if self.loss_weight_semantic > 0.0: 64 | loss_semantic, output_argmax, target_argmax, label, rand_t, \ 65 | similarity_output, similarity_target = self.loss_semantic( 66 | output, target_range_image, target_label, mode 67 | ) 68 | else: 69 | loss_semantic = torch.zeros(1).type_as(target_range_image) 70 | B, T, H, W = target_range_image.shape 71 | output_argmax = torch.zeros((B, 1, H, W)).type_as(target_range_image) 72 | target_argmax = output_argmax 73 | label = output_argmax 74 | rand_t = "none" 75 | similarity_output = torch.zeros(1).type_as(target_range_image) 76 | similarity_target = torch.zeros(1).type_as(target_range_image) 77 | 78 | # Chamfer Distance 79 | if self.loss_weight_cd > 0.0 or (mode == "val" and cd_flag) or mode == "test": 80 | chamfer_distance, chamfer_distances_tensor = self.chamfer_distance( 81 | output, target, self.cfg["TEST"]["N_DOWNSAMPLED_POINTS_CD"] 82 | ) 83 | loss_chamfer_distance = sum([cd for cd in chamfer_distance.values()]) / len( 84 | chamfer_distance 85 | ) 86 | detached_chamfer_distance = { 87 | step: cd.detach() for step, cd in chamfer_distance.items() 88 | } 89 | else: 90 | chamfer_distance = dict( 91 | (step, torch.zeros(1).type_as(target_range_image)) 92 | for step in range(self.n_future_steps) 93 | ) 94 | chamfer_distances_tensor = torch.zeros(self.n_future_steps, 1) 95 | loss_chamfer_distance = torch.zeros_like(loss_range_view) 96 | detached_chamfer_distance = chamfer_distance 97 | 98 | loss = ( 99 | self.loss_weight_cd * loss_chamfer_distance 100 | + self.loss_weight_rv * loss_range_view 101 | + self.loss_weight_mask * loss_mask 102 | + self.loss_weight_semantic * loss_semantic 103 | ) 104 | 105 | loss_dict = { 106 | "loss": loss, 107 | "chamfer_distance": detached_chamfer_distance, 108 | "chamfer_distances_tensor": chamfer_distances_tensor.detach(), 109 | "mean_chamfer_distance": loss_chamfer_distance.detach(), 110 | "final_chamfer_distance": chamfer_distance[ 111 | self.n_future_steps - 1 112 | ].detach(), 113 | "loss_range_view": loss_range_view.detach(), 114 | "loss_mask": loss_mask.detach(), 115 | "loss_semantic": loss_semantic.detach() 116 | } 117 | 118 | if mode == "val" or mode == "test": 119 | return loss_dict, \ 120 | self.loss_semantic.learning_map_inv, \ 121 | self.loss_semantic.color_map, \ 122 | output_argmax, \ 123 | target_argmax, \ 124 | label, \ 125 | str(rand_t), \ 126 | similarity_output, \ 127 | similarity_target 128 | 129 | return loss_dict 130 | 131 | 132 | class loss_mask(nn.Module): 133 | """Binary cross entropy loss for prediction of valid mask""" 134 | 135 | def __init__(self, cfg): 136 | super().__init__() 137 | self.cfg = cfg 138 | self.loss = nn.BCEWithLogitsLoss(reduction="mean") 139 | self.projection = projection(self.cfg) 140 | 141 | def forward(self, output, target_range_view): 142 | target_mask = self.projection.get_target_mask_from_range_view(target_range_view) 143 | loss = self.loss(output["mask_logits"], target_mask) 144 | return loss 145 | 146 | 147 | class loss_range(nn.Module): 148 | """L1 loss for range image prediction""" 149 | 150 | def __init__(self, cfg): 151 | super().__init__() 152 | self.cfg = cfg 153 | self.loss = nn.L1Loss(reduction="mean") 154 | 155 | def forward(self, output, target_range_image): 156 | # Do not count L1 loss for invalid GT points 157 | gt_masked_output = output["rv"].clone() 158 | gt_masked_output[target_range_image == -1.0] = -1.0 159 | 160 | loss = self.loss(gt_masked_output, target_range_image) 161 | return loss 162 | 163 | 164 | class loss_semantic(nn.Module): 165 | """Semantic loss for prediction of rv""" 166 | 167 | def __init__(self, cfg): 168 | super().__init__() 169 | self.cfg = cfg 170 | self.semantic_net = self.cfg["MODEL"]["SEMANTIC_NET"] 171 | self.semantic_data_config = self.cfg["MODEL"]["SEMANTIC_DATA_CONFIG"] 172 | self.semantic_pretrained = self.cfg["MODEL"]["SEMANTIC_PRETRAINED_MODEL"] 173 | self.projection = projection(self.cfg) 174 | 175 | pretrained = "semantic_net/" + self.semantic_net + "/model/" + self.semantic_pretrained 176 | config_base_path = "semantic_net/" + self.semantic_net + "/tasks/semantic/config/" 177 | arch_config = pretrained + "/arch_cfg.yaml" 178 | data_config = config_base_path + "labels/" + self.semantic_data_config 179 | 180 | self.DATA = yaml.safe_load(open(data_config, 'r')) 181 | self.n_classes = len(self.DATA["learning_map_inv"]) 182 | self.learning_map_inv = self.DATA["learning_map_inv"] 183 | self.learning_map = self.DATA["learning_map"] 184 | self.color_map = self.DATA["color_map"] 185 | 186 | module = "semantic_net." + self.semantic_net + ".tasks.semantic.modules.segmentator" 187 | segmentator_module = importlib.import_module(module) 188 | 189 | self.loss = nn.L1Loss(reduction="mean") 190 | 191 | if cfg["TRAIN"]["LOSS_WEIGHT_SEMANTIC"] > 0.0: 192 | self.ARCH = yaml.safe_load(open(arch_config, 'r')) 193 | self.sensor = self.ARCH["dataset"]["sensor"] 194 | self.sensor_img_means = torch.tensor(self.sensor["img_means"], dtype=torch.float) 195 | self.sensor_img_stds = torch.tensor(self.sensor["img_stds"], dtype=torch.float) 196 | with torch.no_grad(): 197 | self.semantic_model = segmentator_module.Segmentator(self.ARCH, self.n_classes, pretrained) 198 | self.semantic_model.eval() 199 | 200 | # Semantic Similarity 201 | self.criterion = nn.NLLLoss() 202 | self.semantic_similarity = self.cfg["TEST"]["SEMANTIC_SIMILARITY"] 203 | 204 | def forward(self, output, target_range_image, target_label, mode=None): 205 | B, T, H, W = output["rv"].shape 206 | sensor_img_means = self.sensor_img_means.type_as(target_range_image) 207 | sensor_img_stds = self.sensor_img_stds.type_as(target_range_image) 208 | 209 | label = target_label 210 | output = output["rv"].unsqueeze(2) # [B, T, 1, H, W] 211 | target = target_range_image.unsqueeze(2) 212 | 213 | output_mask = output != -1.0 214 | target_mask = target != -1.0 215 | 216 | output_masked = (output - sensor_img_means[None, None, 0:1, None, None] 217 | ) / sensor_img_stds[None, None, 0:1, None, None] 218 | 219 | target_masked = (target - sensor_img_means[None, None, 0:1, None, None] 220 | ) / sensor_img_stds[None, None, 0:1, None, None] 221 | 222 | output_masked = output_masked * output_mask 223 | target_masked = target_masked * target_mask 224 | 225 | rand_t = random.randint(0, T-1) 226 | output_t = self.semantic_model(output_masked[:, rand_t, :, :, :]) # [B, 20, H, W] 227 | target_t = self.semantic_model(target_masked[:, rand_t, :, :, :]) 228 | label = label[:, rand_t, :, :] 229 | 230 | # get labels from semantic model output 231 | output_argmax = output_t.argmax(dim=1) # [B, H, W] 232 | target_argmax = target_t.argmax(dim=1) 233 | label = map(label.cpu().numpy().astype(np.int32), self.learning_map) 234 | label = torch.tensor(label).type_as(target_label) 235 | 236 | # Do not count L1 loss for invalid GT points 237 | target_mask = target[:, rand_t, :, :, :].repeat(1, 20, 1, 1) 238 | output_last = output_t.clone() 239 | target_last = target_t.clone() 240 | output_last[target_mask == -1.0] = -1.0 241 | target_last[target_mask == -1.0] = -1.0 242 | 243 | # set argmax image for visualization 244 | output_argmax[target_last[:, 0, :, :] == -1.0] = 0.0 245 | target_argmax[target_last[:, 0, :, :] == -1.0] = 0.0 246 | 247 | # Semantic Similarity 248 | if mode == "test" and self.semantic_similarity: 249 | similarity_output = self.n_classes / self.criterion(torch.log(output_t.clamp(min=1e-8)), label.long()) 250 | similarity_target = self.n_classes / self.criterion(torch.log(target_t.clamp(min=1e-8)), label.long()) 251 | else: 252 | similarity_output = torch.zeros(1).type_as(output_t) 253 | similarity_target = torch.zeros(1).type_as(output_t) 254 | 255 | loss = self.loss(output_last, target_last) 256 | 257 | return loss, output_argmax.cpu().numpy(), target_argmax.cpu().numpy(), \ 258 | label.cpu().numpy(), rand_t, similarity_output, similarity_target 259 | 260 | 261 | class chamfer_distance(nn.Module): 262 | """Chamfer distance loss. Additionally, the implementation allows the evaluation 263 | on downsampled point cloud (this is only for comparison to other methods but not recommended, 264 | because it is a bad approximation of the real Chamfer distance. 265 | """ 266 | 267 | def __init__(self, cfg): 268 | super().__init__() 269 | self.cfg = cfg 270 | self.loss = ChamferDistance() 271 | self.projection = projection(self.cfg) 272 | 273 | def forward(self, output, target, n_samples): 274 | batch_size, n_future_steps, H, W = output["rv"].shape 275 | masked_output = self.projection.get_masked_range_view(output) 276 | chamfer_distances = {} 277 | chamfer_distances_tensor = torch.zeros(n_future_steps, batch_size) 278 | for s in range(n_future_steps): 279 | chamfer_distances[s] = 0 280 | for b in range(batch_size): 281 | output_points = self.projection.get_valid_points_from_range_view( 282 | masked_output[b, s, :, :] 283 | ).view(1, -1, 3) 284 | target_points = target[b, 1:4, s, :, :].permute(1, 2, 0) 285 | target_points = target_points[target[b, 0, s, :, :] > 0.0].view( 286 | 1, -1, 3 287 | ) 288 | 289 | if n_samples != -1: 290 | n_output_points = output_points.shape[1] 291 | n_target_points = target_points.shape[1] 292 | 293 | sampled_output_indices = random.sample( 294 | range(n_output_points), n_samples 295 | ) 296 | sampled_target_indices = random.sample( 297 | range(n_target_points), n_samples 298 | ) 299 | 300 | output_points = output_points[:, sampled_output_indices, :] 301 | target_points = target_points[:, sampled_target_indices, :] 302 | 303 | dist1, dist2 = self.loss(output_points, target_points) 304 | dist_combined = torch.mean(dist1) + torch.mean(dist2) 305 | chamfer_distances[s] += dist_combined 306 | chamfer_distances_tensor[s, b] = dist_combined 307 | chamfer_distances[s] = chamfer_distances[s] / batch_size 308 | return chamfer_distances, chamfer_distances_tensor 309 | -------------------------------------------------------------------------------- /pcpnet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/utils/__init__.py -------------------------------------------------------------------------------- /pcpnet/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/utils/__pycache__/preprocess_data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/utils/__pycache__/preprocess_data.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/utils/__pycache__/projection.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/utils/__pycache__/projection.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/utils/__pycache__/visualization.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/utils/__pycache__/visualization.cpython-38.pyc -------------------------------------------------------------------------------- /pcpnet/utils/preprocess_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou 3 | # This file is covered by the LICENSE file in the root of the project PCPNet: 4 | # https://github.com/Blurryface0814/PCPNet 5 | # Brief: Preprocessing point cloud to range images 6 | import os 7 | import numpy as np 8 | import torch 9 | 10 | from pcpnet.utils.utils import load_files, range_projection 11 | 12 | 13 | def prepare_data(cfg, dataset_path, rawdata_path): 14 | """Loads point clouds and labels and pre-processes them into range images 15 | 16 | Args: 17 | cfg (dict): Config 18 | """ 19 | sequences = ( 20 | cfg["DATA_CONFIG"]["SPLIT"]["TRAIN"] 21 | + cfg["DATA_CONFIG"]["SPLIT"]["VAL"] 22 | + cfg["DATA_CONFIG"]["SPLIT"]["TEST"] 23 | ) 24 | 25 | proj_H = cfg["DATA_CONFIG"]["HEIGHT"] 26 | proj_W = cfg["DATA_CONFIG"]["WIDTH"] 27 | 28 | for seq in sequences: 29 | seqstr = "{0:02d}".format(int(seq)) 30 | scan_folder = os.path.join(rawdata_path, seqstr, "velodyne") 31 | label_folder = os.path.join(rawdata_path, seqstr, "labels") 32 | dst_folder = os.path.join(dataset_path, seqstr, "processed") 33 | if not os.path.exists(dst_folder): 34 | os.makedirs(dst_folder) 35 | 36 | # Load LiDAR scan files and label files 37 | scan_paths = load_files(scan_folder) 38 | label_paths = load_files(label_folder) 39 | 40 | if len(scan_paths) != len(label_paths): 41 | print("Points files: ", len(scan_paths)) 42 | print("Label files: ", len(label_paths)) 43 | raise ValueError("Scan and Label don't contain same number of files") 44 | 45 | # Iterate over all scan files and label files 46 | for idx in range(len(scan_paths)): 47 | print( 48 | "Processing file {:d}/{:d} of sequence {:d}".format( 49 | idx, len(scan_paths), seq 50 | ) 51 | ) 52 | 53 | # Load and project a point cloud 54 | current_vertex = np.fromfile(scan_paths[idx], dtype=np.float32) 55 | current_vertex = current_vertex.reshape((-1, 4)) 56 | label = np.fromfile(label_paths[idx], dtype=np.int32) 57 | label = label.reshape((-1)) 58 | # only fill in attribute if the right size 59 | if current_vertex.shape[0] != label.shape[0]: 60 | print("Points shape: ", current_vertex.shape[0]) 61 | print("Label shape: ", label.shape[0]) 62 | raise ValueError("Scan and Label don't contain same number of points") 63 | 64 | proj_range, proj_vertex, proj_intensity, proj_idx = range_projection( 65 | current_vertex, 66 | fov_up=cfg["DATA_CONFIG"]["FOV_UP"], 67 | fov_down=cfg["DATA_CONFIG"]["FOV_DOWN"], 68 | proj_H=cfg["DATA_CONFIG"]["HEIGHT"], 69 | proj_W=cfg["DATA_CONFIG"]["WIDTH"], 70 | max_range=cfg["DATA_CONFIG"]["MAX_RANGE"], 71 | ) 72 | 73 | # Save range 74 | dst_path_range = os.path.join(dst_folder, "range") 75 | if not os.path.exists(dst_path_range): 76 | os.makedirs(dst_path_range) 77 | file_path = os.path.join(dst_path_range, str(idx).zfill(6)) 78 | np.save(file_path, proj_range) 79 | 80 | # Save xyz 81 | dst_path_xyz = os.path.join(dst_folder, "xyz") 82 | if not os.path.exists(dst_path_xyz): 83 | os.makedirs(dst_path_xyz) 84 | file_path = os.path.join(dst_path_xyz, str(idx).zfill(6)) 85 | np.save(file_path, proj_vertex) 86 | 87 | # Save intensity 88 | dst_path_intensity = os.path.join(dst_folder, "intensity") 89 | if not os.path.exists(dst_path_intensity): 90 | os.makedirs(dst_path_intensity) 91 | file_path = os.path.join(dst_path_intensity, str(idx).zfill(6)) 92 | np.save(file_path, proj_intensity) 93 | 94 | # Save label 95 | sem_label = label & 0xFFFF # semantic label in lower half 96 | inst_label = label >> 16 # instance id in upper half 97 | # sanity check 98 | assert ((sem_label + (inst_label << 16) == label).all()) 99 | 100 | proj_sem_label = np.zeros((proj_H, proj_W), dtype=np.int32) # [H,W] label 101 | mask = proj_idx >= 0 102 | proj_sem_label[mask] = sem_label[proj_idx[mask]] 103 | 104 | dst_path_label = os.path.join(dst_folder, "labels") 105 | if not os.path.exists(dst_path_label): 106 | os.makedirs(dst_path_label) 107 | file_path = os.path.join(dst_path_label, str(idx).zfill(6)) 108 | np.save(file_path, proj_sem_label) 109 | 110 | 111 | def compute_mean_and_std(cfg, train_loader): 112 | """Compute training data statistics 113 | 114 | Args: 115 | cfg (dict): Config 116 | train_loader (DataLoader): Pytorch DataLoader to access training data 117 | """ 118 | n_channels = train_loader.dataset.n_channels 119 | mean = [0] * n_channels 120 | std = [0] * n_channels 121 | max = [0] * n_channels 122 | min = [0] * n_channels 123 | for i, data in enumerate(train_loader): 124 | past = data["past_data"] 125 | batch_size, n_channels, frames, H, W = past.shape 126 | 127 | for j in range(n_channels): 128 | channel = past[:, j, :, :, :].view(batch_size, 1, frames, H, W) 129 | mean[j] += torch.mean(channel[channel != -1.0]) / len(train_loader) 130 | std[j] += torch.std(channel[channel != -1.0]) / len(train_loader) 131 | max[j] += torch.max(channel[channel != -1.0]) / len(train_loader) 132 | min[j] += torch.min(channel[channel != -1.0]) / len(train_loader) 133 | 134 | print("Mean and standard deviation of training data:") 135 | for j in range(n_channels): 136 | print( 137 | "Input {:d}: Mean {:.3f}, std {:.3f}, min {:.3f}, max {:.3f}".format( 138 | j, mean[j], std[j], min[j], max[j] 139 | ) 140 | ) 141 | -------------------------------------------------------------------------------- /pcpnet/utils/projection.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou 3 | # This file is covered by the LICENSE file in the root of the project PCPNet: 4 | # https://github.com/Blurryface0814/PCPNet 5 | # Brief: Get a 3D point cloud from a given range image projection 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | 11 | class projection: 12 | """Projection class for getting a 3D point cloud from range images""" 13 | 14 | def __init__(self, cfg): 15 | """Init 16 | 17 | Args: 18 | cfg (dict): Parameters 19 | """ 20 | self.cfg = cfg 21 | 22 | fov_up = ( 23 | self.cfg["DATA_CONFIG"]["FOV_UP"] / 180.0 * np.pi 24 | ) # field of view up in radians 25 | fov_down = ( 26 | self.cfg["DATA_CONFIG"]["FOV_DOWN"] / 180.0 * np.pi 27 | ) # field of view down in radians 28 | fov = abs(fov_down) + abs(fov_up) # get field of view total in radian 29 | W = self.cfg["DATA_CONFIG"]["WIDTH"] 30 | H = self.cfg["DATA_CONFIG"]["HEIGHT"] 31 | 32 | h = torch.arange(0, H).view(H, 1).repeat(1, W) 33 | w = torch.arange(0, W).view(1, W).repeat(H, 1) 34 | yaw = np.pi * (1.0 - 2 * torch.true_divide(w, W)) 35 | pitch = (1.0 - torch.true_divide(h, H)) * fov - abs(fov_down) 36 | self.x_fac = torch.cos(pitch) * torch.cos(yaw) 37 | self.y_fac = torch.cos(pitch) * torch.sin(yaw) 38 | self.z_fac = torch.sin(pitch) 39 | 40 | def get_valid_points_from_range_view(self, range_view): 41 | """Reproject from range image to valid 3D points 42 | 43 | Args: 44 | range_view (torch.tensor): Range image with size (H,W) 45 | 46 | Returns: 47 | torch.tensor: Valid 3D points with size (N,3) 48 | """ 49 | H, W = range_view.shape 50 | points = torch.zeros(H, W, 3).type_as(range_view) 51 | points[:, :, 0] = range_view * self.x_fac.type_as(range_view) 52 | points[:, :, 1] = range_view * self.y_fac.type_as(range_view) 53 | points[:, :, 2] = range_view * self.z_fac.type_as(range_view) 54 | return points[range_view > 0.0] 55 | 56 | def get_xyz_from_range_view(self, range_view): 57 | """Reproject from range image to xyz 58 | 59 | Args: 60 | range_view (torch.tensor): Range image with size (B, T, H, W) 61 | 62 | Returns: 63 | torch.tensor: Valid 3D points with size (B, T, 3, H, W) 64 | """ 65 | B, T, H, W = range_view.shape 66 | points = torch.zeros(B, T, H, W, 3).type_as(range_view) 67 | points[:, :, :, :, 0] = range_view * self.x_fac.type_as(range_view) 68 | points[:, :, :, :, 1] = range_view * self.y_fac.type_as(range_view) 69 | points[:, :, :, :, 2] = range_view * self.z_fac.type_as(range_view) 70 | return points.permute(0, 1, 4, 2, 3) 71 | 72 | def get_mask_from_output(self, output): 73 | """Get mask from logits 74 | 75 | Args: 76 | output (dict): Output dict with mask_logits as key 77 | 78 | Returns: 79 | mask: Predicted mask containing per-point probabilities 80 | """ 81 | mask = nn.Sigmoid()(output["mask_logits"]) 82 | return mask 83 | 84 | def get_target_mask_from_range_view(self, range_view): 85 | """Ground truth mask 86 | 87 | Args: 88 | range_view (torch.tensor): Range image of size (H,W) 89 | 90 | Returns: 91 | torch.tensor: Target mask of valid points 92 | """ 93 | target_mask = torch.zeros(range_view.shape).type_as(range_view) 94 | target_mask[range_view != -1.0] = 1.0 95 | return target_mask 96 | 97 | def get_masked_range_view(self, output): 98 | """Get predicted masked range image 99 | 100 | Args: 101 | output (dict): Dictionary containing predicted mask logits and ranges 102 | 103 | Returns: 104 | torch.tensor: Maskes range image in which invalid points are mapped to -1.0 105 | """ 106 | mask = self.get_mask_from_output(output) 107 | masked_range_view = output["rv"].clone() 108 | 109 | # Set invalid points to -1.0 according to mask 110 | masked_range_view[mask < self.cfg["MODEL"]["MASK_THRESHOLD"]] = -1.0 111 | return masked_range_view 112 | -------------------------------------------------------------------------------- /pcpnet/utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou 3 | # This file is covered by the LICENSE file in the root of the project PCPNet: 4 | # https://github.com/Blurryface0814/PCPNet 5 | # Brief: Project a point cloud into a range image 6 | import os 7 | import math 8 | import numpy as np 9 | import random 10 | import torch 11 | 12 | def set_seed(seed): 13 | """ 14 | Set random seed for torch, numpy and python 15 | """ 16 | random.seed(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | if torch.cuda.is_available(): 20 | torch.cuda.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | 23 | torch.backends.cudnn.benchmark = False 24 | 25 | torch.backends.cudnn.deterministic = True 26 | 27 | 28 | def load_poses(pose_path): 29 | """Load ground truth poses (T_w_cam0) from file. 30 | 31 | Args: 32 | pose_path: (Complete) filename for the pose file 33 | 34 | Returns: 35 | A numpy array of size nx4x4 with n poses as 4x4 transformation 36 | matrices 37 | """ 38 | # Read and parse the poses 39 | poses = [] 40 | try: 41 | if ".txt" in pose_path: 42 | with open(pose_path, "r") as f: 43 | lines = f.readlines() 44 | for line in lines: 45 | T_w_cam0 = np.fromstring(line, dtype=float, sep=" ") 46 | T_w_cam0 = T_w_cam0.reshape(3, 4) 47 | T_w_cam0 = np.vstack((T_w_cam0, [0, 0, 0, 1])) 48 | poses.append(T_w_cam0) 49 | else: 50 | poses = np.load(pose_path)["arr_0"] 51 | 52 | except FileNotFoundError: 53 | print("Ground truth poses are not avaialble.") 54 | 55 | return np.array(poses) 56 | 57 | 58 | def load_calib(calib_path): 59 | """Load calibrations (T_cam_velo) from file.""" 60 | # Read and parse the calibrations 61 | T_cam_velo = [] 62 | try: 63 | with open(calib_path, "r") as f: 64 | lines = f.readlines() 65 | for line in lines: 66 | if "Tr:" in line: 67 | line = line.replace("Tr:", "") 68 | T_cam_velo = np.fromstring(line, dtype=float, sep=" ") 69 | T_cam_velo = T_cam_velo.reshape(3, 4) 70 | T_cam_velo = np.vstack((T_cam_velo, [0, 0, 0, 1])) 71 | 72 | except FileNotFoundError: 73 | print("Calibrations are not avaialble.") 74 | 75 | return np.array(T_cam_velo) 76 | 77 | 78 | def range_projection( 79 | current_vertex, fov_up=3.0, fov_down=-25.0, proj_H=64, proj_W=900, max_range=50 80 | ): 81 | """Project a pointcloud into a spherical projection, range image. 82 | 83 | Args: 84 | current_vertex: raw point clouds 85 | 86 | Returns: 87 | proj_range: projected range image with depth, each pixel contains the corresponding depth 88 | proj_vertex: each pixel contains the corresponding point (x, y, z, 1) 89 | proj_intensity: each pixel contains the corresponding intensity 90 | proj_idx: each pixel contains the corresponding index of the point in the raw point cloud 91 | """ 92 | # laser parameters 93 | fov_up = fov_up / 180.0 * np.pi # field of view up in radians 94 | fov_down = fov_down / 180.0 * np.pi # field of view down in radians 95 | fov = abs(fov_down) + abs(fov_up) # get field of view total in radians 96 | 97 | # get depth of all points 98 | depth = np.linalg.norm(current_vertex[:, :3], 2, axis=1) 99 | 100 | # # we use a maximum range threshold 101 | # current_vertex = current_vertex[(depth > 0) & (depth < max_range)] # get rid of [0, 0, 0] points 102 | # depth = depth[(depth > 0) & (depth < max_range)] 103 | 104 | # get scan components 105 | scan_x = current_vertex[:, 0] 106 | scan_y = current_vertex[:, 1] 107 | scan_z = current_vertex[:, 2] 108 | intensity = current_vertex[:, 3] 109 | 110 | # get angles of all points 111 | yaw = -np.arctan2(scan_y, scan_x) 112 | pitch = np.arcsin(scan_z / depth) 113 | 114 | # get projections in image coords 115 | proj_x = 0.5 * (yaw / np.pi + 1.0) # in [0.0, 1.0] 116 | proj_y = 1.0 - (pitch + abs(fov_down)) / fov # in [0.0, 1.0] 117 | 118 | # scale to image size using angular resolution 119 | proj_x *= proj_W # in [0.0, W] 120 | proj_y *= proj_H # in [0.0, H] 121 | 122 | # round and clamp for use as index 123 | proj_x = np.floor(proj_x) 124 | proj_x = np.minimum(proj_W - 1, proj_x) 125 | proj_x = np.maximum(0, proj_x).astype(np.int32) # in [0,W-1] 126 | 127 | proj_y = np.floor(proj_y) 128 | proj_y = np.minimum(proj_H - 1, proj_y) 129 | proj_y = np.maximum(0, proj_y).astype(np.int32) # in [0,H-1] 130 | 131 | # order in decreasing depth 132 | order = np.argsort(depth)[::-1] 133 | depth = depth[order] 134 | intensity = intensity[order] 135 | proj_y = proj_y[order] 136 | proj_x = proj_x[order] 137 | 138 | scan_x = scan_x[order] 139 | scan_y = scan_y[order] 140 | scan_z = scan_z[order] 141 | 142 | indices = np.arange(depth.shape[0]) 143 | indices = indices[order] 144 | 145 | proj_range = np.full( 146 | (proj_H, proj_W), -1, dtype=np.float32 147 | ) # [H,W] range (-1 is no data) 148 | proj_vertex = np.full( 149 | (proj_H, proj_W, 4), -1, dtype=np.float32 150 | ) # [H,W] index (-1 is no data) 151 | proj_idx = np.full( 152 | (proj_H, proj_W), -1, dtype=np.int32 153 | ) # [H,W] index (-1 is no data) 154 | proj_intensity = np.full( 155 | (proj_H, proj_W), -1, dtype=np.float32 156 | ) # [H,W] index (-1 is no data) 157 | 158 | proj_range[proj_y, proj_x] = depth 159 | proj_vertex[proj_y, proj_x] = np.array( 160 | [scan_x, scan_y, scan_z, np.ones(len(scan_x))] 161 | ).T 162 | proj_idx[proj_y, proj_x] = indices 163 | proj_intensity[proj_y, proj_x] = intensity 164 | 165 | return proj_range, proj_vertex, proj_intensity, proj_idx 166 | 167 | 168 | def gen_normal_map(current_range, current_vertex, proj_H=64, proj_W=900): 169 | """Generate a normal image given the range projection of a point cloud. 170 | 171 | Args: 172 | current_range: range projection of a point cloud, each pixel contains the corresponding depth 173 | current_vertex: range projection of a point cloud, 174 | each pixel contains the corresponding point (x, y, z, 1) 175 | 176 | Returns: 177 | normal_data: each pixel contains the corresponding normal 178 | """ 179 | normal_data = np.full((proj_H, proj_W, 3), -1, dtype=np.float32) 180 | 181 | # iterate over all pixels in the range image 182 | for x in range(proj_W): 183 | for y in range(proj_H - 1): 184 | p = current_vertex[y, x][:3] 185 | depth = current_range[y, x] 186 | 187 | if depth > 0: 188 | wrap_x = wrap(x + 1, proj_W) 189 | u = current_vertex[y, wrap_x][:3] 190 | u_depth = current_range[y, wrap_x] 191 | if u_depth <= 0: 192 | continue 193 | 194 | v = current_vertex[y + 1, x][:3] 195 | v_depth = current_range[y + 1, x] 196 | if v_depth <= 0: 197 | continue 198 | 199 | u_norm = (u - p) / np.linalg.norm(u - p) 200 | v_norm = (v - p) / np.linalg.norm(v - p) 201 | 202 | w = np.cross(v_norm, u_norm) 203 | norm = np.linalg.norm(w) 204 | if norm > 0: 205 | normal = w / norm 206 | normal_data[y, x] = normal 207 | 208 | return normal_data 209 | 210 | 211 | def wrap(x, dim): 212 | """Wrap the boarder of the range image.""" 213 | value = x 214 | if value >= dim: 215 | value = value - dim 216 | if value < 0: 217 | value = value + dim 218 | return value 219 | 220 | 221 | def euler_angles_from_rotation_matrix(R): 222 | """From the paper by Gregory G. Slabaugh, 223 | Computing Euler angles from a rotation matrix 224 | psi, theta, phi = roll pitch yaw (x, y, z) 225 | 226 | Args: 227 | R: rotation matrix, a 3x3 numpy array 228 | 229 | Returns: 230 | a tuple with the 3 values psi, theta, phi in radians 231 | """ 232 | 233 | def isclose(x, y, rtol=1.0e-5, atol=1.0e-8): 234 | return abs(x - y) <= atol + rtol * abs(y) 235 | 236 | phi = 0.0 237 | if isclose(R[2, 0], -1.0): 238 | theta = math.pi / 2.0 239 | psi = math.atan2(R[0, 1], R[0, 2]) 240 | elif isclose(R[2, 0], 1.0): 241 | theta = -math.pi / 2.0 242 | psi = math.atan2(-R[0, 1], -R[0, 2]) 243 | else: 244 | theta = -math.asin(R[2, 0]) 245 | cos_theta = math.cos(theta) 246 | psi = math.atan2(R[2, 1] / cos_theta, R[2, 2] / cos_theta) 247 | phi = math.atan2(R[1, 0] / cos_theta, R[0, 0] / cos_theta) 248 | return psi, theta, phi 249 | 250 | 251 | def load_vertex(scan_path): 252 | """Load 3D points of a scan. The fileformat is the .bin format used in 253 | the KITTI dataset. 254 | 255 | Args: 256 | scan_path: the (full) filename of the scan file 257 | 258 | Returns: 259 | A nx4 numpy array of homogeneous points (x, y, z, 1). 260 | """ 261 | current_vertex = np.fromfile(scan_path, dtype=np.float32) 262 | current_vertex = current_vertex.reshape((-1, 4)) 263 | current_points = current_vertex[:, 0:3] 264 | current_vertex = np.ones((current_points.shape[0], current_points.shape[1] + 1)) 265 | current_vertex[:, :-1] = current_points 266 | return current_vertex 267 | 268 | 269 | def load_files(folder): 270 | """Load all files in a folder and sort.""" 271 | file_paths = [ 272 | os.path.join(dp, f) 273 | for dp, dn, fn in os.walk(os.path.expanduser(folder)) 274 | for f in fn 275 | ] 276 | file_paths.sort() 277 | return file_paths 278 | -------------------------------------------------------------------------------- /pcpnet/utils/visualization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou 3 | # This file is covered by the LICENSE file in the root of the project PCPNet: 4 | # https://github.com/Blurryface0814/PCPNet 5 | # Brief: Visualization of Point Cloud Predictions with Open3D 6 | from functools import partial 7 | import os 8 | import glob 9 | import time 10 | import open3d as o3d 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | 14 | WIDTH = 1000 15 | HEIGHT = 1000 16 | POINTSIZE = 1.5 17 | SLEEPTIME = 0.3 18 | 19 | 20 | def get_car_model(filename): 21 | """Car model for visualization 22 | 23 | Args: 24 | filename (str): filename of mesh 25 | 26 | Returns: 27 | mesh: open3D mesh 28 | """ 29 | mesh = o3d.io.read_triangle_mesh(filename) 30 | mesh.scale(0.001, (-0.5, 0.4, -1.2)) 31 | R = mesh.get_rotation_matrix_from_quaternion(np.array([0, 0, 1, 1]).T) 32 | mesh.rotate(R) 33 | mesh.compute_triangle_normals() 34 | mesh.compute_vertex_normals() 35 | return mesh 36 | 37 | 38 | def get_filename(path, idx): 39 | filenames = sorted(glob.glob(path + "*.ply")) 40 | return int(filenames[idx].split(".")[0].split("/")[-1]) 41 | 42 | 43 | def last_file(path): 44 | return get_filename(path, -1) 45 | 46 | 47 | def first_file(path): 48 | return get_filename(path, 0) 49 | 50 | 51 | class Visualization: 52 | """Visualization of point cloud predictions with open3D""" 53 | 54 | def __init__( 55 | self, 56 | path, 57 | sequence, 58 | start, 59 | end, 60 | capture=False, 61 | path_to_car_model=None, 62 | sleep_time=5e-3, 63 | ): 64 | """Init 65 | 66 | Args: 67 | path (str): path to data should be 68 | . 69 | ├── sequence 70 | │ ├── gt 71 | | | ├──frame.ply 72 | │ ├─── pred 73 | | | ├── frame 74 | | │ | ├─── (frame+1).ply 75 | 76 | sequence (int): Sequence to visualize 77 | start (int): Start at specific frame 78 | end (itn): End at specific frame 79 | capture (bool, optional): Save to file at each frame. Defaults to False. 80 | path_to_car_model (str, optional): Path to car model. Defaults to None. 81 | sleep_time (float, optional): Sleep time between frames. Defaults to 5e-3. 82 | """ 83 | self.vis = o3d.visualization.VisualizerWithKeyCallback() 84 | self.vis.create_window(width=WIDTH, height=HEIGHT) 85 | self.render_options = self.vis.get_render_option() 86 | self.render_options.point_size = POINTSIZE 87 | self.capture = capture 88 | 89 | # Load car model 90 | if path_to_car_model: 91 | self.car_mesh = get_car_model(path_to_car_model) 92 | else: 93 | self.car_mesh = None 94 | 95 | # Path and sequence to visualize 96 | self.path = path 97 | self.sequence = sequence 98 | 99 | # Frames to visualize 100 | self.start = start 101 | self.end = end 102 | 103 | # Set flag 104 | self.gt_flag = True 105 | self.pred_flag = True 106 | self.car_flag = True 107 | 108 | # Init 109 | self.current_frame = self.start 110 | self.current_step = 1 111 | self.n_pred_steps = 5 112 | 113 | # Save last view 114 | self.ctr = self.vis.get_view_control() 115 | self.camera = self.ctr.convert_to_pinhole_camera_parameters() 116 | self.viewpoint_path = os.path.join(self.path, "viewpoint.json") 117 | 118 | self.print_help() 119 | self.update(self.vis) 120 | 121 | # Continuous time plot 122 | self.stop = False 123 | self.sleep_time = sleep_time 124 | 125 | # Initialize the default callbacks 126 | self._register_key_callbacks() 127 | 128 | self.last_time_key_pressed = time.time() 129 | 130 | def prev_frame(self, vis): 131 | if time.time() - self.last_time_key_pressed > SLEEPTIME: 132 | self.last_time_key_pressed = time.time() 133 | self.current_frame = max(self.start, self.current_frame - 1) 134 | self.update(vis) 135 | return False 136 | 137 | def next_frame(self, vis): 138 | if time.time() - self.last_time_key_pressed > SLEEPTIME: 139 | self.last_time_key_pressed = time.time() 140 | self.current_frame = min(self.end, self.current_frame + 1) 141 | self.update(vis) 142 | return False 143 | 144 | def prev_prediction_step(self, vis): 145 | if time.time() - self.last_time_key_pressed > SLEEPTIME: 146 | self.last_time_key_pressed = time.time() 147 | self.current_step = max(1, self.current_step - 1) 148 | self.update(vis) 149 | return False 150 | 151 | def next_prediction_step(self, vis): 152 | if time.time() - self.last_time_key_pressed > SLEEPTIME: 153 | self.last_time_key_pressed = time.time() 154 | self.current_step = min(self.n_pred_steps, self.current_step + 1) 155 | self.update(vis) 156 | return False 157 | 158 | def show_or_hide_gt(self, vis): 159 | if time.time() - self.last_time_key_pressed > SLEEPTIME: 160 | self.last_time_key_pressed = time.time() 161 | if self.gt_flag: 162 | self.gt_flag = False 163 | else: 164 | self.gt_flag = True 165 | self.update(vis) 166 | return False 167 | 168 | def show_or_hide_pred(self, vis): 169 | if time.time() - self.last_time_key_pressed > SLEEPTIME: 170 | self.last_time_key_pressed = time.time() 171 | if self.pred_flag: 172 | self.pred_flag = False 173 | else: 174 | self.pred_flag = True 175 | self.update(vis) 176 | return False 177 | 178 | def show_or_hide_car(self, vis): 179 | if time.time() - self.last_time_key_pressed > SLEEPTIME: 180 | self.last_time_key_pressed = time.time() 181 | if self.car_flag: 182 | self.car_flag = False 183 | else: 184 | self.car_flag = True 185 | self.update(vis) 186 | return False 187 | 188 | def play_sequence(self, vis): 189 | self.stop = False 190 | while not self.stop: 191 | self.next_frame(vis) 192 | time.sleep(self.sleep_time) 193 | 194 | def stop_sequence(self, vis): 195 | self.stop = True 196 | 197 | def toggle_capture_mode(self, vis): 198 | if self.capture: 199 | self.capture = False 200 | else: 201 | self.capture = True 202 | print( 203 | "Current frame: {:d}/{:d}, Prediction step {:d}/{:d}, capture_mode: {:b}".format( 204 | self.current_frame, 205 | self.end, 206 | self.current_step, 207 | self.n_pred_steps, 208 | self.capture, 209 | ), 210 | end="\r", 211 | ) 212 | 213 | def update(self, vis): 214 | """Get point clouds and visualize""" 215 | print( 216 | "Current frame: {:d}/{:d}, Prediction step {:d}/{:d}, capture_mode: {:b}".format( 217 | self.current_frame, 218 | self.end, 219 | self.current_step, 220 | self.n_pred_steps, 221 | self.capture, 222 | ), 223 | end="\r", 224 | ) 225 | gt_pcd, pred_pcd = self.get_gt_and_predictions( 226 | self.path, self.sequence, self.current_frame, self.current_step 227 | ) 228 | 229 | geometry_list = [] 230 | 231 | if self.gt_flag: 232 | geometry_list.append(gt_pcd) 233 | if self.pred_flag: 234 | geometry_list.append(pred_pcd) 235 | if self.car_mesh: 236 | if self.car_flag: 237 | geometry_list.append(self.car_mesh) 238 | self.vis_update_geometries(vis, geometry_list) 239 | 240 | if self.capture: 241 | self.capture_frame() 242 | 243 | def get_gt_and_predictions(self, path, sequence, current_frame, step): 244 | """Load GT and predictions from path 245 | 246 | Args: 247 | path (str): Path to files 248 | sequence (int): Sequence to visualize 249 | current_frame (int): Last received frame for prediction 250 | step (int): Prediction step to visualize 251 | 252 | Returns: 253 | o3d.point_cloud: GT and predicted point clouds 254 | """ 255 | pred_path = os.path.join( 256 | path, 257 | sequence, 258 | "pred", 259 | str(current_frame).zfill(6), 260 | str(current_frame + step).zfill(6) + ".ply", 261 | ) 262 | pred_pcd = o3d.io.read_point_cloud(pred_path) 263 | pred_pcd.paint_uniform_color([0, 0, 1]) 264 | 265 | gt_path = os.path.join( 266 | path, sequence, "gt", str(current_frame + step).zfill(6) + ".ply" 267 | ) 268 | gt_pcd = o3d.io.read_point_cloud(gt_path) 269 | gt_pcd.paint_uniform_color([1, 0, 0]) 270 | return gt_pcd, pred_pcd 271 | 272 | def vis_update_geometries(self, vis, geometries): 273 | """Save camera pose and update point clouds""" 274 | # Save current camera pose 275 | self.camera = self.ctr.convert_to_pinhole_camera_parameters() 276 | 277 | vis.clear_geometries() 278 | for geometry in geometries: 279 | vis.add_geometry(geometry) 280 | 281 | # Set to last view 282 | self.ctr.convert_from_pinhole_camera_parameters(self.camera) 283 | 284 | self.vis.poll_events() 285 | self.vis.update_renderer() 286 | 287 | def set_render_options(self, **kwargs): 288 | for key, value in kwargs.items(): 289 | setattr(self.render_options, key, value) 290 | 291 | def register_key_callback(self, key, callback): 292 | self.vis.register_key_callback(ord(str(key)), partial(callback)) 293 | 294 | def set_white_background(self, vis): 295 | """Change backround between white and white""" 296 | self.render_options.background_color = [1.0, 1.0, 1.0] 297 | 298 | def set_black_background(self, vis): 299 | """Change backround between white and black""" 300 | self.render_options.background_color = [0.0, 0.0, 0.0] 301 | 302 | def save_viewpoint(self, vis): 303 | """Saves viewpoint""" 304 | self.camera = self.ctr.convert_to_pinhole_camera_parameters() 305 | o3d.io.write_pinhole_camera_parameters(self.viewpoint_path, self.camera) 306 | 307 | def load_viewpoint(self, vis): 308 | """Loads viewpoint""" 309 | self.camera = o3d.io.read_pinhole_camera_parameters(self.viewpoint_path) 310 | self.ctr.convert_from_pinhole_camera_parameters(self.camera) 311 | 312 | def _register_key_callbacks(self): 313 | self.register_key_callback("L", self.next_frame) 314 | self.register_key_callback("H", self.prev_frame) 315 | self.register_key_callback("K", self.next_prediction_step) 316 | self.register_key_callback("J", self.prev_prediction_step) 317 | self.register_key_callback("G", self.show_or_hide_gt) 318 | self.register_key_callback("F", self.show_or_hide_pred) 319 | self.register_key_callback("D", self.show_or_hide_car) 320 | self.register_key_callback("S", self.play_sequence) 321 | self.register_key_callback("X", self.stop_sequence) 322 | self.register_key_callback("C", self.toggle_capture_mode) 323 | self.register_key_callback("W", self.set_white_background) 324 | self.register_key_callback("B", self.set_black_background) 325 | self.register_key_callback("V", self.save_viewpoint) 326 | self.register_key_callback("Q", self.load_viewpoint) 327 | 328 | def print_help(self): 329 | print("L: next frame") 330 | print("H: previous frame") 331 | print("K: next prediction step") 332 | print("J: previous prediction step") 333 | print("G: show or hide gt point cloud") 334 | print("F: show or hide pred point cloud") 335 | print("D: show or hide car model") 336 | print("S: start") 337 | print("X: stop") 338 | print("C: Toggle capture mode") 339 | print("W: white background") 340 | print("B: black background") 341 | print("V: save viewpoint") 342 | print("Q: set to saved viewpoint") 343 | print("ESC: quit") 344 | 345 | def capture_frame(self): 346 | """Save view from current viewpoint""" 347 | image = self.vis.capture_screen_float_buffer(False) 348 | path = os.path.join(self.path, self.sequence, "images") 349 | if not os.path.exists(path): 350 | os.makedirs(path) 351 | 352 | filename = os.path.join( 353 | path, 354 | "step_{:1d}_prediction_from_frame_{:05d}.png".format( 355 | self.current_step, self.current_frame 356 | ), 357 | ) 358 | print("Capture image ", filename) 359 | plt.imsave(filename, np.asarray(image), dpi=1) 360 | 361 | def run(self): 362 | self.vis.run() 363 | self.vis.destroy_window() 364 | -------------------------------------------------------------------------------- /pyTorchChamferDistance/LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyTorchChamferDistance/README.md: -------------------------------------------------------------------------------- 1 | # Chamfer Distance for pyTorch 2 | 3 | This is an implementation of the Chamfer Distance as a module for pyTorch. It is written as a custom C++/CUDA extension. 4 | 5 | As it is using pyTorch's [JIT compilation](https://pytorch.org/tutorials/advanced/cpp_extension.html), there are no additional prerequisite steps that have to be taken. Simply import the module as shown below; CUDA and C++ code will be compiled on the first run. 6 | 7 | ### Usage 8 | ```python 9 | from chamfer_distance import ChamferDistance 10 | chamfer_dist = ChamferDistance() 11 | 12 | #... 13 | # points and points_reconstructed are n_points x 3 matrices 14 | 15 | dist1, dist2 = chamfer_dist(points, points_reconstructed) 16 | loss = (torch.mean(dist1)) + (torch.mean(dist2)) 17 | 18 | 19 | #... 20 | ``` 21 | 22 | ### Integration 23 | This code has been integrated into the [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) library for 3D Deep Learning by NVIDIAGameWorks. You should probably take a look at it if you are working on anything 3D :) 24 | -------------------------------------------------------------------------------- /pyTorchChamferDistance/chamfer_distance/__init__.py: -------------------------------------------------------------------------------- 1 | from .chamfer_distance import ChamferDistance 2 | -------------------------------------------------------------------------------- /pyTorchChamferDistance/chamfer_distance/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pyTorchChamferDistance/chamfer_distance/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /pyTorchChamferDistance/chamfer_distance/__pycache__/chamfer_distance.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pyTorchChamferDistance/chamfer_distance/__pycache__/chamfer_distance.cpython-38.pyc -------------------------------------------------------------------------------- /pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | int ChamferDistanceKernelLauncher( 5 | const int b, const int n, 6 | const float* xyz, 7 | const int m, 8 | const float* xyz2, 9 | float* result, 10 | int* result_i, 11 | float* result2, 12 | int* result2_i); 13 | 14 | int ChamferDistanceGradKernelLauncher( 15 | const int b, const int n, 16 | const float* xyz1, 17 | const int m, 18 | const float* xyz2, 19 | const float* grad_dist1, 20 | const int* idx1, 21 | const float* grad_dist2, 22 | const int* idx2, 23 | float* grad_xyz1, 24 | float* grad_xyz2); 25 | 26 | 27 | void chamfer_distance_forward_cuda( 28 | const at::Tensor xyz1, 29 | const at::Tensor xyz2, 30 | const at::Tensor dist1, 31 | const at::Tensor dist2, 32 | const at::Tensor idx1, 33 | const at::Tensor idx2) 34 | { 35 | ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 36 | xyz2.size(1), xyz2.data(), 37 | dist1.data(), idx1.data(), 38 | dist2.data(), idx2.data()); 39 | } 40 | 41 | void chamfer_distance_backward_cuda( 42 | const at::Tensor xyz1, 43 | const at::Tensor xyz2, 44 | at::Tensor gradxyz1, 45 | at::Tensor gradxyz2, 46 | at::Tensor graddist1, 47 | at::Tensor graddist2, 48 | at::Tensor idx1, 49 | at::Tensor idx2) 50 | { 51 | ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 52 | xyz2.size(1), xyz2.data(), 53 | graddist1.data(), idx1.data(), 54 | graddist2.data(), idx2.data(), 55 | gradxyz1.data(), gradxyz2.data()); 56 | } 57 | 58 | 59 | void nnsearch( 60 | const int b, const int n, const int m, 61 | const float* xyz1, 62 | const float* xyz2, 63 | float* dist, 64 | int* idx) 65 | { 66 | for (int i = 0; i < b; i++) { 67 | for (int j = 0; j < n; j++) { 68 | const float x1 = xyz1[(i*n+j)*3+0]; 69 | const float y1 = xyz1[(i*n+j)*3+1]; 70 | const float z1 = xyz1[(i*n+j)*3+2]; 71 | double best = 0; 72 | int besti = 0; 73 | for (int k = 0; k < m; k++) { 74 | const float x2 = xyz2[(i*m+k)*3+0] - x1; 75 | const float y2 = xyz2[(i*m+k)*3+1] - y1; 76 | const float z2 = xyz2[(i*m+k)*3+2] - z1; 77 | const double d=x2*x2+y2*y2+z2*z2; 78 | if (k==0 || d < best){ 79 | best = d; 80 | besti = k; 81 | } 82 | } 83 | dist[i*n+j] = best; 84 | idx[i*n+j] = besti; 85 | } 86 | } 87 | } 88 | 89 | 90 | void chamfer_distance_forward( 91 | const at::Tensor xyz1, 92 | const at::Tensor xyz2, 93 | const at::Tensor dist1, 94 | const at::Tensor dist2, 95 | const at::Tensor idx1, 96 | const at::Tensor idx2) 97 | { 98 | const int batchsize = xyz1.size(0); 99 | const int n = xyz1.size(1); 100 | const int m = xyz2.size(1); 101 | 102 | const float* xyz1_data = xyz1.data(); 103 | const float* xyz2_data = xyz2.data(); 104 | float* dist1_data = dist1.data(); 105 | float* dist2_data = dist2.data(); 106 | int* idx1_data = idx1.data(); 107 | int* idx2_data = idx2.data(); 108 | 109 | nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data); 110 | nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data); 111 | } 112 | 113 | 114 | void chamfer_distance_backward( 115 | const at::Tensor xyz1, 116 | const at::Tensor xyz2, 117 | at::Tensor gradxyz1, 118 | at::Tensor gradxyz2, 119 | at::Tensor graddist1, 120 | at::Tensor graddist2, 121 | at::Tensor idx1, 122 | at::Tensor idx2) 123 | { 124 | const int b = xyz1.size(0); 125 | const int n = xyz1.size(1); 126 | const int m = xyz2.size(1); 127 | 128 | const float* xyz1_data = xyz1.data(); 129 | const float* xyz2_data = xyz2.data(); 130 | float* gradxyz1_data = gradxyz1.data(); 131 | float* gradxyz2_data = gradxyz2.data(); 132 | float* graddist1_data = graddist1.data(); 133 | float* graddist2_data = graddist2.data(); 134 | const int* idx1_data = idx1.data(); 135 | const int* idx2_data = idx2.data(); 136 | 137 | for (int i = 0; i < b*n*3; i++) 138 | gradxyz1_data[i] = 0; 139 | for (int i = 0; i < b*m*3; i++) 140 | gradxyz2_data[i] = 0; 141 | for (int i = 0;i < b; i++) { 142 | for (int j = 0; j < n; j++) { 143 | const float x1 = xyz1_data[(i*n+j)*3+0]; 144 | const float y1 = xyz1_data[(i*n+j)*3+1]; 145 | const float z1 = xyz1_data[(i*n+j)*3+2]; 146 | const int j2 = idx1_data[i*n+j]; 147 | 148 | const float x2 = xyz2_data[(i*m+j2)*3+0]; 149 | const float y2 = xyz2_data[(i*m+j2)*3+1]; 150 | const float z2 = xyz2_data[(i*m+j2)*3+2]; 151 | const float g = graddist1_data[i*n+j]*2; 152 | 153 | gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2); 154 | gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2); 155 | gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2); 156 | gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2)); 157 | gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2)); 158 | gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2)); 159 | } 160 | for (int j = 0; j < m; j++) { 161 | const float x1 = xyz2_data[(i*m+j)*3+0]; 162 | const float y1 = xyz2_data[(i*m+j)*3+1]; 163 | const float z1 = xyz2_data[(i*m+j)*3+2]; 164 | const int j2 = idx2_data[i*m+j]; 165 | const float x2 = xyz1_data[(i*n+j2)*3+0]; 166 | const float y2 = xyz1_data[(i*n+j2)*3+1]; 167 | const float z2 = xyz1_data[(i*n+j2)*3+2]; 168 | const float g = graddist2_data[i*m+j]*2; 169 | gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2); 170 | gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2); 171 | gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2); 172 | gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2)); 173 | gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2)); 174 | gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2)); 175 | } 176 | } 177 | } 178 | 179 | 180 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 181 | m.def("forward", &chamfer_distance_forward, "ChamferDistance forward"); 182 | m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)"); 183 | m.def("backward", &chamfer_distance_backward, "ChamferDistance backward"); 184 | m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)"); 185 | } 186 | -------------------------------------------------------------------------------- /pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | __global__ 7 | void ChamferDistanceKernel( 8 | int b, 9 | int n, 10 | const float* xyz, 11 | int m, 12 | const float* xyz2, 13 | float* result, 14 | int* result_i) 15 | { 16 | const int batch=512; 17 | __shared__ float buf[batch*3]; 18 | for (int i=blockIdx.x;ibest){ 130 | result[(i*n+j)]=best; 131 | result_i[(i*n+j)]=best_i; 132 | } 133 | } 134 | __syncthreads(); 135 | } 136 | } 137 | } 138 | 139 | void ChamferDistanceKernelLauncher( 140 | const int b, const int n, 141 | const float* xyz, 142 | const int m, 143 | const float* xyz2, 144 | float* result, 145 | int* result_i, 146 | float* result2, 147 | int* result2_i) 148 | { 149 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i); 150 | ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i); 151 | 152 | cudaError_t err = cudaGetLastError(); 153 | if (err != cudaSuccess) 154 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err)); 155 | } 156 | 157 | 158 | __global__ 159 | void ChamferDistanceGradKernel( 160 | int b, int n, 161 | const float* xyz1, 162 | int m, 163 | const float* xyz2, 164 | const float* grad_dist1, 165 | const int* idx1, 166 | float* grad_xyz1, 167 | float* grad_xyz2) 168 | { 169 | for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2); 204 | ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1); 205 | 206 | cudaError_t err = cudaGetLastError(); 207 | if (err != cudaSuccess) 208 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err)); 209 | } 210 | -------------------------------------------------------------------------------- /pyTorchChamferDistance/chamfer_distance/chamfer_distance.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import os 4 | 5 | from torch.utils.cpp_extension import load 6 | cd = load(name="cd", 7 | sources=[os.path.abspath("pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp"), 8 | os.path.abspath("pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu")]) 9 | 10 | class ChamferDistanceFunction(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, xyz1, xyz2): 13 | batchsize, n, _ = xyz1.size() 14 | _, m, _ = xyz2.size() 15 | xyz1 = xyz1.contiguous() 16 | xyz2 = xyz2.contiguous() 17 | dist1 = torch.zeros(batchsize, n) 18 | dist2 = torch.zeros(batchsize, m) 19 | 20 | idx1 = torch.zeros(batchsize, n, dtype=torch.int) 21 | idx2 = torch.zeros(batchsize, m, dtype=torch.int) 22 | 23 | if not xyz1.is_cuda: 24 | cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 25 | else: 26 | dist1 = dist1.cuda() 27 | dist2 = dist2.cuda() 28 | idx1 = idx1.cuda() 29 | idx2 = idx2.cuda() 30 | cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2) 31 | 32 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 33 | 34 | return dist1, dist2 35 | 36 | @staticmethod 37 | def backward(ctx, graddist1, graddist2): 38 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 39 | 40 | graddist1 = graddist1.contiguous() 41 | graddist2 = graddist2.contiguous() 42 | 43 | gradxyz1 = torch.zeros(xyz1.size()) 44 | gradxyz2 = torch.zeros(xyz2.size()) 45 | 46 | if not graddist1.is_cuda: 47 | cd.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 48 | else: 49 | gradxyz1 = gradxyz1.cuda() 50 | gradxyz2 = gradxyz2.cuda() 51 | cd.backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 52 | 53 | return gradxyz1, gradxyz2 54 | 55 | 56 | class ChamferDistance(torch.nn.Module): 57 | def forward(self, xyz1, xyz2): 58 | return ChamferDistanceFunction.apply(xyz1, xyz2) 59 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "point-cloud-prediction" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Benedikt Mersch "] 6 | packages = [ 7 | {include = "./pcf"}, 8 | {include = "./pyTorchChamferDistance"}, 9 | ] 10 | 11 | [tool.poetry.dependencies] 12 | python = "^3.8" 13 | PyYAML = "^6.0" 14 | matplotlib = "^3.4.3" 15 | open3d = "^0.13.0" 16 | pytorch-lightning = "^1.5.0" 17 | ninja = "^1.10.2" 18 | 19 | 20 | [tool.poetry.dev-dependencies] 21 | 22 | [build-system] 23 | requires = ["poetry-core>=1.0.0"] 24 | build-backend = "poetry.core.masonry.api" 25 | -------------------------------------------------------------------------------- /semantic_net/rangenet/README.md: -------------------------------------------------------------------------------- 1 | # LiDAR-Bonnetal Training 2 | 3 | This part of the framework deals with the training of segmentation networks for point cloud data using range images. 4 | 5 | ## Tasks 6 | 7 | - [Semantic Segmentation](tasks/semantic). 8 | - [Panoptic Segmentation](tasks/panoptic) \[Soon\]. 9 | 10 | ## Dependencies 11 | 12 | First you need to install the nvidia driver and CUDA, so have fun! 13 | 14 | - CUDA Installation guide: [link](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) 15 | 16 | - System dependencies: 17 | 18 | ```sh 19 | $ sudo apt-get update 20 | $ sudo apt-get install -yqq build-essential ninja-build \ 21 | python3-dev python3-pip apt-utils curl git cmake unzip autoconf autogen \ 22 | libtool mlocate zlib1g-dev python3-numpy python3-wheel wget \ 23 | software-properties-common openjdk-8-jdk libpng-dev \ 24 | libxft-dev ffmpeg python3-pyqt5.qtopengl 25 | $ sudo updatedb 26 | ``` 27 | 28 | - Python dependencies 29 | 30 | ```sh 31 | $ sudo pip3 install -r requirements.txt 32 | ``` -------------------------------------------------------------------------------- /semantic_net/rangenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/__init__.py -------------------------------------------------------------------------------- /semantic_net/rangenet/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /semantic_net/rangenet/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/backbones/__init__.py -------------------------------------------------------------------------------- /semantic_net/rangenet/backbones/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/backbones/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /semantic_net/rangenet/backbones/__pycache__/squeezeseg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/backbones/__pycache__/squeezeseg.cpython-38.pyc -------------------------------------------------------------------------------- /semantic_net/rangenet/backbones/__pycache__/squeezesegV2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/backbones/__pycache__/squeezesegV2.cpython-38.pyc -------------------------------------------------------------------------------- /semantic_net/rangenet/backbones/darknet.py: -------------------------------------------------------------------------------- 1 | # This file was modified from https://github.com/BobLiu20/YOLOv3_PyTorch 2 | # It needed to be modified in order to accomodate for different strides in the 3 | 4 | import torch.nn as nn 5 | from collections import OrderedDict 6 | import torch.nn.functional as F 7 | 8 | 9 | class BasicBlock(nn.Module): 10 | def __init__(self, inplanes, planes, bn_d=0.1): 11 | super(BasicBlock, self).__init__() 12 | self.conv1 = nn.Conv2d(inplanes, planes[0], kernel_size=1, 13 | stride=1, padding=0, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes[0], momentum=bn_d) 15 | self.relu1 = nn.LeakyReLU(0.1) 16 | self.conv2 = nn.Conv2d(planes[0], planes[1], kernel_size=3, 17 | stride=1, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes[1], momentum=bn_d) 19 | self.relu2 = nn.LeakyReLU(0.1) 20 | 21 | def forward(self, x): 22 | residual = x 23 | 24 | out = self.conv1(x) 25 | out = self.bn1(out) 26 | out = self.relu1(out) 27 | 28 | out = self.conv2(out) 29 | out = self.bn2(out) 30 | out = self.relu2(out) 31 | 32 | out += residual 33 | return out 34 | 35 | 36 | # ****************************************************************************** 37 | 38 | # number of layers per model 39 | model_blocks = { 40 | 21: [1, 1, 2, 2, 1], 41 | 53: [1, 2, 8, 8, 4], 42 | } 43 | 44 | 45 | class Backbone(nn.Module): 46 | """ 47 | Class for DarknetSeg. Subclasses PyTorch's own "nn" module 48 | """ 49 | 50 | def __init__(self, params): 51 | super(Backbone, self).__init__() 52 | self.use_range = params["input_depth"]["range"] 53 | self.use_xyz = params["input_depth"]["xyz"] 54 | self.use_remission = params["input_depth"]["remission"] 55 | self.drop_prob = params["dropout"] 56 | self.bn_d = params["bn_d"] 57 | self.OS = params["OS"] 58 | self.layers = params["extra"]["layers"] 59 | print("Using DarknetNet" + str(self.layers) + " Backbone") 60 | 61 | # input depth calc 62 | self.input_depth = 0 63 | self.input_idxs = [] 64 | if self.use_range: 65 | self.input_depth += 1 66 | self.input_idxs.append(0) 67 | if self.use_xyz: 68 | self.input_depth += 3 69 | self.input_idxs.extend([1, 2, 3]) 70 | if self.use_remission: 71 | self.input_depth += 1 72 | self.input_idxs.append(4) 73 | print("Depth of backbone input = ", self.input_depth) 74 | 75 | # stride play 76 | self.strides = [2, 2, 2, 2, 2] 77 | # check current stride 78 | current_os = 1 79 | for s in self.strides: 80 | current_os *= s 81 | print("Original OS: ", current_os) 82 | 83 | # make the new stride 84 | if self.OS > current_os: 85 | print("Can't do OS, ", self.OS, 86 | " because it is bigger than original ", current_os) 87 | else: 88 | # redo strides according to needed stride 89 | for i, stride in enumerate(reversed(self.strides), 0): 90 | if int(current_os) != self.OS: 91 | if stride == 2: 92 | current_os /= 2 93 | self.strides[-1 - i] = 1 94 | if int(current_os) == self.OS: 95 | break 96 | print("New OS: ", int(current_os)) 97 | print("Strides: ", self.strides) 98 | 99 | # check that darknet exists 100 | assert self.layers in model_blocks.keys() 101 | 102 | # generate layers depending on darknet type 103 | self.blocks = model_blocks[self.layers] 104 | 105 | # input layer 106 | self.conv1 = nn.Conv2d(self.input_depth, 32, kernel_size=3, 107 | stride=1, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(32, momentum=self.bn_d) 109 | self.relu1 = nn.LeakyReLU(0.1) 110 | 111 | # encoder 112 | self.enc1 = self._make_enc_layer(BasicBlock, [32, 64], self.blocks[0], 113 | stride=self.strides[0], bn_d=self.bn_d) 114 | self.enc2 = self._make_enc_layer(BasicBlock, [64, 128], self.blocks[1], 115 | stride=self.strides[1], bn_d=self.bn_d) 116 | self.enc3 = self._make_enc_layer(BasicBlock, [128, 256], self.blocks[2], 117 | stride=self.strides[2], bn_d=self.bn_d) 118 | self.enc4 = self._make_enc_layer(BasicBlock, [256, 512], self.blocks[3], 119 | stride=self.strides[3], bn_d=self.bn_d) 120 | self.enc5 = self._make_enc_layer(BasicBlock, [512, 1024], self.blocks[4], 121 | stride=self.strides[4], bn_d=self.bn_d) 122 | 123 | # for a bit of fun 124 | self.dropout = nn.Dropout2d(self.drop_prob) 125 | 126 | # last channels 127 | self.last_channels = 1024 128 | 129 | # make layer useful function 130 | def _make_enc_layer(self, block, planes, blocks, stride, bn_d=0.1): 131 | layers = [] 132 | 133 | # downsample 134 | layers.append(("conv", nn.Conv2d(planes[0], planes[1], 135 | kernel_size=3, 136 | stride=[1, stride], dilation=1, 137 | padding=1, bias=False))) 138 | layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d))) 139 | layers.append(("relu", nn.LeakyReLU(0.1))) 140 | 141 | # blocks 142 | inplanes = planes[1] 143 | for i in range(0, blocks): 144 | layers.append(("residual_{}".format(i), 145 | block(inplanes, planes, bn_d))) 146 | 147 | return nn.Sequential(OrderedDict(layers)) 148 | 149 | def run_layer(self, x, layer, skips, os): 150 | y = layer(x) 151 | if y.shape[2] < x.shape[2] or y.shape[3] < x.shape[3]: 152 | skips[os] = x.detach() 153 | os *= 2 154 | x = y 155 | return x, skips, os 156 | 157 | def forward(self, x): 158 | # filter input 159 | x = x[:, self.input_idxs] 160 | 161 | # run cnn 162 | # store for skip connections 163 | skips = {} 164 | os = 1 165 | 166 | # first layer 167 | x, skips, os = self.run_layer(x, self.conv1, skips, os) 168 | x, skips, os = self.run_layer(x, self.bn1, skips, os) 169 | x, skips, os = self.run_layer(x, self.relu1, skips, os) 170 | 171 | # all encoder blocks with intermediate dropouts 172 | x, skips, os = self.run_layer(x, self.enc1, skips, os) 173 | x, skips, os = self.run_layer(x, self.dropout, skips, os) 174 | x, skips, os = self.run_layer(x, self.enc2, skips, os) 175 | x, skips, os = self.run_layer(x, self.dropout, skips, os) 176 | x, skips, os = self.run_layer(x, self.enc3, skips, os) 177 | x, skips, os = self.run_layer(x, self.dropout, skips, os) 178 | x, skips, os = self.run_layer(x, self.enc4, skips, os) 179 | x, skips, os = self.run_layer(x, self.dropout, skips, os) 180 | x, skips, os = self.run_layer(x, self.enc5, skips, os) 181 | x, skips, os = self.run_layer(x, self.dropout, skips, os) 182 | 183 | return x, skips 184 | 185 | def get_last_depth(self): 186 | return self.last_channels 187 | 188 | def get_input_depth(self): 189 | return self.input_depth 190 | -------------------------------------------------------------------------------- /semantic_net/rangenet/backbones/squeezeseg.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/BichenWuUCB/SqueezeSeg 2 | from __future__ import print_function 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Fire(nn.Module): 9 | 10 | def __init__(self, inplanes, squeeze_planes, 11 | expand1x1_planes, expand3x3_planes): 12 | super(Fire, self).__init__() 13 | self.inplanes = inplanes 14 | self.activation = nn.ReLU(inplace=True) 15 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 16 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, 17 | kernel_size=1) 18 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, 19 | kernel_size=3, padding=1) 20 | 21 | def forward(self, x): 22 | x = self.activation(self.squeeze(x)) 23 | return torch.cat([ 24 | self.activation(self.expand1x1(x)), 25 | self.activation(self.expand3x3(x)) 26 | ], 1) 27 | 28 | 29 | # ****************************************************************************** 30 | 31 | class Backbone(nn.Module): 32 | """ 33 | Class for Squeezeseg. Subclasses PyTorch's own "nn" module 34 | """ 35 | 36 | def __init__(self, params): 37 | # Call the super constructor 38 | super(Backbone, self).__init__() 39 | print("Using SqueezeNet Backbone") 40 | self.use_range = params["input_depth"]["range"] 41 | self.use_xyz = params["input_depth"]["xyz"] 42 | self.use_remission = params["input_depth"]["remission"] 43 | self.drop_prob = params["dropout"] 44 | self.OS = params["OS"] 45 | 46 | # input depth calc 47 | self.input_depth = 0 48 | self.input_idxs = [] 49 | if self.use_range: 50 | self.input_depth += 1 51 | self.input_idxs.append(0) 52 | if self.use_xyz: 53 | self.input_depth += 3 54 | self.input_idxs.extend([1, 2, 3]) 55 | if self.use_remission: 56 | self.input_depth += 1 57 | self.input_idxs.append(4) 58 | print("Depth of backbone input = ", self.input_depth) 59 | 60 | # stride play 61 | self.strides = [2, 2, 2, 2] 62 | # check current stride 63 | current_os = 1 64 | for s in self.strides: 65 | current_os *= s 66 | print("Original OS: ", current_os) 67 | 68 | # make the new stride 69 | if self.OS > current_os: 70 | print("Can't do OS, ", self.OS, 71 | " because it is bigger than original ", current_os) 72 | else: 73 | # redo strides according to needed stride 74 | for i, stride in enumerate(reversed(self.strides), 0): 75 | if int(current_os) != self.OS: 76 | if stride == 2: 77 | current_os /= 2 78 | self.strides[-1 - i] = 1 79 | if int(current_os) == self.OS: 80 | break 81 | print("New OS: ", int(current_os)) 82 | print("Strides: ", self.strides) 83 | 84 | # encoder 85 | self.conv1a = nn.Sequential(nn.Conv2d(self.input_depth, 64, kernel_size=3, 86 | stride=[1, self.strides[0]], 87 | padding=1), 88 | nn.ReLU(inplace=True)) 89 | self.conv1b = nn.Conv2d(self.input_depth, 64, kernel_size=1, 90 | stride=1, padding=0) 91 | self.fire23 = nn.Sequential(nn.MaxPool2d(kernel_size=3, 92 | stride=[1, self.strides[1]], 93 | padding=1), 94 | Fire(64, 16, 64, 64), 95 | Fire(128, 16, 64, 64)) 96 | self.fire45 = nn.Sequential(nn.MaxPool2d(kernel_size=3, 97 | stride=[1, self.strides[2]], 98 | padding=1), 99 | Fire(128, 32, 128, 128), 100 | Fire(256, 32, 128, 128)) 101 | self.fire6789 = nn.Sequential(nn.MaxPool2d(kernel_size=3, 102 | stride=[1, self.strides[3]], 103 | padding=1), 104 | Fire(256, 48, 192, 192), 105 | Fire(384, 48, 192, 192), 106 | Fire(384, 64, 256, 256), 107 | Fire(512, 64, 256, 256)) 108 | 109 | # output 110 | self.dropout = nn.Dropout2d(self.drop_prob) 111 | 112 | # last channels 113 | self.last_channels = 512 114 | 115 | def run_layer(self, x, layer, skips, os): 116 | y = layer(x) 117 | if y.shape[2] < x.shape[2] or y.shape[3] < x.shape[3]: 118 | skips[os] = x.detach() 119 | os *= 2 120 | x = y 121 | return x, skips, os 122 | 123 | def forward(self, x): 124 | # filter input 125 | x = x[:, self.input_idxs] 126 | 127 | # run cnn 128 | # store for skip connections 129 | skips = {} 130 | os = 1 131 | 132 | # encoder 133 | skip_in = self.conv1b(x) 134 | x = self.conv1a(x) 135 | # first skip done manually 136 | skips[1] = skip_in.detach() 137 | os *= 2 138 | 139 | x, skips, os = self.run_layer(x, self.fire23, skips, os) 140 | x, skips, os = self.run_layer(x, self.dropout, skips, os) 141 | x, skips, os = self.run_layer(x, self.fire45, skips, os) 142 | x, skips, os = self.run_layer(x, self.dropout, skips, os) 143 | x, skips, os = self.run_layer(x, self.fire6789, skips, os) 144 | x, skips, os = self.run_layer(x, self.dropout, skips, os) 145 | 146 | return x, skips 147 | 148 | def get_last_depth(self): 149 | return self.last_channels 150 | 151 | def get_input_depth(self): 152 | return self.input_depth 153 | -------------------------------------------------------------------------------- /semantic_net/rangenet/backbones/squeezesegV2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | 4 | from __future__ import print_function 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Fire(nn.Module): 11 | def __init__(self, inplanes, squeeze_planes, 12 | expand1x1_planes, expand3x3_planes, bn_d=0.1): 13 | super(Fire, self).__init__() 14 | self.inplanes = inplanes 15 | self.bn_d = bn_d 16 | self.activation = nn.ReLU(inplace=True) 17 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 18 | self.squeeze_bn = nn.BatchNorm2d(squeeze_planes, momentum=self.bn_d) 19 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, 20 | kernel_size=1) 21 | self.expand1x1_bn = nn.BatchNorm2d(expand1x1_planes, momentum=self.bn_d) 22 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, 23 | kernel_size=3, padding=1) 24 | self.expand3x3_bn = nn.BatchNorm2d(expand3x3_planes, momentum=self.bn_d) 25 | 26 | def forward(self, x): 27 | x = self.activation(self.squeeze_bn(self.squeeze(x))) 28 | return torch.cat([ 29 | self.activation(self.expand1x1_bn(self.expand1x1(x))), 30 | self.activation(self.expand3x3_bn(self.expand3x3(x))) 31 | ], 1) 32 | 33 | 34 | class CAM(nn.Module): 35 | 36 | def __init__(self, inplanes, bn_d=0.1): 37 | super(CAM, self).__init__() 38 | self.inplanes = inplanes 39 | self.bn_d = bn_d 40 | self.pool = nn.MaxPool2d(7, 1, 3) 41 | self.squeeze = nn.Conv2d(inplanes, inplanes // 16, 42 | kernel_size=1, stride=1) 43 | self.squeeze_bn = nn.BatchNorm2d(inplanes // 16, momentum=self.bn_d) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.unsqueeze = nn.Conv2d(inplanes // 16, inplanes, 46 | kernel_size=1, stride=1) 47 | self.unsqueeze_bn = nn.BatchNorm2d(inplanes, momentum=self.bn_d) 48 | self.sigmoid = nn.Sigmoid() 49 | 50 | def forward(self, x): 51 | # 7x7 pooling 52 | y = self.pool(x) 53 | # squeezing and relu 54 | y = self.relu(self.squeeze_bn(self.squeeze(y))) 55 | # unsqueezing 56 | y = self.sigmoid(self.unsqueeze_bn(self.unsqueeze(y))) 57 | # attention 58 | return y * x 59 | 60 | # ****************************************************************************** 61 | 62 | 63 | class Backbone(nn.Module): 64 | """ 65 | Class for Squeezeseg. Subclasses PyTorch's own "nn" module 66 | """ 67 | 68 | def __init__(self, params): 69 | # Call the super constructor 70 | super(Backbone, self).__init__() 71 | print("Using SqueezeNet Backbone") 72 | self.use_range = params["input_depth"]["range"] 73 | self.use_xyz = params["input_depth"]["xyz"] 74 | self.use_remission = params["input_depth"]["remission"] 75 | self.bn_d = params["bn_d"] 76 | self.drop_prob = params["dropout"] 77 | self.OS = params["OS"] 78 | 79 | # input depth calc 80 | self.input_depth = 0 81 | self.input_idxs = [] 82 | if self.use_range: 83 | self.input_depth += 1 84 | self.input_idxs.append(0) 85 | if self.use_xyz: 86 | self.input_depth += 3 87 | self.input_idxs.extend([1, 2, 3]) 88 | if self.use_remission: 89 | self.input_depth += 1 90 | self.input_idxs.append(4) 91 | print("Depth of backbone input = ", self.input_depth) 92 | 93 | # stride play 94 | self.strides = [2, 2, 2, 2] 95 | # check current stride 96 | current_os = 1 97 | for s in self.strides: 98 | current_os *= s 99 | print("Original OS: ", current_os) 100 | 101 | # make the new stride 102 | if self.OS > current_os: 103 | print("Can't do OS, ", self.OS, 104 | " because it is bigger than original ", current_os) 105 | else: 106 | # redo strides according to needed stride 107 | for i, stride in enumerate(reversed(self.strides), 0): 108 | if int(current_os) != self.OS: 109 | if stride == 2: 110 | current_os /= 2 111 | self.strides[-1 - i] = 1 112 | if int(current_os) == self.OS: 113 | break 114 | print("New OS: ", int(current_os)) 115 | print("Strides: ", self.strides) 116 | 117 | # encoder 118 | self.conv1a = nn.Sequential(nn.Conv2d(self.input_depth, 64, kernel_size=3, 119 | stride=[1, self.strides[0]], 120 | padding=1), 121 | nn.BatchNorm2d(64, momentum=self.bn_d), 122 | nn.ReLU(inplace=True), 123 | CAM(64, bn_d=self.bn_d)) 124 | self.conv1b = nn.Sequential(nn.Conv2d(self.input_depth, 64, kernel_size=1, 125 | stride=1, padding=0), 126 | nn.BatchNorm2d(64, momentum=self.bn_d)) 127 | self.fire23 = nn.Sequential(nn.MaxPool2d(kernel_size=3, 128 | stride=[1, self.strides[1]], 129 | padding=1), 130 | Fire(64, 16, 64, 64, bn_d=self.bn_d), 131 | CAM(128, bn_d=self.bn_d), 132 | Fire(128, 16, 64, 64, bn_d=self.bn_d), 133 | CAM(128, bn_d=self.bn_d)) 134 | self.fire45 = nn.Sequential(nn.MaxPool2d(kernel_size=3, 135 | stride=[1, self.strides[2]], 136 | padding=1), 137 | Fire(128, 32, 128, 128, bn_d=self.bn_d), 138 | Fire(256, 32, 128, 128, bn_d=self.bn_d)) 139 | self.fire6789 = nn.Sequential(nn.MaxPool2d(kernel_size=3, 140 | stride=[1, self.strides[3]], 141 | padding=1), 142 | Fire(256, 48, 192, 192, bn_d=self.bn_d), 143 | Fire(384, 48, 192, 192, bn_d=self.bn_d), 144 | Fire(384, 64, 256, 256, bn_d=self.bn_d), 145 | Fire(512, 64, 256, 256, bn_d=self.bn_d)) 146 | 147 | # output 148 | self.dropout = nn.Dropout2d(self.drop_prob) 149 | 150 | # last channels 151 | self.last_channels = 512 152 | 153 | def run_layer(self, x, layer, skips, os): 154 | y = layer(x) 155 | if y.shape[2] < x.shape[2] or y.shape[3] < x.shape[3]: 156 | skips[os] = x.detach() 157 | os *= 2 158 | x = y 159 | return x, skips, os 160 | 161 | def forward(self, x): 162 | # filter input 163 | x = x[:, self.input_idxs] 164 | 165 | # run cnn 166 | # store for skip connections 167 | skips = {} 168 | os = 1 169 | 170 | # encoder 171 | skip_in = self.conv1b(x) 172 | x = self.conv1a(x) 173 | # first skip done manually 174 | skips[1] = skip_in.detach() 175 | os *= 2 176 | 177 | x, skips, os = self.run_layer(x, self.fire23, skips, os) 178 | x, skips, os = self.run_layer(x, self.dropout, skips, os) 179 | x, skips, os = self.run_layer(x, self.fire45, skips, os) 180 | x, skips, os = self.run_layer(x, self.dropout, skips, os) 181 | x, skips, os = self.run_layer(x, self.fire6789, skips, os) 182 | x, skips, os = self.run_layer(x, self.dropout, skips, os) 183 | 184 | return x, skips 185 | 186 | def get_last_depth(self): 187 | return self.last_channels 188 | 189 | def get_input_depth(self): 190 | return self.input_depth 191 | -------------------------------------------------------------------------------- /semantic_net/rangenet/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/common/__init__.py -------------------------------------------------------------------------------- /semantic_net/rangenet/common/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/common/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /semantic_net/rangenet/common/__pycache__/laserscan.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/common/__pycache__/laserscan.cpython-38.pyc -------------------------------------------------------------------------------- /semantic_net/rangenet/common/laserscan.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | import numpy as np 4 | 5 | 6 | class LaserScan: 7 | """Class that contains LaserScan with x,y,z,r""" 8 | EXTENSIONS_SCAN = ['.bin'] 9 | 10 | def __init__(self, project=False, H=64, W=1024, fov_up=3.0, fov_down=-25.0): 11 | self.project = project 12 | self.proj_H = H 13 | self.proj_W = W 14 | self.proj_fov_up = fov_up 15 | self.proj_fov_down = fov_down 16 | self.reset() 17 | 18 | def reset(self): 19 | """ Reset scan members. """ 20 | self.points = np.zeros((0, 3), dtype=np.float32) # [m, 3]: x, y, z 21 | self.remissions = np.zeros((0, 1), dtype=np.float32) # [m ,1]: remission 22 | 23 | # projected range image - [H,W] range (-1 is no data) 24 | self.proj_range = np.full((self.proj_H, self.proj_W), -1, 25 | dtype=np.float32) 26 | 27 | # unprojected range (list of depths for each point) 28 | self.unproj_range = np.zeros((0, 1), dtype=np.float32) 29 | 30 | # projected point cloud xyz - [H,W,3] xyz coord (-1 is no data) 31 | self.proj_xyz = np.full((self.proj_H, self.proj_W, 3), -1, 32 | dtype=np.float32) 33 | 34 | # projected remission - [H,W] intensity (-1 is no data) 35 | self.proj_remission = np.full((self.proj_H, self.proj_W), -1, 36 | dtype=np.float32) 37 | 38 | # projected index (for each pixel, what I am in the pointcloud) 39 | # [H,W] index (-1 is no data) 40 | self.proj_idx = np.full((self.proj_H, self.proj_W), -1, 41 | dtype=np.int32) 42 | 43 | # for each point, where it is in the range image 44 | self.proj_x = np.zeros((0, 1), dtype=np.int32) # [m, 1]: x 45 | self.proj_y = np.zeros((0, 1), dtype=np.int32) # [m, 1]: y 46 | 47 | # mask containing for each pixel, if it contains a point or not 48 | self.proj_mask = np.zeros((self.proj_H, self.proj_W), 49 | dtype=np.int32) # [H,W] mask 50 | 51 | def size(self): 52 | """ Return the size of the point cloud. """ 53 | return self.points.shape[0] 54 | 55 | def __len__(self): 56 | return self.size() 57 | 58 | def open_scan(self, filename): 59 | """ Open raw scan and fill in attributes 60 | """ 61 | # reset just in case there was an open structure 62 | self.reset() 63 | 64 | # check filename is string 65 | if not isinstance(filename, str): 66 | raise TypeError("Filename should be string type, " 67 | "but was {type}".format(type=str(type(filename)))) 68 | 69 | # check extension is a laserscan 70 | if not any(filename.endswith(ext) for ext in self.EXTENSIONS_SCAN): 71 | raise RuntimeError("Filename extension is not valid scan file.") 72 | 73 | # if all goes well, open pointcloud 74 | scan = np.fromfile(filename, dtype=np.float32) 75 | scan = scan.reshape((-1, 4)) 76 | 77 | # put in attribute 78 | points = scan[:, 0:3] # get xyz 79 | remissions = scan[:, 3] # get remission 80 | self.set_points(points, remissions) 81 | 82 | def set_points(self, points, remissions=None): 83 | """ Set scan attributes (instead of opening from file) 84 | """ 85 | # reset just in case there was an open structure 86 | self.reset() 87 | 88 | # check scan makes sense 89 | if not isinstance(points, np.ndarray): 90 | raise TypeError("Scan should be numpy array") 91 | 92 | # check remission makes sense 93 | if remissions is not None and not isinstance(remissions, np.ndarray): 94 | raise TypeError("Remissions should be numpy array") 95 | 96 | # put in attribute 97 | self.points = points # get xyz 98 | if remissions is not None: 99 | self.remissions = remissions # get remission 100 | else: 101 | self.remissions = np.zeros((points.shape[0]), dtype=np.float32) 102 | 103 | # if projection is wanted, then do it and fill in the structure 104 | if self.project: 105 | self.do_range_projection() 106 | 107 | def do_range_projection(self): 108 | """ Project a pointcloud into a spherical projection image.projection. 109 | Function takes no arguments because it can be also called externally 110 | if the value of the constructor was not set (in case you change your 111 | mind about wanting the projection) 112 | """ 113 | # laser parameters 114 | fov_up = self.proj_fov_up / 180.0 * np.pi # field of view up in rad 115 | fov_down = self.proj_fov_down / 180.0 * np.pi # field of view down in rad 116 | fov = abs(fov_down) + abs(fov_up) # get field of view total in rad 117 | 118 | # get depth of all points 119 | depth = np.linalg.norm(self.points, 2, axis=1) 120 | 121 | # get scan components 122 | scan_x = self.points[:, 0] 123 | scan_y = self.points[:, 1] 124 | scan_z = self.points[:, 2] 125 | 126 | # get angles of all points 127 | yaw = -np.arctan2(scan_y, scan_x) 128 | pitch = np.arcsin(scan_z / depth) 129 | 130 | # get projections in image coords 131 | proj_x = 0.5 * (yaw / np.pi + 1.0) # in [0.0, 1.0] 132 | proj_y = 1.0 - (pitch + abs(fov_down)) / fov # in [0.0, 1.0] 133 | 134 | # scale to image size using angular resolution 135 | proj_x *= self.proj_W # in [0.0, W] 136 | proj_y *= self.proj_H # in [0.0, H] 137 | 138 | # round and clamp for use as index 139 | proj_x = np.floor(proj_x) 140 | proj_x = np.minimum(self.proj_W - 1, proj_x) 141 | proj_x = np.maximum(0, proj_x).astype(np.int32) # in [0,W-1] 142 | self.proj_x = np.copy(proj_x) # store a copy in orig order 143 | 144 | proj_y = np.floor(proj_y) 145 | proj_y = np.minimum(self.proj_H - 1, proj_y) 146 | proj_y = np.maximum(0, proj_y).astype(np.int32) # in [0,H-1] 147 | self.proj_y = np.copy(proj_y) # stope a copy in original order 148 | 149 | # copy of depth in original order 150 | self.unproj_range = np.copy(depth) 151 | 152 | # order in decreasing depth 153 | indices = np.arange(depth.shape[0]) 154 | order = np.argsort(depth)[::-1] 155 | depth = depth[order] 156 | indices = indices[order] 157 | points = self.points[order] 158 | remission = self.remissions[order] 159 | proj_y = proj_y[order] 160 | proj_x = proj_x[order] 161 | 162 | # assing to images 163 | self.proj_range[proj_y, proj_x] = depth 164 | self.proj_xyz[proj_y, proj_x] = points 165 | self.proj_remission[proj_y, proj_x] = remission 166 | self.proj_idx[proj_y, proj_x] = indices 167 | self.proj_mask = (self.proj_idx > 0).astype(np.int32) 168 | 169 | 170 | class SemLaserScan(LaserScan): 171 | """Class that contains LaserScan with x,y,z,r,sem_label,sem_color_label,inst_label,inst_color_label""" 172 | EXTENSIONS_LABEL = ['.label'] 173 | 174 | def __init__(self, sem_color_dict=None, project=False, H=64, W=1024, fov_up=3.0, fov_down=-25.0, max_classes=300): 175 | super(SemLaserScan, self).__init__(project, H, W, fov_up, fov_down) 176 | self.reset() 177 | 178 | # make semantic colors 179 | if sem_color_dict: 180 | # if I have a dict, make it 181 | max_sem_key = 0 182 | for key, data in sem_color_dict.items(): 183 | if key + 1 > max_sem_key: 184 | max_sem_key = key + 1 185 | self.sem_color_lut = np.zeros((max_sem_key + 100, 3), dtype=np.float32) 186 | for key, value in sem_color_dict.items(): 187 | self.sem_color_lut[key] = np.array(value, np.float32) / 255.0 188 | else: 189 | # otherwise make random 190 | max_sem_key = max_classes 191 | self.sem_color_lut = np.random.uniform(low=0.0, 192 | high=1.0, 193 | size=(max_sem_key, 3)) 194 | # force zero to a gray-ish color 195 | self.sem_color_lut[0] = np.full((3), 0.1) 196 | 197 | # make instance colors 198 | max_inst_id = 100000 199 | self.inst_color_lut = np.random.uniform(low=0.0, 200 | high=1.0, 201 | size=(max_inst_id, 3)) 202 | # force zero to a gray-ish color 203 | self.inst_color_lut[0] = np.full((3), 0.1) 204 | 205 | def reset(self): 206 | """ Reset scan members. """ 207 | super(SemLaserScan, self).reset() 208 | 209 | # semantic labels 210 | self.sem_label = np.zeros((0, 1), dtype=np.int32) # [m, 1]: label 211 | self.sem_label_color = np.zeros((0, 3), dtype=np.float32) # [m ,3]: color 212 | 213 | # instance labels 214 | self.inst_label = np.zeros((0, 1), dtype=np.int32) # [m, 1]: label 215 | self.inst_label_color = np.zeros((0, 3), dtype=np.float32) # [m ,3]: color 216 | 217 | # projection color with semantic labels 218 | self.proj_sem_label = np.zeros((self.proj_H, self.proj_W), 219 | dtype=np.int32) # [H,W] label 220 | self.proj_sem_color = np.zeros((self.proj_H, self.proj_W, 3), 221 | dtype=np.float) # [H,W,3] color 222 | 223 | # projection color with instance labels 224 | self.proj_inst_label = np.zeros((self.proj_H, self.proj_W), 225 | dtype=np.int32) # [H,W] label 226 | self.proj_inst_color = np.zeros((self.proj_H, self.proj_W, 3), 227 | dtype=np.float) # [H,W,3] color 228 | 229 | def open_label(self, filename): 230 | """ Open raw scan and fill in attributes 231 | """ 232 | # check filename is string 233 | if not isinstance(filename, str): 234 | raise TypeError("Filename should be string type, " 235 | "but was {type}".format(type=str(type(filename)))) 236 | 237 | # check extension is a laserscan 238 | if not any(filename.endswith(ext) for ext in self.EXTENSIONS_LABEL): 239 | raise RuntimeError("Filename extension is not valid label file.") 240 | 241 | # if all goes well, open label 242 | label = np.fromfile(filename, dtype=np.int32) 243 | label = label.reshape((-1)) 244 | 245 | # set it 246 | self.set_label(label) 247 | 248 | def set_label(self, label): 249 | """ Set points for label not from file but from np 250 | """ 251 | # check label makes sense 252 | if not isinstance(label, np.ndarray): 253 | raise TypeError("Label should be numpy array") 254 | 255 | # only fill in attribute if the right size 256 | if label.shape[0] == self.points.shape[0]: 257 | self.sem_label = label & 0xFFFF # semantic label in lower half 258 | self.inst_label = label >> 16 # instance id in upper half 259 | else: 260 | print("Points shape: ", self.points.shape) 261 | print("Label shape: ", label.shape) 262 | raise ValueError("Scan and Label don't contain same number of points") 263 | 264 | # sanity check 265 | assert((self.sem_label + (self.inst_label << 16) == label).all()) 266 | 267 | if self.project: 268 | self.do_label_projection() 269 | 270 | def colorize(self): 271 | """ Colorize pointcloud with the color of each semantic label 272 | """ 273 | self.sem_label_color = self.sem_color_lut[self.sem_label] 274 | self.sem_label_color = self.sem_label_color.reshape((-1, 3)) 275 | 276 | self.inst_label_color = self.inst_color_lut[self.inst_label] 277 | self.inst_label_color = self.inst_label_color.reshape((-1, 3)) 278 | 279 | def do_label_projection(self): 280 | # only map colors to labels that exist 281 | mask = self.proj_idx >= 0 282 | 283 | # semantics 284 | self.proj_sem_label[mask] = self.sem_label[self.proj_idx[mask]] 285 | self.proj_sem_color[mask] = self.sem_color_lut[self.sem_label[self.proj_idx[mask]]] 286 | 287 | # instances 288 | self.proj_inst_label[mask] = self.inst_label[self.proj_idx[mask]] 289 | self.proj_inst_color[mask] = self.inst_color_lut[self.inst_label[self.proj_idx[mask]]] 290 | -------------------------------------------------------------------------------- /semantic_net/rangenet/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.14.0 2 | scipy==0.19.1 3 | torch==1.1.0 4 | tensorflow==1.13.1 5 | vispy==0.5.3 6 | torchvision==0.2.2.post3 7 | opencv_contrib_python==4.1.0.25 8 | matplotlib==2.2.3 9 | Pillow==6.1.0 10 | PyYAML==5.1.1 11 | -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/__init__.py -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/README.md: -------------------------------------------------------------------------------- 1 | # LiDAR-Bonnetal Semantic Segmentation Training 2 | 3 | This part of the framework deals with the training of semantic segmentation networks for point cloud data using range images. This code allows to reproduce the experiments from the [RangeNet++](http://www.ipb.uni-bonn.de/wp-content/papercite-data/pdf/milioto2019iros.pdf) paper 4 | 5 | _Examples of segmentation results from [SemanticKITTI](http://semantic-kitti.org) dataset:_ 6 | ![ptcl](../../../pics/semantic-ptcl.gif) 7 | ![ptcl](../../../pics/semantic-proj.gif) 8 | 9 | ## Configuration files 10 | 11 | Architecture configuration files are located at [config/arch](config/arch/) 12 | Dataset configuration files are located at [config/labels](config/labels/) 13 | 14 | ## Apps 15 | 16 | `ALL SCRIPTS CAN BE INVOKED WITH -h TO GET EXTRA HELP ON HOW TO RUN THEM` 17 | 18 | ### Visualization 19 | 20 | To visualize the data (in this example sequence 00): 21 | 22 | ```sh 23 | $ ./visualize.py -d /path/to/dataset/ -s 00 24 | ``` 25 | 26 | To visualize the predictions (in this example sequence 00): 27 | 28 | ```sh 29 | $ ./visualize.py -d /path/to/dataset/ -p /path/to/predictions/ -s 00 30 | ``` 31 | 32 | ### Training 33 | 34 | To train a network (from scratch): 35 | 36 | ```sh 37 | $ ./train.py -d /path/to/dataset -ac /config/arch/CHOICE.yaml -l /path/to/log 38 | ``` 39 | 40 | To train a network (from pretrained model): 41 | 42 | ``` 43 | $ ./train.py -d /path/to/dataset -ac /config/arch/CHOICE.yaml -dc /config/labels/CHOICE.yaml -l /path/to/log -p /path/to/pretrained 44 | ``` 45 | 46 | This will generate a tensorboard log, which can be visualized by running: 47 | 48 | ```sh 49 | $ cd /path/to/log 50 | $ tensorboard --logdir=. --port 5555 51 | ``` 52 | 53 | And acccessing [http://localhost:5555](http://localhost:5555) in your browser. 54 | 55 | ### Inference 56 | 57 | To infer the predictions for the entire dataset: 58 | 59 | ```sh 60 | $ ./infer.py -d /path/to/dataset/ -l /path/for/predictions -m /path/to/model 61 | ```` 62 | 63 | ### Evaluation 64 | 65 | To evaluate the overall IoU of the point clouds (of a specific split, which in semantic kitti can only be train and valid, since test is only run in our evaluation server): 66 | 67 | ```sh 68 | $ ./evaluate_iou.py -d /path/to/dataset -p /path/to/predictions/ --split valid 69 | ``` 70 | 71 | To evaluate the border IoU of the point clouds (introduced in RangeNet++ paper): 72 | 73 | ```sh 74 | $ ./evaluate_biou.py -d /path/to/dataset -p /path/to/predictions/ --split valid --border 1 --conn 4 75 | ``` 76 | 77 | ## Pre-trained Models 78 | 79 | ### [SemanticKITTI](http://semantic-kitti.org) 80 | 81 | - [squeezeseg](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/squeezeseg.tar.gz) 82 | - [squeezeseg + crf](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/squeezeseg-crf.tar.gz) 83 | - [squeezesegV2](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/squeezesegV2.tar.gz) 84 | - [squeezesegV2 + crf](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/squeezesegV2-crf.tar.gz) 85 | - [darknet21](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/darknet21.tar.gz) 86 | - [darknet53](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/darknet53.tar.gz) 87 | - [darknet53-1024](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/darknet53-1024.tar.gz) 88 | - [darknet53-512](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/darknet53-512.tar.gz) 89 | 90 | To enable kNN post-processing, just change the boolean value to `True` in the `arch_cfg.yaml` file parameter, inside the model directory. 91 | 92 | ## Predictions from Models 93 | 94 | ### [SemanticKITTI](http://semantic-kitti.org) 95 | 96 | These are the predictions for the train, validation, and test sets. The performance can be evaluated for the training and validation set, but for test set evaluation a submission to the benchmark needs to be made (labels are not public). 97 | 98 | No post-processing: 99 | - [squeezeseg](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/squeezeseg.tar.gz) 100 | - [squeezeseg + crf](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/squeezeseg-crf.tar.gz) 101 | - [squeezesegV2](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/squeezesegV2.tar.gz) 102 | - [squeezesegV2 + crf](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/squeezesegV2-crf.tar.gz) 103 | - [darknet21](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet21.tar.gz) 104 | - [darknet53](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet53.tar.gz) 105 | - [darknet53-1024](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet53-1024.tar.gz) 106 | - [darknet53-512](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet53-512.tar.gz) 107 | 108 | With k-NN processing: 109 | - [squeezeseg](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/squeezeseg-knn.tar.gz) 110 | - [squeezesegV2](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/squeezesegV2-knn.tar.gz) 111 | - [darknet53](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet53-knn.tar.gz) 112 | - [darknet21](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet21-knn.tar.gz) 113 | - [darknet53-1024](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet53-1024-knn.tar.gz) 114 | - [darknet53-512](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet53-512-knn.tar.gz) 115 | 116 | ## Citations 117 | 118 | If you use our framework, model, or predictions for any academic work, please cite the original [paper](http://www.ipb.uni-bonn.de/wp-content/papercite-data/pdf/milioto2019iros.pdf), and the [dataset](http://semantic-kitti.org). 119 | 120 | ``` 121 | @inproceedings{milioto2019iros, 122 | author = {A. Milioto and I. Vizzo and J. Behley and C. Stachniss}, 123 | title = {{RangeNet++: Fast and Accurate LiDAR Semantic Segmentation}}, 124 | booktitle = {IEEE/RSJ Intl.~Conf.~on Intelligent Robots and Systems (IROS)}, 125 | year = 2019, 126 | codeurl = {https://github.com/PRBonn/lidar-bonnetal}, 127 | videourl = {https://youtu.be/wuokg7MFZyU}, 128 | } 129 | ``` 130 | 131 | ``` 132 | @inproceedings{behley2019iccv, 133 | author = {J. Behley and M. Garbade and A. Milioto and J. Quenzel and S. Behnke and C. Stachniss and J. Gall}, 134 | title = {{SemanticKITTI: A Dataset for Semantic Scene Understanding of LiDAR Sequences}}, 135 | booktitle = {Proc. of the IEEE/CVF International Conf.~on Computer Vision (ICCV)}, 136 | year = {2019} 137 | } 138 | ``` -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | TRAIN_PATH = "../../" 3 | DEPLOY_PATH = "../../../deploy" 4 | sys.path.insert(0, TRAIN_PATH) 5 | -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/config/labels/semantic-kitti-all.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | name: "kitti" 3 | labels: 4 | 0 : "unlabeled" 5 | 1 : "outlier" 6 | 10: "car" 7 | 11: "bicycle" 8 | 13: "bus" 9 | 15: "motorcycle" 10 | 16: "on-rails" 11 | 18: "truck" 12 | 20: "other-vehicle" 13 | 30: "person" 14 | 31: "bicyclist" 15 | 32: "motorcyclist" 16 | 40: "road" 17 | 44: "parking" 18 | 48: "sidewalk" 19 | 49: "other-ground" 20 | 50: "building" 21 | 51: "fence" 22 | 52: "other-structure" 23 | 60: "lane-marking" 24 | 70: "vegetation" 25 | 71: "trunk" 26 | 72: "terrain" 27 | 80: "pole" 28 | 81: "traffic-sign" 29 | 99: "other-object" 30 | 252: "moving-car" 31 | 253: "moving-bicyclist" 32 | 254: "moving-person" 33 | 255: "moving-motorcyclist" 34 | 256: "moving-on-rails" 35 | 257: "moving-bus" 36 | 258: "moving-truck" 37 | 259: "moving-other-vehicle" 38 | color_map: # bgr 39 | 0 : [0, 0, 0] 40 | 1 : [0, 0, 255] 41 | 10: [245, 150, 100] 42 | 11: [245, 230, 100] 43 | 13: [250, 80, 100] 44 | 15: [150, 60, 30] 45 | 16: [255, 0, 0] 46 | 18: [180, 30, 80] 47 | 20: [255, 0, 0] 48 | 30: [30, 30, 255] 49 | 31: [200, 40, 255] 50 | 32: [90, 30, 150] 51 | 40: [255, 0, 255] 52 | 44: [255, 150, 255] 53 | 48: [75, 0, 75] 54 | 49: [75, 0, 175] 55 | 50: [0, 200, 255] 56 | 51: [50, 120, 255] 57 | 52: [0, 150, 255] 58 | 60: [170, 255, 150] 59 | 70: [0, 175, 0] 60 | 71: [0, 60, 135] 61 | 72: [80, 240, 150] 62 | 80: [150, 240, 255] 63 | 81: [0, 0, 255] 64 | 99: [255, 255, 50] 65 | 252: [245, 150, 100] 66 | 256: [255, 0, 0] 67 | 253: [200, 40, 255] 68 | 254: [30, 30, 255] 69 | 255: [90, 30, 150] 70 | 257: [250, 80, 100] 71 | 258: [180, 30, 80] 72 | 259: [255, 0, 0] 73 | content: # as a ratio with the total number of points 74 | 0: 0.018889854628292943 75 | 1: 0.0002937197336781505 76 | 10: 0.040818519255974316 77 | 11: 0.00016609538710764618 78 | 13: 2.7879693665067774e-05 79 | 15: 0.00039838616015114444 80 | 16: 0.0 81 | 18: 0.0020633612104619787 82 | 20: 0.0016218197275284021 83 | 30: 0.00017698551338515307 84 | 31: 1.1065903904919655e-08 85 | 32: 5.532951952459828e-09 86 | 40: 0.1987493871255525 87 | 44: 0.014717169549888214 88 | 48: 0.14392298360372 89 | 49: 0.0039048553037472045 90 | 50: 0.1326861944777486 91 | 51: 0.0723592229456223 92 | 52: 0.002395131480328884 93 | 60: 4.7084144280367186e-05 94 | 70: 0.26681502148037506 95 | 71: 0.006035012012626033 96 | 72: 0.07814222006271769 97 | 80: 0.002855498193863172 98 | 81: 0.0006155958086189918 99 | 99: 0.009923127583046915 100 | 252: 0.001789309418528068 101 | 253: 0.00012709999297008662 102 | 254: 0.00016059776092534436 103 | 255: 3.745553104802113e-05 104 | 256: 0.0 105 | 257: 0.00011351574470342043 106 | 258: 0.00010157861367183268 107 | 259: 4.3840131989471124e-05 108 | # classes that are indistinguishable from single scan or inconsistent in 109 | # ground truth are mapped to their closest equivalent 110 | learning_map: 111 | 0 : 0 # "unlabeled" 112 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped 113 | 10: 1 # "car" 114 | 11: 2 # "bicycle" 115 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 116 | 15: 3 # "motorcycle" 117 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 118 | 18: 4 # "truck" 119 | 20: 5 # "other-vehicle" 120 | 30: 6 # "person" 121 | 31: 7 # "bicyclist" 122 | 32: 8 # "motorcyclist" 123 | 40: 9 # "road" 124 | 44: 10 # "parking" 125 | 48: 11 # "sidewalk" 126 | 49: 12 # "other-ground" 127 | 50: 13 # "building" 128 | 51: 14 # "fence" 129 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 130 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 131 | 70: 15 # "vegetation" 132 | 71: 16 # "trunk" 133 | 72: 17 # "terrain" 134 | 80: 18 # "pole" 135 | 81: 19 # "traffic-sign" 136 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 137 | 252: 20 # "moving-car" 138 | 253: 21 # "moving-bicyclist" 139 | 254: 22 # "moving-person" 140 | 255: 23 # "moving-motorcyclist" 141 | 256: 24 # "moving-on-rails" mapped to "moving-other-vehicle" ------mapped 142 | 257: 24 # "moving-bus" mapped to "moving-other-vehicle" -----------mapped 143 | 258: 25 # "moving-truck" 144 | 259: 24 # "moving-other-vehicle" 145 | learning_map_inv: # inverse of previous map 146 | 0: 0 # "unlabeled", and others ignored 147 | 1: 10 # "car" 148 | 2: 11 # "bicycle" 149 | 3: 15 # "motorcycle" 150 | 4: 18 # "truck" 151 | 5: 20 # "other-vehicle" 152 | 6: 30 # "person" 153 | 7: 31 # "bicyclist" 154 | 8: 32 # "motorcyclist" 155 | 9: 40 # "road" 156 | 10: 44 # "parking" 157 | 11: 48 # "sidewalk" 158 | 12: 49 # "other-ground" 159 | 13: 50 # "building" 160 | 14: 51 # "fence" 161 | 15: 70 # "vegetation" 162 | 16: 71 # "trunk" 163 | 17: 72 # "terrain" 164 | 18: 80 # "pole" 165 | 19: 81 # "traffic-sign" 166 | 20: 252 # "moving-car" 167 | 21: 253 # "moving-bicyclist" 168 | 22: 254 # "moving-person" 169 | 23: 255 # "moving-motorcyclist" 170 | 24: 259 # "moving-other-vehicle" 171 | 25: 258 # "moving-truck" 172 | learning_ignore: # Ignore classes 173 | 0: True # "unlabeled", and others ignored 174 | 1: False # "car" 175 | 2: False # "bicycle" 176 | 3: False # "motorcycle" 177 | 4: False # "truck" 178 | 5: False # "other-vehicle" 179 | 6: False # "person" 180 | 7: False # "bicyclist" 181 | 8: False # "motorcyclist" 182 | 9: False # "road" 183 | 10: False # "parking" 184 | 11: False # "sidewalk" 185 | 12: False # "other-ground" 186 | 13: False # "building" 187 | 14: False # "fence" 188 | 15: False # "vegetation" 189 | 16: False # "trunk" 190 | 17: False # "terrain" 191 | 18: False # "pole" 192 | 19: False # "traffic-sign" 193 | 20: False # "moving-car" 194 | 21: False # "moving-bicyclist" 195 | 22: False # "moving-person" 196 | 23: False # "moving-motorcyclist" 197 | 24: False # "moving-other-vehicle" 198 | 25: False # "moving-truck" 199 | split: # sequence numbers 200 | train: 201 | - 0 202 | - 1 203 | - 2 204 | - 3 205 | - 4 206 | - 5 207 | - 6 208 | - 7 209 | - 9 210 | - 10 211 | valid: 212 | - 8 213 | test: 214 | - 11 215 | - 12 216 | - 13 217 | - 14 218 | - 15 219 | - 16 220 | - 17 221 | - 18 222 | - 19 223 | - 20 224 | - 21 225 | -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/config/labels/semantic-kitti.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | name: "kitti" 3 | labels: 4 | 0 : "unlabeled" 5 | 1 : "outlier" 6 | 10: "car" 7 | 11: "bicycle" 8 | 13: "bus" 9 | 15: "motorcycle" 10 | 16: "on-rails" 11 | 18: "truck" 12 | 20: "other-vehicle" 13 | 30: "person" 14 | 31: "bicyclist" 15 | 32: "motorcyclist" 16 | 40: "road" 17 | 44: "parking" 18 | 48: "sidewalk" 19 | 49: "other-ground" 20 | 50: "building" 21 | 51: "fence" 22 | 52: "other-structure" 23 | 60: "lane-marking" 24 | 70: "vegetation" 25 | 71: "trunk" 26 | 72: "terrain" 27 | 80: "pole" 28 | 81: "traffic-sign" 29 | 99: "other-object" 30 | 252: "moving-car" 31 | 253: "moving-bicyclist" 32 | 254: "moving-person" 33 | 255: "moving-motorcyclist" 34 | 256: "moving-on-rails" 35 | 257: "moving-bus" 36 | 258: "moving-truck" 37 | 259: "moving-other-vehicle" 38 | color_map: # bgr 39 | 0 : [0, 0, 0] 40 | 1 : [0, 0, 255] 41 | 10: [245, 150, 100] 42 | 11: [245, 230, 100] 43 | 13: [250, 80, 100] 44 | 15: [150, 60, 30] 45 | 16: [255, 0, 0] 46 | 18: [180, 30, 80] 47 | 20: [255, 0, 0] 48 | 30: [30, 30, 255] 49 | 31: [200, 40, 255] 50 | 32: [90, 30, 150] 51 | 40: [255, 0, 255] 52 | 44: [255, 150, 255] 53 | 48: [75, 0, 75] 54 | 49: [75, 0, 175] 55 | 50: [0, 200, 255] 56 | 51: [50, 120, 255] 57 | 52: [0, 150, 255] 58 | 60: [170, 255, 150] 59 | 70: [0, 175, 0] 60 | 71: [0, 60, 135] 61 | 72: [80, 240, 150] 62 | 80: [150, 240, 255] 63 | 81: [0, 0, 255] 64 | 99: [255, 255, 50] 65 | 252: [245, 150, 100] 66 | 256: [255, 0, 0] 67 | 253: [200, 40, 255] 68 | 254: [30, 30, 255] 69 | 255: [90, 30, 150] 70 | 257: [250, 80, 100] 71 | 258: [180, 30, 80] 72 | 259: [255, 0, 0] 73 | content: # as a ratio with the total number of points 74 | 0: 0.018889854628292943 75 | 1: 0.0002937197336781505 76 | 10: 0.040818519255974316 77 | 11: 0.00016609538710764618 78 | 13: 2.7879693665067774e-05 79 | 15: 0.00039838616015114444 80 | 16: 0.0 81 | 18: 0.0020633612104619787 82 | 20: 0.0016218197275284021 83 | 30: 0.00017698551338515307 84 | 31: 1.1065903904919655e-08 85 | 32: 5.532951952459828e-09 86 | 40: 0.1987493871255525 87 | 44: 0.014717169549888214 88 | 48: 0.14392298360372 89 | 49: 0.0039048553037472045 90 | 50: 0.1326861944777486 91 | 51: 0.0723592229456223 92 | 52: 0.002395131480328884 93 | 60: 4.7084144280367186e-05 94 | 70: 0.26681502148037506 95 | 71: 0.006035012012626033 96 | 72: 0.07814222006271769 97 | 80: 0.002855498193863172 98 | 81: 0.0006155958086189918 99 | 99: 0.009923127583046915 100 | 252: 0.001789309418528068 101 | 253: 0.00012709999297008662 102 | 254: 0.00016059776092534436 103 | 255: 3.745553104802113e-05 104 | 256: 0.0 105 | 257: 0.00011351574470342043 106 | 258: 0.00010157861367183268 107 | 259: 4.3840131989471124e-05 108 | # classes that are indistinguishable from single scan or inconsistent in 109 | # ground truth are mapped to their closest equivalent 110 | learning_map: 111 | 0 : 0 # "unlabeled" 112 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped 113 | 10: 1 # "car" 114 | 11: 2 # "bicycle" 115 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 116 | 15: 3 # "motorcycle" 117 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 118 | 18: 4 # "truck" 119 | 20: 5 # "other-vehicle" 120 | 30: 6 # "person" 121 | 31: 7 # "bicyclist" 122 | 32: 8 # "motorcyclist" 123 | 40: 9 # "road" 124 | 44: 10 # "parking" 125 | 48: 11 # "sidewalk" 126 | 49: 12 # "other-ground" 127 | 50: 13 # "building" 128 | 51: 14 # "fence" 129 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 130 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 131 | 70: 15 # "vegetation" 132 | 71: 16 # "trunk" 133 | 72: 17 # "terrain" 134 | 80: 18 # "pole" 135 | 81: 19 # "traffic-sign" 136 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 137 | 252: 1 # "moving-car" to "car" ------------------------------------mapped 138 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped 139 | 254: 6 # "moving-person" to "person" ------------------------------mapped 140 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped 141 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped 142 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped 143 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped 144 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped 145 | learning_map_inv: # inverse of previous map 146 | 0: 0 # "unlabeled", and others ignored 147 | 1: 10 # "car" 148 | 2: 11 # "bicycle" 149 | 3: 15 # "motorcycle" 150 | 4: 18 # "truck" 151 | 5: 20 # "other-vehicle" 152 | 6: 30 # "person" 153 | 7: 31 # "bicyclist" 154 | 8: 32 # "motorcyclist" 155 | 9: 40 # "road" 156 | 10: 44 # "parking" 157 | 11: 48 # "sidewalk" 158 | 12: 49 # "other-ground" 159 | 13: 50 # "building" 160 | 14: 51 # "fence" 161 | 15: 70 # "vegetation" 162 | 16: 71 # "trunk" 163 | 17: 72 # "terrain" 164 | 18: 80 # "pole" 165 | 19: 81 # "traffic-sign" 166 | learning_ignore: # Ignore classes 167 | 0: True # "unlabeled", and others ignored 168 | 1: False # "car" 169 | 2: False # "bicycle" 170 | 3: False # "motorcycle" 171 | 4: False # "truck" 172 | 5: False # "other-vehicle" 173 | 6: False # "person" 174 | 7: False # "bicyclist" 175 | 8: False # "motorcyclist" 176 | 9: False # "road" 177 | 10: False # "parking" 178 | 11: False # "sidewalk" 179 | 12: False # "other-ground" 180 | 13: False # "building" 181 | 14: False # "fence" 182 | 15: False # "vegetation" 183 | 16: False # "trunk" 184 | 17: False # "terrain" 185 | 18: False # "pole" 186 | 19: False # "traffic-sign" 187 | split: # sequence numbers 188 | train: 189 | - 0 190 | - 1 191 | - 2 192 | - 3 193 | - 4 194 | - 5 195 | - 6 196 | - 7 197 | - 9 198 | - 10 199 | valid: 200 | - 8 201 | test: 202 | - 11 203 | - 12 204 | - 13 205 | - 14 206 | - 15 207 | - 16 208 | - 17 209 | - 18 210 | - 19 211 | - 20 212 | - 21 213 | -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/decoders/__init__.py -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/decoders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/decoders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/decoders/__pycache__/squeezeseg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/decoders/__pycache__/squeezeseg.cpython-38.pyc -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/decoders/__pycache__/squeezesegV2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/decoders/__pycache__/squeezesegV2.cpython-38.pyc -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/decoders/darknet.py: -------------------------------------------------------------------------------- 1 | # This file was modified from https://github.com/BobLiu20/YOLOv3_PyTorch 2 | # It needed to be modified in order to accomodate for different strides in the 3 | 4 | import torch.nn as nn 5 | from collections import OrderedDict 6 | import torch.nn.functional as F 7 | 8 | 9 | class BasicBlock(nn.Module): 10 | def __init__(self, inplanes, planes, bn_d=0.1): 11 | super(BasicBlock, self).__init__() 12 | self.conv1 = nn.Conv2d(inplanes, planes[0], kernel_size=1, 13 | stride=1, padding=0, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes[0], momentum=bn_d) 15 | self.relu1 = nn.LeakyReLU(0.1) 16 | self.conv2 = nn.Conv2d(planes[0], planes[1], kernel_size=3, 17 | stride=1, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes[1], momentum=bn_d) 19 | self.relu2 = nn.LeakyReLU(0.1) 20 | 21 | def forward(self, x): 22 | residual = x 23 | 24 | out = self.conv1(x) 25 | out = self.bn1(out) 26 | out = self.relu1(out) 27 | 28 | out = self.conv2(out) 29 | out = self.bn2(out) 30 | out = self.relu2(out) 31 | 32 | out += residual 33 | return out 34 | 35 | 36 | # ****************************************************************************** 37 | 38 | class Decoder(nn.Module): 39 | """ 40 | Class for DarknetSeg. Subclasses PyTorch's own "nn" module 41 | """ 42 | 43 | def __init__(self, params, stub_skips, OS=32, feature_depth=1024): 44 | super(Decoder, self).__init__() 45 | self.backbone_OS = OS 46 | self.backbone_feature_depth = feature_depth 47 | self.drop_prob = params["dropout"] 48 | self.bn_d = params["bn_d"] 49 | 50 | # stride play 51 | self.strides = [2, 2, 2, 2, 2] 52 | # check current stride 53 | current_os = 1 54 | for s in self.strides: 55 | current_os *= s 56 | print("Decoder original OS: ", int(current_os)) 57 | # redo strides according to needed stride 58 | for i, stride in enumerate(self.strides): 59 | if int(current_os) != self.backbone_OS: 60 | if stride == 2: 61 | current_os /= 2 62 | self.strides[i] = 1 63 | if int(current_os) == self.backbone_OS: 64 | break 65 | print("Decoder new OS: ", int(current_os)) 66 | print("Decoder strides: ", self.strides) 67 | 68 | # decoder 69 | self.dec5 = self._make_dec_layer(BasicBlock, 70 | [self.backbone_feature_depth, 512], 71 | bn_d=self.bn_d, 72 | stride=self.strides[0]) 73 | self.dec4 = self._make_dec_layer(BasicBlock, [512, 256], bn_d=self.bn_d, 74 | stride=self.strides[1]) 75 | self.dec3 = self._make_dec_layer(BasicBlock, [256, 128], bn_d=self.bn_d, 76 | stride=self.strides[2]) 77 | self.dec2 = self._make_dec_layer(BasicBlock, [128, 64], bn_d=self.bn_d, 78 | stride=self.strides[3]) 79 | self.dec1 = self._make_dec_layer(BasicBlock, [64, 32], bn_d=self.bn_d, 80 | stride=self.strides[4]) 81 | 82 | # layer list to execute with skips 83 | self.layers = [self.dec5, self.dec4, self.dec3, self.dec2, self.dec1] 84 | 85 | # for a bit of fun 86 | self.dropout = nn.Dropout2d(self.drop_prob) 87 | 88 | # last channels 89 | self.last_channels = 32 90 | 91 | def _make_dec_layer(self, block, planes, bn_d=0.1, stride=2): 92 | layers = [] 93 | 94 | # downsample 95 | if stride == 2: 96 | layers.append(("upconv", nn.ConvTranspose2d(planes[0], planes[1], 97 | kernel_size=[1, 4], stride=[1, 2], 98 | padding=[0, 1]))) 99 | else: 100 | layers.append(("conv", nn.Conv2d(planes[0], planes[1], 101 | kernel_size=3, padding=1))) 102 | layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d))) 103 | layers.append(("relu", nn.LeakyReLU(0.1))) 104 | 105 | # blocks 106 | layers.append(("residual", block(planes[1], planes, bn_d))) 107 | 108 | return nn.Sequential(OrderedDict(layers)) 109 | 110 | def run_layer(self, x, layer, skips, os): 111 | feats = layer(x) # up 112 | if feats.shape[-1] > x.shape[-1]: 113 | os //= 2 # match skip 114 | feats = feats + skips[os].detach() # add skip 115 | x = feats 116 | return x, skips, os 117 | 118 | def forward(self, x, skips): 119 | os = self.backbone_OS 120 | 121 | # run layers 122 | x, skips, os = self.run_layer(x, self.dec5, skips, os) 123 | x, skips, os = self.run_layer(x, self.dec4, skips, os) 124 | x, skips, os = self.run_layer(x, self.dec3, skips, os) 125 | x, skips, os = self.run_layer(x, self.dec2, skips, os) 126 | x, skips, os = self.run_layer(x, self.dec1, skips, os) 127 | 128 | x = self.dropout(x) 129 | 130 | return x 131 | 132 | def get_last_depth(self): 133 | return self.last_channels 134 | -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/decoders/squeezeseg.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/BichenWuUCB/SqueezeSeg 2 | from __future__ import print_function 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class FireUp(nn.Module): 9 | 10 | def __init__(self, inplanes, squeeze_planes, 11 | expand1x1_planes, expand3x3_planes, stride): 12 | super(FireUp, self).__init__() 13 | self.inplanes = inplanes 14 | self.stride = stride 15 | self.activation = nn.ReLU(inplace=True) 16 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 17 | if self.stride == 2: 18 | self.upconv = nn.ConvTranspose2d(squeeze_planes, squeeze_planes, 19 | kernel_size=[1, 4], stride=[1, 2], 20 | padding=[0, 1]) 21 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, 22 | kernel_size=1) 23 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, 24 | kernel_size=3, padding=1) 25 | 26 | def forward(self, x): 27 | x = self.activation(self.squeeze(x)) 28 | if self.stride == 2: 29 | x = self.activation(self.upconv(x)) 30 | return torch.cat([ 31 | self.activation(self.expand1x1(x)), 32 | self.activation(self.expand3x3(x)) 33 | ], 1) 34 | 35 | 36 | # ****************************************************************************** 37 | 38 | class Decoder(nn.Module): 39 | """ 40 | Class for DarknetSeg. Subclasses PyTorch's own "nn" module 41 | """ 42 | 43 | def __init__(self, params, stub_skips, OS=32, feature_depth=512): 44 | super(Decoder, self).__init__() 45 | self.backbone_OS = OS 46 | self.backbone_feature_depth = feature_depth 47 | self.drop_prob = params["dropout"] 48 | 49 | # stride play 50 | self.strides = [2, 2, 2, 2] 51 | # check current stride 52 | current_os = 1 53 | for s in self.strides: 54 | current_os *= s 55 | print("Decoder original OS: ", int(current_os)) 56 | # redo strides according to needed stride 57 | for i, stride in enumerate(self.strides): 58 | if int(current_os) != self.backbone_OS: 59 | if stride == 2: 60 | current_os /= 2 61 | self.strides[i] = 1 62 | if int(current_os) == self.backbone_OS: 63 | break 64 | print("Decoder new OS: ", int(current_os)) 65 | print("Decoder strides: ", self.strides) 66 | 67 | # decoder 68 | # decoder 69 | self.firedec10 = FireUp(self.backbone_feature_depth, 64, 128, 128, 70 | stride=self.strides[0]) 71 | self.firedec11 = FireUp(256, 32, 64, 64, 72 | stride=self.strides[1]) 73 | self.firedec12 = FireUp(128, 16, 32, 32, 74 | stride=self.strides[2]) 75 | self.firedec13 = FireUp(64, 16, 32, 32, 76 | stride=self.strides[3]) 77 | 78 | # layer list to execute with skips 79 | self.layers = [self.firedec10, self.firedec11, 80 | self.firedec12, self.firedec13] 81 | 82 | # for a bit of fun 83 | self.dropout = nn.Dropout2d(self.drop_prob) 84 | 85 | # last channels 86 | self.last_channels = 64 87 | 88 | def run_layer(self, x, layer, skips, os): 89 | feats = layer(x) # up 90 | if feats.shape[-1] > x.shape[-1]: 91 | os //= 2 # match skip 92 | feats = feats + skips[os].detach() # add skip 93 | x = feats 94 | return x, skips, os 95 | 96 | def forward(self, x, skips): 97 | os = self.backbone_OS 98 | 99 | # run layers 100 | x, skips, os = self.run_layer(x, self.firedec10, skips, os) 101 | x, skips, os = self.run_layer(x, self.firedec11, skips, os) 102 | x, skips, os = self.run_layer(x, self.firedec12, skips, os) 103 | x, skips, os = self.run_layer(x, self.firedec13, skips, os) 104 | 105 | x = self.dropout(x) 106 | 107 | return x 108 | 109 | def get_last_depth(self): 110 | return self.last_channels 111 | -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/decoders/squeezesegV2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | 4 | from __future__ import print_function 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class FireUp(nn.Module): 11 | 12 | def __init__(self, inplanes, squeeze_planes, 13 | expand1x1_planes, expand3x3_planes, bn_d, stride): 14 | super(FireUp, self).__init__() 15 | self.inplanes = inplanes 16 | self.bn_d = bn_d 17 | self.stride = stride 18 | self.activation = nn.ReLU(inplace=True) 19 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 20 | self.squeeze_bn = nn.BatchNorm2d(squeeze_planes, momentum=self.bn_d) 21 | if self.stride == 2: 22 | self.upconv = nn.ConvTranspose2d(squeeze_planes, squeeze_planes, 23 | kernel_size=[1, 4], stride=[1, 2], 24 | padding=[0, 1]) 25 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, 26 | kernel_size=1) 27 | self.expand1x1_bn = nn.BatchNorm2d(expand1x1_planes, momentum=self.bn_d) 28 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, 29 | kernel_size=3, padding=1) 30 | self.expand3x3_bn = nn.BatchNorm2d(expand3x3_planes, momentum=self.bn_d) 31 | 32 | def forward(self, x): 33 | x = self.activation(self.squeeze_bn(self.squeeze(x))) 34 | if self.stride == 2: 35 | x = self.activation(self.upconv(x)) 36 | return torch.cat([ 37 | self.activation(self.expand1x1_bn(self.expand1x1(x))), 38 | self.activation(self.expand3x3_bn(self.expand3x3(x))) 39 | ], 1) 40 | 41 | 42 | # ****************************************************************************** 43 | 44 | class Decoder(nn.Module): 45 | """ 46 | Class for DarknetSeg. Subclasses PyTorch's own "nn" module 47 | """ 48 | 49 | def __init__(self, params, stub_skips, OS=32, feature_depth=512): 50 | super(Decoder, self).__init__() 51 | self.backbone_OS = OS 52 | self.backbone_feature_depth = feature_depth 53 | self.drop_prob = params["dropout"] 54 | self.bn_d = params["bn_d"] 55 | 56 | # stride play 57 | self.strides = [2, 2, 2, 2] 58 | # check current stride 59 | current_os = 1 60 | for s in self.strides: 61 | current_os *= s 62 | print("Decoder original OS: ", int(current_os)) 63 | # redo strides according to needed stride 64 | for i, stride in enumerate(self.strides): 65 | if int(current_os) != self.backbone_OS: 66 | if stride == 2: 67 | current_os /= 2 68 | self.strides[i] = 1 69 | if int(current_os) == self.backbone_OS: 70 | break 71 | print("Decoder new OS: ", int(current_os)) 72 | print("Decoder strides: ", self.strides) 73 | 74 | # decoder 75 | # decoder 76 | self.firedec10 = FireUp(self.backbone_feature_depth, 77 | 64, 128, 128, bn_d=self.bn_d, 78 | stride=self.strides[0]) 79 | self.firedec11 = FireUp(256, 32, 64, 64, bn_d=self.bn_d, 80 | stride=self.strides[1]) 81 | self.firedec12 = FireUp(128, 16, 32, 32, bn_d=self.bn_d, 82 | stride=self.strides[2]) 83 | self.firedec13 = FireUp(64, 16, 32, 32, bn_d=self.bn_d, 84 | stride=self.strides[3]) 85 | 86 | # for a bit of fun 87 | self.dropout = nn.Dropout2d(self.drop_prob) 88 | 89 | # last channels 90 | self.last_channels = 64 91 | 92 | def run_layer(self, x, layer, skips, os): 93 | feats = layer(x) # up 94 | if feats.shape[-1] > x.shape[-1]: 95 | os //= 2 # match skip 96 | feats = feats + skips[os].detach() # add skip 97 | x = feats 98 | return x, skips, os 99 | 100 | def forward(self, x, skips): 101 | os = self.backbone_OS 102 | 103 | # run layers 104 | x, skips, os = self.run_layer(x, self.firedec10, skips, os) 105 | x, skips, os = self.run_layer(x, self.firedec11, skips, os) 106 | x, skips, os = self.run_layer(x, self.firedec12, skips, os) 107 | x, skips, os = self.run_layer(x, self.firedec13, skips, os) 108 | 109 | x = self.dropout(x) 110 | 111 | return x 112 | 113 | def get_last_depth(self): 114 | return self.last_channels 115 | -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/modules/__init__.py -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/modules/__pycache__/segmentator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/modules/__pycache__/segmentator.cpython-38.pyc -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/modules/segmentator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | 4 | import imp 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from semantic_net.rangenet.tasks.semantic.postproc.CRF import CRF 9 | import importlib 10 | # import __init__ as booger 11 | 12 | 13 | class Segmentator(nn.Module): 14 | def __init__(self, ARCH, nclasses, path=None, path_append="", strict=False): 15 | super().__init__() 16 | self.ARCH = ARCH 17 | self.nclasses = nclasses 18 | self.path = path 19 | self.path_append = path_append 20 | self.strict = False 21 | 22 | # get the model 23 | # bboneModule = imp.load_source("bboneModule", 24 | # booger.TRAIN_PATH + '/backbones/' + 25 | # self.ARCH["backbone"]["name"] + '.py') 26 | backbone_path = "semantic_net.rangenet.backbones." + self.ARCH["backbone"]["name"] 27 | bboneModule = importlib.import_module(backbone_path) 28 | self.backbone = bboneModule.Backbone(params=self.ARCH["backbone"]) 29 | 30 | # do a pass of the backbone to initialize the skip connections 31 | stub = torch.zeros((1, 32 | self.backbone.get_input_depth(), 33 | self.ARCH["dataset"]["sensor"]["img_prop"]["height"], 34 | self.ARCH["dataset"]["sensor"]["img_prop"]["width"])) 35 | 36 | if torch.cuda.is_available(): 37 | stub = stub.cuda() 38 | self.backbone.cuda() 39 | _, stub_skips = self.backbone(stub) 40 | 41 | # decoderModule = imp.load_source("decoderModule", 42 | # booger.TRAIN_PATH + '/tasks/semantic/decoders/' + 43 | # self.ARCH["decoder"]["name"] + '.py') 44 | decoder_path = "semantic_net.rangenet.tasks.semantic.decoders." + self.ARCH["decoder"]["name"] 45 | decoderModule = importlib.import_module(decoder_path) 46 | self.decoder = decoderModule.Decoder(params=self.ARCH["decoder"], 47 | stub_skips=stub_skips, 48 | OS=self.ARCH["backbone"]["OS"], 49 | feature_depth=self.backbone.get_last_depth()) 50 | 51 | self.head = nn.Sequential(nn.Dropout2d(p=ARCH["head"]["dropout"]), 52 | nn.Conv2d(self.decoder.get_last_depth(), 53 | self.nclasses, kernel_size=3, 54 | stride=1, padding=1)) 55 | 56 | if self.ARCH["post"]["CRF"]["use"]: 57 | self.CRF = CRF(self.ARCH["post"]["CRF"]["params"], self.nclasses) 58 | else: 59 | self.CRF = None 60 | 61 | ### Do not Train any parameter in RangeNet++ (by Zhen Luo) 62 | # train backbone? 63 | # if not self.ARCH["backbone"]["train"]: 64 | for w in self.backbone.parameters(): 65 | w.requires_grad = False 66 | 67 | # train decoder? 68 | # if not self.ARCH["decoder"]["train"]: 69 | for w in self.decoder.parameters(): 70 | w.requires_grad = False 71 | 72 | # train head? 73 | # if not self.ARCH["head"]["train"]: 74 | for w in self.head.parameters(): 75 | w.requires_grad = False 76 | 77 | # train CRF? 78 | # if self.CRF and not self.ARCH["post"]["CRF"]["train"]: 79 | if self.CRF: 80 | for w in self.CRF.parameters(): 81 | w.requires_grad = False 82 | 83 | # print number of parameters and the ones requiring gradients 84 | # print number of parameters and the ones requiring gradients 85 | weights_total = sum(p.numel() for p in self.parameters()) 86 | weights_grad = sum(p.numel() for p in self.parameters() if p.requires_grad) 87 | print("Total number of parameters: ", weights_total) 88 | print("Total number of parameters requires_grad: ", weights_grad) 89 | 90 | # breakdown by layer 91 | weights_enc = sum(p.numel() for p in self.backbone.parameters()) 92 | weights_dec = sum(p.numel() for p in self.decoder.parameters()) 93 | weights_head = sum(p.numel() for p in self.head.parameters()) 94 | print("Param encoder ", weights_enc) 95 | print("Param decoder ", weights_dec) 96 | print("Param head ", weights_head) 97 | if self.CRF: 98 | weights_crf = sum(p.numel() for p in self.CRF.parameters()) 99 | print("Param CRF ", weights_crf) 100 | 101 | # get weights 102 | if path is not None: 103 | # try backbone 104 | try: 105 | w_dict = torch.load(path + "/backbone" + path_append, 106 | map_location=lambda storage, loc: storage) 107 | self.backbone.load_state_dict(w_dict, strict=True) 108 | print("Successfully loaded model backbone weights") 109 | except Exception as e: 110 | print() 111 | print("Couldn't load backbone, using random weights. Error: ", e) 112 | if strict: 113 | print("I'm in strict mode and failure to load weights blows me up :)") 114 | raise e 115 | 116 | # try decoder 117 | try: 118 | w_dict = torch.load(path + "/segmentation_decoder" + path_append, 119 | map_location=lambda storage, loc: storage) 120 | self.decoder.load_state_dict(w_dict, strict=True) 121 | print("Successfully loaded model decoder weights") 122 | except Exception as e: 123 | print("Couldn't load decoder, using random weights. Error: ", e) 124 | if strict: 125 | print("I'm in strict mode and failure to load weights blows me up :)") 126 | raise e 127 | 128 | # try head 129 | try: 130 | w_dict = torch.load(path + "/segmentation_head" + path_append, 131 | map_location=lambda storage, loc: storage) 132 | self.head.load_state_dict(w_dict, strict=True) 133 | print("Successfully loaded model head weights") 134 | except Exception as e: 135 | print("Couldn't load head, using random weights. Error: ", e) 136 | if strict: 137 | print("I'm in strict mode and failure to load weights blows me up :)") 138 | raise e 139 | 140 | # try CRF 141 | if self.CRF: 142 | try: 143 | w_dict = torch.load(path + "/segmentation_CRF" + path_append, 144 | map_location=lambda storage, loc: storage) 145 | self.CRF.load_state_dict(w_dict, strict=True) 146 | print("Successfully loaded model CRF weights") 147 | except Exception as e: 148 | print("Couldn't load CRF, using random weights. Error: ", e) 149 | if strict: 150 | print("I'm in strict mode and failure to load weights blows me up :)") 151 | raise e 152 | else: 153 | print("No path to pretrained, using random init.") 154 | 155 | def forward(self, x, mask=None): 156 | y, skips = self.backbone(x) 157 | y = self.decoder(y, skips) 158 | y = self.head(y) 159 | y = F.softmax(y, dim=1) 160 | if self.CRF: 161 | assert(mask is not None) 162 | y = self.CRF(x, y, mask) 163 | return y 164 | 165 | def save_checkpoint(self, logdir, suffix=""): 166 | # Save the weights 167 | torch.save(self.backbone.state_dict(), logdir + 168 | "/backbone" + suffix) 169 | torch.save(self.decoder.state_dict(), logdir + 170 | "/segmentation_decoder" + suffix) 171 | torch.save(self.head.state_dict(), logdir + 172 | "/segmentation_head" + suffix) 173 | if self.CRF: 174 | torch.save(self.CRF.state_dict(), logdir + 175 | "/segmentation_CRF" + suffix) 176 | -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/postproc/CRF.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | 4 | import numpy as np 5 | from scipy import signal 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | # import __init__ as booger 10 | 11 | 12 | class LocallyConnectedXYZLayer(nn.Module): 13 | def __init__(self, h, w, sigma, nclasses): 14 | super().__init__() 15 | # size of window 16 | self.h = h 17 | self.padh = h//2 18 | self.w = w 19 | self.padw = w//2 20 | assert(self.h % 2 == 1 and self.w % 2 == 1) # window must be odd 21 | self.sigma = sigma 22 | self.gauss_den = 2 * self.sigma**2 23 | self.nclasses = nclasses 24 | 25 | def forward(self, xyz, softmax, mask): 26 | # softmax size 27 | N, C, H, W = softmax.shape 28 | 29 | # make sofmax zero everywhere input is invalid 30 | softmax = softmax * mask.unsqueeze(1).float() 31 | 32 | # get x,y,z for distance (shape N,1,H,W) 33 | x = xyz[:, 0].unsqueeze(1) 34 | y = xyz[:, 1].unsqueeze(1) 35 | z = xyz[:, 2].unsqueeze(1) 36 | 37 | # im2col in size of window of input (x,y,z separately) 38 | window_x = F.unfold(x, kernel_size=(self.h, self.w), 39 | padding=(self.padh, self.padw)) 40 | center_x = F.unfold(x, kernel_size=(1, 1), 41 | padding=(0, 0)) 42 | window_y = F.unfold(y, kernel_size=(self.h, self.w), 43 | padding=(self.padh, self.padw)) 44 | center_y = F.unfold(y, kernel_size=(1, 1), 45 | padding=(0, 0)) 46 | window_z = F.unfold(z, kernel_size=(self.h, self.w), 47 | padding=(self.padh, self.padw)) 48 | center_z = F.unfold(z, kernel_size=(1, 1), 49 | padding=(0, 0)) 50 | 51 | # sq distance to center (center distance is zero) 52 | unravel_dist2 = (window_x - center_x)**2 + \ 53 | (window_y - center_y)**2 + \ 54 | (window_z - center_z)**2 55 | 56 | # weight input distance by gaussian weights 57 | unravel_gaussian = torch.exp(- unravel_dist2 / self.gauss_den) 58 | 59 | # im2col in size of window of softmax to reweight by gaussian weights from input 60 | cloned_softmax = softmax.clone() 61 | for i in range(self.nclasses): 62 | # get the softmax for this class 63 | c_softmax = softmax[:, i].unsqueeze(1) 64 | # unfold this class to weigh it by the proper gaussian weights 65 | unravel_softmax = F.unfold(c_softmax, 66 | kernel_size=(self.h, self.w), 67 | padding=(self.padh, self.padw)) 68 | unravel_w_softmax = unravel_softmax * unravel_gaussian 69 | # add dimenssion 1 to obtain the new softmax for this class 70 | unravel_added_softmax = unravel_w_softmax.sum(dim=1).unsqueeze(1) 71 | # fold it and put it in new tensor 72 | added_softmax = unravel_added_softmax.view(N, H, W) 73 | cloned_softmax[:, i] = added_softmax 74 | 75 | return cloned_softmax 76 | 77 | 78 | class CRF(nn.Module): 79 | def __init__(self, params, nclasses): 80 | super().__init__() 81 | self.params = params 82 | self.iter = torch.nn.Parameter(torch.tensor(params["iter"]), 83 | requires_grad=False) 84 | self.lcn_size = torch.nn.Parameter(torch.tensor([params["lcn_size"]["h"], 85 | params["lcn_size"]["w"]]), 86 | requires_grad=False) 87 | self.xyz_coef = torch.nn.Parameter(torch.tensor(params["xyz_coef"]), 88 | requires_grad=False).float() 89 | self.xyz_sigma = torch.nn.Parameter(torch.tensor(params["xyz_sigma"]), 90 | requires_grad=False).float() 91 | 92 | self.nclasses = nclasses 93 | print("Using CRF!") 94 | 95 | # define layers here 96 | # compat init 97 | self.compat_kernel_init = np.reshape(np.ones((self.nclasses, self.nclasses)) - 98 | np.identity(self.nclasses), 99 | [self.nclasses, self.nclasses, 1, 1]) 100 | 101 | # bilateral compatibility matrixes 102 | self.compat_conv = nn.Conv2d(self.nclasses, self.nclasses, 1) 103 | self.compat_conv.weight = torch.nn.Parameter(torch.from_numpy( 104 | self.compat_kernel_init).float() * self.xyz_coef, requires_grad=True) 105 | 106 | # locally connected layer for message passing 107 | self.local_conn_xyz = LocallyConnectedXYZLayer(params["lcn_size"]["h"], 108 | params["lcn_size"]["w"], 109 | params["xyz_coef"], 110 | self.nclasses) 111 | 112 | def forward(self, input, softmax, mask): 113 | # use xyz 114 | xyz = input[:, 1:4] 115 | 116 | # iteratively 117 | for iter in range(self.iter): 118 | # message passing as locally connected layer 119 | locally_connected = self.local_conn_xyz(xyz, softmax, mask) 120 | 121 | # reweigh with the 1x1 convolution 122 | reweight_softmax = self.compat_conv(locally_connected) 123 | 124 | # add the new values to the original softmax 125 | reweight_softmax = reweight_softmax + softmax 126 | 127 | # lastly, renormalize 128 | softmax = F.softmax(reweight_softmax, dim=1) 129 | 130 | return softmax 131 | -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/postproc/KNN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import __init__ as booger 9 | 10 | 11 | def get_gaussian_kernel(kernel_size=3, sigma=2, channels=1): 12 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) 13 | x_coord = torch.arange(kernel_size) 14 | x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) 15 | y_grid = x_grid.t() 16 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() 17 | 18 | mean = (kernel_size - 1) / 2. 19 | variance = sigma**2. 20 | 21 | # Calculate the 2-dimensional gaussian kernel which is 22 | # the product of two gaussian distributions for two different 23 | # variables (in this case called x and y) 24 | gaussian_kernel = (1. / (2. * math.pi * variance)) *\ 25 | torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) / (2 * variance)) 26 | 27 | # Make sure sum of values in gaussian kernel equals 1. 28 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 29 | 30 | # Reshape to 2d depthwise convolutional weight 31 | gaussian_kernel = gaussian_kernel.view(kernel_size, kernel_size) 32 | 33 | return gaussian_kernel 34 | 35 | 36 | class KNN(nn.Module): 37 | def __init__(self, params, nclasses): 38 | super().__init__() 39 | print("*"*80) 40 | print("Cleaning point-clouds with kNN post-processing") 41 | self.knn = params["knn"] 42 | self.search = params["search"] 43 | self.sigma = params["sigma"] 44 | self.cutoff = params["cutoff"] 45 | self.nclasses = nclasses 46 | print("kNN parameters:") 47 | print("knn:", self.knn) 48 | print("search:", self.search) 49 | print("sigma:", self.sigma) 50 | print("cutoff:", self.cutoff) 51 | print("nclasses:", self.nclasses) 52 | print("*"*80) 53 | 54 | def forward(self, proj_range, unproj_range, proj_argmax, px, py): 55 | ''' Warning! Only works for un-batched pointclouds. 56 | If they come batched we need to iterate over the batch dimension or do 57 | something REALLY smart to handle unaligned number of points in memory 58 | ''' 59 | # get device 60 | if proj_range.is_cuda: 61 | device = torch.device("cuda") 62 | else: 63 | device = torch.device("cpu") 64 | 65 | # sizes of projection scan 66 | H, W = proj_range.shape 67 | 68 | # number of points 69 | P = unproj_range.shape 70 | 71 | # check if size of kernel is odd and complain 72 | if (self.search % 2 == 0): 73 | raise ValueError("Nearest neighbor kernel must be odd number") 74 | 75 | # calculate padding 76 | pad = int((self.search - 1) / 2) 77 | 78 | # unfold neighborhood to get nearest neighbors for each pixel (range image) 79 | proj_unfold_k_rang = F.unfold(proj_range[None, None, ...], 80 | kernel_size=(self.search, self.search), 81 | padding=(pad, pad)) 82 | 83 | # index with px, py to get ALL the pcld points 84 | idx_list = py * W + px 85 | unproj_unfold_k_rang = proj_unfold_k_rang[:, :, idx_list] 86 | 87 | # WARNING, THIS IS A HACK 88 | # Make non valid (<0) range points extremely big so that there is no screwing 89 | # up the nn self.search 90 | unproj_unfold_k_rang[unproj_unfold_k_rang < 0] = float("inf") 91 | 92 | # now the matrix is unfolded TOTALLY, replace the middle points with the actual range points 93 | center = int(((self.search * self.search) - 1) / 2) 94 | unproj_unfold_k_rang[:, center, :] = unproj_range 95 | 96 | # now compare range 97 | k2_distances = torch.abs(unproj_unfold_k_rang - unproj_range) 98 | 99 | # make a kernel to weigh the ranges according to distance in (x,y) 100 | # I make this 1 - kernel because I want distances that are close in (x,y) 101 | # to matter more 102 | inv_gauss_k = ( 103 | 1 - get_gaussian_kernel(self.search, self.sigma, 1)).view(1, -1, 1) 104 | inv_gauss_k = inv_gauss_k.to(device).type(proj_range.type()) 105 | 106 | # apply weighing 107 | k2_distances = k2_distances * inv_gauss_k 108 | 109 | # find nearest neighbors 110 | _, knn_idx = k2_distances.topk( 111 | self.knn, dim=1, largest=False, sorted=False) 112 | 113 | # do the same unfolding with the argmax 114 | proj_unfold_1_argmax = F.unfold(proj_argmax[None, None, ...].float(), 115 | kernel_size=(self.search, self.search), 116 | padding=(pad, pad)).long() 117 | unproj_unfold_1_argmax = proj_unfold_1_argmax[:, :, idx_list] 118 | 119 | # get the top k predictions from the knn at each pixel 120 | knn_argmax = torch.gather( 121 | input=unproj_unfold_1_argmax, dim=1, index=knn_idx) 122 | 123 | # fake an invalid argmax of classes + 1 for all cutoff items 124 | if self.cutoff > 0: 125 | knn_distances = torch.gather(input=k2_distances, dim=1, index=knn_idx) 126 | knn_invalid_idx = knn_distances > self.cutoff 127 | knn_argmax[knn_invalid_idx] = self.nclasses 128 | 129 | # now vote 130 | # argmax onehot has an extra class for objects after cutoff 131 | knn_argmax_onehot = torch.zeros( 132 | (1, self.nclasses + 1, P[0]), device=device).type(proj_range.type()) 133 | ones = torch.ones_like(knn_argmax).type(proj_range.type()) 134 | knn_argmax_onehot = knn_argmax_onehot.scatter_add_(1, knn_argmax, ones) 135 | 136 | # now vote (as a sum over the onehot shit) (don't let it choose unlabeled OR invalid) 137 | knn_argmax_out = knn_argmax_onehot[:, 1:-1].argmax(dim=1) + 1 138 | 139 | # reshape again 140 | knn_argmax_out = knn_argmax_out.view(P) 141 | 142 | return knn_argmax_out 143 | -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/postproc/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | TRAIN_PATH = "../" 3 | DEPLOY_PATH = "../../deploy" 4 | sys.path.insert(0, TRAIN_PATH) 5 | -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/postproc/__pycache__/CRF.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/postproc/__pycache__/CRF.cpython-38.pyc -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/postproc/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/postproc/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /semantic_net/rangenet/tasks/semantic/readme.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/readme.md -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou 3 | # This file is covered by the LICENSE file in the root of the project PCPNet: 4 | # https://github.com/Blurryface0814/PCPNet 5 | # Brief: Test script for range-image-based point cloud prediction 6 | import os 7 | import time 8 | import argparse 9 | import random 10 | import yaml 11 | from pytorch_lightning import Trainer 12 | from pytorch_lightning import loggers as pl_loggers 13 | 14 | import pcpnet.datasets.datasets as datasets 15 | import pcpnet.models.PCPNet as PCPNet 16 | from pcpnet.utils.utils import set_seed 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser("./test.py") 20 | parser.add_argument( 21 | "--model", 22 | "-m", 23 | type=str, 24 | default=None, 25 | help="Model to be tested" 26 | ) 27 | parser.add_argument( 28 | "--limit_test_batches", 29 | "-l", 30 | type=float, 31 | default=1.0, 32 | help="Percentage of test data to be tested", 33 | ) 34 | parser.add_argument( 35 | "-d", 36 | "--dataset", 37 | type=str, 38 | default=None, 39 | required=True, 40 | help="The path of processed KITTI odometry and SemanticKITTI dataset", 41 | ) 42 | parser.add_argument( 43 | "--save", 44 | "-s", 45 | action="store_true", 46 | help="Save point clouds" 47 | ) 48 | parser.add_argument( 49 | "--only_save_pc", 50 | "-o", 51 | action="store_true", 52 | help="Only save point clouds, not compute loss" 53 | ) 54 | parser.add_argument( 55 | "--cd_downsample", 56 | type=int, 57 | default=-1, 58 | help="Number of downsampled points for evaluating Chamfer Distance", 59 | ) 60 | parser.add_argument("--path", "-p", type=str, default=None, help="Path to data") 61 | parser.add_argument( 62 | "-seq", 63 | "--sequence", 64 | type=int, 65 | nargs="+", 66 | default=None, 67 | help="Sequence to be tested", 68 | ) 69 | 70 | args, unparsed = parser.parse_known_args() 71 | dataset_path = args.dataset 72 | if dataset_path: 73 | pass 74 | else: 75 | raise Exception("Please enter the path of dataset") 76 | 77 | # load config file 78 | config_filename = os.path.dirname(os.path.dirname(os.path.dirname(args.model))) + "/hparams.yaml" 79 | cfg = yaml.safe_load(open(config_filename)) 80 | print("Starting testing model ", cfg["LOG_NAME"]) 81 | """Manually set these""" 82 | cfg["DATA_CONFIG"]["COMPUTE_MEAN_AND_STD"] = False 83 | cfg["DATA_CONFIG"]["GENERATE_FILES"] = False 84 | 85 | if args.only_save_pc: 86 | cfg["TEST"]["ONLY_SAVE_POINT_CLOUDS"] = args.only_save_pc 87 | print("Only save point clouds") 88 | else: 89 | cfg["TEST"]["SAVE_POINT_CLOUDS"] = args.save 90 | 91 | cfg["TEST"]["N_DOWNSAMPLED_POINTS_CD"] = args.cd_downsample 92 | print("Evaluating CD on ", cfg["TEST"]["N_DOWNSAMPLED_POINTS_CD"], " points.") 93 | 94 | if args.sequence: 95 | cfg["DATA_CONFIG"]["SPLIT"]["TEST"] = args.sequence 96 | cfg["DATA_CONFIG"]["SPLIT"]["TRAIN"] = args.sequence 97 | cfg["DATA_CONFIG"]["SPLIT"]["VAL"] = args.sequence 98 | 99 | ###### Set random seed for torch, numpy and python 100 | set_seed(cfg["DATA_CONFIG"]["RANDOM_SEED"]) 101 | print("Random seed is ", cfg["DATA_CONFIG"]["RANDOM_SEED"]) 102 | 103 | data = datasets.KittiOdometryModule(cfg, dataset_path) 104 | data.setup() 105 | 106 | checkpoint_path = args.model 107 | cfg["TEST"]["USED_CHECKPOINT"] = checkpoint_path 108 | test_dir_name = "test_" + time.strftime("%Y%m%d_%H%M%S") 109 | cfg["TEST"]["DIR_NAME"] = test_dir_name 110 | 111 | model = PCPNet.PCPNet.load_from_checkpoint(checkpoint_path, cfg=cfg) 112 | 113 | # Only log if test is done on full data 114 | if args.limit_test_batches == 1.0 and not args.only_save_pc: 115 | logger = pl_loggers.TensorBoardLogger( 116 | save_dir=cfg["LOG_DIR"], 117 | default_hp_metric=False, 118 | name=test_dir_name, 119 | version="" 120 | ) 121 | else: 122 | logger = False 123 | 124 | trainer = Trainer( 125 | limit_test_batches=args.limit_test_batches, 126 | gpus=cfg["TRAIN"]["N_GPUS"], 127 | logger=logger, 128 | ) 129 | 130 | results = trainer.test(model, data.test_dataloader()) 131 | 132 | if logger: 133 | filename = os.path.join( 134 | cfg["LOG_DIR"], cfg["TEST"]["DIR_NAME"], "results" + ".yaml" 135 | ) 136 | log_to_save = {**{"results": results}, **vars(args), **cfg} 137 | with open(filename, "w") as yaml_file: 138 | yaml.dump(log_to_save, yaml_file, default_flow_style=False) 139 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou 3 | # This file is covered by the LICENSE file in the root of the project PCPNet: 4 | # https://github.com/Blurryface0814/PCPNet 5 | # Brief: Train script for range-image-based point cloud prediction 6 | import os 7 | import time 8 | import argparse 9 | import yaml 10 | from pytorch_lightning import Trainer 11 | from pytorch_lightning.loggers import TensorBoardLogger 12 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 13 | 14 | from pcpnet.datasets.datasets import KittiOdometryModule 15 | from pcpnet.models.PCPNet import PCPNet 16 | from pcpnet.utils.utils import set_seed 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser("./train.py") 20 | parser.add_argument( 21 | "--comment", "-c", type=str, default="", help="Add a comment to the LOG ID." 22 | ) 23 | parser.add_argument( 24 | "-res", 25 | "--resume", 26 | type=str, 27 | default=None, 28 | help="Resume training from specified model.", 29 | ) 30 | parser.add_argument( 31 | "-w", 32 | "--weights", 33 | type=str, 34 | default=None, 35 | help="Init model with weights from specified model", 36 | ) 37 | parser.add_argument( 38 | "-d", 39 | "--dataset", 40 | type=str, 41 | default=None, 42 | required=True, 43 | help="The path of processed KITTI odometry and SemanticKITTI dataset", 44 | ) 45 | parser.add_argument( 46 | "-raw", 47 | "--rawdata", 48 | type=str, 49 | default=None, 50 | help="The path of raw KITTI odometry and SemanticKITTI dataset", 51 | ) 52 | parser.add_argument( 53 | "-r", 54 | "--range", 55 | type=float, 56 | default=None, 57 | help="Change weight of range image loss.", 58 | ) 59 | parser.add_argument( 60 | "-m", "--mask", type=float, default=None, help="Change weight of mask loss." 61 | ) 62 | parser.add_argument( 63 | "-cd", 64 | "--chamfer", 65 | type=float, 66 | default=None, 67 | help="Change weight of Chamfer distance loss.", 68 | ) 69 | parser.add_argument( 70 | "-s", 71 | "--semantic", 72 | type=float, 73 | default=None, 74 | help="Change weight of semantic loss.", 75 | ) 76 | parser.add_argument( 77 | "-e", "--epochs", type=int, default=None, help="Number of training epochs." 78 | ) 79 | parser.add_argument( 80 | "-seq", 81 | "--sequence", 82 | type=int, 83 | nargs="+", 84 | default=None, 85 | help="Sequences for training.", 86 | ) 87 | parser.add_argument( 88 | "-u", 89 | "--update-cfg", 90 | type=bool, 91 | default=False, 92 | help="Update config file.", 93 | ) 94 | args, unparsed = parser.parse_known_args() 95 | 96 | dataset_path = args.dataset 97 | if dataset_path: 98 | pass 99 | else: 100 | raise Exception("Please enter the path of dataset") 101 | 102 | model_path = args.resume if args.resume else args.weights 103 | if model_path and not args.update_cfg: 104 | ###### Load config and update parameters 105 | checkpoint_path = model_path 106 | config_filename = os.path.dirname(model_path) 107 | if os.path.basename(config_filename) == "val": 108 | config_filename = os.path.dirname(config_filename) 109 | config_filename = os.path.dirname(config_filename) + "/hparams.yaml" 110 | 111 | cfg = yaml.safe_load(open(config_filename)) 112 | 113 | if args.weights and not args.comment: 114 | args.comment = "_pretrained" 115 | 116 | cfg["LOG_DIR"] = cfg["LOG_DIR"] + args.comment 117 | cfg["LOG_NAME"] = cfg["LOG_NAME"] + args.comment 118 | print("New log name is ", cfg["LOG_DIR"]) 119 | 120 | """Manually set these""" 121 | cfg["DATA_CONFIG"]["COMPUTE_MEAN_AND_STD"] = False 122 | cfg["DATA_CONFIG"]["GENERATE_FILES"] = False 123 | 124 | if args.epochs: 125 | cfg["TRAIN"]["MAX_EPOCH"] = args.epochs 126 | print("Set max_epochs to ", args.epochs) 127 | if args.range: 128 | cfg["TRAIN"]["LOSS_WEIGHT_RANGE_VIEW"] = args.range 129 | print("Overwriting LOSS_WEIGHT_RANGE_VIEW =", args.range) 130 | if args.mask: 131 | cfg["TRAIN"]["LOSS_WEIGHT_MASK"] = args.mask 132 | print("Overwriting LOSS_WEIGHT_MASK =", args.mask) 133 | if args.chamfer: 134 | cfg["TRAIN"]["LOSS_WEIGHT_CHAMFER_DISTANCE"] = args.chamfer 135 | print("Overwriting LOSS_WEIGHT_CHAMFER_DISTANCE =", args.chamfer) 136 | if args.semantic: 137 | cfg["TRAIN"]["LOSS_WEIGHT_SEMANTIC"] = args.semantic 138 | print("Overwriting LOSS_WEIGHT_SEMANTIC =", args.semantic) 139 | if args.sequence: 140 | cfg["DATA_CONFIG"]["SPLIT"]["TRAIN"] = args.sequence 141 | print("Training on sequences ", args.sequence) 142 | else: 143 | ###### Create new log 144 | resume_from_checkpoint = None 145 | config_filename = "config/parameters.yaml" 146 | cfg = yaml.safe_load(open(config_filename)) 147 | if args.update_cfg: 148 | checkpoint_path = model_path 149 | print("Updated config file manually") 150 | if args.comment: 151 | cfg["EXPERIMENT"]["ID"] = args.comment 152 | cfg["LOG_NAME"] = cfg["EXPERIMENT"]["ID"] + "_" + time.strftime("%Y%m%d_%H%M%S") 153 | cfg["LOG_DIR"] = os.path.join("./runs", cfg["LOG_NAME"]) 154 | if not os.path.exists(cfg["LOG_DIR"]): 155 | os.makedirs(cfg["LOG_DIR"]) 156 | print("Starting experiment with log name", cfg["LOG_NAME"]) 157 | 158 | model_file_path = "./pcpnet/models" 159 | os.system('cp -r %s %s' % (model_file_path, cfg["LOG_DIR"])) 160 | 161 | ###### Set random seed for torch, numpy and python 162 | set_seed(cfg["DATA_CONFIG"]["RANDOM_SEED"]) 163 | 164 | ###### Logger 165 | tb_logger = TensorBoardLogger( 166 | save_dir=cfg["LOG_DIR"], default_hp_metric=False, name="", version="" 167 | ) 168 | 169 | ###### Dataset 170 | data = KittiOdometryModule(cfg, dataset_path, args.rawdata) 171 | 172 | ###### Model 173 | model = PCPNet(cfg) 174 | 175 | ###### Load checkpoint 176 | if args.resume: 177 | resume_from_checkpoint = checkpoint_path 178 | print("Resuming from checkpoint ", checkpoint_path) 179 | elif args.weights: 180 | model = model.load_from_checkpoint(checkpoint_path, cfg=cfg) 181 | resume_from_checkpoint = None 182 | print("Loading weigths from ", checkpoint_path) 183 | 184 | ###### Callbacks 185 | lr_monitor = LearningRateMonitor(logging_interval="step") 186 | checkpoint = ModelCheckpoint( 187 | monitor="val/loss", 188 | dirpath=os.path.join(cfg["LOG_DIR"], "checkpoints"), 189 | filename="{val/loss:.3f}-{epoch:02d}", 190 | mode="min", 191 | save_top_k=5, 192 | save_last=True 193 | ) 194 | 195 | ###### Trainer 196 | trainer = Trainer( 197 | gpus=cfg["TRAIN"]["N_GPUS"], 198 | logger=tb_logger, 199 | accumulate_grad_batches=cfg["TRAIN"]["BATCH_ACC"], 200 | max_epochs=cfg["TRAIN"]["MAX_EPOCH"], 201 | log_every_n_steps=cfg["TRAIN"][ 202 | "LOG_EVERY_N_STEPS" 203 | ], # times accumulate_grad_batches 204 | callbacks=[lr_monitor, checkpoint], 205 | ) 206 | 207 | ###### Training 208 | trainer.fit(model, data, ckpt_path=resume_from_checkpoint) 209 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou 3 | # This file is covered by the LICENSE file in the root of the project PCPNet: 4 | # https://github.com/Blurryface0814/PCPNet 5 | # Brief: Visualize script for range-image-based point cloud prediction 6 | import os 7 | import argparse 8 | 9 | from pcpnet.utils.visualization import * 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser("./visualize.py") 13 | parser.add_argument( 14 | "--path", "-p", type=str, default=None, help="Path to point clouds" 15 | ) 16 | parser.add_argument( 17 | "--sequence", "-s", type=str, default="08", help="Sequence to visualize" 18 | ) 19 | parser.add_argument("--start", type=int, default=0, help="Start frame") 20 | parser.add_argument("--end", type=int, default=None, help="End frame") 21 | parser.add_argument("--capture", "-c", action="store_true", help="Capture frames") 22 | args, unparsed = parser.parse_known_args() 23 | 24 | gt_path = os.path.join(args.path, args.sequence, "gt/") 25 | end = last_file(gt_path) if not args.end else args.end 26 | start = first_file(gt_path) if not args.start else args.start 27 | assert end > start 28 | print("\nRendering scans [{s},{e}] from:{d}\n".format(s=start, e=end, d=gt_path)) 29 | 30 | path_to_car_model = "car_model/bus_ply.ply" 31 | vis = Visualization( 32 | path=args.path, 33 | sequence=args.sequence, 34 | start=start, 35 | end=end, 36 | capture=args.capture, 37 | path_to_car_model=path_to_car_model, 38 | ) 39 | vis.set_render_options( 40 | mesh_show_wireframe=False, 41 | mesh_show_back_face=False, 42 | show_coordinate_frame=False, 43 | ) 44 | vis.run() 45 | --------------------------------------------------------------------------------