├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── config ├── hrnet_rellis │ ├── seg_hrnet_ocr_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml │ └── seg_hrnet_ocr_w48_train_512x1024_sgd_lr1e-3_wd5e-4_bs_12_epoch484.yaml ├── obstacles.yaml ├── rellis.yaml ├── rellis_to_flexibility.yaml ├── rellis_to_traversability.yaml ├── rellis_traversability.yaml ├── rviz │ ├── cloud_fusion.rviz │ ├── cloud_segm.rviz │ ├── cloud_segm_planner.rviz │ ├── colored_cloud_marv.rviz │ ├── ctu_demo.rviz │ ├── demo_cloud.rviz │ ├── geometric_cloud_segmentation.rviz │ ├── husky.rviz │ ├── husky_demo.rviz │ ├── marv.rviz │ ├── navigation.rviz │ ├── robot_data.rviz │ ├── segmented_cloud.rviz │ ├── semantic_trav.rviz │ ├── spot.rviz │ ├── tradr.rviz │ └── trav_eval.rviz ├── semantickitti19.yaml ├── semantickitti19_to_flexibility.yaml ├── semantickitti19_to_traversability.yaml ├── slam │ ├── icp.yaml │ ├── input_filters.yaml │ └── map_post_filters.yaml └── workspace.repos ├── docker └── jetson │ ├── Dockerfile │ ├── Makefile │ ├── install_ros_melodic.sh │ └── install_ros_noetic.sh ├── docs ├── cloud_flex_gt.png ├── cloud_trav_gt.png ├── colored_pc_demo.png ├── colored_pc_demo_pred.png ├── colored_pc_demo_rgb.png ├── docker.md ├── install.md ├── rellis.md ├── rgb_sem_gt.png ├── rgb_sem_gt4.png ├── segmentation_labels.png ├── segmented_pc.png ├── semantic_traversability_pipeline.png └── trav_data.md ├── launch ├── base_footprint.launch ├── cloud_filter.launch ├── cloud_fusion.launch ├── cloud_ortho_stats.launch ├── cloud_projection.launch ├── color_pc_bagfile_demo.launch ├── ctu_robot.launch ├── demo_cloud.launch ├── fused_traversability.launch ├── generate_points.launch ├── generate_trav_src.launch ├── geometric_cloud_segmentation.launch ├── geometric_traversability.launch ├── husky_robot.launch ├── image_segmentation_dataset_demo.launch ├── marv_robot.launch ├── orient_frame.launch ├── play_bag.launch ├── robot_data.launch ├── semantic_traversability.launch ├── semantic_traversability_tconcord3d.launch ├── show_trav_data.launch ├── slam.launch ├── traversability_bag_demo.launch ├── traversability_dataset_demo.launch ├── traversability_evaluation.launch └── video │ ├── front_overlay_camera.rviz │ ├── graph_pcd_scene.rviz │ ├── rear_overlay_camera.rviz │ ├── record_video.launch │ ├── traversability_final_demo_visualization.launch │ ├── traversability_final_demo_visualization.rviz │ └── traversability_scene.rviz ├── notebooks ├── hrnet_demo.ipynb ├── segmentation_evaluation.py ├── smp_demo.py ├── smp_rellis3d_eval.ipynb ├── smp_rellis3d_train.ipynb └── traversability_learning.ipynb ├── package.xml ├── scripts ├── nodes │ ├── cloud_fusion │ ├── cloud_segmentation │ ├── cloud_segmentation_tconcord3d │ ├── cloud_to_depth │ ├── find_traversed_points │ ├── geometric_cloud_segmentation │ ├── global_map │ ├── latch_sensor_info │ ├── play_tf_static │ ├── robot_data │ ├── segmentation_inference │ ├── stamp_twist │ └── traversability_fusion └── tools │ ├── eval_depth │ ├── generate_flexibility_data │ ├── generate_traversability_data │ ├── legacy_weights │ ├── save_clouds │ ├── save_clouds_from_bag │ ├── test_depth │ ├── test_img │ ├── train_depth │ ├── train_img │ ├── train_smp │ └── traversability_learning ├── setup.py ├── singularity ├── .gitignore ├── build.sh ├── recepie.def └── requirements.txt └── src ├── datasets ├── __init__.py ├── augmentations.py ├── base_dataset.py ├── cwt.py ├── laserscan.py ├── rellis_3d.py ├── semantic.py ├── traversability_cloud.py └── traversability_dataset.py ├── hrnet ├── __init__.py ├── config │ ├── __init__.py │ ├── default.py │ ├── hrnet_config.py │ └── models.py ├── core │ ├── __init__.py │ ├── criterion.py │ └── function.py ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ └── rellis.py ├── models │ ├── __init__.py │ ├── bn_helper.py │ ├── hrnet.py │ ├── seg_hrnet.py │ ├── seg_hrnet_ocr.py │ └── sync_bn │ │ ├── LICENSE │ │ ├── __init__.py │ │ └── inplace_abn │ │ ├── __init__.py │ │ ├── bn.py │ │ ├── functions.py │ │ └── src │ │ ├── common.h │ │ ├── inplace_abn.cpp │ │ ├── inplace_abn.h │ │ ├── inplace_abn_cpu.cpp │ │ └── inplace_abn_cuda.cu └── utils │ ├── __init__.py │ ├── distributed.py │ ├── modelsummary.py │ └── utils.py ├── tconcord3d ├── __init__.py ├── builder │ ├── __init__.py │ ├── data_builder.py │ ├── loss_builder.py │ └── model_builder.py ├── config │ ├── __init__.py │ ├── config.py │ ├── label_mapping │ │ └── sematickitti │ │ │ ├── semantic-kitti-multiscan_ssl_s20_p80.yaml │ │ │ └── semantic-kitti_ssl_s20_p80.yaml │ └── semantickitti │ │ ├── semantickitti_S0_0_T11_33_ssl_s20_p80.yaml │ │ ├── semantickitti_S0_0_test.yaml │ │ ├── semantickitti_T0_0.yaml │ │ ├── semantickitti_T1_1.yaml │ │ ├── semantickitti_T2_2.yaml │ │ └── semantickitti_T3_3.yaml ├── model │ ├── __init__.py │ ├── cylinder_3d.py │ ├── cylinder_feature.py │ └── segment_3d.py └── utils │ ├── __init__.py │ ├── load_save_util.py │ ├── log_util.py │ ├── loss_func.py │ ├── lovasz_losses.py │ ├── metric_util.py │ ├── trainer_function.py │ └── ups.py └── traversability_estimation ├── __init__.py ├── geometry.py ├── ransac.py ├── segmentation.py ├── topic_service_proxy.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Data 2 | data/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # Pycharm cache 135 | .idea/ 136 | 137 | # Pytorch weights 138 | *.pth 139 | *.zip 140 | *.tar 141 | weights 142 | 143 | config/hrnet_rellis/*.yaml 144 | thirdparty 145 | tb_runs/ 146 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Czech Technical University in Prague 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /config/hrnet_rellis/seg_hrnet_ocr_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml: -------------------------------------------------------------------------------- 1 | CUDNN: 2 | BENCHMARK: true 3 | DETERMINISTIC: false 4 | ENABLED: true 5 | DATASET: 6 | DATASET: rellis 7 | NUM_CLASSES: 19 8 | ROOT: /home/ruslan/workspaces/traversability_ws/src/traversability_estimation/data/Rellis_3D/ 9 | TEST_SET: val.lst 10 | TRAIN_SET: train.lst 11 | GPUS: (0,1) 12 | LOG_DIR: log 13 | LOSS: 14 | BALANCE_WEIGHTS: 15 | - 0.4 16 | - 1 17 | OHEMKEEP: 131072 18 | OHEMTHRES: 0.9 19 | USE_OHEM: false 20 | MODEL: 21 | EXTRA: 22 | FINAL_CONV_KERNEL: 1 23 | STAGE1: 24 | BLOCK: BOTTLENECK 25 | FUSE_METHOD: SUM 26 | NUM_BLOCKS: 27 | - 4 28 | NUM_CHANNELS: 29 | - 64 30 | NUM_MODULES: 1 31 | NUM_RANCHES: 1 32 | STAGE2: 33 | BLOCK: BASIC 34 | FUSE_METHOD: SUM 35 | NUM_BLOCKS: 36 | - 4 37 | - 4 38 | NUM_BRANCHES: 2 39 | NUM_CHANNELS: 40 | - 48 41 | - 96 42 | NUM_MODULES: 1 43 | STAGE3: 44 | BLOCK: BASIC 45 | FUSE_METHOD: SUM 46 | NUM_BLOCKS: 47 | - 4 48 | - 4 49 | - 4 50 | NUM_BRANCHES: 3 51 | NUM_CHANNELS: 52 | - 48 53 | - 96 54 | - 192 55 | NUM_MODULES: 4 56 | STAGE4: 57 | BLOCK: BASIC 58 | FUSE_METHOD: SUM 59 | NUM_BLOCKS: 60 | - 4 61 | - 4 62 | - 4 63 | - 4 64 | NUM_BRANCHES: 4 65 | NUM_CHANNELS: 66 | - 48 67 | - 96 68 | - 192 69 | - 384 70 | NUM_MODULES: 3 71 | NAME: seg_hrnet_ocr 72 | NUM_OUTPUTS: 2 73 | PRETRAINED: /home/ruslan/workspaces/traversability_ws/src/traversability_estimation/config/weights/seg_hrnet_ocr_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484/best.pth 74 | OUTPUT_DIR: output 75 | PRINT_FREQ: 10 76 | TEST: 77 | BASE_SIZE: 1920 78 | BATCH_SIZE_PER_GPU: 4 79 | FLIP_TEST: false 80 | IMAGE_SIZE: 81 | - 1920 82 | - 1200 83 | MULTI_SCALE: false 84 | TRAIN: 85 | BASE_SIZE: 1920 86 | BATCH_SIZE_PER_GPU: 4 87 | BEGIN_EPOCH: 0 88 | DOWNSAMPLERATE: 1 89 | END_EPOCH: 484 90 | FLIP: true 91 | IGNORE_LABEL: 255 92 | IMAGE_SIZE: 93 | - 1024 94 | - 640 95 | LR: 0.01 96 | MOMENTUM: 0.9 97 | MULTI_SCALE: true 98 | NESTEROV: false 99 | OPTIMIZER: sgd 100 | RESUME: true 101 | SCALE_FACTOR: 16 102 | SHUFFLE: true 103 | WD: 0.0005 104 | WORKERS: 4 105 | -------------------------------------------------------------------------------- /config/hrnet_rellis/seg_hrnet_ocr_w48_train_512x1024_sgd_lr1e-3_wd5e-4_bs_12_epoch484.yaml: -------------------------------------------------------------------------------- 1 | CUDNN: 2 | BENCHMARK: true 3 | DETERMINISTIC: false 4 | ENABLED: true 5 | DATASET: 6 | DATASET: rellis 7 | NUM_CLASSES: 19 8 | ROOT: /home/ruslan/workspaces/traversability_ws/src/traversability_estimation/data/Rellis_3D/ 9 | TEST_SET: test.lst 10 | TRAIN_SET: train.lst 11 | GPUS: (0,) 12 | LOG_DIR: log 13 | LOSS: 14 | BALANCE_WEIGHTS: 15 | - 0.4 16 | - 1 17 | OHEMKEEP: 131072 18 | OHEMTHRES: 0.9 19 | USE_OHEM: false 20 | MODEL: 21 | EXTRA: 22 | FINAL_CONV_KERNEL: 1 23 | STAGE1: 24 | BLOCK: BOTTLENECK 25 | FUSE_METHOD: SUM 26 | NUM_BLOCKS: 27 | - 4 28 | NUM_CHANNELS: 29 | - 64 30 | NUM_MODULES: 1 31 | NUM_RANCHES: 1 32 | STAGE2: 33 | BLOCK: BASIC 34 | FUSE_METHOD: SUM 35 | NUM_BLOCKS: 36 | - 4 37 | - 4 38 | NUM_BRANCHES: 2 39 | NUM_CHANNELS: 40 | - 48 41 | - 96 42 | NUM_MODULES: 1 43 | STAGE3: 44 | BLOCK: BASIC 45 | FUSE_METHOD: SUM 46 | NUM_BLOCKS: 47 | - 4 48 | - 4 49 | - 4 50 | NUM_BRANCHES: 3 51 | NUM_CHANNELS: 52 | - 48 53 | - 96 54 | - 192 55 | NUM_MODULES: 4 56 | STAGE4: 57 | BLOCK: BASIC 58 | FUSE_METHOD: SUM 59 | NUM_BLOCKS: 60 | - 4 61 | - 4 62 | - 4 63 | - 4 64 | NUM_BRANCHES: 4 65 | NUM_CHANNELS: 66 | - 48 67 | - 96 68 | - 192 69 | - 384 70 | NUM_MODULES: 3 71 | NAME: seg_hrnet_ocr 72 | NUM_OUTPUTS: 2 73 | PRETRAINED: /home/ruslan/workspaces/traversability_ws/src/traversability_estimation/config/weights/seg_hrnet_ocr_w48_train_512x1024_sgd_lr1e-3_wd5e-4_bs_12_epoch484/best.pth 74 | OUTPUT_DIR: output 75 | PRINT_FREQ: 10 76 | TEST: 77 | BASE_SIZE: 1024 78 | BATCH_SIZE_PER_GPU: 1 79 | FLIP_TEST: false 80 | IMAGE_SIZE: 81 | - 1024 82 | - 512 83 | MULTI_SCALE: false 84 | TRAIN: 85 | BASE_SIZE: 1920 86 | BATCH_SIZE_PER_GPU: 4 87 | BEGIN_EPOCH: 0 88 | DOWNSAMPLERATE: 1 89 | END_EPOCH: 484 90 | FLIP: true 91 | IGNORE_LABEL: 255 92 | IMAGE_SIZE: 93 | - 1024 94 | - 640 95 | LR: 0.001 96 | MOMENTUM: 0.9 97 | MULTI_SCALE: true 98 | NESTEROV: false 99 | OPTIMIZER: sgd 100 | RESUME: true 101 | SCALE_FACTOR: 16 102 | SHUFFLE: true 103 | WD: 0.0005 104 | WORKERS: 4 105 | -------------------------------------------------------------------------------- /config/obstacles.yaml: -------------------------------------------------------------------------------- 1 | name: "obstacles" 2 | labels: 3 | 0: "traversable" 4 | 1: "obstacle" 5 | color_map: # bgr 6 | 0: [0, 255, 0] 7 | 1: [255, 0, 0] 8 | -------------------------------------------------------------------------------- /config/rellis.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | name: "rellis" 3 | labels: 4 | 0: "void" 5 | 1: "dirt" 6 | 3: "grass" 7 | 4: "tree" 8 | 5: "pole" 9 | 6: "water" 10 | 7: "sky" 11 | 8: "vehicle" 12 | 9: "object" 13 | 10: "asphalt" 14 | 12: "building" 15 | 15: "log" 16 | 17: "person" 17 | 18: "fence" 18 | 19: "bush" 19 | 23: "concrete" 20 | 27: "barrier" 21 | 31: "puddle" 22 | 33: "mud" 23 | 34: "rubble" 24 | color_map: # bgr 25 | 0: [0,0,0] 26 | 1: [108, 64, 20] 27 | 3: [0,102,0] 28 | 4: [0,255,0] 29 | 5: [0,153,153] 30 | 6: [0,128,255] 31 | 7: [0,0,255] 32 | 8: [255,255,0] 33 | 9: [255,0,127] 34 | 10: [64,64,64] 35 | 12: [255,0,0] 36 | 15: [102,0,0] 37 | 17: [204,153,255] 38 | 18: [102, 0, 204] 39 | 19: [255,153,204] 40 | 23: [170,170,170] 41 | 27: [41,121,255] 42 | 31: [134,255,239] 43 | 33: [99,66,34] 44 | 34: [110,22,138] 45 | content: # as a ratio with the total number of points 46 | 0: 447156890 47 | 1: 0 48 | 3: 261005182 49 | 4: 107172982 50 | 5: 22852 51 | 6: 224173 52 | 7: 0 53 | 8: 111345 54 | 9: 2 55 | 10: 479 56 | 12: 10 57 | 15: 554091 58 | 17: 10626325 59 | 18: 1588416 60 | 19: 168764964 61 | 23: 10944799 62 | 27: 3502156 63 | 31: 1493276 64 | 33: 5798200 65 | 34: 3395458 66 | # classes that are indistinguishable from single scan or inconsistent in 67 | # ground truth are mapped to their closest equivalent 68 | learning_map: 69 | 0: 0 #"void" 70 | 1: 0 #"dirt" 71 | 3: 1 #"grass" 72 | 4: 2 #"tree" 73 | 5: 3 #"pole" 74 | 6: 4 #"water" 75 | 7: 0 #"sky" 76 | 8: 5 #"vehicle" 77 | 9: 0 #"object" 78 | 10: 0 #"asphalt" 79 | 12: 0 #"building" 80 | 15: 6 #"log" 81 | 17: 7 #"person" 82 | 18: 8 #"fence" 83 | 19: 9 #"bush" 84 | 23: 10 #"concrete" 85 | 27: 11 #"barrier" 86 | 31: 12 #"puddle" 87 | 33: 13 #"mud" 88 | 34: 14 #"rubble" 89 | learning_map_inv: # inverse of previous map 90 | 0: 0 #"void"#"dirt" 5: 7 #"sky"9 #"object"10 #"asphalt"12 #"building" 91 | 1: 3 #"grass" 92 | 2: 4 #"tree" 93 | 3: 5 #"pole" 94 | 4: 6 #"water" 95 | 5: 8 #"vehicle" 96 | 6: 15 #"log" 97 | 7: 17 #"person" 98 | 8: 18 #"fence" 99 | 9: 19 #"bush" 100 | 10: 23 #"concrete" 101 | 11: 27 #"barrier" 102 | 12: 31 #"puddle" 103 | 13: 33 #"mud" 104 | 14: 34 #"rubble" 105 | learning_ignore: # Ignore classes 106 | 0: True #"void"#"dirt" 107 | 1: False #"grass" 108 | 2: False #"tree" 109 | 3: False #"pole" 110 | 4: False #"water" 111 | 5: False #"vehicle" 112 | 6: False #"object" 113 | 7: False #"asphalt" 114 | 8: False #"building" 115 | 9: False #"log" 116 | 10: False #"person" 117 | 11: False #"fence" 118 | 12: False #"bush" 119 | 13: False #"concrete" 120 | 14: False #"barrier" 121 | split: # sequence numbers 122 | train: "pt_train.lst" 123 | valid: "pt_val.lst" 124 | test: "pt_test.lst" -------------------------------------------------------------------------------- /config/rellis_to_flexibility.yaml: -------------------------------------------------------------------------------- 1 | '0': 255 # "void" 2 | '1': 0 # "dirt" 3 | '3': 1 # "grass" 4 | '4': 0 # "tree" 5 | '5': 0 # "pole" 6 | '6': 0 # "water" 7 | '7': 1 # "sky" 8 | '8': 0 # "vehicle" 9 | '9': 0 # "object" 10 | '10': 0 # "asphalt" 11 | '12': 0 # "building" 12 | '15': 0 # "log" 13 | '17': 0 # "person" 14 | '18': 0 # "fence" 15 | '19': 1 # "bush" 16 | '23': 0 # "concrete" 17 | '27': 0 # "barrier" 18 | '31': 0 # "puddle" 19 | '33': 0 # "mud" 20 | '34': 0 # "rubble" -------------------------------------------------------------------------------- /config/rellis_to_traversability.yaml: -------------------------------------------------------------------------------- 1 | '0': 255 # "void" 2 | '1': 1 # "dirt" 3 | '3': 0 # "grass" 4 | '4': 1 # "tree" 5 | '5': 1 # "pole" 6 | '6': 1 # "water" 7 | '7': 0 # "sky" 8 | '8': 1 # "vehicle" 9 | '9': 1 # "object" 10 | '10': 1 # "asphalt" 11 | '12': 1 # "building" 12 | '15': 1 # "log" 13 | '17': 1 # "person" 14 | '18': 1 # "fence" 15 | '19': 0 # "bush" 16 | '23': 0 # "concrete" 17 | '27': 1 # "barrier" 18 | '31': 1 # "puddle" 19 | '33': 1 # "mud" 20 | '34': 1 # "rubble" 21 | -------------------------------------------------------------------------------- /config/rellis_traversability.yaml: -------------------------------------------------------------------------------- 1 | labels: 2 | 0: "void" 3 | 1: "dirt" # True 4 | 3: "grass" # True 5 | 4: "tree" # False 6 | 5: "pole" # False 7 | 6: "water" # False 8 | 7: "sky" # False 9 | 8: "vehicle" # False 10 | 9: "object" # False 11 | 10: "asphalt" # True 12 | 12: "building" # False 13 | 15: "log" # False 14 | 17: "person" # False 15 | 18: "fence" # False 16 | 19: "bush" # False 17 | 23: "concrete" # False 18 | 27: "barrier" # False 19 | 31: "puddle" # True 20 | 33: "mud" # True 21 | 34: "rubble" # True 22 | 23 | color_map: # bgr 24 | 0: [ 0, 0, 0 ] 25 | 1: [ 0,255,0 ] 26 | 3: [ 0,255,0 ] 27 | 4: [ 255,0,0 ] 28 | 5: [ 255,0,0 ] 29 | 6: [ 255,0,0 ] 30 | 7: [ 255,0,0 ] 31 | 8: [ 255,0,0 ] 32 | 9: [ 255,0,0 ] 33 | 10: [ 0,255,0 ] 34 | 12: [ 255,0,0 ] 35 | 15: [ 255,0,0 ] 36 | 17: [ 255,0,0 ] 37 | 18: [ 255,0,0 ] 38 | 19: [ 255,0,0 ] 39 | 23: [ 255,0,0 ] 40 | 27: [ 255,0,0 ] 41 | 31: [ 0,255,0 ] 42 | 33: [ 0,255,0 ] 43 | 34: [ 0,255,0 ] 44 | content: # as a ratio with the total number of points 45 | 0: 447156890 46 | 1: 0 47 | 3: 261005182 48 | 4: 107172982 49 | 5: 22852 50 | 6: 224173 51 | 7: 0 52 | 8: 111345 53 | 9: 2 54 | 10: 479 55 | 12: 10 56 | 15: 554091 57 | 17: 10626325 58 | 18: 1588416 59 | 19: 168764964 60 | 23: 10944799 61 | 27: 3502156 62 | 31: 1493276 63 | 33: 5798200 64 | 34: 3395458 65 | # classes that are indistinguishable from single scan or inconsistent in 66 | # ground truth are mapped to their closest equivalent 67 | learning_map: 68 | 0: 0 #"void" 69 | 1: 0 #"dirt" 70 | 3: 1 #"grass" 71 | 4: 2 #"tree" 72 | 5: 3 #"pole" 73 | 6: 4 #"water" 74 | 7: 0 #"sky" 75 | 8: 5 #"vehicle" 76 | 9: 0 #"object" 77 | 10: 0 #"asphalt" 78 | 12: 0 #"building" 79 | 15: 6 #"log" 80 | 17: 7 #"person" 81 | 18: 8 #"fence" 82 | 19: 9 #"bush" 83 | 23: 10 #"concrete" 84 | 27: 11 #"barrier" 85 | 31: 12 #"puddle" 86 | 33: 13 #"mud" 87 | 34: 14 #"rubble" 88 | learning_map_inv: # inverse of previous map 89 | 0: 0 #"void"#"dirt" 5: 7 #"sky"9 #"object"10 #"asphalt"12 #"building" 90 | 1: 3 #"grass" 91 | 2: 4 #"tree" 92 | 3: 5 #"pole" 93 | 4: 6 #"water" 94 | 5: 8 #"vehicle" 95 | 6: 15 #"log" 96 | 7: 17 #"person" 97 | 8: 18 #"fence" 98 | 9: 19 #"bush" 99 | 10: 23 #"concrete" 100 | 11: 27 #"barrier" 101 | 12: 31 #"puddle" 102 | 13: 33 #"mud" 103 | 14: 34 #"rubble" 104 | learning_ignore: # Ignore classes 105 | 0: True #"void"#"dirt" 106 | 1: False #"grass" 107 | 2: False #"tree" 108 | 3: False #"pole" 109 | 4: False #"water" 110 | 5: False #"vehicle" 111 | 6: False #"object" 112 | 7: False #"asphalt" 113 | 8: False #"building" 114 | 9: False #"log" 115 | 10: False #"person" 116 | 11: False #"fence" 117 | 12: False #"bush" 118 | 13: False #"concrete" 119 | 14: False #"barrier" 120 | split: # sequence numbers 121 | train: "pt_train.lst" 122 | valid: "pt_val.lst" 123 | test: "pt_test.lst" -------------------------------------------------------------------------------- /config/semantickitti19_to_flexibility.yaml: -------------------------------------------------------------------------------- 1 | 0 : 255 # "unlabeled" 2 | 1 : 255 # "outlier" 3 | 10: 0 # "car" 4 | 11: 0 # "bicycle" 5 | 13: 0 # "bus" 6 | 15: 0 # "motorcycle" 7 | 16: 0 # "on-rails" 8 | 18: 0 # "truck" 9 | 20: 0 # "other-vehicle" 10 | 30: 0 # "person" 11 | 31: 0 # "bicyclist" 12 | 32: 0 # "motorcyclist" 13 | 40: 0 # "road" 14 | 44: 0 # "parking" 15 | 48: 0 # "sidewalk" 16 | 49: 0 # "other-ground" 17 | 50: 0 # "building" 18 | 51: 0 # "fence" 19 | 52: 0 # "other-structure" 20 | 60: 0 # "lane-marking" 21 | 70: 1 # "vegetation" 22 | 71: 0 # "trunk" 23 | 72: 0 # "terrain" 24 | 80: 0 # "pole" 25 | 81: 0 # "traffic-sign" 26 | 99: 0 # "other-object" 27 | 252: 0 # "moving-car" 28 | 253: 0 # "moving-bicyclist" 29 | 254: 0 # "moving-person" 30 | 255: 0 # "moving-motorcyclist" 31 | 256: 0 # "moving-on-rails" 32 | 257: 0 # "moving-bus" 33 | 258: 0 # "moving-truck" 34 | 259: 0 # "moving-other-vehicle" 35 | -------------------------------------------------------------------------------- /config/semantickitti19_to_traversability.yaml: -------------------------------------------------------------------------------- 1 | 0 : 255 # "unlabeled" 2 | 1 : 255 # "outlier" 3 | 10: 1 # "car" 4 | 11: 1 # "bicycle" 5 | 13: 1 # "bus" 6 | 15: 1 # "motorcycle" 7 | 16: 1 # "on-rails" 8 | 18: 1 # "truck" 9 | 20: 1 # "other-vehicle" 10 | 30: 1 # "person" 11 | 31: 1 # "bicyclist" 12 | 32: 1 # "motorcyclist" 13 | 40: 0 # "road" 14 | 44: 0 # "parking" 15 | 48: 0 # "sidewalk" 16 | 49: 0 # "other-ground" 17 | 50: 1 # "building" 18 | 51: 1 # "fence" 19 | 52: 1 # "other-structure" 20 | 60: 0 # "lane-marking" 21 | 70: 1 # "vegetation" 22 | 71: 1 # "trunk" 23 | 72: 0 # "terrain" 24 | 80: 1 # "pole" 25 | 81: 1 # "traffic-sign" 26 | 99: 1 # "other-object" 27 | 252: 1 # "moving-car" 28 | 253: 1 # "moving-bicyclist" 29 | 254: 1 # "moving-person" 30 | 255: 1 # "moving-motorcyclist" 31 | 256: 1 # "moving-on-rails" 32 | 257: 1 # "moving-bus" 33 | 258: 1 # "moving-truck" 34 | 259: 1 # "moving-other-vehicle" 35 | -------------------------------------------------------------------------------- /config/slam/icp.yaml: -------------------------------------------------------------------------------- 1 | matcher: 2 | KDTreeMatcher: 3 | knn: 3 4 | maxDist: 10.0 5 | epsilon: 0 6 | 7 | outlierFilters: 8 | - TrimmedDistOutlierFilter: 9 | ratio: 0.80 10 | - SurfaceNormalOutlierFilter: 11 | maxAngle: 1.57 12 | 13 | errorMinimizer: 14 | PointToPlaneErrorMinimizer: 15 | 16 | transformationCheckers: 17 | - DifferentialTransformationChecker: 18 | minDiffRotErr: 0.001 19 | minDiffTransErr: 0.01 20 | smoothLength: 2 21 | - CounterTransformationChecker: 22 | maxIterationCount: 100 23 | - BoundTransformationChecker: 24 | maxRotationNorm: 0.8 25 | maxTranslationNorm: 30.00 26 | 27 | inspector: 28 | NullInspector 29 | 30 | logger: 31 | FileLogger 32 | -------------------------------------------------------------------------------- /config/slam/input_filters.yaml: -------------------------------------------------------------------------------- 1 | - DistanceLimitDataPointsFilter: 2 | dim: -1 3 | dist: 0.5 4 | 5 | - SurfaceNormalDataPointsFilter: 6 | knn: 9 7 | epsilon: 0 8 | keepNormals: 1 9 | 10 | - ObservationDirectionDataPointsFilter 11 | 12 | - OrientNormalsDataPointsFilter: 13 | towardCenter: 1 14 | -------------------------------------------------------------------------------- /config/slam/map_post_filters.yaml: -------------------------------------------------------------------------------- 1 | - SurfaceNormalDataPointsFilter: 2 | knn: 9 3 | epsilon: 0 4 | keepNormals: 1 5 | 6 | - OrientNormalsDataPointsFilter: 7 | towardCenter: 1 8 | -------------------------------------------------------------------------------- /config/workspace.repos: -------------------------------------------------------------------------------- 1 | repositories: 2 | cras_msgs: 3 | type: git 4 | url: git@gitlab.fel.cvut.cz:cras/subt/common/cras_msgs.git 5 | version: master 6 | cras_ouster_description: 7 | type: git 8 | url: git@gitlab.fel.cvut.cz:cras/subt/sensors/ouster/cras_ouster_description.git 9 | version: master 10 | cras_ouster_driver: 11 | type: git 12 | url: git@gitlab.fel.cvut.cz:cras/subt/sensors/ouster/cras_ouster_driver.git 13 | version: master 14 | cras_ouster_msgs: 15 | type: git 16 | url: git@gitlab.fel.cvut.cz:cras/subt/sensors/ouster/cras_ouster_msgs.git 17 | version: master 18 | ctu_mapping: 19 | type: git 20 | url: https://github.com/norlab-ulaval/ctu_mapping.git 21 | version: husky 22 | libpointmatcher_ros: 23 | type: git 24 | url: https://github.com/norlab-ulaval/libpointmatcher_ros.git 25 | version: master 26 | marv_common: 27 | type: git 28 | url: https://gitlab.jettyvision.cz/usar_marv/marv_common.git 29 | version: master 30 | norlab_icp_mapper: 31 | type: git 32 | url: git@github.com:tpet/norlab_icp_mapper.git 33 | version: master 34 | norlab_icp_mapper_ros: 35 | type: git 36 | url: git@github.com:tpet/norlab_icp_mapper_ros.git 37 | version: master 38 | ouster_example: 39 | type: git 40 | url: https://github.com/ouster-lidar/ouster_example.git 41 | version: 47f25ed29ab0b4b6d32c3209fa1b53ea46751d0c 42 | ouster_ray_directions: 43 | type: git 44 | url: git@gitlab.fel.cvut.cz:cras/subt/sensors/ouster/ouster_ray_directions.git 45 | version: master 46 | point_cloud_color: 47 | type: git 48 | url: git@github.com:tpet/point_cloud_color.git 49 | version: master 50 | ros-utils: 51 | type: git 52 | url: git@gitlab.fel.cvut.cz:cras/subt/common/ros-utils.git 53 | version: master 54 | traversability_estimation: 55 | type: git 56 | url: git@github.com:RuslanAgishev/traversability_estimation.git 57 | version: main 58 | -------------------------------------------------------------------------------- /docker/jetson/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/l4t-ml:r32.7.1-py3 2 | 3 | 4 | RUN apt-get update && apt-get install -y \ 5 | ffmpeg \ 6 | build-essential \ 7 | libsm6 \ 8 | libxext6 \ 9 | libfontconfig1 \ 10 | libxrender1 \ 11 | libswscale-dev \ 12 | libtbb2 \ 13 | libtbb-dev \ 14 | libjpeg-dev \ 15 | libpng-dev \ 16 | libtiff-dev \ 17 | libavformat-dev \ 18 | libpq-dev \ 19 | libturbojpeg \ 20 | software-properties-common \ 21 | libboost-all-dev \ 22 | libssl-dev \ 23 | libgeos-dev \ 24 | wget \ 25 | nano \ 26 | sudo \ 27 | python3-matplotlib \ 28 | python3-opencv \ 29 | python3-tk \ 30 | && apt-get clean \ 31 | && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* 32 | 33 | 34 | RUN pip3 install \ 35 | tqdm==4.62.3 \ 36 | yacs==0.1.6 \ 37 | #open3d==0.11.2 \ 38 | gnupg==2.3.1 \ 39 | configparser==5.2.0 \ 40 | psutil==5.8.0 \ 41 | rospkg \ 42 | empy 43 | 44 | COPY ./install_ros_melodic.sh /tmp/install_ros.sh 45 | RUN /tmp/install_ros.sh 46 | 47 | -------------------------------------------------------------------------------- /docker/jetson/Makefile: -------------------------------------------------------------------------------- 1 | APP_NAME=ml:traversability 2 | CONTAINER_NAME=traversability 3 | 4 | build: 5 | docker build -t $(APP_NAME) -f Dockerfile . 6 | 7 | inference: 8 | docker run -it --rm --runtime nvidia --network host \ 9 | -v /mnt:/mnt \ 10 | -v /home/robot/:/home/robot \ 11 | $(APP_NAME) bash 12 | 13 | stop: ## Stop and remove a running container 14 | docker stop ${CONTAINER_NAME}; docker rm ${CONTAINER_NAME} 15 | 16 | -------------------------------------------------------------------------------- /docker/jetson/install_ros_melodic.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install Robot Operating System (ROS) on NVIDIA Jetson Developer Kit 4 | # Maintainer of ARM builds for ROS is http://answers.ros.org/users/1034/ahendrix/ 5 | # Information from: 6 | # http://wiki.ros.org/melodic/Installation/UbuntuARM 7 | 8 | # Red is 1 9 | # Green is 2 10 | # Reset is sgr0 11 | 12 | function usage 13 | { 14 | echo "Usage: ./installROS.sh [[-p package] | [-h]]" 15 | echo "Install ROS Melodic" 16 | echo "Installs ros-melodic-ros-base as default base package; Use -p to override" 17 | echo "-p | --package ROS package to install" 18 | echo " Multiple usage allowed" 19 | echo " Must include one of the following:" 20 | echo " ros-melodic-ros-base" 21 | echo " ros-melodic-desktop" 22 | echo " ros-melodic-desktop-full" 23 | echo "-h | --help This message" 24 | } 25 | 26 | function shouldInstallPackages 27 | { 28 | tput setaf 1 29 | echo "Your package list did not include a recommended base package" 30 | tput sgr0 31 | echo "Please include one of the following:" 32 | echo " ros-melodic-ros-base" 33 | echo " ros-melodic-desktop" 34 | echo " ros-melodic-desktop-full" 35 | echo "" 36 | echo "ROS not installed" 37 | } 38 | 39 | # Iterate through command line inputs 40 | packages=() 41 | while [ "$1" != "" ]; do 42 | case $1 in 43 | -p | --package ) shift 44 | packages+=("$1") 45 | ;; 46 | -h | --help ) usage 47 | exit 48 | ;; 49 | * ) usage 50 | exit 1 51 | esac 52 | shift 53 | done 54 | # Check to see if other packages were specified 55 | # If not, set the default base package 56 | if [ ${#packages[@]} -eq 0 ] ; then 57 | packages+="ros-melodic-ros-base" 58 | fi 59 | echo "Packages to install: "${packages[@]} 60 | # Check to see if we have a ROS base kinda thingie 61 | hasBasePackage=false 62 | for package in "${packages[@]}"; do 63 | if [[ $package == "ros-melodic-ros-base" ]]; then 64 | hasBasePackage=true 65 | break 66 | elif [[ $package == "ros-melodic-desktop" ]]; then 67 | hasBasePackage=true 68 | break 69 | elif [[ $package == "ros-melodic-desktop-full" ]]; then 70 | hasBasePackage=true 71 | break 72 | fi 73 | done 74 | if [ $hasBasePackage == false ] ; then 75 | shouldInstallPackages 76 | exit 1 77 | fi 78 | 79 | # Let's start installing! 80 | 81 | tput setaf 2 82 | echo "Adding repository and source list" 83 | tput sgr0 84 | sudo apt-add-repository universe 85 | sudo apt-add-repository multiverse 86 | sudo apt-add-repository restricted 87 | 88 | # Setup sources.lst 89 | sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' 90 | # Setup keys 91 | sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654 92 | # If you experience issues connecting to the keyserver, you can try substituting hkp://pgp.mit.edu:80 or hkp://keyserver.ubuntu.com:80 in the previous command. 93 | # Installation 94 | tput setaf 2 95 | echo "Updating apt-get" 96 | tput sgr0 97 | sudo apt-get update 98 | tput setaf 2 99 | echo "Installing ROS" 100 | tput sgr0 101 | # This is where you might start to modify the packages being installed, i.e. 102 | # sudo apt-get install ros-melodic-desktop 103 | 104 | # Here we loop through any packages passed on the command line 105 | # Install packages ... 106 | for package in "${packages[@]}"; do 107 | sudo apt-get install $package -y 108 | done 109 | 110 | # Add Individual Packages here 111 | # You can install a specific ROS package (replace underscores with dashes of the package name): 112 | # sudo apt-get install ros-melodic-PACKAGE 113 | # e.g. 114 | # sudo apt-get install ros-melodic-navigation 115 | # 116 | # To find available packages: 117 | # apt-cache search ros-melodic 118 | # 119 | # Initialize rosdep 120 | tput setaf 2 121 | echo "Installing rosdep" 122 | tput sgr0 123 | sudo apt-get install python-rosdep -y 124 | # Certificates are messed up on earlier version Jetson for some reason 125 | # sudo c_rehash /etc/ssl/certs 126 | # Initialize rosdep 127 | tput setaf 2 128 | echo "Initializaing rosdep" 129 | tput sgr0 130 | sudo rosdep init 131 | # To find available packages, use: 132 | rosdep update 133 | # Environment Setup - Don't add /opt/ros/melodic/setup.bash if it's already in bashrc 134 | grep -q -F 'source /opt/ros/melodic/setup.bash' ~/.bashrc || echo "source /opt/ros/melodic/setup.bash" >> ~/.bashrc 135 | source ~/.bashrc 136 | # Install rosinstall 137 | tput setaf 2 138 | echo "Installing rosinstall tools" 139 | tput sgr0 140 | sudo apt-get install python-rosinstall python-rosinstall-generator python-wstool build-essential -y 141 | tput setaf 2 142 | echo "Installation complete!" 143 | echo "Please setup your Catkin Workspace" 144 | #tput sgr0 145 | 146 | -------------------------------------------------------------------------------- /docker/jetson/install_ros_noetic.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' 4 | 5 | apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654 6 | 7 | # https://gist.github.com/Pyrestone/ef683aec160825eee5c252f22218ddb2 8 | apt-get update 9 | apt-get install python3-rosdep python3-rosinstall-generator python3-vcstool build-essential python3-empy libconsole-bridge-dev libpoco-dev libtinyxml-dev qtbase5-dev liborocos-kdl-dev -y 10 | 11 | rosdep init 12 | rosdep update 13 | 14 | mkdir -p /opt/ros/ros_catkin_ws 15 | cd /opt/ros/ros_catkin_ws && \ 16 | rosinstall_generator robot perception --rosdistro noetic --deps --tar > noetic-robot-perception.rosinstall 17 | 18 | mkdir -p /opt/ros/ros_catkin_ws/src && \ 19 | cd /opt/ros/ros_catkin_ws/ && \ 20 | vcs import --input noetic-robot-perception.rosinstall ./src && \ 21 | rosdep install --from-paths ./src --ignore-packages-from-source --rosdistro noetic -y 22 | 23 | cd /opt/ros/ros_catkin_ws/ && \ 24 | ./src/catkin/bin/catkin_make_isolated --install -DCMAKE_BUILD_TYPE=Release -DPYTHON_EXECUTABLE=/usr/bin/python3 --install-space /opt/ros/noetic 25 | 26 | source /opt/ros/ros_catkin_ws/install_isolated/setup.bash 27 | 28 | -------------------------------------------------------------------------------- /docs/cloud_flex_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ctu-vras/traversability_estimation/9e96f12a6769e8d90240e54cce47b4afd25a3229/docs/cloud_flex_gt.png -------------------------------------------------------------------------------- /docs/cloud_trav_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ctu-vras/traversability_estimation/9e96f12a6769e8d90240e54cce47b4afd25a3229/docs/cloud_trav_gt.png -------------------------------------------------------------------------------- /docs/colored_pc_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ctu-vras/traversability_estimation/9e96f12a6769e8d90240e54cce47b4afd25a3229/docs/colored_pc_demo.png -------------------------------------------------------------------------------- /docs/colored_pc_demo_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ctu-vras/traversability_estimation/9e96f12a6769e8d90240e54cce47b4afd25a3229/docs/colored_pc_demo_pred.png -------------------------------------------------------------------------------- /docs/colored_pc_demo_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ctu-vras/traversability_estimation/9e96f12a6769e8d90240e54cce47b4afd25a3229/docs/colored_pc_demo_rgb.png -------------------------------------------------------------------------------- /docs/docker.md: -------------------------------------------------------------------------------- 1 | ## Docker 2 | 3 | First of all install Docker: 4 | ``` 5 | sudo apt install docker.io 6 | ``` 7 | After that install [nvidia-docker v2.0](): 8 | ``` 9 | sudo apt-get install nvidia-docker2 10 | sudo pkill -SIGHUP dockerd 11 | ``` 12 | 13 | ### Docker image with `pytorch` and `ROS melodic` on Jetson 14 | 15 | ```bash 16 | git checkout jetson 17 | cd ../traversability_estimation/docker/jetson/ 18 | ``` 19 | 20 | Build docker image: 21 | ```bash 22 | make build 23 | ``` 24 | 25 | Run docker container: 26 | ```bash 27 | make inference 28 | ``` 29 | -------------------------------------------------------------------------------- /docs/install.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | Prerequisite: 4 | - install [ROS](http://wiki.ros.org/ROS/Installation). 5 | - install [PyTorch](https://pytorch.org). 6 | - install [torchvision](https://pytorch.org/vision/stable/index.html). 7 | 8 | If you want to use only semantic cloud segmentation node just build the package in a catkin workspace, for example: 9 | 10 | ```bash 11 | mkdir -p ~/catkin_ws/src 12 | cd ~/catkin_ws/src 13 | git clone https://github.com/ctu-vras/traversability_estimation 14 | git clone https://github.com/ctu-vras/cloud_proc 15 | cd ~/catkin_ws/ 16 | rosdep install --from-paths /catkin_ws --ignore-src --rosdistro noetic -y 17 | catkin build 18 | ``` 19 | 20 | In case you would like to run geometric cloud segmentation, traversability fusion or image segementation to point cloud projection nodes, 21 | please follow the extended proceedure (requires access to another repositories): 22 | 23 | - Install [vcstool](http://wiki.ros.org/vcstool) for workspace creation: 24 | ```bash 25 | sudo apt install python3-vcstool 26 | ``` 27 | - Create and build ROS workspace: 28 | ```bash 29 | cd ~/catkin_ws/ 30 | vcs import src < src/traversability_estimation/config/workspace.repos 31 | catkin config -DCMAKE_BUILD_TYPE=Release 32 | catkin build 33 | ``` 34 | 35 | Put the [weights](http://subtdata.felk.cvut.cz/robingas/data/traversability_estimation/weights/) 36 | to [./config/weights/](./config/weights/) folder: 37 | 38 | ```bash 39 | ./config/weights/ 40 | ├── hrnetv2_w48_imagenet_pretrained.pth 41 | ├── seg_hrnet_ocr_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484/ 42 | ├── depth_cloud/ 43 | └── smp/ 44 | └── se_resnext50_32x4d_352x640_lr1e-4.pth 45 | ``` 46 | 47 | One may also download datasets to train images and point cloud segmentation models. 48 | Please, refer to [./docs/rellis.md](./docs/rellis.md) or [./docs/trav_data.md](./docs/trav_data.md) for examples. 49 | -------------------------------------------------------------------------------- /docs/rellis.md: -------------------------------------------------------------------------------- 1 | ## RELLIS-3D Dataset 2 | 3 | A multimodal dataset collected in an off-road environment containing annotations 4 | for 13,556 LiDAR scans and 6,235 images (semantic segmentation). 5 | Data in ROS bag format, including RGB camera images, LiDAR point clouds, a pair of stereo images, 6 | high-precision GPS measurement, and IMU data. 7 | 8 | ### Format 9 | 10 | - Go to the dataset [webpage](https://unmannedlab.github.io/research/RELLIS-3D). 11 | - Download the data to the path [traversability_estimation/data](../data). 12 | - Extract the zip files in order to have the following layout on disk: 13 | 14 | ```bash 15 | ├─ Rellis_3D 16 | ├── 00000 17 | │ ├── os1_cloud_node_color_ply 18 | │ ├── os1_cloud_node_kitti_bin 19 | │ ├── os1_cloud_node_semantickitti_label_id 20 | │ ├── pylon_camera_node 21 | │ ├── pylon_camera_node_label_color 22 | │ └── pylon_camera_node_label_id 23 | ... 24 | └── calibration 25 | ├── 00000 26 | ... 27 | └── raw_data 28 | ``` 29 | 30 | Run the demo to explore data samples (assuming that the package is built): 31 | ```bash 32 | python -m datasets.rellis_3d 33 | ``` 34 | See [rellis_3d.py](../src/datasets/rellis_3d.py) for more details. 35 | 36 | ### ROS and Rellis3D 37 | 38 | Publish the RELLIS-3D data as ROS messages: 39 | 40 | ```bash 41 | roslaunch traversability_estimation robot_data.launch data_sequence:='00000' rviz:=True 42 | ``` 43 | -------------------------------------------------------------------------------- /docs/rgb_sem_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ctu-vras/traversability_estimation/9e96f12a6769e8d90240e54cce47b4afd25a3229/docs/rgb_sem_gt.png -------------------------------------------------------------------------------- /docs/rgb_sem_gt4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ctu-vras/traversability_estimation/9e96f12a6769e8d90240e54cce47b4afd25a3229/docs/rgb_sem_gt4.png -------------------------------------------------------------------------------- /docs/segmentation_labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ctu-vras/traversability_estimation/9e96f12a6769e8d90240e54cce47b4afd25a3229/docs/segmentation_labels.png -------------------------------------------------------------------------------- /docs/segmented_pc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ctu-vras/traversability_estimation/9e96f12a6769e8d90240e54cce47b4afd25a3229/docs/segmented_pc.png -------------------------------------------------------------------------------- /docs/semantic_traversability_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ctu-vras/traversability_estimation/9e96f12a6769e8d90240e54cce47b4afd25a3229/docs/semantic_traversability_pipeline.png -------------------------------------------------------------------------------- /docs/trav_data.md: -------------------------------------------------------------------------------- 1 | ## Traversability dataset 2 | 3 | [Traversability Dataset](http://subtdata.felk.cvut.cz/robingas/data/traversability_estimation/TraversabilityDataset/): 4 | 5 | - Point clouds (151 data samples) and RGB images (250 data samples) with segmentation labels (traversable, non-traversable area). 6 | Data is labelled manually by human annotator. 7 | 8 | Forest environment | Town environment 9 | :-------------------------:|:-------------------------: 10 | ![](./rgb_sem_gt.png) | ![](./rgb_sem_gt4.png) 11 | 12 | ![](./cloud_trav_gt.png) 13 | 14 | - Point clouds with self-supervised annotations of traversable area from robot's trajectories (10162 data samples). 15 | Data is labelled utilizing robot pose estimation (with the help of LiDAR SLAM). 16 | Automatic data annotation for learning the traversability model in the field (e.g. driving through grass). 17 | Designation of solid terrain under the robot and according to geometry. 18 | 19 | ![](./cloud_flex_gt.png) 20 | 21 | Download the data to the path [traversability_estimation/data](../data). 22 | 23 | Run the demo to explore data samples (assuming that the package is built): 24 | ```bash 25 | python -m datasets.traversability_dataset 26 | ``` 27 | 28 | See [traversability_dataset.py](../src/datasets/traversability_dataset.py) for more details. 29 | -------------------------------------------------------------------------------- /launch/base_footprint.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 12 | 13 | -------------------------------------------------------------------------------- /launch/cloud_filter.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 20 | 21 | min_x: -0.5 22 | max_x: 0.5 23 | min_y: -0.3 24 | max_y: 0.3 25 | min_z: -0.14 26 | max_z: 0.9 27 | negative: true 28 | keep_organized: true 29 | input_frame: base_link 30 | output_frame: os_sensor 31 | receive: 32 | rate: 33 | desired: 10 34 | delay: 35 | min: 0.1 36 | max: 0.2 37 | publish: 38 | rate: 39 | desired: 10 40 | delay: 41 | min: 0.1 42 | max: 0.2 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 53 | 54 | min_x: -0.49 55 | max_x: 0.49 56 | min_y: -0.3354 57 | max_y: 0.3354 58 | min_z: -0.13228 59 | max_z: 0.62818 60 | 61 | 62 | min_x: -0.65 63 | max_x: 0.65 64 | min_y: -0.30 65 | max_y: 0.30 66 | min_z: -0.14 67 | max_z: 0.8 68 | 69 | 70 | min_x: -0.56 71 | max_x: 0.56 72 | min_y: -0.26 73 | max_y: 0.26 74 | min_z: -0.6 75 | max_z: 0.5 76 | 77 | 78 | negative: true 79 | keep_organized: true 80 | input_frame: base_link 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /launch/cloud_fusion.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | 8 | 9 | 10 | 11 | 12 | data_sequence: $(arg data_sequence) 13 | map_step: 1 14 | pose_step: 10 15 | lidar_frame: 'ouster_lidar' 16 | camera_frame: 'pylon_camera' 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | device: $(arg device) 32 | max_age: 0.2 33 | lidar_channels: 64 34 | lidar_beams: 2048 35 | lidar_fov_up: 22.5 36 | lidar_fov_down: -22.5 37 | range_projection: true 38 | debug: true 39 | weights: deeplabv3_resnet101_lr_0.0001_bs_8_epoch_90_TraversabilityClouds_depth_labels_traversability_iou_0.972.pth 40 | 41 | 42 | 43 | 44 | 45 | 46 | 49 | 50 | field: x 51 | grid: 0.1 52 | zero_valid: false 53 | 54 | 55 | 56 | 57 | 58 | 61 | 62 | field: x 63 | grid: 0.2 64 | zero_valid: false 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | map_frame: odom 73 | max_age: 1.0 74 | pts_proximity_th: 0.2 75 | label_to_fuse: cost 76 | fusion_mode: $(arg fusion_mode) 77 | 78 | 79 | 80 | 81 | 82 | 83 | 86 | 87 | field: x 88 | grid: 0.2 89 | zero_valid: false 90 | 91 | 92 | 93 | 94 | 95 | 96 | 98 | 99 | -------------------------------------------------------------------------------- /launch/cloud_ortho_stats.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 28 | 29 | extent: $(arg extent) 30 | size: $(arg size) 31 | grid: $(arg grid) 32 | mode: $(arg mode) 33 | output_z: $(arg output_z) 34 | eigenvalues: $(arg eigenvalues) 35 | target_frame: '$(arg target_frame)' 36 | use_only_orientation: $(arg use_only_orientation) 37 | min_z: $(arg min_z) 38 | max_z: $(arg max_z) 39 | zero_valid: false 40 | timeout: 0.2 41 | 42 | 43 | 44 | 45 | 46 | 47 | 49 | 50 | field: z 51 | scale: 32.768 52 | offset: 32.768 53 | 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /launch/cloud_projection.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | height: 128 11 | width: 1024 12 | fov_azimuth: 6.283185307179586 13 | fov_elevation: 1.5707963267948966 14 | keep: 2 15 | azimuth_only: false 16 | frame: odom 17 | timeout: 0.2 18 | 19 | 20 | 21 | 22 | 23 | 25 | 26 | negative: true 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /launch/ctu_robot.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | model_name: hrnet 11 | device: $(arg device) 12 | dtype: float 13 | num_cameras: 3 14 | image_transport: compressed 15 | legend: false 16 | max_age: 1.0 17 | input_scale: 0.25 18 | traversability_labels: false 19 | label_config: $(dirname)/../config/obstacles.yaml 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 51 | 52 | fixed_frame: odom 53 | 54 | 55 | 56 | field_name: obstacle 57 | field_type: 2 58 | default_color: 1.0 59 | num_cameras: 3 60 | image_transport: compressed 61 | max_image_age: 15.0 62 | use_first_valid: true 63 | image_queue_size: 2 64 | cloud_queue_size: 2 65 | wait_for_transform: 1.0 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /launch/demo_cloud.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 31 | 32 | 33 | 34 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 46 | 47 | min_x: -0.5 48 | max_x: 0.5 49 | min_y: -0.3 50 | max_y: 0.3 51 | min_z: -0.14 52 | max_z: 0.6 53 | negative: true 54 | keep_organized: true 55 | input_frame: base_link 56 | output_frame: os_sensor 57 | receive: 58 | rate: 59 | desired: 10 60 | delay: 61 | min: 0.1 62 | max: 0.2 63 | publish: 64 | rate: 65 | desired: 10 66 | delay: 67 | min: 0.1 68 | max: 0.2 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 88 | 89 | 91 | 92 | -------------------------------------------------------------------------------- /launch/fused_traversability.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | fixed_frame: gps_odom 38 | trigger: geometric 39 | sync: false 40 | max_time_diff: 1.0 41 | dist_th: 0.25 42 | flat_cost_th: 0.5 43 | obstacle_cost_th: 0.9 44 | semantic_cost_offset: 0.5 45 | timeout: 0.5 46 | rate: 0.0 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 57 | 58 | -------------------------------------------------------------------------------- /launch/generate_points.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 18 | 19 | 20 | 21 | 22 | 23 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /launch/generate_trav_src.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 54 | 55 | 57 | 58 | 59 | 60 | 63 | 64 | -------------------------------------------------------------------------------- /launch/geometric_cloud_segmentation.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 37 | 38 | 40 | 41 | 42 | 43 | 44 | max_age: 0.5 45 | fixed_frame: map 46 | z_support: 47 | range: [0.6, 8.0] 48 | grid: 0.05 49 | scale: 0.05 50 | radius: 0.05 51 | 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /launch/image_segmentation_dataset_demo.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | model_name: $(arg model_name) 23 | smp_weights: $(arg smp_weights) 24 | hrnet_weights: $(arg hrnet_weights) 25 | device: $(arg device) 26 | num_cameras: 1 27 | image_transport: 'compressed' 28 | legend: false 29 | max_age: 1.0 30 | input_scale: 0.5 31 | traversability_labels: false 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /launch/marv_robot.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | model_name: hrnet 11 | device: $(arg device) 12 | dtype: float 13 | num_cameras: 3 14 | image_transport: compressed 15 | legend: false 16 | max_age: 1.0 17 | input_scale: 0.25 18 | traversability_labels: false 19 | label_config: $(dirname)/../config/obstacles.yaml 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 51 | 52 | fixed_frame: odom 53 | 54 | 55 | 56 | field_name: obstacle 57 | field_type: 2 58 | default_color: 1.0 59 | num_cameras: 3 60 | image_transport: compressed 61 | max_image_age: 15.0 62 | use_first_valid: true 63 | image_queue_size: 2 64 | cloud_queue_size: 2 65 | wait_for_transform: 1.0 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /launch/orient_frame.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 10 | 11 | parent_frame: $(arg parent_frame) 12 | child_frame: $(arg child_frame) 13 | oriented_frame: $(arg oriented_frame) 14 | align: z 15 | timeout: 0.1 16 | timeout_relative: false 17 | trigger_queue_size: 2 18 | tf_queue_size: 5 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /launch/play_bag.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | 9 | 11 | 12 | 15 | 17 | 19 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /launch/robot_data.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | data_sequence: $(arg data_sequence) 10 | pose_step: $(arg pose_step) 11 | lidar_frame: 'ouster_lidar' 12 | camera_frame: 'pylon_camera' 13 | 14 | 15 | 16 | 17 | 19 | 20 | -------------------------------------------------------------------------------- /launch/semantic_traversability_tconcord3d.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 25 | 26 | 27 | 28 | 29 | 30 | device: $(arg device) 31 | max_age: $(arg max_age) 32 | weights: $(arg weights) 33 | cloud_in: $(arg input) 34 | cloud_out: cloud_segmentation_tconcord3d/points 35 | > 36 | 37 | 38 | 39 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /launch/show_trav_data.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 28 | 29 | 32 | 33 | -------------------------------------------------------------------------------- /launch/slam.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 9 | 10 | odom_frame: $(arg odom_frame) 11 | map_frame: $(arg map_frame) 12 | 13 | robot_frame: base_link 14 | initial_map_file_name: '' 15 | initial_robot_pose: '' 16 | final_map_file_name: $(dirname)/../map.vtk 17 | final_trajectory_file_name: $(dirname)/../trajectory.vtk 18 | 19 | 20 | 21 | input_filters_config: $(find ctu_mapping)/params/realtime_input_filters.yaml 22 | icp_config: $(find ctu_mapping)/params/realtime_icp_config.yaml 23 | map_post_filters_config: $(find ctu_mapping)/params/realtime_post_filters.yaml 24 | map_update_condition: overlap 25 | map_update_overlap: 0.9 26 | map_update_delay: 0.0 27 | map_update_distance: 0.0 28 | map_publish_rate: 1.0 29 | map_tf_publish_rate: 0.0 30 | max_idle_time: 10 31 | min_dist_new_point: 0.1 32 | sensor_max_range: 25 33 | prior_dynamic: 0.6 34 | threshold_dynamic: 0.9 35 | beam_half_angle: 0.01 36 | epsilon_a: 0.01 37 | epsilon_d: 0.01 38 | alpha: 0.8 39 | beta: 0.99 40 | is_3D: true 41 | is_online: true 42 | compute_prob_dynamic: true 43 | is_mapping: true 44 | save_map_cells_on_hard_drive: false 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /launch/traversability_bag_demo.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /launch/traversability_dataset_demo.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 23 | 24 | 25 | 26 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /launch/traversability_evaluation.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | robot_frame: base_link 46 | fixed_frame: odom 47 | robot_radius: 0.5 48 | horizon_time: 3.0 49 | horizon_step: 1.0 50 | cloud_in: graph_pcd 51 | cloud_out: traversed_points/points 52 | bag: $(arg bag) 53 | 54 | 55 | 56 | 57 | 58 | 59 | map_frame: odom 60 | max_age: 0.2 61 | pts_proximity_th: 0.5 62 | label_to_fuse: 'untrav_cost' 63 | fusion_mode: $(arg trav_fusion_mode) 64 | 65 | 66 | 67 | 68 | 69 | 70 | 72 | 73 | -------------------------------------------------------------------------------- /launch/video/record_video.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 16 | 17 | use_scale: false 18 | height: $(arg height) 19 | width: $(arg width) 20 | 21 | 22 | 23 | 24 | 25 | 27 | 28 | filename: $(arg filename) 29 | fps: $(arg fps) 30 | codec: $(arg codec) 31 | encoding: $(arg encoding) 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /notebooks/smp_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | import yaml 5 | from hrnet.core.function import convert_label, convert_color 6 | import numpy as np 7 | from datasets.utils import visualize 8 | from datasets.rellis_3d import Rellis3DImages as Dataset 9 | 10 | 11 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | model = torch.load('../config/weights/smp/PSPNet_resnext50_32x4d_704x960_lr0.0001_bs6_epoch18_Rellis3D_iou_0.73.pth') 13 | model = model.to(DEVICE) 14 | model = model.eval() 15 | 16 | # prepare data 17 | test_dataset = Dataset(split='test', crop_size=(704, 960)) 18 | 19 | data_cfg = "../config/rellis.yaml" 20 | CFG = yaml.safe_load(open(data_cfg, 'r')) 21 | id_color_map = CFG["color_map"] 22 | 23 | with torch.no_grad(): 24 | image, gt_mask = test_dataset[0][:2] 25 | x = torch.from_numpy(image.transpose([2, 0, 1])).unsqueeze(0).to(DEVICE) 26 | 27 | pred = model(x) 28 | pred_np = pred.cpu().numpy().squeeze(0) 29 | 30 | pred_arg = np.argmax(pred_np, axis=0).astype(np.uint8) - 1 31 | pred_arg = convert_label(pred_arg, inverse=True) 32 | pred_color = convert_color(pred_arg, id_color_map) 33 | 34 | gt_arg = np.argmax(gt_mask, axis=0).astype(np.uint8) - 1 35 | gt_arg = convert_label(gt_arg, inverse=True) 36 | gt_color = convert_color(gt_arg, id_color_map) 37 | 38 | visualize(prediction=pred_color, gt=gt_color) 39 | -------------------------------------------------------------------------------- /package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | traversability_estimation 4 | 0.0.1 5 | Semantic Segmentation of Images and Point Clouds for Traversability Estimation 6 | https://github.com/RuslanAgishev/traversability_estimation 7 | Ruslan Agishev 8 | Tomas Petricek 9 | 10 | BSD 11 | 12 | catkin 13 | 14 | geometry_msgs 15 | nav_msgs 16 | sensor_msgs 17 | std_msgs 18 | tf2_ros 19 | ros_numpy 20 | 21 | message_generation 22 | message_generation 23 | message_runtime 24 | 25 | -------------------------------------------------------------------------------- /scripts/nodes/cloud_to_depth: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import, division, print_function 3 | import cv2 as cv 4 | import numpy as np 5 | from numpy.lib.recfunctions import structured_to_unstructured 6 | from ros_numpy import msgify, numpify 7 | import rospy 8 | from sensor_msgs.msg import CompressedImage, Image, PointCloud2 9 | 10 | 11 | class CloudToDepth: 12 | def __init__(self): 13 | self.negative = rospy.get_param('~negative', False) 14 | self.image_pub = rospy.Publisher('image', Image, queue_size=2) 15 | self.compressed_pub = rospy.Publisher(self.image_pub.resolved_name + '/compressed', CompressedImage, queue_size=2) 16 | self.cloud_sub = rospy.Subscriber('cloud', PointCloud2, self.on_cloud, queue_size=2) 17 | 18 | def on_cloud(self, cloud_msg): 19 | if self.image_pub.get_num_connections() == 0 and self.compressed_pub.get_num_connections() == 0: 20 | return 21 | cloud = numpify(cloud_msg) 22 | cloud = structured_to_unstructured(cloud[['x', 'y', 'z']]) 23 | depth = 1000.0 * np.linalg.norm(cloud, 2, axis=-1) 24 | depth = depth.clip(np.iinfo(np.uint16).min, np.iinfo(np.uint16).max) 25 | depth = depth.astype(np.uint16) 26 | if self.negative: 27 | depth[depth > 0] = 2**16 - depth[depth > 0] 28 | if self.image_pub.get_num_connections(): 29 | depth_msg = msgify(Image, depth, 'mono16') 30 | depth_msg.header = cloud_msg.header 31 | self.image_pub.publish(depth_msg) 32 | if self.compressed_pub.get_num_connections(): 33 | compressed_msg = CompressedImage() 34 | compressed_msg.header = cloud_msg.header 35 | compressed_msg.format = 'mono16; png compressed' 36 | compressed_msg.data = cv.imencode('.png', depth, [cv.IMWRITE_PNG_COMPRESSION, 5])[1].tobytes() 37 | self.compressed_pub.publish(compressed_msg) 38 | 39 | 40 | def main(): 41 | rospy.init_node('cloud_to_depth', log_level=rospy.INFO) 42 | node = CloudToDepth() 43 | rospy.spin() 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /scripts/nodes/geometric_cloud_segmentation: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import, division, print_function 3 | import numpy as np 4 | from numpy.lib.recfunctions import merge_arrays, unstructured_to_structured 5 | from ros_numpy import msgify, numpify 6 | import rospy 7 | from sensor_msgs.msg import PointCloud2 8 | from tf2_ros import Buffer, TransformException, TransformListener 9 | from traversability_estimation.segmentation import filter_grid, filter_range, compute_rigid_support 10 | 11 | 12 | class GeometricCloudSegmentation(object): 13 | 14 | def __init__(self): 15 | self.max_age = rospy.get_param('~max_age', 0.2) 16 | rospy.loginfo('Max cloud age: %.3f s', self.max_age) 17 | 18 | self.fixed_frame = rospy.get_param('~fixed_frame', 'map') 19 | rospy.loginfo('Fixed frame: %s', self.fixed_frame) 20 | 21 | self.range = rospy.get_param('~range', None) 22 | if self.range is not None: 23 | rospy.loginfo('Range: [%.3f m, %.3f m]', *self.range) 24 | 25 | self.grid = rospy.get_param('~grid', None) 26 | if self.grid is not None: 27 | rospy.loginfo('Grid: %.3f m', self.grid) 28 | 29 | self.z_support = rospy.get_param('~z_support', None) 30 | if self.z_support is not None: 31 | if 'scale' in self.z_support: 32 | scale = self.z_support['scale'] 33 | if isinstance(scale, (float, int)): 34 | scale = np.array([1.0, 1.0, scale]) 35 | scale = scale.reshape((-1, 3)) 36 | self.z_support['scale'] = scale 37 | rospy.loginfo('Z support: %s', self.z_support) 38 | 39 | self.tf = Buffer(rospy.Duration.from_sec(10.0)) 40 | self.tf_sub = TransformListener(self.tf) 41 | 42 | self.output_pub = rospy.Publisher('output', PointCloud2, queue_size=1) 43 | self.input_sub = rospy.Subscriber('input', PointCloud2, self.on_cloud, queue_size=1) 44 | 45 | def preprocess(self, input): 46 | output = input 47 | if self.range is not None: 48 | output = filter_range(output, *self.range) 49 | if self.grid is not None: 50 | output = filter_grid(output, self.grid) 51 | return output 52 | 53 | def compute_features(self, input, input_to_fixed=None): 54 | features = {} 55 | 56 | if self.z_support is not None: 57 | z_support = compute_rigid_support(input, transform=input_to_fixed, **self.z_support) 58 | features['z_support'] = z_support 59 | 60 | arrays = [] 61 | for name in sorted(features): 62 | f = features[name][0] 63 | f = f.flatten().reshape((input.size, -1)) 64 | array = unstructured_to_structured(f, names=[name]) 65 | array = array.reshape(input.shape) 66 | arrays.append(array) 67 | 68 | output = merge_arrays([input] + arrays, flatten=True) 69 | 70 | return output 71 | 72 | def on_cloud(self, input_msg): 73 | if (rospy.Time.now() - input_msg.header.stamp).to_sec() > self.max_age: 74 | return 75 | 76 | try: 77 | input_to_fixed = self.tf.lookup_transform(self.fixed_frame, 78 | input_msg.header.frame_id, input_msg.header.stamp, 79 | timeout=rospy.Duration.from_sec(1.0)) 80 | except TransformException as ex: 81 | rospy.logwarn('Could not transform input: %s', ex) 82 | return 83 | 84 | input_to_fixed = numpify(input_to_fixed.transform) 85 | input = numpify(input_msg) 86 | output = self.preprocess(input) 87 | output = self.compute_features(output, input_to_fixed=input_to_fixed) 88 | output_msg = msgify(PointCloud2, output) 89 | output_msg.header = input_msg.header 90 | self.output_pub.publish(output_msg) 91 | 92 | 93 | def main(): 94 | rospy.init_node('geometric_cloud_segmentation', log_level=rospy.INFO) 95 | node = GeometricCloudSegmentation() 96 | rospy.spin() 97 | 98 | 99 | if __name__ == '__main__': 100 | main() 101 | -------------------------------------------------------------------------------- /scripts/nodes/global_map: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import absolute_import, division, print_function 4 | import rospy 5 | from sensor_msgs.msg import PointCloud2 6 | import tf2_ros 7 | import torch 8 | import numpy as np 9 | from ros_numpy import msgify, numpify 10 | from numpy.lib.recfunctions import structured_to_unstructured, unstructured_to_structured 11 | 12 | 13 | def to_cloud_msg(cloud, time_stamp, frame): 14 | # publish point cloud 15 | if cloud.dtype.names is None: 16 | cloud = unstructured_to_structured(cloud[:, :3], names=['x', 'y', 'z']) 17 | pc_msg = msgify(PointCloud2, cloud) 18 | pc_msg.header.stamp = time_stamp 19 | pc_msg.header.frame_id = frame 20 | return pc_msg 21 | 22 | 23 | class PointsProcessor: 24 | def __init__(self, pc_topic='/points'): 25 | self.odom_frame = rospy.get_param('~odom_frame', 'odom') 26 | self.pc_frame = None 27 | self.clouds = [] 28 | if torch.cuda.is_available(): 29 | self.device = torch.device("cuda:0") 30 | torch.cuda.set_device(self.device) 31 | else: 32 | self.device = torch.device("cpu") 33 | 34 | self.pc_topic = rospy.get_param('~pointcloud_topic', pc_topic) 35 | rospy.loginfo("Subscribed to " + self.pc_topic) 36 | pc_sub = rospy.Subscriber(pc_topic, PointCloud2, self.pc_callback) 37 | self.tf = tf2_ros.Buffer() 38 | self.tl = tf2_ros.TransformListener(self.tf) 39 | 40 | self.map_pc_pub = rospy.Publisher('~map_cloud', PointCloud2, queue_size=1) 41 | 42 | def pc_callback(self, pc_msg): 43 | assert isinstance(pc_msg, PointCloud2) 44 | now = rospy.Time.now() 45 | 46 | cloud_lid = numpify(pc_msg) 47 | # remove inf points 48 | points = structured_to_unstructured(cloud_lid[['x', 'y', 'z']]) 49 | traversability = structured_to_unstructured(cloud_lid[['obstacle']]) 50 | rospy.logdebug('Traversability values: %s', np.unique(traversability)) 51 | 52 | self.pc_frame = pc_msg.header.frame_id 53 | try: 54 | trans = self.tf.lookup_transform(self.odom_frame, self.pc_frame, rospy.Time()) 55 | except (tf2_ros.LookupException, tf2_ros.ConnectivityException, tf2_ros.ExtrapolationException): 56 | rospy.logwarn('No transformation between %s and %s', self.pc_frame, self.odom_frame) 57 | return 58 | pose = numpify(trans.transform) 59 | cloud_map = np.matmul(points, pose[:3, :3].T) + pose[:3, 3:].T 60 | self.clouds.append(cloud_map) 61 | 62 | global_cloud = np.asarray(np.concatenate(self.clouds), dtype=np.float32) 63 | global_cloud = unstructured_to_structured(global_cloud, names=['x', 'y', 'z']) 64 | rospy.logdebug('Global map shape: %s', global_cloud.shape) 65 | 66 | map_msg = to_cloud_msg(global_cloud, time_stamp=now, frame=self.odom_frame) 67 | self.map_pc_pub.publish(map_msg) 68 | 69 | 70 | if __name__ == '__main__': 71 | rospy.init_node('pc_processor_node', log_level=rospy.DEBUG) 72 | proc = PointsProcessor(pc_topic='/points_filtered_kontron_traversability') 73 | rospy.spin() 74 | -------------------------------------------------------------------------------- /scripts/nodes/latch_sensor_info: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import, division, print_function 3 | from rosbag import Bag 4 | import rospy 5 | from cras_ouster_msgs.msg import SensorInfo 6 | 7 | 8 | class LatchSensorInfo(object): 9 | 10 | def __init__(self): 11 | self.bag = rospy.get_param('~bag', None) 12 | self.pub = rospy.Publisher('/os_node/sensor_info', SensorInfo, queue_size=1, latch=True) 13 | self.sub = rospy.Subscriber('/os_node/sensor_info', SensorInfo, self.callback) 14 | self.published = False 15 | if self.bag: 16 | self.publish_from_bag() 17 | 18 | def publish_from_bag(self): 19 | with Bag(self.bag, 'r') as bag: 20 | for _, msg, t in bag.read_messages(topics=['/os_node/sensor_info']): 21 | if rospy.is_shutdown(): 22 | break 23 | self.callback(msg) 24 | 25 | def callback(self, msg): 26 | # assert isinstance(msg, SensorInfo) 27 | if self.published: 28 | return 29 | self.published = True 30 | self.pub.publish(msg) 31 | rospy.loginfo('Sensor info published.') 32 | 33 | 34 | if __name__ == '__main__': 35 | rospy.init_node('latch_sensor_info') 36 | node = LatchSensorInfo() 37 | rospy.spin() 38 | -------------------------------------------------------------------------------- /scripts/nodes/play_tf_static: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import, division, print_function 3 | import rosbag 4 | import rospy 5 | from tf2_msgs.msg import TFMessage 6 | from threading import Lock 7 | 8 | 9 | class PlayTfStatic(object): 10 | 11 | def __init__(self): 12 | self.order_by = rospy.get_param('~order_by', 'capture') 13 | assert self.order_by in ('capture', 'header') 14 | rospy.loginfo('Publish latest static transforms according to %s time.', self.order_by) 15 | 16 | self.bag_paths = rospy.get_param('~bag', []) 17 | if isinstance(self.bag_paths, str): 18 | self.bag_paths = [self.bag_paths] 19 | self.bag_paths += rospy.myargv()[1:] 20 | self.start_time = rospy.get_param('~start_time', None) 21 | self.end_time = rospy.get_param('~end_time', None) 22 | rospy.loginfo('Publish transforms from %s [%s, %s].', 23 | ', '.join(self.bag_paths), self.start_time or 'start', self.end_time or 'end') 24 | 25 | self.lock = Lock() 26 | self.frames = [] 27 | self.transforms = {} 28 | self.times = {} 29 | 30 | self.tf_static_pub = rospy.Publisher('/tf_static', TFMessage, queue_size=1, latch=True) 31 | 32 | # Remap if you want to collect published transforms to avoid delays on start. 33 | self.tf_sub = rospy.Subscriber('~tf_static', TFMessage, self.merge_transforms, queue_size=2) 34 | 35 | def publish_transforms(self): 36 | with self.lock: 37 | self.tf_static_pub.publish(TFMessage(self.transforms.values())) 38 | frames = sorted(self.transforms.keys()) 39 | if frames == self.frames: 40 | return 41 | self.frames = frames 42 | rospy.loginfo('Latching at the latest transforms for child frames %s.', ', '.join(frames)) 43 | 44 | def update_transform(self, tf, t): 45 | with self.lock: 46 | child = tf.child_frame_id 47 | if self.order_by == 'capture': 48 | time = t 49 | else: 50 | time = tf.header.stamp 51 | 52 | if child in self.transforms: 53 | if time <= self.times[child] or tf == self.transforms[child]: 54 | rospy.logdebug('Same or more recent transform for %s already published, do nothing.', child) 55 | return False 56 | else: 57 | rospy.loginfo('Child frame %s updated.', child) 58 | 59 | self.transforms[child] = tf 60 | self.times[child] = time 61 | return True 62 | 63 | def merge_transforms(self, msg, t=None): 64 | if t is None: 65 | t = rospy.Time.now() 66 | # Need to process all transforms: create complete list first. 67 | updated = [self.update_transform(tf, t) for tf in msg.transforms] 68 | if any(updated): 69 | self.publish_transforms() 70 | 71 | def spin(self): 72 | for bag_path in self.bag_paths: 73 | if rospy.is_shutdown(): 74 | break 75 | with rosbag.Bag(bag_path) as bag: 76 | start_time = (None if self.start_time is None 77 | else rospy.Time.from_sec(bag.get_start_time() + self.start_time)) 78 | end_time = (None if self.end_time is None 79 | else rospy.Time.from_sec(bag.get_start_time() + self.end_time)) 80 | for _, msg, t in bag.read_messages(topics=['/tf_static'], start_time=start_time, end_time=end_time): 81 | if rospy.is_shutdown(): 82 | break 83 | self.merge_transforms(msg, t) 84 | rospy.loginfo('Bag files processed.') 85 | 86 | 87 | if __name__ == '__main__': 88 | rospy.init_node('play_tf_static') 89 | node = PlayTfStatic() 90 | node.spin() 91 | rospy.spin() 92 | -------------------------------------------------------------------------------- /scripts/nodes/stamp_twist: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import, division, print_function 3 | import rospy 4 | from geometry_msgs.msg import Twist, TwistStamped 5 | 6 | 7 | class StampTwist(object): 8 | 9 | def __init__(self): 10 | self.frame = rospy.get_param('~frame', 'base_link') 11 | self.pub = rospy.Publisher('twist_stamped', TwistStamped, queue_size=1) 12 | self.sub = rospy.Subscriber('twist', Twist, self.callback) 13 | 14 | def callback(self, msg): 15 | assert isinstance(msg, Twist) 16 | stamped_msg = TwistStamped() 17 | stamped_msg.header.stamp = rospy.Time.now() 18 | stamped_msg.header.frame_id = self.frame 19 | stamped_msg.twist = msg 20 | self.pub.publish(stamped_msg) 21 | 22 | 23 | if __name__ == '__main__': 24 | rospy.init_node('stamp_twist') 25 | node = StampTwist() 26 | rospy.spin() 27 | -------------------------------------------------------------------------------- /scripts/tools/legacy_weights: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Convert model weights to legacy format supported by older PyTorch versions.""" 3 | from __future__ import absolute_import, division, print_function 4 | from argparse import ArgumentParser 5 | import torch 6 | 7 | 8 | def main(): 9 | parser = ArgumentParser() 10 | parser.add_argument('input', type=str) 11 | parser.add_argument('output', type=str) 12 | args = parser.parse_args() 13 | obj = torch.load(args.input, map_location='cpu') 14 | torch.save(obj, args.output, _use_new_zipfile_serialization=False) 15 | print('Model %s converted to %s in legacy format.' % (args.input, args.output)) 16 | 17 | 18 | if __name__ == '__main__': 19 | main() 20 | -------------------------------------------------------------------------------- /scripts/tools/test_depth: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import numpy as np 4 | import torch 5 | from argparse import ArgumentParser 6 | import datasets 7 | import os 8 | from traversability_estimation.utils import visualize_imgs, visualize_cloud 9 | import open3d as o3d 10 | 11 | 12 | def main(): 13 | parser = ArgumentParser() 14 | # parser.add_argument('--dataset', type=str, default='Rellis3DClouds') 15 | # parser.add_argument('--dataset', type=str, default='TraversabilityClouds') 16 | parser.add_argument('--dataset', type=str, default='FlexibilityClouds') 17 | # parser.add_argument('--dataset', type=str, default='SemanticUSL') 18 | parser.add_argument('--weights', type=str, default='deeplabv3_resnet101_lr_0.0001_bs_6_epoch_80_FlexibilityClouds_depth_64x1024_labels_flexibility_iou_0.790.pth') 19 | parser.add_argument('--device', type=str, default='cpu') 20 | args = parser.parse_args() 21 | print(args) 22 | 23 | pkg_path = os.path.realpath(os.path.join(os.path.dirname(__file__), '../../')) 24 | 25 | # Initialize model with the best available weights 26 | model_name = args.weights 27 | assert args.dataset in model_name 28 | model = torch.load(os.path.join(pkg_path, 'config/weights/depth_cloud', model_name), map_location=args.device) 29 | # model = torch.load(model_name, map_location=args.device) 30 | model.eval() 31 | 32 | data_fields = [f[1:-1] for f in ['_x_', '_y_', '_z_', '_intensity_', '_depth_'] if f in model_name] 33 | print('Model takes as input: %s' % ','.join(data_fields)) 34 | 35 | if 'traversability' in model_name.lower(): 36 | output = 'traversability' 37 | ignore_label = 255 38 | elif 'flexibility' in model_name.lower(): 39 | output = 'flexibility' 40 | ignore_label = 255 41 | else: 42 | output = None 43 | ignore_label = 0 44 | 45 | Dataset = eval('datasets.%s' % args.dataset) 46 | ds = Dataset(split='test', fields=data_fields, 47 | output=output, 48 | lidar_H_step=2, lidar_W_step=1) 49 | 50 | for _ in range(5): 51 | # Apply inference preprocessing transforms 52 | inpt, label = ds[np.random.choice(range(len(ds)))] 53 | 54 | depth_img = inpt[0] 55 | power = 16 56 | depth_img_vis = np.copy(depth_img).squeeze() # depth 57 | depth_img_vis[depth_img_vis > 0] = depth_img_vis[depth_img_vis > 0] ** (1 / power) 58 | depth_img_vis[depth_img_vis > 0] = (depth_img_vis[depth_img_vis > 0] - depth_img_vis[depth_img_vis > 0].min()) / \ 59 | (depth_img_vis[depth_img_vis > 0].max() - depth_img_vis[ 60 | depth_img_vis > 0].min()) 61 | 62 | # Use the model and visualize the prediction 63 | batch = torch.from_numpy(inpt).unsqueeze(0).to(args.device) 64 | with torch.no_grad(): 65 | pred = model(batch)['out'] 66 | pred = torch.softmax(pred.squeeze(0), dim=0).cpu().numpy() 67 | pred = np.argmax(pred, axis=0) 68 | pred_ign = pred.copy() 69 | pred_ign[label == ignore_label] = ignore_label 70 | 71 | # label_flex = pred == 1 72 | # depth_img_with_flex_points = (0.3 * depth_img_vis + 0.7 * label_flex).astype("float") 73 | 74 | color_pred = ds.label_to_color(pred) 75 | color_pred_ign = ds.label_to_color(pred_ign) 76 | color_gt = ds.label_to_color(label) 77 | 78 | visualize_imgs(layout='columns', 79 | depth_img=depth_img_vis, 80 | # depth_img_with_flex_points=depth_img_with_flex_points, 81 | prediction=color_pred, 82 | prediction_without_background=color_pred_ign, 83 | ground_truth=color_gt, 84 | ) 85 | 86 | # visualize_cloud(xyz=ds.scan.proj_xyz.reshape((-1, 3)), color=color_pred.reshape((-1, 3))) 87 | # visualize_cloud(xyz=ds.scan.proj_xyz.reshape((-1, 3)), color=color_gt.reshape((-1, 3))) 88 | 89 | pcd = o3d.geometry.PointCloud() 90 | xyz = ds.scan.proj_xyz[::ds.lidar_H_step] 91 | pcd.points = o3d.utility.Vector3dVector(xyz.reshape((-1, 3))) 92 | pcd.colors = o3d.utility.Vector3dVector(color_pred.reshape((-1, 3)) / color_pred.max()) 93 | 94 | pcd_gt = o3d.geometry.PointCloud() 95 | pcd_gt.points = o3d.utility.Vector3dVector(xyz.reshape((-1, 3)) + np.asarray([50, 0, 0])) 96 | pcd_gt.colors = o3d.utility.Vector3dVector(color_gt.reshape((-1, 3)) / color_gt.max()) 97 | 98 | o3d.visualization.draw_geometries([pcd, pcd_gt]) 99 | 100 | 101 | if __name__ == '__main__': 102 | main() 103 | -------------------------------------------------------------------------------- /scripts/tools/test_img: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import absolute_import 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from argparse import ArgumentParser 8 | from matplotlib import pyplot as plt 9 | import datasets 10 | from traversability_estimation.utils import convert_label, convert_color 11 | import yaml 12 | import os 13 | 14 | 15 | def main(): 16 | parser = ArgumentParser() 17 | # parser.add_argument('--dataset', type=str, default='TraversabilityImagesFiftyone') 18 | parser.add_argument('--dataset', type=str, default='Rellis3DImages') 19 | parser.add_argument('--img_size', nargs='+', default=(192, 320)) 20 | parser.add_argument('--device', type=str, default='cpu') 21 | args = parser.parse_args() 22 | print(args) 23 | 24 | Dataset = eval('datasets.%s' % args.dataset) 25 | ds = Dataset(crop_size=args.img_size, split='test') 26 | 27 | # Initialize model with the best available weights 28 | model_name = 'fcn_resnet50_lr_1e-05_bs_2_epoch_44_TraversabilityImages_iou_0.86.pth' 29 | # model_name = 'fcn_resnet50_lr_1e-05_bs_1_epoch_3_TraversabilityImages_iou_0.71.pth' 30 | model_path = os.path.join('../../config/weights/image/', model_name) 31 | model = torch.load(model_path, map_location=args.device).eval() 32 | 33 | pkg_path = os.path.realpath(os.path.join(os.path.dirname(__file__), '../../')) 34 | label_config = os.path.join(pkg_path, "config/rellis.yaml") 35 | data_cfg = yaml.safe_load(open(label_config, 'r')) 36 | 37 | for i in range(5): 38 | # Apply inference preprocessing transforms 39 | img, gt_mask = ds[i] 40 | img_vis = np.uint8(255 * (img * ds.std + ds.mean)) 41 | if ds.split == 'test': 42 | img = img.transpose((2, 0, 1)) # (H x W x C) -> (C x H x W) 43 | batch = torch.from_numpy(img).unsqueeze(0).to(args.device) 44 | 45 | # Use the model and visualize the prediction 46 | with torch.no_grad(): 47 | pred = model(batch)['out'] 48 | pred = torch.softmax(pred, dim=1) 49 | pred = pred.squeeze(0).cpu().numpy() 50 | mask = np.argmax(pred, axis=0) 51 | gt_mask = np.argmax(gt_mask, axis=0) 52 | # mask = convert_label(mask, inverse=True) 53 | size = (args.img_size[1], args.img_size[0]) 54 | mask = cv2.resize(mask.astype('float32'), size, interpolation=cv2.INTER_LINEAR).astype('int8') 55 | 56 | # result = convert_color(mask, data_cfg['color_map']) 57 | result = convert_color(mask, {0: [0, 0, 0], 1: [0, 255, 0], 2: [255, 0, 0]}) 58 | gt_result = convert_color(gt_mask, {0: [0, 0, 0], 1: [0, 255, 0], 2: [255, 0, 0]}) 59 | plt.figure(figsize=(20, 10)) 60 | plt.subplot(1, 3, 1) 61 | plt.imshow(img_vis) 62 | plt.subplot(1, 3, 2) 63 | plt.imshow(result) 64 | plt.subplot(1, 3, 3) 65 | plt.imshow(gt_result) 66 | plt.show() 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /scripts/tools/train_smp: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import datasets 5 | from torch.utils.data import DataLoader 6 | import torch 7 | import segmentation_models_pytorch as smp 8 | from argparse import ArgumentParser 9 | 10 | 11 | def main(): 12 | parser = ArgumentParser() 13 | parser.add_argument('--lr', type=float, default=1e-4) 14 | parser.add_argument('--dataset', type=str, default='Rellis3DImages') 15 | parser.add_argument('--model', type=str, default='Unet') 16 | parser.add_argument('--encoder', type=str, default='resnet34') 17 | parser.add_argument('--encoder_weights', type=str, default='imagenet') 18 | parser.add_argument('--batch_size', type=int, default=1) 19 | parser.add_argument('--img_size', nargs='+', type=int, default=(1184, 1920)) 20 | parser.add_argument('--n_epochs', type=int, default=100) 21 | parser.add_argument('--device', type=str, default='cuda') 22 | parser.add_argument('--n_workers', type=int, default=os.cpu_count()) 23 | parser.add_argument('--num_samples', type=int, default=None) 24 | args = parser.parse_args() 25 | 26 | Dataset = eval('datasets.%s' % args.dataset) 27 | train_dataset = Dataset(crop_size=args.img_size, split='train', num_samples=args.num_samples) 28 | valid_dataset = Dataset(crop_size=(1184, 1920), split='val', num_samples=args.num_samples) 29 | 30 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers) 31 | valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=args.n_workers) 32 | 33 | # create segmentation model with pretrained encoder 34 | architecture = eval('smp.%s' % args.model) 35 | model = architecture( 36 | encoder_name=args.encoder, 37 | encoder_weights=args.encoder_weights, 38 | in_channels=3, 39 | classes=len(train_dataset.CLASSES), 40 | activation='sigmoid' if len(train_dataset.CLASSES) == 1 else 'softmax2d', 41 | ) 42 | model = model.train() 43 | 44 | # Dice/F1 score - https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient 45 | loss_fn = smp.utils.losses.DiceLoss(activation='softmax2d') 46 | 47 | # IoU/Jaccard score - https://en.wikipedia.org/wiki/Jaccard_index 48 | metrics = [smp.utils.metrics.IoU(threshold=0.5)] 49 | 50 | optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=args.lr)]) 51 | 52 | # create epoch runners 53 | # it is a simple loop of iterating over dataloader`s samples 54 | train_epoch = smp.utils.train.TrainEpoch( 55 | model, 56 | loss=loss_fn, 57 | metrics=metrics, 58 | optimizer=optimizer, 59 | device=args.device, 60 | verbose=True, 61 | ) 62 | 63 | valid_epoch = smp.utils.train.ValidEpoch( 64 | model, 65 | loss=loss_fn, 66 | metrics=metrics, 67 | device=args.device, 68 | verbose=True, 69 | ) 70 | 71 | # train model 72 | max_score = 0 73 | for i in range(0, args.n_epochs): 74 | print('\nEpoch: {}'.format(i)) 75 | train_logs = train_epoch.run(train_loader) 76 | valid_logs = valid_epoch.run(valid_loader) 77 | 78 | # do something (save model, change lr, etc.) 79 | if max_score < valid_logs['iou_score']: 80 | max_score = valid_logs['iou_score'] 81 | best_model_name = './%s_%s_%dx%d_lr%g_bs%d_epoch%d_%s_iou_%.2f.pth' %\ 82 | (args.model, args.encoder, args.img_size[0], args.img_size[1], 83 | args.lr, args.batch_size, i, args.dataset, max_score) 84 | torch.save(model, best_model_name) 85 | print('Model %s saved!' % best_model_name) 86 | 87 | if i == 25: 88 | optimizer.param_groups[0]['lr'] = args.lr / 10.0 89 | print('Decrease decoder learning rate!') 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # ! DO NOT MANUALLY INVOKE THIS setup.py, USE CATKIN INSTEAD 2 | from setuptools import setup 3 | from catkin_pkg.python_setup import generate_distutils_setup 4 | 5 | # fetch values from package.xml 6 | setup_args = generate_distutils_setup( 7 | packages=['traversability_estimation', 'hrnet', 'datasets', 'tconcord3d'], 8 | package_dir={'': 'src'}) 9 | 10 | setup(**setup_args) 11 | -------------------------------------------------------------------------------- /singularity/.gitignore: -------------------------------------------------------------------------------- 1 | *.sif 2 | *.simg 3 | -------------------------------------------------------------------------------- /singularity/build.sh: -------------------------------------------------------------------------------- 1 | # Building singularity image from the def file 2 | sudo singularity build --nv traversability_estimation.sif recepie.def 3 | -------------------------------------------------------------------------------- /singularity/recepie.def: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: ros:noetic-perception 3 | 4 | %files 5 | requirements.txt 6 | 7 | %post 8 | export XDG_CACHE_HOME=/tmp/singularity-cache # pip cache 9 | 10 | # Install Apt packages 11 | packages=" 12 | gcc 13 | g++ 14 | bridge-utils 15 | build-essential 16 | htop 17 | net-tools 18 | screen 19 | sshpass 20 | tmux 21 | vim 22 | wget 23 | curl 24 | git 25 | python3-pip 26 | python3-catkin-tools 27 | ros-noetic-ros-numpy 28 | ros-noetic-jsk-rviz-plugins 29 | ros-noetic-rviz" 30 | 31 | apt update 32 | apt install -y ${packages} 33 | 34 | # Pytorch 35 | # pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html 36 | pip install torch torchvision torchaudio torchmetrics tensorboard --extra-index-url https://download.pytorch.org/whl/cu113 37 | # Install python packages 38 | pip install -r ${SINGULARITY_ROOTFS}/requirements.txt 39 | 40 | ln -s /usr/bin/python3 /usr/bin/python 41 | -------------------------------------------------------------------------------- /singularity/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.62.3 2 | pillow==8.4.0 3 | matplotlib==3.2.1 4 | pyyaml==6.0 5 | pathlib==1.0.1 6 | yacs==0.1.6 7 | scipy==1.6.3 8 | setuptools==58.0.4 9 | rospkg==1.3.0 10 | python-dateutil==2.8.2 11 | empy==3.3.4 12 | gnupg==2.3.1 13 | configparser==5.0.2 14 | psutil==5.8.0 15 | defusedxml==0.7.1 16 | scikit-image==0.18.1 17 | sklearn==0.0 18 | scikit-learn==1.0 19 | six==1.15.0 20 | segmentation_models_pytorch==0.2.1 21 | albumentations==1.2.1 22 | opencv-python==4.2.0.32 23 | open3d==0.10.0.0 24 | #fiftyone==0.16.5 25 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | from .rellis_3d import Rellis3DImages, Rellis3DClouds 4 | from .semantic import SemanticUSL, SemanticKITTI 5 | from .cwt import CWT 6 | from .traversability_dataset import TraversabilityImages51, FlexibilityClouds, TraversabilityClouds 7 | -------------------------------------------------------------------------------- /src/datasets/cwt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import dirname, join, realpath 3 | from traversability_estimation.utils import * 4 | from .base_dataset import BaseDatasetImages 5 | import numpy as np 6 | 7 | __all__ = [ 8 | 'data_dir', 9 | 'CWT', 10 | ] 11 | 12 | data_dir = realpath(join(dirname(__file__), '..', '..', 'data')) 13 | 14 | 15 | class CWT(BaseDatasetImages): 16 | CLASSES = ["flat", "bumpy", "water", "rock", "mixed", "excavator", "obstacle"] 17 | PALETTE = [[0, 255, 0], [255, 255, 0], [255, 0, 0], [128, 0, 0], [100, 65, 0], [0, 255, 255], [0, 0, 255]] 18 | 19 | def __init__(self, 20 | path=None, 21 | split='train', 22 | num_samples=None, 23 | classes=None, 24 | multi_scale=True, # TODO: fix padding, background must be black for masks (0) 25 | flip=True, 26 | ignore_label=-1, 27 | base_size=2048, 28 | crop_size=(1200, 1920), 29 | downsample_rate=1, 30 | scale_factor=16, 31 | mean=np.asarray([0.0, 0.0, 0.0]), 32 | std=np.asarray([1.0, 1.0, 1.0])): 33 | super(CWT, self).__init__(ignore_label, base_size, crop_size, downsample_rate, scale_factor, mean, std, ) 34 | 35 | if path is None: 36 | path = join(data_dir, 'CWT') 37 | assert os.path.exists(path) 38 | assert split in ['train', 'test', 'val'] 39 | # validation dataset is called 'test' 40 | if split == 'val': 41 | split = 'test' 42 | self.path = path 43 | self.split = split 44 | if not classes: 45 | classes = self.CLASSES 46 | # convert str names to class values on masks 47 | self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes] 48 | self.palette_values = [self.PALETTE[c] for c in self.class_values] 49 | self.color_map = {} 50 | for k, v in zip(self.class_values, self.palette_values): 51 | self.color_map[k] = v 52 | 53 | self.base_size = base_size 54 | self.crop_size = crop_size 55 | self.ignore_label = ignore_label 56 | 57 | self.mean = mean 58 | self.std = std 59 | self.scale_factor = scale_factor 60 | self.downsample_rate = 1. / downsample_rate 61 | 62 | self.multi_scale = multi_scale 63 | self.flip = flip 64 | 65 | self.fps = np.genfromtxt(os.path.join(path, '%s.txt' % split), dtype=str).tolist() 66 | 67 | self.files = self.read_files() 68 | if num_samples: 69 | self.files = self.files[:num_samples] 70 | 71 | def read_files(self): 72 | files = [] 73 | for path in self.fps: 74 | name = path.split('/')[1] 75 | files.append({ 76 | "img": os.path.join(self.path, 'img', '%s.jpg' % path), 77 | "label": os.path.join(self.path, 'annotation/grey_mask', '%s.png' % path), 78 | "name": name, 79 | }) 80 | return files 81 | 82 | def __getitem__(self, index): 83 | item = self.files[index] 84 | image = cv2.imread(item["img"], cv2.IMREAD_COLOR) 85 | 86 | mask = np.array(cv2.imread(item["label"], 0)) 87 | 88 | # add augmentations 89 | image, mask = self.apply_augmentations(image, mask, self.multi_scale, self.flip) 90 | 91 | # extract certain classes from mask 92 | masks = [(mask == v) for v in self.class_values] 93 | mask = np.stack(masks, axis=0).astype('float') 94 | 95 | return image.copy(), mask.copy() 96 | 97 | 98 | def demo(): 99 | # split = np.random.choice(['test', 'train', 'val']) 100 | split = 'train' 101 | ds = CWT(split=split) 102 | 103 | for _ in range(5): 104 | image, gt_mask = ds[int(np.random.choice(range(len(ds))))] 105 | image = image.transpose([1, 2, 0]) 106 | image_vis = np.uint8(255 * (image * ds.std + ds.mean)) 107 | 108 | gt_arg = np.argmax(gt_mask, axis=0).astype(np.uint8) 109 | gt_color = convert_color(gt_arg, ds.color_map) 110 | 111 | visualize_imgs( 112 | image=image_vis, 113 | label=gt_color, 114 | ) 115 | 116 | 117 | def main(): 118 | demo() 119 | 120 | 121 | if __name__ == '__main__': 122 | main() 123 | -------------------------------------------------------------------------------- /src/datasets/traversability_cloud.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import open3d as o3d 4 | from datasets.base_dataset import data_dir 5 | from segments import SegmentsClient, SegmentsDataset 6 | from sklearn.model_selection import train_test_split 7 | 8 | 9 | class TraversabilityCloud(object): 10 | def __init__(self, path: str, version: str = "v1.1", split: str = None): 11 | self.path = path 12 | self.split = split 13 | 14 | self.version = version 15 | self.api_key = '4bacc032570420552ef6b038e1a1e8383ac372d9' 16 | self.dataset_name = 'aleskucera/Pointcloud_traversability' 17 | 18 | self.label_map = {0: 0, 19 | 1: 1, 20 | 2: 255} 21 | 22 | self.color_map = {0: [0, 0, 0], 23 | 1: [0, 255, 0], 24 | 255: [255, 0, 0]} 25 | 26 | self.point_clouds, self.labels = self._init_dataset() 27 | 28 | def _init_dataset(self) -> (list, list): 29 | point_clouds = [] 30 | labels = [] 31 | # load and format dataset annotations 32 | client = SegmentsClient(self.api_key) 33 | release = client.get_release(self.dataset_name, self.version) 34 | dataset = SegmentsDataset(release, labelset='ground-truth', filter_by=['REVIEWED']) 35 | samples = dataset.samples 36 | for sample in samples: 37 | # get attributes of the label 38 | attributes = sample["labels"]["ground-truth"]["attributes"] 39 | point_annotations = attributes["point_annotations"] 40 | annotations = attributes["annotations"] 41 | 42 | # append sample to dataset 43 | point_clouds.append(self._get_path(sample["name"])) 44 | labels.append(self._map_annotations(point_annotations, annotations)) 45 | return self._generate_split(point_clouds, labels) 46 | 47 | def _map_annotations(self, point_annotations: list, annotations: list) -> np.ndarray: 48 | ret = [] 49 | 50 | # map annotations by self.label_map 51 | mapped_annotations = {0: 0} 52 | for annotation in annotations: 53 | mapped_annotations[annotation["id"]] = self.label_map[annotation["category_id"]] 54 | 55 | for instance_id in point_annotations: 56 | category_id = mapped_annotations[instance_id] 57 | ret.append(category_id) 58 | return np.array(ret) 59 | 60 | def _get_path(self, name: str) -> str: 61 | return os.path.join(self.path, name) 62 | 63 | def _generate_split(self, X: list, y: list, test_ratio=0.2) -> (list, list): 64 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_ratio, random_state=42) 65 | if self.split == 'train': 66 | X = X_train 67 | y = y_train 68 | elif self.split in ['val', 'test']: 69 | X = X_test 70 | y = y_test 71 | return X, y 72 | 73 | def __getitem__(self, index: int) -> (np.ndarray, np.ndarray): 74 | point_cloud = o3d.io.read_point_cloud(self.point_clouds[index]) 75 | point_cloud = np.asarray(point_cloud.points).reshape((128, -1, 3)) 76 | label = self.labels[index].reshape((128, -1)) 77 | return point_cloud, label 78 | 79 | def __len__(self) -> int: 80 | return len(self.point_clouds) 81 | 82 | def visualize_sample(self, index: int) -> None: 83 | point_cloud = o3d.io.read_point_cloud(self.point_clouds[index]) 84 | 85 | colors = np.array([self.color_map[label] for label in self.labels[index]]) 86 | point_cloud.colors = o3d.utility.Vector3dVector(colors) 87 | o3d.visualization.draw_geometries([point_cloud]) 88 | 89 | 90 | def main(): 91 | # directory = "/home/ales/Datasets/points_colored" 92 | directory = os.path.join(data_dir, "TraversabilityDataset/supervised/clouds/destaggered_points_colored/") 93 | dataset = TraversabilityCloud(directory) 94 | print(f"INFO: Initialized dataset split type: {dataset.split}") 95 | print(f"INFO: Split contains {len(dataset)} samples.") 96 | for i, sample in enumerate(dataset): 97 | point_cloud, label = sample 98 | print(f"INFO: Sample {i} has shape {point_cloud.shape} and label {label.shape}") 99 | dataset.visualize_sample(i) 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /src/hrnet/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | -------------------------------------------------------------------------------- /src/hrnet/config/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | from .default import _C as config 11 | from .default import update_config 12 | from .models import MODEL_EXTRAS 13 | -------------------------------------------------------------------------------- /src/hrnet/config/default.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | 13 | from yacs.config import CfgNode as CN 14 | 15 | _C = CN() 16 | 17 | _C.OUTPUT_DIR = '' 18 | _C.LOG_DIR = '' 19 | _C.GPUS = (0,) 20 | _C.WORKERS = 4 21 | _C.PRINT_FREQ = 20 22 | _C.AUTO_RESUME = False 23 | _C.PIN_MEMORY = True 24 | _C.RANK = 0 25 | 26 | # Cudnn related params 27 | _C.CUDNN = CN() 28 | _C.CUDNN.BENCHMARK = True 29 | _C.CUDNN.DETERMINISTIC = False 30 | _C.CUDNN.ENABLED = True 31 | 32 | # common params for NETWORK 33 | _C.MODEL = CN() 34 | _C.MODEL.NAME = 'seg_hrnet' 35 | pkg_path = os.path.realpath(os.path.join(os.path.dirname(__file__), '../../..')) 36 | _C.MODEL.PRETRAINED = "%s/models/seg_hrnet_ocr_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484/best.pth" % pkg_path 37 | _C.MODEL.ALIGN_CORNERS = True 38 | _C.MODEL.NUM_OUTPUTS = 1 39 | _C.MODEL.EXTRA = CN(new_allowed=True) 40 | 41 | _C.MODEL.OCR = CN() 42 | _C.MODEL.OCR.MID_CHANNELS = 512 43 | _C.MODEL.OCR.KEY_CHANNELS = 256 44 | _C.MODEL.OCR.DROPOUT = 0.05 45 | _C.MODEL.OCR.SCALE = 1 46 | 47 | _C.LOSS = CN() 48 | _C.LOSS.USE_OHEM = False 49 | _C.LOSS.OHEMTHRES = 0.9 50 | _C.LOSS.OHEMKEEP = 100000 51 | _C.LOSS.CLASS_BALANCE = False 52 | _C.LOSS.BALANCE_WEIGHTS = [1] 53 | 54 | # DATASET related params 55 | _C.DATASET = CN() 56 | _C.DATASET.ROOT = '' 57 | _C.DATASET.DATASET = 'cityscapes' 58 | _C.DATASET.NUM_CLASSES = 19 59 | _C.DATASET.TRAIN_SET = 'list/cityscapes/train.lst' 60 | _C.DATASET.EXTRA_TRAIN_SET = '' 61 | _C.DATASET.TEST_SET = 'list/cityscapes/val.lst' 62 | 63 | # training 64 | _C.TRAIN = CN() 65 | 66 | _C.TRAIN.FREEZE_LAYERS = '' 67 | _C.TRAIN.FREEZE_EPOCHS = -1 68 | _C.TRAIN.NONBACKBONE_KEYWORDS = [] 69 | _C.TRAIN.NONBACKBONE_MULT = 10 70 | 71 | _C.TRAIN.IMAGE_SIZE = [1024, 512] # width * height 72 | _C.TRAIN.BASE_SIZE = 2048 73 | _C.TRAIN.DOWNSAMPLERATE = 1 74 | _C.TRAIN.FLIP = True 75 | _C.TRAIN.MULTI_SCALE = True 76 | _C.TRAIN.SCALE_FACTOR = 16 77 | 78 | _C.TRAIN.RANDOM_BRIGHTNESS = False 79 | _C.TRAIN.RANDOM_BRIGHTNESS_SHIFT_VALUE = 10 80 | 81 | _C.TRAIN.LR_FACTOR = 0.1 82 | _C.TRAIN.LR_STEP = [90, 110] 83 | _C.TRAIN.LR = 0.01 84 | _C.TRAIN.EXTRA_LR = 0.001 85 | 86 | _C.TRAIN.OPTIMIZER = 'sgd' 87 | _C.TRAIN.MOMENTUM = 0.9 88 | _C.TRAIN.WD = 0.0001 89 | _C.TRAIN.NESTEROV = False 90 | _C.TRAIN.IGNORE_LABEL = -1 91 | 92 | _C.TRAIN.BEGIN_EPOCH = 0 93 | _C.TRAIN.END_EPOCH = 484 94 | _C.TRAIN.EXTRA_EPOCH = 0 95 | 96 | _C.TRAIN.RESUME = False 97 | 98 | _C.TRAIN.BATCH_SIZE_PER_GPU = 32 99 | _C.TRAIN.SHUFFLE = True 100 | # only using some training samples 101 | _C.TRAIN.NUM_SAMPLES = 0 102 | 103 | # testing 104 | _C.TEST = CN() 105 | 106 | _C.TEST.IMAGE_SIZE = [2048, 1024] # width * height 107 | _C.TEST.BASE_SIZE = 2048 108 | 109 | _C.TEST.BATCH_SIZE_PER_GPU = 32 110 | # only testing some samples 111 | _C.TEST.NUM_SAMPLES = 0 112 | 113 | _C.TEST.MODEL_FILE = '' 114 | _C.TEST.FLIP_TEST = False 115 | _C.TEST.MULTI_SCALE = False 116 | _C.TEST.SCALE_LIST = [1] 117 | 118 | _C.TEST.OUTPUT_INDEX = -1 119 | 120 | # debug 121 | _C.DEBUG = CN() 122 | _C.DEBUG.DEBUG = False 123 | _C.DEBUG.SAVE_BATCH_IMAGES_GT = False 124 | _C.DEBUG.SAVE_BATCH_IMAGES_PRED = False 125 | _C.DEBUG.SAVE_HEATMAPS_GT = False 126 | _C.DEBUG.SAVE_HEATMAPS_PRED = False 127 | 128 | 129 | def update_config(cfg, args): 130 | cfg.defrost() 131 | 132 | cfg.merge_from_file(args.cfg) 133 | cfg.merge_from_list(args.opts) 134 | 135 | cfg.freeze() 136 | 137 | 138 | if __name__ == '__main__': 139 | import sys 140 | 141 | with open(sys.argv[1], 'w') as f: 142 | print(_C, file=f) 143 | -------------------------------------------------------------------------------- /src/hrnet/config/hrnet_config.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Create by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn), Rainbowsecret (yuyua@microsoft.com) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | from yacs.config import CfgNode as CN 13 | 14 | 15 | # configs for HRNet48 16 | HRNET_48 = CN() 17 | HRNET_48.FINAL_CONV_KERNEL = 1 18 | 19 | HRNET_48.STAGE1 = CN() 20 | HRNET_48.STAGE1.NUM_MODULES = 1 21 | HRNET_48.STAGE1.NUM_BRANCHES = 1 22 | HRNET_48.STAGE1.NUM_BLOCKS = [4] 23 | HRNET_48.STAGE1.NUM_CHANNELS = [64] 24 | HRNET_48.STAGE1.BLOCK = 'BOTTLENECK' 25 | HRNET_48.STAGE1.FUSE_METHOD = 'SUM' 26 | 27 | HRNET_48.STAGE2 = CN() 28 | HRNET_48.STAGE2.NUM_MODULES = 1 29 | HRNET_48.STAGE2.NUM_BRANCHES = 2 30 | HRNET_48.STAGE2.NUM_BLOCKS = [4, 4] 31 | HRNET_48.STAGE2.NUM_CHANNELS = [48, 96] 32 | HRNET_48.STAGE2.BLOCK = 'BASIC' 33 | HRNET_48.STAGE2.FUSE_METHOD = 'SUM' 34 | 35 | HRNET_48.STAGE3 = CN() 36 | HRNET_48.STAGE3.NUM_MODULES = 4 37 | HRNET_48.STAGE3.NUM_BRANCHES = 3 38 | HRNET_48.STAGE3.NUM_BLOCKS = [4, 4, 4] 39 | HRNET_48.STAGE3.NUM_CHANNELS = [48, 96, 192] 40 | HRNET_48.STAGE3.BLOCK = 'BASIC' 41 | HRNET_48.STAGE3.FUSE_METHOD = 'SUM' 42 | 43 | HRNET_48.STAGE4 = CN() 44 | HRNET_48.STAGE4.NUM_MODULES = 3 45 | HRNET_48.STAGE4.NUM_BRANCHES = 4 46 | HRNET_48.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 47 | HRNET_48.STAGE4.NUM_CHANNELS = [48, 96, 192, 384] 48 | HRNET_48.STAGE4.BLOCK = 'BASIC' 49 | HRNET_48.STAGE4.FUSE_METHOD = 'SUM' 50 | 51 | 52 | # configs for HRNet32 53 | HRNET_32 = CN() 54 | HRNET_32.FINAL_CONV_KERNEL = 1 55 | 56 | HRNET_32.STAGE1 = CN() 57 | HRNET_32.STAGE1.NUM_MODULES = 1 58 | HRNET_32.STAGE1.NUM_BRANCHES = 1 59 | HRNET_32.STAGE1.NUM_BLOCKS = [4] 60 | HRNET_32.STAGE1.NUM_CHANNELS = [64] 61 | HRNET_32.STAGE1.BLOCK = 'BOTTLENECK' 62 | HRNET_32.STAGE1.FUSE_METHOD = 'SUM' 63 | 64 | HRNET_32.STAGE2 = CN() 65 | HRNET_32.STAGE2.NUM_MODULES = 1 66 | HRNET_32.STAGE2.NUM_BRANCHES = 2 67 | HRNET_32.STAGE2.NUM_BLOCKS = [4, 4] 68 | HRNET_32.STAGE2.NUM_CHANNELS = [32, 64] 69 | HRNET_32.STAGE2.BLOCK = 'BASIC' 70 | HRNET_32.STAGE2.FUSE_METHOD = 'SUM' 71 | 72 | HRNET_32.STAGE3 = CN() 73 | HRNET_32.STAGE3.NUM_MODULES = 4 74 | HRNET_32.STAGE3.NUM_BRANCHES = 3 75 | HRNET_32.STAGE3.NUM_BLOCKS = [4, 4, 4] 76 | HRNET_32.STAGE3.NUM_CHANNELS = [32, 64, 128] 77 | HRNET_32.STAGE3.BLOCK = 'BASIC' 78 | HRNET_32.STAGE3.FUSE_METHOD = 'SUM' 79 | 80 | HRNET_32.STAGE4 = CN() 81 | HRNET_32.STAGE4.NUM_MODULES = 3 82 | HRNET_32.STAGE4.NUM_BRANCHES = 4 83 | HRNET_32.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 84 | HRNET_32.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] 85 | HRNET_32.STAGE4.BLOCK = 'BASIC' 86 | HRNET_32.STAGE4.FUSE_METHOD = 'SUM' 87 | 88 | 89 | # configs for HRNet18 90 | HRNET_18 = CN() 91 | HRNET_18.FINAL_CONV_KERNEL = 1 92 | 93 | HRNET_18.STAGE1 = CN() 94 | HRNET_18.STAGE1.NUM_MODULES = 1 95 | HRNET_18.STAGE1.NUM_BRANCHES = 1 96 | HRNET_18.STAGE1.NUM_BLOCKS = [4] 97 | HRNET_18.STAGE1.NUM_CHANNELS = [64] 98 | HRNET_18.STAGE1.BLOCK = 'BOTTLENECK' 99 | HRNET_18.STAGE1.FUSE_METHOD = 'SUM' 100 | 101 | HRNET_18.STAGE2 = CN() 102 | HRNET_18.STAGE2.NUM_MODULES = 1 103 | HRNET_18.STAGE2.NUM_BRANCHES = 2 104 | HRNET_18.STAGE2.NUM_BLOCKS = [4, 4] 105 | HRNET_18.STAGE2.NUM_CHANNELS = [18, 36] 106 | HRNET_18.STAGE2.BLOCK = 'BASIC' 107 | HRNET_18.STAGE2.FUSE_METHOD = 'SUM' 108 | 109 | HRNET_18.STAGE3 = CN() 110 | HRNET_18.STAGE3.NUM_MODULES = 4 111 | HRNET_18.STAGE3.NUM_BRANCHES = 3 112 | HRNET_18.STAGE3.NUM_BLOCKS = [4, 4, 4] 113 | HRNET_18.STAGE3.NUM_CHANNELS = [18, 36, 72] 114 | HRNET_18.STAGE3.BLOCK = 'BASIC' 115 | HRNET_18.STAGE3.FUSE_METHOD = 'SUM' 116 | 117 | HRNET_18.STAGE4 = CN() 118 | HRNET_18.STAGE4.NUM_MODULES = 3 119 | HRNET_18.STAGE4.NUM_BRANCHES = 4 120 | HRNET_18.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 121 | HRNET_18.STAGE4.NUM_CHANNELS = [18, 36, 72, 144] 122 | HRNET_18.STAGE4.BLOCK = 'BASIC' 123 | HRNET_18.STAGE4.FUSE_METHOD = 'SUM' 124 | 125 | 126 | MODEL_CONFIGS = { 127 | 'hrnet18': HRNET_18, 128 | 'hrnet32': HRNET_32, 129 | 'hrnet48': HRNET_48, 130 | } -------------------------------------------------------------------------------- /src/hrnet/config/models.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | from yacs.config import CfgNode as CN 12 | 13 | # high_resoluton_net related params for segmentation 14 | HIGH_RESOLUTION_NET = CN() 15 | HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*'] 16 | HIGH_RESOLUTION_NET.STEM_INPLANES = 64 17 | HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1 18 | HIGH_RESOLUTION_NET.WITH_HEAD = True 19 | 20 | HIGH_RESOLUTION_NET.STAGE2 = CN() 21 | HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1 22 | HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2 23 | HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4] 24 | HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64] 25 | HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC' 26 | HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM' 27 | 28 | HIGH_RESOLUTION_NET.STAGE3 = CN() 29 | HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1 30 | HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3 31 | HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4] 32 | HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128] 33 | HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC' 34 | HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM' 35 | 36 | HIGH_RESOLUTION_NET.STAGE4 = CN() 37 | HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1 38 | HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4 39 | HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 40 | HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] 41 | HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC' 42 | HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM' 43 | 44 | MODEL_EXTRAS = { 45 | 'seg_hrnet': HIGH_RESOLUTION_NET, 46 | } 47 | -------------------------------------------------------------------------------- /src/hrnet/core/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | -------------------------------------------------------------------------------- /src/hrnet/core/criterion.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | import logging 11 | from hrnet.config import config 12 | 13 | 14 | class CrossEntropy(nn.Module): 15 | def __init__(self, ignore_label=-1, weight=None): 16 | super(CrossEntropy, self).__init__() 17 | self.ignore_label = ignore_label 18 | self.criterion = nn.CrossEntropyLoss( 19 | weight=weight, 20 | ignore_index=ignore_label 21 | ) 22 | 23 | def _forward(self, score, target): 24 | ph, pw = score.size(2), score.size(3) 25 | h, w = target.size(1), target.size(2) 26 | if ph != h or pw != w: 27 | score = F.interpolate(input=score, size=( 28 | h, w), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS) 29 | 30 | loss = self.criterion(score, target) 31 | 32 | return loss 33 | 34 | def forward(self, score, target): 35 | 36 | if config.MODEL.NUM_OUTPUTS == 1: 37 | score = [score] 38 | 39 | weights = config.LOSS.BALANCE_WEIGHTS 40 | assert len(weights) == len(score) 41 | 42 | return sum([w * self._forward(x, target) for (w, x) in zip(weights, score)]) 43 | 44 | 45 | class OhemCrossEntropy(nn.Module): 46 | def __init__(self, ignore_label=-1, thres=0.7, 47 | min_kept=100000, weight=None): 48 | super(OhemCrossEntropy, self).__init__() 49 | self.thresh = thres 50 | self.min_kept = max(1, min_kept) 51 | self.ignore_label = ignore_label 52 | self.criterion = nn.CrossEntropyLoss( 53 | weight=weight, 54 | ignore_index=ignore_label, 55 | reduction='none' 56 | ) 57 | 58 | def _ce_forward(self, score, target): 59 | ph, pw = score.size(2), score.size(3) 60 | h, w = target.size(1), target.size(2) 61 | if ph != h or pw != w: 62 | score = F.interpolate(input=score, size=( 63 | h, w), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS) 64 | 65 | loss = self.criterion(score, target) 66 | 67 | return loss 68 | 69 | def _ohem_forward(self, score, target, **kwargs): 70 | ph, pw = score.size(2), score.size(3) 71 | h, w = target.size(1), target.size(2) 72 | if ph != h or pw != w: 73 | score = F.interpolate(input=score, size=( 74 | h, w), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS) 75 | pred = F.softmax(score, dim=1) 76 | pixel_losses = self.criterion(score, target).contiguous().view(-1) 77 | mask = target.contiguous().view(-1) != self.ignore_label 78 | 79 | tmp_target = target.clone() 80 | tmp_target[tmp_target == self.ignore_label] = 0 81 | pred = pred.gather(1, tmp_target.unsqueeze(1)) 82 | pred, ind = pred.contiguous().view(-1,)[mask].contiguous().sort() 83 | min_value = pred[min(self.min_kept, pred.numel() - 1)] 84 | threshold = max(min_value, self.thresh) 85 | 86 | pixel_losses = pixel_losses[mask][ind] 87 | pixel_losses = pixel_losses[pred < threshold] 88 | return pixel_losses.mean() 89 | 90 | def forward(self, score, target): 91 | 92 | if config.MODEL.NUM_OUTPUTS == 1: 93 | score = [score] 94 | 95 | weights = config.LOSS.BALANCE_WEIGHTS 96 | assert len(weights) == len(score) 97 | 98 | functions = [self._ce_forward] * \ 99 | (len(weights) - 1) + [self._ohem_forward] 100 | return sum([ 101 | w * func(x, target) 102 | for (w, x, func) in zip(weights, score, functions) 103 | ]) 104 | -------------------------------------------------------------------------------- /src/hrnet/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | from .rellis import Rellis as rellis 12 | -------------------------------------------------------------------------------- /src/hrnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import hrnet.models.seg_hrnet 12 | import hrnet.models.seg_hrnet_ocr 13 | -------------------------------------------------------------------------------- /src/hrnet/models/bn_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import functools 3 | 4 | if torch.__version__.startswith('0'): 5 | from .sync_bn.inplace_abn.bn import InPlaceABNSync 6 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') 7 | BatchNorm2d_class = InPlaceABNSync 8 | relu_inplace = False 9 | else: 10 | # BatchNorm2d_class = BatchNorm2d = torch.nn.SyncBatchNorm # Cannot be used on CPU. 11 | BatchNorm2d_class = BatchNorm2d = torch.nn.BatchNorm2d 12 | relu_inplace = True -------------------------------------------------------------------------------- /src/hrnet/models/sync_bn/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2017, mapillary 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /src/hrnet/models/sync_bn/__init__.py: -------------------------------------------------------------------------------- 1 | from .inplace_abn import bn -------------------------------------------------------------------------------- /src/hrnet/models/sync_bn/inplace_abn/__init__.py: -------------------------------------------------------------------------------- 1 | from .bn import ABN, InPlaceABN, InPlaceABNSync 2 | from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE 3 | -------------------------------------------------------------------------------- /src/hrnet/models/sync_bn/inplace_abn/src/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | /* 6 | * General settings 7 | */ 8 | const int WARP_SIZE = 32; 9 | const int MAX_BLOCK_SIZE = 512; 10 | 11 | template 12 | struct Pair { 13 | T v1, v2; 14 | __device__ Pair() {} 15 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} 16 | __device__ Pair(T v) : v1(v), v2(v) {} 17 | __device__ Pair(int v) : v1(v), v2(v) {} 18 | __device__ Pair &operator+=(const Pair &a) { 19 | v1 += a.v1; 20 | v2 += a.v2; 21 | return *this; 22 | } 23 | }; 24 | 25 | /* 26 | * Utility functions 27 | */ 28 | template 29 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, 30 | unsigned int mask = 0xffffffff) { 31 | #if CUDART_VERSION >= 9000 32 | return __shfl_xor_sync(mask, value, laneMask, width); 33 | #else 34 | return __shfl_xor(value, laneMask, width); 35 | #endif 36 | } 37 | 38 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 39 | 40 | static int getNumThreads(int nElem) { 41 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; 42 | for (int i = 0; i != 5; ++i) { 43 | if (nElem <= threadSizes[i]) { 44 | return threadSizes[i]; 45 | } 46 | } 47 | return MAX_BLOCK_SIZE; 48 | } 49 | 50 | template 51 | static __device__ __forceinline__ T warpSum(T val) { 52 | #if __CUDA_ARCH__ >= 300 53 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 54 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 55 | } 56 | #else 57 | __shared__ T values[MAX_BLOCK_SIZE]; 58 | values[threadIdx.x] = val; 59 | __threadfence_block(); 60 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 61 | for (int i = 1; i < WARP_SIZE; i++) { 62 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 63 | } 64 | #endif 65 | return val; 66 | } 67 | 68 | template 69 | static __device__ __forceinline__ Pair warpSum(Pair value) { 70 | value.v1 = warpSum(value.v1); 71 | value.v2 = warpSum(value.v2); 72 | return value; 73 | } 74 | 75 | template 76 | __device__ T reduce(Op op, int plane, int N, int C, int S) { 77 | T sum = (T)0; 78 | for (int batch = 0; batch < N; ++batch) { 79 | for (int x = threadIdx.x; x < S; x += blockDim.x) { 80 | sum += op(batch, plane, x); 81 | } 82 | } 83 | 84 | // sum over NumThreads within a warp 85 | sum = warpSum(sum); 86 | 87 | // 'transpose', and reduce within warp again 88 | __shared__ T shared[32]; 89 | __syncthreads(); 90 | if (threadIdx.x % WARP_SIZE == 0) { 91 | shared[threadIdx.x / WARP_SIZE] = sum; 92 | } 93 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 94 | // zero out the other entries in shared 95 | shared[threadIdx.x] = (T)0; 96 | } 97 | __syncthreads(); 98 | if (threadIdx.x / WARP_SIZE == 0) { 99 | sum = warpSum(shared[threadIdx.x]); 100 | if (threadIdx.x == 0) { 101 | shared[0] = sum; 102 | } 103 | } 104 | __syncthreads(); 105 | 106 | // Everyone picks it up, should be broadcast into the whole gradInput 107 | return shared[0]; 108 | } -------------------------------------------------------------------------------- /src/hrnet/models/sync_bn/inplace_abn/src/inplace_abn.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "inplace_abn.h" 6 | 7 | std::vector mean_var(at::Tensor x) { 8 | if (x.is_cuda()) { 9 | return mean_var_cuda(x); 10 | } else { 11 | return mean_var_cpu(x); 12 | } 13 | } 14 | 15 | at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 16 | bool affine, float eps) { 17 | if (x.is_cuda()) { 18 | return forward_cuda(x, mean, var, weight, bias, affine, eps); 19 | } else { 20 | return forward_cpu(x, mean, var, weight, bias, affine, eps); 21 | } 22 | } 23 | 24 | std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 25 | bool affine, float eps) { 26 | if (z.is_cuda()) { 27 | return edz_eydz_cuda(z, dz, weight, bias, affine, eps); 28 | } else { 29 | return edz_eydz_cpu(z, dz, weight, bias, affine, eps); 30 | } 31 | } 32 | 33 | std::vector backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 34 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 35 | if (z.is_cuda()) { 36 | return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps); 37 | } else { 38 | return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps); 39 | } 40 | } 41 | 42 | void leaky_relu_forward(at::Tensor z, float slope) { 43 | at::leaky_relu_(z, slope); 44 | } 45 | 46 | void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) { 47 | if (z.is_cuda()) { 48 | return leaky_relu_backward_cuda(z, dz, slope); 49 | } else { 50 | return leaky_relu_backward_cpu(z, dz, slope); 51 | } 52 | } 53 | 54 | void elu_forward(at::Tensor z) { 55 | at::elu_(z); 56 | } 57 | 58 | void elu_backward(at::Tensor z, at::Tensor dz) { 59 | if (z.is_cuda()) { 60 | return elu_backward_cuda(z, dz); 61 | } else { 62 | return elu_backward_cpu(z, dz); 63 | } 64 | } 65 | 66 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 67 | m.def("mean_var", &mean_var, "Mean and variance computation"); 68 | m.def("forward", &forward, "In-place forward computation"); 69 | m.def("edz_eydz", &edz_eydz, "First part of backward computation"); 70 | m.def("backward", &backward, "Second part of backward computation"); 71 | m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation"); 72 | m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion"); 73 | m.def("elu_forward", &elu_forward, "Elu forward computation"); 74 | m.def("elu_backward", &elu_backward, "Elu backward computation and inversion"); 75 | } -------------------------------------------------------------------------------- /src/hrnet/models/sync_bn/inplace_abn/src/inplace_abn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | std::vector mean_var_cpu(at::Tensor x); 8 | std::vector mean_var_cuda(at::Tensor x); 9 | 10 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 11 | bool affine, float eps); 12 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 13 | bool affine, float eps); 14 | 15 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 16 | bool affine, float eps); 17 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 18 | bool affine, float eps); 19 | 20 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 21 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 22 | std::vector backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 23 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 24 | 25 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope); 26 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope); 27 | 28 | void elu_backward_cpu(at::Tensor z, at::Tensor dz); 29 | void elu_backward_cuda(at::Tensor z, at::Tensor dz); -------------------------------------------------------------------------------- /src/hrnet/models/sync_bn/inplace_abn/src/inplace_abn_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "inplace_abn.h" 6 | 7 | at::Tensor reduce_sum(at::Tensor x) { 8 | if (x.ndimension() == 2) { 9 | return x.sum(0); 10 | } else { 11 | auto x_view = x.view({x.size(0), x.size(1), -1}); 12 | return x_view.sum(-1).sum(0); 13 | } 14 | } 15 | 16 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) { 17 | if (x.ndimension() == 2) { 18 | return v; 19 | } else { 20 | std::vector broadcast_size = {1, -1}; 21 | for (int64_t i = 2; i < x.ndimension(); ++i) 22 | broadcast_size.push_back(1); 23 | 24 | return v.view(broadcast_size); 25 | } 26 | } 27 | 28 | int64_t count(at::Tensor x) { 29 | int64_t count = x.size(0); 30 | for (int64_t i = 2; i < x.ndimension(); ++i) 31 | count *= x.size(i); 32 | 33 | return count; 34 | } 35 | 36 | at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) { 37 | if (affine) { 38 | return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z); 39 | } else { 40 | return z; 41 | } 42 | } 43 | 44 | std::vector mean_var_cpu(at::Tensor x) { 45 | auto num = count(x); 46 | auto mean = reduce_sum(x) / num; 47 | auto diff = x - broadcast_to(mean, x); 48 | auto var = reduce_sum(diff.pow(2)) / num; 49 | 50 | return {mean, var}; 51 | } 52 | 53 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 54 | bool affine, float eps) { 55 | auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var); 56 | auto mul = at::rsqrt(var + eps) * gamma; 57 | 58 | x.sub_(broadcast_to(mean, x)); 59 | x.mul_(broadcast_to(mul, x)); 60 | if (affine) x.add_(broadcast_to(bias, x)); 61 | 62 | return x; 63 | } 64 | 65 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 66 | bool affine, float eps) { 67 | auto edz = reduce_sum(dz); 68 | auto y = invert_affine(z, weight, bias, affine, eps); 69 | auto eydz = reduce_sum(y * dz); 70 | 71 | return {edz, eydz}; 72 | } 73 | 74 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 75 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 76 | auto y = invert_affine(z, weight, bias, affine, eps); 77 | auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps); 78 | 79 | auto num = count(z); 80 | auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz); 81 | 82 | auto dweight = at::empty(z.type(), {0}); 83 | auto dbias = at::empty(z.type(), {0}); 84 | if (affine) { 85 | dweight = eydz * at::sign(weight); 86 | dbias = edz; 87 | } 88 | 89 | return {dx, dweight, dbias}; 90 | } 91 | 92 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) { 93 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] { 94 | int64_t count = z.numel(); 95 | auto *_z = z.data(); 96 | auto *_dz = dz.data(); 97 | 98 | for (int64_t i = 0; i < count; ++i) { 99 | if (_z[i] < 0) { 100 | _z[i] *= 1 / slope; 101 | _dz[i] *= slope; 102 | } 103 | } 104 | })); 105 | } 106 | 107 | void elu_backward_cpu(at::Tensor z, at::Tensor dz) { 108 | AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] { 109 | int64_t count = z.numel(); 110 | auto *_z = z.data(); 111 | auto *_dz = dz.data(); 112 | 113 | for (int64_t i = 0; i < count; ++i) { 114 | if (_z[i] < 0) { 115 | _z[i] = log1p(_z[i]); 116 | _dz[i] *= (_z[i] + 1.f); 117 | } 118 | } 119 | })); 120 | } -------------------------------------------------------------------------------- /src/hrnet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ctu-vras/traversability_estimation/9e96f12a6769e8d90240e54cce47b4afd25a3229/src/hrnet/utils/__init__.py -------------------------------------------------------------------------------- /src/hrnet/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Jingyi Xie (hsfzxjy@gmail.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | import torch 8 | import torch.distributed as torch_dist 9 | 10 | def is_distributed(): 11 | return torch_dist.is_initialized() 12 | 13 | def get_world_size(): 14 | if not torch_dist.is_initialized(): 15 | return 1 16 | return torch_dist.get_world_size() 17 | 18 | def get_rank(): 19 | if not torch_dist.is_initialized(): 20 | return 0 21 | return torch_dist.get_rank() -------------------------------------------------------------------------------- /src/tconcord3d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ctu-vras/traversability_estimation/9e96f12a6769e8d90240e54cce47b4afd25a3229/src/tconcord3d/__init__.py -------------------------------------------------------------------------------- /src/tconcord3d/builder/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Awet H. Gebrehiwot 3 | # --------------------------| 4 | -------------------------------------------------------------------------------- /src/tconcord3d/builder/loss_builder.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Awet H. Gebrehiwot 3 | # --------------------------| 4 | import torch 5 | from utils.lovasz_losses import lovasz_softmax, lovasz_softmax_lcw, cross_entropy_lcw 6 | from utils.loss_func import FocalLoss 7 | 8 | 9 | def build(wce=True, lovasz=True, num_class=20, ignore_label=None, weights=None, ssl=False, fl=False): 10 | # focal loss and semisupervised learning 11 | if ssl and fl: 12 | if wce and lovasz: 13 | return FocalLoss(weight=weights, ignore_index=ignore_label), lovasz_softmax_lcw 14 | elif wce and not lovasz: 15 | return wce 16 | elif not wce and lovasz: 17 | return lovasz_softmax_lcw 18 | 19 | # only semi-supervised learning 20 | if ssl: 21 | if wce and lovasz: 22 | return cross_entropy_lcw, lovasz_softmax_lcw 23 | elif wce and not lovasz: 24 | return wce 25 | elif not wce and lovasz: 26 | return lovasz_softmax_lcw 27 | 28 | # focal loss on GT (fully supervised) 29 | if fl: 30 | loss_funs = FocalLoss(weight=weights, ignore_index=ignore_label) 31 | else: 32 | loss_funs = torch.nn.CrossEntropyLoss(ignore_index=ignore_label) 33 | 34 | if wce and lovasz: 35 | return loss_funs, lovasz_softmax 36 | elif wce and not lovasz: 37 | return wce 38 | elif not wce and lovasz: 39 | return lovasz_softmax 40 | else: 41 | raise NotImplementedError 42 | -------------------------------------------------------------------------------- /src/tconcord3d/builder/model_builder.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | from tconcord3d.model.cylinder_3d import get_model_class 4 | from tconcord3d.model.segment_3d import Asymm_3d_spconv 5 | from tconcord3d.model.cylinder_feature import cylinder_fea 6 | 7 | 8 | def build(model_config): 9 | output_shape = model_config['output_shape'] 10 | num_class = model_config['num_class'] 11 | num_input_features = model_config['num_input_features'] 12 | use_norm = model_config['use_norm'] 13 | init_size = model_config['init_size'] 14 | fea_dim = model_config['fea_dim'] 15 | out_fea_dim = model_config['out_fea_dim'] 16 | 17 | cylinder_3d_spconv_seg = Asymm_3d_spconv( 18 | output_shape=output_shape, 19 | use_norm=use_norm, 20 | num_input_features=num_input_features, 21 | init_size=init_size, 22 | nclasses=num_class) 23 | 24 | cy_fea_net = cylinder_fea(grid_size=output_shape, 25 | fea_dim=fea_dim, 26 | out_pt_fea_dim=out_fea_dim, 27 | fea_compre=num_input_features) 28 | 29 | model = get_model_class(model_config["model_architecture"])( 30 | cylin_model=cy_fea_net, 31 | segmentator_spconv=cylinder_3d_spconv_seg, 32 | sparse_shape=output_shape 33 | ) 34 | 35 | return model 36 | -------------------------------------------------------------------------------- /src/tconcord3d/config/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | -------------------------------------------------------------------------------- /src/tconcord3d/config/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | from pathlib import Path 3 | 4 | from strictyaml import Bool, Float, Int, Map, Seq, Str, as_document, load 5 | 6 | model_params = Map( 7 | { 8 | "model_architecture": Str(), 9 | "output_shape": Seq(Int()), 10 | "fea_dim": Int(), 11 | "out_fea_dim": Int(), 12 | "num_class": Int(), 13 | "num_input_features": Int(), 14 | "use_norm": Bool(), 15 | "init_size": Int(), 16 | } 17 | ) 18 | 19 | dataset_params = Map( 20 | { 21 | "dataset_type": Str(), 22 | "pc_dataset_type": Str(), 23 | "ignore_label": Int(), 24 | "return_test": Bool(), 25 | "fixed_volume_space": Bool(), 26 | "label_mapping": Str(), 27 | "max_volume_space": Seq(Float()), 28 | "min_volume_space": Seq(Float()), 29 | } 30 | ) 31 | 32 | 33 | train_data_loader = Map( 34 | { 35 | "data_path": Str(), 36 | "imageset": Str(), 37 | "return_ref": Bool(), 38 | "batch_size": Int(), 39 | "shuffle": Bool(), 40 | "num_workers": Int(), 41 | } 42 | ) 43 | 44 | val_data_loader = Map( 45 | { 46 | "data_path": Str(), 47 | "imageset": Str(), 48 | "return_ref": Bool(), 49 | "batch_size": Int(), 50 | "shuffle": Bool(), 51 | "num_workers": Int(), 52 | } 53 | ) 54 | 55 | test_data_loader = Map( 56 | { 57 | "data_path": Str(), 58 | "imageset": Str(), 59 | "return_ref": Bool(), 60 | "batch_size": Int(), 61 | "shuffle": Bool(), 62 | "num_workers": Int(), 63 | } 64 | ) 65 | 66 | ssl_data_loader = Map( 67 | { 68 | "data_path": Str(), 69 | "imageset": Str(), 70 | "return_ref": Bool(), 71 | "batch_size": Int(), 72 | "shuffle": Bool(), 73 | "num_workers": Int(), 74 | } 75 | ) 76 | 77 | train_params = Map( 78 | { 79 | "model_load_path": Str(), 80 | "model_save_path": Str(), 81 | "checkpoint_every_n_steps": Int(), 82 | "max_num_epochs": Int(), 83 | "eval_every_n_steps": Int(), 84 | "learning_rate": Float(), 85 | "past": Int(), 86 | "future": Int(), 87 | "T_past": Str(), 88 | "T_future": Str(), 89 | "ssl": Bool(), 90 | "rgb": Bool(), 91 | } 92 | ) 93 | 94 | schema_v4 = Map( 95 | { 96 | "format_version": Int(), 97 | "model_params": model_params, 98 | "dataset_params": dataset_params, 99 | "train_data_loader": train_data_loader, 100 | "val_data_loader": val_data_loader, 101 | "test_data_loader": test_data_loader, 102 | "ssl_data_loader": ssl_data_loader, 103 | "train_params": train_params, 104 | } 105 | ) 106 | 107 | 108 | SCHEMA_FORMAT_VERSION_TO_SCHEMA = {4: schema_v4} 109 | 110 | 111 | def load_config_data(path: str) -> dict: 112 | yaml_string = Path(path).read_text() 113 | cfg_without_schema = load(yaml_string, schema=None) 114 | schema_version = int(cfg_without_schema["format_version"]) 115 | if schema_version not in SCHEMA_FORMAT_VERSION_TO_SCHEMA: 116 | raise Exception(f"Unsupported schema format version: {schema_version}.") 117 | 118 | strict_cfg = load(yaml_string, schema=SCHEMA_FORMAT_VERSION_TO_SCHEMA[schema_version]) 119 | cfg: dict = strict_cfg.data 120 | return cfg 121 | 122 | 123 | def config_data_to_config(data): # type: ignore 124 | return as_document(data, schema_v4) 125 | 126 | 127 | def save_config_data(data: dict, path: str) -> None: 128 | cfg_document = config_data_to_config(data) 129 | with open(Path(path), "w") as f: 130 | f.write(cfg_document.as_yaml()) 131 | -------------------------------------------------------------------------------- /src/tconcord3d/config/label_mapping/sematickitti/semantic-kitti_ssl_s20_p80.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : "traversable" 4 | 1 : "obstacle" 5 | 2 : "traver" 6 | 10: "car" 7 | 11: "bicycle" 8 | 13: "bus" 9 | 15: "motorcycle" 10 | 16: "on-rails" 11 | 18: "truck" 12 | 20: "other-vehicle" 13 | 30: "person" 14 | 31: "bicyclist" 15 | 32: "motorcyclist" 16 | 40: "road" 17 | 44: "parking" 18 | 48: "sidewalk" 19 | 49: "other-ground" 20 | 50: "building" 21 | 51: "fence" 22 | 52: "other-structure" 23 | 60: "lane-marking" 24 | 70: "vegetation" 25 | 71: "trunk" 26 | 72: "terrain" 27 | 80: "pole" 28 | 81: "traffic-sign" 29 | 99: "other-object" 30 | 252: "moving-car" 31 | 253: "moving-bicyclist" 32 | 254: "moving-person" 33 | 255: "moving-motorcyclist" 34 | 256: "moving-on-rails" 35 | 257: "moving-bus" 36 | 258: "moving-truck" 37 | 259: "moving-other-vehicle" 38 | color_map: # bgr 39 | 0 : [0, 0, 0] 40 | 1 : [0, 0, 255] 41 | 2 : [0, 0, 255] 42 | 10: [245, 150, 100] 43 | 11: [245, 230, 100] 44 | 13: [250, 80, 100] 45 | 15: [150, 60, 30] 46 | 16: [255, 0, 0] 47 | 18: [180, 30, 80] 48 | 20: [255, 0, 0] 49 | 30: [30, 30, 255] 50 | 31: [200, 40, 255] 51 | 32: [90, 30, 150] 52 | 40: [255, 0, 255] 53 | 44: [255, 150, 255] 54 | 48: [75, 0, 75] 55 | 49: [75, 0, 175] 56 | 50: [0, 200, 255] 57 | 51: [50, 120, 255] 58 | 52: [0, 150, 255] 59 | 60: [170, 255, 150] 60 | 70: [0, 175, 0] 61 | 71: [0, 60, 135] 62 | 72: [80, 240, 150] 63 | 80: [150, 240, 255] 64 | 81: [0, 0, 255] 65 | 99: [255, 255, 50] 66 | 252: [245, 150, 100] 67 | 256: [255, 0, 0] 68 | 253: [200, 40, 255] 69 | 254: [30, 30, 255] 70 | 255: [90, 30, 150] 71 | 257: [250, 80, 100] 72 | 258: [180, 30, 80] 73 | 259: [255, 0, 0] 74 | content: # as a ratio with the total number of points 75 | 0: 0.018889854628292943 76 | 1: 0.0002937197336781505 77 | 10: 0.040818519255974316 78 | 11: 0.00016609538710764618 79 | 13: 2.7879693665067774e-05 80 | 15: 0.00039838616015114444 81 | 16: 0.0 82 | 18: 0.0020633612104619787 83 | 20: 0.0016218197275284021 84 | 30: 0.00017698551338515307 85 | 31: 1.1065903904919655e-08 86 | 32: 5.532951952459828e-09 87 | 40: 0.1987493871255525 88 | 44: 0.014717169549888214 89 | 48: 0.14392298360372 90 | 49: 0.0039048553037472045 91 | 50: 0.1326861944777486 92 | 51: 0.0723592229456223 93 | 52: 0.002395131480328884 94 | 60: 4.7084144280367186e-05 95 | 70: 0.26681502148037506 96 | 71: 0.006035012012626033 97 | 72: 0.07814222006271769 98 | 80: 0.002855498193863172 99 | 81: 0.0006155958086189918 100 | 99: 0.009923127583046915 101 | 252: 0.001789309418528068 102 | 253: 0.00012709999297008662 103 | 254: 0.00016059776092534436 104 | 255: 3.745553104802113e-05 105 | 256: 0.0 106 | 257: 0.00011351574470342043 107 | 258: 0.00010157861367183268 108 | 259: 4.3840131989471124e-05 109 | # classes that are indistinguishable from single scan or inconsistent in 110 | # ground truth are mapped to their closest equivalent 111 | learning_map: 112 | 0: 0 # "unlabeled" 113 | 1: 0 # "outlier" 114 | 10: 1 # "car" 115 | 11: 1 # "bicycle" 116 | 13: 1 # "bus" 117 | 15: 1 # "motorcycle" 118 | 16: 1 # "on-rails" 119 | 18: 1 # "truck" 120 | 20: 1 # "other-vehicle" 121 | 30: 1 # "person" 122 | 31: 1 # "bicyclist" 123 | 32: 1 # "motorcyclist" 124 | 40: 2 # "road" 125 | 44: 2 # "parking" 126 | 48: 2 # "sidewalk" 127 | 49: 2 # "other-ground" 128 | 50: 1 # "building" 129 | 51: 1 # "fence" 130 | 52: 1 # "other-structure" 131 | 60: 2 # "lane-marking" 132 | 70: 1 # "vegetation" 133 | 71: 1 # "trunk" 134 | 72: 2 # "terrain" 135 | 80: 1 # "pole" 136 | 81: 1 # "traffic-sign" 137 | 99: 1 # "other-object" 138 | 252: 1 # "moving-car" 139 | 253: 1 # "moving-bicyclist" 140 | 254: 1 # "moving-person" 141 | 255: 1 # "moving-motorcyclist" 142 | 256: 1 # "moving-on-rails" 143 | 257: 1 # "moving-bus" 144 | 258: 1 # "moving-truck" 145 | 259: 1 # "moving-other-vehicle 146 | learning_map_inv: # inverse of previous map 147 | 0: 255 # "unlabeled", and others ignored 148 | 1: 1 # "car" 149 | 2: 0 # "bicycle" 150 | learning_ignore: # Ignore classes 151 | 0: True # "unlabeled", and others ignored 152 | 1: False # "car" 153 | 2: False # "bicycle" 154 | 155 | split: # sequence numbers 156 | train: # 20 percent gt 157 | # - '0020' 158 | # - '0120' 159 | # - '0220' 160 | # - '0320' 161 | - '04' 162 | # - '0520' 163 | # - '0620' 164 | # - '0720' 165 | # - '0920' 166 | # - '1020' 167 | 168 | pseudo: # 80 percent pseudo-labeled 169 | # - '0080' 170 | # - '0180' 171 | # - '0280' 172 | - '03' 173 | # - '0480' 174 | # - '0580' 175 | # - '0680' 176 | # - '0780' 177 | # - '0980' 178 | # - '1080' 179 | 180 | valid: 181 | # - '08' 182 | - '08' 183 | test: 184 | - '11' 185 | - '12' 186 | - '13' 187 | - '14' 188 | - '15' 189 | - '16' 190 | - '17' 191 | - '18' 192 | - '19' 193 | - '20' 194 | - '21' 195 | -------------------------------------------------------------------------------- /src/tconcord3d/config/semantickitti/semantickitti_S0_0_T11_33_ssl_s20_p80.yaml: -------------------------------------------------------------------------------- 1 | # Config format schema number 2 | format_version: 4 3 | 4 | ################### 5 | ## Model options 6 | model_params: 7 | model_architecture: "cylinder_asym" 8 | 9 | output_shape: 10 | - 480 11 | - 360 12 | - 32 13 | 14 | fea_dim: 9 15 | out_fea_dim: 256 16 | num_class: 3 17 | num_input_features: 16 18 | use_norm: True 19 | init_size: 32 20 | 21 | ################### 22 | ## Dataset options 23 | dataset_params: 24 | dataset_type: "cylinder_dataset" 25 | pc_dataset_type: "SemKITTI_sk_multiscan" # # # "SemKITTI_sk" # 26 | ignore_label: 0 27 | return_test: False 28 | fixed_volume_space: True 29 | label_mapping: "./config/label_mapping/sematickitti/semantic-kitti_ssl_s20_p80.yaml" 30 | max_volume_space: 31 | - 50 32 | - 3.1415926 33 | - 2 34 | min_volume_space: 35 | - 0 36 | - -3.1415926 37 | - -4 38 | 39 | ################### 40 | ## Data_loader options 41 | train_data_loader: 42 | data_path: "/home/ruslan/data/datasets/KITTI/SemanticKITTI/sequences" 43 | imageset: "train" 44 | return_ref: True 45 | batch_size: 4 #4 46 | shuffle: True 47 | num_workers: 0 48 | 49 | ssl_data_loader: 50 | data_path: "/home/ruslan/data/datasets/KITTI/SemanticKITTI/sequences" 51 | imageset: "pseudo" 52 | return_ref: True 53 | batch_size: 4 54 | shuffle: False 55 | num_workers: 4 56 | 57 | val_data_loader: 58 | data_path: "/home/ruslan/data/datasets/KITTI/SemanticKITTI/sequences" 59 | imageset: "val" 60 | return_ref: True 61 | batch_size: 2 62 | shuffle: False 63 | num_workers: 0 64 | 65 | test_data_loader: 66 | data_path: "/home/ruslan/data/datasets/KITTI/SemanticKITTI/sequences" 67 | imageset: "test" 68 | return_ref: True 69 | batch_size: 2 70 | shuffle: False 71 | num_workers: 4 72 | 73 | 74 | ################### 75 | ## test params 76 | train_params: 77 | model_load_path: "./model_save_dir/student_kitti_traversablity_f0_0_time_ema.pt" 78 | # model_load_path: "./model_save_dir/model_save_f0_0_T11_33_ssl_s20_p80_b4_singlescan.pt" 79 | model_save_path: "./model_save_dir/model_save_f0_0_T11_33_ssl_s20_p80_b4_singlescan.pt" 80 | checkpoint_every_n_steps: 1000 #4599 81 | max_num_epochs: 40 82 | eval_every_n_steps: 1000 #4599 83 | learning_rate: 0.001 84 | past: 0 85 | future: 0 86 | T_past: 11 87 | T_future: 33 88 | ssl: True 89 | rgb: False 90 | -------------------------------------------------------------------------------- /src/tconcord3d/config/semantickitti/semantickitti_S0_0_test.yaml: -------------------------------------------------------------------------------- 1 | # Config format schema number 2 | format_version: 4 3 | 4 | ################### 5 | ## Model options 6 | model_params: 7 | model_architecture: "cylinder_asym" 8 | 9 | output_shape: 10 | - 480 11 | - 360 12 | - 32 13 | 14 | fea_dim: 9 15 | out_fea_dim: 256 16 | num_class: 20 17 | num_input_features: 16 18 | use_norm: True 19 | init_size: 32 20 | 21 | ################### 22 | ## Dataset options 23 | dataset_params: 24 | dataset_type: "cylinder_dataset" 25 | pc_dataset_type: "SemKITTI_sk_multiscan" #"SemKITTI_sk_multiscan" # # # "SemKITTI_sk" # 26 | ignore_label: 0 27 | return_test: False 28 | fixed_volume_space: True 29 | label_mapping: "./config/label_mapping/sematickitti/semantic-kitti_2.yaml" 30 | max_volume_space: 31 | - 50 32 | - 3.1415926 33 | - 2 34 | min_volume_space: 35 | - 0 36 | - -3.1415926 37 | - -4 38 | 39 | ################### 40 | ## Data_loader options 41 | train_data_loader: 42 | data_path: "/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/dataset/sequences" 43 | #"/mnt/data/vras/data/gebreawe/semantic-kitti/dataset/sequences" 44 | #"/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/dataset/sequences" #"/data/dataset/semantic_kitti/data_semkitti/dataset/sequences/" 45 | imageset: "train" 46 | return_ref: True 47 | batch_size: 5 #4 48 | shuffle: True 49 | num_workers: 0 50 | 51 | val_data_loader: 52 | data_path: "/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/dataset/sequences" 53 | #"/mnt/data/vras/data/gebreawe/semantic-kitti/dataset/sequences" 54 | #"/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/dataset/sequences" #"/data/dataset/semantic_kitti/data_semkitti/dataset/sequences/" 55 | imageset: "val" 56 | return_ref: True 57 | batch_size: 5 58 | shuffle: False 59 | num_workers: 0 60 | 61 | test_data_loader: 62 | data_path: "/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/dataset/sequences" 63 | #"/mnt/data/vras/data/gebreawe/semantic-kitti/dataset/sequences" 64 | #"/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/dataset/sequences" #"/data/dataset/semantic_kitti/data_semkitti/dataset/sequences/" 65 | imageset: "test" 66 | return_ref: True 67 | batch_size: 5 68 | shuffle: False 69 | num_workers: 0 70 | 71 | ssl_data_loader: 72 | data_path: "/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/dataset/sequences" 73 | imageset: "ssl" 74 | return_ref: True 75 | batch_size: 5 76 | shuffle: False 77 | num_workers: 4 78 | 79 | ################### 80 | ## test params 81 | train_params: 82 | model_load_path: "./model_save_dir/model_save_f0_0_s20_b4_singlescan.pt" 83 | model_save_path: "./model_save_dir/model_save_f0_0_s20_b4_singlescan.pt" 84 | checkpoint_every_n_steps: 1000 85 | max_num_epochs: 40 86 | eval_every_n_steps: 1000 87 | learning_rate: 0.001 88 | past: 0 89 | future: 0 90 | T_past: 0 91 | T_future: 20 92 | ssl: False 93 | rgb: False 94 | -------------------------------------------------------------------------------- /src/tconcord3d/config/semantickitti/semantickitti_T0_0.yaml: -------------------------------------------------------------------------------- 1 | # Config format schema number 2 | format_version: 4 3 | 4 | ################### 5 | ## Model options 6 | model_params: 7 | model_architecture: "cylinder_asym" 8 | 9 | output_shape: 10 | - 480 11 | - 360 12 | - 32 13 | 14 | fea_dim: 9 15 | out_fea_dim: 256 16 | num_class: 20 17 | num_input_features: 16 18 | use_norm: True 19 | init_size: 32 20 | 21 | ################### 22 | ## Dataset options 23 | dataset_params: 24 | dataset_type: "cylinder_dataset" 25 | pc_dataset_type: "SemKITTI_sk_multiscan" # # # "SemKITTI_sk" # 26 | ignore_label: 0 27 | return_test: False 28 | fixed_volume_space: True 29 | label_mapping: "./config/label_mapping/sematickitti/semantic-kitti_ssl_s20_p80.yaml" 30 | max_volume_space: 31 | - 50 32 | - 3.1415926 33 | - 2 34 | min_volume_space: 35 | - 0 36 | - -3.1415926 37 | - -4 38 | 39 | ################### 40 | ## Data_loader options 41 | train_data_loader: 42 | data_path: "/home/ruslan/data/datasets/KITTI/SemanticKITTI/sequences" 43 | imageset: "train" 44 | return_ref: True 45 | batch_size: 4 #4 46 | shuffle: True 47 | num_workers: 4 48 | 49 | val_data_loader: 50 | data_path: "/home/ruslan/data/datasets/KITTI/SemanticKITTI/sequences" 51 | imageset: "val" 52 | return_ref: True 53 | batch_size: 4 54 | shuffle: False 55 | num_workers: 0 56 | 57 | test_data_loader: 58 | data_path: "/home/ruslan/data/datasets/KITTI/SemanticKITTI/sequences" 59 | imageset: "test" 60 | return_ref: True 61 | batch_size: 4 62 | shuffle: False 63 | num_workers: 4 64 | 65 | ssl_data_loader: 66 | data_path: "/home/ruslan/data/datasets/KITTI/SemanticKITTI/sequences" 67 | imageset: "pseudo" 68 | return_ref: True 69 | batch_size: 4 70 | shuffle: False 71 | num_workers: 4 72 | 73 | ################### 74 | ## test params 75 | train_params: 76 | model_load_path: "./model_save_dir/model_save_f0_0_s100_b4_t1.pt" 77 | model_save_path: "./model_save_dir/model_save_f0_0_s100_b4_t1.pt" 78 | checkpoint_every_n_steps: 1000 79 | max_num_epochs: 40 80 | eval_every_n_steps: 1000 81 | learning_rate: 0.001 82 | past: 0 83 | future: 0 84 | T_past: 0 85 | T_future: 0 86 | ssl: False 87 | rgb: False 88 | -------------------------------------------------------------------------------- /src/tconcord3d/config/semantickitti/semantickitti_T1_1.yaml: -------------------------------------------------------------------------------- 1 | # Config format schema number 2 | format_version: 4 3 | 4 | ################### 5 | ## Model options 6 | model_params: 7 | model_architecture: "cylinder_asym" 8 | 9 | output_shape: 10 | - 480 11 | - 360 12 | - 32 13 | 14 | fea_dim: 9 15 | out_fea_dim: 256 16 | num_class: 20 17 | num_input_features: 16 18 | use_norm: True 19 | init_size: 32 20 | 21 | ################### 22 | ## Dataset options 23 | dataset_params: 24 | dataset_type: "cylinder_dataset" 25 | pc_dataset_type: "SemKITTI_sk_multiscan" # # # "SemKITTI_sk" # 26 | ignore_label: 0 27 | return_test: False 28 | fixed_volume_space: True 29 | label_mapping: "./config/label_mapping/sematickitti/semantic-kitti_ssl_teacher_student.yaml" 30 | max_volume_space: 31 | - 50 32 | - 3.1415926 33 | - 2 34 | min_volume_space: 35 | - 0 36 | - -3.1415926 37 | - -4 38 | 39 | ################### 40 | ## Data_loader options 41 | train_data_loader: 42 | data_path: "/mnt/personal/gebreawe/Datasets/RealWorld/semantic-kitti/train_pseudo_60/sequences" 43 | imageset: "train" 44 | return_ref: True 45 | batch_size: 4 #4 46 | shuffle: True 47 | num_workers: 4 48 | 49 | val_data_loader: 50 | data_path: "/mnt/personal/gebreawe/Datasets/RealWorld/semantic-kitti/train_pseudo_60/sequences" 51 | imageset: "val" 52 | return_ref: True 53 | batch_size: 4 54 | shuffle: False 55 | num_workers: 0 56 | 57 | test_data_loader: 58 | data_path: "/mnt/personal/gebreawe/Datasets/RealWorld/semantic-kitti/train_pseudo_60/sequences" 59 | imageset: "test" 60 | return_ref: True 61 | batch_size: 4 62 | shuffle: False 63 | num_workers: 4 64 | 65 | ssl_data_loader: 66 | data_path: "/mnt/personal/gebreawe/Datasets/RealWorld/semantic-kitti/train_pseudo_60/sequences" 67 | imageset: "pseudo" 68 | return_ref: True 69 | batch_size: 4 70 | shuffle: False 71 | num_workers: 4 72 | 73 | ################### 74 | ## test params 75 | train_params: 76 | model_load_path: "./model_save_dir/model_save_f1_1_s60_b4_t1.pt" 77 | model_save_path: "./model_save_dir/model_save_f1_1_s60_b4_t1.pt" 78 | checkpoint_every_n_steps: 1000 79 | max_num_epochs: 40 80 | eval_every_n_steps: 1000 81 | learning_rate: 0.001 82 | past: 1 83 | future: 1 84 | T_past: 1 85 | T_future: 11 86 | ssl: False 87 | rgb: False 88 | -------------------------------------------------------------------------------- /src/tconcord3d/config/semantickitti/semantickitti_T2_2.yaml: -------------------------------------------------------------------------------- 1 | # Config format schema number 2 | format_version: 4 3 | 4 | ################### 5 | ## Model options 6 | model_params: 7 | model_architecture: "cylinder_asym" 8 | 9 | output_shape: 10 | - 480 11 | - 360 12 | - 32 13 | 14 | fea_dim: 9 15 | out_fea_dim: 256 16 | num_class: 20 17 | num_input_features: 16 18 | use_norm: True 19 | init_size: 32 20 | 21 | ################### 22 | ## Dataset options 23 | dataset_params: 24 | dataset_type: "cylinder_dataset" 25 | pc_dataset_type: "SemKITTI_sk_multiscan" # # # "SemKITTI_sk" # 26 | ignore_label: 0 27 | return_test: False 28 | fixed_volume_space: True 29 | label_mapping: "./config/label_mapping/sematickitti/semantic-kitti_s20.yaml" 30 | max_volume_space: 31 | - 50 32 | - 3.1415926 33 | - 2 34 | min_volume_space: 35 | - 0 36 | - -3.1415926 37 | - -4 38 | 39 | ################### 40 | ## Data_loader options 41 | train_data_loader: 42 | data_path: "/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_20/sequences" 43 | #"/mnt/data/vras/data/gebreawe/semantic-kitti/dataset/sequences" 44 | #"/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_20/sequences" #"/data/dataset/semantic_kitti/data_semkitti/dataset/sequences/" 45 | imageset: "train" 46 | return_ref: True 47 | batch_size: 3 #4 48 | shuffle: True 49 | num_workers: 4 50 | 51 | val_data_loader: 52 | data_path: "/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_20/sequences" 53 | #"/mnt/data/vras/data/gebreawe/semantic-kitti/dataset/sequences" 54 | #"/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_20/sequences" #"/data/dataset/semantic_kitti/data_semkitti/dataset/sequences/" 55 | imageset: "val" 56 | return_ref: True 57 | batch_size: 4 58 | shuffle: False 59 | num_workers: 0 60 | 61 | test_data_loader: 62 | data_path: "/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_20/sequences" 63 | #"/mnt/data/vras/data/gebreawe/semantic-kitti/dataset/sequences" 64 | #"/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_20/sequences" #"/data/dataset/semantic_kitti/data_semkitti/dataset/sequences/" 65 | imageset: "test" 66 | return_ref: True 67 | batch_size: 4 68 | shuffle: False 69 | num_workers: 4 70 | 71 | ssl_data_loader: 72 | data_path: "/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_20/sequences" 73 | #"/mnt/data/vras/data/gebreawe/semantic-kitti/dataset/sequences" 74 | #"/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_10/sequences" #"/data/dataset/semantic_kitti/data_semkitti/dataset/sequences/" 75 | imageset: "ssl" 76 | return_ref: True 77 | batch_size: 1 78 | shuffle: False 79 | num_workers: 4 80 | 81 | ################### 82 | ## test params 83 | train_params: 84 | model_load_path: "./model_save_dir/model_save_f2_2_s20_b4_t2.pt" 85 | model_save_path: "./model_save_dir/model_save_f2_2_s20_b4_t2.pt" 86 | checkpoint_every_n_steps: 1000 87 | max_num_epochs: 40 88 | eval_every_n_steps: 1000 89 | learning_rate: 0.001 90 | past: 2 91 | future: 2 92 | T_past: 2 93 | T_future: 22 94 | ssl: False 95 | rgb: False 96 | -------------------------------------------------------------------------------- /src/tconcord3d/config/semantickitti/semantickitti_T3_3.yaml: -------------------------------------------------------------------------------- 1 | # Config format schema number 2 | format_version: 4 3 | 4 | ################### 5 | ## Model options 6 | model_params: 7 | model_architecture: "cylinder_asym" 8 | 9 | output_shape: 10 | - 480 11 | - 360 12 | - 32 13 | 14 | fea_dim: 9 15 | out_fea_dim: 256 16 | num_class: 20 17 | num_input_features: 16 18 | use_norm: True 19 | init_size: 32 20 | 21 | ################### 22 | ## Dataset options 23 | dataset_params: 24 | dataset_type: "cylinder_dataset" 25 | pc_dataset_type: "SemKITTI_sk_multiscan" # # # "SemKITTI_sk" # 26 | ignore_label: 0 27 | return_test: False 28 | fixed_volume_space: True 29 | label_mapping: "./config/label_mapping/sematickitti/semantic-kitti_s20.yaml" 30 | max_volume_space: 31 | - 50 32 | - 3.1415926 33 | - 2 34 | min_volume_space: 35 | - 0 36 | - -3.1415926 37 | - -4 38 | 39 | ################### 40 | ## Data_loader options 41 | train_data_loader: 42 | data_path: "/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_20/sequences" 43 | #"/mnt/data/vras/data/gebreawe/semantic-kitti/dataset/sequences" 44 | #"/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_20/sequences" #"/data/dataset/semantic_kitti/data_semkitti/dataset/sequences/" 45 | imageset: "train" 46 | return_ref: True 47 | batch_size: 3 #4 48 | shuffle: True 49 | num_workers: 4 50 | 51 | val_data_loader: 52 | data_path: "/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_20/sequences" 53 | #"/mnt/data/vras/data/gebreawe/semantic-kitti/dataset/sequences" 54 | #"/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_20/sequences" #"/data/dataset/semantic_kitti/data_semkitti/dataset/sequences/" 55 | imageset: "val" 56 | return_ref: True 57 | batch_size: 3 58 | shuffle: False 59 | num_workers: 4 60 | 61 | test_data_loader: 62 | data_path: "/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_20/sequences" 63 | #"/mnt/data/vras/data/gebreawe/semantic-kitti/dataset/sequences" 64 | #"/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_20/sequences" #"/data/dataset/semantic_kitti/data_semkitti/dataset/sequences/" 65 | imageset: "test" 66 | return_ref: True 67 | batch_size: 8 68 | shuffle: False 69 | num_workers: 0 70 | 71 | ssl_data_loader: 72 | data_path: "/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_20/sequences" 73 | #"/mnt/data/vras/data/gebreawe/semantic-kitti/dataset/sequences" 74 | #"/mnt/beegfs/gpu/argoverse-tracking-all-training/semantic-kitti/train_pseudo_20/sequences" #"/data/dataset/semantic_kitti/data_semkitti/dataset/sequences/" 75 | imageset: "ssl" 76 | return_ref: True 77 | batch_size: 8 78 | shuffle: False 79 | num_workers: 0 80 | 81 | ################### 82 | ## test params 83 | train_params: 84 | model_load_path: "./model_save_dir/model_save_f3_3_s20_b3.pt" 85 | model_save_path: "./model_save_dir/model_save_f3_3_s20_b3.pt" 86 | checkpoint_every_n_steps: 1000 87 | max_num_epochs: 40 88 | eval_every_n_steps: 1000 89 | learning_rate: 0.001 90 | past: 3 91 | future: 3 92 | T_past: 3 93 | T_future: 3 94 | ssl: False 95 | rgb: False 96 | -------------------------------------------------------------------------------- /src/tconcord3d/model/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | -------------------------------------------------------------------------------- /src/tconcord3d/model/cylinder_3d.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import torch 3 | from torch import nn 4 | 5 | REGISTERED_MODELS_CLASSES = {} 6 | 7 | 8 | def register_model(cls, name=None): 9 | global REGISTERED_MODELS_CLASSES 10 | if name is None: 11 | name = cls.__name__ 12 | assert name not in REGISTERED_MODELS_CLASSES, f"exist class: {REGISTERED_MODELS_CLASSES}" 13 | REGISTERED_MODELS_CLASSES[name] = cls 14 | return cls 15 | 16 | 17 | def get_model_class(name): 18 | global REGISTERED_MODELS_CLASSES 19 | assert name in REGISTERED_MODELS_CLASSES, f"available class: {REGISTERED_MODELS_CLASSES}" 20 | return REGISTERED_MODELS_CLASSES[name] 21 | 22 | 23 | @register_model 24 | class cylinder_asym(nn.Module): 25 | def __init__(self, 26 | cylin_model, 27 | segmentator_spconv, 28 | sparse_shape, 29 | ): 30 | super().__init__() 31 | self.name = "cylinder_asym" 32 | 33 | self.cylinder_3d_generator = cylin_model 34 | 35 | self.cylinder_3d_spconv_seg = segmentator_spconv 36 | 37 | self.sparse_shape = sparse_shape 38 | 39 | def forward(self, train_pt_fea_ten, train_vox_ten, batch_size, val_grid=None, voting_num=4, use_tta=False): 40 | coords, features_3d = self.cylinder_3d_generator(train_pt_fea_ten, train_vox_ten) 41 | 42 | # spatial_features = self.cylinder_3d_spconv_seg(features_3d, coords, batch_size) 43 | # 44 | # return spatial_features 45 | if use_tta: 46 | batch_size *= voting_num 47 | 48 | spatial_features = self.cylinder_3d_spconv_seg(features_3d, coords, batch_size) 49 | 50 | if use_tta: 51 | features_ori = torch.split(spatial_features, 1, dim=0) 52 | fused_predict = features_ori[0][0, :, val_grid[0][:, 0], val_grid[0][:, 1], val_grid[0][:, 2]] 53 | for idx in range(1, voting_num, 1): 54 | fused_predict += features_ori[idx][0, :, val_grid[idx][:, 0], val_grid[idx][:, 1], val_grid[idx][:, 2]] 55 | return fused_predict 56 | else: 57 | return spatial_features 58 | -------------------------------------------------------------------------------- /src/tconcord3d/model/cylinder_feature.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import numba as nb 9 | import multiprocessing 10 | import torch_scatter 11 | 12 | 13 | class cylinder_fea(nn.Module): 14 | 15 | def __init__(self, grid_size, fea_dim=3, 16 | out_pt_fea_dim=64, max_pt_per_encode=64, fea_compre=None): 17 | super(cylinder_fea, self).__init__() 18 | 19 | self.PPmodel = nn.Sequential( 20 | nn.BatchNorm1d(fea_dim), 21 | 22 | nn.Linear(fea_dim, 64), 23 | nn.BatchNorm1d(64), 24 | nn.ReLU(), 25 | 26 | nn.Linear(64, 128), 27 | nn.BatchNorm1d(128), 28 | nn.ReLU(), 29 | 30 | nn.Linear(128, 256), 31 | nn.BatchNorm1d(256), 32 | nn.ReLU(), 33 | 34 | nn.Linear(256, out_pt_fea_dim) 35 | ) 36 | 37 | self.max_pt = max_pt_per_encode 38 | self.fea_compre = fea_compre 39 | self.grid_size = grid_size 40 | kernel_size = 3 41 | self.local_pool_op = torch.nn.MaxPool2d(kernel_size, stride=1, 42 | padding=(kernel_size - 1) // 2, 43 | dilation=1) 44 | self.pool_dim = out_pt_fea_dim 45 | 46 | # point feature compression 47 | if self.fea_compre is not None: 48 | self.fea_compression = nn.Sequential( 49 | nn.Linear(self.pool_dim, self.fea_compre), 50 | nn.ReLU()) 51 | self.pt_fea_dim = self.fea_compre 52 | else: 53 | self.pt_fea_dim = self.pool_dim 54 | 55 | def forward(self, pt_fea, xy_ind): 56 | cur_dev = pt_fea[0].get_device() 57 | 58 | # concate everything 59 | cat_pt_ind = [] 60 | # for i_batch in range(len(xy_ind)): 61 | # cat_pt_ind.append(F.pad(xy_ind[i_batch], (1, 0), 'constant', value=i_batch)) 62 | # Awet Optimized append into list comprehension for faster runtime 63 | cat_pt_ind = [F.pad(xy_ind[i_batch], (1, 0), 'constant', value=i_batch) for i_batch in range(len(xy_ind)) ] 64 | 65 | cat_pt_fea = torch.cat(pt_fea, dim=0) 66 | cat_pt_ind = torch.cat(cat_pt_ind, dim=0) 67 | pt_num = cat_pt_ind.shape[0] 68 | 69 | # shuffle the data 70 | shuffled_ind = torch.randperm(pt_num, device=cur_dev) 71 | cat_pt_fea = cat_pt_fea[shuffled_ind, :] 72 | cat_pt_ind = cat_pt_ind[shuffled_ind, :] 73 | 74 | # unique xy grid index 75 | unq, unq_inv, unq_cnt = torch.unique(cat_pt_ind, return_inverse=True, return_counts=True, dim=0) 76 | unq = unq.type(torch.int64) 77 | 78 | # process feature 79 | processed_cat_pt_fea = self.PPmodel(cat_pt_fea) 80 | pooled_data = torch_scatter.scatter_max(processed_cat_pt_fea, unq_inv, dim=0)[0] 81 | 82 | if self.fea_compre: 83 | processed_pooled_data = self.fea_compression(pooled_data) 84 | else: 85 | processed_pooled_data = pooled_data 86 | 87 | return unq, processed_pooled_data 88 | -------------------------------------------------------------------------------- /src/tconcord3d/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: __init__.py.py 4 | -------------------------------------------------------------------------------- /src/tconcord3d/utils/load_save_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import torch 3 | 4 | 5 | def load_checkpoint(model_load_path, model, map_location=None): 6 | my_model_dict = model.state_dict() 7 | if map_location is not None: 8 | pre_weight = torch.load(model_load_path, map_location=f'cuda:{map_location}') 9 | else: 10 | pre_weight = torch.load(model_load_path) 11 | 12 | part_load = {} 13 | match_size = 0 14 | nomatch_size = 0 15 | for k in pre_weight.keys(): 16 | value = pre_weight[k] 17 | if k[:7] == 'module.': 18 | k=k[7:] 19 | if k in my_model_dict and my_model_dict[k].shape == value.shape: 20 | #print("loading ", k) 21 | match_size += 1 22 | part_load[k] = value 23 | else: 24 | nomatch_size += 1 25 | 26 | print("matched parameter sets: {}, and no matched: {}".format(match_size, nomatch_size)) 27 | 28 | my_model_dict.update(part_load) 29 | model.load_state_dict(my_model_dict) 30 | 31 | return model 32 | 33 | 34 | def load_checkpoint_1b1(model_load_path, model): 35 | my_model_dict = model.state_dict() 36 | pre_weight = torch.load(model_load_path) 37 | 38 | part_load = {} 39 | match_size = 0 40 | nomatch_size = 0 41 | 42 | pre_weight_list = [*pre_weight] 43 | my_model_dict_list = [*my_model_dict] 44 | 45 | for idx in range(len(pre_weight_list)): 46 | key_ = pre_weight_list[idx] 47 | key_2 = my_model_dict_list[idx] 48 | value_ = pre_weight[key_] 49 | if my_model_dict[key_2].shape == pre_weight[key_].shape: 50 | # print("loading ", k) 51 | match_size += 1 52 | part_load[key_2] = value_ 53 | else: 54 | print(key_) 55 | print(key_2) 56 | nomatch_size += 1 57 | 58 | print("matched parameter sets: {}, and no matched: {}".format(match_size, nomatch_size)) 59 | 60 | my_model_dict.update(part_load) 61 | model.load_state_dict(my_model_dict) 62 | 63 | return model 64 | -------------------------------------------------------------------------------- /src/tconcord3d/utils/log_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | def save_to_log(logdir, logfile, message): 3 | f = open(logdir + '/' + logfile, "a") 4 | f.write(message + '\n') 5 | f.close() 6 | return -------------------------------------------------------------------------------- /src/tconcord3d/utils/loss_func.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Awet H. Gebrehiwot 3 | # --------------------------| 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class FocalLoss(nn.Module): 10 | 11 | def __init__(self, weight=None, ignore_index=None, 12 | gamma=2., reduction='none', ssl=False): 13 | nn.Module.__init__(self) 14 | self.ignore_index = ignore_index 15 | self.weight = weight 16 | self.gamma = gamma 17 | self.reduction = reduction 18 | self.ssl = ssl 19 | 20 | def forward(self, input_tensor, target_tensor, lcw=None): 21 | log_prob = F.log_softmax(input_tensor, dim=1) 22 | prob = torch.exp(log_prob) 23 | raw_loss = F.nll_loss( 24 | ((1 - prob) ** self.gamma) * log_prob, 25 | target_tensor, 26 | weight=self.weight, 27 | reduction=self.reduction, 28 | ignore_index=self.ignore_index 29 | ) 30 | 31 | if self.ssl and lcw is not None: 32 | norm_lcw = (lcw/100.0) 33 | weighted_loss = (raw_loss * lcw).mean() 34 | return weighted_loss 35 | else: 36 | return raw_loss.mean() 37 | 38 | 39 | class WeightedFocalLoss(nn.Module): 40 | "Non weighted version of Focal Loss" 41 | def __init__(self, weight=None, ignore_index=None, 42 | gamma=2., reduction='none', ssl=False): 43 | super().__init__() 44 | self.ignore_index = ignore_index 45 | self.weight = weight 46 | self.gamma = gamma 47 | self.reduction = reduction 48 | self.ssl = ssl 49 | 50 | def forward(self, inputs, targets): 51 | inputs = inputs.squeeze() 52 | targets = targets.squeeze() 53 | 54 | BCE_loss = F.cross_entropy(inputs, targets, reduction='none') 55 | pt = torch.exp(-BCE_loss) 56 | F_loss = self.weights[targets]*(1-pt)**self.gamma * BCE_loss 57 | 58 | return F_loss.mean() 59 | -------------------------------------------------------------------------------- /src/tconcord3d/utils/metric_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import numpy as np 3 | 4 | 5 | def fast_hist(pred, label, n): 6 | k = (label >= 0) & (label < n) 7 | bin_count = np.bincount( 8 | n * label[k].astype(int) + pred[k], minlength=n ** 2) 9 | return bin_count[:n ** 2].reshape(n, n) 10 | 11 | 12 | def per_class_iu(hist): 13 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 14 | 15 | 16 | def fast_hist_crop(output, target, unique_label): 17 | hist = fast_hist(output.flatten(), target.flatten(), np.max(unique_label) + 2) 18 | hist = hist[unique_label + 1, :] 19 | hist = hist[:, unique_label + 1] 20 | return hist 21 | 22 | 23 | # TODO: check if this implemented correctly 24 | def fast_ups_crop(uncrt, target, unique_label): 25 | hist = [np.sum(uncrt[target==i]) for i in range(20)] 26 | va, cla_count = np.unique(target, return_counts=True) 27 | class_count = np.zeros(20) 28 | class_count[va] = cla_count 29 | return hist, class_count 30 | -------------------------------------------------------------------------------- /src/tconcord3d/utils/ups.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Awet H. Gebrehiwot 3 | # --------------------------| 4 | 5 | 6 | def enable_dropout(model): 7 | for m in model.modules(): 8 | if m.__class__.__name__.startswith('Dropout'): 9 | m.train() 10 | -------------------------------------------------------------------------------- /src/traversability_estimation/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | -------------------------------------------------------------------------------- /src/traversability_estimation/ransac.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | from math import log, ceil 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | 7 | def num_iters(xi, zeta, m): 8 | if xi < 1e-6: 9 | return int(1e9) 10 | iters = ceil(log(zeta) / log(1 - xi**m)) 11 | return iters 12 | 13 | 14 | def ransac(x, min_sample, get_model, get_inliers, fail_prob=1e-3, 15 | max_iters=10000, inl_ratio=0.0, lo_iters=0, verbosity=0): 16 | """Random Sample Consensus (RANSAC) 17 | 18 | Stochastic parameter estimator which can handle large number of outliers. 19 | 20 | @param x: Data matrix with points in rows. 21 | @param min_sample: Minimum sample size to determine model parameters. 22 | @param get_model: Model constructor called as model = get_model(x[sample]). 23 | Should handle len(sample) >= m. 24 | @param get_inliers: A function called as inliers = get_inliers(x, model). 25 | @param fail_prob: An acceptable probability of not finding the correct solution. 26 | @param max_iters: The maximum number of iterations to perform. 27 | @param inl_ratio: An initial estimate of the inlier ratio. 28 | @param lo_iters: The number of optimization iterations. 29 | If > 0, get_model(x[inliers]) is called for len(inliers) > min_sample. 30 | @param verbosity: Verbosity level. 31 | @return: Tuple of the best model parameters found and its corresponding inliers. 32 | """ 33 | 34 | best_model = None 35 | inliers = [] 36 | 37 | # for i in tqdm(range(max_iters)): 38 | for i in range(max_iters): 39 | if i > max_iters: 40 | break 41 | sample = np.random.choice(len(x), min_sample, replace=False) 42 | model = get_model(x[sample]) 43 | if model is None: 44 | if verbosity > 0: 45 | print('Failed to fit model to sample.') 46 | continue 47 | support = get_inliers(model, x) 48 | if verbosity > 0 and len(support) < min_sample: 49 | print('Support lower than minimal sample.') 50 | 51 | # Local optimization if requested. 52 | if len(support) > min_sample: 53 | for j in range(lo_iters): 54 | new_model = get_model(x[support]) 55 | if new_model is None: 56 | if verbosity > 0: 57 | print('Failed to fit model to support.') 58 | break 59 | new_support = get_inliers(new_model, x) 60 | if len(new_support) < min_sample: 61 | print('Optimized support lower than minimal sample.') 62 | if len(new_support) > len(support): 63 | if verbosity > 0: 64 | print('Improved optimized model %s with %d inliers (prev. %i).' 65 | % (new_model, len(new_support), len(support))) 66 | model = new_model 67 | support = new_support 68 | else: 69 | # Not improving, halt local optimization. 70 | break 71 | 72 | if len(support) > len(inliers): 73 | if verbosity > 0: 74 | print('New best model %s with %i inliers (prev. %i).' 75 | % (model, len(support), len(inliers))) 76 | best_model = model 77 | inliers = support 78 | 79 | inl_ratio = len(support) / len(x) 80 | max_iters = min(max_iters, num_iters(inl_ratio, fail_prob, min_sample)) 81 | 82 | # print('RANSAC finished after %i iterations.' % i) 83 | return best_model, inliers 84 | -------------------------------------------------------------------------------- /src/traversability_estimation/topic_service_proxy.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import rospy 3 | from threading import Event 4 | 5 | __all__ = ['TopicServiceProxy'] 6 | 7 | 8 | class TopicServiceProxy(object): 9 | """Service proxy wrapper around set of input and output topics.""" 10 | 11 | def __init__(self, request, response, queue_size=2, timeout=None, return_incomplete=False): 12 | """Create a service proxy. 13 | 14 | Parameters 15 | - request: A (topic, type) tuple or a list of those. 16 | - response: A (topic, type) tuple or a list of those. 17 | - queue_size: Queue size for publishers and subscribers. 18 | """ 19 | assert request is not None 20 | assert response is not None 21 | assert len(request) > 0 22 | assert len(response) > 0 23 | 24 | if isinstance(request[0], str): 25 | request = [request] 26 | if isinstance(response[0], str): 27 | response = [response] 28 | 29 | self.event = Event() 30 | self.timeout = timeout 31 | self.pubs = len(request) * [None] 32 | self.subs = len(response) * [None] 33 | self.response = len(response) * [None] 34 | for i, (topic, type) in enumerate(request): 35 | self.pubs[i] = rospy.Publisher(topic, type, queue_size=queue_size) 36 | for i, (topic, type) in enumerate(response): 37 | self.subs[i] = rospy.Subscriber(topic, type, lambda msg, i=i: self.callback(msg, i), queue_size=queue_size) 38 | 39 | def clear_response(self): 40 | self.response = len(self.response) * [None] 41 | 42 | def response_empty(self): 43 | return all([msg is None for msg in self.response]) 44 | 45 | def response_complete(self): 46 | return all([msg is not None for msg in self.response]) 47 | 48 | def callback(self, msg, i): 49 | assert self.response[i] is None 50 | self.response[i] = msg 51 | if self.response_complete(): 52 | self.event.set() 53 | 54 | def call(self, msgs): 55 | """Call the service. 56 | 57 | Raises TimeoutError if response messages do not arrive in time. 58 | """ 59 | assert len(msgs) == len(self.pubs) 60 | assert self.response_empty() 61 | assert not self.event.is_set() 62 | self.event.clear() 63 | for i, msg in enumerate(msgs): 64 | self.pubs[i].publish(msg) 65 | if not self.event.wait(self.timeout): 66 | raise TimeoutError('Service call timed out.') 67 | assert self.response_complete() 68 | response = self.response 69 | self.clear_response() 70 | return response 71 | 72 | def __call__(self, msgs): 73 | return self.call(msgs) 74 | 75 | 76 | def test(): 77 | import roslaunch 78 | from std_msgs.msg import String 79 | from time import sleep 80 | 81 | class RosCore(object): 82 | def __init__(self): 83 | uuid = roslaunch.rlutil.get_or_generate_uuid(options_runid=None, options_wait_for_master=False) 84 | roslaunch.configure_logging(uuid) 85 | self.launch = roslaunch.parent.ROSLaunchParent(uuid, roslaunch_files=[], is_core=True) 86 | self.launch.start() 87 | 88 | def __del__(self): 89 | self.launch.shutdown() 90 | 91 | class Repeater(object): 92 | def __init__(self, input, output): 93 | self.pub = rospy.Publisher(output, String, queue_size=1) 94 | self.sub = rospy.Subscriber(input, String, self.callback, queue_size=1) 95 | 96 | def callback(self, msg): 97 | self.pub.publish(msg) 98 | 99 | roscore = RosCore() 100 | sleep(2) 101 | 102 | rospy.init_node('topic_service_proxy_test') 103 | 104 | print('Creating repeater service...') 105 | repeater = Repeater('request', 'response') 106 | sleep(2) 107 | 108 | print('Creating service proxy...') 109 | repeater_proxy = TopicServiceProxy([('request', String)], [('response', String)]) 110 | void_proxy = TopicServiceProxy([('void', String)], [('silent', String)], timeout=1.0) 111 | sleep(2) 112 | 113 | request = String('Hello World!') 114 | 115 | try: 116 | print('Calling repeater service...') 117 | response, = repeater_proxy([request]) 118 | assert response == request 119 | 120 | print('Calling void service...') 121 | response, = void_proxy([request]) 122 | timed_out = False 123 | 124 | except TimeoutError as ex: 125 | timed_out = True 126 | print('Timed out.') 127 | 128 | assert timed_out 129 | 130 | 131 | def main(): 132 | test() 133 | 134 | 135 | if __name__ == '__main__': 136 | main() 137 | --------------------------------------------------------------------------------