├── IMG ├── EOTL-2D-CAR-312.png ├── EOTL-2D-CYC-066.png ├── EOTL-2D-PED-022.png ├── EOTL-3D-CAR-276.png ├── EOTL-3D-CYC-125.png └── EOTL-3D-PED-091.png ├── README.md ├── autoware_tracker ├── CMakeLists.txt ├── LICENSE ├── README.md ├── config │ └── params.yaml ├── launch │ ├── run.launch │ └── rviz.rviz ├── msg │ ├── Centroids.msg │ ├── CloudCluster.msg │ ├── CloudClusterArray.msg │ ├── DetectedObject.msg │ └── DetectedObjectArray.msg ├── package.xml └── src │ ├── detected_objects_visualizer │ ├── visualize_detected_objects.cpp │ ├── visualize_detected_objects.h │ ├── visualize_detected_objects_main.cpp │ ├── visualize_rects.cpp │ ├── visualize_rects.h │ └── visualize_rects_main.cpp │ ├── libkitti │ ├── kitti.cpp │ └── kitti.h │ ├── lidar_euclidean_cluster_detect │ ├── cluster.cpp │ ├── cluster.h │ ├── gencolors.cpp │ ├── lidar_euclidean_cluster_detect.cpp │ └── lidar_euclidean_cluster_detect.cpp.old │ └── lidar_imm_ukf_pda_track │ ├── imm_ukf_pda.cpp │ ├── imm_ukf_pda.cpp.old │ ├── imm_ukf_pda.h │ ├── imm_ukf_pda_main.cpp │ ├── ukf.cpp │ └── ukf.h ├── efficient_det_ros ├── LICENSE ├── __pycache__ │ └── backbone.cpython-37.pyc ├── backbone.py ├── benchmark │ └── coco_eval_result ├── coco_eval.py ├── efficient_det_node.py ├── efficientdet │ ├── __pycache__ │ │ ├── model.cpython-37.pyc │ │ └── utils.cpython-37.pyc │ ├── config.py │ ├── dataset.py │ ├── loss.py │ ├── model.py │ └── utils.py ├── efficientdet_test.py ├── efficientdet_test_videos.py ├── efficientnet │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── model.cpython-37.pyc │ │ ├── utils.cpython-37.pyc │ │ └── utils_extra.cpython-37.pyc │ ├── model.py │ ├── utils.py │ └── utils_extra.py ├── projects │ ├── coco.yml │ ├── kitti.yml │ └── shape.yml ├── readme.md ├── train.py ├── tutorial │ └── train_shape.ipynb ├── utils │ ├── __pycache__ │ │ └── utils.cpython-37.pyc │ ├── sync_batchnorm │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── batchnorm.cpython-37.pyc │ │ │ ├── comm.cpython-37.pyc │ │ │ └── replicate.cpython-37.pyc │ │ ├── batchnorm.py │ │ ├── batchnorm_reimpl.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py │ └── utils.py └── weights │ └── efficientdet-d2.pth ├── kitti_camera_ros ├── .gitignore ├── .travis.yml ├── CMakeLists.txt ├── LICENSE ├── README.md ├── launch │ └── kitti_camera_ros.launch ├── package.xml └── src │ └── kitti_camera_ros.cpp ├── kitti_velodyne_ros ├── .gitignore ├── .travis.yml ├── CMakeLists.txt ├── LICENSE ├── README.md ├── launch │ ├── kitti_velodyne_ros.launch │ ├── kitti_velodyne_ros.rviz │ ├── kitti_velodyne_ros_loam.launch │ └── kitti_velodyne_ros_loam.rviz ├── package.xml └── src │ └── kitti_velodyne_ros.cpp ├── launch ├── efficient_online_learning.launch └── efficient_online_learning.rviz ├── online_forests_ros ├── CMakeLists.txt ├── LICENSE ├── README.md ├── config │ └── orf.conf ├── doc │ └── 2009-OnlineRandomForests.pdf ├── include │ └── online_forests │ │ ├── classifier.h │ │ ├── data.h │ │ ├── hyperparameters.h │ │ ├── onlinenode.h │ │ ├── onlinerf.h │ │ ├── onlinetree.h │ │ ├── randomtest.h │ │ └── utilities.h ├── launch │ └── online_forests_ros.launch ├── model │ ├── dna-test.libsvm │ └── dna-train.libsvm ├── package.xml └── src │ ├── online_forests │ ├── Online-Forest.cpp │ ├── classifier.cpp │ ├── data.cpp │ ├── hyperparameters.cpp │ ├── onlinenode.cpp │ ├── onlinerf.cpp │ ├── onlinetree.cpp │ ├── randomtest.cpp │ └── utilities.cpp │ └── online_forests_ros.cpp ├── online_svm_ros ├── CMakeLists.txt ├── LICENSE ├── README.md ├── config │ └── svm.yaml ├── launch │ └── online_svm_ros.launch ├── package.xml └── src │ └── online_svm_ros.cpp └── point_cloud_features ├── CMakeLists.txt ├── LICENSE ├── README.md ├── include └── point_cloud_features │ └── point_cloud_features.h ├── launch └── point_cloud_feature_extractor.launch ├── package.xml └── src ├── point_cloud_feature_extractor.cpp └── point_cloud_features └── point_cloud_features.cpp /IMG/EOTL-2D-CAR-312.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/IMG/EOTL-2D-CAR-312.png -------------------------------------------------------------------------------- /IMG/EOTL-2D-CYC-066.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/IMG/EOTL-2D-CYC-066.png -------------------------------------------------------------------------------- /IMG/EOTL-2D-PED-022.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/IMG/EOTL-2D-PED-022.png -------------------------------------------------------------------------------- /IMG/EOTL-3D-CAR-276.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/IMG/EOTL-3D-CAR-276.png -------------------------------------------------------------------------------- /IMG/EOTL-3D-CYC-125.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/IMG/EOTL-3D-CYC-125.png -------------------------------------------------------------------------------- /IMG/EOTL-3D-PED-091.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/IMG/EOTL-3D-PED-091.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient Online Transfer Learning for 3D Object Classification in Autonomous Driving # 2 | 3 | *We are actively updating this repository (especially removing hard code and adding comments) to make it easy to use. If you have any questions, please open an issue. Thanks!* 4 | 5 | This is a ROS-based efficient online learning framework for object classification in 3D LiDAR scans, taking advantage of robust multi-target tracking to avoid the need for data annotation by a human expert. 6 | The system is only tested in Ubuntu 18.04 and ROS Melodic (compilation fails on Ubuntu 20.04 and ROS Noetic). 7 | 8 | Please watch the videos below for more details. 9 | 10 | [![YouTube Video 1](https://img.youtube.com/vi/wl5ehOFV5Ac/0.jpg)](https://www.youtube.com/watch?v=wl5ehOFV5Ac) 11 | 12 | ## NEWS 13 | [2023-01-11] Our evaluation results in KITTI 3D OBJECT DETECTION are ranked 276, **91**, 125 in car, pedestrian and cyclist respectively. 14 |
15 | 3D Object Detection 16 | 17 | *CAR* 18 | ![image](https://github.com/epan-utbm/efficient_online_learning/blob/master/IMG/EOTL-3D-CAR-276.png) 19 | 20 | *PEDESTRIAN* 21 | ![image](https://github.com/epan-utbm/efficient_online_learning/blob/master/IMG/EOTL-3D-PED-091.png) 22 | 23 | *CYCLIST* 24 | ![image](https://github.com/epan-utbm/efficient_online_learning/blob/master/IMG/EOTL-3D-CYC-125.png) 25 |
26 | 27 | [2023-01-11] Our evaluation results in KITTI 2D OBJECT DETECTION achieved rankings of **22** on pedestrian and **66** on cyclist! 28 |
29 | 2D Object Detection 30 | 31 | *CAR* 32 | ![image](https://github.com/epan-utbm/efficient_online_learning/blob/master/IMG/EOTL-2D-CAR-312.png) 33 | 34 | *PEDESTRIAN* 35 | ![image](https://github.com/epan-utbm/efficient_online_learning/blob/master/IMG/EOTL-2D-PED-022.png) 36 | 37 | *CYCLIST* 38 | ![image](https://github.com/epan-utbm/efficient_online_learning/blob/master/IMG/EOTL-2D-CYC-066.png) 39 |
40 | 41 | ## Install & Build 42 | Please read the readme file of each sub-package first and install the corresponding dependencies. 43 | 44 | ## Run 45 | ### 1. Prepare dataset 46 | * (Optional) Download the [raw data](http://www.cvlibs.net/datasets/kitti/raw_data.php) from KITTI. 47 | 48 | * (Optional) Download the [sample data](https://github.com/epan-utbm/efficient_online_learning/releases/download/sample_data/2011_09_26_drive_0005_sync.tar) for testing. 49 | 50 | * (Optional) Prepare a customized dataset according to the format of the sample data. 51 | 52 | ### 2. Manual set specific path parameters 53 | # launch/efficient_online_learning 54 | # autoware_tracker/config/params.yaml 55 | 56 | ### 3. Run the project 57 | ```sh 58 | cd catkin_ws 59 | source devel/setup.bash 60 | roslaunch src/efficient_online_learning/launch/efficient_online_learning.launch 61 | ``` 62 | 63 | ## Citation 64 | 65 | If you are considering using this code, please reference the following: 66 | 67 | ``` 68 | @article{yangr23sensors, 69 | author = {Rui Yang and Zhi Yan and Tao Yang and Yaonan Wang and Yassine Ruichek}, 70 | title = {Efficient Online Transfer Learning for Road Participants Detection in Autonomous Driving}, 71 | journal = {IEEE Sensors Journal}, 72 | volume = {23}, 73 | number = {19}, 74 | Pages = {23522--23535}, 75 | year = {2023} 76 | } 77 | 78 | @inproceedings{yangr21itsc, 79 | title={Efficient online transfer learning for 3D object classification in autonomous driving}, 80 | author={Rui Yang and Zhi Yan and Tao Yang and Yassine Ruichek}, 81 | booktitle = {Proceedings of the 2021 IEEE International Conference on Intelligent Transportation Systems (ITSC)}, 82 | pages = {2950--2957}, 83 | address = {Indianapolis, USA}, 84 | month = {September}, 85 | year = {2021} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /autoware_tracker/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(autoware_tracker) 3 | 4 | set(CMAKE_BUILD_TYPE "Release") 5 | set(CMAKE_CXX_FLAGS "-std=c++11") 6 | set(CMAKE_CXX_FLAGS_RELEASE "-O3 -Wall -g -pthread") 7 | 8 | find_package(catkin REQUIRED COMPONENTS 9 | tf 10 | pcl_ros 11 | roscpp 12 | std_msgs 13 | sensor_msgs 14 | geometry_msgs 15 | visualization_msgs 16 | cv_bridge 17 | image_transport 18 | jsk_recognition_msgs 19 | jsk_rviz_plugins 20 | message_generation 21 | ) 22 | 23 | # Messages 24 | add_message_files( 25 | DIRECTORY msg 26 | FILES 27 | Centroids.msg 28 | CloudCluster.msg 29 | CloudClusterArray.msg 30 | DetectedObject.msg 31 | DetectedObjectArray.msg 32 | ) 33 | generate_messages( 34 | DEPENDENCIES 35 | geometry_msgs 36 | jsk_recognition_msgs 37 | sensor_msgs 38 | std_msgs 39 | ) 40 | 41 | catkin_package( 42 | #INCLUDE_DIRS include 43 | CATKIN_DEPENDS 44 | tf 45 | pcl_ros 46 | roscpp 47 | std_msgs 48 | sensor_msgs 49 | geometry_msgs 50 | visualization_msgs 51 | cv_bridge 52 | image_transport 53 | jsk_recognition_msgs 54 | jsk_rviz_plugins 55 | message_runtime 56 | message_generation 57 | ) 58 | 59 | 60 | find_package(OpenMP) 61 | find_package(OpenCV REQUIRED) 62 | find_package(Eigen3 QUIET) 63 | 64 | include_directories( 65 | #include 66 | ${catkin_INCLUDE_DIRS} 67 | ${OpenCV_INCLUDE_DIRS} 68 | ) 69 | 70 | add_executable(visualize_detected_objects 71 | src/detected_objects_visualizer/visualize_detected_objects_main.cpp 72 | src/detected_objects_visualizer/visualize_detected_objects.cpp 73 | ) 74 | add_dependencies(visualize_detected_objects ${catkin_EXPORTED_TARGETS} ${PROJECT_NAME}_generate_messages_cpp) 75 | target_link_libraries(visualize_detected_objects ${OpenCV_LIBRARIES} ${EIGEN3_LIBRARIES} ${catkin_LIBRARIES}) 76 | 77 | add_executable(lidar_euclidean_cluster_detect 78 | src/lidar_euclidean_cluster_detect/lidar_euclidean_cluster_detect.cpp 79 | src/lidar_euclidean_cluster_detect/cluster.cpp 80 | src/libkitti/kitti.cpp) 81 | add_dependencies(lidar_euclidean_cluster_detect ${catkin_EXPORTED_TARGETS} ${PROJECT_NAME}_generate_messages_cpp) 82 | target_link_libraries(lidar_euclidean_cluster_detect ${OpenCV_LIBRARIES} ${catkin_LIBRARIES}) 83 | 84 | add_executable(imm_ukf_pda 85 | src/lidar_imm_ukf_pda_track/imm_ukf_pda_main.cpp 86 | src/lidar_imm_ukf_pda_track/imm_ukf_pda.h 87 | src/lidar_imm_ukf_pda_track/imm_ukf_pda.cpp 88 | src/lidar_imm_ukf_pda_track/ukf.cpp 89 | ) 90 | add_dependencies(imm_ukf_pda ${catkin_EXPORTED_TARGETS} ${PROJECT_NAME}_generate_messages_cpp) 91 | target_link_libraries(imm_ukf_pda ${catkin_LIBRARIES}) 92 | 93 | install(DIRECTORY include/${PROJECT_NAME}/ 94 | DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION} 95 | FILES_MATCHING PATTERN "*.h" 96 | PATTERN ".svn" EXCLUDE) 97 | -------------------------------------------------------------------------------- /autoware_tracker/README.md: -------------------------------------------------------------------------------- 1 | # autoware_tracker 2 | 3 | This pacakge is forked from [https://github.com/TixiaoShan/autoware_tracker](https://github.com/TixiaoShan/autoware_tracker), and the original Readme file is below the dividing line. 4 | 5 | [2020-10-xx]: Added "automatic annotation" for point clouds, please install the dependencies first: `$ sudo apt install ros-melodic-vision-msgs`. 6 | 7 | [2020-09-18]: Added "intensity" to the points, which is essential for our online learning system, as the intensity can help us distinguish objects. 8 | 9 | --- 10 | 11 | # Readme 12 | 13 | Barebone package for point cloud object tracking used in Autoware. The package is only tested in Ubuntu 16.04 and ROS Kinetic. No deep learning is used. 14 | 15 | # Install JSK 16 | ``` 17 | sudo apt-get install ros-kinetic-jsk-recognition-msgs 18 | sudo apt-get install ros-kinetic-jsk-rviz-plugins 19 | ``` 20 | 21 | # Compile 22 | ``` 23 | cd ~/catkin_ws/src 24 | git clone https://github.com/TixiaoShan/autoware_tracker.git 25 | cd ~/catkin_ws 26 | catkin_make -j1 27 | ``` 28 | ```-j1``` is only needed for message generation in the first install. 29 | 30 | # Sample data 31 | 32 | In case you don't have some bag files handy, you can download a sample bag using: 33 | ``` 34 | wget https://autoware-ai.s3.us-east-2.amazonaws.com/sample_moriyama_150324.tar.gz 35 | ``` 36 | 37 | # Demo 38 | 39 | Run the autoware tracker: 40 | ``` 41 | roslaunch autoware_tracker run.launch 42 | ``` 43 | 44 | Play the sample ros bag: 45 | ``` 46 | rosbag play sample_moriyama_150324.bag 47 | ``` 48 | -------------------------------------------------------------------------------- /autoware_tracker/config/params.yaml: -------------------------------------------------------------------------------- 1 | autoware_tracker: 2 | 3 | cluster: 4 | label_source: "/image_detections" 5 | extrinsic_calibration: "/home/epan/Rui/datasets/2011_09_26_drive_0005_sync/calib_2011_09_26.txt" 6 | iou_threshold: 0.5 7 | 8 | points_node: "/points_raw" 9 | output_frame: "velodyne" 10 | 11 | remove_ground: true 12 | 13 | downsample_cloud: true 14 | leaf_size: 0.1 15 | 16 | use_multiple_thres: true 17 | 18 | cluster_size_min: 20 19 | cluster_size_max: 100000 20 | clip_min_height: -2.0 21 | clip_max_height: 0.0 22 | cluster_merge_threshold: 1.5 23 | clustering_distance: 0.75 24 | remove_points_min: 2.0 25 | remove_points_max: 100.0 26 | 27 | keep_lanes: false 28 | keep_lane_left_distance: 5.0 29 | keep_lane_right_distance: 5.0 30 | 31 | use_diffnormals: false 32 | pose_estimation: true 33 | 34 | tracker: 35 | gating_thres: 9.22 36 | gate_probability: 0.99 37 | detection_probability: 0.9 38 | life_time_thres: 8 39 | static_velocity_thres: 0.5 40 | static_num_history_thres: 3 41 | prevent_explosion_thres: 1000 42 | merge_distance_threshold: 0.5 43 | use_sukf: false 44 | 45 | tracking_frame: "/world" 46 | # namespace: /detection/object_tracker/ 47 | track_probability: 0.7 48 | -------------------------------------------------------------------------------- /autoware_tracker/launch/run.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /autoware_tracker/launch/rviz.rviz: -------------------------------------------------------------------------------- 1 | Panels: 2 | - Class: rviz/Displays 3 | Help Height: 0 4 | Name: Displays 5 | Property Tree Widget: 6 | Expanded: 7 | - /Global Options1 8 | - /MarkerArray1 9 | - /MarkerArray1/Namespaces1 10 | Splitter Ratio: 0.505813956 11 | Tree Height: 1124 12 | - Class: rviz/Selection 13 | Name: Selection 14 | - Class: rviz/Tool Properties 15 | Expanded: 16 | - /2D Pose Estimate1 17 | - /2D Nav Goal1 18 | - /Publish Point1 19 | Name: Tool Properties 20 | Splitter Ratio: 0.588679016 21 | - Class: rviz/Views 22 | Expanded: 23 | - /Current View1 24 | Name: Views 25 | Splitter Ratio: 0.5 26 | - Class: rviz/Time 27 | Experimental: false 28 | Name: Time 29 | SyncMode: 0 30 | SyncSource: Point cluster 31 | Toolbars: 32 | toolButtonStyle: 2 33 | Visualization Manager: 34 | Class: "" 35 | Displays: 36 | - Class: rviz/Axes 37 | Enabled: true 38 | Length: 1 39 | Name: Axes 40 | Radius: 0.300000012 41 | Reference Frame: 42 | Value: true 43 | - Class: rviz/Group 44 | Displays: 45 | - Alpha: 1 46 | Autocompute Intensity Bounds: true 47 | Autocompute Value Bounds: 48 | Max Value: 10 49 | Min Value: -10 50 | Value: true 51 | Axis: Z 52 | Channel Name: intensity 53 | Class: rviz/PointCloud2 54 | Color: 255; 255; 255 55 | Color Transformer: Intensity 56 | Decay Time: 0 57 | Enabled: true 58 | Invert Rainbow: false 59 | Max Color: 255; 255; 255 60 | Max Intensity: 142 61 | Min Color: 0; 0; 0 62 | Min Intensity: 0 63 | Name: Point cloud 64 | Position Transformer: XYZ 65 | Queue Size: 10 66 | Selectable: true 67 | Size (Pixels): 1 68 | Size (m): 0.00999999978 69 | Style: Points 70 | Topic: /points_raw 71 | Unreliable: false 72 | Use Fixed Frame: true 73 | Use rainbow: true 74 | Value: true 75 | - Alpha: 1 76 | Autocompute Intensity Bounds: true 77 | Autocompute Value Bounds: 78 | Max Value: 10 79 | Min Value: -10 80 | Value: true 81 | Axis: Z 82 | Channel Name: intensity 83 | Class: rviz/PointCloud2 84 | Color: 255; 255; 255 85 | Color Transformer: RGB8 86 | Decay Time: 0 87 | Enabled: true 88 | Invert Rainbow: false 89 | Max Color: 255; 255; 255 90 | Max Intensity: 4096 91 | Min Color: 0; 0; 0 92 | Min Intensity: 0 93 | Name: Point cluster 94 | Position Transformer: XYZ 95 | Queue Size: 10 96 | Selectable: true 97 | Size (Pixels): 5 98 | Size (m): 0.00999999978 99 | Style: Points 100 | Topic: /autoware_tracker/cluster/points_cluster 101 | Unreliable: false 102 | Use Fixed Frame: true 103 | Use rainbow: true 104 | Value: true 105 | Enabled: true 106 | Name: Lidar cluster 107 | - Class: rviz/MarkerArray 108 | Enabled: true 109 | Marker Topic: /autoware_tracker/visualizer/objects 110 | Name: MarkerArray 111 | Namespaces: 112 | arrow_markers: true 113 | box_markers: false 114 | centroid_markers: true 115 | hull_markers: true 116 | label_markers: true 117 | Queue Size: 100 118 | Value: true 119 | Enabled: true 120 | Global Options: 121 | Background Color: 48; 48; 48 122 | Default Light: true 123 | Fixed Frame: velodyne 124 | Frame Rate: 30 125 | Name: root 126 | Tools: 127 | - Class: rviz/Interact 128 | Hide Inactive Objects: true 129 | - Class: rviz/MoveCamera 130 | - Class: rviz/Select 131 | - Class: rviz/FocusCamera 132 | - Class: rviz/Measure 133 | - Class: rviz/SetInitialPose 134 | Topic: /initialpose 135 | - Class: rviz/SetGoal 136 | Topic: /move_base_simple/goal 137 | - Class: rviz/PublishPoint 138 | Single click: true 139 | Topic: /clicked_point 140 | Value: true 141 | Views: 142 | Current: 143 | Class: rviz/Orbit 144 | Distance: 48.6941605 145 | Enable Stereo Rendering: 146 | Stereo Eye Separation: 0.0599999987 147 | Stereo Focal Distance: 1 148 | Swap Stereo Eyes: false 149 | Value: false 150 | Focal Point: 151 | X: 0 152 | Y: 0 153 | Z: 0 154 | Focal Shape Fixed Size: true 155 | Focal Shape Size: 0.0500000007 156 | Invert Z Axis: false 157 | Name: Current View 158 | Near Clip Distance: 0.00999999978 159 | Pitch: 0.869796932 160 | Target Frame: base_link 161 | Value: Orbit (rviz) 162 | Yaw: 2.58810186 163 | Saved: ~ 164 | Window Geometry: 165 | Displays: 166 | collapsed: false 167 | Height: 1274 168 | Hide Left Dock: false 169 | Hide Right Dock: true 170 | QMainWindow State: 000000ff00000000fd000000040000000000000206000004abfc020000000efb0000001200530065006c0065006300740069006f006e00000001e10000009b0000006300fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afc0000002e000004ab000000dd00fffffffa000000020100000003fb0000000a0049006d0061006700650000000000ffffffff0000000000000000fb0000000c00430061006d0065007200610000000000ffffffff0000000000000000fb000000100044006900730070006c0061007900730100000000000001360000017900fffffffb0000000a0049006d006100670065010000028e000000d20000000000000000fb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261fb000000120049006d006100670065005f0072006100770000000000ffffffff0000000000000000fb0000000c00430061006d006500720061000000024e000001710000000000000000fb000000120049006d00610067006500200052006100770100000421000000160000000000000000fb0000000a0049006d00610067006501000002f4000000cb0000000000000000fb0000000a0049006d006100670065010000056c0000026c0000000000000000000000010000016300000313fc0200000003fb0000000a00560069006500770073000000002e00000313000000b700fffffffb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000001200530065006c0065006300740069006f006e010000025a000000b20000000000000000000000020000073f000000a8fc0100000001fb0000000a00560069006500770073030000004e00000080000002e10000019700000003000006400000005cfc0100000002fb0000000800540069006d00650000000000000006400000038300fffffffb0000000800540069006d00650100000000000004500000000000000000000006ae000004ab00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000 171 | Selection: 172 | collapsed: false 173 | Time: 174 | collapsed: false 175 | Tool Properties: 176 | collapsed: false 177 | Views: 178 | collapsed: true 179 | Width: 2235 180 | X: 692 181 | Y: 316 182 | -------------------------------------------------------------------------------- /autoware_tracker/msg/Centroids.msg: -------------------------------------------------------------------------------- 1 | std_msgs/Header header 2 | geometry_msgs/Point[] points 3 | -------------------------------------------------------------------------------- /autoware_tracker/msg/CloudCluster.msg: -------------------------------------------------------------------------------- 1 | std_msgs/Header header 2 | 3 | uint32 id 4 | string label 5 | float64 score 6 | 7 | sensor_msgs/PointCloud2 cloud 8 | 9 | geometry_msgs/PointStamped min_point 10 | geometry_msgs/PointStamped max_point 11 | geometry_msgs/PointStamped avg_point 12 | geometry_msgs/PointStamped centroid_point 13 | 14 | float64 estimated_angle 15 | 16 | geometry_msgs/Vector3 dimensions 17 | geometry_msgs/Vector3 eigen_values 18 | geometry_msgs/Vector3[] eigen_vectors 19 | 20 | #Array of 33 floats containing the FPFH descriptor 21 | std_msgs/Float32MultiArray fpfh_descriptor 22 | 23 | jsk_recognition_msgs/BoundingBox bounding_box 24 | geometry_msgs/PolygonStamped convex_hull 25 | 26 | # Indicator information 27 | # INDICATOR_LEFT 0 28 | # INDICATOR_RIGHT 1 29 | # INDICATOR_BOTH 2 30 | # INDICATOR_NONE 3 31 | uint32 indicator_state -------------------------------------------------------------------------------- /autoware_tracker/msg/CloudClusterArray.msg: -------------------------------------------------------------------------------- 1 | std_msgs/Header header 2 | CloudCluster[] clusters -------------------------------------------------------------------------------- /autoware_tracker/msg/DetectedObject.msg: -------------------------------------------------------------------------------- 1 | std_msgs/Header header 2 | 3 | uint32 id 4 | string label 5 | float32 score #Score as defined by the detection, Optional 6 | std_msgs/ColorRGBA color # Define this object specific color 7 | bool valid # Defines if this object is valid, or invalid as defined by the filtering 8 | 9 | ################ 3D BB 10 | string space_frame #3D Space coordinate frame of the object, required if pose and dimensions are defines 11 | geometry_msgs/Pose pose 12 | geometry_msgs/Vector3 dimensions 13 | geometry_msgs/Vector3 variance 14 | geometry_msgs/Twist velocity 15 | geometry_msgs/Twist acceleration 16 | 17 | sensor_msgs/PointCloud2 pointcloud 18 | 19 | geometry_msgs/PolygonStamped convex_hull 20 | # autoware_msgs/LaneArray candidate_trajectories 21 | 22 | bool pose_reliable 23 | bool velocity_reliable 24 | bool acceleration_reliable 25 | 26 | ############### 2D Rect 27 | string image_frame # Image coordinate Frame, Required if x,y,w,h defined 28 | int32 x # X coord in image space(pixel) of the initial point of the Rect 29 | int32 y # Y coord in image space(pixel) of the initial point of the Rect 30 | int32 width # Width of the Rect in pixels 31 | int32 height # Height of the Rect in pixels 32 | float32 angle # Angle [0 to 2*PI), allow rotated rects 33 | 34 | sensor_msgs/Image roi_image 35 | 36 | ############### Indicator information 37 | uint8 indicator_state # INDICATOR_LEFT = 0, INDICATOR_RIGHT = 1, INDICATOR_BOTH = 2, INDICATOR_NONE = 3 38 | 39 | ############### Behavior State of the Detected Object 40 | uint8 behavior_state # FORWARD_STATE = 0, STOPPING_STATE = 1, BRANCH_LEFT_STATE = 2, BRANCH_RIGHT_STATE = 3, YIELDING_STATE = 4, ACCELERATING_STATE = 5, SLOWDOWN_STATE = 6 41 | 42 | # 43 | string[] user_defined_info 44 | 45 | # yang21itsc 46 | bool last_sample -------------------------------------------------------------------------------- /autoware_tracker/msg/DetectedObjectArray.msg: -------------------------------------------------------------------------------- 1 | std_msgs/Header header 2 | DetectedObject[] objects -------------------------------------------------------------------------------- /autoware_tracker/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | autoware_tracker 4 | 1.0.0 5 | The autoware tracker package 6 | Tixiao Shan 7 | Apache 2 8 | 9 | catkin 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | jsk_recognition_msgs 20 | jsk_recognition_msgs 21 | 22 | 23 | pcl_ros 24 | roscpp 25 | geometry_msgs 26 | std_msgs 27 | sensor_msgs 28 | tf 29 | visualization_msgs 30 | cv_bridge 31 | image_transport 32 | 33 | 34 | pcl_ros 35 | roscpp 36 | geometry_msgs 37 | std_msgs 38 | sensor_msgs 39 | tf 40 | visualization_msgs 41 | cv_bridge 42 | image_transport 43 | 44 | 45 | 46 | 47 | jsk_rviz_plugins 48 | 49 | 50 | 51 | jsk_rviz_plugins 52 | 53 | 54 | message_generation 55 | message_generation 56 | message_runtime 57 | message_runtime 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /autoware_tracker/src/detected_objects_visualizer/visualize_detected_objects.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018-2019 Autoware Foundation. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | ******************** 17 | * v1.0: amc-nu (abrahammonrroy@yahoo.com) 18 | */ 19 | 20 | #ifndef _VISUALIZEDETECTEDOBJECTS_H 21 | #define _VISUALIZEDETECTEDOBJECTS_H 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | #include 30 | #include 31 | 32 | #include 33 | 34 | #include 35 | #include 36 | 37 | #include "autoware_tracker/DetectedObject.h" 38 | #include "autoware_tracker/DetectedObjectArray.h" 39 | 40 | #define __APP_NAME__ "visualize_detected_objects" 41 | 42 | class VisualizeDetectedObjects 43 | { 44 | private: 45 | const double arrow_height_; 46 | const double label_height_; 47 | const double object_max_linear_size_ = 50.; 48 | double object_speed_threshold_; 49 | double arrow_speed_threshold_; 50 | double marker_display_duration_; 51 | 52 | int marker_id_; 53 | 54 | std_msgs::ColorRGBA label_color_, box_color_, hull_color_, arrow_color_, centroid_color_, model_color_; 55 | 56 | std::string input_topic_, ros_namespace_; 57 | 58 | ros::NodeHandle node_handle_; 59 | ros::Subscriber subscriber_detected_objects_; 60 | 61 | ros::Publisher publisher_markers_; 62 | 63 | visualization_msgs::MarkerArray ObjectsToLabels(const autoware_tracker::DetectedObjectArray &in_objects); 64 | 65 | visualization_msgs::MarkerArray ObjectsToArrows(const autoware_tracker::DetectedObjectArray &in_objects); 66 | 67 | visualization_msgs::MarkerArray ObjectsToBoxes(const autoware_tracker::DetectedObjectArray &in_objects); 68 | 69 | visualization_msgs::MarkerArray ObjectsToModels(const autoware_tracker::DetectedObjectArray &in_objects); 70 | 71 | visualization_msgs::MarkerArray ObjectsToHulls(const autoware_tracker::DetectedObjectArray &in_objects); 72 | 73 | visualization_msgs::MarkerArray ObjectsToCentroids(const autoware_tracker::DetectedObjectArray &in_objects); 74 | 75 | std::string ColorToString(const std_msgs::ColorRGBA &in_color); 76 | 77 | void DetectedObjectsCallback(const autoware_tracker::DetectedObjectArray &in_objects); 78 | 79 | bool IsObjectValid(const autoware_tracker::DetectedObject &in_object); 80 | 81 | float CheckColor(double value); 82 | 83 | float CheckAlpha(double value); 84 | 85 | std_msgs::ColorRGBA ParseColor(const std::vector &in_color); 86 | 87 | public: 88 | VisualizeDetectedObjects(); 89 | }; 90 | 91 | #endif // _VISUALIZEDETECTEDOBJECTS_H 92 | -------------------------------------------------------------------------------- /autoware_tracker/src/detected_objects_visualizer/visualize_detected_objects_main.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018-2019 Autoware Foundation. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | ******************** 17 | * v1.0: amc-nu (abrahammonrroy@yahoo.com) 18 | */ 19 | 20 | #include "visualize_detected_objects.h" 21 | 22 | int main(int argc, char** argv) 23 | { 24 | ros::init(argc, argv, "visualize_detected_objects"); 25 | ros::console::set_logger_level(ROSCONSOLE_DEFAULT_NAME, ros::console::levels::Warn); 26 | 27 | VisualizeDetectedObjects app; 28 | ros::spin(); 29 | 30 | return 0; 31 | } 32 | -------------------------------------------------------------------------------- /autoware_tracker/src/detected_objects_visualizer/visualize_rects.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018-2019 Autoware Foundation. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | ******************** 17 | * v1.0: amc-nu (abrahammonrroy@yahoo.com) 18 | */ 19 | 20 | #include "visualize_rects.h" 21 | 22 | VisualizeRects::VisualizeRects() 23 | { 24 | ros::NodeHandle private_nh_("~"); 25 | 26 | ros::NodeHandle nh; 27 | 28 | std::string image_src_topic; 29 | std::string object_src_topic; 30 | std::string image_out_topic; 31 | 32 | private_nh_.param("image_src", image_src_topic, "/image_raw"); 33 | private_nh_.param("object_src", object_src_topic, "/detection/image_detector/objects"); 34 | private_nh_.param("image_out", image_out_topic, "/image_rects"); 35 | 36 | //get namespace from topic 37 | std::string ros_namespace = image_src_topic; 38 | std::size_t found_pos = ros_namespace.rfind("/");//find last / from topic name to extract namespace 39 | std::cout << ros_namespace << std::endl; 40 | if (found_pos!=std::string::npos) 41 | ros_namespace.erase(found_pos, ros_namespace.length()-found_pos); 42 | std::cout << ros_namespace << std::endl; 43 | image_out_topic = ros_namespace + image_out_topic; 44 | 45 | image_filter_subscriber_ = new message_filters::Subscriber(private_nh_, 46 | image_src_topic, 47 | 1); 48 | ROS_INFO("[%s] image_src: %s", __APP_NAME__, image_src_topic.c_str()); 49 | detection_filter_subscriber_ = new message_filters::Subscriber(private_nh_, 50 | object_src_topic, 51 | 1); 52 | ROS_INFO("[%s] object_src: %s", __APP_NAME__, object_src_topic.c_str()); 53 | 54 | detections_synchronizer_ = 55 | new message_filters::Synchronizer(SyncPolicyT(10), 56 | *image_filter_subscriber_, 57 | *detection_filter_subscriber_); 58 | detections_synchronizer_->registerCallback( 59 | boost::bind(&VisualizeRects::SyncedDetectionsCallback, this, _1, _2)); 60 | 61 | 62 | publisher_image_ = node_handle_.advertise( 63 | image_out_topic, 1); 64 | ROS_INFO("[%s] image_out: %s", __APP_NAME__, image_out_topic.c_str()); 65 | 66 | } 67 | 68 | void 69 | VisualizeRects::SyncedDetectionsCallback( 70 | const sensor_msgs::Image::ConstPtr &in_image_msg, 71 | const autoware_tracker::DetectedObjectArray::ConstPtr &in_objects) 72 | { 73 | try 74 | { 75 | image_ = cv_bridge::toCvShare(in_image_msg, "bgr8")->image; 76 | cv::Mat drawn_image; 77 | drawn_image = ObjectsToRects(image_, in_objects); 78 | sensor_msgs::ImagePtr drawn_msg = cv_bridge::CvImage(in_image_msg->header, "bgr8", drawn_image).toImageMsg(); 79 | publisher_image_.publish(drawn_msg); 80 | } 81 | catch (cv_bridge::Exception& e) 82 | { 83 | ROS_ERROR("[%s] Could not convert from '%s' to 'bgr8'.", __APP_NAME__, in_image_msg->encoding.c_str()); 84 | } 85 | } 86 | 87 | cv::Mat 88 | VisualizeRects::ObjectsToRects(cv::Mat in_image, const autoware_tracker::DetectedObjectArray::ConstPtr &in_objects) 89 | { 90 | cv::Mat final_image = in_image.clone(); 91 | for (auto const &object: in_objects->objects) 92 | { 93 | if (IsObjectValid(object)) 94 | { 95 | cv::Rect rect; 96 | rect.x = object.x; 97 | rect.y = object.y; 98 | rect.width = object.width; 99 | rect.height = object.height; 100 | 101 | if (rect.x+rect.width >= in_image.cols) 102 | rect.width = in_image.cols -rect.x - 1; 103 | 104 | if (rect.y+rect.height >= in_image.rows) 105 | rect.height = in_image.rows -rect.y - 1; 106 | 107 | //draw rectangle 108 | cv::rectangle(final_image, 109 | rect, 110 | cv::Scalar(244,134,66), 111 | 4, 112 | CV_AA); 113 | 114 | //draw label 115 | std::string label = ""; 116 | if (!object.label.empty() && object.label != "unknown") 117 | { 118 | label = object.label; 119 | } 120 | int font_face = cv::FONT_HERSHEY_DUPLEX; 121 | double font_scale = 1.5; 122 | int thickness = 1; 123 | 124 | int baseline=0; 125 | cv::Size text_size = cv::getTextSize(label, 126 | font_face, 127 | font_scale, 128 | thickness, 129 | &baseline); 130 | baseline += thickness; 131 | 132 | cv::Point text_origin(object.x - text_size.height,object.y); 133 | 134 | cv::rectangle(final_image, 135 | text_origin + cv::Point(0, baseline), 136 | text_origin + cv::Point(text_size.width, -text_size.height), 137 | cv::Scalar(0,0,0), 138 | CV_FILLED, 139 | CV_AA, 140 | 0); 141 | 142 | cv::putText(final_image, 143 | label, 144 | text_origin, 145 | font_face, 146 | font_scale, 147 | cv::Scalar::all(255), 148 | thickness, 149 | CV_AA, 150 | false); 151 | 152 | } 153 | } 154 | return final_image; 155 | }//ObjectsToBoxes 156 | 157 | bool VisualizeRects::IsObjectValid(const autoware_tracker::DetectedObject &in_object) 158 | { 159 | if (!in_object.valid || 160 | in_object.width < 0 || 161 | in_object.height < 0 || 162 | in_object.x < 0 || 163 | in_object.y < 0 164 | ) 165 | { 166 | return false; 167 | } 168 | return true; 169 | }//end IsObjectValid -------------------------------------------------------------------------------- /autoware_tracker/src/detected_objects_visualizer/visualize_rects.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018-2019 Autoware Foundation. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | ******************** 17 | * v1.0: amc-nu (abrahammonrroy@yahoo.com) 18 | */ 19 | 20 | #ifndef _VISUALIZERECTS_H 21 | #define _VISUALIZERECTS_H 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | #include 30 | 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | 37 | #include 38 | #include 39 | #include 40 | 41 | #include "autoware_tracker/DetectedObject.h" 42 | #include "autoware_tracker/DetectedObjectArray.h" 43 | 44 | #define __APP_NAME__ "visualize_rects" 45 | 46 | class VisualizeRects 47 | { 48 | private: 49 | std::string input_topic_; 50 | 51 | ros::NodeHandle node_handle_; 52 | ros::Subscriber subscriber_detected_objects_; 53 | image_transport::Subscriber subscriber_image_; 54 | 55 | message_filters::Subscriber *detection_filter_subscriber_; 56 | message_filters::Subscriber *image_filter_subscriber_; 57 | 58 | ros::Publisher publisher_image_; 59 | 60 | cv::Mat image_; 61 | std_msgs::Header image_header_; 62 | 63 | typedef 64 | message_filters::sync_policies::ApproximateTime SyncPolicyT; 66 | 67 | message_filters::Synchronizer 68 | *detections_synchronizer_; 69 | 70 | void 71 | SyncedDetectionsCallback( 72 | const sensor_msgs::Image::ConstPtr &in_image_msg, 73 | const autoware_tracker::DetectedObjectArray::ConstPtr &in_range_detections); 74 | 75 | bool IsObjectValid(const autoware_tracker::DetectedObject &in_object); 76 | 77 | cv::Mat ObjectsToRects(cv::Mat in_image, const autoware_tracker::DetectedObjectArray::ConstPtr &in_objects); 78 | 79 | public: 80 | VisualizeRects(); 81 | }; 82 | 83 | #endif // _VISUALIZERECTS_H 84 | -------------------------------------------------------------------------------- /autoware_tracker/src/detected_objects_visualizer/visualize_rects_main.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018-2019 Autoware Foundation. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | ******************** 17 | * v1.0: amc-nu (abrahammonrroy@yahoo.com) 18 | */ 19 | 20 | #include "visualize_rects.h" 21 | 22 | int main(int argc, char** argv) 23 | { 24 | ros::init(argc, argv, "visualize_rects"); 25 | VisualizeRects app; 26 | ros::spin(); 27 | 28 | return 0; 29 | } 30 | -------------------------------------------------------------------------------- /autoware_tracker/src/lidar_euclidean_cluster_detect/cluster.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Cluster.h 3 | * 4 | * Created on: Oct 19, 2016 5 | * Author: Ne0 6 | */ 7 | #ifndef CLUSTER_H_ 8 | #define CLUSTER_H_ 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include 20 | #include 21 | 22 | #include 23 | #include 24 | #include 25 | 26 | #include 27 | #include 28 | #include 29 | 30 | #include 31 | 32 | #include 33 | #include 34 | 35 | #include 36 | #include 37 | #include 38 | 39 | #include 40 | #include 41 | 42 | #include 43 | #include 44 | 45 | #include 46 | 47 | #include 48 | #include 49 | #include 50 | 51 | #include "autoware_tracker/CloudCluster.h" 52 | 53 | #include "opencv2/core/core.hpp" 54 | #include "opencv2/imgproc/imgproc.hpp" 55 | 56 | #include 57 | #include 58 | #include 59 | 60 | class Cluster 61 | { 62 | pcl::PointCloud::Ptr pointcloud_; 63 | pcl::PointXYZI min_point_; 64 | pcl::PointXYZI max_point_; 65 | pcl::PointXYZI average_point_; 66 | pcl::PointXYZI centroid_; 67 | double orientation_angle_; 68 | float length_, width_, height_; 69 | 70 | jsk_recognition_msgs::BoundingBox bounding_box_; 71 | geometry_msgs::PolygonStamped polygon_; 72 | 73 | std::string label_; 74 | int id_; 75 | int r_, g_, b_; 76 | 77 | Eigen::Matrix3f eigen_vectors_; 78 | Eigen::Vector3f eigen_values_; 79 | 80 | bool valid_cluster_; 81 | 82 | public: 83 | /* \brief Constructor. Creates a Cluster object using the specified points in a PointCloud 84 | * \param[in] in_origin_cloud_ptr Origin PointCloud 85 | * \param[in] in_cluster_indices Indices of the Origin Pointcloud to create the Cluster 86 | * \param[in] in_id ID of the cluster 87 | * \param[in] in_r Amount of Red [0-255] 88 | * \param[in] in_g Amount of Green [0-255] 89 | * \param[in] in_b Amount of Blue [0-255] 90 | * \param[in] in_label Label to identify this cluster (optional) 91 | * \param[in] in_estimate_pose Flag to enable Pose Estimation of the Bounding Box 92 | * */ 93 | void SetCloud(const pcl::PointCloud::Ptr in_origin_cloud_ptr, 94 | const std::vector& in_cluster_indices, std_msgs::Header in_ros_header, int in_id, int in_r, 95 | int in_g, int in_b, std::string in_label, bool in_estimate_pose); 96 | 97 | /* \brief Returns the autoware_tracker::CloudCluster message associated to this Cluster */ 98 | void ToROSMessage(std_msgs::Header in_ros_header, autoware_tracker::CloudCluster& out_cluster_message); 99 | 100 | Cluster(); 101 | virtual ~Cluster(); 102 | 103 | /* \brief Returns the pointer to the PointCloud containing the points in this Cluster */ 104 | pcl::PointCloud::Ptr GetCloud(); 105 | /* \brief Returns the minimum point in the cluster */ 106 | pcl::PointXYZI GetMinPoint(); 107 | /* \brief Returns the maximum point in the cluster*/ 108 | pcl::PointXYZI GetMaxPoint(); 109 | /* \brief Returns the average point in the cluster*/ 110 | pcl::PointXYZI GetAveragePoint(); 111 | /* \brief Returns the centroid point in the cluster */ 112 | pcl::PointXYZI GetCentroid(); 113 | /* \brief Returns the calculated BoundingBox of the object */ 114 | jsk_recognition_msgs::BoundingBox GetBoundingBox(); 115 | /* \brief Returns the calculated PolygonArray of the object */ 116 | geometry_msgs::PolygonStamped GetPolygon(); 117 | /* \brief Returns the angle in radians of the BoundingBox. 0 if pose estimation was not enabled. */ 118 | double GetOrientationAngle(); 119 | /* \brief Returns the Length of the Cluster */ 120 | float GetLenght(); 121 | /* \brief Returns the Width of the Cluster */ 122 | float GetWidth(); 123 | /* \brief Returns the Height of the Cluster */ 124 | float GetHeight(); 125 | /* \brief Returns the Id of the Cluster */ 126 | int GetId(); 127 | /* \brief Returns the Label of the Cluster */ 128 | std::string GetLabel(); 129 | /* \brief Returns the Eigen Vectors of the cluster */ 130 | Eigen::Matrix3f GetEigenVectors(); 131 | /* \brief Returns the Eigen Values of the Cluster */ 132 | Eigen::Vector3f GetEigenValues(); 133 | 134 | /* \brief Returns if the Cluster is marked as valid or not*/ 135 | bool IsValid(); 136 | /* \brief Sets whether the Cluster is valid or not*/ 137 | void SetValidity(bool in_valid); 138 | 139 | /* \brief Returns a pointer to a PointCloud object containing the merged points between current Cluster and the 140 | * specified PointCloud 141 | * \param[in] in_cloud_ptr Origin PointCloud 142 | * */ 143 | pcl::PointCloud::Ptr JoinCloud(const pcl::PointCloud::Ptr in_cloud_ptr); 144 | 145 | /* \brief Calculates and returns a pointer to the FPFH Descriptor of this cluster 146 | * 147 | */ 148 | std::vector GetFpfhDescriptor(const unsigned int& in_ompnum_threads, const double& in_normal_search_radius, 149 | const double& in_fpfh_search_radius); 150 | }; 151 | 152 | typedef boost::shared_ptr ClusterPtr; 153 | 154 | #endif /* CLUSTER_H_ */ 155 | -------------------------------------------------------------------------------- /autoware_tracker/src/lidar_euclidean_cluster_detect/gencolors.cpp: -------------------------------------------------------------------------------- 1 | #ifndef GENCOLORS_CPP_ 2 | #define GENCOLORS_CPP_ 3 | 4 | /*M/////////////////////////////////////////////////////////////////////////////////////// 5 | // 6 | // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 7 | // 8 | // By downloading, copying, installing or using the software you agree to this license. 9 | // If you do not agree to this license, do not download, install, 10 | // copy or use the software. 11 | // 12 | // 13 | // License Agreement 14 | // For Open Source Computer Vision Library 15 | // 16 | // Copyright (C) 2000-2008, Intel Corporation, all rights reserved. 17 | // Copyright (C) 2009, Willow Garage Inc., all rights reserved. 18 | // Third party copyrights are property of their respective owners. 19 | // 20 | // Redistribution and use in source and binary forms, with or without modification, 21 | // are permitted provided that the following conditions are met: 22 | // 23 | // * Redistribution's of source code must retain the above copyright notice, 24 | // this list of conditions and the following disclaimer. 25 | // 26 | // * Redistribution's in binary form must reproduce the above copyright notice, 27 | // this list of conditions and the following disclaimer in the documentation 28 | // and/or other materials provided with the distribution. 29 | // 30 | // * The name of the copyright holders may not be used to endorse or promote products 31 | // derived from this software without specific prior written permission. 32 | // 33 | // This software is provided by the copyright holders and contributors "as is" and 34 | // any express or implied warranties, including, but not limited to, the implied 35 | // warranties of merchantability and fitness for a particular purpose are disclaimed. 36 | // In no event shall the Intel Corporation or contributors be liable for any direct, 37 | // indirect, incidental, special, exemplary, or consequential damages 38 | // (including, but not limited to, procurement of substitute goods or services; 39 | // loss of use, data, or profits; or business interruption) however caused 40 | // and on any theory of liability, whether in contract, strict liability, 41 | // or tort (including negligence or otherwise) arising in any way out of 42 | // the use of this software, even if advised of the possibility of such damage. 43 | // 44 | //M*/ 45 | #include "opencv2/core/core.hpp" 46 | //#include "precomp.hpp" 47 | #include 48 | 49 | #include 50 | 51 | using namespace cv; 52 | 53 | static void downsamplePoints(const Mat& src, Mat& dst, size_t count) 54 | { 55 | CV_Assert(count >= 2); 56 | CV_Assert(src.cols == 1 || src.rows == 1); 57 | CV_Assert(src.total() >= count); 58 | CV_Assert(src.type() == CV_8UC3); 59 | 60 | dst.create(1, (int)count, CV_8UC3); 61 | // TODO: optimize by exploiting symmetry in the distance matrix 62 | Mat dists((int)src.total(), (int)src.total(), CV_32FC1, Scalar(0)); 63 | if (dists.empty()) 64 | std::cerr << "Such big matrix cann't be created." << std::endl; 65 | 66 | for (int i = 0; i < dists.rows; i++) 67 | { 68 | for (int j = i; j < dists.cols; j++) 69 | { 70 | float dist = (float)norm(src.at >(i) - src.at >(j)); 71 | dists.at(j, i) = dists.at(i, j) = dist; 72 | } 73 | } 74 | 75 | double maxVal; 76 | Point maxLoc; 77 | minMaxLoc(dists, 0, &maxVal, 0, &maxLoc); 78 | 79 | dst.at >(0) = src.at >(maxLoc.x); 80 | dst.at >(1) = src.at >(maxLoc.y); 81 | 82 | Mat activedDists(0, dists.cols, dists.type()); 83 | Mat candidatePointsMask(1, dists.cols, CV_8UC1, Scalar(255)); 84 | activedDists.push_back(dists.row(maxLoc.y)); 85 | candidatePointsMask.at(0, maxLoc.y) = 0; 86 | 87 | for (size_t i = 2; i < count; i++) 88 | { 89 | activedDists.push_back(dists.row(maxLoc.x)); 90 | candidatePointsMask.at(0, maxLoc.x) = 0; 91 | 92 | Mat minDists; 93 | reduce(activedDists, minDists, 0, CV_REDUCE_MIN); 94 | minMaxLoc(minDists, 0, &maxVal, 0, &maxLoc, candidatePointsMask); 95 | dst.at >((int)i) = src.at >(maxLoc.x); 96 | } 97 | } 98 | 99 | void generateColors(std::vector& colors, size_t count, size_t factor = 10) 100 | { 101 | if (count < 1) 102 | return; 103 | 104 | colors.resize(count); 105 | 106 | if (count == 1) 107 | { 108 | colors[0] = Scalar(0, 0, 255); // red 109 | return; 110 | } 111 | if (count == 2) 112 | { 113 | colors[0] = Scalar(0, 0, 255); // red 114 | colors[1] = Scalar(0, 255, 0); // green 115 | return; 116 | } 117 | 118 | // Generate a set of colors in RGB space. A size of the set is severel times (=factor) larger then 119 | // the needed count of colors. 120 | Mat bgr(1, (int)(count * factor), CV_8UC3); 121 | randu(bgr, 0, 256); 122 | 123 | // Convert the colors set to Lab space. 124 | // Distances between colors in this space correspond a human perception. 125 | Mat lab; 126 | cvtColor(bgr, lab, cv::COLOR_BGR2Lab); 127 | 128 | // Subsample colors from the generated set so that 129 | // to maximize the minimum distances between each other. 130 | // Douglas-Peucker algorithm is used for this. 131 | Mat lab_subset; 132 | downsamplePoints(lab, lab_subset, count); 133 | 134 | // Convert subsampled colors back to RGB 135 | Mat bgr_subset; 136 | cvtColor(lab_subset, bgr_subset, cv::COLOR_BGR2Lab); 137 | 138 | CV_Assert(bgr_subset.total() == count); 139 | for (size_t i = 0; i < count; i++) 140 | { 141 | Point3_ c = bgr_subset.at >((int)i); 142 | colors[i] = Scalar(c.x, c.y, c.z); 143 | } 144 | } 145 | 146 | #endif // GENCOLORS_CPP 147 | -------------------------------------------------------------------------------- /autoware_tracker/src/lidar_imm_ukf_pda_track/imm_ukf_pda.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018-2019 Autoware Foundation. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef OBJECT_TRACKING_IMM_UKF_JPDAF_H 18 | #define OBJECT_TRACKING_IMM_UKF_JPDAF_H 19 | 20 | 21 | #include 22 | #include 23 | #include 24 | 25 | 26 | #include 27 | #include 28 | 29 | #include 30 | #include 31 | #include 32 | #include 33 | 34 | #include 35 | 36 | // #include 37 | 38 | #include "autoware_tracker/DetectedObject.h" 39 | #include "autoware_tracker/DetectedObjectArray.h" 40 | 41 | #include "ukf.h" 42 | 43 | class ImmUkfPda 44 | { 45 | private: 46 | int target_id_; 47 | bool init_; 48 | double timestamp_; 49 | 50 | std::vector targets_; 51 | 52 | // probabilistic data association params 53 | double gating_thres_; 54 | double gate_probability_; 55 | double detection_probability_; 56 | 57 | // object association param 58 | int life_time_thres_; 59 | 60 | // static classification param 61 | double static_velocity_thres_; 62 | int static_num_history_thres_; 63 | 64 | // switch sukf and ImmUkfPda 65 | bool use_sukf_; 66 | 67 | // whether if benchmarking tracking result 68 | bool is_benchmark_; 69 | int frame_count_; 70 | std::string kitti_data_dir_; 71 | 72 | // for benchmark 73 | std::string result_file_path_; 74 | 75 | // prevent explode param for ukf 76 | double prevent_explosion_thres_; 77 | 78 | // for vectormap assisted tarcking 79 | bool use_vectormap_; 80 | bool has_subscribed_vectormap_; 81 | double lane_direction_chi_thres_; 82 | double nearest_lane_distance_thres_; 83 | std::string vectormap_frame_; 84 | // vector_map::VectorMap vmap_; 85 | // std::vector lanes_; 86 | 87 | double merge_distance_threshold_; 88 | const double CENTROID_DISTANCE = 0.2;//distance to consider centroids the same 89 | 90 | std::string input_topic_; 91 | std::string output_topic_; 92 | 93 | std::string tracking_frame_; 94 | 95 | tf::TransformListener tf_listener_; 96 | tf::StampedTransform local2global_; 97 | tf::StampedTransform tracking_frame2lane_frame_; 98 | tf::StampedTransform lane_frame2tracking_frame_; 99 | 100 | ros::NodeHandle node_handle_; 101 | ros::NodeHandle private_nh_; 102 | ros::Subscriber sub_detected_array_; 103 | ros::Publisher pub_object_array_; 104 | //yang21itsc 105 | ros::Publisher pub_example_array_; 106 | ros::Publisher vis_examples_; 107 | //yang21itsc 108 | 109 | std_msgs::Header input_header_; 110 | 111 | // yang21itsc 112 | std::vector learning_buffer; 113 | double track_probability_; 114 | // yang21itsc 115 | 116 | void callback(const autoware_tracker::DetectedObjectArray& input); 117 | 118 | void transformPoseToGlobal(const autoware_tracker::DetectedObjectArray& input, 119 | autoware_tracker::DetectedObjectArray& transformed_input); 120 | void transformPoseToLocal(autoware_tracker::DetectedObjectArray& detected_objects_output); 121 | 122 | geometry_msgs::Pose getTransformedPose(const geometry_msgs::Pose& in_pose, 123 | const tf::StampedTransform& tf_stamp); 124 | 125 | bool updateNecessaryTransform(); 126 | 127 | void measurementValidation(const autoware_tracker::DetectedObjectArray& input, UKF& target, const bool second_init, 128 | const Eigen::VectorXd& max_det_z, const Eigen::MatrixXd& max_det_s, 129 | std::vector& object_vec, std::vector& matching_vec); 130 | autoware_tracker::DetectedObject getNearestObject(UKF& target, 131 | const std::vector& object_vec); 132 | void updateBehaviorState(const UKF& target, const bool use_sukf, autoware_tracker::DetectedObject& object); 133 | 134 | void initTracker(const autoware_tracker::DetectedObjectArray& input, double timestamp); 135 | void secondInit(UKF& target, const std::vector& object_vec, double dt); 136 | 137 | void updateTrackingNum(const std::vector& object_vec, UKF& target); 138 | 139 | bool probabilisticDataAssociation(const autoware_tracker::DetectedObjectArray& input, const double dt, 140 | std::vector& matching_vec, 141 | std::vector& object_vec, UKF& target); 142 | void makeNewTargets(const double timestamp, const autoware_tracker::DetectedObjectArray& input, 143 | const std::vector& matching_vec); 144 | 145 | void staticClassification(); 146 | 147 | void makeOutput(const autoware_tracker::DetectedObjectArray& input, 148 | const std::vector& matching_vec, 149 | autoware_tracker::DetectedObjectArray& detected_objects_output); 150 | 151 | void removeUnnecessaryTarget(); 152 | 153 | void dumpResultText(autoware_tracker::DetectedObjectArray& detected_objects); 154 | 155 | void tracker(const autoware_tracker::DetectedObjectArray& transformed_input, 156 | autoware_tracker::DetectedObjectArray& detected_objects_output); 157 | 158 | bool updateDirection(const double smallest_nis, const autoware_tracker::DetectedObject& in_object, 159 | autoware_tracker::DetectedObject& out_object, UKF& target); 160 | 161 | bool storeObjectWithNearestLaneDirection(const autoware_tracker::DetectedObject& in_object, 162 | autoware_tracker::DetectedObject& out_object); 163 | 164 | void checkVectormapSubscription(); 165 | 166 | autoware_tracker::DetectedObjectArray 167 | removeRedundantObjects(const autoware_tracker::DetectedObjectArray& in_detected_objects, 168 | const std::vector in_tracker_indices); 169 | 170 | autoware_tracker::DetectedObjectArray 171 | forwardNonMatchedObject(const autoware_tracker::DetectedObjectArray& tmp_objects, 172 | const autoware_tracker::DetectedObjectArray& input, 173 | const std::vector& matching_vec); 174 | 175 | bool 176 | arePointsClose(const geometry_msgs::Point& in_point_a, 177 | const geometry_msgs::Point& in_point_b, 178 | float in_radius); 179 | 180 | bool 181 | arePointsEqual(const geometry_msgs::Point& in_point_a, 182 | const geometry_msgs::Point& in_point_b); 183 | 184 | bool 185 | isPointInPool(const std::vector& in_pool, 186 | const geometry_msgs::Point& in_point); 187 | 188 | void updateTargetWithAssociatedObject(const std::vector& object_vec, 189 | UKF& target); 190 | 191 | public: 192 | ImmUkfPda(); 193 | void run(); 194 | }; 195 | 196 | #endif /* OBJECT_TRACKING_IMM_UKF_JPDAF_H */ 197 | -------------------------------------------------------------------------------- /autoware_tracker/src/lidar_imm_ukf_pda_track/imm_ukf_pda_main.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018-2019 Autoware Foundation. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "imm_ukf_pda.h" 18 | 19 | int main(int argc, char** argv) 20 | { 21 | ros::init(argc, argv, "imm_ukf_pda_tracker"); 22 | ros::console::set_logger_level(ROSCONSOLE_DEFAULT_NAME, ros::console::levels::Warn); 23 | 24 | ImmUkfPda app; 25 | app.run(); 26 | ros::spin(); 27 | return 0; 28 | } 29 | -------------------------------------------------------------------------------- /efficient_det_ros/__pycache__/backbone.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/__pycache__/backbone.cpython-37.pyc -------------------------------------------------------------------------------- /efficient_det_ros/backbone.py: -------------------------------------------------------------------------------- 1 | # Author: Zylo117 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from efficientdet.model import BiFPN, Regressor, Classifier, EfficientNet 7 | from efficientdet.utils import Anchors 8 | 9 | 10 | class EfficientDetBackbone(nn.Module): 11 | def __init__(self, num_classes=80, compound_coef=0, load_weights=False, **kwargs): 12 | super(EfficientDetBackbone, self).__init__() 13 | self.compound_coef = compound_coef 14 | 15 | self.backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6, 7] 16 | self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384, 384] 17 | self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8, 8] 18 | self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536] 19 | self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5, 5] 20 | self.pyramid_levels = [5, 5, 5, 5, 5, 5, 5, 5, 6] 21 | self.anchor_scale = [4., 4., 4., 4., 4., 4., 4., 5., 4.] 22 | self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]) 23 | self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])) 24 | conv_channel_coef = { 25 | # the channels of P3/P4/P5. 26 | 0: [40, 112, 320], 27 | 1: [40, 112, 320], 28 | 2: [48, 120, 352], 29 | 3: [48, 136, 384], 30 | 4: [56, 160, 448], 31 | 5: [64, 176, 512], 32 | 6: [72, 200, 576], 33 | 7: [72, 200, 576], 34 | 8: [80, 224, 640], 35 | } 36 | 37 | num_anchors = len(self.aspect_ratios) * self.num_scales 38 | 39 | self.bifpn = nn.Sequential( 40 | *[BiFPN(self.fpn_num_filters[self.compound_coef], 41 | conv_channel_coef[compound_coef], 42 | True if _ == 0 else False, 43 | attention=True if compound_coef < 6 else False, 44 | use_p8=compound_coef > 7) 45 | for _ in range(self.fpn_cell_repeats[compound_coef])]) 46 | 47 | self.num_classes = num_classes 48 | self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors, 49 | num_layers=self.box_class_repeats[self.compound_coef], 50 | pyramid_levels=self.pyramid_levels[self.compound_coef]) 51 | self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors, 52 | num_classes=num_classes, 53 | num_layers=self.box_class_repeats[self.compound_coef], 54 | pyramid_levels=self.pyramid_levels[self.compound_coef]) 55 | 56 | self.anchors = Anchors(anchor_scale=self.anchor_scale[compound_coef], 57 | pyramid_levels=(torch.arange(self.pyramid_levels[self.compound_coef]) + 3).tolist(), 58 | **kwargs) 59 | 60 | self.backbone_net = EfficientNet(self.backbone_compound_coef[compound_coef], load_weights) 61 | 62 | def freeze_bn(self): 63 | for m in self.modules(): 64 | if isinstance(m, nn.BatchNorm2d): 65 | m.eval() 66 | 67 | def forward(self, inputs): 68 | max_size = inputs.shape[-1] 69 | 70 | _, p3, p4, p5 = self.backbone_net(inputs) 71 | 72 | features = (p3, p4, p5) 73 | features = self.bifpn(features) 74 | 75 | regression = self.regressor(features) 76 | classification = self.classifier(features) 77 | anchors = self.anchors(inputs, inputs.dtype) 78 | 79 | return features, regression, classification, anchors 80 | 81 | def init_backbone(self, path): 82 | state_dict = torch.load(path) 83 | try: 84 | ret = self.load_state_dict(state_dict, strict=False) 85 | print(ret) 86 | except RuntimeError as e: 87 | print('Ignoring ' + str(e) + '"') 88 | -------------------------------------------------------------------------------- /efficient_det_ros/coco_eval.py: -------------------------------------------------------------------------------- 1 | # Author: Zylo117 2 | 3 | """ 4 | COCO-Style Evaluations 5 | 6 | put images here datasets/your_project_name/val_set_name/*.jpg 7 | put annotations here datasets/your_project_name/annotations/instances_{val_set_name}.json 8 | put weights here /path/to/your/weights/*.pth 9 | change compound_coef 10 | 11 | """ 12 | 13 | import json 14 | import os 15 | 16 | import argparse 17 | import torch 18 | import yaml 19 | from tqdm import tqdm 20 | from pycocotools.coco import COCO 21 | from pycocotools.cocoeval import COCOeval 22 | 23 | from backbone import EfficientDetBackbone 24 | from efficientdet.utils import BBoxTransform, ClipBoxes 25 | from utils.utils import preprocess, invert_affine, postprocess, boolean_string 26 | 27 | ap = argparse.ArgumentParser() 28 | ap.add_argument('-p', '--project', type=str, default='coco', help='project file that contains parameters') 29 | ap.add_argument('-c', '--compound_coef', type=int, default=0, help='coefficients of efficientdet') 30 | ap.add_argument('-w', '--weights', type=str, default=None, help='/path/to/weights') 31 | ap.add_argument('--nms_threshold', type=float, default=0.5, help='nms threshold, don\'t change it if not for testing purposes') 32 | ap.add_argument('--cuda', type=boolean_string, default=True) 33 | ap.add_argument('--device', type=int, default=0) 34 | ap.add_argument('--float16', type=boolean_string, default=False) 35 | ap.add_argument('--override', type=boolean_string, default=True, help='override previous bbox results file if exists') 36 | args = ap.parse_args() 37 | 38 | compound_coef = args.compound_coef 39 | nms_threshold = args.nms_threshold 40 | use_cuda = args.cuda 41 | gpu = args.device 42 | use_float16 = args.float16 43 | override_prev_results = args.override 44 | project_name = args.project 45 | weights_path = f'weights/efficientdet-d{compound_coef}.pth' if args.weights is None else args.weights 46 | 47 | print(f'running coco-style evaluation on project {project_name}, weights {weights_path}...') 48 | 49 | params = yaml.safe_load(open(f'projects/{project_name}.yml')) 50 | obj_list = params['obj_list'] 51 | 52 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536] 53 | 54 | 55 | def evaluate_coco(img_path, set_name, image_ids, coco, model, threshold=0.05): 56 | results = [] 57 | 58 | regressBoxes = BBoxTransform() 59 | clipBoxes = ClipBoxes() 60 | 61 | for image_id in tqdm(image_ids): 62 | image_info = coco.loadImgs(image_id)[0] 63 | image_path = img_path + image_info['file_name'] 64 | 65 | ori_imgs, framed_imgs, framed_metas = preprocess(image_path, max_size=input_sizes[compound_coef], mean=params['mean'], std=params['std']) 66 | x = torch.from_numpy(framed_imgs[0]) 67 | 68 | if use_cuda: 69 | x = x.cuda(gpu) 70 | if use_float16: 71 | x = x.half() 72 | else: 73 | x = x.float() 74 | else: 75 | x = x.float() 76 | 77 | x = x.unsqueeze(0).permute(0, 3, 1, 2) 78 | features, regression, classification, anchors = model(x) 79 | 80 | preds = postprocess(x, 81 | anchors, regression, classification, 82 | regressBoxes, clipBoxes, 83 | threshold, nms_threshold) 84 | 85 | if not preds: 86 | continue 87 | 88 | preds = invert_affine(framed_metas, preds)[0] 89 | 90 | scores = preds['scores'] 91 | class_ids = preds['class_ids'] 92 | rois = preds['rois'] 93 | 94 | if rois.shape[0] > 0: 95 | # x1,y1,x2,y2 -> x1,y1,w,h 96 | rois[:, 2] -= rois[:, 0] 97 | rois[:, 3] -= rois[:, 1] 98 | 99 | bbox_score = scores 100 | 101 | for roi_id in range(rois.shape[0]): 102 | score = float(bbox_score[roi_id]) 103 | label = int(class_ids[roi_id]) 104 | box = rois[roi_id, :] 105 | 106 | image_result = { 107 | 'image_id': image_id, 108 | 'category_id': label + 1, 109 | 'score': float(score), 110 | 'bbox': box.tolist(), 111 | } 112 | 113 | results.append(image_result) 114 | 115 | if not len(results): 116 | raise Exception('the model does not provide any valid output, check model architecture and the data input') 117 | 118 | # write output 119 | filepath = f'{set_name}_bbox_results.json' 120 | if os.path.exists(filepath): 121 | os.remove(filepath) 122 | json.dump(results, open(filepath, 'w'), indent=4) 123 | 124 | 125 | def _eval(coco_gt, image_ids, pred_json_path): 126 | # load results in COCO evaluation tool 127 | coco_pred = coco_gt.loadRes(pred_json_path) 128 | 129 | # run COCO evaluation 130 | print('BBox') 131 | coco_eval = COCOeval(coco_gt, coco_pred, 'bbox') 132 | coco_eval.params.imgIds = image_ids 133 | coco_eval.evaluate() 134 | coco_eval.accumulate() 135 | coco_eval.summarize() 136 | 137 | 138 | if __name__ == '__main__': 139 | SET_NAME = params['val_set'] 140 | VAL_GT = f'datasets/{params["project_name"]}/annotations/instances_{SET_NAME}.json' 141 | VAL_IMGS = f'datasets/{params["project_name"]}/{SET_NAME}/' 142 | MAX_IMAGES = 10000 143 | coco_gt = COCO(VAL_GT) 144 | image_ids = coco_gt.getImgIds()[:MAX_IMAGES] 145 | 146 | if override_prev_results or not os.path.exists(f'{SET_NAME}_bbox_results.json'): 147 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list), 148 | ratios=eval(params['anchors_ratios']), scales=eval(params['anchors_scales'])) 149 | model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) 150 | model.requires_grad_(False) 151 | model.eval() 152 | 153 | if use_cuda: 154 | model.cuda(gpu) 155 | 156 | if use_float16: 157 | model.half() 158 | 159 | evaluate_coco(VAL_IMGS, SET_NAME, image_ids, coco_gt, model) 160 | 161 | _eval(coco_gt, image_ids, f'{SET_NAME}_bbox_results.json') 162 | -------------------------------------------------------------------------------- /efficient_det_ros/efficientdet/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/efficientdet/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /efficient_det_ros/efficientdet/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/efficientdet/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /efficient_det_ros/efficientdet/config.py: -------------------------------------------------------------------------------- 1 | COCO_CLASSES = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", 2 | "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", 3 | "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", 4 | "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", 5 | "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", 6 | "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", 7 | "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", 8 | "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", 9 | "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", 10 | "teddy bear", "hair drier", "toothbrush"] 11 | 12 | colors = [(39, 129, 113), (164, 80, 133), (83, 122, 114), (99, 81, 172), (95, 56, 104), (37, 84, 86), (14, 89, 122), 13 | (80, 7, 65), (10, 102, 25), (90, 185, 109), (106, 110, 132), (169, 158, 85), (188, 185, 26), (103, 1, 17), 14 | (82, 144, 81), (92, 7, 184), (49, 81, 155), (179, 177, 69), (93, 187, 158), (13, 39, 73), (12, 50, 60), 15 | (16, 179, 33), (112, 69, 165), (15, 139, 63), (33, 191, 159), (182, 173, 32), (34, 113, 133), (90, 135, 34), 16 | (53, 34, 86), (141, 35, 190), (6, 171, 8), (118, 76, 112), (89, 60, 55), (15, 54, 88), (112, 75, 181), 17 | (42, 147, 38), (138, 52, 63), (128, 65, 149), (106, 103, 24), (168, 33, 45), (28, 136, 135), (86, 91, 108), 18 | (52, 11, 76), (142, 6, 189), (57, 81, 168), (55, 19, 148), (182, 101, 89), (44, 65, 179), (1, 33, 26), 19 | (122, 164, 26), (70, 63, 134), (137, 106, 82), (120, 118, 52), (129, 74, 42), (182, 147, 112), (22, 157, 50), 20 | (56, 50, 20), (2, 22, 177), (156, 100, 106), (21, 35, 42), (13, 8, 121), (142, 92, 28), (45, 118, 33), 21 | (105, 118, 30), (7, 185, 124), (46, 34, 146), (105, 184, 169), (22, 18, 5), (147, 71, 73), (181, 64, 91), 22 | (31, 39, 184), (164, 179, 33), (96, 50, 18), (95, 15, 106), (113, 68, 54), (136, 116, 112), (119, 139, 130), 23 | (31, 139, 34), (66, 6, 127), (62, 39, 2), (49, 99, 180), (49, 119, 155), (153, 50, 183), (125, 38, 3), 24 | (129, 87, 143), (49, 87, 40), (128, 62, 120), (73, 85, 148), (28, 144, 118), (29, 9, 24), (175, 45, 108), 25 | (81, 175, 64), (178, 19, 157), (74, 188, 190), (18, 114, 2), (62, 128, 96), (21, 3, 150), (0, 6, 95), 26 | (2, 20, 184), (122, 37, 185)] 27 | -------------------------------------------------------------------------------- /efficient_det_ros/efficientdet/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from torch.utils.data import Dataset, DataLoader 6 | from pycocotools.coco import COCO 7 | import cv2 8 | 9 | 10 | class CocoDataset(Dataset): 11 | def __init__(self, root_dir, set='train2017', transform=None): 12 | 13 | self.root_dir = root_dir 14 | self.set_name = set 15 | self.transform = transform 16 | 17 | self.coco = COCO(os.path.join(self.root_dir, 'annotations', 'instances_' + self.set_name + '.json')) 18 | self.image_ids = self.coco.getImgIds() 19 | 20 | self.load_classes() 21 | 22 | def load_classes(self): 23 | 24 | # load class names (name -> label) 25 | categories = self.coco.loadCats(self.coco.getCatIds()) 26 | categories.sort(key=lambda x: x['id']) 27 | 28 | self.classes = {} 29 | for c in categories: 30 | self.classes[c['name']] = len(self.classes) 31 | 32 | # also load the reverse (label -> name) 33 | self.labels = {} 34 | for key, value in self.classes.items(): 35 | self.labels[value] = key 36 | 37 | def __len__(self): 38 | return len(self.image_ids) 39 | 40 | def __getitem__(self, idx): 41 | 42 | img = self.load_image(idx) 43 | annot = self.load_annotations(idx) 44 | sample = {'img': img, 'annot': annot} 45 | if self.transform: 46 | sample = self.transform(sample) 47 | return sample 48 | 49 | def load_image(self, image_index): 50 | image_info = self.coco.loadImgs(self.image_ids[image_index])[0] 51 | path = os.path.join(self.root_dir, self.set_name, image_info['file_name']) 52 | img = cv2.imread(path) 53 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 54 | 55 | return img.astype(np.float32) / 255. 56 | 57 | def load_annotations(self, image_index): 58 | # get ground truth annotations 59 | annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False) 60 | annotations = np.zeros((0, 5)) 61 | 62 | # some images appear to miss annotations 63 | if len(annotations_ids) == 0: 64 | return annotations 65 | 66 | # parse annotations 67 | coco_annotations = self.coco.loadAnns(annotations_ids) 68 | for idx, a in enumerate(coco_annotations): 69 | 70 | # some annotations have basically no width / height, skip them 71 | if a['bbox'][2] < 1 or a['bbox'][3] < 1: 72 | continue 73 | 74 | annotation = np.zeros((1, 5)) 75 | annotation[0, :4] = a['bbox'] 76 | annotation[0, 4] = a['category_id'] - 1 77 | annotations = np.append(annotations, annotation, axis=0) 78 | 79 | # transform from [x, y, w, h] to [x1, y1, x2, y2] 80 | annotations[:, 2] = annotations[:, 0] + annotations[:, 2] 81 | annotations[:, 3] = annotations[:, 1] + annotations[:, 3] 82 | 83 | return annotations 84 | 85 | 86 | def collater(data): 87 | imgs = [s['img'] for s in data] 88 | annots = [s['annot'] for s in data] 89 | scales = [s['scale'] for s in data] 90 | 91 | imgs = torch.from_numpy(np.stack(imgs, axis=0)) 92 | 93 | max_num_annots = max(annot.shape[0] for annot in annots) 94 | 95 | if max_num_annots > 0: 96 | 97 | annot_padded = torch.ones((len(annots), max_num_annots, 5)) * -1 98 | 99 | for idx, annot in enumerate(annots): 100 | if annot.shape[0] > 0: 101 | annot_padded[idx, :annot.shape[0], :] = annot 102 | else: 103 | annot_padded = torch.ones((len(annots), 1, 5)) * -1 104 | 105 | imgs = imgs.permute(0, 3, 1, 2) 106 | 107 | return {'img': imgs, 'annot': annot_padded, 'scale': scales} 108 | 109 | 110 | class Resizer(object): 111 | """Convert ndarrays in sample to Tensors.""" 112 | 113 | def __init__(self, img_size=512): 114 | self.img_size = img_size 115 | 116 | def __call__(self, sample): 117 | image, annots = sample['img'], sample['annot'] 118 | height, width, _ = image.shape 119 | if height > width: 120 | scale = self.img_size / height 121 | resized_height = self.img_size 122 | resized_width = int(width * scale) 123 | else: 124 | scale = self.img_size / width 125 | resized_height = int(height * scale) 126 | resized_width = self.img_size 127 | 128 | image = cv2.resize(image, (resized_width, resized_height), interpolation=cv2.INTER_LINEAR) 129 | 130 | new_image = np.zeros((self.img_size, self.img_size, 3)) 131 | new_image[0:resized_height, 0:resized_width] = image 132 | 133 | annots[:, :4] *= scale 134 | 135 | return {'img': torch.from_numpy(new_image).to(torch.float32), 'annot': torch.from_numpy(annots), 'scale': scale} 136 | 137 | 138 | class Augmenter(object): 139 | """Convert ndarrays in sample to Tensors.""" 140 | 141 | def __call__(self, sample, flip_x=0.5): 142 | if np.random.rand() < flip_x: 143 | image, annots = sample['img'], sample['annot'] 144 | image = image[:, ::-1, :] 145 | 146 | rows, cols, channels = image.shape 147 | 148 | x1 = annots[:, 0].copy() 149 | x2 = annots[:, 2].copy() 150 | 151 | x_tmp = x1.copy() 152 | 153 | annots[:, 0] = cols - x2 154 | annots[:, 2] = cols - x_tmp 155 | 156 | sample = {'img': image, 'annot': annots} 157 | 158 | return sample 159 | 160 | 161 | class Normalizer(object): 162 | 163 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 164 | self.mean = np.array([[mean]]) 165 | self.std = np.array([[std]]) 166 | 167 | def __call__(self, sample): 168 | image, annots = sample['img'], sample['annot'] 169 | 170 | return {'img': ((image.astype(np.float32) - self.mean) / self.std), 'annot': annots} 171 | -------------------------------------------------------------------------------- /efficient_det_ros/efficientdet/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | class BBoxTransform(nn.Module): 8 | def forward(self, anchors, regression): 9 | """ 10 | decode_box_outputs adapted from https://github.com/google/automl/blob/master/efficientdet/anchors.py 11 | 12 | Args: 13 | anchors: [batchsize, boxes, (y1, x1, y2, x2)] 14 | regression: [batchsize, boxes, (dy, dx, dh, dw)] 15 | 16 | Returns: 17 | 18 | """ 19 | y_centers_a = (anchors[..., 0] + anchors[..., 2]) / 2 20 | x_centers_a = (anchors[..., 1] + anchors[..., 3]) / 2 21 | ha = anchors[..., 2] - anchors[..., 0] 22 | wa = anchors[..., 3] - anchors[..., 1] 23 | 24 | w = regression[..., 3].exp() * wa 25 | h = regression[..., 2].exp() * ha 26 | 27 | y_centers = regression[..., 0] * ha + y_centers_a 28 | x_centers = regression[..., 1] * wa + x_centers_a 29 | 30 | ymin = y_centers - h / 2. 31 | xmin = x_centers - w / 2. 32 | ymax = y_centers + h / 2. 33 | xmax = x_centers + w / 2. 34 | 35 | return torch.stack([xmin, ymin, xmax, ymax], dim=2) 36 | 37 | 38 | class ClipBoxes(nn.Module): 39 | 40 | def __init__(self): 41 | super(ClipBoxes, self).__init__() 42 | 43 | def forward(self, boxes, img): 44 | batch_size, num_channels, height, width = img.shape 45 | 46 | boxes[:, :, 0] = torch.clamp(boxes[:, :, 0], min=0) 47 | boxes[:, :, 1] = torch.clamp(boxes[:, :, 1], min=0) 48 | 49 | boxes[:, :, 2] = torch.clamp(boxes[:, :, 2], max=width - 1) 50 | boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=height - 1) 51 | 52 | return boxes 53 | 54 | 55 | class Anchors(nn.Module): 56 | """ 57 | adapted and modified from https://github.com/google/automl/blob/master/efficientdet/anchors.py by Zylo117 58 | """ 59 | 60 | def __init__(self, anchor_scale=4., pyramid_levels=None, **kwargs): 61 | super().__init__() 62 | self.anchor_scale = anchor_scale 63 | 64 | if pyramid_levels is None: 65 | self.pyramid_levels = [3, 4, 5, 6, 7] 66 | else: 67 | self.pyramid_levels = pyramid_levels 68 | 69 | self.strides = kwargs.get('strides', [2 ** x for x in self.pyramid_levels]) 70 | self.scales = np.array(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])) 71 | self.ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]) 72 | 73 | self.last_anchors = {} 74 | self.last_shape = None 75 | 76 | def forward(self, image, dtype=torch.float32): 77 | """Generates multiscale anchor boxes. 78 | 79 | Args: 80 | image_size: integer number of input image size. The input image has the 81 | same dimension for width and height. The image_size should be divided by 82 | the largest feature stride 2^max_level. 83 | anchor_scale: float number representing the scale of size of the base 84 | anchor to the feature stride 2^level. 85 | anchor_configs: a dictionary with keys as the levels of anchors and 86 | values as a list of anchor configuration. 87 | 88 | Returns: 89 | anchor_boxes: a numpy array with shape [N, 4], which stacks anchors on all 90 | feature levels. 91 | Raises: 92 | ValueError: input size must be the multiple of largest feature stride. 93 | """ 94 | image_shape = image.shape[2:] 95 | 96 | if image_shape == self.last_shape and image.device in self.last_anchors: 97 | return self.last_anchors[image.device] 98 | 99 | if self.last_shape is None or self.last_shape != image_shape: 100 | self.last_shape = image_shape 101 | 102 | if dtype == torch.float16: 103 | dtype = np.float16 104 | else: 105 | dtype = np.float32 106 | 107 | boxes_all = [] 108 | for stride in self.strides: 109 | boxes_level = [] 110 | for scale, ratio in itertools.product(self.scales, self.ratios): 111 | if image_shape[1] % stride != 0: 112 | raise ValueError('input size must be divided by the stride.') 113 | base_anchor_size = self.anchor_scale * stride * scale 114 | anchor_size_x_2 = base_anchor_size * ratio[0] / 2.0 115 | anchor_size_y_2 = base_anchor_size * ratio[1] / 2.0 116 | 117 | x = np.arange(stride / 2, image_shape[1], stride) 118 | y = np.arange(stride / 2, image_shape[0], stride) 119 | xv, yv = np.meshgrid(x, y) 120 | xv = xv.reshape(-1) 121 | yv = yv.reshape(-1) 122 | 123 | # y1,x1,y2,x2 124 | boxes = np.vstack((yv - anchor_size_y_2, xv - anchor_size_x_2, 125 | yv + anchor_size_y_2, xv + anchor_size_x_2)) 126 | boxes = np.swapaxes(boxes, 0, 1) 127 | boxes_level.append(np.expand_dims(boxes, axis=1)) 128 | # concat anchors on the same level to the reshape NxAx4 129 | boxes_level = np.concatenate(boxes_level, axis=1) 130 | boxes_all.append(boxes_level.reshape([-1, 4])) 131 | 132 | anchor_boxes = np.vstack(boxes_all) 133 | 134 | anchor_boxes = torch.from_numpy(anchor_boxes.astype(dtype)).to(image.device) 135 | anchor_boxes = anchor_boxes.unsqueeze(0) 136 | 137 | # save it for later use to reduce overhead 138 | self.last_anchors[image.device] = anchor_boxes 139 | return anchor_boxes 140 | -------------------------------------------------------------------------------- /efficient_det_ros/efficientdet_test.py: -------------------------------------------------------------------------------- 1 | # Author: Zylo117 2 | 3 | """ 4 | Simple Inference Script of EfficientDet-Pytorch 5 | """ 6 | import time 7 | import torch 8 | from torch.backends import cudnn 9 | from matplotlib import colors 10 | 11 | from backbone import EfficientDetBackbone 12 | import cv2 13 | import numpy as np 14 | 15 | from efficientdet.utils import BBoxTransform, ClipBoxes 16 | from utils.utils import preprocess, invert_affine, postprocess, STANDARD_COLORS, standard_to_bgr, get_index_label, plot_one_box 17 | 18 | compound_coef = 0 19 | force_input_size = None # set None to use default size 20 | img_path = 'test/0000000000.png' 21 | 22 | # replace this part with your project's anchor config 23 | anchor_ratios = [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)] 24 | anchor_scales = [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)] 25 | 26 | threshold = 0.2 27 | iou_threshold = 0.2 28 | 29 | use_cuda = True 30 | use_float16 = False 31 | cudnn.fastest = True 32 | cudnn.benchmark = True 33 | 34 | # obj_list = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 35 | # 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 36 | # 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie', 37 | # 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 38 | # 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 39 | # 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 40 | # 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv', 41 | # 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 42 | # 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 43 | # 'toothbrush'] 44 | 45 | obj_list = ['car', 'person', 'cyclist'] 46 | 47 | 48 | color_list = standard_to_bgr(STANDARD_COLORS) 49 | # tf bilinear interpolation is different from any other's, just make do 50 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536] 51 | input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size 52 | ori_imgs, framed_imgs, framed_metas = preprocess(img_path, max_size=input_size) 53 | 54 | if use_cuda: 55 | x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0) 56 | else: 57 | x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0) 58 | 59 | x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2) 60 | 61 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list), 62 | ratios=anchor_ratios, scales=anchor_scales) 63 | model.load_state_dict(torch.load(f'weights/efficientdet-d{compound_coef}.pth', map_location='cpu')) 64 | model.requires_grad_(False) 65 | model.eval() 66 | 67 | if use_cuda: 68 | model = model.cuda() 69 | if use_float16: 70 | model = model.half() 71 | 72 | with torch.no_grad(): 73 | features, regression, classification, anchors = model(x) 74 | 75 | regressBoxes = BBoxTransform() 76 | clipBoxes = ClipBoxes() 77 | 78 | out = postprocess(x, 79 | anchors, regression, classification, 80 | regressBoxes, clipBoxes, 81 | threshold, iou_threshold) 82 | 83 | def display(preds, imgs, imshow=True, imwrite=False): 84 | for i in range(len(imgs)): 85 | if len(preds[i]['rois']) == 0: 86 | continue 87 | 88 | imgs[i] = imgs[i].copy() 89 | 90 | for j in range(len(preds[i]['rois'])): 91 | x1, y1, x2, y2 = preds[i]['rois'][j].astype(np.int) 92 | obj = obj_list[preds[i]['class_ids'][j]] 93 | score = float(preds[i]['scores'][j]) 94 | plot_one_box(imgs[i], [x1, y1, x2, y2], label=obj,score=score,color=color_list[get_index_label(obj, obj_list)]) 95 | 96 | 97 | if imshow: 98 | cv2.imshow('img', imgs[i]) 99 | cv2.waitKey(0) 100 | 101 | if imwrite: 102 | cv2.imwrite(f'test/img_inferred_d{compound_coef}_this_repo_{i}.jpg', imgs[i]) 103 | 104 | 105 | out = invert_affine(framed_metas, out) 106 | display(out, ori_imgs, imshow=False, imwrite=True) 107 | 108 | print('running speed test...') 109 | with torch.no_grad(): 110 | print('test1: model inferring and postprocessing') 111 | print('inferring image for 10 times...') 112 | t1 = time.time() 113 | for _ in range(10): 114 | _, regression, classification, anchors = model(x) 115 | 116 | out = postprocess(x, 117 | anchors, regression, classification, 118 | regressBoxes, clipBoxes, 119 | threshold, iou_threshold) 120 | out = invert_affine(framed_metas, out) 121 | 122 | t2 = time.time() 123 | tact_time = (t2 - t1) / 10 124 | print(f'{tact_time} seconds, {1 / tact_time} FPS, @batch_size 1') 125 | 126 | # uncomment this if you want a extreme fps test 127 | # print('test2: model inferring only') 128 | # print('inferring images for batch_size 32 for 10 times...') 129 | # t1 = time.time() 130 | # x = torch.cat([x] * 32, 0) 131 | # for _ in range(10): 132 | # _, regression, classification, anchors = model(x) 133 | # 134 | # t2 = time.time() 135 | # tact_time = (t2 - t1) / 10 136 | # print(f'{tact_time} seconds, {32 / tact_time} FPS, @batch_size 32') 137 | -------------------------------------------------------------------------------- /efficient_det_ros/efficientdet_test_videos.py: -------------------------------------------------------------------------------- 1 | # Core Author: Zylo117 2 | # Script's Author: winter2897 3 | 4 | """ 5 | Simple Inference Script of EfficientDet-Pytorch for detecting objects on webcam 6 | """ 7 | import time 8 | import torch 9 | import cv2 10 | import numpy as np 11 | from torch.backends import cudnn 12 | from backbone import EfficientDetBackbone 13 | from efficientdet.utils import BBoxTransform, ClipBoxes 14 | from utils.utils import preprocess, invert_affine, postprocess, preprocess_video 15 | 16 | # Video's path 17 | video_src = 'videotest.mp4' # set int to use webcam, set str to read from a video file 18 | 19 | compound_coef = 0 20 | force_input_size = None # set None to use default size 21 | 22 | threshold = 0.2 23 | iou_threshold = 0.2 24 | 25 | use_cuda = True 26 | use_float16 = False 27 | cudnn.fastest = True 28 | cudnn.benchmark = True 29 | 30 | obj_list = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 31 | 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 32 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie', 33 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 34 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 35 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 36 | 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv', 37 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 38 | 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 39 | 'toothbrush'] 40 | 41 | # tf bilinear interpolation is different from any other's, just make do 42 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536] 43 | input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size 44 | 45 | # load model 46 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list)) 47 | model.load_state_dict(torch.load(f'weights/efficientdet-d{compound_coef}.pth')) 48 | model.requires_grad_(False) 49 | model.eval() 50 | 51 | if use_cuda: 52 | model = model.cuda() 53 | if use_float16: 54 | model = model.half() 55 | 56 | # function for display 57 | def display(preds, imgs): 58 | for i in range(len(imgs)): 59 | if len(preds[i]['rois']) == 0: 60 | return imgs[i] 61 | 62 | for j in range(len(preds[i]['rois'])): 63 | (x1, y1, x2, y2) = preds[i]['rois'][j].astype(np.int) 64 | cv2.rectangle(imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2) 65 | obj = obj_list[preds[i]['class_ids'][j]] 66 | score = float(preds[i]['scores'][j]) 67 | 68 | cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score), 69 | (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 70 | (255, 255, 0), 1) 71 | 72 | return imgs[i] 73 | # Box 74 | regressBoxes = BBoxTransform() 75 | clipBoxes = ClipBoxes() 76 | 77 | # Video capture 78 | cap = cv2.VideoCapture(video_src) 79 | 80 | while True: 81 | ret, frame = cap.read() 82 | if not ret: 83 | break 84 | 85 | # frame preprocessing 86 | ori_imgs, framed_imgs, framed_metas = preprocess_video(frame, max_size=input_size) 87 | 88 | if use_cuda: 89 | x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0) 90 | else: 91 | x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0) 92 | 93 | x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2) 94 | 95 | # model predict 96 | with torch.no_grad(): 97 | features, regression, classification, anchors = model(x) 98 | 99 | out = postprocess(x, 100 | anchors, regression, classification, 101 | regressBoxes, clipBoxes, 102 | threshold, iou_threshold) 103 | 104 | # result 105 | out = invert_affine(framed_metas, out) 106 | img_show = display(out, ori_imgs) 107 | 108 | # show frame by frame 109 | cv2.imshow('frame',img_show) 110 | if cv2.waitKey(1) & 0xFF == ord('q'): 111 | break 112 | 113 | cap.release() 114 | cv2.destroyAllWindows() 115 | 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /efficient_det_ros/efficientnet/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.1" 2 | from .model import EfficientNet 3 | from .utils import ( 4 | GlobalParams, 5 | BlockArgs, 6 | BlockDecoder, 7 | efficientnet, 8 | get_model_params, 9 | ) 10 | 11 | -------------------------------------------------------------------------------- /efficient_det_ros/efficientnet/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/efficientnet/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /efficient_det_ros/efficientnet/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/efficientnet/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /efficient_det_ros/efficientnet/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/efficientnet/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /efficient_det_ros/efficientnet/__pycache__/utils_extra.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/efficientnet/__pycache__/utils_extra.cpython-37.pyc -------------------------------------------------------------------------------- /efficient_det_ros/efficientnet/utils_extra.py: -------------------------------------------------------------------------------- 1 | # Author: Zylo117 2 | 3 | import math 4 | 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Conv2dStaticSamePadding(nn.Module): 10 | """ 11 | created by Zylo117 12 | The real keras/tensorflow conv2d with same padding 13 | """ 14 | 15 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs): 16 | super().__init__() 17 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, 18 | bias=bias, groups=groups) 19 | self.stride = self.conv.stride 20 | self.kernel_size = self.conv.kernel_size 21 | self.dilation = self.conv.dilation 22 | 23 | if isinstance(self.stride, int): 24 | self.stride = [self.stride] * 2 25 | elif len(self.stride) == 1: 26 | self.stride = [self.stride[0]] * 2 27 | 28 | if isinstance(self.kernel_size, int): 29 | self.kernel_size = [self.kernel_size] * 2 30 | elif len(self.kernel_size) == 1: 31 | self.kernel_size = [self.kernel_size[0]] * 2 32 | 33 | def forward(self, x): 34 | h, w = x.shape[-2:] 35 | 36 | extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1] 37 | extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0] 38 | 39 | left = extra_h // 2 40 | right = extra_h - left 41 | top = extra_v // 2 42 | bottom = extra_v - top 43 | 44 | x = F.pad(x, [left, right, top, bottom]) 45 | 46 | x = self.conv(x) 47 | return x 48 | 49 | 50 | class MaxPool2dStaticSamePadding(nn.Module): 51 | """ 52 | created by Zylo117 53 | The real keras/tensorflow MaxPool2d with same padding 54 | """ 55 | 56 | def __init__(self, *args, **kwargs): 57 | super().__init__() 58 | self.pool = nn.MaxPool2d(*args, **kwargs) 59 | self.stride = self.pool.stride 60 | self.kernel_size = self.pool.kernel_size 61 | 62 | if isinstance(self.stride, int): 63 | self.stride = [self.stride] * 2 64 | elif len(self.stride) == 1: 65 | self.stride = [self.stride[0]] * 2 66 | 67 | if isinstance(self.kernel_size, int): 68 | self.kernel_size = [self.kernel_size] * 2 69 | elif len(self.kernel_size) == 1: 70 | self.kernel_size = [self.kernel_size[0]] * 2 71 | 72 | def forward(self, x): 73 | h, w = x.shape[-2:] 74 | 75 | extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1] 76 | extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0] 77 | 78 | left = extra_h // 2 79 | right = extra_h - left 80 | top = extra_v // 2 81 | bottom = extra_v - top 82 | 83 | x = F.pad(x, [left, right, top, bottom]) 84 | 85 | x = self.pool(x) 86 | return x 87 | -------------------------------------------------------------------------------- /efficient_det_ros/projects/coco.yml: -------------------------------------------------------------------------------- 1 | project_name: coco # also the folder name of the dataset that under data_path folder 2 | train_set: train2017 3 | val_set: val2017 4 | num_gpus: 4 5 | 6 | # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco. 7 | mean: [0.485, 0.456, 0.406] 8 | std: [0.229, 0.224, 0.225] 9 | 10 | # this is coco anchors, change it if necessary 11 | anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]' 12 | anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]' 13 | 14 | # must match your dataset's category_id. 15 | # category_id is one_indexed, 16 | # for example, index of 'car' here is 2, while category_id of is 3 17 | obj_list: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 18 | 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 19 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie', 20 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 21 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 22 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 23 | 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv', 24 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 25 | 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 26 | 'toothbrush'] -------------------------------------------------------------------------------- /efficient_det_ros/projects/kitti.yml: -------------------------------------------------------------------------------- 1 | project_name: KITTI_3D_Object # also the folder name of the dataset that under data_path folder 2 | train_set: train 3 | val_set: val 4 | num_gpus: 2 5 | 6 | # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco. 7 | mean: [0.485, 0.456, 0.406] 8 | std: [0.229, 0.224, 0.225] 9 | 10 | # this is coco anchors, change it if necessary 11 | anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]' 12 | anchors_ratios: '[(0.6, 1.5), (1.1, 0.9), (1.5, 0.7)]' 13 | 14 | # must match your dataset's category_id. 15 | # category_id is one_indexed, 16 | # for example, index of 'car' here is 2, while category_id of is 3 17 | obj_list: ['car', 'person', 'cyclist'] 18 | -------------------------------------------------------------------------------- /efficient_det_ros/projects/shape.yml: -------------------------------------------------------------------------------- 1 | project_name: shape # also the folder name of the dataset that under data_path folder 2 | train_set: train 3 | val_set: val 4 | num_gpus: 1 5 | 6 | # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco. 7 | mean: [0.485, 0.456, 0.406] 8 | std: [0.229, 0.224, 0.225] 9 | 10 | # this anchor is adapted to the dataset 11 | anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]' 12 | anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]' 13 | 14 | obj_list: ['rectangle', 'circle'] -------------------------------------------------------------------------------- /efficient_det_ros/utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /efficient_det_ros/utils/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .batchnorm import patch_sync_batchnorm, convert_model 13 | from .replicate import DataParallelWithCallback, patch_replication_callback 14 | -------------------------------------------------------------------------------- /efficient_det_ros/utils/sync_batchnorm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/utils/sync_batchnorm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /efficient_det_ros/utils/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/utils/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /efficient_det_ros/utils/sync_batchnorm/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/utils/sync_batchnorm/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /efficient_det_ros/utils/sync_batchnorm/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/utils/sync_batchnorm/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /efficient_det_ros/utils/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNorm2dReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /efficient_det_ros/utils/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /efficient_det_ros/utils/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /efficient_det_ros/utils/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /efficient_det_ros/weights/efficientdet-d2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/weights/efficientdet-d2.pth -------------------------------------------------------------------------------- /kitti_camera_ros/.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | -------------------------------------------------------------------------------- /kitti_camera_ros/.travis.yml: -------------------------------------------------------------------------------- 1 | dist: bionic 2 | sudo: required 3 | language: generic 4 | cache: apt 5 | 6 | install: 7 | - sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' 8 | - sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654 9 | - sudo apt update 10 | - sudo apt install ros-melodic-desktop-full 11 | - source /opt/ros/melodic/setup.bash 12 | - mkdir -p ~/catkin_ws/src 13 | - cd ~/catkin_ws/ 14 | - catkin_make 15 | - source devel/setup.bash 16 | 17 | script: 18 | - cd ~/catkin_ws/src 19 | - git clone -b melodic https://github.com/epan-utbm/kitti_velodyne_ros.git 20 | - cd ~/catkin_ws 21 | - catkin_make 22 | -------------------------------------------------------------------------------- /kitti_camera_ros/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(kitti_camera_ros) 3 | 4 | find_package(catkin REQUIRED COMPONENTS 5 | roscpp 6 | sensor_msgs 7 | geometry_msgs 8 | visualization_msgs 9 | pcl_conversions 10 | pcl_ros 11 | ) 12 | 13 | find_package(PCL REQUIRED) 14 | 15 | include_directories(include ${catkin_INCLUDE_DIRS} ${PCL_INCLUDE_DIRS}) 16 | 17 | catkin_package() 18 | 19 | add_executable(kitti_camera_ros src/kitti_camera_ros.cpp) 20 | target_link_libraries(kitti_camera_ros ${catkin_LIBRARIES} ${PCL_LIBRARIES}) 21 | if(catkin_EXPORTED_TARGETS) 22 | add_dependencies(kitti_camera_ros ${catkin_EXPORTED_TARGETS}) 23 | endif() 24 | -------------------------------------------------------------------------------- /kitti_camera_ros/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Rui Yang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /kitti_camera_ros/README.md: -------------------------------------------------------------------------------- 1 | # kitti_camera_ros 2 | 3 | Load KITTI camera data, play in ROS. 4 | 5 | ## Usage 6 | 7 | ```console 8 | $ roslaunch kitti_camera_ros kitti_camera_ros.launch 9 | ``` 10 | -------------------------------------------------------------------------------- /kitti_camera_ros/launch/kitti_camera_ros.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /kitti_camera_ros/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | kitti_camera_ros 4 | 0.0.1 5 | Load KITTI camera data, play in ROS 6 | Rui Yang 7 | BSD 8 | 9 | https://github.com/epan-utbm/efficient_online_learning 10 | Rui Yang 11 | 12 | catkin 13 | 14 | roscpp 15 | sensor_msgs 16 | geometry_msgs 17 | visualization_msgs 18 | pcl_conversions 19 | pcl_ros 20 | 21 | roscpp 22 | sensor_msgs 23 | geometry_msgs 24 | visualization_msgs 25 | pcl_conversions 26 | pcl_ros 27 | 28 | -------------------------------------------------------------------------------- /kitti_camera_ros/src/kitti_camera_ros.cpp: -------------------------------------------------------------------------------- 1 | // ROS 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | // PCL 8 | #include 9 | // C++ 10 | #include 11 | // Rui 12 | #include 13 | #include 14 | 15 | int main(int argc, char **argv) { 16 | double frequency; 17 | std::string camera_dir; 18 | double timestamp; 19 | 20 | ros::init(argc, argv, "kitti_camera_ros"); 21 | ros::NodeHandle private_nh("~"); 22 | 23 | ros::Publisher camera_pub = private_nh.advertise("/image_detections", 100, true); 24 | 25 | private_nh.param("frequency", frequency, 10); 26 | private_nh.param("camera_dir", camera_dir, "camera_dir_path"); 27 | 28 | ros::Rate loop_rate(frequency); 29 | 30 | //vision_msgs::Detection2DArray detection_results; 31 | 32 | struct dirent **filelist; 33 | int n_file = scandir(camera_dir.c_str(), &filelist, NULL, alphasort); 34 | if(n_file == -1) { 35 | ROS_ERROR_STREAM("[kitti_camera_ros] Could not open directory: " << camera_dir); 36 | return EXIT_FAILURE; 37 | } else { 38 | ROS_INFO_STREAM("[kitti_camera_ros] Load camera files in " << camera_dir); 39 | ROS_INFO_STREAM("[kitti_camera_ros] frequency (loop rate): " << frequency); 40 | } 41 | 42 | int i_file = 2; // 0 = . 1 = .. 43 | while(ros::ok() && i_file < n_file) { 44 | vision_msgs::Detection2DArray detection_results; 45 | 46 | /*** Camera ***/ 47 | std::string s = camera_dir + filelist[i_file]->d_name; 48 | std::fstream camera_txt(s.c_str(), std::ios::in | std::ios::binary); 49 | //std::cerr << "s: " << s.c_str() << std::endl; 50 | if(!camera_txt.good()) { 51 | ROS_ERROR_STREAM("[kitti_camera_ros] Could not read file: " << s); 52 | return EXIT_FAILURE; 53 | } else { 54 | camera_txt >> timestamp; 55 | ros::Time timestamp_ros(timestamp == 0 ? ros::TIME_MIN.toSec() : timestamp); 56 | detection_results.header.stamp = timestamp_ros; 57 | 58 | //camera_txt.seekg(0, std::ios::beg); 59 | 60 | for(int i = 0; camera_txt.good() && !camera_txt.eof(); i++) { 61 | vision_msgs::Detection2D detection; 62 | vision_msgs::ObjectHypothesisWithPose result; 63 | camera_txt >> detection.bbox.center.x; 64 | camera_txt >> detection.bbox.center.y; 65 | camera_txt >> detection.bbox.size_x; 66 | camera_txt >> detection.bbox.size_y; 67 | camera_txt >> result.id; 68 | camera_txt >> result.score; 69 | detection.results.push_back(result); 70 | detection_results.detections.push_back(detection); 71 | } 72 | camera_txt.close(); 73 | 74 | camera_pub.publish(detection_results); 75 | // ROS_INFO_STREAM("[kitti_camera_ros] detection_results.size " << detection_results.detections.size()); 76 | // ROS_INFO_STREAM("--------------------------------------------"); 77 | // for(int n = 0; n < detection_results.detections.size(); n++) { 78 | // ROS_INFO_STREAM("[kitti_camera_ros] detections.label " << detection_results.detections[n].results[0].id); 79 | // ROS_INFO_STREAM("[kitti_camera_ros] detections.score " << detection_results.detections[n].results[0].score); 80 | // } 81 | } 82 | 83 | ros::spinOnce(); 84 | loop_rate.sleep(); 85 | i_file++; 86 | } 87 | 88 | for(int i = 2; i < n_file; i++) { 89 | free(filelist[i]); 90 | } 91 | free(filelist); 92 | 93 | return EXIT_SUCCESS; 94 | } 95 | -------------------------------------------------------------------------------- /kitti_velodyne_ros/.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | -------------------------------------------------------------------------------- /kitti_velodyne_ros/.travis.yml: -------------------------------------------------------------------------------- 1 | dist: bionic 2 | sudo: required 3 | language: generic 4 | cache: apt 5 | 6 | install: 7 | - sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' 8 | - sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654 9 | - sudo apt update 10 | - sudo apt install ros-melodic-desktop-full 11 | - source /opt/ros/melodic/setup.bash 12 | - mkdir -p ~/catkin_ws/src 13 | - cd ~/catkin_ws/ 14 | - catkin_make 15 | - source devel/setup.bash 16 | 17 | script: 18 | - cd ~/catkin_ws/src 19 | - git clone -b melodic https://github.com/epan-utbm/kitti_velodyne_ros.git 20 | - cd ~/catkin_ws 21 | - catkin_make 22 | -------------------------------------------------------------------------------- /kitti_velodyne_ros/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(kitti_velodyne_ros) 3 | 4 | find_package(catkin REQUIRED COMPONENTS 5 | roscpp 6 | sensor_msgs 7 | geometry_msgs 8 | visualization_msgs 9 | pcl_conversions 10 | pcl_ros 11 | ) 12 | 13 | find_package(PCL REQUIRED) 14 | 15 | include_directories(include ${catkin_INCLUDE_DIRS} ${PCL_INCLUDE_DIRS}) 16 | 17 | catkin_package() 18 | 19 | add_executable(kitti_velodyne_ros src/kitti_velodyne_ros.cpp) 20 | target_link_libraries(kitti_velodyne_ros ${catkin_LIBRARIES} ${PCL_LIBRARIES}) 21 | if(catkin_EXPORTED_TARGETS) 22 | add_dependencies(kitti_velodyne_ros ${catkin_EXPORTED_TARGETS}) 23 | endif() 24 | -------------------------------------------------------------------------------- /kitti_velodyne_ros/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Zhi Yan 4 | Copyright (c) 2021, Rui Yang 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | 1. Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | 2. Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /kitti_velodyne_ros/README.md: -------------------------------------------------------------------------------- 1 | # kitti_velodyne_ros 2 | 3 | [![Build Status](https://travis-ci.org/epan-utbm/kitti_velodyne_ros.svg?branch=melodic)](https://travis-ci.org/epan-utbm/kitti_velodyne_ros) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/24e89caa1d40456f966e039145f64edf)](https://app.codacy.com/gh/epan-utbm/kitti_velodyne_ros?utm_source=github.com&utm_medium=referral&utm_content=epan-utbm/kitti_velodyne_ros&utm_campaign=Badge_Grade_Dashboard) [![License](https://img.shields.io/badge/License-BSD%203--Clause-gree.svg)](https://opensource.org/licenses/BSD-3-Clause) 4 | 5 | Load KITTI velodyne data, play in ROS. 6 | 7 | ## Usage 8 | 9 | ```console 10 | $ roslaunch kitti_velodyne_ros kitti_velodyne_ros.launch 11 | ``` 12 | 13 | If you want to save the point cloud as a csv file, simply activate in [kitti_velodyne_ros.launch](launch/kitti_velodyne_ros.launch) : 14 | 15 | ```console 16 | 17 | ``` 18 | 19 | In case you want to play with [LOAM](https://github.com/laboshinl/loam_velodyne): 20 | 21 | ```console 22 | $ roslaunch kitti_velodyne_ros kitti_velodyne_ros_loam.launch 23 | ``` 24 | -------------------------------------------------------------------------------- /kitti_velodyne_ros/launch/kitti_velodyne_ros.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /kitti_velodyne_ros/launch/kitti_velodyne_ros.rviz: -------------------------------------------------------------------------------- 1 | Panels: 2 | - Class: rviz/Displays 3 | Help Height: 78 4 | Name: Displays 5 | Property Tree Widget: 6 | Expanded: 7 | - /Global Options1 8 | - /Status1 9 | Splitter Ratio: 0.5 10 | Tree Height: 731 11 | - Class: rviz/Selection 12 | Name: Selection 13 | - Class: rviz/Tool Properties 14 | Expanded: 15 | - /2D Pose Estimate1 16 | - /2D Nav Goal1 17 | - /Publish Point1 18 | Name: Tool Properties 19 | Splitter Ratio: 0.5886790156364441 20 | - Class: rviz/Views 21 | Expanded: 22 | - /Current View1 23 | Name: Views 24 | Splitter Ratio: 0.5 25 | - Class: rviz/Time 26 | Experimental: false 27 | Name: Time 28 | SyncMode: 0 29 | SyncSource: PointCloud2 30 | Preferences: 31 | PromptSaveOnExit: true 32 | Toolbars: 33 | toolButtonStyle: 2 34 | Visualization Manager: 35 | Class: "" 36 | Displays: 37 | - Alpha: 1 38 | Autocompute Intensity Bounds: true 39 | Autocompute Value Bounds: 40 | Max Value: 10 41 | Min Value: -10 42 | Value: true 43 | Axis: Z 44 | Channel Name: intensity 45 | Class: rviz/PointCloud2 46 | Color: 255; 255; 255 47 | Color Transformer: Intensity 48 | Decay Time: 0 49 | Enabled: true 50 | Invert Rainbow: false 51 | Max Color: 255; 255; 255 52 | Max Intensity: 0.9900000095367432 53 | Min Color: 0; 0; 0 54 | Min Intensity: 0 55 | Name: PointCloud2 56 | Position Transformer: XYZ 57 | Queue Size: 10 58 | Selectable: true 59 | Size (Pixels): 1 60 | Size (m): 0.019999999552965164 61 | Style: Flat Squares 62 | Topic: /kitti_velodyne_ros/velodyne_points 63 | Unreliable: false 64 | Use Fixed Frame: true 65 | Use rainbow: true 66 | Value: true 67 | - Alpha: 1 68 | Arrow Length: 0.30000001192092896 69 | Axes Length: 0.30000001192092896 70 | Axes Radius: 0.009999999776482582 71 | Class: rviz/PoseArray 72 | Color: 255; 25; 0 73 | Enabled: false 74 | Head Length: 0.07000000029802322 75 | Head Radius: 0.029999999329447746 76 | Name: GroundTruthPose 77 | Shaft Length: 0.23000000417232513 78 | Shaft Radius: 0.009999999776482582 79 | Shape: Arrow (Flat) 80 | Topic: /kitti_velodyne_ros/poses 81 | Unreliable: false 82 | Value: false 83 | - Alpha: 0.5 84 | Cell Size: 1 85 | Class: rviz/Grid 86 | Color: 160; 160; 164 87 | Enabled: false 88 | Line Style: 89 | Line Width: 0.029999999329447746 90 | Value: Lines 91 | Name: Grid 92 | Normal Cell Count: 0 93 | Offset: 94 | X: 0 95 | Y: 0 96 | Z: 0 97 | Plane: XY 98 | Plane Cell Count: 10 99 | Reference Frame: 100 | Value: false 101 | - Class: rviz/TF 102 | Enabled: false 103 | Frame Timeout: 15 104 | Frames: 105 | All Enabled: true 106 | Marker Scale: 7 107 | Name: TF 108 | Show Arrows: true 109 | Show Axes: true 110 | Show Names: true 111 | Tree: 112 | {} 113 | Update Interval: 0 114 | Value: false 115 | - Class: rviz/MarkerArray 116 | Enabled: false 117 | Marker Topic: /kitti_velodyne_ros/markers 118 | Name: MarkerArray 119 | Namespaces: 120 | {} 121 | Queue Size: 100 122 | Value: false 123 | Enabled: true 124 | Global Options: 125 | Background Color: 20; 20; 20 126 | Default Light: true 127 | Fixed Frame: velodyne 128 | Frame Rate: 30 129 | Name: root 130 | Tools: 131 | - Class: rviz/Interact 132 | Hide Inactive Objects: true 133 | - Class: rviz/MoveCamera 134 | - Class: rviz/Select 135 | - Class: rviz/FocusCamera 136 | - Class: rviz/Measure 137 | - Class: rviz/SetInitialPose 138 | Theta std deviation: 0.2617993950843811 139 | Topic: /initialpose 140 | X std deviation: 0.5 141 | Y std deviation: 0.5 142 | - Class: rviz/SetGoal 143 | Topic: /move_base_simple/goal 144 | - Class: rviz/PublishPoint 145 | Single click: true 146 | Topic: /clicked_point 147 | Value: true 148 | Views: 149 | Current: 150 | Class: rviz/Orbit 151 | Distance: 59.19333267211914 152 | Enable Stereo Rendering: 153 | Stereo Eye Separation: 0.05999999865889549 154 | Stereo Focal Distance: 1 155 | Swap Stereo Eyes: false 156 | Value: false 157 | Focal Point: 158 | X: 0.8557648658752441 159 | Y: -1.995847225189209 160 | Z: 0.9883638620376587 161 | Focal Shape Fixed Size: true 162 | Focal Shape Size: 0.05000000074505806 163 | Invert Z Axis: false 164 | Name: Current View 165 | Near Clip Distance: 0.009999999776482582 166 | Pitch: 0.6902026534080505 167 | Target Frame: 168 | Value: Orbit (rviz) 169 | Yaw: 0.9703954458236694 170 | Saved: ~ 171 | Window Geometry: 172 | Displays: 173 | collapsed: false 174 | Height: 1025 175 | Hide Left Dock: false 176 | Hide Right Dock: true 177 | QMainWindow State: 000000ff00000000fd00000004000000000000016a00000366fc0200000008fb0000001200530065006c0065006300740069006f006e00000001e10000009b0000005c00fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000003d00000366000000c900fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261000000010000010f00000396fc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000002800000396000000a400fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e100000197000000030000073d0000003bfc0100000002fb0000000800540069006d006501000000000000073d000002eb00fffffffb0000000800540069006d00650100000000000004500000000000000000000005cd0000036600000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000 178 | Selection: 179 | collapsed: false 180 | Time: 181 | collapsed: false 182 | Tool Properties: 183 | collapsed: false 184 | Views: 185 | collapsed: true 186 | Width: 1853 187 | X: 67 188 | Y: 27 189 | -------------------------------------------------------------------------------- /kitti_velodyne_ros/launch/kitti_velodyne_ros_loam.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /kitti_velodyne_ros/launch/kitti_velodyne_ros_loam.rviz: -------------------------------------------------------------------------------- 1 | Panels: 2 | - Class: rviz/Displays 3 | Help Height: 78 4 | Name: Displays 5 | Property Tree Widget: 6 | Expanded: 7 | - /Global Options1 8 | - /Status1 9 | Splitter Ratio: 0.5 10 | Tree Height: 561 11 | - Class: rviz/Selection 12 | Name: Selection 13 | - Class: rviz/Tool Properties 14 | Expanded: 15 | - /2D Pose Estimate1 16 | - /2D Nav Goal1 17 | - /Publish Point1 18 | Name: Tool Properties 19 | Splitter Ratio: 0.588679016 20 | - Class: rviz/Views 21 | Expanded: 22 | - /Current View1 23 | Name: Views 24 | Splitter Ratio: 0.5 25 | - Class: rviz/Time 26 | Experimental: false 27 | Name: Time 28 | SyncMode: 0 29 | SyncSource: PointCloud2 30 | Visualization Manager: 31 | Class: "" 32 | Displays: 33 | - Alpha: 1 34 | Autocompute Intensity Bounds: true 35 | Autocompute Value Bounds: 36 | Max Value: 10 37 | Min Value: -10 38 | Value: true 39 | Axis: Z 40 | Channel Name: intensity 41 | Class: rviz/PointCloud2 42 | Color: 255; 255; 255 43 | Color Transformer: Intensity 44 | Decay Time: 0 45 | Enabled: true 46 | Invert Rainbow: false 47 | Max Color: 255; 255; 255 48 | Max Intensity: 0.99000001 49 | Min Color: 0; 0; 0 50 | Min Intensity: 0 51 | Name: PointCloud2 52 | Position Transformer: XYZ 53 | Queue Size: 10 54 | Selectable: true 55 | Size (Pixels): 1 56 | Size (m): 0.0199999996 57 | Style: Flat Squares 58 | Topic: /kitti_velodyne_ros/velodyne_points 59 | Unreliable: false 60 | Use Fixed Frame: true 61 | Use rainbow: true 62 | Value: true 63 | - Alpha: 1 64 | Arrow Length: 0.300000012 65 | Axes Length: 0.300000012 66 | Axes Radius: 0.00999999978 67 | Class: rviz/PoseArray 68 | Color: 255; 25; 0 69 | Enabled: false 70 | Head Length: 0.0700000003 71 | Head Radius: 0.0299999993 72 | Name: GroundTruthPose 73 | Shaft Length: 0.230000004 74 | Shaft Radius: 0.00999999978 75 | Shape: Arrow (Flat) 76 | Topic: /kitti_velodyne_ros/poses 77 | Unreliable: false 78 | Value: false 79 | - Alpha: 0.5 80 | Cell Size: 1 81 | Class: rviz/Grid 82 | Color: 160; 160; 164 83 | Enabled: true 84 | Line Style: 85 | Line Width: 0.0299999993 86 | Value: Lines 87 | Name: Grid 88 | Normal Cell Count: 0 89 | Offset: 90 | X: 0 91 | Y: 0 92 | Z: 0 93 | Plane: XY 94 | Plane Cell Count: 10 95 | Reference Frame: 96 | Value: true 97 | - Class: rviz/TF 98 | Enabled: true 99 | Frame Timeout: 15 100 | Frames: 101 | All Enabled: true 102 | base_link: 103 | Value: true 104 | camera_init: 105 | Value: true 106 | map: 107 | Value: true 108 | velodyne: 109 | Value: true 110 | Marker Scale: 7 111 | Name: TF 112 | Show Arrows: true 113 | Show Axes: true 114 | Show Names: true 115 | Tree: 116 | map: 117 | camera_init: 118 | base_link: 119 | velodyne: 120 | {} 121 | Update Interval: 0 122 | Value: true 123 | - Class: rviz/MarkerArray 124 | Enabled: true 125 | Marker Topic: /kitti_velodyne_ros/markers 126 | Name: MarkerArray 127 | Namespaces: 128 | "": true 129 | Queue Size: 100 130 | Value: true 131 | Enabled: true 132 | Global Options: 133 | Background Color: 20; 20; 20 134 | Default Light: true 135 | Fixed Frame: map 136 | Frame Rate: 30 137 | Name: root 138 | Tools: 139 | - Class: rviz/Interact 140 | Hide Inactive Objects: true 141 | - Class: rviz/MoveCamera 142 | - Class: rviz/Select 143 | - Class: rviz/FocusCamera 144 | - Class: rviz/Measure 145 | - Class: rviz/SetInitialPose 146 | Topic: /initialpose 147 | - Class: rviz/SetGoal 148 | Topic: /move_base_simple/goal 149 | - Class: rviz/PublishPoint 150 | Single click: true 151 | Topic: /clicked_point 152 | Value: true 153 | Views: 154 | Current: 155 | Class: rviz/Orbit 156 | Distance: 81.0036621 157 | Enable Stereo Rendering: 158 | Stereo Eye Separation: 0.0599999987 159 | Stereo Focal Distance: 1 160 | Swap Stereo Eyes: false 161 | Value: false 162 | Focal Point: 163 | X: 0.855764866 164 | Y: -1.99584723 165 | Z: 0.988363862 166 | Focal Shape Fixed Size: true 167 | Focal Shape Size: 0.0500000007 168 | Invert Z Axis: false 169 | Name: Current View 170 | Near Clip Distance: 0.00999999978 171 | Pitch: 0.690202653 172 | Target Frame: 173 | Value: Orbit (rviz) 174 | Yaw: 0.970395446 175 | Saved: ~ 176 | Window Geometry: 177 | Displays: 178 | collapsed: false 179 | Height: 839 180 | Hide Left Dock: false 181 | Hide Right Dock: true 182 | QMainWindow State: 000000ff00000000fd00000004000000000000016a000002c0fc0200000008fb0000001200530065006c0065006300740069006f006e00000001e10000009b0000006100fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c0061007900730100000028000002c0000000d700fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261000000010000010f00000396fc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000002800000396000000ad00fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e10000019700000003000005490000003bfc0100000002fb0000000800540069006d00650100000000000005490000030000fffffffb0000000800540069006d00650100000000000004500000000000000000000003d9000002c000000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000 183 | Selection: 184 | collapsed: false 185 | Time: 186 | collapsed: false 187 | Tool Properties: 188 | collapsed: false 189 | Views: 190 | collapsed: true 191 | Width: 1353 192 | X: 65 193 | Y: 24 194 | -------------------------------------------------------------------------------- /kitti_velodyne_ros/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | kitti_velodyne_ros 4 | 0.0.1 5 | Load KITTI velodyne data, play in ROS 6 | Rui Yang 7 | BSD 8 | 9 | https://github.com/epan-utbm/efficient_online_learning 10 | Zhi Yan 11 | 12 | catkin 13 | 14 | roscpp 15 | sensor_msgs 16 | geometry_msgs 17 | visualization_msgs 18 | pcl_conversions 19 | pcl_ros 20 | 21 | roscpp 22 | sensor_msgs 23 | geometry_msgs 24 | visualization_msgs 25 | pcl_conversions 26 | pcl_ros 27 | 28 | -------------------------------------------------------------------------------- /launch/efficient_online_learning.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /online_forests_ros/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(online_forests_ros) 3 | 4 | find_package(catkin REQUIRED COMPONENTS 5 | roscpp 6 | std_msgs 7 | ) 8 | 9 | include_directories( 10 | include 11 | ${catkin_INCLUDE_DIRS} 12 | ) 13 | 14 | catkin_package( 15 | INCLUDE_DIRS include 16 | ) 17 | 18 | add_executable(online_forests_ros 19 | src/online_forests/classifier.cpp 20 | src/online_forests/hyperparameters.cpp 21 | src/online_forests/onlinenode.cpp 22 | src/online_forests/onlinetree.cpp 23 | src/online_forests/utilities.cpp 24 | src/online_forests/data.cpp 25 | # src/online_forests/Online-Forest.cpp 26 | src/online_forests/onlinerf.cpp 27 | src/online_forests/randomtest.cpp 28 | src/online_forests_ros.cpp 29 | ) 30 | 31 | target_link_libraries(online_forests_ros 32 | ${catkin_LIBRARIES} 33 | config++ 34 | blas 35 | ) 36 | 37 | if(catkin_EXPORTED_TARGETS) 38 | add_dependencies(online_forests_ros 39 | ${catkin_EXPORTED_TARGETS} 40 | ) 41 | endif() 42 | -------------------------------------------------------------------------------- /online_forests_ros/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Amir Saffari 4 | Copyright (c) 2020 Zhi Yan 5 | Copyright (c) 2021 Rui Yang 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | 25 | -------------------------------------------------------------------------------- /online_forests_ros/README.md: -------------------------------------------------------------------------------- 1 | # online_forests_ros 2 | 3 | This pacakge is forked from [https://github.com/amirsaffari/online-random-forests](https://github.com/amirsaffari/online-random-forests), and the original Readme file is below the dividing line. 4 | 5 | [2020-09-18]: ROSified the original Online Random Forests, and added support for data stream (not limited to files). 6 | 7 | [2020-09-27]: Added the storage function by the binary tree structure allows the model to be stored and loaded. 8 | 9 | --- 10 | 11 | # Install prerequisites 12 | 13 | ``` 14 | sudo apt install libgmm++-dev libconfig++-dev libatlas-base-dev libblas-dev liblapack-dev 15 | ``` 16 | 17 | --- 18 | 19 | Online Random Forests 20 | ===================== 21 | 22 | This is the original implementation of the Online Random Forest algorithm [1]. There is a more recent implementation of this algorithm at https://github.com/amirsaffari/online-multiclass-lpboost which was used in [2]. 23 | 24 | Read the INSTALL file for build instructions. 25 | 26 | Usage: 27 | ====== 28 | Input arguments: 29 | 30 | -h | --help : will display this message. 31 | -c : path to the config file. 32 | 33 | --ort : use Online Random Tree (ORT) algorithm. 34 | --orf : use Online Random Forest (ORF) algorithm. 35 | 36 | 37 | --train : train the classifier. 38 | --test : test the classifier. 39 | --t2 : train and test the classifier at the same time. 40 | 41 | 42 | Examples: 43 | ./Online-Forest -c conf/orf.conf --orf --t2 44 | 45 | Config file: 46 | ============ 47 | All the settings for the classifier are passed via the config file. You can find the 48 | config file in "conf" folder. It is easy to see what are the meanings behind each of 49 | these settings: 50 | Data: 51 | * trainData = path to the training file 52 | * testData = path to the test file 53 | 54 | Tree: 55 | * maxDepth = maximum depth for a tree 56 | * numRandomTests = number of random tests for each node 57 | * numProjectionFeatures = number of features for hyperplane tests 58 | * counterThreshold = number of samples to be seen for an online node before splitting 59 | 60 | Forest: 61 | * numTrees = number of trees in the forest 62 | * numEpochs = number of online training epochs 63 | * useSoftVoting = boolean flag for using hard or soft voting 64 | 65 | Output: 66 | * savePath = path to save the results (not implemented yet) 67 | * verbose = defines the verbosity level (0: silence) 68 | 69 | Data format: 70 | ============ 71 | The data formats used is very similar to the LIBSVM file formats. It only need to have 72 | one header line which contains the following information: 73 | \#Samples \#Features \#Classes \#FeatureMinIndex 74 | 75 | where 76 | 77 | \#Samples: number of samples 78 | 79 | \#Features: number of features 80 | 81 | \#Classes: number of classes 82 | 83 | \#FeatureMinIndex: the index of the first feature used 84 | 85 | You can find a few datasets in the data folder, check their header to see some examples. 86 | Currently, there is only one limitation with the data files: the classes should be 87 | labeled starting in a regular format and start from 0. For example, for a 3 class problem 88 | the labels should be in {0, 1, 2} set. 89 | 90 | =========== 91 | REFERENCES: 92 | =========== 93 | [1] Amir Saffari, Christian Leistner, Jakob Santner, Martin Godec, and Horst Bischof, 94 | "On-line Random Forests," 95 | 3rd IEEE ICCV Workshop on On-line Computer Vision, 2009. 96 | 97 | [2] Amir Saffari, Martin Godec, Thomas Pock, Christian Leistner, Horst Bischof, 98 | “Online Multi-Class LPBoost“, 99 | Proceedings of IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2010. 100 | -------------------------------------------------------------------------------- /online_forests_ros/config/orf.conf: -------------------------------------------------------------------------------- 1 | Data: 2 | { 3 | trainData = "data/dna-train.libsvm"; 4 | trainLabels = "data/train.label"; 5 | testData = "data/dna-test.libsvm"; 6 | testLabels = "data/test.label"; 7 | }; 8 | Tree: 9 | { 10 | maxDepth = 50; 11 | numRandomTests = 30; 12 | numProjectionFeatures = 2; 13 | counterThreshold = 50; 14 | }; 15 | Forest: 16 | { 17 | numTrees = 100; 18 | numEpochs = 20; 19 | useSoftVoting = 1; 20 | }; 21 | Output: 22 | { 23 | savePath = "/tmp/online-forest-"; 24 | verbose = 1; // 0 = None 25 | }; 26 | -------------------------------------------------------------------------------- /online_forests_ros/doc/2009-OnlineRandomForests.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/online_forests_ros/doc/2009-OnlineRandomForests.pdf -------------------------------------------------------------------------------- /online_forests_ros/include/online_forests/classifier.h: -------------------------------------------------------------------------------- 1 | #ifndef CLASSIFIER_H_ 2 | #define CLASSIFIER_H_ 3 | 4 | #include 5 | 6 | #include "online_forests/data.h" 7 | 8 | using namespace std; 9 | 10 | class Classifier { 11 | public: 12 | virtual void update(Sample &sample) = 0; 13 | virtual void train(DataSet &dataset) = 0; 14 | virtual Result eval(Sample &sample) = 0; 15 | virtual vector test(DataSet & dataset) = 0; 16 | virtual vector trainAndTest(DataSet &dataset_tr, DataSet &dataset_ts) = 0; 17 | 18 | double compError(const vector &results, const DataSet &dataset) { 19 | double error = 0.0; 20 | for (int i = 0; i < dataset.m_numSamples; i++) { 21 | if (results[i].prediction != dataset.m_samples[i].y) { 22 | error++; 23 | } 24 | } 25 | 26 | return error / dataset.m_numSamples; 27 | } 28 | 29 | void dispErrors(const vector &errors) { 30 | for (int i = 0; i < (int) errors.size(); i++) { 31 | cout << i + 1 << ": " << errors[i] << " --- "; 32 | } 33 | cout << endl; 34 | } 35 | }; 36 | 37 | #endif /* CLASSIFIER_H_ */ 38 | -------------------------------------------------------------------------------- /online_forests_ros/include/online_forests/data.h: -------------------------------------------------------------------------------- 1 | #ifndef DATA_H_ 2 | #define DATA_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace std; 10 | using namespace gmm; 11 | 12 | // TYPEDEFS 13 | typedef int Label; 14 | typedef double Weight; 15 | typedef rsvector SparseVector; 16 | 17 | // DATA CLASSES 18 | class Sample { 19 | public: 20 | SparseVector x; 21 | Label y; 22 | Weight w; 23 | 24 | void disp() { 25 | cout << "Sample: y = " << y << ", w = " << w << ", x = "; 26 | cout << x << endl; 27 | } 28 | }; 29 | 30 | class DataSet { 31 | public: 32 | vector m_samples; 33 | int m_numSamples; 34 | int m_numFeatures; 35 | int m_numClasses; 36 | 37 | vector m_minFeatRange; 38 | vector m_maxFeatRange; 39 | 40 | void findFeatRange(); 41 | 42 | void loadLIBSVM(string filename); 43 | void loadLIBSVM2(string data); 44 | }; 45 | 46 | class Result { 47 | public: 48 | vector confidence; 49 | int prediction; 50 | }; 51 | 52 | #endif /* DATA_H_ */ 53 | -------------------------------------------------------------------------------- /online_forests_ros/include/online_forests/hyperparameters.h: -------------------------------------------------------------------------------- 1 | #ifndef HYPERPARAMETERS_H_ 2 | #define HYPERPARAMETERS_H_ 3 | 4 | #include 5 | using namespace std; 6 | 7 | class Hyperparameters 8 | { 9 | public: 10 | Hyperparameters(); 11 | Hyperparameters(const string& confFile); 12 | 13 | // Online node 14 | int numRandomTests; 15 | int numProjectionFeatures; 16 | int counterThreshold; 17 | int maxDepth; 18 | 19 | // Online tree 20 | 21 | // Online forest 22 | int numTrees; 23 | int useSoftVoting; 24 | int numEpochs; 25 | 26 | // Data 27 | string trainData; 28 | string testData; 29 | 30 | // Output 31 | int verbose; 32 | }; 33 | 34 | #endif /* HYPERPARAMETERS_H_ */ 35 | -------------------------------------------------------------------------------- /online_forests_ros/include/online_forests/onlinerf.h: -------------------------------------------------------------------------------- 1 | #ifndef ONLINERF_H_ 2 | #define ONLINERF_H_ 3 | 4 | #include "online_forests/classifier.h" 5 | #include "online_forests/data.h" 6 | #include "online_forests/hyperparameters.h" 7 | #include "online_forests/onlinetree.h" 8 | #include "online_forests/utilities.h" 9 | 10 | class OnlineRF: public Classifier { 11 | public: 12 | OnlineRF(const Hyperparameters &hp, const int &numClasses, const int &numFeatures, const vector &minFeatRange, 13 | const vector &maxFeatRange) : 14 | m_numClasses(&numClasses), m_counter(0.0), m_oobe(0.0), m_hp(&hp) { 15 | OnlineTree *tree; 16 | for (int i = 0; i < hp.numTrees; i++) { 17 | tree = new OnlineTree(hp, numClasses, numFeatures, minFeatRange, maxFeatRange); 18 | m_trees.push_back(tree); 19 | } 20 | } 21 | 22 | ~OnlineRF() { 23 | for (int i = 0; i < m_hp->numTrees; i++) { 24 | delete m_trees[i]; 25 | } 26 | } 27 | 28 | virtual void update(Sample &sample) { 29 | m_counter += sample.w; 30 | 31 | Result result, treeResult; 32 | for (int i = 0; i < *m_numClasses; i++) { 33 | result.confidence.push_back(0.0); 34 | } 35 | 36 | int numTries; 37 | for (int i = 0; i < m_hp->numTrees; i++) { 38 | numTries = poisson(1.0); 39 | if (numTries) { 40 | for (int n = 0; n < numTries; n++) { 41 | m_trees[i]->update(sample); 42 | } 43 | } else { 44 | treeResult = m_trees[i]->eval(sample); 45 | if (m_hp->useSoftVoting) { 46 | add(treeResult.confidence, result.confidence); 47 | } else { 48 | result.confidence[treeResult.prediction]++; 49 | } 50 | } 51 | } 52 | 53 | if (argmax(result.confidence) != sample.y) { 54 | m_oobe += sample.w; 55 | } 56 | } 57 | 58 | virtual void train(DataSet &dataset) { 59 | vector randIndex; 60 | int sampRatio = dataset.m_numSamples / 10; 61 | for (int n = 0; n < m_hp->numEpochs; n++) { 62 | randPerm(dataset.m_numSamples, randIndex); 63 | for (int i = 0; i < dataset.m_numSamples; i++) { 64 | update(dataset.m_samples[randIndex[i]]); 65 | if (m_hp->verbose >= 1 && (i % sampRatio) == 0) { 66 | //cout << "--- Online Random Forest training --- Epoch: " << n + 1 << " --- "; 67 | //cout << (10 * i) / sampRatio << "%" << endl; 68 | } 69 | } 70 | cout << "--- Online Random Forest training --- Epoch: " << n + 1 << " --- " << endl; 71 | } 72 | } 73 | 74 | virtual Result eval(Sample &sample) { 75 | Result result, treeResult; 76 | for (int i = 0; i < *m_numClasses; i++) { 77 | result.confidence.push_back(0.0); 78 | } 79 | 80 | for (int i = 0; i < m_hp->numTrees; i++) { 81 | treeResult = m_trees[i]->eval(sample); 82 | if (m_hp->useSoftVoting) { 83 | add(treeResult.confidence, result.confidence); 84 | } else { 85 | result.confidence[treeResult.prediction]++; 86 | } 87 | } 88 | 89 | scale(result.confidence, 1.0 / m_hp->numTrees); 90 | result.prediction = argmax(result.confidence); 91 | return result; 92 | } 93 | 94 | virtual vector test(DataSet &dataset) { 95 | vector results; 96 | for (int i = 0; i < dataset.m_numSamples; i++) { 97 | results.push_back(eval(dataset.m_samples[i])); 98 | } 99 | 100 | double error = compError(results, dataset); 101 | if (m_hp->verbose >= 1) { 102 | cout << "--- Online Random Forest test error: " << error << endl; 103 | } 104 | 105 | return results; 106 | } 107 | 108 | virtual vector trainAndTest(DataSet &dataset_tr, DataSet &dataset_ts) { 109 | vector results; 110 | vector randIndex; 111 | int sampRatio = dataset_tr.m_numSamples / 10; 112 | vector testError; 113 | for (int n = 0; n < m_hp->numEpochs; n++) { 114 | randPerm(dataset_tr.m_numSamples, randIndex); 115 | for (int i = 0; i < dataset_tr.m_numSamples; i++) { 116 | update(dataset_tr.m_samples[randIndex[i]]); 117 | if (m_hp->verbose >= 1 && (i % sampRatio) == 0) { 118 | cout << "--- Online Random Forest training --- Epoch: " << n + 1 << " --- "; 119 | cout << (10 * i) / sampRatio << "%" << endl; 120 | } 121 | } 122 | 123 | results = test(dataset_ts); 124 | testError.push_back(compError(results, dataset_ts)); 125 | } 126 | 127 | if (m_hp->verbose >= 1) { 128 | cout << endl << "--- Online Random Forest test error over epochs: "; 129 | dispErrors(testError); 130 | } 131 | 132 | return results; 133 | } 134 | 135 | virtual void writeForest(string fileName) { 136 | cout<<"Writing forest"<numTrees; i++) { 140 | m_trees[i]->writeTree(fp); 141 | } 142 | fclose(fp); 143 | cout<<"Writing forest done"<numTrees; i++) { 151 | m_trees[i]->loadTree(fp, i); 152 | } 153 | fclose(fp); 154 | cout<<"Loading forest done"< m_trees; 165 | }; 166 | 167 | #endif /* ONLINERF_H_ */ 168 | -------------------------------------------------------------------------------- /online_forests_ros/include/online_forests/onlinetree.h: -------------------------------------------------------------------------------- 1 | #ifndef ONLINETREE_H_ 2 | #define ONLINETREE_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include "online_forests/classifier.h" 8 | #include "online_forests/data.h" 9 | #include "online_forests/hyperparameters.h" 10 | #include "online_forests/onlinenode.h" 11 | 12 | using namespace std; 13 | 14 | class OnlineTree: public Classifier { 15 | public: 16 | OnlineTree(const Hyperparameters &hp, const int &numClasses, const int &numFeatures, const vector &minFeatRange, 17 | const vector &maxFeatRange) : 18 | m_counter(0.0), m_hp(&hp) { 19 | m_rootNode = new OnlineNode(hp, numClasses, numFeatures, minFeatRange, maxFeatRange, 0); 20 | } 21 | 22 | ~OnlineTree() { 23 | delete m_rootNode; 24 | } 25 | 26 | virtual void update(Sample &sample) { 27 | m_rootNode->update(sample); 28 | } 29 | 30 | virtual void train(DataSet &dataset) { 31 | vector randIndex; 32 | int sampRatio = dataset.m_numSamples / 10; 33 | for (int n = 0; n < m_hp->numEpochs; n++) { 34 | randPerm(dataset.m_numSamples, randIndex); 35 | for (int i = 0; i < dataset.m_numSamples; i++) { 36 | update(dataset.m_samples[randIndex[i]]); 37 | if (m_hp->verbose >= 3 && (i % sampRatio) == 0) { 38 | cout << "--- Online Random Tree training --- Epoch: " << n + 1 << " --- "; 39 | cout << (10 * i) / sampRatio << "%" << endl; 40 | } 41 | } 42 | } 43 | } 44 | 45 | virtual Result eval(Sample &sample) { 46 | return m_rootNode->eval(sample); 47 | } 48 | 49 | virtual vector test(DataSet &dataset) { 50 | vector results; 51 | for (int i = 0; i < dataset.m_numSamples; i++) { 52 | results.push_back(eval(dataset.m_samples[i])); 53 | } 54 | 55 | double error = compError(results, dataset); 56 | if (m_hp->verbose >= 3) { 57 | cout << "--- Online Random Tree test error: " << error << endl; 58 | } 59 | 60 | return results; 61 | } 62 | 63 | virtual vector trainAndTest(DataSet &dataset_tr, DataSet &dataset_ts) { 64 | vector results; 65 | vector randIndex; 66 | int sampRatio = dataset_tr.m_numSamples / 10; 67 | vector testError; 68 | for (int n = 0; n < m_hp->numEpochs; n++) { 69 | randPerm(dataset_tr.m_numSamples, randIndex); 70 | for (int i = 0; i < dataset_tr.m_numSamples; i++) { 71 | update(dataset_tr.m_samples[randIndex[i]]); 72 | if (m_hp->verbose >= 3 && (i % sampRatio) == 0) { 73 | cout << "--- Online Random Tree training --- Epoch: " << n + 1 << " --- "; 74 | cout << (10 * i) / sampRatio << "%" << endl; 75 | } 76 | } 77 | 78 | results = test(dataset_ts); 79 | testError.push_back(compError(results, dataset_ts)); 80 | } 81 | 82 | if (m_hp->verbose >= 3) { 83 | cout << endl << "--- Online Random Tree test error over epochs: "; 84 | dispErrors(testError); 85 | } 86 | 87 | return results; 88 | } 89 | 90 | virtual void writeTree(FILE * fp) { 91 | m_rootNode->writeNode(m_rootNode, m_counter, fp); 92 | fprintf(fp,"T\n"); 93 | } 94 | 95 | virtual void loadTree(FILE * fp, int tree_index) { 96 | m_rootNode->loadNode(m_rootNode, fp); 97 | return; 98 | } 99 | 100 | private: 101 | double m_counter; 102 | const Hyperparameters *m_hp; 103 | 104 | OnlineNode* m_rootNode; 105 | }; 106 | 107 | #endif /* ONLINETREE_H_ */ 108 | -------------------------------------------------------------------------------- /online_forests_ros/include/online_forests/randomtest.h: -------------------------------------------------------------------------------- 1 | #ifndef RANDOMTEST_H_ 2 | #define RANDOMTEST_H_ 3 | 4 | #include "online_forests/data.h" 5 | #include "online_forests/utilities.h" 6 | 7 | class RandomTest { 8 | public: 9 | RandomTest() { 10 | 11 | } 12 | 13 | RandomTest(const int &numClasses) : 14 | m_numClasses(&numClasses), m_trueCount(0.0), m_falseCount(0.0) { 15 | for (int i = 0; i < numClasses; i++) { 16 | m_trueStats.push_back(0.0); 17 | m_falseStats.push_back(0.0); 18 | } 19 | m_threshold = randomFromRange(-1, 1); 20 | } 21 | 22 | RandomTest(const int &numClasses, const double featMin, const double featMax) : 23 | m_numClasses(&numClasses), m_trueCount(0.0), m_falseCount(0.0) { 24 | for (int i = 0; i < numClasses; i++) { 25 | m_trueStats.push_back(0.0); 26 | m_falseStats.push_back(0.0); 27 | } 28 | m_threshold = randomFromRange(featMin, featMax); 29 | } 30 | 31 | void updateStats(const Sample &sample, const bool decision) { 32 | if (decision) { 33 | m_trueCount += sample.w; 34 | m_trueStats[sample.y] += sample.w; 35 | } else { 36 | m_falseCount += sample.w; 37 | m_falseStats[sample.y] += sample.w; 38 | } 39 | } 40 | 41 | double score() { 42 | double totalCount = m_trueCount + m_falseCount; 43 | 44 | // Split Entropy 45 | double p, splitEntropy = 0.0; 46 | if (m_trueCount) { 47 | p = m_trueCount / totalCount; 48 | splitEntropy -= p * log2(p); 49 | } 50 | if (m_trueCount) { 51 | p = m_trueCount / totalCount; 52 | splitEntropy -= p * log2(p); 53 | } 54 | 55 | // Prior Entropy 56 | double priorEntropy = 0.0; 57 | for (int i = 0; i < *m_numClasses; i++) { 58 | p = (m_trueStats[i] + m_falseStats[i]) / totalCount; 59 | if (p) { 60 | priorEntropy -= p * log2(p); 61 | } 62 | } 63 | 64 | // Posterior Entropy 65 | double trueScore = 0.0, falseScore = 0.0; 66 | if (m_trueCount) { 67 | for (int i = 0; i < *m_numClasses; i++) { 68 | p = m_trueStats[i] / m_trueCount; 69 | if (p) { 70 | trueScore -= p * log2(p); 71 | } 72 | } 73 | } 74 | if (m_falseCount) { 75 | for (int i = 0; i < *m_numClasses; i++) { 76 | p = m_falseStats[i] / m_falseCount; 77 | if (p) { 78 | falseScore -= p * log2(p); 79 | } 80 | } 81 | } 82 | double posteriorEntropy = (m_trueCount * trueScore + m_falseCount * falseScore) / totalCount; 83 | 84 | // Information Gain 85 | return (2.0 * (priorEntropy - posteriorEntropy)) / (priorEntropy * splitEntropy + 1e-10); 86 | } 87 | 88 | pair , vector > getStats() { 89 | return pair , vector > (m_trueStats, m_falseStats); 90 | } 91 | 92 | protected: 93 | const int *m_numClasses; 94 | double m_threshold; 95 | double m_trueCount; 96 | double m_falseCount; 97 | vector m_trueStats; 98 | vector m_falseStats; 99 | }; 100 | 101 | class HyperplaneFeature: public RandomTest { 102 | public: 103 | HyperplaneFeature() { 104 | 105 | } 106 | 107 | HyperplaneFeature(const int &numClasses, const int &numFeatures, const int &numProjFeatures, const vector &minFeatRange, 108 | const vector &maxFeatRange) : 109 | RandomTest(numClasses), m_numProjFeatures(&numProjFeatures) { 110 | randPerm(numFeatures, numProjFeatures, m_features); 111 | fillWithRandomNumbers(numProjFeatures, m_weights); 112 | 113 | // Find min and max range of the projection 114 | double minRange = 0.0, maxRange = 0.0; 115 | for (int i = 0; i < numProjFeatures; i++) { 116 | minRange += minFeatRange[m_features[i]] * m_weights[i]; 117 | maxRange += maxFeatRange[m_features[i]] * m_weights[i]; 118 | } 119 | 120 | m_threshold = randomFromRange(minRange, maxRange); 121 | } 122 | 123 | void update(Sample &sample) { 124 | updateStats(sample, eval(sample)); 125 | } 126 | 127 | bool eval(Sample &sample) { 128 | double proj = 0.0; 129 | for (int i = 0; i < *m_numProjFeatures; i++) { 130 | proj += sample.x[m_features[i]] * m_weights[i]; 131 | } 132 | 133 | return (proj > m_threshold) ? true : false; 134 | } 135 | 136 | void writeTest(FILE *fp){ 137 | fprintf(fp," %lf",m_threshold); 138 | fprintf(fp," %lf",m_trueCount); 139 | fprintf(fp," %lf",m_falseCount); 140 | 141 | for (int i = 0; i < *m_numClasses; i++) { 142 | fprintf(fp," %lf",m_trueStats[i]); 143 | fprintf(fp," %lf",m_falseStats[i]); 144 | } 145 | 146 | for (int i = 0; i < *m_numProjFeatures; i++) { 147 | fprintf(fp," %d",m_features[i]); 148 | fprintf(fp," %lf",m_weights[i]); 149 | } 150 | } 151 | 152 | void loadTest(FILE *fp){ 153 | fscanf(fp, "%lf ", &m_threshold); 154 | fscanf(fp, "%lf ", &m_trueCount); 155 | fscanf(fp, "%lf ", &m_falseCount); 156 | 157 | for (int i = 0; i < *m_numClasses; i++) { 158 | fscanf(fp, "%lf ", &m_trueStats[i]); 159 | fscanf(fp, "%lf ", &m_falseStats[i]); 160 | } 161 | 162 | for (int i = 0; i < *m_numProjFeatures; i++) { 163 | fscanf(fp, "%d ", &m_features[i]); 164 | fscanf(fp, "%lf ", &m_weights[i]); 165 | } 166 | } 167 | 168 | private: 169 | const int *m_numProjFeatures; 170 | vector m_features; 171 | vector m_weights; 172 | }; 173 | 174 | #endif /* RANDOMTEST_H_ */ 175 | -------------------------------------------------------------------------------- /online_forests_ros/include/online_forests/utilities.h: -------------------------------------------------------------------------------- 1 | #ifndef UTILITIES_H_ 2 | #define UTILITIES_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #ifndef WIN32 12 | #include 13 | #endif 14 | 15 | using namespace std; 16 | 17 | // Random Numbers Generators 18 | unsigned int getDevRandom(); 19 | 20 | //! Returns a random number in [0, 1] 21 | inline double randDouble() { 22 | static bool didSeeding = false; 23 | 24 | if (!didSeeding) { 25 | #ifdef WIN32 26 | srand(0); 27 | #else 28 | unsigned int seedNum; 29 | struct timeval TV; 30 | unsigned int curTime; 31 | 32 | gettimeofday(&TV, NULL); 33 | curTime = (unsigned int) TV.tv_usec; 34 | seedNum = (unsigned int) time(NULL) + curTime + getpid() + getDevRandom(); 35 | 36 | srand(seedNum); 37 | #endif 38 | didSeeding = true; 39 | } 40 | return rand() / (RAND_MAX + 1.0); 41 | } 42 | 43 | //! Returns a random number in [min, max] 44 | inline double randomFromRange(const double &minRange, const double &maxRange) { 45 | return minRange + (maxRange - minRange) * randDouble(); 46 | } 47 | 48 | //! Random permutations 49 | void randPerm(const int &inNum, vector &outVect); 50 | void randPerm(const int &inNum, const int inPart, vector &outVect); 51 | 52 | inline void fillWithRandomNumbers(const int &length, vector &inVect) { 53 | inVect.clear(); 54 | for (int i = 0; i < length; i++) { 55 | inVect.push_back(2.0 * (randDouble() - 0.5)); 56 | } 57 | } 58 | 59 | inline int argmax(const vector &inVect) { 60 | double maxValue = inVect[0]; 61 | int maxIndex = 0, i = 1; 62 | vector::const_iterator itr(inVect.begin() + 1), end(inVect.end()); 63 | while (itr != end) { 64 | if (*itr > maxValue) { 65 | maxValue = *itr; 66 | maxIndex = i; 67 | } 68 | 69 | ++i; 70 | ++itr; 71 | } 72 | 73 | return maxIndex; 74 | } 75 | 76 | inline double sum(const vector &inVect) { 77 | double val = 0.0; 78 | vector::const_iterator itr(inVect.begin()), end(inVect.end()); 79 | while (itr != end) { 80 | val += *itr; 81 | ++itr; 82 | } 83 | 84 | return val; 85 | } 86 | 87 | //! Poisson sampling 88 | inline int poisson(double A) { 89 | int k = 0; 90 | int maxK = 10; 91 | while (1) { 92 | double U_k = randDouble(); 93 | A *= U_k; 94 | if (k > maxK || A < exp(-1.0)) { 95 | break; 96 | } 97 | k++; 98 | } 99 | return k; 100 | } 101 | 102 | #endif /* UTILITIES_H_ */ 103 | -------------------------------------------------------------------------------- /online_forests_ros/launch/online_forests_ros.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /online_forests_ros/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | online_forests_ros 4 | 0.1.2 5 | 6 | A package that ROSifies the Online Random Forests. 7 | 8 | Rui Yang 9 | Zhi Yan 10 | MIT 11 | 12 | https://github.com/amirsaffari/online-random-forests 13 | Amir Saffari 14 | 15 | catkin 16 | 17 | roscpp 18 | 19 | libconfig++-dev 20 | libgmm++-dev 21 | libatlas-base-dev 22 | libblas-dev 23 | liblapack-dev 24 | autoware_tracker 25 | std_msgs 26 | 27 | autoware_tracker 28 | std_msgs 29 | 30 | -------------------------------------------------------------------------------- /online_forests_ros/src/online_forests/Online-Forest.cpp: -------------------------------------------------------------------------------- 1 | #define GMM_USES_BLAS 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "online_forests/data.h" 11 | #include "online_forests/onlinetree.h" 12 | #include "online_forests/onlinerf.h" 13 | 14 | using namespace std; 15 | using namespace libconfig; 16 | 17 | typedef enum { 18 | ORT, ORF 19 | } CLASSIFIER_TYPE; 20 | 21 | //! Prints the interface help message 22 | void help() { 23 | cout << endl; 24 | cout << "OnlineForest Classification Package:" << endl; 25 | cout << "Input arguments:" << endl; 26 | cout << "\t -h | --help : \t will display this message." << endl; 27 | cout << "\t -c : \t\t path to the config file." << endl << endl; 28 | cout << "\t --ort : \t use Online Random Tree (ORT) algorithm." << endl; 29 | cout << "\t --orf : \t use Online Random Forest (ORF) algorithm." << endl; 30 | cout << endl << endl; 31 | cout << "\t --train : \t train the classifier." << endl; 32 | cout << "\t --test : \t test the classifier." << endl; 33 | cout << "\t --t2 : \t train and test the classifier at the same time." << endl; 34 | cout << endl << endl; 35 | cout << "\tExamples:" << endl; 36 | cout << "\t ./Online-Forest -c conf/orf.conf --orf --train --test" << endl; 37 | } 38 | 39 | //! Returns the time (ms) elapsed between two calls to this function 40 | double timeIt(int reset) { 41 | static time_t startTime, endTime; 42 | static int timerWorking = 0; 43 | 44 | if (reset) { 45 | startTime = time(NULL); 46 | timerWorking = 1; 47 | return -1; 48 | } else { 49 | if (timerWorking) { 50 | endTime = time(NULL); 51 | timerWorking = 0; 52 | return (double) (endTime - startTime); 53 | } else { 54 | startTime = time(NULL); 55 | timerWorking = 1; 56 | return -1; 57 | } 58 | } 59 | } 60 | 61 | int main(int argc, char *argv[]) { 62 | // Parsing command line 63 | string confFileName; 64 | int classifier = -1, doTraining = false, doTesting = false, doT2 = false, inputCounter = 1; 65 | 66 | if (argc == 1) { 67 | cout << "\tNo input argument specified: aborting." << endl; 68 | help(); 69 | exit(EXIT_SUCCESS); 70 | } 71 | 72 | while (inputCounter < argc) { 73 | if (!strcmp(argv[inputCounter], "-h") || !strcmp(argv[inputCounter], "--help")) { 74 | help(); 75 | return EXIT_SUCCESS; 76 | } else if (!strcmp(argv[inputCounter], "-c")) { 77 | confFileName = argv[++inputCounter]; 78 | } else if (!strcmp(argv[inputCounter], "--ort")) { 79 | classifier = ORT; 80 | } else if (!strcmp(argv[inputCounter], "--orf")) { 81 | classifier = ORF; 82 | } else if (!strcmp(argv[inputCounter], "--train")) { 83 | doTraining = true; 84 | } else if (!strcmp(argv[inputCounter], "--test")) { 85 | doTesting = true; 86 | } else if (!strcmp(argv[inputCounter], "--t2")) { 87 | doT2 = true; 88 | } else { 89 | cout << "\tUnknown input argument: " << argv[inputCounter]; 90 | cout << ", please try --help for more information." << endl; 91 | exit(EXIT_FAILURE); 92 | } 93 | 94 | inputCounter++; 95 | } 96 | 97 | cout << "OnlineMCBoost Classification Package:" << endl; 98 | 99 | if (!doTraining && !doTesting && !doT2) { 100 | cout << "\tNothing to do, no training, no testing !!!" << endl; 101 | exit(EXIT_FAILURE); 102 | } 103 | 104 | if (doT2) { 105 | doTraining = false; 106 | doTesting = false; 107 | } 108 | 109 | // Load the hyperparameters 110 | Hyperparameters hp(confFileName); 111 | 112 | // Creating the train data 113 | DataSet dataset_tr, dataset_ts; 114 | dataset_tr.loadLIBSVM(hp.trainData); 115 | if (doT2 || doTesting) { 116 | dataset_ts.loadLIBSVM(hp.testData); 117 | } 118 | 119 | // Calling training/testing 120 | switch (classifier) { 121 | case ORT: { 122 | OnlineTree model(hp, dataset_tr.m_numClasses, dataset_tr.m_numFeatures, dataset_tr.m_minFeatRange, dataset_tr.m_maxFeatRange); 123 | if (doT2) { 124 | timeIt(1); 125 | model.trainAndTest(dataset_tr, dataset_ts); 126 | cout << "Training/Test time: " << timeIt(0) << endl; 127 | } 128 | if (doTraining) { 129 | timeIt(1); 130 | model.train(dataset_tr); 131 | cout << "Training time: " << timeIt(0) << endl; 132 | } else if (doTesting) { 133 | timeIt(1); 134 | model.test(dataset_ts); 135 | cout << "Test time: " << timeIt(0) << endl; 136 | } 137 | break; 138 | } 139 | case ORF: { 140 | OnlineRF model(hp, dataset_tr.m_numClasses, dataset_tr.m_numFeatures, dataset_tr.m_minFeatRange, dataset_tr.m_maxFeatRange); 141 | if (doT2) { 142 | timeIt(1); 143 | model.trainAndTest(dataset_tr, dataset_ts); 144 | cout << "Training/Test time: " << timeIt(0) << endl; 145 | } 146 | if (doTraining) { 147 | timeIt(1); 148 | model.train(dataset_tr); 149 | cout << "Training time: " << timeIt(0) << endl; 150 | } 151 | if (doTesting) { 152 | timeIt(1); 153 | model.test(dataset_ts); 154 | cout << "Test time: " << timeIt(0) << endl; 155 | } 156 | break; 157 | } 158 | } 159 | 160 | return EXIT_SUCCESS; 161 | } 162 | -------------------------------------------------------------------------------- /online_forests_ros/src/online_forests/classifier.cpp: -------------------------------------------------------------------------------- 1 | #include "online_forests/classifier.h" 2 | -------------------------------------------------------------------------------- /online_forests_ros/src/online_forests/data.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "online_forests/data.h" 6 | 7 | using namespace std; 8 | 9 | void DataSet::findFeatRange() { 10 | double minVal, maxVal; 11 | for (int i = 0; i < m_numFeatures; i++) { 12 | minVal = m_samples[0].x[i]; 13 | maxVal = m_samples[0].x[i]; 14 | for (int n = 1; n < m_numSamples; n++) { 15 | if (m_samples[n].x[i] < minVal) { 16 | minVal = m_samples[n].x[i]; 17 | } 18 | if (m_samples[n].x[i] > maxVal) { 19 | maxVal = m_samples[n].x[i]; 20 | } 21 | } 22 | 23 | m_minFeatRange.push_back(minVal); 24 | m_maxFeatRange.push_back(maxVal); 25 | } 26 | } 27 | 28 | void DataSet::loadLIBSVM(string filename) { 29 | ifstream fp(filename.c_str(), ios::binary); 30 | if (!fp) { 31 | cout << "Could not open input file " << filename << endl; 32 | exit(EXIT_FAILURE); 33 | } 34 | 35 | cout << "Loading data file: " << filename << " ... " << endl; 36 | 37 | // Reading the header 38 | int startIndex; 39 | fp >> m_numSamples; 40 | fp >> m_numFeatures; 41 | fp >> m_numClasses; 42 | fp >> startIndex; 43 | 44 | // Reading the data 45 | string line, tmpStr; 46 | int prePos, curPos, colIndex; 47 | m_samples.clear(); 48 | 49 | for (int i = 0; i < m_numSamples; i++) { 50 | wsvector x(m_numFeatures); 51 | Sample sample; 52 | resize(sample.x, m_numFeatures); 53 | fp >> sample.y; // read label 54 | sample.w = 1.0; // set weight 55 | 56 | getline(fp, line); // read the rest of the line 57 | prePos = 0; 58 | curPos = line.find(' ', 0); 59 | while (prePos <= curPos) { 60 | prePos = curPos + 1; 61 | curPos = line.find(':', prePos); 62 | tmpStr = line.substr(prePos, curPos - prePos); 63 | colIndex = atoi(tmpStr.c_str()) - startIndex; 64 | 65 | prePos = curPos + 1; 66 | curPos = line.find(' ', prePos); 67 | tmpStr = line.substr(prePos, curPos - prePos); 68 | x[colIndex] = atof(tmpStr.c_str()); 69 | } 70 | copy(x, sample.x); 71 | m_samples.push_back(sample); // push sample into dataset 72 | } 73 | 74 | fp.close(); 75 | 76 | if (m_numSamples != (int) m_samples.size()) { 77 | cout << "Could not load " << m_numSamples << " samples from " << filename; 78 | cout << ". There were only " << m_samples.size() << " samples!" << endl; 79 | exit(EXIT_FAILURE); 80 | } 81 | 82 | // Find the data range 83 | findFeatRange(); 84 | 85 | cout << "Loaded " << m_numSamples << " samples with " << m_numFeatures; 86 | cout << " features and " << m_numClasses << " classes." << endl; 87 | } 88 | 89 | void DataSet::loadLIBSVM2(string data) { 90 | // Reading the header 91 | std::istringstream iss(data); 92 | string line; 93 | int startIndex; 94 | 95 | getline(iss, line, ' '); 96 | m_numSamples = atoi(line.c_str()); 97 | getline(iss, line, ' '); 98 | m_numFeatures = atoi(line.c_str()); 99 | getline(iss, line, ' '); 100 | m_numClasses = atoi(line.c_str()); 101 | getline(iss, line, '\n'); 102 | startIndex = atoi(line.c_str()); 103 | 104 | // Reading the data 105 | string tmpStr; 106 | int prePos, curPos, colIndex; 107 | m_samples.clear(); 108 | 109 | for (int i = 0; i < m_numSamples; i++) { 110 | wsvector x(m_numFeatures); 111 | Sample sample; 112 | resize(sample.x, m_numFeatures); 113 | // getline(iss, line); 114 | // sample.y = atoi(line.substr(line.find(' ')).c_str()); // read label 115 | getline(iss, line, ' '); 116 | sample.y = atoi(line.c_str()); // read label 117 | sample.w = 1.0; // set weight 118 | 119 | getline(iss, line); 120 | prePos = 0; 121 | curPos = line.find(' ', 0); 122 | while (prePos <= curPos) { 123 | prePos = curPos + 1; 124 | curPos = line.find(':', prePos); 125 | tmpStr = line.substr(prePos, curPos - prePos); 126 | colIndex = atoi(tmpStr.c_str()) - startIndex; 127 | 128 | prePos = curPos + 1; 129 | curPos = line.find(' ', prePos); 130 | tmpStr = line.substr(prePos, curPos - prePos); 131 | x[colIndex] = atof(tmpStr.c_str()); 132 | } 133 | copy(x, sample.x); 134 | m_samples.push_back(sample); // push sample into dataset 135 | } 136 | 137 | if (m_numSamples != (int) m_samples.size()) { 138 | cout << "Could not load " << m_numSamples; 139 | cout << ". There were only " << m_samples.size() << " samples!" << endl; 140 | exit(EXIT_FAILURE); 141 | } 142 | 143 | // Find the data range 144 | findFeatRange(); 145 | 146 | cout << "Loaded " << m_numSamples << " samples with " << m_numFeatures; 147 | cout << " features and " << m_numClasses << " classes." << endl; 148 | } 149 | -------------------------------------------------------------------------------- /online_forests_ros/src/online_forests/hyperparameters.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "online_forests/hyperparameters.h" 5 | 6 | using namespace std; 7 | using namespace libconfig; 8 | 9 | Hyperparameters::Hyperparameters(const string& confFile) { 10 | cout << "Loading config file: " << confFile << " ... "; 11 | 12 | Config configFile; 13 | configFile.readFile(confFile.c_str()); 14 | 15 | // Node/Tree 16 | maxDepth = configFile.lookup("Tree.maxDepth"); 17 | numRandomTests = configFile.lookup("Tree.numRandomTests"); 18 | numProjectionFeatures = configFile.lookup("Tree.numProjectionFeatures"); 19 | counterThreshold = configFile.lookup("Tree.counterThreshold"); 20 | 21 | // Forest 22 | numTrees = configFile.lookup("Forest.numTrees"); 23 | numEpochs = configFile.lookup("Forest.numEpochs"); 24 | useSoftVoting = configFile.lookup("Forest.useSoftVoting"); 25 | 26 | // Data 27 | trainData = (const char *) configFile.lookup("Data.trainData"); 28 | testData = (const char *) configFile.lookup("Data.testData"); 29 | 30 | // Output 31 | verbose = configFile.lookup("Output.verbose"); 32 | 33 | cout << "Done." << endl; 34 | } 35 | -------------------------------------------------------------------------------- /online_forests_ros/src/online_forests/onlinenode.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "online_forests/onlinenode.h" 4 | 5 | using namespace std; 6 | 7 | void OnlineNode::update(Sample &sample) { 8 | m_counter += sample.w; 9 | m_labelStats[sample.y] += sample.w; 10 | 11 | if (m_isLeaf) { 12 | // Update online tests 13 | for (int i = 0; i < m_hp->numRandomTests; i++) { 14 | m_onlineTests[i].update(sample); 15 | } 16 | 17 | // Update the label 18 | m_label = argmax(m_labelStats); 19 | 20 | // Decide for split 21 | if (shouldISplit()) { 22 | m_isLeaf = false; 23 | 24 | // Find the best online test 25 | int maxIndex = 0; 26 | double maxScore = -1e10, score; 27 | for (int i = 0; i < m_hp->numRandomTests; i++) { 28 | score = m_onlineTests[i].score(); 29 | if (score > maxScore) { 30 | maxScore = score; 31 | maxIndex = i; 32 | } 33 | } 34 | m_bestTest = m_onlineTests[maxIndex]; 35 | m_onlineTests.clear(); 36 | 37 | if (m_hp->verbose >= 4) { 38 | cout << "--- Splitting node --- best score: " << maxScore; 39 | cout << " by test number: " << maxIndex << endl; 40 | } 41 | 42 | // Split 43 | pair , vector > parentStats = m_bestTest.getStats(); 44 | m_rightChildNode = new OnlineNode(*m_hp, *m_numClasses, *m_numFeatures, *m_minFeatRange, *m_maxFeatRange, m_depth + 1, 45 | parentStats.first); 46 | m_leftChildNode = new OnlineNode(*m_hp, *m_numClasses, *m_numFeatures, *m_minFeatRange, *m_maxFeatRange, m_depth + 1, 47 | parentStats.second); 48 | } 49 | } else { 50 | if (m_bestTest.eval(sample)) { 51 | m_rightChildNode->update(sample); 52 | } else { 53 | m_leftChildNode->update(sample); 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /online_forests_ros/src/online_forests/onlinerf.cpp: -------------------------------------------------------------------------------- 1 | #include "online_forests/onlinerf.h" 2 | 3 | -------------------------------------------------------------------------------- /online_forests_ros/src/online_forests/onlinetree.cpp: -------------------------------------------------------------------------------- 1 | #include "online_forests/onlinetree.h" 2 | -------------------------------------------------------------------------------- /online_forests_ros/src/online_forests/randomtest.cpp: -------------------------------------------------------------------------------- 1 | #include "online_forests/randomtest.h" 2 | -------------------------------------------------------------------------------- /online_forests_ros/src/online_forests/utilities.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #ifndef WIN32 7 | #include 8 | #endif 9 | 10 | #include "online_forests/utilities.h" 11 | 12 | using namespace std; 13 | 14 | unsigned int getDevRandom() { 15 | ifstream devFile("/dev/urandom", ios::binary); 16 | unsigned int outInt = 0; 17 | char tempChar[sizeof(outInt)]; 18 | 19 | devFile.read(tempChar, sizeof(outInt)); 20 | outInt = atoi(tempChar); 21 | 22 | devFile.close(); 23 | 24 | return outInt; 25 | } 26 | 27 | void randPerm(const int &inNum, vector &outVect) { 28 | outVect.resize(inNum); 29 | int randIndex, tempIndex; 30 | for (int nFeat = 0; nFeat < inNum; nFeat++) { 31 | outVect[nFeat] = nFeat; 32 | } 33 | for (register int nFeat = 0; nFeat < inNum; nFeat++) { 34 | randIndex = (int) floor(((double) inNum - nFeat) * randDouble()) + nFeat; 35 | if (randIndex == inNum) { 36 | randIndex--; 37 | } 38 | tempIndex = outVect[nFeat]; 39 | outVect[nFeat] = outVect[randIndex]; 40 | outVect[randIndex] = tempIndex; 41 | } 42 | } 43 | 44 | void randPerm(const int &inNum, const int inPart, vector &outVect) { 45 | outVect.resize(inNum); 46 | int randIndex, tempIndex; 47 | for (int nFeat = 0; nFeat < inNum; nFeat++) { 48 | outVect[nFeat] = nFeat; 49 | } 50 | for (register int nFeat = 0; nFeat < inPart; nFeat++) { 51 | randIndex = (int) floor(((double) inNum - nFeat) * randDouble()) + nFeat; 52 | if (randIndex == inNum) { 53 | randIndex--; 54 | } 55 | tempIndex = outVect[nFeat]; 56 | outVect[nFeat] = outVect[randIndex]; 57 | outVect[randIndex] = tempIndex; 58 | } 59 | 60 | outVect.erase(outVect.begin() + inPart, outVect.end()); 61 | } 62 | -------------------------------------------------------------------------------- /online_forests_ros/src/online_forests_ros.cpp: -------------------------------------------------------------------------------- 1 | // (c) 2020 Zhi Yan, Rui Yang 2 | // This code is licensed under MIT license (see LICENSE.txt for details) 3 | #define GMM_USES_BLAS 4 | 5 | // ROS 6 | #include 7 | #include 8 | // Online Random Forests 9 | #include "online_forests/onlinetree.h" 10 | #include "online_forests/onlinerf.h" 11 | 12 | int main(int argc, char **argv) { 13 | std::ofstream icra_log; 14 | std::string log_name = "orf_time_log_"+std::to_string(ros::WallTime::now().toSec()); 15 | 16 | std::string conf_file_name; 17 | std::string model_file_name; 18 | int mode; // 1 - train, 2 - test, 3 - train and test. 19 | int minimum_samples; 20 | int total_samples = 0; 21 | 22 | ros::init(argc, argv, "online_forests_ros"); 23 | ros::NodeHandle nh, private_nh("~"); 24 | 25 | if(private_nh.getParam("conf_file_name", conf_file_name)) { 26 | ROS_INFO("Got param 'conf_file_name': %s", conf_file_name.c_str()); 27 | } else { 28 | ROS_ERROR("Failed to get param 'conf_file_name'"); 29 | exit(EXIT_SUCCESS); 30 | } 31 | 32 | if(private_nh.getParam("model_file_name", model_file_name)) { 33 | ROS_INFO("Got param 'model_file_name': %s", model_file_name.c_str()); 34 | } else { 35 | ROS_ERROR("Failed to get param 'model_file_name'"); 36 | exit(EXIT_SUCCESS); 37 | } 38 | 39 | if(private_nh.getParam("mode", mode)) { 40 | ROS_INFO("Got param 'mode': %d", mode); 41 | } else { 42 | ROS_ERROR("Failed to get param 'mode'"); 43 | exit(EXIT_SUCCESS); 44 | } 45 | 46 | private_nh.param("minimum_samples", minimum_samples, 1); 47 | 48 | Hyperparameters hp(conf_file_name); 49 | std_msgs::String::ConstPtr features; 50 | 51 | while (ros::ok()) { 52 | features = ros::topic::waitForMessage("/point_cloud_features/features"); // process blocked waiting 53 | 54 | // Creating the train data 55 | DataSet dataset_tr; 56 | dataset_tr.loadLIBSVM2(features->data); 57 | 58 | // Creating the test data 59 | DataSet dataset_ts; 60 | // dataset_ts.loadLIBSVM(hp.testData); 61 | 62 | if(atoi(features->data.substr(0, features->data.find(" ")).c_str()) >= minimum_samples) { 63 | OnlineRF model(hp, dataset_tr.m_numClasses, dataset_tr.m_numFeatures, dataset_tr.m_minFeatRange, dataset_tr.m_maxFeatRange); // TOTEST: OnlineTree 64 | 65 | //string model_file_name = ""; 66 | icra_log.open(log_name, std::ofstream::out | std::ofstream::app); 67 | time_t start_time = ros::WallTime::now().toSec(); 68 | 69 | switch(mode) { 70 | case 1: // train only 71 | if(access( model_file_name.c_str(), F_OK ) != -1){ 72 | model.loadForest(model_file_name); 73 | } 74 | model.train(dataset_tr); 75 | //model.writeForest(model_file_name); //turning off the writing function 76 | break; 77 | case 2: // test only 78 | model.loadForest(model_file_name); 79 | model.test(dataset_ts); 80 | break; 81 | case 3: // train and test 82 | model.trainAndTest(dataset_tr, dataset_ts); 83 | break; 84 | default: 85 | ROS_ERROR("Unknown 'mode'"); 86 | } 87 | 88 | std::cout << "[online_forests_ros] Training time: " << ros::WallTime::now().toSec() - start_time << " s" << std::endl; 89 | icra_log << (total_samples+=dataset_tr.m_numSamples) << " " << ros::WallTime::now().toSec()-start_time << "\n"; 90 | icra_log.close(); 91 | } 92 | 93 | ros::spinOnce(); 94 | } 95 | 96 | return EXIT_SUCCESS; 97 | } 98 | -------------------------------------------------------------------------------- /online_svm_ros/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(online_svm_ros) 3 | 4 | find_package(catkin REQUIRED COMPONENTS 5 | roscpp 6 | std_msgs 7 | ) 8 | 9 | find_package(yaml-cpp REQUIRED) 10 | 11 | include_directories( 12 | ${catkin_INCLUDE_DIRS} 13 | ) 14 | 15 | add_executable(online_svm_ros 16 | src/online_svm_ros.cpp 17 | ) 18 | 19 | target_link_libraries(online_svm_ros 20 | ${catkin_LIBRARIES} 21 | yaml-cpp 22 | svm 23 | ) 24 | 25 | if(catkin_EXPORTED_TARGETS) 26 | add_dependencies(online_svm_ros 27 | ${catkin_EXPORTED_TARGETS} 28 | ) 29 | endif() 30 | -------------------------------------------------------------------------------- /online_svm_ros/README.md: -------------------------------------------------------------------------------- 1 | # online_svm_ros 2 | 3 | Barebone package for SVM-based classifier online training used in [online learning](https://github.com/yzrobot/online_learning). 4 | -------------------------------------------------------------------------------- /online_svm_ros/config/svm.yaml: -------------------------------------------------------------------------------- 1 | # https://github.com/cjlin1/libsvm 2 | x_lower: -1.0 3 | x_upper: 1.0 4 | max_examples: 5000 5 | svm_type: 0 # default 0 (C_SVC) 6 | kernel_type: 2 # default 2 (RBF) 7 | degree: 3 # default 3 8 | gamma: 0.02 # default 1.0/(float)FEATURE_SIZE 9 | coef0: 0 # default 0 10 | cache_size: 256 # default 100 11 | eps: 0.001 # default 0.001 12 | C: 8 # default 1 13 | nr_weight: 0 14 | nu: 0.5 15 | p: 0.1 16 | shrinking: 0 17 | probability: 1 18 | save_data: true 19 | best_params: false -------------------------------------------------------------------------------- /online_svm_ros/launch/online_svm_ros.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /online_svm_ros/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | online_svm_ros 4 | 0.0.1 5 | 6 | Online SVM. 7 | 8 | Rui Yang 9 | GPLv3 10 | 11 | https://github.com/yzrobot/online_learning 12 | Zhi Yan 13 | 14 | catkin 15 | 16 | roscpp 17 | yaml-cpp 18 | libsvm-dev 19 | std_msgs 20 | 21 | -------------------------------------------------------------------------------- /point_cloud_features/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(point_cloud_features) 3 | 4 | find_package(catkin REQUIRED COMPONENTS 5 | roscpp 6 | pcl_conversions 7 | pcl_ros 8 | std_msgs 9 | autoware_tracker 10 | ) 11 | 12 | find_package(PCL REQUIRED) 13 | 14 | include_directories(include ${catkin_INCLUDE_DIRS}) 15 | 16 | catkin_package(INCLUDE_DIRS include) 17 | 18 | add_executable(point_cloud_features 19 | src/point_cloud_feature_extractor.cpp 20 | src/${PROJECT_NAME}/point_cloud_features.cpp 21 | ) 22 | 23 | target_link_libraries(point_cloud_features 24 | ${catkin_LIBRARIES} 25 | ) 26 | 27 | if(catkin_EXPORTED_TARGETS) 28 | add_dependencies(point_cloud_features 29 | ${catkin_EXPORTED_TARGETS} 30 | ) 31 | endif() 32 | -------------------------------------------------------------------------------- /point_cloud_features/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Zhi Yan 4 | Copyright (c) 2021, Rui Yang 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | 1. Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | 2. Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /point_cloud_features/README.md: -------------------------------------------------------------------------------- 1 | # point_cloud_features 2 | 3 | Barebone package for point cloud feature extraction used in [online learning](https://github.com/yzrobot/online_learning). 4 | 5 | [2020-11-14]: Adjusted the calculation of minimum_points. 6 | -------------------------------------------------------------------------------- /point_cloud_features/include/point_cloud_features/point_cloud_features.h: -------------------------------------------------------------------------------- 1 | /** 2 | * BSD 3-Clause License 3 | * 4 | * Copyright (c) 2020, Zhi Yan 5 | * All rights reserved. 6 | 7 | * Redistribution and use in source and binary forms, with or without 8 | * modification, are permitted provided that the following conditions are met: 9 | 10 | * 1. Redistributions of source code must retain the above copyright notice, this 11 | * list of conditions and the following disclaimer. 12 | 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | 17 | * 3. Neither the name of the copyright holder nor the names of its 18 | * contributors may be used to endorse or promote products derived from 19 | * this software without specific prior written permission. 20 | 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | **/ 32 | 33 | #ifndef POINT_CLOUD_FEATURES_H 34 | #define POINT_CLOUD_FEATURES_H 35 | 36 | // ROS 37 | #include 38 | 39 | // PCL 40 | #include 41 | #include 42 | 43 | int numberOfPoints(pcl::PointCloud::Ptr); 44 | float minDistance(pcl::PointCloud::Ptr); 45 | void covarianceMat3D(pcl::PointCloud::Ptr, std::vector &); 46 | void normalizedMOIT(pcl::PointCloud::Ptr, std::vector &); 47 | void sliceFeature(pcl::PointCloud::Ptr, int, std::vector &); 48 | void intensityDistribution(pcl::PointCloud::Ptr, int, std::vector &); 49 | 50 | #endif /* POINT_CLOUD_FEATURES_H */ 51 | -------------------------------------------------------------------------------- /point_cloud_features/launch/point_cloud_feature_extractor.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /point_cloud_features/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | point_cloud_features 4 | 0.0.1 5 | 6 | A ROS package for extracting features from 3D point clouds. 7 | 8 | Rui Yang 9 | BSD 10 | 11 | https://github.com/yzrobot/online_learning 12 | Zhi Yan 13 | 14 | catkin 15 | 16 | roscpp 17 | pcl_conversions 18 | pcl_ros 19 | std_msgs 20 | autoware_tracker 21 | 22 | -------------------------------------------------------------------------------- /point_cloud_features/src/point_cloud_features/point_cloud_features.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * BSD 3-Clause License 3 | * 4 | * Copyright (c) 2020, Zhi Yan 5 | * All rights reserved. 6 | 7 | * Redistribution and use in source and binary forms, with or without 8 | * modification, are permitted provided that the following conditions are met: 9 | 10 | * 1. Redistributions of source code must retain the above copyright notice, this 11 | * list of conditions and the following disclaimer. 12 | 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | 17 | * 3. Neither the name of the copyright holder nor the names of its 18 | * contributors may be used to endorse or promote products derived from 19 | * this software without specific prior written permission. 20 | 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | **/ 32 | 33 | #include "point_cloud_features/point_cloud_features.h" 34 | 35 | /* f1 (1d): Number of points included in a cluster */ 36 | int numberOfPoints(pcl::PointCloud::Ptr pc) { 37 | return pc->size(); 38 | } 39 | 40 | /* f2 (1d): The minimum distance of the cluster to the sensor */ 41 | /* f1 and f2 could be used in pairs, since f1 varies with f2 changes */ 42 | float minDistance(pcl::PointCloud::Ptr pc) { 43 | float m = FLT_MAX; 44 | 45 | for(int i = 0; i < pc->size(); i++) { 46 | m = std::min(m, pc->points[i].x*pc->points[i].x + pc->points[i].y*pc->points[i].y + pc->points[i].z*pc->points[i].z); 47 | } 48 | 49 | return sqrt(m); 50 | } 51 | 52 | /* f3 (6d): 3D covariance matrix of the cluster */ 53 | void covarianceMat3D(pcl::PointCloud::Ptr pc, std::vector &res) { 54 | Eigen::Matrix3f covariance_3d; 55 | pcl::PCA pca; 56 | pcl::PointCloud::Ptr pc_projected(new pcl::PointCloud); 57 | Eigen::Vector4f centroid; 58 | 59 | pca.setInputCloud(pc); 60 | pca.project(*pc, *pc_projected); 61 | pcl::compute3DCentroid(*pc, centroid); 62 | pcl::computeCovarianceMatrixNormalized(*pc_projected, centroid, covariance_3d); 63 | 64 | // Only 6 elements are needed as covariance_3d is symmetric. 65 | res.push_back(covariance_3d(0,0)); 66 | res.push_back(covariance_3d(0,1)); 67 | res.push_back(covariance_3d(0,2)); 68 | res.push_back(covariance_3d(1,1)); 69 | res.push_back(covariance_3d(1,2)); 70 | res.push_back(covariance_3d(2,2)); 71 | } 72 | 73 | /* f4 (6d): The normalized moment of inertia tensor */ 74 | void normalizedMOIT(pcl::PointCloud::Ptr pc, std::vector &res) { 75 | Eigen::Matrix3f moment_3d; 76 | pcl::PCA pca; 77 | pcl::PointCloud::Ptr pc_projected(new pcl::PointCloud); 78 | 79 | moment_3d.setZero(); 80 | pca.setInputCloud(pc); 81 | pca.project(*pc, *pc_projected); 82 | for(int i = 0; i < (*pc_projected).size(); i++) { 83 | moment_3d(0,0) += (*pc_projected)[i].y*(*pc_projected)[i].y + (*pc_projected)[i].z*(*pc_projected)[i].z; 84 | moment_3d(0,1) -= (*pc_projected)[i].x*(*pc_projected)[i].y; 85 | moment_3d(0,2) -= (*pc_projected)[i].x*(*pc_projected)[i].z; 86 | moment_3d(1,1) += (*pc_projected)[i].x*(*pc_projected)[i].x + (*pc_projected)[i].z*(*pc_projected)[i].z; 87 | moment_3d(1,2) -= (*pc_projected)[i].y*(*pc_projected)[i].z; 88 | moment_3d(2,2) += (*pc_projected)[i].x*(*pc_projected)[i].x + (*pc_projected)[i].y*(*pc_projected)[i].y; 89 | } 90 | 91 | // Only 6 elements are needed as moment_3d is symmetric. 92 | res.push_back(moment_3d(0,0)); 93 | res.push_back(moment_3d(0,1)); 94 | res.push_back(moment_3d(0,2)); 95 | res.push_back(moment_3d(1,1)); 96 | res.push_back(moment_3d(1,2)); 97 | res.push_back(moment_3d(2,2)); 98 | } 99 | 100 | /* f5 (n*2d): Slice feature for the cluster */ 101 | void sliceFeature(pcl::PointCloud::Ptr pc, int n, std::vector &res) { 102 | for(int i = 0; i < n*2; i++) { 103 | res.push_back(0); 104 | } 105 | 106 | Eigen::Vector4f pc_min, pc_max; 107 | pcl::getMinMax3D(*pc, pc_min, pc_max); 108 | 109 | pcl::PointCloud::Ptr blocks[n]; 110 | float itv = (pc_max[2] - pc_min[2]) / n; 111 | 112 | if(itv > 0) { 113 | for(int i = 0; i < n; i++) { 114 | blocks[i].reset(new pcl::PointCloud); 115 | } 116 | for(unsigned int i = 0, j; i < pc->size(); i++) { 117 | j = std::min((n-1), (int)((pc->points[i].z - pc_min[2]) / itv)); 118 | blocks[j]->points.push_back(pc->points[i]); 119 | } 120 | 121 | Eigen::Vector4f block_min, block_max; 122 | for(int i = 0; i < n; i++) { 123 | if(blocks[i]->size() > 2) { // At least 3 points to perform pca. 124 | pcl::PCA pca; 125 | pcl::PointCloud::Ptr block_projected(new pcl::PointCloud); 126 | pca.setInputCloud(blocks[i]); 127 | pca.project(*blocks[i], *block_projected); 128 | pcl::getMinMax3D(*block_projected, block_min, block_max); 129 | } else { 130 | block_min.setZero(); 131 | block_max.setZero(); 132 | } 133 | res[i*2] = block_max[0] - block_min[0]; 134 | res[i*2+1] = block_max[1] - block_min[1]; 135 | } 136 | } 137 | } 138 | 139 | /* f6 (n+2d): Distribution of the reflection intensity, including the mean, the standard deviation and the normalized 1D histogram (n is the number of bins) */ 140 | void intensityDistribution(pcl::PointCloud::Ptr pc, int n, std::vector &res) { 141 | float sum = 0, min = FLT_MAX, max = -FLT_MAX, mean = 0, sum_dev = 0; 142 | 143 | for(int i = 0; i < n+2; i++) { 144 | res.push_back(0); 145 | } 146 | 147 | for(int i = 0; i < pc->size(); i++) { 148 | sum += pc->points[i].intensity; 149 | min = std::min(min, pc->points[i].intensity); 150 | max = std::max(max, pc->points[i].intensity); 151 | } 152 | mean = sum / pc->size(); 153 | 154 | for(int i = 0; i < pc->size(); i++) { 155 | sum_dev += (pc->points[i].intensity - mean) * (pc->points[i].intensity - mean); 156 | 157 | int j = std::min(float(n-1), std::floor((pc->points[i].intensity-min) / ((max-min) / n))); 158 | res[j]++; 159 | } 160 | 161 | res[n] = sqrt(sum_dev / pc->size()); 162 | res[n+1] = mean; 163 | } 164 | --------------------------------------------------------------------------------