├── IMG
├── EOTL-2D-CAR-312.png
├── EOTL-2D-CYC-066.png
├── EOTL-2D-PED-022.png
├── EOTL-3D-CAR-276.png
├── EOTL-3D-CYC-125.png
└── EOTL-3D-PED-091.png
├── README.md
├── autoware_tracker
├── CMakeLists.txt
├── LICENSE
├── README.md
├── config
│ └── params.yaml
├── launch
│ ├── run.launch
│ └── rviz.rviz
├── msg
│ ├── Centroids.msg
│ ├── CloudCluster.msg
│ ├── CloudClusterArray.msg
│ ├── DetectedObject.msg
│ └── DetectedObjectArray.msg
├── package.xml
└── src
│ ├── detected_objects_visualizer
│ ├── visualize_detected_objects.cpp
│ ├── visualize_detected_objects.h
│ ├── visualize_detected_objects_main.cpp
│ ├── visualize_rects.cpp
│ ├── visualize_rects.h
│ └── visualize_rects_main.cpp
│ ├── libkitti
│ ├── kitti.cpp
│ └── kitti.h
│ ├── lidar_euclidean_cluster_detect
│ ├── cluster.cpp
│ ├── cluster.h
│ ├── gencolors.cpp
│ ├── lidar_euclidean_cluster_detect.cpp
│ └── lidar_euclidean_cluster_detect.cpp.old
│ └── lidar_imm_ukf_pda_track
│ ├── imm_ukf_pda.cpp
│ ├── imm_ukf_pda.cpp.old
│ ├── imm_ukf_pda.h
│ ├── imm_ukf_pda_main.cpp
│ ├── ukf.cpp
│ └── ukf.h
├── efficient_det_ros
├── LICENSE
├── __pycache__
│ └── backbone.cpython-37.pyc
├── backbone.py
├── benchmark
│ └── coco_eval_result
├── coco_eval.py
├── efficient_det_node.py
├── efficientdet
│ ├── __pycache__
│ │ ├── model.cpython-37.pyc
│ │ └── utils.cpython-37.pyc
│ ├── config.py
│ ├── dataset.py
│ ├── loss.py
│ ├── model.py
│ └── utils.py
├── efficientdet_test.py
├── efficientdet_test_videos.py
├── efficientnet
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── model.cpython-37.pyc
│ │ ├── utils.cpython-37.pyc
│ │ └── utils_extra.cpython-37.pyc
│ ├── model.py
│ ├── utils.py
│ └── utils_extra.py
├── projects
│ ├── coco.yml
│ ├── kitti.yml
│ └── shape.yml
├── readme.md
├── train.py
├── tutorial
│ └── train_shape.ipynb
├── utils
│ ├── __pycache__
│ │ └── utils.cpython-37.pyc
│ ├── sync_batchnorm
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── batchnorm.cpython-37.pyc
│ │ │ ├── comm.cpython-37.pyc
│ │ │ └── replicate.cpython-37.pyc
│ │ ├── batchnorm.py
│ │ ├── batchnorm_reimpl.py
│ │ ├── comm.py
│ │ ├── replicate.py
│ │ └── unittest.py
│ └── utils.py
└── weights
│ └── efficientdet-d2.pth
├── kitti_camera_ros
├── .gitignore
├── .travis.yml
├── CMakeLists.txt
├── LICENSE
├── README.md
├── launch
│ └── kitti_camera_ros.launch
├── package.xml
└── src
│ └── kitti_camera_ros.cpp
├── kitti_velodyne_ros
├── .gitignore
├── .travis.yml
├── CMakeLists.txt
├── LICENSE
├── README.md
├── launch
│ ├── kitti_velodyne_ros.launch
│ ├── kitti_velodyne_ros.rviz
│ ├── kitti_velodyne_ros_loam.launch
│ └── kitti_velodyne_ros_loam.rviz
├── package.xml
└── src
│ └── kitti_velodyne_ros.cpp
├── launch
├── efficient_online_learning.launch
└── efficient_online_learning.rviz
├── online_forests_ros
├── CMakeLists.txt
├── LICENSE
├── README.md
├── config
│ └── orf.conf
├── doc
│ └── 2009-OnlineRandomForests.pdf
├── include
│ └── online_forests
│ │ ├── classifier.h
│ │ ├── data.h
│ │ ├── hyperparameters.h
│ │ ├── onlinenode.h
│ │ ├── onlinerf.h
│ │ ├── onlinetree.h
│ │ ├── randomtest.h
│ │ └── utilities.h
├── launch
│ └── online_forests_ros.launch
├── model
│ ├── dna-test.libsvm
│ └── dna-train.libsvm
├── package.xml
└── src
│ ├── online_forests
│ ├── Online-Forest.cpp
│ ├── classifier.cpp
│ ├── data.cpp
│ ├── hyperparameters.cpp
│ ├── onlinenode.cpp
│ ├── onlinerf.cpp
│ ├── onlinetree.cpp
│ ├── randomtest.cpp
│ └── utilities.cpp
│ └── online_forests_ros.cpp
├── online_svm_ros
├── CMakeLists.txt
├── LICENSE
├── README.md
├── config
│ └── svm.yaml
├── launch
│ └── online_svm_ros.launch
├── package.xml
└── src
│ └── online_svm_ros.cpp
└── point_cloud_features
├── CMakeLists.txt
├── LICENSE
├── README.md
├── include
└── point_cloud_features
│ └── point_cloud_features.h
├── launch
└── point_cloud_feature_extractor.launch
├── package.xml
└── src
├── point_cloud_feature_extractor.cpp
└── point_cloud_features
└── point_cloud_features.cpp
/IMG/EOTL-2D-CAR-312.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/IMG/EOTL-2D-CAR-312.png
--------------------------------------------------------------------------------
/IMG/EOTL-2D-CYC-066.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/IMG/EOTL-2D-CYC-066.png
--------------------------------------------------------------------------------
/IMG/EOTL-2D-PED-022.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/IMG/EOTL-2D-PED-022.png
--------------------------------------------------------------------------------
/IMG/EOTL-3D-CAR-276.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/IMG/EOTL-3D-CAR-276.png
--------------------------------------------------------------------------------
/IMG/EOTL-3D-CYC-125.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/IMG/EOTL-3D-CYC-125.png
--------------------------------------------------------------------------------
/IMG/EOTL-3D-PED-091.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/IMG/EOTL-3D-PED-091.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Efficient Online Transfer Learning for 3D Object Classification in Autonomous Driving #
2 |
3 | *We are actively updating this repository (especially removing hard code and adding comments) to make it easy to use. If you have any questions, please open an issue. Thanks!*
4 |
5 | This is a ROS-based efficient online learning framework for object classification in 3D LiDAR scans, taking advantage of robust multi-target tracking to avoid the need for data annotation by a human expert.
6 | The system is only tested in Ubuntu 18.04 and ROS Melodic (compilation fails on Ubuntu 20.04 and ROS Noetic).
7 |
8 | Please watch the videos below for more details.
9 |
10 | [](https://www.youtube.com/watch?v=wl5ehOFV5Ac)
11 |
12 | ## NEWS
13 | [2023-01-11] Our evaluation results in KITTI 3D OBJECT DETECTION are ranked 276, **91**, 125 in car, pedestrian and cyclist respectively.
14 |
15 | 3D Object Detection
16 |
17 | *CAR*
18 | 
19 |
20 | *PEDESTRIAN*
21 | 
22 |
23 | *CYCLIST*
24 | 
25 |
26 |
27 | [2023-01-11] Our evaluation results in KITTI 2D OBJECT DETECTION achieved rankings of **22** on pedestrian and **66** on cyclist!
28 |
29 | 2D Object Detection
30 |
31 | *CAR*
32 | 
33 |
34 | *PEDESTRIAN*
35 | 
36 |
37 | *CYCLIST*
38 | 
39 |
40 |
41 | ## Install & Build
42 | Please read the readme file of each sub-package first and install the corresponding dependencies.
43 |
44 | ## Run
45 | ### 1. Prepare dataset
46 | * (Optional) Download the [raw data](http://www.cvlibs.net/datasets/kitti/raw_data.php) from KITTI.
47 |
48 | * (Optional) Download the [sample data](https://github.com/epan-utbm/efficient_online_learning/releases/download/sample_data/2011_09_26_drive_0005_sync.tar) for testing.
49 |
50 | * (Optional) Prepare a customized dataset according to the format of the sample data.
51 |
52 | ### 2. Manual set specific path parameters
53 | # launch/efficient_online_learning
54 | # autoware_tracker/config/params.yaml
55 |
56 | ### 3. Run the project
57 | ```sh
58 | cd catkin_ws
59 | source devel/setup.bash
60 | roslaunch src/efficient_online_learning/launch/efficient_online_learning.launch
61 | ```
62 |
63 | ## Citation
64 |
65 | If you are considering using this code, please reference the following:
66 |
67 | ```
68 | @article{yangr23sensors,
69 | author = {Rui Yang and Zhi Yan and Tao Yang and Yaonan Wang and Yassine Ruichek},
70 | title = {Efficient Online Transfer Learning for Road Participants Detection in Autonomous Driving},
71 | journal = {IEEE Sensors Journal},
72 | volume = {23},
73 | number = {19},
74 | Pages = {23522--23535},
75 | year = {2023}
76 | }
77 |
78 | @inproceedings{yangr21itsc,
79 | title={Efficient online transfer learning for 3D object classification in autonomous driving},
80 | author={Rui Yang and Zhi Yan and Tao Yang and Yassine Ruichek},
81 | booktitle = {Proceedings of the 2021 IEEE International Conference on Intelligent Transportation Systems (ITSC)},
82 | pages = {2950--2957},
83 | address = {Indianapolis, USA},
84 | month = {September},
85 | year = {2021}
86 | }
87 | ```
88 |
--------------------------------------------------------------------------------
/autoware_tracker/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 2.8.3)
2 | project(autoware_tracker)
3 |
4 | set(CMAKE_BUILD_TYPE "Release")
5 | set(CMAKE_CXX_FLAGS "-std=c++11")
6 | set(CMAKE_CXX_FLAGS_RELEASE "-O3 -Wall -g -pthread")
7 |
8 | find_package(catkin REQUIRED COMPONENTS
9 | tf
10 | pcl_ros
11 | roscpp
12 | std_msgs
13 | sensor_msgs
14 | geometry_msgs
15 | visualization_msgs
16 | cv_bridge
17 | image_transport
18 | jsk_recognition_msgs
19 | jsk_rviz_plugins
20 | message_generation
21 | )
22 |
23 | # Messages
24 | add_message_files(
25 | DIRECTORY msg
26 | FILES
27 | Centroids.msg
28 | CloudCluster.msg
29 | CloudClusterArray.msg
30 | DetectedObject.msg
31 | DetectedObjectArray.msg
32 | )
33 | generate_messages(
34 | DEPENDENCIES
35 | geometry_msgs
36 | jsk_recognition_msgs
37 | sensor_msgs
38 | std_msgs
39 | )
40 |
41 | catkin_package(
42 | #INCLUDE_DIRS include
43 | CATKIN_DEPENDS
44 | tf
45 | pcl_ros
46 | roscpp
47 | std_msgs
48 | sensor_msgs
49 | geometry_msgs
50 | visualization_msgs
51 | cv_bridge
52 | image_transport
53 | jsk_recognition_msgs
54 | jsk_rviz_plugins
55 | message_runtime
56 | message_generation
57 | )
58 |
59 |
60 | find_package(OpenMP)
61 | find_package(OpenCV REQUIRED)
62 | find_package(Eigen3 QUIET)
63 |
64 | include_directories(
65 | #include
66 | ${catkin_INCLUDE_DIRS}
67 | ${OpenCV_INCLUDE_DIRS}
68 | )
69 |
70 | add_executable(visualize_detected_objects
71 | src/detected_objects_visualizer/visualize_detected_objects_main.cpp
72 | src/detected_objects_visualizer/visualize_detected_objects.cpp
73 | )
74 | add_dependencies(visualize_detected_objects ${catkin_EXPORTED_TARGETS} ${PROJECT_NAME}_generate_messages_cpp)
75 | target_link_libraries(visualize_detected_objects ${OpenCV_LIBRARIES} ${EIGEN3_LIBRARIES} ${catkin_LIBRARIES})
76 |
77 | add_executable(lidar_euclidean_cluster_detect
78 | src/lidar_euclidean_cluster_detect/lidar_euclidean_cluster_detect.cpp
79 | src/lidar_euclidean_cluster_detect/cluster.cpp
80 | src/libkitti/kitti.cpp)
81 | add_dependencies(lidar_euclidean_cluster_detect ${catkin_EXPORTED_TARGETS} ${PROJECT_NAME}_generate_messages_cpp)
82 | target_link_libraries(lidar_euclidean_cluster_detect ${OpenCV_LIBRARIES} ${catkin_LIBRARIES})
83 |
84 | add_executable(imm_ukf_pda
85 | src/lidar_imm_ukf_pda_track/imm_ukf_pda_main.cpp
86 | src/lidar_imm_ukf_pda_track/imm_ukf_pda.h
87 | src/lidar_imm_ukf_pda_track/imm_ukf_pda.cpp
88 | src/lidar_imm_ukf_pda_track/ukf.cpp
89 | )
90 | add_dependencies(imm_ukf_pda ${catkin_EXPORTED_TARGETS} ${PROJECT_NAME}_generate_messages_cpp)
91 | target_link_libraries(imm_ukf_pda ${catkin_LIBRARIES})
92 |
93 | install(DIRECTORY include/${PROJECT_NAME}/
94 | DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION}
95 | FILES_MATCHING PATTERN "*.h"
96 | PATTERN ".svn" EXCLUDE)
97 |
--------------------------------------------------------------------------------
/autoware_tracker/README.md:
--------------------------------------------------------------------------------
1 | # autoware_tracker
2 |
3 | This pacakge is forked from [https://github.com/TixiaoShan/autoware_tracker](https://github.com/TixiaoShan/autoware_tracker), and the original Readme file is below the dividing line.
4 |
5 | [2020-10-xx]: Added "automatic annotation" for point clouds, please install the dependencies first: `$ sudo apt install ros-melodic-vision-msgs`.
6 |
7 | [2020-09-18]: Added "intensity" to the points, which is essential for our online learning system, as the intensity can help us distinguish objects.
8 |
9 | ---
10 |
11 | # Readme
12 |
13 | Barebone package for point cloud object tracking used in Autoware. The package is only tested in Ubuntu 16.04 and ROS Kinetic. No deep learning is used.
14 |
15 | # Install JSK
16 | ```
17 | sudo apt-get install ros-kinetic-jsk-recognition-msgs
18 | sudo apt-get install ros-kinetic-jsk-rviz-plugins
19 | ```
20 |
21 | # Compile
22 | ```
23 | cd ~/catkin_ws/src
24 | git clone https://github.com/TixiaoShan/autoware_tracker.git
25 | cd ~/catkin_ws
26 | catkin_make -j1
27 | ```
28 | ```-j1``` is only needed for message generation in the first install.
29 |
30 | # Sample data
31 |
32 | In case you don't have some bag files handy, you can download a sample bag using:
33 | ```
34 | wget https://autoware-ai.s3.us-east-2.amazonaws.com/sample_moriyama_150324.tar.gz
35 | ```
36 |
37 | # Demo
38 |
39 | Run the autoware tracker:
40 | ```
41 | roslaunch autoware_tracker run.launch
42 | ```
43 |
44 | Play the sample ros bag:
45 | ```
46 | rosbag play sample_moriyama_150324.bag
47 | ```
48 |
--------------------------------------------------------------------------------
/autoware_tracker/config/params.yaml:
--------------------------------------------------------------------------------
1 | autoware_tracker:
2 |
3 | cluster:
4 | label_source: "/image_detections"
5 | extrinsic_calibration: "/home/epan/Rui/datasets/2011_09_26_drive_0005_sync/calib_2011_09_26.txt"
6 | iou_threshold: 0.5
7 |
8 | points_node: "/points_raw"
9 | output_frame: "velodyne"
10 |
11 | remove_ground: true
12 |
13 | downsample_cloud: true
14 | leaf_size: 0.1
15 |
16 | use_multiple_thres: true
17 |
18 | cluster_size_min: 20
19 | cluster_size_max: 100000
20 | clip_min_height: -2.0
21 | clip_max_height: 0.0
22 | cluster_merge_threshold: 1.5
23 | clustering_distance: 0.75
24 | remove_points_min: 2.0
25 | remove_points_max: 100.0
26 |
27 | keep_lanes: false
28 | keep_lane_left_distance: 5.0
29 | keep_lane_right_distance: 5.0
30 |
31 | use_diffnormals: false
32 | pose_estimation: true
33 |
34 | tracker:
35 | gating_thres: 9.22
36 | gate_probability: 0.99
37 | detection_probability: 0.9
38 | life_time_thres: 8
39 | static_velocity_thres: 0.5
40 | static_num_history_thres: 3
41 | prevent_explosion_thres: 1000
42 | merge_distance_threshold: 0.5
43 | use_sukf: false
44 |
45 | tracking_frame: "/world"
46 | # namespace: /detection/object_tracker/
47 | track_probability: 0.7
48 |
--------------------------------------------------------------------------------
/autoware_tracker/launch/run.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/autoware_tracker/launch/rviz.rviz:
--------------------------------------------------------------------------------
1 | Panels:
2 | - Class: rviz/Displays
3 | Help Height: 0
4 | Name: Displays
5 | Property Tree Widget:
6 | Expanded:
7 | - /Global Options1
8 | - /MarkerArray1
9 | - /MarkerArray1/Namespaces1
10 | Splitter Ratio: 0.505813956
11 | Tree Height: 1124
12 | - Class: rviz/Selection
13 | Name: Selection
14 | - Class: rviz/Tool Properties
15 | Expanded:
16 | - /2D Pose Estimate1
17 | - /2D Nav Goal1
18 | - /Publish Point1
19 | Name: Tool Properties
20 | Splitter Ratio: 0.588679016
21 | - Class: rviz/Views
22 | Expanded:
23 | - /Current View1
24 | Name: Views
25 | Splitter Ratio: 0.5
26 | - Class: rviz/Time
27 | Experimental: false
28 | Name: Time
29 | SyncMode: 0
30 | SyncSource: Point cluster
31 | Toolbars:
32 | toolButtonStyle: 2
33 | Visualization Manager:
34 | Class: ""
35 | Displays:
36 | - Class: rviz/Axes
37 | Enabled: true
38 | Length: 1
39 | Name: Axes
40 | Radius: 0.300000012
41 | Reference Frame:
42 | Value: true
43 | - Class: rviz/Group
44 | Displays:
45 | - Alpha: 1
46 | Autocompute Intensity Bounds: true
47 | Autocompute Value Bounds:
48 | Max Value: 10
49 | Min Value: -10
50 | Value: true
51 | Axis: Z
52 | Channel Name: intensity
53 | Class: rviz/PointCloud2
54 | Color: 255; 255; 255
55 | Color Transformer: Intensity
56 | Decay Time: 0
57 | Enabled: true
58 | Invert Rainbow: false
59 | Max Color: 255; 255; 255
60 | Max Intensity: 142
61 | Min Color: 0; 0; 0
62 | Min Intensity: 0
63 | Name: Point cloud
64 | Position Transformer: XYZ
65 | Queue Size: 10
66 | Selectable: true
67 | Size (Pixels): 1
68 | Size (m): 0.00999999978
69 | Style: Points
70 | Topic: /points_raw
71 | Unreliable: false
72 | Use Fixed Frame: true
73 | Use rainbow: true
74 | Value: true
75 | - Alpha: 1
76 | Autocompute Intensity Bounds: true
77 | Autocompute Value Bounds:
78 | Max Value: 10
79 | Min Value: -10
80 | Value: true
81 | Axis: Z
82 | Channel Name: intensity
83 | Class: rviz/PointCloud2
84 | Color: 255; 255; 255
85 | Color Transformer: RGB8
86 | Decay Time: 0
87 | Enabled: true
88 | Invert Rainbow: false
89 | Max Color: 255; 255; 255
90 | Max Intensity: 4096
91 | Min Color: 0; 0; 0
92 | Min Intensity: 0
93 | Name: Point cluster
94 | Position Transformer: XYZ
95 | Queue Size: 10
96 | Selectable: true
97 | Size (Pixels): 5
98 | Size (m): 0.00999999978
99 | Style: Points
100 | Topic: /autoware_tracker/cluster/points_cluster
101 | Unreliable: false
102 | Use Fixed Frame: true
103 | Use rainbow: true
104 | Value: true
105 | Enabled: true
106 | Name: Lidar cluster
107 | - Class: rviz/MarkerArray
108 | Enabled: true
109 | Marker Topic: /autoware_tracker/visualizer/objects
110 | Name: MarkerArray
111 | Namespaces:
112 | arrow_markers: true
113 | box_markers: false
114 | centroid_markers: true
115 | hull_markers: true
116 | label_markers: true
117 | Queue Size: 100
118 | Value: true
119 | Enabled: true
120 | Global Options:
121 | Background Color: 48; 48; 48
122 | Default Light: true
123 | Fixed Frame: velodyne
124 | Frame Rate: 30
125 | Name: root
126 | Tools:
127 | - Class: rviz/Interact
128 | Hide Inactive Objects: true
129 | - Class: rviz/MoveCamera
130 | - Class: rviz/Select
131 | - Class: rviz/FocusCamera
132 | - Class: rviz/Measure
133 | - Class: rviz/SetInitialPose
134 | Topic: /initialpose
135 | - Class: rviz/SetGoal
136 | Topic: /move_base_simple/goal
137 | - Class: rviz/PublishPoint
138 | Single click: true
139 | Topic: /clicked_point
140 | Value: true
141 | Views:
142 | Current:
143 | Class: rviz/Orbit
144 | Distance: 48.6941605
145 | Enable Stereo Rendering:
146 | Stereo Eye Separation: 0.0599999987
147 | Stereo Focal Distance: 1
148 | Swap Stereo Eyes: false
149 | Value: false
150 | Focal Point:
151 | X: 0
152 | Y: 0
153 | Z: 0
154 | Focal Shape Fixed Size: true
155 | Focal Shape Size: 0.0500000007
156 | Invert Z Axis: false
157 | Name: Current View
158 | Near Clip Distance: 0.00999999978
159 | Pitch: 0.869796932
160 | Target Frame: base_link
161 | Value: Orbit (rviz)
162 | Yaw: 2.58810186
163 | Saved: ~
164 | Window Geometry:
165 | Displays:
166 | collapsed: false
167 | Height: 1274
168 | Hide Left Dock: false
169 | Hide Right Dock: true
170 | QMainWindow State: 000000ff00000000fd000000040000000000000206000004abfc020000000efb0000001200530065006c0065006300740069006f006e00000001e10000009b0000006300fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afc0000002e000004ab000000dd00fffffffa000000020100000003fb0000000a0049006d0061006700650000000000ffffffff0000000000000000fb0000000c00430061006d0065007200610000000000ffffffff0000000000000000fb000000100044006900730070006c0061007900730100000000000001360000017900fffffffb0000000a0049006d006100670065010000028e000000d20000000000000000fb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261fb000000120049006d006100670065005f0072006100770000000000ffffffff0000000000000000fb0000000c00430061006d006500720061000000024e000001710000000000000000fb000000120049006d00610067006500200052006100770100000421000000160000000000000000fb0000000a0049006d00610067006501000002f4000000cb0000000000000000fb0000000a0049006d006100670065010000056c0000026c0000000000000000000000010000016300000313fc0200000003fb0000000a00560069006500770073000000002e00000313000000b700fffffffb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000001200530065006c0065006300740069006f006e010000025a000000b20000000000000000000000020000073f000000a8fc0100000001fb0000000a00560069006500770073030000004e00000080000002e10000019700000003000006400000005cfc0100000002fb0000000800540069006d00650000000000000006400000038300fffffffb0000000800540069006d00650100000000000004500000000000000000000006ae000004ab00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000
171 | Selection:
172 | collapsed: false
173 | Time:
174 | collapsed: false
175 | Tool Properties:
176 | collapsed: false
177 | Views:
178 | collapsed: true
179 | Width: 2235
180 | X: 692
181 | Y: 316
182 |
--------------------------------------------------------------------------------
/autoware_tracker/msg/Centroids.msg:
--------------------------------------------------------------------------------
1 | std_msgs/Header header
2 | geometry_msgs/Point[] points
3 |
--------------------------------------------------------------------------------
/autoware_tracker/msg/CloudCluster.msg:
--------------------------------------------------------------------------------
1 | std_msgs/Header header
2 |
3 | uint32 id
4 | string label
5 | float64 score
6 |
7 | sensor_msgs/PointCloud2 cloud
8 |
9 | geometry_msgs/PointStamped min_point
10 | geometry_msgs/PointStamped max_point
11 | geometry_msgs/PointStamped avg_point
12 | geometry_msgs/PointStamped centroid_point
13 |
14 | float64 estimated_angle
15 |
16 | geometry_msgs/Vector3 dimensions
17 | geometry_msgs/Vector3 eigen_values
18 | geometry_msgs/Vector3[] eigen_vectors
19 |
20 | #Array of 33 floats containing the FPFH descriptor
21 | std_msgs/Float32MultiArray fpfh_descriptor
22 |
23 | jsk_recognition_msgs/BoundingBox bounding_box
24 | geometry_msgs/PolygonStamped convex_hull
25 |
26 | # Indicator information
27 | # INDICATOR_LEFT 0
28 | # INDICATOR_RIGHT 1
29 | # INDICATOR_BOTH 2
30 | # INDICATOR_NONE 3
31 | uint32 indicator_state
--------------------------------------------------------------------------------
/autoware_tracker/msg/CloudClusterArray.msg:
--------------------------------------------------------------------------------
1 | std_msgs/Header header
2 | CloudCluster[] clusters
--------------------------------------------------------------------------------
/autoware_tracker/msg/DetectedObject.msg:
--------------------------------------------------------------------------------
1 | std_msgs/Header header
2 |
3 | uint32 id
4 | string label
5 | float32 score #Score as defined by the detection, Optional
6 | std_msgs/ColorRGBA color # Define this object specific color
7 | bool valid # Defines if this object is valid, or invalid as defined by the filtering
8 |
9 | ################ 3D BB
10 | string space_frame #3D Space coordinate frame of the object, required if pose and dimensions are defines
11 | geometry_msgs/Pose pose
12 | geometry_msgs/Vector3 dimensions
13 | geometry_msgs/Vector3 variance
14 | geometry_msgs/Twist velocity
15 | geometry_msgs/Twist acceleration
16 |
17 | sensor_msgs/PointCloud2 pointcloud
18 |
19 | geometry_msgs/PolygonStamped convex_hull
20 | # autoware_msgs/LaneArray candidate_trajectories
21 |
22 | bool pose_reliable
23 | bool velocity_reliable
24 | bool acceleration_reliable
25 |
26 | ############### 2D Rect
27 | string image_frame # Image coordinate Frame, Required if x,y,w,h defined
28 | int32 x # X coord in image space(pixel) of the initial point of the Rect
29 | int32 y # Y coord in image space(pixel) of the initial point of the Rect
30 | int32 width # Width of the Rect in pixels
31 | int32 height # Height of the Rect in pixels
32 | float32 angle # Angle [0 to 2*PI), allow rotated rects
33 |
34 | sensor_msgs/Image roi_image
35 |
36 | ############### Indicator information
37 | uint8 indicator_state # INDICATOR_LEFT = 0, INDICATOR_RIGHT = 1, INDICATOR_BOTH = 2, INDICATOR_NONE = 3
38 |
39 | ############### Behavior State of the Detected Object
40 | uint8 behavior_state # FORWARD_STATE = 0, STOPPING_STATE = 1, BRANCH_LEFT_STATE = 2, BRANCH_RIGHT_STATE = 3, YIELDING_STATE = 4, ACCELERATING_STATE = 5, SLOWDOWN_STATE = 6
41 |
42 | #
43 | string[] user_defined_info
44 |
45 | # yang21itsc
46 | bool last_sample
--------------------------------------------------------------------------------
/autoware_tracker/msg/DetectedObjectArray.msg:
--------------------------------------------------------------------------------
1 | std_msgs/Header header
2 | DetectedObject[] objects
--------------------------------------------------------------------------------
/autoware_tracker/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | autoware_tracker
4 | 1.0.0
5 | The autoware tracker package
6 | Tixiao Shan
7 | Apache 2
8 |
9 | catkin
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | jsk_recognition_msgs
20 | jsk_recognition_msgs
21 |
22 |
23 | pcl_ros
24 | roscpp
25 | geometry_msgs
26 | std_msgs
27 | sensor_msgs
28 | tf
29 | visualization_msgs
30 | cv_bridge
31 | image_transport
32 |
33 |
34 | pcl_ros
35 | roscpp
36 | geometry_msgs
37 | std_msgs
38 | sensor_msgs
39 | tf
40 | visualization_msgs
41 | cv_bridge
42 | image_transport
43 |
44 |
45 |
46 |
47 | jsk_rviz_plugins
48 |
49 |
50 |
51 | jsk_rviz_plugins
52 |
53 |
54 | message_generation
55 | message_generation
56 | message_runtime
57 | message_runtime
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/autoware_tracker/src/detected_objects_visualizer/visualize_detected_objects.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018-2019 Autoware Foundation. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | *
16 | ********************
17 | * v1.0: amc-nu (abrahammonrroy@yahoo.com)
18 | */
19 |
20 | #ifndef _VISUALIZEDETECTEDOBJECTS_H
21 | #define _VISUALIZEDETECTEDOBJECTS_H
22 |
23 | #include
24 | #include
25 | #include
26 | #include
27 | #include
28 |
29 | #include
30 | #include
31 |
32 | #include
33 |
34 | #include
35 | #include
36 |
37 | #include "autoware_tracker/DetectedObject.h"
38 | #include "autoware_tracker/DetectedObjectArray.h"
39 |
40 | #define __APP_NAME__ "visualize_detected_objects"
41 |
42 | class VisualizeDetectedObjects
43 | {
44 | private:
45 | const double arrow_height_;
46 | const double label_height_;
47 | const double object_max_linear_size_ = 50.;
48 | double object_speed_threshold_;
49 | double arrow_speed_threshold_;
50 | double marker_display_duration_;
51 |
52 | int marker_id_;
53 |
54 | std_msgs::ColorRGBA label_color_, box_color_, hull_color_, arrow_color_, centroid_color_, model_color_;
55 |
56 | std::string input_topic_, ros_namespace_;
57 |
58 | ros::NodeHandle node_handle_;
59 | ros::Subscriber subscriber_detected_objects_;
60 |
61 | ros::Publisher publisher_markers_;
62 |
63 | visualization_msgs::MarkerArray ObjectsToLabels(const autoware_tracker::DetectedObjectArray &in_objects);
64 |
65 | visualization_msgs::MarkerArray ObjectsToArrows(const autoware_tracker::DetectedObjectArray &in_objects);
66 |
67 | visualization_msgs::MarkerArray ObjectsToBoxes(const autoware_tracker::DetectedObjectArray &in_objects);
68 |
69 | visualization_msgs::MarkerArray ObjectsToModels(const autoware_tracker::DetectedObjectArray &in_objects);
70 |
71 | visualization_msgs::MarkerArray ObjectsToHulls(const autoware_tracker::DetectedObjectArray &in_objects);
72 |
73 | visualization_msgs::MarkerArray ObjectsToCentroids(const autoware_tracker::DetectedObjectArray &in_objects);
74 |
75 | std::string ColorToString(const std_msgs::ColorRGBA &in_color);
76 |
77 | void DetectedObjectsCallback(const autoware_tracker::DetectedObjectArray &in_objects);
78 |
79 | bool IsObjectValid(const autoware_tracker::DetectedObject &in_object);
80 |
81 | float CheckColor(double value);
82 |
83 | float CheckAlpha(double value);
84 |
85 | std_msgs::ColorRGBA ParseColor(const std::vector &in_color);
86 |
87 | public:
88 | VisualizeDetectedObjects();
89 | };
90 |
91 | #endif // _VISUALIZEDETECTEDOBJECTS_H
92 |
--------------------------------------------------------------------------------
/autoware_tracker/src/detected_objects_visualizer/visualize_detected_objects_main.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018-2019 Autoware Foundation. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | *
16 | ********************
17 | * v1.0: amc-nu (abrahammonrroy@yahoo.com)
18 | */
19 |
20 | #include "visualize_detected_objects.h"
21 |
22 | int main(int argc, char** argv)
23 | {
24 | ros::init(argc, argv, "visualize_detected_objects");
25 | ros::console::set_logger_level(ROSCONSOLE_DEFAULT_NAME, ros::console::levels::Warn);
26 |
27 | VisualizeDetectedObjects app;
28 | ros::spin();
29 |
30 | return 0;
31 | }
32 |
--------------------------------------------------------------------------------
/autoware_tracker/src/detected_objects_visualizer/visualize_rects.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018-2019 Autoware Foundation. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | *
16 | ********************
17 | * v1.0: amc-nu (abrahammonrroy@yahoo.com)
18 | */
19 |
20 | #include "visualize_rects.h"
21 |
22 | VisualizeRects::VisualizeRects()
23 | {
24 | ros::NodeHandle private_nh_("~");
25 |
26 | ros::NodeHandle nh;
27 |
28 | std::string image_src_topic;
29 | std::string object_src_topic;
30 | std::string image_out_topic;
31 |
32 | private_nh_.param("image_src", image_src_topic, "/image_raw");
33 | private_nh_.param("object_src", object_src_topic, "/detection/image_detector/objects");
34 | private_nh_.param("image_out", image_out_topic, "/image_rects");
35 |
36 | //get namespace from topic
37 | std::string ros_namespace = image_src_topic;
38 | std::size_t found_pos = ros_namespace.rfind("/");//find last / from topic name to extract namespace
39 | std::cout << ros_namespace << std::endl;
40 | if (found_pos!=std::string::npos)
41 | ros_namespace.erase(found_pos, ros_namespace.length()-found_pos);
42 | std::cout << ros_namespace << std::endl;
43 | image_out_topic = ros_namespace + image_out_topic;
44 |
45 | image_filter_subscriber_ = new message_filters::Subscriber(private_nh_,
46 | image_src_topic,
47 | 1);
48 | ROS_INFO("[%s] image_src: %s", __APP_NAME__, image_src_topic.c_str());
49 | detection_filter_subscriber_ = new message_filters::Subscriber(private_nh_,
50 | object_src_topic,
51 | 1);
52 | ROS_INFO("[%s] object_src: %s", __APP_NAME__, object_src_topic.c_str());
53 |
54 | detections_synchronizer_ =
55 | new message_filters::Synchronizer(SyncPolicyT(10),
56 | *image_filter_subscriber_,
57 | *detection_filter_subscriber_);
58 | detections_synchronizer_->registerCallback(
59 | boost::bind(&VisualizeRects::SyncedDetectionsCallback, this, _1, _2));
60 |
61 |
62 | publisher_image_ = node_handle_.advertise(
63 | image_out_topic, 1);
64 | ROS_INFO("[%s] image_out: %s", __APP_NAME__, image_out_topic.c_str());
65 |
66 | }
67 |
68 | void
69 | VisualizeRects::SyncedDetectionsCallback(
70 | const sensor_msgs::Image::ConstPtr &in_image_msg,
71 | const autoware_tracker::DetectedObjectArray::ConstPtr &in_objects)
72 | {
73 | try
74 | {
75 | image_ = cv_bridge::toCvShare(in_image_msg, "bgr8")->image;
76 | cv::Mat drawn_image;
77 | drawn_image = ObjectsToRects(image_, in_objects);
78 | sensor_msgs::ImagePtr drawn_msg = cv_bridge::CvImage(in_image_msg->header, "bgr8", drawn_image).toImageMsg();
79 | publisher_image_.publish(drawn_msg);
80 | }
81 | catch (cv_bridge::Exception& e)
82 | {
83 | ROS_ERROR("[%s] Could not convert from '%s' to 'bgr8'.", __APP_NAME__, in_image_msg->encoding.c_str());
84 | }
85 | }
86 |
87 | cv::Mat
88 | VisualizeRects::ObjectsToRects(cv::Mat in_image, const autoware_tracker::DetectedObjectArray::ConstPtr &in_objects)
89 | {
90 | cv::Mat final_image = in_image.clone();
91 | for (auto const &object: in_objects->objects)
92 | {
93 | if (IsObjectValid(object))
94 | {
95 | cv::Rect rect;
96 | rect.x = object.x;
97 | rect.y = object.y;
98 | rect.width = object.width;
99 | rect.height = object.height;
100 |
101 | if (rect.x+rect.width >= in_image.cols)
102 | rect.width = in_image.cols -rect.x - 1;
103 |
104 | if (rect.y+rect.height >= in_image.rows)
105 | rect.height = in_image.rows -rect.y - 1;
106 |
107 | //draw rectangle
108 | cv::rectangle(final_image,
109 | rect,
110 | cv::Scalar(244,134,66),
111 | 4,
112 | CV_AA);
113 |
114 | //draw label
115 | std::string label = "";
116 | if (!object.label.empty() && object.label != "unknown")
117 | {
118 | label = object.label;
119 | }
120 | int font_face = cv::FONT_HERSHEY_DUPLEX;
121 | double font_scale = 1.5;
122 | int thickness = 1;
123 |
124 | int baseline=0;
125 | cv::Size text_size = cv::getTextSize(label,
126 | font_face,
127 | font_scale,
128 | thickness,
129 | &baseline);
130 | baseline += thickness;
131 |
132 | cv::Point text_origin(object.x - text_size.height,object.y);
133 |
134 | cv::rectangle(final_image,
135 | text_origin + cv::Point(0, baseline),
136 | text_origin + cv::Point(text_size.width, -text_size.height),
137 | cv::Scalar(0,0,0),
138 | CV_FILLED,
139 | CV_AA,
140 | 0);
141 |
142 | cv::putText(final_image,
143 | label,
144 | text_origin,
145 | font_face,
146 | font_scale,
147 | cv::Scalar::all(255),
148 | thickness,
149 | CV_AA,
150 | false);
151 |
152 | }
153 | }
154 | return final_image;
155 | }//ObjectsToBoxes
156 |
157 | bool VisualizeRects::IsObjectValid(const autoware_tracker::DetectedObject &in_object)
158 | {
159 | if (!in_object.valid ||
160 | in_object.width < 0 ||
161 | in_object.height < 0 ||
162 | in_object.x < 0 ||
163 | in_object.y < 0
164 | )
165 | {
166 | return false;
167 | }
168 | return true;
169 | }//end IsObjectValid
--------------------------------------------------------------------------------
/autoware_tracker/src/detected_objects_visualizer/visualize_rects.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018-2019 Autoware Foundation. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | *
16 | ********************
17 | * v1.0: amc-nu (abrahammonrroy@yahoo.com)
18 | */
19 |
20 | #ifndef _VISUALIZERECTS_H
21 | #define _VISUALIZERECTS_H
22 |
23 | #include
24 | #include
25 | #include
26 | #include
27 | #include
28 |
29 | #include
30 |
31 | #include
32 | #include
33 | #include
34 | #include
35 | #include
36 |
37 | #include
38 | #include
39 | #include
40 |
41 | #include "autoware_tracker/DetectedObject.h"
42 | #include "autoware_tracker/DetectedObjectArray.h"
43 |
44 | #define __APP_NAME__ "visualize_rects"
45 |
46 | class VisualizeRects
47 | {
48 | private:
49 | std::string input_topic_;
50 |
51 | ros::NodeHandle node_handle_;
52 | ros::Subscriber subscriber_detected_objects_;
53 | image_transport::Subscriber subscriber_image_;
54 |
55 | message_filters::Subscriber *detection_filter_subscriber_;
56 | message_filters::Subscriber *image_filter_subscriber_;
57 |
58 | ros::Publisher publisher_image_;
59 |
60 | cv::Mat image_;
61 | std_msgs::Header image_header_;
62 |
63 | typedef
64 | message_filters::sync_policies::ApproximateTime SyncPolicyT;
66 |
67 | message_filters::Synchronizer
68 | *detections_synchronizer_;
69 |
70 | void
71 | SyncedDetectionsCallback(
72 | const sensor_msgs::Image::ConstPtr &in_image_msg,
73 | const autoware_tracker::DetectedObjectArray::ConstPtr &in_range_detections);
74 |
75 | bool IsObjectValid(const autoware_tracker::DetectedObject &in_object);
76 |
77 | cv::Mat ObjectsToRects(cv::Mat in_image, const autoware_tracker::DetectedObjectArray::ConstPtr &in_objects);
78 |
79 | public:
80 | VisualizeRects();
81 | };
82 |
83 | #endif // _VISUALIZERECTS_H
84 |
--------------------------------------------------------------------------------
/autoware_tracker/src/detected_objects_visualizer/visualize_rects_main.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018-2019 Autoware Foundation. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | *
16 | ********************
17 | * v1.0: amc-nu (abrahammonrroy@yahoo.com)
18 | */
19 |
20 | #include "visualize_rects.h"
21 |
22 | int main(int argc, char** argv)
23 | {
24 | ros::init(argc, argv, "visualize_rects");
25 | VisualizeRects app;
26 | ros::spin();
27 |
28 | return 0;
29 | }
30 |
--------------------------------------------------------------------------------
/autoware_tracker/src/lidar_euclidean_cluster_detect/cluster.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Cluster.h
3 | *
4 | * Created on: Oct 19, 2016
5 | * Author: Ne0
6 | */
7 | #ifndef CLUSTER_H_
8 | #define CLUSTER_H_
9 |
10 | #include
11 | #include
12 |
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include
18 |
19 | #include
20 | #include
21 |
22 | #include
23 | #include
24 | #include
25 |
26 | #include
27 | #include
28 | #include
29 |
30 | #include
31 |
32 | #include
33 | #include
34 |
35 | #include
36 | #include
37 | #include
38 |
39 | #include
40 | #include
41 |
42 | #include
43 | #include
44 |
45 | #include
46 |
47 | #include
48 | #include
49 | #include
50 |
51 | #include "autoware_tracker/CloudCluster.h"
52 |
53 | #include "opencv2/core/core.hpp"
54 | #include "opencv2/imgproc/imgproc.hpp"
55 |
56 | #include
57 | #include
58 | #include
59 |
60 | class Cluster
61 | {
62 | pcl::PointCloud::Ptr pointcloud_;
63 | pcl::PointXYZI min_point_;
64 | pcl::PointXYZI max_point_;
65 | pcl::PointXYZI average_point_;
66 | pcl::PointXYZI centroid_;
67 | double orientation_angle_;
68 | float length_, width_, height_;
69 |
70 | jsk_recognition_msgs::BoundingBox bounding_box_;
71 | geometry_msgs::PolygonStamped polygon_;
72 |
73 | std::string label_;
74 | int id_;
75 | int r_, g_, b_;
76 |
77 | Eigen::Matrix3f eigen_vectors_;
78 | Eigen::Vector3f eigen_values_;
79 |
80 | bool valid_cluster_;
81 |
82 | public:
83 | /* \brief Constructor. Creates a Cluster object using the specified points in a PointCloud
84 | * \param[in] in_origin_cloud_ptr Origin PointCloud
85 | * \param[in] in_cluster_indices Indices of the Origin Pointcloud to create the Cluster
86 | * \param[in] in_id ID of the cluster
87 | * \param[in] in_r Amount of Red [0-255]
88 | * \param[in] in_g Amount of Green [0-255]
89 | * \param[in] in_b Amount of Blue [0-255]
90 | * \param[in] in_label Label to identify this cluster (optional)
91 | * \param[in] in_estimate_pose Flag to enable Pose Estimation of the Bounding Box
92 | * */
93 | void SetCloud(const pcl::PointCloud::Ptr in_origin_cloud_ptr,
94 | const std::vector& in_cluster_indices, std_msgs::Header in_ros_header, int in_id, int in_r,
95 | int in_g, int in_b, std::string in_label, bool in_estimate_pose);
96 |
97 | /* \brief Returns the autoware_tracker::CloudCluster message associated to this Cluster */
98 | void ToROSMessage(std_msgs::Header in_ros_header, autoware_tracker::CloudCluster& out_cluster_message);
99 |
100 | Cluster();
101 | virtual ~Cluster();
102 |
103 | /* \brief Returns the pointer to the PointCloud containing the points in this Cluster */
104 | pcl::PointCloud::Ptr GetCloud();
105 | /* \brief Returns the minimum point in the cluster */
106 | pcl::PointXYZI GetMinPoint();
107 | /* \brief Returns the maximum point in the cluster*/
108 | pcl::PointXYZI GetMaxPoint();
109 | /* \brief Returns the average point in the cluster*/
110 | pcl::PointXYZI GetAveragePoint();
111 | /* \brief Returns the centroid point in the cluster */
112 | pcl::PointXYZI GetCentroid();
113 | /* \brief Returns the calculated BoundingBox of the object */
114 | jsk_recognition_msgs::BoundingBox GetBoundingBox();
115 | /* \brief Returns the calculated PolygonArray of the object */
116 | geometry_msgs::PolygonStamped GetPolygon();
117 | /* \brief Returns the angle in radians of the BoundingBox. 0 if pose estimation was not enabled. */
118 | double GetOrientationAngle();
119 | /* \brief Returns the Length of the Cluster */
120 | float GetLenght();
121 | /* \brief Returns the Width of the Cluster */
122 | float GetWidth();
123 | /* \brief Returns the Height of the Cluster */
124 | float GetHeight();
125 | /* \brief Returns the Id of the Cluster */
126 | int GetId();
127 | /* \brief Returns the Label of the Cluster */
128 | std::string GetLabel();
129 | /* \brief Returns the Eigen Vectors of the cluster */
130 | Eigen::Matrix3f GetEigenVectors();
131 | /* \brief Returns the Eigen Values of the Cluster */
132 | Eigen::Vector3f GetEigenValues();
133 |
134 | /* \brief Returns if the Cluster is marked as valid or not*/
135 | bool IsValid();
136 | /* \brief Sets whether the Cluster is valid or not*/
137 | void SetValidity(bool in_valid);
138 |
139 | /* \brief Returns a pointer to a PointCloud object containing the merged points between current Cluster and the
140 | * specified PointCloud
141 | * \param[in] in_cloud_ptr Origin PointCloud
142 | * */
143 | pcl::PointCloud::Ptr JoinCloud(const pcl::PointCloud::Ptr in_cloud_ptr);
144 |
145 | /* \brief Calculates and returns a pointer to the FPFH Descriptor of this cluster
146 | *
147 | */
148 | std::vector GetFpfhDescriptor(const unsigned int& in_ompnum_threads, const double& in_normal_search_radius,
149 | const double& in_fpfh_search_radius);
150 | };
151 |
152 | typedef boost::shared_ptr ClusterPtr;
153 |
154 | #endif /* CLUSTER_H_ */
155 |
--------------------------------------------------------------------------------
/autoware_tracker/src/lidar_euclidean_cluster_detect/gencolors.cpp:
--------------------------------------------------------------------------------
1 | #ifndef GENCOLORS_CPP_
2 | #define GENCOLORS_CPP_
3 |
4 | /*M///////////////////////////////////////////////////////////////////////////////////////
5 | //
6 | // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
7 | //
8 | // By downloading, copying, installing or using the software you agree to this license.
9 | // If you do not agree to this license, do not download, install,
10 | // copy or use the software.
11 | //
12 | //
13 | // License Agreement
14 | // For Open Source Computer Vision Library
15 | //
16 | // Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
17 | // Copyright (C) 2009, Willow Garage Inc., all rights reserved.
18 | // Third party copyrights are property of their respective owners.
19 | //
20 | // Redistribution and use in source and binary forms, with or without modification,
21 | // are permitted provided that the following conditions are met:
22 | //
23 | // * Redistribution's of source code must retain the above copyright notice,
24 | // this list of conditions and the following disclaimer.
25 | //
26 | // * Redistribution's in binary form must reproduce the above copyright notice,
27 | // this list of conditions and the following disclaimer in the documentation
28 | // and/or other materials provided with the distribution.
29 | //
30 | // * The name of the copyright holders may not be used to endorse or promote products
31 | // derived from this software without specific prior written permission.
32 | //
33 | // This software is provided by the copyright holders and contributors "as is" and
34 | // any express or implied warranties, including, but not limited to, the implied
35 | // warranties of merchantability and fitness for a particular purpose are disclaimed.
36 | // In no event shall the Intel Corporation or contributors be liable for any direct,
37 | // indirect, incidental, special, exemplary, or consequential damages
38 | // (including, but not limited to, procurement of substitute goods or services;
39 | // loss of use, data, or profits; or business interruption) however caused
40 | // and on any theory of liability, whether in contract, strict liability,
41 | // or tort (including negligence or otherwise) arising in any way out of
42 | // the use of this software, even if advised of the possibility of such damage.
43 | //
44 | //M*/
45 | #include "opencv2/core/core.hpp"
46 | //#include "precomp.hpp"
47 | #include
48 |
49 | #include
50 |
51 | using namespace cv;
52 |
53 | static void downsamplePoints(const Mat& src, Mat& dst, size_t count)
54 | {
55 | CV_Assert(count >= 2);
56 | CV_Assert(src.cols == 1 || src.rows == 1);
57 | CV_Assert(src.total() >= count);
58 | CV_Assert(src.type() == CV_8UC3);
59 |
60 | dst.create(1, (int)count, CV_8UC3);
61 | // TODO: optimize by exploiting symmetry in the distance matrix
62 | Mat dists((int)src.total(), (int)src.total(), CV_32FC1, Scalar(0));
63 | if (dists.empty())
64 | std::cerr << "Such big matrix cann't be created." << std::endl;
65 |
66 | for (int i = 0; i < dists.rows; i++)
67 | {
68 | for (int j = i; j < dists.cols; j++)
69 | {
70 | float dist = (float)norm(src.at >(i) - src.at >(j));
71 | dists.at(j, i) = dists.at(i, j) = dist;
72 | }
73 | }
74 |
75 | double maxVal;
76 | Point maxLoc;
77 | minMaxLoc(dists, 0, &maxVal, 0, &maxLoc);
78 |
79 | dst.at >(0) = src.at >(maxLoc.x);
80 | dst.at >(1) = src.at >(maxLoc.y);
81 |
82 | Mat activedDists(0, dists.cols, dists.type());
83 | Mat candidatePointsMask(1, dists.cols, CV_8UC1, Scalar(255));
84 | activedDists.push_back(dists.row(maxLoc.y));
85 | candidatePointsMask.at(0, maxLoc.y) = 0;
86 |
87 | for (size_t i = 2; i < count; i++)
88 | {
89 | activedDists.push_back(dists.row(maxLoc.x));
90 | candidatePointsMask.at(0, maxLoc.x) = 0;
91 |
92 | Mat minDists;
93 | reduce(activedDists, minDists, 0, CV_REDUCE_MIN);
94 | minMaxLoc(minDists, 0, &maxVal, 0, &maxLoc, candidatePointsMask);
95 | dst.at >((int)i) = src.at >(maxLoc.x);
96 | }
97 | }
98 |
99 | void generateColors(std::vector& colors, size_t count, size_t factor = 10)
100 | {
101 | if (count < 1)
102 | return;
103 |
104 | colors.resize(count);
105 |
106 | if (count == 1)
107 | {
108 | colors[0] = Scalar(0, 0, 255); // red
109 | return;
110 | }
111 | if (count == 2)
112 | {
113 | colors[0] = Scalar(0, 0, 255); // red
114 | colors[1] = Scalar(0, 255, 0); // green
115 | return;
116 | }
117 |
118 | // Generate a set of colors in RGB space. A size of the set is severel times (=factor) larger then
119 | // the needed count of colors.
120 | Mat bgr(1, (int)(count * factor), CV_8UC3);
121 | randu(bgr, 0, 256);
122 |
123 | // Convert the colors set to Lab space.
124 | // Distances between colors in this space correspond a human perception.
125 | Mat lab;
126 | cvtColor(bgr, lab, cv::COLOR_BGR2Lab);
127 |
128 | // Subsample colors from the generated set so that
129 | // to maximize the minimum distances between each other.
130 | // Douglas-Peucker algorithm is used for this.
131 | Mat lab_subset;
132 | downsamplePoints(lab, lab_subset, count);
133 |
134 | // Convert subsampled colors back to RGB
135 | Mat bgr_subset;
136 | cvtColor(lab_subset, bgr_subset, cv::COLOR_BGR2Lab);
137 |
138 | CV_Assert(bgr_subset.total() == count);
139 | for (size_t i = 0; i < count; i++)
140 | {
141 | Point3_ c = bgr_subset.at >((int)i);
142 | colors[i] = Scalar(c.x, c.y, c.z);
143 | }
144 | }
145 |
146 | #endif // GENCOLORS_CPP
147 |
--------------------------------------------------------------------------------
/autoware_tracker/src/lidar_imm_ukf_pda_track/imm_ukf_pda.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018-2019 Autoware Foundation. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #ifndef OBJECT_TRACKING_IMM_UKF_JPDAF_H
18 | #define OBJECT_TRACKING_IMM_UKF_JPDAF_H
19 |
20 |
21 | #include
22 | #include
23 | #include
24 |
25 |
26 | #include
27 | #include
28 |
29 | #include
30 | #include
31 | #include
32 | #include
33 |
34 | #include
35 |
36 | // #include
37 |
38 | #include "autoware_tracker/DetectedObject.h"
39 | #include "autoware_tracker/DetectedObjectArray.h"
40 |
41 | #include "ukf.h"
42 |
43 | class ImmUkfPda
44 | {
45 | private:
46 | int target_id_;
47 | bool init_;
48 | double timestamp_;
49 |
50 | std::vector targets_;
51 |
52 | // probabilistic data association params
53 | double gating_thres_;
54 | double gate_probability_;
55 | double detection_probability_;
56 |
57 | // object association param
58 | int life_time_thres_;
59 |
60 | // static classification param
61 | double static_velocity_thres_;
62 | int static_num_history_thres_;
63 |
64 | // switch sukf and ImmUkfPda
65 | bool use_sukf_;
66 |
67 | // whether if benchmarking tracking result
68 | bool is_benchmark_;
69 | int frame_count_;
70 | std::string kitti_data_dir_;
71 |
72 | // for benchmark
73 | std::string result_file_path_;
74 |
75 | // prevent explode param for ukf
76 | double prevent_explosion_thres_;
77 |
78 | // for vectormap assisted tarcking
79 | bool use_vectormap_;
80 | bool has_subscribed_vectormap_;
81 | double lane_direction_chi_thres_;
82 | double nearest_lane_distance_thres_;
83 | std::string vectormap_frame_;
84 | // vector_map::VectorMap vmap_;
85 | // std::vector lanes_;
86 |
87 | double merge_distance_threshold_;
88 | const double CENTROID_DISTANCE = 0.2;//distance to consider centroids the same
89 |
90 | std::string input_topic_;
91 | std::string output_topic_;
92 |
93 | std::string tracking_frame_;
94 |
95 | tf::TransformListener tf_listener_;
96 | tf::StampedTransform local2global_;
97 | tf::StampedTransform tracking_frame2lane_frame_;
98 | tf::StampedTransform lane_frame2tracking_frame_;
99 |
100 | ros::NodeHandle node_handle_;
101 | ros::NodeHandle private_nh_;
102 | ros::Subscriber sub_detected_array_;
103 | ros::Publisher pub_object_array_;
104 | //yang21itsc
105 | ros::Publisher pub_example_array_;
106 | ros::Publisher vis_examples_;
107 | //yang21itsc
108 |
109 | std_msgs::Header input_header_;
110 |
111 | // yang21itsc
112 | std::vector learning_buffer;
113 | double track_probability_;
114 | // yang21itsc
115 |
116 | void callback(const autoware_tracker::DetectedObjectArray& input);
117 |
118 | void transformPoseToGlobal(const autoware_tracker::DetectedObjectArray& input,
119 | autoware_tracker::DetectedObjectArray& transformed_input);
120 | void transformPoseToLocal(autoware_tracker::DetectedObjectArray& detected_objects_output);
121 |
122 | geometry_msgs::Pose getTransformedPose(const geometry_msgs::Pose& in_pose,
123 | const tf::StampedTransform& tf_stamp);
124 |
125 | bool updateNecessaryTransform();
126 |
127 | void measurementValidation(const autoware_tracker::DetectedObjectArray& input, UKF& target, const bool second_init,
128 | const Eigen::VectorXd& max_det_z, const Eigen::MatrixXd& max_det_s,
129 | std::vector& object_vec, std::vector& matching_vec);
130 | autoware_tracker::DetectedObject getNearestObject(UKF& target,
131 | const std::vector& object_vec);
132 | void updateBehaviorState(const UKF& target, const bool use_sukf, autoware_tracker::DetectedObject& object);
133 |
134 | void initTracker(const autoware_tracker::DetectedObjectArray& input, double timestamp);
135 | void secondInit(UKF& target, const std::vector& object_vec, double dt);
136 |
137 | void updateTrackingNum(const std::vector& object_vec, UKF& target);
138 |
139 | bool probabilisticDataAssociation(const autoware_tracker::DetectedObjectArray& input, const double dt,
140 | std::vector& matching_vec,
141 | std::vector& object_vec, UKF& target);
142 | void makeNewTargets(const double timestamp, const autoware_tracker::DetectedObjectArray& input,
143 | const std::vector& matching_vec);
144 |
145 | void staticClassification();
146 |
147 | void makeOutput(const autoware_tracker::DetectedObjectArray& input,
148 | const std::vector& matching_vec,
149 | autoware_tracker::DetectedObjectArray& detected_objects_output);
150 |
151 | void removeUnnecessaryTarget();
152 |
153 | void dumpResultText(autoware_tracker::DetectedObjectArray& detected_objects);
154 |
155 | void tracker(const autoware_tracker::DetectedObjectArray& transformed_input,
156 | autoware_tracker::DetectedObjectArray& detected_objects_output);
157 |
158 | bool updateDirection(const double smallest_nis, const autoware_tracker::DetectedObject& in_object,
159 | autoware_tracker::DetectedObject& out_object, UKF& target);
160 |
161 | bool storeObjectWithNearestLaneDirection(const autoware_tracker::DetectedObject& in_object,
162 | autoware_tracker::DetectedObject& out_object);
163 |
164 | void checkVectormapSubscription();
165 |
166 | autoware_tracker::DetectedObjectArray
167 | removeRedundantObjects(const autoware_tracker::DetectedObjectArray& in_detected_objects,
168 | const std::vector in_tracker_indices);
169 |
170 | autoware_tracker::DetectedObjectArray
171 | forwardNonMatchedObject(const autoware_tracker::DetectedObjectArray& tmp_objects,
172 | const autoware_tracker::DetectedObjectArray& input,
173 | const std::vector& matching_vec);
174 |
175 | bool
176 | arePointsClose(const geometry_msgs::Point& in_point_a,
177 | const geometry_msgs::Point& in_point_b,
178 | float in_radius);
179 |
180 | bool
181 | arePointsEqual(const geometry_msgs::Point& in_point_a,
182 | const geometry_msgs::Point& in_point_b);
183 |
184 | bool
185 | isPointInPool(const std::vector& in_pool,
186 | const geometry_msgs::Point& in_point);
187 |
188 | void updateTargetWithAssociatedObject(const std::vector& object_vec,
189 | UKF& target);
190 |
191 | public:
192 | ImmUkfPda();
193 | void run();
194 | };
195 |
196 | #endif /* OBJECT_TRACKING_IMM_UKF_JPDAF_H */
197 |
--------------------------------------------------------------------------------
/autoware_tracker/src/lidar_imm_ukf_pda_track/imm_ukf_pda_main.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018-2019 Autoware Foundation. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #include "imm_ukf_pda.h"
18 |
19 | int main(int argc, char** argv)
20 | {
21 | ros::init(argc, argv, "imm_ukf_pda_tracker");
22 | ros::console::set_logger_level(ROSCONSOLE_DEFAULT_NAME, ros::console::levels::Warn);
23 |
24 | ImmUkfPda app;
25 | app.run();
26 | ros::spin();
27 | return 0;
28 | }
29 |
--------------------------------------------------------------------------------
/efficient_det_ros/__pycache__/backbone.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/__pycache__/backbone.cpython-37.pyc
--------------------------------------------------------------------------------
/efficient_det_ros/backbone.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from efficientdet.model import BiFPN, Regressor, Classifier, EfficientNet
7 | from efficientdet.utils import Anchors
8 |
9 |
10 | class EfficientDetBackbone(nn.Module):
11 | def __init__(self, num_classes=80, compound_coef=0, load_weights=False, **kwargs):
12 | super(EfficientDetBackbone, self).__init__()
13 | self.compound_coef = compound_coef
14 |
15 | self.backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6, 7]
16 | self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384, 384]
17 | self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8, 8]
18 | self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
19 | self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5, 5]
20 | self.pyramid_levels = [5, 5, 5, 5, 5, 5, 5, 5, 6]
21 | self.anchor_scale = [4., 4., 4., 4., 4., 4., 4., 5., 4.]
22 | self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)])
23 | self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]))
24 | conv_channel_coef = {
25 | # the channels of P3/P4/P5.
26 | 0: [40, 112, 320],
27 | 1: [40, 112, 320],
28 | 2: [48, 120, 352],
29 | 3: [48, 136, 384],
30 | 4: [56, 160, 448],
31 | 5: [64, 176, 512],
32 | 6: [72, 200, 576],
33 | 7: [72, 200, 576],
34 | 8: [80, 224, 640],
35 | }
36 |
37 | num_anchors = len(self.aspect_ratios) * self.num_scales
38 |
39 | self.bifpn = nn.Sequential(
40 | *[BiFPN(self.fpn_num_filters[self.compound_coef],
41 | conv_channel_coef[compound_coef],
42 | True if _ == 0 else False,
43 | attention=True if compound_coef < 6 else False,
44 | use_p8=compound_coef > 7)
45 | for _ in range(self.fpn_cell_repeats[compound_coef])])
46 |
47 | self.num_classes = num_classes
48 | self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
49 | num_layers=self.box_class_repeats[self.compound_coef],
50 | pyramid_levels=self.pyramid_levels[self.compound_coef])
51 | self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
52 | num_classes=num_classes,
53 | num_layers=self.box_class_repeats[self.compound_coef],
54 | pyramid_levels=self.pyramid_levels[self.compound_coef])
55 |
56 | self.anchors = Anchors(anchor_scale=self.anchor_scale[compound_coef],
57 | pyramid_levels=(torch.arange(self.pyramid_levels[self.compound_coef]) + 3).tolist(),
58 | **kwargs)
59 |
60 | self.backbone_net = EfficientNet(self.backbone_compound_coef[compound_coef], load_weights)
61 |
62 | def freeze_bn(self):
63 | for m in self.modules():
64 | if isinstance(m, nn.BatchNorm2d):
65 | m.eval()
66 |
67 | def forward(self, inputs):
68 | max_size = inputs.shape[-1]
69 |
70 | _, p3, p4, p5 = self.backbone_net(inputs)
71 |
72 | features = (p3, p4, p5)
73 | features = self.bifpn(features)
74 |
75 | regression = self.regressor(features)
76 | classification = self.classifier(features)
77 | anchors = self.anchors(inputs, inputs.dtype)
78 |
79 | return features, regression, classification, anchors
80 |
81 | def init_backbone(self, path):
82 | state_dict = torch.load(path)
83 | try:
84 | ret = self.load_state_dict(state_dict, strict=False)
85 | print(ret)
86 | except RuntimeError as e:
87 | print('Ignoring ' + str(e) + '"')
88 |
--------------------------------------------------------------------------------
/efficient_det_ros/coco_eval.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | """
4 | COCO-Style Evaluations
5 |
6 | put images here datasets/your_project_name/val_set_name/*.jpg
7 | put annotations here datasets/your_project_name/annotations/instances_{val_set_name}.json
8 | put weights here /path/to/your/weights/*.pth
9 | change compound_coef
10 |
11 | """
12 |
13 | import json
14 | import os
15 |
16 | import argparse
17 | import torch
18 | import yaml
19 | from tqdm import tqdm
20 | from pycocotools.coco import COCO
21 | from pycocotools.cocoeval import COCOeval
22 |
23 | from backbone import EfficientDetBackbone
24 | from efficientdet.utils import BBoxTransform, ClipBoxes
25 | from utils.utils import preprocess, invert_affine, postprocess, boolean_string
26 |
27 | ap = argparse.ArgumentParser()
28 | ap.add_argument('-p', '--project', type=str, default='coco', help='project file that contains parameters')
29 | ap.add_argument('-c', '--compound_coef', type=int, default=0, help='coefficients of efficientdet')
30 | ap.add_argument('-w', '--weights', type=str, default=None, help='/path/to/weights')
31 | ap.add_argument('--nms_threshold', type=float, default=0.5, help='nms threshold, don\'t change it if not for testing purposes')
32 | ap.add_argument('--cuda', type=boolean_string, default=True)
33 | ap.add_argument('--device', type=int, default=0)
34 | ap.add_argument('--float16', type=boolean_string, default=False)
35 | ap.add_argument('--override', type=boolean_string, default=True, help='override previous bbox results file if exists')
36 | args = ap.parse_args()
37 |
38 | compound_coef = args.compound_coef
39 | nms_threshold = args.nms_threshold
40 | use_cuda = args.cuda
41 | gpu = args.device
42 | use_float16 = args.float16
43 | override_prev_results = args.override
44 | project_name = args.project
45 | weights_path = f'weights/efficientdet-d{compound_coef}.pth' if args.weights is None else args.weights
46 |
47 | print(f'running coco-style evaluation on project {project_name}, weights {weights_path}...')
48 |
49 | params = yaml.safe_load(open(f'projects/{project_name}.yml'))
50 | obj_list = params['obj_list']
51 |
52 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
53 |
54 |
55 | def evaluate_coco(img_path, set_name, image_ids, coco, model, threshold=0.05):
56 | results = []
57 |
58 | regressBoxes = BBoxTransform()
59 | clipBoxes = ClipBoxes()
60 |
61 | for image_id in tqdm(image_ids):
62 | image_info = coco.loadImgs(image_id)[0]
63 | image_path = img_path + image_info['file_name']
64 |
65 | ori_imgs, framed_imgs, framed_metas = preprocess(image_path, max_size=input_sizes[compound_coef], mean=params['mean'], std=params['std'])
66 | x = torch.from_numpy(framed_imgs[0])
67 |
68 | if use_cuda:
69 | x = x.cuda(gpu)
70 | if use_float16:
71 | x = x.half()
72 | else:
73 | x = x.float()
74 | else:
75 | x = x.float()
76 |
77 | x = x.unsqueeze(0).permute(0, 3, 1, 2)
78 | features, regression, classification, anchors = model(x)
79 |
80 | preds = postprocess(x,
81 | anchors, regression, classification,
82 | regressBoxes, clipBoxes,
83 | threshold, nms_threshold)
84 |
85 | if not preds:
86 | continue
87 |
88 | preds = invert_affine(framed_metas, preds)[0]
89 |
90 | scores = preds['scores']
91 | class_ids = preds['class_ids']
92 | rois = preds['rois']
93 |
94 | if rois.shape[0] > 0:
95 | # x1,y1,x2,y2 -> x1,y1,w,h
96 | rois[:, 2] -= rois[:, 0]
97 | rois[:, 3] -= rois[:, 1]
98 |
99 | bbox_score = scores
100 |
101 | for roi_id in range(rois.shape[0]):
102 | score = float(bbox_score[roi_id])
103 | label = int(class_ids[roi_id])
104 | box = rois[roi_id, :]
105 |
106 | image_result = {
107 | 'image_id': image_id,
108 | 'category_id': label + 1,
109 | 'score': float(score),
110 | 'bbox': box.tolist(),
111 | }
112 |
113 | results.append(image_result)
114 |
115 | if not len(results):
116 | raise Exception('the model does not provide any valid output, check model architecture and the data input')
117 |
118 | # write output
119 | filepath = f'{set_name}_bbox_results.json'
120 | if os.path.exists(filepath):
121 | os.remove(filepath)
122 | json.dump(results, open(filepath, 'w'), indent=4)
123 |
124 |
125 | def _eval(coco_gt, image_ids, pred_json_path):
126 | # load results in COCO evaluation tool
127 | coco_pred = coco_gt.loadRes(pred_json_path)
128 |
129 | # run COCO evaluation
130 | print('BBox')
131 | coco_eval = COCOeval(coco_gt, coco_pred, 'bbox')
132 | coco_eval.params.imgIds = image_ids
133 | coco_eval.evaluate()
134 | coco_eval.accumulate()
135 | coco_eval.summarize()
136 |
137 |
138 | if __name__ == '__main__':
139 | SET_NAME = params['val_set']
140 | VAL_GT = f'datasets/{params["project_name"]}/annotations/instances_{SET_NAME}.json'
141 | VAL_IMGS = f'datasets/{params["project_name"]}/{SET_NAME}/'
142 | MAX_IMAGES = 10000
143 | coco_gt = COCO(VAL_GT)
144 | image_ids = coco_gt.getImgIds()[:MAX_IMAGES]
145 |
146 | if override_prev_results or not os.path.exists(f'{SET_NAME}_bbox_results.json'):
147 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list),
148 | ratios=eval(params['anchors_ratios']), scales=eval(params['anchors_scales']))
149 | model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
150 | model.requires_grad_(False)
151 | model.eval()
152 |
153 | if use_cuda:
154 | model.cuda(gpu)
155 |
156 | if use_float16:
157 | model.half()
158 |
159 | evaluate_coco(VAL_IMGS, SET_NAME, image_ids, coco_gt, model)
160 |
161 | _eval(coco_gt, image_ids, f'{SET_NAME}_bbox_results.json')
162 |
--------------------------------------------------------------------------------
/efficient_det_ros/efficientdet/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/efficientdet/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/efficient_det_ros/efficientdet/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/efficientdet/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/efficient_det_ros/efficientdet/config.py:
--------------------------------------------------------------------------------
1 | COCO_CLASSES = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
2 | "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog",
3 | "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
4 | "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite",
5 | "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
6 | "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
7 | "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant",
8 | "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
9 | "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
10 | "teddy bear", "hair drier", "toothbrush"]
11 |
12 | colors = [(39, 129, 113), (164, 80, 133), (83, 122, 114), (99, 81, 172), (95, 56, 104), (37, 84, 86), (14, 89, 122),
13 | (80, 7, 65), (10, 102, 25), (90, 185, 109), (106, 110, 132), (169, 158, 85), (188, 185, 26), (103, 1, 17),
14 | (82, 144, 81), (92, 7, 184), (49, 81, 155), (179, 177, 69), (93, 187, 158), (13, 39, 73), (12, 50, 60),
15 | (16, 179, 33), (112, 69, 165), (15, 139, 63), (33, 191, 159), (182, 173, 32), (34, 113, 133), (90, 135, 34),
16 | (53, 34, 86), (141, 35, 190), (6, 171, 8), (118, 76, 112), (89, 60, 55), (15, 54, 88), (112, 75, 181),
17 | (42, 147, 38), (138, 52, 63), (128, 65, 149), (106, 103, 24), (168, 33, 45), (28, 136, 135), (86, 91, 108),
18 | (52, 11, 76), (142, 6, 189), (57, 81, 168), (55, 19, 148), (182, 101, 89), (44, 65, 179), (1, 33, 26),
19 | (122, 164, 26), (70, 63, 134), (137, 106, 82), (120, 118, 52), (129, 74, 42), (182, 147, 112), (22, 157, 50),
20 | (56, 50, 20), (2, 22, 177), (156, 100, 106), (21, 35, 42), (13, 8, 121), (142, 92, 28), (45, 118, 33),
21 | (105, 118, 30), (7, 185, 124), (46, 34, 146), (105, 184, 169), (22, 18, 5), (147, 71, 73), (181, 64, 91),
22 | (31, 39, 184), (164, 179, 33), (96, 50, 18), (95, 15, 106), (113, 68, 54), (136, 116, 112), (119, 139, 130),
23 | (31, 139, 34), (66, 6, 127), (62, 39, 2), (49, 99, 180), (49, 119, 155), (153, 50, 183), (125, 38, 3),
24 | (129, 87, 143), (49, 87, 40), (128, 62, 120), (73, 85, 148), (28, 144, 118), (29, 9, 24), (175, 45, 108),
25 | (81, 175, 64), (178, 19, 157), (74, 188, 190), (18, 114, 2), (62, 128, 96), (21, 3, 150), (0, 6, 95),
26 | (2, 20, 184), (122, 37, 185)]
27 |
--------------------------------------------------------------------------------
/efficient_det_ros/efficientdet/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 | from torch.utils.data import Dataset, DataLoader
6 | from pycocotools.coco import COCO
7 | import cv2
8 |
9 |
10 | class CocoDataset(Dataset):
11 | def __init__(self, root_dir, set='train2017', transform=None):
12 |
13 | self.root_dir = root_dir
14 | self.set_name = set
15 | self.transform = transform
16 |
17 | self.coco = COCO(os.path.join(self.root_dir, 'annotations', 'instances_' + self.set_name + '.json'))
18 | self.image_ids = self.coco.getImgIds()
19 |
20 | self.load_classes()
21 |
22 | def load_classes(self):
23 |
24 | # load class names (name -> label)
25 | categories = self.coco.loadCats(self.coco.getCatIds())
26 | categories.sort(key=lambda x: x['id'])
27 |
28 | self.classes = {}
29 | for c in categories:
30 | self.classes[c['name']] = len(self.classes)
31 |
32 | # also load the reverse (label -> name)
33 | self.labels = {}
34 | for key, value in self.classes.items():
35 | self.labels[value] = key
36 |
37 | def __len__(self):
38 | return len(self.image_ids)
39 |
40 | def __getitem__(self, idx):
41 |
42 | img = self.load_image(idx)
43 | annot = self.load_annotations(idx)
44 | sample = {'img': img, 'annot': annot}
45 | if self.transform:
46 | sample = self.transform(sample)
47 | return sample
48 |
49 | def load_image(self, image_index):
50 | image_info = self.coco.loadImgs(self.image_ids[image_index])[0]
51 | path = os.path.join(self.root_dir, self.set_name, image_info['file_name'])
52 | img = cv2.imread(path)
53 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
54 |
55 | return img.astype(np.float32) / 255.
56 |
57 | def load_annotations(self, image_index):
58 | # get ground truth annotations
59 | annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False)
60 | annotations = np.zeros((0, 5))
61 |
62 | # some images appear to miss annotations
63 | if len(annotations_ids) == 0:
64 | return annotations
65 |
66 | # parse annotations
67 | coco_annotations = self.coco.loadAnns(annotations_ids)
68 | for idx, a in enumerate(coco_annotations):
69 |
70 | # some annotations have basically no width / height, skip them
71 | if a['bbox'][2] < 1 or a['bbox'][3] < 1:
72 | continue
73 |
74 | annotation = np.zeros((1, 5))
75 | annotation[0, :4] = a['bbox']
76 | annotation[0, 4] = a['category_id'] - 1
77 | annotations = np.append(annotations, annotation, axis=0)
78 |
79 | # transform from [x, y, w, h] to [x1, y1, x2, y2]
80 | annotations[:, 2] = annotations[:, 0] + annotations[:, 2]
81 | annotations[:, 3] = annotations[:, 1] + annotations[:, 3]
82 |
83 | return annotations
84 |
85 |
86 | def collater(data):
87 | imgs = [s['img'] for s in data]
88 | annots = [s['annot'] for s in data]
89 | scales = [s['scale'] for s in data]
90 |
91 | imgs = torch.from_numpy(np.stack(imgs, axis=0))
92 |
93 | max_num_annots = max(annot.shape[0] for annot in annots)
94 |
95 | if max_num_annots > 0:
96 |
97 | annot_padded = torch.ones((len(annots), max_num_annots, 5)) * -1
98 |
99 | for idx, annot in enumerate(annots):
100 | if annot.shape[0] > 0:
101 | annot_padded[idx, :annot.shape[0], :] = annot
102 | else:
103 | annot_padded = torch.ones((len(annots), 1, 5)) * -1
104 |
105 | imgs = imgs.permute(0, 3, 1, 2)
106 |
107 | return {'img': imgs, 'annot': annot_padded, 'scale': scales}
108 |
109 |
110 | class Resizer(object):
111 | """Convert ndarrays in sample to Tensors."""
112 |
113 | def __init__(self, img_size=512):
114 | self.img_size = img_size
115 |
116 | def __call__(self, sample):
117 | image, annots = sample['img'], sample['annot']
118 | height, width, _ = image.shape
119 | if height > width:
120 | scale = self.img_size / height
121 | resized_height = self.img_size
122 | resized_width = int(width * scale)
123 | else:
124 | scale = self.img_size / width
125 | resized_height = int(height * scale)
126 | resized_width = self.img_size
127 |
128 | image = cv2.resize(image, (resized_width, resized_height), interpolation=cv2.INTER_LINEAR)
129 |
130 | new_image = np.zeros((self.img_size, self.img_size, 3))
131 | new_image[0:resized_height, 0:resized_width] = image
132 |
133 | annots[:, :4] *= scale
134 |
135 | return {'img': torch.from_numpy(new_image).to(torch.float32), 'annot': torch.from_numpy(annots), 'scale': scale}
136 |
137 |
138 | class Augmenter(object):
139 | """Convert ndarrays in sample to Tensors."""
140 |
141 | def __call__(self, sample, flip_x=0.5):
142 | if np.random.rand() < flip_x:
143 | image, annots = sample['img'], sample['annot']
144 | image = image[:, ::-1, :]
145 |
146 | rows, cols, channels = image.shape
147 |
148 | x1 = annots[:, 0].copy()
149 | x2 = annots[:, 2].copy()
150 |
151 | x_tmp = x1.copy()
152 |
153 | annots[:, 0] = cols - x2
154 | annots[:, 2] = cols - x_tmp
155 |
156 | sample = {'img': image, 'annot': annots}
157 |
158 | return sample
159 |
160 |
161 | class Normalizer(object):
162 |
163 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
164 | self.mean = np.array([[mean]])
165 | self.std = np.array([[std]])
166 |
167 | def __call__(self, sample):
168 | image, annots = sample['img'], sample['annot']
169 |
170 | return {'img': ((image.astype(np.float32) - self.mean) / self.std), 'annot': annots}
171 |
--------------------------------------------------------------------------------
/efficient_det_ros/efficientdet/utils.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 |
7 | class BBoxTransform(nn.Module):
8 | def forward(self, anchors, regression):
9 | """
10 | decode_box_outputs adapted from https://github.com/google/automl/blob/master/efficientdet/anchors.py
11 |
12 | Args:
13 | anchors: [batchsize, boxes, (y1, x1, y2, x2)]
14 | regression: [batchsize, boxes, (dy, dx, dh, dw)]
15 |
16 | Returns:
17 |
18 | """
19 | y_centers_a = (anchors[..., 0] + anchors[..., 2]) / 2
20 | x_centers_a = (anchors[..., 1] + anchors[..., 3]) / 2
21 | ha = anchors[..., 2] - anchors[..., 0]
22 | wa = anchors[..., 3] - anchors[..., 1]
23 |
24 | w = regression[..., 3].exp() * wa
25 | h = regression[..., 2].exp() * ha
26 |
27 | y_centers = regression[..., 0] * ha + y_centers_a
28 | x_centers = regression[..., 1] * wa + x_centers_a
29 |
30 | ymin = y_centers - h / 2.
31 | xmin = x_centers - w / 2.
32 | ymax = y_centers + h / 2.
33 | xmax = x_centers + w / 2.
34 |
35 | return torch.stack([xmin, ymin, xmax, ymax], dim=2)
36 |
37 |
38 | class ClipBoxes(nn.Module):
39 |
40 | def __init__(self):
41 | super(ClipBoxes, self).__init__()
42 |
43 | def forward(self, boxes, img):
44 | batch_size, num_channels, height, width = img.shape
45 |
46 | boxes[:, :, 0] = torch.clamp(boxes[:, :, 0], min=0)
47 | boxes[:, :, 1] = torch.clamp(boxes[:, :, 1], min=0)
48 |
49 | boxes[:, :, 2] = torch.clamp(boxes[:, :, 2], max=width - 1)
50 | boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=height - 1)
51 |
52 | return boxes
53 |
54 |
55 | class Anchors(nn.Module):
56 | """
57 | adapted and modified from https://github.com/google/automl/blob/master/efficientdet/anchors.py by Zylo117
58 | """
59 |
60 | def __init__(self, anchor_scale=4., pyramid_levels=None, **kwargs):
61 | super().__init__()
62 | self.anchor_scale = anchor_scale
63 |
64 | if pyramid_levels is None:
65 | self.pyramid_levels = [3, 4, 5, 6, 7]
66 | else:
67 | self.pyramid_levels = pyramid_levels
68 |
69 | self.strides = kwargs.get('strides', [2 ** x for x in self.pyramid_levels])
70 | self.scales = np.array(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]))
71 | self.ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)])
72 |
73 | self.last_anchors = {}
74 | self.last_shape = None
75 |
76 | def forward(self, image, dtype=torch.float32):
77 | """Generates multiscale anchor boxes.
78 |
79 | Args:
80 | image_size: integer number of input image size. The input image has the
81 | same dimension for width and height. The image_size should be divided by
82 | the largest feature stride 2^max_level.
83 | anchor_scale: float number representing the scale of size of the base
84 | anchor to the feature stride 2^level.
85 | anchor_configs: a dictionary with keys as the levels of anchors and
86 | values as a list of anchor configuration.
87 |
88 | Returns:
89 | anchor_boxes: a numpy array with shape [N, 4], which stacks anchors on all
90 | feature levels.
91 | Raises:
92 | ValueError: input size must be the multiple of largest feature stride.
93 | """
94 | image_shape = image.shape[2:]
95 |
96 | if image_shape == self.last_shape and image.device in self.last_anchors:
97 | return self.last_anchors[image.device]
98 |
99 | if self.last_shape is None or self.last_shape != image_shape:
100 | self.last_shape = image_shape
101 |
102 | if dtype == torch.float16:
103 | dtype = np.float16
104 | else:
105 | dtype = np.float32
106 |
107 | boxes_all = []
108 | for stride in self.strides:
109 | boxes_level = []
110 | for scale, ratio in itertools.product(self.scales, self.ratios):
111 | if image_shape[1] % stride != 0:
112 | raise ValueError('input size must be divided by the stride.')
113 | base_anchor_size = self.anchor_scale * stride * scale
114 | anchor_size_x_2 = base_anchor_size * ratio[0] / 2.0
115 | anchor_size_y_2 = base_anchor_size * ratio[1] / 2.0
116 |
117 | x = np.arange(stride / 2, image_shape[1], stride)
118 | y = np.arange(stride / 2, image_shape[0], stride)
119 | xv, yv = np.meshgrid(x, y)
120 | xv = xv.reshape(-1)
121 | yv = yv.reshape(-1)
122 |
123 | # y1,x1,y2,x2
124 | boxes = np.vstack((yv - anchor_size_y_2, xv - anchor_size_x_2,
125 | yv + anchor_size_y_2, xv + anchor_size_x_2))
126 | boxes = np.swapaxes(boxes, 0, 1)
127 | boxes_level.append(np.expand_dims(boxes, axis=1))
128 | # concat anchors on the same level to the reshape NxAx4
129 | boxes_level = np.concatenate(boxes_level, axis=1)
130 | boxes_all.append(boxes_level.reshape([-1, 4]))
131 |
132 | anchor_boxes = np.vstack(boxes_all)
133 |
134 | anchor_boxes = torch.from_numpy(anchor_boxes.astype(dtype)).to(image.device)
135 | anchor_boxes = anchor_boxes.unsqueeze(0)
136 |
137 | # save it for later use to reduce overhead
138 | self.last_anchors[image.device] = anchor_boxes
139 | return anchor_boxes
140 |
--------------------------------------------------------------------------------
/efficient_det_ros/efficientdet_test.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | """
4 | Simple Inference Script of EfficientDet-Pytorch
5 | """
6 | import time
7 | import torch
8 | from torch.backends import cudnn
9 | from matplotlib import colors
10 |
11 | from backbone import EfficientDetBackbone
12 | import cv2
13 | import numpy as np
14 |
15 | from efficientdet.utils import BBoxTransform, ClipBoxes
16 | from utils.utils import preprocess, invert_affine, postprocess, STANDARD_COLORS, standard_to_bgr, get_index_label, plot_one_box
17 |
18 | compound_coef = 0
19 | force_input_size = None # set None to use default size
20 | img_path = 'test/0000000000.png'
21 |
22 | # replace this part with your project's anchor config
23 | anchor_ratios = [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]
24 | anchor_scales = [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]
25 |
26 | threshold = 0.2
27 | iou_threshold = 0.2
28 |
29 | use_cuda = True
30 | use_float16 = False
31 | cudnn.fastest = True
32 | cudnn.benchmark = True
33 |
34 | # obj_list = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
35 | # 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
36 | # 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie',
37 | # 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
38 | # 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
39 | # 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
40 | # 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv',
41 | # 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
42 | # 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
43 | # 'toothbrush']
44 |
45 | obj_list = ['car', 'person', 'cyclist']
46 |
47 |
48 | color_list = standard_to_bgr(STANDARD_COLORS)
49 | # tf bilinear interpolation is different from any other's, just make do
50 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
51 | input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size
52 | ori_imgs, framed_imgs, framed_metas = preprocess(img_path, max_size=input_size)
53 |
54 | if use_cuda:
55 | x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0)
56 | else:
57 | x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0)
58 |
59 | x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2)
60 |
61 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list),
62 | ratios=anchor_ratios, scales=anchor_scales)
63 | model.load_state_dict(torch.load(f'weights/efficientdet-d{compound_coef}.pth', map_location='cpu'))
64 | model.requires_grad_(False)
65 | model.eval()
66 |
67 | if use_cuda:
68 | model = model.cuda()
69 | if use_float16:
70 | model = model.half()
71 |
72 | with torch.no_grad():
73 | features, regression, classification, anchors = model(x)
74 |
75 | regressBoxes = BBoxTransform()
76 | clipBoxes = ClipBoxes()
77 |
78 | out = postprocess(x,
79 | anchors, regression, classification,
80 | regressBoxes, clipBoxes,
81 | threshold, iou_threshold)
82 |
83 | def display(preds, imgs, imshow=True, imwrite=False):
84 | for i in range(len(imgs)):
85 | if len(preds[i]['rois']) == 0:
86 | continue
87 |
88 | imgs[i] = imgs[i].copy()
89 |
90 | for j in range(len(preds[i]['rois'])):
91 | x1, y1, x2, y2 = preds[i]['rois'][j].astype(np.int)
92 | obj = obj_list[preds[i]['class_ids'][j]]
93 | score = float(preds[i]['scores'][j])
94 | plot_one_box(imgs[i], [x1, y1, x2, y2], label=obj,score=score,color=color_list[get_index_label(obj, obj_list)])
95 |
96 |
97 | if imshow:
98 | cv2.imshow('img', imgs[i])
99 | cv2.waitKey(0)
100 |
101 | if imwrite:
102 | cv2.imwrite(f'test/img_inferred_d{compound_coef}_this_repo_{i}.jpg', imgs[i])
103 |
104 |
105 | out = invert_affine(framed_metas, out)
106 | display(out, ori_imgs, imshow=False, imwrite=True)
107 |
108 | print('running speed test...')
109 | with torch.no_grad():
110 | print('test1: model inferring and postprocessing')
111 | print('inferring image for 10 times...')
112 | t1 = time.time()
113 | for _ in range(10):
114 | _, regression, classification, anchors = model(x)
115 |
116 | out = postprocess(x,
117 | anchors, regression, classification,
118 | regressBoxes, clipBoxes,
119 | threshold, iou_threshold)
120 | out = invert_affine(framed_metas, out)
121 |
122 | t2 = time.time()
123 | tact_time = (t2 - t1) / 10
124 | print(f'{tact_time} seconds, {1 / tact_time} FPS, @batch_size 1')
125 |
126 | # uncomment this if you want a extreme fps test
127 | # print('test2: model inferring only')
128 | # print('inferring images for batch_size 32 for 10 times...')
129 | # t1 = time.time()
130 | # x = torch.cat([x] * 32, 0)
131 | # for _ in range(10):
132 | # _, regression, classification, anchors = model(x)
133 | #
134 | # t2 = time.time()
135 | # tact_time = (t2 - t1) / 10
136 | # print(f'{tact_time} seconds, {32 / tact_time} FPS, @batch_size 32')
137 |
--------------------------------------------------------------------------------
/efficient_det_ros/efficientdet_test_videos.py:
--------------------------------------------------------------------------------
1 | # Core Author: Zylo117
2 | # Script's Author: winter2897
3 |
4 | """
5 | Simple Inference Script of EfficientDet-Pytorch for detecting objects on webcam
6 | """
7 | import time
8 | import torch
9 | import cv2
10 | import numpy as np
11 | from torch.backends import cudnn
12 | from backbone import EfficientDetBackbone
13 | from efficientdet.utils import BBoxTransform, ClipBoxes
14 | from utils.utils import preprocess, invert_affine, postprocess, preprocess_video
15 |
16 | # Video's path
17 | video_src = 'videotest.mp4' # set int to use webcam, set str to read from a video file
18 |
19 | compound_coef = 0
20 | force_input_size = None # set None to use default size
21 |
22 | threshold = 0.2
23 | iou_threshold = 0.2
24 |
25 | use_cuda = True
26 | use_float16 = False
27 | cudnn.fastest = True
28 | cudnn.benchmark = True
29 |
30 | obj_list = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
31 | 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
32 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie',
33 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
34 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
35 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
36 | 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv',
37 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
38 | 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
39 | 'toothbrush']
40 |
41 | # tf bilinear interpolation is different from any other's, just make do
42 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
43 | input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size
44 |
45 | # load model
46 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list))
47 | model.load_state_dict(torch.load(f'weights/efficientdet-d{compound_coef}.pth'))
48 | model.requires_grad_(False)
49 | model.eval()
50 |
51 | if use_cuda:
52 | model = model.cuda()
53 | if use_float16:
54 | model = model.half()
55 |
56 | # function for display
57 | def display(preds, imgs):
58 | for i in range(len(imgs)):
59 | if len(preds[i]['rois']) == 0:
60 | return imgs[i]
61 |
62 | for j in range(len(preds[i]['rois'])):
63 | (x1, y1, x2, y2) = preds[i]['rois'][j].astype(np.int)
64 | cv2.rectangle(imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2)
65 | obj = obj_list[preds[i]['class_ids'][j]]
66 | score = float(preds[i]['scores'][j])
67 |
68 | cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score),
69 | (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
70 | (255, 255, 0), 1)
71 |
72 | return imgs[i]
73 | # Box
74 | regressBoxes = BBoxTransform()
75 | clipBoxes = ClipBoxes()
76 |
77 | # Video capture
78 | cap = cv2.VideoCapture(video_src)
79 |
80 | while True:
81 | ret, frame = cap.read()
82 | if not ret:
83 | break
84 |
85 | # frame preprocessing
86 | ori_imgs, framed_imgs, framed_metas = preprocess_video(frame, max_size=input_size)
87 |
88 | if use_cuda:
89 | x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0)
90 | else:
91 | x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0)
92 |
93 | x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2)
94 |
95 | # model predict
96 | with torch.no_grad():
97 | features, regression, classification, anchors = model(x)
98 |
99 | out = postprocess(x,
100 | anchors, regression, classification,
101 | regressBoxes, clipBoxes,
102 | threshold, iou_threshold)
103 |
104 | # result
105 | out = invert_affine(framed_metas, out)
106 | img_show = display(out, ori_imgs)
107 |
108 | # show frame by frame
109 | cv2.imshow('frame',img_show)
110 | if cv2.waitKey(1) & 0xFF == ord('q'):
111 | break
112 |
113 | cap.release()
114 | cv2.destroyAllWindows()
115 |
116 |
117 |
118 |
119 |
120 |
--------------------------------------------------------------------------------
/efficient_det_ros/efficientnet/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.6.1"
2 | from .model import EfficientNet
3 | from .utils import (
4 | GlobalParams,
5 | BlockArgs,
6 | BlockDecoder,
7 | efficientnet,
8 | get_model_params,
9 | )
10 |
11 |
--------------------------------------------------------------------------------
/efficient_det_ros/efficientnet/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/efficientnet/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/efficient_det_ros/efficientnet/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/efficientnet/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/efficient_det_ros/efficientnet/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/efficientnet/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/efficient_det_ros/efficientnet/__pycache__/utils_extra.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/efficientnet/__pycache__/utils_extra.cpython-37.pyc
--------------------------------------------------------------------------------
/efficient_det_ros/efficientnet/utils_extra.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | import math
4 |
5 | from torch import nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class Conv2dStaticSamePadding(nn.Module):
10 | """
11 | created by Zylo117
12 | The real keras/tensorflow conv2d with same padding
13 | """
14 |
15 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs):
16 | super().__init__()
17 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
18 | bias=bias, groups=groups)
19 | self.stride = self.conv.stride
20 | self.kernel_size = self.conv.kernel_size
21 | self.dilation = self.conv.dilation
22 |
23 | if isinstance(self.stride, int):
24 | self.stride = [self.stride] * 2
25 | elif len(self.stride) == 1:
26 | self.stride = [self.stride[0]] * 2
27 |
28 | if isinstance(self.kernel_size, int):
29 | self.kernel_size = [self.kernel_size] * 2
30 | elif len(self.kernel_size) == 1:
31 | self.kernel_size = [self.kernel_size[0]] * 2
32 |
33 | def forward(self, x):
34 | h, w = x.shape[-2:]
35 |
36 | extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1]
37 | extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0]
38 |
39 | left = extra_h // 2
40 | right = extra_h - left
41 | top = extra_v // 2
42 | bottom = extra_v - top
43 |
44 | x = F.pad(x, [left, right, top, bottom])
45 |
46 | x = self.conv(x)
47 | return x
48 |
49 |
50 | class MaxPool2dStaticSamePadding(nn.Module):
51 | """
52 | created by Zylo117
53 | The real keras/tensorflow MaxPool2d with same padding
54 | """
55 |
56 | def __init__(self, *args, **kwargs):
57 | super().__init__()
58 | self.pool = nn.MaxPool2d(*args, **kwargs)
59 | self.stride = self.pool.stride
60 | self.kernel_size = self.pool.kernel_size
61 |
62 | if isinstance(self.stride, int):
63 | self.stride = [self.stride] * 2
64 | elif len(self.stride) == 1:
65 | self.stride = [self.stride[0]] * 2
66 |
67 | if isinstance(self.kernel_size, int):
68 | self.kernel_size = [self.kernel_size] * 2
69 | elif len(self.kernel_size) == 1:
70 | self.kernel_size = [self.kernel_size[0]] * 2
71 |
72 | def forward(self, x):
73 | h, w = x.shape[-2:]
74 |
75 | extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1]
76 | extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0]
77 |
78 | left = extra_h // 2
79 | right = extra_h - left
80 | top = extra_v // 2
81 | bottom = extra_v - top
82 |
83 | x = F.pad(x, [left, right, top, bottom])
84 |
85 | x = self.pool(x)
86 | return x
87 |
--------------------------------------------------------------------------------
/efficient_det_ros/projects/coco.yml:
--------------------------------------------------------------------------------
1 | project_name: coco # also the folder name of the dataset that under data_path folder
2 | train_set: train2017
3 | val_set: val2017
4 | num_gpus: 4
5 |
6 | # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco.
7 | mean: [0.485, 0.456, 0.406]
8 | std: [0.229, 0.224, 0.225]
9 |
10 | # this is coco anchors, change it if necessary
11 | anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]'
12 | anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]'
13 |
14 | # must match your dataset's category_id.
15 | # category_id is one_indexed,
16 | # for example, index of 'car' here is 2, while category_id of is 3
17 | obj_list: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
18 | 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
19 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie',
20 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
21 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
22 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
23 | 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv',
24 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
25 | 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
26 | 'toothbrush']
--------------------------------------------------------------------------------
/efficient_det_ros/projects/kitti.yml:
--------------------------------------------------------------------------------
1 | project_name: KITTI_3D_Object # also the folder name of the dataset that under data_path folder
2 | train_set: train
3 | val_set: val
4 | num_gpus: 2
5 |
6 | # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco.
7 | mean: [0.485, 0.456, 0.406]
8 | std: [0.229, 0.224, 0.225]
9 |
10 | # this is coco anchors, change it if necessary
11 | anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]'
12 | anchors_ratios: '[(0.6, 1.5), (1.1, 0.9), (1.5, 0.7)]'
13 |
14 | # must match your dataset's category_id.
15 | # category_id is one_indexed,
16 | # for example, index of 'car' here is 2, while category_id of is 3
17 | obj_list: ['car', 'person', 'cyclist']
18 |
--------------------------------------------------------------------------------
/efficient_det_ros/projects/shape.yml:
--------------------------------------------------------------------------------
1 | project_name: shape # also the folder name of the dataset that under data_path folder
2 | train_set: train
3 | val_set: val
4 | num_gpus: 1
5 |
6 | # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco.
7 | mean: [0.485, 0.456, 0.406]
8 | std: [0.229, 0.224, 0.225]
9 |
10 | # this anchor is adapted to the dataset
11 | anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]'
12 | anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]'
13 |
14 | obj_list: ['rectangle', 'circle']
--------------------------------------------------------------------------------
/efficient_det_ros/utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/efficient_det_ros/utils/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .batchnorm import patch_sync_batchnorm, convert_model
13 | from .replicate import DataParallelWithCallback, patch_replication_callback
14 |
--------------------------------------------------------------------------------
/efficient_det_ros/utils/sync_batchnorm/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/utils/sync_batchnorm/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/efficient_det_ros/utils/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/utils/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc
--------------------------------------------------------------------------------
/efficient_det_ros/utils/sync_batchnorm/__pycache__/comm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/utils/sync_batchnorm/__pycache__/comm.cpython-37.pyc
--------------------------------------------------------------------------------
/efficient_det_ros/utils/sync_batchnorm/__pycache__/replicate.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/utils/sync_batchnorm/__pycache__/replicate.cpython-37.pyc
--------------------------------------------------------------------------------
/efficient_det_ros/utils/sync_batchnorm/batchnorm_reimpl.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : batchnorm_reimpl.py
4 | # Author : acgtyrant
5 | # Date : 11/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 |
15 | __all__ = ['BatchNorm2dReimpl']
16 |
17 |
18 | class BatchNorm2dReimpl(nn.Module):
19 | """
20 | A re-implementation of batch normalization, used for testing the numerical
21 | stability.
22 |
23 | Author: acgtyrant
24 | See also:
25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26 | """
27 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
28 | super().__init__()
29 |
30 | self.num_features = num_features
31 | self.eps = eps
32 | self.momentum = momentum
33 | self.weight = nn.Parameter(torch.empty(num_features))
34 | self.bias = nn.Parameter(torch.empty(num_features))
35 | self.register_buffer('running_mean', torch.zeros(num_features))
36 | self.register_buffer('running_var', torch.ones(num_features))
37 | self.reset_parameters()
38 |
39 | def reset_running_stats(self):
40 | self.running_mean.zero_()
41 | self.running_var.fill_(1)
42 |
43 | def reset_parameters(self):
44 | self.reset_running_stats()
45 | init.uniform_(self.weight)
46 | init.zeros_(self.bias)
47 |
48 | def forward(self, input_):
49 | batchsize, channels, height, width = input_.size()
50 | numel = batchsize * height * width
51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52 | sum_ = input_.sum(1)
53 | sum_of_square = input_.pow(2).sum(1)
54 | mean = sum_ / numel
55 | sumvar = sum_of_square - sum_ * mean
56 |
57 | self.running_mean = (
58 | (1 - self.momentum) * self.running_mean
59 | + self.momentum * mean.detach()
60 | )
61 | unbias_var = sumvar / (numel - 1)
62 | self.running_var = (
63 | (1 - self.momentum) * self.running_var
64 | + self.momentum * unbias_var.detach()
65 | )
66 |
67 | bias_var = sumvar / numel
68 | inv_std = 1 / (bias_var + self.eps).pow(0.5)
69 | output = (
70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72 |
73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74 |
75 |
--------------------------------------------------------------------------------
/efficient_det_ros/utils/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def __getstate__(self):
79 | return {'master_callback': self._master_callback}
80 |
81 | def __setstate__(self, state):
82 | self.__init__(state['master_callback'])
83 |
84 | def register_slave(self, identifier):
85 | """
86 | Register an slave device.
87 |
88 | Args:
89 | identifier: an identifier, usually is the device id.
90 |
91 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
92 |
93 | """
94 | if self._activated:
95 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
96 | self._activated = False
97 | self._registry.clear()
98 | future = FutureResult()
99 | self._registry[identifier] = _MasterRegistry(future)
100 | return SlavePipe(identifier, self._queue, future)
101 |
102 | def run_master(self, master_msg):
103 | """
104 | Main entry for the master device in each forward pass.
105 | The messages were first collected from each devices (including the master device), and then
106 | an callback will be invoked to compute the message to be sent back to each devices
107 | (including the master device).
108 |
109 | Args:
110 | master_msg: the message that the master want to send to itself. This will be placed as the first
111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112 |
113 | Returns: the message to be sent back to the master device.
114 |
115 | """
116 | self._activated = True
117 |
118 | intermediates = [(0, master_msg)]
119 | for i in range(self.nr_slaves):
120 | intermediates.append(self._queue.get())
121 |
122 | results = self._master_callback(intermediates)
123 | assert results[0][0] == 0, 'The first result should belongs to the master.'
124 |
125 | for i, res in results:
126 | if i == 0:
127 | continue
128 | self._registry[i].result.put(res)
129 |
130 | for i in range(self.nr_slaves):
131 | assert self._queue.get() is True
132 |
133 | return results[0][1]
134 |
135 | @property
136 | def nr_slaves(self):
137 | return len(self._registry)
138 |
--------------------------------------------------------------------------------
/efficient_det_ros/utils/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/efficient_det_ros/utils/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 | import torch
13 |
14 |
15 | class TorchTestCase(unittest.TestCase):
16 | def assertTensorClose(self, x, y):
17 | adiff = float((x - y).abs().max())
18 | if (y == 0).all():
19 | rdiff = 'NaN'
20 | else:
21 | rdiff = float((adiff / y).abs().max())
22 |
23 | message = (
24 | 'Tensor close check failed\n'
25 | 'adiff={}\n'
26 | 'rdiff={}\n'
27 | ).format(adiff, rdiff)
28 | self.assertTrue(torch.allclose(x, y), message)
29 |
30 |
--------------------------------------------------------------------------------
/efficient_det_ros/weights/efficientdet-d2.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/efficient_det_ros/weights/efficientdet-d2.pth
--------------------------------------------------------------------------------
/kitti_camera_ros/.gitignore:
--------------------------------------------------------------------------------
1 | # Prerequisites
2 | *.d
3 |
4 | # Compiled Object files
5 | *.slo
6 | *.lo
7 | *.o
8 | *.obj
9 |
10 | # Precompiled Headers
11 | *.gch
12 | *.pch
13 |
14 | # Compiled Dynamic libraries
15 | *.so
16 | *.dylib
17 | *.dll
18 |
19 | # Fortran module files
20 | *.mod
21 | *.smod
22 |
23 | # Compiled Static libraries
24 | *.lai
25 | *.la
26 | *.a
27 | *.lib
28 |
29 | # Executables
30 | *.exe
31 | *.out
32 | *.app
33 |
--------------------------------------------------------------------------------
/kitti_camera_ros/.travis.yml:
--------------------------------------------------------------------------------
1 | dist: bionic
2 | sudo: required
3 | language: generic
4 | cache: apt
5 |
6 | install:
7 | - sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list'
8 | - sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654
9 | - sudo apt update
10 | - sudo apt install ros-melodic-desktop-full
11 | - source /opt/ros/melodic/setup.bash
12 | - mkdir -p ~/catkin_ws/src
13 | - cd ~/catkin_ws/
14 | - catkin_make
15 | - source devel/setup.bash
16 |
17 | script:
18 | - cd ~/catkin_ws/src
19 | - git clone -b melodic https://github.com/epan-utbm/kitti_velodyne_ros.git
20 | - cd ~/catkin_ws
21 | - catkin_make
22 |
--------------------------------------------------------------------------------
/kitti_camera_ros/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 2.8.3)
2 | project(kitti_camera_ros)
3 |
4 | find_package(catkin REQUIRED COMPONENTS
5 | roscpp
6 | sensor_msgs
7 | geometry_msgs
8 | visualization_msgs
9 | pcl_conversions
10 | pcl_ros
11 | )
12 |
13 | find_package(PCL REQUIRED)
14 |
15 | include_directories(include ${catkin_INCLUDE_DIRS} ${PCL_INCLUDE_DIRS})
16 |
17 | catkin_package()
18 |
19 | add_executable(kitti_camera_ros src/kitti_camera_ros.cpp)
20 | target_link_libraries(kitti_camera_ros ${catkin_LIBRARIES} ${PCL_LIBRARIES})
21 | if(catkin_EXPORTED_TARGETS)
22 | add_dependencies(kitti_camera_ros ${catkin_EXPORTED_TARGETS})
23 | endif()
24 |
--------------------------------------------------------------------------------
/kitti_camera_ros/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2021, Rui Yang
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/kitti_camera_ros/README.md:
--------------------------------------------------------------------------------
1 | # kitti_camera_ros
2 |
3 | Load KITTI camera data, play in ROS.
4 |
5 | ## Usage
6 |
7 | ```console
8 | $ roslaunch kitti_camera_ros kitti_camera_ros.launch
9 | ```
10 |
--------------------------------------------------------------------------------
/kitti_camera_ros/launch/kitti_camera_ros.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/kitti_camera_ros/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | kitti_camera_ros
4 | 0.0.1
5 | Load KITTI camera data, play in ROS
6 | Rui Yang
7 | BSD
8 |
9 | https://github.com/epan-utbm/efficient_online_learning
10 | Rui Yang
11 |
12 | catkin
13 |
14 | roscpp
15 | sensor_msgs
16 | geometry_msgs
17 | visualization_msgs
18 | pcl_conversions
19 | pcl_ros
20 |
21 | roscpp
22 | sensor_msgs
23 | geometry_msgs
24 | visualization_msgs
25 | pcl_conversions
26 | pcl_ros
27 |
28 |
--------------------------------------------------------------------------------
/kitti_camera_ros/src/kitti_camera_ros.cpp:
--------------------------------------------------------------------------------
1 | // ROS
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | // PCL
8 | #include
9 | // C++
10 | #include
11 | // Rui
12 | #include
13 | #include
14 |
15 | int main(int argc, char **argv) {
16 | double frequency;
17 | std::string camera_dir;
18 | double timestamp;
19 |
20 | ros::init(argc, argv, "kitti_camera_ros");
21 | ros::NodeHandle private_nh("~");
22 |
23 | ros::Publisher camera_pub = private_nh.advertise("/image_detections", 100, true);
24 |
25 | private_nh.param("frequency", frequency, 10);
26 | private_nh.param("camera_dir", camera_dir, "camera_dir_path");
27 |
28 | ros::Rate loop_rate(frequency);
29 |
30 | //vision_msgs::Detection2DArray detection_results;
31 |
32 | struct dirent **filelist;
33 | int n_file = scandir(camera_dir.c_str(), &filelist, NULL, alphasort);
34 | if(n_file == -1) {
35 | ROS_ERROR_STREAM("[kitti_camera_ros] Could not open directory: " << camera_dir);
36 | return EXIT_FAILURE;
37 | } else {
38 | ROS_INFO_STREAM("[kitti_camera_ros] Load camera files in " << camera_dir);
39 | ROS_INFO_STREAM("[kitti_camera_ros] frequency (loop rate): " << frequency);
40 | }
41 |
42 | int i_file = 2; // 0 = . 1 = ..
43 | while(ros::ok() && i_file < n_file) {
44 | vision_msgs::Detection2DArray detection_results;
45 |
46 | /*** Camera ***/
47 | std::string s = camera_dir + filelist[i_file]->d_name;
48 | std::fstream camera_txt(s.c_str(), std::ios::in | std::ios::binary);
49 | //std::cerr << "s: " << s.c_str() << std::endl;
50 | if(!camera_txt.good()) {
51 | ROS_ERROR_STREAM("[kitti_camera_ros] Could not read file: " << s);
52 | return EXIT_FAILURE;
53 | } else {
54 | camera_txt >> timestamp;
55 | ros::Time timestamp_ros(timestamp == 0 ? ros::TIME_MIN.toSec() : timestamp);
56 | detection_results.header.stamp = timestamp_ros;
57 |
58 | //camera_txt.seekg(0, std::ios::beg);
59 |
60 | for(int i = 0; camera_txt.good() && !camera_txt.eof(); i++) {
61 | vision_msgs::Detection2D detection;
62 | vision_msgs::ObjectHypothesisWithPose result;
63 | camera_txt >> detection.bbox.center.x;
64 | camera_txt >> detection.bbox.center.y;
65 | camera_txt >> detection.bbox.size_x;
66 | camera_txt >> detection.bbox.size_y;
67 | camera_txt >> result.id;
68 | camera_txt >> result.score;
69 | detection.results.push_back(result);
70 | detection_results.detections.push_back(detection);
71 | }
72 | camera_txt.close();
73 |
74 | camera_pub.publish(detection_results);
75 | // ROS_INFO_STREAM("[kitti_camera_ros] detection_results.size " << detection_results.detections.size());
76 | // ROS_INFO_STREAM("--------------------------------------------");
77 | // for(int n = 0; n < detection_results.detections.size(); n++) {
78 | // ROS_INFO_STREAM("[kitti_camera_ros] detections.label " << detection_results.detections[n].results[0].id);
79 | // ROS_INFO_STREAM("[kitti_camera_ros] detections.score " << detection_results.detections[n].results[0].score);
80 | // }
81 | }
82 |
83 | ros::spinOnce();
84 | loop_rate.sleep();
85 | i_file++;
86 | }
87 |
88 | for(int i = 2; i < n_file; i++) {
89 | free(filelist[i]);
90 | }
91 | free(filelist);
92 |
93 | return EXIT_SUCCESS;
94 | }
95 |
--------------------------------------------------------------------------------
/kitti_velodyne_ros/.gitignore:
--------------------------------------------------------------------------------
1 | # Prerequisites
2 | *.d
3 |
4 | # Compiled Object files
5 | *.slo
6 | *.lo
7 | *.o
8 | *.obj
9 |
10 | # Precompiled Headers
11 | *.gch
12 | *.pch
13 |
14 | # Compiled Dynamic libraries
15 | *.so
16 | *.dylib
17 | *.dll
18 |
19 | # Fortran module files
20 | *.mod
21 | *.smod
22 |
23 | # Compiled Static libraries
24 | *.lai
25 | *.la
26 | *.a
27 | *.lib
28 |
29 | # Executables
30 | *.exe
31 | *.out
32 | *.app
33 |
--------------------------------------------------------------------------------
/kitti_velodyne_ros/.travis.yml:
--------------------------------------------------------------------------------
1 | dist: bionic
2 | sudo: required
3 | language: generic
4 | cache: apt
5 |
6 | install:
7 | - sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list'
8 | - sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654
9 | - sudo apt update
10 | - sudo apt install ros-melodic-desktop-full
11 | - source /opt/ros/melodic/setup.bash
12 | - mkdir -p ~/catkin_ws/src
13 | - cd ~/catkin_ws/
14 | - catkin_make
15 | - source devel/setup.bash
16 |
17 | script:
18 | - cd ~/catkin_ws/src
19 | - git clone -b melodic https://github.com/epan-utbm/kitti_velodyne_ros.git
20 | - cd ~/catkin_ws
21 | - catkin_make
22 |
--------------------------------------------------------------------------------
/kitti_velodyne_ros/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 2.8.3)
2 | project(kitti_velodyne_ros)
3 |
4 | find_package(catkin REQUIRED COMPONENTS
5 | roscpp
6 | sensor_msgs
7 | geometry_msgs
8 | visualization_msgs
9 | pcl_conversions
10 | pcl_ros
11 | )
12 |
13 | find_package(PCL REQUIRED)
14 |
15 | include_directories(include ${catkin_INCLUDE_DIRS} ${PCL_INCLUDE_DIRS})
16 |
17 | catkin_package()
18 |
19 | add_executable(kitti_velodyne_ros src/kitti_velodyne_ros.cpp)
20 | target_link_libraries(kitti_velodyne_ros ${catkin_LIBRARIES} ${PCL_LIBRARIES})
21 | if(catkin_EXPORTED_TARGETS)
22 | add_dependencies(kitti_velodyne_ros ${catkin_EXPORTED_TARGETS})
23 | endif()
24 |
--------------------------------------------------------------------------------
/kitti_velodyne_ros/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020, Zhi Yan
4 | Copyright (c) 2021, Rui Yang
5 | All rights reserved.
6 |
7 | Redistribution and use in source and binary forms, with or without
8 | modification, are permitted provided that the following conditions are met:
9 |
10 | 1. Redistributions of source code must retain the above copyright notice, this
11 | list of conditions and the following disclaimer.
12 |
13 | 2. Redistributions in binary form must reproduce the above copyright notice,
14 | this list of conditions and the following disclaimer in the documentation
15 | and/or other materials provided with the distribution.
16 |
17 | 3. Neither the name of the copyright holder nor the names of its
18 | contributors may be used to endorse or promote products derived from
19 | this software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 |
--------------------------------------------------------------------------------
/kitti_velodyne_ros/README.md:
--------------------------------------------------------------------------------
1 | # kitti_velodyne_ros
2 |
3 | [](https://travis-ci.org/epan-utbm/kitti_velodyne_ros) [](https://app.codacy.com/gh/epan-utbm/kitti_velodyne_ros?utm_source=github.com&utm_medium=referral&utm_content=epan-utbm/kitti_velodyne_ros&utm_campaign=Badge_Grade_Dashboard) [](https://opensource.org/licenses/BSD-3-Clause)
4 |
5 | Load KITTI velodyne data, play in ROS.
6 |
7 | ## Usage
8 |
9 | ```console
10 | $ roslaunch kitti_velodyne_ros kitti_velodyne_ros.launch
11 | ```
12 |
13 | If you want to save the point cloud as a csv file, simply activate in [kitti_velodyne_ros.launch](launch/kitti_velodyne_ros.launch) :
14 |
15 | ```console
16 |
17 | ```
18 |
19 | In case you want to play with [LOAM](https://github.com/laboshinl/loam_velodyne):
20 |
21 | ```console
22 | $ roslaunch kitti_velodyne_ros kitti_velodyne_ros_loam.launch
23 | ```
24 |
--------------------------------------------------------------------------------
/kitti_velodyne_ros/launch/kitti_velodyne_ros.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/kitti_velodyne_ros/launch/kitti_velodyne_ros.rviz:
--------------------------------------------------------------------------------
1 | Panels:
2 | - Class: rviz/Displays
3 | Help Height: 78
4 | Name: Displays
5 | Property Tree Widget:
6 | Expanded:
7 | - /Global Options1
8 | - /Status1
9 | Splitter Ratio: 0.5
10 | Tree Height: 731
11 | - Class: rviz/Selection
12 | Name: Selection
13 | - Class: rviz/Tool Properties
14 | Expanded:
15 | - /2D Pose Estimate1
16 | - /2D Nav Goal1
17 | - /Publish Point1
18 | Name: Tool Properties
19 | Splitter Ratio: 0.5886790156364441
20 | - Class: rviz/Views
21 | Expanded:
22 | - /Current View1
23 | Name: Views
24 | Splitter Ratio: 0.5
25 | - Class: rviz/Time
26 | Experimental: false
27 | Name: Time
28 | SyncMode: 0
29 | SyncSource: PointCloud2
30 | Preferences:
31 | PromptSaveOnExit: true
32 | Toolbars:
33 | toolButtonStyle: 2
34 | Visualization Manager:
35 | Class: ""
36 | Displays:
37 | - Alpha: 1
38 | Autocompute Intensity Bounds: true
39 | Autocompute Value Bounds:
40 | Max Value: 10
41 | Min Value: -10
42 | Value: true
43 | Axis: Z
44 | Channel Name: intensity
45 | Class: rviz/PointCloud2
46 | Color: 255; 255; 255
47 | Color Transformer: Intensity
48 | Decay Time: 0
49 | Enabled: true
50 | Invert Rainbow: false
51 | Max Color: 255; 255; 255
52 | Max Intensity: 0.9900000095367432
53 | Min Color: 0; 0; 0
54 | Min Intensity: 0
55 | Name: PointCloud2
56 | Position Transformer: XYZ
57 | Queue Size: 10
58 | Selectable: true
59 | Size (Pixels): 1
60 | Size (m): 0.019999999552965164
61 | Style: Flat Squares
62 | Topic: /kitti_velodyne_ros/velodyne_points
63 | Unreliable: false
64 | Use Fixed Frame: true
65 | Use rainbow: true
66 | Value: true
67 | - Alpha: 1
68 | Arrow Length: 0.30000001192092896
69 | Axes Length: 0.30000001192092896
70 | Axes Radius: 0.009999999776482582
71 | Class: rviz/PoseArray
72 | Color: 255; 25; 0
73 | Enabled: false
74 | Head Length: 0.07000000029802322
75 | Head Radius: 0.029999999329447746
76 | Name: GroundTruthPose
77 | Shaft Length: 0.23000000417232513
78 | Shaft Radius: 0.009999999776482582
79 | Shape: Arrow (Flat)
80 | Topic: /kitti_velodyne_ros/poses
81 | Unreliable: false
82 | Value: false
83 | - Alpha: 0.5
84 | Cell Size: 1
85 | Class: rviz/Grid
86 | Color: 160; 160; 164
87 | Enabled: false
88 | Line Style:
89 | Line Width: 0.029999999329447746
90 | Value: Lines
91 | Name: Grid
92 | Normal Cell Count: 0
93 | Offset:
94 | X: 0
95 | Y: 0
96 | Z: 0
97 | Plane: XY
98 | Plane Cell Count: 10
99 | Reference Frame:
100 | Value: false
101 | - Class: rviz/TF
102 | Enabled: false
103 | Frame Timeout: 15
104 | Frames:
105 | All Enabled: true
106 | Marker Scale: 7
107 | Name: TF
108 | Show Arrows: true
109 | Show Axes: true
110 | Show Names: true
111 | Tree:
112 | {}
113 | Update Interval: 0
114 | Value: false
115 | - Class: rviz/MarkerArray
116 | Enabled: false
117 | Marker Topic: /kitti_velodyne_ros/markers
118 | Name: MarkerArray
119 | Namespaces:
120 | {}
121 | Queue Size: 100
122 | Value: false
123 | Enabled: true
124 | Global Options:
125 | Background Color: 20; 20; 20
126 | Default Light: true
127 | Fixed Frame: velodyne
128 | Frame Rate: 30
129 | Name: root
130 | Tools:
131 | - Class: rviz/Interact
132 | Hide Inactive Objects: true
133 | - Class: rviz/MoveCamera
134 | - Class: rviz/Select
135 | - Class: rviz/FocusCamera
136 | - Class: rviz/Measure
137 | - Class: rviz/SetInitialPose
138 | Theta std deviation: 0.2617993950843811
139 | Topic: /initialpose
140 | X std deviation: 0.5
141 | Y std deviation: 0.5
142 | - Class: rviz/SetGoal
143 | Topic: /move_base_simple/goal
144 | - Class: rviz/PublishPoint
145 | Single click: true
146 | Topic: /clicked_point
147 | Value: true
148 | Views:
149 | Current:
150 | Class: rviz/Orbit
151 | Distance: 59.19333267211914
152 | Enable Stereo Rendering:
153 | Stereo Eye Separation: 0.05999999865889549
154 | Stereo Focal Distance: 1
155 | Swap Stereo Eyes: false
156 | Value: false
157 | Focal Point:
158 | X: 0.8557648658752441
159 | Y: -1.995847225189209
160 | Z: 0.9883638620376587
161 | Focal Shape Fixed Size: true
162 | Focal Shape Size: 0.05000000074505806
163 | Invert Z Axis: false
164 | Name: Current View
165 | Near Clip Distance: 0.009999999776482582
166 | Pitch: 0.6902026534080505
167 | Target Frame:
168 | Value: Orbit (rviz)
169 | Yaw: 0.9703954458236694
170 | Saved: ~
171 | Window Geometry:
172 | Displays:
173 | collapsed: false
174 | Height: 1025
175 | Hide Left Dock: false
176 | Hide Right Dock: true
177 | QMainWindow State: 000000ff00000000fd00000004000000000000016a00000366fc0200000008fb0000001200530065006c0065006300740069006f006e00000001e10000009b0000005c00fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000003d00000366000000c900fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261000000010000010f00000396fc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000002800000396000000a400fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e100000197000000030000073d0000003bfc0100000002fb0000000800540069006d006501000000000000073d000002eb00fffffffb0000000800540069006d00650100000000000004500000000000000000000005cd0000036600000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000
178 | Selection:
179 | collapsed: false
180 | Time:
181 | collapsed: false
182 | Tool Properties:
183 | collapsed: false
184 | Views:
185 | collapsed: true
186 | Width: 1853
187 | X: 67
188 | Y: 27
189 |
--------------------------------------------------------------------------------
/kitti_velodyne_ros/launch/kitti_velodyne_ros_loam.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/kitti_velodyne_ros/launch/kitti_velodyne_ros_loam.rviz:
--------------------------------------------------------------------------------
1 | Panels:
2 | - Class: rviz/Displays
3 | Help Height: 78
4 | Name: Displays
5 | Property Tree Widget:
6 | Expanded:
7 | - /Global Options1
8 | - /Status1
9 | Splitter Ratio: 0.5
10 | Tree Height: 561
11 | - Class: rviz/Selection
12 | Name: Selection
13 | - Class: rviz/Tool Properties
14 | Expanded:
15 | - /2D Pose Estimate1
16 | - /2D Nav Goal1
17 | - /Publish Point1
18 | Name: Tool Properties
19 | Splitter Ratio: 0.588679016
20 | - Class: rviz/Views
21 | Expanded:
22 | - /Current View1
23 | Name: Views
24 | Splitter Ratio: 0.5
25 | - Class: rviz/Time
26 | Experimental: false
27 | Name: Time
28 | SyncMode: 0
29 | SyncSource: PointCloud2
30 | Visualization Manager:
31 | Class: ""
32 | Displays:
33 | - Alpha: 1
34 | Autocompute Intensity Bounds: true
35 | Autocompute Value Bounds:
36 | Max Value: 10
37 | Min Value: -10
38 | Value: true
39 | Axis: Z
40 | Channel Name: intensity
41 | Class: rviz/PointCloud2
42 | Color: 255; 255; 255
43 | Color Transformer: Intensity
44 | Decay Time: 0
45 | Enabled: true
46 | Invert Rainbow: false
47 | Max Color: 255; 255; 255
48 | Max Intensity: 0.99000001
49 | Min Color: 0; 0; 0
50 | Min Intensity: 0
51 | Name: PointCloud2
52 | Position Transformer: XYZ
53 | Queue Size: 10
54 | Selectable: true
55 | Size (Pixels): 1
56 | Size (m): 0.0199999996
57 | Style: Flat Squares
58 | Topic: /kitti_velodyne_ros/velodyne_points
59 | Unreliable: false
60 | Use Fixed Frame: true
61 | Use rainbow: true
62 | Value: true
63 | - Alpha: 1
64 | Arrow Length: 0.300000012
65 | Axes Length: 0.300000012
66 | Axes Radius: 0.00999999978
67 | Class: rviz/PoseArray
68 | Color: 255; 25; 0
69 | Enabled: false
70 | Head Length: 0.0700000003
71 | Head Radius: 0.0299999993
72 | Name: GroundTruthPose
73 | Shaft Length: 0.230000004
74 | Shaft Radius: 0.00999999978
75 | Shape: Arrow (Flat)
76 | Topic: /kitti_velodyne_ros/poses
77 | Unreliable: false
78 | Value: false
79 | - Alpha: 0.5
80 | Cell Size: 1
81 | Class: rviz/Grid
82 | Color: 160; 160; 164
83 | Enabled: true
84 | Line Style:
85 | Line Width: 0.0299999993
86 | Value: Lines
87 | Name: Grid
88 | Normal Cell Count: 0
89 | Offset:
90 | X: 0
91 | Y: 0
92 | Z: 0
93 | Plane: XY
94 | Plane Cell Count: 10
95 | Reference Frame:
96 | Value: true
97 | - Class: rviz/TF
98 | Enabled: true
99 | Frame Timeout: 15
100 | Frames:
101 | All Enabled: true
102 | base_link:
103 | Value: true
104 | camera_init:
105 | Value: true
106 | map:
107 | Value: true
108 | velodyne:
109 | Value: true
110 | Marker Scale: 7
111 | Name: TF
112 | Show Arrows: true
113 | Show Axes: true
114 | Show Names: true
115 | Tree:
116 | map:
117 | camera_init:
118 | base_link:
119 | velodyne:
120 | {}
121 | Update Interval: 0
122 | Value: true
123 | - Class: rviz/MarkerArray
124 | Enabled: true
125 | Marker Topic: /kitti_velodyne_ros/markers
126 | Name: MarkerArray
127 | Namespaces:
128 | "": true
129 | Queue Size: 100
130 | Value: true
131 | Enabled: true
132 | Global Options:
133 | Background Color: 20; 20; 20
134 | Default Light: true
135 | Fixed Frame: map
136 | Frame Rate: 30
137 | Name: root
138 | Tools:
139 | - Class: rviz/Interact
140 | Hide Inactive Objects: true
141 | - Class: rviz/MoveCamera
142 | - Class: rviz/Select
143 | - Class: rviz/FocusCamera
144 | - Class: rviz/Measure
145 | - Class: rviz/SetInitialPose
146 | Topic: /initialpose
147 | - Class: rviz/SetGoal
148 | Topic: /move_base_simple/goal
149 | - Class: rviz/PublishPoint
150 | Single click: true
151 | Topic: /clicked_point
152 | Value: true
153 | Views:
154 | Current:
155 | Class: rviz/Orbit
156 | Distance: 81.0036621
157 | Enable Stereo Rendering:
158 | Stereo Eye Separation: 0.0599999987
159 | Stereo Focal Distance: 1
160 | Swap Stereo Eyes: false
161 | Value: false
162 | Focal Point:
163 | X: 0.855764866
164 | Y: -1.99584723
165 | Z: 0.988363862
166 | Focal Shape Fixed Size: true
167 | Focal Shape Size: 0.0500000007
168 | Invert Z Axis: false
169 | Name: Current View
170 | Near Clip Distance: 0.00999999978
171 | Pitch: 0.690202653
172 | Target Frame:
173 | Value: Orbit (rviz)
174 | Yaw: 0.970395446
175 | Saved: ~
176 | Window Geometry:
177 | Displays:
178 | collapsed: false
179 | Height: 839
180 | Hide Left Dock: false
181 | Hide Right Dock: true
182 | QMainWindow State: 000000ff00000000fd00000004000000000000016a000002c0fc0200000008fb0000001200530065006c0065006300740069006f006e00000001e10000009b0000006100fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c0061007900730100000028000002c0000000d700fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261000000010000010f00000396fc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000002800000396000000ad00fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e10000019700000003000005490000003bfc0100000002fb0000000800540069006d00650100000000000005490000030000fffffffb0000000800540069006d00650100000000000004500000000000000000000003d9000002c000000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000
183 | Selection:
184 | collapsed: false
185 | Time:
186 | collapsed: false
187 | Tool Properties:
188 | collapsed: false
189 | Views:
190 | collapsed: true
191 | Width: 1353
192 | X: 65
193 | Y: 24
194 |
--------------------------------------------------------------------------------
/kitti_velodyne_ros/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | kitti_velodyne_ros
4 | 0.0.1
5 | Load KITTI velodyne data, play in ROS
6 | Rui Yang
7 | BSD
8 |
9 | https://github.com/epan-utbm/efficient_online_learning
10 | Zhi Yan
11 |
12 | catkin
13 |
14 | roscpp
15 | sensor_msgs
16 | geometry_msgs
17 | visualization_msgs
18 | pcl_conversions
19 | pcl_ros
20 |
21 | roscpp
22 | sensor_msgs
23 | geometry_msgs
24 | visualization_msgs
25 | pcl_conversions
26 | pcl_ros
27 |
28 |
--------------------------------------------------------------------------------
/launch/efficient_online_learning.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/online_forests_ros/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 2.8.3)
2 | project(online_forests_ros)
3 |
4 | find_package(catkin REQUIRED COMPONENTS
5 | roscpp
6 | std_msgs
7 | )
8 |
9 | include_directories(
10 | include
11 | ${catkin_INCLUDE_DIRS}
12 | )
13 |
14 | catkin_package(
15 | INCLUDE_DIRS include
16 | )
17 |
18 | add_executable(online_forests_ros
19 | src/online_forests/classifier.cpp
20 | src/online_forests/hyperparameters.cpp
21 | src/online_forests/onlinenode.cpp
22 | src/online_forests/onlinetree.cpp
23 | src/online_forests/utilities.cpp
24 | src/online_forests/data.cpp
25 | # src/online_forests/Online-Forest.cpp
26 | src/online_forests/onlinerf.cpp
27 | src/online_forests/randomtest.cpp
28 | src/online_forests_ros.cpp
29 | )
30 |
31 | target_link_libraries(online_forests_ros
32 | ${catkin_LIBRARIES}
33 | config++
34 | blas
35 | )
36 |
37 | if(catkin_EXPORTED_TARGETS)
38 | add_dependencies(online_forests_ros
39 | ${catkin_EXPORTED_TARGETS}
40 | )
41 | endif()
42 |
--------------------------------------------------------------------------------
/online_forests_ros/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2014 Amir Saffari
4 | Copyright (c) 2020 Zhi Yan
5 | Copyright (c) 2021 Rui Yang
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 |
25 |
--------------------------------------------------------------------------------
/online_forests_ros/README.md:
--------------------------------------------------------------------------------
1 | # online_forests_ros
2 |
3 | This pacakge is forked from [https://github.com/amirsaffari/online-random-forests](https://github.com/amirsaffari/online-random-forests), and the original Readme file is below the dividing line.
4 |
5 | [2020-09-18]: ROSified the original Online Random Forests, and added support for data stream (not limited to files).
6 |
7 | [2020-09-27]: Added the storage function by the binary tree structure allows the model to be stored and loaded.
8 |
9 | ---
10 |
11 | # Install prerequisites
12 |
13 | ```
14 | sudo apt install libgmm++-dev libconfig++-dev libatlas-base-dev libblas-dev liblapack-dev
15 | ```
16 |
17 | ---
18 |
19 | Online Random Forests
20 | =====================
21 |
22 | This is the original implementation of the Online Random Forest algorithm [1]. There is a more recent implementation of this algorithm at https://github.com/amirsaffari/online-multiclass-lpboost which was used in [2].
23 |
24 | Read the INSTALL file for build instructions.
25 |
26 | Usage:
27 | ======
28 | Input arguments:
29 |
30 | -h | --help : will display this message.
31 | -c : path to the config file.
32 |
33 | --ort : use Online Random Tree (ORT) algorithm.
34 | --orf : use Online Random Forest (ORF) algorithm.
35 |
36 |
37 | --train : train the classifier.
38 | --test : test the classifier.
39 | --t2 : train and test the classifier at the same time.
40 |
41 |
42 | Examples:
43 | ./Online-Forest -c conf/orf.conf --orf --t2
44 |
45 | Config file:
46 | ============
47 | All the settings for the classifier are passed via the config file. You can find the
48 | config file in "conf" folder. It is easy to see what are the meanings behind each of
49 | these settings:
50 | Data:
51 | * trainData = path to the training file
52 | * testData = path to the test file
53 |
54 | Tree:
55 | * maxDepth = maximum depth for a tree
56 | * numRandomTests = number of random tests for each node
57 | * numProjectionFeatures = number of features for hyperplane tests
58 | * counterThreshold = number of samples to be seen for an online node before splitting
59 |
60 | Forest:
61 | * numTrees = number of trees in the forest
62 | * numEpochs = number of online training epochs
63 | * useSoftVoting = boolean flag for using hard or soft voting
64 |
65 | Output:
66 | * savePath = path to save the results (not implemented yet)
67 | * verbose = defines the verbosity level (0: silence)
68 |
69 | Data format:
70 | ============
71 | The data formats used is very similar to the LIBSVM file formats. It only need to have
72 | one header line which contains the following information:
73 | \#Samples \#Features \#Classes \#FeatureMinIndex
74 |
75 | where
76 |
77 | \#Samples: number of samples
78 |
79 | \#Features: number of features
80 |
81 | \#Classes: number of classes
82 |
83 | \#FeatureMinIndex: the index of the first feature used
84 |
85 | You can find a few datasets in the data folder, check their header to see some examples.
86 | Currently, there is only one limitation with the data files: the classes should be
87 | labeled starting in a regular format and start from 0. For example, for a 3 class problem
88 | the labels should be in {0, 1, 2} set.
89 |
90 | ===========
91 | REFERENCES:
92 | ===========
93 | [1] Amir Saffari, Christian Leistner, Jakob Santner, Martin Godec, and Horst Bischof,
94 | "On-line Random Forests,"
95 | 3rd IEEE ICCV Workshop on On-line Computer Vision, 2009.
96 |
97 | [2] Amir Saffari, Martin Godec, Thomas Pock, Christian Leistner, Horst Bischof,
98 | “Online Multi-Class LPBoost“,
99 | Proceedings of IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2010.
100 |
--------------------------------------------------------------------------------
/online_forests_ros/config/orf.conf:
--------------------------------------------------------------------------------
1 | Data:
2 | {
3 | trainData = "data/dna-train.libsvm";
4 | trainLabels = "data/train.label";
5 | testData = "data/dna-test.libsvm";
6 | testLabels = "data/test.label";
7 | };
8 | Tree:
9 | {
10 | maxDepth = 50;
11 | numRandomTests = 30;
12 | numProjectionFeatures = 2;
13 | counterThreshold = 50;
14 | };
15 | Forest:
16 | {
17 | numTrees = 100;
18 | numEpochs = 20;
19 | useSoftVoting = 1;
20 | };
21 | Output:
22 | {
23 | savePath = "/tmp/online-forest-";
24 | verbose = 1; // 0 = None
25 | };
26 |
--------------------------------------------------------------------------------
/online_forests_ros/doc/2009-OnlineRandomForests.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RuiYang-1010/efficient_online_learning/2dca08c20a0a8f88b957c3e276f3ed4bf6f78c31/online_forests_ros/doc/2009-OnlineRandomForests.pdf
--------------------------------------------------------------------------------
/online_forests_ros/include/online_forests/classifier.h:
--------------------------------------------------------------------------------
1 | #ifndef CLASSIFIER_H_
2 | #define CLASSIFIER_H_
3 |
4 | #include
5 |
6 | #include "online_forests/data.h"
7 |
8 | using namespace std;
9 |
10 | class Classifier {
11 | public:
12 | virtual void update(Sample &sample) = 0;
13 | virtual void train(DataSet &dataset) = 0;
14 | virtual Result eval(Sample &sample) = 0;
15 | virtual vector test(DataSet & dataset) = 0;
16 | virtual vector trainAndTest(DataSet &dataset_tr, DataSet &dataset_ts) = 0;
17 |
18 | double compError(const vector &results, const DataSet &dataset) {
19 | double error = 0.0;
20 | for (int i = 0; i < dataset.m_numSamples; i++) {
21 | if (results[i].prediction != dataset.m_samples[i].y) {
22 | error++;
23 | }
24 | }
25 |
26 | return error / dataset.m_numSamples;
27 | }
28 |
29 | void dispErrors(const vector &errors) {
30 | for (int i = 0; i < (int) errors.size(); i++) {
31 | cout << i + 1 << ": " << errors[i] << " --- ";
32 | }
33 | cout << endl;
34 | }
35 | };
36 |
37 | #endif /* CLASSIFIER_H_ */
38 |
--------------------------------------------------------------------------------
/online_forests_ros/include/online_forests/data.h:
--------------------------------------------------------------------------------
1 | #ifndef DATA_H_
2 | #define DATA_H_
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 | using namespace std;
10 | using namespace gmm;
11 |
12 | // TYPEDEFS
13 | typedef int Label;
14 | typedef double Weight;
15 | typedef rsvector SparseVector;
16 |
17 | // DATA CLASSES
18 | class Sample {
19 | public:
20 | SparseVector x;
21 | Label y;
22 | Weight w;
23 |
24 | void disp() {
25 | cout << "Sample: y = " << y << ", w = " << w << ", x = ";
26 | cout << x << endl;
27 | }
28 | };
29 |
30 | class DataSet {
31 | public:
32 | vector m_samples;
33 | int m_numSamples;
34 | int m_numFeatures;
35 | int m_numClasses;
36 |
37 | vector m_minFeatRange;
38 | vector m_maxFeatRange;
39 |
40 | void findFeatRange();
41 |
42 | void loadLIBSVM(string filename);
43 | void loadLIBSVM2(string data);
44 | };
45 |
46 | class Result {
47 | public:
48 | vector confidence;
49 | int prediction;
50 | };
51 |
52 | #endif /* DATA_H_ */
53 |
--------------------------------------------------------------------------------
/online_forests_ros/include/online_forests/hyperparameters.h:
--------------------------------------------------------------------------------
1 | #ifndef HYPERPARAMETERS_H_
2 | #define HYPERPARAMETERS_H_
3 |
4 | #include
5 | using namespace std;
6 |
7 | class Hyperparameters
8 | {
9 | public:
10 | Hyperparameters();
11 | Hyperparameters(const string& confFile);
12 |
13 | // Online node
14 | int numRandomTests;
15 | int numProjectionFeatures;
16 | int counterThreshold;
17 | int maxDepth;
18 |
19 | // Online tree
20 |
21 | // Online forest
22 | int numTrees;
23 | int useSoftVoting;
24 | int numEpochs;
25 |
26 | // Data
27 | string trainData;
28 | string testData;
29 |
30 | // Output
31 | int verbose;
32 | };
33 |
34 | #endif /* HYPERPARAMETERS_H_ */
35 |
--------------------------------------------------------------------------------
/online_forests_ros/include/online_forests/onlinerf.h:
--------------------------------------------------------------------------------
1 | #ifndef ONLINERF_H_
2 | #define ONLINERF_H_
3 |
4 | #include "online_forests/classifier.h"
5 | #include "online_forests/data.h"
6 | #include "online_forests/hyperparameters.h"
7 | #include "online_forests/onlinetree.h"
8 | #include "online_forests/utilities.h"
9 |
10 | class OnlineRF: public Classifier {
11 | public:
12 | OnlineRF(const Hyperparameters &hp, const int &numClasses, const int &numFeatures, const vector &minFeatRange,
13 | const vector &maxFeatRange) :
14 | m_numClasses(&numClasses), m_counter(0.0), m_oobe(0.0), m_hp(&hp) {
15 | OnlineTree *tree;
16 | for (int i = 0; i < hp.numTrees; i++) {
17 | tree = new OnlineTree(hp, numClasses, numFeatures, minFeatRange, maxFeatRange);
18 | m_trees.push_back(tree);
19 | }
20 | }
21 |
22 | ~OnlineRF() {
23 | for (int i = 0; i < m_hp->numTrees; i++) {
24 | delete m_trees[i];
25 | }
26 | }
27 |
28 | virtual void update(Sample &sample) {
29 | m_counter += sample.w;
30 |
31 | Result result, treeResult;
32 | for (int i = 0; i < *m_numClasses; i++) {
33 | result.confidence.push_back(0.0);
34 | }
35 |
36 | int numTries;
37 | for (int i = 0; i < m_hp->numTrees; i++) {
38 | numTries = poisson(1.0);
39 | if (numTries) {
40 | for (int n = 0; n < numTries; n++) {
41 | m_trees[i]->update(sample);
42 | }
43 | } else {
44 | treeResult = m_trees[i]->eval(sample);
45 | if (m_hp->useSoftVoting) {
46 | add(treeResult.confidence, result.confidence);
47 | } else {
48 | result.confidence[treeResult.prediction]++;
49 | }
50 | }
51 | }
52 |
53 | if (argmax(result.confidence) != sample.y) {
54 | m_oobe += sample.w;
55 | }
56 | }
57 |
58 | virtual void train(DataSet &dataset) {
59 | vector randIndex;
60 | int sampRatio = dataset.m_numSamples / 10;
61 | for (int n = 0; n < m_hp->numEpochs; n++) {
62 | randPerm(dataset.m_numSamples, randIndex);
63 | for (int i = 0; i < dataset.m_numSamples; i++) {
64 | update(dataset.m_samples[randIndex[i]]);
65 | if (m_hp->verbose >= 1 && (i % sampRatio) == 0) {
66 | //cout << "--- Online Random Forest training --- Epoch: " << n + 1 << " --- ";
67 | //cout << (10 * i) / sampRatio << "%" << endl;
68 | }
69 | }
70 | cout << "--- Online Random Forest training --- Epoch: " << n + 1 << " --- " << endl;
71 | }
72 | }
73 |
74 | virtual Result eval(Sample &sample) {
75 | Result result, treeResult;
76 | for (int i = 0; i < *m_numClasses; i++) {
77 | result.confidence.push_back(0.0);
78 | }
79 |
80 | for (int i = 0; i < m_hp->numTrees; i++) {
81 | treeResult = m_trees[i]->eval(sample);
82 | if (m_hp->useSoftVoting) {
83 | add(treeResult.confidence, result.confidence);
84 | } else {
85 | result.confidence[treeResult.prediction]++;
86 | }
87 | }
88 |
89 | scale(result.confidence, 1.0 / m_hp->numTrees);
90 | result.prediction = argmax(result.confidence);
91 | return result;
92 | }
93 |
94 | virtual vector test(DataSet &dataset) {
95 | vector results;
96 | for (int i = 0; i < dataset.m_numSamples; i++) {
97 | results.push_back(eval(dataset.m_samples[i]));
98 | }
99 |
100 | double error = compError(results, dataset);
101 | if (m_hp->verbose >= 1) {
102 | cout << "--- Online Random Forest test error: " << error << endl;
103 | }
104 |
105 | return results;
106 | }
107 |
108 | virtual vector trainAndTest(DataSet &dataset_tr, DataSet &dataset_ts) {
109 | vector results;
110 | vector randIndex;
111 | int sampRatio = dataset_tr.m_numSamples / 10;
112 | vector testError;
113 | for (int n = 0; n < m_hp->numEpochs; n++) {
114 | randPerm(dataset_tr.m_numSamples, randIndex);
115 | for (int i = 0; i < dataset_tr.m_numSamples; i++) {
116 | update(dataset_tr.m_samples[randIndex[i]]);
117 | if (m_hp->verbose >= 1 && (i % sampRatio) == 0) {
118 | cout << "--- Online Random Forest training --- Epoch: " << n + 1 << " --- ";
119 | cout << (10 * i) / sampRatio << "%" << endl;
120 | }
121 | }
122 |
123 | results = test(dataset_ts);
124 | testError.push_back(compError(results, dataset_ts));
125 | }
126 |
127 | if (m_hp->verbose >= 1) {
128 | cout << endl << "--- Online Random Forest test error over epochs: ";
129 | dispErrors(testError);
130 | }
131 |
132 | return results;
133 | }
134 |
135 | virtual void writeForest(string fileName) {
136 | cout<<"Writing forest"<numTrees; i++) {
140 | m_trees[i]->writeTree(fp);
141 | }
142 | fclose(fp);
143 | cout<<"Writing forest done"<numTrees; i++) {
151 | m_trees[i]->loadTree(fp, i);
152 | }
153 | fclose(fp);
154 | cout<<"Loading forest done"< m_trees;
165 | };
166 |
167 | #endif /* ONLINERF_H_ */
168 |
--------------------------------------------------------------------------------
/online_forests_ros/include/online_forests/onlinetree.h:
--------------------------------------------------------------------------------
1 | #ifndef ONLINETREE_H_
2 | #define ONLINETREE_H_
3 |
4 | #include
5 | #include
6 |
7 | #include "online_forests/classifier.h"
8 | #include "online_forests/data.h"
9 | #include "online_forests/hyperparameters.h"
10 | #include "online_forests/onlinenode.h"
11 |
12 | using namespace std;
13 |
14 | class OnlineTree: public Classifier {
15 | public:
16 | OnlineTree(const Hyperparameters &hp, const int &numClasses, const int &numFeatures, const vector &minFeatRange,
17 | const vector &maxFeatRange) :
18 | m_counter(0.0), m_hp(&hp) {
19 | m_rootNode = new OnlineNode(hp, numClasses, numFeatures, minFeatRange, maxFeatRange, 0);
20 | }
21 |
22 | ~OnlineTree() {
23 | delete m_rootNode;
24 | }
25 |
26 | virtual void update(Sample &sample) {
27 | m_rootNode->update(sample);
28 | }
29 |
30 | virtual void train(DataSet &dataset) {
31 | vector randIndex;
32 | int sampRatio = dataset.m_numSamples / 10;
33 | for (int n = 0; n < m_hp->numEpochs; n++) {
34 | randPerm(dataset.m_numSamples, randIndex);
35 | for (int i = 0; i < dataset.m_numSamples; i++) {
36 | update(dataset.m_samples[randIndex[i]]);
37 | if (m_hp->verbose >= 3 && (i % sampRatio) == 0) {
38 | cout << "--- Online Random Tree training --- Epoch: " << n + 1 << " --- ";
39 | cout << (10 * i) / sampRatio << "%" << endl;
40 | }
41 | }
42 | }
43 | }
44 |
45 | virtual Result eval(Sample &sample) {
46 | return m_rootNode->eval(sample);
47 | }
48 |
49 | virtual vector test(DataSet &dataset) {
50 | vector results;
51 | for (int i = 0; i < dataset.m_numSamples; i++) {
52 | results.push_back(eval(dataset.m_samples[i]));
53 | }
54 |
55 | double error = compError(results, dataset);
56 | if (m_hp->verbose >= 3) {
57 | cout << "--- Online Random Tree test error: " << error << endl;
58 | }
59 |
60 | return results;
61 | }
62 |
63 | virtual vector trainAndTest(DataSet &dataset_tr, DataSet &dataset_ts) {
64 | vector results;
65 | vector randIndex;
66 | int sampRatio = dataset_tr.m_numSamples / 10;
67 | vector testError;
68 | for (int n = 0; n < m_hp->numEpochs; n++) {
69 | randPerm(dataset_tr.m_numSamples, randIndex);
70 | for (int i = 0; i < dataset_tr.m_numSamples; i++) {
71 | update(dataset_tr.m_samples[randIndex[i]]);
72 | if (m_hp->verbose >= 3 && (i % sampRatio) == 0) {
73 | cout << "--- Online Random Tree training --- Epoch: " << n + 1 << " --- ";
74 | cout << (10 * i) / sampRatio << "%" << endl;
75 | }
76 | }
77 |
78 | results = test(dataset_ts);
79 | testError.push_back(compError(results, dataset_ts));
80 | }
81 |
82 | if (m_hp->verbose >= 3) {
83 | cout << endl << "--- Online Random Tree test error over epochs: ";
84 | dispErrors(testError);
85 | }
86 |
87 | return results;
88 | }
89 |
90 | virtual void writeTree(FILE * fp) {
91 | m_rootNode->writeNode(m_rootNode, m_counter, fp);
92 | fprintf(fp,"T\n");
93 | }
94 |
95 | virtual void loadTree(FILE * fp, int tree_index) {
96 | m_rootNode->loadNode(m_rootNode, fp);
97 | return;
98 | }
99 |
100 | private:
101 | double m_counter;
102 | const Hyperparameters *m_hp;
103 |
104 | OnlineNode* m_rootNode;
105 | };
106 |
107 | #endif /* ONLINETREE_H_ */
108 |
--------------------------------------------------------------------------------
/online_forests_ros/include/online_forests/randomtest.h:
--------------------------------------------------------------------------------
1 | #ifndef RANDOMTEST_H_
2 | #define RANDOMTEST_H_
3 |
4 | #include "online_forests/data.h"
5 | #include "online_forests/utilities.h"
6 |
7 | class RandomTest {
8 | public:
9 | RandomTest() {
10 |
11 | }
12 |
13 | RandomTest(const int &numClasses) :
14 | m_numClasses(&numClasses), m_trueCount(0.0), m_falseCount(0.0) {
15 | for (int i = 0; i < numClasses; i++) {
16 | m_trueStats.push_back(0.0);
17 | m_falseStats.push_back(0.0);
18 | }
19 | m_threshold = randomFromRange(-1, 1);
20 | }
21 |
22 | RandomTest(const int &numClasses, const double featMin, const double featMax) :
23 | m_numClasses(&numClasses), m_trueCount(0.0), m_falseCount(0.0) {
24 | for (int i = 0; i < numClasses; i++) {
25 | m_trueStats.push_back(0.0);
26 | m_falseStats.push_back(0.0);
27 | }
28 | m_threshold = randomFromRange(featMin, featMax);
29 | }
30 |
31 | void updateStats(const Sample &sample, const bool decision) {
32 | if (decision) {
33 | m_trueCount += sample.w;
34 | m_trueStats[sample.y] += sample.w;
35 | } else {
36 | m_falseCount += sample.w;
37 | m_falseStats[sample.y] += sample.w;
38 | }
39 | }
40 |
41 | double score() {
42 | double totalCount = m_trueCount + m_falseCount;
43 |
44 | // Split Entropy
45 | double p, splitEntropy = 0.0;
46 | if (m_trueCount) {
47 | p = m_trueCount / totalCount;
48 | splitEntropy -= p * log2(p);
49 | }
50 | if (m_trueCount) {
51 | p = m_trueCount / totalCount;
52 | splitEntropy -= p * log2(p);
53 | }
54 |
55 | // Prior Entropy
56 | double priorEntropy = 0.0;
57 | for (int i = 0; i < *m_numClasses; i++) {
58 | p = (m_trueStats[i] + m_falseStats[i]) / totalCount;
59 | if (p) {
60 | priorEntropy -= p * log2(p);
61 | }
62 | }
63 |
64 | // Posterior Entropy
65 | double trueScore = 0.0, falseScore = 0.0;
66 | if (m_trueCount) {
67 | for (int i = 0; i < *m_numClasses; i++) {
68 | p = m_trueStats[i] / m_trueCount;
69 | if (p) {
70 | trueScore -= p * log2(p);
71 | }
72 | }
73 | }
74 | if (m_falseCount) {
75 | for (int i = 0; i < *m_numClasses; i++) {
76 | p = m_falseStats[i] / m_falseCount;
77 | if (p) {
78 | falseScore -= p * log2(p);
79 | }
80 | }
81 | }
82 | double posteriorEntropy = (m_trueCount * trueScore + m_falseCount * falseScore) / totalCount;
83 |
84 | // Information Gain
85 | return (2.0 * (priorEntropy - posteriorEntropy)) / (priorEntropy * splitEntropy + 1e-10);
86 | }
87 |
88 | pair , vector > getStats() {
89 | return pair , vector > (m_trueStats, m_falseStats);
90 | }
91 |
92 | protected:
93 | const int *m_numClasses;
94 | double m_threshold;
95 | double m_trueCount;
96 | double m_falseCount;
97 | vector m_trueStats;
98 | vector m_falseStats;
99 | };
100 |
101 | class HyperplaneFeature: public RandomTest {
102 | public:
103 | HyperplaneFeature() {
104 |
105 | }
106 |
107 | HyperplaneFeature(const int &numClasses, const int &numFeatures, const int &numProjFeatures, const vector &minFeatRange,
108 | const vector &maxFeatRange) :
109 | RandomTest(numClasses), m_numProjFeatures(&numProjFeatures) {
110 | randPerm(numFeatures, numProjFeatures, m_features);
111 | fillWithRandomNumbers(numProjFeatures, m_weights);
112 |
113 | // Find min and max range of the projection
114 | double minRange = 0.0, maxRange = 0.0;
115 | for (int i = 0; i < numProjFeatures; i++) {
116 | minRange += minFeatRange[m_features[i]] * m_weights[i];
117 | maxRange += maxFeatRange[m_features[i]] * m_weights[i];
118 | }
119 |
120 | m_threshold = randomFromRange(minRange, maxRange);
121 | }
122 |
123 | void update(Sample &sample) {
124 | updateStats(sample, eval(sample));
125 | }
126 |
127 | bool eval(Sample &sample) {
128 | double proj = 0.0;
129 | for (int i = 0; i < *m_numProjFeatures; i++) {
130 | proj += sample.x[m_features[i]] * m_weights[i];
131 | }
132 |
133 | return (proj > m_threshold) ? true : false;
134 | }
135 |
136 | void writeTest(FILE *fp){
137 | fprintf(fp," %lf",m_threshold);
138 | fprintf(fp," %lf",m_trueCount);
139 | fprintf(fp," %lf",m_falseCount);
140 |
141 | for (int i = 0; i < *m_numClasses; i++) {
142 | fprintf(fp," %lf",m_trueStats[i]);
143 | fprintf(fp," %lf",m_falseStats[i]);
144 | }
145 |
146 | for (int i = 0; i < *m_numProjFeatures; i++) {
147 | fprintf(fp," %d",m_features[i]);
148 | fprintf(fp," %lf",m_weights[i]);
149 | }
150 | }
151 |
152 | void loadTest(FILE *fp){
153 | fscanf(fp, "%lf ", &m_threshold);
154 | fscanf(fp, "%lf ", &m_trueCount);
155 | fscanf(fp, "%lf ", &m_falseCount);
156 |
157 | for (int i = 0; i < *m_numClasses; i++) {
158 | fscanf(fp, "%lf ", &m_trueStats[i]);
159 | fscanf(fp, "%lf ", &m_falseStats[i]);
160 | }
161 |
162 | for (int i = 0; i < *m_numProjFeatures; i++) {
163 | fscanf(fp, "%d ", &m_features[i]);
164 | fscanf(fp, "%lf ", &m_weights[i]);
165 | }
166 | }
167 |
168 | private:
169 | const int *m_numProjFeatures;
170 | vector m_features;
171 | vector m_weights;
172 | };
173 |
174 | #endif /* RANDOMTEST_H_ */
175 |
--------------------------------------------------------------------------------
/online_forests_ros/include/online_forests/utilities.h:
--------------------------------------------------------------------------------
1 | #ifndef UTILITIES_H_
2 | #define UTILITIES_H_
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #ifndef WIN32
12 | #include
13 | #endif
14 |
15 | using namespace std;
16 |
17 | // Random Numbers Generators
18 | unsigned int getDevRandom();
19 |
20 | //! Returns a random number in [0, 1]
21 | inline double randDouble() {
22 | static bool didSeeding = false;
23 |
24 | if (!didSeeding) {
25 | #ifdef WIN32
26 | srand(0);
27 | #else
28 | unsigned int seedNum;
29 | struct timeval TV;
30 | unsigned int curTime;
31 |
32 | gettimeofday(&TV, NULL);
33 | curTime = (unsigned int) TV.tv_usec;
34 | seedNum = (unsigned int) time(NULL) + curTime + getpid() + getDevRandom();
35 |
36 | srand(seedNum);
37 | #endif
38 | didSeeding = true;
39 | }
40 | return rand() / (RAND_MAX + 1.0);
41 | }
42 |
43 | //! Returns a random number in [min, max]
44 | inline double randomFromRange(const double &minRange, const double &maxRange) {
45 | return minRange + (maxRange - minRange) * randDouble();
46 | }
47 |
48 | //! Random permutations
49 | void randPerm(const int &inNum, vector &outVect);
50 | void randPerm(const int &inNum, const int inPart, vector &outVect);
51 |
52 | inline void fillWithRandomNumbers(const int &length, vector &inVect) {
53 | inVect.clear();
54 | for (int i = 0; i < length; i++) {
55 | inVect.push_back(2.0 * (randDouble() - 0.5));
56 | }
57 | }
58 |
59 | inline int argmax(const vector &inVect) {
60 | double maxValue = inVect[0];
61 | int maxIndex = 0, i = 1;
62 | vector::const_iterator itr(inVect.begin() + 1), end(inVect.end());
63 | while (itr != end) {
64 | if (*itr > maxValue) {
65 | maxValue = *itr;
66 | maxIndex = i;
67 | }
68 |
69 | ++i;
70 | ++itr;
71 | }
72 |
73 | return maxIndex;
74 | }
75 |
76 | inline double sum(const vector &inVect) {
77 | double val = 0.0;
78 | vector::const_iterator itr(inVect.begin()), end(inVect.end());
79 | while (itr != end) {
80 | val += *itr;
81 | ++itr;
82 | }
83 |
84 | return val;
85 | }
86 |
87 | //! Poisson sampling
88 | inline int poisson(double A) {
89 | int k = 0;
90 | int maxK = 10;
91 | while (1) {
92 | double U_k = randDouble();
93 | A *= U_k;
94 | if (k > maxK || A < exp(-1.0)) {
95 | break;
96 | }
97 | k++;
98 | }
99 | return k;
100 | }
101 |
102 | #endif /* UTILITIES_H_ */
103 |
--------------------------------------------------------------------------------
/online_forests_ros/launch/online_forests_ros.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/online_forests_ros/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | online_forests_ros
4 | 0.1.2
5 |
6 | A package that ROSifies the Online Random Forests.
7 |
8 | Rui Yang
9 | Zhi Yan
10 | MIT
11 |
12 | https://github.com/amirsaffari/online-random-forests
13 | Amir Saffari
14 |
15 | catkin
16 |
17 | roscpp
18 |
19 | libconfig++-dev
20 | libgmm++-dev
21 | libatlas-base-dev
22 | libblas-dev
23 | liblapack-dev
24 | autoware_tracker
25 | std_msgs
26 |
27 | autoware_tracker
28 | std_msgs
29 |
30 |
--------------------------------------------------------------------------------
/online_forests_ros/src/online_forests/Online-Forest.cpp:
--------------------------------------------------------------------------------
1 | #define GMM_USES_BLAS
2 |
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | #include "online_forests/data.h"
11 | #include "online_forests/onlinetree.h"
12 | #include "online_forests/onlinerf.h"
13 |
14 | using namespace std;
15 | using namespace libconfig;
16 |
17 | typedef enum {
18 | ORT, ORF
19 | } CLASSIFIER_TYPE;
20 |
21 | //! Prints the interface help message
22 | void help() {
23 | cout << endl;
24 | cout << "OnlineForest Classification Package:" << endl;
25 | cout << "Input arguments:" << endl;
26 | cout << "\t -h | --help : \t will display this message." << endl;
27 | cout << "\t -c : \t\t path to the config file." << endl << endl;
28 | cout << "\t --ort : \t use Online Random Tree (ORT) algorithm." << endl;
29 | cout << "\t --orf : \t use Online Random Forest (ORF) algorithm." << endl;
30 | cout << endl << endl;
31 | cout << "\t --train : \t train the classifier." << endl;
32 | cout << "\t --test : \t test the classifier." << endl;
33 | cout << "\t --t2 : \t train and test the classifier at the same time." << endl;
34 | cout << endl << endl;
35 | cout << "\tExamples:" << endl;
36 | cout << "\t ./Online-Forest -c conf/orf.conf --orf --train --test" << endl;
37 | }
38 |
39 | //! Returns the time (ms) elapsed between two calls to this function
40 | double timeIt(int reset) {
41 | static time_t startTime, endTime;
42 | static int timerWorking = 0;
43 |
44 | if (reset) {
45 | startTime = time(NULL);
46 | timerWorking = 1;
47 | return -1;
48 | } else {
49 | if (timerWorking) {
50 | endTime = time(NULL);
51 | timerWorking = 0;
52 | return (double) (endTime - startTime);
53 | } else {
54 | startTime = time(NULL);
55 | timerWorking = 1;
56 | return -1;
57 | }
58 | }
59 | }
60 |
61 | int main(int argc, char *argv[]) {
62 | // Parsing command line
63 | string confFileName;
64 | int classifier = -1, doTraining = false, doTesting = false, doT2 = false, inputCounter = 1;
65 |
66 | if (argc == 1) {
67 | cout << "\tNo input argument specified: aborting." << endl;
68 | help();
69 | exit(EXIT_SUCCESS);
70 | }
71 |
72 | while (inputCounter < argc) {
73 | if (!strcmp(argv[inputCounter], "-h") || !strcmp(argv[inputCounter], "--help")) {
74 | help();
75 | return EXIT_SUCCESS;
76 | } else if (!strcmp(argv[inputCounter], "-c")) {
77 | confFileName = argv[++inputCounter];
78 | } else if (!strcmp(argv[inputCounter], "--ort")) {
79 | classifier = ORT;
80 | } else if (!strcmp(argv[inputCounter], "--orf")) {
81 | classifier = ORF;
82 | } else if (!strcmp(argv[inputCounter], "--train")) {
83 | doTraining = true;
84 | } else if (!strcmp(argv[inputCounter], "--test")) {
85 | doTesting = true;
86 | } else if (!strcmp(argv[inputCounter], "--t2")) {
87 | doT2 = true;
88 | } else {
89 | cout << "\tUnknown input argument: " << argv[inputCounter];
90 | cout << ", please try --help for more information." << endl;
91 | exit(EXIT_FAILURE);
92 | }
93 |
94 | inputCounter++;
95 | }
96 |
97 | cout << "OnlineMCBoost Classification Package:" << endl;
98 |
99 | if (!doTraining && !doTesting && !doT2) {
100 | cout << "\tNothing to do, no training, no testing !!!" << endl;
101 | exit(EXIT_FAILURE);
102 | }
103 |
104 | if (doT2) {
105 | doTraining = false;
106 | doTesting = false;
107 | }
108 |
109 | // Load the hyperparameters
110 | Hyperparameters hp(confFileName);
111 |
112 | // Creating the train data
113 | DataSet dataset_tr, dataset_ts;
114 | dataset_tr.loadLIBSVM(hp.trainData);
115 | if (doT2 || doTesting) {
116 | dataset_ts.loadLIBSVM(hp.testData);
117 | }
118 |
119 | // Calling training/testing
120 | switch (classifier) {
121 | case ORT: {
122 | OnlineTree model(hp, dataset_tr.m_numClasses, dataset_tr.m_numFeatures, dataset_tr.m_minFeatRange, dataset_tr.m_maxFeatRange);
123 | if (doT2) {
124 | timeIt(1);
125 | model.trainAndTest(dataset_tr, dataset_ts);
126 | cout << "Training/Test time: " << timeIt(0) << endl;
127 | }
128 | if (doTraining) {
129 | timeIt(1);
130 | model.train(dataset_tr);
131 | cout << "Training time: " << timeIt(0) << endl;
132 | } else if (doTesting) {
133 | timeIt(1);
134 | model.test(dataset_ts);
135 | cout << "Test time: " << timeIt(0) << endl;
136 | }
137 | break;
138 | }
139 | case ORF: {
140 | OnlineRF model(hp, dataset_tr.m_numClasses, dataset_tr.m_numFeatures, dataset_tr.m_minFeatRange, dataset_tr.m_maxFeatRange);
141 | if (doT2) {
142 | timeIt(1);
143 | model.trainAndTest(dataset_tr, dataset_ts);
144 | cout << "Training/Test time: " << timeIt(0) << endl;
145 | }
146 | if (doTraining) {
147 | timeIt(1);
148 | model.train(dataset_tr);
149 | cout << "Training time: " << timeIt(0) << endl;
150 | }
151 | if (doTesting) {
152 | timeIt(1);
153 | model.test(dataset_ts);
154 | cout << "Test time: " << timeIt(0) << endl;
155 | }
156 | break;
157 | }
158 | }
159 |
160 | return EXIT_SUCCESS;
161 | }
162 |
--------------------------------------------------------------------------------
/online_forests_ros/src/online_forests/classifier.cpp:
--------------------------------------------------------------------------------
1 | #include "online_forests/classifier.h"
2 |
--------------------------------------------------------------------------------
/online_forests_ros/src/online_forests/data.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "online_forests/data.h"
6 |
7 | using namespace std;
8 |
9 | void DataSet::findFeatRange() {
10 | double minVal, maxVal;
11 | for (int i = 0; i < m_numFeatures; i++) {
12 | minVal = m_samples[0].x[i];
13 | maxVal = m_samples[0].x[i];
14 | for (int n = 1; n < m_numSamples; n++) {
15 | if (m_samples[n].x[i] < minVal) {
16 | minVal = m_samples[n].x[i];
17 | }
18 | if (m_samples[n].x[i] > maxVal) {
19 | maxVal = m_samples[n].x[i];
20 | }
21 | }
22 |
23 | m_minFeatRange.push_back(minVal);
24 | m_maxFeatRange.push_back(maxVal);
25 | }
26 | }
27 |
28 | void DataSet::loadLIBSVM(string filename) {
29 | ifstream fp(filename.c_str(), ios::binary);
30 | if (!fp) {
31 | cout << "Could not open input file " << filename << endl;
32 | exit(EXIT_FAILURE);
33 | }
34 |
35 | cout << "Loading data file: " << filename << " ... " << endl;
36 |
37 | // Reading the header
38 | int startIndex;
39 | fp >> m_numSamples;
40 | fp >> m_numFeatures;
41 | fp >> m_numClasses;
42 | fp >> startIndex;
43 |
44 | // Reading the data
45 | string line, tmpStr;
46 | int prePos, curPos, colIndex;
47 | m_samples.clear();
48 |
49 | for (int i = 0; i < m_numSamples; i++) {
50 | wsvector x(m_numFeatures);
51 | Sample sample;
52 | resize(sample.x, m_numFeatures);
53 | fp >> sample.y; // read label
54 | sample.w = 1.0; // set weight
55 |
56 | getline(fp, line); // read the rest of the line
57 | prePos = 0;
58 | curPos = line.find(' ', 0);
59 | while (prePos <= curPos) {
60 | prePos = curPos + 1;
61 | curPos = line.find(':', prePos);
62 | tmpStr = line.substr(prePos, curPos - prePos);
63 | colIndex = atoi(tmpStr.c_str()) - startIndex;
64 |
65 | prePos = curPos + 1;
66 | curPos = line.find(' ', prePos);
67 | tmpStr = line.substr(prePos, curPos - prePos);
68 | x[colIndex] = atof(tmpStr.c_str());
69 | }
70 | copy(x, sample.x);
71 | m_samples.push_back(sample); // push sample into dataset
72 | }
73 |
74 | fp.close();
75 |
76 | if (m_numSamples != (int) m_samples.size()) {
77 | cout << "Could not load " << m_numSamples << " samples from " << filename;
78 | cout << ". There were only " << m_samples.size() << " samples!" << endl;
79 | exit(EXIT_FAILURE);
80 | }
81 |
82 | // Find the data range
83 | findFeatRange();
84 |
85 | cout << "Loaded " << m_numSamples << " samples with " << m_numFeatures;
86 | cout << " features and " << m_numClasses << " classes." << endl;
87 | }
88 |
89 | void DataSet::loadLIBSVM2(string data) {
90 | // Reading the header
91 | std::istringstream iss(data);
92 | string line;
93 | int startIndex;
94 |
95 | getline(iss, line, ' ');
96 | m_numSamples = atoi(line.c_str());
97 | getline(iss, line, ' ');
98 | m_numFeatures = atoi(line.c_str());
99 | getline(iss, line, ' ');
100 | m_numClasses = atoi(line.c_str());
101 | getline(iss, line, '\n');
102 | startIndex = atoi(line.c_str());
103 |
104 | // Reading the data
105 | string tmpStr;
106 | int prePos, curPos, colIndex;
107 | m_samples.clear();
108 |
109 | for (int i = 0; i < m_numSamples; i++) {
110 | wsvector x(m_numFeatures);
111 | Sample sample;
112 | resize(sample.x, m_numFeatures);
113 | // getline(iss, line);
114 | // sample.y = atoi(line.substr(line.find(' ')).c_str()); // read label
115 | getline(iss, line, ' ');
116 | sample.y = atoi(line.c_str()); // read label
117 | sample.w = 1.0; // set weight
118 |
119 | getline(iss, line);
120 | prePos = 0;
121 | curPos = line.find(' ', 0);
122 | while (prePos <= curPos) {
123 | prePos = curPos + 1;
124 | curPos = line.find(':', prePos);
125 | tmpStr = line.substr(prePos, curPos - prePos);
126 | colIndex = atoi(tmpStr.c_str()) - startIndex;
127 |
128 | prePos = curPos + 1;
129 | curPos = line.find(' ', prePos);
130 | tmpStr = line.substr(prePos, curPos - prePos);
131 | x[colIndex] = atof(tmpStr.c_str());
132 | }
133 | copy(x, sample.x);
134 | m_samples.push_back(sample); // push sample into dataset
135 | }
136 |
137 | if (m_numSamples != (int) m_samples.size()) {
138 | cout << "Could not load " << m_numSamples;
139 | cout << ". There were only " << m_samples.size() << " samples!" << endl;
140 | exit(EXIT_FAILURE);
141 | }
142 |
143 | // Find the data range
144 | findFeatRange();
145 |
146 | cout << "Loaded " << m_numSamples << " samples with " << m_numFeatures;
147 | cout << " features and " << m_numClasses << " classes." << endl;
148 | }
149 |
--------------------------------------------------------------------------------
/online_forests_ros/src/online_forests/hyperparameters.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "online_forests/hyperparameters.h"
5 |
6 | using namespace std;
7 | using namespace libconfig;
8 |
9 | Hyperparameters::Hyperparameters(const string& confFile) {
10 | cout << "Loading config file: " << confFile << " ... ";
11 |
12 | Config configFile;
13 | configFile.readFile(confFile.c_str());
14 |
15 | // Node/Tree
16 | maxDepth = configFile.lookup("Tree.maxDepth");
17 | numRandomTests = configFile.lookup("Tree.numRandomTests");
18 | numProjectionFeatures = configFile.lookup("Tree.numProjectionFeatures");
19 | counterThreshold = configFile.lookup("Tree.counterThreshold");
20 |
21 | // Forest
22 | numTrees = configFile.lookup("Forest.numTrees");
23 | numEpochs = configFile.lookup("Forest.numEpochs");
24 | useSoftVoting = configFile.lookup("Forest.useSoftVoting");
25 |
26 | // Data
27 | trainData = (const char *) configFile.lookup("Data.trainData");
28 | testData = (const char *) configFile.lookup("Data.testData");
29 |
30 | // Output
31 | verbose = configFile.lookup("Output.verbose");
32 |
33 | cout << "Done." << endl;
34 | }
35 |
--------------------------------------------------------------------------------
/online_forests_ros/src/online_forests/onlinenode.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "online_forests/onlinenode.h"
4 |
5 | using namespace std;
6 |
7 | void OnlineNode::update(Sample &sample) {
8 | m_counter += sample.w;
9 | m_labelStats[sample.y] += sample.w;
10 |
11 | if (m_isLeaf) {
12 | // Update online tests
13 | for (int i = 0; i < m_hp->numRandomTests; i++) {
14 | m_onlineTests[i].update(sample);
15 | }
16 |
17 | // Update the label
18 | m_label = argmax(m_labelStats);
19 |
20 | // Decide for split
21 | if (shouldISplit()) {
22 | m_isLeaf = false;
23 |
24 | // Find the best online test
25 | int maxIndex = 0;
26 | double maxScore = -1e10, score;
27 | for (int i = 0; i < m_hp->numRandomTests; i++) {
28 | score = m_onlineTests[i].score();
29 | if (score > maxScore) {
30 | maxScore = score;
31 | maxIndex = i;
32 | }
33 | }
34 | m_bestTest = m_onlineTests[maxIndex];
35 | m_onlineTests.clear();
36 |
37 | if (m_hp->verbose >= 4) {
38 | cout << "--- Splitting node --- best score: " << maxScore;
39 | cout << " by test number: " << maxIndex << endl;
40 | }
41 |
42 | // Split
43 | pair , vector > parentStats = m_bestTest.getStats();
44 | m_rightChildNode = new OnlineNode(*m_hp, *m_numClasses, *m_numFeatures, *m_minFeatRange, *m_maxFeatRange, m_depth + 1,
45 | parentStats.first);
46 | m_leftChildNode = new OnlineNode(*m_hp, *m_numClasses, *m_numFeatures, *m_minFeatRange, *m_maxFeatRange, m_depth + 1,
47 | parentStats.second);
48 | }
49 | } else {
50 | if (m_bestTest.eval(sample)) {
51 | m_rightChildNode->update(sample);
52 | } else {
53 | m_leftChildNode->update(sample);
54 | }
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/online_forests_ros/src/online_forests/onlinerf.cpp:
--------------------------------------------------------------------------------
1 | #include "online_forests/onlinerf.h"
2 |
3 |
--------------------------------------------------------------------------------
/online_forests_ros/src/online_forests/onlinetree.cpp:
--------------------------------------------------------------------------------
1 | #include "online_forests/onlinetree.h"
2 |
--------------------------------------------------------------------------------
/online_forests_ros/src/online_forests/randomtest.cpp:
--------------------------------------------------------------------------------
1 | #include "online_forests/randomtest.h"
2 |
--------------------------------------------------------------------------------
/online_forests_ros/src/online_forests/utilities.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #ifndef WIN32
7 | #include
8 | #endif
9 |
10 | #include "online_forests/utilities.h"
11 |
12 | using namespace std;
13 |
14 | unsigned int getDevRandom() {
15 | ifstream devFile("/dev/urandom", ios::binary);
16 | unsigned int outInt = 0;
17 | char tempChar[sizeof(outInt)];
18 |
19 | devFile.read(tempChar, sizeof(outInt));
20 | outInt = atoi(tempChar);
21 |
22 | devFile.close();
23 |
24 | return outInt;
25 | }
26 |
27 | void randPerm(const int &inNum, vector &outVect) {
28 | outVect.resize(inNum);
29 | int randIndex, tempIndex;
30 | for (int nFeat = 0; nFeat < inNum; nFeat++) {
31 | outVect[nFeat] = nFeat;
32 | }
33 | for (register int nFeat = 0; nFeat < inNum; nFeat++) {
34 | randIndex = (int) floor(((double) inNum - nFeat) * randDouble()) + nFeat;
35 | if (randIndex == inNum) {
36 | randIndex--;
37 | }
38 | tempIndex = outVect[nFeat];
39 | outVect[nFeat] = outVect[randIndex];
40 | outVect[randIndex] = tempIndex;
41 | }
42 | }
43 |
44 | void randPerm(const int &inNum, const int inPart, vector &outVect) {
45 | outVect.resize(inNum);
46 | int randIndex, tempIndex;
47 | for (int nFeat = 0; nFeat < inNum; nFeat++) {
48 | outVect[nFeat] = nFeat;
49 | }
50 | for (register int nFeat = 0; nFeat < inPart; nFeat++) {
51 | randIndex = (int) floor(((double) inNum - nFeat) * randDouble()) + nFeat;
52 | if (randIndex == inNum) {
53 | randIndex--;
54 | }
55 | tempIndex = outVect[nFeat];
56 | outVect[nFeat] = outVect[randIndex];
57 | outVect[randIndex] = tempIndex;
58 | }
59 |
60 | outVect.erase(outVect.begin() + inPart, outVect.end());
61 | }
62 |
--------------------------------------------------------------------------------
/online_forests_ros/src/online_forests_ros.cpp:
--------------------------------------------------------------------------------
1 | // (c) 2020 Zhi Yan, Rui Yang
2 | // This code is licensed under MIT license (see LICENSE.txt for details)
3 | #define GMM_USES_BLAS
4 |
5 | // ROS
6 | #include
7 | #include
8 | // Online Random Forests
9 | #include "online_forests/onlinetree.h"
10 | #include "online_forests/onlinerf.h"
11 |
12 | int main(int argc, char **argv) {
13 | std::ofstream icra_log;
14 | std::string log_name = "orf_time_log_"+std::to_string(ros::WallTime::now().toSec());
15 |
16 | std::string conf_file_name;
17 | std::string model_file_name;
18 | int mode; // 1 - train, 2 - test, 3 - train and test.
19 | int minimum_samples;
20 | int total_samples = 0;
21 |
22 | ros::init(argc, argv, "online_forests_ros");
23 | ros::NodeHandle nh, private_nh("~");
24 |
25 | if(private_nh.getParam("conf_file_name", conf_file_name)) {
26 | ROS_INFO("Got param 'conf_file_name': %s", conf_file_name.c_str());
27 | } else {
28 | ROS_ERROR("Failed to get param 'conf_file_name'");
29 | exit(EXIT_SUCCESS);
30 | }
31 |
32 | if(private_nh.getParam("model_file_name", model_file_name)) {
33 | ROS_INFO("Got param 'model_file_name': %s", model_file_name.c_str());
34 | } else {
35 | ROS_ERROR("Failed to get param 'model_file_name'");
36 | exit(EXIT_SUCCESS);
37 | }
38 |
39 | if(private_nh.getParam("mode", mode)) {
40 | ROS_INFO("Got param 'mode': %d", mode);
41 | } else {
42 | ROS_ERROR("Failed to get param 'mode'");
43 | exit(EXIT_SUCCESS);
44 | }
45 |
46 | private_nh.param("minimum_samples", minimum_samples, 1);
47 |
48 | Hyperparameters hp(conf_file_name);
49 | std_msgs::String::ConstPtr features;
50 |
51 | while (ros::ok()) {
52 | features = ros::topic::waitForMessage("/point_cloud_features/features"); // process blocked waiting
53 |
54 | // Creating the train data
55 | DataSet dataset_tr;
56 | dataset_tr.loadLIBSVM2(features->data);
57 |
58 | // Creating the test data
59 | DataSet dataset_ts;
60 | // dataset_ts.loadLIBSVM(hp.testData);
61 |
62 | if(atoi(features->data.substr(0, features->data.find(" ")).c_str()) >= minimum_samples) {
63 | OnlineRF model(hp, dataset_tr.m_numClasses, dataset_tr.m_numFeatures, dataset_tr.m_minFeatRange, dataset_tr.m_maxFeatRange); // TOTEST: OnlineTree
64 |
65 | //string model_file_name = "";
66 | icra_log.open(log_name, std::ofstream::out | std::ofstream::app);
67 | time_t start_time = ros::WallTime::now().toSec();
68 |
69 | switch(mode) {
70 | case 1: // train only
71 | if(access( model_file_name.c_str(), F_OK ) != -1){
72 | model.loadForest(model_file_name);
73 | }
74 | model.train(dataset_tr);
75 | //model.writeForest(model_file_name); //turning off the writing function
76 | break;
77 | case 2: // test only
78 | model.loadForest(model_file_name);
79 | model.test(dataset_ts);
80 | break;
81 | case 3: // train and test
82 | model.trainAndTest(dataset_tr, dataset_ts);
83 | break;
84 | default:
85 | ROS_ERROR("Unknown 'mode'");
86 | }
87 |
88 | std::cout << "[online_forests_ros] Training time: " << ros::WallTime::now().toSec() - start_time << " s" << std::endl;
89 | icra_log << (total_samples+=dataset_tr.m_numSamples) << " " << ros::WallTime::now().toSec()-start_time << "\n";
90 | icra_log.close();
91 | }
92 |
93 | ros::spinOnce();
94 | }
95 |
96 | return EXIT_SUCCESS;
97 | }
98 |
--------------------------------------------------------------------------------
/online_svm_ros/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 2.8.3)
2 | project(online_svm_ros)
3 |
4 | find_package(catkin REQUIRED COMPONENTS
5 | roscpp
6 | std_msgs
7 | )
8 |
9 | find_package(yaml-cpp REQUIRED)
10 |
11 | include_directories(
12 | ${catkin_INCLUDE_DIRS}
13 | )
14 |
15 | add_executable(online_svm_ros
16 | src/online_svm_ros.cpp
17 | )
18 |
19 | target_link_libraries(online_svm_ros
20 | ${catkin_LIBRARIES}
21 | yaml-cpp
22 | svm
23 | )
24 |
25 | if(catkin_EXPORTED_TARGETS)
26 | add_dependencies(online_svm_ros
27 | ${catkin_EXPORTED_TARGETS}
28 | )
29 | endif()
30 |
--------------------------------------------------------------------------------
/online_svm_ros/README.md:
--------------------------------------------------------------------------------
1 | # online_svm_ros
2 |
3 | Barebone package for SVM-based classifier online training used in [online learning](https://github.com/yzrobot/online_learning).
4 |
--------------------------------------------------------------------------------
/online_svm_ros/config/svm.yaml:
--------------------------------------------------------------------------------
1 | # https://github.com/cjlin1/libsvm
2 | x_lower: -1.0
3 | x_upper: 1.0
4 | max_examples: 5000
5 | svm_type: 0 # default 0 (C_SVC)
6 | kernel_type: 2 # default 2 (RBF)
7 | degree: 3 # default 3
8 | gamma: 0.02 # default 1.0/(float)FEATURE_SIZE
9 | coef0: 0 # default 0
10 | cache_size: 256 # default 100
11 | eps: 0.001 # default 0.001
12 | C: 8 # default 1
13 | nr_weight: 0
14 | nu: 0.5
15 | p: 0.1
16 | shrinking: 0
17 | probability: 1
18 | save_data: true
19 | best_params: false
--------------------------------------------------------------------------------
/online_svm_ros/launch/online_svm_ros.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/online_svm_ros/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | online_svm_ros
4 | 0.0.1
5 |
6 | Online SVM.
7 |
8 | Rui Yang
9 | GPLv3
10 |
11 | https://github.com/yzrobot/online_learning
12 | Zhi Yan
13 |
14 | catkin
15 |
16 | roscpp
17 | yaml-cpp
18 | libsvm-dev
19 | std_msgs
20 |
21 |
--------------------------------------------------------------------------------
/point_cloud_features/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 2.8.3)
2 | project(point_cloud_features)
3 |
4 | find_package(catkin REQUIRED COMPONENTS
5 | roscpp
6 | pcl_conversions
7 | pcl_ros
8 | std_msgs
9 | autoware_tracker
10 | )
11 |
12 | find_package(PCL REQUIRED)
13 |
14 | include_directories(include ${catkin_INCLUDE_DIRS})
15 |
16 | catkin_package(INCLUDE_DIRS include)
17 |
18 | add_executable(point_cloud_features
19 | src/point_cloud_feature_extractor.cpp
20 | src/${PROJECT_NAME}/point_cloud_features.cpp
21 | )
22 |
23 | target_link_libraries(point_cloud_features
24 | ${catkin_LIBRARIES}
25 | )
26 |
27 | if(catkin_EXPORTED_TARGETS)
28 | add_dependencies(point_cloud_features
29 | ${catkin_EXPORTED_TARGETS}
30 | )
31 | endif()
32 |
--------------------------------------------------------------------------------
/point_cloud_features/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020, Zhi Yan
4 | Copyright (c) 2021, Rui Yang
5 | All rights reserved.
6 |
7 | Redistribution and use in source and binary forms, with or without
8 | modification, are permitted provided that the following conditions are met:
9 |
10 | 1. Redistributions of source code must retain the above copyright notice, this
11 | list of conditions and the following disclaimer.
12 |
13 | 2. Redistributions in binary form must reproduce the above copyright notice,
14 | this list of conditions and the following disclaimer in the documentation
15 | and/or other materials provided with the distribution.
16 |
17 | 3. Neither the name of the copyright holder nor the names of its
18 | contributors may be used to endorse or promote products derived from
19 | this software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 |
--------------------------------------------------------------------------------
/point_cloud_features/README.md:
--------------------------------------------------------------------------------
1 | # point_cloud_features
2 |
3 | Barebone package for point cloud feature extraction used in [online learning](https://github.com/yzrobot/online_learning).
4 |
5 | [2020-11-14]: Adjusted the calculation of minimum_points.
6 |
--------------------------------------------------------------------------------
/point_cloud_features/include/point_cloud_features/point_cloud_features.h:
--------------------------------------------------------------------------------
1 | /**
2 | * BSD 3-Clause License
3 | *
4 | * Copyright (c) 2020, Zhi Yan
5 | * All rights reserved.
6 |
7 | * Redistribution and use in source and binary forms, with or without
8 | * modification, are permitted provided that the following conditions are met:
9 |
10 | * 1. Redistributions of source code must retain the above copyright notice, this
11 | * list of conditions and the following disclaimer.
12 |
13 | * 2. Redistributions in binary form must reproduce the above copyright notice,
14 | * this list of conditions and the following disclaimer in the documentation
15 | * and/or other materials provided with the distribution.
16 |
17 | * 3. Neither the name of the copyright holder nor the names of its
18 | * contributors may be used to endorse or promote products derived from
19 | * this software without specific prior written permission.
20 |
21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 | **/
32 |
33 | #ifndef POINT_CLOUD_FEATURES_H
34 | #define POINT_CLOUD_FEATURES_H
35 |
36 | // ROS
37 | #include
38 |
39 | // PCL
40 | #include
41 | #include
42 |
43 | int numberOfPoints(pcl::PointCloud::Ptr);
44 | float minDistance(pcl::PointCloud::Ptr);
45 | void covarianceMat3D(pcl::PointCloud::Ptr, std::vector &);
46 | void normalizedMOIT(pcl::PointCloud::Ptr, std::vector &);
47 | void sliceFeature(pcl::PointCloud::Ptr, int, std::vector &);
48 | void intensityDistribution(pcl::PointCloud::Ptr, int, std::vector &);
49 |
50 | #endif /* POINT_CLOUD_FEATURES_H */
51 |
--------------------------------------------------------------------------------
/point_cloud_features/launch/point_cloud_feature_extractor.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/point_cloud_features/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | point_cloud_features
4 | 0.0.1
5 |
6 | A ROS package for extracting features from 3D point clouds.
7 |
8 | Rui Yang
9 | BSD
10 |
11 | https://github.com/yzrobot/online_learning
12 | Zhi Yan
13 |
14 | catkin
15 |
16 | roscpp
17 | pcl_conversions
18 | pcl_ros
19 | std_msgs
20 | autoware_tracker
21 |
22 |
--------------------------------------------------------------------------------
/point_cloud_features/src/point_cloud_features/point_cloud_features.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * BSD 3-Clause License
3 | *
4 | * Copyright (c) 2020, Zhi Yan
5 | * All rights reserved.
6 |
7 | * Redistribution and use in source and binary forms, with or without
8 | * modification, are permitted provided that the following conditions are met:
9 |
10 | * 1. Redistributions of source code must retain the above copyright notice, this
11 | * list of conditions and the following disclaimer.
12 |
13 | * 2. Redistributions in binary form must reproduce the above copyright notice,
14 | * this list of conditions and the following disclaimer in the documentation
15 | * and/or other materials provided with the distribution.
16 |
17 | * 3. Neither the name of the copyright holder nor the names of its
18 | * contributors may be used to endorse or promote products derived from
19 | * this software without specific prior written permission.
20 |
21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 | **/
32 |
33 | #include "point_cloud_features/point_cloud_features.h"
34 |
35 | /* f1 (1d): Number of points included in a cluster */
36 | int numberOfPoints(pcl::PointCloud::Ptr pc) {
37 | return pc->size();
38 | }
39 |
40 | /* f2 (1d): The minimum distance of the cluster to the sensor */
41 | /* f1 and f2 could be used in pairs, since f1 varies with f2 changes */
42 | float minDistance(pcl::PointCloud::Ptr pc) {
43 | float m = FLT_MAX;
44 |
45 | for(int i = 0; i < pc->size(); i++) {
46 | m = std::min(m, pc->points[i].x*pc->points[i].x + pc->points[i].y*pc->points[i].y + pc->points[i].z*pc->points[i].z);
47 | }
48 |
49 | return sqrt(m);
50 | }
51 |
52 | /* f3 (6d): 3D covariance matrix of the cluster */
53 | void covarianceMat3D(pcl::PointCloud::Ptr pc, std::vector &res) {
54 | Eigen::Matrix3f covariance_3d;
55 | pcl::PCA pca;
56 | pcl::PointCloud::Ptr pc_projected(new pcl::PointCloud);
57 | Eigen::Vector4f centroid;
58 |
59 | pca.setInputCloud(pc);
60 | pca.project(*pc, *pc_projected);
61 | pcl::compute3DCentroid(*pc, centroid);
62 | pcl::computeCovarianceMatrixNormalized(*pc_projected, centroid, covariance_3d);
63 |
64 | // Only 6 elements are needed as covariance_3d is symmetric.
65 | res.push_back(covariance_3d(0,0));
66 | res.push_back(covariance_3d(0,1));
67 | res.push_back(covariance_3d(0,2));
68 | res.push_back(covariance_3d(1,1));
69 | res.push_back(covariance_3d(1,2));
70 | res.push_back(covariance_3d(2,2));
71 | }
72 |
73 | /* f4 (6d): The normalized moment of inertia tensor */
74 | void normalizedMOIT(pcl::PointCloud::Ptr pc, std::vector &res) {
75 | Eigen::Matrix3f moment_3d;
76 | pcl::PCA pca;
77 | pcl::PointCloud::Ptr pc_projected(new pcl::PointCloud);
78 |
79 | moment_3d.setZero();
80 | pca.setInputCloud(pc);
81 | pca.project(*pc, *pc_projected);
82 | for(int i = 0; i < (*pc_projected).size(); i++) {
83 | moment_3d(0,0) += (*pc_projected)[i].y*(*pc_projected)[i].y + (*pc_projected)[i].z*(*pc_projected)[i].z;
84 | moment_3d(0,1) -= (*pc_projected)[i].x*(*pc_projected)[i].y;
85 | moment_3d(0,2) -= (*pc_projected)[i].x*(*pc_projected)[i].z;
86 | moment_3d(1,1) += (*pc_projected)[i].x*(*pc_projected)[i].x + (*pc_projected)[i].z*(*pc_projected)[i].z;
87 | moment_3d(1,2) -= (*pc_projected)[i].y*(*pc_projected)[i].z;
88 | moment_3d(2,2) += (*pc_projected)[i].x*(*pc_projected)[i].x + (*pc_projected)[i].y*(*pc_projected)[i].y;
89 | }
90 |
91 | // Only 6 elements are needed as moment_3d is symmetric.
92 | res.push_back(moment_3d(0,0));
93 | res.push_back(moment_3d(0,1));
94 | res.push_back(moment_3d(0,2));
95 | res.push_back(moment_3d(1,1));
96 | res.push_back(moment_3d(1,2));
97 | res.push_back(moment_3d(2,2));
98 | }
99 |
100 | /* f5 (n*2d): Slice feature for the cluster */
101 | void sliceFeature(pcl::PointCloud::Ptr pc, int n, std::vector