├── requirements.txt ├── 利用PARL复现基于神经网络与DQN算法.py ├── 利用PARL复现基于神经网络与DQN算法.md ├── README.md └── 利用PARL复现基于神经网络与DQN算法.ipynb /requirements.txt: -------------------------------------------------------------------------------- 1 | paddlepaddle==1.6.3 2 | parl==1.3.1 3 | gym 4 | -------------------------------------------------------------------------------- /利用PARL复现基于神经网络与DQN算法.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # 深度学习入门 | 三岁在飞桨带你入门深度学习—Carpoel,利用PARL复现基于神经网络与DQN算法(真的是0基础) 5 | # 大家好,这里是三岁,众所周知三岁是编程届小白,为了给大家贡献一个“爬山”的模板, 6 | # 三岁利用最基础的深度学习“hello world”项目给大家解析,及做示范。 7 | # 三岁老规矩,白话,简单,入门,基础 8 | # 如果有什么不准确,不正确的地方希望大家可以提出来! 9 | # (代码源于[强化学习7日打卡营-世界冠军带你从零实践>PARL强化学习公开课Lesson3_DQN](https://aistudio.baidu.com/aistudio/projectdetail/569647)) 10 | # * 以下项目适用于CPU环境 11 | # ## 参考资料 12 | # * B站视频地址:[https://www.bilibili.com/video/bv1v54y1v7Qf](https://www.bilibili.com/video/bv1v54y1v7Qf) 13 | # * AI 社区文章地址:[https://ai.baidu.com/forum/topic/show/962531](https://ai.baidu.com/forum/topic/show/962531) 14 | # * CSDN文章地址:[https://editor.csdn.net/md?articleId=107393006](https://editor.csdn.net/md?articleId=107393006) 15 | # * 三岁推文地址:[https://mp.weixin.qq.com/s/6-6RR0XuvTNuXKhX7fFXaQ](https://mp.weixin.qq.com/s/6-6RR0XuvTNuXKhX7fFXaQ) 16 | # * 参考论文:[https://www.nature.com/articles/nature14236](https://www.nature.com/articles/nature14236) 17 | # * DQNgithub地址:[https://github.com/PaddlePaddle/PARL/tree/develop/examples](https://github.com/PaddlePaddle/PARL/tree/develop/examples) 18 | # * 参考视频:[https://www.bilibili.com/video/BV1yv411i7xd?p=12](https://www.bilibili.com/video/BV1yv411i7xd?p=12) 19 | # * Carpoel参考资料:[https://gym.openai.com/envs/CartPole-v1/](https://gym.openai.com/envs/CartPole-v1/) 20 | # * PARL官方地址:[https://github.com/PaddlePaddle/PARL](https://github.com/PaddlePaddle/PARL) 21 | # 22 | # 那么我们接下来就开始爬山吧,记得唱着小白船然后录像呦,三岁会帮你调整一下jio的位置的【滑稽】 23 | # 24 | 25 | # ### 环境预设 26 | # 根据实际情况把AI Studio的环境进行修改使其更加符合代码的运行。 27 | 28 | # In[1]: 29 | 30 | 31 | get_ipython().system('pip uninstall -y parl # 说明:AIStudio预装的parl版本太老,容易跟其他库产生兼容性冲突,建议先卸载') 32 | get_ipython().system('pip uninstall -y pandas scikit-learn # 提示:在AIStudio中卸载这两个库再import parl可避免warning提示,不卸载也不影响parl的使用') 33 | 34 | get_ipython().system('pip install gym') 35 | get_ipython().system('pip install paddlepaddle==1.6.3') 36 | get_ipython().system('pip install parl==1.3.1') 37 | 38 | # 建议下载paddle系列产品时添加百度源 -i https://mirror.baidu.com/pypi/simple 39 | # python -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple 40 | # pip install parl==1.3.1 -i https://mirror.baidu.com/pypi/simple 41 | 42 | # 说明:安装日志中出现两条红色的关于 paddlehub 和 visualdl 的 ERROR 与parl无关,可以忽略,不影响使用 43 | 44 | 45 | # ### Step2 导入依赖 46 | # 如果依赖导入失败有可能没有下载第三方库可以在此前加上代码块 47 | # !pip instudio 第三方库名 48 | # 49 | # 如果安装失败可以加上镜像 50 | # 51 | # ![](https://ai-studio-static-online.cdn.bcebos.com/8ee07cbb24874e96bdac93b07e3a22e3b6b465981c904129990af727701bd0b0) 52 | # ``` 53 | # 镜像源地址: 54 | # 55 | # 百度:https://mirror.baidu.com/pypi/simple 56 | # 57 | # 清华:https://pypi.tuna.tsinghua.edu.cn/simple 58 | # 59 | # 阿里云:http://mirrors.aliyun.com/pypi/simple/ 60 | # 61 | # 中国科技大学 https://pypi.mirrors.ustc.edu.cn/simple/ 62 | # 63 | # 华中理工大学:http://pypi.hustunique.com/ 64 | # 65 | # 山东理工大学:http://pypi.sdutlinux.org/ 66 | # 67 | # 豆瓣:http://pypi.douban.com/simple/ 68 | # 69 | # 例:!pip instudio jieba -i https://mirror.baidu.com/pypi/simple 70 | # ``` 71 | 72 | # In[2]: 73 | 74 | 75 | import parl 76 | from parl import layers 77 | import paddle.fluid as fluid 78 | import copy 79 | import numpy as np 80 | import os 81 | import gym 82 | from parl.utils import logger 83 | 84 | 85 | # # 流程解析 86 | # ![](https://ai-studio-static-online.cdn.bcebos.com/25f4800e0c3b44bc9ddc908a4d3dba1793dc74f0f5a045e799771f0466b88e5b) 87 | # 88 | 89 | # ### Step3 设置超参数 90 | 91 | # In[3]: 92 | 93 | 94 | LEARN_FREQ = 5 # 训练频率,不需要每一个step都learn,攒一些新增经验后再learn,提高效率 95 | MEMORY_SIZE = 20000 # replay memory的大小,越大越占用内存 96 | MEMORY_WARMUP_SIZE = 200 # replay_memory 里需要预存一些经验数据,再开启训练 97 | BATCH_SIZE = 32 # 每次给agent learn的数据数量,从replay memory随机里sample一批数据出来 98 | LEARNING_RATE = 0.001 # 学习率 99 | GAMMA = 0.99 # reward 的衰减因子,一般取 0.9 到 0.999 不等 100 | 101 | 102 | # # 搭建PARL 103 | # 什么是PARL呢?让我们看看官方文档怎么说 104 | # 105 | # ![飞桨官网描述](https://ai-studio-static-online.cdn.bcebos.com/1f72929e036549d493dee85c2834b8a98e330c4ac6a447f08b647cab09e12bfa) 飞桨官网描述 106 | # ![PARL GitHub描述](https://ai-studio-static-online.cdn.bcebos.com/9a19623bdf41411ca07c089d3cf0f64752e0eb51a58a47e694bc4e4614079078) PARL GitHub描述 107 | # 108 | # * PARL主要是基于Model、Algorithm、Agent三个代码块来实现,其中Model和Agent是用户自定义操作。 109 | # * Model:是网络结构:要三层网络还是四层网络都是在Model中去定义(下文是三层的网络结构) 110 | # * Agent:是PARL与环境的一个接口,通过对模板的修改即可运用到各个不同的环境中去。 111 | # 112 | # * 至于Algorithm是内部已经封装好了的,直接加入参数运行即可,主要是算法的模块的展现 113 | # 114 | 115 | # ### Step4 搭建Model、Algorithm、Agent架构 116 | # * `Agent`把产生的数据传给`algorithm`,`algorithm`根据`model`的模型结构计算出`Loss`,使用`SGD`或者其他优化器不断的优化,`PARL`这种架构可以很方便的应用在各类深度强化学习问题中。 117 | # 118 | # #### (1)Model 119 | # * `Model`用来定义前向(`Forward`)网络,用户可以自由的定制自己的网络结构。 120 | 121 | # In[4]: 122 | 123 | 124 | class Model(parl.Model): 125 | def __init__(self, act_dim): 126 | hid1_size = 128 127 | hid2_size = 128 128 | # 3层全连接网络 129 | self.fc1 = layers.fc(size=hid1_size, act='relu') 130 | self.fc2 = layers.fc(size=hid2_size, act='relu') 131 | self.fc3 = layers.fc(size=act_dim, act=None) 132 | 133 | def value(self, obs): 134 | # 定义网络 135 | # 输入state,输出所有action对应的Q,[Q(s,a1), Q(s,a2), Q(s,a3)...] 136 | h1 = self.fc1(obs) 137 | h2 = self.fc2(h1) 138 | Q = self.fc3(h2) 139 | return Q 140 | 141 | 142 | # ## Q算法的“藏身之地” 143 | # ### 优势 144 | # DQN算法较普通算法在经验回放和固定Q目标有了较大的改进 145 | # 146 | # * 1、经验回放:他充分利用了off-colicp的优势,通过训练把结果(成绩)存入Q表格,然后随机从表格中取出一条结果进行优化。这样子一方面可以:减少样本之间的关联性另一方面:提高样本的利用率 147 | # 注:训练结果会存进Q表格,当Q表格满了以后,存进来的数据会把最早存进去的数据“挤出去”(弹出) 148 | # * 2、固定Q目标他解决了算法更新不平稳的问题。 149 | # 和监督学习做比较,监督学习的最终值要逼近实际结果,这个结果是固定的,但是我们的DQN却不是,他的目标值是经过神经网络以后的一个值,那么这个值是变动的不好拟合,怎么办,DQN团队想到了一个很好的办法,让这个值在一定时间里面保持不变,这样子这个目标就可以确定了,然后目标值更新以后更加接近实际结果,可以更好的进行训练。 150 | 151 | # #### (2)Algorithm 152 | # * `Algorithm` 定义了具体的算法来更新前向网络(`Model`),也就是通过定义损失函数来更新`Model`,和算法相关的计算都放在`algorithm`中。 153 | # 154 | # 155 | 156 | # In[5]: 157 | 158 | 159 | # from parl.algorithms import DQN # 也可以直接从parl库中导入DQN算法 160 | 161 | class DQN(parl.Algorithm): 162 | def __init__(self, model, act_dim=None, gamma=None, lr=None): 163 | """ DQN algorithm 164 | 165 | Args: 166 | model (parl.Model): 定义Q函数的前向网络结构 167 | act_dim (int): action空间的维度,即有几个action 168 | gamma (float): reward的衰减因子 169 | lr (float): learning rate 学习率. 170 | """ 171 | self.model = model 172 | self.target_model = copy.deepcopy(model) 173 | 174 | assert isinstance(act_dim, int) 175 | assert isinstance(gamma, float) 176 | assert isinstance(lr, float) 177 | self.act_dim = act_dim 178 | self.gamma = gamma 179 | self.lr = lr 180 | 181 | def predict(self, obs): 182 | """ 使用self.model的value网络来获取 [Q(s,a1),Q(s,a2),...] 183 | """ 184 | return self.model.value(obs) 185 | 186 | def learn(self, obs, action, reward, next_obs, terminal): 187 | """ 使用DQN算法更新self.model的value网络 188 | """ 189 | # 从target_model中获取 max Q' 的值,用于计算target_Q 190 | next_pred_value = self.target_model.value(next_obs) 191 | best_v = layers.reduce_max(next_pred_value, dim=1) 192 | best_v.stop_gradient = True # 阻止梯度传递 193 | terminal = layers.cast(terminal, dtype='float32') 194 | target = reward + (1.0 - terminal) * self.gamma * best_v 195 | 196 | pred_value = self.model.value(obs) # 获取Q预测值 197 | # 将action转onehot向量,比如:3 => [0,0,0,1,0] 198 | action_onehot = layers.one_hot(action, self.act_dim) 199 | action_onehot = layers.cast(action_onehot, dtype='float32') 200 | # 下面一行是逐元素相乘,拿到action对应的 Q(s,a) 201 | # 比如:pred_value = [[2.3, 5.7, 1.2, 3.9, 1.4]], action_onehot = [[0,0,0,1,0]] 202 | # ==> pred_action_value = [[3.9]] 203 | pred_action_value = layers.reduce_sum( 204 | layers.elementwise_mul(action_onehot, pred_value), dim=1) 205 | 206 | # 计算 Q(s,a) 与 target_Q的均方差,得到loss 207 | cost = layers.square_error_cost(pred_action_value, target) 208 | cost = layers.reduce_mean(cost) 209 | optimizer = fluid.optimizer.Adam(learning_rate=self.lr) # 使用Adam优化器 210 | optimizer.minimize(cost) 211 | return cost 212 | 213 | def sync_target(self): 214 | """ 把 self.model 的模型参数值同步到 self.target_model 215 | """ 216 | self.model.sync_weights_to(self.target_model) 217 | 218 | 219 | # #### (3)Agent 220 | # * `Agent` 负责算法与环境的交互,在交互过程中把生成的数据提供给`Algorithm`来更新模型(`Model`),数据的预处理流程也一般定义在这里。 221 | 222 | # In[6]: 223 | 224 | 225 | class Agent(parl.Agent): 226 | def __init__(self, 227 | algorithm, 228 | obs_dim, 229 | act_dim, 230 | e_greed=0.1, 231 | e_greed_decrement=0): 232 | assert isinstance(obs_dim, int) 233 | assert isinstance(act_dim, int) 234 | self.obs_dim = obs_dim 235 | self.act_dim = act_dim 236 | super(Agent, self).__init__(algorithm) 237 | 238 | self.global_step = 0 239 | self.update_target_steps = 200 # 每隔200个training steps再把model的参数复制到target_model中 240 | 241 | self.e_greed = e_greed # 有一定概率随机选取动作,探索 242 | self.e_greed_decrement = e_greed_decrement # 随着训练逐步收敛,探索的程度慢慢降低 243 | 244 | def build_program(self): 245 | self.pred_program = fluid.Program() 246 | self.learn_program = fluid.Program() 247 | 248 | with fluid.program_guard(self.pred_program): # 搭建计算图用于 预测动作,定义输入输出变量 249 | obs = layers.data( 250 | name='obs', shape=[self.obs_dim], dtype='float32') 251 | self.value = self.alg.predict(obs) 252 | 253 | with fluid.program_guard(self.learn_program): # 搭建计算图用于 更新Q网络,定义输入输出变量 254 | obs = layers.data( 255 | name='obs', shape=[self.obs_dim], dtype='float32') 256 | action = layers.data(name='act', shape=[1], dtype='int32') 257 | reward = layers.data(name='reward', shape=[], dtype='float32') 258 | next_obs = layers.data( 259 | name='next_obs', shape=[self.obs_dim], dtype='float32') 260 | terminal = layers.data(name='terminal', shape=[], dtype='bool') 261 | self.cost = self.alg.learn(obs, action, reward, next_obs, terminal) 262 | 263 | def sample(self, obs): 264 | sample = np.random.rand() # 产生0~1之间的小数 265 | if sample < self.e_greed: 266 | act = np.random.randint(self.act_dim) # 探索:每个动作都有概率被选择 267 | else: 268 | act = self.predict(obs) # 选择最优动作 269 | self.e_greed = max( 270 | 0.01, self.e_greed - self.e_greed_decrement) # 随着训练逐步收敛,探索的程度慢慢降低 271 | return act 272 | 273 | def predict(self, obs): # 选择最优动作 274 | obs = np.expand_dims(obs, axis=0) 275 | pred_Q = self.fluid_executor.run( 276 | self.pred_program, 277 | feed={'obs': obs.astype('float32')}, 278 | fetch_list=[self.value])[0] 279 | pred_Q = np.squeeze(pred_Q, axis=0) 280 | act = np.argmax(pred_Q) # 选择Q最大的下标,即对应的动作 281 | return act 282 | 283 | def learn(self, obs, act, reward, next_obs, terminal): 284 | # 每隔200个training steps同步一次model和target_model的参数 285 | if self.global_step % self.update_target_steps == 0: 286 | self.alg.sync_target() 287 | self.global_step += 1 288 | 289 | act = np.expand_dims(act, -1) 290 | feed = { 291 | 'obs': obs.astype('float32'), 292 | 'act': act.astype('int32'), 293 | 'reward': reward, 294 | 'next_obs': next_obs.astype('float32'), 295 | 'terminal': terminal 296 | } 297 | cost = self.fluid_executor.run( 298 | self.learn_program, feed=feed, fetch_list=[self.cost])[0] # 训练一次网络 299 | return cost 300 | 301 | 302 | # ### Step5 ReplayMemory 303 | # * 经验池:用于存储多条经验,实现 经验回放。 304 | # ![](https://ai-studio-static-online.cdn.bcebos.com/340817e194974c24ac63dd569fba336cd500d120729e4354825fb9c3501108ab) 305 | # 306 | 307 | # In[7]: 308 | 309 | 310 | import random 311 | import collections 312 | import numpy as np 313 | 314 | 315 | class ReplayMemory(object): 316 | def __init__(self, max_size): 317 | self.buffer = collections.deque(maxlen=max_size) 318 | 319 | # 增加一条经验到经验池中 320 | def append(self, exp): 321 | self.buffer.append(exp) 322 | 323 | # 从经验池中选取N条经验出来 324 | def sample(self, batch_size): 325 | mini_batch = random.sample(self.buffer, batch_size) 326 | obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = [], [], [], [], [] 327 | 328 | for experience in mini_batch: 329 | s, a, r, s_p, done = experience 330 | obs_batch.append(s) 331 | action_batch.append(a) 332 | reward_batch.append(r) 333 | next_obs_batch.append(s_p) 334 | done_batch.append(done) 335 | 336 | return np.array(obs_batch).astype('float32'), np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'), np.array(next_obs_batch).astype('float32'), np.array(done_batch).astype('float32') 337 | 338 | def __len__(self): 339 | return len(self.buffer) 340 | 341 | 342 | # ### Step6 Training && Test(训练&&测试) 343 | # ![](https://ai-studio-static-online.cdn.bcebos.com/3b9f168ee09a46dc9962decb0d2a60f7d35196d459824b9f923b163a0b4bbe4e) 344 | # ![](https://ai-studio-static-online.cdn.bcebos.com/4ffa39f252744a3791c04c5d8381824d4f7447a592d4409da658045061bcd0b9) 345 | # * 训练和评估的一个模块 346 | 347 | # In[8]: 348 | 349 | 350 | # 训练一个episode 351 | def run_episode(env, agent, rpm): 352 | total_reward = 0 353 | obs = env.reset() 354 | step = 0 355 | while True: 356 | step += 1 357 | action = agent.sample(obs) # 采样动作,所有动作都有概率被尝试到 358 | next_obs, reward, done, _ = env.step(action) 359 | rpm.append((obs, action, reward, next_obs, done)) 360 | 361 | # train model 362 | if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0): 363 | (batch_obs, batch_action, batch_reward, batch_next_obs, 364 | batch_done) = rpm.sample(BATCH_SIZE) 365 | train_loss = agent.learn(batch_obs, batch_action, batch_reward, 366 | batch_next_obs, 367 | batch_done) # s,a,r,s',done 368 | 369 | total_reward += reward 370 | obs = next_obs 371 | if done: 372 | break 373 | return total_reward 374 | 375 | 376 | # 评估 agent, 跑 5 个episode,总reward求平均 377 | def evaluate(env, agent, render=False): 378 | eval_reward = [] 379 | for i in range(5): 380 | obs = env.reset() 381 | episode_reward = 0 382 | while True: 383 | action = agent.predict(obs) # 预测动作,只选最优动作 384 | obs, reward, done, _ = env.step(action) 385 | episode_reward += reward 386 | if render: 387 | env.render() 388 | if done: 389 | break 390 | eval_reward.append(episode_reward) 391 | return np.mean(eval_reward) 392 | 393 | 394 | # ### Step7 创建环境和Agent,创建经验池,启动训练,保存模型 395 | 396 | # # 主函数 397 | # * 讲不清,理还乱,看一波图解,冷静冷静!!! 398 | # ![](https://ai-studio-static-online.cdn.bcebos.com/dab5d244b98c41108a84e37b79db0009109ed32d4d734ca7b04b68e2e301da09) 399 | # ![](https://ai-studio-static-online.cdn.bcebos.com/a297e56a6de045169e4a04391c22aa80bf7cc9d9608446fd97bef4dcd16431c8) 400 | # ![](https://ai-studio-static-online.cdn.bcebos.com/8bc3b4cd1913479480baa286b120f45b49cb8ac7c36645df8423d873a7bd6d07) 401 | # 402 | 403 | # In[9]: 404 | 405 | 406 | env = gym.make('CartPole-v0') # CartPole-v0: 预期最后一次评估总分 > 180(最大值是200) 407 | action_dim = env.action_space.n # CartPole-v0: 2 408 | obs_shape = env.observation_space.shape # CartPole-v0: (4,) 409 | 410 | rpm = ReplayMemory(MEMORY_SIZE) # DQN的经验回放池 411 | 412 | # 根据parl框架构建agent 413 | model = Model(act_dim=action_dim) 414 | algorithm = DQN(model, act_dim=action_dim, gamma=GAMMA, lr=LEARNING_RATE) 415 | agent = Agent( 416 | algorithm, 417 | obs_dim=obs_shape[0], 418 | act_dim=action_dim, 419 | e_greed=0.1, # 有一定概率随机选取动作,探索 420 | e_greed_decrement=1e-6) # 随着训练逐步收敛,探索的程度慢慢降低 421 | 422 | # 加载模型 423 | # save_path = './dqn_model.ckpt' 424 | # agent.restore(save_path) 425 | 426 | # 先往经验池里存一些数据,避免最开始训练的时候样本丰富度不够 427 | while len(rpm) < MEMORY_WARMUP_SIZE: 428 | run_episode(env, agent, rpm) 429 | 430 | max_episode = 2000 431 | 432 | # 开始训练 433 | episode = 0 434 | while episode < max_episode: # 训练max_episode个回合,test部分不计算入episode数量 435 | # train part 436 | for i in range(0, 50): 437 | total_reward = run_episode(env, agent, rpm) 438 | episode += 1 439 | 440 | # test part 441 | eval_reward = evaluate(env, agent, render=False) # render=True 查看显示效果 442 | logger.info('episode:{} e_greed:{} test_reward:{}'.format( 443 | episode, agent.e_greed, eval_reward)) 444 | 445 | # 训练结束,保存模型 446 | save_path = './dqn_model.ckpt' 447 | agent.save(save_path) 448 | 449 | 450 | # ![](https://ai-studio-static-online.cdn.bcebos.com/94a7460d1da44164a7805c5680a2779689b1dc4c78494104aa3d2bda075411fa) 451 | # 452 | 453 | # # 运行经历 454 | # 这一串代码是从课件上copy下来的,所以里面的参数及数据都是已经调整好了的,但是在线下跑数据发现每一次跑也并非到最后都是200也就是每一次的结果都是不一定的,然后我开了显示模式,里面的测试画面和结果会被显示和打印,前面的速度会比后面的快,因为分数低时间自然就短了。 455 | # 更据这段时间对深度学习的学习,基本上代码运行没有问题后那么调整几个超参基本上可以得到一个比较好的拟合效果。 456 | # 457 | 458 | # # 心得体会 459 | # 本次“爬山”,发现之前的一些盲点在本次有了解决,,之前对深度学习、PARL、算法等都有了最新的认识,虽然还是那个小白但是认识提上去了,以后还是可以更加努力的去奋斗的,这就是传说中的回头看的时候就知道自己以前是多么的无知了。 460 | 461 | # # 这里是三岁,请大家多多指教啊! 462 | -------------------------------------------------------------------------------- /利用PARL复现基于神经网络与DQN算法.md: -------------------------------------------------------------------------------- 1 | # 深度学习入门 | 三岁在飞桨带你入门深度学习—Carpoel,利用PARL复现基于神经网络与DQN算法(真的是0基础) 2 | 大家好,这里是三岁,众所周知三岁是编程届小白,为了给大家贡献一个“爬山”的模板, 3 | 三岁利用最基础的深度学习“hello world”项目给大家解析,及做示范。 4 | 三岁老规矩,白话,简单,入门,基础 5 | 如果有什么不准确,不正确的地方希望大家可以提出来! 6 | (代码源于[强化学习7日打卡营-世界冠军带你从零实践>PARL强化学习公开课Lesson3_DQN](https://aistudio.baidu.com/aistudio/projectdetail/569647)) 7 | * 以下项目适用于CPU环境 8 | ## 参考资料 9 | * B站视频地址:[https://www.bilibili.com/video/bv1v54y1v7Qf](https://www.bilibili.com/video/bv1v54y1v7Qf) 10 | * AI 社区文章地址:[https://ai.baidu.com/forum/topic/show/962531](https://ai.baidu.com/forum/topic/show/962531) 11 | * CSDN文章地址:[https://editor.csdn.net/md?articleId=107393006](https://editor.csdn.net/md?articleId=107393006) 12 | * 三岁推文地址:[https://mp.weixin.qq.com/s/6-6RR0XuvTNuXKhX7fFXaQ](https://mp.weixin.qq.com/s/6-6RR0XuvTNuXKhX7fFXaQ) 13 | * 参考论文:[https://www.nature.com/articles/nature14236](https://www.nature.com/articles/nature14236) 14 | * DQNgithub地址:[https://github.com/PaddlePaddle/PARL/tree/develop/examples](https://github.com/PaddlePaddle/PARL/tree/develop/examples) 15 | * 参考视频:[https://www.bilibili.com/video/BV1yv411i7xd?p=12](https://www.bilibili.com/video/BV1yv411i7xd?p=12) 16 | * Carpoel参考资料:[https://gym.openai.com/envs/CartPole-v1/](https://gym.openai.com/envs/CartPole-v1/) 17 | * PARL官方地址:[https://github.com/PaddlePaddle/PARL](https://github.com/PaddlePaddle/PARL) 18 | 19 | 那么我们接下来就开始爬山吧,记得唱着小白船然后录像呦,三岁会帮你调整一下jio的位置的【滑稽】 20 | 21 | 22 | ### 环境预设 23 | 根据实际情况把AI Studio的环境进行修改使其更加符合代码的运行。 24 | 25 | 26 | ```python 27 | !pip uninstall -y parl # 说明:AIStudio预装的parl版本太老,容易跟其他库产生兼容性冲突,建议先卸载 28 | !pip uninstall -y pandas scikit-learn # 提示:在AIStudio中卸载这两个库再import parl可避免warning提示,不卸载也不影响parl的使用 29 | 30 | !pip install gym 31 | !pip install paddlepaddle==1.6.3 32 | !pip install parl==1.3.1 33 | 34 | # 建议下载paddle系列产品时添加百度源 -i https://mirror.baidu.com/pypi/simple 35 | # python -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple 36 | # pip install parl==1.3.1 -i https://mirror.baidu.com/pypi/simple 37 | 38 | # 说明:安装日志中出现两条红色的关于 paddlehub 和 visualdl 的 ERROR 与parl无关,可以忽略,不影响使用 39 | ``` 40 | 41 | Uninstalling parl-1.1.2: 42 | Successfully uninstalled parl-1.1.2 43 | Uninstalling pandas-0.23.4: 44 | Successfully uninstalled pandas-0.23.4 45 | Uninstalling scikit-learn-0.20.0: 46 | Successfully uninstalled scikit-learn-0.20.0 47 | Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/ 48 | Requirement already satisfied: gym in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (0.12.1) 49 | Requirement already satisfied: requests>=2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (2.22.0) 50 | Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.15.0) 51 | Requirement already satisfied: numpy>=1.10.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.16.4) 52 | Requirement already satisfied: pyglet>=1.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.4.5) 53 | Requirement already satisfied: scipy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.3.0) 54 | Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (2019.9.11) 55 | Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (3.0.4) 56 | Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (2.8) 57 | Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (1.25.6) 58 | Requirement already satisfied: future in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pyglet>=1.2.0->gym) (0.18.0) 59 | Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/ 60 | Collecting paddlepaddle==1.6.3 61 | [?25l Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/96/28/e72bebb3c9b3d98eb9b15d9f6d85150f3cbd63e695e59882ff9f04846686/paddlepaddle-1.6.3-cp37-cp37m-manylinux1_x86_64.whl (90.9MB) 62 |  |████████████████████████████████| 90.9MB 488kB/s eta 0:00:011 |███████████████▌ | 44.0MB 481kB/s eta 0:01:38 63 | [?25hRequirement already satisfied: nltk; python_version >= "3.5" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (3.4.5) 64 | Requirement already satisfied: prettytable in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (0.7.2) 65 | Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (1.15.0) 66 | Requirement already satisfied: protobuf>=3.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (3.10.0) 67 | Requirement already satisfied: numpy>=1.12; python_version >= "3.5" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (1.16.4) 68 | Requirement already satisfied: scipy; python_version >= "3.5" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (1.3.0) 69 | Requirement already satisfied: objgraph in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (3.4.1) 70 | Requirement already satisfied: decorator in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (4.4.0) 71 | Requirement already satisfied: matplotlib; python_version >= "3.6" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (2.2.3) 72 | Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (7.1.2) 73 | Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (4.1.1.26) 74 | Requirement already satisfied: requests>=2.20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (2.22.0) 75 | Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (5.1.2) 76 | Requirement already satisfied: funcsigs in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (1.0.2) 77 | Requirement already satisfied: rarfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (3.1) 78 | Requirement already satisfied: graphviz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (0.13) 79 | Requirement already satisfied: setuptools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from protobuf>=3.1.0->paddlepaddle==1.6.3) (41.4.0) 80 | Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= "3.6"->paddlepaddle==1.6.3) (0.10.0) 81 | Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= "3.6"->paddlepaddle==1.6.3) (1.1.0) 82 | Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= "3.6"->paddlepaddle==1.6.3) (2019.3) 83 | Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= "3.6"->paddlepaddle==1.6.3) (2.8.0) 84 | Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= "3.6"->paddlepaddle==1.6.3) (2.4.2) 85 | Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle==1.6.3) (2.8) 86 | Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle==1.6.3) (3.0.4) 87 | Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle==1.6.3) (2019.9.11) 88 | Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle==1.6.3) (1.25.6) 89 | Installing collected packages: paddlepaddle 90 | Found existing installation: paddlepaddle 1.8.0 91 | Uninstalling paddlepaddle-1.8.0: 92 | Successfully uninstalled paddlepaddle-1.8.0 93 | Successfully installed paddlepaddle-1.6.3 94 | Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/ 95 | Collecting parl==1.3.1 96 | [?25l Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/62/79/590af38a920792c71afb73fad7583967928b4d0ba9fca76250d935c7fda8/parl-1.3.1-py2.py3-none-any.whl (521kB) 97 |  |████████████████████████████████| 522kB 17.8MB/s eta 0:00:01 98 | [?25hRequirement already satisfied: tb-nightly==1.15.0a20190801 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.15.0a20190801) 99 | Requirement already satisfied: visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (2.0.0b4) 100 | Requirement already satisfied: pyzmq==18.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (18.0.1) 101 | Requirement already satisfied: flask>=1.0.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.1.1) 102 | Collecting flask-cors (from parl==1.3.1) 103 | Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/78/38/e68b11daa5d613e3a91e4bf3da76c94ac9ee0d9cd515af9c1ab80d36f709/Flask_Cors-3.0.8-py2.py3-none-any.whl 104 | Requirement already satisfied: click in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (7.0) 105 | Collecting psutil>=5.6.2 (from parl==1.3.1) 106 | [?25l Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/aa/3e/d18f2c04cf2b528e18515999b0c8e698c136db78f62df34eee89cee205f1/psutil-5.7.2.tar.gz (460kB) 107 |  |████████████████████████████████| 460kB 52.1MB/s eta 0:00:01 108 | [?25hRequirement already satisfied: tensorboardX==1.8 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.8) 109 | Requirement already satisfied: pyarrow==0.13.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (0.13.0) 110 | Requirement already satisfied: scipy>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.3.0) 111 | Requirement already satisfied: termcolor>=1.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.1.0) 112 | Requirement already satisfied: cloudpickle==1.2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.2.1) 113 | Requirement already satisfied: absl-py>=0.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (0.8.1) 114 | Requirement already satisfied: six>=1.10.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (1.15.0) 115 | Requirement already satisfied: grpcio>=1.6.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (1.26.0) 116 | Requirement already satisfied: wheel>=0.26; python_version >= "3" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (0.33.6) 117 | Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (41.4.0) 118 | Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (0.16.0) 119 | Requirement already satisfied: markdown>=2.6.8 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (3.1.1) 120 | Requirement already satisfied: protobuf>=3.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (3.10.0) 121 | Requirement already satisfied: numpy>=1.12.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (1.16.4) 122 | Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2.22.0) 123 | Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (1.21.0) 124 | Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (3.8.2) 125 | Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (1.0.0) 126 | Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (4.1.1.26) 127 | Requirement already satisfied: Pillow>=7.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (7.1.2) 128 | Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.0.4->parl==1.3.1) (1.1.0) 129 | Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.0.4->parl==1.3.1) (2.10.3) 130 | Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2019.9.11) 131 | Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2.8) 132 | Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (3.0.4) 133 | Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (1.25.6) 134 | Requirement already satisfied: importlib-metadata; python_version < "3.8" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (0.23) 135 | Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (1.4.10) 136 | Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (1.3.4) 137 | Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2.0.1) 138 | Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (1.3.0) 139 | Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (16.7.9) 140 | Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (0.10.0) 141 | Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (5.1.2) 142 | Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (0.6.1) 143 | Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2.6.0) 144 | Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2.2.0) 145 | Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2019.3) 146 | Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2.8.0) 147 | Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.0.4->parl==1.3.1) (1.1.1) 148 | Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata; python_version < "3.8"->pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (0.6.0) 149 | Requirement already satisfied: more-itertools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from zipp>=0.5->importlib-metadata; python_version < "3.8"->pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (7.2.0) 150 | Building wheels for collected packages: psutil 151 | Building wheel for psutil (setup.py) ... [?25ldone 152 | [?25h Created wheel for psutil: filename=psutil-5.7.2-cp37-cp37m-linux_x86_64.whl size=268459 sha256=ef019c256f341f219f0260fa4acfe2863e8f0bf1f96ed95b3f910a2a4c44a74c 153 | Stored in directory: /home/aistudio/.cache/pip/wheels/a8/74/a2/9f54383a7c48678163f965a5d2f4acb794417e60ab0d7351f8 154 | Successfully built psutil 155 | Installing collected packages: flask-cors, psutil, parl 156 | Successfully installed flask-cors-3.0.8 parl-1.3.1 psutil-5.7.2 157 | 158 | 159 | ### Step2 导入依赖 160 | 如果依赖导入失败有可能没有下载第三方库可以在此前加上代码块 161 | !pip instudio 第三方库名 162 | 163 | 如果安装失败可以加上镜像 164 | 165 | ![](https://ai-studio-static-online.cdn.bcebos.com/8ee07cbb24874e96bdac93b07e3a22e3b6b465981c904129990af727701bd0b0) 166 | ``` 167 | 镜像源地址: 168 | 169 | 百度:https://mirror.baidu.com/pypi/simple 170 | 171 | 清华:https://pypi.tuna.tsinghua.edu.cn/simple 172 | 173 | 阿里云:http://mirrors.aliyun.com/pypi/simple/ 174 | 175 | 中国科技大学 https://pypi.mirrors.ustc.edu.cn/simple/ 176 | 177 | 华中理工大学:http://pypi.hustunique.com/ 178 | 179 | 山东理工大学:http://pypi.sdutlinux.org/ 180 | 181 | 豆瓣:http://pypi.douban.com/simple/ 182 | 183 | 例:!pip instudio jieba -i https://mirror.baidu.com/pypi/simple 184 | ``` 185 | 186 | 187 | ```python 188 | import parl 189 | from parl import layers 190 | import paddle.fluid as fluid 191 | import copy 192 | import numpy as np 193 | import os 194 | import gym 195 | from parl.utils import logger 196 | ``` 197 | 198 | # 流程解析 199 | ![](https://ai-studio-static-online.cdn.bcebos.com/25f4800e0c3b44bc9ddc908a4d3dba1793dc74f0f5a045e799771f0466b88e5b) 200 | 201 | 202 | ### Step3 设置超参数 203 | 204 | 205 | ```python 206 | LEARN_FREQ = 5 # 训练频率,不需要每一个step都learn,攒一些新增经验后再learn,提高效率 207 | MEMORY_SIZE = 20000 # replay memory的大小,越大越占用内存 208 | MEMORY_WARMUP_SIZE = 200 # replay_memory 里需要预存一些经验数据,再开启训练 209 | BATCH_SIZE = 32 # 每次给agent learn的数据数量,从replay memory随机里sample一批数据出来 210 | LEARNING_RATE = 0.001 # 学习率 211 | GAMMA = 0.99 # reward 的衰减因子,一般取 0.9 到 0.999 不等 212 | ``` 213 | 214 | # 搭建PARL 215 | 什么是PARL呢?让我们看看官方文档怎么说 216 | 217 | ![飞桨官网描述](https://ai-studio-static-online.cdn.bcebos.com/1f72929e036549d493dee85c2834b8a98e330c4ac6a447f08b647cab09e12bfa) 飞桨官网描述 218 | ![PARL GitHub描述](https://ai-studio-static-online.cdn.bcebos.com/9a19623bdf41411ca07c089d3cf0f64752e0eb51a58a47e694bc4e4614079078) PARL GitHub描述 219 | 220 | * PARL主要是基于Model、Algorithm、Agent三个代码块来实现,其中Model和Agent是用户自定义操作。 221 | * Model:是网络结构:要三层网络还是四层网络都是在Model中去定义(下文是三层的网络结构) 222 | * Agent:是PARL与环境的一个接口,通过对模板的修改即可运用到各个不同的环境中去。 223 | 224 | * 至于Algorithm是内部已经封装好了的,直接加入参数运行即可,主要是算法的模块的展现 225 | 226 | 227 | ### Step4 搭建Model、Algorithm、Agent架构 228 | * `Agent`把产生的数据传给`algorithm`,`algorithm`根据`model`的模型结构计算出`Loss`,使用`SGD`或者其他优化器不断的优化,`PARL`这种架构可以很方便的应用在各类深度强化学习问题中。 229 | 230 | #### (1)Model 231 | * `Model`用来定义前向(`Forward`)网络,用户可以自由的定制自己的网络结构。 232 | 233 | 234 | ```python 235 | class Model(parl.Model): 236 | def __init__(self, act_dim): 237 | hid1_size = 128 238 | hid2_size = 128 239 | # 3层全连接网络 240 | self.fc1 = layers.fc(size=hid1_size, act='relu') 241 | self.fc2 = layers.fc(size=hid2_size, act='relu') 242 | self.fc3 = layers.fc(size=act_dim, act=None) 243 | 244 | def value(self, obs): 245 | # 定义网络 246 | # 输入state,输出所有action对应的Q,[Q(s,a1), Q(s,a2), Q(s,a3)...] 247 | h1 = self.fc1(obs) 248 | h2 = self.fc2(h1) 249 | Q = self.fc3(h2) 250 | return Q 251 | ``` 252 | 253 | ## Q算法的“藏身之地” 254 | ### 优势 255 | DQN算法较普通算法在经验回放和固定Q目标有了较大的改进 256 | 257 | * 1、经验回放:他充分利用了off-colicp的优势,通过训练把结果(成绩)存入Q表格,然后随机从表格中取出一条结果进行优化。这样子一方面可以:减少样本之间的关联性另一方面:提高样本的利用率 258 | 注:训练结果会存进Q表格,当Q表格满了以后,存进来的数据会把最早存进去的数据“挤出去”(弹出) 259 | * 2、固定Q目标他解决了算法更新不平稳的问题。 260 | 和监督学习做比较,监督学习的最终值要逼近实际结果,这个结果是固定的,但是我们的DQN却不是,他的目标值是经过神经网络以后的一个值,那么这个值是变动的不好拟合,怎么办,DQN团队想到了一个很好的办法,让这个值在一定时间里面保持不变,这样子这个目标就可以确定了,然后目标值更新以后更加接近实际结果,可以更好的进行训练。 261 | 262 | #### (2)Algorithm 263 | * `Algorithm` 定义了具体的算法来更新前向网络(`Model`),也就是通过定义损失函数来更新`Model`,和算法相关的计算都放在`algorithm`中。 264 | 265 | 266 | 267 | 268 | ```python 269 | # from parl.algorithms import DQN # 也可以直接从parl库中导入DQN算法 270 | 271 | class DQN(parl.Algorithm): 272 | def __init__(self, model, act_dim=None, gamma=None, lr=None): 273 | """ DQN algorithm 274 | 275 | Args: 276 | model (parl.Model): 定义Q函数的前向网络结构 277 | act_dim (int): action空间的维度,即有几个action 278 | gamma (float): reward的衰减因子 279 | lr (float): learning rate 学习率. 280 | """ 281 | self.model = model 282 | self.target_model = copy.deepcopy(model) 283 | 284 | assert isinstance(act_dim, int) 285 | assert isinstance(gamma, float) 286 | assert isinstance(lr, float) 287 | self.act_dim = act_dim 288 | self.gamma = gamma 289 | self.lr = lr 290 | 291 | def predict(self, obs): 292 | """ 使用self.model的value网络来获取 [Q(s,a1),Q(s,a2),...] 293 | """ 294 | return self.model.value(obs) 295 | 296 | def learn(self, obs, action, reward, next_obs, terminal): 297 | """ 使用DQN算法更新self.model的value网络 298 | """ 299 | # 从target_model中获取 max Q' 的值,用于计算target_Q 300 | next_pred_value = self.target_model.value(next_obs) 301 | best_v = layers.reduce_max(next_pred_value, dim=1) 302 | best_v.stop_gradient = True # 阻止梯度传递 303 | terminal = layers.cast(terminal, dtype='float32') 304 | target = reward + (1.0 - terminal) * self.gamma * best_v 305 | 306 | pred_value = self.model.value(obs) # 获取Q预测值 307 | # 将action转onehot向量,比如:3 => [0,0,0,1,0] 308 | action_onehot = layers.one_hot(action, self.act_dim) 309 | action_onehot = layers.cast(action_onehot, dtype='float32') 310 | # 下面一行是逐元素相乘,拿到action对应的 Q(s,a) 311 | # 比如:pred_value = [[2.3, 5.7, 1.2, 3.9, 1.4]], action_onehot = [[0,0,0,1,0]] 312 | # ==> pred_action_value = [[3.9]] 313 | pred_action_value = layers.reduce_sum( 314 | layers.elementwise_mul(action_onehot, pred_value), dim=1) 315 | 316 | # 计算 Q(s,a) 与 target_Q的均方差,得到loss 317 | cost = layers.square_error_cost(pred_action_value, target) 318 | cost = layers.reduce_mean(cost) 319 | optimizer = fluid.optimizer.Adam(learning_rate=self.lr) # 使用Adam优化器 320 | optimizer.minimize(cost) 321 | return cost 322 | 323 | def sync_target(self): 324 | """ 把 self.model 的模型参数值同步到 self.target_model 325 | """ 326 | self.model.sync_weights_to(self.target_model) 327 | 328 | ``` 329 | 330 | #### (3)Agent 331 | * `Agent` 负责算法与环境的交互,在交互过程中把生成的数据提供给`Algorithm`来更新模型(`Model`),数据的预处理流程也一般定义在这里。 332 | 333 | 334 | ```python 335 | class Agent(parl.Agent): 336 | def __init__(self, 337 | algorithm, 338 | obs_dim, 339 | act_dim, 340 | e_greed=0.1, 341 | e_greed_decrement=0): 342 | assert isinstance(obs_dim, int) 343 | assert isinstance(act_dim, int) 344 | self.obs_dim = obs_dim 345 | self.act_dim = act_dim 346 | super(Agent, self).__init__(algorithm) 347 | 348 | self.global_step = 0 349 | self.update_target_steps = 200 # 每隔200个training steps再把model的参数复制到target_model中 350 | 351 | self.e_greed = e_greed # 有一定概率随机选取动作,探索 352 | self.e_greed_decrement = e_greed_decrement # 随着训练逐步收敛,探索的程度慢慢降低 353 | 354 | def build_program(self): 355 | self.pred_program = fluid.Program() 356 | self.learn_program = fluid.Program() 357 | 358 | with fluid.program_guard(self.pred_program): # 搭建计算图用于 预测动作,定义输入输出变量 359 | obs = layers.data( 360 | name='obs', shape=[self.obs_dim], dtype='float32') 361 | self.value = self.alg.predict(obs) 362 | 363 | with fluid.program_guard(self.learn_program): # 搭建计算图用于 更新Q网络,定义输入输出变量 364 | obs = layers.data( 365 | name='obs', shape=[self.obs_dim], dtype='float32') 366 | action = layers.data(name='act', shape=[1], dtype='int32') 367 | reward = layers.data(name='reward', shape=[], dtype='float32') 368 | next_obs = layers.data( 369 | name='next_obs', shape=[self.obs_dim], dtype='float32') 370 | terminal = layers.data(name='terminal', shape=[], dtype='bool') 371 | self.cost = self.alg.learn(obs, action, reward, next_obs, terminal) 372 | 373 | def sample(self, obs): 374 | sample = np.random.rand() # 产生0~1之间的小数 375 | if sample < self.e_greed: 376 | act = np.random.randint(self.act_dim) # 探索:每个动作都有概率被选择 377 | else: 378 | act = self.predict(obs) # 选择最优动作 379 | self.e_greed = max( 380 | 0.01, self.e_greed - self.e_greed_decrement) # 随着训练逐步收敛,探索的程度慢慢降低 381 | return act 382 | 383 | def predict(self, obs): # 选择最优动作 384 | obs = np.expand_dims(obs, axis=0) 385 | pred_Q = self.fluid_executor.run( 386 | self.pred_program, 387 | feed={'obs': obs.astype('float32')}, 388 | fetch_list=[self.value])[0] 389 | pred_Q = np.squeeze(pred_Q, axis=0) 390 | act = np.argmax(pred_Q) # 选择Q最大的下标,即对应的动作 391 | return act 392 | 393 | def learn(self, obs, act, reward, next_obs, terminal): 394 | # 每隔200个training steps同步一次model和target_model的参数 395 | if self.global_step % self.update_target_steps == 0: 396 | self.alg.sync_target() 397 | self.global_step += 1 398 | 399 | act = np.expand_dims(act, -1) 400 | feed = { 401 | 'obs': obs.astype('float32'), 402 | 'act': act.astype('int32'), 403 | 'reward': reward, 404 | 'next_obs': next_obs.astype('float32'), 405 | 'terminal': terminal 406 | } 407 | cost = self.fluid_executor.run( 408 | self.learn_program, feed=feed, fetch_list=[self.cost])[0] # 训练一次网络 409 | return cost 410 | ``` 411 | 412 | ### Step5 ReplayMemory 413 | * 经验池:用于存储多条经验,实现 经验回放。 414 | ![](https://ai-studio-static-online.cdn.bcebos.com/340817e194974c24ac63dd569fba336cd500d120729e4354825fb9c3501108ab) 415 | 416 | 417 | 418 | ```python 419 | import random 420 | import collections 421 | import numpy as np 422 | 423 | 424 | class ReplayMemory(object): 425 | def __init__(self, max_size): 426 | self.buffer = collections.deque(maxlen=max_size) 427 | 428 | # 增加一条经验到经验池中 429 | def append(self, exp): 430 | self.buffer.append(exp) 431 | 432 | # 从经验池中选取N条经验出来 433 | def sample(self, batch_size): 434 | mini_batch = random.sample(self.buffer, batch_size) 435 | obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = [], [], [], [], [] 436 | 437 | for experience in mini_batch: 438 | s, a, r, s_p, done = experience 439 | obs_batch.append(s) 440 | action_batch.append(a) 441 | reward_batch.append(r) 442 | next_obs_batch.append(s_p) 443 | done_batch.append(done) 444 | 445 | return np.array(obs_batch).astype('float32'), \ 446 | np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'),\ 447 | np.array(next_obs_batch).astype('float32'), np.array(done_batch).astype('float32') 448 | 449 | def __len__(self): 450 | return len(self.buffer) 451 | 452 | ``` 453 | 454 | ### Step6 Training && Test(训练&&测试) 455 | ![](https://ai-studio-static-online.cdn.bcebos.com/3b9f168ee09a46dc9962decb0d2a60f7d35196d459824b9f923b163a0b4bbe4e) 456 | ![](https://ai-studio-static-online.cdn.bcebos.com/4ffa39f252744a3791c04c5d8381824d4f7447a592d4409da658045061bcd0b9) 457 | * 训练和评估的一个模块 458 | 459 | 460 | ```python 461 | # 训练一个episode 462 | def run_episode(env, agent, rpm): 463 | total_reward = 0 464 | obs = env.reset() 465 | step = 0 466 | while True: 467 | step += 1 468 | action = agent.sample(obs) # 采样动作,所有动作都有概率被尝试到 469 | next_obs, reward, done, _ = env.step(action) 470 | rpm.append((obs, action, reward, next_obs, done)) 471 | 472 | # train model 473 | if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0): 474 | (batch_obs, batch_action, batch_reward, batch_next_obs, 475 | batch_done) = rpm.sample(BATCH_SIZE) 476 | train_loss = agent.learn(batch_obs, batch_action, batch_reward, 477 | batch_next_obs, 478 | batch_done) # s,a,r,s',done 479 | 480 | total_reward += reward 481 | obs = next_obs 482 | if done: 483 | break 484 | return total_reward 485 | 486 | 487 | # 评估 agent, 跑 5 个episode,总reward求平均 488 | def evaluate(env, agent, render=False): 489 | eval_reward = [] 490 | for i in range(5): 491 | obs = env.reset() 492 | episode_reward = 0 493 | while True: 494 | action = agent.predict(obs) # 预测动作,只选最优动作 495 | obs, reward, done, _ = env.step(action) 496 | episode_reward += reward 497 | if render: 498 | env.render() 499 | if done: 500 | break 501 | eval_reward.append(episode_reward) 502 | return np.mean(eval_reward) 503 | 504 | ``` 505 | 506 | ### Step7 创建环境和Agent,创建经验池,启动训练,保存模型 507 | 508 | # 主函数 509 | * 讲不清,理还乱,看一波图解,冷静冷静!!! 510 | ![](https://ai-studio-static-online.cdn.bcebos.com/dab5d244b98c41108a84e37b79db0009109ed32d4d734ca7b04b68e2e301da09) 511 | ![](https://ai-studio-static-online.cdn.bcebos.com/a297e56a6de045169e4a04391c22aa80bf7cc9d9608446fd97bef4dcd16431c8) 512 | ![](https://ai-studio-static-online.cdn.bcebos.com/8bc3b4cd1913479480baa286b120f45b49cb8ac7c36645df8423d873a7bd6d07) 513 | 514 | 515 | 516 | ```python 517 | env = gym.make('CartPole-v0') # CartPole-v0: 预期最后一次评估总分 > 180(最大值是200) 518 | action_dim = env.action_space.n # CartPole-v0: 2 519 | obs_shape = env.observation_space.shape # CartPole-v0: (4,) 520 | 521 | rpm = ReplayMemory(MEMORY_SIZE) # DQN的经验回放池 522 | 523 | # 根据parl框架构建agent 524 | model = Model(act_dim=action_dim) 525 | algorithm = DQN(model, act_dim=action_dim, gamma=GAMMA, lr=LEARNING_RATE) 526 | agent = Agent( 527 | algorithm, 528 | obs_dim=obs_shape[0], 529 | act_dim=action_dim, 530 | e_greed=0.1, # 有一定概率随机选取动作,探索 531 | e_greed_decrement=1e-6) # 随着训练逐步收敛,探索的程度慢慢降低 532 | 533 | # 加载模型 534 | # save_path = './dqn_model.ckpt' 535 | # agent.restore(save_path) 536 | 537 | # 先往经验池里存一些数据,避免最开始训练的时候样本丰富度不够 538 | while len(rpm) < MEMORY_WARMUP_SIZE: 539 | run_episode(env, agent, rpm) 540 | 541 | max_episode = 2000 542 | 543 | # 开始训练 544 | episode = 0 545 | while episode < max_episode: # 训练max_episode个回合,test部分不计算入episode数量 546 | # train part 547 | for i in range(0, 50): 548 | total_reward = run_episode(env, agent, rpm) 549 | episode += 1 550 | 551 | # test part 552 | eval_reward = evaluate(env, agent, render=False) # render=True 查看显示效果 553 | logger.info('episode:{} e_greed:{} test_reward:{}'.format( 554 | episode, agent.e_greed, eval_reward)) 555 | 556 | # 训练结束,保存模型 557 | save_path = './dqn_model.ckpt' 558 | agent.save(save_path) 559 | ``` 560 | 561 | [07-27 18:19:17 MainThread @machine_info.py:88] Cannot find available GPU devices, using CPU now. 562 | [07-27 18:19:17 MainThread @machine_info.py:88] Cannot find available GPU devices, using CPU now. 563 | [07-27 18:19:18 MainThread @machine_info.py:88] Cannot find available GPU devices, using CPU now. 564 | [07-27 18:19:19 MainThread @:38] episode:50 e_greed:0.09930599999999931 test_reward:11.0 565 | [07-27 18:19:21 MainThread @:38] episode:100 e_greed:0.09877699999999878 test_reward:9.4 566 | [07-27 18:19:23 MainThread @:38] episode:150 e_greed:0.09828699999999829 test_reward:9.4 567 | [07-27 18:19:24 MainThread @:38] episode:200 e_greed:0.09781299999999782 test_reward:9.2 568 | [07-27 18:19:26 MainThread @:38] episode:250 e_greed:0.09733199999999734 test_reward:9.2 569 | [07-27 18:19:27 MainThread @:38] episode:300 e_greed:0.09683499999999684 test_reward:9.6 570 | [07-27 18:19:29 MainThread @:38] episode:350 e_greed:0.09634199999999635 test_reward:9.6 571 | [07-27 18:19:30 MainThread @:38] episode:400 e_greed:0.09583999999999585 test_reward:9.6 572 | [07-27 18:19:32 MainThread @:38] episode:450 e_greed:0.09533799999999534 test_reward:11.0 573 | [07-27 18:19:34 MainThread @:38] episode:500 e_greed:0.09476499999999477 test_reward:9.4 574 | [07-27 18:19:37 MainThread @:38] episode:550 e_greed:0.09395199999999396 test_reward:44.4 575 | [07-27 18:19:43 MainThread @:38] episode:600 e_greed:0.0921909999999922 test_reward:21.2 576 | [07-27 18:20:05 MainThread @:38] episode:650 e_greed:0.08648299999998649 test_reward:186.0 577 | [07-27 18:20:39 MainThread @:38] episode:700 e_greed:0.07723099999997723 test_reward:199.0 578 | [07-27 18:21:12 MainThread @:38] episode:750 e_greed:0.06832399999996833 test_reward:188.6 579 | [07-27 18:21:42 MainThread @:38] episode:800 e_greed:0.06003099999996003 test_reward:124.6 580 | [07-27 18:22:09 MainThread @:38] episode:850 e_greed:0.05293599999995294 test_reward:165.6 581 | [07-27 18:22:31 MainThread @:38] episode:900 e_greed:0.04666999999994667 test_reward:110.2 582 | [07-27 18:22:56 MainThread @:38] episode:950 e_greed:0.040123999999940124 test_reward:111.0 583 | [07-27 18:23:21 MainThread @:38] episode:1000 e_greed:0.03343099999993343 test_reward:133.6 584 | [07-27 18:23:47 MainThread @:38] episode:1050 e_greed:0.026627999999926627 test_reward:130.0 585 | [07-27 18:24:08 MainThread @:38] episode:1100 e_greed:0.02101999999992102 test_reward:130.0 586 | [07-27 18:24:31 MainThread @:38] episode:1150 e_greed:0.015293999999915868 test_reward:179.4 587 | [07-27 18:25:03 MainThread @:38] episode:1200 e_greed:0.01 test_reward:183.0 588 | [07-27 18:25:33 MainThread @:38] episode:1250 e_greed:0.01 test_reward:130.8 589 | [07-27 18:25:59 MainThread @:38] episode:1300 e_greed:0.01 test_reward:172.8 590 | [07-27 18:26:25 MainThread @:38] episode:1350 e_greed:0.01 test_reward:152.2 591 | [07-27 18:26:49 MainThread @:38] episode:1400 e_greed:0.01 test_reward:127.8 592 | [07-27 18:27:19 MainThread @:38] episode:1450 e_greed:0.01 test_reward:182.0 593 | [07-27 18:27:52 MainThread @:38] episode:1500 e_greed:0.01 test_reward:120.6 594 | [07-27 18:28:25 MainThread @:38] episode:1550 e_greed:0.01 test_reward:133.6 595 | [07-27 18:28:45 MainThread @:38] episode:1600 e_greed:0.01 test_reward:14.0 596 | [07-27 18:28:49 MainThread @:38] episode:1650 e_greed:0.01 test_reward:15.8 597 | [07-27 18:29:09 MainThread @:38] episode:1700 e_greed:0.01 test_reward:164.2 598 | [07-27 18:29:47 MainThread @:38] episode:1750 e_greed:0.01 test_reward:200.0 599 | [07-27 18:30:26 MainThread @:38] episode:1800 e_greed:0.01 test_reward:200.0 600 | [07-27 18:31:03 MainThread @:38] episode:1850 e_greed:0.01 test_reward:182.4 601 | [07-27 18:31:40 MainThread @:38] episode:1900 e_greed:0.01 test_reward:193.8 602 | [07-27 18:32:17 MainThread @:38] episode:1950 e_greed:0.01 test_reward:193.4 603 | [07-27 18:32:53 MainThread @:38] episode:2000 e_greed:0.01 test_reward:170.2 604 | 605 | 606 | ![](https://ai-studio-static-online.cdn.bcebos.com/94a7460d1da44164a7805c5680a2779689b1dc4c78494104aa3d2bda075411fa) 607 | 608 | 609 | # 运行经历 610 | 这一串代码是从课件上copy下来的,所以里面的参数及数据都是已经调整好了的,但是在线下跑数据发现每一次跑也并非到最后都是200也就是每一次的结果都是不一定的,然后我开了显示模式,里面的测试画面和结果会被显示和打印,前面的速度会比后面的快,因为分数低时间自然就短了。 611 | 更据这段时间对深度学习的学习,基本上代码运行没有问题后那么调整几个超参基本上可以得到一个比较好的拟合效果。 612 | 613 | 614 | # 心得体会 615 | 本次“爬山”,发现之前的一些盲点在本次有了解决,,之前对深度学习、PARL、算法等都有了最新的认识,虽然还是那个小白但是认识提上去了,以后还是可以更加努力的去奋斗的,这就是传说中的回头看的时候就知道自己以前是多么的无知了。 616 | 617 | # 这里是三岁,请大家多多指教啊! 618 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Carpoel-at-age-three-DQN-algorithm 2 | 3 | ## 本代码基于AIStudio环境运行!!! 4 | 深度学习入门 | 三岁在飞桨带你入门深度学习—Carpoel,利用PARL复现基于神经网络与DQN算法(真的是0基础) 5 | # 深度学习入门 | 三岁在飞桨带你入门深度学习—Carpoel,利用PARL复现基于神经网络与DQN算法(真的是0基础) 6 | 大家好,这里是三岁,众所周知三岁是编程届小白,为了给大家贡献一个“爬山”的模板, 7 | 三岁利用最基础的深度学习“hello world”项目给大家解析,及做示范。 8 | 三岁老规矩,白话,简单,入门,基础 9 | 如果有什么不准确,不正确的地方希望大家可以提出来! 10 | (代码源于[强化学习7日打卡营-世界冠军带你从零实践>PARL强化学习公开课Lesson3_DQN](https://aistudio.baidu.com/aistudio/projectdetail/569647)) 11 | * 以下项目适用于CPU环境 12 | ## 参考资料 13 | * B站视频地址:[https://www.bilibili.com/video/bv1v54y1v7Qf](https://www.bilibili.com/video/bv1v54y1v7Qf) 14 | * AI 社区文章地址:[https://ai.baidu.com/forum/topic/show/962531](https://ai.baidu.com/forum/topic/show/962531) 15 | * CSDN文章地址:[https://editor.csdn.net/md?articleId=107393006](https://editor.csdn.net/md?articleId=107393006) 16 | * 三岁推文地址:[https://mp.weixin.qq.com/s/6-6RR0XuvTNuXKhX7fFXaQ](https://mp.weixin.qq.com/s/6-6RR0XuvTNuXKhX7fFXaQ) 17 | * 参考论文:[https://www.nature.com/articles/nature14236](https://www.nature.com/articles/nature14236) 18 | * DQNgithub地址:[https://github.com/PaddlePaddle/PARL/tree/develop/examples](https://github.com/PaddlePaddle/PARL/tree/develop/examples) 19 | * 参考视频:[https://www.bilibili.com/video/BV1yv411i7xd?p=12](https://www.bilibili.com/video/BV1yv411i7xd?p=12) 20 | * Carpoel参考资料:[https://gym.openai.com/envs/CartPole-v1/](https://gym.openai.com/envs/CartPole-v1/) 21 | * PARL官方地址:[https://github.com/PaddlePaddle/PARL](https://github.com/PaddlePaddle/PARL) 22 | 23 | 那么我们接下来就开始爬山吧,记得唱着小白船然后录像呦,三岁会帮你调整一下jio的位置的【滑稽】 24 | 25 | 26 | ### 环境预设 27 | 根据实际情况把AI Studio的环境进行修改使其更加符合代码的运行。 28 | 29 | 30 | ```python 31 | !pip uninstall -y parl # 说明:AIStudio预装的parl版本太老,容易跟其他库产生兼容性冲突,建议先卸载 32 | !pip uninstall -y pandas scikit-learn # 提示:在AIStudio中卸载这两个库再import parl可避免warning提示,不卸载也不影响parl的使用 33 | 34 | !pip install gym 35 | !pip install paddlepaddle==1.6.3 36 | !pip install parl==1.3.1 37 | 38 | # 建议下载paddle系列产品时添加百度源 -i https://mirror.baidu.com/pypi/simple 39 | # python -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple 40 | # pip install parl==1.3.1 -i https://mirror.baidu.com/pypi/simple 41 | 42 | # 说明:安装日志中出现两条红色的关于 paddlehub 和 visualdl 的 ERROR 与parl无关,可以忽略,不影响使用 43 | ``` 44 | 45 | Uninstalling parl-1.1.2: 46 | Successfully uninstalled parl-1.1.2 47 | Uninstalling pandas-0.23.4: 48 | Successfully uninstalled pandas-0.23.4 49 | Uninstalling scikit-learn-0.20.0: 50 | Successfully uninstalled scikit-learn-0.20.0 51 | Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/ 52 | Requirement already satisfied: gym in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (0.12.1) 53 | Requirement already satisfied: requests>=2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (2.22.0) 54 | Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.15.0) 55 | Requirement already satisfied: numpy>=1.10.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.16.4) 56 | Requirement already satisfied: pyglet>=1.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.4.5) 57 | Requirement already satisfied: scipy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.3.0) 58 | Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (2019.9.11) 59 | Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (3.0.4) 60 | Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (2.8) 61 | Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (1.25.6) 62 | Requirement already satisfied: future in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pyglet>=1.2.0->gym) (0.18.0) 63 | Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/ 64 | Collecting paddlepaddle==1.6.3 65 | [?25l Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/96/28/e72bebb3c9b3d98eb9b15d9f6d85150f3cbd63e695e59882ff9f04846686/paddlepaddle-1.6.3-cp37-cp37m-manylinux1_x86_64.whl (90.9MB) 66 |  |████████████████████████████████| 90.9MB 488kB/s eta 0:00:011 |███████████████▌ | 44.0MB 481kB/s eta 0:01:38 67 | [?25hRequirement already satisfied: nltk; python_version >= "3.5" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (3.4.5) 68 | Requirement already satisfied: prettytable in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (0.7.2) 69 | Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (1.15.0) 70 | Requirement already satisfied: protobuf>=3.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (3.10.0) 71 | Requirement already satisfied: numpy>=1.12; python_version >= "3.5" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (1.16.4) 72 | Requirement already satisfied: scipy; python_version >= "3.5" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (1.3.0) 73 | Requirement already satisfied: objgraph in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (3.4.1) 74 | Requirement already satisfied: decorator in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (4.4.0) 75 | Requirement already satisfied: matplotlib; python_version >= "3.6" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (2.2.3) 76 | Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (7.1.2) 77 | Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (4.1.1.26) 78 | Requirement already satisfied: requests>=2.20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (2.22.0) 79 | Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (5.1.2) 80 | Requirement already satisfied: funcsigs in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (1.0.2) 81 | Requirement already satisfied: rarfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (3.1) 82 | Requirement already satisfied: graphviz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (0.13) 83 | Requirement already satisfied: setuptools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from protobuf>=3.1.0->paddlepaddle==1.6.3) (41.4.0) 84 | Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= "3.6"->paddlepaddle==1.6.3) (0.10.0) 85 | Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= "3.6"->paddlepaddle==1.6.3) (1.1.0) 86 | Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= "3.6"->paddlepaddle==1.6.3) (2019.3) 87 | Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= "3.6"->paddlepaddle==1.6.3) (2.8.0) 88 | Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= "3.6"->paddlepaddle==1.6.3) (2.4.2) 89 | Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle==1.6.3) (2.8) 90 | Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle==1.6.3) (3.0.4) 91 | Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle==1.6.3) (2019.9.11) 92 | Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle==1.6.3) (1.25.6) 93 | Installing collected packages: paddlepaddle 94 | Found existing installation: paddlepaddle 1.8.0 95 | Uninstalling paddlepaddle-1.8.0: 96 | Successfully uninstalled paddlepaddle-1.8.0 97 | Successfully installed paddlepaddle-1.6.3 98 | Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/ 99 | Collecting parl==1.3.1 100 | [?25l Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/62/79/590af38a920792c71afb73fad7583967928b4d0ba9fca76250d935c7fda8/parl-1.3.1-py2.py3-none-any.whl (521kB) 101 |  |████████████████████████████████| 522kB 17.8MB/s eta 0:00:01 102 | [?25hRequirement already satisfied: tb-nightly==1.15.0a20190801 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.15.0a20190801) 103 | Requirement already satisfied: visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (2.0.0b4) 104 | Requirement already satisfied: pyzmq==18.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (18.0.1) 105 | Requirement already satisfied: flask>=1.0.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.1.1) 106 | Collecting flask-cors (from parl==1.3.1) 107 | Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/78/38/e68b11daa5d613e3a91e4bf3da76c94ac9ee0d9cd515af9c1ab80d36f709/Flask_Cors-3.0.8-py2.py3-none-any.whl 108 | Requirement already satisfied: click in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (7.0) 109 | Collecting psutil>=5.6.2 (from parl==1.3.1) 110 | [?25l Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/aa/3e/d18f2c04cf2b528e18515999b0c8e698c136db78f62df34eee89cee205f1/psutil-5.7.2.tar.gz (460kB) 111 |  |████████████████████████████████| 460kB 52.1MB/s eta 0:00:01 112 | [?25hRequirement already satisfied: tensorboardX==1.8 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.8) 113 | Requirement already satisfied: pyarrow==0.13.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (0.13.0) 114 | Requirement already satisfied: scipy>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.3.0) 115 | Requirement already satisfied: termcolor>=1.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.1.0) 116 | Requirement already satisfied: cloudpickle==1.2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.2.1) 117 | Requirement already satisfied: absl-py>=0.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (0.8.1) 118 | Requirement already satisfied: six>=1.10.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (1.15.0) 119 | Requirement already satisfied: grpcio>=1.6.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (1.26.0) 120 | Requirement already satisfied: wheel>=0.26; python_version >= "3" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (0.33.6) 121 | Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (41.4.0) 122 | Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (0.16.0) 123 | Requirement already satisfied: markdown>=2.6.8 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (3.1.1) 124 | Requirement already satisfied: protobuf>=3.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (3.10.0) 125 | Requirement already satisfied: numpy>=1.12.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (1.16.4) 126 | Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2.22.0) 127 | Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (1.21.0) 128 | Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (3.8.2) 129 | Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (1.0.0) 130 | Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (4.1.1.26) 131 | Requirement already satisfied: Pillow>=7.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (7.1.2) 132 | Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.0.4->parl==1.3.1) (1.1.0) 133 | Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.0.4->parl==1.3.1) (2.10.3) 134 | Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2019.9.11) 135 | Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2.8) 136 | Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (3.0.4) 137 | Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (1.25.6) 138 | Requirement already satisfied: importlib-metadata; python_version < "3.8" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (0.23) 139 | Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (1.4.10) 140 | Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (1.3.4) 141 | Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2.0.1) 142 | Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (1.3.0) 143 | Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (16.7.9) 144 | Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (0.10.0) 145 | Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (5.1.2) 146 | Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (0.6.1) 147 | Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2.6.0) 148 | Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2.2.0) 149 | Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2019.3) 150 | Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (2.8.0) 151 | Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.0.4->parl==1.3.1) (1.1.1) 152 | Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata; python_version < "3.8"->pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (0.6.0) 153 | Requirement already satisfied: more-itertools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from zipp>=0.5->importlib-metadata; python_version < "3.8"->pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl==1.3.1) (7.2.0) 154 | Building wheels for collected packages: psutil 155 | Building wheel for psutil (setup.py) ... [?25ldone 156 | [?25h Created wheel for psutil: filename=psutil-5.7.2-cp37-cp37m-linux_x86_64.whl size=268459 sha256=ef019c256f341f219f0260fa4acfe2863e8f0bf1f96ed95b3f910a2a4c44a74c 157 | Stored in directory: /home/aistudio/.cache/pip/wheels/a8/74/a2/9f54383a7c48678163f965a5d2f4acb794417e60ab0d7351f8 158 | Successfully built psutil 159 | Installing collected packages: flask-cors, psutil, parl 160 | Successfully installed flask-cors-3.0.8 parl-1.3.1 psutil-5.7.2 161 | 162 | 163 | ### Step2 导入依赖 164 | 如果依赖导入失败有可能没有下载第三方库可以在此前加上代码块 165 | !pip instudio 第三方库名 166 | 167 | 如果安装失败可以加上镜像 168 | 169 | ![](https://ai-studio-static-online.cdn.bcebos.com/8ee07cbb24874e96bdac93b07e3a22e3b6b465981c904129990af727701bd0b0) 170 | ``` 171 | 镜像源地址: 172 | 173 | 百度:https://mirror.baidu.com/pypi/simple 174 | 175 | 清华:https://pypi.tuna.tsinghua.edu.cn/simple 176 | 177 | 阿里云:http://mirrors.aliyun.com/pypi/simple/ 178 | 179 | 中国科技大学 https://pypi.mirrors.ustc.edu.cn/simple/ 180 | 181 | 华中理工大学:http://pypi.hustunique.com/ 182 | 183 | 山东理工大学:http://pypi.sdutlinux.org/ 184 | 185 | 豆瓣:http://pypi.douban.com/simple/ 186 | 187 | 例:!pip instudio jieba -i https://mirror.baidu.com/pypi/simple 188 | ``` 189 | 190 | 191 | ```python 192 | import parl 193 | from parl import layers 194 | import paddle.fluid as fluid 195 | import copy 196 | import numpy as np 197 | import os 198 | import gym 199 | from parl.utils import logger 200 | ``` 201 | 202 | # 流程解析 203 | ![](https://ai-studio-static-online.cdn.bcebos.com/25f4800e0c3b44bc9ddc908a4d3dba1793dc74f0f5a045e799771f0466b88e5b) 204 | 205 | 206 | ### Step3 设置超参数 207 | 208 | 209 | ```python 210 | LEARN_FREQ = 5 # 训练频率,不需要每一个step都learn,攒一些新增经验后再learn,提高效率 211 | MEMORY_SIZE = 20000 # replay memory的大小,越大越占用内存 212 | MEMORY_WARMUP_SIZE = 200 # replay_memory 里需要预存一些经验数据,再开启训练 213 | BATCH_SIZE = 32 # 每次给agent learn的数据数量,从replay memory随机里sample一批数据出来 214 | LEARNING_RATE = 0.001 # 学习率 215 | GAMMA = 0.99 # reward 的衰减因子,一般取 0.9 到 0.999 不等 216 | ``` 217 | 218 | # 搭建PARL 219 | 什么是PARL呢?让我们看看官方文档怎么说 220 | 221 | ![飞桨官网描述](https://ai-studio-static-online.cdn.bcebos.com/1f72929e036549d493dee85c2834b8a98e330c4ac6a447f08b647cab09e12bfa) 飞桨官网描述 222 | ![PARL GitHub描述](https://ai-studio-static-online.cdn.bcebos.com/9a19623bdf41411ca07c089d3cf0f64752e0eb51a58a47e694bc4e4614079078) PARL GitHub描述 223 | 224 | * PARL主要是基于Model、Algorithm、Agent三个代码块来实现,其中Model和Agent是用户自定义操作。 225 | * Model:是网络结构:要三层网络还是四层网络都是在Model中去定义(下文是三层的网络结构) 226 | * Agent:是PARL与环境的一个接口,通过对模板的修改即可运用到各个不同的环境中去。 227 | 228 | * 至于Algorithm是内部已经封装好了的,直接加入参数运行即可,主要是算法的模块的展现 229 | 230 | 231 | ### Step4 搭建Model、Algorithm、Agent架构 232 | * `Agent`把产生的数据传给`algorithm`,`algorithm`根据`model`的模型结构计算出`Loss`,使用`SGD`或者其他优化器不断的优化,`PARL`这种架构可以很方便的应用在各类深度强化学习问题中。 233 | 234 | #### (1)Model 235 | * `Model`用来定义前向(`Forward`)网络,用户可以自由的定制自己的网络结构。 236 | 237 | 238 | ```python 239 | class Model(parl.Model): 240 | def __init__(self, act_dim): 241 | hid1_size = 128 242 | hid2_size = 128 243 | # 3层全连接网络 244 | self.fc1 = layers.fc(size=hid1_size, act='relu') 245 | self.fc2 = layers.fc(size=hid2_size, act='relu') 246 | self.fc3 = layers.fc(size=act_dim, act=None) 247 | 248 | def value(self, obs): 249 | # 定义网络 250 | # 输入state,输出所有action对应的Q,[Q(s,a1), Q(s,a2), Q(s,a3)...] 251 | h1 = self.fc1(obs) 252 | h2 = self.fc2(h1) 253 | Q = self.fc3(h2) 254 | return Q 255 | ``` 256 | 257 | ## Q算法的“藏身之地” 258 | ### 优势 259 | DQN算法较普通算法在经验回放和固定Q目标有了较大的改进 260 | 261 | * 1、经验回放:他充分利用了off-colicp的优势,通过训练把结果(成绩)存入Q表格,然后随机从表格中取出一条结果进行优化。这样子一方面可以:减少样本之间的关联性另一方面:提高样本的利用率 262 | 注:训练结果会存进Q表格,当Q表格满了以后,存进来的数据会把最早存进去的数据“挤出去”(弹出) 263 | * 2、固定Q目标他解决了算法更新不平稳的问题。 264 | 和监督学习做比较,监督学习的最终值要逼近实际结果,这个结果是固定的,但是我们的DQN却不是,他的目标值是经过神经网络以后的一个值,那么这个值是变动的不好拟合,怎么办,DQN团队想到了一个很好的办法,让这个值在一定时间里面保持不变,这样子这个目标就可以确定了,然后目标值更新以后更加接近实际结果,可以更好的进行训练。 265 | 266 | #### (2)Algorithm 267 | * `Algorithm` 定义了具体的算法来更新前向网络(`Model`),也就是通过定义损失函数来更新`Model`,和算法相关的计算都放在`algorithm`中。 268 | 269 | 270 | 271 | 272 | ```python 273 | # from parl.algorithms import DQN # 也可以直接从parl库中导入DQN算法 274 | 275 | class DQN(parl.Algorithm): 276 | def __init__(self, model, act_dim=None, gamma=None, lr=None): 277 | """ DQN algorithm 278 | 279 | Args: 280 | model (parl.Model): 定义Q函数的前向网络结构 281 | act_dim (int): action空间的维度,即有几个action 282 | gamma (float): reward的衰减因子 283 | lr (float): learning rate 学习率. 284 | """ 285 | self.model = model 286 | self.target_model = copy.deepcopy(model) 287 | 288 | assert isinstance(act_dim, int) 289 | assert isinstance(gamma, float) 290 | assert isinstance(lr, float) 291 | self.act_dim = act_dim 292 | self.gamma = gamma 293 | self.lr = lr 294 | 295 | def predict(self, obs): 296 | """ 使用self.model的value网络来获取 [Q(s,a1),Q(s,a2),...] 297 | """ 298 | return self.model.value(obs) 299 | 300 | def learn(self, obs, action, reward, next_obs, terminal): 301 | """ 使用DQN算法更新self.model的value网络 302 | """ 303 | # 从target_model中获取 max Q' 的值,用于计算target_Q 304 | next_pred_value = self.target_model.value(next_obs) 305 | best_v = layers.reduce_max(next_pred_value, dim=1) 306 | best_v.stop_gradient = True # 阻止梯度传递 307 | terminal = layers.cast(terminal, dtype='float32') 308 | target = reward + (1.0 - terminal) * self.gamma * best_v 309 | 310 | pred_value = self.model.value(obs) # 获取Q预测值 311 | # 将action转onehot向量,比如:3 => [0,0,0,1,0] 312 | action_onehot = layers.one_hot(action, self.act_dim) 313 | action_onehot = layers.cast(action_onehot, dtype='float32') 314 | # 下面一行是逐元素相乘,拿到action对应的 Q(s,a) 315 | # 比如:pred_value = [[2.3, 5.7, 1.2, 3.9, 1.4]], action_onehot = [[0,0,0,1,0]] 316 | # ==> pred_action_value = [[3.9]] 317 | pred_action_value = layers.reduce_sum( 318 | layers.elementwise_mul(action_onehot, pred_value), dim=1) 319 | 320 | # 计算 Q(s,a) 与 target_Q的均方差,得到loss 321 | cost = layers.square_error_cost(pred_action_value, target) 322 | cost = layers.reduce_mean(cost) 323 | optimizer = fluid.optimizer.Adam(learning_rate=self.lr) # 使用Adam优化器 324 | optimizer.minimize(cost) 325 | return cost 326 | 327 | def sync_target(self): 328 | """ 把 self.model 的模型参数值同步到 self.target_model 329 | """ 330 | self.model.sync_weights_to(self.target_model) 331 | 332 | ``` 333 | 334 | #### (3)Agent 335 | * `Agent` 负责算法与环境的交互,在交互过程中把生成的数据提供给`Algorithm`来更新模型(`Model`),数据的预处理流程也一般定义在这里。 336 | 337 | 338 | ```python 339 | class Agent(parl.Agent): 340 | def __init__(self, 341 | algorithm, 342 | obs_dim, 343 | act_dim, 344 | e_greed=0.1, 345 | e_greed_decrement=0): 346 | assert isinstance(obs_dim, int) 347 | assert isinstance(act_dim, int) 348 | self.obs_dim = obs_dim 349 | self.act_dim = act_dim 350 | super(Agent, self).__init__(algorithm) 351 | 352 | self.global_step = 0 353 | self.update_target_steps = 200 # 每隔200个training steps再把model的参数复制到target_model中 354 | 355 | self.e_greed = e_greed # 有一定概率随机选取动作,探索 356 | self.e_greed_decrement = e_greed_decrement # 随着训练逐步收敛,探索的程度慢慢降低 357 | 358 | def build_program(self): 359 | self.pred_program = fluid.Program() 360 | self.learn_program = fluid.Program() 361 | 362 | with fluid.program_guard(self.pred_program): # 搭建计算图用于 预测动作,定义输入输出变量 363 | obs = layers.data( 364 | name='obs', shape=[self.obs_dim], dtype='float32') 365 | self.value = self.alg.predict(obs) 366 | 367 | with fluid.program_guard(self.learn_program): # 搭建计算图用于 更新Q网络,定义输入输出变量 368 | obs = layers.data( 369 | name='obs', shape=[self.obs_dim], dtype='float32') 370 | action = layers.data(name='act', shape=[1], dtype='int32') 371 | reward = layers.data(name='reward', shape=[], dtype='float32') 372 | next_obs = layers.data( 373 | name='next_obs', shape=[self.obs_dim], dtype='float32') 374 | terminal = layers.data(name='terminal', shape=[], dtype='bool') 375 | self.cost = self.alg.learn(obs, action, reward, next_obs, terminal) 376 | 377 | def sample(self, obs): 378 | sample = np.random.rand() # 产生0~1之间的小数 379 | if sample < self.e_greed: 380 | act = np.random.randint(self.act_dim) # 探索:每个动作都有概率被选择 381 | else: 382 | act = self.predict(obs) # 选择最优动作 383 | self.e_greed = max( 384 | 0.01, self.e_greed - self.e_greed_decrement) # 随着训练逐步收敛,探索的程度慢慢降低 385 | return act 386 | 387 | def predict(self, obs): # 选择最优动作 388 | obs = np.expand_dims(obs, axis=0) 389 | pred_Q = self.fluid_executor.run( 390 | self.pred_program, 391 | feed={'obs': obs.astype('float32')}, 392 | fetch_list=[self.value])[0] 393 | pred_Q = np.squeeze(pred_Q, axis=0) 394 | act = np.argmax(pred_Q) # 选择Q最大的下标,即对应的动作 395 | return act 396 | 397 | def learn(self, obs, act, reward, next_obs, terminal): 398 | # 每隔200个training steps同步一次model和target_model的参数 399 | if self.global_step % self.update_target_steps == 0: 400 | self.alg.sync_target() 401 | self.global_step += 1 402 | 403 | act = np.expand_dims(act, -1) 404 | feed = { 405 | 'obs': obs.astype('float32'), 406 | 'act': act.astype('int32'), 407 | 'reward': reward, 408 | 'next_obs': next_obs.astype('float32'), 409 | 'terminal': terminal 410 | } 411 | cost = self.fluid_executor.run( 412 | self.learn_program, feed=feed, fetch_list=[self.cost])[0] # 训练一次网络 413 | return cost 414 | ``` 415 | 416 | ### Step5 ReplayMemory 417 | * 经验池:用于存储多条经验,实现 经验回放。 418 | ![](https://ai-studio-static-online.cdn.bcebos.com/340817e194974c24ac63dd569fba336cd500d120729e4354825fb9c3501108ab) 419 | 420 | 421 | 422 | ```python 423 | import random 424 | import collections 425 | import numpy as np 426 | 427 | 428 | class ReplayMemory(object): 429 | def __init__(self, max_size): 430 | self.buffer = collections.deque(maxlen=max_size) 431 | 432 | # 增加一条经验到经验池中 433 | def append(self, exp): 434 | self.buffer.append(exp) 435 | 436 | # 从经验池中选取N条经验出来 437 | def sample(self, batch_size): 438 | mini_batch = random.sample(self.buffer, batch_size) 439 | obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = [], [], [], [], [] 440 | 441 | for experience in mini_batch: 442 | s, a, r, s_p, done = experience 443 | obs_batch.append(s) 444 | action_batch.append(a) 445 | reward_batch.append(r) 446 | next_obs_batch.append(s_p) 447 | done_batch.append(done) 448 | 449 | return np.array(obs_batch).astype('float32'), \ 450 | np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'),\ 451 | np.array(next_obs_batch).astype('float32'), np.array(done_batch).astype('float32') 452 | 453 | def __len__(self): 454 | return len(self.buffer) 455 | 456 | ``` 457 | 458 | ### Step6 Training && Test(训练&&测试) 459 | ![](https://ai-studio-static-online.cdn.bcebos.com/3b9f168ee09a46dc9962decb0d2a60f7d35196d459824b9f923b163a0b4bbe4e) 460 | ![](https://ai-studio-static-online.cdn.bcebos.com/4ffa39f252744a3791c04c5d8381824d4f7447a592d4409da658045061bcd0b9) 461 | * 训练和评估的一个模块 462 | 463 | 464 | ```python 465 | # 训练一个episode 466 | def run_episode(env, agent, rpm): 467 | total_reward = 0 468 | obs = env.reset() 469 | step = 0 470 | while True: 471 | step += 1 472 | action = agent.sample(obs) # 采样动作,所有动作都有概率被尝试到 473 | next_obs, reward, done, _ = env.step(action) 474 | rpm.append((obs, action, reward, next_obs, done)) 475 | 476 | # train model 477 | if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0): 478 | (batch_obs, batch_action, batch_reward, batch_next_obs, 479 | batch_done) = rpm.sample(BATCH_SIZE) 480 | train_loss = agent.learn(batch_obs, batch_action, batch_reward, 481 | batch_next_obs, 482 | batch_done) # s,a,r,s',done 483 | 484 | total_reward += reward 485 | obs = next_obs 486 | if done: 487 | break 488 | return total_reward 489 | 490 | 491 | # 评估 agent, 跑 5 个episode,总reward求平均 492 | def evaluate(env, agent, render=False): 493 | eval_reward = [] 494 | for i in range(5): 495 | obs = env.reset() 496 | episode_reward = 0 497 | while True: 498 | action = agent.predict(obs) # 预测动作,只选最优动作 499 | obs, reward, done, _ = env.step(action) 500 | episode_reward += reward 501 | if render: 502 | env.render() 503 | if done: 504 | break 505 | eval_reward.append(episode_reward) 506 | return np.mean(eval_reward) 507 | 508 | ``` 509 | 510 | ### Step7 创建环境和Agent,创建经验池,启动训练,保存模型 511 | 512 | # 主函数 513 | * 讲不清,理还乱,看一波图解,冷静冷静!!! 514 | ![](https://ai-studio-static-online.cdn.bcebos.com/dab5d244b98c41108a84e37b79db0009109ed32d4d734ca7b04b68e2e301da09) 515 | ![](https://ai-studio-static-online.cdn.bcebos.com/a297e56a6de045169e4a04391c22aa80bf7cc9d9608446fd97bef4dcd16431c8) 516 | ![](https://ai-studio-static-online.cdn.bcebos.com/8bc3b4cd1913479480baa286b120f45b49cb8ac7c36645df8423d873a7bd6d07) 517 | 518 | 519 | 520 | ```python 521 | env = gym.make('CartPole-v0') # CartPole-v0: 预期最后一次评估总分 > 180(最大值是200) 522 | action_dim = env.action_space.n # CartPole-v0: 2 523 | obs_shape = env.observation_space.shape # CartPole-v0: (4,) 524 | 525 | rpm = ReplayMemory(MEMORY_SIZE) # DQN的经验回放池 526 | 527 | # 根据parl框架构建agent 528 | model = Model(act_dim=action_dim) 529 | algorithm = DQN(model, act_dim=action_dim, gamma=GAMMA, lr=LEARNING_RATE) 530 | agent = Agent( 531 | algorithm, 532 | obs_dim=obs_shape[0], 533 | act_dim=action_dim, 534 | e_greed=0.1, # 有一定概率随机选取动作,探索 535 | e_greed_decrement=1e-6) # 随着训练逐步收敛,探索的程度慢慢降低 536 | 537 | # 加载模型 538 | # save_path = './dqn_model.ckpt' 539 | # agent.restore(save_path) 540 | 541 | # 先往经验池里存一些数据,避免最开始训练的时候样本丰富度不够 542 | while len(rpm) < MEMORY_WARMUP_SIZE: 543 | run_episode(env, agent, rpm) 544 | 545 | max_episode = 2000 546 | 547 | # 开始训练 548 | episode = 0 549 | while episode < max_episode: # 训练max_episode个回合,test部分不计算入episode数量 550 | # train part 551 | for i in range(0, 50): 552 | total_reward = run_episode(env, agent, rpm) 553 | episode += 1 554 | 555 | # test part 556 | eval_reward = evaluate(env, agent, render=False) # render=True 查看显示效果 557 | logger.info('episode:{} e_greed:{} test_reward:{}'.format( 558 | episode, agent.e_greed, eval_reward)) 559 | 560 | # 训练结束,保存模型 561 | save_path = './dqn_model.ckpt' 562 | agent.save(save_path) 563 | ``` 564 | 565 | [07-27 18:19:17 MainThread @machine_info.py:88] Cannot find available GPU devices, using CPU now. 566 | [07-27 18:19:17 MainThread @machine_info.py:88] Cannot find available GPU devices, using CPU now. 567 | [07-27 18:19:18 MainThread @machine_info.py:88] Cannot find available GPU devices, using CPU now. 568 | [07-27 18:19:19 MainThread @:38] episode:50 e_greed:0.09930599999999931 test_reward:11.0 569 | [07-27 18:19:21 MainThread @:38] episode:100 e_greed:0.09877699999999878 test_reward:9.4 570 | [07-27 18:19:23 MainThread @:38] episode:150 e_greed:0.09828699999999829 test_reward:9.4 571 | [07-27 18:19:24 MainThread @:38] episode:200 e_greed:0.09781299999999782 test_reward:9.2 572 | [07-27 18:19:26 MainThread @:38] episode:250 e_greed:0.09733199999999734 test_reward:9.2 573 | [07-27 18:19:27 MainThread @:38] episode:300 e_greed:0.09683499999999684 test_reward:9.6 574 | [07-27 18:19:29 MainThread @:38] episode:350 e_greed:0.09634199999999635 test_reward:9.6 575 | [07-27 18:19:30 MainThread @:38] episode:400 e_greed:0.09583999999999585 test_reward:9.6 576 | [07-27 18:19:32 MainThread @:38] episode:450 e_greed:0.09533799999999534 test_reward:11.0 577 | [07-27 18:19:34 MainThread @:38] episode:500 e_greed:0.09476499999999477 test_reward:9.4 578 | [07-27 18:19:37 MainThread @:38] episode:550 e_greed:0.09395199999999396 test_reward:44.4 579 | [07-27 18:19:43 MainThread @:38] episode:600 e_greed:0.0921909999999922 test_reward:21.2 580 | [07-27 18:20:05 MainThread @:38] episode:650 e_greed:0.08648299999998649 test_reward:186.0 581 | [07-27 18:20:39 MainThread @:38] episode:700 e_greed:0.07723099999997723 test_reward:199.0 582 | [07-27 18:21:12 MainThread @:38] episode:750 e_greed:0.06832399999996833 test_reward:188.6 583 | [07-27 18:21:42 MainThread @:38] episode:800 e_greed:0.06003099999996003 test_reward:124.6 584 | [07-27 18:22:09 MainThread @:38] episode:850 e_greed:0.05293599999995294 test_reward:165.6 585 | [07-27 18:22:31 MainThread @:38] episode:900 e_greed:0.04666999999994667 test_reward:110.2 586 | [07-27 18:22:56 MainThread @:38] episode:950 e_greed:0.040123999999940124 test_reward:111.0 587 | [07-27 18:23:21 MainThread @:38] episode:1000 e_greed:0.03343099999993343 test_reward:133.6 588 | [07-27 18:23:47 MainThread @:38] episode:1050 e_greed:0.026627999999926627 test_reward:130.0 589 | [07-27 18:24:08 MainThread @:38] episode:1100 e_greed:0.02101999999992102 test_reward:130.0 590 | [07-27 18:24:31 MainThread @:38] episode:1150 e_greed:0.015293999999915868 test_reward:179.4 591 | [07-27 18:25:03 MainThread @:38] episode:1200 e_greed:0.01 test_reward:183.0 592 | [07-27 18:25:33 MainThread @:38] episode:1250 e_greed:0.01 test_reward:130.8 593 | [07-27 18:25:59 MainThread @:38] episode:1300 e_greed:0.01 test_reward:172.8 594 | [07-27 18:26:25 MainThread @:38] episode:1350 e_greed:0.01 test_reward:152.2 595 | [07-27 18:26:49 MainThread @:38] episode:1400 e_greed:0.01 test_reward:127.8 596 | [07-27 18:27:19 MainThread @:38] episode:1450 e_greed:0.01 test_reward:182.0 597 | [07-27 18:27:52 MainThread @:38] episode:1500 e_greed:0.01 test_reward:120.6 598 | [07-27 18:28:25 MainThread @:38] episode:1550 e_greed:0.01 test_reward:133.6 599 | [07-27 18:28:45 MainThread @:38] episode:1600 e_greed:0.01 test_reward:14.0 600 | [07-27 18:28:49 MainThread @:38] episode:1650 e_greed:0.01 test_reward:15.8 601 | [07-27 18:29:09 MainThread @:38] episode:1700 e_greed:0.01 test_reward:164.2 602 | [07-27 18:29:47 MainThread @:38] episode:1750 e_greed:0.01 test_reward:200.0 603 | [07-27 18:30:26 MainThread @:38] episode:1800 e_greed:0.01 test_reward:200.0 604 | [07-27 18:31:03 MainThread @:38] episode:1850 e_greed:0.01 test_reward:182.4 605 | [07-27 18:31:40 MainThread @:38] episode:1900 e_greed:0.01 test_reward:193.8 606 | [07-27 18:32:17 MainThread @:38] episode:1950 e_greed:0.01 test_reward:193.4 607 | [07-27 18:32:53 MainThread @:38] episode:2000 e_greed:0.01 test_reward:170.2 608 | 609 | 610 | ![](https://ai-studio-static-online.cdn.bcebos.com/94a7460d1da44164a7805c5680a2779689b1dc4c78494104aa3d2bda075411fa) 611 | 612 | 613 | # 运行经历 614 | 这一串代码是从课件上copy下来的,所以里面的参数及数据都是已经调整好了的,但是在线下跑数据发现每一次跑也并非到最后都是200也就是每一次的结果都是不一定的,然后我开了显示模式,里面的测试画面和结果会被显示和打印,前面的速度会比后面的快,因为分数低时间自然就短了。 615 | 更据这段时间对深度学习的学习,基本上代码运行没有问题后那么调整几个超参基本上可以得到一个比较好的拟合效果。 616 | 617 | 618 | # 心得体会 619 | 本次“爬山”,发现之前的一些盲点在本次有了解决,,之前对深度学习、PARL、算法等都有了最新的认识,虽然还是那个小白但是认识提上去了,以后还是可以更加努力的去奋斗的,这就是传说中的回头看的时候就知道自己以前是多么的无知了。 620 | 621 | # 这里是三岁,请大家多多指教啊! 622 | -------------------------------------------------------------------------------- /利用PARL复现基于神经网络与DQN算法.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": false 7 | }, 8 | "source": [ 9 | "# 深度学习入门 | 三岁在飞桨带你入门深度学习—Carpoel,利用PARL复现基于神经网络与DQN算法(真的是0基础)\n", 10 | "大家好,这里是三岁,众所周知三岁是编程届小白,为了给大家贡献一个“爬山”的模板,\n", 11 | "三岁利用最基础的深度学习“hello world”项目给大家解析,及做示范。\n", 12 | "三岁老规矩,白话,简单,入门,基础\n", 13 | "如果有什么不准确,不正确的地方希望大家可以提出来!\n", 14 | "(代码源于[强化学习7日打卡营-世界冠军带你从零实践>PARL强化学习公开课Lesson3_DQN](https://aistudio.baidu.com/aistudio/projectdetail/569647))\n", 15 | "* 以下项目适用于CPU环境\n", 16 | "## 参考资料\n", 17 | "* B站视频地址:[https://www.bilibili.com/video/bv1v54y1v7Qf](https://www.bilibili.com/video/bv1v54y1v7Qf)\n", 18 | "* AI 社区文章地址:[https://ai.baidu.com/forum/topic/show/962531](https://ai.baidu.com/forum/topic/show/962531)\n", 19 | "* CSDN文章地址:[https://editor.csdn.net/md?articleId=107393006](https://editor.csdn.net/md?articleId=107393006)\n", 20 | "* 三岁推文地址:[https://mp.weixin.qq.com/s/6-6RR0XuvTNuXKhX7fFXaQ](https://mp.weixin.qq.com/s/6-6RR0XuvTNuXKhX7fFXaQ)\n", 21 | "* 参考论文:[https://www.nature.com/articles/nature14236](https://www.nature.com/articles/nature14236)\n", 22 | "* DQNgithub地址:[https://github.com/PaddlePaddle/PARL/tree/develop/examples](https://github.com/PaddlePaddle/PARL/tree/develop/examples)\n", 23 | "* 参考视频:[https://www.bilibili.com/video/BV1yv411i7xd?p=12](https://www.bilibili.com/video/BV1yv411i7xd?p=12)\n", 24 | "* Carpoel参考资料:[https://gym.openai.com/envs/CartPole-v1/](https://gym.openai.com/envs/CartPole-v1/)\n", 25 | "* PARL官方地址:[https://github.com/PaddlePaddle/PARL](https://github.com/PaddlePaddle/PARL)\n", 26 | "\n", 27 | "那么我们接下来就开始爬山吧,记得唱着小白船然后录像呦,三岁会帮你调整一下jio的位置的【滑稽】\n" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "collapsed": false 34 | }, 35 | "source": [ 36 | "### 环境预设\n", 37 | "根据实际情况把AI Studio的环境进行修改使其更加符合代码的运行。" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 1, 43 | "metadata": { 44 | "collapsed": false 45 | }, 46 | "outputs": [ 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "Uninstalling parl-1.1.2:\n", 52 | " Successfully uninstalled parl-1.1.2\n", 53 | "Uninstalling pandas-0.23.4:\n", 54 | " Successfully uninstalled pandas-0.23.4\n", 55 | "Uninstalling scikit-learn-0.20.0:\n", 56 | " Successfully uninstalled scikit-learn-0.20.0\n", 57 | "Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/\n", 58 | "Requirement already satisfied: gym in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (0.12.1)\n", 59 | "Requirement already satisfied: requests>=2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (2.22.0)\n", 60 | "Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.15.0)\n", 61 | "Requirement already satisfied: numpy>=1.10.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.16.4)\n", 62 | "Requirement already satisfied: pyglet>=1.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.4.5)\n", 63 | "Requirement already satisfied: scipy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.3.0)\n", 64 | "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (2019.9.11)\n", 65 | "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (3.0.4)\n", 66 | "Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (2.8)\n", 67 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (1.25.6)\n", 68 | "Requirement already satisfied: future in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pyglet>=1.2.0->gym) (0.18.0)\n", 69 | "Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/\n", 70 | "Collecting paddlepaddle==1.6.3\n", 71 | "\u001b[?25l Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/96/28/e72bebb3c9b3d98eb9b15d9f6d85150f3cbd63e695e59882ff9f04846686/paddlepaddle-1.6.3-cp37-cp37m-manylinux1_x86_64.whl (90.9MB)\n", 72 | "\u001b[K |████████████████████████████████| 90.9MB 488kB/s eta 0:00:011 |███████████████▌ | 44.0MB 481kB/s eta 0:01:38\n", 73 | "\u001b[?25hRequirement already satisfied: nltk; python_version >= \"3.5\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (3.4.5)\n", 74 | "Requirement already satisfied: prettytable in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (0.7.2)\n", 75 | "Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (1.15.0)\n", 76 | "Requirement already satisfied: protobuf>=3.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (3.10.0)\n", 77 | "Requirement already satisfied: numpy>=1.12; python_version >= \"3.5\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (1.16.4)\n", 78 | "Requirement already satisfied: scipy; python_version >= \"3.5\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (1.3.0)\n", 79 | "Requirement already satisfied: objgraph in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (3.4.1)\n", 80 | "Requirement already satisfied: decorator in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (4.4.0)\n", 81 | "Requirement already satisfied: matplotlib; python_version >= \"3.6\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (2.2.3)\n", 82 | "Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (7.1.2)\n", 83 | "Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (4.1.1.26)\n", 84 | "Requirement already satisfied: requests>=2.20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (2.22.0)\n", 85 | "Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (5.1.2)\n", 86 | "Requirement already satisfied: funcsigs in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (1.0.2)\n", 87 | "Requirement already satisfied: rarfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (3.1)\n", 88 | "Requirement already satisfied: graphviz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle==1.6.3) (0.13)\n", 89 | "Requirement already satisfied: setuptools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from protobuf>=3.1.0->paddlepaddle==1.6.3) (41.4.0)\n", 90 | "Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= \"3.6\"->paddlepaddle==1.6.3) (0.10.0)\n", 91 | "Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= \"3.6\"->paddlepaddle==1.6.3) (1.1.0)\n", 92 | "Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= \"3.6\"->paddlepaddle==1.6.3) (2019.3)\n", 93 | "Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= \"3.6\"->paddlepaddle==1.6.3) (2.8.0)\n", 94 | "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= \"3.6\"->paddlepaddle==1.6.3) (2.4.2)\n", 95 | "Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle==1.6.3) (2.8)\n", 96 | "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle==1.6.3) (3.0.4)\n", 97 | "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle==1.6.3) (2019.9.11)\n", 98 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle==1.6.3) (1.25.6)\n", 99 | "Installing collected packages: paddlepaddle\n", 100 | " Found existing installation: paddlepaddle 1.8.0\n", 101 | " Uninstalling paddlepaddle-1.8.0:\n", 102 | " Successfully uninstalled paddlepaddle-1.8.0\n", 103 | "Successfully installed paddlepaddle-1.6.3\n", 104 | "Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/\n", 105 | "Collecting parl==1.3.1\n", 106 | "\u001b[?25l Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/62/79/590af38a920792c71afb73fad7583967928b4d0ba9fca76250d935c7fda8/parl-1.3.1-py2.py3-none-any.whl (521kB)\n", 107 | "\u001b[K |████████████████████████████████| 522kB 17.8MB/s eta 0:00:01\n", 108 | "\u001b[?25hRequirement already satisfied: tb-nightly==1.15.0a20190801 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.15.0a20190801)\n", 109 | "Requirement already satisfied: visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (2.0.0b4)\n", 110 | "Requirement already satisfied: pyzmq==18.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (18.0.1)\n", 111 | "Requirement already satisfied: flask>=1.0.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.1.1)\n", 112 | "Collecting flask-cors (from parl==1.3.1)\n", 113 | " Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/78/38/e68b11daa5d613e3a91e4bf3da76c94ac9ee0d9cd515af9c1ab80d36f709/Flask_Cors-3.0.8-py2.py3-none-any.whl\n", 114 | "Requirement already satisfied: click in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (7.0)\n", 115 | "Collecting psutil>=5.6.2 (from parl==1.3.1)\n", 116 | "\u001b[?25l Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/aa/3e/d18f2c04cf2b528e18515999b0c8e698c136db78f62df34eee89cee205f1/psutil-5.7.2.tar.gz (460kB)\n", 117 | "\u001b[K |████████████████████████████████| 460kB 52.1MB/s eta 0:00:01\n", 118 | "\u001b[?25hRequirement already satisfied: tensorboardX==1.8 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.8)\n", 119 | "Requirement already satisfied: pyarrow==0.13.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (0.13.0)\n", 120 | "Requirement already satisfied: scipy>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.3.0)\n", 121 | "Requirement already satisfied: termcolor>=1.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.1.0)\n", 122 | "Requirement already satisfied: cloudpickle==1.2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl==1.3.1) (1.2.1)\n", 123 | "Requirement already satisfied: absl-py>=0.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (0.8.1)\n", 124 | "Requirement already satisfied: six>=1.10.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (1.15.0)\n", 125 | "Requirement already satisfied: grpcio>=1.6.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (1.26.0)\n", 126 | "Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (0.33.6)\n", 127 | "Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (41.4.0)\n", 128 | "Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (0.16.0)\n", 129 | "Requirement already satisfied: markdown>=2.6.8 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (3.1.1)\n", 130 | "Requirement already satisfied: protobuf>=3.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (3.10.0)\n", 131 | "Requirement already satisfied: numpy>=1.12.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl==1.3.1) (1.16.4)\n", 132 | "Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (2.22.0)\n", 133 | "Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (1.21.0)\n", 134 | "Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (3.8.2)\n", 135 | "Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (1.0.0)\n", 136 | "Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (4.1.1.26)\n", 137 | "Requirement already satisfied: Pillow>=7.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (7.1.2)\n", 138 | "Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.0.4->parl==1.3.1) (1.1.0)\n", 139 | "Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.0.4->parl==1.3.1) (2.10.3)\n", 140 | "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (2019.9.11)\n", 141 | "Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (2.8)\n", 142 | "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (3.0.4)\n", 143 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (1.25.6)\n", 144 | "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (0.23)\n", 145 | "Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (1.4.10)\n", 146 | "Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (1.3.4)\n", 147 | "Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (2.0.1)\n", 148 | "Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (1.3.0)\n", 149 | "Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (16.7.9)\n", 150 | "Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (0.10.0)\n", 151 | "Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (5.1.2)\n", 152 | "Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (0.6.1)\n", 153 | "Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (2.6.0)\n", 154 | "Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (2.2.0)\n", 155 | "Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (2019.3)\n", 156 | "Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (2.8.0)\n", 157 | "Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.0.4->parl==1.3.1) (1.1.1)\n", 158 | "Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata; python_version < \"3.8\"->pre-commit->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (0.6.0)\n", 159 | "Requirement already satisfied: more-itertools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from zipp>=0.5->importlib-metadata; python_version < \"3.8\"->pre-commit->visualdl>=2.0.0b; python_version >= \"3\" and platform_system == \"Linux\"->parl==1.3.1) (7.2.0)\n", 160 | "Building wheels for collected packages: psutil\n", 161 | " Building wheel for psutil (setup.py) ... \u001b[?25ldone\n", 162 | "\u001b[?25h Created wheel for psutil: filename=psutil-5.7.2-cp37-cp37m-linux_x86_64.whl size=268459 sha256=ef019c256f341f219f0260fa4acfe2863e8f0bf1f96ed95b3f910a2a4c44a74c\n", 163 | " Stored in directory: /home/aistudio/.cache/pip/wheels/a8/74/a2/9f54383a7c48678163f965a5d2f4acb794417e60ab0d7351f8\n", 164 | "Successfully built psutil\n", 165 | "Installing collected packages: flask-cors, psutil, parl\n", 166 | "Successfully installed flask-cors-3.0.8 parl-1.3.1 psutil-5.7.2\n" 167 | ] 168 | } 169 | ], 170 | "source": [ 171 | "!pip uninstall -y parl # 说明:AIStudio预装的parl版本太老,容易跟其他库产生兼容性冲突,建议先卸载\n", 172 | "!pip uninstall -y pandas scikit-learn # 提示:在AIStudio中卸载这两个库再import parl可避免warning提示,不卸载也不影响parl的使用\n", 173 | "\n", 174 | "!pip install gym\n", 175 | "!pip install paddlepaddle==1.6.3\n", 176 | "!pip install parl==1.3.1\n", 177 | "\n", 178 | "# 建议下载paddle系列产品时添加百度源 -i https://mirror.baidu.com/pypi/simple\n", 179 | "# python -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple\n", 180 | "# pip install parl==1.3.1 -i https://mirror.baidu.com/pypi/simple\n", 181 | "\n", 182 | "# 说明:安装日志中出现两条红色的关于 paddlehub 和 visualdl 的 ERROR 与parl无关,可以忽略,不影响使用" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": { 188 | "collapsed": false 189 | }, 190 | "source": [ 191 | "### Step2 导入依赖\n", 192 | "如果依赖导入失败有可能没有下载第三方库可以在此前加上代码块\n", 193 | "!pip instudio 第三方库名 \n", 194 | "\n", 195 | "如果安装失败可以加上镜像\n", 196 | "\n", 197 | "![](https://ai-studio-static-online.cdn.bcebos.com/8ee07cbb24874e96bdac93b07e3a22e3b6b465981c904129990af727701bd0b0)\n", 198 | "```\n", 199 | "镜像源地址:\n", 200 | "\n", 201 | "百度:https://mirror.baidu.com/pypi/simple\n", 202 | "\n", 203 | "清华:https://pypi.tuna.tsinghua.edu.cn/simple\n", 204 | "\n", 205 | "阿里云:http://mirrors.aliyun.com/pypi/simple/\n", 206 | "\n", 207 | "中国科技大学 https://pypi.mirrors.ustc.edu.cn/simple/\n", 208 | "\n", 209 | "华中理工大学:http://pypi.hustunique.com/\n", 210 | "\n", 211 | "山东理工大学:http://pypi.sdutlinux.org/\n", 212 | "\n", 213 | "豆瓣:http://pypi.douban.com/simple/\n", 214 | "\n", 215 | "例:!pip instudio jieba -i https://mirror.baidu.com/pypi/simple\n", 216 | "```" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 2, 222 | "metadata": { 223 | "collapsed": false 224 | }, 225 | "outputs": [], 226 | "source": [ 227 | "import parl\n", 228 | "from parl import layers\n", 229 | "import paddle.fluid as fluid\n", 230 | "import copy\n", 231 | "import numpy as np\n", 232 | "import os\n", 233 | "import gym\n", 234 | "from parl.utils import logger" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": { 240 | "collapsed": false 241 | }, 242 | "source": [ 243 | "# 流程解析\n", 244 | "![](https://ai-studio-static-online.cdn.bcebos.com/25f4800e0c3b44bc9ddc908a4d3dba1793dc74f0f5a045e799771f0466b88e5b)\n" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": { 250 | "collapsed": false 251 | }, 252 | "source": [ 253 | "### Step3 设置超参数" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 3, 259 | "metadata": { 260 | "collapsed": false 261 | }, 262 | "outputs": [], 263 | "source": [ 264 | "LEARN_FREQ = 5 # 训练频率,不需要每一个step都learn,攒一些新增经验后再learn,提高效率\n", 265 | "MEMORY_SIZE = 20000 # replay memory的大小,越大越占用内存\n", 266 | "MEMORY_WARMUP_SIZE = 200 # replay_memory 里需要预存一些经验数据,再开启训练\n", 267 | "BATCH_SIZE = 32 # 每次给agent learn的数据数量,从replay memory随机里sample一批数据出来\n", 268 | "LEARNING_RATE = 0.001 # 学习率\n", 269 | "GAMMA = 0.99 # reward 的衰减因子,一般取 0.9 到 0.999 不等" 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "metadata": { 275 | "collapsed": false 276 | }, 277 | "source": [ 278 | "# 搭建PARL\n", 279 | "什么是PARL呢?让我们看看官方文档怎么说\n", 280 | "\n", 281 | "![飞桨官网描述](https://ai-studio-static-online.cdn.bcebos.com/1f72929e036549d493dee85c2834b8a98e330c4ac6a447f08b647cab09e12bfa) 飞桨官网描述\n", 282 | "![PARL GitHub描述](https://ai-studio-static-online.cdn.bcebos.com/9a19623bdf41411ca07c089d3cf0f64752e0eb51a58a47e694bc4e4614079078) PARL GitHub描述\n", 283 | "\n", 284 | "* PARL主要是基于Model、Algorithm、Agent三个代码块来实现,其中Model和Agent是用户自定义操作。\n", 285 | "* Model:是网络结构:要三层网络还是四层网络都是在Model中去定义(下文是三层的网络结构)\n", 286 | "* Agent:是PARL与环境的一个接口,通过对模板的修改即可运用到各个不同的环境中去。\n", 287 | "\n", 288 | "* 至于Algorithm是内部已经封装好了的,直接加入参数运行即可,主要是算法的模块的展现\n" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": { 294 | "collapsed": false 295 | }, 296 | "source": [ 297 | "### Step4 搭建Model、Algorithm、Agent架构\n", 298 | "* `Agent`把产生的数据传给`algorithm`,`algorithm`根据`model`的模型结构计算出`Loss`,使用`SGD`或者其他优化器不断的优化,`PARL`这种架构可以很方便的应用在各类深度强化学习问题中。\n", 299 | "\n", 300 | "#### (1)Model\n", 301 | "* `Model`用来定义前向(`Forward`)网络,用户可以自由的定制自己的网络结构。" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 4, 307 | "metadata": { 308 | "collapsed": false 309 | }, 310 | "outputs": [], 311 | "source": [ 312 | "class Model(parl.Model):\n", 313 | " def __init__(self, act_dim):\n", 314 | " hid1_size = 128\n", 315 | " hid2_size = 128\n", 316 | " # 3层全连接网络\n", 317 | " self.fc1 = layers.fc(size=hid1_size, act='relu')\n", 318 | " self.fc2 = layers.fc(size=hid2_size, act='relu')\n", 319 | " self.fc3 = layers.fc(size=act_dim, act=None)\n", 320 | "\n", 321 | " def value(self, obs):\n", 322 | " # 定义网络\n", 323 | " # 输入state,输出所有action对应的Q,[Q(s,a1), Q(s,a2), Q(s,a3)...]\n", 324 | " h1 = self.fc1(obs)\n", 325 | " h2 = self.fc2(h1)\n", 326 | " Q = self.fc3(h2)\n", 327 | " return Q" 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "metadata": { 333 | "collapsed": false 334 | }, 335 | "source": [ 336 | "## Q算法的“藏身之地”\n", 337 | "### 优势\n", 338 | "DQN算法较普通算法在经验回放和固定Q目标有了较大的改进\n", 339 | "\n", 340 | "* 1、经验回放:他充分利用了off-colicp的优势,通过训练把结果(成绩)存入Q表格,然后随机从表格中取出一条结果进行优化。这样子一方面可以:减少样本之间的关联性另一方面:提高样本的利用率\n", 341 | "注:训练结果会存进Q表格,当Q表格满了以后,存进来的数据会把最早存进去的数据“挤出去”(弹出)\n", 342 | "* 2、固定Q目标他解决了算法更新不平稳的问题。\n", 343 | "和监督学习做比较,监督学习的最终值要逼近实际结果,这个结果是固定的,但是我们的DQN却不是,他的目标值是经过神经网络以后的一个值,那么这个值是变动的不好拟合,怎么办,DQN团队想到了一个很好的办法,让这个值在一定时间里面保持不变,这样子这个目标就可以确定了,然后目标值更新以后更加接近实际结果,可以更好的进行训练。" 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "metadata": { 349 | "collapsed": false 350 | }, 351 | "source": [ 352 | "#### (2)Algorithm\n", 353 | "* `Algorithm` 定义了具体的算法来更新前向网络(`Model`),也就是通过定义损失函数来更新`Model`,和算法相关的计算都放在`algorithm`中。\n", 354 | "\n" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": 5, 360 | "metadata": { 361 | "collapsed": false 362 | }, 363 | "outputs": [], 364 | "source": [ 365 | "# from parl.algorithms import DQN # 也可以直接从parl库中导入DQN算法\n", 366 | "\n", 367 | "class DQN(parl.Algorithm):\n", 368 | " def __init__(self, model, act_dim=None, gamma=None, lr=None):\n", 369 | " \"\"\" DQN algorithm\n", 370 | " \n", 371 | " Args:\n", 372 | " model (parl.Model): 定义Q函数的前向网络结构\n", 373 | " act_dim (int): action空间的维度,即有几个action\n", 374 | " gamma (float): reward的衰减因子\n", 375 | " lr (float): learning rate 学习率.\n", 376 | " \"\"\"\n", 377 | " self.model = model\n", 378 | " self.target_model = copy.deepcopy(model)\n", 379 | "\n", 380 | " assert isinstance(act_dim, int)\n", 381 | " assert isinstance(gamma, float)\n", 382 | " assert isinstance(lr, float)\n", 383 | " self.act_dim = act_dim\n", 384 | " self.gamma = gamma\n", 385 | " self.lr = lr\n", 386 | "\n", 387 | " def predict(self, obs):\n", 388 | " \"\"\" 使用self.model的value网络来获取 [Q(s,a1),Q(s,a2),...]\n", 389 | " \"\"\"\n", 390 | " return self.model.value(obs)\n", 391 | "\n", 392 | " def learn(self, obs, action, reward, next_obs, terminal):\n", 393 | " \"\"\" 使用DQN算法更新self.model的value网络\n", 394 | " \"\"\"\n", 395 | " # 从target_model中获取 max Q' 的值,用于计算target_Q\n", 396 | " next_pred_value = self.target_model.value(next_obs)\n", 397 | " best_v = layers.reduce_max(next_pred_value, dim=1)\n", 398 | " best_v.stop_gradient = True # 阻止梯度传递\n", 399 | " terminal = layers.cast(terminal, dtype='float32')\n", 400 | " target = reward + (1.0 - terminal) * self.gamma * best_v\n", 401 | "\n", 402 | " pred_value = self.model.value(obs) # 获取Q预测值\n", 403 | " # 将action转onehot向量,比如:3 => [0,0,0,1,0]\n", 404 | " action_onehot = layers.one_hot(action, self.act_dim)\n", 405 | " action_onehot = layers.cast(action_onehot, dtype='float32')\n", 406 | " # 下面一行是逐元素相乘,拿到action对应的 Q(s,a)\n", 407 | " # 比如:pred_value = [[2.3, 5.7, 1.2, 3.9, 1.4]], action_onehot = [[0,0,0,1,0]]\n", 408 | " # ==> pred_action_value = [[3.9]]\n", 409 | " pred_action_value = layers.reduce_sum(\n", 410 | " layers.elementwise_mul(action_onehot, pred_value), dim=1)\n", 411 | "\n", 412 | " # 计算 Q(s,a) 与 target_Q的均方差,得到loss\n", 413 | " cost = layers.square_error_cost(pred_action_value, target)\n", 414 | " cost = layers.reduce_mean(cost)\n", 415 | " optimizer = fluid.optimizer.Adam(learning_rate=self.lr) # 使用Adam优化器\n", 416 | " optimizer.minimize(cost)\n", 417 | " return cost\n", 418 | "\n", 419 | " def sync_target(self):\n", 420 | " \"\"\" 把 self.model 的模型参数值同步到 self.target_model\n", 421 | " \"\"\"\n", 422 | " self.model.sync_weights_to(self.target_model)\n" 423 | ] 424 | }, 425 | { 426 | "cell_type": "markdown", 427 | "metadata": { 428 | "collapsed": false 429 | }, 430 | "source": [ 431 | "#### (3)Agent\n", 432 | "* `Agent` 负责算法与环境的交互,在交互过程中把生成的数据提供给`Algorithm`来更新模型(`Model`),数据的预处理流程也一般定义在这里。" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": 6, 438 | "metadata": { 439 | "collapsed": false 440 | }, 441 | "outputs": [], 442 | "source": [ 443 | "class Agent(parl.Agent):\n", 444 | " def __init__(self,\n", 445 | " algorithm,\n", 446 | " obs_dim,\n", 447 | " act_dim,\n", 448 | " e_greed=0.1,\n", 449 | " e_greed_decrement=0):\n", 450 | " assert isinstance(obs_dim, int)\n", 451 | " assert isinstance(act_dim, int)\n", 452 | " self.obs_dim = obs_dim\n", 453 | " self.act_dim = act_dim\n", 454 | " super(Agent, self).__init__(algorithm)\n", 455 | "\n", 456 | " self.global_step = 0\n", 457 | " self.update_target_steps = 200 # 每隔200个training steps再把model的参数复制到target_model中\n", 458 | "\n", 459 | " self.e_greed = e_greed # 有一定概率随机选取动作,探索\n", 460 | " self.e_greed_decrement = e_greed_decrement # 随着训练逐步收敛,探索的程度慢慢降低\n", 461 | "\n", 462 | " def build_program(self):\n", 463 | " self.pred_program = fluid.Program()\n", 464 | " self.learn_program = fluid.Program()\n", 465 | "\n", 466 | " with fluid.program_guard(self.pred_program): # 搭建计算图用于 预测动作,定义输入输出变量\n", 467 | " obs = layers.data(\n", 468 | " name='obs', shape=[self.obs_dim], dtype='float32')\n", 469 | " self.value = self.alg.predict(obs)\n", 470 | "\n", 471 | " with fluid.program_guard(self.learn_program): # 搭建计算图用于 更新Q网络,定义输入输出变量\n", 472 | " obs = layers.data(\n", 473 | " name='obs', shape=[self.obs_dim], dtype='float32')\n", 474 | " action = layers.data(name='act', shape=[1], dtype='int32')\n", 475 | " reward = layers.data(name='reward', shape=[], dtype='float32')\n", 476 | " next_obs = layers.data(\n", 477 | " name='next_obs', shape=[self.obs_dim], dtype='float32')\n", 478 | " terminal = layers.data(name='terminal', shape=[], dtype='bool')\n", 479 | " self.cost = self.alg.learn(obs, action, reward, next_obs, terminal)\n", 480 | "\n", 481 | " def sample(self, obs):\n", 482 | " sample = np.random.rand() # 产生0~1之间的小数\n", 483 | " if sample < self.e_greed:\n", 484 | " act = np.random.randint(self.act_dim) # 探索:每个动作都有概率被选择\n", 485 | " else:\n", 486 | " act = self.predict(obs) # 选择最优动作\n", 487 | " self.e_greed = max(\n", 488 | " 0.01, self.e_greed - self.e_greed_decrement) # 随着训练逐步收敛,探索的程度慢慢降低\n", 489 | " return act\n", 490 | "\n", 491 | " def predict(self, obs): # 选择最优动作\n", 492 | " obs = np.expand_dims(obs, axis=0)\n", 493 | " pred_Q = self.fluid_executor.run(\n", 494 | " self.pred_program,\n", 495 | " feed={'obs': obs.astype('float32')},\n", 496 | " fetch_list=[self.value])[0]\n", 497 | " pred_Q = np.squeeze(pred_Q, axis=0)\n", 498 | " act = np.argmax(pred_Q) # 选择Q最大的下标,即对应的动作\n", 499 | " return act\n", 500 | "\n", 501 | " def learn(self, obs, act, reward, next_obs, terminal):\n", 502 | " # 每隔200个training steps同步一次model和target_model的参数\n", 503 | " if self.global_step % self.update_target_steps == 0:\n", 504 | " self.alg.sync_target()\n", 505 | " self.global_step += 1\n", 506 | "\n", 507 | " act = np.expand_dims(act, -1)\n", 508 | " feed = {\n", 509 | " 'obs': obs.astype('float32'),\n", 510 | " 'act': act.astype('int32'),\n", 511 | " 'reward': reward,\n", 512 | " 'next_obs': next_obs.astype('float32'),\n", 513 | " 'terminal': terminal\n", 514 | " }\n", 515 | " cost = self.fluid_executor.run(\n", 516 | " self.learn_program, feed=feed, fetch_list=[self.cost])[0] # 训练一次网络\n", 517 | " return cost" 518 | ] 519 | }, 520 | { 521 | "cell_type": "markdown", 522 | "metadata": { 523 | "collapsed": false 524 | }, 525 | "source": [ 526 | "### Step5 ReplayMemory\n", 527 | "* 经验池:用于存储多条经验,实现 经验回放。\n", 528 | "![](https://ai-studio-static-online.cdn.bcebos.com/340817e194974c24ac63dd569fba336cd500d120729e4354825fb9c3501108ab)\n" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": 7, 534 | "metadata": { 535 | "collapsed": false 536 | }, 537 | "outputs": [], 538 | "source": [ 539 | "import random\n", 540 | "import collections\n", 541 | "import numpy as np\n", 542 | "\n", 543 | "\n", 544 | "class ReplayMemory(object):\n", 545 | " def __init__(self, max_size):\n", 546 | " self.buffer = collections.deque(maxlen=max_size)\n", 547 | "\n", 548 | " # 增加一条经验到经验池中\n", 549 | " def append(self, exp):\n", 550 | " self.buffer.append(exp)\n", 551 | "\n", 552 | " # 从经验池中选取N条经验出来\n", 553 | " def sample(self, batch_size):\n", 554 | " mini_batch = random.sample(self.buffer, batch_size)\n", 555 | " obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = [], [], [], [], []\n", 556 | "\n", 557 | " for experience in mini_batch:\n", 558 | " s, a, r, s_p, done = experience\n", 559 | " obs_batch.append(s)\n", 560 | " action_batch.append(a)\n", 561 | " reward_batch.append(r)\n", 562 | " next_obs_batch.append(s_p)\n", 563 | " done_batch.append(done)\n", 564 | "\n", 565 | " return np.array(obs_batch).astype('float32'), \\\n", 566 | " np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'),\\\n", 567 | " np.array(next_obs_batch).astype('float32'), np.array(done_batch).astype('float32')\n", 568 | "\n", 569 | " def __len__(self):\n", 570 | " return len(self.buffer)\n" 571 | ] 572 | }, 573 | { 574 | "cell_type": "markdown", 575 | "metadata": { 576 | "collapsed": false 577 | }, 578 | "source": [ 579 | "### Step6 Training && Test(训练&&测试)\n", 580 | "![](https://ai-studio-static-online.cdn.bcebos.com/3b9f168ee09a46dc9962decb0d2a60f7d35196d459824b9f923b163a0b4bbe4e)\n", 581 | "![](https://ai-studio-static-online.cdn.bcebos.com/4ffa39f252744a3791c04c5d8381824d4f7447a592d4409da658045061bcd0b9)\n", 582 | "* 训练和评估的一个模块" 583 | ] 584 | }, 585 | { 586 | "cell_type": "code", 587 | "execution_count": 8, 588 | "metadata": { 589 | "collapsed": false 590 | }, 591 | "outputs": [], 592 | "source": [ 593 | "# 训练一个episode\n", 594 | "def run_episode(env, agent, rpm):\n", 595 | " total_reward = 0\n", 596 | " obs = env.reset()\n", 597 | " step = 0\n", 598 | " while True:\n", 599 | " step += 1\n", 600 | " action = agent.sample(obs) # 采样动作,所有动作都有概率被尝试到\n", 601 | " next_obs, reward, done, _ = env.step(action)\n", 602 | " rpm.append((obs, action, reward, next_obs, done))\n", 603 | "\n", 604 | " # train model\n", 605 | " if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):\n", 606 | " (batch_obs, batch_action, batch_reward, batch_next_obs,\n", 607 | " batch_done) = rpm.sample(BATCH_SIZE)\n", 608 | " train_loss = agent.learn(batch_obs, batch_action, batch_reward,\n", 609 | " batch_next_obs,\n", 610 | " batch_done) # s,a,r,s',done\n", 611 | "\n", 612 | " total_reward += reward\n", 613 | " obs = next_obs\n", 614 | " if done:\n", 615 | " break\n", 616 | " return total_reward\n", 617 | "\n", 618 | "\n", 619 | "# 评估 agent, 跑 5 个episode,总reward求平均\n", 620 | "def evaluate(env, agent, render=False):\n", 621 | " eval_reward = []\n", 622 | " for i in range(5):\n", 623 | " obs = env.reset()\n", 624 | " episode_reward = 0\n", 625 | " while True:\n", 626 | " action = agent.predict(obs) # 预测动作,只选最优动作\n", 627 | " obs, reward, done, _ = env.step(action)\n", 628 | " episode_reward += reward\n", 629 | " if render:\n", 630 | " env.render()\n", 631 | " if done:\n", 632 | " break\n", 633 | " eval_reward.append(episode_reward)\n", 634 | " return np.mean(eval_reward)\n" 635 | ] 636 | }, 637 | { 638 | "cell_type": "markdown", 639 | "metadata": { 640 | "collapsed": false 641 | }, 642 | "source": [ 643 | "### Step7 创建环境和Agent,创建经验池,启动训练,保存模型" 644 | ] 645 | }, 646 | { 647 | "cell_type": "markdown", 648 | "metadata": { 649 | "collapsed": false 650 | }, 651 | "source": [ 652 | "# 主函数\n", 653 | "* 讲不清,理还乱,看一波图解,冷静冷静!!!\n", 654 | "![](https://ai-studio-static-online.cdn.bcebos.com/dab5d244b98c41108a84e37b79db0009109ed32d4d734ca7b04b68e2e301da09)\n", 655 | "![](https://ai-studio-static-online.cdn.bcebos.com/a297e56a6de045169e4a04391c22aa80bf7cc9d9608446fd97bef4dcd16431c8)\n", 656 | "![](https://ai-studio-static-online.cdn.bcebos.com/8bc3b4cd1913479480baa286b120f45b49cb8ac7c36645df8423d873a7bd6d07)\n" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 9, 662 | "metadata": { 663 | "collapsed": false 664 | }, 665 | "outputs": [ 666 | { 667 | "name": "stdout", 668 | "output_type": "stream", 669 | "text": [ 670 | "\u001b[32m[07-27 18:19:17 MainThread @machine_info.py:88]\u001b[0m Cannot find available GPU devices, using CPU now.\n", 671 | "\u001b[32m[07-27 18:19:17 MainThread @machine_info.py:88]\u001b[0m Cannot find available GPU devices, using CPU now.\n", 672 | "\u001b[32m[07-27 18:19:18 MainThread @machine_info.py:88]\u001b[0m Cannot find available GPU devices, using CPU now.\n", 673 | "\u001b[32m[07-27 18:19:19 MainThread @:38]\u001b[0m episode:50 e_greed:0.09930599999999931 test_reward:11.0\n", 674 | "\u001b[32m[07-27 18:19:21 MainThread @:38]\u001b[0m episode:100 e_greed:0.09877699999999878 test_reward:9.4\n", 675 | "\u001b[32m[07-27 18:19:23 MainThread @:38]\u001b[0m episode:150 e_greed:0.09828699999999829 test_reward:9.4\n", 676 | "\u001b[32m[07-27 18:19:24 MainThread @:38]\u001b[0m episode:200 e_greed:0.09781299999999782 test_reward:9.2\n", 677 | "\u001b[32m[07-27 18:19:26 MainThread @:38]\u001b[0m episode:250 e_greed:0.09733199999999734 test_reward:9.2\n", 678 | "\u001b[32m[07-27 18:19:27 MainThread @:38]\u001b[0m episode:300 e_greed:0.09683499999999684 test_reward:9.6\n", 679 | "\u001b[32m[07-27 18:19:29 MainThread @:38]\u001b[0m episode:350 e_greed:0.09634199999999635 test_reward:9.6\n", 680 | "\u001b[32m[07-27 18:19:30 MainThread @:38]\u001b[0m episode:400 e_greed:0.09583999999999585 test_reward:9.6\n", 681 | "\u001b[32m[07-27 18:19:32 MainThread @:38]\u001b[0m episode:450 e_greed:0.09533799999999534 test_reward:11.0\n", 682 | "\u001b[32m[07-27 18:19:34 MainThread @:38]\u001b[0m episode:500 e_greed:0.09476499999999477 test_reward:9.4\n", 683 | "\u001b[32m[07-27 18:19:37 MainThread @:38]\u001b[0m episode:550 e_greed:0.09395199999999396 test_reward:44.4\n", 684 | "\u001b[32m[07-27 18:19:43 MainThread @:38]\u001b[0m episode:600 e_greed:0.0921909999999922 test_reward:21.2\n", 685 | "\u001b[32m[07-27 18:20:05 MainThread @:38]\u001b[0m episode:650 e_greed:0.08648299999998649 test_reward:186.0\n", 686 | "\u001b[32m[07-27 18:20:39 MainThread @:38]\u001b[0m episode:700 e_greed:0.07723099999997723 test_reward:199.0\n", 687 | "\u001b[32m[07-27 18:21:12 MainThread @:38]\u001b[0m episode:750 e_greed:0.06832399999996833 test_reward:188.6\n", 688 | "\u001b[32m[07-27 18:21:42 MainThread @:38]\u001b[0m episode:800 e_greed:0.06003099999996003 test_reward:124.6\n", 689 | "\u001b[32m[07-27 18:22:09 MainThread @:38]\u001b[0m episode:850 e_greed:0.05293599999995294 test_reward:165.6\n", 690 | "\u001b[32m[07-27 18:22:31 MainThread @:38]\u001b[0m episode:900 e_greed:0.04666999999994667 test_reward:110.2\n", 691 | "\u001b[32m[07-27 18:22:56 MainThread @:38]\u001b[0m episode:950 e_greed:0.040123999999940124 test_reward:111.0\n", 692 | "\u001b[32m[07-27 18:23:21 MainThread @:38]\u001b[0m episode:1000 e_greed:0.03343099999993343 test_reward:133.6\n", 693 | "\u001b[32m[07-27 18:23:47 MainThread @:38]\u001b[0m episode:1050 e_greed:0.026627999999926627 test_reward:130.0\n", 694 | "\u001b[32m[07-27 18:24:08 MainThread @:38]\u001b[0m episode:1100 e_greed:0.02101999999992102 test_reward:130.0\n", 695 | "\u001b[32m[07-27 18:24:31 MainThread @:38]\u001b[0m episode:1150 e_greed:0.015293999999915868 test_reward:179.4\n", 696 | "\u001b[32m[07-27 18:25:03 MainThread @:38]\u001b[0m episode:1200 e_greed:0.01 test_reward:183.0\n", 697 | "\u001b[32m[07-27 18:25:33 MainThread @:38]\u001b[0m episode:1250 e_greed:0.01 test_reward:130.8\n", 698 | "\u001b[32m[07-27 18:25:59 MainThread @:38]\u001b[0m episode:1300 e_greed:0.01 test_reward:172.8\n", 699 | "\u001b[32m[07-27 18:26:25 MainThread @:38]\u001b[0m episode:1350 e_greed:0.01 test_reward:152.2\n", 700 | "\u001b[32m[07-27 18:26:49 MainThread @:38]\u001b[0m episode:1400 e_greed:0.01 test_reward:127.8\n", 701 | "\u001b[32m[07-27 18:27:19 MainThread @:38]\u001b[0m episode:1450 e_greed:0.01 test_reward:182.0\n", 702 | "\u001b[32m[07-27 18:27:52 MainThread @:38]\u001b[0m episode:1500 e_greed:0.01 test_reward:120.6\n", 703 | "\u001b[32m[07-27 18:28:25 MainThread @:38]\u001b[0m episode:1550 e_greed:0.01 test_reward:133.6\n", 704 | "\u001b[32m[07-27 18:28:45 MainThread @:38]\u001b[0m episode:1600 e_greed:0.01 test_reward:14.0\n", 705 | "\u001b[32m[07-27 18:28:49 MainThread @:38]\u001b[0m episode:1650 e_greed:0.01 test_reward:15.8\n", 706 | "\u001b[32m[07-27 18:29:09 MainThread @:38]\u001b[0m episode:1700 e_greed:0.01 test_reward:164.2\n", 707 | "\u001b[32m[07-27 18:29:47 MainThread @:38]\u001b[0m episode:1750 e_greed:0.01 test_reward:200.0\n", 708 | "\u001b[32m[07-27 18:30:26 MainThread @:38]\u001b[0m episode:1800 e_greed:0.01 test_reward:200.0\n", 709 | "\u001b[32m[07-27 18:31:03 MainThread @:38]\u001b[0m episode:1850 e_greed:0.01 test_reward:182.4\n", 710 | "\u001b[32m[07-27 18:31:40 MainThread @:38]\u001b[0m episode:1900 e_greed:0.01 test_reward:193.8\n", 711 | "\u001b[32m[07-27 18:32:17 MainThread @:38]\u001b[0m episode:1950 e_greed:0.01 test_reward:193.4\n", 712 | "\u001b[32m[07-27 18:32:53 MainThread @:38]\u001b[0m episode:2000 e_greed:0.01 test_reward:170.2\n" 713 | ] 714 | } 715 | ], 716 | "source": [ 717 | "env = gym.make('CartPole-v0') # CartPole-v0: 预期最后一次评估总分 > 180(最大值是200)\n", 718 | "action_dim = env.action_space.n # CartPole-v0: 2\n", 719 | "obs_shape = env.observation_space.shape # CartPole-v0: (4,)\n", 720 | "\n", 721 | "rpm = ReplayMemory(MEMORY_SIZE) # DQN的经验回放池\n", 722 | "\n", 723 | "# 根据parl框架构建agent\n", 724 | "model = Model(act_dim=action_dim)\n", 725 | "algorithm = DQN(model, act_dim=action_dim, gamma=GAMMA, lr=LEARNING_RATE)\n", 726 | "agent = Agent(\n", 727 | " algorithm,\n", 728 | " obs_dim=obs_shape[0],\n", 729 | " act_dim=action_dim,\n", 730 | " e_greed=0.1, # 有一定概率随机选取动作,探索\n", 731 | " e_greed_decrement=1e-6) # 随着训练逐步收敛,探索的程度慢慢降低\n", 732 | "\n", 733 | "# 加载模型\n", 734 | "# save_path = './dqn_model.ckpt'\n", 735 | "# agent.restore(save_path)\n", 736 | "\n", 737 | "# 先往经验池里存一些数据,避免最开始训练的时候样本丰富度不够\n", 738 | "while len(rpm) < MEMORY_WARMUP_SIZE:\n", 739 | " run_episode(env, agent, rpm)\n", 740 | "\n", 741 | "max_episode = 2000\n", 742 | "\n", 743 | "# 开始训练\n", 744 | "episode = 0\n", 745 | "while episode < max_episode: # 训练max_episode个回合,test部分不计算入episode数量\n", 746 | " # train part\n", 747 | " for i in range(0, 50):\n", 748 | " total_reward = run_episode(env, agent, rpm)\n", 749 | " episode += 1\n", 750 | "\n", 751 | " # test part\n", 752 | " eval_reward = evaluate(env, agent, render=False) # render=True 查看显示效果\n", 753 | " logger.info('episode:{} e_greed:{} test_reward:{}'.format(\n", 754 | " episode, agent.e_greed, eval_reward))\n", 755 | "\n", 756 | "# 训练结束,保存模型\n", 757 | "save_path = './dqn_model.ckpt'\n", 758 | "agent.save(save_path)" 759 | ] 760 | }, 761 | { 762 | "cell_type": "markdown", 763 | "metadata": { 764 | "collapsed": false 765 | }, 766 | "source": [ 767 | "![](https://ai-studio-static-online.cdn.bcebos.com/94a7460d1da44164a7805c5680a2779689b1dc4c78494104aa3d2bda075411fa)\n" 768 | ] 769 | }, 770 | { 771 | "cell_type": "markdown", 772 | "metadata": { 773 | "collapsed": false 774 | }, 775 | "source": [ 776 | "# 运行经历\n", 777 | "这一串代码是从课件上copy下来的,所以里面的参数及数据都是已经调整好了的,但是在线下跑数据发现每一次跑也并非到最后都是200也就是每一次的结果都是不一定的,然后我开了显示模式,里面的测试画面和结果会被显示和打印,前面的速度会比后面的快,因为分数低时间自然就短了。\n", 778 | "更据这段时间对深度学习的学习,基本上代码运行没有问题后那么调整几个超参基本上可以得到一个比较好的拟合效果。\n" 779 | ] 780 | }, 781 | { 782 | "cell_type": "markdown", 783 | "metadata": { 784 | "collapsed": false 785 | }, 786 | "source": [ 787 | "# 心得体会\n", 788 | "本次“爬山”,发现之前的一些盲点在本次有了解决,,之前对深度学习、PARL、算法等都有了最新的认识,虽然还是那个小白但是认识提上去了,以后还是可以更加努力的去奋斗的,这就是传说中的回头看的时候就知道自己以前是多么的无知了。" 789 | ] 790 | }, 791 | { 792 | "cell_type": "markdown", 793 | "metadata": { 794 | "collapsed": false 795 | }, 796 | "source": [ 797 | "# 这里是三岁,请大家多多指教啊!" 798 | ] 799 | } 800 | ], 801 | "metadata": { 802 | "kernelspec": { 803 | "display_name": "PaddlePaddle 1.8.0 (Python 3.5)", 804 | "language": "python", 805 | "name": "py35-paddle1.2.0" 806 | }, 807 | "language_info": { 808 | "codemirror_mode": { 809 | "name": "ipython", 810 | "version": 3 811 | }, 812 | "file_extension": ".py", 813 | "mimetype": "text/x-python", 814 | "name": "python", 815 | "nbconvert_exporter": "python", 816 | "pygments_lexer": "ipython3", 817 | "version": "3.7.4" 818 | } 819 | }, 820 | "nbformat": 4, 821 | "nbformat_minor": 1 822 | } 823 | --------------------------------------------------------------------------------