├── .gitignore
├── LICENSE
├── README.md
├── image
├── CartPole-v0.gif
├── CartPole-v0_reward_curve.png
├── CartPole-v1.gif
├── CartPole-v1_reward_curve.png
├── LunarLander-v2.gif
├── LunarLander-v2_reward_curve.png
├── MountainCar-v0.gif
└── MountainCar-v0_reward_curve.png
└── src
├── PPO.ipynb
├── PPO.py
├── PPO2.ipynb
├── PPO2.py
├── categorical_dqn.ipynb
├── categorical_dueling_ddqn.ipynb
├── core.ipynb
├── core.py
├── dqn.ipynb
├── dueling_double_dqn.ipynb
├── grpo_discrete_episodic.ipynb
├── ppo_discrete_episodic.ipynb
├── ppo_discrete_episodic2.ipynb
├── ppo_discrete_step.ipynb
├── ppo_discrete_step.py
├── ppo_discrete_step_parallel.py
├── ppo_discrete_step_test.ipynb
├── ppo_discrete_step_test.py
├── rand_net_distill
├── dueling_ddqn_with_rnd.ipynb
└── ppo_step_with_rnd.ipynb
├── running_mean_std.py
├── soft_actor_critic.ipynb
└── test
├── running_mean_std.py
├── saved_models
├── CartPole-v0_ep179_clear_model_ppo_st.pt
├── CartPole-v0_ep87_clear_model_ppo_st.pt
├── CartPole-v1_ep108_clear_model_ppo_st.pt
├── CartPole-v1_ep112_clear_model_dddqn.pt
├── CartPole-v1_ep1150_clear_model_ppo_st.pt
├── CartPole-v1_ep118_clear_model_ppo_st.pt
├── CartPole-v1_ep1338_clear_model_ppo_st.pt
├── CartPole-v1_ep134_clear_model_ppo_st.pt
├── CartPole-v1_ep164_clear_model_ppo_st.pt
├── CartPole-v1_ep168_clear_model_dddqn.pt
├── CartPole-v1_ep1715_clear_model_ppo_st.pt
├── CartPole-v1_ep213_clear_model_ppo_st.pt
├── CartPole-v1_ep216_clear_model_ppo_st.pt
├── CartPole-v1_ep222_clear_model_ppo_st.pt
├── CartPole-v1_ep223_clear_model_ppo_st.pt
├── CartPole-v1_ep258_clear_model_ppo_st.pt
├── CartPole-v1_ep270_clear_model_ppo_st.pt
├── CartPole-v1_ep273_clear_model_ppo_st.pt
├── CartPole-v1_ep294_clear_model_ppo_st.pt
├── CartPole-v1_ep295_clear_model_ppo_st.pt
├── CartPole-v1_ep296_clear_model_ppo_st.pt
├── CartPole-v1_ep308_clear_model_ppo_st.pt
├── CartPole-v1_ep313_clear_model_ppo_st.pt
├── CartPole-v1_ep320_clear_model_ppo_st.pt
├── CartPole-v1_ep326_clear_model_ppo_st.pt
├── CartPole-v1_ep329_clear_model_ppo_st.pt
├── CartPole-v1_ep357_clear_model_ppo_st.pt
├── CartPole-v1_ep376_clear_model_ppo_st.pt
├── CartPole-v1_ep385_clear_model_ppo_st.pt
├── CartPole-v1_ep388_clear_model_ppo_st.pt
├── CartPole-v1_ep408_clear_model_ppo_st.pt
├── CartPole-v1_ep417_clear_model_ppo_st.pt
├── CartPole-v1_ep418_clear_model_ppo_st.pt
├── CartPole-v1_ep420_clear_model_ppo_st.pt
├── CartPole-v1_ep424_clear_model_ppo_st.pt
├── CartPole-v1_ep433_clear_model_ppo_st.pt
├── CartPole-v1_ep436_clear_model_ppo_st.pt
├── CartPole-v1_ep445_clear_model_ppo_st.pt
├── CartPole-v1_ep454_clear_model_ppo_st.pt
├── CartPole-v1_ep456_clear_model_ppo_st.pt
├── CartPole-v1_ep458_clear_model_ppo_st.pt
├── CartPole-v1_ep463_clear_model_ppo_st.pt
├── CartPole-v1_ep470_clear_model_ppo_st.pt
├── CartPole-v1_ep474_clear_model_ppo_st.pt
├── CartPole-v1_ep475_clear_model_ppo_st.pt
├── CartPole-v1_ep492_clear_model_ppo_st.pt
├── CartPole-v1_ep498_clear_model_ppo_st.pt
├── CartPole-v1_ep504_clear_model_ppo_st.pt
├── CartPole-v1_ep506_clear_model_ppo_st.pt
├── CartPole-v1_ep509_clear_model_ppo_st.pt
├── CartPole-v1_ep516_clear_model_ppo_st.pt
├── CartPole-v1_ep519_clear_model_ppo_st.pt
├── CartPole-v1_ep530_clear_model_ppo_st.pt
├── CartPole-v1_ep534_clear_model_ppo_st.pt
├── CartPole-v1_ep546_clear_model_ppo_st.pt
├── CartPole-v1_ep552_clear_model_ppo_st.pt
├── CartPole-v1_ep555_clear_model_ppo_st.pt
├── CartPole-v1_ep559_clear_model_ppo_st.pt
├── CartPole-v1_ep566_clear_model_ppo_st.pt
├── CartPole-v1_ep569_clear_model_ppo_st.pt
├── CartPole-v1_ep586_clear_model_ppo_st.pt
├── CartPole-v1_ep589_clear_model_ppo_st.pt
├── CartPole-v1_ep595_clear_model_ppo_st.pt
├── CartPole-v1_ep631_clear_model_ppo_st.pt
├── CartPole-v1_ep635_clear_model_ppo_st.pt
├── CartPole-v1_ep637_clear_model_ppo_st.pt
├── CartPole-v1_ep642_clear_model_ppo_st.pt
├── CartPole-v1_ep675_clear_model_ppo_st.pt
├── CartPole-v1_ep690_clear_model_ppo_st.pt
├── CartPole-v1_ep727_clear_model_ppo_st.pt
├── CartPole-v1_ep741_clear_model_ppo_st.pt
├── CartPole-v1_ep763_clear_model_ppo_st.pt
├── CartPole-v1_ep840_clear_model_ppo_st.pt
├── CartPole-v1_ep845_clear_model_ppo_st.pt
├── CartPole-v1_ep866_clear_model_ppo_st.pt
├── CartPole-v1_ep871_clear_model_ppo_st.pt
├── CartPole-v1_ep892_clear_model_ppo_st.pt
├── CartPole-v1_ep903_clear_model_ppo_st.pt
├── CartPole-v1_ep924_clear_model_ppo_st.pt
├── CartPole-v1_ep939_clear_model_ppo_st.pt
├── CartPole-v1_up50_clear_model_ppo_st.pt
├── CartPole-v1_up50_clear_norm_obs.pkl
├── LunarLander-v2_ep260_clear_model_dqn.pt
├── LunarLander-v2_ep370_clear_model_dddqn.pt
├── LunarLander-v2_ep8461_clear_model_ppo_st.pt
├── LunarLander-v2_ep876_clear_model_ppo_st.pt
├── LunarLander-v2_up1099_clear_model_ppo_st.pt
├── LunarLander-v2_up1317_clear_model_ppo_st.pt
├── LunarLander-v2_up254_clear_model_ppo_st.pt
├── LunarLander-v2_up570_clear_model_ppo_st.pt
├── LunarLander-v2_up75_clear_model_ppo_st.pt
├── MountainCar-v0_ep304_clear_model_dqn.pt
├── MountainCar-v0_ep532_clear_model_dddqn.pt
├── MountainCar-v0_ep984_clear_model_dddqn.pt
├── MountainCar-v0_up1441_clear_model_ppo_st.pt
└── MountainCar-v0_up1441_clear_norm_obs.pkl
├── test_dqn.ipynb
├── test_dueling_double_dqn.ipynb
├── test_ppo.ipynb
└── test_ppo_with_rnd.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # VS Code
107 | .vscode
108 |
109 | # tensorboard
110 | runs
111 |
112 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Jungdae Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Deep RL Algorithms in PyTorch
2 |
3 | ### Models
4 | - DQN
5 | - Dueling Double DQN
6 | - Categorical DQN (C51)
7 | - Categotical Dueling Double DQN
8 | - Proximal Policy Optimization (PPO)
9 | + discrete (episodic, n-step)
10 | - Group Relative Policy Optimization (GRPO)
11 |
12 |
13 |
14 | ### Exploration
15 | - Random Network Distillation (RND)
16 |
17 |
18 | ### Experiments
19 | The result of passing the environment-defined "solving" criteria.
20 | - **Dueling Double DQN**
21 | + Only one hyperparameter "UP_COEF" was adjusted.
22 | ###### CartPole-v0
23 |
24 |


25 |
26 |
27 | ###### CartPole-v1
28 |
29 |


30 |
31 |
32 | ###### MountainCar-v0
33 |
34 |


35 |
36 |
37 | ###### LunarLander-v2
38 |
39 |


40 |
41 |
42 |
43 | ### TODO
44 | - Proximal Policy Optimization (PPO)
45 | + continuous
46 |
--------------------------------------------------------------------------------
/image/CartPole-v0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/CartPole-v0.gif
--------------------------------------------------------------------------------
/image/CartPole-v0_reward_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/CartPole-v0_reward_curve.png
--------------------------------------------------------------------------------
/image/CartPole-v1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/CartPole-v1.gif
--------------------------------------------------------------------------------
/image/CartPole-v1_reward_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/CartPole-v1_reward_curve.png
--------------------------------------------------------------------------------
/image/LunarLander-v2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/LunarLander-v2.gif
--------------------------------------------------------------------------------
/image/LunarLander-v2_reward_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/LunarLander-v2_reward_curve.png
--------------------------------------------------------------------------------
/image/MountainCar-v0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/MountainCar-v0.gif
--------------------------------------------------------------------------------
/image/MountainCar-v0_reward_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/image/MountainCar-v0_reward_curve.png
--------------------------------------------------------------------------------
/src/PPO.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import random\n",
10 | "import numpy as np\n",
11 | "import torch\n",
12 | "from torch.optim import Adam, AdamW\n",
13 | "import gym\n",
14 | "import time\n",
15 | "import core\n",
16 | "import matplotlib.pyplot as plt\n",
17 | "from IPython.display import clear_output\n",
18 | "from running_mean_std import RunningMeanStd\n",
19 | "\n",
20 | "\n",
21 | "class PPOBuffer(object):\n",
22 | " def __init__(self, obs_dim, act_dim, size, gamma=0.999, lam=0.97):\n",
23 | " self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)\n",
24 | " self.act_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32)\n",
25 | " self.adv_buf = np.zeros(size, dtype=np.float32)\n",
26 | " self.rew_buf = np.zeros(size, dtype=np.float32)\n",
27 | " self.ret_buf = np.zeros(size, dtype=np.float32)\n",
28 | " self.val_buf = np.zeros(size, dtype=np.float32)\n",
29 | " self.logp_buf = np.zeros(size, dtype=np.float32)\n",
30 | " self.gamma, self.lam = gamma, lam\n",
31 | " self.ptr, self.path_start_idx, self.max_size = 0, 0, size\n",
32 | "\n",
33 | " def store(self, obs, act, rew, val, logp):\n",
34 | " assert self.ptr < self.max_size\n",
35 | " self.obs_buf[self.ptr] = obs\n",
36 | " self.act_buf[self.ptr] = act\n",
37 | " self.rew_buf[self.ptr] = rew\n",
38 | " self.val_buf[self.ptr] = val\n",
39 | " self.logp_buf[self.ptr] = logp\n",
40 | " self.ptr += 1\n",
41 | "\n",
42 | " def finish_path(self, last_val=0):\n",
43 | " path_slice = slice(self.path_start_idx, self.ptr)\n",
44 | " rews = np.append(self.rew_buf[path_slice], last_val)\n",
45 | " vals = np.append(self.val_buf[path_slice], last_val)\n",
46 | " \n",
47 | " deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1]\n",
48 | " self.adv_buf[path_slice] = core.discount_cumsum(deltas, self.gamma * self.lam)\n",
49 | " self.ret_buf[path_slice] = core.discount_cumsum(rews, self.gamma)[:-1]\n",
50 | " self.path_start_idx = self.ptr\n",
51 | "\n",
52 | " def get(self):\n",
53 | " assert self.ptr == self.max_size\n",
54 | " self.ptr, self.path_start_idx = 0, 0\n",
55 | " adv_mean = np.mean(self.adv_buf)\n",
56 | " adv_std = np.std(self.adv_buf)\n",
57 | " self.adv_buf = (self.adv_buf - adv_mean) / adv_std\n",
58 | " data = dict(obs=self.obs_buf, act=self.act_buf, ret=self.ret_buf,\n",
59 | " adv=self.adv_buf, logp=self.logp_buf)\n",
60 | " return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in data.items()}"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "def plot(ep_ret_buf, eval_ret_buf, loss_buf):\n",
70 | " clear_output(True)\n",
71 | " plt.figure(figsize=(16, 5))\n",
72 | " plt.subplot(131)\n",
73 | " plt.plot(ep_ret_buf, alpha=0.5)\n",
74 | " plt.subplot(131)\n",
75 | " plt.plot(eval_ret_buf)\n",
76 | " plt.title(f\"Reward: {eval_ret_buf[-1]:.0f}\")\n",
77 | " plt.subplot(132)\n",
78 | " plt.plot(loss_buf['pi'], alpha=0.5)\n",
79 | " plt.title(f\"Pi_Loss: {np.mean(loss_buf['pi'][:-20:]):.3f}\")\n",
80 | " plt.subplot(133)\n",
81 | " plt.plot(loss_buf['vf'], alpha=0.5)\n",
82 | " plt.title(f\"Vf_Loss: {np.mean(loss_buf['vf'][-20:]):.2f}\")\n",
83 | " plt.show()"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": null,
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "def compute_loss_pi(data, ac, clip_ratio):\n",
93 | " obs, act, adv, logp_old = data['obs'], data['act'], data['adv'], data['logp']\n",
94 | "\n",
95 | " # Policy loss\n",
96 | " pi, logp = ac.pi(obs, act)\n",
97 | " ratio = torch.exp(logp - logp_old)\n",
98 | " clip_adv = torch.clamp(ratio, 1-clip_ratio, 1+clip_ratio) * adv\n",
99 | " loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()\n",
100 | "\n",
101 | " # Useful extra info\n",
102 | "# approx_kl = (logp_old - logp).mean().item()\n",
103 | " kl_div = ((logp.exp() * (logp - logp_old)).mean()).detach().item()\n",
104 | " ent = pi.entropy().mean().detach().item()\n",
105 | " clipped = ratio.gt(1+clip_ratio) | ratio.lt(1-clip_ratio)\n",
106 | " clipfrac = torch.as_tensor(clipped, dtype=torch.float32).mean().detach().item()\n",
107 | " pi_info = dict(kl=kl_div, ent=ent, cf=clipfrac)\n",
108 | "\n",
109 | " return loss_pi, pi_info\n",
110 | "\n",
111 | "def compute_loss_v(data, ac):\n",
112 | " obs, ret = data['obs'], data['ret']\n",
113 | " return ((ac.v(obs) - ret)**2).mean()\n",
114 | "\n",
115 | "\n",
116 | "def update(buf, train_pi_iters, train_vf_iters, clip_ratio, target_kl, ac, pi_optimizer, vf_optimizer, loss_buf):\n",
117 | " data = buf.get()\n",
118 | "\n",
119 | " # Train policy with multiple steps of gradient descent\n",
120 | " for i in range(train_pi_iters):\n",
121 | " pi_optimizer.zero_grad()\n",
122 | " loss_pi, pi_info = compute_loss_pi(data, ac, clip_ratio)\n",
123 | " loss_buf['pi'].append(loss_pi.item())\n",
124 | " kl = pi_info['kl']\n",
125 | " if kl > 1.5 * target_kl:\n",
126 | " print('Early stopping at step %d due to reaching max kl.'%i)\n",
127 | " break\n",
128 | " loss_pi.backward()\n",
129 | " pi_optimizer.step()\n",
130 | "\n",
131 | " # Value function learning\n",
132 | " for i in range(train_vf_iters):\n",
133 | " vf_optimizer.zero_grad()\n",
134 | " loss_vf = compute_loss_v(data, ac)\n",
135 | " loss_buf['vf'].append(loss_vf.item())\n",
136 | " loss_vf.backward()\n",
137 | " vf_optimizer.step()"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "def main():\n",
147 | " actor_critic=core.MLPActorCritic\n",
148 | " hidden_size = 64\n",
149 | " activation = torch.nn.Tanh\n",
150 | " seed = 5\n",
151 | " steps_per_epoch = 4096\n",
152 | " epochs = 1000\n",
153 | " gamma = 0.99\n",
154 | " lam = 0.97\n",
155 | " clip_ratio = 0.2\n",
156 | " pi_lr = 3e-4\n",
157 | " vf_lr = 1e-3\n",
158 | " train_pi_iters = 80\n",
159 | " train_vf_iters = 80\n",
160 | " max_ep_len = 1000\n",
161 | " target_kl = 0.01\n",
162 | " save_freq = 10\n",
163 | " obs_norm = True\n",
164 | " view_curve = True\n",
165 | "\n",
166 | " # make an environment\n",
167 | "# env = gym.make('CartPole-v0')\n",
168 | "# env = gym.make('CartPole-v1')\n",
169 | "# env = gym.make('MountainCar-v0')\n",
170 | "# env = gym.make('LunarLander-v2')\n",
171 | " env = gym.make('BipedalWalker-v3')\n",
172 | " print(f\"reward_threshold: {env.spec.reward_threshold}\")\n",
173 | "\n",
174 | " obs_dim = env.observation_space.shape\n",
175 | " act_dim = env.action_space.shape\n",
176 | "\n",
177 | " # Random seed\n",
178 | " env.seed(seed)\n",
179 | " random.seed(seed)\n",
180 | " torch.manual_seed(seed)\n",
181 | " np.random.seed(seed)\n",
182 | "\n",
183 | " # Create actor-critic module\n",
184 | " ac = actor_critic(env.observation_space, env.action_space, (hidden_size, hidden_size), activation)\n",
185 | " \n",
186 | " # Set up optimizers for policy and value function\n",
187 | " pi_optimizer = AdamW(ac.pi.parameters(), lr=pi_lr, eps=1e-6)\n",
188 | " vf_optimizer = AdamW(ac.v.parameters(), lr=vf_lr, eps=1e-6)\n",
189 | "\n",
190 | " # Count variables\n",
191 | " var_counts = tuple(core.count_vars(module) for module in [ac.pi, ac.v])\n",
192 | "\n",
193 | " # Set up experience buffer\n",
194 | " local_steps_per_epoch = int(steps_per_epoch)\n",
195 | " buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)\n",
196 | " \n",
197 | " # Prepare for interaction with environment\n",
198 | " start_time = time.time()\n",
199 | " o, ep_ret, ep_len = env.reset(), 0, 0\n",
200 | " ep_num = 0\n",
201 | " ep_ret_buf, eval_ret_buf = [], []\n",
202 | " loss_buf = {'pi': [], 'vf': []}\n",
203 | " obs_normalizer = RunningMeanStd(shape=env.observation_space.shape)\n",
204 | " # Main loop: collect experience in env and update/log each epoch\n",
205 | " for epoch in range(epochs):\n",
206 | " for t in range(local_steps_per_epoch):\n",
207 | " if obs_norm:\n",
208 | " obs_normalizer.update(np.array([o]))\n",
209 | " o_norm = np.clip((o - obs_normalizer.mean) / np.sqrt(obs_normalizer.var), -10, 10)\n",
210 | " a, v, logp = ac.step(torch.as_tensor(o_norm, dtype=torch.float32))\n",
211 | " else:\n",
212 | " a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32))\n",
213 | "\n",
214 | " next_o, r, d, _ = env.step(a)\n",
215 | " ep_ret += r\n",
216 | " ep_len += 1\n",
217 | "\n",
218 | " # save and log\n",
219 | " if obs_norm:\n",
220 | " buf.store(o_norm, a, r, v, logp)\n",
221 | " else:\n",
222 | " buf.store(o, a, r, v, logp)\n",
223 | "\n",
224 | " # Update obs\n",
225 | " o = next_o\n",
226 | "\n",
227 | " timeout = ep_len == max_ep_len\n",
228 | " terminal = d or timeout\n",
229 | " epoch_ended = t==local_steps_per_epoch-1\n",
230 | "\n",
231 | " if terminal or epoch_ended:\n",
232 | " if timeout or epoch_ended:\n",
233 | " if obs_norm:\n",
234 | " obs_normalizer.update(np.array([o]))\n",
235 | " o_norm = np.clip((o - obs_normalizer.mean) / np.sqrt(obs_normalizer.var), -10, 10)\n",
236 | " _, v, _ = ac.step(torch.as_tensor(o_norm, dtype=torch.float32))\n",
237 | " else:\n",
238 | " _, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32))\n",
239 | " else:\n",
240 | " if obs_norm:\n",
241 | " obs_normalizer.update(np.array([o]))\n",
242 | " v = 0\n",
243 | " buf.finish_path(v)\n",
244 | " if terminal:\n",
245 | " ep_ret_buf.append(ep_ret)\n",
246 | " eval_ret_buf.append(np.mean(ep_ret_buf[-20:]))\n",
247 | " ep_num += 1\n",
248 | " if view_curve:\n",
249 | " plot(ep_ret_buf, eval_ret_buf, loss_buf)\n",
250 | " else:\n",
251 | " print(f'Episode: {ep_num:3} Reward: {ep_reward:3}')\n",
252 | " if eval_ret_buf[-1] >= env.spec.reward_threshold:\n",
253 | " print(f\"\\n{env.spec.id} is sloved! {ep_num} Episode\")\n",
254 | " return\n",
255 | "\n",
256 | " o, ep_ret, ep_len = env.reset(), 0, 0\n",
257 | " # Perform PPO update!\n",
258 | " update(buf, train_pi_iters, train_vf_iters, clip_ratio, target_kl, ac, pi_optimizer, vf_optimizer, loss_buf)"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": null,
264 | "metadata": {},
265 | "outputs": [],
266 | "source": [
267 | "main()"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": null,
273 | "metadata": {},
274 | "outputs": [],
275 | "source": []
276 | }
277 | ],
278 | "metadata": {
279 | "kernelspec": {
280 | "display_name": "Python 3",
281 | "language": "python",
282 | "name": "python3"
283 | },
284 | "language_info": {
285 | "codemirror_mode": {
286 | "name": "ipython",
287 | "version": 3
288 | },
289 | "file_extension": ".py",
290 | "mimetype": "text/x-python",
291 | "name": "python",
292 | "nbconvert_exporter": "python",
293 | "pygments_lexer": "ipython3",
294 | "version": "3.6.10"
295 | }
296 | },
297 | "nbformat": 4,
298 | "nbformat_minor": 4
299 | }
300 |
--------------------------------------------------------------------------------
/src/PPO.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | # %%
4 |
5 | # %%
6 |
7 | import pickle
8 | import random
9 |
10 | import core
11 | import gym
12 | import matplotlib.pyplot as plt
13 | import numpy as np
14 | import torch
15 | from running_mean_std import RunningMeanStd
16 | from torch.optim import AdamW
17 |
18 |
19 | class PPOBuffer(object):
20 | def __init__(self, obs_dim, act_dim, size, gamma=0.999, lam=0.97):
21 | self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
22 | self.act_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32)
23 | self.adv_buf = np.zeros(size, dtype=np.float32)
24 | self.rew_buf = np.zeros(size, dtype=np.float32)
25 | self.ret_buf = np.zeros(size, dtype=np.float32)
26 | self.val_buf = np.zeros(size, dtype=np.float32)
27 | self.logp_buf = np.zeros(size, dtype=np.float32)
28 | self.gamma, self.lam = gamma, lam
29 | self.ptr, self.path_start_idx, self.max_size = 0, 0, size
30 |
31 | def store(self, obs, act, rew, val, logp):
32 | assert self.ptr < self.max_size
33 | self.obs_buf[self.ptr] = obs
34 | self.act_buf[self.ptr] = act
35 | self.rew_buf[self.ptr] = rew
36 | self.val_buf[self.ptr] = val
37 | self.logp_buf[self.ptr] = logp
38 | self.ptr += 1
39 |
40 | def finish_path(self, last_val=0):
41 | path_slice = slice(self.path_start_idx, self.ptr)
42 | rews = np.append(self.rew_buf[path_slice], last_val)
43 | vals = np.append(self.val_buf[path_slice], last_val)
44 | deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1]
45 | self.adv_buf[path_slice] = core.discount_cumsum(deltas, self.gamma * self.lam)
46 | self.ret_buf[path_slice] = core.discount_cumsum(rews, self.gamma)[:-1]
47 | self.path_start_idx = self.ptr
48 |
49 | def get(self):
50 | assert self.ptr == self.max_size
51 | self.ptr, self.path_start_idx = 0, 0
52 | adv_mean = np.mean(self.adv_buf)
53 | adv_std = np.std(self.adv_buf)
54 | self.adv_buf = (self.adv_buf - adv_mean) / adv_std
55 | data = dict(obs=self.obs_buf, act=self.act_buf, ret=self.ret_buf, adv=self.adv_buf, logp=self.logp_buf)
56 | return {k: torch.as_tensor(v, dtype=torch.float32) for k, v in data.items()}
57 |
58 |
59 | # %%
60 |
61 |
62 | def plot(ep_ret_buf, eval_ret_buf, loss_buf):
63 | plt.figure(figsize=(16, 5))
64 | plt.subplot(131)
65 | plt.plot(ep_ret_buf, alpha=0.5)
66 | plt.subplot(131)
67 | plt.plot(eval_ret_buf)
68 | plt.title(f"Reward: {eval_ret_buf[-1]:.0f}")
69 | plt.subplot(132)
70 | plt.plot(loss_buf["pi"], alpha=0.5)
71 | plt.title(f"Pi_Loss: {np.mean(loss_buf['pi'][:-20:]):.3f}")
72 | plt.subplot(133)
73 | plt.plot(loss_buf["vf"], alpha=0.5)
74 | plt.title(f"Vf_Loss: {np.mean(loss_buf['vf'][-20:]):.2f}")
75 | plt.show()
76 |
77 |
78 | # %%
79 |
80 |
81 | def compute_loss_pi(data, ac, clip_ratio):
82 | obs, act, adv, logp_old = data["obs"], data["act"], data["adv"], data["logp"]
83 |
84 | # Policy loss
85 | pi, logp = ac.pi(obs, act)
86 | ratio = torch.exp(logp - logp_old)
87 | clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv
88 | loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()
89 |
90 | # Useful extra info
91 | # approx_kl = (logp_old - logp).mean().item()
92 | kl_div = ((logp.exp() * (logp - logp_old)).mean()).detach().item()
93 | ent = pi.entropy().mean().detach().item()
94 | clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
95 | clipfrac = torch.as_tensor(clipped, dtype=torch.float32).mean().detach().item()
96 | pi_info = dict(kl=kl_div, ent=ent, cf=clipfrac)
97 |
98 | return loss_pi, pi_info
99 |
100 |
101 | def compute_loss_v(data, ac):
102 | obs, ret = data["obs"], data["ret"]
103 | return ((ac.v(obs) - ret) ** 2).mean()
104 |
105 |
106 | def update(buf, train_pi_iters, train_vf_iters, clip_ratio, target_kl, ac, pi_optimizer, vf_optimizer, loss_buf):
107 | data = buf.get()
108 |
109 | # Train policy with multiple steps of gradient descent
110 | for i in range(train_pi_iters):
111 | pi_optimizer.zero_grad()
112 | loss_pi, pi_info = compute_loss_pi(data, ac, clip_ratio)
113 | loss_buf["pi"].append(loss_pi.item())
114 | kl = pi_info["kl"]
115 | if kl > 1.5 * target_kl:
116 | print("Early stopping at step %d due to reaching max kl." % i)
117 | break
118 | loss_pi.backward()
119 | pi_optimizer.step()
120 |
121 | # Value function learning
122 | for i in range(train_vf_iters):
123 | vf_optimizer.zero_grad()
124 | loss_vf = compute_loss_v(data, ac)
125 | loss_buf["vf"].append(loss_vf.item())
126 | loss_vf.backward()
127 | vf_optimizer.step()
128 |
129 |
130 | # %%
131 |
132 |
133 | def main():
134 | actor_critic = core.MLPActorCritic
135 | hidden_size = 64
136 | activation = torch.nn.Tanh
137 | seed = 5
138 | steps_per_epoch = 2048
139 | epochs = 1000
140 | gamma = 0.99
141 | lam = 0.97
142 | clip_ratio = 0.2
143 | pi_lr = 3e-4
144 | vf_lr = 1e-3
145 | train_pi_iters = 80
146 | train_vf_iters = 80
147 | max_ep_len = 1000
148 | target_kl = 0.01
149 | save_freq = 10
150 | obs_norm = True
151 | view_curve = False
152 |
153 | # make an environment
154 | # env = gym.make('CartPole-v0')
155 | # env = gym.make('CartPole-v1')
156 | # env = gym.make('MountainCar-v0')
157 | # env = gym.make('LunarLander-v2')
158 | env = gym.make("BipedalWalker-v3")
159 | print(f"reward_threshold: {env.spec.reward_threshold}")
160 |
161 | obs_dim = env.observation_space.shape
162 | act_dim = env.action_space.shape
163 |
164 | # Random seed
165 | env.seed(seed)
166 | random.seed(seed)
167 | torch.manual_seed(seed)
168 | np.random.seed(seed)
169 |
170 | # Create actor-critic module
171 | ac = actor_critic(env.observation_space, env.action_space, (hidden_size, hidden_size), activation)
172 |
173 | # Set up optimizers for policy and value function
174 | pi_optimizer = AdamW(ac.pi.parameters(), lr=pi_lr, eps=1e-6)
175 | vf_optimizer = AdamW(ac.v.parameters(), lr=vf_lr, eps=1e-6)
176 |
177 | # Set up experience buffer
178 | local_steps_per_epoch = int(steps_per_epoch)
179 | buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)
180 |
181 | # Prepare for interaction with environment
182 | o, ep_ret, ep_len = env.reset(), 0, 0
183 | ep_num = 0
184 | ep_ret_buf, eval_ret_buf = [], []
185 | loss_buf = {"pi": [], "vf": []}
186 | obs_normalizer = RunningMeanStd(shape=env.observation_space.shape)
187 | # Main loop: collect experience in env and update/log each epoch
188 | for epoch in range(epochs):
189 | for t in range(local_steps_per_epoch):
190 | env.render()
191 | if obs_norm:
192 | obs_normalizer.update(np.array([o]))
193 | o_norm = np.clip((o - obs_normalizer.mean) / np.sqrt(obs_normalizer.var), -10, 10)
194 | a, v, logp = ac.step(torch.as_tensor(o_norm, dtype=torch.float32))
195 | else:
196 | a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32))
197 |
198 | next_o, r, d, _ = env.step(a)
199 | ep_ret += r
200 | ep_len += 1
201 |
202 | # save and log
203 | if obs_norm:
204 | buf.store(o_norm, a, r, v, logp)
205 | else:
206 | buf.store(o, a, r, v, logp)
207 |
208 | # Update obs
209 | o = next_o
210 |
211 | timeout = ep_len == max_ep_len
212 | terminal = d or timeout
213 | epoch_ended = t == local_steps_per_epoch - 1
214 |
215 | if terminal or epoch_ended:
216 | if timeout or epoch_ended:
217 | if obs_norm:
218 | obs_normalizer.update(np.array([o]))
219 | o_norm = np.clip((o - obs_normalizer.mean) / np.sqrt(obs_normalizer.var), -10, 10)
220 | _, v, _ = ac.step(torch.as_tensor(o_norm, dtype=torch.float32))
221 | else:
222 | _, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32))
223 | else:
224 | if obs_norm:
225 | obs_normalizer.update(np.array([o]))
226 | v = 0
227 | buf.finish_path(v)
228 | if terminal:
229 | ep_ret_buf.append(ep_ret)
230 | eval_ret_buf.append(np.mean(ep_ret_buf[-20:]))
231 | ep_num += 1
232 | if view_curve:
233 | plot(ep_ret_buf, eval_ret_buf, loss_buf)
234 | else:
235 | print(f"Episode: {ep_num:3}\tReward: {ep_ret:3}")
236 | if eval_ret_buf[-1] >= env.spec.reward_threshold:
237 | print(f"\n{env.spec.id} is sloved! {ep_num} Episode")
238 | torch.save(ac.state_dict(), f"./test/saved_models/{env.spec.id}_ep{ep_num}_clear_model_ppo.pt")
239 | with open(f"./test/saved_models/{env.spec.id}_ep{ep_num}_clear_norm_obs.pkl", "wb") as f:
240 | pickle.dump(obs_normalizer, f, pickle.HIGHEST_PROTOCOL)
241 | return
242 |
243 | o, ep_ret, ep_len = env.reset(), 0, 0
244 | # Perform PPO update!
245 | update(buf, train_pi_iters, train_vf_iters, clip_ratio, target_kl, ac, pi_optimizer, vf_optimizer, loss_buf)
246 |
247 |
248 | # %%
249 |
250 |
251 | main()
252 |
--------------------------------------------------------------------------------
/src/PPO2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | # %%
4 | import random
5 | import time
6 |
7 | import core
8 | import gym
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import torch
12 | from IPython.display import clear_output
13 | from running_mean_std import RunningMeanStd
14 | from torch.optim import Adam, AdamW
15 |
16 |
17 | class PPOBuffer(object):
18 | def __init__(self, obs_dim, act_dim, size, gamma=0.999, lam=0.97):
19 | self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
20 | self.act_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32)
21 | self.adv_buf = np.zeros(size, dtype=np.float32)
22 | self.rew_buf = np.zeros(size, dtype=np.float32)
23 | self.ret_buf = np.zeros(size, dtype=np.float32)
24 | self.val_buf = np.zeros(size, dtype=np.float32)
25 | self.logp_buf = np.zeros(size, dtype=np.float32)
26 | self.gamma, self.lam = gamma, lam
27 | self.ptr, self.path_start_idx, self.max_size = 0, 0, size
28 |
29 | def store(self, obs, act, rew, val, logp):
30 | assert self.ptr < self.max_size
31 | self.obs_buf[self.ptr] = obs
32 | self.act_buf[self.ptr] = act
33 | self.rew_buf[self.ptr] = rew
34 | self.val_buf[self.ptr] = val
35 | self.logp_buf[self.ptr] = logp
36 | self.ptr += 1
37 |
38 | def finish_path(self, last_val=0):
39 | path_slice = slice(self.path_start_idx, self.ptr)
40 | rews = np.append(self.rew_buf[path_slice], last_val)
41 | vals = np.append(self.val_buf[path_slice], last_val)
42 |
43 | deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1]
44 | self.adv_buf[path_slice] = core.discount_cumsum(deltas, self.gamma * self.lam)
45 | self.ret_buf[path_slice] = core.discount_cumsum(rews, self.gamma)[:-1]
46 | self.path_start_idx = self.ptr
47 |
48 | def get(self):
49 | assert self.ptr == self.max_size
50 | self.ptr, self.path_start_idx = 0, 0
51 | adv_mean = np.mean(self.adv_buf)
52 | adv_std = np.std(self.adv_buf)
53 | self.adv_buf = (self.adv_buf - adv_mean) / adv_std
54 | data = dict(obs=self.obs_buf, act=self.act_buf, ret=self.ret_buf, adv=self.adv_buf, logp=self.logp_buf)
55 | return {k: torch.as_tensor(v, dtype=torch.float32) for k, v in data.items()}
56 |
57 |
58 | def plot(ep_ret_buf, eval_ret_buf, loss_buf):
59 | clear_output(True)
60 | plt.figure(figsize=(16, 5))
61 | plt.subplot(131)
62 | plt.plot(ep_ret_buf, alpha=0.5)
63 | plt.subplot(131)
64 | plt.plot(eval_ret_buf)
65 | plt.title(f"Reward: {eval_ret_buf[-1]:.0f}")
66 | plt.subplot(132)
67 | plt.plot(loss_buf["pi"], alpha=0.5)
68 | plt.title(f"Pi_Loss: {np.mean(loss_buf['pi'][:-10:]):.3f}")
69 | plt.subplot(133)
70 | plt.plot(loss_buf["vf"], alpha=0.5)
71 | plt.title(f"Vf_Loss: {np.mean(loss_buf['vf'][-10:]):.2f}")
72 | plt.show()
73 |
74 |
75 | def compute_loss_pi(data, ac, beta):
76 | obs, act, adv, logp_old = data["obs"], data["act"], data["adv"], data["logp"]
77 |
78 | # Policy loss
79 | pi, logp = ac.pi(obs, act)
80 | ratio = torch.exp(logp - logp_old)
81 | kl_div = logp.exp() * (logp - logp_old)
82 | loss_pi = -(ratio * adv - beta * kl_div).mean()
83 |
84 | # Useful extra info
85 | kl = ((logp.exp() * (logp - logp_old)).mean()).detach().item()
86 | ent = pi.entropy().mean().detach().item()
87 | pi_info = dict(kl=kl, ent=ent)
88 | return loss_pi, pi_info
89 |
90 |
91 | def compute_loss_v(data, ac):
92 | obs, ret = data["obs"], data["ret"]
93 | return ((ac.v(obs) - ret) ** 2).mean()
94 |
95 |
96 | def update(buf, train_pi_iters, train_vf_iters, beta, target_kl, ac, pi_optimizer, vf_optimizer, loss_buf):
97 | data = buf.get()
98 |
99 | # Train policy with multiple steps of gradient descent
100 | for i in range(train_pi_iters):
101 | pi_optimizer.zero_grad()
102 | loss_pi, pi_info = compute_loss_pi(data, ac, beta)
103 | loss_buf["pi"].append(loss_pi.item())
104 | loss_pi.backward()
105 | pi_optimizer.step()
106 |
107 | # Value function learning
108 | for i in range(train_vf_iters):
109 | vf_optimizer.zero_grad()
110 | loss_vf = compute_loss_v(data, ac)
111 | loss_buf["vf"].append(loss_vf.item())
112 | loss_vf.backward()
113 | vf_optimizer.step()
114 |
115 |
116 | def main():
117 | actor_critic = core.MLPActorCritic
118 | hidden_size = 64
119 | activation = torch.nn.Tanh
120 | seed = 5
121 | steps_per_epoch = 2048
122 | epochs = 1000
123 | gamma = 0.99
124 | lam = 0.97
125 | beta = 3.0
126 | pi_lr = 3e-4
127 | vf_lr = 1e-3
128 | train_pi_iters = 80
129 | train_vf_iters = 80
130 | max_ep_len = 1000
131 | target_kl = 0.01
132 | save_freq = 10
133 | obs_norm = True
134 | view_curve = False
135 |
136 | # make an environment
137 | # env = gym.make('CartPole-v0')
138 | # env = gym.make('CartPole-v1')
139 | # env = gym.make('MountainCar-v0')
140 | # env = gym.make('LunarLander-v2')
141 | env = gym.make("BipedalWalker-v3")
142 | print(f"reward_threshold: {env.spec.reward_threshold}")
143 |
144 | obs_dim = env.observation_space.shape
145 | act_dim = env.action_space.shape
146 |
147 | # Random seed
148 | env.seed(seed)
149 | random.seed(seed)
150 | torch.manual_seed(seed)
151 | np.random.seed(seed)
152 |
153 | # Create actor-critic module
154 | ac = actor_critic(env.observation_space, env.action_space, (hidden_size, hidden_size), activation)
155 |
156 | # Set up optimizers for policy and value function
157 | pi_optimizer = AdamW(ac.pi.parameters(), lr=pi_lr, eps=1e-6)
158 | vf_optimizer = AdamW(ac.v.parameters(), lr=vf_lr, eps=1e-6)
159 |
160 | # Count variables
161 | var_counts = tuple(core.count_vars(module) for module in [ac.pi, ac.v])
162 |
163 | # Set up experience buffer
164 | local_steps_per_epoch = int(steps_per_epoch)
165 | buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)
166 |
167 | # Prepare for interaction with environment
168 | start_time = time.time()
169 | o, ep_ret, ep_len = env.reset(), 0, 0
170 | ep_num = 0
171 | ep_ret_buf, eval_ret_buf = [], []
172 | loss_buf = {"pi": [], "vf": []}
173 | obs_normalizer = RunningMeanStd(shape=env.observation_space.shape)
174 | # Main loop: collect experience in env and update/log each epoch
175 | for epoch in range(epochs):
176 | for t in range(local_steps_per_epoch):
177 | env.render()
178 | if obs_norm:
179 | obs_normalizer.update(np.array([o]))
180 | o_norm = np.clip((o - obs_normalizer.mean) / np.sqrt(obs_normalizer.var), -10, 10)
181 | a, v, logp = ac.step(torch.as_tensor(o_norm, dtype=torch.float32))
182 | else:
183 | a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32))
184 |
185 | next_o, r, d, _ = env.step(a)
186 | ep_ret += r
187 | ep_len += 1
188 |
189 | # save and log
190 | if obs_norm:
191 | buf.store(o_norm, a, r, v, logp)
192 | else:
193 | buf.store(o, a, r, v, logp)
194 |
195 | # Update obs
196 | o = next_o
197 |
198 | timeout = ep_len == max_ep_len
199 | terminal = d or timeout
200 | epoch_ended = t == local_steps_per_epoch - 1
201 |
202 | if terminal or epoch_ended:
203 | if timeout or epoch_ended:
204 | if obs_norm:
205 | obs_normalizer.update(np.array([o]))
206 | o_norm = np.clip((o - obs_normalizer.mean) / np.sqrt(obs_normalizer.var), -10, 10)
207 | _, v, _ = ac.step(torch.as_tensor(o_norm, dtype=torch.float32))
208 | else:
209 | _, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32))
210 | else:
211 | if obs_norm:
212 | obs_normalizer.update(np.array([o]))
213 | v = 0
214 | buf.finish_path(v)
215 | if terminal:
216 | ep_ret_buf.append(ep_ret)
217 | eval_ret_buf.append(np.mean(ep_ret_buf[-100:]))
218 | ep_num += 1
219 | if view_curve:
220 | plot(ep_ret_buf, eval_ret_buf, loss_buf)
221 | else:
222 | print(f"Episode: {ep_num:3} Reward: {ep_ret:3}")
223 | if eval_ret_buf[-1] >= env.spec.reward_threshold:
224 | print(f"\n{env.spec.id} is sloved! {ep_num} Episode")
225 | return
226 |
227 | o, ep_ret, ep_len = env.reset(), 0, 0
228 | # Perform PPO update!
229 | update(buf, train_pi_iters, train_vf_iters, beta, target_kl, ac, pi_optimizer, vf_optimizer, loss_buf)
230 |
231 |
232 | if __name__ == "__main__":
233 | main()
234 |
235 |
236 | # %%
237 | # %%
238 |
--------------------------------------------------------------------------------
/src/core.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import scipy.signal\n",
11 | "from gym.spaces import Box, Discrete\n",
12 | "\n",
13 | "import torch\n",
14 | "import torch.nn as nn\n",
15 | "from torch.distributions.normal import Normal\n",
16 | "from torch.distributions.categorical import Categorical\n",
17 | "\n",
18 | "\n",
19 | "def combined_shape(length, shape=None):\n",
20 | " if shape is None:\n",
21 | " return (length,)\n",
22 | " return (length, shape) if np.isscalar(shape) else (length, *shape)\n",
23 | "\n",
24 | "\n",
25 | "def mlp(sizes, activation, output_activation=nn.Identity):\n",
26 | " layers = []\n",
27 | " for j in range(len(sizes)-1):\n",
28 | " act = activation if j < len(sizes)-2 else output_activation\n",
29 | " layers += [nn.Linear(sizes[j], sizes[j+1]), act()]\n",
30 | " return nn.Sequential(*layers)\n",
31 | "\n",
32 | "\n",
33 | "def count_vars(module):\n",
34 | " return sum([np.prod(p.shape) for p in module.parameters()])\n",
35 | "\n",
36 | "\n",
37 | "def discount_cumsum(x, discount):\n",
38 | " \"\"\"\n",
39 | " magic from rllab for computing discounted cumulative sums of vectors.\n",
40 | " input: \n",
41 | " vector x, \n",
42 | " [x0, \n",
43 | " x1, \n",
44 | " x2]\n",
45 | " output:\n",
46 | " [x0 + discount * x1 + discount^2 * x2, \n",
47 | " x1 + discount * x2,\n",
48 | " x2]\n",
49 | " \"\"\"\n",
50 | " return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]\n",
51 | "\n",
52 | "\n",
53 | "class Actor(nn.Module):\n",
54 | "\n",
55 | " def _distribution(self, obs):\n",
56 | " raise NotImplementedError\n",
57 | "\n",
58 | " def _log_prob_from_distribution(self, pi, act):\n",
59 | " raise NotImplementedError\n",
60 | "\n",
61 | " def forward(self, obs, act=None):\n",
62 | " # Produce action distributions for given observations, and \n",
63 | " # optionally compute the log likelihood of given actions under\n",
64 | " # those distributions.\n",
65 | " pi = self._distribution(obs)\n",
66 | " logp_a = None\n",
67 | " if act is not None:\n",
68 | " logp_a = self._log_prob_from_distribution(pi, act)\n",
69 | " return pi, logp_a\n",
70 | "\n",
71 | "\n",
72 | "class MLPCategoricalActor(Actor):\n",
73 | " \n",
74 | " def __init__(self, obs_dim, act_dim, hidden_sizes, activation):\n",
75 | " super().__init__()\n",
76 | " self.logits_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)\n",
77 | "\n",
78 | " def _distribution(self, obs):\n",
79 | " logits = self.logits_net(obs)\n",
80 | " return Categorical(logits=logits)\n",
81 | "\n",
82 | " def _log_prob_from_distribution(self, pi, act):\n",
83 | " return pi.log_prob(act)\n",
84 | "\n",
85 | "\n",
86 | "class MLPGaussianActor(Actor):\n",
87 | "\n",
88 | " def __init__(self, obs_dim, act_dim, hidden_sizes, activation):\n",
89 | " super().__init__()\n",
90 | " log_std = -0.5 * np.ones(act_dim, dtype=np.float32)\n",
91 | " self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))\n",
92 | " self.mu_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)\n",
93 | "\n",
94 | " def _distribution(self, obs):\n",
95 | " mu = self.mu_net(obs)\n",
96 | " std = torch.exp(self.log_std)\n",
97 | " return Normal(mu, std)\n",
98 | "\n",
99 | " def _log_prob_from_distribution(self, pi, act):\n",
100 | " return pi.log_prob(act).sum(axis=-1) # Last axis sum needed for Torch Normal distribution\n",
101 | "\n",
102 | "\n",
103 | "class MLPCritic(nn.Module):\n",
104 | "\n",
105 | " def __init__(self, obs_dim, hidden_sizes, activation):\n",
106 | " super().__init__()\n",
107 | " self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation)\n",
108 | "\n",
109 | " def forward(self, obs):\n",
110 | " return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape.\n",
111 | "\n",
112 | "\n",
113 | "\n",
114 | "class MLPActorCritic(nn.Module):\n",
115 | "\n",
116 | "\n",
117 | " def __init__(self, observation_space, action_space, \n",
118 | " hidden_sizes=(64,64), activation=nn.Tanh):\n",
119 | " super().__init__()\n",
120 | "\n",
121 | " obs_dim = observation_space.shape[0]\n",
122 | "\n",
123 | " # policy builder depends on action space\n",
124 | " if isinstance(action_space, Box):\n",
125 | " self.pi = MLPGaussianActor(obs_dim, action_space.shape[0], hidden_sizes, activation)\n",
126 | " elif isinstance(action_space, Discrete):\n",
127 | " self.pi = MLPCategoricalActor(obs_dim, action_space.n, hidden_sizes, activation)\n",
128 | "\n",
129 | " # build value function\n",
130 | " self.v = MLPCritic(obs_dim, hidden_sizes, activation)\n",
131 | "\n",
132 | " def step(self, obs):\n",
133 | " with torch.no_grad():\n",
134 | " pi = self.pi._distribution(obs)\n",
135 | " a = pi.sample()\n",
136 | " logp_a = self.pi._log_prob_from_distribution(pi, a)\n",
137 | " v = self.v(obs)\n",
138 | " return a.numpy(), v.numpy(), logp_a.numpy()\n",
139 | "\n",
140 | " def act(self, obs):\n",
141 | " return self.step(obs)[0]"
142 | ]
143 | }
144 | ],
145 | "metadata": {
146 | "kernelspec": {
147 | "display_name": "Python 3",
148 | "language": "python",
149 | "name": "python3"
150 | },
151 | "language_info": {
152 | "codemirror_mode": {
153 | "name": "ipython",
154 | "version": 3
155 | },
156 | "file_extension": ".py",
157 | "mimetype": "text/x-python",
158 | "name": "python",
159 | "nbconvert_exporter": "python",
160 | "pygments_lexer": "ipython3",
161 | "version": "3.6.10"
162 | }
163 | },
164 | "nbformat": 4,
165 | "nbformat_minor": 4
166 | }
167 |
--------------------------------------------------------------------------------
/src/core.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | # %%
4 | import numpy as np
5 | import scipy.signal
6 | import torch
7 | import torch.nn as nn
8 | from gym.spaces import Box, Discrete
9 | from torch.distributions.categorical import Categorical
10 | from torch.distributions.normal import Normal
11 |
12 |
13 | def combined_shape(length, shape=None):
14 | if shape is None:
15 | return (length,)
16 | return (length, shape) if np.isscalar(shape) else (length, *shape)
17 |
18 |
19 | def mlp(sizes, activation, output_activation=nn.Identity):
20 | layers = []
21 | for j in range(len(sizes) - 1):
22 | act = activation if j < len(sizes) - 2 else output_activation
23 | layers += [nn.Linear(sizes[j], sizes[j + 1]), act()]
24 | return nn.Sequential(*layers)
25 |
26 |
27 | def count_vars(module):
28 | return sum([np.prod(p.shape) for p in module.parameters()])
29 |
30 |
31 | def discount_cumsum(x, discount):
32 | """
33 | magic from rllab for computing discounted cumulative sums of vectors.
34 | input:
35 | vector x,
36 | [x0,
37 | x1,
38 | x2]
39 | output:
40 | [x0 + discount * x1 + discount^2 * x2,
41 | x1 + discount * x2,
42 | x2]
43 | """
44 | return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]
45 |
46 |
47 | class Actor(nn.Module):
48 |
49 | def _distribution(self, obs):
50 | raise NotImplementedError
51 |
52 | def _log_prob_from_distribution(self, pi, act):
53 | raise NotImplementedError
54 |
55 | def forward(self, obs, act=None):
56 | # Produce action distributions for given observations, and
57 | # optionally compute the log likelihood of given actions under
58 | # those distributions.
59 | pi = self._distribution(obs)
60 | logp_a = None
61 | if act is not None:
62 | logp_a = self._log_prob_from_distribution(pi, act)
63 | return pi, logp_a
64 |
65 |
66 | class MLPCategoricalActor(Actor):
67 |
68 | def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
69 | super().__init__()
70 | self.logits_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)
71 |
72 | def _distribution(self, obs):
73 | logits = self.logits_net(obs)
74 | return Categorical(logits=logits)
75 |
76 | def _log_prob_from_distribution(self, pi, act):
77 | return pi.log_prob(act)
78 |
79 |
80 | class MLPGaussianActor(Actor):
81 |
82 | def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
83 | super().__init__()
84 | log_std = -0.5 * np.ones(act_dim, dtype=np.float32)
85 | self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))
86 | self.mu_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)
87 |
88 | def _distribution(self, obs):
89 | mu = self.mu_net(obs)
90 | std = torch.exp(self.log_std)
91 | return Normal(mu, std)
92 |
93 | def _log_prob_from_distribution(self, pi, act):
94 | return pi.log_prob(act).sum(axis=-1) # Last axis sum needed for Torch Normal distribution
95 |
96 |
97 | class MLPCritic(nn.Module):
98 |
99 | def __init__(self, obs_dim, hidden_sizes, activation):
100 | super().__init__()
101 | self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation)
102 |
103 | def forward(self, obs):
104 | return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape.
105 |
106 |
107 | class MLPActorCritic(nn.Module):
108 |
109 | def __init__(self, observation_space, action_space, hidden_sizes=(64, 64), activation=nn.Tanh):
110 | super().__init__()
111 |
112 | obs_dim = observation_space.shape[0]
113 |
114 | # policy builder depends on action space
115 | if isinstance(action_space, Box):
116 | self.pi = MLPGaussianActor(obs_dim, action_space.shape[0], hidden_sizes, activation)
117 | elif isinstance(action_space, Discrete):
118 | self.pi = MLPCategoricalActor(obs_dim, action_space.n, hidden_sizes, activation)
119 |
120 | # build value function
121 | self.v = MLPCritic(obs_dim, hidden_sizes, activation)
122 |
123 | def step(self, obs):
124 | with torch.no_grad():
125 | pi = self.pi._distribution(obs)
126 | a = pi.sample()
127 | logp_a = self.pi._log_prob_from_distribution(pi, a)
128 | v = self.v(obs)
129 | return a.numpy(), v.numpy(), logp_a.numpy()
130 |
131 | def act(self, obs):
132 | return self.step(obs)[0]
133 |
--------------------------------------------------------------------------------
/src/ppo_discrete_step.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | # %%
4 | import random
5 | from collections import deque
6 | from copy import deepcopy
7 |
8 | import gym
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import torch
12 | import torch.nn as nn
13 | import torch.optim as optim
14 | from torch.distributions import Categorical
15 | from torch.utils.data import DataLoader
16 | from IPython.display import clear_output
17 |
18 |
19 | # %%
20 | SEED = 5
21 | BATCH_SIZE = 256
22 | LR = 0.01
23 | EPOCHS = 10
24 | CLIP = 0.2
25 | GAMMA = 0.999
26 | LAMBDA = 0.98
27 | ENT_COEF = 0.01
28 | V_COEF = 0.5
29 | V_CLIP = True
30 | LIN_REDUCE = False
31 | GRAD_NORM = False
32 | # set device
33 | use_cuda = torch.cuda.is_available()
34 | print('cuda:', use_cuda)
35 | device = torch.device('cuda' if use_cuda else 'cpu')
36 |
37 | # random seed
38 | random.seed(SEED)
39 | np.random.seed(SEED)
40 | torch.manual_seed(SEED)
41 | if use_cuda:
42 | torch.cuda.manual_seed_all(SEED)
43 |
44 |
45 | # %%
46 | class ActorCriticNet(nn.Module):
47 | def __init__(self, obs_space, action_space):
48 | super().__init__()
49 | h = 32
50 | self.head = nn.Sequential(
51 | nn.Linear(obs_space, h),
52 | nn.Tanh()
53 | )
54 | self.pol = nn.Sequential(
55 | nn.Linear(h, h),
56 | nn.Tanh(),
57 | nn.Linear(h, action_space)
58 | )
59 | self.val = nn.Sequential(
60 | nn.Linear(h, h),
61 | nn.Tanh(),
62 | nn.Linear(h, 1)
63 | )
64 | self.log_softmax = nn.LogSoftmax(dim=-1)
65 |
66 | def forward(self, x):
67 | out = self.head(x)
68 | logit = self.pol(out).reshape(out.shape[0], -1)
69 | log_p = self.log_softmax(logit)
70 | v = self.val(out).reshape(out.shape[0], 1)
71 |
72 | return log_p, v
73 |
74 |
75 | # %%
76 | losses = []
77 |
78 |
79 | def learn(net, old_net, optimizer, train_memory):
80 | global CLIP, LR
81 | global total_epochs
82 | net.train()
83 | old_net.train()
84 |
85 | for epoch in range(EPOCHS):
86 | if LIN_REDUCE:
87 | lr = LR - (LR * epoch / total_epochs)
88 | clip = CLIP - (CLIP * epoch / total_epochs)
89 | else:
90 | lr = LR
91 | clip = CLIP
92 |
93 | for param_group in optimizer.param_groups:
94 | param_group['lr'] = lr
95 |
96 | dataloader = DataLoader(
97 | train_memory,
98 | shuffle=True,
99 | batch_size=BATCH_SIZE,
100 | pin_memory=use_cuda
101 | )
102 |
103 | for (s, a, ret, adv) in dataloader:
104 | s_batch = s.to(device).float()
105 | a_batch = a.to(device).long()
106 | ret_batch = ret.to(device).float()
107 | ret_batch = (ret_batch - ret_batch.mean()) / (ret_batch.std() + 1e-6)
108 | adv_batch = adv.to(device).float()
109 | adv_batch = (adv_batch - adv_batch.mean()) / (adv_batch.std() + 1e-6)
110 | with torch.no_grad():
111 | log_p_batch_old, v_batch_old = old_net(s_batch)
112 | log_p_acting_old = log_p_batch_old[range(BATCH_SIZE), a_batch]
113 |
114 | log_p_batch, v_batch = net(s_batch)
115 | log_p_acting = log_p_batch[range(BATCH_SIZE), a_batch]
116 | p_ratio = (log_p_acting - log_p_acting_old).exp()
117 | p_ratio_clip = torch.clamp(p_ratio, 1 - clip, 1 + clip)
118 | p_loss = torch.min(p_ratio * adv_batch,
119 | p_ratio_clip * adv_batch).mean()
120 | if V_CLIP:
121 | v_clip = v_batch_old + torch.clamp(v_batch - v_batch_old, -clip, clip)
122 | v_loss1 = (ret_batch - v_clip).pow(2)
123 | v_loss2 = (ret_batch - v_batch).pow(2)
124 | v_loss = 0.5 * torch.max(v_loss1, v_loss2).mean()
125 | else:
126 | v_loss = 0.5 * (ret_batch - v_batch).pow(2).mean()
127 |
128 | log_p, _ = net(s_batch)
129 | entropy = -(log_p.exp() * log_p).sum(dim=1).mean()
130 |
131 | # loss
132 | loss = -(p_loss - V_COEF * v_loss + ENT_COEF * entropy)
133 | losses.append(loss.item())
134 |
135 | optimizer.zero_grad()
136 | loss.backward()
137 | if GRAD_NORM:
138 | nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.5)
139 | optimizer.step()
140 | train_memory.clear()
141 |
142 |
143 | def get_action_and_value(obs, old_net):
144 | old_net.eval()
145 | with torch.no_grad():
146 | state = torch.tensor([obs]).to(device).float()
147 | log_p, v = old_net(state)
148 | m = Categorical(log_p.exp())
149 | action = m.sample()
150 |
151 | return action.item(), v.item()
152 |
153 |
154 | def compute_adv_with_gae(rewards, values, roll_memory):
155 | rew = np.array(rewards, 'float')
156 | val = np.array(values[:-1], 'float')
157 | _val = np.array(values[1:], 'float')
158 | delta = rew + GAMMA * _val - val
159 | dis_r = np.array([GAMMA**(i) * r for i, r in enumerate(rewards)], 'float')
160 | gae_dt = np.array([(GAMMA * LAMBDA)**(i) * dt for i,
161 | dt in enumerate(delta.tolist())], 'float')
162 | for i, data in enumerate(roll_memory):
163 | data.append(sum(dis_r[i:] / GAMMA**(i)))
164 | data.append(sum(gae_dt[i:] / (GAMMA * LAMBDA)**(i)))
165 |
166 | rewards.clear()
167 | values.clear()
168 |
169 | return roll_memory
170 |
171 |
172 | def plot():
173 | clear_output(True)
174 | plt.figure(figsize=(16, 5))
175 | plt.subplot(121)
176 | plt.plot(ep_rewards, alpha=0.5)
177 | plt.subplot(121)
178 | plt.plot(reward_eval)
179 | plt.title(f'Reward: '
180 | f'{reward_eval[-1]}')
181 | plt.subplot(122)
182 | plt.plot(losses, alpha=0.5)
183 | plt.title(f'Loss: '
184 | f'{np.mean(list(reversed(losses))[: n_eval]).round(decimals=2)}')
185 | plt.show()
186 |
187 |
188 | # %%
189 | # make an environment
190 | # env = gym.make('CartPole-v0')
191 | env = gym.make('CartPole-v1')
192 | # env = gym.make('MountainCar-v0')
193 | # env = gym.make('LunarLander-v2')
194 |
195 | env.seed(SEED)
196 | obs_space = env.observation_space.shape[0]
197 | action_space = env.action_space.n
198 |
199 | # hyperparameter
200 | n_episodes = 100000
201 | roll_len = 2048
202 | total_epochs = roll_len // BATCH_SIZE
203 | n_eval = 10
204 |
205 | # global values
206 | steps = 0
207 | ep_rewards = []
208 | reward_eval = []
209 | is_rollout = False
210 | is_solved = False
211 |
212 | # make memories
213 | train_memory = []
214 | roll_memory = []
215 | rewards = []
216 | values = []
217 |
218 | # make nerual networks
219 | net = ActorCriticNet(obs_space, action_space).to(device)
220 | old_net = deepcopy(net)
221 | # no_decay = ['bias']
222 | # grouped_parameters = [
223 | # {'params': [p for n, p in net.named_parameters() if not any(
224 | # nd in n for nd in no_decay)], 'weight_decay': 0.0},
225 | # {'params': [p for n, p in net.named_parameters() if any(
226 | # nd in n for nd in no_decay)], 'weight_decay': 0.0}
227 | # ]
228 | optimizer = torch.optim.AdamW(net.parameters(), lr=LR, eps=1e-6)
229 |
230 | # play!
231 | for i in range(1, n_episodes + 1):
232 | obs = env.reset()
233 | done = False
234 | ep_reward = 0
235 | while not done:
236 | env.render()
237 | action, value = get_action_and_value(obs, old_net)
238 | _obs, reward, done, _ = env.step(action)
239 |
240 | # store
241 | roll_memory.append([obs, action])
242 | rewards.append(reward)
243 | values.append(value)
244 |
245 | obs = _obs
246 | steps += 1
247 | ep_reward += reward
248 |
249 | if done or steps % roll_len == 0:
250 | if done:
251 | _value = 0.
252 | else:
253 | _, _value = get_action_and_value(_obs, old_net)
254 |
255 | values.append(_value)
256 | train_memory.extend(compute_adv_with_gae(
257 | rewards, values, roll_memory))
258 | roll_memory.clear()
259 |
260 | if steps % roll_len == 0:
261 | learn(net, old_net, optimizer, train_memory)
262 | old_net.load_state_dict(net.state_dict())
263 |
264 | if done:
265 | ep_rewards.append(ep_reward)
266 | reward_eval.append(
267 | np.mean(list(reversed(ep_rewards))[: n_eval]).round(decimals=2))
268 | # plot()
269 | print('{:3} Episode in {:5} steps, reward {:.2f}'.format(
270 | i, steps, ep_reward))
271 |
272 | if len(ep_rewards) >= n_eval:
273 | if reward_eval[-1] >= env.spec.reward_threshold:
274 | # if reward_eval[-1] >= 495:
275 | print('\n{} is sloved! {:3} Episode in {:3} steps'.format(
276 | env.spec.id, i, steps))
277 | torch.save(net.state_dict(),
278 | f'./test/saved_models/{env.spec.id}_ep{i}_clear_model_ppo_st.pt')
279 | break
280 | env.close()
281 |
282 |
283 | # %%
284 | print(env.spec.max_episode_steps)
285 |
286 |
287 | # %%
288 | # [
289 | # ('CartPole-v0', 889, 2048, 0.2, 10, 0.5, 0.01, False, 0.999, 0.98),
290 | # ('CartPole-v1', 801, 2048, 0.2, 10, 0.5, 0.01, False, 0.999, 0.98),
291 | # ('MountainCar-v0', None),
292 | # ('LunarLander-v2', 876, 2048, 0.2, 10, 1.0, 0.01, False, 0.999, 0.98)
293 | # ]
294 |
295 |
--------------------------------------------------------------------------------
/src/ppo_discrete_step_parallel.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from collections import deque
3 | from copy import deepcopy
4 |
5 | import gym
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import torch.multiprocessing as mp
11 | from torch.distributions import Categorical
12 | from torch.utils.data import DataLoader
13 | from torch.utils.tensorboard import SummaryWriter
14 | from running_mean_std import RunningMeanStd
15 |
16 | N_PROCESS = 4
17 | ROLL_LEN = 2048 * N_PROCESS
18 | BATCH_SIZE = 2048
19 | P_LR = 3e-4
20 | V_LR = 1e-3
21 | ITER = 80
22 | CLIP = 0.2
23 | GAMMA = 0.999
24 | LAMBDA = 0.97
25 | # BETA = 3.0
26 | # ENT_COEF = 0.0
27 | GRAD_NORM = False
28 | OBS_NORM = True
29 |
30 | # set device
31 | use_cuda = torch.cuda.is_available()
32 | print('cuda:', use_cuda)
33 | device = torch.device('cuda:0' if use_cuda else 'cpu')
34 | writer = SummaryWriter()
35 |
36 | # random seed
37 | torch.manual_seed(5)
38 | if use_cuda:
39 | torch.cuda.manual_seed_all(5)
40 |
41 | # make an environment
42 | # env = gym.make('CartPole-v0')
43 | # env = gym.make('CartPole-v1')
44 | env = gym.make('MountainCar-v0')
45 | # env = gym.make('LunarLander-v2')
46 |
47 |
48 | class ActorCriticNet(nn.Module):
49 | def __init__(self, obs_space, action_space):
50 | super().__init__()
51 | h = 32
52 | self.pol = nn.Sequential(
53 | nn.Linear(obs_space, h),
54 | nn.Tanh(),
55 | nn.Linear(h, h),
56 | nn.Tanh(),
57 | nn.Linear(h, action_space)
58 | )
59 | self.val = nn.Sequential(
60 | nn.Linear(obs_space, h),
61 | nn.Tanh(),
62 | nn.Linear(h, h),
63 | nn.Tanh(),
64 | nn.Linear(h, 1)
65 | )
66 | self.log_softmax = nn.LogSoftmax(dim=-1)
67 |
68 | def forward(self, x):
69 | logit = self.pol(x).reshape(x.shape[0], -1)
70 | log_p = self.log_softmax(logit)
71 | v = self.val(x).reshape(x.shape[0], 1)
72 | return log_p, v
73 |
74 |
75 | # +
76 | def learn(net, optimizer, train_memory):
77 | global steps
78 | old_net = deepcopy(net)
79 | net.train()
80 | old_net.train()
81 | dataloader = DataLoader(
82 | train_memory,
83 | shuffle=False,
84 | batch_size=BATCH_SIZE,
85 | pin_memory=use_cuda
86 | )
87 | advs = []
88 | for data in dataloader.dataset:
89 | advs.append(data[3])
90 | advs = torch.tensor(advs, dtype=torch.float32).to(device)
91 | adv_mean = advs.mean()
92 | adv_std = advs.std()
93 | for _ in range(ITER):
94 | for (s, a, ret, adv) in dataloader:
95 | s_batch = s.to(device).float()
96 | a_batch = a.to(device).long()
97 | ret_batch = ret.to(device).float()
98 | adv_batch = adv.to(device).float()
99 | adv_batch = (adv_batch - adv_mean) / adv_std
100 | with torch.no_grad():
101 | log_p_batch_old, v_batch_old = old_net(s_batch)
102 | log_p_acting_old = log_p_batch_old[range(BATCH_SIZE), a_batch]
103 |
104 | log_p_batch, v_batch = net(s_batch)
105 | log_p_acting = log_p_batch[range(BATCH_SIZE), a_batch]
106 | p_ratio = (log_p_acting - log_p_acting_old).exp()
107 | p_ratio_clip = torch.clamp(p_ratio, 1 - CLIP, 1 + CLIP)
108 | p_loss = -(torch.min(p_ratio * adv_batch, p_ratio_clip * adv_batch).mean())
109 | v_loss = (ret_batch - v_batch).pow(2).mean()
110 |
111 | kl_div = ((log_p_batch.exp() * (log_p_batch - log_p_batch_old)).sum(dim=-1).mean()).detach().item()
112 |
113 | # log_p, _ = net(s_batch)
114 | # entropy = -(log_p.exp() * log_p).sum(dim=1).mean()
115 |
116 | # loss
117 | loss = p_loss + v_loss
118 |
119 | if kl_div <= 0.01 * 1.5:
120 | optimizer[0].zero_grad()
121 | p_loss.backward()
122 | if GRAD_NORM:
123 | nn.utils.clip_grad_norm_(net.parameters() , max_norm=1.0)
124 | optimizer[0].step()
125 | else:
126 | print("Pass the Pi update!")
127 | optimizer[1].zero_grad()
128 | v_loss.backward()
129 | if GRAD_NORM:
130 | nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
131 | optimizer[1].step()
132 | if rank == 0:
133 | writer.add_scalar('data/p_loss', p_loss.item(), steps)
134 | writer.add_scalar('data/v_loss', v_loss.item(), steps)
135 | writer.add_scalar('data/kl_div',kl_div, steps)
136 | steps += 1
137 | train_memory.clear()
138 | return net.state_dict()
139 |
140 |
141 | # -
142 |
143 | def get_action_and_value(obs, old_net):
144 | old_net.eval()
145 | with torch.no_grad():
146 | state = torch.tensor([obs]).to(device).float()
147 | log_p, v = old_net(state)
148 | m = Categorical(log_p.exp())
149 | action = m.sample()
150 | return action.item(), v.item()
151 |
152 |
153 | def compute_adv_with_gae(rewards, values, roll_memory):
154 | rew = np.array(rewards, 'float')
155 | val = np.array(values[:-1], 'float')
156 | _val = np.array(values[1:], 'float')
157 | ret = rew + GAMMA * _val
158 | delta = ret - val
159 | gae_dt = np.array([(GAMMA * LAMBDA)**(i) * dt for i, dt in enumerate(delta.tolist())], 'float')
160 | for i, data in enumerate(roll_memory):
161 | data.append(ret[i])
162 | data.append(sum(gae_dt[i:] / (GAMMA * LAMBDA)**(i)))
163 |
164 | rewards.clear()
165 | values.clear()
166 | return roll_memory
167 |
168 |
169 | def roll_out(env, length, rank, child):
170 | env.seed(rank)
171 |
172 | # hyperparameter
173 | roll_len = length
174 |
175 | # for play
176 | episodes = 0
177 | ep_steps = 0
178 | ep_rewards = []
179 |
180 | # memories
181 | train_memory = []
182 | roll_memory = []
183 | rewards = []
184 | values = []
185 |
186 | # recieve
187 | old_net, norm_obs = child.recv()
188 |
189 | # Play!
190 | while True:
191 | obs = env.reset()
192 | done = False
193 | ep_reward = 0
194 | while not done:
195 | # env.render()
196 | if OBS_NORM:
197 | norm_obs.update(np.array([obs]))
198 | obs_norm = np.clip((obs - norm_obs.mean) / np.sqrt(norm_obs.var), -10, 10)
199 | action, value = get_action_and_value(obs_norm, old_net)
200 | else:
201 | action, value = get_action_and_value(obs, old_net)
202 |
203 | # step
204 | _obs, reward, done, _ = env.step(action)
205 |
206 | # store
207 | values.append(value)
208 |
209 | if OBS_NORM:
210 | roll_memory.append([obs_norm, action])
211 | else:
212 | roll_memory.append([obs, action])
213 |
214 | rewards.append(reward)
215 | obs = _obs
216 | ep_reward += reward
217 | ep_steps += 1
218 |
219 | if done or ep_steps % roll_len == 0:
220 | if OBS_NORM:
221 | norm_obs.update(np.array([_obs]))
222 | if done:
223 | _value = 0.
224 | else:
225 | if OBS_NORM:
226 | _obs_norm = np.clip((_obs - norm_obs.mean) / np.sqrt(norm_obs.var), -10, 10)
227 | _, _value = get_action_and_value(_obs_norm, old_net)
228 | else:
229 | _, _value = get_action_and_value(_obs, old_net)
230 |
231 | values.append(_value)
232 | train_memory.extend(compute_adv_with_gae(rewards, values, roll_memory))
233 | roll_memory.clear()
234 |
235 | if ep_steps % roll_len == 0:
236 | child.send(((train_memory, ep_rewards), 'train', rank))
237 | train_memory.clear()
238 | ep_rewards.clear()
239 | state_dict, norm_obs = child.recv()
240 | old_net.load_state_dict(state_dict)
241 | break
242 |
243 | if done:
244 | episodes += 1
245 | ep_rewards.append(ep_reward)
246 | print('{:3} Episode in {:4} steps, reward {:4} [Process-{}]'.format(episodes, ep_steps, ep_reward, rank))
247 |
248 |
249 | if __name__ == '__main__':
250 | mp.set_start_method('spawn')
251 | obs_space = env.observation_space.shape[0]
252 | action_space = env.action_space.n
253 | n_eval = 100
254 | net = ActorCriticNet(obs_space, action_space).to(device)
255 | net.share_memory()
256 | param_p = [p for n, p in net.named_parameters() if 'pol' in n]
257 | param_v = [p for n, p in net.named_parameters() if 'val' in n]
258 | optim_p = torch.optim.AdamW(param_p, lr=P_LR, eps=1e-6)
259 | optim_v = torch.optim.AdamW(param_v, lr=V_LR, eps=1e-6)
260 | optimizer = [optim_p, optim_v]
261 | norm_obs = RunningMeanStd(shape=env.observation_space.shape)
262 |
263 | jobs = []
264 | pipes = []
265 | trajectory = []
266 | rewards = deque(maxlen=n_eval)
267 | update = 0
268 | steps = 0
269 | for i in range(N_PROCESS):
270 | parent, child = mp.Pipe()
271 | p = mp.Process(target=roll_out, args=(env, ROLL_LEN//N_PROCESS, i, child), daemon=True)
272 | jobs.append(p)
273 | pipes.append(parent)
274 |
275 | for i in range(N_PROCESS):
276 | pipes[i].send((net, norm_obs))
277 | jobs[i].start()
278 |
279 | while True:
280 | for i in range(N_PROCESS):
281 | data, msg, rank = pipes[i].recv()
282 | if msg == 'train':
283 | traj, ep_rews = data
284 | trajectory.extend(traj)
285 | rewards.extend(ep_rews)
286 |
287 | if len(rewards) == n_eval:
288 | writer.add_scalar('data/reward', np.mean(rewards), update)
289 | if np.mean(rewards) >= env.spec.reward_threshold:
290 | print('\n{} is sloved! [{} Update]'.format(env.spec.id, update))
291 | torch.save(net.state_dict(), f'./test/saved_models/{env.spec.id}_up{update}_clear_model_ppo_st.pt')
292 | with open(f'./test/saved_models/{env.spec.id}_up{update}_clear_norm_obs.pkl', 'wb') as f:
293 | pickle.dump(norm_obs, f, pickle.HIGHEST_PROTOCOL)
294 | break
295 |
296 | if len(trajectory) == ROLL_LEN:
297 | state_dict = learn(net, optimizer, trajectory)
298 | update += 1
299 | print(f'Update: {update}')
300 | for i in range(N_PROCESS):
301 | pipes[i].send((state_dict, norm_obs))
302 | env.close()
303 |
--------------------------------------------------------------------------------
/src/ppo_discrete_step_test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "colab": {},
8 | "colab_type": "code",
9 | "id": "IWnm3qot3o1W"
10 | },
11 | "outputs": [],
12 | "source": [
13 | "import random\n",
14 | "from collections import deque\n",
15 | "from copy import deepcopy\n",
16 | "\n",
17 | "import gym\n",
18 | "import matplotlib.pyplot as plt\n",
19 | "import numpy as np\n",
20 | "import torch\n",
21 | "import torch.nn as nn\n",
22 | "import torch.nn.functional as F\n",
23 | "import torch.optim as optim\n",
24 | "from torch.distributions import Categorical\n",
25 | "from torch.utils.data import DataLoader\n",
26 | "from IPython.display import clear_output\n",
27 | "\n",
28 | "from running_mean_std import RunningMeanStd"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 2,
34 | "metadata": {
35 | "colab": {},
36 | "colab_type": "code",
37 | "id": "IWnm3qot3o1W"
38 | },
39 | "outputs": [],
40 | "source": [
41 | "SEED = 5\n",
42 | "BATCH_SIZE = 2048\n",
43 | "P_LR = 3e-4\n",
44 | "V_LR = 1e-3\n",
45 | "ITER = 80\n",
46 | "CLIP = 0.2\n",
47 | "GAMMA = 0.999\n",
48 | "LAMBDA = 0.97\n",
49 | "BETA = 3.0\n",
50 | "# ENT_COEF = 0.0\n",
51 | "GRAD_NORM = False\n",
52 | "OBS_NORM = True\n",
53 | "VIEW_CURVE = False\n",
54 | "\n",
55 | "# set device\n",
56 | "use_cuda = torch.cuda.is_available()\n",
57 | "device = torch.device('cuda' if use_cuda else 'cpu')\n",
58 | "\n",
59 | "# random seed\n",
60 | "random.seed(SEED)\n",
61 | "np.random.seed(SEED)\n",
62 | "torch.manual_seed(SEED)\n",
63 | "if use_cuda:\n",
64 | " torch.cuda.manual_seed_all(SEED)"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "metadata": {
71 | "colab": {},
72 | "colab_type": "code",
73 | "id": "IWnm3qot3o1W"
74 | },
75 | "outputs": [],
76 | "source": [
77 | "class ActorCriticNet(nn.Module):\n",
78 | " def __init__(self, obs_space, action_space):\n",
79 | " super().__init__()\n",
80 | " h = 32\n",
81 | " self.pol = nn.Sequential(\n",
82 | " nn.Linear(obs_space, h),\n",
83 | " nn.Tanh(),\n",
84 | " nn.Linear(h, h),\n",
85 | " nn.Tanh(),\n",
86 | " nn.Linear(h, action_space)\n",
87 | " )\n",
88 | " self.val = nn.Sequential(\n",
89 | " nn.Linear(obs_space, h),\n",
90 | " nn.Tanh(),\n",
91 | " nn.Linear(h, h),\n",
92 | " nn.Tanh(),\n",
93 | " nn.Linear(h, 1)\n",
94 | " )\n",
95 | " self.log_softmax = nn.LogSoftmax(dim=-1)\n",
96 | "\n",
97 | " def forward(self, x):\n",
98 | " logit = self.pol(x).reshape(x.shape[0], -1)\n",
99 | " log_p = self.log_softmax(logit)\n",
100 | " v = self.val(x).reshape(x.shape[0], 1)\n",
101 | " return log_p, v"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": 4,
107 | "metadata": {
108 | "colab": {},
109 | "colab_type": "code",
110 | "id": "IWnm3qot3o1W"
111 | },
112 | "outputs": [],
113 | "source": [
114 | "losses = []\n",
115 | "kl_divs = []\n",
116 | "\n",
117 | "def learn(net, old_net, optimizer, train_memory):\n",
118 | " net.train()\n",
119 | " old_net.train()\n",
120 | " dataloader = DataLoader(\n",
121 | " train_memory,\n",
122 | " shuffle=False,\n",
123 | " batch_size=BATCH_SIZE,\n",
124 | " pin_memory=True,\n",
125 | " num_workers=0,\n",
126 | " )\n",
127 | " for _ in range(ITER): \n",
128 | " for i, (s, a, ret, adv) in enumerate(dataloader):\n",
129 | " s = s.to(device).float()\n",
130 | " a = a.to(device).long()\n",
131 | " ret = ret.to(device).float()\n",
132 | " adv = adv.to(device).float()\n",
133 | " \n",
134 | " with torch.no_grad():\n",
135 | " log_p_old, v_old = old_net(s)\n",
136 | " log_p_act_old = log_p_old[range(BATCH_SIZE), a]\n",
137 | "\n",
138 | " log_p, v = net(s)\n",
139 | " log_p_act = log_p[range(BATCH_SIZE), a]\n",
140 | " p_ratio = (log_p_act - log_p_act_old).exp()\n",
141 | " p_ratio_clip = torch.clamp(p_ratio, 1 - CLIP, 1 + CLIP)\n",
142 | " p_loss = -(torch.min(p_ratio * adv, p_ratio_clip * adv).mean())\n",
143 | "# kl_div = (log_p.exp() * (log_p - log_p_old)).sum(dim=-1)\n",
144 | "# p_loss = -(p_ratio * adv - BETA * kl_div).mean()\n",
145 | " v_loss = (ret - v).pow(2).mean()\n",
146 | "# log_p, _ = net(s)\n",
147 | "# entropy = -(log_p.exp() * log_p).sum(dim=1).mean()\n",
148 | " # loss\n",
149 | "# loss = p_loss + v_loss - ENT_COEF * entropy\n",
150 | " kl_div = ((log_p.exp() * (log_p - log_p_old)).sum(dim=-1).mean()).detach().item()\n",
151 | " loss = p_loss + v_loss\n",
152 | " losses.append(loss.item())\n",
153 | " kl_divs.append(kl_div)\n",
154 | " if kl_div <= 0.01 * 1.5:\n",
155 | " optimizer[0].zero_grad()\n",
156 | " p_loss.backward()\n",
157 | " if GRAD_NORM:\n",
158 | " nn.utils.clip_grad_norm_(net.parameters() , max_norm=1.0)\n",
159 | " optimizer[0].step()\n",
160 | " else:\n",
161 | " if not VIEW_CURVE:\n",
162 | " print(\"Pass the Pi update!\")\n",
163 | " optimizer[1].zero_grad()\n",
164 | " v_loss.backward()\n",
165 | " if GRAD_NORM:\n",
166 | " nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)\n",
167 | " optimizer[1].step()\n",
168 | "\n",
169 | "\n",
170 | "def get_action_and_value(obs, old_net):\n",
171 | " old_net.eval()\n",
172 | " with torch.no_grad():\n",
173 | " state = torch.tensor([obs]).to(device).float()\n",
174 | " log_p, v = old_net(state)\n",
175 | " m = Categorical(log_p.exp())\n",
176 | " action = m.sample()\n",
177 | " return action.item(), v.item()\n",
178 | "\n",
179 | "\n",
180 | "def compute_adv_with_gae(rewards, values, roll_memory):\n",
181 | " rew = np.array(rewards, np.float32)\n",
182 | " val = np.array(values[:-1], np.float32)\n",
183 | " _val = np.array(values[1:], np.float32)\n",
184 | " delta = rew + GAMMA * _val - val\n",
185 | " dis_r = np.array([GAMMA**(i) * r for i, r in enumerate(rewards)], np.float32)\n",
186 | " gae_dt = np.array([(GAMMA * LAMBDA)**(i) * dt for i, dt in enumerate(delta.tolist())], np.float32)\n",
187 | " for i, data in enumerate(roll_memory):\n",
188 | " data.append(sum(dis_r[i:] / GAMMA**(i)))\n",
189 | " data.append(sum(gae_dt[i:] / (GAMMA * LAMBDA)**(i)))\n",
190 | "\n",
191 | " rewards.clear()\n",
192 | " values.clear()\n",
193 | " return roll_memory\n",
194 | "\n",
195 | "\n",
196 | "def plot():\n",
197 | " clear_output(True)\n",
198 | " plt.figure(figsize=(16, 5))\n",
199 | " plt.subplot(131)\n",
200 | " plt.plot(ep_rewards, alpha=0.5)\n",
201 | " plt.subplot(131)\n",
202 | " plt.plot(reward_eval)\n",
203 | " plt.title(f'Reward: {reward_eval[-1]:.0f}')\n",
204 | " if losses:\n",
205 | " plt.subplot(132)\n",
206 | " plt.plot(losses, alpha=0.5)\n",
207 | " plt.title(f'Loss: {losses[-1]:.2f}')\n",
208 | " plt.subplot(133)\n",
209 | " plt.plot(kl_divs, alpha=0.5)\n",
210 | " plt.title(f'kl_div: {kl_divs[-1]:.4f}')\n",
211 | " plt.show()"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": 5,
217 | "metadata": {
218 | "colab": {},
219 | "colab_type": "code",
220 | "id": "IWnm3qot3o1W",
221 | "scrolled": false
222 | },
223 | "outputs": [
224 | {
225 | "name": "stdout",
226 | "output_type": "stream",
227 | "text": [
228 | "n_episodes: 625\n",
229 | "reward_threshold: 300\n"
230 | ]
231 | },
232 | {
233 | "name": "stderr",
234 | "output_type": "stream",
235 | "text": [
236 | "/workspace/media/ai/Storage/gym/gym/logger.py:30: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n",
237 | " warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))\n"
238 | ]
239 | },
240 | {
241 | "ename": "TypeError",
242 | "evalue": "new(): argument 'size' must be tuple of ints, but found element of type tuple at pos 2",
243 | "output_type": "error",
244 | "traceback": [
245 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
246 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
247 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;31m# make nerual networks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0mnet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mActorCriticNet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobs_space\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction_space\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 36\u001b[0m \u001b[0mold_net\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeepcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
248 | "\u001b[0;32m\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, obs_space, action_space)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m32\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m self.pol = nn.Sequential(\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobs_space\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTanh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
249 | "\u001b[0;32m/opt/conda/lib/python3.6/site-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, in_features, out_features, bias)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0min_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0min_features\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout_features\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 79\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mParameter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_features\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 80\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mParameter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
250 | "\u001b[0;31mTypeError\u001b[0m: new(): argument 'size' must be tuple of ints, but found element of type tuple at pos 2"
251 | ]
252 | }
253 | ],
254 | "source": [
255 | "# make an environment\n",
256 | "# env = gym.make('CartPole-v0')\n",
257 | "# env = gym.make('CartPole-v1')\n",
258 | "# env = gym.make('MountainCar-v0')\n",
259 | "# env = gym.make('LunarLander-v2')\n",
260 | "env = gym.make('BipedalWalker-v3')\n",
261 | "\n",
262 | "env.seed(SEED)\n",
263 | "obs_space = env.observation_space.shape\n",
264 | "action_space = env.action_space.shape\n",
265 | "\n",
266 | "# hyperparameter\n",
267 | "n_episodes = int(1e6 / env.spec.max_episode_steps)\n",
268 | "roll_len = 2048\n",
269 | "n_eval = 100\n",
270 | "\n",
271 | "print(f\"n_episodes: {n_episodes}\")\n",
272 | "print(f\"reward_threshold: {env.spec.reward_threshold}\")\n",
273 | "\n",
274 | "# global values\n",
275 | "steps = 0\n",
276 | "ep_rewards = []\n",
277 | "reward_eval = []\n",
278 | "is_rollout = False\n",
279 | "is_solved = False\n",
280 | "\n",
281 | "# make memories\n",
282 | "train_memory = []\n",
283 | "roll_memory = []\n",
284 | "rewards = []\n",
285 | "values = []\n",
286 | "norm_obs = RunningMeanStd(shape=env.observation_space.shape)\n",
287 | "\n",
288 | "# make nerual networks\n",
289 | "net = ActorCriticNet(obs_space, action_space).to(device)\n",
290 | "old_net = deepcopy(net)\n",
291 | "\n",
292 | "param_p = [p for n, p in net.named_parameters() if 'pol' in n]\n",
293 | "param_v = [p for n, p in net.named_parameters() if 'val' in n]\n",
294 | "optim_p = torch.optim.AdamW(param_p, lr=P_LR, eps=1e-6, weight_decay=0.01)\n",
295 | "optim_v = torch.optim.AdamW(param_v, lr=V_LR, eps=1e-6, weight_decay=0.01)\n",
296 | "optimizer = [optim_p, optim_v]\n",
297 | "\n",
298 | "# play!\n",
299 | "for i in range(1, n_episodes + 1):\n",
300 | " obs = env.reset()\n",
301 | " done = False\n",
302 | " ep_reward = 0\n",
303 | " while not done:\n",
304 | "# env.render()\n",
305 | " if OBS_NORM:\n",
306 | " print(obs)\n",
307 | " norm_obs.update(np.array([obs]))\n",
308 | " obs_norm = np.clip((obs - norm_obs.mean) / np.sqrt(norm_obs.var), -10, 10)\n",
309 | " action, value = get_action_and_value(obs_norm, old_net)\n",
310 | " else:\n",
311 | " action, value = get_action_and_value(obs, old_net)\n",
312 | " \n",
313 | " _obs, reward, done, _ = env.step(action)\n",
314 | "\n",
315 | " # store\n",
316 | " if OBS_NORM:\n",
317 | " roll_memory.append([obs_norm, action])\n",
318 | " else:\n",
319 | " roll_memory.append([obs, action])\n",
320 | " \n",
321 | "\n",
322 | " rewards.append(reward)\n",
323 | " values.append(value)\n",
324 | "\n",
325 | " obs = _obs\n",
326 | " steps += 1\n",
327 | " ep_reward += reward\n",
328 | "\n",
329 | " if done or steps % roll_len == 0:\n",
330 | " if OBS_NORM:\n",
331 | " norm_obs.update(np.array([_obs]))\n",
332 | " if done:\n",
333 | " _value = 0.\n",
334 | " else:\n",
335 | " if OBS_NORM:\n",
336 | " _obs_norm = np.clip((_obs - norm_obs.mean) / np.sqrt(norm_obs.var), -10, 10)\n",
337 | " _, _value = get_action_and_value(_obs_norm, old_net)\n",
338 | " else:\n",
339 | " _, _value = get_action_and_value(_obs, old_net)\n",
340 | "\n",
341 | " values.append(_value)\n",
342 | " train_memory.extend(compute_adv_with_gae(rewards, values, roll_memory))\n",
343 | " roll_memory.clear()\n",
344 | "\n",
345 | " if steps % roll_len == 0:\n",
346 | " # adv normalize\n",
347 | " advs = []\n",
348 | " for m in train_memory:\n",
349 | " advs.append(m[3])\n",
350 | " advs = np.array(advs)\n",
351 | " for m in train_memory:\n",
352 | " m[3] = (m[3] - np.mean(advs)) / np.std(advs)\n",
353 | " \n",
354 | " learn(net, old_net, optimizer, train_memory)\n",
355 | " old_net.load_state_dict(net.state_dict())\n",
356 | " train_memory.clear()\n",
357 | " break\n",
358 | "\n",
359 | " if done:\n",
360 | " ep_rewards.append(ep_reward)\n",
361 | " reward_eval.append(np.mean(ep_rewards[-n_eval:]))\n",
362 | " if VIEW_CURVE:\n",
363 | " plot()\n",
364 | " else:\n",
365 | " print(f'{i:3} Episode in {steps:5} steps, reward {ep_reward:.2f}')\n",
366 | "\n",
367 | " if len(ep_rewards) >= n_eval:\n",
368 | " if reward_eval[-1] >= env.spec.reward_threshold:\n",
369 | " print('\\n{} is sloved! {:3} Episode in {:3} steps'.format(\n",
370 | " env.spec.id, i, steps))\n",
371 | " torch.save(net.state_dict(),\n",
372 | " f'./test/saved_models/{env.spec.id}_ep{i}_clear_model_ppo_st.pt')\n",
373 | " break\n",
374 | "env.close()"
375 | ]
376 | },
377 | {
378 | "cell_type": "code",
379 | "execution_count": 1,
380 | "metadata": {},
381 | "outputs": [
382 | {
383 | "name": "stdout",
384 | "output_type": "stream",
385 | "text": [
386 | "[NbConvertApp] Converting notebook PPO.ipynb to script\n",
387 | "[NbConvertApp] Writing 8573 bytes to PPO.py\n"
388 | ]
389 | }
390 | ],
391 | "source": [
392 | "# !jupyter nbconvert --to script PPO.ipynb"
393 | ]
394 | },
395 | {
396 | "cell_type": "code",
397 | "execution_count": null,
398 | "metadata": {},
399 | "outputs": [],
400 | "source": []
401 | }
402 | ],
403 | "metadata": {
404 | "colab": {
405 | "collapsed_sections": [],
406 | "name": "C51_tensorflow.ipynb",
407 | "provenance": [],
408 | "version": "0.3.2"
409 | },
410 | "kernelspec": {
411 | "display_name": "Python 3",
412 | "language": "python",
413 | "name": "python3"
414 | },
415 | "language_info": {
416 | "codemirror_mode": {
417 | "name": "ipython",
418 | "version": 3
419 | },
420 | "file_extension": ".py",
421 | "mimetype": "text/x-python",
422 | "name": "python",
423 | "nbconvert_exporter": "python",
424 | "pygments_lexer": "ipython3",
425 | "version": "3.6.10"
426 | }
427 | },
428 | "nbformat": 4,
429 | "nbformat_minor": 1
430 | }
431 |
--------------------------------------------------------------------------------
/src/ppo_discrete_step_test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | # %%
4 | import random
5 | from collections import deque
6 | from copy import deepcopy
7 |
8 | import gym
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import torch.optim as optim
15 | from torch.distributions import Categorical
16 | from torch.utils.data import DataLoader
17 | from IPython.display import clear_output
18 |
19 | from running_mean_std import RunningMeanStd
20 |
21 |
22 | # %%
23 | SEED = 5
24 | BATCH_SIZE = 32
25 | LR = 3e-4
26 | EPOCHS = 20
27 | CLIP = 0.2
28 | GAMMA = 0.999
29 | LAMBDA = 0.97
30 | ENT_COEF = 0.0
31 | V_COEF = 2.0
32 | V_CLIP = False
33 | LIN_REDUCE = False
34 | GRAD_NORM = False
35 | OBS_NORM = True
36 | REW_NORM = False
37 | # set device
38 | use_cuda = torch.cuda.is_available()
39 | print('cuda:', use_cuda)
40 | device = torch.device('cuda' if use_cuda else 'cpu')
41 |
42 | # random seed
43 | random.seed(SEED)
44 | np.random.seed(SEED)
45 | torch.manual_seed(SEED)
46 | if use_cuda:
47 | torch.cuda.manual_seed_all(SEED)
48 |
49 |
50 | # %%
51 |
52 |
53 | class ActorCriticNet(nn.Module):
54 | def __init__(self, obs_space, action_space):
55 | super().__init__()
56 | h = 32
57 | # self.head = nn.Sequential(
58 | # nn.Linear(obs_space, h),
59 | # nn.Tanh()
60 | # )
61 | self.pol = nn.Sequential(
62 | nn.Linear(obs_space, h),
63 | nn.Tanh(),
64 | nn.Linear(h, h),
65 | nn.Tanh(),
66 | nn.Linear(h, action_space)
67 | )
68 | self.val = nn.Sequential(
69 | nn.Linear(obs_space, h),
70 | nn.Tanh(),
71 | nn.Linear(h, h),
72 | nn.Tanh(),
73 | nn.Linear(h, 1)
74 | )
75 | self.log_softmax = nn.LogSoftmax(dim=-1)
76 |
77 | def forward(self, x):
78 | # x = self.head(x)
79 | logit = self.pol(x).reshape(x.shape[0], -1)
80 | log_p = self.log_softmax(logit)
81 | v = self.val(x).reshape(x.shape[0], 1)
82 |
83 | return log_p, v
84 |
85 |
86 | # %%
87 |
88 |
89 | losses = []
90 |
91 |
92 | def learn(net, old_net, optimizer, train_memory):
93 | global CLIP, LR
94 | global total_epochs
95 | net.train()
96 | old_net.train()
97 | dataloader = DataLoader(
98 | train_memory,
99 | shuffle=True,
100 | batch_size=BATCH_SIZE,
101 | pin_memory=False,
102 | num_workers=0,
103 | )
104 | for epoch in range(EPOCHS):
105 | if LIN_REDUCE:
106 | clip = CLIP - (CLIP * epoch / total_epochs)
107 | lr = LR - (LR * epoch / total_epochs)
108 | for param_group in optimizer.param_groups:
109 | param_group['lr'] = lr
110 | else:
111 | clip = CLIP
112 |
113 | for (s, a, ret, adv) in dataloader:
114 | s_batch = s.to(device).float()
115 | a_batch = a.to(device).long()
116 | ret_batch = ret.to(device).float()
117 | ret_batch = (ret_batch - ret_batch.mean()) / ret_batch.std()
118 | adv_batch = adv.to(device).float()
119 | adv_batch = (adv_batch - adv_batch.mean()) / adv_batch.std()
120 |
121 | for optim in optimizer:
122 | optim.zero_grad()
123 |
124 | with torch.no_grad():
125 | log_p_batch_old, v_batch_old = old_net(s_batch)
126 | log_p_acting_old = log_p_batch_old[range(BATCH_SIZE), a_batch]
127 |
128 | log_p_batch, v_batch = net(s_batch)
129 | log_p_acting = log_p_batch[range(BATCH_SIZE), a_batch]
130 | p_ratio = (log_p_acting - log_p_acting_old).exp()
131 | p_ratio_clip = torch.clamp(p_ratio, 1 - clip, 1 + clip)
132 | p_loss = -(torch.min(p_ratio * adv_batch, p_ratio_clip * adv_batch)).mean()
133 | # approx_kl = (log_p_batch - log_p_batch_old).mean().detach().item()
134 | # kl_div = (log_p_batch.exp() * (log_p_batch - log_p_batch_old)).sum(dim=-1)
135 | # p_loss = -(p_ratio * adv_batch - 3.0 * kl_div).mean()
136 | if V_CLIP:
137 | v_clip = v_batch_old + torch.clamp(v_batch - v_batch_old, -clip, clip)
138 | v_loss1 = (ret_batch - v_clip).pow(2)
139 | v_loss2 = (ret_batch - v_batch).pow(2)
140 | v_loss = 0.5 * torch.max(v_loss1, v_loss2).mean()
141 | else:
142 | v_loss = 0.5 * (ret_batch - v_batch).pow(2).mean()
143 |
144 | # log_p, _ = net(s_batch)
145 | # entropy = -(log_p.exp() * log_p).sum(dim=1).mean()
146 |
147 | # loss
148 | # loss = p_loss + V_COEF * v_loss - ENT_COEF * entropy
149 | loss = p_loss + V_COEF * v_loss
150 |
151 | losses.append(loss.item())
152 | # if approx_kl <= 1.5* 0.01:
153 | # p_loss.backward()
154 | # optimizer[0].step()
155 | # v_loss.backward()
156 | # optimizer[1].step()
157 | loss.backward()
158 | if GRAD_NORM:
159 | nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.5)
160 | for optim in optimizer:
161 | optim.step()
162 |
163 |
164 | def get_action_and_value(obs, old_net):
165 | old_net.eval()
166 | with torch.no_grad():
167 | state = torch.tensor([obs]).to(device).float()
168 | log_p, v = old_net(state)
169 | m = Categorical(log_p.exp())
170 | action = m.sample()
171 | # print(action)
172 |
173 | return action.item(), v.item()
174 |
175 |
176 | def compute_adv_with_gae(rewards, values, roll_memory):
177 | rew = np.array(rewards, np.float64)
178 | val = np.array(values[:-1], np.float64)
179 | _val = np.array(values[1:], np.float64)
180 | delta = rew + GAMMA * _val - val
181 | dis_r = np.array([GAMMA**(i) * r for i, r in enumerate(rewards)], np.float64)
182 | gae_dt = np.array([(GAMMA * LAMBDA)**(i) * dt for i,
183 | dt in enumerate(delta.tolist())], np.float64)
184 | for i, data in enumerate(roll_memory):
185 | data.append(sum(dis_r[i:] / GAMMA**(i)))
186 | data.append(sum(gae_dt[i:] / (GAMMA * LAMBDA)**(i)))
187 |
188 | rewards.clear()
189 | values.clear()
190 |
191 | return roll_memory
192 |
193 |
194 | def plot():
195 | clear_output(True)
196 | plt.figure(figsize=(16, 5))
197 | plt.subplot(121)
198 | plt.plot(ep_rewards, alpha=0.5)
199 | plt.subplot(121)
200 | plt.plot(reward_eval)
201 | plt.title(f'Reward: '
202 | f'{reward_eval[-1]:.2f}')
203 | plt.subplot(122)
204 | plt.plot(losses, alpha=0.5)
205 | plt.title(f'Loss: '
206 | f'{np.mean(list(reversed(losses))[: n_eval]):.2f}')
207 | plt.show()
208 |
209 |
210 | # %%
211 |
212 |
213 | # make an environment
214 | # env = gym.make('CartPole-v0')
215 | env = gym.make('CartPole-v1')
216 | # env = gym.make('MountainCar-v0')
217 | # env = gym.make('LunarLander-v2')
218 |
219 | env.seed(SEED)
220 | obs_space = env.observation_space.shape[0]
221 | action_space = env.action_space.n
222 |
223 | # hyperparameter
224 | n_episodes = 100000
225 | roll_len = 4096
226 | total_epochs = roll_len // BATCH_SIZE
227 | n_eval = 10
228 |
229 | # global values
230 | steps = 0
231 | ep_rewards = []
232 | reward_eval = []
233 | is_rollout = False
234 | is_solved = False
235 |
236 | # make memories
237 | train_memory = []
238 | roll_memory = []
239 | obses = []
240 | rews = []
241 | rewards = []
242 | values = []
243 | norm_obs = RunningMeanStd(shape=env.observation_space.shape)
244 | norm_rew = RunningMeanStd()
245 |
246 | # make nerual networks
247 | net = ActorCriticNet(obs_space, action_space).to(device)
248 | old_net = deepcopy(net)
249 |
250 | # grouped_parameters = [
251 | # {'params': [p for n, p in net.named_parameters() if n == 'val'], 'lr': LR * 0.1},
252 | # {'params': [p for n, p in net.named_parameters() if n != 'val'], 'lr': LR}
253 | # ]
254 | param_p = [p for n, p in net.named_parameters() if 'val' not in n]
255 | param_v = [p for n, p in net.named_parameters() if 'val' in n]
256 | optim_p = torch.optim.AdamW(param_p, lr=LR, eps=1e-6)
257 | optim_v = torch.optim.AdamW(param_v, lr=0.001, eps=1e-6)
258 | optimizer = [optim_p, optim_v]
259 | # optimizer = [torch.optim.AdamW(net.parameters(), lr=LR, eps=1e-6)]
260 |
261 | # play!
262 | for i in range(1, n_episodes + 1):
263 | obs = env.reset()
264 | done = False
265 | ep_reward = 0
266 | while not done:
267 | # env.render()
268 | obses.append(obs)
269 | if OBS_NORM:
270 | obs_norm = np.clip((obs - norm_obs.mean) / np.sqrt(norm_obs.var), -10, 10)
271 | action, value = get_action_and_value(obs_norm, old_net)
272 | else:
273 | action, value = get_action_and_value(obs, old_net)
274 |
275 | _obs, reward, done, _ = env.step(action)
276 |
277 | # store
278 | if OBS_NORM:
279 | roll_memory.append([obs_norm, action])
280 | else:
281 | roll_memory.append([obs, action])
282 |
283 | if REW_NORM:
284 | rew_norm = np.clip((reward - norm_rew.mean) / np.sqrt(norm_rew.var), -10, 10)
285 | rewards.append(rew_norm)
286 | rews.append(reward)
287 | else:
288 | rewards.append(reward)
289 |
290 | values.append(value)
291 |
292 | obs = _obs
293 | steps += 1
294 | ep_reward += reward
295 |
296 | if done or steps % roll_len == 0:
297 | if done:
298 | obses.append(_obs)
299 | _value = 0.
300 | else:
301 | if OBS_NORM:
302 | _obs_norm = np.clip((_obs - norm_obs.mean) / np.sqrt(norm_obs.var), -10, 10)
303 | _, _value = get_action_and_value(_obs_norm, old_net)
304 | else:
305 | _, _value = get_action_and_value(_obs, old_net)
306 |
307 | values.append(_value)
308 | train_memory.extend(compute_adv_with_gae(rewards, values, roll_memory))
309 | roll_memory.clear()
310 |
311 | if steps % roll_len == 0:
312 | learn(net, old_net, optimizer, train_memory)
313 | old_net.load_state_dict(net.state_dict())
314 | if OBS_NORM:
315 | norm_obs.update(np.array(obses))
316 | if REW_NORM:
317 | norm_rew.update(np.array(rews))
318 | train_memory.clear()
319 | obses.clear()
320 | rews.clear()
321 |
322 | if done:
323 | ep_rewards.append(ep_reward)
324 | reward_eval.append(np.mean(list(reversed(ep_rewards))[: n_eval]))
325 | # plot()
326 | print('{:3} Episode in {:5} steps, reward {:.2f}'.format(
327 | i, steps, ep_reward))
328 |
329 | if len(ep_rewards) >= n_eval:
330 | if reward_eval[-1] >= env.spec.reward_threshold:
331 | print('\n{} is sloved! {:3} Episode in {:3} steps'.format(
332 | env.spec.id, i, steps))
333 | torch.save(net.state_dict(),
334 | f'./test/saved_models/{env.spec.id}_ep{i}_clear_model_ppo_st.pt')
335 | break
336 | env.close()
337 |
338 |
339 | # %%
340 |
341 |
342 | # env.spec.reward_threshold
343 |
344 |
345 | # %%
346 |
347 |
348 | # [
349 | # ('CartPole-v0', 889, 2048, 0.2, 10, 0.5, 0.01, False, 0.999, 0.98),
350 | # ('CartPole-v1', 801, 2048, 0.2, 10, 0.5, 0.01, False, 0.999, 0.98), 675,
351 | # ('MountainCar-v0', None),
352 | # ('LunarLander-v2', 876, 2048, 0.2, 10, 1.0, 0.01, False, 0.999, 0.98)
353 | # ]
354 |
355 |
356 | # %%
357 |
358 |
359 | # get_ipython().system('jupyter nbconvert --to script ppo_discrete_step_test.ipynb')
360 |
361 |
--------------------------------------------------------------------------------
/src/running_mean_std.py:
--------------------------------------------------------------------------------
1 | # https://github.com/openai/baselines/blob/master/baselines/common/running_mean_std.py
2 | import numpy as np
3 |
4 |
5 | class RunningMeanStd(object):
6 | def __init__(self, epsilon=1e-4, shape=()):
7 | self.mean = np.zeros(shape, np.float64)
8 | self.var = np.ones(shape, np.float64)
9 | self.count = epsilon
10 |
11 | def update(self, x):
12 | batch_mean = np.mean(x, axis=0)
13 | batch_var = np.var(x, axis=0)
14 | batch_count = x.shape[0]
15 | self.update_from_moments(batch_mean, batch_var, batch_count)
16 |
17 | def update_from_moments(self, batch_mean, batch_var, batch_count):
18 | self.mean, self.var, self.count = update_mean_var_count_from_moments(
19 | self.mean, self.var, self.count,
20 | batch_mean, batch_var, batch_count)
21 |
22 |
23 | def update_mean_var_count_from_moments(mean, var, count,
24 | batch_mean, batch_var, batch_count):
25 | delta = batch_mean - mean
26 | tot_count = count + batch_count
27 |
28 | new_mean = mean + delta * batch_count / tot_count
29 | m_a = var * count
30 | m_b = batch_var * batch_count
31 | M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
32 | new_var = M2 / tot_count
33 | new_count = tot_count
34 |
35 | return new_mean, new_var, new_count
36 |
--------------------------------------------------------------------------------
/src/soft_actor_critic.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "colab": {},
8 | "colab_type": "code",
9 | "id": "IWnm3qot3o1W"
10 | },
11 | "outputs": [],
12 | "source": [
13 | "import random\n",
14 | "import math\n",
15 | "from collections import deque\n",
16 | "from copy import deepcopy\n",
17 | "\n",
18 | "import gym\n",
19 | "import matplotlib.pyplot as plt\n",
20 | "import numpy as np\n",
21 | "import torch\n",
22 | "import torch.nn as nn\n",
23 | "import torch.optim as optim\n",
24 | "from torch.distributions import Categorical, Dirichlet\n",
25 | "from torch.utils.data import DataLoader\n",
26 | "from IPython.display import clear_output"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 2,
32 | "metadata": {
33 | "colab": {
34 | "base_uri": "https://localhost:8080/",
35 | "height": 35
36 | },
37 | "colab_type": "code",
38 | "executionInfo": {
39 | "elapsed": 708,
40 | "status": "ok",
41 | "timestamp": 1534482400648,
42 | "user": {
43 | "displayName": "윤승제",
44 | "photoUrl": "//lh5.googleusercontent.com/-EucKC7DmcQI/AAAAAAAAAAI/AAAAAAAAAGA/gQU1NPEmNFA/s50-c-k-no/photo.jpg",
45 | "userId": "105654037995838004821"
46 | },
47 | "user_tz": -540
48 | },
49 | "id": "maRVADiTlzHD",
50 | "outputId": "783b7610-95c2-4b54-b2ce-d8e853c484ba"
51 | },
52 | "outputs": [],
53 | "source": [
54 | "SEED = 1\n",
55 | "BATCH_SIZE = 32\n",
56 | "LR = 0.0003\n",
57 | "UP_COEF = 0.01\n",
58 | "ENT_COEF = 0.01\n",
59 | "GAMMA = 0.99\n",
60 | "\n",
61 | "# set device\n",
62 | "use_cuda = torch.cuda.is_available()\n",
63 | "device = torch.device('cuda' if use_cuda else 'cpu')\n",
64 | "\n",
65 | "# random seed\n",
66 | "random.seed(SEED)\n",
67 | "np.random.seed(SEED)\n",
68 | "torch.manual_seed(SEED)\n",
69 | "if use_cuda:\n",
70 | " torch.cuda.manual_seed_all(SEED)"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 3,
76 | "metadata": {
77 | "colab": {},
78 | "colab_type": "code",
79 | "id": "9Ffkl_5C4R81"
80 | },
81 | "outputs": [],
82 | "source": [
83 | "class QNet(nn.Module):\n",
84 | " def __init__(self, obs_space, action_space):\n",
85 | " super().__init__()\n",
86 | "\n",
87 | " self.head = nn.Sequential(\n",
88 | " nn.Linear(obs_space, obs_space*10),\n",
89 | " nn.SELU()\n",
90 | " )\n",
91 | " self.fc = nn.Sequential(\n",
92 | " nn.Linear(obs_space*10, 512),\n",
93 | " nn.SELU(),\n",
94 | " nn.Linear(512, 512),\n",
95 | " nn.SELU(),\n",
96 | " nn.Linear(512, action_space)\n",
97 | " )\n",
98 | "\n",
99 | " def forward(self, x):\n",
100 | " out = self.head(x)\n",
101 | " q = self.fc(out).reshape(out.shape[0], -1)\n",
102 | "\n",
103 | " return q\n",
104 | "\n",
105 | "\n",
106 | "class PolicyNet(nn.Module):\n",
107 | " def __init__(self, obs_space, action_space):\n",
108 | " super().__init__()\n",
109 | "\n",
110 | " self.head = nn.Sequential(\n",
111 | " nn.Linear(obs_space, obs_space*10),\n",
112 | " nn.SELU()\n",
113 | " )\n",
114 | "\n",
115 | " self.fc = nn.Sequential(\n",
116 | " nn.Linear(obs_space*10, 512),\n",
117 | " nn.SELU(),\n",
118 | " nn.Linear(512, 512),\n",
119 | " nn.SELU(),\n",
120 | " nn.Linear(512, action_space)\n",
121 | " )\n",
122 | "\n",
123 | " self.log_softmax = nn.LogSoftmax(dim=-1)\n",
124 | "\n",
125 | " def forward(self, x):\n",
126 | " out = self.head(x)\n",
127 | " logit = self.fc(out).reshape(out.shape[0], -1)\n",
128 | " log_p = self.log_softmax(logit)\n",
129 | "\n",
130 | " return log_p\n",
131 | "\n",
132 | "\n",
133 | "class ValueNet(nn.Module):\n",
134 | " def __init__(self, obs_space):\n",
135 | " super().__init__()\n",
136 | "\n",
137 | " self.head = nn.Sequential(\n",
138 | " nn.Linear(obs_space, obs_space*10),\n",
139 | " nn.SELU()\n",
140 | " )\n",
141 | "\n",
142 | " self.fc = nn.Sequential(\n",
143 | " nn.Linear(obs_space*10, 512),\n",
144 | " nn.SELU(),\n",
145 | " nn.Linear(512, 512),\n",
146 | " nn.SELU(),\n",
147 | " nn.Linear(512, 1)\n",
148 | " )\n",
149 | "\n",
150 | " def forward(self, x):\n",
151 | " out = self.head(x)\n",
152 | " v = self.fc(out).reshape(out.shape[0], 1)\n",
153 | "\n",
154 | " return v"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": 4,
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "losses = []\n",
164 | "\n",
165 | "\n",
166 | "def learn(q_net, p_net, v_net, v_tgt, optimizer, rep_memory):\n",
167 | " global action_space\n",
168 | " \n",
169 | " q_net.train()\n",
170 | " p_net.train()\n",
171 | " v_net.train()\n",
172 | " v_tgt.train()\n",
173 | "\n",
174 | " train_data = random.sample(rep_memory, BATCH_SIZE)\n",
175 | " dataloader = DataLoader(train_data,\n",
176 | " batch_size=BATCH_SIZE,\n",
177 | " pin_memory=use_cuda)\n",
178 | "\n",
179 | " for i, (s, a, r, _s, d) in enumerate(dataloader):\n",
180 | " s_batch = s.to(device).float()\n",
181 | " a_batch = a.to(device).long()\n",
182 | " _s_batch = _s.to(device).float()\n",
183 | " r_batch = r.to(device).float()\n",
184 | " done_mask = 1. - d.to(device).float()\n",
185 | " discount = torch.full_like(r_batch, GAMMA)\n",
186 | " \n",
187 | " q_batch = q_net(s_batch)\n",
188 | " q_acting = q_batch[range(BATCH_SIZE), a_batch]\n",
189 | " q_acting_ = q_acting.detach() \n",
190 | " \n",
191 | " v_batch = v_net(s_batch)\n",
192 | " \n",
193 | " with torch.no_grad():\n",
194 | " _log_p_batch = p_net(_s_batch)\n",
195 | " _log_p_acting = _log_p_batch[range(BATCH_SIZE), a_batch]\n",
196 | " v_target = q_acting_ - ENT_COEF * _log_p_acting\n",
197 | " \n",
198 | " v_loss = (v_target - v_batch).pow(2).mean()\n",
199 | " \n",
200 | " with torch.no_grad():\n",
201 | " _v_batch = v_tgt(_s_batch) * done_mask\n",
202 | " q_target = r_batch + _v_batch * discount\n",
203 | " \n",
204 | " q_loss = (q_target - q_acting).pow(2).mean()\n",
205 | " \n",
206 | " log_p_batch = p_net(s_batch)\n",
207 | " q_batch_ = q_net(s_batch)\n",
208 | " entropy = -(ENT_COEF * log_p_batch.exp() * q_batch).sum(dim=-1).mean()\n",
209 | " \n",
210 | " loss = v_loss + q_loss + entropy\n",
211 | " \n",
212 | " optimizer.zero_grad()\n",
213 | " loss.backward()\n",
214 | "# nn.utils.clip_grad_norm_(total_params, max_norm=0.5)\n",
215 | " optimizer.step()\n",
216 | "\n",
217 | "\n",
218 | "def select_action(obs, p_net):\n",
219 | " p_net.eval()\n",
220 | " with torch.no_grad():\n",
221 | " state = torch.tensor([obs]).to(device).float()\n",
222 | " log_p = p_net(state)\n",
223 | " m = Categorical(log_p.exp())\n",
224 | " action = m.sample()\n",
225 | "\n",
226 | " return action.item()\n",
227 | "\n",
228 | "\n",
229 | "def plot():\n",
230 | " clear_output(True)\n",
231 | " plt.figure(figsize=(16, 5))\n",
232 | " plt.subplot(121)\n",
233 | " plt.plot(ep_rewards)\n",
234 | " plt.title('Reward')\n",
235 | " plt.subplot(122)\n",
236 | " plt.plot(losses)\n",
237 | " plt.title('Loss')\n",
238 | " plt.show()"
239 | ]
240 | },
241 | {
242 | "cell_type": "markdown",
243 | "metadata": {},
244 | "source": [
245 | "## Main"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": null,
251 | "metadata": {
252 | "colab": {
253 | "base_uri": "https://localhost:8080/",
254 | "height": 3377
255 | },
256 | "colab_type": "code",
257 | "executionInfo": {
258 | "elapsed": 135196,
259 | "status": "ok",
260 | "timestamp": 1534482559393,
261 | "user": {
262 | "displayName": "윤승제",
263 | "photoUrl": "//lh5.googleusercontent.com/-EucKC7DmcQI/AAAAAAAAAAI/AAAAAAAAAGA/gQU1NPEmNFA/s50-c-k-no/photo.jpg",
264 | "userId": "105654037995838004821"
265 | },
266 | "user_tz": -540
267 | },
268 | "id": "PnifSBJglzHh",
269 | "outputId": "94177345-918e-4a96-d9a8-d8aba0a4bc9a",
270 | "scrolled": true
271 | },
272 | "outputs": [
273 | {
274 | "name": "stderr",
275 | "output_type": "stream",
276 | "text": [
277 | "/home/jay/anaconda3/lib/python3.7/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated. Call .resolve and .require separately.\n",
278 | " result = entry_point.load(False)\n"
279 | ]
280 | }
281 | ],
282 | "source": [
283 | "# make an environment\n",
284 | "env = gym.make('CartPole-v0')\n",
285 | "# env = gym.make('CartPole-v1')\n",
286 | "# env = gym.make('MountainCar-v0')\n",
287 | "# env = gym.make('LunarLander-v2')\n",
288 | "\n",
289 | "env.seed(SEED)\n",
290 | "obs_space = env.observation_space.shape[0]\n",
291 | "action_space = env.action_space.n\n",
292 | "\n",
293 | "# hyperparameter\n",
294 | "n_episodes = 10000\n",
295 | "learn_start = 1500\n",
296 | "memory_size = 50000\n",
297 | "update_frq = 1\n",
298 | "use_eps_decay = False\n",
299 | "n_eval = env.spec.trials\n",
300 | "\n",
301 | "# global values\n",
302 | "total_steps = 0\n",
303 | "learn_steps = 0\n",
304 | "rewards = []\n",
305 | "reward_eval = deque(maxlen=n_eval)\n",
306 | "is_learned = False\n",
307 | "is_solved = False\n",
308 | "\n",
309 | "# make two nerual networks\n",
310 | "q_net = QNet(obs_space, action_space).to(device)\n",
311 | "p_net = PolicyNet(obs_space, action_space).to(device)\n",
312 | "v_net = ValueNet(obs_space).to(device)\n",
313 | "v_tgt = deepcopy(v_net)\n",
314 | "\n",
315 | "# make optimizer\n",
316 | "total_params = list(q_net.parameters()) + list(p_net.parameters()) + list(v_net.parameters())\n",
317 | "optimizer = optim.Adam(total_params, lr=LR, eps=1e-5)\n",
318 | "\n",
319 | "# make a memory\n",
320 | "rep_memory = deque(maxlen=memory_size)"
321 | ]
322 | },
323 | {
324 | "cell_type": "code",
325 | "execution_count": null,
326 | "metadata": {},
327 | "outputs": [
328 | {
329 | "data": {
330 | "text/plain": [
331 | "200"
332 | ]
333 | },
334 | "execution_count": 6,
335 | "metadata": {},
336 | "output_type": "execute_result"
337 | }
338 | ],
339 | "source": [
340 | "env.spec.max_episode_steps"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": null,
346 | "metadata": {},
347 | "outputs": [
348 | {
349 | "data": {
350 | "text/plain": [
351 | "100"
352 | ]
353 | },
354 | "execution_count": 7,
355 | "metadata": {},
356 | "output_type": "execute_result"
357 | }
358 | ],
359 | "source": [
360 | "env.spec.trials"
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": null,
366 | "metadata": {},
367 | "outputs": [
368 | {
369 | "data": {
370 | "text/plain": [
371 | "195.0"
372 | ]
373 | },
374 | "execution_count": 8,
375 | "metadata": {},
376 | "output_type": "execute_result"
377 | }
378 | ],
379 | "source": [
380 | "env.spec.reward_threshold"
381 | ]
382 | },
383 | {
384 | "cell_type": "code",
385 | "execution_count": null,
386 | "metadata": {
387 | "colab": {
388 | "base_uri": "https://localhost:8080/",
389 | "height": 3377
390 | },
391 | "colab_type": "code",
392 | "executionInfo": {
393 | "elapsed": 135196,
394 | "status": "ok",
395 | "timestamp": 1534482559393,
396 | "user": {
397 | "displayName": "윤승제",
398 | "photoUrl": "//lh5.googleusercontent.com/-EucKC7DmcQI/AAAAAAAAAAI/AAAAAAAAAGA/gQU1NPEmNFA/s50-c-k-no/photo.jpg",
399 | "userId": "105654037995838004821"
400 | },
401 | "user_tz": -540
402 | },
403 | "id": "PnifSBJglzHh",
404 | "outputId": "94177345-918e-4a96-d9a8-d8aba0a4bc9a",
405 | "scrolled": true
406 | },
407 | "outputs": [
408 | {
409 | "name": "stdout",
410 | "output_type": "stream",
411 | "text": [
412 | " 1 Episode in 15 steps, reward 15.00\n",
413 | " 2 Episode in 41 steps, reward 26.00\n",
414 | " 3 Episode in 51 steps, reward 10.00\n",
415 | " 4 Episode in 62 steps, reward 11.00\n",
416 | " 5 Episode in 72 steps, reward 10.00\n",
417 | " 6 Episode in 105 steps, reward 33.00\n",
418 | " 7 Episode in 151 steps, reward 46.00\n",
419 | " 8 Episode in 163 steps, reward 12.00\n",
420 | " 9 Episode in 189 steps, reward 26.00\n",
421 | " 10 Episode in 206 steps, reward 17.00\n",
422 | " 11 Episode in 224 steps, reward 18.00\n",
423 | " 12 Episode in 242 steps, reward 18.00\n",
424 | " 13 Episode in 267 steps, reward 25.00\n",
425 | " 14 Episode in 282 steps, reward 15.00\n",
426 | " 15 Episode in 295 steps, reward 13.00\n",
427 | " 16 Episode in 307 steps, reward 12.00\n",
428 | " 17 Episode in 324 steps, reward 17.00\n",
429 | " 18 Episode in 344 steps, reward 20.00\n",
430 | " 19 Episode in 364 steps, reward 20.00\n",
431 | " 20 Episode in 383 steps, reward 19.00\n",
432 | " 21 Episode in 430 steps, reward 47.00\n",
433 | " 22 Episode in 452 steps, reward 22.00\n",
434 | " 23 Episode in 469 steps, reward 17.00\n",
435 | " 24 Episode in 488 steps, reward 19.00\n",
436 | " 25 Episode in 511 steps, reward 23.00\n",
437 | " 26 Episode in 524 steps, reward 13.00\n",
438 | " 27 Episode in 546 steps, reward 22.00\n",
439 | " 28 Episode in 559 steps, reward 13.00\n",
440 | " 29 Episode in 577 steps, reward 18.00\n",
441 | " 30 Episode in 591 steps, reward 14.00\n",
442 | " 31 Episode in 609 steps, reward 18.00\n",
443 | " 32 Episode in 630 steps, reward 21.00\n",
444 | " 33 Episode in 643 steps, reward 13.00\n",
445 | " 34 Episode in 703 steps, reward 60.00\n",
446 | " 35 Episode in 717 steps, reward 14.00\n",
447 | " 36 Episode in 771 steps, reward 54.00\n",
448 | " 37 Episode in 793 steps, reward 22.00\n",
449 | " 38 Episode in 818 steps, reward 25.00\n",
450 | " 39 Episode in 834 steps, reward 16.00\n",
451 | " 40 Episode in 849 steps, reward 15.00\n",
452 | " 41 Episode in 869 steps, reward 20.00\n",
453 | " 42 Episode in 898 steps, reward 29.00\n",
454 | " 43 Episode in 913 steps, reward 15.00\n",
455 | " 44 Episode in 923 steps, reward 10.00\n",
456 | " 45 Episode in 957 steps, reward 34.00\n",
457 | " 46 Episode in 975 steps, reward 18.00\n",
458 | " 47 Episode in 1001 steps, reward 26.00\n",
459 | " 48 Episode in 1011 steps, reward 10.00\n",
460 | " 49 Episode in 1028 steps, reward 17.00\n",
461 | " 50 Episode in 1049 steps, reward 21.00\n",
462 | " 51 Episode in 1060 steps, reward 11.00\n",
463 | " 52 Episode in 1079 steps, reward 19.00\n",
464 | " 53 Episode in 1089 steps, reward 10.00\n",
465 | " 54 Episode in 1103 steps, reward 14.00\n",
466 | " 55 Episode in 1117 steps, reward 14.00\n",
467 | " 56 Episode in 1130 steps, reward 13.00\n",
468 | " 57 Episode in 1151 steps, reward 21.00\n",
469 | " 58 Episode in 1178 steps, reward 27.00\n",
470 | " 59 Episode in 1200 steps, reward 22.00\n",
471 | " 60 Episode in 1218 steps, reward 18.00\n",
472 | " 61 Episode in 1229 steps, reward 11.00\n",
473 | " 62 Episode in 1241 steps, reward 12.00\n",
474 | " 63 Episode in 1263 steps, reward 22.00\n",
475 | " 64 Episode in 1277 steps, reward 14.00\n",
476 | " 65 Episode in 1300 steps, reward 23.00\n",
477 | " 66 Episode in 1340 steps, reward 40.00\n",
478 | " 67 Episode in 1352 steps, reward 12.00\n",
479 | " 68 Episode in 1372 steps, reward 20.00\n",
480 | " 69 Episode in 1388 steps, reward 16.00\n",
481 | " 70 Episode in 1405 steps, reward 17.00\n",
482 | " 71 Episode in 1420 steps, reward 15.00\n",
483 | " 72 Episode in 1459 steps, reward 39.00\n",
484 | " 73 Episode in 1469 steps, reward 10.00\n",
485 | " 74 Episode in 1480 steps, reward 11.00\n",
486 | " 75 Episode in 1497 steps, reward 17.00\n",
487 | "\n",
488 | "============ Start Learning ============\n",
489 | "\n",
490 | " 76 Episode in 1513 steps, reward 16.00\n",
491 | " 77 Episode in 1524 steps, reward 11.00\n",
492 | " 78 Episode in 1544 steps, reward 20.00\n",
493 | " 79 Episode in 1556 steps, reward 12.00\n",
494 | " 80 Episode in 1568 steps, reward 12.00\n",
495 | " 81 Episode in 1591 steps, reward 23.00\n",
496 | " 82 Episode in 1609 steps, reward 18.00\n",
497 | " 83 Episode in 1622 steps, reward 13.00\n",
498 | " 84 Episode in 1634 steps, reward 12.00\n",
499 | " 85 Episode in 1644 steps, reward 10.00\n",
500 | " 86 Episode in 1654 steps, reward 10.00\n",
501 | " 87 Episode in 1664 steps, reward 10.00\n",
502 | " 88 Episode in 1678 steps, reward 14.00\n",
503 | " 89 Episode in 1688 steps, reward 10.00\n",
504 | " 90 Episode in 1698 steps, reward 10.00\n",
505 | " 91 Episode in 1708 steps, reward 10.00\n",
506 | " 92 Episode in 1718 steps, reward 10.00\n",
507 | " 93 Episode in 1728 steps, reward 10.00\n",
508 | " 94 Episode in 1737 steps, reward 9.00\n",
509 | " 95 Episode in 1746 steps, reward 9.00\n",
510 | " 96 Episode in 1756 steps, reward 10.00\n",
511 | " 97 Episode in 1766 steps, reward 10.00\n",
512 | " 98 Episode in 1776 steps, reward 10.00\n",
513 | " 99 Episode in 1787 steps, reward 11.00\n",
514 | "100 Episode in 1801 steps, reward 14.00\n",
515 | "101 Episode in 1824 steps, reward 23.00\n",
516 | "102 Episode in 1840 steps, reward 16.00\n",
517 | "103 Episode in 1852 steps, reward 12.00\n",
518 | "104 Episode in 1865 steps, reward 13.00\n",
519 | "105 Episode in 1877 steps, reward 12.00\n",
520 | "106 Episode in 1894 steps, reward 17.00\n",
521 | "107 Episode in 1918 steps, reward 24.00\n",
522 | "108 Episode in 1937 steps, reward 19.00\n",
523 | "109 Episode in 1963 steps, reward 26.00\n",
524 | "110 Episode in 2005 steps, reward 42.00\n",
525 | "111 Episode in 2039 steps, reward 34.00\n",
526 | "112 Episode in 2065 steps, reward 26.00\n",
527 | "113 Episode in 2087 steps, reward 22.00\n",
528 | "114 Episode in 2107 steps, reward 20.00\n",
529 | "115 Episode in 2132 steps, reward 25.00\n",
530 | "116 Episode in 2179 steps, reward 47.00\n",
531 | "117 Episode in 2201 steps, reward 22.00\n",
532 | "118 Episode in 2231 steps, reward 30.00\n",
533 | "119 Episode in 2282 steps, reward 51.00\n",
534 | "120 Episode in 2344 steps, reward 62.00\n",
535 | "121 Episode in 2365 steps, reward 21.00\n",
536 | "122 Episode in 2378 steps, reward 13.00\n",
537 | "123 Episode in 2388 steps, reward 10.00\n",
538 | "124 Episode in 2396 steps, reward 8.00\n",
539 | "125 Episode in 2406 steps, reward 10.00\n",
540 | "126 Episode in 2414 steps, reward 8.00\n",
541 | "127 Episode in 2423 steps, reward 9.00\n",
542 | "128 Episode in 2434 steps, reward 11.00\n",
543 | "129 Episode in 2447 steps, reward 13.00\n",
544 | "130 Episode in 2457 steps, reward 10.00\n",
545 | "131 Episode in 2469 steps, reward 12.00\n",
546 | "132 Episode in 2479 steps, reward 10.00\n",
547 | "133 Episode in 2494 steps, reward 15.00\n",
548 | "134 Episode in 2517 steps, reward 23.00\n",
549 | "135 Episode in 2537 steps, reward 20.00\n",
550 | "136 Episode in 2562 steps, reward 25.00\n",
551 | "137 Episode in 2580 steps, reward 18.00\n",
552 | "138 Episode in 2594 steps, reward 14.00\n",
553 | "139 Episode in 2606 steps, reward 12.00\n",
554 | "140 Episode in 2618 steps, reward 12.00\n",
555 | "141 Episode in 2631 steps, reward 13.00\n",
556 | "142 Episode in 2647 steps, reward 16.00\n",
557 | "143 Episode in 2663 steps, reward 16.00\n",
558 | "144 Episode in 2678 steps, reward 15.00\n",
559 | "145 Episode in 2688 steps, reward 10.00\n",
560 | "146 Episode in 2698 steps, reward 10.00\n",
561 | "147 Episode in 2707 steps, reward 9.00\n",
562 | "148 Episode in 2719 steps, reward 12.00\n",
563 | "149 Episode in 2734 steps, reward 15.00\n",
564 | "150 Episode in 2753 steps, reward 19.00\n",
565 | "151 Episode in 2772 steps, reward 19.00\n",
566 | "152 Episode in 2791 steps, reward 19.00\n",
567 | "153 Episode in 2806 steps, reward 15.00\n",
568 | "154 Episode in 2817 steps, reward 11.00\n",
569 | "155 Episode in 2830 steps, reward 13.00\n",
570 | "156 Episode in 2841 steps, reward 11.00\n",
571 | "157 Episode in 2852 steps, reward 11.00\n",
572 | "158 Episode in 2869 steps, reward 17.00\n",
573 | "159 Episode in 2889 steps, reward 20.00\n",
574 | "160 Episode in 2907 steps, reward 18.00\n",
575 | "161 Episode in 2925 steps, reward 18.00\n",
576 | "162 Episode in 2937 steps, reward 12.00\n",
577 | "163 Episode in 2948 steps, reward 11.00\n",
578 | "164 Episode in 2958 steps, reward 10.00\n",
579 | "165 Episode in 2967 steps, reward 9.00\n",
580 | "166 Episode in 2977 steps, reward 10.00\n",
581 | "167 Episode in 2987 steps, reward 10.00\n",
582 | "168 Episode in 2997 steps, reward 10.00\n",
583 | "169 Episode in 3007 steps, reward 10.00\n",
584 | "170 Episode in 3016 steps, reward 9.00\n",
585 | "171 Episode in 3026 steps, reward 10.00\n",
586 | "172 Episode in 3036 steps, reward 10.00\n",
587 | "173 Episode in 3045 steps, reward 9.00\n",
588 | "174 Episode in 3054 steps, reward 9.00\n",
589 | "175 Episode in 3063 steps, reward 9.00\n",
590 | "176 Episode in 3074 steps, reward 11.00\n",
591 | "177 Episode in 3082 steps, reward 8.00\n",
592 | "178 Episode in 3093 steps, reward 11.00\n",
593 | "179 Episode in 3103 steps, reward 10.00\n",
594 | "180 Episode in 3114 steps, reward 11.00\n",
595 | "181 Episode in 3124 steps, reward 10.00\n",
596 | "182 Episode in 3137 steps, reward 13.00\n",
597 | "183 Episode in 3153 steps, reward 16.00\n",
598 | "184 Episode in 3218 steps, reward 65.00\n",
599 | "185 Episode in 3230 steps, reward 12.00\n",
600 | "186 Episode in 3241 steps, reward 11.00\n",
601 | "187 Episode in 3250 steps, reward 9.00\n",
602 | "188 Episode in 3260 steps, reward 10.00\n",
603 | "189 Episode in 3270 steps, reward 10.00\n",
604 | "190 Episode in 3281 steps, reward 11.00\n",
605 | "191 Episode in 3289 steps, reward 8.00\n",
606 | "192 Episode in 3298 steps, reward 9.00\n",
607 | "193 Episode in 3308 steps, reward 10.00\n",
608 | "194 Episode in 3318 steps, reward 10.00\n",
609 | "195 Episode in 3327 steps, reward 9.00\n",
610 | "196 Episode in 3337 steps, reward 10.00\n",
611 | "197 Episode in 3345 steps, reward 8.00\n",
612 | "198 Episode in 3354 steps, reward 9.00\n",
613 | "199 Episode in 3363 steps, reward 9.00\n",
614 | "200 Episode in 3373 steps, reward 10.00\n",
615 | "201 Episode in 3383 steps, reward 10.00\n"
616 | ]
617 | },
618 | {
619 | "name": "stdout",
620 | "output_type": "stream",
621 | "text": [
622 | "202 Episode in 3392 steps, reward 9.00\n",
623 | "203 Episode in 3400 steps, reward 8.00\n",
624 | "204 Episode in 3408 steps, reward 8.00\n",
625 | "205 Episode in 3418 steps, reward 10.00\n",
626 | "206 Episode in 3427 steps, reward 9.00\n",
627 | "207 Episode in 3436 steps, reward 9.00\n",
628 | "208 Episode in 3446 steps, reward 10.00\n",
629 | "209 Episode in 3454 steps, reward 8.00\n",
630 | "210 Episode in 3463 steps, reward 9.00\n",
631 | "211 Episode in 3473 steps, reward 10.00\n",
632 | "212 Episode in 3481 steps, reward 8.00\n",
633 | "213 Episode in 3491 steps, reward 10.00\n",
634 | "214 Episode in 3499 steps, reward 8.00\n",
635 | "215 Episode in 3508 steps, reward 9.00\n",
636 | "216 Episode in 3518 steps, reward 10.00\n",
637 | "217 Episode in 3528 steps, reward 10.00\n",
638 | "218 Episode in 3536 steps, reward 8.00\n",
639 | "219 Episode in 3545 steps, reward 9.00\n",
640 | "220 Episode in 3555 steps, reward 10.00\n",
641 | "221 Episode in 3563 steps, reward 8.00\n",
642 | "222 Episode in 3572 steps, reward 9.00\n",
643 | "223 Episode in 3580 steps, reward 8.00\n",
644 | "224 Episode in 3590 steps, reward 10.00\n",
645 | "225 Episode in 3600 steps, reward 10.00\n",
646 | "226 Episode in 3610 steps, reward 10.00\n",
647 | "227 Episode in 3620 steps, reward 10.00\n",
648 | "228 Episode in 3629 steps, reward 9.00\n",
649 | "229 Episode in 3638 steps, reward 9.00\n",
650 | "230 Episode in 3646 steps, reward 8.00\n",
651 | "231 Episode in 3655 steps, reward 9.00\n",
652 | "232 Episode in 3664 steps, reward 9.00\n",
653 | "233 Episode in 3673 steps, reward 9.00\n",
654 | "234 Episode in 3682 steps, reward 9.00\n",
655 | "235 Episode in 3691 steps, reward 9.00\n",
656 | "236 Episode in 3701 steps, reward 10.00\n",
657 | "237 Episode in 3711 steps, reward 10.00\n",
658 | "238 Episode in 3719 steps, reward 8.00\n",
659 | "239 Episode in 3729 steps, reward 10.00\n",
660 | "240 Episode in 3739 steps, reward 10.00\n",
661 | "241 Episode in 3749 steps, reward 10.00\n",
662 | "242 Episode in 3759 steps, reward 10.00\n",
663 | "243 Episode in 3768 steps, reward 9.00\n",
664 | "244 Episode in 3777 steps, reward 9.00\n",
665 | "245 Episode in 3787 steps, reward 10.00\n",
666 | "246 Episode in 3796 steps, reward 9.00\n",
667 | "247 Episode in 3804 steps, reward 8.00\n",
668 | "248 Episode in 3813 steps, reward 9.00\n",
669 | "249 Episode in 3823 steps, reward 10.00\n",
670 | "250 Episode in 3832 steps, reward 9.00\n",
671 | "251 Episode in 3841 steps, reward 9.00\n",
672 | "252 Episode in 3851 steps, reward 10.00\n",
673 | "253 Episode in 3860 steps, reward 9.00\n",
674 | "254 Episode in 3869 steps, reward 9.00\n",
675 | "255 Episode in 3878 steps, reward 9.00\n",
676 | "256 Episode in 3887 steps, reward 9.00\n",
677 | "257 Episode in 3897 steps, reward 10.00\n",
678 | "258 Episode in 3906 steps, reward 9.00\n",
679 | "259 Episode in 3914 steps, reward 8.00\n",
680 | "260 Episode in 3922 steps, reward 8.00\n",
681 | "261 Episode in 3931 steps, reward 9.00\n",
682 | "262 Episode in 3940 steps, reward 9.00\n",
683 | "263 Episode in 3950 steps, reward 10.00\n",
684 | "264 Episode in 3959 steps, reward 9.00\n",
685 | "265 Episode in 3968 steps, reward 9.00\n",
686 | "266 Episode in 3976 steps, reward 8.00\n",
687 | "267 Episode in 3986 steps, reward 10.00\n",
688 | "268 Episode in 3995 steps, reward 9.00\n",
689 | "269 Episode in 4004 steps, reward 9.00\n",
690 | "270 Episode in 4014 steps, reward 10.00\n",
691 | "271 Episode in 4024 steps, reward 10.00\n",
692 | "272 Episode in 4033 steps, reward 9.00\n",
693 | "273 Episode in 4042 steps, reward 9.00\n"
694 | ]
695 | }
696 | ],
697 | "source": [
698 | "# play\n",
699 | "for i in range(1, n_episodes + 1):\n",
700 | " obs = env.reset()\n",
701 | " done = False\n",
702 | " ep_reward = 0\n",
703 | " while not done:\n",
704 | "# env.render()\n",
705 | " action = select_action(obs, p_net)\n",
706 | " _obs, reward, done, _ = env.step(action)\n",
707 | " rep_memory.append((obs, action, reward, _obs, done))\n",
708 | "\n",
709 | " obs = _obs\n",
710 | " total_steps += 1\n",
711 | " ep_reward += reward\n",
712 | "\n",
713 | " if len(rep_memory) >= learn_start:\n",
714 | " if len(rep_memory) == learn_start:\n",
715 | " print('\\n============ Start Learning ============\\n')\n",
716 | " learn(q_net, p_net, v_net, v_tgt, optimizer, rep_memory)\n",
717 | " learn_steps += 1\n",
718 | "\n",
719 | " if learn_steps == update_frq:\n",
720 | " # target smoothing update\n",
721 | " with torch.no_grad():\n",
722 | " for t, n in zip(v_tgt.parameters(), v_net.parameters()):\n",
723 | " t.data = UP_COEF * n.data + (1 - UP_COEF) * t.data\n",
724 | " learn_steps = 0\n",
725 | " if done:\n",
726 | " rewards.append(ep_reward)\n",
727 | " reward_eval.append(ep_reward)\n",
728 | " plot()\n",
729 | "# print('{:3} Episode in {:5} steps, reward {:.2f}'.format(\n",
730 | "# i, total_steps, ep_reward))\n",
731 | "\n",
732 | " if len(reward_eval) >= n_eval:\n",
733 | " if np.mean(reward_eval) >= env.spec.reward_threshold:\n",
734 | " print('\\n{} is sloved! {:3} Episode in {:3} steps'.format(\n",
735 | " env.spec.id, i, total_steps))\n",
736 | " torch.save(target_net.state_dict(),\n",
737 | " f'./test/saved_models/{env.spec.id}_ep{i}_clear_model_sac.pt')\n",
738 | " break\n",
739 | "env.close()"
740 | ]
741 | },
742 | {
743 | "cell_type": "code",
744 | "execution_count": null,
745 | "metadata": {
746 | "scrolled": false
747 | },
748 | "outputs": [],
749 | "source": [
750 | "plt.figure(figsize=(15, 5))\n",
751 | "plt.title('Reward')\n",
752 | "plt.plot(rewards)\n",
753 | "plt.figure(figsize=(15, 5))\n",
754 | "plt.title('Loss')\n",
755 | "plt.plot(losses)\n",
756 | "plt.show()"
757 | ]
758 | },
759 | {
760 | "cell_type": "code",
761 | "execution_count": null,
762 | "metadata": {},
763 | "outputs": [],
764 | "source": [
765 | "[\n",
766 | " ('CartPole-v0', 299, 0.25),\n",
767 | " ('CartPole-v1', 413, 0.025),\n",
768 | " ('MountainCar-v0', None ,0.05)\n",
769 | "]"
770 | ]
771 | }
772 | ],
773 | "metadata": {
774 | "colab": {
775 | "collapsed_sections": [],
776 | "name": "C51_tensorflow.ipynb",
777 | "provenance": [],
778 | "version": "0.3.2"
779 | },
780 | "kernelspec": {
781 | "display_name": "Python 3",
782 | "language": "python",
783 | "name": "python3"
784 | },
785 | "language_info": {
786 | "codemirror_mode": {
787 | "name": "ipython",
788 | "version": 3
789 | },
790 | "file_extension": ".py",
791 | "mimetype": "text/x-python",
792 | "name": "python",
793 | "nbconvert_exporter": "python",
794 | "pygments_lexer": "ipython3",
795 | "version": "3.7.0"
796 | }
797 | },
798 | "nbformat": 4,
799 | "nbformat_minor": 1
800 | }
801 |
--------------------------------------------------------------------------------
/src/test/running_mean_std.py:
--------------------------------------------------------------------------------
1 | # https://github.com/openai/baselines/blob/master/baselines/common/running_mean_std.py
2 | import numpy as np
3 |
4 |
5 | class RunningMeanStd(object):
6 | def __init__(self, epsilon=1e-4, shape=()):
7 | self.mean = np.zeros(shape, 'float64')
8 | self.var = np.ones(shape, 'float64')
9 | self.count = epsilon
10 |
11 | def update(self, x):
12 | batch_mean = np.mean(x, axis=0)
13 | batch_var = np.var(x, axis=0)
14 | batch_count = x.shape[0]
15 | self.update_from_moments(batch_mean, batch_var, batch_count)
16 |
17 | def update_from_moments(self, batch_mean, batch_var, batch_count):
18 | self.mean, self.var, self.count = update_mean_var_count_from_moments(
19 | self.mean, self.var, self.count,
20 | batch_mean, batch_var, batch_count)
21 |
22 |
23 | def update_mean_var_count_from_moments(mean, var, count,
24 | batch_mean, batch_var, batch_count):
25 | delta = batch_mean - mean
26 | tot_count = count + batch_count
27 |
28 | new_mean = mean + delta * batch_count / tot_count
29 | m_a = var * count
30 | m_b = batch_var * batch_count
31 | M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
32 | new_var = M2 / tot_count
33 | new_count = tot_count
34 |
35 | return new_mean, new_var, new_count
36 |
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v0_ep179_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v0_ep179_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v0_ep87_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v0_ep87_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep108_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep108_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep112_clear_model_dddqn.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep112_clear_model_dddqn.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep1150_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep1150_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep118_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep118_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep1338_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep1338_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep134_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep134_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep164_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep164_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep168_clear_model_dddqn.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep168_clear_model_dddqn.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep1715_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep1715_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep213_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep213_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep216_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep216_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep222_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep222_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep223_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep223_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep258_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep258_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep270_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep270_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep273_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep273_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep294_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep294_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep295_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep295_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep296_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep296_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep308_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep308_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep313_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep313_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep320_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep320_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep326_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep326_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep329_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep329_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep357_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep357_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep376_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep376_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep385_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep385_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep388_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep388_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep408_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep408_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep417_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep417_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep418_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep418_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep420_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep420_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep424_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep424_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep433_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep433_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep436_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep436_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep445_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep445_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep454_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep454_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep456_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep456_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep458_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep458_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep463_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep463_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep470_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep470_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep474_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep474_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep475_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep475_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep492_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep492_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep498_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep498_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep504_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep504_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep506_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep506_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep509_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep509_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep516_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep516_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep519_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep519_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep530_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep530_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep534_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep534_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep546_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep546_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep552_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep552_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep555_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep555_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep559_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep559_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep566_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep566_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep569_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep569_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep586_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep586_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep589_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep589_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep595_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep595_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep631_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep631_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep635_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep635_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep637_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep637_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep642_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep642_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep675_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep675_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep690_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep690_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep727_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep727_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep741_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep741_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep763_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep763_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep840_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep840_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep845_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep845_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep866_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep866_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep871_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep871_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep892_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep892_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep903_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep903_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep924_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep924_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_ep939_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_ep939_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_up50_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_up50_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/CartPole-v1_up50_clear_norm_obs.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/CartPole-v1_up50_clear_norm_obs.pkl
--------------------------------------------------------------------------------
/src/test/saved_models/LunarLander-v2_ep260_clear_model_dqn.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_ep260_clear_model_dqn.pt
--------------------------------------------------------------------------------
/src/test/saved_models/LunarLander-v2_ep370_clear_model_dddqn.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_ep370_clear_model_dddqn.pt
--------------------------------------------------------------------------------
/src/test/saved_models/LunarLander-v2_ep8461_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_ep8461_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/LunarLander-v2_ep876_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_ep876_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/LunarLander-v2_up1099_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_up1099_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/LunarLander-v2_up1317_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_up1317_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/LunarLander-v2_up254_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_up254_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/LunarLander-v2_up570_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_up570_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/LunarLander-v2_up75_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/LunarLander-v2_up75_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/MountainCar-v0_ep304_clear_model_dqn.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/MountainCar-v0_ep304_clear_model_dqn.pt
--------------------------------------------------------------------------------
/src/test/saved_models/MountainCar-v0_ep532_clear_model_dddqn.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/MountainCar-v0_ep532_clear_model_dddqn.pt
--------------------------------------------------------------------------------
/src/test/saved_models/MountainCar-v0_ep984_clear_model_dddqn.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/MountainCar-v0_ep984_clear_model_dddqn.pt
--------------------------------------------------------------------------------
/src/test/saved_models/MountainCar-v0_up1441_clear_model_ppo_st.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/MountainCar-v0_up1441_clear_model_ppo_st.pt
--------------------------------------------------------------------------------
/src/test/saved_models/MountainCar-v0_up1441_clear_norm_obs.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kekmodel/rl_pytorch/0f68e3816807cb8b8a6151d18e28cebcc296f2e1/src/test/saved_models/MountainCar-v0_up1441_clear_norm_obs.pkl
--------------------------------------------------------------------------------
/src/test/test_dueling_double_dqn.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "colab": {},
8 | "colab_type": "code",
9 | "id": "IWnm3qot3o1W"
10 | },
11 | "outputs": [],
12 | "source": [
13 | "from collections import deque\n",
14 | "\n",
15 | "import gym\n",
16 | "import imageio\n",
17 | "import matplotlib.pyplot as plt\n",
18 | "import numpy as np\n",
19 | "import torch\n",
20 | "import torch.nn as nn\n",
21 | "from torch.nn import functional as F\n",
22 | "from torch.distributions import Categorical"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 2,
28 | "metadata": {
29 | "colab": {},
30 | "colab_type": "code",
31 | "id": "9Ffkl_5C4R81"
32 | },
33 | "outputs": [],
34 | "source": [
35 | "class DuelingDQN(nn.Module):\n",
36 | " def __init__(self, obs_space, action_space):\n",
37 | " super().__init__()\n",
38 | "\n",
39 | " self.head = nn.Sequential(\n",
40 | " nn.Linear(obs_space, obs_space*10),\n",
41 | " nn.SELU()\n",
42 | " )\n",
43 | "\n",
44 | " self.val = nn.Sequential(\n",
45 | " nn.Linear(obs_space*10, 512),\n",
46 | " nn.SELU(),\n",
47 | " nn.Linear(512, 512),\n",
48 | " nn.SELU(),\n",
49 | " nn.Linear(512, 1)\n",
50 | " )\n",
51 | "\n",
52 | " self.adv = nn.Sequential(\n",
53 | " nn.Linear(obs_space*10, 512),\n",
54 | " nn.SELU(),\n",
55 | " nn.Linear(512, 512),\n",
56 | " nn.SELU(),\n",
57 | " nn.Linear(512, action_space)\n",
58 | " )\n",
59 | "\n",
60 | " def forward(self, x):\n",
61 | " out = self.head(x)\n",
62 | " val_out = self.val(out).reshape(out.shape[0], 1)\n",
63 | " adv_out = self.adv(out).reshape(out.shape[0], -1)\n",
64 | " adv_mean = adv_out.mean(dim=1, keepdim=True)\n",
65 | " q = val_out + adv_out - adv_mean\n",
66 | "\n",
67 | " return q"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 3,
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "def select_action_(obs, tgt_net):\n",
77 | " tgt_net.eval()\n",
78 | " with torch.no_grad():\n",
79 | " state = torch.tensor([obs]).to(device).float()\n",
80 | " q = tgt_net(state)\n",
81 | " action = torch.argmax(q)\n",
82 | "\n",
83 | " return action.item()\n",
84 | "\n",
85 | "\n",
86 | "def select_action(obs, tgt_net):\n",
87 | " tgt_net.eval()\n",
88 | " with torch.no_grad():\n",
89 | " state = torch.tensor([obs]).to(device).float()\n",
90 | " q = tgt_net(state)\n",
91 | " probs = F.softmax(q/0.35)\n",
92 | " m = Categorical(probs)\n",
93 | " action = m.sample()\n",
94 | "\n",
95 | " return action.item()"
96 | ]
97 | },
98 | {
99 | "cell_type": "markdown",
100 | "metadata": {},
101 | "source": [
102 | "## Main"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 4,
108 | "metadata": {
109 | "colab": {
110 | "base_uri": "https://localhost:8080/",
111 | "height": 3377
112 | },
113 | "colab_type": "code",
114 | "executionInfo": {
115 | "elapsed": 135196,
116 | "status": "ok",
117 | "timestamp": 1534482559393,
118 | "user": {
119 | "displayName": "윤승제",
120 | "photoUrl": "//lh5.googleusercontent.com/-EucKC7DmcQI/AAAAAAAAAAI/AAAAAAAAAGA/gQU1NPEmNFA/s50-c-k-no/photo.jpg",
121 | "userId": "105654037995838004821"
122 | },
123 | "user_tz": -540
124 | },
125 | "id": "PnifSBJglzHh",
126 | "outputId": "94177345-918e-4a96-d9a8-d8aba0a4bc9a",
127 | "scrolled": false
128 | },
129 | "outputs": [
130 | {
131 | "name": "stderr",
132 | "output_type": "stream",
133 | "text": [
134 | "/home/jay/anaconda3/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated. Call .resolve and .require separately.\n",
135 | " result = entry_point.load(False)\n"
136 | ]
137 | }
138 | ],
139 | "source": [
140 | "# set device\n",
141 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
142 | "\n",
143 | "# make an environment\n",
144 | "# env = gym.make('CartPole-v0')\n",
145 | "env = gym.make('CartPole-v1')\n",
146 | "# env = gym.make('MountainCar-v0')\n",
147 | "# env = gym.make('LunarLander-v2')\n",
148 | "\n",
149 | "SEED = 0\n",
150 | "env.seed(SEED)\n",
151 | "obs_space = env.observation_space.shape[0]\n",
152 | "action_space = env.action_space.n\n",
153 | "\n",
154 | "# hyperparameter\n",
155 | "n_episodes = 1000\n",
156 | "n_eval = env.spec.trials\n",
157 | "\n",
158 | "# global values\n",
159 | "total_steps = 0\n",
160 | "rewards = []\n",
161 | "reward_eval = deque(maxlen=n_eval)\n",
162 | "is_solved = False\n",
163 | "\n",
164 | "# load a model\n",
165 | "target_net = DuelingDQN(obs_space, action_space).to(device)\n",
166 | "target_net.load_state_dict(torch.load(\n",
167 | " './saved_models/CartPole-v1_ep217_clear_model_dddqn.pt'))"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": 5,
173 | "metadata": {
174 | "scrolled": true
175 | },
176 | "outputs": [
177 | {
178 | "data": {
179 | "text/plain": [
180 | "500"
181 | ]
182 | },
183 | "execution_count": 5,
184 | "metadata": {},
185 | "output_type": "execute_result"
186 | }
187 | ],
188 | "source": [
189 | "env.spec.max_episode_steps"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": 6,
195 | "metadata": {
196 | "scrolled": true
197 | },
198 | "outputs": [
199 | {
200 | "data": {
201 | "text/plain": [
202 | "100"
203 | ]
204 | },
205 | "execution_count": 6,
206 | "metadata": {},
207 | "output_type": "execute_result"
208 | }
209 | ],
210 | "source": [
211 | "env.spec.trials"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": 7,
217 | "metadata": {},
218 | "outputs": [
219 | {
220 | "data": {
221 | "text/plain": [
222 | "475.0"
223 | ]
224 | },
225 | "execution_count": 7,
226 | "metadata": {},
227 | "output_type": "execute_result"
228 | }
229 | ],
230 | "source": [
231 | "env.spec.reward_threshold"
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": 8,
237 | "metadata": {},
238 | "outputs": [],
239 | "source": [
240 | "# env.metadata['video.frames_per_second'] = 60"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": 9,
246 | "metadata": {
247 | "colab": {
248 | "base_uri": "https://localhost:8080/",
249 | "height": 3377
250 | },
251 | "colab_type": "code",
252 | "executionInfo": {
253 | "elapsed": 135196,
254 | "status": "ok",
255 | "timestamp": 1534482559393,
256 | "user": {
257 | "displayName": "윤승제",
258 | "photoUrl": "//lh5.googleusercontent.com/-EucKC7DmcQI/AAAAAAAAAAI/AAAAAAAAAGA/gQU1NPEmNFA/s50-c-k-no/photo.jpg",
259 | "userId": "105654037995838004821"
260 | },
261 | "user_tz": -540
262 | },
263 | "id": "PnifSBJglzHh",
264 | "outputId": "94177345-918e-4a96-d9a8-d8aba0a4bc9a",
265 | "scrolled": true
266 | },
267 | "outputs": [
268 | {
269 | "name": "stderr",
270 | "output_type": "stream",
271 | "text": [
272 | "/home/jay/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:16: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
273 | " app.launch_new_instance()\n"
274 | ]
275 | },
276 | {
277 | "name": "stdout",
278 | "output_type": "stream",
279 | "text": [
280 | " 1 Episode in 500 steps, reward 500.00\n",
281 | " 2 Episode in 1000 steps, reward 500.00\n",
282 | " 3 Episode in 1500 steps, reward 500.00\n",
283 | " 4 Episode in 2000 steps, reward 500.00\n",
284 | " 5 Episode in 2500 steps, reward 500.00\n",
285 | " 6 Episode in 3000 steps, reward 500.00\n",
286 | " 7 Episode in 3500 steps, reward 500.00\n",
287 | " 8 Episode in 4000 steps, reward 500.00\n",
288 | " 9 Episode in 4500 steps, reward 500.00\n",
289 | " 10 Episode in 5000 steps, reward 500.00\n",
290 | " 11 Episode in 5500 steps, reward 500.00\n",
291 | " 12 Episode in 6000 steps, reward 500.00\n",
292 | " 13 Episode in 6500 steps, reward 500.00\n",
293 | " 14 Episode in 7000 steps, reward 500.00\n",
294 | " 15 Episode in 7500 steps, reward 500.00\n",
295 | " 16 Episode in 8000 steps, reward 500.00\n",
296 | " 17 Episode in 8500 steps, reward 500.00\n",
297 | " 18 Episode in 9000 steps, reward 500.00\n",
298 | " 19 Episode in 9500 steps, reward 500.00\n",
299 | " 20 Episode in 10000 steps, reward 500.00\n",
300 | " 21 Episode in 10500 steps, reward 500.00\n",
301 | " 22 Episode in 11000 steps, reward 500.00\n",
302 | " 23 Episode in 11500 steps, reward 500.00\n",
303 | " 24 Episode in 12000 steps, reward 500.00\n",
304 | " 25 Episode in 12500 steps, reward 500.00\n",
305 | " 26 Episode in 13000 steps, reward 500.00\n",
306 | " 27 Episode in 13500 steps, reward 500.00\n",
307 | " 28 Episode in 14000 steps, reward 500.00\n",
308 | " 29 Episode in 14500 steps, reward 500.00\n",
309 | " 30 Episode in 15000 steps, reward 500.00\n",
310 | " 31 Episode in 15500 steps, reward 500.00\n",
311 | " 32 Episode in 16000 steps, reward 500.00\n",
312 | " 33 Episode in 16500 steps, reward 500.00\n",
313 | " 34 Episode in 17000 steps, reward 500.00\n",
314 | " 35 Episode in 17500 steps, reward 500.00\n",
315 | " 36 Episode in 18000 steps, reward 500.00\n",
316 | " 37 Episode in 18500 steps, reward 500.00\n",
317 | " 38 Episode in 19000 steps, reward 500.00\n",
318 | " 39 Episode in 19500 steps, reward 500.00\n",
319 | " 40 Episode in 20000 steps, reward 500.00\n",
320 | " 41 Episode in 20500 steps, reward 500.00\n",
321 | " 42 Episode in 21000 steps, reward 500.00\n",
322 | " 43 Episode in 21500 steps, reward 500.00\n",
323 | " 44 Episode in 22000 steps, reward 500.00\n",
324 | " 45 Episode in 22500 steps, reward 500.00\n",
325 | " 46 Episode in 23000 steps, reward 500.00\n",
326 | " 47 Episode in 23500 steps, reward 500.00\n",
327 | " 48 Episode in 24000 steps, reward 500.00\n",
328 | " 49 Episode in 24500 steps, reward 500.00\n",
329 | " 50 Episode in 25000 steps, reward 500.00\n",
330 | " 51 Episode in 25500 steps, reward 500.00\n",
331 | " 52 Episode in 26000 steps, reward 500.00\n",
332 | " 53 Episode in 26500 steps, reward 500.00\n",
333 | " 54 Episode in 27000 steps, reward 500.00\n",
334 | " 55 Episode in 27500 steps, reward 500.00\n",
335 | " 56 Episode in 28000 steps, reward 500.00\n",
336 | " 57 Episode in 28500 steps, reward 500.00\n",
337 | " 58 Episode in 29000 steps, reward 500.00\n",
338 | " 59 Episode in 29500 steps, reward 500.00\n",
339 | " 60 Episode in 30000 steps, reward 500.00\n",
340 | " 61 Episode in 30500 steps, reward 500.00\n",
341 | " 62 Episode in 31000 steps, reward 500.00\n",
342 | " 63 Episode in 31500 steps, reward 500.00\n",
343 | " 64 Episode in 32000 steps, reward 500.00\n",
344 | " 65 Episode in 32500 steps, reward 500.00\n",
345 | " 66 Episode in 33000 steps, reward 500.00\n",
346 | " 67 Episode in 33500 steps, reward 500.00\n",
347 | " 68 Episode in 34000 steps, reward 500.00\n",
348 | " 69 Episode in 34500 steps, reward 500.00\n",
349 | " 70 Episode in 35000 steps, reward 500.00\n",
350 | " 71 Episode in 35500 steps, reward 500.00\n",
351 | " 72 Episode in 36000 steps, reward 500.00\n",
352 | " 73 Episode in 36500 steps, reward 500.00\n",
353 | " 74 Episode in 37000 steps, reward 500.00\n",
354 | " 75 Episode in 37500 steps, reward 500.00\n",
355 | " 76 Episode in 38000 steps, reward 500.00\n",
356 | " 77 Episode in 38500 steps, reward 500.00\n",
357 | " 78 Episode in 39000 steps, reward 500.00\n",
358 | " 79 Episode in 39500 steps, reward 500.00\n",
359 | " 80 Episode in 40000 steps, reward 500.00\n",
360 | " 81 Episode in 40500 steps, reward 500.00\n",
361 | " 82 Episode in 41000 steps, reward 500.00\n",
362 | " 83 Episode in 41500 steps, reward 500.00\n",
363 | " 84 Episode in 42000 steps, reward 500.00\n",
364 | " 85 Episode in 42500 steps, reward 500.00\n",
365 | " 86 Episode in 43000 steps, reward 500.00\n",
366 | " 87 Episode in 43500 steps, reward 500.00\n",
367 | " 88 Episode in 44000 steps, reward 500.00\n",
368 | " 89 Episode in 44500 steps, reward 500.00\n",
369 | " 90 Episode in 45000 steps, reward 500.00\n",
370 | " 91 Episode in 45500 steps, reward 500.00\n",
371 | " 92 Episode in 46000 steps, reward 500.00\n",
372 | " 93 Episode in 46500 steps, reward 500.00\n",
373 | " 94 Episode in 47000 steps, reward 500.00\n",
374 | " 95 Episode in 47500 steps, reward 500.00\n",
375 | " 96 Episode in 48000 steps, reward 500.00\n",
376 | " 97 Episode in 48500 steps, reward 500.00\n",
377 | " 98 Episode in 49000 steps, reward 500.00\n",
378 | " 99 Episode in 49500 steps, reward 500.00\n",
379 | "100 Episode in 50000 steps, reward 500.00\n",
380 | "\n",
381 | "CartPole-v1 is sloved! 100 Episode in 50000 steps\n",
382 | "Mean Reward: 500.0\n"
383 | ]
384 | }
385 | ],
386 | "source": [
387 | "# play\n",
388 | "# frames = []\n",
389 | "for i in range(1, n_episodes + 1):\n",
390 | " obs = env.reset()\n",
391 | " done = False\n",
392 | " ep_reward = 0\n",
393 | " while not done:\n",
394 | "# frames.append(env.render(mode = 'rgb_array'))\n",
395 | " env.render()\n",
396 | " action = select_action(obs, target_net)\n",
397 | " _obs, reward, done, _ = env.step(action)\n",
398 | " obs = _obs\n",
399 | " total_steps += 1\n",
400 | " ep_reward += reward \n",
401 | " if done:\n",
402 | " env.render()\n",
403 | " rewards.append(ep_reward)\n",
404 | " reward_eval.append(ep_reward)\n",
405 | " print('{:3} Episode in {:5} steps, reward {:.2f}'.format(\n",
406 | " i, total_steps, ep_reward))\n",
407 | "# frames.append(env.render(mode = 'rgb_array'))\n",
408 | "# imageio.mimsave(f'{env.spec.id}.gif', frames,)\n",
409 | " \n",
410 | " if len(reward_eval) >= n_eval:\n",
411 | " if np.mean(reward_eval) >= env.spec.reward_threshold:\n",
412 | " print('\\n{} is sloved! {:3} Episode in {:3} steps'.format(\n",
413 | " env.spec.id, i, total_steps))\n",
414 | " print(f'Mean Reward: {np.mean(reward_eval).round(decimals=2)}')\n",
415 | " break\n",
416 | "env.close()"
417 | ]
418 | },
419 | {
420 | "cell_type": "code",
421 | "execution_count": 10,
422 | "metadata": {
423 | "scrolled": false
424 | },
425 | "outputs": [
426 | {
427 | "data": {
428 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA3AAAAE/CAYAAAAHeyFHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFTRJREFUeJzt3X2wpnV93/HPV1YwUpAoKwMsZM2IaZCJhDmhNKQaJRV5iJhWp5tqZIwpfWAmJtNWpUmbxKkzzTSjlj6QUEyKbVApZputsY5UIqaTAd0VyqMmOzy426Wy+AAOKBb59o/7WjnCWfYse3Zvfnter5kz57p+53ef+3eYa67De6/rvk91dwAAAHj2e868FwAAAMDyCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAWEFVdU9V/cy81wHAwUnAAQAADELAATC8qlqzGp4TAAQcAEOablV8V1XdkuThqjqxqj5WVTur6u6q+uVp3vOq6ltVdfS0/+tV9VhVHTnt/8uq+sC0fV5V3VRVD1XVtqr6zUXPt76quqreXlVfTnLdNP4LVXVvVX21qn7tAP9nAGCVEXAAjOznk5yX5IVJNib530mOT3JWkl+pqrO7+9tJPp/kVdNjXpnk3iRnLtq/ftp+OMlbkxw1fd9/WFVveNJzvirJjyY5u6pOTnJZkl9IclySFyVZt8I/IwB8j4ADYGSXdve2JKckWdvd7+nu73T3XUn+Y5IN07zrk7xquu3xx5JcOu0/L8lPJPmzJOnuz3T3rd39eHffkuTDeSL8dvnN7n64u7+V5I1JPt7dn+3uR5P88ySP798fGYDVzP37AIxs2/T5h5IcV1XfWPS1QzKFWWYB974kpyW5Ncm1ST6Y5IwkW7v7gSSpqr+W5F9lFoSHJjksyX/dzXMms6tu39vv7oer6qv7/mMBwNJcgQNgZD193pbk7u4+atHHEd197vT1P0/yI0l+Lsn13X1HkhMzu03y+kXf76okm5Kc0N0vSPK7SWo3z5kk9yU5YddOVT0/s9soAWC/EHAAHAw+l+Sh6U1NfqCqDqmqU6rqJ5Kkux9JsiXJxXki2P48yd/P9wfcEUm+1t3frqrTk/zdPTzvNUnOr6qfqqpDk7wnfrcCsB/5JQPA8Lr7u0l+NsmpSe5O8kCSK5K8YNG065M8N7PY27V/RJLPLprzj5K8p6q+meRfJLl6D897e2ZReFVmV+O+nmT7Pv44ALBb1d17ngUAAMDcuQIHAAAwCAEHAAAwCAEHAAAwCAEHAAAwCAEHAAAwiDXzXkCSHH300b1+/fp5LwMAAGAutmzZ8kB3r93TvGdFwK1fvz6bN2+e9zIAAADmoqruXc48t1ACAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMYlkBV1X3VNWtVXVzVW2exv51VX2xqm6pqo1VddSi+ZdU1daq+lJVnb2/Fg8AALCa7M0VuFd396ndvTDtX5vklO7+sSR/keSSJKmqk5NsSPLyJK9L8h+q6pAVXDMAAMCq9IxvoezuT3X3Y9PuDUnWTdsXJPlIdz/a3Xcn2Zrk9H1bJgAAAMsNuE7yqaraUlUXLfH1X0zyP6bt45NsW/S17dMYAAAA+2DNMued2d07qurFSa6tqi9292eTpKp+LcljSf5wmltLPL6fPDCF4EVJcuKJJ+71wgEAAFabZV2B6+4d0+f7k2zMdEtkVV2Y5Pwkb+7uXZG2PckJix6+LsmOJb7n5d290N0La9eufeY/AQAAwCqxx4CrqsOr6ohd20lem+S2qnpdkncleX13P7LoIZuSbKiqw6rqJUlOSvK5lV86AADA6rKcWyiPSbKxqnbNv6q7P1lVW5McltktlUlyQ3f/g+6+vaquTnJHZrdWXtzd390/ywcAAFg99hhw3X1XklcsMf7Sp3nMe5O8d9+WBgAAwGLP+M8IAAAAcGAJOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEIOAAAgEEsK+Cq6p6qurWqbq6qzdPYm6rq9qp6vKoWnjT/kqraWlVfqqqz98fCAQAAVps1ezH31d39wKL925L8rSS/t3hSVZ2cZEOSlyc5Lsn/rKqXdfd393WxAAAAq9kzvoWyu+/s7i8t8aULknykux/t7ruTbE1y+jN9HgAAAGaWG3Cd5FNVtaWqLtrD3OOTbFu0v30a+z5VdVFVba6qzTt37lzmMgAAAFav5Qbcmd19WpJzklxcVa98mrm1xFg/ZaD78u5e6O6FtWvXLnMZAAAAq9eyAq67d0yf70+yMU9/S+T2JCcs2l+XZMczXSAAAAAzewy4qjq8qo7YtZ3ktZm9gcnubEqyoaoOq6qXJDkpyedWYrEAAACr2XLehfKYJBuratf8q7r7k1X1c0n+bZK1Sf6kqm7u7rO7+/aqujrJHUkeS3Kxd6AEAADYd9X9lJenHXALCwu9efPmeS8DAABgLqpqS3cv7GneM/4zAgAAABxYAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQywq4qrqnqm6tqpuravM09sKquraq/nL6/IPTeFXVpVW1tapuqarT9ucPAAAAsFrszRW4V3f3qd29MO2/O8mnu/ukJJ+e9pPknCQnTR8XJblspRYLAACwmu3LLZQXJLly2r4yyRsWjX+oZ25IclRVHbsPzwMAAECSNcuc10k+VVWd5Pe6+/Ikx3T3fUnS3fdV1Yunuccn2bbosdunsftWaM0HxG/999tzx46H5r0MAABgBZ183JH5jZ99+byX8YwtN+DO7O4dU6RdW1VffJq5tcRYP2VS1UWZ3WKZE088cZnLAAAAWL2WFXDdvWP6fH9VbUxyepKvVNWx09W3Y5PcP03fnuSERQ9fl2THEt/z8iSXJ8nCwsJTAm/eRq5yAADg4LTH18BV1eFVdcSu7SSvTXJbkk1JLpymXZjkj6ftTUneOr0b5RlJHtx1qyUAAADP3HKuwB2TZGNV7Zp/VXd/sqo+n+Tqqnp7ki8nedM0/xNJzk2yNckjSd624qsGAABYhfYYcN19V5JXLDH+1SRnLTHeSS5ekdUBAADwPfvyZwQAAAA4gAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIJYdcFV1SFXdVFUfn/ZfU1VfqKrbqurKqlozjVdVXVpVW6vqlqo6bX8tHgAAYDXZmytw70hyZ5JU1XOSXJlkQ3efkuTeJBdO885JctL0cVGSy1ZstQAAAKvYsgKuqtYlOS/JFdPQi5I82t1/Me1fm+RvT9sXJPlQz9yQ5KiqOnYF1wwAALAqLfcK3AeSvDPJ49P+A0meW1UL0/4bk5wwbR+fZNuix26fxr5PVV1UVZuravPOnTv3euEAAACrzR4DrqrOT3J/d2/ZNdbdnWRDkvdX1eeSfDPJY7sessS36acMdF/e3QvdvbB27dpntHgAAIDVZM0y5pyZ5PVVdW6S5yU5sqr+S3e/JcnfSJKqem2Sl03zt+eJq3FJsi7JjpVbMgAAwOq0xytw3X1Jd6/r7vWZXXW7rrvfUlUvTpKqOizJu5L87vSQTUneOr0b5RlJHuzu+/bP8gEAAFaP5VyB251/Ot1e+Zwkl3X3ddP4J5Kcm2RrkkeSvG3flggAAECS1OzlbPO1sLDQmzdvnvcyAAAA5qKqtnT3wp7m7c3fgQMAAGCOBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAgBBwAAMAglh1wVXVIVd1UVR+f9s+qqi9U1c1V9b+q6qXT+GFV9dGq2lpVN1bV+v2zdAAAgNVlb67AvSPJnYv2L0vy5u4+NclVSX59Gn97kq9390uTvD/Jb6/EQgEAAFa7ZQVcVa1Lcl6SKxYNd5Ijp+0XJNkxbV+Q5Mpp+5okZ1VV7ftSAQAAVrc1y5z3gSTvTHLEorFfSvKJqvpWkoeSnDGNH59kW5J092NV9WCSFyV5YEVWDAAAsErt8QpcVZ2f5P7u3vKkL/1qknO7e12SP0jyvl0PWeLb9BLf96Kq2lxVm3fu3LmXywYAAFh9lnML5ZlJXl9V9yT5SJLXVNWfJHlFd984zflokp+ctrcnOSFJqmpNZrdXfu3J37S7L+/uhe5eWLt27b79FAAAAKvAHgOuuy/p7nXdvT7JhiTXZfY6txdU1cumaX8zT7zByaYkF07bb0xyXXc/5QocAAAAe2e5r4H7PtNr2/5eko9V1eNJvp7kF6cvfzDJf66qrZldeduwIisFAABY5fYq4Lr7M0k+M21vTLJxiTnfTvKmFVgbAAAAi+zN34EDAABgjgQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIKq7572GVNXOJPfOex1LODrJA/NeBAc9xxkHguOM/c0xxoHgOONAmNdx9kPdvXZPk54VAfdsVVWbu3th3uvg4OY440BwnLG/OcY4EBxnHAjP9uPMLZQAAACDEHAAAACDEHBP7/J5L4BVwXHGgeA4Y39zjHEgOM44EJ7Vx5nXwAEAAAzCFTgAAIBBCLjdqKrXVdWXqmprVb173uthfFV1QlX9aVXdWVW3V9U7pvEXVtW1VfWX0+cfnPdaGV9VHVJVN1XVx6f9l1TVjdNx9tGqOnTea2RsVXVUVV1TVV+czmt/3fmMlVRVvzr9vrytqj5cVc9zLmMlVNXvV9X9VXXborElz181c+nUBLdU1WnzW/mMgFtCVR2S5N8nOSfJyUl+vqpOnu+qOAg8luQfd/ePJjkjycXTcfXuJJ/u7pOSfHrah331jiR3Ltr/7STvn46zryd5+1xWxcHk3yT5ZHf/1SSvyOx4cz5jRVTV8Ul+OclCd5+S5JAkG+Jcxsr4T0le96Sx3Z2/zkly0vRxUZLLDtAad0vALe30JFu7+67u/k6SjyS5YM5rYnDdfV93f2Ha/mZm/7NzfGbH1pXTtCuTvGE+K+RgUVXrkpyX5Ippv5K8Jsk10xTHGfukqo5M8sokH0yS7v5Od38jzmesrDVJfqCq1iR5fpL74lzGCujuzyb52pOGd3f+uiDJh3rmhiRHVdWxB2alSxNwSzs+ybZF+9unMVgRVbU+yY8nuTHJMd19XzKLvCQvnt/KOEh8IMk7kzw+7b8oyTe6+7Fp3zmNffXDSXYm+YPpVt0rqurwOJ+xQrr7/yT5nSRfzizcHkyyJc5l7D+7O38967pAwC2tlhjzdp2siKr6K0k+luRXuvuhea+Hg0tVnZ/k/u7esnh4ianOaeyLNUlOS3JZd/94kofjdklW0PT6owuSvCTJcUkOz+xWtidzLmN/e9b9DhVwS9ue5IRF++uS7JjTWjiIVNVzM4u3P+zuP5qGv7LrUvz0+f55rY+DwplJXl9V92R2+/drMrsid9R0G1LinMa+255ke3ffOO1fk1nQOZ+xUn4myd3dvbO7/1+SP0ryk3EuY//Z3fnrWdcFAm5pn09y0vROR4dm9qLZTXNeE4ObXof0wSR3dvf7Fn1pU5ILp+0Lk/zxgV4bB4/uvqS713X3+szOXdd195uT/GmSN07THGfsk+7+v0m2VdWPTENnJbkjzmesnC8nOaOqnj/9/tx1jDmXsb/s7vy1Kclbp3ejPCPJg7tutZwXf8h7N6rq3Mz+1fqQJL/f3e+d85IYXFX9VJI/S3Jrnnht0j/L7HVwVyc5MbNfWG/q7ie/sBb2WlX9dJJ/0t3nV9UPZ3ZF7oVJbkrylu5+dJ7rY2xVdWpmb5RzaJK7krwts38Ydj5jRVTVbyX5O5m9i/NNSX4ps9ceOZexT6rqw0l+OsnRSb6S5DeS/Lcscf6a/gHh32X2rpWPJHlbd2+ex7p3EXAAAACDcAslAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIP4/OeGmLRKsYKQAAAAASUVORK5CYII=\n",
429 | "text/plain": [
430 | ""
431 | ]
432 | },
433 | "metadata": {
434 | "needs_background": "light"
435 | },
436 | "output_type": "display_data"
437 | }
438 | ],
439 | "source": [
440 | "plt.figure(figsize=(15, 5))\n",
441 | "plt.title('reward')\n",
442 | "plt.plot(rewards)\n",
443 | "plt.show()"
444 | ]
445 | },
446 | {
447 | "cell_type": "code",
448 | "execution_count": 11,
449 | "metadata": {},
450 | "outputs": [
451 | {
452 | "data": {
453 | "text/plain": [
454 | "[('CartPole-v0', 412, 1),\n",
455 | " ('CartPole-v1', 452, 0.05),\n",
456 | " ('MountainCar-v0', 193, 0.1),\n",
457 | " ('LunarLander-v2', 260, 0.1)]"
458 | ]
459 | },
460 | "execution_count": 11,
461 | "metadata": {},
462 | "output_type": "execute_result"
463 | }
464 | ],
465 | "source": [
466 | "[\n",
467 | " ('CartPole-v0', 412, 1),\n",
468 | " ('CartPole-v1', 452, 0.05),\n",
469 | " ('MountainCar-v0', 193, 0.1),\n",
470 | " ('LunarLander-v2', 260, 0.1)\n",
471 | "]"
472 | ]
473 | }
474 | ],
475 | "metadata": {
476 | "colab": {
477 | "collapsed_sections": [],
478 | "name": "C51_tensorflow.ipynb",
479 | "provenance": [],
480 | "version": "0.3.2"
481 | },
482 | "kernelspec": {
483 | "display_name": "Python 3",
484 | "language": "python",
485 | "name": "python3"
486 | },
487 | "language_info": {
488 | "codemirror_mode": {
489 | "name": "ipython",
490 | "version": 3
491 | },
492 | "file_extension": ".py",
493 | "mimetype": "text/x-python",
494 | "name": "python",
495 | "nbconvert_exporter": "python",
496 | "pygments_lexer": "ipython3",
497 | "version": "3.6.7"
498 | }
499 | },
500 | "nbformat": 4,
501 | "nbformat_minor": 1
502 | }
503 |
--------------------------------------------------------------------------------
/src/test/test_ppo_with_rnd.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "colab": {},
8 | "colab_type": "code",
9 | "id": "IWnm3qot3o1W"
10 | },
11 | "outputs": [],
12 | "source": [
13 | "from collections import deque\n",
14 | "\n",
15 | "import gym\n",
16 | "import imageio\n",
17 | "import matplotlib.pyplot as plt\n",
18 | "import numpy as np\n",
19 | "import torch\n",
20 | "import torch.nn as nn\n",
21 | "from torch.distributions import Categorical "
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "metadata": {
28 | "colab": {},
29 | "colab_type": "code",
30 | "id": "9Ffkl_5C4R81"
31 | },
32 | "outputs": [],
33 | "source": [
34 | "class ActorCriticNet(nn.Module):\n",
35 | " def __init__(self, obs_space, action_space):\n",
36 | " super().__init__()\n",
37 | "\n",
38 | " self.head = nn.Sequential(\n",
39 | " nn.Linear(obs_space, obs_space*10),\n",
40 | " nn.SELU()\n",
41 | " )\n",
42 | " self.pol = nn.Sequential(\n",
43 | " nn.Linear(obs_space*10, 512),\n",
44 | " nn.SELU(),\n",
45 | " nn.Linear(512, 512),\n",
46 | " nn.SELU(),\n",
47 | " nn.Linear(512, action_space)\n",
48 | " )\n",
49 | " self.val_ex = nn.Sequential(\n",
50 | " nn.Linear(obs_space*10, 512),\n",
51 | " nn.SELU(),\n",
52 | " nn.Linear(512, 512),\n",
53 | " nn.SELU(),\n",
54 | " nn.Linear(512, 1)\n",
55 | " )\n",
56 | " self.val_in = nn.Sequential(\n",
57 | " nn.Linear(obs_space*10, 512),\n",
58 | " nn.SELU(),\n",
59 | " nn.Linear(512, 512),\n",
60 | " nn.SELU(),\n",
61 | " nn.Linear(512, 1)\n",
62 | " )\n",
63 | " self.log_softmax = nn.LogSoftmax(dim=-1)\n",
64 | "\n",
65 | " def forward(self, x):\n",
66 | " out = self.head(x)\n",
67 | " logit = self.pol(out).reshape(out.shape[0], -1)\n",
68 | " value_ex = self.val_ex(out).reshape(out.shape[0], 1)\n",
69 | " value_in = self.val_in(out).reshape(out.shape[0], 1)\n",
70 | " log_probs = self.log_softmax(logit)\n",
71 | " \n",
72 | " return log_probs, value_ex, value_in"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": 3,
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "def get_action_and_value(obs, old_net):\n",
82 | " old_net.eval()\n",
83 | " with torch.no_grad():\n",
84 | " state = torch.tensor([obs]).to(device).float()\n",
85 | " log_p, _, _ = old_net(state)\n",
86 | " m = Categorical(log_p.exp())\n",
87 | " action = m.sample()\n",
88 | "\n",
89 | " return action.item()"
90 | ]
91 | },
92 | {
93 | "cell_type": "markdown",
94 | "metadata": {},
95 | "source": [
96 | "## Main"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 4,
102 | "metadata": {
103 | "colab": {
104 | "base_uri": "https://localhost:8080/",
105 | "height": 3377
106 | },
107 | "colab_type": "code",
108 | "executionInfo": {
109 | "elapsed": 135196,
110 | "status": "ok",
111 | "timestamp": 1534482559393,
112 | "user": {
113 | "displayName": "윤승제",
114 | "photoUrl": "//lh5.googleusercontent.com/-EucKC7DmcQI/AAAAAAAAAAI/AAAAAAAAAGA/gQU1NPEmNFA/s50-c-k-no/photo.jpg",
115 | "userId": "105654037995838004821"
116 | },
117 | "user_tz": -540
118 | },
119 | "id": "PnifSBJglzHh",
120 | "outputId": "94177345-918e-4a96-d9a8-d8aba0a4bc9a",
121 | "scrolled": false
122 | },
123 | "outputs": [],
124 | "source": [
125 | "# set device\n",
126 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
127 | "\n",
128 | "# make an environment\n",
129 | "# env = gym.make('CartPole-v0')\n",
130 | "env = gym.make('CartPole-v1')\n",
131 | "# env = gym.make('MountainCar-v0')\n",
132 | "# env = gym.make('LunarLander-v2')\n",
133 | "\n",
134 | "SEED = 0\n",
135 | "env.seed(SEED)\n",
136 | "obs_space = env.observation_space.shape[0]\n",
137 | "action_space = env.action_space.n\n",
138 | "\n",
139 | "# hyperparameter\n",
140 | "n_episodes = 1000\n",
141 | "n_eval = env.spec.trials\n",
142 | "\n",
143 | "# global values\n",
144 | "total_steps = 0\n",
145 | "rewards = []\n",
146 | "reward_eval = deque(maxlen=n_eval)\n",
147 | "is_solved = False\n",
148 | "\n",
149 | "# load a model\n",
150 | "target_net = ActorCriticNet(obs_space, action_space).to(device)\n",
151 | "target_net.load_state_dict(torch.load(\n",
152 | " './saved_models/CartPole-v1_ep225_clear_model_ppo.pt'))"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": 5,
158 | "metadata": {
159 | "scrolled": true
160 | },
161 | "outputs": [
162 | {
163 | "data": {
164 | "text/plain": [
165 | "500"
166 | ]
167 | },
168 | "execution_count": 5,
169 | "metadata": {},
170 | "output_type": "execute_result"
171 | }
172 | ],
173 | "source": [
174 | "env.spec.max_episode_steps"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": 6,
180 | "metadata": {
181 | "scrolled": true
182 | },
183 | "outputs": [
184 | {
185 | "data": {
186 | "text/plain": [
187 | "100"
188 | ]
189 | },
190 | "execution_count": 6,
191 | "metadata": {},
192 | "output_type": "execute_result"
193 | }
194 | ],
195 | "source": [
196 | "env.spec.trials"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 7,
202 | "metadata": {},
203 | "outputs": [
204 | {
205 | "data": {
206 | "text/plain": [
207 | "475.0"
208 | ]
209 | },
210 | "execution_count": 7,
211 | "metadata": {},
212 | "output_type": "execute_result"
213 | }
214 | ],
215 | "source": [
216 | "env.spec.reward_threshold"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": 8,
222 | "metadata": {},
223 | "outputs": [],
224 | "source": [
225 | "# env.metadata['video.frames_per_second'] = 60"
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": 9,
231 | "metadata": {
232 | "colab": {
233 | "base_uri": "https://localhost:8080/",
234 | "height": 3377
235 | },
236 | "colab_type": "code",
237 | "executionInfo": {
238 | "elapsed": 135196,
239 | "status": "ok",
240 | "timestamp": 1534482559393,
241 | "user": {
242 | "displayName": "윤승제",
243 | "photoUrl": "//lh5.googleusercontent.com/-EucKC7DmcQI/AAAAAAAAAAI/AAAAAAAAAGA/gQU1NPEmNFA/s50-c-k-no/photo.jpg",
244 | "userId": "105654037995838004821"
245 | },
246 | "user_tz": -540
247 | },
248 | "id": "PnifSBJglzHh",
249 | "outputId": "94177345-918e-4a96-d9a8-d8aba0a4bc9a",
250 | "scrolled": true
251 | },
252 | "outputs": [
253 | {
254 | "name": "stdout",
255 | "output_type": "stream",
256 | "text": [
257 | " 1 Episode in 500 steps, reward 500.00\n",
258 | " 2 Episode in 1000 steps, reward 500.00\n",
259 | " 3 Episode in 1500 steps, reward 500.00\n",
260 | " 4 Episode in 2000 steps, reward 500.00\n",
261 | " 5 Episode in 2500 steps, reward 500.00\n",
262 | " 6 Episode in 3000 steps, reward 500.00\n",
263 | " 7 Episode in 3500 steps, reward 500.00\n",
264 | " 8 Episode in 4000 steps, reward 500.00\n",
265 | " 9 Episode in 4500 steps, reward 500.00\n",
266 | " 10 Episode in 5000 steps, reward 500.00\n",
267 | " 11 Episode in 5500 steps, reward 500.00\n",
268 | " 12 Episode in 6000 steps, reward 500.00\n",
269 | " 13 Episode in 6500 steps, reward 500.00\n",
270 | " 14 Episode in 7000 steps, reward 500.00\n",
271 | " 15 Episode in 7500 steps, reward 500.00\n",
272 | " 16 Episode in 8000 steps, reward 500.00\n",
273 | " 17 Episode in 8500 steps, reward 500.00\n",
274 | " 18 Episode in 9000 steps, reward 500.00\n",
275 | " 19 Episode in 9500 steps, reward 500.00\n",
276 | " 20 Episode in 10000 steps, reward 500.00\n",
277 | " 21 Episode in 10500 steps, reward 500.00\n",
278 | " 22 Episode in 11000 steps, reward 500.00\n",
279 | " 23 Episode in 11500 steps, reward 500.00\n",
280 | " 24 Episode in 12000 steps, reward 500.00\n",
281 | " 25 Episode in 12500 steps, reward 500.00\n",
282 | " 26 Episode in 13000 steps, reward 500.00\n",
283 | " 27 Episode in 13500 steps, reward 500.00\n",
284 | " 28 Episode in 14000 steps, reward 500.00\n",
285 | " 29 Episode in 14500 steps, reward 500.00\n",
286 | " 30 Episode in 15000 steps, reward 500.00\n",
287 | " 31 Episode in 15500 steps, reward 500.00\n",
288 | " 32 Episode in 16000 steps, reward 500.00\n",
289 | " 33 Episode in 16500 steps, reward 500.00\n",
290 | " 34 Episode in 17000 steps, reward 500.00\n",
291 | " 35 Episode in 17500 steps, reward 500.00\n",
292 | " 36 Episode in 18000 steps, reward 500.00\n",
293 | " 37 Episode in 18500 steps, reward 500.00\n",
294 | " 38 Episode in 19000 steps, reward 500.00\n",
295 | " 39 Episode in 19500 steps, reward 500.00\n",
296 | " 40 Episode in 20000 steps, reward 500.00\n",
297 | " 41 Episode in 20500 steps, reward 500.00\n",
298 | " 42 Episode in 21000 steps, reward 500.00\n",
299 | " 43 Episode in 21500 steps, reward 500.00\n",
300 | " 44 Episode in 22000 steps, reward 500.00\n",
301 | " 45 Episode in 22500 steps, reward 500.00\n",
302 | " 46 Episode in 23000 steps, reward 500.00\n",
303 | " 47 Episode in 23500 steps, reward 500.00\n",
304 | " 48 Episode in 24000 steps, reward 500.00\n",
305 | " 49 Episode in 24500 steps, reward 500.00\n",
306 | " 50 Episode in 25000 steps, reward 500.00\n",
307 | " 51 Episode in 25500 steps, reward 500.00\n",
308 | " 52 Episode in 26000 steps, reward 500.00\n",
309 | " 53 Episode in 26500 steps, reward 500.00\n",
310 | " 54 Episode in 27000 steps, reward 500.00\n",
311 | " 55 Episode in 27500 steps, reward 500.00\n",
312 | " 56 Episode in 28000 steps, reward 500.00\n",
313 | " 57 Episode in 28500 steps, reward 500.00\n",
314 | " 58 Episode in 29000 steps, reward 500.00\n",
315 | " 59 Episode in 29500 steps, reward 500.00\n",
316 | " 60 Episode in 30000 steps, reward 500.00\n",
317 | " 61 Episode in 30500 steps, reward 500.00\n",
318 | " 62 Episode in 31000 steps, reward 500.00\n",
319 | " 63 Episode in 31500 steps, reward 500.00\n",
320 | " 64 Episode in 32000 steps, reward 500.00\n",
321 | " 65 Episode in 32500 steps, reward 500.00\n",
322 | " 66 Episode in 33000 steps, reward 500.00\n",
323 | " 67 Episode in 33500 steps, reward 500.00\n",
324 | " 68 Episode in 34000 steps, reward 500.00\n",
325 | " 69 Episode in 34500 steps, reward 500.00\n",
326 | " 70 Episode in 35000 steps, reward 500.00\n",
327 | " 71 Episode in 35500 steps, reward 500.00\n",
328 | " 72 Episode in 36000 steps, reward 500.00\n",
329 | " 73 Episode in 36500 steps, reward 500.00\n",
330 | " 74 Episode in 37000 steps, reward 500.00\n",
331 | " 75 Episode in 37500 steps, reward 500.00\n",
332 | " 76 Episode in 38000 steps, reward 500.00\n",
333 | " 77 Episode in 38500 steps, reward 500.00\n",
334 | " 78 Episode in 39000 steps, reward 500.00\n",
335 | " 79 Episode in 39500 steps, reward 500.00\n",
336 | " 80 Episode in 40000 steps, reward 500.00\n",
337 | " 81 Episode in 40500 steps, reward 500.00\n",
338 | " 82 Episode in 41000 steps, reward 500.00\n",
339 | " 83 Episode in 41500 steps, reward 500.00\n",
340 | " 84 Episode in 42000 steps, reward 500.00\n",
341 | " 85 Episode in 42500 steps, reward 500.00\n",
342 | " 86 Episode in 43000 steps, reward 500.00\n",
343 | " 87 Episode in 43500 steps, reward 500.00\n",
344 | " 88 Episode in 44000 steps, reward 500.00\n",
345 | " 89 Episode in 44500 steps, reward 500.00\n",
346 | " 90 Episode in 45000 steps, reward 500.00\n",
347 | " 91 Episode in 45500 steps, reward 500.00\n",
348 | " 92 Episode in 46000 steps, reward 500.00\n",
349 | " 93 Episode in 46500 steps, reward 500.00\n",
350 | " 94 Episode in 47000 steps, reward 500.00\n",
351 | " 95 Episode in 47500 steps, reward 500.00\n",
352 | " 96 Episode in 48000 steps, reward 500.00\n",
353 | " 97 Episode in 48500 steps, reward 500.00\n",
354 | " 98 Episode in 49000 steps, reward 500.00\n",
355 | " 99 Episode in 49500 steps, reward 500.00\n",
356 | "100 Episode in 50000 steps, reward 500.00\n",
357 | "\n",
358 | "CartPole-v1 is sloved! 100 Episode in 50000 steps\n",
359 | "500.0\n"
360 | ]
361 | }
362 | ],
363 | "source": [
364 | "# play\n",
365 | "# frames = []\n",
366 | "for i in range(1, n_episodes + 1):\n",
367 | " obs = env.reset()\n",
368 | " done = False\n",
369 | " ep_reward = 0\n",
370 | " while not done:\n",
371 | "# frames.append(env.render(mode = 'rgb_array'))\n",
372 | " env.render()\n",
373 | " action = get_action_and_value(obs, target_net)\n",
374 | " _obs, reward, done, _ = env.step(action)\n",
375 | " obs = _obs\n",
376 | " total_steps += 1\n",
377 | " ep_reward += reward \n",
378 | " if done:\n",
379 | " env.render()\n",
380 | " rewards.append(ep_reward)\n",
381 | " reward_eval.append(ep_reward)\n",
382 | " print('{:3} Episode in {:5} steps, reward {:.2f}'.format(\n",
383 | " i, total_steps, ep_reward))\n",
384 | "# frames.append(env.render(mode = 'rgb_array'))\n",
385 | "# imageio.mimsave(f'{env.spec.id}.gif', frames,)\n",
386 | " \n",
387 | " if len(reward_eval) >= n_eval:\n",
388 | " if np.mean(reward_eval) >= env.spec.reward_threshold:\n",
389 | " print('\\n{} is sloved! {:3} Episode in {:3} steps'.format(\n",
390 | " env.spec.id, i, total_steps))\n",
391 | " print(f'Mean Reward: {np.mean(reward_eval).round(decimals=2)}')\n",
392 | " break\n",
393 | "env.close()"
394 | ]
395 | },
396 | {
397 | "cell_type": "code",
398 | "execution_count": 10,
399 | "metadata": {
400 | "scrolled": false
401 | },
402 | "outputs": [
403 | {
404 | "data": {
405 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA3AAAAE/CAYAAAAHeyFHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFRJJREFUeJzt3X+wpmV93/HPVzZgpAhRVkZYyNoRkyITCXNCaaUxSiryo2ISk5JqpMZ0Jy0zMZm2RmrbJHacaaYdtbQJCcWk2AaVYLbZGutIJWI6GdCzQvmpyQ4/XDYoCyo64I8i3/7x3CvH7a57lnN2H649r9fMzrnv67me81yHuede3nvfz3OquwMAAMDT3zPmvQAAAACWR8ABAAAMQsABAAAMQsABAAAMQsABAAAMQsABAAAMQsABwCqqqnur6sfnvQ4ADk0CDgAAYBACDoDhVdW6tfCaACDgABjSdKvir1bVrUkeraqTquqDVbWzqu6pql+a5j2zqr5WVcdO+2+rqser6tnT/r+pqndP2+dX1c1V9ZWq2l5Vv77k9TZWVVfVm6rqc0mun8Z/rqruq6qHq+ptB/k/AwBrjIADYGQ/m+T8JM9JsjnJ/0lyQpKzk/xyVZ3T3V9P8qkkL5ue87Ik9yV56ZL9G6btR5O8Ickx0/f9x1X1mt1e82VJ/kaSc6rqlCSXJ/m5JMcneW6SDav8MwLAtwk4AEZ2WXdvT3JqkvXd/fbu/mZ3353kPye5aJp3Q5KXTbc9/lCSy6b9Zyb5kSSfSJLu/nh339bdT3T3rUnelyfDb5df7+5Hu/trSV6b5EPd/Ynu/kaSf5XkiQP7IwOwlrl/H4CRbZ++fn+S46vqy0seOyzJn03bNyR5Z5LTk9yW5Lok70lyZpJt3f1wklTV30zybzMLwsOTHJHkD/fymsnsqtu397v70ap6eOU/FgDsmStwAIysp6/bk9zT3ccs+XNUd583Pf7nSX4gyU8kuaG770xyUpLz8uTtk0lydZItSU7s7qOT/E6S2strJskDSU7ctVNVz8rsNkoAOCAEHACHgk8m+er0oSbfW1WHVdWpVfUjSdLdjyXZmuSSPBlsf57kF/OdAXdUki9299er6owk/2Afr3ttkguq6qyqOjzJ2+PvVgAOIH/JADC87v5WkguSnJbkniQPJbkyydFLpt2Q5Hsyi71d+0dlev/b5J8keXtVfTXJv05yzT5e947MovDqzK7GfSnJ/Sv8cQBgr6q79z0LAACAuXMFDgAAYBACDgAAYBACDgAAYBACDgAAYBACDgAAYBDr5r2AJDn22GN748aN814GAADAXGzduvWh7l6/r3lPi4DbuHFjFhcX570MAACAuaiq+5Yzzy2UAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAgxBwAAAAg1hWwFXVvVV1W1XdUlWL09i/q6rPVNWtVbW5qo5ZMv/SqtpWVZ+tqnMO1OIBAADWkv25Avfy7j6tuxem/euSnNrdP5TkL5JcmiRVdUqSi5K8OMmrkvx2VR22imsGAABYk57yLZTd/dHufnzavTHJhmn7wiTv7+5vdPc9SbYlOWNlywQAAGC5AddJPlpVW6tq0x4e//kk/3PaPiHJ9iWP3T+NAQAAsALrljnvrO7eUVXPS3JdVX2muz+RJFX1tiSPJ/mD/XnhKQQ3JclJJ520P08FAABYk5Z1Ba67d0xfH0yyOdMtkVX1D5NckOR13d3T9B1JTlzy9A3T2O7f84ruXujuhfXr1z/lHwAAAGCt2GfAVdWRVXXUru0kr0xye1W9Kslbkry6ux9b8pQtSS6qqiOq6gVJTk7yydVfOgAAwNqynFsoj0uyuap2zb+6uz9SVduSHJHZLZVJcmN3/2J331FV1yS5M7NbKy/p7m8dmOUDAACsHfsMuO6+O8lL9jD+wu/ynHckecfKlgYAAMBST/nXCAAAAHBwCTgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBCDgAAIBBLCvgqureqrqtqm6pqsVp7Ker6o6qeqKqFnabf2lVbauqz1bVOQdi4QAAAGvNuv2Y+/LufmjJ/u1JfjLJ7y6dVFWnJLkoyYuTHJ/kf1XVi7r7WytdLAAAwFr2lG+h7O67uvuze3jowiTv7+5vdPc9SbYlOeOpvg4AAAAzyw24TvLRqtpaVZv2MfeEJNuX7N8/jX2HqtpUVYtVtbhz585lLgMAAGDtWm7AndXdpyc5N8klVfWjK33h7r6iuxe6e2H9+vUr/XYAAACHvGUFXHfvmL4+mGRzvvstkTuSnLhkf8M0BgAAwArsM+Cq6siqOmrXdpJXZvYBJnuzJclFVXVEVb0gyclJPrkaiwUAAFjLlvMplMcl2VxVu+Zf3d0fqaqfSPIfk6xP8idVdUt3n9Pdd1TVNUnuTPJ4kkt8AiUAAMDKVXfPew1ZWFjoxcXFeS8DAABgLqpqa3cv7GveU/41AgAAABxcAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQAg4AAGAQywq4qrq3qm6rqluqanEae05VXVdVfzl9/b5pvKrqsqraVlW3VtXpB/IHAAAAWCv25wrcy7v7tO5emPbfmuRj3X1yko9N+0lybpKTpz+bkly+WosFAABYy1ZyC+WFSa6atq9K8pol4+/tmRuTHFNVz1/B6wAAAJBk3TLndZKPVlUn+d3uviLJcd39wPT455McN22fkGT7kufeP409kIH8xv+4I3f+1VfmvQwAAGAVnXL8s/Nrf+/F817GU7bcgDuru3dU1fOSXFdVn1n6YHf3FHfLVlWbMrvFMieddNL+PBUAAGBNWlbAdfeO6euDVbU5yRlJvlBVz+/uB6ZbJB+cpu9IcuKSp2+Yxnb/nlckuSJJFhYW9iv+DoaRqxwAADg07fM9cFV1ZFUdtWs7ySuT3J5kS5KLp2kXJ/njaXtLkjdMn0Z5ZpJHltxqCQAAwFO0nCtwxyXZXFW75l/d3R+pqk8luaaq3pTkviQ/M83/cJLzkmxL8liSN676qgEAANagfQZcd9+d5CV7GH84ydl7GO8kl6zK6gAAAPi2lfwaAQAAAA4iAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADAIAQcAADCIZQdcVR1WVTdX1Yem/VdU1aer6vaquqqq1k3jVVWXVdW2qrq1qk4/UIsHAABYS/bnCtybk9yVJFX1jCRXJbmou09Ncl+Si6d55yY5efqzKcnlq7ZaAACANWxZAVdVG5Kcn+TKaei5Sb7Z3X8x7V+X5Kem7QuTvLdnbkxyTFU9fxXXDAAAsCYt9wrcu5O8JckT0/5DSdZV1cK0/9okJ07bJyTZvuS5909j36GqNlXVYlUt7ty5c78XDgAAsNbsM+Cq6oIkD3b31l1j3d1JLkryrqr6ZJKvJvnW/rxwd1/R3QvdvbB+/fr9XDYAAMDas24Zc16a5NVVdV6SZyZ5dlX9t+5+fZK/kyRV9cokL5rm78iTV+OSZMM0BgAAwArs8wpcd1/a3Ru6e2NmV92u7+7XV9XzkqSqjkjyq0l+Z3rKliRvmD6N8swkj3T3Awdm+QAAAGvHcq7A7c0/n26vfEaSy7v7+mn8w0nOS7ItyWNJ3riyJQIAAJAkNXs723wtLCz04uLivJcBAAAwF1W1tbsX9jVvf34PHAAAAHMk4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAYh4AAAAAax7ICrqsOq6uaq+tC0f3ZVfbqqbqmq/11VL5zGj6iqD1TVtqq6qao2HpilAwAArC37cwXuzUnuWrJ/eZLXdfdpSa5O8i+n8Tcl+VJ3vzDJu5L85mosFAAAYK1bVsBV1YYk5ye5cslwJ3n2tH10kr+ati9MctW0fW2Ss6uqVr5UAACAtW3dMue9O8lbkhy1ZOwXkny4qr6W5CtJzpzGT0iyPUm6+/GqeiTJc5M8tCorBgAAWKP2eQWuqi5I8mB3b93toV9Jcl53b0jy+0neuT8vXFWbqmqxqhZ37ty5P08FAABYk5ZzC+VLk7y6qu5N8v4kr6iqP0nyku6+aZrzgSR/e9rekeTEJKmqdZndXvnw7t+0u6/o7oXuXli/fv3KfgoAAIA1YJ8B192XdveG7t6Y5KIk12f2Prejq+pF07S/myc/4GRLkoun7dcmub67e1VXDQAAsAYt9z1w32F6b9s/SvLBqnoiyZeS/Pz08HuS/Neq2pbki5lFHwAAACu0XwHX3R9P8vFpe3OSzXuY8/UkP70KawMAAGCJ/fk9cAAAAMyRgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABhEdfe815Cq2pnkvnmvYw+OTfLQvBfBIc9xxsHgOONAc4xxMDjOOBjmdZx9f3ev39ekp0XAPV1V1WJ3L8x7HRzaHGccDI4zDjTHGAeD44yD4el+nLmFEgAAYBACDgAAYBAC7ru7Yt4LYE1wnHEwOM440BxjHAyOMw6Gp/Vx5j1wAAAAg3AFDgAAYBACbi+q6lVV9dmq2lZVb533ehhfVZ1YVX9aVXdW1R1V9eZp/DlVdV1V/eX09fvmvVbGV1WHVdXNVfWhaf8FVXXTdE77QFUdPu81MraqOqaqrq2qz1TVXVX1t5zPWE1V9SvT35e3V9X7quqZzmWshqr6vap6sKpuXzK2x/NXzVw2HXO3VtXp81v5jIDbg6o6LMlvJTk3ySlJfraqTpnvqjgEPJ7kn3b3KUnOTHLJdFy9NcnHuvvkJB+b9mGl3pzkriX7v5nkXd39wiRfSvKmuayKQ8l/SPKR7v7BJC/J7HhzPmNVVNUJSX4pyUJ3n5rksCQXxbmM1fFfkrxqt7G9nb/OTXLy9GdTkssP0hr3SsDt2RlJtnX33d39zSTvT3LhnNfE4Lr7ge7+9LT91cz+Z+eEzI6tq6ZpVyV5zXxWyKGiqjYkOT/JldN+JXlFkmunKY4zVqSqjk7yo0nekyTd/c3u/nKcz1hd65J8b1WtS/KsJA/EuYxV0N2fSPLF3Yb3dv66MMl7e+bGJMdU1fMPzkr3TMDt2QlJti/Zv38ag1VRVRuT/HCSm5Ic190PTA99Pslxc1oWh453J3lLkiem/ecm+XJ3Pz7tO6exUi9IsjPJ70+36l5ZVUfG+YxV0t07kvz7JJ/LLNweSbI1zmUcOHs7fz3tukDAwUFWVX8tyQeT/HJ3f2XpYz37WFgfDctTVlUXJHmwu7fOey0c0tYlOT3J5d39w0kezW63SzqfsRLT+48uzOwfC45PcmT+/1ve4IB4up+/BNye7Uhy4pL9DdMYrEhVfU9m8fYH3f1H0/AXdl2Kn74+OK/1cUh4aZJXV9W9md3+/YrM3qt0zHQbUuKcxsrdn+T+7r5p2r82s6BzPmO1/HiSe7p7Z3f/3yR/lNn5zbmMA2Vv56+nXRcIuD37VJKTp086OjyzN81umfOaGNz0PqT3JLmru9+55KEtSS6eti9O8scHe20cOrr70u7e0N0bMzt3Xd/dr0vyp0leO01znLEi3f35JNur6gemobOT3BnnM1bP55KcWVXPmv7+3HWMOZdxoOzt/LUlyRumT6M8M8kjS261nAu/yHsvquq8zN5HcliS3+vud8x5SQyuqs5K8mdJbsuT7036F5m9D+6aJCcluS/Jz3T37m+shf1WVT+W5J919wVV9dczuyL3nCQ3J3l9d39jnutjbFV1WmYflHN4kruTvDGzfxh2PmNVVNVvJPn7mX2K881JfiGz9x45l7EiVfW+JD+W5NgkX0jya0n+e/Zw/pr+AeE/ZXYL72NJ3tjdi/NY9y4CDgAAYBBuoQQAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABiEgAMAABjE/wMi+6Usv0OhqAAAAABJRU5ErkJggg==\n",
406 | "text/plain": [
407 | ""
408 | ]
409 | },
410 | "metadata": {
411 | "needs_background": "light"
412 | },
413 | "output_type": "display_data"
414 | }
415 | ],
416 | "source": [
417 | "plt.figure(figsize=(15, 5))\n",
418 | "plt.title('reward')\n",
419 | "plt.plot(rewards)\n",
420 | "plt.show()"
421 | ]
422 | },
423 | {
424 | "cell_type": "code",
425 | "execution_count": 11,
426 | "metadata": {},
427 | "outputs": [
428 | {
429 | "data": {
430 | "text/plain": [
431 | "[('CartPole-v0', 412, 1),\n",
432 | " ('CartPole-v1', 452, 0.05),\n",
433 | " ('MountainCar-v0', 193, 0.1),\n",
434 | " ('LunarLander-v2', 260, 0.1)]"
435 | ]
436 | },
437 | "execution_count": 11,
438 | "metadata": {},
439 | "output_type": "execute_result"
440 | }
441 | ],
442 | "source": [
443 | "[\n",
444 | " ('CartPole-v0', 412, 1),\n",
445 | " ('CartPole-v1', 452, 0.05),\n",
446 | " ('MountainCar-v0', 193, 0.1),\n",
447 | " ('LunarLander-v2', 260, 0.1)\n",
448 | "]"
449 | ]
450 | }
451 | ],
452 | "metadata": {
453 | "colab": {
454 | "collapsed_sections": [],
455 | "name": "C51_tensorflow.ipynb",
456 | "provenance": [],
457 | "version": "0.3.2"
458 | },
459 | "kernelspec": {
460 | "display_name": "Python 3",
461 | "language": "python",
462 | "name": "python3"
463 | },
464 | "language_info": {
465 | "codemirror_mode": {
466 | "name": "ipython",
467 | "version": 3
468 | },
469 | "file_extension": ".py",
470 | "mimetype": "text/x-python",
471 | "name": "python",
472 | "nbconvert_exporter": "python",
473 | "pygments_lexer": "ipython3",
474 | "version": "3.7.0"
475 | }
476 | },
477 | "nbformat": 4,
478 | "nbformat_minor": 1
479 | }
480 |
--------------------------------------------------------------------------------