├── .gitattributes ├── Cartpole.png ├── model ├── actor.pkl └── critic.pkl ├── reinforcement_learning.png ├── .idea ├── misc.xml ├── modules.xml ├── Actor-Critic-pytorch.iml └── workspace.xml ├── LICENSE ├── .gitignore ├── README.md └── Actor-Critic.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /Cartpole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yc930401/Actor-Critic-pytorch/HEAD/Cartpole.png -------------------------------------------------------------------------------- /model/actor.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yc930401/Actor-Critic-pytorch/HEAD/model/actor.pkl -------------------------------------------------------------------------------- /model/critic.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yc930401/Actor-Critic-pytorch/HEAD/model/critic.pkl -------------------------------------------------------------------------------- /reinforcement_learning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yc930401/Actor-Critic-pytorch/HEAD/reinforcement_learning.png -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/Actor-Critic-pytorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yang Cheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DQN to play Cartpole game with pytorch 2 | 3 | DQN to play Cartpole game with pytorch 4 | 5 | ## Introduction 6 | 7 | Humans excel at solving a wide variety of challenging problems, from low-level motor control through to high-level cognitive tasks. 8 | Like a human, our agents learn for themselves to achieve successful strategies that lead to the greatest long-term rewards. This paradigm of 9 | learning by trial-and-error, solely from rewards or punishments, is known as reinforcement learning (RL). Also like a human, our agents 10 | construct and learn their own knowledge directly from raw inputs, such as vision, without any hand-engineered features or domain heuristics. 11 | This is achieved by deep learning of neural networks. 12 | The agents must continually make value judgements so as to select good actions over bad. This knowledge is represented by a Q-network that 13 | estimates the total reward that an agent can expect to receive after taking a particular action. The key idea was to use deep neural networks 14 | to represent the Q-network, and to train this Q-network to predict total reward. Previous attempts to combine RL with neural networks had 15 | largely failed due to unstable learning. To address these instabilities, our Deep Q-Networks (DQN) algorithm stores all of the agent's experiences 16 | and then randomly samples and replays these experiences to provide diverse and decorrelated training data.
17 | Reinforcement learning:
18 | ![reinforcement learning](reinforcement_learning.png)
19 | In this post, I implement a DQN to Cartpole game:
20 | ![Cartpole](Cartpole.png)
21 | 22 | 23 | ## Methodology 24 | 25 | 1. Define a Actor network and a Critic Network 26 | 2. Get data (state, next_state, reward, done signals) from gym 27 | 3. Play Cartpole game and calculate rewards for each step at the end of one game, train the two networks 28 | 4. Save the model 29 | 30 | 31 | 32 | ## References: 33 | https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/5-2-policy-gradient-softmax2/
34 | https://github.com/higgsfield/RL-Adventure-2/blob/master/1.actor-critic.ipynb
35 | https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/5-2-policy-gradient-softmax2/
36 | https://arxiv.org/pdf/1509.02971.pdf
37 | -------------------------------------------------------------------------------- /Actor-Critic.py: -------------------------------------------------------------------------------- 1 | import gym, os 2 | from itertools import count 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | from torch.distributions import Categorical 8 | 9 | 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | env = gym.make("CartPole-v0").unwrapped 12 | 13 | state_size = env.observation_space.shape[0] 14 | action_size = env.action_space.n 15 | lr = 0.0001 16 | 17 | class Actor(nn.Module): 18 | def __init__(self, state_size, action_size): 19 | super(Actor, self).__init__() 20 | self.state_size = state_size 21 | self.action_size = action_size 22 | self.linear1 = nn.Linear(self.state_size, 128) 23 | self.linear2 = nn.Linear(128, 256) 24 | self.linear3 = nn.Linear(256, self.action_size) 25 | 26 | def forward(self, state): 27 | output = F.relu(self.linear1(state)) 28 | output = F.relu(self.linear2(output)) 29 | output = self.linear3(output) 30 | distribution = Categorical(F.softmax(output, dim=-1)) 31 | return distribution 32 | 33 | 34 | class Critic(nn.Module): 35 | def __init__(self, state_size, action_size): 36 | super(Critic, self).__init__() 37 | self.state_size = state_size 38 | self.action_size = action_size 39 | self.linear1 = nn.Linear(self.state_size, 128) 40 | self.linear2 = nn.Linear(128, 256) 41 | self.linear3 = nn.Linear(256, 1) 42 | 43 | def forward(self, state): 44 | output = F.relu(self.linear1(state)) 45 | output = F.relu(self.linear2(output)) 46 | value = self.linear3(output) 47 | return value 48 | 49 | 50 | def compute_returns(next_value, rewards, masks, gamma=0.99): 51 | R = next_value 52 | returns = [] 53 | for step in reversed(range(len(rewards))): 54 | R = rewards[step] + gamma * R * masks[step] 55 | returns.insert(0, R) 56 | return returns 57 | 58 | 59 | def trainIters(actor, critic, n_iters): 60 | optimizerA = optim.Adam(actor.parameters()) 61 | optimizerC = optim.Adam(critic.parameters()) 62 | for iter in range(n_iters): 63 | state = env.reset() 64 | log_probs = [] 65 | values = [] 66 | rewards = [] 67 | masks = [] 68 | entropy = 0 69 | env.reset() 70 | 71 | for i in count(): 72 | env.render() 73 | state = torch.FloatTensor(state).to(device) 74 | dist, value = actor(state), critic(state) 75 | 76 | action = dist.sample() 77 | next_state, reward, done, _ = env.step(action.cpu().numpy()) 78 | 79 | log_prob = dist.log_prob(action).unsqueeze(0) 80 | entropy += dist.entropy().mean() 81 | 82 | log_probs.append(log_prob) 83 | values.append(value) 84 | rewards.append(torch.tensor([reward], dtype=torch.float, device=device)) 85 | masks.append(torch.tensor([1-done], dtype=torch.float, device=device)) 86 | 87 | state = next_state 88 | 89 | if done: 90 | print('Iteration: {}, Score: {}'.format(iter, i)) 91 | break 92 | 93 | 94 | next_state = torch.FloatTensor(next_state).to(device) 95 | next_value = critic(next_state) 96 | returns = compute_returns(next_value, rewards, masks) 97 | 98 | log_probs = torch.cat(log_probs) 99 | returns = torch.cat(returns).detach() 100 | values = torch.cat(values) 101 | 102 | advantage = returns - values 103 | 104 | actor_loss = -(log_probs * advantage.detach()).mean() 105 | critic_loss = advantage.pow(2).mean() 106 | 107 | optimizerA.zero_grad() 108 | optimizerC.zero_grad() 109 | actor_loss.backward() 110 | critic_loss.backward() 111 | optimizerA.step() 112 | optimizerC.step() 113 | torch.save(actor, 'model/actor.pkl') 114 | torch.save(critic, 'model/critic.pkl') 115 | env.close() 116 | 117 | 118 | if __name__ == '__main__': 119 | if os.path.exists('model/actor.pkl'): 120 | actor = torch.load('model/actor.pkl') 121 | print('Actor Model loaded') 122 | else: 123 | actor = Actor(state_size, action_size).to(device) 124 | if os.path.exists('model/critic.pkl'): 125 | critic = torch.load('model/critic.pkl') 126 | print('Critic Model loaded') 127 | else: 128 | critic = Critic(state_size, action_size).to(device) 129 | trainIters(actor, critic, n_iters=100) -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 46 | 47 | 48 | 49 | curr 50 | finish_episode 51 | floa 52 | print 53 | action 54 | 55 | 56 | 57 | 65 | 66 | 67 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 |