├── scores
├── __init__.py
├── scores.png
├── solved.png
├── solved.csv
├── scores.csv
└── score_logger.py
├── .github
└── FUNDING.yml
├── requirements.txt
├── assets
├── cartpole_example.gif
└── cartpole_icon_web.png
├── .idea
├── vcs.xml
├── other.xml
├── modules.xml
├── misc.xml
├── cartpole.iml
└── workspace.xml
├── LICENSE
├── .gitignore
├── README.md
└── cartpole.py
/scores/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | patreon: gsurma
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | gym
3 | keras
4 | matplotlib
5 | tensorflow
6 |
--------------------------------------------------------------------------------
/scores/scores.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gsurma/cartpole/HEAD/scores/scores.png
--------------------------------------------------------------------------------
/scores/solved.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gsurma/cartpole/HEAD/scores/solved.png
--------------------------------------------------------------------------------
/assets/cartpole_example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gsurma/cartpole/HEAD/assets/cartpole_example.gif
--------------------------------------------------------------------------------
/assets/cartpole_icon_web.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gsurma/cartpole/HEAD/assets/cartpole_icon_web.png
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
46 |
47 |
48 | ##### Example trial chart
49 |
50 |
51 |
52 | ##### Solved trials chart
53 |
54 |
55 |
56 | ## Author
57 |
58 | **Greg (Grzegorz) Surma**
59 |
60 | [**PORTFOLIO**](https://gsurma.github.io)
61 |
62 | [**GITHUB**](https://github.com/gsurma)
63 |
64 | [**BLOG**](https://medium.com/@gsurma)
65 |
66 |
--------------------------------------------------------------------------------
/cartpole.py:
--------------------------------------------------------------------------------
1 | import random
2 | import gym
3 | import numpy as np
4 | from collections import deque
5 | from keras.models import Sequential
6 | from keras.layers import Dense
7 | from keras.optimizers import Adam
8 |
9 |
10 | from scores.score_logger import ScoreLogger
11 |
12 | ENV_NAME = "CartPole-v1"
13 |
14 | GAMMA = 0.95
15 | LEARNING_RATE = 0.001
16 |
17 | MEMORY_SIZE = 1000000
18 | BATCH_SIZE = 20
19 |
20 | EXPLORATION_MAX = 1.0
21 | EXPLORATION_MIN = 0.01
22 | EXPLORATION_DECAY = 0.995
23 |
24 |
25 | class DQNSolver:
26 |
27 | def __init__(self, observation_space, action_space):
28 | self.exploration_rate = EXPLORATION_MAX
29 |
30 | self.action_space = action_space
31 | self.memory = deque(maxlen=MEMORY_SIZE)
32 |
33 | self.model = Sequential()
34 | self.model.add(Dense(24, input_shape=(observation_space,), activation="relu"))
35 | self.model.add(Dense(24, activation="relu"))
36 | self.model.add(Dense(self.action_space, activation="linear"))
37 | self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE))
38 |
39 | def remember(self, state, action, reward, next_state, done):
40 | self.memory.append((state, action, reward, next_state, done))
41 |
42 | def act(self, state):
43 | if np.random.rand() < self.exploration_rate:
44 | return random.randrange(self.action_space)
45 | q_values = self.model.predict(state)
46 | return np.argmax(q_values[0])
47 |
48 | def experience_replay(self):
49 | if len(self.memory) < BATCH_SIZE:
50 | return
51 | batch = random.sample(self.memory, BATCH_SIZE)
52 | for state, action, reward, state_next, terminal in batch:
53 | q_update = reward
54 | if not terminal:
55 | q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0]))
56 | q_values = self.model.predict(state)
57 | q_values[0][action] = q_update
58 | self.model.fit(state, q_values, verbose=0)
59 | self.exploration_rate *= EXPLORATION_DECAY
60 | self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)
61 |
62 |
63 | def cartpole():
64 | env = gym.make(ENV_NAME)
65 | score_logger = ScoreLogger(ENV_NAME)
66 | observation_space = env.observation_space.shape[0]
67 | action_space = env.action_space.n
68 | dqn_solver = DQNSolver(observation_space, action_space)
69 | run = 0
70 | while True:
71 | run += 1
72 | state = env.reset()
73 | state = np.reshape(state, [1, observation_space])
74 | step = 0
75 | while True:
76 | step += 1
77 | #env.render()
78 | action = dqn_solver.act(state)
79 | state_next, reward, terminal, info = env.step(action)
80 | reward = reward if not terminal else -reward
81 | state_next = np.reshape(state_next, [1, observation_space])
82 | dqn_solver.remember(state, action, reward, state_next, terminal)
83 | state = state_next
84 | if terminal:
85 | print "Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step)
86 | score_logger.add_score(step, run)
87 | break
88 | dqn_solver.experience_replay()
89 |
90 |
91 | if __name__ == "__main__":
92 | cartpole()
93 |
--------------------------------------------------------------------------------
/scores/score_logger.py:
--------------------------------------------------------------------------------
1 | from statistics import mean
2 | import matplotlib
3 | matplotlib.use('Agg')
4 | import matplotlib.pyplot as plt
5 | from collections import deque
6 | import os
7 | import csv
8 | import numpy as np
9 |
10 | SCORES_CSV_PATH = "./scores/scores.csv"
11 | SCORES_PNG_PATH = "./scores/scores.png"
12 | SOLVED_CSV_PATH = "./scores/solved.csv"
13 | SOLVED_PNG_PATH = "./scores/solved.png"
14 | AVERAGE_SCORE_TO_SOLVE = 195
15 | CONSECUTIVE_RUNS_TO_SOLVE = 100
16 |
17 |
18 | class ScoreLogger:
19 |
20 | def __init__(self, env_name):
21 | self.scores = deque(maxlen=CONSECUTIVE_RUNS_TO_SOLVE)
22 | self.env_name = env_name
23 |
24 | if os.path.exists(SCORES_PNG_PATH):
25 | os.remove(SCORES_PNG_PATH)
26 | if os.path.exists(SCORES_CSV_PATH):
27 | os.remove(SCORES_CSV_PATH)
28 |
29 | def add_score(self, score, run):
30 | self._save_csv(SCORES_CSV_PATH, score)
31 | self._save_png(input_path=SCORES_CSV_PATH,
32 | output_path=SCORES_PNG_PATH,
33 | x_label="runs",
34 | y_label="scores",
35 | average_of_n_last=CONSECUTIVE_RUNS_TO_SOLVE,
36 | show_goal=True,
37 | show_trend=True,
38 | show_legend=True)
39 | self.scores.append(score)
40 | mean_score = mean(self.scores)
41 | print "Scores: (min: " + str(min(self.scores)) + ", avg: " + str(mean_score) + ", max: " + str(max(self.scores)) + ")\n"
42 | if mean_score >= AVERAGE_SCORE_TO_SOLVE and len(self.scores) >= CONSECUTIVE_RUNS_TO_SOLVE:
43 | solve_score = run-CONSECUTIVE_RUNS_TO_SOLVE
44 | print "Solved in " + str(solve_score) + " runs, " + str(run) + " total runs."
45 | self._save_csv(SOLVED_CSV_PATH, solve_score)
46 | self._save_png(input_path=SOLVED_CSV_PATH,
47 | output_path=SOLVED_PNG_PATH,
48 | x_label="trials",
49 | y_label="steps before solve",
50 | average_of_n_last=None,
51 | show_goal=False,
52 | show_trend=False,
53 | show_legend=False)
54 | exit()
55 |
56 | def _save_png(self, input_path, output_path, x_label, y_label, average_of_n_last, show_goal, show_trend, show_legend):
57 | x = []
58 | y = []
59 | with open(input_path, "r") as scores:
60 | reader = csv.reader(scores)
61 | data = list(reader)
62 | for i in range(0, len(data)):
63 | x.append(int(i))
64 | y.append(int(data[i][0]))
65 |
66 | plt.subplots()
67 | plt.plot(x, y, label="score per run")
68 |
69 | average_range = average_of_n_last if average_of_n_last is not None else len(x)
70 | plt.plot(x[-average_range:], [np.mean(y[-average_range:])] * len(y[-average_range:]), linestyle="--", label="last " + str(average_range) + " runs average")
71 |
72 | if show_goal:
73 | plt.plot(x, [AVERAGE_SCORE_TO_SOLVE] * len(x), linestyle=":", label=str(AVERAGE_SCORE_TO_SOLVE) + " score average goal")
74 |
75 | if show_trend and len(x) > 1:
76 | trend_x = x[1:]
77 | z = np.polyfit(np.array(trend_x), np.array(y[1:]), 1)
78 | p = np.poly1d(z)
79 | plt.plot(trend_x, p(trend_x), linestyle="-.", label="trend")
80 |
81 | plt.title(self.env_name)
82 | plt.xlabel(x_label)
83 | plt.ylabel(y_label)
84 |
85 | if show_legend:
86 | plt.legend(loc="upper left")
87 |
88 | plt.savefig(output_path, bbox_inches="tight")
89 | plt.close()
90 |
91 | def _save_csv(self, path, score):
92 | if not os.path.exists(path):
93 | with open(path, "w"):
94 | pass
95 | scores_file = open(path, "a")
96 | with scores_file:
97 | writer = csv.writer(scores_file)
98 | writer.writerow([score])
99 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |