├── .gitattributes
├── Cartpole.png
├── model
├── actor.pkl
└── critic.pkl
├── reinforcement_learning.png
├── .idea
├── misc.xml
├── modules.xml
├── Actor-Critic-pytorch.iml
└── workspace.xml
├── LICENSE
├── .gitignore
├── README.md
└── Actor-Critic.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/Cartpole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yc930401/Actor-Critic-pytorch/HEAD/Cartpole.png
--------------------------------------------------------------------------------
/model/actor.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yc930401/Actor-Critic-pytorch/HEAD/model/actor.pkl
--------------------------------------------------------------------------------
/model/critic.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yc930401/Actor-Critic-pytorch/HEAD/model/critic.pkl
--------------------------------------------------------------------------------
/reinforcement_learning.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yc930401/Actor-Critic-pytorch/HEAD/reinforcement_learning.png
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/Actor-Critic-pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Yang Cheng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DQN to play Cartpole game with pytorch
2 |
3 | DQN to play Cartpole game with pytorch
4 |
5 | ## Introduction
6 |
7 | Humans excel at solving a wide variety of challenging problems, from low-level motor control through to high-level cognitive tasks.
8 | Like a human, our agents learn for themselves to achieve successful strategies that lead to the greatest long-term rewards. This paradigm of
9 | learning by trial-and-error, solely from rewards or punishments, is known as reinforcement learning (RL). Also like a human, our agents
10 | construct and learn their own knowledge directly from raw inputs, such as vision, without any hand-engineered features or domain heuristics.
11 | This is achieved by deep learning of neural networks.
12 | The agents must continually make value judgements so as to select good actions over bad. This knowledge is represented by a Q-network that
13 | estimates the total reward that an agent can expect to receive after taking a particular action. The key idea was to use deep neural networks
14 | to represent the Q-network, and to train this Q-network to predict total reward. Previous attempts to combine RL with neural networks had
15 | largely failed due to unstable learning. To address these instabilities, our Deep Q-Networks (DQN) algorithm stores all of the agent's experiences
16 | and then randomly samples and replays these experiences to provide diverse and decorrelated training data.
17 | Reinforcement learning:
18 | 
19 | In this post, I implement a DQN to Cartpole game:
20 | 
21 |
22 |
23 | ## Methodology
24 |
25 | 1. Define a Actor network and a Critic Network
26 | 2. Get data (state, next_state, reward, done signals) from gym
27 | 3. Play Cartpole game and calculate rewards for each step at the end of one game, train the two networks
28 | 4. Save the model
29 |
30 |
31 |
32 | ## References:
33 | https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/5-2-policy-gradient-softmax2/
34 | https://github.com/higgsfield/RL-Adventure-2/blob/master/1.actor-critic.ipynb
35 | https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/5-2-policy-gradient-softmax2/
36 | https://arxiv.org/pdf/1509.02971.pdf
37 |
--------------------------------------------------------------------------------
/Actor-Critic.py:
--------------------------------------------------------------------------------
1 | import gym, os
2 | from itertools import count
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | from torch.distributions import Categorical
8 |
9 |
10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11 | env = gym.make("CartPole-v0").unwrapped
12 |
13 | state_size = env.observation_space.shape[0]
14 | action_size = env.action_space.n
15 | lr = 0.0001
16 |
17 | class Actor(nn.Module):
18 | def __init__(self, state_size, action_size):
19 | super(Actor, self).__init__()
20 | self.state_size = state_size
21 | self.action_size = action_size
22 | self.linear1 = nn.Linear(self.state_size, 128)
23 | self.linear2 = nn.Linear(128, 256)
24 | self.linear3 = nn.Linear(256, self.action_size)
25 |
26 | def forward(self, state):
27 | output = F.relu(self.linear1(state))
28 | output = F.relu(self.linear2(output))
29 | output = self.linear3(output)
30 | distribution = Categorical(F.softmax(output, dim=-1))
31 | return distribution
32 |
33 |
34 | class Critic(nn.Module):
35 | def __init__(self, state_size, action_size):
36 | super(Critic, self).__init__()
37 | self.state_size = state_size
38 | self.action_size = action_size
39 | self.linear1 = nn.Linear(self.state_size, 128)
40 | self.linear2 = nn.Linear(128, 256)
41 | self.linear3 = nn.Linear(256, 1)
42 |
43 | def forward(self, state):
44 | output = F.relu(self.linear1(state))
45 | output = F.relu(self.linear2(output))
46 | value = self.linear3(output)
47 | return value
48 |
49 |
50 | def compute_returns(next_value, rewards, masks, gamma=0.99):
51 | R = next_value
52 | returns = []
53 | for step in reversed(range(len(rewards))):
54 | R = rewards[step] + gamma * R * masks[step]
55 | returns.insert(0, R)
56 | return returns
57 |
58 |
59 | def trainIters(actor, critic, n_iters):
60 | optimizerA = optim.Adam(actor.parameters())
61 | optimizerC = optim.Adam(critic.parameters())
62 | for iter in range(n_iters):
63 | state = env.reset()
64 | log_probs = []
65 | values = []
66 | rewards = []
67 | masks = []
68 | entropy = 0
69 | env.reset()
70 |
71 | for i in count():
72 | env.render()
73 | state = torch.FloatTensor(state).to(device)
74 | dist, value = actor(state), critic(state)
75 |
76 | action = dist.sample()
77 | next_state, reward, done, _ = env.step(action.cpu().numpy())
78 |
79 | log_prob = dist.log_prob(action).unsqueeze(0)
80 | entropy += dist.entropy().mean()
81 |
82 | log_probs.append(log_prob)
83 | values.append(value)
84 | rewards.append(torch.tensor([reward], dtype=torch.float, device=device))
85 | masks.append(torch.tensor([1-done], dtype=torch.float, device=device))
86 |
87 | state = next_state
88 |
89 | if done:
90 | print('Iteration: {}, Score: {}'.format(iter, i))
91 | break
92 |
93 |
94 | next_state = torch.FloatTensor(next_state).to(device)
95 | next_value = critic(next_state)
96 | returns = compute_returns(next_value, rewards, masks)
97 |
98 | log_probs = torch.cat(log_probs)
99 | returns = torch.cat(returns).detach()
100 | values = torch.cat(values)
101 |
102 | advantage = returns - values
103 |
104 | actor_loss = -(log_probs * advantage.detach()).mean()
105 | critic_loss = advantage.pow(2).mean()
106 |
107 | optimizerA.zero_grad()
108 | optimizerC.zero_grad()
109 | actor_loss.backward()
110 | critic_loss.backward()
111 | optimizerA.step()
112 | optimizerC.step()
113 | torch.save(actor, 'model/actor.pkl')
114 | torch.save(critic, 'model/critic.pkl')
115 | env.close()
116 |
117 |
118 | if __name__ == '__main__':
119 | if os.path.exists('model/actor.pkl'):
120 | actor = torch.load('model/actor.pkl')
121 | print('Actor Model loaded')
122 | else:
123 | actor = Actor(state_size, action_size).to(device)
124 | if os.path.exists('model/critic.pkl'):
125 | critic = torch.load('model/critic.pkl')
126 | print('Critic Model loaded')
127 | else:
128 | critic = Critic(state_size, action_size).to(device)
129 | trainIters(actor, critic, n_iters=100)
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
46 |
47 |
48 |
49 | curr
50 | finish_episode
51 | floa
52 | print
53 | action
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 | 1533277484576
207 |
208 |
209 | 1533277484576
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
--------------------------------------------------------------------------------