├── LICENSE ├── README.md ├── rl_utils.py ├── 第10章-Actor-Critic算法.ipynb ├── 第11章-TRPO算法.ipynb ├── 第12章-PPO算法.ipynb ├── 第13章-DDPG算法.ipynb ├── 第14章-SAC算法.ipynb ├── 第15章-模仿学习.ipynb ├── 第16章-模型预测控制.ipynb ├── 第17章-基于模型的策略优化.ipynb ├── 第18章-离线强化学习.ipynb ├── 第19章-目标导向的强化学习.ipynb ├── 第20章-多智能体强化学习入门.ipynb ├── 第21章-多智能体强化学习进阶.ipynb ├── 第2章-多臂老虎机问题.ipynb ├── 第3章-马尔可夫决策过程.ipynb ├── 第4章-动态规划算法.ipynb ├── 第5章-时序差分算法.ipynb ├── 第6章-Dyna-Q算法.ipynb ├── 第7章-DQN算法.ipynb ├── 第8章-DQN改进算法.ipynb └── 第9章-策略梯度算法.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 动手学强化学习 2 | 3 | Tips: 若运行gym环境的代码时遇到报错,请尝试pip install gym==0.18.3安装此版本的gym库,若仍有问题,欢迎提交issue! 4 | 5 | 欢迎来到《动手学强化学习》(Hands-on Reinforcement Learning)的地带。该系列从强化学习的定义等基础讲起,一步步由浅入深,介绍目前一些主流的强化学习算法。每一章内容都是一个Jupyter Notebook,内含详细的图文介绍和代码讲解。 6 | 7 | * 由于GitHub上渲染notebook效果有限,我们推荐读者前往[Hands-on RL主页](https://hrl.boyuai.com/)进行浏览,我们在此提供了纯代码版本的notebook,供大家下载运行。 8 | 9 | * 欢迎在[京东](https://item.jd.com/13129509.html)和[当当网](http://product.dangdang.com/29391150.html)购买《动手学强化学习》。 10 | 11 | * 如果你发现了本书的任何问题,或者有任何改善建议的,欢迎提交issue! 12 | 13 | * 本书配套的强化学习课程已上线到[伯禹学习平台](https://www.boyuai.com/elites/course/xVqhU42F5IDky94x),所有人都可以免费学习和讨论。 14 | 15 | ![](https://boyuai.oss-cn-shanghai.aliyuncs.com/disk/tmp/hrl-poster.jpeg) 16 | -------------------------------------------------------------------------------- /rl_utils.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import torch 4 | import collections 5 | import random 6 | 7 | class ReplayBuffer: 8 | def __init__(self, capacity): 9 | self.buffer = collections.deque(maxlen=capacity) 10 | 11 | def add(self, state, action, reward, next_state, done): 12 | self.buffer.append((state, action, reward, next_state, done)) 13 | 14 | def sample(self, batch_size): 15 | transitions = random.sample(self.buffer, batch_size) 16 | state, action, reward, next_state, done = zip(*transitions) 17 | return np.array(state), action, reward, np.array(next_state), done 18 | 19 | def size(self): 20 | return len(self.buffer) 21 | 22 | def moving_average(a, window_size): 23 | cumulative_sum = np.cumsum(np.insert(a, 0, 0)) 24 | middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size 25 | r = np.arange(1, window_size-1, 2) 26 | begin = np.cumsum(a[:window_size-1])[::2] / r 27 | end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1] 28 | return np.concatenate((begin, middle, end)) 29 | 30 | def train_on_policy_agent(env, agent, num_episodes): 31 | return_list = [] 32 | for i in range(10): 33 | with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar: 34 | for i_episode in range(int(num_episodes/10)): 35 | episode_return = 0 36 | transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []} 37 | state = env.reset() 38 | done = False 39 | while not done: 40 | action = agent.take_action(state) 41 | next_state, reward, done, _ = env.step(action) 42 | transition_dict['states'].append(state) 43 | transition_dict['actions'].append(action) 44 | transition_dict['next_states'].append(next_state) 45 | transition_dict['rewards'].append(reward) 46 | transition_dict['dones'].append(done) 47 | state = next_state 48 | episode_return += reward 49 | return_list.append(episode_return) 50 | agent.update(transition_dict) 51 | if (i_episode+1) % 10 == 0: 52 | pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])}) 53 | pbar.update(1) 54 | return return_list 55 | 56 | def train_off_policy_agent(env, agent, num_episodes, replay_buffer, minimal_size, batch_size): 57 | return_list = [] 58 | for i in range(10): 59 | with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar: 60 | for i_episode in range(int(num_episodes/10)): 61 | episode_return = 0 62 | state = env.reset() 63 | done = False 64 | while not done: 65 | action = agent.take_action(state) 66 | next_state, reward, done, _ = env.step(action) 67 | replay_buffer.add(state, action, reward, next_state, done) 68 | state = next_state 69 | episode_return += reward 70 | if replay_buffer.size() > minimal_size: 71 | b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size) 72 | transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d} 73 | agent.update(transition_dict) 74 | return_list.append(episode_return) 75 | if (i_episode+1) % 10 == 0: 76 | pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])}) 77 | pbar.update(1) 78 | return return_list 79 | 80 | 81 | def compute_advantage(gamma, lmbda, td_delta): 82 | td_delta = td_delta.detach().numpy() 83 | advantage_list = [] 84 | advantage = 0.0 85 | for delta in td_delta[::-1]: 86 | advantage = gamma * lmbda * advantage + delta 87 | advantage_list.append(advantage) 88 | advantage_list.reverse() 89 | return torch.tensor(advantage_list, dtype=torch.float) 90 | -------------------------------------------------------------------------------- /第16章-模型预测控制.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "executionInfo": { 8 | "elapsed": 6698, 9 | "status": "ok", 10 | "timestamp": 1649956814219, 11 | "user": { 12 | "displayName": "Sam Lu", 13 | "userId": "15789059763790170725" 14 | }, 15 | "user_tz": -480 16 | }, 17 | "id": "pkDNguALCr-X" 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "import numpy as np\n", 22 | "from scipy.stats import truncnorm\n", 23 | "import gym\n", 24 | "import itertools\n", 25 | "import torch\n", 26 | "import torch.nn as nn\n", 27 | "import torch.nn.functional as F\n", 28 | "import collections\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "\n", 31 | "\n", 32 | "class CEM:\n", 33 | " def __init__(self, n_sequence, elite_ratio, fake_env, upper_bound,\n", 34 | " lower_bound):\n", 35 | " self.n_sequence = n_sequence\n", 36 | " self.elite_ratio = elite_ratio\n", 37 | " self.upper_bound = upper_bound\n", 38 | " self.lower_bound = lower_bound\n", 39 | " self.fake_env = fake_env\n", 40 | "\n", 41 | " def optimize(self, state, init_mean, init_var):\n", 42 | " mean, var = init_mean, init_var\n", 43 | " X = truncnorm(-2, 2, loc=np.zeros_like(mean), scale=np.ones_like(var))\n", 44 | " state = np.tile(state, (self.n_sequence, 1))\n", 45 | "\n", 46 | " for _ in range(5):\n", 47 | " lb_dist, ub_dist = mean - self.lower_bound, self.upper_bound - mean\n", 48 | " constrained_var = np.minimum(\n", 49 | " np.minimum(np.square(lb_dist / 2), np.square(ub_dist / 2)),\n", 50 | " var)\n", 51 | " # 生成动作序列\n", 52 | " action_sequences = [X.rvs() for _ in range(self.n_sequence)\n", 53 | " ] * np.sqrt(constrained_var) + mean\n", 54 | " # 计算每条动作序列的累积奖励\n", 55 | " returns = self.fake_env.propagate(state, action_sequences)[:, 0]\n", 56 | " # 选取累积奖励最高的若干条动作序列\n", 57 | " elites = action_sequences[np.argsort(\n", 58 | " returns)][-int(self.elite_ratio * self.n_sequence):]\n", 59 | " new_mean = np.mean(elites, axis=0)\n", 60 | " new_var = np.var(elites, axis=0)\n", 61 | " # 更新动作序列分布\n", 62 | " mean = 0.1 * mean + 0.9 * new_mean\n", 63 | " var = 0.1 * var + 0.9 * new_var\n", 64 | "\n", 65 | " return mean" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 2, 71 | "metadata": { 72 | "executionInfo": { 73 | "elapsed": 9, 74 | "status": "ok", 75 | "timestamp": 1649956814220, 76 | "user": { 77 | "displayName": "Sam Lu", 78 | "userId": "15789059763790170725" 79 | }, 80 | "user_tz": -480 81 | }, 82 | "id": "coGG5UOpCr-Z" 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\n", 87 | " \"cpu\")\n", 88 | "\n", 89 | "\n", 90 | "class Swish(nn.Module):\n", 91 | " ''' Swish激活函数 '''\n", 92 | " def __init__(self):\n", 93 | " super(Swish, self).__init__()\n", 94 | "\n", 95 | " def forward(self, x):\n", 96 | " return x * torch.sigmoid(x)\n", 97 | "\n", 98 | "\n", 99 | "def init_weights(m):\n", 100 | " ''' 初始化模型权重 '''\n", 101 | " def truncated_normal_init(t, mean=0.0, std=0.01):\n", 102 | " torch.nn.init.normal_(t, mean=mean, std=std)\n", 103 | " while True:\n", 104 | " cond = (t < mean - 2 * std) | (t > mean + 2 * std)\n", 105 | " if not torch.sum(cond):\n", 106 | " break\n", 107 | " t = torch.where(\n", 108 | " cond,\n", 109 | " torch.nn.init.normal_(torch.ones(t.shape, device=device),\n", 110 | " mean=mean,\n", 111 | " std=std), t)\n", 112 | " return t\n", 113 | "\n", 114 | " if type(m) == nn.Linear or isinstance(m, FCLayer):\n", 115 | " truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(m._input_dim)))\n", 116 | " m.bias.data.fill_(0.0)\n", 117 | "\n", 118 | "\n", 119 | "class FCLayer(nn.Module):\n", 120 | " ''' 集成之后的全连接层 '''\n", 121 | " def __init__(self, input_dim, output_dim, ensemble_size, activation):\n", 122 | " super(FCLayer, self).__init__()\n", 123 | " self._input_dim, self._output_dim = input_dim, output_dim\n", 124 | " self.weight = nn.Parameter(\n", 125 | " torch.Tensor(ensemble_size, input_dim, output_dim).to(device))\n", 126 | " self._activation = activation\n", 127 | " self.bias = nn.Parameter(\n", 128 | " torch.Tensor(ensemble_size, output_dim).to(device))\n", 129 | "\n", 130 | " def forward(self, x):\n", 131 | " return self._activation(\n", 132 | " torch.add(torch.bmm(x, self.weight), self.bias[:, None, :]))" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 3, 138 | "metadata": { 139 | "executionInfo": { 140 | "elapsed": 8, 141 | "status": "ok", 142 | "timestamp": 1649956814220, 143 | "user": { 144 | "displayName": "Sam Lu", 145 | "userId": "15789059763790170725" 146 | }, 147 | "user_tz": -480 148 | }, 149 | "id": "SNVDgXI2Cr-a" 150 | }, 151 | "outputs": [], 152 | "source": [ 153 | "class EnsembleModel(nn.Module):\n", 154 | " ''' 环境模型集成 '''\n", 155 | " def __init__(self,\n", 156 | " state_dim,\n", 157 | " action_dim,\n", 158 | " ensemble_size=5,\n", 159 | " learning_rate=1e-3):\n", 160 | " super(EnsembleModel, self).__init__()\n", 161 | " # 输出包括均值和方差,因此是状态与奖励维度之和的两倍\n", 162 | " self._output_dim = (state_dim + 1) * 2\n", 163 | " self._max_logvar = nn.Parameter((torch.ones(\n", 164 | " (1, self._output_dim // 2)).float() / 2).to(device),\n", 165 | " requires_grad=False)\n", 166 | " self._min_logvar = nn.Parameter((-torch.ones(\n", 167 | " (1, self._output_dim // 2)).float() * 10).to(device),\n", 168 | " requires_grad=False)\n", 169 | "\n", 170 | " self.layer1 = FCLayer(state_dim + action_dim, 200, ensemble_size,\n", 171 | " Swish())\n", 172 | " self.layer2 = FCLayer(200, 200, ensemble_size, Swish())\n", 173 | " self.layer3 = FCLayer(200, 200, ensemble_size, Swish())\n", 174 | " self.layer4 = FCLayer(200, 200, ensemble_size, Swish())\n", 175 | " self.layer5 = FCLayer(200, self._output_dim, ensemble_size,\n", 176 | " nn.Identity())\n", 177 | " self.apply(init_weights) # 初始化环境模型中的参数\n", 178 | " self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)\n", 179 | "\n", 180 | " def forward(self, x, return_log_var=False):\n", 181 | " ret = self.layer5(self.layer4(self.layer3(self.layer2(\n", 182 | " self.layer1(x)))))\n", 183 | " mean = ret[:, :, :self._output_dim // 2]\n", 184 | " # 在PETS算法中,将方差控制在最小值和最大值之间\n", 185 | " logvar = self._max_logvar - F.softplus(\n", 186 | " self._max_logvar - ret[:, :, self._output_dim // 2:])\n", 187 | " logvar = self._min_logvar + F.softplus(logvar - self._min_logvar)\n", 188 | " return mean, logvar if return_log_var else torch.exp(logvar)\n", 189 | "\n", 190 | " def loss(self, mean, logvar, labels, use_var_loss=True):\n", 191 | " inverse_var = torch.exp(-logvar)\n", 192 | " if use_var_loss:\n", 193 | " mse_loss = torch.mean(torch.mean(torch.pow(mean - labels, 2) *\n", 194 | " inverse_var,\n", 195 | " dim=-1),\n", 196 | " dim=-1)\n", 197 | " var_loss = torch.mean(torch.mean(logvar, dim=-1), dim=-1)\n", 198 | " total_loss = torch.sum(mse_loss) + torch.sum(var_loss)\n", 199 | " else:\n", 200 | " mse_loss = torch.mean(torch.pow(mean - labels, 2), dim=(1, 2))\n", 201 | " total_loss = torch.sum(mse_loss)\n", 202 | " return total_loss, mse_loss\n", 203 | "\n", 204 | " def train(self, loss):\n", 205 | " self.optimizer.zero_grad()\n", 206 | " loss += 0.01 * torch.sum(self._max_logvar) - 0.01 * torch.sum(\n", 207 | " self._min_logvar)\n", 208 | " loss.backward()\n", 209 | " self.optimizer.step()" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 4, 215 | "metadata": { 216 | "executionInfo": { 217 | "elapsed": 8, 218 | "status": "ok", 219 | "timestamp": 1649956814221, 220 | "user": { 221 | "displayName": "Sam Lu", 222 | "userId": "15789059763790170725" 223 | }, 224 | "user_tz": -480 225 | }, 226 | "id": "kVE0nKi6Cr-b" 227 | }, 228 | "outputs": [], 229 | "source": [ 230 | "class EnsembleDynamicsModel:\n", 231 | " ''' 环境模型集成,加入精细化的训练 '''\n", 232 | " def __init__(self, state_dim, action_dim, num_network=5):\n", 233 | " self._num_network = num_network\n", 234 | " self._state_dim, self._action_dim = state_dim, action_dim\n", 235 | " self.model = EnsembleModel(state_dim,\n", 236 | " action_dim,\n", 237 | " ensemble_size=num_network)\n", 238 | " self._epoch_since_last_update = 0\n", 239 | "\n", 240 | " def train(self,\n", 241 | " inputs,\n", 242 | " labels,\n", 243 | " batch_size=64,\n", 244 | " holdout_ratio=0.1,\n", 245 | " max_iter=20):\n", 246 | " # 设置训练集与验证集\n", 247 | " permutation = np.random.permutation(inputs.shape[0])\n", 248 | " inputs, labels = inputs[permutation], labels[permutation]\n", 249 | " num_holdout = int(inputs.shape[0] * holdout_ratio)\n", 250 | " train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:]\n", 251 | " holdout_inputs, holdout_labels = inputs[:\n", 252 | " num_holdout], labels[:\n", 253 | " num_holdout]\n", 254 | " holdout_inputs = torch.from_numpy(holdout_inputs).float().to(device)\n", 255 | " holdout_labels = torch.from_numpy(holdout_labels).float().to(device)\n", 256 | " holdout_inputs = holdout_inputs[None, :, :].repeat(\n", 257 | " [self._num_network, 1, 1])\n", 258 | " holdout_labels = holdout_labels[None, :, :].repeat(\n", 259 | " [self._num_network, 1, 1])\n", 260 | "\n", 261 | " # 保留最好的结果\n", 262 | " self._snapshots = {i: (None, 1e10) for i in range(self._num_network)}\n", 263 | "\n", 264 | " for epoch in itertools.count():\n", 265 | " # 定义每一个网络的训练数据\n", 266 | " train_index = np.vstack([\n", 267 | " np.random.permutation(train_inputs.shape[0])\n", 268 | " for _ in range(self._num_network)\n", 269 | " ])\n", 270 | " # 所有真实数据都用来训练\n", 271 | " for batch_start_pos in range(0, train_inputs.shape[0], batch_size):\n", 272 | " batch_index = train_index[:, batch_start_pos:batch_start_pos +\n", 273 | " batch_size]\n", 274 | " train_input = torch.from_numpy(\n", 275 | " train_inputs[batch_index]).float().to(device)\n", 276 | " train_label = torch.from_numpy(\n", 277 | " train_labels[batch_index]).float().to(device)\n", 278 | "\n", 279 | " mean, logvar = self.model(train_input, return_log_var=True)\n", 280 | " loss, _ = self.model.loss(mean, logvar, train_label)\n", 281 | " self.model.train(loss)\n", 282 | "\n", 283 | " with torch.no_grad():\n", 284 | " mean, logvar = self.model(holdout_inputs, return_log_var=True)\n", 285 | " _, holdout_losses = self.model.loss(mean,\n", 286 | " logvar,\n", 287 | " holdout_labels,\n", 288 | " use_var_loss=False)\n", 289 | " holdout_losses = holdout_losses.cpu()\n", 290 | " break_condition = self._save_best(epoch, holdout_losses)\n", 291 | " if break_condition or epoch > max_iter: # 结束训练\n", 292 | " break\n", 293 | "\n", 294 | " def _save_best(self, epoch, losses, threshold=0.1):\n", 295 | " updated = False\n", 296 | " for i in range(len(losses)):\n", 297 | " current = losses[i]\n", 298 | " _, best = self._snapshots[i]\n", 299 | " improvement = (best - current) / best\n", 300 | " if improvement > threshold:\n", 301 | " self._snapshots[i] = (epoch, current)\n", 302 | " updated = True\n", 303 | " self._epoch_since_last_update = 0 if updated else self._epoch_since_last_update + 1\n", 304 | " return self._epoch_since_last_update > 5\n", 305 | "\n", 306 | " def predict(self, inputs, batch_size=64):\n", 307 | " mean, var = [], []\n", 308 | " for i in range(0, inputs.shape[0], batch_size):\n", 309 | " input = torch.from_numpy(\n", 310 | " inputs[i:min(i +\n", 311 | " batch_size, inputs.shape[0])]).float().to(device)\n", 312 | " cur_mean, cur_var = self.model(input[None, :, :].repeat(\n", 313 | " [self._num_network, 1, 1]),\n", 314 | " return_log_var=False)\n", 315 | " mean.append(cur_mean.detach().cpu().numpy())\n", 316 | " var.append(cur_var.detach().cpu().numpy())\n", 317 | " return np.hstack(mean), np.hstack(var)" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 5, 323 | "metadata": { 324 | "executionInfo": { 325 | "elapsed": 7, 326 | "status": "ok", 327 | "timestamp": 1649956814221, 328 | "user": { 329 | "displayName": "Sam Lu", 330 | "userId": "15789059763790170725" 331 | }, 332 | "user_tz": -480 333 | }, 334 | "id": "1auD04WgCr-c" 335 | }, 336 | "outputs": [], 337 | "source": [ 338 | "class FakeEnv:\n", 339 | " def __init__(self, model):\n", 340 | " self.model = model\n", 341 | "\n", 342 | " def step(self, obs, act):\n", 343 | " inputs = np.concatenate((obs, act), axis=-1)\n", 344 | " ensemble_model_means, ensemble_model_vars = self.model.predict(inputs)\n", 345 | " ensemble_model_means[:, :, 1:] += obs.numpy()\n", 346 | " ensemble_model_stds = np.sqrt(ensemble_model_vars)\n", 347 | " ensemble_samples = ensemble_model_means + np.random.normal(\n", 348 | " size=ensemble_model_means.shape) * ensemble_model_stds\n", 349 | "\n", 350 | " num_models, batch_size, _ = ensemble_model_means.shape\n", 351 | " models_to_use = np.random.choice(\n", 352 | " [i for i in range(self.model._num_network)], size=batch_size)\n", 353 | " batch_inds = np.arange(0, batch_size)\n", 354 | " samples = ensemble_samples[models_to_use, batch_inds]\n", 355 | " rewards, next_obs = samples[:, :1], samples[:, 1:]\n", 356 | " return rewards, next_obs\n", 357 | "\n", 358 | " def propagate(self, obs, actions):\n", 359 | " with torch.no_grad():\n", 360 | " obs = np.copy(obs)\n", 361 | " total_reward = np.expand_dims(np.zeros(obs.shape[0]), axis=-1)\n", 362 | " obs, actions = torch.as_tensor(obs), torch.as_tensor(actions)\n", 363 | " for i in range(actions.shape[1]):\n", 364 | " action = torch.unsqueeze(actions[:, i], 1)\n", 365 | " rewards, next_obs = self.step(obs, action)\n", 366 | " total_reward += rewards\n", 367 | " obs = torch.as_tensor(next_obs)\n", 368 | " return total_reward" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": 6, 374 | "metadata": { 375 | "executionInfo": { 376 | "elapsed": 8, 377 | "status": "ok", 378 | "timestamp": 1649956814222, 379 | "user": { 380 | "displayName": "Sam Lu", 381 | "userId": "15789059763790170725" 382 | }, 383 | "user_tz": -480 384 | }, 385 | "id": "Kl3fh7_iCr-c" 386 | }, 387 | "outputs": [], 388 | "source": [ 389 | "class ReplayBuffer:\n", 390 | " def __init__(self, capacity):\n", 391 | " self.buffer = collections.deque(maxlen=capacity)\n", 392 | "\n", 393 | " def add(self, state, action, reward, next_state, done):\n", 394 | " self.buffer.append((state, action, reward, next_state, done))\n", 395 | "\n", 396 | " def size(self):\n", 397 | " return len(self.buffer)\n", 398 | "\n", 399 | " def return_all_samples(self):\n", 400 | " all_transitions = list(self.buffer)\n", 401 | " state, action, reward, next_state, done = zip(*all_transitions)\n", 402 | " return np.array(state), action, reward, np.array(next_state), done" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": 7, 408 | "metadata": { 409 | "executionInfo": { 410 | "elapsed": 11, 411 | "status": "ok", 412 | "timestamp": 1649956814723, 413 | "user": { 414 | "displayName": "Sam Lu", 415 | "userId": "15789059763790170725" 416 | }, 417 | "user_tz": -480 418 | }, 419 | "id": "7iPZNkXHCr-d" 420 | }, 421 | "outputs": [], 422 | "source": [ 423 | "class PETS:\n", 424 | " ''' PETS算法 '''\n", 425 | " def __init__(self, env, replay_buffer, n_sequence, elite_ratio,\n", 426 | " plan_horizon, num_episodes):\n", 427 | " self._env = env\n", 428 | " self._env_pool = ReplayBuffer(buffer_size)\n", 429 | "\n", 430 | " obs_dim = env.observation_space.shape[0]\n", 431 | " self._action_dim = env.action_space.shape[0]\n", 432 | " self._model = EnsembleDynamicsModel(obs_dim, self._action_dim)\n", 433 | " self._fake_env = FakeEnv(self._model)\n", 434 | " self.upper_bound = env.action_space.high[0]\n", 435 | " self.lower_bound = env.action_space.low[0]\n", 436 | "\n", 437 | " self._cem = CEM(n_sequence, elite_ratio, self._fake_env,\n", 438 | " self.upper_bound, self.lower_bound)\n", 439 | " self.plan_horizon = plan_horizon\n", 440 | " self.num_episodes = num_episodes\n", 441 | "\n", 442 | " def train_model(self):\n", 443 | " env_samples = self._env_pool.return_all_samples()\n", 444 | " obs = env_samples[0]\n", 445 | " actions = np.array(env_samples[1])\n", 446 | " rewards = np.array(env_samples[2]).reshape(-1, 1)\n", 447 | " next_obs = env_samples[3]\n", 448 | " inputs = np.concatenate((obs, actions), axis=-1)\n", 449 | " labels = np.concatenate((rewards, next_obs - obs), axis=-1)\n", 450 | " self._model.train(inputs, labels)\n", 451 | "\n", 452 | " def mpc(self):\n", 453 | " mean = np.tile((self.upper_bound + self.lower_bound) / 2.0,\n", 454 | " self.plan_horizon)\n", 455 | " var = np.tile(\n", 456 | " np.square(self.upper_bound - self.lower_bound) / 16,\n", 457 | " self.plan_horizon)\n", 458 | " obs, done, episode_return = self._env.reset(), False, 0\n", 459 | " while not done:\n", 460 | " actions = self._cem.optimize(obs, mean, var)\n", 461 | " action = actions[:self._action_dim] # 选取第一个动作\n", 462 | " next_obs, reward, done, _ = self._env.step(action)\n", 463 | " self._env_pool.add(obs, action, reward, next_obs, done)\n", 464 | " obs = next_obs\n", 465 | " episode_return += reward\n", 466 | " mean = np.concatenate([\n", 467 | " np.copy(actions)[self._action_dim:],\n", 468 | " np.zeros(self._action_dim)\n", 469 | " ])\n", 470 | " return episode_return\n", 471 | "\n", 472 | " def explore(self):\n", 473 | " obs, done, episode_return = self._env.reset(), False, 0\n", 474 | " while not done:\n", 475 | " action = self._env.action_space.sample()\n", 476 | " next_obs, reward, done, _ = self._env.step(action)\n", 477 | " self._env_pool.add(obs, action, reward, next_obs, done)\n", 478 | " obs = next_obs\n", 479 | " episode_return += reward\n", 480 | " return episode_return\n", 481 | "\n", 482 | " def train(self):\n", 483 | " return_list = []\n", 484 | " explore_return = self.explore() # 先进行随机策略的探索来收集一条序列的数据\n", 485 | " print('episode: 1, return: %d' % explore_return)\n", 486 | " return_list.append(explore_return)\n", 487 | "\n", 488 | " for i_episode in range(self.num_episodes - 1):\n", 489 | " self.train_model()\n", 490 | " episode_return = self.mpc()\n", 491 | " return_list.append(episode_return)\n", 492 | " print('episode: %d, return: %d' % (i_episode + 2, episode_return))\n", 493 | " return return_list" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 8, 499 | "metadata": { 500 | "colab": { 501 | "base_uri": "https://localhost:8080/", 502 | "height": 469 503 | }, 504 | "executionInfo": { 505 | "elapsed": 641756, 506 | "status": "ok", 507 | "timestamp": 1649957456469, 508 | "user": { 509 | "displayName": "Sam Lu", 510 | "userId": "15789059763790170725" 511 | }, 512 | "user_tz": -480 513 | }, 514 | "id": "pfzBBzuECr-d", 515 | "outputId": "4b8d971d-31e3-47c6-aa29-00f0f4c062a3" 516 | }, 517 | "outputs": [ 518 | { 519 | "name": "stdout", 520 | "output_type": "stream", 521 | "text": [ 522 | "episode: 1, return: -985\n", 523 | "episode: 2, return: -1384\n", 524 | "episode: 3, return: -1006\n", 525 | "episode: 4, return: -1853\n", 526 | "episode: 5, return: -378\n", 527 | "episode: 6, return: -123\n", 528 | "episode: 7, return: -124\n", 529 | "episode: 8, return: -122\n", 530 | "episode: 9, return: -124\n", 531 | "episode: 10, return: -125\n" 532 | ] 533 | }, 534 | { 535 | "data": { 536 | "image/png": "\n", 537 | "text/plain": [ 538 | "
" 539 | ] 540 | }, 541 | "metadata": { 542 | "needs_background": "light" 543 | }, 544 | "output_type": "display_data" 545 | } 546 | ], 547 | "source": [ 548 | "buffer_size = 100000\n", 549 | "n_sequence = 50\n", 550 | "elite_ratio = 0.2\n", 551 | "plan_horizon = 25\n", 552 | "num_episodes = 10\n", 553 | "env_name = 'Pendulum-v0'\n", 554 | "env = gym.make(env_name)\n", 555 | "\n", 556 | "replay_buffer = ReplayBuffer(buffer_size)\n", 557 | "pets = PETS(env, replay_buffer, n_sequence, elite_ratio, plan_horizon,\n", 558 | " num_episodes)\n", 559 | "return_list = pets.train()\n", 560 | "\n", 561 | "episodes_list = list(range(len(return_list)))\n", 562 | "plt.plot(episodes_list, return_list)\n", 563 | "plt.xlabel('Episodes')\n", 564 | "plt.ylabel('Returns')\n", 565 | "plt.title('PETS on {}'.format(env_name))\n", 566 | "plt.show()\n", 567 | "\n", 568 | "# episode: 1, return: -1062\n", 569 | "# episode: 2, return: -1257\n", 570 | "# episode: 3, return: -1792\n", 571 | "# episode: 4, return: -1225\n", 572 | "# episode: 5, return: -248\n", 573 | "# episode: 6, return: -124\n", 574 | "# episode: 7, return: -249\n", 575 | "# episode: 8, return: -269\n", 576 | "# episode: 9, return: -245\n", 577 | "# episode: 10, return: -119" 578 | ] 579 | } 580 | ], 581 | "metadata": { 582 | "colab": { 583 | "collapsed_sections": [], 584 | "name": "第16章-模型预测控制.ipynb", 585 | "provenance": [] 586 | }, 587 | "kernelspec": { 588 | "display_name": "Python 3", 589 | "language": "python", 590 | "name": "python3" 591 | }, 592 | "language_info": { 593 | "codemirror_mode": { 594 | "name": "ipython", 595 | "version": 3 596 | }, 597 | "file_extension": ".py", 598 | "mimetype": "text/x-python", 599 | "name": "python", 600 | "nbconvert_exporter": "python", 601 | "pygments_lexer": "ipython3", 602 | "version": "3.7.6" 603 | } 604 | }, 605 | "nbformat": 4, 606 | "nbformat_minor": 1 607 | } 608 | -------------------------------------------------------------------------------- /第17章-基于模型的策略优化.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "metadata": { 7 | "executionInfo": { 8 | "elapsed": 17, 9 | "status": "ok", 10 | "timestamp": 1649957428444, 11 | "user": { 12 | "displayName": "Sam Lu", 13 | "userId": "15789059763790170725" 14 | }, 15 | "user_tz": -480 16 | }, 17 | "id": "WGYnB9z5GEne" 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "import gym\n", 22 | "from collections import namedtuple\n", 23 | "import itertools\n", 24 | "from itertools import count\n", 25 | "import torch\n", 26 | "import torch.nn as nn\n", 27 | "import torch.nn.functional as F\n", 28 | "from torch.distributions.normal import Normal\n", 29 | "import numpy as np\n", 30 | "import collections\n", 31 | "import random\n", 32 | "import matplotlib.pyplot as plt" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 7, 38 | "metadata": { 39 | "executionInfo": { 40 | "elapsed": 17, 41 | "status": "ok", 42 | "timestamp": 1649957428445, 43 | "user": { 44 | "displayName": "Sam Lu", 45 | "userId": "15789059763790170725" 46 | }, 47 | "user_tz": -480 48 | }, 49 | "id": "z8M3b0CiGEnj" 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "class PolicyNet(torch.nn.Module):\n", 54 | " def __init__(self, state_dim, hidden_dim, action_dim, action_bound):\n", 55 | " super(PolicyNet, self).__init__()\n", 56 | " self.fc1 = torch.nn.Linear(state_dim, hidden_dim)\n", 57 | " self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)\n", 58 | " self.fc_std = torch.nn.Linear(hidden_dim, action_dim)\n", 59 | " self.action_bound = action_bound\n", 60 | "\n", 61 | " def forward(self, x):\n", 62 | " x = F.relu(self.fc1(x))\n", 63 | " mu = self.fc_mu(x)\n", 64 | " std = F.softplus(self.fc_std(x))\n", 65 | " dist = Normal(mu, std)\n", 66 | " normal_sample = dist.rsample() # rsample()是重参数化采样函数\n", 67 | " log_prob = dist.log_prob(normal_sample)\n", 68 | " action = torch.tanh(normal_sample) # 计算tanh_normal分布的对数概率密度\n", 69 | " log_prob = log_prob - torch.log(1 - torch.tanh(action).pow(2) + 1e-7)\n", 70 | " action = action * self.action_bound\n", 71 | " return action, log_prob\n", 72 | "\n", 73 | "\n", 74 | "class QValueNet(torch.nn.Module):\n", 75 | " def __init__(self, state_dim, hidden_dim, action_dim):\n", 76 | " super(QValueNet, self).__init__()\n", 77 | " self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)\n", 78 | " self.fc2 = torch.nn.Linear(hidden_dim, 1)\n", 79 | "\n", 80 | " def forward(self, x, a):\n", 81 | " cat = torch.cat([x, a], dim=1) # 拼接状态和动作\n", 82 | " x = F.relu(self.fc1(cat))\n", 83 | " return self.fc2(x)\n", 84 | "\n", 85 | "\n", 86 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\n", 87 | " \"cpu\")\n", 88 | "\n", 89 | "\n", 90 | "class SAC:\n", 91 | " ''' 处理连续动作的SAC算法 '''\n", 92 | " def __init__(self, state_dim, hidden_dim, action_dim, action_bound,\n", 93 | " actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma):\n", 94 | " self.actor = PolicyNet(state_dim, hidden_dim, action_dim,\n", 95 | " action_bound).to(device) # 策略网络\n", 96 | " # 第一个Q网络\n", 97 | " self.critic_1 = QValueNet(state_dim, hidden_dim, action_dim).to(device)\n", 98 | " # 第二个Q网络\n", 99 | " self.critic_2 = QValueNet(state_dim, hidden_dim, action_dim).to(device)\n", 100 | " self.target_critic_1 = QValueNet(state_dim, hidden_dim,\n", 101 | " action_dim).to(device) # 第一个目标Q网络\n", 102 | " self.target_critic_2 = QValueNet(state_dim, hidden_dim,\n", 103 | " action_dim).to(device) # 第二个目标Q网络\n", 104 | " # 令目标Q网络的初始参数和Q网络一样\n", 105 | " self.target_critic_1.load_state_dict(self.critic_1.state_dict())\n", 106 | " self.target_critic_2.load_state_dict(self.critic_2.state_dict())\n", 107 | " self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),\n", 108 | " lr=actor_lr)\n", 109 | " self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),\n", 110 | " lr=critic_lr)\n", 111 | " self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),\n", 112 | " lr=critic_lr)\n", 113 | " # 使用alpha的log值,可以使训练结果比较稳定\n", 114 | " self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)\n", 115 | " self.log_alpha.requires_grad = True # 可以对alpha求梯度\n", 116 | " self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],\n", 117 | " lr=alpha_lr)\n", 118 | " self.target_entropy = target_entropy # 目标熵的大小\n", 119 | " self.gamma = gamma\n", 120 | " self.tau = tau\n", 121 | "\n", 122 | " def take_action(self, state):\n", 123 | " state = torch.tensor([state], dtype=torch.float).to(device)\n", 124 | " action = self.actor(state)[0]\n", 125 | " return [action.item()]\n", 126 | "\n", 127 | " def calc_target(self, rewards, next_states, dones): # 计算目标Q值\n", 128 | " next_actions, log_prob = self.actor(next_states)\n", 129 | " entropy = -log_prob\n", 130 | " q1_value = self.target_critic_1(next_states, next_actions)\n", 131 | " q2_value = self.target_critic_2(next_states, next_actions)\n", 132 | " next_value = torch.min(q1_value,\n", 133 | " q2_value) + self.log_alpha.exp() * entropy\n", 134 | " td_target = rewards + self.gamma * next_value * (1 - dones)\n", 135 | " return td_target\n", 136 | "\n", 137 | " def soft_update(self, net, target_net):\n", 138 | " for param_target, param in zip(target_net.parameters(),\n", 139 | " net.parameters()):\n", 140 | " param_target.data.copy_(param_target.data * (1.0 - self.tau) +\n", 141 | " param.data * self.tau)\n", 142 | "\n", 143 | " def update(self, transition_dict):\n", 144 | " states = torch.tensor(transition_dict['states'],\n", 145 | " dtype=torch.float).to(device)\n", 146 | " actions = torch.tensor(transition_dict['actions'],\n", 147 | " dtype=torch.float).view(-1, 1).to(device)\n", 148 | " rewards = torch.tensor(transition_dict['rewards'],\n", 149 | " dtype=torch.float).view(-1, 1).to(device)\n", 150 | " next_states = torch.tensor(transition_dict['next_states'],\n", 151 | " dtype=torch.float).to(device)\n", 152 | " dones = torch.tensor(transition_dict['dones'],\n", 153 | " dtype=torch.float).view(-1, 1).to(device)\n", 154 | " rewards = (rewards + 8.0) / 8.0 # 对倒立摆环境的奖励进行重塑\n", 155 | "\n", 156 | " # 更新两个Q网络\n", 157 | " td_target = self.calc_target(rewards, next_states, dones)\n", 158 | " critic_1_loss = torch.mean(\n", 159 | " F.mse_loss(self.critic_1(states, actions), td_target.detach()))\n", 160 | " critic_2_loss = torch.mean(\n", 161 | " F.mse_loss(self.critic_2(states, actions), td_target.detach()))\n", 162 | " self.critic_1_optimizer.zero_grad()\n", 163 | " critic_1_loss.backward()\n", 164 | " self.critic_1_optimizer.step()\n", 165 | " self.critic_2_optimizer.zero_grad()\n", 166 | " critic_2_loss.backward()\n", 167 | " self.critic_2_optimizer.step()\n", 168 | "\n", 169 | " # 更新策略网络\n", 170 | " new_actions, log_prob = self.actor(states)\n", 171 | " entropy = -log_prob\n", 172 | " q1_value = self.critic_1(states, new_actions)\n", 173 | " q2_value = self.critic_2(states, new_actions)\n", 174 | " actor_loss = torch.mean(-self.log_alpha.exp() * entropy -\n", 175 | " torch.min(q1_value, q2_value))\n", 176 | " self.actor_optimizer.zero_grad()\n", 177 | " actor_loss.backward()\n", 178 | " self.actor_optimizer.step()\n", 179 | "\n", 180 | " # 更新alpha值\n", 181 | " alpha_loss = torch.mean(\n", 182 | " (entropy - target_entropy).detach() * self.log_alpha.exp())\n", 183 | " self.log_alpha_optimizer.zero_grad()\n", 184 | " alpha_loss.backward()\n", 185 | " self.log_alpha_optimizer.step()\n", 186 | "\n", 187 | " self.soft_update(self.critic_1, self.target_critic_1)\n", 188 | " self.soft_update(self.critic_2, self.target_critic_2)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 8, 194 | "metadata": { 195 | "executionInfo": { 196 | "elapsed": 17, 197 | "status": "ok", 198 | "timestamp": 1649957428446, 199 | "user": { 200 | "displayName": "Sam Lu", 201 | "userId": "15789059763790170725" 202 | }, 203 | "user_tz": -480 204 | }, 205 | "id": "xfK4N1doGEnl" 206 | }, 207 | "outputs": [], 208 | "source": [ 209 | "class Swish(nn.Module):\n", 210 | " ''' Swish激活函数 '''\n", 211 | " def __init__(self):\n", 212 | " super(Swish, self).__init__()\n", 213 | "\n", 214 | " def forward(self, x):\n", 215 | " return x * torch.sigmoid(x)\n", 216 | "\n", 217 | "\n", 218 | "def init_weights(m):\n", 219 | " ''' 初始化模型权重 '''\n", 220 | " def truncated_normal_init(t, mean=0.0, std=0.01):\n", 221 | " torch.nn.init.normal_(t, mean=mean, std=std)\n", 222 | " while True:\n", 223 | " cond = (t < mean - 2 * std) | (t > mean + 2 * std)\n", 224 | " if not torch.sum(cond):\n", 225 | " break\n", 226 | " t = torch.where(\n", 227 | " cond,\n", 228 | " torch.nn.init.normal_(torch.ones(t.shape, device=device),\n", 229 | " mean=mean,\n", 230 | " std=std), t)\n", 231 | " return t\n", 232 | "\n", 233 | " if type(m) == nn.Linear or isinstance(m, FCLayer):\n", 234 | " truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(m._input_dim)))\n", 235 | " m.bias.data.fill_(0.0)\n", 236 | "\n", 237 | "\n", 238 | "class FCLayer(nn.Module):\n", 239 | " ''' 集成之后的全连接层 '''\n", 240 | " def __init__(self, input_dim, output_dim, ensemble_size, activation):\n", 241 | " super(FCLayer, self).__init__()\n", 242 | " self._input_dim, self._output_dim = input_dim, output_dim\n", 243 | " self.weight = nn.Parameter(\n", 244 | " torch.Tensor(ensemble_size, input_dim, output_dim).to(device))\n", 245 | " self._activation = activation\n", 246 | " self.bias = nn.Parameter(\n", 247 | " torch.Tensor(ensemble_size, output_dim).to(device))\n", 248 | "\n", 249 | " def forward(self, x):\n", 250 | " return self._activation(\n", 251 | " torch.add(torch.bmm(x, self.weight), self.bias[:, None, :]))" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 10, 257 | "metadata": { 258 | "executionInfo": { 259 | "elapsed": 779, 260 | "status": "ok", 261 | "timestamp": 1649957441286, 262 | "user": { 263 | "displayName": "Sam Lu", 264 | "userId": "15789059763790170725" 265 | }, 266 | "user_tz": -480 267 | }, 268 | "id": "o8OfdjXJGEnm" 269 | }, 270 | "outputs": [], 271 | "source": [ 272 | "class EnsembleModel(nn.Module):\n", 273 | " ''' 环境模型集成 '''\n", 274 | " def __init__(self,\n", 275 | " state_dim,\n", 276 | " action_dim,\n", 277 | " model_alpha,\n", 278 | " ensemble_size=5,\n", 279 | " learning_rate=1e-3):\n", 280 | " super(EnsembleModel, self).__init__()\n", 281 | " # 输出包括均值和方差,因此是状态与奖励维度之和的两倍\n", 282 | " self._output_dim = (state_dim + 1) * 2\n", 283 | " self._model_alpha = model_alpha # 模型损失函数中加权时的权重\n", 284 | " self._max_logvar = nn.Parameter((torch.ones(\n", 285 | " (1, self._output_dim // 2)).float() / 2).to(device),\n", 286 | " requires_grad=False)\n", 287 | " self._min_logvar = nn.Parameter((-torch.ones(\n", 288 | " (1, self._output_dim // 2)).float() * 10).to(device),\n", 289 | " requires_grad=False)\n", 290 | "\n", 291 | " self.layer1 = FCLayer(state_dim + action_dim, 200, ensemble_size,\n", 292 | " Swish())\n", 293 | " self.layer2 = FCLayer(200, 200, ensemble_size, Swish())\n", 294 | " self.layer3 = FCLayer(200, 200, ensemble_size, Swish())\n", 295 | " self.layer4 = FCLayer(200, 200, ensemble_size, Swish())\n", 296 | " self.layer5 = FCLayer(200, self._output_dim, ensemble_size,\n", 297 | " nn.Identity())\n", 298 | " self.apply(init_weights) # 初始化环境模型中的参数\n", 299 | " self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)\n", 300 | "\n", 301 | " def forward(self, x, return_log_var=False):\n", 302 | " ret = self.layer5(self.layer4(self.layer3(self.layer2(\n", 303 | " self.layer1(x)))))\n", 304 | " mean = ret[:, :, :self._output_dim // 2]\n", 305 | " # 在PETS算法中,将方差控制在最小值和最大值之间\n", 306 | " logvar = self._max_logvar - F.softplus(\n", 307 | " self._max_logvar - ret[:, :, self._output_dim // 2:])\n", 308 | " logvar = self._min_logvar + F.softplus(logvar - self._min_logvar)\n", 309 | " return mean, logvar if return_log_var else torch.exp(logvar)\n", 310 | "\n", 311 | " def loss(self, mean, logvar, labels, use_var_loss=True):\n", 312 | " inverse_var = torch.exp(-logvar)\n", 313 | " if use_var_loss:\n", 314 | " mse_loss = torch.mean(torch.mean(torch.pow(mean - labels, 2) *\n", 315 | " inverse_var,\n", 316 | " dim=-1),\n", 317 | " dim=-1)\n", 318 | " var_loss = torch.mean(torch.mean(logvar, dim=-1), dim=-1)\n", 319 | " total_loss = torch.sum(mse_loss) + torch.sum(var_loss)\n", 320 | " else:\n", 321 | " mse_loss = torch.mean(torch.pow(mean - labels, 2), dim=(1, 2))\n", 322 | " total_loss = torch.sum(mse_loss)\n", 323 | " return total_loss, mse_loss\n", 324 | "\n", 325 | " def train(self, loss):\n", 326 | " self.optimizer.zero_grad()\n", 327 | " loss += self._model_alpha * torch.sum(\n", 328 | " self._max_logvar) - self._model_alpha * torch.sum(self._min_logvar)\n", 329 | " loss.backward()\n", 330 | " self.optimizer.step()\n", 331 | "\n", 332 | "\n", 333 | "class EnsembleDynamicsModel:\n", 334 | " ''' 环境模型集成,加入精细化的训练 '''\n", 335 | " def __init__(self, state_dim, action_dim, model_alpha=0.01, num_network=5):\n", 336 | " self._num_network = num_network\n", 337 | " self._state_dim, self._action_dim = state_dim, action_dim\n", 338 | " self.model = EnsembleModel(state_dim,\n", 339 | " action_dim,\n", 340 | " model_alpha,\n", 341 | " ensemble_size=num_network)\n", 342 | " self._epoch_since_last_update = 0\n", 343 | "\n", 344 | " def train(self,\n", 345 | " inputs,\n", 346 | " labels,\n", 347 | " batch_size=64,\n", 348 | " holdout_ratio=0.1,\n", 349 | " max_iter=20):\n", 350 | " # 设置训练集与验证集\n", 351 | " permutation = np.random.permutation(inputs.shape[0])\n", 352 | " inputs, labels = inputs[permutation], labels[permutation]\n", 353 | " num_holdout = int(inputs.shape[0] * holdout_ratio)\n", 354 | " train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:]\n", 355 | " holdout_inputs, holdout_labels = inputs[:\n", 356 | " num_holdout], labels[:\n", 357 | " num_holdout]\n", 358 | " holdout_inputs = torch.from_numpy(holdout_inputs).float().to(device)\n", 359 | " holdout_labels = torch.from_numpy(holdout_labels).float().to(device)\n", 360 | " holdout_inputs = holdout_inputs[None, :, :].repeat(\n", 361 | " [self._num_network, 1, 1])\n", 362 | " holdout_labels = holdout_labels[None, :, :].repeat(\n", 363 | " [self._num_network, 1, 1])\n", 364 | "\n", 365 | " # 保留最好的结果\n", 366 | " self._snapshots = {i: (None, 1e10) for i in range(self._num_network)}\n", 367 | "\n", 368 | " for epoch in itertools.count():\n", 369 | " # 定义每一个网络的训练数据\n", 370 | " train_index = np.vstack([\n", 371 | " np.random.permutation(train_inputs.shape[0])\n", 372 | " for _ in range(self._num_network)\n", 373 | " ])\n", 374 | " # 所有真实数据都用来训练\n", 375 | " for batch_start_pos in range(0, train_inputs.shape[0], batch_size):\n", 376 | " batch_index = train_index[:, batch_start_pos:batch_start_pos +\n", 377 | " batch_size]\n", 378 | " train_input = torch.from_numpy(\n", 379 | " train_inputs[batch_index]).float().to(device)\n", 380 | " train_label = torch.from_numpy(\n", 381 | " train_labels[batch_index]).float().to(device)\n", 382 | "\n", 383 | " mean, logvar = self.model(train_input, return_log_var=True)\n", 384 | " loss, _ = self.model.loss(mean, logvar, train_label)\n", 385 | " self.model.train(loss)\n", 386 | "\n", 387 | " with torch.no_grad():\n", 388 | " mean, logvar = self.model(holdout_inputs, return_log_var=True)\n", 389 | " _, holdout_losses = self.model.loss(mean,\n", 390 | " logvar,\n", 391 | " holdout_labels,\n", 392 | " use_var_loss=False)\n", 393 | " holdout_losses = holdout_losses.cpu()\n", 394 | " break_condition = self._save_best(epoch, holdout_losses)\n", 395 | " if break_condition or epoch > max_iter: # 结束训练\n", 396 | " break\n", 397 | "\n", 398 | " def _save_best(self, epoch, losses, threshold=0.1):\n", 399 | " updated = False\n", 400 | " for i in range(len(losses)):\n", 401 | " current = losses[i]\n", 402 | " _, best = self._snapshots[i]\n", 403 | " improvement = (best - current) / best\n", 404 | " if improvement > threshold:\n", 405 | " self._snapshots[i] = (epoch, current)\n", 406 | " updated = True\n", 407 | " self._epoch_since_last_update = 0 if updated else self._epoch_since_last_update + 1\n", 408 | " return self._epoch_since_last_update > 5\n", 409 | "\n", 410 | " def predict(self, inputs, batch_size=64):\n", 411 | " inputs = np.tile(inputs, (self._num_network, 1, 1))\n", 412 | " inputs = torch.tensor(inputs, dtype=torch.float).to(device)\n", 413 | " mean, var = self.model(inputs, return_log_var=False)\n", 414 | " return mean.detach().cpu().numpy(), var.detach().cpu().numpy()\n", 415 | "\n", 416 | "\n", 417 | "class FakeEnv:\n", 418 | " def __init__(self, model):\n", 419 | " self.model = model\n", 420 | "\n", 421 | " def step(self, obs, act):\n", 422 | " inputs = np.concatenate((obs, act), axis=-1)\n", 423 | " ensemble_model_means, ensemble_model_vars = self.model.predict(inputs)\n", 424 | " ensemble_model_means[:, :, 1:] += obs\n", 425 | " ensemble_model_stds = np.sqrt(ensemble_model_vars)\n", 426 | " ensemble_samples = ensemble_model_means + np.random.normal(\n", 427 | " size=ensemble_model_means.shape) * ensemble_model_stds\n", 428 | "\n", 429 | " num_models, batch_size, _ = ensemble_model_means.shape\n", 430 | " models_to_use = np.random.choice(\n", 431 | " [i for i in range(self.model._num_network)], size=batch_size)\n", 432 | " batch_inds = np.arange(0, batch_size)\n", 433 | " samples = ensemble_samples[models_to_use, batch_inds]\n", 434 | " rewards, next_obs = samples[:, :1][0][0], samples[:, 1:][0]\n", 435 | " return rewards, next_obs" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 11, 441 | "metadata": { 442 | "executionInfo": { 443 | "elapsed": 636, 444 | "status": "ok", 445 | "timestamp": 1649957452282, 446 | "user": { 447 | "displayName": "Sam Lu", 448 | "userId": "15789059763790170725" 449 | }, 450 | "user_tz": -480 451 | }, 452 | "id": "T1X6ABP3GEno" 453 | }, 454 | "outputs": [], 455 | "source": [ 456 | "class MBPO:\n", 457 | " def __init__(self, env, agent, fake_env, env_pool, model_pool,\n", 458 | " rollout_length, rollout_batch_size, real_ratio, num_episode):\n", 459 | "\n", 460 | " self.env = env\n", 461 | " self.agent = agent\n", 462 | " self.fake_env = fake_env\n", 463 | " self.env_pool = env_pool\n", 464 | " self.model_pool = model_pool\n", 465 | " self.rollout_length = rollout_length\n", 466 | " self.rollout_batch_size = rollout_batch_size\n", 467 | " self.real_ratio = real_ratio\n", 468 | " self.num_episode = num_episode\n", 469 | "\n", 470 | " def rollout_model(self):\n", 471 | " observations, _, _, _, _ = self.env_pool.sample(\n", 472 | " self.rollout_batch_size)\n", 473 | " for obs in observations:\n", 474 | " for i in range(self.rollout_length):\n", 475 | " action = self.agent.take_action(obs)\n", 476 | " reward, next_obs = self.fake_env.step(obs, action)\n", 477 | " self.model_pool.add(obs, action, reward, next_obs, False)\n", 478 | " obs = next_obs\n", 479 | "\n", 480 | " def update_agent(self, policy_train_batch_size=64):\n", 481 | " env_batch_size = int(policy_train_batch_size * self.real_ratio)\n", 482 | " model_batch_size = policy_train_batch_size - env_batch_size\n", 483 | " for epoch in range(10):\n", 484 | " env_obs, env_action, env_reward, env_next_obs, env_done = self.env_pool.sample(\n", 485 | " env_batch_size)\n", 486 | " if self.model_pool.size() > 0:\n", 487 | " model_obs, model_action, model_reward, model_next_obs, model_done = self.model_pool.sample(\n", 488 | " model_batch_size)\n", 489 | " obs = np.concatenate((env_obs, model_obs), axis=0)\n", 490 | " action = np.concatenate((env_action, model_action), axis=0)\n", 491 | " next_obs = np.concatenate((env_next_obs, model_next_obs),\n", 492 | " axis=0)\n", 493 | " reward = np.concatenate((env_reward, model_reward), axis=0)\n", 494 | " done = np.concatenate((env_done, model_done), axis=0)\n", 495 | " else:\n", 496 | " obs, action, next_obs, reward, done = env_obs, env_action, env_next_obs, env_reward, env_done\n", 497 | " transition_dict = {\n", 498 | " 'states': obs,\n", 499 | " 'actions': action,\n", 500 | " 'next_states': next_obs,\n", 501 | " 'rewards': reward,\n", 502 | " 'dones': done\n", 503 | " }\n", 504 | " self.agent.update(transition_dict)\n", 505 | "\n", 506 | " def train_model(self):\n", 507 | " obs, action, reward, next_obs, done = self.env_pool.return_all_samples(\n", 508 | " )\n", 509 | " inputs = np.concatenate((obs, action), axis=-1)\n", 510 | " reward = np.array(reward)\n", 511 | " labels = np.concatenate(\n", 512 | " (np.reshape(reward, (reward.shape[0], -1)), next_obs - obs),\n", 513 | " axis=-1)\n", 514 | " self.fake_env.model.train(inputs, labels)\n", 515 | "\n", 516 | " def explore(self):\n", 517 | " obs, done, episode_return = self.env.reset(), False, 0\n", 518 | " while not done:\n", 519 | " action = self.agent.take_action(obs)\n", 520 | " next_obs, reward, done, _ = self.env.step(action)\n", 521 | " self.env_pool.add(obs, action, reward, next_obs, done)\n", 522 | " obs = next_obs\n", 523 | " episode_return += reward\n", 524 | " return episode_return\n", 525 | "\n", 526 | " def train(self):\n", 527 | " return_list = []\n", 528 | " explore_return = self.explore() # 随机探索采取数据\n", 529 | " print('episode: 1, return: %d' % explore_return)\n", 530 | " return_list.append(explore_return)\n", 531 | "\n", 532 | " for i_episode in range(self.num_episode - 1):\n", 533 | " obs, done, episode_return = self.env.reset(), False, 0\n", 534 | " step = 0\n", 535 | " while not done:\n", 536 | " if step % 50 == 0:\n", 537 | " self.train_model()\n", 538 | " self.rollout_model()\n", 539 | " action = self.agent.take_action(obs)\n", 540 | " next_obs, reward, done, _ = self.env.step(action)\n", 541 | " self.env_pool.add(obs, action, reward, next_obs, done)\n", 542 | " obs = next_obs\n", 543 | " episode_return += reward\n", 544 | "\n", 545 | " self.update_agent()\n", 546 | " step += 1\n", 547 | " return_list.append(episode_return)\n", 548 | " print('episode: %d, return: %d' % (i_episode + 2, episode_return))\n", 549 | " return return_list\n", 550 | "\n", 551 | "\n", 552 | "class ReplayBuffer:\n", 553 | " def __init__(self, capacity):\n", 554 | " self.buffer = collections.deque(maxlen=capacity)\n", 555 | "\n", 556 | " def add(self, state, action, reward, next_state, done):\n", 557 | " self.buffer.append((state, action, reward, next_state, done))\n", 558 | "\n", 559 | " def size(self):\n", 560 | " return len(self.buffer)\n", 561 | "\n", 562 | " def sample(self, batch_size):\n", 563 | " if batch_size > len(self.buffer):\n", 564 | " return self.return_all_samples()\n", 565 | " else:\n", 566 | " transitions = random.sample(self.buffer, batch_size)\n", 567 | " state, action, reward, next_state, done = zip(*transitions)\n", 568 | " return np.array(state), action, reward, np.array(next_state), done\n", 569 | "\n", 570 | " def return_all_samples(self):\n", 571 | " all_transitions = list(self.buffer)\n", 572 | " state, action, reward, next_state, done = zip(*all_transitions)\n", 573 | " return np.array(state), action, reward, np.array(next_state), done" 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": 12, 579 | "metadata": { 580 | "colab": { 581 | "base_uri": "https://localhost:8080/", 582 | "height": 680 583 | }, 584 | "executionInfo": { 585 | "elapsed": 613836, 586 | "status": "ok", 587 | "timestamp": 1649958070782, 588 | "user": { 589 | "displayName": "Sam Lu", 590 | "userId": "15789059763790170725" 591 | }, 592 | "user_tz": -480 593 | }, 594 | "id": "_gcY5HvTGEnr", 595 | "outputId": "49c828a2-35ec-44d9-f952-a52e01e46fe8" 596 | }, 597 | "outputs": [ 598 | { 599 | "name": "stderr", 600 | "output_type": "stream", 601 | "text": [ 602 | "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:59: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:201.)\n" 603 | ] 604 | }, 605 | { 606 | "name": "stdout", 607 | "output_type": "stream", 608 | "text": [ 609 | "episode: 1, return: -1617\n", 610 | "episode: 2, return: -1463\n", 611 | "episode: 3, return: -1407\n", 612 | "episode: 4, return: -929\n", 613 | "episode: 5, return: -860\n", 614 | "episode: 6, return: -643\n", 615 | "episode: 7, return: -128\n", 616 | "episode: 8, return: -368\n", 617 | "episode: 9, return: -118\n", 618 | "episode: 10, return: -123\n", 619 | "episode: 11, return: -122\n", 620 | "episode: 12, return: -118\n", 621 | "episode: 13, return: -119\n", 622 | "episode: 14, return: -119\n", 623 | "episode: 15, return: -121\n", 624 | "episode: 16, return: -123\n", 625 | "episode: 17, return: 0\n", 626 | "episode: 18, return: -125\n", 627 | "episode: 19, return: -126\n", 628 | "episode: 20, return: -243\n" 629 | ] 630 | }, 631 | { 632 | "data": { 633 | "image/png": "\n", 634 | "text/plain": [ 635 | "
" 636 | ] 637 | }, 638 | "metadata": { 639 | "needs_background": "light" 640 | }, 641 | "output_type": "display_data" 642 | } 643 | ], 644 | "source": [ 645 | "real_ratio = 0.5\n", 646 | "env_name = 'Pendulum-v0'\n", 647 | "env = gym.make(env_name)\n", 648 | "num_episodes = 20\n", 649 | "actor_lr = 5e-4\n", 650 | "critic_lr = 5e-3\n", 651 | "alpha_lr = 1e-3\n", 652 | "hidden_dim = 128\n", 653 | "gamma = 0.98\n", 654 | "tau = 0.005 # 软更新参数\n", 655 | "buffer_size = 10000\n", 656 | "target_entropy = -1\n", 657 | "model_alpha = 0.01 # 模型损失函数中的加权权重\n", 658 | "state_dim = env.observation_space.shape[0]\n", 659 | "action_dim = env.action_space.shape[0]\n", 660 | "action_bound = env.action_space.high[0] # 动作最大值\n", 661 | "\n", 662 | "rollout_batch_size = 1000\n", 663 | "rollout_length = 1 # 推演长度k,推荐更多尝试\n", 664 | "model_pool_size = rollout_batch_size * rollout_length\n", 665 | "\n", 666 | "agent = SAC(state_dim, hidden_dim, action_dim, action_bound, actor_lr,\n", 667 | " critic_lr, alpha_lr, target_entropy, tau, gamma)\n", 668 | "model = EnsembleDynamicsModel(state_dim, action_dim, model_alpha)\n", 669 | "fake_env = FakeEnv(model)\n", 670 | "env_pool = ReplayBuffer(buffer_size)\n", 671 | "model_pool = ReplayBuffer(model_pool_size)\n", 672 | "mbpo = MBPO(env, agent, fake_env, env_pool, model_pool, rollout_length,\n", 673 | " rollout_batch_size, real_ratio, num_episodes)\n", 674 | "\n", 675 | "return_list = mbpo.train()\n", 676 | "\n", 677 | "episodes_list = list(range(len(return_list)))\n", 678 | "plt.plot(episodes_list, return_list)\n", 679 | "plt.xlabel('Episodes')\n", 680 | "plt.ylabel('Returns')\n", 681 | "plt.title('MBPO on {}'.format(env_name))\n", 682 | "plt.show()\n", 683 | "\n", 684 | "# episode: 1, return: -1083\n", 685 | "# episode: 2, return: -1324\n", 686 | "# episode: 3, return: -979\n", 687 | "# episode: 4, return: -130\n", 688 | "# episode: 5, return: -246\n", 689 | "# episode: 6, return: -2\n", 690 | "# episode: 7, return: -239\n", 691 | "# episode: 8, return: -2\n", 692 | "# episode: 9, return: -122\n", 693 | "# episode: 10, return: -236\n", 694 | "# episode: 11, return: -238\n", 695 | "# episode: 12, return: -2\n", 696 | "# episode: 13, return: -127\n", 697 | "# episode: 14, return: -128\n", 698 | "# episode: 15, return: -125\n", 699 | "# episode: 16, return: -124\n", 700 | "# episode: 17, return: -125\n", 701 | "# episode: 18, return: -247\n", 702 | "# episode: 19, return: -127\n", 703 | "# episode: 20, return: -129" 704 | ] 705 | } 706 | ], 707 | "metadata": { 708 | "colab": { 709 | "collapsed_sections": [], 710 | "name": "第17章-基于模型的策略优化.ipynb", 711 | "provenance": [] 712 | }, 713 | "kernelspec": { 714 | "display_name": "Python 3", 715 | "language": "python", 716 | "name": "python3" 717 | }, 718 | "language_info": { 719 | "codemirror_mode": { 720 | "name": "ipython", 721 | "version": 3 722 | }, 723 | "file_extension": ".py", 724 | "mimetype": "text/x-python", 725 | "name": "python", 726 | "nbconvert_exporter": "python", 727 | "pygments_lexer": "ipython3", 728 | "version": "3.7.6" 729 | } 730 | }, 731 | "nbformat": 4, 732 | "nbformat_minor": 1 733 | } 734 | -------------------------------------------------------------------------------- /第19章-目标导向的强化学习.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "executionInfo": { 8 | "elapsed": 6862, 9 | "status": "ok", 10 | "timestamp": 1650010812842, 11 | "user": { 12 | "displayName": "Sam Lu", 13 | "userId": "15789059763790170725" 14 | }, 15 | "user_tz": -480 16 | }, 17 | "id": "98nP9Uh9GUTL" 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "import torch\n", 22 | "import torch.nn.functional as F\n", 23 | "import numpy as np\n", 24 | "import random\n", 25 | "from tqdm import tqdm\n", 26 | "import collections\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "\n", 29 | "\n", 30 | "class WorldEnv:\n", 31 | " def __init__(self):\n", 32 | " self.distance_threshold = 0.15\n", 33 | " self.action_bound = 1\n", 34 | "\n", 35 | " def reset(self): # 重置环境\n", 36 | " # 生成一个目标状态, 坐标范围是[3.5~4.5, 3.5~4.5]\n", 37 | " self.goal = np.array(\n", 38 | " [4 + random.uniform(-0.5, 0.5), 4 + random.uniform(-0.5, 0.5)])\n", 39 | " self.state = np.array([0, 0]) # 初始状态\n", 40 | " self.count = 0\n", 41 | " return np.hstack((self.state, self.goal))\n", 42 | "\n", 43 | " def step(self, action):\n", 44 | " action = np.clip(action, -self.action_bound, self.action_bound)\n", 45 | " x = max(0, min(5, self.state[0] + action[0]))\n", 46 | " y = max(0, min(5, self.state[1] + action[1]))\n", 47 | " self.state = np.array([x, y])\n", 48 | " self.count += 1\n", 49 | "\n", 50 | " dis = np.sqrt(np.sum(np.square(self.state - self.goal)))\n", 51 | " reward = -1.0 if dis > self.distance_threshold else 0\n", 52 | " if dis <= self.distance_threshold or self.count == 50:\n", 53 | " done = True\n", 54 | " else:\n", 55 | " done = False\n", 56 | "\n", 57 | " return np.hstack((self.state, self.goal)), reward, done" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 2, 63 | "metadata": { 64 | "executionInfo": { 65 | "elapsed": 3, 66 | "status": "ok", 67 | "timestamp": 1650010812843, 68 | "user": { 69 | "displayName": "Sam Lu", 70 | "userId": "15789059763790170725" 71 | }, 72 | "user_tz": -480 73 | }, 74 | "id": "hhrV6UDwGUTP" 75 | }, 76 | "outputs": [], 77 | "source": [ 78 | "class PolicyNet(torch.nn.Module):\n", 79 | " def __init__(self, state_dim, hidden_dim, action_dim, action_bound):\n", 80 | " super(PolicyNet, self).__init__()\n", 81 | " self.fc1 = torch.nn.Linear(state_dim, hidden_dim)\n", 82 | " self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)\n", 83 | " self.fc3 = torch.nn.Linear(hidden_dim, action_dim)\n", 84 | " self.action_bound = action_bound # action_bound是环境可以接受的动作最大值\n", 85 | "\n", 86 | " def forward(self, x):\n", 87 | " x = F.relu(self.fc2(F.relu(self.fc1(x))))\n", 88 | " return torch.tanh(self.fc3(x)) * self.action_bound\n", 89 | "\n", 90 | "\n", 91 | "class QValueNet(torch.nn.Module):\n", 92 | " def __init__(self, state_dim, hidden_dim, action_dim):\n", 93 | " super(QValueNet, self).__init__()\n", 94 | " self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)\n", 95 | " self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)\n", 96 | " self.fc3 = torch.nn.Linear(hidden_dim, 1)\n", 97 | "\n", 98 | " def forward(self, x, a):\n", 99 | " cat = torch.cat([x, a], dim=1) # 拼接状态和动作\n", 100 | " x = F.relu(self.fc2(F.relu(self.fc1(cat))))\n", 101 | " return self.fc3(x)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 3, 107 | "metadata": { 108 | "executionInfo": { 109 | "elapsed": 2, 110 | "status": "ok", 111 | "timestamp": 1650010819329, 112 | "user": { 113 | "displayName": "Sam Lu", 114 | "userId": "15789059763790170725" 115 | }, 116 | "user_tz": -480 117 | }, 118 | "id": "bxiqOl_vGUTR" 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "class DDPG:\n", 123 | " ''' DDPG算法 '''\n", 124 | " def __init__(self, state_dim, hidden_dim, action_dim, action_bound,\n", 125 | " actor_lr, critic_lr, sigma, tau, gamma, device):\n", 126 | " self.action_dim = action_dim\n", 127 | " self.actor = PolicyNet(state_dim, hidden_dim, action_dim,\n", 128 | " action_bound).to(device)\n", 129 | " self.critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)\n", 130 | " self.target_actor = PolicyNet(state_dim, hidden_dim, action_dim,\n", 131 | " action_bound).to(device)\n", 132 | " self.target_critic = QValueNet(state_dim, hidden_dim,\n", 133 | " action_dim).to(device)\n", 134 | " # 初始化目标价值网络并使其参数和价值网络一样\n", 135 | " self.target_critic.load_state_dict(self.critic.state_dict())\n", 136 | " # 初始化目标策略网络并使其参数和策略网络一样\n", 137 | " self.target_actor.load_state_dict(self.actor.state_dict())\n", 138 | " self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),\n", 139 | " lr=actor_lr)\n", 140 | " self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),\n", 141 | " lr=critic_lr)\n", 142 | " self.gamma = gamma\n", 143 | " self.sigma = sigma # 高斯噪声的标准差,均值直接设为0\n", 144 | " self.tau = tau # 目标网络软更新参数\n", 145 | " self.action_bound = action_bound\n", 146 | " self.device = device\n", 147 | "\n", 148 | " def take_action(self, state):\n", 149 | " state = torch.tensor([state], dtype=torch.float).to(self.device)\n", 150 | " action = self.actor(state).detach().cpu().numpy()[0]\n", 151 | " # 给动作添加噪声,增加探索\n", 152 | " action = action + self.sigma * np.random.randn(self.action_dim)\n", 153 | " return action\n", 154 | "\n", 155 | " def soft_update(self, net, target_net):\n", 156 | " for param_target, param in zip(target_net.parameters(),\n", 157 | " net.parameters()):\n", 158 | " param_target.data.copy_(param_target.data * (1.0 - self.tau) +\n", 159 | " param.data * self.tau)\n", 160 | "\n", 161 | " def update(self, transition_dict):\n", 162 | " states = torch.tensor(transition_dict['states'],\n", 163 | " dtype=torch.float).to(self.device)\n", 164 | " actions = torch.tensor(transition_dict['actions'],\n", 165 | " dtype=torch.float).to(self.device)\n", 166 | " rewards = torch.tensor(transition_dict['rewards'],\n", 167 | " dtype=torch.float).view(-1, 1).to(self.device)\n", 168 | " next_states = torch.tensor(transition_dict['next_states'],\n", 169 | " dtype=torch.float).to(self.device)\n", 170 | " dones = torch.tensor(transition_dict['dones'],\n", 171 | " dtype=torch.float).view(-1, 1).to(self.device)\n", 172 | "\n", 173 | " next_q_values = self.target_critic(next_states,\n", 174 | " self.target_actor(next_states))\n", 175 | " q_targets = rewards + self.gamma * next_q_values * (1 - dones)\n", 176 | " # MSE损失函数\n", 177 | " critic_loss = torch.mean(\n", 178 | " F.mse_loss(self.critic(states, actions), q_targets))\n", 179 | " self.critic_optimizer.zero_grad()\n", 180 | " critic_loss.backward()\n", 181 | " self.critic_optimizer.step()\n", 182 | "\n", 183 | " # 策略网络就是为了使Q值最大化\n", 184 | " actor_loss = -torch.mean(self.critic(states, self.actor(states)))\n", 185 | " self.actor_optimizer.zero_grad()\n", 186 | " actor_loss.backward()\n", 187 | " self.actor_optimizer.step()\n", 188 | "\n", 189 | " self.soft_update(self.actor, self.target_actor) # 软更新策略网络\n", 190 | " self.soft_update(self.critic, self.target_critic) # 软更新价值网络" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 4, 196 | "metadata": { 197 | "executionInfo": { 198 | "elapsed": 303, 199 | "status": "ok", 200 | "timestamp": 1650010821234, 201 | "user": { 202 | "displayName": "Sam Lu", 203 | "userId": "15789059763790170725" 204 | }, 205 | "user_tz": -480 206 | }, 207 | "id": "aw60NZwLGUTS" 208 | }, 209 | "outputs": [], 210 | "source": [ 211 | "class Trajectory:\n", 212 | " ''' 用来记录一条完整轨迹 '''\n", 213 | " def __init__(self, init_state):\n", 214 | " self.states = [init_state]\n", 215 | " self.actions = []\n", 216 | " self.rewards = []\n", 217 | " self.dones = []\n", 218 | " self.length = 0\n", 219 | "\n", 220 | " def store_step(self, action, state, reward, done):\n", 221 | " self.actions.append(action)\n", 222 | " self.states.append(state)\n", 223 | " self.rewards.append(reward)\n", 224 | " self.dones.append(done)\n", 225 | " self.length += 1\n", 226 | "\n", 227 | "\n", 228 | "class ReplayBuffer_Trajectory:\n", 229 | " ''' 存储轨迹的经验回放池 '''\n", 230 | " def __init__(self, capacity):\n", 231 | " self.buffer = collections.deque(maxlen=capacity)\n", 232 | "\n", 233 | " def add_trajectory(self, trajectory):\n", 234 | " self.buffer.append(trajectory)\n", 235 | "\n", 236 | " def size(self):\n", 237 | " return len(self.buffer)\n", 238 | "\n", 239 | " def sample(self, batch_size, use_her, dis_threshold=0.15, her_ratio=0.8):\n", 240 | " batch = dict(states=[],\n", 241 | " actions=[],\n", 242 | " next_states=[],\n", 243 | " rewards=[],\n", 244 | " dones=[])\n", 245 | " for _ in range(batch_size):\n", 246 | " traj = random.sample(self.buffer, 1)[0]\n", 247 | " step_state = np.random.randint(traj.length)\n", 248 | " state = traj.states[step_state]\n", 249 | " next_state = traj.states[step_state + 1]\n", 250 | " action = traj.actions[step_state]\n", 251 | " reward = traj.rewards[step_state]\n", 252 | " done = traj.dones[step_state]\n", 253 | "\n", 254 | " if use_her and np.random.uniform() <= her_ratio:\n", 255 | " step_goal = np.random.randint(step_state + 1, traj.length + 1)\n", 256 | " goal = traj.states[step_goal][:2] # 使用HER算法的future方案设置目标\n", 257 | " dis = np.sqrt(np.sum(np.square(next_state[:2] - goal)))\n", 258 | " reward = -1.0 if dis > dis_threshold else 0\n", 259 | " done = False if dis > dis_threshold else True\n", 260 | " state = np.hstack((state[:2], goal))\n", 261 | " next_state = np.hstack((next_state[:2], goal))\n", 262 | "\n", 263 | " batch['states'].append(state)\n", 264 | " batch['next_states'].append(next_state)\n", 265 | " batch['actions'].append(action)\n", 266 | " batch['rewards'].append(reward)\n", 267 | " batch['dones'].append(done)\n", 268 | "\n", 269 | " batch['states'] = np.array(batch['states'])\n", 270 | " batch['next_states'] = np.array(batch['next_states'])\n", 271 | " batch['actions'] = np.array(batch['actions'])\n", 272 | " return batch" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 6, 278 | "metadata": { 279 | "colab": { 280 | "base_uri": "https://localhost:8080/", 281 | "height": 506 282 | }, 283 | "executionInfo": { 284 | "elapsed": 890109, 285 | "status": "ok", 286 | "timestamp": 1650011748151, 287 | "user": { 288 | "displayName": "Sam Lu", 289 | "userId": "15789059763790170725" 290 | }, 291 | "user_tz": -480 292 | }, 293 | "id": "0wLycs-3GUTT", 294 | "outputId": "a73c3a0d-d7ac-486b-87aa-b141aa9d4d0c" 295 | }, 296 | "outputs": [ 297 | { 298 | "name": "stderr", 299 | "output_type": "stream", 300 | "text": [ 301 | "Iteration 0: 0%| | 0/200 [00:00" 319 | ] 320 | }, 321 | "metadata": { 322 | "needs_background": "light" 323 | }, 324 | "output_type": "display_data" 325 | } 326 | ], 327 | "source": [ 328 | "actor_lr = 1e-3\n", 329 | "critic_lr = 1e-3\n", 330 | "hidden_dim = 128\n", 331 | "state_dim = 4\n", 332 | "action_dim = 2\n", 333 | "action_bound = 1\n", 334 | "sigma = 0.1\n", 335 | "tau = 0.005\n", 336 | "gamma = 0.98\n", 337 | "num_episodes = 2000\n", 338 | "n_train = 20\n", 339 | "batch_size = 256\n", 340 | "minimal_episodes = 200\n", 341 | "buffer_size = 10000\n", 342 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\n", 343 | " \"cpu\")\n", 344 | "\n", 345 | "random.seed(0)\n", 346 | "np.random.seed(0)\n", 347 | "torch.manual_seed(0)\n", 348 | "env = WorldEnv()\n", 349 | "replay_buffer = ReplayBuffer_Trajectory(buffer_size)\n", 350 | "agent = DDPG(state_dim, hidden_dim, action_dim, action_bound, actor_lr,\n", 351 | " critic_lr, sigma, tau, gamma, device)\n", 352 | "\n", 353 | "return_list = []\n", 354 | "for i in range(10):\n", 355 | " with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:\n", 356 | " for i_episode in range(int(num_episodes / 10)):\n", 357 | " episode_return = 0\n", 358 | " state = env.reset()\n", 359 | " traj = Trajectory(state)\n", 360 | " done = False\n", 361 | " while not done:\n", 362 | " action = agent.take_action(state)\n", 363 | " state, reward, done = env.step(action)\n", 364 | " episode_return += reward\n", 365 | " traj.store_step(action, state, reward, done)\n", 366 | " replay_buffer.add_trajectory(traj)\n", 367 | " return_list.append(episode_return)\n", 368 | " if replay_buffer.size() >= minimal_episodes:\n", 369 | " for _ in range(n_train):\n", 370 | " transition_dict = replay_buffer.sample(batch_size, True)\n", 371 | " agent.update(transition_dict)\n", 372 | " if (i_episode + 1) % 10 == 0:\n", 373 | " pbar.set_postfix({\n", 374 | " 'episode':\n", 375 | " '%d' % (num_episodes / 10 * i + i_episode + 1),\n", 376 | " 'return':\n", 377 | " '%.3f' % np.mean(return_list[-10:])\n", 378 | " })\n", 379 | " pbar.update(1)\n", 380 | "\n", 381 | "episodes_list = list(range(len(return_list)))\n", 382 | "plt.plot(episodes_list, return_list)\n", 383 | "plt.xlabel('Episodes')\n", 384 | "plt.ylabel('Returns')\n", 385 | "plt.title('DDPG with HER on {}'.format('GridWorld'))\n", 386 | "plt.show()\n", 387 | "\n", 388 | "# Iteration 0: 100%|██████████| 200/200 [00:03<00:00, 58.91it/s, episode=200,\n", 389 | "# return=-50.000]\n", 390 | "# Iteration 1: 100%|██████████| 200/200 [01:17<00:00, 2.56it/s, episode=400,\n", 391 | "# return=-4.200]\n", 392 | "# Iteration 2: 100%|██████████| 200/200 [01:18<00:00, 2.56it/s, episode=600,\n", 393 | "# return=-4.700]\n", 394 | "# Iteration 3: 100%|██████████| 200/200 [01:18<00:00, 2.56it/s, episode=800,\n", 395 | "# return=-4.300]\n", 396 | "# Iteration 4: 100%|██████████| 200/200 [01:17<00:00, 2.57it/s, episode=1000,\n", 397 | "# return=-3.800]\n", 398 | "# Iteration 5: 100%|██████████| 200/200 [01:17<00:00, 2.57it/s, episode=1200,\n", 399 | "# return=-4.800]\n", 400 | "# Iteration 6: 100%|██████████| 200/200 [01:18<00:00, 2.54it/s, episode=1400,\n", 401 | "# return=-4.500]\n", 402 | "# Iteration 7: 100%|██████████| 200/200 [01:19<00:00, 2.52it/s, episode=1600,\n", 403 | "# return=-4.400]\n", 404 | "# Iteration 8: 100%|██████████| 200/200 [01:18<00:00, 2.55it/s, episode=1800,\n", 405 | "# return=-4.200]\n", 406 | "# Iteration 9: 100%|██████████| 200/200 [01:18<00:00, 2.55it/s, episode=2000,\n", 407 | "# return=-4.300]" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": null, 413 | "metadata": { 414 | "id": "Cc0b1OlFGUTV" 415 | }, 416 | "outputs": [], 417 | "source": [ 418 | "random.seed(0)\n", 419 | "np.random.seed(0)\n", 420 | "torch.manual_seed(0)\n", 421 | "env = WorldEnv()\n", 422 | "replay_buffer = ReplayBuffer_Trajectory(buffer_size)\n", 423 | "agent = DDPG(state_dim, hidden_dim, action_dim, action_bound, actor_lr,\n", 424 | " critic_lr, sigma, tau, gamma, device)\n", 425 | "\n", 426 | "return_list = []\n", 427 | "for i in range(10):\n", 428 | " with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:\n", 429 | " for i_episode in range(int(num_episodes / 10)):\n", 430 | " episode_return = 0\n", 431 | " state = env.reset()\n", 432 | " traj = Trajectory(state)\n", 433 | " done = False\n", 434 | " while not done:\n", 435 | " action = agent.take_action(state)\n", 436 | " state, reward, done = env.step(action)\n", 437 | " episode_return += reward\n", 438 | " traj.store_step(action, state, reward, done)\n", 439 | " replay_buffer.add_trajectory(traj)\n", 440 | " return_list.append(episode_return)\n", 441 | " if replay_buffer.size() >= minimal_episodes:\n", 442 | " for _ in range(n_train):\n", 443 | " # 和使用HER训练的唯一区别\n", 444 | " transition_dict = replay_buffer.sample(batch_size, False)\n", 445 | " agent.update(transition_dict)\n", 446 | " if (i_episode + 1) % 10 == 0:\n", 447 | " pbar.set_postfix({\n", 448 | " 'episode':\n", 449 | " '%d' % (num_episodes / 10 * i + i_episode + 1),\n", 450 | " 'return':\n", 451 | " '%.3f' % np.mean(return_list[-10:])\n", 452 | " })\n", 453 | " pbar.update(1)\n", 454 | "\n", 455 | "episodes_list = list(range(len(return_list)))\n", 456 | "plt.plot(episodes_list, return_list)\n", 457 | "plt.xlabel('Episodes')\n", 458 | "plt.ylabel('Returns')\n", 459 | "plt.title('DDPG without HER on {}'.format('GridWorld'))\n", 460 | "plt.show()\n", 461 | "\n", 462 | "# Iteration 0: 100%|██████████| 200/200 [00:03<00:00, 62.82it/s, episode=200,\n", 463 | "# return=-50.000]\n", 464 | "# Iteration 1: 100%|██████████| 200/200 [00:39<00:00, 5.01it/s, episode=400,\n", 465 | "# return=-50.000]\n", 466 | "# Iteration 2: 100%|██████████| 200/200 [00:41<00:00, 4.83it/s, episode=600,\n", 467 | "# return=-50.000]\n", 468 | "# Iteration 3: 100%|██████████| 200/200 [00:41<00:00, 4.82it/s, episode=800,\n", 469 | "# return=-50.000]\n", 470 | "# Iteration 4: 100%|██████████| 200/200 [00:41<00:00, 4.81it/s, episode=1000,\n", 471 | "# return=-50.000]\n", 472 | "# Iteration 5: 100%|██████████| 200/200 [00:41<00:00, 4.79it/s, episode=1200,\n", 473 | "# return=-50.000]\n", 474 | "# Iteration 6: 100%|██████████| 200/200 [00:42<00:00, 4.76it/s, episode=1400,\n", 475 | "# return=-45.500]\n", 476 | "# Iteration 7: 100%|██████████| 200/200 [00:41<00:00, 4.80it/s, episode=1600,\n", 477 | "# return=-42.600]\n", 478 | "# Iteration 8: 100%|██████████| 200/200 [00:40<00:00, 4.92it/s, episode=1800,\n", 479 | "# return=-4.800]\n", 480 | "# Iteration 9: 100%|██████████| 200/200 [00:40<00:00, 4.99it/s, episode=2000,\n", 481 | "# return=-4.800]" 482 | ] 483 | } 484 | ], 485 | "metadata": { 486 | "accelerator": "GPU", 487 | "colab": { 488 | "collapsed_sections": [], 489 | "name": "第19章-目标导向的强化学习.ipynb", 490 | "provenance": [] 491 | }, 492 | "kernelspec": { 493 | "display_name": "Python 3", 494 | "language": "python", 495 | "name": "python3" 496 | }, 497 | "language_info": { 498 | "codemirror_mode": { 499 | "name": "ipython", 500 | "version": 3 501 | }, 502 | "file_extension": ".py", 503 | "mimetype": "text/x-python", 504 | "name": "python", 505 | "nbconvert_exporter": "python", 506 | "pygments_lexer": "ipython3", 507 | "version": "3.7.6" 508 | } 509 | }, 510 | "nbformat": 4, 511 | "nbformat_minor": 1 512 | } 513 | -------------------------------------------------------------------------------- /第20章-多智能体强化学习入门.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/" 9 | }, 10 | "executionInfo": { 11 | "elapsed": 10107, 12 | "status": "ok", 13 | "timestamp": 1650012696153, 14 | "user": { 15 | "displayName": "Sam Lu", 16 | "userId": "15789059763790170725" 17 | }, 18 | "user_tz": -480 19 | }, 20 | "id": "-_L_dhppItIk", 21 | "outputId": "6c1eecf0-fd72-4d13-ad05-192463636129" 22 | }, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Cloning into 'ma-gym'...\n", 29 | "remote: Enumerating objects: 1072, done.\u001b[K\n", 30 | "remote: Counting objects: 100% (141/141), done.\u001b[K\n", 31 | "remote: Compressing objects: 100% (131/131), done.\u001b[K\n", 32 | "remote: Total 1072 (delta 61), reused 31 (delta 6), pack-reused 931\u001b[K\n", 33 | "Receiving objects: 100% (1072/1072), 3.74 MiB | 4.47 MiB/s, done.\n", 34 | "Resolving deltas: 100% (524/524), done.\n" 35 | ] 36 | } 37 | ], 38 | "source": [ 39 | "import torch\n", 40 | "import torch.nn.functional as F\n", 41 | "import numpy as np\n", 42 | "import rl_utils\n", 43 | "from tqdm import tqdm\n", 44 | "import matplotlib.pyplot as plt\n", 45 | "\n", 46 | "! git clone https://github.com/boyu-ai/ma-gym.git\n", 47 | "import sys\n", 48 | "sys.path.append(\"./ma-gym\")\n", 49 | "from ma_gym.envs.combat.combat import Combat" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": { 56 | "id": "HdZSfYc7ItIn" 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "class PolicyNet(torch.nn.Module):\n", 61 | " def __init__(self, state_dim, hidden_dim, action_dim):\n", 62 | " super(PolicyNet, self).__init__()\n", 63 | " self.fc1 = torch.nn.Linear(state_dim, hidden_dim)\n", 64 | " self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)\n", 65 | " self.fc3 = torch.nn.Linear(hidden_dim, action_dim)\n", 66 | "\n", 67 | " def forward(self, x):\n", 68 | " x = F.relu(self.fc2(F.relu(self.fc1(x))))\n", 69 | " return F.softmax(self.fc3(x), dim=1)\n", 70 | "\n", 71 | "\n", 72 | "class ValueNet(torch.nn.Module):\n", 73 | " def __init__(self, state_dim, hidden_dim):\n", 74 | " super(ValueNet, self).__init__()\n", 75 | " self.fc1 = torch.nn.Linear(state_dim, hidden_dim)\n", 76 | " self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)\n", 77 | " self.fc3 = torch.nn.Linear(hidden_dim, 1)\n", 78 | "\n", 79 | " def forward(self, x):\n", 80 | " x = F.relu(self.fc2(F.relu(self.fc1(x))))\n", 81 | " return self.fc3(x)\n", 82 | "\n", 83 | "\n", 84 | "class PPO:\n", 85 | " ''' PPO算法,采用截断方式 '''\n", 86 | " def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,\n", 87 | " lmbda, eps, gamma, device):\n", 88 | " self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)\n", 89 | " self.critic = ValueNet(state_dim, hidden_dim).to(device)\n", 90 | " self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),\n", 91 | " lr=actor_lr)\n", 92 | " self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),\n", 93 | " lr=critic_lr)\n", 94 | " self.gamma = gamma\n", 95 | " self.lmbda = lmbda\n", 96 | " self.eps = eps # PPO中截断范围的参数\n", 97 | " self.device = device\n", 98 | "\n", 99 | " def take_action(self, state):\n", 100 | " state = torch.tensor([state], dtype=torch.float).to(self.device)\n", 101 | " probs = self.actor(state)\n", 102 | " action_dist = torch.distributions.Categorical(probs)\n", 103 | " action = action_dist.sample()\n", 104 | " return action.item()\n", 105 | "\n", 106 | " def update(self, transition_dict):\n", 107 | " states = torch.tensor(transition_dict['states'],\n", 108 | " dtype=torch.float).to(self.device)\n", 109 | " actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(\n", 110 | " self.device)\n", 111 | " rewards = torch.tensor(transition_dict['rewards'],\n", 112 | " dtype=torch.float).view(-1, 1).to(self.device)\n", 113 | " next_states = torch.tensor(transition_dict['next_states'],\n", 114 | " dtype=torch.float).to(self.device)\n", 115 | " dones = torch.tensor(transition_dict['dones'],\n", 116 | " dtype=torch.float).view(-1, 1).to(self.device)\n", 117 | " td_target = rewards + self.gamma * self.critic(next_states) * (1 -\n", 118 | " dones)\n", 119 | " td_delta = td_target - self.critic(states)\n", 120 | " advantage = rl_utils.compute_advantage(self.gamma, self.lmbda,\n", 121 | " td_delta.cpu()).to(self.device)\n", 122 | " old_log_probs = torch.log(self.actor(states).gather(1,\n", 123 | " actions)).detach()\n", 124 | "\n", 125 | " log_probs = torch.log(self.actor(states).gather(1, actions))\n", 126 | " ratio = torch.exp(log_probs - old_log_probs)\n", 127 | " surr1 = ratio * advantage\n", 128 | " surr2 = torch.clamp(ratio, 1 - self.eps,\n", 129 | " 1 + self.eps) * advantage # 截断\n", 130 | " actor_loss = torch.mean(-torch.min(surr1, surr2)) # PPO损失函数\n", 131 | " critic_loss = torch.mean(\n", 132 | " F.mse_loss(self.critic(states), td_target.detach()))\n", 133 | " self.actor_optimizer.zero_grad()\n", 134 | " self.critic_optimizer.zero_grad()\n", 135 | " actor_loss.backward()\n", 136 | " critic_loss.backward()\n", 137 | " self.actor_optimizer.step()\n", 138 | " self.critic_optimizer.step()" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": { 145 | "colab": { 146 | "base_uri": "https://localhost:8080/" 147 | }, 148 | "executionInfo": { 149 | "elapsed": 2805926, 150 | "status": "ok", 151 | "timestamp": 1649963248923, 152 | "user": { 153 | "displayName": "Sam Lu", 154 | "userId": "15789059763790170725" 155 | }, 156 | "user_tz": -480 157 | }, 158 | "id": "t8FsMOFPItIp", 159 | "outputId": "2f453795-508c-45ff-91e1-fb8b81eb5e9c" 160 | }, 161 | "outputs": [ 162 | { 163 | "name": "stderr", 164 | "output_type": "stream", 165 | "text": [ 166 | "/usr/local/lib/python3.7/dist-packages/gym/logger.py:30: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n", 167 | " warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))\n", 168 | "Iteration 0: 100%|██████████| 10000/10000 [07:17<00:00, 22.85it/s, episode=10000, return=0.310]\n", 169 | "Iteration 1: 100%|██████████| 10000/10000 [05:43<00:00, 29.08it/s, episode=20000, return=0.370]\n", 170 | "Iteration 2: 100%|██████████| 10000/10000 [05:30<00:00, 30.26it/s, episode=30000, return=0.560]\n", 171 | "Iteration 3: 100%|██████████| 10000/10000 [04:54<00:00, 33.96it/s, episode=40000, return=0.670]\n", 172 | "Iteration 4: 100%|██████████| 10000/10000 [04:20<00:00, 38.46it/s, episode=50000, return=0.670]\n", 173 | "Iteration 5: 100%|██████████| 10000/10000 [03:52<00:00, 43.09it/s, episode=60000, return=0.620]\n", 174 | "Iteration 6: 100%|██████████| 10000/10000 [03:55<00:00, 42.53it/s, episode=70000, return=0.610]\n", 175 | "Iteration 7: 100%|██████████| 10000/10000 [03:40<00:00, 45.26it/s, episode=80000, return=0.640]\n", 176 | "Iteration 8: 100%|██████████| 10000/10000 [03:48<00:00, 43.81it/s, episode=90000, return=0.650]\n", 177 | "Iteration 9: 100%|██████████| 10000/10000 [03:42<00:00, 44.91it/s, episode=100000, return=0.770]\n" 178 | ] 179 | } 180 | ], 181 | "source": [ 182 | "actor_lr = 3e-4\n", 183 | "critic_lr = 1e-3\n", 184 | "num_episodes = 100000\n", 185 | "hidden_dim = 64\n", 186 | "gamma = 0.99\n", 187 | "lmbda = 0.97\n", 188 | "eps = 0.2\n", 189 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\n", 190 | " \"cpu\")\n", 191 | "\n", 192 | "team_size = 2\n", 193 | "grid_size = (15, 15)\n", 194 | "#创建Combat环境,格子世界的大小为15x15,己方智能体和敌方智能体数量都为2\n", 195 | "env = Combat(grid_shape=grid_size, n_agents=team_size, n_opponents=team_size)\n", 196 | "\n", 197 | "state_dim = env.observation_space[0].shape[0]\n", 198 | "action_dim = env.action_space[0].n\n", 199 | "#两个智能体共享同一个策略\n", 200 | "agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, eps,\n", 201 | " gamma, device)\n", 202 | "\n", 203 | "win_list = []\n", 204 | "for i in range(10):\n", 205 | " with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:\n", 206 | " for i_episode in range(int(num_episodes / 10)):\n", 207 | " transition_dict_1 = {\n", 208 | " 'states': [],\n", 209 | " 'actions': [],\n", 210 | " 'next_states': [],\n", 211 | " 'rewards': [],\n", 212 | " 'dones': []\n", 213 | " }\n", 214 | " transition_dict_2 = {\n", 215 | " 'states': [],\n", 216 | " 'actions': [],\n", 217 | " 'next_states': [],\n", 218 | " 'rewards': [],\n", 219 | " 'dones': []\n", 220 | " }\n", 221 | " s = env.reset()\n", 222 | " terminal = False\n", 223 | " while not terminal:\n", 224 | " a_1 = agent.take_action(s[0])\n", 225 | " a_2 = agent.take_action(s[1])\n", 226 | " next_s, r, done, info = env.step([a_1, a_2])\n", 227 | " transition_dict_1['states'].append(s[0])\n", 228 | " transition_dict_1['actions'].append(a_1)\n", 229 | " transition_dict_1['next_states'].append(next_s[0])\n", 230 | " transition_dict_1['rewards'].append(\n", 231 | " r[0] + 100 if info['win'] else r[0] - 0.1)\n", 232 | " transition_dict_1['dones'].append(False)\n", 233 | " transition_dict_2['states'].append(s[1])\n", 234 | " transition_dict_2['actions'].append(a_2)\n", 235 | " transition_dict_2['next_states'].append(next_s[1])\n", 236 | " transition_dict_2['rewards'].append(\n", 237 | " r[1] + 100 if info['win'] else r[1] - 0.1)\n", 238 | " transition_dict_2['dones'].append(False)\n", 239 | " s = next_s\n", 240 | " terminal = all(done)\n", 241 | " win_list.append(1 if info[\"win\"] else 0)\n", 242 | " agent.update(transition_dict_1)\n", 243 | " agent.update(transition_dict_2)\n", 244 | " if (i_episode + 1) % 100 == 0:\n", 245 | " pbar.set_postfix({\n", 246 | " 'episode':\n", 247 | " '%d' % (num_episodes / 10 * i + i_episode + 1),\n", 248 | " 'return':\n", 249 | " '%.3f' % np.mean(win_list[-100:])\n", 250 | " })\n", 251 | " pbar.update(1)\n", 252 | "\n", 253 | "# /usr/local/lib/python3.7/dist-packages/gym/logger.py:30: UserWarning:[33mWARN:\n", 254 | "# Box bound precision lowered by casting to float32[0m\n", 255 | "# warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))\n", 256 | "\n", 257 | "# Iteration 0: 100%|██████████| 10000/10000 [05:22<00:00, 31.02it/s, episode=10000,\n", 258 | "# return=0.220]\n", 259 | "# Iteration 1: 100%|██████████| 10000/10000 [04:03<00:00, 41.07it/s, episode=20000,\n", 260 | "# return=0.400]\n", 261 | "# Iteration 2: 100%|██████████| 10000/10000 [03:37<00:00, 45.96it/s, episode=30000,\n", 262 | "# return=0.670]\n", 263 | "# Iteration 3: 100%|██████████| 10000/10000 [03:13<00:00, 51.55it/s, episode=40000,\n", 264 | "# return=0.590]\n", 265 | "# Iteration 4: 100%|██████████| 10000/10000 [02:58<00:00, 56.07it/s, episode=50000,\n", 266 | "# return=0.750]\n", 267 | "# Iteration 5: 100%|██████████| 10000/10000 [02:58<00:00, 56.09it/s, episode=60000,\n", 268 | "# return=0.660]\n", 269 | "# Iteration 6: 100%|██████████| 10000/10000 [02:57<00:00, 56.42it/s, episode=70000,\n", 270 | "# return=0.660]\n", 271 | "# Iteration 7: 100%|██████████| 10000/10000 [03:04<00:00, 54.20it/s, episode=80000,\n", 272 | "# return=0.720]\n", 273 | "# Iteration 8: 100%|██████████| 10000/10000 [02:59<00:00, 55.84it/s, episode=90000,\n", 274 | "# return=0.530]\n", 275 | "# Iteration 9: 100%|██████████| 10000/10000 [03:03<00:00, 54.55it/s, episode=100000,\n", 276 | "# return=0.710]" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": { 283 | "colab": { 284 | "base_uri": "https://localhost:8080/", 285 | "height": 295 286 | }, 287 | "executionInfo": { 288 | "elapsed": 20, 289 | "status": "ok", 290 | "timestamp": 1649963248923, 291 | "user": { 292 | "displayName": "Sam Lu", 293 | "userId": "15789059763790170725" 294 | }, 295 | "user_tz": -480 296 | }, 297 | "id": "OT2mwoZdItIq", 298 | "outputId": "6ea70d1d-bb28-456e-ffca-fe4f8106e0b8" 299 | }, 300 | "outputs": [ 301 | { 302 | "data": { 303 | "image/png": "\n", 304 | "text/plain": [ 305 | "
" 306 | ] 307 | }, 308 | "metadata": { 309 | "needs_background": "light" 310 | }, 311 | "output_type": "display_data" 312 | } 313 | ], 314 | "source": [ 315 | "win_array = np.array(win_list)\n", 316 | "#每100条轨迹取一次平均\n", 317 | "win_array = np.mean(win_array.reshape(-1, 100), axis=1)\n", 318 | "\n", 319 | "episodes_list = np.arange(win_array.shape[0]) * 100\n", 320 | "plt.plot(episodes_list, win_array)\n", 321 | "plt.xlabel('Episodes')\n", 322 | "plt.ylabel('Win rate')\n", 323 | "plt.title('IPPO on Combat')\n", 324 | "plt.show()" 325 | ] 326 | } 327 | ], 328 | "metadata": { 329 | "colab": { 330 | "collapsed_sections": [], 331 | "name": "第20章-多智能体强化学习入门.ipynb", 332 | "provenance": [] 333 | }, 334 | "kernelspec": { 335 | "display_name": "Python 3", 336 | "language": "python", 337 | "name": "python3" 338 | }, 339 | "language_info": { 340 | "codemirror_mode": { 341 | "name": "ipython", 342 | "version": 3 343 | }, 344 | "file_extension": ".py", 345 | "mimetype": "text/x-python", 346 | "name": "python", 347 | "nbconvert_exporter": "python", 348 | "pygments_lexer": "ipython3", 349 | "version": "3.7.6" 350 | } 351 | }, 352 | "nbformat": 4, 353 | "nbformat_minor": 1 354 | } 355 | -------------------------------------------------------------------------------- /第3章-马尔可夫决策过程.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/" 9 | }, 10 | "executionInfo": { 11 | "elapsed": 5, 12 | "status": "ok", 13 | "timestamp": 1649954662434, 14 | "user": { 15 | "displayName": "Sam Lu", 16 | "userId": "15789059763790170725" 17 | }, 18 | "user_tz": -480 19 | }, 20 | "id": "5OzU9RtB9fWZ", 21 | "outputId": "146722e2-c641-46dc-a690-01e8eed9160c" 22 | }, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "根据本序列计算得到回报为:-2.5。\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "import numpy as np\n", 34 | "np.random.seed(0)\n", 35 | "# 定义状态转移概率矩阵P\n", 36 | "P = [\n", 37 | " [0.9, 0.1, 0.0, 0.0, 0.0, 0.0],\n", 38 | " [0.5, 0.0, 0.5, 0.0, 0.0, 0.0],\n", 39 | " [0.0, 0.0, 0.0, 0.6, 0.0, 0.4],\n", 40 | " [0.0, 0.0, 0.0, 0.0, 0.3, 0.7],\n", 41 | " [0.0, 0.2, 0.3, 0.5, 0.0, 0.0],\n", 42 | " [0.0, 0.0, 0.0, 0.0, 0.0, 1.0],\n", 43 | "]\n", 44 | "P = np.array(P)\n", 45 | "\n", 46 | "rewards = [-1, -2, -2, 10, 1, 0] # 定义奖励函数\n", 47 | "gamma = 0.5 # 定义折扣因子\n", 48 | "\n", 49 | "\n", 50 | "# 给定一条序列,计算从某个索引(起始状态)开始到序列最后(终止状态)得到的回报\n", 51 | "def compute_return(start_index, chain, gamma):\n", 52 | " G = 0\n", 53 | " for i in reversed(range(start_index, len(chain))):\n", 54 | " G = gamma * G + rewards[chain[i] - 1]\n", 55 | " return G\n", 56 | "\n", 57 | "\n", 58 | "# 一个状态序列,s1-s2-s3-s6\n", 59 | "chain = [1, 2, 3, 6]\n", 60 | "start_index = 0\n", 61 | "G = compute_return(start_index, chain, gamma)\n", 62 | "print(\"根据本序列计算得到回报为:%s。\" % G)\n", 63 | "\n", 64 | "# 根据本序列计算得到回报为:-2.5。" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 2, 70 | "metadata": { 71 | "colab": { 72 | "base_uri": "https://localhost:8080/" 73 | }, 74 | "executionInfo": { 75 | "elapsed": 3, 76 | "status": "ok", 77 | "timestamp": 1649954662902, 78 | "user": { 79 | "displayName": "Sam Lu", 80 | "userId": "15789059763790170725" 81 | }, 82 | "user_tz": -480 83 | }, 84 | "id": "8sywqMFs9fWd", 85 | "outputId": "d5c626fd-70c9-44f7-a4c3-3b112e4654c4" 86 | }, 87 | "outputs": [ 88 | { 89 | "name": "stdout", 90 | "output_type": "stream", 91 | "text": [ 92 | "MRP中每个状态价值分别为\n", 93 | " [[-2.01950168]\n", 94 | " [-2.21451846]\n", 95 | " [ 1.16142785]\n", 96 | " [10.53809283]\n", 97 | " [ 3.58728554]\n", 98 | " [ 0. ]]\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "def compute(P, rewards, gamma, states_num):\n", 104 | " ''' 利用贝尔曼方程的矩阵形式计算解析解,states_num是MRP的状态数 '''\n", 105 | " rewards = np.array(rewards).reshape((-1, 1)) #将rewards写成列向量形式\n", 106 | " value = np.dot(np.linalg.inv(np.eye(states_num, states_num) - gamma * P),\n", 107 | " rewards)\n", 108 | " return value\n", 109 | "\n", 110 | "\n", 111 | "V = compute(P, rewards, gamma, 6)\n", 112 | "print(\"MRP中每个状态价值分别为\\n\", V)\n", 113 | "\n", 114 | "# MRP中每个状态价值分别为\n", 115 | "# [[-2.01950168]\n", 116 | "# [-2.21451846]\n", 117 | "# [ 1.16142785]\n", 118 | "# [10.53809283]\n", 119 | "# [ 3.58728554]\n", 120 | "# [ 0. ]]" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 3, 126 | "metadata": { 127 | "executionInfo": { 128 | "elapsed": 340, 129 | "status": "ok", 130 | "timestamp": 1649954667427, 131 | "user": { 132 | "displayName": "Sam Lu", 133 | "userId": "15789059763790170725" 134 | }, 135 | "user_tz": -480 136 | }, 137 | "id": "5ILxWaLR9fWd" 138 | }, 139 | "outputs": [], 140 | "source": [ 141 | "S = [\"s1\", \"s2\", \"s3\", \"s4\", \"s5\"] # 状态集合\n", 142 | "A = [\"保持s1\", \"前往s1\", \"前往s2\", \"前往s3\", \"前往s4\", \"前往s5\", \"概率前往\"] # 动作集合\n", 143 | "# 状态转移函数\n", 144 | "P = {\n", 145 | " \"s1-保持s1-s1\": 1.0,\n", 146 | " \"s1-前往s2-s2\": 1.0,\n", 147 | " \"s2-前往s1-s1\": 1.0,\n", 148 | " \"s2-前往s3-s3\": 1.0,\n", 149 | " \"s3-前往s4-s4\": 1.0,\n", 150 | " \"s3-前往s5-s5\": 1.0,\n", 151 | " \"s4-前往s5-s5\": 1.0,\n", 152 | " \"s4-概率前往-s2\": 0.2,\n", 153 | " \"s4-概率前往-s3\": 0.4,\n", 154 | " \"s4-概率前往-s4\": 0.4,\n", 155 | "}\n", 156 | "# 奖励函数\n", 157 | "R = {\n", 158 | " \"s1-保持s1\": -1,\n", 159 | " \"s1-前往s2\": 0,\n", 160 | " \"s2-前往s1\": -1,\n", 161 | " \"s2-前往s3\": -2,\n", 162 | " \"s3-前往s4\": -2,\n", 163 | " \"s3-前往s5\": 0,\n", 164 | " \"s4-前往s5\": 10,\n", 165 | " \"s4-概率前往\": 1,\n", 166 | "}\n", 167 | "gamma = 0.5 # 折扣因子\n", 168 | "MDP = (S, A, P, R, gamma)\n", 169 | "\n", 170 | "# 策略1,随机策略\n", 171 | "Pi_1 = {\n", 172 | " \"s1-保持s1\": 0.5,\n", 173 | " \"s1-前往s2\": 0.5,\n", 174 | " \"s2-前往s1\": 0.5,\n", 175 | " \"s2-前往s3\": 0.5,\n", 176 | " \"s3-前往s4\": 0.5,\n", 177 | " \"s3-前往s5\": 0.5,\n", 178 | " \"s4-前往s5\": 0.5,\n", 179 | " \"s4-概率前往\": 0.5,\n", 180 | "}\n", 181 | "# 策略2\n", 182 | "Pi_2 = {\n", 183 | " \"s1-保持s1\": 0.6,\n", 184 | " \"s1-前往s2\": 0.4,\n", 185 | " \"s2-前往s1\": 0.3,\n", 186 | " \"s2-前往s3\": 0.7,\n", 187 | " \"s3-前往s4\": 0.5,\n", 188 | " \"s3-前往s5\": 0.5,\n", 189 | " \"s4-前往s5\": 0.1,\n", 190 | " \"s4-概率前往\": 0.9,\n", 191 | "}\n", 192 | "\n", 193 | "\n", 194 | "# 把输入的两个字符串通过“-”连接,便于使用上述定义的P、R变量\n", 195 | "def join(str1, str2):\n", 196 | " return str1 + '-' + str2" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 4, 202 | "metadata": { 203 | "colab": { 204 | "base_uri": "https://localhost:8080/" 205 | }, 206 | "executionInfo": { 207 | "elapsed": 3, 208 | "status": "ok", 209 | "timestamp": 1649954670178, 210 | "user": { 211 | "displayName": "Sam Lu", 212 | "userId": "15789059763790170725" 213 | }, 214 | "user_tz": -480 215 | }, 216 | "id": "juDFPGkP9fWe", 217 | "outputId": "20903c97-0f0e-4fb2-93e1-5b6b650fda1a" 218 | }, 219 | "outputs": [ 220 | { 221 | "name": "stdout", 222 | "output_type": "stream", 223 | "text": [ 224 | "MDP中每个状态价值分别为\n", 225 | " [[-1.22555411]\n", 226 | " [-1.67666232]\n", 227 | " [ 0.51890482]\n", 228 | " [ 6.0756193 ]\n", 229 | " [ 0. ]]\n" 230 | ] 231 | } 232 | ], 233 | "source": [ 234 | "gamma = 0.5\n", 235 | "# 转化后的MRP的状态转移矩阵\n", 236 | "P_from_mdp_to_mrp = [\n", 237 | " [0.5, 0.5, 0.0, 0.0, 0.0],\n", 238 | " [0.5, 0.0, 0.5, 0.0, 0.0],\n", 239 | " [0.0, 0.0, 0.0, 0.5, 0.5],\n", 240 | " [0.0, 0.1, 0.2, 0.2, 0.5],\n", 241 | " [0.0, 0.0, 0.0, 0.0, 1.0],\n", 242 | "]\n", 243 | "P_from_mdp_to_mrp = np.array(P_from_mdp_to_mrp)\n", 244 | "R_from_mdp_to_mrp = [-0.5, -1.5, -1.0, 5.5, 0]\n", 245 | "\n", 246 | "V = compute(P_from_mdp_to_mrp, R_from_mdp_to_mrp, gamma, 5)\n", 247 | "print(\"MDP中每个状态价值分别为\\n\", V)\n", 248 | "\n", 249 | "# MDP中每个状态价值分别为\n", 250 | "# [[-1.22555411]\n", 251 | "# [-1.67666232]\n", 252 | "# [ 0.51890482]\n", 253 | "# [ 6.0756193 ]\n", 254 | "# [ 0. ]]" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 6, 260 | "metadata": { 261 | "colab": { 262 | "base_uri": "https://localhost:8080/" 263 | }, 264 | "executionInfo": { 265 | "elapsed": 317, 266 | "status": "ok", 267 | "timestamp": 1649954714601, 268 | "user": { 269 | "displayName": "Sam Lu", 270 | "userId": "15789059763790170725" 271 | }, 272 | "user_tz": -480 273 | }, 274 | "id": "3gKVFNen9scC", 275 | "outputId": "3d5b5f1b-d5d2-4a26-fde2-76843910507b" 276 | }, 277 | "outputs": [ 278 | { 279 | "name": "stdout", 280 | "output_type": "stream", 281 | "text": [ 282 | "第一条序列\n", 283 | " [('s1', '前往s2', 0, 's2'), ('s2', '前往s3', -2, 's3'), ('s3', '前往s5', 0, 's5')]\n", 284 | "第二条序列\n", 285 | " [('s4', '概率前往', 1, 's4'), ('s4', '前往s5', 10, 's5')]\n", 286 | "第五条序列\n", 287 | " [('s2', '前往s3', -2, 's3'), ('s3', '前往s4', -2, 's4'), ('s4', '前往s5', 10, 's5')]\n" 288 | ] 289 | } 290 | ], 291 | "source": [ 292 | "def sample(MDP, Pi, timestep_max, number):\n", 293 | " ''' 采样函数,策略Pi,限制最长时间步timestep_max,总共采样序列数number '''\n", 294 | " S, A, P, R, gamma = MDP\n", 295 | " episodes = []\n", 296 | " for _ in range(number):\n", 297 | " episode = []\n", 298 | " timestep = 0\n", 299 | " s = S[np.random.randint(4)] # 随机选择一个除s5以外的状态s作为起点\n", 300 | " # 当前状态为终止状态或者时间步太长时,一次采样结束\n", 301 | " while s != \"s5\" and timestep <= timestep_max:\n", 302 | " timestep += 1\n", 303 | " rand, temp = np.random.rand(), 0\n", 304 | " # 在状态s下根据策略选择动作\n", 305 | " for a_opt in A:\n", 306 | " temp += Pi.get(join(s, a_opt), 0)\n", 307 | " if temp > rand:\n", 308 | " a = a_opt\n", 309 | " r = R.get(join(s, a), 0)\n", 310 | " break\n", 311 | " rand, temp = np.random.rand(), 0\n", 312 | " # 根据状态转移概率得到下一个状态s_next\n", 313 | " for s_opt in S:\n", 314 | " temp += P.get(join(join(s, a), s_opt), 0)\n", 315 | " if temp > rand:\n", 316 | " s_next = s_opt\n", 317 | " break\n", 318 | " episode.append((s, a, r, s_next)) # 把(s,a,r,s_next)元组放入序列中\n", 319 | " s = s_next # s_next变成当前状态,开始接下来的循环\n", 320 | " episodes.append(episode)\n", 321 | " return episodes\n", 322 | "\n", 323 | "\n", 324 | "# 采样5次,每个序列最长不超过1000步\n", 325 | "episodes = sample(MDP, Pi_1, 20, 5)\n", 326 | "print('第一条序列\\n', episodes[0])\n", 327 | "print('第二条序列\\n', episodes[1])\n", 328 | "print('第五条序列\\n', episodes[4])\n", 329 | "\n", 330 | "# 第一条序列\n", 331 | "# [('s1', '前往s2', 0, 's2'), ('s2', '前往s3', -2, 's3'), ('s3', '前往s5', 0, 's5')]\n", 332 | "# 第二条序列\n", 333 | "# [('s4', '概率前往', 1, 's4'), ('s4', '前往s5', 10, 's5')]\n", 334 | "# 第五条序列\n", 335 | "# [('s2', '前往s3', -2, 's3'), ('s3', '前往s4', -2, 's4'), ('s4', '前往s5', 10, 's5')]" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 7, 341 | "metadata": { 342 | "colab": { 343 | "base_uri": "https://localhost:8080/" 344 | }, 345 | "executionInfo": { 346 | "elapsed": 292, 347 | "status": "ok", 348 | "timestamp": 1649954717890, 349 | "user": { 350 | "displayName": "Sam Lu", 351 | "userId": "15789059763790170725" 352 | }, 353 | "user_tz": -480 354 | }, 355 | "id": "uZR44aSO9fWf", 356 | "outputId": "7354c75f-1bc7-44b2-accc-85019278720d" 357 | }, 358 | "outputs": [ 359 | { 360 | "name": "stdout", 361 | "output_type": "stream", 362 | "text": [ 363 | "使用蒙特卡洛方法计算MDP的状态价值为\n", 364 | " {'s1': -1.228923788722258, 's2': -1.6955696284402704, 's3': 0.4823809701532294, 's4': 5.967514743019431, 's5': 0}\n" 365 | ] 366 | } 367 | ], 368 | "source": [ 369 | "# 对所有采样序列计算所有状态的价值\n", 370 | "def MC(episodes, V, N, gamma):\n", 371 | " for episode in episodes:\n", 372 | " G = 0\n", 373 | " for i in range(len(episode) - 1, -1, -1): #一个序列从后往前计算\n", 374 | " (s, a, r, s_next) = episode[i]\n", 375 | " G = r + gamma * G\n", 376 | " N[s] = N[s] + 1\n", 377 | " V[s] = V[s] + (G - V[s]) / N[s]\n", 378 | "\n", 379 | "\n", 380 | "timestep_max = 20\n", 381 | "# 采样1000次,可以自行修改\n", 382 | "episodes = sample(MDP, Pi_1, timestep_max, 1000)\n", 383 | "gamma = 0.5\n", 384 | "V = {\"s1\": 0, \"s2\": 0, \"s3\": 0, \"s4\": 0, \"s5\": 0}\n", 385 | "N = {\"s1\": 0, \"s2\": 0, \"s3\": 0, \"s4\": 0, \"s5\": 0}\n", 386 | "MC(episodes, V, N, gamma)\n", 387 | "print(\"使用蒙特卡洛方法计算MDP的状态价值为\\n\", V)\n", 388 | "\n", 389 | "# 使用蒙特卡洛方法计算MDP的状态价值为\n", 390 | "# {'s1': -1.228923788722258, 's2': -1.6955696284402704, 's3': 0.4823809701532294,\n", 391 | "# 's4': 5.967514743019431, 's5': 0}" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": 8, 397 | "metadata": { 398 | "colab": { 399 | "base_uri": "https://localhost:8080/" 400 | }, 401 | "executionInfo": { 402 | "elapsed": 303, 403 | "status": "ok", 404 | "timestamp": 1649954723228, 405 | "user": { 406 | "displayName": "Sam Lu", 407 | "userId": "15789059763790170725" 408 | }, 409 | "user_tz": -480 410 | }, 411 | "id": "COkP4ZDh9fWg", 412 | "outputId": "943a07c9-f8db-4646-841b-d2960785eb17" 413 | }, 414 | "outputs": [ 415 | { 416 | "name": "stdout", 417 | "output_type": "stream", 418 | "text": [ 419 | "0.112567796310472 0.23199480615618912\n" 420 | ] 421 | } 422 | ], 423 | "source": [ 424 | "def occupancy(episodes, s, a, timestep_max, gamma):\n", 425 | " ''' 计算状态动作对(s,a)出现的频率,以此来估算策略的占用度量 '''\n", 426 | " rho = 0\n", 427 | " total_times = np.zeros(timestep_max) # 记录每个时间步t各被经历过几次\n", 428 | " occur_times = np.zeros(timestep_max) # 记录(s_t,a_t)=(s,a)的次数\n", 429 | " for episode in episodes:\n", 430 | " for i in range(len(episode)):\n", 431 | " (s_opt, a_opt, r, s_next) = episode[i]\n", 432 | " total_times[i] += 1\n", 433 | " if s == s_opt and a == a_opt:\n", 434 | " occur_times[i] += 1\n", 435 | " for i in reversed(range(timestep_max)):\n", 436 | " if total_times[i]:\n", 437 | " rho += gamma**i * occur_times[i] / total_times[i]\n", 438 | " return (1 - gamma) * rho\n", 439 | "\n", 440 | "\n", 441 | "gamma = 0.5\n", 442 | "timestep_max = 1000\n", 443 | "\n", 444 | "episodes_1 = sample(MDP, Pi_1, timestep_max, 1000)\n", 445 | "episodes_2 = sample(MDP, Pi_2, timestep_max, 1000)\n", 446 | "rho_1 = occupancy(episodes_1, \"s4\", \"概率前往\", timestep_max, gamma)\n", 447 | "rho_2 = occupancy(episodes_2, \"s4\", \"概率前往\", timestep_max, gamma)\n", 448 | "print(rho_1, rho_2)\n", 449 | "\n", 450 | "# 0.112567796310472 0.23199480615618912" 451 | ] 452 | } 453 | ], 454 | "metadata": { 455 | "colab": { 456 | "collapsed_sections": [], 457 | "name": "第3章-马尔可夫决策过程.ipynb", 458 | "provenance": [] 459 | }, 460 | "kernelspec": { 461 | "display_name": "Python 3", 462 | "language": "python", 463 | "name": "python3" 464 | }, 465 | "language_info": { 466 | "codemirror_mode": { 467 | "name": "ipython", 468 | "version": 3 469 | }, 470 | "file_extension": ".py", 471 | "mimetype": "text/x-python", 472 | "name": "python", 473 | "nbconvert_exporter": "python", 474 | "pygments_lexer": "ipython3", 475 | "version": "3.7.6" 476 | } 477 | }, 478 | "nbformat": 4, 479 | "nbformat_minor": 1 480 | } 481 | -------------------------------------------------------------------------------- /第4章-动态规划算法.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "executionInfo": { 8 | "elapsed": 4, 9 | "status": "ok", 10 | "timestamp": 1649954819944, 11 | "user": { 12 | "displayName": "Sam Lu", 13 | "userId": "15789059763790170725" 14 | }, 15 | "user_tz": -480 16 | }, 17 | "id": "oXP3ykOT95VF" 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "import copy\n", 22 | "\n", 23 | "\n", 24 | "class CliffWalkingEnv:\n", 25 | " \"\"\" 悬崖漫步环境\"\"\"\n", 26 | " def __init__(self, ncol=12, nrow=4):\n", 27 | " self.ncol = ncol # 定义网格世界的列\n", 28 | " self.nrow = nrow # 定义网格世界的行\n", 29 | " # 转移矩阵P[state][action] = [(p, next_state, reward, done)]包含下一个状态和奖励\n", 30 | " self.P = self.createP()\n", 31 | "\n", 32 | " def createP(self):\n", 33 | " # 初始化\n", 34 | " P = [[[] for j in range(4)] for i in range(self.nrow * self.ncol)]\n", 35 | " # 4种动作, change[0]:上,change[1]:下, change[2]:左, change[3]:右。坐标系原点(0,0)\n", 36 | " # 定义在左上角\n", 37 | " change = [[0, -1], [0, 1], [-1, 0], [1, 0]]\n", 38 | " for i in range(self.nrow):\n", 39 | " for j in range(self.ncol):\n", 40 | " for a in range(4):\n", 41 | " # 位置在悬崖或者目标状态,因为无法继续交互,任何动作奖励都为0\n", 42 | " if i == self.nrow - 1 and j > 0:\n", 43 | " P[i * self.ncol + j][a] = [(1, i * self.ncol + j, 0,\n", 44 | " True)]\n", 45 | " continue\n", 46 | " # 其他位置\n", 47 | " next_x = min(self.ncol - 1, max(0, j + change[a][0]))\n", 48 | " next_y = min(self.nrow - 1, max(0, i + change[a][1]))\n", 49 | " next_state = next_y * self.ncol + next_x\n", 50 | " reward = -1\n", 51 | " done = False\n", 52 | " # 下一个位置在悬崖或者终点\n", 53 | " if next_y == self.nrow - 1 and next_x > 0:\n", 54 | " done = True\n", 55 | " if next_x != self.ncol - 1: # 下一个位置在悬崖\n", 56 | " reward = -100\n", 57 | " P[i * self.ncol + j][a] = [(1, next_state, reward, done)]\n", 58 | " return P" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 2, 64 | "metadata": { 65 | "executionInfo": { 66 | "elapsed": 2012, 67 | "status": "ok", 68 | "timestamp": 1649954830279, 69 | "user": { 70 | "displayName": "Sam Lu", 71 | "userId": "15789059763790170725" 72 | }, 73 | "user_tz": -480 74 | }, 75 | "id": "N95QAxfi95VJ" 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "class PolicyIteration:\n", 80 | " \"\"\" 策略迭代算法 \"\"\"\n", 81 | " def __init__(self, env, theta, gamma):\n", 82 | " self.env = env\n", 83 | " self.v = [0] * self.env.ncol * self.env.nrow # 初始化价值为0\n", 84 | " self.pi = [[0.25, 0.25, 0.25, 0.25]\n", 85 | " for i in range(self.env.ncol * self.env.nrow)] # 初始化为均匀随机策略\n", 86 | " self.theta = theta # 策略评估收敛阈值\n", 87 | " self.gamma = gamma # 折扣因子\n", 88 | "\n", 89 | " def policy_evaluation(self): # 策略评估\n", 90 | " cnt = 1 # 计数器\n", 91 | " while 1:\n", 92 | " max_diff = 0\n", 93 | " new_v = [0] * self.env.ncol * self.env.nrow\n", 94 | " for s in range(self.env.ncol * self.env.nrow):\n", 95 | " qsa_list = [] # 开始计算状态s下的所有Q(s,a)价值\n", 96 | " for a in range(4):\n", 97 | " qsa = 0\n", 98 | " for res in self.env.P[s][a]:\n", 99 | " p, next_state, r, done = res\n", 100 | " qsa += p * (r + self.gamma * self.v[next_state] *\n", 101 | " (1 - done))\n", 102 | " # 本章环境比较特殊,奖励和下一个状态有关,所以需要和状态转移概率相乘\n", 103 | " qsa_list.append(self.pi[s][a] * qsa)\n", 104 | " new_v[s] = sum(qsa_list) # 状态价值函数和动作价值函数之间的关系\n", 105 | " max_diff = max(max_diff, abs(new_v[s] - self.v[s]))\n", 106 | " self.v = new_v\n", 107 | " if max_diff < self.theta: break # 满足收敛条件,退出评估迭代\n", 108 | " cnt += 1\n", 109 | " print(\"策略评估进行%d轮后完成\" % cnt)\n", 110 | "\n", 111 | " def policy_improvement(self): # 策略提升\n", 112 | " for s in range(self.env.nrow * self.env.ncol):\n", 113 | " qsa_list = []\n", 114 | " for a in range(4):\n", 115 | " qsa = 0\n", 116 | " for res in self.env.P[s][a]:\n", 117 | " p, next_state, r, done = res\n", 118 | " qsa += p * (r + self.gamma * self.v[next_state] *\n", 119 | " (1 - done))\n", 120 | " qsa_list.append(qsa)\n", 121 | " maxq = max(qsa_list)\n", 122 | " cntq = qsa_list.count(maxq) # 计算有几个动作得到了最大的Q值\n", 123 | " # 让这些动作均分概率\n", 124 | " self.pi[s] = [1 / cntq if q == maxq else 0 for q in qsa_list]\n", 125 | " print(\"策略提升完成\")\n", 126 | " return self.pi\n", 127 | "\n", 128 | " def policy_iteration(self): # 策略迭代\n", 129 | " while 1:\n", 130 | " self.policy_evaluation()\n", 131 | " old_pi = copy.deepcopy(self.pi) # 将列表进行深拷贝,方便接下来进行比较\n", 132 | " new_pi = self.policy_improvement()\n", 133 | " if old_pi == new_pi: break" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 3, 139 | "metadata": { 140 | "colab": { 141 | "base_uri": "https://localhost:8080/" 142 | }, 143 | "executionInfo": { 144 | "elapsed": 4, 145 | "status": "ok", 146 | "timestamp": 1649954831712, 147 | "user": { 148 | "displayName": "Sam Lu", 149 | "userId": "15789059763790170725" 150 | }, 151 | "user_tz": -480 152 | }, 153 | "id": "yZCGJazo95VK", 154 | "outputId": "c84ef7a4-7c75-42f6-92ec-bf0c3e4ea167" 155 | }, 156 | "outputs": [ 157 | { 158 | "name": "stdout", 159 | "output_type": "stream", 160 | "text": [ 161 | "策略评估进行60轮后完成\n", 162 | "策略提升完成\n", 163 | "策略评估进行72轮后完成\n", 164 | "策略提升完成\n", 165 | "策略评估进行44轮后完成\n", 166 | "策略提升完成\n", 167 | "策略评估进行12轮后完成\n", 168 | "策略提升完成\n", 169 | "策略评估进行1轮后完成\n", 170 | "策略提升完成\n", 171 | "状态价值:\n", 172 | "-7.712 -7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 \n", 173 | "-7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 \n", 174 | "-7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 -1.000 \n", 175 | "-7.458 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 \n", 176 | "策略:\n", 177 | "ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo \n", 178 | "ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo \n", 179 | "ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo \n", 180 | "^ooo **** **** **** **** **** **** **** **** **** **** EEEE \n" 181 | ] 182 | } 183 | ], 184 | "source": [ 185 | "def print_agent(agent, action_meaning, disaster=[], end=[]):\n", 186 | " print(\"状态价值:\")\n", 187 | " for i in range(agent.env.nrow):\n", 188 | " for j in range(agent.env.ncol):\n", 189 | " # 为了输出美观,保持输出6个字符\n", 190 | " print('%6.6s' % ('%.3f' % agent.v[i * agent.env.ncol + j]),\n", 191 | " end=' ')\n", 192 | " print()\n", 193 | "\n", 194 | " print(\"策略:\")\n", 195 | " for i in range(agent.env.nrow):\n", 196 | " for j in range(agent.env.ncol):\n", 197 | " # 一些特殊的状态,例如悬崖漫步中的悬崖\n", 198 | " if (i * agent.env.ncol + j) in disaster:\n", 199 | " print('****', end=' ')\n", 200 | " elif (i * agent.env.ncol + j) in end: # 目标状态\n", 201 | " print('EEEE', end=' ')\n", 202 | " else:\n", 203 | " a = agent.pi[i * agent.env.ncol + j]\n", 204 | " pi_str = ''\n", 205 | " for k in range(len(action_meaning)):\n", 206 | " pi_str += action_meaning[k] if a[k] > 0 else 'o'\n", 207 | " print(pi_str, end=' ')\n", 208 | " print()\n", 209 | "\n", 210 | "\n", 211 | "env = CliffWalkingEnv()\n", 212 | "action_meaning = ['^', 'v', '<', '>']\n", 213 | "theta = 0.001\n", 214 | "gamma = 0.9\n", 215 | "agent = PolicyIteration(env, theta, gamma)\n", 216 | "agent.policy_iteration()\n", 217 | "print_agent(agent, action_meaning, list(range(37, 47)), [47])\n", 218 | "\n", 219 | "# 策略评估进行60轮后完成\n", 220 | "# 策略提升完成\n", 221 | "# 策略评估进行72轮后完成\n", 222 | "# 策略提升完成\n", 223 | "# 策略评估进行44轮后完成\n", 224 | "# 策略提升完成\n", 225 | "# 策略评估进行12轮后完成\n", 226 | "# 策略提升完成\n", 227 | "# 策略评估进行1轮后完成\n", 228 | "# 策略提升完成\n", 229 | "# 状态价值:\n", 230 | "# -7.712 -7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710\n", 231 | "# -7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900\n", 232 | "# -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 -1.000\n", 233 | "# -7.458 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000\n", 234 | "# 策略:\n", 235 | "# ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo\n", 236 | "# ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo\n", 237 | "# ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo\n", 238 | "# ^ooo **** **** **** **** **** **** **** **** **** **** EEEE" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 4, 244 | "metadata": { 245 | "executionInfo": { 246 | "elapsed": 3, 247 | "status": "ok", 248 | "timestamp": 1649954834592, 249 | "user": { 250 | "displayName": "Sam Lu", 251 | "userId": "15789059763790170725" 252 | }, 253 | "user_tz": -480 254 | }, 255 | "id": "qs8Xd7LI95VL" 256 | }, 257 | "outputs": [], 258 | "source": [ 259 | "class ValueIteration:\n", 260 | " \"\"\" 价值迭代算法 \"\"\"\n", 261 | " def __init__(self, env, theta, gamma):\n", 262 | " self.env = env\n", 263 | " self.v = [0] * self.env.ncol * self.env.nrow # 初始化价值为0\n", 264 | " self.theta = theta # 价值收敛阈值\n", 265 | " self.gamma = gamma\n", 266 | " # 价值迭代结束后得到的策略\n", 267 | " self.pi = [None for i in range(self.env.ncol * self.env.nrow)]\n", 268 | "\n", 269 | " def value_iteration(self):\n", 270 | " cnt = 0\n", 271 | " while 1:\n", 272 | " max_diff = 0\n", 273 | " new_v = [0] * self.env.ncol * self.env.nrow\n", 274 | " for s in range(self.env.ncol * self.env.nrow):\n", 275 | " qsa_list = [] # 开始计算状态s下的所有Q(s,a)价值\n", 276 | " for a in range(4):\n", 277 | " qsa = 0\n", 278 | " for res in self.env.P[s][a]:\n", 279 | " p, next_state, r, done = res\n", 280 | " qsa += p * (r + self.gamma * self.v[next_state] *\n", 281 | " (1 - done))\n", 282 | " qsa_list.append(qsa) # 这一行和下一行代码是价值迭代和策略迭代的主要区别\n", 283 | " new_v[s] = max(qsa_list)\n", 284 | " max_diff = max(max_diff, abs(new_v[s] - self.v[s]))\n", 285 | " self.v = new_v\n", 286 | " if max_diff < self.theta: break # 满足收敛条件,退出评估迭代\n", 287 | " cnt += 1\n", 288 | " print(\"价值迭代一共进行%d轮\" % cnt)\n", 289 | " self.get_policy()\n", 290 | "\n", 291 | " def get_policy(self): # 根据价值函数导出一个贪婪策略\n", 292 | " for s in range(self.env.nrow * self.env.ncol):\n", 293 | " qsa_list = []\n", 294 | " for a in range(4):\n", 295 | " qsa = 0\n", 296 | " for res in self.env.P[s][a]:\n", 297 | " p, next_state, r, done = res\n", 298 | " qsa += p * (r + self.gamma * self.v[next_state] * (1 - done))\n", 299 | " qsa_list.append(qsa)\n", 300 | " maxq = max(qsa_list)\n", 301 | " cntq = qsa_list.count(maxq) # 计算有几个动作得到了最大的Q值\n", 302 | " # 让这些动作均分概率\n", 303 | " self.pi[s] = [1 / cntq if q == maxq else 0 for q in qsa_list]\n", 304 | "\n", 305 | "\n", 306 | "# env = CliffWalkingEnv()\n", 307 | "# action_meaning = ['^', 'v', '<', '>']\n", 308 | "# theta = 0.001\n", 309 | "# gamma = 0.9\n", 310 | "# agent = ValueIteration(env, theta, gamma)\n", 311 | "# agent.value_iteration()\n", 312 | "# print_agent(agent, action_meaning, list(range(37, 47)), [47])\n", 313 | "\n", 314 | "# 价值迭代一共进行14轮\n", 315 | "# 状态价值:\n", 316 | "# -7.712 -7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710\n", 317 | "# -7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900\n", 318 | "# -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 -1.000\n", 319 | "# -7.458 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000\n", 320 | "# 策略:\n", 321 | "# ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo\n", 322 | "# ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo\n", 323 | "# ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo\n", 324 | "# ^ooo **** **** **** **** **** **** **** **** **** **** EEEE" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 5, 330 | "metadata": { 331 | "colab": { 332 | "base_uri": "https://localhost:8080/" 333 | }, 334 | "executionInfo": { 335 | "elapsed": 3, 336 | "status": "ok", 337 | "timestamp": 1649954837446, 338 | "user": { 339 | "displayName": "Sam Lu", 340 | "userId": "15789059763790170725" 341 | }, 342 | "user_tz": -480 343 | }, 344 | "id": "2ZFlb2dB95VM", 345 | "outputId": "0d4f2c7a-3589-4bb1-f95a-e9e4696d6201" 346 | }, 347 | "outputs": [ 348 | { 349 | "name": "stdout", 350 | "output_type": "stream", 351 | "text": [ 352 | "\n", 353 | "\u001b[41mS\u001b[0mFFF\n", 354 | "FHFH\n", 355 | "FFFH\n", 356 | "HFFG\n", 357 | "冰洞的索引: {11, 12, 5, 7}\n", 358 | "目标的索引: {15}\n", 359 | "[(0.3333333333333333, 10, 0.0, False), (0.3333333333333333, 13, 0.0, False), (0.3333333333333333, 14, 0.0, False)]\n", 360 | "[(0.3333333333333333, 13, 0.0, False), (0.3333333333333333, 14, 0.0, False), (0.3333333333333333, 15, 1.0, True)]\n", 361 | "[(0.3333333333333333, 14, 0.0, False), (0.3333333333333333, 15, 1.0, True), (0.3333333333333333, 10, 0.0, False)]\n", 362 | "[(0.3333333333333333, 15, 1.0, True), (0.3333333333333333, 10, 0.0, False), (0.3333333333333333, 13, 0.0, False)]\n" 363 | ] 364 | } 365 | ], 366 | "source": [ 367 | "import gym\n", 368 | "env = gym.make(\"FrozenLake-v0\") # 创建环境\n", 369 | "env = env.unwrapped # 解封装才能访问状态转移矩阵P\n", 370 | "env.render() # 环境渲染,通常是弹窗显示或打印出可视化的环境\n", 371 | "\n", 372 | "holes = set()\n", 373 | "ends = set()\n", 374 | "for s in env.P:\n", 375 | " for a in env.P[s]:\n", 376 | " for s_ in env.P[s][a]:\n", 377 | " if s_[2] == 1.0: # 获得奖励为1,代表是目标\n", 378 | " ends.add(s_[1])\n", 379 | " if s_[3] == True:\n", 380 | " holes.add(s_[1])\n", 381 | "holes = holes - ends\n", 382 | "print(\"冰洞的索引:\", holes)\n", 383 | "print(\"目标的索引:\", ends)\n", 384 | "\n", 385 | "for a in env.P[14]: # 查看目标左边一格的状态转移信息\n", 386 | " print(env.P[14][a])\n", 387 | "\n", 388 | "# SFFF\n", 389 | "# FHFH\n", 390 | "# FFFH\n", 391 | "# HFFG\n", 392 | "# 冰洞的索引: {11, 12, 5, 7}\n", 393 | "# 目标的索引: {15}\n", 394 | "# [(0.3333333333333333, 10, 0.0, False), (0.3333333333333333, 13, 0.0, False),\n", 395 | "# (0.3333333333333333, 14, 0.0, False)]\n", 396 | "# [(0.3333333333333333, 13, 0.0, False), (0.3333333333333333, 14, 0.0, False),\n", 397 | "# (0.3333333333333333, 15, 1.0, True)]\n", 398 | "# [(0.3333333333333333, 14, 0.0, False), (0.3333333333333333, 15, 1.0, True),\n", 399 | "# (0.3333333333333333, 10, 0.0, False)]\n", 400 | "# [(0.3333333333333333, 15, 1.0, True), (0.3333333333333333, 10, 0.0, False),\n", 401 | "# (0.3333333333333333, 13, 0.0, False)]" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 6, 407 | "metadata": { 408 | "colab": { 409 | "base_uri": "https://localhost:8080/" 410 | }, 411 | "executionInfo": { 412 | "elapsed": 7, 413 | "status": "ok", 414 | "timestamp": 1649954839116, 415 | "user": { 416 | "displayName": "Sam Lu", 417 | "userId": "15789059763790170725" 418 | }, 419 | "user_tz": -480 420 | }, 421 | "id": "4gf_IaeZ95VM", 422 | "outputId": "20ba5b95-4481-4b40-8015-54ea7fa0c2cb" 423 | }, 424 | "outputs": [ 425 | { 426 | "name": "stdout", 427 | "output_type": "stream", 428 | "text": [ 429 | "策略评估进行25轮后完成\n", 430 | "策略提升完成\n", 431 | "策略评估进行58轮后完成\n", 432 | "策略提升完成\n", 433 | "状态价值:\n", 434 | " 0.069 0.061 0.074 0.056 \n", 435 | " 0.092 0.000 0.112 0.000 \n", 436 | " 0.145 0.247 0.300 0.000 \n", 437 | " 0.000 0.380 0.639 0.000 \n", 438 | "策略:\n", 439 | "o **** \n", 441 | "ooo^ ovoo o ovoo EEEE \n" 443 | ] 444 | } 445 | ], 446 | "source": [ 447 | "# 这个动作意义是Gym库针对冰湖环境事先规定好的\n", 448 | "action_meaning = ['<', 'v', '>', '^']\n", 449 | "theta = 1e-5\n", 450 | "gamma = 0.9\n", 451 | "agent = PolicyIteration(env, theta, gamma)\n", 452 | "agent.policy_iteration()\n", 453 | "print_agent(agent, action_meaning, [5, 7, 11, 12], [15])\n", 454 | "\n", 455 | "# 策略评估进行25轮后完成\n", 456 | "# 策略提升完成\n", 457 | "# 策略评估进行58轮后完成\n", 458 | "# 策略提升完成\n", 459 | "# 状态价值:\n", 460 | "# 0.069 0.061 0.074 0.056\n", 461 | "# 0.092 0.000 0.112 0.000\n", 462 | "# 0.145 0.247 0.300 0.000\n", 463 | "# 0.000 0.380 0.639 0.000\n", 464 | "# 策略:\n", 465 | "# o ****\n", 467 | "# ooo^ ovoo o ovoo EEEE" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 7, 474 | "metadata": { 475 | "colab": { 476 | "base_uri": "https://localhost:8080/" 477 | }, 478 | "executionInfo": { 479 | "elapsed": 6, 480 | "status": "ok", 481 | "timestamp": 1649954839117, 482 | "user": { 483 | "displayName": "Sam Lu", 484 | "userId": "15789059763790170725" 485 | }, 486 | "user_tz": -480 487 | }, 488 | "id": "cqPm4jxd95VN", 489 | "outputId": "ab6bdc64-5a4a-47f0-c6af-08bf4c2aeb0f" 490 | }, 491 | "outputs": [ 492 | { 493 | "name": "stdout", 494 | "output_type": "stream", 495 | "text": [ 496 | "价值迭代一共进行60轮\n", 497 | "状态价值:\n", 498 | " 0.069 0.061 0.074 0.056 \n", 499 | " 0.092 0.000 0.112 0.000 \n", 500 | " 0.145 0.247 0.300 0.000 \n", 501 | " 0.000 0.380 0.639 0.000 \n", 502 | "策略:\n", 503 | "o **** \n", 505 | "ooo^ ovoo o ovoo EEEE \n" 507 | ] 508 | } 509 | ], 510 | "source": [ 511 | "action_meaning = ['<', 'v', '>', '^']\n", 512 | "theta = 1e-5\n", 513 | "gamma = 0.9\n", 514 | "agent = ValueIteration(env, theta, gamma)\n", 515 | "agent.value_iteration()\n", 516 | "print_agent(agent, action_meaning, [5, 7, 11, 12], [15])\n", 517 | "\n", 518 | "# 价值迭代一共进行60轮\n", 519 | "# 状态价值:\n", 520 | "# 0.069 0.061 0.074 0.056\n", 521 | "# 0.092 0.000 0.112 0.000\n", 522 | "# 0.145 0.247 0.300 0.000\n", 523 | "# 0.000 0.380 0.639 0.000\n", 524 | "# 策略:\n", 525 | "# o ****\n", 527 | "# ooo^ ovoo o ovoo EEEE" 529 | ] 530 | } 531 | ], 532 | "metadata": { 533 | "colab": { 534 | "collapsed_sections": [], 535 | "name": "第4章-动态规划算法.ipynb", 536 | "provenance": [], 537 | "toc_visible": true 538 | }, 539 | "kernelspec": { 540 | "display_name": "Python 3", 541 | "language": "python", 542 | "name": "python3" 543 | }, 544 | "language_info": { 545 | "codemirror_mode": { 546 | "name": "ipython", 547 | "version": 3 548 | }, 549 | "file_extension": ".py", 550 | "mimetype": "text/x-python", 551 | "name": "python", 552 | "nbconvert_exporter": "python", 553 | "pygments_lexer": "ipython3", 554 | "version": "3.7.6" 555 | } 556 | }, 557 | "nbformat": 4, 558 | "nbformat_minor": 1 559 | } 560 | --------------------------------------------------------------------------------