├── .gitignore ├── LICENSE ├── README.md ├── image ├── CartPole-v0.gif ├── CartPole-v0_reward_curve.png ├── CartPole-v1.gif ├── CartPole-v1_reward_curve.png ├── LunarLander-v2.gif ├── LunarLander-v2_reward_curve.png ├── MountainCar-v0.gif └── MountainCar-v0_reward_curve.png └── src ├── PPO.ipynb ├── PPO.py ├── PPO2.ipynb ├── PPO2.py ├── categorical_dqn.ipynb ├── categorical_dueling_ddqn.ipynb ├── core.ipynb ├── core.py ├── dqn.ipynb ├── dueling_double_dqn.ipynb ├── grpo_discrete_episodic.ipynb ├── ppo_discrete_episodic.ipynb ├── ppo_discrete_episodic2.ipynb ├── ppo_discrete_step.ipynb ├── ppo_discrete_step.py ├── ppo_discrete_step_parallel.py ├── ppo_discrete_step_test.ipynb ├── ppo_discrete_step_test.py ├── rand_net_distill ├── dueling_ddqn_with_rnd.ipynb └── ppo_step_with_rnd.ipynb ├── running_mean_std.py ├── soft_actor_critic.ipynb └── test ├── running_mean_std.py ├── saved_models ├── CartPole-v0_ep179_clear_model_ppo_st.pt ├── CartPole-v0_ep87_clear_model_ppo_st.pt ├── CartPole-v1_ep108_clear_model_ppo_st.pt ├── CartPole-v1_ep112_clear_model_dddqn.pt ├── CartPole-v1_ep1150_clear_model_ppo_st.pt ├── CartPole-v1_ep118_clear_model_ppo_st.pt ├── CartPole-v1_ep1338_clear_model_ppo_st.pt ├── CartPole-v1_ep134_clear_model_ppo_st.pt ├── CartPole-v1_ep164_clear_model_ppo_st.pt ├── CartPole-v1_ep168_clear_model_dddqn.pt ├── CartPole-v1_ep1715_clear_model_ppo_st.pt ├── CartPole-v1_ep213_clear_model_ppo_st.pt ├── CartPole-v1_ep216_clear_model_ppo_st.pt ├── CartPole-v1_ep222_clear_model_ppo_st.pt ├── CartPole-v1_ep223_clear_model_ppo_st.pt ├── CartPole-v1_ep258_clear_model_ppo_st.pt ├── CartPole-v1_ep270_clear_model_ppo_st.pt ├── CartPole-v1_ep273_clear_model_ppo_st.pt ├── CartPole-v1_ep294_clear_model_ppo_st.pt ├── CartPole-v1_ep295_clear_model_ppo_st.pt ├── CartPole-v1_ep296_clear_model_ppo_st.pt ├── CartPole-v1_ep308_clear_model_ppo_st.pt ├── CartPole-v1_ep313_clear_model_ppo_st.pt ├── CartPole-v1_ep320_clear_model_ppo_st.pt ├── CartPole-v1_ep326_clear_model_ppo_st.pt ├── CartPole-v1_ep329_clear_model_ppo_st.pt ├── CartPole-v1_ep357_clear_model_ppo_st.pt ├── CartPole-v1_ep376_clear_model_ppo_st.pt ├── CartPole-v1_ep385_clear_model_ppo_st.pt ├── CartPole-v1_ep388_clear_model_ppo_st.pt ├── CartPole-v1_ep408_clear_model_ppo_st.pt ├── CartPole-v1_ep417_clear_model_ppo_st.pt ├── CartPole-v1_ep418_clear_model_ppo_st.pt ├── CartPole-v1_ep420_clear_model_ppo_st.pt ├── CartPole-v1_ep424_clear_model_ppo_st.pt ├── CartPole-v1_ep433_clear_model_ppo_st.pt ├── CartPole-v1_ep436_clear_model_ppo_st.pt ├── CartPole-v1_ep445_clear_model_ppo_st.pt ├── CartPole-v1_ep454_clear_model_ppo_st.pt ├── CartPole-v1_ep456_clear_model_ppo_st.pt ├── CartPole-v1_ep458_clear_model_ppo_st.pt ├── CartPole-v1_ep463_clear_model_ppo_st.pt ├── CartPole-v1_ep470_clear_model_ppo_st.pt ├── CartPole-v1_ep474_clear_model_ppo_st.pt ├── CartPole-v1_ep475_clear_model_ppo_st.pt ├── CartPole-v1_ep492_clear_model_ppo_st.pt ├── CartPole-v1_ep498_clear_model_ppo_st.pt ├── CartPole-v1_ep504_clear_model_ppo_st.pt ├── CartPole-v1_ep506_clear_model_ppo_st.pt ├── CartPole-v1_ep509_clear_model_ppo_st.pt ├── CartPole-v1_ep516_clear_model_ppo_st.pt ├── CartPole-v1_ep519_clear_model_ppo_st.pt ├── CartPole-v1_ep530_clear_model_ppo_st.pt ├── CartPole-v1_ep534_clear_model_ppo_st.pt ├── CartPole-v1_ep546_clear_model_ppo_st.pt ├── CartPole-v1_ep552_clear_model_ppo_st.pt ├── CartPole-v1_ep555_clear_model_ppo_st.pt ├── CartPole-v1_ep559_clear_model_ppo_st.pt ├── CartPole-v1_ep566_clear_model_ppo_st.pt ├── CartPole-v1_ep569_clear_model_ppo_st.pt ├── CartPole-v1_ep586_clear_model_ppo_st.pt ├── CartPole-v1_ep589_clear_model_ppo_st.pt ├── CartPole-v1_ep595_clear_model_ppo_st.pt ├── CartPole-v1_ep631_clear_model_ppo_st.pt ├── CartPole-v1_ep635_clear_model_ppo_st.pt ├── CartPole-v1_ep637_clear_model_ppo_st.pt ├── CartPole-v1_ep642_clear_model_ppo_st.pt ├── CartPole-v1_ep675_clear_model_ppo_st.pt ├── CartPole-v1_ep690_clear_model_ppo_st.pt ├── CartPole-v1_ep727_clear_model_ppo_st.pt ├── CartPole-v1_ep741_clear_model_ppo_st.pt ├── CartPole-v1_ep763_clear_model_ppo_st.pt ├── CartPole-v1_ep840_clear_model_ppo_st.pt ├── CartPole-v1_ep845_clear_model_ppo_st.pt ├── CartPole-v1_ep866_clear_model_ppo_st.pt ├── CartPole-v1_ep871_clear_model_ppo_st.pt ├── CartPole-v1_ep892_clear_model_ppo_st.pt ├── CartPole-v1_ep903_clear_model_ppo_st.pt ├── CartPole-v1_ep924_clear_model_ppo_st.pt ├── CartPole-v1_ep939_clear_model_ppo_st.pt ├── CartPole-v1_up50_clear_model_ppo_st.pt ├── CartPole-v1_up50_clear_norm_obs.pkl ├── LunarLander-v2_ep260_clear_model_dqn.pt ├── LunarLander-v2_ep370_clear_model_dddqn.pt ├── LunarLander-v2_ep8461_clear_model_ppo_st.pt ├── LunarLander-v2_ep876_clear_model_ppo_st.pt ├── LunarLander-v2_up1099_clear_model_ppo_st.pt ├── LunarLander-v2_up1317_clear_model_ppo_st.pt ├── LunarLander-v2_up254_clear_model_ppo_st.pt ├── LunarLander-v2_up570_clear_model_ppo_st.pt ├── LunarLander-v2_up75_clear_model_ppo_st.pt ├── MountainCar-v0_ep304_clear_model_dqn.pt ├── MountainCar-v0_ep532_clear_model_dddqn.pt ├── MountainCar-v0_ep984_clear_model_dddqn.pt ├── MountainCar-v0_up1441_clear_model_ppo_st.pt └── MountainCar-v0_up1441_clear_norm_obs.pkl ├── test_dqn.ipynb ├── test_dueling_double_dqn.ipynb ├── test_ppo.ipynb └── test_ppo_with_rnd.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # VS Code 107 | .vscode 108 | 109 | # tensorboard 110 | runs 111 | 112 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Jungdae Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Deep RL Algorithms in PyTorch 2 | 3 | ### Models 4 | - DQN 5 | - Dueling Double DQN 6 | - Categorical DQN (C51) 7 | - Categotical Dueling Double DQN 8 | - Proximal Policy Optimization (PPO) 9 | + discrete (episodic, n-step) 10 | - Group Relative Policy Optimization (GRPO) 11 | 12 |
13 | 14 | ### Exploration 15 | - Random Network Distillation (RND) 16 |
17 | 18 | ### Experiments 19 | The result of passing the environment-defined "solving" criteria. 20 | - **Dueling Double DQN** 21 | + Only one hyperparameter "UP_COEF" was adjusted. 22 | ###### CartPole-v0 23 |
24 | 25 |
26 | 27 | ###### CartPole-v1 28 |
29 | 30 |
31 | 32 | ###### MountainCar-v0 33 |
34 | 35 |
36 | 37 | ###### LunarLander-v2 38 |
39 | 40 |
41 |
42 | 43 | ### TODO 44 | - Proximal Policy Optimization (PPO) 45 | + continuous 46 | -------------------------------------------------------------------------------- /image/CartPole-v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/CartPole-v0.gif -------------------------------------------------------------------------------- /image/CartPole-v0_reward_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/CartPole-v0_reward_curve.png -------------------------------------------------------------------------------- /image/CartPole-v1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/CartPole-v1.gif -------------------------------------------------------------------------------- /image/CartPole-v1_reward_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/CartPole-v1_reward_curve.png -------------------------------------------------------------------------------- /image/LunarLander-v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/LunarLander-v2.gif -------------------------------------------------------------------------------- /image/LunarLander-v2_reward_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/LunarLander-v2_reward_curve.png -------------------------------------------------------------------------------- /image/MountainCar-v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/MountainCar-v0.gif -------------------------------------------------------------------------------- /image/MountainCar-v0_reward_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/MountainCar-v0_reward_curve.png -------------------------------------------------------------------------------- /src/PPO.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import random\n", 10 | "import numpy as np\n", 11 | "import torch\n", 12 | "from torch.optim import Adam, AdamW\n", 13 | "import gym\n", 14 | "import time\n", 15 | "import core\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "from IPython.display import clear_output\n", 18 | "from running_mean_std import RunningMeanStd\n", 19 | "\n", 20 | "\n", 21 | "class PPOBuffer(object):\n", 22 | " def __init__(self, obs_dim, act_dim, size, gamma=0.999, lam=0.97):\n", 23 | " self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)\n", 24 | " self.act_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32)\n", 25 | " self.adv_buf = np.zeros(size, dtype=np.float32)\n", 26 | " self.rew_buf = np.zeros(size, dtype=np.float32)\n", 27 | " self.ret_buf = np.zeros(size, dtype=np.float32)\n", 28 | " self.val_buf = np.zeros(size, dtype=np.float32)\n", 29 | " self.logp_buf = np.zeros(size, dtype=np.float32)\n", 30 | " self.gamma, self.lam = gamma, lam\n", 31 | " self.ptr, self.path_start_idx, self.max_size = 0, 0, size\n", 32 | "\n", 33 | " def store(self, obs, act, rew, val, logp):\n", 34 | " assert self.ptr < self.max_size\n", 35 | " self.obs_buf[self.ptr] = obs\n", 36 | " self.act_buf[self.ptr] = act\n", 37 | " self.rew_buf[self.ptr] = rew\n", 38 | " self.val_buf[self.ptr] = val\n", 39 | " self.logp_buf[self.ptr] = logp\n", 40 | " self.ptr += 1\n", 41 | "\n", 42 | " def finish_path(self, last_val=0):\n", 43 | " path_slice = slice(self.path_start_idx, self.ptr)\n", 44 | " rews = np.append(self.rew_buf[path_slice], last_val)\n", 45 | " vals = np.append(self.val_buf[path_slice], last_val)\n", 46 | " \n", 47 | " deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1]\n", 48 | " self.adv_buf[path_slice] = core.discount_cumsum(deltas, self.gamma * self.lam)\n", 49 | " self.ret_buf[path_slice] = core.discount_cumsum(rews, self.gamma)[:-1]\n", 50 | " self.path_start_idx = self.ptr\n", 51 | "\n", 52 | " def get(self):\n", 53 | " assert self.ptr == self.max_size\n", 54 | " self.ptr, self.path_start_idx = 0, 0\n", 55 | " adv_mean = np.mean(self.adv_buf)\n", 56 | " adv_std = np.std(self.adv_buf)\n", 57 | " self.adv_buf = (self.adv_buf - adv_mean) / adv_std\n", 58 | " data = dict(obs=self.obs_buf, act=self.act_buf, ret=self.ret_buf,\n", 59 | " adv=self.adv_buf, logp=self.logp_buf)\n", 60 | " return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in data.items()}" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "def plot(ep_ret_buf, eval_ret_buf, loss_buf):\n", 70 | " clear_output(True)\n", 71 | " plt.figure(figsize=(16, 5))\n", 72 | " plt.subplot(131)\n", 73 | " plt.plot(ep_ret_buf, alpha=0.5)\n", 74 | " plt.subplot(131)\n", 75 | " plt.plot(eval_ret_buf)\n", 76 | " plt.title(f\"Reward: {eval_ret_buf[-1]:.0f}\")\n", 77 | " plt.subplot(132)\n", 78 | " plt.plot(loss_buf['pi'], alpha=0.5)\n", 79 | " plt.title(f\"Pi_Loss: {np.mean(loss_buf['pi'][:-20:]):.3f}\")\n", 80 | " plt.subplot(133)\n", 81 | " plt.plot(loss_buf['vf'], alpha=0.5)\n", 82 | " plt.title(f\"Vf_Loss: {np.mean(loss_buf['vf'][-20:]):.2f}\")\n", 83 | " plt.show()" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "def compute_loss_pi(data, ac, clip_ratio):\n", 93 | " obs, act, adv, logp_old = data['obs'], data['act'], data['adv'], data['logp']\n", 94 | "\n", 95 | " # Policy loss\n", 96 | " pi, logp = ac.pi(obs, act)\n", 97 | " ratio = torch.exp(logp - logp_old)\n", 98 | " clip_adv = torch.clamp(ratio, 1-clip_ratio, 1+clip_ratio) * adv\n", 99 | " loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()\n", 100 | "\n", 101 | " # Useful extra info\n", 102 | "# approx_kl = (logp_old - logp).mean().item()\n", 103 | " kl_div = ((logp.exp() * (logp - logp_old)).mean()).detach().item()\n", 104 | " ent = pi.entropy().mean().detach().item()\n", 105 | " clipped = ratio.gt(1+clip_ratio) | ratio.lt(1-clip_ratio)\n", 106 | " clipfrac = torch.as_tensor(clipped, dtype=torch.float32).mean().detach().item()\n", 107 | " pi_info = dict(kl=kl_div, ent=ent, cf=clipfrac)\n", 108 | "\n", 109 | " return loss_pi, pi_info\n", 110 | "\n", 111 | "def compute_loss_v(data, ac):\n", 112 | " obs, ret = data['obs'], data['ret']\n", 113 | " return ((ac.v(obs) - ret)**2).mean()\n", 114 | "\n", 115 | "\n", 116 | "def update(buf, train_pi_iters, train_vf_iters, clip_ratio, target_kl, ac, pi_optimizer, vf_optimizer, loss_buf):\n", 117 | " data = buf.get()\n", 118 | "\n", 119 | " # Train policy with multiple steps of gradient descent\n", 120 | " for i in range(train_pi_iters):\n", 121 | " pi_optimizer.zero_grad()\n", 122 | " loss_pi, pi_info = compute_loss_pi(data, ac, clip_ratio)\n", 123 | " loss_buf['pi'].append(loss_pi.item())\n", 124 | " kl = pi_info['kl']\n", 125 | " if kl > 1.5 * target_kl:\n", 126 | " print('Early stopping at step %d due to reaching max kl.'%i)\n", 127 | " break\n", 128 | " loss_pi.backward()\n", 129 | " pi_optimizer.step()\n", 130 | "\n", 131 | " # Value function learning\n", 132 | " for i in range(train_vf_iters):\n", 133 | " vf_optimizer.zero_grad()\n", 134 | " loss_vf = compute_loss_v(data, ac)\n", 135 | " loss_buf['vf'].append(loss_vf.item())\n", 136 | " loss_vf.backward()\n", 137 | " vf_optimizer.step()" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "def main():\n", 147 | " actor_critic=core.MLPActorCritic\n", 148 | " hidden_size = 64\n", 149 | " activation = torch.nn.Tanh\n", 150 | " seed = 5\n", 151 | " steps_per_epoch = 4096\n", 152 | " epochs = 1000\n", 153 | " gamma = 0.99\n", 154 | " lam = 0.97\n", 155 | " clip_ratio = 0.2\n", 156 | " pi_lr = 3e-4\n", 157 | " vf_lr = 1e-3\n", 158 | " train_pi_iters = 80\n", 159 | " train_vf_iters = 80\n", 160 | " max_ep_len = 1000\n", 161 | " target_kl = 0.01\n", 162 | " save_freq = 10\n", 163 | " obs_norm = True\n", 164 | " view_curve = True\n", 165 | "\n", 166 | " # make an environment\n", 167 | "# env = gym.make('CartPole-v0')\n", 168 | "# env = gym.make('CartPole-v1')\n", 169 | "# env = gym.make('MountainCar-v0')\n", 170 | "# env = gym.make('LunarLander-v2')\n", 171 | " env = gym.make('BipedalWalker-v3')\n", 172 | " print(f\"reward_threshold: {env.spec.reward_threshold}\")\n", 173 | "\n", 174 | " obs_dim = env.observation_space.shape\n", 175 | " act_dim = env.action_space.shape\n", 176 | "\n", 177 | " # Random seed\n", 178 | " env.seed(seed)\n", 179 | " random.seed(seed)\n", 180 | " torch.manual_seed(seed)\n", 181 | " np.random.seed(seed)\n", 182 | "\n", 183 | " # Create actor-critic module\n", 184 | " ac = actor_critic(env.observation_space, env.action_space, (hidden_size, hidden_size), activation)\n", 185 | " \n", 186 | " # Set up optimizers for policy and value function\n", 187 | " pi_optimizer = AdamW(ac.pi.parameters(), lr=pi_lr, eps=1e-6)\n", 188 | " vf_optimizer = AdamW(ac.v.parameters(), lr=vf_lr, eps=1e-6)\n", 189 | "\n", 190 | " # Count variables\n", 191 | " var_counts = tuple(core.count_vars(module) for module in [ac.pi, ac.v])\n", 192 | "\n", 193 | " # Set up experience buffer\n", 194 | " local_steps_per_epoch = int(steps_per_epoch)\n", 195 | " buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)\n", 196 | " \n", 197 | " # Prepare for interaction with environment\n", 198 | " start_time = time.time()\n", 199 | " o, ep_ret, ep_len = env.reset(), 0, 0\n", 200 | " ep_num = 0\n", 201 | " ep_ret_buf, eval_ret_buf = [], []\n", 202 | " loss_buf = {'pi': [], 'vf': []}\n", 203 | " obs_normalizer = RunningMeanStd(shape=env.observation_space.shape)\n", 204 | " # Main loop: collect experience in env and update/log each epoch\n", 205 | " for epoch in range(epochs):\n", 206 | " for t in range(local_steps_per_epoch):\n", 207 | " if obs_norm:\n", 208 | " obs_normalizer.update(np.array([o]))\n", 209 | " o_norm = np.clip((o - obs_normalizer.mean) / np.sqrt(obs_normalizer.var), -10, 10)\n", 210 | " a, v, logp = ac.step(torch.as_tensor(o_norm, dtype=torch.float32))\n", 211 | " else:\n", 212 | " a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32))\n", 213 | "\n", 214 | " next_o, r, d, _ = env.step(a)\n", 215 | " ep_ret += r\n", 216 | " ep_len += 1\n", 217 | "\n", 218 | " # save and log\n", 219 | " if obs_norm:\n", 220 | " buf.store(o_norm, a, r, v, logp)\n", 221 | " else:\n", 222 | " buf.store(o, a, r, v, logp)\n", 223 | "\n", 224 | " # Update obs\n", 225 | " o = next_o\n", 226 | "\n", 227 | " timeout = ep_len == max_ep_len\n", 228 | " terminal = d or timeout\n", 229 | " epoch_ended = t==local_steps_per_epoch-1\n", 230 | "\n", 231 | " if terminal or epoch_ended:\n", 232 | " if timeout or epoch_ended:\n", 233 | " if obs_norm:\n", 234 | " obs_normalizer.update(np.array([o]))\n", 235 | " o_norm = np.clip((o - obs_normalizer.mean) / np.sqrt(obs_normalizer.var), -10, 10)\n", 236 | " _, v, _ = ac.step(torch.as_tensor(o_norm, dtype=torch.float32))\n", 237 | " else:\n", 238 | " _, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32))\n", 239 | " else:\n", 240 | " if obs_norm:\n", 241 | " obs_normalizer.update(np.array([o]))\n", 242 | " v = 0\n", 243 | " buf.finish_path(v)\n", 244 | " if terminal:\n", 245 | " ep_ret_buf.append(ep_ret)\n", 246 | " eval_ret_buf.append(np.mean(ep_ret_buf[-20:]))\n", 247 | " ep_num += 1\n", 248 | " if view_curve:\n", 249 | " plot(ep_ret_buf, eval_ret_buf, loss_buf)\n", 250 | " else:\n", 251 | " print(f'Episode: {ep_num:3} Reward: {ep_reward:3}')\n", 252 | " if eval_ret_buf[-1] >= env.spec.reward_threshold:\n", 253 | " print(f\"\\n{env.spec.id} is sloved! {ep_num} Episode\")\n", 254 | " return\n", 255 | "\n", 256 | " o, ep_ret, ep_len = env.reset(), 0, 0\n", 257 | " # Perform PPO update!\n", 258 | " update(buf, train_pi_iters, train_vf_iters, clip_ratio, target_kl, ac, pi_optimizer, vf_optimizer, loss_buf)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "main()" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [] 276 | } 277 | ], 278 | "metadata": { 279 | "kernelspec": { 280 | "display_name": "Python 3", 281 | "language": "python", 282 | "name": "python3" 283 | }, 284 | "language_info": { 285 | "codemirror_mode": { 286 | "name": "ipython", 287 | "version": 3 288 | }, 289 | "file_extension": ".py", 290 | "mimetype": "text/x-python", 291 | "name": "python", 292 | "nbconvert_exporter": "python", 293 | "pygments_lexer": "ipython3", 294 | "version": "3.6.10" 295 | } 296 | }, 297 | "nbformat": 4, 298 | "nbformat_minor": 4 299 | } 300 | -------------------------------------------------------------------------------- /src/PPO.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # %% 4 | 5 | # %% 6 | 7 | import pickle 8 | import random 9 | 10 | import core 11 | import gym 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | import torch 15 | from running_mean_std import RunningMeanStd 16 | from torch.optim import AdamW 17 | 18 | 19 | class PPOBuffer(object): 20 | def __init__(self, obs_dim, act_dim, size, gamma=0.999, lam=0.97): 21 | self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32) 22 | self.act_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32) 23 | self.adv_buf = np.zeros(size, dtype=np.float32) 24 | self.rew_buf = np.zeros(size, dtype=np.float32) 25 | self.ret_buf = np.zeros(size, dtype=np.float32) 26 | self.val_buf = np.zeros(size, dtype=np.float32) 27 | self.logp_buf = np.zeros(size, dtype=np.float32) 28 | self.gamma, self.lam = gamma, lam 29 | self.ptr, self.path_start_idx, self.max_size = 0, 0, size 30 | 31 | def store(self, obs, act, rew, val, logp): 32 | assert self.ptr < self.max_size 33 | self.obs_buf[self.ptr] = obs 34 | self.act_buf[self.ptr] = act 35 | self.rew_buf[self.ptr] = rew 36 | self.val_buf[self.ptr] = val 37 | self.logp_buf[self.ptr] = logp 38 | self.ptr += 1 39 | 40 | def finish_path(self, last_val=0): 41 | path_slice = slice(self.path_start_idx, self.ptr) 42 | rews = np.append(self.rew_buf[path_slice], last_val) 43 | vals = np.append(self.val_buf[path_slice], last_val) 44 | deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1] 45 | self.adv_buf[path_slice] = core.discount_cumsum(deltas, self.gamma * self.lam) 46 | self.ret_buf[path_slice] = core.discount_cumsum(rews, self.gamma)[:-1] 47 | self.path_start_idx = self.ptr 48 | 49 | def get(self): 50 | assert self.ptr == self.max_size 51 | self.ptr, self.path_start_idx = 0, 0 52 | adv_mean = np.mean(self.adv_buf) 53 | adv_std = np.std(self.adv_buf) 54 | self.adv_buf = (self.adv_buf - adv_mean) / adv_std 55 | data = dict(obs=self.obs_buf, act=self.act_buf, ret=self.ret_buf, adv=self.adv_buf, logp=self.logp_buf) 56 | return {k: torch.as_tensor(v, dtype=torch.float32) for k, v in data.items()} 57 | 58 | 59 | # %% 60 | 61 | 62 | def plot(ep_ret_buf, eval_ret_buf, loss_buf): 63 | plt.figure(figsize=(16, 5)) 64 | plt.subplot(131) 65 | plt.plot(ep_ret_buf, alpha=0.5) 66 | plt.subplot(131) 67 | plt.plot(eval_ret_buf) 68 | plt.title(f"Reward: {eval_ret_buf[-1]:.0f}") 69 | plt.subplot(132) 70 | plt.plot(loss_buf["pi"], alpha=0.5) 71 | plt.title(f"Pi_Loss: {np.mean(loss_buf['pi'][:-20:]):.3f}") 72 | plt.subplot(133) 73 | plt.plot(loss_buf["vf"], alpha=0.5) 74 | plt.title(f"Vf_Loss: {np.mean(loss_buf['vf'][-20:]):.2f}") 75 | plt.show() 76 | 77 | 78 | # %% 79 | 80 | 81 | def compute_loss_pi(data, ac, clip_ratio): 82 | obs, act, adv, logp_old = data["obs"], data["act"], data["adv"], data["logp"] 83 | 84 | # Policy loss 85 | pi, logp = ac.pi(obs, act) 86 | ratio = torch.exp(logp - logp_old) 87 | clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv 88 | loss_pi = -(torch.min(ratio * adv, clip_adv)).mean() 89 | 90 | # Useful extra info 91 | # approx_kl = (logp_old - logp).mean().item() 92 | kl_div = ((logp.exp() * (logp - logp_old)).mean()).detach().item() 93 | ent = pi.entropy().mean().detach().item() 94 | clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) 95 | clipfrac = torch.as_tensor(clipped, dtype=torch.float32).mean().detach().item() 96 | pi_info = dict(kl=kl_div, ent=ent, cf=clipfrac) 97 | 98 | return loss_pi, pi_info 99 | 100 | 101 | def compute_loss_v(data, ac): 102 | obs, ret = data["obs"], data["ret"] 103 | return ((ac.v(obs) - ret) ** 2).mean() 104 | 105 | 106 | def update(buf, train_pi_iters, train_vf_iters, clip_ratio, target_kl, ac, pi_optimizer, vf_optimizer, loss_buf): 107 | data = buf.get() 108 | 109 | # Train policy with multiple steps of gradient descent 110 | for i in range(train_pi_iters): 111 | pi_optimizer.zero_grad() 112 | loss_pi, pi_info = compute_loss_pi(data, ac, clip_ratio) 113 | loss_buf["pi"].append(loss_pi.item()) 114 | kl = pi_info["kl"] 115 | if kl > 1.5 * target_kl: 116 | print("Early stopping at step %d due to reaching max kl." % i) 117 | break 118 | loss_pi.backward() 119 | pi_optimizer.step() 120 | 121 | # Value function learning 122 | for i in range(train_vf_iters): 123 | vf_optimizer.zero_grad() 124 | loss_vf = compute_loss_v(data, ac) 125 | loss_buf["vf"].append(loss_vf.item()) 126 | loss_vf.backward() 127 | vf_optimizer.step() 128 | 129 | 130 | # %% 131 | 132 | 133 | def main(): 134 | actor_critic = core.MLPActorCritic 135 | hidden_size = 64 136 | activation = torch.nn.Tanh 137 | seed = 5 138 | steps_per_epoch = 2048 139 | epochs = 1000 140 | gamma = 0.99 141 | lam = 0.97 142 | clip_ratio = 0.2 143 | pi_lr = 3e-4 144 | vf_lr = 1e-3 145 | train_pi_iters = 80 146 | train_vf_iters = 80 147 | max_ep_len = 1000 148 | target_kl = 0.01 149 | save_freq = 10 150 | obs_norm = True 151 | view_curve = False 152 | 153 | # make an environment 154 | # env = gym.make('CartPole-v0') 155 | # env = gym.make('CartPole-v1') 156 | # env = gym.make('MountainCar-v0') 157 | # env = gym.make('LunarLander-v2') 158 | env = gym.make("BipedalWalker-v3") 159 | print(f"reward_threshold: {env.spec.reward_threshold}") 160 | 161 | obs_dim = env.observation_space.shape 162 | act_dim = env.action_space.shape 163 | 164 | # Random seed 165 | env.seed(seed) 166 | random.seed(seed) 167 | torch.manual_seed(seed) 168 | np.random.seed(seed) 169 | 170 | # Create actor-critic module 171 | ac = actor_critic(env.observation_space, env.action_space, (hidden_size, hidden_size), activation) 172 | 173 | # Set up optimizers for policy and value function 174 | pi_optimizer = AdamW(ac.pi.parameters(), lr=pi_lr, eps=1e-6) 175 | vf_optimizer = AdamW(ac.v.parameters(), lr=vf_lr, eps=1e-6) 176 | 177 | # Set up experience buffer 178 | local_steps_per_epoch = int(steps_per_epoch) 179 | buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam) 180 | 181 | # Prepare for interaction with environment 182 | o, ep_ret, ep_len = env.reset(), 0, 0 183 | ep_num = 0 184 | ep_ret_buf, eval_ret_buf = [], [] 185 | loss_buf = {"pi": [], "vf": []} 186 | obs_normalizer = RunningMeanStd(shape=env.observation_space.shape) 187 | # Main loop: collect experience in env and update/log each epoch 188 | for epoch in range(epochs): 189 | for t in range(local_steps_per_epoch): 190 | env.render() 191 | if obs_norm: 192 | obs_normalizer.update(np.array([o])) 193 | o_norm = np.clip((o - obs_normalizer.mean) / np.sqrt(obs_normalizer.var), -10, 10) 194 | a, v, logp = ac.step(torch.as_tensor(o_norm, dtype=torch.float32)) 195 | else: 196 | a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32)) 197 | 198 | next_o, r, d, _ = env.step(a) 199 | ep_ret += r 200 | ep_len += 1 201 | 202 | # save and log 203 | if obs_norm: 204 | buf.store(o_norm, a, r, v, logp) 205 | else: 206 | buf.store(o, a, r, v, logp) 207 | 208 | # Update obs 209 | o = next_o 210 | 211 | timeout = ep_len == max_ep_len 212 | terminal = d or timeout 213 | epoch_ended = t == local_steps_per_epoch - 1 214 | 215 | if terminal or epoch_ended: 216 | if timeout or epoch_ended: 217 | if obs_norm: 218 | obs_normalizer.update(np.array([o])) 219 | o_norm = np.clip((o - obs_normalizer.mean) / np.sqrt(obs_normalizer.var), -10, 10) 220 | _, v, _ = ac.step(torch.as_tensor(o_norm, dtype=torch.float32)) 221 | else: 222 | _, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32)) 223 | else: 224 | if obs_norm: 225 | obs_normalizer.update(np.array([o])) 226 | v = 0 227 | buf.finish_path(v) 228 | if terminal: 229 | ep_ret_buf.append(ep_ret) 230 | eval_ret_buf.append(np.mean(ep_ret_buf[-20:])) 231 | ep_num += 1 232 | if view_curve: 233 | plot(ep_ret_buf, eval_ret_buf, loss_buf) 234 | else: 235 | print(f"Episode: {ep_num:3}\tReward: {ep_ret:3}") 236 | if eval_ret_buf[-1] >= env.spec.reward_threshold: 237 | print(f"\n{env.spec.id} is sloved! {ep_num} Episode") 238 | torch.save(ac.state_dict(), f"./test/saved_models/{env.spec.id}_ep{ep_num}_clear_model_ppo.pt") 239 | with open(f"./test/saved_models/{env.spec.id}_ep{ep_num}_clear_norm_obs.pkl", "wb") as f: 240 | pickle.dump(obs_normalizer, f, pickle.HIGHEST_PROTOCOL) 241 | return 242 | 243 | o, ep_ret, ep_len = env.reset(), 0, 0 244 | # Perform PPO update! 245 | update(buf, train_pi_iters, train_vf_iters, clip_ratio, target_kl, ac, pi_optimizer, vf_optimizer, loss_buf) 246 | 247 | 248 | # %% 249 | 250 | 251 | main() 252 | -------------------------------------------------------------------------------- /src/PPO2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # %% 4 | import random 5 | import time 6 | 7 | import core 8 | import gym 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | from IPython.display import clear_output 13 | from running_mean_std import RunningMeanStd 14 | from torch.optim import Adam, AdamW 15 | 16 | 17 | class PPOBuffer(object): 18 | def __init__(self, obs_dim, act_dim, size, gamma=0.999, lam=0.97): 19 | self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32) 20 | self.act_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32) 21 | self.adv_buf = np.zeros(size, dtype=np.float32) 22 | self.rew_buf = np.zeros(size, dtype=np.float32) 23 | self.ret_buf = np.zeros(size, dtype=np.float32) 24 | self.val_buf = np.zeros(size, dtype=np.float32) 25 | self.logp_buf = np.zeros(size, dtype=np.float32) 26 | self.gamma, self.lam = gamma, lam 27 | self.ptr, self.path_start_idx, self.max_size = 0, 0, size 28 | 29 | def store(self, obs, act, rew, val, logp): 30 | assert self.ptr < self.max_size 31 | self.obs_buf[self.ptr] = obs 32 | self.act_buf[self.ptr] = act 33 | self.rew_buf[self.ptr] = rew 34 | self.val_buf[self.ptr] = val 35 | self.logp_buf[self.ptr] = logp 36 | self.ptr += 1 37 | 38 | def finish_path(self, last_val=0): 39 | path_slice = slice(self.path_start_idx, self.ptr) 40 | rews = np.append(self.rew_buf[path_slice], last_val) 41 | vals = np.append(self.val_buf[path_slice], last_val) 42 | 43 | deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1] 44 | self.adv_buf[path_slice] = core.discount_cumsum(deltas, self.gamma * self.lam) 45 | self.ret_buf[path_slice] = core.discount_cumsum(rews, self.gamma)[:-1] 46 | self.path_start_idx = self.ptr 47 | 48 | def get(self): 49 | assert self.ptr == self.max_size 50 | self.ptr, self.path_start_idx = 0, 0 51 | adv_mean = np.mean(self.adv_buf) 52 | adv_std = np.std(self.adv_buf) 53 | self.adv_buf = (self.adv_buf - adv_mean) / adv_std 54 | data = dict(obs=self.obs_buf, act=self.act_buf, ret=self.ret_buf, adv=self.adv_buf, logp=self.logp_buf) 55 | return {k: torch.as_tensor(v, dtype=torch.float32) for k, v in data.items()} 56 | 57 | 58 | def plot(ep_ret_buf, eval_ret_buf, loss_buf): 59 | clear_output(True) 60 | plt.figure(figsize=(16, 5)) 61 | plt.subplot(131) 62 | plt.plot(ep_ret_buf, alpha=0.5) 63 | plt.subplot(131) 64 | plt.plot(eval_ret_buf) 65 | plt.title(f"Reward: {eval_ret_buf[-1]:.0f}") 66 | plt.subplot(132) 67 | plt.plot(loss_buf["pi"], alpha=0.5) 68 | plt.title(f"Pi_Loss: {np.mean(loss_buf['pi'][:-10:]):.3f}") 69 | plt.subplot(133) 70 | plt.plot(loss_buf["vf"], alpha=0.5) 71 | plt.title(f"Vf_Loss: {np.mean(loss_buf['vf'][-10:]):.2f}") 72 | plt.show() 73 | 74 | 75 | def compute_loss_pi(data, ac, beta): 76 | obs, act, adv, logp_old = data["obs"], data["act"], data["adv"], data["logp"] 77 | 78 | # Policy loss 79 | pi, logp = ac.pi(obs, act) 80 | ratio = torch.exp(logp - logp_old) 81 | kl_div = logp.exp() * (logp - logp_old) 82 | loss_pi = -(ratio * adv - beta * kl_div).mean() 83 | 84 | # Useful extra info 85 | kl = ((logp.exp() * (logp - logp_old)).mean()).detach().item() 86 | ent = pi.entropy().mean().detach().item() 87 | pi_info = dict(kl=kl, ent=ent) 88 | return loss_pi, pi_info 89 | 90 | 91 | def compute_loss_v(data, ac): 92 | obs, ret = data["obs"], data["ret"] 93 | return ((ac.v(obs) - ret) ** 2).mean() 94 | 95 | 96 | def update(buf, train_pi_iters, train_vf_iters, beta, target_kl, ac, pi_optimizer, vf_optimizer, loss_buf): 97 | data = buf.get() 98 | 99 | # Train policy with multiple steps of gradient descent 100 | for i in range(train_pi_iters): 101 | pi_optimizer.zero_grad() 102 | loss_pi, pi_info = compute_loss_pi(data, ac, beta) 103 | loss_buf["pi"].append(loss_pi.item()) 104 | loss_pi.backward() 105 | pi_optimizer.step() 106 | 107 | # Value function learning 108 | for i in range(train_vf_iters): 109 | vf_optimizer.zero_grad() 110 | loss_vf = compute_loss_v(data, ac) 111 | loss_buf["vf"].append(loss_vf.item()) 112 | loss_vf.backward() 113 | vf_optimizer.step() 114 | 115 | 116 | def main(): 117 | actor_critic = core.MLPActorCritic 118 | hidden_size = 64 119 | activation = torch.nn.Tanh 120 | seed = 5 121 | steps_per_epoch = 2048 122 | epochs = 1000 123 | gamma = 0.99 124 | lam = 0.97 125 | beta = 3.0 126 | pi_lr = 3e-4 127 | vf_lr = 1e-3 128 | train_pi_iters = 80 129 | train_vf_iters = 80 130 | max_ep_len = 1000 131 | target_kl = 0.01 132 | save_freq = 10 133 | obs_norm = True 134 | view_curve = False 135 | 136 | # make an environment 137 | # env = gym.make('CartPole-v0') 138 | # env = gym.make('CartPole-v1') 139 | # env = gym.make('MountainCar-v0') 140 | # env = gym.make('LunarLander-v2') 141 | env = gym.make("BipedalWalker-v3") 142 | print(f"reward_threshold: {env.spec.reward_threshold}") 143 | 144 | obs_dim = env.observation_space.shape 145 | act_dim = env.action_space.shape 146 | 147 | # Random seed 148 | env.seed(seed) 149 | random.seed(seed) 150 | torch.manual_seed(seed) 151 | np.random.seed(seed) 152 | 153 | # Create actor-critic module 154 | ac = actor_critic(env.observation_space, env.action_space, (hidden_size, hidden_size), activation) 155 | 156 | # Set up optimizers for policy and value function 157 | pi_optimizer = AdamW(ac.pi.parameters(), lr=pi_lr, eps=1e-6) 158 | vf_optimizer = AdamW(ac.v.parameters(), lr=vf_lr, eps=1e-6) 159 | 160 | # Count variables 161 | var_counts = tuple(core.count_vars(module) for module in [ac.pi, ac.v]) 162 | 163 | # Set up experience buffer 164 | local_steps_per_epoch = int(steps_per_epoch) 165 | buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam) 166 | 167 | # Prepare for interaction with environment 168 | start_time = time.time() 169 | o, ep_ret, ep_len = env.reset(), 0, 0 170 | ep_num = 0 171 | ep_ret_buf, eval_ret_buf = [], [] 172 | loss_buf = {"pi": [], "vf": []} 173 | obs_normalizer = RunningMeanStd(shape=env.observation_space.shape) 174 | # Main loop: collect experience in env and update/log each epoch 175 | for epoch in range(epochs): 176 | for t in range(local_steps_per_epoch): 177 | env.render() 178 | if obs_norm: 179 | obs_normalizer.update(np.array([o])) 180 | o_norm = np.clip((o - obs_normalizer.mean) / np.sqrt(obs_normalizer.var), -10, 10) 181 | a, v, logp = ac.step(torch.as_tensor(o_norm, dtype=torch.float32)) 182 | else: 183 | a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32)) 184 | 185 | next_o, r, d, _ = env.step(a) 186 | ep_ret += r 187 | ep_len += 1 188 | 189 | # save and log 190 | if obs_norm: 191 | buf.store(o_norm, a, r, v, logp) 192 | else: 193 | buf.store(o, a, r, v, logp) 194 | 195 | # Update obs 196 | o = next_o 197 | 198 | timeout = ep_len == max_ep_len 199 | terminal = d or timeout 200 | epoch_ended = t == local_steps_per_epoch - 1 201 | 202 | if terminal or epoch_ended: 203 | if timeout or epoch_ended: 204 | if obs_norm: 205 | obs_normalizer.update(np.array([o])) 206 | o_norm = np.clip((o - obs_normalizer.mean) / np.sqrt(obs_normalizer.var), -10, 10) 207 | _, v, _ = ac.step(torch.as_tensor(o_norm, dtype=torch.float32)) 208 | else: 209 | _, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32)) 210 | else: 211 | if obs_norm: 212 | obs_normalizer.update(np.array([o])) 213 | v = 0 214 | buf.finish_path(v) 215 | if terminal: 216 | ep_ret_buf.append(ep_ret) 217 | eval_ret_buf.append(np.mean(ep_ret_buf[-100:])) 218 | ep_num += 1 219 | if view_curve: 220 | plot(ep_ret_buf, eval_ret_buf, loss_buf) 221 | else: 222 | print(f"Episode: {ep_num:3} Reward: {ep_ret:3}") 223 | if eval_ret_buf[-1] >= env.spec.reward_threshold: 224 | print(f"\n{env.spec.id} is sloved! {ep_num} Episode") 225 | return 226 | 227 | o, ep_ret, ep_len = env.reset(), 0, 0 228 | # Perform PPO update! 229 | update(buf, train_pi_iters, train_vf_iters, beta, target_kl, ac, pi_optimizer, vf_optimizer, loss_buf) 230 | 231 | 232 | if __name__ == "__main__": 233 | main() 234 | 235 | 236 | # %% 237 | # %% 238 | -------------------------------------------------------------------------------- /src/core.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import scipy.signal\n", 11 | "from gym.spaces import Box, Discrete\n", 12 | "\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "from torch.distributions.normal import Normal\n", 16 | "from torch.distributions.categorical import Categorical\n", 17 | "\n", 18 | "\n", 19 | "def combined_shape(length, shape=None):\n", 20 | " if shape is None:\n", 21 | " return (length,)\n", 22 | " return (length, shape) if np.isscalar(shape) else (length, *shape)\n", 23 | "\n", 24 | "\n", 25 | "def mlp(sizes, activation, output_activation=nn.Identity):\n", 26 | " layers = []\n", 27 | " for j in range(len(sizes)-1):\n", 28 | " act = activation if j < len(sizes)-2 else output_activation\n", 29 | " layers += [nn.Linear(sizes[j], sizes[j+1]), act()]\n", 30 | " return nn.Sequential(*layers)\n", 31 | "\n", 32 | "\n", 33 | "def count_vars(module):\n", 34 | " return sum([np.prod(p.shape) for p in module.parameters()])\n", 35 | "\n", 36 | "\n", 37 | "def discount_cumsum(x, discount):\n", 38 | " \"\"\"\n", 39 | " magic from rllab for computing discounted cumulative sums of vectors.\n", 40 | " input: \n", 41 | " vector x, \n", 42 | " [x0, \n", 43 | " x1, \n", 44 | " x2]\n", 45 | " output:\n", 46 | " [x0 + discount * x1 + discount^2 * x2, \n", 47 | " x1 + discount * x2,\n", 48 | " x2]\n", 49 | " \"\"\"\n", 50 | " return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]\n", 51 | "\n", 52 | "\n", 53 | "class Actor(nn.Module):\n", 54 | "\n", 55 | " def _distribution(self, obs):\n", 56 | " raise NotImplementedError\n", 57 | "\n", 58 | " def _log_prob_from_distribution(self, pi, act):\n", 59 | " raise NotImplementedError\n", 60 | "\n", 61 | " def forward(self, obs, act=None):\n", 62 | " # Produce action distributions for given observations, and \n", 63 | " # optionally compute the log likelihood of given actions under\n", 64 | " # those distributions.\n", 65 | " pi = self._distribution(obs)\n", 66 | " logp_a = None\n", 67 | " if act is not None:\n", 68 | " logp_a = self._log_prob_from_distribution(pi, act)\n", 69 | " return pi, logp_a\n", 70 | "\n", 71 | "\n", 72 | "class MLPCategoricalActor(Actor):\n", 73 | " \n", 74 | " def __init__(self, obs_dim, act_dim, hidden_sizes, activation):\n", 75 | " super().__init__()\n", 76 | " self.logits_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)\n", 77 | "\n", 78 | " def _distribution(self, obs):\n", 79 | " logits = self.logits_net(obs)\n", 80 | " return Categorical(logits=logits)\n", 81 | "\n", 82 | " def _log_prob_from_distribution(self, pi, act):\n", 83 | " return pi.log_prob(act)\n", 84 | "\n", 85 | "\n", 86 | "class MLPGaussianActor(Actor):\n", 87 | "\n", 88 | " def __init__(self, obs_dim, act_dim, hidden_sizes, activation):\n", 89 | " super().__init__()\n", 90 | " log_std = -0.5 * np.ones(act_dim, dtype=np.float32)\n", 91 | " self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))\n", 92 | " self.mu_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)\n", 93 | "\n", 94 | " def _distribution(self, obs):\n", 95 | " mu = self.mu_net(obs)\n", 96 | " std = torch.exp(self.log_std)\n", 97 | " return Normal(mu, std)\n", 98 | "\n", 99 | " def _log_prob_from_distribution(self, pi, act):\n", 100 | " return pi.log_prob(act).sum(axis=-1) # Last axis sum needed for Torch Normal distribution\n", 101 | "\n", 102 | "\n", 103 | "class MLPCritic(nn.Module):\n", 104 | "\n", 105 | " def __init__(self, obs_dim, hidden_sizes, activation):\n", 106 | " super().__init__()\n", 107 | " self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation)\n", 108 | "\n", 109 | " def forward(self, obs):\n", 110 | " return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape.\n", 111 | "\n", 112 | "\n", 113 | "\n", 114 | "class MLPActorCritic(nn.Module):\n", 115 | "\n", 116 | "\n", 117 | " def __init__(self, observation_space, action_space, \n", 118 | " hidden_sizes=(64,64), activation=nn.Tanh):\n", 119 | " super().__init__()\n", 120 | "\n", 121 | " obs_dim = observation_space.shape[0]\n", 122 | "\n", 123 | " # policy builder depends on action space\n", 124 | " if isinstance(action_space, Box):\n", 125 | " self.pi = MLPGaussianActor(obs_dim, action_space.shape[0], hidden_sizes, activation)\n", 126 | " elif isinstance(action_space, Discrete):\n", 127 | " self.pi = MLPCategoricalActor(obs_dim, action_space.n, hidden_sizes, activation)\n", 128 | "\n", 129 | " # build value function\n", 130 | " self.v = MLPCritic(obs_dim, hidden_sizes, activation)\n", 131 | "\n", 132 | " def step(self, obs):\n", 133 | " with torch.no_grad():\n", 134 | " pi = self.pi._distribution(obs)\n", 135 | " a = pi.sample()\n", 136 | " logp_a = self.pi._log_prob_from_distribution(pi, a)\n", 137 | " v = self.v(obs)\n", 138 | " return a.numpy(), v.numpy(), logp_a.numpy()\n", 139 | "\n", 140 | " def act(self, obs):\n", 141 | " return self.step(obs)[0]" 142 | ] 143 | } 144 | ], 145 | "metadata": { 146 | "kernelspec": { 147 | "display_name": "Python 3", 148 | "language": "python", 149 | "name": "python3" 150 | }, 151 | "language_info": { 152 | "codemirror_mode": { 153 | "name": "ipython", 154 | "version": 3 155 | }, 156 | "file_extension": ".py", 157 | "mimetype": "text/x-python", 158 | "name": "python", 159 | "nbconvert_exporter": "python", 160 | "pygments_lexer": "ipython3", 161 | "version": "3.6.10" 162 | } 163 | }, 164 | "nbformat": 4, 165 | "nbformat_minor": 4 166 | } 167 | -------------------------------------------------------------------------------- /src/core.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # %% 4 | import numpy as np 5 | import scipy.signal 6 | import torch 7 | import torch.nn as nn 8 | from gym.spaces import Box, Discrete 9 | from torch.distributions.categorical import Categorical 10 | from torch.distributions.normal import Normal 11 | 12 | 13 | def combined_shape(length, shape=None): 14 | if shape is None: 15 | return (length,) 16 | return (length, shape) if np.isscalar(shape) else (length, *shape) 17 | 18 | 19 | def mlp(sizes, activation, output_activation=nn.Identity): 20 | layers = [] 21 | for j in range(len(sizes) - 1): 22 | act = activation if j < len(sizes) - 2 else output_activation 23 | layers += [nn.Linear(sizes[j], sizes[j + 1]), act()] 24 | return nn.Sequential(*layers) 25 | 26 | 27 | def count_vars(module): 28 | return sum([np.prod(p.shape) for p in module.parameters()]) 29 | 30 | 31 | def discount_cumsum(x, discount): 32 | """ 33 | magic from rllab for computing discounted cumulative sums of vectors. 34 | input: 35 | vector x, 36 | [x0, 37 | x1, 38 | x2] 39 | output: 40 | [x0 + discount * x1 + discount^2 * x2, 41 | x1 + discount * x2, 42 | x2] 43 | """ 44 | return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1] 45 | 46 | 47 | class Actor(nn.Module): 48 | 49 | def _distribution(self, obs): 50 | raise NotImplementedError 51 | 52 | def _log_prob_from_distribution(self, pi, act): 53 | raise NotImplementedError 54 | 55 | def forward(self, obs, act=None): 56 | # Produce action distributions for given observations, and 57 | # optionally compute the log likelihood of given actions under 58 | # those distributions. 59 | pi = self._distribution(obs) 60 | logp_a = None 61 | if act is not None: 62 | logp_a = self._log_prob_from_distribution(pi, act) 63 | return pi, logp_a 64 | 65 | 66 | class MLPCategoricalActor(Actor): 67 | 68 | def __init__(self, obs_dim, act_dim, hidden_sizes, activation): 69 | super().__init__() 70 | self.logits_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation) 71 | 72 | def _distribution(self, obs): 73 | logits = self.logits_net(obs) 74 | return Categorical(logits=logits) 75 | 76 | def _log_prob_from_distribution(self, pi, act): 77 | return pi.log_prob(act) 78 | 79 | 80 | class MLPGaussianActor(Actor): 81 | 82 | def __init__(self, obs_dim, act_dim, hidden_sizes, activation): 83 | super().__init__() 84 | log_std = -0.5 * np.ones(act_dim, dtype=np.float32) 85 | self.log_std = torch.nn.Parameter(torch.as_tensor(log_std)) 86 | self.mu_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation) 87 | 88 | def _distribution(self, obs): 89 | mu = self.mu_net(obs) 90 | std = torch.exp(self.log_std) 91 | return Normal(mu, std) 92 | 93 | def _log_prob_from_distribution(self, pi, act): 94 | return pi.log_prob(act).sum(axis=-1) # Last axis sum needed for Torch Normal distribution 95 | 96 | 97 | class MLPCritic(nn.Module): 98 | 99 | def __init__(self, obs_dim, hidden_sizes, activation): 100 | super().__init__() 101 | self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation) 102 | 103 | def forward(self, obs): 104 | return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape. 105 | 106 | 107 | class MLPActorCritic(nn.Module): 108 | 109 | def __init__(self, observation_space, action_space, hidden_sizes=(64, 64), activation=nn.Tanh): 110 | super().__init__() 111 | 112 | obs_dim = observation_space.shape[0] 113 | 114 | # policy builder depends on action space 115 | if isinstance(action_space, Box): 116 | self.pi = MLPGaussianActor(obs_dim, action_space.shape[0], hidden_sizes, activation) 117 | elif isinstance(action_space, Discrete): 118 | self.pi = MLPCategoricalActor(obs_dim, action_space.n, hidden_sizes, activation) 119 | 120 | # build value function 121 | self.v = MLPCritic(obs_dim, hidden_sizes, activation) 122 | 123 | def step(self, obs): 124 | with torch.no_grad(): 125 | pi = self.pi._distribution(obs) 126 | a = pi.sample() 127 | logp_a = self.pi._log_prob_from_distribution(pi, a) 128 | v = self.v(obs) 129 | return a.numpy(), v.numpy(), logp_a.numpy() 130 | 131 | def act(self, obs): 132 | return self.step(obs)[0] 133 | -------------------------------------------------------------------------------- /src/ppo_discrete_step.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # %% 4 | import random 5 | from collections import deque 6 | from copy import deepcopy 7 | 8 | import gym 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | from torch.distributions import Categorical 15 | from torch.utils.data import DataLoader 16 | from IPython.display import clear_output 17 | 18 | 19 | # %% 20 | SEED = 5 21 | BATCH_SIZE = 256 22 | LR = 0.01 23 | EPOCHS = 10 24 | CLIP = 0.2 25 | GAMMA = 0.999 26 | LAMBDA = 0.98 27 | ENT_COEF = 0.01 28 | V_COEF = 0.5 29 | V_CLIP = True 30 | LIN_REDUCE = False 31 | GRAD_NORM = False 32 | # set device 33 | use_cuda = torch.cuda.is_available() 34 | print('cuda:', use_cuda) 35 | device = torch.device('cuda' if use_cuda else 'cpu') 36 | 37 | # random seed 38 | random.seed(SEED) 39 | np.random.seed(SEED) 40 | torch.manual_seed(SEED) 41 | if use_cuda: 42 | torch.cuda.manual_seed_all(SEED) 43 | 44 | 45 | # %% 46 | class ActorCriticNet(nn.Module): 47 | def __init__(self, obs_space, action_space): 48 | super().__init__() 49 | h = 32 50 | self.head = nn.Sequential( 51 | nn.Linear(obs_space, h), 52 | nn.Tanh() 53 | ) 54 | self.pol = nn.Sequential( 55 | nn.Linear(h, h), 56 | nn.Tanh(), 57 | nn.Linear(h, action_space) 58 | ) 59 | self.val = nn.Sequential( 60 | nn.Linear(h, h), 61 | nn.Tanh(), 62 | nn.Linear(h, 1) 63 | ) 64 | self.log_softmax = nn.LogSoftmax(dim=-1) 65 | 66 | def forward(self, x): 67 | out = self.head(x) 68 | logit = self.pol(out).reshape(out.shape[0], -1) 69 | log_p = self.log_softmax(logit) 70 | v = self.val(out).reshape(out.shape[0], 1) 71 | 72 | return log_p, v 73 | 74 | 75 | # %% 76 | losses = [] 77 | 78 | 79 | def learn(net, old_net, optimizer, train_memory): 80 | global CLIP, LR 81 | global total_epochs 82 | net.train() 83 | old_net.train() 84 | 85 | for epoch in range(EPOCHS): 86 | if LIN_REDUCE: 87 | lr = LR - (LR * epoch / total_epochs) 88 | clip = CLIP - (CLIP * epoch / total_epochs) 89 | else: 90 | lr = LR 91 | clip = CLIP 92 | 93 | for param_group in optimizer.param_groups: 94 | param_group['lr'] = lr 95 | 96 | dataloader = DataLoader( 97 | train_memory, 98 | shuffle=True, 99 | batch_size=BATCH_SIZE, 100 | pin_memory=use_cuda 101 | ) 102 | 103 | for (s, a, ret, adv) in dataloader: 104 | s_batch = s.to(device).float() 105 | a_batch = a.to(device).long() 106 | ret_batch = ret.to(device).float() 107 | ret_batch = (ret_batch - ret_batch.mean()) / (ret_batch.std() + 1e-6) 108 | adv_batch = adv.to(device).float() 109 | adv_batch = (adv_batch - adv_batch.mean()) / (adv_batch.std() + 1e-6) 110 | with torch.no_grad(): 111 | log_p_batch_old, v_batch_old = old_net(s_batch) 112 | log_p_acting_old = log_p_batch_old[range(BATCH_SIZE), a_batch] 113 | 114 | log_p_batch, v_batch = net(s_batch) 115 | log_p_acting = log_p_batch[range(BATCH_SIZE), a_batch] 116 | p_ratio = (log_p_acting - log_p_acting_old).exp() 117 | p_ratio_clip = torch.clamp(p_ratio, 1 - clip, 1 + clip) 118 | p_loss = torch.min(p_ratio * adv_batch, 119 | p_ratio_clip * adv_batch).mean() 120 | if V_CLIP: 121 | v_clip = v_batch_old + torch.clamp(v_batch - v_batch_old, -clip, clip) 122 | v_loss1 = (ret_batch - v_clip).pow(2) 123 | v_loss2 = (ret_batch - v_batch).pow(2) 124 | v_loss = 0.5 * torch.max(v_loss1, v_loss2).mean() 125 | else: 126 | v_loss = 0.5 * (ret_batch - v_batch).pow(2).mean() 127 | 128 | log_p, _ = net(s_batch) 129 | entropy = -(log_p.exp() * log_p).sum(dim=1).mean() 130 | 131 | # loss 132 | loss = -(p_loss - V_COEF * v_loss + ENT_COEF * entropy) 133 | losses.append(loss.item()) 134 | 135 | optimizer.zero_grad() 136 | loss.backward() 137 | if GRAD_NORM: 138 | nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.5) 139 | optimizer.step() 140 | train_memory.clear() 141 | 142 | 143 | def get_action_and_value(obs, old_net): 144 | old_net.eval() 145 | with torch.no_grad(): 146 | state = torch.tensor([obs]).to(device).float() 147 | log_p, v = old_net(state) 148 | m = Categorical(log_p.exp()) 149 | action = m.sample() 150 | 151 | return action.item(), v.item() 152 | 153 | 154 | def compute_adv_with_gae(rewards, values, roll_memory): 155 | rew = np.array(rewards, 'float') 156 | val = np.array(values[:-1], 'float') 157 | _val = np.array(values[1:], 'float') 158 | delta = rew + GAMMA * _val - val 159 | dis_r = np.array([GAMMA**(i) * r for i, r in enumerate(rewards)], 'float') 160 | gae_dt = np.array([(GAMMA * LAMBDA)**(i) * dt for i, 161 | dt in enumerate(delta.tolist())], 'float') 162 | for i, data in enumerate(roll_memory): 163 | data.append(sum(dis_r[i:] / GAMMA**(i))) 164 | data.append(sum(gae_dt[i:] / (GAMMA * LAMBDA)**(i))) 165 | 166 | rewards.clear() 167 | values.clear() 168 | 169 | return roll_memory 170 | 171 | 172 | def plot(): 173 | clear_output(True) 174 | plt.figure(figsize=(16, 5)) 175 | plt.subplot(121) 176 | plt.plot(ep_rewards, alpha=0.5) 177 | plt.subplot(121) 178 | plt.plot(reward_eval) 179 | plt.title(f'Reward: ' 180 | f'{reward_eval[-1]}') 181 | plt.subplot(122) 182 | plt.plot(losses, alpha=0.5) 183 | plt.title(f'Loss: ' 184 | f'{np.mean(list(reversed(losses))[: n_eval]).round(decimals=2)}') 185 | plt.show() 186 | 187 | 188 | # %% 189 | # make an environment 190 | # env = gym.make('CartPole-v0') 191 | env = gym.make('CartPole-v1') 192 | # env = gym.make('MountainCar-v0') 193 | # env = gym.make('LunarLander-v2') 194 | 195 | env.seed(SEED) 196 | obs_space = env.observation_space.shape[0] 197 | action_space = env.action_space.n 198 | 199 | # hyperparameter 200 | n_episodes = 100000 201 | roll_len = 2048 202 | total_epochs = roll_len // BATCH_SIZE 203 | n_eval = 10 204 | 205 | # global values 206 | steps = 0 207 | ep_rewards = [] 208 | reward_eval = [] 209 | is_rollout = False 210 | is_solved = False 211 | 212 | # make memories 213 | train_memory = [] 214 | roll_memory = [] 215 | rewards = [] 216 | values = [] 217 | 218 | # make nerual networks 219 | net = ActorCriticNet(obs_space, action_space).to(device) 220 | old_net = deepcopy(net) 221 | # no_decay = ['bias'] 222 | # grouped_parameters = [ 223 | # {'params': [p for n, p in net.named_parameters() if not any( 224 | # nd in n for nd in no_decay)], 'weight_decay': 0.0}, 225 | # {'params': [p for n, p in net.named_parameters() if any( 226 | # nd in n for nd in no_decay)], 'weight_decay': 0.0} 227 | # ] 228 | optimizer = torch.optim.AdamW(net.parameters(), lr=LR, eps=1e-6) 229 | 230 | # play! 231 | for i in range(1, n_episodes + 1): 232 | obs = env.reset() 233 | done = False 234 | ep_reward = 0 235 | while not done: 236 | env.render() 237 | action, value = get_action_and_value(obs, old_net) 238 | _obs, reward, done, _ = env.step(action) 239 | 240 | # store 241 | roll_memory.append([obs, action]) 242 | rewards.append(reward) 243 | values.append(value) 244 | 245 | obs = _obs 246 | steps += 1 247 | ep_reward += reward 248 | 249 | if done or steps % roll_len == 0: 250 | if done: 251 | _value = 0. 252 | else: 253 | _, _value = get_action_and_value(_obs, old_net) 254 | 255 | values.append(_value) 256 | train_memory.extend(compute_adv_with_gae( 257 | rewards, values, roll_memory)) 258 | roll_memory.clear() 259 | 260 | if steps % roll_len == 0: 261 | learn(net, old_net, optimizer, train_memory) 262 | old_net.load_state_dict(net.state_dict()) 263 | 264 | if done: 265 | ep_rewards.append(ep_reward) 266 | reward_eval.append( 267 | np.mean(list(reversed(ep_rewards))[: n_eval]).round(decimals=2)) 268 | # plot() 269 | print('{:3} Episode in {:5} steps, reward {:.2f}'.format( 270 | i, steps, ep_reward)) 271 | 272 | if len(ep_rewards) >= n_eval: 273 | if reward_eval[-1] >= env.spec.reward_threshold: 274 | # if reward_eval[-1] >= 495: 275 | print('\n{} is sloved! {:3} Episode in {:3} steps'.format( 276 | env.spec.id, i, steps)) 277 | torch.save(net.state_dict(), 278 | f'./test/saved_models/{env.spec.id}_ep{i}_clear_model_ppo_st.pt') 279 | break 280 | env.close() 281 | 282 | 283 | # %% 284 | print(env.spec.max_episode_steps) 285 | 286 | 287 | # %% 288 | # [ 289 | # ('CartPole-v0', 889, 2048, 0.2, 10, 0.5, 0.01, False, 0.999, 0.98), 290 | # ('CartPole-v1', 801, 2048, 0.2, 10, 0.5, 0.01, False, 0.999, 0.98), 291 | # ('MountainCar-v0', None), 292 | # ('LunarLander-v2', 876, 2048, 0.2, 10, 1.0, 0.01, False, 0.999, 0.98) 293 | # ] 294 | 295 | -------------------------------------------------------------------------------- /src/ppo_discrete_step_parallel.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from collections import deque 3 | from copy import deepcopy 4 | 5 | import gym 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.multiprocessing as mp 11 | from torch.distributions import Categorical 12 | from torch.utils.data import DataLoader 13 | from torch.utils.tensorboard import SummaryWriter 14 | from running_mean_std import RunningMeanStd 15 | 16 | N_PROCESS = 4 17 | ROLL_LEN = 2048 * N_PROCESS 18 | BATCH_SIZE = 2048 19 | P_LR = 3e-4 20 | V_LR = 1e-3 21 | ITER = 80 22 | CLIP = 0.2 23 | GAMMA = 0.999 24 | LAMBDA = 0.97 25 | # BETA = 3.0 26 | # ENT_COEF = 0.0 27 | GRAD_NORM = False 28 | OBS_NORM = True 29 | 30 | # set device 31 | use_cuda = torch.cuda.is_available() 32 | print('cuda:', use_cuda) 33 | device = torch.device('cuda:0' if use_cuda else 'cpu') 34 | writer = SummaryWriter() 35 | 36 | # random seed 37 | torch.manual_seed(5) 38 | if use_cuda: 39 | torch.cuda.manual_seed_all(5) 40 | 41 | # make an environment 42 | # env = gym.make('CartPole-v0') 43 | # env = gym.make('CartPole-v1') 44 | env = gym.make('MountainCar-v0') 45 | # env = gym.make('LunarLander-v2') 46 | 47 | 48 | class ActorCriticNet(nn.Module): 49 | def __init__(self, obs_space, action_space): 50 | super().__init__() 51 | h = 32 52 | self.pol = nn.Sequential( 53 | nn.Linear(obs_space, h), 54 | nn.Tanh(), 55 | nn.Linear(h, h), 56 | nn.Tanh(), 57 | nn.Linear(h, action_space) 58 | ) 59 | self.val = nn.Sequential( 60 | nn.Linear(obs_space, h), 61 | nn.Tanh(), 62 | nn.Linear(h, h), 63 | nn.Tanh(), 64 | nn.Linear(h, 1) 65 | ) 66 | self.log_softmax = nn.LogSoftmax(dim=-1) 67 | 68 | def forward(self, x): 69 | logit = self.pol(x).reshape(x.shape[0], -1) 70 | log_p = self.log_softmax(logit) 71 | v = self.val(x).reshape(x.shape[0], 1) 72 | return log_p, v 73 | 74 | 75 | # + 76 | def learn(net, optimizer, train_memory): 77 | global steps 78 | old_net = deepcopy(net) 79 | net.train() 80 | old_net.train() 81 | dataloader = DataLoader( 82 | train_memory, 83 | shuffle=False, 84 | batch_size=BATCH_SIZE, 85 | pin_memory=use_cuda 86 | ) 87 | advs = [] 88 | for data in dataloader.dataset: 89 | advs.append(data[3]) 90 | advs = torch.tensor(advs, dtype=torch.float32).to(device) 91 | adv_mean = advs.mean() 92 | adv_std = advs.std() 93 | for _ in range(ITER): 94 | for (s, a, ret, adv) in dataloader: 95 | s_batch = s.to(device).float() 96 | a_batch = a.to(device).long() 97 | ret_batch = ret.to(device).float() 98 | adv_batch = adv.to(device).float() 99 | adv_batch = (adv_batch - adv_mean) / adv_std 100 | with torch.no_grad(): 101 | log_p_batch_old, v_batch_old = old_net(s_batch) 102 | log_p_acting_old = log_p_batch_old[range(BATCH_SIZE), a_batch] 103 | 104 | log_p_batch, v_batch = net(s_batch) 105 | log_p_acting = log_p_batch[range(BATCH_SIZE), a_batch] 106 | p_ratio = (log_p_acting - log_p_acting_old).exp() 107 | p_ratio_clip = torch.clamp(p_ratio, 1 - CLIP, 1 + CLIP) 108 | p_loss = -(torch.min(p_ratio * adv_batch, p_ratio_clip * adv_batch).mean()) 109 | v_loss = (ret_batch - v_batch).pow(2).mean() 110 | 111 | kl_div = ((log_p_batch.exp() * (log_p_batch - log_p_batch_old)).sum(dim=-1).mean()).detach().item() 112 | 113 | # log_p, _ = net(s_batch) 114 | # entropy = -(log_p.exp() * log_p).sum(dim=1).mean() 115 | 116 | # loss 117 | loss = p_loss + v_loss 118 | 119 | if kl_div <= 0.01 * 1.5: 120 | optimizer[0].zero_grad() 121 | p_loss.backward() 122 | if GRAD_NORM: 123 | nn.utils.clip_grad_norm_(net.parameters() , max_norm=1.0) 124 | optimizer[0].step() 125 | else: 126 | print("Pass the Pi update!") 127 | optimizer[1].zero_grad() 128 | v_loss.backward() 129 | if GRAD_NORM: 130 | nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0) 131 | optimizer[1].step() 132 | if rank == 0: 133 | writer.add_scalar('data/p_loss', p_loss.item(), steps) 134 | writer.add_scalar('data/v_loss', v_loss.item(), steps) 135 | writer.add_scalar('data/kl_div',kl_div, steps) 136 | steps += 1 137 | train_memory.clear() 138 | return net.state_dict() 139 | 140 | 141 | # - 142 | 143 | def get_action_and_value(obs, old_net): 144 | old_net.eval() 145 | with torch.no_grad(): 146 | state = torch.tensor([obs]).to(device).float() 147 | log_p, v = old_net(state) 148 | m = Categorical(log_p.exp()) 149 | action = m.sample() 150 | return action.item(), v.item() 151 | 152 | 153 | def compute_adv_with_gae(rewards, values, roll_memory): 154 | rew = np.array(rewards, 'float') 155 | val = np.array(values[:-1], 'float') 156 | _val = np.array(values[1:], 'float') 157 | ret = rew + GAMMA * _val 158 | delta = ret - val 159 | gae_dt = np.array([(GAMMA * LAMBDA)**(i) * dt for i, dt in enumerate(delta.tolist())], 'float') 160 | for i, data in enumerate(roll_memory): 161 | data.append(ret[i]) 162 | data.append(sum(gae_dt[i:] / (GAMMA * LAMBDA)**(i))) 163 | 164 | rewards.clear() 165 | values.clear() 166 | return roll_memory 167 | 168 | 169 | def roll_out(env, length, rank, child): 170 | env.seed(rank) 171 | 172 | # hyperparameter 173 | roll_len = length 174 | 175 | # for play 176 | episodes = 0 177 | ep_steps = 0 178 | ep_rewards = [] 179 | 180 | # memories 181 | train_memory = [] 182 | roll_memory = [] 183 | rewards = [] 184 | values = [] 185 | 186 | # recieve 187 | old_net, norm_obs = child.recv() 188 | 189 | # Play! 190 | while True: 191 | obs = env.reset() 192 | done = False 193 | ep_reward = 0 194 | while not done: 195 | # env.render() 196 | if OBS_NORM: 197 | norm_obs.update(np.array([obs])) 198 | obs_norm = np.clip((obs - norm_obs.mean) / np.sqrt(norm_obs.var), -10, 10) 199 | action, value = get_action_and_value(obs_norm, old_net) 200 | else: 201 | action, value = get_action_and_value(obs, old_net) 202 | 203 | # step 204 | _obs, reward, done, _ = env.step(action) 205 | 206 | # store 207 | values.append(value) 208 | 209 | if OBS_NORM: 210 | roll_memory.append([obs_norm, action]) 211 | else: 212 | roll_memory.append([obs, action]) 213 | 214 | rewards.append(reward) 215 | obs = _obs 216 | ep_reward += reward 217 | ep_steps += 1 218 | 219 | if done or ep_steps % roll_len == 0: 220 | if OBS_NORM: 221 | norm_obs.update(np.array([_obs])) 222 | if done: 223 | _value = 0. 224 | else: 225 | if OBS_NORM: 226 | _obs_norm = np.clip((_obs - norm_obs.mean) / np.sqrt(norm_obs.var), -10, 10) 227 | _, _value = get_action_and_value(_obs_norm, old_net) 228 | else: 229 | _, _value = get_action_and_value(_obs, old_net) 230 | 231 | values.append(_value) 232 | train_memory.extend(compute_adv_with_gae(rewards, values, roll_memory)) 233 | roll_memory.clear() 234 | 235 | if ep_steps % roll_len == 0: 236 | child.send(((train_memory, ep_rewards), 'train', rank)) 237 | train_memory.clear() 238 | ep_rewards.clear() 239 | state_dict, norm_obs = child.recv() 240 | old_net.load_state_dict(state_dict) 241 | break 242 | 243 | if done: 244 | episodes += 1 245 | ep_rewards.append(ep_reward) 246 | print('{:3} Episode in {:4} steps, reward {:4} [Process-{}]'.format(episodes, ep_steps, ep_reward, rank)) 247 | 248 | 249 | if __name__ == '__main__': 250 | mp.set_start_method('spawn') 251 | obs_space = env.observation_space.shape[0] 252 | action_space = env.action_space.n 253 | n_eval = 100 254 | net = ActorCriticNet(obs_space, action_space).to(device) 255 | net.share_memory() 256 | param_p = [p for n, p in net.named_parameters() if 'pol' in n] 257 | param_v = [p for n, p in net.named_parameters() if 'val' in n] 258 | optim_p = torch.optim.AdamW(param_p, lr=P_LR, eps=1e-6) 259 | optim_v = torch.optim.AdamW(param_v, lr=V_LR, eps=1e-6) 260 | optimizer = [optim_p, optim_v] 261 | norm_obs = RunningMeanStd(shape=env.observation_space.shape) 262 | 263 | jobs = [] 264 | pipes = [] 265 | trajectory = [] 266 | rewards = deque(maxlen=n_eval) 267 | update = 0 268 | steps = 0 269 | for i in range(N_PROCESS): 270 | parent, child = mp.Pipe() 271 | p = mp.Process(target=roll_out, args=(env, ROLL_LEN//N_PROCESS, i, child), daemon=True) 272 | jobs.append(p) 273 | pipes.append(parent) 274 | 275 | for i in range(N_PROCESS): 276 | pipes[i].send((net, norm_obs)) 277 | jobs[i].start() 278 | 279 | while True: 280 | for i in range(N_PROCESS): 281 | data, msg, rank = pipes[i].recv() 282 | if msg == 'train': 283 | traj, ep_rews = data 284 | trajectory.extend(traj) 285 | rewards.extend(ep_rews) 286 | 287 | if len(rewards) == n_eval: 288 | writer.add_scalar('data/reward', np.mean(rewards), update) 289 | if np.mean(rewards) >= env.spec.reward_threshold: 290 | print('\n{} is sloved! [{} Update]'.format(env.spec.id, update)) 291 | torch.save(net.state_dict(), f'./test/saved_models/{env.spec.id}_up{update}_clear_model_ppo_st.pt') 292 | with open(f'./test/saved_models/{env.spec.id}_up{update}_clear_norm_obs.pkl', 'wb') as f: 293 | pickle.dump(norm_obs, f, pickle.HIGHEST_PROTOCOL) 294 | break 295 | 296 | if len(trajectory) == ROLL_LEN: 297 | state_dict = learn(net, optimizer, trajectory) 298 | update += 1 299 | print(f'Update: {update}') 300 | for i in range(N_PROCESS): 301 | pipes[i].send((state_dict, norm_obs)) 302 | env.close() 303 | -------------------------------------------------------------------------------- /src/ppo_discrete_step_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "colab": {}, 8 | "colab_type": "code", 9 | "id": "IWnm3qot3o1W" 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "import random\n", 14 | "from collections import deque\n", 15 | "from copy import deepcopy\n", 16 | "\n", 17 | "import gym\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "import numpy as np\n", 20 | "import torch\n", 21 | "import torch.nn as nn\n", 22 | "import torch.nn.functional as F\n", 23 | "import torch.optim as optim\n", 24 | "from torch.distributions import Categorical\n", 25 | "from torch.utils.data import DataLoader\n", 26 | "from IPython.display import clear_output\n", 27 | "\n", 28 | "from running_mean_std import RunningMeanStd" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": { 35 | "colab": {}, 36 | "colab_type": "code", 37 | "id": "IWnm3qot3o1W" 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "SEED = 5\n", 42 | "BATCH_SIZE = 2048\n", 43 | "P_LR = 3e-4\n", 44 | "V_LR = 1e-3\n", 45 | "ITER = 80\n", 46 | "CLIP = 0.2\n", 47 | "GAMMA = 0.999\n", 48 | "LAMBDA = 0.97\n", 49 | "BETA = 3.0\n", 50 | "# ENT_COEF = 0.0\n", 51 | "GRAD_NORM = False\n", 52 | "OBS_NORM = True\n", 53 | "VIEW_CURVE = False\n", 54 | "\n", 55 | "# set device\n", 56 | "use_cuda = torch.cuda.is_available()\n", 57 | "device = torch.device('cuda' if use_cuda else 'cpu')\n", 58 | "\n", 59 | "# random seed\n", 60 | "random.seed(SEED)\n", 61 | "np.random.seed(SEED)\n", 62 | "torch.manual_seed(SEED)\n", 63 | "if use_cuda:\n", 64 | " torch.cuda.manual_seed_all(SEED)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": { 71 | "colab": {}, 72 | "colab_type": "code", 73 | "id": "IWnm3qot3o1W" 74 | }, 75 | "outputs": [], 76 | "source": [ 77 | "class ActorCriticNet(nn.Module):\n", 78 | " def __init__(self, obs_space, action_space):\n", 79 | " super().__init__()\n", 80 | " h = 32\n", 81 | " self.pol = nn.Sequential(\n", 82 | " nn.Linear(obs_space, h),\n", 83 | " nn.Tanh(),\n", 84 | " nn.Linear(h, h),\n", 85 | " nn.Tanh(),\n", 86 | " nn.Linear(h, action_space)\n", 87 | " )\n", 88 | " self.val = nn.Sequential(\n", 89 | " nn.Linear(obs_space, h),\n", 90 | " nn.Tanh(),\n", 91 | " nn.Linear(h, h),\n", 92 | " nn.Tanh(),\n", 93 | " nn.Linear(h, 1)\n", 94 | " )\n", 95 | " self.log_softmax = nn.LogSoftmax(dim=-1)\n", 96 | "\n", 97 | " def forward(self, x):\n", 98 | " logit = self.pol(x).reshape(x.shape[0], -1)\n", 99 | " log_p = self.log_softmax(logit)\n", 100 | " v = self.val(x).reshape(x.shape[0], 1)\n", 101 | " return log_p, v" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 4, 107 | "metadata": { 108 | "colab": {}, 109 | "colab_type": "code", 110 | "id": "IWnm3qot3o1W" 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "losses = []\n", 115 | "kl_divs = []\n", 116 | "\n", 117 | "def learn(net, old_net, optimizer, train_memory):\n", 118 | " net.train()\n", 119 | " old_net.train()\n", 120 | " dataloader = DataLoader(\n", 121 | " train_memory,\n", 122 | " shuffle=False,\n", 123 | " batch_size=BATCH_SIZE,\n", 124 | " pin_memory=True,\n", 125 | " num_workers=0,\n", 126 | " )\n", 127 | " for _ in range(ITER): \n", 128 | " for i, (s, a, ret, adv) in enumerate(dataloader):\n", 129 | " s = s.to(device).float()\n", 130 | " a = a.to(device).long()\n", 131 | " ret = ret.to(device).float()\n", 132 | " adv = adv.to(device).float()\n", 133 | " \n", 134 | " with torch.no_grad():\n", 135 | " log_p_old, v_old = old_net(s)\n", 136 | " log_p_act_old = log_p_old[range(BATCH_SIZE), a]\n", 137 | "\n", 138 | " log_p, v = net(s)\n", 139 | " log_p_act = log_p[range(BATCH_SIZE), a]\n", 140 | " p_ratio = (log_p_act - log_p_act_old).exp()\n", 141 | " p_ratio_clip = torch.clamp(p_ratio, 1 - CLIP, 1 + CLIP)\n", 142 | " p_loss = -(torch.min(p_ratio * adv, p_ratio_clip * adv).mean())\n", 143 | "# kl_div = (log_p.exp() * (log_p - log_p_old)).sum(dim=-1)\n", 144 | "# p_loss = -(p_ratio * adv - BETA * kl_div).mean()\n", 145 | " v_loss = (ret - v).pow(2).mean()\n", 146 | "# log_p, _ = net(s)\n", 147 | "# entropy = -(log_p.exp() * log_p).sum(dim=1).mean()\n", 148 | " # loss\n", 149 | "# loss = p_loss + v_loss - ENT_COEF * entropy\n", 150 | " kl_div = ((log_p.exp() * (log_p - log_p_old)).sum(dim=-1).mean()).detach().item()\n", 151 | " loss = p_loss + v_loss\n", 152 | " losses.append(loss.item())\n", 153 | " kl_divs.append(kl_div)\n", 154 | " if kl_div <= 0.01 * 1.5:\n", 155 | " optimizer[0].zero_grad()\n", 156 | " p_loss.backward()\n", 157 | " if GRAD_NORM:\n", 158 | " nn.utils.clip_grad_norm_(net.parameters() , max_norm=1.0)\n", 159 | " optimizer[0].step()\n", 160 | " else:\n", 161 | " if not VIEW_CURVE:\n", 162 | " print(\"Pass the Pi update!\")\n", 163 | " optimizer[1].zero_grad()\n", 164 | " v_loss.backward()\n", 165 | " if GRAD_NORM:\n", 166 | " nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)\n", 167 | " optimizer[1].step()\n", 168 | "\n", 169 | "\n", 170 | "def get_action_and_value(obs, old_net):\n", 171 | " old_net.eval()\n", 172 | " with torch.no_grad():\n", 173 | " state = torch.tensor([obs]).to(device).float()\n", 174 | " log_p, v = old_net(state)\n", 175 | " m = Categorical(log_p.exp())\n", 176 | " action = m.sample()\n", 177 | " return action.item(), v.item()\n", 178 | "\n", 179 | "\n", 180 | "def compute_adv_with_gae(rewards, values, roll_memory):\n", 181 | " rew = np.array(rewards, np.float32)\n", 182 | " val = np.array(values[:-1], np.float32)\n", 183 | " _val = np.array(values[1:], np.float32)\n", 184 | " delta = rew + GAMMA * _val - val\n", 185 | " dis_r = np.array([GAMMA**(i) * r for i, r in enumerate(rewards)], np.float32)\n", 186 | " gae_dt = np.array([(GAMMA * LAMBDA)**(i) * dt for i, dt in enumerate(delta.tolist())], np.float32)\n", 187 | " for i, data in enumerate(roll_memory):\n", 188 | " data.append(sum(dis_r[i:] / GAMMA**(i)))\n", 189 | " data.append(sum(gae_dt[i:] / (GAMMA * LAMBDA)**(i)))\n", 190 | "\n", 191 | " rewards.clear()\n", 192 | " values.clear()\n", 193 | " return roll_memory\n", 194 | "\n", 195 | "\n", 196 | "def plot():\n", 197 | " clear_output(True)\n", 198 | " plt.figure(figsize=(16, 5))\n", 199 | " plt.subplot(131)\n", 200 | " plt.plot(ep_rewards, alpha=0.5)\n", 201 | " plt.subplot(131)\n", 202 | " plt.plot(reward_eval)\n", 203 | " plt.title(f'Reward: {reward_eval[-1]:.0f}')\n", 204 | " if losses:\n", 205 | " plt.subplot(132)\n", 206 | " plt.plot(losses, alpha=0.5)\n", 207 | " plt.title(f'Loss: {losses[-1]:.2f}')\n", 208 | " plt.subplot(133)\n", 209 | " plt.plot(kl_divs, alpha=0.5)\n", 210 | " plt.title(f'kl_div: {kl_divs[-1]:.4f}')\n", 211 | " plt.show()" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 5, 217 | "metadata": { 218 | "colab": {}, 219 | "colab_type": "code", 220 | "id": "IWnm3qot3o1W", 221 | "scrolled": false 222 | }, 223 | "outputs": [ 224 | { 225 | "name": "stdout", 226 | "output_type": "stream", 227 | "text": [ 228 | "n_episodes: 625\n", 229 | "reward_threshold: 300\n" 230 | ] 231 | }, 232 | { 233 | "name": "stderr", 234 | "output_type": "stream", 235 | "text": [ 236 | "/workspace/media/ai/Storage/gym/gym/logger.py:30: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n", 237 | " warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))\n" 238 | ] 239 | }, 240 | { 241 | "ename": "TypeError", 242 | "evalue": "new(): argument 'size' must be tuple of ints, but found element of type tuple at pos 2", 243 | "output_type": "error", 244 | "traceback": [ 245 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 246 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 247 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;31m# make nerual networks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0mnet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mActorCriticNet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobs_space\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction_space\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 36\u001b[0m \u001b[0mold_net\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeepcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 248 | "\u001b[0;32m\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, obs_space, action_space)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m32\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m self.pol = nn.Sequential(\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobs_space\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTanh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 249 | "\u001b[0;32m/opt/conda/lib/python3.6/site-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, in_features, out_features, bias)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0min_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0min_features\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout_features\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 79\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mParameter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_features\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 80\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mParameter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 250 | "\u001b[0;31mTypeError\u001b[0m: new(): argument 'size' must be tuple of ints, but found element of type tuple at pos 2" 251 | ] 252 | } 253 | ], 254 | "source": [ 255 | "# make an environment\n", 256 | "# env = gym.make('CartPole-v0')\n", 257 | "# env = gym.make('CartPole-v1')\n", 258 | "# env = gym.make('MountainCar-v0')\n", 259 | "# env = gym.make('LunarLander-v2')\n", 260 | "env = gym.make('BipedalWalker-v3')\n", 261 | "\n", 262 | "env.seed(SEED)\n", 263 | "obs_space = env.observation_space.shape\n", 264 | "action_space = env.action_space.shape\n", 265 | "\n", 266 | "# hyperparameter\n", 267 | "n_episodes = int(1e6 / env.spec.max_episode_steps)\n", 268 | "roll_len = 2048\n", 269 | "n_eval = 100\n", 270 | "\n", 271 | "print(f\"n_episodes: {n_episodes}\")\n", 272 | "print(f\"reward_threshold: {env.spec.reward_threshold}\")\n", 273 | "\n", 274 | "# global values\n", 275 | "steps = 0\n", 276 | "ep_rewards = []\n", 277 | "reward_eval = []\n", 278 | "is_rollout = False\n", 279 | "is_solved = False\n", 280 | "\n", 281 | "# make memories\n", 282 | "train_memory = []\n", 283 | "roll_memory = []\n", 284 | "rewards = []\n", 285 | "values = []\n", 286 | "norm_obs = RunningMeanStd(shape=env.observation_space.shape)\n", 287 | "\n", 288 | "# make nerual networks\n", 289 | "net = ActorCriticNet(obs_space, action_space).to(device)\n", 290 | "old_net = deepcopy(net)\n", 291 | "\n", 292 | "param_p = [p for n, p in net.named_parameters() if 'pol' in n]\n", 293 | "param_v = [p for n, p in net.named_parameters() if 'val' in n]\n", 294 | "optim_p = torch.optim.AdamW(param_p, lr=P_LR, eps=1e-6, weight_decay=0.01)\n", 295 | "optim_v = torch.optim.AdamW(param_v, lr=V_LR, eps=1e-6, weight_decay=0.01)\n", 296 | "optimizer = [optim_p, optim_v]\n", 297 | "\n", 298 | "# play!\n", 299 | "for i in range(1, n_episodes + 1):\n", 300 | " obs = env.reset()\n", 301 | " done = False\n", 302 | " ep_reward = 0\n", 303 | " while not done:\n", 304 | "# env.render()\n", 305 | " if OBS_NORM:\n", 306 | " print(obs)\n", 307 | " norm_obs.update(np.array([obs]))\n", 308 | " obs_norm = np.clip((obs - norm_obs.mean) / np.sqrt(norm_obs.var), -10, 10)\n", 309 | " action, value = get_action_and_value(obs_norm, old_net)\n", 310 | " else:\n", 311 | " action, value = get_action_and_value(obs, old_net)\n", 312 | " \n", 313 | " _obs, reward, done, _ = env.step(action)\n", 314 | "\n", 315 | " # store\n", 316 | " if OBS_NORM:\n", 317 | " roll_memory.append([obs_norm, action])\n", 318 | " else:\n", 319 | " roll_memory.append([obs, action])\n", 320 | " \n", 321 | "\n", 322 | " rewards.append(reward)\n", 323 | " values.append(value)\n", 324 | "\n", 325 | " obs = _obs\n", 326 | " steps += 1\n", 327 | " ep_reward += reward\n", 328 | "\n", 329 | " if done or steps % roll_len == 0:\n", 330 | " if OBS_NORM:\n", 331 | " norm_obs.update(np.array([_obs]))\n", 332 | " if done:\n", 333 | " _value = 0.\n", 334 | " else:\n", 335 | " if OBS_NORM:\n", 336 | " _obs_norm = np.clip((_obs - norm_obs.mean) / np.sqrt(norm_obs.var), -10, 10)\n", 337 | " _, _value = get_action_and_value(_obs_norm, old_net)\n", 338 | " else:\n", 339 | " _, _value = get_action_and_value(_obs, old_net)\n", 340 | "\n", 341 | " values.append(_value)\n", 342 | " train_memory.extend(compute_adv_with_gae(rewards, values, roll_memory))\n", 343 | " roll_memory.clear()\n", 344 | "\n", 345 | " if steps % roll_len == 0:\n", 346 | " # adv normalize\n", 347 | " advs = []\n", 348 | " for m in train_memory:\n", 349 | " advs.append(m[3])\n", 350 | " advs = np.array(advs)\n", 351 | " for m in train_memory:\n", 352 | " m[3] = (m[3] - np.mean(advs)) / np.std(advs)\n", 353 | " \n", 354 | " learn(net, old_net, optimizer, train_memory)\n", 355 | " old_net.load_state_dict(net.state_dict())\n", 356 | " train_memory.clear()\n", 357 | " break\n", 358 | "\n", 359 | " if done:\n", 360 | " ep_rewards.append(ep_reward)\n", 361 | " reward_eval.append(np.mean(ep_rewards[-n_eval:]))\n", 362 | " if VIEW_CURVE:\n", 363 | " plot()\n", 364 | " else:\n", 365 | " print(f'{i:3} Episode in {steps:5} steps, reward {ep_reward:.2f}')\n", 366 | "\n", 367 | " if len(ep_rewards) >= n_eval:\n", 368 | " if reward_eval[-1] >= env.spec.reward_threshold:\n", 369 | " print('\\n{} is sloved! {:3} Episode in {:3} steps'.format(\n", 370 | " env.spec.id, i, steps))\n", 371 | " torch.save(net.state_dict(),\n", 372 | " f'./test/saved_models/{env.spec.id}_ep{i}_clear_model_ppo_st.pt')\n", 373 | " break\n", 374 | "env.close()" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 1, 380 | "metadata": {}, 381 | "outputs": [ 382 | { 383 | "name": "stdout", 384 | "output_type": "stream", 385 | "text": [ 386 | "[NbConvertApp] Converting notebook PPO.ipynb to script\n", 387 | "[NbConvertApp] Writing 8573 bytes to PPO.py\n" 388 | ] 389 | } 390 | ], 391 | "source": [ 392 | "# !jupyter nbconvert --to script PPO.ipynb" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "metadata": {}, 399 | "outputs": [], 400 | "source": [] 401 | } 402 | ], 403 | "metadata": { 404 | "colab": { 405 | "collapsed_sections": [], 406 | "name": "C51_tensorflow.ipynb", 407 | "provenance": [], 408 | "version": "0.3.2" 409 | }, 410 | "kernelspec": { 411 | "display_name": "Python 3", 412 | "language": "python", 413 | "name": "python3" 414 | }, 415 | "language_info": { 416 | "codemirror_mode": { 417 | "name": "ipython", 418 | "version": 3 419 | }, 420 | "file_extension": ".py", 421 | "mimetype": "text/x-python", 422 | "name": "python", 423 | "nbconvert_exporter": "python", 424 | "pygments_lexer": "ipython3", 425 | "version": "3.6.10" 426 | } 427 | }, 428 | "nbformat": 4, 429 | "nbformat_minor": 1 430 | } 431 | -------------------------------------------------------------------------------- /src/ppo_discrete_step_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # %% 4 | import random 5 | from collections import deque 6 | from copy import deepcopy 7 | 8 | import gym 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from torch.distributions import Categorical 16 | from torch.utils.data import DataLoader 17 | from IPython.display import clear_output 18 | 19 | from running_mean_std import RunningMeanStd 20 | 21 | 22 | # %% 23 | SEED = 5 24 | BATCH_SIZE = 32 25 | LR = 3e-4 26 | EPOCHS = 20 27 | CLIP = 0.2 28 | GAMMA = 0.999 29 | LAMBDA = 0.97 30 | ENT_COEF = 0.0 31 | V_COEF = 2.0 32 | V_CLIP = False 33 | LIN_REDUCE = False 34 | GRAD_NORM = False 35 | OBS_NORM = True 36 | REW_NORM = False 37 | # set device 38 | use_cuda = torch.cuda.is_available() 39 | print('cuda:', use_cuda) 40 | device = torch.device('cuda' if use_cuda else 'cpu') 41 | 42 | # random seed 43 | random.seed(SEED) 44 | np.random.seed(SEED) 45 | torch.manual_seed(SEED) 46 | if use_cuda: 47 | torch.cuda.manual_seed_all(SEED) 48 | 49 | 50 | # %% 51 | 52 | 53 | class ActorCriticNet(nn.Module): 54 | def __init__(self, obs_space, action_space): 55 | super().__init__() 56 | h = 32 57 | # self.head = nn.Sequential( 58 | # nn.Linear(obs_space, h), 59 | # nn.Tanh() 60 | # ) 61 | self.pol = nn.Sequential( 62 | nn.Linear(obs_space, h), 63 | nn.Tanh(), 64 | nn.Linear(h, h), 65 | nn.Tanh(), 66 | nn.Linear(h, action_space) 67 | ) 68 | self.val = nn.Sequential( 69 | nn.Linear(obs_space, h), 70 | nn.Tanh(), 71 | nn.Linear(h, h), 72 | nn.Tanh(), 73 | nn.Linear(h, 1) 74 | ) 75 | self.log_softmax = nn.LogSoftmax(dim=-1) 76 | 77 | def forward(self, x): 78 | # x = self.head(x) 79 | logit = self.pol(x).reshape(x.shape[0], -1) 80 | log_p = self.log_softmax(logit) 81 | v = self.val(x).reshape(x.shape[0], 1) 82 | 83 | return log_p, v 84 | 85 | 86 | # %% 87 | 88 | 89 | losses = [] 90 | 91 | 92 | def learn(net, old_net, optimizer, train_memory): 93 | global CLIP, LR 94 | global total_epochs 95 | net.train() 96 | old_net.train() 97 | dataloader = DataLoader( 98 | train_memory, 99 | shuffle=True, 100 | batch_size=BATCH_SIZE, 101 | pin_memory=False, 102 | num_workers=0, 103 | ) 104 | for epoch in range(EPOCHS): 105 | if LIN_REDUCE: 106 | clip = CLIP - (CLIP * epoch / total_epochs) 107 | lr = LR - (LR * epoch / total_epochs) 108 | for param_group in optimizer.param_groups: 109 | param_group['lr'] = lr 110 | else: 111 | clip = CLIP 112 | 113 | for (s, a, ret, adv) in dataloader: 114 | s_batch = s.to(device).float() 115 | a_batch = a.to(device).long() 116 | ret_batch = ret.to(device).float() 117 | ret_batch = (ret_batch - ret_batch.mean()) / ret_batch.std() 118 | adv_batch = adv.to(device).float() 119 | adv_batch = (adv_batch - adv_batch.mean()) / adv_batch.std() 120 | 121 | for optim in optimizer: 122 | optim.zero_grad() 123 | 124 | with torch.no_grad(): 125 | log_p_batch_old, v_batch_old = old_net(s_batch) 126 | log_p_acting_old = log_p_batch_old[range(BATCH_SIZE), a_batch] 127 | 128 | log_p_batch, v_batch = net(s_batch) 129 | log_p_acting = log_p_batch[range(BATCH_SIZE), a_batch] 130 | p_ratio = (log_p_acting - log_p_acting_old).exp() 131 | p_ratio_clip = torch.clamp(p_ratio, 1 - clip, 1 + clip) 132 | p_loss = -(torch.min(p_ratio * adv_batch, p_ratio_clip * adv_batch)).mean() 133 | # approx_kl = (log_p_batch - log_p_batch_old).mean().detach().item() 134 | # kl_div = (log_p_batch.exp() * (log_p_batch - log_p_batch_old)).sum(dim=-1) 135 | # p_loss = -(p_ratio * adv_batch - 3.0 * kl_div).mean() 136 | if V_CLIP: 137 | v_clip = v_batch_old + torch.clamp(v_batch - v_batch_old, -clip, clip) 138 | v_loss1 = (ret_batch - v_clip).pow(2) 139 | v_loss2 = (ret_batch - v_batch).pow(2) 140 | v_loss = 0.5 * torch.max(v_loss1, v_loss2).mean() 141 | else: 142 | v_loss = 0.5 * (ret_batch - v_batch).pow(2).mean() 143 | 144 | # log_p, _ = net(s_batch) 145 | # entropy = -(log_p.exp() * log_p).sum(dim=1).mean() 146 | 147 | # loss 148 | # loss = p_loss + V_COEF * v_loss - ENT_COEF * entropy 149 | loss = p_loss + V_COEF * v_loss 150 | 151 | losses.append(loss.item()) 152 | # if approx_kl <= 1.5* 0.01: 153 | # p_loss.backward() 154 | # optimizer[0].step() 155 | # v_loss.backward() 156 | # optimizer[1].step() 157 | loss.backward() 158 | if GRAD_NORM: 159 | nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.5) 160 | for optim in optimizer: 161 | optim.step() 162 | 163 | 164 | def get_action_and_value(obs, old_net): 165 | old_net.eval() 166 | with torch.no_grad(): 167 | state = torch.tensor([obs]).to(device).float() 168 | log_p, v = old_net(state) 169 | m = Categorical(log_p.exp()) 170 | action = m.sample() 171 | # print(action) 172 | 173 | return action.item(), v.item() 174 | 175 | 176 | def compute_adv_with_gae(rewards, values, roll_memory): 177 | rew = np.array(rewards, np.float64) 178 | val = np.array(values[:-1], np.float64) 179 | _val = np.array(values[1:], np.float64) 180 | delta = rew + GAMMA * _val - val 181 | dis_r = np.array([GAMMA**(i) * r for i, r in enumerate(rewards)], np.float64) 182 | gae_dt = np.array([(GAMMA * LAMBDA)**(i) * dt for i, 183 | dt in enumerate(delta.tolist())], np.float64) 184 | for i, data in enumerate(roll_memory): 185 | data.append(sum(dis_r[i:] / GAMMA**(i))) 186 | data.append(sum(gae_dt[i:] / (GAMMA * LAMBDA)**(i))) 187 | 188 | rewards.clear() 189 | values.clear() 190 | 191 | return roll_memory 192 | 193 | 194 | def plot(): 195 | clear_output(True) 196 | plt.figure(figsize=(16, 5)) 197 | plt.subplot(121) 198 | plt.plot(ep_rewards, alpha=0.5) 199 | plt.subplot(121) 200 | plt.plot(reward_eval) 201 | plt.title(f'Reward: ' 202 | f'{reward_eval[-1]:.2f}') 203 | plt.subplot(122) 204 | plt.plot(losses, alpha=0.5) 205 | plt.title(f'Loss: ' 206 | f'{np.mean(list(reversed(losses))[: n_eval]):.2f}') 207 | plt.show() 208 | 209 | 210 | # %% 211 | 212 | 213 | # make an environment 214 | # env = gym.make('CartPole-v0') 215 | env = gym.make('CartPole-v1') 216 | # env = gym.make('MountainCar-v0') 217 | # env = gym.make('LunarLander-v2') 218 | 219 | env.seed(SEED) 220 | obs_space = env.observation_space.shape[0] 221 | action_space = env.action_space.n 222 | 223 | # hyperparameter 224 | n_episodes = 100000 225 | roll_len = 4096 226 | total_epochs = roll_len // BATCH_SIZE 227 | n_eval = 10 228 | 229 | # global values 230 | steps = 0 231 | ep_rewards = [] 232 | reward_eval = [] 233 | is_rollout = False 234 | is_solved = False 235 | 236 | # make memories 237 | train_memory = [] 238 | roll_memory = [] 239 | obses = [] 240 | rews = [] 241 | rewards = [] 242 | values = [] 243 | norm_obs = RunningMeanStd(shape=env.observation_space.shape) 244 | norm_rew = RunningMeanStd() 245 | 246 | # make nerual networks 247 | net = ActorCriticNet(obs_space, action_space).to(device) 248 | old_net = deepcopy(net) 249 | 250 | # grouped_parameters = [ 251 | # {'params': [p for n, p in net.named_parameters() if n == 'val'], 'lr': LR * 0.1}, 252 | # {'params': [p for n, p in net.named_parameters() if n != 'val'], 'lr': LR} 253 | # ] 254 | param_p = [p for n, p in net.named_parameters() if 'val' not in n] 255 | param_v = [p for n, p in net.named_parameters() if 'val' in n] 256 | optim_p = torch.optim.AdamW(param_p, lr=LR, eps=1e-6) 257 | optim_v = torch.optim.AdamW(param_v, lr=0.001, eps=1e-6) 258 | optimizer = [optim_p, optim_v] 259 | # optimizer = [torch.optim.AdamW(net.parameters(), lr=LR, eps=1e-6)] 260 | 261 | # play! 262 | for i in range(1, n_episodes + 1): 263 | obs = env.reset() 264 | done = False 265 | ep_reward = 0 266 | while not done: 267 | # env.render() 268 | obses.append(obs) 269 | if OBS_NORM: 270 | obs_norm = np.clip((obs - norm_obs.mean) / np.sqrt(norm_obs.var), -10, 10) 271 | action, value = get_action_and_value(obs_norm, old_net) 272 | else: 273 | action, value = get_action_and_value(obs, old_net) 274 | 275 | _obs, reward, done, _ = env.step(action) 276 | 277 | # store 278 | if OBS_NORM: 279 | roll_memory.append([obs_norm, action]) 280 | else: 281 | roll_memory.append([obs, action]) 282 | 283 | if REW_NORM: 284 | rew_norm = np.clip((reward - norm_rew.mean) / np.sqrt(norm_rew.var), -10, 10) 285 | rewards.append(rew_norm) 286 | rews.append(reward) 287 | else: 288 | rewards.append(reward) 289 | 290 | values.append(value) 291 | 292 | obs = _obs 293 | steps += 1 294 | ep_reward += reward 295 | 296 | if done or steps % roll_len == 0: 297 | if done: 298 | obses.append(_obs) 299 | _value = 0. 300 | else: 301 | if OBS_NORM: 302 | _obs_norm = np.clip((_obs - norm_obs.mean) / np.sqrt(norm_obs.var), -10, 10) 303 | _, _value = get_action_and_value(_obs_norm, old_net) 304 | else: 305 | _, _value = get_action_and_value(_obs, old_net) 306 | 307 | values.append(_value) 308 | train_memory.extend(compute_adv_with_gae(rewards, values, roll_memory)) 309 | roll_memory.clear() 310 | 311 | if steps % roll_len == 0: 312 | learn(net, old_net, optimizer, train_memory) 313 | old_net.load_state_dict(net.state_dict()) 314 | if OBS_NORM: 315 | norm_obs.update(np.array(obses)) 316 | if REW_NORM: 317 | norm_rew.update(np.array(rews)) 318 | train_memory.clear() 319 | obses.clear() 320 | rews.clear() 321 | 322 | if done: 323 | ep_rewards.append(ep_reward) 324 | reward_eval.append(np.mean(list(reversed(ep_rewards))[: n_eval])) 325 | # plot() 326 | print('{:3} Episode in {:5} steps, reward {:.2f}'.format( 327 | i, steps, ep_reward)) 328 | 329 | if len(ep_rewards) >= n_eval: 330 | if reward_eval[-1] >= env.spec.reward_threshold: 331 | print('\n{} is sloved! {:3} Episode in {:3} steps'.format( 332 | env.spec.id, i, steps)) 333 | torch.save(net.state_dict(), 334 | f'./test/saved_models/{env.spec.id}_ep{i}_clear_model_ppo_st.pt') 335 | break 336 | env.close() 337 | 338 | 339 | # %% 340 | 341 | 342 | # env.spec.reward_threshold 343 | 344 | 345 | # %% 346 | 347 | 348 | # [ 349 | # ('CartPole-v0', 889, 2048, 0.2, 10, 0.5, 0.01, False, 0.999, 0.98), 350 | # ('CartPole-v1', 801, 2048, 0.2, 10, 0.5, 0.01, False, 0.999, 0.98), 675, 351 | # ('MountainCar-v0', None), 352 | # ('LunarLander-v2', 876, 2048, 0.2, 10, 1.0, 0.01, False, 0.999, 0.98) 353 | # ] 354 | 355 | 356 | # %% 357 | 358 | 359 | # get_ipython().system('jupyter nbconvert --to script ppo_discrete_step_test.ipynb') 360 | 361 | -------------------------------------------------------------------------------- /src/running_mean_std.py: -------------------------------------------------------------------------------- 1 | # https://github.com/openai/baselines/blob/master/baselines/common/running_mean_std.py 2 | import numpy as np 3 | 4 | 5 | class RunningMeanStd(object): 6 | def __init__(self, epsilon=1e-4, shape=()): 7 | self.mean = np.zeros(shape, np.float64) 8 | self.var = np.ones(shape, np.float64) 9 | self.count = epsilon 10 | 11 | def update(self, x): 12 | batch_mean = np.mean(x, axis=0) 13 | batch_var = np.var(x, axis=0) 14 | batch_count = x.shape[0] 15 | self.update_from_moments(batch_mean, batch_var, batch_count) 16 | 17 | def update_from_moments(self, batch_mean, batch_var, batch_count): 18 | self.mean, self.var, self.count = update_mean_var_count_from_moments( 19 | self.mean, self.var, self.count, 20 | batch_mean, batch_var, batch_count) 21 | 22 | 23 | def update_mean_var_count_from_moments(mean, var, count, 24 | batch_mean, batch_var, batch_count): 25 | delta = batch_mean - mean 26 | tot_count = count + batch_count 27 | 28 | new_mean = mean + delta * batch_count / tot_count 29 | m_a = var * count 30 | m_b = batch_var * batch_count 31 | M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count 32 | new_var = M2 / tot_count 33 | new_count = tot_count 34 | 35 | return new_mean, new_var, new_count 36 | -------------------------------------------------------------------------------- /src/soft_actor_critic.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "colab": {}, 8 | "colab_type": "code", 9 | "id": "IWnm3qot3o1W" 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "import random\n", 14 | "import math\n", 15 | "from collections import deque\n", 16 | "from copy import deepcopy\n", 17 | "\n", 18 | "import gym\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "import numpy as np\n", 21 | "import torch\n", 22 | "import torch.nn as nn\n", 23 | "import torch.optim as optim\n", 24 | "from torch.distributions import Categorical, Dirichlet\n", 25 | "from torch.utils.data import DataLoader\n", 26 | "from IPython.display import clear_output" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": { 33 | "colab": { 34 | "base_uri": "https://localhost:8080/", 35 | "height": 35 36 | }, 37 | "colab_type": "code", 38 | "executionInfo": { 39 | "elapsed": 708, 40 | "status": "ok", 41 | "timestamp": 1534482400648, 42 | "user": { 43 | "displayName": "윤승제", 44 | "photoUrl": "//lh5.googleusercontent.com/-EucKC7DmcQI/AAAAAAAAAAI/AAAAAAAAAGA/gQU1NPEmNFA/s50-c-k-no/photo.jpg", 45 | "userId": "105654037995838004821" 46 | }, 47 | "user_tz": -540 48 | }, 49 | "id": "maRVADiTlzHD", 50 | "outputId": "783b7610-95c2-4b54-b2ce-d8e853c484ba" 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "SEED = 1\n", 55 | "BATCH_SIZE = 32\n", 56 | "LR = 0.0003\n", 57 | "UP_COEF = 0.01\n", 58 | "ENT_COEF = 0.01\n", 59 | "GAMMA = 0.99\n", 60 | "\n", 61 | "# set device\n", 62 | "use_cuda = torch.cuda.is_available()\n", 63 | "device = torch.device('cuda' if use_cuda else 'cpu')\n", 64 | "\n", 65 | "# random seed\n", 66 | "random.seed(SEED)\n", 67 | "np.random.seed(SEED)\n", 68 | "torch.manual_seed(SEED)\n", 69 | "if use_cuda:\n", 70 | " torch.cuda.manual_seed_all(SEED)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 3, 76 | "metadata": { 77 | "colab": {}, 78 | "colab_type": "code", 79 | "id": "9Ffkl_5C4R81" 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "class QNet(nn.Module):\n", 84 | " def __init__(self, obs_space, action_space):\n", 85 | " super().__init__()\n", 86 | "\n", 87 | " self.head = nn.Sequential(\n", 88 | " nn.Linear(obs_space, obs_space*10),\n", 89 | " nn.SELU()\n", 90 | " )\n", 91 | " self.fc = nn.Sequential(\n", 92 | " nn.Linear(obs_space*10, 512),\n", 93 | " nn.SELU(),\n", 94 | " nn.Linear(512, 512),\n", 95 | " nn.SELU(),\n", 96 | " nn.Linear(512, action_space)\n", 97 | " )\n", 98 | "\n", 99 | " def forward(self, x):\n", 100 | " out = self.head(x)\n", 101 | " q = self.fc(out).reshape(out.shape[0], -1)\n", 102 | "\n", 103 | " return q\n", 104 | "\n", 105 | "\n", 106 | "class PolicyNet(nn.Module):\n", 107 | " def __init__(self, obs_space, action_space):\n", 108 | " super().__init__()\n", 109 | "\n", 110 | " self.head = nn.Sequential(\n", 111 | " nn.Linear(obs_space, obs_space*10),\n", 112 | " nn.SELU()\n", 113 | " )\n", 114 | "\n", 115 | " self.fc = nn.Sequential(\n", 116 | " nn.Linear(obs_space*10, 512),\n", 117 | " nn.SELU(),\n", 118 | " nn.Linear(512, 512),\n", 119 | " nn.SELU(),\n", 120 | " nn.Linear(512, action_space)\n", 121 | " )\n", 122 | "\n", 123 | " self.log_softmax = nn.LogSoftmax(dim=-1)\n", 124 | "\n", 125 | " def forward(self, x):\n", 126 | " out = self.head(x)\n", 127 | " logit = self.fc(out).reshape(out.shape[0], -1)\n", 128 | " log_p = self.log_softmax(logit)\n", 129 | "\n", 130 | " return log_p\n", 131 | "\n", 132 | "\n", 133 | "class ValueNet(nn.Module):\n", 134 | " def __init__(self, obs_space):\n", 135 | " super().__init__()\n", 136 | "\n", 137 | " self.head = nn.Sequential(\n", 138 | " nn.Linear(obs_space, obs_space*10),\n", 139 | " nn.SELU()\n", 140 | " )\n", 141 | "\n", 142 | " self.fc = nn.Sequential(\n", 143 | " nn.Linear(obs_space*10, 512),\n", 144 | " nn.SELU(),\n", 145 | " nn.Linear(512, 512),\n", 146 | " nn.SELU(),\n", 147 | " nn.Linear(512, 1)\n", 148 | " )\n", 149 | "\n", 150 | " def forward(self, x):\n", 151 | " out = self.head(x)\n", 152 | " v = self.fc(out).reshape(out.shape[0], 1)\n", 153 | "\n", 154 | " return v" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 4, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "losses = []\n", 164 | "\n", 165 | "\n", 166 | "def learn(q_net, p_net, v_net, v_tgt, optimizer, rep_memory):\n", 167 | " global action_space\n", 168 | " \n", 169 | " q_net.train()\n", 170 | " p_net.train()\n", 171 | " v_net.train()\n", 172 | " v_tgt.train()\n", 173 | "\n", 174 | " train_data = random.sample(rep_memory, BATCH_SIZE)\n", 175 | " dataloader = DataLoader(train_data,\n", 176 | " batch_size=BATCH_SIZE,\n", 177 | " pin_memory=use_cuda)\n", 178 | "\n", 179 | " for i, (s, a, r, _s, d) in enumerate(dataloader):\n", 180 | " s_batch = s.to(device).float()\n", 181 | " a_batch = a.to(device).long()\n", 182 | " _s_batch = _s.to(device).float()\n", 183 | " r_batch = r.to(device).float()\n", 184 | " done_mask = 1. - d.to(device).float()\n", 185 | " discount = torch.full_like(r_batch, GAMMA)\n", 186 | " \n", 187 | " q_batch = q_net(s_batch)\n", 188 | " q_acting = q_batch[range(BATCH_SIZE), a_batch]\n", 189 | " q_acting_ = q_acting.detach() \n", 190 | " \n", 191 | " v_batch = v_net(s_batch)\n", 192 | " \n", 193 | " with torch.no_grad():\n", 194 | " _log_p_batch = p_net(_s_batch)\n", 195 | " _log_p_acting = _log_p_batch[range(BATCH_SIZE), a_batch]\n", 196 | " v_target = q_acting_ - ENT_COEF * _log_p_acting\n", 197 | " \n", 198 | " v_loss = (v_target - v_batch).pow(2).mean()\n", 199 | " \n", 200 | " with torch.no_grad():\n", 201 | " _v_batch = v_tgt(_s_batch) * done_mask\n", 202 | " q_target = r_batch + _v_batch * discount\n", 203 | " \n", 204 | " q_loss = (q_target - q_acting).pow(2).mean()\n", 205 | " \n", 206 | " log_p_batch = p_net(s_batch)\n", 207 | " q_batch_ = q_net(s_batch)\n", 208 | " entropy = -(ENT_COEF * log_p_batch.exp() * q_batch).sum(dim=-1).mean()\n", 209 | " \n", 210 | " loss = v_loss + q_loss + entropy\n", 211 | " \n", 212 | " optimizer.zero_grad()\n", 213 | " loss.backward()\n", 214 | "# nn.utils.clip_grad_norm_(total_params, max_norm=0.5)\n", 215 | " optimizer.step()\n", 216 | "\n", 217 | "\n", 218 | "def select_action(obs, p_net):\n", 219 | " p_net.eval()\n", 220 | " with torch.no_grad():\n", 221 | " state = torch.tensor([obs]).to(device).float()\n", 222 | " log_p = p_net(state)\n", 223 | " m = Categorical(log_p.exp())\n", 224 | " action = m.sample()\n", 225 | "\n", 226 | " return action.item()\n", 227 | "\n", 228 | "\n", 229 | "def plot():\n", 230 | " clear_output(True)\n", 231 | " plt.figure(figsize=(16, 5))\n", 232 | " plt.subplot(121)\n", 233 | " plt.plot(ep_rewards)\n", 234 | " plt.title('Reward')\n", 235 | " plt.subplot(122)\n", 236 | " plt.plot(losses)\n", 237 | " plt.title('Loss')\n", 238 | " plt.show()" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "## Main" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": { 252 | "colab": { 253 | "base_uri": "https://localhost:8080/", 254 | "height": 3377 255 | }, 256 | "colab_type": "code", 257 | "executionInfo": { 258 | "elapsed": 135196, 259 | "status": "ok", 260 | "timestamp": 1534482559393, 261 | "user": { 262 | "displayName": "윤승제", 263 | "photoUrl": "//lh5.googleusercontent.com/-EucKC7DmcQI/AAAAAAAAAAI/AAAAAAAAAGA/gQU1NPEmNFA/s50-c-k-no/photo.jpg", 264 | "userId": "105654037995838004821" 265 | }, 266 | "user_tz": -540 267 | }, 268 | "id": "PnifSBJglzHh", 269 | "outputId": "94177345-918e-4a96-d9a8-d8aba0a4bc9a", 270 | "scrolled": true 271 | }, 272 | "outputs": [ 273 | { 274 | "name": "stderr", 275 | "output_type": "stream", 276 | "text": [ 277 | "/home/jay/anaconda3/lib/python3.7/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated. Call .resolve and .require separately.\n", 278 | " result = entry_point.load(False)\n" 279 | ] 280 | } 281 | ], 282 | "source": [ 283 | "# make an environment\n", 284 | "env = gym.make('CartPole-v0')\n", 285 | "# env = gym.make('CartPole-v1')\n", 286 | "# env = gym.make('MountainCar-v0')\n", 287 | "# env = gym.make('LunarLander-v2')\n", 288 | "\n", 289 | "env.seed(SEED)\n", 290 | "obs_space = env.observation_space.shape[0]\n", 291 | "action_space = env.action_space.n\n", 292 | "\n", 293 | "# hyperparameter\n", 294 | "n_episodes = 10000\n", 295 | "learn_start = 1500\n", 296 | "memory_size = 50000\n", 297 | "update_frq = 1\n", 298 | "use_eps_decay = False\n", 299 | "n_eval = env.spec.trials\n", 300 | "\n", 301 | "# global values\n", 302 | "total_steps = 0\n", 303 | "learn_steps = 0\n", 304 | "rewards = []\n", 305 | "reward_eval = deque(maxlen=n_eval)\n", 306 | "is_learned = False\n", 307 | "is_solved = False\n", 308 | "\n", 309 | "# make two nerual networks\n", 310 | "q_net = QNet(obs_space, action_space).to(device)\n", 311 | "p_net = PolicyNet(obs_space, action_space).to(device)\n", 312 | "v_net = ValueNet(obs_space).to(device)\n", 313 | "v_tgt = deepcopy(v_net)\n", 314 | "\n", 315 | "# make optimizer\n", 316 | "total_params = list(q_net.parameters()) + list(p_net.parameters()) + list(v_net.parameters())\n", 317 | "optimizer = optim.Adam(total_params, lr=LR, eps=1e-5)\n", 318 | "\n", 319 | "# make a memory\n", 320 | "rep_memory = deque(maxlen=memory_size)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "metadata": {}, 327 | "outputs": [ 328 | { 329 | "data": { 330 | "text/plain": [ 331 | "200" 332 | ] 333 | }, 334 | "execution_count": 6, 335 | "metadata": {}, 336 | "output_type": "execute_result" 337 | } 338 | ], 339 | "source": [ 340 | "env.spec.max_episode_steps" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "metadata": {}, 347 | "outputs": [ 348 | { 349 | "data": { 350 | "text/plain": [ 351 | "100" 352 | ] 353 | }, 354 | "execution_count": 7, 355 | "metadata": {}, 356 | "output_type": "execute_result" 357 | } 358 | ], 359 | "source": [ 360 | "env.spec.trials" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "metadata": {}, 367 | "outputs": [ 368 | { 369 | "data": { 370 | "text/plain": [ 371 | "195.0" 372 | ] 373 | }, 374 | "execution_count": 8, 375 | "metadata": {}, 376 | "output_type": "execute_result" 377 | } 378 | ], 379 | "source": [ 380 | "env.spec.reward_threshold" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "metadata": { 387 | "colab": { 388 | "base_uri": "https://localhost:8080/", 389 | "height": 3377 390 | }, 391 | "colab_type": "code", 392 | "executionInfo": { 393 | "elapsed": 135196, 394 | "status": "ok", 395 | "timestamp": 1534482559393, 396 | "user": { 397 | "displayName": "윤승제", 398 | "photoUrl": "//lh5.googleusercontent.com/-EucKC7DmcQI/AAAAAAAAAAI/AAAAAAAAAGA/gQU1NPEmNFA/s50-c-k-no/photo.jpg", 399 | "userId": "105654037995838004821" 400 | }, 401 | "user_tz": -540 402 | }, 403 | "id": "PnifSBJglzHh", 404 | "outputId": "94177345-918e-4a96-d9a8-d8aba0a4bc9a", 405 | "scrolled": true 406 | }, 407 | "outputs": [ 408 | { 409 | "name": "stdout", 410 | "output_type": "stream", 411 | "text": [ 412 | " 1 Episode in 15 steps, reward 15.00\n", 413 | " 2 Episode in 41 steps, reward 26.00\n", 414 | " 3 Episode in 51 steps, reward 10.00\n", 415 | " 4 Episode in 62 steps, reward 11.00\n", 416 | " 5 Episode in 72 steps, reward 10.00\n", 417 | " 6 Episode in 105 steps, reward 33.00\n", 418 | " 7 Episode in 151 steps, reward 46.00\n", 419 | " 8 Episode in 163 steps, reward 12.00\n", 420 | " 9 Episode in 189 steps, reward 26.00\n", 421 | " 10 Episode in 206 steps, reward 17.00\n", 422 | " 11 Episode in 224 steps, reward 18.00\n", 423 | " 12 Episode in 242 steps, reward 18.00\n", 424 | " 13 Episode in 267 steps, reward 25.00\n", 425 | " 14 Episode in 282 steps, reward 15.00\n", 426 | " 15 Episode in 295 steps, reward 13.00\n", 427 | " 16 Episode in 307 steps, reward 12.00\n", 428 | " 17 Episode in 324 steps, reward 17.00\n", 429 | " 18 Episode in 344 steps, reward 20.00\n", 430 | " 19 Episode in 364 steps, reward 20.00\n", 431 | " 20 Episode in 383 steps, reward 19.00\n", 432 | " 21 Episode in 430 steps, reward 47.00\n", 433 | " 22 Episode in 452 steps, reward 22.00\n", 434 | " 23 Episode in 469 steps, reward 17.00\n", 435 | " 24 Episode in 488 steps, reward 19.00\n", 436 | " 25 Episode in 511 steps, reward 23.00\n", 437 | " 26 Episode in 524 steps, reward 13.00\n", 438 | " 27 Episode in 546 steps, reward 22.00\n", 439 | " 28 Episode in 559 steps, reward 13.00\n", 440 | " 29 Episode in 577 steps, reward 18.00\n", 441 | " 30 Episode in 591 steps, reward 14.00\n", 442 | " 31 Episode in 609 steps, reward 18.00\n", 443 | " 32 Episode in 630 steps, reward 21.00\n", 444 | " 33 Episode in 643 steps, reward 13.00\n", 445 | " 34 Episode in 703 steps, reward 60.00\n", 446 | " 35 Episode in 717 steps, reward 14.00\n", 447 | " 36 Episode in 771 steps, reward 54.00\n", 448 | " 37 Episode in 793 steps, reward 22.00\n", 449 | " 38 Episode in 818 steps, reward 25.00\n", 450 | " 39 Episode in 834 steps, reward 16.00\n", 451 | " 40 Episode in 849 steps, reward 15.00\n", 452 | " 41 Episode in 869 steps, reward 20.00\n", 453 | " 42 Episode in 898 steps, reward 29.00\n", 454 | " 43 Episode in 913 steps, reward 15.00\n", 455 | " 44 Episode in 923 steps, reward 10.00\n", 456 | " 45 Episode in 957 steps, reward 34.00\n", 457 | " 46 Episode in 975 steps, reward 18.00\n", 458 | " 47 Episode in 1001 steps, reward 26.00\n", 459 | " 48 Episode in 1011 steps, reward 10.00\n", 460 | " 49 Episode in 1028 steps, reward 17.00\n", 461 | " 50 Episode in 1049 steps, reward 21.00\n", 462 | " 51 Episode in 1060 steps, reward 11.00\n", 463 | " 52 Episode in 1079 steps, reward 19.00\n", 464 | " 53 Episode in 1089 steps, reward 10.00\n", 465 | " 54 Episode in 1103 steps, reward 14.00\n", 466 | " 55 Episode in 1117 steps, reward 14.00\n", 467 | " 56 Episode in 1130 steps, reward 13.00\n", 468 | " 57 Episode in 1151 steps, reward 21.00\n", 469 | " 58 Episode in 1178 steps, reward 27.00\n", 470 | " 59 Episode in 1200 steps, reward 22.00\n", 471 | " 60 Episode in 1218 steps, reward 18.00\n", 472 | " 61 Episode in 1229 steps, reward 11.00\n", 473 | " 62 Episode in 1241 steps, reward 12.00\n", 474 | " 63 Episode in 1263 steps, reward 22.00\n", 475 | " 64 Episode in 1277 steps, reward 14.00\n", 476 | " 65 Episode in 1300 steps, reward 23.00\n", 477 | " 66 Episode in 1340 steps, reward 40.00\n", 478 | " 67 Episode in 1352 steps, reward 12.00\n", 479 | " 68 Episode in 1372 steps, reward 20.00\n", 480 | " 69 Episode in 1388 steps, reward 16.00\n", 481 | " 70 Episode in 1405 steps, reward 17.00\n", 482 | " 71 Episode in 1420 steps, reward 15.00\n", 483 | " 72 Episode in 1459 steps, reward 39.00\n", 484 | " 73 Episode in 1469 steps, reward 10.00\n", 485 | " 74 Episode in 1480 steps, reward 11.00\n", 486 | " 75 Episode in 1497 steps, reward 17.00\n", 487 | "\n", 488 | "============ Start Learning ============\n", 489 | "\n", 490 | " 76 Episode in 1513 steps, reward 16.00\n", 491 | " 77 Episode in 1524 steps, reward 11.00\n", 492 | " 78 Episode in 1544 steps, reward 20.00\n", 493 | " 79 Episode in 1556 steps, reward 12.00\n", 494 | " 80 Episode in 1568 steps, reward 12.00\n", 495 | " 81 Episode in 1591 steps, reward 23.00\n", 496 | " 82 Episode in 1609 steps, reward 18.00\n", 497 | " 83 Episode in 1622 steps, reward 13.00\n", 498 | " 84 Episode in 1634 steps, reward 12.00\n", 499 | " 85 Episode in 1644 steps, reward 10.00\n", 500 | " 86 Episode in 1654 steps, reward 10.00\n", 501 | " 87 Episode in 1664 steps, reward 10.00\n", 502 | " 88 Episode in 1678 steps, reward 14.00\n", 503 | " 89 Episode in 1688 steps, reward 10.00\n", 504 | " 90 Episode in 1698 steps, reward 10.00\n", 505 | " 91 Episode in 1708 steps, reward 10.00\n", 506 | " 92 Episode in 1718 steps, reward 10.00\n", 507 | " 93 Episode in 1728 steps, reward 10.00\n", 508 | " 94 Episode in 1737 steps, reward 9.00\n", 509 | " 95 Episode in 1746 steps, reward 9.00\n", 510 | " 96 Episode in 1756 steps, reward 10.00\n", 511 | " 97 Episode in 1766 steps, reward 10.00\n", 512 | " 98 Episode in 1776 steps, reward 10.00\n", 513 | " 99 Episode in 1787 steps, reward 11.00\n", 514 | "100 Episode in 1801 steps, reward 14.00\n", 515 | "101 Episode in 1824 steps, reward 23.00\n", 516 | "102 Episode in 1840 steps, reward 16.00\n", 517 | "103 Episode in 1852 steps, reward 12.00\n", 518 | "104 Episode in 1865 steps, reward 13.00\n", 519 | "105 Episode in 1877 steps, reward 12.00\n", 520 | "106 Episode in 1894 steps, reward 17.00\n", 521 | "107 Episode in 1918 steps, reward 24.00\n", 522 | "108 Episode in 1937 steps, reward 19.00\n", 523 | "109 Episode in 1963 steps, reward 26.00\n", 524 | "110 Episode in 2005 steps, reward 42.00\n", 525 | "111 Episode in 2039 steps, reward 34.00\n", 526 | "112 Episode in 2065 steps, reward 26.00\n", 527 | "113 Episode in 2087 steps, reward 22.00\n", 528 | "114 Episode in 2107 steps, reward 20.00\n", 529 | "115 Episode in 2132 steps, reward 25.00\n", 530 | "116 Episode in 2179 steps, reward 47.00\n", 531 | "117 Episode in 2201 steps, reward 22.00\n", 532 | "118 Episode in 2231 steps, reward 30.00\n", 533 | "119 Episode in 2282 steps, reward 51.00\n", 534 | "120 Episode in 2344 steps, reward 62.00\n", 535 | "121 Episode in 2365 steps, reward 21.00\n", 536 | "122 Episode in 2378 steps, reward 13.00\n", 537 | "123 Episode in 2388 steps, reward 10.00\n", 538 | "124 Episode in 2396 steps, reward 8.00\n", 539 | "125 Episode in 2406 steps, reward 10.00\n", 540 | "126 Episode in 2414 steps, reward 8.00\n", 541 | "127 Episode in 2423 steps, reward 9.00\n", 542 | "128 Episode in 2434 steps, reward 11.00\n", 543 | "129 Episode in 2447 steps, reward 13.00\n", 544 | "130 Episode in 2457 steps, reward 10.00\n", 545 | "131 Episode in 2469 steps, reward 12.00\n", 546 | "132 Episode in 2479 steps, reward 10.00\n", 547 | "133 Episode in 2494 steps, reward 15.00\n", 548 | "134 Episode in 2517 steps, reward 23.00\n", 549 | "135 Episode in 2537 steps, reward 20.00\n", 550 | "136 Episode in 2562 steps, reward 25.00\n", 551 | "137 Episode in 2580 steps, reward 18.00\n", 552 | "138 Episode in 2594 steps, reward 14.00\n", 553 | "139 Episode in 2606 steps, reward 12.00\n", 554 | "140 Episode in 2618 steps, reward 12.00\n", 555 | "141 Episode in 2631 steps, reward 13.00\n", 556 | "142 Episode in 2647 steps, reward 16.00\n", 557 | "143 Episode in 2663 steps, reward 16.00\n", 558 | "144 Episode in 2678 steps, reward 15.00\n", 559 | "145 Episode in 2688 steps, reward 10.00\n", 560 | "146 Episode in 2698 steps, reward 10.00\n", 561 | "147 Episode in 2707 steps, reward 9.00\n", 562 | "148 Episode in 2719 steps, reward 12.00\n", 563 | "149 Episode in 2734 steps, reward 15.00\n", 564 | "150 Episode in 2753 steps, reward 19.00\n", 565 | "151 Episode in 2772 steps, reward 19.00\n", 566 | "152 Episode in 2791 steps, reward 19.00\n", 567 | "153 Episode in 2806 steps, reward 15.00\n", 568 | "154 Episode in 2817 steps, reward 11.00\n", 569 | "155 Episode in 2830 steps, reward 13.00\n", 570 | "156 Episode in 2841 steps, reward 11.00\n", 571 | "157 Episode in 2852 steps, reward 11.00\n", 572 | "158 Episode in 2869 steps, reward 17.00\n", 573 | "159 Episode in 2889 steps, reward 20.00\n", 574 | "160 Episode in 2907 steps, reward 18.00\n", 575 | "161 Episode in 2925 steps, reward 18.00\n", 576 | "162 Episode in 2937 steps, reward 12.00\n", 577 | "163 Episode in 2948 steps, reward 11.00\n", 578 | "164 Episode in 2958 steps, reward 10.00\n", 579 | "165 Episode in 2967 steps, reward 9.00\n", 580 | "166 Episode in 2977 steps, reward 10.00\n", 581 | "167 Episode in 2987 steps, reward 10.00\n", 582 | "168 Episode in 2997 steps, reward 10.00\n", 583 | "169 Episode in 3007 steps, reward 10.00\n", 584 | "170 Episode in 3016 steps, reward 9.00\n", 585 | "171 Episode in 3026 steps, reward 10.00\n", 586 | "172 Episode in 3036 steps, reward 10.00\n", 587 | "173 Episode in 3045 steps, reward 9.00\n", 588 | "174 Episode in 3054 steps, reward 9.00\n", 589 | "175 Episode in 3063 steps, reward 9.00\n", 590 | "176 Episode in 3074 steps, reward 11.00\n", 591 | "177 Episode in 3082 steps, reward 8.00\n", 592 | "178 Episode in 3093 steps, reward 11.00\n", 593 | "179 Episode in 3103 steps, reward 10.00\n", 594 | "180 Episode in 3114 steps, reward 11.00\n", 595 | "181 Episode in 3124 steps, reward 10.00\n", 596 | "182 Episode in 3137 steps, reward 13.00\n", 597 | "183 Episode in 3153 steps, reward 16.00\n", 598 | "184 Episode in 3218 steps, reward 65.00\n", 599 | "185 Episode in 3230 steps, reward 12.00\n", 600 | "186 Episode in 3241 steps, reward 11.00\n", 601 | "187 Episode in 3250 steps, reward 9.00\n", 602 | "188 Episode in 3260 steps, reward 10.00\n", 603 | "189 Episode in 3270 steps, reward 10.00\n", 604 | "190 Episode in 3281 steps, reward 11.00\n", 605 | "191 Episode in 3289 steps, reward 8.00\n", 606 | "192 Episode in 3298 steps, reward 9.00\n", 607 | "193 Episode in 3308 steps, reward 10.00\n", 608 | "194 Episode in 3318 steps, reward 10.00\n", 609 | "195 Episode in 3327 steps, reward 9.00\n", 610 | "196 Episode in 3337 steps, reward 10.00\n", 611 | "197 Episode in 3345 steps, reward 8.00\n", 612 | "198 Episode in 3354 steps, reward 9.00\n", 613 | "199 Episode in 3363 steps, reward 9.00\n", 614 | "200 Episode in 3373 steps, reward 10.00\n", 615 | "201 Episode in 3383 steps, reward 10.00\n" 616 | ] 617 | }, 618 | { 619 | "name": "stdout", 620 | "output_type": "stream", 621 | "text": [ 622 | "202 Episode in 3392 steps, reward 9.00\n", 623 | "203 Episode in 3400 steps, reward 8.00\n", 624 | "204 Episode in 3408 steps, reward 8.00\n", 625 | "205 Episode in 3418 steps, reward 10.00\n", 626 | "206 Episode in 3427 steps, reward 9.00\n", 627 | "207 Episode in 3436 steps, reward 9.00\n", 628 | "208 Episode in 3446 steps, reward 10.00\n", 629 | "209 Episode in 3454 steps, reward 8.00\n", 630 | "210 Episode in 3463 steps, reward 9.00\n", 631 | "211 Episode in 3473 steps, reward 10.00\n", 632 | "212 Episode in 3481 steps, reward 8.00\n", 633 | "213 Episode in 3491 steps, reward 10.00\n", 634 | "214 Episode in 3499 steps, reward 8.00\n", 635 | "215 Episode in 3508 steps, reward 9.00\n", 636 | "216 Episode in 3518 steps, reward 10.00\n", 637 | "217 Episode in 3528 steps, reward 10.00\n", 638 | "218 Episode in 3536 steps, reward 8.00\n", 639 | "219 Episode in 3545 steps, reward 9.00\n", 640 | "220 Episode in 3555 steps, reward 10.00\n", 641 | "221 Episode in 3563 steps, reward 8.00\n", 642 | "222 Episode in 3572 steps, reward 9.00\n", 643 | "223 Episode in 3580 steps, reward 8.00\n", 644 | "224 Episode in 3590 steps, reward 10.00\n", 645 | "225 Episode in 3600 steps, reward 10.00\n", 646 | "226 Episode in 3610 steps, reward 10.00\n", 647 | "227 Episode in 3620 steps, reward 10.00\n", 648 | "228 Episode in 3629 steps, reward 9.00\n", 649 | "229 Episode in 3638 steps, reward 9.00\n", 650 | "230 Episode in 3646 steps, reward 8.00\n", 651 | "231 Episode in 3655 steps, reward 9.00\n", 652 | "232 Episode in 3664 steps, reward 9.00\n", 653 | "233 Episode in 3673 steps, reward 9.00\n", 654 | "234 Episode in 3682 steps, reward 9.00\n", 655 | "235 Episode in 3691 steps, reward 9.00\n", 656 | "236 Episode in 3701 steps, reward 10.00\n", 657 | "237 Episode in 3711 steps, reward 10.00\n", 658 | "238 Episode in 3719 steps, reward 8.00\n", 659 | "239 Episode in 3729 steps, reward 10.00\n", 660 | "240 Episode in 3739 steps, reward 10.00\n", 661 | "241 Episode in 3749 steps, reward 10.00\n", 662 | "242 Episode in 3759 steps, reward 10.00\n", 663 | "243 Episode in 3768 steps, reward 9.00\n", 664 | "244 Episode in 3777 steps, reward 9.00\n", 665 | "245 Episode in 3787 steps, reward 10.00\n", 666 | "246 Episode in 3796 steps, reward 9.00\n", 667 | "247 Episode in 3804 steps, reward 8.00\n", 668 | "248 Episode in 3813 steps, reward 9.00\n", 669 | "249 Episode in 3823 steps, reward 10.00\n", 670 | "250 Episode in 3832 steps, reward 9.00\n", 671 | "251 Episode in 3841 steps, reward 9.00\n", 672 | "252 Episode in 3851 steps, reward 10.00\n", 673 | "253 Episode in 3860 steps, reward 9.00\n", 674 | "254 Episode in 3869 steps, reward 9.00\n", 675 | "255 Episode in 3878 steps, reward 9.00\n", 676 | "256 Episode in 3887 steps, reward 9.00\n", 677 | "257 Episode in 3897 steps, reward 10.00\n", 678 | "258 Episode in 3906 steps, reward 9.00\n", 679 | "259 Episode in 3914 steps, reward 8.00\n", 680 | "260 Episode in 3922 steps, reward 8.00\n", 681 | "261 Episode in 3931 steps, reward 9.00\n", 682 | "262 Episode in 3940 steps, reward 9.00\n", 683 | "263 Episode in 3950 steps, reward 10.00\n", 684 | "264 Episode in 3959 steps, reward 9.00\n", 685 | "265 Episode in 3968 steps, reward 9.00\n", 686 | "266 Episode in 3976 steps, reward 8.00\n", 687 | "267 Episode in 3986 steps, reward 10.00\n", 688 | "268 Episode in 3995 steps, reward 9.00\n", 689 | "269 Episode in 4004 steps, reward 9.00\n", 690 | "270 Episode in 4014 steps, reward 10.00\n", 691 | "271 Episode in 4024 steps, reward 10.00\n", 692 | "272 Episode in 4033 steps, reward 9.00\n", 693 | "273 Episode in 4042 steps, reward 9.00\n" 694 | ] 695 | } 696 | ], 697 | "source": [ 698 | "# play\n", 699 | "for i in range(1, n_episodes + 1):\n", 700 | " obs = env.reset()\n", 701 | " done = False\n", 702 | " ep_reward = 0\n", 703 | " while not done:\n", 704 | "# env.render()\n", 705 | " action = select_action(obs, p_net)\n", 706 | " _obs, reward, done, _ = env.step(action)\n", 707 | " rep_memory.append((obs, action, reward, _obs, done))\n", 708 | "\n", 709 | " obs = _obs\n", 710 | " total_steps += 1\n", 711 | " ep_reward += reward\n", 712 | "\n", 713 | " if len(rep_memory) >= learn_start:\n", 714 | " if len(rep_memory) == learn_start:\n", 715 | " print('\\n============ Start Learning ============\\n')\n", 716 | " learn(q_net, p_net, v_net, v_tgt, optimizer, rep_memory)\n", 717 | " learn_steps += 1\n", 718 | "\n", 719 | " if learn_steps == update_frq:\n", 720 | " # target smoothing update\n", 721 | " with torch.no_grad():\n", 722 | " for t, n in zip(v_tgt.parameters(), v_net.parameters()):\n", 723 | " t.data = UP_COEF * n.data + (1 - UP_COEF) * t.data\n", 724 | " learn_steps = 0\n", 725 | " if done:\n", 726 | " rewards.append(ep_reward)\n", 727 | " reward_eval.append(ep_reward)\n", 728 | " plot()\n", 729 | "# print('{:3} Episode in {:5} steps, reward {:.2f}'.format(\n", 730 | "# i, total_steps, ep_reward))\n", 731 | "\n", 732 | " if len(reward_eval) >= n_eval:\n", 733 | " if np.mean(reward_eval) >= env.spec.reward_threshold:\n", 734 | " print('\\n{} is sloved! {:3} Episode in {:3} steps'.format(\n", 735 | " env.spec.id, i, total_steps))\n", 736 | " torch.save(target_net.state_dict(),\n", 737 | " f'./test/saved_models/{env.spec.id}_ep{i}_clear_model_sac.pt')\n", 738 | " break\n", 739 | "env.close()" 740 | ] 741 | }, 742 | { 743 | "cell_type": "code", 744 | "execution_count": null, 745 | "metadata": { 746 | "scrolled": false 747 | }, 748 | "outputs": [], 749 | "source": [ 750 | "plt.figure(figsize=(15, 5))\n", 751 | "plt.title('Reward')\n", 752 | "plt.plot(rewards)\n", 753 | "plt.figure(figsize=(15, 5))\n", 754 | "plt.title('Loss')\n", 755 | "plt.plot(losses)\n", 756 | "plt.show()" 757 | ] 758 | }, 759 | { 760 | "cell_type": "code", 761 | "execution_count": null, 762 | "metadata": {}, 763 | "outputs": [], 764 | "source": [ 765 | "[\n", 766 | " ('CartPole-v0', 299, 0.25),\n", 767 | " ('CartPole-v1', 413, 0.025),\n", 768 | " ('MountainCar-v0', None ,0.05)\n", 769 | "]" 770 | ] 771 | } 772 | ], 773 | "metadata": { 774 | "colab": { 775 | "collapsed_sections": [], 776 | "name": "C51_tensorflow.ipynb", 777 | "provenance": [], 778 | "version": "0.3.2" 779 | }, 780 | "kernelspec": { 781 | "display_name": "Python 3", 782 | "language": "python", 783 | "name": "python3" 784 | }, 785 | "language_info": { 786 | "codemirror_mode": { 787 | "name": "ipython", 788 | "version": 3 789 | }, 790 | "file_extension": ".py", 791 | "mimetype": "text/x-python", 792 | "name": "python", 793 | "nbconvert_exporter": "python", 794 | "pygments_lexer": "ipython3", 795 | "version": "3.7.0" 796 | } 797 | }, 798 | "nbformat": 4, 799 | "nbformat_minor": 1 800 | } 801 | -------------------------------------------------------------------------------- /src/test/running_mean_std.py: -------------------------------------------------------------------------------- 1 | # https://github.com/openai/baselines/blob/master/baselines/common/running_mean_std.py 2 | import numpy as np 3 | 4 | 5 | class RunningMeanStd(object): 6 | def __init__(self, epsilon=1e-4, shape=()): 7 | self.mean = np.zeros(shape, 'float64') 8 | self.var = np.ones(shape, 'float64') 9 | self.count = epsilon 10 | 11 | def update(self, x): 12 | batch_mean = np.mean(x, axis=0) 13 | batch_var = np.var(x, axis=0) 14 | batch_count = x.shape[0] 15 | self.update_from_moments(batch_mean, batch_var, batch_count) 16 | 17 | def update_from_moments(self, batch_mean, batch_var, batch_count): 18 | self.mean, self.var, self.count = update_mean_var_count_from_moments( 19 | self.mean, self.var, self.count, 20 | batch_mean, batch_var, batch_count) 21 | 22 | 23 | def update_mean_var_count_from_moments(mean, var, count, 24 | batch_mean, batch_var, batch_count): 25 | delta = batch_mean - mean 26 | tot_count = count + batch_count 27 | 28 | new_mean = mean + delta * batch_count / tot_count 29 | m_a = var * count 30 | m_b = batch_var * batch_count 31 | M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count 32 | new_var = M2 / tot_count 33 | new_count = tot_count 34 | 35 | return new_mean, new_var, new_count 36 | -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v0_ep179_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v0_ep179_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v0_ep87_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v0_ep87_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep108_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep108_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep112_clear_model_dddqn.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep112_clear_model_dddqn.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep1150_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep1150_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep118_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep118_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep1338_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep1338_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep134_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep134_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep164_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep164_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep168_clear_model_dddqn.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep168_clear_model_dddqn.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep1715_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep1715_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep213_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep213_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep216_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep216_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep222_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep222_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep223_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep223_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep258_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep258_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep270_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep270_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep273_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep273_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep294_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep294_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep295_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep295_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep296_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep296_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep308_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep308_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep313_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep313_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep320_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep320_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep326_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep326_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep329_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep329_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep357_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep357_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep376_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep376_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep385_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep385_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep388_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep388_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep408_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep408_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep417_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep417_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep418_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep418_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep420_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep420_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep424_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep424_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep433_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep433_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep436_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep436_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep445_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep445_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep454_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep454_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep456_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep456_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep458_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep458_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep463_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep463_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep470_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep470_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep474_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep474_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep475_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep475_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep492_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep492_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep498_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep498_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep504_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep504_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep506_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep506_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep509_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep509_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep516_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep516_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep519_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep519_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep530_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep530_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep534_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep534_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep546_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep546_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep552_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep552_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep555_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep555_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep559_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep559_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep566_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep566_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep569_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep569_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep586_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep586_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep589_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep589_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep595_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep595_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep631_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep631_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep635_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep635_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep637_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep637_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep642_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep642_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep675_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep675_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep690_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep690_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep727_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep727_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep741_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep741_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep763_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep763_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep840_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep840_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep845_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep845_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep866_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep866_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep871_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep871_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep892_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep892_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep903_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep903_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep924_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep924_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_ep939_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep939_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_up50_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_up50_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/CartPole-v1_up50_clear_norm_obs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_up50_clear_norm_obs.pkl -------------------------------------------------------------------------------- /src/test/saved_models/LunarLander-v2_ep260_clear_model_dqn.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_ep260_clear_model_dqn.pt -------------------------------------------------------------------------------- /src/test/saved_models/LunarLander-v2_ep370_clear_model_dddqn.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_ep370_clear_model_dddqn.pt -------------------------------------------------------------------------------- /src/test/saved_models/LunarLander-v2_ep8461_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_ep8461_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/LunarLander-v2_ep876_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_ep876_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/LunarLander-v2_up1099_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_up1099_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/LunarLander-v2_up1317_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_up1317_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/LunarLander-v2_up254_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_up254_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/LunarLander-v2_up570_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_up570_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/LunarLander-v2_up75_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_up75_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/MountainCar-v0_ep304_clear_model_dqn.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/MountainCar-v0_ep304_clear_model_dqn.pt -------------------------------------------------------------------------------- /src/test/saved_models/MountainCar-v0_ep532_clear_model_dddqn.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/MountainCar-v0_ep532_clear_model_dddqn.pt -------------------------------------------------------------------------------- /src/test/saved_models/MountainCar-v0_ep984_clear_model_dddqn.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/MountainCar-v0_ep984_clear_model_dddqn.pt -------------------------------------------------------------------------------- /src/test/saved_models/MountainCar-v0_up1441_clear_model_ppo_st.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/MountainCar-v0_up1441_clear_model_ppo_st.pt -------------------------------------------------------------------------------- /src/test/saved_models/MountainCar-v0_up1441_clear_norm_obs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/MountainCar-v0_up1441_clear_norm_obs.pkl -------------------------------------------------------------------------------- /src/test/test_dueling_double_dqn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "colab": {}, 8 | "colab_type": "code", 9 | "id": "IWnm3qot3o1W" 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "from collections import deque\n", 14 | "\n", 15 | "import gym\n", 16 | "import imageio\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "import numpy as np\n", 19 | "import torch\n", 20 | "import torch.nn as nn\n", 21 | "from torch.nn import functional as F\n", 22 | "from torch.distributions import Categorical" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": { 29 | "colab": {}, 30 | "colab_type": "code", 31 | "id": "9Ffkl_5C4R81" 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "class DuelingDQN(nn.Module):\n", 36 | " def __init__(self, obs_space, action_space):\n", 37 | " super().__init__()\n", 38 | "\n", 39 | " self.head = nn.Sequential(\n", 40 | " nn.Linear(obs_space, obs_space*10),\n", 41 | " nn.SELU()\n", 42 | " )\n", 43 | "\n", 44 | " self.val = nn.Sequential(\n", 45 | " nn.Linear(obs_space*10, 512),\n", 46 | " nn.SELU(),\n", 47 | " nn.Linear(512, 512),\n", 48 | " nn.SELU(),\n", 49 | " nn.Linear(512, 1)\n", 50 | " )\n", 51 | "\n", 52 | " self.adv = nn.Sequential(\n", 53 | " nn.Linear(obs_space*10, 512),\n", 54 | " nn.SELU(),\n", 55 | " nn.Linear(512, 512),\n", 56 | " nn.SELU(),\n", 57 | " nn.Linear(512, action_space)\n", 58 | " )\n", 59 | "\n", 60 | " def forward(self, x):\n", 61 | " out = self.head(x)\n", 62 | " val_out = self.val(out).reshape(out.shape[0], 1)\n", 63 | " adv_out = self.adv(out).reshape(out.shape[0], -1)\n", 64 | " adv_mean = adv_out.mean(dim=1, keepdim=True)\n", 65 | " q = val_out + adv_out - adv_mean\n", 66 | "\n", 67 | " return q" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "def select_action_(obs, tgt_net):\n", 77 | " tgt_net.eval()\n", 78 | " with torch.no_grad():\n", 79 | " state = torch.tensor([obs]).to(device).float()\n", 80 | " q = tgt_net(state)\n", 81 | " action = torch.argmax(q)\n", 82 | "\n", 83 | " return action.item()\n", 84 | "\n", 85 | "\n", 86 | "def select_action(obs, tgt_net):\n", 87 | " tgt_net.eval()\n", 88 | " with torch.no_grad():\n", 89 | " state = torch.tensor([obs]).to(device).float()\n", 90 | " q = tgt_net(state)\n", 91 | " probs = F.softmax(q/0.35)\n", 92 | " m = Categorical(probs)\n", 93 | " action = m.sample()\n", 94 | "\n", 95 | " return action.item()" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "## Main" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 4, 108 | "metadata": { 109 | "colab": { 110 | "base_uri": "https://localhost:8080/", 111 | "height": 3377 112 | }, 113 | "colab_type": "code", 114 | "executionInfo": { 115 | "elapsed": 135196, 116 | "status": "ok", 117 | "timestamp": 1534482559393, 118 | "user": { 119 | "displayName": "윤승제", 120 | "photoUrl": "//lh5.googleusercontent.com/-EucKC7DmcQI/AAAAAAAAAAI/AAAAAAAAAGA/gQU1NPEmNFA/s50-c-k-no/photo.jpg", 121 | "userId": "105654037995838004821" 122 | }, 123 | "user_tz": -540 124 | }, 125 | "id": "PnifSBJglzHh", 126 | "outputId": "94177345-918e-4a96-d9a8-d8aba0a4bc9a", 127 | "scrolled": false 128 | }, 129 | "outputs": [ 130 | { 131 | "name": "stderr", 132 | "output_type": "stream", 133 | "text": [ 134 | "/home/jay/anaconda3/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated. Call .resolve and .require separately.\n", 135 | " result = entry_point.load(False)\n" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "# set device\n", 141 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 142 | "\n", 143 | "# make an environment\n", 144 | "# env = gym.make('CartPole-v0')\n", 145 | "env = gym.make('CartPole-v1')\n", 146 | "# env = gym.make('MountainCar-v0')\n", 147 | "# env = gym.make('LunarLander-v2')\n", 148 | "\n", 149 | "SEED = 0\n", 150 | "env.seed(SEED)\n", 151 | "obs_space = env.observation_space.shape[0]\n", 152 | "action_space = env.action_space.n\n", 153 | "\n", 154 | "# hyperparameter\n", 155 | "n_episodes = 1000\n", 156 | "n_eval = env.spec.trials\n", 157 | "\n", 158 | "# global values\n", 159 | "total_steps = 0\n", 160 | "rewards = []\n", 161 | "reward_eval = deque(maxlen=n_eval)\n", 162 | "is_solved = False\n", 163 | "\n", 164 | "# load a model\n", 165 | "target_net = DuelingDQN(obs_space, action_space).to(device)\n", 166 | "target_net.load_state_dict(torch.load(\n", 167 | " './saved_models/CartPole-v1_ep217_clear_model_dddqn.pt'))" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 5, 173 | "metadata": { 174 | "scrolled": true 175 | }, 176 | "outputs": [ 177 | { 178 | "data": { 179 | "text/plain": [ 180 | "500" 181 | ] 182 | }, 183 | "execution_count": 5, 184 | "metadata": {}, 185 | "output_type": "execute_result" 186 | } 187 | ], 188 | "source": [ 189 | "env.spec.max_episode_steps" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 6, 195 | "metadata": { 196 | "scrolled": true 197 | }, 198 | "outputs": [ 199 | { 200 | "data": { 201 | "text/plain": [ 202 | "100" 203 | ] 204 | }, 205 | "execution_count": 6, 206 | "metadata": {}, 207 | "output_type": "execute_result" 208 | } 209 | ], 210 | "source": [ 211 | "env.spec.trials" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 7, 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "data": { 221 | "text/plain": [ 222 | "475.0" 223 | ] 224 | }, 225 | "execution_count": 7, 226 | "metadata": {}, 227 | "output_type": "execute_result" 228 | } 229 | ], 230 | "source": [ 231 | "env.spec.reward_threshold" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 8, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "# env.metadata['video.frames_per_second'] = 60" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 9, 246 | "metadata": { 247 | "colab": { 248 | "base_uri": "https://localhost:8080/", 249 | "height": 3377 250 | }, 251 | "colab_type": "code", 252 | "executionInfo": { 253 | "elapsed": 135196, 254 | "status": "ok", 255 | "timestamp": 1534482559393, 256 | "user": { 257 | "displayName": "윤승제", 258 | "photoUrl": "//lh5.googleusercontent.com/-EucKC7DmcQI/AAAAAAAAAAI/AAAAAAAAAGA/gQU1NPEmNFA/s50-c-k-no/photo.jpg", 259 | "userId": "105654037995838004821" 260 | }, 261 | "user_tz": -540 262 | }, 263 | "id": "PnifSBJglzHh", 264 | "outputId": "94177345-918e-4a96-d9a8-d8aba0a4bc9a", 265 | "scrolled": true 266 | }, 267 | "outputs": [ 268 | { 269 | "name": "stderr", 270 | "output_type": "stream", 271 | "text": [ 272 | "/home/jay/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:16: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", 273 | " app.launch_new_instance()\n" 274 | ] 275 | }, 276 | { 277 | "name": "stdout", 278 | "output_type": "stream", 279 | "text": [ 280 | " 1 Episode in 500 steps, reward 500.00\n", 281 | " 2 Episode in 1000 steps, reward 500.00\n", 282 | " 3 Episode in 1500 steps, reward 500.00\n", 283 | " 4 Episode in 2000 steps, reward 500.00\n", 284 | " 5 Episode in 2500 steps, reward 500.00\n", 285 | " 6 Episode in 3000 steps, reward 500.00\n", 286 | " 7 Episode in 3500 steps, reward 500.00\n", 287 | " 8 Episode in 4000 steps, reward 500.00\n", 288 | " 9 Episode in 4500 steps, reward 500.00\n", 289 | " 10 Episode in 5000 steps, reward 500.00\n", 290 | " 11 Episode in 5500 steps, reward 500.00\n", 291 | " 12 Episode in 6000 steps, reward 500.00\n", 292 | " 13 Episode in 6500 steps, reward 500.00\n", 293 | " 14 Episode in 7000 steps, reward 500.00\n", 294 | " 15 Episode in 7500 steps, reward 500.00\n", 295 | " 16 Episode in 8000 steps, reward 500.00\n", 296 | " 17 Episode in 8500 steps, reward 500.00\n", 297 | " 18 Episode in 9000 steps, reward 500.00\n", 298 | " 19 Episode in 9500 steps, reward 500.00\n", 299 | " 20 Episode in 10000 steps, reward 500.00\n", 300 | " 21 Episode in 10500 steps, reward 500.00\n", 301 | " 22 Episode in 11000 steps, reward 500.00\n", 302 | " 23 Episode in 11500 steps, reward 500.00\n", 303 | " 24 Episode in 12000 steps, reward 500.00\n", 304 | " 25 Episode in 12500 steps, reward 500.00\n", 305 | " 26 Episode in 13000 steps, reward 500.00\n", 306 | " 27 Episode in 13500 steps, reward 500.00\n", 307 | " 28 Episode in 14000 steps, reward 500.00\n", 308 | " 29 Episode in 14500 steps, reward 500.00\n", 309 | " 30 Episode in 15000 steps, reward 500.00\n", 310 | " 31 Episode in 15500 steps, reward 500.00\n", 311 | " 32 Episode in 16000 steps, reward 500.00\n", 312 | " 33 Episode in 16500 steps, reward 500.00\n", 313 | " 34 Episode in 17000 steps, reward 500.00\n", 314 | " 35 Episode in 17500 steps, reward 500.00\n", 315 | " 36 Episode in 18000 steps, reward 500.00\n", 316 | " 37 Episode in 18500 steps, reward 500.00\n", 317 | " 38 Episode in 19000 steps, reward 500.00\n", 318 | " 39 Episode in 19500 steps, reward 500.00\n", 319 | " 40 Episode in 20000 steps, reward 500.00\n", 320 | " 41 Episode in 20500 steps, reward 500.00\n", 321 | " 42 Episode in 21000 steps, reward 500.00\n", 322 | " 43 Episode in 21500 steps, reward 500.00\n", 323 | " 44 Episode in 22000 steps, reward 500.00\n", 324 | " 45 Episode in 22500 steps, reward 500.00\n", 325 | " 46 Episode in 23000 steps, reward 500.00\n", 326 | " 47 Episode in 23500 steps, reward 500.00\n", 327 | " 48 Episode in 24000 steps, reward 500.00\n", 328 | " 49 Episode in 24500 steps, reward 500.00\n", 329 | " 50 Episode in 25000 steps, reward 500.00\n", 330 | " 51 Episode in 25500 steps, reward 500.00\n", 331 | " 52 Episode in 26000 steps, reward 500.00\n", 332 | " 53 Episode in 26500 steps, reward 500.00\n", 333 | " 54 Episode in 27000 steps, reward 500.00\n", 334 | " 55 Episode in 27500 steps, reward 500.00\n", 335 | " 56 Episode in 28000 steps, reward 500.00\n", 336 | " 57 Episode in 28500 steps, reward 500.00\n", 337 | " 58 Episode in 29000 steps, reward 500.00\n", 338 | " 59 Episode in 29500 steps, reward 500.00\n", 339 | " 60 Episode in 30000 steps, reward 500.00\n", 340 | " 61 Episode in 30500 steps, reward 500.00\n", 341 | " 62 Episode in 31000 steps, reward 500.00\n", 342 | " 63 Episode in 31500 steps, reward 500.00\n", 343 | " 64 Episode in 32000 steps, reward 500.00\n", 344 | " 65 Episode in 32500 steps, reward 500.00\n", 345 | " 66 Episode in 33000 steps, reward 500.00\n", 346 | " 67 Episode in 33500 steps, reward 500.00\n", 347 | " 68 Episode in 34000 steps, reward 500.00\n", 348 | " 69 Episode in 34500 steps, reward 500.00\n", 349 | " 70 Episode in 35000 steps, reward 500.00\n", 350 | " 71 Episode in 35500 steps, reward 500.00\n", 351 | " 72 Episode in 36000 steps, reward 500.00\n", 352 | " 73 Episode in 36500 steps, reward 500.00\n", 353 | " 74 Episode in 37000 steps, reward 500.00\n", 354 | " 75 Episode in 37500 steps, reward 500.00\n", 355 | " 76 Episode in 38000 steps, reward 500.00\n", 356 | " 77 Episode in 38500 steps, reward 500.00\n", 357 | " 78 Episode in 39000 steps, reward 500.00\n", 358 | " 79 Episode in 39500 steps, reward 500.00\n", 359 | " 80 Episode in 40000 steps, reward 500.00\n", 360 | " 81 Episode in 40500 steps, reward 500.00\n", 361 | " 82 Episode in 41000 steps, reward 500.00\n", 362 | " 83 Episode in 41500 steps, reward 500.00\n", 363 | " 84 Episode in 42000 steps, reward 500.00\n", 364 | " 85 Episode in 42500 steps, reward 500.00\n", 365 | " 86 Episode in 43000 steps, reward 500.00\n", 366 | " 87 Episode in 43500 steps, reward 500.00\n", 367 | " 88 Episode in 44000 steps, reward 500.00\n", 368 | " 89 Episode in 44500 steps, reward 500.00\n", 369 | " 90 Episode in 45000 steps, reward 500.00\n", 370 | " 91 Episode in 45500 steps, reward 500.00\n", 371 | " 92 Episode in 46000 steps, reward 500.00\n", 372 | " 93 Episode in 46500 steps, reward 500.00\n", 373 | " 94 Episode in 47000 steps, reward 500.00\n", 374 | " 95 Episode in 47500 steps, reward 500.00\n", 375 | " 96 Episode in 48000 steps, reward 500.00\n", 376 | " 97 Episode in 48500 steps, reward 500.00\n", 377 | " 98 Episode in 49000 steps, reward 500.00\n", 378 | " 99 Episode in 49500 steps, reward 500.00\n", 379 | "100 Episode in 50000 steps, reward 500.00\n", 380 | "\n", 381 | "CartPole-v1 is sloved! 100 Episode in 50000 steps\n", 382 | "Mean Reward: 500.0\n" 383 | ] 384 | } 385 | ], 386 | "source": [ 387 | "# play\n", 388 | "# frames = []\n", 389 | "for i in range(1, n_episodes + 1):\n", 390 | " obs = env.reset()\n", 391 | " done = False\n", 392 | " ep_reward = 0\n", 393 | " while not done:\n", 394 | "# frames.append(env.render(mode = 'rgb_array'))\n", 395 | " env.render()\n", 396 | " action = select_action(obs, target_net)\n", 397 | " _obs, reward, done, _ = env.step(action)\n", 398 | " obs = _obs\n", 399 | " total_steps += 1\n", 400 | " ep_reward += reward \n", 401 | " if done:\n", 402 | " env.render()\n", 403 | " rewards.append(ep_reward)\n", 404 | " reward_eval.append(ep_reward)\n", 405 | " print('{:3} Episode in {:5} steps, reward {:.2f}'.format(\n", 406 | " i, total_steps, ep_reward))\n", 407 | "# frames.append(env.render(mode = 'rgb_array'))\n", 408 | "# imageio.mimsave(f'{env.spec.id}.gif', frames,)\n", 409 | " \n", 410 | " if len(reward_eval) >= n_eval:\n", 411 | " if np.mean(reward_eval) >= env.spec.reward_threshold:\n", 412 | " print('\\n{} is sloved! {:3} Episode in {:3} steps'.format(\n", 413 | " env.spec.id, i, total_steps))\n", 414 | " print(f'Mean Reward: {np.mean(reward_eval).round(decimals=2)}')\n", 415 | " break\n", 416 | "env.close()" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 10, 422 | "metadata": { 423 | "scrolled": false 424 | }, 425 | "outputs": [ 426 | { 427 | "data": { 428 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA3AAAAE/CAYAAAAHeyFHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFTRJREFUeJzt3X2wpnV93/HPV1YwUpAoKwMsZM2IaZCJhDmhNKQaJRV5iJhWp5tqZIwpfWAmJtNWpUmbxKkzzTSjlj6QUEyKbVApZputsY5UIqaTAd0VyqMmOzy426Wy+AAOKBb59o/7WjnCWfYse3Zvfnter5kz57p+53ef+3eYa67De6/rvk91dwAAAHj2e868FwAAAMDyCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAWEFVdU9V/cy81wHAwUnAAQAADELAATC8qlqzGp4TAAQcAEOablV8V1XdkuThqjqxqj5WVTur6u6q+uVp3vOq6ltVdfS0/+tV9VhVHTnt/8uq+sC0fV5V3VRVD1XVtqr6zUXPt76quqreXlVfTnLdNP4LVXVvVX21qn7tAP9nAGCVEXAAjOznk5yX5IVJNib530mOT3JWkl+pqrO7+9tJPp/kVdNjXpnk3iRnLtq/ftp+OMlbkxw1fd9/WFVveNJzvirJjyY5u6pOTnJZkl9IclySFyVZt8I/IwB8j4ADYGSXdve2JKckWdvd7+nu73T3XUn+Y5IN07zrk7xquu3xx5JcOu0/L8lPJPmzJOnuz3T3rd39eHffkuTDeSL8dvnN7n64u7+V5I1JPt7dn+3uR5P88ySP798fGYDVzP37AIxs2/T5h5IcV1XfWPS1QzKFWWYB974kpyW5Ncm1ST6Y5IwkW7v7gSSpqr+W5F9lFoSHJjksyX/dzXMms6tu39vv7oer6qv7/mMBwNJcgQNgZD193pbk7u4+atHHEd197vT1P0/yI0l+Lsn13X1HkhMzu03y+kXf76okm5Kc0N0vSPK7SWo3z5kk9yU5YddOVT0/s9soAWC/EHAAHAw+l+Sh6U1NfqCqDqmqU6rqJ5Kkux9JsiXJxXki2P48yd/P9wfcEUm+1t3frqrTk/zdPTzvNUnOr6qfqqpDk7wnfrcCsB/5JQPA8Lr7u0l+NsmpSe5O8kCSK5K8YNG065M8N7PY27V/RJLPLprzj5K8p6q+meRfJLl6D897e2ZReFVmV+O+nmT7Pv44ALBb1d17ngUAAMDcuQIHAAAwCAEHAAAwCAEHAAAwCAEHAAAwCAEHAAAwiDXzXkCSHH300b1+/fp5LwMAAGAutmzZ8kB3r93TvGdFwK1fvz6bN2+e9zIAAADmoqruXc48t1ACAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMYlkBV1X3VNWtVXVzVW2exv51VX2xqm6pqo1VddSi+ZdU1daq+lJVnb2/Fg8AALCa7M0VuFd396ndvTDtX5vklO7+sSR/keSSJKmqk5NsSPLyJK9L8h+q6pAVXDMAAMCq9IxvoezuT3X3Y9PuDUnWTdsXJPlIdz/a3Xcn2Zrk9H1bJgAAAMsNuE7yqaraUlUXLfH1X0zyP6bt45NsW/S17dMYAAAA+2DNMued2d07qurFSa6tqi9292eTpKp+LcljSf5wmltLPL6fPDCF4EVJcuKJJ+71wgEAAFabZV2B6+4d0+f7k2zMdEtkVV2Y5Pwkb+7uXZG2PckJix6+LsmOJb7n5d290N0La9eufeY/AQAAwCqxx4CrqsOr6ohd20lem+S2qnpdkncleX13P7LoIZuSbKiqw6rqJUlOSvK5lV86AADA6rKcWyiPSbKxqnbNv6q7P1lVW5McltktlUlyQ3f/g+6+vaquTnJHZrdWXtzd390/ywcAAFg99hhw3X1XklcsMf7Sp3nMe5O8d9+WBgAAwGLP+M8IAAAAcGAJOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEsK+Cq6p6qurWqbq6qzdPYm6rq9qp6vKoWnjT/kqraWlVfqqqz98fCAQAAVps1ezH31d39wKL925L8rSS/t3hSVZ2cZEOSlyc5Lsn/rKqXdfd393WxAAAAq9kzvoWyu+/s7i8t8aULknykux/t7ruTbE1y+jN9HgAAAGaWG3Cd5FNVtaWqLtrD3OOTbFu0v30a+z5VdVFVba6qzTt37lzmMgAAAFav5Qbcmd19WpJzklxcVa98mrm1xFg/ZaD78u5e6O6FtWvXLnMZAAAAq9eyAq67d0yf70+yMU9/S+T2JCcs2l+XZMczXSAAAAAzewy4qjq8qo7YtZ3ktZm9gcnubEqyoaoOq6qXJDkpyedWYrEAAACr2XLehfKYJBuratf8q7r7k1X1c0n+bZK1Sf6kqm7u7rO7+/aqujrJHUkeS3Kxd6AEAADYd9X9lJenHXALCwu9efPmeS8DAABgLqpqS3cv7GneM/4zAgAAABxYAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQywq4qrqnqm6tqpuravM09sKquraq/nL6/IPTeFXVpVW1tapuqarT9ucPAAAAsFrszRW4V3f3qd29MO2/O8mnu/ukJJ+e9pPknCQnTR8XJblspRYLAACwmu3LLZQXJLly2r4yyRsWjX+oZ25IclRVHbsPzwMAAECSNcuc10k+VVWd5Pe6+/Ikx3T3fUnS3fdV1Yunuccn2bbosdunsftWaM0HxG/999tzx46H5r0MAABgBZ183JH5jZ99+byX8YwtN+DO7O4dU6RdW1VffJq5tcRYP2VS1UWZ3WKZE088cZnLAAAAWL2WFXDdvWP6fH9VbUxyepKvVNWx09W3Y5PcP03fnuSERQ9fl2THEt/z8iSXJ8nCwsJTAm/eRq5yAADg4LTH18BV1eFVdcSu7SSvTXJbkk1JLpymXZjkj6ftTUneOr0b5RlJHtx1qyUAAADP3HKuwB2TZGNV7Zp/VXd/sqo+n+Tqqnp7ki8nedM0/xNJzk2yNckjSd624qsGAABYhfYYcN19V5JXLDH+1SRnLTHeSS5ekdUBAADwPfvyZwQAAAA4gAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIJYdcFV1SFXdVFUfn/ZfU1VfqKrbqurKqlozjVdVXVpVW6vqlqo6bX8tHgAAYDXZmytw70hyZ5JU1XOSXJlkQ3efkuTeJBdO885JctL0cVGSy1ZstQAAAKvYsgKuqtYlOS/JFdPQi5I82t1/Me1fm+RvT9sXJPlQz9yQ5KiqOnYF1wwAALAqLfcK3AeSvDPJ49P+A0meW1UL0/4bk5wwbR+fZNuix26fxr5PVV1UVZuravPOnTv3euEAAACrzR4DrqrOT3J/d2/ZNdbdnWRDkvdX1eeSfDPJY7sessS36acMdF/e3QvdvbB27dpntHgAAIDVZM0y5pyZ5PVVdW6S5yU5sqr+S3e/JcnfSJKqem2Sl03zt+eJq3FJsi7JjpVbMgAAwOq0xytw3X1Jd6/r7vWZXXW7rrvfUlUvTpKqOizJu5L87vSQTUneOr0b5RlJHuzu+/bP8gEAAFaP5VyB251/Ot1e+Zwkl3X3ddP4J5Kcm2RrkkeSvG3flggAAECS1OzlbPO1sLDQmzdvnvcyAAAA5qKqtnT3wp7m7c3fgQMAAGCOBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAglh1wVXVIVd1UVR+f9s+qqi9U1c1V9b+q6qXT+GFV9dGq2lpVN1bV+v2zdAAAgNVlb67AvSPJnYv2L0vy5u4+NclVSX59Gn97kq9390uTvD/Jb6/EQgEAAFa7ZQVcVa1Lcl6SKxYNd5Ijp+0XJNkxbV+Q5Mpp+5okZ1VV7ftSAQAAVrc1y5z3gSTvTHLEorFfSvKJqvpWkoeSnDGNH59kW5J092NV9WCSFyV5YEVWDAAAsErt8QpcVZ2f5P7u3vKkL/1qknO7e12SP0jyvl0PWeLb9BLf96Kq2lxVm3fu3LmXywYAAFh9lnML5ZlJXl9V9yT5SJLXVNWfJHlFd984zflokp+ctrcnOSFJqmpNZrdXfu3J37S7L+/uhe5eWLt27b79FAAAAKvAHgOuuy/p7nXdvT7JhiTXZfY6txdU1cumaX8zT7zByaYkF07bb0xyXXc/5QocAAAAe2e5r4H7PtNr2/5eko9V1eNJvp7kF6cvfzDJf66qrZldeduwIisFAABY5fYq4Lr7M0k+M21vTLJxiTnfTvKmFVgbAAAAi+zN34EDAABgjgQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIKq7572GVNXOJPfOex1LODrJA/NeBAc9xxkHguOM/c0xxoHgOONAmNdx9kPdvXZPk54VAfdsVVWbu3th3uvg4OY440BwnLG/OcY4EBxnHAjP9uPMLZQAAACDEHAAAACDEHBP7/J5L4BVwXHGgeA4Y39zjHEgOM44EJ7Vx5nXwAEAAAzCFTgAAIBBCLjdqKrXVdWXqmprVb173uthfFV1QlX9aVXdWVW3V9U7pvEXVtW1VfWX0+cfnPdaGV9VHVJVN1XVx6f9l1TVjdNx9tGqOnTea2RsVXVUVV1TVV+czmt/3fmMlVRVvzr9vrytqj5cVc9zLmMlVNXvV9X9VXXborElz181c+nUBLdU1WnzW/mMgFtCVR2S5N8nOSfJyUl+vqpOnu+qOAg8luQfd/ePJjkjycXTcfXuJJ/u7pOSfHrah331jiR3Ltr/7STvn46zryd5+1xWxcHk3yT5ZHf/1SSvyOx4cz5jRVTV8Ul+OclCd5+S5JAkG+Jcxsr4T0le96Sx3Z2/zkly0vRxUZLLDtAad0vALe30JFu7+67u/k6SjyS5YM5rYnDdfV93f2Ha/mZm/7NzfGbH1pXTtCuTvGE+K+RgUVXrkpyX5Ippv5K8Jsk10xTHGfukqo5M8sokH0yS7v5Od38jzmesrDVJfqCq1iR5fpL74lzGCujuzyb52pOGd3f+uiDJh3rmhiRHVdWxB2alSxNwSzs+ybZF+9unMVgRVbU+yY8nuTHJMd19XzKLvCQvnt/KOEh8IMk7kzw+7b8oyTe6+7Fp3zmNffXDSXYm+YPpVt0rqurwOJ+xQrr7/yT5nSRfzizcHkyyJc5l7D+7O38967pAwC2tlhjzdp2siKr6K0k+luRXuvuhea+Hg0tVnZ/k/u7esnh4ianOaeyLNUlOS3JZd/94kofjdklW0PT6owuSvCTJcUkOz+xWtidzLmN/e9b9DhVwS9ue5IRF++uS7JjTWjiIVNVzM4u3P+zuP5qGv7LrUvz0+f55rY+DwplJXl9V92R2+/drMrsid9R0G1LinMa+255ke3ffOO1fk1nQOZ+xUn4myd3dvbO7/1+SP0ryk3EuY//Z3fnrWdcFAm5pn09y0vROR4dm9qLZTXNeE4ObXof0wSR3dvf7Fn1pU5ILp+0Lk/zxgV4bB4/uvqS713X3+szOXdd195uT/GmSN07THGfsk+7+v0m2VdWPTENnJbkjzmesnC8nOaOqnj/9/tx1jDmXsb/s7vy1Kclbp3ejPCPJg7tutZwXf8h7N6rq3Mz+1fqQJL/f3e+d85IYXFX9VJI/S3Jrnnht0j/L7HVwVyc5MbNfWG/q7ie/sBb2WlX9dJJ/0t3nV9UPZ3ZF7oVJbkrylu5+dJ7rY2xVdWpmb5RzaJK7krwts38Ydj5jRVTVbyX5O5m9i/NNSX4ps9ceOZexT6rqw0l+OsnRSb6S5DeS/Lcscf6a/gHh32X2rpWPJHlbd2+ex7p3EXAAAACDcAslAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIP4/OeGmLRKsYKQAAAAASUVORK5CYII=\n", 429 | "text/plain": [ 430 | "
" 431 | ] 432 | }, 433 | "metadata": { 434 | "needs_background": "light" 435 | }, 436 | "output_type": "display_data" 437 | } 438 | ], 439 | "source": [ 440 | "plt.figure(figsize=(15, 5))\n", 441 | "plt.title('reward')\n", 442 | "plt.plot(rewards)\n", 443 | "plt.show()" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 11, 449 | "metadata": {}, 450 | "outputs": [ 451 | { 452 | "data": { 453 | "text/plain": [ 454 | "[('CartPole-v0', 412, 1),\n", 455 | " ('CartPole-v1', 452, 0.05),\n", 456 | " ('MountainCar-v0', 193, 0.1),\n", 457 | " ('LunarLander-v2', 260, 0.1)]" 458 | ] 459 | }, 460 | "execution_count": 11, 461 | "metadata": {}, 462 | "output_type": "execute_result" 463 | } 464 | ], 465 | "source": [ 466 | "[\n", 467 | " ('CartPole-v0', 412, 1),\n", 468 | " ('CartPole-v1', 452, 0.05),\n", 469 | " ('MountainCar-v0', 193, 0.1),\n", 470 | " ('LunarLander-v2', 260, 0.1)\n", 471 | "]" 472 | ] 473 | } 474 | ], 475 | "metadata": { 476 | "colab": { 477 | "collapsed_sections": [], 478 | "name": "C51_tensorflow.ipynb", 479 | "provenance": [], 480 | "version": "0.3.2" 481 | }, 482 | "kernelspec": { 483 | "display_name": "Python 3", 484 | "language": "python", 485 | "name": "python3" 486 | }, 487 | "language_info": { 488 | "codemirror_mode": { 489 | "name": "ipython", 490 | "version": 3 491 | }, 492 | "file_extension": ".py", 493 | "mimetype": "text/x-python", 494 | "name": "python", 495 | "nbconvert_exporter": "python", 496 | "pygments_lexer": "ipython3", 497 | "version": "3.6.7" 498 | } 499 | }, 500 | "nbformat": 4, 501 | "nbformat_minor": 1 502 | } 503 | -------------------------------------------------------------------------------- /src/test/test_ppo_with_rnd.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "colab": {}, 8 | "colab_type": "code", 9 | "id": "IWnm3qot3o1W" 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "from collections import deque\n", 14 | "\n", 15 | "import gym\n", 16 | "import imageio\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "import numpy as np\n", 19 | "import torch\n", 20 | "import torch.nn as nn\n", 21 | "from torch.distributions import Categorical " 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": { 28 | "colab": {}, 29 | "colab_type": "code", 30 | "id": "9Ffkl_5C4R81" 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "class ActorCriticNet(nn.Module):\n", 35 | " def __init__(self, obs_space, action_space):\n", 36 | " super().__init__()\n", 37 | "\n", 38 | " self.head = nn.Sequential(\n", 39 | " nn.Linear(obs_space, obs_space*10),\n", 40 | " nn.SELU()\n", 41 | " )\n", 42 | " self.pol = nn.Sequential(\n", 43 | " nn.Linear(obs_space*10, 512),\n", 44 | " nn.SELU(),\n", 45 | " nn.Linear(512, 512),\n", 46 | " nn.SELU(),\n", 47 | " nn.Linear(512, action_space)\n", 48 | " )\n", 49 | " self.val_ex = nn.Sequential(\n", 50 | " nn.Linear(obs_space*10, 512),\n", 51 | " nn.SELU(),\n", 52 | " nn.Linear(512, 512),\n", 53 | " nn.SELU(),\n", 54 | " nn.Linear(512, 1)\n", 55 | " )\n", 56 | " self.val_in = nn.Sequential(\n", 57 | " nn.Linear(obs_space*10, 512),\n", 58 | " nn.SELU(),\n", 59 | " nn.Linear(512, 512),\n", 60 | " nn.SELU(),\n", 61 | " nn.Linear(512, 1)\n", 62 | " )\n", 63 | " self.log_softmax = nn.LogSoftmax(dim=-1)\n", 64 | "\n", 65 | " def forward(self, x):\n", 66 | " out = self.head(x)\n", 67 | " logit = self.pol(out).reshape(out.shape[0], -1)\n", 68 | " value_ex = self.val_ex(out).reshape(out.shape[0], 1)\n", 69 | " value_in = self.val_in(out).reshape(out.shape[0], 1)\n", 70 | " log_probs = self.log_softmax(logit)\n", 71 | " \n", 72 | " return log_probs, value_ex, value_in" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "def get_action_and_value(obs, old_net):\n", 82 | " old_net.eval()\n", 83 | " with torch.no_grad():\n", 84 | " state = torch.tensor([obs]).to(device).float()\n", 85 | " log_p, _, _ = old_net(state)\n", 86 | " m = Categorical(log_p.exp())\n", 87 | " action = m.sample()\n", 88 | "\n", 89 | " return action.item()" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "## Main" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 4, 102 | "metadata": { 103 | "colab": { 104 | "base_uri": "https://localhost:8080/", 105 | "height": 3377 106 | }, 107 | "colab_type": "code", 108 | "executionInfo": { 109 | "elapsed": 135196, 110 | "status": "ok", 111 | "timestamp": 1534482559393, 112 | "user": { 113 | "displayName": "윤승제", 114 | "photoUrl": "//lh5.googleusercontent.com/-EucKC7DmcQI/AAAAAAAAAAI/AAAAAAAAAGA/gQU1NPEmNFA/s50-c-k-no/photo.jpg", 115 | "userId": "105654037995838004821" 116 | }, 117 | "user_tz": -540 118 | }, 119 | "id": "PnifSBJglzHh", 120 | "outputId": "94177345-918e-4a96-d9a8-d8aba0a4bc9a", 121 | "scrolled": false 122 | }, 123 | "outputs": [], 124 | "source": [ 125 | "# set device\n", 126 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 127 | "\n", 128 | "# make an environment\n", 129 | "# env = gym.make('CartPole-v0')\n", 130 | "env = gym.make('CartPole-v1')\n", 131 | "# env = gym.make('MountainCar-v0')\n", 132 | "# env = gym.make('LunarLander-v2')\n", 133 | "\n", 134 | "SEED = 0\n", 135 | "env.seed(SEED)\n", 136 | "obs_space = env.observation_space.shape[0]\n", 137 | "action_space = env.action_space.n\n", 138 | "\n", 139 | "# hyperparameter\n", 140 | "n_episodes = 1000\n", 141 | "n_eval = env.spec.trials\n", 142 | "\n", 143 | "# global values\n", 144 | "total_steps = 0\n", 145 | "rewards = []\n", 146 | "reward_eval = deque(maxlen=n_eval)\n", 147 | "is_solved = False\n", 148 | "\n", 149 | "# load a model\n", 150 | "target_net = ActorCriticNet(obs_space, action_space).to(device)\n", 151 | "target_net.load_state_dict(torch.load(\n", 152 | " './saved_models/CartPole-v1_ep225_clear_model_ppo.pt'))" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 5, 158 | "metadata": { 159 | "scrolled": true 160 | }, 161 | "outputs": [ 162 | { 163 | "data": { 164 | "text/plain": [ 165 | "500" 166 | ] 167 | }, 168 | "execution_count": 5, 169 | "metadata": {}, 170 | "output_type": "execute_result" 171 | } 172 | ], 173 | "source": [ 174 | "env.spec.max_episode_steps" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 6, 180 | "metadata": { 181 | "scrolled": true 182 | }, 183 | "outputs": [ 184 | { 185 | "data": { 186 | "text/plain": [ 187 | "100" 188 | ] 189 | }, 190 | "execution_count": 6, 191 | "metadata": {}, 192 | "output_type": "execute_result" 193 | } 194 | ], 195 | "source": [ 196 | "env.spec.trials" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 7, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "data": { 206 | "text/plain": [ 207 | "475.0" 208 | ] 209 | }, 210 | "execution_count": 7, 211 | "metadata": {}, 212 | "output_type": "execute_result" 213 | } 214 | ], 215 | "source": [ 216 | "env.spec.reward_threshold" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 8, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "# env.metadata['video.frames_per_second'] = 60" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 9, 231 | "metadata": { 232 | "colab": { 233 | "base_uri": "https://localhost:8080/", 234 | "height": 3377 235 | }, 236 | "colab_type": "code", 237 | "executionInfo": { 238 | "elapsed": 135196, 239 | "status": "ok", 240 | "timestamp": 1534482559393, 241 | "user": { 242 | "displayName": "윤승제", 243 | "photoUrl": "//lh5.googleusercontent.com/-EucKC7DmcQI/AAAAAAAAAAI/AAAAAAAAAGA/gQU1NPEmNFA/s50-c-k-no/photo.jpg", 244 | "userId": "105654037995838004821" 245 | }, 246 | "user_tz": -540 247 | }, 248 | "id": "PnifSBJglzHh", 249 | "outputId": "94177345-918e-4a96-d9a8-d8aba0a4bc9a", 250 | "scrolled": true 251 | }, 252 | "outputs": [ 253 | { 254 | "name": "stdout", 255 | "output_type": "stream", 256 | "text": [ 257 | " 1 Episode in 500 steps, reward 500.00\n", 258 | " 2 Episode in 1000 steps, reward 500.00\n", 259 | " 3 Episode in 1500 steps, reward 500.00\n", 260 | " 4 Episode in 2000 steps, reward 500.00\n", 261 | " 5 Episode in 2500 steps, reward 500.00\n", 262 | " 6 Episode in 3000 steps, reward 500.00\n", 263 | " 7 Episode in 3500 steps, reward 500.00\n", 264 | " 8 Episode in 4000 steps, reward 500.00\n", 265 | " 9 Episode in 4500 steps, reward 500.00\n", 266 | " 10 Episode in 5000 steps, reward 500.00\n", 267 | " 11 Episode in 5500 steps, reward 500.00\n", 268 | " 12 Episode in 6000 steps, reward 500.00\n", 269 | " 13 Episode in 6500 steps, reward 500.00\n", 270 | " 14 Episode in 7000 steps, reward 500.00\n", 271 | " 15 Episode in 7500 steps, reward 500.00\n", 272 | " 16 Episode in 8000 steps, reward 500.00\n", 273 | " 17 Episode in 8500 steps, reward 500.00\n", 274 | " 18 Episode in 9000 steps, reward 500.00\n", 275 | " 19 Episode in 9500 steps, reward 500.00\n", 276 | " 20 Episode in 10000 steps, reward 500.00\n", 277 | " 21 Episode in 10500 steps, reward 500.00\n", 278 | " 22 Episode in 11000 steps, reward 500.00\n", 279 | " 23 Episode in 11500 steps, reward 500.00\n", 280 | " 24 Episode in 12000 steps, reward 500.00\n", 281 | " 25 Episode in 12500 steps, reward 500.00\n", 282 | " 26 Episode in 13000 steps, reward 500.00\n", 283 | " 27 Episode in 13500 steps, reward 500.00\n", 284 | " 28 Episode in 14000 steps, reward 500.00\n", 285 | " 29 Episode in 14500 steps, reward 500.00\n", 286 | " 30 Episode in 15000 steps, reward 500.00\n", 287 | " 31 Episode in 15500 steps, reward 500.00\n", 288 | " 32 Episode in 16000 steps, reward 500.00\n", 289 | " 33 Episode in 16500 steps, reward 500.00\n", 290 | " 34 Episode in 17000 steps, reward 500.00\n", 291 | " 35 Episode in 17500 steps, reward 500.00\n", 292 | " 36 Episode in 18000 steps, reward 500.00\n", 293 | " 37 Episode in 18500 steps, reward 500.00\n", 294 | " 38 Episode in 19000 steps, reward 500.00\n", 295 | " 39 Episode in 19500 steps, reward 500.00\n", 296 | " 40 Episode in 20000 steps, reward 500.00\n", 297 | " 41 Episode in 20500 steps, reward 500.00\n", 298 | " 42 Episode in 21000 steps, reward 500.00\n", 299 | " 43 Episode in 21500 steps, reward 500.00\n", 300 | " 44 Episode in 22000 steps, reward 500.00\n", 301 | " 45 Episode in 22500 steps, reward 500.00\n", 302 | " 46 Episode in 23000 steps, reward 500.00\n", 303 | " 47 Episode in 23500 steps, reward 500.00\n", 304 | " 48 Episode in 24000 steps, reward 500.00\n", 305 | " 49 Episode in 24500 steps, reward 500.00\n", 306 | " 50 Episode in 25000 steps, reward 500.00\n", 307 | " 51 Episode in 25500 steps, reward 500.00\n", 308 | " 52 Episode in 26000 steps, reward 500.00\n", 309 | " 53 Episode in 26500 steps, reward 500.00\n", 310 | " 54 Episode in 27000 steps, reward 500.00\n", 311 | " 55 Episode in 27500 steps, reward 500.00\n", 312 | " 56 Episode in 28000 steps, reward 500.00\n", 313 | " 57 Episode in 28500 steps, reward 500.00\n", 314 | " 58 Episode in 29000 steps, reward 500.00\n", 315 | " 59 Episode in 29500 steps, reward 500.00\n", 316 | " 60 Episode in 30000 steps, reward 500.00\n", 317 | " 61 Episode in 30500 steps, reward 500.00\n", 318 | " 62 Episode in 31000 steps, reward 500.00\n", 319 | " 63 Episode in 31500 steps, reward 500.00\n", 320 | " 64 Episode in 32000 steps, reward 500.00\n", 321 | " 65 Episode in 32500 steps, reward 500.00\n", 322 | " 66 Episode in 33000 steps, reward 500.00\n", 323 | " 67 Episode in 33500 steps, reward 500.00\n", 324 | " 68 Episode in 34000 steps, reward 500.00\n", 325 | " 69 Episode in 34500 steps, reward 500.00\n", 326 | " 70 Episode in 35000 steps, reward 500.00\n", 327 | " 71 Episode in 35500 steps, reward 500.00\n", 328 | " 72 Episode in 36000 steps, reward 500.00\n", 329 | " 73 Episode in 36500 steps, reward 500.00\n", 330 | " 74 Episode in 37000 steps, reward 500.00\n", 331 | " 75 Episode in 37500 steps, reward 500.00\n", 332 | " 76 Episode in 38000 steps, reward 500.00\n", 333 | " 77 Episode in 38500 steps, reward 500.00\n", 334 | " 78 Episode in 39000 steps, reward 500.00\n", 335 | " 79 Episode in 39500 steps, reward 500.00\n", 336 | " 80 Episode in 40000 steps, reward 500.00\n", 337 | " 81 Episode in 40500 steps, reward 500.00\n", 338 | " 82 Episode in 41000 steps, reward 500.00\n", 339 | " 83 Episode in 41500 steps, reward 500.00\n", 340 | " 84 Episode in 42000 steps, reward 500.00\n", 341 | " 85 Episode in 42500 steps, reward 500.00\n", 342 | " 86 Episode in 43000 steps, reward 500.00\n", 343 | " 87 Episode in 43500 steps, reward 500.00\n", 344 | " 88 Episode in 44000 steps, reward 500.00\n", 345 | " 89 Episode in 44500 steps, reward 500.00\n", 346 | " 90 Episode in 45000 steps, reward 500.00\n", 347 | " 91 Episode in 45500 steps, reward 500.00\n", 348 | " 92 Episode in 46000 steps, reward 500.00\n", 349 | " 93 Episode in 46500 steps, reward 500.00\n", 350 | " 94 Episode in 47000 steps, reward 500.00\n", 351 | " 95 Episode in 47500 steps, reward 500.00\n", 352 | " 96 Episode in 48000 steps, reward 500.00\n", 353 | " 97 Episode in 48500 steps, reward 500.00\n", 354 | " 98 Episode in 49000 steps, reward 500.00\n", 355 | " 99 Episode in 49500 steps, reward 500.00\n", 356 | "100 Episode in 50000 steps, reward 500.00\n", 357 | "\n", 358 | "CartPole-v1 is sloved! 100 Episode in 50000 steps\n", 359 | "500.0\n" 360 | ] 361 | } 362 | ], 363 | "source": [ 364 | "# play\n", 365 | "# frames = []\n", 366 | "for i in range(1, n_episodes + 1):\n", 367 | " obs = env.reset()\n", 368 | " done = False\n", 369 | " ep_reward = 0\n", 370 | " while not done:\n", 371 | "# frames.append(env.render(mode = 'rgb_array'))\n", 372 | " env.render()\n", 373 | " action = get_action_and_value(obs, target_net)\n", 374 | " _obs, reward, done, _ = env.step(action)\n", 375 | " obs = _obs\n", 376 | " total_steps += 1\n", 377 | " ep_reward += reward \n", 378 | " if done:\n", 379 | " env.render()\n", 380 | " rewards.append(ep_reward)\n", 381 | " reward_eval.append(ep_reward)\n", 382 | " print('{:3} Episode in {:5} steps, reward {:.2f}'.format(\n", 383 | " i, total_steps, ep_reward))\n", 384 | "# frames.append(env.render(mode = 'rgb_array'))\n", 385 | "# imageio.mimsave(f'{env.spec.id}.gif', frames,)\n", 386 | " \n", 387 | " if len(reward_eval) >= n_eval:\n", 388 | " if np.mean(reward_eval) >= env.spec.reward_threshold:\n", 389 | " print('\\n{} is sloved! {:3} Episode in {:3} steps'.format(\n", 390 | " env.spec.id, i, total_steps))\n", 391 | " print(f'Mean Reward: {np.mean(reward_eval).round(decimals=2)}')\n", 392 | " break\n", 393 | "env.close()" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": 10, 399 | "metadata": { 400 | "scrolled": false 401 | }, 402 | "outputs": [ 403 | { 404 | "data": { 405 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA3AAAAE/CAYAAAAHeyFHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFRJJREFUeJzt3X+wpmV93/HPVzZgpAhRVkZYyNoRkyITCXNCaaUxSiryo2ISk5JqpMZ0Jy0zMZm2RmrbJHacaaYdtbQJCcWk2AaVYLbZGutIJWI6GdCzQvmpyQ4/XDYoCyo64I8i3/7x3CvH7a57lnN2H649r9fMzrnv67me81yHuede3nvfz3OquwMAAMDT3zPmvQAAAACWR8ABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABwCqqqnur6sfnvQ4ADk0CDgAAYBACDoDhVdW6tfCaACDgABjSdKvir1bVrUkeraqTquqDVbWzqu6pql+a5j2zqr5WVcdO+2+rqser6tnT/r+pqndP2+dX1c1V9ZWq2l5Vv77k9TZWVVfVm6rqc0mun8Z/rqruq6qHq+ptB/k/AwBrjIADYGQ/m+T8JM9JsjnJ/0lyQpKzk/xyVZ3T3V9P8qkkL5ue87Ik9yV56ZL9G6btR5O8Ickx0/f9x1X1mt1e82VJ/kaSc6rqlCSXJ/m5JMcneW6SDav8MwLAtwk4AEZ2WXdvT3JqkvXd/fbu/mZ3353kPye5aJp3Q5KXTbc9/lCSy6b9Zyb5kSSfSJLu/nh339bdT3T3rUnelyfDb5df7+5Hu/trSV6b5EPd/Ynu/kaSf5XkiQP7IwOwlrl/H4CRbZ++fn+S46vqy0seOyzJn03bNyR5Z5LTk9yW5Lok70lyZpJt3f1wklTV30zybzMLwsOTHJHkD/fymsnsqtu397v70ap6eOU/FgDsmStwAIysp6/bk9zT3ccs+XNUd583Pf7nSX4gyU8kuaG770xyUpLz8uTtk0lydZItSU7s7qOT/E6S2strJskDSU7ctVNVz8rsNkoAOCAEHACHgk8m+er0oSbfW1WHVdWpVfUjSdLdjyXZmuSSPBlsf57kF/OdAXdUki9299er6owk/2Afr3ttkguq6qyqOjzJ2+PvVgAOIH/JADC87v5WkguSnJbkniQPJbkyydFLpt2Q5Hsyi71d+0dlev/b5J8keXtVfTXJv05yzT5e947MovDqzK7GfSnJ/Sv8cQBgr6q79z0LAACAuXMFDgAAYBACDgAAYBACDgAAYBACDgAAYBACDgAAYBDr5r2AJDn22GN748aN814GAADAXGzduvWh7l6/r3lPi4DbuHFjFhcX570MAACAuaiq+5Yzzy2UAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAg1hWwFXVvVV1W1XdUlWL09i/q6rPVNWtVbW5qo5ZMv/SqtpWVZ+tqnMO1OIBAADWkv25Avfy7j6tuxem/euSnNrdP5TkL5JcmiRVdUqSi5K8OMmrkvx2VR22imsGAABYk57yLZTd/dHufnzavTHJhmn7wiTv7+5vdPc9SbYlOWNlywQAAGC5AddJPlpVW6tq0x4e//kk/3PaPiHJ9iWP3T+NAQAAsALrljnvrO7eUVXPS3JdVX2muz+RJFX1tiSPJ/mD/XnhKQQ3JclJJ520P08FAABYk5Z1Ba67d0xfH0yyOdMtkVX1D5NckOR13d3T9B1JTlzy9A3T2O7f84ruXujuhfXr1z/lHwAAAGCt2GfAVdWRVXXUru0kr0xye1W9Kslbkry6ux9b8pQtSS6qqiOq6gVJTk7yydVfOgAAwNqynFsoj0uyuap2zb+6uz9SVduSHJHZLZVJcmN3/2J331FV1yS5M7NbKy/p7m8dmOUDAACsHfsMuO6+O8lL9jD+wu/ynHckecfKlgYAAMBST/nXCAAAAHBwCTgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBLCvgqureqrqtqm6pqsVp7Ker6o6qeqKqFnabf2lVbauqz1bVOQdi4QAAAGvNuv2Y+/LufmjJ/u1JfjLJ7y6dVFWnJLkoyYuTHJ/kf1XVi7r7WytdLAAAwFr2lG+h7O67uvuze3jowiTv7+5vdPc9SbYlOeOpvg4AAAAzyw24TvLRqtpaVZv2MfeEJNuX7N8/jX2HqtpUVYtVtbhz585lLgMAAGDtWm7AndXdpyc5N8klVfWjK33h7r6iuxe6e2H9+vUr/XYAAACHvGUFXHfvmL4+mGRzvvstkTuSnLhkf8M0BgAAwArsM+Cq6siqOmrXdpJXZvYBJnuzJclFVXVEVb0gyclJPrkaiwUAAFjLlvMplMcl2VxVu+Zf3d0fqaqfSPIfk6xP8idVdUt3n9Pdd1TVNUnuTPJ4kkt8AiUAAMDKVXfPew1ZWFjoxcXFeS8DAABgLqpqa3cv7GveU/41AgAAABxcAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQywq4qrq3qm6rqluqanEae05VXVdVfzl9/b5pvKrqsqraVlW3VtXpB/IHAAAAWCv25wrcy7v7tO5emPbfmuRj3X1yko9N+0lybpKTpz+bkly+WosFAABYy1ZyC+WFSa6atq9K8pol4+/tmRuTHFNVz1/B6wAAAJBk3TLndZKPVlUn+d3uviLJcd39wPT455McN22fkGT7kufeP409kIH8xv+4I3f+1VfmvQwAAGAVnXL8s/Nrf+/F817GU7bcgDuru3dU1fOSXFdVn1n6YHf3FHfLVlWbMrvFMieddNL+PBUAAGBNWlbAdfeO6euDVbU5yRlJvlBVz+/uB6ZbJB+cpu9IcuKSp2+Yxnb/nlckuSJJFhYW9iv+DoaRqxwAADg07fM9cFV1ZFUdtWs7ySuT3J5kS5KLp2kXJ/njaXtLkjdMn0Z5ZpJHltxqCQAAwFO0nCtwxyXZXFW75l/d3R+pqk8luaaq3pTkviQ/M83/cJLzkmxL8liSN676qgEAANagfQZcd9+d5CV7GH84ydl7GO8kl6zK6gAAAPi2lfwaAQAAAA4iAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADCIZQdcVR1WVTdX1Yem/VdU1aer6vaquqqq1k3jVVWXVdW2qrq1qk4/UIsHAABYS/bnCtybk9yVJFX1jCRXJbmou09Ncl+Si6d55yY5efqzKcnlq7ZaAACANWxZAVdVG5Kcn+TKaei5Sb7Z3X8x7V+X5Kem7QuTvLdnbkxyTFU9fxXXDAAAsCYt9wrcu5O8JckT0/5DSdZV1cK0/9okJ07bJyTZvuS5909j36GqNlXVYlUt7ty5c78XDgAAsNbsM+Cq6oIkD3b31l1j3d1JLkryrqr6ZJKvJvnW/rxwd1/R3QvdvbB+/fr9XDYAAMDas24Zc16a5NVVdV6SZyZ5dlX9t+5+fZK/kyRV9cokL5rm78iTV+OSZMM0BgAAwArs8wpcd1/a3Ru6e2NmV92u7+7XV9XzkqSqjkjyq0l+Z3rKliRvmD6N8swkj3T3Awdm+QAAAGvHcq7A7c0/n26vfEaSy7v7+mn8w0nOS7ItyWNJ3riyJQIAAJAkNXs723wtLCz04uLivJcBAAAwF1W1tbsX9jVvf34PHAAAAHMk4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAax7ICrqsOq6uaq+tC0f3ZVfbqqbqmq/11VL5zGj6iqD1TVtqq6qao2HpilAwAArC37cwXuzUnuWrJ/eZLXdfdpSa5O8i+n8Tcl+VJ3vzDJu5L85mosFAAAYK1bVsBV1YYk5ye5cslwJ3n2tH10kr+ati9MctW0fW2Ss6uqVr5UAACAtW3dMue9O8lbkhy1ZOwXkny4qr6W5CtJzpzGT0iyPUm6+/GqeiTJc5M8tCorBgAAWKP2eQWuqi5I8mB3b93toV9Jcl53b0jy+0neuT8vXFWbqmqxqhZ37ty5P08FAABYk5ZzC+VLk7y6qu5N8v4kr6iqP0nyku6+aZrzgSR/e9rekeTEJKmqdZndXvnw7t+0u6/o7oXuXli/fv3KfgoAAIA1YJ8B192XdveG7t6Y5KIk12f2Prejq+pF07S/myc/4GRLkoun7dcmub67e1VXDQAAsAYt9z1w32F6b9s/SvLBqnoiyZeS/Pz08HuS/Neq2pbki5lFHwAAACu0XwHX3R9P8vFpe3OSzXuY8/UkP70KawMAAGCJ/fk9cAAAAMyRgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABhEdfe815Cq2pnkvnmvYw+OTfLQvBfBIc9xxsHgOONAc4xxMDjOOBjmdZx9f3ev39ekp0XAPV1V1WJ3L8x7HRzaHGccDI4zDjTHGAeD44yD4el+nLmFEgAAYBACDgAAYBAC7ru7Yt4LYE1wnHEwOM440BxjHAyOMw6Gp/Vx5j1wAAAAg3AFDgAAYBACbi+q6lVV9dmq2lZVb533ehhfVZ1YVX9aVXdW1R1V9eZp/DlVdV1V/eX09fvmvVbGV1WHVdXNVfWhaf8FVXXTdE77QFUdPu81MraqOqaqrq2qz1TVXVX1t5zPWE1V9SvT35e3V9X7quqZzmWshqr6vap6sKpuXzK2x/NXzVw2HXO3VtXp81v5jIDbg6o6LMlvJTk3ySlJfraqTpnvqjgEPJ7kn3b3KUnOTHLJdFy9NcnHuvvkJB+b9mGl3pzkriX7v5nkXd39wiRfSvKmuayKQ8l/SPKR7v7BJC/J7HhzPmNVVNUJSX4pyUJ3n5rksCQXxbmM1fFfkrxqt7G9nb/OTXLy9GdTkssP0hr3SsDt2RlJtnX33d39zSTvT3LhnNfE4Lr7ge7+9LT91cz+Z+eEzI6tq6ZpVyV5zXxWyKGiqjYkOT/JldN+JXlFkmunKY4zVqSqjk7yo0nekyTd/c3u/nKcz1hd65J8b1WtS/KsJA/EuYxV0N2fSPLF3Yb3dv66MMl7e+bGJMdU1fMPzkr3TMDt2QlJti/Zv38ag1VRVRuT/HCSm5Ic190PTA99Pslxc1oWh453J3lLkiem/ecm+XJ3Pz7tO6exUi9IsjPJ70+36l5ZVUfG+YxV0t07kvz7JJ/LLNweSbI1zmUcOHs7fz3tukDAwUFWVX8tyQeT/HJ3f2XpYz37WFgfDctTVlUXJHmwu7fOey0c0tYlOT3J5d39w0kezW63SzqfsRLT+48uzOwfC45PcmT+/1ve4IB4up+/BNye7Uhy4pL9DdMYrEhVfU9m8fYH3f1H0/AXdl2Kn74+OK/1cUh4aZJXV9W9md3+/YrM3qt0zHQbUuKcxsrdn+T+7r5p2r82s6BzPmO1/HiSe7p7Z3f/3yR/lNn5zbmMA2Vv56+nXRcIuD37VJKTp086OjyzN81umfOaGNz0PqT3JLmru9+55KEtSS6eti9O8scHe20cOrr70u7e0N0bMzt3Xd/dr0vyp0leO01znLEi3f35JNur6gemobOT3BnnM1bP55KcWVXPmv7+3HWMOZdxoOzt/LUlyRumT6M8M8kjS261nAu/yHsvquq8zN5HcliS3+vud8x5SQyuqs5K8mdJbsuT7036F5m9D+6aJCcluS/Jz3T37m+shf1WVT+W5J919wVV9dczuyL3nCQ3J3l9d39jnutjbFV1WmYflHN4kruTvDGzfxh2PmNVVNVvJPn7mX2K881JfiGz9x45l7EiVfW+JD+W5NgkX0jya0n+e/Zw/pr+AeE/ZXYL72NJ3tjdi/NY9y4CDgAAYBBuoQQAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABjE/wMi+6Usv0OhqAAAAABJRU5ErkJggg==\n", 406 | "text/plain": [ 407 | "
" 408 | ] 409 | }, 410 | "metadata": { 411 | "needs_background": "light" 412 | }, 413 | "output_type": "display_data" 414 | } 415 | ], 416 | "source": [ 417 | "plt.figure(figsize=(15, 5))\n", 418 | "plt.title('reward')\n", 419 | "plt.plot(rewards)\n", 420 | "plt.show()" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 11, 426 | "metadata": {}, 427 | "outputs": [ 428 | { 429 | "data": { 430 | "text/plain": [ 431 | "[('CartPole-v0', 412, 1),\n", 432 | " ('CartPole-v1', 452, 0.05),\n", 433 | " ('MountainCar-v0', 193, 0.1),\n", 434 | " ('LunarLander-v2', 260, 0.1)]" 435 | ] 436 | }, 437 | "execution_count": 11, 438 | "metadata": {}, 439 | "output_type": "execute_result" 440 | } 441 | ], 442 | "source": [ 443 | "[\n", 444 | " ('CartPole-v0', 412, 1),\n", 445 | " ('CartPole-v1', 452, 0.05),\n", 446 | " ('MountainCar-v0', 193, 0.1),\n", 447 | " ('LunarLander-v2', 260, 0.1)\n", 448 | "]" 449 | ] 450 | } 451 | ], 452 | "metadata": { 453 | "colab": { 454 | "collapsed_sections": [], 455 | "name": "C51_tensorflow.ipynb", 456 | "provenance": [], 457 | "version": "0.3.2" 458 | }, 459 | "kernelspec": { 460 | "display_name": "Python 3", 461 | "language": "python", 462 | "name": "python3" 463 | }, 464 | "language_info": { 465 | "codemirror_mode": { 466 | "name": "ipython", 467 | "version": 3 468 | }, 469 | "file_extension": ".py", 470 | "mimetype": "text/x-python", 471 | "name": "python", 472 | "nbconvert_exporter": "python", 473 | "pygments_lexer": "ipython3", 474 | "version": "3.7.0" 475 | } 476 | }, 477 | "nbformat": 4, 478 | "nbformat_minor": 1 479 | } 480 | --------------------------------------------------------------------------------