├── scores ├── __init__.py ├── scores.png ├── solved.png ├── solved.csv ├── scores.csv └── score_logger.py ├── .github └── FUNDING.yml ├── requirements.txt ├── assets ├── cartpole_example.gif └── cartpole_icon_web.png ├── .idea ├── vcs.xml ├── other.xml ├── modules.xml ├── misc.xml ├── cartpole.iml └── workspace.xml ├── LICENSE ├── .gitignore ├── README.md └── cartpole.py /scores/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | patreon: gsurma 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | gym 3 | keras 4 | matplotlib 5 | tensorflow 6 | -------------------------------------------------------------------------------- /scores/scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gsurma/cartpole/HEAD/scores/scores.png -------------------------------------------------------------------------------- /scores/solved.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gsurma/cartpole/HEAD/scores/solved.png -------------------------------------------------------------------------------- /assets/cartpole_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gsurma/cartpole/HEAD/assets/cartpole_example.gif -------------------------------------------------------------------------------- /assets/cartpole_icon_web.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gsurma/cartpole/HEAD/assets/cartpole_icon_web.png -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /scores/solved.csv: -------------------------------------------------------------------------------- 1 | 88 2 | 572 3 | 66 4 | 361 5 | 57 6 | 115 7 | 64 8 | 25 9 | 9 10 | 324 11 | 0 12 | 234 13 | 11 14 | 123 15 | 527 16 | 10 17 | 49 18 | 72 19 | 3 20 | 237 21 | 395 22 | 313 23 | 15 24 | 5 25 | 120 26 | 0 27 | 78 28 | 36 29 | 37 30 | 12 31 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/cartpole.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Grzegorz Surma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scores/scores.csv: -------------------------------------------------------------------------------- 1 | 43 2 | 40 3 | 26 4 | 32 5 | 13 6 | 13 7 | 12 8 | 11 9 | 13 10 | 15 11 | 8 12 | 11 13 | 9 14 | 13 15 | 10 16 | 10 17 | 8 18 | 14 19 | 16 20 | 9 21 | 11 22 | 12 23 | 10 24 | 8 25 | 10 26 | 10 27 | 13 28 | 9 29 | 88 30 | 97 31 | 56 32 | 28 33 | 24 34 | 41 35 | 45 36 | 29 37 | 30 38 | 68 39 | 49 40 | 34 41 | 62 42 | 67 43 | 87 44 | 59 45 | 97 46 | 69 47 | 96 48 | 109 49 | 184 50 | 201 51 | 176 52 | 139 53 | 340 54 | 238 55 | 283 56 | 237 57 | 250 58 | 374 59 | 226 60 | 256 61 | 419 62 | 230 63 | 265 64 | 280 65 | 220 66 | 260 67 | 234 68 | 240 69 | 209 70 | 500 71 | 500 72 | 424 73 | 212 74 | 500 75 | 300 76 | 269 77 | 446 78 | 209 79 | 203 80 | 251 81 | 229 82 | 203 83 | 500 84 | 232 85 | 360 86 | 388 87 | 317 88 | 184 89 | 500 90 | 500 91 | 306 92 | 500 93 | 425 94 | 464 95 | 297 96 | 346 97 | 105 98 | 10 99 | 9 100 | 11 101 | 10 102 | 11 103 | 10 104 | 440 105 | 475 106 | 500 107 | 431 108 | 500 109 | 179 110 | 500 111 | 13 112 | 500 113 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | # Cartpole 6 | 7 | Reinforcement Learning solution of the [OpenAI's Cartpole](https://gym.openai.com/envs/CartPole-v0/). 8 | 9 | Check out corresponding Medium article: [Cartpole - Introduction to Reinforcement Learning (DQN - Deep Q-Learning)](https://towardsdatascience.com/cartpole-introduction-to-reinforcement-learning-ed0eb5b58288) 10 | 11 | ## About 12 | 13 | > A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The system is controlled by applying a force of +1 or -1 to the cart. The pendulum starts upright, and the goal is to prevent it from falling over. A reward of +1 is provided for every timestep that the pole remains upright. The episode ends when the pole is more than 15 degrees from vertical, or the cart moves more than 2.4 units from the center. [source](https://gym.openai.com/envs/CartPole-v0/) 14 | 15 | ## DQN 16 | Standard DQN with Experience Replay. 17 | 18 | ### Hyperparameters: 19 | 20 | * GAMMA = 0.95 21 | * LEARNING_RATE = 0.001 22 | * MEMORY_SIZE = 1000000 23 | * BATCH_SIZE = 20 24 | * EXPLORATION_MAX = 1.0 25 | * EXPLORATION_MIN = 0.01 26 | * EXPLORATION_DECAY = 0.995 27 | 28 | ### Model structure: 29 | 30 | 1. Dense layer - input: **4**, output: **24**, activation: **relu** 31 | 2. Dense layer - input **24**, output: **24**, activation: **relu** 32 | 3. Dense layer - input **24**, output: **2**, activation: **linear** 33 | 34 | * **MSE** loss function 35 | * **Adam** optimizer 36 | 37 | 38 | ## Performance 39 | 40 | > CartPole-v0 defines "solving" as getting average reward of 195.0 over 100 consecutive trials. [source](https://gym.openai.com/envs/CartPole-v0/) 41 | > 42 | 43 | ##### Example trial gif 44 | 45 | 46 | 47 | 48 | ##### Example trial chart 49 | 50 | 51 | 52 | ##### Solved trials chart 53 | 54 | 55 | 56 | ## Author 57 | 58 | **Greg (Grzegorz) Surma** 59 | 60 | [**PORTFOLIO**](https://gsurma.github.io) 61 | 62 | [**GITHUB**](https://github.com/gsurma) 63 | 64 | [**BLOG**](https://medium.com/@gsurma) 65 | 66 | -------------------------------------------------------------------------------- /cartpole.py: -------------------------------------------------------------------------------- 1 | import random 2 | import gym 3 | import numpy as np 4 | from collections import deque 5 | from keras.models import Sequential 6 | from keras.layers import Dense 7 | from keras.optimizers import Adam 8 | 9 | 10 | from scores.score_logger import ScoreLogger 11 | 12 | ENV_NAME = "CartPole-v1" 13 | 14 | GAMMA = 0.95 15 | LEARNING_RATE = 0.001 16 | 17 | MEMORY_SIZE = 1000000 18 | BATCH_SIZE = 20 19 | 20 | EXPLORATION_MAX = 1.0 21 | EXPLORATION_MIN = 0.01 22 | EXPLORATION_DECAY = 0.995 23 | 24 | 25 | class DQNSolver: 26 | 27 | def __init__(self, observation_space, action_space): 28 | self.exploration_rate = EXPLORATION_MAX 29 | 30 | self.action_space = action_space 31 | self.memory = deque(maxlen=MEMORY_SIZE) 32 | 33 | self.model = Sequential() 34 | self.model.add(Dense(24, input_shape=(observation_space,), activation="relu")) 35 | self.model.add(Dense(24, activation="relu")) 36 | self.model.add(Dense(self.action_space, activation="linear")) 37 | self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE)) 38 | 39 | def remember(self, state, action, reward, next_state, done): 40 | self.memory.append((state, action, reward, next_state, done)) 41 | 42 | def act(self, state): 43 | if np.random.rand() < self.exploration_rate: 44 | return random.randrange(self.action_space) 45 | q_values = self.model.predict(state) 46 | return np.argmax(q_values[0]) 47 | 48 | def experience_replay(self): 49 | if len(self.memory) < BATCH_SIZE: 50 | return 51 | batch = random.sample(self.memory, BATCH_SIZE) 52 | for state, action, reward, state_next, terminal in batch: 53 | q_update = reward 54 | if not terminal: 55 | q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0])) 56 | q_values = self.model.predict(state) 57 | q_values[0][action] = q_update 58 | self.model.fit(state, q_values, verbose=0) 59 | self.exploration_rate *= EXPLORATION_DECAY 60 | self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate) 61 | 62 | 63 | def cartpole(): 64 | env = gym.make(ENV_NAME) 65 | score_logger = ScoreLogger(ENV_NAME) 66 | observation_space = env.observation_space.shape[0] 67 | action_space = env.action_space.n 68 | dqn_solver = DQNSolver(observation_space, action_space) 69 | run = 0 70 | while True: 71 | run += 1 72 | state = env.reset() 73 | state = np.reshape(state, [1, observation_space]) 74 | step = 0 75 | while True: 76 | step += 1 77 | #env.render() 78 | action = dqn_solver.act(state) 79 | state_next, reward, terminal, info = env.step(action) 80 | reward = reward if not terminal else -reward 81 | state_next = np.reshape(state_next, [1, observation_space]) 82 | dqn_solver.remember(state, action, reward, state_next, terminal) 83 | state = state_next 84 | if terminal: 85 | print "Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step) 86 | score_logger.add_score(step, run) 87 | break 88 | dqn_solver.experience_replay() 89 | 90 | 91 | if __name__ == "__main__": 92 | cartpole() 93 | -------------------------------------------------------------------------------- /scores/score_logger.py: -------------------------------------------------------------------------------- 1 | from statistics import mean 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | from collections import deque 6 | import os 7 | import csv 8 | import numpy as np 9 | 10 | SCORES_CSV_PATH = "./scores/scores.csv" 11 | SCORES_PNG_PATH = "./scores/scores.png" 12 | SOLVED_CSV_PATH = "./scores/solved.csv" 13 | SOLVED_PNG_PATH = "./scores/solved.png" 14 | AVERAGE_SCORE_TO_SOLVE = 195 15 | CONSECUTIVE_RUNS_TO_SOLVE = 100 16 | 17 | 18 | class ScoreLogger: 19 | 20 | def __init__(self, env_name): 21 | self.scores = deque(maxlen=CONSECUTIVE_RUNS_TO_SOLVE) 22 | self.env_name = env_name 23 | 24 | if os.path.exists(SCORES_PNG_PATH): 25 | os.remove(SCORES_PNG_PATH) 26 | if os.path.exists(SCORES_CSV_PATH): 27 | os.remove(SCORES_CSV_PATH) 28 | 29 | def add_score(self, score, run): 30 | self._save_csv(SCORES_CSV_PATH, score) 31 | self._save_png(input_path=SCORES_CSV_PATH, 32 | output_path=SCORES_PNG_PATH, 33 | x_label="runs", 34 | y_label="scores", 35 | average_of_n_last=CONSECUTIVE_RUNS_TO_SOLVE, 36 | show_goal=True, 37 | show_trend=True, 38 | show_legend=True) 39 | self.scores.append(score) 40 | mean_score = mean(self.scores) 41 | print "Scores: (min: " + str(min(self.scores)) + ", avg: " + str(mean_score) + ", max: " + str(max(self.scores)) + ")\n" 42 | if mean_score >= AVERAGE_SCORE_TO_SOLVE and len(self.scores) >= CONSECUTIVE_RUNS_TO_SOLVE: 43 | solve_score = run-CONSECUTIVE_RUNS_TO_SOLVE 44 | print "Solved in " + str(solve_score) + " runs, " + str(run) + " total runs." 45 | self._save_csv(SOLVED_CSV_PATH, solve_score) 46 | self._save_png(input_path=SOLVED_CSV_PATH, 47 | output_path=SOLVED_PNG_PATH, 48 | x_label="trials", 49 | y_label="steps before solve", 50 | average_of_n_last=None, 51 | show_goal=False, 52 | show_trend=False, 53 | show_legend=False) 54 | exit() 55 | 56 | def _save_png(self, input_path, output_path, x_label, y_label, average_of_n_last, show_goal, show_trend, show_legend): 57 | x = [] 58 | y = [] 59 | with open(input_path, "r") as scores: 60 | reader = csv.reader(scores) 61 | data = list(reader) 62 | for i in range(0, len(data)): 63 | x.append(int(i)) 64 | y.append(int(data[i][0])) 65 | 66 | plt.subplots() 67 | plt.plot(x, y, label="score per run") 68 | 69 | average_range = average_of_n_last if average_of_n_last is not None else len(x) 70 | plt.plot(x[-average_range:], [np.mean(y[-average_range:])] * len(y[-average_range:]), linestyle="--", label="last " + str(average_range) + " runs average") 71 | 72 | if show_goal: 73 | plt.plot(x, [AVERAGE_SCORE_TO_SOLVE] * len(x), linestyle=":", label=str(AVERAGE_SCORE_TO_SOLVE) + " score average goal") 74 | 75 | if show_trend and len(x) > 1: 76 | trend_x = x[1:] 77 | z = np.polyfit(np.array(trend_x), np.array(y[1:]), 1) 78 | p = np.poly1d(z) 79 | plt.plot(trend_x, p(trend_x), linestyle="-.", label="trend") 80 | 81 | plt.title(self.env_name) 82 | plt.xlabel(x_label) 83 | plt.ylabel(y_label) 84 | 85 | if show_legend: 86 | plt.legend(loc="upper left") 87 | 88 | plt.savefig(output_path, bbox_inches="tight") 89 | plt.close() 90 | 91 | def _save_csv(self, path, score): 92 | if not os.path.exists(path): 93 | with open(path, "w"): 94 | pass 95 | scores_file = open(path, "a") 96 | with scores_file: 97 | writer = csv.writer(scores_file) 98 | writer.writerow([score]) 99 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 60 | 61 | 62 | 63 | gam 64 | dense 65 | 66 | 67 | 68 | 70 | 71 | 83 | 84 | 85 | 86 | 87 | true 88 | DEFINITION_ORDER 89 | 90 | 91 | 92 | 93 | 94 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 |