├── DQN_Cartpole.ipynb ├── DQN_Cartpole.py ├── DuelingDQN.ipynb ├── DuelingDQN.py ├── PrioritizedReplayDQN-ProportionalVariant.ipynb ├── README.md └── drl-gym ├── DQNModel.py ├── argument.py ├── atari_wrappers.py ├── envWrapper.py ├── logger.py ├── memory.py ├── netFrame.py ├── run.py ├── testModel.py └── utils.py /DQN_Cartpole.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import tensorflow as tf\n", 12 | "import numpy as np\n", 13 | "import collections\n", 14 | "import gym\n", 15 | "import random\n", 16 | "import tensorflow.contrib.layers as layers" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": { 23 | "collapsed": true 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "ENV = \"CartPole-v0\"" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 3, 33 | "metadata": { 34 | "collapsed": true 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "MEMORY_SIZE = 10000\n", 39 | "EPISODES = 500\n", 40 | "MAX_STEP = 500\n", 41 | "BATCH_SIZE = 32\n", 42 | "UPDATE_PERIOD = 200 # update target network parameters\n" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 4, 48 | "metadata": { 49 | "collapsed": true 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "##built class for the DQN\n", 54 | "class DeepQNetwork():\n", 55 | " def __init__(self , env , sess=None , gamma = 0.8, epsilon = 0.8 ):\n", 56 | " self.gamma = gamma\n", 57 | " self.epsilon = epsilon\n", 58 | " self.action_dim = env.action_space.n\n", 59 | " self.state_dim = env.observation_space.shape[0]\n", 60 | " self.network()\n", 61 | " self.sess = sess\n", 62 | " self.sess.run(tf.global_variables_initializer())\n", 63 | " tf.summary.FileWriter(\"DQN/summaries\" , sess.graph )\n", 64 | " \n", 65 | " # net_frame using for creating Q & target network\n", 66 | " def net_frame(self , hiddens, inpt, num_actions, scope, reuse=None):\n", 67 | " with tf.variable_scope(scope, reuse=reuse):\n", 68 | " out = inpt \n", 69 | " for hidden in hiddens:\n", 70 | " out = layers.fully_connected(out, num_outputs=hidden, activation_fn=tf.nn.relu)\n", 71 | " out = layers.fully_connected(out, num_outputs=num_actions, activation_fn=None) \n", 72 | " return out\n", 73 | " \n", 74 | " # create q_network & target_network \n", 75 | " def network(self): \n", 76 | " # q_network\n", 77 | " self.inputs_q = tf.placeholder(dtype = tf.float32 , shape = [None , self.state_dim] , name = \"inputs_q\")\n", 78 | " scope_var = \"q_network\" \n", 79 | " self.q_value = self.net_frame([64] , self.inputs_q , self.action_dim , scope_var , reuse = True )\n", 80 | " \n", 81 | " # target_network\n", 82 | " self.inputs_target = tf.placeholder(dtype = tf.float32 , shape = [None , self.state_dim] , name = \"inputs_target\")\n", 83 | " scope_tar = \"target_network\" \n", 84 | " self.q_target = self.net_frame([64] , self.inputs_target , self.action_dim , scope_tar )\n", 85 | " \n", 86 | " with tf.variable_scope(\"loss\"):\n", 87 | "# #【方案一】\n", 88 | "# self.target = tf.placeholder(dtype = tf.float32 , shape = [None , self.action_dim] , name = \"target\")\n", 89 | "# self.loss = tf.reduce_mean( tf.square(self.q_value - self.target))\n", 90 | " #【方案二】\n", 91 | " self.action = tf.placeholder(dtype = tf.int32 , shape = [ None ] , name = \"action\")\n", 92 | " action_one_hot = tf.one_hot(self.action , self.action_dim )\n", 93 | " q_action = tf.reduce_sum( tf.multiply(self.q_value , action_one_hot) , axis = 1 ) \n", 94 | " \n", 95 | " self.target = tf.placeholder(dtype = tf.float32 , shape = [None ] , name = \"target\")\n", 96 | " self.loss = tf.reduce_mean( tf.square(q_action - self.target))\n", 97 | "\n", 98 | " with tf.variable_scope(\"train\"):\n", 99 | " optimizer = tf.train.RMSPropOptimizer(0.001)\n", 100 | " self.train_op = optimizer.minimize(self.loss) \n", 101 | " \n", 102 | " # training\n", 103 | " def train(self , state , reward , action , state_next , done):\n", 104 | " q , q_target = self.sess.run([self.q_value , self.q_target] , \n", 105 | " feed_dict={self.inputs_q : state , self.inputs_target : state_next } )\n", 106 | "# #【方案一】\n", 107 | "# target = reward + self.gamma * np.max(q_target , axis = 1)*(1.0 - done)\n", 108 | " \n", 109 | "# self.reform_target = q.copy()\n", 110 | "# batch_index = np.arange(BATCH_SIZE , dtype = np.int32)\n", 111 | "# self.reform_target[batch_index , action] = target\n", 112 | " \n", 113 | "# loss , _ = self.sess.run([self.loss , self.train_op] , feed_dict={self.inputs_q: state , self.target: self.reform_target} )\n", 114 | "\n", 115 | " #【方案二】\n", 116 | " q_target_best = np.max(q_target , axis = 1)\n", 117 | " q_target_best_mask = ( 1.0 - done) * q_target_best\n", 118 | " \n", 119 | " target = reward + self.gamma * q_target_best_mask\n", 120 | " \n", 121 | " loss , _ = self.sess.run([self.loss , self.train_op] , \n", 122 | " feed_dict={self.inputs_q: state , self.target:target , self.action:action} ) \n", 123 | " # chose action\n", 124 | " def chose_action(self , current_state):\n", 125 | " current_state = current_state[np.newaxis , :] #*** array dim: (xx,) --> (1 , xx) ***\n", 126 | " q = self.sess.run(self.q_value , feed_dict={self.inputs_q : current_state} )\n", 127 | " \n", 128 | " # e-greedy\n", 129 | " if np.random.random() < self.epsilon:\n", 130 | " action_chosen = np.random.randint(0 , self.action_dim)\n", 131 | " else:\n", 132 | " action_chosen = np.argmax(q)\n", 133 | " \n", 134 | " return action_chosen\n", 135 | " \n", 136 | " #upadate parmerters\n", 137 | " def update_prmt(self):\n", 138 | " q_prmts = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES , \"q_network\" )\n", 139 | " target_prmts = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, \"target_network\" )\n", 140 | " self.sess.run( [tf.assign(t , q)for t,q in zip(target_prmts , q_prmts)]) #***\n", 141 | " print(\"updating target-network parmeters...\")\n", 142 | " \n", 143 | " def decay_epsilon(self):\n", 144 | " if self.epsilon > 0.03:\n", 145 | " self.epsilon = self.epsilon - 0.02" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 5, 151 | "metadata": { 152 | "collapsed": true 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "# memory for momery replay\n", 157 | "memory = []\n", 158 | "Transition = collections.namedtuple(\"Transition\" , [\"state\", \"action\" , \"reward\" , \"next_state\" , \"done\"])" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 6, 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "name": "stderr", 168 | "output_type": "stream", 169 | "text": [ 170 | "[2017-06-16 10:32:02,009] Making new env: CartPole-v0\n" 171 | ] 172 | }, 173 | { 174 | "name": "stdout", 175 | "output_type": "stream", 176 | "text": [ 177 | "updating target-network parmeters...\n", 178 | "updating target-network parmeters...\n", 179 | "updating target-network parmeters...\n", 180 | "updating target-network parmeters...\n", 181 | "updating target-network parmeters...\n", 182 | "updating target-network parmeters...\n", 183 | "updating target-network parmeters...\n", 184 | "updating target-network parmeters...\n", 185 | "updating target-network parmeters...\n", 186 | "updating target-network parmeters...\n", 187 | "updating target-network parmeters...\n", 188 | "updating target-network parmeters...\n", 189 | "updating target-network parmeters...\n", 190 | "updating target-network parmeters...\n", 191 | "updating target-network parmeters...\n", 192 | "updating target-network parmeters...\n", 193 | "updating target-network parmeters...\n", 194 | "updating target-network parmeters...\n", 195 | "updating target-network parmeters...\n", 196 | "updating target-network parmeters...\n", 197 | "updating target-network parmeters...\n", 198 | "updating target-network parmeters...\n", 199 | "updating target-network parmeters...\n", 200 | "updating target-network parmeters...\n", 201 | "updating target-network parmeters...\n", 202 | "updating target-network parmeters...\n", 203 | "updating target-network parmeters...\n", 204 | "updating target-network parmeters...\n", 205 | "updating target-network parmeters...\n", 206 | "updating target-network parmeters...\n", 207 | "updating target-network parmeters...\n", 208 | "updating target-network parmeters...\n", 209 | "updating target-network parmeters...\n", 210 | "updating target-network parmeters...\n", 211 | "updating target-network parmeters...\n", 212 | "updating target-network parmeters...\n", 213 | "updating target-network parmeters...\n", 214 | "updating target-network parmeters...\n", 215 | "updating target-network parmeters...\n", 216 | "updating target-network parmeters...\n", 217 | "updating target-network parmeters...\n", 218 | "updating target-network parmeters...\n", 219 | "updating target-network parmeters...\n", 220 | "updating target-network parmeters...\n", 221 | "updating target-network parmeters...\n", 222 | "updating target-network parmeters...\n", 223 | "updating target-network parmeters...\n", 224 | "updating target-network parmeters...\n", 225 | "updating target-network parmeters...\n", 226 | "updating target-network parmeters...\n", 227 | "updating target-network parmeters...\n", 228 | "updating target-network parmeters...\n", 229 | "updating target-network parmeters...\n", 230 | "updating target-network parmeters...\n", 231 | "updating target-network parmeters...\n", 232 | "updating target-network parmeters...\n", 233 | "updating target-network parmeters...\n", 234 | "updating target-network parmeters...\n", 235 | "updating target-network parmeters...\n", 236 | "updating target-network parmeters...\n", 237 | "updating target-network parmeters...\n", 238 | "updating target-network parmeters...\n", 239 | "updating target-network parmeters...\n", 240 | "updating target-network parmeters...\n", 241 | "updating target-network parmeters...\n", 242 | "updating target-network parmeters...\n", 243 | "updating target-network parmeters...\n", 244 | "updating target-network parmeters...\n", 245 | "updating target-network parmeters...\n", 246 | "updating target-network parmeters...\n", 247 | "updating target-network parmeters...\n", 248 | "updating target-network parmeters...\n", 249 | "updating target-network parmeters...\n", 250 | "updating target-network parmeters...\n", 251 | "updating target-network parmeters...\n", 252 | "updating target-network parmeters...\n", 253 | "updating target-network parmeters...\n", 254 | "updating target-network parmeters...\n", 255 | "updating target-network parmeters...\n", 256 | "updating target-network parmeters...\n", 257 | "updating target-network parmeters...\n", 258 | "updating target-network parmeters...\n", 259 | "updating target-network parmeters...\n", 260 | "updating target-network parmeters...\n", 261 | "updating target-network parmeters...\n", 262 | "updating target-network parmeters...\n", 263 | "updating target-network parmeters...\n", 264 | "updating target-network parmeters...\n", 265 | "updating target-network parmeters...\n", 266 | "updating target-network parmeters...\n", 267 | "updating target-network parmeters...\n", 268 | "updating target-network parmeters...\n", 269 | "updating target-network parmeters...\n", 270 | "updating target-network parmeters...\n", 271 | "updating target-network parmeters...\n", 272 | "updating target-network parmeters...\n", 273 | "updating target-network parmeters...\n", 274 | "updating target-network parmeters...\n", 275 | "updating target-network parmeters...\n", 276 | "updating target-network parmeters...\n", 277 | "updating target-network parmeters...\n", 278 | "updating target-network parmeters...\n", 279 | "updating target-network parmeters...\n", 280 | "updating target-network parmeters...\n", 281 | "updating target-network parmeters...\n", 282 | "updating target-network parmeters...\n", 283 | "updating target-network parmeters...\n", 284 | "updating target-network parmeters...\n", 285 | "updating target-network parmeters...\n", 286 | "updating target-network parmeters...\n", 287 | "updating target-network parmeters...\n", 288 | "updating target-network parmeters...\n", 289 | "updating target-network parmeters...\n", 290 | "updating target-network parmeters...\n", 291 | "updating target-network parmeters...\n", 292 | "updating target-network parmeters...\n", 293 | "updating target-network parmeters...\n", 294 | "updating target-network parmeters...\n", 295 | "updating target-network parmeters...\n", 296 | "updating target-network parmeters...\n", 297 | "updating target-network parmeters...\n", 298 | "updating target-network parmeters...\n", 299 | "updating target-network parmeters...\n", 300 | "updating target-network parmeters...\n", 301 | "updating target-network parmeters...\n", 302 | "updating target-network parmeters...\n", 303 | "[episode = 0 ] step = 125\n", 304 | "updating target-network parmeters...\n", 305 | "updating target-network parmeters...\n", 306 | "[episode = 1 ] step = 24\n", 307 | "[episode = 2 ] step = 45\n", 308 | "[episode = 3 ] step = 51\n", 309 | "updating target-network parmeters...\n", 310 | "[episode = 4 ] step = 78\n", 311 | "[episode = 5 ] step = 58\n", 312 | "[episode = 6 ] step = 32\n", 313 | "[episode = 7 ] step = 28\n", 314 | "[episode = 8 ] step = 34\n", 315 | "[episode = 9 ] step = 11\n", 316 | "[episode = 10 ] step = 9\n", 317 | "[episode = 11 ] step = 8\n", 318 | "[episode = 12 ] step = 8\n", 319 | "updating target-network parmeters...\n", 320 | "[episode = 13 ] step = 8\n", 321 | "[episode = 14 ] step = 30\n", 322 | "[episode = 15 ] step = 45\n", 323 | "[episode = 16 ] step = 17\n", 324 | "[episode = 17 ] step = 8\n", 325 | "[episode = 18 ] step = 13\n", 326 | "[episode = 19 ] step = 9\n", 327 | "[episode = 20 ] step = 9\n", 328 | "[episode = 21 ] step = 10\n", 329 | "[episode = 22 ] step = 35\n", 330 | "updating target-network parmeters...\n", 331 | "[episode = 23 ] step = 11\n", 332 | "[episode = 24 ] step = 33\n", 333 | "[episode = 25 ] step = 39\n", 334 | "[episode = 26 ] step = 30\n", 335 | "[episode = 27 ] step = 22\n", 336 | "[episode = 28 ] step = 30\n", 337 | "[episode = 29 ] step = 19\n", 338 | "[episode = 30 ] step = 16\n", 339 | "updating target-network parmeters...\n", 340 | "[episode = 31 ] step = 38\n", 341 | "[episode = 32 ] step = 34\n", 342 | "[episode = 33 ] step = 49\n", 343 | "[episode = 34 ] step = 34\n", 344 | "[episode = 35 ] step = 18\n", 345 | "[episode = 36 ] step = 10\n", 346 | "updating target-network parmeters...\n", 347 | "[episode = 37 ] step = 20\n", 348 | "[episode = 38 ] step = 33\n", 349 | "[episode = 39 ] step = 62\n", 350 | "[episode = 40 ] step = 44\n", 351 | "[episode = 41 ] step = 26\n", 352 | "updating target-network parmeters...\n", 353 | "[episode = 42 ] step = 32\n", 354 | "[episode = 43 ] step = 10\n", 355 | "[episode = 44 ] step = 14\n", 356 | "[episode = 45 ] step = 45\n", 357 | "[episode = 46 ] step = 9\n", 358 | "[episode = 47 ] step = 9\n", 359 | "[episode = 48 ] step = 8\n", 360 | "[episode = 49 ] step = 9\n", 361 | "[episode = 50 ] step = 13\n", 362 | "[episode = 51 ] step = 9\n", 363 | "[episode = 52 ] step = 9\n", 364 | "[episode = 53 ] step = 8\n", 365 | "[episode = 54 ] step = 9\n", 366 | "[episode = 55 ] step = 9\n", 367 | "[episode = 56 ] step = 8\n", 368 | "updating target-network parmeters...\n", 369 | "[episode = 57 ] step = 11\n", 370 | "[episode = 58 ] step = 21\n", 371 | "[episode = 59 ] step = 35\n", 372 | "[episode = 60 ] step = 19\n", 373 | "[episode = 61 ] step = 11\n", 374 | "[episode = 62 ] step = 9\n", 375 | "[episode = 63 ] step = 22\n", 376 | "[episode = 64 ] step = 17\n", 377 | "[episode = 65 ] step = 9\n", 378 | "[episode = 66 ] step = 9\n", 379 | "[episode = 67 ] step = 11\n", 380 | "updating target-network parmeters...\n", 381 | "[episode = 68 ] step = 32\n", 382 | "[episode = 69 ] step = 14\n", 383 | "[episode = 70 ] step = 9\n", 384 | "[episode = 71 ] step = 11\n", 385 | "[episode = 72 ] step = 9\n", 386 | "[episode = 73 ] step = 17\n", 387 | "[episode = 74 ] step = 14\n", 388 | "[episode = 75 ] step = 22\n", 389 | "[episode = 76 ] step = 30\n", 390 | "[episode = 77 ] step = 8\n", 391 | "[episode = 78 ] step = 13\n", 392 | "[episode = 79 ] step = 10\n", 393 | "[episode = 80 ] step = 10\n", 394 | "updating target-network parmeters...\n", 395 | "[episode = 81 ] step = 9\n", 396 | "[episode = 82 ] step = 39\n", 397 | "[episode = 83 ] step = 30\n", 398 | "[episode = 84 ] step = 9\n", 399 | "[episode = 85 ] step = 33\n", 400 | "[episode = 86 ] step = 11\n", 401 | "[episode = 87 ] step = 9\n", 402 | "[episode = 88 ] step = 25\n", 403 | "[episode = 89 ] step = 10\n", 404 | "[episode = 90 ] step = 15\n", 405 | "updating target-network parmeters...\n", 406 | "[episode = 91 ] step = 21\n", 407 | "[episode = 92 ] step = 86\n", 408 | "[episode = 93 ] step = 19\n", 409 | "[episode = 94 ] step = 11\n", 410 | "[episode = 95 ] step = 9\n", 411 | "[episode = 96 ] step = 56\n", 412 | "updating target-network parmeters...\n", 413 | "[episode = 97 ] step = 9\n", 414 | "[episode = 98 ] step = 9\n", 415 | "[episode = 99 ] step = 10\n", 416 | "[episode = 100 ] step = 24\n", 417 | "[episode = 101 ] step = 20\n", 418 | "[episode = 102 ] step = 22\n", 419 | "[episode = 103 ] step = 36\n", 420 | "[episode = 104 ] step = 38\n", 421 | "updating target-network parmeters...\n", 422 | "[episode = 105 ] step = 34\n", 423 | "[episode = 106 ] step = 38\n", 424 | "[episode = 107 ] step = 92\n", 425 | "[episode = 108 ] step = 31\n", 426 | "updating target-network parmeters...\n", 427 | "[episode = 109 ] step = 60\n", 428 | "[episode = 110 ] step = 150\n", 429 | "updating target-network parmeters...\n", 430 | "[episode = 111 ] step = 68\n", 431 | "[episode = 112 ] step = 42\n", 432 | "[episode = 113 ] step = 22\n", 433 | "[episode = 114 ] step = 31\n", 434 | "[episode = 115 ] step = 31\n" 435 | ] 436 | }, 437 | { 438 | "name": "stdout", 439 | "output_type": "stream", 440 | "text": [ 441 | "updating target-network parmeters...\n", 442 | "[episode = 116 ] step = 39\n", 443 | "[episode = 117 ] step = 44\n", 444 | "[episode = 118 ] step = 34\n", 445 | "[episode = 119 ] step = 29\n", 446 | "[episode = 120 ] step = 48\n", 447 | "updating target-network parmeters...\n", 448 | "[episode = 121 ] step = 44\n", 449 | "[episode = 122 ] step = 33\n", 450 | "[episode = 123 ] step = 19\n", 451 | "[episode = 124 ] step = 44\n", 452 | "[episode = 125 ] step = 56\n", 453 | "updating target-network parmeters...\n", 454 | "[episode = 126 ] step = 43\n", 455 | "[episode = 127 ] step = 36\n", 456 | "[episode = 128 ] step = 24\n", 457 | "[episode = 129 ] step = 33\n", 458 | "[episode = 130 ] step = 33\n", 459 | "updating target-network parmeters...\n", 460 | "[episode = 131 ] step = 52\n", 461 | "[episode = 132 ] step = 56\n", 462 | "[episode = 133 ] step = 59\n", 463 | "[episode = 134 ] step = 39\n", 464 | "updating target-network parmeters...\n", 465 | "[episode = 135 ] step = 48\n", 466 | "[episode = 136 ] step = 44\n", 467 | "[episode = 137 ] step = 39\n", 468 | "[episode = 138 ] step = 33\n", 469 | "updating target-network parmeters...\n", 470 | "[episode = 139 ] step = 135\n", 471 | "[episode = 140 ] step = 36\n", 472 | "[episode = 141 ] step = 38\n", 473 | "[episode = 142 ] step = 50\n", 474 | "updating target-network parmeters...\n", 475 | "[episode = 143 ] step = 44\n", 476 | "[episode = 144 ] step = 61\n", 477 | "[episode = 145 ] step = 48\n", 478 | "updating target-network parmeters...\n", 479 | "[episode = 146 ] step = 76\n", 480 | "[episode = 147 ] step = 106\n", 481 | "[episode = 148 ] step = 45\n", 482 | "updating target-network parmeters...\n", 483 | "[episode = 149 ] step = 74\n", 484 | "[episode = 150 ] step = 33\n", 485 | "[episode = 151 ] step = 61\n", 486 | "[episode = 152 ] step = 40\n", 487 | "updating target-network parmeters...\n", 488 | "[episode = 153 ] step = 62\n", 489 | "[episode = 154 ] step = 69\n", 490 | "[episode = 155 ] step = 42\n", 491 | "updating target-network parmeters...\n", 492 | "[episode = 156 ] step = 50\n", 493 | "[episode = 157 ] step = 58\n", 494 | "[episode = 158 ] step = 42\n", 495 | "[episode = 159 ] step = 65\n", 496 | "updating target-network parmeters...\n", 497 | "[episode = 160 ] step = 55\n", 498 | "[episode = 161 ] step = 41\n", 499 | "[episode = 162 ] step = 63\n", 500 | "updating target-network parmeters...\n", 501 | "[episode = 163 ] step = 57\n", 502 | "[episode = 164 ] step = 68\n", 503 | "[episode = 165 ] step = 48\n", 504 | "[episode = 166 ] step = 51\n", 505 | "updating target-network parmeters...\n", 506 | "[episode = 167 ] step = 80\n", 507 | "[episode = 168 ] step = 56\n", 508 | "[episode = 169 ] step = 49\n", 509 | "updating target-network parmeters...\n", 510 | "[episode = 170 ] step = 55\n", 511 | "[episode = 171 ] step = 73\n", 512 | "[episode = 172 ] step = 38\n", 513 | "[episode = 173 ] step = 56\n", 514 | "updating target-network parmeters...\n", 515 | "[episode = 174 ] step = 58\n", 516 | "[episode = 175 ] step = 94\n", 517 | "updating target-network parmeters...\n", 518 | "[episode = 176 ] step = 52\n", 519 | "[episode = 177 ] step = 46\n", 520 | "[episode = 178 ] step = 53\n", 521 | "[episode = 179 ] step = 67\n", 522 | "updating target-network parmeters...\n", 523 | "[episode = 180 ] step = 41\n", 524 | "[episode = 181 ] step = 79\n", 525 | "updating target-network parmeters...\n", 526 | "[episode = 182 ] step = 151\n", 527 | "[episode = 183 ] step = 51\n", 528 | "[episode = 184 ] step = 99\n", 529 | "updating target-network parmeters...\n", 530 | "[episode = 185 ] step = 48\n", 531 | "[episode = 186 ] step = 99\n", 532 | "updating target-network parmeters...\n", 533 | "[episode = 187 ] step = 69\n", 534 | "[episode = 188 ] step = 60\n", 535 | "updating target-network parmeters...\n", 536 | "[episode = 189 ] step = 172\n", 537 | "[episode = 190 ] step = 53\n", 538 | "[episode = 191 ] step = 56\n", 539 | "updating target-network parmeters...\n", 540 | "[episode = 192 ] step = 54\n", 541 | "[episode = 193 ] step = 93\n", 542 | "[episode = 194 ] step = 48\n", 543 | "updating target-network parmeters...\n", 544 | "[episode = 195 ] step = 89\n", 545 | "updating target-network parmeters...\n", 546 | "[episode = 196 ] step = 163\n", 547 | "[episode = 197 ] step = 106\n", 548 | "[episode = 198 ] step = 61\n", 549 | "updating target-network parmeters...\n", 550 | "[episode = 199 ] step = 61\n", 551 | "[episode = 200 ] step = 54\n", 552 | "[episode = 201 ] step = 67\n", 553 | "updating target-network parmeters...\n", 554 | "[episode = 202 ] step = 53\n", 555 | "[episode = 203 ] step = 80\n", 556 | "[episode = 204 ] step = 64\n", 557 | "updating target-network parmeters...\n", 558 | "[episode = 205 ] step = 68\n", 559 | "[episode = 206 ] step = 76\n", 560 | "[episode = 207 ] step = 66\n", 561 | "updating target-network parmeters...\n", 562 | "[episode = 208 ] step = 78\n", 563 | "[episode = 209 ] step = 59\n", 564 | "[episode = 210 ] step = 70\n", 565 | "updating target-network parmeters...\n", 566 | "[episode = 211 ] step = 68\n", 567 | "[episode = 212 ] step = 67\n", 568 | "updating target-network parmeters...\n", 569 | "[episode = 213 ] step = 73\n", 570 | "[episode = 214 ] step = 84\n", 571 | "[episode = 215 ] step = 62\n", 572 | "updating target-network parmeters...\n", 573 | "[episode = 216 ] step = 56\n", 574 | "[episode = 217 ] step = 98\n", 575 | "[episode = 218 ] step = 59\n", 576 | "updating target-network parmeters...\n", 577 | "[episode = 219 ] step = 66\n", 578 | "[episode = 220 ] step = 123\n", 579 | "updating target-network parmeters...\n", 580 | "[episode = 221 ] step = 184\n", 581 | "updating target-network parmeters...\n", 582 | "[episode = 222 ] step = 92\n", 583 | "[episode = 223 ] step = 96\n", 584 | "updating target-network parmeters...\n", 585 | "[episode = 224 ] step = 78\n", 586 | "[episode = 225 ] step = 165\n", 587 | "updating target-network parmeters...\n", 588 | "[episode = 226 ] step = 77\n", 589 | "[episode = 227 ] step = 63\n", 590 | "updating target-network parmeters...\n", 591 | "[episode = 228 ] step = 108\n", 592 | "[episode = 229 ] step = 63\n", 593 | "[episode = 230 ] step = 103\n", 594 | "updating target-network parmeters...\n", 595 | "[episode = 231 ] step = 155\n", 596 | "updating target-network parmeters...\n", 597 | "[episode = 232 ] step = 83\n", 598 | "[episode = 233 ] step = 88\n", 599 | "[episode = 234 ] step = 11\n", 600 | "updating target-network parmeters...\n", 601 | "[episode = 235 ] step = 106\n", 602 | "[episode = 236 ] step = 97\n", 603 | "updating target-network parmeters...\n", 604 | "[episode = 237 ] step = 95\n", 605 | "[episode = 238 ] step = 153\n", 606 | "updating target-network parmeters...\n", 607 | "[episode = 239 ] step = 107\n", 608 | "[episode = 240 ] step = 93\n", 609 | "updating target-network parmeters...\n", 610 | "[episode = 241 ] step = 96\n", 611 | "[episode = 242 ] step = 93\n", 612 | "updating target-network parmeters...\n", 613 | "[episode = 243 ] step = 72\n", 614 | "[episode = 244 ] step = 67\n", 615 | "updating target-network parmeters...\n", 616 | "[episode = 245 ] step = 89\n", 617 | "[episode = 246 ] step = 10\n", 618 | "[episode = 247 ] step = 12\n", 619 | "[episode = 248 ] step = 67\n", 620 | "[episode = 249 ] step = 76\n", 621 | "updating target-network parmeters...\n", 622 | "[episode = 250 ] step = 112\n", 623 | "updating target-network parmeters...\n", 624 | "[episode = 251 ] step = 107\n", 625 | "[episode = 252 ] step = 85\n", 626 | "[episode = 253 ] step = 73\n", 627 | "updating target-network parmeters...\n", 628 | "[episode = 254 ] step = 106\n", 629 | "updating target-network parmeters...\n", 630 | "[episode = 255 ] step = 199\n", 631 | "[episode = 256 ] step = 125\n", 632 | "updating target-network parmeters...\n", 633 | "[episode = 257 ] step = 64\n", 634 | "[episode = 258 ] step = 117\n", 635 | "updating target-network parmeters...\n", 636 | "[episode = 259 ] step = 93\n", 637 | "[episode = 260 ] step = 84\n", 638 | "[episode = 261 ] step = 10\n", 639 | "updating target-network parmeters...\n", 640 | "[episode = 262 ] step = 72\n", 641 | "[episode = 263 ] step = 79\n", 642 | "updating target-network parmeters...\n", 643 | "[episode = 264 ] step = 80\n", 644 | "[episode = 265 ] step = 70\n", 645 | "[episode = 266 ] step = 88\n", 646 | "updating target-network parmeters...\n", 647 | "[episode = 267 ] step = 112\n", 648 | "[episode = 268 ] step = 88\n", 649 | "updating target-network parmeters...\n", 650 | "[episode = 269 ] step = 80\n", 651 | "updating target-network parmeters...\n", 652 | "[episode = 270 ] step = 151\n", 653 | "[episode = 271 ] step = 161\n", 654 | "updating target-network parmeters...\n", 655 | "[episode = 272 ] step = 177\n", 656 | "updating target-network parmeters...\n", 657 | "[episode = 273 ] step = 179\n", 658 | "updating target-network parmeters...\n", 659 | "[episode = 274 ] step = 127\n", 660 | "[episode = 275 ] step = 121\n", 661 | "updating target-network parmeters...\n", 662 | "[episode = 276 ] step = 147\n", 663 | "updating target-network parmeters...\n", 664 | "[episode = 277 ] step = 138\n", 665 | "[episode = 278 ] step = 109\n", 666 | "updating target-network parmeters...\n", 667 | "[episode = 279 ] step = 145\n", 668 | "updating target-network parmeters...\n", 669 | "[episode = 280 ] step = 148\n", 670 | "[episode = 281 ] step = 119\n", 671 | "updating target-network parmeters...\n", 672 | "[episode = 282 ] step = 162\n", 673 | "updating target-network parmeters...\n", 674 | "[episode = 283 ] step = 129\n", 675 | "updating target-network parmeters...\n", 676 | "[episode = 284 ] step = 139\n", 677 | "[episode = 285 ] step = 116\n", 678 | "updating target-network parmeters...\n", 679 | "[episode = 286 ] step = 106\n", 680 | "[episode = 287 ] step = 74\n", 681 | "updating target-network parmeters...\n", 682 | "[episode = 288 ] step = 163\n", 683 | "[episode = 289 ] step = 113\n", 684 | "updating target-network parmeters...\n", 685 | "[episode = 290 ] step = 117\n", 686 | "updating target-network parmeters...\n", 687 | "[episode = 291 ] step = 117\n", 688 | "[episode = 292 ] step = 107\n", 689 | "updating target-network parmeters...\n", 690 | "[episode = 293 ] step = 91\n", 691 | "[episode = 294 ] step = 136\n", 692 | "updating target-network parmeters...\n", 693 | "[episode = 295 ] step = 134\n", 694 | "updating target-network parmeters...\n", 695 | "[episode = 296 ] step = 193\n", 696 | "updating target-network parmeters...\n", 697 | "[episode = 297 ] step = 146\n", 698 | "updating target-network parmeters...\n", 699 | "[episode = 298 ] step = 183\n", 700 | "updating target-network parmeters...\n", 701 | "[episode = 299 ] step = 197\n", 702 | "[episode = 300 ] step = 140\n", 703 | "updating target-network parmeters...\n", 704 | "[episode = 301 ] step = 113\n", 705 | "updating target-network parmeters...\n", 706 | "[episode = 302 ] step = 132\n", 707 | "[episode = 303 ] step = 174\n", 708 | "updating target-network parmeters...\n", 709 | "[episode = 304 ] step = 123\n", 710 | "updating target-network parmeters...\n", 711 | "[episode = 305 ] step = 196\n", 712 | "updating target-network parmeters...\n" 713 | ] 714 | }, 715 | { 716 | "name": "stdout", 717 | "output_type": "stream", 718 | "text": [ 719 | "[episode = 306 ] step = 176\n", 720 | "updating target-network parmeters...\n", 721 | "[episode = 307 ] step = 188\n", 722 | "updating target-network parmeters...\n", 723 | "[episode = 308 ] step = 130\n", 724 | "[episode = 309 ] step = 185\n", 725 | "updating target-network parmeters...\n", 726 | "[episode = 310 ] step = 199\n", 727 | "updating target-network parmeters...\n", 728 | "[episode = 311 ] step = 199\n", 729 | "updating target-network parmeters...\n", 730 | "[episode = 312 ] step = 196\n", 731 | "updating target-network parmeters...\n", 732 | "[episode = 313 ] step = 199\n", 733 | "updating target-network parmeters...\n", 734 | "[episode = 314 ] step = 152\n", 735 | "updating target-network parmeters...\n", 736 | "[episode = 315 ] step = 199\n", 737 | "updating target-network parmeters...\n", 738 | "[episode = 316 ] step = 170\n", 739 | "updating target-network parmeters...\n", 740 | "[episode = 317 ] step = 144\n", 741 | "updating target-network parmeters...\n", 742 | "[episode = 318 ] step = 166\n", 743 | "updating target-network parmeters...\n", 744 | "[episode = 319 ] step = 184\n", 745 | "updating target-network parmeters...\n", 746 | "[episode = 320 ] step = 197\n", 747 | "[episode = 321 ] step = 147\n", 748 | "updating target-network parmeters...\n", 749 | "[episode = 322 ] step = 130\n", 750 | "updating target-network parmeters...\n", 751 | "[episode = 323 ] step = 133\n", 752 | "updating target-network parmeters...\n", 753 | "[episode = 324 ] step = 195\n", 754 | "[episode = 325 ] step = 169\n", 755 | "updating target-network parmeters...\n", 756 | "[episode = 326 ] step = 199\n", 757 | "updating target-network parmeters...\n", 758 | "[episode = 327 ] step = 199\n", 759 | "updating target-network parmeters...\n", 760 | "[episode = 328 ] step = 199\n", 761 | "updating target-network parmeters...\n", 762 | "[episode = 329 ] step = 199\n", 763 | "updating target-network parmeters...\n", 764 | "[episode = 330 ] step = 199\n", 765 | "updating target-network parmeters...\n", 766 | "[episode = 331 ] step = 158\n", 767 | "updating target-network parmeters...\n", 768 | "[episode = 332 ] step = 185\n", 769 | "updating target-network parmeters...\n", 770 | "[episode = 333 ] step = 175\n", 771 | "updating target-network parmeters...\n", 772 | "[episode = 334 ] step = 152\n", 773 | "updating target-network parmeters...\n", 774 | "[episode = 335 ] step = 168\n", 775 | "[episode = 336 ] step = 156\n", 776 | "updating target-network parmeters...\n", 777 | "[episode = 337 ] step = 170\n", 778 | "updating target-network parmeters...\n", 779 | "[episode = 338 ] step = 199\n", 780 | "updating target-network parmeters...\n", 781 | "[episode = 339 ] step = 199\n", 782 | "updating target-network parmeters...\n", 783 | "[episode = 340 ] step = 199\n", 784 | "updating target-network parmeters...\n", 785 | "[episode = 341 ] step = 199\n", 786 | "updating target-network parmeters...\n", 787 | "[episode = 342 ] step = 199\n", 788 | "updating target-network parmeters...\n", 789 | "[episode = 343 ] step = 62\n", 790 | "updating target-network parmeters...\n", 791 | "[episode = 344 ] step = 199\n", 792 | "updating target-network parmeters...\n", 793 | "[episode = 345 ] step = 187\n", 794 | "[episode = 346 ] step = 71\n", 795 | "updating target-network parmeters...\n", 796 | "[episode = 347 ] step = 153\n", 797 | "[episode = 348 ] step = 141\n", 798 | "updating target-network parmeters...\n", 799 | "[episode = 349 ] step = 152\n", 800 | "updating target-network parmeters...\n", 801 | "[episode = 350 ] step = 140\n", 802 | "updating target-network parmeters...\n", 803 | "[episode = 351 ] step = 193\n", 804 | "[episode = 352 ] step = 94\n", 805 | "updating target-network parmeters...\n", 806 | "[episode = 353 ] step = 125\n", 807 | "updating target-network parmeters...\n", 808 | "[episode = 354 ] step = 119\n", 809 | "[episode = 355 ] step = 60\n", 810 | "updating target-network parmeters...\n", 811 | "[episode = 356 ] step = 185\n", 812 | "[episode = 357 ] step = 127\n", 813 | "updating target-network parmeters...\n", 814 | "[episode = 358 ] step = 49\n", 815 | "updating target-network parmeters...\n", 816 | "[episode = 359 ] step = 199\n", 817 | "updating target-network parmeters...\n", 818 | "[episode = 360 ] step = 199\n", 819 | "[episode = 361 ] step = 61\n", 820 | "[episode = 362 ] step = 17\n", 821 | "[episode = 363 ] step = 23\n", 822 | "updating target-network parmeters...\n", 823 | "[episode = 364 ] step = 199\n", 824 | "updating target-network parmeters...\n", 825 | "[episode = 365 ] step = 107\n", 826 | "updating target-network parmeters...\n", 827 | "[episode = 366 ] step = 180\n", 828 | "[episode = 367 ] step = 21\n", 829 | "[episode = 368 ] step = 69\n", 830 | "updating target-network parmeters...\n", 831 | "[episode = 369 ] step = 153\n", 832 | "[episode = 370 ] step = 34\n", 833 | "[episode = 371 ] step = 34\n", 834 | "updating target-network parmeters...\n", 835 | "[episode = 372 ] step = 60\n", 836 | "[episode = 373 ] step = 111\n", 837 | "updating target-network parmeters...\n", 838 | "[episode = 374 ] step = 122\n", 839 | "[episode = 375 ] step = 23\n", 840 | "[episode = 376 ] step = 53\n", 841 | "updating target-network parmeters...\n", 842 | "[episode = 377 ] step = 107\n", 843 | "[episode = 378 ] step = 28\n", 844 | "[episode = 379 ] step = 22\n", 845 | "[episode = 380 ] step = 74\n", 846 | "[episode = 381 ] step = 20\n", 847 | "updating target-network parmeters...\n", 848 | "[episode = 382 ] step = 39\n", 849 | "[episode = 383 ] step = 121\n", 850 | "updating target-network parmeters...\n", 851 | "[episode = 384 ] step = 164\n", 852 | "[episode = 385 ] step = 76\n", 853 | "updating target-network parmeters...\n", 854 | "[episode = 386 ] step = 159\n", 855 | "updating target-network parmeters...\n", 856 | "[episode = 387 ] step = 134\n", 857 | "[episode = 388 ] step = 100\n", 858 | "updating target-network parmeters...\n", 859 | "[episode = 389 ] step = 118\n", 860 | "updating target-network parmeters...\n", 861 | "[episode = 390 ] step = 113\n", 862 | "[episode = 391 ] step = 143\n", 863 | "updating target-network parmeters...\n", 864 | "[episode = 392 ] step = 122\n", 865 | "updating target-network parmeters...\n", 866 | "[episode = 393 ] step = 130\n", 867 | "[episode = 394 ] step = 121\n", 868 | "updating target-network parmeters...\n", 869 | "[episode = 395 ] step = 122\n", 870 | "[episode = 396 ] step = 106\n", 871 | "updating target-network parmeters...\n", 872 | "[episode = 397 ] step = 114\n", 873 | "[episode = 398 ] step = 117\n", 874 | "updating target-network parmeters...\n", 875 | "[episode = 399 ] step = 104\n", 876 | "updating target-network parmeters...\n", 877 | "[episode = 400 ] step = 109\n", 878 | "[episode = 401 ] step = 119\n", 879 | "updating target-network parmeters...\n", 880 | "[episode = 402 ] step = 121\n", 881 | "[episode = 403 ] step = 108\n", 882 | "updating target-network parmeters...\n", 883 | "[episode = 404 ] step = 123\n", 884 | "updating target-network parmeters...\n", 885 | "[episode = 405 ] step = 166\n", 886 | "[episode = 406 ] step = 117\n", 887 | "updating target-network parmeters...\n", 888 | "[episode = 407 ] step = 132\n", 889 | "updating target-network parmeters...\n", 890 | "[episode = 408 ] step = 125\n", 891 | "[episode = 409 ] step = 110\n", 892 | "[episode = 410 ] step = 61\n", 893 | "updating target-network parmeters...\n", 894 | "[episode = 411 ] step = 101\n", 895 | "updating target-network parmeters...\n", 896 | "[episode = 412 ] step = 199\n", 897 | "updating target-network parmeters...\n", 898 | "[episode = 413 ] step = 199\n", 899 | "[episode = 414 ] step = 94\n", 900 | "updating target-network parmeters...\n", 901 | "[episode = 415 ] step = 155\n", 902 | "updating target-network parmeters...\n", 903 | "[episode = 416 ] step = 199\n", 904 | "updating target-network parmeters...\n", 905 | "[episode = 417 ] step = 199\n", 906 | "updating target-network parmeters...\n", 907 | "[episode = 418 ] step = 199\n", 908 | "updating target-network parmeters...\n", 909 | "[episode = 419 ] step = 199\n", 910 | "updating target-network parmeters...\n", 911 | "[episode = 420 ] step = 199\n", 912 | "updating target-network parmeters...\n", 913 | "[episode = 421 ] step = 163\n", 914 | "[episode = 422 ] step = 61\n", 915 | "updating target-network parmeters...\n", 916 | "[episode = 423 ] step = 154\n", 917 | "updating target-network parmeters...\n", 918 | "[episode = 424 ] step = 153\n", 919 | "updating target-network parmeters...\n", 920 | "[episode = 425 ] step = 199\n", 921 | "updating target-network parmeters...\n", 922 | "[episode = 426 ] step = 181\n", 923 | "updating target-network parmeters...\n", 924 | "[episode = 427 ] step = 150\n", 925 | "[episode = 428 ] step = 137\n", 926 | "updating target-network parmeters...\n", 927 | "[episode = 429 ] step = 134\n", 928 | "[episode = 430 ] step = 87\n", 929 | "updating target-network parmeters...\n", 930 | "[episode = 431 ] step = 102\n", 931 | "updating target-network parmeters...\n", 932 | "[episode = 432 ] step = 199\n", 933 | "[episode = 433 ] step = 73\n", 934 | "updating target-network parmeters...\n", 935 | "[episode = 434 ] step = 116\n", 936 | "[episode = 435 ] step = 79\n", 937 | "updating target-network parmeters...\n", 938 | "[episode = 436 ] step = 199\n", 939 | "updating target-network parmeters...\n", 940 | "[episode = 437 ] step = 199\n", 941 | "updating target-network parmeters...\n", 942 | "[episode = 438 ] step = 160\n", 943 | "updating target-network parmeters...\n", 944 | "[episode = 439 ] step = 143\n", 945 | "updating target-network parmeters...\n", 946 | "[episode = 440 ] step = 197\n", 947 | "[episode = 441 ] step = 35\n", 948 | "updating target-network parmeters...\n", 949 | "[episode = 442 ] step = 183\n", 950 | "updating target-network parmeters...\n", 951 | "[episode = 443 ] step = 199\n", 952 | "updating target-network parmeters...\n", 953 | "[episode = 444 ] step = 199\n", 954 | "updating target-network parmeters...\n", 955 | "[episode = 445 ] step = 199\n", 956 | "updating target-network parmeters...\n", 957 | "[episode = 446 ] step = 154\n", 958 | "updating target-network parmeters...\n", 959 | "[episode = 447 ] step = 199\n", 960 | "updating target-network parmeters...\n", 961 | "[episode = 448 ] step = 195\n", 962 | "updating target-network parmeters...\n", 963 | "[episode = 449 ] step = 199\n", 964 | "updating target-network parmeters...\n", 965 | "[episode = 450 ] step = 186\n", 966 | "[episode = 451 ] step = 188\n", 967 | "updating target-network parmeters...\n", 968 | "[episode = 452 ] step = 142\n", 969 | "updating target-network parmeters...\n", 970 | "[episode = 453 ] step = 199\n", 971 | "updating target-network parmeters...\n", 972 | "[episode = 454 ] step = 199\n", 973 | "updating target-network parmeters...\n", 974 | "[episode = 455 ] step = 186\n", 975 | "updating target-network parmeters...\n", 976 | "[episode = 456 ] step = 169\n", 977 | "updating target-network parmeters...\n", 978 | "[episode = 457 ] step = 190\n", 979 | "[episode = 458 ] step = 25\n" 980 | ] 981 | }, 982 | { 983 | "name": "stdout", 984 | "output_type": "stream", 985 | "text": [ 986 | "updating target-network parmeters...\n", 987 | "[episode = 459 ] step = 90\n", 988 | "[episode = 460 ] step = 52\n", 989 | "[episode = 461 ] step = 36\n", 990 | "updating target-network parmeters...\n", 991 | "[episode = 462 ] step = 175\n", 992 | "updating target-network parmeters...\n", 993 | "[episode = 463 ] step = 170\n", 994 | "updating target-network parmeters...\n", 995 | "[episode = 464 ] step = 199\n", 996 | "updating target-network parmeters...\n", 997 | "[episode = 465 ] step = 199\n", 998 | "[episode = 466 ] step = 79\n", 999 | "updating target-network parmeters...\n", 1000 | "[episode = 467 ] step = 199\n", 1001 | "[episode = 468 ] step = 46\n", 1002 | "[episode = 469 ] step = 15\n", 1003 | "updating target-network parmeters...\n", 1004 | "[episode = 470 ] step = 199\n", 1005 | "updating target-network parmeters...\n", 1006 | "[episode = 471 ] step = 50\n", 1007 | "[episode = 472 ] step = 119\n", 1008 | "updating target-network parmeters...\n", 1009 | "[episode = 473 ] step = 199\n", 1010 | "updating target-network parmeters...\n", 1011 | "[episode = 474 ] step = 188\n", 1012 | "updating target-network parmeters...\n", 1013 | "[episode = 475 ] step = 199\n", 1014 | "updating target-network parmeters...\n", 1015 | "[episode = 476 ] step = 199\n", 1016 | "updating target-network parmeters...\n", 1017 | "[episode = 477 ] step = 71\n", 1018 | "[episode = 478 ] step = 89\n", 1019 | "[episode = 479 ] step = 66\n", 1020 | "updating target-network parmeters...\n", 1021 | "[episode = 480 ] step = 123\n", 1022 | "updating target-network parmeters...\n", 1023 | "[episode = 481 ] step = 170\n", 1024 | "updating target-network parmeters...\n", 1025 | "[episode = 482 ] step = 199\n", 1026 | "updating target-network parmeters...\n", 1027 | "[episode = 483 ] step = 198\n", 1028 | "updating target-network parmeters...\n", 1029 | "[episode = 484 ] step = 183\n", 1030 | "updating target-network parmeters...\n", 1031 | "[episode = 485 ] step = 146\n", 1032 | "[episode = 486 ] step = 146\n", 1033 | "updating target-network parmeters...\n", 1034 | "[episode = 487 ] step = 199\n", 1035 | "[episode = 488 ] step = 36\n", 1036 | "updating target-network parmeters...\n", 1037 | "[episode = 489 ] step = 68\n", 1038 | "updating target-network parmeters...\n", 1039 | "[episode = 490 ] step = 199\n", 1040 | "[episode = 491 ] step = 84\n", 1041 | "updating target-network parmeters...\n", 1042 | "[episode = 492 ] step = 93\n", 1043 | "[episode = 493 ] step = 126\n", 1044 | "updating target-network parmeters...\n", 1045 | "[episode = 494 ] step = 122\n", 1046 | "updating target-network parmeters...\n", 1047 | "[episode = 495 ] step = 199\n", 1048 | "updating target-network parmeters...\n", 1049 | "[episode = 496 ] step = 199\n", 1050 | "updating target-network parmeters...\n", 1051 | "[episode = 497 ] step = 165\n", 1052 | "[episode = 498 ] step = 125\n", 1053 | "updating target-network parmeters...\n", 1054 | "[episode = 499 ] step = 170\n" 1055 | ] 1056 | } 1057 | ], 1058 | "source": [ 1059 | "if __name__ == \"__main__\":\n", 1060 | " env = gym.make(ENV)\n", 1061 | " with tf.Session() as sess:\n", 1062 | " DQN = DeepQNetwork(env , sess )\n", 1063 | " update_iter = 0\n", 1064 | " step_his = []\n", 1065 | " for episode in range(EPISODES):\n", 1066 | " state = env.reset()\n", 1067 | " env.render() \n", 1068 | " reward_all = 0\n", 1069 | " #training\n", 1070 | " for step in range(MAX_STEP):\n", 1071 | " action = DQN.chose_action(state)\n", 1072 | " next_state , reward , done , _ = env.step(action)\n", 1073 | " reward_all += reward \n", 1074 | "\n", 1075 | " if len(memory) > MEMORY_SIZE:\n", 1076 | " memory.pop(0)\n", 1077 | " memory.append(Transition(state, action , reward , next_state , float(done)))\n", 1078 | "\n", 1079 | " if len(memory) > BATCH_SIZE * 4:\n", 1080 | " batch_transition = random.sample(memory , BATCH_SIZE)\n", 1081 | " #***\n", 1082 | " batch_state, batch_action, batch_reward, batch_next_state, batch_done = map(np.array , zip(*batch_transition)) \n", 1083 | " DQN.train(state = batch_state ,\n", 1084 | " reward = batch_reward , \n", 1085 | " action = batch_action , \n", 1086 | " state_next = batch_next_state,\n", 1087 | " done = batch_done\n", 1088 | " )\n", 1089 | " update_iter += 1\n", 1090 | "\n", 1091 | " if update_iter % UPDATE_PERIOD == 0:\n", 1092 | " DQN.update_prmt()\n", 1093 | " \n", 1094 | " if update_iter % 200 == 0:\n", 1095 | " DQN.decay_epsilon()\n", 1096 | "\n", 1097 | " if done:\n", 1098 | " print(\"[episode = {} ] step = {}\".format(episode , step))\n", 1099 | " break\n", 1100 | " \n", 1101 | " state = next_state\n", 1102 | " " 1103 | ] 1104 | } 1105 | ], 1106 | "metadata": { 1107 | "kernelspec": { 1108 | "display_name": "Python 3", 1109 | "language": "python", 1110 | "name": "python3" 1111 | }, 1112 | "language_info": { 1113 | "codemirror_mode": { 1114 | "name": "ipython", 1115 | "version": 3 1116 | }, 1117 | "file_extension": ".py", 1118 | "mimetype": "text/x-python", 1119 | "name": "python", 1120 | "nbconvert_exporter": "python", 1121 | "pygments_lexer": "ipython3", 1122 | "version": "3.5.2" 1123 | } 1124 | }, 1125 | "nbformat": 4, 1126 | "nbformat_minor": 2 1127 | } 1128 | -------------------------------------------------------------------------------- /DQN_Cartpole.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import collections 4 | import gym 5 | import random 6 | import tensorflow.contrib.layers as layers 7 | 8 | ENV = "CartPole-v0" 9 | 10 | MEMORY_SIZE = 10000 11 | EPISODES = 500 12 | MAX_STEP = 500 13 | BATCH_SIZE = 32 14 | UPDATE_PERIOD = 200 # update target network parameters 15 | 16 | 17 | ##built class for the DQN 18 | class DeepQNetwork(): 19 | def __init__(self , env , sess=None , gamma = 0.8, epsilon = 0.8 ): 20 | self.gamma = gamma 21 | self.epsilon = epsilon 22 | self.action_dim = env.action_space.n 23 | self.state_dim = env.observation_space.shape[0] 24 | self.network() 25 | self.sess = sess 26 | self.sess.run(tf.global_variables_initializer()) 27 | tf.summary.FileWriter("DQN/summaries" , sess.graph ) 28 | 29 | # net_frame using for creating Q & target network 30 | def net_frame(self , hiddens, inpt, num_actions, scope, reuse=None): 31 | with tf.variable_scope(scope, reuse=reuse): 32 | out = inpt 33 | for hidden in hiddens: 34 | out = layers.fully_connected(out, num_outputs=hidden, activation_fn=tf.nn.relu) 35 | out = layers.fully_connected(out, num_outputs=num_actions, activation_fn=None) 36 | return out 37 | 38 | # create q_network & target_network 39 | def network(self): 40 | # q_network 41 | self.inputs_q = tf.placeholder(dtype = tf.float32 , shape = [None , self.state_dim] , name = "inputs_q") 42 | scope_var = "q_network" 43 | self.q_value = self.net_frame([64] , self.inputs_q , self.action_dim , scope_var , reuse = True ) 44 | 45 | # target_network 46 | self.inputs_target = tf.placeholder(dtype = tf.float32 , shape = [None , self.state_dim] , name = "inputs_target") 47 | scope_tar = "target_network" 48 | self.q_target = self.net_frame([64] , self.inputs_target , self.action_dim , scope_tar ) 49 | 50 | with tf.variable_scope("loss"): 51 | # #【方案一】 52 | # self.target = tf.placeholder(dtype = tf.float32 , shape = [None , self.action_dim] , name = "target") 53 | # self.loss = tf.reduce_mean( tf.square(self.q_value - self.target)) 54 | #【方案二】 55 | self.action = tf.placeholder(dtype = tf.int32 , shape = [ None ] , name = "action") 56 | action_one_hot = tf.one_hot(self.action , self.action_dim ) 57 | q_action = tf.reduce_sum( tf.multiply(self.q_value , action_one_hot) , axis = 1 ) 58 | 59 | self.target = tf.placeholder(dtype = tf.float32 , shape = [None ] , name = "target") 60 | self.loss = tf.reduce_mean( tf.square(q_action - self.target)) 61 | 62 | with tf.variable_scope("train"): 63 | optimizer = tf.train.RMSPropOptimizer(0.001) 64 | self.train_op = optimizer.minimize(self.loss) 65 | 66 | # training 67 | def train(self , state , reward , action , state_next , done): 68 | q , q_target = self.sess.run([self.q_value , self.q_target] , 69 | feed_dict={self.inputs_q : state , self.inputs_target : state_next } ) 70 | # #【方案一】 71 | # target = reward + self.gamma * np.max(q_target , axis = 1)*(1.0 - done) 72 | 73 | # self.reform_target = q.copy() 74 | # batch_index = np.arange(BATCH_SIZE , dtype = np.int32) 75 | # self.reform_target[batch_index , action] = target 76 | 77 | # loss , _ = self.sess.run([self.loss , self.train_op] , feed_dict={self.inputs_q: state , self.target: self.reform_target} ) 78 | 79 | #【方案二】 80 | q_target_best = np.max(q_target , axis = 1) 81 | q_target_best_mask = ( 1.0 - done) * q_target_best 82 | 83 | target = reward + self.gamma * q_target_best_mask 84 | 85 | loss , _ = self.sess.run([self.loss , self.train_op] , 86 | feed_dict={self.inputs_q: state , self.target:target , self.action:action} ) 87 | # chose action 88 | def chose_action(self , current_state): 89 | current_state = current_state[np.newaxis , :] #*** array dim: (xx,) --> (1 , xx) *** 90 | q = self.sess.run(self.q_value , feed_dict={self.inputs_q : current_state} ) 91 | 92 | # e-greedy 93 | if np.random.random() < self.epsilon: 94 | action_chosen = np.random.randint(0 , self.action_dim) 95 | else: 96 | action_chosen = np.argmax(q) 97 | 98 | return action_chosen 99 | 100 | #upadate parmerters 101 | def update_prmt(self): 102 | q_prmts = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES , "q_network" ) 103 | target_prmts = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, "target_network" ) 104 | self.sess.run( [tf.assign(t , q)for t,q in zip(target_prmts , q_prmts)]) #*** 105 | print("updating target-network parmeters...") 106 | 107 | def decay_epsilon(self): 108 | if self.epsilon > 0.03: 109 | self.epsilon = self.epsilon - 0.02 110 | 111 | # memory for momery replay 112 | memory = [] 113 | Transition = collections.namedtuple("Transition" , ["state", "action" , "reward" , "next_state" , "done"]) 114 | 115 | if __name__ == "__main__": 116 | env = gym.make(ENV) 117 | with tf.Session() as sess: 118 | DQN = DeepQNetwork(env , sess ) 119 | update_iter = 0 120 | step_his = [] 121 | for episode in range(EPISODES): 122 | state = env.reset() 123 | env.render() 124 | reward_all = 0 125 | #training 126 | for step in range(MAX_STEP): 127 | action = DQN.chose_action(state) 128 | next_state , reward , done , _ = env.step(action) 129 | reward_all += reward 130 | 131 | if len(memory) > MEMORY_SIZE: 132 | memory.pop(0) 133 | memory.append(Transition(state, action , reward , next_state , float(done))) 134 | 135 | if len(memory) > BATCH_SIZE * 4: 136 | batch_transition = random.sample(memory , BATCH_SIZE) 137 | #*** 138 | batch_state, batch_action, batch_reward, batch_next_state, batch_done = map(np.array , zip(*batch_transition)) 139 | DQN.train(state = batch_state , 140 | reward = batch_reward , 141 | action = batch_action , 142 | state_next = batch_next_state, 143 | done = batch_done 144 | ) 145 | update_iter += 1 146 | 147 | if update_iter % UPDATE_PERIOD == 0: 148 | DQN.update_prmt() 149 | 150 | if update_iter % 200 == 0: 151 | DQN.decay_epsilon() 152 | 153 | if done: 154 | print("[episode = {} ] step = {}".format(episode , step)) 155 | break 156 | 157 | state = next_state 158 | -------------------------------------------------------------------------------- /DuelingDQN.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import collections 4 | import gym 5 | import random 6 | import tensorflow.contrib.layers as layers 7 | 8 | ENV = "CartPole-v0" 9 | 10 | MEMORY_SIZE = 10000 11 | EPISODES = 1000 12 | MAX_STEP = 500 13 | BATCH_SIZE = 32 14 | UPDATE_PERIOD = 200 # update target network parameters 15 | 16 | 17 | 18 | 19 | 20 | 21 | ##built class for the DQN 22 | class DeepQNetwork(): 23 | def __init__(self , scope_main , env , sess=None , gamma = 0.8, epsilon = 0.8 , dueling = True , out_graph = False , out_dqn = True): 24 | self.gamma = gamma 25 | self.epsilon = epsilon 26 | self.loss_his = [] 27 | 28 | self.scope_main = scope_main 29 | self.dueling = dueling 30 | self.out_dqn = out_dqn 31 | 32 | self.action_dim = env.action_space.n 33 | self.state_dim = env.observation_space.shape[0] 34 | self.network() 35 | self.sess = sess 36 | self.sess.run(tf.global_variables_initializer()) 37 | tf.summary.FileWriter("DQN/summaries" , sess.graph ) 38 | 39 | # net_frame using for creating Q & target network 40 | def net_frame(self , hiddens, inpt, num_actions, scope, hiddens_a , hiddens_v , reuse=None): 41 | with tf.variable_scope(scope, reuse=reuse): 42 | out = inpt 43 | for hidden in hiddens: 44 | out = layers.fully_connected(out, num_outputs=hidden, activation_fn=tf.nn.relu) 45 | 46 | if self.dueling == True : 47 | # value_stream 48 | with tf.variable_scope("value_stream"): 49 | value = out 50 | for hidden in hiddens_v: 51 | value = layers.fully_connected(value, num_outputs= hidden , activation_fn=None) 52 | value = layers.fully_connected(value, num_outputs= 1 , activation_fn=None) 53 | 54 | # advantage_stream 55 | with tf.variable_scope("advantage_stream"): 56 | advantage = out 57 | for hidden in hiddens_a: 58 | advantage = layers.fully_connected(advantage , num_outputs = hidden , activation_fn=None) 59 | advantage = layers.fully_connected(advantage , num_outputs= num_actions , activation_fn=None) 60 | 61 | # aggregating_moudle 62 | with tf.variable_scope("aggregating_moudle"): 63 | q_out = value + advantage - tf.reduce_mean(advantage , axis = 1 , keep_dims = True ) # ***keep_dims 64 | 65 | elif self.out_dqn: 66 | with tf.variable_scope("dqn_out"): 67 | q_out = layers.fully_connected(out, num_outputs=num_actions, activation_fn=None) 68 | 69 | return q_out 70 | 71 | 72 | # create q_network & target_network 73 | def network(self): 74 | # q_network 75 | self.inputs_q = tf.placeholder(dtype = tf.float32 , shape = [None , self.state_dim] , name = "inputs_q") 76 | scope_var = "q_network" 77 | self.q_value = self.net_frame([64] , self.inputs_q , self.action_dim , scope_var , [20] , [20] , reuse = True ) 78 | 79 | # target_network 80 | self.inputs_target = tf.placeholder(dtype = tf.float32 , shape = [None , self.state_dim] , name = "inputs_target") 81 | scope_tar = "target_network" 82 | self.q_target = self.net_frame([64] , self.inputs_target , self.action_dim , scope_tar , [20] , [20] ) 83 | 84 | with tf.variable_scope("loss"): 85 | self.action = tf.placeholder(dtype = tf.int32 , shape = [ None ] , name = "action") 86 | action_one_hot = tf.one_hot(self.action , self.action_dim ) 87 | q_action = tf.reduce_sum( tf.multiply(self.q_value , action_one_hot) , axis = 1 ) 88 | 89 | self.target = tf.placeholder(dtype = tf.float32 , shape = [None ] , name = "target") 90 | self.loss = tf.reduce_mean( tf.square(q_action - self.target)) 91 | 92 | with tf.variable_scope("train"): 93 | optimizer = tf.train.RMSPropOptimizer(0.001) 94 | self.train_op = optimizer.minimize(self.loss) 95 | 96 | # training 97 | def train(self , state , reward , action , state_next , done): 98 | q , q_target = self.sess.run([self.q_value , self.q_target] , 99 | feed_dict={self.inputs_q : state , self.inputs_target : state_next } ) 100 | 101 | q_target_best = np.max(q_target , axis = 1) 102 | q_target_best_mask = ( 1.0 - done) * q_target_best 103 | 104 | target = reward + self.gamma * q_target_best_mask 105 | 106 | loss , _ = self.sess.run([self.loss , self.train_op] , 107 | feed_dict={self.inputs_q: state , self.target:target , self.action:action} ) 108 | self.loss_his.append(loss) 109 | 110 | # chose action 111 | def chose_action(self , current_state): 112 | current_state = current_state[np.newaxis , :] #*** array dim: (xx,) --> (1 , xx) *** 113 | q = self.sess.run(self.q_value , feed_dict={self.inputs_q : current_state} ) 114 | 115 | # e-greedy 116 | if np.random.random() < self.epsilon: 117 | action_chosen = np.random.randint(0 , self.action_dim) 118 | else: 119 | action_chosen = np.argmax(q) 120 | 121 | return action_chosen 122 | 123 | #upadate parmerters 124 | def update_prmt(self): 125 | q_prmts = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES , self.scope_main + "/q_network" ) 126 | target_prmts = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, self.scope_main + "/target_network" ) 127 | self.sess.run( [tf.assign(t , q)for t,q in zip(target_prmts , q_prmts)]) #*** 128 | print("updating target-network parmeters...") 129 | 130 | def decay_epsilon(self): 131 | if self.epsilon > 0.03: 132 | self.epsilon = self.epsilon - 0.02 133 | 134 | def greedy_action(self , current_state): 135 | current_state = current_state[np.newaxis , :] 136 | q = self.sess.run(self.q_value , feed_dict={self.inputs_q : current_state} ) 137 | action_greedy = np.argmax(q) 138 | return action_greedy 139 | 140 | 141 | 142 | 143 | # memory for momery replay 144 | memory = [] 145 | Transition = collections.namedtuple("Transition" , ["state", "action" , "reward" , "next_state" , "done"]) 146 | 147 | 148 | 149 | 150 | 151 | def train( DQN , env ): 152 | reward_his = [] 153 | all_reward = 0 154 | step_his = [] 155 | update_iter = 0 156 | for episode in range(EPISODES): 157 | state = env.reset() 158 | # env.render() 159 | # reward_all = 0 160 | #training 161 | for step in range(MAX_STEP): 162 | action = DQN.chose_action(state) 163 | next_state , reward , done , _ = env.step(action) 164 | all_reward += reward 165 | 166 | if len(memory) > MEMORY_SIZE: 167 | memory.pop(0) 168 | memory.append(Transition(state, action , reward , next_state , float(done))) 169 | 170 | if len(memory) > BATCH_SIZE * 4: 171 | batch_transition = random.sample(memory , BATCH_SIZE) 172 | #*** 173 | batch_state, batch_action, batch_reward, batch_next_state, batch_done = map(np.array , zip(*batch_transition)) 174 | DQN.train(state = batch_state , 175 | reward = batch_reward , 176 | action = batch_action , 177 | state_next = batch_next_state, 178 | done = batch_done 179 | ) 180 | update_iter += 1 181 | 182 | if update_iter % UPDATE_PERIOD == 0: 183 | DQN.update_prmt() 184 | 185 | if update_iter % 200 == 0: 186 | DQN.decay_epsilon() 187 | 188 | if done: 189 | step_his.append(step) 190 | reward_his.append(all_reward) 191 | print("[episode= {} ] step = {}".format(episode , step)) 192 | break 193 | 194 | state = next_state 195 | 196 | loss_his = DQN.loss_his 197 | return [step_his , reward_his , loss_his] 198 | 199 | 200 | 201 | 202 | if __name__ == "__main__": 203 | env = gym.make(ENV) 204 | with tf.Session() as sess: 205 | with tf.variable_scope("DQN"): 206 | DQN = DeepQNetwork( "DQN" , env , sess , dueling = False , out_graph = False , out_dqn = True ) 207 | with tf.variable_scope("Deuling"): 208 | Dueling = DeepQNetwork("Deuling" , env , sess , dueling = True , out_graph = False , out_dqn = False ) 209 | 210 | step_dqn , reward_dqn , loss_dqn = train(DQN , env) 211 | step_dueling , reward_dueling , loss_dueling = train(Dueling , env) -------------------------------------------------------------------------------- /PrioritizedReplayDQN-ProportionalVariant.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import tensorflow as tf\n", 12 | "import numpy as np\n", 13 | "import gym\n", 14 | "import random\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "import collections" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": { 23 | "collapsed": true 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "ENV = \"CartPole-v0\"\n", 28 | "\n", 29 | "CAPCITY = 2000\n", 30 | "\n", 31 | "UPDAT_PRMT_PERIOD = 100\n", 32 | "REPLAY_PERIOD = 3\n", 33 | "BATCH_SIZE = 32\n", 34 | "\n", 35 | "EPISODE = 10000\n", 36 | "MAX_STEP = 300\n", 37 | "\n", 38 | "ALPHA = 0.5\n", 39 | "BETA_INIT = 0.1 " 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": { 46 | "collapsed": true 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "# proportional variant / used in class Memory\n", 51 | "class SumTree():\n", 52 | " def __init__(self , ):\n", 53 | " self.capcity = CAPCITY\n", 54 | " self.alpha = ALPHA\n", 55 | " \n", 56 | " # struct of SumTree & memory for the transition\n", 57 | " self.tree = np.zeros( 2 * self.capcity - 1)\n", 58 | " self.data = [None] * self.capcity\n", 59 | " \n", 60 | " # pointer for the position\n", 61 | " self.pointer = 0\n", 62 | " \n", 63 | " # add new priority in leaf_node\n", 64 | " def add_leaf_node(self , p_alpha , transition):\n", 65 | " leaf_idex = self.pointer + self.capcity - 1\n", 66 | " \n", 67 | " self.data[self.pointer] = transition\n", 68 | " self.update_leaf_node(leaf_idex , p_alpha)\n", 69 | " \n", 70 | " self.pointer += 1\n", 71 | " if self.pointer >= self.capcity: # !not self.capcity-1 \n", 72 | " self.pointer = 0\n", 73 | " \n", 74 | " # update leaf_node according leaf_idex\n", 75 | " def update_leaf_node(self , leaf_idex , p_alpha):\n", 76 | " change = p_alpha - self.tree[leaf_idex]\n", 77 | " self.tree[leaf_idex] = p_alpha \n", 78 | " self._update_parent_node(change , leaf_idex )\n", 79 | " \n", 80 | " # update the value of sum p in parent node\n", 81 | " def _update_parent_node(self , change , child_idex ):\n", 82 | " parent_idex = (child_idex - 1) // 2\n", 83 | " \n", 84 | " self.tree[parent_idex] += change \n", 85 | " \n", 86 | " if parent_idex != 0:\n", 87 | " self._update_parent_node(change , parent_idex) \n", 88 | " \n", 89 | " # sampling to get leaf idex and transition\n", 90 | " def sampling(self , sample_idex):\n", 91 | " leaf_idex = self._retrieve(sample_idex)\n", 92 | " data_idex = leaf_idex - self.capcity + 1\n", 93 | " \n", 94 | " return [leaf_idex , self.tree[leaf_idex] , self.data[data_idex] ]\n", 95 | " \n", 96 | " # retrieve with O(log n)\n", 97 | " def _retrieve(self , sample_idex , node_idex = 0):\n", 98 | " left_child_idex = node_idex * 2 + 1\n", 99 | " right_child_idex = left_child_idex + 1\n", 100 | " \n", 101 | " if left_child_idex >= len(self.tree): # ! must be >= \n", 102 | " return node_idex\n", 103 | " \n", 104 | " if self.tree[left_child_idex] == self.tree[right_child_idex]: \n", 105 | " return self._retrieve(sample_idex , np.random.choice([left_child_idex , right_child_idex]))\n", 106 | " if self.tree[left_child_idex] > sample_idex:\n", 107 | " return self._retrieve(sample_idex , node_idex = left_child_idex)\n", 108 | " else:\n", 109 | " return self._retrieve(sample_idex - self.tree[left_child_idex] , node_idex = right_child_idex )\n", 110 | " \n", 111 | " # sum of p in root node\n", 112 | " def root_priority(self):\n", 113 | " return self.tree[0]" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 4, 119 | "metadata": { 120 | "collapsed": true 121 | }, 122 | "outputs": [], 123 | "source": [ 124 | "class Memory():\n", 125 | " def __init__(self , tree_epsilon = 0.01):\n", 126 | " self.epsilon = tree_epsilon\n", 127 | " self.p_init = 1. \n", 128 | " self.beta = BETA_INIT\n", 129 | " self.beta_change_step = 0.001\n", 130 | " self.capcity = CAPCITY\n", 131 | " \n", 132 | " self.sum_tree = SumTree() \n", 133 | " \n", 134 | " # store transition & priority before replay\n", 135 | " def store(self , transition):\n", 136 | " p_max = np.max(self.sum_tree.tree[-self.capcity:])\n", 137 | " if p_max == 0:\n", 138 | " p_max = self.p_init\n", 139 | " self.sum_tree.add_leaf_node(p_max , transition)\n", 140 | " \n", 141 | " # update SumTree\n", 142 | " def update(self , leaf_idex , td_error ):\n", 143 | " p = np.abs(td_error) + self.epsilon\n", 144 | " p_alpha = np.power(p , ALPHA)\n", 145 | " \n", 146 | " for i in range(len(leaf_idex)):\n", 147 | " self.sum_tree.update_leaf_node(leaf_idex[i] , p_alpha[i] )\n", 148 | " \n", 149 | " # sample\n", 150 | " def sampling(self , batch_size ):\n", 151 | " batch_idex = []\n", 152 | " batch_transition = []\n", 153 | " batch_ISweight = []\n", 154 | " \n", 155 | " segment = self.sum_tree.root_priority() / batch_size\n", 156 | " \n", 157 | " for i in range(batch_size):\n", 158 | " low = segment * i\n", 159 | " high = segment * (i + 1)\n", 160 | " sample_idex = np.random.uniform(low , high)\n", 161 | " idex , p_alpha , transition = self.sum_tree.sampling(sample_idex)\n", 162 | " prob = p_alpha / self.sum_tree.root_priority()\n", 163 | " batch_ISweight.append( np.power(self.capcity * prob , -self.beta) )\n", 164 | " batch_idex.append( idex )\n", 165 | " batch_transition.append( transition)\n", 166 | "\n", 167 | " i_maxiwi = np.power(self.capcity * np.min(self.sum_tree.tree[-self.capcity:]) / self.sum_tree.root_priority() , self.beta)\n", 168 | "\n", 169 | " batch_ISweight = np.array(batch_ISweight) * i_maxiwi \n", 170 | " \n", 171 | " return batch_idex , batch_transition , batch_ISweight\n", 172 | " \n", 173 | " # change beta\n", 174 | " def change_beta(self):\n", 175 | " self.beta -= self.beta_change_step\n", 176 | " return np.min(1 , self.beta)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 5, 182 | "metadata": { 183 | "collapsed": true 184 | }, 185 | "outputs": [], 186 | "source": [ 187 | "class AgentModel():\n", 188 | " def __init__(self , env , sess):\n", 189 | " self.state_dim = env.observation_space.shape[0]\n", 190 | " self.action_dim = env.action_space.n\n", 191 | " self.step_size = 0.01\n", 192 | " self.gamma = 0.5\n", 193 | " self.sess = sess\n", 194 | " self.epsilon = 0.8\n", 195 | " self.build_net()\n", 196 | " self.sess.run(tf.global_variables_initializer())\n", 197 | " \n", 198 | " # net_frame\n", 199 | " def net_frame(self , clt_name , inputs ):\n", 200 | " weight_init = tf.random_normal_initializer()\n", 201 | " bias_init = tf.constant_initializer(0.1)\n", 202 | " \n", 203 | " layer1_units = 64\n", 204 | " layer2_units = 32\n", 205 | " \n", 206 | " with tf.variable_scope(\"layer1\"):\n", 207 | " weights1 = tf.get_variable(\"weight\", initializer = weight_init , collections = clt_name , \n", 208 | " shape = [self.state_dim , layer1_units ])\n", 209 | " bias1 = tf.get_variable(\"bias\" , initializer = bias_init , collections = clt_name , shape = [layer1_units] )\n", 210 | " wx_b1 = tf.matmul(inputs , weights1) + bias1\n", 211 | " h1 = tf.nn.relu(wx_b1)\n", 212 | " \n", 213 | " with tf.variable_scope(\"layer2\"):\n", 214 | " weights2 = tf.get_variable(\"weight\" , initializer = weight_init , collections = clt_name , \n", 215 | " shape = [layer1_units , layer2_units])\n", 216 | " bias2 = tf.get_variable(\"bias\" , initializer = weight_init , collections = clt_name , \n", 217 | " shape = [layer2_units])\n", 218 | " wx_b2 = tf.matmul(h1 , weights2) + bias2\n", 219 | " h2 = tf.nn.relu(wx_b2)\n", 220 | " \n", 221 | " with tf.variable_scope(\"layer3\"):\n", 222 | " weights3 = tf.get_variable(\"weight\" , initializer = weight_init , collections = clt_name , \n", 223 | " shape = [layer2_units , self.action_dim])\n", 224 | " bias3 = tf.get_variable(\"bias\" , initializer = weight_init , collections = clt_name , \n", 225 | " shape = [self.action_dim])\n", 226 | " q_out = tf.matmul(h2 , weights3) + bias3\n", 227 | " \n", 228 | " return q_out \n", 229 | " \n", 230 | " # build net\n", 231 | " def build_net(self):\n", 232 | " with tf.variable_scope(\"q_net\"):\n", 233 | " clt_name_q = [\"q_net_prmts\" , tf.GraphKeys.GLOBAL_VARIABLES]\n", 234 | " self.inputs_q = tf.placeholder( dtype = tf.float32 , shape = [None , self.state_dim] , name = \"q_inputs\" )\n", 235 | " self.q_value = self.net_frame(clt_name_q , self.inputs_q)\n", 236 | " \n", 237 | " with tf.variable_scope(\"target_net\"):\n", 238 | " clt_name_target = [\"target_net_prmts\" , tf.GraphKeys.GLOBAL_VARIABLES]\n", 239 | " self.inputs_target = tf.placeholder( dtype = tf.float32 , shape = [None , self.state_dim] , name = \"q_target\")\n", 240 | " self.q_target = self.net_frame(clt_name_target , self.inputs_target)\n", 241 | " \n", 242 | " with tf.variable_scope(\"loss\"):\n", 243 | " self.target = tf.placeholder( dtype = tf.float32 , shape=[ None , self.action_dim ] , name=\"target\" )\n", 244 | " self.ISweight = tf.placeholder( dtype = tf.float32 , shape = [ None , self.action_dim] , name = \"wi\")\n", 245 | " \n", 246 | " self.td_error = self.target - self.q_value\n", 247 | " self.loss = tf.reduce_sum(self.ISweight * tf.squared_difference(self.target , self.q_value )) # ***\n", 248 | " \n", 249 | " with tf.variable_scope(\"train\"):\n", 250 | " self.train_op = tf.train.RMSPropOptimizer(self.step_size).minimize(self.loss)\n", 251 | " \n", 252 | " # training\n", 253 | " def training(self , state , next_state , reward , action , batch_ISweight ):\n", 254 | " q_target , q_value = self.sess.run([self.q_target , self.q_value] , \n", 255 | " feed_dict = {self.inputs_q : state , self.inputs_target : next_state }) \n", 256 | " \n", 257 | " target = reward + self.gamma * np.max(q_target)\n", 258 | " reform_target = q_value.copy()\n", 259 | " batch_idex = np.arange(BATCH_SIZE)\n", 260 | " reform_target[batch_idex , action] = target\n", 261 | " \n", 262 | " batch_ISweight = np.stack([batch_ISweight , batch_ISweight] , axis = -1 )\n", 263 | " \n", 264 | " _ , td_error , loss = self.sess.run([self.train_op , self.td_error , self.loss] ,\n", 265 | " feed_dict = {self.ISweight : batch_ISweight , \n", 266 | " self.inputs_q : state ,\n", 267 | " self.target : reform_target })\n", 268 | " return td_error\n", 269 | " \n", 270 | " # update target_net prmt\n", 271 | " def update_target_prmts(self):\n", 272 | " q_value_prmts = tf.get_collection(\"q_net_prmts\")\n", 273 | " q_target_prmts = tf.get_collection(\"target_net_prmts\")\n", 274 | " self.sess.run( [tf.assign(v , t) for v , t in zip(q_value_prmts , q_target_prmts) ])\n", 275 | " print(\"updating target network prmts...\")\n", 276 | " \n", 277 | " # chose action\n", 278 | " def chose_action(self , current_state):\n", 279 | " state = current_state[np.newaxis , :]\n", 280 | " q_value = self.sess.run(self.q_value , feed_dict={self.inputs_q : state})\n", 281 | " if np.random.random() > self.epsilon:\n", 282 | " action = np.argmax(q_value)\n", 283 | " else:\n", 284 | " action = np.random.randint(0 , self.action_dim)\n", 285 | " return action\n", 286 | " \n", 287 | " def greedy_action(self , current_state):\n", 288 | " current_state = current_state[np.newaxis , :] \n", 289 | " q = self.sess.run(self.q_value , feed_dict={self.inputs_q : current_state} ) \n", 290 | " action_greedy = np.argmax(q)\n", 291 | " return action_greedy\n", 292 | " " 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 6, 298 | "metadata": { 299 | "collapsed": true 300 | }, 301 | "outputs": [], 302 | "source": [ 303 | "def train( env , agent , memory ):\n", 304 | " Transition = collections.namedtuple(\"Transition\" , [\"state\" , \"action\" , \"reward\" , \"next_state\" , \"done\"])\n", 305 | " replay_iter = 0\n", 306 | " update_target_iter = 0\n", 307 | " for episode in range(EPISODE):\n", 308 | " state = env.reset()\n", 309 | " for step in range(MAX_STEP):\n", 310 | " action = agent.chose_action(state)\n", 311 | " next_state , reward , done , _ = env.step(action)\n", 312 | " replay_iter += 1\n", 313 | " \n", 314 | " memory.store(Transition(state , action , reward , next_state , done))\n", 315 | " \n", 316 | " if replay_iter > CAPCITY and replay_iter % REPLAY_PERIOD == 0:\n", 317 | " batch_idex , batch_transition , batch_ISweight = memory.sampling(BATCH_SIZE)\n", 318 | " batch_state , batch_action , batch_reward , batch_next_state , batch_done = map(np.array , zip(*batch_transition))\n", 319 | " td_error = agent.training(batch_state , batch_next_state , batch_reward , batch_action , batch_ISweight)\n", 320 | "\n", 321 | " batch_iter = np.arange(BATCH_SIZE)\n", 322 | " memory.update(batch_idex , td_error[batch_iter , batch_action])\n", 323 | "\n", 324 | " if replay_iter > CAPCITY and replay_iter % UPDAT_PRMT_PERIOD == 0:\n", 325 | " agent.update_target_prmts()\n", 326 | " \n", 327 | " if done:\n", 328 | " if episode % 50 == 0:\n", 329 | " print(\"episode: %d , step: %d\" %( episode , step ))\n", 330 | " break\n", 331 | " state = next_state\n", 332 | " \n", 333 | " \n", 334 | " " 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "metadata": { 341 | "collapsed": true 342 | }, 343 | "outputs": [], 344 | "source": [ 345 | "if __name__ == \"__main__\":\n", 346 | " env = gym.make(ENV)\n", 347 | " memory = Memory()\n", 348 | " sess = tf.Session()\n", 349 | " agent = AgentModel(env , sess)\n", 350 | " train( env , agent , memory )" 351 | ] 352 | } 353 | ], 354 | "metadata": { 355 | "kernelspec": { 356 | "display_name": "Python 3", 357 | "language": "python", 358 | "name": "python3" 359 | }, 360 | "language_info": { 361 | "codemirror_mode": { 362 | "name": "ipython", 363 | "version": 3 364 | }, 365 | "file_extension": ".py", 366 | "mimetype": "text/x-python", 367 | "name": "python", 368 | "nbconvert_exporter": "python", 369 | "pygments_lexer": "ipython3", 370 | "version": "3.5.2" 371 | } 372 | }, 373 | "nbformat": 4, 374 | "nbformat_minor": 2 375 | } 376 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReinforcementLearningCode 2 | Codes for understanding Reinforcement Learning(updating... ) 3 | 4 | “drl-gym” floder gives a updating model to learn both atari games such as breakout & pong, and toy games such as CartPole. 5 | It can be run though there still is some works wanted to be done. 6 | 7 | The others are just demos for toy games, they are Simple Implementation to understand the algorithms. 8 | -------------------------------------------------------------------------------- /drl-gym/DQNModel.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | import gym 7 | import random 8 | import tensorflow.contrib.layers as layers 9 | import os 10 | import time 11 | from netFrame import net_frame_mlp, net_frame_cnn_to_mlp, build_net 12 | from memory import Memory 13 | from utils import * 14 | from argument import args 15 | 16 | class DeepQNetwork4Atari(): 17 | def __init__(self , scope_main, env, flag_double, flag_cnn, sess=None , gamma = 0.99): 18 | self.gamma = gamma 19 | self.epsilon = 1.0 20 | self.action_dim = env.action_space.n 21 | self.state_shape = [None] + list(env.observation_space.shape) # levin soft code: [None,84,84,4] or [None, xx..] 22 | self.scope_main = scope_main 23 | self.flag_double = flag_double 24 | self.flag_cnn = flag_cnn 25 | self.network() 26 | self.sess = sess 27 | 28 | # self.merged = tf.summary.merge_all() 29 | # self.write2 = tf.summary.FileWriter("HVDQN/test1/2" , sess.graph ) 30 | 31 | self.sess.run(tf.global_variables_initializer()) 32 | # self.merged = tf.merge_all_summaries() 33 | # self.result_tensorboar 34 | 35 | # create q_network & target_network 36 | def network(self): 37 | self.inputs_q = tf.placeholder(dtype = tf.float32 , shape = self.state_shape, name = "inputs_q") 38 | scope_var = "q_network" 39 | 40 | self.inputs_target = tf.placeholder(dtype = tf.float32 , shape = self.state_shape , name = "inputs_target") 41 | scope_tar = "target_network" 42 | 43 | # q_network 44 | if self.flag_cnn: 45 | self.q_value = net_frame_cnn_to_mlp(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], 46 | hiddens=[512], 47 | inpt=self.inputs_q, 48 | num_actions=self.action_dim, 49 | scope=scope_var, 50 | dueling=0,) 51 | else: 52 | self.q_value = net_frame_mlp([32,16] , self.inputs_q , self.action_dim , scope=scope_var ) 53 | 54 | # target_network 55 | if self.flag_cnn: 56 | self.q_target = net_frame_cnn_to_mlp(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], 57 | hiddens=[512], 58 | inpt=self.inputs_target, 59 | num_actions=self.action_dim, 60 | scope=scope_tar, 61 | dueling=0,) 62 | else: 63 | self.q_target = net_frame_mlp([32,16] , self.inputs_target , self.action_dim , scope_tar ) 64 | 65 | # # === test === 66 | # self.q_value = build_net(self.inputs_q, scope_var) 67 | # self.q_target = build_net(self.inputs_target, scope_tar) 68 | 69 | with tf.variable_scope("loss"): 70 | self.action = tf.placeholder(dtype = tf.int32 , shape = [None] , name = "action") 71 | self.action_one_hot = tf.one_hot(self.action , self.action_dim ) 72 | q_action = tf.reduce_sum( tf.multiply(self.q_value , self.action_one_hot) , axis = 1 ) 73 | 74 | self.target = tf.placeholder(dtype = tf.float32 , shape = [None] , name = "target") 75 | 76 | # # ----- L2Loss or huberLoss ------- 77 | # self.loss = tf.reduce_mean( tf.square(q_action - self.target)) 78 | self.loss = tf.reduce_mean( huber_loss(q_action - self.target)) 79 | 80 | with tf.variable_scope("train"): 81 | # optimizer = tf.train.RMSPropOptimizer(args.learning_rate_Q, decay=0.99, momentum=0.0, epsilon=1e-6) # 0.001 0.0005(better) 0.0002(net-best) 82 | # # optimizer = tf.train.AdamOptimizer(args.learning_rate_Q,) 83 | # if args.flag_CLIP_NORM : 84 | # gradients = optimizer.compute_gradients(self.loss) 85 | # for i , (g, v) in enumerate(gradients): 86 | # if g is not None: 87 | # gradients[i] = (tf.clip_by_norm(g , 1) , v) 88 | # self.train_op = optimizer.apply_gradients(gradients) 89 | # else: 90 | # self.train_op = optimizer.minimize(self.loss) 91 | self.train_op = build_rmsprop_optimizer(args.learning_rate_Q, 0.99, 1e-6, 1, 'rmsprop', loss=self.loss) 92 | 93 | # training 94 | def train(self, state, reward, action, state_next, done, episode_return, estim_Qvalue_argmax, estim_Qvalue_expect): 95 | q, q_target = self.sess.run([self.q_value, self.q_target], 96 | feed_dict={self.inputs_q: state, self.inputs_target: state_next}) 97 | # dqn 98 | if not self.flag_double: 99 | q_target_best = np.max(q_target, axis = 1) 100 | # doubel dqn 101 | else: 102 | q_next = self.sess.run(self.q_value , feed_dict={self.inputs_q : state_next}) 103 | action_best = np.argmax(q_next , axis = 1) 104 | action_best_one_hot = self.sess.run(self.action_one_hot, feed_dict={self.action: action_best}) 105 | q_target_best = np.sum(q_target * action_best_one_hot, axis=1) 106 | 107 | q_target_best_mask = ( 1.0 - done) * q_target_best 108 | target = reward + self.gamma * q_target_best_mask 109 | 110 | loss, _ = self.sess.run([self.loss, self.train_op] , 111 | feed_dict={self.inputs_q: state , self.target: target , self.action: action} ) 112 | # if update_iter % SUMMARY_PERIOD == 0: 113 | # result = self.sess.run(self.merged, 114 | # feed_dict={self.inputs_q: state , self.target: target , self.action: action, self.inputs_target: state} ) 115 | # self.write.add_summary(result, update_iter) 116 | 117 | return estim_Qvalue_argmax, estim_Qvalue_expect 118 | 119 | # chose action 120 | def chose_action(self , current_state): 121 | 122 | # e-greedy 123 | if np.random.random() < self.epsilon: 124 | # action_chosen = np.random.randint(0 , self.action_dim) 125 | action_chosen = random.randrange(self.action_dim) 126 | 127 | else: 128 | current_state = current_state[np.newaxis , :] # *** array dim: (xx,) --> (1 , xx) *** 129 | q = self.sess.run(self.q_value , feed_dict={self.inputs_q : current_state} ) 130 | action_chosen = np.argmax(q) 131 | return action_chosen 132 | 133 | def greedy_action(self , current_state): 134 | current_state = current_state[np.newaxis , :] 135 | # print(current_state) 136 | q = self.sess.run(self.q_value , feed_dict={self.inputs_q : current_state} ) 137 | action_greedy = np.argmax(q) 138 | return action_greedy 139 | 140 | #upadate parmerters 141 | def update_prmt(self): 142 | q_prmts = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, self.scope_main + "/q_network" ) 143 | target_prmts = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, self.scope_main + "/target_network" ) 144 | self.sess.run( [tf.assign(t , q)for t,q in zip(target_prmts , q_prmts)]) #*** 145 | print("===updating target-network parmeters===") 146 | 147 | def decay_epsilon(self, episode, SUM_EP): 148 | episode = episode * 1.0 149 | faster_factor = 1.0 # 1.2 ... 2... 150 | if self.epsilon > 0.1: 151 | self.epsilon = 1 - episode * faster_factor / SUM_EP 152 | # print("epsilon:............................",self.epsilon) 153 | 154 | -------------------------------------------------------------------------------- /drl-gym/argument.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import argparse 4 | import os 5 | import time 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument("--env_name", type=str, default="BreakoutNoFrameskip-v4") # "Acrobot-v1" "BreakoutNoFrameskip-v4" CartPole-v0 11 | parser.add_argument("--epoches", type=int, default=200, help="If flag_done_break is True, it is episodes; Else, it like periods") # 100000 12 | parser.add_argument("--max_step", type=int, default=5000, help="If flag_done_break, it usually has no means; Else, it work with epoches") # 1000 13 | 14 | parser.add_argument("--flag_done_break", type=bool, default=False, help="if True, when done,restart a new episode; Else, keep in this epoch") 15 | parser.add_argument("--update_period", type=int, default=2500, help="period to update targetNet. If flag_done_break, it is episodes; else, steps") # 2500 16 | parser.add_argument("--reveal_period", type=int, default=5000, help="period to reveal the results, it only work when not flag_done_break") # 2500 17 | parser.add_argument("--test_period", type=int, default=10000, help="period to update targetNet. If flag_done_break, it is episodes; else, steps") 18 | 19 | parser.add_argument("--memory_size", type=int, default=200000, help="If atari, suggesting lagrge; else, maybe 20000. Depend on memory of your computer") 20 | parser.add_argument("--batch_size", type=int, default=32) 21 | parser.add_argument("--learning_rate_Q", type=float, default=0.00025) #0.001 22 | 23 | parser.add_argument("--seed", type=int, default=11037) 24 | parser.add_argument("--env_seed", type=int, default=1234) 25 | 26 | parser.add_argument("--observe_step", type=int, default=50000) # 50000 150 27 | parser.add_argument("--explore_step", type=int, default=500000) # 400000 1000000 10000 28 | parser.add_argument("--train_step", type=int, default=3000) # 29 | 30 | parser.add_argument("--frame_skip", type=int, default=4) 31 | 32 | parser.add_argument("--flag_double_dqn", type=bool, default=False) 33 | parser.add_argument("--flag_dueling_dqn", type=bool, default=False) 34 | parser.add_argument("--flag_CLIP_NORM", type=bool, default=True) 35 | parser.add_argument("--flag_flag_cnn", type=bool, default=True) 36 | 37 | parser.add_argument('--file_path', type=str, default=str(os.getcwd())) 38 | parser.add_argument('--flag_save', type=bool, default=True) 39 | # file_path + file_path 40 | parser.add_argument('--config_file', type=str, default='/config'+str(int(time.time()))+'.txt') 41 | parser.add_argument('--train_file', type=str, default='/train_file'+str(int(time.time()))+'.csv') 42 | parser.add_argument('--test_file', type=str, default='/test_file'+str(int(time.time()))+'.csv') 43 | parser.add_argument('--all_file', type=str, default='/all_file'+str(int(time.time()))+'.csv') 44 | 45 | 46 | args_origin = parser.parse_args() 47 | 48 | def argsWrapper(args): 49 | experiment_name="seed1_dqn" 50 | # set config 51 | env_set = ["BreakoutNoFrameskip-v4", "PongNoFrameskip-v4"] 52 | if args.env_name in env_set: 53 | args.flag_cnn = True 54 | else: 55 | args.flag_cnn = False 56 | if not args.flag_cnn: 57 | print("make true your parms, such as memory_size, update_period") 58 | print(" ") 59 | print(" ") 60 | # set dir_path 61 | if args.flag_save: 62 | dir_path = str(int(time.time())) 63 | path = str(os.getcwd()) 64 | new_path = path + '/results/' + experiment_name + "_" + args.env_name + dir_path 65 | # new_path = path + '/' + args.env_name + dir_path 66 | os.makedirs(new_path) 67 | args.file_path = new_path 68 | args.config_file = args.file_path + args.config_file 69 | args.train_file = args.file_path + args.train_file 70 | args.test_file = args.file_path + args.test_file 71 | args.all_file = args.file_path + args.all_file 72 | return args 73 | 74 | args = argsWrapper(args_origin) 75 | 76 | print(args) -------------------------------------------------------------------------------- /drl-gym/atari_wrappers.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import os 6 | os.environ.setdefault('PATH', '') 7 | from collections import deque 8 | import gym 9 | from gym import spaces 10 | import cv2 11 | cv2.ocl.setUseOpenCL(False) 12 | 13 | class NoopResetEnv(gym.Wrapper): 14 | def __init__(self, env, noop_max=30): 15 | """Sample initial states by taking random number of no-ops on reset. 16 | No-op is assumed to be action 0. 17 | """ 18 | gym.Wrapper.__init__(self, env) 19 | self.noop_max = noop_max 20 | self.override_num_noops = None 21 | self.noop_action = 0 22 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP' 23 | 24 | def reset(self, **kwargs): 25 | """ Do no-op action for a number of steps in [1, noop_max].""" 26 | self.env.reset(**kwargs) 27 | if self.override_num_noops is not None: 28 | noops = self.override_num_noops 29 | else: 30 | noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101 31 | assert noops > 0 32 | obs = None 33 | for _ in range(noops): 34 | obs, _, done, _ = self.env.step(self.noop_action) 35 | if done: 36 | obs = self.env.reset(**kwargs) 37 | return obs 38 | 39 | def step(self, ac): 40 | return self.env.step(ac) 41 | 42 | class FireResetEnv(gym.Wrapper): 43 | def __init__(self, env): 44 | """Take action on reset for environments that are fixed until firing.""" 45 | gym.Wrapper.__init__(self, env) 46 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE' 47 | assert len(env.unwrapped.get_action_meanings()) >= 3 48 | 49 | def reset(self, **kwargs): 50 | self.env.reset(**kwargs) 51 | obs, _, done, _ = self.env.step(1) 52 | if done: 53 | self.env.reset(**kwargs) 54 | obs, _, done, _ = self.env.step(2) 55 | if done: 56 | self.env.reset(**kwargs) 57 | return obs 58 | 59 | def step(self, ac): 60 | return self.env.step(ac) 61 | 62 | class EpisodicLifeEnv(gym.Wrapper): 63 | def __init__(self, env): 64 | """Make end-of-life == end-of-episode, but only reset on true game over. 65 | Done by DeepMind for the DQN and co. since it helps value estimation. 66 | """ 67 | gym.Wrapper.__init__(self, env) 68 | self.lives = 0 69 | self.was_real_done = True 70 | 71 | def step(self, action): 72 | obs, reward, done, info = self.env.step(action) 73 | self.was_real_done = done 74 | # check current lives, make loss of life terminal, 75 | # then update lives to handle bonus lives 76 | lives = self.env.unwrapped.ale.lives() 77 | if lives < self.lives and lives > 0: 78 | # for Qbert sometimes we stay in lives == 0 condtion for a few frames 79 | # so its important to keep lives > 0, so that we only reset once 80 | # the environment advertises done. 81 | done = True 82 | self.lives = lives 83 | return obs, reward, done, info 84 | 85 | def reset(self, **kwargs): 86 | """Reset only when lives are exhausted. 87 | This way all states are still reachable even though lives are episodic, 88 | and the learner need not know about any of this behind-the-scenes. 89 | """ 90 | if self.was_real_done: 91 | obs = self.env.reset(**kwargs) 92 | else: 93 | # no-op step to advance from terminal/lost life state 94 | obs, _, _, _ = self.env.step(0) 95 | self.lives = self.env.unwrapped.ale.lives() 96 | return obs 97 | 98 | class MaxAndSkipEnv(gym.Wrapper): 99 | def __init__(self, env, skip=4): 100 | """Return only every `skip`-th frame""" 101 | gym.Wrapper.__init__(self, env) 102 | # most recent raw observations (for max pooling across time steps) 103 | self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) 104 | self._skip = skip 105 | 106 | def step(self, action): 107 | """Repeat action, sum reward, and max over last observations.""" 108 | total_reward = 0.0 109 | done = None 110 | for i in range(self._skip): 111 | obs, reward, done, info = self.env.step(action) 112 | if i == self._skip - 2: self._obs_buffer[0] = obs 113 | if i == self._skip - 1: self._obs_buffer[1] = obs 114 | total_reward += reward 115 | if done: 116 | break 117 | # Note that the observation on the done=True frame 118 | # doesn't matter 119 | max_frame = self._obs_buffer.max(axis=0) 120 | 121 | return max_frame, total_reward, done, info 122 | 123 | def reset(self, **kwargs): 124 | return self.env.reset(**kwargs) 125 | 126 | class ClipRewardEnv(gym.RewardWrapper): 127 | def __init__(self, env): 128 | gym.RewardWrapper.__init__(self, env) 129 | 130 | def reward(self, reward): 131 | """Bin reward to {+1, 0, -1} by its sign.""" 132 | return np.sign(reward) 133 | 134 | class WarpFrame(gym.ObservationWrapper): 135 | def __init__(self, env): 136 | """Warp frames to 84x84 as done in the Nature paper and later work.""" 137 | gym.ObservationWrapper.__init__(self, env) 138 | self.width = 84 139 | self.height = 84 140 | self.observation_space = spaces.Box(low=0, high=255, 141 | shape=(self.height, self.width, 1), dtype=np.uint8) 142 | 143 | def observation(self, frame): 144 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 145 | frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) 146 | return frame[:, :, None] 147 | 148 | class FrameStack(gym.Wrapper): 149 | def __init__(self, env, k): 150 | """Stack k last frames. 151 | 152 | Returns lazy array, which is much more memory efficient. 153 | 154 | See Also 155 | -------- 156 | baselines.common.atari_wrappers.LazyFrames 157 | """ 158 | gym.Wrapper.__init__(self, env) 159 | self.k = k 160 | self.frames = deque([], maxlen=k) 161 | shp = env.observation_space.shape 162 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype) 163 | 164 | def reset(self): 165 | ob = self.env.reset() 166 | for _ in range(self.k): 167 | self.frames.append(ob) 168 | return self._get_ob() 169 | 170 | def step(self, action): 171 | ob, reward, done, info = self.env.step(action) 172 | self.frames.append(ob) 173 | return self._get_ob(), reward, done, info 174 | 175 | def _get_ob(self): 176 | assert len(self.frames) == self.k 177 | return LazyFrames(list(self.frames)) 178 | 179 | class ScaledFloatFrame(gym.ObservationWrapper): 180 | def __init__(self, env): 181 | gym.ObservationWrapper.__init__(self, env) 182 | self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32) 183 | 184 | def observation(self, observation): 185 | # careful! This undoes the memory optimization, use 186 | # with smaller replay buffers only. 187 | return np.array(observation).astype(np.float32) / 255.0 188 | 189 | class LazyFrames(object): 190 | def __init__(self, frames): 191 | """This object ensures that common frames between the observations are only stored once. 192 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay 193 | buffers. 194 | 195 | This object should only be converted to numpy array before being passed to the model. 196 | 197 | You'd not believe how complex the previous solution was.""" 198 | self._frames = frames 199 | self._out = None 200 | 201 | def _force(self): 202 | if self._out is None: 203 | self._out = np.concatenate(self._frames, axis=2) 204 | self._frames = None 205 | return self._out 206 | 207 | def __array__(self, dtype=None): 208 | out = self._force() 209 | if dtype is not None: 210 | out = out.astype(dtype) 211 | return out 212 | 213 | def __len__(self): 214 | return len(self._force()) 215 | 216 | def __getitem__(self, i): 217 | return self._force()[i] 218 | 219 | def make_atari(env_id): 220 | env = gym.make(env_id) 221 | assert 'NoFrameskip' in env.spec.id 222 | env = NoopResetEnv(env, noop_max=30) 223 | env = MaxAndSkipEnv(env, skip=4) 224 | return env 225 | 226 | def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False): 227 | """Configure environment for DeepMind-style Atari. 228 | """ 229 | if episode_life: 230 | env = EpisodicLifeEnv(env) 231 | if 'FIRE' in env.unwrapped.get_action_meanings(): 232 | env = FireResetEnv(env) 233 | env = WarpFrame(env) 234 | if scale: 235 | env = ScaledFloatFrame(env) 236 | if clip_rewards: 237 | env = ClipRewardEnv(env) 238 | if frame_stack: 239 | env = FrameStack(env, 4) 240 | return env 241 | 242 | -------------------------------------------------------------------------------- /drl-gym/envWrapper.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import gym 5 | from atari_wrappers import * 6 | from argument import args 7 | 8 | def envMakeWrapper(env_name, flag_cnn=False): 9 | if not flag_cnn: 10 | print("not wrapper") 11 | env = gym.make(env_name) 12 | env.seed(args.env_seed) 13 | return 14 | 15 | else: 16 | print("wrapper") 17 | env = make_atari(env_name) 18 | env = wrap_deepmind(env, frame_stack=True, scale=True) 19 | env.seed(args.env_seed) 20 | return env 21 | 22 | -------------------------------------------------------------------------------- /drl-gym/logger.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import os 6 | 7 | def write(file, value): 8 | file = open(file_path, 'a') 9 | if isinstance(value, int): 10 | file.write() 11 | ''' 12 | 作用在于: 13 | store_resutl(): 进行数据储存; 14 | reveal_last(): 定期输出该期间的平均值; 15 | write_last(): 定期将平均值结果写入文件,防止数据丢失; 16 | write_final(): 将每次保存的数据进行写入,不止是平均值。 17 | ''' 18 | class Logger: 19 | def __init__(self): 20 | self.result_dict = dict() 21 | self.scale_dict = dict() 22 | 23 | self.result_iter = dict() # used in reveal_last 24 | self.result_last_iter = dict() # used in reveal_last 25 | self.flag_title = 1 26 | self.final_flag_title = 1 27 | 28 | def store_num(self, **kwargs): 29 | for k,v in kwargs.items(): 30 | if not(k in self.scale_dict.keys()): 31 | self.scale_dict[k] = [] 32 | self.result_iter[k] = 0 33 | if not isinstance(v, np.ndarray): 34 | v = np.array(v) 35 | self.scale_dict[k].append(v) 36 | self.result_iter[k] += 1 37 | 38 | def store_result(self, **kwargs): # warning: kwargs must be a dict or using "=" 39 | for k,v in kwargs.items(): 40 | if not(k in self.result_dict.keys()): 41 | self.result_dict[k] = [] 42 | self.result_iter[k] = 0 43 | self.result_last_iter[k] = 0 44 | if not isinstance(v, np.ndarray): 45 | v = np.array(v) 46 | self.result_dict[k].append(v) 47 | self.result_iter[k] += 1 48 | 49 | # only reveal, without write file 50 | def reveal_last(self, *args): 51 | # auto reveal last k mean results, where k = reveal_period in main_loop, s.t. k = 1 ... 52 | if len(args) > 0: 53 | for key in args: 54 | if key in self.scale_dict.keys(): 55 | print(str(key) , ":" , self.scale_dict[key][self.result_iter[key]-1], end=" , ") 56 | elif key in self.result_dict.keys(): 57 | value_last = np.mean(self.result_dict[key][self.result_last_iter[key]:self.result_iter[key]], axis=0) 58 | print(str(key) , ":" , value_last, end=" , ") 59 | else: 60 | raise KeyError(key) 61 | self.result_last_iter[key] = self.result_iter[key] 62 | print("\n") 63 | # reveal all results (lat k mean results) 64 | else: 65 | for key in self.scale_dict.keys(): 66 | print(str(key) , ":" , self.scale_dict[key][self.result_iter[key]-1], end=" , ") 67 | for key in self.result_dict.keys(): 68 | value_last = np.mean(self.result_dict[key][self.result_last_iter[key]:self.result_iter[key]], axis=0) 69 | print(str(key) , ":" , value_last, end=" , ") 70 | self.result_last_iter[key] = self.result_iter[key] 71 | print(" ") 72 | 73 | # # test... dont use 74 | # def reveal_last_value(self, *args): 75 | # # auto reveal last k mean results, where k = reveal_period in main_loop, s.t. k = 1 ... 76 | # if len(args) > 0: 77 | # # reveal 78 | # for key in args: 79 | # assert key in self.result_dict.keys() 80 | # value_last = np.mean(self.result_dict[key][self.result_last_iter[key]:self.result_iter[key]], axis=0) 81 | # print(str(key) , ":" , value_last, end=" , ") 82 | # self.result_last_iter[key] = self.result_iter[key] 83 | # print(" ") 84 | # # reveal all results (lat k mean results) 85 | # else: 86 | # for key in self.result_dict.keys(): 87 | # value_last = np.mean(self.result_dict[key][self.result_last_iter[key]:self.result_iter[key]], axis=0) 88 | # print(str(key) , ":" , value_last, end=" , ") 89 | # self.result_last_iter[key] = self.result_iter[key] 90 | # print(" ") 91 | 92 | # only write, without reveal 93 | def write_last(self, save_path=os.getcwd(), save_name='result.csv', write_period=1): 94 | self.fl=open(save_path + '/' + save_name, 'a') 95 | # TODO-1: a judge --> diff save_name for diff files 96 | # TODO-2: only write all, next adding *args 97 | 98 | # write title 99 | if self.flag_title: 100 | self._write_title(self.fl, self.scale_dict) 101 | self._write_title(self.fl, self.result_dict) 102 | self.fl.write("\n") 103 | self.flag_title = 0 104 | # write value 105 | for key in self.scale_dict.keys(): 106 | space_num = self.scale_dict[key][0].size 107 | if space_num == 1: 108 | self.fl.write(str(self.scale_dict[key][self.result_iter[key]-1]) + ",") 109 | else: 110 | for j in range(space_num): 111 | self.fl.write(str(self.scale_dict[key][self.result_iter[key]-1][j]) + ",") 112 | for key in self.result_dict.keys(): 113 | space_num = self.result_dict[key][0].size 114 | if space_num == 1: 115 | self.fl.write(str(np.mean(self.result_dict[key][-write_period:], axis=0)) + ",") 116 | else: 117 | for j in range(space_num): 118 | # print(np.mean(self.result_dict[key][-write_period:], axis=0)) 119 | # assert True 120 | self.fl.write(str(np.mean(self.result_dict[key][-write_period:], axis=0)[j]) + ",") 121 | # print(str(self.result_dict[key][iter_key][j])) 122 | # print(str(np.mean(self.result_dict[key], axis=0)[j])) 123 | self.fl.write("\n") 124 | # self.fl.flush() 125 | self.fl.close() 126 | 127 | def _write_title(self, file, key_list): 128 | for key in key_list.keys(): 129 | # print(key_list[key][0]) 130 | # print(key_list[key]) 131 | # print(key) 132 | space_num = key_list[key][0].size 133 | # print(space_num) 134 | file.write(str(key) + ",") 135 | for j in range(space_num - 1): 136 | file.write(" " + ",") 137 | 138 | 139 | def reveal_all(self, *args): 140 | # reveal some results based on args 141 | if len(args) > 0: 142 | for key in args: 143 | assert key in self.result_dict.keys() 144 | print(str(key) , ": " , self.result_dict[key]) 145 | # reveal all results 146 | else: 147 | for key in self.result_dict.keys(): 148 | print(str(key) , ": " , self.result_dict[key]) 149 | 150 | def write_final(self, save_path=os.getcwd(), save_name='result_all.csv'): 151 | fl=open(save_path + '/' + save_name, 'w') 152 | for key in self.result_dict.keys(): 153 | space_num = self.result_dict[key][0].size 154 | value_num = len(self.result_dict[key]) 155 | # print(space_num) 156 | fl.write(str(key) + ",") 157 | for j in range(space_num - 1): 158 | fl.write(" " + ",") 159 | fl.write("\n") 160 | # write value 161 | for iter_key in range(value_num): 162 | for key in self.result_dict.keys(): 163 | space_num = self.result_dict[key][0].size 164 | if space_num == 1: 165 | fl.write(str(self.result_dict[key][iter_key]) + ",") 166 | else: 167 | for j in range(space_num): 168 | fl.write(str(self.result_dict[key][iter_key][j]) + ",") # TODO-error 169 | # print(str(self.result_dict[key][iter_key][j])) 170 | fl.write("\n") 171 | fl.close() 172 | 173 | 174 | # # 使用测试: 175 | # Rcd = Logger() 176 | 177 | # for i in range(5): 178 | # a = np.random.randn(1) 179 | # b = np.random.randn(2) 180 | # c = np.random.randn(1) 181 | # Rcd.store_result(resA=i) 182 | # Rcd.store_result(resB=b) 183 | # Rcd.store_result(resC=c) 184 | # if (i+1) % 2 == 0: 185 | # # Rcd.write_last(write_period=2) 186 | # print(a) 187 | # print("last") 188 | # Rcd.reveal_last("resA") 189 | # Rcd.write_final() 190 | 191 | # a = np.array(np.random.randn(1,2)) 192 | # print(type(a)) 193 | # print(a) 194 | # print(a.shape) 195 | # a = [] 196 | # for j in range(10): 197 | # a.append(j) 198 | # if (j+1) % 2 == 0: 199 | # print("a",a) 200 | # print("seg:",a[-2:]) 201 | 202 | 203 | # print(a) 204 | # # print(a[1:3]) 205 | # # print(np.mean(a[1:3])) 206 | # k = 3 207 | # print(a[-k:]) -------------------------------------------------------------------------------- /drl-gym/memory.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow as tf 5 | import collections 6 | import numpy as np 7 | import random 8 | 9 | # memory for momery replay 10 | Transition = collections.namedtuple("Transition" , ["state", "action", "reward", "next_state", "done", "episode_return"]) 11 | class Memory: 12 | def __init__(self, size, flag_piexl=0): 13 | self.memory = [] 14 | self.capacity = size 15 | self.flag_piexl = flag_piexl 16 | 17 | def store(self, state, action, reward, next_state, done, ep_return=0): 18 | if len(self.memory) > self.capacity: 19 | self.memory.pop(0) 20 | if self.flag_piexl: 21 | assert np.amin(state) >= 0.0 22 | assert np.amax(state) <= 1.0 23 | 24 | # Class LazyFrame --> np.array() 25 | state = np.array(state) 26 | next_state = np.array(next_state) 27 | 28 | state = (state * 255).round().astype(np.uint8) 29 | next_state = (next_state * 255).round().astype(np.uint8) 30 | 31 | self.memory.append(Transition(state, action , reward , next_state , float(done), ep_return)) 32 | 33 | def batchSample(self, batch_size): 34 | batch_transition = random.sample(self.memory, batch_size) 35 | state, action, reward, next_state, done, ep_return = map(np.array , zip(*batch_transition)) 36 | if self.flag_piexl: 37 | state = state.astype(np.float32) / 255.0 38 | next_state = next_state.astype(np.float32) / 255.0 39 | return state, action, reward, next_state, done, ep_return 40 | 41 | def size(self): 42 | return len(self.memory) 43 | -------------------------------------------------------------------------------- /drl-gym/netFrame.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow as tf 5 | import tensorflow.contrib.layers as layers 6 | from argument import args 7 | 8 | # net_frame using for creating Q & target network 9 | def net_frame_mlp(hiddens, inpt, num_actions, scope, reuse=False, layer_norm=False): 10 | with tf.variable_scope(scope, reuse=reuse): 11 | out = inpt 12 | for hidden in hiddens: 13 | out = layers.fully_connected(out, num_outputs=hidden, activation_fn=None) 14 | if layer_norm: 15 | out = layers.layer_norm(out, center=True, scale=True) 16 | out = tf.nn.relu(out) 17 | q_out = layers.fully_connected(out, num_outputs=num_actions, activation_fn=None) 18 | return q_out 19 | 20 | def net_frame_cnn_to_mlp(convs, hiddens, inpt, num_actions, scope, dueling=False, reuse=False, layer_norm=False): 21 | with tf.variable_scope(scope, reuse=reuse): 22 | out = inpt 23 | with tf.variable_scope("convnet"): 24 | for num_outputs, kernel_size, stride in convs: 25 | out = layers.convolution2d(out, 26 | num_outputs=num_outputs, 27 | kernel_size=kernel_size, 28 | stride=stride, 29 | activation_fn=tf.nn.relu) 30 | conv_out = layers.flatten(out) 31 | with tf.variable_scope("action_value"): 32 | action_out = conv_out 33 | for hidden in hiddens: 34 | action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None) 35 | if layer_norm: 36 | action_out = layers.layer_norm(action_out, center=True, scale=True) 37 | action_out = tf.nn.relu(action_out) 38 | action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None) 39 | 40 | if dueling: 41 | with tf.variable_scope("state_value"): 42 | state_out = conv_out 43 | for hidden in hiddens: 44 | state_out = layers.fully_connected(state_out, num_outputs=hidden, activation_fn=None) 45 | if layer_norm: 46 | state_out = layers.layer_norm(state_out, center=True, scale=True) 47 | state_out = tf.nn.relu(state_out) 48 | state_score = layers.fully_connected(state_out, num_outputs=1, activation_fn=None) 49 | action_scores_mean = tf.reduce_mean(action_scores, 1) 50 | action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, 1) 51 | q_out = state_score + action_scores_centered 52 | else: 53 | q_out = action_scores 54 | return q_out 55 | 56 | 57 | # ref: break-out-master 58 | def build_net(s, var_scope, dueling=0): 59 | with tf.variable_scope(var_scope): 60 | with tf.variable_scope('conv1'): 61 | W1 = init_W(shape=[8, 8, 4, 32]) 62 | b1 = init_b(shape=[32]) 63 | conv1 = conv2d(s, W1, strides=4) 64 | h_conv1 = tf.nn.relu(tf.nn.bias_add(conv1, b1)) 65 | 66 | # with tf.name_scope('max_pool1'): 67 | # h_pool1 = max_pool(h_conv1) 68 | 69 | with tf.variable_scope('conv2'): 70 | W2 = init_W(shape=[4, 4, 32, 64]) 71 | b2 = init_b(shape=[64]) 72 | conv2 = conv2d(h_conv1, W2, strides=2) 73 | h_conv2 = tf.nn.relu(tf.nn.bias_add(conv2, b2)) 74 | 75 | with tf.variable_scope('conv3'): 76 | W3 = init_W(shape=[3, 3, 64, 64]) 77 | b3 = init_b(shape=[64]) 78 | conv3 = conv2d(h_conv2, W3, strides=1) 79 | h_conv3 = tf.nn.relu(tf.nn.bias_add(conv3, b3)) 80 | 81 | h_flatten = tf.reshape(h_conv3, [-1, 3136]) 82 | 83 | with tf.variable_scope('fc1'): 84 | W_fc1 = init_W(shape=[3136, 512]) 85 | b_fc1 = init_b(shape=[512]) 86 | fc1 = tf.nn.bias_add(tf.matmul(h_flatten, W_fc1), b_fc1) 87 | 88 | if not dueling: 89 | with tf.variable_scope('fc2'): 90 | h_fc1 = tf.nn.relu(fc1) 91 | W_fc2 = init_W(shape=[512, 4]) 92 | b_fc2 = init_b(shape=[4]) 93 | out = tf.nn.bias_add(tf.matmul(h_fc1, W_fc2), b_fc2, name='Q') 94 | else: 95 | with tf.variable_scope('Value'): 96 | h_fc1_v = tf.nn.relu(fc1) 97 | W_v = init_W(shape=[512, 1]) 98 | b_v = init_b(shape=[1]) 99 | V = tf.nn.bias_add(tf.matmul(h_fc1_v, W_v), b_v, name='V') 100 | 101 | with tf.variable_scope('Advantage'): 102 | h_fc1_a = tf.nn.relu(fc1) 103 | W_a = init_W(shape=[512, 4]) 104 | b_a = init_b(shape=[4]) 105 | A = tf.nn.bias_add(tf.matmul(h_fc1_a, W_a), b_a, name='A') 106 | 107 | with tf.variable_scope('Q'): 108 | out = V + ( A - tf.reduce_mean( A, axis=1, keep_dims=True)) 109 | return out 110 | 111 | def init_W(shape, name='weights', w_initializer=tf.truncated_normal_initializer(0, 1e-2)): 112 | 113 | return tf.get_variable( 114 | name=name, 115 | shape=shape, 116 | initializer=w_initializer) 117 | 118 | def init_b(shape, name='biases', b_initializer = tf.constant_initializer(1e-2)): 119 | 120 | return tf.get_variable( 121 | name=name, 122 | shape=shape, 123 | initializer=b_initializer) 124 | 125 | def conv2d(x, kernel, strides=4): 126 | 127 | return tf.nn.conv2d( 128 | input=x, 129 | filter=kernel, 130 | strides=[1, strides, strides, 1], 131 | padding="VALID") 132 | 133 | def max_pool(x, ksize=2, strides=2): 134 | return tf.nn.max_pool(x, 135 | ksize=[1, ksize, ksize, 1], 136 | strides=[1, strides, strides, 1], 137 | padding="SAME") 138 | -------------------------------------------------------------------------------- /drl-gym/run.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | import gym 7 | import os 8 | import time 9 | 10 | from testModel import testQGame 11 | from netFrame import net_frame_mlp, net_frame_cnn_to_mlp 12 | from utils import * 13 | from envWrapper import envMakeWrapper 14 | from DQNModel import DeepQNetwork4Atari 15 | from memory import Memory 16 | 17 | from argument import args 18 | from logger import Logger 19 | 20 | 21 | def training(agent, env, flag_using_heu, file = None): 22 | update_iter = 0 23 | reward_every = [] 24 | Max_return = 0 25 | reward_all = 0 26 | 27 | reward_all_his = [] 28 | estim_Qvalue_argmax = [] 29 | estim_Qvalue_expect = [] 30 | 31 | test_reward_all = [] 32 | test_every_average1_Q = [] 33 | test_every_each10_Q = [] 34 | 35 | test_reward_H = [] 36 | test_every_H = [] 37 | 38 | reward_episode = 0 #!!!!! 39 | ep_flag = 0 40 | 41 | for episode in range(args.epoches): 42 | state = env.reset() 43 | # env.render() 44 | 45 | if episode != 0 and (episode + 1) % args.test_period == 0 and update_iter > args.observe_step: 46 | test_reward_all, test_every_average1_Q, test_every_each10_Q = testQGame(args.env_name, agent, episode, test_reward_all, test_every_average1_Q, test_every_each10_Q, args.flag_cnn) 47 | 48 | #training 49 | for step in range(args.max_step): 50 | action = agent.chose_action(state) 51 | next_state , reward , done , _ = env.step(action) 52 | update_iter += 1 53 | reward_episode += reward 54 | # tf.summary.scalar('reward_all',reward_all) 55 | 56 | memory.store(state, action , reward , next_state , done) 57 | 58 | if memory.size() > args.observe_step: # [TODO-why: observe so much?] 59 | if update_iter % args.frame_skip == 0: # levin: without testing [TODO-why: significant?] 60 | #*** 61 | batch_state, batch_action, batch_reward, batch_next_state, batch_done, batch_return = memory.batchSample(args.batch_size) 62 | estim_Qvalue_argmax, estim_Qvalue_expect = agent.train(state=batch_state , 63 | reward= batch_reward , 64 | action = batch_action , 65 | state_next = batch_next_state, 66 | done = batch_done, 67 | episode_return = batch_return, 68 | estim_Qvalue_argmax=estim_Qvalue_argmax, 69 | estim_Qvalue_expect=estim_Qvalue_expect, 70 | 71 | ) 72 | # if flag_summary: 73 | # agent.write_summary(state = batch_state , 74 | # reward = batch_reward , 75 | # action = batch_action , 76 | # state_next = batch_next_state, 77 | # done = batch_done, 78 | # episode_return = batch_return, 79 | # summary_iter = update_iter, 80 | # reward_all = reward_all, 81 | # flag_using_heu = flag_using_heu 82 | # ) 83 | 84 | 85 | if update_iter > args.observe_step and update_iter % args.update_period == 0: 86 | agent.update_prmt() 87 | 88 | if update_iter > args.observe_step and update_iter % 100 == 0 and update_iter != 0: 89 | agent.decay_epsilon(update_iter - args.observe_step, args.explore_step) 90 | 91 | # episode or epoch, if episode: done --> break 92 | if done: 93 | if args.flag_done_break: 94 | print(" epoch:%3d , step: %3d , epsilon: %.3f , reward: %d"%(episode, step, agent.epsilon, reward_episode)) 95 | # file = open(file_path,'w') 96 | # file.write(" episode:%3d , step:%3d , reward: %d"%(episode, step, reward_episode)) 97 | # file.write("\n")d 98 | reward_every.append(reward_episode) 99 | # break 100 | 101 | reward_all += reward_episode 102 | reward_all_his.append(reward_all) 103 | reward_episode = 0 104 | # next_state = env.reset() 105 | break 106 | 107 | else: 108 | next_state = env.reset() # TODO: doubt --- seems wrong? 109 | 110 | if step == args.max_step - 1: 111 | print(" epoch:%3d , step: %3d , epsilon: %.3f , reward: %d"%(episode, step, agent.epsilon, reward_episode)) 112 | # file = open(file_path,'w') 113 | # file.write(" episode:%3d , step:%3d , reward: %d"%(episode, step, reward_episode)) 114 | # file.write("\n") 115 | reward_every.append(reward_episode) 116 | # break 117 | 118 | reward_all += reward_episode 119 | reward_all_his.append(reward_all) 120 | reward_episode = 0 121 | state = next_state 122 | return reward_all_his, reward_every, test_reward_all, test_every_average1_Q, test_every_each10_Q 123 | 124 | 125 | if __name__ == "__main__": 126 | set_random_seed(args.seed) 127 | 128 | env_set = ["BreakoutNoFrameskip-v4", "PongNoFrameskip-v4"] 129 | if args.env_name in env_set: 130 | args.flag_cnn = True 131 | else: 132 | args.flag_cnn = False 133 | 134 | memory = Memory(args.memory_size, flag_piexl=args.flag_cnn) 135 | 136 | config = tf.ConfigProto() 137 | config.gpu_options.allow_growth = True 138 | # config.gpu_options.per_process_gpu_memory_fraction = 0.4 139 | print(config) 140 | 141 | with tf.Session(config=config) as sess2: 142 | env_1 = envMakeWrapper(args.env_name, args.flag_cnn) 143 | scope_name=str(time.time()) 144 | with tf.variable_scope(scope_name): 145 | print("") 146 | print("*******************") 147 | print("double-DQN" if args.flag_double_dqn else "DQN") 148 | print("*******************") 149 | DQN = DeepQNetwork4Atari(scope_name, env_1 , args.flag_double_dqn, args.flag_cnn, sess2) 150 | reward_all_his, reward_every, test_reward_all, test_every_average1_Q, test_every_each10_Q= training(DQN, env_1, args.flag_double_dqn) 151 | -------------------------------------------------------------------------------- /drl-gym/testModel.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import gym 6 | import random 7 | from envWrapper import envMakeWrapper 8 | 9 | def testQGame(env_name, agent, episode, test_reward_all_his, test_every_average1, test_every_each10, flag_cnn, max_step=100000000, test_times=10): 10 | # print("testing_q...") 11 | # TODO: remove the loop of "maxStep" 12 | env1 = envMakeWrapper(env_name, flag_cnn) 13 | if len(test_reward_all_his) is 0: 14 | test_reward_all = 0 15 | else: 16 | test_reward_all = test_reward_all_his[-1] 17 | reward_all_episode = 0 18 | for i in range(test_times): 19 | # env1.render() 20 | reward_episode = 0 21 | state = env1.reset() 22 | for step in range(max_step): 23 | action = agent.greedy_action(state) 24 | next_state , reward , done , _ = env1.step(action) 25 | reward_episode += reward 26 | if done: 27 | print("times: %2d -- step: %3d -- reward: %d"%(i, step, reward_episode)) 28 | 29 | test_reward_all += reward_episode 30 | reward_all_episode += reward_episode 31 | test_reward_all_his.append(test_reward_all) 32 | test_every_each10.append(reward_episode) 33 | break 34 | state = next_state 35 | test_average = reward_all_episode / test_times 36 | test_every_average1.append(test_average) 37 | 38 | print("test_times: %d , per_reward: %.4f"%( test_times , test_average) ) 39 | return test_reward_all_his, test_every_average1, test_every_each10 40 | -------------------------------------------------------------------------------- /drl-gym/utils.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow as tf 5 | import random 6 | import numpy as np 7 | 8 | def set_random_seed(SEED=0): 9 | tf.set_random_seed(SEED) 10 | np.random.seed(SEED) 11 | random.seed(SEED) 12 | 13 | def huber_loss(x, delta=1.0): 14 | """Reference: https://en.wikipedia.org/wiki/Huber_loss""" 15 | return tf.where( 16 | tf.abs(x) < delta, 17 | tf.square(x) * 0.5, 18 | delta * (tf.abs(x) - 0.5 * delta) 19 | ) 20 | 21 | 22 | def build_rmsprop_optimizer(learning_rate, rmsprop_decay, rmsprop_constant, gradient_clip, version, loss): 23 | with tf.name_scope('rmsprop'): 24 | optimizer = None 25 | if version == 'rmsprop': 26 | optimizer = tf.train.RMSPropOptimizer(learning_rate, decay=rmsprop_decay, momentum=0.0, epsilon=rmsprop_constant) 27 | elif version == 'graves_rmsprop': 28 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 29 | 30 | grads_and_vars = optimizer.compute_gradients(loss) 31 | grads = [gv[0] for gv in grads_and_vars] 32 | params = [gv[1] for gv in grads_and_vars] 33 | # print(grads) 34 | if gradient_clip > 0: 35 | grads = tf.clip_by_global_norm(grads, gradient_clip)[0] 36 | 37 | grads = [grad for grad in grads if grad != None] 38 | 39 | if version == 'rmsprop': 40 | return optimizer.apply_gradients(zip(grads, params)) 41 | elif version == 'graves_rmsprop': 42 | square_grads = [tf.square(grad) for grad in grads if grad != None] 43 | 44 | avg_grads = [tf.Variable(tf.zeros(var.get_shape())) for var in params] 45 | avg_square_grads = [tf.Variable(tf.zeros(var.get_shape())) for var in params] 46 | 47 | update_avg_grads = [grad_pair[0].assign((rmsprop_decay * grad_pair[0]) + ((1 - rmsprop_decay) * grad_pair[1])) 48 | for grad_pair in zip(avg_grads, grads)] 49 | 50 | update_avg_square_grads = [grad_pair[0].assign((rmsprop_decay * grad_pair[0]) + ((1 - rmsprop_decay) * tf.square(grad_pair[1]))) 51 | for grad_pair in zip(avg_square_grads, grads)] 52 | avg_grad_updates = update_avg_grads + update_avg_square_grads 53 | 54 | rms = [tf.sqrt(avg_grad_pair[1] - tf.square(avg_grad_pair[0]) + rmsprop_constant) 55 | for avg_grad_pair in zip(avg_grads, avg_square_grads)] 56 | 57 | rms_updates = [grad_rms_pair[0] / grad_rms_pair[1] for grad_rms_pair in zip(grads, rms)] 58 | train = optimizer.apply_gradients(zip(rms_updates, params)) 59 | 60 | return tf.group(train, tf.group(*avg_grad_updates)) --------------------------------------------------------------------------------