├── .gitignore ├── README.md ├── core.py ├── dqn.py ├── env.py ├── figs ├── boxing.png └── breakout.png ├── logger.py ├── main.py ├── model.py ├── policy.py ├── roms ├── air_raid.bin ├── alien.bin ├── amidar.bin ├── assault.bin ├── asterix.bin ├── asteroids.bin ├── atlantis.bin ├── bank_heist.bin ├── battle_zone.bin ├── beam_rider.bin ├── berzerk.bin ├── bowling.bin ├── boxing.bin ├── breakout.bin ├── carnival.bin ├── centipede.bin ├── chopper_command.bin ├── crazy_climber.bin ├── defender.bin ├── demon_attack.bin ├── double_dunk.bin ├── elevator_action.bin ├── enduro.bin ├── fishing_derby.bin ├── freeway.bin ├── frostbite.bin ├── gopher.bin ├── gravitar.bin ├── hero.bin ├── ice_hockey.bin ├── jamesbond.bin ├── journey_escape.bin ├── kangaroo.bin ├── krull.bin ├── kung_fu_master.bin ├── montezuma_revenge.bin ├── ms_pacman.bin ├── name_this_game.bin ├── phoenix.bin ├── pitfall.bin ├── pong.bin ├── pooyan.bin ├── private_eye.bin ├── qbert.bin ├── riverraid.bin ├── road_runner.bin ├── robotank.bin ├── seaquest.bin ├── skiing.bin ├── solaris.bin ├── space_invaders.bin ├── star_gunner.bin ├── tennis.bin ├── time_pilot.bin ├── tutankham.bin ├── up_n_down.bin ├── venture.bin ├── video_pinball.bin ├── wizard_of_wor.bin ├── yars_revenge.bin └── zaxxon.bin ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | test_env/ 3 | experiments/ 4 | exps/ 5 | devs/ 6 | render_plot.py 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Implementation of Rainbow 2 | 3 | This repo is a partial implementation of the [Rainbow](https://arxiv.org/pdf/1710.02298.pdf) 4 | agent published by researchers from DeepMind. 5 | The implementation is efficient and of high quality. It trains at a speed of 6 | 350 frames/s on a PC with a 3.5GHz CPU and GTX1080 GPU. 7 | 8 | Rainbow is a deep Q learning based agent that combines a bunch of existing techiques 9 | such as dueling dqn, distributional dqn, etc. This repo currenly implemented the 10 | following dqn variants: 11 | * [DQN](https://www.nature.com/articles/nature14236) 12 | * [Double DQN](https://arxiv.org/abs/1509.06461) 13 | * [Dueling DQN](https://arxiv.org/abs/1511.06581) 14 | * [Distributional DQN](https://arxiv.org/pdf/1707.06887.pdf) 15 | * [Noisy Net](https://arxiv.org/abs/1706.10295) 16 | 17 | and it will need the following extensions to become a full "Rainbow": 18 | * Multi-step learning 19 | * Priority Replay 20 | 21 | ## Hyperparameters 22 | 23 | The hyperparameters in this repo follows the ones described in 24 | [Rainbow](https://arxiv.org/pdf/1710.02298.pdf) 25 | paper as close as possible. However, there may still be some differences due to 26 | misunderstanding. 27 | 28 | ## Performance 29 | 30 | DQN agent often takes days to train. For sanity check, we can 31 | train a agent to play a simple game "boxing". Follwing is the learning curve 32 | of a dueling double dqn trained on boxing. 33 | 34 | ![](figs/boxing.png) 35 | 36 | The agent almost solves boxing after around 12M frames, which is a good sign 37 | that the implementation is working. 38 | 39 | To test the distributional DQN and Noisy Net, the agent is trained on "breakout" since 40 | distributional DQN performs significantly better than others on this game, 41 | reaching >400 scores rapidly while other DQN methods struggle to do so. 42 | 43 | ![](figs/breakout.png) 44 | 45 | From the figure we see that the agent can reach >400 scores very rapidly and steadily. 46 | Note that the publicly reported numbers on papers are produced by training the agent for 47 | 200M frames while here it trains only for 50M frames due to computation cost. 48 | 49 | Figures here are smoothed. 50 | 51 | ## Future Works 52 | 53 | We plan to implement multi-step learing and priority replay. Also, the current 54 | implementation uses a simple wrapper on the [Arcade Learning Enviroment](https://github.com/mgbellemare/Arcade-Learning-Environment). 55 | We may want to shift to OpenAI gym for better visualization and video recording. 56 | On top of Rainbow, it will also be interesting to include other new techniques, 57 | such as [Distributional RL with Quantile Regression](https://arxiv.org/pdf/1710.10044.pdf). 58 | 59 | Contributions and bug-catchings are welcome! 60 | -------------------------------------------------------------------------------- /core.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | import utils 5 | from policy import GreedyEpsilonPolicy 6 | 7 | 8 | class Sample(object): 9 | def __init__(self, state, action, reward, next_state, end): 10 | utils.assert_eq(type(state), type(next_state)) 11 | 12 | self._state = (state * 255.0).astype(np.uint8) 13 | self._next_state = (next_state * 255.0).astype(np.uint8) 14 | self.action = action 15 | self.reward = reward 16 | self.end = end 17 | 18 | @property 19 | def state(self): 20 | return self._state.astype(np.float32) / 255.0 21 | 22 | @property 23 | def next_state(self): 24 | return self._next_state.astype(np.float32) / 255.0 25 | 26 | def __repr__(self): 27 | info = ('S(mean): %3.4f, A: %s, R: %s, NS(mean): %3.4f, End: %s' 28 | % (self.state.mean(), self.action, self.reward, 29 | self.next_state.mean(), self.end)) 30 | return info 31 | 32 | 33 | class ReplayMemory(object): 34 | def __init__(self, max_size): 35 | self.max_size = max_size 36 | self.samples = [] 37 | self.oldest_idx = 0 38 | 39 | def __len__(self): 40 | return len(self.samples) 41 | 42 | def _evict(self): 43 | """Simplest FIFO eviction scheme.""" 44 | to_evict = self.oldest_idx 45 | self.oldest_idx = (self.oldest_idx + 1) % self.max_size 46 | return to_evict 47 | 48 | def burn_in(self, env, agent, num_steps): 49 | policy = GreedyEpsilonPolicy(1, agent) # uniform policy 50 | i = 0 51 | while i < num_steps or not env.end: 52 | if env.end: 53 | state = env.reset() 54 | action = policy.get_action(None) 55 | next_state, reward = env.step(action) 56 | self.append(state, action, reward, next_state, env.end) 57 | state = next_state 58 | i += 1 59 | if i % 10000 == 0: 60 | print '%d frames burned in' % i 61 | print '%d frames burned into the memory.' % i 62 | 63 | def append(self, state, action, reward, next_state, end): 64 | assert len(self.samples) <= self.max_size 65 | new_sample = Sample(state, action, reward, next_state, end) 66 | if len(self.samples) == self.max_size: 67 | avail_slot = self._evict() 68 | self.samples[avail_slot] = new_sample 69 | else: 70 | self.samples.append(new_sample) 71 | 72 | def sample(self, batch_size): 73 | """Simpliest uniform sampling (w/o replacement) to produce a batch. 74 | """ 75 | assert batch_size < len(self.samples), 'no enough samples to sample from' 76 | return random.sample(self.samples, batch_size) 77 | 78 | def clear(self): 79 | self.samples = [] 80 | self.oldest_idx = 0 81 | 82 | 83 | def samples_to_tensors(samples): 84 | num_samples = len(samples) 85 | 86 | states_shape = (num_samples, ) + samples[0].state.shape 87 | states = np.zeros(states_shape, dtype=np.float32) 88 | next_states = np.zeros(states_shape, dtype=np.float32) 89 | 90 | rewards = np.zeros(num_samples, dtype=np.float32) 91 | actions = np.zeros(num_samples, dtype=np.int64) 92 | non_ends = np.zeros(num_samples, dtype=np.float32) 93 | 94 | for i, s in enumerate(samples): 95 | states[i] = s.state 96 | next_states[i] = s.next_state 97 | rewards[i] = s.reward 98 | actions[i] = s.action 99 | non_ends[i] = 0.0 if s.end else 1.0 100 | 101 | states = torch.from_numpy(states).cuda() 102 | actions = torch.from_numpy(actions).cuda() 103 | rewards = torch.from_numpy(rewards).cuda() 104 | next_states = torch.from_numpy(next_states).cuda() 105 | non_ends = torch.from_numpy(non_ends).cuda() 106 | 107 | return states, actions, rewards, next_states, non_ends 108 | -------------------------------------------------------------------------------- /dqn.py: -------------------------------------------------------------------------------- 1 | """Main DQN agent.""" 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import utils 8 | 9 | 10 | class DQNAgent(object): 11 | def __init__(self, q_net, double_dqn, num_actions): 12 | self.online_q_net = q_net 13 | self.target_q_net = copy.deepcopy(q_net) 14 | self.double_dqn = double_dqn 15 | self.num_actions = num_actions 16 | 17 | def save_q_net(self, prefix): 18 | torch.save(self.online_q_net.state_dict(), prefix+'online_q_net.pth') 19 | 20 | def parameters(self): 21 | return self.online_q_net.parameters() 22 | 23 | def sync_target(self): 24 | self.target_q_net = copy.deepcopy(self.online_q_net) 25 | 26 | def target_q_values(self, states): 27 | q_vals = self.target_q_net(Variable(states, volatile=True)).data 28 | return q_vals 29 | 30 | def online_q_values(self, states): 31 | q_vals = self.online_q_net(Variable(states, volatile=True)).data 32 | return q_vals 33 | 34 | def compute_targets(self, rewards, next_states, non_ends, gamma): 35 | """Compute batch of targets for dqn 36 | 37 | params: 38 | rewards: Tensor [batch] 39 | next_states: Tensor [batch, channel, w, h] 40 | non_ends: Tensor [batch] 41 | gamma: float 42 | """ 43 | next_q_vals = self.target_q_values(next_states) 44 | 45 | if self.double_dqn: 46 | next_actions = self.online_q_values(next_states).max(1, True)[1] 47 | next_actions = utils.one_hot(next_actions, self.num_actions) 48 | next_qs = (next_q_vals * next_actions).sum(1) 49 | else: 50 | next_qs = next_q_vals.max(1)[0] # max returns a pair 51 | 52 | targets = rewards + gamma * next_qs * non_ends 53 | return targets 54 | 55 | def loss(self, states, actions, targets): 56 | """ 57 | params: 58 | states: Variable [batch, channel, w, h] 59 | actions: Variable [batch, num_actions] one hot encoding 60 | targets: Variable [batch] 61 | """ 62 | utils.assert_eq(actions.size(1), self.num_actions) 63 | 64 | qs = self.online_q_net(states) 65 | preds = (qs * actions).sum(1) 66 | err = nn.functional.smooth_l1_loss(preds, targets) 67 | return err 68 | 69 | 70 | class DistributionalDQNAgent(DQNAgent): 71 | def __init__(self, q_net, double_dqn, num_actions, num_atoms, vmin, vmax): 72 | super(DistributionalDQNAgent, self).__init__(q_net, double_dqn, num_actions) 73 | 74 | self.num_atoms = num_atoms 75 | self.vmin = float(vmin) 76 | self.vmax = float(vmax) 77 | 78 | self.delta_z = (self.vmax - self.vmin) / (num_atoms - 1) 79 | 80 | zpoints = np.linspace(vmin, vmax, num_atoms).astype(np.float32) 81 | self.zpoints = Variable(torch.from_numpy(zpoints).unsqueeze(0)).cuda() 82 | 83 | def _q_values(self, q_net, states): 84 | """internal function to compute q_value 85 | 86 | params: 87 | q_net: self.online_q_net or self.target_q_net 88 | states: Variable [batch, channel, w, h] 89 | """ 90 | probs = q_net(states) # [batch, num_actions, num_atoms] 91 | q_vals = (probs * self.zpoints).sum(2) 92 | return q_vals, probs 93 | 94 | def target_q_values(self, states): 95 | states = Variable(states, volatile=True) 96 | q_vals, _ = self._q_values(self.target_q_net, states) 97 | return q_vals.data 98 | 99 | def online_q_values(self, states): 100 | states = Variable(states, volatile=True) 101 | q_vals, _ = self._q_values(self.online_q_net, states) 102 | return q_vals.data 103 | 104 | def compute_targets(self, rewards, next_states, non_ends, gamma): 105 | """Compute batch of targets for distributional dqn 106 | 107 | params: 108 | rewards: Tensor [batch, 1] 109 | next_states: Tensor [batch, channel, w, h] 110 | non_ends: Tensor [batch, 1] 111 | gamma: float 112 | """ 113 | assert not self.double_dqn, 'not supported yet' 114 | 115 | # get next distribution 116 | next_states = Variable(next_states, volatile=True) 117 | # [batch, num_actions], [batch, num_actions, num_atoms] 118 | next_q_vals, next_probs = self._q_values(self.target_q_net, next_states) 119 | next_actions = next_q_vals.data.max(1, True)[1] # [batch, 1] 120 | next_actions = utils.one_hot(next_actions, self.num_actions).unsqueeze(2) 121 | next_greedy_probs = (next_actions * next_probs.data).sum(1) 122 | 123 | # transform the distribution 124 | rewards = rewards.unsqueeze(1) 125 | non_ends = non_ends.unsqueeze(1) 126 | proj_zpoints = rewards + gamma * non_ends * self.zpoints.data 127 | proj_zpoints.clamp_(self.vmin, self.vmax) 128 | 129 | # project onto shared support 130 | b = (proj_zpoints - self.vmin) / self.delta_z 131 | lower = b.floor() 132 | upper = b.ceil() 133 | # handle corner case where b is integer 134 | eq = (upper == lower).float() 135 | lower -= eq 136 | lt0 = (lower < 0).float() 137 | lower += lt0 138 | upper += lt0 139 | 140 | # note: it's faster to do the following on cpu 141 | ml = (next_greedy_probs * (upper - b)).cpu().numpy() 142 | mu = (next_greedy_probs * (b - lower)).cpu().numpy() 143 | 144 | lower = lower.cpu().numpy().astype(np.int32) 145 | upper = upper.cpu().numpy().astype(np.int32) 146 | 147 | batch_size = rewards.size(0) 148 | mass = np.zeros((batch_size, self.num_atoms), dtype=np.float32) 149 | brange = range(batch_size) 150 | for i in range(self.num_atoms): 151 | mass[brange, lower[brange, i]] += ml[brange, i] 152 | mass[brange, upper[brange, i]] += mu[brange, i] 153 | 154 | return torch.from_numpy(mass).cuda() 155 | 156 | def loss(self, states, actions, targets): 157 | """ 158 | params: 159 | states: Variable [batch, channel, w, h] 160 | actions: Variable [batch, num_actions] one hot encoding 161 | targets: Variable [batch, num_atoms] 162 | """ 163 | utils.assert_eq(actions.size(1), self.num_actions) 164 | 165 | actions = actions.unsqueeze(2) 166 | probs = self.online_q_net(states) # [batch, num_actions, num_atoms] 167 | probs = (probs * actions).sum(1) # [batch, num_atoms] 168 | xent = -(targets * torch.log(probs.clamp(min=utils.EPS))).sum(1) 169 | xent = xent.mean(0) 170 | return xent 171 | -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import deque 3 | import cv2 4 | import numpy as np 5 | from ale_python_interface import ALEInterface 6 | 7 | 8 | # glb_counter = 0 9 | 10 | 11 | def preprocess_frame(observ, output_size): 12 | gray = cv2.cvtColor(observ, cv2.COLOR_RGB2GRAY) 13 | output = cv2.resize(gray, (output_size, output_size)) 14 | output = output.astype(np.float32, copy=False) 15 | return output 16 | 17 | 18 | class Environment(object): 19 | def __init__(self, 20 | rom_file, 21 | frame_skip, 22 | num_frames, 23 | frame_size, 24 | no_op_start, 25 | rand_seed, 26 | dead_as_eoe): 27 | self.ale = self._init_ale(rand_seed, rom_file) 28 | # normally (160, 210) 29 | self.actions = self.ale.getMinimalActionSet() 30 | 31 | self.frame_skip = frame_skip 32 | self.num_frames = num_frames 33 | self.frame_size = frame_size 34 | self.no_op_start = no_op_start 35 | self.dead_as_eoe = dead_as_eoe 36 | 37 | self.clipped_reward = 0 38 | self.total_reward = 0 39 | screen_width, screen_height = self.ale.getScreenDims() 40 | self.prev_screen = np.zeros( 41 | (screen_height, screen_width, 3), dtype=np.float32) 42 | self.frame_queue = deque(maxlen=num_frames) 43 | self.end = True 44 | 45 | @staticmethod 46 | def _init_ale(rand_seed, rom_file): 47 | assert os.path.exists(rom_file), '%s does not exists.' 48 | ale = ALEInterface() 49 | ale.setInt('random_seed', rand_seed) 50 | ale.setBool('showinfo', False) 51 | ale.setInt('frame_skip', 1) 52 | ale.setFloat('repeat_action_probability', 0.0) 53 | ale.setBool('color_averaging', False) 54 | ale.loadROM(rom_file) 55 | return ale 56 | 57 | @property 58 | def num_actions(self): 59 | return len(self.actions) 60 | 61 | def _get_current_frame(self): 62 | # global glb_counter 63 | screen = self.ale.getScreenRGB() 64 | max_screen = np.maximum(self.prev_screen, screen) 65 | frame = preprocess_frame(max_screen, self.frame_size) 66 | frame /= 255.0 67 | # cv2.imwrite('test_env/%d.png' % glb_counter, cv2.resize(frame, (800, 800))) 68 | # glb_counter += 1 69 | # print 'glb_counter', glb_counter 70 | return frame 71 | 72 | def reset(self): 73 | for _ in range(self.num_frames - 1): 74 | self.frame_queue.append( 75 | np.zeros((self.frame_size, self.frame_size), dtype=np.float32)) 76 | 77 | self.ale.reset_game() 78 | self.clipped_reward = 0 79 | self.total_reward = 0 80 | self.prev_screen = np.zeros(self.prev_screen.shape, dtype=np.float32) 81 | 82 | n = np.random.randint(0, self.no_op_start) 83 | for i in range(n): 84 | if i == n - 1: 85 | self.prev_screen = self.ale.getScreenRGB() 86 | self.ale.act(0) 87 | 88 | self.frame_queue.append(self._get_current_frame()) 89 | assert not self.ale.game_over() 90 | self.end = False 91 | return np.array(self.frame_queue) 92 | 93 | def step(self, action_idx): 94 | """Perform action and return frame sequence and reward. 95 | Return: 96 | state: [frames] of length num_frames, 0 if fewer is available 97 | reward: float 98 | """ 99 | assert not self.end 100 | reward = 0 101 | clipped_reward = 0 102 | old_lives = self.ale.lives() 103 | 104 | for _ in range(self.frame_skip): 105 | self.prev_screen = self.ale.getScreenRGB() 106 | r = self.ale.act(self.actions[action_idx]) 107 | reward += r 108 | clipped_reward += np.sign(r) 109 | dead = (self.ale.lives() < old_lives) 110 | if self.ale.game_over() or (self.dead_as_eoe and dead): 111 | self.end = True 112 | break 113 | 114 | self.frame_queue.append(self._get_current_frame()) 115 | self.total_reward += reward 116 | self.clipped_reward += clipped_reward 117 | return np.array(self.frame_queue), clipped_reward 118 | 119 | 120 | if __name__ == '__main__': 121 | 122 | env = Environment('roms/breakout.bin', 4, 4, 84, 30, 33, False) 123 | print 'starting with game over?', env.ale.game_over() 124 | 125 | state = env.reset() 126 | i = 0 127 | while not env.end: 128 | print i 129 | action = np.random.randint(0, env.num_actions) 130 | state, reward = env.step(action) 131 | if i % 100 == 0: 132 | for idx, f in enumerate(state): 133 | filename = 'test_env/f%d-%d.png' % (i, idx) 134 | cv2.imwrite(filename, cv2.resize(f, (800, 800))) 135 | i += 1 136 | print 'total_reward:', env.total_reward 137 | print 'clipped_reward', env.clipped_reward 138 | print 'total steps:', i 139 | -------------------------------------------------------------------------------- /figs/boxing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/figs/boxing.png -------------------------------------------------------------------------------- /figs/breakout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/figs/breakout.png -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # helper class to handle logging issue 2 | import numpy as np 3 | 4 | 5 | class Logger(object): 6 | def __init__(self, output_name): 7 | self.log_file = open(output_name, 'w') 8 | self.infos = {} 9 | 10 | def append(self, key, val): 11 | vals = self.infos.setdefault(key, []) 12 | vals.append(val) 13 | 14 | def log(self, extra_msg=''): 15 | msgs = [extra_msg] 16 | for key, vals in self.infos.iteritems(): 17 | msgs.append('%s %.6f' % (key, np.mean(vals))) 18 | msg = '\n'.join(msgs) 19 | self.log_file.write(msg + '\n') 20 | self.log_file.flush() 21 | self.infos = {} 22 | return msg 23 | 24 | def write(self, msg): 25 | self.log_file.write(msg + '\n') 26 | self.log_file.flush() 27 | print msg 28 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Run Atari Environment with DQN.""" 2 | import os 3 | import argparse 4 | import torch 5 | import model 6 | import dqn 7 | import utils 8 | from env import Environment 9 | from policy import GreedyEpsilonPolicy, LinearDecayGreedyEpsilonPolicy 10 | from core import ReplayMemory 11 | from train import train, evaluate 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser(description='Run DQN on Atari') 16 | parser.add_argument('--rom', default='roms/breakout.bin', 17 | help='path to rom') 18 | parser.add_argument('--seed', default=10001, type=int, 19 | help='Random seed') 20 | parser.add_argument('--q_net', default='', type=str, 21 | help='load pretrained q net') 22 | parser.add_argument('--gamma', default=0.99, type=float, 23 | help='discount factor') 24 | parser.add_argument('--num_iters', default=int(5e7), type=int) 25 | parser.add_argument('--replay_buffer_size', default=int(1e6), type=int) 26 | parser.add_argument('--frame_skip', default=4, type=int, 27 | help='num frames for repeated action') 28 | parser.add_argument('--num_frames', default=4, type=int, 29 | help='num stacked frames') 30 | parser.add_argument('--frame_size', default=84, type=int) 31 | parser.add_argument('--batch_size', default=32, type=int) 32 | parser.add_argument('--frames_per_update', default=4, type=int) 33 | parser.add_argument('--frames_per_sync', default=32000, type=int) 34 | 35 | # for using eps-greedy exploration 36 | parser.add_argument('--train_start_eps', default=1.0, type=float) 37 | parser.add_argument('--train_final_eps', default=0.01, type=float) 38 | parser.add_argument('--train_eps_num_steps', default=int(1e6), type=int) 39 | 40 | # for noisy net 41 | parser.add_argument('--noisy_net', action='store_true') 42 | parser.add_argument('--sigma0', default=0.4, type=float) 43 | 44 | parser.add_argument('--eval_eps', default=0.001, type=float) 45 | parser.add_argument('--frames_per_eval', default=int(5e5), type=int) 46 | parser.add_argument('--burn_in_frames', default=200000, type=int) 47 | parser.add_argument('--no_op_start', default=30, type=int) 48 | 49 | parser.add_argument('--dev', action='store_true') 50 | parser.add_argument('--output', default='exps/', type=str) 51 | parser.add_argument('--suffix', default='', type=str) 52 | parser.add_argument('--double_dqn', action='store_true') 53 | parser.add_argument('--dueling', action='store_true') 54 | parser.add_argument('--dist', action='store_true') 55 | parser.add_argument('--num_atoms', default=51, type=int) 56 | 57 | parser.add_argument('--net', default=None, type=str) 58 | 59 | args = parser.parse_args() 60 | if args.dev: 61 | args.burn_in_frames = 500 62 | args.frames_per_eval = 5000 63 | args.output = 'devs/' 64 | 65 | game_name = args.rom.split('/')[-1].split('.')[0] 66 | 67 | model_name = [] 68 | if args.noisy_net: 69 | model_name.append('noisy') 70 | 71 | if args.dist: 72 | model_name.append('dist') 73 | 74 | if args.dueling: 75 | model_name.append('dueling') 76 | else: 77 | model_name.append('basic') 78 | 79 | if args.double_dqn: 80 | model_name.append('ddqn') 81 | 82 | if args.suffix: 83 | model_name.append(args.suffix) 84 | 85 | model_name = '_'.join(model_name) 86 | args.output = os.path.join(args.output, game_name, model_name) 87 | utils.Config(vars(args)).dump(os.path.join(args.output, 'configs.txt')) 88 | return args 89 | 90 | 91 | if __name__ == '__main__': 92 | args = main() 93 | 94 | torch.backends.cudnn.benckmark = True 95 | utils.set_all_seeds(args.seed) 96 | 97 | train_env = Environment( 98 | args.rom, 99 | args.frame_skip, 100 | args.num_frames, 101 | args.frame_size, 102 | args.no_op_start + 1, 103 | utils.large_randint(), 104 | True) 105 | eval_env = Environment( 106 | args.rom, 107 | args.frame_skip, 108 | args.num_frames, 109 | args.frame_size, 110 | args.no_op_start + 1, 111 | utils.large_randint(), 112 | False) 113 | 114 | if args.dist: 115 | assert not args.dueling, 'not supported yet.' 116 | q_net = model.build_distributional_basic_network( 117 | args.num_frames, 118 | args.frame_size, 119 | train_env.num_actions, 120 | args.num_atoms, 121 | args.noisy_net, 122 | args.sigma0, 123 | args.net) 124 | q_net.cuda() 125 | agent = dqn.DistributionalDQNAgent( 126 | q_net, args.double_dqn, train_env.num_actions, args.num_atoms, -10, 10) 127 | else: 128 | if args.dueling: 129 | q_net_builder = model.build_dueling_network 130 | else: 131 | q_net_builder = model.build_basic_network 132 | 133 | q_net = q_net_builder( 134 | args.num_frames, 135 | args.frame_size, 136 | train_env.num_actions, 137 | args.noisy_net, 138 | args.sigma0, 139 | args.net) 140 | 141 | q_net.cuda() 142 | agent = dqn.DQNAgent(q_net, args.double_dqn, train_env.num_actions) 143 | 144 | if args.noisy_net: 145 | train_policy = GreedyEpsilonPolicy(0, agent) 146 | else: 147 | train_policy = LinearDecayGreedyEpsilonPolicy( 148 | args.train_start_eps, 149 | args.train_final_eps, 150 | args.train_eps_num_steps, 151 | agent) 152 | 153 | eval_policy = GreedyEpsilonPolicy(args.eval_eps, agent) 154 | replay_memory = ReplayMemory(args.replay_buffer_size) 155 | replay_memory.burn_in(train_env, agent, args.burn_in_frames) 156 | 157 | evaluator = lambda logger: evaluate(eval_env, eval_policy, 10, logger) 158 | train(agent, 159 | train_env, 160 | train_policy, 161 | replay_memory, 162 | args.gamma, 163 | args.batch_size, 164 | args.num_iters, 165 | args.frames_per_update, 166 | args.frames_per_sync, 167 | args.frames_per_eval, 168 | evaluator, 169 | args.output) 170 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import utils 6 | 7 | 8 | class BasicNetwork(nn.Module): 9 | def __init__(self, conv, fc): 10 | super(BasicNetwork, self).__init__() 11 | self.conv = conv 12 | self.fc = fc 13 | 14 | def forward(self, x): 15 | assert x.data.max() <= 1.0 16 | batch = x.size(0) 17 | y = self.conv(x) 18 | y = y.view(batch, -1) 19 | y = self.fc(y) 20 | return y 21 | 22 | 23 | class DuelingNetwork(nn.Module): 24 | def __init__(self, conv, adv, val): 25 | super(DuelingNetwork, self).__init__() 26 | self.conv = conv 27 | self.adv = adv 28 | self.val = val 29 | 30 | def forward(self, x): 31 | assert x.data.max() <= 1.0 32 | batch = x.size(0) 33 | feat = self.conv(x) 34 | feat = feat.view(batch, -1) 35 | adv = self.adv(feat) 36 | val = self.val(feat) 37 | q = val - adv.mean(1, keepdim=True) + adv 38 | return q 39 | 40 | 41 | # TODO: DistributionalDuelingNetwork 42 | class DistributionalBasicNetwork(nn.Module): 43 | def __init__(self, conv, fc, num_actions, num_atoms): 44 | super(DistributionalBasicNetwork, self).__init__() 45 | self.conv = conv 46 | self.fc = fc 47 | self.num_actions = num_actions 48 | self.num_atoms = num_atoms 49 | 50 | def forward(self, x): 51 | batch = x.size(0) 52 | y = self.conv(x) 53 | y = y.view(batch, -1) 54 | y = self.fc(y) 55 | logits = y.view(batch, self.num_actions, self.num_atoms) 56 | probs = nn.functional.softmax(logits, 2) 57 | return probs 58 | 59 | 60 | class NoisyLinear(nn.Module): 61 | """Factorised Gaussian NoisyNet""" 62 | def __init__(self, in_features, out_features, sigma0): 63 | super(NoisyLinear, self).__init__() 64 | self.in_features = in_features 65 | self.out_features = out_features 66 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 67 | self.bias = nn.Parameter(torch.Tensor(out_features)) 68 | self.noisy_weight = nn.Parameter(torch.Tensor(out_features, in_features)) 69 | self.noisy_bias = nn.Parameter(torch.Tensor(out_features)) 70 | self.reset_parameters() 71 | 72 | self.noise_std = sigma0 / math.sqrt(self.in_features) 73 | self.in_noise = torch.FloatTensor(in_features).cuda() 74 | self.out_noise = torch.FloatTensor(out_features).cuda() 75 | self.noise = None 76 | self.sample_noise() 77 | 78 | def sample_noise(self): 79 | self.in_noise.normal_(0, self.noise_std) 80 | self.out_noise.normal_(0, self.noise_std) 81 | self.noise = torch.mm(self.out_noise.view(-1, 1), self.in_noise.view(1, -1)) 82 | 83 | def reset_parameters(self): 84 | stdv = 1. / math.sqrt(self.weight.size(1)) 85 | self.weight.data.uniform_(-stdv, stdv) 86 | self.noisy_weight.data.uniform_(-stdv, stdv) 87 | if self.bias is not None: 88 | self.bias.data.uniform_(-stdv, stdv) 89 | self.noisy_bias.data.uniform_(-stdv, stdv) 90 | 91 | def forward(self, x): 92 | normal_y = nn.functional.linear(x, self.weight, self.bias) 93 | if not x.volatile: 94 | # update the noise once per update 95 | self.sample_noise() 96 | 97 | noisy_weight = self.noisy_weight * Variable(self.noise) 98 | noisy_bias = self.noisy_bias * Variable(self.out_noise) 99 | noisy_y = nn.functional.linear(x, noisy_weight, noisy_bias) 100 | return noisy_y + normal_y 101 | 102 | def __repr__(self): 103 | return self.__class__.__name__ + '(' \ 104 | + 'in_features=' + str(self.in_features) \ 105 | + ', out_features=' + str(self.out_features) + ')' 106 | 107 | 108 | # --------------------------------------- 109 | 110 | 111 | def _build_default_conv(in_channels): 112 | conv = nn.Sequential( 113 | nn.Conv2d(in_channels, 32, 8, 4), 114 | nn.ReLU(), 115 | nn.Conv2d(32, 64, 4, 2), 116 | nn.ReLU(), 117 | nn.Conv2d(64, 64, 3, 1), 118 | nn.ReLU() 119 | ) 120 | return conv 121 | 122 | 123 | def _build_fc(dims): 124 | layers = [nn.Linear(dims[0], dims[1])] 125 | for i in range(1, len(dims) - 1): 126 | layers.append(nn.ReLU()) 127 | layers.append(nn.Linear(dims[i], dims[i+1])) 128 | 129 | fc = nn.Sequential(*layers) 130 | return fc 131 | 132 | 133 | def _build_noisy_fc(dims, sigma0): 134 | layers = [NoisyLinear(dims[0], dims[1], sigma0)] 135 | for i in range(1, len(dims) - 1): 136 | layers.append(nn.ReLU()) 137 | layers.append(NoisyLinear(dims[i], dims[i+1], sigma0)) 138 | 139 | fc = nn.Sequential(*layers) 140 | return fc 141 | 142 | 143 | def build_basic_network(in_channels, in_size, out_dim, noisy, sigma0, net_file): 144 | conv = _build_default_conv(in_channels) 145 | 146 | in_shape = (1, in_channels, in_size, in_size) 147 | fc_in = utils.count_output_size(in_shape, conv) 148 | fc_hid = 512 149 | dims = [fc_in, fc_hid, out_dim] 150 | if noisy: 151 | fc = _build_noisy_fc(dims, sigma0) 152 | else: 153 | fc = _build_fc(dims) 154 | 155 | net = BasicNetwork(conv, fc) 156 | utils.init_net(net, net_file) 157 | return net 158 | 159 | 160 | def build_dueling_network(in_channels, in_size, out_dim, noisy, sigma0, net_file): 161 | conv = _build_default_conv(in_channels) 162 | 163 | in_shape = (1, in_channels, in_size, in_size) 164 | fc_in = utils.count_output_size(in_shape, conv) 165 | fc_hid = 512 166 | adv_dims = [fc_in, fc_hid, out_dim] 167 | val_dims = [fc_in, fc_hid, 1] 168 | 169 | if noisy: 170 | adv = _build_noisy_fc(adv_dims, sigma0) 171 | val = _build_noisy_fc(val_dims, sigma0) 172 | else: 173 | adv = _build_fc(adv_dims) 174 | val = _build_fc(val_dims) 175 | 176 | net = DuelingNetwork(conv, adv, val) 177 | utils.init_net(net, net_file) 178 | return net 179 | 180 | 181 | def build_distributional_basic_network( 182 | in_channels, in_size, out_dim, num_atoms, noisy, sigma0, net_file): 183 | 184 | conv = _build_default_conv(in_channels) 185 | 186 | in_shape = (1, in_channels, in_size, in_size) 187 | fc_in = utils.count_output_size(in_shape, conv) 188 | fc_hid = 512 189 | fc_dims = [fc_in, fc_hid, out_dim * num_atoms] 190 | if noisy: 191 | fc = _build_noisy_fc(fc_dims, sigma0) 192 | else: 193 | fc = _build_fc(fc_dims) 194 | 195 | net = DistributionalBasicNetwork(conv, fc, out_dim, num_atoms) 196 | utils.init_net(net, net_file) 197 | return net 198 | 199 | 200 | if __name__ == '__main__': 201 | import copy 202 | from torch.autograd import Variable 203 | 204 | # qnet = build_basic_network(4, 84, 6, None) 205 | qnet = build_dueling_network(4, 84, 6, None) 206 | print qnet 207 | qnet_target = copy.deepcopy(qnet) 208 | 209 | for p in qnet.parameters(): 210 | print p.mean().data[0], p.std().data[0] 211 | fake_input = Variable(torch.FloatTensor(10, 4, 84, 84)) 212 | print qnet(fake_input).size() 213 | -------------------------------------------------------------------------------- /policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import utils 3 | import torch 4 | 5 | 6 | class GreedyEpsilonPolicy(object): 7 | def __init__(self, epsilon, q_agent): 8 | self.epsilon = float(epsilon) 9 | self.q_agent = q_agent 10 | 11 | def get_action(self, state): 12 | """Run Greedy-Epsilon for the given state. 13 | 14 | params: 15 | state: numpy-array [num_frames, w, h] 16 | 17 | return: 18 | action: int, in range [0, num_actions) 19 | """ 20 | if np.random.uniform() <= self.epsilon: 21 | action = np.random.randint(0, self.q_agent.num_actions) 22 | return action 23 | 24 | state = torch.from_numpy(state) 25 | state = state.unsqueeze(0).cuda() 26 | 27 | q_vals = self.q_agent.online_q_values(state) 28 | utils.assert_eq(q_vals.size(0), 1) 29 | q_vals = q_vals.view(-1) 30 | q_vals = q_vals.cpu().numpy() 31 | action = q_vals.argmax() 32 | return action 33 | 34 | def decay(self): 35 | return 36 | 37 | 38 | class LinearDecayGreedyEpsilonPolicy(GreedyEpsilonPolicy): 39 | """Policy with a parameter that decays linearly. 40 | """ 41 | def __init__(self, start_eps, end_eps, num_steps, q_agent): 42 | super(LinearDecayGreedyEpsilonPolicy, self).__init__(start_eps, q_agent) 43 | self.num_steps = num_steps 44 | self.decay_rate = (start_eps - end_eps) / float(num_steps) 45 | 46 | def decay(self): 47 | if self.num_steps > 0: 48 | self.epsilon -= self.decay_rate 49 | self.num_steps -= 1 50 | 51 | 52 | # if __name__ == '__main__': 53 | # q_values = np.random.uniform(0, 1, (3,)) 54 | # target_actions = q_values.argmax() 55 | 56 | # greedy_policy = GreedyEpsilonPolicy(0) 57 | # actions = greedy_policy(q_values) 58 | # assert (actions == target_actions).all() 59 | 60 | # uniform_policy = GreedyEpsilonPolicy(1) 61 | # uni_actions = uniform_policy(q_values) 62 | # assert not (uni_actions == target_actions).all() 63 | 64 | # steps = 9 65 | # ldg_policy = LinearDecayGreedyEpsilonPolicy(1, 0.1, steps) 66 | # expect_eps = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.1] 67 | # actual_eps = [1.0] 68 | # for i in range(steps+1): 69 | # actions = ldg_policy(q_values) 70 | # actual_eps.append(ldg_policy.epsilon) 71 | # assert (np.abs((np.array(actual_eps) - np.array(expect_eps))) < 1e-5).all() 72 | -------------------------------------------------------------------------------- /roms/air_raid.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/air_raid.bin -------------------------------------------------------------------------------- /roms/alien.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/alien.bin -------------------------------------------------------------------------------- /roms/amidar.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/amidar.bin -------------------------------------------------------------------------------- /roms/assault.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/assault.bin -------------------------------------------------------------------------------- /roms/asterix.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/asterix.bin -------------------------------------------------------------------------------- /roms/asteroids.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/asteroids.bin -------------------------------------------------------------------------------- /roms/atlantis.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/atlantis.bin -------------------------------------------------------------------------------- /roms/bank_heist.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/bank_heist.bin -------------------------------------------------------------------------------- /roms/battle_zone.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/battle_zone.bin -------------------------------------------------------------------------------- /roms/beam_rider.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/beam_rider.bin -------------------------------------------------------------------------------- /roms/berzerk.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/berzerk.bin -------------------------------------------------------------------------------- /roms/bowling.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/bowling.bin -------------------------------------------------------------------------------- /roms/boxing.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/boxing.bin -------------------------------------------------------------------------------- /roms/breakout.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/breakout.bin -------------------------------------------------------------------------------- /roms/carnival.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/carnival.bin -------------------------------------------------------------------------------- /roms/centipede.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/centipede.bin -------------------------------------------------------------------------------- /roms/chopper_command.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/chopper_command.bin -------------------------------------------------------------------------------- /roms/crazy_climber.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/crazy_climber.bin -------------------------------------------------------------------------------- /roms/defender.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/defender.bin -------------------------------------------------------------------------------- /roms/demon_attack.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/demon_attack.bin -------------------------------------------------------------------------------- /roms/double_dunk.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/double_dunk.bin -------------------------------------------------------------------------------- /roms/elevator_action.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/elevator_action.bin -------------------------------------------------------------------------------- /roms/enduro.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/enduro.bin -------------------------------------------------------------------------------- /roms/fishing_derby.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/fishing_derby.bin -------------------------------------------------------------------------------- /roms/freeway.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/freeway.bin -------------------------------------------------------------------------------- /roms/frostbite.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/frostbite.bin -------------------------------------------------------------------------------- /roms/gopher.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/gopher.bin -------------------------------------------------------------------------------- /roms/gravitar.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/gravitar.bin -------------------------------------------------------------------------------- /roms/hero.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/hero.bin -------------------------------------------------------------------------------- /roms/ice_hockey.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/ice_hockey.bin -------------------------------------------------------------------------------- /roms/jamesbond.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/jamesbond.bin -------------------------------------------------------------------------------- /roms/journey_escape.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/journey_escape.bin -------------------------------------------------------------------------------- /roms/kangaroo.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/kangaroo.bin -------------------------------------------------------------------------------- /roms/krull.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/krull.bin -------------------------------------------------------------------------------- /roms/kung_fu_master.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/kung_fu_master.bin -------------------------------------------------------------------------------- /roms/montezuma_revenge.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/montezuma_revenge.bin -------------------------------------------------------------------------------- /roms/ms_pacman.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/ms_pacman.bin -------------------------------------------------------------------------------- /roms/name_this_game.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/name_this_game.bin -------------------------------------------------------------------------------- /roms/phoenix.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/phoenix.bin -------------------------------------------------------------------------------- /roms/pitfall.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/pitfall.bin -------------------------------------------------------------------------------- /roms/pong.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/pong.bin -------------------------------------------------------------------------------- /roms/pooyan.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/pooyan.bin -------------------------------------------------------------------------------- /roms/private_eye.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/private_eye.bin -------------------------------------------------------------------------------- /roms/qbert.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/qbert.bin -------------------------------------------------------------------------------- /roms/riverraid.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/riverraid.bin -------------------------------------------------------------------------------- /roms/road_runner.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/road_runner.bin -------------------------------------------------------------------------------- /roms/robotank.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/robotank.bin -------------------------------------------------------------------------------- /roms/seaquest.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/seaquest.bin -------------------------------------------------------------------------------- /roms/skiing.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/skiing.bin -------------------------------------------------------------------------------- /roms/solaris.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/solaris.bin -------------------------------------------------------------------------------- /roms/space_invaders.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/space_invaders.bin -------------------------------------------------------------------------------- /roms/star_gunner.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/star_gunner.bin -------------------------------------------------------------------------------- /roms/tennis.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/tennis.bin -------------------------------------------------------------------------------- /roms/time_pilot.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/time_pilot.bin -------------------------------------------------------------------------------- /roms/tutankham.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/tutankham.bin -------------------------------------------------------------------------------- /roms/up_n_down.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/up_n_down.bin -------------------------------------------------------------------------------- /roms/venture.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/venture.bin -------------------------------------------------------------------------------- /roms/video_pinball.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/video_pinball.bin -------------------------------------------------------------------------------- /roms/wizard_of_wor.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/wizard_of_wor.bin -------------------------------------------------------------------------------- /roms/yars_revenge.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/yars_revenge.bin -------------------------------------------------------------------------------- /roms/zaxxon.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hengyuan-hu/rainbow/192fcd105909ed9d86448e59c3afc92f7a1fd342/roms/zaxxon.bin -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import torch.nn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import utils 8 | from core import samples_to_tensors 9 | from logger import Logger 10 | 11 | 12 | def update_agent(agent, replay_memory, gamma, optim, batch_size): 13 | samples = replay_memory.sample(batch_size) 14 | states, actions, rewards, next_states, non_ends = samples_to_tensors(samples) 15 | actions = utils.one_hot(actions.unsqueeze(1), agent.num_actions) 16 | targets = agent.compute_targets(rewards, next_states, non_ends, gamma) 17 | states = Variable(states) 18 | actions = Variable(actions) 19 | targets = Variable(targets) 20 | loss = agent.loss(states, actions, targets) 21 | loss.backward() 22 | optim.step() 23 | optim.zero_grad() 24 | return loss.data[0] 25 | 26 | 27 | def train(agent, 28 | env, 29 | policy, 30 | replay_memory, 31 | gamma, 32 | batch_size, 33 | num_iters, 34 | frames_per_update, 35 | frames_per_sync, 36 | frames_per_eval, 37 | evaluator, 38 | output_dir): 39 | 40 | logger = Logger(os.path.join(output_dir, 'train_log.txt')) 41 | optim = torch.optim.Adam(agent.parameters(), lr=6.25e-5, eps=1.5e-4) 42 | action_dist = np.zeros(env.num_actions) 43 | max_epsd_iters = 20000 44 | 45 | best_avg_rewards = 0 46 | num_epsd = 0 47 | epsd_iters = 0 48 | epsd_rewards = 0 49 | t = time.time() 50 | for i in xrange(num_iters): 51 | if env.end or epsd_iters > max_epsd_iters: 52 | num_epsd += 1 53 | if num_epsd % 10 == 0: 54 | fps = epsd_iters / (time.time() - t) 55 | logger.write('Episode: %d, Iter: %d, Fps: %.2f' 56 | % (num_epsd, i+1, fps)) 57 | logger.write('sum clipped rewards %d' % epsd_rewards) 58 | logger.log() 59 | epsd_iters = 0 60 | epsd_rewards = 0 61 | t = time.time() 62 | 63 | state = env.reset() 64 | 65 | action = policy.get_action(state) 66 | action_dist[action] += 1 67 | next_state, reward = env.step(action) 68 | replay_memory.append(state, action, reward, next_state, env.end) 69 | state = next_state 70 | epsd_iters += 1 71 | epsd_rewards += reward 72 | 73 | if (i+1) % frames_per_update == 0: 74 | loss = update_agent(agent, replay_memory, gamma, optim, batch_size) 75 | logger.append('loss', loss) 76 | policy.decay() 77 | 78 | if (i+1) % frames_per_sync == 0: 79 | logger.write('>>>syncing nets, i: %d' % (i+1)) 80 | agent.sync_target() 81 | 82 | if (i+1) % frames_per_eval == 0: 83 | logger.write('Train Action distribution:') 84 | for act, count in enumerate(action_dist): 85 | prob = float(count) / action_dist.sum() 86 | logger.write('\t action: %d, p: %.4f' % (act, prob)) 87 | action_dist = np.zeros(env.num_actions) 88 | 89 | avg_rewards = evaluator(logger) 90 | if avg_rewards > best_avg_rewards: 91 | prefix = os.path.join(output_dir, '') 92 | agent.save_q_net(prefix) 93 | best_avg_rewards = avg_rewards 94 | 95 | 96 | def evaluate(env, policy, num_epsd, logger): 97 | actions = np.zeros(env.num_actions) 98 | total_rewards = np.zeros(num_epsd) 99 | epsd_idx = 0 100 | epsd_iters = 0 101 | max_epsd_iters = 108000 102 | 103 | state = env.reset() 104 | while epsd_idx < num_epsd: 105 | action = policy.get_action(state) 106 | actions[action] += 1 107 | state, _ = env.step(action) 108 | epsd_iters += 1 109 | 110 | if env.end or epsd_iters >= max_epsd_iters: 111 | total_rewards[epsd_idx] = env.total_reward 112 | logger.write('>>>Eval: [%d/%d], rewards: %s' % 113 | (epsd_idx+1, num_epsd, total_rewards[epsd_idx])) 114 | 115 | if epsd_idx < num_epsd - 1: # leave last reset to next run 116 | state = env.reset() 117 | 118 | epsd_idx += 1 119 | epsd_iters = 0 120 | 121 | avg_rewards = total_rewards.mean() 122 | logger.write('>>>Eval: avg total rewards: %s' % avg_rewards) 123 | logger.write('>>>Eval: actions dist:') 124 | probs = list(actions/actions.sum()) 125 | for action, prob in enumerate(probs): 126 | logger.write('\t action: %d, p: %.4f' % (action, prob)) 127 | 128 | return avg_rewards 129 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """Common functions you may find useful in your implementation.""" 2 | import os 3 | import json 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import numpy as np 9 | 10 | 11 | EPS = 1e-7 12 | 13 | 14 | def assert_eq(real, expected): 15 | assert real == expected, '%s (true) vs %s (expected)' % (real, expected) 16 | 17 | 18 | def assert_zero_grads(params): 19 | for p in params: 20 | if p.grad is not None: 21 | assert_eq(p.grad.data.sum(), 0) 22 | 23 | 24 | def assert_frozen(module): 25 | for p in module.parameters(): 26 | assert not p.requires_grad 27 | 28 | 29 | def weights_init(m): 30 | """custom weights initialization""" 31 | classtype = m.__class__ 32 | if classtype == nn.Linear or classtype == nn.Conv2d: 33 | m.weight.data.normal_(0.0, 0.02) 34 | elif classtype == nn.BatchNorm2d: 35 | m.weight.data.normal_(1.0, 0.02) 36 | m.bias.data.fill_(0) 37 | else: 38 | print '%s is not initialized.' % classtype 39 | 40 | 41 | def init_net(net, net_file): 42 | if net_file: 43 | net.load_state_dict(torch.load(net_file)) 44 | else: 45 | net.apply(weights_init) 46 | 47 | 48 | def count_output_size(input_shape, module): 49 | fake_input = Variable(torch.FloatTensor(*input_shape), volatile=True) 50 | output_size = module.forward(fake_input).view(-1).size()[0] 51 | return output_size 52 | 53 | 54 | def one_hot(x, n): 55 | assert x.dim() == 2 56 | one_hot_x = torch.zeros(x.size(0), n).cuda() 57 | one_hot_x.scatter_(1, x, 1) 58 | return one_hot_x 59 | 60 | 61 | def large_randint(): 62 | return random.randint(int(1e5), int(1e6)) 63 | 64 | 65 | def set_all_seeds(rand_seed): 66 | random.seed(rand_seed) 67 | np.random.seed(large_randint()) 68 | torch.manual_seed(large_randint()) 69 | torch.cuda.manual_seed(large_randint()) 70 | 71 | 72 | class Config(object): 73 | def __init__(self, attrs): 74 | self.__dict__.update(attrs) 75 | 76 | @classmethod 77 | def load(cls, filename): 78 | with open(filename, 'r') as f: 79 | attrs = json.load(f) 80 | return cls(attrs) 81 | 82 | def dump(self, filename): 83 | dirname = os.path.dirname(filename) 84 | if not os.path.exists(dirname): 85 | os.makedirs(dirname) 86 | print 'Results will be stored in:', dirname 87 | 88 | with open(filename, 'w') as f: 89 | json.dump(vars(self), f, sort_keys=True, indent=2) 90 | f.write('\n') 91 | 92 | def __repr__(self): 93 | return json.dumps(vars(self), sort_keys=True, indent=2) 94 | --------------------------------------------------------------------------------