├── common
├── __init__.py
└── multiprocessing_env.py
├── README.md
├── 1.actor-critic.ipynb
├── 2.gae.ipynb
├── 5.ddpg.ipynb
├── 8.gail.ipynb
└── 3.ppo.ipynb
/common/__init__.py:
--------------------------------------------------------------------------------
1 | import multiprocessing_env
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RL-Adventure-2: Policy Gradients
2 |
3 |
4 |
5 |
6 | PyTorch tutorial of: actor critic / proximal policy optimization / acer / ddpg / twin dueling ddpg / soft actor critic / generative adversarial imitation learning / hindsight experience replay
7 |
8 | The deep reinforcement learning community has made several improvements to the [policy gradient](http://rll.berkeley.edu/deeprlcourse/f17docs/lecture_4_policy_gradient.pdf) algorithms. This tutorial presents latest extensions in the following order:
9 |
10 | 1. Advantage Actor Critic (A2C)
11 | - [actor-critic.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/1.actor-critic.ipynb)
12 | - [A3C Paper](https://arxiv.org/pdf/1602.01783.pdf)
13 | - [OpenAI blog](https://blog.openai.com/baselines-acktr-a2c/#a2canda3c)
14 | 2. High-Dimensional Continuous Control Using Generalized Advantage Estimation
15 | - [gae.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/2.gae.ipynb)
16 | - [GAE Paper](https://arxiv.org/abs/1506.02438)
17 | 3. Proximal Policy Optimization Algorithms
18 | - [ppo.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/3.ppo.ipynb)
19 | - [PPO Paper](https://arxiv.org/abs/1707.06347)
20 | - [OpenAI blog](https://blog.openai.com/openai-baselines-ppo/)
21 | 4. Sample Efficient Actor-Critic with Experience Replay
22 | - [acer.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/4.acer.ipynb)
23 | - [ACER Paper](https://arxiv.org/abs/1611.01224)
24 | 5. Continuous control with deep reinforcement learning
25 | - [ddpg.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/5.ddpg.ipynb)
26 | - [DDPG Paper](https://arxiv.org/abs/1509.02971)
27 | 6. Addressing Function Approximation Error in Actor-Critic Methods
28 | - [td3.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/6.td3.ipynb)
29 | - [Twin Dueling DDPG Paper](https://arxiv.org/abs/1802.09477)
30 | 7. Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor
31 | - [soft actor-critic.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/7.soft%20actor-critic.ipynb)
32 | - [Soft Actor-Critic Paper](https://arxiv.org/abs/1801.01290)
33 | 8. Generative Adversarial Imitation Learning
34 | - [gail.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/8.gail.ipynb)
35 | - [GAIL Paper](https://arxiv.org/abs/1606.03476)
36 | 9. Hindsight Experience Replay
37 | - [her.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/9.her.ipynb)
38 | - [HER Paper](https://arxiv.org/abs/1707.01495)
39 | - [OpenAI Blog](https://blog.openai.com/ingredients-for-robotics-research/#understandingher)
40 |
41 | # If you get stuck…
42 | - Remember you are not stuck unless you have spent more than a week on a single algorithm. It is perfectly normal if you do not have all the required knowledge of mathematics and CS.
43 | - Carefully go through the paper. Try to see what is the problem the authors are solving. Understand a high-level idea of the approach, then read the code (skipping the proofs), and after go over the mathematical details and proofs.
44 |
45 | # RL Algorithms
46 | Deep Q Learning tutorial: [DQN Adventure: from Zero to State of the Art](https://github.com/higgsfield/RL-Adventure)
47 | []()
48 | Awesome RL libs: rlkit [@vitchyr](https://github.com/vitchyr), pytorch-a2c-ppo-acktr [@ikostrikov](https://github.com/ikostrikov),
49 | ACER [@Kaixhin](https://github.com/Kaixhin)
50 |
51 | # Best RL courses
52 | - Berkeley deep RL [link](http://rll.berkeley.edu/deeprlcourse/)
53 | - Deep RL Bootcamp [link](https://sites.google.com/view/deep-rl-bootcamp/lectures)
54 | - David Silver's course [link](http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching.html)
55 | - Practical RL [link](https://github.com/yandexdataschool/Practical_RL)
56 |
--------------------------------------------------------------------------------
/common/multiprocessing_env.py:
--------------------------------------------------------------------------------
1 | #This code is from openai baseline
2 | #https://github.com/openai/baselines/tree/master/baselines/common/vec_env
3 |
4 | import numpy as np
5 | from multiprocessing import Process, Pipe
6 |
7 | def worker(remote, parent_remote, env_fn_wrapper):
8 | parent_remote.close()
9 | env = env_fn_wrapper.x()
10 | while True:
11 | cmd, data = remote.recv()
12 | if cmd == 'step':
13 | ob, reward, done, info = env.step(data)
14 | if done:
15 | ob = env.reset()
16 | remote.send((ob, reward, done, info))
17 | elif cmd == 'reset':
18 | ob = env.reset()
19 | remote.send(ob)
20 | elif cmd == 'reset_task':
21 | ob = env.reset_task()
22 | remote.send(ob)
23 | elif cmd == 'close':
24 | remote.close()
25 | break
26 | elif cmd == 'get_spaces':
27 | remote.send((env.observation_space, env.action_space))
28 | else:
29 | raise NotImplementedError
30 |
31 | class VecEnv(object):
32 | """
33 | An abstract asynchronous, vectorized environment.
34 | """
35 | def __init__(self, num_envs, observation_space, action_space):
36 | self.num_envs = num_envs
37 | self.observation_space = observation_space
38 | self.action_space = action_space
39 |
40 | def reset(self):
41 | """
42 | Reset all the environments and return an array of
43 | observations, or a tuple of observation arrays.
44 | If step_async is still doing work, that work will
45 | be cancelled and step_wait() should not be called
46 | until step_async() is invoked again.
47 | """
48 | pass
49 |
50 | def step_async(self, actions):
51 | """
52 | Tell all the environments to start taking a step
53 | with the given actions.
54 | Call step_wait() to get the results of the step.
55 | You should not call this if a step_async run is
56 | already pending.
57 | """
58 | pass
59 |
60 | def step_wait(self):
61 | """
62 | Wait for the step taken with step_async().
63 | Returns (obs, rews, dones, infos):
64 | - obs: an array of observations, or a tuple of
65 | arrays of observations.
66 | - rews: an array of rewards
67 | - dones: an array of "episode done" booleans
68 | - infos: a sequence of info objects
69 | """
70 | pass
71 |
72 | def close(self):
73 | """
74 | Clean up the environments' resources.
75 | """
76 | pass
77 |
78 | def step(self, actions):
79 | self.step_async(actions)
80 | return self.step_wait()
81 |
82 |
83 | class CloudpickleWrapper(object):
84 | """
85 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
86 | """
87 | def __init__(self, x):
88 | self.x = x
89 | def __getstate__(self):
90 | import cloudpickle
91 | return cloudpickle.dumps(self.x)
92 | def __setstate__(self, ob):
93 | import pickle
94 | self.x = pickle.loads(ob)
95 |
96 |
97 | class SubprocVecEnv(VecEnv):
98 | def __init__(self, env_fns, spaces=None):
99 | """
100 | envs: list of gym environments to run in subprocesses
101 | """
102 | self.waiting = False
103 | self.closed = False
104 | nenvs = len(env_fns)
105 | self.nenvs = nenvs
106 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
107 | self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
108 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
109 | for p in self.ps:
110 | p.daemon = True # if the main process crashes, we should not cause things to hang
111 | p.start()
112 | for remote in self.work_remotes:
113 | remote.close()
114 |
115 | self.remotes[0].send(('get_spaces', None))
116 | observation_space, action_space = self.remotes[0].recv()
117 | VecEnv.__init__(self, len(env_fns), observation_space, action_space)
118 |
119 | def step_async(self, actions):
120 | for remote, action in zip(self.remotes, actions):
121 | remote.send(('step', action))
122 | self.waiting = True
123 |
124 | def step_wait(self):
125 | results = [remote.recv() for remote in self.remotes]
126 | self.waiting = False
127 | obs, rews, dones, infos = zip(*results)
128 | return np.stack(obs), np.stack(rews), np.stack(dones), infos
129 |
130 | def reset(self):
131 | for remote in self.remotes:
132 | remote.send(('reset', None))
133 | return np.stack([remote.recv() for remote in self.remotes])
134 |
135 | def reset_task(self):
136 | for remote in self.remotes:
137 | remote.send(('reset_task', None))
138 | return np.stack([remote.recv() for remote in self.remotes])
139 |
140 | def close(self):
141 | if self.closed:
142 | return
143 | if self.waiting:
144 | for remote in self.remotes:
145 | remote.recv()
146 | for remote in self.remotes:
147 | remote.send(('close', None))
148 | for p in self.ps:
149 | p.join()
150 | self.closed = True
151 |
152 | def __len__(self):
153 | return self.nenvs
--------------------------------------------------------------------------------
/1.actor-critic.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import math\n",
10 | "import random\n",
11 | "\n",
12 | "import gym\n",
13 | "import numpy as np\n",
14 | "\n",
15 | "import torch\n",
16 | "import torch.nn as nn\n",
17 | "import torch.optim as optim\n",
18 | "import torch.nn.functional as F\n",
19 | "from torch.distributions import Categorical"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "from IPython.display import clear_output\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "%matplotlib inline"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "
Use CUDA "
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 3,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "use_cuda = torch.cuda.is_available()\n",
47 | "device = torch.device(\"cuda\" if use_cuda else \"cpu\")"
48 | ]
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "metadata": {},
53 | "source": [
54 | "Create Environments "
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 4,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "from common.multiprocessing_env import SubprocVecEnv\n",
64 | "\n",
65 | "num_envs = 16\n",
66 | "env_name = \"CartPole-v0\"\n",
67 | "\n",
68 | "def make_env():\n",
69 | " def _thunk():\n",
70 | " env = gym.make(env_name)\n",
71 | " return env\n",
72 | "\n",
73 | " return _thunk\n",
74 | "\n",
75 | "envs = [make_env() for i in range(num_envs)]\n",
76 | "envs = SubprocVecEnv(envs)\n",
77 | "\n",
78 | "env = gym.make(env_name)"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {},
84 | "source": [
85 | "Neural Network "
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 19,
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "class ActorCritic(nn.Module):\n",
95 | " def __init__(self, num_inputs, num_outputs, hidden_size, std=0.0):\n",
96 | " super(ActorCritic, self).__init__()\n",
97 | " \n",
98 | " self.critic = nn.Sequential(\n",
99 | " nn.Linear(num_inputs, hidden_size),\n",
100 | " nn.ReLU(),\n",
101 | " nn.Linear(hidden_size, 1)\n",
102 | " )\n",
103 | " \n",
104 | " self.actor = nn.Sequential(\n",
105 | " nn.Linear(num_inputs, hidden_size),\n",
106 | " nn.ReLU(),\n",
107 | " nn.Linear(hidden_size, num_outputs),\n",
108 | " nn.Softmax(dim=1),\n",
109 | " )\n",
110 | " \n",
111 | " def forward(self, x):\n",
112 | " value = self.critic(x)\n",
113 | " probs = self.actor(x)\n",
114 | " dist = Categorical(probs)\n",
115 | " return dist, value"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 20,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "def plot(frame_idx, rewards):\n",
125 | " clear_output(True)\n",
126 | " plt.figure(figsize=(20,5))\n",
127 | " plt.subplot(131)\n",
128 | " plt.title('frame %s. reward: %s' % (frame_idx, rewards[-1]))\n",
129 | " plt.plot(rewards)\n",
130 | " plt.show()\n",
131 | " \n",
132 | "def test_env(vis=False):\n",
133 | " state = env.reset()\n",
134 | " if vis: env.render()\n",
135 | " done = False\n",
136 | " total_reward = 0\n",
137 | " while not done:\n",
138 | " state = torch.FloatTensor(state).unsqueeze(0).to(device)\n",
139 | " dist, _ = model(state)\n",
140 | " next_state, reward, done, _ = env.step(dist.sample().cpu().numpy()[0])\n",
141 | " state = next_state\n",
142 | " if vis: env.render()\n",
143 | " total_reward += reward\n",
144 | " return total_reward"
145 | ]
146 | },
147 | {
148 | "cell_type": "markdown",
149 | "metadata": {},
150 | "source": [
151 | "A2C: Synchronous Advantage Actor Critic \n",
152 | "\n",
153 | "The Asynchronous Advantage Actor Critic method (A3C) has been very influential since the paper was published. The algorithm combines a few key ideas:
\n",
154 | "\n",
155 | "\n",
156 | " An updating scheme that operates on fixed-length segments of experience (say, 20 timesteps) and uses these segments to compute estimators of the returns and advantage function. \n",
157 | " Architectures that share layers between the policy and value function. \n",
158 | " Asynchronous updates. \n",
159 | " \n",
160 | "\n",
161 | "After reading the paper, AI researchers wondered whether the asynchrony led to improved performance (e.g. “perhaps the added noise would provide some regularization or exploration?“), or if it was just an implementation detail that allowed for faster training with a CPU-based implementation.
\n",
162 | "\n",
163 | "As an alternative to the asynchronous implementation, researchers found you can write a synchronous, deterministic implementation that waits for each actor to finish its segment of experience before performing an update, averaging over all of the actors. One advantage of this method is that it can more effectively use of GPUs, which perform best with large batch sizes. This algorithm is naturally called A2C, short for advantage actor critic. (This term has been used in several papers.)
"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 21,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "def compute_returns(next_value, rewards, masks, gamma=0.99):\n",
173 | " R = next_value\n",
174 | " returns = []\n",
175 | " for step in reversed(range(len(rewards))):\n",
176 | " R = rewards[step] + gamma * R * masks[step]\n",
177 | " returns.insert(0, R)\n",
178 | " return returns"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": 22,
184 | "metadata": {},
185 | "outputs": [],
186 | "source": [
187 | "num_inputs = envs.observation_space.shape[0]\n",
188 | "num_outputs = envs.action_space.n\n",
189 | "\n",
190 | "#Hyper params:\n",
191 | "hidden_size = 256\n",
192 | "lr = 3e-4\n",
193 | "num_steps = 5\n",
194 | "\n",
195 | "model = ActorCritic(num_inputs, num_outputs, hidden_size).to(device)\n",
196 | "optimizer = optim.Adam(model.parameters())"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 23,
202 | "metadata": {},
203 | "outputs": [],
204 | "source": [
205 | "max_frames = 20000\n",
206 | "frame_idx = 0\n",
207 | "test_rewards = []"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": 17,
213 | "metadata": {},
214 | "outputs": [
215 | {
216 | "data": {
217 | "image/png": "\n",
218 | "text/plain": [
219 | ""
220 | ]
221 | },
222 | "metadata": {},
223 | "output_type": "display_data"
224 | }
225 | ],
226 | "source": [
227 | "state = envs.reset()\n",
228 | "\n",
229 | "while frame_idx < max_frames:\n",
230 | "\n",
231 | " log_probs = []\n",
232 | " values = []\n",
233 | " rewards = []\n",
234 | " masks = []\n",
235 | " entropy = 0\n",
236 | "\n",
237 | " for _ in range(num_steps):\n",
238 | " state = torch.FloatTensor(state).to(device)\n",
239 | " dist, value = model(state)\n",
240 | "\n",
241 | " action = dist.sample()\n",
242 | " next_state, reward, done, _ = envs.step(action.cpu().numpy())\n",
243 | "\n",
244 | " log_prob = dist.log_prob(action)\n",
245 | " entropy += dist.entropy().mean()\n",
246 | " \n",
247 | " log_probs.append(log_prob)\n",
248 | " values.append(value)\n",
249 | " rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(device))\n",
250 | " masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device))\n",
251 | " \n",
252 | " state = next_state\n",
253 | " frame_idx += 1\n",
254 | " \n",
255 | " if frame_idx % 1000 == 0:\n",
256 | " test_rewards.append(np.mean([test_env() for _ in range(10)]))\n",
257 | " plot(frame_idx, test_rewards)\n",
258 | " \n",
259 | " next_state = torch.FloatTensor(next_state).to(device)\n",
260 | " _, next_value = model(next_state)\n",
261 | " returns = compute_returns(next_value, rewards, masks)\n",
262 | " \n",
263 | " log_probs = torch.cat(log_probs)\n",
264 | " returns = torch.cat(returns).detach()\n",
265 | " values = torch.cat(values)\n",
266 | "\n",
267 | " advantage = returns - values\n",
268 | "\n",
269 | " actor_loss = -(log_probs * advantage.detach()).mean()\n",
270 | " critic_loss = advantage.pow(2).mean()\n",
271 | "\n",
272 | " loss = actor_loss + 0.5 * critic_loss - 0.001 * entropy\n",
273 | "\n",
274 | " optimizer.zero_grad()\n",
275 | " loss.backward()\n",
276 | " optimizer.step()"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": 26,
282 | "metadata": {},
283 | "outputs": [
284 | {
285 | "data": {
286 | "text/plain": [
287 | "200.0"
288 | ]
289 | },
290 | "execution_count": 26,
291 | "metadata": {},
292 | "output_type": "execute_result"
293 | }
294 | ],
295 | "source": [
296 | "test_env(True)"
297 | ]
298 | }
299 | ],
300 | "metadata": {
301 | "kernelspec": {
302 | "display_name": "Python [conda env:pytorch4]",
303 | "language": "python",
304 | "name": "conda-env-pytorch4-py"
305 | },
306 | "language_info": {
307 | "codemirror_mode": {
308 | "name": "ipython",
309 | "version": 3
310 | },
311 | "file_extension": ".py",
312 | "mimetype": "text/x-python",
313 | "name": "python",
314 | "nbconvert_exporter": "python",
315 | "pygments_lexer": "ipython3",
316 | "version": "3.5.5"
317 | }
318 | },
319 | "nbformat": 4,
320 | "nbformat_minor": 2
321 | }
322 |
--------------------------------------------------------------------------------
/2.gae.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 5,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import math\n",
10 | "import random\n",
11 | "\n",
12 | "import gym\n",
13 | "import numpy as np\n",
14 | "\n",
15 | "import torch\n",
16 | "import torch.nn as nn\n",
17 | "import torch.optim as optim\n",
18 | "import torch.nn.functional as F\n",
19 | "from torch.distributions import Normal"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 6,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "from IPython.display import clear_output\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "%matplotlib inline"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "Use CUDA "
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 7,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "use_cuda = torch.cuda.is_available()\n",
47 | "device = torch.device(\"cuda\" if use_cuda else \"cpu\")"
48 | ]
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "metadata": {},
53 | "source": [
54 | "Create Environments "
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 22,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "from common.multiprocessing_env import SubprocVecEnv\n",
64 | "\n",
65 | "num_envs = 16\n",
66 | "env_name = \"Pendulum-v0\"\n",
67 | "\n",
68 | "def make_env():\n",
69 | " def _thunk():\n",
70 | " env = gym.make(env_name)\n",
71 | " return env\n",
72 | "\n",
73 | " return _thunk\n",
74 | "\n",
75 | "envs = [make_env() for i in range(num_envs)]\n",
76 | "envs = SubprocVecEnv(envs)\n",
77 | "\n",
78 | "env = gym.make(env_name)"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {},
84 | "source": [
85 | "Neural Network "
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 10,
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "def init_weights(m):\n",
95 | " if isinstance(m, nn.Linear):\n",
96 | " nn.init.normal_(m.weight, mean=0., std=0.1)\n",
97 | " nn.init.constant_(m.bias, 0.1)\n",
98 | "\n",
99 | "\n",
100 | "class ActorCritic(nn.Module):\n",
101 | " def __init__(self, num_inputs, num_outputs, hidden_size, std=0.0):\n",
102 | " super(ActorCritic, self).__init__()\n",
103 | " \n",
104 | " self.critic = nn.Sequential(\n",
105 | " nn.Linear(num_inputs, hidden_size),\n",
106 | " nn.ReLU(),\n",
107 | " nn.Linear(hidden_size, 1)\n",
108 | " )\n",
109 | " \n",
110 | " self.actor = nn.Sequential(\n",
111 | " nn.Linear(num_inputs, hidden_size),\n",
112 | " nn.ReLU(),\n",
113 | " nn.Linear(hidden_size, num_outputs),\n",
114 | " )\n",
115 | " self.log_std = nn.Parameter(torch.ones(1, num_outputs) * std)\n",
116 | " \n",
117 | " self.apply(init_weights)\n",
118 | " \n",
119 | " def forward(self, x):\n",
120 | " value = self.critic(x)\n",
121 | " mu = self.actor(x)\n",
122 | " std = self.log_std.exp().expand_as(mu)\n",
123 | " dist = Normal(mu, std)\n",
124 | " return dist, value"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 11,
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "def plot(frame_idx, rewards):\n",
134 | " clear_output(True)\n",
135 | " plt.figure(figsize=(20,5))\n",
136 | " plt.subplot(131)\n",
137 | " plt.title('frame %s. reward: %s' % (frame_idx, rewards[-1]))\n",
138 | " plt.plot(rewards)\n",
139 | " plt.show()\n",
140 | " \n",
141 | "def test_env(vis=False):\n",
142 | " state = env.reset()\n",
143 | " if vis: env.render()\n",
144 | " done = False\n",
145 | " total_reward = 0\n",
146 | " while not done:\n",
147 | " state = torch.FloatTensor(state).unsqueeze(0).to(device)\n",
148 | " dist, _ = model(state)\n",
149 | " next_state, reward, done, _ = env.step(dist.sample().cpu().numpy()[0])\n",
150 | " state = next_state\n",
151 | " if vis: env.render()\n",
152 | " total_reward += reward\n",
153 | " return total_reward"
154 | ]
155 | },
156 | {
157 | "cell_type": "markdown",
158 | "metadata": {},
159 | "source": [
160 | "High-Dimensional Continuous Control Using Generalized Advantage Estimation \n",
161 | ""
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": 17,
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "def compute_gae(next_value, rewards, masks, values, gamma=0.99, tau=0.95):\n",
171 | " values = values + [next_value]\n",
172 | " gae = 0\n",
173 | " returns = []\n",
174 | " for step in reversed(range(len(rewards))):\n",
175 | " delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]\n",
176 | " gae = delta + gamma * tau * masks[step] * gae\n",
177 | " returns.insert(0, gae + values[step])\n",
178 | " return returns"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": 29,
184 | "metadata": {},
185 | "outputs": [],
186 | "source": [
187 | "num_inputs = envs.observation_space.shape[0]\n",
188 | "num_outputs = envs.action_space.shape[0]\n",
189 | "\n",
190 | "#Hyper params:\n",
191 | "hidden_size = 256\n",
192 | "lr = 3e-2\n",
193 | "num_steps = 20\n",
194 | "\n",
195 | "model = ActorCritic(num_inputs, num_outputs, hidden_size).to(device)\n",
196 | "optimizer = optim.Adam(model.parameters())"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 30,
202 | "metadata": {},
203 | "outputs": [],
204 | "source": [
205 | "max_frames = 100000\n",
206 | "frame_idx = 0\n",
207 | "test_rewards = []"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": 31,
213 | "metadata": {},
214 | "outputs": [
215 | {
216 | "data": {
217 | "image/png": "\n",
218 | "text/plain": [
219 | ""
220 | ]
221 | },
222 | "metadata": {},
223 | "output_type": "display_data"
224 | }
225 | ],
226 | "source": [
227 | "state = envs.reset()\n",
228 | "\n",
229 | "while frame_idx < max_frames:\n",
230 | "\n",
231 | " log_probs = []\n",
232 | " values = []\n",
233 | " rewards = []\n",
234 | " masks = []\n",
235 | " entropy = 0\n",
236 | "\n",
237 | " for _ in range(num_steps):\n",
238 | " state = torch.FloatTensor(state).to(device)\n",
239 | " dist, value = model(state)\n",
240 | "\n",
241 | " action = dist.sample()\n",
242 | " next_state, reward, done, _ = envs.step(action.cpu().numpy())\n",
243 | "\n",
244 | " log_prob = dist.log_prob(action)\n",
245 | " entropy += dist.entropy().mean()\n",
246 | " \n",
247 | " log_probs.append(log_prob)\n",
248 | " values.append(value)\n",
249 | " rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(device))\n",
250 | " masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device))\n",
251 | " \n",
252 | " state = next_state\n",
253 | " frame_idx += 1\n",
254 | " \n",
255 | " if frame_idx % 1000 == 0:\n",
256 | " test_rewards.append(np.mean([test_env() for _ in range(10)]))\n",
257 | " plot(frame_idx, test_rewards)\n",
258 | " \n",
259 | " next_state = torch.FloatTensor(next_state).to(device)\n",
260 | " _, next_value = model(next_state)\n",
261 | " returns = compute_gae(next_value, rewards, masks, values)\n",
262 | " \n",
263 | " log_probs = torch.cat(log_probs)\n",
264 | " returns = torch.cat(returns).detach()\n",
265 | " values = torch.cat(values)\n",
266 | "\n",
267 | " advantage = returns - values\n",
268 | "\n",
269 | " actor_loss = -(log_probs * advantage.detach()).mean()\n",
270 | " critic_loss = advantage.pow(2).mean()\n",
271 | "\n",
272 | " loss = actor_loss + 0.5 * critic_loss - 0.001 * entropy\n",
273 | "\n",
274 | " optimizer.zero_grad()\n",
275 | " loss.backward()\n",
276 | " optimizer.step()"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": 32,
282 | "metadata": {},
283 | "outputs": [
284 | {
285 | "data": {
286 | "text/plain": [
287 | "-283.0576102217745"
288 | ]
289 | },
290 | "execution_count": 32,
291 | "metadata": {},
292 | "output_type": "execute_result"
293 | }
294 | ],
295 | "source": [
296 | "test_env(True)"
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": null,
302 | "metadata": {},
303 | "outputs": [],
304 | "source": []
305 | }
306 | ],
307 | "metadata": {
308 | "kernelspec": {
309 | "display_name": "Python [conda env:pytorch4]",
310 | "language": "python",
311 | "name": "conda-env-pytorch4-py"
312 | },
313 | "language_info": {
314 | "codemirror_mode": {
315 | "name": "ipython",
316 | "version": 3
317 | },
318 | "file_extension": ".py",
319 | "mimetype": "text/x-python",
320 | "name": "python",
321 | "nbconvert_exporter": "python",
322 | "pygments_lexer": "ipython3",
323 | "version": "3.5.5"
324 | }
325 | },
326 | "nbformat": 4,
327 | "nbformat_minor": 2
328 | }
329 |
--------------------------------------------------------------------------------
/5.ddpg.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import math\n",
10 | "import random\n",
11 | "\n",
12 | "import gym\n",
13 | "import numpy as np\n",
14 | "\n",
15 | "import torch\n",
16 | "import torch.nn as nn\n",
17 | "import torch.optim as optim\n",
18 | "import torch.nn.functional as F\n",
19 | "from torch.distributions import Normal"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "from IPython.display import clear_output\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "%matplotlib inline"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "Use CUDA "
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 3,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "use_cuda = torch.cuda.is_available()\n",
47 | "device = torch.device(\"cuda\" if use_cuda else \"cpu\")"
48 | ]
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "metadata": {},
53 | "source": [
54 | "Replay Buffer "
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 5,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "class ReplayBuffer:\n",
64 | " def __init__(self, capacity):\n",
65 | " self.capacity = capacity\n",
66 | " self.buffer = []\n",
67 | " self.position = 0\n",
68 | " \n",
69 | " def push(self, state, action, reward, next_state, done):\n",
70 | " if len(self.buffer) < self.capacity:\n",
71 | " self.buffer.append(None)\n",
72 | " self.buffer[self.position] = (state, action, reward, next_state, done)\n",
73 | " self.position = (self.position + 1) % self.capacity\n",
74 | " \n",
75 | " def sample(self, batch_size):\n",
76 | " batch = random.sample(self.buffer, batch_size)\n",
77 | " state, action, reward, next_state, done = map(np.stack, zip(*batch))\n",
78 | " return state, action, reward, next_state, done\n",
79 | " \n",
80 | " def __len__(self):\n",
81 | " return len(self.buffer)"
82 | ]
83 | },
84 | {
85 | "cell_type": "markdown",
86 | "metadata": {},
87 | "source": [
88 | "Normalize action space "
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 8,
94 | "metadata": {},
95 | "outputs": [],
96 | "source": [
97 | "class NormalizedActions(gym.ActionWrapper):\n",
98 | "\n",
99 | " def _action(self, action):\n",
100 | " low_bound = self.action_space.low\n",
101 | " upper_bound = self.action_space.high\n",
102 | " \n",
103 | " action = low_bound + (action + 1.0) * 0.5 * (upper_bound - low_bound)\n",
104 | " action = np.clip(action, low_bound, upper_bound)\n",
105 | " \n",
106 | " return action\n",
107 | "\n",
108 | " def _reverse_action(self, action):\n",
109 | " low_bound = self.action_space.low\n",
110 | " upper_bound = self.action_space.high\n",
111 | " \n",
112 | " action = 2 * (action - low_bound) / (upper_bound - low_bound) - 1\n",
113 | " action = np.clip(action, low_bound, upper_bound)\n",
114 | " \n",
115 | " return actions"
116 | ]
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "metadata": {},
121 | "source": [
122 | "Ornstein-Uhlenbeck process \n",
123 | "Adding time-correlated noise to the actions taken by the deterministic policy \n",
124 | "wiki "
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 12,
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "class OUNoise(object):\n",
134 | " def __init__(self, action_space, mu=0.0, theta=0.15, max_sigma=0.3, min_sigma=0.3, decay_period=100000):\n",
135 | " self.mu = mu\n",
136 | " self.theta = theta\n",
137 | " self.sigma = max_sigma\n",
138 | " self.max_sigma = max_sigma\n",
139 | " self.min_sigma = min_sigma\n",
140 | " self.decay_period = decay_period\n",
141 | " self.action_dim = action_space.shape[0]\n",
142 | " self.low = action_space.low\n",
143 | " self.high = action_space.high\n",
144 | " self.reset()\n",
145 | " \n",
146 | " def reset(self):\n",
147 | " self.state = np.ones(self.action_dim) * self.mu\n",
148 | " \n",
149 | " def evolve_state(self):\n",
150 | " x = self.state\n",
151 | " dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim)\n",
152 | " self.state = x + dx\n",
153 | " return self.state\n",
154 | " \n",
155 | " def get_action(self, action, t=0):\n",
156 | " ou_state = self.evolve_state()\n",
157 | " self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min(1.0, t / self.decay_period)\n",
158 | " return np.clip(action + ou_state, self.low, self.high)\n",
159 | " \n",
160 | "#https://github.com/vitchyr/rlkit/blob/master/rlkit/exploration_strategies/ou_strategy.py"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": 16,
166 | "metadata": {},
167 | "outputs": [],
168 | "source": [
169 | "def plot(frame_idx, rewards):\n",
170 | " clear_output(True)\n",
171 | " plt.figure(figsize=(20,5))\n",
172 | " plt.subplot(131)\n",
173 | " plt.title('frame %s. reward: %s' % (frame_idx, rewards[-1]))\n",
174 | " plt.plot(rewards)\n",
175 | " plt.show()"
176 | ]
177 | },
178 | {
179 | "cell_type": "markdown",
180 | "metadata": {},
181 | "source": [
182 | " Continuous control with deep reinforcement learning \n",
183 | ""
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": 18,
189 | "metadata": {},
190 | "outputs": [],
191 | "source": [
192 | "class ValueNetwork(nn.Module):\n",
193 | " def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3):\n",
194 | " super(ValueNetwork, self).__init__()\n",
195 | " \n",
196 | " self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size)\n",
197 | " self.linear2 = nn.Linear(hidden_size, hidden_size)\n",
198 | " self.linear3 = nn.Linear(hidden_size, 1)\n",
199 | " \n",
200 | " self.linear3.weight.data.uniform_(-init_w, init_w)\n",
201 | " self.linear3.bias.data.uniform_(-init_w, init_w)\n",
202 | " \n",
203 | " def forward(self, state, action):\n",
204 | " x = torch.cat([state, action], 1)\n",
205 | " x = F.relu(self.linear1(x))\n",
206 | " x = F.relu(self.linear2(x))\n",
207 | " x = self.linear3(x)\n",
208 | " return x\n",
209 | " \n",
210 | "\n",
211 | "class PolicyNetwork(nn.Module):\n",
212 | " def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3):\n",
213 | " super(PolicyNetwork, self).__init__()\n",
214 | " \n",
215 | " self.linear1 = nn.Linear(num_inputs, hidden_size)\n",
216 | " self.linear2 = nn.Linear(hidden_size, hidden_size)\n",
217 | " self.linear3 = nn.Linear(hidden_size, num_actions)\n",
218 | " \n",
219 | " self.linear3.weight.data.uniform_(-init_w, init_w)\n",
220 | " self.linear3.bias.data.uniform_(-init_w, init_w)\n",
221 | " \n",
222 | " def forward(self, state):\n",
223 | " x = F.relu(self.linear1(state))\n",
224 | " x = F.relu(self.linear2(x))\n",
225 | " x = F.tanh(self.linear3(x))\n",
226 | " return x\n",
227 | " \n",
228 | " def get_action(self, state):\n",
229 | " state = torch.FloatTensor(state).unsqueeze(0).to(device)\n",
230 | " action = self.forward(state)\n",
231 | " return action.detach().cpu().numpy()[0, 0]"
232 | ]
233 | },
234 | {
235 | "cell_type": "markdown",
236 | "metadata": {},
237 | "source": [
238 | "DDPG Update "
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 19,
244 | "metadata": {},
245 | "outputs": [],
246 | "source": [
247 | "def ddpg_update(batch_size, \n",
248 | " gamma = 0.99,\n",
249 | " min_value=-np.inf,\n",
250 | " max_value=np.inf,\n",
251 | " soft_tau=1e-2):\n",
252 | " \n",
253 | " state, action, reward, next_state, done = replay_buffer.sample(batch_size)\n",
254 | " \n",
255 | " state = torch.FloatTensor(state).to(device)\n",
256 | " next_state = torch.FloatTensor(next_state).to(device)\n",
257 | " action = torch.FloatTensor(action).to(device)\n",
258 | " reward = torch.FloatTensor(reward).unsqueeze(1).to(device)\n",
259 | " done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(device)\n",
260 | "\n",
261 | " policy_loss = value_net(state, policy_net(state))\n",
262 | " policy_loss = -policy_loss.mean()\n",
263 | "\n",
264 | " next_action = target_policy_net(next_state)\n",
265 | " target_value = target_value_net(next_state, next_action.detach())\n",
266 | " expected_value = reward + (1.0 - done) * gamma * target_value\n",
267 | " expected_value = torch.clamp(expected_value, min_value, max_value)\n",
268 | "\n",
269 | " value = value_net(state, action)\n",
270 | " value_loss = value_criterion(value, expected_value.detach())\n",
271 | "\n",
272 | "\n",
273 | " policy_optimizer.zero_grad()\n",
274 | " policy_loss.backward()\n",
275 | " policy_optimizer.step()\n",
276 | "\n",
277 | " value_optimizer.zero_grad()\n",
278 | " value_loss.backward()\n",
279 | " value_optimizer.step()\n",
280 | "\n",
281 | " for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):\n",
282 | " target_param.data.copy_(\n",
283 | " target_param.data * (1.0 - soft_tau) + param.data * soft_tau\n",
284 | " )\n",
285 | "\n",
286 | " for target_param, param in zip(target_policy_net.parameters(), policy_net.parameters()):\n",
287 | " target_param.data.copy_(\n",
288 | " target_param.data * (1.0 - soft_tau) + param.data * soft_tau\n",
289 | " )"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": 25,
295 | "metadata": {},
296 | "outputs": [
297 | {
298 | "name": "stdout",
299 | "output_type": "stream",
300 | "text": [
301 | "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n",
302 | "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n"
303 | ]
304 | }
305 | ],
306 | "source": [
307 | "env = NormalizedActions(gym.make(\"Pendulum-v0\"))\n",
308 | "ou_noise = OUNoise(env.action_space)\n",
309 | "\n",
310 | "state_dim = env.observation_space.shape[0]\n",
311 | "action_dim = env.action_space.shape[0]\n",
312 | "hidden_dim = 256\n",
313 | "\n",
314 | "value_net = ValueNetwork(state_dim, action_dim, hidden_dim).to(device)\n",
315 | "policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim).to(device)\n",
316 | "\n",
317 | "target_value_net = ValueNetwork(state_dim, action_dim, hidden_dim).to(device)\n",
318 | "target_policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim).to(device)\n",
319 | "\n",
320 | "for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):\n",
321 | " target_param.data.copy_(param.data)\n",
322 | "\n",
323 | "for target_param, param in zip(target_policy_net.parameters(), policy_net.parameters()):\n",
324 | " target_param.data.copy_(param.data)\n",
325 | " \n",
326 | " \n",
327 | "value_lr = 1e-3\n",
328 | "policy_lr = 1e-4\n",
329 | "\n",
330 | "value_optimizer = optim.Adam(value_net.parameters(), lr=value_lr)\n",
331 | "policy_optimizer = optim.Adam(policy_net.parameters(), lr=policy_lr)\n",
332 | "\n",
333 | "value_criterion = nn.MSELoss()\n",
334 | "\n",
335 | "replay_buffer_size = 1000000\n",
336 | "replay_buffer = ReplayBuffer(replay_buffer_size)"
337 | ]
338 | },
339 | {
340 | "cell_type": "code",
341 | "execution_count": 28,
342 | "metadata": {},
343 | "outputs": [],
344 | "source": [
345 | "max_frames = 12000\n",
346 | "max_steps = 500\n",
347 | "frame_idx = 0\n",
348 | "rewards = []\n",
349 | "batch_size = 128"
350 | ]
351 | },
352 | {
353 | "cell_type": "code",
354 | "execution_count": 29,
355 | "metadata": {},
356 | "outputs": [
357 | {
358 | "data": {
359 | "image/png": "\n",
360 | "text/plain": [
361 | ""
362 | ]
363 | },
364 | "metadata": {},
365 | "output_type": "display_data"
366 | }
367 | ],
368 | "source": [
369 | "while frame_idx < max_frames:\n",
370 | " state = env.reset()\n",
371 | " ou_noise.reset()\n",
372 | " episode_reward = 0\n",
373 | " \n",
374 | " for step in range(max_steps):\n",
375 | " action = policy_net.get_action(state)\n",
376 | " action = ou_noise.get_action(action, step)\n",
377 | " next_state, reward, done, _ = env.step(action)\n",
378 | " \n",
379 | " replay_buffer.push(state, action, reward, next_state, done)\n",
380 | " if len(replay_buffer) > batch_size:\n",
381 | " ddpg_update(batch_size)\n",
382 | " \n",
383 | " state = next_state\n",
384 | " episode_reward += reward\n",
385 | " frame_idx += 1\n",
386 | " \n",
387 | " if frame_idx % max(1000, max_steps + 1) == 0:\n",
388 | " plot(frame_idx, rewards)\n",
389 | " \n",
390 | " if done:\n",
391 | " break\n",
392 | " \n",
393 | " rewards.append(episode_reward)"
394 | ]
395 | },
396 | {
397 | "cell_type": "code",
398 | "execution_count": null,
399 | "metadata": {},
400 | "outputs": [],
401 | "source": []
402 | }
403 | ],
404 | "metadata": {
405 | "kernelspec": {
406 | "display_name": "Python [conda env:pytorch4]",
407 | "language": "python",
408 | "name": "conda-env-pytorch4-py"
409 | },
410 | "language_info": {
411 | "codemirror_mode": {
412 | "name": "ipython",
413 | "version": 3
414 | },
415 | "file_extension": ".py",
416 | "mimetype": "text/x-python",
417 | "name": "python",
418 | "nbconvert_exporter": "python",
419 | "pygments_lexer": "ipython3",
420 | "version": "3.5.5"
421 | }
422 | },
423 | "nbformat": 4,
424 | "nbformat_minor": 2
425 | }
426 |
--------------------------------------------------------------------------------
/8.gail.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import math\n",
10 | "import random\n",
11 | "\n",
12 | "import gym\n",
13 | "import numpy as np\n",
14 | "\n",
15 | "import torch\n",
16 | "import torch.nn as nn\n",
17 | "import torch.optim as optim\n",
18 | "import torch.nn.functional as F\n",
19 | "from torch.distributions import Normal"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "from IPython.display import clear_output\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "%matplotlib inline"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "Use CUDA "
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 3,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "use_cuda = torch.cuda.is_available()\n",
47 | "device = torch.device(\"cuda\" if use_cuda else \"cpu\")"
48 | ]
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "metadata": {},
53 | "source": [
54 | "Create Environments "
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 4,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "from common.multiprocessing_env import SubprocVecEnv\n",
64 | "\n",
65 | "num_envs = 16\n",
66 | "env_name = \"Pendulum-v0\"\n",
67 | "\n",
68 | "def make_env():\n",
69 | " def _thunk():\n",
70 | " env = gym.make(env_name)\n",
71 | " return env\n",
72 | "\n",
73 | " return _thunk\n",
74 | "\n",
75 | "envs = [make_env() for i in range(num_envs)]\n",
76 | "envs = SubprocVecEnv(envs)\n",
77 | "\n",
78 | "env = gym.make(env_name)"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {},
84 | "source": [
85 | "Neural Network "
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 6,
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "def init_weights(m):\n",
95 | " if isinstance(m, nn.Linear):\n",
96 | " nn.init.normal_(m.weight, mean=0., std=0.1)\n",
97 | " nn.init.constant_(m.bias, 0.1)\n",
98 | " \n",
99 | "\n",
100 | "class ActorCritic(nn.Module):\n",
101 | " def __init__(self, num_inputs, num_outputs, hidden_size, std=0.0):\n",
102 | " super(ActorCritic, self).__init__()\n",
103 | " \n",
104 | " self.critic = nn.Sequential(\n",
105 | " nn.Linear(num_inputs, hidden_size),\n",
106 | " nn.ReLU(),\n",
107 | " nn.Linear(hidden_size, 1)\n",
108 | " )\n",
109 | " \n",
110 | " self.actor = nn.Sequential(\n",
111 | " nn.Linear(num_inputs, hidden_size),\n",
112 | " nn.ReLU(),\n",
113 | " nn.Linear(hidden_size, num_outputs),\n",
114 | " )\n",
115 | " self.log_std = nn.Parameter(torch.ones(1, num_outputs) * std)\n",
116 | " \n",
117 | " self.apply(init_weights)\n",
118 | " \n",
119 | " def forward(self, x):\n",
120 | " value = self.critic(x)\n",
121 | " mu = self.actor(x)\n",
122 | " std = self.log_std.exp().expand_as(mu)\n",
123 | " dist = Normal(mu, std)\n",
124 | " return dist, value"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 7,
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "def plot(frame_idx, rewards):\n",
134 | " clear_output(True)\n",
135 | " plt.figure(figsize=(20,5))\n",
136 | " plt.subplot(131)\n",
137 | " plt.title('frame %s. reward: %s' % (frame_idx, rewards[-1]))\n",
138 | " plt.plot(rewards)\n",
139 | " plt.show()\n",
140 | " \n",
141 | "def test_env(vis=False):\n",
142 | " state = env.reset()\n",
143 | " if vis: env.render()\n",
144 | " done = False\n",
145 | " total_reward = 0\n",
146 | " while not done:\n",
147 | " state = torch.FloatTensor(state).unsqueeze(0).to(device)\n",
148 | " dist, _ = model(state)\n",
149 | " next_state, reward, done, _ = env.step(dist.sample().cpu().numpy()[0])\n",
150 | " state = next_state\n",
151 | " if vis: env.render()\n",
152 | " total_reward += reward\n",
153 | " return total_reward"
154 | ]
155 | },
156 | {
157 | "cell_type": "markdown",
158 | "metadata": {},
159 | "source": [
160 | "GAE "
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": 9,
166 | "metadata": {},
167 | "outputs": [],
168 | "source": [
169 | "def compute_gae(next_value, rewards, masks, values, gamma=0.99, tau=0.95):\n",
170 | " values = values + [next_value]\n",
171 | " gae = 0\n",
172 | " returns = []\n",
173 | " for step in reversed(range(len(rewards))):\n",
174 | " delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]\n",
175 | " gae = delta + gamma * tau * masks[step] * gae\n",
176 | " returns.insert(0, gae + values[step])\n",
177 | " return returns"
178 | ]
179 | },
180 | {
181 | "cell_type": "markdown",
182 | "metadata": {},
183 | "source": [
184 | "PPO "
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": 33,
190 | "metadata": {},
191 | "outputs": [],
192 | "source": [
193 | "def ppo_iter(mini_batch_size, states, actions, log_probs, returns, advantage):\n",
194 | " batch_size = states.size(0)\n",
195 | " for _ in range(batch_size // mini_batch_size):\n",
196 | " rand_ids = np.random.randint(0, batch_size, mini_batch_size)\n",
197 | " yield states[rand_ids, :], actions[rand_ids, :], log_probs[rand_ids, :], returns[rand_ids, :], advantage[rand_ids, :]\n",
198 | " \n",
199 | " \n",
200 | "\n",
201 | "def ppo_update(ppo_epochs, mini_batch_size, states, actions, log_probs, returns, advantages, clip_param=0.2):\n",
202 | " for _ in range(ppo_epochs):\n",
203 | " for state, action, old_log_probs, return_, advantage in ppo_iter(mini_batch_size, states, actions, log_probs, returns, advantages):\n",
204 | " dist, value = model(state)\n",
205 | " entropy = dist.entropy().mean()\n",
206 | " new_log_probs = dist.log_prob(action)\n",
207 | "\n",
208 | " ratio = (new_log_probs - old_log_probs).exp()\n",
209 | " surr1 = ratio * advantage\n",
210 | " surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantage\n",
211 | "\n",
212 | " actor_loss = - torch.min(surr1, surr2).mean()\n",
213 | " critic_loss = (return_ - value).pow(2).mean()\n",
214 | "\n",
215 | " loss = 0.5 * critic_loss + actor_loss - 0.001 * entropy\n",
216 | "\n",
217 | " optimizer.zero_grad()\n",
218 | " loss.backward()\n",
219 | " optimizer.step()"
220 | ]
221 | },
222 | {
223 | "cell_type": "markdown",
224 | "metadata": {},
225 | "source": [
226 | "Loading expert trajectories from №3 notebook "
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": 23,
232 | "metadata": {},
233 | "outputs": [],
234 | "source": [
235 | "try:\n",
236 | " expert_traj = np.load(\"expert_traj.npy\")\n",
237 | "except:\n",
238 | " print(\"Train, generate and save expert trajectories in notebook №3\")\n",
239 | " assert False"
240 | ]
241 | },
242 | {
243 | "cell_type": "markdown",
244 | "metadata": {},
245 | "source": [
246 | "Generative Adversarial Imitation Learning \n",
247 | ""
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": 24,
253 | "metadata": {},
254 | "outputs": [],
255 | "source": [
256 | "class Discriminator(nn.Module):\n",
257 | " def __init__(self, num_inputs, hidden_size):\n",
258 | " super(Discriminator, self).__init__()\n",
259 | " \n",
260 | " self.linear1 = nn.Linear(num_inputs, hidden_size)\n",
261 | " self.linear2 = nn.Linear(hidden_size, hidden_size)\n",
262 | " self.linear3 = nn.Linear(hidden_size, 1)\n",
263 | " self.linear3.weight.data.mul_(0.1)\n",
264 | " self.linear3.bias.data.mul_(0.0)\n",
265 | " \n",
266 | " def forward(self, x):\n",
267 | " x = F.tanh(self.linear1(x))\n",
268 | " x = F.tanh(self.linear2(x))\n",
269 | " prob = F.sigmoid(self.linear3(x))\n",
270 | " return prob"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": 25,
276 | "metadata": {},
277 | "outputs": [],
278 | "source": [
279 | "def expert_reward(state, action):\n",
280 | " state = state.cpu().numpy()\n",
281 | " state_action = torch.FloatTensor(np.concatenate([state, action], 1)).to(device)\n",
282 | " return -np.log(discriminator(state_action).cpu().data.numpy())"
283 | ]
284 | },
285 | {
286 | "cell_type": "code",
287 | "execution_count": 35,
288 | "metadata": {},
289 | "outputs": [],
290 | "source": [
291 | "num_inputs = envs.observation_space.shape[0]\n",
292 | "num_outputs = envs.action_space.shape[0]\n",
293 | "\n",
294 | "\n",
295 | "#Hyper params:\n",
296 | "a2c_hidden_size = 256\n",
297 | "discrim_hidden_size = 128\n",
298 | "lr = 3e-3\n",
299 | "num_steps = 20\n",
300 | "mini_batch_size = 5\n",
301 | "ppo_epochs = 4\n",
302 | "threshold_reward = -200\n",
303 | "\n",
304 | "\n",
305 | "model = ActorCritic(num_inputs, num_outputs, a2c_hidden_size).to(device)\n",
306 | "discriminator = Discriminator(num_inputs + num_outputs, discrim_hidden_size).to(device)\n",
307 | "\n",
308 | "discrim_criterion = nn.BCELoss()\n",
309 | "\n",
310 | "optimizer = optim.Adam(model.parameters(), lr=lr)\n",
311 | "optimizer_discrim = optim.Adam(discriminator.parameters(), lr=lr)"
312 | ]
313 | },
314 | {
315 | "cell_type": "code",
316 | "execution_count": 36,
317 | "metadata": {},
318 | "outputs": [],
319 | "source": [
320 | "test_rewards = []\n",
321 | "max_frames = 100000\n",
322 | "frame_idx = 0"
323 | ]
324 | },
325 | {
326 | "cell_type": "code",
327 | "execution_count": 37,
328 | "metadata": {},
329 | "outputs": [
330 | {
331 | "data": {
332 | "image/png": "\n",
333 | "text/plain": [
334 | ""
335 | ]
336 | },
337 | "metadata": {},
338 | "output_type": "display_data"
339 | }
340 | ],
341 | "source": [
342 | "i_update = 0\n",
343 | "state = envs.reset()\n",
344 | "early_stop = False\n",
345 | "\n",
346 | "while frame_idx < max_frames and not early_stop:\n",
347 | " i_update += 1\n",
348 | " \n",
349 | " log_probs = []\n",
350 | " values = []\n",
351 | " states = []\n",
352 | " actions = []\n",
353 | " rewards = []\n",
354 | " masks = []\n",
355 | " entropy = 0\n",
356 | "\n",
357 | " for _ in range(num_steps):\n",
358 | " state = torch.FloatTensor(state).to(device)\n",
359 | " dist, value = model(state)\n",
360 | "\n",
361 | " action = dist.sample()\n",
362 | " next_state, reward, done, _ = envs.step(action.cpu().numpy())\n",
363 | " reward = expert_reward(state, action.cpu().numpy())\n",
364 | " \n",
365 | " log_prob = dist.log_prob(action)\n",
366 | " entropy += dist.entropy().mean()\n",
367 | " \n",
368 | " log_probs.append(log_prob)\n",
369 | " values.append(value)\n",
370 | " rewards.append(torch.FloatTensor(reward).to(device))\n",
371 | " masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device))\n",
372 | " \n",
373 | " states.append(state)\n",
374 | " actions.append(action)\n",
375 | " \n",
376 | " state = next_state\n",
377 | " frame_idx += 1\n",
378 | " \n",
379 | " if frame_idx % 1000 == 0:\n",
380 | " test_reward = np.mean([test_env() for _ in range(10)])\n",
381 | " test_rewards.append(test_reward)\n",
382 | " plot(frame_idx, test_rewards)\n",
383 | " if test_reward > threshold_reward: early_stop = True\n",
384 | " \n",
385 | "\n",
386 | " next_state = torch.FloatTensor(next_state).to(device)\n",
387 | " _, next_value = model(next_state)\n",
388 | " returns = compute_gae(next_value, rewards, masks, values)\n",
389 | "\n",
390 | " returns = torch.cat(returns).detach()\n",
391 | " log_probs = torch.cat(log_probs).detach()\n",
392 | " values = torch.cat(values).detach()\n",
393 | " states = torch.cat(states)\n",
394 | " actions = torch.cat(actions)\n",
395 | " advantage = returns - values\n",
396 | " \n",
397 | " if i_update % 3 == 0:\n",
398 | " ppo_update(4, mini_batch_size, states, actions, log_probs, returns, advantage)\n",
399 | " \n",
400 | " \n",
401 | " expert_state_action = expert_traj[np.random.randint(0, expert_traj.shape[0], 2 * num_steps * num_envs), :]\n",
402 | " expert_state_action = torch.FloatTensor(expert_state_action).to(device)\n",
403 | " state_action = torch.cat([states, actions], 1)\n",
404 | " fake = discriminator(state_action)\n",
405 | " real = discriminator(expert_state_action)\n",
406 | " optimizer_discrim.zero_grad()\n",
407 | " discrim_loss = discrim_criterion(fake, torch.ones((states.shape[0], 1)).to(device)) + \\\n",
408 | " discrim_criterion(real, torch.zeros((expert_state_action.size(0), 1)).to(device))\n",
409 | " discrim_loss.backward()\n",
410 | " optimizer_discrim.step()"
411 | ]
412 | },
413 | {
414 | "cell_type": "code",
415 | "execution_count": null,
416 | "metadata": {},
417 | "outputs": [],
418 | "source": []
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": null,
423 | "metadata": {},
424 | "outputs": [],
425 | "source": [
426 | "test_env(True)"
427 | ]
428 | }
429 | ],
430 | "metadata": {
431 | "kernelspec": {
432 | "display_name": "Python [conda env:pytorch4]",
433 | "language": "python",
434 | "name": "conda-env-pytorch4-py"
435 | },
436 | "language_info": {
437 | "codemirror_mode": {
438 | "name": "ipython",
439 | "version": 3
440 | },
441 | "file_extension": ".py",
442 | "mimetype": "text/x-python",
443 | "name": "python",
444 | "nbconvert_exporter": "python",
445 | "pygments_lexer": "ipython3",
446 | "version": "3.5.5"
447 | }
448 | },
449 | "nbformat": 4,
450 | "nbformat_minor": 2
451 | }
452 |
--------------------------------------------------------------------------------
/3.ppo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import math\n",
10 | "import random\n",
11 | "\n",
12 | "import gym\n",
13 | "import numpy as np\n",
14 | "\n",
15 | "import torch\n",
16 | "import torch.nn as nn\n",
17 | "import torch.optim as optim\n",
18 | "import torch.nn.functional as F\n",
19 | "from torch.distributions import Normal"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "from IPython.display import clear_output\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "%matplotlib inline"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "Use CUDA "
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 3,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "use_cuda = torch.cuda.is_available()\n",
47 | "device = torch.device(\"cuda\" if use_cuda else \"cpu\")"
48 | ]
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "metadata": {},
53 | "source": [
54 | "Create Environments "
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 4,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "from common.multiprocessing_env import SubprocVecEnv\n",
64 | "\n",
65 | "num_envs = 16\n",
66 | "env_name = \"Pendulum-v0\"\n",
67 | "\n",
68 | "def make_env():\n",
69 | " def _thunk():\n",
70 | " env = gym.make(env_name)\n",
71 | " return env\n",
72 | "\n",
73 | " return _thunk\n",
74 | "\n",
75 | "envs = [make_env() for i in range(num_envs)]\n",
76 | "envs = SubprocVecEnv(envs)\n",
77 | "\n",
78 | "env = gym.make(env_name)"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {},
84 | "source": [
85 | "Neural Network "
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 71,
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "def init_weights(m):\n",
95 | " if isinstance(m, nn.Linear):\n",
96 | " nn.init.normal_(m.weight, mean=0., std=0.1)\n",
97 | " nn.init.constant_(m.bias, 0.1)\n",
98 | " \n",
99 | "\n",
100 | "class ActorCritic(nn.Module):\n",
101 | " def __init__(self, num_inputs, num_outputs, hidden_size, std=0.0):\n",
102 | " super(ActorCritic, self).__init__()\n",
103 | " \n",
104 | " self.critic = nn.Sequential(\n",
105 | " nn.Linear(num_inputs, hidden_size),\n",
106 | " nn.ReLU(),\n",
107 | " nn.Linear(hidden_size, 1)\n",
108 | " )\n",
109 | " \n",
110 | " self.actor = nn.Sequential(\n",
111 | " nn.Linear(num_inputs, hidden_size),\n",
112 | " nn.ReLU(),\n",
113 | " nn.Linear(hidden_size, num_outputs),\n",
114 | " )\n",
115 | " self.log_std = nn.Parameter(torch.ones(1, num_outputs) * std)\n",
116 | " \n",
117 | " self.apply(init_weights)\n",
118 | " \n",
119 | " def forward(self, x):\n",
120 | " value = self.critic(x)\n",
121 | " mu = self.actor(x)\n",
122 | " std = self.log_std.exp().expand_as(mu)\n",
123 | " dist = Normal(mu, std)\n",
124 | " return dist, value"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 72,
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "def plot(frame_idx, rewards):\n",
134 | " clear_output(True)\n",
135 | " plt.figure(figsize=(20,5))\n",
136 | " plt.subplot(131)\n",
137 | " plt.title('frame %s. reward: %s' % (frame_idx, rewards[-1]))\n",
138 | " plt.plot(rewards)\n",
139 | " plt.show()\n",
140 | " \n",
141 | "def test_env(vis=False):\n",
142 | " state = env.reset()\n",
143 | " if vis: env.render()\n",
144 | " done = False\n",
145 | " total_reward = 0\n",
146 | " while not done:\n",
147 | " state = torch.FloatTensor(state).unsqueeze(0).to(device)\n",
148 | " dist, _ = model(state)\n",
149 | " next_state, reward, done, _ = env.step(dist.sample().cpu().numpy()[0])\n",
150 | " state = next_state\n",
151 | " if vis: env.render()\n",
152 | " total_reward += reward\n",
153 | " return total_reward"
154 | ]
155 | },
156 | {
157 | "cell_type": "markdown",
158 | "metadata": {},
159 | "source": [
160 | "GAE "
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": 73,
166 | "metadata": {},
167 | "outputs": [],
168 | "source": [
169 | "def compute_gae(next_value, rewards, masks, values, gamma=0.99, tau=0.95):\n",
170 | " values = values + [next_value]\n",
171 | " gae = 0\n",
172 | " returns = []\n",
173 | " for step in reversed(range(len(rewards))):\n",
174 | " delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]\n",
175 | " gae = delta + gamma * tau * masks[step] * gae\n",
176 | " returns.insert(0, gae + values[step])\n",
177 | " return returns"
178 | ]
179 | },
180 | {
181 | "cell_type": "markdown",
182 | "metadata": {},
183 | "source": [
184 | " Proximal Policy Optimization Algorithm \n",
185 | ""
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 74,
191 | "metadata": {},
192 | "outputs": [],
193 | "source": [
194 | "def ppo_iter(mini_batch_size, states, actions, log_probs, returns, advantage):\n",
195 | " batch_size = states.size(0)\n",
196 | " for _ in range(batch_size // mini_batch_size):\n",
197 | " rand_ids = np.random.randint(0, batch_size, mini_batch_size)\n",
198 | " yield states[rand_ids, :], actions[rand_ids, :], log_probs[rand_ids, :], returns[rand_ids, :], advantage[rand_ids, :]\n",
199 | " \n",
200 | " \n",
201 | "\n",
202 | "def ppo_update(ppo_epochs, mini_batch_size, states, actions, log_probs, returns, advantages, clip_param=0.2):\n",
203 | " for _ in range(ppo_epochs):\n",
204 | " for state, action, old_log_probs, return_, advantage in ppo_iter(mini_batch_size, states, actions, log_probs, returns, advantages):\n",
205 | " dist, value = model(state)\n",
206 | " entropy = dist.entropy().mean()\n",
207 | " new_log_probs = dist.log_prob(action)\n",
208 | "\n",
209 | " ratio = (new_log_probs - old_log_probs).exp()\n",
210 | " surr1 = ratio * advantage\n",
211 | " surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantage\n",
212 | "\n",
213 | " actor_loss = - torch.min(surr1, surr2).mean()\n",
214 | " critic_loss = (return_ - value).pow(2).mean()\n",
215 | "\n",
216 | " loss = 0.5 * critic_loss + actor_loss - 0.001 * entropy\n",
217 | "\n",
218 | " optimizer.zero_grad()\n",
219 | " loss.backward()\n",
220 | " optimizer.step()"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 82,
226 | "metadata": {},
227 | "outputs": [],
228 | "source": [
229 | "num_inputs = envs.observation_space.shape[0]\n",
230 | "num_outputs = envs.action_space.shape[0]\n",
231 | "\n",
232 | "#Hyper params:\n",
233 | "hidden_size = 256\n",
234 | "lr = 3e-4\n",
235 | "num_steps = 20\n",
236 | "mini_batch_size = 5\n",
237 | "ppo_epochs = 4\n",
238 | "threshold_reward = -200\n",
239 | "\n",
240 | "model = ActorCritic(num_inputs, num_outputs, hidden_size).to(device)\n",
241 | "optimizer = optim.Adam(model.parameters(), lr=lr)"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": 83,
247 | "metadata": {},
248 | "outputs": [],
249 | "source": [
250 | "max_frames = 15000\n",
251 | "frame_idx = 0\n",
252 | "test_rewards = []"
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": 86,
258 | "metadata": {},
259 | "outputs": [
260 | {
261 | "data": {
262 | "image/png": "\n",
263 | "text/plain": [
264 | ""
265 | ]
266 | },
267 | "metadata": {},
268 | "output_type": "display_data"
269 | }
270 | ],
271 | "source": [
272 | "state = envs.reset()\n",
273 | "early_stop = False\n",
274 | "\n",
275 | "while frame_idx < max_frames and not early_stop:\n",
276 | "\n",
277 | " log_probs = []\n",
278 | " values = []\n",
279 | " states = []\n",
280 | " actions = []\n",
281 | " rewards = []\n",
282 | " masks = []\n",
283 | " entropy = 0\n",
284 | "\n",
285 | " for _ in range(num_steps):\n",
286 | " state = torch.FloatTensor(state).to(device)\n",
287 | " dist, value = model(state)\n",
288 | "\n",
289 | " action = dist.sample()\n",
290 | " next_state, reward, done, _ = envs.step(action.cpu().numpy())\n",
291 | "\n",
292 | " log_prob = dist.log_prob(action)\n",
293 | " entropy += dist.entropy().mean()\n",
294 | " \n",
295 | " log_probs.append(log_prob)\n",
296 | " values.append(value)\n",
297 | " rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(device))\n",
298 | " masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device))\n",
299 | " \n",
300 | " states.append(state)\n",
301 | " actions.append(action)\n",
302 | " \n",
303 | " state = next_state\n",
304 | " frame_idx += 1\n",
305 | " \n",
306 | " if frame_idx % 1000 == 0:\n",
307 | " test_reward = np.mean([test_env() for _ in range(10)])\n",
308 | " test_rewards.append(test_reward)\n",
309 | " plot(frame_idx, test_rewards)\n",
310 | " if test_reward > threshold_reward: early_stop = True\n",
311 | " \n",
312 | "\n",
313 | " next_state = torch.FloatTensor(next_state).to(device)\n",
314 | " _, next_value = model(next_state)\n",
315 | " returns = compute_gae(next_value, rewards, masks, values)\n",
316 | "\n",
317 | " returns = torch.cat(returns).detach()\n",
318 | " log_probs = torch.cat(log_probs).detach()\n",
319 | " values = torch.cat(values).detach()\n",
320 | " states = torch.cat(states)\n",
321 | " actions = torch.cat(actions)\n",
322 | " advantage = returns - values\n",
323 | " \n",
324 | " ppo_update(ppo_epochs, mini_batch_size, states, actions, log_probs, returns, advantage)"
325 | ]
326 | },
327 | {
328 | "cell_type": "markdown",
329 | "metadata": {},
330 | "source": [
331 | "Saving trajectories for GAIL "
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": 87,
337 | "metadata": {},
338 | "outputs": [
339 | {
340 | "name": "stdout",
341 | "output_type": "stream",
342 | "text": [
343 | "episode: 0 reward: -133.5056485070341\n",
344 | "episode: 1 reward: -3.3737309166625002\n",
345 | "episode: 2 reward: -135.0328820133956\n",
346 | "episode: 3 reward: -131.27964142064513\n",
347 | "episode: 4 reward: -125.12845453838382\n",
348 | "episode: 5 reward: -4.247933460422459\n",
349 | "episode: 6 reward: -395.59297834503883\n",
350 | "episode: 7 reward: -253.25736991568547\n",
351 | "episode: 8 reward: -135.50603026103278\n",
352 | "episode: 9 reward: -132.72095459732952\n",
353 | "episode: 10 reward: -133.89608385869212\n",
354 | "episode: 11 reward: -4.5990508813314035\n",
355 | "episode: 12 reward: -134.44470210766775\n",
356 | "episode: 13 reward: -801.7661346371387\n",
357 | "episode: 14 reward: -131.97725229377644\n",
358 | "episode: 15 reward: -266.76940521674015\n",
359 | "episode: 16 reward: -247.5062278004002\n",
360 | "episode: 17 reward: -4.914595620774103\n",
361 | "episode: 18 reward: -138.7990887577753\n",
362 | "episode: 19 reward: -268.3754189751262\n",
363 | "episode: 20 reward: -363.28764882256417\n",
364 | "episode: 21 reward: -128.15870842354997\n",
365 | "episode: 22 reward: -134.94598918501788\n",
366 | "episode: 23 reward: -309.9577786212293\n",
367 | "episode: 24 reward: -131.91670030817002\n",
368 | "episode: 25 reward: -134.65823444568952\n",
369 | "episode: 26 reward: -134.5615349098279\n",
370 | "episode: 27 reward: -273.5740578550409\n",
371 | "episode: 28 reward: -265.05553942459926\n",
372 | "episode: 29 reward: -258.0591054576666\n",
373 | "episode: 30 reward: -128.91060595426686\n",
374 | "episode: 31 reward: -656.2461074160591\n",
375 | "episode: 32 reward: -136.84071690580248\n",
376 | "episode: 33 reward: -259.2365200533221\n",
377 | "episode: 34 reward: -132.68644155022494\n",
378 | "episode: 35 reward: -260.66364797902054\n",
379 | "episode: 36 reward: -128.8211009270027\n",
380 | "episode: 37 reward: -384.53615237759317\n",
381 | "episode: 38 reward: -4.612904346743044\n",
382 | "episode: 39 reward: -401.1162060114804\n",
383 | "episode: 40 reward: -126.25334578262932\n",
384 | "episode: 41 reward: -3.845934927726255\n",
385 | "episode: 42 reward: -132.44253012402612\n",
386 | "episode: 43 reward: -134.1267203432647\n",
387 | "episode: 44 reward: -128.56866661753938\n",
388 | "episode: 45 reward: -4.97856955649956\n",
389 | "episode: 46 reward: -392.498679426522\n",
390 | "episode: 47 reward: -4.756869243844947\n",
391 | "episode: 48 reward: -4.59189846851519\n",
392 | "episode: 49 reward: -4.7496626929539225\n",
393 | "episode: 50 reward: -131.08999767991665\n",
394 | "episode: 51 reward: -138.17235302513578\n",
395 | "episode: 52 reward: -3.751761058079555\n",
396 | "episode: 53 reward: -260.6317126814632\n",
397 | "episode: 54 reward: -4.535299319594524\n",
398 | "episode: 55 reward: -133.70892423024802\n",
399 | "episode: 56 reward: -134.8732103854694\n",
400 | "episode: 57 reward: -5.315182694344295\n",
401 | "episode: 58 reward: -265.04898120165\n",
402 | "episode: 59 reward: -124.99288470795233\n",
403 | "episode: 60 reward: -4.247632479535832\n",
404 | "episode: 61 reward: -3.68334723705883\n",
405 | "episode: 62 reward: -133.617727327027\n",
406 | "episode: 63 reward: -136.28353948776376\n",
407 | "episode: 64 reward: -5.056124136459314\n",
408 | "episode: 65 reward: -262.7844771770983\n",
409 | "episode: 66 reward: -251.52420165781922\n",
410 | "episode: 67 reward: -133.4014820950796\n",
411 | "episode: 68 reward: -7.0558924646711\n",
412 | "episode: 69 reward: -135.41150554590206\n",
413 | "episode: 70 reward: -131.8871841825757\n",
414 | "episode: 71 reward: -130.8724972571845\n",
415 | "episode: 72 reward: -367.7339135957503\n",
416 | "episode: 73 reward: -134.25198778254116\n",
417 | "episode: 74 reward: -133.86858295338342\n",
418 | "episode: 75 reward: -378.9443227440811\n",
419 | "episode: 76 reward: -3.5473336732949625\n",
420 | "episode: 77 reward: -261.5470895641183\n",
421 | "episode: 78 reward: -408.34135925288217\n",
422 | "episode: 79 reward: -257.6727990499033\n",
423 | "episode: 80 reward: -399.78682205537433\n",
424 | "episode: 81 reward: -266.08087229456055\n",
425 | "episode: 82 reward: -817.186490578741\n",
426 | "episode: 83 reward: -4.500140134501902\n",
427 | "episode: 84 reward: -508.65456581456573\n",
428 | "episode: 85 reward: -378.46002005145874\n",
429 | "episode: 86 reward: -137.76181809972095\n",
430 | "episode: 87 reward: -674.8280917415572\n",
431 | "episode: 88 reward: -128.65034230393303\n",
432 | "episode: 89 reward: -3.922315525193146\n",
433 | "episode: 90 reward: -131.00005239353024\n",
434 | "episode: 91 reward: -130.68974732718007\n",
435 | "episode: 92 reward: -135.21946982972375\n",
436 | "episode: 93 reward: -137.3667851983452\n",
437 | "episode: 94 reward: -136.9119001250973\n",
438 | "episode: 95 reward: -254.5371556381929\n",
439 | "episode: 96 reward: -374.827391591992\n",
440 | "episode: 97 reward: -523.9964989484117\n",
441 | "episode: 98 reward: -133.94200200894622\n",
442 | "episode: 99 reward: -133.74880434577523\n",
443 | "episode: 100 reward: -247.32247835568552\n",
444 | "episode: 101 reward: -138.75528548988993\n",
445 | "episode: 102 reward: -4.847096453940289\n",
446 | "episode: 103 reward: -136.62732481247133\n",
447 | "episode: 104 reward: -262.20300946977864\n",
448 | "episode: 105 reward: -6.5435854338994\n",
449 | "episode: 106 reward: -125.17361036750681\n",
450 | "episode: 107 reward: -690.5202921080676\n",
451 | "episode: 108 reward: -280.53617631459497\n",
452 | "episode: 109 reward: -135.40352441695322\n",
453 | "episode: 110 reward: -131.07617970631023\n",
454 | "episode: 111 reward: -247.0260554601557\n",
455 | "episode: 112 reward: -135.40673404514774\n",
456 | "episode: 113 reward: -395.03306256658476\n",
457 | "episode: 114 reward: -384.1784417792837\n",
458 | "episode: 115 reward: -128.4500742980931\n",
459 | "episode: 116 reward: -463.6977661877445\n",
460 | "episode: 117 reward: -130.94801971085445\n",
461 | "episode: 118 reward: -144.0228791279258\n",
462 | "episode: 119 reward: -667.2634492717342\n",
463 | "episode: 120 reward: -131.79948959004724\n",
464 | "episode: 121 reward: -138.03140142705894\n",
465 | "episode: 122 reward: -129.26779443720966\n",
466 | "episode: 123 reward: -3.2877798185337896\n",
467 | "episode: 124 reward: -134.72016283865193\n",
468 | "episode: 125 reward: -382.2159098741087\n",
469 | "episode: 126 reward: -264.6491917411121\n",
470 | "episode: 127 reward: -134.2254720027939\n",
471 | "episode: 128 reward: -424.8235005744391\n",
472 | "episode: 129 reward: -134.52619102883028\n",
473 | "episode: 130 reward: -537.7406839640856\n",
474 | "episode: 131 reward: -133.90654715605245\n",
475 | "episode: 132 reward: -132.20198118805123\n",
476 | "episode: 133 reward: -400.3589991495165\n",
477 | "episode: 134 reward: -130.12695949420717\n",
478 | "episode: 135 reward: -290.86810229081595\n",
479 | "episode: 136 reward: -394.9043391522139\n",
480 | "episode: 137 reward: -133.42125091255778\n",
481 | "episode: 138 reward: -134.96306459417266\n",
482 | "episode: 139 reward: -3.8499366797706336\n",
483 | "episode: 140 reward: -3.828788719469504\n",
484 | "episode: 141 reward: -5.554963437941836\n",
485 | "episode: 142 reward: -4.510403163975261\n",
486 | "episode: 143 reward: -325.97799775791754\n",
487 | "episode: 144 reward: -3.1174779530363375\n",
488 | "episode: 145 reward: -134.55262416681552\n",
489 | "episode: 146 reward: -350.45777263184095\n",
490 | "episode: 147 reward: -137.33235583532627\n",
491 | "episode: 148 reward: -452.0061280718382\n",
492 | "episode: 149 reward: -265.98673902850385\n",
493 | "episode: 150 reward: -284.8590382363739\n",
494 | "episode: 151 reward: -250.06981206461143\n",
495 | "episode: 152 reward: -129.50428228187013\n",
496 | "episode: 153 reward: -393.09302439930724\n",
497 | "episode: 154 reward: -5.075964808667517\n",
498 | "episode: 155 reward: -129.83816358490287\n",
499 | "episode: 156 reward: -266.1020126434327\n",
500 | "episode: 157 reward: -132.23463644630868\n",
501 | "episode: 158 reward: -779.5855091317233\n",
502 | "episode: 159 reward: -3.763971510946643\n",
503 | "episode: 160 reward: -132.67794144748086\n",
504 | "episode: 161 reward: -662.5587064643477\n",
505 | "episode: 162 reward: -135.2401324340408\n",
506 | "episode: 163 reward: -259.9633585943629\n",
507 | "episode: 164 reward: -6.232862086437321\n",
508 | "episode: 165 reward: -139.498411973157\n",
509 | "episode: 166 reward: -135.35070491390638\n",
510 | "episode: 167 reward: -135.1400077480551\n",
511 | "episode: 168 reward: -347.3664683729514\n",
512 | "episode: 169 reward: -427.1984854733556\n",
513 | "episode: 170 reward: -5.15672209428849\n",
514 | "episode: 171 reward: -525.916662268042\n",
515 | "episode: 172 reward: -133.7053511504196\n",
516 | "episode: 173 reward: -271.26784680564384\n",
517 | "episode: 174 reward: -124.85474506625023\n",
518 | "episode: 175 reward: -134.19873581079943\n",
519 | "episode: 176 reward: -255.83160338962983\n",
520 | "episode: 177 reward: -135.13400569542506\n",
521 | "episode: 178 reward: -4.960226836538054\n",
522 | "episode: 179 reward: -139.19809065222032\n",
523 | "episode: 180 reward: -140.05080094044732\n",
524 | "episode: 181 reward: -137.76647105767526\n",
525 | "episode: 182 reward: -403.1731636539886\n",
526 | "episode: 183 reward: -257.970427512537\n",
527 | "episode: 184 reward: -3.7473226459331066\n",
528 | "episode: 185 reward: -278.3098063643893\n",
529 | "episode: 186 reward: -255.99692458401518\n",
530 | "episode: 187 reward: -4.6365121508813445\n",
531 | "episode: 188 reward: -244.67627722290948\n",
532 | "episode: 189 reward: -131.21920785362062\n",
533 | "episode: 190 reward: -777.3698354491825\n",
534 | "episode: 191 reward: -132.07220706141683\n",
535 | "episode: 192 reward: -392.09434598281683\n",
536 | "episode: 193 reward: -136.06354238422503\n",
537 | "episode: 194 reward: -377.4409927865957\n",
538 | "episode: 195 reward: -132.18253486880235\n",
539 | "episode: 196 reward: -129.15162595976702\n",
540 | "episode: 197 reward: -396.5254064840202\n",
541 | "episode: 198 reward: -3.610361833207753\n",
542 | "episode: 199 reward: -245.53736015092704\n",
543 | "episode: 200 reward: -270.99181854480565\n",
544 | "episode: 201 reward: -247.4231450110685\n",
545 | "episode: 202 reward: -131.59894474370887\n",
546 | "episode: 203 reward: -144.7898370619998\n",
547 | "episode: 204 reward: -926.5588068852352\n",
548 | "episode: 205 reward: -133.39727923189105\n",
549 | "episode: 206 reward: -131.93566436017008\n",
550 | "episode: 207 reward: -6.40529176710689\n",
551 | "episode: 208 reward: -257.08448208556194\n",
552 | "episode: 209 reward: -130.92098423630432\n",
553 | "episode: 210 reward: -262.2927047192545\n",
554 | "episode: 211 reward: -6.859901180492491\n",
555 | "episode: 212 reward: -262.70877767928914\n",
556 | "episode: 213 reward: -134.56588203218894\n",
557 | "episode: 214 reward: -135.22465193371625\n",
558 | "episode: 215 reward: -137.9657247788344\n",
559 | "episode: 216 reward: -135.13425433384725\n",
560 | "episode: 217 reward: -132.3215993693809\n",
561 | "episode: 218 reward: -400.611961792729\n",
562 | "episode: 219 reward: -401.91908212383294\n",
563 | "episode: 220 reward: -282.5082305011229\n",
564 | "episode: 221 reward: -135.42191465289923\n",
565 | "episode: 222 reward: -399.7881535647735\n",
566 | "episode: 223 reward: -131.06522770318847\n",
567 | "episode: 224 reward: -130.7681491912167\n",
568 | "episode: 225 reward: -135.31477016876133\n",
569 | "episode: 226 reward: -3.914901001828447\n",
570 | "episode: 227 reward: -134.5129393394648\n",
571 | "episode: 228 reward: -376.1469783238271\n",
572 | "episode: 229 reward: -133.09045533066046\n",
573 | "episode: 230 reward: -383.2750315233141\n",
574 | "episode: 231 reward: -263.71240275232276\n",
575 | "episode: 232 reward: -500.0083919266878\n",
576 | "episode: 233 reward: -135.22531187168758\n",
577 | "episode: 234 reward: -135.17818433537522\n",
578 | "episode: 235 reward: -395.9834332194123\n",
579 | "episode: 236 reward: -126.08778928679216\n",
580 | "episode: 237 reward: -413.7495701300203\n",
581 | "episode: 238 reward: -131.37116502717876\n",
582 | "episode: 239 reward: -121.6506938627967\n",
583 | "episode: 240 reward: -653.7053929625495\n",
584 | "episode: 241 reward: -254.87183145095838\n",
585 | "episode: 242 reward: -129.71331746419523\n",
586 | "episode: 243 reward: -265.9795936355916\n",
587 | "episode: 244 reward: -400.65989274385277\n",
588 | "episode: 245 reward: -251.82522565834446\n",
589 | "episode: 246 reward: -3.95924871368981\n",
590 | "episode: 247 reward: -312.7505665224348\n",
591 | "episode: 248 reward: -135.5875093701436\n",
592 | "episode: 249 reward: -441.6053043293015\n"
593 | ]
594 | }
595 | ],
596 | "source": [
597 | "from itertools import count\n",
598 | "\n",
599 | "max_expert_num = 50000\n",
600 | "num_steps = 0\n",
601 | "expert_traj = []\n",
602 | "\n",
603 | "for i_episode in count():\n",
604 | " state = env.reset()\n",
605 | " done = False\n",
606 | " total_reward = 0\n",
607 | " \n",
608 | " while not done:\n",
609 | " state = torch.FloatTensor(state).unsqueeze(0).to(device)\n",
610 | " dist, _ = model(state)\n",
611 | " action = dist.sample().cpu().numpy()[0]\n",
612 | " next_state, reward, done, _ = env.step(action)\n",
613 | " state = next_state\n",
614 | " total_reward += reward\n",
615 | " expert_traj.append(np.hstack([state, action]))\n",
616 | " num_steps += 1\n",
617 | " \n",
618 | " print(\"episode:\", i_episode, \"reward:\", total_reward)\n",
619 | " \n",
620 | " if num_steps >= max_expert_num:\n",
621 | " break\n",
622 | " \n",
623 | "expert_traj = np.stack(expert_traj)\n",
624 | "print()\n",
625 | "print(expert_traj.shape)\n",
626 | "print()\n",
627 | "np.save(\"expert_traj.npy\", expert_traj)"
628 | ]
629 | }
630 | ],
631 | "metadata": {
632 | "kernelspec": {
633 | "display_name": "Python [conda env:pytorch4]",
634 | "language": "python",
635 | "name": "conda-env-pytorch4-py"
636 | },
637 | "language_info": {
638 | "codemirror_mode": {
639 | "name": "ipython",
640 | "version": 3
641 | },
642 | "file_extension": ".py",
643 | "mimetype": "text/x-python",
644 | "name": "python",
645 | "nbconvert_exporter": "python",
646 | "pygments_lexer": "ipython3",
647 | "version": "3.5.5"
648 | }
649 | },
650 | "nbformat": 4,
651 | "nbformat_minor": 2
652 | }
653 |
--------------------------------------------------------------------------------