├── README.md ├── chapter10 ├── DQN.ipynb ├── agent.py ├── cnn.py ├── deep_q.py ├── environment.py ├── experience.py ├── history.py ├── main.py └── statistic.py ├── chapter2 ├── Breakout.ipynb ├── Breakout.py ├── Environment.ipynb ├── Environment.py ├── greedy.ipynb └── greedy.py ├── chapter3 ├── Policy Evaluation.ipynb ├── Policy Evaluation.py ├── Policy Improvement.ipynb ├── Policy Improvement.py ├── Value Iteration.ipynb └── Value Iteration.py ├── chapter4 ├── MC firstvisit prediciton.ipynb ├── MC firstvisit prediciton.py ├── MC_blackjack.ipynb ├── MC_blackjack.py ├── MC_firstvisit_control .ipynb ├── MC_firstvisit_control .py ├── MC_off_policy_weighted_importance_sampleing.ipynb └── MC_off_policy_weighted_importance_sampleing.py ├── chapter5 ├── TD_CartPole.ipynb ├── TD_CartPole.py ├── TD_Qlearning.ipynb ├── TD_Qlearning.py ├── TD_sarsa.ipynb └── TD_sarsa.py ├── chapter6 ├── FA_Qlearning.ipynb ├── FA_Qlearning.py ├── FA_Qlearning2.ipynb ├── FA_Qlearning2.py ├── FA_SARSA.ipynb └── FA_SARSA.py └── chapter7 ├── PG_ACPG.ipynb └── PG_MCPG.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Deep Reinforcement Learning: Principles and Practices 2 | 3 | 4 | < Deep Reinforcement Learning: Principles and Practices >《深度强化学习:原理与实践》代码示例 5 | 6 | 豆瓣介绍:https://book.douban.com/subject/32568833/ 7 | 8 | 9 | ### 介绍 10 | 11 | 《深度强化学习:原理与实践》包含12个章节和5个附录,其中第1至8章围绕强化学习领域,第9至12章围绕深度强化学习领域,附录A至附录E主要介绍深度学习相关的基础知识。基于章节之间的逻辑关系,本书将12个章节分成四篇(核心为第二至四篇),接下来对这四篇内容分别进行简要介绍。 ### 第一篇(第1~2章)初探强化学习 12 | 这部分主要围绕强化学习的概念和基础框架,包括其基本概念和数学原理。该部分介绍的基础知识将贯穿全书,尽管涉及的数学公式和推导方程稍显复杂,但有助于深度理解强化学习的基础概念。 第1章按顺序依次介绍强化学习的发展历史、基础理论、应用案例、特点与未来。从强化学习的发展历史中可以了解强化学习与机器学习之间的关系;基础理论可以帮助读者对强化学习有一个整体的认识与了解,通过具体的应用案例可以了解如何对强化学习进行落地应用。最后,从宏观角度对强化学习的特点与未来进行了讨论。第2章则集中介绍强化学习涉及的数学概念,从马尔科夫决策过程对强化学习任务的表示开始,到介绍价值函数和策略。其中,价值函数是强化学习的核心,后续章节的大部分求解方法都集中在价值函数的逼近上。 ### 第二篇(第3~5章)求解强化学习 这部分主要探讨如何通过数学求解获得强化学习的最优策略。对于基于模型的强化学习任务可以使用动态规划法,对于免模型的强化学习任务可以使用蒙特卡洛法和时间差分法。值得注意的是,本部分对于强化学习任务的求解使用的是基于表格的求解方法。 第3章介绍使用动态规划法求解强化学习任务,通过策略评估和策略改进的迭代交互计算方式,提出了用以求解价值函数和策略的策略迭代算法。然而策略迭代算法存在效率低、初始化随机性等问题,研究者又提出了值迭代算法。由于实际情况中不一定能够获得完备的环境知识,因此出现了第4章的针对免模型任务的强化学习求解方法。其中,蒙特卡洛求解法基于采样的经验轨迹,从真实/仿真的环境中进行采样学习,并分别从蒙特卡洛预测、蒙特卡洛评估到蒙特卡洛控制进行了详细介绍。事实上,蒙特卡洛法同样存在一些不足,如使用离线学习方式、数据方差大、收敛速度慢等,这会导致在真实环境中的运行效果并不理想。第5章中引入了在线学习的时间差分法,主要分为固定策略的Sarsa算法和非固定策略的Q-learning算法。需要注意的是,Q-learning算法将作为深度强化学习(即第四篇)中的基础算法之一。 ### 第三篇(第6~8章)求解强化学习进阶 动态规划法、蒙特卡洛法、时间差分法都属于基于表格的求解方法。第三篇介绍的近似求解法通过寻找目标函数的近似函数,大大降低了表格求解法所需的计算规模和复杂度。近似求解方法主要分为3种:基于价值的强化学习求解法——值函数近似法;基于策略的强化学习求解法——策略梯度法;基于模型的强化学习求解法——学习与规划。 第6章详细介绍了基于价值的强化学习任务求解方法,即对价值函数进行近似求解。通过对函数近似进行数学解释,来引入值函数近似的数学概念和值函数近似法。然而,基于值函数近似的方法难以处理连续动作空间的任务,因此有了第7章介绍的策略梯度法。其将策略的学习从概率集合变换成策略函数,通过求解策略目标函数的极大值,从而得到最优策略。第8章为基于模型的强化学习,智能体从真实的经验数据中学习环境模型,并基于该环境模型产生的虚拟经验轨迹进行规划,从而获得价值函数或者策略函数。 ### 第四篇(第9~12章)深度强化学习 此部分主要围绕深度强化学习展开,该技术通过结合深度学习的表征能力和强化学习的决策能力,使得智能体具备了更好的学习能力,能够解决更为复杂的感知决策问题。 第9章首先概述深度学习中较为经典的3种网络结构模型:深度神经网络、卷积神经网络和循环神经网络。随后介绍深度强化学习相关概念,并对深度强化学习当前具有代表性的应用进行简单介绍。第10章介绍了第一个深度强化学习算法:DQN算法。该方法通过结合Q-learning算法、经验回放机制以及卷积神经网络生成目标Q值等技术,有效地解决了深度学习和强化学习融合过程中所面临的问题和挑战,实现了深度学习与强化学习的深层次融合。第11章介绍了DQN算法所存在的不足,以及后续研究者所提出的具有代表性的深度强化学习算法:DDPG算法、A3C算法、Rainbow算法和Ape-X算法。第12章全面而细致地介绍了AlphaGo程序的设计思想与原理,并给出了AlphaGo和AlphaGo Zero程序的算法细节。 本书的最后提供了附录A~附录E,内容涵盖深度学习方面相关函数、算法及技巧,供读者学习使用。 13 | -------------------------------------------------------------------------------- /chapter10/DQN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", 13 | " from ._conv import register_converters as _register_converters\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import gym\n", 19 | "import os\n", 20 | "import sys\n", 21 | "import itertools\n", 22 | "import numpy as np\n", 23 | "import random\n", 24 | "import tensorflow as tf\n", 25 | "from collections import defaultdict, namedtuple\n", 26 | "\n", 27 | "import matplotlib\n", 28 | "from matplotlib import pyplot as plt\n", 29 | "%matplotlib inline\n", 30 | "matplotlib.style.use('ggplot')" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "# Deep q Network\n", 40 | "flags.DEFINE_boolean('use_gpu', True, 'Whether to use gpu or not. gpu use NHWC and gpu use NCHW for data_format')\n", 41 | "flags.DEFINE_string('agent_type', 'DQN', 'The type of agent [DQN]')\n", 42 | "flags.DEFINE_boolean('double_q', False, 'Whether to use double Q-learning')\n", 43 | "flags.DEFINE_string('network_header_type', 'nips', 'The type of network header [mlp, nature, nips]')\n", 44 | "flags.DEFINE_string('network_output_type', 'normal', 'The type of network output [normal, dueling]')\n", 45 | "\n", 46 | "# Environment\n", 47 | "flags.DEFINE_string('env_name', 'Breakout-v0', 'The name of gym environment to use')\n", 48 | "flags.DEFINE_integer('n_action_repeat', 1, 'The number of actions to repeat')\n", 49 | "flags.DEFINE_integer('max_random_start', 30, 'The maximum number of NOOP actions at the beginning of an episode')\n", 50 | "flags.DEFINE_integer('history_length', 4, 'The length of history of observation to use as an input to DQN')\n", 51 | "flags.DEFINE_integer('max_r', +1, 'The maximum value of clipped reward')\n", 52 | "flags.DEFINE_integer('min_r', -1, 'The minimum value of clipped reward')\n", 53 | "flags.DEFINE_string('observation_dims', '[80, 80]', 'The dimension of gym observation')\n", 54 | "flags.DEFINE_boolean('random_start', True, 'Whether to start with random state')\n", 55 | "flags.DEFINE_boolean('use_cumulated_reward', False, 'Whether to use cumulated reward or not')\n", 56 | "\n", 57 | "# Training\n", 58 | "flags.DEFINE_boolean('is_train', True, 'Whether to do training or testing')\n", 59 | "flags.DEFINE_integer('max_delta', None, 'The maximum value of delta')\n", 60 | "flags.DEFINE_integer('min_delta', None, 'The minimum value of delta')\n", 61 | "flags.DEFINE_float('ep_start', 1., 'The value of epsilon at start in e-greedy')\n", 62 | "flags.DEFINE_float('ep_end', 0.01, 'The value of epsilnon at the end in e-greedy')\n", 63 | "flags.DEFINE_integer('batch_size', 32, 'The size of batch for minibatch training')\n", 64 | "flags.DEFINE_integer('max_grad_norm', None, 'The maximum norm of gradient while updating')\n", 65 | "flags.DEFINE_float('discount_r', 0.99, 'The discount factor for reward')\n", 66 | "\n", 67 | "# Timer\n", 68 | "flags.DEFINE_integer('t_train_freq', 4, '')\n", 69 | "\n", 70 | "# Below numbers will be multiplied by scale\n", 71 | "flags.DEFINE_integer('scale', 10000, 'The scale for big numbers')\n", 72 | "flags.DEFINE_integer('memory_size', 100, 'The size of experience memory (*= scale)')\n", 73 | "flags.DEFINE_integer('t_target_q_update_freq', 1, 'The frequency of target network to be updated (*= scale)')\n", 74 | "flags.DEFINE_integer('t_test', 1, 'The maximum number of t while training (*= scale)')\n", 75 | "flags.DEFINE_integer('t_ep_end', 100, 'The time when epsilon reach ep_end (*= scale)')\n", 76 | "flags.DEFINE_integer('t_train_max', 5000, 'The maximum number of t while training (*= scale)')\n", 77 | "flags.DEFINE_float('t_learn_start', 5, 'The time when to begin training (*= scale)')\n", 78 | "flags.DEFINE_float('learning_rate_decay_step', 5, 'The learning rate of training (*= scale)')\n", 79 | "\n", 80 | "# Optimizer\n", 81 | "flags.DEFINE_float('learning_rate', 0.00025, 'The learning rate of training')\n", 82 | "flags.DEFINE_float('learning_rate_minimum', 0.00025, 'The minimum learning rate of training')\n", 83 | "flags.DEFINE_float('learning_rate_decay', 0.96, 'The decay of learning rate of training')\n", 84 | "flags.DEFINE_float('decay', 0.99, 'Decay of RMSProp optimizer')\n", 85 | "flags.DEFINE_float('momentum', 0.0, 'Momentum of RMSProp optimizer')\n", 86 | "flags.DEFINE_float('gamma', 0.99, 'Discount factor of return')\n", 87 | "flags.DEFINE_float('beta', 0.01, 'Beta of RMSProp optimizer')" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "FLAGS.__delattr__() \n", 97 | "\n", 98 | "flags = tf.app.flags\n", 99 | "FLAGS = flags.FLAGS\n", 100 | "\n", 101 | "flags.DEFINE_boolean(\"duele\", False, \"use dueling deep Q-learning\")\n", 102 | "flags.DEFINE_boolean(\"double\", False, \"use double Q-learning\")\n", 103 | "\n", 104 | "flags.DEFINE_boolean(\"is_train\", True, \"training or testing\")\n", 105 | "flags.DEFINE_integer(\"random_seed\", 123, \"value of random seed\")\n", 106 | "flags.DEFINE_boolean(\"display\", False, \"display the game\")\n", 107 | "flags.DEFINE_integer(\"scale\", 10000, \"step and the memory size\")\n", 108 | "flags.DEFINE_integer(\"batch_size\", 32, \"batch size\")\n", 109 | "flags.DEFINE_float(\"discount\", 0.99)\n", 110 | "flags.DEFINE_float(\"learning_rate\", 0.00025)\n", 111 | "flags.DEFINE_float(\"learning_rate_min\", 0.00025)\n", 112 | "flags.DEFINE_float(\"learning_rate_decay\", 0.96)\n", 113 | "flags.DEFINE_integer(\"history_length\", 4)\n", 114 | "flags.DEFINE_integer(\"train_frequency\", 4)\n", 115 | "\n", 116 | "flags.DEFINE_integer(\"learn_start\", 50000)\n", 117 | "flags.DEFINE_integer(\"frame_width\", 84)\n", 118 | "flags.DEFINE_integer(\"frame_height\", 84)\n", 119 | "flags.DEFINE_integer(\"max_reward\", 1)\n", 120 | "flags.DEFINE_integer(\"min_reward\", -1)\n", 121 | "flags.DEFINE_integer(\"episode_in_test\", 80)\n", 122 | "flags.DEFINE_integer(\"episode_in_train\", 18000)\n", 123 | "flags.DEFINE_integer(\"test_max_step\", 10000)\n", 124 | "\n", 125 | "FLAGS = flags.FLAGS" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "env = gym.make('Breakout-v0')\n", 135 | "env = env.unwrapped\n", 136 | "env.seed(FLAGS.random_seed)\n", 137 | "tf.set_random_seed(FLAGS.random_seed)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "class Environment(object):\n", 147 | " def __init__(self, env, history):\n", 148 | " self.env = env\n", 149 | " self.reward = 0\n", 150 | " self.terminal = False\n", 151 | " self.state_history = history\n", 152 | " self.state_dim = (FLAGS.frame_width, FLAGS.frame_height)\n", 153 | " self.nA = self.env.action_space.n\n", 154 | " self.nS = None\n", 155 | "\n", 156 | " def reset(self):\n", 157 | " self.env.reset()\n", 158 | " \n", 159 | " def random_start(self):\n", 160 | " self.reset()\n", 161 | " for _ in reversed(range(random.randint(4, 30))):\n", 162 | " state, _, _, _ = self.env.step(0)\n", 163 | " if 4 - _ > 0:\n", 164 | " self.state_history.push(self.__frame(state))\n", 165 | " \n", 166 | " self.env.render()\n", 167 | " return self.state_history\n", 168 | " \n", 169 | " def step(self, action):\n", 170 | " state, self.reward, self.terminal, _ = self.env.step(action)\n", 171 | " self.state = self.__frame(state)\n", 172 | "\n", 173 | " self.env.render()\n", 174 | " return self.state, self.reward, self.terminal\n", 175 | " \n", 176 | " @property\n", 177 | " def __frame(self, state):\n", 178 | " processed_state = np.array(state)\n", 179 | " frame_state = np.uint8(resize(rgb2gray(processed_state)/255., self.state_dim))\n", 180 | " return frame_state" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "class History(object):\n", 190 | " def __init__(self):\n", 191 | " self.history = np.zeros([FLAGS.history_length, \n", 192 | " FLAGS.frame_width,\n", 193 | " FLAGS.frame_height], dtype=np.float32)\n", 194 | " \n", 195 | " def push(self, state):\n", 196 | " self.history[:-1] = self.history[1:]\n", 197 | " self.history[-1] = state\n", 198 | " \n", 199 | " def get():\n", 200 | " retrun self.history\n", 201 | " \n", 202 | " def clean():\n", 203 | " self.history *= 0" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "class Memory(object):\n", 213 | " def __init__(self):\n", 214 | " pass\n", 215 | " \n", 216 | " def push(self):\n", 217 | " pass\n", 218 | " \n", 219 | " def getState(self):\n", 220 | " pass\n", 221 | " \n", 222 | " def sample(self):\n", 223 | " pass" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "class Agent(object):\n", 233 | " def __init__(self, env, history, memory):\n", 234 | " self.env = env\n", 235 | " self.nA = env.nA\n", 236 | " self.state_history = history\n", 237 | " self.state_memroy = memory\n", 238 | " self.t = 0\n", 239 | " \n", 240 | " self.q_value, self.q_network = self.__build_network()\n", 241 | " self.target_q_value, self.target_q_network = self.__build_network()\n", 242 | " \n", 243 | " self.sess = tf.Session()\n", 244 | " \n", 245 | " tf.summary.FileWriter(\"summary/\", self.sess.graph)\n", 246 | " self.sess.run(tf.global_variables_initializer())\n", 247 | " self.saver = tf.train.Saver()\n", 248 | " \n", 249 | " def predict(self, state):\n", 250 | " if self.t < FLAGS.learn_start:\n", 251 | " action = random.randrange(self.nA)\n", 252 | " else:\n", 253 | " action = np.argmax(self.q_value.eval(feed_dict={self.s:state}))[0]\n", 254 | " \n", 255 | " return action\n", 256 | " \n", 257 | " def run(self, state, reward, action, done):\n", 258 | " reward = max(self.min_reward, min(self.self.max_reward, reward))\n", 259 | " \n", 260 | " self.history.add(state)\n", 261 | " self.memory.add(state, reward, action, done)\n", 262 | " \n", 263 | " if self.t > FLAGS.learn_start:\n", 264 | " if self.t % FLAGS.train_frequency:\n", 265 | " self.q_\n", 266 | " \n", 267 | " \n", 268 | " # 调用sess.run运行图,生成一步的训练过程数据 \n", 269 | " train_summary = sess.run(merge_summary, feed={})\n", 270 | " # 调用train_writer的add_summary方法将训练过程以及训练步数保存 \n", 271 | " train_writer.add_summary(train_summary, step)\n", 272 | " \n", 273 | " self.t += 1\n", 274 | " \n", 275 | " def __build_network(self):\n", 276 | " \"\"\"\n", 277 | " build the natural network\n", 278 | " \"\"\"\n", 279 | " # Create placeholders\n", 280 | " with tf.name_scope('actor_inputs'):\n", 281 | " self.X = tf.placeholder(tf.float32, \n", 282 | " shape=[None, FLAGS.frame_width, FLAGS.frame_height, FLAGS.history_length],\n", 283 | " name=\"states\")\n", 284 | " self.Y = tf.placeholder(tf.float32, shape=(self.n_y, None), name=\"action\")\n", 285 | " self.disc_norm_ep_reward = tf.placeholder(tf.float32, name=\"td_error\")\n", 286 | "\n", 287 | " with tf.name_scope(\"conv1\"):\n", 288 | " conv1 = tf.nn.con2d(self.X, 32, \n", 289 | " kernel_size=[8, 8], strides=[4, 4],\n", 290 | " padding=\"same\", activation=tf.nn.relu)\n", 291 | " # pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2,2], strides=2)\n", 292 | " with tf.name_scope(\"conv2\"):\n", 293 | " conv2 = tf.nn.conv2(conv1, 64,\n", 294 | " kernel_size=[4, 4], strides=[2, 2],\n", 295 | " padding=\"same\", activation=tf.nn.relu)\n", 296 | " # pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2,2], strides=2)\n", 297 | " with tf.name_scope(\"conv3\"):\n", 298 | " conv3 = tf.nn.conv2(conv2, 64,\n", 299 | " kernel_size=[3, 3], strides=[1, 1],\n", 300 | " padding=\"same\", activation=tf.nn.relu)\n", 301 | " # pool3 = tf.layers.max_pooling2d(inputs=conv3, pool_size=[2,2], strides=2)\n", 302 | "\n", 303 | " with tf.name_scope(\"dense_layer\"):\n", 304 | " conv3_flat = tf.reshape(conv3, [-1, * 64])\n", 305 | " dense1 = tf.layers.dense(conv3_flat, units=512, activation=tf.nn.relu)\n", 306 | " \n", 307 | " with tf.name_scope(\"logits_layer\"):\n", 308 | " logits = tf.layers.dense(dense1, units=self.env.nA)\n", 309 | " \n", 310 | "\n", 311 | " with tf.name_scope('actor_loss'):\n", 312 | " neg_log_prob = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels)\n", 313 | " loss = tf.reduce_mean(neg_log_prob * self.disc_norm_ep_reward) # reward guided loss\n", 314 | "\n", 315 | " with tf.name_scope('actor_train'):\n", 316 | " self.train_op = tf.train.AdamOptimizer(self.lr).minimize(loss)\n", 317 | " \n", 318 | " def __weigfht_variable(self, shape, name):\n", 319 | " initial = tf.contrib.layers.xavier_initializer(seed=1)\n", 320 | " return tf.get_variable(name, shape, initializer=initial)\n", 321 | " \n", 322 | " def __bias_bariable(self, shape, name):\n", 323 | " initial = tf.contrib.layers.xavier_initializer(seed=1)\n", 324 | " return tf.get_variable(name, shape, initializer=initial)\n", 325 | " \n", 326 | " def setup_summary(self):\n", 327 | " # 生成准确率标量图 \n", 328 | " average_reward = tf.placeholder('float32', None, name=\"average_reward\")\n", 329 | " tf.summary.scalar(\"average_reward\", average_reward)\n", 330 | " average_loss = tf.placeholder('float32', None, name=\"average_loss\")\n", 331 | " tf.summary.scalar(\"average_loss\", average_loss)\n", 332 | " average_q = tf.placeholder('float',None, name=\"average_q\")\n", 333 | " tf.summary.scalar(\"average_q\", average_q)\n", 334 | " \n", 335 | " episode_max_reward = tf.placeholder('float',None, name=\"episode_max_reward\")\n", 336 | " tf.summary.scalar(\"episode_max_reward\", episode_max_reward)\n", 337 | " episode_min_reward = tf.placeholder('float',None, name=\"episode_min_reward\")\n", 338 | " tf.summary.scalar(\"episode_min_reward\", episode_min_reward)\n", 339 | " episode_avg_reward = tf.placeholder('float',None, name=\"episode_avg_reward\")\n", 340 | " tf.summary.scalar(\"episode_avg_reward\", episode_avg_reward)\n", 341 | " episode_num = tf.placeholder('float',None, name=\"episode_num\")\n", 342 | " tf.summary.scalar(\"episode_num\", episode_num)\n", 343 | " episode_learning_rate = tf.placeholder('float',None, name=\"episode_learning_rate\")\n", 344 | " tf.summary.scalar(\"episode_learning_rate\", episode_learning_rate)\n", 345 | " \n", 346 | " # 定义一个写入summary的目标文件,dir为写入文件地址 \n", 347 | " merge_summary = tf.summary.merge_all()\n", 348 | " train_writer = tf.summary.FileWriter('./logs/', self.sess.graph)" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": 27, 354 | "metadata": {}, 355 | "outputs": [ 356 | { 357 | "ename": "IndentationError", 358 | "evalue": "expected an indented block (, line 5)", 359 | "output_type": "error", 360 | "traceback": [ 361 | "\u001b[1;36m File \u001b[1;32m\"\"\u001b[1;36m, line \u001b[1;32m5\u001b[0m\n\u001b[1;33m else:\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mIndentationError\u001b[0m\u001b[1;31m:\u001b[0m expected an indented block\n" 362 | ] 363 | } 364 | ], 365 | "source": [ 366 | "def deep_Qlearning(env):\n", 367 | " state_history = History()\n", 368 | " state_memory = Memory()\n", 369 | " env = Environment(gym.make('Breakout-v0'), state_history)\n", 370 | " agent = Agent(env, state_history, state_memory)\n", 371 | "\n", 372 | " if FLAGS.is_train:\n", 373 | " # for trainning the deep Q learning model\n", 374 | " max_avg_reward = 0\n", 375 | " \n", 376 | " for _ in range(FLAGS.episode_in_trains):\n", 377 | " total_reward, total_loss, total_q, ep_reward = 0, 0, 0, 0\n", 378 | " rewards, actions = [], []\n", 379 | " state_history = env.random_start()\n", 380 | " \n", 381 | " for t in itertools.count():\n", 382 | " # predict\n", 383 | " action = agent.predict(state_history.get())\n", 384 | " # action\n", 385 | " state, reward, done = env.step(action)\n", 386 | " # record\n", 387 | " # target = reward + gamma * np.amax(model.predict(next_state))\n", 388 | " agent.run(state, reward, action, done)\n", 389 | " \n", 390 | " if done:\n", 391 | " ep_reward = 0\n", 392 | " rewards.append(ep_reward)\n", 393 | " state_history = env.random_start()\n", 394 | " else:\n", 395 | " ep_reward += reward\n", 396 | " \n", 397 | " actions.append(action)\n", 398 | " total_reward += reward \n", 399 | " else:\n", 400 | " # for test the deep Q learning model\n", 401 | " best_reward, best_idx = 0, 0\n", 402 | " for _ in range(FLAGS.episode_in_test):\n", 403 | " state_history = env.random_start()\n", 404 | " current_reward = 0\n", 405 | " \n", 406 | " for t in itertools.count():\n", 407 | " # predict\n", 408 | " action = agent.predict(state_history.get())\n", 409 | " # action\n", 410 | " state, reward, done = env.step(action)\n", 411 | " # record\n", 412 | " state_history.push(state)\n", 413 | " \n", 414 | " current_reward += reward\n", 415 | " if done: break\n", 416 | " \n", 417 | " # print out the reward \n", 418 | " if current_reward > best_reward:\n", 419 | " best_reward = current_reward\n", 420 | " best_idx = _\n", 421 | " print(\"*\"*80)\n", 422 | " print(\"[{}] Best reward:{}\".format(best_idx, best_reward))\n", 423 | " " 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": 12, 429 | "metadata": {}, 430 | "outputs": [ 431 | { 432 | "ename": "TypeError", 433 | "evalue": "'int' object is not subscriptable", 434 | "output_type": "error", 435 | "traceback": [ 436 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 437 | "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", 438 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mrandom\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m5\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", 439 | "\u001b[1;31mTypeError\u001b[0m: 'int' object is not subscriptable" 440 | ] 441 | } 442 | ], 443 | "source": [ 444 | "random.randrange(5)" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": null, 450 | "metadata": {}, 451 | "outputs": [], 452 | "source": [] 453 | } 454 | ], 455 | "metadata": { 456 | "kernelspec": { 457 | "display_name": "Python 3", 458 | "language": "python", 459 | "name": "python3" 460 | }, 461 | "language_info": { 462 | "codemirror_mode": { 463 | "name": "ipython", 464 | "version": 3 465 | }, 466 | "file_extension": ".py", 467 | "mimetype": "text/x-python", 468 | "name": "python", 469 | "nbconvert_exporter": "python", 470 | "pygments_lexer": "ipython3", 471 | "version": "3.6.5" 472 | } 473 | }, 474 | "nbformat": 4, 475 | "nbformat_minor": 2 476 | } 477 | -------------------------------------------------------------------------------- /chapter10/agent.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import time 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | import tensorflow as tf 7 | from logging import getLogger 8 | 9 | from .history import History 10 | from .experience import Experience 11 | 12 | logger = getLogger(__name__) 13 | 14 | def get_time(): 15 | return time.strftime("%Y-%m-%d_%H:%M:%S", time.gmtime()) 16 | 17 | class Agent(object): 18 | def __init__(self, sess, pred_network, env, stat, conf, target_network=None): 19 | self.sess = sess 20 | self.stat = stat 21 | 22 | self.ep_start = conf.ep_start 23 | self.ep_end = conf.ep_end 24 | self.history_length = conf.history_length 25 | self.t_ep_end = conf.t_ep_end 26 | self.t_learn_start = conf.t_learn_start 27 | self.t_train_freq = conf.t_train_freq 28 | self.t_target_q_update_freq = conf.t_target_q_update_freq 29 | self.env_name = conf.env_name 30 | 31 | self.discount_r = conf.discount_r 32 | self.min_r = conf.min_r 33 | self.max_r = conf.max_r 34 | self.min_delta = conf.min_delta 35 | self.max_delta = conf.max_delta 36 | self.max_grad_norm = conf.max_grad_norm 37 | self.observation_dims = conf.observation_dims 38 | 39 | self.learning_rate = conf.learning_rate 40 | self.learning_rate_minimum = conf.learning_rate_minimum 41 | self.learning_rate_decay = conf.learning_rate_decay 42 | self.learning_rate_decay_step = conf.learning_rate_decay_step 43 | 44 | # network 45 | self.double_q = conf.double_q 46 | self.pred_network = pred_network 47 | self.target_network = target_network 48 | self.target_network.create_copy_op(self.pred_network) 49 | 50 | self.env = env 51 | self.history = History(conf.data_format, 52 | conf.batch_size, conf.history_length, conf.observation_dims) 53 | self.experience = Experience(conf.data_format, 54 | conf.batch_size, conf.history_length, conf.memory_size, conf.observation_dims) 55 | 56 | if conf.random_start: 57 | self.new_game = self.env.new_random_game 58 | else: 59 | self.new_game = self.env.new_game 60 | 61 | def train(self, t_max): 62 | tf.global_variables_initializer().run() 63 | 64 | self.stat.load_model() 65 | self.target_network.run_copy() 66 | 67 | start_t = self.stat.get_t() 68 | observation, reward, terminal = self.new_game() 69 | 70 | for _ in range(self.history_length): 71 | self.history.add(observation) 72 | 73 | for self.t in tqdm(range(start_t, t_max), ncols=70, initial=start_t): 74 | ep = (self.ep_end + 75 | max(0., (self.ep_start - self.ep_end) 76 | * (self.t_ep_end - max(0., self.t - self.t_learn_start)) / self.t_ep_end)) 77 | 78 | # 1. predict 79 | action = self.predict(self.history.get(), ep) 80 | # 2. act 81 | observation, reward, terminal, info = self.env.step(action, is_training=True) 82 | # 3. observe 83 | q, loss, is_update = self.observe(observation, reward, action, terminal) 84 | 85 | logger.debug("a: %d, r: %d, t: %d, q: %.4f, l: %.2f" % \ 86 | (action, reward, terminal, np.mean(q), loss)) 87 | 88 | if self.stat: 89 | self.stat.on_step(self.t, action, reward, terminal, 90 | ep, q, loss, is_update, self.learning_rate_op) 91 | if terminal: 92 | observation, reward, terminal = self.new_game() 93 | 94 | def play(self, test_ep, n_step=10000, n_episode=100): 95 | tf.initialize_all_variables().run() 96 | 97 | self.stat.load_model() 98 | self.target_network.run_copy() 99 | 100 | if not self.env.display: 101 | gym_dir = '/tmp/%s-%s' % (self.env_name, get_time()) 102 | env = gym.wrappers.Monitor(self.env.env, gym_dir) 103 | 104 | best_reward, best_idx, best_count = 0, 0, 0 105 | try: 106 | itr = xrange(n_episode) 107 | except NameError: 108 | itr = range(n_episode) 109 | for idx in itr: 110 | observation, reward, terminal = self.new_game() 111 | current_reward = 0 112 | 113 | for _ in range(self.history_length): 114 | self.history.add(observation) 115 | 116 | for self.t in tqdm(range(n_step), ncols=70): 117 | # 1. predict 118 | action = self.predict(self.history.get(), test_ep) 119 | # 2. act 120 | observation, reward, terminal, info = self.env.step(action, is_training=False) 121 | # 3. observe 122 | q, loss, is_update = self.observe(observation, reward, action, terminal) 123 | 124 | logger.debug("a: %d, r: %d, t: %d, q: %.4f, l: %.2f" % \ 125 | (action, reward, terminal, np.mean(q), loss)) 126 | current_reward += reward 127 | 128 | if terminal: 129 | break 130 | 131 | if current_reward > best_reward: 132 | best_reward = current_reward 133 | best_idx = idx 134 | best_count = 0 135 | elif current_reward == best_reward: 136 | best_count += 1 137 | 138 | print ("="*30) 139 | print (" [%d] Best reward : %d (dup-percent: %d/%d)" % (best_idx, best_reward, best_count, n_episode)) 140 | print ("="*30) 141 | 142 | #if not self.env.display: 143 | #gym.upload(gym_dir, writeup='https://github.com/devsisters/DQN-tensorflow', api_key='') 144 | 145 | def predict(self, s_t, ep): 146 | if random.random() < ep: 147 | action = random.randrange(self.env.action_size) 148 | else: 149 | action = self.pred_network.calc_actions([s_t])[0] 150 | return action 151 | 152 | def q_learning_minibatch_test(self): 153 | s_t = np.array([[[ 0., 0., 0., 0.], 154 | [ 0., 0., 0., 0.], 155 | [ 0., 0., 0., 0.], 156 | [ 1., 0., 0., 0.]]], dtype=np.uint8) 157 | s_t_plus_1 = np.array([[[ 0., 0., 0., 0.], 158 | [ 0., 0., 0., 0.], 159 | [ 1., 0., 0., 0.], 160 | [ 0., 0., 0., 0.]]], dtype=np.uint8) 161 | s_t = s_t.reshape([1, 1] + self.observation_dims) 162 | s_t_plus_1 = s_t_plus_1.reshape([1, 1] + self.observation_dims) 163 | 164 | action = [3] 165 | reward = [1] 166 | terminal = [0] 167 | 168 | terminal = np.array(terminal) + 0. 169 | max_q_t_plus_1 = self.target_network.calc_max_outputs(s_t_plus_1) 170 | target_q_t = (1. - terminal) * self.discount_r * max_q_t_plus_1 + reward 171 | 172 | _, q_t, a, loss = self.sess.run([ 173 | self.optim, self.pred_network.outputs, self.pred_network.actions, self.loss 174 | ], { 175 | self.targets: target_q_t, 176 | self.actions: action, 177 | self.pred_network.inputs: s_t, 178 | }) 179 | 180 | logger.info("q: %s, a: %d, l: %.2f" % (q_t, a, loss)) 181 | 182 | def update_target_q_network(self): 183 | assert self.target_network != None 184 | self.target_network.run_copy() 185 | -------------------------------------------------------------------------------- /chapter10/cnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | from functools import reduce 4 | from tensorflow.contrib.layers.python.layers import initializers 5 | 6 | 7 | class CNN(Network): 8 | def __init__(self, sess, 9 | data_format, 10 | history_length, 11 | observation_dims, 12 | output_size, 13 | trainable=True, 14 | hidden_activation_fn=tf.nn.relu, 15 | output_activation_fn=None, 16 | weights_initializer=initializers.xavier_initializer(), 17 | biases_initializer=tf.constant_initializer(0.1), 18 | value_hidden_sizes=[512], 19 | advantage_hidden_sizes=[512], 20 | network_output_type='dueling', 21 | network_header_type='nips', 22 | name='CNN'): 23 | super(CNN, self).__init__(sess, name) 24 | 25 | if data_format == 'NHWC': 26 | self.inputs = tf.placeholder('float32', 27 | [None] + observation_dims + [history_length], name='inputs') 28 | elif data_format == 'NCHW': 29 | self.inputs = tf.placeholder('float32', 30 | [None, history_length] + observation_dims, name='inputs') 31 | else: 32 | raise ValueError("unknown data_format : %s" % data_format) 33 | 34 | self.var = {} 35 | self.l0 = tf.div(self.inputs, 255.) 36 | 37 | with tf.variable_scope(name): 38 | if network_header_type.lower() == 'nature': 39 | self.l1, self.var['l1_w'], self.var['l1_b'] = conv2d(self.l0, 40 | 32, [8, 8], [4, 4], weights_initializer, biases_initializer, 41 | hidden_activation_fn, data_format, name='l1_conv') 42 | self.l2, self.var['l2_w'], self.var['l2_b'] = conv2d(self.l1, 43 | 64, [4, 4], [2, 2], weights_initializer, biases_initializer, 44 | hidden_activation_fn, data_format, name='l2_conv') 45 | self.l3, self.var['l3_w'], self.var['l3_b'] = conv2d(self.l2, 46 | 64, [3, 3], [1, 1], weights_initializer, biases_initializer, 47 | hidden_activation_fn, data_format, name='l3_conv') 48 | self.l4, self.var['l4_w'], self.var['l4_b'] = \ 49 | linear(self.l3, 512, weights_initializer, biases_initializer, 50 | hidden_activation_fn, data_format, name='l4_conv') 51 | layer = self.l4 52 | elif network_header_type.lower() == 'nips': 53 | self.l1, self.var['l1_w'], self.var['l1_b'] = conv2d(self.l0, 54 | 16, [8, 8], [4, 4], weights_initializer, biases_initializer, 55 | hidden_activation_fn, data_format, name='l1_conv') 56 | self.l2, self.var['l2_w'], self.var['l2_b'] = conv2d(self.l1, 57 | 32, [4, 4], [2, 2], weights_initializer, biases_initializer, 58 | hidden_activation_fn, data_format, name='l2_conv') 59 | self.l3, self.var['l3_w'], self.var['l3_b'] = \ 60 | linear(self.l2, 256, weights_initializer, biases_initializer, 61 | hidden_activation_fn, data_format, name='l3_conv') 62 | layer = self.l3 63 | else: 64 | raise ValueError('Wrong DQN type: %s' % network_header_type) 65 | 66 | self.build_output_ops(layer, network_output_type, 67 | value_hidden_sizes, advantage_hidden_sizes, output_size, 68 | weights_initializer, biases_initializer, hidden_activation_fn, 69 | output_activation_fn, trainable) 70 | 71 | 72 | 73 | 74 | class Network(object): 75 | def __init__(self, sess, name): 76 | self.sess = sess 77 | self.copy_op = None 78 | self.name = name 79 | self.var = {} 80 | 81 | def build_output_ops(self, input_layer, network_output_type, 82 | value_hidden_sizes, advantage_hidden_sizes, output_size, 83 | weights_initializer, biases_initializer, hidden_activation_fn, 84 | output_activation_fn, trainable): 85 | 86 | self.outputs, self.var['w_out'], self.var['b_out'] = linear(input_layer, output_size, weights_initializer, 87 | biases_initializer, output_activation_fn, trainable, name='out') 88 | 89 | 90 | self.max_outputs = tf.reduce_max(self.outputs, reduction_indices=1) 91 | self.outputs_idx = tf.placeholder('int32', [None, None], 'outputs_idx') 92 | self.outputs_with_idx = tf.gather_nd(self.outputs, self.outputs_idx) 93 | self.actions = tf.argmax(self.outputs, axis=1) 94 | 95 | def run_copy(self): 96 | if self.copy_op is None: 97 | raise Exception("run `create_copy_op` first before copy") 98 | else: 99 | self.sess.run(self.copy_op) 100 | 101 | def create_copy_op(self, network): 102 | with tf.variable_scope(self.name): 103 | copy_ops = [] 104 | 105 | for name in self.var.keys(): 106 | copy_op = self.var[name].assign(network.var[name]) 107 | copy_ops.append(copy_op) 108 | 109 | self.copy_op = tf.group(*copy_ops, name='copy_op') 110 | 111 | def calc_actions(self, observation): 112 | return self.actions.eval({self.inputs: observation}, session=self.sess) 113 | 114 | def calc_outputs(self, observation): 115 | return self.outputs.eval({self.inputs: observation}, session=self.sess) 116 | 117 | def calc_max_outputs(self, observation): 118 | return self.max_outputs.eval({self.inputs: observation}, session=self.sess) 119 | 120 | def calc_outputs_with_idx(self, observation, idx): 121 | return self.outputs_with_idx.eval( 122 | {self.inputs: observation, self.outputs_idx: idx}, session=self.sess) 123 | 124 | 125 | 126 | def conv2d(x, 127 | output_dim, 128 | kernel_size, 129 | stride, 130 | weights_initializer=tf.contrib.layers.xavier_initializer(), 131 | biases_initializer=tf.zeros_initializer, 132 | activation_fn=tf.nn.relu, 133 | data_format='NHWC', 134 | padding='VALID', 135 | name='conv2d', 136 | trainable=True): 137 | with tf.variable_scope(name): 138 | if data_format == 'NCHW': 139 | stride = [1, 1, stride[0], stride[1]] 140 | kernel_shape = [kernel_size[0], kernel_size[1], x.get_shape()[1], output_dim] 141 | elif data_format == 'NHWC': 142 | stride = [1, stride[0], stride[1], 1] 143 | kernel_shape = [kernel_size[0], kernel_size[1], x.get_shape()[-1], output_dim] 144 | 145 | w = tf.get_variable('w', kernel_shape, 146 | tf.float32, initializer=weights_initializer, trainable=trainable) 147 | conv = tf.nn.conv2d(x, w, stride, padding, data_format=data_format) 148 | 149 | b = tf.get_variable('b', [output_dim], 150 | tf.float32, initializer=biases_initializer, trainable=trainable) 151 | out = tf.nn.bias_add(conv, b, data_format) 152 | 153 | if activation_fn != None: 154 | out = activation_fn(out) 155 | 156 | return out, w, b 157 | 158 | def linear(input_, 159 | output_size, 160 | weights_initializer=initializers.xavier_initializer(), 161 | biases_initializer=tf.zeros_initializer, 162 | activation_fn=None, 163 | trainable=True, 164 | name='linear'): 165 | shape = input_.get_shape().as_list() 166 | 167 | if len(shape) > 2: 168 | input_ = tf.reshape(input_, [-1, reduce(lambda x, y: x * y, shape[1:])]) 169 | shape = input_.get_shape().as_list() 170 | 171 | with tf.variable_scope(name): 172 | w = tf.get_variable('w', [shape[1], output_size], tf.float32, 173 | initializer=weights_initializer, trainable=trainable) 174 | b = tf.get_variable('b', [output_size], 175 | initializer=biases_initializer, trainable=trainable) 176 | out = tf.nn.bias_add(tf.matmul(input_, w), b) 177 | 178 | if activation_fn != None: 179 | return activation_fn(out), w, b 180 | else: 181 | return out, w, b 182 | 183 | def batch_sample(probs, name='batch_sample'): 184 | with tf.variable_scope(name): 185 | uniform = tf.random_uniform(tf.shape(probs), minval=0, maxval=1) 186 | samples = tf.argmax(probs - uniform, dimension=1) 187 | return samples 188 | -------------------------------------------------------------------------------- /chapter10/deep_q.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | from logging import getLogger 6 | 7 | from .agent import Agent 8 | 9 | logger = getLogger(__name__) 10 | 11 | class DeepQ(Agent): 12 | def __init__(self, sess, pred_network, env, stat, conf, target_network=None): 13 | super(DeepQ, self).__init__(sess, pred_network, env, stat, conf, target_network=target_network) 14 | 15 | # Optimizer 16 | with tf.variable_scope('optimizer'): 17 | self.targets = tf.placeholder('float32', [None], name='target_q_t') 18 | self.actions = tf.placeholder('int64', [None], name='action') 19 | 20 | actions_one_hot = tf.one_hot(self.actions, self.env.action_size, 1.0, 0.0, name='action_one_hot') 21 | pred_q = tf.reduce_sum(self.pred_network.outputs * actions_one_hot, reduction_indices=1, name='q_acted') 22 | 23 | self.delta = self.targets - pred_q 24 | self.clipped_error = tf.where(tf.abs(self.delta) < 1.0, 25 | 0.5 * tf.square(self.delta), 26 | tf.abs(self.delta) - 0.5, name='clipped_error') 27 | 28 | self.loss = tf.reduce_mean(self.clipped_error, name='loss') 29 | 30 | self.learning_rate_op = tf.maximum(self.learning_rate_minimum, 31 | tf.train.exponential_decay( 32 | self.learning_rate, 33 | self.stat.t_op, 34 | self.learning_rate_decay_step, 35 | self.learning_rate_decay, 36 | staircase=True)) 37 | 38 | optimizer = tf.train.RMSPropOptimizer( 39 | self.learning_rate_op, momentum=0.95, epsilon=0.01) 40 | 41 | if self.max_grad_norm != None: 42 | grads_and_vars = optimizer.compute_gradients(self.loss) 43 | for idx, (grad, var) in enumerate(grads_and_vars): 44 | if grad is not None: 45 | grads_and_vars[idx] = (tf.clip_by_norm(grad, self.max_grad_norm), var) 46 | self.optim = optimizer.apply_gradients(grads_and_vars) 47 | else: 48 | self.optim = optimizer.minimize(self.loss) 49 | 50 | def observe(self, observation, reward, action, terminal): 51 | reward = max(self.min_r, min(self.max_r, reward)) 52 | 53 | self.history.add(observation) 54 | self.experience.add(observation, reward, action, terminal) 55 | 56 | # q, loss, is_update 57 | result = [], 0, False 58 | 59 | if self.t > self.t_learn_start: 60 | if self.t % self.t_train_freq == 0: 61 | result = self.q_learning_minibatch() 62 | 63 | if self.t % self.t_target_q_update_freq == self.t_target_q_update_freq - 1: 64 | self.update_target_q_network() 65 | 66 | return result 67 | 68 | def q_learning_minibatch(self): 69 | if self.experience.count < self.history_length: 70 | return [], 0, False 71 | else: 72 | s_t, action, reward, s_t_plus_1, terminal = self.experience.sample() 73 | 74 | terminal = np.array(terminal) + 0. 75 | 76 | if self.double_q: 77 | # Double Q-learning 78 | pred_action = self.pred_network.calc_actions(s_t_plus_1) 79 | q_t_plus_1_with_pred_action = self.target_network.calc_outputs_with_idx( 80 | s_t_plus_1, [[idx, pred_a] for idx, pred_a in enumerate(pred_action)]) 81 | target_q_t = (1. - terminal) * self.discount_r * q_t_plus_1_with_pred_action + reward 82 | else: 83 | # Deep Q-learning 84 | max_q_t_plus_1 = self.target_network.calc_max_outputs(s_t_plus_1) 85 | target_q_t = (1. - terminal) * self.discount_r * max_q_t_plus_1 + reward 86 | 87 | _, q_t, loss = self.sess.run([self.optim, self.pred_network.outputs, self.loss], { 88 | self.targets: target_q_t, 89 | self.actions: action, 90 | self.pred_network.inputs: s_t, 91 | }) 92 | 93 | return q_t, loss, True 94 | -------------------------------------------------------------------------------- /chapter10/environment.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import random 3 | import logging 4 | import numpy as np 5 | 6 | try: 7 | import scipy.misc 8 | imresize = scipy.misc.imresize 9 | imwrite = scipy.misc.imsave 10 | except: 11 | import cv2 12 | imresize = cv2.resize 13 | imwrite = cv2.imwrite 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | class Environment(object): 18 | def __init__(self, env_name, n_action_repeat, max_random_start, 19 | observation_dims, data_format, display, use_cumulated_reward=False): 20 | self.env = gym.make(env_name) 21 | 22 | self.n_action_repeat = n_action_repeat 23 | self.max_random_start = max_random_start 24 | self.action_size = self.env.action_space.n 25 | 26 | self.display = display 27 | self.data_format = data_format 28 | self.observation_dims = observation_dims 29 | self.use_cumulated_reward = use_cumulated_reward 30 | 31 | if hasattr(self.env, 'get_action_meanings'): 32 | logger.info("Using %d actions : %s" % (self.action_size, ", ".join(self.env.get_action_meanings()))) 33 | 34 | def new_game(self): 35 | return self.preprocess(self.env.reset()), 0, False 36 | 37 | def new_random_game(self): 38 | return self.new_game() 39 | 40 | def step(self, action, is_training=False): 41 | observation, reward, terminal, info = self.env.step(action) 42 | if self.display: self.env.render() 43 | return self.preprocess(observation), reward, terminal, info 44 | 45 | def preprocess(self): 46 | raise NotImplementedError() 47 | 48 | class ToyEnvironment(Environment): 49 | def preprocess(self, obs): 50 | new_obs = np.zeros([self.env.observation_space.n]) 51 | new_obs[obs] = 1 52 | return new_obs 53 | 54 | class AtariEnvironment(Environment): 55 | def __init__(self, env_name, n_action_repeat, max_random_start, 56 | observation_dims, data_format, display, use_cumulated_reward): 57 | super(AtariEnvironment, self).__init__(env_name, 58 | n_action_repeat, max_random_start, observation_dims, data_format, display, use_cumulated_reward) 59 | 60 | def new_game(self, from_random_game=False): 61 | screen = self.env.reset() 62 | screen, reward, terminal, _ = self.env.step(0) 63 | 64 | if self.display: 65 | self.env.render() 66 | 67 | if from_random_game: 68 | return screen, 0, False 69 | else: 70 | self.lives = self.env.unwrapped.ale.lives() 71 | terminal = False 72 | return self.preprocess(screen, terminal), 0, terminal 73 | 74 | def new_random_game(self): 75 | screen, reward, terminal = self.new_game(True) 76 | 77 | for idx in range(random.randrange(self.max_random_start)): 78 | screen, reward, terminal, _ = self.env.step(0) 79 | 80 | if terminal: logger.warning("warning: terminal signal received after %d 0-steps", idx) 81 | 82 | if self.display: 83 | self.env.render() 84 | 85 | self.lives = self.env.unwrapped.ale.lives() 86 | 87 | terminal = False 88 | return self.preprocess(screen, terminal), 0, terminal 89 | 90 | def step(self, action, is_training): 91 | if action == -1: 92 | # Step with random action 93 | action = self.env.action_space.sample() 94 | 95 | cumulated_reward = 0 96 | 97 | for _ in range(self.n_action_repeat): 98 | screen, reward, terminal, _ = self.env.step(action) 99 | cumulated_reward += reward 100 | current_lives = self.env.unwrapped.ale.lives() 101 | 102 | if is_training and self.lives > current_lives: 103 | terminal = True 104 | 105 | if terminal: break 106 | 107 | if self.display: 108 | self.env.render() 109 | 110 | if not terminal: 111 | self.lives = current_lives 112 | 113 | if self.use_cumulated_reward: 114 | return self.preprocess(screen, terminal), cumulated_reward, terminal, {} 115 | else: 116 | return self.preprocess(screen, terminal), reward, terminal, {} 117 | 118 | def preprocess(self, raw_screen, terminal): 119 | y = 0.2126 * raw_screen[:, :, 0] + 0.7152 * raw_screen[:, :, 1] + 0.0722 * raw_screen[:, :, 2] 120 | y = y.astype(np.uint8) 121 | y_screen = imresize(y, self.observation_dims) 122 | return y_screen 123 | -------------------------------------------------------------------------------- /chapter10/experience.py: -------------------------------------------------------------------------------- 1 | """Modification of https://github.com/tambetm/simple_dqn/blob/master/src/replay_memory.py""" 2 | 3 | import random 4 | import numpy as np 5 | 6 | class Experience(object): 7 | def __init__(self, data_format, batch_size, history_length, memory_size, observation_dims): 8 | self.data_format = data_format 9 | self.batch_size = batch_size 10 | self.history_length = history_length 11 | self.memory_size = memory_size 12 | 13 | self.actions = np.empty(self.memory_size, dtype=np.uint8) 14 | self.rewards = np.empty(self.memory_size, dtype=np.int8) 15 | self.observations = np.empty([self.memory_size] + observation_dims, dtype=np.uint8) 16 | self.terminals = np.empty(self.memory_size, dtype=np.bool) 17 | 18 | # pre-allocate prestates and poststates for minibatch 19 | self.prestates = np.empty([self.batch_size, self.history_length] + observation_dims, dtype = np.float16) 20 | self.poststates = np.empty([self.batch_size, self.history_length] + observation_dims, dtype = np.float16) 21 | 22 | self.count = 0 23 | self.current = 0 24 | 25 | def add(self, observation, reward, action, terminal): 26 | self.actions[self.current] = action 27 | self.rewards[self.current] = reward 28 | self.observations[self.current, ...] = observation 29 | self.terminals[self.current] = terminal 30 | self.count = max(self.count, self.current + 1) 31 | self.current = (self.current + 1) % self.memory_size 32 | 33 | def sample(self): 34 | indexes = [] 35 | while len(indexes) < self.batch_size: 36 | while True: 37 | index = random.randint(self.history_length, self.count - 1) 38 | if index >= self.current and index - self.history_length < self.current: 39 | continue 40 | if self.terminals[(index - self.history_length):index].any(): 41 | continue 42 | break 43 | 44 | self.prestates[len(indexes), ...] = self.retreive(index - 1) 45 | self.poststates[len(indexes), ...] = self.retreive(index) 46 | indexes.append(index) 47 | 48 | actions = self.actions[indexes] 49 | rewards = self.rewards[indexes] 50 | terminals = self.terminals[indexes] 51 | 52 | if self.data_format == 'NHWC' and len(self.prestates.shape) == 4: 53 | return np.transpose(self.prestates, (0, 2, 3, 1)), actions, \ 54 | rewards, np.transpose(self.poststates, (0, 2, 3, 1)), terminals 55 | else: 56 | return self.prestates, actions, rewards, self.poststates, terminals 57 | 58 | def retreive(self, index): 59 | index = index % self.count 60 | if index >= self.history_length - 1: 61 | return self.observations[(index - (self.history_length - 1)):(index + 1), ...] 62 | else: 63 | indexes = [(index - i) % self.count for i in reversed(range(self.history_length))] 64 | return self.observations[indexes, ...] 65 | -------------------------------------------------------------------------------- /chapter10/history.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class History: 4 | def __init__(self, data_format, batch_size, history_length, screen_dims): 5 | self.data_format = data_format 6 | self.history = np.zeros([history_length] + screen_dims, dtype=np.float32) 7 | 8 | def add(self, screen): 9 | self.history[:-1] = self.history[1:] 10 | self.history[-1] = screen 11 | 12 | def reset(self): 13 | self.history *= 0 14 | 15 | def get(self): 16 | if self.data_format == 'NHWC' and len(self.history.shape) == 3: 17 | return np.transpose(self.history, (1, 2, 0)) 18 | else: 19 | return self.history 20 | -------------------------------------------------------------------------------- /chapter10/main.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import random 3 | import logging 4 | import tensorflow as tf 5 | 6 | from utils import get_model_dir 7 | from cnn import CNN 8 | from statistic import Statistic 9 | from environment import ToyEnvironment, AtariEnvironment 10 | 11 | flags = tf.app.flags 12 | 13 | # Deep q Network 14 | flags.DEFINE_boolean('use_gpu', True, 'Whether to use gpu or not. gpu use NHWC and gpu use NCHW for data_format') 15 | flags.DEFINE_string('agent_type', 'DQN', 'The type of agent [DQN]') 16 | flags.DEFINE_boolean('double_q', False, 'Whether to use double Q-learning') 17 | flags.DEFINE_string('network_header_type', 'nips', 'The type of network header [mlp, nature, nips]') 18 | flags.DEFINE_string('network_output_type', 'normal', 'The type of network output [normal, dueling]') 19 | 20 | # Environment 21 | flags.DEFINE_string('env_name', 'Breakout-v0', 'The name of gym environment to use') 22 | flags.DEFINE_integer('n_action_repeat', 1, 'The number of actions to repeat') 23 | flags.DEFINE_integer('max_random_start', 30, 'The maximum number of NOOP actions at the beginning of an episode') 24 | flags.DEFINE_integer('history_length', 4, 'The length of history of observation to use as an input to DQN') 25 | flags.DEFINE_integer('max_r', +1, 'The maximum value of clipped reward') 26 | flags.DEFINE_integer('min_r', -1, 'The minimum value of clipped reward') 27 | flags.DEFINE_string('observation_dims', '[80, 80]', 'The dimension of gym observation') 28 | flags.DEFINE_boolean('random_start', True, 'Whether to start with random state') 29 | flags.DEFINE_boolean('use_cumulated_reward', False, 'Whether to use cumulated reward or not') 30 | 31 | # Training 32 | flags.DEFINE_boolean('is_train', True, 'Whether to do training or testing') 33 | flags.DEFINE_integer('max_delta', None, 'The maximum value of delta') 34 | flags.DEFINE_integer('min_delta', None, 'The minimum value of delta') 35 | flags.DEFINE_float('ep_start', 1., 'The value of epsilon at start in e-greedy') 36 | flags.DEFINE_float('ep_end', 0.01, 'The value of epsilnon at the end in e-greedy') 37 | flags.DEFINE_integer('batch_size', 32, 'The size of batch for minibatch training') 38 | flags.DEFINE_integer('max_grad_norm', None, 'The maximum norm of gradient while updating') 39 | flags.DEFINE_float('discount_r', 0.99, 'The discount factor for reward') 40 | 41 | # Timer 42 | flags.DEFINE_integer('t_train_freq', 4, '') 43 | 44 | # Below numbers will be multiplied by scale 45 | flags.DEFINE_integer('scale', 10000, 'The scale for big numbers') 46 | flags.DEFINE_integer('memory_size', 100, 'The size of experience memory (*= scale)') 47 | flags.DEFINE_integer('t_target_q_update_freq', 1, 'The frequency of target network to be updated (*= scale)') 48 | flags.DEFINE_integer('t_test', 1, 'The maximum number of t while training (*= scale)') 49 | flags.DEFINE_integer('t_ep_end', 100, 'The time when epsilon reach ep_end (*= scale)') 50 | flags.DEFINE_integer('t_train_max', 5000, 'The maximum number of t while training (*= scale)') 51 | flags.DEFINE_float('t_learn_start', 5, 'The time when to begin training (*= scale)') 52 | flags.DEFINE_float('learning_rate_decay_step', 5, 'The learning rate of training (*= scale)') 53 | 54 | # Optimizer 55 | flags.DEFINE_float('learning_rate', 0.00025, 'The learning rate of training') 56 | flags.DEFINE_float('learning_rate_minimum', 0.00025, 'The minimum learning rate of training') 57 | flags.DEFINE_float('learning_rate_decay', 0.96, 'The decay of learning rate of training') 58 | flags.DEFINE_float('decay', 0.99, 'Decay of RMSProp optimizer') 59 | flags.DEFINE_float('momentum', 0.0, 'Momentum of RMSProp optimizer') 60 | flags.DEFINE_float('gamma', 0.99, 'Discount factor of return') 61 | flags.DEFINE_float('beta', 0.01, 'Beta of RMSProp optimizer') 62 | 63 | # Debug 64 | flags.DEFINE_boolean('display', False, 'Whether to do display the game screen or not') 65 | flags.DEFINE_string('log_level', 'INFO', 'Log level [DEBUG, INFO, WARNING, ERROR, CRITICAL]') 66 | flags.DEFINE_integer('random_seed', 123, 'Value of random seed') 67 | flags.DEFINE_string('tag', '', 'The name of tag for a model, only for debugging') 68 | flags.DEFINE_boolean('allow_soft_placement', True, 'Whether to use part or all of a GPU') 69 | #flags.DEFINE_string('gpu_fraction', '1/1', 'idx / # of gpu fraction e.g. 1/3, 2/3, 3/3') 70 | 71 | # Internal 72 | # It is forbidden to set a flag that is not defined 73 | flags.DEFINE_string('data_format', 'NCHW', 'INTERNAL USED ONLY') 74 | 75 | def calc_gpu_fraction(fraction_string): 76 | idx, num = fraction_string.split('/') 77 | idx, num = float(idx), float(num) 78 | 79 | fraction = 1 / (num - idx + 1) 80 | print (" [*] GPU : %.4f" % fraction) 81 | return fraction 82 | 83 | conf = flags.FLAGS 84 | 85 | from deep_q import DeepQ 86 | TrainAgent = DeepQ 87 | 88 | 89 | logger = logging.getLogger() 90 | logger.propagate = False 91 | logger.setLevel(conf.log_level) 92 | 93 | # set random seed 94 | tf.set_random_seed(conf.random_seed) 95 | random.seed(conf.random_seed) 96 | 97 | def main(_): 98 | # preprocess 99 | conf.observation_dims = eval(conf.observation_dims) 100 | 101 | for flag in ['memory_size', 't_target_q_update_freq', 't_test', 102 | 't_ep_end', 't_train_max', 't_learn_start', 'learning_rate_decay_step']: 103 | setattr(conf, flag, getattr(conf, flag) * conf.scale) 104 | 105 | if conf.use_gpu: 106 | conf.data_format = 'NCHW' 107 | else: 108 | conf.data_format = 'NHWC' 109 | 110 | model_dir = get_model_dir(conf, 111 | ['use_gpu', 'max_random_start', 'n_worker', 'is_train', 'memory_size', 'gpu_fraction', 112 | 't_save', 't_train', 'display', 'log_level', 'random_seed', 'tag', 'scale']) 113 | 114 | sess_config = tf.ConfigProto( 115 | log_device_placement=False, allow_soft_placement=conf.allow_soft_placement) 116 | sess_config.gpu_options.allow_growth = conf.allow_soft_placement 117 | 118 | with tf.Session(config=sess_config) as sess: 119 | env = AtariEnvironment(conf.env_name, conf.n_action_repeat, 120 | conf.max_random_start, conf.observation_dims, 121 | conf.data_format, conf.display, conf.use_cumulated_reward) 122 | 123 | 124 | pred_network = CNN(sess=sess, 125 | data_format=conf.data_format, 126 | history_length=conf.history_length, 127 | observation_dims=conf.observation_dims, 128 | output_size=env.env.action_space.n, 129 | network_header_type=conf.network_header_type, 130 | name='pred_network', trainable=True) 131 | target_network = CNN(sess=sess, 132 | data_format=conf.data_format, 133 | history_length=conf.history_length, 134 | observation_dims=conf.observation_dims, 135 | output_size=env.env.action_space.n, 136 | network_header_type=conf.network_header_type, 137 | name='target_network', trainable=False) 138 | 139 | 140 | stat = Statistic(sess, conf.t_test, conf.t_learn_start, model_dir, pred_network.var.values()) 141 | agent = TrainAgent(sess, pred_network, env, stat, conf, target_network=target_network) 142 | 143 | if conf.is_train: 144 | agent.train(conf.t_train_max) 145 | else: 146 | agent.play(conf.ep_end) 147 | 148 | if __name__ == '__main__': 149 | tf.app.run() 150 | -------------------------------------------------------------------------------- /chapter10/statistic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | class Statistic(object): 6 | def __init__(self, sess, t_test, t_learn_start, model_dir, variables, max_to_keep=20): 7 | self.sess = sess 8 | self.t_test = t_test 9 | self.t_learn_start = t_learn_start 10 | 11 | self.reset() 12 | self.max_avg_ep_reward = 0 13 | 14 | with tf.variable_scope('t'): 15 | self.t_op = tf.Variable(0, trainable=False, name='t') 16 | self.t_add_op = self.t_op.assign_add(1) 17 | 18 | self.model_dir = model_dir 19 | self.saver = tf.train.Saver(list(variables) + [self.t_op], max_to_keep=max_to_keep) 20 | self.writer = tf.summary.FileWriter('./logs/%s' % self.model_dir, self.sess.graph) 21 | 22 | with tf.variable_scope('summary'): 23 | scalar_summary_tags = [ 24 | 'average/reward', 'average/loss', 'average/q', 25 | 'episode/max_reward', 'episode/min_reward', 'episode/avg_reward', 26 | 'episode/num_of_game', 'training/learning_rate', 'training/epsilon', 27 | ] 28 | 29 | self.summary_placeholders = {} 30 | self.summary_ops = {} 31 | 32 | for tag in scalar_summary_tags: 33 | self.summary_placeholders[tag] = tf.placeholder('float32', None, name=tag.replace(' ', '_')) 34 | self.summary_ops[tag] = tf.summary.scalar(tag, self.summary_placeholders[tag]) 35 | 36 | histogram_summary_tags = ['episode/rewards', 'episode/actions'] 37 | 38 | for tag in histogram_summary_tags: 39 | self.summary_placeholders[tag] = tf.placeholder('float32', None, name=tag.replace(' ', '_')) 40 | self.summary_ops[tag] = tf.summary.histogram(tag, self.summary_placeholders[tag]) 41 | 42 | 43 | def reset(self): 44 | self.num_game = 0 45 | self.update_count = 0 46 | self.ep_reward = 0. 47 | self.total_loss = 0. 48 | self.total_reward = 0. 49 | self.actions = [] 50 | self.total_q = [] 51 | self.ep_rewards = [] 52 | 53 | def on_step(self, t, action, reward, terminal, 54 | ep, q, loss, is_update, learning_rate_op): 55 | if t >= self.t_learn_start: 56 | self.total_q.extend(q) 57 | self.actions.append(action) 58 | 59 | self.total_loss += loss 60 | self.total_reward += reward 61 | 62 | if terminal: 63 | self.num_game += 1 64 | self.ep_rewards.append(self.ep_reward) 65 | self.ep_reward = 0. 66 | else: 67 | self.ep_reward += reward 68 | 69 | if is_update: 70 | self.update_count += 1 71 | 72 | if t % self.t_test == self.t_test - 1 and self.update_count != 0: 73 | avg_q = np.mean(self.total_q) 74 | avg_loss = self.total_loss / self.update_count 75 | avg_reward = self.total_reward / self.t_test 76 | 77 | try: 78 | max_ep_reward = np.max(self.ep_rewards) 79 | min_ep_reward = np.min(self.ep_rewards) 80 | avg_ep_reward = np.mean(self.ep_rewards) 81 | except: 82 | max_ep_reward, min_ep_reward, avg_ep_reward = 0, 0, 0 83 | 84 | print ('\navg_r: %.4f, avg_l: %.6f, avg_q: %3.6f, avg_ep_r: %.4f, max_ep_r: %.4f, min_ep_r: %.4f, # game: %d' \ 85 | % (avg_reward, avg_loss, avg_q, avg_ep_reward, max_ep_reward, min_ep_reward, self.num_game)) 86 | 87 | if self.max_avg_ep_reward * 0.9 <= avg_ep_reward: 88 | assert t == self.get_t() 89 | 90 | self.save_model(t) 91 | 92 | self.max_avg_ep_reward = max(self.max_avg_ep_reward, avg_ep_reward) 93 | 94 | self.inject_summary({ 95 | 'average/q': avg_q, 96 | 'average/loss': avg_loss, 97 | 'average/reward': avg_reward, 98 | 'episode/max_reward': max_ep_reward, 99 | 'episode/min_reward': min_ep_reward, 100 | 'episode/avg_reward': avg_ep_reward, 101 | 'episode/num_of_game': self.num_game, 102 | 'episode/actions': self.actions, 103 | 'episode/rewards': self.ep_rewards, 104 | 'training/learning_rate': learning_rate_op.eval(session=self.sess), 105 | 'training/epsilon': ep, 106 | }, t) 107 | 108 | self.reset() 109 | 110 | self.t_add_op.eval(session=self.sess) 111 | 112 | def inject_summary(self, tag_dict, t): 113 | summary_str_lists = self.sess.run([self.summary_ops[tag] for tag in tag_dict.keys()], { 114 | self.summary_placeholders[tag]: value for tag, value in tag_dict.items() 115 | }) 116 | for summary_str in summary_str_lists: 117 | self.writer.add_summary(summary_str, t) 118 | 119 | def get_t(self): 120 | return self.t_op.eval(session=self.sess) 121 | 122 | def save_model(self, t): 123 | print(" [*] Saving checkpoints...") 124 | model_name = type(self).__name__ 125 | 126 | if not os.path.exists(self.model_dir): 127 | os.makedirs(self.model_dir) 128 | self.saver.save(self.sess, self.model_dir, global_step=t) 129 | 130 | def load_model(self): 131 | ckpt = tf.train.get_checkpoint_state(self.model_dir) 132 | if ckpt and ckpt.model_checkpoint_path: 133 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 134 | fname = os.path.join(self.model_dir, ckpt_name) 135 | self.saver.restore(self.sess, fname) 136 | print(" [*] Load SUCCESS: %s" % fname) 137 | return True 138 | else: 139 | print(" [!] Load FAILED: %s" % self.model_dir) 140 | return False 141 | -------------------------------------------------------------------------------- /chapter2/Breakout.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gym\n", 10 | "import time\n", 11 | "from gym import wrappers" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": { 18 | "collapsed": true 19 | }, 20 | "outputs": [ 21 | { 22 | "name": "stderr", 23 | "output_type": "stream", 24 | "text": [ 25 | "[2017-07-05 12:55:41,434] Making new env: Breakout-v0\n", 26 | "[2017-07-05 12:55:41,716] An unexpected error occurred while tokenizing input\n", 27 | "The following traceback may be corrupted or invalid\n", 28 | "The error message is: ('EOF in multi-line string', (56, 105))\n", 29 | "\n" 30 | ] 31 | }, 32 | { 33 | "ename": "Error", 34 | "evalue": "Trying to write to monitor directory /Users/chenzomi with existing monitor files: /Users/chenzomi/openaigym.manifest.0.1871.manifest.json.\n\n You should use a unique directory for each training run, or use 'force=True' to automatically clear previous monitor files.", 35 | "output_type": "error", 36 | "traceback": [ 37 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 38 | "\u001b[0;31mError\u001b[0m Traceback (most recent call last)", 39 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Breakout-v0'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwrappers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMonitor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'/Users/chenzomi'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi_episode\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mobservation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 40 | "\u001b[0;32m/Users/chenzomi/gym/gym/wrappers/monitoring.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, env, directory, video_callable, force, resume, write_upon_reset, uid, mode)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m self._start(directory, video_callable, force, resume,\n\u001b[0;32m---> 29\u001b[0;31m write_upon_reset, uid, mode)\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 41 | "\u001b[0;32m/Users/chenzomi/gym/gym/wrappers/monitoring.py\u001b[0m in \u001b[0;36m_start\u001b[0;34m(self, directory, video_callable, force, resume, write_upon_reset, uid, mode)\u001b[0m\n\u001b[1;32m 97\u001b[0m raise error.Error('''Trying to write to monitor directory {} with existing monitor files: {}.\n\u001b[1;32m 98\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m You should use a unique directory for each training run, or use 'force=True' to automatically clear previous monitor files.'''.format(directory, ', '.join(training_manifests[:5])))\n\u001b[0m\u001b[1;32m 100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_monitor_id\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmonitor_closer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mregister\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 42 | "\u001b[0;31mError\u001b[0m: Trying to write to monitor directory /Users/chenzomi with existing monitor files: /Users/chenzomi/openaigym.manifest.0.1871.manifest.json.\n\n You should use a unique directory for each training run, or use 'force=True' to automatically clear previous monitor files." 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "env = gym.make('Breakout-v0')\n", 48 | "env = wrappers.Monitor(env, '/Users/chenzomi')\n", 49 | "\n", 50 | "for i_episode in range(20):\n", 51 | " observation = env.reset()\n", 52 | " for t in range(100):\n", 53 | " env.render()\n", 54 | " \n", 55 | " time.sleep(2)\n", 56 | " print(observation)\n", 57 | " action = env.action_space.sample()\n", 58 | " observation, reward, done, info = env.step(action)\n", 59 | " if done:\n", 60 | " print(\"Episode finished after {} timesteps\".format(t+1))\n", 61 | " break" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": { 68 | "collapsed": true 69 | }, 70 | "outputs": [], 71 | "source": [] 72 | } 73 | ], 74 | "metadata": { 75 | "kernelspec": { 76 | "display_name": "Python 3", 77 | "language": "python", 78 | "name": "python3" 79 | }, 80 | "language_info": { 81 | "codemirror_mode": { 82 | "name": "ipython", 83 | "version": 3 84 | }, 85 | "file_extension": ".py", 86 | "mimetype": "text/x-python", 87 | "name": "python", 88 | "nbconvert_exporter": "python", 89 | "pygments_lexer": "ipython3", 90 | "version": "3.6.0" 91 | } 92 | }, 93 | "nbformat": 4, 94 | "nbformat_minor": 2 95 | } 96 | -------------------------------------------------------------------------------- /chapter2/Breakout.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import gym 8 | import time 9 | from gym import wrappers 10 | 11 | 12 | # In[2]: 13 | 14 | 15 | env = gym.make('Breakout-v0') 16 | env = wrappers.Monitor(env, '/Users/chenzomi') 17 | 18 | for i_episode in range(20): 19 | observation = env.reset() 20 | for t in range(100): 21 | env.render() 22 | 23 | time.sleep(2) 24 | print(observation) 25 | action = env.action_space.sample() 26 | observation, reward, done, info = env.step(action) 27 | if done: 28 | print("Episode finished after {} timesteps".format(t+1)) 29 | break 30 | 31 | -------------------------------------------------------------------------------- /chapter2/Environment.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 49, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stdout", 12 | "output_type": "stream", 13 | "text": [ 14 | "Pretty printing has been turned ON\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import numpy as np\n", 20 | "import sys\n", 21 | "from six import StringIO, b\n", 22 | "from pprint import PrettyPrinter\n", 23 | "%pprint\n", 24 | "\n", 25 | "from gym import utils\n", 26 | "from gym.envs.toy_text import discrete" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 60, 32 | "metadata": { 33 | "collapsed": false 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "pp = PrettyPrinter(indent=2)\n", 38 | "\n", 39 | "UP = 0\n", 40 | "RIGHT = 1\n", 41 | "DOWN = 2\n", 42 | "LEFT = 3\n", 43 | "\n", 44 | "MAPS = {'4x4':[\"SOOO\",\"OXOX\",\"OOOX\",\"XOOG\"]}" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 61, 50 | "metadata": { 51 | "collapsed": false 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "class GridworldEnv(discrete.DiscreteEnv):\n", 56 | " \"\"\"\n", 57 | " FrozenLakeEnv1 is a copy environment from GYM toy_text FrozenLake-01\n", 58 | "\n", 59 | " You are an agent on an 4x4 grid and your goal is to reach the terminal\n", 60 | " state at the bottom right corner.\n", 61 | " \n", 62 | " For example, a 4x4 grid looks as follows:\n", 63 | " \n", 64 | " S O O O\n", 65 | " O X O X\n", 66 | " O O O X\n", 67 | " X O O G\n", 68 | " \n", 69 | " S : starting point, safe\n", 70 | " O : frozen surface, safe\n", 71 | " X : hole, fall to your doom\n", 72 | " G : goal, where the frisbee is located\n", 73 | " \n", 74 | " The episode ends when you reach the goal or fall in a hole.\n", 75 | " You receive a reward of 1 if you reach the goal, and zero otherwise.\n", 76 | " \n", 77 | " You can take actions in each direction (UP=0, RIGHT=1, DOWN=2, LEFT=3).\n", 78 | " Actions going off the edge leave you in your current state.\n", 79 | " \"\"\"\n", 80 | " metadata = {'render.modes': ['human', 'ansi']}\n", 81 | " \n", 82 | " def __init__(self, desc=None, map_name='4x4'):\n", 83 | " self.desc = desc = np.asarray(MAPS[map_name], dtype='c')\n", 84 | " self.nrow, self.ncol = nrow, ncol = desc.shape\n", 85 | " self.shape = desc.shape\n", 86 | " \n", 87 | " nA = 4 # 动作集个数\n", 88 | " nS = np.prod(desc.shape) # 状态集个数\n", 89 | "\n", 90 | " MAX_Y = desc.shape[0]\n", 91 | " MAX_X = desc.shape[1]\n", 92 | "\n", 93 | " # initial state distribution [ 1. 0. 0. ...] \n", 94 | " isd = np.array(desc == b'S').astype('float64').ravel()\n", 95 | " isd /= isd.sum()\n", 96 | " \n", 97 | " P = {} \n", 98 | " state_grid = np.arange(nS).reshape(self.shape)\n", 99 | " it = np.nditer(state_grid, flags=['multi_index'])\n", 100 | " \n", 101 | " while not it.finished:\n", 102 | " s = it.iterindex\n", 103 | " y, x = it.multi_index\n", 104 | "\n", 105 | " # P[s][a] == [(probability, nextstate, reward, done), ...]\n", 106 | " P[s] = {a : [] for a in range(nA)}\n", 107 | "\n", 108 | " s_letter = desc[y][x]\n", 109 | " is_done = lambda letter: letter in b'GX'\n", 110 | " reward = 0.0 if s_letter in b'G' else -1.0\n", 111 | " \n", 112 | " if is_done(s_letter):\n", 113 | " P[s][UP] = [(1.0, s, reward, True)]\n", 114 | " P[s][RIGHT] = [(1.0, s, reward, True)]\n", 115 | " P[s][DOWN] = [(1.0, s, reward, True)]\n", 116 | " P[s][LEFT] = [(1.0, s, reward, True)]\n", 117 | " else:\n", 118 | " ns_up = s if y == 0 else s - MAX_X\n", 119 | " ns_right = s if x == (MAX_X - 1) else s + 1\n", 120 | " ns_down = s if y == (MAX_Y - 1) else s + MAX_X\n", 121 | " ns_left = s if x == 0 else s - 1\n", 122 | "\n", 123 | " sl_up = desc[ns_up//MAX_Y][ns_up%MAX_X]\n", 124 | " sl_right = desc[ns_right//MAX_Y][ns_right%MAX_X]\n", 125 | " sl_down = desc[ns_down//MAX_Y][ns_down%MAX_X]\n", 126 | " sl_left = desc[ns_left//MAX_Y][ns_left%MAX_X]\n", 127 | " \n", 128 | " P[s][UP] = [(1.0, ns_up, reward, is_done(sl_up))]\n", 129 | " P[s][RIGHT] = [(1.0, ns_right, reward, is_done(sl_right))]\n", 130 | " P[s][DOWN] = [(1.0, ns_down, reward, is_done(sl_down))]\n", 131 | " P[s][LEFT] = [(1.0, ns_left, reward, is_done(sl_left))]\n", 132 | " \n", 133 | " it.iternext()\n", 134 | " \n", 135 | " self.P = P\n", 136 | " \n", 137 | " super(GridworldEnv, self).__init__(nS, nA, P, isd)\n", 138 | "\n", 139 | " def _render(self, mode='human', close=False):\n", 140 | " if close: # 初始化环境Environment的时候不显示\n", 141 | " return\n", 142 | " \n", 143 | " outfile = StringIO() if mode == 'ansi' else sys.stdout\n", 144 | "\n", 145 | " desc = self.desc.tolist()\n", 146 | " desc = [[c.decode('utf-8') for c in line] for line in desc]\n", 147 | " \n", 148 | " state_grid = np.arange(self.nS).reshape(self.shape)\n", 149 | " it = np.nditer(state_grid, flags=['multi_index'])\n", 150 | " \n", 151 | " while not it.finished:\n", 152 | " s = it.iterindex\n", 153 | " y, x = it.multi_index\n", 154 | " \n", 155 | " # 对于当前状态用红色标注\n", 156 | " if self.s == s:\n", 157 | " desc[y][x] = utils.colorize(desc[y][x], \"red\", highlight=True)\n", 158 | " \n", 159 | " it.iternext()\n", 160 | " \n", 161 | " outfile.write(\"\\n\".join(' '.join(line) for line in desc)+\"\\n\")\n", 162 | "\n", 163 | " if mode != 'human':\n", 164 | " return outfile\n", 165 | " \n", 166 | "env = GridworldEnv()" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 62, 172 | "metadata": { 173 | "collapsed": false 174 | }, 175 | "outputs": [ 176 | { 177 | "name": "stdout", 178 | "output_type": "stream", 179 | "text": [ 180 | "\u001b[41mS\u001b[0m O O O\n", 181 | "O X O X\n", 182 | "O O O X\n", 183 | "X O O G\n", 184 | "action:0(Up)\n", 185 | "done:False, observation:0, reward:-1.0\n", 186 | "\u001b[41mS\u001b[0m O O O\n", 187 | "O X O X\n", 188 | "O O O X\n", 189 | "X O O G\n", 190 | "action:1(Right)\n", 191 | "done:False, observation:1, reward:-1.0\n", 192 | "S \u001b[41mO\u001b[0m O O\n", 193 | "O X O X\n", 194 | "O O O X\n", 195 | "X O O G\n", 196 | "action:2(Down)\n", 197 | "done:True, observation:5, reward:-1.0\n", 198 | "{ 0: { 0: [(1.0, 0, -1.0, False)],\n", 199 | " 1: [(1.0, 1, -1.0, False)],\n", 200 | " 2: [(1.0, 4, -1.0, False)],\n", 201 | " 3: [(1.0, 0, -1.0, False)]},\n", 202 | " 1: { 0: [(1.0, 1, -1.0, False)],\n", 203 | " 1: [(1.0, 2, -1.0, False)],\n", 204 | " 2: [(1.0, 5, -1.0, True)],\n", 205 | " 3: [(1.0, 0, -1.0, False)]},\n", 206 | " 2: { 0: [(1.0, 2, -1.0, False)],\n", 207 | " 1: [(1.0, 3, -1.0, False)],\n", 208 | " 2: [(1.0, 6, -1.0, False)],\n", 209 | " 3: [(1.0, 1, -1.0, False)]},\n", 210 | " 3: { 0: [(1.0, 3, -1.0, False)],\n", 211 | " 1: [(1.0, 3, -1.0, False)],\n", 212 | " 2: [(1.0, 7, -1.0, True)],\n", 213 | " 3: [(1.0, 2, -1.0, False)]},\n", 214 | " 4: { 0: [(1.0, 0, -1.0, False)],\n", 215 | " 1: [(1.0, 5, -1.0, True)],\n", 216 | " 2: [(1.0, 8, -1.0, False)],\n", 217 | " 3: [(1.0, 4, -1.0, False)]},\n", 218 | " 5: { 0: [(1.0, 5, -1.0, True)],\n", 219 | " 1: [(1.0, 5, -1.0, True)],\n", 220 | " 2: [(1.0, 5, -1.0, True)],\n", 221 | " 3: [(1.0, 5, -1.0, True)]},\n", 222 | " 6: { 0: [(1.0, 2, -1.0, False)],\n", 223 | " 1: [(1.0, 7, -1.0, True)],\n", 224 | " 2: [(1.0, 10, -1.0, False)],\n", 225 | " 3: [(1.0, 5, -1.0, True)]},\n", 226 | " 7: { 0: [(1.0, 7, -1.0, True)],\n", 227 | " 1: [(1.0, 7, -1.0, True)],\n", 228 | " 2: [(1.0, 7, -1.0, True)],\n", 229 | " 3: [(1.0, 7, -1.0, True)]},\n", 230 | " 8: { 0: [(1.0, 4, -1.0, False)],\n", 231 | " 1: [(1.0, 9, -1.0, False)],\n", 232 | " 2: [(1.0, 12, -1.0, True)],\n", 233 | " 3: [(1.0, 8, -1.0, False)]},\n", 234 | " 9: { 0: [(1.0, 5, -1.0, True)],\n", 235 | " 1: [(1.0, 10, -1.0, False)],\n", 236 | " 2: [(1.0, 13, -1.0, False)],\n", 237 | " 3: [(1.0, 8, -1.0, False)]},\n", 238 | " 10: { 0: [(1.0, 6, -1.0, False)],\n", 239 | " 1: [(1.0, 11, -1.0, True)],\n", 240 | " 2: [(1.0, 14, -1.0, False)],\n", 241 | " 3: [(1.0, 9, -1.0, False)]},\n", 242 | " 11: { 0: [(1.0, 11, -1.0, True)],\n", 243 | " 1: [(1.0, 11, -1.0, True)],\n", 244 | " 2: [(1.0, 11, -1.0, True)],\n", 245 | " 3: [(1.0, 11, -1.0, True)]},\n", 246 | " 12: { 0: [(1.0, 12, -1.0, True)],\n", 247 | " 1: [(1.0, 12, -1.0, True)],\n", 248 | " 2: [(1.0, 12, -1.0, True)],\n", 249 | " 3: [(1.0, 12, -1.0, True)]},\n", 250 | " 13: { 0: [(1.0, 9, -1.0, False)],\n", 251 | " 1: [(1.0, 14, -1.0, False)],\n", 252 | " 2: [(1.0, 13, -1.0, False)],\n", 253 | " 3: [(1.0, 12, -1.0, True)]},\n", 254 | " 14: { 0: [(1.0, 10, -1.0, False)],\n", 255 | " 1: [(1.0, 15, -1.0, True)],\n", 256 | " 2: [(1.0, 14, -1.0, False)],\n", 257 | " 3: [(1.0, 13, -1.0, False)]},\n", 258 | " 15: { 0: [(1.0, 15, 0.0, True)],\n", 259 | " 1: [(1.0, 15, 0.0, True)],\n", 260 | " 2: [(1.0, 15, 0.0, True)],\n", 261 | " 3: [(1.0, 15, 0.0, True)]}}\n", 262 | "Episode finished after 3 timesteps\n" 263 | ] 264 | } 265 | ], 266 | "source": [ 267 | "observation = env.reset()\n", 268 | "for _ in range(5):\n", 269 | " env.render()\n", 270 | " action = env.action_space.sample()\n", 271 | " observation, reward, done, info = env.step(action)\n", 272 | " print(\"action:{}({})\".format(action, [\"Up\",\"Right\",\"Down\",\"Left\"][action]))\n", 273 | " print(\"done:{}, observation:{}, reward:{}\".format(done, observation, reward))\n", 274 | " if done:\n", 275 | " pp.pprint(env.P)\n", 276 | " print(\"Episode finished after {} timesteps\".format(_+1))\n", 277 | " break" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": { 284 | "collapsed": true 285 | }, 286 | "outputs": [], 287 | "source": [] 288 | } 289 | ], 290 | "metadata": { 291 | "kernelspec": { 292 | "display_name": "Python 3", 293 | "language": "python", 294 | "name": "python3" 295 | }, 296 | "language_info": { 297 | "codemirror_mode": { 298 | "name": "ipython", 299 | "version": 3 300 | }, 301 | "file_extension": ".py", 302 | "mimetype": "text/x-python", 303 | "name": "python", 304 | "nbconvert_exporter": "python", 305 | "pygments_lexer": "ipython3", 306 | "version": "3.6.0" 307 | } 308 | }, 309 | "nbformat": 4, 310 | "nbformat_minor": 2 311 | } 312 | -------------------------------------------------------------------------------- /chapter2/Environment.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[49]: 5 | 6 | 7 | import numpy as np 8 | import sys 9 | from six import StringIO, b 10 | from pprint import PrettyPrinter 11 | get_ipython().run_line_magic('pprint', '') 12 | 13 | from gym import utils 14 | from gym.envs.toy_text import discrete 15 | 16 | 17 | # In[60]: 18 | 19 | 20 | pp = PrettyPrinter(indent=2) 21 | 22 | UP = 0 23 | RIGHT = 1 24 | DOWN = 2 25 | LEFT = 3 26 | 27 | MAPS = {'4x4':["SOOO","OXOX","OOOX","XOOG"]} 28 | 29 | 30 | # In[61]: 31 | 32 | 33 | class GridworldEnv(discrete.DiscreteEnv): 34 | """ 35 | FrozenLakeEnv1 is a copy environment from GYM toy_text FrozenLake-01 36 | 37 | You are an agent on an 4x4 grid and your goal is to reach the terminal 38 | state at the bottom right corner. 39 | 40 | For example, a 4x4 grid looks as follows: 41 | 42 | S O O O 43 | O X O X 44 | O O O X 45 | X O O G 46 | 47 | S : starting point, safe 48 | O : frozen surface, safe 49 | X : hole, fall to your doom 50 | G : goal, where the frisbee is located 51 | 52 | The episode ends when you reach the goal or fall in a hole. 53 | You receive a reward of 1 if you reach the goal, and zero otherwise. 54 | 55 | You can take actions in each direction (UP=0, RIGHT=1, DOWN=2, LEFT=3). 56 | Actions going off the edge leave you in your current state. 57 | """ 58 | metadata = {'render.modes': ['human', 'ansi']} 59 | 60 | def __init__(self, desc=None, map_name='4x4'): 61 | self.desc = desc = np.asarray(MAPS[map_name], dtype='c') 62 | self.nrow, self.ncol = nrow, ncol = desc.shape 63 | self.shape = desc.shape 64 | 65 | nA = 4 # 动作集个数 66 | nS = np.prod(desc.shape) # 状态集个数 67 | 68 | MAX_Y = desc.shape[0] 69 | MAX_X = desc.shape[1] 70 | 71 | # initial state distribution [ 1. 0. 0. ...] 72 | isd = np.array(desc == b'S').astype('float64').ravel() 73 | isd /= isd.sum() 74 | 75 | P = {} 76 | state_grid = np.arange(nS).reshape(self.shape) 77 | it = np.nditer(state_grid, flags=['multi_index']) 78 | 79 | while not it.finished: 80 | s = it.iterindex 81 | y, x = it.multi_index 82 | 83 | # P[s][a] == [(probability, nextstate, reward, done), ...] 84 | P[s] = {a : [] for a in range(nA)} 85 | 86 | s_letter = desc[y][x] 87 | is_done = lambda letter: letter in b'GX' 88 | reward = 0.0 if s_letter in b'G' else -1.0 89 | 90 | if is_done(s_letter): 91 | P[s][UP] = [(1.0, s, reward, True)] 92 | P[s][RIGHT] = [(1.0, s, reward, True)] 93 | P[s][DOWN] = [(1.0, s, reward, True)] 94 | P[s][LEFT] = [(1.0, s, reward, True)] 95 | else: 96 | ns_up = s if y == 0 else s - MAX_X 97 | ns_right = s if x == (MAX_X - 1) else s + 1 98 | ns_down = s if y == (MAX_Y - 1) else s + MAX_X 99 | ns_left = s if x == 0 else s - 1 100 | 101 | sl_up = desc[ns_up//MAX_Y][ns_up%MAX_X] 102 | sl_right = desc[ns_right//MAX_Y][ns_right%MAX_X] 103 | sl_down = desc[ns_down//MAX_Y][ns_down%MAX_X] 104 | sl_left = desc[ns_left//MAX_Y][ns_left%MAX_X] 105 | 106 | P[s][UP] = [(1.0, ns_up, reward, is_done(sl_up))] 107 | P[s][RIGHT] = [(1.0, ns_right, reward, is_done(sl_right))] 108 | P[s][DOWN] = [(1.0, ns_down, reward, is_done(sl_down))] 109 | P[s][LEFT] = [(1.0, ns_left, reward, is_done(sl_left))] 110 | 111 | it.iternext() 112 | 113 | self.P = P 114 | 115 | super(GridworldEnv, self).__init__(nS, nA, P, isd) 116 | 117 | def _render(self, mode='human', close=False): 118 | if close: # 初始化环境Environment的时候不显示 119 | return 120 | 121 | outfile = StringIO() if mode == 'ansi' else sys.stdout 122 | 123 | desc = self.desc.tolist() 124 | desc = [[c.decode('utf-8') for c in line] for line in desc] 125 | 126 | state_grid = np.arange(self.nS).reshape(self.shape) 127 | it = np.nditer(state_grid, flags=['multi_index']) 128 | 129 | while not it.finished: 130 | s = it.iterindex 131 | y, x = it.multi_index 132 | 133 | # 对于当前状态用红色标注 134 | if self.s == s: 135 | desc[y][x] = utils.colorize(desc[y][x], "red", highlight=True) 136 | 137 | it.iternext() 138 | 139 | outfile.write("\n".join(' '.join(line) for line in desc)+"\n") 140 | 141 | if mode != 'human': 142 | return outfile 143 | 144 | env = GridworldEnv() 145 | 146 | 147 | # In[62]: 148 | 149 | 150 | observation = env.reset() 151 | for _ in range(5): 152 | env.render() 153 | action = env.action_space.sample() 154 | observation, reward, done, info = env.step(action) 155 | print("action:{}({})".format(action, ["Up","Right","Down","Left"][action])) 156 | print("done:{}, observation:{}, reward:{}".format(done, observation, reward)) 157 | if done: 158 | pp.pprint(env.P) 159 | print("Episode finished after {} timesteps".format(_+1)) 160 | break 161 | 162 | -------------------------------------------------------------------------------- /chapter2/greedy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 8, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import random" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 17, 18 | "metadata": { 19 | "collapsed": false 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "def epsilon_greedy(nA, R, T, epsilon=0.6):\n", 24 | " \"\"\"\n", 25 | " 输入:\n", 26 | " nA 动作数量\n", 27 | " R 奖励函数\n", 28 | " T 迭代次数\n", 29 | " \"\"\"\n", 30 | " # 初始化累积奖励 r\n", 31 | " r = 0 \n", 32 | " count = [0]*nA\n", 33 | " \n", 34 | " for _ in range(T):\n", 35 | " if np.random.rand() < epsilon:\n", 36 | " # 探索:以均匀分布随机选择\n", 37 | " a = np.random.randint(q_value.shape[0])\n", 38 | " else:\n", 39 | " # 利用:选择价值函数最大的动作\n", 40 | " a = np.argmax(q_value[:])\n", 41 | " \n", 42 | " # 更新累积奖励和价值函数\n", 43 | " v = R(a)\n", 44 | " r = r + v\n", 45 | " q_value[a] = (q_value[a] * count[a] + v)/(count[a]+1)\n", 46 | " count[a] += 1\n", 47 | " \n", 48 | " return r" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "collapsed": true 56 | }, 57 | "outputs": [], 58 | "source": [] 59 | } 60 | ], 61 | "metadata": { 62 | "kernelspec": { 63 | "display_name": "Python 3", 64 | "language": "python", 65 | "name": "python3" 66 | }, 67 | "language_info": { 68 | "codemirror_mode": { 69 | "name": "ipython", 70 | "version": 3 71 | }, 72 | "file_extension": ".py", 73 | "mimetype": "text/x-python", 74 | "name": "python", 75 | "nbconvert_exporter": "python", 76 | "pygments_lexer": "ipython3", 77 | "version": "3.6.0" 78 | } 79 | }, 80 | "nbformat": 4, 81 | "nbformat_minor": 2 82 | } 83 | -------------------------------------------------------------------------------- /chapter2/greedy.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[8]: 5 | 6 | 7 | import numpy as np 8 | import random 9 | 10 | 11 | # In[17]: 12 | 13 | 14 | def epsilon_greedy(nA, R, T, epsilon=0.6): 15 | """ 16 | 输入: 17 | nA 动作数量 18 | R 奖励函数 19 | T 迭代次数 20 | """ 21 | # 初始化累积奖励 r 22 | r = 0 23 | count = [0]*nA 24 | 25 | for _ in range(T): 26 | if np.random.rand() < epsilon: 27 | # 探索:以均匀分布随机选择 28 | a = np.random.randint(q_value.shape[0]) 29 | else: 30 | # 利用:选择价值函数最大的动作 31 | a = np.argmax(q_value[:]) 32 | 33 | # 更新累积奖励和价值函数 34 | v = R(a) 35 | r = r + v 36 | q_value[a] = (q_value[a] * count[a] + v)/(count[a]+1) 37 | count[a] += 1 38 | 39 | return r 40 | 41 | -------------------------------------------------------------------------------- /chapter3/Policy Evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stdout", 12 | "output_type": "stream", 13 | "text": [ 14 | "Pretty printing has been turned OFF\n", 15 | "Pretty printing has been turned ON\n" 16 | ] 17 | } 18 | ], 19 | "source": [ 20 | "import numpy as np\n", 21 | "import pprint\n", 22 | "from Environment import GridworldEnv\n", 23 | "from pprint import PrettyPrinter\n", 24 | "\n", 25 | "%pprint\n", 26 | "pp = PrettyPrinter(indent=4)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 32, 32 | "metadata": { 33 | "collapsed": false 34 | }, 35 | "outputs": [ 36 | { 37 | "name": "stdout", 38 | "output_type": "stream", 39 | "text": [ 40 | "[[ 0.25 0.25 0.25 0.25]\n", 41 | " [ 0.25 0.25 0.25 0.25]\n", 42 | " [ 0.25 0.25 0.25 0.25]\n", 43 | " [ 0.25 0.25 0.25 0.25]\n", 44 | " [ 0.25 0.25 0.25 0.25]\n", 45 | " [ 0.25 0.25 0.25 0.25]\n", 46 | " [ 0.25 0.25 0.25 0.25]\n", 47 | " [ 0.25 0.25 0.25 0.25]\n", 48 | " [ 0.25 0.25 0.25 0.25]\n", 49 | " [ 0.25 0.25 0.25 0.25]\n", 50 | " [ 0.25 0.25 0.25 0.25]\n", 51 | " [ 0.25 0.25 0.25 0.25]\n", 52 | " [ 0.25 0.25 0.25 0.25]\n", 53 | " [ 0.25 0.25 0.25 0.25]\n", 54 | " [ 0.25 0.25 0.25 0.25]\n", 55 | " [ 0.25 0.25 0.25 0.25]]\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "env = GridworldEnv()\n", 61 | "random_policy = np.ones([env.nS, env.nA])/env.nA\n", 62 | "print(random_policy)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 41, 68 | "metadata": { 69 | "collapsed": true 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "def policy_eval(policy, environment, discount_factor=1.0, theta=1.0):\n", 74 | " env = environment # 环境变量\n", 75 | " \n", 76 | " # 初始化一个全0的价值函数\n", 77 | " V = np.zeros(env.nS)\n", 78 | " \n", 79 | " # 迭代开始\n", 80 | " for _ in range(10000):\n", 81 | " delta = 0\n", 82 | " \n", 83 | " # 对于GridWorld中的每一个状态都进行全备份\n", 84 | " for s in range(env.nS):\n", 85 | " v = 0\n", 86 | " # 检查下一个有可能执行的动作\n", 87 | " for a, action_prob in enumerate(policy[s]):\n", 88 | " \n", 89 | " # 对于每一个动作检查下一个状态\n", 90 | " for prob, next_state, reward, done in env.P[s][a]:\n", 91 | " # 累积计算下一个动作的期望价值\n", 92 | " v += action_prob * prob * (reward + discount_factor * V[next_state])\n", 93 | " \n", 94 | " # 选出最大的变化量\n", 95 | " delta = max(delta, np.abs(v - V[s]))\n", 96 | " V[s] = v\n", 97 | " \n", 98 | " print(\"=\"*60, _)\n", 99 | " print(V.reshape(env.shape))\n", 100 | " \n", 101 | " # 停止标志位\n", 102 | " if delta <= theta:\n", 103 | " break\n", 104 | " \n", 105 | " return np.array(V)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": { 112 | "collapsed": false 113 | }, 114 | "outputs": [ 115 | { 116 | "name": "stdout", 117 | "output_type": "stream", 118 | "text": [ 119 | "============================================================ 0\n", 120 | "[[-1. -1.25 -1.3125 -1.328125 ]\n", 121 | " [-1.25 -1. -1.578125 -1. ]\n", 122 | " [-1.3125 -1.578125 -1.7890625 -1. ]\n", 123 | " [-1. -1.64453125 -1.85839844 1. ]]\n", 124 | "============================================================ 1\n", 125 | "[[-2.125 -2.421875 -2.66015625 -2.57910156]\n", 126 | " [-2.421875 -2. -2.86230469 -2. ]\n", 127 | " [-2.578125 -3.00292969 -3.1809082 -2. ]\n", 128 | " [-2. -3.12646484 -2.79144287 2. ]]\n", 129 | "============================================================ 2\n", 130 | "[[-3.2734375 -3.58886719 -3.92260742 -3.77020264]\n", 131 | " [-3.56835938 -3. -4.02587891 -3. ]\n", 132 | " [-3.78735352 -4.27368164 -4.27275085 -3. ]\n", 133 | " [-3. -4.29789734 -3.34052277 3. ]]\n", 134 | "============================================================ 3\n", 135 | "[[-4.42602539 -4.734375 -5.11326599 -4.91341782]\n", 136 | " [-4.69543457 -4. -5.09650421 -4. ]\n", 137 | " [-4.93911743 -5.37744141 -5.2036171 -4. ]\n", 138 | " [-4. -5.25396538 -3.69952631 4. ]]\n", 139 | "============================================================ 4\n", 140 | "[[-5.57046509 -5.85452652 -6.24442863 -6.01781607]\n", 141 | " [-5.80125427 -5. -6.11201143 -5. ]\n", 142 | " [-6.02945328 -6.37175894 -6.04582417 -5. ]\n", 143 | " [-5. -6.08131266 -3.95666578 5. ]]\n", 144 | "============================================================ 5\n", 145 | "[[-6.69917774 -6.94953322 -7.33094734 -7.09164487]\n", 146 | " [-6.88247132 -6. -7.09419288 -6. ]\n", 147 | " [-7.07092088 -7.29951443 -6.83759327 -6. ]\n", 148 | " [-6. -6.83437322 -4.15715807 6. ]]\n", 149 | "============================================================ 6\n", 150 | "[[-7.80759001 -8.02201764 -8.38470068 -8.1419976 ]\n", 151 | " [-7.94024555 -7. -8.05557349 -7. ]\n", 152 | " [-8.07767022 -8.18740918 -7.60003518 -7. ]\n", 153 | " [-7. -7.54473512 -4.32548209 7. ]]\n", 154 | "============================================================ 7\n", 155 | "[[-8.8943608 -9.07526978 -9.41438539 -9.17459515]\n", 156 | " [-8.97806914 -8. -9.00360514 -8. ]\n", 157 | " [-9.06078713 -9.05138936 -8.34511915 -8. ]\n", 158 | " [-8. -8.23040164 -4.47525072 8. ]]\n", 159 | "============================================================ 8\n", 160 | "[[ -9.96051513 -10.11254258 -10.42628206 -10.19386809]\n", 161 | " [ -9.99984285 -9. -9.9428503 -9. ]\n", 162 | " [-10.02800484 -9.90088141 -9.07974561 -9. ]\n", 163 | " [ -9. -8.90163344 -4.61415744 9. ]]\n", 164 | "============================================================ 9\n", 165 | "[[-11.00835392 -11.13679464 -11.42494877 -11.20317124]\n", 166 | " [-11.0090504 -10. -10.8761736 -10. ]\n", 167 | " [-10.98448416 -10.7414658 -9.80794921 -10. ]\n", 168 | " [-10. -9.56431417 -4.74660521 10. ]]\n", 169 | "============================================================ 10\n", 170 | "[[-12.04063822 -12.15059541 -12.41372225 -12.20501618]\n", 171 | " [-12.0085432 -11. -11.80541787 -11. ]\n", 172 | " [-11.93362329 -11.57647167 -10.53212369 -11. ]\n", 173 | " [-11. -10.22184776 -4.87514416 11. ]]\n", 174 | "============================================================ 11\n", 175 | "[[-13.06010376 -13.15610536 -13.39506542 -13.20127445]\n", 176 | " [-13.00056756 -12. -12.73179728 -12. ]\n", 177 | " [-12.87766563 -12.40790927 -11.25371268 -12. ]\n", 178 | " [-12. -10.8762253 -5.00127053 12. ]]\n", 179 | "============================================================ 12\n", 180 | "[[-14.06922011 -14.15509772 -14.37080871 -14.1933394 ]\n", 181 | " [-13.98686333 -13. -13.65613035 -13. ]\n", 182 | " [-13.81810956 -13.23701188 -11.97360319 -13. ]\n", 183 | " [-13. -11.52862693 -5.12587516 13. ]]\n", 184 | "============================================================ 13\n", 185 | "[[-15.07010032 -15.14900169 -15.34232004 -15.18224971]\n", 186 | " [-14.9687683 -14. -14.57898081 -14. ]\n", 187 | " [-14.75597243 -14.06455064 -12.69235165 -14. ]\n", 188 | " [-14. -12.17976318 -5.2494975 14. ]]\n", 189 | "============================================================ 14\n", 190 | "[[-16.06449266 -16.1389536 -16.31062604 -16.16878136]\n", 191 | " [-15.94730835 -15. -15.50074442 -15. ]\n", 192 | " [-15.69195786 -14.89101817 -13.41031502 -15. ]\n", 193 | " [-15. -12.83006971 -5.37247056 15. ]]\n", 194 | "Reshaped Grid Value Function:\n", 195 | "[[-16.06449266 -16.1389536 -16.31062604 -16.16878136]\n", 196 | " [-15.94730835 -15. -15.50074442 -15. ]\n", 197 | " [-15.69195786 -14.89101817 -13.41031502 -15. ]\n", 198 | " [-15. -12.83006971 -5.37247056 15. ]]\n", 199 | "\n" 200 | ] 201 | } 202 | ], 203 | "source": [ 204 | "v = policy_eval(random_policy, env)\n", 205 | "print(\"Reshaped Grid Value Function:\")\n", 206 | "print(v.reshape(env.shape))\n", 207 | "print(\"\")" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 21, 213 | "metadata": { 214 | "collapsed": false 215 | }, 216 | "outputs": [ 217 | { 218 | "name": "stdout", 219 | "output_type": "stream", 220 | "text": [ 221 | "[array([2, 1, 3, 1, 2, 0, 1, 0, 1, 1, 0, 1, 0, 3, 2, 3]),\n", 222 | " array([0, 3, 1, 1, 0, 3, 3, 1, 2, 1, 0, 1, 2, 1, 2, 1]),\n", 223 | " array([2, 2, 2, 1, 2, 3, 1, 1, 1, 1, 1, 1, 2, 0, 3, 1]),\n", 224 | " array([3, 2, 0, 0, 2, 1, 1, 2, 0, 1, 3, 3, 0, 3, 0, 1]),\n", 225 | " array([1, 1, 1, 1, 1, 1, 3, 1, 0, 3, 1, 0, 0, 2, 3, 2]),\n", 226 | " array([0, 0, 0, 0, 0, 2, 3, 1, 1, 3, 2, 1, 3, 1, 0, 3]),\n", 227 | " array([2, 1, 3, 3, 2, 3, 2, 3, 2, 3, 3, 1, 0, 3, 0, 3]),\n", 228 | " array([2, 2, 2, 1, 3, 1, 0, 1, 3, 0, 2, 3, 3, 1, 0, 3]),\n", 229 | " array([0, 1, 2, 3, 0, 2, 2, 0, 0, 2, 2, 2, 3, 3, 1, 3]),\n", 230 | " array([3, 1, 2, 0, 0, 3, 0, 0, 3, 1, 3, 3, 0, 2, 1, 1])]\n" 231 | ] 232 | } 233 | ], 234 | "source": [ 235 | "def gen_random_policy():\n", 236 | " return np.random.choice(4, size=((16)))\n", 237 | "\n", 238 | "n_policy = 100\n", 239 | "policy_pop = [gen_random_policy() for _ in range(n_policy)]\n", 240 | "pprint.pprint(policy_pop[:10])" 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": {}, 246 | "source": [ 247 | "# Test the Spesfiy Input Policy\n", 248 | "\n", 249 | "if I input a speciay policy what the value funciton calcuate by the policy evalutaiton." 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 9, 255 | "metadata": { 256 | "collapsed": false 257 | }, 258 | "outputs": [ 259 | { 260 | "name": "stdout", 261 | "output_type": "stream", 262 | "text": [ 263 | "[[ 0. 0. 1. 0.]\n", 264 | " [ 0. 1. 0. 0.]\n", 265 | " [ 0. 0. 1. 0.]\n", 266 | " [ 0. 0. 0. 1.]\n", 267 | " [ 0. 0. 1. 0.]\n", 268 | " [ 1. 0. 0. 0.]\n", 269 | " [ 0. 0. 1. 0.]\n", 270 | " [ 1. 0. 0. 0.]\n", 271 | " [ 0. 1. 0. 0.]\n", 272 | " [ 0. 0. 1. 0.]\n", 273 | " [ 0. 0. 1. 0.]\n", 274 | " [ 1. 0. 0. 0.]\n", 275 | " [ 1. 0. 0. 0.]\n", 276 | " [ 0. 1. 0. 0.]\n", 277 | " [ 0. 1. 0. 0.]\n", 278 | " [ 1. 0. 0. 0.]]\n" 279 | ] 280 | } 281 | ], 282 | "source": [ 283 | "input_policy = [2,1,2,3,2,0,2,0,1,2,2,0,0,1,1,0]\n", 284 | "\n", 285 | "env = GridworldEnv()\n", 286 | "policy = np.zeros([env.nS, env.nA])\n", 287 | "\n", 288 | "for _, x in enumerate(input_policy):\n", 289 | " policy[_][x] = 1\n", 290 | " \n", 291 | "print(policy)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 7, 297 | "metadata": { 298 | "collapsed": true 299 | }, 300 | "outputs": [], 301 | "source": [ 302 | "def policy_eval(policy, environment, discount_factor=1.0, theta=0.1):\n", 303 | " env = environment # 环境变量\n", 304 | " \n", 305 | " # 初始化一个全0的价值函数\n", 306 | " V = np.zeros(env.nS)\n", 307 | " \n", 308 | " # 迭代开始\n", 309 | " for _ in range(50):\n", 310 | " delta = 0\n", 311 | " \n", 312 | " # 对于GridWorld中的每一个状态都进行全备份\n", 313 | " for s in range(env.nS):\n", 314 | " v = 0\n", 315 | " # 检查下一个有可能执行的动作\n", 316 | " for a, action_prob in enumerate(policy[s]):\n", 317 | " \n", 318 | " # 对于每一个动作检查下一个状态\n", 319 | " for prob, next_state, reward, done in env.P[s][a]:\n", 320 | " # 累积计算下一个动作的期望价值\n", 321 | " v += action_prob * prob * (reward + discount_factor * V[next_state])\n", 322 | " \n", 323 | " # 选出最大的变化量\n", 324 | " delta = max(delta, np.abs(v - V[s]))\n", 325 | " V[s] = v\n", 326 | " \n", 327 | " print(\"=\"*60, _)\n", 328 | " print(V.reshape(env.shape))\n", 329 | " \n", 330 | " # 停止标志位\n", 331 | " if delta <= theta:\n", 332 | " break\n", 333 | " \n", 334 | " return np.array(V)" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 8, 340 | "metadata": { 341 | "collapsed": false 342 | }, 343 | "outputs": [ 344 | { 345 | "name": "stdout", 346 | "output_type": "stream", 347 | "text": [ 348 | "============================================================ 0\n", 349 | "[[-1. -2. -1. -1.]\n", 350 | " [-1. -1. -1. -1.]\n", 351 | " [-1. -1. -1. -1.]\n", 352 | " [-1. -1. -1. 1.]]\n", 353 | "============================================================ 1\n", 354 | "[[-2. -3. -2. -2.]\n", 355 | " [-2. -2. -2. -2.]\n", 356 | " [-2. -2. -2. -2.]\n", 357 | " [-2. -2. 0. 2.]]\n", 358 | "============================================================ 2\n", 359 | "[[-3. -4. -3. -3.]\n", 360 | " [-3. -3. -3. -3.]\n", 361 | " [-3. -3. -1. -3.]\n", 362 | " [-3. -1. 1. 3.]]\n", 363 | "============================================================ 3\n", 364 | "[[-4. -5. -4. -4.]\n", 365 | " [-4. -4. -2. -4.]\n", 366 | " [-4. -2. 0. -4.]\n", 367 | " [-4. 0. 2. 4.]]\n", 368 | "============================================================ 4\n", 369 | "[[-5. -6. -3. -5.]\n", 370 | " [-5. -5. -1. -5.]\n", 371 | " [-3. -1. 1. -5.]\n", 372 | " [-5. 1. 3. 5.]]\n", 373 | "============================================================ 5\n", 374 | "[[-6. -7. -2. -6.]\n", 375 | " [-4. -6. 0. -6.]\n", 376 | " [-2. 0. 2. -6.]\n", 377 | " [-6. 2. 4. 6.]]\n", 378 | "============================================================ 6\n", 379 | "[[-5. -6. -1. -7.]\n", 380 | " [-3. -7. 1. -7.]\n", 381 | " [-1. 1. 3. -7.]\n", 382 | " [-7. 3. 5. 7.]]\n", 383 | "============================================================ 7\n", 384 | "[[-4. -5. 0. -8.]\n", 385 | " [-2. -8. 2. -8.]\n", 386 | " [ 0. 2. 4. -8.]\n", 387 | " [-8. 4. 6. 8.]]\n", 388 | "============================================================ 8\n", 389 | "[[-3. -4. 1. -9.]\n", 390 | " [-1. -9. 3. -9.]\n", 391 | " [ 1. 3. 5. -9.]\n", 392 | " [-9. 5. 7. 9.]]\n", 393 | "============================================================ 9\n", 394 | "[[ -2. -3. 2. -10.]\n", 395 | " [ 0. -10. 4. -10.]\n", 396 | " [ 2. 4. 6. -10.]\n", 397 | " [-10. 6. 8. 10.]]\n", 398 | "============================================================ 10\n", 399 | "[[ -1. -2. 3. -11.]\n", 400 | " [ 1. -11. 5. -11.]\n", 401 | " [ 3. 5. 7. -11.]\n", 402 | " [-11. 7. 9. 11.]]\n", 403 | "============================================================ 11\n", 404 | "[[ 0. -1. 4. -12.]\n", 405 | " [ 2. -12. 6. -12.]\n", 406 | " [ 4. 6. 8. -12.]\n", 407 | " [-12. 8. 10. 12.]]\n", 408 | "============================================================ 12\n", 409 | "[[ 1. 0. 5. -13.]\n", 410 | " [ 3. -13. 7. -13.]\n", 411 | " [ 5. 7. 9. -13.]\n", 412 | " [-13. 9. 11. 13.]]\n", 413 | "============================================================ 13\n", 414 | "[[ 2. 1. 6. -14.]\n", 415 | " [ 4. -14. 8. -14.]\n", 416 | " [ 6. 8. 10. -14.]\n", 417 | " [-14. 10. 12. 14.]]\n", 418 | "============================================================ 14\n", 419 | "[[ 3. 2. 7. -15.]\n", 420 | " [ 5. -15. 9. -15.]\n", 421 | " [ 7. 9. 11. -15.]\n", 422 | " [-15. 11. 13. 15.]]\n", 423 | "============================================================ 15\n", 424 | "[[ 4. 3. 8. -16.]\n", 425 | " [ 6. -16. 10. -16.]\n", 426 | " [ 8. 10. 12. -16.]\n", 427 | " [-16. 12. 14. 16.]]\n", 428 | "============================================================ 16\n", 429 | "[[ 5. 4. 9. -17.]\n", 430 | " [ 7. -17. 11. -17.]\n", 431 | " [ 9. 11. 13. -17.]\n", 432 | " [-17. 13. 15. 17.]]\n", 433 | "============================================================ 17\n", 434 | "[[ 6. 5. 10. -18.]\n", 435 | " [ 8. -18. 12. -18.]\n", 436 | " [ 10. 12. 14. -18.]\n", 437 | " [-18. 14. 16. 18.]]\n", 438 | "============================================================ 18\n", 439 | "[[ 7. 6. 11. -19.]\n", 440 | " [ 9. -19. 13. -19.]\n", 441 | " [ 11. 13. 15. -19.]\n", 442 | " [-19. 15. 17. 19.]]\n", 443 | "============================================================ 19\n", 444 | "[[ 8. 7. 12. -20.]\n", 445 | " [ 10. -20. 14. -20.]\n", 446 | " [ 12. 14. 16. -20.]\n", 447 | " [-20. 16. 18. 20.]]\n", 448 | "============================================================ 20\n", 449 | "[[ 9. 8. 13. -21.]\n", 450 | " [ 11. -21. 15. -21.]\n", 451 | " [ 13. 15. 17. -21.]\n", 452 | " [-21. 17. 19. 21.]]\n", 453 | "============================================================ 21\n", 454 | "[[ 10. 9. 14. -22.]\n", 455 | " [ 12. -22. 16. -22.]\n", 456 | " [ 14. 16. 18. -22.]\n", 457 | " [-22. 18. 20. 22.]]\n", 458 | "============================================================ 22\n", 459 | "[[ 11. 10. 15. -23.]\n", 460 | " [ 13. -23. 17. -23.]\n", 461 | " [ 15. 17. 19. -23.]\n", 462 | " [-23. 19. 21. 23.]]\n", 463 | "============================================================ 23\n", 464 | "[[ 12. 11. 16. -24.]\n", 465 | " [ 14. -24. 18. -24.]\n", 466 | " [ 16. 18. 20. -24.]\n", 467 | " [-24. 20. 22. 24.]]\n", 468 | "============================================================ 24\n", 469 | "[[ 13. 12. 17. -25.]\n", 470 | " [ 15. -25. 19. -25.]\n", 471 | " [ 17. 19. 21. -25.]\n", 472 | " [-25. 21. 23. 25.]]\n", 473 | "============================================================ 25\n", 474 | "[[ 14. 13. 18. -26.]\n", 475 | " [ 16. -26. 20. -26.]\n", 476 | " [ 18. 20. 22. -26.]\n", 477 | " [-26. 22. 24. 26.]]\n", 478 | "============================================================ 26\n", 479 | "[[ 15. 14. 19. -27.]\n", 480 | " [ 17. -27. 21. -27.]\n", 481 | " [ 19. 21. 23. -27.]\n", 482 | " [-27. 23. 25. 27.]]\n", 483 | "============================================================ 27\n", 484 | "[[ 16. 15. 20. -28.]\n", 485 | " [ 18. -28. 22. -28.]\n", 486 | " [ 20. 22. 24. -28.]\n", 487 | " [-28. 24. 26. 28.]]\n", 488 | "============================================================ 28\n", 489 | "[[ 17. 16. 21. -29.]\n", 490 | " [ 19. -29. 23. -29.]\n", 491 | " [ 21. 23. 25. -29.]\n", 492 | " [-29. 25. 27. 29.]]\n", 493 | "============================================================ 29\n", 494 | "[[ 18. 17. 22. -30.]\n", 495 | " [ 20. -30. 24. -30.]\n", 496 | " [ 22. 24. 26. -30.]\n", 497 | " [-30. 26. 28. 30.]]\n", 498 | "============================================================ 30\n", 499 | "[[ 19. 18. 23. -31.]\n", 500 | " [ 21. -31. 25. -31.]\n", 501 | " [ 23. 25. 27. -31.]\n", 502 | " [-31. 27. 29. 31.]]\n", 503 | "============================================================ 31\n", 504 | "[[ 20. 19. 24. -32.]\n", 505 | " [ 22. -32. 26. -32.]\n", 506 | " [ 24. 26. 28. -32.]\n", 507 | " [-32. 28. 30. 32.]]\n", 508 | "============================================================ 32\n", 509 | "[[ 21. 20. 25. -33.]\n", 510 | " [ 23. -33. 27. -33.]\n", 511 | " [ 25. 27. 29. -33.]\n", 512 | " [-33. 29. 31. 33.]]\n", 513 | "============================================================ 33\n", 514 | "[[ 22. 21. 26. -34.]\n", 515 | " [ 24. -34. 28. -34.]\n", 516 | " [ 26. 28. 30. -34.]\n", 517 | " [-34. 30. 32. 34.]]\n", 518 | "============================================================ 34\n", 519 | "[[ 23. 22. 27. -35.]\n", 520 | " [ 25. -35. 29. -35.]\n", 521 | " [ 27. 29. 31. -35.]\n", 522 | " [-35. 31. 33. 35.]]\n", 523 | "============================================================ 35\n", 524 | "[[ 24. 23. 28. -36.]\n", 525 | " [ 26. -36. 30. -36.]\n", 526 | " [ 28. 30. 32. -36.]\n", 527 | " [-36. 32. 34. 36.]]\n", 528 | "============================================================ 36\n", 529 | "[[ 25. 24. 29. -37.]\n", 530 | " [ 27. -37. 31. -37.]\n", 531 | " [ 29. 31. 33. -37.]\n", 532 | " [-37. 33. 35. 37.]]\n", 533 | "============================================================ 37\n", 534 | "[[ 26. 25. 30. -38.]\n", 535 | " [ 28. -38. 32. -38.]\n", 536 | " [ 30. 32. 34. -38.]\n", 537 | " [-38. 34. 36. 38.]]\n", 538 | "============================================================ 38\n", 539 | "[[ 27. 26. 31. -39.]\n", 540 | " [ 29. -39. 33. -39.]\n", 541 | " [ 31. 33. 35. -39.]\n", 542 | " [-39. 35. 37. 39.]]\n", 543 | "============================================================ 39\n", 544 | "[[ 28. 27. 32. -40.]\n", 545 | " [ 30. -40. 34. -40.]\n", 546 | " [ 32. 34. 36. -40.]\n", 547 | " [-40. 36. 38. 40.]]\n", 548 | "============================================================ 40\n", 549 | "[[ 29. 28. 33. -41.]\n", 550 | " [ 31. -41. 35. -41.]\n", 551 | " [ 33. 35. 37. -41.]\n", 552 | " [-41. 37. 39. 41.]]\n", 553 | "============================================================ 41\n", 554 | "[[ 30. 29. 34. -42.]\n", 555 | " [ 32. -42. 36. -42.]\n", 556 | " [ 34. 36. 38. -42.]\n", 557 | " [-42. 38. 40. 42.]]\n", 558 | "============================================================ 42\n", 559 | "[[ 31. 30. 35. -43.]\n", 560 | " [ 33. -43. 37. -43.]\n", 561 | " [ 35. 37. 39. -43.]\n", 562 | " [-43. 39. 41. 43.]]\n", 563 | "============================================================ 43\n", 564 | "[[ 32. 31. 36. -44.]\n", 565 | " [ 34. -44. 38. -44.]\n", 566 | " [ 36. 38. 40. -44.]\n", 567 | " [-44. 40. 42. 44.]]\n", 568 | "============================================================ 44\n", 569 | "[[ 33. 32. 37. -45.]\n", 570 | " [ 35. -45. 39. -45.]\n", 571 | " [ 37. 39. 41. -45.]\n", 572 | " [-45. 41. 43. 45.]]\n", 573 | "============================================================ 45\n", 574 | "[[ 34. 33. 38. -46.]\n", 575 | " [ 36. -46. 40. -46.]\n", 576 | " [ 38. 40. 42. -46.]\n", 577 | " [-46. 42. 44. 46.]]\n", 578 | "============================================================ 46\n", 579 | "[[ 35. 34. 39. -47.]\n", 580 | " [ 37. -47. 41. -47.]\n", 581 | " [ 39. 41. 43. -47.]\n", 582 | " [-47. 43. 45. 47.]]\n", 583 | "============================================================ 47\n", 584 | "[[ 36. 35. 40. -48.]\n", 585 | " [ 38. -48. 42. -48.]\n", 586 | " [ 40. 42. 44. -48.]\n", 587 | " [-48. 44. 46. 48.]]\n", 588 | "============================================================ 48\n", 589 | "[[ 37. 36. 41. -49.]\n", 590 | " [ 39. -49. 43. -49.]\n", 591 | " [ 41. 43. 45. -49.]\n", 592 | " [-49. 45. 47. 49.]]\n", 593 | "============================================================ 49\n", 594 | "[[ 38. 37. 42. -50.]\n", 595 | " [ 40. -50. 44. -50.]\n", 596 | " [ 42. 44. 46. -50.]\n", 597 | " [-50. 46. 48. 50.]]\n", 598 | "Reshaped Grid Value Function:\n", 599 | "[[ 38. 37. 42. -50.]\n", 600 | " [ 40. -50. 44. -50.]\n", 601 | " [ 42. 44. 46. -50.]\n", 602 | " [-50. 46. 48. 50.]]\n", 603 | "\n" 604 | ] 605 | } 606 | ], 607 | "source": [ 608 | "v = policy_eval(policy, env)\n", 609 | "print(\"Reshaped Grid Value Function:\")\n", 610 | "print(v.reshape(env.shape))\n", 611 | "print(\"\")" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": null, 617 | "metadata": { 618 | "collapsed": true 619 | }, 620 | "outputs": [], 621 | "source": [] 622 | } 623 | ], 624 | "metadata": { 625 | "kernelspec": { 626 | "display_name": "Python 3", 627 | "language": "python", 628 | "name": "python3" 629 | }, 630 | "language_info": { 631 | "codemirror_mode": { 632 | "name": "ipython", 633 | "version": 3 634 | }, 635 | "file_extension": ".py", 636 | "mimetype": "text/x-python", 637 | "name": "python", 638 | "nbconvert_exporter": "python", 639 | "pygments_lexer": "ipython3", 640 | "version": "3.6.0" 641 | } 642 | }, 643 | "nbformat": 4, 644 | "nbformat_minor": 2 645 | } 646 | -------------------------------------------------------------------------------- /chapter3/Policy Evaluation.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[2]: 5 | 6 | 7 | import numpy as np 8 | import pprint 9 | from Environment import GridworldEnv 10 | from pprint import PrettyPrinter 11 | 12 | get_ipython().run_line_magic('pprint', '') 13 | pp = PrettyPrinter(indent=4) 14 | 15 | 16 | # In[32]: 17 | 18 | 19 | env = GridworldEnv() 20 | random_policy = np.ones([env.nS, env.nA])/env.nA 21 | print(random_policy) 22 | 23 | 24 | # In[41]: 25 | 26 | 27 | def policy_eval(policy, environment, discount_factor=1.0, theta=1.0): 28 | env = environment # 环境变量 29 | 30 | # 初始化一个全0的价值函数 31 | V = np.zeros(env.nS) 32 | 33 | # 迭代开始 34 | for _ in range(10000): 35 | delta = 0 36 | 37 | # 对于GridWorld中的每一个状态都进行全备份 38 | for s in range(env.nS): 39 | v = 0 40 | # 检查下一个有可能执行的动作 41 | for a, action_prob in enumerate(policy[s]): 42 | 43 | # 对于每一个动作检查下一个状态 44 | for prob, next_state, reward, done in env.P[s][a]: 45 | # 累积计算下一个动作的期望价值 46 | v += action_prob * prob * (reward + discount_factor * V[next_state]) 47 | 48 | # 选出最大的变化量 49 | delta = max(delta, np.abs(v - V[s])) 50 | V[s] = v 51 | 52 | print("="*60, _) 53 | print(V.reshape(env.shape)) 54 | 55 | # 停止标志位 56 | if delta <= theta: 57 | break 58 | 59 | return np.array(V) 60 | 61 | 62 | # In[6]: 63 | 64 | 65 | v = policy_eval(random_policy, env) 66 | print("Reshaped Grid Value Function:") 67 | print(v.reshape(env.shape)) 68 | print("") 69 | 70 | 71 | # In[21]: 72 | 73 | 74 | def gen_random_policy(): 75 | return np.random.choice(4, size=((16))) 76 | 77 | n_policy = 100 78 | policy_pop = [gen_random_policy() for _ in range(n_policy)] 79 | pprint.pprint(policy_pop[:10]) 80 | 81 | 82 | # # Test the Spesfiy Input Policy 83 | # 84 | # if I input a speciay policy what the value funciton calcuate by the policy evalutaiton. 85 | 86 | # In[9]: 87 | 88 | 89 | input_policy = [2,1,2,3,2,0,2,0,1,2,2,0,0,1,1,0] 90 | 91 | env = GridworldEnv() 92 | policy = np.zeros([env.nS, env.nA]) 93 | 94 | for _, x in enumerate(input_policy): 95 | policy[_][x] = 1 96 | 97 | print(policy) 98 | 99 | 100 | # In[7]: 101 | 102 | 103 | def policy_eval(policy, environment, discount_factor=1.0, theta=0.1): 104 | env = environment # 环境变量 105 | 106 | # 初始化一个全0的价值函数 107 | V = np.zeros(env.nS) 108 | 109 | # 迭代开始 110 | for _ in range(50): 111 | delta = 0 112 | 113 | # 对于GridWorld中的每一个状态都进行全备份 114 | for s in range(env.nS): 115 | v = 0 116 | # 检查下一个有可能执行的动作 117 | for a, action_prob in enumerate(policy[s]): 118 | 119 | # 对于每一个动作检查下一个状态 120 | for prob, next_state, reward, done in env.P[s][a]: 121 | # 累积计算下一个动作的期望价值 122 | v += action_prob * prob * (reward + discount_factor * V[next_state]) 123 | 124 | # 选出最大的变化量 125 | delta = max(delta, np.abs(v - V[s])) 126 | V[s] = v 127 | 128 | print("="*60, _) 129 | print(V.reshape(env.shape)) 130 | 131 | # 停止标志位 132 | if delta <= theta: 133 | break 134 | 135 | return np.array(V) 136 | 137 | 138 | # In[8]: 139 | 140 | 141 | v = policy_eval(policy, env) 142 | print("Reshaped Grid Value Function:") 143 | print(v.reshape(env.shape)) 144 | print("") 145 | 146 | -------------------------------------------------------------------------------- /chapter3/Policy Improvement.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stdout", 12 | "output_type": "stream", 13 | "text": [ 14 | "Pretty printing has been turned OFF\n", 15 | "Pretty printing has been turned ON\n" 16 | ] 17 | } 18 | ], 19 | "source": [ 20 | "import numpy as np\n", 21 | "import pprint\n", 22 | "from Environment import GridworldEnv\n", 23 | "from pprint import PrettyPrinter\n", 24 | "\n", 25 | "%pprint\n", 26 | "pp = PrettyPrinter(indent=4)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 15, 32 | "metadata": { 33 | "collapsed": false 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "env = GridworldEnv()" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 16, 43 | "metadata": { 44 | "collapsed": true 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "def policy_eval(policy, environment, discount_factor=1.0, theta=0.1):\n", 49 | " env = environment # 环境变量\n", 50 | " \n", 51 | " # 初始化一个全0的价值函数\n", 52 | " V = np.zeros(env.nS)\n", 53 | " \n", 54 | " # 迭代开始\n", 55 | " for _ in range(50):\n", 56 | " delta = 0\n", 57 | " \n", 58 | " # 对于GridWorld中的每一个状态都进行全备份\n", 59 | " for s in range(env.nS):\n", 60 | " v = 0\n", 61 | " # 检查下一个有可能执行的动作\n", 62 | " for a, action_prob in enumerate(policy[s]):\n", 63 | " \n", 64 | " # 对于每一个动作检查下一个状态\n", 65 | " for prob, next_state, reward, done in env.P[s][a]:\n", 66 | " # 累积计算下一个动作的期望价值\n", 67 | " v += action_prob * prob * (reward + discount_factor * V[next_state])\n", 68 | " # 选出最大的变化量\n", 69 | " delta = max(delta, np.abs(v - V[s]))\n", 70 | " V[s] = v\n", 71 | " \n", 72 | " # 停止标志位\n", 73 | " if delta <= theta:\n", 74 | " break\n", 75 | " \n", 76 | " return np.array(V)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 21, 82 | "metadata": { 83 | "collapsed": true 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "def policy_improvement(env, policy, discount_factor=1.0):\n", 88 | " \"\"\"\n", 89 | " Policy Imrpovement.\n", 90 | " Iterativedly evaluates and improves a policy until an \n", 91 | " optimal policy is found or to the limited iter threshold.\n", 92 | " \n", 93 | " Args:\n", 94 | " env: the environment.\n", 95 | " policy_eval_fun: Policy Evaluation function with 3 \n", 96 | " argements: policy, env, discount_factor.\n", 97 | " \n", 98 | " Returns:\n", 99 | " tuple(policy, V).\n", 100 | " \"\"\"\n", 101 | " k = 0\n", 102 | " while True:\n", 103 | " print(k)\n", 104 | " V = policy_eval(policy, env, discount_factor)\n", 105 | " print(\"random policy:\\n\", policy)\n", 106 | " print(\"policy eval:\\n\",V.reshape(env.shape))\n", 107 | " policy_stable = True\n", 108 | " for s in range(env.nS):\n", 109 | " chosen_a = np.argmax(policy[s])\n", 110 | " \n", 111 | " action_values = np.zeros(env.nA)\n", 112 | " for a in range(env.nA):\n", 113 | " for prob, next_state, reward, done in env.P[s][a]:\n", 114 | " action_values[a] += prob * (reward + discount_factor * V[next_state])\n", 115 | " if done and next_state != 15:\n", 116 | " action_values[a] = float('-inf')\n", 117 | "\n", 118 | " print(\"action_values:\\n\",s, action_values)\n", 119 | " \n", 120 | " best_a = np.argmax(action_values)\n", 121 | " \n", 122 | " if chosen_a != best_a:\n", 123 | " policy_stable = False\n", 124 | " policy[s] = np.eye(env.nA)[best_a]\n", 125 | " \n", 126 | " print(\"policy\\n\", np.reshape(np.argmax(policy, axis=1), env.shape))\n", 127 | " \n", 128 | " if policy_stable:\n", 129 | " return policy, V\n", 130 | " k+=1\n" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 22, 136 | "metadata": { 137 | "collapsed": false 138 | }, 139 | "outputs": [ 140 | { 141 | "name": "stdout", 142 | "output_type": "stream", 143 | "text": [ 144 | "0\n", 145 | "random policy:\n", 146 | " [[ 0.25 0.25 0.25 0.25]\n", 147 | " [ 0.25 0.25 0.25 0.25]\n", 148 | " [ 0.25 0.25 0.25 0.25]\n", 149 | " [ 0.25 0.25 0.25 0.25]\n", 150 | " [ 0.25 0.25 0.25 0.25]\n", 151 | " [ 0.25 0.25 0.25 0.25]\n", 152 | " [ 0.25 0.25 0.25 0.25]\n", 153 | " [ 0.25 0.25 0.25 0.25]\n", 154 | " [ 0.25 0.25 0.25 0.25]\n", 155 | " [ 0.25 0.25 0.25 0.25]\n", 156 | " [ 0.25 0.25 0.25 0.25]\n", 157 | " [ 0.25 0.25 0.25 0.25]\n", 158 | " [ 0.25 0.25 0.25 0.25]\n", 159 | " [ 0.25 0.25 0.25 0.25]\n", 160 | " [ 0.25 0.25 0.25 0.25]\n", 161 | " [ 0.25 0.25 0.25 0.25]]\n", 162 | "policy eval:\n", 163 | " [[-50.16293984 -50.36830499 -49.87680931 -50.45935702]\n", 164 | " [-49.84606245 -50. -47.65840563 -50. ]\n", 165 | " [-48.27313919 -43.72740352 -38.47270705 -50. ]\n", 166 | " [-50. -35.52802106 -9.62643679 50. ]]\n", 167 | "action_values:\n", 168 | " 0 [-51.16293984 -51.36830499 -50.84606245 -51.16293984]\n", 169 | "action_values:\n", 170 | " 1 [-51.36830499 -50.87680931 -inf -51.16293984]\n", 171 | "action_values:\n", 172 | " 2 [-50.87680931 -51.45935702 -48.65840563 -51.36830499]\n", 173 | "action_values:\n", 174 | " 3 [-51.45935702 -51.45935702 -inf -50.87680931]\n", 175 | "action_values:\n", 176 | " 4 [-51.16293984 -inf -49.27313919 -50.84606245]\n", 177 | "action_values:\n", 178 | " 5 [-inf -inf -inf -inf]\n", 179 | "action_values:\n", 180 | " 6 [-50.87680931 -inf -39.47270705 -inf]\n", 181 | "action_values:\n", 182 | " 7 [-inf -inf -inf -inf]\n", 183 | "action_values:\n", 184 | " 8 [-50.84606245 -44.72740352 -inf -49.27313919]\n", 185 | "action_values:\n", 186 | " 9 [ -inf -39.47270705 -36.52802106 -49.27313919]\n", 187 | "action_values:\n", 188 | " 10 [-48.65840563 -inf -10.62643679 -44.72740352]\n", 189 | "action_values:\n", 190 | " 11 [-inf -inf -inf -inf]\n", 191 | "action_values:\n", 192 | " 12 [-inf -inf -inf -inf]\n", 193 | "action_values:\n", 194 | " 13 [-44.72740352 -10.62643679 -36.52802106 -inf]\n", 195 | "action_values:\n", 196 | " 14 [-39.47270705 49. -10.62643679 -36.52802106]\n", 197 | "action_values:\n", 198 | " 15 [ 51. 51. 51. 51.]\n", 199 | "policy\n", 200 | " [[2 1 2 3]\n", 201 | " [2 0 2 0]\n", 202 | " [1 2 2 0]\n", 203 | " [0 1 1 0]]\n", 204 | "1\n", 205 | "random policy:\n", 206 | " [[ 0. 0. 1. 0.]\n", 207 | " [ 0. 1. 0. 0.]\n", 208 | " [ 0. 0. 1. 0.]\n", 209 | " [ 0. 0. 0. 1.]\n", 210 | " [ 0. 0. 1. 0.]\n", 211 | " [ 1. 0. 0. 0.]\n", 212 | " [ 0. 0. 1. 0.]\n", 213 | " [ 1. 0. 0. 0.]\n", 214 | " [ 0. 1. 0. 0.]\n", 215 | " [ 0. 0. 1. 0.]\n", 216 | " [ 0. 0. 1. 0.]\n", 217 | " [ 1. 0. 0. 0.]\n", 218 | " [ 1. 0. 0. 0.]\n", 219 | " [ 0. 1. 0. 0.]\n", 220 | " [ 0. 1. 0. 0.]\n", 221 | " [ 1. 0. 0. 0.]]\n", 222 | "policy eval:\n", 223 | " [[ 38. 40. 42. 41.]\n", 224 | " [ 40. -50. 44. -50.]\n", 225 | " [ 42. 44. 46. -50.]\n", 226 | " [-50. 46. 48. 50.]]\n", 227 | "action_values:\n", 228 | " 0 [ 37. 39. 39. 37.]\n", 229 | "action_values:\n", 230 | " 1 [ 39. 41. -inf 37.]\n", 231 | "action_values:\n", 232 | " 2 [ 41. 40. 43. 39.]\n", 233 | "action_values:\n", 234 | " 3 [ 40. 40. -inf 41.]\n", 235 | "action_values:\n", 236 | " 4 [ 37. -inf 41. 39.]\n", 237 | "action_values:\n", 238 | " 5 [-inf -inf -inf -inf]\n", 239 | "action_values:\n", 240 | " 6 [ 41. -inf 45. -inf]\n", 241 | "action_values:\n", 242 | " 7 [-inf -inf -inf -inf]\n", 243 | "action_values:\n", 244 | " 8 [ 39. 43. -inf 41.]\n", 245 | "action_values:\n", 246 | " 9 [-inf 45. 45. 41.]\n", 247 | "action_values:\n", 248 | " 10 [ 43. -inf 47. 43.]\n", 249 | "action_values:\n", 250 | " 11 [-inf -inf -inf -inf]\n", 251 | "action_values:\n", 252 | " 12 [-inf -inf -inf -inf]\n", 253 | "action_values:\n", 254 | " 13 [ 43. 47. 45. -inf]\n", 255 | "action_values:\n", 256 | " 14 [ 45. 49. 47. 45.]\n", 257 | "action_values:\n", 258 | " 15 [ 51. 51. 51. 51.]\n", 259 | "policy\n", 260 | " [[1 1 2 3]\n", 261 | " [2 0 2 0]\n", 262 | " [1 1 2 0]\n", 263 | " [0 1 1 0]]\n", 264 | "2\n", 265 | "random policy:\n", 266 | " [[ 0. 1. 0. 0.]\n", 267 | " [ 0. 1. 0. 0.]\n", 268 | " [ 0. 0. 1. 0.]\n", 269 | " [ 0. 0. 0. 1.]\n", 270 | " [ 0. 0. 1. 0.]\n", 271 | " [ 1. 0. 0. 0.]\n", 272 | " [ 0. 0. 1. 0.]\n", 273 | " [ 1. 0. 0. 0.]\n", 274 | " [ 0. 1. 0. 0.]\n", 275 | " [ 0. 1. 0. 0.]\n", 276 | " [ 0. 0. 1. 0.]\n", 277 | " [ 1. 0. 0. 0.]\n", 278 | " [ 1. 0. 0. 0.]\n", 279 | " [ 0. 1. 0. 0.]\n", 280 | " [ 0. 1. 0. 0.]\n", 281 | " [ 1. 0. 0. 0.]]\n", 282 | "policy eval:\n", 283 | " [[ 38. 40. 42. 41.]\n", 284 | " [ 40. -50. 44. -50.]\n", 285 | " [ 42. 44. 46. -50.]\n", 286 | " [-50. 46. 48. 50.]]\n", 287 | "action_values:\n", 288 | " 0 [ 37. 39. 39. 37.]\n", 289 | "action_values:\n", 290 | " 1 [ 39. 41. -inf 37.]\n", 291 | "action_values:\n", 292 | " 2 [ 41. 40. 43. 39.]\n", 293 | "action_values:\n", 294 | " 3 [ 40. 40. -inf 41.]\n", 295 | "action_values:\n", 296 | " 4 [ 37. -inf 41. 39.]\n", 297 | "action_values:\n", 298 | " 5 [-inf -inf -inf -inf]\n", 299 | "action_values:\n", 300 | " 6 [ 41. -inf 45. -inf]\n", 301 | "action_values:\n", 302 | " 7 [-inf -inf -inf -inf]\n", 303 | "action_values:\n", 304 | " 8 [ 39. 43. -inf 41.]\n", 305 | "action_values:\n", 306 | " 9 [-inf 45. 45. 41.]\n", 307 | "action_values:\n", 308 | " 10 [ 43. -inf 47. 43.]\n", 309 | "action_values:\n", 310 | " 11 [-inf -inf -inf -inf]\n", 311 | "action_values:\n", 312 | " 12 [-inf -inf -inf -inf]\n", 313 | "action_values:\n", 314 | " 13 [ 43. 47. 45. -inf]\n", 315 | "action_values:\n", 316 | " 14 [ 45. 49. 47. 45.]\n", 317 | "action_values:\n", 318 | " 15 [ 51. 51. 51. 51.]\n", 319 | "policy\n", 320 | " [[1 1 2 3]\n", 321 | " [2 0 2 0]\n", 322 | " [1 1 2 0]\n", 323 | " [0 1 1 0]]\n", 324 | "\n", 325 | "Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):\n", 326 | "[[1 1 2 3]\n", 327 | " [2 0 2 0]\n", 328 | " [1 1 2 0]\n", 329 | " [0 1 1 0]]\n", 330 | "\n", 331 | "Reshaped Grid Value Function:\n", 332 | "[[ 38. 40. 42. 41.]\n", 333 | " [ 40. -50. 44. -50.]\n", 334 | " [ 42. 44. 46. -50.]\n", 335 | " [-50. 46. 48. 50.]]\n", 336 | "\n" 337 | ] 338 | } 339 | ], 340 | "source": [ 341 | "random_policy = np.ones([env.nS, env.nA])/env.nA\n", 342 | "policy, v = policy_improvement(env, random_policy)\n", 343 | "\n", 344 | "print(\"\\nReshaped Grid Policy (0=up, 1=right, 2=down, 3=left):\")\n", 345 | "print(np.reshape(np.argmax(policy, axis=1), env.shape))\n", 346 | "print(\"\")\n", 347 | "\n", 348 | "print(\"Reshaped Grid Value Function:\")\n", 349 | "print(v.reshape(env.shape))\n", 350 | "print(\"\")" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": { 356 | "collapsed": true 357 | }, 358 | "source": [ 359 | "The real policy Ieration function following\n", 360 | "=============" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "metadata": { 367 | "collapsed": true 368 | }, 369 | "outputs": [], 370 | "source": [ 371 | "def policy_iteration(env, policy, discount_factor=1.0):\n", 372 | " while True:\n", 373 | " # 评估当前策略 policy\n", 374 | " V = policy_eval(policy, env, discount_factor)\n", 375 | "\n", 376 | " # policy 标志位,当某状态的策略更改后该标志位为 False\n", 377 | " policy_stable = True\n", 378 | " \n", 379 | " # 策略改进\n", 380 | " for s in range(env.nS):\n", 381 | " # 在当前状态和策略下选择概率最高的动作\n", 382 | " old_action = np.argmax(policy[s])\n", 383 | " \n", 384 | " # 在当前状态和策略下找到最优动作\n", 385 | " action_values = np.zeros(env.nA)\n", 386 | " for a in range(env.nA):\n", 387 | " for prob, next_state, reward, done in env.P[s][a]:\n", 388 | " action_values[a] += prob * (reward + discount_factor * V[next_state])\n", 389 | " if done and next_state != 15:\n", 390 | " action_values[a] = float('-inf')\n", 391 | "\n", 392 | " print(\"action_values:\\n\",s, action_values)\n", 393 | " \n", 394 | " # 采用贪婪算法更新当前策略\n", 395 | " best_action = np.argmax(action_values)\n", 396 | " \n", 397 | " if old_action != best_action:\n", 398 | " policy_stable = False\n", 399 | " policy[s] = np.eye(env.nA)[best_a]\n", 400 | " \n", 401 | " \n", 402 | " # 选择的动作不再变化,则代表策略已经稳定下来\n", 403 | " if policy_stable:\n", 404 | " # 返回最优策略和对应状态值\n", 405 | " return policy, V" 406 | ] 407 | } 408 | ], 409 | "metadata": { 410 | "kernelspec": { 411 | "display_name": "Python 3", 412 | "language": "python", 413 | "name": "python3" 414 | }, 415 | "language_info": { 416 | "codemirror_mode": { 417 | "name": "ipython", 418 | "version": 3 419 | }, 420 | "file_extension": ".py", 421 | "mimetype": "text/x-python", 422 | "name": "python", 423 | "nbconvert_exporter": "python", 424 | "pygments_lexer": "ipython3", 425 | "version": "3.6.0" 426 | } 427 | }, 428 | "nbformat": 4, 429 | "nbformat_minor": 2 430 | } 431 | -------------------------------------------------------------------------------- /chapter3/Policy Improvement.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[3]: 5 | 6 | 7 | import numpy as np 8 | import pprint 9 | from Environment import GridworldEnv 10 | from pprint import PrettyPrinter 11 | 12 | get_ipython().run_line_magic('pprint', '') 13 | pp = PrettyPrinter(indent=4) 14 | 15 | 16 | # In[15]: 17 | 18 | 19 | env = GridworldEnv() 20 | 21 | 22 | # In[16]: 23 | 24 | 25 | def policy_eval(policy, environment, discount_factor=1.0, theta=0.1): 26 | env = environment # 环境变量 27 | 28 | # 初始化一个全0的价值函数 29 | V = np.zeros(env.nS) 30 | 31 | # 迭代开始 32 | for _ in range(50): 33 | delta = 0 34 | 35 | # 对于GridWorld中的每一个状态都进行全备份 36 | for s in range(env.nS): 37 | v = 0 38 | # 检查下一个有可能执行的动作 39 | for a, action_prob in enumerate(policy[s]): 40 | 41 | # 对于每一个动作检查下一个状态 42 | for prob, next_state, reward, done in env.P[s][a]: 43 | # 累积计算下一个动作的期望价值 44 | v += action_prob * prob * (reward + discount_factor * V[next_state]) 45 | # 选出最大的变化量 46 | delta = max(delta, np.abs(v - V[s])) 47 | V[s] = v 48 | 49 | # 停止标志位 50 | if delta <= theta: 51 | break 52 | 53 | return np.array(V) 54 | 55 | 56 | # In[21]: 57 | 58 | 59 | def policy_improvement(env, policy, discount_factor=1.0): 60 | """ 61 | Policy Imrpovement. 62 | Iterativedly evaluates and improves a policy until an 63 | optimal policy is found or to the limited iter threshold. 64 | 65 | Args: 66 | env: the environment. 67 | policy_eval_fun: Policy Evaluation function with 3 68 | argements: policy, env, discount_factor. 69 | 70 | Returns: 71 | tuple(policy, V). 72 | """ 73 | k = 0 74 | while True: 75 | print(k) 76 | V = policy_eval(policy, env, discount_factor) 77 | print("random policy:\n", policy) 78 | print("policy eval:\n",V.reshape(env.shape)) 79 | policy_stable = True 80 | for s in range(env.nS): 81 | chosen_a = np.argmax(policy[s]) 82 | 83 | action_values = np.zeros(env.nA) 84 | for a in range(env.nA): 85 | for prob, next_state, reward, done in env.P[s][a]: 86 | action_values[a] += prob * (reward + discount_factor * V[next_state]) 87 | if done and next_state != 15: 88 | action_values[a] = float('-inf') 89 | 90 | print("action_values:\n",s, action_values) 91 | 92 | best_a = np.argmax(action_values) 93 | 94 | if chosen_a != best_a: 95 | policy_stable = False 96 | policy[s] = np.eye(env.nA)[best_a] 97 | 98 | print("policy\n", np.reshape(np.argmax(policy, axis=1), env.shape)) 99 | 100 | if policy_stable: 101 | return policy, V 102 | k+=1 103 | 104 | 105 | # In[22]: 106 | 107 | 108 | random_policy = np.ones([env.nS, env.nA])/env.nA 109 | policy, v = policy_improvement(env, random_policy) 110 | 111 | print("\nReshaped Grid Policy (0=up, 1=right, 2=down, 3=left):") 112 | print(np.reshape(np.argmax(policy, axis=1), env.shape)) 113 | print("") 114 | 115 | print("Reshaped Grid Value Function:") 116 | print(v.reshape(env.shape)) 117 | print("") 118 | 119 | 120 | # The real policy Ieration function following 121 | # ============= 122 | 123 | # In[ ]: 124 | 125 | 126 | def policy_iteration(env, policy, discount_factor=1.0): 127 | while True: 128 | # 评估当前策略 policy 129 | V = policy_eval(policy, env, discount_factor) 130 | 131 | # policy 标志位,当某状态的策略更改后该标志位为 False 132 | policy_stable = True 133 | 134 | # 策略改进 135 | for s in range(env.nS): 136 | # 在当前状态和策略下选择概率最高的动作 137 | old_action = np.argmax(policy[s]) 138 | 139 | # 在当前状态和策略下找到最优动作 140 | action_values = np.zeros(env.nA) 141 | for a in range(env.nA): 142 | for prob, next_state, reward, done in env.P[s][a]: 143 | action_values[a] += prob * (reward + discount_factor * V[next_state]) 144 | if done and next_state != 15: 145 | action_values[a] = float('-inf') 146 | 147 | print("action_values:\n",s, action_values) 148 | 149 | # 采用贪婪算法更新当前策略 150 | best_action = np.argmax(action_values) 151 | 152 | if old_action != best_action: 153 | policy_stable = False 154 | policy[s] = np.eye(env.nA)[best_a] 155 | 156 | 157 | # 选择的动作不再变化,则代表策略已经稳定下来 158 | if policy_stable: 159 | # 返回最优策略和对应状态值 160 | return policy, V 161 | 162 | -------------------------------------------------------------------------------- /chapter3/Value Iteration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "Pretty printing has been turned OFF\n", 13 | "Pretty printing has been turned ON\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import numpy as np\n", 19 | "import pprint\n", 20 | "from Environment import GridworldEnv\n", 21 | "from pprint import PrettyPrinter\n", 22 | "\n", 23 | "%pprint\n", 24 | "pp = PrettyPrinter(indent=4)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 43, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "def calc_action_value(state, V, discount_factor=1.0):\n", 34 | " \"\"\"\n", 35 | " Calculate the expected value of each action in a given state.\n", 36 | " 对于给定的状态 s 计算其动作 a 的期望值\n", 37 | " \"\"\"\n", 38 | " A = np.zeros(env.nA)\n", 39 | " for a in range(env.nA):\n", 40 | " for prob, next_state, reward, done in env.P[state][a]:\n", 41 | " A[a] += prob * (reward + discount_factor * V[next_state])\n", 42 | " return A\n", 43 | " \n", 44 | " \n", 45 | "def value_iteration(env, theta=0.1, discount_factor=1.0):\n", 46 | " \"\"\"\n", 47 | " Value Iteration Algorithm. 值迭代算法\n", 48 | " \"\"\"\n", 49 | " # 初始化状态值\n", 50 | " V = np.zeros(env.nS)\n", 51 | "\n", 52 | " # 迭代计算找到最优的状态值函数 optimal value function\n", 53 | " for _ in range(50):\n", 54 | " delta = 0 # 停止标志位\n", 55 | " \n", 56 | " # 计算每个状态的状态值\n", 57 | " for s in range(env.nS):\n", 58 | " A = calc_action_value(s, V) # 执行一次找到当前状态的动作期望\n", 59 | " best_action_value = np.max(A) # 选择最好的动作期望作为新的状态值\n", 60 | " \n", 61 | " # 计算停止标志位\n", 62 | " delta = max(delta, np.abs(best_action_value - V[s])) \n", 63 | " \n", 64 | " # 更新状态值函数\n", 65 | " V[s] = best_action_value \n", 66 | " \n", 67 | " if delta < theta:\n", 68 | " break\n", 69 | " \n", 70 | " \n", 71 | " # 输出最优策略:通过最优状态值函数找到决定性策略\n", 72 | " policy = np.zeros([env.nS, env.nA]) # 初始化策略\n", 73 | " \n", 74 | " for s in range(env.nS):\n", 75 | " # 执行一次找到当前状态的最优状态值的动作期望 A\n", 76 | " A = calc_action_value(s, V)\n", 77 | " \n", 78 | " # 选出状态值最大的作为最优动作\n", 79 | " best_action = np.argmax(A)\n", 80 | " policy[s, best_action] = 1.0\n", 81 | " \n", 82 | " return policy, V" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 44, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stdout", 92 | "output_type": "stream", 93 | "text": [ 94 | "Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):\n", 95 | "[[1 1 2 3]\n", 96 | " [2 0 2 0]\n", 97 | " [1 1 2 0]\n", 98 | " [0 1 1 0]]\n", 99 | "\n", 100 | "Reshaped Grid Value Function:\n", 101 | "[[ 38. 40. 42. 41.]\n", 102 | " [ 40. -50. 44. -50.]\n", 103 | " [ 42. 44. 46. -50.]\n", 104 | " [-50. 46. 48. 50.]]\n", 105 | "\n" 106 | ] 107 | } 108 | ], 109 | "source": [ 110 | "env = GridworldEnv()\n", 111 | "policy, v = value_iteration(env)\n", 112 | "\n", 113 | "print(\"Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):\")\n", 114 | "print(np.reshape(np.argmax(policy, axis=1), env.shape))\n", 115 | "print(\"\")\n", 116 | "\n", 117 | "print(\"Reshaped Grid Value Function:\")\n", 118 | "print(v.reshape(env.shape))\n", 119 | "print(\"\")" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": { 126 | "collapsed": true 127 | }, 128 | "outputs": [], 129 | "source": [] 130 | } 131 | ], 132 | "metadata": { 133 | "kernelspec": { 134 | "display_name": "Python 3", 135 | "language": "python", 136 | "name": "python3" 137 | }, 138 | "language_info": { 139 | "codemirror_mode": { 140 | "name": "ipython", 141 | "version": 3 142 | }, 143 | "file_extension": ".py", 144 | "mimetype": "text/x-python", 145 | "name": "python", 146 | "nbconvert_exporter": "python", 147 | "pygments_lexer": "ipython3", 148 | "version": "3.6.0" 149 | } 150 | }, 151 | "nbformat": 4, 152 | "nbformat_minor": 2 153 | } 154 | -------------------------------------------------------------------------------- /chapter3/Value Iteration.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | import numpy as np 7 | import pprint 8 | from Environment import GridworldEnv 9 | from pprint import PrettyPrinter 10 | 11 | get_ipython().magic('pprint') 12 | pp = PrettyPrinter(indent=4) 13 | 14 | 15 | # In[43]: 16 | 17 | def calc_action_value(state, V, discount_factor=1.0): 18 | """ 19 | Calculate the expected value of each action in a given state. 20 | 对于给定的状态 s 计算其动作 a 的期望值 21 | """ 22 | A = np.zeros(env.nA) 23 | for a in range(env.nA): 24 | for prob, next_state, reward, done in env.P[state][a]: 25 | A[a] += prob * (reward + discount_factor * V[next_state]) 26 | return A 27 | 28 | 29 | def value_iteration(env, theta=0.1, discount_factor=1.0): 30 | """ 31 | Value Iteration Algorithm. 值迭代算法 32 | """ 33 | # 初始化状态值 34 | V = np.zeros(env.nS) 35 | 36 | # 迭代计算找到最优的状态值函数 optimal value function 37 | for _ in range(50): 38 | delta = 0 # 停止标志位 39 | 40 | # 计算每个状态的状态值 41 | for s in range(env.nS): 42 | A = calc_action_value(s, V) # 执行一次找到当前状态的动作期望 43 | best_action_value = np.max(A) # 选择最好的动作期望作为新的状态值 44 | 45 | # 计算停止标志位 46 | delta = max(delta, np.abs(best_action_value - V[s])) 47 | 48 | # 更新状态值函数 49 | V[s] = best_action_value 50 | 51 | if delta < theta: 52 | break 53 | 54 | 55 | # 输出最优策略:通过最优状态值函数找到决定性策略 56 | policy = np.zeros([env.nS, env.nA]) # 初始化策略 57 | 58 | for s in range(env.nS): 59 | # 执行一次找到当前状态的最优状态值的动作期望 A 60 | A = calc_action_value(s, V) 61 | 62 | # 选出状态值最大的作为最优动作 63 | best_action = np.argmax(A) 64 | policy[s, best_action] = 1.0 65 | 66 | return policy, V 67 | 68 | 69 | # In[44]: 70 | 71 | env = GridworldEnv() 72 | policy, v = value_iteration(env) 73 | 74 | print("Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):") 75 | print(np.reshape(np.argmax(policy, axis=1), env.shape)) 76 | print("") 77 | 78 | print("Reshaped Grid Value Function:") 79 | print(v.reshape(env.shape)) 80 | print("") 81 | 82 | 83 | # The real policy Ieration function following 84 | # ============= 85 | 86 | # In[ ]: 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /chapter4/MC firstvisit prediciton.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[227]: 5 | 6 | 7 | import gym 8 | import sys 9 | import numpy as np 10 | import matplotlib 11 | from collections import defaultdict 12 | from matplotlib import pyplot as plt 13 | from mpl_toolkits.mplot3d import Axes3D 14 | 15 | 16 | get_ipython().run_line_magic('matplotlib', 'inline') 17 | matplotlib.style.use("ggplot") 18 | 19 | 20 | # In[245]: 21 | 22 | 23 | def plot_value_function(V, title): 24 | """ 25 | Plots the value function as a surface plot. 26 | """ 27 | min_x = min(k[0] for k in V.keys()) 28 | max_x = max(k[0] for k in V.keys()) 29 | min_y = min(k[1] for k in V.keys()) 30 | max_y = max(k[1] for k in V.keys()) 31 | 32 | x_range = np.arange(min_x, max_x + 1) 33 | y_range = np.arange(min_y, max_y + 1) 34 | X, Y = np.meshgrid(x_range, y_range) 35 | 36 | Z_noace = np.apply_along_axis(lambda _: V[(_[0], _[1], False)], 2, np.dstack([X, Y])) 37 | Z_ace = np.apply_along_axis(lambda _: V[(_[0], _[1], True)], 2, np.dstack([X, Y])) 38 | 39 | def plot_surface(X, Y, Z, title=None): 40 | fig = plt.figure(figsize=(20, 10), facecolor='white') 41 | 42 | ax = fig.add_subplot(111, projection='3d') 43 | surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, 44 | cmap=matplotlib.cm.coolwarm, vmin=-1.0, vmax=1.0) 45 | ax.set_xlabel('Player sum') 46 | ax.set_ylabel('Dealer showing') 47 | ax.set_zlabel('Value') 48 | if title: ax.set_title(title) 49 | ax.view_init(ax.elev, -120) 50 | ax.set_facecolor("white") 51 | fig.colorbar(surf) 52 | plt.show() 53 | 54 | plot_surface(X, Y, Z_noace, "(No Usable Ace)") 55 | plot_surface(X, Y, Z_ace, "(Usable Ace)") 56 | 57 | 58 | # In[246]: 59 | 60 | 61 | env = gym.make("Blackjack-v0") 62 | 63 | def simple_policy(state): 64 | player_score, _, _ = state 65 | return 0 if player_score >= 18 else 1 66 | 67 | def mc_firstvisit_prediction(policy, env, num_episodes, 68 | episode_endtime= 10, discount = 1.0): 69 | r_sum = defaultdict(float) 70 | r_count = defaultdict(float) 71 | r_V = defaultdict(float) 72 | 73 | for i in range(num_episodes): 74 | # print out the episodes rate for displaying. 75 | episode_rate = int(40 * i / num_episodes) 76 | print("Episode {}/{}".format(i+1, num_episodes), end="\r") 77 | sys.stdout.flush() 78 | 79 | # init the episode list and state 80 | episode = [] 81 | state = env.reset() 82 | 83 | # Generate an episode which including tuple(state, aciton, reward). 84 | for j in range(episode_endtime): 85 | action = policy(state) 86 | next_state, reward, done, _ = env.step(action) 87 | episode.append((state, action, reward)) 88 | if done: break 89 | state = next_state 90 | 91 | # first visit mc method 92 | for k, data_k in enumerate(episode): 93 | state_k = data_k[0] 94 | G = sum([x[2] * np.power(discount, i) for i,x in enumerate(episode[k:])]) 95 | r_sum[state_k] += G 96 | r_count[state_k] += 1.0 97 | r_V[state_k] = r_sum[state_k] / r_count[state_k] 98 | 99 | return r_V 100 | 101 | v1 = mc_firstvisit_prediction(simple_policy, env, 100000) 102 | plot_value_function(v1, title=None) 103 | 104 | 105 | # In[235]: 106 | 107 | 108 | env = env = gym.make("Blackjack-v0") 109 | 110 | def simple_policy(state): 111 | player_score, _, _ = state 112 | return 0 if player_score >= 18 else 1 113 | 114 | def mc_everyvisit_prediction(policy, env, num_episodes, episode_endtime = 10, discount = 1.0): 115 | r_sum = defaultdict(float) 116 | r_count = defaultdict(float) 117 | r_V = defaultdict(float) 118 | 119 | for i in range(num_episodes): 120 | # print out the episodes rate for displaying. 121 | episode_rate = int(80 * i / num_episodes) 122 | print("Episode {}/{}".format(i+1, num_episodes) + "=" * episode_rate, end="\r") 123 | sys.stdout.flush() 124 | 125 | # init the episode list and state 126 | episode = [] 127 | state = env.reset() 128 | 129 | # Generate an episode which including tuple(state, aciton, reward). 130 | for j in range(episode_endtime): 131 | action = policy(state) 132 | next_state, reward, done, _ = env.step(action) 133 | episode.append((state, action, reward)) 134 | if done: break 135 | state = next_state 136 | 137 | # every visit mc method 138 | for k, data_k in enumerate(episode): 139 | state_k = data_k[0] 140 | G = sum([x[2] * np.power(discount, i) for i,x in enumerate(episode)]) 141 | r_sum[state_k] += G 142 | r_count[state_k] += 1.0 143 | r_V[state_k] = r_sum[state_k] / r_count[state_k] 144 | 145 | return r_V 146 | 147 | v2 = mc_everyvisit_prediction(simple_policy, env, 100000) 148 | plot_value_function(v2, title=None) 149 | 150 | 151 | # In[247]: 152 | 153 | 154 | v1 155 | 156 | -------------------------------------------------------------------------------- /chapter4/MC_blackjack.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 24, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import gym\n", 11 | "from termcolor import colored" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 31, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "\n", 24 | "==============================\n", 25 | "Player:20, ace:False, Dealer:18\n", 26 | "Player Simple strategy take action:STAND\n", 27 | "Player:20, ace:False, Dealer:18\n", 28 | "Game win.(Reward 1)\n", 29 | "PLAYER:\u001b[31m[10, 10]\u001b[0m\t DEALER:\u001b[32m[8, 10]\u001b[0m\n", 30 | "\n", 31 | "==============================\n", 32 | "Player:21, ace:True, Dealer:14\n", 33 | "Player Simple strategy take action:STAND\n", 34 | "Player:21, ace:True, Dealer:24\n", 35 | "Game win.(Reward 1)\n", 36 | "PLAYER:\u001b[31m[1, 10]\u001b[0m\t DEALER:\u001b[32m[10, 4, 10]\u001b[0m\n", 37 | "\n", 38 | "==============================\n", 39 | "Player:17, ace:True, Dealer:20\n", 40 | "Player Simple strategy take action:HIT\n", 41 | "Player:17, ace:False, Dealer:20\n", 42 | "Player Simple strategy take action:HIT\n", 43 | "Player:27, ace:False, Dealer:20\n", 44 | "Game loss.(Reward -1)\n", 45 | "PLAYER:\u001b[31m[6, 1, 10, 10]\u001b[0m\t DEALER:\u001b[32m[10, 10]\u001b[0m\n", 46 | "\n", 47 | "==============================\n", 48 | "Player:5, ace:False, Dealer:14\n", 49 | "Player Simple strategy take action:HIT\n", 50 | "Player:13, ace:False, Dealer:14\n", 51 | "Player Simple strategy take action:HIT\n", 52 | "Player:16, ace:False, Dealer:14\n", 53 | "Player Simple strategy take action:HIT\n", 54 | "Player:26, ace:False, Dealer:14\n", 55 | "Game loss.(Reward -1)\n", 56 | "PLAYER:\u001b[31m[3, 2, 8, 3, 10]\u001b[0m\t DEALER:\u001b[32m[10, 4]\u001b[0m\n", 57 | "\n", 58 | "==============================\n", 59 | "Player:21, ace:True, Dealer:20\n", 60 | "Player Simple strategy take action:STAND\n", 61 | "Player:21, ace:True, Dealer:20\n", 62 | "Game win.(Reward 1)\n", 63 | "PLAYER:\u001b[31m[10, 1]\u001b[0m\t DEALER:\u001b[32m[10, 10]\u001b[0m\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "env = gym.make(\"Blackjack-v0\")\n", 69 | "\n", 70 | "def show_state(state):\n", 71 | " player, dealer, ace = state\n", 72 | " dealer = sum(env.dealer)\n", 73 | " print(\"Player:{}, ace:{}, Dealer:{}\".format(player, ace, dealer))\n", 74 | "\n", 75 | "def simple_strategy(state):\n", 76 | " player, dealer, ace = state\n", 77 | " return 0 if player >= 18 else 1\n", 78 | "\n", 79 | "def episode(num_episodes):\n", 80 | " episode = []\n", 81 | " for i_episode in range(5):\n", 82 | " print(\"\\n\" + \"=\"* 30)\n", 83 | " state = env.reset()\n", 84 | " for t in range(10):\n", 85 | " show_state(state)\n", 86 | " action = simple_strategy(state)\n", 87 | " action_ = [\"STAND\", \"HIT\"][action]\n", 88 | " print(\"Player Simple strategy take action:{}\".format(action_))\n", 89 | " \n", 90 | " next_state, reward, done, _ = env.step(action)\n", 91 | " episode.append((state, action, reward))\n", 92 | " if done:\n", 93 | " show_state(state)\n", 94 | " # [-1(loss), -(push), 1(win)]\n", 95 | " reward_ = [\"loss\", \"push\", \"win\"][int(reward+1)]\n", 96 | " print(\"Game {}.(Reward {})\".format(reward_, int(reward)))\n", 97 | " print(\"PLAYER:{}\\t DEALER:{}\".format(colored(env.player, 'red'), \n", 98 | " colored(env.dealer, 'green')))\n", 99 | " break\n", 100 | " \n", 101 | " state = next_state\n", 102 | "\n", 103 | "episode(1000)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [] 119 | } 120 | ], 121 | "metadata": { 122 | "kernelspec": { 123 | "display_name": "Python 3", 124 | "language": "python", 125 | "name": "python3" 126 | }, 127 | "language_info": { 128 | "codemirror_mode": { 129 | "name": "ipython", 130 | "version": 3 131 | }, 132 | "file_extension": ".py", 133 | "mimetype": "text/x-python", 134 | "name": "python", 135 | "nbconvert_exporter": "python", 136 | "pygments_lexer": "ipython3", 137 | "version": "3.6.0" 138 | } 139 | }, 140 | "nbformat": 4, 141 | "nbformat_minor": 2 142 | } 143 | -------------------------------------------------------------------------------- /chapter4/MC_blackjack.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[24]: 5 | 6 | 7 | import numpy as np 8 | import gym 9 | from termcolor import colored 10 | 11 | 12 | # In[31]: 13 | 14 | 15 | env = gym.make("Blackjack-v0") 16 | 17 | def show_state(state): 18 | player, dealer, ace = state 19 | dealer = sum(env.dealer) 20 | print("Player:{}, ace:{}, Dealer:{}".format(player, ace, dealer)) 21 | 22 | def simple_strategy(state): 23 | player, dealer, ace = state 24 | return 0 if player >= 18 else 1 25 | 26 | def episode(num_episodes): 27 | episode = [] 28 | for i_episode in range(5): 29 | print("\n" + "="* 30) 30 | state = env.reset() 31 | for t in range(10): 32 | show_state(state) 33 | action = simple_strategy(state) 34 | action_ = ["STAND", "HIT"][action] 35 | print("Player Simple strategy take action:{}".format(action_)) 36 | 37 | next_state, reward, done, _ = env.step(action) 38 | episode.append((state, action, reward)) 39 | if done: 40 | show_state(state) 41 | # [-1(loss), -(push), 1(win)] 42 | reward_ = ["loss", "push", "win"][int(reward+1)] 43 | print("Game {}.(Reward {})".format(reward_, int(reward))) 44 | print("PLAYER:{}\t DEALER:{}".format(colored(env.player, 'red'), 45 | colored(env.dealer, 'green'))) 46 | break 47 | 48 | state = next_state 49 | 50 | episode(1000) 51 | 52 | -------------------------------------------------------------------------------- /chapter4/MC_firstvisit_control .py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[32]: 5 | 6 | 7 | import gym 8 | import numpy as np 9 | import sys 10 | import matplotlib 11 | from matplotlib import pyplot as plt 12 | from mpl_toolkits.mplot3d import Axes3D 13 | 14 | from collections import defaultdict 15 | 16 | get_ipython().run_line_magic('matplotlib', 'inline') 17 | matplotlib.style.use('ggplot') 18 | 19 | 20 | # In[27]: 21 | 22 | 23 | def plot_value_function(V, title): 24 | """ 25 | Plots the value function as a surface plot. 26 | """ 27 | min_x = min(k[0] for k in V.keys()) 28 | max_x = max(k[0] for k in V.keys()) 29 | min_y = min(k[1] for k in V.keys()) 30 | max_y = max(k[1] for k in V.keys()) 31 | 32 | x_range = np.arange(min_x, max_x + 1) 33 | y_range = np.arange(min_y, max_y + 1) 34 | X, Y = np.meshgrid(x_range, y_range) 35 | 36 | Z_noace = np.apply_along_axis(lambda _: V[(_[0], _[1], False)], 2, np.dstack([X, Y])) 37 | Z_ace = np.apply_along_axis(lambda _: V[(_[0], _[1], True)], 2, np.dstack([X, Y])) 38 | 39 | def plot_surface(X, Y, Z, title=None): 40 | fig = plt.figure(figsize=(20, 10), facecolor='white') 41 | 42 | ax = fig.add_subplot(111, projection='3d') 43 | surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, 44 | cmap=matplotlib.cm.coolwarm, vmin=-1.0, vmax=1.0) 45 | ax.set_xlabel('Player sum') 46 | ax.set_ylabel('Dealer showing') 47 | ax.set_zlabel('Value') 48 | if title: ax.set_title(title) 49 | ax.view_init(ax.elev, -120) 50 | ax.set_facecolor("white") 51 | fig.colorbar(surf) 52 | plt.show() 53 | 54 | plot_surface(X, Y, Z_noace, "(No Usable Ace)") 55 | plot_surface(X, Y, Z_ace, "(Usable Ace)") 56 | 57 | 58 | # In[26]: 59 | 60 | 61 | env = gym.make("Blackjack-v0") 62 | 63 | def epsilon_greddy_policy(q, epsilon, nA): 64 | 65 | def policy_(state): 66 | A_ = np.ones(nA, dtype=float) 67 | A = A_ * epsilon / nA 68 | best = np.argmax(q[state]) 69 | A[best] += 1 - epsilon 70 | return A 71 | 72 | return policy_ 73 | 74 | def mc_firstvisit_control_epsilon_greddy(env, num_episodes=100, epsilon=0.1, 75 | episode_endtime = 10, discount=1.0): 76 | nA = env.action_space.n 77 | Q = defaultdict(lambda: np.zeros(nA)) 78 | r_sum = defaultdict(float) 79 | r_cou = defaultdict(float) 80 | 81 | policy = epsilon_greddy_policy(Q, epsilon, nA) 82 | 83 | for i in range(num_episodes): 84 | # print out the episodes rate for displaying. 85 | episode_rate = int(40 * i / num_episodes) 86 | print("Episode {}/{}".format(i+1, num_episodes), end="\r") 87 | sys.stdout.flush() 88 | 89 | # init the episode list and state 90 | episode = [] 91 | state = env.reset() 92 | 93 | # Generate an episode which including tuple(state, aciton, reward). 94 | for j in range(episode_endtime): 95 | # explore and explict the state-action by epsilon greddy algorithm. 96 | action_prob = policy(state) 97 | action = np.random.choice(np.arange(action_prob.shape[0]), p=action_prob) 98 | 99 | next_state, reward, done, _ = env.step(action) 100 | episode.append((state, action, reward)) 101 | if done: break 102 | state = next_state 103 | 104 | for k, (state, actions, reward) in enumerate(episode): 105 | # state action pair in tuple type 106 | sa_pair = (state, action) 107 | first_visit_idx = k 108 | G = sum([x[2] * np.power(discount, i) for i, x in enumerate(episode[first_visit_idx:])]) 109 | 110 | r_sum[sa_pair] += G 111 | r_cou[sa_pair] += 1.0 112 | Q[state][actions] = r_sum[sa_pair] / r_cou[sa_pair] 113 | 114 | return Q 115 | 116 | Q = mc_firstvisit_control_epsilon_greddy(env, num_episodes=500000) 117 | 118 | V = defaultdict(float) 119 | for state, actions in Q.items(): 120 | V[state] = np.max(actions) 121 | 122 | plot_value_function(v1, title=None) 123 | 124 | 125 | # In[33]: 126 | 127 | 128 | plot_value_function(V, title=None) 129 | 130 | -------------------------------------------------------------------------------- /chapter4/MC_off_policy_weighted_importance_sampleing.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import gym 8 | import sys 9 | import numpy as np 10 | import matplotlib 11 | from collections import defaultdict 12 | from matplotlib import pyplot as plt 13 | from mpl_toolkits.mplot3d import Axes3D 14 | 15 | 16 | get_ipython().run_line_magic('matplotlib', 'inline') 17 | matplotlib.style.use("ggplot") 18 | 19 | 20 | # In[14]: 21 | 22 | 23 | def plot_value_function(V, title): 24 | """ 25 | Plots the value function as a surface plot. 26 | """ 27 | min_x = min(k[0] for k in V.keys()) 28 | max_x = max(k[0] for k in V.keys()) 29 | min_y = min(k[1] for k in V.keys()) 30 | max_y = max(k[1] for k in V.keys()) 31 | 32 | x_range = np.arange(min_x, max_x + 1) 33 | y_range = np.arange(min_y, max_y + 1) 34 | X, Y = np.meshgrid(x_range, y_range) 35 | 36 | Z_noace = np.apply_along_axis(lambda _: V[(_[0], _[1], False)], 2, np.dstack([X, Y])) 37 | Z_ace = np.apply_along_axis(lambda _: V[(_[0], _[1], True)], 2, np.dstack([X, Y])) 38 | 39 | def plot_surface(X, Y, Z, title=None): 40 | fig = plt.figure(figsize=(20, 10), facecolor='white') 41 | 42 | ax = fig.add_subplot(111, projection='3d') 43 | surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, 44 | cmap=matplotlib.cm.coolwarm, vmin=-1.0, vmax=1.0) 45 | ax.set_xlabel('Player sum') 46 | ax.set_ylabel('Dealer showing') 47 | ax.set_zlabel('Value') 48 | if title: ax.set_title(title) 49 | ax.view_init(ax.elev, -120) 50 | ax.set_facecolor("white") 51 | fig.colorbar(surf) 52 | plt.show() 53 | 54 | plot_surface(X, Y, Z_noace, "(No Usable Ace)") 55 | plot_surface(X, Y, Z_ace, "(Usable Ace)") 56 | 57 | 58 | # In[3]: 59 | 60 | 61 | env = gym.make("Blackjack-v0") 62 | 63 | 64 | # In[4]: 65 | 66 | 67 | def create_random_policy(nA): 68 | 69 | A = np.ones(nA, dtype=float) / nA 70 | def policy_fn(observation): 71 | return A 72 | return policy_fn 73 | 74 | def create_greedy_policy(Q): 75 | 76 | def policy_fn(state): 77 | A = np.zeros_like(Q[state], dtype=float) 78 | best_action = np.argmax(Q[state]) 79 | A[best_action] = 1.0 80 | return A 81 | return policy_fn 82 | 83 | 84 | # In[5]: 85 | 86 | 87 | def mc_control_importance_sampling(env, num_episodes, behavior_policy, discount_factor=1.0): 88 | Q = defaultdict(lambda: np.zeros(env.action_space.n)) 89 | C = defaultdict(lambda: np.zeros(env.action_space.n)) 90 | 91 | # Our greedily policy we want to learn 92 | target_policy = create_greedy_policy(Q) 93 | 94 | for i_episode in range(1, num_episodes + 1): 95 | if i_episode % 1000 == 0: 96 | print("\rEpisode {}/{}.".format(i_episode, num_episodes), end="") 97 | sys.stdout.flush() 98 | 99 | # An episode is an array of (state, action, reward) tuples 100 | episode = [] 101 | state = env.reset() 102 | for t in range(100): 103 | # Sample an action from our policy 104 | probs = behavior_policy(state) 105 | action = np.random.choice(np.arange(len(probs)), p=probs) 106 | next_state, reward, done, _ = env.step(action) 107 | episode.append((state, action, reward)) 108 | if done: 109 | break 110 | state = next_state 111 | 112 | # Sum of discounted returns 113 | G = 0.0 114 | # The importance sampling ratio (the weights of the returns) 115 | W = 1.0 116 | # For each step in the episode, backwards 117 | for t in range(len(episode))[::-1]: 118 | state, action, reward = episode[t] 119 | # Update the total reward since step t 120 | G = discount_factor * G + reward 121 | # Update weighted importance sampling formula denominator 122 | C[state][action] += W 123 | # Update the action-value function using the incremental update formula (5.7) 124 | # This also improves our target policy which holds a reference to Q 125 | Q[state][action] += (W / C[state][action]) * (G - Q[state][action]) 126 | # If the action taken by the behavior policy is not the action 127 | # taken by the target policy the probability will be 0 and we can break 128 | if action != np.argmax(target_policy(state)): 129 | break 130 | W = W * 1./behavior_policy(state)[action] 131 | 132 | return Q, target_policy 133 | 134 | 135 | # In[6]: 136 | 137 | 138 | random_policy = create_random_policy(env.action_space.n) 139 | Q, policy = mc_control_importance_sampling(env, num_episodes=500000, behavior_policy=random_policy) 140 | 141 | 142 | # In[15]: 143 | 144 | 145 | V = defaultdict(float) 146 | for state, action_values in Q.items(): 147 | action_value = np.max(action_values) 148 | V[state] = action_value 149 | plot_value_function(V, title="Optimal Value Function") 150 | 151 | -------------------------------------------------------------------------------- /chapter5/TD_CartPole.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gym\n", 10 | "import numpy as np\n", 11 | "import sys\n", 12 | "import time" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 8, 18 | "metadata": { 19 | "scrolled": true 20 | }, 21 | "outputs": [ 22 | { 23 | "name": "stdout", 24 | "output_type": "stream", 25 | "text": [ 26 | "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", 27 | "Episode finished after 17 timesteps\n", 28 | "Game over...\n", 29 | "Episode finished after 15 timesteps\n", 30 | "Game over...\n", 31 | "Episode finished after 42 timesteps\n", 32 | "Game over...\n", 33 | "Episode finished after 19 timesteps\n", 34 | "Game over...\n", 35 | "Episode finished after 45 timesteps\n", 36 | "Game over...\n", 37 | "Episode finished after 16 timesteps\n", 38 | "Game over...\n", 39 | "Episode finished after 25 timesteps\n", 40 | "Game over...\n", 41 | "Episode finished after 41 timesteps\n", 42 | "Game over...\n", 43 | "Episode finished after 32 timesteps\n", 44 | "Game over...\n", 45 | "Episode finished after 37 timesteps\n", 46 | "Game over...\n", 47 | "Episode finished after 24 timesteps\n", 48 | "Game over...\n", 49 | "Episode finished after 24 timesteps\n", 50 | "Game over...\n", 51 | "Episode finished after 33 timesteps\n", 52 | "Game over...\n", 53 | "Episode finished after 17 timesteps\n", 54 | "Game over...\n", 55 | "Episode finished after 11 timesteps\n", 56 | "Game over...\n", 57 | "Episode finished after 48 timesteps\n", 58 | "Game over...\n", 59 | "Episode finished after 12 timesteps\n", 60 | "Game over...\n", 61 | "Episode finished after 36 timesteps\n", 62 | "Game over...\n", 63 | "Episode finished after 16 timesteps\n", 64 | "Game over...\n", 65 | "Episode finished after 15 timesteps\n", 66 | "Game over...\n", 67 | "Episode finished after 30 timesteps\n", 68 | "Game over...\n", 69 | "Episode finished after 12 timesteps\n", 70 | "Game over...\n", 71 | "Episode finished after 25 timesteps\n", 72 | "Game over...\n", 73 | "Episode finished after 24 timesteps\n", 74 | "Game over...\n", 75 | "Episode finished after 24 timesteps\n", 76 | "Game over...\n", 77 | "Episode finished after 30 timesteps\n", 78 | "Game over...\n", 79 | "Episode finished after 28 timesteps\n", 80 | "Game over...\n", 81 | "Episode finished after 18 timesteps\n", 82 | "Game over...\n", 83 | "Episode finished after 16 timesteps\n", 84 | "Game over...\n", 85 | "Episode finished after 14 timesteps\n", 86 | "Game over...\n", 87 | "Episode finished after 21 timesteps\n", 88 | "Game over...\n", 89 | "Episode finished after 15 timesteps\n", 90 | "Game over...\n", 91 | "Episode finished after 13 timesteps\n", 92 | "Game over...\n", 93 | "Episode finished after 16 timesteps\n", 94 | "Game over...\n", 95 | "Episode finished after 15 timesteps\n", 96 | "Game over...\n", 97 | "Episode finished after 18 timesteps\n", 98 | "Game over...\n", 99 | "Episode finished after 13 timesteps\n", 100 | "Game over...\n", 101 | "Episode finished after 61 timesteps\n", 102 | "Game over...\n", 103 | "Episode finished after 43 timesteps\n", 104 | "Game over...\n", 105 | "Episode finished after 16 timesteps\n", 106 | "Game over...\n", 107 | "Episode finished after 33 timesteps\n", 108 | "Game over...\n", 109 | "Episode finished after 23 timesteps\n", 110 | "Game over...\n", 111 | "Episode finished after 15 timesteps\n", 112 | "Game over...\n", 113 | "Episode finished after 13 timesteps\n", 114 | "Game over...\n", 115 | "Episode finished after 15 timesteps\n", 116 | "Game over...\n", 117 | "Episode finished after 23 timesteps\n", 118 | "Game over...\n", 119 | "Episode finished after 19 timesteps\n", 120 | "Game over...\n", 121 | "Episode finished after 44 timesteps\n", 122 | "Game over...\n", 123 | "Episode finished after 19 timesteps\n", 124 | "Game over...\n", 125 | "Episode finished after 16 timesteps\n", 126 | "Game over...\n", 127 | "Episode finished after 12 timesteps\n", 128 | "Game over...\n", 129 | "Episode finished after 17 timesteps\n", 130 | "Game over...\n", 131 | "Episode finished after 11 timesteps\n", 132 | "Game over...\n", 133 | "Episode finished after 45 timesteps\n", 134 | "Game over...\n", 135 | "Episode finished after 14 timesteps\n", 136 | "Game over...\n", 137 | "Episode finished after 49 timesteps\n", 138 | "Game over...\n", 139 | "Episode finished after 13 timesteps\n", 140 | "Game over...\n", 141 | "Episode finished after 16 timesteps\n", 142 | "Game over...\n", 143 | "Episode finished after 19 timesteps\n", 144 | "Game over...\n", 145 | "Episode finished after 21 timesteps\n", 146 | "Game over...\n", 147 | "Episode finished after 16 timesteps\n", 148 | "Game over...\n", 149 | "Episode finished after 20 timesteps\n", 150 | "Game over...\n", 151 | "Episode finished after 38 timesteps\n", 152 | "Game over...\n", 153 | "Episode finished after 12 timesteps\n", 154 | "Game over...\n", 155 | "Episode finished after 18 timesteps\n", 156 | "Game over...\n", 157 | "Episode finished after 13 timesteps\n", 158 | "Game over...\n", 159 | "Episode finished after 16 timesteps\n", 160 | "Game over...\n", 161 | "Episode finished after 16 timesteps\n", 162 | "Game over...\n", 163 | "Episode finished after 13 timesteps\n", 164 | "Game over...\n", 165 | "Episode finished after 12 timesteps\n", 166 | "Game over...\n", 167 | "Episode finished after 13 timesteps\n", 168 | "Game over...\n", 169 | "Episode finished after 29 timesteps\n", 170 | "Game over...\n", 171 | "Episode finished after 14 timesteps\n", 172 | "Game over...\n", 173 | "Episode finished after 25 timesteps\n", 174 | "Game over...\n", 175 | "Episode finished after 21 timesteps\n", 176 | "Game over...\n", 177 | "Episode finished after 20 timesteps\n", 178 | "Game over...\n", 179 | "Episode finished after 30 timesteps\n", 180 | "Game over...\n", 181 | "Episode finished after 19 timesteps\n", 182 | "Game over...\n", 183 | "Episode finished after 22 timesteps\n", 184 | "Game over...\n", 185 | "Episode finished after 13 timesteps\n", 186 | "Game over...\n", 187 | "Episode finished after 14 timesteps\n", 188 | "Game over...\n", 189 | "Episode finished after 32 timesteps\n", 190 | "Game over...\n", 191 | "Episode finished after 16 timesteps\n", 192 | "Game over...\n", 193 | "Episode finished after 22 timesteps\n", 194 | "Game over...\n", 195 | "Episode finished after 16 timesteps\n", 196 | "Game over...\n", 197 | "Episode finished after 21 timesteps\n", 198 | "Game over...\n", 199 | "Episode finished after 41 timesteps\n", 200 | "Game over...\n", 201 | "Episode finished after 15 timesteps\n", 202 | "Game over...\n", 203 | "Episode finished after 20 timesteps\n", 204 | "Game over...\n", 205 | "Episode finished after 12 timesteps\n", 206 | "Game over...\n", 207 | "Episode finished after 13 timesteps\n", 208 | "Game over...\n", 209 | "Episode finished after 22 timesteps\n", 210 | "Game over...\n", 211 | "Episode finished after 15 timesteps\n", 212 | "Game over...\n", 213 | "Episode finished after 14 timesteps\n", 214 | "Game over...\n", 215 | "Episode finished after 21 timesteps\n", 216 | "Game over...\n", 217 | "Episode finished after 10 timesteps\n", 218 | "Game over...\n", 219 | "Episode finished after 34 timesteps\n", 220 | "Game over...\n", 221 | "Episode finished after 20 timesteps\n", 222 | "Game over...\n", 223 | "Episode finished after 19 timesteps\n", 224 | "Game over...\n", 225 | "Episode finished after 44 timesteps\n", 226 | "Game over...\n", 227 | "Episode finished after 33 timesteps\n", 228 | "Game over...\n", 229 | "Episode finished after 18 timesteps\n", 230 | "Game over...\n", 231 | "Episode finished after 16 timesteps\n", 232 | "Game over...\n", 233 | "Episode finished after 21 timesteps\n", 234 | "Game over...\n", 235 | "Episode finished after 37 timesteps\n", 236 | "Game over...\n", 237 | "Episode finished after 15 timesteps\n", 238 | "Game over...\n", 239 | "Episode finished after 18 timesteps\n", 240 | "Game over...\n", 241 | "Episode finished after 22 timesteps\n", 242 | "Game over...\n", 243 | "Episode finished after 34 timesteps\n", 244 | "Game over...\n", 245 | "Episode finished after 86 timesteps\n", 246 | "Game over...\n", 247 | "Episode finished after 18 timesteps\n", 248 | "Game over...\n", 249 | "Episode finished after 34 timesteps\n", 250 | "Game over...\n", 251 | "Episode finished after 11 timesteps\n", 252 | "Game over...\n", 253 | "Episode finished after 11 timesteps\n", 254 | "Game over...\n", 255 | "Episode finished after 24 timesteps\n", 256 | "Game over...\n", 257 | "Episode finished after 27 timesteps\n", 258 | "Game over...\n", 259 | "Episode finished after 17 timesteps\n", 260 | "Game over...\n", 261 | "Episode finished after 10 timesteps\n", 262 | "Game over...\n", 263 | "Episode finished after 22 timesteps\n", 264 | "Game over...\n", 265 | "Episode finished after 15 timesteps\n", 266 | "Game over...\n", 267 | "Episode finished after 38 timesteps\n", 268 | "Game over...\n", 269 | "Episode finished after 37 timesteps\n", 270 | "Game over...\n", 271 | "Episode finished after 20 timesteps\n", 272 | "Game over...\n", 273 | "Episode finished after 14 timesteps\n", 274 | "Game over...\n", 275 | "Episode finished after 20 timesteps\n", 276 | "Game over...\n", 277 | "Episode finished after 15 timesteps\n", 278 | "Game over...\n", 279 | "Episode finished after 19 timesteps\n", 280 | "Game over...\n", 281 | "Episode finished after 20 timesteps\n", 282 | "Game over...\n", 283 | "Episode finished after 21 timesteps\n", 284 | "Game over...\n", 285 | "Episode finished after 13 timesteps\n", 286 | "Game over...\n", 287 | "Episode finished after 23 timesteps\n", 288 | "Game over...\n", 289 | "Episode finished after 17 timesteps\n", 290 | "Game over...\n", 291 | "Episode finished after 51 timesteps\n", 292 | "Game over...\n", 293 | "Episode finished after 24 timesteps\n", 294 | "Game over...\n", 295 | "Episode finished after 34 timesteps\n", 296 | "Game over...\n", 297 | "Episode finished after 29 timesteps\n", 298 | "Game over...\n", 299 | "Episode finished after 58 timesteps\n", 300 | "Game over...\n", 301 | "Episode finished after 34 timesteps\n", 302 | "Game over...\n", 303 | "Episode finished after 15 timesteps\n", 304 | "Game over...\n", 305 | "Episode finished after 11 timesteps\n", 306 | "Game over...\n", 307 | "Episode finished after 17 timesteps\n", 308 | "Game over...\n", 309 | "Episode finished after 29 timesteps\n", 310 | "Game over...\n", 311 | "Episode finished after 30 timesteps\n", 312 | "Game over...\n", 313 | "Episode finished after 23 timesteps\n", 314 | "Game over...\n", 315 | "Episode finished after 10 timesteps\n", 316 | "Game over...\n", 317 | "Episode finished after 28 timesteps\n", 318 | "Game over...\n", 319 | "Episode finished after 14 timesteps\n", 320 | "Game over...\n", 321 | "Episode finished after 13 timesteps\n", 322 | "Game over...\n", 323 | "Episode finished after 11 timesteps\n", 324 | "Game over...\n", 325 | "Episode finished after 17 timesteps\n", 326 | "Game over...\n", 327 | "Episode finished after 26 timesteps\n", 328 | "Game over...\n", 329 | "Episode finished after 40 timesteps\n", 330 | "Game over...\n", 331 | "Episode finished after 24 timesteps\n", 332 | "Game over...\n", 333 | "Episode finished after 12 timesteps\n", 334 | "Game over...\n", 335 | "Episode finished after 14 timesteps\n", 336 | "Game over...\n", 337 | "Episode finished after 64 timesteps\n", 338 | "Game over...\n", 339 | "Episode finished after 45 timesteps\n", 340 | "Game over...\n", 341 | "Episode finished after 14 timesteps\n", 342 | "Game over...\n", 343 | "Episode finished after 16 timesteps\n", 344 | "Game over...\n", 345 | "Episode finished after 16 timesteps\n", 346 | "Game over...\n", 347 | "Episode finished after 36 timesteps\n", 348 | "Game over...\n", 349 | "Episode finished after 36 timesteps\n", 350 | "Game over...\n", 351 | "Episode finished after 22 timesteps\n", 352 | "Game over...\n", 353 | "Episode finished after 33 timesteps\n", 354 | "Game over...\n", 355 | "Episode finished after 41 timesteps\n", 356 | "Game over...\n" 357 | ] 358 | }, 359 | { 360 | "name": "stdout", 361 | "output_type": "stream", 362 | "text": [ 363 | "Episode finished after 17 timesteps\n", 364 | "Game over...\n", 365 | "Episode finished after 23 timesteps\n", 366 | "Game over...\n", 367 | "Episode finished after 22 timesteps\n", 368 | "Game over...\n", 369 | "Episode finished after 24 timesteps\n", 370 | "Game over...\n", 371 | "Episode finished after 46 timesteps\n", 372 | "Game over...\n", 373 | "Episode finished after 51 timesteps\n", 374 | "Game over...\n", 375 | "Episode finished after 35 timesteps\n", 376 | "Game over...\n", 377 | "Episode finished after 30 timesteps\n", 378 | "Game over...\n", 379 | "Episode finished after 23 timesteps\n", 380 | "Game over...\n", 381 | "Episode finished after 12 timesteps\n", 382 | "Game over...\n", 383 | "Episode finished after 12 timesteps\n", 384 | "Game over...\n", 385 | "Episode finished after 35 timesteps\n", 386 | "Game over...\n", 387 | "Episode finished after 35 timesteps\n", 388 | "Game over...\n", 389 | "Episode finished after 19 timesteps\n", 390 | "Game over...\n", 391 | "Episode finished after 23 timesteps\n", 392 | "Game over...\n", 393 | "Episode finished after 38 timesteps\n", 394 | "Game over...\n", 395 | "Episode finished after 20 timesteps\n", 396 | "Game over...\n", 397 | "Episode finished after 31 timesteps\n", 398 | "Game over...\n", 399 | "Episode finished after 26 timesteps\n", 400 | "Game over...\n", 401 | "Episode finished after 15 timesteps\n", 402 | "Game over...\n", 403 | "Episode finished after 16 timesteps\n", 404 | "Game over...\n", 405 | "Episode finished after 77 timesteps\n", 406 | "Game over...\n", 407 | "Episode finished after 22 timesteps\n", 408 | "Game over...\n", 409 | "Episode finished after 17 timesteps\n", 410 | "Game over...\n", 411 | "Episode finished after 10 timesteps\n", 412 | "Game over...\n", 413 | "Episode finished after 13 timesteps\n", 414 | "Game over...\n", 415 | "Episode finished after 22 timesteps\n", 416 | "Game over...\n", 417 | "Episode finished after 28 timesteps\n", 418 | "Game over...\n", 419 | "Episode finished after 16 timesteps\n", 420 | "Game over...\n", 421 | "Episode finished after 19 timesteps\n", 422 | "Game over...\n", 423 | "Episode finished after 10 timesteps\n", 424 | "Game over...\n", 425 | "Episode finished after 28 timesteps\n", 426 | "Game over...\n", 427 | "Episode finished after 11 timesteps\n", 428 | "Game over...\n", 429 | "Episode finished after 58 timesteps\n", 430 | "Game over...\n", 431 | "Episode finished after 17 timesteps\n", 432 | "Game over...\n" 433 | ] 434 | } 435 | ], 436 | "source": [ 437 | "env = gym.make(\"CartPole-v0\")\n", 438 | "# env = gym.wrappers.Monitor(env, 'cartpole-experiment-1', force=True)\n", 439 | "\n", 440 | "sumlist = []\n", 441 | "for t in range(200):\n", 442 | " state = env.reset()\n", 443 | " i = 0\n", 444 | " while(True):\n", 445 | " i += 1\n", 446 | " env.render()\n", 447 | " action = env.action_space.sample()\n", 448 | " nA = env.action_space.n\n", 449 | " state, reward, done, _ = env.step(action)\n", 450 | " # print(state, action, reward)\n", 451 | "\n", 452 | " if done:\n", 453 | " print(\"Episode finished after {} timesteps\".format(i+1))\n", 454 | " break\n", 455 | " \n", 456 | " sumlist.append(i)\n", 457 | " print(\"Game over...\")\n", 458 | " \n", 459 | "# env.monitor.close()" 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": 9, 465 | "metadata": {}, 466 | "outputs": [], 467 | "source": [ 468 | "env.close()" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 10, 474 | "metadata": {}, 475 | "outputs": [ 476 | { 477 | "name": "stdout", 478 | "output_type": "stream", 479 | "text": [ 480 | "CartPole game iter average time is: 22.745\n" 481 | ] 482 | } 483 | ], 484 | "source": [ 485 | "iter_time = sum(sumlist)/len(sumlist)\n", 486 | "print(\"CartPole game iter average time is: {}\".format(iter_time))" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": null, 492 | "metadata": {}, 493 | "outputs": [], 494 | "source": [] 495 | } 496 | ], 497 | "metadata": { 498 | "kernelspec": { 499 | "display_name": "Python 3", 500 | "language": "python", 501 | "name": "python3" 502 | }, 503 | "language_info": { 504 | "codemirror_mode": { 505 | "name": "ipython", 506 | "version": 3 507 | }, 508 | "file_extension": ".py", 509 | "mimetype": "text/x-python", 510 | "name": "python", 511 | "nbconvert_exporter": "python", 512 | "pygments_lexer": "ipython3", 513 | "version": "3.6.0" 514 | } 515 | }, 516 | "nbformat": 4, 517 | "nbformat_minor": 2 518 | } 519 | -------------------------------------------------------------------------------- /chapter5/TD_CartPole.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[2]: 5 | 6 | 7 | import gym 8 | import numpy as np 9 | import sys 10 | import time 11 | 12 | 13 | # In[8]: 14 | 15 | 16 | env = gym.make("CartPole-v0") 17 | # env = gym.wrappers.Monitor(env, 'cartpole-experiment-1', force=True) 18 | 19 | sumlist = [] 20 | for t in range(200): 21 | state = env.reset() 22 | i = 0 23 | while(True): 24 | i += 1 25 | env.render() 26 | action = env.action_space.sample() 27 | nA = env.action_space.n 28 | state, reward, done, _ = env.step(action) 29 | # print(state, action, reward) 30 | 31 | if done: 32 | print("Episode finished after {} timesteps".format(i+1)) 33 | break 34 | 35 | sumlist.append(i) 36 | print("Game over...") 37 | 38 | # env.monitor.close() 39 | 40 | 41 | # In[9]: 42 | 43 | 44 | env.close() 45 | 46 | 47 | # In[10]: 48 | 49 | 50 | iter_time = sum(sumlist)/len(sumlist) 51 | print("CartPole game iter average time is: {}".format(iter_time)) 52 | 53 | -------------------------------------------------------------------------------- /chapter5/TD_Qlearning.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[17]: 5 | 6 | 7 | import gym 8 | import numpy as np 9 | import sys 10 | import time 11 | import pandas as pd 12 | import matplotlib 13 | from collections import defaultdict, namedtuple 14 | 15 | get_ipython().run_line_magic('matplotlib', 'inline') 16 | matplotlib.style.use('ggplot') 17 | 18 | 19 | # In[4]: 20 | 21 | 22 | env = gym.make("CartPole-v0") 23 | 24 | 25 | # In[20]: 26 | 27 | 28 | class QLearning(): 29 | def __init__(self, env, num_episodes, discount=1.0, alpha=0.5, epsilon=0.1, n_bins=10): 30 | self.nA = env.action_space.n 31 | self.nS = env.observation_space.shape[0] 32 | self.env = env 33 | self.num_episodes = num_episodes 34 | self.discount = discount 35 | self.alpha = alpha 36 | self.epsilon = epsilon 37 | # Initialize Q(s; a) 38 | self.Q = defaultdict(lambda: np.zeros(self.nA)) 39 | 40 | # Keeps track of useful statistics 41 | record = namedtuple("Record", ["episode_lengths","episode_rewards"]) 42 | self.rec = record(episode_lengths=np.zeros(num_episodes), 43 | episode_rewards=np.zeros(num_episodes)) 44 | 45 | self.cart_position_bins = pd.cut([-2.4, 2.4], bins=n_bins, retbins=True)[1] 46 | self.pole_angle_bins = pd.cut([-2, 2], bins=n_bins, retbins=True)[1] 47 | self.cart_velocity_bins = pd.cut([-1, 1], bins=n_bins, retbins=True)[1] 48 | self.angle_rate_bins = pd.cut([-3.5, 3.5], bins=n_bins, retbins=True)[1] 49 | 50 | def __get_bins_states(self, state): 51 | """ 52 | Case number of the sate is huge so in order to simplify the situation 53 | cut the state sapece in to bins. 54 | 55 | if the state_idx is [1,3,6,4] than the return will be 1364 56 | """ 57 | s1_, s2_, s3_, s4_ = state 58 | cart_position_idx = np.digitize(s1_, self.cart_position_bins) 59 | pole_angle_idx = np.digitize(s2_, self.pole_angle_bins) 60 | cart_velocity_idx = np.digitize(s3_, self.cart_velocity_bins) 61 | angle_rate_idx = np.digitize(s4_, self.angle_rate_bins) 62 | 63 | state_ = [cart_position_idx, pole_angle_idx, 64 | cart_velocity_idx, angle_rate_idx] 65 | 66 | state = map(lambda s: int(s), state_) 67 | return tuple(state) 68 | 69 | def __epislon_greedy_policy(self, epsilon, nA): 70 | 71 | def policy(state): 72 | A = np.ones(nA, dtype=float) * epsilon / nA 73 | best_action = np.argmax(self.Q[state]) 74 | A[best_action] += (1.0 - epsilon) 75 | return A 76 | 77 | return policy 78 | 79 | def __next_action(self, prob): 80 | return np.random.choice(np.arange(len(prob)), p=prob) 81 | 82 | def qlearning(self): 83 | """ 84 | q-learning algo 85 | """ 86 | policy = self.__epislon_greedy_policy(self.epsilon, self.nA) 87 | sumlist = [] 88 | 89 | for i_episode in range(self.num_episodes): 90 | # Print out which episode we are on 91 | if 0 == (i_episode+1) % 10: 92 | print("\r Episode {} in {}".format(i_episode+1, self.num_episodes)) 93 | # sys.stdout.flush() 94 | 95 | step = 0 96 | # Initialize S 97 | state__ = self.env.reset() 98 | state = self.__get_bins_states(state__) 99 | 100 | # Repeat (for each step of episode) 101 | while(True): 102 | # Choose A from S using policy derived from Q 103 | prob_actions = policy(state) 104 | action = self.__next_action(prob_actions) 105 | 106 | # Take action A, observe R, S' 107 | next_state__, reward, done, info = env.step(action) 108 | next_state = self.__get_bins_states(next_state__) 109 | 110 | # update history record 111 | self.rec.episode_lengths[i_episode] += reward 112 | self.rec.episode_rewards[i_episode] = step 113 | 114 | # TD update: Q(S; A)<-Q(S; A) + aplha*[R + discount * max Q(S'; a) − Q(S; A)] 115 | best_next_action = np.argmax(self.Q[next_state]) 116 | td_target = reward + self.discount * self.Q[next_state][best_next_action] 117 | td_delta = td_target - self.Q[state][action] 118 | self.Q[state][action] += self.alpha * td_delta 119 | 120 | if done: 121 | # until S is terminal 122 | print("Episode finished after {} timesteps".format(step)) 123 | sumlist.append(step) 124 | break 125 | else: 126 | step += 1 127 | # S<-S' 128 | state = next_state 129 | 130 | iter_time = sum(sumlist)/len(sumlist) 131 | print("CartPole game iter average time is: {}".format(iter_time)) 132 | return self.Q 133 | 134 | cls_qlearning = QLearning(env, num_episodes=200) 135 | Q = cls_qlearning.qlearning() 136 | 137 | 138 | # In[27]: 139 | 140 | 141 | from matplotlib import pyplot as plt 142 | 143 | def plot_episode_stats(stats, smoothing_window=10, noshow=False): 144 | # Plot the episode length over time 145 | fig1 = plt.figure(figsize=(10,5)) 146 | plt.plot(stats.episode_lengths[:200]) 147 | plt.xlabel("Episode") 148 | plt.ylabel("Episode Length") 149 | plt.title("Episode Length over Time") 150 | if noshow: 151 | plt.close(fig1) 152 | else: 153 | plt.show(fig1) 154 | 155 | # Plot the episode reward over time 156 | fig2 = plt.figure(figsize=(10,5)) 157 | rewards_smoothed = pd.Series(stats.episode_rewards[:200]).rolling(smoothing_window, min_periods=smoothing_window).mean() 158 | plt.plot(rewards_smoothed) 159 | plt.xlabel("Episode") 160 | plt.ylabel("Episode Reward") 161 | plt.title("Episode Reward over Time".format(smoothing_window)) 162 | if noshow: 163 | plt.close(fig2) 164 | else: 165 | plt.show(fig2) 166 | 167 | return fig1, fig2 168 | 169 | plot_episode_stats(cls_qlearning.rec) 170 | 171 | -------------------------------------------------------------------------------- /chapter5/TD_sarsa.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[4]: 5 | 6 | 7 | import gym 8 | import numpy as np 9 | import sys 10 | import time 11 | import pandas as pd 12 | import matplotlib 13 | from collections import defaultdict, namedtuple 14 | 15 | get_ipython().run_line_magic('matplotlib', 'inline') 16 | matplotlib.style.use('ggplot') 17 | 18 | 19 | # In[5]: 20 | 21 | 22 | env = gym.make("CartPole-v0") 23 | 24 | 25 | # In[7]: 26 | 27 | 28 | class SARSA(): 29 | def __init__(self, env, num_episodes, discount=1.0, alpha=0.5, epsilon=0.1, n_bins=10): 30 | self.nA = env.action_space.n 31 | self.nS = env.observation_space.shape[0] 32 | self.env = env 33 | self.num_episodes = num_episodes 34 | self.discount = discount 35 | self.alpha = alpha 36 | self.epsilon = epsilon 37 | self.Q = defaultdict(lambda: np.zeros(self.nA)) 38 | 39 | # Keeps track of useful statistics 40 | record = namedtuple("Record", ["episode_lengths","episode_rewards"]) 41 | self.rec = record(episode_lengths=np.zeros(num_episodes), 42 | episode_rewards=np.zeros(num_episodes)) 43 | 44 | self.cart_position_bins = pd.cut([-2.4, 2.4], bins=n_bins, retbins=True)[1] 45 | self.pole_angle_bins = pd.cut([-2, 2], bins=n_bins, retbins=True)[1] 46 | self.cart_velocity_bins = pd.cut([-1, 1], bins=n_bins, retbins=True)[1] 47 | self.angle_rate_bins = pd.cut([-3.5, 3.5], bins=n_bins, retbins=True)[1] 48 | 49 | def __get_bins_states(self, state): 50 | """ 51 | Case number of the sate is huge so in order to simplify the situation 52 | cut the state sapece in to bins. 53 | 54 | if the state_idx is [1,3,6,4] than the return will be 1364 55 | """ 56 | s1_, s2_, s3_, s4_ = state 57 | cart_position_idx = np.digitize(s1_, self.cart_position_bins) 58 | pole_angle_idx = np.digitize(s2_, self.pole_angle_bins) 59 | cart_velocity_idx = np.digitize(s3_, self.cart_velocity_bins) 60 | angle_rate_idx = np.digitize(s4_, self.angle_rate_bins) 61 | 62 | state_ = [cart_position_idx, pole_angle_idx, 63 | cart_velocity_idx, angle_rate_idx] 64 | 65 | state = map(lambda s: int(s), state_) 66 | return tuple(state) 67 | 68 | def __epislon_greedy_policy(self, epsilon, nA): 69 | 70 | def policy(state): 71 | A = np.ones(nA, dtype=float) * epsilon / nA 72 | best_action = np.argmax(self.Q[state]) 73 | A[best_action] += (1.0 - epsilon) 74 | return A 75 | 76 | return policy 77 | 78 | def __next_action(self, prob): 79 | return np.random.choice(np.arange(len(prob)), p=prob) 80 | 81 | def sarsa(self): 82 | """ 83 | SARSA algo 84 | """ 85 | policy = self.__epislon_greedy_policy(self.epsilon, self.nA) 86 | sumlist = [] 87 | 88 | for i_episode in range(self.num_episodes): 89 | if 0 == (i_episode+1) % 10: 90 | print("\r Episode {} in {}".format(i_episode+1, self.num_episodes)) 91 | # sys.stdout.flush() 92 | 93 | step = 0 94 | state__ = self.env.reset() 95 | state = self.__get_bins_states(state__) 96 | prob_actions = policy(state) 97 | action = self.__next_action(prob_actions) 98 | 99 | # one step 100 | while(True): 101 | next_state__, reward, done, info = env.step(action) 102 | next_state = self.__get_bins_states(next_state__) 103 | 104 | prob_next_actions = policy(next_state) 105 | next_action = self.__next_action(prob_next_actions) 106 | 107 | # update history record 108 | self.rec.episode_lengths[i_episode] += reward 109 | self.rec.episode_rewards[i_episode] = step 110 | 111 | # TD update 112 | td_target = reward + self.discount * self.Q[next_state][next_action] 113 | td_delta = td_target - self.Q[state][action] 114 | self.Q[state][action] += self.alpha * td_delta 115 | 116 | if done: 117 | reward = -200 118 | print("Episode finished after {} timesteps".format(step)) 119 | sumlist.append(step) 120 | break 121 | else: 122 | step += 1 123 | state = next_state 124 | action = next_action 125 | 126 | iter_time = sum(sumlist)/len(sumlist) 127 | print("CartPole game iter average time is: {}".format(iter_time)) 128 | return self.Q 129 | 130 | cls_sarsa = SARSA(env, num_episodes=1000) 131 | Q = cls_sarsa.sarsa() 132 | 133 | 134 | # In[9]: 135 | 136 | 137 | from matplotlib import pyplot as plt 138 | 139 | def plot_episode_stats(stats, smoothing_window=10, noshow=False): 140 | # Plot the episode length over time 141 | fig1 = plt.figure(figsize=(10,5)) 142 | plt.plot(stats.episode_lengths[:200]) 143 | plt.xlabel("Episode") 144 | plt.ylabel("Episode Length") 145 | plt.title("Episode Length over Time") 146 | if noshow: 147 | plt.close(fig1) 148 | else: 149 | plt.show(fig1) 150 | 151 | # Plot the episode reward over time 152 | fig2 = plt.figure(figsize=(10,5)) 153 | rewards_smoothed = pd.Series(stats.episode_rewards[:200]).rolling(smoothing_window, min_periods=smoothing_window).mean() 154 | plt.plot(rewards_smoothed) 155 | plt.xlabel("Episode") 156 | plt.ylabel("Episode Reward") 157 | plt.title("Episode Reward over Time".format(smoothing_window)) 158 | if noshow: 159 | plt.close(fig2) 160 | else: 161 | plt.show(fig2) 162 | 163 | return fig1, fig2 164 | 165 | plot_episode_stats(cls_sarsa.rec) 166 | 167 | -------------------------------------------------------------------------------- /chapter6/FA_Qlearning.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[58]: 5 | 6 | 7 | import gym 8 | import sys 9 | import numpy as np 10 | import matplotlib 11 | from collections import defaultdict, namedtuple 12 | from matplotlib import pyplot as plt 13 | from mpl_toolkits.mplot3d import Axes3D 14 | 15 | from sklearn import pipeline 16 | from sklearn.preprocessing import StandardScaler 17 | from sklearn.linear_model import SGDRegressor 18 | from sklearn.kernel_approximation import RBFSampler 19 | 20 | get_ipython().run_line_magic('matplotlib', 'inline') 21 | matplotlib.style.use("ggplot") 22 | 23 | 24 | # In[59]: 25 | 26 | 27 | env = gym.envs.make("MountainCar-v0") 28 | 29 | 30 | # In[60]: 31 | 32 | 33 | state_samples = np.array([env.observation_space.sample() for x in range(10000)]) 34 | # Num Observation Min Max 35 | # 0 position -1.2 0.6 36 | # 1 velocity -0.07 0.07 37 | position_max = np.amax(observation_examples[:, 0]) 38 | position_min = np.amin(observation_examples[:, 0]) 39 | velocity_max = np.amax(observation_examples[:, 1]) 40 | velocity_min = np.amin(observation_examples[:, 1]) 41 | 42 | scaler = StandardScaler() 43 | scaler.fit(state_samples) 44 | scaler_samples = scaler.transform(state_samples) 45 | 46 | featurizer_state = RBFSampler(gamma=0.5, n_components=100) 47 | featurizer_state.fit(scaler_samples) 48 | print(featurizer_state) 49 | 50 | state = env.reset() 51 | print(observation_examples[20]) 52 | featurized = featurizer_state.transform([observation_examples[10]]) 53 | 54 | 55 | # In[75]: 56 | 57 | 58 | class ValueFunction(object): 59 | """ 60 | Value Funciton approximator. 61 | """ 62 | def __init__(self): 63 | # sampleing envrionment state in order to featurize it. 64 | state_samples = np.array([env.observation_space.sample() for x in range(10000)]) 65 | 66 | # Standardize features by removing the mean and scaling to unit variance 67 | self.scaler = StandardScaler() 68 | self.scaler.fit(state_samples) 69 | scaler_samples = scaler.transform(state_samples) 70 | 71 | # Approximates feature map of an RBF kernel 72 | # by Monte Carlo approximation of its Fourier transform. 73 | self.featurizer_state = RBFSampler(gamma=0.5, n_components=100) 74 | self.featurizer_state.fit(scaler_samples) 75 | 76 | # action model for SGD regressor 77 | self.action_models = [] 78 | nA = env.action_space.n 79 | for na in range(nA): 80 | # Linear classifiers with SGD training. 81 | model = SGDRegressor(learning_rate="constant") 82 | model.partial_fit([self.__featurize_state(env.reset())], [0]) 83 | self.action_models.append(model) 84 | 85 | # print(self.action_models) 86 | 87 | def __featurize_state(self, state): 88 | scaler_state = self.scaler.transform([state]) 89 | return self.featurizer_state.transform(scaler_state)[0] 90 | 91 | def predict(self, state): 92 | curr_features = self.__featurize_state(state) 93 | action_probs = np.array([m.predict([curr_features])[0] for m in self.action_models]) 94 | # print(action_probs) 95 | return action_probs 96 | 97 | def update(self, state, action, y): 98 | curr_features = self.__featurize_state(state) 99 | self.action_models[action].partial_fit([curr_features], [y]) 100 | 101 | class QLearning(): 102 | def __init__(self, env, num_episodes, discount=1.0, alpha=0.5, epsilon=0.1, ep_decay=1.0): 103 | self.nA = env.action_space.n 104 | self.nS = env.observation_space.shape[0] 105 | self.env = env 106 | self.num_episodes = num_episodes 107 | self.discount = discount 108 | self.alpha = alpha 109 | self.epsilon = epsilon 110 | self.vfa = ValueFunction() 111 | self.ep_decay = ep_decay 112 | 113 | # Keeps track of useful statistics 114 | record = namedtuple("Record", ["episode_lengths","episode_rewards"]) 115 | self.rec = record(episode_lengths=np.zeros(num_episodes), 116 | episode_rewards=np.zeros(num_episodes)) 117 | 118 | def __epislon_greedy_policy(self, epsilon, nA): 119 | 120 | def policy(state): 121 | A = np.ones(nA, dtype=float) * epsilon / nA 122 | Q = self.vfa.predict(state) 123 | best_action = np.argmax(Q) 124 | A[best_action] += (1.0 - epsilon) 125 | return A 126 | 127 | return policy 128 | 129 | def __next_action(self, prob): 130 | return np.random.choice(np.arange(len(prob)), p=prob) 131 | 132 | def qlearning(self): 133 | """ 134 | Q-learning algo 135 | """ 136 | sumlist = [] 137 | 138 | for i_episode in range(self.num_episodes): 139 | # Print out which episode we are on 140 | if 0 == (i_episode) % 10: 141 | print("Episode {} in {}".format(i_episode+1, self.num_episodes)) 142 | # sys.stdout.flush() 143 | 144 | # following current policy 145 | policy_epsilon = self.epsilon*self.ep_decay**i_episode 146 | policy = self.__epislon_greedy_policy(policy_epsilon, self.nA) 147 | 148 | step = 0 149 | # Initialize S 150 | state = self.env.reset() 151 | 152 | # Repeat (for each step of episode) 153 | while(True): 154 | # Choose A from S using policy derived from Q 155 | prob_actions = policy(state) 156 | action = self.__next_action(prob_actions) 157 | 158 | # Take action A, observe R, S' 159 | next_state, reward, done, info = env.step(action) 160 | 161 | # update history record 162 | self.rec.episode_lengths[i_episode] += reward 163 | self.rec.episode_rewards[i_episode] = step 164 | 165 | # TD update: Q(S; A)<-Q(S; A) + aplha*[R + discount * max Q(S'; a) − Q(S; A)] 166 | q_next_value = self.vfa.predict(next_state) 167 | td_target = reward + self.discount * np.max(q_next_value) 168 | self.vfa.update(state, action, td_target) 169 | 170 | if done: 171 | # until S is terminal 172 | print("Episode finished after {} timesteps".format(step)) 173 | sumlist.append(step) 174 | break 175 | else: 176 | step += 1 177 | # S<-S' 178 | state = next_state 179 | 180 | iter_time = sum(sumlist)/len(sumlist) 181 | print("MountainCar game iter average time is: {}".format(iter_time)) 182 | 183 | cls_qlearning = QLearning(env, num_episodes=200) 184 | cls_qlearning.qlearning() 185 | 186 | 187 | # In[77]: 188 | 189 | 190 | from matplotlib import pyplot as plt 191 | import pandas as pd 192 | 193 | def plot_episode_stats(stats, smoothing_window=10, noshow=False): 194 | # Plot the episode length over time 195 | fig1 = plt.figure(figsize=(10,5)) 196 | plt.plot(stats.episode_lengths[:200]) 197 | plt.xlabel("Episode") 198 | plt.ylabel("Episode Length") 199 | plt.title("Episode Length over Time") 200 | if noshow: 201 | plt.close(fig1) 202 | else: 203 | plt.show(fig1) 204 | 205 | # Plot the episode reward over time 206 | fig2 = plt.figure(figsize=(10,5)) 207 | rewards_smoothed = pd.Series(stats.episode_rewards[:200]).rolling(smoothing_window, min_periods=smoothing_window).mean() 208 | plt.plot(rewards_smoothed) 209 | plt.xlabel("Episode") 210 | plt.ylabel("Episode Reward") 211 | plt.title("Episode Reward over Time".format(smoothing_window)) 212 | if noshow: 213 | plt.close(fig2) 214 | else: 215 | plt.show(fig2) 216 | 217 | return fig1, fig2 218 | 219 | plot_episode_stats(cls_qlearning.rec) 220 | 221 | -------------------------------------------------------------------------------- /chapter6/FA_Qlearning2.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import gym 8 | import sys 9 | import itertools 10 | import matplotlib 11 | import numpy as np 12 | import pandas as pd 13 | from collections import defaultdict, namedtuple 14 | 15 | from sklearn.pipeline import FeatureUnion 16 | from sklearn.preprocessing import StandardScaler as Scaler 17 | from sklearn.linear_model import SGDRegressor as SGD 18 | from sklearn.kernel_approximation import RBFSampler as RBF 19 | 20 | from matplotlib import pyplot as plt 21 | from mpl_toolkits.mplot3d import Axes3D 22 | 23 | get_ipython().run_line_magic('matplotlib', 'inline') 24 | matplotlib.style.use('ggplot') 25 | 26 | 27 | # In[2]: 28 | 29 | 30 | env = gym.envs.make("MountainCar-v0") 31 | 32 | 33 | # In[7]: 34 | 35 | 36 | class Estimator(): 37 | """ 38 | Value Function approximator. 39 | """ 40 | def __init__(self): 41 | # sampleing envrionment state in order to featurize it. 42 | observation_examples = np.array([env.observation_space.sample() for x in range(10000)]) 43 | 44 | # Feature Preprocessing: Normalize to zero mean and unit variance 45 | # We use a few samples from the observation space to do this 46 | self.scaler = Scaler() 47 | self.scaler.fit(observation_examples) 48 | 49 | # Used to convert a state to a featurizes represenation. 50 | # We use RBF kernels with different variances to cover different parts of the space 51 | self.featurizer = FeatureUnion([ 52 | ("rbf1", RBF(gamma=5.0, n_components=100)), 53 | ("rbf2", RBF(gamma=2.0, n_components=100)), 54 | ("rbf3", RBF(gamma=1.0, n_components=100)), 55 | ("rbf4", RBF(gamma=0.5, n_components=100)) 56 | ]) 57 | self.featurizer.fit(self.scaler.transform(observation_examples)) 58 | 59 | # action model for SGD regressor 60 | self.action_models = [] 61 | self.nA = env.action_space.n 62 | 63 | for na in range(self.nA): 64 | model = SGD(learning_rate="constant") 65 | model.partial_fit([self.__featurize_state(env.reset())], [0]) 66 | self.action_models.append(model) 67 | 68 | # print(self.action_models) 69 | 70 | def __featurize_state(self, state): 71 | """ 72 | Returns the featurized representation for a state. 73 | """ 74 | scaled = self.scaler.transform([state]) 75 | return self.featurizer.transform(scaled)[0] 76 | 77 | def predict(self, s, a=None): 78 | """ 79 | Makes value function predictions. 80 | 81 | Args: 82 | s: state to make a prediction for 83 | a: (Optional) action to make a prediction for 84 | 85 | Returns 86 | If an action a is given this returns a single number as the prediction. 87 | If no action is given this returns a vector or predictions for all actions 88 | in the environment where pred[i] is the prediction for action i. 89 | 90 | """ 91 | features = self.__featurize_state(s) 92 | if not a: 93 | return np.array([model.predict([features])[0] for model in self.action_models]) 94 | else: 95 | return self.action_models[a].predict([features])[0] 96 | 97 | def update(self, s, a, y): 98 | """ 99 | Updates the estimator parameters for a given state and action towards 100 | the target y. 101 | """ 102 | cur_features = self.__featurize_state(s) 103 | self.action_models[a].partial_fit([cur_features], [y]) 104 | 105 | 106 | # In[9]: 107 | 108 | 109 | class VF_QLearning(): 110 | """ 111 | Value Funciton Approximator with Q-learning 112 | 113 | Q-Learning algorithm for TD control using Function Approximation. 114 | Finds the optimal greedy policy while following an epsilon-greedy policy. 115 | """ 116 | def __init__(self, env, estimator, 117 | num_episodes, epsilon=0.1, 118 | discount_factor=1.0, epsilon_decay=1.0): 119 | 120 | self.nA = env.action_space.n 121 | self.nS = env.observation_space.shape[0] 122 | self.env = env 123 | self.num_episodes = num_episodes 124 | self.epsilon = epsilon 125 | self.discount_factor = discount_factor 126 | self.epsilon_decay = epsilon_decay 127 | self.estimator = estimator 128 | 129 | # Keeps track of useful statistics 130 | record_head = namedtuple("Stats",["episode_lengths", "episode_rewards"]) 131 | self.record = record_head( 132 | episode_lengths=np.zeros(num_episodes), 133 | episode_rewards=np.zeros(num_episodes)) 134 | 135 | def __epislon_greedy_policy(self, nA, epislon=0.5): 136 | """ 137 | epislon greedy policy algorithm 138 | """ 139 | def policy(state): 140 | A = np.ones(nA, dtype=float) * epislon / nA 141 | Q = self.estimator.predict(state) 142 | best_action = np.argmax(Q) 143 | A[best_action] += (1.0 - epislon) 144 | 145 | return A 146 | 147 | return policy 148 | 149 | def __random_aciton(self, action_prob): 150 | """ 151 | """ 152 | return np.random.choice(np.arange(len(action_prob)), p=action_prob) 153 | 154 | def q_learning(self): 155 | """ 156 | """ 157 | for i_episode in range(self.num_episodes): 158 | # print the number iter episode 159 | num_present = (i_episode+1)/self.num_episodes 160 | print("Episode {}/{}".format(i_episode + 1, self.num_episodes), end="") 161 | print("="*round(num_present*60)) 162 | 163 | # The policy we're following 164 | policy_epislon = self.epsilon * self.epsilon_decay**i_episode 165 | policy = self.__epislon_greedy_policy(self.nA, policy_epislon) 166 | 167 | # Print out which episode we're on, useful for debugging. 168 | # Also print reward for last episode 169 | last_reward = self.record.episode_rewards[i_episode - 1] 170 | sys.stdout.flush() 171 | 172 | # Reset the environment and pick the first action 173 | state = env.reset() 174 | 175 | next_action = None 176 | 177 | # One step in the environment, replace while(True) 178 | for t in itertools.count(): 179 | action_probs = policy(state) 180 | action = self.__random_aciton(action_probs) 181 | 182 | # Take a step 183 | next_state, reward, done, _ = env.step(action) 184 | 185 | # Update statistics 186 | self.record.episode_rewards[i_episode] += reward 187 | self.record.episode_lengths[i_episode] = t 188 | 189 | # TD Update 190 | q_values_next = estimator.predict(next_state) 191 | # Q-Value TD Target 192 | td_target = reward + self.discount_factor * np.max(q_values_next) 193 | # Update the function approximator using our target 194 | estimator.update(state, action, td_target) 195 | 196 | print("\rStep {} with reward ({})".format(t, last_reward), end="") 197 | 198 | if done: break 199 | 200 | state = next_state 201 | 202 | return self.record 203 | 204 | 205 | # In[32]: 206 | 207 | 208 | estimator = Estimator() 209 | vf = VF_QLearning(env, estimator, num_episodes=100, epsilon=0.2) 210 | result = vf.q_learning() 211 | 212 | 213 | # In[33]: 214 | 215 | 216 | def plot_cost_to_go_mountain_car(env, estimator, niter, num_tiles=20): 217 | x = np.linspace(env.observation_space.low[0], env.observation_space.high[0], num=num_tiles) 218 | y = np.linspace(env.observation_space.low[1], env.observation_space.high[1], num=num_tiles) 219 | X, Y = np.meshgrid(x, y) 220 | Z = np.apply_along_axis(lambda _: -np.max(estimator.predict(_)), 2, np.dstack([X, Y])) 221 | 222 | fig = plt.figure(figsize=(15,7.5)) 223 | ax = fig.add_subplot(111, projection='3d') 224 | surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, 225 | cmap=matplotlib.cm.coolwarm, vmin=0, vmax=160) 226 | ax.set_xlabel('Position') 227 | ax.set_ylabel('Velocity') 228 | ax.set_zlabel('Value') 229 | ax.set_zlim(0, 160) 230 | ax.set_facecolor("white") 231 | ax.set_title("Cost To Go Function (iter:{})".format(niter)) 232 | fig.colorbar(surf) 233 | plt.show() 234 | 235 | def plot_episode_stats(stats, smoothing_window=10, noshow=False): 236 | # Plot the episode length over time 237 | fig1 = plt.figure(figsize=(15,7.5)) 238 | plt.plot(stats.episode_lengths) 239 | plt.xlabel("Episode") 240 | plt.ylabel("Episode Length") 241 | plt.title("Episode Length over Time") 242 | if noshow: 243 | plt.close(fig1) 244 | else: 245 | plt.show(fig1) 246 | 247 | # Plot the episode reward over time 248 | fig2 = plt.figure(figsize=(15,7.5)) 249 | rewards_smoothed = pd.Series(stats.episode_rewards).rolling(smoothing_window, min_periods=smoothing_window).mean() 250 | plt.plot(rewards_smoothed) 251 | plt.xlabel("Episode") 252 | plt.ylabel("Episode Reward (Smoothed)") 253 | plt.title("Episode Reward over Time (Smoothed over window size {})".format(smoothing_window)) 254 | if noshow: 255 | plt.close(fig2) 256 | else: 257 | plt.show(fig2) 258 | 259 | # Plot time steps and episode number 260 | fig3 = plt.figure(figsize=(15,7.5)) 261 | plt.plot(np.cumsum(stats.episode_lengths), np.arange(len(stats.episode_lengths))) 262 | plt.xlabel("Time Steps") 263 | plt.ylabel("Episode") 264 | plt.title("Episode per time step") 265 | if noshow: 266 | plt.close(fig3) 267 | else: 268 | plt.show(fig3) 269 | 270 | return fig1, fig2, fig3 271 | 272 | 273 | plot_cost_to_go_mountain_car(env, estimator, 100) 274 | plot_episode_stats(result) 275 | 276 | -------------------------------------------------------------------------------- /chapter6/FA_SARSA.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[58]: 5 | 6 | 7 | import gym 8 | import sys 9 | import numpy as np 10 | import matplotlib 11 | from collections import defaultdict, namedtuple 12 | from matplotlib import pyplot as plt 13 | from mpl_toolkits.mplot3d import Axes3D 14 | 15 | from sklearn import pipeline 16 | from sklearn.preprocessing import StandardScaler 17 | from sklearn.linear_model import SGDRegressor 18 | from sklearn.kernel_approximation import RBFSampler 19 | 20 | get_ipython().run_line_magic('matplotlib', 'inline') 21 | matplotlib.style.use("ggplot") 22 | 23 | 24 | # In[59]: 25 | 26 | 27 | env = gym.envs.make("MountainCar-v0") 28 | 29 | 30 | # In[75]: 31 | 32 | 33 | class ValueFunction(object): 34 | """ 35 | Value Funciton approximator. 36 | """ 37 | def __init__(self): 38 | # sampleing envrionment state in order to featurize it. 39 | state_samples = np.array([env.observation_space.sample() for x in range(10000)]) 40 | 41 | # Standardize features by removing the mean and scaling to unit variance 42 | self.scaler = StandardScaler() 43 | self.scaler.fit(state_samples) 44 | scaler_samples = scaler.transform(state_samples) 45 | 46 | # Approximates feature map of an RBF kernel 47 | # by Monte Carlo approximation of its Fourier transform. 48 | self.featurizer_state = RBFSampler(gamma=0.5, n_components=100) 49 | self.featurizer_state.fit(scaler_samples) 50 | 51 | # action model for SGD regressor 52 | self.action_models = [] 53 | nA = env.action_space.n 54 | for na in range(nA): 55 | # Linear classifiers with SGD training. 56 | model = SGDRegressor(learning_rate="constant") 57 | model.partial_fit([self.__featurize_state(env.reset())], [0]) 58 | self.action_models.append(model) 59 | 60 | # print(self.action_models) 61 | 62 | def __featurize_state(self, state): 63 | scaler_state = self.scaler.transform([state]) 64 | return self.featurizer_state.transform(scaler_state)[0] 65 | 66 | def predict(self, state): 67 | curr_features = self.__featurize_state(state) 68 | action_probs = np.array([m.predict([curr_features])[0] for m in self.action_models]) 69 | # print(action_probs) 70 | return action_probs 71 | 72 | def update(self, state, action, y): 73 | curr_features = self.__featurize_state(state) 74 | self.action_models[action].partial_fit([curr_features], [y]) 75 | 76 | class QLearning(): 77 | def __init__(self, env, num_episodes, discount=1.0, alpha=0.5, epsilon=0.1, ep_decay=1.0): 78 | self.nA = env.action_space.n 79 | self.nS = env.observation_space.shape[0] 80 | self.env = env 81 | self.num_episodes = num_episodes 82 | self.discount = discount 83 | self.alpha = alpha 84 | self.epsilon = epsilon 85 | self.vfa = ValueFunction() 86 | self.ep_decay = ep_decay 87 | 88 | # Keeps track of useful statistics 89 | record = namedtuple("Record", ["episode_lengths","episode_rewards"]) 90 | self.rec = record(episode_lengths=np.zeros(num_episodes), 91 | episode_rewards=np.zeros(num_episodes)) 92 | 93 | def __epislon_greedy_policy(self, epsilon, nA): 94 | 95 | def policy(state): 96 | A = np.ones(nA, dtype=float) * epsilon / nA 97 | Q = self.vfa.predict(state) 98 | best_action = np.argmax(Q) 99 | A[best_action] += (1.0 - epsilon) 100 | return A 101 | 102 | return policy 103 | 104 | def __next_action(self, prob): 105 | return np.random.choice(np.arange(len(prob)), p=prob) 106 | 107 | def sarsa(self): 108 | """ 109 | sarsa algo 110 | """ 111 | sumlist = [] 112 | 113 | for i_episode in range(self.num_episodes): 114 | # Print out which episode we are on 115 | if 0 == (i_episode) % 10: 116 | print("Episode {} in {}".format(i_episode+1, self.num_episodes)) 117 | # sys.stdout.flush() 118 | 119 | # following current policy 120 | policy_epsilon = self.epsilon*self.ep_decay**i_episode 121 | policy = self.__epislon_greedy_policy(policy_epsilon, self.nA) 122 | 123 | step = 0 124 | # Initialize S 125 | state = self.env.reset() 126 | next_action = None 127 | 128 | # Repeat (for each step of episode) 129 | while(True): 130 | # Choose A from S using policy derived from Q 131 | prob_actions = policy(state) 132 | action = self.__next_action(prob_actions) 133 | 134 | # Take action A, observe R, S' 135 | next_state, reward, done, info = env.step(action) 136 | 137 | # update history record 138 | self.rec.episode_lengths[i_episode] += reward 139 | self.rec.episode_rewards[i_episode] = step 140 | 141 | # TD update: Q(S; A)<-Q(S; A) + aplha*[R + discount * max Q(S'; a) − Q(S; A)] 142 | q_next_value = self.vfa.predict(next_state) 143 | td_target = reward + self.discount * np.max(q_next_value) 144 | self.vfa.update(state, action, td_target) 145 | 146 | if done: 147 | # until S is terminal 148 | print("Episode finished after {} timesteps".format(step)) 149 | sumlist.append(step) 150 | break 151 | else: 152 | step += 1 153 | # S<-S' 154 | state = next_state 155 | 156 | iter_time = sum(sumlist)/len(sumlist) 157 | print("MountainCar game iter average time is: {}".format(iter_time)) 158 | 159 | cls_qlearning = QLearning(env, num_episodes=200) 160 | cls_qlearning.qlearning() 161 | 162 | 163 | # In[77]: 164 | 165 | 166 | from matplotlib import pyplot as plt 167 | import pandas as pd 168 | 169 | def plot_episode_stats(stats, smoothing_window=10, noshow=False): 170 | # Plot the episode length over time 171 | fig1 = plt.figure(figsize=(10,5)) 172 | plt.plot(stats.episode_lengths[:200]) 173 | plt.xlabel("Episode") 174 | plt.ylabel("Episode Length") 175 | plt.title("Episode Length over Time") 176 | if noshow: 177 | plt.close(fig1) 178 | else: 179 | plt.show(fig1) 180 | 181 | # Plot the episode reward over time 182 | fig2 = plt.figure(figsize=(10,5)) 183 | rewards_smoothed = pd.Series(stats.episode_rewards[:200]).rolling(smoothing_window, min_periods=smoothing_window).mean() 184 | plt.plot(rewards_smoothed) 185 | plt.xlabel("Episode") 186 | plt.ylabel("Episode Reward") 187 | plt.title("Episode Reward over Time".format(smoothing_window)) 188 | if noshow: 189 | plt.close(fig2) 190 | else: 191 | plt.show(fig2) 192 | 193 | return fig1, fig2 194 | 195 | plot_episode_stats(cls_qlearning.rec) 196 | 197 | --------------------------------------------------------------------------------