├── .gitignore ├── README.md ├── assets └── breakout_10mio.gif ├── main.py ├── pretrained_models └── net.tar.gz ├── requirements.txt └── src ├── __init__.py ├── agent.py ├── config.py ├── dqn_agent.py ├── drqn_agent.py ├── env_wrapper.py ├── history.py ├── networks ├── __init__.py ├── base.py ├── dqn.py └── drqn.py ├── replay_memory.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | *.bk2 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DRQN-tensorflow 2 | Deep Recurrent Q Learning using Tensorflow, openai/gym and openai/retro 3 | 4 | This repository contains code for training a DQN or a DRQN on [openai/gym](https://github.com/openai/gym) Atari and [openai/retro](https://github.com/openai/retro) environments. 5 | 6 | Note that training on Retro environments is completely experimental as of now and these environments have to 7 | be wrapped to reduce the action space to a more sensible subspace of all 8 | actions for each game. The wrapper currently implemented only makes sense for 9 | the SEGA Sonic environments. 10 | ### Installation 11 | You can install all dependencies by issuing following command: 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | This will install Tensorflow without GPU support. However, I highly recommend using Tensorflow with GPU support, otherwise training will take a very long time. For more information on this topic please see https://www.tensorflow.org/install/. In order to run the retro environments, you have to gather the roms of the games you want to play and import them: https://github.com/openai/retro#roms 16 | ### Running 17 | You can start training by: 18 | ``` 19 | python main.py --gym=gym --steps=10000000 --train=True --network_type=dqn --env_name=Breakout-v0 20 | ``` 21 | This will train a DQN on Atari Breakout for 10 mio observations. For more on command line parameters please see 22 | ``` 23 | python main.py -h 24 | ``` 25 | Visualizing the training process can be done using tensorboard by: 26 | ``` 27 | tensorboard --logdir=out 28 | ``` 29 | ### Pretrained models 30 | A pretrained model for Breakout is available in `pretrained_models` 31 | ### Result after training for 10mio steps (approx. 11 hours on GTX 1080 Ti) 32 | ![Alt Text](https://github.com/marctuscher/dqn/blob/master/assets/breakout_10mio.gif) 33 | ### References 34 | 1. [DQN-tensorflow](https://github.com/devsisters/DQN-tensorflow) 35 | 2. [Playing Atari with Deep Reinforcement Learning](https://arxiv.org/pdf/1312.5602.pdf) 36 | 3. [Playing FPS Games with Deep Reinforcement Learning](https://arxiv.org/pdf/1609.05521.pdf) 37 | 4. [Deep Recurrent Q-Learning for Partially Observable MDPs](https://arxiv.org/pdf/1507.06527.pdf) 38 | -------------------------------------------------------------------------------- /assets/breakout_10mio.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marctuscher/DRQN-tensorflow/5ac7977528b677b9e19c0b4acdea63cdfc452959/assets/breakout_10mio.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from src.dqn_agent import DQNAgent 2 | from src.drqn_agent import DRQNAgent 3 | from src.config import RetroConfig, GymConfig 4 | import sys 5 | 6 | import argparse 7 | 8 | class Main(): 9 | 10 | def __init__(self, net_type, conf): 11 | if net_type == "drqn": 12 | self.agent = DRQNAgent(conf) 13 | else: 14 | self.agent = DQNAgent(conf) 15 | 16 | def train(self, steps): 17 | self.agent.train(steps) 18 | 19 | def play(self, episodes, net_path): 20 | self.agent.play(episodes, net_path) 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser(description="DRQN") 25 | parser.add_argument("--gym", type=str, default="gym", help="Type of the environment. Can either be 'gym' or 'retro'") 26 | parser.add_argument("--network_type", type=str, default="dqn", help="Type of the network to build, can either be 'dqn' or 'drqn'") 27 | parser.add_argument("--env_name", type=str, default="Breakout-v0", help="Name of the gym/retro environment used to train the agent") 28 | parser.add_argument("--retro_state", type=str, default="Start", help="Name of the state (level) to start training. This is only necessary for retro envs") 29 | parser.add_argument("--train", type=str, default="True", help="Whether to train a network or to play with a given network") 30 | parser.add_argument("--model_dir", type=str, default="saved_session/net/", help="directory to save the model and replay memory during training") 31 | parser.add_argument("--net_path", type=str, default="", help="path to checkpoint of model") 32 | parser.add_argument("--steps", type=int, default=50000000, help="number of frames to train") 33 | args, remaining = parser.parse_known_args() 34 | 35 | if args.gym == "gym": 36 | conf = GymConfig() 37 | conf.env_name = args.env_name 38 | else: 39 | conf = RetroConfig() 40 | conf.env_name = args.env_name 41 | conf.state = args.retro_state 42 | conf.network_type = args.network_type 43 | conf.train = args.train 44 | conf.dir_save = args.model_dir 45 | conf.train_steps = args.steps 46 | main = Main(conf.network_type, conf) 47 | 48 | if conf.train == "True": 49 | print(conf.train) 50 | main.train(conf.train_steps) 51 | else: 52 | assert args.net_path != "", "Please specify a net_path using the option --net_path" 53 | main.play(100000, args.net_path) 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /pretrained_models/net.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marctuscher/DRQN-tensorflow/5ac7977528b677b9e19c0b4acdea63cdfc452959/pretrained_models/net.tar.gz -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.6.1 2 | astor==0.7.1 3 | astroid==2.1.0 4 | atari-py==0.1.7 5 | bleach==3.1.0 6 | certifi==2018.11.29 7 | chardet==3.0.4 8 | cloudpickle==0.6.1 9 | cycler==0.10.0 10 | dask==1.0.0 11 | decorator==4.3.0 12 | future==0.17.1 13 | gast==0.2.1.post0 14 | grpcio==1.17.1 15 | gym==0.10.9 16 | gym-retro==0.6.0 17 | html5lib==1.0.1 18 | idna==2.8 19 | isort==4.3.4 20 | kiwisolver==1.0.1 21 | lazy-object-proxy==1.3.1 22 | llvmlite==0.27.0 23 | Markdown==3.0.1 24 | matplotlib==3.0.2 25 | mccabe==0.6.1 26 | networkx==2.2 27 | numba==0.42.0 28 | numpy==1.15.4 29 | opencv-python==4.0.0.21 30 | Pillow==6.2.0 31 | protobuf==3.6.1 32 | pyglet==1.3.2 33 | pylint==2.2.2 34 | PyOpenGL==3.1.0 35 | pyparsing==2.3.0 36 | python-dateutil==2.7.5 37 | pytz==2018.9 38 | PyWavelets==1.0.1 39 | requests==2.21.0 40 | scikit-image==0.14.1 41 | scipy==1.2.0 42 | six==1.12.0 43 | tensorboard==1.12.2 44 | tensorflow==1.12.2 45 | termcolor==1.1.0 46 | toolz==0.9.0 47 | tqdm==4.29.0 48 | urllib3==1.24.2 49 | Werkzeug==0.14.1 50 | wrapt==1.11.0 51 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marctuscher/DRQN-tensorflow/5ac7977528b677b9e19c0b4acdea63cdfc452959/src/__init__.py -------------------------------------------------------------------------------- /src/agent.py: -------------------------------------------------------------------------------- 1 | from src.env_wrapper import GymWrapper, RetroWrapper 2 | import numpy as np 3 | 4 | class BaseAgent(): 5 | 6 | def __init__(self, config): 7 | self.config = config 8 | if config.state != None: 9 | self.env_wrapper = RetroWrapper(config) 10 | else: 11 | self.env_wrapper = GymWrapper(config) 12 | self.rewards = 0 13 | self.lens = 0 14 | self.epsilon = config.epsilon_start 15 | self.min_reward = -1. 16 | self.max_reward = 1.0 17 | self.replay_memory = None 18 | self.history = None 19 | self.net = None 20 | if self.config.restore: 21 | self.load() 22 | else: 23 | self.i = 0 24 | 25 | 26 | 27 | def save(self): 28 | self.replay_memory.save() 29 | self.net.save_session() 30 | np.save(self.config.dir_save+'step.npy', self.i) 31 | 32 | def load(self): 33 | self.replay_memory.load() 34 | self.net.restore_session() 35 | self.i = np.load(self.config.dir_save+'step.npy') 36 | 37 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.client import device_lib 2 | 3 | def get_available_gpus(): 4 | local_device_protos = device_lib.list_local_devices() 5 | return [x.name for x in local_device_protos if x.device_type == 'GPU'] 6 | 7 | class Config(object): 8 | 9 | train_steps = 50000000 10 | batch_size = 64 11 | history_len = 4 12 | frame_skip = 4 13 | epsilon_start = 1.0 14 | epsilon_end = 0.02 15 | max_steps = 10000 16 | epsilon_decay_episodes = 1000000 17 | train_freq = 8 18 | update_freq = 10000 19 | train_start = 20000 20 | dir_save = "saved_session/" 21 | restore = False 22 | epsilon_decay = float((epsilon_start - epsilon_end))/float(epsilon_decay_episodes) 23 | random_start = 10 24 | test_step = 5000 25 | network_type = "dqn" 26 | 27 | 28 | gamma = 0.99 29 | learning_rate_minimum = 0.00025 30 | lr_method = "rmsprop" 31 | learning_rate = 0.00025 32 | lr_decay = 0.97 33 | keep_prob = 0.8 34 | 35 | num_lstm_layers = 1 36 | lstm_size = 512 37 | min_history = 4 38 | states_to_update = 4 39 | 40 | if get_available_gpus(): 41 | cnn_format = "NCHW" 42 | else: 43 | cnn_format = "NHWC" 44 | 45 | 46 | 47 | class GymConfig(Config): 48 | state = None 49 | screen_height = 84 50 | screen_width = 84 51 | env_name = "Breakout-v0" 52 | mem_size = 800000 53 | 54 | 55 | 56 | class RetroConfig(Config): 57 | state="HydrocityZone.Act1" 58 | mem_size = 100000 59 | screen_height = 120 60 | screen_width = 160 61 | env_name = "SonicAndKnuckles3-Genesis" 62 | -------------------------------------------------------------------------------- /src/dqn_agent.py: -------------------------------------------------------------------------------- 1 | from src.agent import BaseAgent 2 | from src.history import History 3 | from src.replay_memory import DQNReplayMemory 4 | from src.networks.dqn import DQN 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | class DQNAgent(BaseAgent): 9 | 10 | def __init__(self, config): 11 | super(DQNAgent, self).__init__(config) 12 | self.history = History(config) 13 | self.replay_memory = DQNReplayMemory(config) 14 | self.net = DQN(self.env_wrapper.action_space.n, config) 15 | self.net.build() 16 | self.net.add_summary(["average_reward", "average_loss", "average_q", "ep_max_reward", "ep_min_reward", "ep_num_game", "learning_rate"], ["ep_rewards", "ep_actions"]) 17 | 18 | def observe(self): 19 | reward = max(self.min_reward, min(self.max_reward, self.env_wrapper.reward)) 20 | screen = self.env_wrapper.screen 21 | self.history.add(screen) 22 | self.replay_memory.add(screen, reward, self.env_wrapper.action, self.env_wrapper.terminal) 23 | if self.i < self.config.epsilon_decay_episodes: 24 | self.epsilon -= self.config.epsilon_decay 25 | if self.i % self.config.train_freq == 0 and self.i > self.config.train_start: 26 | state, action, reward, state_, terminal = self.replay_memory.sample_batch() 27 | q, loss= self.net.train_on_batch_target(state, action, reward, state_, terminal, self.i) 28 | self.total_q += q 29 | self.total_loss += loss 30 | self.update_count += 1 31 | if self.i % self.config.update_freq == 0: 32 | self.net.update_target() 33 | 34 | def policy(self): 35 | if np.random.rand() < self.epsilon: 36 | return self.env_wrapper.random_step() 37 | else: 38 | state = self.history.get()/255.0 39 | a = self.net.q_action.eval({ 40 | self.net.state : [state] 41 | }, session=self.net.sess) 42 | return a[0] 43 | 44 | 45 | def train(self, steps): 46 | render = False 47 | self.env_wrapper.new_random_game() 48 | num_game, self.update_count, ep_reward = 0,0,0. 49 | total_reward, self.total_loss, self.total_q = 0.,0.,0. 50 | ep_rewards, actions = [], [] 51 | t = 0 52 | 53 | for _ in range(self.config.history_len): 54 | self.history.add(self.env_wrapper.screen) 55 | for self.i in tqdm(range(self.i, steps)): 56 | action = self.policy() 57 | self.env_wrapper.act(action) 58 | self.observe() 59 | if self.env_wrapper.terminal: 60 | t = 0 61 | self.env_wrapper.new_random_game() 62 | num_game += 1 63 | ep_rewards.append(ep_reward) 64 | ep_reward = 0. 65 | else: 66 | ep_reward += self.env_wrapper.reward 67 | t += 1 68 | actions.append(action) 69 | total_reward += self.env_wrapper.reward 70 | 71 | if self.i >= self.config.train_start: 72 | if self.i % self.config.test_step == self.config.test_step -1: 73 | avg_reward = total_reward / self.config.test_step 74 | avg_loss = self.total_loss / self.update_count 75 | avg_q = self.total_q / self.update_count 76 | 77 | try: 78 | max_ep_reward = np.max(ep_rewards) 79 | min_ep_reward = np.min(ep_rewards) 80 | avg_ep_reward = np.mean(ep_rewards) 81 | except: 82 | max_ep_reward, min_ep_reward, avg_ep_reward = 0, 0, 0 83 | 84 | sum_dict = { 85 | 'average_reward': avg_reward, 86 | 'average_loss': avg_loss, 87 | 'average_q': avg_q, 88 | 'ep_max_reward': max_ep_reward, 89 | 'ep_min_reward': min_ep_reward, 90 | 'ep_num_game': num_game, 91 | 'learning_rate': self.net.learning_rate, 92 | 'ep_rewards': ep_rewards, 93 | 'ep_actions': actions 94 | } 95 | self.net.inject_summary(sum_dict, self.i) 96 | num_game = 0 97 | total_reward = 0. 98 | self.total_loss = 0. 99 | self.total_q = 0. 100 | self.update_count = 0 101 | ep_reward = 0. 102 | ep_rewards = [] 103 | actions = [] 104 | 105 | if self.i % 500000 == 0 and self.i > 0: 106 | j = 0 107 | self.save() 108 | if self.i % 100000 == 0: 109 | j = 0 110 | render = True 111 | 112 | if render: 113 | self.env_wrapper.env.render() 114 | j += 1 115 | if j == 1000: 116 | render = False 117 | 118 | def play(self, episodes, net_path): 119 | self.net.restore_session(path=net_path) 120 | self.env_wrapper.new_game() 121 | i = 0 122 | for _ in range(self.config.history_len): 123 | self.history.add(self.env_wrapper.screen) 124 | episode_steps = 0 125 | while i < episodes: 126 | a = self.net.q_action.eval({ 127 | self.net.state : [self.history.get()/255.0] 128 | }, session=self.net.sess) 129 | action = a[0] 130 | self.env_wrapper.act_play(action) 131 | self.history.add(self.env_wrapper.screen) 132 | episode_steps += 1 133 | if episode_steps > self.config.max_steps: 134 | self.env_wrapper.terminal = True 135 | if self.env_wrapper.terminal: 136 | episode_steps = 0 137 | i += 1 138 | self.env_wrapper.new_play_game() 139 | for _ in range(self.config.history_len): 140 | screen = self.env_wrapper.screen 141 | self.history.add(screen) 142 | -------------------------------------------------------------------------------- /src/drqn_agent.py: -------------------------------------------------------------------------------- 1 | from src.agent import BaseAgent 2 | from src.replay_memory import DRQNReplayMemory 3 | from src.networks.drqn import DRQN 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | class DRQNAgent(BaseAgent): 8 | 9 | def __init__(self, config): 10 | super(DRQNAgent, self).__init__(config) 11 | self.replay_memory = DRQNReplayMemory(config) 12 | self.net = DRQN(self.env_wrapper.action_space.n, config) 13 | self.net.build() 14 | self.net.add_summary(["average_reward", "average_loss", "average_q", "ep_max_reward", "ep_min_reward", "ep_num_game", "learning_rate"], ["ep_rewards", "ep_actions"]) 15 | 16 | def observe(self, t): 17 | reward = max(self.min_reward, min(self.max_reward, self.env_wrapper.reward)) 18 | self.screen = self.env_wrapper.screen 19 | self.replay_memory.add(self.screen, reward, self.env_wrapper.action, self.env_wrapper.terminal, t) 20 | if self.i < self.config.epsilon_decay_episodes: 21 | self.epsilon -= self.config.epsilon_decay 22 | if self.i % self.config.train_freq == 0 and self.i > self.config.train_start: 23 | states, action, reward, terminal = self.replay_memory.sample_batch() 24 | q, loss= self.net.train_on_batch_target(states, action, reward, terminal, self.i) 25 | self.total_q += q 26 | self.total_loss += loss 27 | self.update_count += 1 28 | if self.i % self.config.update_freq == 0: 29 | self.net.update_target() 30 | 31 | def policy(self, state): 32 | self.random = False 33 | if np.random.rand() < self.epsilon: 34 | self.random = True 35 | return self.env_wrapper.random_step() 36 | else: 37 | a, self.lstm_state_c, self.lstm_state_h = self.net.sess.run([self.net.q_action, self.net.state_output_c, self.net.state_output_h],{ 38 | self.net.state : [[state]], 39 | self.net.c_state_train: self.lstm_state_c, 40 | self.net.h_state_train: self.lstm_state_h 41 | }) 42 | return a[0] 43 | 44 | 45 | def train(self, steps): 46 | render = False 47 | self.env_wrapper.new_random_game() 48 | num_game, self.update_count, ep_reward = 0,0,0. 49 | total_reward, self.total_loss, self.total_q = 0.,0.,0. 50 | ep_rewards, actions = [], [] 51 | t = 0 52 | self.screen = self.env_wrapper.screen 53 | self.lstm_state_c, self.lstm_state_h = self.net.initial_zero_state_single, self.net.initial_zero_state_single 54 | 55 | for self.i in tqdm(range(self.i, steps)): 56 | state = self.screen/255 57 | action = self.policy(state) 58 | self.env_wrapper.act(action) 59 | if self.random: 60 | self.lstm_state_c, self.lstm_state_h = self.net.sess.run([self.net.state_output_c, self.net.state_output_h], { 61 | self.net.state: [[state]], 62 | self.net.c_state_train : self.lstm_state_c, 63 | self.net.h_state_train: self.lstm_state_h 64 | }) 65 | self.observe(t) 66 | if self.env_wrapper.terminal: 67 | t = 0 68 | self.env_wrapper.new_random_game() 69 | self.screen = self.env_wrapper.screen 70 | num_game += 1 71 | ep_rewards.append(ep_reward) 72 | ep_reward = 0. 73 | self.lstm_state_c, self.lstm_state_h = self.net.initial_zero_state_single, self.net.initial_zero_state_single 74 | else: 75 | ep_reward += self.env_wrapper.reward 76 | t += 1 77 | actions.append(action) 78 | total_reward += self.env_wrapper.reward 79 | 80 | if self.i >= self.config.train_start: 81 | if self.i % self.config.test_step == self.config.test_step -1: 82 | avg_reward = total_reward / self.config.test_step 83 | avg_loss = self.total_loss / self.update_count 84 | avg_q = self.total_q / self.update_count 85 | 86 | try: 87 | max_ep_reward = np.max(ep_rewards) 88 | min_ep_reward = np.min(ep_rewards) 89 | avg_ep_reward = np.mean(ep_rewards) 90 | except: 91 | max_ep_reward, min_ep_reward, avg_ep_reward = 0, 0, 0 92 | 93 | sum_dict = { 94 | 'average_reward': avg_reward, 95 | 'average_loss': avg_loss, 96 | 'average_q': avg_q, 97 | 'ep_max_reward': max_ep_reward, 98 | 'ep_min_reward': min_ep_reward, 99 | 'ep_num_game': num_game, 100 | 'learning_rate': self.net.learning_rate, 101 | 'ep_rewards': ep_rewards, 102 | 'ep_actions': actions 103 | } 104 | self.net.inject_summary(sum_dict, self.i) 105 | num_game = 0 106 | total_reward = 0. 107 | self.total_loss = 0. 108 | self.total_q = 0. 109 | self.update_count = 0 110 | ep_reward = 0. 111 | ep_rewards = [] 112 | actions = [] 113 | 114 | if self.i % 500000 == 0 and self.i > 0: 115 | j = 0 116 | self.save() 117 | if self.i % 100000 == 0: 118 | j = 0 119 | render = True 120 | 121 | if render: 122 | self.env_wrapper.env.render() 123 | j += 1 124 | if j == 1000: 125 | render = False 126 | 127 | def play(self, episodes, net_path): 128 | self.net.restore_session(path=net_path) 129 | self.env_wrapper.new_game() 130 | self.lstm_state_c, self.lstm_state_h = self.net.initial_zero_state_single, self.net.initial_zero_state_single 131 | i = 0 132 | episode_steps = 0 133 | while i < episodes: 134 | a, self.lstm_state_c, self.lstm_state_h = self.net.sess.run([self.net.q_action, self.net.state_output_c, self.net.state_output_h],{ 135 | self.net.state : [[self.env_wrapper.screen]], 136 | self.net.c_state_train: self.lstm_state_c, 137 | self.net.h_state_train: self.lstm_state_h 138 | }) 139 | action = a[0] 140 | self.env_wrapper.act_play(action) 141 | episode_steps += 1 142 | if episode_steps > self.config.max_steps: 143 | self.env_wrapper.terminal = True 144 | if self.env_wrapper.terminal: 145 | episode_steps = 0 146 | i += 1 147 | self.env_wrapper.new_play_game() 148 | self.lstm_state_c, self.lstm_state_h = self.net.initial_zero_state_single, self.net.initial_zero_state_single 149 | -------------------------------------------------------------------------------- /src/env_wrapper.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import retro 3 | from src.utils import resize, rgb2gray 4 | import numpy as np 5 | 6 | 7 | 8 | class GymWrapper(): 9 | 10 | def __init__(self, config): 11 | self.env = gym.make(config.env_name) 12 | self.screen_width, self.screen_height = config.screen_width, config.screen_height 13 | self.reward = 0 14 | self.terminal = True 15 | self.info = {'ale.lives': 0} 16 | self.env.env.frameskip = config.frame_skip 17 | self.random_start = config.random_start 18 | self.action_space = self.env.action_space 19 | self._screen = np.empty((210, 160), dtype=np.uint8) 20 | 21 | def new_game(self): 22 | if self.lives == 0: 23 | self.env.reset() 24 | self._step(0) 25 | self.reward = 0 26 | self.action = 0 27 | 28 | def new_random_game(self): 29 | self.new_game() 30 | for _ in range(np.random.randint(0, self.random_start)): 31 | self._step(0) 32 | 33 | 34 | def _step(self, action): 35 | self.action = action 36 | _, self.reward, self.terminal, self.info = self.env.step(action) 37 | 38 | 39 | def random_step(self): 40 | return self.action_space.sample() 41 | 42 | def act(self, action): 43 | lives_before = self.lives 44 | self._step(action) 45 | if self.lives < lives_before: 46 | self.terminal = True 47 | 48 | 49 | def act_play(self, action): 50 | lives_before = self.lives 51 | self._step(action) 52 | self.env.render() 53 | if self.lives < lives_before: 54 | self.terminal = True 55 | 56 | def new_play_game(self): 57 | self.new_game() 58 | self._step(1) 59 | 60 | @property 61 | def screen(self): 62 | self._screen = self.env.env.ale.getScreenGrayscale(self._screen) 63 | a = resize(self._screen ,(self.screen_height, self.screen_width)) 64 | return a 65 | 66 | @property 67 | def lives(self): 68 | return self.info['ale.lives'] 69 | 70 | 71 | 72 | class RetroWrapper(): 73 | 74 | def __init__(self, config): 75 | """ 76 | TODO !!! Reward Shaping!!! 77 | Parameters 78 | ---------- 79 | config 80 | """ 81 | self.env = retro.make(game=config.env_name, state=config.state, use_restricted_actions=2) 82 | self.screen_width, self.screen_height = config.screen_width, config.screen_height 83 | self.reward = 0 84 | self.terminal = True 85 | self.info = {'lives': 0} 86 | self.frameskip = config.frame_skip 87 | self.random_start = config.random_start 88 | buttons = ["B", "A", "MODE", "START", "UP", "DOWN", "LEFT", "RIGHT", "C", "Y", "X", "Z"] 89 | actions = [['LEFT'], ['RIGHT'], ['LEFT', 'DOWN'], ['RIGHT', 'DOWN'], ['DOWN'], 90 | ['DOWN', 'B'], ['B']] 91 | self._actions = [] 92 | for action in actions: 93 | arr = np.array([False] * 12) 94 | for button in action: 95 | arr[buttons.index(button)] = True 96 | self._actions.append(arr) 97 | self.action_space = gym.spaces.Discrete(len(self._actions)) 98 | print(self.action_space.sample()) 99 | 100 | def new_game(self): 101 | if self.lives == 0: 102 | self.env.reset() 103 | self._step(0) 104 | self.reward = 0 105 | self.action = 0 106 | 107 | def new_random_game(self): 108 | self.new_game() 109 | for _ in range(np.random.randint(0, self.random_start)): 110 | self._step(0) 111 | 112 | def new_play_game(self): 113 | self.new_game() 114 | self._step(1) 115 | 116 | def _step(self, action): 117 | self.action = action 118 | self._screen, self.reward, self.terminal, self.info = self.env.step(action) 119 | self.reward += 0.1 * self.info['rings'] 120 | 121 | def random_step(self): 122 | return self.action_space.sample() 123 | 124 | def act(self, action): 125 | lives_before = self.lives 126 | # frameskip has to be incorporated on wrapper level 127 | for _ in range(self.frameskip): 128 | self._step(action) 129 | if self.lives < lives_before: 130 | self.terminal = True 131 | 132 | 133 | def act_play(self, action): 134 | lives_before = self.lives 135 | self._step(action) 136 | self.env.render() 137 | if self.lives < lives_before: 138 | self.terminal = True 139 | 140 | @property 141 | def screen(self): 142 | a = resize(rgb2gray(self._screen) ,(self.screen_height, self.screen_width)) 143 | return a 144 | 145 | @property 146 | def lives(self): 147 | return self.info['lives'] 148 | -------------------------------------------------------------------------------- /src/history.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | class History: 5 | 6 | def __init__(self, config): 7 | self.batch_size = config.batch_size 8 | self.history_len = config.history_len 9 | self.screen_width = config.screen_width 10 | self.screen_height = config.screen_height 11 | self.history = np.zeros((self.history_len, self.screen_height, self.screen_width), dtype=np.uint8) 12 | 13 | def add(self, screen): 14 | self.history[:-1] = self.history[1:] 15 | self.history[-1] = screen 16 | 17 | def reset(self): 18 | self.history *= 0 19 | 20 | def get(self): 21 | return self.history 22 | -------------------------------------------------------------------------------- /src/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marctuscher/DRQN-tensorflow/5ac7977528b677b9e19c0b4acdea63cdfc452959/src/networks/__init__.py -------------------------------------------------------------------------------- /src/networks/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import time 4 | import tensorflow as tf 5 | import datetime 6 | import shutil 7 | from functools import reduce 8 | from tensorflow.python import debug as tf_debug 9 | from src.utils import conv2d_layer, fully_connected_layer 10 | 11 | 12 | 13 | 14 | class BaseModel(): 15 | """ 16 | Base class for deep Q learning 17 | """ 18 | 19 | def __init__(self, config, network_type): 20 | 21 | self.screen_width = config.screen_width 22 | self.screen_height = config.screen_height 23 | self.gamma = config.gamma 24 | self.dir_save = config.dir_save 25 | self.learning_rate_minimum = config.learning_rate_minimum 26 | self.debug = not True 27 | self.keep_prob = config.keep_prob 28 | self.batch_size = config.batch_size 29 | self.lr_method = config.lr_method 30 | self.learning_rate = config.learning_rate 31 | self.lr_decay = config.lr_decay 32 | self.sess = None 33 | self.saver = None 34 | # delete ./out 35 | self.dir_output = "./out/"+network_type+"/"+config.env_name+"/"+ str(datetime.datetime.utcnow()) + "/" 36 | self.dir_model = self.dir_save + "/net/" +config.env_name+"/"+ str(datetime.datetime.utcnow()) + "/" 37 | 38 | self.train_steps = 0 39 | self.is_training = False 40 | 41 | def add_train_op(self, lr_method, lr, loss, clip=-1): 42 | _lr_m = lr_method.lower() # lower to make sure 43 | 44 | with tf.variable_scope("train_step"): 45 | if _lr_m == 'adam': # sgd method 46 | optimizer = tf.train.AdamOptimizer(lr) 47 | elif _lr_m == 'adagrad': 48 | optimizer = tf.train.AdagradOptimizer(lr) 49 | elif _lr_m == 'sgd': 50 | optimizer = tf.train.GradientDescentOptimizer(lr) 51 | elif _lr_m == 'rmsprop': 52 | optimizer = tf.train.RMSPropOptimizer(lr, momentum=0.95, epsilon=0.01) 53 | else: 54 | raise NotImplementedError("Unknown method {}".format(_lr_m)) 55 | 56 | if clip > 0: # gradient clipping if clip is positive 57 | grads, vs = zip(*optimizer.compute_gradients(loss)) 58 | grads, gnorm = tf.clip_by_global_norm(grads, clip) 59 | self.train_op = optimizer.apply_gradients(zip(grads, vs)) 60 | else: 61 | self.train_op = optimizer.minimize(loss) 62 | 63 | def initialize_session(self): 64 | print("Initializing tf session") 65 | self.sess = tf.Session() 66 | if self.debug: 67 | self.sess = tf_debug.TensorBoardDebugWrapperSession(self.sess, "localhost:6064") 68 | self.sess.run(tf.global_variables_initializer()) 69 | self.saver = tf.train.Saver() 70 | 71 | def close_session(self): 72 | self.sess.close() 73 | 74 | def add_summary(self, summary_tags, histogram_tags): 75 | self.summary_placeholders = {} 76 | self.summary_ops = {} 77 | for tag in summary_tags: 78 | self.summary_placeholders[tag] = tf.placeholder(tf.float32, None, name=tag) 79 | self.summary_ops[tag] = tf.summary.scalar(tag, self.summary_placeholders[tag]) 80 | for tag in histogram_tags: 81 | self.summary_placeholders[tag] = tf.placeholder('float32', None, name=tag) 82 | self.summary_ops[tag] = tf.summary.histogram(tag, self.summary_placeholders[tag]) 83 | self.file_writer = tf.summary.FileWriter(self.dir_output + "/train", 84 | self.sess.graph) 85 | 86 | def inject_summary(self, tag_dict, step): 87 | summary_str_lists = self.sess.run([self.summary_ops[tag] for tag in tag_dict], { 88 | self.summary_placeholders[tag]: value for tag, value in tag_dict.items() 89 | }) 90 | for summ in summary_str_lists: 91 | self.file_writer.add_summary(summ, step) 92 | 93 | 94 | def save_session(self): 95 | """Saves session = weights""" 96 | if not os.path.exists(self.dir_model): 97 | os.makedirs(self.dir_model) 98 | self.saver.save(self.sess, self.dir_model) 99 | 100 | def restore_session(self, path=None): 101 | if path is not None: 102 | self.saver.restore(self.sess, path) 103 | else: 104 | self.saver.restore(self.sess, self.dir_model) 105 | -------------------------------------------------------------------------------- /src/networks/dqn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import tensorflow as tf 4 | import shutil 5 | from functools import reduce 6 | from tensorflow.python import debug as tf_debug 7 | from src.utils import conv2d_layer, fully_connected_layer, huber_loss 8 | from src.networks.base import BaseModel 9 | 10 | 11 | 12 | 13 | class DQN(BaseModel): 14 | 15 | def __init__(self, n_actions, config): 16 | super(DQN, self).__init__(config, "dqn") 17 | self.n_actions = n_actions 18 | self.history_len = config.history_len 19 | self.cnn_format = config.cnn_format 20 | self.all_tf = not True 21 | 22 | 23 | def train_on_batch_target(self, state, action, reward, state_, terminal, steps): 24 | state_ = state_ / 255.0 25 | state = state / 255.0 26 | target_val = self.q_target_out.eval({self.state_target: state_}, session=self.sess) 27 | max_target = np.max(target_val, axis=1) 28 | target = (1. - terminal) * self.gamma * max_target + reward 29 | _, q, train_loss, q_summary, image_summary = self.sess.run( 30 | [self.train_op, self.q_out, self.loss, self.avg_q_summary, self.merged_image_sum], 31 | feed_dict={ 32 | self.state: state, 33 | self.action: action, 34 | self.target_val: target, 35 | self.lr: self.learning_rate 36 | } 37 | ) 38 | if self.train_steps % 1000 == 0: 39 | self.file_writer.add_summary(q_summary, self.train_steps) 40 | self.file_writer.add_summary(image_summary, self.train_steps) 41 | if steps % 20000 == 0 and steps > 50000: 42 | self.learning_rate *= self.lr_decay # decay learning rate 43 | if self.learning_rate < self.learning_rate_minimum: 44 | self.learning_rate = self.learning_rate_minimum 45 | self.train_steps += 1 46 | return q.mean(), train_loss 47 | 48 | def train_on_batch_all_tf(self, state, action, reward, state_, terminal, steps): 49 | state = state/255.0 50 | state_= state_/255.0 51 | _, q, train_loss, q_summary, image_summary = self.sess.run( 52 | [self.train_op, self.q_out, self.loss, self.avg_q_summary, self.merged_image_sum], feed_dict={ 53 | self.state: state, 54 | self.action: action, 55 | self.state_target:state_, 56 | self.reward: reward, 57 | self.terminal: terminal, 58 | self.lr: self.learning_rate, 59 | self.dropout: self.keep_prob 60 | } 61 | ) 62 | if self.train_steps % 1000 == 0: 63 | self.file_writer.add_summary(q_summary, self.train_steps) 64 | self.file_writer.add_summary(image_summary, self.train_steps) 65 | if steps % 20000 == 0 and steps > 50000: 66 | self.learning_rate *= self.lr_decay # decay learning rate 67 | if self.learning_rate < self.learning_rate_minimum: 68 | self.learning_rate = self.learning_rate_minimum 69 | self.train_steps += 1 70 | return q.mean(), train_loss 71 | 72 | def add_placeholders(self): 73 | self.w = {} 74 | self.w_target = {} 75 | self.state = tf.placeholder(tf.float32, shape=[None, self.history_len, self.screen_height, self.screen_width], 76 | name="input_state") 77 | self.action = tf.placeholder(tf.int32, shape=[None], name="action_input") 78 | self.reward = tf.placeholder(tf.int32, shape=[None], name="reward") 79 | 80 | self.state_target = tf.placeholder(tf.float32, 81 | shape=[None, self.history_len, self.screen_height, self.screen_width], 82 | name="input_target") 83 | self.dropout = tf.placeholder(dtype=tf.float32, shape=[], 84 | name="dropout") 85 | self.lr = tf.placeholder(dtype=tf.float32, shape=[], 86 | name="lr") 87 | self.terminal = tf.placeholder(dtype=tf.float32, shape=[None], name="terminal") 88 | 89 | self.target_val = tf.placeholder(dtype=tf.float32, shape=[None], name="target_val") 90 | self.target_val_tf = tf.placeholder(dtype=tf.float32, shape=[None, self.n_actions]) 91 | 92 | self.learning_rate_step = tf.placeholder("int64", None, name="learning_rate_step") 93 | 94 | def add_logits_op_train(self): 95 | self.image_summary = [] 96 | if self.cnn_format == "NHWC": 97 | x = tf.transpose(self.state, [0, 2, 3, 1]) 98 | else: 99 | x = self.state 100 | w, b, out, summary = conv2d_layer(x, 32, [8, 8], [4, 4], scope_name="conv1_train", summary_tag="conv1_out", 101 | activation=tf.nn.relu, data_format=self.cnn_format) 102 | self.w["wc1"] = w 103 | self.w["bc1"] = b 104 | self.image_summary.append(summary) 105 | 106 | w, b, out, summary = conv2d_layer(out, 64, [4, 4], [2, 2], scope_name="conv2_train", summary_tag="conv2_out", 107 | activation=tf.nn.relu, data_format=self.cnn_format) 108 | self.w["wc2"] = w 109 | self.w["bc2"] = b 110 | self.image_summary.append(summary) 111 | 112 | w, b, out, summary = conv2d_layer(out, 64, [3, 3], [1, 1], scope_name="conv3_train", summary_tag="conv3_out", 113 | activation=tf.nn.relu, data_format=self.cnn_format) 114 | self.w["wc3"] = w 115 | self.w["bc3"] = b 116 | self.image_summary.append(summary) 117 | 118 | shape = out.get_shape().as_list() 119 | out_flat = tf.reshape(out, [-1, reduce(lambda x, y: x * y, shape[1:])]) 120 | 121 | w, b, out = fully_connected_layer(out_flat, 512, scope_name="fully1_train") 122 | 123 | self.w["wf1"] = w 124 | self.w["bf1"] = b 125 | 126 | w, b, out = fully_connected_layer(out, self.n_actions, scope_name="out_train", activation=None) 127 | 128 | self.w["wout"] = w 129 | self.w["bout"] = b 130 | 131 | self.q_out = out 132 | self.q_action = tf.argmax(self.q_out, axis=1) 133 | 134 | def add_logits_op_target(self): 135 | if self.cnn_format == "NHWC": 136 | x = tf.transpose(self.state_target, [0, 2, 3, 1]) 137 | else: 138 | x = self.state_target 139 | w, b, out, _ = conv2d_layer(x, 32, [8, 8], [4, 4], scope_name="conv1_target", summary_tag=None, 140 | activation=tf.nn.relu, data_format=self.cnn_format) 141 | self.w_target["wc1"] = w 142 | self.w_target["bc1"] = b 143 | 144 | w, b, out, _ = conv2d_layer(out, 64, [4, 4], [2, 2], scope_name="conv2_target", summary_tag=None, 145 | activation=tf.nn.relu, data_format=self.cnn_format) 146 | self.w_target["wc2"] = w 147 | self.w_target["bc2"] = b 148 | 149 | w, b, out, _ = conv2d_layer(out, 64, [3, 3], [1, 1], scope_name="conv3_target", summary_tag=None, 150 | activation=tf.nn.relu, data_format=self.cnn_format) 151 | self.w_target["wc3"] = w 152 | self.w_target["bc3"] = b 153 | 154 | shape = out.get_shape().as_list() 155 | out_flat = tf.reshape(out, [-1, reduce(lambda x, y: x * y, shape[1:])]) 156 | 157 | w, b, out = fully_connected_layer(out_flat, 512, scope_name="fully1_target") 158 | 159 | self.w_target["wf1"] = w 160 | self.w_target["bf1"] = b 161 | 162 | w, b, out = fully_connected_layer(out, self.n_actions, scope_name="out_target", activation=None) 163 | 164 | self.w_target["wout"] = w 165 | self.w_target["bout"] = b 166 | 167 | self.q_target_out = out 168 | self.q_target_action = tf.argmax(self.q_target_out, axis=1) 169 | 170 | def init_update(self): 171 | self.target_w_in = {} 172 | self.target_w_assign = {} 173 | for name in self.w: 174 | self.target_w_in[name] = tf.placeholder(tf.float32, self.w_target[name].get_shape().as_list(), name=name) 175 | self.target_w_assign[name] = self.w_target[name].assign(self.target_w_in[name]) 176 | 177 | def add_loss_op_target(self): 178 | action_one_hot = tf.one_hot(self.action, self.n_actions, 1.0, 0.0, name='action_one_hot') 179 | train = tf.reduce_sum(self.q_out * action_one_hot, reduction_indices=1, name='q_acted') 180 | self.delta = train - self.target_val 181 | self.loss = tf.reduce_mean(huber_loss(self.delta)) 182 | 183 | avg_q = tf.reduce_mean(self.q_out, 0) 184 | q_summary = [] 185 | for i in range(self.n_actions): 186 | q_summary.append(tf.summary.histogram('q/{}'.format(i), avg_q[i])) 187 | self.merged_image_sum = tf.summary.merge(self.image_summary, "images") 188 | self.avg_q_summary = tf.summary.merge(q_summary, 'q_summary') 189 | self.loss_summary = tf.summary.scalar("loss", self.loss) 190 | 191 | def add_loss_op_target_tf(self): 192 | self.reward = tf.cast(self.reward, dtype=tf.float32) 193 | target_best = tf.reduce_max(self.q_target_out, 1) 194 | masked = (1.0 - self.terminal) * target_best 195 | target = self.reward + self.gamma * masked 196 | 197 | action_one_hot = tf.one_hot(self.action, self.n_actions, 1.0, 0.0, name='action_one_hot') 198 | train = tf.reduce_sum(self.q_out * action_one_hot, reduction_indices=1) 199 | delta = target - train 200 | self.loss = tf.reduce_mean(huber_loss(delta)) 201 | avg_q = tf.reduce_mean(self.q_out, 0) 202 | q_summary = [] 203 | for i in range(self.n_actions): 204 | q_summary.append(tf.summary.histogram('q/{}'.format(i), avg_q[i])) 205 | self.avg_q_summary = tf.summary.merge(q_summary, 'q_summary') 206 | self.loss_summary = tf.summary.scalar("loss", self.loss) 207 | self.merged_image_sum = tf.summary.merge(self.image_summary, "images") 208 | 209 | def build(self): 210 | self.add_placeholders() 211 | self.add_logits_op_train() 212 | self.add_logits_op_target() 213 | if self.all_tf: 214 | self.add_loss_op_target_tf() 215 | else: 216 | self.add_loss_op_target() 217 | self.add_train_op(self.lr_method, self.lr, self.loss, clip=10) 218 | self.initialize_session() 219 | self.init_update() 220 | 221 | def update_target(self): 222 | for name in self.w: 223 | self.target_w_assign[name].eval({self.target_w_in[name]: self.w[name].eval(session=self.sess)}, 224 | session=self.sess) 225 | 226 | -------------------------------------------------------------------------------- /src/networks/drqn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import tensorflow as tf 4 | import shutil 5 | from functools import reduce 6 | from tensorflow.python import debug as tf_debug 7 | from src.utils import conv2d_layer, fully_connected_layer, stateful_lstm, huber_loss 8 | from src.networks.base import BaseModel 9 | 10 | 11 | class DRQN(BaseModel): 12 | 13 | def __init__(self, n_actions, config): 14 | super(DRQN, self).__init__(config, "drqn") 15 | self.n_actions = n_actions 16 | self.cnn_format = config.cnn_format 17 | self.num_lstm_layers = config.num_lstm_layers 18 | self.lstm_size = config.lstm_size 19 | self.min_history = config.min_history 20 | self.states_to_update = config.states_to_update 21 | 22 | def add_placeholders(self): 23 | self.w = {} 24 | self.w_target = {} 25 | self.state = tf.placeholder(tf.float32, shape=[None, 1, self.screen_height, self.screen_width], 26 | name="input_state") 27 | self.action = tf.placeholder(tf.int32, shape=[None], name="action_input") 28 | self.reward = tf.placeholder(tf.int32, shape=[None], name="reward") 29 | self.state_target = tf.placeholder(tf.float32, 30 | shape=[None, 1, self.screen_height, self.screen_width], 31 | name="input_target") 32 | # create placeholder to fill in lstm state 33 | self.c_state_train = tf.placeholder(tf.float32, [None, self.lstm_size], name="train_c") 34 | self.h_state_train = tf.placeholder(tf.float32, [None, self.lstm_size], name="train_h") 35 | self.lstm_state_train = tf.nn.rnn_cell.LSTMStateTuple(self.c_state_train, self.h_state_train) 36 | 37 | 38 | 39 | self.c_state_target = tf.placeholder(tf.float32, [None, self.lstm_size], name="target_c") 40 | self.h_state_target = tf.placeholder(tf.float32, [None, self.lstm_size], name="target_h") 41 | self.lstm_state_target = tf.nn.rnn_cell.LSTMStateTuple(self.c_state_target, self.h_state_target) 42 | 43 | # initial zero state to be used when starting episode 44 | self.initial_zero_state_batch = np.zeros((self.batch_size, self.lstm_size)) 45 | self.initial_zero_state_single = np.zeros((1, self.lstm_size)) 46 | 47 | self.initial_zero_complete = np.zeros((self.num_lstm_layers, 2, self.batch_size, self.lstm_size)) 48 | 49 | self.dropout = tf.placeholder(dtype=tf.float32, shape=[], 50 | name="dropout") 51 | self.lr = tf.placeholder(dtype=tf.float32, shape=[], 52 | name="lr") 53 | self.target_val = tf.placeholder(dtype=tf.float32, shape=[None], name="target_val") 54 | self.terminal = tf.placeholder(dtype=tf.float32, shape=[None], name="terminal") 55 | self.target_val_tf = tf.placeholder(dtype=tf.float32, shape=[None, self.n_actions]) 56 | 57 | def add_logits_op_train(self): 58 | if self.cnn_format == "NHWC": 59 | x = tf.transpose(self.state, [0, 2, 3, 1]) 60 | else: 61 | x = self.state 62 | self.image_summary = [] 63 | w, b, out, summary = conv2d_layer(x, 32, [8, 8], [4, 4], scope_name="conv1_train", 64 | summary_tag="conv1_out", 65 | activation=tf.nn.relu, data_format=self.cnn_format) 66 | self.w["wc1"] = w 67 | self.w["bc1"] = b 68 | self.image_summary.append(summary) 69 | 70 | w, b, out, summary = conv2d_layer(out, 64, [4, 4], [2, 2], scope_name="conv2_train", summary_tag="conv2_out", 71 | activation=tf.nn.relu, data_format=self.cnn_format) 72 | self.w["wc2"] = w 73 | self.w["bc2"] = b 74 | self.image_summary.append(summary) 75 | 76 | w, b, out, summary = conv2d_layer(out, 64, [3, 3], [1, 1], scope_name="conv3_train", summary_tag="conv3_out", 77 | activation=tf.nn.relu, data_format=self.cnn_format) 78 | self.w["wc3"] = w 79 | self.w["bc3"] = b 80 | self.image_summary.append(summary) 81 | 82 | shape = out.get_shape().as_list() 83 | out_flat = tf.reshape(out, [tf.shape(out)[0], 1, shape[1] * shape[2] * shape[3]]) 84 | out, state = stateful_lstm(out_flat, self.num_lstm_layers, self.lstm_size, tuple([self.lstm_state_train]), 85 | scope_name="lstm_train") 86 | self.state_output_c = state[0][0] 87 | self.state_output_h = state[0][1] 88 | shape = out.get_shape().as_list() 89 | out = tf.reshape(out, [tf.shape(out)[0], shape[2]]) 90 | w, b, out = fully_connected_layer(out, self.n_actions, scope_name="out_train", activation=None) 91 | 92 | self.w["wout"] = w 93 | self.w["bout"] = b 94 | 95 | self.q_out = out 96 | self.q_action = tf.argmax(self.q_out, axis=1) 97 | 98 | def add_logits_op_target(self): 99 | if self.cnn_format == "NHWC": 100 | x = tf.transpose(self.state_target, [0, 2, 3, 1]) 101 | else: 102 | x = self.state_target 103 | w, b, out, _ = conv2d_layer(x, 32, [8, 8], [4, 4], scope_name="conv1_target", summary_tag=None, 104 | activation=tf.nn.relu, data_format=self.cnn_format) 105 | self.w_target["wc1"] = w 106 | self.w_target["bc1"] = b 107 | 108 | w, b, out, _ = conv2d_layer(out, 64, [4, 4], [2, 2], scope_name="conv2_target", summary_tag=None, 109 | activation=tf.nn.relu, data_format=self.cnn_format) 110 | self.w_target["wc2"] = w 111 | self.w_target["bc2"] = b 112 | 113 | w, b, out, _ = conv2d_layer(out, 64, [3, 3], [1, 1], scope_name="conv3_target", summary_tag=None, 114 | activation=tf.nn.relu, data_format=self.cnn_format) 115 | self.w_target["wc3"] = w 116 | self.w_target["bc3"] = b 117 | 118 | shape = out.get_shape().as_list() 119 | out_flat = tf.reshape(out, [tf.shape(out)[0], 1, shape[1] * shape[2] * shape[3]]) 120 | out, state = stateful_lstm(out_flat, self.num_lstm_layers, self.lstm_size, 121 | tuple([self.lstm_state_target]), scope_name="lstm_target") 122 | self.state_output_target_c = state[0][0] 123 | self.state_output_target_h = state[0][1] 124 | shape = out.get_shape().as_list() 125 | 126 | out = tf.reshape(out, [tf.shape(out)[0], shape[2]]) 127 | 128 | w, b, out = fully_connected_layer(out, self.n_actions, scope_name="out_target", activation=None) 129 | 130 | self.w_target["wout"] = w 131 | self.w_target["bout"] = b 132 | 133 | self.q_target_out = out 134 | self.q_target_action = tf.argmax(self.q_target_out, axis=1) 135 | 136 | def train_on_batch_target(self, states, action, reward, terminal, steps): 137 | states = states / 255.0 138 | q, loss = np.zeros((self.batch_size, self.n_actions)), 0 139 | states = np.transpose(states, [1, 0, 2, 3]) 140 | action = np.transpose(action, [1, 0]) 141 | reward = np.transpose(reward, [1, 0]) 142 | terminal = np.transpose(terminal, [1, 0]) 143 | states = np.reshape(states, [states.shape[0], states.shape[1], 1, states.shape[2], states.shape[3]]) 144 | lstm_state_c, lstm_state_h = self.initial_zero_state_batch, self.initial_zero_state_batch 145 | lstm_state_target_c, lstm_state_target_h = self.sess.run( 146 | [self.state_output_target_c, self.state_output_target_h], 147 | { 148 | self.state_target: states[0], 149 | self.c_state_target: self.initial_zero_state_batch, 150 | self.h_state_target: self.initial_zero_state_batch 151 | } 152 | ) 153 | for i in range(self.min_history): 154 | j = i + 1 155 | lstm_state_c, lstm_state_h, lstm_state_target_c, lstm_state_target_h = self.sess.run( 156 | [self.state_output_c, self.state_output_h, self.state_output_target_c, self.state_output_target_h], 157 | { 158 | self.state: states[i], 159 | self.state_target: states[j], 160 | self.c_state_target: lstm_state_target_c, 161 | self.h_state_target: lstm_state_target_h, 162 | self.c_state_train: lstm_state_c, 163 | self.h_state_train: lstm_state_h 164 | } 165 | ) 166 | for i in range(self.min_history, self.min_history + self.states_to_update): 167 | j = i + 1 168 | target_val, lstm_state_target_c, lstm_state_target_h = self.sess.run( 169 | [self.q_target_out, self.state_output_target_c, self.state_output_target_h], 170 | { 171 | self.state_target: states[j], 172 | self.c_state_target: lstm_state_target_c, 173 | self.h_state_target: lstm_state_target_h 174 | } 175 | ) 176 | max_target = np.max(target_val, axis=1) 177 | target = (1. - terminal[i]) * self.gamma * max_target + reward[i] 178 | _, q_, train_loss_, lstm_state_c, lstm_state_h, merged_imgs= self.sess.run( 179 | [self.train_op, self.q_out, self.loss, self.state_output_c, self.state_output_h, self.merged_image_sum], 180 | feed_dict={ 181 | self.state: states[i], 182 | self.c_state_train: lstm_state_c, 183 | self.h_state_train: lstm_state_h, 184 | self.action: action[i], 185 | self.target_val: target, 186 | self.lr: self.learning_rate 187 | } 188 | ) 189 | q += q_ 190 | loss += train_loss_ 191 | if self.train_steps % 5000 == 0: 192 | self.file_writer.add_summary(merged_imgs, steps) 193 | if steps % 20000 == 0 and steps > 50000: 194 | self.learning_rate *= self.lr_decay # decay learning rate 195 | if self.learning_rate < self.learning_rate_minimum: 196 | self.learning_rate = self.learning_rate_minimum 197 | self.train_steps += 1 198 | return q.mean(), loss / (self.states_to_update) 199 | 200 | 201 | 202 | def add_loss_op_target(self): 203 | action_one_hot = tf.one_hot(self.action, self.n_actions, 1.0, 0.0, name='action_one_hot') 204 | train = tf.reduce_sum(self.q_out * action_one_hot, reduction_indices=1, name='q_acted') 205 | self.delta = train - self.target_val 206 | self.loss = tf.reduce_mean(huber_loss(self.delta)) 207 | 208 | avg_q = tf.reduce_mean(self.q_out, 0) 209 | q_summary = [] 210 | for i in range(self.n_actions): 211 | q_summary.append(tf.summary.histogram('q/{}'.format(i), avg_q[i])) 212 | self.merged_image_sum = tf.summary.merge(self.image_summary, "images") 213 | self.avg_q_summary = tf.summary.merge(q_summary, 'q_summary') 214 | self.loss_summary = tf.summary.scalar("loss", self.loss) 215 | 216 | def build(self): 217 | self.add_placeholders() 218 | self.add_logits_op_train() 219 | self.add_logits_op_target() 220 | self.add_loss_op_target() 221 | self.add_train_op(self.lr_method, self.lr, self.loss, clip=10) 222 | self.initialize_session() 223 | self.init_update() 224 | 225 | def update_target(self): 226 | for name in self.w: 227 | self.target_w_assign[name].eval({self.target_w_in[name]: self.w[name].eval(session=self.sess)}, 228 | session=self.sess) 229 | for var in self.lstm_vars: 230 | self.target_w_assign[var.name].eval({self.target_w_in[var.name]: var.eval(session=self.sess)}, 231 | session=self.sess) 232 | 233 | def init_update(self): 234 | self.target_w_in = {} 235 | self.target_w_assign = {} 236 | for name in self.w: 237 | self.target_w_in[name] = tf.placeholder(tf.float32, self.w_target[name].get_shape().as_list(), name=name) 238 | self.target_w_assign[name] = self.w_target[name].assign(self.target_w_in[name]) 239 | 240 | self.lstm_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="lstm_train") 241 | lstm_target_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="lstm_target") 242 | 243 | for i, var in enumerate(self.lstm_vars): 244 | self.target_w_in[var.name] = tf.placeholder(tf.float32, var.get_shape().as_list()) 245 | self.target_w_assign[var.name] = lstm_target_vars[i].assign(self.target_w_in[var.name]) 246 | -------------------------------------------------------------------------------- /src/replay_memory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import os 4 | 5 | class ReplayMemory: 6 | 7 | def __init__(self, config): 8 | self.config = config 9 | self.actions = np.empty((self.config.mem_size), dtype=np.int32) 10 | self.rewards = np.empty((self.config.mem_size), dtype=np.int32) 11 | # Screens are dtype=np.uint8 which saves massive amounts of memory, however the network expects state inputs 12 | # to be dtype=np.float32. Remember this every time you feed something into the network 13 | self.screens = np.empty((self.config.mem_size, self.config.screen_height, self.config.screen_width), dtype=np.uint8) 14 | self.terminals = np.empty((self.config.mem_size,), dtype=np.float16) 15 | self.count = 0 16 | self.current = 0 17 | self.dir_save = config.dir_save + "memory/" 18 | 19 | if not os.path.exists(self.dir_save): 20 | os.makedirs(self.dir_save) 21 | 22 | def save(self): 23 | np.save(self.dir_save + "screens.npy", self.screens) 24 | np.save(self.dir_save + "actions.npy", self.actions) 25 | np.save(self.dir_save + "rewards.npy", self.rewards) 26 | np.save(self.dir_save + "terminals.npy", self.terminals) 27 | 28 | def load(self): 29 | self.screens = np.load(self.dir_save + "screens.npy") 30 | self.actions = np.load(self.dir_save + "actions.npy") 31 | self.rewards = np.load(self.dir_save + "rewards.npy") 32 | self.terminals = np.load(self.dir_save + "terminals.npy") 33 | 34 | 35 | 36 | class DQNReplayMemory(ReplayMemory): 37 | 38 | def __init__(self, config): 39 | super(DQNReplayMemory, self).__init__(config) 40 | 41 | self.pre = np.empty((self.config.batch_size, self.config.history_len, self.config.screen_height, self.config.screen_width), dtype=np.uint8) 42 | self.post = np.empty((self.config.batch_size, self.config.history_len, self.config.screen_height, self.config.screen_width), dtype=np.uint8) 43 | 44 | def getState(self, index): 45 | 46 | index = index % self.count 47 | if index >= self.config.history_len - 1: 48 | a = self.screens[(index - (self.config.history_len - 1)):(index + 1), ...] 49 | return a 50 | else: 51 | indices = [(index - i) % self.count for i in reversed(range(self.config.history_len))] 52 | return self.screens[indices, ...] 53 | 54 | def add(self, screen, reward, action, terminal): 55 | assert screen.shape == (self.config.screen_height, self.config.screen_width) 56 | 57 | self.actions[self.current] = action 58 | self.rewards[self.current] = reward 59 | self.screens[self.current] = screen 60 | self.terminals[self.current] = float(terminal) 61 | self.count = max(self.count, self.current + 1) 62 | self.current = (self.current + 1) % self.config.mem_size 63 | 64 | def sample_batch(self): 65 | assert self.count > self.config.history_len 66 | 67 | indices = [] 68 | while len(indices) < self.config.batch_size: 69 | 70 | while True: 71 | index = random.randint(self.config.history_len, self.count-1) 72 | 73 | if index >= self.current and index - self.config.history_len < self.current: 74 | continue 75 | 76 | if self.terminals[(index - self.config.history_len): index].any(): 77 | continue 78 | break 79 | self.pre[len(indices)] = self.getState(index - 1) 80 | self.post[len(indices)] = self.getState(index) 81 | indices.append(index) 82 | 83 | actions = self.actions[indices] 84 | rewards = self.rewards[indices] 85 | terminals = self.terminals[indices] 86 | 87 | return self.pre, actions, rewards, self.post, terminals 88 | 89 | class DRQNReplayMemory(ReplayMemory): 90 | 91 | def __init__(self, config): 92 | super(DRQNReplayMemory, self).__init__(config) 93 | 94 | self.timesteps = np.empty((self.config.mem_size), dtype=np.int32) 95 | self.states = np.empty((self.config.batch_size, self.config.min_history + self.config.states_to_update + 1, self.config.screen_height, self.config.screen_width), dtype=np.uint8) 96 | self.actions_out = np.empty((self.config.batch_size, self.config.min_history + self.config.states_to_update +1)) 97 | self.rewards_out = np.empty((self.config.batch_size, self.config.min_history + self.config.states_to_update +1)) 98 | self.terminals_out = np.empty((self.config.batch_size, self.config.min_history + self.config.states_to_update +1)) 99 | 100 | def add(self, screen, reward, action, terminal, t): 101 | assert screen.shape == (self.config.screen_height, self.config.screen_width) 102 | 103 | self.actions[self.current] = action 104 | self.rewards[self.current] = reward 105 | self.screens[self.current] = screen 106 | self.timesteps[self.current] = t 107 | self.terminals[self.current] = float(terminal) 108 | self.count = max(self.count, self.current + 1) 109 | self.current = (self.current + 1) % self.config.mem_size 110 | 111 | 112 | def getState(self, index): 113 | a = self.screens[index - (self.config.min_history + self.config.states_to_update + 1): index] 114 | return a 115 | 116 | def get_scalars(self, index): 117 | t = self.terminals[index - (self.config.min_history + self.config.states_to_update + 1): index] 118 | a = self.actions[index - (self.config.min_history + self.config.states_to_update + 1): index] 119 | r = self.rewards[index - (self.config.min_history + self.config.states_to_update + 1): index] 120 | return a, t, r 121 | 122 | def sample_batch(self): 123 | assert self.count > self.config.min_history + self.config.states_to_update 124 | 125 | indices = [] 126 | while len(indices) < self.config.batch_size: 127 | 128 | while True: 129 | index = random.randint(self.config.min_history, self.count-1) 130 | if index >= self.current and index - self.config.min_history < self.current: 131 | continue 132 | if index < self.config.min_history + self.config.states_to_update + 1: 133 | continue 134 | if self.timesteps[index] < self.config.min_history + self.config.states_to_update: 135 | continue 136 | break 137 | self.states[len(indices)] = self.getState(index) 138 | self.actions_out[len(indices)], self.terminals_out[len(indices)], self.rewards_out[len(indices)] = self.get_scalars(index) 139 | indices.append(index) 140 | 141 | 142 | return self.states, self.actions_out, self.rewards_out, self.terminals_out 143 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from scipy.misc import imresize as resize 2 | import json 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def rgb2gray(screen): 8 | return np.dot(screen[..., :3], [0.299, 0.587, 0.114]) 9 | 10 | 11 | def load_config(config_file): 12 | pass 13 | 14 | 15 | def save_config(config_file, config_dict): 16 | with open(config_file, 'w') as fp: 17 | json.dump(config_dict, fp) 18 | 19 | 20 | def conv2d_layer(x, output_dim, kernel_size, stride, initializer=None, padding="VALID", data_format="NCHW", 21 | summary_tag=None, 22 | scope_name="conv2d", activation=tf.nn.relu): 23 | with tf.variable_scope(scope_name): 24 | if data_format == 'NCHW': 25 | stride = [1, 1, stride[0], stride[1]] 26 | kernel_shape = [kernel_size[0], kernel_size[1], x.get_shape()[1], output_dim] 27 | elif data_format == 'NHWC': 28 | stride = [1, stride[0], stride[1], 1] 29 | kernel_shape = [kernel_size[0], kernel_size[1], x.get_shape()[-1], output_dim] 30 | 31 | w = tf.get_variable('w', kernel_shape, tf.float32, initializer=tf.truncated_normal_initializer(0, 0.02)) 32 | conv = tf.nn.conv2d(x, w, stride, padding, data_format=data_format) 33 | 34 | b = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 35 | out = tf.nn.bias_add(conv, b, data_format) 36 | 37 | if activation != None: 38 | out = activation(out) 39 | summary = None 40 | if summary_tag is not None: 41 | # TODO general definitions 42 | if output_dim == 32: 43 | ix = 4 44 | iy = 8 45 | elif output_dim == 64: 46 | ix = 8 47 | iy = 8 48 | 49 | img = tf.slice(out, [0, 0, 0, 0], [1, -1, -1, -1]) 50 | if data_format == "NCHW": 51 | img = tf.transpose(img, [0, 2, 3, 1]) 52 | out_shape = img.get_shape().as_list() 53 | img = tf.reshape(img, [out_shape[1], out_shape[2], out_shape[3]]) 54 | out_shape[1] += 4 55 | out_shape[2] += 4 56 | img = tf.image.resize_image_with_crop_or_pad(img, out_shape[1], out_shape[2]) 57 | img = tf.reshape(img, [out_shape[1], out_shape[2], ix, iy]) 58 | img = tf.transpose(img, [2, 0, 3, 1]) 59 | img = tf.reshape(img, [1, ix * out_shape[1], iy * out_shape[2], 1]) 60 | summary = tf.summary.image(summary_tag, img) 61 | return w, b, out, summary 62 | 63 | 64 | def fully_connected_layer(x, output_dim, scope_name="fully", initializer=tf.random_normal_initializer(stddev=0.02), 65 | activation=tf.nn.relu): 66 | shape = x.get_shape().as_list() 67 | with tf.variable_scope(scope_name): 68 | w = tf.get_variable("w", [shape[1], output_dim], dtype=tf.float32, 69 | initializer=initializer) 70 | b = tf.get_variable("b", [output_dim], dtype=tf.float32, 71 | initializer=tf.zeros_initializer()) 72 | out = tf.nn.xw_plus_b(x, w, b) 73 | if activation is not None: 74 | out = activation(out) 75 | 76 | return w, b, out 77 | 78 | def stateful_lstm(x, num_layers, lstm_size, state_input, scope_name="lstm"): 79 | with tf.variable_scope(scope_name): 80 | cell = tf.nn.rnn_cell.LSTMCell(lstm_size, state_is_tuple=True) 81 | cell = tf.nn.rnn_cell.MultiRNNCell([cell]*num_layers, state_is_tuple=True) 82 | outputs, state = tf.nn.dynamic_rnn(cell, x, initial_state=state_input) 83 | return outputs, state 84 | 85 | 86 | def huber_loss(x, delta=1.0): 87 | return tf.where(tf.abs(x) < delta, 0.5 * tf.square(x), delta * tf.abs(x) - 0.5* delta) 88 | 89 | def integer_product(x): 90 | return int(np.prod(x)) 91 | 92 | 93 | def initializer_bounds_filter(filter_shape): 94 | fan_in = integer_product(filter_shape[:3]) 95 | fan_out = integer_product(filter_shape[:2]) * filter_shape[3] 96 | return np.sqrt(6. / (fan_in + fan_out)) 97 | --------------------------------------------------------------------------------