├── README.md ├── dqn ├── LICENSE ├── NeuralQLearner.lua ├── Rectifier.lua ├── Scale.lua ├── TransitionTable.lua ├── convnet.lua ├── convnet_atari3.lua ├── initenv.lua ├── net_downsample_2x_full_y.lua ├── nnutils.lua └── train_agent.lua ├── install_dependencies.sh ├── roms └── README ├── run_cpu └── run_gpu /README.md: -------------------------------------------------------------------------------- 1 | # DQN 3.0 2 | 3 | This project contains the source code of DQN 3.0, a Lua-based deep reinforcement 4 | learning architecture, necessary to reproduce the experiments 5 | described in the paper ["Human-level control through deep reinforcement 6 | learning", Nature 518, 529–533 (26 February 2015) 7 | doi:10.1038/nature14236](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html). 8 | 9 | To replicate the experiment results, a number of dependencies need to be 10 | installed, namely: 11 | 12 | - LuaJIT and Torch 7.0 13 | - nngraph 14 | - Xitari (fork of the Arcade Learning Environment (Bellemare et al., 2013)) 15 | - AleWrap (a lua interface to Xitari) 16 | 17 | An install script for these dependencies is provided. 18 | 19 | Two run scripts are provided: run_cpu and run_gpu. As the names imply, 20 | the former trains the DQN network using regular CPUs, while the latter uses 21 | GPUs (CUDA), which typically results in a significant speed-up. 22 | 23 | # Installation instructions 24 | 25 | The installation requires Linux with apt-get. 26 | 27 | Note: In order to run the GPU version of DQN, you should additionally have the 28 | NVIDIA® CUDA® (version 5.5 or later) toolkit installed prior to the Torch 29 | installation below. 30 | This can be downloaded from https://developer.nvidia.com/cuda-toolkit 31 | and installation instructions can be found in 32 | http://docs.nvidia.com/cuda/cuda-getting-started-guide-for-linux 33 | 34 | 35 | To train DQN on Atari games, the following components must be installed: 36 | 37 | - LuaJIT and Torch 7.0 38 | - nngraph 39 | - Xitari 40 | - AleWrap 41 | 42 | To install all of the above in a subdirectory called 'torch', it should be enough to run 43 | 44 | ./install_dependencies.sh 45 | 46 | from the base directory of the package. 47 | 48 | 49 | Note: The above install script will install the following packages via apt-get: 50 | build-essential, gcc, g++, cmake, curl, libreadline-dev, git-core, libjpeg-dev, 51 | libpng-dev, ncurses-dev, imagemagick, unzip 52 | 53 | # Training DQN on Atari games 54 | 55 | Prior to running DQN on a game, you should copy its ROM in the 'roms' subdirectory. 56 | It should then be sufficient to run the script 57 | 58 | ./run_cpu 59 | 60 | Or, if GPU support is enabled, 61 | 62 | ./run_gpu 63 | 64 | Note: On a system with more than one GPU, DQN training can be launched on a 65 | specified GPU by setting the environment variable GPU_ID, e.g. by 66 | 67 | GPU_ID=2 ./run_gpu 68 | 69 | If GPU_ID is not specified, the first available GPU (ID 0) will be used by default. 70 | 71 | # Options 72 | 73 | Options to DQN are set within run_cpu (respectively, run_gpu). You may, 74 | for example, want to change the frequency at which information is output 75 | to stdout by setting 'prog_freq' to a different value. 76 | -------------------------------------------------------------------------------- /dqn/LICENSE: -------------------------------------------------------------------------------- 1 | LIMITED LICENSE: 2 | 3 | Copyright (c) 2016 DeepMind Technologies Limited 4 | Limited License: Under no circumstance is commercial use, reproduction, or 5 | distribution permitted. Use, reproduction, and distribution are permitted 6 | solely for academic use in evaluating and reviewing claims made in 7 | "Human-level control through deep reinforcement learning", Nature 518, 529–533 8 | (26 February 2015) doi:10.1038/nature14236, provided that the following 9 | conditions are met: 10 | 11 | * Any reproduction or distribution of source code must retain the above 12 | copyright notice and the full text of this license including the Disclaimer, 13 | below.
 14 | 15 | * Any reproduction or distribution in binary form must reproduce the above 16 | copyright notice and the full text of this license including the Disclaimer 17 | below
in the documentation and/or other materials provided with the 18 | Distribution. 19 | 20 | * Any publication that discloses findings arising from using this source code 21 | must cite “Human-level control through deep reinforcement learning", Nature 22 | 518, 529–533 (26 February 2015) doi:10.1038/nature14236. 23 | 24 | DISCLAIMER 25 | 26 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 27 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 28 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 29 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 30 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 31 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 32 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 33 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 34 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 35 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 36 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 37 | -------------------------------------------------------------------------------- /dqn/NeuralQLearner.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | if not dqn then 8 | require 'initenv' 9 | end 10 | 11 | local nql = torch.class('dqn.NeuralQLearner') 12 | 13 | 14 | function nql:__init(args) 15 | self.state_dim = args.state_dim -- State dimensionality. 16 | self.actions = args.actions 17 | self.n_actions = #self.actions 18 | self.verbose = args.verbose 19 | self.best = args.best 20 | 21 | --- epsilon annealing 22 | self.ep_start = args.ep or 1 23 | self.ep = self.ep_start -- Exploration probability. 24 | self.ep_end = args.ep_end or self.ep 25 | self.ep_endt = args.ep_endt or 1000000 26 | 27 | ---- learning rate annealing 28 | self.lr_start = args.lr or 0.01 --Learning rate. 29 | self.lr = self.lr_start 30 | self.lr_end = args.lr_end or self.lr 31 | self.lr_endt = args.lr_endt or 1000000 32 | self.wc = args.wc or 0 -- L2 weight cost. 33 | self.minibatch_size = args.minibatch_size or 1 34 | self.valid_size = args.valid_size or 500 35 | 36 | --- Q-learning parameters 37 | self.discount = args.discount or 0.99 --Discount factor. 38 | self.update_freq = args.update_freq or 1 39 | -- Number of points to replay per learning step. 40 | self.n_replay = args.n_replay or 1 41 | -- Number of steps after which learning starts. 42 | self.learn_start = args.learn_start or 0 43 | -- Size of the transition table. 44 | self.replay_memory = args.replay_memory or 1000000 45 | self.hist_len = args.hist_len or 1 46 | self.rescale_r = args.rescale_r 47 | self.max_reward = args.max_reward 48 | self.min_reward = args.min_reward 49 | self.clip_delta = args.clip_delta 50 | self.target_q = args.target_q 51 | self.bestq = 0 52 | 53 | self.gpu = args.gpu 54 | 55 | self.ncols = args.ncols or 1 -- number of color channels in input 56 | self.input_dims = args.input_dims or {self.hist_len*self.ncols, 84, 84} 57 | self.preproc = args.preproc -- name of preprocessing network 58 | self.histType = args.histType or "linear" -- history type to use 59 | self.histSpacing = args.histSpacing or 1 60 | self.nonTermProb = args.nonTermProb or 1 61 | self.bufferSize = args.bufferSize or 512 62 | 63 | self.transition_params = args.transition_params or {} 64 | 65 | self.network = args.network or self:createNetwork() 66 | 67 | -- check whether there is a network file 68 | local network_function 69 | if not (type(self.network) == 'string') then 70 | error("The type of the network provided in NeuralQLearner" .. 71 | " is not a string!") 72 | end 73 | 74 | local msg, err = pcall(require, self.network) 75 | if not msg then 76 | -- try to load saved agent 77 | local err_msg, exp = pcall(torch.load, self.network) 78 | if not err_msg then 79 | error("Could not find network file ") 80 | end 81 | if self.best and exp.best_model then 82 | self.network = exp.best_model 83 | else 84 | self.network = exp.model 85 | end 86 | else 87 | print('Creating Agent Network from ' .. self.network) 88 | self.network = err 89 | self.network = self:network() 90 | end 91 | 92 | if self.gpu and self.gpu >= 0 then 93 | self.network:cuda() 94 | else 95 | self.network:float() 96 | end 97 | 98 | -- Load preprocessing network. 99 | if not (type(self.preproc == 'string')) then 100 | error('The preprocessing is not a string') 101 | end 102 | msg, err = pcall(require, self.preproc) 103 | if not msg then 104 | error("Error loading preprocessing net") 105 | end 106 | self.preproc = err 107 | self.preproc = self:preproc() 108 | self.preproc:float() 109 | 110 | if self.gpu and self.gpu >= 0 then 111 | self.network:cuda() 112 | self.tensor_type = torch.CudaTensor 113 | else 114 | self.network:float() 115 | self.tensor_type = torch.FloatTensor 116 | end 117 | 118 | -- Create transition table. 119 | ---- assuming the transition table always gets floating point input 120 | ---- (Foat or Cuda tensors) and always returns one of the two, as required 121 | ---- internally it always uses ByteTensors for states, scaling and 122 | ---- converting accordingly 123 | local transition_args = { 124 | stateDim = self.state_dim, numActions = self.n_actions, 125 | histLen = self.hist_len, gpu = self.gpu, 126 | maxSize = self.replay_memory, histType = self.histType, 127 | histSpacing = self.histSpacing, nonTermProb = self.nonTermProb, 128 | bufferSize = self.bufferSize 129 | } 130 | 131 | self.transitions = dqn.TransitionTable(transition_args) 132 | 133 | self.numSteps = 0 -- Number of perceived states. 134 | self.lastState = nil 135 | self.lastAction = nil 136 | self.v_avg = 0 -- V running average. 137 | self.tderr_avg = 0 -- TD error running average. 138 | 139 | self.q_max = 1 140 | self.r_max = 1 141 | 142 | self.w, self.dw = self.network:getParameters() 143 | self.dw:zero() 144 | 145 | self.deltas = self.dw:clone():fill(0) 146 | 147 | self.tmp= self.dw:clone():fill(0) 148 | self.g = self.dw:clone():fill(0) 149 | self.g2 = self.dw:clone():fill(0) 150 | 151 | if self.target_q then 152 | self.target_network = self.network:clone() 153 | end 154 | end 155 | 156 | 157 | function nql:reset(state) 158 | if not state then 159 | return 160 | end 161 | self.best_network = state.best_network 162 | self.network = state.model 163 | self.w, self.dw = self.network:getParameters() 164 | self.dw:zero() 165 | self.numSteps = 0 166 | print("RESET STATE SUCCESFULLY") 167 | end 168 | 169 | 170 | function nql:preprocess(rawstate) 171 | if self.preproc then 172 | return self.preproc:forward(rawstate:float()) 173 | :clone():reshape(self.state_dim) 174 | end 175 | 176 | return rawstate 177 | end 178 | 179 | 180 | function nql:getQUpdate(args) 181 | local s, a, r, s2, term, delta 182 | local q, q2, q2_max 183 | 184 | s = args.s 185 | a = args.a 186 | r = args.r 187 | s2 = args.s2 188 | term = args.term 189 | 190 | -- The order of calls to forward is a bit odd in order 191 | -- to avoid unnecessary calls (we only need 2). 192 | 193 | -- delta = r + (1-terminal) * gamma * max_a Q(s2, a) - Q(s, a) 194 | term = term:clone():float():mul(-1):add(1) 195 | 196 | local target_q_net 197 | if self.target_q then 198 | target_q_net = self.target_network 199 | else 200 | target_q_net = self.network 201 | end 202 | 203 | -- Compute max_a Q(s_2, a). 204 | q2_max = target_q_net:forward(s2):float():max(2) 205 | 206 | -- Compute q2 = (1-terminal) * gamma * max_a Q(s2, a) 207 | q2 = q2_max:clone():mul(self.discount):cmul(term) 208 | 209 | delta = r:clone():float() 210 | 211 | if self.rescale_r then 212 | delta:div(self.r_max) 213 | end 214 | delta:add(q2) 215 | 216 | -- q = Q(s,a) 217 | local q_all = self.network:forward(s):float() 218 | q = torch.FloatTensor(q_all:size(1)) 219 | for i=1,q_all:size(1) do 220 | q[i] = q_all[i][a[i]] 221 | end 222 | delta:add(-1, q) 223 | 224 | if self.clip_delta then 225 | delta[delta:ge(self.clip_delta)] = self.clip_delta 226 | delta[delta:le(-self.clip_delta)] = -self.clip_delta 227 | end 228 | 229 | local targets = torch.zeros(self.minibatch_size, self.n_actions):float() 230 | for i=1,math.min(self.minibatch_size,a:size(1)) do 231 | targets[i][a[i]] = delta[i] 232 | end 233 | 234 | if self.gpu >= 0 then targets = targets:cuda() end 235 | 236 | return targets, delta, q2_max 237 | end 238 | 239 | 240 | function nql:qLearnMinibatch() 241 | -- Perform a minibatch Q-learning update: 242 | -- w += alpha * (r + gamma max Q(s2,a2) - Q(s,a)) * dQ(s,a)/dw 243 | assert(self.transitions:size() > self.minibatch_size) 244 | 245 | local s, a, r, s2, term = self.transitions:sample(self.minibatch_size) 246 | 247 | local targets, delta, q2_max = self:getQUpdate{s=s, a=a, r=r, s2=s2, 248 | term=term, update_qmax=true} 249 | 250 | -- zero gradients of parameters 251 | self.dw:zero() 252 | 253 | -- get new gradient 254 | self.network:backward(s, targets) 255 | 256 | -- add weight cost to gradient 257 | self.dw:add(-self.wc, self.w) 258 | 259 | -- compute linearly annealed learning rate 260 | local t = math.max(0, self.numSteps - self.learn_start) 261 | self.lr = (self.lr_start - self.lr_end) * (self.lr_endt - t)/self.lr_endt + 262 | self.lr_end 263 | self.lr = math.max(self.lr, self.lr_end) 264 | 265 | -- use gradients 266 | self.g:mul(0.95):add(0.05, self.dw) 267 | self.tmp:cmul(self.dw, self.dw) 268 | self.g2:mul(0.95):add(0.05, self.tmp) 269 | self.tmp:cmul(self.g, self.g) 270 | self.tmp:mul(-1) 271 | self.tmp:add(self.g2) 272 | self.tmp:add(0.01) 273 | self.tmp:sqrt() 274 | 275 | -- accumulate update 276 | self.deltas:mul(0):addcdiv(self.lr, self.dw, self.tmp) 277 | self.w:add(self.deltas) 278 | end 279 | 280 | 281 | function nql:sample_validation_data() 282 | local s, a, r, s2, term = self.transitions:sample(self.valid_size) 283 | self.valid_s = s:clone() 284 | self.valid_a = a:clone() 285 | self.valid_r = r:clone() 286 | self.valid_s2 = s2:clone() 287 | self.valid_term = term:clone() 288 | end 289 | 290 | 291 | function nql:compute_validation_statistics() 292 | local targets, delta, q2_max = self:getQUpdate{s=self.valid_s, 293 | a=self.valid_a, r=self.valid_r, s2=self.valid_s2, term=self.valid_term} 294 | 295 | self.v_avg = self.q_max * q2_max:mean() 296 | self.tderr_avg = delta:clone():abs():mean() 297 | end 298 | 299 | 300 | function nql:perceive(reward, rawstate, terminal, testing, testing_ep) 301 | -- Preprocess state (will be set to nil if terminal) 302 | local state = self:preprocess(rawstate):float() 303 | local curState 304 | 305 | if self.max_reward then 306 | reward = math.min(reward, self.max_reward) 307 | end 308 | if self.min_reward then 309 | reward = math.max(reward, self.min_reward) 310 | end 311 | if self.rescale_r then 312 | self.r_max = math.max(self.r_max, reward) 313 | end 314 | 315 | self.transitions:add_recent_state(state, terminal) 316 | 317 | local currentFullState = self.transitions:get_recent() 318 | 319 | --Store transition s, a, r, s' 320 | if self.lastState and not testing then 321 | self.transitions:add(self.lastState, self.lastAction, reward, 322 | self.lastTerminal, priority) 323 | end 324 | 325 | if self.numSteps == self.learn_start+1 and not testing then 326 | self:sample_validation_data() 327 | end 328 | 329 | curState= self.transitions:get_recent() 330 | curState = curState:resize(1, unpack(self.input_dims)) 331 | 332 | -- Select action 333 | local actionIndex = 1 334 | if not terminal then 335 | actionIndex = self:eGreedy(curState, testing_ep) 336 | end 337 | 338 | self.transitions:add_recent_action(actionIndex) 339 | 340 | --Do some Q-learning updates 341 | if self.numSteps > self.learn_start and not testing and 342 | self.numSteps % self.update_freq == 0 then 343 | for i = 1, self.n_replay do 344 | self:qLearnMinibatch() 345 | end 346 | end 347 | 348 | if not testing then 349 | self.numSteps = self.numSteps + 1 350 | end 351 | 352 | self.lastState = state:clone() 353 | self.lastAction = actionIndex 354 | self.lastTerminal = terminal 355 | 356 | if self.target_q and self.numSteps % self.target_q == 1 then 357 | self.target_network = self.network:clone() 358 | end 359 | 360 | if not terminal then 361 | return actionIndex 362 | else 363 | return 0 364 | end 365 | end 366 | 367 | 368 | function nql:eGreedy(state, testing_ep) 369 | self.ep = testing_ep or (self.ep_end + 370 | math.max(0, (self.ep_start - self.ep_end) * (self.ep_endt - 371 | math.max(0, self.numSteps - self.learn_start))/self.ep_endt)) 372 | -- Epsilon greedy 373 | if torch.uniform() < self.ep then 374 | return torch.random(1, self.n_actions) 375 | else 376 | return self:greedy(state) 377 | end 378 | end 379 | 380 | 381 | function nql:greedy(state) 382 | -- Turn single state into minibatch. Needed for convolutional nets. 383 | if state:dim() == 2 then 384 | assert(false, 'Input must be at least 3D') 385 | state = state:resize(1, state:size(1), state:size(2)) 386 | end 387 | 388 | if self.gpu >= 0 then 389 | state = state:cuda() 390 | end 391 | 392 | local q = self.network:forward(state):float():squeeze() 393 | local maxq = q[1] 394 | local besta = {1} 395 | 396 | -- Evaluate all other actions (with random tie-breaking) 397 | for a = 2, self.n_actions do 398 | if q[a] > maxq then 399 | besta = { a } 400 | maxq = q[a] 401 | elseif q[a] == maxq then 402 | besta[#besta+1] = a 403 | end 404 | end 405 | self.bestq = maxq 406 | 407 | local r = torch.random(1, #besta) 408 | 409 | self.lastAction = besta[r] 410 | 411 | return besta[r] 412 | end 413 | 414 | 415 | function nql:createNetwork() 416 | local n_hid = 128 417 | local mlp = nn.Sequential() 418 | mlp:add(nn.Reshape(self.hist_len*self.ncols*self.state_dim)) 419 | mlp:add(nn.Linear(self.hist_len*self.ncols*self.state_dim, n_hid)) 420 | mlp:add(nn.Rectifier()) 421 | mlp:add(nn.Linear(n_hid, n_hid)) 422 | mlp:add(nn.Rectifier()) 423 | mlp:add(nn.Linear(n_hid, self.n_actions)) 424 | 425 | return mlp 426 | end 427 | 428 | 429 | function nql:_loadNet() 430 | local net = self.network 431 | if self.gpu then 432 | net:cuda() 433 | else 434 | net:float() 435 | end 436 | return net 437 | end 438 | 439 | 440 | function nql:init(arg) 441 | self.actions = arg.actions 442 | self.n_actions = #self.actions 443 | self.network = self:_loadNet() 444 | -- Generate targets. 445 | self.transitions:empty() 446 | end 447 | 448 | 449 | function nql:report() 450 | print(get_weight_norms(self.network)) 451 | print(get_grad_norms(self.network)) 452 | end 453 | -------------------------------------------------------------------------------- /dqn/Rectifier.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | --[[ Rectified Linear Unit. 8 | 9 | The output is max(0, input). 10 | --]] 11 | 12 | local Rectifier, parent = torch.class('nn.Rectifier', 'nn.Module') 13 | 14 | -- This module accepts minibatches 15 | function Rectifier:updateOutput(input) 16 | return self.output:resizeAs(input):copy(input):abs():add(input):div(2) 17 | end 18 | 19 | function Rectifier:updateGradInput(input, gradOutput) 20 | self.gradInput:resizeAs(self.output) 21 | return self.gradInput:sign(self.output):cmul(gradOutput) 22 | end -------------------------------------------------------------------------------- /dqn/Scale.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "nn" 8 | require "image" 9 | 10 | local scale = torch.class('nn.Scale', 'nn.Module') 11 | 12 | 13 | function scale:__init(height, width) 14 | self.height = height 15 | self.width = width 16 | end 17 | 18 | function scale:forward(x) 19 | local x = x 20 | if x:dim() > 3 then 21 | x = x[1] 22 | end 23 | 24 | x = image.rgb2y(x) 25 | x = image.scale(x, self.width, self.height, 'bilinear') 26 | return x 27 | end 28 | 29 | function scale:updateOutput(input) 30 | return self:forward(input) 31 | end 32 | 33 | function scale:float() 34 | end 35 | -------------------------------------------------------------------------------- /dqn/TransitionTable.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require 'image' 8 | 9 | local trans = torch.class('dqn.TransitionTable') 10 | 11 | 12 | function trans:__init(args) 13 | self.stateDim = args.stateDim 14 | self.numActions = args.numActions 15 | self.histLen = args.histLen 16 | self.maxSize = args.maxSize or 1024^2 17 | self.bufferSize = args.bufferSize or 1024 18 | self.histType = args.histType or "linear" 19 | self.histSpacing = args.histSpacing or 1 20 | self.zeroFrames = args.zeroFrames or 1 21 | self.nonTermProb = args.nonTermProb or 1 22 | self.nonEventProb = args.nonEventProb or 1 23 | self.gpu = args.gpu 24 | self.numEntries = 0 25 | self.insertIndex = 0 26 | 27 | self.histIndices = {} 28 | local histLen = self.histLen 29 | if self.histType == "linear" then 30 | -- History is the last histLen frames. 31 | self.recentMemSize = self.histSpacing*histLen 32 | for i=1,histLen do 33 | self.histIndices[i] = i*self.histSpacing 34 | end 35 | elseif self.histType == "exp2" then 36 | -- The ith history frame is from 2^(i-1) frames ago. 37 | self.recentMemSize = 2^(histLen-1) 38 | self.histIndices[1] = 1 39 | for i=1,histLen-1 do 40 | self.histIndices[i+1] = self.histIndices[i] + 2^(7-i) 41 | end 42 | elseif self.histType == "exp1.25" then 43 | -- The ith history frame is from 1.25^(i-1) frames ago. 44 | self.histIndices[histLen] = 1 45 | for i=histLen-1,1,-1 do 46 | self.histIndices[i] = math.ceil(1.25*self.histIndices[i+1])+1 47 | end 48 | self.recentMemSize = self.histIndices[1] 49 | for i=1,histLen do 50 | self.histIndices[i] = self.recentMemSize - self.histIndices[i] + 1 51 | end 52 | end 53 | 54 | self.s = torch.ByteTensor(self.maxSize, self.stateDim):fill(0) 55 | self.a = torch.LongTensor(self.maxSize):fill(0) 56 | self.r = torch.zeros(self.maxSize) 57 | self.t = torch.ByteTensor(self.maxSize):fill(0) 58 | self.action_encodings = torch.eye(self.numActions) 59 | 60 | -- Tables for storing the last histLen states. They are used for 61 | -- constructing the most recent agent state more easily. 62 | self.recent_s = {} 63 | self.recent_a = {} 64 | self.recent_t = {} 65 | 66 | local s_size = self.stateDim*histLen 67 | self.buf_a = torch.LongTensor(self.bufferSize):fill(0) 68 | self.buf_r = torch.zeros(self.bufferSize) 69 | self.buf_term = torch.ByteTensor(self.bufferSize):fill(0) 70 | self.buf_s = torch.ByteTensor(self.bufferSize, s_size):fill(0) 71 | self.buf_s2 = torch.ByteTensor(self.bufferSize, s_size):fill(0) 72 | 73 | if self.gpu and self.gpu >= 0 then 74 | self.gpu_s = self.buf_s:float():cuda() 75 | self.gpu_s2 = self.buf_s2:float():cuda() 76 | end 77 | end 78 | 79 | 80 | function trans:reset() 81 | self.numEntries = 0 82 | self.insertIndex = 0 83 | end 84 | 85 | 86 | function trans:size() 87 | return self.numEntries 88 | end 89 | 90 | 91 | function trans:empty() 92 | return self.numEntries == 0 93 | end 94 | 95 | 96 | function trans:fill_buffer() 97 | assert(self.numEntries >= self.bufferSize) 98 | -- clear CPU buffers 99 | self.buf_ind = 1 100 | local ind 101 | for buf_ind=1,self.bufferSize do 102 | local s, a, r, s2, term = self:sample_one(1) 103 | self.buf_s[buf_ind]:copy(s) 104 | self.buf_a[buf_ind] = a 105 | self.buf_r[buf_ind] = r 106 | self.buf_s2[buf_ind]:copy(s2) 107 | self.buf_term[buf_ind] = term 108 | end 109 | self.buf_s = self.buf_s:float():div(255) 110 | self.buf_s2 = self.buf_s2:float():div(255) 111 | if self.gpu and self.gpu >= 0 then 112 | self.gpu_s:copy(self.buf_s) 113 | self.gpu_s2:copy(self.buf_s2) 114 | end 115 | end 116 | 117 | 118 | function trans:sample_one() 119 | assert(self.numEntries > 1) 120 | local index 121 | local valid = false 122 | while not valid do 123 | -- start at 2 because of previous action 124 | index = torch.random(2, self.numEntries-self.recentMemSize) 125 | if self.t[index+self.recentMemSize-1] == 0 then 126 | valid = true 127 | end 128 | if self.nonTermProb < 1 and self.t[index+self.recentMemSize] == 0 and 129 | torch.uniform() > self.nonTermProb then 130 | -- Discard non-terminal states with probability (1-nonTermProb). 131 | -- Note that this is the terminal flag for s_{t+1}. 132 | valid = false 133 | end 134 | if self.nonEventProb < 1 and self.t[index+self.recentMemSize] == 0 and 135 | self.r[index+self.recentMemSize-1] == 0 and 136 | torch.uniform() > self.nonTermProb then 137 | -- Discard non-terminal or non-reward states with 138 | -- probability (1-nonTermProb). 139 | valid = false 140 | end 141 | end 142 | 143 | return self:get(index) 144 | end 145 | 146 | 147 | function trans:sample(batch_size) 148 | local batch_size = batch_size or 1 149 | assert(batch_size < self.bufferSize) 150 | 151 | if not self.buf_ind or self.buf_ind + batch_size - 1 > self.bufferSize then 152 | self:fill_buffer() 153 | end 154 | 155 | local index = self.buf_ind 156 | 157 | self.buf_ind = self.buf_ind+batch_size 158 | local range = {{index, index+batch_size-1}} 159 | 160 | local buf_s, buf_s2, buf_a, buf_r, buf_term = self.buf_s, self.buf_s2, 161 | self.buf_a, self.buf_r, self.buf_term 162 | if self.gpu and self.gpu >=0 then 163 | buf_s = self.gpu_s 164 | buf_s2 = self.gpu_s2 165 | end 166 | 167 | return buf_s[range], buf_a[range], buf_r[range], buf_s2[range], buf_term[range] 168 | end 169 | 170 | 171 | function trans:concatFrames(index, use_recent) 172 | if use_recent then 173 | s, t = self.recent_s, self.recent_t 174 | else 175 | s, t = self.s, self.t 176 | end 177 | 178 | local fullstate = s[1].new() 179 | fullstate:resize(self.histLen, unpack(s[1]:size():totable())) 180 | 181 | -- Zero out frames from all but the most recent episode. 182 | local zero_out = false 183 | local episode_start = self.histLen 184 | 185 | for i=self.histLen-1,1,-1 do 186 | if not zero_out then 187 | for j=index+self.histIndices[i]-1,index+self.histIndices[i+1]-2 do 188 | if t[j] == 1 then 189 | zero_out = true 190 | break 191 | end 192 | end 193 | end 194 | 195 | if zero_out then 196 | fullstate[i]:zero() 197 | else 198 | episode_start = i 199 | end 200 | end 201 | 202 | if self.zeroFrames == 0 then 203 | episode_start = 1 204 | end 205 | 206 | -- Copy frames from the current episode. 207 | for i=episode_start,self.histLen do 208 | fullstate[i]:copy(s[index+self.histIndices[i]-1]) 209 | end 210 | 211 | return fullstate 212 | end 213 | 214 | 215 | function trans:concatActions(index, use_recent) 216 | local act_hist = torch.FloatTensor(self.histLen, self.numActions) 217 | if use_recent then 218 | a, t = self.recent_a, self.recent_t 219 | else 220 | a, t = self.a, self.t 221 | end 222 | 223 | -- Zero out frames from all but the most recent episode. 224 | local zero_out = false 225 | local episode_start = self.histLen 226 | 227 | for i=self.histLen-1,1,-1 do 228 | if not zero_out then 229 | for j=index+self.histIndices[i]-1,index+self.histIndices[i+1]-2 do 230 | if t[j] == 1 then 231 | zero_out = true 232 | break 233 | end 234 | end 235 | end 236 | 237 | if zero_out then 238 | act_hist[i]:zero() 239 | else 240 | episode_start = i 241 | end 242 | end 243 | 244 | if self.zeroFrames == 0 then 245 | episode_start = 1 246 | end 247 | 248 | -- Copy frames from the current episode. 249 | for i=episode_start,self.histLen do 250 | act_hist[i]:copy(self.action_encodings[a[index+self.histIndices[i]-1]]) 251 | end 252 | 253 | return act_hist 254 | end 255 | 256 | 257 | function trans:get_recent() 258 | -- Assumes that the most recent state has been added, but the action has not 259 | return self:concatFrames(1, true):float():div(255) 260 | end 261 | 262 | 263 | function trans:get(index) 264 | local s = self:concatFrames(index) 265 | local s2 = self:concatFrames(index+1) 266 | local ar_index = index+self.recentMemSize-1 267 | 268 | return s, self.a[ar_index], self.r[ar_index], s2, self.t[ar_index+1] 269 | end 270 | 271 | 272 | function trans:add(s, a, r, term) 273 | assert(s, 'State cannot be nil') 274 | assert(a, 'Action cannot be nil') 275 | assert(r, 'Reward cannot be nil') 276 | 277 | -- Incremenet until at full capacity 278 | if self.numEntries < self.maxSize then 279 | self.numEntries = self.numEntries + 1 280 | end 281 | 282 | -- Always insert at next index, then wrap around 283 | self.insertIndex = self.insertIndex + 1 284 | -- Overwrite oldest experience once at capacity 285 | if self.insertIndex > self.maxSize then 286 | self.insertIndex = 1 287 | end 288 | 289 | -- Overwrite (s,a,r,t) at insertIndex 290 | self.s[self.insertIndex] = s:clone():float():mul(255) 291 | self.a[self.insertIndex] = a 292 | self.r[self.insertIndex] = r 293 | if term then 294 | self.t[self.insertIndex] = 1 295 | else 296 | self.t[self.insertIndex] = 0 297 | end 298 | end 299 | 300 | 301 | function trans:add_recent_state(s, term) 302 | local s = s:clone():float():mul(255):byte() 303 | if #self.recent_s == 0 then 304 | for i=1,self.recentMemSize do 305 | table.insert(self.recent_s, s:clone():zero()) 306 | table.insert(self.recent_t, 1) 307 | end 308 | end 309 | 310 | table.insert(self.recent_s, s) 311 | if term then 312 | table.insert(self.recent_t, 1) 313 | else 314 | table.insert(self.recent_t, 0) 315 | end 316 | 317 | -- Keep recentMemSize states. 318 | if #self.recent_s > self.recentMemSize then 319 | table.remove(self.recent_s, 1) 320 | table.remove(self.recent_t, 1) 321 | end 322 | end 323 | 324 | 325 | function trans:add_recent_action(a) 326 | if #self.recent_a == 0 then 327 | for i=1,self.recentMemSize do 328 | table.insert(self.recent_a, 1) 329 | end 330 | end 331 | 332 | table.insert(self.recent_a, a) 333 | 334 | -- Keep recentMemSize steps. 335 | if #self.recent_a > self.recentMemSize then 336 | table.remove(self.recent_a, 1) 337 | end 338 | end 339 | 340 | 341 | --[[ 342 | Override the write function to serialize this class into a file. 343 | We do not want to store anything into the file, just the necessary info 344 | to create an empty transition table. 345 | 346 | @param file (FILE object ) @see torch.DiskFile 347 | --]] 348 | function trans:write(file) 349 | file:writeObject({self.stateDim, 350 | self.numActions, 351 | self.histLen, 352 | self.maxSize, 353 | self.bufferSize, 354 | self.numEntries, 355 | self.insertIndex, 356 | self.recentMemSize, 357 | self.histIndices}) 358 | end 359 | 360 | 361 | --[[ 362 | Override the read function to desearialize this class from file. 363 | Recreates an empty table. 364 | 365 | @param file (FILE object ) @see torch.DiskFile 366 | --]] 367 | function trans:read(file) 368 | local stateDim, numActions, histLen, maxSize, bufferSize, numEntries, insertIndex, recentMemSize, histIndices = unpack(file:readObject()) 369 | self.stateDim = stateDim 370 | self.numActions = numActions 371 | self.histLen = histLen 372 | self.maxSize = maxSize 373 | self.bufferSize = bufferSize 374 | self.recentMemSize = recentMemSize 375 | self.histIndices = histIndices 376 | self.numEntries = 0 377 | self.insertIndex = 0 378 | 379 | self.s = torch.ByteTensor(self.maxSize, self.stateDim):fill(0) 380 | self.a = torch.LongTensor(self.maxSize):fill(0) 381 | self.r = torch.zeros(self.maxSize) 382 | self.t = torch.ByteTensor(self.maxSize):fill(0) 383 | self.action_encodings = torch.eye(self.numActions) 384 | 385 | -- Tables for storing the last histLen states. They are used for 386 | -- constructing the most recent agent state more easily. 387 | self.recent_s = {} 388 | self.recent_a = {} 389 | self.recent_t = {} 390 | 391 | self.buf_a = torch.LongTensor(self.bufferSize):fill(0) 392 | self.buf_r = torch.zeros(self.bufferSize) 393 | self.buf_term = torch.ByteTensor(self.bufferSize):fill(0) 394 | self.buf_s = torch.ByteTensor(self.bufferSize, self.stateDim * self.histLen):fill(0) 395 | self.buf_s2 = torch.ByteTensor(self.bufferSize, self.stateDim * self.histLen):fill(0) 396 | 397 | if self.gpu and self.gpu >= 0 then 398 | self.gpu_s = self.buf_s:float():cuda() 399 | self.gpu_s2 = self.buf_s2:float():cuda() 400 | end 401 | end 402 | -------------------------------------------------------------------------------- /dqn/convnet.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "initenv" 8 | 9 | function create_network(args) 10 | 11 | local net = nn.Sequential() 12 | net:add(nn.Reshape(unpack(args.input_dims))) 13 | 14 | --- first convolutional layer 15 | local convLayer = nn.SpatialConvolution 16 | 17 | if args.gpu >= 0 then 18 | net:add(nn.Transpose({1,2},{2,3},{3,4})) 19 | convLayer = nn.SpatialConvolutionCUDA 20 | end 21 | 22 | net:add(convLayer(args.hist_len*args.ncols, args.n_units[1], 23 | args.filter_size[1], args.filter_size[1], 24 | args.filter_stride[1], args.filter_stride[1],1)) 25 | net:add(args.nl()) 26 | 27 | -- Add convolutional layers 28 | for i=1,(#args.n_units-1) do 29 | -- second convolutional layer 30 | net:add(convLayer(args.n_units[i], args.n_units[i+1], 31 | args.filter_size[i+1], args.filter_size[i+1], 32 | args.filter_stride[i+1], args.filter_stride[i+1])) 33 | net:add(args.nl()) 34 | end 35 | 36 | local nel 37 | if args.gpu >= 0 then 38 | net:add(nn.Transpose({4,3},{3,2},{2,1})) 39 | nel = net:cuda():forward(torch.zeros(1,unpack(args.input_dims)) 40 | :cuda()):nElement() 41 | else 42 | nel = net:forward(torch.zeros(1,unpack(args.input_dims))):nElement() 43 | end 44 | 45 | -- reshape all feature planes into a vector per example 46 | net:add(nn.Reshape(nel)) 47 | 48 | -- fully connected layer 49 | net:add(nn.Linear(nel, args.n_hid[1])) 50 | net:add(args.nl()) 51 | local last_layer_size = args.n_hid[1] 52 | 53 | for i=1,(#args.n_hid-1) do 54 | -- add Linear layer 55 | last_layer_size = args.n_hid[i+1] 56 | net:add(nn.Linear(args.n_hid[i], last_layer_size)) 57 | net:add(args.nl()) 58 | end 59 | 60 | -- add the last fully connected layer (to actions) 61 | net:add(nn.Linear(last_layer_size, args.n_actions)) 62 | 63 | if args.gpu >=0 then 64 | net:cuda() 65 | end 66 | if args.verbose >= 2 then 67 | print(net) 68 | print('Convolutional layers flattened output size:', nel) 69 | end 70 | return net 71 | end 72 | -------------------------------------------------------------------------------- /dqn/convnet_atari3.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require 'convnet' 8 | 9 | return function(args) 10 | args.n_units = {32, 64, 64} 11 | args.filter_size = {8, 4, 3} 12 | args.filter_stride = {4, 2, 1} 13 | args.n_hid = {512} 14 | args.nl = nn.Rectifier 15 | 16 | return create_network(args) 17 | end 18 | 19 | -------------------------------------------------------------------------------- /dqn/initenv.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | dqn = {} 7 | 8 | require 'torch' 9 | require 'nn' 10 | require 'nngraph' 11 | require 'nnutils' 12 | require 'image' 13 | require 'Scale' 14 | require 'NeuralQLearner' 15 | require 'TransitionTable' 16 | require 'Rectifier' 17 | 18 | 19 | function torchSetup(_opt) 20 | _opt = _opt or {} 21 | local opt = table.copy(_opt) 22 | assert(opt) 23 | 24 | -- preprocess options: 25 | --- convert options strings to tables 26 | if opt.pool_frms then 27 | opt.pool_frms = str_to_table(opt.pool_frms) 28 | end 29 | if opt.env_params then 30 | opt.env_params = str_to_table(opt.env_params) 31 | end 32 | if opt.agent_params then 33 | opt.agent_params = str_to_table(opt.agent_params) 34 | opt.agent_params.gpu = opt.gpu 35 | opt.agent_params.best = opt.best 36 | opt.agent_params.verbose = opt.verbose 37 | if opt.network ~= '' then 38 | opt.agent_params.network = opt.network 39 | end 40 | end 41 | 42 | --- general setup 43 | opt.tensorType = opt.tensorType or 'torch.FloatTensor' 44 | torch.setdefaulttensortype(opt.tensorType) 45 | if not opt.threads then 46 | opt.threads = 4 47 | end 48 | torch.setnumthreads(opt.threads) 49 | if not opt.verbose then 50 | opt.verbose = 10 51 | end 52 | if opt.verbose >= 1 then 53 | print('Torch Threads:', torch.getnumthreads()) 54 | end 55 | 56 | --- set gpu device 57 | if opt.gpu and opt.gpu >= 0 then 58 | require 'cutorch' 59 | require 'cunn' 60 | if opt.gpu == 0 then 61 | local gpu_id = tonumber(os.getenv('GPU_ID')) 62 | if gpu_id then opt.gpu = gpu_id+1 end 63 | end 64 | if opt.gpu > 0 then cutorch.setDevice(opt.gpu) end 65 | opt.gpu = cutorch.getDevice() 66 | print('Using GPU device id:', opt.gpu-1) 67 | else 68 | opt.gpu = -1 69 | if opt.verbose >= 1 then 70 | print('Using CPU code only. GPU device id:', opt.gpu) 71 | end 72 | end 73 | 74 | --- set up random number generators 75 | -- removing lua RNG; seeding torch RNG with opt.seed and setting cutorch 76 | -- RNG seed to the first uniform random int32 from the previous RNG; 77 | -- this is preferred because using the same seed for both generators 78 | -- may introduce correlations; we assume that both torch RNGs ensure 79 | -- adequate dispersion for different seeds. 80 | math.random = nil 81 | opt.seed = opt.seed or 1 82 | torch.manualSeed(opt.seed) 83 | if opt.verbose >= 1 then 84 | print('Torch Seed:', torch.initialSeed()) 85 | end 86 | local firstRandInt = torch.random() 87 | if opt.gpu >= 0 then 88 | cutorch.manualSeed(firstRandInt) 89 | if opt.verbose >= 1 then 90 | print('CUTorch Seed:', cutorch.initialSeed()) 91 | end 92 | end 93 | 94 | return opt 95 | end 96 | 97 | 98 | function setup(_opt) 99 | assert(_opt) 100 | 101 | --preprocess options: 102 | --- convert options strings to tables 103 | _opt.pool_frms = str_to_table(_opt.pool_frms) 104 | _opt.env_params = str_to_table(_opt.env_params) 105 | _opt.agent_params = str_to_table(_opt.agent_params) 106 | if _opt.agent_params.transition_params then 107 | _opt.agent_params.transition_params = 108 | str_to_table(_opt.agent_params.transition_params) 109 | end 110 | 111 | --- first things first 112 | local opt = torchSetup(_opt) 113 | 114 | -- load training framework and environment 115 | local framework = require(opt.framework) 116 | assert(framework) 117 | 118 | local gameEnv = framework.GameEnvironment(opt) 119 | local gameActions = gameEnv:getActions() 120 | 121 | -- agent options 122 | _opt.agent_params.actions = gameActions 123 | _opt.agent_params.gpu = _opt.gpu 124 | _opt.agent_params.best = _opt.best 125 | if _opt.network ~= '' then 126 | _opt.agent_params.network = _opt.network 127 | end 128 | _opt.agent_params.verbose = _opt.verbose 129 | if not _opt.agent_params.state_dim then 130 | _opt.agent_params.state_dim = gameEnv:nObsFeature() 131 | end 132 | 133 | local agent = dqn[_opt.agent](_opt.agent_params) 134 | 135 | if opt.verbose >= 1 then 136 | print('Set up Torch using these options:') 137 | for k, v in pairs(opt) do 138 | print(k, v) 139 | end 140 | end 141 | 142 | return gameEnv, gameActions, agent, opt 143 | end 144 | 145 | 146 | 147 | --- other functions 148 | 149 | function str_to_table(str) 150 | if type(str) == 'table' then 151 | return str 152 | end 153 | if not str or type(str) ~= 'string' then 154 | if type(str) == 'table' then 155 | return str 156 | end 157 | return {} 158 | end 159 | local ttr 160 | if str ~= '' then 161 | local ttx=tt 162 | loadstring('tt = {' .. str .. '}')() 163 | ttr = tt 164 | tt = ttx 165 | else 166 | ttr = {} 167 | end 168 | return ttr 169 | end 170 | 171 | function table.copy(t) 172 | if t == nil then return nil end 173 | local nt = {} 174 | for k, v in pairs(t) do 175 | if type(v) == 'table' then 176 | nt[k] = table.copy(v) 177 | else 178 | nt[k] = v 179 | end 180 | end 181 | setmetatable(nt, table.copy(getmetatable(t))) 182 | return nt 183 | end 184 | -------------------------------------------------------------------------------- /dqn/net_downsample_2x_full_y.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "image" 8 | require "Scale" 9 | 10 | local function create_network(args) 11 | -- Y (luminance) 12 | return nn.Scale(84, 84, true) 13 | end 14 | 15 | return create_network 16 | -------------------------------------------------------------------------------- /dqn/nnutils.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "torch" 8 | 9 | function recursive_map(module, field, func) 10 | local str = "" 11 | if module[field] or module.modules then 12 | str = str .. torch.typename(module) .. ": " 13 | end 14 | if module[field] then 15 | str = str .. func(module[field]) 16 | end 17 | if module.modules then 18 | str = str .. "[" 19 | for i, submodule in ipairs(module.modules) do 20 | local submodule_str = recursive_map(submodule, field, func) 21 | str = str .. submodule_str 22 | if i < #module.modules and string.len(submodule_str) > 0 then 23 | str = str .. " " 24 | end 25 | end 26 | str = str .. "]" 27 | end 28 | 29 | return str 30 | end 31 | 32 | function abs_mean(w) 33 | return torch.mean(torch.abs(w:clone():float())) 34 | end 35 | 36 | function abs_max(w) 37 | return torch.abs(w:clone():float()):max() 38 | end 39 | 40 | -- Build a string of average absolute weight values for the modules in the 41 | -- given network. 42 | function get_weight_norms(module) 43 | return "Weight norms:\n" .. recursive_map(module, "weight", abs_mean) .. 44 | "\nWeight max:\n" .. recursive_map(module, "weight", abs_max) 45 | end 46 | 47 | -- Build a string of average absolute weight gradient values for the modules 48 | -- in the given network. 49 | function get_grad_norms(module) 50 | return "Weight grad norms:\n" .. 51 | recursive_map(module, "gradWeight", abs_mean) .. 52 | "\nWeight grad max:\n" .. recursive_map(module, "gradWeight", abs_max) 53 | end 54 | -------------------------------------------------------------------------------- /dqn/train_agent.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | if not dqn then 8 | require "initenv" 9 | end 10 | 11 | local cmd = torch.CmdLine() 12 | cmd:text() 13 | cmd:text('Train Agent in Environment:') 14 | cmd:text() 15 | cmd:text('Options:') 16 | 17 | cmd:option('-framework', '', 'name of training framework') 18 | cmd:option('-env', '', 'name of environment to use') 19 | cmd:option('-game_path', '', 'path to environment file (ROM)') 20 | cmd:option('-env_params', '', 'string of environment parameters') 21 | cmd:option('-pool_frms', '', 22 | 'string of frame pooling parameters (e.g.: size=2,type="max")') 23 | cmd:option('-actrep', 1, 'how many times to repeat action') 24 | cmd:option('-random_starts', 0, 'play action 0 between 1 and random_starts ' .. 25 | 'number of times at the start of each training episode') 26 | 27 | cmd:option('-name', '', 'filename used for saving network and training history') 28 | cmd:option('-network', '', 'reload pretrained network') 29 | cmd:option('-agent', '', 'name of agent file to use') 30 | cmd:option('-agent_params', '', 'string of agent parameters') 31 | cmd:option('-seed', 1, 'fixed input seed for repeatable experiments') 32 | cmd:option('-saveNetworkParams', false, 33 | 'saves the agent network in a separate file') 34 | cmd:option('-prog_freq', 5*10^3, 'frequency of progress output') 35 | cmd:option('-save_freq', 5*10^4, 'the model is saved every save_freq steps') 36 | cmd:option('-eval_freq', 10^4, 'frequency of greedy evaluation') 37 | cmd:option('-save_versions', 0, '') 38 | 39 | cmd:option('-steps', 10^5, 'number of training steps to perform') 40 | cmd:option('-eval_steps', 10^5, 'number of evaluation steps') 41 | 42 | cmd:option('-verbose', 2, 43 | 'the higher the level, the more information is printed to screen') 44 | cmd:option('-threads', 1, 'number of BLAS threads') 45 | cmd:option('-gpu', -1, 'gpu flag') 46 | 47 | cmd:text() 48 | 49 | local opt = cmd:parse(arg) 50 | 51 | --- General setup. 52 | local game_env, game_actions, agent, opt = setup(opt) 53 | 54 | -- override print to always flush the output 55 | local old_print = print 56 | local print = function(...) 57 | old_print(...) 58 | io.flush() 59 | end 60 | 61 | local learn_start = agent.learn_start 62 | local start_time = sys.clock() 63 | local reward_counts = {} 64 | local episode_counts = {} 65 | local time_history = {} 66 | local v_history = {} 67 | local qmax_history = {} 68 | local td_history = {} 69 | local reward_history = {} 70 | local step = 0 71 | time_history[1] = 0 72 | 73 | local total_reward 74 | local nrewards 75 | local nepisodes 76 | local episode_reward 77 | 78 | local screen, reward, terminal = game_env:getState() 79 | 80 | print("Iteration ..", step) 81 | while step < opt.steps do 82 | step = step + 1 83 | local action_index = agent:perceive(reward, screen, terminal) 84 | 85 | -- game over? get next game! 86 | if not terminal then 87 | screen, reward, terminal = game_env:step(game_actions[action_index], true) 88 | else 89 | if opt.random_starts > 0 then 90 | screen, reward, terminal = game_env:nextRandomGame() 91 | else 92 | screen, reward, terminal = game_env:newGame() 93 | end 94 | end 95 | 96 | if step % opt.prog_freq == 0 then 97 | assert(step==agent.numSteps, 'trainer step: ' .. step .. 98 | ' & agent.numSteps: ' .. agent.numSteps) 99 | print("Steps: ", step) 100 | agent:report() 101 | collectgarbage() 102 | end 103 | 104 | if step%1000 == 0 then collectgarbage() end 105 | 106 | if step % opt.eval_freq == 0 and step > learn_start then 107 | 108 | screen, reward, terminal = game_env:newGame() 109 | 110 | total_reward = 0 111 | nrewards = 0 112 | nepisodes = 0 113 | episode_reward = 0 114 | 115 | local eval_time = sys.clock() 116 | for estep=1,opt.eval_steps do 117 | local action_index = agent:perceive(reward, screen, terminal, true, 0.05) 118 | 119 | -- Play game in test mode (episodes don't end when losing a life) 120 | screen, reward, terminal = game_env:step(game_actions[action_index]) 121 | 122 | if estep%1000 == 0 then collectgarbage() end 123 | 124 | -- record every reward 125 | episode_reward = episode_reward + reward 126 | if reward ~= 0 then 127 | nrewards = nrewards + 1 128 | end 129 | 130 | if terminal then 131 | total_reward = total_reward + episode_reward 132 | episode_reward = 0 133 | nepisodes = nepisodes + 1 134 | screen, reward, terminal = game_env:nextRandomGame() 135 | end 136 | end 137 | 138 | eval_time = sys.clock() - eval_time 139 | start_time = start_time + eval_time 140 | agent:compute_validation_statistics() 141 | local ind = #reward_history+1 142 | total_reward = total_reward/math.max(1, nepisodes) 143 | 144 | if #reward_history == 0 or total_reward > torch.Tensor(reward_history):max() then 145 | agent.best_network = agent.network:clone() 146 | end 147 | 148 | if agent.v_avg then 149 | v_history[ind] = agent.v_avg 150 | td_history[ind] = agent.tderr_avg 151 | qmax_history[ind] = agent.q_max 152 | end 153 | print("V", v_history[ind], "TD error", td_history[ind], "Qmax", qmax_history[ind]) 154 | 155 | reward_history[ind] = total_reward 156 | reward_counts[ind] = nrewards 157 | episode_counts[ind] = nepisodes 158 | 159 | time_history[ind+1] = sys.clock() - start_time 160 | 161 | local time_dif = time_history[ind+1] - time_history[ind] 162 | 163 | local training_rate = opt.actrep*opt.eval_freq/time_dif 164 | 165 | print(string.format( 166 | '\nSteps: %d (frames: %d), reward: %.2f, epsilon: %.2f, lr: %G, ' .. 167 | 'training time: %ds, training rate: %dfps, testing time: %ds, ' .. 168 | 'testing rate: %dfps, num. ep.: %d, num. rewards: %d', 169 | step, step*opt.actrep, total_reward, agent.ep, agent.lr, time_dif, 170 | training_rate, eval_time, opt.actrep*opt.eval_steps/eval_time, 171 | nepisodes, nrewards)) 172 | end 173 | 174 | if step % opt.save_freq == 0 or step == opt.steps then 175 | local s, a, r, s2, term = agent.valid_s, agent.valid_a, agent.valid_r, 176 | agent.valid_s2, agent.valid_term 177 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2, 178 | agent.valid_term = nil, nil, nil, nil, nil, nil, nil 179 | local w, dw, g, g2, delta, delta2, deltas, tmp = agent.w, agent.dw, 180 | agent.g, agent.g2, agent.delta, agent.delta2, agent.deltas, agent.tmp 181 | agent.w, agent.dw, agent.g, agent.g2, agent.delta, agent.delta2, 182 | agent.deltas, agent.tmp = nil, nil, nil, nil, nil, nil, nil, nil 183 | 184 | local filename = opt.name 185 | if opt.save_versions > 0 then 186 | filename = filename .. "_" .. math.floor(step / opt.save_versions) 187 | end 188 | filename = filename 189 | torch.save(filename .. ".t7", {agent = agent, 190 | model = agent.network, 191 | best_model = agent.best_network, 192 | reward_history = reward_history, 193 | reward_counts = reward_counts, 194 | episode_counts = episode_counts, 195 | time_history = time_history, 196 | v_history = v_history, 197 | td_history = td_history, 198 | qmax_history = qmax_history, 199 | arguments=opt}) 200 | if opt.saveNetworkParams then 201 | local nets = {network=w:clone():float()} 202 | torch.save(filename..'.params.t7', nets, 'ascii') 203 | end 204 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2, 205 | agent.valid_term = s, a, r, s2, term 206 | agent.w, agent.dw, agent.g, agent.g2, agent.delta, agent.delta2, 207 | agent.deltas, agent.tmp = w, dw, g, g2, delta, delta2, deltas, tmp 208 | print('Saved:', filename .. '.t7') 209 | io.flush() 210 | collectgarbage() 211 | end 212 | end 213 | -------------------------------------------------------------------------------- /install_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ###################################################################### 4 | # Torch install 5 | ###################################################################### 6 | 7 | 8 | TOPDIR=$PWD 9 | 10 | # Prefix: 11 | PREFIX=$PWD/torch 12 | echo "Installing Torch into: $PREFIX" 13 | 14 | if [[ `uname` != 'Linux' ]]; then 15 | echo 'Platform unsupported, only available for Linux' 16 | exit 17 | fi 18 | if [[ `which apt-get` == '' ]]; then 19 | echo 'apt-get not found, platform not supported' 20 | exit 21 | fi 22 | 23 | # Install dependencies for Torch: 24 | sudo apt-get update 25 | sudo apt-get install -qqy build-essential 26 | sudo apt-get install -qqy gcc g++ 27 | sudo apt-get install -qqy cmake 28 | sudo apt-get install -qqy curl 29 | sudo apt-get install -qqy libreadline-dev 30 | sudo apt-get install -qqy git-core 31 | sudo apt-get install -qqy libjpeg-dev 32 | sudo apt-get install -qqy libpng-dev 33 | sudo apt-get install -qqy ncurses-dev 34 | sudo apt-get install -qqy imagemagick 35 | sudo apt-get install -qqy unzip 36 | sudo apt-get update 37 | 38 | 39 | echo "==> Torch7's dependencies have been installed" 40 | 41 | 42 | 43 | 44 | 45 | # Build and install Torch7 46 | cd /tmp 47 | rm -rf luajit-rocks 48 | git clone https://github.com/torch/luajit-rocks.git 49 | cd luajit-rocks 50 | mkdir -p build 51 | cd build 52 | git checkout master; git pull 53 | rm -f CMakeCache.txt 54 | cmake .. -DCMAKE_INSTALL_PREFIX=$PREFIX -DCMAKE_BUILD_TYPE=Release 55 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 56 | make 57 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 58 | make install 59 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 60 | 61 | 62 | path_to_nvcc=$(which nvcc) 63 | if [ -x "$path_to_nvcc" ] 64 | then 65 | cutorch=ok 66 | cunn=ok 67 | fi 68 | 69 | # Install base packages: 70 | $PREFIX/bin/luarocks install cwrap 71 | $PREFIX/bin/luarocks install paths 72 | $PREFIX/bin/luarocks install torch 73 | $PREFIX/bin/luarocks install nn 74 | 75 | [ -n "$cutorch" ] && \ 76 | ($PREFIX/bin/luarocks install cutorch) 77 | [ -n "$cunn" ] && \ 78 | ($PREFIX/bin/luarocks install cunn) 79 | 80 | $PREFIX/bin/luarocks install luafilesystem 81 | $PREFIX/bin/luarocks install penlight 82 | $PREFIX/bin/luarocks install sys 83 | $PREFIX/bin/luarocks install xlua 84 | $PREFIX/bin/luarocks install image 85 | $PREFIX/bin/luarocks install env 86 | 87 | echo "" 88 | echo "=> Torch7 has been installed successfully" 89 | echo "" 90 | 91 | 92 | echo "Installing nngraph ... " 93 | $PREFIX/bin/luarocks install nngraph 94 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 95 | echo "nngraph installation completed" 96 | 97 | echo "Installing Xitari ... " 98 | cd /tmp 99 | rm -rf xitari 100 | git clone https://github.com/deepmind/xitari.git 101 | cd xitari 102 | $PREFIX/bin/luarocks make 103 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 104 | echo "Xitari installation completed" 105 | 106 | echo "Installing Alewrap ... " 107 | cd /tmp 108 | rm -rf alewrap 109 | git clone https://github.com/deepmind/alewrap.git 110 | cd alewrap 111 | $PREFIX/bin/luarocks make 112 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 113 | echo "Alewrap installation completed" 114 | 115 | echo 116 | echo "You can run experiments by executing: " 117 | echo 118 | echo " ./run_cpu game_name" 119 | echo 120 | echo " or " 121 | echo 122 | echo " ./run_gpu game_name" 123 | echo 124 | echo "For this you need to provide the rom files of the respective games (game_name.bin) in the roms/ directory" 125 | echo 126 | 127 | -------------------------------------------------------------------------------- /roms/README: -------------------------------------------------------------------------------- 1 | Rom files should be put in this directory 2 | -------------------------------------------------------------------------------- /run_cpu: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./run_cpu breakout "; exit 0 5 | fi 6 | ENV=$1 7 | FRAMEWORK="alewrap" 8 | 9 | game_path=$PWD"/roms/" 10 | env_params="useRGB=true" 11 | agent="NeuralQLearner" 12 | n_replay=1 13 | netfile="\"convnet_atari3\"" 14 | update_freq=4 15 | actrep=4 16 | discount=0.99 17 | seed=1 18 | learn_start=50000 19 | pool_frms_type="\"max\"" 20 | pool_frms_size=2 21 | initial_priority="false" 22 | replay_memory=1000000 23 | eps_end=0.1 24 | eps_endt=replay_memory 25 | lr=0.00025 26 | agent_type="DQN3_0_1" 27 | preproc_net="\"net_downsample_2x_full_y\"" 28 | agent_name=$agent_type"_"$1"_FULL_Y" 29 | state_dim=7056 30 | ncols=1 31 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,min_reward=-1,max_reward=1" 32 | steps=50000000 33 | eval_freq=250000 34 | eval_steps=125000 35 | prog_freq=5000 36 | save_freq=125000 37 | gpu=-1 38 | random_starts=30 39 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 40 | num_threads=4 41 | 42 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads" 43 | echo $args 44 | 45 | cd dqn 46 | ../torch/bin/luajit train_agent.lua $args 47 | -------------------------------------------------------------------------------- /run_gpu: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./run_gpu breakout "; exit 0 5 | fi 6 | ENV=$1 7 | FRAMEWORK="alewrap" 8 | 9 | game_path=$PWD"/roms/" 10 | env_params="useRGB=true" 11 | agent="NeuralQLearner" 12 | n_replay=1 13 | netfile="\"convnet_atari3\"" 14 | update_freq=4 15 | actrep=4 16 | discount=0.99 17 | seed=1 18 | learn_start=50000 19 | pool_frms_type="\"max\"" 20 | pool_frms_size=2 21 | initial_priority="false" 22 | replay_memory=1000000 23 | eps_end=0.1 24 | eps_endt=replay_memory 25 | lr=0.00025 26 | agent_type="DQN3_0_1" 27 | preproc_net="\"net_downsample_2x_full_y\"" 28 | agent_name=$agent_type"_"$1"_FULL_Y" 29 | state_dim=7056 30 | ncols=1 31 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,min_reward=-1,max_reward=1" 32 | steps=50000000 33 | eval_freq=250000 34 | eval_steps=125000 35 | prog_freq=10000 36 | save_freq=125000 37 | gpu=0 38 | random_starts=30 39 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 40 | num_threads=4 41 | 42 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads" 43 | echo $args 44 | 45 | cd dqn 46 | ../torch/bin/luajit train_agent.lua $args 47 | --------------------------------------------------------------------------------