├── .gitignore
├── README.md
├── assets
├── result.gif
└── result.png
├── ddqn.py
├── dqn.py
└── saved_networks
└── Breakout-v0
├── Breakout-v0-4500000
├── Breakout-v0-4500000.meta
└── checkpoint
/.gitignore:
--------------------------------------------------------------------------------
1 | summary
2 | *.py[cod]
3 | .DS_Store
4 | .Python
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DQN in Keras + TensorFlow + OpenAI Gym
2 | This is an implementation of DQN (based on [Mnih et al., 2015](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)) in Keras + TensorFlow + OpenAI Gym.
3 |
4 | ## Requirements
5 | - gym (Atari environment)
6 | - scikit-image
7 | - keras
8 | - tensorflow
9 |
10 | ## Results
11 | This is the result of training of DQN for about 28 hours (12K episodes, 4.7 millions frames) on AWS EC2 g2.2xlarge instance.
12 |
13 | 
14 |
15 |
16 | Statistics of average loss, average max q value, duration, and total reward / episode.
17 |
18 | 
19 |
20 | ## Usage
21 | #### Training
22 | For DQN, run:
23 |
24 | ```
25 | python dqn.py
26 | ```
27 |
28 | For Double DQN, run:
29 |
30 | ```
31 | python ddqn.py
32 | ```
33 |
34 | #### Visualizing learning with TensorBoard
35 | Run the following:
36 |
37 | ```
38 | tensorboard --logdir=summary/
39 | ```
40 |
41 | ## Using GPU
42 | I built an AMI for this experiment. All of requirements + CUDA + cuDNN are pre-installed in the AMI.
43 | The AMI name is `DQN-AMI`, the ID is `ami-c4a969a9`, and the region is N. Virginia. Feel free to use it.
44 |
45 | ## ToDo
46 | - [RMSPropGraves](http://arxiv.org/abs/1308.0850)
47 | - [Dueling Network](https://arxiv.org/abs/1511.06581)
48 |
49 | ## References
50 | - [Mnih et al., 2013, Playing atari with deep reinforcement learning](https://arxiv.org/abs/1312.5602)
51 | - [Mnih et al., 2015, Human-level control through deep reinforcement learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
52 | - [van Hasselt et al., 2016, Deep Reinforcement Learning with Double Q-learning](http://arxiv.org/abs/1509.06461)
53 | - [devsisters/DQN-tensorflow](https://github.com/devsisters/DQN-tensorflow)
54 | - [spragunr/deep_q_rl](https://github.com/spragunr/deep_q_rl)
55 |
--------------------------------------------------------------------------------
/assets/result.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tokb23/dqn/23f410d366facf066ede9dbb4d940a9fdb83fd81/assets/result.gif
--------------------------------------------------------------------------------
/assets/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tokb23/dqn/23f410d366facf066ede9dbb4d940a9fdb83fd81/assets/result.png
--------------------------------------------------------------------------------
/ddqn.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 |
3 | import os
4 | import gym
5 | import random
6 | import numpy as np
7 | import tensorflow as tf
8 | from collections import deque
9 | from skimage.color import rgb2gray
10 | from skimage.transform import resize
11 | from keras.models import Sequential
12 | from keras.layers import Convolution2D, Flatten, Dense
13 |
14 | ENV_NAME = 'Breakout-v0' # Environment name
15 | FRAME_WIDTH = 84 # Resized frame width
16 | FRAME_HEIGHT = 84 # Resized frame height
17 | NUM_EPISODES = 12000 # Number of episodes the agent plays
18 | STATE_LENGTH = 4 # Number of most recent frames to produce the input to the network
19 | GAMMA = 0.99 # Discount factor
20 | EXPLORATION_STEPS = 1000000 # Number of steps over which the initial value of epsilon is linearly annealed to its final value
21 | INITIAL_EPSILON = 1.0 # Initial value of epsilon in epsilon-greedy
22 | FINAL_EPSILON = 0.1 # Final value of epsilon in epsilon-greedy
23 | INITIAL_REPLAY_SIZE = 20000 # Number of steps to populate the replay memory before training starts
24 | NUM_REPLAY_MEMORY = 400000 # Number of replay memory the agent uses for training
25 | BATCH_SIZE = 32 # Mini batch size
26 | TARGET_UPDATE_INTERVAL = 10000 # The frequency with which the target network is updated
27 | TRAIN_INTERVAL = 4 # The agent selects 4 actions between successive updates
28 | LEARNING_RATE = 0.00025 # Learning rate used by RMSProp
29 | MOMENTUM = 0.95 # Momentum used by RMSProp
30 | MIN_GRAD = 0.01 # Constant added to the squared gradient in the denominator of the RMSProp update
31 | SAVE_INTERVAL = 300000 # The frequency with which the network is saved
32 | NO_OP_STEPS = 30 # Maximum number of "do nothing" actions to be performed by the agent at the start of an episode
33 | LOAD_NETWORK = False
34 | TRAIN = True
35 | SAVE_NETWORK_PATH = 'saved_networks/' + ENV_NAME
36 | SAVE_SUMMARY_PATH = 'summary/' + ENV_NAME
37 | NUM_EPISODES_AT_TEST = 30 # Number of episodes the agent plays at test time
38 |
39 |
40 | class Agent():
41 | def __init__(self, num_actions):
42 | self.num_actions = num_actions
43 | self.epsilon = INITIAL_EPSILON
44 | self.epsilon_step = (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORATION_STEPS
45 | self.t = 0
46 |
47 | # Parameters used for summary
48 | self.total_reward = 0
49 | self.total_q_max = 0
50 | self.total_loss = 0
51 | self.duration = 0
52 | self.episode = 0
53 |
54 | # Create replay memory
55 | self.replay_memory = deque()
56 |
57 | # Create q network
58 | self.s, self.q_values, q_network = self.build_network()
59 | q_network_weights = q_network.trainable_weights
60 |
61 | # Create target network
62 | self.st, self.target_q_values, target_network = self.build_network()
63 | target_network_weights = target_network.trainable_weights
64 |
65 | # Define target network update operation
66 | self.update_target_network = [target_network_weights[i].assign(q_network_weights[i]) for i in range(len(target_network_weights))]
67 |
68 | # Define loss and gradient update operation
69 | self.a, self.y, self.loss, self.grads_update = self.build_training_op(q_network_weights)
70 |
71 | self.sess = tf.InteractiveSession()
72 | self.saver = tf.train.Saver(q_network_weights)
73 | self.summary_placeholders, self.update_ops, self.summary_op = self.setup_summary()
74 | self.summary_writer = tf.train.SummaryWriter(SAVE_SUMMARY_PATH, self.sess.graph)
75 |
76 | if not os.path.exists(SAVE_NETWORK_PATH):
77 | os.makedirs(SAVE_NETWORK_PATH)
78 |
79 | self.sess.run(tf.initialize_all_variables())
80 |
81 | # Load network
82 | if LOAD_NETWORK:
83 | self.load_network()
84 |
85 | # Initialize target network
86 | self.sess.run(self.update_target_network)
87 |
88 | def build_network(self):
89 | model = Sequential()
90 | model.add(Convolution2D(32, 8, 8, subsample=(4, 4), activation='relu', input_shape=(STATE_LENGTH, FRAME_WIDTH, FRAME_HEIGHT)))
91 | model.add(Convolution2D(64, 4, 4, subsample=(2, 2), activation='relu'))
92 | model.add(Convolution2D(64, 3, 3, subsample=(1, 1), activation='relu'))
93 | model.add(Flatten())
94 | model.add(Dense(512, activation='relu'))
95 | model.add(Dense(self.num_actions))
96 |
97 | s = tf.placeholder(tf.float32, [None, STATE_LENGTH, FRAME_WIDTH, FRAME_HEIGHT])
98 | q_values = model(s)
99 |
100 | return s, q_values, model
101 |
102 | def build_training_op(self, q_network_weights):
103 | a = tf.placeholder(tf.int64, [None])
104 | y = tf.placeholder(tf.float32, [None])
105 |
106 | # Convert action to one hot vector
107 | a_one_hot = tf.one_hot(a, self.num_actions, 1.0, 0.0)
108 | q_value = tf.reduce_sum(tf.mul(self.q_values, a_one_hot), reduction_indices=1)
109 |
110 | # Clip the error, the loss is quadratic when the error is in (-1, 1), and linear outside of that region
111 | error = tf.abs(y - q_value)
112 | quadratic_part = tf.clip_by_value(error, 0.0, 1.0)
113 | linear_part = error - quadratic_part
114 | loss = tf.reduce_mean(0.5 * tf.square(quadratic_part) + linear_part)
115 |
116 | optimizer = tf.train.RMSPropOptimizer(LEARNING_RATE, momentum=MOMENTUM, epsilon=MIN_GRAD)
117 | grads_update = optimizer.minimize(loss, var_list=q_network_weights)
118 |
119 | return a, y, loss, grads_update
120 |
121 | def get_initial_state(self, observation, last_observation):
122 | processed_observation = np.maximum(observation, last_observation)
123 | processed_observation = np.uint8(resize(rgb2gray(processed_observation), (FRAME_WIDTH, FRAME_HEIGHT)) * 255)
124 | state = [processed_observation for _ in range(STATE_LENGTH)]
125 | return np.stack(state, axis=0)
126 |
127 | def get_action(self, state):
128 | if self.epsilon >= random.random() or self.t < INITIAL_REPLAY_SIZE:
129 | action = random.randrange(self.num_actions)
130 | else:
131 | action = np.argmax(self.q_values.eval(feed_dict={self.s: [np.float32(state / 255.0)]}))
132 |
133 | # Anneal epsilon linearly over time
134 | if self.epsilon > FINAL_EPSILON and self.t >= INITIAL_REPLAY_SIZE:
135 | self.epsilon -= self.epsilon_step
136 |
137 | return action
138 |
139 | def run(self, state, action, reward, terminal, observation):
140 | next_state = np.append(state[1:, :, :], observation, axis=0)
141 |
142 | # Clip all positive rewards at 1 and all negative rewards at -1, leaving 0 rewards unchanged
143 | reward = np.clip(reward, -1, 1)
144 |
145 | # Store transition in replay memory
146 | self.replay_memory.append((state, action, reward, next_state, terminal))
147 | if len(self.replay_memory) > NUM_REPLAY_MEMORY:
148 | self.replay_memory.popleft()
149 |
150 | if self.t >= INITIAL_REPLAY_SIZE:
151 | # Train network
152 | if self.t % TRAIN_INTERVAL == 0:
153 | self.train_network()
154 |
155 | # Update target network
156 | if self.t % TARGET_UPDATE_INTERVAL == 0:
157 | self.sess.run(self.update_target_network)
158 |
159 | # Save network
160 | if self.t % SAVE_INTERVAL == 0:
161 | save_path = self.saver.save(self.sess, SAVE_NETWORK_PATH + '/' + ENV_NAME, global_step=self.t)
162 | print('Successfully saved: ' + save_path)
163 |
164 | self.total_reward += reward
165 | self.total_q_max += np.max(self.q_values.eval(feed_dict={self.s: [np.float32(state / 255.0)]}))
166 | self.duration += 1
167 |
168 | if terminal:
169 | # Write summary
170 | if self.t >= INITIAL_REPLAY_SIZE:
171 | stats = [self.total_reward, self.total_q_max / float(self.duration),
172 | self.duration, self.total_loss / (float(self.duration) / float(TRAIN_INTERVAL))]
173 | for i in range(len(stats)):
174 | self.sess.run(self.update_ops[i], feed_dict={
175 | self.summary_placeholders[i]: float(stats[i])
176 | })
177 | summary_str = self.sess.run(self.summary_op)
178 | self.summary_writer.add_summary(summary_str, self.episode + 1)
179 |
180 | # Debug
181 | if self.t < INITIAL_REPLAY_SIZE:
182 | mode = 'random'
183 | elif INITIAL_REPLAY_SIZE <= self.t < INITIAL_REPLAY_SIZE + EXPLORATION_STEPS:
184 | mode = 'explore'
185 | else:
186 | mode = 'exploit'
187 | print('EPISODE: {0:6d} / TIMESTEP: {1:8d} / DURATION: {2:5d} / EPSILON: {3:.5f} / TOTAL_REWARD: {4:3.0f} / AVG_MAX_Q: {5:2.4f} / AVG_LOSS: {6:.5f} / MODE: {7}'.format(
188 | self.episode + 1, self.t, self.duration, self.epsilon,
189 | self.total_reward, self.total_q_max / float(self.duration),
190 | self.total_loss / (float(self.duration) / float(TRAIN_INTERVAL)), mode))
191 |
192 | self.total_reward = 0
193 | self.total_q_max = 0
194 | self.total_loss = 0
195 | self.duration = 0
196 | self.episode += 1
197 |
198 | self.t += 1
199 |
200 | return next_state
201 |
202 | def train_network(self):
203 | state_batch = []
204 | action_batch = []
205 | reward_batch = []
206 | next_state_batch = []
207 | terminal_batch = []
208 | y_batch = []
209 |
210 | # Sample random minibatch of transition from replay memory
211 | minibatch = random.sample(self.replay_memory, BATCH_SIZE)
212 | for data in minibatch:
213 | state_batch.append(data[0])
214 | action_batch.append(data[1])
215 | reward_batch.append(data[2])
216 | next_state_batch.append(data[3])
217 | terminal_batch.append(data[4])
218 |
219 | # Convert True to 1, False to 0
220 | terminal_batch = np.array(terminal_batch) + 0
221 |
222 | next_action_batch = np.argmax(self.q_values.eval(feed_dict={self.s: next_state_batch}), axis=1)
223 | target_q_values_batch = self.target_q_values.eval(feed_dict={self.st: next_state_batch})
224 | for i in xrange(len(minibatch)):
225 | y_batch.append(reward_batch[i] + (1 - terminal_batch[i]) * GAMMA * target_q_values_batch[i][next_action_batch[i]])
226 |
227 | loss, _ = self.sess.run([self.loss, self.grads_update], feed_dict={
228 | self.s: np.float32(np.array(state_batch) / 255.0),
229 | self.a: action_batch,
230 | self.y: y_batch
231 | })
232 |
233 | self.total_loss += loss
234 |
235 | def setup_summary(self):
236 | episode_total_reward = tf.Variable(0.)
237 | tf.scalar_summary(ENV_NAME + '/Total Reward/Episode', episode_total_reward)
238 | episode_avg_max_q = tf.Variable(0.)
239 | tf.scalar_summary(ENV_NAME + '/Average Max Q/Episode', episode_avg_max_q)
240 | episode_duration = tf.Variable(0.)
241 | tf.scalar_summary(ENV_NAME + '/Duration/Episode', episode_duration)
242 | episode_avg_loss = tf.Variable(0.)
243 | tf.scalar_summary(ENV_NAME + '/Average Loss/Episode', episode_avg_loss)
244 | summary_vars = [episode_total_reward, episode_avg_max_q, episode_duration, episode_avg_loss]
245 | summary_placeholders = [tf.placeholder(tf.float32) for _ in range(len(summary_vars))]
246 | update_ops = [summary_vars[i].assign(summary_placeholders[i]) for i in range(len(summary_vars))]
247 | summary_op = tf.merge_all_summaries()
248 | return summary_placeholders, update_ops, summary_op
249 |
250 | def load_network(self):
251 | checkpoint = tf.train.get_checkpoint_state(SAVE_NETWORK_PATH)
252 | if checkpoint and checkpoint.model_checkpoint_path:
253 | self.saver.restore(self.sess, checkpoint.model_checkpoint_path)
254 | print('Successfully loaded: ' + checkpoint.model_checkpoint_path)
255 | else:
256 | print('Training new network...')
257 |
258 | def get_action_at_test(self, state):
259 | if random.random() <= 0.05:
260 | action = random.randrange(self.num_actions)
261 | else:
262 | action = np.argmax(self.q_values.eval(feed_dict={self.s: [np.float32(state / 255.0)]}))
263 |
264 | self.t += 1
265 |
266 | return action
267 |
268 |
269 | def preprocess(observation, last_observation):
270 | processed_observation = np.maximum(observation, last_observation)
271 | processed_observation = np.uint8(resize(rgb2gray(processed_observation), (FRAME_WIDTH, FRAME_HEIGHT)) * 255)
272 | return np.reshape(processed_observation, (1, FRAME_WIDTH, FRAME_HEIGHT))
273 |
274 |
275 | def main():
276 | env = gym.make(ENV_NAME)
277 | agent = Agent(num_actions=env.action_space.n)
278 |
279 | if TRAIN: # Train mode
280 | for _ in range(NUM_EPISODES):
281 | terminal = False
282 | observation = env.reset()
283 | for _ in range(random.randint(1, NO_OP_STEPS)):
284 | last_observation = observation
285 | observation, _, _, _ = env.step(0) # Do nothing
286 | state = agent.get_initial_state(observation, last_observation)
287 | while not terminal:
288 | last_observation = observation
289 | action = agent.get_action(state)
290 | observation, reward, terminal, _ = env.step(action)
291 | # env.render()
292 | processed_observation = preprocess(observation, last_observation)
293 | state = agent.run(state, action, reward, terminal, processed_observation)
294 | else: # Test mode
295 | # env.monitor.start(ENV_NAME + '-test')
296 | for _ in range(NUM_EPISODES_AT_TEST):
297 | terminal = False
298 | observation = env.reset()
299 | for _ in range(random.randint(1, NO_OP_STEPS)):
300 | last_observation = observation
301 | observation, _, _, _ = env.step(0) # Do nothing
302 | state = agent.get_initial_state(observation, last_observation)
303 | while not terminal:
304 | last_observation = observation
305 | action = agent.get_action_at_test(state)
306 | observation, _, terminal, _ = env.step(action)
307 | env.render()
308 | processed_observation = preprocess(observation, last_observation)
309 | state = np.append(state[1:, :, :], processed_observation, axis=0)
310 | # env.monitor.close()
311 |
312 |
313 | if __name__ == '__main__':
314 | main()
315 |
--------------------------------------------------------------------------------
/dqn.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 |
3 | import os
4 | import gym
5 | import random
6 | import numpy as np
7 | import tensorflow as tf
8 | from collections import deque
9 | from skimage.color import rgb2gray
10 | from skimage.transform import resize
11 | from keras.models import Sequential
12 | from keras.layers import Convolution2D, Flatten, Dense
13 |
14 | ENV_NAME = 'Breakout-v0' # Environment name
15 | FRAME_WIDTH = 84 # Resized frame width
16 | FRAME_HEIGHT = 84 # Resized frame height
17 | NUM_EPISODES = 12000 # Number of episodes the agent plays
18 | STATE_LENGTH = 4 # Number of most recent frames to produce the input to the network
19 | GAMMA = 0.99 # Discount factor
20 | EXPLORATION_STEPS = 1000000 # Number of steps over which the initial value of epsilon is linearly annealed to its final value
21 | INITIAL_EPSILON = 1.0 # Initial value of epsilon in epsilon-greedy
22 | FINAL_EPSILON = 0.1 # Final value of epsilon in epsilon-greedy
23 | INITIAL_REPLAY_SIZE = 20000 # Number of steps to populate the replay memory before training starts
24 | NUM_REPLAY_MEMORY = 400000 # Number of replay memory the agent uses for training
25 | BATCH_SIZE = 32 # Mini batch size
26 | TARGET_UPDATE_INTERVAL = 10000 # The frequency with which the target network is updated
27 | TRAIN_INTERVAL = 4 # The agent selects 4 actions between successive updates
28 | LEARNING_RATE = 0.00025 # Learning rate used by RMSProp
29 | MOMENTUM = 0.95 # Momentum used by RMSProp
30 | MIN_GRAD = 0.01 # Constant added to the squared gradient in the denominator of the RMSProp update
31 | SAVE_INTERVAL = 300000 # The frequency with which the network is saved
32 | NO_OP_STEPS = 30 # Maximum number of "do nothing" actions to be performed by the agent at the start of an episode
33 | LOAD_NETWORK = False
34 | TRAIN = True
35 | SAVE_NETWORK_PATH = 'saved_networks/' + ENV_NAME
36 | SAVE_SUMMARY_PATH = 'summary/' + ENV_NAME
37 | NUM_EPISODES_AT_TEST = 30 # Number of episodes the agent plays at test time
38 |
39 |
40 | class Agent():
41 | def __init__(self, num_actions):
42 | self.num_actions = num_actions
43 | self.epsilon = INITIAL_EPSILON
44 | self.epsilon_step = (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORATION_STEPS
45 | self.t = 0
46 |
47 | # Parameters used for summary
48 | self.total_reward = 0
49 | self.total_q_max = 0
50 | self.total_loss = 0
51 | self.duration = 0
52 | self.episode = 0
53 |
54 | # Create replay memory
55 | self.replay_memory = deque()
56 |
57 | # Create q network
58 | self.s, self.q_values, q_network = self.build_network()
59 | q_network_weights = q_network.trainable_weights
60 |
61 | # Create target network
62 | self.st, self.target_q_values, target_network = self.build_network()
63 | target_network_weights = target_network.trainable_weights
64 |
65 | # Define target network update operation
66 | self.update_target_network = [target_network_weights[i].assign(q_network_weights[i]) for i in range(len(target_network_weights))]
67 |
68 | # Define loss and gradient update operation
69 | self.a, self.y, self.loss, self.grads_update = self.build_training_op(q_network_weights)
70 |
71 | self.sess = tf.InteractiveSession()
72 | self.saver = tf.train.Saver(q_network_weights)
73 | self.summary_placeholders, self.update_ops, self.summary_op = self.setup_summary()
74 | self.summary_writer = tf.train.SummaryWriter(SAVE_SUMMARY_PATH, self.sess.graph)
75 |
76 | if not os.path.exists(SAVE_NETWORK_PATH):
77 | os.makedirs(SAVE_NETWORK_PATH)
78 |
79 | self.sess.run(tf.initialize_all_variables())
80 |
81 | # Load network
82 | if LOAD_NETWORK:
83 | self.load_network()
84 |
85 | # Initialize target network
86 | self.sess.run(self.update_target_network)
87 |
88 | def build_network(self):
89 | model = Sequential()
90 | model.add(Convolution2D(32, 8, 8, subsample=(4, 4), activation='relu', input_shape=(STATE_LENGTH, FRAME_WIDTH, FRAME_HEIGHT)))
91 | model.add(Convolution2D(64, 4, 4, subsample=(2, 2), activation='relu'))
92 | model.add(Convolution2D(64, 3, 3, subsample=(1, 1), activation='relu'))
93 | model.add(Flatten())
94 | model.add(Dense(512, activation='relu'))
95 | model.add(Dense(self.num_actions))
96 |
97 | s = tf.placeholder(tf.float32, [None, STATE_LENGTH, FRAME_WIDTH, FRAME_HEIGHT])
98 | q_values = model(s)
99 |
100 | return s, q_values, model
101 |
102 | def build_training_op(self, q_network_weights):
103 | a = tf.placeholder(tf.int64, [None])
104 | y = tf.placeholder(tf.float32, [None])
105 |
106 | # Convert action to one hot vector
107 | a_one_hot = tf.one_hot(a, self.num_actions, 1.0, 0.0)
108 | q_value = tf.reduce_sum(tf.mul(self.q_values, a_one_hot), reduction_indices=1)
109 |
110 | # Clip the error, the loss is quadratic when the error is in (-1, 1), and linear outside of that region
111 | error = tf.abs(y - q_value)
112 | quadratic_part = tf.clip_by_value(error, 0.0, 1.0)
113 | linear_part = error - quadratic_part
114 | loss = tf.reduce_mean(0.5 * tf.square(quadratic_part) + linear_part)
115 |
116 | optimizer = tf.train.RMSPropOptimizer(LEARNING_RATE, momentum=MOMENTUM, epsilon=MIN_GRAD)
117 | grads_update = optimizer.minimize(loss, var_list=q_network_weights)
118 |
119 | return a, y, loss, grads_update
120 |
121 | def get_initial_state(self, observation, last_observation):
122 | processed_observation = np.maximum(observation, last_observation)
123 | processed_observation = np.uint8(resize(rgb2gray(processed_observation), (FRAME_WIDTH, FRAME_HEIGHT)) * 255)
124 | state = [processed_observation for _ in range(STATE_LENGTH)]
125 | return np.stack(state, axis=0)
126 |
127 | def get_action(self, state):
128 | if self.epsilon >= random.random() or self.t < INITIAL_REPLAY_SIZE:
129 | action = random.randrange(self.num_actions)
130 | else:
131 | action = np.argmax(self.q_values.eval(feed_dict={self.s: [np.float32(state / 255.0)]}))
132 |
133 | # Anneal epsilon linearly over time
134 | if self.epsilon > FINAL_EPSILON and self.t >= INITIAL_REPLAY_SIZE:
135 | self.epsilon -= self.epsilon_step
136 |
137 | return action
138 |
139 | def run(self, state, action, reward, terminal, observation):
140 | next_state = np.append(state[1:, :, :], observation, axis=0)
141 |
142 | # Clip all positive rewards at 1 and all negative rewards at -1, leaving 0 rewards unchanged
143 | reward = np.clip(reward, -1, 1)
144 |
145 | # Store transition in replay memory
146 | self.replay_memory.append((state, action, reward, next_state, terminal))
147 | if len(self.replay_memory) > NUM_REPLAY_MEMORY:
148 | self.replay_memory.popleft()
149 |
150 | if self.t >= INITIAL_REPLAY_SIZE:
151 | # Train network
152 | if self.t % TRAIN_INTERVAL == 0:
153 | self.train_network()
154 |
155 | # Update target network
156 | if self.t % TARGET_UPDATE_INTERVAL == 0:
157 | self.sess.run(self.update_target_network)
158 |
159 | # Save network
160 | if self.t % SAVE_INTERVAL == 0:
161 | save_path = self.saver.save(self.sess, SAVE_NETWORK_PATH + '/' + ENV_NAME, global_step=self.t)
162 | print('Successfully saved: ' + save_path)
163 |
164 | self.total_reward += reward
165 | self.total_q_max += np.max(self.q_values.eval(feed_dict={self.s: [np.float32(state / 255.0)]}))
166 | self.duration += 1
167 |
168 | if terminal:
169 | # Write summary
170 | if self.t >= INITIAL_REPLAY_SIZE:
171 | stats = [self.total_reward, self.total_q_max / float(self.duration),
172 | self.duration, self.total_loss / (float(self.duration) / float(TRAIN_INTERVAL))]
173 | for i in range(len(stats)):
174 | self.sess.run(self.update_ops[i], feed_dict={
175 | self.summary_placeholders[i]: float(stats[i])
176 | })
177 | summary_str = self.sess.run(self.summary_op)
178 | self.summary_writer.add_summary(summary_str, self.episode + 1)
179 |
180 | # Debug
181 | if self.t < INITIAL_REPLAY_SIZE:
182 | mode = 'random'
183 | elif INITIAL_REPLAY_SIZE <= self.t < INITIAL_REPLAY_SIZE + EXPLORATION_STEPS:
184 | mode = 'explore'
185 | else:
186 | mode = 'exploit'
187 | print('EPISODE: {0:6d} / TIMESTEP: {1:8d} / DURATION: {2:5d} / EPSILON: {3:.5f} / TOTAL_REWARD: {4:3.0f} / AVG_MAX_Q: {5:2.4f} / AVG_LOSS: {6:.5f} / MODE: {7}'.format(
188 | self.episode + 1, self.t, self.duration, self.epsilon,
189 | self.total_reward, self.total_q_max / float(self.duration),
190 | self.total_loss / (float(self.duration) / float(TRAIN_INTERVAL)), mode))
191 |
192 | self.total_reward = 0
193 | self.total_q_max = 0
194 | self.total_loss = 0
195 | self.duration = 0
196 | self.episode += 1
197 |
198 | self.t += 1
199 |
200 | return next_state
201 |
202 | def train_network(self):
203 | state_batch = []
204 | action_batch = []
205 | reward_batch = []
206 | next_state_batch = []
207 | terminal_batch = []
208 | y_batch = []
209 |
210 | # Sample random minibatch of transition from replay memory
211 | minibatch = random.sample(self.replay_memory, BATCH_SIZE)
212 | for data in minibatch:
213 | state_batch.append(data[0])
214 | action_batch.append(data[1])
215 | reward_batch.append(data[2])
216 | next_state_batch.append(data[3])
217 | terminal_batch.append(data[4])
218 |
219 | # Convert True to 1, False to 0
220 | terminal_batch = np.array(terminal_batch) + 0
221 |
222 | target_q_values_batch = self.target_q_values.eval(feed_dict={self.st: np.float32(np.array(next_state_batch) / 255.0)})
223 | y_batch = reward_batch + (1 - terminal_batch) * GAMMA * np.max(target_q_values_batch, axis=1)
224 |
225 | loss, _ = self.sess.run([self.loss, self.grads_update], feed_dict={
226 | self.s: np.float32(np.array(state_batch) / 255.0),
227 | self.a: action_batch,
228 | self.y: y_batch
229 | })
230 |
231 | self.total_loss += loss
232 |
233 | def setup_summary(self):
234 | episode_total_reward = tf.Variable(0.)
235 | tf.scalar_summary(ENV_NAME + '/Total Reward/Episode', episode_total_reward)
236 | episode_avg_max_q = tf.Variable(0.)
237 | tf.scalar_summary(ENV_NAME + '/Average Max Q/Episode', episode_avg_max_q)
238 | episode_duration = tf.Variable(0.)
239 | tf.scalar_summary(ENV_NAME + '/Duration/Episode', episode_duration)
240 | episode_avg_loss = tf.Variable(0.)
241 | tf.scalar_summary(ENV_NAME + '/Average Loss/Episode', episode_avg_loss)
242 | summary_vars = [episode_total_reward, episode_avg_max_q, episode_duration, episode_avg_loss]
243 | summary_placeholders = [tf.placeholder(tf.float32) for _ in range(len(summary_vars))]
244 | update_ops = [summary_vars[i].assign(summary_placeholders[i]) for i in range(len(summary_vars))]
245 | summary_op = tf.merge_all_summaries()
246 | return summary_placeholders, update_ops, summary_op
247 |
248 | def load_network(self):
249 | checkpoint = tf.train.get_checkpoint_state(SAVE_NETWORK_PATH)
250 | if checkpoint and checkpoint.model_checkpoint_path:
251 | self.saver.restore(self.sess, checkpoint.model_checkpoint_path)
252 | print('Successfully loaded: ' + checkpoint.model_checkpoint_path)
253 | else:
254 | print('Training new network...')
255 |
256 | def get_action_at_test(self, state):
257 | if random.random() <= 0.05:
258 | action = random.randrange(self.num_actions)
259 | else:
260 | action = np.argmax(self.q_values.eval(feed_dict={self.s: [np.float32(state / 255.0)]}))
261 |
262 | self.t += 1
263 |
264 | return action
265 |
266 |
267 | def preprocess(observation, last_observation):
268 | processed_observation = np.maximum(observation, last_observation)
269 | processed_observation = np.uint8(resize(rgb2gray(processed_observation), (FRAME_WIDTH, FRAME_HEIGHT)) * 255)
270 | return np.reshape(processed_observation, (1, FRAME_WIDTH, FRAME_HEIGHT))
271 |
272 |
273 | def main():
274 | env = gym.make(ENV_NAME)
275 | agent = Agent(num_actions=env.action_space.n)
276 |
277 | if TRAIN: # Train mode
278 | for _ in range(NUM_EPISODES):
279 | terminal = False
280 | observation = env.reset()
281 | for _ in range(random.randint(1, NO_OP_STEPS)):
282 | last_observation = observation
283 | observation, _, _, _ = env.step(0) # Do nothing
284 | state = agent.get_initial_state(observation, last_observation)
285 | while not terminal:
286 | last_observation = observation
287 | action = agent.get_action(state)
288 | observation, reward, terminal, _ = env.step(action)
289 | # env.render()
290 | processed_observation = preprocess(observation, last_observation)
291 | state = agent.run(state, action, reward, terminal, processed_observation)
292 | else: # Test mode
293 | # env.monitor.start(ENV_NAME + '-test')
294 | for _ in range(NUM_EPISODES_AT_TEST):
295 | terminal = False
296 | observation = env.reset()
297 | for _ in range(random.randint(1, NO_OP_STEPS)):
298 | last_observation = observation
299 | observation, _, _, _ = env.step(0) # Do nothing
300 | state = agent.get_initial_state(observation, last_observation)
301 | while not terminal:
302 | last_observation = observation
303 | action = agent.get_action_at_test(state)
304 | observation, _, terminal, _ = env.step(action)
305 | env.render()
306 | processed_observation = preprocess(observation, last_observation)
307 | state = np.append(state[1:, :, :], processed_observation, axis=0)
308 | # env.monitor.close()
309 |
310 |
311 | if __name__ == '__main__':
312 | main()
313 |
--------------------------------------------------------------------------------
/saved_networks/Breakout-v0/Breakout-v0-4500000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tokb23/dqn/23f410d366facf066ede9dbb4d940a9fdb83fd81/saved_networks/Breakout-v0/Breakout-v0-4500000
--------------------------------------------------------------------------------
/saved_networks/Breakout-v0/Breakout-v0-4500000.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tokb23/dqn/23f410d366facf066ede9dbb4d940a9fdb83fd81/saved_networks/Breakout-v0/Breakout-v0-4500000.meta
--------------------------------------------------------------------------------
/saved_networks/Breakout-v0/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "Breakout-v0-4500000"
2 | all_model_checkpoint_paths: "Breakout-v0-3300000"
3 | all_model_checkpoint_paths: "Breakout-v0-3600000"
4 | all_model_checkpoint_paths: "Breakout-v0-3900000"
5 | all_model_checkpoint_paths: "Breakout-v0-4200000"
6 | all_model_checkpoint_paths: "Breakout-v0-4500000"
7 |
--------------------------------------------------------------------------------