├── LICENSE
├── README.md
├── config
├── parameters(PCPNet-Semantic).yaml
└── parameters.yaml
├── figs
├── motivation.png
├── overall_architecture.png
└── predictions.gif
├── pcpnet
├── __init__.py
├── __pycache__
│ └── __init__.cpython-38.pyc
├── datasets
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ └── datasets.cpython-38.pyc
│ └── datasets.py
├── models
│ ├── PCPNet.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── PCPNet.cpython-38.pyc
│ │ ├── PPT.cpython-38.pyc
│ │ ├── TCNet.cpython-38.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── base.cpython-38.pyc
│ │ ├── layers.cpython-38.pyc
│ │ ├── loss.cpython-38.pyc
│ │ └── transformer.cpython-38.pyc
│ ├── base.py
│ ├── layers.py
│ └── loss.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── logger.cpython-38.pyc
│ ├── preprocess_data.cpython-38.pyc
│ ├── projection.cpython-38.pyc
│ ├── utils.cpython-38.pyc
│ └── visualization.cpython-38.pyc
│ ├── logger.py
│ ├── preprocess_data.py
│ ├── projection.py
│ ├── utils.py
│ └── visualization.py
├── poetry.lock
├── pyTorchChamferDistance
├── LICENSE.md
├── README.md
└── chamfer_distance
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── chamfer_distance.cpython-38.pyc
│ ├── chamfer_distance.cpp
│ ├── chamfer_distance.cu
│ └── chamfer_distance.py
├── pyproject.toml
├── semantic_net
└── rangenet
│ ├── README.md
│ ├── __init__.py
│ ├── __pycache__
│ └── __init__.cpython-38.pyc
│ ├── backbones
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── squeezeseg.cpython-38.pyc
│ │ └── squeezesegV2.cpython-38.pyc
│ ├── darknet.py
│ ├── squeezeseg.py
│ └── squeezesegV2.py
│ ├── common
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ └── laserscan.cpython-38.pyc
│ └── laserscan.py
│ ├── requirements.txt
│ └── tasks
│ ├── __init__.py
│ ├── __pycache__
│ └── __init__.cpython-38.pyc
│ └── semantic
│ ├── README.md
│ ├── __init__.py
│ ├── __pycache__
│ └── __init__.cpython-38.pyc
│ ├── config
│ └── labels
│ │ ├── semantic-kitti-all.yaml
│ │ └── semantic-kitti.yaml
│ ├── decoders
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── squeezeseg.cpython-38.pyc
│ │ └── squeezesegV2.cpython-38.pyc
│ ├── darknet.py
│ ├── squeezeseg.py
│ └── squeezesegV2.py
│ ├── modules
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ └── segmentator.cpython-38.pyc
│ └── segmentator.py
│ ├── postproc
│ ├── CRF.py
│ ├── KNN.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── CRF.cpython-38.pyc
│ │ └── __init__.cpython-38.pyc
│ └── borderMask.py
│ └── readme.md
├── test.py
├── train.py
└── visualize.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Zhen Luo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PCPNet: An Efficient and Semantic-Enhanced Transformer Network for Point Cloud Prediction
2 |
3 | Accepted by IEEE RA-L.
4 |
5 | Developed by [Zhen Luo](https://github.com/Blurryface0814) and [Junyi Ma](https://github.com/BIT-MJY).
6 |
7 |
8 |
9 | *PCPNet predicts future range images based on past range image sequences. The semantic information of sequential range images is extracted for auxiliary training, making the outputs of PCPNet closer to the ground truth in semantics.*
10 |
11 |
12 |
13 | *Ground truth and future range images predicted by PCPNet.*
14 |
15 | ## Contents
16 | 1. [Publication](#Publication)
17 | 2. [Dataset](#Dataset)
18 | 3. [Installation](#Installation)
19 | 4. [Training](#Training)
20 | 5. [Semantic-Auxiliary-Training](#Semantic-Auxiliary-Training)
21 | 6. [Testing](#Testing)
22 | 7. [Visualization](#Visualization)
23 | 8. [Download](#Dwnload)
24 | 9. [Acknowledgment](#Acknowledgment)
25 | 10. [License](#License)
26 |
27 | 
28 | *Overall architecture of our proposed PCPNet. The input range images are first downsampled and compressed along the height and width dimensions respectively to generate the sentence-like features for the following Transformer blocks. The enhanced features are then combined and upsampled to the predicted range images and mask images. Semantic auxiliary training is used to improve the practical value of predicted point clouds.*
29 |
30 | ## Publication
31 | If you use our code in your academic work, please cite the corresponding [paper](https://ieeexplore.ieee.org/abstract/document/10141631?casa_token=VCXSYRIkT88AAAAA:-LLz-ZSNJVSLCYSjXjzpV_DrwtBggRvOKW_1dWxUDNa1IE4VzREdHovp-PyD1zA9rVlRZblXpQu1qfQ):
32 |
33 | ```latex
34 | @ARTICLE{10141631,
35 | author={Luo, Zhen and Ma, Junyi and Zhou, Zijie and Xiong, Guangming},
36 | journal={IEEE Robotics and Automation Letters},
37 | title={PCPNet: An Efficient and Semantic-Enhanced Transformer Network for Point Cloud Prediction},
38 | year={2023},
39 | volume={8},
40 | number={7},
41 | pages={4267-4274},
42 | doi={10.1109/LRA.2023.3281937}}
43 | ```
44 |
45 | ## Dataset
46 | We use the KITTI Odometry dataset to train and evaluate PCPNet in this repo, which you can download [here](http://www.cvlibs.net/datasets/kitti/eval_odometry.php).
47 |
48 |
49 | Besides, we use SemanticKITTI for semantic auxiliary training, which you can download [here](http://semantic-kitti.org/dataset.html#download).
50 |
51 | ## Installation
52 |
53 | ### Source Code
54 | Clone this repository and run
55 | ```bash
56 | cd PCPNet
57 | git submodule update --init
58 | ```
59 | to install the Chamfer distance submodule. The Chamfer distance submodule is originally taken from [here](https://github.com/chrdiller/pyTorchChamferDistance) with some modifications to use it as a submodule.
60 |
61 | All parameters are stored in ```config/parameters.yaml```.
62 |
63 | ### Dependencies
64 | In this project, we use CUDA 11.4, pytorch 1.8.0 and pytorch-lightning 1.5.0. All other dependencies are managed with Python Poetry and can be found in the ```poetry.lock``` file. If you want to use Python Poetry, install it with:
65 | ```bash
66 | curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/install-poetry.py | python -
67 | ```
68 |
69 | Install Python dependencies with Python Poetry
70 | ```bash
71 | poetry install
72 | ```
73 |
74 | and activate the virtual environment in the shell with
75 | ```bash
76 | poetry shell
77 | ```
78 |
79 | ## Training
80 | We process the data in advance to speed up training. The preprocessing is automatically done if ```GENERATE_FILES``` is set to True in ```config/parameters.yaml```.
81 |
82 | If you have not pre-processed the data yet, please set ```GENERATE_FILES: True``` in ```config/parameters.yaml```, and run the training script by
83 | ```bash
84 | python train.py --rawdata /PATH/TO/RAW/KITTI/dataset/sequences --dataset /PATH/TO/PROCESSED/dataset/
85 | ```
86 | in which ```--rawdata``` points to the directory containing the train/val/test sequences specified in the config file and ```--dataset``` points to the directory containing the processed train/val/test sequences
87 |
88 | If you have already pre-processed the data, please set ```GENERATE_FILES: False``` to skip this step, and run the training script by
89 | ```bash
90 | python train.py --dataset /PATH/TO/PROCESSED/dataset/
91 | ```
92 |
93 | To resume from a checkpoint, please run the training script by
94 | ```bash
95 | python train.py --dataset /PATH/TO/PROCESSED/dataset/ --resume /PATH/TO/YOUR/MODEL/
96 | ```
97 | You can also use the flag```--weights``` to initialize the weights from a pre-trained model. Pass the flag ```--help``` if you want to see more options.
98 |
99 | A directory will be created in ```runs``` which saves everything like the model files, used config, logs, and checkpoint.
100 |
101 | ## Semantic-Auxiliary-Training
102 | If you want to use our proposed semantic auxiliary training strategy, you need to first pre-train a semantic segmentation model. We provide codes for semantic auxiliary training using RangeNet++ in ```semantic_net/rangenet```. To use these codes, please first clone the [official codes](https://github.com/PRBonn/lidar-bonnetal) of RangeNet++ and train a semantic segmentation model.
103 |
104 | *Note that we recommend using squeezesegV2 backbone without CRF and only use ```range``` in the ```input_depth``` option while training RangeNet++, according to the codes we are currently providing.* If you want to use other backbones, please make corresponding modifications to ```class loss_semantic``` in ```pcpnet/models/loss.py```.
105 |
106 | Once you have completed the pre-training, you need to copy the folder containing the pre-trained model to ```semantic_net/rangenet/model/``` and modify ```SEMANTIC_PRETRAINED_MODEL``` in ```config/parameters.yaml``` to the folder name.
107 |
108 | After completing the above steps, you can start to use semantic auxiliary training by running the training script by
109 | ```bash
110 | python train.py --dataset /PATH/TO/PROCESSED/dataset/
111 | ```
112 | *Note* that you need to set ```LOSS_WEIGHT_SEMANTIC``` in ```config/parameters.yaml``` to the weight you want (we recommend 1.0) instead of 0.0 before running the training script.
113 |
114 | ## Testing
115 | You can evaluate the performance of the model by running
116 | ```bash
117 | python test.py --dataset /PATH/TO/PROCESSED/dataset/ --model /PATH/TO/YOUR/MODEL/
118 | ```
119 | *Note*: Please use the flag ```-s``` if you want to save the predicted point clouds for visualization, and ```-l``` if you want to test the model on a smaller amount of data. By using the flag ```-o```, you can only save the predicted point clouds without computing loss to accelerate the speed of saving.
120 |
121 | ## Visualization
122 | After passing the ```-s``` flag or the ```-o```flag to the testing script, the predicted range images will be saved as .png files in ```runs/MODEL_NAME/test_TIME/range_view_predictions```. The predicted point clouds are saved to ```runs/MODEL_NAME/test_TIME/point_clouds```. You can visualize the predicted point clouds by running
123 | ```bash
124 | python visualize.py --path runs/MODEL_NAME/test_TIME/point_clouds
125 | ```
126 | Please download the car model from [here](https://drive.google.com/drive/folders/1bmBMdfJaN2ptJVh7gHv1Gy8L1aLMOslr?usp=share_link) and put it into ```./car_ model/``` to display the car during the visualization process.
127 |
128 |
129 | ## Download
130 | You can download our pre-trained model from this [link](https://drive.google.com/drive/folders/1p9q_SoXOsigi8vB_bXxWNZbU-ypSm19J?usp=share_link). Just extract the zip file into ```runs```.
131 |
132 | ## Acknowledgment
133 | We would like to thank Benedikt Mersch, Andres Milioto and Christian Diller et al. for their significant contributions in the field of point cloud processing. Some of the code in this repo is borrowed from [TCNet](https://github.com/PRBonn/point-cloud-prediction), [RangeNet++](https://github.com/PRBonn/lidar-bonnetal), and [pyTorchChamferDistance](https://github.com/chrdiller/pyTorchChamferDistance). Thank all the authors for their awesome projects.
134 |
135 | ## License
136 | This project is free software made available under the MIT License. For details see the LICENSE file.
137 |
--------------------------------------------------------------------------------
/config/parameters(PCPNet-Semantic).yaml:
--------------------------------------------------------------------------------
1 | EXPERIMENT:
2 | ID: PCPNet-Semantic # Give your experiment a unique ID which is used in the log
3 |
4 | DATA_CONFIG:
5 | DATASET_NAME: KITTIOdometry
6 | GENERATE_FILES: False # If true, the data will be pre-processed
7 | COMPUTE_MEAN_AND_STD: False # If true, the mean and std of the training data will be computed to use it in MEAN and STD.
8 | RANDOM_SEED: 1 # Set random seed for torch, numpy and python
9 | DATALOADER:
10 | NUM_WORKER: 4
11 | SHUFFLE: True
12 | SPLIT: # Specify the sequences here. We follow the KITTI format.
13 | TRAIN:
14 | - 0
15 | - 1
16 | - 2
17 | - 3
18 | - 4
19 | - 5
20 | VAL:
21 | - 6
22 | - 7
23 | TEST:
24 | - 8
25 | - 9
26 | - 10
27 |
28 | HEIGHT: 64 # Height of range images
29 | WIDTH: 2048 # Width of range images
30 | FOV_UP: 3.0 # Depends on the used LiDAR sensor
31 | FOV_DOWN: -25.0 # Depends on the used LiDAR sensor
32 | MAX_RANGE: 85.0 # Average max value in training set
33 | MIN_RANGE: 1.0 # Average min value in training set
34 | MEAN:
35 | - 10.839 # Range
36 | - 0.005 # X
37 | - 0.494 # Y
38 | - -1.13 # Z
39 | - 0.287 # Intensity
40 |
41 | STD:
42 | - 9.314 # Range
43 | - 11.521 # X
44 | - 8.262 # Y
45 | - 0.828 # Z
46 | - 0.14 # Intensity
47 |
48 | MODEL:
49 | N_PAST_STEPS: 5 # Number of input range images
50 | N_FUTURE_STEPS: 5 # Number of predicted future range images
51 | MASK_THRESHOLD: 0.5 # Threshold for valid point mask classification
52 | USE:
53 | XYZ: False # If true: x,y, and z coordinates will be used as additional input channels
54 | INTENSITY: False # If true: intensity will be used as additional input channel
55 | CHANNELS: # Number of channels in pre-encoder and post-decoder, respectively.
56 | - 16
57 | - 32
58 | - 64
59 | SKIP_IF_CHANNEL_SIZE: # Adds a skip connection between pre-encoder and post-decoder at desired channels
60 | - 32
61 | - 64
62 | TRANSFORMER_H_LAYERS: 4 # Number of layers in transfomer_H
63 | TRANSFORMER_W_LAYERS: 4 # Number of layers in transfomer_W
64 | CIRCULAR_PADDING: True
65 | NORM: batch # batch, group, none, instance
66 | N_CHANNELS_PER_GROUP: 16
67 | SEMANTIC_NET: rangenet
68 | SEMANTIC_DATA_CONFIG: semantic-kitti.yaml
69 | SEMANTIC_PRETRAINED_MODEL: squeezesegV2
70 |
71 | TRAIN:
72 | LR: 0.001
73 | LR_EPOCH: 1
74 | LR_DECAY: 0.99
75 | MAX_EPOCH: 50
76 | BATCH_SIZE: 2
77 | BATCH_ACC: 4
78 | N_GPUS: 1
79 | LOG_EVERY_N_STEPS: 10
80 | LOSS_WEIGHT_CHAMFER_DISTANCE: 0.0
81 | LOSS_WEIGHT_RANGE_VIEW: 1.0
82 | LOSS_WEIGHT_MASK: 1.0
83 | LOSS_WEIGHT_SEMANTIC: 1.0
84 | CHAMFER_DISTANCE_EVERY_N_VAL_EPOCH: 5 # Log chamfer distance every N val epoch.
85 |
86 | VALIDATION:
87 | SELECTED_SEQUENCE_AND_FRAME: # Only log point clouds for selected validation sequence and frame.
88 | 6:
89 | - 10
90 |
91 | TEST:
92 | N_BATCHES_TO_SAVE: -1 # If set to -1 and SAVE_POINT_CLOUDS is true, all batches of the test set will be saved.
93 | SAVE_POINT_CLOUDS: False
94 | ONLY_SAVE_POINT_CLOUDS: False # Only save point clouds, not compute loss
95 | SEMANTIC_SIMILARITY: True # compute semantic similarity
96 | N_DOWNSAMPLED_POINTS_CD: -1 # Can evaluate the CD on downsampled point clouds. Set -1 to evaluate on full point clouds.
97 | SELECTED_SEQUENCE_AND_FRAME: # Only log point clouds for selected test sequence and frame.
98 | 8:
99 | - 120
100 |
--------------------------------------------------------------------------------
/config/parameters.yaml:
--------------------------------------------------------------------------------
1 | EXPERIMENT:
2 | ID: PCPNet # Give your experiment a unique ID which is used in the log
3 |
4 | DATA_CONFIG:
5 | DATASET_NAME: KITTIOdometry
6 | GENERATE_FILES: False # If true, the data will be pre-processed
7 | COMPUTE_MEAN_AND_STD: False # If true, the mean and std of the training data will be computed to use it in MEAN and STD.
8 | RANDOM_SEED: 1 # Set random seed for torch, numpy and python
9 | DATALOADER:
10 | NUM_WORKER: 4
11 | SHUFFLE: True
12 | SPLIT: # Specify the sequences here. We follow the KITTI format.
13 | TRAIN:
14 | - 0
15 | - 1
16 | - 2
17 | - 3
18 | - 4
19 | - 5
20 | VAL:
21 | - 6
22 | - 7
23 | TEST:
24 | - 8
25 | - 9
26 | - 10
27 |
28 | HEIGHT: 64 # Height of range images
29 | WIDTH: 2048 # Width of range images
30 | FOV_UP: 3.0 # Depends on the used LiDAR sensor
31 | FOV_DOWN: -25.0 # Depends on the used LiDAR sensor
32 | MAX_RANGE: 85.0 # Average max value in training set
33 | MIN_RANGE: 1.0 # Average min value in training set
34 | MEAN:
35 | - 10.839 # Range
36 | - 0.005 # X
37 | - 0.494 # Y
38 | - -1.13 # Z
39 | - 0.287 # Intensity
40 |
41 | STD:
42 | - 9.314 # Range
43 | - 11.521 # X
44 | - 8.262 # Y
45 | - 0.828 # Z
46 | - 0.14 # Intensity
47 |
48 | MODEL:
49 | N_PAST_STEPS: 5 # Number of input range images
50 | N_FUTURE_STEPS: 5 # Number of predicted future range images
51 | MASK_THRESHOLD: 0.5 # Threshold for valid point mask classification
52 | USE:
53 | XYZ: False # If true: x,y, and z coordinates will be used as additional input channels
54 | INTENSITY: False # If true: intensity will be used as additional input channel
55 | CHANNELS: # Number of channels in pre-encoder and post-decoder, respectively.
56 | - 16
57 | - 32
58 | - 64
59 | SKIP_IF_CHANNEL_SIZE: # Adds a skip connection between pre-encoder and post-decoder at desired channels
60 | - 32
61 | - 64
62 | TRANSFORMER_H_LAYERS: 4 # Number of layers in transfomer_H
63 | TRANSFORMER_W_LAYERS: 4 # Number of layers in transfomer_W
64 | CIRCULAR_PADDING: True
65 | NORM: batch # batch, group, none, instance
66 | N_CHANNELS_PER_GROUP: 16
67 | SEMANTIC_NET: rangenet
68 | SEMANTIC_DATA_CONFIG: semantic-kitti.yaml
69 | SEMANTIC_PRETRAINED_MODEL: squeezesegV2
70 |
71 | TRAIN:
72 | LR: 0.001
73 | LR_EPOCH: 1
74 | LR_DECAY: 0.99
75 | MAX_EPOCH: 50
76 | BATCH_SIZE: 2
77 | BATCH_ACC: 4
78 | N_GPUS: 1
79 | LOG_EVERY_N_STEPS: 10
80 | LOSS_WEIGHT_CHAMFER_DISTANCE: 0.0
81 | LOSS_WEIGHT_RANGE_VIEW: 1.0
82 | LOSS_WEIGHT_MASK: 1.0
83 | LOSS_WEIGHT_SEMANTIC: 0.0
84 | CHAMFER_DISTANCE_EVERY_N_VAL_EPOCH: 5 # Log chamfer distance every N val epoch.
85 |
86 | VALIDATION:
87 | SELECTED_SEQUENCE_AND_FRAME: # Only log point clouds for selected validation sequence and frame.
88 | 6:
89 | - 10
90 |
91 | TEST:
92 | N_BATCHES_TO_SAVE: -1 # If set to -1 and SAVE_POINT_CLOUDS is true, all batches of the test set will be saved.
93 | SAVE_POINT_CLOUDS: False
94 | ONLY_SAVE_POINT_CLOUDS: False # Only save point clouds, not compute loss
95 | SEMANTIC_SIMILARITY: False # compute semantic similarity
96 | N_DOWNSAMPLED_POINTS_CD: -1 # Can evaluate the CD on downsampled point clouds. Set -1 to evaluate on full point clouds.
97 | SELECTED_SEQUENCE_AND_FRAME: # Only log point clouds for selected test sequence and frame.
98 | 8:
99 | - 120
100 |
--------------------------------------------------------------------------------
/figs/motivation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/figs/motivation.png
--------------------------------------------------------------------------------
/figs/overall_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/figs/overall_architecture.png
--------------------------------------------------------------------------------
/figs/predictions.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/figs/predictions.gif
--------------------------------------------------------------------------------
/pcpnet/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.1.0"
2 |
--------------------------------------------------------------------------------
/pcpnet/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/datasets/__init__.py
--------------------------------------------------------------------------------
/pcpnet/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/datasets/__pycache__/datasets.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/datasets/__pycache__/datasets.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou
3 | # This file is covered by the LICENSE file in the root of the project PCPNet:
4 | # https://github.com/Blurryface0814/PCPNet
5 | # Brief: Dataset Implementation for KITTI Odometry
6 | import os
7 | import yaml
8 | import torch
9 | import numpy as np
10 | from torch.utils.data import Dataset, DataLoader
11 | from pytorch_lightning import LightningDataModule
12 |
13 | from pcpnet.utils.projection import projection
14 | from pcpnet.utils.preprocess_data import prepare_data, compute_mean_and_std
15 | from pcpnet.utils.utils import load_files
16 |
17 |
18 | class KittiOdometryModule(LightningDataModule):
19 | """A Pytorch Lightning module for KITTI Odometry"""
20 |
21 | def __init__(self, cfg, dataset_path, rawdata_path=None):
22 | """Method to initizalize the Kitti Odometry dataset class
23 |
24 | Args:
25 | cfg: config dict
26 |
27 | Returns:
28 | None
29 | """
30 | super(KittiOdometryModule, self).__init__()
31 | self.cfg = cfg
32 | self.dataset_path = dataset_path
33 | self.rawdata_path = rawdata_path
34 |
35 | def prepare_data(self):
36 | """Call prepare_data method to generate npy range images from raw LiDAR data"""
37 | if self.cfg["DATA_CONFIG"]["GENERATE_FILES"]:
38 | prepare_data(self.cfg, self.dataset_path, self.rawdata_path)
39 |
40 | def setup(self, stage=None):
41 | """Dataloader and iterators for training, validation and test data"""
42 | ########## Point dataset splits
43 | train_set = KittiOdometryRaw(self.cfg, split="train", dataset_path=self.dataset_path)
44 |
45 | val_set = KittiOdometryRaw(self.cfg, split="val", dataset_path=self.dataset_path)
46 |
47 | test_set = KittiOdometryRaw(self.cfg, split="test", dataset_path=self.dataset_path)
48 |
49 | ########## Generate dataloaders and iterables
50 |
51 | self.train_loader = DataLoader(
52 | dataset=train_set,
53 | batch_size=self.cfg["TRAIN"]["BATCH_SIZE"],
54 | shuffle=self.cfg["DATA_CONFIG"]["DATALOADER"]["SHUFFLE"],
55 | num_workers=self.cfg["DATA_CONFIG"]["DATALOADER"]["NUM_WORKER"],
56 | pin_memory=True,
57 | drop_last=False,
58 | timeout=0,
59 | )
60 | self.train_iter = iter(self.train_loader)
61 |
62 | self.valid_loader = DataLoader(
63 | dataset=val_set,
64 | batch_size=self.cfg["TRAIN"]["BATCH_SIZE"],
65 | shuffle=False,
66 | num_workers=self.cfg["DATA_CONFIG"]["DATALOADER"]["NUM_WORKER"],
67 | pin_memory=True,
68 | drop_last=False,
69 | timeout=0,
70 | )
71 | self.valid_iter = iter(self.valid_loader)
72 |
73 | self.test_loader = DataLoader(
74 | dataset=test_set,
75 | batch_size=self.cfg["TRAIN"]["BATCH_SIZE"],
76 | shuffle=False,
77 | num_workers=self.cfg["DATA_CONFIG"]["DATALOADER"]["NUM_WORKER"],
78 | pin_memory=True,
79 | drop_last=False,
80 | timeout=0,
81 | )
82 | self.test_iter = iter(self.test_loader)
83 |
84 | # Optionally compute statistics of training data
85 | if self.cfg["DATA_CONFIG"]["COMPUTE_MEAN_AND_STD"]:
86 | compute_mean_and_std(self.cfg, self.train_loader)
87 |
88 | print(
89 | "Loaded {:d} training, {:d} validation and {:d} test samples.".format(
90 | len(train_set), len(val_set), (len(test_set))
91 | )
92 | )
93 |
94 | def train_dataloader(self):
95 | return self.train_loader
96 |
97 | def val_dataloader(self):
98 | return self.valid_loader
99 |
100 | def test_dataloader(self):
101 | return self.test_loader
102 |
103 |
104 | class KittiOdometryRaw(Dataset):
105 | """Dataset class for range image-based point cloud prediction"""
106 |
107 | def __init__(self, cfg, split, dataset_path=None):
108 | """Read parameters and scan data
109 |
110 | Args:
111 | cfg (dict): Config parameters
112 | split (str): Data split
113 |
114 | Raises:
115 | Exception: [description]
116 | """
117 | self.cfg = cfg
118 | self.root_dir = dataset_path
119 | self.height = self.cfg["DATA_CONFIG"]["HEIGHT"]
120 | self.width = self.cfg["DATA_CONFIG"]["WIDTH"]
121 | self.n_channels = 5 + 1
122 |
123 | self.n_past_steps = self.cfg["MODEL"]["N_PAST_STEPS"]
124 | self.n_future_steps = self.cfg["MODEL"]["N_FUTURE_STEPS"]
125 |
126 | # Projection class for mapping from range image to 3D point cloud
127 | self.projection = projection(self.cfg)
128 |
129 | if split == "train":
130 | self.sequences = self.cfg["DATA_CONFIG"]["SPLIT"]["TRAIN"]
131 | elif split == "val":
132 | self.sequences = self.cfg["DATA_CONFIG"]["SPLIT"]["VAL"]
133 | elif split == "test":
134 | self.sequences = self.cfg["DATA_CONFIG"]["SPLIT"]["TEST"]
135 | else:
136 | raise Exception("Split must be train/val/test")
137 |
138 | # Create a dict filenames that maps from a sequence number to a list of files in the dataset
139 | self.filenames_range = {}
140 | self.filenames_xyz = {}
141 | self.filenames_intensity = {}
142 | self.filenames_label = {}
143 |
144 | # Create a dict idx_mapper that maps from a dataset idx to a sequence number and the index of the current scan
145 | self.dataset_size = 0
146 | self.idx_mapper = {}
147 | idx = 0
148 | for seq in self.sequences:
149 | seqstr = "{0:02d}".format(int(seq))
150 |
151 | scan_path_range = os.path.join(self.root_dir, seqstr, "processed", "range")
152 | self.filenames_range[seq] = load_files(scan_path_range)
153 |
154 | scan_path_xyz = os.path.join(self.root_dir, seqstr, "processed", "xyz")
155 | self.filenames_xyz[seq] = load_files(scan_path_xyz)
156 | assert len(self.filenames_range[seq]) == len(self.filenames_xyz[seq])
157 |
158 | scan_path_intensity = os.path.join(self.root_dir, seqstr, "processed", "intensity")
159 | self.filenames_intensity[seq] = load_files(scan_path_intensity)
160 | assert len(self.filenames_range[seq]) == len(self.filenames_intensity[seq])
161 |
162 | scan_path_label = os.path.join(self.root_dir, seqstr, "processed", "labels")
163 | self.filenames_label[seq] = load_files(scan_path_label)
164 | assert len(self.filenames_range[seq]) == len(self.filenames_label[seq])
165 |
166 | # Get number of sequences based on number of past and future steps
167 | n_samples_sequence = max(
168 | 0,
169 | len(self.filenames_range[seq])
170 | - self.n_past_steps
171 | - self.n_future_steps
172 | + 1,
173 | )
174 |
175 | # Add to idx mapping
176 | for sample_idx in range(n_samples_sequence):
177 | scan_idx = self.n_past_steps + sample_idx - 1
178 | self.idx_mapper[idx] = (seq, scan_idx)
179 | idx += 1
180 | self.dataset_size += n_samples_sequence
181 |
182 | def __len__(self):
183 | return self.dataset_size
184 |
185 | def __getitem__(self, idx):
186 | """Load and concatenate range image channels
187 |
188 | Args:
189 | idx (int): Sample index
190 |
191 | Returns:
192 | item: Dataset dictionary item
193 | """
194 | seq, scan_idx = self.idx_mapper[idx]
195 |
196 | # Load past data
197 | past_data = torch.empty(
198 | [self.n_channels, self.n_past_steps, self.height, self.width]
199 | )
200 |
201 | from_idx = scan_idx - self.n_past_steps + 1
202 | to_idx = scan_idx
203 | past_filenames_range = self.filenames_range[seq][from_idx : to_idx + 1]
204 | past_filenames_xyz = self.filenames_xyz[seq][from_idx : to_idx + 1]
205 | past_filenames_intensity = self.filenames_intensity[seq][from_idx : to_idx + 1]
206 | past_filenames_label = self.filenames_label[seq][from_idx: to_idx + 1]
207 |
208 | for t in range(self.n_past_steps):
209 | past_data[0, t, :, :] = self.load_range(past_filenames_range[t])
210 | past_data[1:4, t, :, :] = self.load_xyz(past_filenames_xyz[t])
211 | past_data[4, t, :, :] = self.load_intensity(past_filenames_intensity[t])
212 | past_data[5, t, :, :] = self.load_label(past_filenames_label[t])
213 |
214 | # Load future data
215 | fut_data = torch.empty(
216 | [self.n_channels, self.n_future_steps, self.height, self.width]
217 | )
218 |
219 | from_idx = scan_idx + 1
220 | to_idx = scan_idx + self.n_future_steps
221 | fut_filenames_range = self.filenames_range[seq][from_idx : to_idx + 1]
222 | fut_filenames_xyz = self.filenames_xyz[seq][from_idx : to_idx + 1]
223 | fut_filenames_intensity = self.filenames_intensity[seq][from_idx : to_idx + 1]
224 | fut_filenames_label = self.filenames_label[seq][from_idx: to_idx + 1]
225 |
226 | for t in range(self.n_future_steps):
227 | fut_data[0, t, :, :] = self.load_range(fut_filenames_range[t])
228 | fut_data[1:4, t, :, :] = self.load_xyz(fut_filenames_xyz[t])
229 | fut_data[4, t, :, :] = self.load_intensity(fut_filenames_intensity[t])
230 | fut_data[5, t, :, :] = self.load_label(fut_filenames_label[t])
231 |
232 | item = {"past_data": past_data, "fut_data": fut_data, "meta": (seq, scan_idx)}
233 | return item
234 |
235 |
236 | def load_range(self, filename):
237 | """Load .npy range image as (1,height,width) tensor"""
238 | rv = torch.Tensor(np.load(filename)).float()
239 | return rv
240 |
241 | def load_xyz(self, filename):
242 | """Load .npy xyz values as (3,height,width) tensor"""
243 | xyz = torch.Tensor(np.load(filename)).float()[:, :, :3]
244 | xyz = xyz.permute(2, 0, 1)
245 | return xyz
246 |
247 | def load_intensity(self, filename):
248 | """Load .npy intensity values as (1,height,width) tensor"""
249 | intensity = torch.Tensor(np.load(filename)).float()
250 | return intensity
251 |
252 | def load_label(self, filename):
253 | """Load .npy label values as (1,height,width) tensor"""
254 | label = torch.Tensor(np.load(filename)).int()
255 | return label
256 |
257 |
258 | if __name__ == "__main__":
259 | config_filename = "./config/parameters.yaml"
260 | cfg = yaml.safe_load(open(config_filename))
261 | dataset_path = ""
262 | rawdata_path = ""
263 | data = KittiOdometryModule(cfg, dataset_path, rawdata_path)
264 | data.prepare_data()
265 | data.setup()
266 |
267 | item = data.valid_loader.dataset.__getitem__(1)
268 |
269 | def normalize(image):
270 | min = np.min(image)
271 | max = np.max(image)
272 | normalized_image = (image - min) / (max - min)
273 | return normalized_image
274 |
275 | import matplotlib.pyplot as plt
276 |
277 | fig, axs = plt.subplots(6, 1, sharex=True, figsize=(30, 30 * 6 * 64 / 2048))
278 |
279 | axs[0].imshow(normalize(item["fut_data"][0, 0, :, :].numpy()))
280 | axs[0].set_title("Range")
281 | axs[1].imshow(normalize(item["fut_data"][1, 0, :, :].numpy()))
282 | axs[1].set_title("X")
283 | axs[2].imshow(normalize(item["fut_data"][2, 0, :, :].numpy()))
284 | axs[2].set_title("Y")
285 | axs[3].imshow(normalize(item["fut_data"][3, 0, :, :].numpy()))
286 | axs[3].set_title("Z")
287 | axs[4].imshow(normalize(item["fut_data"][4, 0, :, :].numpy()))
288 | axs[4].set_title("Intensity")
289 | axs[5].imshow(normalize(item["fut_data"][5, 0, :, :].numpy()))
290 | axs[5].set_title("Label")
291 |
292 | plt.show()
293 |
--------------------------------------------------------------------------------
/pcpnet/models/PCPNet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou
3 | # This file is covered by the LICENSE file in the root of the project PCPNet:
4 | # https://github.com/Blurryface0814/PCPNet
5 | # Brief: Main Network Architecture of PCPNet
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import yaml
10 | import math
11 |
12 | from pcpnet.models.base import BasePredictionModel
13 | from pcpnet.models.layers import CustomConv3d, DownBlock, UpBlock, Transformer_W, Transformer_H, eca_block
14 |
15 |
16 | class PCPNet(BasePredictionModel):
17 | def __init__(self, cfg):
18 | """Init all layers needed for range image-based point cloud prediction"""
19 | super().__init__(cfg)
20 | self.channels = self.cfg["MODEL"]["CHANNELS"]
21 | self.skip_if_channel_size = self.cfg["MODEL"]["SKIP_IF_CHANNEL_SIZE"]
22 | self.circular_padding = self.cfg["MODEL"]["CIRCULAR_PADDING"]
23 | self.transformer_h_layers = self.cfg["MODEL"]["TRANSFORMER_H_LAYERS"]
24 | self.transformer_w_layers = self.cfg["MODEL"]["TRANSFORMER_W_LAYERS"]
25 |
26 | self.input_layer = CustomConv3d(
27 | self.n_inputs,
28 | self.channels[0],
29 | kernel_size=(1, 1, 1),
30 | stride=(1, 1, 1),
31 | bias=True,
32 | circular_padding=self.circular_padding,
33 | )
34 |
35 | self.DownLayers = nn.ModuleList()
36 | for i in range(len(self.channels) - 1):
37 | if self.channels[i + 1] in self.skip_if_channel_size:
38 | self.DownLayers.append(
39 | DownBlock(
40 | self.cfg,
41 | self.channels[i],
42 | self.channels[i + 1],
43 | down_stride_H=2,
44 | down_stride_W=4,
45 | skip=True,
46 | )
47 | )
48 | else:
49 | self.DownLayers.append(
50 | DownBlock(
51 | self.cfg,
52 | self.channels[i],
53 | self.channels[i + 1],
54 | down_stride_H=2,
55 | down_stride_W=4,
56 | skip=False,
57 | )
58 | )
59 |
60 | self.Transformer_H = Transformer_H(self.cfg, self.channels[-1],
61 | layers=self.transformer_h_layers, skip=True)
62 | self.Transformer_W = Transformer_W(self.cfg, self.channels[-1],
63 | layers=self.transformer_w_layers, skip=True)
64 |
65 | self.channel_attention = eca_block(2 * self.channels[-1], b=1, gamma=2)
66 |
67 | self.mid_layer = CustomConv3d(
68 | 2 * self.channels[-1],
69 | self.channels[-1],
70 | kernel_size=(1, 1, 1),
71 | stride=(1, 1, 1),
72 | bias=True,
73 | circular_padding=self.circular_padding,
74 | )
75 |
76 | self.UpLayers = nn.ModuleList()
77 | for i in reversed(range(len(self.channels) - 1)):
78 | if self.channels[i + 1] in self.skip_if_channel_size:
79 | self.UpLayers.append(
80 | UpBlock(
81 | self.cfg,
82 | self.channels[i + 1],
83 | self.channels[i],
84 | up_stride_H=2,
85 | up_stride_W=4,
86 | skip=True,
87 | )
88 | )
89 | else:
90 | self.UpLayers.append(
91 | UpBlock(
92 | self.cfg,
93 | self.channels[i + 1],
94 | self.channels[i],
95 | up_stride_H=2,
96 | up_stride_W=4,
97 | skip=False,
98 | )
99 | )
100 |
101 | self.n_outputs = 2
102 | self.output_layer = CustomConv3d(
103 | self.channels[0],
104 | self.n_outputs,
105 | kernel_size=(1, 1, 1),
106 | stride=(1, 1, 1),
107 | bias=True,
108 | circular_padding=self.circular_padding,
109 | )
110 |
111 | def forward(self, x):
112 | """Forward range image-based point cloud prediction
113 |
114 | Args:
115 | x (torch.tensor): Input tensor of concatenated, unnormalize range images
116 |
117 | Returns:
118 | dict: Containing the predicted range tensor and mask logits
119 | """
120 | # Only select inputs specified in base model
121 | x = x[:, self.inputs, :, :, :]
122 | batch_size, n_inputs, n_past_steps, H, W = x.size()
123 | assert n_inputs == self.n_inputs
124 |
125 | # Get mask of valid points
126 | past_mask = x != -1.0
127 |
128 | # Standardization and set invalid points to zero
129 | mean = self.mean[None, self.inputs, None, None, None]
130 | std = self.std[None, self.inputs, None, None, None]
131 | x = torch.true_divide(x - mean, std)
132 | x = x * past_mask
133 |
134 | skip_list = []
135 | x = x.view(batch_size, n_inputs, n_past_steps, H, W) # [B, C, T, H, W]
136 |
137 | x = self.input_layer(x)
138 | for layer in self.DownLayers:
139 | x = layer(x)
140 | if layer.skip:
141 | skip_list.append(x.clone())
142 |
143 | # [B, C, T, H=16, W=128]
144 | x_h = self.Transformer_H(x)
145 | x_w = self.Transformer_W(x)
146 | x = torch.cat((x_h, x_w), dim=1)
147 | x = self.channel_attention(x)
148 | x = self.mid_layer(x)
149 |
150 | for layer in self.UpLayers:
151 | if layer.skip:
152 | x = layer(x, skip_list.pop())
153 | else:
154 | x = layer(x)
155 |
156 | x = self.output_layer(x)
157 |
158 | output = {}
159 | output["rv"] = self.min_range + nn.Sigmoid()(x[:, 0, :, :, :]) * (
160 | self.max_range - self.min_range
161 | )
162 | output["mask_logits"] = x[:, 1, :, :, :]
163 |
164 | return output
165 |
--------------------------------------------------------------------------------
/pcpnet/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__init__.py
--------------------------------------------------------------------------------
/pcpnet/models/__pycache__/PCPNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/PCPNet.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/models/__pycache__/PPT.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/PPT.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/models/__pycache__/TCNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/TCNet.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/models/__pycache__/base.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/base.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/models/__pycache__/layers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/layers.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/models/__pycache__/loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/loss.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/models/__pycache__/transformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/models/__pycache__/transformer.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/models/base.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou
3 | # This file is covered by the LICENSE file in the root of the project PCPNet:
4 | # https://github.com/Blurryface0814/PCPNet
5 | # Brief: Base Lightning Module of PCPNet
6 | import os
7 | import time
8 | import yaml
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | from pytorch_lightning.core.lightning import LightningModule
15 | from pcpnet.models.loss import Loss
16 | from pcpnet.utils.projection import projection
17 | from pcpnet.utils.logger import log_point_clouds, save_range_mask_and_semantic, save_point_clouds
18 |
19 |
20 | class BasePredictionModel(LightningModule):
21 | """Pytorch Lightning base model for point cloud prediction"""
22 |
23 | def __init__(self, cfg):
24 | """Init base model
25 |
26 | Args:
27 | cfg (dict): Config parameters
28 | """
29 | super(BasePredictionModel, self).__init__()
30 | self.cfg = cfg
31 | self.save_hyperparameters(self.cfg)
32 |
33 | self.height = self.cfg["DATA_CONFIG"]["HEIGHT"]
34 | self.width = self.cfg["DATA_CONFIG"]["WIDTH"]
35 | self.min_range = self.cfg["DATA_CONFIG"]["MIN_RANGE"]
36 | self.max_range = self.cfg["DATA_CONFIG"]["MAX_RANGE"]
37 | self.register_buffer("mean", torch.Tensor(self.cfg["DATA_CONFIG"]["MEAN"]))
38 | self.register_buffer("std", torch.Tensor(self.cfg["DATA_CONFIG"]["STD"]))
39 |
40 | self.n_past_steps = self.cfg["MODEL"]["N_PAST_STEPS"]
41 | self.n_future_steps = self.cfg["MODEL"]["N_FUTURE_STEPS"]
42 | self.use_xyz = self.cfg["MODEL"]["USE"]["XYZ"]
43 | self.use_intensity = self.cfg["MODEL"]["USE"]["INTENSITY"]
44 |
45 | # Create list of index used in input
46 | self.inputs = [0]
47 | if self.use_xyz:
48 | self.inputs.append(1)
49 | self.inputs.append(2)
50 | self.inputs.append(3)
51 | if self.use_intensity:
52 | self.inputs.append(4)
53 | self.n_inputs = len(self.inputs)
54 |
55 | # Init loss
56 | self.loss = Loss(self.cfg)
57 |
58 | # Init projection class for re-projcecting from range images to 3D point clouds
59 | self.projection = projection(self.cfg)
60 |
61 | self.chamfer_distances_tensor = torch.zeros(self.n_future_steps, 1)
62 |
63 | def forward(self, x):
64 | pass
65 |
66 | def configure_optimizers(self):
67 | """Optimizers"""
68 | # optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg["TRAIN"]["LR"])
69 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()),
70 | lr=self.cfg["TRAIN"]["LR"])
71 | scheduler = torch.optim.lr_scheduler.StepLR(
72 | optimizer,
73 | step_size=self.cfg["TRAIN"]["LR_EPOCH"],
74 | gamma=self.cfg["TRAIN"]["LR_DECAY"],
75 | )
76 | return [optimizer], [scheduler]
77 |
78 | def on_load_checkpoint(self, checkpoint):
79 | # load the semantic weights here
80 | for param_name in self.loss.state_dict():
81 | checkpoint['state_dict']['loss.' + param_name] = self.loss.state_dict()[param_name]
82 |
83 | def on_save_checkpoint(self, checkpoint):
84 | # pop the semantic weights here
85 | for param_name in self.loss.state_dict():
86 | del checkpoint['state_dict']['loss.' + param_name]
87 |
88 | def training_step(self, batch, batch_idx):
89 | """Pytorch Lightning training step including logging
90 |
91 | Args:
92 | batch (dict): A dict with a batch of training samples
93 | batch_idx (int): Index of batch in dataset
94 |
95 | Returns:
96 | loss (dict): Multiple loss components
97 | """
98 | past = batch["past_data"]
99 | future = batch["fut_data"]
100 | output = self.forward(past)
101 | loss = self.loss(output, future, "train", self.current_epoch)
102 |
103 | self.log("train/loss", loss["loss"])
104 | self.log("train/mean_chamfer_distance", loss["mean_chamfer_distance"])
105 | self.log("train/final_chamfer_distance", loss["final_chamfer_distance"])
106 | self.log("train/loss_range_view", loss["loss_range_view"])
107 | self.log("train/loss_mask", loss["loss_mask"])
108 | self.log("train/loss_semantic", loss["loss_semantic"])
109 |
110 | return loss
111 |
112 | def validation_step(self, batch, batch_idx):
113 | """Pytorch Lightning validation step including logging
114 |
115 | Args:
116 | batch (dict): A dict with a batch of validation samples
117 | batch_idx (int): Index of batch in dataset
118 |
119 | Returns:
120 | None
121 | """
122 | past = batch["past_data"]
123 | future = batch["fut_data"]
124 | output = self.forward(past)
125 |
126 | loss, learning_map_inv, color_map, output_argmax, target_argmax, label, \
127 | rand_t, similarity_output, similarity_target = self.loss(
128 | output, future, "val", self.current_epoch
129 | )
130 |
131 | self.log("val/loss", loss["loss"], on_epoch=True)
132 | self.log(
133 | "val/mean_chamfer_distance", loss["mean_chamfer_distance"], on_epoch=True
134 | )
135 | self.log(
136 | "val/final_chamfer_distance", loss["final_chamfer_distance"], on_epoch=True
137 | )
138 | self.log("val/loss_range_view", loss["loss_range_view"], on_epoch=True)
139 | self.log("val/loss_mask", loss["loss_mask"], on_epoch=True)
140 | self.log("val/loss_semantic", loss["loss_semantic"], on_epoch=True)
141 |
142 | selected_sequence_and_frame = self.cfg["VALIDATION"]["SELECTED_SEQUENCE_AND_FRAME"]
143 | sequence_batch, frame_batch = batch["meta"]
144 | for sample_idx in range(frame_batch.shape[0]):
145 | sequence = sequence_batch[sample_idx].item()
146 | frame = frame_batch[sample_idx].item()
147 | if sequence in selected_sequence_and_frame.keys():
148 | if frame in selected_sequence_and_frame[sequence]:
149 | log_point_clouds(
150 | self.logger.experiment,
151 | self.projection,
152 | self.current_epoch,
153 | batch,
154 | output,
155 | sample_idx,
156 | sequence,
157 | frame,
158 | )
159 | save_range_mask_and_semantic(
160 | self.cfg,
161 | self.projection,
162 | batch,
163 | output,
164 | sample_idx,
165 | sequence,
166 | frame,
167 | learning_map_inv,
168 | color_map,
169 | output_argmax,
170 | target_argmax,
171 | label,
172 | rand_t
173 | )
174 |
175 | def test_step(self, batch, batch_idx):
176 | """Pytorch Lightning test step including logging
177 |
178 | Args:
179 | batch (dict): A dict with a batch of test samples
180 | batch_idx (int): Index of batch in dataset
181 |
182 | Returns:
183 | loss (dict): Multiple loss components
184 | """
185 | past = batch["past_data"]
186 | future = batch["fut_data"]
187 |
188 | batch_size, n_inputs, n_future_steps, H, W = past.shape
189 |
190 | start = time.time()
191 | output = self.forward(past)
192 | inference_time = (time.time() - start) / batch_size
193 | self.log("test/inference_time", inference_time, on_epoch=True)
194 |
195 | if not self.cfg["TEST"]["ONLY_SAVE_POINT_CLOUDS"]:
196 | loss, learning_map_inv, color_map, output_argmax, target_argmax, label, \
197 | rand_t, similarity_output, similarity_target = self.loss(
198 | output, future, "test", self.current_epoch
199 | )
200 |
201 | self.log("test/loss_range_view", loss["loss_range_view"], on_epoch=True)
202 | self.log("test/loss_mask", loss["loss_mask"], on_epoch=True)
203 | self.log("test/loss_semantic", loss["loss_semantic"], on_epoch=True)
204 | self.log("test/similarity_output", similarity_output.detach(), on_epoch=True)
205 | self.log("test/similarity_target", similarity_target.detach(), on_epoch=True)
206 |
207 | for step, value in loss["chamfer_distance"].items():
208 | self.log("test/chamfer_distance_{:d}".format(step), value, on_epoch=True)
209 |
210 | self.log(
211 | "test/mean_chamfer_distance", loss["mean_chamfer_distance"], on_epoch=True
212 | )
213 | self.log(
214 | "test/final_chamfer_distance", loss["final_chamfer_distance"], on_epoch=True
215 | )
216 |
217 | self.chamfer_distances_tensor = torch.cat(
218 | (self.chamfer_distances_tensor, loss["chamfer_distances_tensor"]), dim=1
219 | )
220 |
221 | selected_sequence_and_frame = self.cfg["TEST"]["SELECTED_SEQUENCE_AND_FRAME"]
222 | sequence_batch, frame_batch = batch["meta"]
223 | for sample_idx in range(frame_batch.shape[0]):
224 | sequence = sequence_batch[sample_idx].item()
225 | frame = frame_batch[sample_idx].item()
226 | if sequence in selected_sequence_and_frame.keys():
227 | if frame in selected_sequence_and_frame[sequence]:
228 | if self.logger:
229 | log_point_clouds(
230 | self.logger.experiment,
231 | self.projection,
232 | self.current_epoch,
233 | batch,
234 | output,
235 | sample_idx,
236 | sequence,
237 | frame,
238 | )
239 | save_range_mask_and_semantic(
240 | self.cfg,
241 | self.projection,
242 | batch,
243 | output,
244 | sample_idx,
245 | sequence,
246 | frame,
247 | learning_map_inv,
248 | color_map,
249 | output_argmax,
250 | target_argmax,
251 | label,
252 | rand_t,
253 | test=True
254 | )
255 | else:
256 | loss = None
257 |
258 | if self.cfg["TEST"]["SAVE_POINT_CLOUDS"] or self.cfg["TEST"]["ONLY_SAVE_POINT_CLOUDS"]:
259 | save_point_clouds(self.cfg, self.projection, batch, output)
260 |
261 | return loss
262 |
263 | def test_epoch_end(self, outputs):
264 | # Remove first row since it was initialized with zero
265 | self.chamfer_distances_tensor = self.chamfer_distances_tensor[:, 1:]
266 | n_steps, _ = self.chamfer_distances_tensor.shape
267 | mean = torch.mean(self.chamfer_distances_tensor, dim=1)
268 | std = torch.std(self.chamfer_distances_tensor, dim=1)
269 | q = torch.tensor([0.25, 0.5, 0.75])
270 | quantile = torch.quantile(self.chamfer_distances_tensor, q, dim=1)
271 |
272 | chamfer_distances = []
273 | for s in range(n_steps):
274 | chamfer_distances.append(self.chamfer_distances_tensor[s, :].tolist())
275 | print("Final size of CD: ", self.chamfer_distances_tensor.shape)
276 | print("Mean :", mean)
277 | print("Std :", std)
278 | print("Quantile :", quantile)
279 |
280 | testdir = os.path.join(self.cfg["LOG_DIR"], self.cfg["TEST"]["DIR_NAME"])
281 | if not os.path.exists(testdir):
282 | os.makedirs(testdir)
283 |
284 | filename = os.path.join(
285 | testdir, "stats" + ".yaml"
286 | )
287 |
288 | log_to_save = {
289 | "mean": mean.tolist(),
290 | "std": std.tolist(),
291 | "quantile": quantile.tolist(),
292 | "chamfer_distances": chamfer_distances,
293 | }
294 | with open(filename, "w") as yaml_file:
295 | yaml.dump(log_to_save, yaml_file, default_flow_style=False)
296 |
--------------------------------------------------------------------------------
/pcpnet/models/loss.py:
--------------------------------------------------------------------------------
1 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou
2 | # This file is covered by the LICENSE file in the root of the project PCPNet:
3 | # https://github.com/Blurryface0814/PCPNet
4 | # Brief: Implementation of the Loss Modules
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import os
9 | import yaml
10 | import math
11 | import random
12 |
13 | from pyTorchChamferDistance.chamfer_distance import ChamferDistance
14 | from pcpnet.utils.projection import projection
15 | from pcpnet.utils.logger import map
16 | import importlib
17 | import numpy as np
18 |
19 |
20 | class Loss(nn.Module):
21 | """Combined loss for point cloud prediction"""
22 |
23 | def __init__(self, cfg):
24 | """Init"""
25 | super().__init__()
26 | self.cfg = cfg
27 | self.n_future_steps = self.cfg["MODEL"]["N_FUTURE_STEPS"]
28 | self.loss_weight_cd = self.cfg["TRAIN"]["LOSS_WEIGHT_CHAMFER_DISTANCE"]
29 | self.loss_weight_rv = self.cfg["TRAIN"]["LOSS_WEIGHT_RANGE_VIEW"]
30 | self.loss_weight_mask = self.cfg["TRAIN"]["LOSS_WEIGHT_MASK"]
31 | self.loss_weight_semantic = self.cfg["TRAIN"]["LOSS_WEIGHT_SEMANTIC"]
32 | self.cd_every_n_val_epoch = self.cfg["TRAIN"]["CHAMFER_DISTANCE_EVERY_N_VAL_EPOCH"]
33 |
34 | self.loss_range = loss_range(self.cfg)
35 | self.chamfer_distance = chamfer_distance(self.cfg)
36 | self.loss_mask = loss_mask(self.cfg)
37 | self.loss_semantic = loss_semantic(self.cfg)
38 |
39 | def forward(self, output, target, mode, current_epoch=None):
40 | """Forward pass with multiple loss components
41 |
42 | Args:
43 | output (dict): Predicted mask logits and ranges
44 | target (torch.tensor): Target range image
45 | mode (str): Mode (train,val,test)
46 |
47 | Returns:
48 | dict: Dict with loss components
49 | """
50 | cd_flag = ((current_epoch + 1) % self.cd_every_n_val_epoch == 0)
51 |
52 | target_range_image = target[:, 0, :, :, :]
53 | target_label = target[:, 5, :, :, :]
54 |
55 | # Range view
56 | loss_range_view = self.loss_range(output, target_range_image)
57 |
58 | # Mask
59 | loss_mask = self.loss_mask(output, target_range_image)
60 |
61 | # Semantic
62 | # if self.loss_weight_semantic > 0.0 or mode == "val" or mode == "test":
63 | if self.loss_weight_semantic > 0.0:
64 | loss_semantic, output_argmax, target_argmax, label, rand_t, \
65 | similarity_output, similarity_target = self.loss_semantic(
66 | output, target_range_image, target_label, mode
67 | )
68 | else:
69 | loss_semantic = torch.zeros(1).type_as(target_range_image)
70 | B, T, H, W = target_range_image.shape
71 | output_argmax = torch.zeros((B, 1, H, W)).type_as(target_range_image)
72 | target_argmax = output_argmax
73 | label = output_argmax
74 | rand_t = "none"
75 | similarity_output = torch.zeros(1).type_as(target_range_image)
76 | similarity_target = torch.zeros(1).type_as(target_range_image)
77 |
78 | # Chamfer Distance
79 | if self.loss_weight_cd > 0.0 or (mode == "val" and cd_flag) or mode == "test":
80 | chamfer_distance, chamfer_distances_tensor = self.chamfer_distance(
81 | output, target, self.cfg["TEST"]["N_DOWNSAMPLED_POINTS_CD"]
82 | )
83 | loss_chamfer_distance = sum([cd for cd in chamfer_distance.values()]) / len(
84 | chamfer_distance
85 | )
86 | detached_chamfer_distance = {
87 | step: cd.detach() for step, cd in chamfer_distance.items()
88 | }
89 | else:
90 | chamfer_distance = dict(
91 | (step, torch.zeros(1).type_as(target_range_image))
92 | for step in range(self.n_future_steps)
93 | )
94 | chamfer_distances_tensor = torch.zeros(self.n_future_steps, 1)
95 | loss_chamfer_distance = torch.zeros_like(loss_range_view)
96 | detached_chamfer_distance = chamfer_distance
97 |
98 | loss = (
99 | self.loss_weight_cd * loss_chamfer_distance
100 | + self.loss_weight_rv * loss_range_view
101 | + self.loss_weight_mask * loss_mask
102 | + self.loss_weight_semantic * loss_semantic
103 | )
104 |
105 | loss_dict = {
106 | "loss": loss,
107 | "chamfer_distance": detached_chamfer_distance,
108 | "chamfer_distances_tensor": chamfer_distances_tensor.detach(),
109 | "mean_chamfer_distance": loss_chamfer_distance.detach(),
110 | "final_chamfer_distance": chamfer_distance[
111 | self.n_future_steps - 1
112 | ].detach(),
113 | "loss_range_view": loss_range_view.detach(),
114 | "loss_mask": loss_mask.detach(),
115 | "loss_semantic": loss_semantic.detach()
116 | }
117 |
118 | if mode == "val" or mode == "test":
119 | return loss_dict, \
120 | self.loss_semantic.learning_map_inv, \
121 | self.loss_semantic.color_map, \
122 | output_argmax, \
123 | target_argmax, \
124 | label, \
125 | str(rand_t), \
126 | similarity_output, \
127 | similarity_target
128 |
129 | return loss_dict
130 |
131 |
132 | class loss_mask(nn.Module):
133 | """Binary cross entropy loss for prediction of valid mask"""
134 |
135 | def __init__(self, cfg):
136 | super().__init__()
137 | self.cfg = cfg
138 | self.loss = nn.BCEWithLogitsLoss(reduction="mean")
139 | self.projection = projection(self.cfg)
140 |
141 | def forward(self, output, target_range_view):
142 | target_mask = self.projection.get_target_mask_from_range_view(target_range_view)
143 | loss = self.loss(output["mask_logits"], target_mask)
144 | return loss
145 |
146 |
147 | class loss_range(nn.Module):
148 | """L1 loss for range image prediction"""
149 |
150 | def __init__(self, cfg):
151 | super().__init__()
152 | self.cfg = cfg
153 | self.loss = nn.L1Loss(reduction="mean")
154 |
155 | def forward(self, output, target_range_image):
156 | # Do not count L1 loss for invalid GT points
157 | gt_masked_output = output["rv"].clone()
158 | gt_masked_output[target_range_image == -1.0] = -1.0
159 |
160 | loss = self.loss(gt_masked_output, target_range_image)
161 | return loss
162 |
163 |
164 | class loss_semantic(nn.Module):
165 | """Semantic loss for prediction of rv"""
166 |
167 | def __init__(self, cfg):
168 | super().__init__()
169 | self.cfg = cfg
170 | self.semantic_net = self.cfg["MODEL"]["SEMANTIC_NET"]
171 | self.semantic_data_config = self.cfg["MODEL"]["SEMANTIC_DATA_CONFIG"]
172 | self.semantic_pretrained = self.cfg["MODEL"]["SEMANTIC_PRETRAINED_MODEL"]
173 | self.projection = projection(self.cfg)
174 |
175 | pretrained = "semantic_net/" + self.semantic_net + "/model/" + self.semantic_pretrained
176 | config_base_path = "semantic_net/" + self.semantic_net + "/tasks/semantic/config/"
177 | arch_config = pretrained + "/arch_cfg.yaml"
178 | data_config = config_base_path + "labels/" + self.semantic_data_config
179 |
180 | self.DATA = yaml.safe_load(open(data_config, 'r'))
181 | self.n_classes = len(self.DATA["learning_map_inv"])
182 | self.learning_map_inv = self.DATA["learning_map_inv"]
183 | self.learning_map = self.DATA["learning_map"]
184 | self.color_map = self.DATA["color_map"]
185 |
186 | module = "semantic_net." + self.semantic_net + ".tasks.semantic.modules.segmentator"
187 | segmentator_module = importlib.import_module(module)
188 |
189 | self.loss = nn.L1Loss(reduction="mean")
190 |
191 | if cfg["TRAIN"]["LOSS_WEIGHT_SEMANTIC"] > 0.0:
192 | self.ARCH = yaml.safe_load(open(arch_config, 'r'))
193 | self.sensor = self.ARCH["dataset"]["sensor"]
194 | self.sensor_img_means = torch.tensor(self.sensor["img_means"], dtype=torch.float)
195 | self.sensor_img_stds = torch.tensor(self.sensor["img_stds"], dtype=torch.float)
196 | with torch.no_grad():
197 | self.semantic_model = segmentator_module.Segmentator(self.ARCH, self.n_classes, pretrained)
198 | self.semantic_model.eval()
199 |
200 | # Semantic Similarity
201 | self.criterion = nn.NLLLoss()
202 | self.semantic_similarity = self.cfg["TEST"]["SEMANTIC_SIMILARITY"]
203 |
204 | def forward(self, output, target_range_image, target_label, mode=None):
205 | B, T, H, W = output["rv"].shape
206 | sensor_img_means = self.sensor_img_means.type_as(target_range_image)
207 | sensor_img_stds = self.sensor_img_stds.type_as(target_range_image)
208 |
209 | label = target_label
210 | output = output["rv"].unsqueeze(2) # [B, T, 1, H, W]
211 | target = target_range_image.unsqueeze(2)
212 |
213 | output_mask = output != -1.0
214 | target_mask = target != -1.0
215 |
216 | output_masked = (output - sensor_img_means[None, None, 0:1, None, None]
217 | ) / sensor_img_stds[None, None, 0:1, None, None]
218 |
219 | target_masked = (target - sensor_img_means[None, None, 0:1, None, None]
220 | ) / sensor_img_stds[None, None, 0:1, None, None]
221 |
222 | output_masked = output_masked * output_mask
223 | target_masked = target_masked * target_mask
224 |
225 | rand_t = random.randint(0, T-1)
226 | output_t = self.semantic_model(output_masked[:, rand_t, :, :, :]) # [B, 20, H, W]
227 | target_t = self.semantic_model(target_masked[:, rand_t, :, :, :])
228 | label = label[:, rand_t, :, :]
229 |
230 | # get labels from semantic model output
231 | output_argmax = output_t.argmax(dim=1) # [B, H, W]
232 | target_argmax = target_t.argmax(dim=1)
233 | label = map(label.cpu().numpy().astype(np.int32), self.learning_map)
234 | label = torch.tensor(label).type_as(target_label)
235 |
236 | # Do not count L1 loss for invalid GT points
237 | target_mask = target[:, rand_t, :, :, :].repeat(1, 20, 1, 1)
238 | output_last = output_t.clone()
239 | target_last = target_t.clone()
240 | output_last[target_mask == -1.0] = -1.0
241 | target_last[target_mask == -1.0] = -1.0
242 |
243 | # set argmax image for visualization
244 | output_argmax[target_last[:, 0, :, :] == -1.0] = 0.0
245 | target_argmax[target_last[:, 0, :, :] == -1.0] = 0.0
246 |
247 | # Semantic Similarity
248 | if mode == "test" and self.semantic_similarity:
249 | similarity_output = self.n_classes / self.criterion(torch.log(output_t.clamp(min=1e-8)), label.long())
250 | similarity_target = self.n_classes / self.criterion(torch.log(target_t.clamp(min=1e-8)), label.long())
251 | else:
252 | similarity_output = torch.zeros(1).type_as(output_t)
253 | similarity_target = torch.zeros(1).type_as(output_t)
254 |
255 | loss = self.loss(output_last, target_last)
256 |
257 | return loss, output_argmax.cpu().numpy(), target_argmax.cpu().numpy(), \
258 | label.cpu().numpy(), rand_t, similarity_output, similarity_target
259 |
260 |
261 | class chamfer_distance(nn.Module):
262 | """Chamfer distance loss. Additionally, the implementation allows the evaluation
263 | on downsampled point cloud (this is only for comparison to other methods but not recommended,
264 | because it is a bad approximation of the real Chamfer distance.
265 | """
266 |
267 | def __init__(self, cfg):
268 | super().__init__()
269 | self.cfg = cfg
270 | self.loss = ChamferDistance()
271 | self.projection = projection(self.cfg)
272 |
273 | def forward(self, output, target, n_samples):
274 | batch_size, n_future_steps, H, W = output["rv"].shape
275 | masked_output = self.projection.get_masked_range_view(output)
276 | chamfer_distances = {}
277 | chamfer_distances_tensor = torch.zeros(n_future_steps, batch_size)
278 | for s in range(n_future_steps):
279 | chamfer_distances[s] = 0
280 | for b in range(batch_size):
281 | output_points = self.projection.get_valid_points_from_range_view(
282 | masked_output[b, s, :, :]
283 | ).view(1, -1, 3)
284 | target_points = target[b, 1:4, s, :, :].permute(1, 2, 0)
285 | target_points = target_points[target[b, 0, s, :, :] > 0.0].view(
286 | 1, -1, 3
287 | )
288 |
289 | if n_samples != -1:
290 | n_output_points = output_points.shape[1]
291 | n_target_points = target_points.shape[1]
292 |
293 | sampled_output_indices = random.sample(
294 | range(n_output_points), n_samples
295 | )
296 | sampled_target_indices = random.sample(
297 | range(n_target_points), n_samples
298 | )
299 |
300 | output_points = output_points[:, sampled_output_indices, :]
301 | target_points = target_points[:, sampled_target_indices, :]
302 |
303 | dist1, dist2 = self.loss(output_points, target_points)
304 | dist_combined = torch.mean(dist1) + torch.mean(dist2)
305 | chamfer_distances[s] += dist_combined
306 | chamfer_distances_tensor[s, b] = dist_combined
307 | chamfer_distances[s] = chamfer_distances[s] / batch_size
308 | return chamfer_distances, chamfer_distances_tensor
309 |
--------------------------------------------------------------------------------
/pcpnet/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/utils/__init__.py
--------------------------------------------------------------------------------
/pcpnet/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/utils/__pycache__/logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/utils/__pycache__/logger.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/utils/__pycache__/preprocess_data.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/utils/__pycache__/preprocess_data.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/utils/__pycache__/projection.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/utils/__pycache__/projection.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/utils/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/utils/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/utils/__pycache__/visualization.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pcpnet/utils/__pycache__/visualization.cpython-38.pyc
--------------------------------------------------------------------------------
/pcpnet/utils/preprocess_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou
3 | # This file is covered by the LICENSE file in the root of the project PCPNet:
4 | # https://github.com/Blurryface0814/PCPNet
5 | # Brief: Preprocessing point cloud to range images
6 | import os
7 | import numpy as np
8 | import torch
9 |
10 | from pcpnet.utils.utils import load_files, range_projection
11 |
12 |
13 | def prepare_data(cfg, dataset_path, rawdata_path):
14 | """Loads point clouds and labels and pre-processes them into range images
15 |
16 | Args:
17 | cfg (dict): Config
18 | """
19 | sequences = (
20 | cfg["DATA_CONFIG"]["SPLIT"]["TRAIN"]
21 | + cfg["DATA_CONFIG"]["SPLIT"]["VAL"]
22 | + cfg["DATA_CONFIG"]["SPLIT"]["TEST"]
23 | )
24 |
25 | proj_H = cfg["DATA_CONFIG"]["HEIGHT"]
26 | proj_W = cfg["DATA_CONFIG"]["WIDTH"]
27 |
28 | for seq in sequences:
29 | seqstr = "{0:02d}".format(int(seq))
30 | scan_folder = os.path.join(rawdata_path, seqstr, "velodyne")
31 | label_folder = os.path.join(rawdata_path, seqstr, "labels")
32 | dst_folder = os.path.join(dataset_path, seqstr, "processed")
33 | if not os.path.exists(dst_folder):
34 | os.makedirs(dst_folder)
35 |
36 | # Load LiDAR scan files and label files
37 | scan_paths = load_files(scan_folder)
38 | label_paths = load_files(label_folder)
39 |
40 | if len(scan_paths) != len(label_paths):
41 | print("Points files: ", len(scan_paths))
42 | print("Label files: ", len(label_paths))
43 | raise ValueError("Scan and Label don't contain same number of files")
44 |
45 | # Iterate over all scan files and label files
46 | for idx in range(len(scan_paths)):
47 | print(
48 | "Processing file {:d}/{:d} of sequence {:d}".format(
49 | idx, len(scan_paths), seq
50 | )
51 | )
52 |
53 | # Load and project a point cloud
54 | current_vertex = np.fromfile(scan_paths[idx], dtype=np.float32)
55 | current_vertex = current_vertex.reshape((-1, 4))
56 | label = np.fromfile(label_paths[idx], dtype=np.int32)
57 | label = label.reshape((-1))
58 | # only fill in attribute if the right size
59 | if current_vertex.shape[0] != label.shape[0]:
60 | print("Points shape: ", current_vertex.shape[0])
61 | print("Label shape: ", label.shape[0])
62 | raise ValueError("Scan and Label don't contain same number of points")
63 |
64 | proj_range, proj_vertex, proj_intensity, proj_idx = range_projection(
65 | current_vertex,
66 | fov_up=cfg["DATA_CONFIG"]["FOV_UP"],
67 | fov_down=cfg["DATA_CONFIG"]["FOV_DOWN"],
68 | proj_H=cfg["DATA_CONFIG"]["HEIGHT"],
69 | proj_W=cfg["DATA_CONFIG"]["WIDTH"],
70 | max_range=cfg["DATA_CONFIG"]["MAX_RANGE"],
71 | )
72 |
73 | # Save range
74 | dst_path_range = os.path.join(dst_folder, "range")
75 | if not os.path.exists(dst_path_range):
76 | os.makedirs(dst_path_range)
77 | file_path = os.path.join(dst_path_range, str(idx).zfill(6))
78 | np.save(file_path, proj_range)
79 |
80 | # Save xyz
81 | dst_path_xyz = os.path.join(dst_folder, "xyz")
82 | if not os.path.exists(dst_path_xyz):
83 | os.makedirs(dst_path_xyz)
84 | file_path = os.path.join(dst_path_xyz, str(idx).zfill(6))
85 | np.save(file_path, proj_vertex)
86 |
87 | # Save intensity
88 | dst_path_intensity = os.path.join(dst_folder, "intensity")
89 | if not os.path.exists(dst_path_intensity):
90 | os.makedirs(dst_path_intensity)
91 | file_path = os.path.join(dst_path_intensity, str(idx).zfill(6))
92 | np.save(file_path, proj_intensity)
93 |
94 | # Save label
95 | sem_label = label & 0xFFFF # semantic label in lower half
96 | inst_label = label >> 16 # instance id in upper half
97 | # sanity check
98 | assert ((sem_label + (inst_label << 16) == label).all())
99 |
100 | proj_sem_label = np.zeros((proj_H, proj_W), dtype=np.int32) # [H,W] label
101 | mask = proj_idx >= 0
102 | proj_sem_label[mask] = sem_label[proj_idx[mask]]
103 |
104 | dst_path_label = os.path.join(dst_folder, "labels")
105 | if not os.path.exists(dst_path_label):
106 | os.makedirs(dst_path_label)
107 | file_path = os.path.join(dst_path_label, str(idx).zfill(6))
108 | np.save(file_path, proj_sem_label)
109 |
110 |
111 | def compute_mean_and_std(cfg, train_loader):
112 | """Compute training data statistics
113 |
114 | Args:
115 | cfg (dict): Config
116 | train_loader (DataLoader): Pytorch DataLoader to access training data
117 | """
118 | n_channels = train_loader.dataset.n_channels
119 | mean = [0] * n_channels
120 | std = [0] * n_channels
121 | max = [0] * n_channels
122 | min = [0] * n_channels
123 | for i, data in enumerate(train_loader):
124 | past = data["past_data"]
125 | batch_size, n_channels, frames, H, W = past.shape
126 |
127 | for j in range(n_channels):
128 | channel = past[:, j, :, :, :].view(batch_size, 1, frames, H, W)
129 | mean[j] += torch.mean(channel[channel != -1.0]) / len(train_loader)
130 | std[j] += torch.std(channel[channel != -1.0]) / len(train_loader)
131 | max[j] += torch.max(channel[channel != -1.0]) / len(train_loader)
132 | min[j] += torch.min(channel[channel != -1.0]) / len(train_loader)
133 |
134 | print("Mean and standard deviation of training data:")
135 | for j in range(n_channels):
136 | print(
137 | "Input {:d}: Mean {:.3f}, std {:.3f}, min {:.3f}, max {:.3f}".format(
138 | j, mean[j], std[j], min[j], max[j]
139 | )
140 | )
141 |
--------------------------------------------------------------------------------
/pcpnet/utils/projection.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou
3 | # This file is covered by the LICENSE file in the root of the project PCPNet:
4 | # https://github.com/Blurryface0814/PCPNet
5 | # Brief: Get a 3D point cloud from a given range image projection
6 | import torch
7 | import torch.nn as nn
8 | import numpy as np
9 |
10 |
11 | class projection:
12 | """Projection class for getting a 3D point cloud from range images"""
13 |
14 | def __init__(self, cfg):
15 | """Init
16 |
17 | Args:
18 | cfg (dict): Parameters
19 | """
20 | self.cfg = cfg
21 |
22 | fov_up = (
23 | self.cfg["DATA_CONFIG"]["FOV_UP"] / 180.0 * np.pi
24 | ) # field of view up in radians
25 | fov_down = (
26 | self.cfg["DATA_CONFIG"]["FOV_DOWN"] / 180.0 * np.pi
27 | ) # field of view down in radians
28 | fov = abs(fov_down) + abs(fov_up) # get field of view total in radian
29 | W = self.cfg["DATA_CONFIG"]["WIDTH"]
30 | H = self.cfg["DATA_CONFIG"]["HEIGHT"]
31 |
32 | h = torch.arange(0, H).view(H, 1).repeat(1, W)
33 | w = torch.arange(0, W).view(1, W).repeat(H, 1)
34 | yaw = np.pi * (1.0 - 2 * torch.true_divide(w, W))
35 | pitch = (1.0 - torch.true_divide(h, H)) * fov - abs(fov_down)
36 | self.x_fac = torch.cos(pitch) * torch.cos(yaw)
37 | self.y_fac = torch.cos(pitch) * torch.sin(yaw)
38 | self.z_fac = torch.sin(pitch)
39 |
40 | def get_valid_points_from_range_view(self, range_view):
41 | """Reproject from range image to valid 3D points
42 |
43 | Args:
44 | range_view (torch.tensor): Range image with size (H,W)
45 |
46 | Returns:
47 | torch.tensor: Valid 3D points with size (N,3)
48 | """
49 | H, W = range_view.shape
50 | points = torch.zeros(H, W, 3).type_as(range_view)
51 | points[:, :, 0] = range_view * self.x_fac.type_as(range_view)
52 | points[:, :, 1] = range_view * self.y_fac.type_as(range_view)
53 | points[:, :, 2] = range_view * self.z_fac.type_as(range_view)
54 | return points[range_view > 0.0]
55 |
56 | def get_xyz_from_range_view(self, range_view):
57 | """Reproject from range image to xyz
58 |
59 | Args:
60 | range_view (torch.tensor): Range image with size (B, T, H, W)
61 |
62 | Returns:
63 | torch.tensor: Valid 3D points with size (B, T, 3, H, W)
64 | """
65 | B, T, H, W = range_view.shape
66 | points = torch.zeros(B, T, H, W, 3).type_as(range_view)
67 | points[:, :, :, :, 0] = range_view * self.x_fac.type_as(range_view)
68 | points[:, :, :, :, 1] = range_view * self.y_fac.type_as(range_view)
69 | points[:, :, :, :, 2] = range_view * self.z_fac.type_as(range_view)
70 | return points.permute(0, 1, 4, 2, 3)
71 |
72 | def get_mask_from_output(self, output):
73 | """Get mask from logits
74 |
75 | Args:
76 | output (dict): Output dict with mask_logits as key
77 |
78 | Returns:
79 | mask: Predicted mask containing per-point probabilities
80 | """
81 | mask = nn.Sigmoid()(output["mask_logits"])
82 | return mask
83 |
84 | def get_target_mask_from_range_view(self, range_view):
85 | """Ground truth mask
86 |
87 | Args:
88 | range_view (torch.tensor): Range image of size (H,W)
89 |
90 | Returns:
91 | torch.tensor: Target mask of valid points
92 | """
93 | target_mask = torch.zeros(range_view.shape).type_as(range_view)
94 | target_mask[range_view != -1.0] = 1.0
95 | return target_mask
96 |
97 | def get_masked_range_view(self, output):
98 | """Get predicted masked range image
99 |
100 | Args:
101 | output (dict): Dictionary containing predicted mask logits and ranges
102 |
103 | Returns:
104 | torch.tensor: Maskes range image in which invalid points are mapped to -1.0
105 | """
106 | mask = self.get_mask_from_output(output)
107 | masked_range_view = output["rv"].clone()
108 |
109 | # Set invalid points to -1.0 according to mask
110 | masked_range_view[mask < self.cfg["MODEL"]["MASK_THRESHOLD"]] = -1.0
111 | return masked_range_view
112 |
--------------------------------------------------------------------------------
/pcpnet/utils/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou
3 | # This file is covered by the LICENSE file in the root of the project PCPNet:
4 | # https://github.com/Blurryface0814/PCPNet
5 | # Brief: Project a point cloud into a range image
6 | import os
7 | import math
8 | import numpy as np
9 | import random
10 | import torch
11 |
12 | def set_seed(seed):
13 | """
14 | Set random seed for torch, numpy and python
15 | """
16 | random.seed(seed)
17 | np.random.seed(seed)
18 | torch.manual_seed(seed)
19 | if torch.cuda.is_available():
20 | torch.cuda.manual_seed(seed)
21 | torch.cuda.manual_seed_all(seed)
22 |
23 | torch.backends.cudnn.benchmark = False
24 |
25 | torch.backends.cudnn.deterministic = True
26 |
27 |
28 | def load_poses(pose_path):
29 | """Load ground truth poses (T_w_cam0) from file.
30 |
31 | Args:
32 | pose_path: (Complete) filename for the pose file
33 |
34 | Returns:
35 | A numpy array of size nx4x4 with n poses as 4x4 transformation
36 | matrices
37 | """
38 | # Read and parse the poses
39 | poses = []
40 | try:
41 | if ".txt" in pose_path:
42 | with open(pose_path, "r") as f:
43 | lines = f.readlines()
44 | for line in lines:
45 | T_w_cam0 = np.fromstring(line, dtype=float, sep=" ")
46 | T_w_cam0 = T_w_cam0.reshape(3, 4)
47 | T_w_cam0 = np.vstack((T_w_cam0, [0, 0, 0, 1]))
48 | poses.append(T_w_cam0)
49 | else:
50 | poses = np.load(pose_path)["arr_0"]
51 |
52 | except FileNotFoundError:
53 | print("Ground truth poses are not avaialble.")
54 |
55 | return np.array(poses)
56 |
57 |
58 | def load_calib(calib_path):
59 | """Load calibrations (T_cam_velo) from file."""
60 | # Read and parse the calibrations
61 | T_cam_velo = []
62 | try:
63 | with open(calib_path, "r") as f:
64 | lines = f.readlines()
65 | for line in lines:
66 | if "Tr:" in line:
67 | line = line.replace("Tr:", "")
68 | T_cam_velo = np.fromstring(line, dtype=float, sep=" ")
69 | T_cam_velo = T_cam_velo.reshape(3, 4)
70 | T_cam_velo = np.vstack((T_cam_velo, [0, 0, 0, 1]))
71 |
72 | except FileNotFoundError:
73 | print("Calibrations are not avaialble.")
74 |
75 | return np.array(T_cam_velo)
76 |
77 |
78 | def range_projection(
79 | current_vertex, fov_up=3.0, fov_down=-25.0, proj_H=64, proj_W=900, max_range=50
80 | ):
81 | """Project a pointcloud into a spherical projection, range image.
82 |
83 | Args:
84 | current_vertex: raw point clouds
85 |
86 | Returns:
87 | proj_range: projected range image with depth, each pixel contains the corresponding depth
88 | proj_vertex: each pixel contains the corresponding point (x, y, z, 1)
89 | proj_intensity: each pixel contains the corresponding intensity
90 | proj_idx: each pixel contains the corresponding index of the point in the raw point cloud
91 | """
92 | # laser parameters
93 | fov_up = fov_up / 180.0 * np.pi # field of view up in radians
94 | fov_down = fov_down / 180.0 * np.pi # field of view down in radians
95 | fov = abs(fov_down) + abs(fov_up) # get field of view total in radians
96 |
97 | # get depth of all points
98 | depth = np.linalg.norm(current_vertex[:, :3], 2, axis=1)
99 |
100 | # # we use a maximum range threshold
101 | # current_vertex = current_vertex[(depth > 0) & (depth < max_range)] # get rid of [0, 0, 0] points
102 | # depth = depth[(depth > 0) & (depth < max_range)]
103 |
104 | # get scan components
105 | scan_x = current_vertex[:, 0]
106 | scan_y = current_vertex[:, 1]
107 | scan_z = current_vertex[:, 2]
108 | intensity = current_vertex[:, 3]
109 |
110 | # get angles of all points
111 | yaw = -np.arctan2(scan_y, scan_x)
112 | pitch = np.arcsin(scan_z / depth)
113 |
114 | # get projections in image coords
115 | proj_x = 0.5 * (yaw / np.pi + 1.0) # in [0.0, 1.0]
116 | proj_y = 1.0 - (pitch + abs(fov_down)) / fov # in [0.0, 1.0]
117 |
118 | # scale to image size using angular resolution
119 | proj_x *= proj_W # in [0.0, W]
120 | proj_y *= proj_H # in [0.0, H]
121 |
122 | # round and clamp for use as index
123 | proj_x = np.floor(proj_x)
124 | proj_x = np.minimum(proj_W - 1, proj_x)
125 | proj_x = np.maximum(0, proj_x).astype(np.int32) # in [0,W-1]
126 |
127 | proj_y = np.floor(proj_y)
128 | proj_y = np.minimum(proj_H - 1, proj_y)
129 | proj_y = np.maximum(0, proj_y).astype(np.int32) # in [0,H-1]
130 |
131 | # order in decreasing depth
132 | order = np.argsort(depth)[::-1]
133 | depth = depth[order]
134 | intensity = intensity[order]
135 | proj_y = proj_y[order]
136 | proj_x = proj_x[order]
137 |
138 | scan_x = scan_x[order]
139 | scan_y = scan_y[order]
140 | scan_z = scan_z[order]
141 |
142 | indices = np.arange(depth.shape[0])
143 | indices = indices[order]
144 |
145 | proj_range = np.full(
146 | (proj_H, proj_W), -1, dtype=np.float32
147 | ) # [H,W] range (-1 is no data)
148 | proj_vertex = np.full(
149 | (proj_H, proj_W, 4), -1, dtype=np.float32
150 | ) # [H,W] index (-1 is no data)
151 | proj_idx = np.full(
152 | (proj_H, proj_W), -1, dtype=np.int32
153 | ) # [H,W] index (-1 is no data)
154 | proj_intensity = np.full(
155 | (proj_H, proj_W), -1, dtype=np.float32
156 | ) # [H,W] index (-1 is no data)
157 |
158 | proj_range[proj_y, proj_x] = depth
159 | proj_vertex[proj_y, proj_x] = np.array(
160 | [scan_x, scan_y, scan_z, np.ones(len(scan_x))]
161 | ).T
162 | proj_idx[proj_y, proj_x] = indices
163 | proj_intensity[proj_y, proj_x] = intensity
164 |
165 | return proj_range, proj_vertex, proj_intensity, proj_idx
166 |
167 |
168 | def gen_normal_map(current_range, current_vertex, proj_H=64, proj_W=900):
169 | """Generate a normal image given the range projection of a point cloud.
170 |
171 | Args:
172 | current_range: range projection of a point cloud, each pixel contains the corresponding depth
173 | current_vertex: range projection of a point cloud,
174 | each pixel contains the corresponding point (x, y, z, 1)
175 |
176 | Returns:
177 | normal_data: each pixel contains the corresponding normal
178 | """
179 | normal_data = np.full((proj_H, proj_W, 3), -1, dtype=np.float32)
180 |
181 | # iterate over all pixels in the range image
182 | for x in range(proj_W):
183 | for y in range(proj_H - 1):
184 | p = current_vertex[y, x][:3]
185 | depth = current_range[y, x]
186 |
187 | if depth > 0:
188 | wrap_x = wrap(x + 1, proj_W)
189 | u = current_vertex[y, wrap_x][:3]
190 | u_depth = current_range[y, wrap_x]
191 | if u_depth <= 0:
192 | continue
193 |
194 | v = current_vertex[y + 1, x][:3]
195 | v_depth = current_range[y + 1, x]
196 | if v_depth <= 0:
197 | continue
198 |
199 | u_norm = (u - p) / np.linalg.norm(u - p)
200 | v_norm = (v - p) / np.linalg.norm(v - p)
201 |
202 | w = np.cross(v_norm, u_norm)
203 | norm = np.linalg.norm(w)
204 | if norm > 0:
205 | normal = w / norm
206 | normal_data[y, x] = normal
207 |
208 | return normal_data
209 |
210 |
211 | def wrap(x, dim):
212 | """Wrap the boarder of the range image."""
213 | value = x
214 | if value >= dim:
215 | value = value - dim
216 | if value < 0:
217 | value = value + dim
218 | return value
219 |
220 |
221 | def euler_angles_from_rotation_matrix(R):
222 | """From the paper by Gregory G. Slabaugh,
223 | Computing Euler angles from a rotation matrix
224 | psi, theta, phi = roll pitch yaw (x, y, z)
225 |
226 | Args:
227 | R: rotation matrix, a 3x3 numpy array
228 |
229 | Returns:
230 | a tuple with the 3 values psi, theta, phi in radians
231 | """
232 |
233 | def isclose(x, y, rtol=1.0e-5, atol=1.0e-8):
234 | return abs(x - y) <= atol + rtol * abs(y)
235 |
236 | phi = 0.0
237 | if isclose(R[2, 0], -1.0):
238 | theta = math.pi / 2.0
239 | psi = math.atan2(R[0, 1], R[0, 2])
240 | elif isclose(R[2, 0], 1.0):
241 | theta = -math.pi / 2.0
242 | psi = math.atan2(-R[0, 1], -R[0, 2])
243 | else:
244 | theta = -math.asin(R[2, 0])
245 | cos_theta = math.cos(theta)
246 | psi = math.atan2(R[2, 1] / cos_theta, R[2, 2] / cos_theta)
247 | phi = math.atan2(R[1, 0] / cos_theta, R[0, 0] / cos_theta)
248 | return psi, theta, phi
249 |
250 |
251 | def load_vertex(scan_path):
252 | """Load 3D points of a scan. The fileformat is the .bin format used in
253 | the KITTI dataset.
254 |
255 | Args:
256 | scan_path: the (full) filename of the scan file
257 |
258 | Returns:
259 | A nx4 numpy array of homogeneous points (x, y, z, 1).
260 | """
261 | current_vertex = np.fromfile(scan_path, dtype=np.float32)
262 | current_vertex = current_vertex.reshape((-1, 4))
263 | current_points = current_vertex[:, 0:3]
264 | current_vertex = np.ones((current_points.shape[0], current_points.shape[1] + 1))
265 | current_vertex[:, :-1] = current_points
266 | return current_vertex
267 |
268 |
269 | def load_files(folder):
270 | """Load all files in a folder and sort."""
271 | file_paths = [
272 | os.path.join(dp, f)
273 | for dp, dn, fn in os.walk(os.path.expanduser(folder))
274 | for f in fn
275 | ]
276 | file_paths.sort()
277 | return file_paths
278 |
--------------------------------------------------------------------------------
/pcpnet/utils/visualization.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou
3 | # This file is covered by the LICENSE file in the root of the project PCPNet:
4 | # https://github.com/Blurryface0814/PCPNet
5 | # Brief: Visualization of Point Cloud Predictions with Open3D
6 | from functools import partial
7 | import os
8 | import glob
9 | import time
10 | import open3d as o3d
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 |
14 | WIDTH = 1000
15 | HEIGHT = 1000
16 | POINTSIZE = 1.5
17 | SLEEPTIME = 0.3
18 |
19 |
20 | def get_car_model(filename):
21 | """Car model for visualization
22 |
23 | Args:
24 | filename (str): filename of mesh
25 |
26 | Returns:
27 | mesh: open3D mesh
28 | """
29 | mesh = o3d.io.read_triangle_mesh(filename)
30 | mesh.scale(0.001, (-0.5, 0.4, -1.2))
31 | R = mesh.get_rotation_matrix_from_quaternion(np.array([0, 0, 1, 1]).T)
32 | mesh.rotate(R)
33 | mesh.compute_triangle_normals()
34 | mesh.compute_vertex_normals()
35 | return mesh
36 |
37 |
38 | def get_filename(path, idx):
39 | filenames = sorted(glob.glob(path + "*.ply"))
40 | return int(filenames[idx].split(".")[0].split("/")[-1])
41 |
42 |
43 | def last_file(path):
44 | return get_filename(path, -1)
45 |
46 |
47 | def first_file(path):
48 | return get_filename(path, 0)
49 |
50 |
51 | class Visualization:
52 | """Visualization of point cloud predictions with open3D"""
53 |
54 | def __init__(
55 | self,
56 | path,
57 | sequence,
58 | start,
59 | end,
60 | capture=False,
61 | path_to_car_model=None,
62 | sleep_time=5e-3,
63 | ):
64 | """Init
65 |
66 | Args:
67 | path (str): path to data should be
68 | .
69 | ├── sequence
70 | │ ├── gt
71 | | | ├──frame.ply
72 | │ ├─── pred
73 | | | ├── frame
74 | | │ | ├─── (frame+1).ply
75 |
76 | sequence (int): Sequence to visualize
77 | start (int): Start at specific frame
78 | end (itn): End at specific frame
79 | capture (bool, optional): Save to file at each frame. Defaults to False.
80 | path_to_car_model (str, optional): Path to car model. Defaults to None.
81 | sleep_time (float, optional): Sleep time between frames. Defaults to 5e-3.
82 | """
83 | self.vis = o3d.visualization.VisualizerWithKeyCallback()
84 | self.vis.create_window(width=WIDTH, height=HEIGHT)
85 | self.render_options = self.vis.get_render_option()
86 | self.render_options.point_size = POINTSIZE
87 | self.capture = capture
88 |
89 | # Load car model
90 | if path_to_car_model:
91 | self.car_mesh = get_car_model(path_to_car_model)
92 | else:
93 | self.car_mesh = None
94 |
95 | # Path and sequence to visualize
96 | self.path = path
97 | self.sequence = sequence
98 |
99 | # Frames to visualize
100 | self.start = start
101 | self.end = end
102 |
103 | # Set flag
104 | self.gt_flag = True
105 | self.pred_flag = True
106 | self.car_flag = True
107 |
108 | # Init
109 | self.current_frame = self.start
110 | self.current_step = 1
111 | self.n_pred_steps = 5
112 |
113 | # Save last view
114 | self.ctr = self.vis.get_view_control()
115 | self.camera = self.ctr.convert_to_pinhole_camera_parameters()
116 | self.viewpoint_path = os.path.join(self.path, "viewpoint.json")
117 |
118 | self.print_help()
119 | self.update(self.vis)
120 |
121 | # Continuous time plot
122 | self.stop = False
123 | self.sleep_time = sleep_time
124 |
125 | # Initialize the default callbacks
126 | self._register_key_callbacks()
127 |
128 | self.last_time_key_pressed = time.time()
129 |
130 | def prev_frame(self, vis):
131 | if time.time() - self.last_time_key_pressed > SLEEPTIME:
132 | self.last_time_key_pressed = time.time()
133 | self.current_frame = max(self.start, self.current_frame - 1)
134 | self.update(vis)
135 | return False
136 |
137 | def next_frame(self, vis):
138 | if time.time() - self.last_time_key_pressed > SLEEPTIME:
139 | self.last_time_key_pressed = time.time()
140 | self.current_frame = min(self.end, self.current_frame + 1)
141 | self.update(vis)
142 | return False
143 |
144 | def prev_prediction_step(self, vis):
145 | if time.time() - self.last_time_key_pressed > SLEEPTIME:
146 | self.last_time_key_pressed = time.time()
147 | self.current_step = max(1, self.current_step - 1)
148 | self.update(vis)
149 | return False
150 |
151 | def next_prediction_step(self, vis):
152 | if time.time() - self.last_time_key_pressed > SLEEPTIME:
153 | self.last_time_key_pressed = time.time()
154 | self.current_step = min(self.n_pred_steps, self.current_step + 1)
155 | self.update(vis)
156 | return False
157 |
158 | def show_or_hide_gt(self, vis):
159 | if time.time() - self.last_time_key_pressed > SLEEPTIME:
160 | self.last_time_key_pressed = time.time()
161 | if self.gt_flag:
162 | self.gt_flag = False
163 | else:
164 | self.gt_flag = True
165 | self.update(vis)
166 | return False
167 |
168 | def show_or_hide_pred(self, vis):
169 | if time.time() - self.last_time_key_pressed > SLEEPTIME:
170 | self.last_time_key_pressed = time.time()
171 | if self.pred_flag:
172 | self.pred_flag = False
173 | else:
174 | self.pred_flag = True
175 | self.update(vis)
176 | return False
177 |
178 | def show_or_hide_car(self, vis):
179 | if time.time() - self.last_time_key_pressed > SLEEPTIME:
180 | self.last_time_key_pressed = time.time()
181 | if self.car_flag:
182 | self.car_flag = False
183 | else:
184 | self.car_flag = True
185 | self.update(vis)
186 | return False
187 |
188 | def play_sequence(self, vis):
189 | self.stop = False
190 | while not self.stop:
191 | self.next_frame(vis)
192 | time.sleep(self.sleep_time)
193 |
194 | def stop_sequence(self, vis):
195 | self.stop = True
196 |
197 | def toggle_capture_mode(self, vis):
198 | if self.capture:
199 | self.capture = False
200 | else:
201 | self.capture = True
202 | print(
203 | "Current frame: {:d}/{:d}, Prediction step {:d}/{:d}, capture_mode: {:b}".format(
204 | self.current_frame,
205 | self.end,
206 | self.current_step,
207 | self.n_pred_steps,
208 | self.capture,
209 | ),
210 | end="\r",
211 | )
212 |
213 | def update(self, vis):
214 | """Get point clouds and visualize"""
215 | print(
216 | "Current frame: {:d}/{:d}, Prediction step {:d}/{:d}, capture_mode: {:b}".format(
217 | self.current_frame,
218 | self.end,
219 | self.current_step,
220 | self.n_pred_steps,
221 | self.capture,
222 | ),
223 | end="\r",
224 | )
225 | gt_pcd, pred_pcd = self.get_gt_and_predictions(
226 | self.path, self.sequence, self.current_frame, self.current_step
227 | )
228 |
229 | geometry_list = []
230 |
231 | if self.gt_flag:
232 | geometry_list.append(gt_pcd)
233 | if self.pred_flag:
234 | geometry_list.append(pred_pcd)
235 | if self.car_mesh:
236 | if self.car_flag:
237 | geometry_list.append(self.car_mesh)
238 | self.vis_update_geometries(vis, geometry_list)
239 |
240 | if self.capture:
241 | self.capture_frame()
242 |
243 | def get_gt_and_predictions(self, path, sequence, current_frame, step):
244 | """Load GT and predictions from path
245 |
246 | Args:
247 | path (str): Path to files
248 | sequence (int): Sequence to visualize
249 | current_frame (int): Last received frame for prediction
250 | step (int): Prediction step to visualize
251 |
252 | Returns:
253 | o3d.point_cloud: GT and predicted point clouds
254 | """
255 | pred_path = os.path.join(
256 | path,
257 | sequence,
258 | "pred",
259 | str(current_frame).zfill(6),
260 | str(current_frame + step).zfill(6) + ".ply",
261 | )
262 | pred_pcd = o3d.io.read_point_cloud(pred_path)
263 | pred_pcd.paint_uniform_color([0, 0, 1])
264 |
265 | gt_path = os.path.join(
266 | path, sequence, "gt", str(current_frame + step).zfill(6) + ".ply"
267 | )
268 | gt_pcd = o3d.io.read_point_cloud(gt_path)
269 | gt_pcd.paint_uniform_color([1, 0, 0])
270 | return gt_pcd, pred_pcd
271 |
272 | def vis_update_geometries(self, vis, geometries):
273 | """Save camera pose and update point clouds"""
274 | # Save current camera pose
275 | self.camera = self.ctr.convert_to_pinhole_camera_parameters()
276 |
277 | vis.clear_geometries()
278 | for geometry in geometries:
279 | vis.add_geometry(geometry)
280 |
281 | # Set to last view
282 | self.ctr.convert_from_pinhole_camera_parameters(self.camera)
283 |
284 | self.vis.poll_events()
285 | self.vis.update_renderer()
286 |
287 | def set_render_options(self, **kwargs):
288 | for key, value in kwargs.items():
289 | setattr(self.render_options, key, value)
290 |
291 | def register_key_callback(self, key, callback):
292 | self.vis.register_key_callback(ord(str(key)), partial(callback))
293 |
294 | def set_white_background(self, vis):
295 | """Change backround between white and white"""
296 | self.render_options.background_color = [1.0, 1.0, 1.0]
297 |
298 | def set_black_background(self, vis):
299 | """Change backround between white and black"""
300 | self.render_options.background_color = [0.0, 0.0, 0.0]
301 |
302 | def save_viewpoint(self, vis):
303 | """Saves viewpoint"""
304 | self.camera = self.ctr.convert_to_pinhole_camera_parameters()
305 | o3d.io.write_pinhole_camera_parameters(self.viewpoint_path, self.camera)
306 |
307 | def load_viewpoint(self, vis):
308 | """Loads viewpoint"""
309 | self.camera = o3d.io.read_pinhole_camera_parameters(self.viewpoint_path)
310 | self.ctr.convert_from_pinhole_camera_parameters(self.camera)
311 |
312 | def _register_key_callbacks(self):
313 | self.register_key_callback("L", self.next_frame)
314 | self.register_key_callback("H", self.prev_frame)
315 | self.register_key_callback("K", self.next_prediction_step)
316 | self.register_key_callback("J", self.prev_prediction_step)
317 | self.register_key_callback("G", self.show_or_hide_gt)
318 | self.register_key_callback("F", self.show_or_hide_pred)
319 | self.register_key_callback("D", self.show_or_hide_car)
320 | self.register_key_callback("S", self.play_sequence)
321 | self.register_key_callback("X", self.stop_sequence)
322 | self.register_key_callback("C", self.toggle_capture_mode)
323 | self.register_key_callback("W", self.set_white_background)
324 | self.register_key_callback("B", self.set_black_background)
325 | self.register_key_callback("V", self.save_viewpoint)
326 | self.register_key_callback("Q", self.load_viewpoint)
327 |
328 | def print_help(self):
329 | print("L: next frame")
330 | print("H: previous frame")
331 | print("K: next prediction step")
332 | print("J: previous prediction step")
333 | print("G: show or hide gt point cloud")
334 | print("F: show or hide pred point cloud")
335 | print("D: show or hide car model")
336 | print("S: start")
337 | print("X: stop")
338 | print("C: Toggle capture mode")
339 | print("W: white background")
340 | print("B: black background")
341 | print("V: save viewpoint")
342 | print("Q: set to saved viewpoint")
343 | print("ESC: quit")
344 |
345 | def capture_frame(self):
346 | """Save view from current viewpoint"""
347 | image = self.vis.capture_screen_float_buffer(False)
348 | path = os.path.join(self.path, self.sequence, "images")
349 | if not os.path.exists(path):
350 | os.makedirs(path)
351 |
352 | filename = os.path.join(
353 | path,
354 | "step_{:1d}_prediction_from_frame_{:05d}.png".format(
355 | self.current_step, self.current_frame
356 | ),
357 | )
358 | print("Capture image ", filename)
359 | plt.imsave(filename, np.asarray(image), dpi=1)
360 |
361 | def run(self):
362 | self.vis.run()
363 | self.vis.destroy_window()
364 |
--------------------------------------------------------------------------------
/pyTorchChamferDistance/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) [year] [fullname]
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyTorchChamferDistance/README.md:
--------------------------------------------------------------------------------
1 | # Chamfer Distance for pyTorch
2 |
3 | This is an implementation of the Chamfer Distance as a module for pyTorch. It is written as a custom C++/CUDA extension.
4 |
5 | As it is using pyTorch's [JIT compilation](https://pytorch.org/tutorials/advanced/cpp_extension.html), there are no additional prerequisite steps that have to be taken. Simply import the module as shown below; CUDA and C++ code will be compiled on the first run.
6 |
7 | ### Usage
8 | ```python
9 | from chamfer_distance import ChamferDistance
10 | chamfer_dist = ChamferDistance()
11 |
12 | #...
13 | # points and points_reconstructed are n_points x 3 matrices
14 |
15 | dist1, dist2 = chamfer_dist(points, points_reconstructed)
16 | loss = (torch.mean(dist1)) + (torch.mean(dist2))
17 |
18 |
19 | #...
20 | ```
21 |
22 | ### Integration
23 | This code has been integrated into the [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) library for 3D Deep Learning by NVIDIAGameWorks. You should probably take a look at it if you are working on anything 3D :)
24 |
--------------------------------------------------------------------------------
/pyTorchChamferDistance/chamfer_distance/__init__.py:
--------------------------------------------------------------------------------
1 | from .chamfer_distance import ChamferDistance
2 |
--------------------------------------------------------------------------------
/pyTorchChamferDistance/chamfer_distance/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pyTorchChamferDistance/chamfer_distance/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/pyTorchChamferDistance/chamfer_distance/__pycache__/chamfer_distance.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/pyTorchChamferDistance/chamfer_distance/__pycache__/chamfer_distance.cpython-38.pyc
--------------------------------------------------------------------------------
/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | // CUDA forward declarations
4 | int ChamferDistanceKernelLauncher(
5 | const int b, const int n,
6 | const float* xyz,
7 | const int m,
8 | const float* xyz2,
9 | float* result,
10 | int* result_i,
11 | float* result2,
12 | int* result2_i);
13 |
14 | int ChamferDistanceGradKernelLauncher(
15 | const int b, const int n,
16 | const float* xyz1,
17 | const int m,
18 | const float* xyz2,
19 | const float* grad_dist1,
20 | const int* idx1,
21 | const float* grad_dist2,
22 | const int* idx2,
23 | float* grad_xyz1,
24 | float* grad_xyz2);
25 |
26 |
27 | void chamfer_distance_forward_cuda(
28 | const at::Tensor xyz1,
29 | const at::Tensor xyz2,
30 | const at::Tensor dist1,
31 | const at::Tensor dist2,
32 | const at::Tensor idx1,
33 | const at::Tensor idx2)
34 | {
35 | ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(),
36 | xyz2.size(1), xyz2.data(),
37 | dist1.data(), idx1.data(),
38 | dist2.data(), idx2.data());
39 | }
40 |
41 | void chamfer_distance_backward_cuda(
42 | const at::Tensor xyz1,
43 | const at::Tensor xyz2,
44 | at::Tensor gradxyz1,
45 | at::Tensor gradxyz2,
46 | at::Tensor graddist1,
47 | at::Tensor graddist2,
48 | at::Tensor idx1,
49 | at::Tensor idx2)
50 | {
51 | ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(),
52 | xyz2.size(1), xyz2.data(),
53 | graddist1.data(), idx1.data(),
54 | graddist2.data(), idx2.data(),
55 | gradxyz1.data(), gradxyz2.data());
56 | }
57 |
58 |
59 | void nnsearch(
60 | const int b, const int n, const int m,
61 | const float* xyz1,
62 | const float* xyz2,
63 | float* dist,
64 | int* idx)
65 | {
66 | for (int i = 0; i < b; i++) {
67 | for (int j = 0; j < n; j++) {
68 | const float x1 = xyz1[(i*n+j)*3+0];
69 | const float y1 = xyz1[(i*n+j)*3+1];
70 | const float z1 = xyz1[(i*n+j)*3+2];
71 | double best = 0;
72 | int besti = 0;
73 | for (int k = 0; k < m; k++) {
74 | const float x2 = xyz2[(i*m+k)*3+0] - x1;
75 | const float y2 = xyz2[(i*m+k)*3+1] - y1;
76 | const float z2 = xyz2[(i*m+k)*3+2] - z1;
77 | const double d=x2*x2+y2*y2+z2*z2;
78 | if (k==0 || d < best){
79 | best = d;
80 | besti = k;
81 | }
82 | }
83 | dist[i*n+j] = best;
84 | idx[i*n+j] = besti;
85 | }
86 | }
87 | }
88 |
89 |
90 | void chamfer_distance_forward(
91 | const at::Tensor xyz1,
92 | const at::Tensor xyz2,
93 | const at::Tensor dist1,
94 | const at::Tensor dist2,
95 | const at::Tensor idx1,
96 | const at::Tensor idx2)
97 | {
98 | const int batchsize = xyz1.size(0);
99 | const int n = xyz1.size(1);
100 | const int m = xyz2.size(1);
101 |
102 | const float* xyz1_data = xyz1.data();
103 | const float* xyz2_data = xyz2.data();
104 | float* dist1_data = dist1.data();
105 | float* dist2_data = dist2.data();
106 | int* idx1_data = idx1.data();
107 | int* idx2_data = idx2.data();
108 |
109 | nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data);
110 | nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data);
111 | }
112 |
113 |
114 | void chamfer_distance_backward(
115 | const at::Tensor xyz1,
116 | const at::Tensor xyz2,
117 | at::Tensor gradxyz1,
118 | at::Tensor gradxyz2,
119 | at::Tensor graddist1,
120 | at::Tensor graddist2,
121 | at::Tensor idx1,
122 | at::Tensor idx2)
123 | {
124 | const int b = xyz1.size(0);
125 | const int n = xyz1.size(1);
126 | const int m = xyz2.size(1);
127 |
128 | const float* xyz1_data = xyz1.data();
129 | const float* xyz2_data = xyz2.data();
130 | float* gradxyz1_data = gradxyz1.data();
131 | float* gradxyz2_data = gradxyz2.data();
132 | float* graddist1_data = graddist1.data();
133 | float* graddist2_data = graddist2.data();
134 | const int* idx1_data = idx1.data();
135 | const int* idx2_data = idx2.data();
136 |
137 | for (int i = 0; i < b*n*3; i++)
138 | gradxyz1_data[i] = 0;
139 | for (int i = 0; i < b*m*3; i++)
140 | gradxyz2_data[i] = 0;
141 | for (int i = 0;i < b; i++) {
142 | for (int j = 0; j < n; j++) {
143 | const float x1 = xyz1_data[(i*n+j)*3+0];
144 | const float y1 = xyz1_data[(i*n+j)*3+1];
145 | const float z1 = xyz1_data[(i*n+j)*3+2];
146 | const int j2 = idx1_data[i*n+j];
147 |
148 | const float x2 = xyz2_data[(i*m+j2)*3+0];
149 | const float y2 = xyz2_data[(i*m+j2)*3+1];
150 | const float z2 = xyz2_data[(i*m+j2)*3+2];
151 | const float g = graddist1_data[i*n+j]*2;
152 |
153 | gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2);
154 | gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2);
155 | gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2);
156 | gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2));
157 | gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2));
158 | gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2));
159 | }
160 | for (int j = 0; j < m; j++) {
161 | const float x1 = xyz2_data[(i*m+j)*3+0];
162 | const float y1 = xyz2_data[(i*m+j)*3+1];
163 | const float z1 = xyz2_data[(i*m+j)*3+2];
164 | const int j2 = idx2_data[i*m+j];
165 | const float x2 = xyz1_data[(i*n+j2)*3+0];
166 | const float y2 = xyz1_data[(i*n+j2)*3+1];
167 | const float z2 = xyz1_data[(i*n+j2)*3+2];
168 | const float g = graddist2_data[i*m+j]*2;
169 | gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2);
170 | gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2);
171 | gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2);
172 | gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2));
173 | gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2));
174 | gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2));
175 | }
176 | }
177 | }
178 |
179 |
180 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
181 | m.def("forward", &chamfer_distance_forward, "ChamferDistance forward");
182 | m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)");
183 | m.def("backward", &chamfer_distance_backward, "ChamferDistance backward");
184 | m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)");
185 | }
186 |
--------------------------------------------------------------------------------
/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 |
6 | __global__
7 | void ChamferDistanceKernel(
8 | int b,
9 | int n,
10 | const float* xyz,
11 | int m,
12 | const float* xyz2,
13 | float* result,
14 | int* result_i)
15 | {
16 | const int batch=512;
17 | __shared__ float buf[batch*3];
18 | for (int i=blockIdx.x;ibest){
130 | result[(i*n+j)]=best;
131 | result_i[(i*n+j)]=best_i;
132 | }
133 | }
134 | __syncthreads();
135 | }
136 | }
137 | }
138 |
139 | void ChamferDistanceKernelLauncher(
140 | const int b, const int n,
141 | const float* xyz,
142 | const int m,
143 | const float* xyz2,
144 | float* result,
145 | int* result_i,
146 | float* result2,
147 | int* result2_i)
148 | {
149 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i);
150 | ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i);
151 |
152 | cudaError_t err = cudaGetLastError();
153 | if (err != cudaSuccess)
154 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err));
155 | }
156 |
157 |
158 | __global__
159 | void ChamferDistanceGradKernel(
160 | int b, int n,
161 | const float* xyz1,
162 | int m,
163 | const float* xyz2,
164 | const float* grad_dist1,
165 | const int* idx1,
166 | float* grad_xyz1,
167 | float* grad_xyz2)
168 | {
169 | for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2);
204 | ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1);
205 |
206 | cudaError_t err = cudaGetLastError();
207 | if (err != cudaSuccess)
208 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err));
209 | }
210 |
--------------------------------------------------------------------------------
/pyTorchChamferDistance/chamfer_distance/chamfer_distance.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import os
4 |
5 | from torch.utils.cpp_extension import load
6 | cd = load(name="cd",
7 | sources=[os.path.abspath("pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp"),
8 | os.path.abspath("pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu")])
9 |
10 | class ChamferDistanceFunction(torch.autograd.Function):
11 | @staticmethod
12 | def forward(ctx, xyz1, xyz2):
13 | batchsize, n, _ = xyz1.size()
14 | _, m, _ = xyz2.size()
15 | xyz1 = xyz1.contiguous()
16 | xyz2 = xyz2.contiguous()
17 | dist1 = torch.zeros(batchsize, n)
18 | dist2 = torch.zeros(batchsize, m)
19 |
20 | idx1 = torch.zeros(batchsize, n, dtype=torch.int)
21 | idx2 = torch.zeros(batchsize, m, dtype=torch.int)
22 |
23 | if not xyz1.is_cuda:
24 | cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
25 | else:
26 | dist1 = dist1.cuda()
27 | dist2 = dist2.cuda()
28 | idx1 = idx1.cuda()
29 | idx2 = idx2.cuda()
30 | cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2)
31 |
32 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
33 |
34 | return dist1, dist2
35 |
36 | @staticmethod
37 | def backward(ctx, graddist1, graddist2):
38 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
39 |
40 | graddist1 = graddist1.contiguous()
41 | graddist2 = graddist2.contiguous()
42 |
43 | gradxyz1 = torch.zeros(xyz1.size())
44 | gradxyz2 = torch.zeros(xyz2.size())
45 |
46 | if not graddist1.is_cuda:
47 | cd.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
48 | else:
49 | gradxyz1 = gradxyz1.cuda()
50 | gradxyz2 = gradxyz2.cuda()
51 | cd.backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
52 |
53 | return gradxyz1, gradxyz2
54 |
55 |
56 | class ChamferDistance(torch.nn.Module):
57 | def forward(self, xyz1, xyz2):
58 | return ChamferDistanceFunction.apply(xyz1, xyz2)
59 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "point-cloud-prediction"
3 | version = "0.1.0"
4 | description = ""
5 | authors = ["Benedikt Mersch "]
6 | packages = [
7 | {include = "./pcf"},
8 | {include = "./pyTorchChamferDistance"},
9 | ]
10 |
11 | [tool.poetry.dependencies]
12 | python = "^3.8"
13 | PyYAML = "^6.0"
14 | matplotlib = "^3.4.3"
15 | open3d = "^0.13.0"
16 | pytorch-lightning = "^1.5.0"
17 | ninja = "^1.10.2"
18 |
19 |
20 | [tool.poetry.dev-dependencies]
21 |
22 | [build-system]
23 | requires = ["poetry-core>=1.0.0"]
24 | build-backend = "poetry.core.masonry.api"
25 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/README.md:
--------------------------------------------------------------------------------
1 | # LiDAR-Bonnetal Training
2 |
3 | This part of the framework deals with the training of segmentation networks for point cloud data using range images.
4 |
5 | ## Tasks
6 |
7 | - [Semantic Segmentation](tasks/semantic).
8 | - [Panoptic Segmentation](tasks/panoptic) \[Soon\].
9 |
10 | ## Dependencies
11 |
12 | First you need to install the nvidia driver and CUDA, so have fun!
13 |
14 | - CUDA Installation guide: [link](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)
15 |
16 | - System dependencies:
17 |
18 | ```sh
19 | $ sudo apt-get update
20 | $ sudo apt-get install -yqq build-essential ninja-build \
21 | python3-dev python3-pip apt-utils curl git cmake unzip autoconf autogen \
22 | libtool mlocate zlib1g-dev python3-numpy python3-wheel wget \
23 | software-properties-common openjdk-8-jdk libpng-dev \
24 | libxft-dev ffmpeg python3-pyqt5.qtopengl
25 | $ sudo updatedb
26 | ```
27 |
28 | - Python dependencies
29 |
30 | ```sh
31 | $ sudo pip3 install -r requirements.txt
32 | ```
--------------------------------------------------------------------------------
/semantic_net/rangenet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/__init__.py
--------------------------------------------------------------------------------
/semantic_net/rangenet/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/semantic_net/rangenet/backbones/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/backbones/__init__.py
--------------------------------------------------------------------------------
/semantic_net/rangenet/backbones/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/backbones/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/semantic_net/rangenet/backbones/__pycache__/squeezeseg.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/backbones/__pycache__/squeezeseg.cpython-38.pyc
--------------------------------------------------------------------------------
/semantic_net/rangenet/backbones/__pycache__/squeezesegV2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/backbones/__pycache__/squeezesegV2.cpython-38.pyc
--------------------------------------------------------------------------------
/semantic_net/rangenet/backbones/darknet.py:
--------------------------------------------------------------------------------
1 | # This file was modified from https://github.com/BobLiu20/YOLOv3_PyTorch
2 | # It needed to be modified in order to accomodate for different strides in the
3 |
4 | import torch.nn as nn
5 | from collections import OrderedDict
6 | import torch.nn.functional as F
7 |
8 |
9 | class BasicBlock(nn.Module):
10 | def __init__(self, inplanes, planes, bn_d=0.1):
11 | super(BasicBlock, self).__init__()
12 | self.conv1 = nn.Conv2d(inplanes, planes[0], kernel_size=1,
13 | stride=1, padding=0, bias=False)
14 | self.bn1 = nn.BatchNorm2d(planes[0], momentum=bn_d)
15 | self.relu1 = nn.LeakyReLU(0.1)
16 | self.conv2 = nn.Conv2d(planes[0], planes[1], kernel_size=3,
17 | stride=1, padding=1, bias=False)
18 | self.bn2 = nn.BatchNorm2d(planes[1], momentum=bn_d)
19 | self.relu2 = nn.LeakyReLU(0.1)
20 |
21 | def forward(self, x):
22 | residual = x
23 |
24 | out = self.conv1(x)
25 | out = self.bn1(out)
26 | out = self.relu1(out)
27 |
28 | out = self.conv2(out)
29 | out = self.bn2(out)
30 | out = self.relu2(out)
31 |
32 | out += residual
33 | return out
34 |
35 |
36 | # ******************************************************************************
37 |
38 | # number of layers per model
39 | model_blocks = {
40 | 21: [1, 1, 2, 2, 1],
41 | 53: [1, 2, 8, 8, 4],
42 | }
43 |
44 |
45 | class Backbone(nn.Module):
46 | """
47 | Class for DarknetSeg. Subclasses PyTorch's own "nn" module
48 | """
49 |
50 | def __init__(self, params):
51 | super(Backbone, self).__init__()
52 | self.use_range = params["input_depth"]["range"]
53 | self.use_xyz = params["input_depth"]["xyz"]
54 | self.use_remission = params["input_depth"]["remission"]
55 | self.drop_prob = params["dropout"]
56 | self.bn_d = params["bn_d"]
57 | self.OS = params["OS"]
58 | self.layers = params["extra"]["layers"]
59 | print("Using DarknetNet" + str(self.layers) + " Backbone")
60 |
61 | # input depth calc
62 | self.input_depth = 0
63 | self.input_idxs = []
64 | if self.use_range:
65 | self.input_depth += 1
66 | self.input_idxs.append(0)
67 | if self.use_xyz:
68 | self.input_depth += 3
69 | self.input_idxs.extend([1, 2, 3])
70 | if self.use_remission:
71 | self.input_depth += 1
72 | self.input_idxs.append(4)
73 | print("Depth of backbone input = ", self.input_depth)
74 |
75 | # stride play
76 | self.strides = [2, 2, 2, 2, 2]
77 | # check current stride
78 | current_os = 1
79 | for s in self.strides:
80 | current_os *= s
81 | print("Original OS: ", current_os)
82 |
83 | # make the new stride
84 | if self.OS > current_os:
85 | print("Can't do OS, ", self.OS,
86 | " because it is bigger than original ", current_os)
87 | else:
88 | # redo strides according to needed stride
89 | for i, stride in enumerate(reversed(self.strides), 0):
90 | if int(current_os) != self.OS:
91 | if stride == 2:
92 | current_os /= 2
93 | self.strides[-1 - i] = 1
94 | if int(current_os) == self.OS:
95 | break
96 | print("New OS: ", int(current_os))
97 | print("Strides: ", self.strides)
98 |
99 | # check that darknet exists
100 | assert self.layers in model_blocks.keys()
101 |
102 | # generate layers depending on darknet type
103 | self.blocks = model_blocks[self.layers]
104 |
105 | # input layer
106 | self.conv1 = nn.Conv2d(self.input_depth, 32, kernel_size=3,
107 | stride=1, padding=1, bias=False)
108 | self.bn1 = nn.BatchNorm2d(32, momentum=self.bn_d)
109 | self.relu1 = nn.LeakyReLU(0.1)
110 |
111 | # encoder
112 | self.enc1 = self._make_enc_layer(BasicBlock, [32, 64], self.blocks[0],
113 | stride=self.strides[0], bn_d=self.bn_d)
114 | self.enc2 = self._make_enc_layer(BasicBlock, [64, 128], self.blocks[1],
115 | stride=self.strides[1], bn_d=self.bn_d)
116 | self.enc3 = self._make_enc_layer(BasicBlock, [128, 256], self.blocks[2],
117 | stride=self.strides[2], bn_d=self.bn_d)
118 | self.enc4 = self._make_enc_layer(BasicBlock, [256, 512], self.blocks[3],
119 | stride=self.strides[3], bn_d=self.bn_d)
120 | self.enc5 = self._make_enc_layer(BasicBlock, [512, 1024], self.blocks[4],
121 | stride=self.strides[4], bn_d=self.bn_d)
122 |
123 | # for a bit of fun
124 | self.dropout = nn.Dropout2d(self.drop_prob)
125 |
126 | # last channels
127 | self.last_channels = 1024
128 |
129 | # make layer useful function
130 | def _make_enc_layer(self, block, planes, blocks, stride, bn_d=0.1):
131 | layers = []
132 |
133 | # downsample
134 | layers.append(("conv", nn.Conv2d(planes[0], planes[1],
135 | kernel_size=3,
136 | stride=[1, stride], dilation=1,
137 | padding=1, bias=False)))
138 | layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d)))
139 | layers.append(("relu", nn.LeakyReLU(0.1)))
140 |
141 | # blocks
142 | inplanes = planes[1]
143 | for i in range(0, blocks):
144 | layers.append(("residual_{}".format(i),
145 | block(inplanes, planes, bn_d)))
146 |
147 | return nn.Sequential(OrderedDict(layers))
148 |
149 | def run_layer(self, x, layer, skips, os):
150 | y = layer(x)
151 | if y.shape[2] < x.shape[2] or y.shape[3] < x.shape[3]:
152 | skips[os] = x.detach()
153 | os *= 2
154 | x = y
155 | return x, skips, os
156 |
157 | def forward(self, x):
158 | # filter input
159 | x = x[:, self.input_idxs]
160 |
161 | # run cnn
162 | # store for skip connections
163 | skips = {}
164 | os = 1
165 |
166 | # first layer
167 | x, skips, os = self.run_layer(x, self.conv1, skips, os)
168 | x, skips, os = self.run_layer(x, self.bn1, skips, os)
169 | x, skips, os = self.run_layer(x, self.relu1, skips, os)
170 |
171 | # all encoder blocks with intermediate dropouts
172 | x, skips, os = self.run_layer(x, self.enc1, skips, os)
173 | x, skips, os = self.run_layer(x, self.dropout, skips, os)
174 | x, skips, os = self.run_layer(x, self.enc2, skips, os)
175 | x, skips, os = self.run_layer(x, self.dropout, skips, os)
176 | x, skips, os = self.run_layer(x, self.enc3, skips, os)
177 | x, skips, os = self.run_layer(x, self.dropout, skips, os)
178 | x, skips, os = self.run_layer(x, self.enc4, skips, os)
179 | x, skips, os = self.run_layer(x, self.dropout, skips, os)
180 | x, skips, os = self.run_layer(x, self.enc5, skips, os)
181 | x, skips, os = self.run_layer(x, self.dropout, skips, os)
182 |
183 | return x, skips
184 |
185 | def get_last_depth(self):
186 | return self.last_channels
187 |
188 | def get_input_depth(self):
189 | return self.input_depth
190 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/backbones/squeezeseg.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/BichenWuUCB/SqueezeSeg
2 | from __future__ import print_function
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class Fire(nn.Module):
9 |
10 | def __init__(self, inplanes, squeeze_planes,
11 | expand1x1_planes, expand3x3_planes):
12 | super(Fire, self).__init__()
13 | self.inplanes = inplanes
14 | self.activation = nn.ReLU(inplace=True)
15 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
16 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
17 | kernel_size=1)
18 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
19 | kernel_size=3, padding=1)
20 |
21 | def forward(self, x):
22 | x = self.activation(self.squeeze(x))
23 | return torch.cat([
24 | self.activation(self.expand1x1(x)),
25 | self.activation(self.expand3x3(x))
26 | ], 1)
27 |
28 |
29 | # ******************************************************************************
30 |
31 | class Backbone(nn.Module):
32 | """
33 | Class for Squeezeseg. Subclasses PyTorch's own "nn" module
34 | """
35 |
36 | def __init__(self, params):
37 | # Call the super constructor
38 | super(Backbone, self).__init__()
39 | print("Using SqueezeNet Backbone")
40 | self.use_range = params["input_depth"]["range"]
41 | self.use_xyz = params["input_depth"]["xyz"]
42 | self.use_remission = params["input_depth"]["remission"]
43 | self.drop_prob = params["dropout"]
44 | self.OS = params["OS"]
45 |
46 | # input depth calc
47 | self.input_depth = 0
48 | self.input_idxs = []
49 | if self.use_range:
50 | self.input_depth += 1
51 | self.input_idxs.append(0)
52 | if self.use_xyz:
53 | self.input_depth += 3
54 | self.input_idxs.extend([1, 2, 3])
55 | if self.use_remission:
56 | self.input_depth += 1
57 | self.input_idxs.append(4)
58 | print("Depth of backbone input = ", self.input_depth)
59 |
60 | # stride play
61 | self.strides = [2, 2, 2, 2]
62 | # check current stride
63 | current_os = 1
64 | for s in self.strides:
65 | current_os *= s
66 | print("Original OS: ", current_os)
67 |
68 | # make the new stride
69 | if self.OS > current_os:
70 | print("Can't do OS, ", self.OS,
71 | " because it is bigger than original ", current_os)
72 | else:
73 | # redo strides according to needed stride
74 | for i, stride in enumerate(reversed(self.strides), 0):
75 | if int(current_os) != self.OS:
76 | if stride == 2:
77 | current_os /= 2
78 | self.strides[-1 - i] = 1
79 | if int(current_os) == self.OS:
80 | break
81 | print("New OS: ", int(current_os))
82 | print("Strides: ", self.strides)
83 |
84 | # encoder
85 | self.conv1a = nn.Sequential(nn.Conv2d(self.input_depth, 64, kernel_size=3,
86 | stride=[1, self.strides[0]],
87 | padding=1),
88 | nn.ReLU(inplace=True))
89 | self.conv1b = nn.Conv2d(self.input_depth, 64, kernel_size=1,
90 | stride=1, padding=0)
91 | self.fire23 = nn.Sequential(nn.MaxPool2d(kernel_size=3,
92 | stride=[1, self.strides[1]],
93 | padding=1),
94 | Fire(64, 16, 64, 64),
95 | Fire(128, 16, 64, 64))
96 | self.fire45 = nn.Sequential(nn.MaxPool2d(kernel_size=3,
97 | stride=[1, self.strides[2]],
98 | padding=1),
99 | Fire(128, 32, 128, 128),
100 | Fire(256, 32, 128, 128))
101 | self.fire6789 = nn.Sequential(nn.MaxPool2d(kernel_size=3,
102 | stride=[1, self.strides[3]],
103 | padding=1),
104 | Fire(256, 48, 192, 192),
105 | Fire(384, 48, 192, 192),
106 | Fire(384, 64, 256, 256),
107 | Fire(512, 64, 256, 256))
108 |
109 | # output
110 | self.dropout = nn.Dropout2d(self.drop_prob)
111 |
112 | # last channels
113 | self.last_channels = 512
114 |
115 | def run_layer(self, x, layer, skips, os):
116 | y = layer(x)
117 | if y.shape[2] < x.shape[2] or y.shape[3] < x.shape[3]:
118 | skips[os] = x.detach()
119 | os *= 2
120 | x = y
121 | return x, skips, os
122 |
123 | def forward(self, x):
124 | # filter input
125 | x = x[:, self.input_idxs]
126 |
127 | # run cnn
128 | # store for skip connections
129 | skips = {}
130 | os = 1
131 |
132 | # encoder
133 | skip_in = self.conv1b(x)
134 | x = self.conv1a(x)
135 | # first skip done manually
136 | skips[1] = skip_in.detach()
137 | os *= 2
138 |
139 | x, skips, os = self.run_layer(x, self.fire23, skips, os)
140 | x, skips, os = self.run_layer(x, self.dropout, skips, os)
141 | x, skips, os = self.run_layer(x, self.fire45, skips, os)
142 | x, skips, os = self.run_layer(x, self.dropout, skips, os)
143 | x, skips, os = self.run_layer(x, self.fire6789, skips, os)
144 | x, skips, os = self.run_layer(x, self.dropout, skips, os)
145 |
146 | return x, skips
147 |
148 | def get_last_depth(self):
149 | return self.last_channels
150 |
151 | def get_input_depth(self):
152 | return self.input_depth
153 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/backbones/squeezesegV2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # This file is covered by the LICENSE file in the root of this project.
3 |
4 | from __future__ import print_function
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class Fire(nn.Module):
11 | def __init__(self, inplanes, squeeze_planes,
12 | expand1x1_planes, expand3x3_planes, bn_d=0.1):
13 | super(Fire, self).__init__()
14 | self.inplanes = inplanes
15 | self.bn_d = bn_d
16 | self.activation = nn.ReLU(inplace=True)
17 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
18 | self.squeeze_bn = nn.BatchNorm2d(squeeze_planes, momentum=self.bn_d)
19 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
20 | kernel_size=1)
21 | self.expand1x1_bn = nn.BatchNorm2d(expand1x1_planes, momentum=self.bn_d)
22 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
23 | kernel_size=3, padding=1)
24 | self.expand3x3_bn = nn.BatchNorm2d(expand3x3_planes, momentum=self.bn_d)
25 |
26 | def forward(self, x):
27 | x = self.activation(self.squeeze_bn(self.squeeze(x)))
28 | return torch.cat([
29 | self.activation(self.expand1x1_bn(self.expand1x1(x))),
30 | self.activation(self.expand3x3_bn(self.expand3x3(x)))
31 | ], 1)
32 |
33 |
34 | class CAM(nn.Module):
35 |
36 | def __init__(self, inplanes, bn_d=0.1):
37 | super(CAM, self).__init__()
38 | self.inplanes = inplanes
39 | self.bn_d = bn_d
40 | self.pool = nn.MaxPool2d(7, 1, 3)
41 | self.squeeze = nn.Conv2d(inplanes, inplanes // 16,
42 | kernel_size=1, stride=1)
43 | self.squeeze_bn = nn.BatchNorm2d(inplanes // 16, momentum=self.bn_d)
44 | self.relu = nn.ReLU(inplace=True)
45 | self.unsqueeze = nn.Conv2d(inplanes // 16, inplanes,
46 | kernel_size=1, stride=1)
47 | self.unsqueeze_bn = nn.BatchNorm2d(inplanes, momentum=self.bn_d)
48 | self.sigmoid = nn.Sigmoid()
49 |
50 | def forward(self, x):
51 | # 7x7 pooling
52 | y = self.pool(x)
53 | # squeezing and relu
54 | y = self.relu(self.squeeze_bn(self.squeeze(y)))
55 | # unsqueezing
56 | y = self.sigmoid(self.unsqueeze_bn(self.unsqueeze(y)))
57 | # attention
58 | return y * x
59 |
60 | # ******************************************************************************
61 |
62 |
63 | class Backbone(nn.Module):
64 | """
65 | Class for Squeezeseg. Subclasses PyTorch's own "nn" module
66 | """
67 |
68 | def __init__(self, params):
69 | # Call the super constructor
70 | super(Backbone, self).__init__()
71 | print("Using SqueezeNet Backbone")
72 | self.use_range = params["input_depth"]["range"]
73 | self.use_xyz = params["input_depth"]["xyz"]
74 | self.use_remission = params["input_depth"]["remission"]
75 | self.bn_d = params["bn_d"]
76 | self.drop_prob = params["dropout"]
77 | self.OS = params["OS"]
78 |
79 | # input depth calc
80 | self.input_depth = 0
81 | self.input_idxs = []
82 | if self.use_range:
83 | self.input_depth += 1
84 | self.input_idxs.append(0)
85 | if self.use_xyz:
86 | self.input_depth += 3
87 | self.input_idxs.extend([1, 2, 3])
88 | if self.use_remission:
89 | self.input_depth += 1
90 | self.input_idxs.append(4)
91 | print("Depth of backbone input = ", self.input_depth)
92 |
93 | # stride play
94 | self.strides = [2, 2, 2, 2]
95 | # check current stride
96 | current_os = 1
97 | for s in self.strides:
98 | current_os *= s
99 | print("Original OS: ", current_os)
100 |
101 | # make the new stride
102 | if self.OS > current_os:
103 | print("Can't do OS, ", self.OS,
104 | " because it is bigger than original ", current_os)
105 | else:
106 | # redo strides according to needed stride
107 | for i, stride in enumerate(reversed(self.strides), 0):
108 | if int(current_os) != self.OS:
109 | if stride == 2:
110 | current_os /= 2
111 | self.strides[-1 - i] = 1
112 | if int(current_os) == self.OS:
113 | break
114 | print("New OS: ", int(current_os))
115 | print("Strides: ", self.strides)
116 |
117 | # encoder
118 | self.conv1a = nn.Sequential(nn.Conv2d(self.input_depth, 64, kernel_size=3,
119 | stride=[1, self.strides[0]],
120 | padding=1),
121 | nn.BatchNorm2d(64, momentum=self.bn_d),
122 | nn.ReLU(inplace=True),
123 | CAM(64, bn_d=self.bn_d))
124 | self.conv1b = nn.Sequential(nn.Conv2d(self.input_depth, 64, kernel_size=1,
125 | stride=1, padding=0),
126 | nn.BatchNorm2d(64, momentum=self.bn_d))
127 | self.fire23 = nn.Sequential(nn.MaxPool2d(kernel_size=3,
128 | stride=[1, self.strides[1]],
129 | padding=1),
130 | Fire(64, 16, 64, 64, bn_d=self.bn_d),
131 | CAM(128, bn_d=self.bn_d),
132 | Fire(128, 16, 64, 64, bn_d=self.bn_d),
133 | CAM(128, bn_d=self.bn_d))
134 | self.fire45 = nn.Sequential(nn.MaxPool2d(kernel_size=3,
135 | stride=[1, self.strides[2]],
136 | padding=1),
137 | Fire(128, 32, 128, 128, bn_d=self.bn_d),
138 | Fire(256, 32, 128, 128, bn_d=self.bn_d))
139 | self.fire6789 = nn.Sequential(nn.MaxPool2d(kernel_size=3,
140 | stride=[1, self.strides[3]],
141 | padding=1),
142 | Fire(256, 48, 192, 192, bn_d=self.bn_d),
143 | Fire(384, 48, 192, 192, bn_d=self.bn_d),
144 | Fire(384, 64, 256, 256, bn_d=self.bn_d),
145 | Fire(512, 64, 256, 256, bn_d=self.bn_d))
146 |
147 | # output
148 | self.dropout = nn.Dropout2d(self.drop_prob)
149 |
150 | # last channels
151 | self.last_channels = 512
152 |
153 | def run_layer(self, x, layer, skips, os):
154 | y = layer(x)
155 | if y.shape[2] < x.shape[2] or y.shape[3] < x.shape[3]:
156 | skips[os] = x.detach()
157 | os *= 2
158 | x = y
159 | return x, skips, os
160 |
161 | def forward(self, x):
162 | # filter input
163 | x = x[:, self.input_idxs]
164 |
165 | # run cnn
166 | # store for skip connections
167 | skips = {}
168 | os = 1
169 |
170 | # encoder
171 | skip_in = self.conv1b(x)
172 | x = self.conv1a(x)
173 | # first skip done manually
174 | skips[1] = skip_in.detach()
175 | os *= 2
176 |
177 | x, skips, os = self.run_layer(x, self.fire23, skips, os)
178 | x, skips, os = self.run_layer(x, self.dropout, skips, os)
179 | x, skips, os = self.run_layer(x, self.fire45, skips, os)
180 | x, skips, os = self.run_layer(x, self.dropout, skips, os)
181 | x, skips, os = self.run_layer(x, self.fire6789, skips, os)
182 | x, skips, os = self.run_layer(x, self.dropout, skips, os)
183 |
184 | return x, skips
185 |
186 | def get_last_depth(self):
187 | return self.last_channels
188 |
189 | def get_input_depth(self):
190 | return self.input_depth
191 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/common/__init__.py
--------------------------------------------------------------------------------
/semantic_net/rangenet/common/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/common/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/semantic_net/rangenet/common/__pycache__/laserscan.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/common/__pycache__/laserscan.cpython-38.pyc
--------------------------------------------------------------------------------
/semantic_net/rangenet/common/laserscan.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # This file is covered by the LICENSE file in the root of this project.
3 | import numpy as np
4 |
5 |
6 | class LaserScan:
7 | """Class that contains LaserScan with x,y,z,r"""
8 | EXTENSIONS_SCAN = ['.bin']
9 |
10 | def __init__(self, project=False, H=64, W=1024, fov_up=3.0, fov_down=-25.0):
11 | self.project = project
12 | self.proj_H = H
13 | self.proj_W = W
14 | self.proj_fov_up = fov_up
15 | self.proj_fov_down = fov_down
16 | self.reset()
17 |
18 | def reset(self):
19 | """ Reset scan members. """
20 | self.points = np.zeros((0, 3), dtype=np.float32) # [m, 3]: x, y, z
21 | self.remissions = np.zeros((0, 1), dtype=np.float32) # [m ,1]: remission
22 |
23 | # projected range image - [H,W] range (-1 is no data)
24 | self.proj_range = np.full((self.proj_H, self.proj_W), -1,
25 | dtype=np.float32)
26 |
27 | # unprojected range (list of depths for each point)
28 | self.unproj_range = np.zeros((0, 1), dtype=np.float32)
29 |
30 | # projected point cloud xyz - [H,W,3] xyz coord (-1 is no data)
31 | self.proj_xyz = np.full((self.proj_H, self.proj_W, 3), -1,
32 | dtype=np.float32)
33 |
34 | # projected remission - [H,W] intensity (-1 is no data)
35 | self.proj_remission = np.full((self.proj_H, self.proj_W), -1,
36 | dtype=np.float32)
37 |
38 | # projected index (for each pixel, what I am in the pointcloud)
39 | # [H,W] index (-1 is no data)
40 | self.proj_idx = np.full((self.proj_H, self.proj_W), -1,
41 | dtype=np.int32)
42 |
43 | # for each point, where it is in the range image
44 | self.proj_x = np.zeros((0, 1), dtype=np.int32) # [m, 1]: x
45 | self.proj_y = np.zeros((0, 1), dtype=np.int32) # [m, 1]: y
46 |
47 | # mask containing for each pixel, if it contains a point or not
48 | self.proj_mask = np.zeros((self.proj_H, self.proj_W),
49 | dtype=np.int32) # [H,W] mask
50 |
51 | def size(self):
52 | """ Return the size of the point cloud. """
53 | return self.points.shape[0]
54 |
55 | def __len__(self):
56 | return self.size()
57 |
58 | def open_scan(self, filename):
59 | """ Open raw scan and fill in attributes
60 | """
61 | # reset just in case there was an open structure
62 | self.reset()
63 |
64 | # check filename is string
65 | if not isinstance(filename, str):
66 | raise TypeError("Filename should be string type, "
67 | "but was {type}".format(type=str(type(filename))))
68 |
69 | # check extension is a laserscan
70 | if not any(filename.endswith(ext) for ext in self.EXTENSIONS_SCAN):
71 | raise RuntimeError("Filename extension is not valid scan file.")
72 |
73 | # if all goes well, open pointcloud
74 | scan = np.fromfile(filename, dtype=np.float32)
75 | scan = scan.reshape((-1, 4))
76 |
77 | # put in attribute
78 | points = scan[:, 0:3] # get xyz
79 | remissions = scan[:, 3] # get remission
80 | self.set_points(points, remissions)
81 |
82 | def set_points(self, points, remissions=None):
83 | """ Set scan attributes (instead of opening from file)
84 | """
85 | # reset just in case there was an open structure
86 | self.reset()
87 |
88 | # check scan makes sense
89 | if not isinstance(points, np.ndarray):
90 | raise TypeError("Scan should be numpy array")
91 |
92 | # check remission makes sense
93 | if remissions is not None and not isinstance(remissions, np.ndarray):
94 | raise TypeError("Remissions should be numpy array")
95 |
96 | # put in attribute
97 | self.points = points # get xyz
98 | if remissions is not None:
99 | self.remissions = remissions # get remission
100 | else:
101 | self.remissions = np.zeros((points.shape[0]), dtype=np.float32)
102 |
103 | # if projection is wanted, then do it and fill in the structure
104 | if self.project:
105 | self.do_range_projection()
106 |
107 | def do_range_projection(self):
108 | """ Project a pointcloud into a spherical projection image.projection.
109 | Function takes no arguments because it can be also called externally
110 | if the value of the constructor was not set (in case you change your
111 | mind about wanting the projection)
112 | """
113 | # laser parameters
114 | fov_up = self.proj_fov_up / 180.0 * np.pi # field of view up in rad
115 | fov_down = self.proj_fov_down / 180.0 * np.pi # field of view down in rad
116 | fov = abs(fov_down) + abs(fov_up) # get field of view total in rad
117 |
118 | # get depth of all points
119 | depth = np.linalg.norm(self.points, 2, axis=1)
120 |
121 | # get scan components
122 | scan_x = self.points[:, 0]
123 | scan_y = self.points[:, 1]
124 | scan_z = self.points[:, 2]
125 |
126 | # get angles of all points
127 | yaw = -np.arctan2(scan_y, scan_x)
128 | pitch = np.arcsin(scan_z / depth)
129 |
130 | # get projections in image coords
131 | proj_x = 0.5 * (yaw / np.pi + 1.0) # in [0.0, 1.0]
132 | proj_y = 1.0 - (pitch + abs(fov_down)) / fov # in [0.0, 1.0]
133 |
134 | # scale to image size using angular resolution
135 | proj_x *= self.proj_W # in [0.0, W]
136 | proj_y *= self.proj_H # in [0.0, H]
137 |
138 | # round and clamp for use as index
139 | proj_x = np.floor(proj_x)
140 | proj_x = np.minimum(self.proj_W - 1, proj_x)
141 | proj_x = np.maximum(0, proj_x).astype(np.int32) # in [0,W-1]
142 | self.proj_x = np.copy(proj_x) # store a copy in orig order
143 |
144 | proj_y = np.floor(proj_y)
145 | proj_y = np.minimum(self.proj_H - 1, proj_y)
146 | proj_y = np.maximum(0, proj_y).astype(np.int32) # in [0,H-1]
147 | self.proj_y = np.copy(proj_y) # stope a copy in original order
148 |
149 | # copy of depth in original order
150 | self.unproj_range = np.copy(depth)
151 |
152 | # order in decreasing depth
153 | indices = np.arange(depth.shape[0])
154 | order = np.argsort(depth)[::-1]
155 | depth = depth[order]
156 | indices = indices[order]
157 | points = self.points[order]
158 | remission = self.remissions[order]
159 | proj_y = proj_y[order]
160 | proj_x = proj_x[order]
161 |
162 | # assing to images
163 | self.proj_range[proj_y, proj_x] = depth
164 | self.proj_xyz[proj_y, proj_x] = points
165 | self.proj_remission[proj_y, proj_x] = remission
166 | self.proj_idx[proj_y, proj_x] = indices
167 | self.proj_mask = (self.proj_idx > 0).astype(np.int32)
168 |
169 |
170 | class SemLaserScan(LaserScan):
171 | """Class that contains LaserScan with x,y,z,r,sem_label,sem_color_label,inst_label,inst_color_label"""
172 | EXTENSIONS_LABEL = ['.label']
173 |
174 | def __init__(self, sem_color_dict=None, project=False, H=64, W=1024, fov_up=3.0, fov_down=-25.0, max_classes=300):
175 | super(SemLaserScan, self).__init__(project, H, W, fov_up, fov_down)
176 | self.reset()
177 |
178 | # make semantic colors
179 | if sem_color_dict:
180 | # if I have a dict, make it
181 | max_sem_key = 0
182 | for key, data in sem_color_dict.items():
183 | if key + 1 > max_sem_key:
184 | max_sem_key = key + 1
185 | self.sem_color_lut = np.zeros((max_sem_key + 100, 3), dtype=np.float32)
186 | for key, value in sem_color_dict.items():
187 | self.sem_color_lut[key] = np.array(value, np.float32) / 255.0
188 | else:
189 | # otherwise make random
190 | max_sem_key = max_classes
191 | self.sem_color_lut = np.random.uniform(low=0.0,
192 | high=1.0,
193 | size=(max_sem_key, 3))
194 | # force zero to a gray-ish color
195 | self.sem_color_lut[0] = np.full((3), 0.1)
196 |
197 | # make instance colors
198 | max_inst_id = 100000
199 | self.inst_color_lut = np.random.uniform(low=0.0,
200 | high=1.0,
201 | size=(max_inst_id, 3))
202 | # force zero to a gray-ish color
203 | self.inst_color_lut[0] = np.full((3), 0.1)
204 |
205 | def reset(self):
206 | """ Reset scan members. """
207 | super(SemLaserScan, self).reset()
208 |
209 | # semantic labels
210 | self.sem_label = np.zeros((0, 1), dtype=np.int32) # [m, 1]: label
211 | self.sem_label_color = np.zeros((0, 3), dtype=np.float32) # [m ,3]: color
212 |
213 | # instance labels
214 | self.inst_label = np.zeros((0, 1), dtype=np.int32) # [m, 1]: label
215 | self.inst_label_color = np.zeros((0, 3), dtype=np.float32) # [m ,3]: color
216 |
217 | # projection color with semantic labels
218 | self.proj_sem_label = np.zeros((self.proj_H, self.proj_W),
219 | dtype=np.int32) # [H,W] label
220 | self.proj_sem_color = np.zeros((self.proj_H, self.proj_W, 3),
221 | dtype=np.float) # [H,W,3] color
222 |
223 | # projection color with instance labels
224 | self.proj_inst_label = np.zeros((self.proj_H, self.proj_W),
225 | dtype=np.int32) # [H,W] label
226 | self.proj_inst_color = np.zeros((self.proj_H, self.proj_W, 3),
227 | dtype=np.float) # [H,W,3] color
228 |
229 | def open_label(self, filename):
230 | """ Open raw scan and fill in attributes
231 | """
232 | # check filename is string
233 | if not isinstance(filename, str):
234 | raise TypeError("Filename should be string type, "
235 | "but was {type}".format(type=str(type(filename))))
236 |
237 | # check extension is a laserscan
238 | if not any(filename.endswith(ext) for ext in self.EXTENSIONS_LABEL):
239 | raise RuntimeError("Filename extension is not valid label file.")
240 |
241 | # if all goes well, open label
242 | label = np.fromfile(filename, dtype=np.int32)
243 | label = label.reshape((-1))
244 |
245 | # set it
246 | self.set_label(label)
247 |
248 | def set_label(self, label):
249 | """ Set points for label not from file but from np
250 | """
251 | # check label makes sense
252 | if not isinstance(label, np.ndarray):
253 | raise TypeError("Label should be numpy array")
254 |
255 | # only fill in attribute if the right size
256 | if label.shape[0] == self.points.shape[0]:
257 | self.sem_label = label & 0xFFFF # semantic label in lower half
258 | self.inst_label = label >> 16 # instance id in upper half
259 | else:
260 | print("Points shape: ", self.points.shape)
261 | print("Label shape: ", label.shape)
262 | raise ValueError("Scan and Label don't contain same number of points")
263 |
264 | # sanity check
265 | assert((self.sem_label + (self.inst_label << 16) == label).all())
266 |
267 | if self.project:
268 | self.do_label_projection()
269 |
270 | def colorize(self):
271 | """ Colorize pointcloud with the color of each semantic label
272 | """
273 | self.sem_label_color = self.sem_color_lut[self.sem_label]
274 | self.sem_label_color = self.sem_label_color.reshape((-1, 3))
275 |
276 | self.inst_label_color = self.inst_color_lut[self.inst_label]
277 | self.inst_label_color = self.inst_label_color.reshape((-1, 3))
278 |
279 | def do_label_projection(self):
280 | # only map colors to labels that exist
281 | mask = self.proj_idx >= 0
282 |
283 | # semantics
284 | self.proj_sem_label[mask] = self.sem_label[self.proj_idx[mask]]
285 | self.proj_sem_color[mask] = self.sem_color_lut[self.sem_label[self.proj_idx[mask]]]
286 |
287 | # instances
288 | self.proj_inst_label[mask] = self.inst_label[self.proj_idx[mask]]
289 | self.proj_inst_color[mask] = self.inst_color_lut[self.inst_label[self.proj_idx[mask]]]
290 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.14.0
2 | scipy==0.19.1
3 | torch==1.1.0
4 | tensorflow==1.13.1
5 | vispy==0.5.3
6 | torchvision==0.2.2.post3
7 | opencv_contrib_python==4.1.0.25
8 | matplotlib==2.2.3
9 | Pillow==6.1.0
10 | PyYAML==5.1.1
11 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/__init__.py
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/README.md:
--------------------------------------------------------------------------------
1 | # LiDAR-Bonnetal Semantic Segmentation Training
2 |
3 | This part of the framework deals with the training of semantic segmentation networks for point cloud data using range images. This code allows to reproduce the experiments from the [RangeNet++](http://www.ipb.uni-bonn.de/wp-content/papercite-data/pdf/milioto2019iros.pdf) paper
4 |
5 | _Examples of segmentation results from [SemanticKITTI](http://semantic-kitti.org) dataset:_
6 | 
7 | 
8 |
9 | ## Configuration files
10 |
11 | Architecture configuration files are located at [config/arch](config/arch/)
12 | Dataset configuration files are located at [config/labels](config/labels/)
13 |
14 | ## Apps
15 |
16 | `ALL SCRIPTS CAN BE INVOKED WITH -h TO GET EXTRA HELP ON HOW TO RUN THEM`
17 |
18 | ### Visualization
19 |
20 | To visualize the data (in this example sequence 00):
21 |
22 | ```sh
23 | $ ./visualize.py -d /path/to/dataset/ -s 00
24 | ```
25 |
26 | To visualize the predictions (in this example sequence 00):
27 |
28 | ```sh
29 | $ ./visualize.py -d /path/to/dataset/ -p /path/to/predictions/ -s 00
30 | ```
31 |
32 | ### Training
33 |
34 | To train a network (from scratch):
35 |
36 | ```sh
37 | $ ./train.py -d /path/to/dataset -ac /config/arch/CHOICE.yaml -l /path/to/log
38 | ```
39 |
40 | To train a network (from pretrained model):
41 |
42 | ```
43 | $ ./train.py -d /path/to/dataset -ac /config/arch/CHOICE.yaml -dc /config/labels/CHOICE.yaml -l /path/to/log -p /path/to/pretrained
44 | ```
45 |
46 | This will generate a tensorboard log, which can be visualized by running:
47 |
48 | ```sh
49 | $ cd /path/to/log
50 | $ tensorboard --logdir=. --port 5555
51 | ```
52 |
53 | And acccessing [http://localhost:5555](http://localhost:5555) in your browser.
54 |
55 | ### Inference
56 |
57 | To infer the predictions for the entire dataset:
58 |
59 | ```sh
60 | $ ./infer.py -d /path/to/dataset/ -l /path/for/predictions -m /path/to/model
61 | ````
62 |
63 | ### Evaluation
64 |
65 | To evaluate the overall IoU of the point clouds (of a specific split, which in semantic kitti can only be train and valid, since test is only run in our evaluation server):
66 |
67 | ```sh
68 | $ ./evaluate_iou.py -d /path/to/dataset -p /path/to/predictions/ --split valid
69 | ```
70 |
71 | To evaluate the border IoU of the point clouds (introduced in RangeNet++ paper):
72 |
73 | ```sh
74 | $ ./evaluate_biou.py -d /path/to/dataset -p /path/to/predictions/ --split valid --border 1 --conn 4
75 | ```
76 |
77 | ## Pre-trained Models
78 |
79 | ### [SemanticKITTI](http://semantic-kitti.org)
80 |
81 | - [squeezeseg](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/squeezeseg.tar.gz)
82 | - [squeezeseg + crf](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/squeezeseg-crf.tar.gz)
83 | - [squeezesegV2](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/squeezesegV2.tar.gz)
84 | - [squeezesegV2 + crf](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/squeezesegV2-crf.tar.gz)
85 | - [darknet21](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/darknet21.tar.gz)
86 | - [darknet53](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/darknet53.tar.gz)
87 | - [darknet53-1024](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/darknet53-1024.tar.gz)
88 | - [darknet53-512](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/models/darknet53-512.tar.gz)
89 |
90 | To enable kNN post-processing, just change the boolean value to `True` in the `arch_cfg.yaml` file parameter, inside the model directory.
91 |
92 | ## Predictions from Models
93 |
94 | ### [SemanticKITTI](http://semantic-kitti.org)
95 |
96 | These are the predictions for the train, validation, and test sets. The performance can be evaluated for the training and validation set, but for test set evaluation a submission to the benchmark needs to be made (labels are not public).
97 |
98 | No post-processing:
99 | - [squeezeseg](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/squeezeseg.tar.gz)
100 | - [squeezeseg + crf](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/squeezeseg-crf.tar.gz)
101 | - [squeezesegV2](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/squeezesegV2.tar.gz)
102 | - [squeezesegV2 + crf](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/squeezesegV2-crf.tar.gz)
103 | - [darknet21](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet21.tar.gz)
104 | - [darknet53](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet53.tar.gz)
105 | - [darknet53-1024](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet53-1024.tar.gz)
106 | - [darknet53-512](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet53-512.tar.gz)
107 |
108 | With k-NN processing:
109 | - [squeezeseg](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/squeezeseg-knn.tar.gz)
110 | - [squeezesegV2](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/squeezesegV2-knn.tar.gz)
111 | - [darknet53](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet53-knn.tar.gz)
112 | - [darknet21](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet21-knn.tar.gz)
113 | - [darknet53-1024](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet53-1024-knn.tar.gz)
114 | - [darknet53-512](http://www.ipb.uni-bonn.de/html/projects/bonnetal/lidar/semantic/predictions/darknet53-512-knn.tar.gz)
115 |
116 | ## Citations
117 |
118 | If you use our framework, model, or predictions for any academic work, please cite the original [paper](http://www.ipb.uni-bonn.de/wp-content/papercite-data/pdf/milioto2019iros.pdf), and the [dataset](http://semantic-kitti.org).
119 |
120 | ```
121 | @inproceedings{milioto2019iros,
122 | author = {A. Milioto and I. Vizzo and J. Behley and C. Stachniss},
123 | title = {{RangeNet++: Fast and Accurate LiDAR Semantic Segmentation}},
124 | booktitle = {IEEE/RSJ Intl.~Conf.~on Intelligent Robots and Systems (IROS)},
125 | year = 2019,
126 | codeurl = {https://github.com/PRBonn/lidar-bonnetal},
127 | videourl = {https://youtu.be/wuokg7MFZyU},
128 | }
129 | ```
130 |
131 | ```
132 | @inproceedings{behley2019iccv,
133 | author = {J. Behley and M. Garbade and A. Milioto and J. Quenzel and S. Behnke and C. Stachniss and J. Gall},
134 | title = {{SemanticKITTI: A Dataset for Semantic Scene Understanding of LiDAR Sequences}},
135 | booktitle = {Proc. of the IEEE/CVF International Conf.~on Computer Vision (ICCV)},
136 | year = {2019}
137 | }
138 | ```
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | TRAIN_PATH = "../../"
3 | DEPLOY_PATH = "../../../deploy"
4 | sys.path.insert(0, TRAIN_PATH)
5 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/config/labels/semantic-kitti-all.yaml:
--------------------------------------------------------------------------------
1 | # This file is covered by the LICENSE file in the root of this project.
2 | name: "kitti"
3 | labels:
4 | 0 : "unlabeled"
5 | 1 : "outlier"
6 | 10: "car"
7 | 11: "bicycle"
8 | 13: "bus"
9 | 15: "motorcycle"
10 | 16: "on-rails"
11 | 18: "truck"
12 | 20: "other-vehicle"
13 | 30: "person"
14 | 31: "bicyclist"
15 | 32: "motorcyclist"
16 | 40: "road"
17 | 44: "parking"
18 | 48: "sidewalk"
19 | 49: "other-ground"
20 | 50: "building"
21 | 51: "fence"
22 | 52: "other-structure"
23 | 60: "lane-marking"
24 | 70: "vegetation"
25 | 71: "trunk"
26 | 72: "terrain"
27 | 80: "pole"
28 | 81: "traffic-sign"
29 | 99: "other-object"
30 | 252: "moving-car"
31 | 253: "moving-bicyclist"
32 | 254: "moving-person"
33 | 255: "moving-motorcyclist"
34 | 256: "moving-on-rails"
35 | 257: "moving-bus"
36 | 258: "moving-truck"
37 | 259: "moving-other-vehicle"
38 | color_map: # bgr
39 | 0 : [0, 0, 0]
40 | 1 : [0, 0, 255]
41 | 10: [245, 150, 100]
42 | 11: [245, 230, 100]
43 | 13: [250, 80, 100]
44 | 15: [150, 60, 30]
45 | 16: [255, 0, 0]
46 | 18: [180, 30, 80]
47 | 20: [255, 0, 0]
48 | 30: [30, 30, 255]
49 | 31: [200, 40, 255]
50 | 32: [90, 30, 150]
51 | 40: [255, 0, 255]
52 | 44: [255, 150, 255]
53 | 48: [75, 0, 75]
54 | 49: [75, 0, 175]
55 | 50: [0, 200, 255]
56 | 51: [50, 120, 255]
57 | 52: [0, 150, 255]
58 | 60: [170, 255, 150]
59 | 70: [0, 175, 0]
60 | 71: [0, 60, 135]
61 | 72: [80, 240, 150]
62 | 80: [150, 240, 255]
63 | 81: [0, 0, 255]
64 | 99: [255, 255, 50]
65 | 252: [245, 150, 100]
66 | 256: [255, 0, 0]
67 | 253: [200, 40, 255]
68 | 254: [30, 30, 255]
69 | 255: [90, 30, 150]
70 | 257: [250, 80, 100]
71 | 258: [180, 30, 80]
72 | 259: [255, 0, 0]
73 | content: # as a ratio with the total number of points
74 | 0: 0.018889854628292943
75 | 1: 0.0002937197336781505
76 | 10: 0.040818519255974316
77 | 11: 0.00016609538710764618
78 | 13: 2.7879693665067774e-05
79 | 15: 0.00039838616015114444
80 | 16: 0.0
81 | 18: 0.0020633612104619787
82 | 20: 0.0016218197275284021
83 | 30: 0.00017698551338515307
84 | 31: 1.1065903904919655e-08
85 | 32: 5.532951952459828e-09
86 | 40: 0.1987493871255525
87 | 44: 0.014717169549888214
88 | 48: 0.14392298360372
89 | 49: 0.0039048553037472045
90 | 50: 0.1326861944777486
91 | 51: 0.0723592229456223
92 | 52: 0.002395131480328884
93 | 60: 4.7084144280367186e-05
94 | 70: 0.26681502148037506
95 | 71: 0.006035012012626033
96 | 72: 0.07814222006271769
97 | 80: 0.002855498193863172
98 | 81: 0.0006155958086189918
99 | 99: 0.009923127583046915
100 | 252: 0.001789309418528068
101 | 253: 0.00012709999297008662
102 | 254: 0.00016059776092534436
103 | 255: 3.745553104802113e-05
104 | 256: 0.0
105 | 257: 0.00011351574470342043
106 | 258: 0.00010157861367183268
107 | 259: 4.3840131989471124e-05
108 | # classes that are indistinguishable from single scan or inconsistent in
109 | # ground truth are mapped to their closest equivalent
110 | learning_map:
111 | 0 : 0 # "unlabeled"
112 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped
113 | 10: 1 # "car"
114 | 11: 2 # "bicycle"
115 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped
116 | 15: 3 # "motorcycle"
117 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped
118 | 18: 4 # "truck"
119 | 20: 5 # "other-vehicle"
120 | 30: 6 # "person"
121 | 31: 7 # "bicyclist"
122 | 32: 8 # "motorcyclist"
123 | 40: 9 # "road"
124 | 44: 10 # "parking"
125 | 48: 11 # "sidewalk"
126 | 49: 12 # "other-ground"
127 | 50: 13 # "building"
128 | 51: 14 # "fence"
129 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped
130 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped
131 | 70: 15 # "vegetation"
132 | 71: 16 # "trunk"
133 | 72: 17 # "terrain"
134 | 80: 18 # "pole"
135 | 81: 19 # "traffic-sign"
136 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped
137 | 252: 20 # "moving-car"
138 | 253: 21 # "moving-bicyclist"
139 | 254: 22 # "moving-person"
140 | 255: 23 # "moving-motorcyclist"
141 | 256: 24 # "moving-on-rails" mapped to "moving-other-vehicle" ------mapped
142 | 257: 24 # "moving-bus" mapped to "moving-other-vehicle" -----------mapped
143 | 258: 25 # "moving-truck"
144 | 259: 24 # "moving-other-vehicle"
145 | learning_map_inv: # inverse of previous map
146 | 0: 0 # "unlabeled", and others ignored
147 | 1: 10 # "car"
148 | 2: 11 # "bicycle"
149 | 3: 15 # "motorcycle"
150 | 4: 18 # "truck"
151 | 5: 20 # "other-vehicle"
152 | 6: 30 # "person"
153 | 7: 31 # "bicyclist"
154 | 8: 32 # "motorcyclist"
155 | 9: 40 # "road"
156 | 10: 44 # "parking"
157 | 11: 48 # "sidewalk"
158 | 12: 49 # "other-ground"
159 | 13: 50 # "building"
160 | 14: 51 # "fence"
161 | 15: 70 # "vegetation"
162 | 16: 71 # "trunk"
163 | 17: 72 # "terrain"
164 | 18: 80 # "pole"
165 | 19: 81 # "traffic-sign"
166 | 20: 252 # "moving-car"
167 | 21: 253 # "moving-bicyclist"
168 | 22: 254 # "moving-person"
169 | 23: 255 # "moving-motorcyclist"
170 | 24: 259 # "moving-other-vehicle"
171 | 25: 258 # "moving-truck"
172 | learning_ignore: # Ignore classes
173 | 0: True # "unlabeled", and others ignored
174 | 1: False # "car"
175 | 2: False # "bicycle"
176 | 3: False # "motorcycle"
177 | 4: False # "truck"
178 | 5: False # "other-vehicle"
179 | 6: False # "person"
180 | 7: False # "bicyclist"
181 | 8: False # "motorcyclist"
182 | 9: False # "road"
183 | 10: False # "parking"
184 | 11: False # "sidewalk"
185 | 12: False # "other-ground"
186 | 13: False # "building"
187 | 14: False # "fence"
188 | 15: False # "vegetation"
189 | 16: False # "trunk"
190 | 17: False # "terrain"
191 | 18: False # "pole"
192 | 19: False # "traffic-sign"
193 | 20: False # "moving-car"
194 | 21: False # "moving-bicyclist"
195 | 22: False # "moving-person"
196 | 23: False # "moving-motorcyclist"
197 | 24: False # "moving-other-vehicle"
198 | 25: False # "moving-truck"
199 | split: # sequence numbers
200 | train:
201 | - 0
202 | - 1
203 | - 2
204 | - 3
205 | - 4
206 | - 5
207 | - 6
208 | - 7
209 | - 9
210 | - 10
211 | valid:
212 | - 8
213 | test:
214 | - 11
215 | - 12
216 | - 13
217 | - 14
218 | - 15
219 | - 16
220 | - 17
221 | - 18
222 | - 19
223 | - 20
224 | - 21
225 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/config/labels/semantic-kitti.yaml:
--------------------------------------------------------------------------------
1 | # This file is covered by the LICENSE file in the root of this project.
2 | name: "kitti"
3 | labels:
4 | 0 : "unlabeled"
5 | 1 : "outlier"
6 | 10: "car"
7 | 11: "bicycle"
8 | 13: "bus"
9 | 15: "motorcycle"
10 | 16: "on-rails"
11 | 18: "truck"
12 | 20: "other-vehicle"
13 | 30: "person"
14 | 31: "bicyclist"
15 | 32: "motorcyclist"
16 | 40: "road"
17 | 44: "parking"
18 | 48: "sidewalk"
19 | 49: "other-ground"
20 | 50: "building"
21 | 51: "fence"
22 | 52: "other-structure"
23 | 60: "lane-marking"
24 | 70: "vegetation"
25 | 71: "trunk"
26 | 72: "terrain"
27 | 80: "pole"
28 | 81: "traffic-sign"
29 | 99: "other-object"
30 | 252: "moving-car"
31 | 253: "moving-bicyclist"
32 | 254: "moving-person"
33 | 255: "moving-motorcyclist"
34 | 256: "moving-on-rails"
35 | 257: "moving-bus"
36 | 258: "moving-truck"
37 | 259: "moving-other-vehicle"
38 | color_map: # bgr
39 | 0 : [0, 0, 0]
40 | 1 : [0, 0, 255]
41 | 10: [245, 150, 100]
42 | 11: [245, 230, 100]
43 | 13: [250, 80, 100]
44 | 15: [150, 60, 30]
45 | 16: [255, 0, 0]
46 | 18: [180, 30, 80]
47 | 20: [255, 0, 0]
48 | 30: [30, 30, 255]
49 | 31: [200, 40, 255]
50 | 32: [90, 30, 150]
51 | 40: [255, 0, 255]
52 | 44: [255, 150, 255]
53 | 48: [75, 0, 75]
54 | 49: [75, 0, 175]
55 | 50: [0, 200, 255]
56 | 51: [50, 120, 255]
57 | 52: [0, 150, 255]
58 | 60: [170, 255, 150]
59 | 70: [0, 175, 0]
60 | 71: [0, 60, 135]
61 | 72: [80, 240, 150]
62 | 80: [150, 240, 255]
63 | 81: [0, 0, 255]
64 | 99: [255, 255, 50]
65 | 252: [245, 150, 100]
66 | 256: [255, 0, 0]
67 | 253: [200, 40, 255]
68 | 254: [30, 30, 255]
69 | 255: [90, 30, 150]
70 | 257: [250, 80, 100]
71 | 258: [180, 30, 80]
72 | 259: [255, 0, 0]
73 | content: # as a ratio with the total number of points
74 | 0: 0.018889854628292943
75 | 1: 0.0002937197336781505
76 | 10: 0.040818519255974316
77 | 11: 0.00016609538710764618
78 | 13: 2.7879693665067774e-05
79 | 15: 0.00039838616015114444
80 | 16: 0.0
81 | 18: 0.0020633612104619787
82 | 20: 0.0016218197275284021
83 | 30: 0.00017698551338515307
84 | 31: 1.1065903904919655e-08
85 | 32: 5.532951952459828e-09
86 | 40: 0.1987493871255525
87 | 44: 0.014717169549888214
88 | 48: 0.14392298360372
89 | 49: 0.0039048553037472045
90 | 50: 0.1326861944777486
91 | 51: 0.0723592229456223
92 | 52: 0.002395131480328884
93 | 60: 4.7084144280367186e-05
94 | 70: 0.26681502148037506
95 | 71: 0.006035012012626033
96 | 72: 0.07814222006271769
97 | 80: 0.002855498193863172
98 | 81: 0.0006155958086189918
99 | 99: 0.009923127583046915
100 | 252: 0.001789309418528068
101 | 253: 0.00012709999297008662
102 | 254: 0.00016059776092534436
103 | 255: 3.745553104802113e-05
104 | 256: 0.0
105 | 257: 0.00011351574470342043
106 | 258: 0.00010157861367183268
107 | 259: 4.3840131989471124e-05
108 | # classes that are indistinguishable from single scan or inconsistent in
109 | # ground truth are mapped to their closest equivalent
110 | learning_map:
111 | 0 : 0 # "unlabeled"
112 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped
113 | 10: 1 # "car"
114 | 11: 2 # "bicycle"
115 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped
116 | 15: 3 # "motorcycle"
117 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped
118 | 18: 4 # "truck"
119 | 20: 5 # "other-vehicle"
120 | 30: 6 # "person"
121 | 31: 7 # "bicyclist"
122 | 32: 8 # "motorcyclist"
123 | 40: 9 # "road"
124 | 44: 10 # "parking"
125 | 48: 11 # "sidewalk"
126 | 49: 12 # "other-ground"
127 | 50: 13 # "building"
128 | 51: 14 # "fence"
129 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped
130 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped
131 | 70: 15 # "vegetation"
132 | 71: 16 # "trunk"
133 | 72: 17 # "terrain"
134 | 80: 18 # "pole"
135 | 81: 19 # "traffic-sign"
136 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped
137 | 252: 1 # "moving-car" to "car" ------------------------------------mapped
138 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped
139 | 254: 6 # "moving-person" to "person" ------------------------------mapped
140 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped
141 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped
142 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped
143 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped
144 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped
145 | learning_map_inv: # inverse of previous map
146 | 0: 0 # "unlabeled", and others ignored
147 | 1: 10 # "car"
148 | 2: 11 # "bicycle"
149 | 3: 15 # "motorcycle"
150 | 4: 18 # "truck"
151 | 5: 20 # "other-vehicle"
152 | 6: 30 # "person"
153 | 7: 31 # "bicyclist"
154 | 8: 32 # "motorcyclist"
155 | 9: 40 # "road"
156 | 10: 44 # "parking"
157 | 11: 48 # "sidewalk"
158 | 12: 49 # "other-ground"
159 | 13: 50 # "building"
160 | 14: 51 # "fence"
161 | 15: 70 # "vegetation"
162 | 16: 71 # "trunk"
163 | 17: 72 # "terrain"
164 | 18: 80 # "pole"
165 | 19: 81 # "traffic-sign"
166 | learning_ignore: # Ignore classes
167 | 0: True # "unlabeled", and others ignored
168 | 1: False # "car"
169 | 2: False # "bicycle"
170 | 3: False # "motorcycle"
171 | 4: False # "truck"
172 | 5: False # "other-vehicle"
173 | 6: False # "person"
174 | 7: False # "bicyclist"
175 | 8: False # "motorcyclist"
176 | 9: False # "road"
177 | 10: False # "parking"
178 | 11: False # "sidewalk"
179 | 12: False # "other-ground"
180 | 13: False # "building"
181 | 14: False # "fence"
182 | 15: False # "vegetation"
183 | 16: False # "trunk"
184 | 17: False # "terrain"
185 | 18: False # "pole"
186 | 19: False # "traffic-sign"
187 | split: # sequence numbers
188 | train:
189 | - 0
190 | - 1
191 | - 2
192 | - 3
193 | - 4
194 | - 5
195 | - 6
196 | - 7
197 | - 9
198 | - 10
199 | valid:
200 | - 8
201 | test:
202 | - 11
203 | - 12
204 | - 13
205 | - 14
206 | - 15
207 | - 16
208 | - 17
209 | - 18
210 | - 19
211 | - 20
212 | - 21
213 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/decoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/decoders/__init__.py
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/decoders/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/decoders/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/decoders/__pycache__/squeezeseg.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/decoders/__pycache__/squeezeseg.cpython-38.pyc
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/decoders/__pycache__/squeezesegV2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/decoders/__pycache__/squeezesegV2.cpython-38.pyc
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/decoders/darknet.py:
--------------------------------------------------------------------------------
1 | # This file was modified from https://github.com/BobLiu20/YOLOv3_PyTorch
2 | # It needed to be modified in order to accomodate for different strides in the
3 |
4 | import torch.nn as nn
5 | from collections import OrderedDict
6 | import torch.nn.functional as F
7 |
8 |
9 | class BasicBlock(nn.Module):
10 | def __init__(self, inplanes, planes, bn_d=0.1):
11 | super(BasicBlock, self).__init__()
12 | self.conv1 = nn.Conv2d(inplanes, planes[0], kernel_size=1,
13 | stride=1, padding=0, bias=False)
14 | self.bn1 = nn.BatchNorm2d(planes[0], momentum=bn_d)
15 | self.relu1 = nn.LeakyReLU(0.1)
16 | self.conv2 = nn.Conv2d(planes[0], planes[1], kernel_size=3,
17 | stride=1, padding=1, bias=False)
18 | self.bn2 = nn.BatchNorm2d(planes[1], momentum=bn_d)
19 | self.relu2 = nn.LeakyReLU(0.1)
20 |
21 | def forward(self, x):
22 | residual = x
23 |
24 | out = self.conv1(x)
25 | out = self.bn1(out)
26 | out = self.relu1(out)
27 |
28 | out = self.conv2(out)
29 | out = self.bn2(out)
30 | out = self.relu2(out)
31 |
32 | out += residual
33 | return out
34 |
35 |
36 | # ******************************************************************************
37 |
38 | class Decoder(nn.Module):
39 | """
40 | Class for DarknetSeg. Subclasses PyTorch's own "nn" module
41 | """
42 |
43 | def __init__(self, params, stub_skips, OS=32, feature_depth=1024):
44 | super(Decoder, self).__init__()
45 | self.backbone_OS = OS
46 | self.backbone_feature_depth = feature_depth
47 | self.drop_prob = params["dropout"]
48 | self.bn_d = params["bn_d"]
49 |
50 | # stride play
51 | self.strides = [2, 2, 2, 2, 2]
52 | # check current stride
53 | current_os = 1
54 | for s in self.strides:
55 | current_os *= s
56 | print("Decoder original OS: ", int(current_os))
57 | # redo strides according to needed stride
58 | for i, stride in enumerate(self.strides):
59 | if int(current_os) != self.backbone_OS:
60 | if stride == 2:
61 | current_os /= 2
62 | self.strides[i] = 1
63 | if int(current_os) == self.backbone_OS:
64 | break
65 | print("Decoder new OS: ", int(current_os))
66 | print("Decoder strides: ", self.strides)
67 |
68 | # decoder
69 | self.dec5 = self._make_dec_layer(BasicBlock,
70 | [self.backbone_feature_depth, 512],
71 | bn_d=self.bn_d,
72 | stride=self.strides[0])
73 | self.dec4 = self._make_dec_layer(BasicBlock, [512, 256], bn_d=self.bn_d,
74 | stride=self.strides[1])
75 | self.dec3 = self._make_dec_layer(BasicBlock, [256, 128], bn_d=self.bn_d,
76 | stride=self.strides[2])
77 | self.dec2 = self._make_dec_layer(BasicBlock, [128, 64], bn_d=self.bn_d,
78 | stride=self.strides[3])
79 | self.dec1 = self._make_dec_layer(BasicBlock, [64, 32], bn_d=self.bn_d,
80 | stride=self.strides[4])
81 |
82 | # layer list to execute with skips
83 | self.layers = [self.dec5, self.dec4, self.dec3, self.dec2, self.dec1]
84 |
85 | # for a bit of fun
86 | self.dropout = nn.Dropout2d(self.drop_prob)
87 |
88 | # last channels
89 | self.last_channels = 32
90 |
91 | def _make_dec_layer(self, block, planes, bn_d=0.1, stride=2):
92 | layers = []
93 |
94 | # downsample
95 | if stride == 2:
96 | layers.append(("upconv", nn.ConvTranspose2d(planes[0], planes[1],
97 | kernel_size=[1, 4], stride=[1, 2],
98 | padding=[0, 1])))
99 | else:
100 | layers.append(("conv", nn.Conv2d(planes[0], planes[1],
101 | kernel_size=3, padding=1)))
102 | layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d)))
103 | layers.append(("relu", nn.LeakyReLU(0.1)))
104 |
105 | # blocks
106 | layers.append(("residual", block(planes[1], planes, bn_d)))
107 |
108 | return nn.Sequential(OrderedDict(layers))
109 |
110 | def run_layer(self, x, layer, skips, os):
111 | feats = layer(x) # up
112 | if feats.shape[-1] > x.shape[-1]:
113 | os //= 2 # match skip
114 | feats = feats + skips[os].detach() # add skip
115 | x = feats
116 | return x, skips, os
117 |
118 | def forward(self, x, skips):
119 | os = self.backbone_OS
120 |
121 | # run layers
122 | x, skips, os = self.run_layer(x, self.dec5, skips, os)
123 | x, skips, os = self.run_layer(x, self.dec4, skips, os)
124 | x, skips, os = self.run_layer(x, self.dec3, skips, os)
125 | x, skips, os = self.run_layer(x, self.dec2, skips, os)
126 | x, skips, os = self.run_layer(x, self.dec1, skips, os)
127 |
128 | x = self.dropout(x)
129 |
130 | return x
131 |
132 | def get_last_depth(self):
133 | return self.last_channels
134 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/decoders/squeezeseg.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/BichenWuUCB/SqueezeSeg
2 | from __future__ import print_function
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class FireUp(nn.Module):
9 |
10 | def __init__(self, inplanes, squeeze_planes,
11 | expand1x1_planes, expand3x3_planes, stride):
12 | super(FireUp, self).__init__()
13 | self.inplanes = inplanes
14 | self.stride = stride
15 | self.activation = nn.ReLU(inplace=True)
16 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
17 | if self.stride == 2:
18 | self.upconv = nn.ConvTranspose2d(squeeze_planes, squeeze_planes,
19 | kernel_size=[1, 4], stride=[1, 2],
20 | padding=[0, 1])
21 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
22 | kernel_size=1)
23 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
24 | kernel_size=3, padding=1)
25 |
26 | def forward(self, x):
27 | x = self.activation(self.squeeze(x))
28 | if self.stride == 2:
29 | x = self.activation(self.upconv(x))
30 | return torch.cat([
31 | self.activation(self.expand1x1(x)),
32 | self.activation(self.expand3x3(x))
33 | ], 1)
34 |
35 |
36 | # ******************************************************************************
37 |
38 | class Decoder(nn.Module):
39 | """
40 | Class for DarknetSeg. Subclasses PyTorch's own "nn" module
41 | """
42 |
43 | def __init__(self, params, stub_skips, OS=32, feature_depth=512):
44 | super(Decoder, self).__init__()
45 | self.backbone_OS = OS
46 | self.backbone_feature_depth = feature_depth
47 | self.drop_prob = params["dropout"]
48 |
49 | # stride play
50 | self.strides = [2, 2, 2, 2]
51 | # check current stride
52 | current_os = 1
53 | for s in self.strides:
54 | current_os *= s
55 | print("Decoder original OS: ", int(current_os))
56 | # redo strides according to needed stride
57 | for i, stride in enumerate(self.strides):
58 | if int(current_os) != self.backbone_OS:
59 | if stride == 2:
60 | current_os /= 2
61 | self.strides[i] = 1
62 | if int(current_os) == self.backbone_OS:
63 | break
64 | print("Decoder new OS: ", int(current_os))
65 | print("Decoder strides: ", self.strides)
66 |
67 | # decoder
68 | # decoder
69 | self.firedec10 = FireUp(self.backbone_feature_depth, 64, 128, 128,
70 | stride=self.strides[0])
71 | self.firedec11 = FireUp(256, 32, 64, 64,
72 | stride=self.strides[1])
73 | self.firedec12 = FireUp(128, 16, 32, 32,
74 | stride=self.strides[2])
75 | self.firedec13 = FireUp(64, 16, 32, 32,
76 | stride=self.strides[3])
77 |
78 | # layer list to execute with skips
79 | self.layers = [self.firedec10, self.firedec11,
80 | self.firedec12, self.firedec13]
81 |
82 | # for a bit of fun
83 | self.dropout = nn.Dropout2d(self.drop_prob)
84 |
85 | # last channels
86 | self.last_channels = 64
87 |
88 | def run_layer(self, x, layer, skips, os):
89 | feats = layer(x) # up
90 | if feats.shape[-1] > x.shape[-1]:
91 | os //= 2 # match skip
92 | feats = feats + skips[os].detach() # add skip
93 | x = feats
94 | return x, skips, os
95 |
96 | def forward(self, x, skips):
97 | os = self.backbone_OS
98 |
99 | # run layers
100 | x, skips, os = self.run_layer(x, self.firedec10, skips, os)
101 | x, skips, os = self.run_layer(x, self.firedec11, skips, os)
102 | x, skips, os = self.run_layer(x, self.firedec12, skips, os)
103 | x, skips, os = self.run_layer(x, self.firedec13, skips, os)
104 |
105 | x = self.dropout(x)
106 |
107 | return x
108 |
109 | def get_last_depth(self):
110 | return self.last_channels
111 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/decoders/squeezesegV2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # This file is covered by the LICENSE file in the root of this project.
3 |
4 | from __future__ import print_function
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class FireUp(nn.Module):
11 |
12 | def __init__(self, inplanes, squeeze_planes,
13 | expand1x1_planes, expand3x3_planes, bn_d, stride):
14 | super(FireUp, self).__init__()
15 | self.inplanes = inplanes
16 | self.bn_d = bn_d
17 | self.stride = stride
18 | self.activation = nn.ReLU(inplace=True)
19 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
20 | self.squeeze_bn = nn.BatchNorm2d(squeeze_planes, momentum=self.bn_d)
21 | if self.stride == 2:
22 | self.upconv = nn.ConvTranspose2d(squeeze_planes, squeeze_planes,
23 | kernel_size=[1, 4], stride=[1, 2],
24 | padding=[0, 1])
25 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
26 | kernel_size=1)
27 | self.expand1x1_bn = nn.BatchNorm2d(expand1x1_planes, momentum=self.bn_d)
28 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
29 | kernel_size=3, padding=1)
30 | self.expand3x3_bn = nn.BatchNorm2d(expand3x3_planes, momentum=self.bn_d)
31 |
32 | def forward(self, x):
33 | x = self.activation(self.squeeze_bn(self.squeeze(x)))
34 | if self.stride == 2:
35 | x = self.activation(self.upconv(x))
36 | return torch.cat([
37 | self.activation(self.expand1x1_bn(self.expand1x1(x))),
38 | self.activation(self.expand3x3_bn(self.expand3x3(x)))
39 | ], 1)
40 |
41 |
42 | # ******************************************************************************
43 |
44 | class Decoder(nn.Module):
45 | """
46 | Class for DarknetSeg. Subclasses PyTorch's own "nn" module
47 | """
48 |
49 | def __init__(self, params, stub_skips, OS=32, feature_depth=512):
50 | super(Decoder, self).__init__()
51 | self.backbone_OS = OS
52 | self.backbone_feature_depth = feature_depth
53 | self.drop_prob = params["dropout"]
54 | self.bn_d = params["bn_d"]
55 |
56 | # stride play
57 | self.strides = [2, 2, 2, 2]
58 | # check current stride
59 | current_os = 1
60 | for s in self.strides:
61 | current_os *= s
62 | print("Decoder original OS: ", int(current_os))
63 | # redo strides according to needed stride
64 | for i, stride in enumerate(self.strides):
65 | if int(current_os) != self.backbone_OS:
66 | if stride == 2:
67 | current_os /= 2
68 | self.strides[i] = 1
69 | if int(current_os) == self.backbone_OS:
70 | break
71 | print("Decoder new OS: ", int(current_os))
72 | print("Decoder strides: ", self.strides)
73 |
74 | # decoder
75 | # decoder
76 | self.firedec10 = FireUp(self.backbone_feature_depth,
77 | 64, 128, 128, bn_d=self.bn_d,
78 | stride=self.strides[0])
79 | self.firedec11 = FireUp(256, 32, 64, 64, bn_d=self.bn_d,
80 | stride=self.strides[1])
81 | self.firedec12 = FireUp(128, 16, 32, 32, bn_d=self.bn_d,
82 | stride=self.strides[2])
83 | self.firedec13 = FireUp(64, 16, 32, 32, bn_d=self.bn_d,
84 | stride=self.strides[3])
85 |
86 | # for a bit of fun
87 | self.dropout = nn.Dropout2d(self.drop_prob)
88 |
89 | # last channels
90 | self.last_channels = 64
91 |
92 | def run_layer(self, x, layer, skips, os):
93 | feats = layer(x) # up
94 | if feats.shape[-1] > x.shape[-1]:
95 | os //= 2 # match skip
96 | feats = feats + skips[os].detach() # add skip
97 | x = feats
98 | return x, skips, os
99 |
100 | def forward(self, x, skips):
101 | os = self.backbone_OS
102 |
103 | # run layers
104 | x, skips, os = self.run_layer(x, self.firedec10, skips, os)
105 | x, skips, os = self.run_layer(x, self.firedec11, skips, os)
106 | x, skips, os = self.run_layer(x, self.firedec12, skips, os)
107 | x, skips, os = self.run_layer(x, self.firedec13, skips, os)
108 |
109 | x = self.dropout(x)
110 |
111 | return x
112 |
113 | def get_last_depth(self):
114 | return self.last_channels
115 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/modules/__init__.py
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/modules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/modules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/modules/__pycache__/segmentator.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/modules/__pycache__/segmentator.cpython-38.pyc
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/modules/segmentator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # This file is covered by the LICENSE file in the root of this project.
3 |
4 | import imp
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from semantic_net.rangenet.tasks.semantic.postproc.CRF import CRF
9 | import importlib
10 | # import __init__ as booger
11 |
12 |
13 | class Segmentator(nn.Module):
14 | def __init__(self, ARCH, nclasses, path=None, path_append="", strict=False):
15 | super().__init__()
16 | self.ARCH = ARCH
17 | self.nclasses = nclasses
18 | self.path = path
19 | self.path_append = path_append
20 | self.strict = False
21 |
22 | # get the model
23 | # bboneModule = imp.load_source("bboneModule",
24 | # booger.TRAIN_PATH + '/backbones/' +
25 | # self.ARCH["backbone"]["name"] + '.py')
26 | backbone_path = "semantic_net.rangenet.backbones." + self.ARCH["backbone"]["name"]
27 | bboneModule = importlib.import_module(backbone_path)
28 | self.backbone = bboneModule.Backbone(params=self.ARCH["backbone"])
29 |
30 | # do a pass of the backbone to initialize the skip connections
31 | stub = torch.zeros((1,
32 | self.backbone.get_input_depth(),
33 | self.ARCH["dataset"]["sensor"]["img_prop"]["height"],
34 | self.ARCH["dataset"]["sensor"]["img_prop"]["width"]))
35 |
36 | if torch.cuda.is_available():
37 | stub = stub.cuda()
38 | self.backbone.cuda()
39 | _, stub_skips = self.backbone(stub)
40 |
41 | # decoderModule = imp.load_source("decoderModule",
42 | # booger.TRAIN_PATH + '/tasks/semantic/decoders/' +
43 | # self.ARCH["decoder"]["name"] + '.py')
44 | decoder_path = "semantic_net.rangenet.tasks.semantic.decoders." + self.ARCH["decoder"]["name"]
45 | decoderModule = importlib.import_module(decoder_path)
46 | self.decoder = decoderModule.Decoder(params=self.ARCH["decoder"],
47 | stub_skips=stub_skips,
48 | OS=self.ARCH["backbone"]["OS"],
49 | feature_depth=self.backbone.get_last_depth())
50 |
51 | self.head = nn.Sequential(nn.Dropout2d(p=ARCH["head"]["dropout"]),
52 | nn.Conv2d(self.decoder.get_last_depth(),
53 | self.nclasses, kernel_size=3,
54 | stride=1, padding=1))
55 |
56 | if self.ARCH["post"]["CRF"]["use"]:
57 | self.CRF = CRF(self.ARCH["post"]["CRF"]["params"], self.nclasses)
58 | else:
59 | self.CRF = None
60 |
61 | ### Do not Train any parameter in RangeNet++ (by Zhen Luo)
62 | # train backbone?
63 | # if not self.ARCH["backbone"]["train"]:
64 | for w in self.backbone.parameters():
65 | w.requires_grad = False
66 |
67 | # train decoder?
68 | # if not self.ARCH["decoder"]["train"]:
69 | for w in self.decoder.parameters():
70 | w.requires_grad = False
71 |
72 | # train head?
73 | # if not self.ARCH["head"]["train"]:
74 | for w in self.head.parameters():
75 | w.requires_grad = False
76 |
77 | # train CRF?
78 | # if self.CRF and not self.ARCH["post"]["CRF"]["train"]:
79 | if self.CRF:
80 | for w in self.CRF.parameters():
81 | w.requires_grad = False
82 |
83 | # print number of parameters and the ones requiring gradients
84 | # print number of parameters and the ones requiring gradients
85 | weights_total = sum(p.numel() for p in self.parameters())
86 | weights_grad = sum(p.numel() for p in self.parameters() if p.requires_grad)
87 | print("Total number of parameters: ", weights_total)
88 | print("Total number of parameters requires_grad: ", weights_grad)
89 |
90 | # breakdown by layer
91 | weights_enc = sum(p.numel() for p in self.backbone.parameters())
92 | weights_dec = sum(p.numel() for p in self.decoder.parameters())
93 | weights_head = sum(p.numel() for p in self.head.parameters())
94 | print("Param encoder ", weights_enc)
95 | print("Param decoder ", weights_dec)
96 | print("Param head ", weights_head)
97 | if self.CRF:
98 | weights_crf = sum(p.numel() for p in self.CRF.parameters())
99 | print("Param CRF ", weights_crf)
100 |
101 | # get weights
102 | if path is not None:
103 | # try backbone
104 | try:
105 | w_dict = torch.load(path + "/backbone" + path_append,
106 | map_location=lambda storage, loc: storage)
107 | self.backbone.load_state_dict(w_dict, strict=True)
108 | print("Successfully loaded model backbone weights")
109 | except Exception as e:
110 | print()
111 | print("Couldn't load backbone, using random weights. Error: ", e)
112 | if strict:
113 | print("I'm in strict mode and failure to load weights blows me up :)")
114 | raise e
115 |
116 | # try decoder
117 | try:
118 | w_dict = torch.load(path + "/segmentation_decoder" + path_append,
119 | map_location=lambda storage, loc: storage)
120 | self.decoder.load_state_dict(w_dict, strict=True)
121 | print("Successfully loaded model decoder weights")
122 | except Exception as e:
123 | print("Couldn't load decoder, using random weights. Error: ", e)
124 | if strict:
125 | print("I'm in strict mode and failure to load weights blows me up :)")
126 | raise e
127 |
128 | # try head
129 | try:
130 | w_dict = torch.load(path + "/segmentation_head" + path_append,
131 | map_location=lambda storage, loc: storage)
132 | self.head.load_state_dict(w_dict, strict=True)
133 | print("Successfully loaded model head weights")
134 | except Exception as e:
135 | print("Couldn't load head, using random weights. Error: ", e)
136 | if strict:
137 | print("I'm in strict mode and failure to load weights blows me up :)")
138 | raise e
139 |
140 | # try CRF
141 | if self.CRF:
142 | try:
143 | w_dict = torch.load(path + "/segmentation_CRF" + path_append,
144 | map_location=lambda storage, loc: storage)
145 | self.CRF.load_state_dict(w_dict, strict=True)
146 | print("Successfully loaded model CRF weights")
147 | except Exception as e:
148 | print("Couldn't load CRF, using random weights. Error: ", e)
149 | if strict:
150 | print("I'm in strict mode and failure to load weights blows me up :)")
151 | raise e
152 | else:
153 | print("No path to pretrained, using random init.")
154 |
155 | def forward(self, x, mask=None):
156 | y, skips = self.backbone(x)
157 | y = self.decoder(y, skips)
158 | y = self.head(y)
159 | y = F.softmax(y, dim=1)
160 | if self.CRF:
161 | assert(mask is not None)
162 | y = self.CRF(x, y, mask)
163 | return y
164 |
165 | def save_checkpoint(self, logdir, suffix=""):
166 | # Save the weights
167 | torch.save(self.backbone.state_dict(), logdir +
168 | "/backbone" + suffix)
169 | torch.save(self.decoder.state_dict(), logdir +
170 | "/segmentation_decoder" + suffix)
171 | torch.save(self.head.state_dict(), logdir +
172 | "/segmentation_head" + suffix)
173 | if self.CRF:
174 | torch.save(self.CRF.state_dict(), logdir +
175 | "/segmentation_CRF" + suffix)
176 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/postproc/CRF.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # This file is covered by the LICENSE file in the root of this project.
3 |
4 | import numpy as np
5 | from scipy import signal
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | # import __init__ as booger
10 |
11 |
12 | class LocallyConnectedXYZLayer(nn.Module):
13 | def __init__(self, h, w, sigma, nclasses):
14 | super().__init__()
15 | # size of window
16 | self.h = h
17 | self.padh = h//2
18 | self.w = w
19 | self.padw = w//2
20 | assert(self.h % 2 == 1 and self.w % 2 == 1) # window must be odd
21 | self.sigma = sigma
22 | self.gauss_den = 2 * self.sigma**2
23 | self.nclasses = nclasses
24 |
25 | def forward(self, xyz, softmax, mask):
26 | # softmax size
27 | N, C, H, W = softmax.shape
28 |
29 | # make sofmax zero everywhere input is invalid
30 | softmax = softmax * mask.unsqueeze(1).float()
31 |
32 | # get x,y,z for distance (shape N,1,H,W)
33 | x = xyz[:, 0].unsqueeze(1)
34 | y = xyz[:, 1].unsqueeze(1)
35 | z = xyz[:, 2].unsqueeze(1)
36 |
37 | # im2col in size of window of input (x,y,z separately)
38 | window_x = F.unfold(x, kernel_size=(self.h, self.w),
39 | padding=(self.padh, self.padw))
40 | center_x = F.unfold(x, kernel_size=(1, 1),
41 | padding=(0, 0))
42 | window_y = F.unfold(y, kernel_size=(self.h, self.w),
43 | padding=(self.padh, self.padw))
44 | center_y = F.unfold(y, kernel_size=(1, 1),
45 | padding=(0, 0))
46 | window_z = F.unfold(z, kernel_size=(self.h, self.w),
47 | padding=(self.padh, self.padw))
48 | center_z = F.unfold(z, kernel_size=(1, 1),
49 | padding=(0, 0))
50 |
51 | # sq distance to center (center distance is zero)
52 | unravel_dist2 = (window_x - center_x)**2 + \
53 | (window_y - center_y)**2 + \
54 | (window_z - center_z)**2
55 |
56 | # weight input distance by gaussian weights
57 | unravel_gaussian = torch.exp(- unravel_dist2 / self.gauss_den)
58 |
59 | # im2col in size of window of softmax to reweight by gaussian weights from input
60 | cloned_softmax = softmax.clone()
61 | for i in range(self.nclasses):
62 | # get the softmax for this class
63 | c_softmax = softmax[:, i].unsqueeze(1)
64 | # unfold this class to weigh it by the proper gaussian weights
65 | unravel_softmax = F.unfold(c_softmax,
66 | kernel_size=(self.h, self.w),
67 | padding=(self.padh, self.padw))
68 | unravel_w_softmax = unravel_softmax * unravel_gaussian
69 | # add dimenssion 1 to obtain the new softmax for this class
70 | unravel_added_softmax = unravel_w_softmax.sum(dim=1).unsqueeze(1)
71 | # fold it and put it in new tensor
72 | added_softmax = unravel_added_softmax.view(N, H, W)
73 | cloned_softmax[:, i] = added_softmax
74 |
75 | return cloned_softmax
76 |
77 |
78 | class CRF(nn.Module):
79 | def __init__(self, params, nclasses):
80 | super().__init__()
81 | self.params = params
82 | self.iter = torch.nn.Parameter(torch.tensor(params["iter"]),
83 | requires_grad=False)
84 | self.lcn_size = torch.nn.Parameter(torch.tensor([params["lcn_size"]["h"],
85 | params["lcn_size"]["w"]]),
86 | requires_grad=False)
87 | self.xyz_coef = torch.nn.Parameter(torch.tensor(params["xyz_coef"]),
88 | requires_grad=False).float()
89 | self.xyz_sigma = torch.nn.Parameter(torch.tensor(params["xyz_sigma"]),
90 | requires_grad=False).float()
91 |
92 | self.nclasses = nclasses
93 | print("Using CRF!")
94 |
95 | # define layers here
96 | # compat init
97 | self.compat_kernel_init = np.reshape(np.ones((self.nclasses, self.nclasses)) -
98 | np.identity(self.nclasses),
99 | [self.nclasses, self.nclasses, 1, 1])
100 |
101 | # bilateral compatibility matrixes
102 | self.compat_conv = nn.Conv2d(self.nclasses, self.nclasses, 1)
103 | self.compat_conv.weight = torch.nn.Parameter(torch.from_numpy(
104 | self.compat_kernel_init).float() * self.xyz_coef, requires_grad=True)
105 |
106 | # locally connected layer for message passing
107 | self.local_conn_xyz = LocallyConnectedXYZLayer(params["lcn_size"]["h"],
108 | params["lcn_size"]["w"],
109 | params["xyz_coef"],
110 | self.nclasses)
111 |
112 | def forward(self, input, softmax, mask):
113 | # use xyz
114 | xyz = input[:, 1:4]
115 |
116 | # iteratively
117 | for iter in range(self.iter):
118 | # message passing as locally connected layer
119 | locally_connected = self.local_conn_xyz(xyz, softmax, mask)
120 |
121 | # reweigh with the 1x1 convolution
122 | reweight_softmax = self.compat_conv(locally_connected)
123 |
124 | # add the new values to the original softmax
125 | reweight_softmax = reweight_softmax + softmax
126 |
127 | # lastly, renormalize
128 | softmax = F.softmax(reweight_softmax, dim=1)
129 |
130 | return softmax
131 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/postproc/KNN.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # This file is covered by the LICENSE file in the root of this project.
3 |
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import __init__ as booger
9 |
10 |
11 | def get_gaussian_kernel(kernel_size=3, sigma=2, channels=1):
12 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
13 | x_coord = torch.arange(kernel_size)
14 | x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
15 | y_grid = x_grid.t()
16 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
17 |
18 | mean = (kernel_size - 1) / 2.
19 | variance = sigma**2.
20 |
21 | # Calculate the 2-dimensional gaussian kernel which is
22 | # the product of two gaussian distributions for two different
23 | # variables (in this case called x and y)
24 | gaussian_kernel = (1. / (2. * math.pi * variance)) *\
25 | torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) / (2 * variance))
26 |
27 | # Make sure sum of values in gaussian kernel equals 1.
28 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
29 |
30 | # Reshape to 2d depthwise convolutional weight
31 | gaussian_kernel = gaussian_kernel.view(kernel_size, kernel_size)
32 |
33 | return gaussian_kernel
34 |
35 |
36 | class KNN(nn.Module):
37 | def __init__(self, params, nclasses):
38 | super().__init__()
39 | print("*"*80)
40 | print("Cleaning point-clouds with kNN post-processing")
41 | self.knn = params["knn"]
42 | self.search = params["search"]
43 | self.sigma = params["sigma"]
44 | self.cutoff = params["cutoff"]
45 | self.nclasses = nclasses
46 | print("kNN parameters:")
47 | print("knn:", self.knn)
48 | print("search:", self.search)
49 | print("sigma:", self.sigma)
50 | print("cutoff:", self.cutoff)
51 | print("nclasses:", self.nclasses)
52 | print("*"*80)
53 |
54 | def forward(self, proj_range, unproj_range, proj_argmax, px, py):
55 | ''' Warning! Only works for un-batched pointclouds.
56 | If they come batched we need to iterate over the batch dimension or do
57 | something REALLY smart to handle unaligned number of points in memory
58 | '''
59 | # get device
60 | if proj_range.is_cuda:
61 | device = torch.device("cuda")
62 | else:
63 | device = torch.device("cpu")
64 |
65 | # sizes of projection scan
66 | H, W = proj_range.shape
67 |
68 | # number of points
69 | P = unproj_range.shape
70 |
71 | # check if size of kernel is odd and complain
72 | if (self.search % 2 == 0):
73 | raise ValueError("Nearest neighbor kernel must be odd number")
74 |
75 | # calculate padding
76 | pad = int((self.search - 1) / 2)
77 |
78 | # unfold neighborhood to get nearest neighbors for each pixel (range image)
79 | proj_unfold_k_rang = F.unfold(proj_range[None, None, ...],
80 | kernel_size=(self.search, self.search),
81 | padding=(pad, pad))
82 |
83 | # index with px, py to get ALL the pcld points
84 | idx_list = py * W + px
85 | unproj_unfold_k_rang = proj_unfold_k_rang[:, :, idx_list]
86 |
87 | # WARNING, THIS IS A HACK
88 | # Make non valid (<0) range points extremely big so that there is no screwing
89 | # up the nn self.search
90 | unproj_unfold_k_rang[unproj_unfold_k_rang < 0] = float("inf")
91 |
92 | # now the matrix is unfolded TOTALLY, replace the middle points with the actual range points
93 | center = int(((self.search * self.search) - 1) / 2)
94 | unproj_unfold_k_rang[:, center, :] = unproj_range
95 |
96 | # now compare range
97 | k2_distances = torch.abs(unproj_unfold_k_rang - unproj_range)
98 |
99 | # make a kernel to weigh the ranges according to distance in (x,y)
100 | # I make this 1 - kernel because I want distances that are close in (x,y)
101 | # to matter more
102 | inv_gauss_k = (
103 | 1 - get_gaussian_kernel(self.search, self.sigma, 1)).view(1, -1, 1)
104 | inv_gauss_k = inv_gauss_k.to(device).type(proj_range.type())
105 |
106 | # apply weighing
107 | k2_distances = k2_distances * inv_gauss_k
108 |
109 | # find nearest neighbors
110 | _, knn_idx = k2_distances.topk(
111 | self.knn, dim=1, largest=False, sorted=False)
112 |
113 | # do the same unfolding with the argmax
114 | proj_unfold_1_argmax = F.unfold(proj_argmax[None, None, ...].float(),
115 | kernel_size=(self.search, self.search),
116 | padding=(pad, pad)).long()
117 | unproj_unfold_1_argmax = proj_unfold_1_argmax[:, :, idx_list]
118 |
119 | # get the top k predictions from the knn at each pixel
120 | knn_argmax = torch.gather(
121 | input=unproj_unfold_1_argmax, dim=1, index=knn_idx)
122 |
123 | # fake an invalid argmax of classes + 1 for all cutoff items
124 | if self.cutoff > 0:
125 | knn_distances = torch.gather(input=k2_distances, dim=1, index=knn_idx)
126 | knn_invalid_idx = knn_distances > self.cutoff
127 | knn_argmax[knn_invalid_idx] = self.nclasses
128 |
129 | # now vote
130 | # argmax onehot has an extra class for objects after cutoff
131 | knn_argmax_onehot = torch.zeros(
132 | (1, self.nclasses + 1, P[0]), device=device).type(proj_range.type())
133 | ones = torch.ones_like(knn_argmax).type(proj_range.type())
134 | knn_argmax_onehot = knn_argmax_onehot.scatter_add_(1, knn_argmax, ones)
135 |
136 | # now vote (as a sum over the onehot shit) (don't let it choose unlabeled OR invalid)
137 | knn_argmax_out = knn_argmax_onehot[:, 1:-1].argmax(dim=1) + 1
138 |
139 | # reshape again
140 | knn_argmax_out = knn_argmax_out.view(P)
141 |
142 | return knn_argmax_out
143 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/postproc/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | TRAIN_PATH = "../"
3 | DEPLOY_PATH = "../../deploy"
4 | sys.path.insert(0, TRAIN_PATH)
5 |
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/postproc/__pycache__/CRF.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/postproc/__pycache__/CRF.cpython-38.pyc
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/postproc/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/postproc/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/semantic_net/rangenet/tasks/semantic/readme.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Blurryface0814/PCPNet/005c6ed0f55d1ba12290bc022d5ada36d4824ba7/semantic_net/rangenet/tasks/semantic/readme.md
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou
3 | # This file is covered by the LICENSE file in the root of the project PCPNet:
4 | # https://github.com/Blurryface0814/PCPNet
5 | # Brief: Test script for range-image-based point cloud prediction
6 | import os
7 | import time
8 | import argparse
9 | import random
10 | import yaml
11 | from pytorch_lightning import Trainer
12 | from pytorch_lightning import loggers as pl_loggers
13 |
14 | import pcpnet.datasets.datasets as datasets
15 | import pcpnet.models.PCPNet as PCPNet
16 | from pcpnet.utils.utils import set_seed
17 |
18 | if __name__ == "__main__":
19 | parser = argparse.ArgumentParser("./test.py")
20 | parser.add_argument(
21 | "--model",
22 | "-m",
23 | type=str,
24 | default=None,
25 | help="Model to be tested"
26 | )
27 | parser.add_argument(
28 | "--limit_test_batches",
29 | "-l",
30 | type=float,
31 | default=1.0,
32 | help="Percentage of test data to be tested",
33 | )
34 | parser.add_argument(
35 | "-d",
36 | "--dataset",
37 | type=str,
38 | default=None,
39 | required=True,
40 | help="The path of processed KITTI odometry and SemanticKITTI dataset",
41 | )
42 | parser.add_argument(
43 | "--save",
44 | "-s",
45 | action="store_true",
46 | help="Save point clouds"
47 | )
48 | parser.add_argument(
49 | "--only_save_pc",
50 | "-o",
51 | action="store_true",
52 | help="Only save point clouds, not compute loss"
53 | )
54 | parser.add_argument(
55 | "--cd_downsample",
56 | type=int,
57 | default=-1,
58 | help="Number of downsampled points for evaluating Chamfer Distance",
59 | )
60 | parser.add_argument("--path", "-p", type=str, default=None, help="Path to data")
61 | parser.add_argument(
62 | "-seq",
63 | "--sequence",
64 | type=int,
65 | nargs="+",
66 | default=None,
67 | help="Sequence to be tested",
68 | )
69 |
70 | args, unparsed = parser.parse_known_args()
71 | dataset_path = args.dataset
72 | if dataset_path:
73 | pass
74 | else:
75 | raise Exception("Please enter the path of dataset")
76 |
77 | # load config file
78 | config_filename = os.path.dirname(os.path.dirname(os.path.dirname(args.model))) + "/hparams.yaml"
79 | cfg = yaml.safe_load(open(config_filename))
80 | print("Starting testing model ", cfg["LOG_NAME"])
81 | """Manually set these"""
82 | cfg["DATA_CONFIG"]["COMPUTE_MEAN_AND_STD"] = False
83 | cfg["DATA_CONFIG"]["GENERATE_FILES"] = False
84 |
85 | if args.only_save_pc:
86 | cfg["TEST"]["ONLY_SAVE_POINT_CLOUDS"] = args.only_save_pc
87 | print("Only save point clouds")
88 | else:
89 | cfg["TEST"]["SAVE_POINT_CLOUDS"] = args.save
90 |
91 | cfg["TEST"]["N_DOWNSAMPLED_POINTS_CD"] = args.cd_downsample
92 | print("Evaluating CD on ", cfg["TEST"]["N_DOWNSAMPLED_POINTS_CD"], " points.")
93 |
94 | if args.sequence:
95 | cfg["DATA_CONFIG"]["SPLIT"]["TEST"] = args.sequence
96 | cfg["DATA_CONFIG"]["SPLIT"]["TRAIN"] = args.sequence
97 | cfg["DATA_CONFIG"]["SPLIT"]["VAL"] = args.sequence
98 |
99 | ###### Set random seed for torch, numpy and python
100 | set_seed(cfg["DATA_CONFIG"]["RANDOM_SEED"])
101 | print("Random seed is ", cfg["DATA_CONFIG"]["RANDOM_SEED"])
102 |
103 | data = datasets.KittiOdometryModule(cfg, dataset_path)
104 | data.setup()
105 |
106 | checkpoint_path = args.model
107 | cfg["TEST"]["USED_CHECKPOINT"] = checkpoint_path
108 | test_dir_name = "test_" + time.strftime("%Y%m%d_%H%M%S")
109 | cfg["TEST"]["DIR_NAME"] = test_dir_name
110 |
111 | model = PCPNet.PCPNet.load_from_checkpoint(checkpoint_path, cfg=cfg)
112 |
113 | # Only log if test is done on full data
114 | if args.limit_test_batches == 1.0 and not args.only_save_pc:
115 | logger = pl_loggers.TensorBoardLogger(
116 | save_dir=cfg["LOG_DIR"],
117 | default_hp_metric=False,
118 | name=test_dir_name,
119 | version=""
120 | )
121 | else:
122 | logger = False
123 |
124 | trainer = Trainer(
125 | limit_test_batches=args.limit_test_batches,
126 | gpus=cfg["TRAIN"]["N_GPUS"],
127 | logger=logger,
128 | )
129 |
130 | results = trainer.test(model, data.test_dataloader())
131 |
132 | if logger:
133 | filename = os.path.join(
134 | cfg["LOG_DIR"], cfg["TEST"]["DIR_NAME"], "results" + ".yaml"
135 | )
136 | log_to_save = {**{"results": results}, **vars(args), **cfg}
137 | with open(filename, "w") as yaml_file:
138 | yaml.dump(log_to_save, yaml_file, default_flow_style=False)
139 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou
3 | # This file is covered by the LICENSE file in the root of the project PCPNet:
4 | # https://github.com/Blurryface0814/PCPNet
5 | # Brief: Train script for range-image-based point cloud prediction
6 | import os
7 | import time
8 | import argparse
9 | import yaml
10 | from pytorch_lightning import Trainer
11 | from pytorch_lightning.loggers import TensorBoardLogger
12 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
13 |
14 | from pcpnet.datasets.datasets import KittiOdometryModule
15 | from pcpnet.models.PCPNet import PCPNet
16 | from pcpnet.utils.utils import set_seed
17 |
18 | if __name__ == "__main__":
19 | parser = argparse.ArgumentParser("./train.py")
20 | parser.add_argument(
21 | "--comment", "-c", type=str, default="", help="Add a comment to the LOG ID."
22 | )
23 | parser.add_argument(
24 | "-res",
25 | "--resume",
26 | type=str,
27 | default=None,
28 | help="Resume training from specified model.",
29 | )
30 | parser.add_argument(
31 | "-w",
32 | "--weights",
33 | type=str,
34 | default=None,
35 | help="Init model with weights from specified model",
36 | )
37 | parser.add_argument(
38 | "-d",
39 | "--dataset",
40 | type=str,
41 | default=None,
42 | required=True,
43 | help="The path of processed KITTI odometry and SemanticKITTI dataset",
44 | )
45 | parser.add_argument(
46 | "-raw",
47 | "--rawdata",
48 | type=str,
49 | default=None,
50 | help="The path of raw KITTI odometry and SemanticKITTI dataset",
51 | )
52 | parser.add_argument(
53 | "-r",
54 | "--range",
55 | type=float,
56 | default=None,
57 | help="Change weight of range image loss.",
58 | )
59 | parser.add_argument(
60 | "-m", "--mask", type=float, default=None, help="Change weight of mask loss."
61 | )
62 | parser.add_argument(
63 | "-cd",
64 | "--chamfer",
65 | type=float,
66 | default=None,
67 | help="Change weight of Chamfer distance loss.",
68 | )
69 | parser.add_argument(
70 | "-s",
71 | "--semantic",
72 | type=float,
73 | default=None,
74 | help="Change weight of semantic loss.",
75 | )
76 | parser.add_argument(
77 | "-e", "--epochs", type=int, default=None, help="Number of training epochs."
78 | )
79 | parser.add_argument(
80 | "-seq",
81 | "--sequence",
82 | type=int,
83 | nargs="+",
84 | default=None,
85 | help="Sequences for training.",
86 | )
87 | parser.add_argument(
88 | "-u",
89 | "--update-cfg",
90 | type=bool,
91 | default=False,
92 | help="Update config file.",
93 | )
94 | args, unparsed = parser.parse_known_args()
95 |
96 | dataset_path = args.dataset
97 | if dataset_path:
98 | pass
99 | else:
100 | raise Exception("Please enter the path of dataset")
101 |
102 | model_path = args.resume if args.resume else args.weights
103 | if model_path and not args.update_cfg:
104 | ###### Load config and update parameters
105 | checkpoint_path = model_path
106 | config_filename = os.path.dirname(model_path)
107 | if os.path.basename(config_filename) == "val":
108 | config_filename = os.path.dirname(config_filename)
109 | config_filename = os.path.dirname(config_filename) + "/hparams.yaml"
110 |
111 | cfg = yaml.safe_load(open(config_filename))
112 |
113 | if args.weights and not args.comment:
114 | args.comment = "_pretrained"
115 |
116 | cfg["LOG_DIR"] = cfg["LOG_DIR"] + args.comment
117 | cfg["LOG_NAME"] = cfg["LOG_NAME"] + args.comment
118 | print("New log name is ", cfg["LOG_DIR"])
119 |
120 | """Manually set these"""
121 | cfg["DATA_CONFIG"]["COMPUTE_MEAN_AND_STD"] = False
122 | cfg["DATA_CONFIG"]["GENERATE_FILES"] = False
123 |
124 | if args.epochs:
125 | cfg["TRAIN"]["MAX_EPOCH"] = args.epochs
126 | print("Set max_epochs to ", args.epochs)
127 | if args.range:
128 | cfg["TRAIN"]["LOSS_WEIGHT_RANGE_VIEW"] = args.range
129 | print("Overwriting LOSS_WEIGHT_RANGE_VIEW =", args.range)
130 | if args.mask:
131 | cfg["TRAIN"]["LOSS_WEIGHT_MASK"] = args.mask
132 | print("Overwriting LOSS_WEIGHT_MASK =", args.mask)
133 | if args.chamfer:
134 | cfg["TRAIN"]["LOSS_WEIGHT_CHAMFER_DISTANCE"] = args.chamfer
135 | print("Overwriting LOSS_WEIGHT_CHAMFER_DISTANCE =", args.chamfer)
136 | if args.semantic:
137 | cfg["TRAIN"]["LOSS_WEIGHT_SEMANTIC"] = args.semantic
138 | print("Overwriting LOSS_WEIGHT_SEMANTIC =", args.semantic)
139 | if args.sequence:
140 | cfg["DATA_CONFIG"]["SPLIT"]["TRAIN"] = args.sequence
141 | print("Training on sequences ", args.sequence)
142 | else:
143 | ###### Create new log
144 | resume_from_checkpoint = None
145 | config_filename = "config/parameters.yaml"
146 | cfg = yaml.safe_load(open(config_filename))
147 | if args.update_cfg:
148 | checkpoint_path = model_path
149 | print("Updated config file manually")
150 | if args.comment:
151 | cfg["EXPERIMENT"]["ID"] = args.comment
152 | cfg["LOG_NAME"] = cfg["EXPERIMENT"]["ID"] + "_" + time.strftime("%Y%m%d_%H%M%S")
153 | cfg["LOG_DIR"] = os.path.join("./runs", cfg["LOG_NAME"])
154 | if not os.path.exists(cfg["LOG_DIR"]):
155 | os.makedirs(cfg["LOG_DIR"])
156 | print("Starting experiment with log name", cfg["LOG_NAME"])
157 |
158 | model_file_path = "./pcpnet/models"
159 | os.system('cp -r %s %s' % (model_file_path, cfg["LOG_DIR"]))
160 |
161 | ###### Set random seed for torch, numpy and python
162 | set_seed(cfg["DATA_CONFIG"]["RANDOM_SEED"])
163 |
164 | ###### Logger
165 | tb_logger = TensorBoardLogger(
166 | save_dir=cfg["LOG_DIR"], default_hp_metric=False, name="", version=""
167 | )
168 |
169 | ###### Dataset
170 | data = KittiOdometryModule(cfg, dataset_path, args.rawdata)
171 |
172 | ###### Model
173 | model = PCPNet(cfg)
174 |
175 | ###### Load checkpoint
176 | if args.resume:
177 | resume_from_checkpoint = checkpoint_path
178 | print("Resuming from checkpoint ", checkpoint_path)
179 | elif args.weights:
180 | model = model.load_from_checkpoint(checkpoint_path, cfg=cfg)
181 | resume_from_checkpoint = None
182 | print("Loading weigths from ", checkpoint_path)
183 |
184 | ###### Callbacks
185 | lr_monitor = LearningRateMonitor(logging_interval="step")
186 | checkpoint = ModelCheckpoint(
187 | monitor="val/loss",
188 | dirpath=os.path.join(cfg["LOG_DIR"], "checkpoints"),
189 | filename="{val/loss:.3f}-{epoch:02d}",
190 | mode="min",
191 | save_top_k=5,
192 | save_last=True
193 | )
194 |
195 | ###### Trainer
196 | trainer = Trainer(
197 | gpus=cfg["TRAIN"]["N_GPUS"],
198 | logger=tb_logger,
199 | accumulate_grad_batches=cfg["TRAIN"]["BATCH_ACC"],
200 | max_epochs=cfg["TRAIN"]["MAX_EPOCH"],
201 | log_every_n_steps=cfg["TRAIN"][
202 | "LOG_EVERY_N_STEPS"
203 | ], # times accumulate_grad_batches
204 | callbacks=[lr_monitor, checkpoint],
205 | )
206 |
207 | ###### Training
208 | trainer.fit(model, data, ckpt_path=resume_from_checkpoint)
209 |
--------------------------------------------------------------------------------
/visualize.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Developed by Zhen Luo, Junyi Ma, and Zijie Zhou
3 | # This file is covered by the LICENSE file in the root of the project PCPNet:
4 | # https://github.com/Blurryface0814/PCPNet
5 | # Brief: Visualize script for range-image-based point cloud prediction
6 | import os
7 | import argparse
8 |
9 | from pcpnet.utils.visualization import *
10 |
11 | if __name__ == "__main__":
12 | parser = argparse.ArgumentParser("./visualize.py")
13 | parser.add_argument(
14 | "--path", "-p", type=str, default=None, help="Path to point clouds"
15 | )
16 | parser.add_argument(
17 | "--sequence", "-s", type=str, default="08", help="Sequence to visualize"
18 | )
19 | parser.add_argument("--start", type=int, default=0, help="Start frame")
20 | parser.add_argument("--end", type=int, default=None, help="End frame")
21 | parser.add_argument("--capture", "-c", action="store_true", help="Capture frames")
22 | args, unparsed = parser.parse_known_args()
23 |
24 | gt_path = os.path.join(args.path, args.sequence, "gt/")
25 | end = last_file(gt_path) if not args.end else args.end
26 | start = first_file(gt_path) if not args.start else args.start
27 | assert end > start
28 | print("\nRendering scans [{s},{e}] from:{d}\n".format(s=start, e=end, d=gt_path))
29 |
30 | path_to_car_model = "car_model/bus_ply.ply"
31 | vis = Visualization(
32 | path=args.path,
33 | sequence=args.sequence,
34 | start=start,
35 | end=end,
36 | capture=args.capture,
37 | path_to_car_model=path_to_car_model,
38 | )
39 | vis.set_render_options(
40 | mesh_show_wireframe=False,
41 | mesh_show_back_face=False,
42 | show_coordinate_frame=False,
43 | )
44 | vis.run()
45 |
--------------------------------------------------------------------------------