├── .gitignore ├── images ├── camvid12.png └── greenhouse4.png ├── script ├── pytorch_enet_ros_node.py ├── visualizer.py └── model.py ├── src ├── pytorch_seg_trav_path_node.cpp ├── pytorch_seg_node.cpp ├── pytorch_seg_trav_node.cpp └── impl │ ├── pytorch_cpp_wrapper_seg.cpp │ ├── pytorch_cpp_wrapper_seg_trav.cpp │ ├── pytorch_cpp_wrapper_seg_trav_path.cpp │ ├── pytorch_cpp_wrapper_base.cpp │ ├── pytorch_seg_ros.cpp │ ├── pytorch_seg_trav_ros.cpp │ └── pytorch_seg_trav_path_ros.cpp ├── include ├── pytorch_cpp_wrapper │ ├── pytorch_cpp_wrapper_seg.h │ ├── pytorch_cpp_wrapper_seg_trav.h │ ├── pytorch_cpp_wrapper_seg_trav_path.h │ └── pytorch_cpp_wrapper_base.h └── pytorch_ros │ ├── pytorch_seg_ros.h │ ├── pytorch_seg_trav_ros.h │ └── pytorch_seg_trav_path_ros.h ├── launch └── pytorch_enet_ros.launch ├── CMakeLists.txt ├── package.xml └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | models/ 2 | -------------------------------------------------------------------------------- /images/camvid12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveIntelligentSystemsLab/pytorch_enet_ros/HEAD/images/camvid12.png -------------------------------------------------------------------------------- /images/greenhouse4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveIntelligentSystemsLab/pytorch_enet_ros/HEAD/images/greenhouse4.png -------------------------------------------------------------------------------- /script/pytorch_enet_ros_node.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .model import ENet 3 | 4 | class pytorch_enet_ros: 5 | def __init__(self): 6 | pass 7 | 8 | 9 | def main(): 10 | 11 | 12 | if __name__=='__main__': 13 | main() 14 | -------------------------------------------------------------------------------- /src/pytorch_seg_trav_path_node.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * A wrapper class of PyTorch C++ to use PyTorch model 3 | * Shigemichi Matsuzaki 4 | * 5 | */ 6 | 7 | #include 8 | #include 9 | 10 | int main(int argc, char* argv[]) { 11 | // Initialize the node 12 | ros::init(argc, argv, "pytorch_seg_trav_path"); 13 | 14 | ros::NodeHandle nh("~"); 15 | 16 | // Initialize the class 17 | PyTorchSegTravPathROS pytorch_ros(nh); 18 | 19 | ROS_INFO("[PyTorchENetROS] The node has been initialized"); 20 | 21 | ros::spin(); 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/pytorch_seg_node.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * A wrapper class of PyTorch C++ to use PyTorch model 3 | * Shigemichi Matsuzaki 4 | * 5 | */ 6 | 7 | #include 8 | #include 9 | 10 | int main(int argc, char* argv[]) { 11 | // Initialize the node 12 | ros::init(argc, argv, "pytorch_seg_ros"); 13 | 14 | ros::NodeHandle nh("~"); 15 | //ros::NodeHandle nh; 16 | 17 | // Initialize the class 18 | PyTorchSegROS pytorch_seg_ros(nh); 19 | 20 | ROS_INFO("[PyTorchSegROS] The node has been initialized"); 21 | 22 | ros::spin(); 23 | 24 | // ros::Rate rate(30.0); 25 | // while(ros::ok()) { 26 | // ros::spinOnce(); 27 | // 28 | // rate.sleep(); 29 | // } 30 | } 31 | -------------------------------------------------------------------------------- /src/pytorch_seg_trav_node.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * A wrapper class of PyTorch C++ to use PyTorch model 3 | * Shigemichi Matsuzaki 4 | * 5 | */ 6 | 7 | #include 8 | #include 9 | 10 | int main(int argc, char* argv[]) { 11 | // Initialize the node 12 | ros::init(argc, argv, "pytorch_seg_trav_ros"); 13 | 14 | ros::NodeHandle nh("~"); 15 | //ros::NodeHandle nh; 16 | 17 | // Initialize the class 18 | PyTorchSegTravROS pytorch_seg_trav_ros(nh); 19 | 20 | ROS_INFO("[PyTorchSegTravROS] The node has been initialized"); 21 | 22 | ros::spin(); 23 | 24 | // ros::Rate rate(30.0); 25 | // while(ros::ok()) { 26 | // ros::spinOnce(); 27 | // 28 | // rate.sleep(); 29 | // } 30 | } 31 | -------------------------------------------------------------------------------- /include/pytorch_cpp_wrapper/pytorch_cpp_wrapper_seg.h: -------------------------------------------------------------------------------- 1 | #ifndef PYTORCH_CPP_WRAPPER 2 | #define PYTORCH_CPP_WRAPPER 3 | 4 | #include // One-stop header. 5 | #include // One-stop header. 6 | #include 7 | #include 8 | #include "opencv2/highgui/highgui.hpp" 9 | #include "pytorch_cpp_wrapper/pytorch_cpp_wrapper_base.h" 10 | 11 | #include 12 | #include 13 | 14 | class PyTorchCppWrapperSeg : public PyTorchCppWrapperBase { 15 | private : 16 | // c = P(s|y=1) in PU learning, calculated during training 17 | float c_{0.3}; 18 | bool use_aux_branch_{false}; 19 | 20 | public: 21 | PyTorchCppWrapperSeg(const std::string & filename, const int class_num); 22 | PyTorchCppWrapperSeg(const char* filename, const int class_num); 23 | 24 | /** 25 | * @brief Get outputs from the model 26 | * @param[in] input_tensor Input tensor 27 | * @return A tuple of output tensors (segmentation) 28 | */ 29 | at::Tensor get_output(at::Tensor input_tensor); 30 | }; 31 | #endif 32 | -------------------------------------------------------------------------------- /include/pytorch_cpp_wrapper/pytorch_cpp_wrapper_seg_trav.h: -------------------------------------------------------------------------------- 1 | #ifndef PYTORCH_CPP_WRAPPER 2 | #define PYTORCH_CPP_WRAPPER 3 | 4 | #include // One-stop header. 5 | #include // One-stop header. 6 | #include 7 | #include 8 | #include "opencv2/highgui/highgui.hpp" 9 | #include "pytorch_cpp_wrapper/pytorch_cpp_wrapper_base.h" 10 | 11 | #include 12 | #include 13 | 14 | class PyTorchCppWrapperSegTrav : public PyTorchCppWrapperBase { 15 | private : 16 | // c = P(s|y=1) in PU learning, calculated during training 17 | float c_{0.3}; 18 | 19 | public: 20 | PyTorchCppWrapperSegTrav(const std::string & filename, const int class_num); 21 | PyTorchCppWrapperSegTrav(const char* filename, const int class_num); 22 | 23 | /** 24 | * @brief Get outputs from the model 25 | * @param[in] input_tensor Input tensor 26 | * @return A tuple of output tensors (segmentation and traversability) 27 | */ 28 | std::tuple get_output(at::Tensor input_tensor); 29 | }; 30 | #endif 31 | -------------------------------------------------------------------------------- /include/pytorch_cpp_wrapper/pytorch_cpp_wrapper_seg_trav_path.h: -------------------------------------------------------------------------------- 1 | #ifndef PYTORCH_CPP_WRAPPER 2 | #define PYTORCH_CPP_WRAPPER 3 | 4 | #include // One-stop header. 5 | #include // One-stop header. 6 | #include 7 | #include 8 | #include "opencv2/highgui/highgui.hpp" 9 | #include "pytorch_cpp_wrapper/pytorch_cpp_wrapper_base.h" 10 | 11 | #include 12 | #include 13 | 14 | 15 | class PyTorchCppWrapperSegTravPath : public PyTorchCppWrapperBase { 16 | private : 17 | // c = P(s|y=1) in PU learning, calculated during training 18 | float c_{0.3}; 19 | 20 | public: 21 | PyTorchCppWrapperSegTravPath(const std::string & filename, const int class_num); 22 | PyTorchCppWrapperSegTravPath(const char* filename, const int class_num); 23 | 24 | /** 25 | * @brief Get outputs from the model 26 | * @param[in] input_tensor Input tensor 27 | * @return A tuple of output tensors (segmentation, traversability, and path (points)) 28 | */ 29 | std::tuple get_output(at::Tensor input_tensor); 30 | }; 31 | #endif 32 | -------------------------------------------------------------------------------- /src/impl/pytorch_cpp_wrapper_seg.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * A wrapper class of PyTorch C++ to use PyTorch model 3 | * Shigemichi Matsuzaki 4 | * 5 | */ 6 | 7 | #include 8 | #include "pytorch_cpp_wrapper/pytorch_cpp_wrapper_seg.h" 9 | #include // One-stop header. 10 | #include // One-stop header. 11 | #include 12 | #include 13 | #include "opencv2/highgui/highgui.hpp" 14 | #include 15 | 16 | PyTorchCppWrapperSeg::PyTorchCppWrapperSeg(const std::string & filename, const int class_num) 17 | : PyTorchCppWrapperBase(filename, class_num) 18 | { } 19 | 20 | PyTorchCppWrapperSeg::PyTorchCppWrapperSeg(const char* filename, const int class_num) 21 | : PyTorchCppWrapperBase(filename, class_num) 22 | { } 23 | 24 | /** 25 | * @brief Get outputs from the model 26 | * @param[in] input_tensor Input tensor 27 | * @return A tuple of output tensors (segmentation) 28 | */ 29 | at::Tensor 30 | PyTorchCppWrapperSeg::get_output(at::Tensor input_tensor) 31 | { 32 | // Execute the model and turn its output into a tensor. 33 | auto outputs_tmp = module_.forward({input_tensor}); //.toTuple(); 34 | 35 | at::Tensor segmentation; 36 | // If the network has two branches 37 | if(use_aux_branch_) { 38 | auto outputs = outputs_tmp.toTuple(); 39 | 40 | at::Tensor output1 = outputs->elements()[0].toTensor(); 41 | at::Tensor output2 = outputs->elements()[1].toTensor(); 42 | 43 | segmentation = output1 + 0.5 * output2; 44 | } else { 45 | // If there's only one segmentation branch, directly use the output 46 | segmentation = outputs_tmp.toTensor(); 47 | } 48 | 49 | return segmentation; 50 | } 51 | 52 | 53 | //} // namespace mpl 54 | -------------------------------------------------------------------------------- /include/pytorch_ros/pytorch_seg_ros.h: -------------------------------------------------------------------------------- 1 | /* 2 | * A ROS node to do inference using PyTorch model 3 | * Shigemichi Matsuzaki 4 | * 5 | */ 6 | 7 | #ifndef PYTORCH_ENET_ROS 8 | #define PYTORCH_ENET_ROS 9 | 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | //#include 16 | #include 17 | 18 | #include"pytorch_cpp_wrapper/pytorch_cpp_wrapper_seg.h" 19 | 20 | #include 21 | #include 22 | #include 23 | 24 | class PyTorchSegROS { 25 | private: 26 | ros::NodeHandle nh_; 27 | 28 | ros::ServiceServer get_label_image_server_; 29 | 30 | image_transport::ImageTransport it_; 31 | 32 | image_transport::Subscriber sub_image_; 33 | image_transport::Publisher pub_label_image_; 34 | image_transport::Publisher pub_color_image_; 35 | image_transport::Publisher pub_uncertainty_image_; 36 | 37 | std::shared_ptr pt_wrapper_ptr_; 38 | 39 | cv::Mat colormap_; 40 | 41 | public: 42 | PyTorchSegROS(ros::NodeHandle & nh); 43 | 44 | void image_callback(const sensor_msgs::ImageConstPtr& msg); 45 | std::tuple inference(cv::Mat & input_image); 46 | bool image_inference_srv_callback(semantic_segmentation_srvs::GetLabelAndProbability::Request & req, 47 | semantic_segmentation_srvs::GetLabelAndProbability::Response & res); 48 | cv_bridge::CvImagePtr msg_to_cv_bridge(sensor_msgs::ImageConstPtr msg); 49 | cv_bridge::CvImagePtr msg_to_cv_bridge(sensor_msgs::Image msg); 50 | void label_to_color(cv::Mat& label, cv::Mat& color_label); 51 | }; 52 | 53 | #endif 54 | -------------------------------------------------------------------------------- /src/impl/pytorch_cpp_wrapper_seg_trav.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * A wrapper class of PyTorch C++ to use PyTorch model 3 | * Shigemichi Matsuzaki 4 | * 5 | */ 6 | 7 | #include 8 | #include "pytorch_cpp_wrapper/pytorch_cpp_wrapper_seg_trav.h" 9 | #include // One-stop header. 10 | #include // One-stop header. 11 | #include 12 | #include 13 | #include "opencv2/highgui/highgui.hpp" 14 | #include 15 | 16 | PyTorchCppWrapperSegTrav::PyTorchCppWrapperSegTrav(const std::string & filename, const int class_num) 17 | : PyTorchCppWrapperBase(filename, class_num) 18 | { } 19 | 20 | PyTorchCppWrapperSegTrav::PyTorchCppWrapperSegTrav(const char* filename, const int class_num) 21 | : PyTorchCppWrapperBase(filename, class_num) 22 | { } 23 | 24 | /** 25 | * @brief Get outputs from the model 26 | * @param[in] input_tensor Input tensor 27 | * @return A tuple of output tensors (segmentation and traversability) 28 | */ 29 | std::tuple 30 | PyTorchCppWrapperSegTrav::get_output(at::Tensor input_tensor) 31 | { 32 | // Execute the model and turn its output into a tensor. 33 | auto outputs_tmp = module_.forward({input_tensor}); //.toTuple(); 34 | 35 | auto outputs = outputs_tmp.toTuple(); 36 | 37 | at::Tensor output1 = outputs->elements()[0].toTensor(); 38 | at::Tensor output2 = outputs->elements()[1].toTensor(); 39 | at::Tensor prob = outputs->elements()[2].toTensor(); 40 | 41 | // Divide probability by c 42 | prob = torch::sigmoid(prob) / c_; 43 | // Limit the values in range [0, 1] 44 | prob = at::clamp(prob, 0.0, 1.0); 45 | 46 | // return output1 + 0.5 * output2; 47 | at::Tensor segmentation = output1 + 0.5 * output2; 48 | 49 | return std::forward_as_tuple(segmentation, prob); 50 | } 51 | 52 | 53 | //} // namespace mpl 54 | -------------------------------------------------------------------------------- /launch/pytorch_enet_ros.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 8 | 9 | 10 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /include/pytorch_ros/pytorch_seg_trav_ros.h: -------------------------------------------------------------------------------- 1 | /* 2 | * A ROS node to do inference using PyTorch model 3 | * Shigemichi Matsuzaki 4 | * 5 | */ 6 | 7 | #ifndef PYTORCH_ENET_ROS 8 | #define PYTORCH_ENET_ROS 9 | 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | //#include 16 | #include 17 | 18 | #include"pytorch_cpp_wrapper/pytorch_cpp_wrapper_seg_trav.h" 19 | 20 | #include 21 | #include 22 | #include 23 | 24 | class PyTorchSegTravROS { 25 | private: 26 | ros::NodeHandle nh_; 27 | 28 | ros::ServiceServer get_label_image_server_; 29 | 30 | image_transport::ImageTransport it_; 31 | 32 | image_transport::Subscriber sub_image_; 33 | image_transport::Publisher pub_label_image_; 34 | image_transport::Publisher pub_color_image_; 35 | image_transport::Publisher pub_prob_image_; 36 | image_transport::Publisher pub_uncertainty_image_; 37 | 38 | // PyTorchCppWrapperSegTrav pt_wrapper_; 39 | std::shared_ptr pt_wrapper_ptr_; 40 | 41 | cv::Mat colormap_; 42 | 43 | public: 44 | PyTorchSegTravROS(ros::NodeHandle & nh); 45 | 46 | void image_callback(const sensor_msgs::ImageConstPtr& msg); 47 | std::tuple inference(cv::Mat & input_image); 48 | bool image_inference_srv_callback(semantic_segmentation_srvs::GetLabelAndProbability::Request & req, 49 | semantic_segmentation_srvs::GetLabelAndProbability::Response & res); 50 | cv_bridge::CvImagePtr msg_to_cv_bridge(sensor_msgs::ImageConstPtr msg); 51 | cv_bridge::CvImagePtr msg_to_cv_bridge(sensor_msgs::Image msg); 52 | void label_to_color(cv::Mat& label, cv::Mat& color_label); 53 | }; 54 | 55 | #endif 56 | -------------------------------------------------------------------------------- /src/impl/pytorch_cpp_wrapper_seg_trav_path.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * A wrapper class of PyTorch C++ to use PyTorch model 3 | * Shigemichi Matsuzaki 4 | * 5 | */ 6 | 7 | 8 | #include 9 | #include "pytorch_cpp_wrapper/pytorch_cpp_wrapper_seg_trav_path.h" 10 | #include // One-stop header. 11 | #include // One-stop header. 12 | #include 13 | #include 14 | #include "opencv2/highgui/highgui.hpp" 15 | #include 16 | 17 | PyTorchCppWrapperSegTravPath::PyTorchCppWrapperSegTravPath(const std::string & filename, const int class_num) 18 | : PyTorchCppWrapperBase(filename, class_num) 19 | { } 20 | 21 | PyTorchCppWrapperSegTravPath::PyTorchCppWrapperSegTravPath(const char* filename, const int class_num) 22 | : PyTorchCppWrapperBase(filename, class_num) 23 | { } 24 | 25 | /** 26 | * @brief Get outputs from the model 27 | * @param[in] input_tensor Input tensor 28 | * @return A tuple of output tensors (segmentation, traversability, and path (points)) 29 | */ 30 | std::tuple 31 | PyTorchCppWrapperSegTravPath::get_output(at::Tensor input_tensor) 32 | { 33 | // Execute the model and turn its output into a tensor. 34 | auto outputs_tmp = module_.forward({input_tensor}); //.toTuple(); 35 | 36 | auto outputs = outputs_tmp.toTuple(); 37 | 38 | at::Tensor output1 = outputs->elements()[0].toTensor(); 39 | at::Tensor output2 = outputs->elements()[1].toTensor(); 40 | at::Tensor prob = outputs->elements()[2].toTensor(); 41 | at::Tensor path = outputs->elements()[3].toTensor(); 42 | 43 | // Divide probability by c 44 | prob = torch::sigmoid(prob) / c_; 45 | // Limit the values in range [0, 1] 46 | prob = at::clamp(prob, 0.0, 1.0); 47 | 48 | // return output1 + 0.5 * output2; 49 | at::Tensor segmentation = output1 + 0.5 * output2; 50 | 51 | path = torch::sigmoid(path); 52 | 53 | return std::forward_as_tuple(segmentation, prob, path); 54 | } 55 | -------------------------------------------------------------------------------- /include/pytorch_cpp_wrapper/pytorch_cpp_wrapper_base.h: -------------------------------------------------------------------------------- 1 | #ifndef PYTORCH_CPP_WRAPPER_BASE 2 | #define PYTORCH_CPP_WRAPPER_BASE 3 | 4 | #include // One-stop header. 5 | #include // One-stop header. 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | /** 13 | * @brief this class is a base class of C++ wrapper of PyTorch 14 | */ 15 | class PyTorchCppWrapperBase { 16 | protected : 17 | torch::jit::script::Module module_; 18 | int class_num_; 19 | float max_entropy_; 20 | 21 | public: 22 | PyTorchCppWrapperBase(); 23 | PyTorchCppWrapperBase(const std::string & filename, const int class_num); 24 | PyTorchCppWrapperBase(const char* filename, const int class_num); 25 | 26 | /** 27 | * @brief import a network 28 | * @param filename 29 | * @return true if import succeeded 30 | */ 31 | bool import_module(const std::string & filename); 32 | 33 | /** 34 | * @brief convert an image(cv::Mat) to a tensor (at::Tensor) 35 | * @param[in] img 36 | * @param[out] tensor 37 | * @param[in] whether to use GPU 38 | */ 39 | void img2tensor(cv::Mat & img, at::Tensor & tensor, const bool & use_gpu = true); 40 | 41 | /** 42 | * @brief convert a tensor (at::Tensor) to an image (cv::Mat) 43 | * @param[in] tensor 44 | * @param[out] img 45 | */ 46 | void tensor2img(at::Tensor tensor, cv::Mat & img); 47 | 48 | /** 49 | * @brief convert a tensor (at::Tensor) to an image (cv::Mat) 50 | * @param[in] tensor 51 | * @return converted CV image 52 | */ 53 | cv::Mat tensor2img(at::Tensor tensor); 54 | 55 | /** 56 | * @brief Take element-wise argmax 57 | * @param[in] tensor 58 | * @param[out] tensor that has index of max value in each element 59 | */ 60 | at::Tensor get_argmax(at::Tensor input_tensor); 61 | 62 | /** 63 | * @brief Take element-wise entropy 64 | * @param[in] tensor 65 | * @param[out] tensor that has index of max value in each element 66 | */ 67 | at::Tensor get_entropy(at::Tensor input_tensor, const bool normalize); 68 | 69 | }; 70 | //} 71 | #endif 72 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(pytorch_ros) 3 | 4 | add_compile_options(-std=c++14) 5 | 6 | # Locate the cmake file of torchlib 7 | # set(Torch_DIR "/usr/local/lib/python3.8/dist-packages/torch/share/cmake/Torch/") 8 | set(Torch_DIR "$ENV{TORCH_PATH}/torch/share/cmake/Torch/") 9 | 10 | ## Find catkin macros and libraries 11 | ## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) 12 | ## is used, also find other catkin packages 13 | find_package(catkin REQUIRED COMPONENTS 14 | roscpp 15 | rospy 16 | std_msgs 17 | image_transport 18 | cv_bridge 19 | semantic_segmentation_srvs 20 | tf2 21 | tf2_ros 22 | ) 23 | 24 | find_package(Torch REQUIRED) 25 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 26 | find_package(OpenCV REQUIRED) 27 | 28 | catkin_package( 29 | INCLUDE_DIRS include 30 | CATKIN_DEPENDS roscpp rospy std_msgs 31 | ) 32 | 33 | ########### 34 | ## Build ## 35 | ########### 36 | 37 | ## Specify additional locations of header files 38 | ## Your package locations should be listed before other locations 39 | include_directories( 40 | include 41 | ${catkin_INCLUDE_DIRS} 42 | ${OpenCV_INCLUDE_DIRS} 43 | ) 44 | 45 | ## Declare a C++ library 46 | add_library(${PROJECT_NAME} 47 | src/impl/pytorch_seg_trav_path_ros.cpp 48 | src/impl/pytorch_cpp_wrapper_seg.cpp 49 | src/impl/pytorch_cpp_wrapper_seg_trav.cpp 50 | src/impl/pytorch_cpp_wrapper_seg_trav_path.cpp 51 | src/impl/pytorch_cpp_wrapper_base.cpp 52 | ) 53 | 54 | target_link_libraries(${PROJECT_NAME} 55 | ${catkin_LIBRARIES} 56 | ${TORCH_LIBRARIES} 57 | ${Open_CV_LIBS} 58 | opencv_core opencv_highgui opencv_imgcodecs 59 | ) 60 | set_property(TARGET ${PROJECT_NAME} PROPERTY CXX_STANDARD 14) 61 | 62 | ## Declare a C++ executable 63 | ## With catkin_make all packages are built within a single CMake context 64 | ## The recommended prefix ensures that target names across packages don't collide 65 | add_executable(pytorch_seg_trav_path_node src/pytorch_seg_trav_path_node.cpp src/impl/pytorch_seg_trav_path_ros.cpp) 66 | add_executable(pytorch_seg_trav_node src/pytorch_seg_trav_node.cpp src/impl/pytorch_seg_trav_ros.cpp) 67 | add_executable(pytorch_seg_node src/pytorch_seg_node.cpp src/impl/pytorch_seg_ros.cpp) 68 | 69 | ## Specify libraries to link a library or executable target against 70 | target_link_libraries(pytorch_seg_trav_path_node 71 | ${catkin_LIBRARIES} 72 | ${PROJECT_NAME} 73 | ${TORCH_LIBRARIES} 74 | ) 75 | 76 | target_link_libraries(pytorch_seg_trav_node 77 | ${catkin_LIBRARIES} 78 | ${PROJECT_NAME} 79 | ${TORCH_LIBRARIES} 80 | ) 81 | 82 | target_link_libraries(pytorch_seg_node 83 | ${catkin_LIBRARIES} 84 | ${PROJECT_NAME} 85 | ${TORCH_LIBRARIES} 86 | ) 87 | -------------------------------------------------------------------------------- /script/visualizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================ 5 | __author__ = "ShigemichiMatsuzaki" 6 | __maintainer__ = "ShigemichiMatsuzaki" 7 | # ============================================ 8 | 9 | import rospy 10 | import cv2 11 | import message_filters 12 | from sensor_msgs.msg import Image 13 | from geometry_msgs.msg import PointStamped 14 | from cv_bridge import CvBridge 15 | import copy 16 | 17 | 18 | class Visualizer: 19 | """ """ 20 | 21 | def __init__(self): 22 | """ Constructor of Visualizer class 23 | 24 | - define the class variables (ROS publisher and subscribers etc.) 25 | 26 | """ 27 | # Subscribers 28 | self.image_sub = message_filters.Subscriber('image', Image) 29 | self.start_point_sub = message_filters.Subscriber( 30 | 'start_point', PointStamped) 31 | self.end_point_sub = message_filters.Subscriber( 32 | 'end_point', PointStamped) 33 | # Time synchronizer 34 | self.ts = message_filters.TimeSynchronizer( 35 | [self.image_sub, self.start_point_sub, self.end_point_sub], 100) 36 | self.ts.registerCallback(self.image_points_callback) 37 | 38 | # Publisher 39 | self.image_pub = rospy.Publisher( 40 | 'image_and_path', Image, queue_size=100) 41 | 42 | self.bridge = CvBridge() 43 | 44 | def image_points_callback(self, img_msg, start_point_msg, end_point_msg): 45 | """Callback of image and point messages 46 | 47 | Args: 48 | img_msg(sensor_msgs/Image) 49 | start_point_msg(geometry_msgs/PointStamped) 50 | end_point_msg(geometry_msgs/PointStamped) 51 | """ 52 | # Convert the image message to 53 | cv_image = self.bridge.imgmsg_to_cv2( 54 | img_msg, desired_encoding='passthrough') 55 | 56 | cv_image_with_line = self.draw_line(cv_image, 57 | (int(start_point_msg.point.x), 58 | int(start_point_msg.point.y)), 59 | (int(end_point_msg.point.x), int(end_point_msg.point.y))) 60 | 61 | vis_img_msg = self.bridge.cv2_to_imgmsg(cv_image_with_line) 62 | 63 | self.image_pub.publish(vis_img_msg) 64 | 65 | def draw_line(self, cv_image, start_point, end_point): 66 | """Draw a line, whose start and end points are given as PointStamped messages, on the image 67 | 68 | Args: 69 | cv_image(OpenCV image) 70 | start_point_msg(geometry_msgs/PointStamped) 71 | end_point_msg(geometry_msgs/PointStamped) 72 | 73 | Return: 74 | OpenCV image with a line drawn 75 | """ 76 | ret_image = copy.deepcopy(cv_image) 77 | 78 | cv2.line(ret_image, start_point, end_point, 79 | color=(0, 0, 255), thickness=10) 80 | 81 | return ret_image 82 | 83 | 84 | def main(): 85 | """Main function to initialize the ROS node""" 86 | rospy.init_node("visualizer") 87 | 88 | visualizer = Visualizer() 89 | 90 | rospy.loginfo('visualizer is initialized') 91 | 92 | rospy.spin() 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | -------------------------------------------------------------------------------- /package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | pytorch_ros 4 | 0.0.0 5 | The pytorch_ros package 6 | 7 | 8 | 9 | 10 | root 11 | 12 | 13 | 14 | 15 | 16 | TODO 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | catkin 52 | roscpp 53 | rospy 54 | std_msgs 55 | semantic_segmentation_srvs 56 | tf2 57 | tf2_ros 58 | tf2_sensor_msgs 59 | roscpp 60 | rospy 61 | std_msgs 62 | roscpp 63 | rospy 64 | std_msgs 65 | semantic_segmentation_srvs 66 | tf2 67 | tf2_ros 68 | tf2_sensor_msgs 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /include/pytorch_ros/pytorch_seg_trav_path_ros.h: -------------------------------------------------------------------------------- 1 | /* 2 | * A ROS node to do inference using PyTorch model 3 | * Shigemichi Matsuzaki 4 | * 5 | */ 6 | 7 | #ifndef PYTORCH_SEG_TRAV_PATH 8 | #define PYTORCH_SEG_TRAV_PATH 9 | 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | //#include 17 | #include 18 | #include"pytorch_cpp_wrapper/pytorch_cpp_wrapper_seg_trav_path.h" 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | class PyTorchSegTravPathROS { 27 | private: 28 | ros::NodeHandle nh_; 29 | 30 | ros::ServiceServer get_label_image_server_; 31 | 32 | image_transport::ImageTransport it_; 33 | 34 | // Message subscribers and publishers 35 | image_transport::Subscriber sub_image_; 36 | image_transport::Publisher pub_label_image_; 37 | image_transport::Publisher pub_color_image_; 38 | image_transport::Publisher pub_prob_image_; 39 | image_transport::Publisher pub_uncertainty_image_; 40 | ros::Publisher pub_start_point_; 41 | ros::Publisher pub_end_point_; 42 | ros::Time stamp_of_current_image_; 43 | 44 | std::shared_ptr pt_wrapper_ptr_; 45 | 46 | // Used to convert a label image to a color image 47 | cv::Mat colormap_; 48 | 49 | public: 50 | PyTorchSegTravPathROS(ros::NodeHandle & nh); 51 | 52 | /** 53 | * @brief Image callback 54 | * @param[in] msg Message 55 | */ 56 | void image_callback(const sensor_msgs::ImageConstPtr& msg); 57 | 58 | /** 59 | * @brief Main function for inference using the model 60 | * @param[in] input_image OpenCV image 61 | * @return A tuple of messages of the inference results 62 | */ 63 | std::tuple 64 | inference(cv::Mat & input_image); 65 | 66 | /** 67 | * @brief Service callback 68 | * @param[in] req Request 69 | * @param[in] res Response 70 | * @return True if the service succeeded 71 | */ 72 | bool image_inference_srv_callback(semantic_segmentation_srvs::GetLabelAndProbability::Request & req, 73 | semantic_segmentation_srvs::GetLabelAndProbability::Response & res); 74 | 75 | /** 76 | * @brief Convert Image message to cv_bridge 77 | * @param[in] msg Pointer of image message 78 | * @return cv_bridge 79 | */ 80 | cv_bridge::CvImagePtr msg_to_cv_bridge(sensor_msgs::ImageConstPtr msg); 81 | 82 | /** 83 | * @brief Convert Image message to cv_bridge 84 | * @param[in] msg Image message 85 | * @return cv_bridge 86 | */ 87 | cv_bridge::CvImagePtr msg_to_cv_bridge(sensor_msgs::Image msg); 88 | 89 | /** 90 | * @brief Convert a label image to color label image for visualization 91 | * @param[in] label Label image 92 | * @param[out] color_label Color image mapped from the label image 93 | */ 94 | void label_to_color(cv::Mat& label, cv::Mat& color_label); 95 | 96 | /** 97 | * @brief Convert a tensor with a size of (1, 4) to start and end points (x, y) 98 | * @param[in] point_tensor (1, 4) tensor 99 | * @param[in] width Original width of the image 100 | * @param[in] height Original height of the image 101 | * @return A tuple of start and end points as geometry_msgs::PointStampedPtr 102 | */ 103 | std::tuple tensor_to_points(const at::Tensor point_tensor, const int & width, const int & height); 104 | 105 | /** 106 | * @brief Normalize a tensor to feed in a model 107 | * @param[in] input Tensor 108 | */ 109 | void normalize_tensor(at::Tensor & input_tensor); 110 | 111 | }; 112 | 113 | #endif 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_scene_recognition_ros 2 | 3 | ## 1. Overview 4 | 5 | A ROS package to use [LibTorch](https://pytorch.org/cppdocs/), a PyTorch C++ API, for inference with our scene recognition models. 6 | 7 | A Docker environment for running this package is [here](https://github.com/ActiveIntelligentSystemsLab/pytorch-enet-docker). 8 | This package is **only tested in the virtual environment**. 9 | 10 | ## 2. Requirements 11 | 12 | - PyTorch with LibTorch built 13 | - Tested with 1.5.0 14 | - [semantic_segmentation_srvs/GetLabelAndProbability](https://github.com/ActiveIntelligentSystemsLab/aisl_utils/blob/master/aisl_srvs/semantic_segmentation_srv/srv/GetLabelAndProbability.srv) 15 | 16 | ## 3. Nodes 17 | 18 | ### 3.1 `pytorch_seg_trav_path_node` 19 | 20 | A node to use a multi-task model for semantic segmentation, traversability estimation, and path estimation. 21 | 22 | #### **3.1.1 Subscribed topics** 23 | 24 | - `image` ([sensor_msgs/Image](http://docs.ros.org/melodic/api/sensor_msgs/html/msg/Image.html)) 25 | 26 | An input image 27 | 28 | #### **3.1.2 Published topics** 29 | 30 | - `label` ([sensor_msgs/Image](http://docs.ros.org/melodic/api/sensor_msgs/html/msg/Image.html)) 31 | 32 | Image that stores label indices of each pixel 33 | 34 | - `color_label` ([sensor_msgs/Image](http://docs.ros.org/melodic/api/sensor_msgs/html/msg/Image.html)) 35 | 36 | Image that stores color labels of each pixel (for visualization) 37 | 38 | - `prob` ([sensor_msgs/Image](http://docs.ros.org/melodic/api/sensor_msgs/html/msg/Image.html)) 39 | 40 | Image that stores *traversability* of each pixel 41 | 42 | - `start_point` ([geometry_msgs/PointStamped](http://docs.ros.org/en/melodic/api/geometry_msgs/html/msg/PointStamped.html)) 43 | 44 | Start point of the estimated path line 45 | 46 | - `end_point` ([geometry_msgs/PointStamped](http://docs.ros.org/en/melodic/api/geometry_msgs/html/msg/PointStamped.html)) 47 | 48 | End point of the estimated path line 49 | 50 | #### **3.1.3 Service** 51 | 52 | - `get_label_image` ([semantic_segmentation_srvs/GetLabelAndProbability](https://github.com/ActiveIntelligentSystemsLab/aisl_utils/blob/master/aisl_srvs/semantic_segmentation_srv/srv/GetLabelAndProbability.srv)) 53 | 54 | Return inference results (segmentation and traversability) for a given image. 55 | 56 | ### **3.2 visualizer.py** 57 | 58 | #### **3.2.1 Subscribed topics** 59 | 60 | - `image` ([sensor_msgs/Image](http://docs.ros.org/melodic/api/sensor_msgs/html/msg/Image.html)) 61 | 62 | An input image 63 | 64 | - `start_point` ([geometry_msgs/PointStamped](http://docs.ros.org/en/melodic/api/geometry_msgs/html/msg/PointStamped.html)) 65 | 66 | Start point of the estimated path line from the inference node 67 | 68 | - `end_point` ([geometry_msgs/PointStamped](http://docs.ros.org/en/melodic/api/geometry_msgs/html/msg/PointStamped.html)) 69 | 70 | End point of the estimated path line from the inference node 71 | 72 | #### **3.2.2 Published topics** 73 | 74 | - `image_with_path` ([sensor_msgs/Image](http://docs.ros.org/melodic/api/sensor_msgs/html/msg/Image.html)) 75 | 76 | An image with the path overlaid 77 | 78 | ## 4. How to run the node 79 | 80 | ``` 81 | roslaunch pytorch_enet_ros.launch image:= model_name:= 82 | ``` 83 | 84 | ## 5. Weight files 85 | 86 | The ROS nodes in this package use models saved as a serialized Torch Script file. 87 | 88 | At this moment, we don't provide a script to generate the weight files. 89 | 90 | Refer to [this page](https://pytorch.org/tutorials/advanced/cpp_export.html) to get the weight file. 91 | 92 | ### CAUTION 93 | If the version of PyTorch that runs this ROS package and that you generate your weight file (serialized Torch Script) do not match, the ROS node may fail to import the weights. 94 | 95 | For example, if you use [our Docker environment](https://github.com/ActiveIntelligentSystemsLab/pytorch-enet-docker), the weights should be generated using PyTorch 1.5.0. 96 | 97 | ## 6. Color map 98 | 99 | For visualization of semantic segmentation, we use a color map image. 100 | 101 | It is a 1xC PNG image file (C: The number of classes), where 102 | the color of class i is stored in the pixel at (1, i). 103 | 104 | ## 7. Publications 105 | 106 | This repository is used in experiments in the publication as follows: 107 | 108 | Shigemichi Matsuzaki, Hiroaki Masuzawa, Jun Miura, Image-Based Scene Recognition for Robot Navigation Considering Traversable Plants and Its Manual Annotation-Free Training, IEEE Access, vol. 10, pp. 5115-5128, 2022 [[paper](https://ieeexplore.ieee.org/document/9674898)] 109 | -------------------------------------------------------------------------------- /src/impl/pytorch_cpp_wrapper_base.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * A wrapper class of PyTorch C++ to use PyTorch model 3 | * Shigemichi Matsuzaki 4 | * 5 | */ 6 | 7 | #include 8 | #include "pytorch_cpp_wrapper/pytorch_cpp_wrapper_base.h" 9 | #include // One-stop header. 10 | #include // One-stop header. 11 | #include 12 | #include 13 | #include "opencv2/highgui/highgui.hpp" 14 | #include 15 | #include 16 | 17 | PyTorchCppWrapperBase::PyTorchCppWrapperBase() {} 18 | 19 | PyTorchCppWrapperBase::PyTorchCppWrapperBase(const std::string & filename, const int class_num) 20 | : class_num_(class_num) 21 | { 22 | // Import model 23 | import_module(filename); 24 | 25 | // Calculate the maximum possible entropy 26 | // to normalize the entropy value in [0, 1]. 27 | max_entropy_ = 0; 28 | const float prob = (float) 1.0 / class_num_; 29 | for(int i = 0; i < class_num_; ++i) { 30 | max_entropy_ += -prob * std::log(prob); 31 | } 32 | } 33 | 34 | PyTorchCppWrapperBase::PyTorchCppWrapperBase(const char* filename, const int class_num) 35 | : class_num_(class_num) 36 | { 37 | // Import model 38 | import_module(std::string(filename)); 39 | 40 | // Calculate the maximum possible entropy 41 | // to normalize the entropy value in [0, 1]. 42 | max_entropy_ = 0; 43 | const float prob = (float) 1.0 / class_num_; 44 | for(int i = 0; i < class_num_; ++i) { 45 | max_entropy_ += -prob * std::log(prob); 46 | } 47 | } 48 | 49 | /** 50 | * @brief import a network 51 | * @param filename 52 | * @return true if import succeeded 53 | */ 54 | bool 55 | PyTorchCppWrapperBase::import_module(const std::string & filename) 56 | { 57 | try { 58 | // Deserialize the ScriptModule from a file using torch::jit::load(). 59 | module_ = torch::jit::load(filename); 60 | // Set evaluation mode 61 | module_.eval(); 62 | std::cout << module_.is_training() << std::endl; 63 | 64 | std::cout << "Import succeeded" << std::endl; 65 | return true; 66 | } 67 | catch (const c10::Error& e) { 68 | std::cerr << e.what(); 69 | return false; 70 | } 71 | } 72 | 73 | /** 74 | * @brief convert an image(cv::Mat) to a tensor (at::Tensor) 75 | * @param[in] img 76 | * @param[out] tensor 77 | * @param[in] whether to use GPU 78 | */ 79 | void 80 | PyTorchCppWrapperBase::img2tensor(cv::Mat & img, at::Tensor & tensor, const bool & use_gpu) 81 | { 82 | // Get the size of the input image 83 | int height = img.size().height; 84 | int width = img.size().width; 85 | 86 | // Create a vector of inputs. 87 | std::vectorshape = {1, height, width, 3}; 88 | if(use_gpu) { 89 | tensor = torch::from_blob(img.data, at::IntList(shape), at::ScalarType::Byte).to(torch::kFloat).to(torch::kCUDA); 90 | } else { 91 | tensor = torch::from_blob(img.data, at::IntList(shape), at::ScalarType::Byte).to(torch::kFloat).to(torch::kCPU); 92 | } 93 | tensor = at::transpose(tensor, 1, 2); 94 | tensor = at::transpose(tensor, 1, 3); 95 | } 96 | 97 | /** 98 | * @brief convert a tensor (at::Tensor) to an image (cv::Mat) 99 | * @param[in] tensor 100 | * @param[out] img 101 | */ 102 | void 103 | PyTorchCppWrapperBase::tensor2img(at::Tensor tensor, cv::Mat & img) 104 | { 105 | // Get the size of the input image 106 | int height = tensor.sizes()[0]; 107 | int width = tensor.sizes()[1]; 108 | 109 | // tensor = tensor.to(torch::kCPU); 110 | 111 | // Convert to OpenCV 112 | img = cv::Mat(height, width, CV_8U, tensor. template data()); 113 | } 114 | 115 | /** 116 | * @brief convert a tensor (at::Tensor) to an image (cv::Mat) 117 | * @param[in] tensor 118 | * @return converted CV image 119 | */ 120 | cv::Mat 121 | PyTorchCppWrapperBase::tensor2img(at::Tensor tensor) 122 | { 123 | // Get the size of the input image 124 | int height = tensor.sizes()[0]; 125 | int width = tensor.sizes()[1]; 126 | 127 | // Convert to OpenCV 128 | return cv::Mat(height, width, CV_8U, tensor. template data()); 129 | } 130 | 131 | /** 132 | * @brief Take element-wise argmax 133 | * @param[in] tensor 134 | * @param[out] tensor that has index of max value in each element 135 | */ 136 | at::Tensor 137 | PyTorchCppWrapperBase::get_argmax(at::Tensor input_tensor) 138 | { 139 | // Calculate argmax to get a label on each pixel 140 | at::Tensor output = at::argmax(input_tensor, /*dim=*/1).to(torch::kCPU).to(at::kByte); 141 | 142 | return output; 143 | } 144 | 145 | /** 146 | * @brief Take element-wise entropy 147 | * @param[in] tensor 148 | * @param[out] tensor that has index of max value in each element 149 | */ 150 | at::Tensor 151 | PyTorchCppWrapperBase::get_entropy(at::Tensor input_tensor, const bool normalize = true) 152 | { 153 | input_tensor.to(torch::kCUDA); 154 | // Calculate the entropy at each pixel 155 | at::Tensor log_p = torch::log_softmax(input_tensor, /*dim=*/1);//at::argmax(input_tensor, 1).to(torch::kCPU).to(at::kByte); 156 | at::Tensor p = torch::softmax(input_tensor, /*dim=*/1); 157 | 158 | at::Tensor entropy = -torch::sum(p * log_p, /*dim=*/1); 159 | 160 | if(normalize) 161 | entropy = entropy / max_entropy_; 162 | 163 | return entropy; 164 | } 165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /src/impl/pytorch_seg_ros.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * A ROS node to do inference using PyTorch model 3 | * Shigemichi Matsuzaki 4 | * 5 | */ 6 | 7 | #include 8 | 9 | PyTorchSegROS::PyTorchSegROS(ros::NodeHandle & nh) 10 | : it_(nh), nh_(nh) 11 | { 12 | sub_image_ = it_.subscribe("image", 1, &PyTorchSegROS::image_callback, this); 13 | pub_label_image_ = it_.advertise("label", 10); 14 | pub_color_image_ = it_.advertise("color_label", 10); 15 | pub_uncertainty_image_ = it_.advertise("uncertainty", 10); 16 | get_label_image_server_ = nh_.advertiseService("get_label_image", &PyTorchSegROS::image_inference_srv_callback, this); 17 | 18 | // Import the model 19 | std::string filename; 20 | nh_.param("model_file", filename, ""); 21 | pt_wrapper_ptr_.reset(new PyTorchCppWrapperSeg(filename, 4)); 22 | if(!pt_wrapper_ptr_->import_module(filename)) { 23 | ROS_ERROR("Failed to import the model file [%s]", filename.c_str()); 24 | ros::shutdown(); 25 | } 26 | 27 | // Import color map image 28 | std::string colormap_name; 29 | nh_.param("colormap", colormap_name, ""); 30 | colormap_ = cv::imread(colormap_name); 31 | if(colormap_.empty()) { 32 | ROS_ERROR("Failed to import the colormap file [%s]", colormap_name.c_str()); 33 | ros::shutdown(); 34 | } 35 | 36 | } 37 | 38 | void 39 | PyTorchSegROS::image_callback(const sensor_msgs::ImageConstPtr& msg) 40 | { 41 | ROS_INFO("[PyTorchSegROS image_callback] Let's start!!"); 42 | 43 | // Convert the image message to a cv_bridge object 44 | cv_bridge::CvImagePtr cv_ptr = msg_to_cv_bridge(msg); 45 | 46 | // Run inference 47 | sensor_msgs::ImagePtr label_msg; 48 | sensor_msgs::ImagePtr color_label_msg; 49 | sensor_msgs::ImagePtr uncertainty_msg; 50 | std::tie(label_msg, color_label_msg, uncertainty_msg) = inference(cv_ptr->image); 51 | 52 | // Set header 53 | label_msg->header = msg->header; 54 | color_label_msg->header = msg->header; 55 | uncertainty_msg->header = msg->header; 56 | 57 | pub_label_image_.publish(label_msg); 58 | pub_color_image_.publish(color_label_msg); 59 | pub_uncertainty_image_.publish(uncertainty_msg); 60 | } 61 | 62 | /* 63 | * image_inference_srv_callback : Callback for the service 64 | */ 65 | bool 66 | PyTorchSegROS::image_inference_srv_callback(semantic_segmentation_srvs::GetLabelAndProbability::Request & req, 67 | semantic_segmentation_srvs::GetLabelAndProbability::Response & res) 68 | { 69 | ROS_INFO("[PyTorchSegROS image_inference_srv_callback] Start"); 70 | 71 | // Convert the image message to a cv_bridge object 72 | cv_bridge::CvImagePtr cv_ptr = msg_to_cv_bridge(req.img); 73 | 74 | // Run inference 75 | sensor_msgs::ImagePtr label_msg; 76 | sensor_msgs::ImagePtr color_label_msg; 77 | sensor_msgs::ImagePtr uncertainty_msg; 78 | std::tie(label_msg, color_label_msg, uncertainty_msg) = inference(cv_ptr->image); 79 | 80 | res.label_img = *label_msg; 81 | res.colorlabel_img = *color_label_msg; 82 | res.uncertainty_img = *uncertainty_msg; 83 | 84 | return true; 85 | } 86 | 87 | /* 88 | * inference : Forward the given input image through the network and return the inference result 89 | */ 90 | std::tuple 91 | PyTorchSegROS::inference(cv::Mat & input_img) 92 | { 93 | 94 | // The size of the original image, to which the result of inference is resized back. 95 | int height_orig = input_img.size().height; 96 | int width_orig = input_img.size().width; 97 | 98 | cv::Size s(480, 256); 99 | // Resize the input image 100 | cv::resize(input_img, input_img, s); 101 | 102 | at::Tensor input_tensor; 103 | pt_wrapper_ptr_->img2tensor(input_img, input_tensor); 104 | 105 | // Normalize from [0, 255] -> [0, 1] 106 | input_tensor /= 255.0; 107 | // z-normalization 108 | std::vector mean_vec{0.485, 0.456, 0.406}; 109 | std::vector std_vec{0.229, 0.224, 0.225}; 110 | for(int i = 0; i < mean_vec.size(); i++) { 111 | input_tensor[0][i] = (input_tensor[0][i] - mean_vec[i]) / std_vec[i]; 112 | } 113 | // std::cout << input_tensor.sizes() << std::endl; 114 | 115 | // Execute the model and turn its output into a tensor. 116 | at::Tensor segmentation; 117 | segmentation = pt_wrapper_ptr_->get_output(input_tensor); 118 | 119 | at::Tensor output_args = pt_wrapper_ptr_->get_argmax(segmentation); 120 | 121 | // Uncertainty of segmentation 122 | at::Tensor uncertainty = pt_wrapper_ptr_->get_entropy(segmentation, true); 123 | uncertainty = (uncertainty[0]*255).to(torch::kCPU).to(torch::kByte); 124 | 125 | // Convert to OpenCV 126 | cv::Mat label; 127 | cv::Mat uncertainty_cv; 128 | pt_wrapper_ptr_->tensor2img(output_args[0], label); 129 | pt_wrapper_ptr_->tensor2img(uncertainty, uncertainty_cv); 130 | 131 | // Set the size 132 | cv::Size s_orig(width_orig, height_orig); 133 | // Resize the input image back to the original size 134 | cv::resize(label, label, s_orig, cv::INTER_NEAREST); 135 | cv::resize(uncertainty_cv, uncertainty_cv, s_orig, cv::INTER_LINEAR); 136 | // Generate color label image 137 | cv::Mat color_label; 138 | label_to_color(label, color_label); 139 | 140 | // Generate an image message 141 | sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", label).toImageMsg(); 142 | sensor_msgs::ImagePtr color_label_msg = cv_bridge::CvImage(std_msgs::Header(), "rgb8", color_label).toImageMsg(); 143 | sensor_msgs::ImagePtr uncertainty_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", uncertainty_cv).toImageMsg(); 144 | 145 | return std::forward_as_tuple(label_msg, color_label_msg, uncertainty_msg); 146 | } 147 | 148 | /* 149 | * label_to_color : Convert a label image to color label image for visualization 150 | */ 151 | void 152 | PyTorchSegROS::label_to_color(cv::Mat& label, cv::Mat& color) 153 | { 154 | cv::cvtColor(label, color, CV_GRAY2BGR); 155 | cv::LUT(color, colormap_, color); 156 | } 157 | 158 | /* 159 | * msg_to_cv_bridge : Generate a cv_image pointer instance from a given image message pointer 160 | */ 161 | cv_bridge::CvImagePtr 162 | PyTorchSegROS::msg_to_cv_bridge(sensor_msgs::ImageConstPtr msg) 163 | { 164 | cv_bridge::CvImagePtr cv_ptr; 165 | 166 | // Convert the image message to a cv_bridge object 167 | try 168 | { 169 | cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::BGR8); 170 | } 171 | catch (cv_bridge::Exception& e) 172 | { 173 | ROS_ERROR("cv_bridge exception: %s", e.what()); 174 | return nullptr; 175 | } 176 | 177 | return cv_ptr; 178 | } 179 | 180 | /* 181 | * msg_to_cv_bridge : Generate a cv_image pointer instance from a given message 182 | */ 183 | cv_bridge::CvImagePtr 184 | PyTorchSegROS::msg_to_cv_bridge(sensor_msgs::Image msg) 185 | { 186 | cv_bridge::CvImagePtr cv_ptr; 187 | 188 | // Convert the image message to a cv_bridge object 189 | try 190 | { 191 | cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::BGR8); 192 | } 193 | catch (cv_bridge::Exception& e) 194 | { 195 | ROS_ERROR("cv_bridge exception: %s", e.what()); 196 | return nullptr; 197 | } 198 | 199 | return cv_ptr; 200 | } 201 | -------------------------------------------------------------------------------- /src/impl/pytorch_seg_trav_ros.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * A ROS node to do inference using PyTorch model 3 | * Shigemichi Matsuzaki 4 | * 5 | */ 6 | 7 | #include 8 | 9 | PyTorchSegTravROS::PyTorchSegTravROS(ros::NodeHandle & nh) 10 | : it_(nh), nh_(nh) 11 | { 12 | sub_image_ = it_.subscribe("image", 1, &PyTorchSegTravROS::image_callback, this); 13 | pub_label_image_ = it_.advertise("label", 10); 14 | pub_color_image_ = it_.advertise("color_label", 10); 15 | pub_prob_image_ = it_.advertise("prob", 10); 16 | pub_uncertainty_image_ = it_.advertise("uncertainty", 10); 17 | get_label_image_server_ = nh_.advertiseService("get_label_image", &PyTorchSegTravROS::image_inference_srv_callback, this); 18 | 19 | // Import the model 20 | std::string filename; 21 | nh_.param("model_file", filename, ""); 22 | pt_wrapper_ptr_.reset(new PyTorchCppWrapperSegTrav(filename, 4)); 23 | if(!pt_wrapper_ptr_->import_module(filename)) { 24 | ROS_ERROR("Failed to import the model file [%s]", filename.c_str()); 25 | ros::shutdown(); 26 | } 27 | 28 | // Import color map image 29 | std::string colormap_name; 30 | nh_.param("colormap", colormap_name, ""); 31 | colormap_ = cv::imread(colormap_name); 32 | if(colormap_.empty()) { 33 | ROS_ERROR("Failed to import the colormap file [%s]", colormap_name.c_str()); 34 | ros::shutdown(); 35 | } 36 | 37 | } 38 | 39 | void 40 | PyTorchSegTravROS::image_callback(const sensor_msgs::ImageConstPtr& msg) 41 | { 42 | ROS_INFO("[PyTorchSegTravROS image_callback] Let's start!!"); 43 | 44 | // Convert the image message to a cv_bridge object 45 | cv_bridge::CvImagePtr cv_ptr = msg_to_cv_bridge(msg); 46 | 47 | // Run inference 48 | sensor_msgs::ImagePtr label_msg; 49 | sensor_msgs::ImagePtr color_label_msg; 50 | sensor_msgs::ImagePtr prob_msg; 51 | sensor_msgs::ImagePtr uncertainty_msg; 52 | std::tie(label_msg, color_label_msg, prob_msg, uncertainty_msg) = inference(cv_ptr->image); 53 | 54 | // Set header 55 | label_msg->header = msg->header; 56 | color_label_msg->header = msg->header; 57 | prob_msg->header = msg->header; 58 | uncertainty_msg->header = msg->header; 59 | 60 | pub_label_image_.publish(label_msg); 61 | pub_color_image_.publish(color_label_msg); 62 | pub_prob_image_.publish(prob_msg); 63 | pub_uncertainty_image_.publish(uncertainty_msg); 64 | } 65 | 66 | /* 67 | * image_inference_srv_callback : Callback for the service 68 | */ 69 | bool 70 | PyTorchSegTravROS::image_inference_srv_callback(semantic_segmentation_srvs::GetLabelAndProbability::Request & req, 71 | semantic_segmentation_srvs::GetLabelAndProbability::Response & res) 72 | { 73 | ROS_INFO("[PyTorchSegTravROS image_inference_srv_callback] Start"); 74 | 75 | // Convert the image message to a cv_bridge object 76 | cv_bridge::CvImagePtr cv_ptr = msg_to_cv_bridge(req.img); 77 | 78 | // Run inference 79 | sensor_msgs::ImagePtr label_msg; 80 | sensor_msgs::ImagePtr color_label_msg; 81 | sensor_msgs::ImagePtr prob_msg; 82 | sensor_msgs::ImagePtr uncertainty_msg; 83 | std::tie(label_msg, color_label_msg, prob_msg, uncertainty_msg) = inference(cv_ptr->image); 84 | 85 | res.label_img = *label_msg; 86 | res.colorlabel_img = *color_label_msg; 87 | res.prob_img = *prob_msg; 88 | res.uncertainty_img = *uncertainty_msg; 89 | 90 | return true; 91 | } 92 | 93 | /* 94 | * inference : Forward the given input image through the network and return the inference result 95 | */ 96 | std::tuple 97 | PyTorchSegTravROS::inference(cv::Mat & input_img) 98 | { 99 | 100 | // The size of the original image, to which the result of inference is resized back. 101 | int height_orig = input_img.size().height; 102 | int width_orig = input_img.size().width; 103 | 104 | cv::Size s(480, 256); 105 | // Resize the input image 106 | cv::resize(input_img, input_img, s); 107 | 108 | at::Tensor input_tensor; 109 | pt_wrapper_ptr_->img2tensor(input_img, input_tensor); 110 | 111 | // Normalize from [0, 255] -> [0, 1] 112 | input_tensor /= 255.0; 113 | // z-normalization 114 | std::vector mean_vec{0.485, 0.456, 0.406}; 115 | std::vector std_vec{0.229, 0.224, 0.225}; 116 | for(int i = 0; i < mean_vec.size(); i++) { 117 | input_tensor[0][i] = (input_tensor[0][i] - mean_vec[i]) / std_vec[i]; 118 | } 119 | // std::cout << input_tensor.sizes() << std::endl; 120 | 121 | // Execute the model and turn its output into a tensor. 122 | at::Tensor segmentation; 123 | at::Tensor prob; 124 | std::tie(segmentation, prob) = pt_wrapper_ptr_->get_output(input_tensor); 125 | prob = (prob[0][0]*255).to(torch::kCPU).to(torch::kByte); 126 | // at::Tensor output = pt_wrapper_ptr_->get_output(input_tensor); 127 | // Calculate argmax to get a label on each pixel 128 | // at::Tensor output_args = pt_wrapper_ptr_->get_argmax(output); 129 | 130 | at::Tensor output_args = pt_wrapper_ptr_->get_argmax(segmentation); 131 | 132 | // Uncertainty of segmentation 133 | at::Tensor uncertainty = pt_wrapper_ptr_->get_entropy(segmentation, true); 134 | // at::Tensor uncertainty = torch::zeros_like(prob); 135 | uncertainty = (uncertainty[0]*255).to(torch::kCPU).to(torch::kByte); 136 | 137 | // Convert to OpenCV 138 | cv::Mat label; 139 | cv::Mat prob_cv; 140 | cv::Mat uncertainty_cv; 141 | pt_wrapper_ptr_->tensor2img(output_args[0], label); 142 | pt_wrapper_ptr_->tensor2img(prob, prob_cv); 143 | pt_wrapper_ptr_->tensor2img(uncertainty, uncertainty_cv); 144 | 145 | // Set the size 146 | cv::Size s_orig(width_orig, height_orig); 147 | // Resize the input image back to the original size 148 | cv::resize(label, label, s_orig, cv::INTER_NEAREST); 149 | cv::resize(prob_cv, prob_cv, s_orig, cv::INTER_LINEAR); 150 | cv::resize(uncertainty_cv, uncertainty_cv, s_orig, cv::INTER_LINEAR); 151 | // Generate color label image 152 | cv::Mat color_label; 153 | label_to_color(label, color_label); 154 | 155 | // Generate an image message 156 | sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", label).toImageMsg(); 157 | sensor_msgs::ImagePtr color_label_msg = cv_bridge::CvImage(std_msgs::Header(), "rgb8", color_label).toImageMsg(); 158 | sensor_msgs::ImagePtr prob_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", prob_cv).toImageMsg(); 159 | sensor_msgs::ImagePtr uncertainty_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", uncertainty_cv).toImageMsg(); 160 | 161 | return std::forward_as_tuple(label_msg, color_label_msg, prob_msg, uncertainty_msg); 162 | } 163 | 164 | /* 165 | * label_to_color : Convert a label image to color label image for visualization 166 | */ 167 | void 168 | PyTorchSegTravROS::label_to_color(cv::Mat& label, cv::Mat& color) 169 | { 170 | cv::cvtColor(label, color, CV_GRAY2BGR); 171 | cv::LUT(color, colormap_, color); 172 | } 173 | 174 | /* 175 | * msg_to_cv_bridge : Generate a cv_image pointer instance from a given image message pointer 176 | */ 177 | cv_bridge::CvImagePtr 178 | PyTorchSegTravROS::msg_to_cv_bridge(sensor_msgs::ImageConstPtr msg) 179 | { 180 | cv_bridge::CvImagePtr cv_ptr; 181 | 182 | // Convert the image message to a cv_bridge object 183 | try 184 | { 185 | cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::BGR8); 186 | } 187 | catch (cv_bridge::Exception& e) 188 | { 189 | ROS_ERROR("cv_bridge exception: %s", e.what()); 190 | return nullptr; 191 | } 192 | 193 | return cv_ptr; 194 | } 195 | 196 | /* 197 | * msg_to_cv_bridge : Generate a cv_image pointer instance from a given message 198 | */ 199 | cv_bridge::CvImagePtr 200 | PyTorchSegTravROS::msg_to_cv_bridge(sensor_msgs::Image msg) 201 | { 202 | cv_bridge::CvImagePtr cv_ptr; 203 | 204 | // Convert the image message to a cv_bridge object 205 | try 206 | { 207 | cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::BGR8); 208 | } 209 | catch (cv_bridge::Exception& e) 210 | { 211 | ROS_ERROR("cv_bridge exception: %s", e.what()); 212 | return nullptr; 213 | } 214 | 215 | return cv_ptr; 216 | } 217 | -------------------------------------------------------------------------------- /src/impl/pytorch_seg_trav_path_ros.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * A ROS node to do inference using PyTorch model 3 | * Shigemichi Matsuzaki 4 | * 5 | */ 6 | 7 | #include 8 | 9 | PyTorchSegTravPathROS::PyTorchSegTravPathROS(ros::NodeHandle & nh) 10 | : it_(nh), nh_(nh) 11 | { 12 | sub_image_ = it_.subscribe("image", 1, &PyTorchSegTravPathROS::image_callback, this); 13 | pub_label_image_ = it_.advertise("label", 1); 14 | pub_color_image_ = it_.advertise("color_label", 1); 15 | pub_prob_image_ = it_.advertise("prob", 1); 16 | pub_uncertainty_image_ = it_.advertise("uncertainty", 1); 17 | pub_start_point_ = nh_.advertise("start_point", 1); 18 | pub_end_point_ = nh_.advertise("end_point", 1); 19 | get_label_image_server_ = nh_.advertiseService("get_label_image", &PyTorchSegTravPathROS::image_inference_srv_callback, this); 20 | 21 | // Import the model 22 | std::string filename; 23 | nh_.param("model_file", filename, ""); 24 | pt_wrapper_ptr_.reset(new PyTorchCppWrapperSegTravPath(filename, 4)); 25 | if(!pt_wrapper_ptr_->import_module(filename)) { 26 | ROS_ERROR("Failed to import the model file [%s]", filename.c_str()); 27 | ros::shutdown(); 28 | } 29 | 30 | // Import color map image 31 | std::string colormap_name; 32 | nh_.param("colormap", colormap_name, ""); 33 | colormap_ = cv::imread(colormap_name); 34 | if(colormap_.empty()) { 35 | ROS_ERROR("Failed to import the colormap file [%s]", colormap_name.c_str()); 36 | ros::shutdown(); 37 | } 38 | 39 | } 40 | 41 | /** 42 | * @brief Image callback 43 | * @param[in] msg Message 44 | */ 45 | void 46 | PyTorchSegTravPathROS::image_callback(const sensor_msgs::ImageConstPtr& msg) 47 | { 48 | ROS_INFO("[PyTorchSegTravPathROS image_callback] Let's start!!"); 49 | 50 | // Convert the image message to a cv_bridge object 51 | cv_bridge::CvImagePtr cv_ptr = msg_to_cv_bridge(msg); 52 | stamp_of_current_image_ = msg->header.stamp; 53 | 54 | // Run inference 55 | sensor_msgs::ImagePtr label_msg; 56 | sensor_msgs::ImagePtr color_label_msg; 57 | sensor_msgs::ImagePtr prob_msg; 58 | sensor_msgs::ImagePtr uncertainty_msg; 59 | geometry_msgs::PointStampedPtr start_point_msg; 60 | geometry_msgs::PointStampedPtr end_point_msg; 61 | std::tie(label_msg, color_label_msg, prob_msg, uncertainty_msg, start_point_msg, end_point_msg) = inference(cv_ptr->image); 62 | 63 | // Set header 64 | label_msg->header = msg->header; 65 | color_label_msg->header = msg->header; 66 | prob_msg->header = msg->header; 67 | uncertainty_msg->header = msg->header; 68 | 69 | // Publish the messages 70 | pub_label_image_.publish(label_msg); 71 | pub_color_image_.publish(color_label_msg); 72 | pub_prob_image_.publish(prob_msg); 73 | pub_uncertainty_image_.publish(uncertainty_msg); 74 | pub_start_point_.publish(start_point_msg); 75 | pub_end_point_.publish(end_point_msg); 76 | } 77 | 78 | /** 79 | * @brief Main function for inference using the model 80 | * @param[in] input_image OpenCV image 81 | * @return A tuple of messages of the inference results 82 | */ 83 | bool 84 | PyTorchSegTravPathROS::image_inference_srv_callback(semantic_segmentation_srvs::GetLabelAndProbability::Request & req, 85 | semantic_segmentation_srvs::GetLabelAndProbability::Response & res) 86 | { 87 | ROS_INFO("[PyTorchSegTravPathROS image_inference_srv_callback] Start"); 88 | 89 | // Convert the image message to a cv_bridge object 90 | cv_bridge::CvImagePtr cv_ptr = msg_to_cv_bridge(req.img); 91 | 92 | // Run inference 93 | sensor_msgs::ImagePtr label_msg; 94 | sensor_msgs::ImagePtr color_label_msg; 95 | sensor_msgs::ImagePtr prob_msg; 96 | sensor_msgs::ImagePtr uncertainty_msg; 97 | geometry_msgs::PointStampedPtr start_point_msg; 98 | geometry_msgs::PointStampedPtr end_point_msg; 99 | std::tie(label_msg, color_label_msg, prob_msg, uncertainty_msg, start_point_msg, end_point_msg) = inference(cv_ptr->image); 100 | 101 | res.label_img = *label_msg; 102 | res.colorlabel_img = *color_label_msg; 103 | res.prob_img = *prob_msg; 104 | 105 | return true; 106 | } 107 | 108 | /** 109 | * @brief Service callback 110 | * @param[in] req Request 111 | * @param[in] res Response 112 | * @return True if the service succeeded 113 | */ 114 | std::tuple 116 | PyTorchSegTravPathROS::inference(cv::Mat & input_img) 117 | { 118 | 119 | // The size of the original image, to which the result of inference is resized back. 120 | int height_orig = input_img.size().height; 121 | int width_orig = input_img.size().width; 122 | 123 | cv::Size s(480, 256); 124 | // Resize the input image 125 | cv::resize(input_img, input_img, s); 126 | 127 | at::Tensor input_tensor; 128 | pt_wrapper_ptr_->img2tensor(input_img, input_tensor); 129 | 130 | normalize_tensor(input_tensor); 131 | 132 | // Execute the model and turn its output into a tensor. 133 | at::Tensor segmentation; 134 | at::Tensor prob; 135 | at::Tensor points; 136 | // segmentation: raw output for segmentation (before softmax) 137 | // prob: traversability 138 | // points: coordinates of the line points 139 | std::tie(segmentation, prob, points) = pt_wrapper_ptr_->get_output(input_tensor); 140 | prob = (prob[0][0]*255).to(torch::kCPU).to(torch::kByte); 141 | 142 | // Get class label map by taking argmax of 'segmentation' 143 | at::Tensor output_args = pt_wrapper_ptr_->get_argmax(segmentation); 144 | 145 | // Uncertainty of segmentation 146 | at::Tensor uncertainty = pt_wrapper_ptr_->get_entropy(segmentation, true); 147 | uncertainty = (uncertainty[0]*255).to(torch::kCPU).to(torch::kByte); 148 | 149 | // Set the size 150 | cv::Size s_orig(width_orig, height_orig); 151 | 152 | // Convert to OpenCV 153 | cv::Mat label; 154 | cv::Mat prob_cv; 155 | cv::Mat uncertainty_cv = cv::Mat::zeros(s_orig.height, s_orig.width, CV_8U); 156 | // Segmentation label 157 | label = pt_wrapper_ptr_->tensor2img(output_args[0]); 158 | // Segmentation label 159 | uncertainty_cv = pt_wrapper_ptr_->tensor2img(uncertainty); 160 | // uncertainty_cv = pt_wrapper_ptr_->tensor2img((uncertainty*255).to(torch::kCPU).to(torch::kByte)); 161 | // Traverability 162 | prob_cv = pt_wrapper_ptr_->tensor2img(prob); 163 | 164 | // Resize the input image back to the original size 165 | cv::resize(label, label, s_orig, cv::INTER_NEAREST); 166 | cv::resize(prob_cv, prob_cv, s_orig, cv::INTER_LINEAR); 167 | cv::resize(uncertainty_cv, uncertainty_cv, s_orig, cv::INTER_LINEAR); 168 | 169 | // Generate color label image 170 | cv::Mat color_label; 171 | label_to_color(label, color_label); 172 | 173 | // Generate an image message and point messages 174 | sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", label).toImageMsg(); 175 | sensor_msgs::ImagePtr color_label_msg = cv_bridge::CvImage(std_msgs::Header(), "rgb8", color_label).toImageMsg(); 176 | // Problem: Wrong data is sometimes assigned to 'prob_cv' 177 | sensor_msgs::ImagePtr prob_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", prob_cv).toImageMsg(); 178 | sensor_msgs::ImagePtr uncertainty_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", uncertainty_cv).toImageMsg(); 179 | geometry_msgs::PointStampedPtr start_point_msg(new geometry_msgs::PointStamped), end_point_msg(new geometry_msgs::PointStamped); 180 | std::tie(start_point_msg, end_point_msg) = tensor_to_points(points, width_orig, height_orig); 181 | 182 | return std::forward_as_tuple(label_msg, color_label_msg, prob_msg, uncertainty_msg, start_point_msg, end_point_msg); 183 | } 184 | 185 | /** 186 | * @brief Convert a tensor with a size of (1, 4) to start and end points (x, y) 187 | * @param[in] point_tensor (1, 4) tensor 188 | * @param[in] width Original width of the image 189 | * @param[in] height Original height of the image 190 | * @return A tuple of start and end points as geometry_msgs::PointStampedPtr 191 | */ 192 | std::tuple 193 | PyTorchSegTravPathROS::tensor_to_points(const at::Tensor point_tensor, const int & width, const int & height) 194 | { 195 | geometry_msgs::PointStampedPtr start_point_msg(new geometry_msgs::PointStamped), end_point_msg(new geometry_msgs::PointStamped); 196 | // Important: put the data on the CPU before accessing the data. 197 | // Absense of this code will result in runtime error. 198 | at::Tensor points = point_tensor.to(torch::kCPU); 199 | auto points_a = points.accessor(); 200 | 201 | // Initialize messgaes 202 | start_point_msg->header.stamp = stamp_of_current_image_;//ros::Time::now(); 203 | start_point_msg->header.frame_id = "kinect2_rgb_optical_frame"; 204 | end_point_msg->header.stamp = stamp_of_current_image_;//ros::Time::now(); 205 | end_point_msg->header.frame_id = "kinect2_rgb_optical_frame"; 206 | // Point tensor has coordinate values normalized with the width and height. 207 | // Therefore each value is multiplied by width or height. 208 | start_point_msg->point.x = points_a[0][0] * width; 209 | start_point_msg->point.y = points_a[0][1] * height; 210 | end_point_msg->point.x = points_a[0][2] * width; 211 | end_point_msg->point.y = points_a[0][3] * height; 212 | 213 | return std::forward_as_tuple(start_point_msg, end_point_msg); 214 | } 215 | 216 | /** 217 | * @brief Convert a label image to color label image for visualization 218 | * @param[in] label Label image 219 | * @param[out] color_label Color image mapped from the label image 220 | */ 221 | void 222 | PyTorchSegTravPathROS::label_to_color(cv::Mat& label, cv::Mat& color_label) 223 | { 224 | cv::cvtColor(label, color_label, CV_GRAY2BGR); 225 | cv::LUT(color_label, colormap_, color_label); 226 | } 227 | 228 | /** 229 | * @brief Convert Image message to cv_bridge 230 | * @param[in] msg Pointer of image message 231 | * @return cv_bridge 232 | */ 233 | cv_bridge::CvImagePtr 234 | PyTorchSegTravPathROS::msg_to_cv_bridge(sensor_msgs::ImageConstPtr msg) 235 | { 236 | cv_bridge::CvImagePtr cv_ptr; 237 | 238 | // Convert the image message to a cv_bridge object 239 | try 240 | { 241 | cv_ptr = cv_bridge::toCvCopy(msg, msg->encoding); 242 | } 243 | catch (cv_bridge::Exception& e) 244 | { 245 | ROS_ERROR("cv_bridge exception: %s", e.what()); 246 | return nullptr; 247 | } 248 | 249 | return cv_ptr; 250 | } 251 | 252 | /** 253 | * @brief Convert Image message to cv_bridge 254 | * @param[in] msg Image message 255 | * @return cv_bridge 256 | */ 257 | cv_bridge::CvImagePtr 258 | PyTorchSegTravPathROS::msg_to_cv_bridge(sensor_msgs::Image msg) 259 | { 260 | cv_bridge::CvImagePtr cv_ptr; 261 | 262 | // Convert the image message to a cv_bridge object 263 | try 264 | { 265 | cv_ptr = cv_bridge::toCvCopy(msg, msg.encoding); 266 | } 267 | catch (cv_bridge::Exception& e) 268 | { 269 | ROS_ERROR("cv_bridge exception: %s", e.what()); 270 | return nullptr; 271 | } 272 | 273 | return cv_ptr; 274 | } 275 | 276 | /** 277 | * @brief Normalize a tensor to feed in a model 278 | * @param[in] input Tensor 279 | */ 280 | void 281 | PyTorchSegTravPathROS::normalize_tensor(at::Tensor & input_tensor) 282 | { 283 | // Normalize from [0, 255] -> [0, 1] 284 | input_tensor /= 255.0; 285 | // z-normalization 286 | std::vector mean_vec{0.485, 0.456, 0.406}; 287 | std::vector std_vec{0.229, 0.224, 0.225}; 288 | for(int i = 0; i < mean_vec.size(); i++) { 289 | input_tensor[0][i] = (input_tensor[0][i] - mean_vec[i]) / std_vec[i]; 290 | } 291 | } 292 | -------------------------------------------------------------------------------- /script/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class InitialBlock(nn.Module): 6 | """The initial block is composed of two branches: 7 | 1. a main branch which performs a regular convolution with stride 2; 8 | 2. an extension branch which performs max-pooling. 9 | 10 | Doing both operations in parallel and concatenating their results 11 | allows for efficient downsampling and expansion. The main branch 12 | outputs 13 feature maps while the extension branch outputs 3, for a 13 | total of 16 feature maps after concatenation. 14 | 15 | Keyword arguments: 16 | - in_channels (int): the number of input channels. 17 | - out_channels (int): the number output channels. 18 | - kernel_size (int, optional): the kernel size of the filters used in 19 | the convolution layer. Default: 3. 20 | - padding (int, optional): zero-padding added to both sides of the 21 | input. Default: 0. 22 | - bias (bool, optional): Adds a learnable bias to the output if 23 | ``True``. Default: False. 24 | - relu (bool, optional): When ``True`` ReLU is used as the activation 25 | function; otherwise, PReLU is used. Default: True. 26 | 27 | """ 28 | 29 | def __init__(self, 30 | in_channels, 31 | out_channels, 32 | bias=False, 33 | relu=True): 34 | super(InitialBlock, self).__init__() 35 | 36 | if relu: 37 | activation = nn.ReLU 38 | else: 39 | activation = nn.PReLU 40 | 41 | # Main branch - As stated above the number of output channels for this 42 | # branch is the total minus 3, since the remaining channels come from 43 | # the extension branch 44 | self.main_branch = nn.Conv2d( 45 | in_channels, 46 | out_channels - 3, 47 | kernel_size=3, 48 | stride=2, 49 | padding=1, 50 | bias=bias) 51 | 52 | # Extension branch 53 | self.ext_branch = nn.MaxPool2d(3, stride=2, padding=1) 54 | 55 | # Initialize batch normalization to be used after concatenation 56 | self.batch_norm = nn.BatchNorm2d(out_channels) 57 | 58 | # PReLU layer to apply after concatenating the branches 59 | self.out_activation = activation() 60 | 61 | self._initialize_weights() 62 | 63 | def forward(self, x): 64 | main = self.main_branch(x) 65 | ext = self.ext_branch(x) 66 | 67 | # Concatenate branches 68 | out = torch.cat((main, ext), 1) 69 | 70 | # Apply batch normalization 71 | out = self.batch_norm(out) 72 | 73 | return self.out_activation(out) 74 | 75 | def _initialize_weights(self): 76 | for m in self.modules(): 77 | if isinstance(m, nn.Conv2d): 78 | #init.orthogonal_(m.weight.data, gain=init.calculate_gain('relu')) 79 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 80 | if m.bias is not None: 81 | nn.init.constant_(m.bias, 0) 82 | elif isinstance(m, nn.BatchNorm2d): 83 | nn.init.constant_(m.weight, 1) 84 | nn.init.constant_(m.bias, 0) 85 | elif isinstance(m, nn.Linear): 86 | nn.init.normal_(m.weight, 0, 0.01) 87 | nn.init.constant_(m.bias, 0) 88 | 89 | 90 | class RegularBottleneck(nn.Module): 91 | """Regular bottlenecks are the main building block of ENet. 92 | Main branch: 93 | 1. Shortcut connection. 94 | 95 | Extension branch: 96 | 1. 1x1 convolution which decreases the number of channels by 97 | ``internal_ratio``, also called a projection; 98 | 2. regular, dilated or asymmetric convolution; 99 | 3. 1x1 convolution which increases the number of channels back to 100 | ``channels``, also called an expansion; 101 | 4. dropout as a regularizer. 102 | 103 | Keyword arguments: 104 | - channels (int): the number of input and output channels. 105 | - internal_ratio (int, optional): a scale factor applied to 106 | ``channels`` used to compute the number of 107 | channels after the projection. eg. given ``channels`` equal to 128 and 108 | internal_ratio equal to 2 the number of channels after the projection 109 | is 64. Default: 4. 110 | - kernel_size (int, optional): the kernel size of the filters used in 111 | the convolution layer described above in item 2 of the extension 112 | branch. Default: 3. 113 | - padding (int, optional): zero-padding added to both sides of the 114 | input. Default: 0. 115 | - dilation (int, optional): spacing between kernel elements for the 116 | convolution described in item 2 of the extension branch. Default: 1. 117 | asymmetric (bool, optional): flags if the convolution described in 118 | item 2 of the extension branch is asymmetric or not. Default: False. 119 | - dropout_prob (float, optional): probability of an element to be 120 | zeroed. Default: 0 (no dropout). 121 | - bias (bool, optional): Adds a learnable bias to the output if 122 | ``True``. Default: False. 123 | - relu (bool, optional): When ``True`` ReLU is used as the activation 124 | function; otherwise, PReLU is used. Default: True. 125 | 126 | """ 127 | 128 | def __init__(self, 129 | channels, 130 | internal_ratio=4, 131 | kernel_size=3, 132 | padding=0, 133 | dilation=1, 134 | asymmetric=False, 135 | dropout_prob=0, 136 | bias=False, 137 | relu=True): 138 | super(RegularBottleneck, self).__init__() 139 | 140 | # Check in the internal_scale parameter is within the expected range 141 | # [1, channels] 142 | if internal_ratio <= 1 or internal_ratio > channels: 143 | raise RuntimeError("Value out of range. Expected value in the " 144 | "interval [1, {0}], got internal_scale={1}." 145 | .format(channels, internal_ratio)) 146 | 147 | internal_channels = channels // internal_ratio 148 | 149 | if relu: 150 | activation = nn.ReLU 151 | else: 152 | activation = nn.PReLU 153 | 154 | # Main branch - shortcut connection 155 | 156 | # Extension branch - 1x1 convolution, followed by a regular, dilated or 157 | # asymmetric convolution, followed by another 1x1 convolution, and, 158 | # finally, a regularizer (spatial dropout). Number of channels is constant. 159 | 160 | # 1x1 projection convolution 161 | self.ext_conv1 = nn.Sequential( 162 | nn.Conv2d( 163 | channels, 164 | internal_channels, 165 | kernel_size=1, 166 | stride=1, 167 | bias=bias), nn.BatchNorm2d(internal_channels), activation()) 168 | 169 | # If the convolution is asymmetric we split the main convolution in 170 | # two. Eg. for a 5x5 asymmetric convolution we have two convolution: 171 | # the first is 5x1 and the second is 1x5. 172 | if asymmetric: 173 | self.ext_conv2 = nn.Sequential( 174 | nn.Conv2d( 175 | internal_channels, 176 | internal_channels, 177 | kernel_size=(kernel_size, 1), 178 | stride=1, 179 | padding=(padding, 0), 180 | dilation=dilation, 181 | bias=bias), nn.BatchNorm2d(internal_channels), activation(), 182 | nn.Conv2d( 183 | internal_channels, 184 | internal_channels, 185 | kernel_size=(1, kernel_size), 186 | stride=1, 187 | padding=(0, padding), 188 | dilation=dilation, 189 | bias=bias), nn.BatchNorm2d(internal_channels), activation()) 190 | else: 191 | self.ext_conv2 = nn.Sequential( 192 | nn.Conv2d( 193 | internal_channels, 194 | internal_channels, 195 | kernel_size=kernel_size, 196 | stride=1, 197 | padding=padding, 198 | dilation=dilation, 199 | bias=bias), nn.BatchNorm2d(internal_channels), activation()) 200 | 201 | # 1x1 expansion convolution 202 | self.ext_conv3 = nn.Sequential( 203 | nn.Conv2d( 204 | internal_channels, 205 | channels, 206 | kernel_size=1, 207 | stride=1, 208 | bias=bias), nn.BatchNorm2d(channels), activation()) 209 | 210 | self.ext_regul = nn.Dropout2d(p=dropout_prob) 211 | 212 | # PReLU layer to apply after adding the branches 213 | self.out_activation = activation() 214 | 215 | self._initialize_weights() 216 | 217 | def forward(self, x): 218 | # Main branch shortcut 219 | main = x 220 | 221 | # Extension branch 222 | ext = self.ext_conv1(x) 223 | ext = self.ext_conv2(ext) 224 | ext = self.ext_conv3(ext) 225 | ext = self.ext_regul(ext) 226 | 227 | # Add main and extension branches 228 | out = main + ext 229 | 230 | return self.out_activation(out) 231 | 232 | def _initialize_weights(self): 233 | for m in self.modules(): 234 | if isinstance(m, nn.Conv2d): 235 | #init.orthogonal_(m.weight.data, gain=init.calculate_gain('relu')) 236 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 237 | if m.bias is not None: 238 | nn.init.constant_(m.bias, 0) 239 | elif isinstance(m, nn.BatchNorm2d): 240 | nn.init.constant_(m.weight, 1) 241 | nn.init.constant_(m.bias, 0) 242 | elif isinstance(m, nn.Linear): 243 | nn.init.normal_(m.weight, 0, 0.01) 244 | nn.init.constant_(m.bias, 0) 245 | 246 | class DownsamplingBottleneck(nn.Module): 247 | """Downsampling bottlenecks further downsample the feature map size. 248 | 249 | Main branch: 250 | 1. max pooling with stride 2; indices are saved to be used for 251 | unpooling later. 252 | 253 | Extension branch: 254 | 1. 2x2 convolution with stride 2 that decreases the number of channels 255 | by ``internal_ratio``, also called a projection; 256 | 2. regular convolution (by default, 3x3); 257 | 3. 1x1 convolution which increases the number of channels to 258 | ``out_channels``, also called an expansion; 259 | 4. dropout as a regularizer. 260 | 261 | Keyword arguments: 262 | - in_channels (int): the number of input channels. 263 | - out_channels (int): the number of output channels. 264 | - internal_ratio (int, optional): a scale factor applied to ``channels`` 265 | used to compute the number of channels after the projection. eg. given 266 | ``channels`` equal to 128 and internal_ratio equal to 2 the number of 267 | channels after the projection is 64. Default: 4. 268 | - return_indices (bool, optional): if ``True``, will return the max 269 | indices along with the outputs. Useful when unpooling later. 270 | - dropout_prob (float, optional): probability of an element to be 271 | zeroed. Default: 0 (no dropout). 272 | - bias (bool, optional): Adds a learnable bias to the output if 273 | ``True``. Default: False. 274 | - relu (bool, optional): When ``True`` ReLU is used as the activation 275 | function; otherwise, PReLU is used. Default: True. 276 | 277 | """ 278 | 279 | def __init__(self, 280 | in_channels, 281 | out_channels, 282 | internal_ratio=4, 283 | return_indices=False, 284 | dropout_prob=0, 285 | bias=False, 286 | relu=True): 287 | super(DownsamplingBottleneck, self).__init__() 288 | 289 | # Store parameters that are needed later 290 | self.return_indices = return_indices 291 | 292 | # Check in the internal_scale parameter is within the expected range 293 | # [1, channels] 294 | if internal_ratio <= 1 or internal_ratio > in_channels: 295 | raise RuntimeError("Value out of range. Expected value in the " 296 | "interval [1, {0}], got internal_scale={1}. " 297 | .format(in_channels, internal_ratio)) 298 | 299 | internal_channels = in_channels // internal_ratio 300 | 301 | if relu: 302 | activation = nn.ReLU 303 | else: 304 | activation = nn.PReLU 305 | 306 | # Main branch - max pooling followed by feature map (channels) padding 307 | self.main_max1 = nn.MaxPool2d( 308 | 2, 309 | stride=2, 310 | return_indices=return_indices) 311 | 312 | # Extension branch - 2x2 convolution, followed by a regular, dilated or 313 | # asymmetric convolution, followed by another 1x1 convolution. Number 314 | # of channels is doubled. 315 | 316 | # 2x2 projection convolution with stride 2 317 | self.ext_conv1 = nn.Sequential( 318 | nn.Conv2d( 319 | in_channels, 320 | internal_channels, 321 | kernel_size=2, 322 | stride=2, 323 | bias=bias), nn.BatchNorm2d(internal_channels), activation()) 324 | 325 | # Convolution 326 | self.ext_conv2 = nn.Sequential( 327 | nn.Conv2d( 328 | internal_channels, 329 | internal_channels, 330 | kernel_size=3, 331 | stride=1, 332 | padding=1, 333 | bias=bias), nn.BatchNorm2d(internal_channels), activation()) 334 | 335 | # 1x1 expansion convolution 336 | self.ext_conv3 = nn.Sequential( 337 | nn.Conv2d( 338 | internal_channels, 339 | out_channels, 340 | kernel_size=1, 341 | stride=1, 342 | bias=bias), nn.BatchNorm2d(out_channels), activation()) 343 | 344 | self.ext_regul = nn.Dropout2d(p=dropout_prob) 345 | 346 | # PReLU layer to apply after concatenating the branches 347 | self.out_activation = activation() 348 | 349 | self._initialize_weights() 350 | 351 | def forward(self, x): 352 | # Main branch shortcut 353 | if self.return_indices: 354 | main, max_indices = self.main_max1(x) 355 | else: 356 | main = self.main_max1(x) 357 | 358 | # Extension branch 359 | ext = self.ext_conv1(x) 360 | ext = self.ext_conv2(ext) 361 | ext = self.ext_conv3(ext) 362 | ext = self.ext_regul(ext) 363 | 364 | # Main branch channel padding 365 | n, ch_ext, h, w = ext.size() 366 | ch_main = main.size()[1] 367 | padding = torch.zeros(n, ch_ext - ch_main, h, w) 368 | 369 | # Before concatenating, check if main is on the CPU or GPU and 370 | # convert padding accordingly 371 | if main.is_cuda: 372 | padding = padding.cuda() 373 | 374 | # Concatenate 375 | main = torch.cat((main, padding), 1) 376 | 377 | # Add main and extension branches 378 | out = main + ext 379 | 380 | return self.out_activation(out), max_indices 381 | 382 | def _initialize_weights(self): 383 | for m in self.modules(): 384 | if isinstance(m, nn.Conv2d): 385 | #init.orthogonal_(m.weight.data, gain=init.calculate_gain('relu')) 386 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 387 | if m.bias is not None: 388 | nn.init.constant_(m.bias, 0) 389 | elif isinstance(m, nn.BatchNorm2d): 390 | nn.init.constant_(m.weight, 1) 391 | nn.init.constant_(m.bias, 0) 392 | elif isinstance(m, nn.Linear): 393 | nn.init.normal_(m.weight, 0, 0.01) 394 | nn.init.constant_(m.bias, 0) 395 | 396 | 397 | 398 | class UpsamplingBottleneck(nn.Module): 399 | """The upsampling bottlenecks upsample the feature map resolution using max 400 | pooling indices stored from the corresponding downsampling bottleneck. 401 | 402 | Main branch: 403 | 1. 1x1 convolution with stride 1 that decreases the number of channels by 404 | ``internal_ratio``, also called a projection; 405 | 2. max unpool layer using the max pool indices from the corresponding 406 | downsampling max pool layer. 407 | 408 | Extension branch: 409 | 1. 1x1 convolution with stride 1 that decreases the number of channels by 410 | ``internal_ratio``, also called a projection; 411 | 2. transposed convolution (by default, 3x3); 412 | 3. 1x1 convolution which increases the number of channels to 413 | ``out_channels``, also called an expansion; 414 | 4. dropout as a regularizer. 415 | 416 | Keyword arguments: 417 | - in_channels (int): the number of input channels. 418 | - out_channels (int): the number of output channels. 419 | - internal_ratio (int, optional): a scale factor applied to ``in_channels`` 420 | used to compute the number of channels after the projection. eg. given 421 | ``in_channels`` equal to 128 and ``internal_ratio`` equal to 2 the number 422 | of channels after the projection is 64. Default: 4. 423 | - dropout_prob (float, optional): probability of an element to be zeroed. 424 | Default: 0 (no dropout). 425 | - bias (bool, optional): Adds a learnable bias to the output if ``True``. 426 | Default: False. 427 | - relu (bool, optional): When ``True`` ReLU is used as the activation 428 | function; otherwise, PReLU is used. Default: True. 429 | 430 | """ 431 | 432 | def __init__(self, 433 | in_channels, 434 | out_channels, 435 | internal_ratio=4, 436 | dropout_prob=0, 437 | bias=False, 438 | relu=True): 439 | super(UpsamplingBottleneck, self).__init__() 440 | 441 | # Check in the internal_scale parameter is within the expected range 442 | # [1, channels] 443 | if internal_ratio <= 1 or internal_ratio > in_channels: 444 | raise RuntimeError("Value out of range. Expected value in the " 445 | "interval [1, {0}], got internal_scale={1}. " 446 | .format(in_channels, internal_ratio)) 447 | 448 | internal_channels = in_channels // internal_ratio 449 | 450 | if relu: 451 | activation = nn.ReLU 452 | else: 453 | activation = nn.PReLU 454 | 455 | # Main branch - max pooling followed by feature map (channels) padding 456 | self.main_conv1 = nn.Sequential( 457 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias), 458 | nn.BatchNorm2d(out_channels)) 459 | 460 | # Remember that the stride is the same as the kernel_size, just like 461 | # the max pooling layers 462 | self.main_unpool1 = nn.MaxUnpool2d(kernel_size=2) 463 | 464 | # Extension branch - 1x1 convolution, followed by a regular, dilated or 465 | # asymmetric convolution, followed by another 1x1 convolution. Number 466 | # of channels is doubled. 467 | 468 | # 1x1 projection convolution with stride 1 469 | self.ext_conv1 = nn.Sequential( 470 | nn.Conv2d( 471 | in_channels, internal_channels, kernel_size=1, bias=bias), 472 | nn.BatchNorm2d(internal_channels), activation()) 473 | 474 | # Transposed convolution 475 | self.ext_tconv1 = nn.ConvTranspose2d( 476 | internal_channels, 477 | internal_channels, 478 | kernel_size=2, 479 | stride=2, 480 | bias=bias) 481 | self.ext_tconv1_bnorm = nn.BatchNorm2d(internal_channels) 482 | self.ext_tconv1_activation = activation() 483 | 484 | # 1x1 expansion convolution 485 | self.ext_conv2 = nn.Sequential( 486 | nn.Conv2d( 487 | internal_channels, out_channels, kernel_size=1, bias=bias), 488 | nn.BatchNorm2d(out_channels), activation()) 489 | 490 | self.ext_regul = nn.Dropout2d(p=dropout_prob) 491 | 492 | # PReLU layer to apply after concatenating the branches 493 | self.out_activation = activation() 494 | 495 | self._initialize_weights() 496 | 497 | def forward(self, x, max_indices, output_size): 498 | # Main branch shortcut 499 | main = self.main_conv1(x) 500 | main = self.main_unpool1( 501 | main, max_indices, output_size=output_size) 502 | 503 | # Extension branch 504 | ext = self.ext_conv1(x) 505 | ext = self.ext_tconv1(ext, output_size=output_size) 506 | ext = self.ext_tconv1_bnorm(ext) 507 | ext = self.ext_tconv1_activation(ext) 508 | ext = self.ext_conv2(ext) 509 | ext = self.ext_regul(ext) 510 | 511 | # Add main and extension branches 512 | out = main + ext 513 | 514 | return self.out_activation(out) 515 | 516 | def _initialize_weights(self): 517 | for m in self.modules(): 518 | if isinstance(m, nn.Conv2d): 519 | #init.orthogonal_(m.weight.data, gain=init.calculate_gain('relu')) 520 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 521 | if m.bias is not None: 522 | nn.init.constant_(m.bias, 0) 523 | elif isinstance(m, nn.BatchNorm2d): 524 | nn.init.constant_(m.weight, 1) 525 | nn.init.constant_(m.bias, 0) 526 | elif isinstance(m, nn.Linear): 527 | nn.init.normal_(m.weight, 0, 0.01) 528 | nn.init.constant_(m.bias, 0) 529 | 530 | class ENet(nn.Module): 531 | """Generate the ENet model. 532 | 533 | Keyword arguments: 534 | - num_classes (int): the number of classes to segment. 535 | - encoder_relu (bool, optional): When ``True`` ReLU is used as the 536 | activation function in the encoder blocks/layers; otherwise, PReLU 537 | is used. Default: False. 538 | - decoder_relu (bool, optional): When ``True`` ReLU is used as the 539 | activation function in the decoder blocks/layers; otherwise, PReLU 540 | is used. Default: True. 541 | 542 | """ 543 | 544 | def __init__(self, num_classes, encoder_relu=False, decoder_relu=True): 545 | super(ENet, self).__init__() 546 | 547 | self.initial_block = InitialBlock(3, 16, relu=encoder_relu) 548 | 549 | # Stage 1 - Encoder 550 | self.downsample1_0 = DownsamplingBottleneck( 551 | 16, 552 | 64, 553 | return_indices=True, 554 | dropout_prob=0.01, 555 | relu=encoder_relu) 556 | self.regular1_1 = RegularBottleneck( 557 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 558 | self.regular1_2 = RegularBottleneck( 559 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 560 | self.regular1_3 = RegularBottleneck( 561 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 562 | self.regular1_4 = RegularBottleneck( 563 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 564 | 565 | # Stage 2 - Encoder 566 | self.downsample2_0 = DownsamplingBottleneck( 567 | 64, 568 | 128, 569 | return_indices=True, 570 | dropout_prob=0.1, 571 | relu=encoder_relu) 572 | 573 | self.regular2_1 = RegularBottleneck( 574 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 575 | 576 | self.dilated2_2 = RegularBottleneck( 577 | 128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu) 578 | 579 | self.asymmetric2_3 = RegularBottleneck( 580 | 128, 581 | kernel_size=5, 582 | padding=2, 583 | asymmetric=True, 584 | dropout_prob=0.1, 585 | relu=encoder_relu) 586 | 587 | self.dilated2_4 = RegularBottleneck( 588 | 128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu) 589 | 590 | self.regular2_5 = RegularBottleneck( 591 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 592 | 593 | self.dilated2_6 = RegularBottleneck( 594 | 128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu) 595 | 596 | self.asymmetric2_7 = RegularBottleneck( 597 | 128, 598 | kernel_size=5, 599 | asymmetric=True, 600 | padding=2, 601 | dropout_prob=0.1, 602 | relu=encoder_relu) 603 | 604 | self.dilated2_8 = RegularBottleneck( 605 | 128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu) 606 | 607 | # Stage 3 - Encoder 608 | self.regular3_0 = RegularBottleneck( 609 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 610 | 611 | self.dilated3_1 = RegularBottleneck( 612 | 128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu) 613 | 614 | self.asymmetric3_2 = RegularBottleneck( 615 | 128, 616 | kernel_size=5, 617 | padding=2, 618 | asymmetric=True, 619 | dropout_prob=0.1, 620 | relu=encoder_relu) 621 | 622 | self.dilated3_3 = RegularBottleneck( 623 | 128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu) 624 | 625 | self.regular3_4 = RegularBottleneck( 626 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 627 | 628 | self.dilated3_5 = RegularBottleneck( 629 | 128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu) 630 | 631 | self.asymmetric3_6 = RegularBottleneck( 632 | 128, 633 | kernel_size=5, 634 | asymmetric=True, 635 | padding=2, 636 | dropout_prob=0.1, 637 | relu=encoder_relu) 638 | 639 | self.dilated3_7 = RegularBottleneck( 640 | 128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu) 641 | 642 | # Stage 4 - Decoder 643 | self.upsample4_0 = UpsamplingBottleneck( 644 | 128, 64, dropout_prob=0.1, relu=decoder_relu) 645 | 646 | self.regular4_1 = RegularBottleneck( 647 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu) 648 | 649 | self.regular4_2 = RegularBottleneck( 650 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu) 651 | 652 | # Stage 5 - Decoder 653 | self.upsample5_0 = UpsamplingBottleneck( 654 | 64, 16, dropout_prob=0.1, relu=decoder_relu) 655 | 656 | self.regular5_1 = RegularBottleneck( 657 | 16, padding=1, dropout_prob=0.1, relu=decoder_relu) 658 | 659 | self.transposed_conv = nn.ConvTranspose2d( 660 | 16, 661 | num_classes, 662 | kernel_size=3, 663 | stride=2, 664 | padding=1, 665 | bias=False) 666 | 667 | self._initialize_weights() 668 | 669 | def forward(self, x): 670 | # Initial block 671 | input_size = x.size() 672 | x = self.initial_block(x) 673 | 674 | # Stage 1 - Encoder 675 | stage1_input_size = x.size() 676 | x, max_indices1_0 = self.downsample1_0(x) 677 | x = self.regular1_1(x) 678 | x = self.regular1_2(x) 679 | x = self.regular1_3(x) 680 | x = self.regular1_4(x) 681 | 682 | # Stage 2 - Encoder 683 | stage2_input_size = x.size() 684 | x, max_indices2_0 = self.downsample2_0(x) 685 | x = self.regular2_1(x) 686 | x = self.dilated2_2(x) 687 | x = self.asymmetric2_3(x) 688 | x = self.dilated2_4(x) 689 | x = self.regular2_5(x) 690 | x = self.dilated2_6(x) 691 | x = self.asymmetric2_7(x) 692 | x = self.dilated2_8(x) 693 | 694 | # Stage 3 - Encoder 695 | x = self.regular3_0(x) 696 | x = self.dilated3_1(x) 697 | x = self.asymmetric3_2(x) 698 | x = self.dilated3_3(x) 699 | x = self.regular3_4(x) 700 | x = self.dilated3_5(x) 701 | x = self.asymmetric3_6(x) 702 | x = self.dilated3_7(x) 703 | 704 | # Stage 4 - Decoder 705 | x = self.upsample4_0(x, max_indices2_0, output_size=stage2_input_size) 706 | x = self.regular4_1(x) 707 | x = self.regular4_2(x) 708 | 709 | # Stage 5 - Decoder 710 | x = self.upsample5_0(x, max_indices1_0, output_size=stage1_input_size) 711 | x = self.regular5_1(x) 712 | x = self.transposed_conv(x, output_size=input_size) 713 | 714 | return x 715 | 716 | def _initialize_weights(self): 717 | for m in self.modules(): 718 | if isinstance(m, nn.Conv2d): 719 | #init.orthogonal_(m.weight.data, gain=init.calculate_gain('relu')) 720 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 721 | if m.bias is not None: 722 | nn.init.constant_(m.bias, 0) 723 | elif isinstance(m, nn.BatchNorm2d): 724 | nn.init.constant_(m.weight, 1) 725 | nn.init.constant_(m.bias, 0) 726 | elif isinstance(m, nn.Linear): 727 | nn.init.normal_(m.weight, 0, 0.01) 728 | nn.init.constant_(m.bias, 0) 729 | --------------------------------------------------------------------------------