├── .github
└── FUNDING.yml
├── .gitignore
├── LICENCE
├── README.md
├── data
├── readme_gifs
│ └── BreakoutNoFrameskip-v4_1.gif
├── readme_pics
│ └── dqn.jpg
└── readme_visualizations
│ ├── breakout.jpg
│ ├── fps_metric.PNG
│ ├── grads.PNG
│ ├── huber_loss.PNG
│ ├── pong.jpg
│ ├── rewards_per_episode.PNG
│ ├── state_all_frames.PNG
│ ├── state_initial.PNG
│ └── steps_per_episode.PNG
├── environment.yml
├── evaluate_DQN_script.py
├── models
└── definitions
│ └── DQN.py
├── playground.py
├── train_DQN_script.py
└── utils
├── constants.py
├── replay_buffer.py
├── utils.py
└── video_utils.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | patreon: theaiepiphany
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # PyCharm IDE
2 | .idea
3 | __pycache__
4 |
5 | # Jupyter notebook checkpoints
6 | .ipynb_checkpoints
7 |
8 | # Models checkpoints and binaries
9 | models/checkpoints
10 | models/binaries
11 |
12 | # Data directory
13 | data/
14 |
15 | # Tensorboard dump dir
16 | runs/
17 | runs_baseline/
18 |
19 | # OpenAI gym Monitor's video dump location
20 | gym_monitor/
21 |
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Aleksa Gordić
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Reinforcement Learning (PyTorch) :robot: + :cake: = :heart:
2 |
3 | This repo will contain PyTorch implementation of various fundamental RL algorithms.
4 | It's aimed at making it **easy** to start playing and learning about RL.
5 |
6 | The problem I came across investigating other DQN projects is that they either:
7 | * Don't have any evidence that they've actually achieved the published results
8 | * Don't have a "smart" replay buffer (i.e. they allocate (1M, 4, 84, 84) ~ 28 GBs! instead of (1M, 84, 84) ~ 7 GB)
9 | * Lack of visualizations and debugging utils
10 |
11 | This repo will aim to solve these problems.
12 |
13 | ## Table of Contents
14 | * [RL agents](#rl-agents)
15 | * [DQN](#dqn)
16 | * [DQN current results](#dqn-current-results)
17 | * [Setup](#setup)
18 | * [Usage](#usage)
19 | * [Training DQN](#training-dqn)
20 | * [Visualization and debugging tools](#visualization-and-debugging-tools)
21 | * [Hardware requirements](#hardware-requirements)
22 | * [Future todos](#future-todos)
23 | * [Learning material](#learning-material)
24 |
25 | ## RL agents
26 |
27 | ## DQN
28 |
29 | This was the project that started the revolution in the RL world - deep Q-network (:link: [Mnih et al.](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)),
30 | aka "Human-level control through deep RL".
31 |
32 | DQN model learned to play **29 Atari games** (out of 49 they it tested on) on a **super-human**/comparable-to-humans level.
33 | Here is the schematic of it's CNN architecture:
34 |
35 |
36 |
37 |
38 |
39 | The fascinating part is that it learned only from "high-dimensional" (84x84) images and (usually sparse) rewards.
40 | The same architecture was used for all of the 49 games - although the model has to be retrained, from scratch, every single time.
41 |
42 | ## DQN current results
43 |
44 | Since it takes [lots of compute and time](#hardware-requirements) to train all of the 49 models I'll consider the DQN project completed once
45 | I succeed in achieving the published results on:
46 | * Breakout
47 | * Pong
48 |
49 | ---
50 |
51 | Having said that the experiments are still in progress, so feel free to **contribute**!
52 | * For some reason the models aren't learning very well so if you find a bug open up a PR! :heart:
53 | * I'm also experiencing slowdowns - so any PRs that would improve/explain the perf are welcome!
54 | * If you decide to train the DQN using this repo on some other Atari game I'll gladly check-in your model!
55 |
56 | **Important note: please follow the coding guidelines of this repo before you submit a PR so that we can minimize
57 | the back-and-forth. I'm a decently busy guy as I assume you are.**
58 |
59 | ### Current results - Breakout
60 |
61 |
62 |
63 |
64 |
65 | As you can see the model did learn something although it's far from being really good.
66 |
67 | ### Current results - Pong
68 |
69 | todo
70 |
71 | ## Setup
72 |
73 | Let's get this thing running! Follow the next steps:
74 |
75 | 1. `git clone https://github.com/gordicaleksa/pytorch-learn-reinforcement-learning`
76 | 2. Open Anaconda console and navigate into project directory `cd path_to_repo`
77 | 3. Run `conda env create` from project directory (this will create a brand new conda environment).
78 | 4. Run `activate pytorch-rl-env` (for running scripts from your console or setup the interpreter in your IDE)
79 |
80 | If you're on Windows you'll additionally need to install this:
81 | `pip install https://github.com/Kojoley/atari-py/releases atary_py` to install gym's Atari dependencies.
82 |
83 | Otherwise this should do it `pip install 'gym[atari]'`, if it's not working check out [this](https://stackoverflow.com/questions/49947555/openai-gym-trouble-installing-atari-dependency-mac-os-x) and [this](https://github.com/openai/gym/issues/1170).
84 |
85 | That's it! It should work out-of-the-box executing environment.yml file which deals with dependencies.
86 |
87 | -----
88 |
89 | PyTorch pip package will come bundled with some version of CUDA/cuDNN with it,
90 | but it is highly recommended that you install a system-wide CUDA beforehand, mostly because of the GPU drivers.
91 | I also recommend using Miniconda installer as a way to get conda on your system.
92 | Follow through points 1 and 2 of [this setup](https://github.com/Petlja/PSIML/blob/master/docs/MachineSetup.md)
93 | and use the most up-to-date versions of Miniconda and CUDA/cuDNN for your system.
94 |
95 | ## Usage
96 |
97 | #### Option 1: Jupyter Notebook
98 |
99 | Coming soon.
100 |
101 | #### Option 2: Use your IDE of choice
102 |
103 | You just need to link the Python environment you created in the [setup](#setup) section.
104 |
105 | ## Training DQN
106 |
107 | To run with default settings just run `python train_DQN_script.py`.
108 |
109 | Settings you'll want to experiment with:
110 | * `--seed` - it may just so happen that I've chosen a bad one (RL is very sensitive)
111 | * `--learning_rate` - DQN originally used RMSProp, I saw that Adam with 1e-4 worked for stable baselines 3
112 | * `--grad_clipping_value` - there was [a lot of noise](#visualization-tools) in the gradients so I used this to control it
113 | * Try using RMSProp (I haven't yet). Adam was an improvement over RMSProp so I doubt it's causing the issues
114 |
115 | Less important settings for getting DQN to work:
116 | * `--env_id` - depending on which game you want to train on (I'd focus on the easiest one for now - Breakout)
117 | * `--replay_buffer_size` - hopefully you can train DQN with 1M, as in the original paper, if not make it smaller
118 | * `--dont_crash_if_no_mem` - add this flag if you want to run with 1M replay buffer even if you don't have enough RAM
119 |
120 | The training script will:
121 | * Dump checkpoint *.pth models into `models/checkpoints/`
122 | * Dump the best (highest reward) *.pth model into `models/binaries/` <- TODO
123 | * Periodically write some training metadata to the console
124 | * Save tensorboard metrics into `runs/`, to use it check out [the visualization section](#visualization-tools)
125 |
126 | ## Visualization and debugging tools
127 |
128 | You can visualize the metrics during the training, by calling `tensorboard --logdir=runs` from your console
129 | and pasting the `http://localhost:6006/` URL into your browser.
130 |
131 | I'm currently visualizing the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss) (and you can see there is something weird going on):
132 |
133 |
134 |
135 |
136 |
137 | Rewards and steps taken per episode (there is a fair bit of correlation between these 2):
138 |
139 |
140 |
141 |
142 |
143 |
144 | And gradient L2 norms of weights and biases of every CNN/FC layer as well as the complete grad vector:
145 |
146 |
147 |
148 |
149 |
150 | As well as epsilon (from the epsilon-greedy algorithm) but that plot is not that informative so I'll omit it here.
151 |
152 | As you can see the plots are super **noisy**! As I could have expected, but the progress just stagnates from certain point onwards
153 | and that's what I'm trying to debug atm.
154 |
155 | ---
156 |
157 | To enter the debug mode add the `--debug` flag to your console or IDE's list of script arguments.
158 |
159 | It'll visualize the current state that's being fed into the RL agent.
160 | Sometimes the state will have some black frames prepended since there aren't enough frames experienced in the current episode:
161 |
162 |
163 |
164 |
165 |
166 | But mostly all of the 4 frames will be in there:
167 |
168 |
169 |
170 |
171 |
172 | And it will start rendering the game frames (`Pong` and `Breakout` showed here from left to right):
173 |
174 |
175 |
176 |
177 |
178 |
179 | ## Hardware requirements
180 |
181 | You'll need some decent hardware to train the DQN in reasonable time so that you can iterate fast:
182 | 1) **16+ GB of RAM** (Replay Buffer takes around ~7 GBs of RAM).
183 | 2) The faster your GPU is - the better! :sweat_smile: Having said that VRAM is not the bottleneck you'll need **2+ GB VRAM**.
184 |
185 | With 16 GB RAM and RTX 2080 it takes ~5 days to train DQN on my machine - I'm **experiencing some slowdowns** which I
186 | haven't debugged yet. Here is the FPS (frames-per-second) metric I'm logging:
187 |
188 |
189 |
190 |
191 |
192 | The shorter, green one is the current experiment I'm running, the red one took over 5 days to train.
193 |
194 | ## Future todos
195 |
196 | 1) Debug DQN and achieve the published results
197 | 2) Add Vanilla PG
198 | 3) Add PPO
199 |
200 | ## Learning material
201 |
202 | Here are some videos I made on RL which may help you to better understand how DQN and other RL algorithms work:
203 |
204 |
205 |
207 |
208 |
209 | And some other ones:
210 | * [DeepMind: AlphaGo](https://www.youtube.com/watch?v=Z1BELqFQZVM)
211 | * [DeepMind: AlphaGo Zero and AlphaZero](https://www.youtube.com/watch?v=0slFo1rV0EM)
212 | * [OpenAI: Solving Rubik's Cube with a Robot Hand](https://www.youtube.com/watch?v=eTa-k1pgvnU)
213 | * [DeepMind: MuZero](https://www.youtube.com/watch?v=mH7f7N7s79s)
214 |
215 | And in this one I tried to film through the process while the project was not nearly as polished as it is now:
216 | * [DQN project update](https://www.youtube.com/watch?v=DrOp_MQGn9o&ab_channel=TheAIEpiphany)
217 |
218 | I'll soon create a blog on how to get started with RL - so stay tuned for that!
219 |
220 | ## Acknowledgements
221 |
222 | I found these resources useful while developing this project, sorted (approximately) by usefulness:
223 |
224 | * [Stable Baselines 3 DQN](https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/dqn/dqn.py)
225 | * [PyTorch reimplementation of Berkley's DQN](https://github.com/transedward/pytorch-dqn) and [Berkley's DQN](https://github.com/berkeleydeeprlcourse/homework/tree/master/hw3)
226 | * [pytorch-dqn](https://github.com/jacobaustin123/pytorch-dqn/blob/master/dqn.py)
227 | * [RL adventures DQN](https://github.com/higgsfield/RL-Adventure/blob/master/1.dqn.ipynb) and [minimal DQN](https://github.com/econti/minimal_dqn/blob/master/main.py)
228 | * [Pytorch tutorial](https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html)
229 |
230 | ## Citation
231 |
232 | If you find this code useful, please cite the following:
233 |
234 | ```
235 | @misc{Gordić2021PyTorchLearnReinforcementLearning,
236 | author = {Gordić, Aleksa},
237 | title = {pytorch-learn-reinforcement-learning},
238 | year = {2021},
239 | publisher = {GitHub},
240 | journal = {GitHub repository},
241 | howpublished = {\url{https://github.com/gordicaleksa/pytorch-learn-reinforcement-learning}},
242 | }
243 | ```
244 |
245 | ## Licence
246 |
247 | [](https://github.com/gordicaleksa/pytorch-learn-reinforcement-learning/blob/master/LICENCE)
--------------------------------------------------------------------------------
/data/readme_gifs/BreakoutNoFrameskip-v4_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_gifs/BreakoutNoFrameskip-v4_1.gif
--------------------------------------------------------------------------------
/data/readme_pics/dqn.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_pics/dqn.jpg
--------------------------------------------------------------------------------
/data/readme_visualizations/breakout.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/breakout.jpg
--------------------------------------------------------------------------------
/data/readme_visualizations/fps_metric.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/fps_metric.PNG
--------------------------------------------------------------------------------
/data/readme_visualizations/grads.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/grads.PNG
--------------------------------------------------------------------------------
/data/readme_visualizations/huber_loss.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/huber_loss.PNG
--------------------------------------------------------------------------------
/data/readme_visualizations/pong.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/pong.jpg
--------------------------------------------------------------------------------
/data/readme_visualizations/rewards_per_episode.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/rewards_per_episode.PNG
--------------------------------------------------------------------------------
/data/readme_visualizations/state_all_frames.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/state_all_frames.PNG
--------------------------------------------------------------------------------
/data/readme_visualizations/state_initial.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/state_initial.PNG
--------------------------------------------------------------------------------
/data/readme_visualizations/steps_per_episode.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/steps_per_episode.PNG
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: pytorch-rl-env
2 | channels:
3 | - defaults
4 | - pytorch
5 | dependencies:
6 | - python==3.8.5
7 | - pip==21.0.1
8 | - pytorch==1.8.1
9 | - pip:
10 | - matplotlib==3.3.3
11 | - GitPython==3.1.2
12 | - psutil==5.8.0
13 | - stable-baselines3==1.0
14 | - opencv-python==4.5.1.48
15 | - imageio==2.9.0
16 | - jupyter==1.0.0
17 | - numpy==1.19.2
18 | - tensorboard==2.2.2
19 | - gym==0.17.3
20 |
--------------------------------------------------------------------------------
/evaluate_DQN_script.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 |
5 | import torch
6 | import matplotlib.pyplot as plt
7 | import cv2 as cv
8 | import numpy as np
9 |
10 |
11 | import utils.utils as utils
12 | from models.definitions.DQN import DQN
13 | from utils.constants import *
14 | from utils.replay_buffer import ReplayBuffer
15 | from utils.video_utils import create_gif
16 |
17 |
18 | if __name__ == '__main__':
19 | # Step 0: Modify these as needed
20 | buffer_size = 100000
21 | epsilon_eval = 0.05
22 | env_id = 'BreakoutNoFrameskip-v4'
23 | model_name = 'dqn_BreakoutNoFrameskip-v4_ckpt_steps_6810000.pth'
24 | should_record_video = True
25 |
26 | game_frames_dump_dir = os.path.join(DATA_DIR_PATH, 'dqn_eval_dump_dir')
27 | if os.path.exists(game_frames_dump_dir):
28 | shutil.rmtree(game_frames_dump_dir)
29 | os.makedirs(game_frames_dump_dir, exist_ok=True)
30 |
31 | # Step 1: Prepare environment, replay buffer and schedule
32 | env = utils.get_env_wrapper(env_id, record_video=should_record_video)
33 | replay_buffer = ReplayBuffer(buffer_size)
34 | const_schedule = utils.ConstSchedule(epsilon_eval) # lambda would also do - doing it like this for consistency
35 |
36 | # Step 2: Prepare the DQN model
37 | model_path = os.path.join(BINARIES_PATH, model_name)
38 | model_state = torch.load(model_path)
39 | assert model_state['env_id'] == env_id, \
40 | f"Model {model_name} was trained on {model_state['env_id']} but you're running it on {env_id}."
41 | utils.print_model_metadata(model_state)
42 |
43 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44 | dqn = DQN(env, number_of_actions=env.action_space.n, epsilon_schedule=const_schedule).to(device)
45 | dqn.load_state_dict(model_state["state_dict"], strict=True)
46 | dqn.eval()
47 |
48 | # Step 3: Evaluate the agent on a single episode
49 | print(f'{"*"*10} Starting the game. {"*"*10}')
50 | last_frame = env.reset()
51 |
52 | score = 0
53 | cnt = 0
54 | while True:
55 | replay_buffer.store_frame(last_frame)
56 | current_state = replay_buffer.fetch_last_state() # fetch the state, shape = (4, 84, 84) for Atari
57 |
58 | with torch.no_grad():
59 | action = dqn.epsilon_greedy(current_state) # act in this state
60 |
61 | new_frame, reward, done, _ = env.step(action) # send the action to the environment
62 | score += reward
63 |
64 | env.render() # plot the current game frame
65 | screen = env.render(mode='rgb_array') # but also save it as an image
66 | processed_screen = cv.resize(screen[:, :, ::-1], (0, 0), fx=1.5, fy=1.5, interpolation=cv.INTER_NEAREST)
67 | cv.imwrite(os.path.join(game_frames_dump_dir, f'{str(cnt).zfill(5)}.jpg'), processed_screen) # cv expects BGR hence ::-1
68 | cnt += 1
69 |
70 | if done:
71 | print(f'Episode over, score = {score}.')
72 | break
73 |
74 | last_frame = new_frame # set the last frame to the newly acquired frame from the env
75 |
76 | create_gif(game_frames_dump_dir, os.path.join(DATA_DIR_PATH, f'{env_id}.gif'), fps=60)
77 |
--------------------------------------------------------------------------------
/models/definitions/DQN.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | import numpy as np
4 |
5 |
6 | from utils.utils import get_env_wrapper
7 |
8 |
9 | class DQN(nn.Module):
10 | """
11 | I wrote the architecture a bit more generic, hence more lines of code,
12 | but it's more flexible if you want to experiment with the DQN architecture.
13 |
14 | """
15 | def __init__(self, env, num_in_channels=4, number_of_actions=3, epsilon_schedule=None):
16 | super().__init__()
17 | self.env = env
18 | self.epsilon_schedule = epsilon_schedule # defines the annealing strategy for epsilon in epsilon-greedy
19 | self.num_calls_to_epsilon_greedy = 0 # counts the number of calls to epsilon greedy function
20 |
21 | #
22 | # CNN params - from the Nature DQN paper - MODIFY this part if you want to experiment
23 | #
24 | num_of_filters_cnn = [num_in_channels, 32, 64, 64]
25 | kernel_sizes = [8, 4, 3]
26 | strides = [4, 2, 1]
27 |
28 | #
29 | # Build CNN part of DQN
30 | #
31 | cnn_modules = []
32 | for i in range(len(num_of_filters_cnn) - 1):
33 | cnn_modules.extend(
34 | self._cnn_block(num_of_filters_cnn[i], num_of_filters_cnn[i + 1], kernel_sizes[i], strides[i])
35 | )
36 |
37 | self.cnn_part = nn.Sequential(
38 | *cnn_modules,
39 | nn.Flatten() # flatten from (B, C, H, W) into (B, C*H*W), where B is batch size and C number of in channels
40 | )
41 |
42 | #
43 | # Build fully-connected part of DQN
44 | #
45 | with torch.no_grad(): # automatically figure out the shape for the given env observation
46 | # shape = (1, C', H, W), uint8, where C' is originally 1, i.e. grayscale frames
47 | dummy_input = torch.from_numpy(env.observation_space.sample()[np.newaxis])
48 |
49 | if dummy_input.shape[1] != num_in_channels:
50 | assert num_in_channels % dummy_input.shape[1] == 0
51 | # shape = (1, C, H, W), float
52 | dummy_input = dummy_input.repeat(1, int(num_in_channels / dummy_input.shape[1]), 1, 1).float()
53 |
54 | num_nodes_fc1 = self.cnn_part(dummy_input).shape[1] # cnn output shape = (B, C*H*W)
55 | print(f"DQN's first FC layer input dimension: {num_nodes_fc1}")
56 |
57 | #
58 | # FC params - MODIFY this part if you want to experiment
59 | #
60 | num_of_neurons_fc = [num_nodes_fc1, 512, number_of_actions]
61 |
62 | fc_modules = []
63 | for i in range(len(num_of_neurons_fc) - 1):
64 | last_layer = i == len(num_of_neurons_fc) - 1 # last layer shouldn't have activation (Q-value is unbounded)
65 | fc_modules.extend(self._fc_block(num_of_neurons_fc[i], num_of_neurons_fc[i + 1], use_relu=not last_layer))
66 |
67 | self.fc_part = nn.Sequential(
68 | *fc_modules
69 | )
70 |
71 | def forward(self, states):
72 | # shape: (B, C, H, W) -> (B, NA) - where NA is the Number of Actions
73 | return self.fc_part(self.cnn_part(states))
74 |
75 | def epsilon_greedy(self, state):
76 | assert self.epsilon_schedule is not None, f"No schedule provided, can't call epsilon_greedy function"
77 | assert state.shape[0] == 1, f'Agent can only act on a single state'
78 | self.num_calls_to_epsilon_greedy += 1
79 |
80 | # Epsilon-greedy exploration
81 | if np.random.rand() < self.epsilon_value():
82 | # With epsilon probability act random
83 | action = self.env.action_space.sample()
84 | else:
85 | # Otherwise act greedily - choosing an action that maximizes Q
86 | # Shape evolution: (1, C, H, W) -> (forward) (1, NA) -> (argmax) (1, 1) -> [0] scalar
87 | action = self.forward(state).argmax(dim=1)[0].cpu().numpy()
88 |
89 | return action
90 |
91 | def epsilon_value(self):
92 | return self.epsilon_schedule(self.num_calls_to_epsilon_greedy)
93 |
94 | #
95 | # Helper/"private" functions
96 | #
97 |
98 | # The original CNN didn't use any padding: https://github.com/deepmind/dqn/blob/master/dqn/convnet.lua
99 | # not that it matters - it would probably work either way feel free to experiment with the architecture.
100 | def _cnn_block(self, num_in_filters, num_out_filters, kernel_size, stride):
101 | layers = [nn.Conv2d(num_in_filters, num_out_filters, kernel_size=kernel_size, stride=stride), nn.ReLU()]
102 | return layers
103 |
104 | def _fc_block(self, num_in_neurons, num_out_neurons, use_relu=True):
105 | layers = [nn.Linear(num_in_neurons, num_out_neurons)]
106 | if use_relu:
107 | layers.append(nn.ReLU())
108 | return layers
109 |
110 |
111 | # Test DQN network
112 | if __name__ == '__main__':
113 | # NoFrameskip - receive every frame from the env whereas the version without NoFrameskip would give every 4th frame
114 | # v4 - actions we send to env are executed, whereas v0 would ignore the last action we sent with 0.25 probability
115 | env_id = "PongNoFrameskip-v4"
116 | env_wrapped = get_env_wrapper(env_id)
117 | dqn = DQN(env_wrapped) # testing only the __init__ function (mainly the automatic shape calculation mechanism)
118 |
119 |
120 |
--------------------------------------------------------------------------------
/playground.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | from stable_baselines3.common.env_util import make_atari_env
5 | from stable_baselines3.common.vec_env import VecFrameStack
6 | from stable_baselines3 import DQN
7 |
8 |
9 | def run_dqn_baseline():
10 | env = make_atari_env('BreakoutNoFrameskip-v4', n_envs=1, seed=0)
11 | env = VecFrameStack(env, n_stack=4)
12 | tensorboard_log = os.path.join(os.path.dirname(__file__), 'runs_baseline')
13 | buffer_size = 100000
14 | num_training_steps = 1000000
15 |
16 | model = DQN(
17 | 'CnnPolicy',
18 | env,
19 | verbose=0,
20 | buffer_size=buffer_size,
21 | learning_starts=50000,
22 | optimize_memory_usage=False,
23 | tensorboard_log=tensorboard_log
24 | )
25 | model.learn(total_timesteps=num_training_steps)
26 |
27 | obs = env.reset()
28 | while True:
29 | action, _states = model.predict(obs)
30 | obs, rewards, dones, info = env.step(action)
31 | env.render()
32 |
33 |
34 | if __name__ == '__main__':
35 | run_dqn_baseline()
36 |
--------------------------------------------------------------------------------
/train_DQN_script.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of the original DQN Nature paper:
3 | https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf
4 |
5 | Some of the complexity is captured via wrappers but the main components such as the DQN model itself,
6 | the training loop, the memory-efficient replay buffer are implemented from scratch.
7 |
8 | Some modifications:
9 | * Using Adam instead of RMSProp
10 |
11 | """
12 |
13 | import os
14 | import argparse
15 | import time
16 | import copy
17 |
18 |
19 | import numpy as np
20 | import torch
21 | from torch import nn
22 | import matplotlib.pyplot as plt
23 | from torch.optim import Adam
24 | from torch.utils.tensorboard import SummaryWriter
25 |
26 |
27 | import utils.utils as utils
28 | from utils.replay_buffer import ReplayBuffer
29 | from utils.constants import *
30 | from models.definitions.DQN import DQN
31 |
32 |
33 | class ActorLearner:
34 |
35 | def __init__(self, config, env, replay_buffer, dqn, target_dqn, last_frame):
36 |
37 | self.start_time = time.time()
38 |
39 | self.config = config
40 | self.env = env
41 | self.last_frame = last_frame # always keeps the latest frame from the environment
42 | self.replay_buffer = replay_buffer
43 |
44 | # DQN Models
45 | self.dqn = dqn
46 | self.target_dqn = target_dqn
47 |
48 | # Logging/debugging-related
49 | self.debug = config['debug']
50 | self.log_freq = config['log_freq']
51 | self.episode_log_freq = config['episode_log_freq']
52 | self.grads_log_freq = config['grads_log_freq']
53 | self.checkpoint_freq = config['checkpoint_freq']
54 | self.tensorboard_writer = SummaryWriter()
55 | self.huber_loss = []
56 | self.best_episode_reward = -np.inf
57 | self.best_dqn_model = None # keeps a deep copy of the best DQN model so far (best = highest episode reward)
58 |
59 | # MSE/L2 between [-1,1] and L1 otherwise (as stated in the Nature paper) aka "Huber loss"
60 | self.loss = nn.SmoothL1Loss()
61 | self.optimizer = Adam(self.dqn.parameters(), lr=config['learning_rate'])
62 | self.grad_clip_value = config['grad_clipping_value']
63 |
64 | self.acting_learning_step_ratio = config['acting_learning_step_ratio']
65 | self.num_warmup_steps = config['num_warmup_steps']
66 | self.batch_size = config['batch_size']
67 | self.gamma = config['gamma'] # discount factor
68 |
69 | self.learner_cnt = 0
70 | self.target_dqn_update_interval = config['target_dqn_update_interval']
71 | # should perform a hard or a soft update of target DQN weights
72 | self.tau = config['tau']
73 |
74 | def collect_experience(self):
75 | # We're collecting more experience than we're doing weight updates (4x in the Nature paper)
76 | for _ in range(self.acting_learning_step_ratio):
77 | last_index = self.replay_buffer.store_frame(self.last_frame)
78 | state = self.replay_buffer.fetch_last_state() # state = 4 preprocessed last frames for Atari
79 |
80 | action = self.sample_action(state)
81 | new_frame, reward, done_flag, _ = self.env.step(action)
82 |
83 | self.replay_buffer.store_action_reward_done(last_index, action, reward, done_flag)
84 |
85 | if done_flag:
86 | new_frame = self.env.reset()
87 | self.maybe_log_episode()
88 |
89 | self.last_frame = new_frame
90 |
91 | if self.debug:
92 | self.visualize_state(state)
93 | self.env.render()
94 |
95 | self.maybe_log()
96 |
97 | def sample_action(self, state):
98 | if self.env.get_total_steps() < self.num_warmup_steps:
99 | action = self.env.action_space.sample() # initial warm up period - no learning, acting randomly
100 | else:
101 | with torch.no_grad():
102 | action = self.dqn.epsilon_greedy(state)
103 | return action
104 |
105 | def get_number_of_env_steps(self):
106 | return self.env.get_total_steps()
107 |
108 | def learn_from_experience(self):
109 | current_states, actions, rewards, next_states, done_flags = self.replay_buffer.fetch_random_states(self.batch_size)
110 |
111 | # Better than detaching: in addition to target dqn not being a part of the computational graph it also
112 | # saves time/memory because we're not storing activations during forward propagation needed for the backprop
113 | with torch.no_grad():
114 | # shape = (B, NA) -> (B, 1), where NA - number of actions
115 | # [0] because max returns (values, indices) tuples
116 | next_state_max_q_values = self.target_dqn(next_states).max(dim=1, keepdim=True)[0]
117 |
118 | # shape = (B, 1), TD targets. We need (1 - done) because when we're in a terminal state the next
119 | # state Q value should be 0 and we only use the reward information
120 | target_q_values = rewards + (1 - done_flags) * self.gamma * next_state_max_q_values
121 |
122 | # shape = (B, 1), pick those Q values that correspond to the actions we made in those states
123 | current_state_q_values = self.dqn(current_states).gather(dim=1, index=actions)
124 |
125 | loss = self.loss(target_q_values, current_state_q_values)
126 | self.huber_loss.append(loss.item())
127 |
128 | self.optimizer.zero_grad()
129 | loss.backward() # compute the gradients
130 |
131 | if self.grad_clip_value is not None: # potentially clip gradients for stability reasons
132 | nn.utils.clip_grad_norm_(self.dqn.parameters(), self.grad_clip_value)
133 |
134 | self.optimizer.step() # update step
135 | self.learner_cnt += 1
136 |
137 | # Periodically update the target DQN weights (coupled to the number of DQN weight updates and not # env steps)
138 | if self.learner_cnt % self.target_dqn_update_interval == 0:
139 | if self.tau == 1.:
140 | print('Update target DQN (hard update)')
141 | self.target_dqn.load_state_dict(self.dqn.state_dict())
142 | else: # soft update, the 2 branches can be merged together, leaving it like this for now
143 | raise Exception(f'Soft update is not yet implemented (hard update was used in the original paper)')
144 |
145 | @staticmethod
146 | def visualize_state(state):
147 | state = state[0].to('cpu').numpy() # (1/B, C, H, W) -> (C, H, W)
148 | stacked_frames = np.hstack([np.repeat((img * 255).astype(np.uint8)[:, :, np.newaxis], 3, axis=2) for img in state]) # (C, H, W) -> (H, C*W, 3)
149 | plt.imshow(stacked_frames)
150 | plt.show()
151 |
152 | def maybe_log_episode(self):
153 | rewards = self.env.get_episode_rewards() # we can do this thanks to the Monitor wrapper
154 | episode_lengths = self.env.get_episode_lengths()
155 | num_episodes = len(rewards)
156 |
157 | if self.episode_log_freq is not None and num_episodes % self.episode_log_freq == 0:
158 | self.tensorboard_writer.add_scalar('Rewards per episode', rewards[-1], num_episodes)
159 | self.tensorboard_writer.add_scalar('Steps per episode', episode_lengths[-1], num_episodes)
160 |
161 | if rewards[-1] > self.best_episode_reward:
162 | self.best_episode_reward = rewards[-1]
163 | self.config['best_episode_reward'] = self.best_episode_reward # metadata
164 | self.best_dqn_model = copy.deepcopy(self.dqn) # keep track of the model that gave the best reward
165 |
166 | def maybe_log(self):
167 | num_steps = self.env.get_total_steps()
168 |
169 | if self.log_freq is not None and num_steps > 0 and num_steps % self.log_freq == 0:
170 | self.tensorboard_writer.add_scalar('Epsilon', self.dqn.epsilon_value(), num_steps)
171 | if len(self.huber_loss) > 0:
172 | self.tensorboard_writer.add_scalar('Huber loss', np.mean(self.huber_loss), num_steps)
173 | self.tensorboard_writer.add_scalar('FPS', num_steps / (time.time() - self.start_time), num_steps)
174 |
175 | self.huber_loss = [] # clear the loss values and start recollecting them again
176 |
177 | # Periodically save DQN models
178 | if self.checkpoint_freq is not None and num_steps > 0 and num_steps % self.checkpoint_freq == 0:
179 | ckpt_model_name = f'dqn_{self.config["env_id"]}_ckpt_steps_{num_steps}.pth'
180 | torch.save(utils.get_training_state(self.config, self.dqn), os.path.join(CHECKPOINTS_PATH, ckpt_model_name))
181 |
182 | # Log the gradients
183 | if self.grads_log_freq is not None and self.learner_cnt > 0 and self.learner_cnt % self.grads_log_freq == 0:
184 | total_grad_l2_norm = 0
185 |
186 | for cnt, (name, weight_or_bias_parameters) in enumerate(self.dqn.named_parameters()):
187 | grad_l2_norm = weight_or_bias_parameters.grad.data.norm(p=2).item()
188 | self.tensorboard_writer.add_scalar(f'grad_norms/{name}', grad_l2_norm, self.learner_cnt)
189 | total_grad_l2_norm += grad_l2_norm ** 2
190 |
191 | # As if we concatenated all of the params into a single vector and took L2
192 | total_grad_l2_norm = total_grad_l2_norm ** (1/2)
193 | self.tensorboard_writer.add_scalar(f'grad_norms/total', total_grad_l2_norm, self.learner_cnt)
194 |
195 | def log_to_console(self): # keep it minimal for now, I mostly use tensorboard - feel free to expand functionality
196 | print(f'Number of env steps = {self.get_number_of_env_steps()}')
197 |
198 |
199 | def train_dqn(config):
200 | env = utils.get_env_wrapper(config['env_id'])
201 | replay_buffer = ReplayBuffer(config['replay_buffer_size'], crash_if_no_mem=config['dont_crash_if_no_mem'])
202 |
203 | utils.set_random_seeds(env, config['seed'])
204 |
205 | linear_schedule = utils.LinearSchedule(
206 | config['epsilon_start_value'],
207 | config['epsilon_end_value'],
208 | config['epsilon_duration']
209 | )
210 |
211 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
212 | dqn = DQN(env, number_of_actions=env.action_space.n, epsilon_schedule=linear_schedule).to(device)
213 | target_dqn = DQN(env, number_of_actions=env.action_space.n).to(device)
214 |
215 | # Don't get confused by the actor-learner terminology, DQN is not an actor-critic method, but conceptually
216 | # we can split the learning process into collecting experience/acting in the env and learning from that experience
217 | actor_learner = ActorLearner(config, env, replay_buffer, dqn, target_dqn, env.reset())
218 |
219 | while actor_learner.get_number_of_env_steps() < config['num_of_training_steps']:
220 |
221 | num_env_steps = actor_learner.get_number_of_env_steps()
222 | if config['console_log_freq'] is not None and num_env_steps % config['console_log_freq'] == 0:
223 | actor_learner.log_to_console()
224 |
225 | actor_learner.collect_experience()
226 |
227 | if num_env_steps > config['num_warmup_steps']:
228 | actor_learner.learn_from_experience()
229 |
230 | torch.save( # save the best DQN model overall (gave the highest reward in an episode)
231 | utils.get_training_state(config, actor_learner.best_dqn_model),
232 | os.path.join(BINARIES_PATH, utils.get_available_binary_name(config['env_id']))
233 | )
234 |
235 |
236 | def get_training_args():
237 | parser = argparse.ArgumentParser()
238 |
239 | # Training related
240 | parser.add_argument("--seed", type=int, help="Very important for reproducibility - set the random seed", default=23)
241 | parser.add_argument("--env_id", type=str, help="Atari game id", default='BreakoutNoFrameskip-v4')
242 | parser.add_argument("--num_of_training_steps", type=int, help="Number of training env steps", default=50000000)
243 | parser.add_argument("--acting_learning_step_ratio", type=int, help="Number of experience collection steps for every learning step", default=4)
244 | parser.add_argument("--learning_rate", type=float, default=1e-4)
245 | parser.add_argument("--grad_clipping_value", type=float, default=5) # 5 is fairly arbitrarily chosen
246 |
247 | parser.add_argument("--replay_buffer_size", type=int, help="Number of frames to store in buffer", default=1000000)
248 | parser.add_argument("--dont_crash_if_no_mem", action='store_false', help="Optimization - crash if not enough RAM before the training even starts (default=True)")
249 | parser.add_argument("--num_warmup_steps", type=int, help="Number of steps before learning starts", default=50000)
250 | parser.add_argument("--target_dqn_update_interval", type=int, help="Target DQN update freq per learning update", default=10000)
251 |
252 | parser.add_argument("--batch_size", type=int, help="Number of states in a batch (from replay buffer)", default=32)
253 | parser.add_argument("--gamma", type=float, help="Discount factor", default=0.99)
254 | parser.add_argument("--tau", type=float, help='Set to 1 for a hard target DQN update, < 1 for a soft one', default=1.)
255 |
256 | # epsilon-greedy annealing params
257 | parser.add_argument("--epsilon_start_value", type=float, default=1.)
258 | parser.add_argument("--epsilon_end_value", type=float, default=0.1)
259 | parser.add_argument("--epsilon_duration", type=int, default=1000000)
260 |
261 | # Logging/debugging/checkpoint related (helps a lot with experimentation)
262 | parser.add_argument("--console_log_freq", type=int, help="Log to console after this many env steps (None = no logging)", default=10000)
263 | parser.add_argument("--log_freq", type=int, help="Log metrics to tensorboard after this many env steps (None = no logging)", default=10000)
264 | parser.add_argument("--episode_log_freq", type=int, help="Log metrics to tensorboard after this many episodes (None = no logging)", default=5)
265 | parser.add_argument("--checkpoint_freq", type=int, help="Save checkpoint model after this many env steps (None = no checkpointing)", default=10000)
266 | parser.add_argument("--grads_log_freq", type=int, help="Log grad norms after this many weight update steps (None = no logging)", default=2500)
267 | parser.add_argument("--debug", action='store_true', help='Train in debugging mode')
268 | args = parser.parse_args()
269 |
270 | # Wrapping training configuration into a dictionary
271 | training_config = dict()
272 | for arg in vars(args):
273 | training_config[arg] = getattr(args, arg)
274 |
275 | return training_config
276 |
277 |
278 | if __name__ == '__main__':
279 | # Train the DQN model
280 | train_dqn(get_training_args())
281 |
282 |
283 |
--------------------------------------------------------------------------------
/utils/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | BINARIES_PATH = os.path.join(os.path.dirname(__file__), os.pardir, 'models', 'binaries')
5 | CHECKPOINTS_PATH = os.path.join(os.path.dirname(__file__), os.pardir, 'models', 'checkpoints')
6 | DATA_DIR_PATH = os.path.join(os.path.dirname(__file__), os.pardir, 'data')
7 |
8 |
9 | # Make sure these exist as the rest of the code assumes it
10 | os.makedirs(BINARIES_PATH, exist_ok=True)
11 | os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
12 |
--------------------------------------------------------------------------------
/utils/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 |
4 | import numpy as np
5 | import psutil
6 | import torch
7 |
8 |
9 | from utils.utils import get_env_wrapper
10 |
11 |
12 | class ReplayBuffer:
13 | """
14 | Since stable baselines 3 doesn't currently support a smart replay buffer (more concretely the "lazy frames" feature)
15 | i.e. allocating (10^6, 84, 84) (~7 GB) for Atari and extracting 4 frames as needed, instead of (10^6, 4, 84, 84),
16 | here is an efficient implementation.
17 |
18 | Note: inspired by Berkley's implementation: https://github.com/berkeleydeeprlcourse/homework/tree/master/hw3
19 |
20 | Further improvements:
21 | * Much more concise (and hopefully readable)
22 | * Reports error if you don't have enough RAM in advance to allocate this buffer
23 | * Fixed a subtle buffer boundary bug (start index edge case)
24 |
25 | """
26 | def __init__(self, max_buffer_size, num_last_frames_to_fetch=4, frame_shape=[1, 84, 84], crash_if_no_mem=True):
27 | self.max_buffer_size = max_buffer_size
28 | self.current_buffer_size = 0
29 | self.current_free_slot_index = 0
30 |
31 | assert frame_shape[0] in (1, 3), f'Expected mono/color image frame got shape={frame_shape}.'
32 | self.frame_height = frame_shape[1]
33 | self.frame_width = frame_shape[2]
34 | self.num_previous_frames_to_fetch = num_last_frames_to_fetch
35 |
36 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37 |
38 | # Create main buffer containers - all types are chosen so as to optimize memory consumption
39 | self.frames = np.zeros([self.max_buffer_size] + frame_shape, dtype=np.uint8)
40 | self.actions = np.zeros([self.max_buffer_size, 1], dtype=np.uint8)
41 | self.rewards = np.zeros([self.max_buffer_size, 1], dtype=np.float32) # we need extra precision for rewards
42 | self.dones = np.zeros([self.max_buffer_size, 1], dtype=np.uint8)
43 |
44 | # Numpy does lazy execution so it can happen that only after a while the training starts hitting the RAM limit
45 | # then the page swapping kicks in which will slow-down the training significantly only after hours of training!
46 | # hence the _check_enough_ram function (I'm saving you time and money thank me later ^^)
47 | self._check_enough_ram(crash_if_no_mem)
48 |
49 | #
50 | # public API functions
51 | #
52 |
53 | def store_frame(self, frame):
54 | self.frames[self.current_free_slot_index] = frame
55 |
56 | self.current_free_slot_index = (self.current_free_slot_index + 1) % self.max_buffer_size # circular logic
57 | self.current_buffer_size = min(self.max_buffer_size, self.current_buffer_size + 1)
58 |
59 | return self.current_free_slot_index - 1 # we yet need to store (action, reward, done) at this index
60 |
61 | def store_action_reward_done(self, index, action, reward, done):
62 | self.actions[index] = action
63 | self.rewards[index] = reward
64 | self.dones[index] = done
65 |
66 | def fetch_random_states(self, batch_size):
67 | assert self._has_enough_data(batch_size), "Can't fetch states from the replay buffer - not enough data."
68 | # Uniform random sampling without replacement. -1 because we always need to fetch the current and the immediate
69 | # next state for Q-learning but the last state in the buffer doesn't have the next state
70 | random_unique_indices = random.sample(range(self.current_buffer_size - 1), batch_size)
71 |
72 | states = self._postprocess_state(
73 | np.concatenate([self._fetch_state(i) for i in random_unique_indices], 0) # shape = (B, C, H, W)
74 | )
75 | next_states = self._postprocess_state(
76 | np.concatenate([self._fetch_state(i + 1) for i in random_unique_indices], 0) # shape = (B, C, H, W)
77 | )
78 | # Long is needed because actions are used for indexing of tensors (PyTorch constraint)
79 | actions = torch.from_numpy(self.actions[random_unique_indices]).to(self.device).long()
80 | rewards = torch.from_numpy(self.rewards[random_unique_indices]).to(self.device)
81 | # Float is needed because we'll be multiplying Q values with done flags (1-done actually)
82 | dones = torch.from_numpy(self.dones[random_unique_indices]).to(self.device).float()
83 |
84 | return states, actions, rewards, next_states, dones
85 |
86 | def fetch_last_state(self):
87 | # shape = (1, C, H, W) where C - number of past frames, 4 for Atari
88 | return self._postprocess_state(
89 | self._fetch_state((self.current_free_slot_index - 1) % self.max_buffer_size)
90 | )
91 |
92 | def get_current_size(self):
93 | return self.current_buffer_size
94 |
95 | #
96 | # Helper functions
97 | #
98 |
99 | def _fetch_state(self, end_index):
100 | """
101 | We fetch end_index frame and ("num_last_frames_to_fetch" - 1) last frames (4 in total in the case of Atari)
102 | in order to generate a state.
103 |
104 | Replay buffer has 2 edge cases that we need to handle:
105 | 1) start_index related:
106 | * index is "too close"* to 0 and our circular buffer is still not full, thus we don't have enough frames
107 | * index is "too close" to the buffer boundary we could mix very old/new observations
108 |
109 | 2) done flag is True - we don't won't to take observations before that index since it belongs to a different
110 | life or episode.
111 |
112 | Notes:
113 | * "too close" is defined by 'num_last_frames_to_fetch' variable
114 | * terminology: state consists out of multiple observations (frames in Atari case)
115 |
116 | """
117 | # Start index is included, end index is excluded <=> [)
118 | end_index += 1
119 | start_index = end_index - self.num_previous_frames_to_fetch
120 | start_index = self._handle_start_index_edge_cases(start_index, end_index)
121 |
122 | num_of_missing_frames = self.num_previous_frames_to_fetch - (end_index - start_index)
123 |
124 | if start_index < 0 or num_of_missing_frames > 0: # start_index:end_index indexing won't work if start_index < 0
125 | # If there are missing frames, because of the above handled edge-cases, fill them with black frames as per
126 | # original DeepMind Lua imp: https://github.com/deepmind/dqn/blob/master/dqn/TransitionTable.lua#L171
127 | state = [np.zeros_like(self.frames[0]) for _ in range(num_of_missing_frames)]
128 |
129 | for index in range(start_index, end_index):
130 | state.append(self.frames[index % self.max_buffer_size])
131 |
132 | # shape = (C, H, W) -> (1, C, H, W) where C - number of past frames, 4 for Atari
133 | return np.concatenate(state, 0)[np.newaxis, :]
134 | else:
135 | # reshape from (C, 1, H, W) to (1, C, H, W) where C number of past frames, 4 for Atari
136 | return self.frames[start_index:end_index].reshape(-1, self.frame_height, self.frame_width)[np.newaxis, :]
137 |
138 | def _postprocess_state(self, state):
139 | # numpy -> tensor, move to device, uint8 -> float, [0,255] -> [0, 1]
140 | return torch.from_numpy(state).to(self.device).float().div(255)
141 |
142 | def _handle_start_index_edge_cases(self, start_index, end_index):
143 | # Edge case 1:
144 | if not self._buffer_full() and start_index < 0:
145 | start_index = 0
146 |
147 | # Edge case 2:
148 | # Handle the case where start index crosses the buffer head pointer - the data before and after the head pointer
149 | # belongs to completely different episodes
150 | if self._buffer_full():
151 | if 0 < (self.current_free_slot_index - start_index) % self.max_buffer_size < self.num_previous_frames_to_fetch:
152 | start_index = self.current_free_slot_index
153 |
154 | # Edge case 3:
155 | # A done flag marks a boundary between different episodes or lives either way we shouldn't take frames
156 | # before or at the done flag into consideration
157 | for index in range(start_index, end_index - 1):
158 | if self.dones[index % self.max_buffer_size]:
159 | start_index = index + 1
160 |
161 | return start_index
162 |
163 | def _buffer_full(self):
164 | return self.current_buffer_size == self.max_buffer_size
165 |
166 | def _has_enough_data(self, batch_size):
167 | return batch_size < self.current_buffer_size # e.g. if buffer size is 32 we need at least 33 frames hence <
168 |
169 | def _check_enough_ram(self, crash_if_no_mem):
170 | def to_GBs(memory_in_bytes): # beautify memory output - helper function
171 | return f'{memory_in_bytes / 2 ** 30:.2f} GBs'
172 |
173 | available_memory = psutil.virtual_memory().available
174 | required_memory = self.frames.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
175 | print(f'required memory = {to_GBs(required_memory)} GB, available memory = {to_GBs(available_memory)} GB')
176 |
177 | if required_memory > available_memory:
178 | message = f"Not enough memory to store the complete replay buffer! \n" \
179 | f"required: {to_GBs(required_memory)} > available: {to_GBs(available_memory)} \n" \
180 | f"Page swapping will make your training super slow once you hit your RAM limit." \
181 | f"You can either modify replay_buffer_size argument or set crash_if_no_mem to False to ignore it."
182 | if crash_if_no_mem:
183 | raise Exception(message)
184 | else:
185 | print(message)
186 |
187 |
188 | # Basic replay buffer testing
189 | if __name__ == '__main__':
190 | size = 500000
191 | num_of_collection_steps = 10000
192 | batch_size = 32
193 |
194 | # Step 0: Create replay buffer and the env
195 | replay_buffer = ReplayBuffer(size)
196 |
197 | # NoFrameskip - receive every frame from the env whereas the version without NoFrameskip would give every 4th frame
198 | # v4 - actions we send to env are executed, whereas v0 would ignore the last action we sent with 0.25 probability
199 | env_id = "PongNoFrameskip-v4"
200 | env = get_env_wrapper(env_id)
201 |
202 | # Step 1: Collect experience
203 | frame = env.reset()
204 |
205 | for i in range(num_of_collection_steps):
206 | random_action = env.action_space.sample()
207 |
208 | # For some reason for Pong gym returns more than 3 actions.
209 | print(f'Sampling action {random_action} - {env.unwrapped.get_action_meanings()[random_action]}')
210 |
211 | frame, reward, done, info = env.step(random_action)
212 |
213 | index = replay_buffer.store_frame(frame)
214 | replay_buffer.store_action_reward_done(index, random_action, reward, done)
215 |
216 | if done:
217 | env.reset()
218 |
219 | # Step 2: Fetch states from the buffer
220 | states, actions, rewards, next_states, dones = replay_buffer.fetch_random_states(batch_size)
221 |
222 | print(states.shape, next_states.shape, actions.shape, rewards.shape, dones.shape)
223 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import re
4 |
5 |
6 | import torch
7 | import git
8 | import gym
9 | import numpy as np
10 | from stable_baselines3.common.atari_wrappers import AtariWrapper
11 | from gym.wrappers import Monitor
12 |
13 |
14 | from .constants import *
15 |
16 |
17 | def get_env_wrapper(env_id, record_video=False):
18 | """
19 | Ultimately it's not very clear why are SB3's wrappers and OpenAI gym's copy/pasted code for the most part.
20 | It seems that OpenAI gym doesn't have reward clipping which is necessary for Atari.
21 |
22 | I'm using SB3 because it's actively maintained compared to OpenAI's gym and it has reward clipping by default.
23 |
24 | """
25 | monitor_dump_dir = os.path.join(os.path.dirname(__file__), os.pardir, 'gym_monitor')
26 |
27 | # This is necessary because AtariWrapper skips 4 frames by default, so we can't have additional skipping through
28 | # the environment itself - hence NoFrameskip requirement
29 | assert 'NoFrameskip' in env_id, f'Expected NoFrameskip environment got {env_id}'
30 |
31 | # The only additional thing needed, on top of AtariWrapper,
32 | # is to convert the shape to channel-first because of PyTorch's models
33 | env_wrapped = Monitor(ChannelFirst(AtariWrapper(gym.make(env_id))), monitor_dump_dir, force=True, video_callable=lambda episode: record_video)
34 |
35 | return env_wrapped
36 |
37 |
38 | class ChannelFirst(gym.ObservationWrapper):
39 | def __init__(self, env):
40 | super().__init__(env)
41 | new_shape = np.roll(self.observation_space.shape, shift=1) # shape: (H, W, C) -> (C, H, W)
42 |
43 | # Update because this is the last wrapper in the hierarchy, we'll be pooling the env for observation shape info
44 | self.observation_space = gym.spaces.Box(low=0, high=255, shape=new_shape, dtype=np.uint8)
45 |
46 | def observation(self, observation):
47 | return np.moveaxis(observation, 2, 0) # shape: (H, W, C) -> (C, H, W)
48 |
49 |
50 | class LinearSchedule:
51 |
52 | def __init__(self, schedule_start_value, schedule_end_value, schedule_duration):
53 | self.start_value = schedule_start_value
54 | self.end_value = schedule_end_value
55 | self.schedule_duration = schedule_duration
56 |
57 | def __call__(self, num_steps):
58 | progress = np.clip(num_steps / self.schedule_duration, a_min=None, a_max=1) # goes from 0 -> 1 and saturates
59 | return self.start_value + (self.end_value - self.start_value) * progress
60 |
61 |
62 | class ConstSchedule:
63 | """ Dummy schedule - used for DQN evaluation in evaluate_dqn_script.py. """
64 | def __init__(self, value):
65 | self.value = value
66 |
67 | def __call__(self, num_steps):
68 | return self.value
69 |
70 |
71 | def print_model_metadata(training_state):
72 | header = f'\n{"*"*5} DQN model training metadata: {"*"*5}'
73 | print(header)
74 |
75 | for key, value in training_state.items():
76 | if key != 'state_dict': # don't print state_dict it's a bunch of numbers...
77 | print(f'{key}: {value}')
78 | print(f'{"*" * len(header)}\n')
79 |
80 |
81 | def get_training_state(training_config, model):
82 | training_state = {
83 | # Reproducibility details
84 | "commit_hash": git.Repo(search_parent_directories=True).head.object.hexsha,
85 | "seed": training_config['seed'],
86 |
87 | # Env details
88 | "env_id": training_config['env_id'],
89 |
90 | # Training details
91 | "best_episode_reward": training_config['best_episode_reward'],
92 |
93 | # Model state
94 | "state_dict": model.state_dict()
95 | }
96 |
97 | return training_state
98 |
99 |
100 | def get_available_binary_name(env_id='env_unknown'):
101 | prefix = f'dqn_{env_id}'
102 |
103 | def valid_binary_name(binary_name):
104 | # First time you see raw f-string? Don't worry the only trick is to double the brackets.
105 | pattern = re.compile(rf'{prefix}_[0-9]{{6}}\.pth')
106 | return re.fullmatch(pattern, binary_name) is not None
107 |
108 | # Just list the existing binaries so that we don't overwrite them but write to a new one
109 | valid_binary_names = list(filter(valid_binary_name, os.listdir(BINARIES_PATH)))
110 | if len(valid_binary_names) > 0:
111 | last_binary_name = sorted(valid_binary_names)[-1]
112 | new_suffix = int(last_binary_name.split('.')[0][-6:]) + 1 # increment by 1
113 | return f'{prefix}_{str(new_suffix).zfill(6)}.pth'
114 | else:
115 | return f'{prefix}_000000.pth'
116 |
117 |
118 | def set_random_seeds(env, seed):
119 | if seed is not None:
120 | torch.manual_seed(seed) # PyTorch
121 | np.random.seed(seed) # NumPy
122 | random.seed(seed) # Python
123 | env.action_space.seed(seed) # probably redundant but I found an article where somebody had a problem with this
124 | env.seed(seed) # OpenAI gym
125 |
126 | # todo: AB test impact on FPS metric
127 | # Deterministic operations for CuDNN, it may impact performances
128 | if torch.cuda.is_available():
129 | torch.backends.cudnn.deterministic = True
130 | torch.backends.cudnn.benchmark = False
131 |
132 |
133 | # Test utils
134 | if __name__ == '__main__':
135 | import matplotlib.pyplot as plt
136 |
137 | schedule = LinearSchedule(schedule_start_value=1., schedule_end_value=0.1, schedule_duration=50)
138 |
139 | schedule_values = []
140 | for i in range(100):
141 | schedule_values.append(schedule(i))
142 |
143 | plt.plot(schedule_values)
144 | plt.show()
145 |
146 |
--------------------------------------------------------------------------------
/utils/video_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | import numpy as np
5 | import cv2 as cv
6 | import imageio
7 |
8 |
9 | def load_image(img_path, target_shape=None):
10 | if not os.path.exists(img_path):
11 | raise Exception(f'Path does not exist: {img_path}')
12 | img = cv.imread(img_path)[:, :, ::-1] # [:, :, ::-1] converts BGR (opencv format...) into RGB
13 |
14 | if target_shape is not None: # resize section
15 | if isinstance(target_shape, int) and target_shape != -1: # scalar -> implicitly setting the width
16 | current_height, current_width = img.shape[:2]
17 | new_width = target_shape
18 | new_height = int(current_height * (new_width / current_width))
19 | img = cv.resize(img, (new_width, new_height), interpolation=cv.INTER_CUBIC)
20 | else: # set both dimensions to target shape
21 | img = cv.resize(img, (target_shape[1], target_shape[0]), interpolation=cv.INTER_CUBIC)
22 |
23 | # this need to go after resizing - otherwise cv.resize will push values outside of [0,1] range
24 | img = img.astype(np.float32) # convert from uint8 to float32
25 | img /= 255.0 # get to [0, 1] range
26 | return img
27 |
28 |
29 | def create_gif(frames_dir, out_path, fps=30, img_width=None):
30 | assert os.path.splitext(out_path)[1].lower() == '.gif', f'Expected .gif got {os.path.splitext(out_path)[1]}.'
31 |
32 | frame_paths = [os.path.join(frames_dir, frame_name) for frame_name in os.listdir(frames_dir) if frame_name.endswith('.jpg')]
33 |
34 | images = [imageio.imread(frame_path) for frame_path in frame_paths]
35 | imageio.mimwrite(out_path, images, fps=fps)
36 | print(f'Saved gif to {out_path}.')
37 |
--------------------------------------------------------------------------------