├── README.md ├── config.py ├── dqn_model.py ├── dqn_utils.py ├── env.py ├── figs ├── breakout_avg_eval_reward_steps.png ├── breakout_avg_training_reward_steps.png ├── breakout_winning_gaps.gif ├── freeway_9heads_prior_episode_reward.png ├── freeway_eval_rewards_steps.png ├── pong_eval_rewards_steps.png ├── small_freeway_ATARI_s14547756_R34.gif └── small_pong_ATARI_step0002509459_r0021_testcolor.gif ├── replay.py ├── roms ├── air_raid.bin ├── alien.bin ├── amidar.bin ├── assault.bin ├── asterix.bin ├── asteroids.bin ├── atlantis.bin ├── bank_heist.bin ├── battle_zone.bin ├── beam_rider.bin ├── berzerk.bin ├── bowling.bin ├── boxing.bin ├── breakout.bin ├── carnival.bin ├── centipede.bin ├── chopper_command.bin ├── crazy_climber.bin ├── defender.bin ├── demon_attack.bin ├── double_dunk.bin ├── elevator_action.bin ├── enduro.bin ├── fishing_derby.bin ├── freeway.bin ├── frostbite.bin ├── gopher.bin ├── gravitar.bin ├── hero.bin ├── ice_hockey.bin ├── jamesbond.bin ├── journey_escape.bin ├── kangaroo.bin ├── krull.bin ├── kung_fu_master.bin ├── montezuma_revenge.bin ├── ms_pacman.bin ├── name_this_game.bin ├── phoenix.bin ├── pitfall.bin ├── pong.bin ├── pooyan.bin ├── private_eye.bin ├── qbert.bin ├── riverraid.bin ├── road_runner.bin ├── robotank.bin ├── seaquest.bin ├── skiing.bin ├── solaris.bin ├── space_invaders.bin ├── star_gunner.bin ├── tennis.bin ├── time_pilot.bin ├── tutankham.bin ├── up_n_down.bin ├── venture.bin ├── video_pinball.bin ├── wizard_of_wor.bin ├── yars_revenge.bin └── zaxxon.bin └── run_bootstrap.py /README.md: -------------------------------------------------------------------------------- 1 | # Bootstrap DQN with options to add a Randomized Prior, Dueling, and Double DQN in ALE games. 2 | 3 | This repo contains our implementation of a Bootstrapped DQN with options to add a Randomized Prior, Dueling, and Double DQN in ALE games. 4 | 5 | [Deep Exploration via Bootstrapped DQN](https://arxiv.org/abs/1602.04621) 6 | 7 | [Randomized Prior Functions for Deep Reinforcement Learning](https://arxiv.org/abs/1806.03335) 8 | 9 | # Some results on Breakout 10 | 11 | ![alt text](figs/breakout_winning_gaps.gif?raw=true "Breakout Agent - Bootstrap, Prior") 12 | 13 | This gif depicts the orange agent from below winning the first game of Breakout and eventually winning a second game. The agent reaches a high score of 830 in this evaluation. There are several gaps in playback due to file size. We show agent steps [1000-1500], [2400-2600], [3000-4500], and [16000-16300]. 14 | 15 | ### Comparison: 16 | 17 | - (blue) DQN with epsilon greed annealed between 1 and 0.01 18 | - (orange) Bootstrap with epsilon greedy annealed between 1 and 0.01 19 | - (green) Bootstrap without epsilon greedy exploration 20 | - (red) Bootstrap with randomized prior 21 | 22 | All agents were implemented as Dueling, Double DQNs. The xlabel in these plots, "steps", 23 | refers to the number of states the agent observed thus far in training. Multiply by 4 to account for a frame-skip of 4 to describe the total number of frames the emulator has progressed. 24 | 25 | Our agents are sent a terminal signal at the end of life. They face a deterministic state progression after a random number<30 of no-op steps at the beginning of each episode. 26 | 27 | ![alt text](figs/breakout_avg_training_reward_steps.png?raw=true "Breakout Agent - Bootstrap, Prior") 28 | 29 | # Some results on Pong 30 | 31 | Here are some results on Pong with Boostrap DQN w/ a Randomized Prior. A optimal strategy is learned within 2.5m steps. 32 | 33 | ![alt text](figs/small_pong_ATARI_step0002509459_r0021_testcolor.gif?raw=true "Breakout Agent - Bootstrap, Prior") 34 | 35 | Pong agent score in evaluation - reward vs steps 36 | ![alt text](figs/pong_eval_rewards_steps.png?raw=true "Eval Pong Agent - Bootstrap with Prior Reward v Steps") 37 | 38 | # Some results on Freeway 39 | 40 | Here are some results on Freeway with Boostrap DQN w/ a Randomized Prior. The random prior allowed us to solve this "hard exploration" problem within 4 millions steps. 41 | 42 | ![alt text](figs/small_freeway_ATARI_s14547756_R34.gif?raw=true "Freeway Agent - Bootstrap with Prior") 43 | 44 | Freeway agent score in evaluation - reward vs steps 45 | 46 | ![alt text](figs/freeway_eval_rewards_steps.png?raw=true "Freeway Agent - Bootstrap with Prior") 47 | 48 | # Dependencies 49 | 50 | atari-py installed from https://github.com/kastnerkyle/atari-py 51 | torch='1.0.1.post2' 52 | cv2='4.0.0' 53 | 54 | 55 | # References 56 | 57 | We referenced several execellent examples/blogposts to build this codebase: 58 | 59 | [Discussion and debugging w/ Kyle Kaster](https://gist.github.com/kastnerkyle/a4498fdf431a3a6d551bcc30cd9a35a0) 60 | 61 | [Fabio M. Graetz's DQN](https://github.com/fg91/Deep-Q-Learning/blob/master/DQN.ipynb) 62 | 63 | [hengyuan-hu's Rainbow](https://github.com/hengyuan-hu/rainbow) 64 | 65 | [Dopamine's baseline](https://github.com/google/dopamine) 66 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | base_datadir = '../dataset/' 2 | model_savedir = '../model_savedir' 3 | results_savedir = '../results' 4 | -------------------------------------------------------------------------------- /dqn_model.py: -------------------------------------------------------------------------------- 1 | 2 | # Model style from Kyle @ 3 | # https://gist.github.com/kastnerkyle/a4498fdf431a3a6d551bcc30cd9a35a0 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from IPython import embed 9 | 10 | # from the DQN paper 11 | #The first convolution layer convolves the input with 32 filters of size 8 (stride 4), 12 | #the second layer has 64 layers of size 4 13 | #(stride 2), the final convolution layer has 64 filters of size 3 (stride 14 | #1). This is followed by a fully-connected hidden layer of 512 units. 15 | 16 | # init func used by hengyaun 17 | def weights_init(m): 18 | """custom weights initialization""" 19 | classtype = m.__class__ 20 | if classtype == nn.Linear or classtype == nn.Conv2d: 21 | print("default init") 22 | #m.weight.data.normal_(0.0, 0.02) 23 | #m.bias.data.fill_(0) 24 | elif classtype == nn.BatchNorm2d: 25 | m.weight.data.normal_(1.0, 0.02) 26 | m.bias.data.fill_(0) 27 | else: 28 | print('%s is not initialized.' %classtype) 29 | 30 | 31 | class CoreNet(nn.Module): 32 | def __init__(self, network_output_size=84, num_channels=4): 33 | super(CoreNet, self).__init__() 34 | self.network_output_size = network_output_size 35 | self.num_channels = num_channels 36 | # params from ddqn appendix 37 | self.conv1 = nn.Conv2d(self.num_channels, 32, 8, 4) 38 | # TODO - should we have this init during PRIOR code? 39 | self.conv2 = nn.Conv2d(32, 64, 4, 2) 40 | self.conv3 = nn.Conv2d(64, 64, 3, 1) 41 | self.conv1.apply(weights_init) 42 | self.conv2.apply(weights_init) 43 | self.conv3.apply(weights_init) 44 | 45 | def forward(self, x): 46 | x = F.relu(self.conv1(x)) 47 | x = F.relu(self.conv2(x)) 48 | x = F.relu(self.conv3(x)) 49 | # size after conv3 50 | reshape = 64*7*7 51 | x = x.view(-1, reshape) 52 | return x 53 | 54 | class DuelingHeadNet(nn.Module): 55 | def __init__(self, n_actions=4): 56 | super(DuelingHeadNet, self).__init__() 57 | mult = 64*7*7 58 | self.split_size = 512 59 | self.fc1 = nn.Linear(mult, self.split_size*2) 60 | self.value = nn.Linear(self.split_size, 1) 61 | self.advantage = nn.Linear(self.split_size, n_actions) 62 | self.fc1.apply(weights_init) 63 | self.value.apply(weights_init) 64 | self.advantage.apply(weights_init) 65 | 66 | def forward(self, x): 67 | x1,x2 = torch.split(F.relu(self.fc1(x)), self.split_size, dim=1) 68 | value = self.value(x1) 69 | advantage = self.advantage(x2) 70 | # value is shape [batch_size, 1] 71 | # advantage is shape [batch_size, n_actions] 72 | q = value + torch.sub(advantage, torch.mean(advantage, dim=1, keepdim=True)) 73 | return q 74 | 75 | class HeadNet(nn.Module): 76 | def __init__(self, n_actions=4): 77 | super(HeadNet, self).__init__() 78 | mult = 64*7*7 79 | self.fc1 = nn.Linear(mult, 512) 80 | self.fc2 = nn.Linear(512, n_actions) 81 | self.fc1.apply(weights_init) 82 | self.fc2.apply(weights_init) 83 | 84 | def forward(self, x): 85 | x = F.relu(self.fc1(x)) 86 | x = self.fc2(x) 87 | return x 88 | 89 | class EnsembleNet(nn.Module): 90 | def __init__(self, n_ensemble, n_actions, network_output_size, num_channels, dueling=False): 91 | super(EnsembleNet, self).__init__() 92 | self.core_net = CoreNet(network_output_size=network_output_size, num_channels=num_channels) 93 | self.dueling = dueling 94 | if self.dueling: 95 | print("using dueling dqn") 96 | self.net_list = nn.ModuleList([DuelingHeadNet(n_actions=n_actions) for k in range(n_ensemble)]) 97 | else: 98 | self.net_list = nn.ModuleList([HeadNet(n_actions=n_actions) for k in range(n_ensemble)]) 99 | 100 | def _core(self, x): 101 | return self.core_net(x) 102 | 103 | def _heads(self, x): 104 | return [net(x) for net in self.net_list] 105 | 106 | def forward(self, x, k): 107 | if k is not None: 108 | return self.net_list[k](self.core_net(x)) 109 | else: 110 | core_cache = self._core(x) 111 | net_heads = self._heads(core_cache) 112 | return net_heads 113 | 114 | class NetWithPrior(nn.Module): 115 | def __init__(self, net, prior, prior_scale=1.): 116 | super(NetWithPrior, self).__init__() 117 | self.net = net 118 | # used when scaling core net 119 | self.core_net = self.net.core_net 120 | self.prior_scale = prior_scale 121 | if self.prior_scale > 0.: 122 | self.prior = prior 123 | 124 | def forward(self, x, k): 125 | if hasattr(self.net, "net_list"): 126 | if k is not None: 127 | if self.prior_scale > 0.: 128 | return self.net(x, k) + self.prior_scale * self.prior(x, k).detach() 129 | else: 130 | return self.net(x, k) 131 | else: 132 | core_cache = self.net._core(x) 133 | net_heads = self.net._heads(core_cache) 134 | if self.prior_scale <= 0.: 135 | return net_heads 136 | else: 137 | prior_core_cache = self.prior._core(x) 138 | prior_heads = self.prior._heads(prior_core_cache) 139 | return [n + self.prior_scale * p.detach() for n, p in zip(net_heads, prior_heads)] 140 | else: 141 | raise ValueError("Only works with a net_list model") 142 | 143 | 144 | -------------------------------------------------------------------------------- /dqn_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | import sys 5 | from imageio import mimsave 6 | #from skimage.transform import resize 7 | import cv2 8 | 9 | def save_checkpoint(state, filename='model.pkl'): 10 | print("starting save of model %s" %filename) 11 | torch.save(state, filename) 12 | print("finished save of model %s" %filename) 13 | 14 | def seed_everything(seed=1234): 15 | #random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | np.random.seed(seed) 19 | os.environ['PYTHONHASHSEED'] = str(seed) 20 | #torch.backends.cudnn.deterministic = True 21 | 22 | def handle_step(random_state, cnt, S_hist, S_prime, action, reward, finished, k_used, acts, episodic_reward, replay_buffer, checkpoint='', n_ensemble=1, bernoulli_p=1.0): 23 | # mask to determine which head can use this experience 24 | exp_mask = random_state.binomial(1, bernoulli_p, n_ensemble).astype(np.uint8) 25 | # at this observed state 26 | experience = [S_prime, action, reward, finished, exp_mask, k_used, acts, cnt] 27 | batch = replay_buffer.send((checkpoint, experience)) 28 | # update so "state" representation is past history_size frames 29 | S_hist.pop(0) 30 | S_hist.append(S_prime) 31 | episodic_reward += reward 32 | cnt+=1 33 | return cnt, S_hist, batch, episodic_reward 34 | 35 | def linearly_decaying_epsilon(decay_period, step, warmup_steps, epsilon): 36 | """ from dopamine - Returns the current epsilon for the agent's epsilon-greedy policy. 37 | This follows the Nature DQN schedule of a linearly decaying epsilon (Mnih et 38 | al., 2015). The schedule is as follows: 39 | Begin at 1. until warmup_steps steps have been taken; then 40 | Linearly decay epsilon from 1. to epsilon in decay_period steps; and then 41 | Use epsilon from there on. 42 | Args: 43 | decay_period: float, the period over which epsilon is decayed. 44 | step: int, the number of training steps completed so far. 45 | warmup_steps: int, the number of steps taken before epsilon is decayed. 46 | epsilon: float, the final value to which to decay the epsilon parameter. 47 | Returns: 48 | A float, the current epsilon value computed according to the schedule. 49 | """ 50 | steps_left = decay_period + warmup_steps - step 51 | bonus = (1.0 - epsilon) * steps_left / decay_period 52 | bonus = np.clip(bonus, 0., 1. - epsilon) 53 | return epsilon + bonus 54 | 55 | def write_info_file(info, model_base_filepath, cnt): 56 | info_filename = model_base_filepath + "_%010d_info.txt"%cnt 57 | info_f = open(info_filename, 'w') 58 | for (key,val) in info.items(): 59 | info_f.write('%s=%s\n'%(key,val)) 60 | info_f.close() 61 | 62 | def generate_gif(base_dir, step_number, frames_for_gif, reward, name='', results=[]): 63 | """ 64 | from @fg91 65 | Args: 66 | step_number: Integer, determining the number of the current frame 67 | frames_for_gif: A sequence of (210, 160, 3) frames of an Atari game in RGB 68 | reward: Integer, Total reward of the episode that es ouputted as a gif 69 | path: String, path where gif is saved 70 | """ 71 | for idx, frame_idx in enumerate(frames_for_gif): 72 | frames_for_gif[idx] = cv2.resize(frame_idx, (320, 220)).astype(np.uint8) 73 | 74 | if len(frames_for_gif[0].shape) == 2: 75 | name+='gray' 76 | else: 77 | name+='color' 78 | gif_fname = os.path.join(base_dir, "ATARI_step%010d_r%04d_%s.gif"%(step_number, int(reward), name)) 79 | 80 | print("WRITING GIF", gif_fname) 81 | mimsave(gif_fname, frames_for_gif, duration=1/30) 82 | if len(results): 83 | txt_fname = gif_fname.replace('.gif', '.txt') 84 | ff = open(txt_fname, 'w') 85 | for ex in results: 86 | ff.write(ex+'\n') 87 | ff.close() 88 | 89 | 90 | -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import deque 3 | import numpy as np 4 | from atari_py.ale_python_interface import ALEInterface 5 | import cv2 6 | 7 | #from skimage.transform import resize 8 | #from skimage.color import rgb2gray 9 | #from imageio import imwrite 10 | #def preprocess_frame(observ, output_size): 11 | # return resize(rgb2gray(observ),(output_size, output_size)).astype(np.float32, copy=False) 12 | 13 | # opencv is ~3x faster than skimage 14 | def cv_preprocess_frame(observ, output_size): 15 | gray = cv2.cvtColor(observ, cv2.COLOR_RGB2GRAY) 16 | output = cv2.resize(gray, (output_size, output_size), interpolation=cv2.INTER_NEAREST) 17 | return output 18 | 19 | class Environment(object): 20 | def __init__(self, 21 | rom_file, 22 | frame_skip=4, 23 | num_frames=4, 24 | frame_size=84, 25 | no_op_start=30, 26 | rand_seed=393, 27 | dead_as_end=True, 28 | max_episode_steps=18000, 29 | autofire=False): 30 | self.max_episode_steps = max_episode_steps 31 | self.random_state = np.random.RandomState(rand_seed+15) 32 | self.ale = self._init_ale(rand_seed, rom_file) 33 | # normally (160, 210) 34 | self.actions = self.ale.getMinimalActionSet() 35 | 36 | self.frame_skip = frame_skip 37 | self.num_frames = num_frames 38 | self.frame_size = frame_size 39 | self.no_op_start = no_op_start 40 | self.dead_as_end = dead_as_end 41 | 42 | self.total_reward = 0 43 | screen_width, screen_height = self.ale.getScreenDims() 44 | self.prev_screen = np.zeros( 45 | (screen_height, screen_width, 3), dtype=np.uint8) 46 | self.frame_queue = deque(maxlen=num_frames) 47 | self.end = True 48 | 49 | @staticmethod 50 | def _init_ale(rand_seed, rom_file): 51 | assert os.path.exists(rom_file), '%s does not exists.' 52 | ale = ALEInterface() 53 | ale.setInt('random_seed', rand_seed) 54 | ale.setBool('showinfo', False) 55 | ale.setInt('frame_skip', 1) 56 | ale.setFloat('repeat_action_probability', 0.0) 57 | ale.setBool('color_averaging', False) 58 | ale.loadROM(rom_file) 59 | return ale 60 | 61 | @property 62 | def num_actions(self): 63 | return len(self.actions) 64 | 65 | def _get_current_frame(self): 66 | # global glb_counter 67 | screen = self.ale.getScreenRGB() 68 | max_screen = np.maximum(self.prev_screen, screen) 69 | frame = cv_preprocess_frame(max_screen, self.frame_size) 70 | return frame 71 | 72 | def reset(self): 73 | self.steps = 0 74 | self.end = False 75 | self.plot_frames = [] 76 | self.gray_plot_frames = [] 77 | for _ in range(self.num_frames - 1): 78 | self.frame_queue.append( 79 | np.zeros((self.frame_size, self.frame_size), dtype=np.uint8)) 80 | 81 | # steps are in steps the agent sees 82 | self.ale.reset_game() 83 | self.total_reward = 0 84 | self.prev_screen = np.zeros(self.prev_screen.shape, dtype=np.uint8) 85 | 86 | n = self.random_state.randint(0, self.no_op_start) 87 | for i in range(n): 88 | if i == n - 1: 89 | self.prev_screen = self.ale.getScreenRGB() 90 | self.ale.act(0) 91 | 92 | self.frame_queue.append(self._get_current_frame()) 93 | self.plot_frames.append(self.prev_screen) 94 | a = np.array(self.frame_queue) 95 | out = np.concatenate((a[0],a[1],a[2],a[3]),axis=0).T 96 | self.gray_plot_frames.append(out) 97 | if self.ale.game_over(): 98 | print("Unexpected game over in reset", self.reset()) 99 | return np.array(self.frame_queue) 100 | 101 | def step(self, action_idx): 102 | """Perform action and return frame sequence and reward. 103 | Return: 104 | state: [frames] of length num_frames, 0 if fewer is available 105 | reward: float 106 | """ 107 | assert not self.end 108 | reward = 0 109 | old_lives = self.ale.lives() 110 | 111 | for i in range(self.frame_skip): 112 | if i == self.frame_skip - 1: 113 | self.prev_screen = self.ale.getScreenRGB() 114 | r = self.ale.act(self.actions[action_idx]) 115 | reward += r 116 | dead = (self.ale.lives() < old_lives) 117 | if self.dead_as_end and dead: 118 | lives_dead = True 119 | else: 120 | lives_dead = False 121 | 122 | if self.ale.game_over(): 123 | self.end = True 124 | lives_dead = True 125 | self.steps +=1 126 | if self.steps >= self.max_episode_steps: 127 | self.end = True 128 | lives_dead = True 129 | self.frame_queue.append(self._get_current_frame()) 130 | self.total_reward += reward 131 | a = np.array(self.frame_queue) 132 | self.prev_screen = self.ale.getScreenRGB() 133 | self.gray_plot_frames.append(np.concatenate((a[0],a[1],a[2],a[3]),axis=0)) 134 | self.plot_frames.append(self.prev_screen) 135 | return np.array(self.frame_queue), reward, lives_dead, self.end 136 | 137 | 138 | if __name__ == '__main__': 139 | 140 | import time 141 | env = Environment('roms/breakout.bin', 4, 4, 84, 30, 33, True) 142 | print('starting with game over?', env.ale.game_over()) 143 | random_state = np.random.RandomState(304) 144 | 145 | state = env.reset() 146 | i = 0 147 | times = [0.0 for a in range(1000)] 148 | total_reward = 0 149 | for t in times: 150 | action = random_state.randint(0, env.num_actions) 151 | st = time.time() 152 | state, reward, end, do_reset = env.step(action) 153 | total_reward += reward 154 | times[i] = time.time()-st 155 | i += 1 156 | if end: 157 | print(i,"END") 158 | if do_reset: 159 | print(i,"RESET") 160 | state = env.reset() 161 | print('total_reward:', total_reward) 162 | print('total steps:', i) 163 | print('mean time', np.mean(times)) 164 | print('max time', np.max(times)) 165 | # with cv - 1000 steps ('mean time', 0.0008075275421142578) max 0.0008950233459472656 166 | # with skimage - ('mean time', 0.0022023658752441406) max, 0.003056049346923828 167 | -------------------------------------------------------------------------------- /figs/breakout_avg_eval_reward_steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/figs/breakout_avg_eval_reward_steps.png -------------------------------------------------------------------------------- /figs/breakout_avg_training_reward_steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/figs/breakout_avg_training_reward_steps.png -------------------------------------------------------------------------------- /figs/breakout_winning_gaps.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/figs/breakout_winning_gaps.gif -------------------------------------------------------------------------------- /figs/freeway_9heads_prior_episode_reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/figs/freeway_9heads_prior_episode_reward.png -------------------------------------------------------------------------------- /figs/freeway_eval_rewards_steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/figs/freeway_eval_rewards_steps.png -------------------------------------------------------------------------------- /figs/pong_eval_rewards_steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/figs/pong_eval_rewards_steps.png -------------------------------------------------------------------------------- /figs/small_freeway_ATARI_s14547756_R34.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/figs/small_freeway_ATARI_s14547756_R34.gif -------------------------------------------------------------------------------- /figs/small_pong_ATARI_step0002509459_r0021_testcolor.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/figs/small_pong_ATARI_step0002509459_r0021_testcolor.gif -------------------------------------------------------------------------------- /replay.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | # This function was mostly pulled from 5 | # https://github.com/fg91/Deep-Q-Learning/blob/master/DQN.ipynb 6 | class ReplayMemory: 7 | """Replay Memory that stores the last size=1,000,000 transitions""" 8 | def __init__(self, size=1000000, frame_height=84, frame_width=84, 9 | agent_history_length=4, batch_size=32, num_heads=1, bernoulli_probability=1.0): 10 | """ 11 | Args: 12 | size: Integer, Number of stored transitions 13 | frame_height: Integer, Height of a frame of an Atari game 14 | frame_width: Integer, Width of a frame of an Atari game 15 | agent_history_length: Integer, Number of frames stacked together to create a state 16 | batch_size: Integer, Number if transitions returned in a minibatch 17 | num_heads: integer number of heads needed in mask 18 | bernoulli_probability: bernoulli probability that an experience will go to a particular head 19 | """ 20 | self.bernoulli_probability = bernoulli_probability 21 | assert(self.bernoulli_probability > 0) 22 | self.size = size 23 | self.frame_height = frame_height 24 | self.frame_width = frame_width 25 | self.agent_history_length = agent_history_length 26 | self.count = 0 27 | self.current = 0 28 | self.num_heads = num_heads 29 | # Pre-allocate memory 30 | self.actions = np.empty(self.size, dtype=np.int32) 31 | self.rewards = np.empty(self.size, dtype=np.float32) 32 | self.frames = np.empty((self.size, self.frame_height, self.frame_width), dtype=np.uint8) 33 | self.terminal_flags = np.empty(self.size, dtype=np.bool) 34 | self.masks = np.empty((self.size, self.num_heads), dtype=np.bool) 35 | 36 | # Pre-allocate memory for the states and new_states in a minibatch 37 | self.states = np.empty((batch_size, self.agent_history_length, 38 | self.frame_height, self.frame_width), dtype=np.uint8) 39 | self.new_states = np.empty((batch_size, self.agent_history_length, 40 | self.frame_height, self.frame_width), dtype=np.uint8) 41 | self.indices = np.empty(batch_size, dtype=np.int32) 42 | self.random_state = np.random.RandomState(393) 43 | if self.num_heads == 1: 44 | assert(self.bernoulli_probability == 1.0) 45 | 46 | def save_buffer(self, filepath): 47 | st = time.time() 48 | print("starting save of buffer to %s"%filepath, st) 49 | np.savez(filepath, 50 | frames=self.frames, actions=self.actions, rewards=self.rewards, 51 | terminal_flags=self.terminal_flags, masks=self.masks, 52 | count=self.count, current=self.current, 53 | agent_history_length=self.agent_history_length, 54 | frame_height=self.frame_height, frame_width=self.frame_width, 55 | num_heads=self.num_heads, bernoulli_probability=self.bernoulli_probability, 56 | ) 57 | print("finished saving buffer", time.time()-st) 58 | 59 | def load_buffer(self, filepath): 60 | st = time.time() 61 | print("starting load of buffer from %s"%filepath, st) 62 | npfile = np.load(filepath) 63 | self.frames = npfile['frames'] 64 | self.actions = npfile['actions'] 65 | self.rewards = npfile['rewards'] 66 | self.terminal_flags = npfile['terminal_flags'] 67 | self.masks = npfile['masks'] 68 | self.count = npfile['count'] 69 | self.current = npfile['current'] 70 | self.agent_history_length = npfile['agent_history_length'] 71 | self.frame_height = npfile['frame_height'] 72 | self.frame_width = npfile['frame_width'] 73 | self.num_heads = npfile['num_heads'] 74 | self.bernoulli_probability = npfile['bernoulli_probability'] 75 | if self.num_heads == 1: 76 | assert(self.bernoulli_probability == 1.0) 77 | print("finished loading buffer", time.time()-st) 78 | print("loaded buffer current is", self.current) 79 | 80 | def add_experience(self, action, frame, reward, terminal): 81 | """ 82 | Args: 83 | action: An integer between 0 and env.action_space.n - 1 84 | determining the action the agent perfomed 85 | frame: A (84, 84, 1) frame of an Atari game in grayscale 86 | reward: A float determining the reward the agend received for performing an action 87 | terminal: A bool stating whether the episode terminated 88 | """ 89 | if frame.shape != (self.frame_height, self.frame_width): 90 | raise ValueError('Dimension of frame is wrong!') 91 | self.actions[self.current] = action 92 | self.frames[self.current, ...] = frame 93 | self.rewards[self.current] = reward 94 | self.terminal_flags[self.current] = terminal 95 | mask = self.random_state.binomial(1, self.bernoulli_probability, self.num_heads) 96 | self.masks[self.current] = mask 97 | self.count = max(self.count, self.current+1) 98 | self.current = (self.current + 1) % self.size 99 | 100 | def _get_state(self, index): 101 | if self.count is 0: 102 | raise ValueError("The replay memory is empty!") 103 | if index < self.agent_history_length - 1: 104 | raise ValueError("Index must be min 3") 105 | return self.frames[index-self.agent_history_length+1:index+1, ...] 106 | 107 | def _get_valid_indices(self, batch_size): 108 | if batch_size != self.indices.shape[0]: 109 | self.indices = np.empty(batch_size, dtype=np.int32) 110 | 111 | for i in range(batch_size): 112 | while True: 113 | index = self.random_state.randint(self.agent_history_length, self.count - 1) 114 | if index < self.agent_history_length: 115 | continue 116 | if index >= self.current and index - self.agent_history_length <= self.current: 117 | continue 118 | # dont add if there was a terminal flag in previous 119 | # history_length steps 120 | if self.terminal_flags[index - self.agent_history_length:index].any(): 121 | continue 122 | break 123 | self.indices[i] = index 124 | 125 | def get_minibatch(self, batch_size): 126 | """ 127 | Returns a minibatch of batch_size 128 | """ 129 | if batch_size != self.states.shape[0]: 130 | self.states = np.empty((batch_size, self.agent_history_length, 131 | self.frame_height, self.frame_width), dtype=np.uint8) 132 | self.new_states = np.empty((batch_size, self.agent_history_length, 133 | self.frame_height, self.frame_width), dtype=np.uint8) 134 | 135 | if self.count < self.agent_history_length: 136 | raise ValueError('Not enough memories to get a minibatch') 137 | 138 | self._get_valid_indices(batch_size) 139 | 140 | for i, idx in enumerate(self.indices): 141 | self.states[i] = self._get_state(idx - 1) 142 | self.new_states[i] = self._get_state(idx) 143 | return self.states, self.actions[self.indices], self.rewards[self.indices], self.new_states, self.terminal_flags[self.indices], self.masks[self.indices] 144 | 145 | 146 | -------------------------------------------------------------------------------- /roms/air_raid.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/air_raid.bin -------------------------------------------------------------------------------- /roms/alien.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/alien.bin -------------------------------------------------------------------------------- /roms/amidar.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/amidar.bin -------------------------------------------------------------------------------- /roms/assault.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/assault.bin -------------------------------------------------------------------------------- /roms/asterix.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/asterix.bin -------------------------------------------------------------------------------- /roms/asteroids.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/asteroids.bin -------------------------------------------------------------------------------- /roms/atlantis.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/atlantis.bin -------------------------------------------------------------------------------- /roms/bank_heist.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/bank_heist.bin -------------------------------------------------------------------------------- /roms/battle_zone.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/battle_zone.bin -------------------------------------------------------------------------------- /roms/beam_rider.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/beam_rider.bin -------------------------------------------------------------------------------- /roms/berzerk.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/berzerk.bin -------------------------------------------------------------------------------- /roms/bowling.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/bowling.bin -------------------------------------------------------------------------------- /roms/boxing.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/boxing.bin -------------------------------------------------------------------------------- /roms/breakout.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/breakout.bin -------------------------------------------------------------------------------- /roms/carnival.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/carnival.bin -------------------------------------------------------------------------------- /roms/centipede.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/centipede.bin -------------------------------------------------------------------------------- /roms/chopper_command.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/chopper_command.bin -------------------------------------------------------------------------------- /roms/crazy_climber.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/crazy_climber.bin -------------------------------------------------------------------------------- /roms/defender.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/defender.bin -------------------------------------------------------------------------------- /roms/demon_attack.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/demon_attack.bin -------------------------------------------------------------------------------- /roms/double_dunk.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/double_dunk.bin -------------------------------------------------------------------------------- /roms/elevator_action.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/elevator_action.bin -------------------------------------------------------------------------------- /roms/enduro.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/enduro.bin -------------------------------------------------------------------------------- /roms/fishing_derby.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/fishing_derby.bin -------------------------------------------------------------------------------- /roms/freeway.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/freeway.bin -------------------------------------------------------------------------------- /roms/frostbite.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/frostbite.bin -------------------------------------------------------------------------------- /roms/gopher.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/gopher.bin -------------------------------------------------------------------------------- /roms/gravitar.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/gravitar.bin -------------------------------------------------------------------------------- /roms/hero.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/hero.bin -------------------------------------------------------------------------------- /roms/ice_hockey.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/ice_hockey.bin -------------------------------------------------------------------------------- /roms/jamesbond.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/jamesbond.bin -------------------------------------------------------------------------------- /roms/journey_escape.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/journey_escape.bin -------------------------------------------------------------------------------- /roms/kangaroo.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/kangaroo.bin -------------------------------------------------------------------------------- /roms/krull.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/krull.bin -------------------------------------------------------------------------------- /roms/kung_fu_master.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/kung_fu_master.bin -------------------------------------------------------------------------------- /roms/montezuma_revenge.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/montezuma_revenge.bin -------------------------------------------------------------------------------- /roms/ms_pacman.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/ms_pacman.bin -------------------------------------------------------------------------------- /roms/name_this_game.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/name_this_game.bin -------------------------------------------------------------------------------- /roms/phoenix.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/phoenix.bin -------------------------------------------------------------------------------- /roms/pitfall.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/pitfall.bin -------------------------------------------------------------------------------- /roms/pong.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/pong.bin -------------------------------------------------------------------------------- /roms/pooyan.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/pooyan.bin -------------------------------------------------------------------------------- /roms/private_eye.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/private_eye.bin -------------------------------------------------------------------------------- /roms/qbert.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/qbert.bin -------------------------------------------------------------------------------- /roms/riverraid.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/riverraid.bin -------------------------------------------------------------------------------- /roms/road_runner.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/road_runner.bin -------------------------------------------------------------------------------- /roms/robotank.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/robotank.bin -------------------------------------------------------------------------------- /roms/seaquest.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/seaquest.bin -------------------------------------------------------------------------------- /roms/skiing.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/skiing.bin -------------------------------------------------------------------------------- /roms/solaris.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/solaris.bin -------------------------------------------------------------------------------- /roms/space_invaders.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/space_invaders.bin -------------------------------------------------------------------------------- /roms/star_gunner.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/star_gunner.bin -------------------------------------------------------------------------------- /roms/tennis.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/tennis.bin -------------------------------------------------------------------------------- /roms/time_pilot.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/time_pilot.bin -------------------------------------------------------------------------------- /roms/tutankham.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/tutankham.bin -------------------------------------------------------------------------------- /roms/up_n_down.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/up_n_down.bin -------------------------------------------------------------------------------- /roms/venture.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/venture.bin -------------------------------------------------------------------------------- /roms/video_pinball.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/video_pinball.bin -------------------------------------------------------------------------------- /roms/wizard_of_wor.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/wizard_of_wor.bin -------------------------------------------------------------------------------- /roms/yars_revenge.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/yars_revenge.bin -------------------------------------------------------------------------------- /roms/zaxxon.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannah/bootstrap_dqn/036a6f638f9f68b07ae8a6a463ffad452cb3610f/roms/zaxxon.bin -------------------------------------------------------------------------------- /run_bootstrap.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | import os 6 | import numpy as np 7 | from IPython import embed 8 | from collections import Counter 9 | import torch 10 | torch.set_num_threads(2) 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | import datetime 15 | import time 16 | from dqn_model import EnsembleNet, NetWithPrior 17 | from dqn_utils import seed_everything, write_info_file, generate_gif, save_checkpoint 18 | from env import Environment 19 | from replay import ReplayMemory 20 | import config 21 | 22 | def rolling_average(a, n=5) : 23 | if n == 0: 24 | return a 25 | ret = np.cumsum(a, dtype=float) 26 | ret[n:] = ret[n:] - ret[:-n] 27 | return ret[n - 1:] / n 28 | 29 | def plot_dict_losses(plot_dict, name='loss_example.png', rolling_length=4, plot_title=''): 30 | f,ax=plt.subplots(1,1,figsize=(6,6)) 31 | for n in plot_dict.keys(): 32 | print('plotting', n) 33 | ax.plot(rolling_average(plot_dict[n]['index']), rolling_average(plot_dict[n]['val']), lw=1) 34 | ax.scatter(rolling_average(plot_dict[n]['index']), rolling_average(plot_dict[n]['val']), label=n, s=3) 35 | ax.legend() 36 | if plot_title != '': 37 | plt.title(plot_title) 38 | plt.savefig(name) 39 | plt.close() 40 | 41 | def matplotlib_plot_all(p): 42 | epoch_num = len(p['steps']) 43 | epochs = np.arange(epoch_num) 44 | steps = p['steps'] 45 | plot_dict_losses({'episode steps':{'index':epochs,'val':p['episode_step']}}, name=os.path.join(model_base_filedir, 'episode_step.png'), rolling_length=0) 46 | plot_dict_losses({'episode steps':{'index':epochs,'val':p['episode_relative_times']}}, name=os.path.join(model_base_filedir, 'episode_relative_times.png'), rolling_length=10) 47 | plot_dict_losses({'episode head':{'index':epochs, 'val':p['episode_head']}}, name=os.path.join(model_base_filedir, 'episode_head.png'), rolling_length=0) 48 | plot_dict_losses({'steps loss':{'index':steps, 'val':p['episode_loss']}}, name=os.path.join(model_base_filedir, 'steps_loss.png')) 49 | plot_dict_losses({'steps eps':{'index':steps, 'val':p['eps_list']}}, name=os.path.join(model_base_filedir, 'steps_mean_eps.png'), rolling_length=0) 50 | plot_dict_losses({'steps reward':{'index':steps,'val':p['episode_reward']}}, name=os.path.join(model_base_filedir, 'steps_reward.png'), rolling_length=0) 51 | plot_dict_losses({'episode reward':{'index':epochs, 'val':p['episode_reward']}}, name=os.path.join(model_base_filedir, 'episode_reward.png'), rolling_length=0) 52 | plot_dict_losses({'episode times':{'index':epochs,'val':p['episode_times']}}, name=os.path.join(model_base_filedir, 'episode_times.png'), rolling_length=5) 53 | plot_dict_losses({'steps avg reward':{'index':steps,'val':p['avg_rewards']}}, name=os.path.join(model_base_filedir, 'steps_avg_reward.png'), rolling_length=0) 54 | plot_dict_losses({'eval rewards':{'index':p['eval_steps'], 'val':p['eval_rewards']}}, name=os.path.join(model_base_filedir, 'eval_rewards_steps.png'), rolling_length=0) 55 | 56 | def handle_checkpoint(last_save, cnt): 57 | if (cnt-last_save) >= info['CHECKPOINT_EVERY_STEPS']: 58 | st = time.time() 59 | print("beginning checkpoint", st) 60 | last_save = cnt 61 | state = {'info':info, 62 | 'optimizer':opt.state_dict(), 63 | 'cnt':cnt, 64 | 'policy_net_state_dict':policy_net.state_dict(), 65 | 'target_net_state_dict':target_net.state_dict(), 66 | 'perf':perf, 67 | } 68 | filename = os.path.abspath(model_base_filepath + "_%010dq.pkl"%cnt) 69 | save_checkpoint(state, filename) 70 | # npz will be added 71 | buff_filename = os.path.abspath(model_base_filepath + "_%010dq_train_buffer"%cnt) 72 | replay_memory.save_buffer(buff_filename) 73 | print("finished checkpoint", time.time()-st) 74 | return last_save 75 | else: return last_save 76 | 77 | 78 | class ActionGetter: 79 | """Determines an action according to an epsilon greedy strategy with annealing epsilon""" 80 | """This class is from fg91's dqn. TODO put my function back in""" 81 | def __init__(self, n_actions, eps_initial=1, eps_final=0.1, eps_final_frame=0.01, 82 | eps_evaluation=0.0, eps_annealing_frames=100000, 83 | replay_memory_start_size=50000, max_steps=25000000, random_seed=122): 84 | """ 85 | Args: 86 | n_actions: Integer, number of possible actions 87 | eps_initial: Float, Exploration probability for the first 88 | replay_memory_start_size frames 89 | eps_final: Float, Exploration probability after 90 | replay_memory_start_size + eps_annealing_frames frames 91 | eps_final_frame: Float, Exploration probability after max_frames frames 92 | eps_evaluation: Float, Exploration probability during evaluation 93 | eps_annealing_frames: Int, Number of frames over which the 94 | exploration probabilty is annealed from eps_initial to eps_final 95 | replay_memory_start_size: Integer, Number of frames during 96 | which the agent only explores 97 | max_frames: Integer, Total number of frames shown to the agent 98 | """ 99 | self.n_actions = n_actions 100 | self.eps_initial = eps_initial 101 | self.eps_final = eps_final 102 | self.eps_final_frame = eps_final_frame 103 | self.eps_evaluation = eps_evaluation 104 | self.eps_annealing_frames = eps_annealing_frames 105 | self.replay_memory_start_size = replay_memory_start_size 106 | self.max_steps = max_steps 107 | self.random_state = np.random.RandomState(random_seed) 108 | 109 | # Slopes and intercepts for exploration decrease 110 | if self.eps_annealing_frames > 0: 111 | self.slope = -(self.eps_initial - self.eps_final)/self.eps_annealing_frames 112 | self.intercept = self.eps_initial - self.slope*self.replay_memory_start_size 113 | self.slope_2 = -(self.eps_final - self.eps_final_frame)/(self.max_steps - self.eps_annealing_frames - self.replay_memory_start_size) 114 | self.intercept_2 = self.eps_final_frame - self.slope_2*self.max_steps 115 | 116 | def pt_get_action(self, step_number, state, active_head=None, evaluation=False): 117 | """ 118 | Args: 119 | step_number: int number of the current step 120 | state: A (4, 84, 84) sequence of frames of an atari game in grayscale 121 | active_head: number of head to use, if None, will run all heads and vote 122 | evaluation: A boolean saying whether the agent is being evaluated 123 | Returns: 124 | An integer between 0 and n_actions 125 | """ 126 | if evaluation: 127 | eps = self.eps_evaluation 128 | elif step_number < self.replay_memory_start_size: 129 | eps = self.eps_initial 130 | elif self.eps_annealing_frames > 0: 131 | # TODO check this 132 | if step_number >= self.replay_memory_start_size and step_number < self.replay_memory_start_size + self.eps_annealing_frames: 133 | eps = self.slope*step_number + self.intercept 134 | elif step_number >= self.replay_memory_start_size + self.eps_annealing_frames: 135 | eps = self.slope_2*step_number + self.intercept_2 136 | else: 137 | eps = 0 138 | if self.random_state.rand() < eps: 139 | return eps, self.random_state.randint(0, self.n_actions) 140 | else: 141 | state = torch.Tensor(state.astype(np.float)/info['NORM_BY'])[None,:].to(info['DEVICE']) 142 | vals = policy_net(state, active_head) 143 | if active_head is not None: 144 | action = torch.argmax(vals, dim=1).item() 145 | return eps, action 146 | else: 147 | # vote 148 | acts = [torch.argmax(vals[h],dim=1).item() for h in range(info['N_ENSEMBLE'])] 149 | data = Counter(acts) 150 | action = data.most_common(1)[0][0] 151 | return eps, action 152 | 153 | def ptlearn(states, actions, rewards, next_states, terminal_flags, masks): 154 | states = torch.Tensor(states.astype(np.float)/info['NORM_BY']).to(info['DEVICE']) 155 | next_states = torch.Tensor(next_states.astype(np.float)/info['NORM_BY']).to(info['DEVICE']) 156 | rewards = torch.Tensor(rewards).to(info['DEVICE']) 157 | actions = torch.LongTensor(actions).to(info['DEVICE']) 158 | terminal_flags = torch.Tensor(terminal_flags.astype(np.int)).to(info['DEVICE']) 159 | masks = torch.FloatTensor(masks.astype(np.int)).to(info['DEVICE']) 160 | # min history to learn is 200,000 frames in dqn - 50000 steps 161 | losses = [0.0 for _ in range(info['N_ENSEMBLE'])] 162 | opt.zero_grad() 163 | q_policy_vals = policy_net(states, None) 164 | next_q_target_vals = target_net(next_states, None) 165 | next_q_policy_vals = policy_net(next_states, None) 166 | cnt_losses = [] 167 | for k in range(info['N_ENSEMBLE']): 168 | #TODO finish masking 169 | total_used = torch.sum(masks[:,k]) 170 | if total_used > 0.0: 171 | next_q_vals = next_q_target_vals[k].data 172 | if info['DOUBLE_DQN']: 173 | next_actions = next_q_policy_vals[k].data.max(1, True)[1] 174 | next_qs = next_q_vals.gather(1, next_actions).squeeze(1) 175 | else: 176 | next_qs = next_q_vals.max(1)[0] # max returns a pair 177 | 178 | preds = q_policy_vals[k].gather(1, actions[:,None]).squeeze(1) 179 | targets = rewards + info['GAMMA'] * next_qs * (1-terminal_flags) 180 | l1loss = F.smooth_l1_loss(preds, targets, reduction='mean') 181 | full_loss = masks[:,k]*l1loss 182 | loss = torch.sum(full_loss/total_used) 183 | cnt_losses.append(loss) 184 | losses[k] = loss.cpu().detach().item() 185 | 186 | loss = sum(cnt_losses)/info['N_ENSEMBLE'] 187 | loss.backward() 188 | for param in policy_net.core_net.parameters(): 189 | if param.grad is not None: 190 | # divide grads in core 191 | param.grad.data *=1.0/float(info['N_ENSEMBLE']) 192 | nn.utils.clip_grad_norm_(policy_net.parameters(), info['CLIP_GRAD']) 193 | opt.step() 194 | return np.mean(losses) 195 | 196 | def train(step_number, last_save): 197 | """Contains the training and evaluation loops""" 198 | epoch_num = len(perf['steps']) 199 | while step_number < info['MAX_STEPS']: 200 | ######################## 201 | ####### Training ####### 202 | ######################## 203 | epoch_frame = 0 204 | while epoch_frame < info['EVAL_FREQUENCY']: 205 | terminal = False 206 | life_lost = True 207 | state = env.reset() 208 | start_steps = step_number 209 | st = time.time() 210 | episode_reward_sum = 0 211 | random_state.shuffle(heads) 212 | active_head = heads[0] 213 | epoch_num += 1 214 | ep_eps_list = [] 215 | ptloss_list = [] 216 | while not terminal: 217 | if life_lost: 218 | action = 1 219 | eps = 0 220 | else: 221 | eps,action = action_getter.pt_get_action(step_number, state=state, active_head=active_head) 222 | ep_eps_list.append(eps) 223 | next_state, reward, life_lost, terminal = env.step(action) 224 | # Store transition in the replay memory 225 | replay_memory.add_experience(action=action, 226 | frame=next_state[-1], 227 | reward=np.sign(reward), # TODO -maybe there should be +1 here 228 | terminal=life_lost) 229 | 230 | step_number += 1 231 | epoch_frame += 1 232 | episode_reward_sum += reward 233 | state = next_state 234 | 235 | if step_number % info['LEARN_EVERY_STEPS'] == 0 and step_number > info['MIN_HISTORY_TO_LEARN']: 236 | _states, _actions, _rewards, _next_states, _terminal_flags, _masks = replay_memory.get_minibatch(info['BATCH_SIZE']) 237 | ptloss = ptlearn(_states, _actions, _rewards, _next_states, _terminal_flags, _masks) 238 | ptloss_list.append(ptloss) 239 | if step_number % info['TARGET_UPDATE'] == 0 and step_number > info['MIN_HISTORY_TO_LEARN']: 240 | print("++++++++++++++++++++++++++++++++++++++++++++++++") 241 | print('updating target network at %s'%step_number) 242 | target_net.load_state_dict(policy_net.state_dict()) 243 | 244 | et = time.time() 245 | ep_time = et-st 246 | perf['steps'].append(step_number) 247 | perf['episode_step'].append(step_number-start_steps) 248 | perf['episode_head'].append(active_head) 249 | perf['eps_list'].append(np.mean(ep_eps_list)) 250 | perf['episode_loss'].append(np.mean(ptloss_list)) 251 | perf['episode_reward'].append(episode_reward_sum) 252 | perf['episode_times'].append(ep_time) 253 | perf['episode_relative_times'].append(time.time()-info['START_TIME']) 254 | perf['avg_rewards'].append(np.mean(perf['episode_reward'][-100:])) 255 | last_save = handle_checkpoint(last_save, step_number) 256 | 257 | if not epoch_num%info['PLOT_EVERY_EPISODES'] and step_number > info['MIN_HISTORY_TO_LEARN']: 258 | # TODO plot title 259 | print('avg reward', perf['avg_rewards'][-1]) 260 | print('last rewards', perf['episode_reward'][-info['PLOT_EVERY_EPISODES']:]) 261 | 262 | matplotlib_plot_all(perf) 263 | with open('rewards.txt', 'a') as reward_file: 264 | print(len(perf['episode_reward']), step_number, perf['avg_rewards'][-1], file=reward_file) 265 | avg_eval_reward = evaluate(step_number) 266 | perf['eval_rewards'].append(avg_eval_reward) 267 | perf['eval_steps'].append(step_number) 268 | matplotlib_plot_all(perf) 269 | 270 | def evaluate(step_number): 271 | print(""" 272 | ######################### 273 | ####### Evaluation ###### 274 | ######################### 275 | """) 276 | eval_rewards = [] 277 | evaluate_step_number = 0 278 | frames_for_gif = [] 279 | results_for_eval = [] 280 | # only run one 281 | for i in range(info['NUM_EVAL_EPISODES']): 282 | state = env.reset() 283 | episode_reward_sum = 0 284 | terminal = False 285 | life_lost = True 286 | episode_steps = 0 287 | while not terminal: 288 | if life_lost: 289 | action = 1 290 | else: 291 | eps,action = action_getter.pt_get_action(step_number, state, active_head=None, evaluation=True) 292 | next_state, reward, life_lost, terminal = env.step(action) 293 | evaluate_step_number += 1 294 | episode_steps +=1 295 | episode_reward_sum += reward 296 | if not i: 297 | # only save first episode 298 | frames_for_gif.append(env.ale.getScreenRGB()) 299 | results_for_eval.append("%s, %s, %s, %s" %(action, reward, life_lost, terminal)) 300 | if not episode_steps%100: 301 | print('eval', episode_steps, episode_reward_sum) 302 | state = next_state 303 | eval_rewards.append(episode_reward_sum) 304 | 305 | print("Evaluation score:\n", np.mean(eval_rewards)) 306 | generate_gif(model_base_filedir, step_number, frames_for_gif, eval_rewards[0], name='test', results=results_for_eval) 307 | 308 | # Show the evaluation score in tensorboard 309 | efile = os.path.join(model_base_filedir, 'eval_rewards.txt') 310 | with open(efile, 'a') as eval_reward_file: 311 | print(step_number, np.mean(eval_rewards), file=eval_reward_file) 312 | return np.mean(eval_rewards) 313 | 314 | if __name__ == '__main__': 315 | from argparse import ArgumentParser 316 | parser = ArgumentParser() 317 | parser.add_argument('-c', '--cuda', action='store_true', default=False) 318 | parser.add_argument('-l', '--model_loadpath', default='', help='.pkl model file full path') 319 | parser.add_argument('-b', '--buffer_loadpath', default='', help='.npz replay buffer file full path') 320 | args = parser.parse_args() 321 | if args.cuda: 322 | device = 'cuda' 323 | else: 324 | device = 'cpu' 325 | print("running on %s"%device) 326 | 327 | info = { 328 | #"GAME":'roms/breakout.bin', # gym prefix 329 | "GAME":'roms/pong.bin', # gym prefix 330 | "DEVICE":device, #cpu vs gpu set by argument 331 | "NAME":'FRANKbootstrap_fasteranneal_pong', # start files with name 332 | "DUELING":True, # use dueling dqn 333 | "DOUBLE_DQN":True, # use double dqn 334 | "PRIOR":True, # turn on to use randomized prior 335 | "PRIOR_SCALE":10, # what to scale prior by 336 | "N_ENSEMBLE":9, # number of bootstrap heads to use. when 1, this is a normal dqn 337 | "LEARN_EVERY_STEPS":4, # updates every 4 steps in osband 338 | "BERNOULLI_PROBABILITY": 0.9, # Probability of experience to go to each head - if 1, every experience goes to every head 339 | "TARGET_UPDATE":10000, # how often to update target network 340 | "MIN_HISTORY_TO_LEARN":50000, # in environment frames 341 | "NORM_BY":255., # divide the float(of uint) by this number to normalize - max val of data is 255 342 | "EPS_INITIAL":1.0, # should be 1 343 | "EPS_FINAL":0.01, # 0.01 in osband 344 | "EPS_EVAL":0.0, # 0 in osband, .05 in others.... 345 | "EPS_ANNEALING_FRAMES":int(1e6), # this may have been 1e6 in osband 346 | #"EPS_ANNEALING_FRAMES":0, # if it annealing is zero, then it will only use the bootstrap after the first MIN_EXAMPLES_TO_LEARN steps which are random 347 | "EPS_FINAL_FRAME":0.01, 348 | "NUM_EVAL_EPISODES":1, # num examples to average in eval 349 | "BUFFER_SIZE":int(1e6), # Buffer size for experience replay 350 | "CHECKPOINT_EVERY_STEPS":500000, # how often to write pkl of model and npz of data buffer 351 | "EVAL_FREQUENCY":250000, # how often to run evaluation episodes 352 | "ADAM_LEARNING_RATE":6.25e-5, 353 | "RMS_LEARNING_RATE": 0.00025, # according to paper = 0.00025 354 | "RMS_DECAY":0.95, 355 | "RMS_MOMENTUM":0.0, 356 | "RMS_EPSILON":0.00001, 357 | "RMS_CENTERED":True, 358 | "HISTORY_SIZE":4, # how many past frames to use for state input 359 | "N_EPOCHS":90000, # Number of episodes to run 360 | "BATCH_SIZE":32, # Batch size to use for learning 361 | "GAMMA":.99, # Gamma weight in Q update 362 | "PLOT_EVERY_EPISODES": 50, 363 | "CLIP_GRAD":5, # Gradient clipping setting 364 | "SEED":101, 365 | "RANDOM_HEAD":-1, # just used in plotting as demarcation 366 | "NETWORK_INPUT_SIZE":(84,84), 367 | "START_TIME":time.time(), 368 | "MAX_STEPS":int(50e6), # 50e6 steps is 200e6 frames 369 | "MAX_EPISODE_STEPS":27000, # Orig dqn give 18k steps, Rainbow seems to give 27k steps 370 | "FRAME_SKIP":4, # deterministic frame skips to match deepmind 371 | "MAX_NO_OP_FRAMES":30, # random number of noops applied to beginning of each episode 372 | "DEAD_AS_END":True, # do you send finished=true to agent while training when it loses a life 373 | } 374 | 375 | info['FAKE_ACTS'] = [info['RANDOM_HEAD'] for x in range(info['N_ENSEMBLE'])] 376 | info['args'] = args 377 | info['load_time'] = datetime.date.today().ctime() 378 | info['NORM_BY'] = float(info['NORM_BY']) 379 | 380 | # create environment 381 | env = Environment(rom_file=info['GAME'], frame_skip=info['FRAME_SKIP'], 382 | num_frames=info['HISTORY_SIZE'], no_op_start=info['MAX_NO_OP_FRAMES'], rand_seed=info['SEED'], 383 | dead_as_end=info['DEAD_AS_END'], max_episode_steps=info['MAX_EPISODE_STEPS']) 384 | 385 | # create replay buffer 386 | replay_memory = ReplayMemory(size=info['BUFFER_SIZE'], 387 | frame_height=info['NETWORK_INPUT_SIZE'][0], 388 | frame_width=info['NETWORK_INPUT_SIZE'][1], 389 | agent_history_length=info['HISTORY_SIZE'], 390 | batch_size=info['BATCH_SIZE'], 391 | num_heads=info['N_ENSEMBLE'], 392 | bernoulli_probability=info['BERNOULLI_PROBABILITY']) 393 | 394 | random_state = np.random.RandomState(info["SEED"]) 395 | action_getter = ActionGetter(n_actions=env.num_actions, 396 | eps_initial=info['EPS_INITIAL'], 397 | eps_final=info['EPS_FINAL'], 398 | eps_final_frame=info['EPS_FINAL_FRAME'], 399 | eps_annealing_frames=info['EPS_ANNEALING_FRAMES'], 400 | eps_evaluation=info['EPS_EVAL'], 401 | replay_memory_start_size=info['MIN_HISTORY_TO_LEARN'], 402 | max_steps=info['MAX_STEPS']) 403 | 404 | if args.model_loadpath != '': 405 | # load data from loadpath - save model load for later. we need some of 406 | # these parameters to setup other things 407 | print('loading model from: %s' %args.model_loadpath) 408 | model_dict = torch.load(args.model_loadpath) 409 | info = model_dict['info'] 410 | info['DEVICE'] = device 411 | # set a new random seed 412 | info["SEED"] = model_dict['cnt'] 413 | model_base_filedir = os.path.split(args.model_loadpath)[0] 414 | start_step_number = start_last_save = model_dict['cnt'] 415 | info['loaded_from'] = args.model_loadpath 416 | perf = model_dict['perf'] 417 | start_step_number = perf['steps'][-1] 418 | else: 419 | # create new project 420 | perf = {'steps':[], 421 | 'avg_rewards':[], 422 | 'episode_step':[], 423 | 'episode_head':[], 424 | 'eps_list':[], 425 | 'episode_loss':[], 426 | 'episode_reward':[], 427 | 'episode_times':[], 428 | 'episode_relative_times':[], 429 | 'eval_rewards':[], 430 | 'eval_steps':[]} 431 | 432 | start_step_number = 0 433 | start_last_save = 0 434 | # make new directory for this run in the case that there is already a 435 | # project with this name 436 | run_num = 0 437 | model_base_filedir = os.path.join(config.model_savedir, info['NAME'] + '%02d'%run_num) 438 | while os.path.exists(model_base_filedir): 439 | run_num +=1 440 | model_base_filedir = os.path.join(config.model_savedir, info['NAME'] + '%02d'%run_num) 441 | os.makedirs(model_base_filedir) 442 | print("----------------------------------------------") 443 | print("starting NEW project: %s"%model_base_filedir) 444 | 445 | model_base_filepath = os.path.join(model_base_filedir, info['NAME']) 446 | write_info_file(info, model_base_filepath, start_step_number) 447 | heads = list(range(info['N_ENSEMBLE'])) 448 | seed_everything(info["SEED"]) 449 | 450 | policy_net = EnsembleNet(n_ensemble=info['N_ENSEMBLE'], 451 | n_actions=env.num_actions, 452 | network_output_size=info['NETWORK_INPUT_SIZE'][0], 453 | num_channels=info['HISTORY_SIZE'], dueling=info['DUELING']).to(info['DEVICE']) 454 | target_net = EnsembleNet(n_ensemble=info['N_ENSEMBLE'], 455 | n_actions=env.num_actions, 456 | network_output_size=info['NETWORK_INPUT_SIZE'][0], 457 | num_channels=info['HISTORY_SIZE'], dueling=info['DUELING']).to(info['DEVICE']) 458 | if info['PRIOR']: 459 | prior_net = EnsembleNet(n_ensemble=info['N_ENSEMBLE'], 460 | n_actions=env.num_actions, 461 | network_output_size=info['NETWORK_INPUT_SIZE'][0], 462 | num_channels=info['HISTORY_SIZE'], dueling=info['DUELING']).to(info['DEVICE']) 463 | 464 | print("using randomized prior") 465 | policy_net = NetWithPrior(policy_net, prior_net, info['PRIOR_SCALE']) 466 | target_net = NetWithPrior(target_net, prior_net, info['PRIOR_SCALE']) 467 | 468 | target_net.load_state_dict(policy_net.state_dict()) 469 | # create optimizer 470 | #opt = optim.RMSprop(policy_net.parameters(), 471 | # lr=info["RMS_LEARNING_RATE"], 472 | # momentum=info["RMS_MOMENTUM"], 473 | # eps=info["RMS_EPSILON"], 474 | # centered=info["RMS_CENTERED"], 475 | # alpha=info["RMS_DECAY"]) 476 | opt = optim.Adam(policy_net.parameters(), lr=info['ADAM_LEARNING_RATE']) 477 | 478 | if args.model_loadpath is not '': 479 | # what about random states - they will be wrong now??? 480 | # TODO - what about target net update cnt 481 | target_net.load_state_dict(model_dict['target_net_state_dict']) 482 | policy_net.load_state_dict(model_dict['policy_net_state_dict']) 483 | opt.load_state_dict(model_dict['optimizer']) 484 | print("loaded model state_dicts") 485 | if args.buffer_loadpath == '': 486 | args.buffer_loadpath = args.model_loadpath.replace('.pkl', '_train_buffer.npz') 487 | print("auto loading buffer from:%s" %args.buffer_loadpath) 488 | try: 489 | replay_memory.load_buffer(args.buffer_loadpath) 490 | except Exception as e: 491 | print(e) 492 | print('not able to load from buffer: %s. exit() to continue with empty buffer' %args.buffer_loadpath) 493 | 494 | train(start_step_number, start_last_save) 495 | 496 | --------------------------------------------------------------------------------