├── CMakeLists.txt ├── README.md ├── launch └── mav.launch ├── multirotor_base.xacro ├── package.xml ├── scripts ├── History.py ├── ReinforceLearning_node.py ├── cv_trainer.py ├── model.py └── ops.py ├── src └── EnvironmentTracker_node.cpp └── srv ├── GetState.srv └── PerformAction.srv /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(rotors_reinforce) 3 | 4 | ## Add support for C++11, supported in ROS Kinetic and newer 5 | # add_definitions(-std=c++11) 6 | 7 | ## Find catkin macros and libraries 8 | ## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) 9 | ## is used, also find other catkin packages 10 | find_package(catkin REQUIRED COMPONENTS 11 | gazebo_msgs 12 | gazebo_plugins 13 | geometry_msgs 14 | mav_msgs 15 | roscpp 16 | rospy 17 | std_msgs 18 | message_generation 19 | rotors_gazebo_plugins 20 | sensor_msgs 21 | std_msgs 22 | ) 23 | 24 | ## System dependencies are found with CMake's conventions 25 | # find_package(Boost REQUIRED COMPONENTS system) 26 | 27 | 28 | ## Uncomment this if the package has a setup.py. This macro ensures 29 | ## modules and global scripts declared therein get installed 30 | ## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html 31 | #catkin_python_setup() 32 | 33 | ################################################ 34 | ## Declare ROS messages, services and actions ## 35 | ################################################ 36 | 37 | ## To declare and build messages, services or actions from within this 38 | ## package, follow these steps: 39 | ## * Let MSG_DEP_SET be the set of packages whose message types you use in 40 | ## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...). 41 | ## * In the file package.xml: 42 | ## * add a build_depend tag for "message_generation" 43 | ## * add a build_depend and a run_depend tag for each package in MSG_DEP_SET 44 | ## * If MSG_DEP_SET isn't empty the following dependency has been pulled in 45 | ## but can be declared for certainty nonetheless: 46 | ## * add a run_depend tag for "message_runtime" 47 | ## * In this file (CMakeLists.txt): 48 | ## * add "message_generation" and every package in MSG_DEP_SET to 49 | ## find_package(catkin REQUIRED COMPONENTS ...) 50 | ## * add "message_runtime" and every package in MSG_DEP_SET to 51 | ## catkin_package(CATKIN_DEPENDS ...) 52 | ## * uncomment the add_*_files sections below as needed 53 | ## and list every .msg/.srv/.action file to be processed 54 | ## * uncomment the generate_messages entry below 55 | ## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...) 56 | 57 | ## Generate messages in the 'msg' folder 58 | # add_message_files( 59 | # FILES 60 | # Message1.msg 61 | # Message2.msg 62 | # ) 63 | 64 | ## Generate services in the 'srv' folder 65 | add_service_files(FILES PerformAction.srv GetState.srv) 66 | 67 | ## Generate actions in the 'action' folder 68 | # add_action_files( 69 | # FILES 70 | # Action1.action 71 | # Action2.action 72 | # ) 73 | 74 | ## Generate added messages and services with any dependencies listed here 75 | #generate_messages( 76 | # DEPENDENCIES 77 | # gazebo_msgs# geometry_msgs# mav_msgs# sensor_msgs 78 | # ) 79 | 80 | generate_messages(DEPENDENCIES std_msgs sensor_msgs) 81 | 82 | ################################################ 83 | ## Declare ROS dynamic reconfigure parameters ## 84 | ################################################ 85 | 86 | ## To declare and build dynamic reconfigure parameters within this 87 | ## package, follow these steps: 88 | ## * In the file package.xml: 89 | ## * add a build_depend and a run_depend tag for "dynamic_reconfigure" 90 | ## * In this file (CMakeLists.txt): 91 | ## * add "dynamic_reconfigure" to 92 | ## find_package(catkin REQUIRED COMPONENTS ...) 93 | ## * uncomment the "generate_dynamic_reconfigure_options" section below 94 | ## and list every .cfg file to be processed 95 | 96 | ## Generate dynamic reconfigure parameters in the 'cfg' folder 97 | # generate_dynamic_reconfigure_options( 98 | # cfg/DynReconf1.cfg 99 | # cfg/DynReconf2.cfg 100 | # ) 101 | 102 | ################################### 103 | ## catkin specific configuration ## 104 | ################################### 105 | ## The catkin_package macro generates cmake config files for your package 106 | ## Declare things to be passed to dependent projects 107 | ## INCLUDE_DIRS: uncomment this if you package contains header files 108 | ## LIBRARIES: libraries you create in this project that dependent projects also need 109 | ## CATKIN_DEPENDS: catkin_packages dependent projects also need 110 | ## DEPENDS: system dependencies of this project that dependent projects also need 111 | catkin_package( 112 | # INCLUDE_DIRS include 113 | # LIBRARIES rotors_reinforce 114 | CATKIN_DEPENDS gazebo_msgs gazebo_plugins geometry_msgs mav_msgs roscpp rospy rotors_gazebo_plugins sensor_msgs std_msgs 115 | # DEPENDS system_lib 116 | ) 117 | 118 | ########### 119 | ## Build ## 120 | ########### 121 | 122 | ## Specify additional locations of header files 123 | ## Your package locations should be listed before other locations 124 | # include_directories(include) 125 | include_directories( 126 | ${catkin_INCLUDE_DIRS} 127 | ) 128 | 129 | ## Declare a C++ library 130 | # add_library(${PROJECT_NAME} 131 | # src/${PROJECT_NAME}/rotors_reinforce.cpp 132 | # ) 133 | 134 | ## Add cmake target dependencies of the library 135 | ## as an example, code may need to be generated before libraries 136 | ## either from message generation or dynamic reconfigure 137 | #add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) 138 | 139 | ## Declare a C++ executable 140 | ## With catkin_make all packages are built within a single CMake context 141 | ## The recommended prefix ensures that target names across packages don't collide 142 | add_executable(EnvironmentTracker_node src/EnvironmentTracker_node.cpp) 143 | 144 | ## Rename C++ executable without prefix 145 | ## The above recommended prefix causes long target names, the following renames the 146 | ## target back to the shorter version for ease of user use 147 | ## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node" 148 | # set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "") 149 | 150 | ## Add cmake target dependencies of the executable 151 | ## same as for the library above 152 | add_dependencies(EnvironmentTracker_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) 153 | 154 | ## Specify libraries to link a library or executable target against 155 | target_link_libraries(EnvironmentTracker_node 156 | ${catkin_LIBRARIES} 157 | ) 158 | 159 | ############# 160 | ## Install ## 161 | ############# 162 | 163 | # all install targets should use catkin DESTINATION variables 164 | # See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html 165 | 166 | ## Mark executable scripts (Python etc.) for installation 167 | ## in contrast to setup.py, you can choose the destination 168 | catkin_install_python(PROGRAMS 169 | scripts/ReinforceLearning_node.py 170 | DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} 171 | ) 172 | 173 | ## Mark executables and/or libraries for installation 174 | # install(TARGETS ${PROJECT_NAME} ${PROJECT_NAME}_node 175 | # ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} 176 | # LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} 177 | # RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} 178 | # ) 179 | 180 | ## Mark cpp header files for installation 181 | # install(DIRECTORY include/${PROJECT_NAME}/ 182 | # DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION} 183 | # FILES_MATCHING PATTERN "*.h" 184 | # PATTERN ".svn" EXCLUDE 185 | # ) 186 | 187 | ## Mark other files for installation (e.g. launch and bag files, etc.) 188 | # install(FILES 189 | # # myfile1 190 | # # myfile2 191 | # DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} 192 | # ) 193 | 194 | ############# 195 | ## Testing ## 196 | ############# 197 | 198 | ## Add gtest based cpp test target and link libraries 199 | # catkin_add_gtest(${PROJECT_NAME}-test test/test_rotors_reinforce.cpp) 200 | # if(TARGET ${PROJECT_NAME}-test) 201 | # target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME}) 202 | # endif() 203 | 204 | ## Add folders to be run by python nosetests 205 | # catkin_add_nosetests(test) 206 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # deep-reinforcement-learning-drone-control 2 | This is a deep reinforcement learning based drone control system implemented in python (Tensorflow/ROS) and C++ (ROS). To test it, please clone the rotors simulator from https://github.com/ethz-asl/rotors_simulator in your catkin workspace. Copy the multirotor_base.xarco to the rotors simulator for adding the camera to the drone. 3 | 4 | The drone control system operates on camera images as input and a discretized version of the steering commands as output. The neural network model is end-to-end and a non-asynchronous implementation of the A3C model (https://arxiv.org/pdf/1602.01783.pdf), because the gazebo simulator is not capable of running multiple copies in parallel (and neither is my laptop :D). The training is performed on the basis of pretrained weights from a supervised learning task, since the simulator is very resource intensive and training is time consuming. 5 | 6 | The outcome was discussed within a practical course at the RWTH Aachen, where this agent served as a proof-of-concept, that it is possible to efficiently train an end-to-end deep reinforcement learning model on the task of controlling a drone in a realistic 3D environment. 7 | -------------------------------------------------------------------------------- /launch/mav.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /multirotor_base.xacro: -------------------------------------------------------------------------------- 1 | 2 | 21 | 22 | 23 | 24 | 25 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | ${robot_namespace} 80 | ${robot_namespace}/base_link 81 | ${rotor_velocity_slowdown_sim} 82 | 83 | 84 | 85 | 86 | 10.0 87 | 1 88 | 89 | ${robot_namespace}/base_link_fixed_joint_lump__base_link_collision 90 | 91 | 92 | true 93 | 10.0 94 | /base_collision 95 | world 96 | 97 | 98 | 99 | 100 | 101 | 102 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 123 | 124 | 125 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | ${robot_namespace} 140 | ${robot_namespace}/rotor_${motor_number}_joint 141 | ${robot_namespace}/rotor_${motor_number} 142 | ${direction} 143 | ${time_constant_up} 144 | ${time_constant_down} 145 | ${max_rot_velocity} 146 | ${motor_constant} 147 | ${moment_constant} 148 | gazebo/command/motor_speed 149 | ${motor_number} 150 | ${rotor_drag_coefficient} 151 | ${rolling_moment_coefficient} 152 | motor_speed/${motor_number} 153 | ${rotor_velocity_slowdown_sim} 154 | 155 | 156 | 157 | Gazebo/${color} 158 | 159 | 10.0 160 | 1 161 | 162 | ${robot_namespace}/rotor_${motor_number}_collision 163 | 164 | 165 | true 166 | 10.0 167 | /rotor_collision 168 | world 169 | 170 | 171 | 172 | 173 | 174 | -------------------------------------------------------------------------------- /package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | rotors_reinforce 4 | 0.0.0 5 | The Reinforcement Learning package 6 | 7 | 8 | 9 | 10 | fischer_t 11 | 12 | 13 | 14 | 15 | 16 | TODO 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | catkin 43 | message_generation 44 | gazebo_msgs 45 | gazebo_plugins 46 | geometry_msgs 47 | mav_msgs 48 | roscpp 49 | rospy 50 | rotors_gazebo_plugins 51 | sensor_msgs 52 | std_msgs 53 | 54 | message_runtime 55 | gazebo_msgs 56 | gazebo_plugins 57 | geometry_msgs 58 | mav_msgs 59 | roscpp 60 | rospy 61 | rotors_gazebo_plugins 62 | sensor_msgs 63 | std_msgs 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /scripts/History.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import numpy as np 4 | 5 | class History: 6 | def __init__(self, actionNumber, actionSize): 7 | self.states = [] 8 | self.actions = [] 9 | self.rewards = [] 10 | self.responses = [] 11 | self.actionNumber = actionNumber 12 | self.actionSize = actionSize 13 | 14 | def addAction2History(self, action): 15 | 16 | self.actions.append(action) 17 | 18 | def getActions(self): 19 | return self.actions 20 | 21 | def getLastAction(self): 22 | assert (len(self.actions) > 0), "Action history is empty!" 23 | return self.actions[-1] 24 | 25 | def addState2History(self, state): 26 | self.states.append(state) 27 | 28 | def addResponse2History(self, response): 29 | self.responses.append(response) 30 | 31 | def getStates(self): 32 | return self.states[:-1] 33 | 34 | def getResponses(self): 35 | return self.responses[:-1] 36 | 37 | def getState(self, iterator): 38 | assert (len(self.states) > 0), "State history is empty!" 39 | return np.expand_dims(self.states[iterator], 0) 40 | 41 | def getLastState(self): 42 | assert (len(self.states) > 0), "State history is empty!" 43 | return np.expand_dims(self.states[-1], 0) 44 | 45 | def addReward2History(self, reward): 46 | self.rewards.append(reward) 47 | 48 | def getRewardHistory(self): 49 | return self.rewards 50 | 51 | def getLastReward(self): 52 | assert (len(self.rewards) > 0), "Reward history is empty!" 53 | return self.rewards[-1] 54 | 55 | def sumRewards(self): 56 | return sum(self.rewards) 57 | 58 | def clean(self): 59 | self.states = [] 60 | self.actions = [] 61 | self.rewards = [] 62 | self.responses = [] -------------------------------------------------------------------------------- /scripts/ReinforceLearning_node.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import rospy 4 | from rotors_reinforce.srv import PerformAction 5 | from rotors_reinforce.srv import GetState 6 | from History import * 7 | from model import ComputationalGraph 8 | 9 | # ROS Image message -> OpenCV2 image converter 10 | from cv_bridge import CvBridge, CvBridgeError 11 | # OpenCV2 for saving an image 12 | import cv2 13 | 14 | import tensorflow as tf 15 | import random 16 | import numpy as np 17 | 18 | TRAIN_MODEL = 1 19 | 20 | EPISODE_LENGTH = 120 21 | 22 | EPSILON = 0.9 23 | EPSILON_DECAY = 0.97 24 | EPSILON_MIN = 0.01 25 | EPSILON_STEP = 10 26 | 27 | ACTION_NUM = 3 28 | ACTION_SIZE = 3 29 | IMG_HEIGHT = 75 30 | IMG_WIDTH = 75 31 | IMG_CHANNELS = 2 32 | RESPAWN_CODE = [0, 0, 0, 42] 33 | 34 | bridge = CvBridge() 35 | 36 | 37 | def convertImage(img): 38 | try: 39 | # Convert your ROS Image message to OpenCV2 40 | cv2_img = bridge.imgmsg_to_cv2(img, "mono8") 41 | except CvBridgeError, e: 42 | print(e) 43 | 44 | # Format for the Tensor 45 | cv2_img = cv2.resize(cv2_img, (IMG_WIDTH, IMG_HEIGHT)) 46 | 47 | return cv2_img 48 | 49 | 50 | def convertState(state, old_state): 51 | converted_state = np.concatenate((np.expand_dims(convertImage(state.img), -1), np.expand_dims(convertImage(old_state.img), -1)), axis=2) 52 | return converted_state #return 2 channel image consisting of 2 states 53 | #return np.expand_dims(convertImage(state.img), -1) 54 | 55 | def chooseAction(probabilities, e, is_training): 56 | action = np.zeros(ACTION_NUM) 57 | 58 | for i in range(len(probabilities[0])): 59 | 60 | if is_training and random.uniform(0, 1) < e: 61 | action[i] = random.randint(0, 2) 62 | else: 63 | action[i] = np.argmax(probabilities[0][i]) 64 | 65 | return action 66 | 67 | 68 | 69 | def reinforce_node(): 70 | 71 | #set up env 72 | rospy.init_node('ReinforceLearning', anonymous=True) 73 | perform_action_client = rospy.ServiceProxy('env_tr_perform_action', PerformAction) 74 | get_state_client = rospy.ServiceProxy('env_tr_get_state', GetState) 75 | 76 | history = History(ACTION_NUM, ACTION_SIZE) 77 | 78 | graph = ComputationalGraph(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS) 79 | 80 | action = np.zeros(ACTION_NUM) 81 | executed_action = np.zeros(ACTION_NUM + 1) 82 | 83 | sess = tf.Session() 84 | 85 | graph.constructGraph(sess, ACTION_NUM, ACTION_SIZE) 86 | 87 | # restoring agent 88 | saver = tf.train.Saver() 89 | try: 90 | saver.restore(sess, "./log/model.ckpt") 91 | print "model restored" 92 | except: 93 | print "model restore failed. random initialization" 94 | init_op = tf.global_variables_initializer() 95 | sess.run(init_op) 96 | 97 | # restore pretrained layers for random init 98 | vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='CV_graph') 99 | cv_saver = tf.train.Saver(vars) 100 | cv_saver.restore(sess, "./cv_graph/cv_graph.cptk") 101 | print "pretrained weights restored" 102 | 103 | 104 | step = 0 105 | e = EPSILON 106 | 107 | # main loop: 108 | while not rospy.is_shutdown(): 109 | crashed_flag = False 110 | 111 | #get initial state 112 | rospy.wait_for_service('env_tr_get_state') 113 | try: 114 | response = get_state_client() 115 | except rospy.ServiceException, e: 116 | print "Service call failed: %s" % e 117 | 118 | 119 | old_response = response 120 | 121 | state = convertState(response, old_response) 122 | history.clean() 123 | history.addState2History(state) 124 | 125 | print response.target_position 126 | 127 | episode_step = 0 128 | # run episode, while not crashed and simulation is running 129 | while not crashed_flag and not rospy.is_shutdown(): 130 | #get most probable variant to act for each action, and the probabilities 131 | probabilities = graph.calculateAction(sess, history.getLastState()) 132 | 133 | #choose action according to softmax distribution (add epsilon randomness in training) 134 | action = chooseAction(probabilities, e, TRAIN_MODEL) 135 | 136 | #choose roll, pitch and thrust according to network output 137 | executed_action[0] = (float(action[0]) - float(ACTION_SIZE) / 2) * 2.0 / float(ACTION_SIZE - 1) 138 | executed_action[1] = (float(action[1]) - float(ACTION_SIZE) / 2) * 2.0 / float(ACTION_SIZE - 1) 139 | #executed_action[2] = (float(action[2]) - float(ACTION_SIZE) / 2) 140 | executed_action[2] = 0. #we skip the yaw command dimesion 141 | executed_action[3] = float(action[2]) / float(ACTION_SIZE - 1) 142 | 143 | rospy.wait_for_service('env_tr_perform_action') 144 | try: 145 | response = perform_action_client(executed_action) 146 | except rospy.ServiceException, e: 147 | print "Service call failed: %s" % e 148 | 149 | state = convertState(response, old_response) 150 | old_response = response 151 | 152 | #update history 153 | actionmatrix = np.zeros([ACTION_NUM, ACTION_SIZE]) 154 | for i in xrange(len(actionmatrix)): 155 | actionmatrix[i][int(action[i])] = 1 156 | 157 | history.addAction2History(actionmatrix) 158 | history.addState2History(state) 159 | history.addResponse2History(response) 160 | history.addReward2History(response.reward) 161 | 162 | crashed_flag = response.crashed 163 | 164 | episode_step+=1 165 | if episode_step >= EPISODE_LENGTH: 166 | rospy.wait_for_service('env_tr_perform_action') 167 | try: 168 | response = perform_action_client(RESPAWN_CODE) 169 | except rospy.ServiceException, e: 170 | print "Service call failed: %s" % e 171 | break 172 | 173 | # update policy 174 | if TRAIN_MODEL == 1: 175 | graph.updatePolicy(sess, history, step) 176 | 177 | #save every 50 episodes 178 | if TRAIN_MODEL == 1 and step % 50 == 0: 179 | save_path = saver.save(sess, "log/model.ckpt") 180 | print("Model saved in file: %s" % save_path) 181 | 182 | if step % EPSILON_STEP == 0: 183 | e = e * EPSILON_DECAY 184 | 185 | print "episode number: ", step 186 | print "total reward: ", history.sumRewards() 187 | step += 1 188 | 189 | 190 | if __name__ == '__main__': 191 | try: 192 | print("starting...") 193 | reinforce_node() 194 | except rospy.ROSInterruptException: 195 | pass 196 | 197 | -------------------------------------------------------------------------------- /scripts/cv_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # ROS Image message -> OpenCV2 image converter 4 | from cv_bridge import CvBridge, CvBridgeError 5 | import matplotlib.pyplot as plt 6 | # OpenCV2 for saving an image 7 | import cv2 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | 12 | import pickle 13 | 14 | from ops import conv2d, dense 15 | 16 | 17 | IMG_WIDTH = 75 18 | IMG_HEIGHT = 75 19 | CHANNELS = 2 20 | OUTPUT_SIZE = 3 21 | 22 | NUM_EPOCHS = 40 23 | BATCH_SIZE = 128 24 | LEARNING_RATE = 0.0001 25 | DROPOUT_KEEP_PROBABILITY = 0.5 26 | 27 | FILENAME_TRAINING_DATA = "data_training" 28 | FILENAME_TEST_DATA = "data_test" 29 | 30 | SAVE_PATH = "./cv_graph/" 31 | if not os.path.exists(SAVE_PATH): 32 | os.makedirs(SAVE_PATH) 33 | 34 | class TrainingData: 35 | def __init__(self): 36 | self.img = [] 37 | self.coordinate = [] 38 | self.orientation = [] 39 | 40 | def add(self, im, c, o): 41 | self.img.append(im) 42 | self.coordinate.append(c) 43 | self.orientation.append(o) 44 | 45 | 46 | class CVGraph: 47 | def __init__(self, h, w, c, o): 48 | self.screen_height = h 49 | self.screen_width = w 50 | self.channels = c 51 | self.output_size = o 52 | self.orientation_size = 4 53 | self.cnn_output_size = 3 54 | 55 | def buildGraph(self): 56 | with tf.variable_scope("CV_graph"): 57 | self.input = tf.placeholder('float32', 58 | [None, self.screen_height, self.screen_width, self.channels], name='input') 59 | 60 | initializer = tf.contrib.layers.xavier_initializer() 61 | activation_fn = tf.nn.relu 62 | 63 | self.conv1, self.conv1_w, self.conv1_b = conv2d(self.input, 64 | 16, 65 | [5, 5], 66 | [1, 1], 67 | initializer, 68 | activation_fn, 69 | 'NHWC', 70 | name='conv1') 71 | 72 | self.conv2, self.conv2_w, self.conv2_b = conv2d(self.conv1, 73 | 16, 74 | [5, 5], 75 | [1, 1], 76 | initializer, 77 | activation_fn, 78 | 'NHWC', 79 | name='conv2') 80 | 81 | self.max_pool1 = tf.nn.max_pool(self.conv2, 82 | [1, 2, 2, 1], 83 | [1, 2, 2, 1], 84 | padding='VALID', 85 | name='max_pool1') 86 | 87 | self.conv3, self.conv3_w, self.conv3_b = conv2d(self.max_pool1, 88 | 32, 89 | [3, 3], 90 | [1, 1], 91 | initializer, 92 | activation_fn, 93 | 'NHWC', 94 | name='conv3') 95 | 96 | self.conv4, self.conv4_w, self.conv4_b = conv2d(self.conv3, 97 | 32, 98 | [3, 3], 99 | [1, 1], 100 | initializer, 101 | activation_fn, 102 | 'NHWC', 103 | name='conv4') 104 | 105 | self.max_pool2 = tf.nn.max_pool(self.conv4, 106 | [1, 2, 2, 1], 107 | [1, 2, 2, 1], 108 | padding='VALID', 109 | name='max_pool2') 110 | 111 | self.flat = tf.contrib.layers.flatten(self.max_pool2) 112 | 113 | 114 | self.keep_probability = tf.placeholder('float32', name='keep_probability') 115 | 116 | self.dense1, self.dense1_w, self.dense1_b = dense(inputs=self.flat, units=100, 117 | activation=tf.nn.relu, name='dense1') 118 | self.dropout1 = tf.nn.dropout(self.dense1, self.keep_probability) 119 | self.dense2, self.dense2_w, self.dense2_b = dense(inputs=self.dropout1, units=50, 120 | activation=tf.nn.relu, name='dense2') 121 | self.dropout2 = tf.nn.dropout(self.dense2, self.keep_probability) 122 | 123 | self.dense3, self.dense3_w, self.dense3_b = dense(inputs=self.dropout2, 124 | units=10, activation=tf.nn.relu, 125 | name='dense3') 126 | 127 | self.output, self.output_w, self.output_b = dense(inputs=self.dense3, 128 | units=self.output_size, activation=None, 129 | name='output') 130 | 131 | self.ground_truth = tf.placeholder('float32', [None, self.output_size], name='ground_truth') 132 | self.loss = tf.reduce_mean(tf.square(tf.subtract(self.output, self.ground_truth))) 133 | self.global_step = tf.Variable(0, trainable=False) 134 | 135 | self.optimizer = tf.train.AdamOptimizer(LEARNING_RATE).minimize(self.loss) 136 | 137 | def train(self, sess, img, labels, step): 138 | _, loss = sess.run([self.optimizer, self.loss], feed_dict={self.input: img, 139 | self.ground_truth: labels, 140 | self.global_step: step, 141 | self.keep_probability: DROPOUT_KEEP_PROBABILITY}) 142 | return loss 143 | 144 | def prediction(self, sess, img): 145 | output = sess.run(self.output, feed_dict={self.input: img, 146 | self.keep_probability: 1.0}) 147 | 148 | return output 149 | 150 | def validate(self, sess, img, labels, step): 151 | loss = sess.run(self.loss, feed_dict={self.input: img, 152 | self.ground_truth: labels, 153 | self.global_step: step, 154 | self.keep_probability: 1.0}) 155 | return loss 156 | 157 | 158 | def convertImage(img): 159 | bridge = CvBridge() 160 | 161 | try: 162 | # Convert your ROS Image message to OpenCV2 163 | cv2_img = bridge.imgmsg_to_cv2(img, "mono8") 164 | except CvBridgeError, e: 165 | print(e) 166 | 167 | cv2_img = cv2.resize(cv2_img, (IMG_WIDTH, IMG_HEIGHT)) 168 | 169 | return cv2_img 170 | 171 | def merge(image, label, old_image, old_label): 172 | merged_image = np.concatenate((np.expand_dims(image, -1), np.expand_dims(old_image, -1)), axis=2) 173 | merged_label = label - old_label 174 | return merged_image, merged_label #return 2 channel image consisting of 2 states 175 | 176 | def process(data): 177 | for i in range(len(data.img)): 178 | if (i+1 >= len(data.img)): 179 | merged_image, merged_label = merge(data.img[i], data.coordinate[i], data.img[i], data.coordinate[i]) 180 | data.img[i] = merged_image 181 | data.coordinate[i] = merged_label 182 | else: 183 | merged_image, merged_label = merge(data.img[i+1], data.coordinate[i+1], data.img[i], data.coordinate[i]) 184 | data.img[i] = merged_image 185 | data.coordinate[i] = merged_label 186 | 187 | return data 188 | 189 | def sampleData(data, index, sample_size): 190 | indexes = range(index, index + sample_size) 191 | images = np.array([data.img[i] for i in indexes]) 192 | labels = np.array([[float(data.coordinate[i][0]), float(data.coordinate[i][1]), float(data.coordinate[i][2])] for i in indexes]) 193 | 194 | return images, labels 195 | 196 | 197 | def visualize(data): 198 | 199 | for i in range(len(data.img)): 200 | print i, data.coordinate[i], data.orientation[i] 201 | plt.imshow(data.img[i], 'gray'), plt.title(str(i)) 202 | plt.show() 203 | 204 | 205 | def train_net(): 206 | global sess 207 | sess = tf.Session() 208 | 209 | graph = CVGraph(IMG_HEIGHT, IMG_WIDTH, CHANNELS, OUTPUT_SIZE) 210 | graph.buildGraph() 211 | 212 | init_op = tf.global_variables_initializer() 213 | sess.run(init_op) 214 | 215 | saver = tf.train.Saver() 216 | 217 | #data containing current camera image, distance to target, orientation 218 | data = pickle.load(open(FILENAME_TRAINING_DATA, "rb")) 219 | test_data = pickle.load(open(FILENAME_TEST_DATA, "rb")) 220 | 221 | data = process(data) 222 | test_data = process(test_data) 223 | 224 | global_step = 1 225 | 226 | for epoch in range(NUM_EPOCHS): 227 | print "epoch ", epoch 228 | avg_loss = 0 229 | num_batches = int(len(data.img)/(BATCH_SIZE)) 230 | index = 0 231 | for i in range(num_batches): 232 | images, labels = sampleData(data, index, BATCH_SIZE) 233 | loss = graph.train(sess, images, labels, global_step) 234 | avg_loss += loss 235 | global_step += 1 236 | index += BATCH_SIZE 237 | 238 | avg_loss = avg_loss / float(num_batches) 239 | 240 | saver.save(sess, SAVE_PATH + 'cv_graph.cptk') 241 | test_loss = test_net(sess, graph, global_step, test_data) 242 | 243 | print "validation Loss: ", test_loss 244 | print "training Loss: ", avg_loss 245 | 246 | 247 | def test_net(sess, graph, global_step, test_data): 248 | num_batches = int(len(test_data.img) / BATCH_SIZE) 249 | index = 0 250 | loss = 0 251 | for i in range(num_batches): 252 | images, labels = sampleData(test_data, index, BATCH_SIZE) 253 | loss += graph.validate(sess, images, labels, global_step) 254 | index += BATCH_SIZE 255 | 256 | return loss/num_batches 257 | 258 | 259 | if __name__ == '__main__': 260 | print("starting...") 261 | train_net() 262 | -------------------------------------------------------------------------------- /scripts/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from ops import conv2d, dense, variable_summaries 4 | 5 | class ComputationalGraph: 6 | def __init__(self, img_h, img_w, img_c): 7 | self.image_height, self.image_width, self.image_channels = img_h, img_w, img_c 8 | 9 | #learning rates suggested in A3C paper 10 | self.lrate = 2.5e-3 11 | self.learning_rate_minimum = 5.0e-4 12 | 13 | self.global_step = tf.Variable(0, trainable=False) 14 | self.decay_step = 40 15 | self.dropout_keep_prob = 0.5 16 | 17 | def constructGraph(self, sess, action_num, action_size): 18 | self.episode_rewards = tf.placeholder(tf.float32, [None], name="episode_rewards") 19 | self.state = tf.placeholder(tf.float32, [None, self.image_height, self.image_width, self.image_channels], name="state") 20 | 21 | ##ConvNet for feature extraction 22 | with tf.variable_scope("CV_graph"): 23 | 24 | initializer = tf.contrib.layers.xavier_initializer() 25 | activation_fn = tf.nn.relu 26 | 27 | self.conv1, self.conv1_w, self.conv1_b = conv2d(self.state, 28 | 16, 29 | [5, 5], 30 | [1, 1], 31 | initializer, 32 | activation_fn, 33 | 'NHWC', 34 | name='conv1') 35 | 36 | 37 | self.conv2, self.conv2_w, self.conv2_b = conv2d(self.conv1, 38 | 16, 39 | [5, 5], 40 | [1, 1], 41 | initializer, 42 | activation_fn, 43 | 'NHWC', 44 | name='conv2') 45 | 46 | self.max_pool1 = tf.nn.max_pool(self.conv2, 47 | [1, 2, 2, 1], 48 | [1, 2, 2, 1], 49 | padding='VALID', 50 | name='max_pool1') 51 | 52 | self.conv3, self.conv3_w, self.conv3_b = conv2d(self.max_pool1, 53 | 32, 54 | [3, 3], 55 | [1, 1], 56 | initializer, 57 | activation_fn, 58 | 'NHWC', 59 | name='conv3') 60 | 61 | self.conv4, self.conv4_w, self.conv4_b = conv2d(self.conv3, 62 | 32, 63 | [3, 3], 64 | [1, 1], 65 | initializer, 66 | activation_fn, 67 | 'NHWC', 68 | name='conv4') 69 | 70 | self.max_pool2 = tf.nn.max_pool(self.conv4, 71 | [1, 2, 2, 1], 72 | [1, 2, 2, 1], 73 | padding='VALID', 74 | name='max_pool2') 75 | 76 | self.flat = tf.contrib.layers.flatten(self.max_pool2) 77 | 78 | self.dense1, self.dense1_w, self.dense1_b = dense(inputs=self.flat, units=100, 79 | activation=tf.nn.relu, name='dense1') 80 | self.dropout1 = tf.nn.dropout(self.dense1, self.dropout_keep_prob) 81 | self.dense2, self.dense2_w, self.dense2_b = dense(inputs=self.dropout1, units=50, 82 | activation=tf.nn.relu, name='dense2') 83 | self.dropout2 = tf.nn.dropout(self.dense2, self.dropout_keep_prob) 84 | 85 | with tf.variable_scope("Agent"): 86 | ##policy network 87 | 88 | self.po_dense3, self.po_dense3_w, self.po_dense3_b = dense(inputs=self.dropout2, units=10, 89 | activation=tf.nn.relu, name='po_dense3') 90 | 91 | self.po_dense4, self.po_dense4_w, self.po_dense4_b = dense(inputs=self.po_dense3, units=action_num*action_size, activation=None, name='po_dense4') 92 | self.po_probabilities = tf.nn.softmax(tf.reshape(self.po_dense4, [-1, action_num, action_size])) 93 | 94 | self.po_prev_actions = tf.placeholder(tf.float32, [None, action_num, action_size], name="po_prev_action") 95 | self.po_return = tf.placeholder(tf.float32, [None, 1], name="po_return") 96 | self.po_eligibility = tf.log(tf.reduce_sum(tf.multiply(self.po_prev_actions, self.po_probabilities), axis=-1)) * self.po_return 97 | self.po_loss = -tf.reduce_sum(self.po_eligibility) 98 | 99 | ##value network 100 | 101 | self.v_dense3, self.v_dense3_w, self.v_dense3_b = dense(inputs=self.dropout2, units=10, activation=tf.nn.relu, name='v_dense3') 102 | 103 | self.v_output, self.v_output_w, self.v_output_b = dense(inputs=self.v_dense3, activation=None, units=1, name='v_output') 104 | 105 | self.v_actual_return = tf.placeholder(tf.float32, [None, 1], name="v_actual_return") 106 | self.v_loss = tf.nn.l2_loss(tf.subtract(self.v_output, self.v_actual_return)) 107 | 108 | ##optimization 109 | 110 | self.learning_rate_op = tf.maximum(self.learning_rate_minimum, 111 | tf.train.exponential_decay( 112 | self.lrate, 113 | self.global_step, 114 | self.decay_step, 115 | 0.98, 116 | staircase=True)) 117 | 118 | self.loss = 0.5 * self.v_loss + self.po_loss 119 | 120 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate_op).minimize(self.loss) 121 | 122 | self.constructSummary(sess) 123 | 124 | def constructSummary(self, sess): 125 | variable_summaries(self.episode_rewards) 126 | self.merged = tf.summary.merge_all() 127 | self.train_writer = tf.summary.FileWriter('./log/train', sess.graph) 128 | 129 | 130 | def calculateAction(self, sess, state): 131 | return sess.run(self.po_probabilities, feed_dict={self.state: state}) 132 | 133 | def calculateReward(self, sess, state): 134 | reward = sess.run(self.v_output, feed_dict={self.state: state}) 135 | return reward[0][0] 136 | 137 | def updatePolicy(self, sess, history, step): 138 | rewards = history.getRewardHistory() 139 | advantages = [] 140 | update_vals = [] 141 | episode_reward = 0 142 | 143 | for i, reward in enumerate(rewards): 144 | episode_reward += reward 145 | future_reward = 0 146 | future_transitions = len(rewards) - i 147 | decrease = 1 148 | for index2 in xrange(future_transitions): 149 | future_reward += rewards[(index2) + i] * decrease 150 | decrease = decrease * 0.98 151 | 152 | prediction = self.calculateReward(sess, history.getState(i)) 153 | advantages.append(future_reward - prediction) 154 | update_vals.append(future_reward) 155 | 156 | statistics, _ = sess.run([self.merged, self.optimizer], feed_dict={self.state: history.getStates(), 157 | self.po_prev_actions: history.getActions(), 158 | self.po_return: np.expand_dims(advantages, axis=1), 159 | self.episode_rewards: history.getRewardHistory(), 160 | self.v_actual_return: np.expand_dims(update_vals, axis=1), 161 | self.global_step: step 162 | }) 163 | 164 | self.train_writer.add_summary(statistics, step) 165 | self.train_writer.flush() -------------------------------------------------------------------------------- /scripts/ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.layers.python.layers import initializers 3 | 4 | def clipped_error(x): 5 | # Huber loss 6 | try: 7 | return tf.select(tf.abs(x) < 1.0, 0.5 * tf.square(x), tf.abs(x) - 0.5) 8 | except: 9 | return tf.where(tf.abs(x) < 1.0, 0.5 * tf.square(x), tf.abs(x) - 0.5) 10 | 11 | 12 | def variable_summaries(var): 13 | """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" 14 | with tf.name_scope('summaries'): 15 | mean = tf.reduce_mean(var) 16 | tf.summary.scalar('mean', mean) 17 | with tf.name_scope('stddev'): 18 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 19 | tf.summary.scalar('stddev', stddev) 20 | tf.summary.scalar('max', tf.reduce_max(var)) 21 | tf.summary.scalar('min', tf.reduce_min(var)) 22 | tf.summary.histogram('histogram', var) 23 | tf.summary.scalar('sum', tf.reduce_sum(var)) 24 | 25 | def conv2d(x, 26 | output_dim, 27 | kernel_size, 28 | stride, 29 | initializer=tf.contrib.layers.xavier_initializer(), 30 | activation_fn=tf.nn.relu, 31 | data_format='NHWC', 32 | padding='VALID', 33 | name='conv2d'): 34 | with tf.variable_scope(name): 35 | if data_format == 'NCHW': 36 | stride = [1, 1, stride[0], stride[1]] 37 | kernel_shape = [kernel_size[0], kernel_size[1], x.get_shape()[1], output_dim] 38 | elif data_format == 'NHWC': 39 | stride = [1, stride[0], stride[1], 1] 40 | kernel_shape = [kernel_size[0], kernel_size[1], x.get_shape()[-1], output_dim] 41 | 42 | w = tf.get_variable('w', kernel_shape, tf.float32, initializer=initializer) 43 | conv = tf.nn.conv2d(x, w, stride, padding, data_format=data_format) 44 | 45 | b = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 46 | out = tf.nn.bias_add(conv, b, data_format) 47 | 48 | if activation_fn != None: 49 | out = activation_fn(out) 50 | 51 | return out, w, b 52 | 53 | def dense(inputs, activation, units, initializer=tf.contrib.layers.xavier_initializer(), name='dense'): 54 | shape = inputs.get_shape().as_list() 55 | 56 | with tf.variable_scope(name): 57 | w = tf.get_variable('weights', [shape[1], units], tf.float32, 58 | initializer) 59 | b = tf.get_variable('bias', [units], 60 | initializer=tf.constant_initializer(0.0)) 61 | 62 | out = tf.nn.bias_add(tf.matmul(inputs, w), b) 63 | 64 | if activation != None: 65 | return activation(out), w, b 66 | else: 67 | return out, w, b 68 | -------------------------------------------------------------------------------- /src/EnvironmentTracker_node.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "ros/ros.h" 11 | #include "mav_msgs/default_topics.h" 12 | #include "mav_msgs/RollPitchYawrateThrust.h" 13 | #include "mav_msgs/Actuators.h" 14 | #include "mav_msgs/eigen_mav_msgs.h" 15 | #include "geometry_msgs/Pose.h" 16 | #include "std_srvs/Empty.h" 17 | #include "gazebo_msgs/GetModelState.h" 18 | #include "gazebo_msgs/SetModelState.h" 19 | #include "gazebo_msgs/ContactsState.h" 20 | #include "rotors_comm/WindSpeed.h" 21 | #include "gazebo_msgs/ModelState.h" 22 | 23 | #include 24 | 25 | #include "rotors_reinforce/PerformAction.h" 26 | #include "rotors_reinforce/GetState.h" 27 | 28 | #include 29 | #include 30 | 31 | const double max_v_xy = 1.0; // [m/s] 32 | const double max_roll = 10.0 * M_PI / 180.0; // [rad] 33 | const double max_pitch = 10.0 * M_PI / 180.0; // [rad] 34 | const double max_rate_yaw = 45.0 * M_PI / 180.0; // [rad/s] 35 | const double max_thrust = 30.0; // [N] 36 | 37 | const double MAX_WIND_VELOCTIY = 3.4; // meter per second --> max 12km/h (equals wind force 3) 38 | 39 | const double axes_roll_direction = -1.0; 40 | const double axes_pitch_direction = 1.0; 41 | const double axes_thrust_direction = 1.0; 42 | 43 | class environmentTracker { 44 | 45 | private: 46 | ros::NodeHandle n; 47 | ros::Publisher firefly_control_pub; 48 | ros::Publisher firefly_wind_pub; 49 | 50 | ros::ServiceClient firefly_reset_client; 51 | ros::ServiceClient get_position_client; 52 | ros::ServiceClient pause_physics; 53 | ros::ServiceClient unpause_physics; 54 | ros::ServiceClient set_state; 55 | 56 | 57 | ros::Subscriber firefly_position_sub; 58 | ros::Subscriber firefly_collision_sub; 59 | ros::Subscriber firefly_ground_collision_sub; 60 | ros::Subscriber firefly_camera_sub; 61 | ros::Subscriber firefly_camera_depth_sub; 62 | ros::ServiceServer perform_action_srv; 63 | ros::ServiceServer get_state_srv; 64 | 65 | std::default_random_engine re; 66 | 67 | sensor_msgs::Image current_img; 68 | sensor_msgs::Image current_img_depth; 69 | 70 | int step_counter; 71 | double current_yaw_vel_; 72 | bool crashed_flag; 73 | bool random_target; 74 | bool enable_wind; 75 | 76 | public: 77 | std::vector current_position; 78 | std::vector current_orientation; 79 | std::vector current_control_params; 80 | 81 | std::vector target_position; 82 | 83 | environmentTracker(ros::NodeHandle node, const std::vector target_pos, bool wind ) { 84 | current_position.resize(3); 85 | //hard constants for target 86 | if (target_pos[0] != 0.0 || target_pos[1] != 0.0 || target_pos[2] != 0.0) { 87 | target_position = target_pos; 88 | random_target = false; 89 | } 90 | else { 91 | target_position = {3.0, 1.0, 7.5}; 92 | random_target = true; 93 | } 94 | enable_wind = wind; 95 | current_orientation.resize(4); 96 | current_control_params.resize(4, 0); 97 | step_counter = 0; 98 | 99 | current_yaw_vel_ = 0.0; 100 | 101 | n = node; 102 | firefly_control_pub = n.advertise("/firefly/command/roll_pitch_yawrate_thrust", 1000); 103 | firefly_wind_pub = n.advertise("/firefly/wind_speed", 1000); 104 | firefly_collision_sub = n.subscribe("/rotor_collision", 100, &environmentTracker::onCollision, this); 105 | firefly_ground_collision_sub = n.subscribe("/base_collision", 100, &environmentTracker::onCollisionGround, this); 106 | firefly_reset_client = n.serviceClient("/gazebo/reset_world"); 107 | firefly_camera_sub = n.subscribe("/firefly/vi_sensor/camera_depth/camera/image_raw", 1, &environmentTracker::getImage, this); 108 | //firefly_camera_depth_sub = n.subscribe("/firefly/vi_sensor/camera_depth/depth/disparity", 1, &environmentTracker::getImageDepth, this); 109 | 110 | pause_physics = n.serviceClient("/gazebo/pause_physics"); 111 | unpause_physics = n.serviceClient("/gazebo/unpause_physics"); 112 | set_state = n.serviceClient("/gazebo/set_model_state"); 113 | 114 | get_position_client = n.serviceClient("/gazebo/get_model_state"); 115 | perform_action_srv = n.advertiseService("env_tr_perform_action", &environmentTracker::performAction, this); 116 | get_state_srv = n.advertiseService("env_tr_get_state", &environmentTracker::getState, this); 117 | 118 | gazebo_msgs::ModelState modelstate; 119 | gazebo_msgs::SetModelState set_state_srv; 120 | modelstate.model_name = "firefly"; 121 | modelstate.pose.position.x = 0.0; 122 | modelstate.pose.position.y = 0.0; 123 | modelstate.pose.position.z = 5.0; 124 | set_state_srv.request.model_state = modelstate; 125 | set_state.call(set_state_srv); 126 | } 127 | 128 | double round_number(double number){ 129 | return round( number * 100.0 ) / 100.0; 130 | } 131 | 132 | void getImage(const sensor_msgs::ImageConstPtr& msg) 133 | { 134 | 135 | current_img = *msg; 136 | } 137 | 138 | void getImageDepth(const sensor_msgs::ImageConstPtr& msg) 139 | { 140 | 141 | current_img_depth = *msg; 142 | } 143 | 144 | void setCrossPosition(std::vector pos) { 145 | gazebo_msgs::ModelState modelstate; 146 | 147 | gazebo_msgs::SetModelState set_state_srv; 148 | 149 | modelstate.model_name = "small box"; 150 | modelstate.pose.position.x = pos[0]; 151 | modelstate.pose.position.y = pos[1]; 152 | modelstate.pose.position.z = 0.08; 153 | 154 | set_state_srv.request.model_state = modelstate; 155 | 156 | if (!set_state.call(set_state_srv)) { 157 | //ROS_INFO("Position: %f %f %f", (float)srv.response.pose.position.x, (float)srv.response.pose.position.y, (float)srv.response.pose.position.z); 158 | ROS_ERROR("Failed to set position"); 159 | } 160 | 161 | modelstate.model_name = "small box_0"; 162 | modelstate.pose.position.x = pos[0]; 163 | modelstate.pose.position.y = pos[1] + 0.2; 164 | modelstate.pose.position.z = 0.08; 165 | 166 | set_state_srv.request.model_state = modelstate; 167 | 168 | if (!set_state.call(set_state_srv)) { 169 | //ROS_INFO("Position: %f %f %f", (float)srv.response.pose.position.x, (float)srv.response.pose.position.y, (float)srv.response.pose.position.z); 170 | ROS_ERROR("Failed to set position"); 171 | } 172 | 173 | modelstate.model_name = "small box_1"; 174 | modelstate.pose.position.x = pos[0]; 175 | modelstate.pose.position.y = pos[1] - 0.2; 176 | modelstate.pose.position.z = 0.08; 177 | 178 | set_state_srv.request.model_state = modelstate; 179 | 180 | if (!set_state.call(set_state_srv)) { 181 | //ROS_INFO("Position: %f %f %f", (float)srv.response.pose.position.x, (float)srv.response.pose.position.y, (float)srv.response.pose.position.z); 182 | ROS_ERROR("Failed to set position"); 183 | } 184 | 185 | modelstate.model_name = "small box_2"; 186 | modelstate.pose.position.x = pos[0] + 0.31; 187 | modelstate.pose.position.y = pos[1]; 188 | modelstate.pose.position.z = 0.08; 189 | 190 | set_state_srv.request.model_state = modelstate; 191 | 192 | if (!set_state.call(set_state_srv)) { 193 | //ROS_INFO("Position: %f %f %f", (float)srv.response.pose.position.x, (float)srv.response.pose.position.y, (float)srv.response.pose.position.z); 194 | ROS_ERROR("Failed to set position"); 195 | } 196 | 197 | modelstate.model_name = "small box_3"; 198 | modelstate.pose.position.x = pos[0] - 0.31; 199 | modelstate.pose.position.y = pos[1]; 200 | modelstate.pose.position.z = 0.08; 201 | 202 | set_state_srv.request.model_state = modelstate; 203 | 204 | if (!set_state.call(set_state_srv)) { 205 | //ROS_INFO("Position: %f %f %f", (float)srv.response.pose.position.x, (float)srv.response.pose.position.y, (float)srv.response.pose.position.z); 206 | ROS_ERROR("Failed to set position"); 207 | } 208 | } 209 | 210 | void respawn() { 211 | mav_msgs::RollPitchYawrateThrust msg; 212 | msg.roll = 0; 213 | msg.pitch = 0; 214 | msg.yaw_rate = 0; 215 | msg.thrust.z = 0; 216 | std_srvs::Empty srv; 217 | 218 | current_position = {0,0,0}; 219 | current_orientation = {0,0,0,0}; 220 | step_counter = 0; 221 | 222 | if (random_target) { 223 | std::uniform_real_distribution unif(0, 8); 224 | 225 | target_position = {round_number(unif(re))-4, round_number(unif(re))-4, 7.5};//round_number(unifz(re))}; 226 | } 227 | 228 | current_control_params.resize(4, 0); 229 | firefly_reset_client.call(srv); 230 | firefly_control_pub.publish(msg); 231 | setCrossPosition(target_position); 232 | gazebo_msgs::ModelState modelstate; 233 | gazebo_msgs::SetModelState set_state_srv; 234 | modelstate.model_name = "firefly"; 235 | modelstate.pose.position.x = 0.0; 236 | modelstate.pose.position.y = 0.0; 237 | modelstate.pose.position.z = 5.0; 238 | set_state_srv.request.model_state = modelstate; 239 | set_state.call(set_state_srv); 240 | } 241 | 242 | void pausePhysics() { 243 | std_srvs::Empty srv; 244 | pause_physics.call(srv); 245 | } 246 | 247 | void unpausePhysics() { 248 | std_srvs::Empty srv; 249 | unpause_physics.call(srv); 250 | } 251 | 252 | void onCollisionGround(const gazebo_msgs::ContactsState::ConstPtr& msg) { 253 | if ((step_counter > 5) && (current_position[2] < 0.5 || current_position[2] > 12.5 || msg->states.size() > 0)) { 254 | ROS_INFO("Crash, respawn..."); 255 | step_counter = 0; 256 | crashed_flag = true; 257 | respawn(); 258 | } 259 | } 260 | 261 | void onCollision(const gazebo_msgs::ContactsState::ConstPtr& msg) { 262 | if ((step_counter > 5) && (current_position[2] < 0.5 || current_position[2] > 12.5 || msg->states.size() > 0)) { 263 | ROS_INFO("Crash, respawn..."); 264 | step_counter = 0; 265 | crashed_flag = true; 266 | respawn(); 267 | } 268 | } 269 | 270 | bool performAction(rotors_reinforce::PerformAction::Request &req, rotors_reinforce::PerformAction::Response &res) { 271 | //respawn code 272 | if (req.action[3] == 42) { 273 | respawn(); 274 | return true; 275 | } 276 | 277 | if ((step_counter > 5) && (current_position[2] < 0.5 || current_position[2] > 12.5)) { 278 | ROS_INFO("Crash, respawn..."); 279 | step_counter = 0; 280 | crashed_flag = true; 281 | respawn(); 282 | } 283 | 284 | mav_msgs::RollPitchYawrateThrust msg; 285 | msg.roll = req.action[0] * max_roll * axes_roll_direction; 286 | msg.pitch = req.action[1] * max_pitch * axes_pitch_direction; 287 | 288 | if(req.action[2] > 0.01) { 289 | current_yaw_vel_ = max_rate_yaw; 290 | } 291 | else if (req.action[2] < -0.01) { 292 | current_yaw_vel_ = max_rate_yaw; 293 | } 294 | else { 295 | current_yaw_vel_ = 0.0; 296 | } 297 | 298 | msg.yaw_rate = current_yaw_vel_; 299 | msg.thrust.z = req.action[3] * max_thrust * axes_thrust_direction; 300 | 301 | 302 | ROS_INFO("roll: %f, pitch: %f, yaw_rate: %f, thrust %f", msg.roll, msg.pitch, msg.yaw_rate, msg.thrust.z); 303 | 304 | if (enable_wind && step_counter % 20 == 0) { //new wind after every 20 steps (1 second) 305 | rotors_comm::WindSpeed wind_msg; 306 | std::uniform_real_distribution unif(0, 1); 307 | if (unif(re) > 0.75) { //probability of 75% for no wind 308 | double rand_number = unif(re); 309 | if (rand_number < 0.34) { //wind comes only from one direction 310 | wind_msg.velocity.z = 0; 311 | wind_msg.velocity.y = 0; 312 | wind_msg.velocity.x = unif(re) * MAX_WIND_VELOCTIY; 313 | } else if (rand_number < 0.67) { 314 | wind_msg.velocity.x = 0; 315 | wind_msg.velocity.z = 0; 316 | wind_msg.velocity.y = unif(re) * MAX_WIND_VELOCTIY; 317 | } else { 318 | wind_msg.velocity.x = 0; 319 | wind_msg.velocity.y = 0; 320 | wind_msg.velocity.z = unif(re) * MAX_WIND_VELOCTIY; 321 | } 322 | ROS_INFO("wind with velocity of x: %f, y: %f, z: %f", wind_msg.velocity.x, wind_msg.velocity.y, wind_msg.velocity.z); 323 | } 324 | else { 325 | ROS_INFO("no wind"); 326 | wind_msg.velocity.x = 0; 327 | wind_msg.velocity.y = 0; 328 | wind_msg.velocity.z = 0; 329 | } 330 | firefly_wind_pub.publish(wind_msg); 331 | } 332 | unpausePhysics(); 333 | unpausePhysics(); 334 | firefly_control_pub.publish(msg); 335 | ros::Duration(0.05).sleep(); //sleep 50ms of simulation time 336 | getPosition(); 337 | pausePhysics(); 338 | step_counter++; 339 | 340 | res.target_position = target_position; 341 | res.position = current_position; 342 | res.orientation = current_orientation; 343 | res.img = current_img; 344 | res.img_depth = current_img_depth; 345 | 346 | res.reward = getReward(step_counter); 347 | res.crashed = false; 348 | 349 | //crash check at the end 350 | if(crashed_flag) { 351 | res.crashed = true; 352 | crashed_flag = false; 353 | } 354 | 355 | return true; 356 | } 357 | 358 | bool getState(rotors_reinforce::GetState::Request &req, rotors_reinforce::GetState::Response &res) { 359 | getPosition(); 360 | res.target_position = target_position; 361 | res.position = current_position; 362 | res.orientation = current_orientation; 363 | res.img = current_img; 364 | res.img_depth = current_img_depth; 365 | 366 | res.reward = 0.0; 367 | res.crashed = crashed_flag; 368 | return true; 369 | } 370 | 371 | void getPosition() { 372 | gazebo_msgs::GetModelState srv; 373 | srv.request.model_name = "firefly"; 374 | if (get_position_client.call(srv)) { 375 | ROS_INFO("Position: %f %f %f", (float)srv.response.pose.position.x, (float)srv.response.pose.position.y, (float)srv.response.pose.position.z); 376 | 377 | current_position[0] = (double)srv.response.pose.position.x; 378 | current_position[1] = (double)srv.response.pose.position.y; 379 | current_position[2] = (double)srv.response.pose.position.z; 380 | current_orientation[0] = (double)srv.response.pose.orientation.x; 381 | current_orientation[1] = (double)srv.response.pose.orientation.y; 382 | current_orientation[2] = (double)srv.response.pose.orientation.z; 383 | current_orientation[3] = (double)srv.response.pose.orientation.w; 384 | } 385 | else { 386 | ROS_ERROR("Failed to get position"); 387 | } 388 | } 389 | 390 | double getReward(const int count) { 391 | double difx = current_position[0] - target_position[0]; 392 | double dify = current_position[1] - target_position[1]; 393 | double difz = current_position[2] - target_position[2]; 394 | 395 | double current_distance = std::sqrt(difx * difx + dify * dify + 2.0 * difz * difz); 396 | 397 | double reward = 0.0; 398 | 399 | if (crashed_flag) { 400 | return 0.0; 401 | } 402 | 403 | double reward4position = 1/(current_distance + 1.0); 404 | //double reward4orientation = 1/((current_orientation[0] * current_orientation[0] + current_orientation[1] * current_orientation[1] + current_orientation[2] * current_orientation[2])/(current_orientation[3] * current_orientation[3]) + 1); 405 | 406 | reward = reward4position;//*0.95 + 0.05 * reward4orientation; 407 | 408 | return reward; 409 | } 410 | 411 | }; 412 | 413 | 414 | 415 | int main(int argc, char **argv) 416 | { 417 | ros::init(argc, argv, "talker"); 418 | 419 | ros::NodeHandle n; 420 | 421 | ros::Rate loop_rate(100); 422 | 423 | std::vector target_position = {0.0, 0.0, 0.0}; 424 | bool wind = false; 425 | 426 | if(argc > 1) { 427 | target_position[0] = atoi(argv[1]); 428 | target_position[1] = atoi(argv[2]); 429 | target_position[2] = atoi(argv[3]); 430 | wind = atoi(argv[4]); 431 | } 432 | 433 | environmentTracker tracker(n, target_position, wind); 434 | 435 | ROS_INFO("Comunication node ready"); 436 | 437 | ros::spin(); 438 | 439 | //delete tracker; 440 | return 0; 441 | } 442 | -------------------------------------------------------------------------------- /srv/GetState.srv: -------------------------------------------------------------------------------- 1 | --- 2 | float64[] position 3 | float64[] velocity 4 | float64[] target_position 5 | float64[] orientation 6 | sensor_msgs/Image img 7 | sensor_msgs/Image img_depth 8 | float64 reward 9 | bool crashed 10 | -------------------------------------------------------------------------------- /srv/PerformAction.srv: -------------------------------------------------------------------------------- 1 | float64[] action 2 | --- 3 | float64[] position 4 | float64[] target_position 5 | float64[] orientation 6 | sensor_msgs/Image img 7 | sensor_msgs/Image img_depth 8 | float64 reward 9 | bool crashed 10 | --------------------------------------------------------------------------------