├── 1.无状态问题 ├── .ipynb_checkpoints │ ├── 1.贪婪算法-checkpoint.ipynb │ ├── 2.递减的贪婪算法-checkpoint.ipynb │ ├── 3.上置信界算法-checkpoint.ipynb │ └── 4.汤普森采样算法-checkpoint.ipynb ├── 1.贪婪算法.ipynb ├── 2.递减的贪婪算法.ipynb ├── 3.上置信界算法.ipynb └── 4.汤普森采样算法.ipynb ├── 10.PPO算法 ├── .ipynb_checkpoints │ ├── 1.PPO算法_平衡车-checkpoint.ipynb │ └── 2.PPO算法_倒立摆-checkpoint.ipynb ├── 1.PPO算法_平衡车.ipynb └── 2.PPO算法_倒立摆.ipynb ├── 11.DDPG算法 ├── .ipynb_checkpoints │ └── 1.DDPG算法-checkpoint.ipynb └── 1.DDPG算法.ipynb ├── 12.SAC算法 ├── .ipynb_checkpoints │ ├── 1.SAC算法_倒立摆-checkpoint.ipynb │ ├── 2.SAC算法_平衡车-checkpoint.ipynb │ ├── x1.倒立摆-checkpoint.ipynb │ └── x2.平衡车-checkpoint.ipynb ├── 1.SAC算法_倒立摆.ipynb ├── 2.SAC算法_平衡车.ipynb ├── x1.倒立摆.ipynb └── x2.平衡车.ipynb ├── 13.模仿学习 ├── .ipynb_checkpoints │ └── 1.模仿学习-checkpoint.ipynb ├── 1.模仿学习_平衡车.ipynb └── 2.模仿学习_倒立摆.ipynb ├── 14.离线学习 ├── .ipynb_checkpoints │ └── 1.离线学习-checkpoint.ipynb └── 1.离线学习.ipynb ├── 15.MPC ├── .ipynb_checkpoints │ └── 1.MPC-checkpoint.ipynb └── 1.MPC.ipynb ├── 16.MBPO ├── .ipynb_checkpoints │ ├── 1.MBPO-Copy1-checkpoint.ipynb │ └── 1.MBPO-checkpoint.ipynb └── 1.MBPO.ipynb ├── 17.目标导向的强化学习 ├── .ipynb_checkpoints │ └── 1.目标导向的强化学习-checkpoint.ipynb └── 1.目标导向的强化学习.ipynb ├── 18.多智能体 ├── .ipynb_checkpoints │ ├── 1.多智能体-Copy1-checkpoint.ipynb │ └── 1.多智能体-checkpoint.ipynb ├── 1.多智能体.ipynb ├── __pycache__ │ └── combat.cpython-36.pyc └── combat.py ├── 2.马尔可夫决策过程 ├── .ipynb_checkpoints │ ├── 1.蒙特卡洛法-checkpoint.ipynb │ └── 2.贝尔曼方程矩阵-checkpoint.ipynb ├── 1.蒙特卡洛法.ipynb └── 2.贝尔曼方程矩阵.ipynb ├── 3.动态规划算法 ├── .ipynb_checkpoints │ ├── 1.策略迭代算法-checkpoint.ipynb │ ├── 2.价值迭代算法-checkpoint.ipynb │ └── 3.冰湖-checkpoint.ipynb ├── 1.策略迭代算法.ipynb ├── 2.价值迭代算法.ipynb └── 3.冰湖.ipynb ├── 4.时序差分算法 ├── .ipynb_checkpoints │ ├── 1.Sarsa算法-checkpoint.ipynb │ ├── 2.N步Sarsa算法-checkpoint.ipynb │ └── 3.QLearning-checkpoint.ipynb ├── 1.Sarsa算法.ipynb ├── 2.N步Sarsa算法.ipynb └── 3.QLearning.ipynb ├── 5.DynaQ算法 ├── .ipynb_checkpoints │ └── 1.DynaQ-checkpoint.ipynb └── 1.DynaQ.ipynb ├── 6.DQN算法 ├── .ipynb_checkpoints │ ├── 1.单模型-checkpoint.ipynb │ ├── 2.双模型_平衡车-checkpoint.ipynb │ ├── 3.双模型_倒立摆-checkpoint.ipynb │ ├── 4.DoubleDQN-checkpoint.ipynb │ └── 5.DuelingDQN-checkpoint.ipynb ├── 1.单模型.ipynb ├── 2.双模型_平衡车.ipynb ├── 3.双模型_倒立摆.ipynb ├── 4.DoubleDQN.ipynb └── 5.DuelingDQN.ipynb ├── 7.策略梯度算法 ├── .ipynb_checkpoints │ ├── 1.Reinforce算法-checkpoint.ipynb │ ├── 2.Actor_Critic算法-checkpoint.ipynb │ ├── 3.TRPO算法_未完成-checkpoint.ipynb │ └── 4.PPO算法-checkpoint.ipynb └── 1.Reinforce算法.ipynb ├── 8.Actor_Critic算法 ├── .ipynb_checkpoints │ └── 1.Actor_Critic算法-checkpoint.ipynb └── 1.Actor_Critic算法.ipynb ├── README.md └── x1.gym ├── .ipynb_checkpoints └── 1.gym-checkpoint.ipynb └── 1.gym.ipynb /1.无状态问题/.ipynb_checkpoints/1.贪婪算法-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [ 10 | { 11 | "data": { 12 | "text/plain": [ 13 | "(array([0.80199222, 0.69477733, 0.83000436, 0.60975194, 0.55430339,\n", 14 | " 0.4454938 , 0.48716133, 0.41699328, 0.26842395, 0.59417058]),\n", 15 | " [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1]])" 16 | ] 17 | }, 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "output_type": "execute_result" 21 | } 22 | ], 23 | "source": [ 24 | "import numpy as np\n", 25 | "\n", 26 | "#每个老虎机的中奖概率,0-1之间的均匀分布\n", 27 | "probs = np.random.uniform(size=10)\n", 28 | "\n", 29 | "#记录每个老虎机的返回值\n", 30 | "rewards = [[1] for _ in range(10)]\n", 31 | "\n", 32 | "probs, rewards" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "0" 44 | ] 45 | }, 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "import random\n", 53 | "\n", 54 | "\n", 55 | "#贪婪算法\n", 56 | "def choose_one():\n", 57 | " #有小概率随机选择一根拉杆\n", 58 | " if random.random() < 0.01:\n", 59 | " return random.randint(0, 9)\n", 60 | "\n", 61 | " #计算每个老虎机的奖励平均\n", 62 | " rewards_mean = [np.mean(i) for i in rewards]\n", 63 | "\n", 64 | " #选择期望奖励估值最大的拉杆\n", 65 | " return np.argmax(rewards_mean)\n", 66 | "\n", 67 | "\n", 68 | "choose_one()" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": { 75 | "scrolled": true 76 | }, 77 | "outputs": [ 78 | { 79 | "data": { 80 | "text/plain": [ 81 | "[[1, 1], [1], [1], [1], [1], [1], [1], [1], [1], [1]]" 82 | ] 83 | }, 84 | "execution_count": 3, 85 | "metadata": {}, 86 | "output_type": "execute_result" 87 | } 88 | ], 89 | "source": [ 90 | "def try_and_play():\n", 91 | " i = choose_one()\n", 92 | "\n", 93 | " #玩老虎机,得到结果\n", 94 | " reward = 0\n", 95 | " if random.random() < probs[i]:\n", 96 | " reward = 1\n", 97 | "\n", 98 | " #记录玩的结果\n", 99 | " rewards[i].append(reward)\n", 100 | "\n", 101 | "\n", 102 | "try_and_play()\n", 103 | "\n", 104 | "rewards" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 4, 110 | "metadata": { 111 | "colab": { 112 | "base_uri": "https://localhost:8080/", 113 | "height": 312 114 | }, 115 | "executionInfo": { 116 | "elapsed": 676, 117 | "status": "ok", 118 | "timestamp": 1649954384006, 119 | "user": { 120 | "displayName": "Sam Lu", 121 | "userId": "15789059763790170725" 122 | }, 123 | "user_tz": -480 124 | }, 125 | "id": "wIHh_wRA8YDz", 126 | "outputId": "d5d65ff2-744d-44e2-ec8a-eb78d13397c2" 127 | }, 128 | "outputs": [ 129 | { 130 | "data": { 131 | "text/plain": [ 132 | "(4150.021823075759, 4077)" 133 | ] 134 | }, 135 | "execution_count": 4, 136 | "metadata": {}, 137 | "output_type": "execute_result" 138 | } 139 | ], 140 | "source": [ 141 | "def get_result():\n", 142 | " #玩N次\n", 143 | " for _ in range(5000):\n", 144 | " try_and_play()\n", 145 | "\n", 146 | " #期望的最好结果\n", 147 | " target = probs.max() * 5000\n", 148 | "\n", 149 | " #实际玩出的结果\n", 150 | " result = sum([sum(i) for i in rewards])\n", 151 | "\n", 152 | " return target, result\n", 153 | "\n", 154 | "\n", 155 | "get_result()" 156 | ] 157 | } 158 | ], 159 | "metadata": { 160 | "colab": { 161 | "collapsed_sections": [], 162 | "name": "第2章-多臂老虎机问题.ipynb", 163 | "provenance": [] 164 | }, 165 | "kernelspec": { 166 | "display_name": "Python 3", 167 | "language": "python", 168 | "name": "python3" 169 | }, 170 | "language_info": { 171 | "codemirror_mode": { 172 | "name": "ipython", 173 | "version": 3 174 | }, 175 | "file_extension": ".py", 176 | "mimetype": "text/x-python", 177 | "name": "python", 178 | "nbconvert_exporter": "python", 179 | "pygments_lexer": "ipython3", 180 | "version": "3.6.13" 181 | } 182 | }, 183 | "nbformat": 4, 184 | "nbformat_minor": 1 185 | } 186 | -------------------------------------------------------------------------------- /1.无状态问题/.ipynb_checkpoints/2.递减的贪婪算法-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [ 10 | { 11 | "data": { 12 | "text/plain": [ 13 | "(array([0.59403249, 0.07876075, 0.91117829, 0.59479119, 0.48536744,\n", 14 | " 0.85182017, 0.78686838, 0.09419114, 0.25834016, 0.2345657 ]),\n", 15 | " [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1]])" 16 | ] 17 | }, 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "output_type": "execute_result" 21 | } 22 | ], 23 | "source": [ 24 | "import numpy as np\n", 25 | "\n", 26 | "#每个老虎机的中奖概率,0-1之间的均匀分布\n", 27 | "probs = np.random.uniform(size=10)\n", 28 | "\n", 29 | "#记录每个老虎机的返回值\n", 30 | "rewards = [[1] for _ in range(10)]\n", 31 | "\n", 32 | "probs, rewards" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "8" 44 | ] 45 | }, 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "import random\n", 53 | "\n", 54 | "\n", 55 | "#随机选择的概率递减的贪婪算法\n", 56 | "def choose_one():\n", 57 | " #求出现在已经玩了多少次了\n", 58 | " played_count = sum([len(i) for i in rewards])\n", 59 | "\n", 60 | " #随机选择的概率逐渐下降\n", 61 | " if random.random() < 1 / played_count:\n", 62 | " return random.randint(0, 9)\n", 63 | "\n", 64 | " #计算每个老虎机的奖励平均\n", 65 | " rewards_mean = [np.mean(i) for i in rewards]\n", 66 | "\n", 67 | " #选择期望奖励估值最大的拉杆\n", 68 | " return np.argmax(rewards_mean)\n", 69 | "\n", 70 | "\n", 71 | "choose_one()" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "metadata": { 78 | "scrolled": true 79 | }, 80 | "outputs": [ 81 | { 82 | "data": { 83 | "text/plain": [ 84 | "[[1, 1], [1], [1], [1], [1], [1], [1], [1], [1], [1]]" 85 | ] 86 | }, 87 | "execution_count": 3, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "def try_and_play():\n", 94 | " i = choose_one()\n", 95 | "\n", 96 | " #玩老虎机,得到结果\n", 97 | " reward = 0\n", 98 | " if random.random() < probs[i]:\n", 99 | " reward = 1\n", 100 | "\n", 101 | " #记录玩的结果\n", 102 | " rewards[i].append(reward)\n", 103 | "\n", 104 | "\n", 105 | "try_and_play()\n", 106 | "\n", 107 | "rewards" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 4, 113 | "metadata": { 114 | "colab": { 115 | "base_uri": "https://localhost:8080/", 116 | "height": 312 117 | }, 118 | "executionInfo": { 119 | "elapsed": 676, 120 | "status": "ok", 121 | "timestamp": 1649954384006, 122 | "user": { 123 | "displayName": "Sam Lu", 124 | "userId": "15789059763790170725" 125 | }, 126 | "user_tz": -480 127 | }, 128 | "id": "wIHh_wRA8YDz", 129 | "outputId": "d5d65ff2-744d-44e2-ec8a-eb78d13397c2" 130 | }, 131 | "outputs": [ 132 | { 133 | "data": { 134 | "text/plain": [ 135 | "(4555.891425873478, 4540)" 136 | ] 137 | }, 138 | "execution_count": 4, 139 | "metadata": {}, 140 | "output_type": "execute_result" 141 | } 142 | ], 143 | "source": [ 144 | "def get_result():\n", 145 | " #玩N次\n", 146 | " for _ in range(5000):\n", 147 | " try_and_play()\n", 148 | "\n", 149 | " #期望的最好结果\n", 150 | " target = probs.max() * 5000\n", 151 | "\n", 152 | " #实际玩出的结果\n", 153 | " result = sum([sum(i) for i in rewards])\n", 154 | "\n", 155 | " return target, result\n", 156 | "\n", 157 | "\n", 158 | "get_result()" 159 | ] 160 | } 161 | ], 162 | "metadata": { 163 | "colab": { 164 | "collapsed_sections": [], 165 | "name": "第2章-多臂老虎机问题.ipynb", 166 | "provenance": [] 167 | }, 168 | "kernelspec": { 169 | "display_name": "Python 3", 170 | "language": "python", 171 | "name": "python3" 172 | }, 173 | "language_info": { 174 | "codemirror_mode": { 175 | "name": "ipython", 176 | "version": 3 177 | }, 178 | "file_extension": ".py", 179 | "mimetype": "text/x-python", 180 | "name": "python", 181 | "nbconvert_exporter": "python", 182 | "pygments_lexer": "ipython3", 183 | "version": "3.6.13" 184 | } 185 | }, 186 | "nbformat": 4, 187 | "nbformat_minor": 1 188 | } 189 | -------------------------------------------------------------------------------- /1.无状态问题/.ipynb_checkpoints/3.上置信界算法-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [ 10 | { 11 | "data": { 12 | "text/plain": [ 13 | "(array([0.4324468 , 0.48173807, 0.98724231, 0.51548606, 0.71232303,\n", 14 | " 0.97799668, 0.60370915, 0.32634855, 0.38733207, 0.08664855]),\n", 15 | " [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1]])" 16 | ] 17 | }, 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "output_type": "execute_result" 21 | } 22 | ], 23 | "source": [ 24 | "import numpy as np\n", 25 | "\n", 26 | "#每个老虎机的中奖概率,0-1之间的均匀分布\n", 27 | "probs = np.random.uniform(size=10)\n", 28 | "\n", 29 | "#记录每个老虎机的返回值\n", 30 | "rewards = [[1] for _ in range(10)]\n", 31 | "\n", 32 | "probs, rewards" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "0" 44 | ] 45 | }, 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "import random\n", 53 | "\n", 54 | "\n", 55 | "#随机选择的概率递减的贪婪算法\n", 56 | "def choose_one():\n", 57 | " #求出每个老虎机各玩了多少次\n", 58 | " played_count = [len(i) for i in rewards]\n", 59 | " played_count = np.array(played_count)\n", 60 | "\n", 61 | " #求出上置信界\n", 62 | " #分子是总共玩了多少次,取根号后让他的增长速度变慢\n", 63 | " #分母是每台老虎机玩的次数,乘以2让他的增长速度变快\n", 64 | " #随着玩的次数增加,分母会很快超过分子的增长速度,导致分数越来越小\n", 65 | " #具体到每一台老虎机,则是玩的次数越多,分数就越小,也就是ucb的加权越小\n", 66 | " #所以ucb衡量了每一台老虎机的不确定性,不确定性越大,探索的价值越大\n", 67 | " fenzi = played_count.sum()**0.5\n", 68 | " fenmu = played_count * 2\n", 69 | " ucb = fenzi / fenmu\n", 70 | "\n", 71 | " #ucb本身取根号\n", 72 | " #大于1的数会被缩小,小于1的数会被放大,这样保持ucb恒定在一定的数值范围内\n", 73 | " ucb = ucb**0.5\n", 74 | "\n", 75 | " #计算每个老虎机的奖励平均\n", 76 | " rewards_mean = [np.mean(i) for i in rewards]\n", 77 | " rewards_mean = np.array(rewards_mean)\n", 78 | "\n", 79 | " #ucb和期望求和\n", 80 | " ucb += rewards_mean\n", 81 | "\n", 82 | " return ucb.argmax()\n", 83 | "\n", 84 | "\n", 85 | "choose_one()" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 3, 91 | "metadata": { 92 | "scrolled": true 93 | }, 94 | "outputs": [ 95 | { 96 | "data": { 97 | "text/plain": [ 98 | "[[1, 1], [1], [1], [1], [1], [1], [1], [1], [1], [1]]" 99 | ] 100 | }, 101 | "execution_count": 3, 102 | "metadata": {}, 103 | "output_type": "execute_result" 104 | } 105 | ], 106 | "source": [ 107 | "def try_and_play():\n", 108 | " i = choose_one()\n", 109 | "\n", 110 | " #玩老虎机,得到结果\n", 111 | " reward = 0\n", 112 | " if random.random() < probs[i]:\n", 113 | " reward = 1\n", 114 | "\n", 115 | " #记录玩的结果\n", 116 | " rewards[i].append(reward)\n", 117 | "\n", 118 | "\n", 119 | "try_and_play()\n", 120 | "\n", 121 | "rewards" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 4, 127 | "metadata": { 128 | "colab": { 129 | "base_uri": "https://localhost:8080/", 130 | "height": 312 131 | }, 132 | "executionInfo": { 133 | "elapsed": 676, 134 | "status": "ok", 135 | "timestamp": 1649954384006, 136 | "user": { 137 | "displayName": "Sam Lu", 138 | "userId": "15789059763790170725" 139 | }, 140 | "user_tz": -480 141 | }, 142 | "id": "wIHh_wRA8YDz", 143 | "outputId": "d5d65ff2-744d-44e2-ec8a-eb78d13397c2", 144 | "scrolled": false 145 | }, 146 | "outputs": [ 147 | { 148 | "data": { 149 | "text/plain": [ 150 | "(4936.211534652689, 4553)" 151 | ] 152 | }, 153 | "execution_count": 4, 154 | "metadata": {}, 155 | "output_type": "execute_result" 156 | } 157 | ], 158 | "source": [ 159 | "def get_result():\n", 160 | " #玩N次\n", 161 | " for _ in range(5000):\n", 162 | " try_and_play()\n", 163 | "\n", 164 | " #期望的最好结果\n", 165 | " target = probs.max() * 5000\n", 166 | "\n", 167 | " #实际玩出的结果\n", 168 | " result = sum([sum(i) for i in rewards])\n", 169 | "\n", 170 | " return target, result\n", 171 | "\n", 172 | "\n", 173 | "get_result()" 174 | ] 175 | } 176 | ], 177 | "metadata": { 178 | "colab": { 179 | "collapsed_sections": [], 180 | "name": "第2章-多臂老虎机问题.ipynb", 181 | "provenance": [] 182 | }, 183 | "kernelspec": { 184 | "display_name": "Python 3", 185 | "language": "python", 186 | "name": "python3" 187 | }, 188 | "language_info": { 189 | "codemirror_mode": { 190 | "name": "ipython", 191 | "version": 3 192 | }, 193 | "file_extension": ".py", 194 | "mimetype": "text/x-python", 195 | "name": "python", 196 | "nbconvert_exporter": "python", 197 | "pygments_lexer": "ipython3", 198 | "version": "3.6.13" 199 | } 200 | }, 201 | "nbformat": 4, 202 | "nbformat_minor": 1 203 | } 204 | -------------------------------------------------------------------------------- /1.无状态问题/.ipynb_checkpoints/4.汤普森采样算法-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [ 10 | { 11 | "data": { 12 | "text/plain": [ 13 | "(array([0.13223698, 0.39830518, 0.93960761, 0.3007807 , 0.59217994,\n", 14 | " 0.92562934, 0.92710191, 0.01909585, 0.20277616, 0.29105418]),\n", 15 | " [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1]])" 16 | ] 17 | }, 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "output_type": "execute_result" 21 | } 22 | ], 23 | "source": [ 24 | "import numpy as np\n", 25 | "\n", 26 | "#每个老虎机的中奖概率,0-1之间的均匀分布\n", 27 | "probs = np.random.uniform(size=10)\n", 28 | "\n", 29 | "#记录每个老虎机的返回值\n", 30 | "rewards = [[1] for _ in range(10)]\n", 31 | "\n", 32 | "probs, rewards" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "当数字小的时候,beta分布的概率有很大的随机性\n", 45 | "0.9566924357894874\n", 46 | "0.796533273269566\n", 47 | "0.14083572337004413\n", 48 | "0.3350811260642629\n", 49 | "0.5601835883123273\n", 50 | "当数字大时,beta分布逐渐稳定\n", 51 | "0.4980336738406946\n", 52 | "0.5014911804072641\n", 53 | "0.49954932416995235\n", 54 | "0.49752638673683025\n", 55 | "0.5003858155869424\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "#beta分布测试\n", 61 | "print('当数字小的时候,beta分布的概率有很大的随机性')\n", 62 | "for _ in range(5):\n", 63 | " print(np.random.beta(1, 1))\n", 64 | "\n", 65 | "print('当数字大时,beta分布逐渐稳定')\n", 66 | "for _ in range(5):\n", 67 | " print(np.random.beta(1e5, 1e5))" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "9" 79 | ] 80 | }, 81 | "execution_count": 3, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | } 85 | ], 86 | "source": [ 87 | "import random\n", 88 | "\n", 89 | "\n", 90 | "def choose_one():\n", 91 | " #求出每个老虎机出1的次数+1\n", 92 | " count_1 = [sum(i) + 1 for i in rewards]\n", 93 | "\n", 94 | " #求出每个老虎机出0的次数+1\n", 95 | " count_0 = [sum(1 - np.array(i)) + 1 for i in rewards]\n", 96 | "\n", 97 | " #按照beta分布计算奖励分布,这可以认为是每一台老虎机中奖的概率\n", 98 | " beta = np.random.beta(count_1, count_0)\n", 99 | "\n", 100 | " return beta.argmax()\n", 101 | "\n", 102 | "\n", 103 | "choose_one()" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "metadata": { 110 | "scrolled": true 111 | }, 112 | "outputs": [ 113 | { 114 | "data": { 115 | "text/plain": [ 116 | "[[1], [1], [1], [1], [1], [1], [1], [1], [1], [1, 0]]" 117 | ] 118 | }, 119 | "execution_count": 4, 120 | "metadata": {}, 121 | "output_type": "execute_result" 122 | } 123 | ], 124 | "source": [ 125 | "def try_and_play():\n", 126 | " i = choose_one()\n", 127 | "\n", 128 | " #玩老虎机,得到结果\n", 129 | " reward = 0\n", 130 | " if random.random() < probs[i]:\n", 131 | " reward = 1\n", 132 | "\n", 133 | " #记录玩的结果\n", 134 | " rewards[i].append(reward)\n", 135 | "\n", 136 | "\n", 137 | "try_and_play()\n", 138 | "\n", 139 | "rewards" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 5, 145 | "metadata": { 146 | "colab": { 147 | "base_uri": "https://localhost:8080/", 148 | "height": 312 149 | }, 150 | "executionInfo": { 151 | "elapsed": 676, 152 | "status": "ok", 153 | "timestamp": 1649954384006, 154 | "user": { 155 | "displayName": "Sam Lu", 156 | "userId": "15789059763790170725" 157 | }, 158 | "user_tz": -480 159 | }, 160 | "id": "wIHh_wRA8YDz", 161 | "outputId": "d5d65ff2-744d-44e2-ec8a-eb78d13397c2", 162 | "scrolled": false 163 | }, 164 | "outputs": [ 165 | { 166 | "data": { 167 | "text/plain": [ 168 | "(4698.038052814513, 4680)" 169 | ] 170 | }, 171 | "execution_count": 5, 172 | "metadata": {}, 173 | "output_type": "execute_result" 174 | } 175 | ], 176 | "source": [ 177 | "def get_result():\n", 178 | " #玩N次\n", 179 | " for _ in range(5000):\n", 180 | " try_and_play()\n", 181 | "\n", 182 | " #期望的最好结果\n", 183 | " target = probs.max() * 5000\n", 184 | "\n", 185 | " #实际玩出的结果\n", 186 | " result = sum([sum(i) for i in rewards])\n", 187 | "\n", 188 | " return target, result\n", 189 | "\n", 190 | "\n", 191 | "get_result()" 192 | ] 193 | } 194 | ], 195 | "metadata": { 196 | "colab": { 197 | "collapsed_sections": [], 198 | "name": "第2章-多臂老虎机问题.ipynb", 199 | "provenance": [] 200 | }, 201 | "kernelspec": { 202 | "display_name": "Python 3", 203 | "language": "python", 204 | "name": "python3" 205 | }, 206 | "language_info": { 207 | "codemirror_mode": { 208 | "name": "ipython", 209 | "version": 3 210 | }, 211 | "file_extension": ".py", 212 | "mimetype": "text/x-python", 213 | "name": "python", 214 | "nbconvert_exporter": "python", 215 | "pygments_lexer": "ipython3", 216 | "version": "3.6.13" 217 | } 218 | }, 219 | "nbformat": 4, 220 | "nbformat_minor": 1 221 | } 222 | -------------------------------------------------------------------------------- /1.无状态问题/1.贪婪算法.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [ 10 | { 11 | "data": { 12 | "text/plain": [ 13 | "(array([0.80199222, 0.69477733, 0.83000436, 0.60975194, 0.55430339,\n", 14 | " 0.4454938 , 0.48716133, 0.41699328, 0.26842395, 0.59417058]),\n", 15 | " [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1]])" 16 | ] 17 | }, 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "output_type": "execute_result" 21 | } 22 | ], 23 | "source": [ 24 | "import numpy as np\n", 25 | "\n", 26 | "#每个老虎机的中奖概率,0-1之间的均匀分布\n", 27 | "probs = np.random.uniform(size=10)\n", 28 | "\n", 29 | "#记录每个老虎机的返回值\n", 30 | "rewards = [[1] for _ in range(10)]\n", 31 | "\n", 32 | "probs, rewards" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "0" 44 | ] 45 | }, 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "import random\n", 53 | "\n", 54 | "\n", 55 | "#贪婪算法\n", 56 | "def choose_one():\n", 57 | " #有小概率随机选择一根拉杆\n", 58 | " if random.random() < 0.01:\n", 59 | " return random.randint(0, 9)\n", 60 | "\n", 61 | " #计算每个老虎机的奖励平均\n", 62 | " rewards_mean = [np.mean(i) for i in rewards]\n", 63 | "\n", 64 | " #选择期望奖励估值最大的拉杆\n", 65 | " return np.argmax(rewards_mean)\n", 66 | "\n", 67 | "\n", 68 | "choose_one()" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": { 75 | "scrolled": true 76 | }, 77 | "outputs": [ 78 | { 79 | "data": { 80 | "text/plain": [ 81 | "[[1, 1], [1], [1], [1], [1], [1], [1], [1], [1], [1]]" 82 | ] 83 | }, 84 | "execution_count": 3, 85 | "metadata": {}, 86 | "output_type": "execute_result" 87 | } 88 | ], 89 | "source": [ 90 | "def try_and_play():\n", 91 | " i = choose_one()\n", 92 | "\n", 93 | " #玩老虎机,得到结果\n", 94 | " reward = 0\n", 95 | " if random.random() < probs[i]:\n", 96 | " reward = 1\n", 97 | "\n", 98 | " #记录玩的结果\n", 99 | " rewards[i].append(reward)\n", 100 | "\n", 101 | "\n", 102 | "try_and_play()\n", 103 | "\n", 104 | "rewards" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 4, 110 | "metadata": { 111 | "colab": { 112 | "base_uri": "https://localhost:8080/", 113 | "height": 312 114 | }, 115 | "executionInfo": { 116 | "elapsed": 676, 117 | "status": "ok", 118 | "timestamp": 1649954384006, 119 | "user": { 120 | "displayName": "Sam Lu", 121 | "userId": "15789059763790170725" 122 | }, 123 | "user_tz": -480 124 | }, 125 | "id": "wIHh_wRA8YDz", 126 | "outputId": "d5d65ff2-744d-44e2-ec8a-eb78d13397c2" 127 | }, 128 | "outputs": [ 129 | { 130 | "data": { 131 | "text/plain": [ 132 | "(4150.021823075759, 4077)" 133 | ] 134 | }, 135 | "execution_count": 4, 136 | "metadata": {}, 137 | "output_type": "execute_result" 138 | } 139 | ], 140 | "source": [ 141 | "def get_result():\n", 142 | " #玩N次\n", 143 | " for _ in range(5000):\n", 144 | " try_and_play()\n", 145 | "\n", 146 | " #期望的最好结果\n", 147 | " target = probs.max() * 5000\n", 148 | "\n", 149 | " #实际玩出的结果\n", 150 | " result = sum([sum(i) for i in rewards])\n", 151 | "\n", 152 | " return target, result\n", 153 | "\n", 154 | "\n", 155 | "get_result()" 156 | ] 157 | } 158 | ], 159 | "metadata": { 160 | "colab": { 161 | "collapsed_sections": [], 162 | "name": "第2章-多臂老虎机问题.ipynb", 163 | "provenance": [] 164 | }, 165 | "kernelspec": { 166 | "display_name": "Python 3", 167 | "language": "python", 168 | "name": "python3" 169 | }, 170 | "language_info": { 171 | "codemirror_mode": { 172 | "name": "ipython", 173 | "version": 3 174 | }, 175 | "file_extension": ".py", 176 | "mimetype": "text/x-python", 177 | "name": "python", 178 | "nbconvert_exporter": "python", 179 | "pygments_lexer": "ipython3", 180 | "version": "3.6.13" 181 | } 182 | }, 183 | "nbformat": 4, 184 | "nbformat_minor": 1 185 | } 186 | -------------------------------------------------------------------------------- /1.无状态问题/2.递减的贪婪算法.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [ 10 | { 11 | "data": { 12 | "text/plain": [ 13 | "(array([0.59403249, 0.07876075, 0.91117829, 0.59479119, 0.48536744,\n", 14 | " 0.85182017, 0.78686838, 0.09419114, 0.25834016, 0.2345657 ]),\n", 15 | " [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1]])" 16 | ] 17 | }, 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "output_type": "execute_result" 21 | } 22 | ], 23 | "source": [ 24 | "import numpy as np\n", 25 | "\n", 26 | "#每个老虎机的中奖概率,0-1之间的均匀分布\n", 27 | "probs = np.random.uniform(size=10)\n", 28 | "\n", 29 | "#记录每个老虎机的返回值\n", 30 | "rewards = [[1] for _ in range(10)]\n", 31 | "\n", 32 | "probs, rewards" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "8" 44 | ] 45 | }, 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "import random\n", 53 | "\n", 54 | "\n", 55 | "#随机选择的概率递减的贪婪算法\n", 56 | "def choose_one():\n", 57 | " #求出现在已经玩了多少次了\n", 58 | " played_count = sum([len(i) for i in rewards])\n", 59 | "\n", 60 | " #随机选择的概率逐渐下降\n", 61 | " if random.random() < 1 / played_count:\n", 62 | " return random.randint(0, 9)\n", 63 | "\n", 64 | " #计算每个老虎机的奖励平均\n", 65 | " rewards_mean = [np.mean(i) for i in rewards]\n", 66 | "\n", 67 | " #选择期望奖励估值最大的拉杆\n", 68 | " return np.argmax(rewards_mean)\n", 69 | "\n", 70 | "\n", 71 | "choose_one()" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "metadata": { 78 | "scrolled": true 79 | }, 80 | "outputs": [ 81 | { 82 | "data": { 83 | "text/plain": [ 84 | "[[1, 1], [1], [1], [1], [1], [1], [1], [1], [1], [1]]" 85 | ] 86 | }, 87 | "execution_count": 3, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "def try_and_play():\n", 94 | " i = choose_one()\n", 95 | "\n", 96 | " #玩老虎机,得到结果\n", 97 | " reward = 0\n", 98 | " if random.random() < probs[i]:\n", 99 | " reward = 1\n", 100 | "\n", 101 | " #记录玩的结果\n", 102 | " rewards[i].append(reward)\n", 103 | "\n", 104 | "\n", 105 | "try_and_play()\n", 106 | "\n", 107 | "rewards" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 4, 113 | "metadata": { 114 | "colab": { 115 | "base_uri": "https://localhost:8080/", 116 | "height": 312 117 | }, 118 | "executionInfo": { 119 | "elapsed": 676, 120 | "status": "ok", 121 | "timestamp": 1649954384006, 122 | "user": { 123 | "displayName": "Sam Lu", 124 | "userId": "15789059763790170725" 125 | }, 126 | "user_tz": -480 127 | }, 128 | "id": "wIHh_wRA8YDz", 129 | "outputId": "d5d65ff2-744d-44e2-ec8a-eb78d13397c2" 130 | }, 131 | "outputs": [ 132 | { 133 | "data": { 134 | "text/plain": [ 135 | "(4555.891425873478, 4540)" 136 | ] 137 | }, 138 | "execution_count": 4, 139 | "metadata": {}, 140 | "output_type": "execute_result" 141 | } 142 | ], 143 | "source": [ 144 | "def get_result():\n", 145 | " #玩N次\n", 146 | " for _ in range(5000):\n", 147 | " try_and_play()\n", 148 | "\n", 149 | " #期望的最好结果\n", 150 | " target = probs.max() * 5000\n", 151 | "\n", 152 | " #实际玩出的结果\n", 153 | " result = sum([sum(i) for i in rewards])\n", 154 | "\n", 155 | " return target, result\n", 156 | "\n", 157 | "\n", 158 | "get_result()" 159 | ] 160 | } 161 | ], 162 | "metadata": { 163 | "colab": { 164 | "collapsed_sections": [], 165 | "name": "第2章-多臂老虎机问题.ipynb", 166 | "provenance": [] 167 | }, 168 | "kernelspec": { 169 | "display_name": "Python 3", 170 | "language": "python", 171 | "name": "python3" 172 | }, 173 | "language_info": { 174 | "codemirror_mode": { 175 | "name": "ipython", 176 | "version": 3 177 | }, 178 | "file_extension": ".py", 179 | "mimetype": "text/x-python", 180 | "name": "python", 181 | "nbconvert_exporter": "python", 182 | "pygments_lexer": "ipython3", 183 | "version": "3.6.13" 184 | } 185 | }, 186 | "nbformat": 4, 187 | "nbformat_minor": 1 188 | } 189 | -------------------------------------------------------------------------------- /1.无状态问题/3.上置信界算法.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [ 10 | { 11 | "data": { 12 | "text/plain": [ 13 | "(array([0.4324468 , 0.48173807, 0.98724231, 0.51548606, 0.71232303,\n", 14 | " 0.97799668, 0.60370915, 0.32634855, 0.38733207, 0.08664855]),\n", 15 | " [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1]])" 16 | ] 17 | }, 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "output_type": "execute_result" 21 | } 22 | ], 23 | "source": [ 24 | "import numpy as np\n", 25 | "\n", 26 | "#每个老虎机的中奖概率,0-1之间的均匀分布\n", 27 | "probs = np.random.uniform(size=10)\n", 28 | "\n", 29 | "#记录每个老虎机的返回值\n", 30 | "rewards = [[1] for _ in range(10)]\n", 31 | "\n", 32 | "probs, rewards" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "0" 44 | ] 45 | }, 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "import random\n", 53 | "\n", 54 | "\n", 55 | "#随机选择的概率递减的贪婪算法\n", 56 | "def choose_one():\n", 57 | " #求出每个老虎机各玩了多少次\n", 58 | " played_count = [len(i) for i in rewards]\n", 59 | " played_count = np.array(played_count)\n", 60 | "\n", 61 | " #求出上置信界\n", 62 | " #分子是总共玩了多少次,取根号后让他的增长速度变慢\n", 63 | " #分母是每台老虎机玩的次数,乘以2让他的增长速度变快\n", 64 | " #随着玩的次数增加,分母会很快超过分子的增长速度,导致分数越来越小\n", 65 | " #具体到每一台老虎机,则是玩的次数越多,分数就越小,也就是ucb的加权越小\n", 66 | " #所以ucb衡量了每一台老虎机的不确定性,不确定性越大,探索的价值越大\n", 67 | " fenzi = played_count.sum()**0.5\n", 68 | " fenmu = played_count * 2\n", 69 | " ucb = fenzi / fenmu\n", 70 | "\n", 71 | " #ucb本身取根号\n", 72 | " #大于1的数会被缩小,小于1的数会被放大,这样保持ucb恒定在一定的数值范围内\n", 73 | " ucb = ucb**0.5\n", 74 | "\n", 75 | " #计算每个老虎机的奖励平均\n", 76 | " rewards_mean = [np.mean(i) for i in rewards]\n", 77 | " rewards_mean = np.array(rewards_mean)\n", 78 | "\n", 79 | " #ucb和期望求和\n", 80 | " ucb += rewards_mean\n", 81 | "\n", 82 | " return ucb.argmax()\n", 83 | "\n", 84 | "\n", 85 | "choose_one()" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 3, 91 | "metadata": { 92 | "scrolled": true 93 | }, 94 | "outputs": [ 95 | { 96 | "data": { 97 | "text/plain": [ 98 | "[[1, 1], [1], [1], [1], [1], [1], [1], [1], [1], [1]]" 99 | ] 100 | }, 101 | "execution_count": 3, 102 | "metadata": {}, 103 | "output_type": "execute_result" 104 | } 105 | ], 106 | "source": [ 107 | "def try_and_play():\n", 108 | " i = choose_one()\n", 109 | "\n", 110 | " #玩老虎机,得到结果\n", 111 | " reward = 0\n", 112 | " if random.random() < probs[i]:\n", 113 | " reward = 1\n", 114 | "\n", 115 | " #记录玩的结果\n", 116 | " rewards[i].append(reward)\n", 117 | "\n", 118 | "\n", 119 | "try_and_play()\n", 120 | "\n", 121 | "rewards" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 4, 127 | "metadata": { 128 | "colab": { 129 | "base_uri": "https://localhost:8080/", 130 | "height": 312 131 | }, 132 | "executionInfo": { 133 | "elapsed": 676, 134 | "status": "ok", 135 | "timestamp": 1649954384006, 136 | "user": { 137 | "displayName": "Sam Lu", 138 | "userId": "15789059763790170725" 139 | }, 140 | "user_tz": -480 141 | }, 142 | "id": "wIHh_wRA8YDz", 143 | "outputId": "d5d65ff2-744d-44e2-ec8a-eb78d13397c2", 144 | "scrolled": false 145 | }, 146 | "outputs": [ 147 | { 148 | "data": { 149 | "text/plain": [ 150 | "(4936.211534652689, 4553)" 151 | ] 152 | }, 153 | "execution_count": 4, 154 | "metadata": {}, 155 | "output_type": "execute_result" 156 | } 157 | ], 158 | "source": [ 159 | "def get_result():\n", 160 | " #玩N次\n", 161 | " for _ in range(5000):\n", 162 | " try_and_play()\n", 163 | "\n", 164 | " #期望的最好结果\n", 165 | " target = probs.max() * 5000\n", 166 | "\n", 167 | " #实际玩出的结果\n", 168 | " result = sum([sum(i) for i in rewards])\n", 169 | "\n", 170 | " return target, result\n", 171 | "\n", 172 | "\n", 173 | "get_result()" 174 | ] 175 | } 176 | ], 177 | "metadata": { 178 | "colab": { 179 | "collapsed_sections": [], 180 | "name": "第2章-多臂老虎机问题.ipynb", 181 | "provenance": [] 182 | }, 183 | "kernelspec": { 184 | "display_name": "Python 3", 185 | "language": "python", 186 | "name": "python3" 187 | }, 188 | "language_info": { 189 | "codemirror_mode": { 190 | "name": "ipython", 191 | "version": 3 192 | }, 193 | "file_extension": ".py", 194 | "mimetype": "text/x-python", 195 | "name": "python", 196 | "nbconvert_exporter": "python", 197 | "pygments_lexer": "ipython3", 198 | "version": "3.6.13" 199 | } 200 | }, 201 | "nbformat": 4, 202 | "nbformat_minor": 1 203 | } 204 | -------------------------------------------------------------------------------- /1.无状态问题/4.汤普森采样算法.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [ 10 | { 11 | "data": { 12 | "text/plain": [ 13 | "(array([0.13223698, 0.39830518, 0.93960761, 0.3007807 , 0.59217994,\n", 14 | " 0.92562934, 0.92710191, 0.01909585, 0.20277616, 0.29105418]),\n", 15 | " [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1]])" 16 | ] 17 | }, 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "output_type": "execute_result" 21 | } 22 | ], 23 | "source": [ 24 | "import numpy as np\n", 25 | "\n", 26 | "#每个老虎机的中奖概率,0-1之间的均匀分布\n", 27 | "probs = np.random.uniform(size=10)\n", 28 | "\n", 29 | "#记录每个老虎机的返回值\n", 30 | "rewards = [[1] for _ in range(10)]\n", 31 | "\n", 32 | "probs, rewards" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "当数字小的时候,beta分布的概率有很大的随机性\n", 45 | "0.9566924357894874\n", 46 | "0.796533273269566\n", 47 | "0.14083572337004413\n", 48 | "0.3350811260642629\n", 49 | "0.5601835883123273\n", 50 | "当数字大时,beta分布逐渐稳定\n", 51 | "0.4980336738406946\n", 52 | "0.5014911804072641\n", 53 | "0.49954932416995235\n", 54 | "0.49752638673683025\n", 55 | "0.5003858155869424\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "#beta分布测试\n", 61 | "print('当数字小的时候,beta分布的概率有很大的随机性')\n", 62 | "for _ in range(5):\n", 63 | " print(np.random.beta(1, 1))\n", 64 | "\n", 65 | "print('当数字大时,beta分布逐渐稳定')\n", 66 | "for _ in range(5):\n", 67 | " print(np.random.beta(1e5, 1e5))" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "9" 79 | ] 80 | }, 81 | "execution_count": 3, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | } 85 | ], 86 | "source": [ 87 | "import random\n", 88 | "\n", 89 | "\n", 90 | "def choose_one():\n", 91 | " #求出每个老虎机出1的次数+1\n", 92 | " count_1 = [sum(i) + 1 for i in rewards]\n", 93 | "\n", 94 | " #求出每个老虎机出0的次数+1\n", 95 | " count_0 = [sum(1 - np.array(i)) + 1 for i in rewards]\n", 96 | "\n", 97 | " #按照beta分布计算奖励分布,这可以认为是每一台老虎机中奖的概率\n", 98 | " beta = np.random.beta(count_1, count_0)\n", 99 | "\n", 100 | " return beta.argmax()\n", 101 | "\n", 102 | "\n", 103 | "choose_one()" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "metadata": { 110 | "scrolled": true 111 | }, 112 | "outputs": [ 113 | { 114 | "data": { 115 | "text/plain": [ 116 | "[[1], [1], [1], [1], [1], [1], [1], [1], [1], [1, 0]]" 117 | ] 118 | }, 119 | "execution_count": 4, 120 | "metadata": {}, 121 | "output_type": "execute_result" 122 | } 123 | ], 124 | "source": [ 125 | "def try_and_play():\n", 126 | " i = choose_one()\n", 127 | "\n", 128 | " #玩老虎机,得到结果\n", 129 | " reward = 0\n", 130 | " if random.random() < probs[i]:\n", 131 | " reward = 1\n", 132 | "\n", 133 | " #记录玩的结果\n", 134 | " rewards[i].append(reward)\n", 135 | "\n", 136 | "\n", 137 | "try_and_play()\n", 138 | "\n", 139 | "rewards" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 5, 145 | "metadata": { 146 | "colab": { 147 | "base_uri": "https://localhost:8080/", 148 | "height": 312 149 | }, 150 | "executionInfo": { 151 | "elapsed": 676, 152 | "status": "ok", 153 | "timestamp": 1649954384006, 154 | "user": { 155 | "displayName": "Sam Lu", 156 | "userId": "15789059763790170725" 157 | }, 158 | "user_tz": -480 159 | }, 160 | "id": "wIHh_wRA8YDz", 161 | "outputId": "d5d65ff2-744d-44e2-ec8a-eb78d13397c2", 162 | "scrolled": false 163 | }, 164 | "outputs": [ 165 | { 166 | "data": { 167 | "text/plain": [ 168 | "(4698.038052814513, 4680)" 169 | ] 170 | }, 171 | "execution_count": 5, 172 | "metadata": {}, 173 | "output_type": "execute_result" 174 | } 175 | ], 176 | "source": [ 177 | "def get_result():\n", 178 | " #玩N次\n", 179 | " for _ in range(5000):\n", 180 | " try_and_play()\n", 181 | "\n", 182 | " #期望的最好结果\n", 183 | " target = probs.max() * 5000\n", 184 | "\n", 185 | " #实际玩出的结果\n", 186 | " result = sum([sum(i) for i in rewards])\n", 187 | "\n", 188 | " return target, result\n", 189 | "\n", 190 | "\n", 191 | "get_result()" 192 | ] 193 | } 194 | ], 195 | "metadata": { 196 | "colab": { 197 | "collapsed_sections": [], 198 | "name": "第2章-多臂老虎机问题.ipynb", 199 | "provenance": [] 200 | }, 201 | "kernelspec": { 202 | "display_name": "Python 3", 203 | "language": "python", 204 | "name": "python3" 205 | }, 206 | "language_info": { 207 | "codemirror_mode": { 208 | "name": "ipython", 209 | "version": 3 210 | }, 211 | "file_extension": ".py", 212 | "mimetype": "text/x-python", 213 | "name": "python", 214 | "nbconvert_exporter": "python", 215 | "pygments_lexer": "ipython3", 216 | "version": "3.6.13" 217 | } 218 | }, 219 | "nbformat": 4, 220 | "nbformat_minor": 1 221 | } 222 | -------------------------------------------------------------------------------- /15.MPC/.ipynb_checkpoints/1.MPC-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4a033743", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "array([-0.35493255, 0.93489194, -0.32008162], dtype=float32)" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import gym\n", 22 | "\n", 23 | "\n", 24 | "#定义环境\n", 25 | "class MyWrapper(gym.Wrapper):\n", 26 | " def __init__(self):\n", 27 | " env = gym.make('Pendulum-v1', render_mode='rgb_array')\n", 28 | " super().__init__(env)\n", 29 | " self.env = env\n", 30 | " self.step_n = 0\n", 31 | "\n", 32 | " def reset(self):\n", 33 | " state, _ = self.env.reset()\n", 34 | " self.step_n = 0\n", 35 | " return state\n", 36 | "\n", 37 | " def step(self, action):\n", 38 | " state, reward, terminated, truncated, info = self.env.step(action)\n", 39 | " done = terminated or truncated\n", 40 | " self.step_n += 1\n", 41 | " if self.step_n >= 200:\n", 42 | " done = True\n", 43 | " return state, reward, done, info\n", 44 | "\n", 45 | "\n", 46 | "env = MyWrapper()\n", 47 | "\n", 48 | "env.reset()" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "id": "9c383bd4", 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "(200,\n", 61 | " ([-0.9102440476417542, -0.4140722155570984, 0.9576234817504883],\n", 62 | " 0.35440483689308167,\n", 63 | " -7.461259538753717,\n", 64 | " [-0.8951918482780457, -0.44568097591400146, 0.700230062007904],\n", 65 | " False),\n", 66 | " torch.Size([200, 4]),\n", 67 | " torch.Size([200, 4]))" 68 | ] 69 | }, 70 | "execution_count": 2, 71 | "metadata": {}, 72 | "output_type": "execute_result" 73 | } 74 | ], 75 | "source": [ 76 | "import numpy as np\n", 77 | "import torch\n", 78 | "\n", 79 | "\n", 80 | "class Pool:\n", 81 | " def __init__(self, limit):\n", 82 | " #样本池\n", 83 | " self.datas = []\n", 84 | " self.limit = limit\n", 85 | "\n", 86 | " def add(self, state, action, reward, next_state, over):\n", 87 | " if isinstance(state, np.ndarray) or isinstance(state, torch.Tensor):\n", 88 | " state = state.reshape(3).tolist()\n", 89 | "\n", 90 | " action = float(action)\n", 91 | "\n", 92 | " reward = float(reward)\n", 93 | "\n", 94 | " if isinstance(next_state, np.ndarray) or isinstance(\n", 95 | " next_state, torch.Tensor):\n", 96 | " next_state = next_state.reshape(3).tolist()\n", 97 | "\n", 98 | " over = bool(over)\n", 99 | "\n", 100 | " self.datas.append((state, action, reward, next_state, over))\n", 101 | " #数据上限,超出时从最古老的开始删除\n", 102 | " while len(self.datas) > self.limit:\n", 103 | " self.datas.pop(0)\n", 104 | "\n", 105 | " #获取一批数据样本\n", 106 | " def get_sample(self):\n", 107 | " #从样本池中采样\n", 108 | " samples = self.datas\n", 109 | "\n", 110 | " #[b, 3]\n", 111 | " state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n", 112 | " #[b, 1]\n", 113 | " action = torch.FloatTensor([i[1] for i in samples]).reshape(-1, 1)\n", 114 | " #[b, 1]\n", 115 | " reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n", 116 | " #[b, 3]\n", 117 | " next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)\n", 118 | " #[b, 1]\n", 119 | " over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n", 120 | "\n", 121 | " #[b, 4]\n", 122 | " input = torch.cat([state, action], dim=1)\n", 123 | " #[b, 4]\n", 124 | " label = torch.cat([reward, next_state - state], dim=1)\n", 125 | "\n", 126 | " return input, label\n", 127 | "\n", 128 | " def __len__(self):\n", 129 | " return len(self.datas)\n", 130 | "\n", 131 | "\n", 132 | "pool = Pool(100000)\n", 133 | "\n", 134 | "\n", 135 | "#初始化一局游戏的数据\n", 136 | "def _():\n", 137 | " #初始化游戏\n", 138 | " state = env.reset()\n", 139 | "\n", 140 | " #玩到游戏结束为止\n", 141 | " over = False\n", 142 | " while not over:\n", 143 | " #随机一个动作\n", 144 | " action = env.action_space.sample()[0]\n", 145 | "\n", 146 | " #执行动作,得到反馈\n", 147 | " next_state, reward, over, _ = env.step([action])\n", 148 | "\n", 149 | " #记录数据样本\n", 150 | " pool.add(state, action, reward, next_state, over)\n", 151 | "\n", 152 | " #更新游戏状态,开始下一个动作\n", 153 | " state = next_state\n", 154 | "\n", 155 | "\n", 156 | "_()\n", 157 | "\n", 158 | "a, b = pool.get_sample()\n", 159 | "\n", 160 | "len(pool), pool.datas[0], a.shape, b.shape" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 3, 166 | "id": "e64c315e", 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "data": { 171 | "text/plain": [ 172 | "(torch.Size([5, 64, 4]), torch.Size([5, 64, 4]))" 173 | ] 174 | }, 175 | "execution_count": 3, 176 | "metadata": {}, 177 | "output_type": "execute_result" 178 | } 179 | ], 180 | "source": [ 181 | "import random\n", 182 | "\n", 183 | "\n", 184 | "#定义主模型\n", 185 | "class Model(torch.nn.Module):\n", 186 | "\n", 187 | " #swish激活函数\n", 188 | " class Swish(torch.nn.Module):\n", 189 | " def __init__(self):\n", 190 | " super().__init__()\n", 191 | "\n", 192 | " def forward(self, x):\n", 193 | " return x * torch.sigmoid(x)\n", 194 | "\n", 195 | " #定义一个工具层\n", 196 | " class FCLayer(torch.nn.Module):\n", 197 | " def __init__(self, in_size, out_size):\n", 198 | " super().__init__()\n", 199 | " self.in_size = in_size\n", 200 | "\n", 201 | " #初始化参数\n", 202 | " std = in_size**0.5\n", 203 | " std *= 2\n", 204 | " std = 1 / std\n", 205 | "\n", 206 | " weight = torch.empty(5, in_size, out_size)\n", 207 | " torch.nn.init.normal_(weight, mean=0.0, std=std)\n", 208 | "\n", 209 | " #[5, in, out]\n", 210 | " self.weight = torch.nn.Parameter(weight)\n", 211 | "\n", 212 | " #[5, 1, out]\n", 213 | " self.bias = torch.nn.Parameter(torch.zeros(5, 1, out_size))\n", 214 | "\n", 215 | " def forward(self, x):\n", 216 | " #x -> [5, b, in]\n", 217 | "\n", 218 | " #[5, b, in] * [5, in, out] -> [5, b, out]\n", 219 | " x = torch.bmm(x, self.weight)\n", 220 | "\n", 221 | " #[5, b, out] + [5, 1, out] -> [5, b, out]\n", 222 | " x = x + self.bias\n", 223 | "\n", 224 | " return x\n", 225 | "\n", 226 | " def __init__(self):\n", 227 | " super().__init__()\n", 228 | "\n", 229 | " self.sequential = torch.nn.Sequential(\n", 230 | " self.FCLayer(4, 200),\n", 231 | " self.Swish(),\n", 232 | " self.FCLayer(200, 200),\n", 233 | " self.Swish(),\n", 234 | " self.FCLayer(200, 200),\n", 235 | " self.Swish(),\n", 236 | " self.FCLayer(200, 200),\n", 237 | " self.Swish(),\n", 238 | " self.FCLayer(200, 8),\n", 239 | " torch.nn.Identity(),\n", 240 | " )\n", 241 | "\n", 242 | " self.softplus = torch.nn.Softplus()\n", 243 | "\n", 244 | " self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n", 245 | "\n", 246 | " def forward(self, x):\n", 247 | " #x -> [5, b, 4]\n", 248 | "\n", 249 | " #[5, b, 4] -> [5, b, 8]\n", 250 | " x = self.sequential(x)\n", 251 | "\n", 252 | " #[5, b, 8] -> [5, b, 4]\n", 253 | " mean = x[..., :4]\n", 254 | "\n", 255 | " #[5, b, 8] -> [5, b, 4]\n", 256 | " logvar = x[..., 4:]\n", 257 | "\n", 258 | " #[1, 1, 4] - [5, b, 4] -> [5, b, 4]\n", 259 | " logvar = 0.5 - logvar\n", 260 | "\n", 261 | " #[1, 1, 4] - [5, b, 4] -> [5, b, 4]\n", 262 | " logvar = 0.5 - self.softplus(logvar)\n", 263 | "\n", 264 | " #[5, b, 4] - [1, 1, 4] -> [5, b, 4]\n", 265 | " logvar = logvar + 10\n", 266 | "\n", 267 | " #[5, b, 4] + [1, 1, 4] -> [5, b, 4]\n", 268 | " logvar = self.softplus(logvar) - 10\n", 269 | "\n", 270 | " #[5, b, 4],[5, b, 4]\n", 271 | " return mean, logvar\n", 272 | "\n", 273 | " def train(self, input, label):\n", 274 | " #input -> [b, 4]\n", 275 | " #label -> [b, 4]\n", 276 | "\n", 277 | " #反复训练N次\n", 278 | " for _ in range(len(input) // 64 * 20):\n", 279 | " #从全量数据中抽样64个,反复抽5遍,形成5份数据\n", 280 | " #[5, 64]\n", 281 | " select = [torch.randperm(len(input))[:64] for _ in range(5)]\n", 282 | " select = torch.stack(select)\n", 283 | " #[5, b, 4],[5, b, 4]\n", 284 | " input_select = input[select]\n", 285 | " label_select = label[select]\n", 286 | " del select\n", 287 | "\n", 288 | " #模型计算\n", 289 | " #[5, b, 4] -> [5, b, 4],[5, b, 4]\n", 290 | " mean, logvar = model(input_select)\n", 291 | "\n", 292 | " #计算loss\n", 293 | " #[b, 4] - [b, 4] * [b, 4] -> [b, 4]\n", 294 | " mse_loss = (mean - label_select)**2 * (-logvar).exp()\n", 295 | "\n", 296 | " #[b, 4] -> [b] -> scala\n", 297 | " mse_loss = mse_loss.mean(dim=1).mean()\n", 298 | "\n", 299 | " #[b, 4] -> [b] -> scala\n", 300 | " var_loss = logvar.mean(dim=1).mean()\n", 301 | "\n", 302 | " loss = mse_loss + var_loss\n", 303 | "\n", 304 | " self.optimizer.zero_grad()\n", 305 | " loss.backward()\n", 306 | " self.optimizer.step()\n", 307 | "\n", 308 | "\n", 309 | "model = Model()\n", 310 | "#model.train(torch.randn(200, 4), torch.randn(200, 4))\n", 311 | "\n", 312 | "a, b = model(torch.randn(5, 64, 4))\n", 313 | "a.shape, b.shape" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 4, 319 | "id": "1bced33b", 320 | "metadata": {}, 321 | "outputs": [ 322 | { 323 | "name": "stdout", 324 | "output_type": "stream", 325 | "text": [ 326 | "torch.Size([200, 1]) torch.Size([200, 3])\n", 327 | "torch.Size([25])\n" 328 | ] 329 | } 330 | ], 331 | "source": [ 332 | "class MPC:\n", 333 | " def _fake_step(self, state, action):\n", 334 | " #state -> [b, 3]\n", 335 | " #action -> [b, 1]\n", 336 | "\n", 337 | " #[b, 4]\n", 338 | " input = torch.cat([state, action], dim=1)\n", 339 | "\n", 340 | " #重复5遍\n", 341 | " #[b, 4] -> [1, b, 4] -> [5, b, 4]\n", 342 | " input = input.unsqueeze(dim=0).repeat([5, 1, 1])\n", 343 | "\n", 344 | " #模型计算\n", 345 | " #[5, b, 4] -> [5, b, 4],[5, b, 4]\n", 346 | " with torch.no_grad():\n", 347 | " mean, std = model(input)\n", 348 | " std = std.exp().sqrt()\n", 349 | " del input\n", 350 | "\n", 351 | " #means的后3列加上环境数据\n", 352 | " mean[:, :, 1:] += state\n", 353 | "\n", 354 | " #重采样\n", 355 | " #[5, b ,4]\n", 356 | " sample = torch.distributions.Normal(0, 1).sample(mean.shape)\n", 357 | " sample = mean + sample * std\n", 358 | "\n", 359 | " #0-4的值域采样b个元素\n", 360 | " #[4, 4, 2, 4, 3, 4, 1, 3, 3, 0, 2,...\n", 361 | " select = [random.choice(range(5)) for _ in range(mean.shape[1])]\n", 362 | "\n", 363 | " #重采样结果,的结果,第0个维度,0-4随机选择,第二个维度,0-b顺序选择\n", 364 | " #[5, b ,4] -> [b, 4]\n", 365 | " sample = sample[select, range(mean.shape[1])]\n", 366 | "\n", 367 | " #切分一下,就成了rewards, next_state\n", 368 | " reward, next_state = sample[:, :1], sample[:, 1:]\n", 369 | "\n", 370 | " return reward, next_state\n", 371 | "\n", 372 | " def _cem_optimize(self, state, mean):\n", 373 | " state = torch.FloatTensor(state).reshape(1, 3)\n", 374 | " var = torch.ones(25)\n", 375 | " #state -> [1, 3]\n", 376 | " #mean -> [25]\n", 377 | "\n", 378 | " #当前游戏的环境信息,复制50次\n", 379 | " #[1, 3] -> [50, 3]\n", 380 | " state = state.repeat(50, 1)\n", 381 | "\n", 382 | " #循环5次,寻找最优解\n", 383 | " for _ in range(5):\n", 384 | " #采样50个标准正态分布数据作为action\n", 385 | " actions = torch.distributions.Normal(0, 1).sample([50, 25])\n", 386 | "\n", 387 | " #乘以标准差,加上均值\n", 388 | " #[50, 25] * [25] -> [50, 25]\n", 389 | " actions *= var**0.5\n", 390 | " #[50, 25] * [25] -> [50, 25]\n", 391 | " actions += mean\n", 392 | "\n", 393 | " #计算每条动作序列的累积奖励\n", 394 | " #[50, 1]\n", 395 | " reward_sum = torch.zeros(50, 1)\n", 396 | "\n", 397 | " #遍历25个动作\n", 398 | " for i in range(25):\n", 399 | " #[50, 25] -> [50, 1]\n", 400 | " action = actions[:, i].unsqueeze(dim=1)\n", 401 | "\n", 402 | " #现在是不能真的去玩游戏的,只能去预测reward和next_state\n", 403 | " #[50, 3],[50, 1] -> [50, 1],[50, 3]\n", 404 | " reward, state = self._fake_step(state, action)\n", 405 | "\n", 406 | " #[50, 1] + [50, 1] -> [50, 1]\n", 407 | " reward_sum += reward\n", 408 | "\n", 409 | " #按照reward_sum从小到大排列\n", 410 | " #[50]\n", 411 | " select = torch.sort(reward_sum.squeeze(dim=1)).indices\n", 412 | " #[50, 25]\n", 413 | " actions = actions[select]\n", 414 | " del select\n", 415 | "\n", 416 | " #取反馈最优的10个动作链\n", 417 | " #[10, 25]\n", 418 | " actions = actions[-10:]\n", 419 | "\n", 420 | " #在下一次随机动作时,希望贴近这些动作的分布\n", 421 | " #[25]\n", 422 | " new_mean = actions.mean(dim=0)\n", 423 | " new_var = actions.var(dim=0)\n", 424 | "\n", 425 | " #增量更新\n", 426 | " #[25] + [25] -> [25]\n", 427 | " mean = 0.1 * mean + 0.9 * new_mean\n", 428 | " #[25] + [25] -> [25]\n", 429 | " var = 0.1 * var + 0.9 * new_var\n", 430 | "\n", 431 | " return mean\n", 432 | "\n", 433 | " def mpc(self):\n", 434 | " #初始化动作的分布均值都是0\n", 435 | " mean = torch.zeros(25)\n", 436 | "\n", 437 | " reward_sum = 0\n", 438 | " state = env.reset()\n", 439 | " over = False\n", 440 | " while not over:\n", 441 | " #在当前状态下,找25个最优动作的均值\n", 442 | " #[1, 3],[25],[25] -> [25]\n", 443 | " actions = self._cem_optimize(state, mean)\n", 444 | "\n", 445 | " #执行第一个动作\n", 446 | " action = actions[0].item()\n", 447 | "\n", 448 | " #执行动作\n", 449 | " next_state, reward, over, _ = env.step([action])\n", 450 | "\n", 451 | " #增加数据\n", 452 | " pool.add(state, action, reward, next_state, over)\n", 453 | "\n", 454 | " state = next_state\n", 455 | " reward_sum += reward\n", 456 | "\n", 457 | " #下个动作的均值,在当前动作均值的基础上寻找\n", 458 | " #[25]\n", 459 | " mean = torch.empty(actions.shape)\n", 460 | " mean[:-1] = actions[1:]\n", 461 | " mean[-1] = 0\n", 462 | "\n", 463 | " return reward_sum\n", 464 | "\n", 465 | "\n", 466 | "mpc = MPC()\n", 467 | "\n", 468 | "a, b = mpc._fake_step(torch.randn(200, 3), torch.randn(200, 1))\n", 469 | "\n", 470 | "print(a.shape, b.shape)\n", 471 | "\n", 472 | "print(mpc._cem_optimize(torch.randn(1, 3), torch.zeros(25)).shape)\n", 473 | "\n", 474 | "#print(mpc.mpc())" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": 5, 480 | "id": "b748ed1d", 481 | "metadata": {}, 482 | "outputs": [ 483 | { 484 | "name": "stdout", 485 | "output_type": "stream", 486 | "text": [ 487 | "0 400 -970.1462027358459\n", 488 | "1 600 -1044.89966441498\n", 489 | "2 800 -984.9260862715894\n", 490 | "3 1000 -1691.3082241692753\n", 491 | "4 1200 -886.7489707469861\n", 492 | "5 1400 -872.9982830534415\n", 493 | "6 1600 -1320.1991576977232\n", 494 | "7 1800 -772.4313091943403\n", 495 | "8 2000 -844.8528234513566\n", 496 | "9 2200 -502.5910743767378\n" 497 | ] 498 | } 499 | ], 500 | "source": [ 501 | "for i in range(10):\n", 502 | " input, label = pool.get_sample()\n", 503 | " model.train(input, label)\n", 504 | " reward_sum = mpc.mpc()\n", 505 | " print(i, len(pool), reward_sum)" 506 | ] 507 | } 508 | ], 509 | "metadata": { 510 | "kernelspec": { 511 | "display_name": "Python [conda env:pt39]", 512 | "language": "python", 513 | "name": "conda-env-pt39-py" 514 | }, 515 | "language_info": { 516 | "codemirror_mode": { 517 | "name": "ipython", 518 | "version": 3 519 | }, 520 | "file_extension": ".py", 521 | "mimetype": "text/x-python", 522 | "name": "python", 523 | "nbconvert_exporter": "python", 524 | "pygments_lexer": "ipython3", 525 | "version": "3.9.13" 526 | } 527 | }, 528 | "nbformat": 4, 529 | "nbformat_minor": 5 530 | } 531 | -------------------------------------------------------------------------------- /15.MPC/1.MPC.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4a033743", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "array([-0.35493255, 0.93489194, -0.32008162], dtype=float32)" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import gym\n", 22 | "\n", 23 | "\n", 24 | "#定义环境\n", 25 | "class MyWrapper(gym.Wrapper):\n", 26 | " def __init__(self):\n", 27 | " env = gym.make('Pendulum-v1', render_mode='rgb_array')\n", 28 | " super().__init__(env)\n", 29 | " self.env = env\n", 30 | " self.step_n = 0\n", 31 | "\n", 32 | " def reset(self):\n", 33 | " state, _ = self.env.reset()\n", 34 | " self.step_n = 0\n", 35 | " return state\n", 36 | "\n", 37 | " def step(self, action):\n", 38 | " state, reward, terminated, truncated, info = self.env.step(action)\n", 39 | " done = terminated or truncated\n", 40 | " self.step_n += 1\n", 41 | " if self.step_n >= 200:\n", 42 | " done = True\n", 43 | " return state, reward, done, info\n", 44 | "\n", 45 | "\n", 46 | "env = MyWrapper()\n", 47 | "\n", 48 | "env.reset()" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "id": "9c383bd4", 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "(200,\n", 61 | " ([-0.9102440476417542, -0.4140722155570984, 0.9576234817504883],\n", 62 | " 0.35440483689308167,\n", 63 | " -7.461259538753717,\n", 64 | " [-0.8951918482780457, -0.44568097591400146, 0.700230062007904],\n", 65 | " False),\n", 66 | " torch.Size([200, 4]),\n", 67 | " torch.Size([200, 4]))" 68 | ] 69 | }, 70 | "execution_count": 2, 71 | "metadata": {}, 72 | "output_type": "execute_result" 73 | } 74 | ], 75 | "source": [ 76 | "import numpy as np\n", 77 | "import torch\n", 78 | "\n", 79 | "\n", 80 | "class Pool:\n", 81 | " def __init__(self, limit):\n", 82 | " #样本池\n", 83 | " self.datas = []\n", 84 | " self.limit = limit\n", 85 | "\n", 86 | " def add(self, state, action, reward, next_state, over):\n", 87 | " if isinstance(state, np.ndarray) or isinstance(state, torch.Tensor):\n", 88 | " state = state.reshape(3).tolist()\n", 89 | "\n", 90 | " action = float(action)\n", 91 | "\n", 92 | " reward = float(reward)\n", 93 | "\n", 94 | " if isinstance(next_state, np.ndarray) or isinstance(\n", 95 | " next_state, torch.Tensor):\n", 96 | " next_state = next_state.reshape(3).tolist()\n", 97 | "\n", 98 | " over = bool(over)\n", 99 | "\n", 100 | " self.datas.append((state, action, reward, next_state, over))\n", 101 | " #数据上限,超出时从最古老的开始删除\n", 102 | " while len(self.datas) > self.limit:\n", 103 | " self.datas.pop(0)\n", 104 | "\n", 105 | " #获取一批数据样本\n", 106 | " def get_sample(self):\n", 107 | " #从样本池中采样\n", 108 | " samples = self.datas\n", 109 | "\n", 110 | " #[b, 3]\n", 111 | " state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n", 112 | " #[b, 1]\n", 113 | " action = torch.FloatTensor([i[1] for i in samples]).reshape(-1, 1)\n", 114 | " #[b, 1]\n", 115 | " reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n", 116 | " #[b, 3]\n", 117 | " next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)\n", 118 | " #[b, 1]\n", 119 | " over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n", 120 | "\n", 121 | " #[b, 4]\n", 122 | " input = torch.cat([state, action], dim=1)\n", 123 | " #[b, 4]\n", 124 | " label = torch.cat([reward, next_state - state], dim=1)\n", 125 | "\n", 126 | " return input, label\n", 127 | "\n", 128 | " def __len__(self):\n", 129 | " return len(self.datas)\n", 130 | "\n", 131 | "\n", 132 | "pool = Pool(100000)\n", 133 | "\n", 134 | "\n", 135 | "#初始化一局游戏的数据\n", 136 | "def _():\n", 137 | " #初始化游戏\n", 138 | " state = env.reset()\n", 139 | "\n", 140 | " #玩到游戏结束为止\n", 141 | " over = False\n", 142 | " while not over:\n", 143 | " #随机一个动作\n", 144 | " action = env.action_space.sample()[0]\n", 145 | "\n", 146 | " #执行动作,得到反馈\n", 147 | " next_state, reward, over, _ = env.step([action])\n", 148 | "\n", 149 | " #记录数据样本\n", 150 | " pool.add(state, action, reward, next_state, over)\n", 151 | "\n", 152 | " #更新游戏状态,开始下一个动作\n", 153 | " state = next_state\n", 154 | "\n", 155 | "\n", 156 | "_()\n", 157 | "\n", 158 | "a, b = pool.get_sample()\n", 159 | "\n", 160 | "len(pool), pool.datas[0], a.shape, b.shape" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 3, 166 | "id": "e64c315e", 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "data": { 171 | "text/plain": [ 172 | "(torch.Size([5, 64, 4]), torch.Size([5, 64, 4]))" 173 | ] 174 | }, 175 | "execution_count": 3, 176 | "metadata": {}, 177 | "output_type": "execute_result" 178 | } 179 | ], 180 | "source": [ 181 | "import random\n", 182 | "\n", 183 | "\n", 184 | "#定义主模型\n", 185 | "class Model(torch.nn.Module):\n", 186 | "\n", 187 | " #swish激活函数\n", 188 | " class Swish(torch.nn.Module):\n", 189 | " def __init__(self):\n", 190 | " super().__init__()\n", 191 | "\n", 192 | " def forward(self, x):\n", 193 | " return x * torch.sigmoid(x)\n", 194 | "\n", 195 | " #定义一个工具层\n", 196 | " class FCLayer(torch.nn.Module):\n", 197 | " def __init__(self, in_size, out_size):\n", 198 | " super().__init__()\n", 199 | " self.in_size = in_size\n", 200 | "\n", 201 | " #初始化参数\n", 202 | " std = in_size**0.5\n", 203 | " std *= 2\n", 204 | " std = 1 / std\n", 205 | "\n", 206 | " weight = torch.empty(5, in_size, out_size)\n", 207 | " torch.nn.init.normal_(weight, mean=0.0, std=std)\n", 208 | "\n", 209 | " #[5, in, out]\n", 210 | " self.weight = torch.nn.Parameter(weight)\n", 211 | "\n", 212 | " #[5, 1, out]\n", 213 | " self.bias = torch.nn.Parameter(torch.zeros(5, 1, out_size))\n", 214 | "\n", 215 | " def forward(self, x):\n", 216 | " #x -> [5, b, in]\n", 217 | "\n", 218 | " #[5, b, in] * [5, in, out] -> [5, b, out]\n", 219 | " x = torch.bmm(x, self.weight)\n", 220 | "\n", 221 | " #[5, b, out] + [5, 1, out] -> [5, b, out]\n", 222 | " x = x + self.bias\n", 223 | "\n", 224 | " return x\n", 225 | "\n", 226 | " def __init__(self):\n", 227 | " super().__init__()\n", 228 | "\n", 229 | " self.sequential = torch.nn.Sequential(\n", 230 | " self.FCLayer(4, 200),\n", 231 | " self.Swish(),\n", 232 | " self.FCLayer(200, 200),\n", 233 | " self.Swish(),\n", 234 | " self.FCLayer(200, 200),\n", 235 | " self.Swish(),\n", 236 | " self.FCLayer(200, 200),\n", 237 | " self.Swish(),\n", 238 | " self.FCLayer(200, 8),\n", 239 | " torch.nn.Identity(),\n", 240 | " )\n", 241 | "\n", 242 | " self.softplus = torch.nn.Softplus()\n", 243 | "\n", 244 | " self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n", 245 | "\n", 246 | " def forward(self, x):\n", 247 | " #x -> [5, b, 4]\n", 248 | "\n", 249 | " #[5, b, 4] -> [5, b, 8]\n", 250 | " x = self.sequential(x)\n", 251 | "\n", 252 | " #[5, b, 8] -> [5, b, 4]\n", 253 | " mean = x[..., :4]\n", 254 | "\n", 255 | " #[5, b, 8] -> [5, b, 4]\n", 256 | " logvar = x[..., 4:]\n", 257 | "\n", 258 | " #[1, 1, 4] - [5, b, 4] -> [5, b, 4]\n", 259 | " logvar = 0.5 - logvar\n", 260 | "\n", 261 | " #[1, 1, 4] - [5, b, 4] -> [5, b, 4]\n", 262 | " logvar = 0.5 - self.softplus(logvar)\n", 263 | "\n", 264 | " #[5, b, 4] - [1, 1, 4] -> [5, b, 4]\n", 265 | " logvar = logvar + 10\n", 266 | "\n", 267 | " #[5, b, 4] + [1, 1, 4] -> [5, b, 4]\n", 268 | " logvar = self.softplus(logvar) - 10\n", 269 | "\n", 270 | " #[5, b, 4],[5, b, 4]\n", 271 | " return mean, logvar\n", 272 | "\n", 273 | " def train(self, input, label):\n", 274 | " #input -> [b, 4]\n", 275 | " #label -> [b, 4]\n", 276 | "\n", 277 | " #反复训练N次\n", 278 | " for _ in range(len(input) // 64 * 20):\n", 279 | " #从全量数据中抽样64个,反复抽5遍,形成5份数据\n", 280 | " #[5, 64]\n", 281 | " select = [torch.randperm(len(input))[:64] for _ in range(5)]\n", 282 | " select = torch.stack(select)\n", 283 | " #[5, b, 4],[5, b, 4]\n", 284 | " input_select = input[select]\n", 285 | " label_select = label[select]\n", 286 | " del select\n", 287 | "\n", 288 | " #模型计算\n", 289 | " #[5, b, 4] -> [5, b, 4],[5, b, 4]\n", 290 | " mean, logvar = model(input_select)\n", 291 | "\n", 292 | " #计算loss\n", 293 | " #[b, 4] - [b, 4] * [b, 4] -> [b, 4]\n", 294 | " mse_loss = (mean - label_select)**2 * (-logvar).exp()\n", 295 | "\n", 296 | " #[b, 4] -> [b] -> scala\n", 297 | " mse_loss = mse_loss.mean(dim=1).mean()\n", 298 | "\n", 299 | " #[b, 4] -> [b] -> scala\n", 300 | " var_loss = logvar.mean(dim=1).mean()\n", 301 | "\n", 302 | " loss = mse_loss + var_loss\n", 303 | "\n", 304 | " self.optimizer.zero_grad()\n", 305 | " loss.backward()\n", 306 | " self.optimizer.step()\n", 307 | "\n", 308 | "\n", 309 | "model = Model()\n", 310 | "#model.train(torch.randn(200, 4), torch.randn(200, 4))\n", 311 | "\n", 312 | "a, b = model(torch.randn(5, 64, 4))\n", 313 | "a.shape, b.shape" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 4, 319 | "id": "1bced33b", 320 | "metadata": {}, 321 | "outputs": [ 322 | { 323 | "name": "stdout", 324 | "output_type": "stream", 325 | "text": [ 326 | "torch.Size([200, 1]) torch.Size([200, 3])\n", 327 | "torch.Size([25])\n" 328 | ] 329 | } 330 | ], 331 | "source": [ 332 | "class MPC:\n", 333 | " def _fake_step(self, state, action):\n", 334 | " #state -> [b, 3]\n", 335 | " #action -> [b, 1]\n", 336 | "\n", 337 | " #[b, 4]\n", 338 | " input = torch.cat([state, action], dim=1)\n", 339 | "\n", 340 | " #重复5遍\n", 341 | " #[b, 4] -> [1, b, 4] -> [5, b, 4]\n", 342 | " input = input.unsqueeze(dim=0).repeat([5, 1, 1])\n", 343 | "\n", 344 | " #模型计算\n", 345 | " #[5, b, 4] -> [5, b, 4],[5, b, 4]\n", 346 | " with torch.no_grad():\n", 347 | " mean, std = model(input)\n", 348 | " std = std.exp().sqrt()\n", 349 | " del input\n", 350 | "\n", 351 | " #means的后3列加上环境数据\n", 352 | " mean[:, :, 1:] += state\n", 353 | "\n", 354 | " #重采样\n", 355 | " #[5, b ,4]\n", 356 | " sample = torch.distributions.Normal(0, 1).sample(mean.shape)\n", 357 | " sample = mean + sample * std\n", 358 | "\n", 359 | " #0-4的值域采样b个元素\n", 360 | " #[4, 4, 2, 4, 3, 4, 1, 3, 3, 0, 2,...\n", 361 | " select = [random.choice(range(5)) for _ in range(mean.shape[1])]\n", 362 | "\n", 363 | " #重采样结果,的结果,第0个维度,0-4随机选择,第二个维度,0-b顺序选择\n", 364 | " #[5, b ,4] -> [b, 4]\n", 365 | " sample = sample[select, range(mean.shape[1])]\n", 366 | "\n", 367 | " #切分一下,就成了rewards, next_state\n", 368 | " reward, next_state = sample[:, :1], sample[:, 1:]\n", 369 | "\n", 370 | " return reward, next_state\n", 371 | "\n", 372 | " def _cem_optimize(self, state, mean):\n", 373 | " state = torch.FloatTensor(state).reshape(1, 3)\n", 374 | " var = torch.ones(25)\n", 375 | " #state -> [1, 3]\n", 376 | " #mean -> [25]\n", 377 | "\n", 378 | " #当前游戏的环境信息,复制50次\n", 379 | " #[1, 3] -> [50, 3]\n", 380 | " state = state.repeat(50, 1)\n", 381 | "\n", 382 | " #循环5次,寻找最优解\n", 383 | " for _ in range(5):\n", 384 | " #采样50个标准正态分布数据作为action\n", 385 | " actions = torch.distributions.Normal(0, 1).sample([50, 25])\n", 386 | "\n", 387 | " #乘以标准差,加上均值\n", 388 | " #[50, 25] * [25] -> [50, 25]\n", 389 | " actions *= var**0.5\n", 390 | " #[50, 25] * [25] -> [50, 25]\n", 391 | " actions += mean\n", 392 | "\n", 393 | " #计算每条动作序列的累积奖励\n", 394 | " #[50, 1]\n", 395 | " reward_sum = torch.zeros(50, 1)\n", 396 | "\n", 397 | " #遍历25个动作\n", 398 | " for i in range(25):\n", 399 | " #[50, 25] -> [50, 1]\n", 400 | " action = actions[:, i].unsqueeze(dim=1)\n", 401 | "\n", 402 | " #现在是不能真的去玩游戏的,只能去预测reward和next_state\n", 403 | " #[50, 3],[50, 1] -> [50, 1],[50, 3]\n", 404 | " reward, state = self._fake_step(state, action)\n", 405 | "\n", 406 | " #[50, 1] + [50, 1] -> [50, 1]\n", 407 | " reward_sum += reward\n", 408 | "\n", 409 | " #按照reward_sum从小到大排列\n", 410 | " #[50]\n", 411 | " select = torch.sort(reward_sum.squeeze(dim=1)).indices\n", 412 | " #[50, 25]\n", 413 | " actions = actions[select]\n", 414 | " del select\n", 415 | "\n", 416 | " #取反馈最优的10个动作链\n", 417 | " #[10, 25]\n", 418 | " actions = actions[-10:]\n", 419 | "\n", 420 | " #在下一次随机动作时,希望贴近这些动作的分布\n", 421 | " #[25]\n", 422 | " new_mean = actions.mean(dim=0)\n", 423 | " new_var = actions.var(dim=0)\n", 424 | "\n", 425 | " #增量更新\n", 426 | " #[25] + [25] -> [25]\n", 427 | " mean = 0.1 * mean + 0.9 * new_mean\n", 428 | " #[25] + [25] -> [25]\n", 429 | " var = 0.1 * var + 0.9 * new_var\n", 430 | "\n", 431 | " return mean\n", 432 | "\n", 433 | " def mpc(self):\n", 434 | " #初始化动作的分布均值都是0\n", 435 | " mean = torch.zeros(25)\n", 436 | "\n", 437 | " reward_sum = 0\n", 438 | " state = env.reset()\n", 439 | " over = False\n", 440 | " while not over:\n", 441 | " #在当前状态下,找25个最优动作的均值\n", 442 | " #[1, 3],[25],[25] -> [25]\n", 443 | " actions = self._cem_optimize(state, mean)\n", 444 | "\n", 445 | " #执行第一个动作\n", 446 | " action = actions[0].item()\n", 447 | "\n", 448 | " #执行动作\n", 449 | " next_state, reward, over, _ = env.step([action])\n", 450 | "\n", 451 | " #增加数据\n", 452 | " pool.add(state, action, reward, next_state, over)\n", 453 | "\n", 454 | " state = next_state\n", 455 | " reward_sum += reward\n", 456 | "\n", 457 | " #下个动作的均值,在当前动作均值的基础上寻找\n", 458 | " #[25]\n", 459 | " mean = torch.empty(actions.shape)\n", 460 | " mean[:-1] = actions[1:]\n", 461 | " mean[-1] = 0\n", 462 | "\n", 463 | " return reward_sum\n", 464 | "\n", 465 | "\n", 466 | "mpc = MPC()\n", 467 | "\n", 468 | "a, b = mpc._fake_step(torch.randn(200, 3), torch.randn(200, 1))\n", 469 | "\n", 470 | "print(a.shape, b.shape)\n", 471 | "\n", 472 | "print(mpc._cem_optimize(torch.randn(1, 3), torch.zeros(25)).shape)\n", 473 | "\n", 474 | "#print(mpc.mpc())" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": 5, 480 | "id": "b748ed1d", 481 | "metadata": {}, 482 | "outputs": [ 483 | { 484 | "name": "stdout", 485 | "output_type": "stream", 486 | "text": [ 487 | "0 400 -970.1462027358459\n", 488 | "1 600 -1044.89966441498\n", 489 | "2 800 -984.9260862715894\n", 490 | "3 1000 -1691.3082241692753\n", 491 | "4 1200 -886.7489707469861\n", 492 | "5 1400 -872.9982830534415\n", 493 | "6 1600 -1320.1991576977232\n", 494 | "7 1800 -772.4313091943403\n", 495 | "8 2000 -844.8528234513566\n", 496 | "9 2200 -502.5910743767378\n" 497 | ] 498 | } 499 | ], 500 | "source": [ 501 | "for i in range(10):\n", 502 | " input, label = pool.get_sample()\n", 503 | " model.train(input, label)\n", 504 | " reward_sum = mpc.mpc()\n", 505 | " print(i, len(pool), reward_sum)" 506 | ] 507 | } 508 | ], 509 | "metadata": { 510 | "kernelspec": { 511 | "display_name": "Python [conda env:pt39]", 512 | "language": "python", 513 | "name": "conda-env-pt39-py" 514 | }, 515 | "language_info": { 516 | "codemirror_mode": { 517 | "name": "ipython", 518 | "version": 3 519 | }, 520 | "file_extension": ".py", 521 | "mimetype": "text/x-python", 522 | "name": "python", 523 | "nbconvert_exporter": "python", 524 | "pygments_lexer": "ipython3", 525 | "version": "3.9.13" 526 | } 527 | }, 528 | "nbformat": 4, 529 | "nbformat_minor": 5 530 | } 531 | -------------------------------------------------------------------------------- /18.多智能体/.ipynb_checkpoints/1.多智能体-Copy1-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": { 7 | "scrolled": false 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stdout", 12 | "output_type": "stream", 13 | "text": [ 14 | "state= 150 150\n", 15 | "action= [5, 2]\n", 16 | "reward= [0, 0]\n", 17 | "next_state= 150 150\n", 18 | "over= [False, False]\n" 19 | ] 20 | }, 21 | { 22 | "data": { 23 | "text/plain": [ 24 | "" 25 | ] 26 | }, 27 | "execution_count": 4, 28 | "metadata": {}, 29 | "output_type": "execute_result" 30 | } 31 | ], 32 | "source": [ 33 | "from combat import Combat\n", 34 | "\n", 35 | "\n", 36 | "def test_env():\n", 37 | " state = env.reset()\n", 38 | " action = env.action_space.sample()\n", 39 | " next_state, reward, over, _ = env.step(action)\n", 40 | "\n", 41 | " print('state=', len(state[0]), len(state[1]))\n", 42 | " print('action=', action)\n", 43 | " print('reward=', reward)\n", 44 | " print('next_state=', len(next_state[0]), len(next_state[1]))\n", 45 | " print('over=', over)\n", 46 | "\n", 47 | "\n", 48 | "env = Combat(grid_shape=(15, 15), n_agents=2, n_opponents=2)\n", 49 | "\n", 50 | "test_env()\n", 51 | "\n", 52 | "env" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "data": { 62 | "text/plain": [ 63 | "5" 64 | ] 65 | }, 66 | "execution_count": 2, 67 | "metadata": {}, 68 | "output_type": "execute_result" 69 | } 70 | ], 71 | "source": [ 72 | "import random\n", 73 | "import torch\n", 74 | "\n", 75 | "\n", 76 | "class PPO:\n", 77 | " def __init__(self):\n", 78 | " self.model_action = torch.nn.Sequential(\n", 79 | " torch.nn.Linear(150, 64),\n", 80 | " torch.nn.ReLU(),\n", 81 | " torch.nn.Linear(64, 64),\n", 82 | " torch.nn.ReLU(),\n", 83 | " torch.nn.Linear(64, 7),\n", 84 | " torch.nn.Softmax(dim=1),\n", 85 | " )\n", 86 | " self.model_value = torch.nn.Sequential(\n", 87 | " torch.nn.Linear(150, 64),\n", 88 | " torch.nn.ReLU(),\n", 89 | " torch.nn.Linear(64, 64),\n", 90 | " torch.nn.ReLU(),\n", 91 | " torch.nn.Linear(64, 1),\n", 92 | " )\n", 93 | " self.optimizer_action = torch.optim.Adam(\n", 94 | " self.model_action.parameters(), lr=3e-4)\n", 95 | " self.optimizer_value = torch.optim.Adam(self.model_value.parameters(),\n", 96 | " lr=3e-3)\n", 97 | "\n", 98 | " self.mse_loss = torch.nn.MSELoss()\n", 99 | "\n", 100 | " def get_action(self, state):\n", 101 | " state = torch.FloatTensor(state).reshape(1, 150)\n", 102 | "\n", 103 | " #[1, 150] -> [7]\n", 104 | " weights = self.model_action(state).squeeze(dim=0).tolist()\n", 105 | "\n", 106 | " #[7] -> scala\n", 107 | " action = random.choices(range(7), weights=weights, k=1)[0]\n", 108 | "\n", 109 | " return action\n", 110 | "\n", 111 | " def _get_advantages(self, deltas):\n", 112 | " advantages = []\n", 113 | "\n", 114 | " #反向遍历deltas\n", 115 | " s = 0.0\n", 116 | " for delta in deltas[::-1]:\n", 117 | " s = 0.99 * 0.97 * s + delta\n", 118 | " advantages.append(s)\n", 119 | "\n", 120 | " #逆序\n", 121 | " advantages.reverse()\n", 122 | " return advantages\n", 123 | "\n", 124 | " def _get_target(self, next_state, reward, over):\n", 125 | " #[b, 150] -> [b, 1]\n", 126 | " target = self.model_value(next_state)\n", 127 | " target *= 0.99\n", 128 | " target *= (1 - over)\n", 129 | " target += reward\n", 130 | " return target\n", 131 | "\n", 132 | " def _get_value(self, state):\n", 133 | " #[b, 150] -> [b, 1]\n", 134 | " return self.model_value(state)\n", 135 | "\n", 136 | " def train(self, state, action, reward, next_state, over):\n", 137 | " #state -> [b, 150]\n", 138 | " #action -> [b, 1]\n", 139 | " #reward -> [b, 1]\n", 140 | " #next_state -> [b, 150]\n", 141 | " #over -> [b, 1]\n", 142 | "\n", 143 | " #[b, 1]\n", 144 | " target = self._get_target(next_state, reward, over).detach()\n", 145 | " #[b, 150] -> [b, 1]\n", 146 | " value = self._get_value(state)\n", 147 | "\n", 148 | " #[b, 1] - [b, 1] -> [b, 1] -> [b]\n", 149 | " delta = (target - value).squeeze(dim=1).tolist()\n", 150 | " #[b] -> [b]\n", 151 | " advantages = self._get_advantages(delta)\n", 152 | " #[b] -> [b, 1]\n", 153 | " advantages = torch.FloatTensor(advantages).reshape(-1, 1)\n", 154 | "\n", 155 | " #[b, 150] -> [b, 7]\n", 156 | " old_prob = self.model_action(state)\n", 157 | " #[b, 7] -> [b, 1]\n", 158 | " old_prob = old_prob.gather(1, action)\n", 159 | " #[b, 1] -> [b, 1]\n", 160 | " old_prob = old_prob.log().detach()\n", 161 | "\n", 162 | " for _ in range(1):\n", 163 | " #[b, 150] -> [b, 7]\n", 164 | " new_prob = self.model_action(state)\n", 165 | " #[b, 7] -> [b, 1]\n", 166 | " new_prob = new_prob.gather(1, action)\n", 167 | " #[b, 1] -> [b, 1]\n", 168 | " new_prob = new_prob.log()\n", 169 | "\n", 170 | " #[b, 1] - [b, 1] -> [b, 1]\n", 171 | " ratio = (new_prob - old_prob).exp()\n", 172 | "\n", 173 | " #[b, 1] * [b, 1] -> [b, 1]\n", 174 | " surr1 = ratio * advantages\n", 175 | " #[b, 1] * [b, 1] -> [b, 1]\n", 176 | " surr2 = torch.clamp(ratio, 0.8, 1.2) * advantages\n", 177 | "\n", 178 | " #[b, 1]\n", 179 | " loss_action = torch.min(surr1, surr2)\n", 180 | " loss_action = -loss_action\n", 181 | " #[b, 1] -> scala\n", 182 | " loss_action = loss_action.mean()\n", 183 | "\n", 184 | " self.optimizer_action.zero_grad()\n", 185 | " loss_action.backward()\n", 186 | " self.optimizer_action.step()\n", 187 | "\n", 188 | " #[b, 4] -> [b, 1]\n", 189 | " value = self._get_value(state)\n", 190 | "\n", 191 | " self.optimizer_value.zero_grad()\n", 192 | " #[b, 1],[b, 1] -> scala\n", 193 | " loss_action = self.mse_loss(value, target)\n", 194 | " self.optimizer_value.step()\n", 195 | "\n", 196 | "\n", 197 | "ppo = PPO()\n", 198 | "\n", 199 | "ppo.train(\n", 200 | " torch.randn(5, 150),\n", 201 | " torch.ones(5, 1).long(),\n", 202 | " torch.randn(5, 1),\n", 203 | " torch.randn(5, 150),\n", 204 | " torch.zeros(5, 1).long(),\n", 205 | ")\n", 206 | "\n", 207 | "ppo.get_action(list(range(150)))" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 3, 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "data": { 217 | "text/plain": [ 218 | "({'state': tensor([[0., 0., 0., ..., 0., 0., 0.],\n", 219 | " [0., 0., 0., ..., 0., 0., 0.],\n", 220 | " [0., 0., 0., ..., 0., 0., 0.],\n", 221 | " ...,\n", 222 | " [0., 0., 0., ..., 0., 0., 0.],\n", 223 | " [0., 0., 0., ..., 0., 0., 0.],\n", 224 | " [0., 0., 0., ..., 0., 0., 0.]]),\n", 225 | " 'action': tensor([[1],\n", 226 | " [2],\n", 227 | " [1],\n", 228 | " [3],\n", 229 | " [1],\n", 230 | " [1],\n", 231 | " [0],\n", 232 | " [5],\n", 233 | " [1],\n", 234 | " [4],\n", 235 | " [4],\n", 236 | " [4],\n", 237 | " [3],\n", 238 | " [1],\n", 239 | " [5],\n", 240 | " [3],\n", 241 | " [0],\n", 242 | " [6],\n", 243 | " [3],\n", 244 | " [5],\n", 245 | " [3],\n", 246 | " [2]]),\n", 247 | " 'reward': tensor([[-1.1000],\n", 248 | " [-0.1000],\n", 249 | " [-1.1000],\n", 250 | " [-0.1000],\n", 251 | " [-1.1000],\n", 252 | " [-0.1000],\n", 253 | " [-0.1000],\n", 254 | " [-0.1000],\n", 255 | " [-0.1000],\n", 256 | " [-0.1000],\n", 257 | " [-0.1000],\n", 258 | " [-0.1000],\n", 259 | " [-0.1000],\n", 260 | " [-0.1000],\n", 261 | " [-0.1000],\n", 262 | " [-0.1000],\n", 263 | " [-0.1000],\n", 264 | " [-0.1000],\n", 265 | " [-0.1000],\n", 266 | " [-0.1000],\n", 267 | " [-0.1000],\n", 268 | " [-0.1000]]),\n", 269 | " 'next_state': tensor([[0., 0., 0., ..., 0., 0., 0.],\n", 270 | " [0., 0., 0., ..., 0., 0., 0.],\n", 271 | " [0., 0., 0., ..., 0., 0., 0.],\n", 272 | " ...,\n", 273 | " [0., 0., 0., ..., 0., 0., 0.],\n", 274 | " [0., 0., 0., ..., 0., 0., 0.],\n", 275 | " [0., 0., 0., ..., 0., 0., 0.]]),\n", 276 | " 'over': tensor([[0],\n", 277 | " [0],\n", 278 | " [0],\n", 279 | " [0],\n", 280 | " [0],\n", 281 | " [0],\n", 282 | " [0],\n", 283 | " [0],\n", 284 | " [0],\n", 285 | " [0],\n", 286 | " [0],\n", 287 | " [0],\n", 288 | " [0],\n", 289 | " [0],\n", 290 | " [0],\n", 291 | " [0],\n", 292 | " [0],\n", 293 | " [0],\n", 294 | " [0],\n", 295 | " [0],\n", 296 | " [0],\n", 297 | " [0]])},\n", 298 | " {'state': tensor([[0., 0., 0., ..., 0., 0., 0.],\n", 299 | " [0., 0., 0., ..., 0., 0., 0.],\n", 300 | " [0., 0., 0., ..., 0., 0., 0.],\n", 301 | " ...,\n", 302 | " [0., 0., 0., ..., 0., 0., 0.],\n", 303 | " [0., 0., 0., ..., 0., 0., 0.],\n", 304 | " [0., 0., 0., ..., 0., 0., 0.]]),\n", 305 | " 'action': tensor([[0],\n", 306 | " [1],\n", 307 | " [1],\n", 308 | " [5],\n", 309 | " [5],\n", 310 | " [4],\n", 311 | " [3],\n", 312 | " [3],\n", 313 | " [5],\n", 314 | " [6],\n", 315 | " [2],\n", 316 | " [6],\n", 317 | " [3],\n", 318 | " [1],\n", 319 | " [1],\n", 320 | " [2],\n", 321 | " [6],\n", 322 | " [3],\n", 323 | " [1],\n", 324 | " [5],\n", 325 | " [5],\n", 326 | " [5]]),\n", 327 | " 'reward': tensor([[-0.1000],\n", 328 | " [-0.1000],\n", 329 | " [-0.1000],\n", 330 | " [-0.1000],\n", 331 | " [-0.1000],\n", 332 | " [-0.1000],\n", 333 | " [-0.1000],\n", 334 | " [-0.1000],\n", 335 | " [-0.1000],\n", 336 | " [-0.1000],\n", 337 | " [-0.1000],\n", 338 | " [-0.1000],\n", 339 | " [-0.1000],\n", 340 | " [-0.1000],\n", 341 | " [-0.1000],\n", 342 | " [-0.1000],\n", 343 | " [-0.1000],\n", 344 | " [-0.1000],\n", 345 | " [-1.1000],\n", 346 | " [-0.1000],\n", 347 | " [-0.1000],\n", 348 | " [-0.1000]]),\n", 349 | " 'next_state': tensor([[0., 0., 0., ..., 0., 0., 0.],\n", 350 | " [0., 0., 0., ..., 0., 0., 0.],\n", 351 | " [0., 0., 0., ..., 0., 0., 0.],\n", 352 | " ...,\n", 353 | " [0., 0., 0., ..., 0., 0., 0.],\n", 354 | " [0., 0., 0., ..., 0., 0., 0.],\n", 355 | " [0., 0., 0., ..., 0., 0., 0.]]),\n", 356 | " 'over': tensor([[0],\n", 357 | " [0],\n", 358 | " [0],\n", 359 | " [0],\n", 360 | " [0],\n", 361 | " [0],\n", 362 | " [0],\n", 363 | " [0],\n", 364 | " [0],\n", 365 | " [0],\n", 366 | " [0],\n", 367 | " [0],\n", 368 | " [0],\n", 369 | " [0],\n", 370 | " [0],\n", 371 | " [0],\n", 372 | " [0],\n", 373 | " [0],\n", 374 | " [0],\n", 375 | " [0],\n", 376 | " [0],\n", 377 | " [0]])},\n", 378 | " False)" 379 | ] 380 | }, 381 | "execution_count": 3, 382 | "metadata": {}, 383 | "output_type": "execute_result" 384 | } 385 | ], 386 | "source": [ 387 | "def get_data():\n", 388 | " data0 = {\n", 389 | " 'state': [],\n", 390 | " 'action': [],\n", 391 | " 'reward': [],\n", 392 | " 'next_state': [],\n", 393 | " 'over': [],\n", 394 | " }\n", 395 | "\n", 396 | " data1 = {\n", 397 | " 'state': [],\n", 398 | " 'action': [],\n", 399 | " 'reward': [],\n", 400 | " 'next_state': [],\n", 401 | " 'over': [],\n", 402 | " }\n", 403 | "\n", 404 | " state = env.reset()\n", 405 | " over = False\n", 406 | " while not over:\n", 407 | " action = [None, None]\n", 408 | "\n", 409 | " action[0] = ppo.get_action(state[0])\n", 410 | " action[1] = ppo.get_action(state[1])\n", 411 | "\n", 412 | " next_state, reward, over, info = env.step(action)\n", 413 | " win = info['win']\n", 414 | " del info\n", 415 | "\n", 416 | " #对reward进行偏移\n", 417 | " if win:\n", 418 | " reward[0] += 100\n", 419 | " reward[1] += 100\n", 420 | " else:\n", 421 | " reward[0] -= 0.1\n", 422 | " reward[1] -= 0.1\n", 423 | "\n", 424 | " data0['state'].append(state[0])\n", 425 | " data0['action'].append(action[0])\n", 426 | " data0['reward'].append(reward[0])\n", 427 | " data0['next_state'].append(next_state[0])\n", 428 | " data0['over'].append(False) #常量\n", 429 | "\n", 430 | " data1['state'].append(state[1])\n", 431 | " data1['action'].append(action[1])\n", 432 | " data1['reward'].append(reward[1])\n", 433 | " data1['next_state'].append(next_state[1])\n", 434 | " data1['over'].append(False) #常量\n", 435 | "\n", 436 | " state = next_state\n", 437 | " over = over[0] and over[1]\n", 438 | "\n", 439 | " data0['state'] = torch.FloatTensor(data0['state']).reshape(-1, 150)\n", 440 | " data0['action'] = torch.LongTensor(data0['action']).reshape(-1, 1)\n", 441 | " data0['reward'] = torch.FloatTensor(data0['reward']).reshape(-1, 1)\n", 442 | " data0['next_state'] = torch.FloatTensor(data0['next_state']).reshape(\n", 443 | " -1, 150)\n", 444 | " data0['over'] = torch.LongTensor(data0['over']).reshape(-1, 1)\n", 445 | "\n", 446 | " data1['state'] = torch.FloatTensor(data1['state']).reshape(-1, 150)\n", 447 | " data1['action'] = torch.LongTensor(data1['action']).reshape(-1, 1)\n", 448 | " data1['reward'] = torch.FloatTensor(data1['reward']).reshape(-1, 1)\n", 449 | " data1['next_state'] = torch.FloatTensor(data1['next_state']).reshape(\n", 450 | " -1, 150)\n", 451 | " data1['over'] = torch.LongTensor(data1['over']).reshape(-1, 1)\n", 452 | "\n", 453 | " return data0, data1, win\n", 454 | "\n", 455 | "\n", 456 | "get_data()" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": 4, 462 | "metadata": { 463 | "colab": { 464 | "base_uri": "https://localhost:8080/" 465 | }, 466 | "executionInfo": { 467 | "elapsed": 10107, 468 | "status": "ok", 469 | "timestamp": 1650012696153, 470 | "user": { 471 | "displayName": "Sam Lu", 472 | "userId": "15789059763790170725" 473 | }, 474 | "user_tz": -480 475 | }, 476 | "id": "-_L_dhppItIk", 477 | "outputId": "6c1eecf0-fd72-4d13-ad05-192463636129" 478 | }, 479 | "outputs": [ 480 | { 481 | "name": "stdout", 482 | "output_type": "stream", 483 | "text": [ 484 | "0 0.0\n", 485 | "5000 0.02\n", 486 | "10000 0.13\n", 487 | "15000 0.23\n", 488 | "20000 0.24\n", 489 | "25000 0.36\n", 490 | "30000 0.28\n", 491 | "35000 0.23\n", 492 | "40000 0.35\n", 493 | "45000 0.31\n", 494 | "50000 0.13\n", 495 | "55000 0.29\n", 496 | "60000 0.26\n", 497 | "65000 0.27\n", 498 | "70000 0.36\n", 499 | "75000 0.33\n", 500 | "80000 0.37\n", 501 | "85000 0.33\n", 502 | "90000 0.43\n", 503 | "95000 0.45\n" 504 | ] 505 | } 506 | ], 507 | "source": [ 508 | "import torch.nn.functional as F\n", 509 | "import numpy as np\n", 510 | "import rl_utils\n", 511 | "\n", 512 | "wins = []\n", 513 | "for i in range(100000):\n", 514 | " data0, data1, win = get_data()\n", 515 | " wins.append(win)\n", 516 | "\n", 517 | " ppo.train(**data0)\n", 518 | " ppo.train(**data1)\n", 519 | "\n", 520 | " if i % 5000 == 0:\n", 521 | " wins = wins[-100:]\n", 522 | " print(i, sum(wins) / len(wins))\n", 523 | " wins = []" 524 | ] 525 | } 526 | ], 527 | "metadata": { 528 | "colab": { 529 | "collapsed_sections": [], 530 | "name": "第20章-多智能体强化学习入门.ipynb", 531 | "provenance": [] 532 | }, 533 | "kernelspec": { 534 | "display_name": "Python 3", 535 | "language": "python", 536 | "name": "python3" 537 | }, 538 | "language_info": { 539 | "codemirror_mode": { 540 | "name": "ipython", 541 | "version": 3 542 | }, 543 | "file_extension": ".py", 544 | "mimetype": "text/x-python", 545 | "name": "python", 546 | "nbconvert_exporter": "python", 547 | "pygments_lexer": "ipython3", 548 | "version": "3.6.13" 549 | } 550 | }, 551 | "nbformat": 4, 552 | "nbformat_minor": 1 553 | } 554 | -------------------------------------------------------------------------------- /18.多智能体/__pycache__/combat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lansinuote/Simple_Reinforcement_Learning/07675afa0f1a47192dcc69ad361bc78d6b98544a/18.多智能体/__pycache__/combat.cpython-36.pyc -------------------------------------------------------------------------------- /2.马尔可夫决策过程/.ipynb_checkpoints/2.贝尔曼方程矩阵-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "40ce79be", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "(array([[0.9, 0.1, 0. , 0. , 0. , 0. ],\n", 13 | " [0.5, 0. , 0.5, 0. , 0. , 0. ],\n", 14 | " [0. , 0. , 0. , 0.6, 0. , 0.4],\n", 15 | " [0. , 0. , 0. , 0. , 0.3, 0.7],\n", 16 | " [0. , 0.2, 0.3, 0.5, 0. , 0. ],\n", 17 | " [0. , 0. , 0. , 0. , 0. , 1. ]]),\n", 18 | " array([-1, -2, -2, 10, 1, 0]))" 19 | ] 20 | }, 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "output_type": "execute_result" 24 | } 25 | ], 26 | "source": [ 27 | "import numpy as np\n", 28 | "\n", 29 | "#状态转移概率矩阵\n", 30 | "P = np.array([\n", 31 | " [0.9, 0.1, 0.0, 0.0, 0.0, 0.0],\n", 32 | " [0.5, 0.0, 0.5, 0.0, 0.0, 0.0],\n", 33 | " [0.0, 0.0, 0.0, 0.6, 0.0, 0.4],\n", 34 | " [0.0, 0.0, 0.0, 0.0, 0.3, 0.7],\n", 35 | " [0.0, 0.2, 0.3, 0.5, 0.0, 0.0],\n", 36 | " [0.0, 0.0, 0.0, 0.0, 0.0, 1.0],\n", 37 | "])\n", 38 | "\n", 39 | "#到达每一个状态的奖励\n", 40 | "R = np.array([-1, -2, -2, 10, 1, 0])\n", 41 | "\n", 42 | "P, R" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "id": "b68758d1", 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "data": { 53 | "text/plain": [ 54 | "-2.5" 55 | ] 56 | }, 57 | "execution_count": 2, 58 | "metadata": {}, 59 | "output_type": "execute_result" 60 | } 61 | ], 62 | "source": [ 63 | "#给定一条序列,计算回报\n", 64 | "def value_by_chain(chain):\n", 65 | " s = 0\n", 66 | " for i, c in enumerate(chain):\n", 67 | " #给每一步的反馈做一个系数,随着步数往前衰减\n", 68 | " s += R[c] * 0.5**i\n", 69 | "\n", 70 | " #最终的反馈是所有步数衰减后的求和\n", 71 | " return s\n", 72 | "\n", 73 | "\n", 74 | "value_by_chain(np.array([0, 1, 2, 5]))" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 3, 80 | "id": "ada08c2b", 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "data": { 85 | "text/plain": [ 86 | "array([-2.01950168e+00, -2.21451846e+00, 1.16142785e+00, 1.05380928e+01,\n", 87 | " 3.58728554e+00, 6.22301528e-61])" 88 | ] 89 | }, 90 | "execution_count": 3, 91 | "metadata": {}, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "#梯度下降法计算贝尔曼矩阵\n", 97 | "def get_bellman():\n", 98 | " #初始化values\n", 99 | " value = np.ones([6])\n", 100 | "\n", 101 | " for _ in range(200):\n", 102 | " for i in range(6):\n", 103 | " #每一行的概率和它对应的value相乘,乘以gamma,然后和奖励相加\n", 104 | " #反复计算,就收敛到了贝尔曼方程矩阵\n", 105 | " value[i] = R[i] + 0.5 * P[i].dot(value)\n", 106 | "\n", 107 | " return value\n", 108 | "\n", 109 | "\n", 110 | "get_bellman()" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 4, 116 | "id": "480f9498", 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "data": { 121 | "text/plain": [ 122 | "array([-2.01950168, -2.21451846, 1.16142785, 10.53809283, 3.58728554,\n", 123 | " 0. ])" 124 | ] 125 | }, 126 | "execution_count": 4, 127 | "metadata": {}, 128 | "output_type": "execute_result" 129 | } 130 | ], 131 | "source": [ 132 | "#解析解贝尔曼方程矩阵\n", 133 | "def get_bellman():\n", 134 | " mat = np.eye(*P.shape)\n", 135 | " mat -= 0.5 * P\n", 136 | " mat = np.linalg.inv(mat)\n", 137 | "\n", 138 | " return mat.dot(R)\n", 139 | "\n", 140 | "\n", 141 | "get_bellman()" 142 | ] 143 | } 144 | ], 145 | "metadata": { 146 | "kernelspec": { 147 | "display_name": "Python 3", 148 | "language": "python", 149 | "name": "python3" 150 | }, 151 | "language_info": { 152 | "codemirror_mode": { 153 | "name": "ipython", 154 | "version": 3 155 | }, 156 | "file_extension": ".py", 157 | "mimetype": "text/x-python", 158 | "name": "python", 159 | "nbconvert_exporter": "python", 160 | "pygments_lexer": "ipython3", 161 | "version": "3.6.13" 162 | } 163 | }, 164 | "nbformat": 4, 165 | "nbformat_minor": 5 166 | } 167 | -------------------------------------------------------------------------------- /2.马尔可夫决策过程/1.蒙特卡洛法.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "(array([[0.5, 0.5, 0. , 0. , 0. ],\n", 12 | " [0.5, 0. , 0.5, 0. , 0. ],\n", 13 | " [0. , 0. , 0. , 0.5, 0.5],\n", 14 | " [0. , 0.1, 0.2, 0.2, 0.5],\n", 15 | " [0. , 0. , 0. , 0. , 0. ]]),\n", 16 | " array([[ -1., 0., -100., -100., -100.],\n", 17 | " [ -1., -100., -2., -100., -100.],\n", 18 | " [-100., -100., -100., -2., 0.],\n", 19 | " [-100., 1., 1., 1., 10.],\n", 20 | " [-100., -100., -100., -100., -100.]]))" 21 | ] 22 | }, 23 | "execution_count": 1, 24 | "metadata": {}, 25 | "output_type": "execute_result" 26 | } 27 | ], 28 | "source": [ 29 | "import numpy as np\n", 30 | "\n", 31 | "#状态转移概率矩阵\n", 32 | "#很显然,状态4(第5行)就是重点了,要进入状态4,只能从状态2,3进入\n", 33 | "#[5, 5]\n", 34 | "P = np.array([\n", 35 | " [0.5, 0.5, 0.0, 0.0, 0.0],\n", 36 | " [0.5, 0.0, 0.5, 0.0, 0.0],\n", 37 | " [0.0, 0.0, 0.0, 0.5, 0.5],\n", 38 | " [0.0, 0.1, 0.2, 0.2, 0.5],\n", 39 | " [0.0, 0.0, 0.0, 0.0, 0.0],\n", 40 | "])\n", 41 | "\n", 42 | "#反馈矩阵,-100的位置是不可能走到的\n", 43 | "#[5, 5]\n", 44 | "R = np.array([\n", 45 | " [-1.0, 0.0, -100.0, -100.0, -100.0],\n", 46 | " [-1.0, -100.0, -2.0, -100.0, -100.0],\n", 47 | " [-100.0, -100.0, -100.0, -2.0, 0.0],\n", 48 | " [-100.0, 1.0, 1.0, 1.0, 10.0],\n", 49 | " [-100.0, -100.0, -100.0, -100.0, -100.0],\n", 50 | "])\n", 51 | "\n", 52 | "P, R" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "data": { 62 | "text/plain": [ 63 | "([2, 3, 4], [-2.0, 10.0])" 64 | ] 65 | }, 66 | "execution_count": 2, 67 | "metadata": {}, 68 | "output_type": "execute_result" 69 | } 70 | ], 71 | "source": [ 72 | "import numpy as np\n", 73 | "import random\n", 74 | "\n", 75 | "\n", 76 | "#生成一个chain\n", 77 | "def get_chain(max_lens):\n", 78 | " #采样结果\n", 79 | " ss = []\n", 80 | " rs = []\n", 81 | "\n", 82 | " #随机选择一个除4以外的状态作为起点\n", 83 | " s = random.choice(range(4))\n", 84 | " ss.append(s)\n", 85 | "\n", 86 | " for _ in range(max_lens):\n", 87 | " #按照P的概率,找到下一个状态\n", 88 | " s_next = np.random.choice(np.arange(5), p=P[s])\n", 89 | "\n", 90 | " #取到r\n", 91 | " r = R[s, s_next]\n", 92 | "\n", 93 | " #s_next变成当前状态,开始接下来的循环\n", 94 | " s = s_next\n", 95 | "\n", 96 | " ss.append(s)\n", 97 | " rs.append(r)\n", 98 | "\n", 99 | " #如果状态到了4则结束\n", 100 | " if s == 4:\n", 101 | " break\n", 102 | "\n", 103 | " return ss, rs\n", 104 | "\n", 105 | "\n", 106 | "get_chain(20)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 3, 112 | "metadata": { 113 | "scrolled": true 114 | }, 115 | "outputs": [ 116 | { 117 | "data": { 118 | "text/plain": [ 119 | "([[2, 4],\n", 120 | " [0, 0, 0, 1, 0, 1, 2, 3, 4],\n", 121 | " [0, 1, 0, 1, 0, 0, 1, 2, 3, 4],\n", 122 | " [1, 2, 4],\n", 123 | " [3, 4],\n", 124 | " [1, 2, 4],\n", 125 | " [1, 0, 1, 0, 0, 1, 2, 3, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 4],\n", 126 | " [2, 4],\n", 127 | " [2, 3, 4],\n", 128 | " [2, 4],\n", 129 | " [0, 1, 0, 0, 0, 1, 2, 4],\n", 130 | " [3, 4],\n", 131 | " [1, 2, 3, 1, 0, 1, 0, 1, 2, 4],\n", 132 | " [2, 3, 3, 3, 2, 4],\n", 133 | " [1, 2, 4],\n", 134 | " [0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 2, 4],\n", 135 | " [0, 1, 2, 4],\n", 136 | " [2, 4],\n", 137 | " [2, 4],\n", 138 | " [2, 4],\n", 139 | " [2, 4],\n", 140 | " [1, 2, 4],\n", 141 | " [2, 4],\n", 142 | " [3, 2, 3, 1, 2, 3, 4],\n", 143 | " [0, 0, 1, 2, 3, 3, 3, 4],\n", 144 | " [2, 3, 2, 4],\n", 145 | " [2, 4],\n", 146 | " [2, 4],\n", 147 | " [1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0],\n", 148 | " [1, 2, 4],\n", 149 | " [2, 3, 3, 4],\n", 150 | " [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 2, 3, 4],\n", 151 | " [1, 2, 3, 4],\n", 152 | " [0, 1, 2, 3, 4],\n", 153 | " [1, 0, 0, 0, 0, 1, 0, 1, 2, 3, 1, 2, 4],\n", 154 | " [2, 3, 4],\n", 155 | " [2, 4],\n", 156 | " [2, 4],\n", 157 | " [3, 3, 1, 2, 3, 2, 3, 4],\n", 158 | " [3, 4],\n", 159 | " [0, 1, 2, 3, 3, 4],\n", 160 | " [3, 4],\n", 161 | " [0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 2, 3, 4],\n", 162 | " [3, 3, 4],\n", 163 | " [2, 3, 4],\n", 164 | " [1, 0, 1, 0, 1, 2, 3, 4],\n", 165 | " [0, 1, 2, 3, 3, 3, 4],\n", 166 | " [1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 3, 4],\n", 167 | " [3, 4],\n", 168 | " [2, 3, 3, 4],\n", 169 | " [2, 4],\n", 170 | " [1, 2, 3, 3, 3, 2, 3, 3, 4],\n", 171 | " [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 2, 4],\n", 172 | " [2, 4],\n", 173 | " [1, 2, 3, 4],\n", 174 | " [2, 4],\n", 175 | " [3, 2, 4],\n", 176 | " [1, 0, 1, 2, 4],\n", 177 | " [3, 3, 3, 3, 4],\n", 178 | " [1, 0, 1, 2, 4],\n", 179 | " [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 3, 4],\n", 180 | " [3, 4],\n", 181 | " [0, 0, 0, 1, 2, 4],\n", 182 | " [0, 1, 0, 0, 0, 1, 2, 4],\n", 183 | " [2, 4],\n", 184 | " [0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 2, 4],\n", 185 | " [2, 4],\n", 186 | " [3, 3, 3, 4],\n", 187 | " [0, 0, 0, 1, 2, 4],\n", 188 | " [2, 3, 2, 4],\n", 189 | " [3, 4],\n", 190 | " [2, 3, 4],\n", 191 | " [0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 2, 4],\n", 192 | " [2, 4],\n", 193 | " [1, 0, 1, 2, 4],\n", 194 | " [3, 4],\n", 195 | " [1, 0, 1, 0, 0, 0, 0, 1, 2, 3, 3, 3, 2, 3, 3, 1, 2, 4],\n", 196 | " [2, 3, 4],\n", 197 | " [0, 1, 0, 0, 1, 2, 3, 4],\n", 198 | " [0, 0, 1, 2, 4],\n", 199 | " [3, 2, 3, 4],\n", 200 | " [0, 0, 1, 2, 3, 3, 4],\n", 201 | " [0, 1, 2, 4],\n", 202 | " [0, 0, 1, 2, 4],\n", 203 | " [0, 0, 1, 0, 0, 0, 1, 0, 1, 2, 4],\n", 204 | " [3, 4],\n", 205 | " [3, 1, 2, 4],\n", 206 | " [2, 3, 2, 4],\n", 207 | " [1, 2, 4],\n", 208 | " [0, 1, 2, 4],\n", 209 | " [1, 0, 0, 0, 0, 1, 2, 3, 3, 3, 3, 4],\n", 210 | " [3, 3, 4],\n", 211 | " [3, 3, 4],\n", 212 | " [3, 3, 2, 3, 4],\n", 213 | " [0, 0, 0, 1, 0, 0, 0, 0, 1, 2, 4],\n", 214 | " [3, 2, 3, 2, 3, 2, 3, 1, 0, 0, 1, 2, 4],\n", 215 | " [2, 3, 1, 2, 4],\n", 216 | " [1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 2, 4],\n", 217 | " [3, 3, 4],\n", 218 | " [3, 4]],\n", 219 | " [[0.0],\n", 220 | " [-1.0, -1.0, 0.0, -1.0, 0.0, -2.0, -2.0, 10.0],\n", 221 | " [0.0, -1.0, 0.0, -1.0, -1.0, 0.0, -2.0, -2.0, 10.0],\n", 222 | " [-2.0, 0.0],\n", 223 | " [10.0],\n", 224 | " [-2.0, 0.0],\n", 225 | " [-1.0,\n", 226 | " 0.0,\n", 227 | " -1.0,\n", 228 | " -1.0,\n", 229 | " 0.0,\n", 230 | " -2.0,\n", 231 | " -2.0,\n", 232 | " 1.0,\n", 233 | " -1.0,\n", 234 | " -1.0,\n", 235 | " -1.0,\n", 236 | " 0.0,\n", 237 | " -1.0,\n", 238 | " -1.0,\n", 239 | " -1.0,\n", 240 | " 0.0,\n", 241 | " -2.0,\n", 242 | " 0.0],\n", 243 | " [0.0],\n", 244 | " [-2.0, 10.0],\n", 245 | " [0.0],\n", 246 | " [0.0, -1.0, -1.0, -1.0, 0.0, -2.0, 0.0],\n", 247 | " [10.0],\n", 248 | " [-2.0, -2.0, 1.0, -1.0, 0.0, -1.0, 0.0, -2.0, 0.0],\n", 249 | " [-2.0, 1.0, 1.0, 1.0, 0.0],\n", 250 | " [-2.0, 0.0],\n", 251 | " [-1.0, -1.0, 0.0, -1.0, -1.0, -1.0, 0.0, -1.0, 0.0, -2.0, 0.0],\n", 252 | " [0.0, -2.0, 0.0],\n", 253 | " [0.0],\n", 254 | " [0.0],\n", 255 | " [0.0],\n", 256 | " [0.0],\n", 257 | " [-2.0, 0.0],\n", 258 | " [0.0],\n", 259 | " [1.0, -2.0, 1.0, -2.0, -2.0, 10.0],\n", 260 | " [-1.0, 0.0, -2.0, -2.0, 1.0, 1.0, 10.0],\n", 261 | " [-2.0, 1.0, 0.0],\n", 262 | " [0.0],\n", 263 | " [0.0],\n", 264 | " [-1.0,\n", 265 | " -1.0,\n", 266 | " 0.0,\n", 267 | " -1.0,\n", 268 | " -1.0,\n", 269 | " 0.0,\n", 270 | " -1.0,\n", 271 | " 0.0,\n", 272 | " -1.0,\n", 273 | " -1.0,\n", 274 | " -1.0,\n", 275 | " 0.0,\n", 276 | " -1.0,\n", 277 | " 0.0,\n", 278 | " -1.0,\n", 279 | " -1.0,\n", 280 | " -1.0,\n", 281 | " -1.0,\n", 282 | " 0.0,\n", 283 | " -1.0],\n", 284 | " [-2.0, 0.0],\n", 285 | " [-2.0, 1.0, 10.0],\n", 286 | " [-1.0,\n", 287 | " -1.0,\n", 288 | " -1.0,\n", 289 | " 0.0,\n", 290 | " -1.0,\n", 291 | " -1.0,\n", 292 | " -1.0,\n", 293 | " 0.0,\n", 294 | " -1.0,\n", 295 | " -1.0,\n", 296 | " 0.0,\n", 297 | " -2.0,\n", 298 | " -2.0,\n", 299 | " 10.0],\n", 300 | " [-2.0, -2.0, 10.0],\n", 301 | " [0.0, -2.0, -2.0, 10.0],\n", 302 | " [-1.0, -1.0, -1.0, -1.0, 0.0, -1.0, 0.0, -2.0, -2.0, 1.0, -2.0, 0.0],\n", 303 | " [-2.0, 10.0],\n", 304 | " [0.0],\n", 305 | " [0.0],\n", 306 | " [1.0, 1.0, -2.0, -2.0, 1.0, -2.0, 10.0],\n", 307 | " [10.0],\n", 308 | " [0.0, -2.0, -2.0, 1.0, 10.0],\n", 309 | " [10.0],\n", 310 | " [-1.0,\n", 311 | " 0.0,\n", 312 | " -1.0,\n", 313 | " -1.0,\n", 314 | " 0.0,\n", 315 | " -1.0,\n", 316 | " -1.0,\n", 317 | " -1.0,\n", 318 | " 0.0,\n", 319 | " -1.0,\n", 320 | " 0.0,\n", 321 | " -1.0,\n", 322 | " 0.0,\n", 323 | " -2.0,\n", 324 | " -2.0,\n", 325 | " 10.0],\n", 326 | " [1.0, 10.0],\n", 327 | " [-2.0, 10.0],\n", 328 | " [-1.0, 0.0, -1.0, 0.0, -2.0, -2.0, 10.0],\n", 329 | " [0.0, -2.0, -2.0, 1.0, 1.0, 10.0],\n", 330 | " [-1.0, -1.0, -1.0, 0.0, -1.0, -1.0, -1.0, 0.0, -2.0, -2.0, 10.0],\n", 331 | " [10.0],\n", 332 | " [-2.0, 1.0, 10.0],\n", 333 | " [0.0],\n", 334 | " [-2.0, -2.0, 1.0, 1.0, 1.0, -2.0, 1.0, 10.0],\n", 335 | " [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 0.0, -1.0, 0.0, -1.0, 0.0, -2.0, 0.0],\n", 336 | " [0.0],\n", 337 | " [-2.0, -2.0, 10.0],\n", 338 | " [0.0],\n", 339 | " [1.0, 0.0],\n", 340 | " [-1.0, 0.0, -2.0, 0.0],\n", 341 | " [1.0, 1.0, 1.0, 10.0],\n", 342 | " [-1.0, 0.0, -2.0, 0.0],\n", 343 | " [-1.0,\n", 344 | " 0.0,\n", 345 | " -1.0,\n", 346 | " -1.0,\n", 347 | " -1.0,\n", 348 | " -1.0,\n", 349 | " 0.0,\n", 350 | " -1.0,\n", 351 | " -1.0,\n", 352 | " -1.0,\n", 353 | " 0.0,\n", 354 | " -1.0,\n", 355 | " -1.0,\n", 356 | " -1.0,\n", 357 | " 0.0,\n", 358 | " -2.0,\n", 359 | " -2.0,\n", 360 | " 10.0],\n", 361 | " [10.0],\n", 362 | " [-1.0, -1.0, 0.0, -2.0, 0.0],\n", 363 | " [0.0, -1.0, -1.0, -1.0, 0.0, -2.0, 0.0],\n", 364 | " [0.0],\n", 365 | " [-1.0, -1.0, -1.0, 0.0, -1.0, 0.0, -1.0, 0.0, -1.0, 0.0, -2.0, 0.0],\n", 366 | " [0.0],\n", 367 | " [1.0, 1.0, 10.0],\n", 368 | " [-1.0, -1.0, 0.0, -2.0, 0.0],\n", 369 | " [-2.0, 1.0, 0.0],\n", 370 | " [10.0],\n", 371 | " [-2.0, 10.0],\n", 372 | " [-1.0, -1.0, 0.0, -1.0, 0.0, -1.0, -1.0, 0.0, -1.0, -1.0, 0.0, -2.0, 0.0],\n", 373 | " [0.0],\n", 374 | " [-1.0, 0.0, -2.0, 0.0],\n", 375 | " [10.0],\n", 376 | " [-1.0,\n", 377 | " 0.0,\n", 378 | " -1.0,\n", 379 | " -1.0,\n", 380 | " -1.0,\n", 381 | " -1.0,\n", 382 | " 0.0,\n", 383 | " -2.0,\n", 384 | " -2.0,\n", 385 | " 1.0,\n", 386 | " 1.0,\n", 387 | " 1.0,\n", 388 | " -2.0,\n", 389 | " 1.0,\n", 390 | " 1.0,\n", 391 | " -2.0,\n", 392 | " 0.0],\n", 393 | " [-2.0, 10.0],\n", 394 | " [0.0, -1.0, -1.0, 0.0, -2.0, -2.0, 10.0],\n", 395 | " [-1.0, 0.0, -2.0, 0.0],\n", 396 | " [1.0, -2.0, 10.0],\n", 397 | " [-1.0, 0.0, -2.0, -2.0, 1.0, 10.0],\n", 398 | " [0.0, -2.0, 0.0],\n", 399 | " [-1.0, 0.0, -2.0, 0.0],\n", 400 | " [-1.0, 0.0, -1.0, -1.0, -1.0, 0.0, -1.0, 0.0, -2.0, 0.0],\n", 401 | " [10.0],\n", 402 | " [1.0, -2.0, 0.0],\n", 403 | " [-2.0, 1.0, 0.0],\n", 404 | " [-2.0, 0.0],\n", 405 | " [0.0, -2.0, 0.0],\n", 406 | " [-1.0, -1.0, -1.0, -1.0, 0.0, -2.0, -2.0, 1.0, 1.0, 1.0, 10.0],\n", 407 | " [1.0, 10.0],\n", 408 | " [1.0, 10.0],\n", 409 | " [1.0, 1.0, -2.0, 10.0],\n", 410 | " [-1.0, -1.0, 0.0, -1.0, -1.0, -1.0, -1.0, 0.0, -2.0, 0.0],\n", 411 | " [1.0, -2.0, 1.0, -2.0, 1.0, -2.0, 1.0, -1.0, -1.0, 0.0, -2.0, 0.0],\n", 412 | " [-2.0, 1.0, -2.0, 0.0],\n", 413 | " [-1.0, -1.0, -1.0, -1.0, -1.0, 0.0, -1.0, -1.0, 0.0, -2.0, 0.0],\n", 414 | " [1.0, 10.0],\n", 415 | " [10.0]])" 416 | ] 417 | }, 418 | "execution_count": 3, 419 | "metadata": {}, 420 | "output_type": "execute_result" 421 | } 422 | ], 423 | "source": [ 424 | "#生成N个chain\n", 425 | "def get_chains(N, max_lens):\n", 426 | " ss = []\n", 427 | " rs = []\n", 428 | " for _ in range(N):\n", 429 | " s, r = get_chain(max_lens)\n", 430 | " ss.append(s)\n", 431 | " rs.append(r)\n", 432 | "\n", 433 | " return ss, rs\n", 434 | "\n", 435 | "\n", 436 | "ss, rs = get_chains(100, 20)\n", 437 | "\n", 438 | "ss, rs" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": 4, 444 | "metadata": {}, 445 | "outputs": [ 446 | { 447 | "data": { 448 | "text/plain": [ 449 | "0.0" 450 | ] 451 | }, 452 | "execution_count": 4, 453 | "metadata": {}, 454 | "output_type": "execute_result" 455 | } 456 | ], 457 | "source": [ 458 | "#给定一条链,计算回报\n", 459 | "def get_value(rs):\n", 460 | " sum = 0\n", 461 | " for i, r in enumerate(rs):\n", 462 | " #给每一步的反馈做一个系数,随着步数往后衰减,也就是说,越早的动作影响越大\n", 463 | " sum += 0.5**i * r\n", 464 | "\n", 465 | " #最终的反馈是所有步数衰减后的求和\n", 466 | " return sum\n", 467 | "\n", 468 | "\n", 469 | "get_value(rs[0])" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 5, 475 | "metadata": {}, 476 | "outputs": [ 477 | { 478 | "name": "stderr", 479 | "output_type": "stream", 480 | "text": [ 481 | "/root/anaconda3/envs/cpu/lib/python3.6/site-packages/numpy/core/fromnumeric.py:3373: RuntimeWarning: Mean of empty slice.\n", 482 | " out=out, **kwargs)\n", 483 | "/root/anaconda3/envs/cpu/lib/python3.6/site-packages/numpy/core/_methods.py:170: RuntimeWarning: invalid value encountered in double_scalars\n", 484 | " ret = ret.dtype.type(ret / rcount)\n" 485 | ] 486 | }, 487 | { 488 | "data": { 489 | "text/plain": [ 490 | "[-1.2689316385800076,\n", 491 | " -1.595738185128587,\n", 492 | " 0.5337043907456025,\n", 493 | " 5.91608556615244,\n", 494 | " nan]" 495 | ] 496 | }, 497 | "execution_count": 5, 498 | "metadata": {}, 499 | "output_type": "execute_result" 500 | } 501 | ], 502 | "source": [ 503 | "#蒙特卡洛法评估每个状态的价值\n", 504 | "def get_values_by_monte_carlo(ss, rs):\n", 505 | " #记录5个不同开头的价值\n", 506 | " #其实只有4个,因为状态4是不可能作为开头状态的\n", 507 | " values = [[] for _ in range(5)]\n", 508 | "\n", 509 | " #遍历所有链\n", 510 | " for s, r in zip(ss, rs):\n", 511 | " #计算不同开头的价值\n", 512 | " values[s[0]].append(get_value(r))\n", 513 | "\n", 514 | " #求每个开头的平均价值\n", 515 | " return [np.mean(i) for i in values]\n", 516 | "\n", 517 | "\n", 518 | "#-1.228923788722258,-1.6955696284402704,0.4823809701532294,5.967514743019431,0\n", 519 | "get_values_by_monte_carlo(*get_chains(2000, 20))" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": 6, 525 | "metadata": {}, 526 | "outputs": [ 527 | { 528 | "data": { 529 | "text/plain": [ 530 | "0.11304324114416356" 531 | ] 532 | }, 533 | "execution_count": 6, 534 | "metadata": {}, 535 | "output_type": "execute_result" 536 | } 537 | ], 538 | "source": [ 539 | "#计算状态动作对(s,a)出现的频率,以此来估算策略的占用度量\n", 540 | "def occupancy(ss, rs, s, a):\n", 541 | " rho = 0\n", 542 | "\n", 543 | " count_by_time = np.zeros(max_time)\n", 544 | " count_by_s_a = np.zeros(max_time)\n", 545 | "\n", 546 | " for si, ri in zip(ss, rs):\n", 547 | " for i in range(len(ri)):\n", 548 | " s_opt = si[i]\n", 549 | " a_opt = si[i + 1]\n", 550 | "\n", 551 | " #统计每个时间步的次数\n", 552 | " count_by_time[i] += 1\n", 553 | "\n", 554 | " #统计s,a出现的次数\n", 555 | " if s == s_opt and a == a_opt:\n", 556 | " count_by_s_a[i] += 1\n", 557 | "\n", 558 | " #i -> [999 - 0]\n", 559 | " for i in reversed(range(max_time)):\n", 560 | " if count_by_time[i] == 0:\n", 561 | " continue\n", 562 | "\n", 563 | " #以时间逐渐衰减\n", 564 | " rho += 0.5**i * count_by_s_a[i] / count_by_time[i]\n", 565 | "\n", 566 | " return (1 - 0.5) * rho\n", 567 | "\n", 568 | "\n", 569 | "max_time = 1000\n", 570 | "ss, rs = get_chains(max_time, 2000)\n", 571 | "\n", 572 | "#0.112567796310472\n", 573 | "occupancy(ss, rs, 3, 1) + occupancy(ss, rs, 3, 2) + occupancy(ss, rs, 3, 3)" 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": 7, 579 | "metadata": {}, 580 | "outputs": [ 581 | { 582 | "data": { 583 | "text/plain": [ 584 | "0.23167624185977734" 585 | ] 586 | }, 587 | "execution_count": 7, 588 | "metadata": {}, 589 | "output_type": "execute_result" 590 | } 591 | ], 592 | "source": [ 593 | "#重新定义状态转移概率矩阵\n", 594 | "P = np.array([\n", 595 | " [0.6, 0.4, 0.0, 0.0, 0.0],\n", 596 | " [0.3, 0.0, 0.7, 0.0, 0.0],\n", 597 | " [0.0, 0.0, 0.0, 0.5, 0.5],\n", 598 | " [0.0, 0.18, 0.36, 0.36, 0.1],\n", 599 | " [0.0, 0.0, 0.0, 0.0, 0.0],\n", 600 | "])\n", 601 | "\n", 602 | "ss, rs = get_chains(max_time, 2000)\n", 603 | "\n", 604 | "#0.23199480615618912\n", 605 | "occupancy(ss, rs, 3, 1) + occupancy(ss, rs, 3, 2) + occupancy(ss, rs, 3, 3)" 606 | ] 607 | } 608 | ], 609 | "metadata": { 610 | "colab": { 611 | "collapsed_sections": [], 612 | "name": "第3章-马尔可夫决策过程.ipynb", 613 | "provenance": [] 614 | }, 615 | "kernelspec": { 616 | "display_name": "Python 3", 617 | "language": "python", 618 | "name": "python3" 619 | }, 620 | "language_info": { 621 | "codemirror_mode": { 622 | "name": "ipython", 623 | "version": 3 624 | }, 625 | "file_extension": ".py", 626 | "mimetype": "text/x-python", 627 | "name": "python", 628 | "nbconvert_exporter": "python", 629 | "pygments_lexer": "ipython3", 630 | "version": "3.6.13" 631 | } 632 | }, 633 | "nbformat": 4, 634 | "nbformat_minor": 1 635 | } 636 | -------------------------------------------------------------------------------- /2.马尔可夫决策过程/2.贝尔曼方程矩阵.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "40ce79be", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "(array([[0.9, 0.1, 0. , 0. , 0. , 0. ],\n", 13 | " [0.5, 0. , 0.5, 0. , 0. , 0. ],\n", 14 | " [0. , 0. , 0. , 0.6, 0. , 0.4],\n", 15 | " [0. , 0. , 0. , 0. , 0.3, 0.7],\n", 16 | " [0. , 0.2, 0.3, 0.5, 0. , 0. ],\n", 17 | " [0. , 0. , 0. , 0. , 0. , 1. ]]),\n", 18 | " array([-1, -2, -2, 10, 1, 0]))" 19 | ] 20 | }, 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "output_type": "execute_result" 24 | } 25 | ], 26 | "source": [ 27 | "import numpy as np\n", 28 | "\n", 29 | "#状态转移概率矩阵\n", 30 | "P = np.array([\n", 31 | " [0.9, 0.1, 0.0, 0.0, 0.0, 0.0],\n", 32 | " [0.5, 0.0, 0.5, 0.0, 0.0, 0.0],\n", 33 | " [0.0, 0.0, 0.0, 0.6, 0.0, 0.4],\n", 34 | " [0.0, 0.0, 0.0, 0.0, 0.3, 0.7],\n", 35 | " [0.0, 0.2, 0.3, 0.5, 0.0, 0.0],\n", 36 | " [0.0, 0.0, 0.0, 0.0, 0.0, 1.0],\n", 37 | "])\n", 38 | "\n", 39 | "#到达每一个状态的奖励\n", 40 | "R = np.array([-1, -2, -2, 10, 1, 0])\n", 41 | "\n", 42 | "P, R" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "id": "b68758d1", 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "data": { 53 | "text/plain": [ 54 | "-2.5" 55 | ] 56 | }, 57 | "execution_count": 2, 58 | "metadata": {}, 59 | "output_type": "execute_result" 60 | } 61 | ], 62 | "source": [ 63 | "#给定一条序列,计算回报\n", 64 | "def value_by_chain(chain):\n", 65 | " s = 0\n", 66 | " for i, c in enumerate(chain):\n", 67 | " #给每一步的反馈做一个系数,随着步数往前衰减\n", 68 | " s += R[c] * 0.5**i\n", 69 | "\n", 70 | " #最终的反馈是所有步数衰减后的求和\n", 71 | " return s\n", 72 | "\n", 73 | "\n", 74 | "value_by_chain(np.array([0, 1, 2, 5]))" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 3, 80 | "id": "ada08c2b", 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "data": { 85 | "text/plain": [ 86 | "array([-2.01950168e+00, -2.21451846e+00, 1.16142785e+00, 1.05380928e+01,\n", 87 | " 3.58728554e+00, 6.22301528e-61])" 88 | ] 89 | }, 90 | "execution_count": 3, 91 | "metadata": {}, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "#梯度下降法计算贝尔曼矩阵\n", 97 | "def get_bellman():\n", 98 | " #初始化values\n", 99 | " value = np.ones([6])\n", 100 | "\n", 101 | " for _ in range(200):\n", 102 | " for i in range(6):\n", 103 | " #每一行的概率和它对应的value相乘,乘以gamma,然后和奖励相加\n", 104 | " #反复计算,就收敛到了贝尔曼方程矩阵\n", 105 | " value[i] = R[i] + 0.5 * P[i].dot(value)\n", 106 | "\n", 107 | " return value\n", 108 | "\n", 109 | "\n", 110 | "get_bellman()" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 4, 116 | "id": "480f9498", 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "data": { 121 | "text/plain": [ 122 | "array([-2.01950168, -2.21451846, 1.16142785, 10.53809283, 3.58728554,\n", 123 | " 0. ])" 124 | ] 125 | }, 126 | "execution_count": 4, 127 | "metadata": {}, 128 | "output_type": "execute_result" 129 | } 130 | ], 131 | "source": [ 132 | "#解析解贝尔曼方程矩阵\n", 133 | "def get_bellman():\n", 134 | " mat = np.eye(*P.shape)\n", 135 | " mat -= 0.5 * P\n", 136 | " mat = np.linalg.inv(mat)\n", 137 | "\n", 138 | " return mat.dot(R)\n", 139 | "\n", 140 | "\n", 141 | "get_bellman()" 142 | ] 143 | } 144 | ], 145 | "metadata": { 146 | "kernelspec": { 147 | "display_name": "Python 3", 148 | "language": "python", 149 | "name": "python3" 150 | }, 151 | "language_info": { 152 | "codemirror_mode": { 153 | "name": "ipython", 154 | "version": 3 155 | }, 156 | "file_extension": ".py", 157 | "mimetype": "text/x-python", 158 | "name": "python", 159 | "nbconvert_exporter": "python", 160 | "pygments_lexer": "ipython3", 161 | "version": "3.6.13" 162 | } 163 | }, 164 | "nbformat": 4, 165 | "nbformat_minor": 5 166 | } 167 | -------------------------------------------------------------------------------- /4.时序差分算法/.ipynb_checkpoints/1.Sarsa算法-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "'ground'" 12 | ] 13 | }, 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "#获取一个格子的状态\n", 21 | "def get_state(row, col):\n", 22 | " if row != 3:\n", 23 | " return 'ground'\n", 24 | "\n", 25 | " if row == 3 and col == 0:\n", 26 | " return 'ground'\n", 27 | "\n", 28 | " if row == 3 and col == 11:\n", 29 | " return 'terminal'\n", 30 | "\n", 31 | " return 'trap'\n", 32 | "\n", 33 | "\n", 34 | "get_state(0, 0)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "(0, 1, -1)" 46 | ] 47 | }, 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | } 52 | ], 53 | "source": [ 54 | "#在一个格子里做一个动作\n", 55 | "def move(row, col, action):\n", 56 | " #如果当前已经在陷阱或者终点,则不能执行任何动作\n", 57 | " if get_state(row, col) in ['trap', 'terminal']:\n", 58 | " return row, col, 0\n", 59 | "\n", 60 | " #↑\n", 61 | " if action == 0:\n", 62 | " row -= 1\n", 63 | "\n", 64 | " #↓\n", 65 | " if action == 1:\n", 66 | " row += 1\n", 67 | "\n", 68 | " #←\n", 69 | " if action == 2:\n", 70 | " col -= 1\n", 71 | "\n", 72 | " #→\n", 73 | " if action == 3:\n", 74 | " col += 1\n", 75 | "\n", 76 | " #不允许走到地图外面去\n", 77 | " row = max(0, row)\n", 78 | " row = min(3, row)\n", 79 | " col = max(0, col)\n", 80 | " col = min(11, col)\n", 81 | "\n", 82 | " #是陷阱的话,奖励是-100,否则都是-1\n", 83 | " reward = -1\n", 84 | " if get_state(row, col) == 'trap':\n", 85 | " reward = -100\n", 86 | "\n", 87 | " return row, col, reward\n", 88 | "\n", 89 | "\n", 90 | "move(0, 0, 3)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "data": { 100 | "text/plain": [ 101 | "(4, 12, 4)" 102 | ] 103 | }, 104 | "execution_count": 3, 105 | "metadata": {}, 106 | "output_type": "execute_result" 107 | } 108 | ], 109 | "source": [ 110 | "import numpy as np\n", 111 | "\n", 112 | "#初始化在每一个格子里采取每个动作的分数,初始化都是0,因为没有任何的知识\n", 113 | "Q = np.zeros([4, 12, 4])\n", 114 | "\n", 115 | "Q.shape" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 4, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "text/plain": [ 126 | "0" 127 | ] 128 | }, 129 | "execution_count": 4, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "import random\n", 136 | "\n", 137 | "\n", 138 | "#根据状态选择一个动作\n", 139 | "def get_action(row, col):\n", 140 | " #有小概率选择随机动作\n", 141 | " if random.random() < 0.1:\n", 142 | " return random.choice(range(4))\n", 143 | "\n", 144 | " #否则选择分数最高的动作\n", 145 | " return Q[row, col].argmax()\n", 146 | "\n", 147 | "\n", 148 | "get_action(0, 0)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 5, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/plain": [ 159 | "-0.1" 160 | ] 161 | }, 162 | "execution_count": 5, 163 | "metadata": {}, 164 | "output_type": "execute_result" 165 | } 166 | ], 167 | "source": [ 168 | "#更新分数,每次更新取决于当前的格子,当前的动作,下个格子,和下个格子的动作\n", 169 | "def get_update(row, col, action, reward, next_row, next_col, next_action):\n", 170 | "\n", 171 | " #计算target\n", 172 | " target = 0.9 * Q[next_row, next_col, next_action]\n", 173 | " target += reward\n", 174 | "\n", 175 | " #计算value\n", 176 | " value = Q[row, col, action]\n", 177 | "\n", 178 | " #根据时序差分算法,当前state,action的分数 = 下一个state,action的分数*gamma + reward\n", 179 | " #此处是求两者的差,越接近0越好\n", 180 | " update = target - value\n", 181 | "\n", 182 | " #这个0.1相当于lr\n", 183 | " update *= 0.1\n", 184 | "\n", 185 | " #更新当前状态和动作的分数\n", 186 | " return update\n", 187 | "\n", 188 | "\n", 189 | "#在0,0向右走,得到-1,到达0,1,再次执行向右走\n", 190 | "get_update(0, 0, 3, -1, 0, 1, 3)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 6, 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "name": "stdout", 200 | "output_type": "stream", 201 | "text": [ 202 | "0 -116\n", 203 | "150 -25\n", 204 | "300 -20\n", 205 | "450 -18\n", 206 | "600 -17\n", 207 | "750 -15\n", 208 | "900 -18\n", 209 | "1050 -18\n", 210 | "1200 -18\n", 211 | "1350 -16\n" 212 | ] 213 | } 214 | ], 215 | "source": [ 216 | "#训练\n", 217 | "def train():\n", 218 | " for epoch in range(1500):\n", 219 | " #初始化当前位置\n", 220 | " row = random.choice(range(4))\n", 221 | " col = 0\n", 222 | "\n", 223 | " #初始化第一个动作\n", 224 | " action = get_action(row, col)\n", 225 | "\n", 226 | " #计算反馈的和,这个数字应该越来越小\n", 227 | " reward_sum = 0\n", 228 | "\n", 229 | " #循环直到到达终点或者掉进陷阱\n", 230 | " while get_state(row, col) not in ['terminal', 'trap']:\n", 231 | "\n", 232 | " #执行动作\n", 233 | " next_row, next_col, reward = move(row, col, action)\n", 234 | " reward_sum += reward\n", 235 | "\n", 236 | " #求新位置的动作\n", 237 | " next_action = get_action(next_row, next_col)\n", 238 | "\n", 239 | " #更新分数\n", 240 | " update = get_update(row, col, action, reward, next_row, next_col,\n", 241 | " next_action)\n", 242 | " Q[row, col, action] += update\n", 243 | "\n", 244 | " #更新当前位置\n", 245 | " row = next_row\n", 246 | " col = next_col\n", 247 | " action = next_action\n", 248 | "\n", 249 | " if epoch % 150 == 0:\n", 250 | " print(epoch, reward_sum)\n", 251 | "\n", 252 | "\n", 253 | "train()" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 7, 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "□□□□□□□□□□□□\n", 266 | "□↑□□□□□□□□□□\n", 267 | "□□□□□□□□□□□□\n", 268 | "□○○○○○○○○○○❤\n" 269 | ] 270 | } 271 | ], 272 | "source": [ 273 | "#打印游戏,方便测试\n", 274 | "def show(row, col, action):\n", 275 | " graph = [\n", 276 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 277 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 278 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '○', '○', '○', '○', '○',\n", 279 | " '○', '○', '○', '○', '○', '❤'\n", 280 | " ]\n", 281 | "\n", 282 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 283 | "\n", 284 | " graph[row * 12 + col] = action\n", 285 | "\n", 286 | " graph = ''.join(graph)\n", 287 | "\n", 288 | " for i in range(0, 4 * 12, 12):\n", 289 | " print(graph[i:i + 12])\n", 290 | "\n", 291 | "\n", 292 | "show(1, 1, 0)" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 8, 298 | "metadata": { 299 | "scrolled": true 300 | }, 301 | "outputs": [ 302 | { 303 | "name": "stdout", 304 | "output_type": "stream", 305 | "text": [ 306 | "□□□□□□□□□□□□\n", 307 | "□□□□□□□□□□□□\n", 308 | "□□□□□□□□□□□↓\n", 309 | "□○○○○○○○○○○❤\n" 310 | ] 311 | } 312 | ], 313 | "source": [ 314 | "from IPython import display\n", 315 | "import time\n", 316 | "\n", 317 | "\n", 318 | "def test():\n", 319 | " #起点\n", 320 | " row = random.choice(range(4))\n", 321 | " col = 0\n", 322 | "\n", 323 | " #最多玩N步\n", 324 | " for _ in range(200):\n", 325 | "\n", 326 | " #获取当前状态,如果状态是终点或者掉陷阱则终止\n", 327 | " if get_state(row, col) in ['trap', 'terminal']:\n", 328 | " break\n", 329 | "\n", 330 | " #选择最优动作\n", 331 | " action = Q[row, col].argmax()\n", 332 | "\n", 333 | " #打印这个动作\n", 334 | " display.clear_output(wait=True)\n", 335 | " time.sleep(0.1)\n", 336 | " show(row, col, action)\n", 337 | "\n", 338 | " #执行动作\n", 339 | " row, col, reward = move(row, col, action)\n", 340 | "\n", 341 | "\n", 342 | "test()" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 9, 348 | "metadata": {}, 349 | "outputs": [ 350 | { 351 | "name": "stdout", 352 | "output_type": "stream", 353 | "text": [ 354 | "→→→→→→→→→→→↓\n", 355 | "→→↑→→→→→→→→↓\n", 356 | "↑↑↑↑←↑↑↑↑↑→↓\n", 357 | "↑↑↑↑↑↑↑↑↑↑↑↑\n" 358 | ] 359 | } 360 | ], 361 | "source": [ 362 | "#打印所有格子的动作倾向\n", 363 | "for row in range(4):\n", 364 | " line = ''\n", 365 | " for col in range(12):\n", 366 | " action = Q[row, col].argmax()\n", 367 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 368 | " line += action\n", 369 | " print(line)" 370 | ] 371 | } 372 | ], 373 | "metadata": { 374 | "colab": { 375 | "collapsed_sections": [], 376 | "name": "第5章-时序差分算法.ipynb", 377 | "provenance": [] 378 | }, 379 | "kernelspec": { 380 | "display_name": "Python 3", 381 | "language": "python", 382 | "name": "python3" 383 | }, 384 | "language_info": { 385 | "codemirror_mode": { 386 | "name": "ipython", 387 | "version": 3 388 | }, 389 | "file_extension": ".py", 390 | "mimetype": "text/x-python", 391 | "name": "python", 392 | "nbconvert_exporter": "python", 393 | "pygments_lexer": "ipython3", 394 | "version": "3.6.13" 395 | } 396 | }, 397 | "nbformat": 4, 398 | "nbformat_minor": 1 399 | } 400 | -------------------------------------------------------------------------------- /4.时序差分算法/.ipynb_checkpoints/2.N步Sarsa算法-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "'ground'" 12 | ] 13 | }, 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "#获取一个格子的状态\n", 21 | "def get_state(row, col):\n", 22 | " if row != 3:\n", 23 | " return 'ground'\n", 24 | "\n", 25 | " if row == 3 and col == 0:\n", 26 | " return 'ground'\n", 27 | "\n", 28 | " if row == 3 and col == 11:\n", 29 | " return 'terminal'\n", 30 | "\n", 31 | " return 'trap'\n", 32 | "\n", 33 | "\n", 34 | "get_state(0, 0)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "(0, 1, -1)" 46 | ] 47 | }, 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | } 52 | ], 53 | "source": [ 54 | "#在一个格子里做一个动作\n", 55 | "def move(row, col, action):\n", 56 | " #如果当前已经在陷阱或者终点,则不能执行任何动作\n", 57 | " if get_state(row, col) in ['trap', 'terminal']:\n", 58 | " return row, col, 0\n", 59 | "\n", 60 | " #↑\n", 61 | " if action == 0:\n", 62 | " row -= 1\n", 63 | "\n", 64 | " #↓\n", 65 | " if action == 1:\n", 66 | " row += 1\n", 67 | "\n", 68 | " #←\n", 69 | " if action == 2:\n", 70 | " col -= 1\n", 71 | "\n", 72 | " #→\n", 73 | " if action == 3:\n", 74 | " col += 1\n", 75 | "\n", 76 | " #不允许走到地图外面去\n", 77 | " row = max(0, row)\n", 78 | " row = min(3, row)\n", 79 | " col = max(0, col)\n", 80 | " col = min(11, col)\n", 81 | "\n", 82 | " #是陷阱的话,奖励是-100,否则都是-1\n", 83 | " reward = -1\n", 84 | " if get_state(row, col) == 'trap':\n", 85 | " reward = -100\n", 86 | "\n", 87 | " return row, col, reward\n", 88 | "\n", 89 | "\n", 90 | "move(0, 0, 3)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "data": { 100 | "text/plain": [ 101 | "(4, 12, 4)" 102 | ] 103 | }, 104 | "execution_count": 3, 105 | "metadata": {}, 106 | "output_type": "execute_result" 107 | } 108 | ], 109 | "source": [ 110 | "import numpy as np\n", 111 | "\n", 112 | "#初始化在每一个格子里采取每个动作的分数,初始化都是0,因为没有任何的知识\n", 113 | "Q = np.zeros([4, 12, 4])\n", 114 | "\n", 115 | "#初始化3个list,用来存储状态,动作,反馈的历史数据,因为后面要回溯这些数据\n", 116 | "state_list = []\n", 117 | "action_list = []\n", 118 | "reward_list = []\n", 119 | "\n", 120 | "Q.shape" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 4, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "0" 132 | ] 133 | }, 134 | "execution_count": 4, 135 | "metadata": {}, 136 | "output_type": "execute_result" 137 | } 138 | ], 139 | "source": [ 140 | "import random\n", 141 | "\n", 142 | "\n", 143 | "#根据状态选择一个动作\n", 144 | "def get_action(row, col):\n", 145 | " #有小概率选择随机动作\n", 146 | " if random.random() < 0.1:\n", 147 | " return random.choice(range(4))\n", 148 | "\n", 149 | " #否则选择分数最高的动作\n", 150 | " return Q[row, col].argmax()\n", 151 | "\n", 152 | "\n", 153 | "get_action(0, 0)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 5, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "#获取5个时间步分别的分数\n", 163 | "def get_update_list(next_row, next_col, next_action):\n", 164 | " #初始化的target是最后一个state和最后一个action的分数\n", 165 | " target = Q[next_row, next_col, next_action]\n", 166 | "\n", 167 | " #计算每一步的target\n", 168 | " #每一步的tagret等于下一步的tagret*0.9,再加上本步的reward\n", 169 | " #时间从后往前回溯,越以前的tagret会累加的信息越多\n", 170 | " #[4, 3, 2, 1, 0]\n", 171 | " target_list = []\n", 172 | " for i in reversed(range(5)):\n", 173 | " target = 0.9 * target + reward_list[i]\n", 174 | " target_list.append(target)\n", 175 | "\n", 176 | " #把时间顺序正过来\n", 177 | " target_list = list(reversed(target_list))\n", 178 | "\n", 179 | " #计算每一步的value\n", 180 | " value_list = []\n", 181 | " for i in range(5):\n", 182 | " row, col = state_list[i]\n", 183 | " action = action_list[i]\n", 184 | " value_list.append(Q[row, col, action])\n", 185 | "\n", 186 | "\n", 187 | " #计算每一步的更新量\n", 188 | " update_list = []\n", 189 | " for i in range(5):\n", 190 | " #根据时序差分算法,当前state,action的分数 = 下一个state,action的分数*gamma + reward\n", 191 | " #此处是求两者的差,越接近0越好\n", 192 | " update = target_list[i] - value_list[i]\n", 193 | "\n", 194 | " #这个0.1相当于lr\n", 195 | " update *= 0.1\n", 196 | "\n", 197 | " update_list.append(update)\n", 198 | "\n", 199 | " return update_list\n", 200 | "\n", 201 | "\n", 202 | "#get_update_list(0, 0, 0)" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 6, 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "name": "stdout", 212 | "output_type": "stream", 213 | "text": [ 214 | "0 -250\n", 215 | "100 -16\n", 216 | "200 -23\n", 217 | "300 -29\n", 218 | "400 -31\n", 219 | "500 -20\n", 220 | "600 -17\n", 221 | "700 -18\n", 222 | "800 -19\n", 223 | "900 -19\n", 224 | "1000 -21\n", 225 | "1100 -39\n", 226 | "1200 -19\n", 227 | "1300 -18\n", 228 | "1400 -26\n" 229 | ] 230 | } 231 | ], 232 | "source": [ 233 | "#训练\n", 234 | "def train():\n", 235 | " for epoch in range(1500):\n", 236 | " #初始化当前位置\n", 237 | " row = random.choice(range(4))\n", 238 | " col = 0\n", 239 | "\n", 240 | " #初始化第一个动作\n", 241 | " action = get_action(row, col)\n", 242 | "\n", 243 | " #计算反馈的和,这个数字应该越来越小\n", 244 | " reward_sum = 0\n", 245 | "\n", 246 | " #初始化3个列表\n", 247 | " state_list.clear()\n", 248 | " action_list.clear()\n", 249 | " reward_list.clear()\n", 250 | "\n", 251 | " #循环直到到达终点或者掉进陷阱\n", 252 | " while get_state(row, col) not in ['terminal', 'trap']:\n", 253 | "\n", 254 | " #执行动作\n", 255 | " next_row, next_col, reward = move(row, col, action)\n", 256 | " reward_sum += reward\n", 257 | "\n", 258 | " #求新位置的动作\n", 259 | " next_action = get_action(next_row, next_col)\n", 260 | "\n", 261 | " #记录历史数据\n", 262 | " state_list.append([row, col])\n", 263 | " action_list.append(action)\n", 264 | " reward_list.append(reward)\n", 265 | "\n", 266 | " #积累到5步以后再开始更新参数\n", 267 | " if len(state_list) == 5:\n", 268 | "\n", 269 | " #计算分数\n", 270 | " update_list = get_update_list(next_row, next_col, next_action)\n", 271 | "\n", 272 | " #只更新第一步的分数\n", 273 | " row, col = state_list[0]\n", 274 | " action = action_list[0]\n", 275 | " update = update_list[0]\n", 276 | "\n", 277 | " Q[row, col, action] += update\n", 278 | "\n", 279 | " #移除第一步,这样在下一次循环时保持列表是5个元素\n", 280 | " state_list.pop(0)\n", 281 | " action_list.pop(0)\n", 282 | " reward_list.pop(0)\n", 283 | "\n", 284 | " #更新当前位置\n", 285 | " row = next_row\n", 286 | " col = next_col\n", 287 | " action = next_action\n", 288 | "\n", 289 | " #走到终点以后,更新剩下步数的update\n", 290 | " for i in range(len(state_list)):\n", 291 | " row, col = state_list[i]\n", 292 | " action = action_list[i]\n", 293 | " update = update_list[i]\n", 294 | " Q[row, col, action] += update\n", 295 | "\n", 296 | " if epoch % 100 == 0:\n", 297 | " print(epoch, reward_sum)\n", 298 | "\n", 299 | "\n", 300 | "train()" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 7, 306 | "metadata": {}, 307 | "outputs": [ 308 | { 309 | "name": "stdout", 310 | "output_type": "stream", 311 | "text": [ 312 | "□□□□□□□□□□□□\n", 313 | "□↑□□□□□□□□□□\n", 314 | "□□□□□□□□□□□□\n", 315 | "□○○○○○○○○○○❤\n" 316 | ] 317 | } 318 | ], 319 | "source": [ 320 | "#打印游戏,方便测试\n", 321 | "def show(row, col, action):\n", 322 | " graph = [\n", 323 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 324 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 325 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '○', '○', '○', '○', '○',\n", 326 | " '○', '○', '○', '○', '○', '❤'\n", 327 | " ]\n", 328 | "\n", 329 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 330 | "\n", 331 | " graph[row * 12 + col] = action\n", 332 | "\n", 333 | " graph = ''.join(graph)\n", 334 | "\n", 335 | " for i in range(0, 4 * 12, 12):\n", 336 | " print(graph[i:i + 12])\n", 337 | "\n", 338 | "\n", 339 | "show(1, 1, 0)" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 8, 345 | "metadata": { 346 | "scrolled": false 347 | }, 348 | "outputs": [ 349 | { 350 | "name": "stdout", 351 | "output_type": "stream", 352 | "text": [ 353 | "□□□□□□□□□□□□\n", 354 | "□□□□□□□□□□□□\n", 355 | "□□□□□□□□□□□↓\n", 356 | "□○○○○○○○○○○❤\n" 357 | ] 358 | } 359 | ], 360 | "source": [ 361 | "from IPython import display\n", 362 | "import time\n", 363 | "\n", 364 | "\n", 365 | "def play():\n", 366 | " #起点\n", 367 | " row = random.choice(range(4))\n", 368 | " col = 0\n", 369 | "\n", 370 | " #最多玩N步\n", 371 | " for _ in range(200):\n", 372 | "\n", 373 | " #获取当前状态,如果状态是终点或者掉陷阱则终止\n", 374 | " if get_state(row, col) in ['trap', 'terminal']:\n", 375 | " break\n", 376 | "\n", 377 | " #选择最优动作\n", 378 | " action = Q[row, col].argmax()\n", 379 | "\n", 380 | " #打印这个动作\n", 381 | " display.clear_output(wait=True)\n", 382 | " time.sleep(0.1)\n", 383 | " show(row, col, action)\n", 384 | "\n", 385 | " #执行动作\n", 386 | " row, col, reward = move(row, col, action)\n", 387 | "\n", 388 | "\n", 389 | "play()" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 9, 395 | "metadata": {}, 396 | "outputs": [ 397 | { 398 | "name": "stdout", 399 | "output_type": "stream", 400 | "text": [ 401 | "→→→→→→→→→↓→↓\n", 402 | "→↑↑→↑↑↑↑←→↑↓\n", 403 | "↑↑→←↑↑↑→↑←→↓\n", 404 | "↑↑↑↑↑↑↑↑↑↑↑↑\n" 405 | ] 406 | } 407 | ], 408 | "source": [ 409 | "#打印所有格子的动作倾向\n", 410 | "for row in range(4):\n", 411 | " line = ''\n", 412 | " for col in range(12):\n", 413 | " action = Q[row, col].argmax()\n", 414 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 415 | " line += action\n", 416 | " print(line)" 417 | ] 418 | } 419 | ], 420 | "metadata": { 421 | "colab": { 422 | "collapsed_sections": [], 423 | "name": "第5章-时序差分算法.ipynb", 424 | "provenance": [] 425 | }, 426 | "kernelspec": { 427 | "display_name": "Python 3", 428 | "language": "python", 429 | "name": "python3" 430 | }, 431 | "language_info": { 432 | "codemirror_mode": { 433 | "name": "ipython", 434 | "version": 3 435 | }, 436 | "file_extension": ".py", 437 | "mimetype": "text/x-python", 438 | "name": "python", 439 | "nbconvert_exporter": "python", 440 | "pygments_lexer": "ipython3", 441 | "version": "3.6.13" 442 | } 443 | }, 444 | "nbformat": 4, 445 | "nbformat_minor": 1 446 | } 447 | -------------------------------------------------------------------------------- /4.时序差分算法/.ipynb_checkpoints/3.QLearning-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "'ground'" 12 | ] 13 | }, 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "#获取一个格子的状态\n", 21 | "def get_state(row, col):\n", 22 | " if row != 3:\n", 23 | " return 'ground'\n", 24 | "\n", 25 | " if row == 3 and col == 0:\n", 26 | " return 'ground'\n", 27 | "\n", 28 | " if row == 3 and col == 11:\n", 29 | " return 'terminal'\n", 30 | "\n", 31 | " return 'trap'\n", 32 | "\n", 33 | "\n", 34 | "get_state(0, 0)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "(0, 1, -1)" 46 | ] 47 | }, 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | } 52 | ], 53 | "source": [ 54 | "#在一个格子里做一个动作\n", 55 | "def move(row, col, action):\n", 56 | " #如果当前已经在陷阱或者终点,则不能执行任何动作\n", 57 | " if get_state(row, col) in ['trap', 'terminal']:\n", 58 | " return row, col, 0\n", 59 | "\n", 60 | " #↑\n", 61 | " if action == 0:\n", 62 | " row -= 1\n", 63 | "\n", 64 | " #↓\n", 65 | " if action == 1:\n", 66 | " row += 1\n", 67 | "\n", 68 | " #←\n", 69 | " if action == 2:\n", 70 | " col -= 1\n", 71 | "\n", 72 | " #→\n", 73 | " if action == 3:\n", 74 | " col += 1\n", 75 | "\n", 76 | " #不允许走到地图外面去\n", 77 | " row = max(0, row)\n", 78 | " row = min(3, row)\n", 79 | " col = max(0, col)\n", 80 | " col = min(11, col)\n", 81 | "\n", 82 | " #是陷阱的话,奖励是-100,否则都是-1\n", 83 | " reward = -1\n", 84 | " if get_state(row, col) == 'trap':\n", 85 | " reward = -100\n", 86 | "\n", 87 | " return row, col, reward\n", 88 | "\n", 89 | "\n", 90 | "move(0, 0, 3)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "data": { 100 | "text/plain": [ 101 | "(4, 12, 4)" 102 | ] 103 | }, 104 | "execution_count": 3, 105 | "metadata": {}, 106 | "output_type": "execute_result" 107 | } 108 | ], 109 | "source": [ 110 | "import numpy as np\n", 111 | "\n", 112 | "#初始化在每一个格子里采取每个动作的分数,初始化都是0,因为没有任何的知识\n", 113 | "Q = np.zeros([4, 12, 4])\n", 114 | "\n", 115 | "Q.shape" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 4, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "text/plain": [ 126 | "0" 127 | ] 128 | }, 129 | "execution_count": 4, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "import random\n", 136 | "\n", 137 | "\n", 138 | "#根据状态选择一个动作\n", 139 | "def get_action(row, col):\n", 140 | " #有小概率选择随机动作\n", 141 | " if random.random() < 0.1:\n", 142 | " return random.choice(range(4))\n", 143 | "\n", 144 | " #否则选择分数最高的动作\n", 145 | " return Q[row, col].argmax()\n", 146 | "\n", 147 | "\n", 148 | "get_action(0, 0)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 5, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/plain": [ 159 | "-0.1" 160 | ] 161 | }, 162 | "execution_count": 5, 163 | "metadata": {}, 164 | "output_type": "execute_result" 165 | } 166 | ], 167 | "source": [ 168 | "def get_update(row, col, action, reward, next_row, next_col):\n", 169 | " #target为下一个格子的最高分数,这里的计算和下一步的动作无关\n", 170 | " target = 0.9 * Q[next_row, next_col].max()\n", 171 | " #加上本步的分数\n", 172 | " target += reward\n", 173 | "\n", 174 | " #value为当前state和action的分数\n", 175 | " value = Q[row, col, action]\n", 176 | "\n", 177 | " #根据时序差分算法,当前state,action的分数 = 下一个state,action的分数*gamma + reward\n", 178 | " #此处是求两者的差,越接近0越好\n", 179 | " update = target - value\n", 180 | "\n", 181 | " #这个0.1相当于lr\n", 182 | " update *= 0.1\n", 183 | "\n", 184 | " return update\n", 185 | "\n", 186 | "\n", 187 | "get_update(0, 0, 3, -1, 0, 1)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 6, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "name": "stdout", 197 | "output_type": "stream", 198 | "text": [ 199 | "0 -118\n", 200 | "100 -49\n", 201 | "200 -31\n", 202 | "300 -20\n", 203 | "400 -109\n", 204 | "500 -102\n", 205 | "600 -12\n", 206 | "700 -13\n", 207 | "800 -18\n", 208 | "900 -13\n", 209 | "1000 -12\n", 210 | "1100 -13\n", 211 | "1200 -15\n", 212 | "1300 -14\n", 213 | "1400 -105\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "#训练\n", 219 | "def train():\n", 220 | " for epoch in range(1500):\n", 221 | " #初始化当前位置\n", 222 | " row = random.choice(range(4))\n", 223 | " col = 0\n", 224 | "\n", 225 | " #初始化第一个动作\n", 226 | " action = get_action(row, col)\n", 227 | "\n", 228 | " #计算反馈的和,这个数字应该越来越小\n", 229 | " reward_sum = 0\n", 230 | "\n", 231 | " #循环直到到达终点或者掉进陷阱\n", 232 | " while get_state(row, col) not in ['terminal', 'trap']:\n", 233 | "\n", 234 | " #执行动作\n", 235 | " next_row, next_col, reward = move(row, col, action)\n", 236 | " reward_sum += reward\n", 237 | "\n", 238 | " #求新位置的动作\n", 239 | " next_action = get_action(next_row, next_col)\n", 240 | "\n", 241 | " #计算分数\n", 242 | " update = get_update(row, col, action, reward, next_row, next_col)\n", 243 | "\n", 244 | " #更新分数\n", 245 | " Q[row, col, action] += update\n", 246 | "\n", 247 | " #更新当前位置\n", 248 | " row = next_row\n", 249 | " col = next_col\n", 250 | " action = next_action\n", 251 | "\n", 252 | " if epoch % 100 == 0:\n", 253 | " print(epoch, reward_sum)\n", 254 | "\n", 255 | "\n", 256 | "train()" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 7, 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "name": "stdout", 266 | "output_type": "stream", 267 | "text": [ 268 | "□□□□□□□□□□□□\n", 269 | "□↑□□□□□□□□□□\n", 270 | "□□□□□□□□□□□□\n", 271 | "□○○○○○○○○○○❤\n" 272 | ] 273 | } 274 | ], 275 | "source": [ 276 | "#打印游戏,方便测试\n", 277 | "def show(row, col, action):\n", 278 | " graph = [\n", 279 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 280 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 281 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '○', '○', '○', '○', '○',\n", 282 | " '○', '○', '○', '○', '○', '❤'\n", 283 | " ]\n", 284 | "\n", 285 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 286 | "\n", 287 | " graph[row * 12 + col] = action\n", 288 | "\n", 289 | " graph = ''.join(graph)\n", 290 | "\n", 291 | " for i in range(0, 4 * 12, 12):\n", 292 | " print(graph[i:i + 12])\n", 293 | "\n", 294 | "\n", 295 | "show(1, 1, 0)" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 8, 301 | "metadata": { 302 | "scrolled": false 303 | }, 304 | "outputs": [ 305 | { 306 | "name": "stdout", 307 | "output_type": "stream", 308 | "text": [ 309 | "□□□□□□□□□□□□\n", 310 | "□□□□□□□□□□□□\n", 311 | "□□□□□□□□□□□↓\n", 312 | "□○○○○○○○○○○❤\n" 313 | ] 314 | } 315 | ], 316 | "source": [ 317 | "from IPython import display\n", 318 | "import time\n", 319 | "\n", 320 | "\n", 321 | "def test():\n", 322 | " #起点\n", 323 | " row = random.choice(range(4))\n", 324 | " col = 0\n", 325 | "\n", 326 | " #最多玩N步\n", 327 | " for _ in range(200):\n", 328 | "\n", 329 | " #获取当前状态,如果状态是终点或者掉陷阱则终止\n", 330 | " if get_state(row, col) in ['trap', 'terminal']:\n", 331 | " break\n", 332 | "\n", 333 | " #选择最优动作\n", 334 | " action = Q[row, col].argmax()\n", 335 | "\n", 336 | " #打印这个动作\n", 337 | " display.clear_output(wait=True)\n", 338 | " time.sleep(0.1)\n", 339 | " show(row, col, action)\n", 340 | "\n", 341 | " #执行动作\n", 342 | " row, col, reward = move(row, col, action)\n", 343 | "\n", 344 | "\n", 345 | "test()" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 9, 351 | "metadata": {}, 352 | "outputs": [ 353 | { 354 | "name": "stdout", 355 | "output_type": "stream", 356 | "text": [ 357 | "→→→→→→↓→→↓→↓\n", 358 | "↓→→↓→→→→→↓→↓\n", 359 | "→→→→→→→→→→→↓\n", 360 | "↑↑↑↑↑↑↑↑↑↑↑↑\n" 361 | ] 362 | } 363 | ], 364 | "source": [ 365 | "#打印所有格子的动作倾向\n", 366 | "for row in range(4):\n", 367 | " line = ''\n", 368 | " for col in range(12):\n", 369 | " action = Q[row, col].argmax()\n", 370 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 371 | " line += action\n", 372 | " print(line)" 373 | ] 374 | } 375 | ], 376 | "metadata": { 377 | "colab": { 378 | "collapsed_sections": [], 379 | "name": "第5章-时序差分算法.ipynb", 380 | "provenance": [] 381 | }, 382 | "kernelspec": { 383 | "display_name": "Python 3", 384 | "language": "python", 385 | "name": "python3" 386 | }, 387 | "language_info": { 388 | "codemirror_mode": { 389 | "name": "ipython", 390 | "version": 3 391 | }, 392 | "file_extension": ".py", 393 | "mimetype": "text/x-python", 394 | "name": "python", 395 | "nbconvert_exporter": "python", 396 | "pygments_lexer": "ipython3", 397 | "version": "3.6.13" 398 | } 399 | }, 400 | "nbformat": 4, 401 | "nbformat_minor": 1 402 | } 403 | -------------------------------------------------------------------------------- /4.时序差分算法/1.Sarsa算法.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "'ground'" 12 | ] 13 | }, 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "#获取一个格子的状态\n", 21 | "def get_state(row, col):\n", 22 | " if row != 3:\n", 23 | " return 'ground'\n", 24 | "\n", 25 | " if row == 3 and col == 0:\n", 26 | " return 'ground'\n", 27 | "\n", 28 | " if row == 3 and col == 11:\n", 29 | " return 'terminal'\n", 30 | "\n", 31 | " return 'trap'\n", 32 | "\n", 33 | "\n", 34 | "get_state(0, 0)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "(0, 1, -1)" 46 | ] 47 | }, 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | } 52 | ], 53 | "source": [ 54 | "#在一个格子里做一个动作\n", 55 | "def move(row, col, action):\n", 56 | " #如果当前已经在陷阱或者终点,则不能执行任何动作\n", 57 | " if get_state(row, col) in ['trap', 'terminal']:\n", 58 | " return row, col, 0\n", 59 | "\n", 60 | " #↑\n", 61 | " if action == 0:\n", 62 | " row -= 1\n", 63 | "\n", 64 | " #↓\n", 65 | " if action == 1:\n", 66 | " row += 1\n", 67 | "\n", 68 | " #←\n", 69 | " if action == 2:\n", 70 | " col -= 1\n", 71 | "\n", 72 | " #→\n", 73 | " if action == 3:\n", 74 | " col += 1\n", 75 | "\n", 76 | " #不允许走到地图外面去\n", 77 | " row = max(0, row)\n", 78 | " row = min(3, row)\n", 79 | " col = max(0, col)\n", 80 | " col = min(11, col)\n", 81 | "\n", 82 | " #是陷阱的话,奖励是-100,否则都是-1\n", 83 | " reward = -1\n", 84 | " if get_state(row, col) == 'trap':\n", 85 | " reward = -100\n", 86 | "\n", 87 | " return row, col, reward\n", 88 | "\n", 89 | "\n", 90 | "move(0, 0, 3)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "data": { 100 | "text/plain": [ 101 | "(4, 12, 4)" 102 | ] 103 | }, 104 | "execution_count": 3, 105 | "metadata": {}, 106 | "output_type": "execute_result" 107 | } 108 | ], 109 | "source": [ 110 | "import numpy as np\n", 111 | "\n", 112 | "#初始化在每一个格子里采取每个动作的分数,初始化都是0,因为没有任何的知识\n", 113 | "Q = np.zeros([4, 12, 4])\n", 114 | "\n", 115 | "Q.shape" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 4, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "text/plain": [ 126 | "0" 127 | ] 128 | }, 129 | "execution_count": 4, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "import random\n", 136 | "\n", 137 | "\n", 138 | "#根据状态选择一个动作\n", 139 | "def get_action(row, col):\n", 140 | " #有小概率选择随机动作\n", 141 | " if random.random() < 0.1:\n", 142 | " return random.choice(range(4))\n", 143 | "\n", 144 | " #否则选择分数最高的动作\n", 145 | " return Q[row, col].argmax()\n", 146 | "\n", 147 | "\n", 148 | "get_action(0, 0)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 5, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/plain": [ 159 | "-0.1" 160 | ] 161 | }, 162 | "execution_count": 5, 163 | "metadata": {}, 164 | "output_type": "execute_result" 165 | } 166 | ], 167 | "source": [ 168 | "#更新分数,每次更新取决于当前的格子,当前的动作,下个格子,和下个格子的动作\n", 169 | "def get_update(row, col, action, reward, next_row, next_col, next_action):\n", 170 | "\n", 171 | " #计算target\n", 172 | " target = 0.9 * Q[next_row, next_col, next_action]\n", 173 | " target += reward\n", 174 | "\n", 175 | " #计算value\n", 176 | " value = Q[row, col, action]\n", 177 | "\n", 178 | " #根据时序差分算法,当前state,action的分数 = 下一个state,action的分数*gamma + reward\n", 179 | " #此处是求两者的差,越接近0越好\n", 180 | " update = target - value\n", 181 | "\n", 182 | " #这个0.1相当于lr\n", 183 | " update *= 0.1\n", 184 | "\n", 185 | " #更新当前状态和动作的分数\n", 186 | " return update\n", 187 | "\n", 188 | "\n", 189 | "#在0,0向右走,得到-1,到达0,1,再次执行向右走\n", 190 | "get_update(0, 0, 3, -1, 0, 1, 3)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 6, 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "name": "stdout", 200 | "output_type": "stream", 201 | "text": [ 202 | "0 -116\n", 203 | "150 -25\n", 204 | "300 -20\n", 205 | "450 -18\n", 206 | "600 -17\n", 207 | "750 -15\n", 208 | "900 -18\n", 209 | "1050 -18\n", 210 | "1200 -18\n", 211 | "1350 -16\n" 212 | ] 213 | } 214 | ], 215 | "source": [ 216 | "#训练\n", 217 | "def train():\n", 218 | " for epoch in range(1500):\n", 219 | " #初始化当前位置\n", 220 | " row = random.choice(range(4))\n", 221 | " col = 0\n", 222 | "\n", 223 | " #初始化第一个动作\n", 224 | " action = get_action(row, col)\n", 225 | "\n", 226 | " #计算反馈的和,这个数字应该越来越小\n", 227 | " reward_sum = 0\n", 228 | "\n", 229 | " #循环直到到达终点或者掉进陷阱\n", 230 | " while get_state(row, col) not in ['terminal', 'trap']:\n", 231 | "\n", 232 | " #执行动作\n", 233 | " next_row, next_col, reward = move(row, col, action)\n", 234 | " reward_sum += reward\n", 235 | "\n", 236 | " #求新位置的动作\n", 237 | " next_action = get_action(next_row, next_col)\n", 238 | "\n", 239 | " #更新分数\n", 240 | " update = get_update(row, col, action, reward, next_row, next_col,\n", 241 | " next_action)\n", 242 | " Q[row, col, action] += update\n", 243 | "\n", 244 | " #更新当前位置\n", 245 | " row = next_row\n", 246 | " col = next_col\n", 247 | " action = next_action\n", 248 | "\n", 249 | " if epoch % 150 == 0:\n", 250 | " print(epoch, reward_sum)\n", 251 | "\n", 252 | "\n", 253 | "train()" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 7, 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "□□□□□□□□□□□□\n", 266 | "□↑□□□□□□□□□□\n", 267 | "□□□□□□□□□□□□\n", 268 | "□○○○○○○○○○○❤\n" 269 | ] 270 | } 271 | ], 272 | "source": [ 273 | "#打印游戏,方便测试\n", 274 | "def show(row, col, action):\n", 275 | " graph = [\n", 276 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 277 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 278 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '○', '○', '○', '○', '○',\n", 279 | " '○', '○', '○', '○', '○', '❤'\n", 280 | " ]\n", 281 | "\n", 282 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 283 | "\n", 284 | " graph[row * 12 + col] = action\n", 285 | "\n", 286 | " graph = ''.join(graph)\n", 287 | "\n", 288 | " for i in range(0, 4 * 12, 12):\n", 289 | " print(graph[i:i + 12])\n", 290 | "\n", 291 | "\n", 292 | "show(1, 1, 0)" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 8, 298 | "metadata": { 299 | "scrolled": true 300 | }, 301 | "outputs": [ 302 | { 303 | "name": "stdout", 304 | "output_type": "stream", 305 | "text": [ 306 | "□□□□□□□□□□□□\n", 307 | "□□□□□□□□□□□□\n", 308 | "□□□□□□□□□□□↓\n", 309 | "□○○○○○○○○○○❤\n" 310 | ] 311 | } 312 | ], 313 | "source": [ 314 | "from IPython import display\n", 315 | "import time\n", 316 | "\n", 317 | "\n", 318 | "def test():\n", 319 | " #起点\n", 320 | " row = random.choice(range(4))\n", 321 | " col = 0\n", 322 | "\n", 323 | " #最多玩N步\n", 324 | " for _ in range(200):\n", 325 | "\n", 326 | " #获取当前状态,如果状态是终点或者掉陷阱则终止\n", 327 | " if get_state(row, col) in ['trap', 'terminal']:\n", 328 | " break\n", 329 | "\n", 330 | " #选择最优动作\n", 331 | " action = Q[row, col].argmax()\n", 332 | "\n", 333 | " #打印这个动作\n", 334 | " display.clear_output(wait=True)\n", 335 | " time.sleep(0.1)\n", 336 | " show(row, col, action)\n", 337 | "\n", 338 | " #执行动作\n", 339 | " row, col, reward = move(row, col, action)\n", 340 | "\n", 341 | "\n", 342 | "test()" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 9, 348 | "metadata": {}, 349 | "outputs": [ 350 | { 351 | "name": "stdout", 352 | "output_type": "stream", 353 | "text": [ 354 | "→→→→→→→→→→→↓\n", 355 | "→→↑→→→→→→→→↓\n", 356 | "↑↑↑↑←↑↑↑↑↑→↓\n", 357 | "↑↑↑↑↑↑↑↑↑↑↑↑\n" 358 | ] 359 | } 360 | ], 361 | "source": [ 362 | "#打印所有格子的动作倾向\n", 363 | "for row in range(4):\n", 364 | " line = ''\n", 365 | " for col in range(12):\n", 366 | " action = Q[row, col].argmax()\n", 367 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 368 | " line += action\n", 369 | " print(line)" 370 | ] 371 | } 372 | ], 373 | "metadata": { 374 | "colab": { 375 | "collapsed_sections": [], 376 | "name": "第5章-时序差分算法.ipynb", 377 | "provenance": [] 378 | }, 379 | "kernelspec": { 380 | "display_name": "Python 3", 381 | "language": "python", 382 | "name": "python3" 383 | }, 384 | "language_info": { 385 | "codemirror_mode": { 386 | "name": "ipython", 387 | "version": 3 388 | }, 389 | "file_extension": ".py", 390 | "mimetype": "text/x-python", 391 | "name": "python", 392 | "nbconvert_exporter": "python", 393 | "pygments_lexer": "ipython3", 394 | "version": "3.6.13" 395 | } 396 | }, 397 | "nbformat": 4, 398 | "nbformat_minor": 1 399 | } 400 | -------------------------------------------------------------------------------- /4.时序差分算法/2.N步Sarsa算法.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "'ground'" 12 | ] 13 | }, 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "#获取一个格子的状态\n", 21 | "def get_state(row, col):\n", 22 | " if row != 3:\n", 23 | " return 'ground'\n", 24 | "\n", 25 | " if row == 3 and col == 0:\n", 26 | " return 'ground'\n", 27 | "\n", 28 | " if row == 3 and col == 11:\n", 29 | " return 'terminal'\n", 30 | "\n", 31 | " return 'trap'\n", 32 | "\n", 33 | "\n", 34 | "get_state(0, 0)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "(0, 1, -1)" 46 | ] 47 | }, 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | } 52 | ], 53 | "source": [ 54 | "#在一个格子里做一个动作\n", 55 | "def move(row, col, action):\n", 56 | " #如果当前已经在陷阱或者终点,则不能执行任何动作\n", 57 | " if get_state(row, col) in ['trap', 'terminal']:\n", 58 | " return row, col, 0\n", 59 | "\n", 60 | " #↑\n", 61 | " if action == 0:\n", 62 | " row -= 1\n", 63 | "\n", 64 | " #↓\n", 65 | " if action == 1:\n", 66 | " row += 1\n", 67 | "\n", 68 | " #←\n", 69 | " if action == 2:\n", 70 | " col -= 1\n", 71 | "\n", 72 | " #→\n", 73 | " if action == 3:\n", 74 | " col += 1\n", 75 | "\n", 76 | " #不允许走到地图外面去\n", 77 | " row = max(0, row)\n", 78 | " row = min(3, row)\n", 79 | " col = max(0, col)\n", 80 | " col = min(11, col)\n", 81 | "\n", 82 | " #是陷阱的话,奖励是-100,否则都是-1\n", 83 | " reward = -1\n", 84 | " if get_state(row, col) == 'trap':\n", 85 | " reward = -100\n", 86 | "\n", 87 | " return row, col, reward\n", 88 | "\n", 89 | "\n", 90 | "move(0, 0, 3)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "data": { 100 | "text/plain": [ 101 | "(4, 12, 4)" 102 | ] 103 | }, 104 | "execution_count": 3, 105 | "metadata": {}, 106 | "output_type": "execute_result" 107 | } 108 | ], 109 | "source": [ 110 | "import numpy as np\n", 111 | "\n", 112 | "#初始化在每一个格子里采取每个动作的分数,初始化都是0,因为没有任何的知识\n", 113 | "Q = np.zeros([4, 12, 4])\n", 114 | "\n", 115 | "#初始化3个list,用来存储状态,动作,反馈的历史数据,因为后面要回溯这些数据\n", 116 | "state_list = []\n", 117 | "action_list = []\n", 118 | "reward_list = []\n", 119 | "\n", 120 | "Q.shape" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 4, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "0" 132 | ] 133 | }, 134 | "execution_count": 4, 135 | "metadata": {}, 136 | "output_type": "execute_result" 137 | } 138 | ], 139 | "source": [ 140 | "import random\n", 141 | "\n", 142 | "\n", 143 | "#根据状态选择一个动作\n", 144 | "def get_action(row, col):\n", 145 | " #有小概率选择随机动作\n", 146 | " if random.random() < 0.1:\n", 147 | " return random.choice(range(4))\n", 148 | "\n", 149 | " #否则选择分数最高的动作\n", 150 | " return Q[row, col].argmax()\n", 151 | "\n", 152 | "\n", 153 | "get_action(0, 0)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 5, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "#获取5个时间步分别的分数\n", 163 | "def get_update_list(next_row, next_col, next_action):\n", 164 | " #初始化的target是最后一个state和最后一个action的分数\n", 165 | " target = Q[next_row, next_col, next_action]\n", 166 | "\n", 167 | " #计算每一步的target\n", 168 | " #每一步的tagret等于下一步的tagret*0.9,再加上本步的reward\n", 169 | " #时间从后往前回溯,越以前的tagret会累加的信息越多\n", 170 | " #[4, 3, 2, 1, 0]\n", 171 | " target_list = []\n", 172 | " for i in reversed(range(5)):\n", 173 | " target = 0.9 * target + reward_list[i]\n", 174 | " target_list.append(target)\n", 175 | "\n", 176 | " #把时间顺序正过来\n", 177 | " target_list = list(reversed(target_list))\n", 178 | "\n", 179 | " #计算每一步的value\n", 180 | " value_list = []\n", 181 | " for i in range(5):\n", 182 | " row, col = state_list[i]\n", 183 | " action = action_list[i]\n", 184 | " value_list.append(Q[row, col, action])\n", 185 | "\n", 186 | "\n", 187 | " #计算每一步的更新量\n", 188 | " update_list = []\n", 189 | " for i in range(5):\n", 190 | " #根据时序差分算法,当前state,action的分数 = 下一个state,action的分数*gamma + reward\n", 191 | " #此处是求两者的差,越接近0越好\n", 192 | " update = target_list[i] - value_list[i]\n", 193 | "\n", 194 | " #这个0.1相当于lr\n", 195 | " update *= 0.1\n", 196 | "\n", 197 | " update_list.append(update)\n", 198 | "\n", 199 | " return update_list\n", 200 | "\n", 201 | "\n", 202 | "#get_update_list(0, 0, 0)" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 6, 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "name": "stdout", 212 | "output_type": "stream", 213 | "text": [ 214 | "0 -250\n", 215 | "100 -16\n", 216 | "200 -23\n", 217 | "300 -29\n", 218 | "400 -31\n", 219 | "500 -20\n", 220 | "600 -17\n", 221 | "700 -18\n", 222 | "800 -19\n", 223 | "900 -19\n", 224 | "1000 -21\n", 225 | "1100 -39\n", 226 | "1200 -19\n", 227 | "1300 -18\n", 228 | "1400 -26\n" 229 | ] 230 | } 231 | ], 232 | "source": [ 233 | "#训练\n", 234 | "def train():\n", 235 | " for epoch in range(1500):\n", 236 | " #初始化当前位置\n", 237 | " row = random.choice(range(4))\n", 238 | " col = 0\n", 239 | "\n", 240 | " #初始化第一个动作\n", 241 | " action = get_action(row, col)\n", 242 | "\n", 243 | " #计算反馈的和,这个数字应该越来越小\n", 244 | " reward_sum = 0\n", 245 | "\n", 246 | " #初始化3个列表\n", 247 | " state_list.clear()\n", 248 | " action_list.clear()\n", 249 | " reward_list.clear()\n", 250 | "\n", 251 | " #循环直到到达终点或者掉进陷阱\n", 252 | " while get_state(row, col) not in ['terminal', 'trap']:\n", 253 | "\n", 254 | " #执行动作\n", 255 | " next_row, next_col, reward = move(row, col, action)\n", 256 | " reward_sum += reward\n", 257 | "\n", 258 | " #求新位置的动作\n", 259 | " next_action = get_action(next_row, next_col)\n", 260 | "\n", 261 | " #记录历史数据\n", 262 | " state_list.append([row, col])\n", 263 | " action_list.append(action)\n", 264 | " reward_list.append(reward)\n", 265 | "\n", 266 | " #积累到5步以后再开始更新参数\n", 267 | " if len(state_list) == 5:\n", 268 | "\n", 269 | " #计算分数\n", 270 | " update_list = get_update_list(next_row, next_col, next_action)\n", 271 | "\n", 272 | " #只更新第一步的分数\n", 273 | " row, col = state_list[0]\n", 274 | " action = action_list[0]\n", 275 | " update = update_list[0]\n", 276 | "\n", 277 | " Q[row, col, action] += update\n", 278 | "\n", 279 | " #移除第一步,这样在下一次循环时保持列表是5个元素\n", 280 | " state_list.pop(0)\n", 281 | " action_list.pop(0)\n", 282 | " reward_list.pop(0)\n", 283 | "\n", 284 | " #更新当前位置\n", 285 | " row = next_row\n", 286 | " col = next_col\n", 287 | " action = next_action\n", 288 | "\n", 289 | " #走到终点以后,更新剩下步数的update\n", 290 | " for i in range(len(state_list)):\n", 291 | " row, col = state_list[i]\n", 292 | " action = action_list[i]\n", 293 | " update = update_list[i]\n", 294 | " Q[row, col, action] += update\n", 295 | "\n", 296 | " if epoch % 100 == 0:\n", 297 | " print(epoch, reward_sum)\n", 298 | "\n", 299 | "\n", 300 | "train()" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 7, 306 | "metadata": {}, 307 | "outputs": [ 308 | { 309 | "name": "stdout", 310 | "output_type": "stream", 311 | "text": [ 312 | "□□□□□□□□□□□□\n", 313 | "□↑□□□□□□□□□□\n", 314 | "□□□□□□□□□□□□\n", 315 | "□○○○○○○○○○○❤\n" 316 | ] 317 | } 318 | ], 319 | "source": [ 320 | "#打印游戏,方便测试\n", 321 | "def show(row, col, action):\n", 322 | " graph = [\n", 323 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 324 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 325 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '○', '○', '○', '○', '○',\n", 326 | " '○', '○', '○', '○', '○', '❤'\n", 327 | " ]\n", 328 | "\n", 329 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 330 | "\n", 331 | " graph[row * 12 + col] = action\n", 332 | "\n", 333 | " graph = ''.join(graph)\n", 334 | "\n", 335 | " for i in range(0, 4 * 12, 12):\n", 336 | " print(graph[i:i + 12])\n", 337 | "\n", 338 | "\n", 339 | "show(1, 1, 0)" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 8, 345 | "metadata": { 346 | "scrolled": false 347 | }, 348 | "outputs": [ 349 | { 350 | "name": "stdout", 351 | "output_type": "stream", 352 | "text": [ 353 | "□□□□□□□□□□□□\n", 354 | "□□□□□□□□□□□□\n", 355 | "□□□□□□□□□□□↓\n", 356 | "□○○○○○○○○○○❤\n" 357 | ] 358 | } 359 | ], 360 | "source": [ 361 | "from IPython import display\n", 362 | "import time\n", 363 | "\n", 364 | "\n", 365 | "def play():\n", 366 | " #起点\n", 367 | " row = random.choice(range(4))\n", 368 | " col = 0\n", 369 | "\n", 370 | " #最多玩N步\n", 371 | " for _ in range(200):\n", 372 | "\n", 373 | " #获取当前状态,如果状态是终点或者掉陷阱则终止\n", 374 | " if get_state(row, col) in ['trap', 'terminal']:\n", 375 | " break\n", 376 | "\n", 377 | " #选择最优动作\n", 378 | " action = Q[row, col].argmax()\n", 379 | "\n", 380 | " #打印这个动作\n", 381 | " display.clear_output(wait=True)\n", 382 | " time.sleep(0.1)\n", 383 | " show(row, col, action)\n", 384 | "\n", 385 | " #执行动作\n", 386 | " row, col, reward = move(row, col, action)\n", 387 | "\n", 388 | "\n", 389 | "play()" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 9, 395 | "metadata": {}, 396 | "outputs": [ 397 | { 398 | "name": "stdout", 399 | "output_type": "stream", 400 | "text": [ 401 | "→→→→→→→→→↓→↓\n", 402 | "→↑↑→↑↑↑↑←→↑↓\n", 403 | "↑↑→←↑↑↑→↑←→↓\n", 404 | "↑↑↑↑↑↑↑↑↑↑↑↑\n" 405 | ] 406 | } 407 | ], 408 | "source": [ 409 | "#打印所有格子的动作倾向\n", 410 | "for row in range(4):\n", 411 | " line = ''\n", 412 | " for col in range(12):\n", 413 | " action = Q[row, col].argmax()\n", 414 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 415 | " line += action\n", 416 | " print(line)" 417 | ] 418 | } 419 | ], 420 | "metadata": { 421 | "colab": { 422 | "collapsed_sections": [], 423 | "name": "第5章-时序差分算法.ipynb", 424 | "provenance": [] 425 | }, 426 | "kernelspec": { 427 | "display_name": "Python 3", 428 | "language": "python", 429 | "name": "python3" 430 | }, 431 | "language_info": { 432 | "codemirror_mode": { 433 | "name": "ipython", 434 | "version": 3 435 | }, 436 | "file_extension": ".py", 437 | "mimetype": "text/x-python", 438 | "name": "python", 439 | "nbconvert_exporter": "python", 440 | "pygments_lexer": "ipython3", 441 | "version": "3.6.13" 442 | } 443 | }, 444 | "nbformat": 4, 445 | "nbformat_minor": 1 446 | } 447 | -------------------------------------------------------------------------------- /4.时序差分算法/3.QLearning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "'ground'" 12 | ] 13 | }, 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "#获取一个格子的状态\n", 21 | "def get_state(row, col):\n", 22 | " if row != 3:\n", 23 | " return 'ground'\n", 24 | "\n", 25 | " if row == 3 and col == 0:\n", 26 | " return 'ground'\n", 27 | "\n", 28 | " if row == 3 and col == 11:\n", 29 | " return 'terminal'\n", 30 | "\n", 31 | " return 'trap'\n", 32 | "\n", 33 | "\n", 34 | "get_state(0, 0)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "(0, 1, -1)" 46 | ] 47 | }, 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | } 52 | ], 53 | "source": [ 54 | "#在一个格子里做一个动作\n", 55 | "def move(row, col, action):\n", 56 | " #如果当前已经在陷阱或者终点,则不能执行任何动作\n", 57 | " if get_state(row, col) in ['trap', 'terminal']:\n", 58 | " return row, col, 0\n", 59 | "\n", 60 | " #↑\n", 61 | " if action == 0:\n", 62 | " row -= 1\n", 63 | "\n", 64 | " #↓\n", 65 | " if action == 1:\n", 66 | " row += 1\n", 67 | "\n", 68 | " #←\n", 69 | " if action == 2:\n", 70 | " col -= 1\n", 71 | "\n", 72 | " #→\n", 73 | " if action == 3:\n", 74 | " col += 1\n", 75 | "\n", 76 | " #不允许走到地图外面去\n", 77 | " row = max(0, row)\n", 78 | " row = min(3, row)\n", 79 | " col = max(0, col)\n", 80 | " col = min(11, col)\n", 81 | "\n", 82 | " #是陷阱的话,奖励是-100,否则都是-1\n", 83 | " reward = -1\n", 84 | " if get_state(row, col) == 'trap':\n", 85 | " reward = -100\n", 86 | "\n", 87 | " return row, col, reward\n", 88 | "\n", 89 | "\n", 90 | "move(0, 0, 3)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "data": { 100 | "text/plain": [ 101 | "(4, 12, 4)" 102 | ] 103 | }, 104 | "execution_count": 3, 105 | "metadata": {}, 106 | "output_type": "execute_result" 107 | } 108 | ], 109 | "source": [ 110 | "import numpy as np\n", 111 | "\n", 112 | "#初始化在每一个格子里采取每个动作的分数,初始化都是0,因为没有任何的知识\n", 113 | "Q = np.zeros([4, 12, 4])\n", 114 | "\n", 115 | "Q.shape" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 4, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "text/plain": [ 126 | "0" 127 | ] 128 | }, 129 | "execution_count": 4, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "import random\n", 136 | "\n", 137 | "\n", 138 | "#根据状态选择一个动作\n", 139 | "def get_action(row, col):\n", 140 | " #有小概率选择随机动作\n", 141 | " if random.random() < 0.1:\n", 142 | " return random.choice(range(4))\n", 143 | "\n", 144 | " #否则选择分数最高的动作\n", 145 | " return Q[row, col].argmax()\n", 146 | "\n", 147 | "\n", 148 | "get_action(0, 0)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 5, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/plain": [ 159 | "-0.1" 160 | ] 161 | }, 162 | "execution_count": 5, 163 | "metadata": {}, 164 | "output_type": "execute_result" 165 | } 166 | ], 167 | "source": [ 168 | "def get_update(row, col, action, reward, next_row, next_col):\n", 169 | " #target为下一个格子的最高分数,这里的计算和下一步的动作无关\n", 170 | " target = 0.9 * Q[next_row, next_col].max()\n", 171 | " #加上本步的分数\n", 172 | " target += reward\n", 173 | "\n", 174 | " #value为当前state和action的分数\n", 175 | " value = Q[row, col, action]\n", 176 | "\n", 177 | " #根据时序差分算法,当前state,action的分数 = 下一个state,action的分数*gamma + reward\n", 178 | " #此处是求两者的差,越接近0越好\n", 179 | " update = target - value\n", 180 | "\n", 181 | " #这个0.1相当于lr\n", 182 | " update *= 0.1\n", 183 | "\n", 184 | " return update\n", 185 | "\n", 186 | "\n", 187 | "get_update(0, 0, 3, -1, 0, 1)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 6, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "name": "stdout", 197 | "output_type": "stream", 198 | "text": [ 199 | "0 -118\n", 200 | "100 -49\n", 201 | "200 -31\n", 202 | "300 -20\n", 203 | "400 -109\n", 204 | "500 -102\n", 205 | "600 -12\n", 206 | "700 -13\n", 207 | "800 -18\n", 208 | "900 -13\n", 209 | "1000 -12\n", 210 | "1100 -13\n", 211 | "1200 -15\n", 212 | "1300 -14\n", 213 | "1400 -105\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "#训练\n", 219 | "def train():\n", 220 | " for epoch in range(1500):\n", 221 | " #初始化当前位置\n", 222 | " row = random.choice(range(4))\n", 223 | " col = 0\n", 224 | "\n", 225 | " #初始化第一个动作\n", 226 | " action = get_action(row, col)\n", 227 | "\n", 228 | " #计算反馈的和,这个数字应该越来越小\n", 229 | " reward_sum = 0\n", 230 | "\n", 231 | " #循环直到到达终点或者掉进陷阱\n", 232 | " while get_state(row, col) not in ['terminal', 'trap']:\n", 233 | "\n", 234 | " #执行动作\n", 235 | " next_row, next_col, reward = move(row, col, action)\n", 236 | " reward_sum += reward\n", 237 | "\n", 238 | " #求新位置的动作\n", 239 | " next_action = get_action(next_row, next_col)\n", 240 | "\n", 241 | " #计算分数\n", 242 | " update = get_update(row, col, action, reward, next_row, next_col)\n", 243 | "\n", 244 | " #更新分数\n", 245 | " Q[row, col, action] += update\n", 246 | "\n", 247 | " #更新当前位置\n", 248 | " row = next_row\n", 249 | " col = next_col\n", 250 | " action = next_action\n", 251 | "\n", 252 | " if epoch % 100 == 0:\n", 253 | " print(epoch, reward_sum)\n", 254 | "\n", 255 | "\n", 256 | "train()" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 7, 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "name": "stdout", 266 | "output_type": "stream", 267 | "text": [ 268 | "□□□□□□□□□□□□\n", 269 | "□↑□□□□□□□□□□\n", 270 | "□□□□□□□□□□□□\n", 271 | "□○○○○○○○○○○❤\n" 272 | ] 273 | } 274 | ], 275 | "source": [ 276 | "#打印游戏,方便测试\n", 277 | "def show(row, col, action):\n", 278 | " graph = [\n", 279 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 280 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 281 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '○', '○', '○', '○', '○',\n", 282 | " '○', '○', '○', '○', '○', '❤'\n", 283 | " ]\n", 284 | "\n", 285 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 286 | "\n", 287 | " graph[row * 12 + col] = action\n", 288 | "\n", 289 | " graph = ''.join(graph)\n", 290 | "\n", 291 | " for i in range(0, 4 * 12, 12):\n", 292 | " print(graph[i:i + 12])\n", 293 | "\n", 294 | "\n", 295 | "show(1, 1, 0)" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 8, 301 | "metadata": { 302 | "scrolled": false 303 | }, 304 | "outputs": [ 305 | { 306 | "name": "stdout", 307 | "output_type": "stream", 308 | "text": [ 309 | "□□□□□□□□□□□□\n", 310 | "□□□□□□□□□□□□\n", 311 | "□□□□□□□□□□□↓\n", 312 | "□○○○○○○○○○○❤\n" 313 | ] 314 | } 315 | ], 316 | "source": [ 317 | "from IPython import display\n", 318 | "import time\n", 319 | "\n", 320 | "\n", 321 | "def test():\n", 322 | " #起点\n", 323 | " row = random.choice(range(4))\n", 324 | " col = 0\n", 325 | "\n", 326 | " #最多玩N步\n", 327 | " for _ in range(200):\n", 328 | "\n", 329 | " #获取当前状态,如果状态是终点或者掉陷阱则终止\n", 330 | " if get_state(row, col) in ['trap', 'terminal']:\n", 331 | " break\n", 332 | "\n", 333 | " #选择最优动作\n", 334 | " action = Q[row, col].argmax()\n", 335 | "\n", 336 | " #打印这个动作\n", 337 | " display.clear_output(wait=True)\n", 338 | " time.sleep(0.1)\n", 339 | " show(row, col, action)\n", 340 | "\n", 341 | " #执行动作\n", 342 | " row, col, reward = move(row, col, action)\n", 343 | "\n", 344 | "\n", 345 | "test()" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 9, 351 | "metadata": {}, 352 | "outputs": [ 353 | { 354 | "name": "stdout", 355 | "output_type": "stream", 356 | "text": [ 357 | "→→→→→→↓→→↓→↓\n", 358 | "↓→→↓→→→→→↓→↓\n", 359 | "→→→→→→→→→→→↓\n", 360 | "↑↑↑↑↑↑↑↑↑↑↑↑\n" 361 | ] 362 | } 363 | ], 364 | "source": [ 365 | "#打印所有格子的动作倾向\n", 366 | "for row in range(4):\n", 367 | " line = ''\n", 368 | " for col in range(12):\n", 369 | " action = Q[row, col].argmax()\n", 370 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 371 | " line += action\n", 372 | " print(line)" 373 | ] 374 | } 375 | ], 376 | "metadata": { 377 | "colab": { 378 | "collapsed_sections": [], 379 | "name": "第5章-时序差分算法.ipynb", 380 | "provenance": [] 381 | }, 382 | "kernelspec": { 383 | "display_name": "Python 3", 384 | "language": "python", 385 | "name": "python3" 386 | }, 387 | "language_info": { 388 | "codemirror_mode": { 389 | "name": "ipython", 390 | "version": 3 391 | }, 392 | "file_extension": ".py", 393 | "mimetype": "text/x-python", 394 | "name": "python", 395 | "nbconvert_exporter": "python", 396 | "pygments_lexer": "ipython3", 397 | "version": "3.6.13" 398 | } 399 | }, 400 | "nbformat": 4, 401 | "nbformat_minor": 1 402 | } 403 | -------------------------------------------------------------------------------- /5.DynaQ算法/.ipynb_checkpoints/1.DynaQ-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "'ground'" 12 | ] 13 | }, 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "#获取一个格子的状态\n", 21 | "def get_state(row, col):\n", 22 | " if row != 3:\n", 23 | " return 'ground'\n", 24 | "\n", 25 | " if row == 3 and col == 0:\n", 26 | " return 'ground'\n", 27 | "\n", 28 | " if row == 3 and col == 11:\n", 29 | " return 'terminal'\n", 30 | "\n", 31 | " return 'trap'\n", 32 | "\n", 33 | "\n", 34 | "get_state(0, 0)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "(0, 1, -1)" 46 | ] 47 | }, 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | } 52 | ], 53 | "source": [ 54 | "#在一个格子里做一个动作\n", 55 | "def move(row, col, action):\n", 56 | " #如果当前已经在陷阱或者终点,则不能执行任何动作\n", 57 | " if get_state(row, col) in ['trap', 'terminal']:\n", 58 | " return row, col, 0\n", 59 | "\n", 60 | " #↑\n", 61 | " if action == 0:\n", 62 | " row -= 1\n", 63 | "\n", 64 | " #↓\n", 65 | " if action == 1:\n", 66 | " row += 1\n", 67 | "\n", 68 | " #←\n", 69 | " if action == 2:\n", 70 | " col -= 1\n", 71 | "\n", 72 | " #→\n", 73 | " if action == 3:\n", 74 | " col += 1\n", 75 | "\n", 76 | " #不允许走到地图外面去\n", 77 | " row = max(0, row)\n", 78 | " row = min(3, row)\n", 79 | " col = max(0, col)\n", 80 | " col = min(11, col)\n", 81 | "\n", 82 | " #是陷阱的话,奖励是-100,否则都是-1\n", 83 | " reward = -1\n", 84 | " if get_state(row, col) == 'trap':\n", 85 | " reward = -100\n", 86 | "\n", 87 | " return row, col, reward\n", 88 | "\n", 89 | "\n", 90 | "move(0, 0, 3)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "data": { 100 | "text/plain": [ 101 | "((4, 12, 4), {})" 102 | ] 103 | }, 104 | "execution_count": 3, 105 | "metadata": {}, 106 | "output_type": "execute_result" 107 | } 108 | ], 109 | "source": [ 110 | "import numpy as np\n", 111 | "\n", 112 | "#初始化在每一个格子里采取每个动作的分数,初始化都是0,因为没有任何的知识\n", 113 | "Q = np.zeros([4, 12, 4])\n", 114 | "\n", 115 | "#保存历史数据,键是(row,col,action),值是(next_row,next_col,reward)\n", 116 | "history = dict()\n", 117 | "\n", 118 | "Q.shape, history" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 4, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "data": { 128 | "text/plain": [ 129 | "0" 130 | ] 131 | }, 132 | "execution_count": 4, 133 | "metadata": {}, 134 | "output_type": "execute_result" 135 | } 136 | ], 137 | "source": [ 138 | "import random\n", 139 | "\n", 140 | "\n", 141 | "#根据状态选择一个动作\n", 142 | "def get_action(row, col):\n", 143 | " #有小概率选择随机动作\n", 144 | " if random.random() < 0.1:\n", 145 | " return random.choice(range(4))\n", 146 | "\n", 147 | " #否则选择分数最高的动作\n", 148 | " return Q[row, col].argmax()\n", 149 | "\n", 150 | "\n", 151 | "get_action(0, 0)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 5, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "data": { 161 | "text/plain": [ 162 | "-0.1" 163 | ] 164 | }, 165 | "execution_count": 5, 166 | "metadata": {}, 167 | "output_type": "execute_result" 168 | } 169 | ], 170 | "source": [ 171 | "def get_update(row, col, action, reward, next_row, next_col):\n", 172 | " #target为下一个格子的最高分数,这里的计算和下一步的动作无关\n", 173 | " target = 0.9 * Q[next_row, next_col].max()\n", 174 | " #加上本步的分数\n", 175 | " target += reward\n", 176 | "\n", 177 | " #计算value\n", 178 | " value = Q[row, col, action]\n", 179 | "\n", 180 | " #根据时序差分算法,当前state,action的分数 = 下一个state,action的分数*gamma + reward\n", 181 | " #此处是求两者的差,越接近0越好\n", 182 | " update = target - value\n", 183 | "\n", 184 | " #这个0.1相当于lr\n", 185 | " update *= 0.1\n", 186 | "\n", 187 | " return update\n", 188 | "\n", 189 | "\n", 190 | "get_update(0, 0, 3, -1, 0, 1)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 6, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "import random\n", 200 | "\n", 201 | "\n", 202 | "def q_planning():\n", 203 | " #Q planning循环,相当于是在反刍历史数据,随机取N个历史数据再进行离线学习\n", 204 | " for _ in range(20):\n", 205 | " #随机选择曾经遇到过的状态动作对\n", 206 | " row, col, action = random.choice(list(history.keys()))\n", 207 | "\n", 208 | " #再获取下一个状态和反馈\n", 209 | " next_row, next_col, reward = history[(row, col, action)]\n", 210 | "\n", 211 | " #计算分数\n", 212 | " update = get_update(row, col, action, reward, next_row, next_col)\n", 213 | "\n", 214 | " #更新分数\n", 215 | " Q[row, col, action] += update\n", 216 | "\n", 217 | "\n", 218 | "#q_planning()" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 7, 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "name": "stdout", 228 | "output_type": "stream", 229 | "text": [ 230 | "0 -129\n", 231 | "20 -108\n", 232 | "40 -16\n", 233 | "60 -12\n", 234 | "80 -15\n", 235 | "100 -13\n", 236 | "120 -12\n", 237 | "140 -15\n", 238 | "160 -13\n", 239 | "180 -15\n", 240 | "200 -14\n", 241 | "220 -15\n", 242 | "240 -14\n", 243 | "260 -12\n", 244 | "280 -14\n" 245 | ] 246 | } 247 | ], 248 | "source": [ 249 | "#训练\n", 250 | "def train():\n", 251 | " for epoch in range(300):\n", 252 | " #初始化当前位置\n", 253 | " row = random.choice(range(4))\n", 254 | " col = 0\n", 255 | "\n", 256 | " #初始化第一个动作\n", 257 | " action = get_action(row, col)\n", 258 | "\n", 259 | " #计算反馈的和,这个数字应该越来越小\n", 260 | " reward_sum = 0\n", 261 | "\n", 262 | " #循环直到到达终点或者掉进陷阱\n", 263 | " while get_state(row, col) not in ['terminal', 'trap']:\n", 264 | "\n", 265 | " #执行动作\n", 266 | " next_row, next_col, reward = move(row, col, action)\n", 267 | " reward_sum += reward\n", 268 | "\n", 269 | " #求新位置的动作\n", 270 | " next_action = get_action(next_row, next_col)\n", 271 | "\n", 272 | " #计算分数\n", 273 | " update = get_update(row, col, action, reward, next_row, next_col)\n", 274 | "\n", 275 | " #更新分数\n", 276 | " Q[row, col, action] += update\n", 277 | "\n", 278 | " #将数据添加到模型中\n", 279 | " history[(row, col, action)] = next_row, next_col, reward\n", 280 | "\n", 281 | " #反刍历史数据,进行离线学习\n", 282 | " q_planning()\n", 283 | "\n", 284 | " #更新当前位置\n", 285 | " row = next_row\n", 286 | " col = next_col\n", 287 | " action = next_action\n", 288 | "\n", 289 | " if epoch % 20 == 0:\n", 290 | " print(epoch, reward_sum)\n", 291 | "\n", 292 | "\n", 293 | "train()" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 8, 299 | "metadata": {}, 300 | "outputs": [ 301 | { 302 | "name": "stdout", 303 | "output_type": "stream", 304 | "text": [ 305 | "□□□□□□□□□□□□\n", 306 | "□↑□□□□□□□□□□\n", 307 | "□□□□□□□□□□□□\n", 308 | "□○○○○○○○○○○❤\n" 309 | ] 310 | } 311 | ], 312 | "source": [ 313 | "#打印游戏,方便测试\n", 314 | "def show(row, col, action):\n", 315 | " graph = [\n", 316 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 317 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 318 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '○', '○', '○', '○', '○',\n", 319 | " '○', '○', '○', '○', '○', '❤'\n", 320 | " ]\n", 321 | "\n", 322 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 323 | "\n", 324 | " graph[row * 12 + col] = action\n", 325 | "\n", 326 | " graph = ''.join(graph)\n", 327 | "\n", 328 | " for i in range(0, 4 * 12, 12):\n", 329 | " print(graph[i:i + 12])\n", 330 | "\n", 331 | "\n", 332 | "show(1, 1, 0)" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 9, 338 | "metadata": { 339 | "scrolled": false 340 | }, 341 | "outputs": [ 342 | { 343 | "name": "stdout", 344 | "output_type": "stream", 345 | "text": [ 346 | "□□□□□□□□□□□□\n", 347 | "□□□□□□□□□□□□\n", 348 | "□□□□□□□□□□□↓\n", 349 | "□○○○○○○○○○○❤\n" 350 | ] 351 | } 352 | ], 353 | "source": [ 354 | "from IPython import display\n", 355 | "import time\n", 356 | "\n", 357 | "\n", 358 | "def test():\n", 359 | " #起点\n", 360 | " row = random.choice(range(4))\n", 361 | " col = 0\n", 362 | "\n", 363 | " #最多玩N步\n", 364 | " for _ in range(200):\n", 365 | "\n", 366 | " #获取当前状态,如果状态是终点或者掉陷阱则终止\n", 367 | " if get_state(row, col) in ['trap', 'terminal']:\n", 368 | " break\n", 369 | "\n", 370 | " #选择最优动作\n", 371 | " action = Q[row, col].argmax()\n", 372 | "\n", 373 | " #打印这个动作\n", 374 | " display.clear_output(wait=True)\n", 375 | " time.sleep(0.1)\n", 376 | " show(row, col, action)\n", 377 | "\n", 378 | " #执行动作\n", 379 | " row, col, reward = move(row, col, action)\n", 380 | "\n", 381 | " \n", 382 | "\n", 383 | "\n", 384 | "test()" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": 10, 390 | "metadata": {}, 391 | "outputs": [ 392 | { 393 | "name": "stdout", 394 | "output_type": "stream", 395 | "text": [ 396 | "→→↓↓↓↓↓↓↓↓↓↓\n", 397 | "↓↓↓↓↓↓↓↓↓↓↓↓\n", 398 | "→→→→→→→→→→→↓\n", 399 | "↑↑↑↑↑↑↑↑↑↑↑↑\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "#打印所有格子的动作倾向\n", 405 | "for row in range(4):\n", 406 | " line = ''\n", 407 | " for col in range(12):\n", 408 | " action = Q[row, col].argmax()\n", 409 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 410 | " line += action\n", 411 | " print(line)" 412 | ] 413 | } 414 | ], 415 | "metadata": { 416 | "colab": { 417 | "collapsed_sections": [], 418 | "name": "第5章-时序差分算法.ipynb", 419 | "provenance": [] 420 | }, 421 | "kernelspec": { 422 | "display_name": "Python 3", 423 | "language": "python", 424 | "name": "python3" 425 | }, 426 | "language_info": { 427 | "codemirror_mode": { 428 | "name": "ipython", 429 | "version": 3 430 | }, 431 | "file_extension": ".py", 432 | "mimetype": "text/x-python", 433 | "name": "python", 434 | "nbconvert_exporter": "python", 435 | "pygments_lexer": "ipython3", 436 | "version": "3.6.13" 437 | } 438 | }, 439 | "nbformat": 4, 440 | "nbformat_minor": 1 441 | } 442 | -------------------------------------------------------------------------------- /5.DynaQ算法/1.DynaQ.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "'ground'" 12 | ] 13 | }, 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "#获取一个格子的状态\n", 21 | "def get_state(row, col):\n", 22 | " if row != 3:\n", 23 | " return 'ground'\n", 24 | "\n", 25 | " if row == 3 and col == 0:\n", 26 | " return 'ground'\n", 27 | "\n", 28 | " if row == 3 and col == 11:\n", 29 | " return 'terminal'\n", 30 | "\n", 31 | " return 'trap'\n", 32 | "\n", 33 | "\n", 34 | "get_state(0, 0)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "(0, 1, -1)" 46 | ] 47 | }, 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | } 52 | ], 53 | "source": [ 54 | "#在一个格子里做一个动作\n", 55 | "def move(row, col, action):\n", 56 | " #如果当前已经在陷阱或者终点,则不能执行任何动作\n", 57 | " if get_state(row, col) in ['trap', 'terminal']:\n", 58 | " return row, col, 0\n", 59 | "\n", 60 | " #↑\n", 61 | " if action == 0:\n", 62 | " row -= 1\n", 63 | "\n", 64 | " #↓\n", 65 | " if action == 1:\n", 66 | " row += 1\n", 67 | "\n", 68 | " #←\n", 69 | " if action == 2:\n", 70 | " col -= 1\n", 71 | "\n", 72 | " #→\n", 73 | " if action == 3:\n", 74 | " col += 1\n", 75 | "\n", 76 | " #不允许走到地图外面去\n", 77 | " row = max(0, row)\n", 78 | " row = min(3, row)\n", 79 | " col = max(0, col)\n", 80 | " col = min(11, col)\n", 81 | "\n", 82 | " #是陷阱的话,奖励是-100,否则都是-1\n", 83 | " reward = -1\n", 84 | " if get_state(row, col) == 'trap':\n", 85 | " reward = -100\n", 86 | "\n", 87 | " return row, col, reward\n", 88 | "\n", 89 | "\n", 90 | "move(0, 0, 3)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "data": { 100 | "text/plain": [ 101 | "((4, 12, 4), {})" 102 | ] 103 | }, 104 | "execution_count": 3, 105 | "metadata": {}, 106 | "output_type": "execute_result" 107 | } 108 | ], 109 | "source": [ 110 | "import numpy as np\n", 111 | "\n", 112 | "#初始化在每一个格子里采取每个动作的分数,初始化都是0,因为没有任何的知识\n", 113 | "Q = np.zeros([4, 12, 4])\n", 114 | "\n", 115 | "#保存历史数据,键是(row,col,action),值是(next_row,next_col,reward)\n", 116 | "history = dict()\n", 117 | "\n", 118 | "Q.shape, history" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 4, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "data": { 128 | "text/plain": [ 129 | "0" 130 | ] 131 | }, 132 | "execution_count": 4, 133 | "metadata": {}, 134 | "output_type": "execute_result" 135 | } 136 | ], 137 | "source": [ 138 | "import random\n", 139 | "\n", 140 | "\n", 141 | "#根据状态选择一个动作\n", 142 | "def get_action(row, col):\n", 143 | " #有小概率选择随机动作\n", 144 | " if random.random() < 0.1:\n", 145 | " return random.choice(range(4))\n", 146 | "\n", 147 | " #否则选择分数最高的动作\n", 148 | " return Q[row, col].argmax()\n", 149 | "\n", 150 | "\n", 151 | "get_action(0, 0)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 5, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "data": { 161 | "text/plain": [ 162 | "-0.1" 163 | ] 164 | }, 165 | "execution_count": 5, 166 | "metadata": {}, 167 | "output_type": "execute_result" 168 | } 169 | ], 170 | "source": [ 171 | "def get_update(row, col, action, reward, next_row, next_col):\n", 172 | " #target为下一个格子的最高分数,这里的计算和下一步的动作无关\n", 173 | " target = 0.9 * Q[next_row, next_col].max()\n", 174 | " #加上本步的分数\n", 175 | " target += reward\n", 176 | "\n", 177 | " #计算value\n", 178 | " value = Q[row, col, action]\n", 179 | "\n", 180 | " #根据时序差分算法,当前state,action的分数 = 下一个state,action的分数*gamma + reward\n", 181 | " #此处是求两者的差,越接近0越好\n", 182 | " update = target - value\n", 183 | "\n", 184 | " #这个0.1相当于lr\n", 185 | " update *= 0.1\n", 186 | "\n", 187 | " return update\n", 188 | "\n", 189 | "\n", 190 | "get_update(0, 0, 3, -1, 0, 1)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 6, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "import random\n", 200 | "\n", 201 | "\n", 202 | "def q_planning():\n", 203 | " #Q planning循环,相当于是在反刍历史数据,随机取N个历史数据再进行离线学习\n", 204 | " for _ in range(20):\n", 205 | " #随机选择曾经遇到过的状态动作对\n", 206 | " row, col, action = random.choice(list(history.keys()))\n", 207 | "\n", 208 | " #再获取下一个状态和反馈\n", 209 | " next_row, next_col, reward = history[(row, col, action)]\n", 210 | "\n", 211 | " #计算分数\n", 212 | " update = get_update(row, col, action, reward, next_row, next_col)\n", 213 | "\n", 214 | " #更新分数\n", 215 | " Q[row, col, action] += update\n", 216 | "\n", 217 | "\n", 218 | "#q_planning()" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 7, 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "name": "stdout", 228 | "output_type": "stream", 229 | "text": [ 230 | "0 -129\n", 231 | "20 -108\n", 232 | "40 -16\n", 233 | "60 -12\n", 234 | "80 -15\n", 235 | "100 -13\n", 236 | "120 -12\n", 237 | "140 -15\n", 238 | "160 -13\n", 239 | "180 -15\n", 240 | "200 -14\n", 241 | "220 -15\n", 242 | "240 -14\n", 243 | "260 -12\n", 244 | "280 -14\n" 245 | ] 246 | } 247 | ], 248 | "source": [ 249 | "#训练\n", 250 | "def train():\n", 251 | " for epoch in range(300):\n", 252 | " #初始化当前位置\n", 253 | " row = random.choice(range(4))\n", 254 | " col = 0\n", 255 | "\n", 256 | " #初始化第一个动作\n", 257 | " action = get_action(row, col)\n", 258 | "\n", 259 | " #计算反馈的和,这个数字应该越来越小\n", 260 | " reward_sum = 0\n", 261 | "\n", 262 | " #循环直到到达终点或者掉进陷阱\n", 263 | " while get_state(row, col) not in ['terminal', 'trap']:\n", 264 | "\n", 265 | " #执行动作\n", 266 | " next_row, next_col, reward = move(row, col, action)\n", 267 | " reward_sum += reward\n", 268 | "\n", 269 | " #求新位置的动作\n", 270 | " next_action = get_action(next_row, next_col)\n", 271 | "\n", 272 | " #计算分数\n", 273 | " update = get_update(row, col, action, reward, next_row, next_col)\n", 274 | "\n", 275 | " #更新分数\n", 276 | " Q[row, col, action] += update\n", 277 | "\n", 278 | " #将数据添加到模型中\n", 279 | " history[(row, col, action)] = next_row, next_col, reward\n", 280 | "\n", 281 | " #反刍历史数据,进行离线学习\n", 282 | " q_planning()\n", 283 | "\n", 284 | " #更新当前位置\n", 285 | " row = next_row\n", 286 | " col = next_col\n", 287 | " action = next_action\n", 288 | "\n", 289 | " if epoch % 20 == 0:\n", 290 | " print(epoch, reward_sum)\n", 291 | "\n", 292 | "\n", 293 | "train()" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 8, 299 | "metadata": {}, 300 | "outputs": [ 301 | { 302 | "name": "stdout", 303 | "output_type": "stream", 304 | "text": [ 305 | "□□□□□□□□□□□□\n", 306 | "□↑□□□□□□□□□□\n", 307 | "□□□□□□□□□□□□\n", 308 | "□○○○○○○○○○○❤\n" 309 | ] 310 | } 311 | ], 312 | "source": [ 313 | "#打印游戏,方便测试\n", 314 | "def show(row, col, action):\n", 315 | " graph = [\n", 316 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 317 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',\n", 318 | " '□', '□', '□', '□', '□', '□', '□', '□', '□', '○', '○', '○', '○', '○',\n", 319 | " '○', '○', '○', '○', '○', '❤'\n", 320 | " ]\n", 321 | "\n", 322 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 323 | "\n", 324 | " graph[row * 12 + col] = action\n", 325 | "\n", 326 | " graph = ''.join(graph)\n", 327 | "\n", 328 | " for i in range(0, 4 * 12, 12):\n", 329 | " print(graph[i:i + 12])\n", 330 | "\n", 331 | "\n", 332 | "show(1, 1, 0)" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 9, 338 | "metadata": { 339 | "scrolled": false 340 | }, 341 | "outputs": [ 342 | { 343 | "name": "stdout", 344 | "output_type": "stream", 345 | "text": [ 346 | "□□□□□□□□□□□□\n", 347 | "□□□□□□□□□□□□\n", 348 | "□□□□□□□□□□□↓\n", 349 | "□○○○○○○○○○○❤\n" 350 | ] 351 | } 352 | ], 353 | "source": [ 354 | "from IPython import display\n", 355 | "import time\n", 356 | "\n", 357 | "\n", 358 | "def test():\n", 359 | " #起点\n", 360 | " row = random.choice(range(4))\n", 361 | " col = 0\n", 362 | "\n", 363 | " #最多玩N步\n", 364 | " for _ in range(200):\n", 365 | "\n", 366 | " #获取当前状态,如果状态是终点或者掉陷阱则终止\n", 367 | " if get_state(row, col) in ['trap', 'terminal']:\n", 368 | " break\n", 369 | "\n", 370 | " #选择最优动作\n", 371 | " action = Q[row, col].argmax()\n", 372 | "\n", 373 | " #打印这个动作\n", 374 | " display.clear_output(wait=True)\n", 375 | " time.sleep(0.1)\n", 376 | " show(row, col, action)\n", 377 | "\n", 378 | " #执行动作\n", 379 | " row, col, reward = move(row, col, action)\n", 380 | "\n", 381 | " \n", 382 | "\n", 383 | "\n", 384 | "test()" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": 10, 390 | "metadata": {}, 391 | "outputs": [ 392 | { 393 | "name": "stdout", 394 | "output_type": "stream", 395 | "text": [ 396 | "→→↓↓↓↓↓↓↓↓↓↓\n", 397 | "↓↓↓↓↓↓↓↓↓↓↓↓\n", 398 | "→→→→→→→→→→→↓\n", 399 | "↑↑↑↑↑↑↑↑↑↑↑↑\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "#打印所有格子的动作倾向\n", 405 | "for row in range(4):\n", 406 | " line = ''\n", 407 | " for col in range(12):\n", 408 | " action = Q[row, col].argmax()\n", 409 | " action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]\n", 410 | " line += action\n", 411 | " print(line)" 412 | ] 413 | } 414 | ], 415 | "metadata": { 416 | "colab": { 417 | "collapsed_sections": [], 418 | "name": "第5章-时序差分算法.ipynb", 419 | "provenance": [] 420 | }, 421 | "kernelspec": { 422 | "display_name": "Python 3", 423 | "language": "python", 424 | "name": "python3" 425 | }, 426 | "language_info": { 427 | "codemirror_mode": { 428 | "name": "ipython", 429 | "version": 3 430 | }, 431 | "file_extension": ".py", 432 | "mimetype": "text/x-python", 433 | "name": "python", 434 | "nbconvert_exporter": "python", 435 | "pygments_lexer": "ipython3", 436 | "version": "3.6.13" 437 | } 438 | }, 439 | "nbformat": 4, 440 | "nbformat_minor": 1 441 | } 442 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 视频课程:https://www.bilibili.com/video/BV1Ge4y1i7L6/ 2 | 3 |
4 | 运行环境: 5 |
6 | python=3.9 7 |
8 | pytorch=1.12.1 9 |
10 | gym=0.26.2 11 | 12 |

13 | 2023年5月5日更新:gym版本升级到0.26.2, python版本升级到3.9, torch升级到1.12.1 14 | --------------------------------------------------------------------------------