├── .github └── FUNDING.yml ├── .gitignore ├── LICENCE ├── README.md ├── data ├── readme_gifs │ └── BreakoutNoFrameskip-v4_1.gif ├── readme_pics │ └── dqn.jpg └── readme_visualizations │ ├── breakout.jpg │ ├── fps_metric.PNG │ ├── grads.PNG │ ├── huber_loss.PNG │ ├── pong.jpg │ ├── rewards_per_episode.PNG │ ├── state_all_frames.PNG │ ├── state_initial.PNG │ └── steps_per_episode.PNG ├── environment.yml ├── evaluate_DQN_script.py ├── models └── definitions │ └── DQN.py ├── playground.py ├── train_DQN_script.py └── utils ├── constants.py ├── replay_buffer.py ├── utils.py └── video_utils.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | patreon: theaiepiphany 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # PyCharm IDE 2 | .idea 3 | __pycache__ 4 | 5 | # Jupyter notebook checkpoints 6 | .ipynb_checkpoints 7 | 8 | # Models checkpoints and binaries 9 | models/checkpoints 10 | models/binaries 11 | 12 | # Data directory 13 | data/ 14 | 15 | # Tensorboard dump dir 16 | runs/ 17 | runs_baseline/ 18 | 19 | # OpenAI gym Monitor's video dump location 20 | gym_monitor/ 21 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Aleksa Gordić 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Reinforcement Learning (PyTorch) :robot: + :cake: = :heart: 2 | 3 | This repo will contain PyTorch implementation of various fundamental RL algorithms.
4 | It's aimed at making it **easy** to start playing and learning about RL.
5 | 6 | The problem I came across investigating other DQN projects is that they either: 7 | * Don't have any evidence that they've actually achieved the published results 8 | * Don't have a "smart" replay buffer (i.e. they allocate (1M, 4, 84, 84) ~ 28 GBs! instead of (1M, 84, 84) ~ 7 GB) 9 | * Lack of visualizations and debugging utils 10 | 11 | This repo will aim to solve these problems. 12 | 13 | ## Table of Contents 14 | * [RL agents](#rl-agents) 15 | * [DQN](#dqn) 16 | * [DQN current results](#dqn-current-results) 17 | * [Setup](#setup) 18 | * [Usage](#usage) 19 | * [Training DQN](#training-dqn) 20 | * [Visualization and debugging tools](#visualization-and-debugging-tools) 21 | * [Hardware requirements](#hardware-requirements) 22 | * [Future todos](#future-todos) 23 | * [Learning material](#learning-material) 24 | 25 | ## RL agents 26 | 27 | ## DQN 28 | 29 | This was the project that started the revolution in the RL world - deep Q-network (:link: [Mnih et al.](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)),
30 | aka "Human-level control through deep RL". 31 | 32 | DQN model learned to play **29 Atari games** (out of 49 they it tested on) on a **super-human**/comparable-to-humans level. 33 | Here is the schematic of it's CNN architecture: 34 | 35 |

36 | 37 |

38 | 39 | The fascinating part is that it learned only from "high-dimensional" (84x84) images and (usually sparse) rewards. 40 | The same architecture was used for all of the 49 games - although the model has to be retrained, from scratch, every single time. 41 | 42 | ## DQN current results 43 | 44 | Since it takes [lots of compute and time](#hardware-requirements) to train all of the 49 models I'll consider the DQN project completed once 45 | I succeed in achieving the published results on: 46 | * Breakout 47 | * Pong 48 | 49 | --- 50 | 51 | Having said that the experiments are still in progress, so feel free to **contribute**! 52 | * For some reason the models aren't learning very well so if you find a bug open up a PR! :heart: 53 | * I'm also experiencing slowdowns - so any PRs that would improve/explain the perf are welcome! 54 | * If you decide to train the DQN using this repo on some other Atari game I'll gladly check-in your model! 55 | 56 | **Important note: please follow the coding guidelines of this repo before you submit a PR so that we can minimize 57 | the back-and-forth. I'm a decently busy guy as I assume you are.** 58 | 59 | ### Current results - Breakout 60 | 61 |

62 | 63 |

64 | 65 | As you can see the model did learn something although it's far from being really good. 66 | 67 | ### Current results - Pong 68 | 69 | todo 70 | 71 | ## Setup 72 | 73 | Let's get this thing running! Follow the next steps: 74 | 75 | 1. `git clone https://github.com/gordicaleksa/pytorch-learn-reinforcement-learning` 76 | 2. Open Anaconda console and navigate into project directory `cd path_to_repo` 77 | 3. Run `conda env create` from project directory (this will create a brand new conda environment). 78 | 4. Run `activate pytorch-rl-env` (for running scripts from your console or setup the interpreter in your IDE) 79 | 80 | If you're on Windows you'll additionally need to install this: 81 | `pip install https://github.com/Kojoley/atari-py/releases atary_py` to install gym's Atari dependencies. 82 | 83 | Otherwise this should do it `pip install 'gym[atari]'`, if it's not working check out [this](https://stackoverflow.com/questions/49947555/openai-gym-trouble-installing-atari-dependency-mac-os-x) and [this](https://github.com/openai/gym/issues/1170). 84 | 85 | That's it! It should work out-of-the-box executing environment.yml file which deals with dependencies.
86 | 87 | ----- 88 | 89 | PyTorch pip package will come bundled with some version of CUDA/cuDNN with it, 90 | but it is highly recommended that you install a system-wide CUDA beforehand, mostly because of the GPU drivers. 91 | I also recommend using Miniconda installer as a way to get conda on your system. 92 | Follow through points 1 and 2 of [this setup](https://github.com/Petlja/PSIML/blob/master/docs/MachineSetup.md) 93 | and use the most up-to-date versions of Miniconda and CUDA/cuDNN for your system. 94 | 95 | ## Usage 96 | 97 | #### Option 1: Jupyter Notebook 98 | 99 | Coming soon. 100 | 101 | #### Option 2: Use your IDE of choice 102 | 103 | You just need to link the Python environment you created in the [setup](#setup) section. 104 | 105 | ## Training DQN 106 | 107 | To run with default settings just run `python train_DQN_script.py`. 108 | 109 | Settings you'll want to experiment with: 110 | * `--seed` - it may just so happen that I've chosen a bad one (RL is very sensitive) 111 | * `--learning_rate` - DQN originally used RMSProp, I saw that Adam with 1e-4 worked for stable baselines 3 112 | * `--grad_clipping_value` - there was [a lot of noise](#visualization-tools) in the gradients so I used this to control it 113 | * Try using RMSProp (I haven't yet). Adam was an improvement over RMSProp so I doubt it's causing the issues 114 | 115 | Less important settings for getting DQN to work: 116 | * `--env_id` - depending on which game you want to train on (I'd focus on the easiest one for now - Breakout) 117 | * `--replay_buffer_size` - hopefully you can train DQN with 1M, as in the original paper, if not make it smaller 118 | * `--dont_crash_if_no_mem` - add this flag if you want to run with 1M replay buffer even if you don't have enough RAM 119 | 120 | The training script will: 121 | * Dump checkpoint *.pth models into `models/checkpoints/` 122 | * Dump the best (highest reward) *.pth model into `models/binaries/` <- TODO 123 | * Periodically write some training metadata to the console 124 | * Save tensorboard metrics into `runs/`, to use it check out [the visualization section](#visualization-tools) 125 | 126 | ## Visualization and debugging tools 127 | 128 | You can visualize the metrics during the training, by calling `tensorboard --logdir=runs` from your console 129 | and pasting the `http://localhost:6006/` URL into your browser. 130 | 131 | I'm currently visualizing the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss) (and you can see there is something weird going on): 132 | 133 |

134 | 135 |

136 | 137 | Rewards and steps taken per episode (there is a fair bit of correlation between these 2): 138 | 139 |

140 | 141 | 142 |

143 | 144 | And gradient L2 norms of weights and biases of every CNN/FC layer as well as the complete grad vector: 145 | 146 |

147 | 148 |

149 | 150 | As well as epsilon (from the epsilon-greedy algorithm) but that plot is not that informative so I'll omit it here. 151 | 152 | As you can see the plots are super **noisy**! As I could have expected, but the progress just stagnates from certain point onwards 153 | and that's what I'm trying to debug atm. 154 | 155 | --- 156 | 157 | To enter the debug mode add the `--debug` flag to your console or IDE's list of script arguments. 158 | 159 | It'll visualize the current state that's being fed into the RL agent. 160 | Sometimes the state will have some black frames prepended since there aren't enough frames experienced in the current episode: 161 | 162 |

163 | 164 |

165 | 166 | But mostly all of the 4 frames will be in there: 167 | 168 |

169 | 170 |

171 | 172 | And it will start rendering the game frames (`Pong` and `Breakout` showed here from left to right): 173 | 174 |

175 | 176 | 177 |

178 | 179 | ## Hardware requirements 180 | 181 | You'll need some decent hardware to train the DQN in reasonable time so that you can iterate fast: 182 | 1) **16+ GB of RAM** (Replay Buffer takes around ~7 GBs of RAM). 183 | 2) The faster your GPU is - the better! :sweat_smile: Having said that VRAM is not the bottleneck you'll need **2+ GB VRAM**. 184 | 185 | With 16 GB RAM and RTX 2080 it takes ~5 days to train DQN on my machine - I'm **experiencing some slowdowns** which I 186 | haven't debugged yet. Here is the FPS (frames-per-second) metric I'm logging: 187 | 188 |

189 | 190 |

191 | 192 | The shorter, green one is the current experiment I'm running, the red one took over 5 days to train. 193 | 194 | ## Future todos 195 | 196 | 1) Debug DQN and achieve the published results 197 | 2) Add Vanilla PG 198 | 3) Add PPO 199 | 200 | ## Learning material 201 | 202 | Here are some videos I made on RL which may help you to better understand how DQN and other RL algorithms work: 203 | 204 |

205 | DQN paper explained 207 |

208 | 209 | And some other ones: 210 | * [DeepMind: AlphaGo](https://www.youtube.com/watch?v=Z1BELqFQZVM) 211 | * [DeepMind: AlphaGo Zero and AlphaZero](https://www.youtube.com/watch?v=0slFo1rV0EM) 212 | * [OpenAI: Solving Rubik's Cube with a Robot Hand](https://www.youtube.com/watch?v=eTa-k1pgvnU) 213 | * [DeepMind: MuZero](https://www.youtube.com/watch?v=mH7f7N7s79s) 214 | 215 | And in this one I tried to film through the process while the project was not nearly as polished as it is now: 216 | * [DQN project update](https://www.youtube.com/watch?v=DrOp_MQGn9o&ab_channel=TheAIEpiphany) 217 | 218 | I'll soon create a blog on how to get started with RL - so stay tuned for that! 219 | 220 | ## Acknowledgements 221 | 222 | I found these resources useful while developing this project, sorted (approximately) by usefulness: 223 | 224 | * [Stable Baselines 3 DQN](https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/dqn/dqn.py) 225 | * [PyTorch reimplementation of Berkley's DQN](https://github.com/transedward/pytorch-dqn) and [Berkley's DQN](https://github.com/berkeleydeeprlcourse/homework/tree/master/hw3) 226 | * [pytorch-dqn](https://github.com/jacobaustin123/pytorch-dqn/blob/master/dqn.py) 227 | * [RL adventures DQN](https://github.com/higgsfield/RL-Adventure/blob/master/1.dqn.ipynb) and [minimal DQN](https://github.com/econti/minimal_dqn/blob/master/main.py) 228 | * [Pytorch tutorial](https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html) 229 | 230 | ## Citation 231 | 232 | If you find this code useful, please cite the following: 233 | 234 | ``` 235 | @misc{Gordić2021PyTorchLearnReinforcementLearning, 236 | author = {Gordić, Aleksa}, 237 | title = {pytorch-learn-reinforcement-learning}, 238 | year = {2021}, 239 | publisher = {GitHub}, 240 | journal = {GitHub repository}, 241 | howpublished = {\url{https://github.com/gordicaleksa/pytorch-learn-reinforcement-learning}}, 242 | } 243 | ``` 244 | 245 | ## Licence 246 | 247 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/gordicaleksa/pytorch-learn-reinforcement-learning/blob/master/LICENCE) -------------------------------------------------------------------------------- /data/readme_gifs/BreakoutNoFrameskip-v4_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_gifs/BreakoutNoFrameskip-v4_1.gif -------------------------------------------------------------------------------- /data/readme_pics/dqn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_pics/dqn.jpg -------------------------------------------------------------------------------- /data/readme_visualizations/breakout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/breakout.jpg -------------------------------------------------------------------------------- /data/readme_visualizations/fps_metric.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/fps_metric.PNG -------------------------------------------------------------------------------- /data/readme_visualizations/grads.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/grads.PNG -------------------------------------------------------------------------------- /data/readme_visualizations/huber_loss.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/huber_loss.PNG -------------------------------------------------------------------------------- /data/readme_visualizations/pong.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/pong.jpg -------------------------------------------------------------------------------- /data/readme_visualizations/rewards_per_episode.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/rewards_per_episode.PNG -------------------------------------------------------------------------------- /data/readme_visualizations/state_all_frames.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/state_all_frames.PNG -------------------------------------------------------------------------------- /data/readme_visualizations/state_initial.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/state_initial.PNG -------------------------------------------------------------------------------- /data/readme_visualizations/steps_per_episode.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-learn-reinforcement-learning/26dd439e73bb804b2065969caa5fa5429becfdd5/data/readme_visualizations/steps_per_episode.PNG -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pytorch-rl-env 2 | channels: 3 | - defaults 4 | - pytorch 5 | dependencies: 6 | - python==3.8.5 7 | - pip==21.0.1 8 | - pytorch==1.8.1 9 | - pip: 10 | - matplotlib==3.3.3 11 | - GitPython==3.1.2 12 | - psutil==5.8.0 13 | - stable-baselines3==1.0 14 | - opencv-python==4.5.1.48 15 | - imageio==2.9.0 16 | - jupyter==1.0.0 17 | - numpy==1.19.2 18 | - tensorboard==2.2.2 19 | - gym==0.17.3 20 | -------------------------------------------------------------------------------- /evaluate_DQN_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | import torch 6 | import matplotlib.pyplot as plt 7 | import cv2 as cv 8 | import numpy as np 9 | 10 | 11 | import utils.utils as utils 12 | from models.definitions.DQN import DQN 13 | from utils.constants import * 14 | from utils.replay_buffer import ReplayBuffer 15 | from utils.video_utils import create_gif 16 | 17 | 18 | if __name__ == '__main__': 19 | # Step 0: Modify these as needed 20 | buffer_size = 100000 21 | epsilon_eval = 0.05 22 | env_id = 'BreakoutNoFrameskip-v4' 23 | model_name = 'dqn_BreakoutNoFrameskip-v4_ckpt_steps_6810000.pth' 24 | should_record_video = True 25 | 26 | game_frames_dump_dir = os.path.join(DATA_DIR_PATH, 'dqn_eval_dump_dir') 27 | if os.path.exists(game_frames_dump_dir): 28 | shutil.rmtree(game_frames_dump_dir) 29 | os.makedirs(game_frames_dump_dir, exist_ok=True) 30 | 31 | # Step 1: Prepare environment, replay buffer and schedule 32 | env = utils.get_env_wrapper(env_id, record_video=should_record_video) 33 | replay_buffer = ReplayBuffer(buffer_size) 34 | const_schedule = utils.ConstSchedule(epsilon_eval) # lambda would also do - doing it like this for consistency 35 | 36 | # Step 2: Prepare the DQN model 37 | model_path = os.path.join(BINARIES_PATH, model_name) 38 | model_state = torch.load(model_path) 39 | assert model_state['env_id'] == env_id, \ 40 | f"Model {model_name} was trained on {model_state['env_id']} but you're running it on {env_id}." 41 | utils.print_model_metadata(model_state) 42 | 43 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 44 | dqn = DQN(env, number_of_actions=env.action_space.n, epsilon_schedule=const_schedule).to(device) 45 | dqn.load_state_dict(model_state["state_dict"], strict=True) 46 | dqn.eval() 47 | 48 | # Step 3: Evaluate the agent on a single episode 49 | print(f'{"*"*10} Starting the game. {"*"*10}') 50 | last_frame = env.reset() 51 | 52 | score = 0 53 | cnt = 0 54 | while True: 55 | replay_buffer.store_frame(last_frame) 56 | current_state = replay_buffer.fetch_last_state() # fetch the state, shape = (4, 84, 84) for Atari 57 | 58 | with torch.no_grad(): 59 | action = dqn.epsilon_greedy(current_state) # act in this state 60 | 61 | new_frame, reward, done, _ = env.step(action) # send the action to the environment 62 | score += reward 63 | 64 | env.render() # plot the current game frame 65 | screen = env.render(mode='rgb_array') # but also save it as an image 66 | processed_screen = cv.resize(screen[:, :, ::-1], (0, 0), fx=1.5, fy=1.5, interpolation=cv.INTER_NEAREST) 67 | cv.imwrite(os.path.join(game_frames_dump_dir, f'{str(cnt).zfill(5)}.jpg'), processed_screen) # cv expects BGR hence ::-1 68 | cnt += 1 69 | 70 | if done: 71 | print(f'Episode over, score = {score}.') 72 | break 73 | 74 | last_frame = new_frame # set the last frame to the newly acquired frame from the env 75 | 76 | create_gif(game_frames_dump_dir, os.path.join(DATA_DIR_PATH, f'{env_id}.gif'), fps=60) 77 | -------------------------------------------------------------------------------- /models/definitions/DQN.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import numpy as np 4 | 5 | 6 | from utils.utils import get_env_wrapper 7 | 8 | 9 | class DQN(nn.Module): 10 | """ 11 | I wrote the architecture a bit more generic, hence more lines of code, 12 | but it's more flexible if you want to experiment with the DQN architecture. 13 | 14 | """ 15 | def __init__(self, env, num_in_channels=4, number_of_actions=3, epsilon_schedule=None): 16 | super().__init__() 17 | self.env = env 18 | self.epsilon_schedule = epsilon_schedule # defines the annealing strategy for epsilon in epsilon-greedy 19 | self.num_calls_to_epsilon_greedy = 0 # counts the number of calls to epsilon greedy function 20 | 21 | # 22 | # CNN params - from the Nature DQN paper - MODIFY this part if you want to experiment 23 | # 24 | num_of_filters_cnn = [num_in_channels, 32, 64, 64] 25 | kernel_sizes = [8, 4, 3] 26 | strides = [4, 2, 1] 27 | 28 | # 29 | # Build CNN part of DQN 30 | # 31 | cnn_modules = [] 32 | for i in range(len(num_of_filters_cnn) - 1): 33 | cnn_modules.extend( 34 | self._cnn_block(num_of_filters_cnn[i], num_of_filters_cnn[i + 1], kernel_sizes[i], strides[i]) 35 | ) 36 | 37 | self.cnn_part = nn.Sequential( 38 | *cnn_modules, 39 | nn.Flatten() # flatten from (B, C, H, W) into (B, C*H*W), where B is batch size and C number of in channels 40 | ) 41 | 42 | # 43 | # Build fully-connected part of DQN 44 | # 45 | with torch.no_grad(): # automatically figure out the shape for the given env observation 46 | # shape = (1, C', H, W), uint8, where C' is originally 1, i.e. grayscale frames 47 | dummy_input = torch.from_numpy(env.observation_space.sample()[np.newaxis]) 48 | 49 | if dummy_input.shape[1] != num_in_channels: 50 | assert num_in_channels % dummy_input.shape[1] == 0 51 | # shape = (1, C, H, W), float 52 | dummy_input = dummy_input.repeat(1, int(num_in_channels / dummy_input.shape[1]), 1, 1).float() 53 | 54 | num_nodes_fc1 = self.cnn_part(dummy_input).shape[1] # cnn output shape = (B, C*H*W) 55 | print(f"DQN's first FC layer input dimension: {num_nodes_fc1}") 56 | 57 | # 58 | # FC params - MODIFY this part if you want to experiment 59 | # 60 | num_of_neurons_fc = [num_nodes_fc1, 512, number_of_actions] 61 | 62 | fc_modules = [] 63 | for i in range(len(num_of_neurons_fc) - 1): 64 | last_layer = i == len(num_of_neurons_fc) - 1 # last layer shouldn't have activation (Q-value is unbounded) 65 | fc_modules.extend(self._fc_block(num_of_neurons_fc[i], num_of_neurons_fc[i + 1], use_relu=not last_layer)) 66 | 67 | self.fc_part = nn.Sequential( 68 | *fc_modules 69 | ) 70 | 71 | def forward(self, states): 72 | # shape: (B, C, H, W) -> (B, NA) - where NA is the Number of Actions 73 | return self.fc_part(self.cnn_part(states)) 74 | 75 | def epsilon_greedy(self, state): 76 | assert self.epsilon_schedule is not None, f"No schedule provided, can't call epsilon_greedy function" 77 | assert state.shape[0] == 1, f'Agent can only act on a single state' 78 | self.num_calls_to_epsilon_greedy += 1 79 | 80 | # Epsilon-greedy exploration 81 | if np.random.rand() < self.epsilon_value(): 82 | # With epsilon probability act random 83 | action = self.env.action_space.sample() 84 | else: 85 | # Otherwise act greedily - choosing an action that maximizes Q 86 | # Shape evolution: (1, C, H, W) -> (forward) (1, NA) -> (argmax) (1, 1) -> [0] scalar 87 | action = self.forward(state).argmax(dim=1)[0].cpu().numpy() 88 | 89 | return action 90 | 91 | def epsilon_value(self): 92 | return self.epsilon_schedule(self.num_calls_to_epsilon_greedy) 93 | 94 | # 95 | # Helper/"private" functions 96 | # 97 | 98 | # The original CNN didn't use any padding: https://github.com/deepmind/dqn/blob/master/dqn/convnet.lua 99 | # not that it matters - it would probably work either way feel free to experiment with the architecture. 100 | def _cnn_block(self, num_in_filters, num_out_filters, kernel_size, stride): 101 | layers = [nn.Conv2d(num_in_filters, num_out_filters, kernel_size=kernel_size, stride=stride), nn.ReLU()] 102 | return layers 103 | 104 | def _fc_block(self, num_in_neurons, num_out_neurons, use_relu=True): 105 | layers = [nn.Linear(num_in_neurons, num_out_neurons)] 106 | if use_relu: 107 | layers.append(nn.ReLU()) 108 | return layers 109 | 110 | 111 | # Test DQN network 112 | if __name__ == '__main__': 113 | # NoFrameskip - receive every frame from the env whereas the version without NoFrameskip would give every 4th frame 114 | # v4 - actions we send to env are executed, whereas v0 would ignore the last action we sent with 0.25 probability 115 | env_id = "PongNoFrameskip-v4" 116 | env_wrapped = get_env_wrapper(env_id) 117 | dqn = DQN(env_wrapped) # testing only the __init__ function (mainly the automatic shape calculation mechanism) 118 | 119 | 120 | -------------------------------------------------------------------------------- /playground.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | from stable_baselines3.common.env_util import make_atari_env 5 | from stable_baselines3.common.vec_env import VecFrameStack 6 | from stable_baselines3 import DQN 7 | 8 | 9 | def run_dqn_baseline(): 10 | env = make_atari_env('BreakoutNoFrameskip-v4', n_envs=1, seed=0) 11 | env = VecFrameStack(env, n_stack=4) 12 | tensorboard_log = os.path.join(os.path.dirname(__file__), 'runs_baseline') 13 | buffer_size = 100000 14 | num_training_steps = 1000000 15 | 16 | model = DQN( 17 | 'CnnPolicy', 18 | env, 19 | verbose=0, 20 | buffer_size=buffer_size, 21 | learning_starts=50000, 22 | optimize_memory_usage=False, 23 | tensorboard_log=tensorboard_log 24 | ) 25 | model.learn(total_timesteps=num_training_steps) 26 | 27 | obs = env.reset() 28 | while True: 29 | action, _states = model.predict(obs) 30 | obs, rewards, dones, info = env.step(action) 31 | env.render() 32 | 33 | 34 | if __name__ == '__main__': 35 | run_dqn_baseline() 36 | -------------------------------------------------------------------------------- /train_DQN_script.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of the original DQN Nature paper: 3 | https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf 4 | 5 | Some of the complexity is captured via wrappers but the main components such as the DQN model itself, 6 | the training loop, the memory-efficient replay buffer are implemented from scratch. 7 | 8 | Some modifications: 9 | * Using Adam instead of RMSProp 10 | 11 | """ 12 | 13 | import os 14 | import argparse 15 | import time 16 | import copy 17 | 18 | 19 | import numpy as np 20 | import torch 21 | from torch import nn 22 | import matplotlib.pyplot as plt 23 | from torch.optim import Adam 24 | from torch.utils.tensorboard import SummaryWriter 25 | 26 | 27 | import utils.utils as utils 28 | from utils.replay_buffer import ReplayBuffer 29 | from utils.constants import * 30 | from models.definitions.DQN import DQN 31 | 32 | 33 | class ActorLearner: 34 | 35 | def __init__(self, config, env, replay_buffer, dqn, target_dqn, last_frame): 36 | 37 | self.start_time = time.time() 38 | 39 | self.config = config 40 | self.env = env 41 | self.last_frame = last_frame # always keeps the latest frame from the environment 42 | self.replay_buffer = replay_buffer 43 | 44 | # DQN Models 45 | self.dqn = dqn 46 | self.target_dqn = target_dqn 47 | 48 | # Logging/debugging-related 49 | self.debug = config['debug'] 50 | self.log_freq = config['log_freq'] 51 | self.episode_log_freq = config['episode_log_freq'] 52 | self.grads_log_freq = config['grads_log_freq'] 53 | self.checkpoint_freq = config['checkpoint_freq'] 54 | self.tensorboard_writer = SummaryWriter() 55 | self.huber_loss = [] 56 | self.best_episode_reward = -np.inf 57 | self.best_dqn_model = None # keeps a deep copy of the best DQN model so far (best = highest episode reward) 58 | 59 | # MSE/L2 between [-1,1] and L1 otherwise (as stated in the Nature paper) aka "Huber loss" 60 | self.loss = nn.SmoothL1Loss() 61 | self.optimizer = Adam(self.dqn.parameters(), lr=config['learning_rate']) 62 | self.grad_clip_value = config['grad_clipping_value'] 63 | 64 | self.acting_learning_step_ratio = config['acting_learning_step_ratio'] 65 | self.num_warmup_steps = config['num_warmup_steps'] 66 | self.batch_size = config['batch_size'] 67 | self.gamma = config['gamma'] # discount factor 68 | 69 | self.learner_cnt = 0 70 | self.target_dqn_update_interval = config['target_dqn_update_interval'] 71 | # should perform a hard or a soft update of target DQN weights 72 | self.tau = config['tau'] 73 | 74 | def collect_experience(self): 75 | # We're collecting more experience than we're doing weight updates (4x in the Nature paper) 76 | for _ in range(self.acting_learning_step_ratio): 77 | last_index = self.replay_buffer.store_frame(self.last_frame) 78 | state = self.replay_buffer.fetch_last_state() # state = 4 preprocessed last frames for Atari 79 | 80 | action = self.sample_action(state) 81 | new_frame, reward, done_flag, _ = self.env.step(action) 82 | 83 | self.replay_buffer.store_action_reward_done(last_index, action, reward, done_flag) 84 | 85 | if done_flag: 86 | new_frame = self.env.reset() 87 | self.maybe_log_episode() 88 | 89 | self.last_frame = new_frame 90 | 91 | if self.debug: 92 | self.visualize_state(state) 93 | self.env.render() 94 | 95 | self.maybe_log() 96 | 97 | def sample_action(self, state): 98 | if self.env.get_total_steps() < self.num_warmup_steps: 99 | action = self.env.action_space.sample() # initial warm up period - no learning, acting randomly 100 | else: 101 | with torch.no_grad(): 102 | action = self.dqn.epsilon_greedy(state) 103 | return action 104 | 105 | def get_number_of_env_steps(self): 106 | return self.env.get_total_steps() 107 | 108 | def learn_from_experience(self): 109 | current_states, actions, rewards, next_states, done_flags = self.replay_buffer.fetch_random_states(self.batch_size) 110 | 111 | # Better than detaching: in addition to target dqn not being a part of the computational graph it also 112 | # saves time/memory because we're not storing activations during forward propagation needed for the backprop 113 | with torch.no_grad(): 114 | # shape = (B, NA) -> (B, 1), where NA - number of actions 115 | # [0] because max returns (values, indices) tuples 116 | next_state_max_q_values = self.target_dqn(next_states).max(dim=1, keepdim=True)[0] 117 | 118 | # shape = (B, 1), TD targets. We need (1 - done) because when we're in a terminal state the next 119 | # state Q value should be 0 and we only use the reward information 120 | target_q_values = rewards + (1 - done_flags) * self.gamma * next_state_max_q_values 121 | 122 | # shape = (B, 1), pick those Q values that correspond to the actions we made in those states 123 | current_state_q_values = self.dqn(current_states).gather(dim=1, index=actions) 124 | 125 | loss = self.loss(target_q_values, current_state_q_values) 126 | self.huber_loss.append(loss.item()) 127 | 128 | self.optimizer.zero_grad() 129 | loss.backward() # compute the gradients 130 | 131 | if self.grad_clip_value is not None: # potentially clip gradients for stability reasons 132 | nn.utils.clip_grad_norm_(self.dqn.parameters(), self.grad_clip_value) 133 | 134 | self.optimizer.step() # update step 135 | self.learner_cnt += 1 136 | 137 | # Periodically update the target DQN weights (coupled to the number of DQN weight updates and not # env steps) 138 | if self.learner_cnt % self.target_dqn_update_interval == 0: 139 | if self.tau == 1.: 140 | print('Update target DQN (hard update)') 141 | self.target_dqn.load_state_dict(self.dqn.state_dict()) 142 | else: # soft update, the 2 branches can be merged together, leaving it like this for now 143 | raise Exception(f'Soft update is not yet implemented (hard update was used in the original paper)') 144 | 145 | @staticmethod 146 | def visualize_state(state): 147 | state = state[0].to('cpu').numpy() # (1/B, C, H, W) -> (C, H, W) 148 | stacked_frames = np.hstack([np.repeat((img * 255).astype(np.uint8)[:, :, np.newaxis], 3, axis=2) for img in state]) # (C, H, W) -> (H, C*W, 3) 149 | plt.imshow(stacked_frames) 150 | plt.show() 151 | 152 | def maybe_log_episode(self): 153 | rewards = self.env.get_episode_rewards() # we can do this thanks to the Monitor wrapper 154 | episode_lengths = self.env.get_episode_lengths() 155 | num_episodes = len(rewards) 156 | 157 | if self.episode_log_freq is not None and num_episodes % self.episode_log_freq == 0: 158 | self.tensorboard_writer.add_scalar('Rewards per episode', rewards[-1], num_episodes) 159 | self.tensorboard_writer.add_scalar('Steps per episode', episode_lengths[-1], num_episodes) 160 | 161 | if rewards[-1] > self.best_episode_reward: 162 | self.best_episode_reward = rewards[-1] 163 | self.config['best_episode_reward'] = self.best_episode_reward # metadata 164 | self.best_dqn_model = copy.deepcopy(self.dqn) # keep track of the model that gave the best reward 165 | 166 | def maybe_log(self): 167 | num_steps = self.env.get_total_steps() 168 | 169 | if self.log_freq is not None and num_steps > 0 and num_steps % self.log_freq == 0: 170 | self.tensorboard_writer.add_scalar('Epsilon', self.dqn.epsilon_value(), num_steps) 171 | if len(self.huber_loss) > 0: 172 | self.tensorboard_writer.add_scalar('Huber loss', np.mean(self.huber_loss), num_steps) 173 | self.tensorboard_writer.add_scalar('FPS', num_steps / (time.time() - self.start_time), num_steps) 174 | 175 | self.huber_loss = [] # clear the loss values and start recollecting them again 176 | 177 | # Periodically save DQN models 178 | if self.checkpoint_freq is not None and num_steps > 0 and num_steps % self.checkpoint_freq == 0: 179 | ckpt_model_name = f'dqn_{self.config["env_id"]}_ckpt_steps_{num_steps}.pth' 180 | torch.save(utils.get_training_state(self.config, self.dqn), os.path.join(CHECKPOINTS_PATH, ckpt_model_name)) 181 | 182 | # Log the gradients 183 | if self.grads_log_freq is not None and self.learner_cnt > 0 and self.learner_cnt % self.grads_log_freq == 0: 184 | total_grad_l2_norm = 0 185 | 186 | for cnt, (name, weight_or_bias_parameters) in enumerate(self.dqn.named_parameters()): 187 | grad_l2_norm = weight_or_bias_parameters.grad.data.norm(p=2).item() 188 | self.tensorboard_writer.add_scalar(f'grad_norms/{name}', grad_l2_norm, self.learner_cnt) 189 | total_grad_l2_norm += grad_l2_norm ** 2 190 | 191 | # As if we concatenated all of the params into a single vector and took L2 192 | total_grad_l2_norm = total_grad_l2_norm ** (1/2) 193 | self.tensorboard_writer.add_scalar(f'grad_norms/total', total_grad_l2_norm, self.learner_cnt) 194 | 195 | def log_to_console(self): # keep it minimal for now, I mostly use tensorboard - feel free to expand functionality 196 | print(f'Number of env steps = {self.get_number_of_env_steps()}') 197 | 198 | 199 | def train_dqn(config): 200 | env = utils.get_env_wrapper(config['env_id']) 201 | replay_buffer = ReplayBuffer(config['replay_buffer_size'], crash_if_no_mem=config['dont_crash_if_no_mem']) 202 | 203 | utils.set_random_seeds(env, config['seed']) 204 | 205 | linear_schedule = utils.LinearSchedule( 206 | config['epsilon_start_value'], 207 | config['epsilon_end_value'], 208 | config['epsilon_duration'] 209 | ) 210 | 211 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 212 | dqn = DQN(env, number_of_actions=env.action_space.n, epsilon_schedule=linear_schedule).to(device) 213 | target_dqn = DQN(env, number_of_actions=env.action_space.n).to(device) 214 | 215 | # Don't get confused by the actor-learner terminology, DQN is not an actor-critic method, but conceptually 216 | # we can split the learning process into collecting experience/acting in the env and learning from that experience 217 | actor_learner = ActorLearner(config, env, replay_buffer, dqn, target_dqn, env.reset()) 218 | 219 | while actor_learner.get_number_of_env_steps() < config['num_of_training_steps']: 220 | 221 | num_env_steps = actor_learner.get_number_of_env_steps() 222 | if config['console_log_freq'] is not None and num_env_steps % config['console_log_freq'] == 0: 223 | actor_learner.log_to_console() 224 | 225 | actor_learner.collect_experience() 226 | 227 | if num_env_steps > config['num_warmup_steps']: 228 | actor_learner.learn_from_experience() 229 | 230 | torch.save( # save the best DQN model overall (gave the highest reward in an episode) 231 | utils.get_training_state(config, actor_learner.best_dqn_model), 232 | os.path.join(BINARIES_PATH, utils.get_available_binary_name(config['env_id'])) 233 | ) 234 | 235 | 236 | def get_training_args(): 237 | parser = argparse.ArgumentParser() 238 | 239 | # Training related 240 | parser.add_argument("--seed", type=int, help="Very important for reproducibility - set the random seed", default=23) 241 | parser.add_argument("--env_id", type=str, help="Atari game id", default='BreakoutNoFrameskip-v4') 242 | parser.add_argument("--num_of_training_steps", type=int, help="Number of training env steps", default=50000000) 243 | parser.add_argument("--acting_learning_step_ratio", type=int, help="Number of experience collection steps for every learning step", default=4) 244 | parser.add_argument("--learning_rate", type=float, default=1e-4) 245 | parser.add_argument("--grad_clipping_value", type=float, default=5) # 5 is fairly arbitrarily chosen 246 | 247 | parser.add_argument("--replay_buffer_size", type=int, help="Number of frames to store in buffer", default=1000000) 248 | parser.add_argument("--dont_crash_if_no_mem", action='store_false', help="Optimization - crash if not enough RAM before the training even starts (default=True)") 249 | parser.add_argument("--num_warmup_steps", type=int, help="Number of steps before learning starts", default=50000) 250 | parser.add_argument("--target_dqn_update_interval", type=int, help="Target DQN update freq per learning update", default=10000) 251 | 252 | parser.add_argument("--batch_size", type=int, help="Number of states in a batch (from replay buffer)", default=32) 253 | parser.add_argument("--gamma", type=float, help="Discount factor", default=0.99) 254 | parser.add_argument("--tau", type=float, help='Set to 1 for a hard target DQN update, < 1 for a soft one', default=1.) 255 | 256 | # epsilon-greedy annealing params 257 | parser.add_argument("--epsilon_start_value", type=float, default=1.) 258 | parser.add_argument("--epsilon_end_value", type=float, default=0.1) 259 | parser.add_argument("--epsilon_duration", type=int, default=1000000) 260 | 261 | # Logging/debugging/checkpoint related (helps a lot with experimentation) 262 | parser.add_argument("--console_log_freq", type=int, help="Log to console after this many env steps (None = no logging)", default=10000) 263 | parser.add_argument("--log_freq", type=int, help="Log metrics to tensorboard after this many env steps (None = no logging)", default=10000) 264 | parser.add_argument("--episode_log_freq", type=int, help="Log metrics to tensorboard after this many episodes (None = no logging)", default=5) 265 | parser.add_argument("--checkpoint_freq", type=int, help="Save checkpoint model after this many env steps (None = no checkpointing)", default=10000) 266 | parser.add_argument("--grads_log_freq", type=int, help="Log grad norms after this many weight update steps (None = no logging)", default=2500) 267 | parser.add_argument("--debug", action='store_true', help='Train in debugging mode') 268 | args = parser.parse_args() 269 | 270 | # Wrapping training configuration into a dictionary 271 | training_config = dict() 272 | for arg in vars(args): 273 | training_config[arg] = getattr(args, arg) 274 | 275 | return training_config 276 | 277 | 278 | if __name__ == '__main__': 279 | # Train the DQN model 280 | train_dqn(get_training_args()) 281 | 282 | 283 | -------------------------------------------------------------------------------- /utils/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | BINARIES_PATH = os.path.join(os.path.dirname(__file__), os.pardir, 'models', 'binaries') 5 | CHECKPOINTS_PATH = os.path.join(os.path.dirname(__file__), os.pardir, 'models', 'checkpoints') 6 | DATA_DIR_PATH = os.path.join(os.path.dirname(__file__), os.pardir, 'data') 7 | 8 | 9 | # Make sure these exist as the rest of the code assumes it 10 | os.makedirs(BINARIES_PATH, exist_ok=True) 11 | os.makedirs(CHECKPOINTS_PATH, exist_ok=True) 12 | -------------------------------------------------------------------------------- /utils/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | import numpy as np 5 | import psutil 6 | import torch 7 | 8 | 9 | from utils.utils import get_env_wrapper 10 | 11 | 12 | class ReplayBuffer: 13 | """ 14 | Since stable baselines 3 doesn't currently support a smart replay buffer (more concretely the "lazy frames" feature) 15 | i.e. allocating (10^6, 84, 84) (~7 GB) for Atari and extracting 4 frames as needed, instead of (10^6, 4, 84, 84), 16 | here is an efficient implementation. 17 | 18 | Note: inspired by Berkley's implementation: https://github.com/berkeleydeeprlcourse/homework/tree/master/hw3 19 | 20 | Further improvements: 21 | * Much more concise (and hopefully readable) 22 | * Reports error if you don't have enough RAM in advance to allocate this buffer 23 | * Fixed a subtle buffer boundary bug (start index edge case) 24 | 25 | """ 26 | def __init__(self, max_buffer_size, num_last_frames_to_fetch=4, frame_shape=[1, 84, 84], crash_if_no_mem=True): 27 | self.max_buffer_size = max_buffer_size 28 | self.current_buffer_size = 0 29 | self.current_free_slot_index = 0 30 | 31 | assert frame_shape[0] in (1, 3), f'Expected mono/color image frame got shape={frame_shape}.' 32 | self.frame_height = frame_shape[1] 33 | self.frame_width = frame_shape[2] 34 | self.num_previous_frames_to_fetch = num_last_frames_to_fetch 35 | 36 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 37 | 38 | # Create main buffer containers - all types are chosen so as to optimize memory consumption 39 | self.frames = np.zeros([self.max_buffer_size] + frame_shape, dtype=np.uint8) 40 | self.actions = np.zeros([self.max_buffer_size, 1], dtype=np.uint8) 41 | self.rewards = np.zeros([self.max_buffer_size, 1], dtype=np.float32) # we need extra precision for rewards 42 | self.dones = np.zeros([self.max_buffer_size, 1], dtype=np.uint8) 43 | 44 | # Numpy does lazy execution so it can happen that only after a while the training starts hitting the RAM limit 45 | # then the page swapping kicks in which will slow-down the training significantly only after hours of training! 46 | # hence the _check_enough_ram function (I'm saving you time and money thank me later ^^) 47 | self._check_enough_ram(crash_if_no_mem) 48 | 49 | # 50 | # public API functions 51 | # 52 | 53 | def store_frame(self, frame): 54 | self.frames[self.current_free_slot_index] = frame 55 | 56 | self.current_free_slot_index = (self.current_free_slot_index + 1) % self.max_buffer_size # circular logic 57 | self.current_buffer_size = min(self.max_buffer_size, self.current_buffer_size + 1) 58 | 59 | return self.current_free_slot_index - 1 # we yet need to store (action, reward, done) at this index 60 | 61 | def store_action_reward_done(self, index, action, reward, done): 62 | self.actions[index] = action 63 | self.rewards[index] = reward 64 | self.dones[index] = done 65 | 66 | def fetch_random_states(self, batch_size): 67 | assert self._has_enough_data(batch_size), "Can't fetch states from the replay buffer - not enough data." 68 | # Uniform random sampling without replacement. -1 because we always need to fetch the current and the immediate 69 | # next state for Q-learning but the last state in the buffer doesn't have the next state 70 | random_unique_indices = random.sample(range(self.current_buffer_size - 1), batch_size) 71 | 72 | states = self._postprocess_state( 73 | np.concatenate([self._fetch_state(i) for i in random_unique_indices], 0) # shape = (B, C, H, W) 74 | ) 75 | next_states = self._postprocess_state( 76 | np.concatenate([self._fetch_state(i + 1) for i in random_unique_indices], 0) # shape = (B, C, H, W) 77 | ) 78 | # Long is needed because actions are used for indexing of tensors (PyTorch constraint) 79 | actions = torch.from_numpy(self.actions[random_unique_indices]).to(self.device).long() 80 | rewards = torch.from_numpy(self.rewards[random_unique_indices]).to(self.device) 81 | # Float is needed because we'll be multiplying Q values with done flags (1-done actually) 82 | dones = torch.from_numpy(self.dones[random_unique_indices]).to(self.device).float() 83 | 84 | return states, actions, rewards, next_states, dones 85 | 86 | def fetch_last_state(self): 87 | # shape = (1, C, H, W) where C - number of past frames, 4 for Atari 88 | return self._postprocess_state( 89 | self._fetch_state((self.current_free_slot_index - 1) % self.max_buffer_size) 90 | ) 91 | 92 | def get_current_size(self): 93 | return self.current_buffer_size 94 | 95 | # 96 | # Helper functions 97 | # 98 | 99 | def _fetch_state(self, end_index): 100 | """ 101 | We fetch end_index frame and ("num_last_frames_to_fetch" - 1) last frames (4 in total in the case of Atari) 102 | in order to generate a state. 103 | 104 | Replay buffer has 2 edge cases that we need to handle: 105 | 1) start_index related: 106 | * index is "too close"* to 0 and our circular buffer is still not full, thus we don't have enough frames 107 | * index is "too close" to the buffer boundary we could mix very old/new observations 108 | 109 | 2) done flag is True - we don't won't to take observations before that index since it belongs to a different 110 | life or episode. 111 | 112 | Notes: 113 | * "too close" is defined by 'num_last_frames_to_fetch' variable 114 | * terminology: state consists out of multiple observations (frames in Atari case) 115 | 116 | """ 117 | # Start index is included, end index is excluded <=> [) 118 | end_index += 1 119 | start_index = end_index - self.num_previous_frames_to_fetch 120 | start_index = self._handle_start_index_edge_cases(start_index, end_index) 121 | 122 | num_of_missing_frames = self.num_previous_frames_to_fetch - (end_index - start_index) 123 | 124 | if start_index < 0 or num_of_missing_frames > 0: # start_index:end_index indexing won't work if start_index < 0 125 | # If there are missing frames, because of the above handled edge-cases, fill them with black frames as per 126 | # original DeepMind Lua imp: https://github.com/deepmind/dqn/blob/master/dqn/TransitionTable.lua#L171 127 | state = [np.zeros_like(self.frames[0]) for _ in range(num_of_missing_frames)] 128 | 129 | for index in range(start_index, end_index): 130 | state.append(self.frames[index % self.max_buffer_size]) 131 | 132 | # shape = (C, H, W) -> (1, C, H, W) where C - number of past frames, 4 for Atari 133 | return np.concatenate(state, 0)[np.newaxis, :] 134 | else: 135 | # reshape from (C, 1, H, W) to (1, C, H, W) where C number of past frames, 4 for Atari 136 | return self.frames[start_index:end_index].reshape(-1, self.frame_height, self.frame_width)[np.newaxis, :] 137 | 138 | def _postprocess_state(self, state): 139 | # numpy -> tensor, move to device, uint8 -> float, [0,255] -> [0, 1] 140 | return torch.from_numpy(state).to(self.device).float().div(255) 141 | 142 | def _handle_start_index_edge_cases(self, start_index, end_index): 143 | # Edge case 1: 144 | if not self._buffer_full() and start_index < 0: 145 | start_index = 0 146 | 147 | # Edge case 2: 148 | # Handle the case where start index crosses the buffer head pointer - the data before and after the head pointer 149 | # belongs to completely different episodes 150 | if self._buffer_full(): 151 | if 0 < (self.current_free_slot_index - start_index) % self.max_buffer_size < self.num_previous_frames_to_fetch: 152 | start_index = self.current_free_slot_index 153 | 154 | # Edge case 3: 155 | # A done flag marks a boundary between different episodes or lives either way we shouldn't take frames 156 | # before or at the done flag into consideration 157 | for index in range(start_index, end_index - 1): 158 | if self.dones[index % self.max_buffer_size]: 159 | start_index = index + 1 160 | 161 | return start_index 162 | 163 | def _buffer_full(self): 164 | return self.current_buffer_size == self.max_buffer_size 165 | 166 | def _has_enough_data(self, batch_size): 167 | return batch_size < self.current_buffer_size # e.g. if buffer size is 32 we need at least 33 frames hence < 168 | 169 | def _check_enough_ram(self, crash_if_no_mem): 170 | def to_GBs(memory_in_bytes): # beautify memory output - helper function 171 | return f'{memory_in_bytes / 2 ** 30:.2f} GBs' 172 | 173 | available_memory = psutil.virtual_memory().available 174 | required_memory = self.frames.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes 175 | print(f'required memory = {to_GBs(required_memory)} GB, available memory = {to_GBs(available_memory)} GB') 176 | 177 | if required_memory > available_memory: 178 | message = f"Not enough memory to store the complete replay buffer! \n" \ 179 | f"required: {to_GBs(required_memory)} > available: {to_GBs(available_memory)} \n" \ 180 | f"Page swapping will make your training super slow once you hit your RAM limit." \ 181 | f"You can either modify replay_buffer_size argument or set crash_if_no_mem to False to ignore it." 182 | if crash_if_no_mem: 183 | raise Exception(message) 184 | else: 185 | print(message) 186 | 187 | 188 | # Basic replay buffer testing 189 | if __name__ == '__main__': 190 | size = 500000 191 | num_of_collection_steps = 10000 192 | batch_size = 32 193 | 194 | # Step 0: Create replay buffer and the env 195 | replay_buffer = ReplayBuffer(size) 196 | 197 | # NoFrameskip - receive every frame from the env whereas the version without NoFrameskip would give every 4th frame 198 | # v4 - actions we send to env are executed, whereas v0 would ignore the last action we sent with 0.25 probability 199 | env_id = "PongNoFrameskip-v4" 200 | env = get_env_wrapper(env_id) 201 | 202 | # Step 1: Collect experience 203 | frame = env.reset() 204 | 205 | for i in range(num_of_collection_steps): 206 | random_action = env.action_space.sample() 207 | 208 | # For some reason for Pong gym returns more than 3 actions. 209 | print(f'Sampling action {random_action} - {env.unwrapped.get_action_meanings()[random_action]}') 210 | 211 | frame, reward, done, info = env.step(random_action) 212 | 213 | index = replay_buffer.store_frame(frame) 214 | replay_buffer.store_action_reward_done(index, random_action, reward, done) 215 | 216 | if done: 217 | env.reset() 218 | 219 | # Step 2: Fetch states from the buffer 220 | states, actions, rewards, next_states, dones = replay_buffer.fetch_random_states(batch_size) 221 | 222 | print(states.shape, next_states.shape, actions.shape, rewards.shape, dones.shape) 223 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import re 4 | 5 | 6 | import torch 7 | import git 8 | import gym 9 | import numpy as np 10 | from stable_baselines3.common.atari_wrappers import AtariWrapper 11 | from gym.wrappers import Monitor 12 | 13 | 14 | from .constants import * 15 | 16 | 17 | def get_env_wrapper(env_id, record_video=False): 18 | """ 19 | Ultimately it's not very clear why are SB3's wrappers and OpenAI gym's copy/pasted code for the most part. 20 | It seems that OpenAI gym doesn't have reward clipping which is necessary for Atari. 21 | 22 | I'm using SB3 because it's actively maintained compared to OpenAI's gym and it has reward clipping by default. 23 | 24 | """ 25 | monitor_dump_dir = os.path.join(os.path.dirname(__file__), os.pardir, 'gym_monitor') 26 | 27 | # This is necessary because AtariWrapper skips 4 frames by default, so we can't have additional skipping through 28 | # the environment itself - hence NoFrameskip requirement 29 | assert 'NoFrameskip' in env_id, f'Expected NoFrameskip environment got {env_id}' 30 | 31 | # The only additional thing needed, on top of AtariWrapper, 32 | # is to convert the shape to channel-first because of PyTorch's models 33 | env_wrapped = Monitor(ChannelFirst(AtariWrapper(gym.make(env_id))), monitor_dump_dir, force=True, video_callable=lambda episode: record_video) 34 | 35 | return env_wrapped 36 | 37 | 38 | class ChannelFirst(gym.ObservationWrapper): 39 | def __init__(self, env): 40 | super().__init__(env) 41 | new_shape = np.roll(self.observation_space.shape, shift=1) # shape: (H, W, C) -> (C, H, W) 42 | 43 | # Update because this is the last wrapper in the hierarchy, we'll be pooling the env for observation shape info 44 | self.observation_space = gym.spaces.Box(low=0, high=255, shape=new_shape, dtype=np.uint8) 45 | 46 | def observation(self, observation): 47 | return np.moveaxis(observation, 2, 0) # shape: (H, W, C) -> (C, H, W) 48 | 49 | 50 | class LinearSchedule: 51 | 52 | def __init__(self, schedule_start_value, schedule_end_value, schedule_duration): 53 | self.start_value = schedule_start_value 54 | self.end_value = schedule_end_value 55 | self.schedule_duration = schedule_duration 56 | 57 | def __call__(self, num_steps): 58 | progress = np.clip(num_steps / self.schedule_duration, a_min=None, a_max=1) # goes from 0 -> 1 and saturates 59 | return self.start_value + (self.end_value - self.start_value) * progress 60 | 61 | 62 | class ConstSchedule: 63 | """ Dummy schedule - used for DQN evaluation in evaluate_dqn_script.py. """ 64 | def __init__(self, value): 65 | self.value = value 66 | 67 | def __call__(self, num_steps): 68 | return self.value 69 | 70 | 71 | def print_model_metadata(training_state): 72 | header = f'\n{"*"*5} DQN model training metadata: {"*"*5}' 73 | print(header) 74 | 75 | for key, value in training_state.items(): 76 | if key != 'state_dict': # don't print state_dict it's a bunch of numbers... 77 | print(f'{key}: {value}') 78 | print(f'{"*" * len(header)}\n') 79 | 80 | 81 | def get_training_state(training_config, model): 82 | training_state = { 83 | # Reproducibility details 84 | "commit_hash": git.Repo(search_parent_directories=True).head.object.hexsha, 85 | "seed": training_config['seed'], 86 | 87 | # Env details 88 | "env_id": training_config['env_id'], 89 | 90 | # Training details 91 | "best_episode_reward": training_config['best_episode_reward'], 92 | 93 | # Model state 94 | "state_dict": model.state_dict() 95 | } 96 | 97 | return training_state 98 | 99 | 100 | def get_available_binary_name(env_id='env_unknown'): 101 | prefix = f'dqn_{env_id}' 102 | 103 | def valid_binary_name(binary_name): 104 | # First time you see raw f-string? Don't worry the only trick is to double the brackets. 105 | pattern = re.compile(rf'{prefix}_[0-9]{{6}}\.pth') 106 | return re.fullmatch(pattern, binary_name) is not None 107 | 108 | # Just list the existing binaries so that we don't overwrite them but write to a new one 109 | valid_binary_names = list(filter(valid_binary_name, os.listdir(BINARIES_PATH))) 110 | if len(valid_binary_names) > 0: 111 | last_binary_name = sorted(valid_binary_names)[-1] 112 | new_suffix = int(last_binary_name.split('.')[0][-6:]) + 1 # increment by 1 113 | return f'{prefix}_{str(new_suffix).zfill(6)}.pth' 114 | else: 115 | return f'{prefix}_000000.pth' 116 | 117 | 118 | def set_random_seeds(env, seed): 119 | if seed is not None: 120 | torch.manual_seed(seed) # PyTorch 121 | np.random.seed(seed) # NumPy 122 | random.seed(seed) # Python 123 | env.action_space.seed(seed) # probably redundant but I found an article where somebody had a problem with this 124 | env.seed(seed) # OpenAI gym 125 | 126 | # todo: AB test impact on FPS metric 127 | # Deterministic operations for CuDNN, it may impact performances 128 | if torch.cuda.is_available(): 129 | torch.backends.cudnn.deterministic = True 130 | torch.backends.cudnn.benchmark = False 131 | 132 | 133 | # Test utils 134 | if __name__ == '__main__': 135 | import matplotlib.pyplot as plt 136 | 137 | schedule = LinearSchedule(schedule_start_value=1., schedule_end_value=0.1, schedule_duration=50) 138 | 139 | schedule_values = [] 140 | for i in range(100): 141 | schedule_values.append(schedule(i)) 142 | 143 | plt.plot(schedule_values) 144 | plt.show() 145 | 146 | -------------------------------------------------------------------------------- /utils/video_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | import numpy as np 5 | import cv2 as cv 6 | import imageio 7 | 8 | 9 | def load_image(img_path, target_shape=None): 10 | if not os.path.exists(img_path): 11 | raise Exception(f'Path does not exist: {img_path}') 12 | img = cv.imread(img_path)[:, :, ::-1] # [:, :, ::-1] converts BGR (opencv format...) into RGB 13 | 14 | if target_shape is not None: # resize section 15 | if isinstance(target_shape, int) and target_shape != -1: # scalar -> implicitly setting the width 16 | current_height, current_width = img.shape[:2] 17 | new_width = target_shape 18 | new_height = int(current_height * (new_width / current_width)) 19 | img = cv.resize(img, (new_width, new_height), interpolation=cv.INTER_CUBIC) 20 | else: # set both dimensions to target shape 21 | img = cv.resize(img, (target_shape[1], target_shape[0]), interpolation=cv.INTER_CUBIC) 22 | 23 | # this need to go after resizing - otherwise cv.resize will push values outside of [0,1] range 24 | img = img.astype(np.float32) # convert from uint8 to float32 25 | img /= 255.0 # get to [0, 1] range 26 | return img 27 | 28 | 29 | def create_gif(frames_dir, out_path, fps=30, img_width=None): 30 | assert os.path.splitext(out_path)[1].lower() == '.gif', f'Expected .gif got {os.path.splitext(out_path)[1]}.' 31 | 32 | frame_paths = [os.path.join(frames_dir, frame_name) for frame_name in os.listdir(frames_dir) if frame_name.endswith('.jpg')] 33 | 34 | images = [imageio.imread(frame_path) for frame_path in frame_paths] 35 | imageio.mimwrite(out_path, images, fps=fps) 36 | print(f'Saved gif to {out_path}.') 37 | --------------------------------------------------------------------------------