├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── assets ├── Demo.gif ├── Leaderboard.jpg └── VisualizeSingleFrame.jpg ├── common ├── __init__.py ├── dataset │ └── kitti │ │ ├── __init__.py │ │ ├── parser.py │ │ └── utils.py ├── laserscan.py ├── laserscanvis.py ├── logger.py ├── summary.py ├── sync_batchnorm │ ├── __init__.py │ ├── batchnorm.py │ ├── comm.py │ └── replicate.py ├── visualization.py └── warmupLR.py ├── config ├── data_preparing.yaml ├── kitti_road_mos.md ├── labels │ ├── semantic-kitti-all.yaml │ ├── semantic-kitti-mos.raw.yaml │ ├── semantic-kitti-mos.yaml │ └── semantic-kitti.yaml ├── post-processing.yaml └── train_split_dynamic_pointnumber.txt ├── environment.yml ├── infer.py ├── modules ├── BaseBlocks.py ├── KNN.py ├── MFMOS.py ├── PointRefine │ ├── PointMLP.py │ ├── spvcnn.py │ └── spvcnn_lite.py ├── SalsaNext.py ├── SalsaNextWithMotionAttention.py ├── __init__.py ├── loss │ ├── DiceLoss.py │ ├── Lovasz_Softmax.py │ └── __init__.py ├── tools.py ├── trainer.py ├── trainer_refine.py ├── user.py └── user_refine.py ├── script ├── dist_train.sh ├── evaluate.sh ├── train_siem.sh ├── valid.sh └── visualize.sh ├── train.py ├── train_2stage.py ├── train_yaml ├── ddp_mos_coarse_stage.yml ├── mos_coarse_stage.yml └── mos_pointrefine_stage.yml └── utils ├── auto_gen_residual_images.py ├── auto_gen_residual_images_mp.py ├── auxiliary ├── __init__.py ├── camera.py ├── glow.py ├── laserscan.py ├── laserscanvis.py ├── np_ioueval.py ├── shaders │ ├── check_uniforms.vert │ ├── draw_pose.geom │ ├── draw_voxels.frag │ ├── draw_voxels.vert │ ├── empty.frag │ ├── empty.vert │ └── passthrough.frag └── torch_ioueval.py ├── combine_semantics.py ├── concat_residual_image.py ├── download_kitti_road.sh ├── evaluate_mos.py ├── gen_residual_images.py ├── kitti_mos_statistical_analysis.py ├── kitti_utils.py ├── scan_cleaner.py ├── utils.py ├── viewfile.json ├── visualize_mos.py ├── viz_concate_residuals.py ├── viz_mos_result_2d.py ├── viz_mos_result_o3d.py ├── viz_range_depth_img.py └── viz_seqVideo.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | .idea/ 11 | log/ 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | # pytype static type analyzer 139 | .pytype/ 140 | 141 | # Cython debug symbols 142 | cython_debug/ 143 | 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ICRA2024] MF-MOS: A Motion-Focused Model for Moving Object Segmentation 2 | 3 |
9 | 10 | **🎉MF-MOS achieved a leading IoU of **_76.7%_** on [the MOS leaderboard of the SemanticKITTI](https://codalab.lisn.upsaclay.fr/competitions/7088) upon submission, demonstrating the current SOTA performance.** 11 | 12 |
13 |
14 |
16 | The MOS leaderboard 17 |
18 | 19 |
20 |
21 |
23 | Video demo 24 |
25 | 26 | ## 📖How to use 27 | ### 📦pretrained model 28 | Our pretrained model (best in validation, with the IoU of **_76.12%_**) can be downloaded from [Google Drive](https://drive.google.com/file/d/1KGPwMr9v9GWdIB0zEGAJ8Wi0k3dvXbZt/view?usp=sharing). 29 | ### 📚Dataset 30 | Download SemanticKITTI dataset from [SemanticKITTI](http://www.semantic-kitti.org/dataset.html#download) (including **Velodyne point clouds**, **calibration data** and **label data**). 31 | #### Preprocessing 32 | After downloading the dataset, the residual maps as the input of the model during training need to be generated. 33 | Run [auto_gen_residual_images.py](./utils/auto_gen_residual_images.py) or [auto_gen_residual_images_mp.py](./utils/auto_gen_residual_images_mp.py)(with multiprocess), 34 | and check that the path is correct before running. 35 | 36 | The structure of one of the folders in the entire dataset is as follows: 37 | ``` 38 | DATAROOT 39 | └── sequences 40 | ├── 00 41 | │ ├── poses.txt 42 | │ ├── calib.txt 43 | │ ├── times.txt 44 | │ ├── labels 45 | │ ├── residual_images_1 46 | │ ├── residual_images_10 47 | │ ├── residual_images_11 48 | │ ├── residual_images_13 49 | │ ├── residual_images_15 50 | │ ├── residual_images_16 51 | │ ├── residual_images_19 52 | │ ├── residual_images_2 53 | │ ├── residual_images_22 54 | │ ├── residual_images_3 55 | │ ├── residual_images_4 56 | │ ├── residual_images_5 57 | │ ├── residual_images_6 58 | │ ├── residual_images_7 59 | │ ├── residual_images_8 60 | │ ├── residual_images_9 61 | │ └── velodyne 62 | ... 63 | ``` 64 | If you don't need to do augmentation for residual maps, you just need the folder with num [1, 2, 3, 4, 5, 6, 7, 8]. 65 | 66 | ### 💾Environment 67 | Our environment: Ubuntu 18.04, CUDA 11.2 68 | 69 | Use conda to create the conda environment and activate it: 70 | ```shell 71 | conda env create -f environment.yml 72 | conda activate mfmos 73 | ``` 74 | #### TorchSparse 75 | Install torchsparse which is used in [SIEM](./modules/PointRefine/spvcnn.py) using the commands: 76 | ```shell 77 | sudo apt install libsparsehash-dev 78 | pip install --upgrade git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0 79 | ``` 80 | 81 | ### 📈Training 82 | Check the path in [dist_train.sh](./script/dist_train.sh), and run it to train: 83 | ```shell 84 | bash script/dist_train.sh 85 | ``` 86 | You can change the number of GPUs as well as ID to suit your needs. 87 | #### Train the SIEM 88 | Once you have completed the first phase of training above, you can continue with SIEM training to get an improved performance. 89 | 90 | Check the path in [train_siem.sh](./script/train_siem.sh) and run it to train the SIEM **(only available on single GPU)**: 91 | ```shell 92 | bash script/train_siem.sh 93 | ``` 94 | 95 | ### 📝Validation and Evaluation 96 | Check the path in [valid.sh](./script/valid.sh) and [evaluate.sh](./script/evaluate.sh). 97 | 98 | Then, run them to get the predicted results and IoU in the paper separately: 99 | ```shell 100 | bash script/valid.sh 101 | # evaluation after validation 102 | bash script/evaluate.sh 103 | ``` 104 | You can also use our pre-trained model which has been provided above to validate its performance. 105 | 106 | 107 | ### 👀Visualization 108 | #### Single-frame visualization 109 | Check the path in [visualize.sh](./script/visualize.sh), and run it to visualize the results in 2D and 3D: 110 | ```shell 111 | bash script/visualize.sh 112 | ``` 113 | If -p is empty: only ground truth will be visualized. 114 | 115 | If -p set the path of predictions: both ground truth and predictions will be visualized. 116 |  117 | #### Get the sequences video 118 | Check the path in [viz_seqVideo.py](./utils/viz_seqVideo.py), and run it to visualize the entire sequence in the form of a video. 119 | 120 | 121 | ## 👏Acknowledgment 122 | This repo is based on [MotionSeg3D](https://github.com/haomo-ai/MotionSeg3D) and [LiDAR-MOS](https://github.com/PRBonn/LiDAR-MOS), we are very grateful for their excellent work. 123 | Besides, excellent works like 4DMOS[[paper](https://www.ipb.uni-bonn.de/wp-content/papercite-data/pdf/mersch2022ral.pdf), [code](https://github.com/PRBonn/4DMOS)] and MapMOS[[paper](https://www.ipb.uni-bonn.de/wp-content/papercite-data/pdf/mersch2023ral.pdf), [code](https://github.com/PRBonn/MapMOS)] have not only demonstrated excellent dynamic object segmentation capabilities on the SemanticKITTI-MOS benchmark but have also exhibited nice generalization abilities on new datasets, which MF-MOS fails to achieve. We appreciate their contributions to MOS and highly recommend people to use their excellent public available code. 124 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | TRAIN_PATH = "./" 4 | DEPLOY_PATH = "../../../deploy" 5 | sys.path.insert(0, TRAIN_PATH) 6 | -------------------------------------------------------------------------------- /assets/Demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/MF-MOS/0c702445a39b978efc107cf7d0a2a33246f857ba/assets/Demo.gif -------------------------------------------------------------------------------- /assets/Leaderboard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/MF-MOS/0c702445a39b978efc107cf7d0a2a33246f857ba/assets/Leaderboard.jpg -------------------------------------------------------------------------------- /assets/VisualizeSingleFrame.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/MF-MOS/0c702445a39b978efc107cf7d0a2a33246f857ba/assets/VisualizeSingleFrame.jpg -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/MF-MOS/0c702445a39b978efc107cf7d0a2a33246f857ba/common/__init__.py -------------------------------------------------------------------------------- /common/dataset/kitti/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/MF-MOS/0c702445a39b978efc107cf7d0a2a33246f857ba/common/dataset/kitti/__init__.py -------------------------------------------------------------------------------- /common/laserscanvis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | 4 | import vispy 5 | from vispy.scene import visuals, SceneCanvas 6 | import numpy as np 7 | from matplotlib import pyplot as plt 8 | 9 | 10 | class LaserScanVis: 11 | """Class that creates and handles a visualizer for a pointcloud""" 12 | 13 | def __init__(self, scan, scan_names, label_names, offset=0, 14 | semantics=True, instances=False): 15 | self.scan = scan 16 | self.scan_names = scan_names 17 | self.label_names = label_names 18 | self.offset = offset 19 | self.semantics = semantics 20 | self.instances = instances 21 | # sanity check 22 | if not self.semantics and self.instances: 23 | print("Instances are only allowed in when semantics=True") 24 | raise ValueError 25 | 26 | self.reset() 27 | self.update_scan() 28 | 29 | def reset(self): 30 | """ Reset. """ 31 | # last key press (it should have a mutex, but visualization is not 32 | # safety critical, so let's do things wrong) 33 | self.action = "no" # no, next, back, quit are the possibilities 34 | 35 | # new canvas prepared for visualizing data 36 | self.canvas = SceneCanvas(keys='interactive', show=True) 37 | # interface (n next, b back, q quit, very simple) 38 | self.canvas.events.key_press.connect(self.key_press) 39 | self.canvas.events.draw.connect(self.draw) 40 | # grid 41 | self.grid = self.canvas.central_widget.add_grid() 42 | 43 | # laserscan part 44 | self.scan_view = vispy.scene.widgets.ViewBox( 45 | border_color='white', parent=self.canvas.scene) 46 | self.grid.add_widget(self.scan_view, 0, 0) 47 | self.scan_vis = visuals.Markers() 48 | self.scan_view.camera = 'turntable' 49 | self.scan_view.add(self.scan_vis) 50 | visuals.XYZAxis(parent=self.scan_view.scene) 51 | 52 | # add semantics 53 | if self.semantics: 54 | print("Using semantics in visualizer") 55 | self.sem_view = vispy.scene.widgets.ViewBox( 56 | border_color='white', parent=self.canvas.scene) 57 | self.grid.add_widget(self.sem_view, 0, 1) 58 | self.sem_vis = visuals.Markers() 59 | self.sem_view.camera = 'turntable' 60 | self.sem_view.add(self.sem_vis) 61 | visuals.XYZAxis(parent=self.sem_view.scene) 62 | # self.sem_view.camera.link(self.scan_view.camera) 63 | 64 | if self.instances: 65 | print("Using instances in visualizer") 66 | self.inst_view = vispy.scene.widgets.ViewBox( 67 | border_color='white', parent=self.canvas.scene) 68 | self.grid.add_widget(self.inst_view, 0, 2) 69 | self.inst_vis = visuals.Markers() 70 | self.inst_view.camera = 'turntable' 71 | self.inst_view.add(self.inst_vis) 72 | visuals.XYZAxis(parent=self.inst_view.scene) 73 | # self.inst_view.camera.link(self.scan_view.camera) 74 | 75 | # img canvas size 76 | self.multiplier = 1 77 | self.canvas_W = 1024 78 | self.canvas_H = 64 79 | if self.semantics: 80 | self.multiplier += 1 81 | if self.instances: 82 | self.multiplier += 1 83 | 84 | # new canvas for img 85 | self.img_canvas = SceneCanvas(keys='interactive', show=True, 86 | size=(self.canvas_W, self.canvas_H * self.multiplier)) 87 | # grid 88 | self.img_grid = self.img_canvas.central_widget.add_grid() 89 | # interface (n next, b back, q quit, very simple) 90 | self.img_canvas.events.key_press.connect(self.key_press) 91 | self.img_canvas.events.draw.connect(self.draw) 92 | 93 | # add a view for the depth 94 | self.img_view = vispy.scene.widgets.ViewBox( 95 | border_color='white', parent=self.img_canvas.scene) 96 | self.img_grid.add_widget(self.img_view, 0, 0) 97 | self.img_vis = visuals.Image(cmap='viridis') 98 | self.img_view.add(self.img_vis) 99 | 100 | # add semantics 101 | if self.semantics: 102 | self.sem_img_view = vispy.scene.widgets.ViewBox( 103 | border_color='white', parent=self.img_canvas.scene) 104 | self.img_grid.add_widget(self.sem_img_view, 1, 0) 105 | self.sem_img_vis = visuals.Image(cmap='viridis') 106 | self.sem_img_view.add(self.sem_img_vis) 107 | 108 | # add instances 109 | if self.instances: 110 | self.inst_img_view = vispy.scene.widgets.ViewBox( 111 | border_color='white', parent=self.img_canvas.scene) 112 | self.img_grid.add_widget(self.inst_img_view, 2, 0) 113 | self.inst_img_vis = visuals.Image(cmap='viridis') 114 | self.inst_img_view.add(self.inst_img_vis) 115 | 116 | def get_mpl_colormap(self, cmap_name): 117 | cmap = plt.get_cmap(cmap_name) 118 | 119 | # Initialize the matplotlib color map 120 | sm = plt.cm.ScalarMappable(cmap=cmap) 121 | 122 | # Obtain linear color range 123 | color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1] 124 | 125 | return color_range.reshape(256, 3).astype(np.float32) / 255.0 126 | 127 | def update_scan(self): 128 | # first open data 129 | self.scan.open_scan(self.scan_names[self.offset]) 130 | if self.semantics: 131 | self.scan.open_label(self.label_names[self.offset]) 132 | self.scan.colorize() 133 | 134 | # then change names 135 | title = "scan " + str(self.offset) + " of " + str(len(self.scan_names)-1) 136 | self.canvas.title = title 137 | self.img_canvas.title = title 138 | 139 | # then do all the point cloud stuff 140 | 141 | # plot scan 142 | power = 16 143 | # print() 144 | range_data = np.copy(self.scan.unproj_range) 145 | # print(range_data.max(), range_data.min()) 146 | range_data = range_data**(1 / power) 147 | # print(range_data.max(), range_data.min()) 148 | viridis_range = ((range_data - range_data.min()) / 149 | (range_data.max() - range_data.min()) * 255).astype(np.uint8) 150 | viridis_map = self.get_mpl_colormap("viridis") 151 | viridis_colors = viridis_map[viridis_range] 152 | self.scan_vis.set_data(self.scan.points, 153 | face_color=viridis_colors[..., ::-1], 154 | edge_color=viridis_colors[..., ::-1], 155 | size=1) 156 | 157 | # plot semantics 158 | if self.semantics: 159 | self.sem_vis.set_data(self.scan.points, 160 | face_color=self.scan.sem_label_color[..., ::-1], 161 | edge_color=self.scan.sem_label_color[..., ::-1], 162 | size=1) 163 | 164 | # plot instances 165 | if self.instances: 166 | self.inst_vis.set_data(self.scan.points, 167 | face_color=self.scan.inst_label_color[..., ::-1], 168 | edge_color=self.scan.inst_label_color[..., ::-1], 169 | size=1) 170 | 171 | # now do all the range image stuff 172 | # plot range image 173 | data = np.copy(self.scan.proj_range) 174 | # print(data[data > 0].max(), data[data > 0].min()) 175 | data[data > 0] = data[data > 0]**(1 / power) 176 | data[data < 0] = data[data > 0].min() 177 | # print(data.max(), data.min()) 178 | data = (data - data[data > 0].min()) / \ 179 | (data.max() - data[data > 0].min()) 180 | # print(data.max(), data.min()) 181 | self.img_vis.set_data(data) 182 | self.img_vis.update() 183 | 184 | if self.semantics: 185 | self.sem_img_vis.set_data(self.scan.proj_sem_color[..., ::-1]) 186 | self.sem_img_vis.update() 187 | 188 | if self.instances: 189 | self.inst_img_vis.set_data(self.scan.proj_inst_color[..., ::-1]) 190 | self.inst_img_vis.update() 191 | 192 | # interface 193 | def key_press(self, event): 194 | self.canvas.events.key_press.block() 195 | self.img_canvas.events.key_press.block() 196 | if event.key == 'N': 197 | self.offset += 1 198 | if self.offset >= len(self.scan_names): 199 | self.offset = 0 200 | self.update_scan() 201 | elif event.key == 'B': 202 | self.offset -= 1 203 | if self.offset <= 0: 204 | self.offset = len(self.scan_names)-1 205 | self.update_scan() 206 | elif event.key == 'Q' or event.key == 'Escape': 207 | self.destroy() 208 | 209 | def draw(self, event): 210 | if self.canvas.events.key_press.blocked(): 211 | self.canvas.events.key_press.unblock() 212 | if self.img_canvas.events.key_press.blocked(): 213 | self.img_canvas.events.key_press.unblock() 214 | 215 | def destroy(self): 216 | # destroy the visualization 217 | self.canvas.close() 218 | self.img_canvas.close() 219 | vispy.app.quit() 220 | 221 | def run(self): 222 | vispy.app.run() 223 | -------------------------------------------------------------------------------- /common/logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | 3 | import numpy as np 4 | import scipy.misc 5 | import tensorflow as tf 6 | from torch.utils.tensorboard import SummaryWriter 7 | import torch 8 | try: 9 | from StringIO import StringIO # Python 2.7 10 | except ImportError: 11 | from io import BytesIO # Python 3.x 12 | 13 | 14 | class Logger(object): 15 | 16 | def __init__(self, log_dir): 17 | """Create a summary writer logging to log_dir.""" 18 | self.writer = tf.summary.create_file_writer(log_dir) 19 | 20 | def scalar_summary(self, tag, value, step): 21 | """Log a scalar variable.""" 22 | with self.writer.as_default(): 23 | tf.summary.scalar(name=tag, data=value, step=step) 24 | self.writer.flush() 25 | 26 | def image_summary(self, tag, images, step): 27 | """Log a list of images.""" 28 | 29 | img_summaries = [] 30 | for i, img in enumerate(images): 31 | # Write the image to a string 32 | try: 33 | s = StringIO() 34 | except: 35 | s = BytesIO() 36 | scipy.misc.toimage(img).save(s, format="png") 37 | 38 | # Create an Image object 39 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 40 | height=img.shape[0], 41 | width=img.shape[1]) 42 | # Create a Summary value 43 | img_summaries.append(tf.Summary.Value( 44 | tag='%s/%d' % (tag, i), image=img_sum)) 45 | 46 | # Create and write Summary 47 | summary = tf.Summary(value=img_summaries) 48 | self.writer.add_summary(summary, step) 49 | self.writer.flush() 50 | 51 | def histo_summary(self, tag, values, step, bins=1000): 52 | """Log a histogram of the tensor of values.""" 53 | 54 | # Create a histogram using numpy 55 | counts, bin_edges = np.histogram(values, bins=bins) 56 | 57 | # Fill the fields of the histogram proto 58 | hist = tf.HistogramProto() 59 | hist.min = float(np.min(values)) 60 | hist.max = float(np.max(values)) 61 | hist.num = int(np.prod(values.shape)) 62 | hist.sum = float(np.sum(values)) 63 | hist.sum_squares = float(np.sum(values ** 2)) 64 | 65 | # Drop the start of the first bin 66 | bin_edges = bin_edges[1:] 67 | 68 | # Add bin edges and counts 69 | for edge in bin_edges: 70 | hist.bucket_limit.append(edge) 71 | for c in counts: 72 | hist.bucket.append(c) 73 | 74 | # Create and write Summary 75 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 76 | self.writer.add_summary(summary, step) 77 | self.writer.flush() 78 | -------------------------------------------------------------------------------- /common/summary.py: -------------------------------------------------------------------------------- 1 | ### https://github.com/sksq96/pytorch-summary/blob/master/torchsummary/torchsummary.py 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | from collections import OrderedDict 7 | import numpy as np 8 | 9 | 10 | def summary(model, input_size, batch_size=-1, device="cuda"): 11 | def register_hook(module): 12 | 13 | def hook(module, input, output): 14 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 15 | module_idx = len(summary) 16 | 17 | m_key = "%s-%i" % (class_name, module_idx + 1) 18 | summary[m_key] = OrderedDict() 19 | summary[m_key]["input_shape"] = list(input[0].size()) 20 | summary[m_key]["input_shape"][0] = batch_size 21 | if isinstance(output, (list, tuple)): 22 | summary[m_key]["output_shape"] = [[-1] + list(o.size())[1:] for o in output] 23 | else: 24 | summary[m_key]["output_shape"] = list(output.size()) 25 | summary[m_key]["output_shape"][0] = batch_size 26 | 27 | params = 0 28 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 29 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 30 | summary[m_key]["trainable"] = module.weight.requires_grad 31 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 32 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 33 | summary[m_key]["nb_params"] = params 34 | 35 | if (not isinstance(module, nn.Sequential) 36 | and not isinstance(module, nn.ModuleList) 37 | and not (module == model)): 38 | hooks.append(module.register_forward_hook(hook)) 39 | 40 | device = device.lower() 41 | assert device in ["cuda", "cpu",], "Input device is not valid, please specify 'cuda' or 'cpu'" 42 | 43 | if device == "cuda" and torch.cuda.is_available(): 44 | dtype = torch.cuda.FloatTensor 45 | else: 46 | dtype = torch.FloatTensor 47 | 48 | # multiple inputs to the network 49 | if isinstance(input_size, tuple): 50 | input_size = [input_size] 51 | 52 | # batch_size of 2 for batchnorm 53 | x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] 54 | # message +=type(x[0])) 55 | 56 | # create properties 57 | summary = OrderedDict() 58 | hooks = [] 59 | 60 | # register hook 61 | model.apply(register_hook) 62 | 63 | # make a forward pass 64 | # message +=x.shape) 65 | model(*x) 66 | 67 | # remove these hooks 68 | for h in hooks: 69 | h.remove() 70 | message = "" 71 | message += "----------------------------------------------------------------\n" 72 | line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") 73 | message += line_new + "\n" 74 | message += "================================================================\n" 75 | total_params = 0 76 | total_output = 0 77 | trainable_params = 0 78 | for layer in summary: 79 | # input_shape, output_shape, trainable, nb_params 80 | line_new = "{:>20} {:>25} {:>15}".format( 81 | layer, 82 | str(summary[layer]["output_shape"]), 83 | "{0:,}".format(summary[layer]["nb_params"]), 84 | ) 85 | total_params += summary[layer]["nb_params"] 86 | total_output += np.prod(summary[layer]["output_shape"]) 87 | if "trainable" in summary[layer]: 88 | if summary[layer]["trainable"] == True: 89 | trainable_params += summary[layer]["nb_params"] 90 | message += line_new + "\n" 91 | 92 | # assume 4 bytes/number (float on cuda). 93 | total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) 94 | total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients 95 | total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.)) 96 | total_size = total_params_size + total_output_size + total_input_size 97 | 98 | message += "================================================================\n" 99 | message += "Total params: {0:,}\n".format(total_params) 100 | message += "Trainable params: {0:,}\n".format(trainable_params) 101 | message += "Non-trainable params: {0:,}\n".format(total_params - trainable_params) 102 | message += "----------------------------------------------------------------\n" 103 | message += "Input size (MB): %0.2f\n" % total_input_size 104 | message += "Forward/backward pass size (MB): %0.2f\n" % total_output_size 105 | message += "Params size (MB): %0.2f\n" % total_params_size 106 | message += "Estimated Total Size (MB): %0.2f\n" % total_size 107 | message += "----------------------------------------------------------------\n" 108 | return message 109 | -------------------------------------------------------------------------------- /common/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/MF-MOS/0c702445a39b978efc107cf7d0a2a33246f857ba/common/sync_batchnorm/__init__.py -------------------------------------------------------------------------------- /common/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | import queue 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple( 44 | '_SlavePipeBase', ['identifier', 'queue', 'result']) 45 | 46 | 47 | class SlavePipe(_SlavePipeBase): 48 | """Pipe for master-slave communication.""" 49 | 50 | def run_slave(self, msg): 51 | self.queue.put((self.identifier, msg)) 52 | ret = self.result.get() 53 | self.queue.put(True) 54 | return ret 55 | 56 | 57 | class SyncMaster(object): 58 | """An abstract `SyncMaster` object. 59 | 60 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 61 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 62 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 63 | and passed to a registered callback. 64 | - After receiving the messages, the master device should gather the information and determine to message passed 65 | back to each slave devices. 66 | """ 67 | 68 | def __init__(self, master_callback): 69 | """ 70 | 71 | Args: 72 | master_callback: a callback to be invoked after having collected messages from slave devices. 73 | """ 74 | self._master_callback = master_callback 75 | self._queue = queue.Queue() 76 | self._registry = collections.OrderedDict() 77 | self._activated = False 78 | 79 | def __getstate__(self): 80 | return {'master_callback': self._master_callback} 81 | 82 | def __setstate__(self, state): 83 | self.__init__(state['master_callback']) 84 | 85 | def register_slave(self, identifier): 86 | """ 87 | Register an slave device. 88 | 89 | Args: 90 | identifier: an identifier, usually is the device id. 91 | 92 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 93 | 94 | """ 95 | if self._activated: 96 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 97 | self._activated = False 98 | self._registry.clear() 99 | future = FutureResult() 100 | self._registry[identifier] = _MasterRegistry(future) 101 | return SlavePipe(identifier, self._queue, future) 102 | 103 | def run_master(self, master_msg): 104 | """ 105 | Main entry for the master device in each forward pass. 106 | The messages were first collected from each devices (including the master device), and then 107 | an callback will be invoked to compute the message to be sent back to each devices 108 | (including the master device). 109 | 110 | Args: 111 | master_msg: the message that the master want to send to itself. This will be placed as the first 112 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 113 | 114 | Returns: the message to be sent back to the master device. 115 | 116 | """ 117 | self._activated = True 118 | 119 | intermediates = [(0, master_msg)] 120 | for i in range(self.nr_slaves): 121 | intermediates.append(self._queue.get()) 122 | 123 | results = self._master_callback(intermediates) 124 | assert results[0][0] == 0, 'The first result should belongs to the master.' 125 | 126 | for i, res in results: 127 | if i == 0: 128 | continue 129 | self._registry[i].result.put(res) 130 | 131 | for i in range(self.nr_slaves): 132 | assert self._queue.get() is True 133 | 134 | return results[0][1] 135 | 136 | @property 137 | def nr_slaves(self): 138 | return len(self._registry) 139 | -------------------------------------------------------------------------------- /common/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, 66 | self).replicate(module, device_ids) 67 | execute_replication_callbacks(modules) 68 | return modules 69 | 70 | 71 | def patch_replication_callback(data_parallel): 72 | """ 73 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 74 | Useful when you have customized `DataParallel` implementation. 75 | 76 | Examples: 77 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 78 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 79 | > patch_replication_callback(sync_bn) 80 | # this is equivalent to 81 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 82 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 83 | """ 84 | 85 | assert isinstance(data_parallel, DataParallel) 86 | 87 | old_replicate = data_parallel.replicate 88 | 89 | @functools.wraps(old_replicate) 90 | def new_replicate(module, device_ids): 91 | modules = old_replicate(module, device_ids) 92 | execute_replication_callbacks(modules) 93 | return modules 94 | 95 | data_parallel.replicate = new_replicate 96 | -------------------------------------------------------------------------------- /common/visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib 4 | import numpy as np 5 | import pykitti 6 | 7 | matplotlib.use('Agg') 8 | import matplotlib.pyplot as plt 9 | import yaml 10 | 11 | basedir = '' 12 | sequence = '' 13 | uncerts = '' 14 | preds = '' 15 | gt = '' 16 | img = '' 17 | lidar = '' 18 | projected_uncert = '' 19 | projected_preds = '' 20 | 21 | dataset = pykitti.odometry(basedir, sequence) 22 | 23 | EXTENSIONS_LABEL = ['.label'] 24 | EXTENSIONS_LIDAR = ['.bin'] 25 | EXTENSIONS_IMG = ['.png'] 26 | 27 | 28 | def is_label(filename): 29 | return any(filename.endswith(ext) for ext in EXTENSIONS_LABEL) 30 | 31 | 32 | def is_lidar(filename): 33 | return any(filename.endswith(ext) for ext in EXTENSIONS_LIDAR) 34 | 35 | 36 | def is_img(filename): 37 | return any(filename.endswith(ext) for ext in EXTENSIONS_IMG) 38 | 39 | 40 | def get_mpl_colormap(cmap_name): 41 | cmap = plt.get_cmap(cmap_name) 42 | 43 | # Initialize the matplotlib color map 44 | sm = plt.cm.ScalarMappable(cmap=cmap) 45 | 46 | # Obtain linear color range 47 | color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1] 48 | 49 | return color_range.reshape(256, 3).astype(np.float32) / 255.0 50 | 51 | 52 | path = os.path.join(basedir + 'sequences/' + sequence + uncerts) 53 | 54 | scan_uncert = [os.path.join(dp, f) for dp, dn, fn in os.walk( 55 | os.path.expanduser(path)) for f in fn if is_label(f)] 56 | scan_uncert.sort() 57 | path = os.path.join(basedir + 'sequences/' + sequence + preds) 58 | scan_preds = [os.path.join(dp, f) for dp, dn, fn in os.walk( 59 | os.path.expanduser(path)) for f in fn if is_label(f)] 60 | scan_preds.sort() 61 | 62 | path = os.path.join(basedir + 'sequences/' + sequence + gt) 63 | scan_gt = [os.path.join(dp, f) for dp, dn, fn in os.walk( 64 | os.path.expanduser(path)) for f in fn if is_label(f)] 65 | scan_gt.sort() 66 | 67 | color_map_dict = yaml.safe_load(open("color_map.yml"))['color_map'] 68 | learning_map = yaml.safe_load(open("color_map.yml"))['learning_map'] 69 | color_map = {} 70 | uncert_mean = np.zeros(20) 71 | total_points_per_class = np.zeros(20) 72 | for key, value in color_map_dict.items(): 73 | color_map[key] = np.array(value, np.float32) / 255.0 74 | 75 | 76 | def plot_and_save(label_uncert, label_name, lidar_name, cam2_image_name): 77 | labels = np.fromfile(label_name, dtype=np.int32).reshape((-1)) 78 | uncerts = np.fromfile(label_uncert, dtype=np.float32).reshape((-1)) 79 | velo_points = np.fromfile(lidar_name, dtype=np.float32).reshape(-1, 4) 80 | try: 81 | cam2_image = plt.imread(cam2_image_name) 82 | except IOError: 83 | print('detect error img %s' % label_name) 84 | 85 | plt.imshow(cam2_image) 86 | 87 | if True: 88 | 89 | # Project points to camera. 90 | cam2_points = dataset.calib.T_cam2_velo.dot(velo_points.T).T 91 | 92 | # Filter out points behind camera 93 | idx = cam2_points[:, 2] > 0 94 | print(idx) 95 | # velo_points_projected = velo_points[idx] 96 | cam2_points = cam2_points[idx] 97 | labels_projected = labels[idx] 98 | uncert_projected = uncerts[idx] 99 | 100 | # Remove homogeneous z. 101 | cam2_points = cam2_points[:, :3] / cam2_points[:, 2:3] 102 | 103 | # Apply instrinsics. 104 | intrinsic_cam2 = dataset.calib.K_cam2 105 | cam2_points = intrinsic_cam2.dot(cam2_points.T).T[:, [1, 0]] 106 | cam2_points = cam2_points.astype(int) 107 | 108 | for i in range(0, cam2_points.shape[0]): 109 | u, v = cam2_points[i, :] 110 | label = labels_projected[i] 111 | uncert = uncert_projected[i] 112 | if label > 0 and v > 0 and v < 1241 and u > 0 and u < 376: 113 | uncert_mean[learning_map[label]] += uncert 114 | total_points_per_class[learning_map[label]] += 1 115 | m_circle = plt.Circle((v, u), 1, 116 | color=matplotlib.cm.viridis(uncert), 117 | alpha=0.4, 118 | # color=color_map[label][..., ::-1] 119 | ) 120 | plt.gcf().gca().add_artist(m_circle) 121 | 122 | plt.axis('off') 123 | path = os.path.join(basedir + 'sequences/' + sequence + projected_uncert) 124 | plt.savefig(path + label_name.split('/')[-1].split('.')[0] + '.png', bbox_inches='tight', transparent=True, 125 | pad_inches=0) 126 | 127 | # with futures.ProcessPoolExecutor() as pool: 128 | for label_uncert, label_name, lidar_name, cam2_image_name in zip(scan_uncert, scan_preds, dataset.velo_files, 129 | dataset.cam2_files): 130 | print(label_name.split('/')[-1]) 131 | # if label_name == '/SPACE/DATA/SemanticKITTI/dataset/sequences/13/predictions/preds/001032.label': 132 | plot_and_save(label_uncert, label_name, lidar_name, cam2_image_name) 133 | print(total_points_per_class) 134 | print(uncert_mean) 135 | 136 | if __name__ == "__main__": 137 | pass 138 | -------------------------------------------------------------------------------- /common/warmupLR.py: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | 3 | import torch.optim.lr_scheduler as toptim 4 | 5 | 6 | class warmupLR(toptim._LRScheduler): 7 | """ Warmup learning rate scheduler. 8 | Initially, increases the learning rate from 0 to the final value, in a 9 | certain number of steps. After this number of steps, each step decreases 10 | LR exponentially. 11 | """ 12 | 13 | def __init__(self, optimizer, lr, warmup_steps, momentum, decay): 14 | # cyclic params 15 | self.optimizer = optimizer 16 | self.lr = lr 17 | self.warmup_steps = warmup_steps 18 | self.momentum = momentum 19 | self.decay = decay 20 | 21 | # cap to one 22 | if self.warmup_steps < 1: 23 | self.warmup_steps = 1 24 | 25 | # cyclic lr 26 | self.initial_scheduler = toptim.CyclicLR(self.optimizer, 27 | base_lr=0, 28 | max_lr=self.lr, 29 | step_size_up=self.warmup_steps, 30 | step_size_down=self.warmup_steps, 31 | cycle_momentum=False, 32 | base_momentum=self.momentum, 33 | max_momentum=self.momentum) 34 | 35 | # our params 36 | self.last_epoch = -1 # fix for pytorch 1.1 and below 37 | self.finished = False # am i done 38 | super().__init__(optimizer) 39 | 40 | def get_lr(self): 41 | return [self.lr * (self.decay ** self.last_epoch) for lr in self.base_lrs] 42 | 43 | def step(self, epoch=None): 44 | if self.finished or self.initial_scheduler.last_epoch >= self.warmup_steps: 45 | if not self.finished: 46 | self.base_lrs = [self.lr for lr in self.base_lrs] 47 | self.finished = True 48 | return super(warmupLR, self).step(epoch) 49 | else: 50 | return self.initial_scheduler.step(epoch) 51 | -------------------------------------------------------------------------------- /config/data_preparing.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | # Developed by: Xieyuanli Chen 3 | # Configuration for preparing residual images (specifying all the paths) 4 | # -------------------------------------------------------------------- 5 | 6 | # General parameters 7 | # number of frames for training, -1 uses all frames 8 | num_frames: -1 9 | # plot images 10 | debug: False 11 | # normalize/scale the difference with corresponding range value 12 | normalize: True 13 | # use the last n frame to calculate the difference image 14 | num_last_n: 8 15 | residual_aug: True 16 | 17 | # Inputs 18 | # the folder of raw LiDAR scans 19 | scan_folder: 'data/sequences/08/velodyne' 20 | # ground truth poses file 21 | pose_file: 'data/sequences/08/poses.txt' 22 | # calibration file 23 | calib_file: 'data/sequences/08/calib.txt' 24 | 25 | # Outputs 26 | # the suffix should be the same as num_last_n! 27 | residual_image_folder: 'data/sequences/08/residual_images_8' 28 | visualize: False 29 | visualization_folder: 'data/sequences/08/visualization_8' 30 | 31 | # range image parameters 32 | range_image: 33 | height: 64 34 | width: 2048 35 | fov_up: 3.0 36 | fov_down: -25.0 37 | max_range: 50.0 38 | min_range: 2.0 39 | 40 | -------------------------------------------------------------------------------- /config/kitti_road_mos.md: -------------------------------------------------------------------------------- 1 | # KITTI-Road-MOS 2 | 3 | To enrich the dataset in the moving object segmentation (MOS) task and to reduce the gap of different data distributions between the validation and test sets in the existing SemanticKITTI-MOS dataset, we automatically annotated and manually corrected the [KITTI-Road](http://www.cvlibs.net/datasets/kitti/raw_data.php?type=road) dataset. 4 | 5 | More specifically, we first use auto-mos labeling method ([link](https://arxiv.org/pdf/2201.04501.pdf)) to automatically generate the MOS labels for KITTI-Road data. We then use a point labeler ([link](https://github.com/jbehley/point_labeler)) to manually refined the labels. 6 | 7 | We follow semantic SLAM [SuMa++](https://github.com/PRBonn/semantic_suma) to rename the sequences of KITTI-Road data as follows. 8 | 9 | ``` 10 | raw_id -> seq_id 11 | 2011_09_26_drive_0015 -> 30 12 | 2011_09_26_drive_0027 -> 31 13 | 2011_09_26_drive_0028 -> 32 14 | 2011_09_26_drive_0029 -> 33 15 | 2011_09_26_drive_0032 -> 34 16 | 2011_09_26_drive_0052 -> 35 17 | 2011_09_26_drive_0070 -> 36 18 | 2011_09_26_drive_0101 -> 37 19 | 2011_09_29_drive_0004 -> 38 20 | 2011_09_30_drive_0016 -> 39 21 | 2011_10_03_drive_0042 -> 40 22 | 2011_10_03_drive_0047 -> 41 23 | ``` 24 | We provide a simple download and conversion script [utils/download_kitti_road.sh](../utils/download_kitti_road.sh), please modify the `DATA_ROOT` path and manually move the result folder `sequences` to the target folder. 25 | And you need to download the KITTI-Road-MOS label data annotated by us, the pose and calib files from [here](https://drive.google.com/file/d/131tKKhJiNeSiJpnlrXS43bHgZJHh9tug/view?usp=sharing) (6.4 MB) [Remap the label to 9 and 251, consistent with the SemanticKITTI-MOS benchmark]. ~~[old version here](https://drive.google.com/file/d/1pdpcGReJHOJp01pbgXUbcGROWOBd_2kj/view?usp=sharing) (6.1 MB)~~. 26 | 27 | We organize our proposed KITTI-Road-MOS using the same setup and data structure used in SemanticKITTI-MOS: 28 | 29 | ``` 30 | DATAROOT 31 | ├── sequences 32 | │ └── 30 33 | │ ├── calib.txt # calibration file provided by KITTI 34 | │ ├── poses.txt # ground truth poses file provided by KITTI 35 | │ ├── velodyne # velodyne 64 LiDAR scans provided by KITTI 36 | │ │ ├── 000000.bin 37 | │ │ ├── 000001.bin 38 | │ │ └── ... 39 | │ ├── labels # ground truth labels from us 40 | │ │ ├── 000000.label 41 | │ │ ├── 000001.label 42 | │ │ └── ... 43 | │ └── residual_images_1 # the proposed residual images 44 | │ ├── 000000.npy 45 | │ ├── 000001.npy 46 | │ └── ... 47 | ``` 48 | -------------------------------------------------------------------------------- /config/labels/semantic-kitti-all.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | name: "kitti" 3 | labels: 4 | 0: "unlabeled" 5 | 1: "outlier" 6 | 10: "car" 7 | 11: "bicycle" 8 | 13: "bus" 9 | 15: "motorcycle" 10 | 16: "on-rails" 11 | 18: "truck" 12 | 20: "other-vehicle" 13 | 30: "person" 14 | 31: "bicyclist" 15 | 32: "motorcyclist" 16 | 40: "road" 17 | 44: "parking" 18 | 48: "sidewalk" 19 | 49: "other-ground" 20 | 50: "building" 21 | 51: "fence" 22 | 52: "other-structure" 23 | 60: "lane-marking" 24 | 70: "vegetation" 25 | 71: "trunk" 26 | 72: "terrain" 27 | 80: "pole" 28 | 81: "traffic-sign" 29 | 99: "other-object" 30 | 252: "moving-car" 31 | 253: "moving-bicyclist" 32 | 254: "moving-person" 33 | 255: "moving-motorcyclist" 34 | 256: "moving-on-rails" 35 | 257: "moving-bus" 36 | 258: "moving-truck" 37 | 259: "moving-other-vehicle" 38 | color_map: # bgr 39 | 0: [0, 0, 0] 40 | 1: [0, 0, 255] 41 | 10: [245, 150, 100] 42 | 11: [245, 230, 100] 43 | 13: [250, 80, 100] 44 | 15: [150, 60, 30] 45 | 16: [255, 0, 0] 46 | 18: [180, 30, 80] 47 | 20: [255, 0, 0] 48 | 30: [30, 30, 255] 49 | 31: [200, 40, 255] 50 | 32: [90, 30, 150] 51 | 40: [255, 0, 255] 52 | 44: [255, 150, 255] 53 | 48: [75, 0, 75] 54 | 49: [75, 0, 175] 55 | 50: [0, 200, 255] 56 | 51: [50, 120, 255] 57 | 52: [0, 150, 255] 58 | 60: [170, 255, 150] 59 | 70: [0, 175, 0] 60 | 71: [0, 60, 135] 61 | 72: [80, 240, 150] 62 | 80: [150, 240, 255] 63 | 81: [0, 0, 255] 64 | 99: [255, 255, 50] 65 | 252: [245, 150, 100] 66 | 256: [255, 0, 0] 67 | 253: [200, 40, 255] 68 | 254: [30, 30, 255] 69 | 255: [90, 30, 150] 70 | 257: [250, 80, 100] 71 | 258: [180, 30, 80] 72 | 259: [255, 0, 0] 73 | content: # as a ratio with the total number of points 74 | 0: 0.018889854628292943 75 | 1: 0.0002937197336781505 76 | 10: 0.040818519255974316 77 | 11: 0.00016609538710764618 78 | 13: 2.7879693665067774e-05 79 | 15: 0.00039838616015114444 80 | 16: 0.0 81 | 18: 0.0020633612104619787 82 | 20: 0.0016218197275284021 83 | 30: 0.00017698551338515307 84 | 31: 1.1065903904919655e-08 85 | 32: 5.532951952459828e-09 86 | 40: 0.1987493871255525 87 | 44: 0.014717169549888214 88 | 48: 0.14392298360372 89 | 49: 0.0039048553037472045 90 | 50: 0.1326861944777486 91 | 51: 0.0723592229456223 92 | 52: 0.002395131480328884 93 | 60: 4.7084144280367186e-05 94 | 70: 0.26681502148037506 95 | 71: 0.006035012012626033 96 | 72: 0.07814222006271769 97 | 80: 0.002855498193863172 98 | 81: 0.0006155958086189918 99 | 99: 0.009923127583046915 100 | 252: 0.001789309418528068 101 | 253: 0.00012709999297008662 102 | 254: 0.00016059776092534436 103 | 255: 3.745553104802113e-05 104 | 256: 0.0 105 | 257: 0.00011351574470342043 106 | 258: 0.00010157861367183268 107 | 259: 4.3840131989471124e-05 108 | # classes that are indistinguishable from single scan or inconsistent in 109 | # ground truth are mapped to their closest equivalent 110 | learning_map: 111 | 0: 0 # "unlabeled" 112 | 1: 0 # "outlier" mapped to "unlabeled" --------------------------mapped 113 | 10: 1 # "car" 114 | 11: 2 # "bicycle" 115 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 116 | 15: 3 # "motorcycle" 117 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 118 | 18: 4 # "truck" 119 | 20: 5 # "other-vehicle" 120 | 30: 6 # "person" 121 | 31: 7 # "bicyclist" 122 | 32: 8 # "motorcyclist" 123 | 40: 9 # "road" 124 | 44: 10 # "parking" 125 | 48: 11 # "sidewalk" 126 | 49: 12 # "other-ground" 127 | 50: 13 # "building" 128 | 51: 14 # "fence" 129 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 130 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 131 | 70: 15 # "vegetation" 132 | 71: 16 # "trunk" 133 | 72: 17 # "terrain" 134 | 80: 18 # "pole" 135 | 81: 19 # "traffic-sign" 136 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 137 | 252: 20 # "moving-car" 138 | 253: 21 # "moving-bicyclist" 139 | 254: 22 # "moving-person" 140 | 255: 23 # "moving-motorcyclist" 141 | 256: 24 # "moving-on-rails" mapped to "moving-other-vehicle" ------mapped 142 | 257: 24 # "moving-bus" mapped to "moving-other-vehicle" -----------mapped 143 | 258: 25 # "moving-truck" 144 | 259: 24 # "moving-other-vehicle" 145 | learning_map_inv: # inverse of previous map 146 | 0: 0 # "unlabeled", and others ignored 147 | 1: 10 # "car" 148 | 2: 11 # "bicycle" 149 | 3: 15 # "motorcycle" 150 | 4: 18 # "truck" 151 | 5: 20 # "other-vehicle" 152 | 6: 30 # "person" 153 | 7: 31 # "bicyclist" 154 | 8: 32 # "motorcyclist" 155 | 9: 40 # "road" 156 | 10: 44 # "parking" 157 | 11: 48 # "sidewalk" 158 | 12: 49 # "other-ground" 159 | 13: 50 # "building" 160 | 14: 51 # "fence" 161 | 15: 70 # "vegetation" 162 | 16: 71 # "trunk" 163 | 17: 72 # "terrain" 164 | 18: 80 # "pole" 165 | 19: 81 # "traffic-sign" 166 | 20: 252 # "moving-car" 167 | 21: 253 # "moving-bicyclist" 168 | 22: 254 # "moving-person" 169 | 23: 255 # "moving-motorcyclist" 170 | 24: 259 # "moving-other-vehicle" 171 | 25: 258 # "moving-truck" 172 | learning_ignore: # Ignore classes 173 | 0: True # "unlabeled", and others ignored 174 | 1: False # "car" 175 | 2: False # "bicycle" 176 | 3: False # "motorcycle" 177 | 4: False # "truck" 178 | 5: False # "other-vehicle" 179 | 6: False # "person" 180 | 7: False # "bicyclist" 181 | 8: False # "motorcyclist" 182 | 9: False # "road" 183 | 10: False # "parking" 184 | 11: False # "sidewalk" 185 | 12: False # "other-ground" 186 | 13: False # "building" 187 | 14: False # "fence" 188 | 15: False # "vegetation" 189 | 16: False # "trunk" 190 | 17: False # "terrain" 191 | 18: False # "pole" 192 | 19: False # "traffic-sign" 193 | 20: False # "moving-car" 194 | 21: False # "moving-bicyclist" 195 | 22: False # "moving-person" 196 | 23: False # "moving-motorcyclist" 197 | 24: False # "moving-other-vehicle" 198 | 25: False # "moving-truck" 199 | split: # sequence numbers 200 | train: 201 | - 0 202 | - 1 203 | - 2 204 | - 3 205 | - 4 206 | - 5 207 | - 6 208 | - 7 209 | - 9 210 | - 10 211 | valid: 212 | - 8 213 | test: 214 | - 11 215 | - 12 216 | - 13 217 | - 14 218 | - 15 219 | - 16 220 | - 17 221 | - 18 222 | - 19 223 | - 20 224 | - 21 225 | -------------------------------------------------------------------------------- /config/labels/semantic-kitti.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | name: "kitti" 3 | labels: 4 | 0: "unlabeled" 5 | 1: "outlier" 6 | 10: "car" 7 | 11: "bicycle" 8 | 13: "bus" 9 | 15: "motorcycle" 10 | 16: "on-rails" 11 | 18: "truck" 12 | 20: "other-vehicle" 13 | 30: "person" 14 | 31: "bicyclist" 15 | 32: "motorcyclist" 16 | 40: "road" 17 | 44: "parking" 18 | 48: "sidewalk" 19 | 49: "other-ground" 20 | 50: "building" 21 | 51: "fence" 22 | 52: "other-structure" 23 | 60: "lane-marking" 24 | 70: "vegetation" 25 | 71: "trunk" 26 | 72: "terrain" 27 | 80: "pole" 28 | 81: "traffic-sign" 29 | 99: "other-object" 30 | 252: "moving-car" 31 | 253: "moving-bicyclist" 32 | 254: "moving-person" 33 | 255: "moving-motorcyclist" 34 | 256: "moving-on-rails" 35 | 257: "moving-bus" 36 | 258: "moving-truck" 37 | 259: "moving-other-vehicle" 38 | color_map: # bgr 39 | 0: [0, 0, 0] 40 | 1: [0, 0, 255] 41 | 10: [245, 150, 100] 42 | 11: [245, 230, 100] 43 | 13: [250, 80, 100] 44 | 15: [150, 60, 30] 45 | 16: [255, 0, 0] 46 | 18: [180, 30, 80] 47 | 20: [255, 0, 0] 48 | 30: [30, 30, 255] 49 | 31: [200, 40, 255] 50 | 32: [90, 30, 150] 51 | 40: [255, 0, 255] 52 | 44: [255, 150, 255] 53 | 48: [75, 0, 75] 54 | 49: [75, 0, 175] 55 | 50: [0, 200, 255] 56 | 51: [50, 120, 255] 57 | 52: [0, 150, 255] 58 | 60: [170, 255, 150] 59 | 70: [0, 175, 0] 60 | 71: [0, 60, 135] 61 | 72: [80, 240, 150] 62 | 80: [150, 240, 255] 63 | 81: [0, 0, 255] 64 | 99: [255, 255, 50] 65 | 252: [245, 150, 100] 66 | 256: [255, 0, 0] 67 | 253: [200, 40, 255] 68 | 254: [30, 30, 255] 69 | 255: [90, 30, 150] 70 | 257: [250, 80, 100] 71 | 258: [180, 30, 80] 72 | 259: [255, 0, 0] 73 | content: # as a ratio with the total number of points 74 | 0: 0.018889854628292943 75 | 1: 0.0002937197336781505 76 | 10: 0.040818519255974316 77 | 11: 0.00016609538710764618 78 | 13: 2.7879693665067774e-05 79 | 15: 0.00039838616015114444 80 | 16: 0.0 81 | 18: 0.0020633612104619787 82 | 20: 0.0016218197275284021 83 | 30: 0.00017698551338515307 84 | 31: 1.1065903904919655e-08 85 | 32: 5.532951952459828e-09 86 | 40: 0.1987493871255525 87 | 44: 0.014717169549888214 88 | 48: 0.14392298360372 89 | 49: 0.0039048553037472045 90 | 50: 0.1326861944777486 91 | 51: 0.0723592229456223 92 | 52: 0.002395131480328884 93 | 60: 4.7084144280367186e-05 94 | 70: 0.26681502148037506 95 | 71: 0.006035012012626033 96 | 72: 0.07814222006271769 97 | 80: 0.002855498193863172 98 | 81: 0.0006155958086189918 99 | 99: 0.009923127583046915 100 | 252: 0.001789309418528068 101 | 253: 0.00012709999297008662 102 | 254: 0.00016059776092534436 103 | 255: 3.745553104802113e-05 104 | 256: 0.0 105 | 257: 0.00011351574470342043 106 | 258: 0.00010157861367183268 107 | 259: 4.3840131989471124e-05 108 | # classes that are indistinguishable from single scan or inconsistent in 109 | # ground truth are mapped to their closest equivalent 110 | learning_map: 111 | 0: 0 # "unlabeled" 112 | 1: 0 # "outlier" mapped to "unlabeled" --------------------------mapped 113 | 10: 1 # "car" 114 | 11: 2 # "bicycle" 115 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 116 | 15: 3 # "motorcycle" 117 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 118 | 18: 4 # "truck" 119 | 20: 5 # "other-vehicle" 120 | 30: 6 # "person" 121 | 31: 7 # "bicyclist" 122 | 32: 8 # "motorcyclist" 123 | 40: 9 # "road" 124 | 44: 10 # "parking" 125 | 48: 11 # "sidewalk" 126 | 49: 12 # "other-ground" 127 | 50: 13 # "building" 128 | 51: 14 # "fence" 129 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 130 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 131 | 70: 15 # "vegetation" 132 | 71: 16 # "trunk" 133 | 72: 17 # "terrain" 134 | 80: 18 # "pole" 135 | 81: 19 # "traffic-sign" 136 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 137 | 252: 1 # "moving-car" to "car" ------------------------------------mapped 138 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped 139 | 254: 6 # "moving-person" to "person" ------------------------------mapped 140 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped 141 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped 142 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped 143 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped 144 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped 145 | learning_map_inv: # inverse of previous map 146 | 0: 0 # "unlabeled", and others ignored 147 | 1: 10 # "car" 148 | 2: 11 # "bicycle" 149 | 3: 15 # "motorcycle" 150 | 4: 18 # "truck" 151 | 5: 20 # "other-vehicle" 152 | 6: 30 # "person" 153 | 7: 31 # "bicyclist" 154 | 8: 32 # "motorcyclist" 155 | 9: 40 # "road" 156 | 10: 44 # "parking" 157 | 11: 48 # "sidewalk" 158 | 12: 49 # "other-ground" 159 | 13: 50 # "building" 160 | 14: 51 # "fence" 161 | 15: 70 # "vegetation" 162 | 16: 71 # "trunk" 163 | 17: 72 # "terrain" 164 | 18: 80 # "pole" 165 | 19: 81 # "traffic-sign" 166 | learning_ignore: # Ignore classes 167 | 0: True # "unlabeled", and others ignored 168 | 1: False # "car" 169 | 2: False # "bicycle" 170 | 3: False # "motorcycle" 171 | 4: False # "truck" 172 | 5: False # "other-vehicle" 173 | 6: False # "person" 174 | 7: False # "bicyclist" 175 | 8: False # "motorcyclist" 176 | 9: False # "road" 177 | 10: False # "parking" 178 | 11: False # "sidewalk" 179 | 12: False # "other-ground" 180 | 13: False # "building" 181 | 14: False # "fence" 182 | 15: False # "vegetation" 183 | 16: False # "trunk" 184 | 17: False # "terrain" 185 | 18: False # "pole" 186 | 19: False # "traffic-sign" 187 | split: # sequence numbers 188 | train: 189 | - 0 190 | - 1 191 | - 2 192 | - 3 193 | - 4 194 | - 5 195 | - 6 196 | - 7 197 | - 9 198 | - 10 199 | valid: 200 | - 8 201 | test: 202 | - 11 203 | - 12 204 | - 13 205 | - 14 206 | - 15 207 | - 16 208 | - 17 209 | - 18 210 | - 19 211 | - 20 212 | - 21 213 | -------------------------------------------------------------------------------- /config/post-processing.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | # Developed by: Xieyuanli Chen 3 | # Configuration (specifying all the paths) 4 | # -------------------------------------------------------------------- 5 | 6 | # Inputs 7 | # the root of raw LiDAR scans 8 | scan_root: 'DATAROOT' 9 | 10 | # the root of mos predictions 11 | mos_pred_root: "./log/Valid/predictions" 12 | 13 | # the root of semantic predictions 14 | semantic_pred_root: './log/Valid/predictions' 15 | 16 | # Outputs 17 | split: valid # choose from (train, valid, test) 18 | combined_results_root: 'outputs/LiDAR_MOS_Prediction/moving_object_seg_semantic_filtered' 19 | clean_scan_root: 'clean_scan' 20 | 21 | 22 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mfmos 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - backcall=0.2.0=pyhd3eb1b0_0 9 | - blas=1.0=mkl 10 | - ca-certificates=2023.08.22=h06a4308_0 11 | - certifi=2022.12.7=py37h06a4308_0 12 | - cudatoolkit=11.0.221=h6bb024c_0 13 | - dbus=1.13.18=hb2f20db_0 14 | - decorator=5.1.1=pyhd3eb1b0_0 15 | - expat=2.4.9=h6a678d5_0 16 | - flit-core=3.6.0=pyhd3eb1b0_0 17 | - fontconfig=2.14.1=h52c9d5c_1 18 | - freetype=2.12.1=h4a9f257_0 19 | - giflib=5.2.1=h5eee18b_3 20 | - glib=2.69.1=h4ff587b_1 21 | - gst-plugins-base=1.14.1=h6a678d5_1 22 | - gstreamer=1.14.1=h5eee18b_1 23 | - icu=58.2=he6710b0_3 24 | - importlib-metadata=4.11.3=py37h06a4308_0 25 | - importlib_metadata=4.11.3=hd3eb1b0_0 26 | - intel-openmp=2021.4.0=h06a4308_3561 27 | - ipython=7.31.1=py37h06a4308_1 28 | - jedi=0.18.1=py37h06a4308_1 29 | - jpeg=9e=h5eee18b_1 30 | - lcms2=2.12=h3be6417_0 31 | - ld_impl_linux-64=2.38=h1181459_1 32 | - lerc=3.0=h295c915_0 33 | - libdeflate=1.17=h5eee18b_0 34 | - libffi=3.3=he6710b0_2 35 | - libgcc-ng=11.2.0=h1234567_1 36 | - libgomp=11.2.0=h1234567_1 37 | - libllvm11=11.1.0=h9e868ea_6 38 | - libpng=1.6.39=h5eee18b_0 39 | - libstdcxx-ng=11.2.0=h1234567_1 40 | - libtiff=4.5.1=h6a678d5_0 41 | - libuuid=1.41.5=h5eee18b_0 42 | - libuv=1.44.2=h5eee18b_0 43 | - libwebp=1.2.4=h11a3e52_1 44 | - libwebp-base=1.2.4=h5eee18b_1 45 | - libxcb=1.15=h7f8727e_0 46 | - libxml2=2.9.14=h74e7548_0 47 | - llvmlite=0.39.1=py37he621ea3_0 48 | - lz4-c=1.9.4=h6a678d5_0 49 | - matplotlib-inline=0.1.6=py37h06a4308_0 50 | - mkl=2021.4.0=h06a4308_640 51 | - mkl-service=2.4.0=py37h7f8727e_0 52 | - mkl_fft=1.3.1=py37hd3c417c_0 53 | - mkl_random=1.2.2=py37h51133e4_0 54 | - ncurses=6.4=h6a678d5_0 55 | - ninja=1.10.2=h06a4308_5 56 | - ninja-base=1.10.2=hd09550d_5 57 | - numba=0.56.4=py37h417a72b_0 58 | - openssl=1.1.1w=h7f8727e_0 59 | - parso=0.8.3=pyhd3eb1b0_0 60 | - pcre=8.45=h295c915_0 61 | - pexpect=4.8.0=pyhd3eb1b0_3 62 | - pickleshare=0.7.5=pyhd3eb1b0_1003 63 | - pip=22.3.1=py37h06a4308_0 64 | - prompt-toolkit=3.0.36=py37h06a4308_0 65 | - ptyprocess=0.7.0=pyhd3eb1b0_2 66 | - pygments=2.11.2=pyhd3eb1b0_0 67 | - pyqt=5.9.2=py37h05f1152_2 68 | - python=3.7.7=hcff3b4d_5 69 | - pytorch=1.7.0=py3.7_cuda11.0.221_cudnn8.0.3_0 70 | - qt=5.9.7=h5867ecd_1 71 | - readline=8.2=h5eee18b_0 72 | - setuptools=65.6.3=py37h06a4308_0 73 | - sip=4.19.8=py37hf484d3e_0 74 | - six=1.16.0=pyhd3eb1b0_1 75 | - sqlite=3.41.2=h5eee18b_0 76 | - tbb=2021.8.0=hdb19cb5_0 77 | - tk=8.6.12=h1ccaba5_0 78 | - torchvision=0.8.0=py37_cu110 79 | - traitlets=5.7.1=py37h06a4308_0 80 | - typing_extensions=4.4.0=py37h06a4308_0 81 | - wcwidth=0.2.5=pyhd3eb1b0_0 82 | - wheel=0.38.4=py37h06a4308_0 83 | - xz=5.4.2=h5eee18b_0 84 | - zipp=3.11.0=py37h06a4308_0 85 | - zlib=1.2.13=h5eee18b_0 86 | - zstd=1.5.5=hc292b87_0 87 | - pip: 88 | - absl-py==1.4.0 89 | - addict==2.4.0 90 | - asttokens==2.2.1 91 | - astunparse==1.6.3 92 | - attrs==23.1.0 93 | - cachetools==4.2.4 94 | - charset-normalizer==3.1.0 95 | - click==8.1.3 96 | - colorama==0.4.6 97 | - configargparse==1.5.3 98 | - cycler==0.11.0 99 | - cython==0.29.26 100 | - dash==2.9.3 101 | - dash-core-components==2.0.0 102 | - dash-html-components==2.0.0 103 | - dash-table==5.0.0 104 | - dataclasses==0.6 105 | - debugpy==1.6.7 106 | - easydict==1.9 107 | - entrypoints==0.4 108 | - executing==1.2.0 109 | - fastjsonschema==2.16.3 110 | - flask==2.2.5 111 | - flatbuffers==23.5.26 112 | - fonttools==4.38.0 113 | - freetype-py==2.4.0 114 | - future==0.18.3 115 | - gast==0.4.0 116 | - google-auth==1.35.0 117 | - google-auth-oauthlib==0.4.6 118 | - google-pasta==0.2.0 119 | - grpcio==1.54.2 120 | - h5py==3.8.0 121 | - hsluv==5.0.3 122 | - icecream==2.1.3 123 | - idna==3.4 124 | - importlib-resources==5.12.0 125 | - ipykernel==6.16.2 126 | - ipywidgets==8.0.6 127 | - itsdangerous==2.1.2 128 | - jinja2==3.1.2 129 | - joblib==1.2.0 130 | - jsonschema==4.17.3 131 | - jupyter-client==7.4.9 132 | - jupyter-core==4.12.0 133 | - jupyterlab-widgets==3.0.7 134 | - keras==2.11.0 135 | - kiwisolver==1.4.4 136 | - libclang==16.0.6 137 | - markdown==3.4.3 138 | - markupsafe==2.1.2 139 | - matplotlib==3.5.3 140 | - nbformat==5.7.0 141 | - nest-asyncio==1.5.6 142 | - nose==1.3.7 143 | - numpy==1.21.6 144 | - oauthlib==3.2.2 145 | - open3d==0.17.0 146 | - opencv-contrib-python==4.5.1.48 147 | - opencv-python==4.5.1.48 148 | - opt-einsum==3.3.0 149 | - packaging==23.1 150 | - pandas==1.3.5 151 | - pillow==9.5.0 152 | - pkgutil-resolve-name==1.3.10 153 | - plotly==5.14.1 154 | - protobuf==3.19.6 155 | - psutil==5.9.5 156 | - pyasn1==0.5.0 157 | - pyasn1-modules==0.3.0 158 | - pyparsing==3.0.9 159 | - pyquaternion==0.9.9 160 | - pyrsistent==0.19.3 161 | - python-dateutil==2.8.2 162 | - pytz==2023.3 163 | - pyyaml==6.0 164 | - pyzmq==25.0.2 165 | - requests==2.30.0 166 | - requests-oauthlib==1.3.1 167 | - rsa==4.9 168 | - scikit-learn==0.24.2 169 | - scipy==1.7.3 170 | - strictyaml==1.4.4 171 | - tenacity==8.2.2 172 | - tensorboard==2.11.2 173 | - tensorboard-data-server==0.6.1 174 | - tensorboard-plugin-wit==1.8.1 175 | - tensorboardx==2.1 176 | - tensorflow==2.11.0 177 | - tensorflow-estimator==2.11.0 178 | - tensorflow-io-gcs-filesystem==0.33.0 179 | - termcolor==2.3.0 180 | - threadpoolctl==3.1.0 181 | - torchinfo==1.7.2 182 | - tornado==6.2 183 | - tqdm==4.65.0 184 | - urllib3==2.0.2 185 | - vispy==0.7.0 186 | - werkzeug==2.2.3 187 | - widgetsnbextension==4.0.7 188 | - wrapt==1.15.0 189 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | 4 | import os 5 | from utils.utils import * 6 | from modules.user import * 7 | from modules.user_refine import * 8 | 9 | if __name__ == '__main__': 10 | 11 | parser = get_args(flags="infer") 12 | FLAGS, unparsed = parser.parse_known_args() 13 | 14 | print("----------") 15 | print("INTERFACE:") 16 | print(" dataset", FLAGS.dataset) 17 | print(" log", FLAGS.log) 18 | print(" model", FLAGS.model) 19 | print(" infering", FLAGS.split) 20 | print(" pointrefine", FLAGS.pointrefine) 21 | print(" save movable", FLAGS.movable) 22 | print("----------\n") 23 | 24 | # open arch / data config file 25 | ARCH = load_yaml(FLAGS.model + "/arch_cfg.yaml") 26 | DATA = load_yaml(FLAGS.model + "/data_cfg.yaml") 27 | 28 | make_predictions_dir(FLAGS, DATA, save_movable=FLAGS.movable) # create predictions file folder 29 | check_model_dir(FLAGS.model) # does model folder exist? 30 | 31 | # create user and infer dataset 32 | if not FLAGS.pointrefine: 33 | user = User(ARCH, DATA, datadir=FLAGS.dataset, outputdir=FLAGS.log, 34 | modeldir=FLAGS.model, split=FLAGS.split, save_movable=FLAGS.movable) 35 | else: 36 | user = UserRefine(ARCH, DATA, datadir=FLAGS.dataset, outputdir=FLAGS.log, 37 | modeldir=FLAGS.model, split=FLAGS.split, save_movable=FLAGS.movable) 38 | user.infer() 39 | -------------------------------------------------------------------------------- /modules/KNN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def get_gaussian_kernel(kernel_size=3, sigma=2, channels=1): 12 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) 13 | x_coord = torch.arange(kernel_size) 14 | x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) 15 | y_grid = x_grid.t() 16 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() 17 | 18 | mean = (kernel_size - 1) / 2. 19 | variance = sigma ** 2. 20 | 21 | # Calculate the 2-dimensional gaussian kernel which is 22 | # the product of two gaussian distributions for two different 23 | # variables (in this case called x and y) 24 | gaussian_kernel = (1. / (2. * math.pi * variance)) * \ 25 | torch.exp(-torch.sum((xy_grid - mean) ** 2., dim=-1) / (2 * variance)) 26 | 27 | # Make sure sum of values in gaussian kernel equals 1. 28 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 29 | 30 | # Reshape to 2d depthwise convolutional weight 31 | gaussian_kernel = gaussian_kernel.view(kernel_size, kernel_size) 32 | 33 | return gaussian_kernel 34 | 35 | 36 | class KNN(nn.Module): 37 | def __init__(self, params, nclasses): 38 | super().__init__() 39 | print("*" * 80) 40 | print("Cleaning point-clouds with kNN post-processing") 41 | self.knn = params["knn"] 42 | self.search = params["search"] 43 | self.sigma = params["sigma"] 44 | self.cutoff = params["cutoff"] 45 | self.nclasses = nclasses 46 | print("kNN parameters:") 47 | print("knn:", self.knn) 48 | print("search:", self.search) 49 | print("sigma:", self.sigma) 50 | print("cutoff:", self.cutoff) 51 | print("nclasses:", self.nclasses) 52 | print("*" * 80) 53 | 54 | def forward(self, proj_range, unproj_range, proj_argmax, px, py): 55 | ''' Warning! Only works for un-batched pointclouds. 56 | If they come batched we need to iterate over the batch dimension or do 57 | something REALLY smart to handle unaligned number of points in memory 58 | ''' 59 | # get device 60 | if proj_range.is_cuda: 61 | device = torch.device("cuda") 62 | else: 63 | device = torch.device("cpu") 64 | 65 | # sizes of projection scan 66 | H, W = proj_range.shape 67 | 68 | # number of points 69 | P = unproj_range.shape 70 | 71 | # check if size of kernel is odd and complain 72 | if (self.search % 2 == 0): 73 | raise ValueError("Nearest neighbor kernel must be odd number") 74 | 75 | # calculate padding 76 | pad = int((self.search - 1) / 2) 77 | 78 | # unfold neighborhood to get nearest neighbors for each pixel (range image) 79 | proj_unfold_k_rang = F.unfold(proj_range[None, None, ...], 80 | kernel_size=(self.search, self.search), 81 | padding=(pad, pad)) 82 | 83 | # index with px, py to get ALL the pcld points 84 | idx_list = py * W + px 85 | unproj_unfold_k_rang = proj_unfold_k_rang[:, :, idx_list] 86 | 87 | # WARNING, THIS IS A HACK 88 | # Make non valid (<0) range points extremely big so that there is no screwing 89 | # up the nn self.search 90 | unproj_unfold_k_rang[unproj_unfold_k_rang < 0] = float("inf") 91 | 92 | # now the matrix is unfolded TOTALLY, replace the middle points with the actual range points 93 | center = int(((self.search * self.search) - 1) / 2) 94 | unproj_unfold_k_rang[:, center, :] = unproj_range 95 | 96 | # now compare range 97 | k2_distances = torch.abs(unproj_unfold_k_rang - unproj_range) 98 | 99 | # make a kernel to weigh the ranges according to distance in (x,y) 100 | # I make this 1 - kernel because I want distances that are close in (x,y) 101 | # to matter more 102 | inv_gauss_k = ( 103 | 1 - get_gaussian_kernel(self.search, self.sigma, 1)).view(1, -1, 1) 104 | inv_gauss_k = inv_gauss_k.to(device).type(proj_range.type()) 105 | 106 | # apply weighing 107 | k2_distances = k2_distances * inv_gauss_k 108 | 109 | # find nearest neighbors 110 | _, knn_idx = k2_distances.topk( 111 | self.knn, dim=1, largest=False, sorted=False) 112 | 113 | # do the same unfolding with the argmax 114 | proj_unfold_1_argmax = F.unfold(proj_argmax[None, None, ...].float(), 115 | kernel_size=(self.search, self.search), 116 | padding=(pad, pad)).long() 117 | unproj_unfold_1_argmax = proj_unfold_1_argmax[:, :, idx_list] 118 | 119 | # get the top k predictions from the knn at each pixel 120 | knn_argmax = torch.gather( 121 | input=unproj_unfold_1_argmax, dim=1, index=knn_idx) 122 | 123 | # fake an invalid argmax of classes + 1 for all cutoff items 124 | if self.cutoff > 0: 125 | knn_distances = torch.gather(input=k2_distances, dim=1, index=knn_idx) 126 | knn_invalid_idx = knn_distances > self.cutoff 127 | knn_argmax[knn_invalid_idx] = self.nclasses 128 | 129 | # now vote 130 | # argmax onehot has an extra class for objects after cutoff 131 | knn_argmax_onehot = torch.zeros( 132 | (1, self.nclasses + 1, P[0]), device=device).type(proj_range.type()) 133 | ones = torch.ones_like(knn_argmax).type(proj_range.type()) 134 | knn_argmax_onehot = knn_argmax_onehot.scatter_add_(1, knn_argmax, ones) 135 | 136 | # now vote (as a sum over the onehot shit) (don't let it choose unlabeled OR invalid) 137 | knn_argmax_out = knn_argmax_onehot[:, 1:-1].argmax(dim=1) + 1 138 | 139 | # reshape again 140 | knn_argmax_out = knn_argmax_out.view(P) 141 | 142 | return knn_argmax_out 143 | -------------------------------------------------------------------------------- /modules/PointRefine/PointMLP.py: -------------------------------------------------------------------------------- 1 | # A simple MLP network structure for point clouds, 2 | # 3 | # Added by Jiadai Sun 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class PointRefine(nn.Module): 11 | 12 | def __init__(self, n_class=3, 13 | in_fea_dim=35, 14 | out_point_fea_dim=64): 15 | super(PointRefine, self).__init__() 16 | 17 | self.n_class = n_class 18 | self.PPmodel = nn.Sequential( 19 | nn.BatchNorm1d(in_fea_dim), 20 | 21 | nn.Linear(in_fea_dim, 64), 22 | nn.BatchNorm1d(64), 23 | nn.ReLU(), 24 | 25 | nn.Linear(64, 128), 26 | nn.BatchNorm1d(128), 27 | nn.ReLU(), 28 | 29 | nn.Linear(128, 256), 30 | nn.BatchNorm1d(256), 31 | nn.ReLU(), 32 | 33 | nn.Linear(256, out_point_fea_dim) 34 | ) 35 | 36 | self.logits = nn.Sequential( 37 | nn.Linear(out_point_fea_dim, self.n_class) 38 | ) 39 | 40 | def forward(self, point_fea): 41 | # the point_fea need with size (b, N, c) e.g. torch.Size([1, 121722, 35]) 42 | # process feature 43 | # torch.Size([124668, 9]) --> torch.Size([124668, 256]) 44 | processed_point_fea = self.PPmodel(point_fea) 45 | logits = self.logits(processed_point_fea) 46 | point_predict = F.softmax(logits, dim=1) 47 | return point_predict 48 | 49 | 50 | if __name__ == '__main__': 51 | 52 | import time 53 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 54 | model = PointRefine() 55 | model.train() 56 | 57 | # t0 = time.time() 58 | # pred = model(cloud) 59 | # t1 = time.time() 60 | # print(t1-t0) 61 | 62 | total = sum([param.nelement() for param in model.parameters()]) 63 | print("Number of PointRefine parameter: %.2fM" % (total/1e6)) 64 | # Number of PointRefine parameter: 0.04M 65 | -------------------------------------------------------------------------------- /modules/SalsaNext.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | import imp 4 | 5 | import __init__ as booger 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | class ResContextBlock(nn.Module): 11 | def __init__(self, in_filters, out_filters): 12 | super(ResContextBlock, self).__init__() 13 | self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), stride=1) 14 | self.act1 = nn.LeakyReLU() 15 | 16 | self.conv2 = nn.Conv2d(out_filters, out_filters, (3,3), padding=1) 17 | self.act2 = nn.LeakyReLU() 18 | self.bn1 = nn.BatchNorm2d(out_filters) 19 | 20 | self.conv3 = nn.Conv2d(out_filters, out_filters, (3,3),dilation=2, padding=2) 21 | self.act3 = nn.LeakyReLU() 22 | self.bn2 = nn.BatchNorm2d(out_filters) 23 | 24 | 25 | def forward(self, x): 26 | 27 | shortcut = self.conv1(x) 28 | shortcut = self.act1(shortcut) 29 | 30 | resA = self.conv2(shortcut) 31 | resA = self.act2(resA) 32 | resA1 = self.bn1(resA) 33 | 34 | resA = self.conv3(resA1) 35 | resA = self.act3(resA) 36 | resA2 = self.bn2(resA) 37 | 38 | output = shortcut + resA2 39 | return output 40 | 41 | 42 | class ResBlock(nn.Module): 43 | def __init__(self, in_filters, out_filters, dropout_rate, kernel_size=(3, 3), stride=1, 44 | pooling=True, drop_out=True): 45 | super(ResBlock, self).__init__() 46 | self.pooling = pooling 47 | self.drop_out = drop_out 48 | self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), stride=stride) 49 | self.act1 = nn.LeakyReLU() 50 | 51 | self.conv2 = nn.Conv2d(in_filters, out_filters, kernel_size=(3,3), padding=1) 52 | self.act2 = nn.LeakyReLU() 53 | self.bn1 = nn.BatchNorm2d(out_filters) 54 | 55 | self.conv3 = nn.Conv2d(out_filters, out_filters, kernel_size=(3,3),dilation=2, padding=2) 56 | self.act3 = nn.LeakyReLU() 57 | self.bn2 = nn.BatchNorm2d(out_filters) 58 | 59 | self.conv4 = nn.Conv2d(out_filters, out_filters, kernel_size=(2, 2), dilation=2, padding=1) 60 | self.act4 = nn.LeakyReLU() 61 | self.bn3 = nn.BatchNorm2d(out_filters) 62 | 63 | self.conv5 = nn.Conv2d(out_filters*3, out_filters, kernel_size=(1, 1)) 64 | self.act5 = nn.LeakyReLU() 65 | self.bn4 = nn.BatchNorm2d(out_filters) 66 | 67 | if pooling: 68 | self.dropout = nn.Dropout2d(p=dropout_rate) 69 | self.pool = nn.AvgPool2d(kernel_size=kernel_size, stride=2, padding=1) 70 | else: 71 | self.dropout = nn.Dropout2d(p=dropout_rate) 72 | 73 | def forward(self, x): 74 | shortcut = self.conv1(x) 75 | shortcut = self.act1(shortcut) 76 | 77 | resA = self.conv2(x) 78 | resA = self.act2(resA) 79 | resA1 = self.bn1(resA) 80 | 81 | resA = self.conv3(resA1) 82 | resA = self.act3(resA) 83 | resA2 = self.bn2(resA) 84 | 85 | resA = self.conv4(resA2) 86 | resA = self.act4(resA) 87 | resA3 = self.bn3(resA) 88 | 89 | concat = torch.cat((resA1,resA2,resA3),dim=1) 90 | resA = self.conv5(concat) 91 | resA = self.act5(resA) 92 | resA = self.bn4(resA) 93 | resA = shortcut + resA 94 | 95 | 96 | if self.pooling: 97 | if self.drop_out: 98 | resB = self.dropout(resA) 99 | else: 100 | resB = resA 101 | resB = self.pool(resB) 102 | 103 | return resB, resA 104 | else: 105 | if self.drop_out: 106 | resB = self.dropout(resA) 107 | else: 108 | resB = resA 109 | return resB 110 | 111 | 112 | class UpBlock(nn.Module): 113 | def __init__(self, in_filters, out_filters, dropout_rate, drop_out=True): 114 | super(UpBlock, self).__init__() 115 | self.drop_out = drop_out 116 | self.in_filters = in_filters 117 | self.out_filters = out_filters 118 | 119 | self.dropout1 = nn.Dropout2d(p=dropout_rate) 120 | 121 | self.dropout2 = nn.Dropout2d(p=dropout_rate) 122 | 123 | self.conv1 = nn.Conv2d(in_filters//4 + 2*out_filters, out_filters, (3,3), padding=1) 124 | self.act1 = nn.LeakyReLU() 125 | self.bn1 = nn.BatchNorm2d(out_filters) 126 | 127 | self.conv2 = nn.Conv2d(out_filters, out_filters, (3,3),dilation=2, padding=2) 128 | self.act2 = nn.LeakyReLU() 129 | self.bn2 = nn.BatchNorm2d(out_filters) 130 | 131 | self.conv3 = nn.Conv2d(out_filters, out_filters, (2,2), dilation=2,padding=1) 132 | self.act3 = nn.LeakyReLU() 133 | self.bn3 = nn.BatchNorm2d(out_filters) 134 | 135 | 136 | self.conv4 = nn.Conv2d(out_filters*3,out_filters,kernel_size=(1,1)) 137 | self.act4 = nn.LeakyReLU() 138 | self.bn4 = nn.BatchNorm2d(out_filters) 139 | 140 | self.dropout3 = nn.Dropout2d(p=dropout_rate) 141 | 142 | def forward(self, x, skip): 143 | upA = nn.PixelShuffle(2)(x) 144 | if self.drop_out: 145 | upA = self.dropout1(upA) 146 | 147 | upB = torch.cat((upA,skip),dim=1) 148 | if self.drop_out: 149 | upB = self.dropout2(upB) 150 | 151 | upE = self.conv1(upB) 152 | upE = self.act1(upE) 153 | upE1 = self.bn1(upE) 154 | 155 | upE = self.conv2(upE1) 156 | upE = self.act2(upE) 157 | upE2 = self.bn2(upE) 158 | 159 | upE = self.conv3(upE2) 160 | upE = self.act3(upE) 161 | upE3 = self.bn3(upE) 162 | 163 | concat = torch.cat((upE1,upE2,upE3),dim=1) 164 | upE = self.conv4(concat) 165 | upE = self.act4(upE) 166 | upE = self.bn4(upE) 167 | if self.drop_out: 168 | upE = self.dropout3(upE) 169 | 170 | return upE 171 | 172 | 173 | class SalsaNext(nn.Module): 174 | def __init__(self, nclasses, params): 175 | super(SalsaNext, self).__init__() 176 | self.nclasses = nclasses 177 | 178 | ### mos modification 179 | if params['train']['residual']: 180 | self.input_size = 5 + params['train']['n_input_scans'] 181 | 182 | else: 183 | self.input_size = 5 * params['train']['n_input_scans'] 184 | 185 | print("Depth of backbone input = ", self.input_size) 186 | ### 187 | 188 | self.downCntx = ResContextBlock(self.input_size, 32) 189 | self.downCntx2 = ResContextBlock(32, 32) 190 | self.downCntx3 = ResContextBlock(32, 32) 191 | 192 | self.resBlock1 = ResBlock(32, 2 * 32, 0.2, pooling=True, drop_out=False) 193 | self.resBlock2 = ResBlock(2 * 32, 2 * 2 * 32, 0.2, pooling=True) 194 | self.resBlock3 = ResBlock(2 * 2 * 32, 2 * 4 * 32, 0.2, pooling=True) 195 | self.resBlock4 = ResBlock(2 * 4 * 32, 2 * 4 * 32, 0.2, pooling=True) 196 | self.resBlock5 = ResBlock(2 * 4 * 32, 2 * 4 * 32, 0.2, pooling=False) 197 | 198 | self.upBlock1 = UpBlock(2 * 4 * 32, 4 * 32, 0.2) 199 | self.upBlock2 = UpBlock(4 * 32, 4 * 32, 0.2) 200 | self.upBlock3 = UpBlock(4 * 32, 2 * 32, 0.2) 201 | self.upBlock4 = UpBlock(2 * 32, 32, 0.2, drop_out=False) 202 | 203 | self.logits = nn.Conv2d(32, nclasses, kernel_size=(1, 1)) 204 | 205 | def forward(self, x): 206 | downCntx = self.downCntx(x) 207 | downCntx = self.downCntx2(downCntx) 208 | downCntx = self.downCntx3(downCntx) 209 | 210 | down0c, down0b = self.resBlock1(downCntx) 211 | down1c, down1b = self.resBlock2(down0c) 212 | down2c, down2b = self.resBlock3(down1c) 213 | down3c, down3b = self.resBlock4(down2c) 214 | down5c = self.resBlock5(down3c) 215 | 216 | up4e = self.upBlock1(down5c,down3b) 217 | up3e = self.upBlock2(up4e, down2b) 218 | up2e = self.upBlock3(up3e, down1b) 219 | up1e = self.upBlock4(up2e, down0b) 220 | logits = self.logits(up1e) 221 | 222 | logits = logits 223 | logits = F.softmax(logits, dim=1) 224 | return logits -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/MF-MOS/0c702445a39b978efc107cf7d0a2a33246f857ba/modules/__init__.py -------------------------------------------------------------------------------- /modules/loss/DiceLoss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | 4 | import numpy 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | #PyTorch 10 | # class DiceLoss(nn.Module): 11 | # def __init__(self, weight=None, size_average=True): 12 | # super(DiceLoss, self).__init__() 13 | 14 | # def forward(self, inputs, targets, smooth=1): 15 | 16 | # #comment out if your model contains a sigmoid or equivalent activation layer 17 | # inputs = F.sigmoid(inputs) 18 | 19 | # #flatten label and prediction tensors 20 | # inputs = inputs.view(-1) 21 | # targets = targets.view(-1) 22 | 23 | # intersection = (inputs * targets).sum() 24 | # dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) 25 | 26 | # return 1 - dice 27 | 28 | # https://smp.readthedocs.io/en/latest/losses.html 29 | # https://github.com/pytorch/pytorch/issues/1249 30 | # https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch#Dice-Loss 31 | # https://kornia.readthedocs.io/en/v0.1.2/_modules/torchgeometry/losses/dice.html 32 | 33 | 34 | # based on: 35 | # https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py 36 | 37 | class DiceLoss(nn.Module): 38 | r"""Criterion that computes Sørensen-Dice Coefficient loss. 39 | 40 | According to [1], we compute the Sørensen-Dice Coefficient as follows: 41 | 42 | .. math:: 43 | 44 | \text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|} 45 | 46 | where: 47 | - :math:`X` expects to be the scores of each class. 48 | - :math:`Y` expects to be the one-hot tensor with the class labels. 49 | 50 | the loss, is finally computed as: 51 | 52 | .. math:: 53 | 54 | \text{loss}(x, class) = 1 - \text{Dice}(x, class) 55 | 56 | [1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient 57 | 58 | Shape: 59 | - Input: :math:`(N, C, H, W)` where C = number of classes. 60 | - Target: :math:`(N, H, W)` where each value is 61 | :math:`0 ≤ targets[i] ≤ C−1`. 62 | 63 | Examples: 64 | >>> N = 5 # num_classes 65 | >>> loss = tgm.losses.DiceLoss() 66 | >>> input = torch.randn(1, N, 3, 5, requires_grad=True) 67 | >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) 68 | >>> output = loss(input, target) 69 | >>> output.backward() 70 | """ 71 | 72 | def __init__(self) -> None: 73 | super(DiceLoss, self).__init__() 74 | self.eps: float = 1e-6 75 | 76 | def forward(self, input: torch.Tensor, 77 | target: torch.Tensor) -> torch.Tensor: 78 | if not torch.is_tensor(input): 79 | raise TypeError("Input type is not a torch.Tensor. Got {}" 80 | .format(type(input))) 81 | if not len(input.shape) == 4: 82 | raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}" 83 | .format(input.shape)) 84 | if not input.shape[-2:] == target.shape[-2:]: 85 | raise ValueError("input and target shapes must be the same. Got: {}" 86 | .format(input.shape, input.shape)) 87 | if not input.device == target.device: 88 | raise ValueError( 89 | "input and target must be in the same device. Got: {}" .format( 90 | input.device, target.device)) 91 | # compute softmax over the classes axis 92 | # input_soft = F.softmax(input, dim=1) # have done is network last layer 93 | 94 | # create the labels one hot tensor 95 | # target_one_hot = one_hot(target, num_classes=input.shape[1], 96 | # device=input.device, dtype=input.dtype) 97 | target_one_hot = F.one_hot(target, num_classes=input.shape[1]).permute(0, 3, 1, 2) 98 | 99 | # compute the actual dice score 100 | dims = (1, 2, 3) 101 | # intersection = torch.sum(input_soft * target_one_hot, dims) 102 | # cardinality = torch.sum(input_soft + target_one_hot, dims) 103 | 104 | ## if we need to ignore the class=0 105 | input_filter = input[:, 1:, :, :] 106 | target_one_hot_filter = input[:, 1:, :, :] 107 | intersection = torch.sum(input_filter * target_one_hot_filter, dims) 108 | cardinality = torch.sum(input_filter + target_one_hot_filter, dims) 109 | 110 | dice_score = 2. * intersection / (cardinality + self.eps) 111 | return torch.mean(1. - dice_score) 112 | 113 | 114 | 115 | ###################### 116 | # functional interface 117 | ###################### 118 | 119 | 120 | def dice_loss(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 121 | r"""Function that computes Sørensen-Dice Coefficient loss. 122 | 123 | See :class:`~torchgeometry.losses.DiceLoss` for details. 124 | """ 125 | return DiceLoss()(input, target) 126 | 127 | -------------------------------------------------------------------------------- /modules/loss/Lovasz_Softmax.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | MIT License 4 | 5 | Copyright (c) 2018 Maxim Berman 6 | Copyright (c) 2020 Tiago Cortinhal, George Tzelepis and Eren Erdal Aksoy 7 | 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | """ 20 | import torch 21 | import torch.nn as nn 22 | from torch.autograd import Variable 23 | 24 | 25 | try: 26 | from itertools import ifilterfalse 27 | except ImportError: 28 | from itertools import filterfalse as ifilterfalse 29 | 30 | 31 | def isnan(x): 32 | return x != x 33 | 34 | 35 | def mean(l, ignore_nan=False, empty=0): 36 | """ 37 | nanmean compatible with generators. 38 | """ 39 | l = iter(l) 40 | if ignore_nan: 41 | l = ifilterfalse(isnan, l) 42 | try: 43 | n = 1 44 | acc = next(l) 45 | except StopIteration: 46 | if empty == 'raise': 47 | raise ValueError('Empty mean') 48 | return empty 49 | for n, v in enumerate(l, 2): 50 | acc += v 51 | if n == 1: 52 | return acc 53 | return acc / n 54 | 55 | 56 | def lovasz_grad(gt_sorted): 57 | """ 58 | Computes gradient of the Lovasz extension w.r.t sorted errors 59 | See Alg. 1 in paper 60 | """ 61 | p = len(gt_sorted) 62 | gts = gt_sorted.sum() 63 | intersection = gts - gt_sorted.float().cumsum(0) 64 | union = gts + (1 - gt_sorted).float().cumsum(0) 65 | jaccard = 1. - intersection / union 66 | if p > 1: # cover 1-pixel case 67 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 68 | return jaccard 69 | 70 | 71 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): 72 | """ 73 | Multi-class Lovasz-Softmax loss 74 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 75 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 76 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 77 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 78 | per_image: compute the loss per image instead of per batch 79 | ignore: void class labels 80 | """ 81 | if per_image: 82 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) 83 | for prob, lab in zip(probas, labels)) 84 | else: 85 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) 86 | return loss 87 | 88 | 89 | def lovasz_softmax_flat(probas, labels, classes='present'): 90 | """ 91 | Multi-class Lovasz-Softmax loss 92 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 93 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 94 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 95 | """ 96 | if probas.numel() == 0: 97 | # only void pixels, the gradients should be 0 98 | return probas * 0. 99 | C = probas.size(1) 100 | losses = [] 101 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 102 | for c in class_to_sum: 103 | fg = (labels == c).float() # foreground for class c 104 | if (classes == 'present' and fg.sum() == 0): 105 | continue 106 | if C == 1: 107 | if len(classes) > 1: 108 | raise ValueError('Sigmoid output possible only with 1 class') 109 | class_pred = probas[:, 0] 110 | else: 111 | class_pred = probas[:, c] 112 | errors = (Variable(fg) - class_pred).abs() 113 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 114 | perm = perm.data 115 | fg_sorted = fg[perm] 116 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 117 | return mean(losses) 118 | 119 | 120 | def flatten_probas(probas, labels, ignore=None): 121 | """ 122 | Flattens predictions in the batch 123 | """ 124 | if probas.dim() == 3: 125 | # assumes output of a sigmoid layer 126 | B, H, W = probas.size() 127 | probas = probas.view(B, 1, H, W) 128 | B, C, H, W = probas.size() 129 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 130 | labels = labels.view(-1) 131 | if ignore is None: 132 | return probas, labels 133 | valid = (labels != ignore) 134 | vprobas = probas[valid.nonzero(as_tuple=False).squeeze()] 135 | vlabels = labels[valid] 136 | return vprobas, vlabels 137 | 138 | 139 | class Lovasz_softmax(nn.Module): 140 | def __init__(self, classes='present', per_image=False, ignore=None): 141 | super(Lovasz_softmax, self).__init__() 142 | self.classes = classes 143 | self.per_image = per_image 144 | self.ignore = ignore 145 | 146 | def forward(self, probas, labels): 147 | return lovasz_softmax(probas, labels, self.classes, self.per_image, self.ignore) 148 | 149 | 150 | # Used to calculate Lovasz Loss with point cloud as input 151 | # Add by Jiadai Sun 152 | class Lovasz_softmax_PointCloud(nn.Module): 153 | def __init__(self, classes='present', ignore=None): 154 | super(Lovasz_softmax_PointCloud, self).__init__() 155 | self.classes = classes 156 | self.ignore = ignore 157 | 158 | def forward(self, probas, labels): 159 | 160 | B, C, N = probas.size() 161 | probas = probas.permute(0, 2, 1).contiguous().view(-1, C) 162 | labels = labels.view(-1) 163 | if self.ignore is not None: 164 | valid = (labels != self.ignore) 165 | vprobas = probas[valid.nonzero(as_tuple=False).squeeze()] 166 | vlabels = labels[valid] 167 | return lovasz_softmax_flat(vprobas, vlabels, classes=self.classes) 168 | else: 169 | return lovasz_softmax_flat(probas, labels, classes=self.classes) 170 | -------------------------------------------------------------------------------- /modules/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/MF-MOS/0c702445a39b978efc107cf7d0a2a33246f857ba/modules/loss/__init__.py -------------------------------------------------------------------------------- /modules/tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | 4 | import torch 5 | import numpy as np 6 | import cv2 7 | from matplotlib import pyplot as plt 8 | 9 | 10 | class AverageMeter(object): 11 | """Computes and stores the average and current value""" 12 | 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | # def one_hot_pred_from_label(y_pred, labels): 30 | # y_true = torch.zeros_like(y_pred) 31 | # ones = torch.ones_like(y_pred) 32 | # indexes = [l for l in labels] 33 | # y_true[torch.arange(labels.size(0)), indexes] = ones[torch.arange( 34 | # labels.size(0)), indexes] 35 | # return y_true 36 | 37 | 38 | # def keep_variance_fn(x): 39 | # return x + 1e-3 40 | 41 | 42 | # class SoftmaxHeteroscedasticLoss(torch.nn.Module): 43 | # def __init__(self): 44 | # super(SoftmaxHeteroscedasticLoss, self).__init__() 45 | # self.adf_softmax = adf.Softmax( 46 | # dim=1, keep_variance_fn=keep_variance_fn) 47 | 48 | # def forward(self, outputs, targets, eps=1e-5): 49 | # mean, var = self.adf_softmax(*outputs) 50 | # targets = torch.nn.functional.one_hot( 51 | # targets, num_classes=20).permute(0, 3, 1, 2).float() 52 | 53 | # precision = 1 / (var + eps) 54 | # return torch.mean(0.5 * precision * (targets - mean) ** 2 + 0.5 * torch.log(var + eps)) 55 | 56 | 57 | def save_to_txtlog(logdir, logfile, message): 58 | f = open(logdir + '/' + logfile, "a") 59 | f.write(message + '\n') 60 | f.close() 61 | return 62 | 63 | 64 | def save_checkpoint(to_save, logdir, suffix=""): 65 | # Save the weights 66 | torch.save(to_save, logdir + 67 | "/MFMOS" + suffix) 68 | 69 | 70 | def get_mpl_colormap(cmap_name): 71 | cmap = plt.get_cmap(cmap_name) 72 | # Initialize the matplotlib color map 73 | sm = plt.cm.ScalarMappable(cmap=cmap) 74 | # Obtain linear color range 75 | color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1] 76 | return color_range.reshape(256, 1, 3) 77 | 78 | 79 | def make_log_img(depth, mask, pred, gt, color_fn, movable=False): 80 | # input should be [depth, pred, gt] 81 | # make range image (normalized to 0,1 for saving) 82 | depth = (cv2.normalize(depth, None, alpha=0, beta=1, 83 | norm_type=cv2.NORM_MINMAX, 84 | dtype=cv2.CV_32F) * 255.0).astype(np.uint8) 85 | out_img = cv2.applyColorMap( 86 | depth, get_mpl_colormap('viridis')) * mask[..., None] 87 | # make label prediction 88 | pred_color = color_fn((pred * mask).astype(np.int32), movable=movable) 89 | out_img = np.concatenate([out_img, pred_color], axis=0) 90 | # make label gt 91 | gt_color = color_fn(gt, movable=movable) 92 | out_img = np.concatenate([out_img, gt_color], axis=0) 93 | return (out_img).astype(np.uint8) 94 | 95 | def show_scans_in_training(proj_mask, in_vol, argmax, proj_labels, color_fn, movable=False): 96 | # get the first scan in batch and project points 97 | mask_np = proj_mask[0].cpu().numpy() 98 | depth_np = in_vol[0][0].cpu().numpy() 99 | pred_np = argmax[0].cpu().numpy() 100 | gt_np = proj_labels[0].cpu().numpy() 101 | out = make_log_img(depth_np, mask_np, pred_np, gt_np, color_fn, movable=movable) 102 | 103 | mask_np = proj_mask[1].cpu().numpy() 104 | depth_np = in_vol[1][0].cpu().numpy() 105 | pred_np = argmax[1].cpu().numpy() 106 | gt_np = proj_labels[1].cpu().numpy() 107 | out2 = make_log_img(depth_np, mask_np, pred_np, gt_np, color_fn, movable=movable) 108 | 109 | out = np.concatenate([out, out2], axis=0) 110 | 111 | cv2.imshow("sample_training", out) 112 | cv2.waitKey(1) 113 | 114 | 115 | class iouEval: 116 | def __init__(self, n_classes, device, ignore=None): 117 | self.n_classes = n_classes 118 | self.device = device 119 | # if ignore is larger than n_classes, consider no ignoreIndex 120 | self.ignore = torch.tensor(ignore).long() 121 | self.include = torch.tensor( 122 | [n for n in range(self.n_classes) if n not in self.ignore]).long() 123 | print("[IOU EVAL] IGNORE: ", self.ignore) 124 | print("[IOU EVAL] INCLUDE: ", self.include) 125 | self.reset() 126 | 127 | def num_classes(self): 128 | return self.n_classes 129 | 130 | def reset(self): 131 | self.conf_matrix = torch.zeros( 132 | (self.n_classes, self.n_classes), device=self.device).long() 133 | self.ones = None 134 | self.last_scan_size = None # for when variable scan size is used 135 | 136 | def addBatch(self, x, y): # x=preds, y=targets 137 | # if numpy, pass to pytorch 138 | # to tensor 139 | if isinstance(x, np.ndarray): 140 | x = torch.from_numpy(np.array(x)).long().to(self.device) 141 | if isinstance(y, np.ndarray): 142 | y = torch.from_numpy(np.array(y)).long().to(self.device) 143 | 144 | # sizes should be "batch_size x H x W" 145 | x_row = x.reshape(-1) # de-batchify 146 | y_row = y.reshape(-1) # de-batchify 147 | 148 | # idxs are labels and predictions 149 | idxs = torch.stack([x_row, y_row], dim=0) 150 | 151 | # ones is what I want to add to conf when I 152 | if self.ones is None or self.last_scan_size != idxs.shape[-1]: 153 | self.ones = torch.ones((idxs.shape[-1]), device=self.device).long() 154 | self.last_scan_size = idxs.shape[-1] 155 | 156 | # make confusion matrix (cols = gt, rows = pred) 157 | self.conf_matrix = self.conf_matrix.index_put_( 158 | tuple(idxs), self.ones, accumulate=True) 159 | 160 | def getStats(self): 161 | # remove fp and fn from confusion on the ignore classes cols and rows 162 | conf = self.conf_matrix.clone().double() 163 | conf[self.ignore] = 0 164 | conf[:, self.ignore] = 0 165 | 166 | # get the clean stats 167 | tp = conf.diag() 168 | fp = conf.sum(dim=1) - tp 169 | fn = conf.sum(dim=0) - tp 170 | return tp, fp, fn 171 | 172 | def getIoU(self): 173 | tp, fp, fn = self.getStats() 174 | intersection = tp 175 | union = tp + fp + fn + 1e-15 176 | iou = intersection / union 177 | iou_mean = (intersection[self.include] / union[self.include]).mean() 178 | return iou_mean, iou # returns "iou mean", "iou per class" ALL CLASSES 179 | 180 | def getacc(self): 181 | tp, fp, fn = self.getStats() 182 | total_tp = tp.sum() 183 | total = tp[self.include].sum() + fp[self.include].sum() + 1e-15 184 | acc_mean = total_tp / total 185 | return acc_mean # returns "acc mean" 186 | -------------------------------------------------------------------------------- /modules/user_refine.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | 4 | import os 5 | import imp 6 | import time 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.backends.cudnn as cudnn 13 | import __init__ as booger 14 | 15 | from tqdm import tqdm 16 | from modules.user import User 17 | # from modules.SalsaNextWithMotionAttention import * 18 | 19 | # from modules.PointRefine.spvcnn import SPVCNN 20 | # from modules.PointRefine.spvcnn_lite import SPVCNN 21 | from torchsparse.utils.quantize import sparse_quantize 22 | from torchsparse.utils.collate import sparse_collate 23 | from torchsparse import SparseTensor 24 | 25 | 26 | class UserRefine(User): 27 | def __init__(self, ARCH, DATA, datadir, outputdir, modeldir, split, save_movable=False): 28 | 29 | super(UserRefine, self).__init__(ARCH, DATA, datadir, outputdir, modeldir, split, 30 | point_refine=True, save_movable=save_movable) 31 | 32 | def infer(self): 33 | coarse, reproj, refine = [], [], [] 34 | 35 | if self.split == 'valid': 36 | self.infer_subset(loader=self.parser.get_valid_set(), 37 | to_orig_fn=self.parser.to_original, 38 | coarse=coarse, reproj=reproj, refine=refine) 39 | elif self.split == 'train': 40 | self.infer_subset(loader=self.parser.get_train_set(), 41 | to_orig_fn=self.parser.to_original, 42 | coarse=coarse, reproj=reproj, refine=refine) 43 | elif self.split == 'test': 44 | self.infer_subset(loader=self.parser.get_test_set(), 45 | to_orig_fn=self.parser.to_original, 46 | coarse=coarse, reproj=reproj, refine=refine) 47 | elif self.split == None: 48 | self.infer_subset(loader=self.parser.get_train_set(), 49 | to_orig_fn=self.parser.to_original, 50 | coarse=coarse, reproj=reproj, refine=refine) 51 | self.infer_subset(loader=self.parser.get_valid_set(), 52 | to_orig_fn=self.parser.to_original, 53 | coarse=coarse, reproj=reproj, refine=refine) 54 | self.infer_subset(loader=self.parser.get_test_set(), 55 | to_orig_fn=self.parser.to_original, 56 | coarse=coarse, reproj=reproj, refine=refine) 57 | else: 58 | raise NotImplementedError 59 | 60 | print(f"Mean Coarse inference time:{'%.8f'%np.mean(coarse)}\t std:{'%.8f'%np.std(coarse)}") 61 | print(f"Mean Reproject inference time:{'%.8f'%np.mean(reproj)}\t std:{'%.8f'%np.std(reproj)}") 62 | print(f"Mean Refine inference time:{'%.8f'%np.mean(refine)}\t std:{'%.8f'%np.std(refine)}") 63 | print(f"Total Frames: {len(coarse)}") 64 | print("Finished Infering") 65 | 66 | return 67 | 68 | def infer_subset(self, loader, to_orig_fn, coarse, reproj, refine): 69 | 70 | # switch to evaluate mode 71 | self.model.eval() 72 | self.refine_module.eval() 73 | 74 | # empty the cache to infer in high res 75 | if self.gpu: 76 | torch.cuda.empty_cache() 77 | 78 | with torch.no_grad(): 79 | 80 | end = time.time() 81 | 82 | for i, (proj_in, proj_mask, _, _, path_seq, path_name, 83 | p_x, p_y, proj_range, unproj_range, _, unproj_xyz, _, _, npoints)\ 84 | in enumerate(tqdm(loader, ncols=80)): 85 | 86 | # first cut to rela size (batch size one allows it) 87 | p_x = p_x[0, :npoints] 88 | p_y = p_y[0, :npoints] 89 | proj_range = proj_range[0, :npoints] 90 | unproj_range = unproj_range[0, :npoints] 91 | path_seq = path_seq[0] 92 | path_name = path_name[0] 93 | points_xyz = unproj_xyz[0, :npoints] 94 | 95 | if self.gpu: 96 | proj_in = proj_in.cuda() 97 | p_x = p_x.cuda() 98 | p_y = p_y.cuda() 99 | if self.post: 100 | proj_range = proj_range.cuda() 101 | unproj_range = unproj_range.cuda() 102 | 103 | end = time.time() 104 | # compute output 105 | proj_output, last_feature, movable_proj_output, _ = self.model(proj_in) 106 | 107 | 108 | if torch.cuda.is_available(): 109 | torch.cuda.synchronize() 110 | res = time.time() - end 111 | coarse.append(res) 112 | 113 | if self.save_movable: 114 | movable_proj_argmax = movable_proj_output[0].argmax(dim=0) 115 | if self.post: 116 | movable_unproj_argmax = self.post(proj_range, unproj_range, 117 | movable_proj_argmax, p_x, p_y) 118 | else: 119 | movable_unproj_argmax = movable_proj_argmax[p_y, p_x] 120 | 121 | end = time.time() 122 | # print(f"CoarseModule seq {path_seq} scan {path_name} in {res} sec") 123 | 124 | """ Reproject 2D features to 3D based on indices and form sparse Tensor""" 125 | points_feature = last_feature[0, :, p_y, p_x] 126 | coords = np.round(points_xyz[:, :3].cpu().numpy() / 0.05) 127 | coords -= coords.min(0, keepdims=1) 128 | coords, indices, inverse = sparse_quantize(coords, return_index=True, return_inverse=True) 129 | coords = torch.tensor(coords, dtype=torch.int, device='cuda') 130 | feats = points_feature.permute(1,0)[indices] #torch.tensor(, dtype=torch.float) 131 | inputs = SparseTensor(coords=coords, feats=feats) 132 | inputs = sparse_collate([inputs]).cuda() 133 | """""""""""""""""""""""" 134 | 135 | # measure elapsed time 136 | if torch.cuda.is_available(): 137 | torch.cuda.synchronize() 138 | res = time.time() - end 139 | reproj.append(res) 140 | end = time.time() 141 | # print(f"DataConvert seq {path_seq} scan {path_name} in {res} sec") 142 | 143 | """ Input to PointHead, refine prediction """ 144 | predict = self.refine_module(inputs) 145 | 146 | if torch.cuda.is_available(): 147 | torch.cuda.synchronize() 148 | res = time.time() - end 149 | refine.append(res) 150 | # print(f"RefineModule seq {path_seq} scan {path_name} in {res} sec") 151 | 152 | predict = predict[inverse] #.permute(1,0) 153 | unproj_argmax = predict.argmax(dim=1) 154 | 155 | # save scan # get the first scan in batch and project scan 156 | pred_np = unproj_argmax.cpu().numpy() 157 | pred_np = pred_np.reshape((-1)).astype(np.int32) 158 | 159 | # map to original label 160 | pred_np = to_orig_fn(pred_np) 161 | 162 | path = os.path.join(self.outputdir, "sequences", path_seq, "predictions", path_name) 163 | pred_np.tofile(path) 164 | 165 | if self.save_movable: 166 | movable_pred_np = movable_unproj_argmax.cpu().numpy() 167 | movable_pred_np = movable_pred_np.reshape((-1)).astype(np.int32) 168 | 169 | # map to original label 170 | movable_pred_np = to_orig_fn(movable_pred_np, movable=True) 171 | path = os.path.join(self.outputdir, "sequences", path_seq, "predictions_movable", path_name) 172 | movable_pred_np.tofile(path) 173 | 174 | movable_pred_np[np.where(pred_np == 251)] = 251 175 | path = os.path.join(self.outputdir, "sequences", path_seq, "predictions_fuse", path_name) 176 | movable_pred_np.tofile(path) 177 | -------------------------------------------------------------------------------- /script/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DatasetPath=DATAROOT 4 | ArchConfig=./train_yaml/ddp_mos_coarse_stage.yml 5 | DataConfig=./config/labels/semantic-kitti-mos.raw.yaml 6 | LogPath=./log/Train 7 | 8 | export CUDA_VISIBLE_DEVICES=0,1 && python3 -m torch.distributed.launch --nproc_per_node=2 \ 9 | ./train.py -d $DatasetPath \ 10 | -ac $ArchConfig \ 11 | -dc $DataConfig \ 12 | -l $LogPath -------------------------------------------------------------------------------- /script/evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DatasetPath=DATAROOT 4 | PredictionsPath=./log/Valid/predictions/ 5 | DataConfig=./config/labels/semantic-kitti-mos.raw.yaml 6 | 7 | python3 utils/evaluate_mos.py -d $DatasetPath \ 8 | -p $PredictionsPath \ 9 | --dc $DataConfig 10 | -------------------------------------------------------------------------------- /script/train_siem.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DatasetPath=DATAROOT 4 | ArchConfig=./train_yaml/mos_pointrefine_stage.yml 5 | DataConfig=./config/labels/semantic-kitti-mos.raw.yaml 6 | LogPath=./log/TrainWithSIEM 7 | FirstStageModelPath=FirstStageModelPath 8 | 9 | export CUDA_VISIBLE_DEVICES=0 && python train_2stage.py -d $DatasetPath \ 10 | -ac $ArchConfig \ 11 | -dc $DataConfig \ 12 | -l $LogPath \ 13 | -p $FirstStageModelPath -------------------------------------------------------------------------------- /script/valid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DatasetPath=DATAROOT 4 | ModelPath=MODELPATH 5 | SavePath=./log/Valid/predictions/ 6 | SPLIT=valid # valid or test 7 | 8 | # If you want to use SIEM, set pointrefine on 9 | export CUDA_VISIBLE_DEVICES=0 && python3 infer.py -d $DatasetPath \ 10 | -m $ModelPath \ 11 | -l $SavePath \ 12 | -s $SPLIT \ 13 | --pointrefine \ 14 | --movable # Whether to save the label of movable objects 15 | -------------------------------------------------------------------------------- /script/visualize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DatasetPath=DATAROOT 4 | Seq=08 5 | DataConfig=./config/labels/semantic-kitti-mos.raw.yaml 6 | Version=fuse # Version in ["moving", "movable", "fuse"] for predictions 7 | #PredictionPath=./log/valid/predictions 8 | 9 | python3 utils/visualize_mos.py -d $DatasetPath \ 10 | -s $Seq \ 11 | -c $DataConfig \ 12 | -v $Version \ 13 | # -p $PredictionPath 14 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | 4 | import torch 5 | print("device count: ", torch.cuda.device_count()) 6 | from torch import distributed as dist 7 | dist.init_process_group(backend="nccl") 8 | print("world_size: ", dist.get_world_size()) 9 | 10 | import random 11 | import numpy as np 12 | import __init__ as booger 13 | 14 | from modules.trainer import Trainer 15 | # from modules.SalsaNextWithMotionAttention import * 16 | from modules.MFMOS import * 17 | 18 | def set_seed(seed=1024): 19 | random.seed(seed) 20 | # os.environ['PYTHONHASHSEED'] = str(seed) 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 25 | 26 | # torch.backends.cudnn.benchmark = False 27 | # torch.backends.cudnn.deterministic = True 28 | # torch.backends.cudnn.enabled = False 29 | # If we need to reproduce the results, increase the training speed 30 | # set benchmark = False 31 | # If we don’t need to reproduce the results, improve the network performance as much as possible 32 | # set benchmark = True 33 | 34 | 35 | if __name__ == '__main__': 36 | parser = get_args(flags="train") 37 | FLAGS, unparsed = parser.parse_known_args() 38 | local_rank = FLAGS.local_rank 39 | torch.cuda.set_device(local_rank) 40 | 41 | FLAGS.log = os.path.join(FLAGS.log, datetime.now().strftime("%Y-%-m-%d-%H:%M") + FLAGS.name) 42 | print(FLAGS.log) 43 | # open arch / data config file 44 | ARCH = load_yaml(FLAGS.arch_cfg) 45 | DATA = load_yaml(FLAGS.data_cfg) 46 | 47 | params = MFMOS(nclasses=3, params=ARCH, movable_nclasses=3) 48 | pytorch_total_params = sum(p.numel() for p in params.parameters() if p.requires_grad) 49 | del params 50 | 51 | if local_rank == 0: 52 | make_logdir(FLAGS=FLAGS, resume_train=False) # create log folder 53 | check_pretrained_dir(FLAGS.pretrained) # does model folder exist? 54 | backup_to_logdir(FLAGS=FLAGS) # backup code and config files to logdir 55 | 56 | set_seed() 57 | # create trainer and start the training 58 | trainer = Trainer(ARCH, DATA, FLAGS.dataset, FLAGS.log, FLAGS.pretrained, local_rank=local_rank) 59 | 60 | if local_rank == 0: 61 | print("----------") 62 | print("INTERFACE:") 63 | print(" dataset:", FLAGS.dataset) 64 | print(" arch_cfg:", FLAGS.arch_cfg) 65 | print(" data_cfg:", FLAGS.data_cfg) 66 | print(" Total of Trainable Parameters: {}".format(millify(pytorch_total_params, 2))) 67 | print(" log:", FLAGS.log) 68 | print(" pretrained:", FLAGS.pretrained) 69 | print(" Augmentation for residual: {}, interval in validation: {}".format(ARCH["train"]["residual_aug"], 70 | ARCH["train"]["valid_residual_delta_t"])) 71 | print("----------\n") 72 | 73 | trainer.train() 74 | -------------------------------------------------------------------------------- /train_2stage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | 4 | import os 5 | import random 6 | import numpy as np 7 | import torch 8 | import __init__ as booger 9 | 10 | from datetime import datetime 11 | from utils.utils import * 12 | from modules.trainer_refine import TrainerRefine 13 | # from modules.SalsaNextWithMotionAttention import * 14 | 15 | 16 | def set_seed(seed=1024): 17 | random.seed(seed) 18 | # os.environ['PYTHONHASHSEED'] = str(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 23 | # torch.backends.cudnn.benchmark = False 24 | # torch.backends.cudnn.deterministic = True 25 | # torch.backends.cudnn.enabled = False 26 | # If we need to reproduce the results, increase the training speed 27 | # set benchmark = False 28 | # If we don’t need to reproduce the results, improve the network performance as much as possible 29 | # set benchmark = True 30 | 31 | 32 | if __name__ == '__main__': 33 | parser = get_args(flags="train") 34 | FLAGS, unparsed = parser.parse_known_args() 35 | FLAGS.log = os.path.join(FLAGS.log, datetime.now().strftime("%Y-%-m-%d-%H:%M") + FLAGS.name) 36 | 37 | # open arch / data config file 38 | ARCH = load_yaml(FLAGS.arch_cfg) 39 | DATA = load_yaml(FLAGS.data_cfg) 40 | 41 | # params = SalsaNextWithMotionAttention(nclasses=3, params=ARCH) 42 | # pytorch_total_params = sum(p.numel() for p in params.parameters() if p.requires_grad) 43 | # del params 44 | 45 | make_logdir(FLAGS=FLAGS, resume_train=False) # create log folder 46 | check_pretrained_dir(FLAGS.pretrained) # does model folder exist? 47 | backup_to_logdir(FLAGS=FLAGS, pretrain_model=True) # backup code and config files to logdir 48 | 49 | set_seed() 50 | # create trainer and start the training 51 | trainer = TrainerRefine(ARCH, DATA, FLAGS.dataset, FLAGS.log, FLAGS.pretrained) 52 | 53 | print("----------") 54 | print("INTERFACE:") 55 | print(" dataset:", FLAGS.dataset) 56 | print(" arch_cfg:", FLAGS.arch_cfg) 57 | print(" data_cfg:", FLAGS.data_cfg) 58 | print(" log:", FLAGS.log) 59 | print(" pretrained:", FLAGS.pretrained) 60 | print(" Augmentation for residual: {}, interval in validation: {}".format(ARCH["train"]["residual_aug"], 61 | ARCH["train"]["valid_residual_delta_t"])) 62 | print("----------\n") 63 | 64 | trainer.train() 65 | -------------------------------------------------------------------------------- /train_yaml/ddp_mos_coarse_stage.yml: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # training parameters 3 | ################################################################################ 4 | train: 5 | loss: "xentropy" # must be either xentropy or iou 6 | max_epochs: 150 7 | lr: 0.002 # sgd learning rate 8 | wup_epochs: 1 # warmup during first XX epochs (can be float) 9 | momentum: 0.9 # sgd momentum 10 | lr_decay: 0.99 # learning rate decay per epoch after initial cycle (from min lr) 11 | w_decay: 0.0001 # weight decay 12 | batch_size: 4 # batch size 13 | report_batch: 10 # every x batches, report loss 14 | report_epoch: 1 # every x epochs, report validation set 15 | epsilon_w: 0.001 # class weight w = 1 / (content + epsilon_w) 16 | save_summary: False # Summary of weight histograms for tensorboard 17 | save_scans: False # False doesn't save anything, True saves some sample images 18 | # (one per batch of the last calculated batch) in log folder 19 | show_scans: False # show scans during training 20 | workers: 8 # number of threads to get data 21 | 22 | # for mos 23 | residual: True # This needs to be the same as in the dataset params below! 24 | residual_aug: True 25 | valid_residual_delta_t: 3 26 | n_input_scans: 8 # This needs to be the same as in the dataset params below! 27 | 28 | ################################################################################ 29 | # postproc parameters 30 | ################################################################################ 31 | post: 32 | CRF: 33 | use: False 34 | train: True 35 | params: False # this should be a dict when in use 36 | KNN: 37 | use: True # This parameter default is false 38 | params: 39 | knn: 5 40 | search: 5 41 | sigma: 1.0 42 | cutoff: 1.0 43 | 44 | ################################################################################ 45 | # classification head parameters 46 | ################################################################################ 47 | # dataset (to find parser) 48 | dataset: 49 | labels: "kitti" 50 | scans: "kitti" 51 | max_points: 150000 # max of any scan in dataset 52 | sensor: 53 | name: "HDL64" 54 | type: "spherical" # projective 55 | fov_up: 3 56 | fov_down: -25 57 | img_prop: 58 | width: 2048 59 | height: 64 60 | img_means: #range,x,y,z,signal 61 | - 12.12 62 | - 10.88 63 | - 0.23 64 | - -1.04 65 | - 0.21 66 | img_stds: #range,x,y,z,signal 67 | - 12.32 68 | - 11.47 69 | - 6.91 70 | - 0.86 71 | - 0.16 72 | 73 | # for mos 74 | n_input_scans: 8 # This needs to be the same as in the backbone params above! 75 | residual: True # This needs to be the same as in the backbone params above! 76 | transform: False # tranform the last n_input_scans - 1 frames before concatenation 77 | use_normal: False # if use normal vector as channels of range image 78 | -------------------------------------------------------------------------------- /train_yaml/mos_coarse_stage.yml: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # training parameters 3 | ################################################################################ 4 | train: 5 | loss: "xentropy" # must be either xentropy or iou 6 | max_epochs: 150 7 | lr: 0.008 # sgd learning rate 8 | wup_epochs: 1 # warmup during first XX epochs (can be float) 9 | momentum: 0.9 # sgd momentum 10 | lr_decay: 0.99 # learning rate decay per epoch after initial cycle (from min lr) 11 | w_decay: 0.0001 # weight decay 12 | batch_size: 16 # batch size 13 | report_batch: 10 # every x batches, report loss 14 | report_epoch: 1 # every x epochs, report validation set 15 | epsilon_w: 0.001 # class weight w = 1 / (content + epsilon_w) 16 | save_summary: False # Summary of weight histograms for tensorboard 17 | save_scans: False # False doesn't save anything, True saves some sample images 18 | # (one per batch of the last calculated batch) in log folder 19 | show_scans: False # show scans during training 20 | workers: 8 # number of threads to get data 21 | 22 | # for mos 23 | residual: True # This needs to be the same as in the dataset params below! 24 | residual_aug: True 25 | valid_residual_delta_t: 3 26 | n_input_scans: 8 # This needs to be the same as in the dataset params below! 27 | 28 | ################################################################################ 29 | # postproc parameters 30 | ################################################################################ 31 | post: 32 | CRF: 33 | use: False 34 | train: True 35 | params: False # this should be a dict when in use 36 | KNN: 37 | use: True # This parameter default is false 38 | params: 39 | knn: 5 40 | search: 5 41 | sigma: 1.0 42 | cutoff: 1.0 43 | 44 | ################################################################################ 45 | # classification head parameters 46 | ################################################################################ 47 | # dataset (to find parser) 48 | dataset: 49 | labels: "kitti" 50 | scans: "kitti" 51 | max_points: 150000 # max of any scan in dataset 52 | sensor: 53 | name: "HDL64" 54 | type: "spherical" # projective 55 | fov_up: 3 56 | fov_down: -25 57 | img_prop: 58 | width: 2048 59 | height: 64 60 | img_means: #range,x,y,z,signal 61 | - 12.12 62 | - 10.88 63 | - 0.23 64 | - -1.04 65 | - 0.21 66 | img_stds: #range,x,y,z,signal 67 | - 12.32 68 | - 11.47 69 | - 6.91 70 | - 0.86 71 | - 0.16 72 | 73 | # for mos 74 | n_input_scans: 8 # This needs to be the same as in the backbone params above! 75 | residual: True # This needs to be the same as in the backbone params above! 76 | transform: False # tranform the last n_input_scans - 1 frames before concatenation 77 | use_normal: False # if use normal vector as channels of range image 78 | -------------------------------------------------------------------------------- /train_yaml/mos_pointrefine_stage.yml: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # training parameters 3 | ################################################################################ 4 | train: 5 | loss: "xentropy" # must be either xentropy or iou 6 | max_epochs: 10 7 | lr: 0.001 # sgd learning rate 8 | wup_epochs: 1 # warmup during first XX epochs (can be float) 9 | momentum: 0.9 # sgd momentum 10 | lr_decay: 0.99 # learning rate decay per epoch after initial cycle (from min lr) 11 | w_decay: 0.0001 # weight decay 12 | batch_size: 1 # batch size 13 | report_batch: 10 # every x batches, report loss 14 | report_epoch: 1 # every x epochs, report validation set 15 | epsilon_w: 0.001 # class weight w = 1 / (content + epsilon_w) 16 | save_summary: False # Summary of weight histograms for tensorboard 17 | save_scans: False # False doesn't save anything, True saves some sample images 18 | # (one per batch of the last calculated batch) in log folder 19 | show_scans: False # show scans during training 20 | workers: 8 # number of threads to get data 21 | 22 | # for mos 23 | residual: True # This needs to be the same as in the dataset params below! 24 | residual_aug: True 25 | valid_residual_delta_t: 3 26 | n_input_scans: 8 # This needs to be the same as in the dataset params below! 27 | 28 | ################################################################################ 29 | # postproc parameters 30 | ################################################################################ 31 | post: 32 | CRF: 33 | use: False 34 | train: True 35 | params: False # this should be a dict when in use 36 | KNN: 37 | use: True # This parameter default is false 38 | params: 39 | knn: 5 40 | search: 5 41 | sigma: 1.0 42 | cutoff: 1.0 43 | 44 | ################################################################################ 45 | # classification head parameters 46 | ################################################################################ 47 | # dataset (to find parser) 48 | dataset: 49 | labels: "kitti" 50 | scans: "kitti" 51 | max_points: 150000 # max of any scan in dataset 52 | sensor: 53 | name: "HDL64" 54 | type: "spherical" # projective 55 | fov_up: 3 56 | fov_down: -25 57 | img_prop: 58 | width: 2048 59 | height: 64 60 | img_means: #range,x,y,z,signal 61 | - 12.12 62 | - 10.88 63 | - 0.23 64 | - -1.04 65 | - 0.21 66 | img_stds: #range,x,y,z,signal 67 | - 12.32 68 | - 11.47 69 | - 6.91 70 | - 0.86 71 | - 0.16 72 | 73 | # for mos 74 | n_input_scans: 8 # This needs to be the same as in the backbone params above! 75 | residual: True # This needs to be the same as in the backbone params above! 76 | transform: False # tranform the last n_input_scans - 1 frames before concatenation 77 | use_normal: False # if use normal vector as channels of range image 78 | -------------------------------------------------------------------------------- /utils/auto_gen_residual_images.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Jiadai Sun 3 | # and the main_funciton 'prosess_one_seq' refers to Xieyuanli Chen’s gen_residual_images.py 4 | # This file is covered by the LICENSE file in the root of this project. 5 | # Brief: This script generates residual images 6 | 7 | import os 8 | os.environ["OMP_NUM_THREADS"] = "4" 9 | import yaml 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | 13 | from tqdm import tqdm 14 | from icecream import ic 15 | from kitti_utils import load_poses, load_calib, load_files, load_vertex 16 | 17 | try: 18 | from c_gen_virtual_scan import gen_virtual_scan as range_projection 19 | except: 20 | print("Using clib by $export PYTHONPATH=$PYTHONPATH: