├── .gitignore ├── README.md ├── dqn ├── LICENSE ├── NeuralQLearner.lua ├── Rectifier.lua ├── Scale.lua ├── TransitionTable.lua ├── convnet.lua ├── convnet_atari3.lua ├── convnet_nes.lua ├── initenv.lua ├── net_downsample_2x_full_y.lua ├── nnutils.lua ├── test_agent.lua └── train_agent.lua ├── install_dependencies.sh ├── logs └── .gitignore ├── roms ├── README └── breakout.bin ├── saves └── .gitignore ├── test_cpu.sh ├── test_gpu.sh ├── train_cpu.sh └── train_gpu.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | torch 3 | dqn/*.t7 4 | 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep-Q MarI/O 2 | This is a fork of Google's Deep Q Network code used to master classic Atari games. This has been adapted here to play Super Mario Bros. It uses a double deep Q network to control an open-source Nintendo Entertainment System emulator called FCEUX. 3 | 4 | For instructions and a summary of changes to the original Google project, please see [this blog post.](http://www.ehrenbrav.com/2016/08/teaching-your-computer-to-play-super-mario-bros-a-fork-of-the-google-deepmind-atari-machine-learning-project/) 5 | 6 | Tested on Debian (x64) with an nVidia GTX 980: 7 | 8 | 9 | -------------------------------------------------------------------------------- /dqn/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | LIMITED LICENSE: 3 | 4 | Copyright (c) 2014 Google Inc. 5 | Limited License: Under no circumstance is commercial use, reproduction, or 6 | distribution permitted. Use, reproduction, and distribution are permitted 7 | solely for academic use in evaluating and reviewing claims made in 8 | "Human-level control through deep reinforcement learning", Nature 518, 529–533 9 | (26 February 2015) doi:10.1038/nature14236, provided that the following 10 | conditions are met: 11 | 12 | * Any reproduction or distribution of source code must retain the above 13 | copyright notice and the full text of this license including the following 14 | disclaimer.
 15 | 16 | * Any reproduction or distribution in binary form must reproduce the above 17 | copyright notice and the full text of this license including the following 18 | disclaimer
 in the documentation and/or other materials provided with the 19 | distribution. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /dqn/NeuralQLearner.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | if not dqn then 8 | require 'initenv' 9 | end 10 | 11 | local nql = torch.class('dqn.NeuralQLearner') 12 | local win = nil 13 | 14 | 15 | function nql:__init(args) 16 | self.state_dim = args.state_dim -- State dimensionality. 17 | self.actions = args.actions 18 | self.n_actions = #self.actions 19 | self.verbose = args.verbose 20 | self.best = args.best -- Whether we should load the best or the latest network. 21 | 22 | --- epsilon annealing 23 | self.ep_start = args.ep or 1 24 | self.ep = self.ep_start -- Exploration probability. 25 | self.ep_end = args.ep_end or self.ep 26 | self.ep_endt = args.ep_endt or 1000000 27 | 28 | ---- learning rate annealing 29 | self.lr_start = args.lr or 0.01 --Learning rate. 30 | self.lr = self.lr_start 31 | self.lr_end = args.lr_end or self.lr 32 | self.lr_endt = args.lr_endt or 1000000 33 | self.wc = args.wc or 0 -- L2 weight cost. 34 | self.minibatch_size = args.minibatch_size or 1 35 | self.valid_size = args.valid_size or 500 36 | 37 | --- Q-learning parameters 38 | self.discount = args.discount or 0.99 --Discount factor. 39 | self.update_freq = args.update_freq or 1 40 | -- Number of points to replay per learning step. 41 | self.n_replay = args.n_replay or 1 42 | -- Number of steps after which learning starts. 43 | self.learn_start = args.learn_start or 0 44 | -- Size of the transition table. 45 | self.replay_memory = args.replay_memory or 1000000 46 | self.hist_len = args.hist_len or 1 47 | self.rescale_r = args.rescale_r 48 | self.max_reward = args.max_reward 49 | self.min_reward = args.min_reward 50 | self.clip_delta = args.clip_delta 51 | self.target_q = args.target_q 52 | self.bestq = 0 53 | 54 | self.gpu = args.gpu 55 | 56 | self.ncols = args.ncols or 1 -- number of color channels in input 57 | self.input_dims = args.input_dims or {self.hist_len*self.ncols, 84, 84} 58 | self.preproc = args.preproc -- name of preprocessing network 59 | self.histType = args.histType or "linear" -- history type to use 60 | self.histSpacing = args.histSpacing or 1 61 | self.nonTermProb = args.nonTermProb or 1 62 | self.nonEventProb = args.nonEventProb 63 | self.bufferSize = args.bufferSize or 512 64 | 65 | self.transition_params = args.transition_params or {} 66 | 67 | self.network = args.network or self:createNetwork() 68 | 69 | -- check whether there is a network file 70 | local network_function 71 | if not (type(self.network) == 'string') then 72 | error("The type of the network provided in NeuralQLearner" .. 73 | " is not a string!") 74 | end 75 | 76 | local msg, err = pcall(require, self.network) 77 | if not msg then 78 | -- try to load saved agent 79 | local err_msg, exp = pcall(torch.load, self.network) 80 | if not err_msg then 81 | error("Could not find network file. Error: " .. exp) 82 | end 83 | if self.best and exp.best_model then 84 | print("Loading best model...") 85 | self.network = exp.best_model 86 | else 87 | print("Loading the latest (not necessarily the best) model...") 88 | self.network = exp.model 89 | end 90 | else 91 | print('Creating Agent Network from ' .. self.network) 92 | self.network = err 93 | self.network = self:network() 94 | end 95 | 96 | if self.gpu and self.gpu >= 0 then 97 | self.network:cuda() 98 | else 99 | self.network:float() 100 | end 101 | 102 | -- Load preprocessing network. 103 | if not (type(self.preproc == 'string')) then 104 | error('The preprocessing is not a string') 105 | end 106 | msg, err = pcall(require, self.preproc) 107 | if not msg then 108 | error("Error loading preprocessing net. Error: " .. err) 109 | end 110 | self.preproc = err 111 | self.preproc = self:preproc() 112 | self.preproc:float() 113 | 114 | if self.gpu and self.gpu >= 0 then 115 | self.network:cuda() 116 | self.tensor_type = torch.CudaTensor 117 | else 118 | self.network:float() 119 | self.tensor_type = torch.FloatTensor 120 | end 121 | 122 | -- Create transition table. 123 | ---- assuming the transition table always gets floating point input 124 | ---- (Foat or Cuda tensors) and always returns one of the two, as required 125 | ---- internally it always uses ByteTensors for states, scaling and 126 | ---- converting accordingly 127 | local transition_args = { 128 | stateDim = self.state_dim, numActions = self.n_actions, 129 | histLen = self.hist_len, gpu = self.gpu, 130 | maxSize = self.replay_memory, histType = self.histType, 131 | histSpacing = self.histSpacing, nonTermProb = self.nonTermProb, 132 | bufferSize = self.bufferSize, nonEventProb = self.nonEventProb 133 | } 134 | 135 | self.transitions = dqn.TransitionTable(transition_args) 136 | 137 | self.numSteps = 0 -- Number of perceived states. 138 | self.lastState = nil 139 | self.lastAction = nil 140 | self.v_avg = 0 -- V running average. 141 | self.tderr_avg = 0 -- Temporal-difference error running average. 142 | 143 | self.q_max = 1 144 | self.r_max = math .max(torch.abs(self.max_reward), torch.abs(self.min_reward)) 145 | 146 | self.w, self.dw = self.network:getParameters() -- Load the weights. 147 | self.dw:zero() -- Set gradient to zero. 148 | 149 | self.deltas = self.dw:clone():fill(0) 150 | 151 | self.tmp= self.dw:clone():fill(0) 152 | self.g = self.dw:clone():fill(0) 153 | self.g2 = self.dw:clone():fill(0) 154 | 155 | -- Initialize the target nework to be equal to the current network. 156 | if self.target_q then 157 | self.target_network = self.network:clone() 158 | end 159 | end 160 | 161 | 162 | function nql:reset(state) 163 | if not state then 164 | return 165 | end 166 | self.best_network = state.best_network 167 | self.network = state.model 168 | self.w, self.dw = self.network:getParameters() 169 | self.dw:zero() 170 | self.numSteps = 0 171 | print("RESET STATE SUCCESFULLY") 172 | end 173 | 174 | 175 | function nql:preprocess(rawstate) 176 | 177 | if self.preproc then 178 | local input_state = self.preproc:forward(rawstate:float()) 179 | :clone():reshape(self.state_dim) 180 | 181 | -- Optionally display the preprocessed image... 182 | if self.verbose > 3 then 183 | win = image.display({image=input_state:clone():reshape(84, 84), win=win}) 184 | end 185 | 186 | return input_state 187 | end 188 | 189 | return rawstate 190 | end 191 | 192 | -- The idea here is to calculate the predicted ideal actions 193 | -- each state in the minibatch, and to update the network 194 | -- such that the prediction matches the actual desireable 195 | -- outcome. 196 | function nql:getQUpdate(args) 197 | local s, a, r, s2, term, delta, best_a 198 | local q, q2, q2_online, q2_target 199 | 200 | s = args.s 201 | a = args.a 202 | r = args.r 203 | s2 = args.s2 204 | term = args.term 205 | 206 | -- The order of calls to forward is a bit odd in order 207 | -- to avoid unnecessary calls (we only need 2). 208 | 209 | -- delta = r + (1-terminal) * gamma * max_a Q(s2, a) - Q(s, a) 210 | term = term:clone():float():mul(-1):add(1) 211 | 212 | -- If we don't have a target Q network yet, make one and use it. 213 | local target_q_net 214 | if self.target_q then 215 | target_q_net = self.target_network 216 | else 217 | target_q_net = self.network 218 | end 219 | 220 | -- Using *Double* DQN here... 221 | -- For each s2 in the minibatch, 222 | -- pick the action with the highest value using the *online* network 223 | -- and then calculate the Q-value of s2 given this action using the *target* network. 224 | q2_online, best_a = self.network:forward(s2):float():max(2) 225 | 226 | -- Get the Q-values for the best actions we identified above, using the *target* network. 227 | q2_target = target_q_net:forward(s2):float() 228 | q2_max = torch.FloatTensor(best_a:size()) 229 | for i=1, best_a:size(1) do 230 | local a_index = best_a[i][1] 231 | q2_max[i] = q2_target[i][a_index] 232 | end 233 | 234 | -- Compute q2 = (1-terminal) * gamma * max_a Q(s2, a) 235 | -- Discounted by gamma and set to zero if terminal. 236 | q2 = q2_max:clone():mul(self.discount):cmul(term) 237 | 238 | -- Set delta equal to the rewards in the minibatch. 239 | delta = r:clone():float() 240 | 241 | -- Rescale the reward to [-1, 1] if requested. 242 | if self.rescale_r then 243 | delta:div(self.r_max) 244 | end 245 | 246 | -- Add the discounted Q(s2, a) values to these rewards. 247 | delta:add(q2) 248 | 249 | -- q = Q(s,a) 250 | -- This estimates the value of state s for actions a using the *online* network, 251 | local q_all = self.network:forward(s):float() 252 | q = torch.FloatTensor(q_all:size(1)) 253 | for i=1,q_all:size(1) do 254 | q[i] = q_all[i][a[i]] 255 | end 256 | 257 | -- Finally, subtract out the Q(s, a) values. 258 | delta:add(-1, q) 259 | 260 | -- Keep the deltas bounded, if requested. 261 | if self.clip_delta then 262 | delta[delta:ge(self.clip_delta)] = self.clip_delta 263 | delta[delta:le(-self.clip_delta)] = -self.clip_delta 264 | end 265 | 266 | local targets = torch.zeros(self.minibatch_size, self.n_actions):float() 267 | for i=1,math.min(self.minibatch_size,a:size(1)) do 268 | targets[i][a[i]] = delta[i] 269 | end 270 | 271 | if self.gpu >= 0 then targets = targets:cuda() end 272 | 273 | return targets, delta, q2_max 274 | end 275 | 276 | 277 | function nql:qLearnMinibatch() 278 | -- Perform a minibatch Q-learning update: 279 | -- w += lr * [r + (discount * max Q(s2,a2)) - Q(s,a)] * dQ(s,a)/dw 280 | assert(self.transitions:size() > self.minibatch_size) 281 | 282 | -- Load a minibatch of experiences. 283 | local s, a, r, s2, term = self.transitions:sample(self.minibatch_size) 284 | 285 | -- Feed these experiences into the Q network. 286 | local targets, delta, q2_max = self:getQUpdate{s=s, a=a, r=r, s2=s2, 287 | term=term, update_qmax=true} 288 | 289 | -- zero gradients of parameters 290 | self.dw:zero() 291 | 292 | -- Do a backwards pass to calculate the gradients. 293 | self.network:backward(s, targets) 294 | 295 | -- add weight cost to gradient - this defaults to zero. 296 | self.dw:add(-self.wc, self.w) 297 | 298 | -- compute linearly annealed learning rate 299 | local t = math.max(0, self.numSteps - self.learn_start) 300 | self.lr = (self.lr_start - self.lr_end) * (self.lr_endt - t)/self.lr_endt + 301 | self.lr_end 302 | self.lr = math.max(self.lr, self.lr_end) 303 | 304 | -- use gradients 305 | self.g:mul(0.95):add(0.05, self.dw) 306 | self.tmp:cmul(self.dw, self.dw) 307 | self.g2:mul(0.95):add(0.05, self.tmp) 308 | self.tmp:cmul(self.g, self.g) 309 | self.tmp:mul(-1) 310 | self.tmp:add(self.g2) 311 | self.tmp:add(0.01) 312 | self.tmp:sqrt() 313 | 314 | -- accumulate update 315 | self.deltas:mul(0):addcdiv(self.lr, self.dw, self.tmp) 316 | self.w:add(self.deltas) 317 | end 318 | 319 | -- Returns valid_size experiences as validation data. 320 | function nql:sample_validation_data() 321 | local s, a, r, s2, term = self.transitions:sample(self.valid_size) 322 | self.valid_s = s:clone() 323 | self.valid_a = a:clone() 324 | self.valid_r = r:clone() 325 | self.valid_s2 = s2:clone() 326 | self.valid_term = term:clone() 327 | end 328 | 329 | -- Compute the mean Q value and the TD error for our early validation experiences. 330 | function nql:compute_validation_statistics() 331 | local targets, delta, q2_max = self:getQUpdate{s=self.valid_s, 332 | a=self.valid_a, r=self.valid_r, s2=self.valid_s2, term=self.valid_term} 333 | 334 | -- This is the average Q value of the target network for the highest-value action. 335 | -- This ideally should rise with learning and stabalize at a reasonable value... 336 | self.v_avg = self.q_max * q2_max:mean() 337 | 338 | -- This in essence is the difference between the target and current networks' value estimate for Q(s, a). 339 | -- This should approach zero with time as learning slows... 340 | self.tderr_avg = delta:clone():abs():mean() 341 | end 342 | 343 | -- Main function for observing the results and learning. 344 | function nql:perceive(reward, rawstate, terminal, testing, testing_ep) 345 | 346 | -- Preprocess state (will be set to nil if terminal) 347 | local state = self:preprocess(rawstate):float() 348 | local curState 349 | 350 | -- Clip the reward to the max/min, if requested. 351 | if self.max_reward then 352 | reward = math.min(reward, self.max_reward) 353 | end 354 | if self.min_reward then 355 | reward = math.max(reward, self.min_reward) 356 | end 357 | 358 | -- Add the preprocessed state and terminal value to the recent state table. 359 | self.transitions:add_recent_state(state, terminal) 360 | 361 | --Store transition s, a, r, s' 362 | if self.lastState and not testing then 363 | self.transitions:add(self.lastState, self.lastAction, reward, 364 | self.lastTerminal, priority) 365 | end 366 | 367 | -- Load validation data once we're past the initial phase. 368 | -- This is just a sample of experiences. 369 | if self.numSteps == self.learn_start+1 and not testing then 370 | self:sample_validation_data() 371 | end 372 | 373 | -- Get the hist_len most recent frames... 374 | -- Dimensions should be (hist_len, width, height). 375 | curState= self.transitions:get_recent() 376 | 377 | -- Add a dimension to make this into a one-entry minibatch 378 | -- to keep the network happy. 379 | curState = curState:resize(1, unpack(self.input_dims)) 380 | 381 | -- OK use the Q network to select an action based on 382 | -- the trailing hist_len frames. 383 | local actionIndex = 1 384 | if not terminal then 385 | actionIndex = self:eGreedy(curState, testing_ep) 386 | end 387 | 388 | -- Add this action to our experiences. 389 | -- This makes the recent states list complete with frames and actions. 390 | self.transitions:add_recent_action(actionIndex) 391 | 392 | -- Learn... 393 | if self.numSteps > self.learn_start and not testing and 394 | self.numSteps % self.update_freq == 0 then 395 | for i = 1, self.n_replay do 396 | self:qLearnMinibatch() 397 | end 398 | end 399 | 400 | -- Track the number of learning steps we've undertaken. 401 | if not testing then 402 | self.numSteps = self.numSteps + 1 403 | end 404 | 405 | -- Save the state and action for the next round. 406 | self.lastState = state:clone() 407 | self.lastAction = actionIndex 408 | self.lastTerminal = terminal 409 | 410 | -- After target_q steps, replace the existing Q network with the newer one. 411 | if self.target_q and self.numSteps % self.target_q == 1 then 412 | self.target_network = self.network:clone() 413 | end 414 | 415 | -- Return the action so we can feed it to the emulator. 416 | if not terminal then 417 | return actionIndex 418 | else 419 | return 0 420 | end 421 | end 422 | 423 | -- Return an action for the given state. 424 | function nql:eGreedy(state, testing_ep) 425 | self.ep = testing_ep or (self.ep_end + 426 | math.max(0, (self.ep_start - self.ep_end) * (self.ep_endt - 427 | math.max(0, self.numSteps - self.learn_start))/self.ep_endt)) 428 | 429 | -- Select an action, maybe randomly. 430 | if torch.uniform() < self.ep then 431 | 432 | -- Select a random action, with probability ep. 433 | return torch.random(1, self.n_actions) 434 | else 435 | 436 | -- Select the action with the highest Q value. 437 | return self:greedy(state) 438 | end 439 | end 440 | 441 | -- Return the action with the highest value given this state. 442 | function nql:greedy(state) 443 | -- Turn single state into minibatch. Needed for convolutional nets. 444 | if state:dim() == 2 then 445 | assert(false, 'Input must be at least 3D') 446 | state = state:resize(1, state:size(1), state:size(2)) 447 | end 448 | 449 | if self.gpu >= 0 then 450 | state = state:cuda() 451 | end 452 | 453 | -- Feed the state into the current network. 454 | local q = self.network:forward(state):float():squeeze() 455 | 456 | -- Initialize the best Q and best action variables. 457 | local maxq = q[1] 458 | local besta = {1} 459 | 460 | -- Evaluate all other actions (with random tie-breaking) 461 | for a = 2, self.n_actions do 462 | if q[a] > maxq then 463 | besta = { a } 464 | maxq = q[a] 465 | 466 | -- Tie, add a second best action to the list. 467 | elseif q[a] == maxq then 468 | besta[#besta+1] = a 469 | end 470 | end 471 | 472 | -- Keep track of our highest Q value. 473 | self.bestq = maxq 474 | 475 | local r = torch.random(1, #besta) 476 | 477 | -- Pick at random from the equally-performing actions. 478 | self.lastAction = besta[r] 479 | 480 | return besta[r] 481 | end 482 | 483 | 484 | function nql:createNetwork() 485 | local n_hid = 128 486 | local mlp = nn.Sequential() 487 | mlp:add(nn.Reshape(self.hist_len*self.ncols*self.state_dim)) 488 | mlp:add(nn.Linear(self.hist_len*self.ncols*self.state_dim, n_hid)) 489 | mlp:add(nn.Rectifier()) 490 | mlp:add(nn.Linear(n_hid, n_hid)) 491 | mlp:add(nn.Rectifier()) 492 | mlp:add(nn.Linear(n_hid, self.n_actions)) 493 | 494 | return mlp 495 | end 496 | 497 | 498 | function nql:_loadNet() 499 | local net = self.network 500 | if self.gpu then 501 | net:cuda() 502 | else 503 | net:float() 504 | end 505 | return net 506 | end 507 | 508 | 509 | function nql:init(arg) 510 | self.actions = arg.actions 511 | self.n_actions = #self.actions 512 | self.network = self:_loadNet() 513 | -- Generate targets. 514 | self.transitions:empty() 515 | end 516 | 517 | 518 | function nql:report() 519 | print(get_weight_norms(self.network)) 520 | print(get_grad_norms(self.network)) 521 | end 522 | 523 | -- Prints the hist_len most recent frames. 524 | -- Assumes images are square... 525 | function nql:printRecent() 526 | 527 | print("Saving frame snapshot...") 528 | 529 | for i = 1, self.hist_len do 530 | 531 | local filename = "Frame" .. self.transitions.histIndices[i] .. ".png" 532 | image.save(filename, self.transitions.recent_s[self.transitions.histIndices[i]]:resize(self.ncols, self.state_dim^.5, self.state_dim^.5)) 533 | --image.save(filename, self.transitions.recent_s[i]:clone():resize(self.ncols, self.state_dim^.5, self.state_dim^.5)) 534 | 535 | end 536 | end 537 | -------------------------------------------------------------------------------- /dqn/Rectifier.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | --[[ Rectified Linear Unit. 8 | 9 | The output is max(0, input). 10 | --]] 11 | 12 | local Rectifier, parent = torch.class('nn.Rectifier', 'nn.Module') 13 | 14 | -- This module accepts minibatches 15 | function Rectifier:updateOutput(input) 16 | return self.output:resizeAs(input):copy(input):abs():add(input):div(2) 17 | end 18 | 19 | function Rectifier:updateGradInput(input, gradOutput) 20 | self.gradInput:resizeAs(self.output) 21 | return self.gradInput:sign(self.output):cmul(gradOutput) 22 | end -------------------------------------------------------------------------------- /dqn/Scale.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "nn" 8 | require "image" 9 | 10 | local scale = torch.class('nn.Scale', 'nn.Module') 11 | 12 | 13 | function scale:__init(height, width) 14 | self.height = height 15 | self.width = width 16 | end 17 | 18 | function scale:forward(x) 19 | local x = x 20 | if x:dim() > 3 then 21 | x = x[1] 22 | end 23 | 24 | x = image.rgb2y(x) 25 | x = image.scale(x, self.width, self.height, 'bilinear') 26 | return x 27 | end 28 | 29 | function scale:updateOutput(input) 30 | return self:forward(input) 31 | end 32 | 33 | function scale:float() 34 | end 35 | -------------------------------------------------------------------------------- /dqn/TransitionTable.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require 'image' 8 | 9 | local trans = torch.class('dqn.TransitionTable') 10 | 11 | 12 | function trans:__init(args) 13 | self.stateDim = args.stateDim 14 | self.numActions = args.numActions 15 | self.histLen = args.histLen 16 | self.maxSize = args.maxSize or 1024^2 17 | self.bufferSize = args.bufferSize or 1024 18 | self.histType = args.histType or "linear" 19 | self.histSpacing = args.histSpacing or 1 20 | self.zeroFrames = args.zeroFrames or 1 21 | self.nonTermProb = args.nonTermProb or 1 22 | self.nonEventProb = args.nonEventProb or 1 23 | self.gpu = args.gpu 24 | self.numEntries = 0 25 | self.insertIndex = 0 26 | 27 | self.histIndices = {} 28 | local histLen = self.histLen 29 | if self.histType == "linear" then 30 | -- History is the last histLen frames. 31 | self.recentMemSize = self.histSpacing*histLen 32 | for i=1,histLen do 33 | self.histIndices[i] = i*self.histSpacing 34 | end 35 | elseif self.histType == "exp2" then 36 | -- The ith history frame is from 2^(i-1) frames ago. 37 | self.recentMemSize = 2^(histLen-1) 38 | self.histIndices[1] = 1 39 | for i=1,histLen-1 do 40 | self.histIndices[i+1] = self.histIndices[i] + 2^(7-i) 41 | end 42 | elseif self.histType == "exp1.25" then 43 | -- The ith history frame is from 1.25^(i-1) frames ago. 44 | self.histIndices[histLen] = 1 45 | for i=histLen-1,1,-1 do 46 | self.histIndices[i] = math.ceil(1.25*self.histIndices[i+1])+1 47 | end 48 | self.recentMemSize = self.histIndices[1] 49 | for i=1,histLen do 50 | self.histIndices[i] = self.recentMemSize - self.histIndices[i] + 1 51 | end 52 | end 53 | 54 | self.s = torch.ByteTensor(self.maxSize, self.stateDim):fill(0) 55 | self.a = torch.LongTensor(self.maxSize):fill(0) 56 | self.r = torch.zeros(self.maxSize) 57 | self.t = torch.ByteTensor(self.maxSize):fill(0) 58 | self.action_encodings = torch.eye(self.numActions) 59 | 60 | -- Tables for storing the last histLen states. They are used for 61 | -- constructing the most recent agent state more easily. 62 | self.recent_s = {} 63 | self.recent_a = {} 64 | self.recent_t = {} 65 | 66 | local s_size = self.stateDim*histLen 67 | self.buf_a = torch.LongTensor(self.bufferSize):fill(0) 68 | self.buf_r = torch.zeros(self.bufferSize) 69 | self.buf_term = torch.ByteTensor(self.bufferSize):fill(0) 70 | self.buf_s = torch.ByteTensor(self.bufferSize, s_size):fill(0) 71 | self.buf_s2 = torch.ByteTensor(self.bufferSize, s_size):fill(0) 72 | 73 | if self.gpu and self.gpu >= 0 then 74 | self.gpu_s = self.buf_s:float():cuda() 75 | self.gpu_s2 = self.buf_s2:float():cuda() 76 | end 77 | end 78 | 79 | 80 | function trans:reset() 81 | self.numEntries = 0 82 | self.insertIndex = 0 83 | end 84 | 85 | 86 | function trans:size() 87 | return self.numEntries 88 | end 89 | 90 | 91 | function trans:empty() 92 | return self.numEntries == 0 93 | end 94 | 95 | 96 | function trans:fill_buffer() 97 | assert(self.numEntries >= self.bufferSize) 98 | -- clear CPU buffers 99 | self.buf_ind = 1 100 | local ind 101 | for buf_ind=1,self.bufferSize do 102 | local s, a, r, s2, term = self:sample_one(1) 103 | self.buf_s[buf_ind]:copy(s) 104 | self.buf_a[buf_ind] = a 105 | self.buf_r[buf_ind] = r 106 | self.buf_s2[buf_ind]:copy(s2) 107 | self.buf_term[buf_ind] = term 108 | end 109 | self.buf_s = self.buf_s:float():div(255) 110 | self.buf_s2 = self.buf_s2:float():div(255) 111 | if self.gpu and self.gpu >= 0 then 112 | self.gpu_s:copy(self.buf_s) 113 | self.gpu_s2:copy(self.buf_s2) 114 | end 115 | end 116 | 117 | -- Returns one randomly sampled experience from memory. 118 | -- Consider setting a non-zero probability that we discard termal 119 | -- or non-reward experiences. 120 | function trans:sample_one() 121 | assert(self.numEntries > 1) 122 | local index 123 | local valid = false 124 | 125 | -- Loop to find a valid experience to send back. 126 | while not valid do 127 | 128 | -- Start at 2 because the experience at 1 might not have 129 | -- an action yet? 130 | -- The upper bound ensures there's enough frames to 131 | -- return the trailing hist_len frames. 132 | -- Grab a random experience from the memory table. 133 | index = torch.random(2, self.numEntries-self.recentMemSize) 134 | 135 | 136 | if self.t[index+self.recentMemSize-1] == 0 then 137 | valid = true 138 | end 139 | 140 | -- If nonTermProb is set to less than 1, there is a chance 141 | -- we discard terminal experiences. Not sure why we'd want to do this... 142 | if self.nonTermProb < 1 and self.t[index+self.recentMemSize] == 0 and 143 | torch.uniform() > self.nonTermProb then 144 | -- Note that this is the terminal flag for s_{t+1}. 145 | valid = false 146 | end 147 | 148 | -- If nonEventProb is set to less than one, there is a chance 149 | -- we discard experiences not resulting in rewards. Would this accelerate learning? 150 | if self.nonEventProb < 1 and self.r[index+self.recentMemSize-1] == 0 and 151 | torch.uniform() > self.nonEventProb then 152 | -- probability (1-nonEventProb). 153 | valid = false 154 | end 155 | end 156 | 157 | -- This returns the trailing hist_len frames from the requested index. 158 | return self:get(index) 159 | end 160 | 161 | 162 | function trans:sample(batch_size) 163 | local batch_size = batch_size or 1 164 | assert(batch_size < self.bufferSize) 165 | 166 | if not self.buf_ind or self.buf_ind + batch_size - 1 > self.bufferSize then 167 | self:fill_buffer() 168 | end 169 | 170 | local index = self.buf_ind 171 | 172 | self.buf_ind = self.buf_ind+batch_size 173 | local range = {{index, index+batch_size-1}} 174 | 175 | local buf_s, buf_s2, buf_a, buf_r, buf_term = self.buf_s, self.buf_s2, 176 | self.buf_a, self.buf_r, self.buf_term 177 | if self.gpu and self.gpu >=0 then 178 | buf_s = self.gpu_s 179 | buf_s2 = self.gpu_s2 180 | end 181 | 182 | return buf_s[range], buf_a[range], buf_r[range], buf_s2[range], buf_term[range] 183 | end 184 | 185 | 186 | function trans:concatFrames(index, use_recent) 187 | 188 | -- Should we use the recent state tables? 189 | if use_recent then 190 | s, t = self.recent_s, self.recent_t 191 | else 192 | s, t = self.s, self.t 193 | end 194 | 195 | local fullstate = s[1].new() 196 | fullstate:resize(self.histLen, unpack(s[1]:size():totable())) 197 | 198 | local zero_out = false 199 | local episode_start = self.histLen 200 | 201 | -- Zero out any frames occuring after a terminal frame. 202 | for i=self.histLen-1,1,-1 do 203 | if not zero_out then 204 | for j=index+self.histIndices[i]-1,index+self.histIndices[i+1]-2 do 205 | if t[j] == 1 then 206 | zero_out = true 207 | break 208 | end 209 | end 210 | end 211 | 212 | if zero_out then 213 | fullstate[i]:zero() 214 | else 215 | episode_start = i 216 | end 217 | end 218 | 219 | -- If there are no zero frames, copy the hist_len most recent frames. 220 | if self.zeroFrames == 0 then 221 | episode_start = 1 222 | end 223 | 224 | -- Copy frames from the current episode. 225 | for i=episode_start,self.histLen do 226 | fullstate[i]:copy(s[index+self.histIndices[i]-1]) 227 | end 228 | 229 | return fullstate 230 | end 231 | 232 | 233 | -- Get the hist_len most recent frames. 234 | function trans:get_recent() 235 | -- Assumes that the most recent state has been added, but the action has not 236 | return self:concatFrames(1, true):float():div(255) 237 | end 238 | 239 | -- Return a full state in a given index: (s, a, r, s2, terminal). 240 | function trans:get(index) 241 | local s = self:concatFrames(index) 242 | local s2 = self:concatFrames(index+1) 243 | local ar_index = index+self.recentMemSize-1 244 | 245 | return s, self.a[ar_index], self.r[ar_index], s2, self.t[ar_index+1] 246 | end 247 | 248 | -- Add a new experience. 249 | function trans:add(s, a, r, term) 250 | assert(s, 'State cannot be nil') 251 | assert(a, 'Action cannot be nil') 252 | assert(r, 'Reward cannot be nil') 253 | 254 | -- Incremement the memory counter. 255 | if self.numEntries < self.maxSize then 256 | self.numEntries = self.numEntries + 1 257 | 258 | -- Spam the console. 259 | if self.numEntries % 1000 == 0 then 260 | print("Recorded experiences: " .. self.numEntries) 261 | end 262 | if self.numEntries == self.maxSize then 263 | print("Filled up the experience record...") 264 | end 265 | end 266 | 267 | -- Always insert at next index, then wrap around 268 | self.insertIndex = self.insertIndex + 1 269 | 270 | -- Overwrite oldest experience once at capacity 271 | if self.insertIndex > self.maxSize then 272 | self.insertIndex = 1 273 | end 274 | 275 | -- Overwrite (s,a,r,t) at insertIndex 276 | self.s[self.insertIndex] = s:clone():float():mul(255) 277 | self.a[self.insertIndex] = a 278 | self.r[self.insertIndex] = r 279 | 280 | -- Record whether this was a terminal experience. 281 | if term then 282 | self.t[self.insertIndex] = 1 283 | else 284 | self.t[self.insertIndex] = 0 285 | end 286 | end 287 | 288 | 289 | function trans:add_recent_state(s, term) 290 | 291 | -- Process... 292 | local s = s:clone():float():mul(255):byte() 293 | 294 | -- Handle no recent states. 295 | if #self.recent_s == 0 then 296 | for i=1,self.recentMemSize do 297 | table.insert(self.recent_s, s:clone():zero()) 298 | table.insert(self.recent_t, 1) 299 | end 300 | end 301 | 302 | -- Record the state in the recent_s list. 303 | table.insert(self.recent_s, s) 304 | 305 | -- Record whether this was a terminal state in the recent_t list. 306 | if term then 307 | table.insert(self.recent_t, 1) 308 | else 309 | table.insert(self.recent_t, 0) 310 | end 311 | 312 | -- Kick out old recent states. 313 | -- recentMemSize is equal to the hist_len * hist_spacing 314 | if #self.recent_s > self.recentMemSize then 315 | table.remove(self.recent_s, 1) 316 | table.remove(self.recent_t, 1) 317 | end 318 | end 319 | 320 | 321 | function trans:add_recent_action(a) 322 | if #self.recent_a == 0 then 323 | for i=1,self.recentMemSize do 324 | table.insert(self.recent_a, 1) 325 | end 326 | end 327 | 328 | table.insert(self.recent_a, a) 329 | 330 | -- Keep recentMemSize steps. 331 | if #self.recent_a > self.recentMemSize then 332 | table.remove(self.recent_a, 1) 333 | end 334 | end 335 | 336 | 337 | --[[ 338 | Override the write function to serialize this class into a file. 339 | We do not want to store anything into the file, just the necessary info 340 | to create an empty transition table. 341 | 342 | @param file (FILE object ) @see torch.DiskFile 343 | --]] 344 | function trans:write(file) 345 | file:writeObject({self.stateDim, 346 | self.numActions, 347 | self.histLen, 348 | self.maxSize, 349 | self.bufferSize, 350 | self.numEntries, 351 | self.insertIndex, 352 | self.recentMemSize, 353 | self.histIndices}) 354 | end 355 | 356 | 357 | --[[ 358 | Override the read function to desearialize this class from file. 359 | Recreates an empty table. 360 | 361 | @param file (FILE object ) @see torch.DiskFile 362 | --]] 363 | function trans:read(file) 364 | local stateDim, numActions, histLen, maxSize, bufferSize, numEntries, insertIndex, recentMemSize, histIndices = unpack(file:readObject()) 365 | self.stateDim = stateDim 366 | self.numActions = numActions 367 | self.histLen = histLen 368 | self.maxSize = maxSize 369 | self.bufferSize = bufferSize 370 | self.recentMemSize = recentMemSize 371 | self.histIndices = histIndices 372 | self.numEntries = 0 373 | self.insertIndex = 0 374 | 375 | self.s = torch.ByteTensor(self.maxSize, self.stateDim):fill(0) 376 | self.a = torch.LongTensor(self.maxSize):fill(0) 377 | self.r = torch.zeros(self.maxSize) 378 | self.t = torch.ByteTensor(self.maxSize):fill(0) 379 | self.action_encodings = torch.eye(self.numActions) 380 | 381 | -- Tables for storing the last histLen states. They are used for 382 | -- constructing the most recent agent state more easily. 383 | self.recent_s = {} 384 | self.recent_a = {} 385 | self.recent_t = {} 386 | 387 | self.buf_a = torch.LongTensor(self.bufferSize):fill(0) 388 | self.buf_r = torch.zeros(self.bufferSize) 389 | self.buf_term = torch.ByteTensor(self.bufferSize):fill(0) 390 | self.buf_s = torch.ByteTensor(self.bufferSize, self.stateDim * self.histLen):fill(0) 391 | self.buf_s2 = torch.ByteTensor(self.bufferSize, self.stateDim * self.histLen):fill(0) 392 | 393 | if self.gpu and self.gpu >= 0 then 394 | self.gpu_s = self.buf_s:float():cuda() 395 | self.gpu_s2 = self.buf_s2:float():cuda() 396 | end 397 | end 398 | -------------------------------------------------------------------------------- /dqn/convnet.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "initenv" 8 | 9 | function create_network(args) 10 | 11 | local net = nn.Sequential() 12 | net:add(nn.Reshape(unpack(args.input_dims))) 13 | 14 | --- first convolutional layer 15 | local convLayer = nn.SpatialConvolution 16 | 17 | net:add(convLayer(args.hist_len*args.ncols, args.n_units[1], 18 | args.filter_size[1], args.filter_size[1], 19 | args.filter_stride[1], args.filter_stride[1],1)) 20 | net:add(args.nl()) 21 | 22 | -- Add convolutional layers 23 | for i=1,(#args.n_units-1) do 24 | -- second convolutional layer 25 | net:add(convLayer(args.n_units[i], args.n_units[i+1], 26 | args.filter_size[i+1], args.filter_size[i+1], 27 | args.filter_stride[i+1], args.filter_stride[i+1])) 28 | net:add(args.nl()) 29 | end 30 | 31 | local nel 32 | if args.gpu >= 0 then 33 | nel = net:cuda():forward(torch.zeros(1,unpack(args.input_dims)) 34 | :cuda()):nElement() 35 | else 36 | nel = net:forward(torch.zeros(1,unpack(args.input_dims))):nElement() 37 | end 38 | 39 | -- reshape all feature planes into a vector per example 40 | net:add(nn.Reshape(nel)) 41 | 42 | -- fully connected layer 43 | net:add(nn.Linear(nel, args.n_hid[1])) 44 | net:add(args.nl()) 45 | local last_layer_size = args.n_hid[1] 46 | 47 | for i=1,(#args.n_hid-1) do 48 | -- add Linear layer 49 | last_layer_size = args.n_hid[i+1] 50 | net:add(nn.Linear(args.n_hid[i], last_layer_size)) 51 | net:add(args.nl()) 52 | end 53 | 54 | -- add the last fully connected layer (to actions) 55 | net:add(nn.Linear(last_layer_size, args.n_actions)) 56 | 57 | if args.gpu >=0 then 58 | net:cuda() 59 | end 60 | if args.verbose >= 2 then 61 | print(net) 62 | print('Convolutional layers flattened output size:', nel) 63 | end 64 | return net 65 | end 66 | -------------------------------------------------------------------------------- /dqn/convnet_atari3.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require 'convnet' 8 | 9 | return function(args) 10 | args.n_units = {32, 64, 64} 11 | args.filter_size = {8, 4, 3} 12 | args.filter_stride = {4, 2, 1} 13 | args.n_hid = {512} 14 | args.nl = nn.Rectifier 15 | 16 | return create_network(args) 17 | end 18 | 19 | -------------------------------------------------------------------------------- /dqn/convnet_nes.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'convnet' 3 | 4 | return function(args) 5 | args.n_units = {32, 64, 128} 6 | args.filter_size = {8, 4, 3} 7 | args.filter_stride = {4, 2, 1} 8 | args.n_hid = {512} 9 | args.nl = nn.Rectifier 10 | 11 | return create_network(args) 12 | end 13 | 14 | -------------------------------------------------------------------------------- /dqn/initenv.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | dqn = {} 7 | 8 | require 'torch' 9 | require 'nn' 10 | require 'nngraph' 11 | require 'nnutils' 12 | require 'image' 13 | require 'Scale' 14 | require 'NeuralQLearner' 15 | require 'TransitionTable' 16 | require 'Rectifier' 17 | 18 | 19 | function torchSetup(_opt) 20 | _opt = _opt or {} 21 | local opt = table.copy(_opt) 22 | assert(opt) 23 | 24 | -- preprocess options: 25 | --- convert options strings to tables 26 | if opt.pool_frms then 27 | opt.pool_frms = str_to_table(opt.pool_frms) 28 | end 29 | if opt.env_params then 30 | opt.env_params = str_to_table(opt.env_params) 31 | end 32 | if opt.agent_params then 33 | opt.agent_params = str_to_table(opt.agent_params) 34 | opt.agent_params.gpu = opt.gpu 35 | opt.agent_params.best = opt.best 36 | opt.agent_params.verbose = opt.verbose 37 | if opt.network ~= '' then 38 | opt.agent_params.network = opt.network 39 | end 40 | end 41 | 42 | --- general setup 43 | opt.tensorType = opt.tensorType or 'torch.FloatTensor' 44 | torch.setdefaulttensortype(opt.tensorType) 45 | if not opt.threads then 46 | opt.threads = 4 47 | end 48 | torch.setnumthreads(opt.threads) 49 | if not opt.verbose then 50 | opt.verbose = 10 51 | end 52 | if opt.verbose >= 1 then 53 | print('Torch Threads:', torch.getnumthreads()) 54 | end 55 | 56 | --- set gpu device 57 | if opt.gpu and opt.gpu >= 0 then 58 | require 'cutorch' 59 | require 'cunn' 60 | if opt.gpu == 0 then 61 | local gpu_id = tonumber(os.getenv('GPU_ID')) 62 | if gpu_id then opt.gpu = gpu_id+1 end 63 | end 64 | if opt.gpu > 0 then cutorch.setDevice(opt.gpu) end 65 | opt.gpu = cutorch.getDevice() 66 | print('Using GPU device id:', opt.gpu-1) 67 | else 68 | opt.gpu = -1 69 | if opt.verbose >= 1 then 70 | print('Using CPU code only. GPU device id:', opt.gpu) 71 | end 72 | end 73 | 74 | --- set up random number generators 75 | -- removing lua RNG; seeding torch RNG with opt.seed and setting cutorch 76 | -- RNG seed to the first uniform random int32 from the previous RNG; 77 | -- this is preferred because using the same seed for both generators 78 | -- may introduce correlations; we assume that both torch RNGs ensure 79 | -- adequate dispersion for different seeds. 80 | math.random = nil 81 | opt.seed = opt.seed or 1 82 | torch.manualSeed(opt.seed) 83 | if opt.verbose >= 1 then 84 | print('Torch Seed:', torch.initialSeed()) 85 | end 86 | local firstRandInt = torch.random() 87 | if opt.gpu >= 0 then 88 | cutorch.manualSeed(firstRandInt) 89 | if opt.verbose >= 1 then 90 | print('CUTorch Seed:', cutorch.initialSeed()) 91 | end 92 | end 93 | 94 | return opt 95 | end 96 | 97 | 98 | function setup(_opt) 99 | assert(_opt) 100 | 101 | --preprocess options: 102 | --- convert options strings to tables 103 | _opt.pool_frms = str_to_table(_opt.pool_frms) 104 | _opt.env_params = str_to_table(_opt.env_params) 105 | _opt.agent_params = str_to_table(_opt.agent_params) 106 | if _opt.agent_params.transition_params then 107 | _opt.agent_params.transition_params = 108 | str_to_table(_opt.agent_params.transition_params) 109 | end 110 | 111 | --- first things first 112 | local opt = torchSetup(_opt) 113 | 114 | -- load training framework and environment 115 | local framework = require(opt.framework) 116 | assert(framework) 117 | 118 | local gameEnv = framework.GameEnvironment(opt) 119 | local gameActions = gameEnv:getActions() 120 | 121 | -- agent options 122 | _opt.agent_params.actions = gameActions 123 | _opt.agent_params.gpu = _opt.gpu 124 | _opt.agent_params.best = _opt.best 125 | if _opt.network ~= '' then 126 | _opt.agent_params.network = _opt.network 127 | end 128 | _opt.agent_params.verbose = _opt.verbose 129 | if not _opt.agent_params.state_dim then 130 | _opt.agent_params.state_dim = gameEnv:nObsFeature() 131 | end 132 | 133 | local agent = dqn[_opt.agent](_opt.agent_params) 134 | 135 | if opt.verbose >= 1 then 136 | print('Set up Torch using these options:') 137 | for k, v in pairs(opt) do 138 | print(k, v) 139 | end 140 | end 141 | 142 | return gameEnv, gameActions, agent, opt 143 | end 144 | 145 | 146 | 147 | --- other functions 148 | 149 | function str_to_table(str) 150 | if type(str) == 'table' then 151 | return str 152 | end 153 | if not str or type(str) ~= 'string' then 154 | if type(str) == 'table' then 155 | return str 156 | end 157 | return {} 158 | end 159 | local ttr 160 | if str ~= '' then 161 | local ttx=tt 162 | loadstring('tt = {' .. str .. '}')() 163 | ttr = tt 164 | tt = ttx 165 | else 166 | ttr = {} 167 | end 168 | return ttr 169 | end 170 | 171 | function table.copy(t) 172 | if t == nil then return nil end 173 | local nt = {} 174 | for k, v in pairs(t) do 175 | if type(v) == 'table' then 176 | nt[k] = table.copy(v) 177 | else 178 | nt[k] = v 179 | end 180 | end 181 | setmetatable(nt, table.copy(getmetatable(t))) 182 | return nt 183 | end 184 | -------------------------------------------------------------------------------- /dqn/net_downsample_2x_full_y.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "image" 8 | require "Scale" 9 | 10 | local function create_network(args) 11 | -- Y (luminance) 12 | return nn.Scale(84, 84, true) 13 | end 14 | 15 | return create_network 16 | -------------------------------------------------------------------------------- /dqn/nnutils.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "torch" 8 | 9 | function recursive_map(module, field, func) 10 | local str = "" 11 | if module[field] or module.modules then 12 | str = str .. torch.typename(module) .. ": " 13 | end 14 | if module[field] then 15 | str = str .. func(module[field]) 16 | end 17 | if module.modules then 18 | str = str .. "[" 19 | for i, submodule in ipairs(module.modules) do 20 | local submodule_str = recursive_map(submodule, field, func) 21 | str = str .. submodule_str 22 | if i < #module.modules and string.len(submodule_str) > 0 then 23 | str = str .. " " 24 | end 25 | end 26 | str = str .. "]" 27 | end 28 | 29 | return str 30 | end 31 | 32 | function abs_mean(w) 33 | return torch.mean(torch.abs(w:clone():float())) 34 | end 35 | 36 | function abs_max(w) 37 | return torch.abs(w:clone():float()):max() 38 | end 39 | 40 | -- Build a string of average absolute weight values for the modules in the 41 | -- given network. 42 | function get_weight_norms(module) 43 | return "Average absolute value weights:\n" .. recursive_map(module, "weight", abs_mean) .. 44 | "\nWeight max:\n" .. recursive_map(module, "weight", abs_max) 45 | end 46 | 47 | -- Build a string of average absolute weight gradient values for the modules 48 | -- in the given network. 49 | function get_grad_norms(module) 50 | return "Average absolute value weight gradients:\n" .. 51 | recursive_map(module, "gradWeight", abs_mean) .. 52 | "\nWeight grad max:\n" .. recursive_map(module, "gradWeight", abs_max) 53 | end 54 | -------------------------------------------------------------------------------- /dqn/test_agent.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | See LICENSE file for full terms of limited license. 4 | ]] 5 | 6 | gd = require "gd" 7 | 8 | if not dqn then 9 | require "initenv" 10 | end 11 | 12 | local cmd = torch.CmdLine() 13 | cmd:text() 14 | cmd:text('Train Agent in Environment:') 15 | cmd:text() 16 | cmd:text('Options:') 17 | 18 | cmd:option('-framework', '', 'name of training framework') 19 | cmd:option('-env', '', 'name of environment to use') 20 | cmd:option('-game_path', '', 'path to environment file (ROM)') 21 | cmd:option('-env_params', '', 'string of environment parameters') 22 | cmd:option('-pool_frms', '', 23 | 'string of frame pooling parameters (e.g.: size=2,type="max")') 24 | cmd:option('-actrep', 1, 'how many times to repeat action') 25 | cmd:option('-gameOverPenalty', 0, 'penalty for the game ending') 26 | cmd:option('-random_starts', 0, 'play action 0 between 1 and random_starts ' .. 27 | 'number of times at the start of each training episode') 28 | 29 | cmd:option('-name', '', 'filename used for saving network and training history') 30 | cmd:option('-network', '', 'reload pretrained network') 31 | cmd:option('-agent', '', 'name of agent file to use') 32 | cmd:option('-agent_params', '', 'string of agent parameters') 33 | cmd:option('-seed', 1, 'fixed input seed for repeatable experiments') 34 | 35 | cmd:option('-verbose', 2, 36 | 'the higher the level, the more information is printed to screen') 37 | cmd:option('-threads', 1, 'number of BLAS threads') 38 | cmd:option('-gpu', -1, 'gpu flag') 39 | cmd:option('-gif_file', '', 'GIF path to write session screens') 40 | cmd:option('-csv_file', '', 'CSV path to write session data') 41 | 42 | cmd:text() 43 | 44 | local opt = cmd:parse(arg) 45 | 46 | --- General setup. 47 | local game_env, game_actions, agent, opt = setup(opt) 48 | 49 | -- override print to always flush the output 50 | local old_print = print 51 | local print = function(...) 52 | old_print(...) 53 | io.flush() 54 | end 55 | 56 | -- file names from command line 57 | local gif_filename = opt.gif_file 58 | 59 | -- start a new game 60 | local screen, reward, terminal = game_env:newGame() 61 | 62 | -- compress screen to JPEG with 100% quality 63 | local jpg = image.compressJPG(screen:squeeze(), 100) 64 | -- create gd image from JPEG string 65 | local im = gd.createFromJpegStr(jpg:storage():string()) 66 | -- convert truecolor to palette 67 | im:trueColorToPalette(false, 256) 68 | 69 | -- write GIF header, use global palette and infinite looping 70 | im:gifAnimBegin(gif_filename, true, 0) 71 | -- write first frame 72 | im:gifAnimAdd(gif_filename, false, 0, 0, 7, gd.DISPOSAL_NONE) 73 | 74 | -- remember the image and show it first 75 | local previm = im 76 | local win = nil 77 | -- local win = image.display({image=screen}) 78 | 79 | print("Started playing...") 80 | 81 | -- play one episode (game) 82 | while not terminal do 83 | -- if action was chosen randomly, Q-value is 0 84 | agent.bestq = 0 85 | 86 | -- choose the best action 87 | local action_index = agent:perceive(reward, screen, terminal, true, 0.05) 88 | 89 | -- play game in test mode (episodes don't end when losing a life) 90 | screen, reward, terminal = game_env:step(game_actions[action_index], false) 91 | 92 | -- display screen 93 | -- image.display({image=screen, win=win}) 94 | 95 | -- create gd image from tensor 96 | jpg = image.compressJPG(screen:squeeze(), 100) 97 | im = gd.createFromJpegStr(jpg:storage():string()) 98 | 99 | -- use palette from previous (first) image 100 | im:trueColorToPalette(false, 256) 101 | im:paletteCopy(previm) 102 | 103 | -- write new GIF frame, no local palette, starting from left-top, 7ms delay 104 | im:gifAnimAdd(gif_filename, false, 0, 0, 7, gd.DISPOSAL_NONE) 105 | -- remember previous screen for optimal compression 106 | previm = im 107 | 108 | end 109 | 110 | -- end GIF animation and close CSV file 111 | gd.gifAnimEnd(gif_filename) 112 | 113 | print("Finished playing, close window to exit!") 114 | -------------------------------------------------------------------------------- /dqn/train_agent.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | if not dqn then 8 | require "initenv" 9 | end 10 | 11 | local cmd = torch.CmdLine() 12 | cmd:text() 13 | cmd:text('Train Agent in Environment:') 14 | cmd:text() 15 | cmd:text('Options:') 16 | 17 | cmd:option('-framework', '', 'name of training framework') 18 | cmd:option('-env', '', 'name of environment to use') 19 | cmd:option('-game_path', '', 'path to environment file (ROM)') 20 | cmd:option('-env_params', '', 'string of environment parameters') 21 | cmd:option('-pool_frms', '', 22 | 'string of frame pooling parameters (e.g.: size=2,type="max")') 23 | cmd:option('-actrep', 1, 'how many times to repeat action') 24 | cmd:option('-gameOverPenalty', 0, 'penalty for the game ending') 25 | cmd:option('-random_starts', 0, 'play action 0 between 1 and random_starts ' .. 26 | 'number of times at the start of each training episode') 27 | 28 | cmd:option('-name', '', 'filename used for saving network and training history') 29 | cmd:option('-network', '', 'reload pretrained network') 30 | cmd:option('-agent', '', 'name of agent file to use') 31 | cmd:option('-agent_params', '', 'string of agent parameters') 32 | cmd:option('-seed', 1, 'fixed input seed for repeatable experiments') 33 | cmd:option('-saveNetworkParams', true, 34 | 'saves the agent network in a separate file') 35 | cmd:option('-prog_freq', 5*10^3, 'frequency of progress output') 36 | cmd:option('-save_freq', 5*10^4, 'the model is saved every save_freq steps') 37 | cmd:option('-eval_freq', 10^4, 'frequency of greedy evaluation') 38 | cmd:option('-save_versions', 0, '') 39 | 40 | cmd:option('-steps', 10^5, 'number of training steps to perform') 41 | cmd:option('-eval_steps', 10^5, 'number of evaluation steps') 42 | 43 | cmd:option('-verbose', 2, 44 | 'the higher the level, the more information is printed to screen') 45 | cmd:option('-threads', 1, 'number of BLAS threads') 46 | cmd:option('-gpu', -1, 'gpu flag') 47 | 48 | cmd:text() 49 | 50 | local opt = cmd:parse(arg) 51 | 52 | --- General setup. 53 | local game_env, game_actions, agent, opt = setup(opt) 54 | 55 | -- override print to always flush the output 56 | local old_print = print 57 | local print = function(...) 58 | old_print(...) 59 | io.flush() 60 | end 61 | 62 | local learn_start = agent.learn_start 63 | local start_time = sys.clock() 64 | local reward_counts = {} 65 | local episode_counts = {} 66 | local time_history = {} 67 | local v_history = {} 68 | local qmax_history = {} 69 | local td_history = {} 70 | local reward_history = {} 71 | local step = 0 72 | time_history[1] = 0 73 | 74 | local total_reward 75 | local nrewards 76 | local nepisodes 77 | local episode_reward 78 | 79 | -- Take one single initial step to get kicked-off... 80 | local screen, reward, terminal = game_env:getState() 81 | 82 | local last_step_log_time = sys.clock() 83 | local win = nil 84 | while step < opt.steps do 85 | step = step + 1 86 | local action_index = agent:perceive(reward, screen, terminal) 87 | 88 | -- game over? get next game! 89 | if not terminal then 90 | 91 | -- Play the selected action in the emulator. 92 | -- Record the resulting screen, reward, and whether this was terminal. 93 | screen, reward, terminal = game_env:step(game_actions[action_index], true) 94 | 95 | -- Spam the console. 96 | if opt.verbose > 3 and reward ~= 0 then 97 | print("Reward: " .. reward) 98 | end 99 | else 100 | if opt.random_starts > 0 then 101 | screen, reward, terminal = game_env:nextRandomGame() 102 | 103 | -- Spam the console. 104 | if opt.verbose > 3 then 105 | print("New random episode.") 106 | end 107 | else 108 | screen, reward, terminal = game_env:newGame() 109 | 110 | -- Spam the console. 111 | if opt.verbose > 3 then 112 | print("New episode.") 113 | end 114 | end 115 | end 116 | 117 | -- display screen 118 | -- win = image.display({image=screen, win=win}) 119 | 120 | -- Logging... 121 | if step % 10000 == 0 then 122 | local elapsed_step_time = sys.clock() - last_step_log_time 123 | last_step_log_time = sys.clock() 124 | print("Steps: " .. step .. " Time: " .. elapsed_step_time) 125 | end 126 | 127 | if step % opt.prog_freq == 0 then 128 | assert(step==agent.numSteps, 'trainer step: ' .. step .. 129 | ' & agent.numSteps: ' .. agent.numSteps) 130 | print("Steps: ", step) 131 | print("Epsilon: ", agent.ep) 132 | agent:report() 133 | 134 | -- Save the hist_len most recent frames. 135 | if opt.verbose > 3 then 136 | agent:printRecent() 137 | end 138 | collectgarbage() 139 | end 140 | 141 | if step%1000 == 0 then collectgarbage() end 142 | 143 | if step % opt.eval_freq == 0 and step > learn_start then 144 | 145 | print("***********") 146 | print("Starting evaluation!") 147 | print("***********") 148 | 149 | screen, reward, terminal = game_env:newGame() 150 | 151 | total_reward = 0 152 | nrewards = 0 153 | nepisodes = 0 154 | episode_reward = 0 155 | 156 | local eval_time = sys.clock() 157 | for estep=1,opt.eval_steps do 158 | local action_index = agent:perceive(reward, screen, terminal, true, 0.05) 159 | 160 | -- Play game in test mode (episodes don't end when losing a life) 161 | screen, reward, terminal = game_env:step(game_actions[action_index]) 162 | 163 | -- display screen 164 | -- This seems to cause crashes :\ 165 | -- win = image.display({image=screen, win=win}) 166 | 167 | if estep%1000 == 0 then collectgarbage() end 168 | 169 | -- record every reward 170 | episode_reward = episode_reward + reward 171 | if reward ~= 0 then 172 | nrewards = nrewards + 1 173 | end 174 | 175 | if opt.verbose > 3 and reward ~= 0 then 176 | print("Episode Reward: " .. episode_reward) 177 | print ("Number of Rewards: " .. nrewards) 178 | end 179 | 180 | if terminal then 181 | total_reward = total_reward + episode_reward 182 | episode_reward = 0 183 | nepisodes = nepisodes + 1 184 | 185 | if opt.verbose > 3 then 186 | print("Total Reward: " .. total_reward) 187 | end 188 | if opt.random_starts > 0 then 189 | screen, reward, terminal = game_env:nextRandomGame() 190 | else 191 | screen, reward, terminal = game_env:newGame() 192 | end 193 | end 194 | end 195 | 196 | eval_time = sys.clock() - eval_time 197 | start_time = start_time + eval_time 198 | agent:compute_validation_statistics() 199 | local ind = #reward_history+1 200 | total_reward = total_reward/math.max(1, nepisodes) 201 | 202 | if #reward_history == 0 or total_reward > torch.Tensor(reward_history):max() then 203 | agent.best_network = agent.network:clone() 204 | end 205 | 206 | if agent.v_avg then 207 | v_history[ind] = agent.v_avg 208 | td_history[ind] = agent.tderr_avg 209 | qmax_history[ind] = agent.q_max 210 | end 211 | print("V", v_history[ind], "TD error", td_history[ind], "Qmax", qmax_history[ind]) 212 | 213 | reward_history[ind] = total_reward 214 | reward_counts[ind] = nrewards 215 | episode_counts[ind] = nepisodes 216 | 217 | time_history[ind+1] = sys.clock() - start_time 218 | 219 | local time_dif = time_history[ind+1] - time_history[ind] 220 | 221 | local training_rate = opt.actrep*opt.eval_freq/time_dif 222 | 223 | print(string.format( 224 | '\nSteps: %d (frames: %d), reward: %.2f, epsilon: %.2f, lr: %G, ' .. 225 | 'training time: %ds, training rate: %dfps, testing time: %ds, ' .. 226 | 'testing rate: %dfps, num. ep.: %d, num. rewards: %d', 227 | step, step*opt.actrep, total_reward, agent.ep, agent.lr, time_dif, 228 | training_rate, eval_time, opt.actrep*opt.eval_steps/eval_time, 229 | nepisodes, nrewards)) 230 | end 231 | 232 | if step % opt.save_freq == 0 or step == opt.steps then 233 | local s, a, r, s2, term = agent.valid_s, agent.valid_a, agent.valid_r, 234 | agent.valid_s2, agent.valid_term 235 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2, 236 | agent.valid_term = nil, nil, nil, nil, nil, nil, nil 237 | local w, dw, g, g2, delta, delta2, deltas, tmp = agent.w, agent.dw, 238 | agent.g, agent.g2, agent.delta, agent.delta2, agent.deltas, agent.tmp 239 | agent.w, agent.dw, agent.g, agent.g2, agent.delta, agent.delta2, 240 | agent.deltas, agent.tmp = nil, nil, nil, nil, nil, nil, nil, nil 241 | 242 | local filename = opt.name 243 | if opt.save_versions > 0 then 244 | filename = filename .. "_" .. math.floor(step / opt.save_versions) 245 | end 246 | filename = filename 247 | torch.save(filename .. ".t7", {agent = agent, 248 | model = agent.network, 249 | best_model = agent.best_network, 250 | reward_history = reward_history, 251 | reward_counts = reward_counts, 252 | episode_counts = episode_counts, 253 | time_history = time_history, 254 | v_history = v_history, 255 | td_history = td_history, 256 | qmax_history = qmax_history, 257 | arguments=opt}) 258 | if opt.saveNetworkParams then 259 | local nets = {network=w:clone():float()} 260 | torch.save(filename..'.params.t7', nets, 'ascii') 261 | end 262 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2, 263 | agent.valid_term = s, a, r, s2, term 264 | agent.w, agent.dw, agent.g, agent.g2, agent.delta, agent.delta2, 265 | agent.deltas, agent.tmp = w, dw, g, g2, delta, delta2, deltas, tmp 266 | print("***********") 267 | print('Saved:', filename .. '.t7') 268 | print("***********") 269 | io.flush() 270 | collectgarbage() 271 | end 272 | end 273 | -------------------------------------------------------------------------------- /install_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ###################################################################### 4 | # Torch install 5 | ###################################################################### 6 | 7 | 8 | TOPDIR=$PWD 9 | 10 | # Prefix: 11 | PREFIX=$PWD/torch 12 | echo "Installing Torch into: $PREFIX" 13 | 14 | if [[ `uname` != 'Linux' ]]; then 15 | echo 'Platform unsupported, only available for Linux' 16 | exit 17 | fi 18 | if [[ `which apt-get` == '' ]]; then 19 | echo 'apt-get not found, platform not supported' 20 | exit 21 | fi 22 | 23 | # Install dependencies for Torch: 24 | sudo apt-get update 25 | sudo apt-get install -qqy build-essential 26 | sudo apt-get install -qqy gcc g++ 27 | sudo apt-get install -qqy cmake 28 | sudo apt-get install -qqy curl 29 | sudo apt-get install -qqy libreadline-dev 30 | sudo apt-get install -qqy git-core 31 | sudo apt-get install -qqy libjpeg-dev 32 | sudo apt-get install -qqy libpng-dev 33 | sudo apt-get install -qqy ncurses-dev 34 | sudo apt-get install -qqy imagemagick 35 | sudo apt-get install -qqy unzip 36 | sudo apt-get install -qqy libqt4-dev 37 | sudo apt-get install -qqy liblua5.1-0-dev 38 | sudo apt-get install -qqy libgd-dev 39 | sudo apt-get install -qqy scons 40 | sudo apt-get install -qqy libgtk2.0-dev 41 | sudo apt-get install -qqy libsdl-dev 42 | sudo apt-get update 43 | 44 | 45 | echo "==> Torch7's dependencies have been installed" 46 | 47 | 48 | 49 | 50 | 51 | # Build and install Torch7 52 | cd /tmp 53 | rm -rf luajit-rocks 54 | git clone https://github.com/torch/luajit-rocks.git 55 | cd luajit-rocks 56 | mkdir -p build 57 | cd build 58 | git checkout master; git pull 59 | rm -f CMakeCache.txt 60 | cmake .. -DCMAKE_INSTALL_PREFIX=$PREFIX -DCMAKE_BUILD_TYPE=Release 61 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 62 | make 63 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 64 | make install 65 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 66 | 67 | 68 | path_to_nvcc=$(which nvcc) 69 | if [ -x "$path_to_nvcc" ] 70 | then 71 | cutorch=ok 72 | cunn=ok 73 | fi 74 | 75 | # Install base packages: 76 | $PREFIX/bin/luarocks install cwrap 77 | $PREFIX/bin/luarocks install paths 78 | $PREFIX/bin/luarocks install torch 79 | $PREFIX/bin/luarocks install nn 80 | 81 | [ -n "$cutorch" ] && \ 82 | ($PREFIX/bin/luarocks install cutorch) 83 | [ -n "$cunn" ] && \ 84 | ($PREFIX/bin/luarocks install cunn) 85 | 86 | $PREFIX/bin/luarocks install luafilesystem 87 | $PREFIX/bin/luarocks install penlight 88 | $PREFIX/bin/luarocks install sys 89 | $PREFIX/bin/luarocks install xlua 90 | $PREFIX/bin/luarocks install image 91 | $PREFIX/bin/luarocks install env 92 | $PREFIX/bin/luarocks install qtlua 93 | $PREFIX/bin/luarocks install qttorch 94 | 95 | echo "" 96 | echo "=> Torch7 has been installed successfully" 97 | echo "" 98 | 99 | 100 | echo "Installing nngraph ... " 101 | $PREFIX/bin/luarocks install nngraph 102 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 103 | echo "nngraph installation completed" 104 | 105 | echo "Installing FCEUX_Learning_Environment ... " 106 | cd /tmp 107 | rm -rf FCEUX_Learning_Environment 108 | git clone https://github.com/ehrenbrav/FCEUX_Learning_Environment.git 109 | cd FCEUX_Learning_Environment 110 | $PREFIX/bin/luarocks make 111 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 112 | echo "FCEUX installation completed" 113 | 114 | echo "Installing neswrap ... " 115 | cd /tmp 116 | rm -rf neswrap 117 | git clone https://github.com/ehrenbrav/neswrap.git 118 | cd neswrap 119 | $PREFIX/bin/luarocks make 120 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 121 | echo "neswrap installation completed" 122 | 123 | echo "Installing Lua-GD ... " 124 | mkdir $PREFIX/src 125 | cd $PREFIX/src 126 | rm -rf lua-gd 127 | git clone https://github.com/ittner/lua-gd.git 128 | cd lua-gd 129 | sed -i "s/LUABIN=lua5.1/LUABIN=..\/..\/bin\/luajit/" Makefile 130 | $PREFIX/bin/luarocks make 131 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 132 | echo "Lua-GD installation completed" 133 | 134 | #echo "Installing GPU dependencies..." 135 | #$PREFIX/bin/luarocks install cutorch 136 | #$PREFIX/bin/luarocks install cunn 137 | #echo "Done trying to install the GPU dependencies." 138 | 139 | echo 140 | echo "All done!" 141 | 142 | -------------------------------------------------------------------------------- /logs/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /roms/README: -------------------------------------------------------------------------------- 1 | Rom files should be put in this directory 2 | -------------------------------------------------------------------------------- /roms/breakout.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ehrenbrav/DeepQNetwork/98c8d2cf0858bf575322c871bd693f7860397270/roms/breakout.bin -------------------------------------------------------------------------------- /saves/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /test_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./watch_pretrained breakout"; exit 0 5 | fi 6 | 7 | if [ -z "$2" ] 8 | then echo "Please provide the pretrained network file, e.g. ./watch_pretrained breakout DQN3_0_1_breakout_FULL_Y.t7"; exit 0 9 | fi 10 | 11 | ENV=$1 12 | NETWORK=$2 13 | FRAMEWORK="neswrap" 14 | 15 | game_path=$PWD"/roms/" 16 | env_params="useRGB=true" 17 | agent="NeuralQLearner" 18 | n_replay=4 19 | netfile="\"convnet_nes\"" 20 | update_freq=4 21 | actrep=8 22 | discount=0.99 23 | seed=1 24 | learn_start=5000 25 | pool_frms_type="\"max\"" 26 | pool_frms_size=1 27 | initial_priority="false" 28 | replay_memory=100 # This doesn't matter for testing... 29 | eps_end=0.1 30 | eps_endt=500000 31 | lr=0.01 32 | agent_type="DQN3_0_1" 33 | preproc_net="\"net_downsample_2x_full_y\"" 34 | agent_name=$agent_type"_"$1"_FULL_Y" 35 | state_dim=7056 36 | ncols=1 37 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=256,rescale_r=1,ncols="$ncols",bufferSize=1024,valid_size=1000,target_q=10000,clip_delta=1,min_reward=-10000,max_reward=10000" 38 | gif_file="../gifs/$ENV.gif" 39 | gpu=-1 40 | random_starts=0 41 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 42 | num_threads=8 43 | 44 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -network $NETWORK -gif_file $gif_file" 45 | echo $args 46 | 47 | cd dqn 48 | ../torch/bin/qlua test_agent.lua $args 49 | -------------------------------------------------------------------------------- /test_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./watch_pretrained breakout"; exit 0 5 | fi 6 | 7 | if [ -z "$2" ] 8 | then echo "Please provide the pretrained network file, e.g. ./watch_pretrained breakout DQN3_0_1_breakout_FULL_Y.t7"; exit 0 9 | fi 10 | 11 | ENV=$1 12 | NETWORK=$2 13 | FRAMEWORK="neswrap" 14 | 15 | game_path=$PWD"/roms/" 16 | env_params="useRGB=true" 17 | agent="NeuralQLearner" 18 | n_replay=4 19 | netfile="\"convnet_nes\"" 20 | update_freq=4 21 | actrep=8 22 | discount=0.99 23 | seed=1 24 | learn_start=5000 25 | pool_frms_type="\"max\"" 26 | pool_frms_size=1 27 | initial_priority="false" 28 | replay_memory=100 # This doesn't matter for testing... 29 | eps_end=0.1 30 | eps_endt=500000 31 | lr=0.01 32 | agent_type="DQN3_0_1" 33 | preproc_net="\"net_downsample_2x_full_y\"" 34 | agent_name=$agent_type"_"$1"_FULL_Y" 35 | state_dim=7056 36 | ncols=1 37 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=256,rescale_r=1,ncols="$ncols",bufferSize=1024,valid_size=1000,target_q=10000,clip_delta=1,min_reward=-10000,max_reward=10000" 38 | gif_file="../gifs/$ENV.gif" 39 | gpu=0 40 | random_starts=0 41 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 42 | num_threads=8 43 | 44 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -network $NETWORK -gif_file $gif_file" 45 | echo $args 46 | 47 | cd dqn 48 | ../torch/bin/qlua test_agent.lua $args 49 | -------------------------------------------------------------------------------- /train_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./run_cpu breakout "; exit 0 5 | fi 6 | ENV=$1 7 | 8 | # FRAMEWORK OPTIONS 9 | FRAMEWORK="neswrap" # Wrapper for the FCEUX Nintendo emulator. 10 | game_path=$PWD"/roms/" 11 | env_params="useRGB=true" 12 | steps=5000000 # Total steps to run the model. 50M for Atari. 13 | save_freq=100000 # Save every save_freq steps. Save early and often! 125k for Atari. 14 | 15 | # PREPROCESSOR OPTIONS 16 | preproc_net="\"net_downsample_2x_full_y\"" 17 | pool_frms_type="\"max\"" 18 | pool_frms_size=1 # Changed from 2 for Atari, since we don't have the same limitations for NES. 19 | initial_priority="false" 20 | state_dim=7056 # The number of pixels in the screen. 21 | ncols=1 # Represents just the Y (ie - grayscale) channel. 22 | 23 | # AGENT OPTIONS 24 | agent="NeuralQLearner" 25 | agent_type="DQN3_0_1" 26 | agent_name=$agent_type"_"$1"_FULL_Y" 27 | actrep=8 # Number of times an action is repeated (and a screen returned). 4 for Atari... 28 | ep=1 # The probability of choosing a random action rather than the best predicted action. 29 | eps_end=0.01 # What epsilon ends up as going forward. 30 | eps_endt=1000000 # This probability decreases over time, presumably as we get better. 31 | max_reward=10000 # Rewards are clipped to this value. 32 | min_reward=-10000 # Ditto. 33 | rescale_r=1 # Rescale rewards to [0, 1] 34 | gameOverPenalty=1 # Gives a negative reward upon dying. 35 | 36 | # LEARNING OPTIONS 37 | lr=0.00025 # .00025 for Atari. 38 | learn_start=50000 # Only start learning after this many steps. Should be bigger than bufferSize. Was set to 50k for Atari. 39 | replay_memory=1000000 # Set small to speed up debugging. 1M is the Atari setting... Big memory object! 40 | n_replay=4 # Minibatches to learn from each learning step. 41 | nonEventProb=nil # Probability of selecting a non-reward-bearing experience. 42 | clip_delta=1 # Limit the delta to +/- 1. 43 | 44 | # Q NETWORK OPTIONS 45 | netfile="\"convnet_nes\"" 46 | target_q=30000 # Steps to replace target nework with the updated one. Atari: 10k. DoubleDQN: 30k 47 | update_freq=4 # How often do we update the Q network? 48 | hist_len=4 # Number of trailing frames to input into the Q network. 4 for Atari... 49 | discount=0.99 # Discount rate given to future rewards. 50 | 51 | # VALIDATION AND EVALUATION 52 | eval_freq=50000 # Evaluate the model every eval_freq steps by calculating the score per episode for a few games. 250k for Atari. 53 | eval_steps=10000 # How many steps does an evaluation last? 125k for Atari. 54 | prog_freq=50000 # How often do you want a progress report? 55 | 56 | # PERFORMANCE AND DEBUG OPTIONS 57 | gpu=-1 # Zero means "use the GPU" which is a bit confusing... -1 for CPU. 58 | num_threads=8 59 | verbose=3 # 2 is default. 3 turns on debugging messages about what the model is doing. 60 | random_starts=0 # How many NOOPs to perform at the start of a game (random number up to this value). Shouldn't matter for SMB? 61 | seed=1 62 | #saved_network="" 63 | 64 | # THE UGLY UNDERBELLY 65 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 66 | 67 | agent_params="lr="$lr",ep="$ep",ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len="$hist_len",learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,ncols="$ncols",bufferSize=1024,valid_size=1000,target_q="$target_q",clip_delta="$clip_delta"",min_reward="$min_reward",max_reward="$max_reward",rescale_r="$rescale_r",nonEventProb="$nonEventProb" 68 | 69 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -verbose $verbose -gameOverPenalty $gameOverPenalty" #-network $saved_network" 70 | 71 | # Copy stdout and stderr to a logfile. 72 | LOGFILE="logs/dqn_log_`/bin/date +\"%F:%R\"`" 73 | exec > >(tee -i ${LOGFILE}) 74 | exec 2>&1 75 | 76 | echo $args 77 | 78 | cd dqn 79 | ../torch/bin/qlua train_agent.lua $args 80 | -------------------------------------------------------------------------------- /train_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./run_cpu breakout "; exit 0 5 | fi 6 | ENV=$1 7 | 8 | # FRAMEWORK OPTIONS 9 | FRAMEWORK="neswrap" # Wrapper for the FCEUX Nintendo emulator. 10 | game_path=$PWD"/roms/" 11 | env_params="useRGB=true" 12 | steps=5000000 # Total steps to run the model. 50M for Atari. 13 | save_freq=100000 # Save every save_freq steps. Save early and often! 125k for Atari. 14 | 15 | # PREPROCESSOR OPTIONS 16 | preproc_net="\"net_downsample_2x_full_y\"" 17 | pool_frms_type="\"max\"" 18 | pool_frms_size=1 # Changed from 2 for Atari, since we don't have the same limitations for NES. 19 | initial_priority="false" 20 | state_dim=7056 # The number of pixels in the screen. 21 | ncols=1 # Represents just the Y (ie - grayscale) channel. 22 | 23 | # AGENT OPTIONS 24 | agent="NeuralQLearner" 25 | agent_type="DQN3_0_1" 26 | agent_name=$agent_type"_"$1"_FULL_Y" 27 | actrep=8 # Number of times an action is repeated (and a screen returned). 4 for Atari... 28 | ep=1 # The probability of choosing a random action rather than the best predicted action. 29 | eps_end=0.01 # What epsilon ends up as going forward. 30 | eps_endt=1000000 # This probability decreases over time, presumably as we get better. 31 | max_reward=10000 # Rewards are clipped to this value. 32 | min_reward=-10000 # Ditto. 33 | rescale_r=1 # Rescale rewards to [-1, 1] 34 | gameOverPenalty=1 # Gives a negative reward upon dying. 35 | 36 | # LEARNING OPTIONS 37 | lr=0.00025 # .00025 for Atari. 38 | learn_start=50000 # Only start learning after this many steps. Should be bigger than bufferSize. Was set to 50k for Atari. 39 | replay_memory=1000000 # Set small to speed up debugging. 1M is the Atari setting... Big memory object! 40 | n_replay=4 # Minibatches to learn from each learning step. 41 | nonEventProb=nil # Probability of selecting a non-reward-bearing experience. 42 | clip_delta=1 # Limit the delta to +/- 1. 43 | 44 | # Q NETWORK OPTIONS 45 | netfile="\"convnet_nes\"" 46 | target_q=30000 # Steps to replace target nework with the updated one. Atari: 10k. DoubleDQN: 30k 47 | update_freq=4 # How often do we update the Q network? 48 | hist_len=4 # Number of trailing frames to input into the Q network. 4 for Atari... 49 | discount=0.99 # Discount rate given to future rewards. 50 | 51 | # VALIDATION AND EVALUATION 52 | eval_freq=50000 # Evaluate the model every eval_freq steps by calculating the score per episode for a few games. 250k for Atari. 53 | eval_steps=10000 # How many steps does an evaluation last? 125k for Atari. 54 | prog_freq=50000 # How often do you want a progress report? 55 | 56 | # PERFORMANCE AND DEBUG OPTIONS 57 | gpu=0 # Zero means "use the GPU" which is a bit confusing... -1 for CPU. 58 | num_threads=8 59 | verbose=3 # 2 is default. 3 turns on debugging messages about what the model is doing. 60 | random_starts=0 # How many NOOPs to perform at the start of a game (random number up to this value). Shouldn't matter for SMB? 61 | seed=1 62 | #saved_network="" 63 | 64 | # THE UGLY UNDERBELLY 65 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 66 | 67 | agent_params="lr="$lr",ep="$ep",ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len="$hist_len",learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,ncols="$ncols",bufferSize=1024,valid_size=1000,target_q="$target_q",clip_delta="$clip_delta"",min_reward="$min_reward",max_reward="$max_reward",rescale_r="$rescale_r",nonEventProb="$nonEventProb" 68 | 69 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -verbose $verbose -gameOverPenalty $gameOverPenalty" #-network $saved_network" 70 | 71 | # Copy stdout and stderr to a logfile. 72 | LOGFILE="logs/dqn_log_`/bin/date +\"%F:%R\"`" 73 | exec > >(tee -i ${LOGFILE}) 74 | exec 2>&1 75 | 76 | echo $args 77 | 78 | cd dqn 79 | ../torch/bin/qlua train_agent.lua $args 80 | --------------------------------------------------------------------------------