├── CMakeLists.txt
├── README.md
├── launch
└── mav.launch
├── multirotor_base.xacro
├── package.xml
├── scripts
├── History.py
├── ReinforceLearning_node.py
├── cv_trainer.py
├── model.py
└── ops.py
├── src
└── EnvironmentTracker_node.cpp
└── srv
├── GetState.srv
└── PerformAction.srv
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 2.8.3)
2 | project(rotors_reinforce)
3 |
4 | ## Add support for C++11, supported in ROS Kinetic and newer
5 | # add_definitions(-std=c++11)
6 |
7 | ## Find catkin macros and libraries
8 | ## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz)
9 | ## is used, also find other catkin packages
10 | find_package(catkin REQUIRED COMPONENTS
11 | gazebo_msgs
12 | gazebo_plugins
13 | geometry_msgs
14 | mav_msgs
15 | roscpp
16 | rospy
17 | std_msgs
18 | message_generation
19 | rotors_gazebo_plugins
20 | sensor_msgs
21 | std_msgs
22 | )
23 |
24 | ## System dependencies are found with CMake's conventions
25 | # find_package(Boost REQUIRED COMPONENTS system)
26 |
27 |
28 | ## Uncomment this if the package has a setup.py. This macro ensures
29 | ## modules and global scripts declared therein get installed
30 | ## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html
31 | #catkin_python_setup()
32 |
33 | ################################################
34 | ## Declare ROS messages, services and actions ##
35 | ################################################
36 |
37 | ## To declare and build messages, services or actions from within this
38 | ## package, follow these steps:
39 | ## * Let MSG_DEP_SET be the set of packages whose message types you use in
40 | ## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...).
41 | ## * In the file package.xml:
42 | ## * add a build_depend tag for "message_generation"
43 | ## * add a build_depend and a run_depend tag for each package in MSG_DEP_SET
44 | ## * If MSG_DEP_SET isn't empty the following dependency has been pulled in
45 | ## but can be declared for certainty nonetheless:
46 | ## * add a run_depend tag for "message_runtime"
47 | ## * In this file (CMakeLists.txt):
48 | ## * add "message_generation" and every package in MSG_DEP_SET to
49 | ## find_package(catkin REQUIRED COMPONENTS ...)
50 | ## * add "message_runtime" and every package in MSG_DEP_SET to
51 | ## catkin_package(CATKIN_DEPENDS ...)
52 | ## * uncomment the add_*_files sections below as needed
53 | ## and list every .msg/.srv/.action file to be processed
54 | ## * uncomment the generate_messages entry below
55 | ## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...)
56 |
57 | ## Generate messages in the 'msg' folder
58 | # add_message_files(
59 | # FILES
60 | # Message1.msg
61 | # Message2.msg
62 | # )
63 |
64 | ## Generate services in the 'srv' folder
65 | add_service_files(FILES PerformAction.srv GetState.srv)
66 |
67 | ## Generate actions in the 'action' folder
68 | # add_action_files(
69 | # FILES
70 | # Action1.action
71 | # Action2.action
72 | # )
73 |
74 | ## Generate added messages and services with any dependencies listed here
75 | #generate_messages(
76 | # DEPENDENCIES
77 | # gazebo_msgs# geometry_msgs# mav_msgs# sensor_msgs
78 | # )
79 |
80 | generate_messages(DEPENDENCIES std_msgs sensor_msgs)
81 |
82 | ################################################
83 | ## Declare ROS dynamic reconfigure parameters ##
84 | ################################################
85 |
86 | ## To declare and build dynamic reconfigure parameters within this
87 | ## package, follow these steps:
88 | ## * In the file package.xml:
89 | ## * add a build_depend and a run_depend tag for "dynamic_reconfigure"
90 | ## * In this file (CMakeLists.txt):
91 | ## * add "dynamic_reconfigure" to
92 | ## find_package(catkin REQUIRED COMPONENTS ...)
93 | ## * uncomment the "generate_dynamic_reconfigure_options" section below
94 | ## and list every .cfg file to be processed
95 |
96 | ## Generate dynamic reconfigure parameters in the 'cfg' folder
97 | # generate_dynamic_reconfigure_options(
98 | # cfg/DynReconf1.cfg
99 | # cfg/DynReconf2.cfg
100 | # )
101 |
102 | ###################################
103 | ## catkin specific configuration ##
104 | ###################################
105 | ## The catkin_package macro generates cmake config files for your package
106 | ## Declare things to be passed to dependent projects
107 | ## INCLUDE_DIRS: uncomment this if you package contains header files
108 | ## LIBRARIES: libraries you create in this project that dependent projects also need
109 | ## CATKIN_DEPENDS: catkin_packages dependent projects also need
110 | ## DEPENDS: system dependencies of this project that dependent projects also need
111 | catkin_package(
112 | # INCLUDE_DIRS include
113 | # LIBRARIES rotors_reinforce
114 | CATKIN_DEPENDS gazebo_msgs gazebo_plugins geometry_msgs mav_msgs roscpp rospy rotors_gazebo_plugins sensor_msgs std_msgs
115 | # DEPENDS system_lib
116 | )
117 |
118 | ###########
119 | ## Build ##
120 | ###########
121 |
122 | ## Specify additional locations of header files
123 | ## Your package locations should be listed before other locations
124 | # include_directories(include)
125 | include_directories(
126 | ${catkin_INCLUDE_DIRS}
127 | )
128 |
129 | ## Declare a C++ library
130 | # add_library(${PROJECT_NAME}
131 | # src/${PROJECT_NAME}/rotors_reinforce.cpp
132 | # )
133 |
134 | ## Add cmake target dependencies of the library
135 | ## as an example, code may need to be generated before libraries
136 | ## either from message generation or dynamic reconfigure
137 | #add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
138 |
139 | ## Declare a C++ executable
140 | ## With catkin_make all packages are built within a single CMake context
141 | ## The recommended prefix ensures that target names across packages don't collide
142 | add_executable(EnvironmentTracker_node src/EnvironmentTracker_node.cpp)
143 |
144 | ## Rename C++ executable without prefix
145 | ## The above recommended prefix causes long target names, the following renames the
146 | ## target back to the shorter version for ease of user use
147 | ## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node"
148 | # set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "")
149 |
150 | ## Add cmake target dependencies of the executable
151 | ## same as for the library above
152 | add_dependencies(EnvironmentTracker_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
153 |
154 | ## Specify libraries to link a library or executable target against
155 | target_link_libraries(EnvironmentTracker_node
156 | ${catkin_LIBRARIES}
157 | )
158 |
159 | #############
160 | ## Install ##
161 | #############
162 |
163 | # all install targets should use catkin DESTINATION variables
164 | # See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html
165 |
166 | ## Mark executable scripts (Python etc.) for installation
167 | ## in contrast to setup.py, you can choose the destination
168 | catkin_install_python(PROGRAMS
169 | scripts/ReinforceLearning_node.py
170 | DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
171 | )
172 |
173 | ## Mark executables and/or libraries for installation
174 | # install(TARGETS ${PROJECT_NAME} ${PROJECT_NAME}_node
175 | # ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
176 | # LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
177 | # RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
178 | # )
179 |
180 | ## Mark cpp header files for installation
181 | # install(DIRECTORY include/${PROJECT_NAME}/
182 | # DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION}
183 | # FILES_MATCHING PATTERN "*.h"
184 | # PATTERN ".svn" EXCLUDE
185 | # )
186 |
187 | ## Mark other files for installation (e.g. launch and bag files, etc.)
188 | # install(FILES
189 | # # myfile1
190 | # # myfile2
191 | # DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
192 | # )
193 |
194 | #############
195 | ## Testing ##
196 | #############
197 |
198 | ## Add gtest based cpp test target and link libraries
199 | # catkin_add_gtest(${PROJECT_NAME}-test test/test_rotors_reinforce.cpp)
200 | # if(TARGET ${PROJECT_NAME}-test)
201 | # target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME})
202 | # endif()
203 |
204 | ## Add folders to be run by python nosetests
205 | # catkin_add_nosetests(test)
206 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # deep-reinforcement-learning-drone-control
2 | This is a deep reinforcement learning based drone control system implemented in python (Tensorflow/ROS) and C++ (ROS). To test it, please clone the rotors simulator from https://github.com/ethz-asl/rotors_simulator in your catkin workspace. Copy the multirotor_base.xarco to the rotors simulator for adding the camera to the drone.
3 |
4 | The drone control system operates on camera images as input and a discretized version of the steering commands as output. The neural network model is end-to-end and a non-asynchronous implementation of the A3C model (https://arxiv.org/pdf/1602.01783.pdf), because the gazebo simulator is not capable of running multiple copies in parallel (and neither is my laptop :D). The training is performed on the basis of pretrained weights from a supervised learning task, since the simulator is very resource intensive and training is time consuming.
5 |
6 | The outcome was discussed within a practical course at the RWTH Aachen, where this agent served as a proof-of-concept, that it is possible to efficiently train an end-to-end deep reinforcement learning model on the task of controlling a drone in a realistic 3D environment.
7 |
--------------------------------------------------------------------------------
/launch/mav.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/multirotor_base.xacro:
--------------------------------------------------------------------------------
1 |
2 |
21 |
22 |
23 |
24 |
25 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 | ${robot_namespace}
80 | ${robot_namespace}/base_link
81 | ${rotor_velocity_slowdown_sim}
82 |
83 |
84 |
85 |
86 | 10.0
87 | 1
88 |
89 | ${robot_namespace}/base_link_fixed_joint_lump__base_link_collision
90 |
91 |
92 | true
93 | 10.0
94 | /base_collision
95 | world
96 |
97 |
98 |
99 |
100 |
101 |
102 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
123 |
124 |
125 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 | ${robot_namespace}
140 | ${robot_namespace}/rotor_${motor_number}_joint
141 | ${robot_namespace}/rotor_${motor_number}
142 | ${direction}
143 | ${time_constant_up}
144 | ${time_constant_down}
145 | ${max_rot_velocity}
146 | ${motor_constant}
147 | ${moment_constant}
148 | gazebo/command/motor_speed
149 | ${motor_number}
150 | ${rotor_drag_coefficient}
151 | ${rolling_moment_coefficient}
152 | motor_speed/${motor_number}
153 | ${rotor_velocity_slowdown_sim}
154 |
155 |
156 |
157 | Gazebo/${color}
158 |
159 | 10.0
160 | 1
161 |
162 | ${robot_namespace}/rotor_${motor_number}_collision
163 |
164 |
165 | true
166 | 10.0
167 | /rotor_collision
168 | world
169 |
170 |
171 |
172 |
173 |
174 |
--------------------------------------------------------------------------------
/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | rotors_reinforce
4 | 0.0.0
5 | The Reinforcement Learning package
6 |
7 |
8 |
9 |
10 | fischer_t
11 |
12 |
13 |
14 |
15 |
16 | TODO
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 | catkin
43 | message_generation
44 | gazebo_msgs
45 | gazebo_plugins
46 | geometry_msgs
47 | mav_msgs
48 | roscpp
49 | rospy
50 | rotors_gazebo_plugins
51 | sensor_msgs
52 | std_msgs
53 |
54 | message_runtime
55 | gazebo_msgs
56 | gazebo_plugins
57 | geometry_msgs
58 | mav_msgs
59 | roscpp
60 | rospy
61 | rotors_gazebo_plugins
62 | sensor_msgs
63 | std_msgs
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/scripts/History.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import numpy as np
4 |
5 | class History:
6 | def __init__(self, actionNumber, actionSize):
7 | self.states = []
8 | self.actions = []
9 | self.rewards = []
10 | self.responses = []
11 | self.actionNumber = actionNumber
12 | self.actionSize = actionSize
13 |
14 | def addAction2History(self, action):
15 |
16 | self.actions.append(action)
17 |
18 | def getActions(self):
19 | return self.actions
20 |
21 | def getLastAction(self):
22 | assert (len(self.actions) > 0), "Action history is empty!"
23 | return self.actions[-1]
24 |
25 | def addState2History(self, state):
26 | self.states.append(state)
27 |
28 | def addResponse2History(self, response):
29 | self.responses.append(response)
30 |
31 | def getStates(self):
32 | return self.states[:-1]
33 |
34 | def getResponses(self):
35 | return self.responses[:-1]
36 |
37 | def getState(self, iterator):
38 | assert (len(self.states) > 0), "State history is empty!"
39 | return np.expand_dims(self.states[iterator], 0)
40 |
41 | def getLastState(self):
42 | assert (len(self.states) > 0), "State history is empty!"
43 | return np.expand_dims(self.states[-1], 0)
44 |
45 | def addReward2History(self, reward):
46 | self.rewards.append(reward)
47 |
48 | def getRewardHistory(self):
49 | return self.rewards
50 |
51 | def getLastReward(self):
52 | assert (len(self.rewards) > 0), "Reward history is empty!"
53 | return self.rewards[-1]
54 |
55 | def sumRewards(self):
56 | return sum(self.rewards)
57 |
58 | def clean(self):
59 | self.states = []
60 | self.actions = []
61 | self.rewards = []
62 | self.responses = []
--------------------------------------------------------------------------------
/scripts/ReinforceLearning_node.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import rospy
4 | from rotors_reinforce.srv import PerformAction
5 | from rotors_reinforce.srv import GetState
6 | from History import *
7 | from model import ComputationalGraph
8 |
9 | # ROS Image message -> OpenCV2 image converter
10 | from cv_bridge import CvBridge, CvBridgeError
11 | # OpenCV2 for saving an image
12 | import cv2
13 |
14 | import tensorflow as tf
15 | import random
16 | import numpy as np
17 |
18 | TRAIN_MODEL = 1
19 |
20 | EPISODE_LENGTH = 120
21 |
22 | EPSILON = 0.9
23 | EPSILON_DECAY = 0.97
24 | EPSILON_MIN = 0.01
25 | EPSILON_STEP = 10
26 |
27 | ACTION_NUM = 3
28 | ACTION_SIZE = 3
29 | IMG_HEIGHT = 75
30 | IMG_WIDTH = 75
31 | IMG_CHANNELS = 2
32 | RESPAWN_CODE = [0, 0, 0, 42]
33 |
34 | bridge = CvBridge()
35 |
36 |
37 | def convertImage(img):
38 | try:
39 | # Convert your ROS Image message to OpenCV2
40 | cv2_img = bridge.imgmsg_to_cv2(img, "mono8")
41 | except CvBridgeError, e:
42 | print(e)
43 |
44 | # Format for the Tensor
45 | cv2_img = cv2.resize(cv2_img, (IMG_WIDTH, IMG_HEIGHT))
46 |
47 | return cv2_img
48 |
49 |
50 | def convertState(state, old_state):
51 | converted_state = np.concatenate((np.expand_dims(convertImage(state.img), -1), np.expand_dims(convertImage(old_state.img), -1)), axis=2)
52 | return converted_state #return 2 channel image consisting of 2 states
53 | #return np.expand_dims(convertImage(state.img), -1)
54 |
55 | def chooseAction(probabilities, e, is_training):
56 | action = np.zeros(ACTION_NUM)
57 |
58 | for i in range(len(probabilities[0])):
59 |
60 | if is_training and random.uniform(0, 1) < e:
61 | action[i] = random.randint(0, 2)
62 | else:
63 | action[i] = np.argmax(probabilities[0][i])
64 |
65 | return action
66 |
67 |
68 |
69 | def reinforce_node():
70 |
71 | #set up env
72 | rospy.init_node('ReinforceLearning', anonymous=True)
73 | perform_action_client = rospy.ServiceProxy('env_tr_perform_action', PerformAction)
74 | get_state_client = rospy.ServiceProxy('env_tr_get_state', GetState)
75 |
76 | history = History(ACTION_NUM, ACTION_SIZE)
77 |
78 | graph = ComputationalGraph(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
79 |
80 | action = np.zeros(ACTION_NUM)
81 | executed_action = np.zeros(ACTION_NUM + 1)
82 |
83 | sess = tf.Session()
84 |
85 | graph.constructGraph(sess, ACTION_NUM, ACTION_SIZE)
86 |
87 | # restoring agent
88 | saver = tf.train.Saver()
89 | try:
90 | saver.restore(sess, "./log/model.ckpt")
91 | print "model restored"
92 | except:
93 | print "model restore failed. random initialization"
94 | init_op = tf.global_variables_initializer()
95 | sess.run(init_op)
96 |
97 | # restore pretrained layers for random init
98 | vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='CV_graph')
99 | cv_saver = tf.train.Saver(vars)
100 | cv_saver.restore(sess, "./cv_graph/cv_graph.cptk")
101 | print "pretrained weights restored"
102 |
103 |
104 | step = 0
105 | e = EPSILON
106 |
107 | # main loop:
108 | while not rospy.is_shutdown():
109 | crashed_flag = False
110 |
111 | #get initial state
112 | rospy.wait_for_service('env_tr_get_state')
113 | try:
114 | response = get_state_client()
115 | except rospy.ServiceException, e:
116 | print "Service call failed: %s" % e
117 |
118 |
119 | old_response = response
120 |
121 | state = convertState(response, old_response)
122 | history.clean()
123 | history.addState2History(state)
124 |
125 | print response.target_position
126 |
127 | episode_step = 0
128 | # run episode, while not crashed and simulation is running
129 | while not crashed_flag and not rospy.is_shutdown():
130 | #get most probable variant to act for each action, and the probabilities
131 | probabilities = graph.calculateAction(sess, history.getLastState())
132 |
133 | #choose action according to softmax distribution (add epsilon randomness in training)
134 | action = chooseAction(probabilities, e, TRAIN_MODEL)
135 |
136 | #choose roll, pitch and thrust according to network output
137 | executed_action[0] = (float(action[0]) - float(ACTION_SIZE) / 2) * 2.0 / float(ACTION_SIZE - 1)
138 | executed_action[1] = (float(action[1]) - float(ACTION_SIZE) / 2) * 2.0 / float(ACTION_SIZE - 1)
139 | #executed_action[2] = (float(action[2]) - float(ACTION_SIZE) / 2)
140 | executed_action[2] = 0. #we skip the yaw command dimesion
141 | executed_action[3] = float(action[2]) / float(ACTION_SIZE - 1)
142 |
143 | rospy.wait_for_service('env_tr_perform_action')
144 | try:
145 | response = perform_action_client(executed_action)
146 | except rospy.ServiceException, e:
147 | print "Service call failed: %s" % e
148 |
149 | state = convertState(response, old_response)
150 | old_response = response
151 |
152 | #update history
153 | actionmatrix = np.zeros([ACTION_NUM, ACTION_SIZE])
154 | for i in xrange(len(actionmatrix)):
155 | actionmatrix[i][int(action[i])] = 1
156 |
157 | history.addAction2History(actionmatrix)
158 | history.addState2History(state)
159 | history.addResponse2History(response)
160 | history.addReward2History(response.reward)
161 |
162 | crashed_flag = response.crashed
163 |
164 | episode_step+=1
165 | if episode_step >= EPISODE_LENGTH:
166 | rospy.wait_for_service('env_tr_perform_action')
167 | try:
168 | response = perform_action_client(RESPAWN_CODE)
169 | except rospy.ServiceException, e:
170 | print "Service call failed: %s" % e
171 | break
172 |
173 | # update policy
174 | if TRAIN_MODEL == 1:
175 | graph.updatePolicy(sess, history, step)
176 |
177 | #save every 50 episodes
178 | if TRAIN_MODEL == 1 and step % 50 == 0:
179 | save_path = saver.save(sess, "log/model.ckpt")
180 | print("Model saved in file: %s" % save_path)
181 |
182 | if step % EPSILON_STEP == 0:
183 | e = e * EPSILON_DECAY
184 |
185 | print "episode number: ", step
186 | print "total reward: ", history.sumRewards()
187 | step += 1
188 |
189 |
190 | if __name__ == '__main__':
191 | try:
192 | print("starting...")
193 | reinforce_node()
194 | except rospy.ROSInterruptException:
195 | pass
196 |
197 |
--------------------------------------------------------------------------------
/scripts/cv_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # ROS Image message -> OpenCV2 image converter
4 | from cv_bridge import CvBridge, CvBridgeError
5 | import matplotlib.pyplot as plt
6 | # OpenCV2 for saving an image
7 | import cv2
8 |
9 | import tensorflow as tf
10 | import numpy as np
11 |
12 | import pickle
13 |
14 | from ops import conv2d, dense
15 |
16 |
17 | IMG_WIDTH = 75
18 | IMG_HEIGHT = 75
19 | CHANNELS = 2
20 | OUTPUT_SIZE = 3
21 |
22 | NUM_EPOCHS = 40
23 | BATCH_SIZE = 128
24 | LEARNING_RATE = 0.0001
25 | DROPOUT_KEEP_PROBABILITY = 0.5
26 |
27 | FILENAME_TRAINING_DATA = "data_training"
28 | FILENAME_TEST_DATA = "data_test"
29 |
30 | SAVE_PATH = "./cv_graph/"
31 | if not os.path.exists(SAVE_PATH):
32 | os.makedirs(SAVE_PATH)
33 |
34 | class TrainingData:
35 | def __init__(self):
36 | self.img = []
37 | self.coordinate = []
38 | self.orientation = []
39 |
40 | def add(self, im, c, o):
41 | self.img.append(im)
42 | self.coordinate.append(c)
43 | self.orientation.append(o)
44 |
45 |
46 | class CVGraph:
47 | def __init__(self, h, w, c, o):
48 | self.screen_height = h
49 | self.screen_width = w
50 | self.channels = c
51 | self.output_size = o
52 | self.orientation_size = 4
53 | self.cnn_output_size = 3
54 |
55 | def buildGraph(self):
56 | with tf.variable_scope("CV_graph"):
57 | self.input = tf.placeholder('float32',
58 | [None, self.screen_height, self.screen_width, self.channels], name='input')
59 |
60 | initializer = tf.contrib.layers.xavier_initializer()
61 | activation_fn = tf.nn.relu
62 |
63 | self.conv1, self.conv1_w, self.conv1_b = conv2d(self.input,
64 | 16,
65 | [5, 5],
66 | [1, 1],
67 | initializer,
68 | activation_fn,
69 | 'NHWC',
70 | name='conv1')
71 |
72 | self.conv2, self.conv2_w, self.conv2_b = conv2d(self.conv1,
73 | 16,
74 | [5, 5],
75 | [1, 1],
76 | initializer,
77 | activation_fn,
78 | 'NHWC',
79 | name='conv2')
80 |
81 | self.max_pool1 = tf.nn.max_pool(self.conv2,
82 | [1, 2, 2, 1],
83 | [1, 2, 2, 1],
84 | padding='VALID',
85 | name='max_pool1')
86 |
87 | self.conv3, self.conv3_w, self.conv3_b = conv2d(self.max_pool1,
88 | 32,
89 | [3, 3],
90 | [1, 1],
91 | initializer,
92 | activation_fn,
93 | 'NHWC',
94 | name='conv3')
95 |
96 | self.conv4, self.conv4_w, self.conv4_b = conv2d(self.conv3,
97 | 32,
98 | [3, 3],
99 | [1, 1],
100 | initializer,
101 | activation_fn,
102 | 'NHWC',
103 | name='conv4')
104 |
105 | self.max_pool2 = tf.nn.max_pool(self.conv4,
106 | [1, 2, 2, 1],
107 | [1, 2, 2, 1],
108 | padding='VALID',
109 | name='max_pool2')
110 |
111 | self.flat = tf.contrib.layers.flatten(self.max_pool2)
112 |
113 |
114 | self.keep_probability = tf.placeholder('float32', name='keep_probability')
115 |
116 | self.dense1, self.dense1_w, self.dense1_b = dense(inputs=self.flat, units=100,
117 | activation=tf.nn.relu, name='dense1')
118 | self.dropout1 = tf.nn.dropout(self.dense1, self.keep_probability)
119 | self.dense2, self.dense2_w, self.dense2_b = dense(inputs=self.dropout1, units=50,
120 | activation=tf.nn.relu, name='dense2')
121 | self.dropout2 = tf.nn.dropout(self.dense2, self.keep_probability)
122 |
123 | self.dense3, self.dense3_w, self.dense3_b = dense(inputs=self.dropout2,
124 | units=10, activation=tf.nn.relu,
125 | name='dense3')
126 |
127 | self.output, self.output_w, self.output_b = dense(inputs=self.dense3,
128 | units=self.output_size, activation=None,
129 | name='output')
130 |
131 | self.ground_truth = tf.placeholder('float32', [None, self.output_size], name='ground_truth')
132 | self.loss = tf.reduce_mean(tf.square(tf.subtract(self.output, self.ground_truth)))
133 | self.global_step = tf.Variable(0, trainable=False)
134 |
135 | self.optimizer = tf.train.AdamOptimizer(LEARNING_RATE).minimize(self.loss)
136 |
137 | def train(self, sess, img, labels, step):
138 | _, loss = sess.run([self.optimizer, self.loss], feed_dict={self.input: img,
139 | self.ground_truth: labels,
140 | self.global_step: step,
141 | self.keep_probability: DROPOUT_KEEP_PROBABILITY})
142 | return loss
143 |
144 | def prediction(self, sess, img):
145 | output = sess.run(self.output, feed_dict={self.input: img,
146 | self.keep_probability: 1.0})
147 |
148 | return output
149 |
150 | def validate(self, sess, img, labels, step):
151 | loss = sess.run(self.loss, feed_dict={self.input: img,
152 | self.ground_truth: labels,
153 | self.global_step: step,
154 | self.keep_probability: 1.0})
155 | return loss
156 |
157 |
158 | def convertImage(img):
159 | bridge = CvBridge()
160 |
161 | try:
162 | # Convert your ROS Image message to OpenCV2
163 | cv2_img = bridge.imgmsg_to_cv2(img, "mono8")
164 | except CvBridgeError, e:
165 | print(e)
166 |
167 | cv2_img = cv2.resize(cv2_img, (IMG_WIDTH, IMG_HEIGHT))
168 |
169 | return cv2_img
170 |
171 | def merge(image, label, old_image, old_label):
172 | merged_image = np.concatenate((np.expand_dims(image, -1), np.expand_dims(old_image, -1)), axis=2)
173 | merged_label = label - old_label
174 | return merged_image, merged_label #return 2 channel image consisting of 2 states
175 |
176 | def process(data):
177 | for i in range(len(data.img)):
178 | if (i+1 >= len(data.img)):
179 | merged_image, merged_label = merge(data.img[i], data.coordinate[i], data.img[i], data.coordinate[i])
180 | data.img[i] = merged_image
181 | data.coordinate[i] = merged_label
182 | else:
183 | merged_image, merged_label = merge(data.img[i+1], data.coordinate[i+1], data.img[i], data.coordinate[i])
184 | data.img[i] = merged_image
185 | data.coordinate[i] = merged_label
186 |
187 | return data
188 |
189 | def sampleData(data, index, sample_size):
190 | indexes = range(index, index + sample_size)
191 | images = np.array([data.img[i] for i in indexes])
192 | labels = np.array([[float(data.coordinate[i][0]), float(data.coordinate[i][1]), float(data.coordinate[i][2])] for i in indexes])
193 |
194 | return images, labels
195 |
196 |
197 | def visualize(data):
198 |
199 | for i in range(len(data.img)):
200 | print i, data.coordinate[i], data.orientation[i]
201 | plt.imshow(data.img[i], 'gray'), plt.title(str(i))
202 | plt.show()
203 |
204 |
205 | def train_net():
206 | global sess
207 | sess = tf.Session()
208 |
209 | graph = CVGraph(IMG_HEIGHT, IMG_WIDTH, CHANNELS, OUTPUT_SIZE)
210 | graph.buildGraph()
211 |
212 | init_op = tf.global_variables_initializer()
213 | sess.run(init_op)
214 |
215 | saver = tf.train.Saver()
216 |
217 | #data containing current camera image, distance to target, orientation
218 | data = pickle.load(open(FILENAME_TRAINING_DATA, "rb"))
219 | test_data = pickle.load(open(FILENAME_TEST_DATA, "rb"))
220 |
221 | data = process(data)
222 | test_data = process(test_data)
223 |
224 | global_step = 1
225 |
226 | for epoch in range(NUM_EPOCHS):
227 | print "epoch ", epoch
228 | avg_loss = 0
229 | num_batches = int(len(data.img)/(BATCH_SIZE))
230 | index = 0
231 | for i in range(num_batches):
232 | images, labels = sampleData(data, index, BATCH_SIZE)
233 | loss = graph.train(sess, images, labels, global_step)
234 | avg_loss += loss
235 | global_step += 1
236 | index += BATCH_SIZE
237 |
238 | avg_loss = avg_loss / float(num_batches)
239 |
240 | saver.save(sess, SAVE_PATH + 'cv_graph.cptk')
241 | test_loss = test_net(sess, graph, global_step, test_data)
242 |
243 | print "validation Loss: ", test_loss
244 | print "training Loss: ", avg_loss
245 |
246 |
247 | def test_net(sess, graph, global_step, test_data):
248 | num_batches = int(len(test_data.img) / BATCH_SIZE)
249 | index = 0
250 | loss = 0
251 | for i in range(num_batches):
252 | images, labels = sampleData(test_data, index, BATCH_SIZE)
253 | loss += graph.validate(sess, images, labels, global_step)
254 | index += BATCH_SIZE
255 |
256 | return loss/num_batches
257 |
258 |
259 | if __name__ == '__main__':
260 | print("starting...")
261 | train_net()
262 |
--------------------------------------------------------------------------------
/scripts/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from ops import conv2d, dense, variable_summaries
4 |
5 | class ComputationalGraph:
6 | def __init__(self, img_h, img_w, img_c):
7 | self.image_height, self.image_width, self.image_channels = img_h, img_w, img_c
8 |
9 | #learning rates suggested in A3C paper
10 | self.lrate = 2.5e-3
11 | self.learning_rate_minimum = 5.0e-4
12 |
13 | self.global_step = tf.Variable(0, trainable=False)
14 | self.decay_step = 40
15 | self.dropout_keep_prob = 0.5
16 |
17 | def constructGraph(self, sess, action_num, action_size):
18 | self.episode_rewards = tf.placeholder(tf.float32, [None], name="episode_rewards")
19 | self.state = tf.placeholder(tf.float32, [None, self.image_height, self.image_width, self.image_channels], name="state")
20 |
21 | ##ConvNet for feature extraction
22 | with tf.variable_scope("CV_graph"):
23 |
24 | initializer = tf.contrib.layers.xavier_initializer()
25 | activation_fn = tf.nn.relu
26 |
27 | self.conv1, self.conv1_w, self.conv1_b = conv2d(self.state,
28 | 16,
29 | [5, 5],
30 | [1, 1],
31 | initializer,
32 | activation_fn,
33 | 'NHWC',
34 | name='conv1')
35 |
36 |
37 | self.conv2, self.conv2_w, self.conv2_b = conv2d(self.conv1,
38 | 16,
39 | [5, 5],
40 | [1, 1],
41 | initializer,
42 | activation_fn,
43 | 'NHWC',
44 | name='conv2')
45 |
46 | self.max_pool1 = tf.nn.max_pool(self.conv2,
47 | [1, 2, 2, 1],
48 | [1, 2, 2, 1],
49 | padding='VALID',
50 | name='max_pool1')
51 |
52 | self.conv3, self.conv3_w, self.conv3_b = conv2d(self.max_pool1,
53 | 32,
54 | [3, 3],
55 | [1, 1],
56 | initializer,
57 | activation_fn,
58 | 'NHWC',
59 | name='conv3')
60 |
61 | self.conv4, self.conv4_w, self.conv4_b = conv2d(self.conv3,
62 | 32,
63 | [3, 3],
64 | [1, 1],
65 | initializer,
66 | activation_fn,
67 | 'NHWC',
68 | name='conv4')
69 |
70 | self.max_pool2 = tf.nn.max_pool(self.conv4,
71 | [1, 2, 2, 1],
72 | [1, 2, 2, 1],
73 | padding='VALID',
74 | name='max_pool2')
75 |
76 | self.flat = tf.contrib.layers.flatten(self.max_pool2)
77 |
78 | self.dense1, self.dense1_w, self.dense1_b = dense(inputs=self.flat, units=100,
79 | activation=tf.nn.relu, name='dense1')
80 | self.dropout1 = tf.nn.dropout(self.dense1, self.dropout_keep_prob)
81 | self.dense2, self.dense2_w, self.dense2_b = dense(inputs=self.dropout1, units=50,
82 | activation=tf.nn.relu, name='dense2')
83 | self.dropout2 = tf.nn.dropout(self.dense2, self.dropout_keep_prob)
84 |
85 | with tf.variable_scope("Agent"):
86 | ##policy network
87 |
88 | self.po_dense3, self.po_dense3_w, self.po_dense3_b = dense(inputs=self.dropout2, units=10,
89 | activation=tf.nn.relu, name='po_dense3')
90 |
91 | self.po_dense4, self.po_dense4_w, self.po_dense4_b = dense(inputs=self.po_dense3, units=action_num*action_size, activation=None, name='po_dense4')
92 | self.po_probabilities = tf.nn.softmax(tf.reshape(self.po_dense4, [-1, action_num, action_size]))
93 |
94 | self.po_prev_actions = tf.placeholder(tf.float32, [None, action_num, action_size], name="po_prev_action")
95 | self.po_return = tf.placeholder(tf.float32, [None, 1], name="po_return")
96 | self.po_eligibility = tf.log(tf.reduce_sum(tf.multiply(self.po_prev_actions, self.po_probabilities), axis=-1)) * self.po_return
97 | self.po_loss = -tf.reduce_sum(self.po_eligibility)
98 |
99 | ##value network
100 |
101 | self.v_dense3, self.v_dense3_w, self.v_dense3_b = dense(inputs=self.dropout2, units=10, activation=tf.nn.relu, name='v_dense3')
102 |
103 | self.v_output, self.v_output_w, self.v_output_b = dense(inputs=self.v_dense3, activation=None, units=1, name='v_output')
104 |
105 | self.v_actual_return = tf.placeholder(tf.float32, [None, 1], name="v_actual_return")
106 | self.v_loss = tf.nn.l2_loss(tf.subtract(self.v_output, self.v_actual_return))
107 |
108 | ##optimization
109 |
110 | self.learning_rate_op = tf.maximum(self.learning_rate_minimum,
111 | tf.train.exponential_decay(
112 | self.lrate,
113 | self.global_step,
114 | self.decay_step,
115 | 0.98,
116 | staircase=True))
117 |
118 | self.loss = 0.5 * self.v_loss + self.po_loss
119 |
120 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate_op).minimize(self.loss)
121 |
122 | self.constructSummary(sess)
123 |
124 | def constructSummary(self, sess):
125 | variable_summaries(self.episode_rewards)
126 | self.merged = tf.summary.merge_all()
127 | self.train_writer = tf.summary.FileWriter('./log/train', sess.graph)
128 |
129 |
130 | def calculateAction(self, sess, state):
131 | return sess.run(self.po_probabilities, feed_dict={self.state: state})
132 |
133 | def calculateReward(self, sess, state):
134 | reward = sess.run(self.v_output, feed_dict={self.state: state})
135 | return reward[0][0]
136 |
137 | def updatePolicy(self, sess, history, step):
138 | rewards = history.getRewardHistory()
139 | advantages = []
140 | update_vals = []
141 | episode_reward = 0
142 |
143 | for i, reward in enumerate(rewards):
144 | episode_reward += reward
145 | future_reward = 0
146 | future_transitions = len(rewards) - i
147 | decrease = 1
148 | for index2 in xrange(future_transitions):
149 | future_reward += rewards[(index2) + i] * decrease
150 | decrease = decrease * 0.98
151 |
152 | prediction = self.calculateReward(sess, history.getState(i))
153 | advantages.append(future_reward - prediction)
154 | update_vals.append(future_reward)
155 |
156 | statistics, _ = sess.run([self.merged, self.optimizer], feed_dict={self.state: history.getStates(),
157 | self.po_prev_actions: history.getActions(),
158 | self.po_return: np.expand_dims(advantages, axis=1),
159 | self.episode_rewards: history.getRewardHistory(),
160 | self.v_actual_return: np.expand_dims(update_vals, axis=1),
161 | self.global_step: step
162 | })
163 |
164 | self.train_writer.add_summary(statistics, step)
165 | self.train_writer.flush()
--------------------------------------------------------------------------------
/scripts/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib.layers.python.layers import initializers
3 |
4 | def clipped_error(x):
5 | # Huber loss
6 | try:
7 | return tf.select(tf.abs(x) < 1.0, 0.5 * tf.square(x), tf.abs(x) - 0.5)
8 | except:
9 | return tf.where(tf.abs(x) < 1.0, 0.5 * tf.square(x), tf.abs(x) - 0.5)
10 |
11 |
12 | def variable_summaries(var):
13 | """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
14 | with tf.name_scope('summaries'):
15 | mean = tf.reduce_mean(var)
16 | tf.summary.scalar('mean', mean)
17 | with tf.name_scope('stddev'):
18 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
19 | tf.summary.scalar('stddev', stddev)
20 | tf.summary.scalar('max', tf.reduce_max(var))
21 | tf.summary.scalar('min', tf.reduce_min(var))
22 | tf.summary.histogram('histogram', var)
23 | tf.summary.scalar('sum', tf.reduce_sum(var))
24 |
25 | def conv2d(x,
26 | output_dim,
27 | kernel_size,
28 | stride,
29 | initializer=tf.contrib.layers.xavier_initializer(),
30 | activation_fn=tf.nn.relu,
31 | data_format='NHWC',
32 | padding='VALID',
33 | name='conv2d'):
34 | with tf.variable_scope(name):
35 | if data_format == 'NCHW':
36 | stride = [1, 1, stride[0], stride[1]]
37 | kernel_shape = [kernel_size[0], kernel_size[1], x.get_shape()[1], output_dim]
38 | elif data_format == 'NHWC':
39 | stride = [1, stride[0], stride[1], 1]
40 | kernel_shape = [kernel_size[0], kernel_size[1], x.get_shape()[-1], output_dim]
41 |
42 | w = tf.get_variable('w', kernel_shape, tf.float32, initializer=initializer)
43 | conv = tf.nn.conv2d(x, w, stride, padding, data_format=data_format)
44 |
45 | b = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
46 | out = tf.nn.bias_add(conv, b, data_format)
47 |
48 | if activation_fn != None:
49 | out = activation_fn(out)
50 |
51 | return out, w, b
52 |
53 | def dense(inputs, activation, units, initializer=tf.contrib.layers.xavier_initializer(), name='dense'):
54 | shape = inputs.get_shape().as_list()
55 |
56 | with tf.variable_scope(name):
57 | w = tf.get_variable('weights', [shape[1], units], tf.float32,
58 | initializer)
59 | b = tf.get_variable('bias', [units],
60 | initializer=tf.constant_initializer(0.0))
61 |
62 | out = tf.nn.bias_add(tf.matmul(inputs, w), b)
63 |
64 | if activation != None:
65 | return activation(out), w, b
66 | else:
67 | return out, w, b
68 |
--------------------------------------------------------------------------------
/src/EnvironmentTracker_node.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | #include "ros/ros.h"
11 | #include "mav_msgs/default_topics.h"
12 | #include "mav_msgs/RollPitchYawrateThrust.h"
13 | #include "mav_msgs/Actuators.h"
14 | #include "mav_msgs/eigen_mav_msgs.h"
15 | #include "geometry_msgs/Pose.h"
16 | #include "std_srvs/Empty.h"
17 | #include "gazebo_msgs/GetModelState.h"
18 | #include "gazebo_msgs/SetModelState.h"
19 | #include "gazebo_msgs/ContactsState.h"
20 | #include "rotors_comm/WindSpeed.h"
21 | #include "gazebo_msgs/ModelState.h"
22 |
23 | #include
24 |
25 | #include "rotors_reinforce/PerformAction.h"
26 | #include "rotors_reinforce/GetState.h"
27 |
28 | #include
29 | #include
30 |
31 | const double max_v_xy = 1.0; // [m/s]
32 | const double max_roll = 10.0 * M_PI / 180.0; // [rad]
33 | const double max_pitch = 10.0 * M_PI / 180.0; // [rad]
34 | const double max_rate_yaw = 45.0 * M_PI / 180.0; // [rad/s]
35 | const double max_thrust = 30.0; // [N]
36 |
37 | const double MAX_WIND_VELOCTIY = 3.4; // meter per second --> max 12km/h (equals wind force 3)
38 |
39 | const double axes_roll_direction = -1.0;
40 | const double axes_pitch_direction = 1.0;
41 | const double axes_thrust_direction = 1.0;
42 |
43 | class environmentTracker {
44 |
45 | private:
46 | ros::NodeHandle n;
47 | ros::Publisher firefly_control_pub;
48 | ros::Publisher firefly_wind_pub;
49 |
50 | ros::ServiceClient firefly_reset_client;
51 | ros::ServiceClient get_position_client;
52 | ros::ServiceClient pause_physics;
53 | ros::ServiceClient unpause_physics;
54 | ros::ServiceClient set_state;
55 |
56 |
57 | ros::Subscriber firefly_position_sub;
58 | ros::Subscriber firefly_collision_sub;
59 | ros::Subscriber firefly_ground_collision_sub;
60 | ros::Subscriber firefly_camera_sub;
61 | ros::Subscriber firefly_camera_depth_sub;
62 | ros::ServiceServer perform_action_srv;
63 | ros::ServiceServer get_state_srv;
64 |
65 | std::default_random_engine re;
66 |
67 | sensor_msgs::Image current_img;
68 | sensor_msgs::Image current_img_depth;
69 |
70 | int step_counter;
71 | double current_yaw_vel_;
72 | bool crashed_flag;
73 | bool random_target;
74 | bool enable_wind;
75 |
76 | public:
77 | std::vector current_position;
78 | std::vector current_orientation;
79 | std::vector current_control_params;
80 |
81 | std::vector target_position;
82 |
83 | environmentTracker(ros::NodeHandle node, const std::vector target_pos, bool wind ) {
84 | current_position.resize(3);
85 | //hard constants for target
86 | if (target_pos[0] != 0.0 || target_pos[1] != 0.0 || target_pos[2] != 0.0) {
87 | target_position = target_pos;
88 | random_target = false;
89 | }
90 | else {
91 | target_position = {3.0, 1.0, 7.5};
92 | random_target = true;
93 | }
94 | enable_wind = wind;
95 | current_orientation.resize(4);
96 | current_control_params.resize(4, 0);
97 | step_counter = 0;
98 |
99 | current_yaw_vel_ = 0.0;
100 |
101 | n = node;
102 | firefly_control_pub = n.advertise("/firefly/command/roll_pitch_yawrate_thrust", 1000);
103 | firefly_wind_pub = n.advertise("/firefly/wind_speed", 1000);
104 | firefly_collision_sub = n.subscribe("/rotor_collision", 100, &environmentTracker::onCollision, this);
105 | firefly_ground_collision_sub = n.subscribe("/base_collision", 100, &environmentTracker::onCollisionGround, this);
106 | firefly_reset_client = n.serviceClient("/gazebo/reset_world");
107 | firefly_camera_sub = n.subscribe("/firefly/vi_sensor/camera_depth/camera/image_raw", 1, &environmentTracker::getImage, this);
108 | //firefly_camera_depth_sub = n.subscribe("/firefly/vi_sensor/camera_depth/depth/disparity", 1, &environmentTracker::getImageDepth, this);
109 |
110 | pause_physics = n.serviceClient("/gazebo/pause_physics");
111 | unpause_physics = n.serviceClient("/gazebo/unpause_physics");
112 | set_state = n.serviceClient("/gazebo/set_model_state");
113 |
114 | get_position_client = n.serviceClient("/gazebo/get_model_state");
115 | perform_action_srv = n.advertiseService("env_tr_perform_action", &environmentTracker::performAction, this);
116 | get_state_srv = n.advertiseService("env_tr_get_state", &environmentTracker::getState, this);
117 |
118 | gazebo_msgs::ModelState modelstate;
119 | gazebo_msgs::SetModelState set_state_srv;
120 | modelstate.model_name = "firefly";
121 | modelstate.pose.position.x = 0.0;
122 | modelstate.pose.position.y = 0.0;
123 | modelstate.pose.position.z = 5.0;
124 | set_state_srv.request.model_state = modelstate;
125 | set_state.call(set_state_srv);
126 | }
127 |
128 | double round_number(double number){
129 | return round( number * 100.0 ) / 100.0;
130 | }
131 |
132 | void getImage(const sensor_msgs::ImageConstPtr& msg)
133 | {
134 |
135 | current_img = *msg;
136 | }
137 |
138 | void getImageDepth(const sensor_msgs::ImageConstPtr& msg)
139 | {
140 |
141 | current_img_depth = *msg;
142 | }
143 |
144 | void setCrossPosition(std::vector pos) {
145 | gazebo_msgs::ModelState modelstate;
146 |
147 | gazebo_msgs::SetModelState set_state_srv;
148 |
149 | modelstate.model_name = "small box";
150 | modelstate.pose.position.x = pos[0];
151 | modelstate.pose.position.y = pos[1];
152 | modelstate.pose.position.z = 0.08;
153 |
154 | set_state_srv.request.model_state = modelstate;
155 |
156 | if (!set_state.call(set_state_srv)) {
157 | //ROS_INFO("Position: %f %f %f", (float)srv.response.pose.position.x, (float)srv.response.pose.position.y, (float)srv.response.pose.position.z);
158 | ROS_ERROR("Failed to set position");
159 | }
160 |
161 | modelstate.model_name = "small box_0";
162 | modelstate.pose.position.x = pos[0];
163 | modelstate.pose.position.y = pos[1] + 0.2;
164 | modelstate.pose.position.z = 0.08;
165 |
166 | set_state_srv.request.model_state = modelstate;
167 |
168 | if (!set_state.call(set_state_srv)) {
169 | //ROS_INFO("Position: %f %f %f", (float)srv.response.pose.position.x, (float)srv.response.pose.position.y, (float)srv.response.pose.position.z);
170 | ROS_ERROR("Failed to set position");
171 | }
172 |
173 | modelstate.model_name = "small box_1";
174 | modelstate.pose.position.x = pos[0];
175 | modelstate.pose.position.y = pos[1] - 0.2;
176 | modelstate.pose.position.z = 0.08;
177 |
178 | set_state_srv.request.model_state = modelstate;
179 |
180 | if (!set_state.call(set_state_srv)) {
181 | //ROS_INFO("Position: %f %f %f", (float)srv.response.pose.position.x, (float)srv.response.pose.position.y, (float)srv.response.pose.position.z);
182 | ROS_ERROR("Failed to set position");
183 | }
184 |
185 | modelstate.model_name = "small box_2";
186 | modelstate.pose.position.x = pos[0] + 0.31;
187 | modelstate.pose.position.y = pos[1];
188 | modelstate.pose.position.z = 0.08;
189 |
190 | set_state_srv.request.model_state = modelstate;
191 |
192 | if (!set_state.call(set_state_srv)) {
193 | //ROS_INFO("Position: %f %f %f", (float)srv.response.pose.position.x, (float)srv.response.pose.position.y, (float)srv.response.pose.position.z);
194 | ROS_ERROR("Failed to set position");
195 | }
196 |
197 | modelstate.model_name = "small box_3";
198 | modelstate.pose.position.x = pos[0] - 0.31;
199 | modelstate.pose.position.y = pos[1];
200 | modelstate.pose.position.z = 0.08;
201 |
202 | set_state_srv.request.model_state = modelstate;
203 |
204 | if (!set_state.call(set_state_srv)) {
205 | //ROS_INFO("Position: %f %f %f", (float)srv.response.pose.position.x, (float)srv.response.pose.position.y, (float)srv.response.pose.position.z);
206 | ROS_ERROR("Failed to set position");
207 | }
208 | }
209 |
210 | void respawn() {
211 | mav_msgs::RollPitchYawrateThrust msg;
212 | msg.roll = 0;
213 | msg.pitch = 0;
214 | msg.yaw_rate = 0;
215 | msg.thrust.z = 0;
216 | std_srvs::Empty srv;
217 |
218 | current_position = {0,0,0};
219 | current_orientation = {0,0,0,0};
220 | step_counter = 0;
221 |
222 | if (random_target) {
223 | std::uniform_real_distribution unif(0, 8);
224 |
225 | target_position = {round_number(unif(re))-4, round_number(unif(re))-4, 7.5};//round_number(unifz(re))};
226 | }
227 |
228 | current_control_params.resize(4, 0);
229 | firefly_reset_client.call(srv);
230 | firefly_control_pub.publish(msg);
231 | setCrossPosition(target_position);
232 | gazebo_msgs::ModelState modelstate;
233 | gazebo_msgs::SetModelState set_state_srv;
234 | modelstate.model_name = "firefly";
235 | modelstate.pose.position.x = 0.0;
236 | modelstate.pose.position.y = 0.0;
237 | modelstate.pose.position.z = 5.0;
238 | set_state_srv.request.model_state = modelstate;
239 | set_state.call(set_state_srv);
240 | }
241 |
242 | void pausePhysics() {
243 | std_srvs::Empty srv;
244 | pause_physics.call(srv);
245 | }
246 |
247 | void unpausePhysics() {
248 | std_srvs::Empty srv;
249 | unpause_physics.call(srv);
250 | }
251 |
252 | void onCollisionGround(const gazebo_msgs::ContactsState::ConstPtr& msg) {
253 | if ((step_counter > 5) && (current_position[2] < 0.5 || current_position[2] > 12.5 || msg->states.size() > 0)) {
254 | ROS_INFO("Crash, respawn...");
255 | step_counter = 0;
256 | crashed_flag = true;
257 | respawn();
258 | }
259 | }
260 |
261 | void onCollision(const gazebo_msgs::ContactsState::ConstPtr& msg) {
262 | if ((step_counter > 5) && (current_position[2] < 0.5 || current_position[2] > 12.5 || msg->states.size() > 0)) {
263 | ROS_INFO("Crash, respawn...");
264 | step_counter = 0;
265 | crashed_flag = true;
266 | respawn();
267 | }
268 | }
269 |
270 | bool performAction(rotors_reinforce::PerformAction::Request &req, rotors_reinforce::PerformAction::Response &res) {
271 | //respawn code
272 | if (req.action[3] == 42) {
273 | respawn();
274 | return true;
275 | }
276 |
277 | if ((step_counter > 5) && (current_position[2] < 0.5 || current_position[2] > 12.5)) {
278 | ROS_INFO("Crash, respawn...");
279 | step_counter = 0;
280 | crashed_flag = true;
281 | respawn();
282 | }
283 |
284 | mav_msgs::RollPitchYawrateThrust msg;
285 | msg.roll = req.action[0] * max_roll * axes_roll_direction;
286 | msg.pitch = req.action[1] * max_pitch * axes_pitch_direction;
287 |
288 | if(req.action[2] > 0.01) {
289 | current_yaw_vel_ = max_rate_yaw;
290 | }
291 | else if (req.action[2] < -0.01) {
292 | current_yaw_vel_ = max_rate_yaw;
293 | }
294 | else {
295 | current_yaw_vel_ = 0.0;
296 | }
297 |
298 | msg.yaw_rate = current_yaw_vel_;
299 | msg.thrust.z = req.action[3] * max_thrust * axes_thrust_direction;
300 |
301 |
302 | ROS_INFO("roll: %f, pitch: %f, yaw_rate: %f, thrust %f", msg.roll, msg.pitch, msg.yaw_rate, msg.thrust.z);
303 |
304 | if (enable_wind && step_counter % 20 == 0) { //new wind after every 20 steps (1 second)
305 | rotors_comm::WindSpeed wind_msg;
306 | std::uniform_real_distribution unif(0, 1);
307 | if (unif(re) > 0.75) { //probability of 75% for no wind
308 | double rand_number = unif(re);
309 | if (rand_number < 0.34) { //wind comes only from one direction
310 | wind_msg.velocity.z = 0;
311 | wind_msg.velocity.y = 0;
312 | wind_msg.velocity.x = unif(re) * MAX_WIND_VELOCTIY;
313 | } else if (rand_number < 0.67) {
314 | wind_msg.velocity.x = 0;
315 | wind_msg.velocity.z = 0;
316 | wind_msg.velocity.y = unif(re) * MAX_WIND_VELOCTIY;
317 | } else {
318 | wind_msg.velocity.x = 0;
319 | wind_msg.velocity.y = 0;
320 | wind_msg.velocity.z = unif(re) * MAX_WIND_VELOCTIY;
321 | }
322 | ROS_INFO("wind with velocity of x: %f, y: %f, z: %f", wind_msg.velocity.x, wind_msg.velocity.y, wind_msg.velocity.z);
323 | }
324 | else {
325 | ROS_INFO("no wind");
326 | wind_msg.velocity.x = 0;
327 | wind_msg.velocity.y = 0;
328 | wind_msg.velocity.z = 0;
329 | }
330 | firefly_wind_pub.publish(wind_msg);
331 | }
332 | unpausePhysics();
333 | unpausePhysics();
334 | firefly_control_pub.publish(msg);
335 | ros::Duration(0.05).sleep(); //sleep 50ms of simulation time
336 | getPosition();
337 | pausePhysics();
338 | step_counter++;
339 |
340 | res.target_position = target_position;
341 | res.position = current_position;
342 | res.orientation = current_orientation;
343 | res.img = current_img;
344 | res.img_depth = current_img_depth;
345 |
346 | res.reward = getReward(step_counter);
347 | res.crashed = false;
348 |
349 | //crash check at the end
350 | if(crashed_flag) {
351 | res.crashed = true;
352 | crashed_flag = false;
353 | }
354 |
355 | return true;
356 | }
357 |
358 | bool getState(rotors_reinforce::GetState::Request &req, rotors_reinforce::GetState::Response &res) {
359 | getPosition();
360 | res.target_position = target_position;
361 | res.position = current_position;
362 | res.orientation = current_orientation;
363 | res.img = current_img;
364 | res.img_depth = current_img_depth;
365 |
366 | res.reward = 0.0;
367 | res.crashed = crashed_flag;
368 | return true;
369 | }
370 |
371 | void getPosition() {
372 | gazebo_msgs::GetModelState srv;
373 | srv.request.model_name = "firefly";
374 | if (get_position_client.call(srv)) {
375 | ROS_INFO("Position: %f %f %f", (float)srv.response.pose.position.x, (float)srv.response.pose.position.y, (float)srv.response.pose.position.z);
376 |
377 | current_position[0] = (double)srv.response.pose.position.x;
378 | current_position[1] = (double)srv.response.pose.position.y;
379 | current_position[2] = (double)srv.response.pose.position.z;
380 | current_orientation[0] = (double)srv.response.pose.orientation.x;
381 | current_orientation[1] = (double)srv.response.pose.orientation.y;
382 | current_orientation[2] = (double)srv.response.pose.orientation.z;
383 | current_orientation[3] = (double)srv.response.pose.orientation.w;
384 | }
385 | else {
386 | ROS_ERROR("Failed to get position");
387 | }
388 | }
389 |
390 | double getReward(const int count) {
391 | double difx = current_position[0] - target_position[0];
392 | double dify = current_position[1] - target_position[1];
393 | double difz = current_position[2] - target_position[2];
394 |
395 | double current_distance = std::sqrt(difx * difx + dify * dify + 2.0 * difz * difz);
396 |
397 | double reward = 0.0;
398 |
399 | if (crashed_flag) {
400 | return 0.0;
401 | }
402 |
403 | double reward4position = 1/(current_distance + 1.0);
404 | //double reward4orientation = 1/((current_orientation[0] * current_orientation[0] + current_orientation[1] * current_orientation[1] + current_orientation[2] * current_orientation[2])/(current_orientation[3] * current_orientation[3]) + 1);
405 |
406 | reward = reward4position;//*0.95 + 0.05 * reward4orientation;
407 |
408 | return reward;
409 | }
410 |
411 | };
412 |
413 |
414 |
415 | int main(int argc, char **argv)
416 | {
417 | ros::init(argc, argv, "talker");
418 |
419 | ros::NodeHandle n;
420 |
421 | ros::Rate loop_rate(100);
422 |
423 | std::vector target_position = {0.0, 0.0, 0.0};
424 | bool wind = false;
425 |
426 | if(argc > 1) {
427 | target_position[0] = atoi(argv[1]);
428 | target_position[1] = atoi(argv[2]);
429 | target_position[2] = atoi(argv[3]);
430 | wind = atoi(argv[4]);
431 | }
432 |
433 | environmentTracker tracker(n, target_position, wind);
434 |
435 | ROS_INFO("Comunication node ready");
436 |
437 | ros::spin();
438 |
439 | //delete tracker;
440 | return 0;
441 | }
442 |
--------------------------------------------------------------------------------
/srv/GetState.srv:
--------------------------------------------------------------------------------
1 | ---
2 | float64[] position
3 | float64[] velocity
4 | float64[] target_position
5 | float64[] orientation
6 | sensor_msgs/Image img
7 | sensor_msgs/Image img_depth
8 | float64 reward
9 | bool crashed
10 |
--------------------------------------------------------------------------------
/srv/PerformAction.srv:
--------------------------------------------------------------------------------
1 | float64[] action
2 | ---
3 | float64[] position
4 | float64[] target_position
5 | float64[] orientation
6 | sensor_msgs/Image img
7 | sensor_msgs/Image img_depth
8 | float64 reward
9 | bool crashed
10 |
--------------------------------------------------------------------------------