├── .gitignore
├── .vscode
└── launch.json
├── LICENSE
├── README.md
├── agent.py
├── dqn.py
├── experience_replay.py
├── hyperparameters.yml
└── runs
├── cartpole1.log
├── cartpole1.png
├── cartpole1.pt
├── flappybird1.log
├── flappybird1.png
└── flappybird1.pt
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Train cartpole1",
9 | "type": "debugpy",
10 | "request": "launch",
11 | "program": "agent.py",
12 | "console": "integratedTerminal",
13 | "args": ["cartpole1", "--train"]
14 | },
15 | {
16 | "name": "Train flappybird1",
17 | "type": "debugpy",
18 | "request": "launch",
19 | "program": "agent.py",
20 | "console": "integratedTerminal",
21 | "args": ["flappybird1", "--train"]
22 | }
23 | ]
24 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 johnnycode8
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Implement DQN in PyTorch - Beginner Tutorials
4 |
5 | This repository contains an implementation of the DQN algorithm from my Deep Q-Learning, aka Deep Q-Network (DQN), YouTube ([@johnnycode](https://www.youtube.com/@johnnycode)) tutorial series. In this series, we code the DQN algorithm from scratch with Python and PyTorch, and then use it to train the Flappy Bird game. If you find the code and tutorials helpful, please consider supporting me:
6 |
7 |
8 |
9 |
10 | If you are brand new to Reinforcement Learning, you may want to start with my Q-Learning tutorials first, then continue on to Deep Q-Learning: https://github.com/johnnycode8/gym_solutions
11 |
12 | ## 1. Install FlappyBird Gymnasium & Setup Development Environment
13 | We'll set up our development environment on VSCode and Conda, and then install Flappy Bird Gymnasium, PyTorch, and Tensorflow (we'll use Tensorflow's TensorBoard to monitor training progress). There are 2 versions of Flappy Bird, one version provides the position of the last pipe, the next pipe, and the bird and the other version that provides RGB (image) frames. The RGB version requires a Convolutional Neural Network, which is more complicated, while the positional values version can be trained using a regular Neural Network. We'll start with the positional version and maybe tackle the RGB version in the future.
14 |
15 |
16 |
17 | ## 2. Implement the Deep Q-Network Module
18 | A Deep Q-Network is nothing more than a regular Neural Network with fully connected layers. This network is the brain of the bird. What makes this neural network special is that the network's input layer represents the State of the environment and the output layer represents the expected Q-values of the set of Actions. The State is a combination of the position of the last pipe, the next pipe, and the bird. The Action with the highest Q-value is the best Action for a given State.
19 |
20 |
21 |
22 | ## 3. Implement Experience Replay & Load Hyperparameters from YAML
23 | The concept of Experience Replay is to collect a large set of "experiences" so that the DQN can be trained using smaller samples. Experience Replay is essential because we need to show the neural network many examples of similar situations to help it learn general patterns. An "experience" consists of the current state, the action that was taken, the resulting new state, the reward that was received, and a flag to indicate if the new state is terminal. When training, we randomly sample from this memory to ensure diverse training data. We also create a separate hyperparameter file to manage parameters like replay memory size, training batch size, and Epsilon for the Epsilon-Greedy algorithm. This way, we can easily change these parameters for different environments.
24 |
25 |
26 |
27 | ## 4. Implement Epsilon-Greedy & Debug the Training Loop
28 | The Epsilon-Greedy algorithm is use for exploration (bird taking random action) and exploitation (bird taking best known action at the moment). We start by initializing the Epsilon value to 1, so initially, the agent will choose 100% random actions. As the training progresses, we'll slowly decay Epsilon, making the agent more likely to select actions based on its learned policy. We'll also convert all necessary inputs to tensors before feeding them into our PyTorch-implemented DQN, so we could use CUDA (GPU) to train the network.
29 |
30 |
31 |
32 | ## 5. Implement the Target Network
33 | Using the DQN module from earlier, we instantiate a Policy Network and a Target Network. The Target Network starts off identical to the Policy Network. The Policy Network represents the brain of the bird; this is the network that we train. The Target Network is used to estimate target Q-values, which is used to train the Policy Network. While it is possible to use the Policy Network to perform the Q-value estimation, the Policy Network is constantly changing during training, so it is more stable to use a Target Network for estimation. After a number of steps (actions), we sync the two networks by copying the Policy Network's weights and biases into the Target Network.
34 |
35 |
36 |
37 | ## 6. Explain Loss, Backpropagation, and Gradient Descent
38 | In case you are not familar with how Neural Networks learn, this video explains the high-level process. The Loss (using Mean Squared Error function, as an example) measures how far our current policy's Q-values are from our target Q-values. Gradient Descent is used to calculate the slope (gradient) of the loss function, which provides an indication of the direction to adjust the weights and biases to lower loss. Backpropagation is the process of performing Gradient Descent and adjusting the weights and biases in the direction that minimizes the loss.
39 |
40 |
41 |
42 | ## 7. Optimize Target Network PyTorch Calculations
43 | In the implementation of the Target Network calculations from earlier, we're looping through a batch of experiences and calculating target Q-values for each one. That code is easy to read and understand, however, it is slow to execute because we're processing each experience one at a time. PyTorch is capable of processing the whole batch at once, which is much more efficient. We'll modify the code to take advantage of PyTorch's computational capabilities.
44 |
45 |
46 |
47 | ## 8. Test DQN Algorithm on CartPole-v1
48 | Reinforcement Learning is fragile as there are many factors that can cause training to fail. We want to make sure that the DQN code we have is bug free. We can test the DQN code on a simple environment that can give us feedback quickly. The Gymnasium Cart Pole environment is perfect for that. Once we are certain that the code is solid, we can finally train Flappy Bird!
49 |
50 |
51 |
52 | ## 9. Train DQN Algorithm on Flappy Bird!
53 | Finally, we can train our DQN algorithm on Flappy Bird! I'll show the results of a 24-hour training session. The bird can fly past quite a few pipe, however, it did not learn to fly indefinitely, that requires perhaps several days of training. I explain why it takes so long to train using DQN.
54 |
55 |
56 |
57 |
58 | ## 10. Double DQN Explained and Implemented
59 | Since the introduction of DQN, there has been many enhancements to the algorithm. Double DQN (DDQN) was the first major enhancement. I explain the concept behind Double DQN using Flappy Bird as an example. The main objective of Double DQN is to reduce the time wasted exploring paths that don't lead to a good outcomes. However, it's important to note that DDQN may not always lead to significant performance gains in all environments.
60 |
61 |
62 |
63 |
64 | ## 11. Dueling DQN Explained and Implemented
65 | Dueling Architecture or Dueling DQN is another enhancement to the DQN algorithm. The main objective of Dueling DQN is to improve training efficiency by splitting the Q-values into two components: Value and Advantages. I explain the concept behind Dueling DQN using Flappy Bird as an example and also implement the Dueling Architecture changes in the DQN module.
66 |
67 |
68 |
69 |
70 |
(back to top)
71 |
--------------------------------------------------------------------------------
/agent.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 | import numpy as np
3 |
4 | import matplotlib
5 | import matplotlib.pyplot as plt
6 |
7 | import random
8 | import torch
9 | from torch import nn
10 | import yaml
11 |
12 | from experience_replay import ReplayMemory
13 | from dqn import DQN
14 |
15 | from datetime import datetime, timedelta
16 | import argparse
17 | import itertools
18 |
19 | import flappy_bird_gymnasium
20 | import os
21 |
22 | # For printing date and time
23 | DATE_FORMAT = "%m-%d %H:%M:%S"
24 |
25 | # Directory for saving run info
26 | RUNS_DIR = "runs"
27 | os.makedirs(RUNS_DIR, exist_ok=True)
28 |
29 | # 'Agg': used to generate plots as images and save them to a file instead of rendering to screen
30 | matplotlib.use('Agg')
31 |
32 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
33 | device = 'cpu' # force cpu, sometimes GPU not always faster than CPU due to overhead of moving data to GPU
34 |
35 | # Deep Q-Learning Agent
36 | class Agent():
37 |
38 | def __init__(self, hyperparameter_set):
39 | with open('hyperparameters.yml', 'r') as file:
40 | all_hyperparameter_sets = yaml.safe_load(file)
41 | hyperparameters = all_hyperparameter_sets[hyperparameter_set]
42 | # print(hyperparameters)
43 |
44 | self.hyperparameter_set = hyperparameter_set
45 |
46 | # Hyperparameters (adjustable)
47 | self.env_id = hyperparameters['env_id']
48 | self.learning_rate_a = hyperparameters['learning_rate_a'] # learning rate (alpha)
49 | self.discount_factor_g = hyperparameters['discount_factor_g'] # discount rate (gamma)
50 | self.network_sync_rate = hyperparameters['network_sync_rate'] # number of steps the agent takes before syncing the policy and target network
51 | self.replay_memory_size = hyperparameters['replay_memory_size'] # size of replay memory
52 | self.mini_batch_size = hyperparameters['mini_batch_size'] # size of the training data set sampled from the replay memory
53 | self.epsilon_init = hyperparameters['epsilon_init'] # 1 = 100% random actions
54 | self.epsilon_decay = hyperparameters['epsilon_decay'] # epsilon decay rate
55 | self.epsilon_min = hyperparameters['epsilon_min'] # minimum epsilon value
56 | self.stop_on_reward = hyperparameters['stop_on_reward'] # stop training after reaching this number of rewards
57 | self.fc1_nodes = hyperparameters['fc1_nodes']
58 | self.env_make_params = hyperparameters.get('env_make_params',{}) # Get optional environment-specific parameters, default to empty dict
59 | self.enable_double_dqn = hyperparameters['enable_double_dqn'] # double dqn on/off flag
60 | self.enable_dueling_dqn = hyperparameters['enable_dueling_dqn'] # dueling dqn on/off flag
61 |
62 | # Neural Network
63 | self.loss_fn = nn.MSELoss() # NN Loss function. MSE=Mean Squared Error can be swapped to something else.
64 | self.optimizer = None # NN Optimizer. Initialize later.
65 |
66 | # Path to Run info
67 | self.LOG_FILE = os.path.join(RUNS_DIR, f'{self.hyperparameter_set}.log')
68 | self.MODEL_FILE = os.path.join(RUNS_DIR, f'{self.hyperparameter_set}.pt')
69 | self.GRAPH_FILE = os.path.join(RUNS_DIR, f'{self.hyperparameter_set}.png')
70 |
71 | def run(self, is_training=True, render=False):
72 | if is_training:
73 | start_time = datetime.now()
74 | last_graph_update_time = start_time
75 |
76 | log_message = f"{start_time.strftime(DATE_FORMAT)}: Training starting..."
77 | print(log_message)
78 | with open(self.LOG_FILE, 'w') as file:
79 | file.write(log_message + '\n')
80 |
81 | # Create instance of the environment.
82 | # Use "**self.env_make_params" to pass in environment-specific parameters from hyperparameters.yml.
83 | env = gym.make(self.env_id, render_mode='human' if render else None, **self.env_make_params)
84 |
85 | # Number of possible actions
86 | num_actions = env.action_space.n
87 |
88 | # Get observation space size
89 | num_states = env.observation_space.shape[0] # Expecting type: Box(low, high, (shape0,), float64)
90 |
91 | # List to keep track of rewards collected per episode.
92 | rewards_per_episode = []
93 |
94 | # Create policy and target network. Number of nodes in the hidden layer can be adjusted.
95 | policy_dqn = DQN(num_states, num_actions, self.fc1_nodes, self.enable_dueling_dqn).to(device)
96 |
97 | if is_training:
98 | # Initialize epsilon
99 | epsilon = self.epsilon_init
100 |
101 | # Initialize replay memory
102 | memory = ReplayMemory(self.replay_memory_size)
103 |
104 | # Create the target network and make it identical to the policy network
105 | target_dqn = DQN(num_states, num_actions, self.fc1_nodes, self.enable_dueling_dqn).to(device)
106 | target_dqn.load_state_dict(policy_dqn.state_dict())
107 |
108 | # Policy network optimizer. "Adam" optimizer can be swapped to something else.
109 | self.optimizer = torch.optim.Adam(policy_dqn.parameters(), lr=self.learning_rate_a)
110 |
111 | # List to keep track of epsilon decay
112 | epsilon_history = []
113 |
114 | # Track number of steps taken. Used for syncing policy => target network.
115 | step_count=0
116 |
117 | # Track best reward
118 | best_reward = -9999999
119 | else:
120 | # Load learned policy
121 | policy_dqn.load_state_dict(torch.load(self.MODEL_FILE))
122 |
123 | # switch model to evaluation mode
124 | policy_dqn.eval()
125 |
126 | # Train INDEFINITELY, manually stop the run when you are satisfied (or unsatisfied) with the results
127 | for episode in itertools.count():
128 |
129 | state, _ = env.reset() # Initialize environment. Reset returns (state,info).
130 | state = torch.tensor(state, dtype=torch.float, device=device) # Convert state to tensor directly on device
131 |
132 | terminated = False # True when agent reaches goal or fails
133 | episode_reward = 0.0 # Used to accumulate rewards per episode
134 |
135 | # Perform actions until episode terminates or reaches max rewards
136 | # (on some envs, it is possible for the agent to train to a point where it NEVER terminates, so stop on reward is necessary)
137 | while(not terminated and episode_reward < self.stop_on_reward):
138 |
139 | # Select action based on epsilon-greedy
140 | if is_training and random.random() < epsilon:
141 | # select random action
142 | action = env.action_space.sample()
143 | action = torch.tensor(action, dtype=torch.int64, device=device)
144 | else:
145 | # select best action
146 | with torch.no_grad():
147 | # state.unsqueeze(dim=0): Pytorch expects a batch layer, so add batch dimension i.e. tensor([1, 2, 3]) unsqueezes to tensor([[1, 2, 3]])
148 | # policy_dqn returns tensor([[1], [2], [3]]), so squeeze it to tensor([1, 2, 3]).
149 | # argmax finds the index of the largest element.
150 | action = policy_dqn(state.unsqueeze(dim=0)).squeeze().argmax()
151 |
152 | # Execute action. Truncated and info is not used.
153 | new_state,reward,terminated,truncated,info = env.step(action.item())
154 |
155 | # Accumulate rewards
156 | episode_reward += reward
157 |
158 | # Convert new state and reward to tensors on device
159 | new_state = torch.tensor(new_state, dtype=torch.float, device=device)
160 | reward = torch.tensor(reward, dtype=torch.float, device=device)
161 |
162 | if is_training:
163 | # Save experience into memory
164 | memory.append((state, action, new_state, reward, terminated))
165 |
166 | # Increment step counter
167 | step_count+=1
168 |
169 | # Move to the next state
170 | state = new_state
171 |
172 | # Keep track of the rewards collected per episode.
173 | rewards_per_episode.append(episode_reward)
174 |
175 | # Save model when new best reward is obtained.
176 | if is_training:
177 | if episode_reward > best_reward:
178 | log_message = f"{datetime.now().strftime(DATE_FORMAT)}: New best reward {episode_reward:0.1f} ({(episode_reward-best_reward)/best_reward*100:+.1f}%) at episode {episode}, saving model..."
179 | print(log_message)
180 | with open(self.LOG_FILE, 'a') as file:
181 | file.write(log_message + '\n')
182 |
183 | torch.save(policy_dqn.state_dict(), self.MODEL_FILE)
184 | best_reward = episode_reward
185 |
186 |
187 | # Update graph every x seconds
188 | current_time = datetime.now()
189 | if current_time - last_graph_update_time > timedelta(seconds=10):
190 | self.save_graph(rewards_per_episode, epsilon_history)
191 | last_graph_update_time = current_time
192 |
193 | # If enough experience has been collected
194 | if len(memory)>self.mini_batch_size:
195 | mini_batch = memory.sample(self.mini_batch_size)
196 | self.optimize(mini_batch, policy_dqn, target_dqn)
197 |
198 | # Decay epsilon
199 | epsilon = max(epsilon * self.epsilon_decay, self.epsilon_min)
200 | epsilon_history.append(epsilon)
201 |
202 | # Copy policy network to target network after a certain number of steps
203 | if step_count > self.network_sync_rate:
204 | target_dqn.load_state_dict(policy_dqn.state_dict())
205 | step_count=0
206 |
207 |
208 | def save_graph(self, rewards_per_episode, epsilon_history):
209 | # Save plots
210 | fig = plt.figure(1)
211 |
212 | # Plot average rewards (Y-axis) vs episodes (X-axis)
213 | mean_rewards = np.zeros(len(rewards_per_episode))
214 | for x in range(len(mean_rewards)):
215 | mean_rewards[x] = np.mean(rewards_per_episode[max(0, x-99):(x+1)])
216 | plt.subplot(121) # plot on a 1 row x 2 col grid, at cell 1
217 | # plt.xlabel('Episodes')
218 | plt.ylabel('Mean Rewards')
219 | plt.plot(mean_rewards)
220 |
221 | # Plot epsilon decay (Y-axis) vs episodes (X-axis)
222 | plt.subplot(122) # plot on a 1 row x 2 col grid, at cell 2
223 | # plt.xlabel('Time Steps')
224 | plt.ylabel('Epsilon Decay')
225 | plt.plot(epsilon_history)
226 |
227 | plt.subplots_adjust(wspace=1.0, hspace=1.0)
228 |
229 | # Save plots
230 | fig.savefig(self.GRAPH_FILE)
231 | plt.close(fig)
232 |
233 |
234 | # Optimize policy network
235 | def optimize(self, mini_batch, policy_dqn, target_dqn):
236 |
237 | # Transpose the list of experiences and separate each element
238 | states, actions, new_states, rewards, terminations = zip(*mini_batch)
239 |
240 | # Stack tensors to create batch tensors
241 | # tensor([[1,2,3]])
242 | states = torch.stack(states)
243 |
244 | actions = torch.stack(actions)
245 |
246 | new_states = torch.stack(new_states)
247 |
248 | rewards = torch.stack(rewards)
249 | terminations = torch.tensor(terminations).float().to(device)
250 |
251 | with torch.no_grad():
252 | if self.enable_double_dqn:
253 | best_actions_from_policy = policy_dqn(new_states).argmax(dim=1)
254 |
255 | target_q = rewards + (1-terminations) * self.discount_factor_g * \
256 | target_dqn(new_states).gather(dim=1, index=best_actions_from_policy.unsqueeze(dim=1)).squeeze()
257 | else:
258 | # Calculate target Q values (expected returns)
259 | target_q = rewards + (1-terminations) * self.discount_factor_g * target_dqn(new_states).max(dim=1)[0]
260 | '''
261 | target_dqn(new_states) ==> tensor([[1,2,3],[4,5,6]])
262 | .max(dim=1) ==> torch.return_types.max(values=tensor([3,6]), indices=tensor([3, 0, 0, 1]))
263 | [0] ==> tensor([3,6])
264 | '''
265 |
266 | # Calcuate Q values from current policy
267 | current_q = policy_dqn(states).gather(dim=1, index=actions.unsqueeze(dim=1)).squeeze()
268 | '''
269 | policy_dqn(states) ==> tensor([[1,2,3],[4,5,6]])
270 | actions.unsqueeze(dim=1)
271 | .gather(1, actions.unsqueeze(dim=1)) ==>
272 | .squeeze() ==>
273 | '''
274 |
275 | # Compute loss
276 | loss = self.loss_fn(current_q, target_q)
277 |
278 | # Optimize the model (backpropagation)
279 | self.optimizer.zero_grad() # Clear gradients
280 | loss.backward() # Compute gradients
281 | self.optimizer.step() # Update network parameters i.e. weights and biases
282 |
283 | if __name__ == '__main__':
284 | # Parse command line inputs
285 | parser = argparse.ArgumentParser(description='Train or test model.')
286 | parser.add_argument('hyperparameters', help='')
287 | parser.add_argument('--train', help='Training mode', action='store_true')
288 | args = parser.parse_args()
289 |
290 | dql = Agent(hyperparameter_set=args.hyperparameters)
291 |
292 | if args.train:
293 | dql.run(is_training=True)
294 | else:
295 | dql.run(is_training=False, render=True)
--------------------------------------------------------------------------------
/dqn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | class DQN(nn.Module):
6 |
7 | def __init__(self, state_dim, action_dim, hidden_dim=256, enable_dueling_dqn=True):
8 | super(DQN, self).__init__()
9 |
10 | self.enable_dueling_dqn=enable_dueling_dqn
11 |
12 | self.fc1 = nn.Linear(state_dim, hidden_dim)
13 |
14 | if self.enable_dueling_dqn:
15 | # Value stream
16 | self.fc_value = nn.Linear(hidden_dim, 256)
17 | self.value = nn.Linear(256, 1)
18 |
19 | # Advantages stream
20 | self.fc_advantages = nn.Linear(hidden_dim, 256)
21 | self.advantages = nn.Linear(256, action_dim)
22 |
23 | else:
24 | self.output = nn.Linear(hidden_dim, action_dim)
25 |
26 | def forward(self, x):
27 | x = F.relu(self.fc1(x))
28 |
29 | if self.enable_dueling_dqn:
30 | # Value calc
31 | v = F.relu(self.fc_value(x))
32 | V = self.value(v)
33 |
34 | # Advantages calc
35 | a = F.relu(self.fc_advantages(x))
36 | A = self.advantages(a)
37 |
38 | # Calc Q
39 | Q = V + A - torch.mean(A, dim=1, keepdim=True)
40 |
41 | else:
42 | Q = self.output(x)
43 |
44 | return Q
45 |
46 |
47 | if __name__ == '__main__':
48 | state_dim = 12
49 | action_dim = 2
50 | net = DQN(state_dim, action_dim)
51 | state = torch.randn(10, state_dim)
52 | output = net(state)
53 | print(output)
54 |
55 |
--------------------------------------------------------------------------------
/experience_replay.py:
--------------------------------------------------------------------------------
1 | # Define memory for Experience Replay
2 | from collections import deque
3 | import random
4 | class ReplayMemory():
5 | def __init__(self, maxlen, seed=None):
6 | self.memory = deque([], maxlen=maxlen)
7 |
8 | # Optional seed for reproducibility
9 | if seed is not None:
10 | random.seed(seed)
11 |
12 | def append(self, transition):
13 | self.memory.append(transition)
14 |
15 | def sample(self, sample_size):
16 | return random.sample(self.memory, sample_size)
17 |
18 | def __len__(self):
19 | return len(self.memory)
--------------------------------------------------------------------------------
/hyperparameters.yml:
--------------------------------------------------------------------------------
1 | cartpole1:
2 | env_id: CartPole-v1
3 | replay_memory_size: 100000
4 | mini_batch_size: 64
5 | epsilon_init: 1
6 | epsilon_decay: 0.9995
7 | epsilon_min: 0.01
8 | network_sync_rate: 100
9 | learning_rate_a: 0.001
10 | discount_factor_g: 0.99
11 | stop_on_reward: 100000
12 | fc1_nodes: 128
13 | enable_double_dqn: False
14 | enable_dueling_dqn: True
15 | flappybird1:
16 | env_id: FlappyBird-v0
17 | replay_memory_size: 100000
18 | mini_batch_size: 32
19 | epsilon_init: 1
20 | epsilon_decay: 0.99_99_5
21 | epsilon_min: 0.05
22 | network_sync_rate: 10
23 | learning_rate_a: 0.0001
24 | discount_factor_g: 0.99
25 | stop_on_reward: 100000
26 | fc1_nodes: 512
27 | env_make_params:
28 | use_lidar: False
29 | enable_double_dqn: True
30 | enable_dueling_dqn: True
--------------------------------------------------------------------------------
/runs/cartpole1.log:
--------------------------------------------------------------------------------
1 | 11-14 22:10:33: Training starting...
2 | 11-14 22:10:35: New best reward 17.0 (-100.0%) at episode 0, saving model...
3 | 11-14 22:10:35: New best reward 22.0 (+29.4%) at episode 1, saving model...
4 | 11-14 22:10:35: New best reward 26.0 (+18.2%) at episode 5, saving model...
5 | 11-14 22:10:35: New best reward 28.0 (+7.7%) at episode 8, saving model...
6 | 11-14 22:10:35: New best reward 35.0 (+25.0%) at episode 10, saving model...
7 | 11-14 22:10:35: New best reward 45.0 (+28.6%) at episode 22, saving model...
8 | 11-14 22:10:35: New best reward 46.0 (+2.2%) at episode 49, saving model...
9 | 11-14 22:10:35: New best reward 60.0 (+30.4%) at episode 62, saving model...
10 | 11-14 22:10:36: New best reward 63.0 (+5.0%) at episode 217, saving model...
11 | 11-14 22:10:36: New best reward 83.0 (+31.7%) at episode 232, saving model...
12 | 11-14 22:10:38: New best reward 117.0 (+41.0%) at episode 663, saving model...
13 | 11-14 22:10:40: New best reward 124.0 (+6.0%) at episode 959, saving model...
14 | 11-14 22:10:45: New best reward 129.0 (+4.0%) at episode 1484, saving model...
15 | 11-14 22:10:46: New best reward 149.0 (+15.5%) at episode 1550, saving model...
16 | 11-14 22:10:48: New best reward 158.0 (+6.0%) at episode 1745, saving model...
17 | 11-14 22:10:49: New best reward 194.0 (+22.8%) at episode 1810, saving model...
18 | 11-14 22:10:49: New best reward 254.0 (+30.9%) at episode 1819, saving model...
19 | 11-14 22:10:54: New best reward 260.0 (+2.4%) at episode 2146, saving model...
20 | 11-14 22:10:54: New best reward 263.0 (+1.2%) at episode 2172, saving model...
21 | 11-14 22:10:57: New best reward 402.0 (+52.9%) at episode 2317, saving model...
22 | 11-14 22:11:00: New best reward 408.0 (+1.5%) at episode 2440, saving model...
23 | 11-14 22:11:04: New best reward 700.0 (+71.6%) at episode 2558, saving model...
24 | 11-14 22:11:19: New best reward 761.0 (+8.7%) at episode 2947, saving model...
25 | 11-14 22:11:20: New best reward 1168.0 (+53.5%) at episode 2962, saving model...
26 | 11-14 22:11:35: New best reward 1398.0 (+19.7%) at episode 3217, saving model...
27 | 11-14 22:11:38: New best reward 1581.0 (+13.1%) at episode 3271, saving model...
28 | 11-14 22:11:41: New best reward 1747.0 (+10.5%) at episode 3306, saving model...
29 | 11-14 22:13:09: New best reward 5349.0 (+206.2%) at episode 4494, saving model...
30 | 11-14 22:13:23: New best reward 100000.0 (+1769.5%) at episode 4512, saving model...
31 |
--------------------------------------------------------------------------------
/runs/cartpole1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/johnnycode8/dqn_pytorch/24ccb030ed9a19cd4b3ff7f6d815b7bf6582661f/runs/cartpole1.png
--------------------------------------------------------------------------------
/runs/cartpole1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/johnnycode8/dqn_pytorch/24ccb030ed9a19cd4b3ff7f6d815b7bf6582661f/runs/cartpole1.pt
--------------------------------------------------------------------------------
/runs/flappybird1.log:
--------------------------------------------------------------------------------
1 | 06-15 21:16:23: Training starting...
2 | 06-15 21:16:34: New best reward -6.9 (-100.0%) at episode 0, saving model...
3 | 06-15 21:16:34: New best reward -5.7 (-17.4%) at episode 1, saving model...
4 | 06-15 21:16:35: New best reward -3.9 (-31.6%) at episode 30, saving model...
5 | 06-15 21:16:35: New best reward -2.1 (-46.2%) at episode 32, saving model...
6 | 06-15 21:16:43: New best reward -1.5 (-28.6%) at episode 958, saving model...
7 | 06-15 21:16:48: New best reward 3.3 (-320.0%) at episode 1523, saving model...
8 | 06-15 21:16:52: New best reward 3.9 (+18.2%) at episode 1935, saving model...
9 | 06-15 21:17:57: New best reward 4.0 (+2.6%) at episode 9189, saving model...
10 | 06-15 21:18:15: New best reward 4.2 (+5.0%) at episode 11036, saving model...
11 | 06-15 21:18:27: New best reward 4.8 (+14.3%) at episode 12231, saving model...
12 | 06-15 21:18:30: New best reward 4.9 (+2.1%) at episode 12484, saving model...
13 | 06-15 21:18:57: New best reward 6.5 (+32.7%) at episode 15071, saving model...
14 | 06-15 21:19:06: New best reward 8.4 (+29.2%) at episode 15937, saving model...
15 | 06-15 21:21:06: New best reward 9.4 (+11.9%) at episode 26098, saving model...
16 | 06-15 22:36:35: New best reward 10.7 (+13.8%) at episode 240814, saving model...
17 | 06-15 22:36:36: New best reward 11.0 (+2.8%) at episode 240882, saving model...
18 | 06-15 22:36:37: New best reward 12.9 (+17.3%) at episode 240931, saving model...
19 | 06-15 22:39:38: New best reward 15.0 (+16.3%) at episode 247469, saving model...
20 | 06-15 22:41:46: New best reward 17.9 (+19.3%) at episode 251675, saving model...
21 | 06-15 22:42:37: New best reward 18.8 (+5.0%) at episode 253192, saving model...
22 | 06-15 22:42:48: New best reward 22.9 (+21.8%) at episode 253493, saving model...
23 | 06-15 22:44:35: New best reward 23.5 (+2.6%) at episode 256924, saving model...
24 | 06-15 22:44:53: New best reward 26.9 (+14.5%) at episode 257482, saving model...
25 | 06-15 22:44:58: New best reward 31.9 (+18.6%) at episode 257792, saving model...
26 | 06-15 22:46:36: New best reward 34.6 (+8.5%) at episode 260551, saving model...
27 | 06-15 22:46:56: New best reward 36.4 (+5.2%) at episode 261078, saving model...
28 | 06-15 22:53:26: New best reward 40.9 (+12.4%) at episode 271818, saving model...
29 | 06-15 22:57:06: New best reward 41.0 (+0.2%) at episode 278161, saving model...
30 | 06-15 23:04:49: New best reward 48.7 (+18.8%) at episode 290927, saving model...
31 | 06-15 23:14:08: New best reward 60.7 (+24.6%) at episode 304692, saving model...
32 | 06-16 01:08:16: New best reward 64.4 (+6.1%) at episode 458928, saving model...
33 | 06-16 02:21:54: New best reward 68.9 (+7.0%) at episode 522084, saving model...
34 | 06-16 06:01:32: New best reward 80.9 (+17.4%) at episode 656794, saving model...
35 | 06-16 08:31:42: New best reward 82.9 (+2.5%) at episode 733104, saving model...
36 | 06-16 13:25:42: New best reward 85.7 (+3.4%) at episode 845852, saving model...
37 | 06-16 20:21:55: New best reward 106.4 (+24.2%) at episode 941863, saving model...
38 |
--------------------------------------------------------------------------------
/runs/flappybird1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/johnnycode8/dqn_pytorch/24ccb030ed9a19cd4b3ff7f6d815b7bf6582661f/runs/flappybird1.png
--------------------------------------------------------------------------------
/runs/flappybird1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/johnnycode8/dqn_pytorch/24ccb030ed9a19cd4b3ff7f6d815b7bf6582661f/runs/flappybird1.pt
--------------------------------------------------------------------------------