├── .gitignore ├── README.md ├── dqn ├── LICENSE ├── NeuralQLearner.lua ├── Rectifier.lua ├── Scale.lua ├── TransitionTable.lua ├── TransitionTable_spriority.lua ├── base.png ├── convnet.lua ├── convnet_atari3.lua ├── image.png ├── initenv.lua ├── net_downsample_2x_full_y.lua ├── nnutils.lua ├── pyserver.py ├── signal.def ├── signal.so ├── templates │ ├── 19.png │ ├── 275.png │ ├── 291.png │ ├── 502.png │ ├── door.png │ ├── door_new.png │ ├── key.png │ ├── ladder.png │ ├── man (copy).png │ ├── man.png │ ├── man_clean.png │ ├── man_red.png │ ├── skull.png │ └── test.png ├── test_agent.lua ├── tmp.png ├── tmp_5000.png ├── tmp_5001.png ├── tmp_5550.png ├── tmp_5555.png ├── tmp_6000.png ├── train_agent.lua └── unit_tests.lua ├── gifs ├── breakout.gif ├── enduro.gif └── enduro.mp4 ├── install_dependencies.sh ├── roms └── montezuma_revenge.bin ├── run.slurm ├── run_cpu ├── run_exp.sh ├── run_exp_multi.sh ├── run_gpu ├── run_slurm.sh ├── run_slurm_multi_exp.sh ├── slurm-10443616.out ├── stop_server.sh ├── structured_priority ├── dqn │ ├── LICENSE │ ├── NeuralQLearner.lua │ ├── Rectifier.lua │ ├── Scale.lua │ ├── TransitionTable.lua │ ├── TransitionTable_spriority.lua │ ├── convnet.lua │ ├── convnet_atari3.lua │ ├── initenv.lua │ ├── net_downsample_2x_full_y.lua │ ├── nnutils.lua │ ├── test_agent.lua │ └── train_agent.lua ├── install_dependencies.sh ├── roms │ ├── README │ └── breakout.bin ├── run_cpu ├── run_gpu ├── runner_slurm.py ├── slurm_scripts │ ├── RL_breakout_exp1_0.slurm │ ├── RL_breakout_exp1_1.slurm │ ├── RL_breakout_exp2_0.slurm │ └── RL_breakout_exp2_1.slurm ├── test_cpu └── test_gpu ├── test_cpu ├── test_exp.sh └── test_gpu /.gitignore: -------------------------------------------------------------------------------- 1 | structured_priority/slurm_scripts 2 | *.t7 3 | structured_priority/dqn/logs 4 | dqn/logs 5 | slurm_logs 6 | *~ 7 | torch 8 | dqn/*.t7 9 | .DS_Store 10 | gifs 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Code for [h-DQN, NIPS 2016](http://papers.nips.cc/paper/6233-hierarchical-deep-reinforcement-learning-integrating-temporal-abstraction-and-intrinsic-motivation) 2 | - Use the synthetic branch for the stochastic decision process example 3 | - Use the metanet branch for Atari. Pre-train the network using the iclr16_basicsubgoal branch before doing this and load this network using the metanet branch to train the full model. 4 | -------------------------------------------------------------------------------- /dqn/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | LIMITED LICENSE: 3 | 4 | Copyright (c) 2014 Google Inc. 5 | Limited License: Under no circumstance is commercial use, reproduction, or 6 | distribution permitted. Use, reproduction, and distribution are permitted 7 | solely for academic use in evaluating and reviewing claims made in 8 | "Human-level control through deep reinforcement learning", Nature 518, 529–533 9 | (26 February 2015) doi:10.1038/nature14236, provided that the following 10 | conditions are met: 11 | 12 | * Any reproduction or distribution of source code must retain the above 13 | copyright notice and the full text of this license including the following 14 | disclaimer.
 15 | 16 | * Any reproduction or distribution in binary form must reproduce the above 17 | copyright notice and the full text of this license including the following 18 | disclaimer
 in the documentation and/or other materials provided with the 19 | distribution. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /dqn/NeuralQLearner.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | require 'image' 7 | require 'torch' 8 | if not dqn then 9 | require 'initenv' 10 | end 11 | 12 | local nql = torch.class('dqn.NeuralQLearner') 13 | 14 | ------ ZMQ server ------- 15 | local json = require ("dkjson") 16 | local zmq = require "lzmq" 17 | ctx = zmq.context() 18 | skt = ctx:socket{zmq.REQ, 19 | linger = 0, rcvtimeo = 1000; 20 | connect = "tcp://127.0.0.1:" .. ZMQ_PORT; 21 | } 22 | 23 | function nql:__init(args) 24 | self.state_dim = args.state_dim -- State dimensionality. 25 | self.actions = args.actions 26 | self.n_actions = #self.actions 27 | self.verbose = args.verbose 28 | self.best = args.best 29 | 30 | self.subgoal_dims = args.subgoal_dims 31 | self.subgoal_nhid = args.subgoal_nhid 32 | 33 | -- run subgoal specific experiments 34 | self.use_distance = args.use_distance -- if we want to use the distance as the reward 35 | 36 | -- to keep track of stats 37 | self.subgoal_success = {} 38 | self.subgoal_total = {} 39 | 40 | -- to keep track of dying position 41 | self.deathPosition = nil 42 | self.DEATH_THRESHOLD = 15 43 | 44 | --- epsilon annealing 45 | self.ep_start = args.ep or 1 46 | self.ep = self.ep_start -- Exploration probability. 47 | self.ep_end = args.ep_end or self.ep 48 | self.ep_endt = args.ep_endt or 1000000 49 | 50 | ---- learning rate annealing 51 | self.lr_start = args.lr or 0.01 --Learning rate. 52 | self.lr = self.lr_start 53 | self.lr_end = args.lr_end or self.lr 54 | self.lr_endt = args.lr_endt or 1000000 55 | self.wc = args.wc or 0 -- L2 weight cost. 56 | self.minibatch_size = args.minibatch_size or 1 57 | self.valid_size = args.valid_size or 500 58 | 59 | --- Q-learning parameters 60 | self.dynamic_discount = args.dynamic_discount 61 | self.discount = args.discount or 0.99 --Discount factor. 62 | self.discount_internal = args.discount_internal --Discount factor for internal rewards 63 | self.update_freq = args.update_freq or 1 64 | -- Number of points to replay per learning step. 65 | self.n_replay = args.n_replay or 1 66 | -- Number of steps after which learning starts. 67 | self.learn_start = args.learn_start or 0 68 | -- Size of the transition table. 69 | self.replay_memory = args.replay_memory or 1000000 70 | self.hist_len = args.hist_len or 1 71 | self.rescale_r = args.rescale_r 72 | self.max_reward = args.max_reward 73 | self.min_reward = args.min_reward 74 | self.clip_delta = args.clip_delta 75 | self.target_q = args.target_q 76 | self.bestq = 0 77 | 78 | self.gpu = args.gpu 79 | 80 | self.ncols = args.ncols or 1 -- number of color channels in input 81 | self.input_dims = args.input_dims or {self.hist_len*self.ncols, 84, 84} 82 | self.preproc = args.preproc -- name of preprocessing network 83 | self.histType = args.histType or "linear" -- history type to use 84 | self.histSpacing = args.histSpacing or 1 85 | self.nonTermProb = args.nonTermProb or 1 86 | self.bufferSize = args.bufferSize or 512 87 | 88 | self.transition_params = args.transition_params or {} 89 | 90 | self.network = args.network or self:createNetwork() 91 | -- check whether there is a network file 92 | local network_function 93 | if not (type(self.network) == 'string') then 94 | error("The type of the network provided in NeuralQLearner" .. 95 | " is not a string!") 96 | end 97 | 98 | local msg, err = pcall(require, self.network) 99 | if not msg then 100 | -- try to load saved agent 101 | local err_msg, exp = pcall(torch.load, self.network) 102 | if not err_msg then 103 | error("Could not find network file ") 104 | end 105 | if self.best and exp.best_model then --best_model_real and model_rel if testing on non-subgoal network 106 | self.network = exp.best_model 107 | else 108 | self.network = exp.model 109 | end 110 | else 111 | print('Creating Agent Network from ' .. self.network) 112 | self.network = err 113 | self.network = self:network() 114 | end 115 | 116 | if self.gpu and self.gpu >= 0 then 117 | self.network:cuda() 118 | else 119 | self.network:float() 120 | end 121 | 122 | -- Load preprocessing network. 123 | if not (type(self.preproc == 'string')) then 124 | error('The preprocessing is not a string') 125 | end 126 | msg, err = pcall(require, self.preproc) 127 | if not msg then 128 | error("Error loading preprocessing net") 129 | end 130 | self.preproc = err 131 | self.preproc = self:preproc() 132 | self.preproc:float() 133 | 134 | if self.gpu and self.gpu >= 0 then 135 | self.network:cuda() 136 | self.tensor_type = torch.CudaTensor 137 | else 138 | self.network:float() 139 | self.tensor_type = torch.FloatTensor 140 | end 141 | 142 | -- Create transition table. 143 | ---- assuming the transition table always gets floating point input 144 | ---- (Foat or Cuda tensors) and always returns one of the two, as required 145 | ---- internally it always uses ByteTensors for states, scaling and 146 | ---- converting accordingly 147 | local transition_args = { 148 | stateDim = self.state_dim, numActions = self.n_actions, 149 | histLen = self.hist_len, gpu = self.gpu, 150 | maxSize = self.replay_memory, histType = self.histType, 151 | histSpacing = self.histSpacing, nonTermProb = self.nonTermProb, 152 | bufferSize = self.bufferSize, 153 | subgoal_dims = args.subgoal_dims 154 | } 155 | 156 | self.transitions = dqn.TransitionTable(transition_args) 157 | 158 | self.numSteps = 0 -- Number of perceived states. 159 | self.lastState = nil 160 | self.lastAction = nil 161 | self.lastSubgoal = nil 162 | self.v_avg = 0 -- V running average. 163 | self.tderr_avg = 0 -- TD error running average. 164 | 165 | self.q_max = 1 166 | self.r_max = 1 167 | 168 | self.network_real = self.network:clone() 169 | 170 | self.w, self.dw = self.network:getParameters() 171 | self.dw:zero() 172 | self.deltas = self.dw:clone():fill(0) 173 | self.tmp= self.dw:clone():fill(0) 174 | self.g = self.dw:clone():fill(0) 175 | self.g2 = self.dw:clone():fill(0) 176 | 177 | self.w_real, self.dw_real = self.network_real:getParameters() 178 | self.dw_real:zero() 179 | self.deltas_real = self.dw_real:clone():fill(0) 180 | self.tmp_real= self.dw_real:clone():fill(0) 181 | self.g_real = self.dw_real:clone():fill(0) 182 | self.g2_real = self.dw_real:clone():fill(0) 183 | 184 | 185 | if self.target_q then 186 | self.target_network = self.network:clone() 187 | self.w_target, self.dw_target = self.target_network:getParameters() 188 | 189 | self.target_network_real = self.network_real:clone() 190 | self.w_real_target, self.dw_real_target = self.target_network_real:getParameters() 191 | end 192 | 193 | --hack for testing (giving correct sequence of subgoals -- 7 6 8) 194 | self.true_subgoal_order = {7,6,8} 195 | self.true_subgoal_indx = 1 196 | end 197 | 198 | 199 | function nql:reset(state) 200 | if not state then 201 | return 202 | end 203 | self.best_network = state.best_network 204 | self.best_network_real = state.best_network_real 205 | 206 | self.network = state.model 207 | self.network_real = state.model_real 208 | 209 | self.w, self.dw = self.network:getParameters() 210 | self.dw:zero() 211 | 212 | self.w_real, self.dw_real = self.network_real:getParameters() 213 | self.dw_real:zero() 214 | 215 | self.numSteps = 0 216 | print("RESET STATE SUCCESFULLY") 217 | end 218 | 219 | 220 | function nql:preprocess(rawstate) 221 | if self.preproc then 222 | return self.preproc:forward(rawstate:float()) 223 | :clone():reshape(self.state_dim) 224 | end 225 | 226 | return rawstate 227 | end 228 | 229 | 230 | function nql:getQUpdate(args, external_r) 231 | local s, a, r, s2, term, delta 232 | local q, q2, q2_max 233 | 234 | s = args.s 235 | a = args.a 236 | r = args.r 237 | s2 = args.s2 238 | subgoals2 = args.subgoals2 239 | subgoals = args.subgoals 240 | term = args.term 241 | 242 | -- The order of calls to forward is a bit odd in order 243 | -- to avoid unnecessary calls (we only need 2). 244 | 245 | -- delta = r + (1-terminal) * gamma * max_a Q(s2, a) - Q(s, a) 246 | term = term:clone():float():mul(-1):add(1) 247 | 248 | local target_q_net 249 | if self.target_q then 250 | target_q_net = args.target_network 251 | else 252 | target_q_net = args.network 253 | end 254 | 255 | -- Compute max_a Q(s_2, a). 256 | -- print(s2:size(), subgoals2:size()) 257 | -- q2_max = target_q_net:forward({s2, subgoals2:zero()}):float():max(2) 258 | q2_max = target_q_net:forward({s2, subgoals2}):float():max(2) 259 | 260 | -- Compute q2 = (1-terminal) * gamma * max_a Q(s2, a) 261 | 262 | local discount 263 | if external_r then 264 | discount = math.max(self.dynamic_discount, self.discount) -- for real network 265 | else 266 | discount = math.max(self.dynamic_discount, self.discount_internal) -- for subgoal network 267 | end 268 | 269 | q2 = q2_max:clone():mul(discount):cmul(term) 270 | 271 | delta = r:clone():float() 272 | 273 | -- TODO: removed scaling. check later 274 | -- if self.rescale_r then 275 | -- delta:div(self.r_max) 276 | -- end 277 | 278 | delta:add(q2) 279 | 280 | -- q = Q(s,a) 281 | -- local q_all = args.network:forward({s, subgoals:zero()}):float() 282 | local q_all = args.network:forward({s, subgoals}):float() 283 | 284 | q = torch.FloatTensor(q_all:size(1)) 285 | for i=1,q_all:size(1) do 286 | q[i] = q_all[i][a[i]] 287 | end 288 | delta:add(-1, q) 289 | 290 | if self.clip_delta then 291 | delta[delta:ge(self.clip_delta)] = self.clip_delta 292 | delta[delta:le(-self.clip_delta)] = -self.clip_delta 293 | end 294 | 295 | local targets = torch.zeros(self.minibatch_size, self.n_actions):float() 296 | for i=1,math.min(self.minibatch_size,a:size(1)) do 297 | targets[i][a[i]] = delta[i] 298 | end 299 | 300 | if self.gpu >= 0 then targets = targets:cuda() end 301 | 302 | return targets, delta, q2_max 303 | end 304 | 305 | 306 | function nql:qLearnMinibatch(network, target_network, dw, w, g, g2, tmp, deltas, external_r) 307 | -- Perform a minibatch Q-learning update: 308 | -- w += alpha * (r + gamma max Q(s2,a2) - Q(s,a)) * dQ(s,a)/dw 309 | assert(self.transitions:size() > self.minibatch_size) 310 | 311 | local s, a, r, s2, term, subgoals, subgoals2 = self.transitions:sample(self.minibatch_size) 312 | -- print(r, s:sum(2)) 313 | if external_r then 314 | r = r[{{},1}] --extract external reward 315 | -- subgoals[{{},{1,self.subgoal_dims}}] = 0 316 | -- subgoals2[{{},{1,self.subgoal_dims}}] = 0 317 | if SUBGOAL_SCREEN then 318 | -- TODO 319 | end 320 | 321 | else 322 | r = r[{{},2}] --external + intrinsic reward 323 | end 324 | 325 | local targets, delta, q2_max = self:getQUpdate({s=s, a=a, r=r, s2=s2, 326 | term=term, subgoals = subgoals, subgoals2=subgoals2, network = network, update_qmax=true, target_network = target_network}, external_r) 327 | 328 | -- zero gradients of parameters 329 | dw:zero() 330 | 331 | -- get new gradient 332 | -- print(subgoals) 333 | network:backward({s, subgoals}, targets) 334 | 335 | -- add weight cost to gradient 336 | dw:add(-self.wc, w) 337 | 338 | -- compute linearly annealed learning rate 339 | local t = math.max(0, self.numSteps - self.learn_start) 340 | self.lr = (self.lr_start - self.lr_end) * (self.lr_endt - t)/self.lr_endt + 341 | self.lr_end 342 | self.lr = math.max(self.lr, self.lr_end) 343 | 344 | 345 | 346 | --grad normalization 347 | -- local max_norm = 1000 348 | -- local grad_norm = dw:norm() 349 | -- if grad_norm > max_norm then 350 | -- local scale_factor = max_norm/grad_norm 351 | -- dw:mul(scale_factor) 352 | -- if false and grad_norm > 1000 then 353 | -- print("Scaling down gradients. Norm:", grad_norm) 354 | -- end 355 | -- end 356 | 357 | -- use gradients (original) 358 | g:mul(0.95):add(0.05, dw) 359 | tmp:cmul(dw, dw) 360 | g2:mul(0.95):add(0.05, tmp) 361 | tmp:cmul(g, g) 362 | tmp:mul(-1) 363 | tmp:add(g2) 364 | tmp:add(0.01) 365 | tmp:sqrt() 366 | 367 | --rmsprop 368 | -- local smoothing_value = 1e-8 369 | -- tmp:cmul(dw, dw) 370 | -- g:mul(0.9):add(0.1, tmp) 371 | -- tmp = torch.sqrt(g) 372 | -- tmp:add(smoothing_value) --negative learning rate 373 | 374 | -- accumulate update 375 | deltas:mul(0):addcdiv(self.lr, dw, tmp) 376 | w:add(deltas) 377 | end 378 | 379 | 380 | function nql:sample_validation_data() 381 | local s, a, r, s2, term, subgoals, subgoals2 = self.transitions:sample(self.valid_size) 382 | self.valid_s = s:clone() 383 | self.valid_a = a:clone() 384 | self.valid_r = r:clone() 385 | self.valid_s2 = s2:clone() 386 | self.valid_term = term:clone() 387 | self.valid_subgoals = subgoals:clone() 388 | self.valid_subgoals2 = subgoals2:clone() 389 | end 390 | 391 | 392 | function nql:compute_validation_statistics() 393 | local targets, delta, q2_max = self:getQUpdate{s=self.valid_s, 394 | a=self.valid_a, r=self.valid_r[{{},1}], s2=self.valid_s2, term=self.valid_term, subgoals = self.valid_subgoals, 395 | subgoals2 = self.valid_subgoals2, network = self.network_real, target_network = self.target_network_real} 396 | 397 | self.v_avg = self.q_max * q2_max:mean() 398 | self.tderr_avg = delta:clone():abs():mean() 399 | end 400 | 401 | 402 | function process_pystr(msg) 403 | loadstring(msg)() 404 | for i = 1, #objlist do 405 | objlist[i] = torch.Tensor(objlist[i]) 406 | end 407 | return objlist 408 | end 409 | 410 | -- returns a table of num_objects x vectorized object reps 411 | function nql:get_objects(rawstate) 412 | image.save('tmp_' .. ZMQ_PORT .. '.png', rawstate[1]) 413 | skt:send("") 414 | msg = skt:recv() 415 | while msg == nil do 416 | msg = skt:recv() 417 | end 418 | local object_list = process_pystr(msg) 419 | return object_list --nn.SplitTable(1):forward(torch.rand(4, self.subgoal_dims)) 420 | end 421 | 422 | function nql:pick_subgoal(rawstate, oid) 423 | local objects = self:get_objects(rawstate) 424 | local indxs = oid or torch.random(3, #objects) -- skip first two as first is agent is the agent 425 | while objects[indxs]:sum() == 0 do -- object absent 426 | indxs = torch.random(3, #objects) -- skip first two as first is agent is the agent 427 | end 428 | 429 | --- REMOVE: giving true subgoal sequence 430 | self.true_subgoal_order = {7,6} 431 | 432 | if self.true_subgoal_indx > #self.true_subgoal_order then 433 | self.true_subgoal_indx = 1 434 | end 435 | --- pick subgoal --- TODO: remove this hack 436 | local current_goal_id = self.true_subgoal_order[self.true_subgoal_indx] 437 | if self.subgoal_success[current_goal_id] and self.subgoal_total[current_goal_id] then 438 | if (self.subgoal_success[current_goal_id] / self.subgoal_total[current_goal_id]) >= 0.3 then 439 | self.true_subgoal_indx = self.true_subgoal_indx + 1 440 | end 441 | end 442 | 443 | indxs = self.true_subgoal_order[self.true_subgoal_indx] 444 | -- self.true_subgoal_indx = self.true_subgoal_indx + 1 445 | 446 | 447 | 448 | -- concatenate subgoal with objects (input into network) 449 | local subg = objects[indxs] 450 | 451 | -- local ftrvec = torch.zeros(#objects*self.subgoal_dims) 452 | -- for i = 1,#objects do 453 | -- ftrvec[{{(i-1)*self.subgoal_dims + 1, i*self.subgoal_dims}}] = objects[i] 454 | -- end 455 | 456 | -- -- add stats 457 | -- self.subgoal_total[subg:sum()] = self.subgoal_total[subg:sum()] or 0 458 | -- self.subgoal_total[subg:sum()] = self.subgoal_total[subg:sum()] + 1 459 | self.subgoal_total[indxs] = self.subgoal_total[indxs] or 0 460 | self.subgoal_total[indxs] = self.subgoal_total[indxs] + 1 461 | 462 | -- zeroing out discrete objects 463 | -- ftrvec:zero() 464 | self.objects = objects 465 | 466 | local ftrvec = torch.zeros(#objects*self.subgoal_dims) 467 | ftrvec[indxs] = 1 468 | ftrvec[#ftrvec] = indxs 469 | return torch.cat(subg, ftrvec) 470 | end 471 | 472 | function nql:isGoalReached(subgoal, objects) 473 | local agent = objects[1] 474 | 475 | -- IMP: remember that subgoal includes both subgoal and all objects 476 | local dist = math.sqrt((subgoal[1] - agent[1])^2 + (subgoal[2]-agent[2])^2) 477 | if dist < 8 then --just a small threshold to indicate when agent meets subgoal (euc dist) 478 | print('subgoal reached!') 479 | -- local indexTensor = subgoal[{{3, self.subgoal_dims}}]:byte() 480 | -- print(subgoal, indexTensor) 481 | local subg = subgoal[{{1, self.subgoal_dims}}] 482 | -- subgoal[#subgoal] is just the subgoal ID 483 | self.subgoal_success[subgoal[#subgoal]] = self.subgoal_success[subgoal[#subgoal]] or 0 484 | self.subgoal_success[subgoal[#subgoal]] = self.subgoal_success[subgoal[#subgoal]] + 1 485 | -- self.subgoal_success[subg:sum()] = self.subgoal_success[subg:sum()] or 0 486 | -- self.subgoal_success[subg:sum()] = self.subgoal_success[subg:sum()] + 1 487 | return true 488 | else 489 | return false 490 | end 491 | end 492 | 493 | function nql:intrinsic_reward(subgoal, objects) 494 | -- return reward based on distance or 0/1 towards sub-goal 495 | local agent = objects[1] 496 | local reward 497 | -- if self.lastSubgoal then 498 | -- print("last subgoal", self.lastSubgoal[{{1,7}}]) 499 | -- end 500 | -- print("current subgoal", subgoal[{{1,7}}]) 501 | if self.lastSubgoal and (self.lastSubgoal[{{3,self.subgoal_dims}}] - subgoal[{{3, self.subgoal_dims}}]):abs():sum() == 0 then 502 | local dist1 = math.sqrt((subgoal[1] - agent[1])^2 + (subgoal[2]-agent[2])^2) 503 | local dist2 = math.sqrt((self.lastSubgoal[1] - self.lastobjects[1][1])^2 + (self.lastSubgoal[2]-self.lastobjects[1][2])^2) 504 | reward = dist2 - dist1 505 | else 506 | reward = 0 507 | end 508 | 509 | 510 | if not self.use_distance then 511 | reward = 0 -- no intrinsic reward except for reaching the subgoal 512 | end 513 | 514 | -- print(reward) 515 | return reward 516 | end 517 | 518 | 519 | function nql:perceive(subgoal, reward, rawstate, terminal, testing, testing_ep) 520 | if reward > 0 then 521 | print('external reward nonzero! [val:',reward, ']') 522 | end 523 | 524 | -- Preprocess state (will be set to nil if terminal) 525 | if terminal then 526 | reward = -200 527 | end 528 | 529 | local state = self:preprocess(rawstate):float() 530 | local objects = self:get_objects(rawstate) 531 | 532 | if terminal then 533 | self.deathPosition = objects[1][{{1,2}}] --just store the x and y coords of the agent 534 | end 535 | 536 | local goal_reached = self:isGoalReached(subgoal, objects) 537 | local intrinsic_reward = self:intrinsic_reward(subgoal, objects) 538 | reward = reward - 0.1 -- penalize for just standing 539 | if goal_reached then 540 | intrinsic_reward = intrinsic_reward + 50 541 | end 542 | 543 | local curState 544 | 545 | if self.max_reward then 546 | reward = math.min(reward, self.max_reward) 547 | end 548 | if self.min_reward then 549 | reward = math.max(reward, self.min_reward) 550 | end 551 | if self.rescale_r then 552 | self.r_max = math.max(self.r_max, reward) 553 | end 554 | 555 | --print(reward, intrinsic_reward) 556 | 557 | self.transitions:add_recent_state(state, terminal, subgoal) 558 | 559 | -- local currentFullState = self.transitions:get_recent() 560 | --Store transition s, a, r, s' 561 | if self.lastState and not testing and self.lastSubgoal then 562 | self.transitions:add(self.lastState, self.lastAction, torch.Tensor({reward, reward + intrinsic_reward}), 563 | self.lastTerminal, self.lastSubgoal, priority) 564 | -- print("STORING PREV TRANSITION", self.lastState:sum(), self.lastAction, torch.Tensor({reward, reward + intrinsic_reward}), 565 | -- self.lastTerminal, self.lastSubgoal:sum(), priority) 566 | end 567 | 568 | if self.numSteps == self.learn_start+1 and not testing then 569 | self:sample_validation_data() 570 | end 571 | 572 | curState, subgoal = self.transitions:get_recent() 573 | curState = curState:resize(1, unpack(self.input_dims)) 574 | 575 | -- Select action 576 | local actionIndex = 1 577 | local qfunc 578 | if not terminal then 579 | actionIndex, qfunc = self:eGreedy(curState, testing, testing_ep, subgoal) 580 | end 581 | 582 | -- actionIndex = 5 --left 583 | 584 | self.transitions:add_recent_action(actionIndex) 585 | 586 | --Do some Q-learning updates 587 | if self.numSteps > self.learn_start and not testing and 588 | self.numSteps % self.update_freq == 0 then 589 | for i = 1, self.n_replay do 590 | self:qLearnMinibatch(self.network, self.target_network, self.dw, self.w, self.g, self.g2, self.tmp, self.deltas, false) 591 | self:qLearnMinibatch(self.network_real, self.target_network_real, self.dw_real, self.w_real, self.g_real, self.g2_real, self.tmp_real, self.deltas_real, true) 592 | end 593 | end 594 | 595 | if not testing then 596 | self.numSteps = self.numSteps + 1 597 | end 598 | 599 | self.lastState = state:clone() 600 | self.lastAction = actionIndex 601 | self.lastTerminal = terminal 602 | if not terminal then 603 | -- print("Getting subgoal") 604 | self.lastSubgoal = subgoal 605 | --check if the game is still in the stages right after the agent dies 606 | if self.deathPosition then 607 | currentPosition = objects[1][{{1,2}}] 608 | -- print("Positions:", currentPosition, self.deathPosition) 609 | if math.sqrt((currentPosition[1]-self.deathPosition[1])^2 + (currentPosition[2]-self.deathPosition[2])^2) < self.DEATH_THRESHOLD then 610 | self.lastSubgoal = nil 611 | -- print("death overruling") 612 | else 613 | -- print("Removing death position", self.deathPosition) 614 | self.deathPosition = nil 615 | end 616 | 617 | end 618 | 619 | else 620 | -- print("LAST SUBGOAL is now NIL") 621 | -- self.lastSubgoal = nil --TODO check 622 | end 623 | 624 | self.lastobjects = objects 625 | 626 | -- target q copy 627 | if false then -- deprecated 628 | if self.target_q and self.numSteps % self.target_q == 1 then 629 | self.target_network = self.network:clone() 630 | self.target_network_real = self.network_real:clone() 631 | end 632 | else --smooth average 633 | local alpha = 0.999 634 | self.w_target:mul(0.999):add(self.w * (1-alpha)) 635 | self.w_real_target:mul(0.999):add(self.w_real * (1-alpha)) 636 | end 637 | 638 | if not terminal then 639 | return actionIndex, goal_reached, reward, reward+intrinsic_reward, qfunc 640 | else 641 | return 0, goal_reached, reward, reward+intrinsic_reward, qfunc 642 | end 643 | end 644 | 645 | 646 | function nql:eGreedy(state, testing, testing_ep, subgoal) 647 | self.ep = testing_ep or (self.ep_end + 648 | math.max(0, (self.ep_start - self.ep_end) * (self.ep_endt - 649 | math.max(0, self.numSteps - self.learn_start))/self.ep_endt)) 650 | 651 | -- if testing, zero out subgoals 652 | -- if testing then 653 | -- subgoal = subgoal:clone() 654 | -- subgoal[{{1,self.subgoal_dims}}] = 0 655 | -- end 656 | 657 | -- Epsilon greedy 658 | if torch.uniform() < self.ep then 659 | return torch.random(1, self.n_actions) 660 | else 661 | return self:greedy(state, subgoal, testing) 662 | end 663 | end 664 | 665 | 666 | function nql:greedy(state, subgoal, testing) 667 | -- Turn single state into minibatch. Needed for convolutional nets. 668 | if state:dim() == 2 then 669 | assert(false, 'Input must be at least 3D') 670 | state = state:resize(1, state:size(1), state:size(2)) 671 | end 672 | subgoal = torch.reshape(subgoal, 1, self.subgoal_dims*9) 673 | if self.gpu >= 0 then 674 | state = state:cuda() 675 | subgoal = subgoal:cuda() 676 | end 677 | -- local q = self.network:forward({state, subgoal:zero()}):float():squeeze() 678 | local q = self.network:forward({state, subgoal}):float():squeeze() 679 | 680 | local maxq = q[1] 681 | local besta = {1} 682 | -- print("Q Value:", q) 683 | -- Evaluate all other actions (with random tie-breaking) 684 | for a = 2, self.n_actions do 685 | if q[a] > maxq then 686 | besta = { a } 687 | maxq = q[a] 688 | elseif q[a] == maxq then 689 | besta[#besta+1] = a 690 | end 691 | end 692 | self.bestq = maxq 693 | 694 | local r = torch.random(1, #besta) 695 | 696 | self.lastAction = besta[r] 697 | 698 | return besta[r], q 699 | end 700 | 701 | 702 | function nql:createNetwork() 703 | local n_hid = 128 704 | local mlp = nn.Sequential() 705 | mlp:add(nn.Reshape(self.hist_len*self.ncols*self.state_dim)) 706 | mlp:add(nn.Linear(self.hist_len*self.ncols*self.state_dim, n_hid)) 707 | mlp:add(nn.Rectifier()) 708 | mlp:add(nn.Linear(n_hid, n_hid)) 709 | mlp:add(nn.Rectifier()) 710 | mlp:add(nn.Linear(n_hid, self.n_actions)) 711 | 712 | return mlp 713 | end 714 | 715 | 716 | function nql:report() 717 | print("Subgoal Network\n---------------------") 718 | print(get_weight_norms(self.network)) 719 | print(get_grad_norms(self.network)) 720 | print(" Real Network\n---------------------") 721 | print(get_weight_norms(self.network_real)) 722 | print(get_grad_norms(self.network_real)) 723 | 724 | 725 | -- print stats on subgoal success rates 726 | for subg, val in pairs(self.subgoal_total) do 727 | if self.subgoal_success[subg] then 728 | print("Subgoal ID (8-key, 6/7-bottom ladders):" , subg , ' : ', self.subgoal_success[subg]/val, self.subgoal_success[subg] .. '/' .. val) 729 | else 730 | print("Subgoal ID (8-key, 6/7-bottom ladders):" , subg , ' : ') 731 | end 732 | end 733 | -- self.subgoal_success = {} 734 | -- self.subgoal_total = {} 735 | end 736 | -------------------------------------------------------------------------------- /dqn/Rectifier.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | --[[ Rectified Linear Unit. 8 | 9 | The output is max(0, input). 10 | --]] 11 | 12 | local Rectifier, parent = torch.class('nn.Rectifier', 'nn.Module') 13 | 14 | -- This module accepts minibatches 15 | function Rectifier:updateOutput(input) 16 | return self.output:resizeAs(input):copy(input):abs():add(input):div(2) 17 | end 18 | 19 | function Rectifier:updateGradInput(input, gradOutput) 20 | self.gradInput:resizeAs(self.output) 21 | return self.gradInput:sign(self.output):cmul(gradOutput) 22 | end -------------------------------------------------------------------------------- /dqn/Scale.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "nn" 8 | require "image" 9 | 10 | local scale = torch.class('nn.Scale', 'nn.Module') 11 | 12 | 13 | function scale:__init(height, width) 14 | self.height = height 15 | self.width = width 16 | end 17 | 18 | function scale:forward(x) 19 | local x = x 20 | if x:dim() > 3 then 21 | x = x[1] 22 | end 23 | 24 | x = image.rgb2y(x) 25 | x = image.scale(x, self.width, self.height, 'bilinear') 26 | return x 27 | end 28 | 29 | function scale:updateOutput(input) 30 | return self:forward(input) 31 | end 32 | 33 | function scale:float() 34 | end 35 | -------------------------------------------------------------------------------- /dqn/TransitionTable.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require 'image' 8 | 9 | local trans = torch.class('dqn.TransitionTable') 10 | 11 | 12 | function trans:__init(args) 13 | self.stateDim = args.stateDim 14 | self.numActions = args.numActions 15 | self.histLen = args.histLen 16 | self.maxSize = args.maxSize or 1024^2 17 | self.bufferSize = args.bufferSize or 1024 18 | 19 | self.histType = args.histType or "linear" 20 | self.histSpacing = args.histSpacing or 1 21 | self.zeroFrames = args.zeroFrames or 1 22 | self.nonTermProb = args.nonTermProb or 1 23 | self.nonEventProb = args.nonEventProb or 1 24 | self.gpu = args.gpu 25 | self.numEntries = 0 26 | self.insertIndex = 0 27 | 28 | self.histIndices = {} 29 | local histLen = self.histLen 30 | if self.histType == "linear" then 31 | -- History is the last histLen frames. 32 | self.recentMemSize = self.histSpacing*histLen 33 | for i=1,histLen do 34 | self.histIndices[i] = i*self.histSpacing 35 | end 36 | elseif self.histType == "exp2" then 37 | -- The ith history frame is from 2^(i-1) frames ago. 38 | self.recentMemSize = 2^(histLen-1) 39 | self.histIndices[1] = 1 40 | for i=1,histLen-1 do 41 | self.histIndices[i+1] = self.histIndices[i] + 2^(7-i) 42 | end 43 | elseif self.histType == "exp1.25" then 44 | -- The ith history frame is from 1.25^(i-1) frames ago. 45 | self.histIndices[histLen] = 1 46 | for i=histLen-1,1,-1 do 47 | self.histIndices[i] = math.ceil(1.25*self.histIndices[i+1])+1 48 | end 49 | self.recentMemSize = self.histIndices[1] 50 | for i=1,histLen do 51 | self.histIndices[i] = self.recentMemSize - self.histIndices[i] + 1 52 | end 53 | end 54 | 55 | self.s = torch.ByteTensor(self.maxSize, self.stateDim):fill(0) 56 | self.a = torch.LongTensor(self.maxSize):fill(0) 57 | self.r = torch.zeros(self.maxSize,2) 58 | self.subgoal_dims = args.subgoal_dims*9 --TODO (total number of objects) 59 | self.subgoal = torch.zeros(self.maxSize, self.subgoal_dims) 60 | self.t = torch.ByteTensor(self.maxSize):fill(0) 61 | self.action_encodings = torch.eye(self.numActions) 62 | 63 | -- Tables for storing the last histLen states. They are used for 64 | -- constructing the most recent agent state more easily. 65 | self.recent_s = {} 66 | self.recent_a = {} 67 | self.recent_t = {} 68 | self.recent_subgoal = {} 69 | 70 | local s_size = self.stateDim*histLen 71 | self.buf_a = torch.LongTensor(self.bufferSize):fill(0) 72 | self.buf_r = torch.zeros(self.bufferSize,2) 73 | self.buf_term = torch.ByteTensor(self.bufferSize):fill(0) 74 | self.buf_s = torch.ByteTensor(self.bufferSize, s_size):fill(0) 75 | self.buf_s2 = torch.ByteTensor(self.bufferSize, s_size):fill(0) 76 | self.buf_subgoal = torch.zeros(self.bufferSize, self.subgoal_dims) 77 | self.buf_subgoal2 = torch.zeros(self.bufferSize, self.subgoal_dims) 78 | 79 | if self.gpu and self.gpu >= 0 then 80 | self.gpu_s = self.buf_s:float():cuda() 81 | self.gpu_s2 = self.buf_s2:float():cuda() 82 | self.gpu_subgoal = self.buf_subgoal:float():cuda() 83 | self.gpu_subgoal2 = self.buf_subgoal2:float():cuda() 84 | end 85 | end 86 | 87 | 88 | function trans:reset() 89 | self.numEntries = 0 90 | self.insertIndex = 0 91 | end 92 | 93 | 94 | function trans:size() 95 | return self.numEntries 96 | end 97 | 98 | 99 | function trans:empty() 100 | return self.numEntries == 0 101 | end 102 | 103 | 104 | function trans:fill_buffer() 105 | assert(self.numEntries >= self.bufferSize) 106 | -- clear CPU buffers 107 | self.buf_ind = 1 108 | local ind 109 | for buf_ind=1,self.bufferSize do 110 | local s, a, r, s2, term, subgoal, subgoal2 = self:sample_one(1) 111 | self.buf_s[buf_ind]:copy(s) 112 | self.buf_a[buf_ind] = a 113 | self.buf_subgoal[buf_ind] = subgoal 114 | self.buf_subgoal2[buf_ind] = subgoal2 115 | self.buf_r[buf_ind] = r 116 | self.buf_s2[buf_ind]:copy(s2) 117 | self.buf_term[buf_ind] = term 118 | end 119 | self.buf_s = self.buf_s:float():div(255) 120 | self.buf_s2 = self.buf_s2:float():div(255) 121 | if self.gpu and self.gpu >= 0 then 122 | self.gpu_s:copy(self.buf_s) 123 | self.gpu_s2:copy(self.buf_s2) 124 | self.gpu_subgoal:copy(self.buf_subgoal) 125 | self.gpu_subgoal2:copy(self.buf_subgoal2) 126 | end 127 | end 128 | 129 | 130 | function trans:sample_one() 131 | assert(self.numEntries > 1) 132 | local index 133 | local valid = false 134 | while not valid do 135 | -- start at 2 because of previous action 136 | index = torch.random(2, self.numEntries-self.recentMemSize) 137 | if self.t[index+self.recentMemSize-1] == 0 then 138 | valid = true 139 | end 140 | if self.nonTermProb < 1 and self.t[index+self.recentMemSize] == 0 and 141 | torch.uniform() > self.nonTermProb then 142 | -- Discard non-terminal states with probability (1-nonTermProb). 143 | -- Note that this is the terminal flag for s_{t+1}. 144 | valid = false 145 | end 146 | if self.nonEventProb < 1 and self.t[index+self.recentMemSize] == 0 and 147 | self.r[index+self.recentMemSize-1] == 0 and 148 | torch.uniform() > self.nonTermProb then 149 | -- Discard non-terminal or non-reward states with 150 | -- probability (1-nonTermProb). 151 | valid = false 152 | end 153 | end 154 | 155 | return self:get(index) 156 | end 157 | 158 | 159 | function trans:sample(batch_size) 160 | local batch_size = batch_size or 1 161 | assert(batch_size < self.bufferSize) 162 | 163 | if not self.buf_ind or self.buf_ind + batch_size - 1 > self.bufferSize then 164 | self:fill_buffer() 165 | end 166 | 167 | local index = self.buf_ind 168 | 169 | self.buf_ind = self.buf_ind+batch_size 170 | local range = {{index, index+batch_size-1}} 171 | 172 | local buf_s, buf_s2, buf_a, buf_r, buf_term, buf_subgoal, buf_subgoal2 = self.buf_s, self.buf_s2, 173 | self.buf_a, self.buf_r, self.buf_term, self.buf_subgoal, self.buf_subgoal2 174 | if self.gpu and self.gpu >=0 then 175 | buf_s = self.gpu_s 176 | buf_s2 = self.gpu_s2 177 | buf_subgoal = self.gpu_subgoal 178 | buf_subgoal2 = self.gpu_subgoal2 179 | end 180 | 181 | return buf_s[range], buf_a[range], buf_r[range], buf_s2[range], buf_term[range], buf_subgoal[range], buf_subgoal2[range] 182 | end 183 | 184 | 185 | function trans:concatFrames(index, use_recent) 186 | if use_recent then 187 | s, t, subgoal = self.recent_s, self.recent_t, self.recent_subgoal[self.histLen] 188 | else 189 | s, t, subgoal = self.s, self.t, self.subgoal[index] 190 | end 191 | 192 | local fullstate = s[1].new() 193 | fullstate:resize(self.histLen, unpack(s[1]:size():totable())) 194 | 195 | -- Zero out frames from all but the most recent episode. 196 | local zero_out = false 197 | local episode_start = self.histLen 198 | 199 | for i=self.histLen-1,1,-1 do 200 | if not zero_out then 201 | for j=index+self.histIndices[i]-1,index+self.histIndices[i+1]-2 do 202 | if t[j] == 1 then 203 | zero_out = true 204 | break 205 | end 206 | end 207 | end 208 | 209 | if zero_out then 210 | fullstate[i]:zero() 211 | else 212 | episode_start = i 213 | end 214 | end 215 | 216 | if self.zeroFrames == 0 then 217 | episode_start = 1 218 | end 219 | 220 | -- Copy frames from the current episode. 221 | for i=episode_start,self.histLen do 222 | fullstate[i]:copy(s[index+self.histIndices[i]-1]) 223 | end 224 | 225 | return fullstate, subgoal 226 | end 227 | 228 | 229 | function trans:concatActions(index, use_recent) 230 | local act_hist = torch.FloatTensor(self.histLen, self.numActions) 231 | if use_recent then 232 | a, t = self.recent_a, self.recent_t 233 | else 234 | a, t = self.a, self.t 235 | end 236 | 237 | -- Zero out frames from all but the most recent episode. 238 | local zero_out = false 239 | local episode_start = self.histLen 240 | 241 | for i=self.histLen-1,1,-1 do 242 | if not zero_out then 243 | for j=index+self.histIndices[i]-1,index+self.histIndices[i+1]-2 do 244 | if t[j] == 1 then 245 | zero_out = true 246 | break 247 | end 248 | end 249 | end 250 | 251 | if zero_out then 252 | act_hist[i]:zero() 253 | else 254 | episode_start = i 255 | end 256 | end 257 | 258 | if self.zeroFrames == 0 then 259 | episode_start = 1 260 | end 261 | 262 | -- Copy frames from the current episode. 263 | for i=episode_start,self.histLen do 264 | act_hist[i]:copy(self.action_encodings[a[index+self.histIndices[i]-1]]) 265 | end 266 | 267 | return act_hist 268 | end 269 | 270 | 271 | function trans:get_recent() 272 | -- Assumes that the most recent state has been added, but the action has not 273 | local fullstate, subgoal = self:concatFrames(1,true) 274 | return fullstate:float():div(255), subgoal 275 | end 276 | 277 | 278 | function trans:get(index) 279 | local s, subgoal = self:concatFrames(index) 280 | local s2, subgoal2 = self:concatFrames(index+1) 281 | local ar_index = index+self.recentMemSize-1 282 | return s, self.a[ar_index], self.r[ar_index], s2, self.t[ar_index+1], self.subgoal[ar_index], self.subgoal[ar_index+1] 283 | end 284 | 285 | 286 | function trans:add(s, a, r, term, subgoal) 287 | assert(s, 'State cannot be nil') 288 | assert(a, 'Action cannot be nil') 289 | assert(r, 'Reward cannot be nil') 290 | 291 | -- Incremenet until at full capacity 292 | if self.numEntries < self.maxSize then 293 | self.numEntries = self.numEntries + 1 294 | end 295 | 296 | -- Always insert at next index, then wrap around 297 | self.insertIndex = self.insertIndex + 1 298 | -- Overwrite oldest experience once at capacity 299 | if self.insertIndex > self.maxSize then 300 | self.insertIndex = 1 301 | end 302 | 303 | -- Overwrite (s,a,r,t) at insertIndex 304 | self.s[self.insertIndex] = s:clone():float():mul(255) 305 | self.a[self.insertIndex] = a 306 | self.r[self.insertIndex] = r 307 | self.subgoal[self.insertIndex] = subgoal 308 | 309 | if term then 310 | self.t[self.insertIndex] = 1 311 | else 312 | self.t[self.insertIndex] = 0 313 | end 314 | end 315 | 316 | 317 | function trans:add_recent_state(s, term, subgoal) 318 | local s = s:clone():float():mul(255):byte() 319 | local subgoal = subgoal:clone() 320 | if #self.recent_s == 0 then 321 | for i=1,self.recentMemSize do 322 | table.insert(self.recent_s, s:clone():zero()) 323 | table.insert(self.recent_t, 1) 324 | table.insert(self.recent_subgoal, subgoal:clone():zero()) 325 | end 326 | end 327 | 328 | table.insert(self.recent_s, s) 329 | table.insert(self.recent_subgoal, subgoal) 330 | if term then 331 | table.insert(self.recent_t, 1) 332 | else 333 | table.insert(self.recent_t, 0) 334 | end 335 | 336 | -- Keep recentMemSize states. 337 | if #self.recent_s > self.recentMemSize then 338 | table.remove(self.recent_s, 1) 339 | table.remove(self.recent_t, 1) 340 | table.remove(self.recent_subgoal, 1) 341 | end 342 | end 343 | 344 | 345 | function trans:add_recent_action(a) 346 | if #self.recent_a == 0 then 347 | for i=1,self.recentMemSize do 348 | table.insert(self.recent_a, 1) 349 | end 350 | end 351 | 352 | table.insert(self.recent_a, a) 353 | 354 | -- Keep recentMemSize steps. 355 | if #self.recent_a > self.recentMemSize then 356 | table.remove(self.recent_a, 1) 357 | end 358 | end 359 | 360 | 361 | --[[ 362 | Override the write function to serialize this class into a file. 363 | We do not want to store anything into the file, just the necessary info 364 | to create an empty transition table. 365 | 366 | @param file (FILE object ) @see torch.DiskFile 367 | --]] 368 | function trans:write(file) 369 | file:writeObject({self.stateDim, 370 | self.numActions, 371 | self.histLen, 372 | self.maxSize, 373 | self.bufferSize, 374 | self.numEntries, 375 | self.insertIndex, 376 | self.recentMemSize, 377 | self.histIndices, 378 | self.subgoal_dims}) 379 | end 380 | 381 | 382 | --[[ 383 | Override the read function to desearialize this class from file. 384 | Recreates an empty table. 385 | 386 | @param file (FILE object ) @see torch.DiskFile 387 | --]] 388 | function trans:read(file) 389 | local stateDim, numActions, histLen, maxSize, bufferSize, numEntries, insertIndex, recentMemSize, histIndices, subgoal_dims = unpack(file:readObject()) 390 | self.stateDim = stateDim 391 | self.numActions = numActions 392 | self.histLen = histLen 393 | self.maxSize = maxSize 394 | self.bufferSize = bufferSize 395 | self.recentMemSize = recentMemSize 396 | self.histIndices = histIndices 397 | self.numEntries = 0 398 | self.insertIndex = 0 399 | self.subgoal_dims = subgoal_dims 400 | 401 | self.s = torch.ByteTensor(self.maxSize, self.stateDim):fill(0) 402 | self.a = torch.LongTensor(self.maxSize):fill(0) 403 | self.r = torch.zeros(self.maxSize, 2) 404 | self.t = torch.ByteTensor(self.maxSize):fill(0) 405 | self.subgoal = torch.zeros(self.maxSize, self.subgoal_dims) 406 | self.action_encodings = torch.eye(self.numActions) 407 | 408 | -- Tables for storing the last histLen states. They are used for 409 | -- constructing the most recent agent state more easily. 410 | self.recent_s = {} 411 | self.recent_a = {} 412 | self.recent_t = {} 413 | self.recent_subgoal = {} 414 | 415 | self.buf_a = torch.LongTensor(self.bufferSize):fill(0) 416 | self.buf_r = torch.zeros(self.bufferSize, 2) 417 | self.buf_term = torch.ByteTensor(self.bufferSize):fill(0) 418 | self.buf_s = torch.ByteTensor(self.bufferSize, self.stateDim * self.histLen):fill(0) 419 | self.buf_s2 = torch.ByteTensor(self.bufferSize, self.stateDim * self.histLen):fill(0) 420 | self.buf_subgoal = torch.zeros(self.bufferSize, self.subgoal_dims) 421 | self.buf_subgoal2 = torch.zeros(self.bufferSize, self.subgoal_dims) 422 | 423 | if self.gpu and self.gpu >= 0 then 424 | self.gpu_s = self.buf_s:float():cuda() 425 | self.gpu_s2 = self.buf_s2:float():cuda() 426 | self.gpu_subgoal = self.buf_subgoal:float():cuda() 427 | self.gpu_subgoal2 = self.buf_subgoal2:float():cuda() 428 | end 429 | end 430 | -------------------------------------------------------------------------------- /dqn/TransitionTable_spriority.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require 'image' 8 | 9 | local trans = torch.class('dqn.TransitionTable') 10 | 11 | 12 | function trans:__init(args) 13 | self.stateDim = args.stateDim 14 | self.numActions = args.numActions 15 | self.histLen = args.histLen 16 | self.maxSize = args.maxSize or 1024^2 17 | self.bufferSize = args.bufferSize or 1024 18 | 19 | self.histType = args.histType or "linear" 20 | self.histSpacing = args.histSpacing or 1 21 | self.zeroFrames = args.zeroFrames or 1 22 | self.nonTermProb = args.nonTermProb or 1 23 | self.nonEventProb = args.nonEventProb or 1 24 | self.gpu = args.gpu 25 | self.numEntries = 0 26 | self.insertIndex = 0 27 | self.ptrInsertIndex = 1 28 | 29 | self.histIndices = {} 30 | local histLen = self.histLen 31 | if self.histType == "linear" then 32 | -- History is the last histLen frames. 33 | self.recentMemSize = self.histSpacing*histLen 34 | for i=1,histLen do 35 | self.histIndices[i] = i*self.histSpacing 36 | end 37 | elseif self.histType == "exp2" then 38 | -- The ith history frame is from 2^(i-1) frames ago. 39 | self.recentMemSize = 2^(histLen-1) 40 | self.histIndices[1] = 1 41 | for i=1,histLen-1 do 42 | self.histIndices[i+1] = self.histIndices[i] + 2^(7-i) 43 | end 44 | elseif self.histType == "exp1.25" then 45 | -- The ith history frame is from 1.25^(i-1) frames ago. 46 | self.histIndices[histLen] = 1 47 | for i=histLen-1,1,-1 do 48 | self.histIndices[i] = math.ceil(1.25*self.histIndices[i+1])+1 49 | end 50 | self.recentMemSize = self.histIndices[1] 51 | for i=1,histLen do 52 | self.histIndices[i] = self.recentMemSize - self.histIndices[i] + 1 53 | end 54 | end 55 | 56 | self.s = torch.ByteTensor(self.maxSize, self.stateDim):fill(0) 57 | self.a = torch.LongTensor(self.maxSize):fill(0) 58 | self.r = torch.zeros(self.maxSize, 2) 59 | self.t = torch.ByteTensor(self.maxSize):fill(0) 60 | self.action_encodings = torch.eye(self.numActions) 61 | self.end_ptrs = {} 62 | self.dyn_ptrs = {} 63 | self.trace_indxs_with_extreward = {} --extrinsic reward 64 | self.trace_indxs_with_intreward = {} --intrinsic reward 65 | 66 | self.subgoal_dims = args.subgoal_dims*9 --TODO (total number of objects) 67 | self.subgoal = torch.zeros(self.maxSize, self.subgoal_dims) 68 | 69 | -- Tables for storing the last histLen states. They are used for 70 | -- constructing the most recent agent state more easily. 71 | self.recent_s = {} 72 | self.recent_a = {} 73 | self.recent_t = {} 74 | self.recent_subgoal = {} 75 | 76 | local s_size = self.stateDim*histLen 77 | self.buf_a = torch.LongTensor(self.bufferSize):fill(0) 78 | self.buf_r = torch.zeros(self.bufferSize,2 ) 79 | self.buf_term = torch.ByteTensor(self.bufferSize):fill(0) 80 | self.buf_s = torch.ByteTensor(self.bufferSize, s_size):fill(0) 81 | self.buf_s2 = torch.ByteTensor(self.bufferSize, s_size):fill(0) 82 | self.buf_subgoal = torch.zeros(self.bufferSize, self.subgoal_dims) 83 | self.buf_subgoal2 = torch.zeros(self.bufferSize, self.subgoal_dims) 84 | 85 | 86 | if self.gpu and self.gpu >= 0 then 87 | self.gpu_s = self.buf_s:float():cuda() 88 | self.gpu_s2 = self.buf_s2:float():cuda() 89 | self.gpu_subgoal = self.buf_subgoal:float():cuda() 90 | self.gpu_subgoal2 = self.buf_subgoal2:float():cuda() 91 | end 92 | end 93 | 94 | 95 | function trans:reset() 96 | self.numEntries = 0 97 | self.insertIndex = 0 98 | self.ptrInsertIndex = 1 99 | end 100 | 101 | 102 | function trans:size() 103 | return self.numEntries 104 | end 105 | 106 | 107 | function trans:empty() 108 | return self.numEntries == 0 109 | end 110 | 111 | 112 | function trans:fill_buffer() 113 | assert(self.numEntries >= self.bufferSize) 114 | -- clear CPU buffers 115 | self.buf_ind = 1 116 | local ind 117 | for buf_ind=1,self.bufferSize do 118 | local s, a, r, s2, term, subgoal, subgoal2 = self:sample_one(1) 119 | self.buf_s[buf_ind]:copy(s) 120 | self.buf_a[buf_ind] = a 121 | self.buf_subgoal[buf_ind] = subgoal 122 | self.buf_subgoal2[buf_ind] = subgoal2 123 | self.buf_r[buf_ind] = r 124 | self.buf_s2[buf_ind]:copy(s2) 125 | self.buf_term[buf_ind] = term 126 | end 127 | self.buf_s = self.buf_s:float():div(255) 128 | self.buf_s2 = self.buf_s2:float():div(255) 129 | if self.gpu and self.gpu >= 0 then 130 | self.gpu_s:copy(self.buf_s) 131 | self.gpu_s2:copy(self.buf_s2) 132 | self.gpu_subgoal:copy(self.buf_subgoal) 133 | self.gpu_subgoal2:copy(self.buf_subgoal2) 134 | end 135 | end 136 | 137 | function trans:get_size(tab) 138 | if tab == nil then return 0 end 139 | local Count = 0 140 | for Index, Value in pairs(tab) do 141 | Count = Count + 1 142 | end 143 | return Count 144 | end 145 | 146 | function trans:get_canonical_indices() 147 | local indx; 148 | local index = -1 149 | while index <= 0 do 150 | indx = torch.random(#self.end_ptrs-1) 151 | index = self.dyn_ptrs[indx] - self.recentMemSize + 1 152 | end 153 | return indx, index 154 | end 155 | 156 | function trans:sample_one() 157 | assert(self.numEntries > 1) 158 | assert(#self.end_ptrs == #self.dyn_ptrs) 159 | -- print(self.end_ptrs) 160 | local index = -1 161 | local indx 162 | 163 | --- choose to either select traces with external or internal reward 164 | local chosen_trace_indxs = self.trace_indxs_with_extreward 165 | if self:get_size(self.trace_indxs_with_extreward) == 0 then 166 | chosen_trace_indxs = self.trace_indxs_with_intreward 167 | else 168 | if torch.uniform() > 0.5 then 169 | chosen_trace_indxs = self.trace_indxs_with_intreward 170 | end 171 | end 172 | 173 | local eps = 0.33; 174 | 175 | if torch.uniform() < eps or self:get_size(chosen_trace_indxs) <= 0 then 176 | --randomly sample without prioritization 177 | indx, index = self:get_canonical_indices() 178 | else 179 | -- prioritize and pick from stored transitions with rewards 180 | --this is only executed if #chosen_trace_indxs > 0, i.e. only if agent has received external reward 181 | while index <= 0 do 182 | local keyset={}; local n=0; 183 | for k,v in pairs(chosen_trace_indxs) do 184 | if k <= self.maxSize - self.histLen + 1 then 185 | n=n+1 186 | keyset[n]=k 187 | end 188 | end 189 | -- print('K:', keyset) 190 | if #keyset == 0 then 191 | indx, index = self:get_canonical_indices() 192 | break 193 | end 194 | local mem_indx = keyset[torch.random(#keyset)] 195 | -- print('mem_indx:', mem_indx) 196 | -- print('R:', chosen_trace_indxs) 197 | -- print('DYN:', self.dyn_ptrs) 198 | -- print('mem_indx:', mem_indx) 199 | -- print('END:', self.end_ptrs) 200 | for k,v in pairs(self.end_ptrs) do 201 | if v == mem_indx then 202 | indx = k 203 | end 204 | end 205 | if indx then 206 | index = self.dyn_ptrs[indx] - self.recentMemSize + 1 207 | else 208 | indx, index = self:get_canonical_indices() 209 | break 210 | end 211 | -- this is a corner case: when there is only 2 eps (fix this TODO) with reward but index is zero 212 | if index <= 0 and self:get_size(chosen_trace_indxs) <= 2 then 213 | indx, index = self:get_canonical_indices() 214 | -- print('INDEX:', index) 215 | break 216 | end 217 | end 218 | end 219 | -- print(index, indx) 220 | self.dyn_ptrs[indx] = self.dyn_ptrs[indx] - 1 221 | if self.dyn_ptrs[indx] <= 0 or self.dyn_ptrs[indx] == self.end_ptrs[indx-1] then 222 | self.dyn_ptrs[indx] = self.end_ptrs[indx] 223 | end 224 | return self:get(index) 225 | end 226 | 227 | 228 | 229 | function trans:sample(batch_size) 230 | local batch_size = batch_size or 1 231 | assert(batch_size < self.bufferSize) 232 | 233 | if not self.buf_ind or self.buf_ind + batch_size - 1 > self.bufferSize then 234 | self:fill_buffer() 235 | end 236 | 237 | local index = self.buf_ind 238 | 239 | self.buf_ind = self.buf_ind+batch_size 240 | local range = {{index, index+batch_size-1}} 241 | 242 | local buf_s, buf_s2, buf_a, buf_r, buf_term, buf_subgoal, buf_subgoal2 = self.buf_s, self.buf_s2, 243 | self.buf_a, self.buf_r, self.buf_term, self.buf_subgoal, self.buf_subgoal2 244 | if self.gpu and self.gpu >=0 then 245 | buf_s = self.gpu_s 246 | buf_s2 = self.gpu_s2 247 | buf_subgoal = self.gpu_subgoal 248 | buf_subgoal2 = self.gpu_subgoal2 249 | end 250 | 251 | return buf_s[range], buf_a[range], buf_r[range], buf_s2[range], buf_term[range], buf_subgoal[range], buf_subgoal2[range] 252 | end 253 | 254 | 255 | function trans:concatFrames(index, use_recent) 256 | if use_recent then 257 | s, t, subgoal = self.recent_s, self.recent_t, self.recent_subgoal[self.histLen] 258 | else 259 | s, t, subgoal = self.s, self.t, self.subgoal[index] 260 | end 261 | 262 | local fullstate = s[1].new() 263 | fullstate:resize(self.histLen, unpack(s[1]:size():totable())) 264 | 265 | -- Zero out frames from all but the most recent episode. 266 | local zero_out = false 267 | local episode_start = self.histLen 268 | 269 | for i=self.histLen-1,1,-1 do 270 | if not zero_out then 271 | for j=index+self.histIndices[i]-1,index+self.histIndices[i+1]-2 do 272 | if t[j] == 1 then 273 | zero_out = true 274 | break 275 | end 276 | end 277 | end 278 | 279 | if zero_out then 280 | fullstate[i]:zero() 281 | else 282 | episode_start = i 283 | end 284 | end 285 | 286 | if self.zeroFrames == 0 then 287 | episode_start = 1 288 | end 289 | 290 | -- Copy frames from the current episode. 291 | for i=episode_start,self.histLen do 292 | fullstate[i]:copy(s[index+self.histIndices[i]-1]) 293 | end 294 | return fullstate, subgoal 295 | end 296 | 297 | 298 | function trans:concatActions(index, use_recent) 299 | local act_hist = torch.FloatTensor(self.histLen, self.numActions) 300 | if use_recent then 301 | a, t = self.recent_a, self.recent_t 302 | else 303 | a, t = self.a, self.t 304 | end 305 | 306 | -- Zero out frames from all but the most recent episode. 307 | local zero_out = false 308 | local episode_start = self.histLen 309 | 310 | for i=self.histLen-1,1,-1 do 311 | if not zero_out then 312 | for j=index+self.histIndices[i]-1,index+self.histIndices[i+1]-2 do 313 | if t[j] == 1 then 314 | zero_out = true 315 | break 316 | end 317 | end 318 | end 319 | 320 | if zero_out then 321 | act_hist[i]:zero() 322 | else 323 | episode_start = i 324 | end 325 | end 326 | 327 | if self.zeroFrames == 0 then 328 | episode_start = 1 329 | end 330 | 331 | -- Copy frames from the current episode. 332 | for i=episode_start,self.histLen do 333 | act_hist[i]:copy(self.action_encodings[a[index+self.histIndices[i]-1]]) 334 | end 335 | 336 | return act_hist 337 | end 338 | 339 | 340 | function trans:get_recent() 341 | -- Assumes that the most recent state has been added, but the action has not 342 | -- return self:concatFrames(1, true):float():div(255) 343 | 344 | local fullstate, subgoal = self:concatFrames(1,true) 345 | return fullstate:float():div(255), subgoal 346 | 347 | end 348 | 349 | 350 | function trans:get(index) 351 | local s, subgoal = self:concatFrames(index) 352 | local s2, subgoal2 = self:concatFrames(index+1) 353 | local ar_index = index+self.recentMemSize-1 354 | -- print(index) 355 | return s, self.a[ar_index], self.r[ar_index], s2, self.t[ar_index+1], self.subgoal[ar_index], self.subgoal[ar_index+1] 356 | end 357 | 358 | 359 | function trans:add(s, a, r, term, subgoal) 360 | -- print('TT:', term, r) 361 | assert(s, 'State cannot be nil') 362 | assert(a, 'Action cannot be nil') 363 | assert(r, 'Reward cannot be nil') 364 | 365 | -- Incremenet until at full capacity 366 | if self.numEntries < self.maxSize then 367 | self.numEntries = self.numEntries + 1 368 | end 369 | 370 | -- Always insert at next index, then wrap around 371 | self.insertIndex = self.insertIndex + 1 372 | 373 | 374 | 375 | -- Overwrite oldest experience once at capacity 376 | if self.insertIndex > self.maxSize then 377 | self.insertIndex = 1 378 | self.ptrInsertIndex = 1 379 | end 380 | 381 | -- Overwrite (s,a,r,t) at insertIndex 382 | self.s[self.insertIndex] = s:clone():float():mul(255) 383 | self.a[self.insertIndex] = a 384 | self.r[self.insertIndex] = r 385 | self.subgoal[self.insertIndex] = subgoal 386 | 387 | if r[1] > 0 then --if extrinsic reward is non-zero, record this! 388 | self.trace_indxs_with_extreward[self.insertIndex] = 1 389 | end 390 | 391 | local intrinsic_reward = r[2] - r[1] 392 | if intrinsic_reward > 0 then --if extrinsic reward is non-zero, record this! 393 | self.trace_indxs_with_intreward[self.insertIndex] = 1 394 | end 395 | 396 | if self.end_ptrs[self.ptrInsertIndex] == self.insertIndex then 397 | table.remove(self.end_ptrs,self.ptrInsertIndex) 398 | table.remove(self.dyn_ptrs,self.ptrInsertIndex) 399 | self.trace_indxs_with_extreward[self.insertIndex] = nil 400 | self.trace_indxs_with_intreward[self.insertIndex] = nil 401 | end 402 | if term then 403 | self.t[self.insertIndex] = 1 404 | table.insert(self.end_ptrs, self.ptrInsertIndex, self.insertIndex) 405 | table.insert(self.dyn_ptrs, self.ptrInsertIndex, self.insertIndex) 406 | self.ptrInsertIndex = self.ptrInsertIndex + 1 407 | else 408 | self.t[self.insertIndex] = 0 409 | end 410 | -- print(#self.end_ptrs, term) 411 | end 412 | 413 | 414 | function trans:add_recent_state(s, term, subgoal) 415 | local s = s:clone():float():mul(255):byte() 416 | local subgoal = subgoal:clone() 417 | if #self.recent_s == 0 then 418 | for i=1,self.recentMemSize do 419 | table.insert(self.recent_s, s:clone():zero()) 420 | table.insert(self.recent_t, 1) 421 | table.insert(self.recent_subgoal, subgoal:clone():zero()) 422 | end 423 | end 424 | 425 | table.insert(self.recent_s, s) 426 | table.insert(self.recent_subgoal, subgoal) 427 | if term then 428 | table.insert(self.recent_t, 1) 429 | else 430 | table.insert(self.recent_t, 0) 431 | end 432 | 433 | -- Keep recentMemSize states. 434 | if #self.recent_s > self.recentMemSize then 435 | table.remove(self.recent_s, 1) 436 | table.remove(self.recent_t, 1) 437 | end 438 | end 439 | 440 | 441 | function trans:add_recent_action(a) 442 | if #self.recent_a == 0 then 443 | for i=1,self.recentMemSize do 444 | table.insert(self.recent_a, 1) 445 | end 446 | end 447 | 448 | table.insert(self.recent_a, a) 449 | 450 | -- Keep recentMemSize steps. 451 | if #self.recent_a > self.recentMemSize then 452 | table.remove(self.recent_a, 1) 453 | end 454 | end 455 | 456 | 457 | --[[ 458 | Override the write function to serialize this class into a file. 459 | We do not want to store anything into the file, just the necessary info 460 | to create an empty transition table. 461 | 462 | @param file (FILE object ) @see torch.DiskFile 463 | --]] 464 | function trans:write(file) 465 | file:writeObject({self.stateDim, 466 | self.numActions, 467 | self.histLen, 468 | self.maxSize, 469 | self.bufferSize, 470 | self.numEntries, 471 | self.insertIndex, 472 | self.recentMemSize, 473 | self.histIndices, 474 | self.subgoal_dims}) 475 | end 476 | 477 | 478 | --[[ 479 | Override the read function to desearialize this class from file. 480 | Recreates an empty table. 481 | 482 | @param file (FILE object ) @see torch.DiskFile 483 | --]] 484 | function trans:read(file) 485 | local stateDim, numActions, histLen, maxSize, bufferSize, numEntries, insertIndex, recentMemSize, histIndices, subgoal_dims = unpack(file:readObject()) 486 | self.stateDim = stateDim 487 | self.numActions = numActions 488 | self.histLen = histLen 489 | self.maxSize = maxSize 490 | self.bufferSize = bufferSize 491 | self.recentMemSize = recentMemSize 492 | self.histIndices = histIndices 493 | self.numEntries = 0 494 | self.insertIndex = 0 495 | self.subgoal_dims = subgoal_dims 496 | 497 | self.s = torch.ByteTensor(self.maxSize, self.stateDim):fill(0) 498 | self.a = torch.LongTensor(self.maxSize):fill(0) 499 | self.r = torch.zeros(self.maxSize, 2) 500 | self.t = torch.ByteTensor(self.maxSize):fill(0) 501 | self.subgoal = torch.zeros(self.maxSize, self.subgoal_dims) 502 | self.action_encodings = torch.eye(self.numActions) 503 | 504 | -- Tables for storing the last histLen states. They are used for 505 | -- constructing the most recent agent state more easily. 506 | self.recent_s = {} 507 | self.recent_a = {} 508 | self.recent_t = {} 509 | 510 | self.buf_a = torch.LongTensor(self.bufferSize):fill(0) 511 | self.buf_r = torch.zeros(self.bufferSize, 2) 512 | self.buf_term = torch.ByteTensor(self.bufferSize):fill(0) 513 | self.buf_s = torch.ByteTensor(self.bufferSize, self.stateDim * self.histLen):fill(0) 514 | self.buf_s2 = torch.ByteTensor(self.bufferSize, self.stateDim * self.histLen):fill(0) 515 | self.buf_subgoal = torch.zeros(self.bufferSize, self.subgoal_dims) 516 | self.buf_subgoal2 = torch.zeros(self.bufferSize, self.subgoal_dims) 517 | 518 | if self.gpu and self.gpu >= 0 then 519 | self.gpu_s = self.buf_s:float():cuda() 520 | self.gpu_s2 = self.buf_s2:float():cuda() 521 | self.gpu_subgoal = self.buf_subgoal:float():cuda() 522 | self.gpu_subgoal2 = self.buf_subgoal2:float():cuda() 523 | end 524 | end 525 | -------------------------------------------------------------------------------- /dqn/base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/base.png -------------------------------------------------------------------------------- /dqn/convnet.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "initenv" 8 | 9 | function create_network(args) 10 | 11 | local net = nn.Sequential() 12 | net:add(nn.Reshape(unpack(args.input_dims))) 13 | 14 | --- first convolutional layer 15 | local convLayer = nn.SpatialConvolution 16 | 17 | net:add(convLayer(args.hist_len*args.ncols, args.n_units[1], 18 | args.filter_size[1], args.filter_size[1], 19 | args.filter_stride[1], args.filter_stride[1],1)) 20 | net:add(args.nl()) 21 | 22 | -- Add convolutional layers 23 | for i=1,(#args.n_units-1) do 24 | -- second convolutional layer 25 | net:add(convLayer(args.n_units[i], args.n_units[i+1], 26 | args.filter_size[i+1], args.filter_size[i+1], 27 | args.filter_stride[i+1], args.filter_stride[i+1])) 28 | net:add(args.nl()) 29 | end 30 | 31 | local nel 32 | if args.gpu >= 0 then 33 | nel = net:cuda():forward(torch.zeros(1,unpack(args.input_dims)) 34 | :cuda()):nElement() 35 | else 36 | nel = net:forward(torch.zeros(1,unpack(args.input_dims))):nElement() 37 | end 38 | 39 | -- reshape all feature planes into a vector per example 40 | net:add(nn.Reshape(nel)) 41 | 42 | -- join vectors 43 | 44 | local subgoal_proc = nn.Sequential() 45 | :add(nn.Linear(args.subgoal_dims*9, args.subgoal_nhid)) 46 | :add(nn.ReLU()) 47 | :add(nn.Linear(args.subgoal_nhid,args.subgoal_nhid)) 48 | :add(nn.ReLU()) 49 | 50 | local net_parallel = nn.ParallelTable(2) 51 | net_parallel:add(net) 52 | net_parallel:add(subgoal_proc) 53 | 54 | local full_net = nn.Sequential() 55 | full_net:add(net_parallel) 56 | full_net:add(nn.JoinTable(2)) 57 | 58 | 59 | -- fully connected layer 60 | full_net:add(nn.Linear(nel+args.subgoal_nhid, args.n_hid[1])) 61 | full_net:add(args.nl()) 62 | local last_layer_size = args.n_hid[1] 63 | 64 | for i=1,(#args.n_hid-1) do 65 | -- add Linear layer 66 | last_layer_size = args.n_hid[i+1] 67 | full_net:add(nn.Linear(args.n_hid[i], last_layer_size)) 68 | full_net:add(args.nl()) 69 | end 70 | 71 | 72 | -- add the last fully connected layer (to actions) 73 | full_net:add(nn.Linear(last_layer_size, args.n_actions)) 74 | 75 | if args.gpu >=0 then 76 | full_net:cuda() 77 | end 78 | if args.verbose >= 2 then 79 | print(full_net) 80 | print('Convolutional layers flattened output size:', nel) 81 | end 82 | return full_net 83 | end 84 | -------------------------------------------------------------------------------- /dqn/convnet_atari3.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require 'convnet' 8 | 9 | return function(args) 10 | args.n_units = {32, 64, 64} 11 | args.filter_size = {8, 4, 3} 12 | args.filter_stride = {4, 2, 1} 13 | args.n_hid = {512} 14 | args.nl = nn.Rectifier 15 | 16 | return create_network(args) 17 | end 18 | 19 | -------------------------------------------------------------------------------- /dqn/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/image.png -------------------------------------------------------------------------------- /dqn/initenv.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | dqn = {} 7 | 8 | require 'torch' 9 | require 'nn' 10 | require 'nngraph' 11 | require 'nnutils' 12 | require 'image' 13 | require 'Scale' 14 | require 'NeuralQLearner' 15 | require 'TransitionTable_spriority' 16 | require 'Rectifier' 17 | 18 | 19 | function torchSetup(_opt) 20 | _opt = _opt or {} 21 | local opt = table.copy(_opt) 22 | assert(opt) 23 | 24 | -- preprocess options: 25 | --- convert options strings to tables 26 | if opt.pool_frms then 27 | opt.pool_frms = str_to_table(opt.pool_frms) 28 | end 29 | if opt.env_params then 30 | opt.env_params = str_to_table(opt.env_params) 31 | end 32 | if opt.agent_params then 33 | opt.agent_params = str_to_table(opt.agent_params) 34 | opt.agent_params.gpu = opt.gpu 35 | opt.agent_params.best = opt.best 36 | opt.agent_params.verbose = opt.verbose 37 | if opt.network ~= '' then 38 | opt.agent_params.network = opt.network 39 | end 40 | end 41 | 42 | --- general setup 43 | opt.tensorType = opt.tensorType or 'torch.FloatTensor' 44 | torch.setdefaulttensortype(opt.tensorType) 45 | if not opt.threads then 46 | opt.threads = 4 47 | end 48 | torch.setnumthreads(opt.threads) 49 | if not opt.verbose then 50 | opt.verbose = 10 51 | end 52 | if opt.verbose >= 1 then 53 | print('Torch Threads:', torch.getnumthreads()) 54 | end 55 | 56 | --- set gpu device 57 | if opt.gpu and opt.gpu >= 0 then 58 | require 'cutorch' 59 | require 'cunn' 60 | if opt.gpu == 0 then 61 | local gpu_id = tonumber(os.getenv('GPU_ID')) 62 | if gpu_id then opt.gpu = gpu_id+1 end 63 | end 64 | if opt.gpu > 0 then cutorch.setDevice(opt.gpu) end 65 | opt.gpu = cutorch.getDevice() 66 | print('Using GPU device id:', opt.gpu-1) 67 | else 68 | opt.gpu = -1 69 | if opt.verbose >= 1 then 70 | print('Using CPU code only. GPU device id:', opt.gpu) 71 | end 72 | end 73 | 74 | --- set up random number generators 75 | -- removing lua RNG; seeding torch RNG with opt.seed and setting cutorch 76 | -- RNG seed to the first uniform random int32 from the previous RNG; 77 | -- this is preferred because using the same seed for both generators 78 | -- may introduce correlations; we assume that both torch RNGs ensure 79 | -- adequate dispersion for different seeds. 80 | math.random = nil 81 | opt.seed = opt.seed or 1 82 | torch.manualSeed(opt.seed) 83 | if opt.verbose >= 1 then 84 | print('Torch Seed:', torch.initialSeed()) 85 | end 86 | local firstRandInt = torch.random() 87 | if opt.gpu >= 0 then 88 | cutorch.manualSeed(firstRandInt) 89 | if opt.verbose >= 1 then 90 | print('CUTorch Seed:', cutorch.initialSeed()) 91 | end 92 | end 93 | 94 | return opt 95 | end 96 | 97 | 98 | function setup(_opt) 99 | assert(_opt) 100 | 101 | --preprocess options: 102 | --- convert options strings to tables 103 | _opt.pool_frms = str_to_table(_opt.pool_frms) 104 | _opt.env_params = str_to_table(_opt.env_params) 105 | _opt.agent_params = str_to_table(_opt.agent_params) 106 | if _opt.agent_params.transition_params then 107 | _opt.agent_params.transition_params = 108 | str_to_table(_opt.agent_params.transition_params) 109 | end 110 | 111 | --- first things first 112 | local opt = torchSetup(_opt) 113 | 114 | -- load training framework and environment 115 | local framework = require(opt.framework) 116 | assert(framework) 117 | 118 | local gameEnv = framework.GameEnvironment(opt) 119 | local gameActions = gameEnv:getActions() 120 | 121 | -- agent options 122 | _opt.agent_params.actions = gameActions 123 | _opt.agent_params.gpu = _opt.gpu 124 | _opt.agent_params.best = _opt.best 125 | _opt.agent_params.subgoal_dims = _opt.subgoal_dims 126 | _opt.agent_params.subgoal_nhid = _opt.subgoal_nhid 127 | 128 | if _opt.network ~= '' then 129 | _opt.agent_params.network = _opt.network 130 | end 131 | _opt.agent_params.verbose = _opt.verbose 132 | if not _opt.agent_params.state_dim then 133 | _opt.agent_params.state_dim = gameEnv:nObsFeature() 134 | end 135 | 136 | local agent = dqn[_opt.agent](_opt.agent_params) 137 | 138 | if opt.verbose >= 1 then 139 | print('Set up Torch using these options:') 140 | for k, v in pairs(opt) do 141 | print(k, v) 142 | end 143 | end 144 | 145 | return gameEnv, gameActions, agent, opt 146 | end 147 | 148 | 149 | 150 | --- other functions 151 | 152 | function str_to_table(str) 153 | if type(str) == 'table' then 154 | return str 155 | end 156 | if not str or type(str) ~= 'string' then 157 | if type(str) == 'table' then 158 | return str 159 | end 160 | return {} 161 | end 162 | local ttr 163 | if str ~= '' then 164 | local ttx=tt 165 | loadstring('tt = {' .. str .. '}')() 166 | ttr = tt 167 | tt = ttx 168 | else 169 | ttr = {} 170 | end 171 | return ttr 172 | end 173 | 174 | function table.copy(t) 175 | if t == nil then return nil end 176 | local nt = {} 177 | for k, v in pairs(t) do 178 | if type(v) == 'table' then 179 | nt[k] = table.copy(v) 180 | else 181 | nt[k] = v 182 | end 183 | end 184 | setmetatable(nt, table.copy(getmetatable(t))) 185 | return nt 186 | end 187 | -------------------------------------------------------------------------------- /dqn/net_downsample_2x_full_y.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "image" 8 | require "Scale" 9 | 10 | local function create_network(args) 11 | -- Y (luminance) 12 | return nn.Scale(84, 84, true) 13 | end 14 | 15 | return create_network 16 | -------------------------------------------------------------------------------- /dqn/nnutils.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "torch" 8 | 9 | function recursive_map(module, field, func) 10 | local str = "" 11 | if module[field] or module.modules then 12 | str = str .. torch.typename(module) .. ": " 13 | end 14 | if module[field] then 15 | str = str .. func(module[field]) 16 | end 17 | if module.modules then 18 | str = str .. "[" 19 | for i, submodule in ipairs(module.modules) do 20 | local submodule_str = recursive_map(submodule, field, func) 21 | str = str .. submodule_str 22 | if i < #module.modules and string.len(submodule_str) > 0 then 23 | str = str .. " " 24 | end 25 | end 26 | str = str .. "]" 27 | end 28 | 29 | return str 30 | end 31 | 32 | function abs_mean(w) 33 | return torch.mean(torch.abs(w:clone():float())) 34 | end 35 | 36 | function abs_max(w) 37 | return torch.abs(w:clone():float()):max() 38 | end 39 | 40 | -- Build a string of average absolute weight values for the modules in the 41 | -- given network. 42 | function get_weight_norms(module) 43 | return "Weight norms:\n" .. recursive_map(module, "weight", abs_mean) .. 44 | "\nWeight max:\n" .. recursive_map(module, "weight", abs_max) 45 | end 46 | 47 | -- Build a string of average absolute weight gradient values for the modules 48 | -- in the given network. 49 | function get_grad_norms(module) 50 | return "Weight grad norms:\n" .. 51 | recursive_map(module, "gradWeight", abs_mean) .. 52 | "\nWeight grad max:\n" .. recursive_map(module, "gradWeight", abs_max) 53 | end 54 | -------------------------------------------------------------------------------- /dqn/pyserver.py: -------------------------------------------------------------------------------- 1 | import zmq 2 | import time 3 | import sys 4 | import cv2 5 | import numpy as np 6 | import copy 7 | import sys 8 | import json, pdb 9 | 10 | port = "5550" 11 | if len(sys.argv) > 1: 12 | port = int(sys.argv[1]) 13 | 14 | context = zmq.Context() 15 | socket = context.socket(zmq.REP) 16 | socket.bind("tcp://*:%s" % port) 17 | 18 | 19 | class Recognizer: 20 | def __init__(self): 21 | self.colors = {'man': [200, 72, 72], 'skull': [236,236,236]} 22 | self.map = {'man': 0, 'skull': 1, 'ladder': 2, 'door': 3, 'key': 4} 23 | 24 | def blob_detect(self, img, id): 25 | mask = np.zeros(np.shape(img)) 26 | mask[:,:,0] = self.colors[id][0]; 27 | mask[:,:,1] = self.colors[id][1]; 28 | mask[:,:,2] = self.colors[id][2]; 29 | 30 | diff = img - mask 31 | indxs = np.where(diff == 0) 32 | diff[np.where(diff < 0)] = 0 33 | diff[np.where(diff > 0)] = 0 34 | diff[indxs] = 255 35 | mean_y = np.sum(indxs[0]) / np.shape(indxs[0])[0] 36 | mean_x = np.sum(indxs[1]) / np.shape(indxs[1])[0] 37 | return (mean_y, mean_x) #flipped co-ords due to numpy blob detect 38 | # return (mean_x, mean_y) 39 | 40 | def template_detect(self, img, id): 41 | template = cv2.imread('templates/' + id + '.png') 42 | w = np.shape(template)[1] 43 | h = np.shape(template)[0] 44 | res = cv2.matchTemplate(img,template,cv2.TM_CCOEFF_NORMED) 45 | threshold = 0.8 46 | loc = np.where( res >= threshold) 47 | loc[0].setflags(write=True) 48 | loc[1].setflags(write=True) 49 | for i in range(np.shape(loc[0])[0]): 50 | loc[0][i] += h/2; loc[1][i] += w/2 51 | return loc, w, h 52 | 53 | def get(self, img): 54 | #detect man 55 | man_coords = self.blob_detect(img, 'man') 56 | skull_coords = self.blob_detect(img, 'skull') 57 | ladder_coords, ladder_w, ladder_h = self.template_detect(img, 'ladder') 58 | key_coords, key_w, key_h = self.template_detect(img, 'key') 59 | door_coords, door_w, door_h = self.template_detect(img, 'door_new') 60 | return {'man': man_coords, 'skull':skull_coords, 'ladder':ladder_coords, 'key':key_coords, 'door':door_coords, 'ladder_w': ladder_w, 61 | 'ladder_h':ladder_h , 'key_w':key_w, 'key_h':key_h, 'door_w':door_w, 'door_h':door_h} 62 | 63 | def drawbbox(self, inputim, coords): 64 | img = copy.deepcopy(inputim) 65 | for id in {'ladder', 'key', 'door'}: 66 | for pt in zip(*coords[id][::-1]): 67 | cv2.rectangle(img, pt, (pt[0] + coords[id+'_w'], pt[1] + coords[id+'_h']), (0,0,255), 2) 68 | cv2.rectangle(img, (coords['man'][0] - 5, coords['man'][1] - 5), (coords['man'][0] + 5, coords['man'][1] + 5), (0,0,255), 2) 69 | cv2.rectangle(img, (coords['skull'][0] - 5, coords['skull'][1] - 5), (coords['skull'][0] + 5, coords['skull'][1] + 5), (0,0,255), 2) 70 | return img 71 | 72 | def get_lives(self, img): 73 | return np.sum(img) 74 | 75 | def get_onehot(self, ID): 76 | tmp = list(np.zeros(len(self.map))) 77 | tmp[ID] = 1 78 | return tmp 79 | 80 | def process_objects(self, objects): 81 | objects_list = [] 82 | 83 | objects_list.append([objects['man'][0], objects['man'][1]] + self.get_onehot(self.map['man'])) 84 | objects_list.append([objects['skull'][0], objects['skull'][1]] + self.get_onehot(self.map['skull'])) 85 | 86 | for obj, val in objects.items(): 87 | # print(obj, val) 88 | if obj is not 'man' and obj is not 'skull': 89 | if type(val) is not type(1): 90 | if type(val[0]) == np.int64: 91 | objects_list.append([val[0], val[1]] + self.get_onehot(self.map[obj])) 92 | else: 93 | for i in range(np.shape(val[0])[0]): 94 | objects_list.append([val[0][i], val[1][i]] + self.get_onehot(self.map[obj])) 95 | #process objects and pad with zeros to ensure fixed length state dim 96 | fill_objects = 8 - len(objects_list) 97 | for j in range(fill_objects): 98 | objects_list.append([0, 0] + list(np.zeros(len(self.map)))) 99 | 100 | return objects_list 101 | 102 | 103 | def show(img): 104 | cv2.imshow('image',img) 105 | cv2.waitKey(0) 106 | # cv2.destroyAllWindows() 107 | 108 | def unit_test(): 109 | rec = Recognizer() 110 | try: 111 | img_id = str(sys.argv[1]) 112 | except: 113 | print 'Using default image 1.png' 114 | img_id = '1' 115 | img_rgb = cv2.imread('tmp.png') 116 | im_score = img_rgb[15:20, 55:95, :] 117 | img_rgb = img_rgb[30:,:,:] 118 | coords = rec.get(img_rgb) 119 | objects = rec.process_objects(coords) 120 | pdb.set_trace() 121 | img = rec.drawbbox(img_rgb, coords) 122 | show(img) 123 | 124 | # unit_test() 125 | 126 | rec = Recognizer() 127 | 128 | img_rgb = cv2.imread('base.png') 129 | im_score = img_rgb[15:20, 55:95, :] 130 | img_rgb = img_rgb[30:,:,:] 131 | coords = rec.get(img_rgb) 132 | objects_list_cache = rec.process_objects(coords) 133 | 134 | while True: 135 | # Wait for next request from client 136 | message = socket.recv() 137 | # print "Received request: ", message 138 | img_rgb = cv2.imread('tmp_'+str(port)+'.png') 139 | im_score = img_rgb[15:20, 55:95, :] 140 | img_rgb = img_rgb[30:,:,:] 141 | coords = rec.get(img_rgb) 142 | # img = rec.drawbbox(img_rgb, coords) 143 | # show(img) 144 | objects_list = copy.deepcopy(objects_list_cache) 145 | objects_list2 = rec.process_objects(coords) 146 | #agent and skull is dynamic. everything else is static. TODO for key 147 | objects_list[0] = objects_list2[0] 148 | objects_list[1] = objects_list2[1] 149 | if objects_list[1][0] == 0 and objects_list[1][1] == 0: 150 | objects_list[1][3] = 0 151 | 152 | assert len(objects_list) == len(objects_list2) 153 | # for ii in range(len(objects_list)): 154 | # if objects_list2[ii][0] == 0 and objects_list2[ii][1] == 0: 155 | # print('Lost something!') 156 | # objects_list[ii] = objects_list2[ii] 157 | # print(objects_list2[ii]) 158 | 159 | # print(len(objects_list)) 160 | socket.send('objlist = '+json.dumps(objects_list).replace('[','{').replace(']','}')) 161 | # socket.send("World from %s" % str(coords)) 162 | # print(rec.get_lives(im_score)) 163 | 164 | -------------------------------------------------------------------------------- /dqn/signal.def: -------------------------------------------------------------------------------- 1 | EXPORTS 2 | luaopen_signal 3 | -------------------------------------------------------------------------------- /dqn/signal.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/signal.so -------------------------------------------------------------------------------- /dqn/templates/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/templates/19.png -------------------------------------------------------------------------------- /dqn/templates/275.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/templates/275.png -------------------------------------------------------------------------------- /dqn/templates/291.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/templates/291.png -------------------------------------------------------------------------------- /dqn/templates/502.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/templates/502.png -------------------------------------------------------------------------------- /dqn/templates/door.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/templates/door.png -------------------------------------------------------------------------------- /dqn/templates/door_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/templates/door_new.png -------------------------------------------------------------------------------- /dqn/templates/key.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/templates/key.png -------------------------------------------------------------------------------- /dqn/templates/ladder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/templates/ladder.png -------------------------------------------------------------------------------- /dqn/templates/man (copy).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/templates/man (copy).png -------------------------------------------------------------------------------- /dqn/templates/man.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/templates/man.png -------------------------------------------------------------------------------- /dqn/templates/man_clean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/templates/man_clean.png -------------------------------------------------------------------------------- /dqn/templates/man_red.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/templates/man_red.png -------------------------------------------------------------------------------- /dqn/templates/skull.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/templates/skull.png -------------------------------------------------------------------------------- /dqn/templates/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/templates/test.png -------------------------------------------------------------------------------- /dqn/test_agent.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | See LICENSE file for full terms of limited license. 4 | ]] 5 | 6 | gd = require "gd" 7 | require 'torch' 8 | 9 | local cmd = torch.CmdLine() 10 | cmd:text() 11 | cmd:text('Train Agent in Environment:') 12 | cmd:text() 13 | cmd:text('Options:') 14 | 15 | cmd:option('-framework', '', 'name of training framework') 16 | cmd:option('-env', '', 'name of environment to use') 17 | cmd:option('-game_path', '', 'path to environment file (ROM)') 18 | cmd:option('-env_params', '', 'string of environment parameters') 19 | cmd:option('-pool_frms', '', 20 | 'string of frame pooling parameters (e.g.: size=2,type="max")') 21 | cmd:option('-actrep', 1, 'how many times to repeat action') 22 | cmd:option('-random_starts', 0, 'play action 0 between 1 and random_starts ' .. 23 | 'number of times at the start of each training episode') 24 | 25 | cmd:option('-name', '', 'filename used for saving network and training history') 26 | cmd:option('-network', '', 'reload pretrained network') 27 | cmd:option('-agent', '', 'name of agent file to use') 28 | cmd:option('-agent_params', '', 'string of agent parameters') 29 | cmd:option('-seed', torch.random(0,10000), 'fixed input seed for repeatable experiments') 30 | 31 | cmd:option('-verbose', 2, 32 | 'the higher the level, the more information is printed to screen') 33 | cmd:option('-threads', 1, 'number of BLAS threads') 34 | cmd:option('-gpu', -1, 'gpu flag') 35 | cmd:option('-gif_file', '', 'GIF path to write session screens') 36 | cmd:option('-csv_file', '', 'CSV path to write session data') 37 | cmd:option('-subgoal_dims', 7, 'dimensions of subgoals') 38 | cmd:option('-subgoal_nhid', 50, '') 39 | cmd:option('-port', 5550, 'Port for zmq connection') 40 | cmd:option('-stepthrough', false, 'Stepthrough') 41 | cmd:option('-human_input', false, 'Human input action') 42 | cmd:option('-subgoal_screen', false, 'overlay subgoal on screen') 43 | 44 | 45 | 46 | cmd:text() 47 | 48 | local opt = cmd:parse(arg) 49 | ZMQ_PORT = opt.port 50 | 51 | 52 | if not dqn then 53 | require "initenv" 54 | end 55 | 56 | --- General setup. 57 | local game_env, game_actions, agent, opt = setup(opt) 58 | 59 | -- override print to always flush the output 60 | local old_print = print 61 | local print = function(...) 62 | old_print(...) 63 | io.flush() 64 | end 65 | 66 | -- file names from command line 67 | local gif_filename = opt.gif_file 68 | 69 | -- start a new game 70 | local screen, reward, terminal = game_env:newGame() 71 | 72 | -- compress screen to JPEG with 100% quality 73 | local jpg = image.compressJPG(screen:squeeze(), 100) 74 | -- create gd image from JPEG string 75 | local im = gd.createFromJpegStr(jpg:storage():string()) 76 | -- convert truecolor to palette 77 | im:trueColorToPalette(false, 256) 78 | 79 | -- write GIF header, use global palette and infinite looping 80 | im:gifAnimBegin(gif_filename, true, 0) 81 | -- write first frame 82 | im:gifAnimAdd(gif_filename, false, 0, 0, 7, gd.DISPOSAL_NONE) 83 | 84 | -- remember the image and show it first 85 | local previm = im 86 | local win = image.display({image=screen}) 87 | 88 | print("Started playing...") 89 | 90 | subgoal = agent:pick_subgoal(screen, 6) 91 | -- print('Subgoal:', subgoal) 92 | 93 | 94 | local action_list = {'no-op', 'fire', 'up', 'right', 'left', 'down', 'up-right','up-left','down-right','down-left', 95 | 'up-fire', 'right-fire','left-fire', 'down-fire','up-right-fire','up-left-fire', 96 | 'down-right-fire', 'down-left-fire'} 97 | 98 | -- play one episode (game) 99 | while true or not terminal do 100 | -- if action was chosen randomly, Q-value is 0 101 | agent.bestq = 0 102 | 103 | if opt.subgoal_screen then 104 | screen[{1,{}, {30+subgoal[1]-5, 30+subgoal[1]+5}, {subgoal[2]-5,subgoal[2]+5} }] = 1 105 | win = image.display({image=screen, win=win}) 106 | end 107 | 108 | -- choose the best action 109 | local action_index, isGoalReached, reward_ext, reward_tot, qfunc 110 | = agent:perceive(subgoal, reward, screen, terminal, true, 0.1) 111 | 112 | local tmp2 113 | 114 | if opt.stepthrough then 115 | print("Reward Ext", reward_ext) 116 | print("Reward Tot", reward_tot) 117 | print("Q-func") 118 | if qfunc then 119 | for i=1, #action_list do 120 | print(string.format("%s %.4f", action_list[i], qfunc[i])) 121 | end 122 | end 123 | print("Action", action_index, action_list[action_index]) 124 | tmp2 = io.read() 125 | end 126 | 127 | --human input of action 128 | if tmp2=='y' or opt.human_input then 129 | print("Enter action") 130 | local tmp = io.read() 131 | if tmp then 132 | action_index = tonumber(tmp) 133 | end 134 | 135 | end 136 | 137 | -- play game in test mode (episodes don't end when losing a life if false below) 138 | screen, reward, terminal = game_env:step(game_actions[action_index], false) 139 | -- screen, reward, terminal = game_env:step(game_actions[1], false) --no-op 140 | 141 | -- screen, reward, terminal = game_env:step(game_actions[action_index]) 142 | 143 | 144 | 145 | 146 | if isGoalReached then 147 | subgoal = agent:pick_subgoal(screen) 148 | end 149 | 150 | 151 | if not opt.subgoal_screen then 152 | screen_cropped = screen:clone() 153 | screen_cropped = screen_cropped[{{},{},{30,210},{1,160}}] 154 | -- screen_cropped[{1,{}, {subgoal[1]-5, subgoal[1]+5}, {subgoal[2]-5,subgoal[2]+5} }] = 1 155 | 156 | -- display screen 157 | image.display({image=screen_cropped, win=win}) 158 | end 159 | 160 | -- create gd image from tensor 161 | jpg = image.compressJPG(screen:squeeze(), 100) 162 | im = gd.createFromJpegStr(jpg:storage():string()) 163 | 164 | -- use palette from previous (first) image 165 | im:trueColorToPalette(false, 256) 166 | im:paletteCopy(previm) 167 | 168 | -- write new GIF frame, no local palette, starting from left-top, 7ms delay 169 | im:gifAnimAdd(gif_filename, false, 0, 0, 7, gd.DISPOSAL_NONE) 170 | -- remember previous screen for optimal compression 171 | previm = im 172 | 173 | end 174 | 175 | -- end GIF animation and close CSV file 176 | gd.gifAnimEnd(gif_filename) 177 | 178 | print("Finished playing, close window to exit!") 179 | -------------------------------------------------------------------------------- /dqn/tmp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/tmp.png -------------------------------------------------------------------------------- /dqn/tmp_5000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/tmp_5000.png -------------------------------------------------------------------------------- /dqn/tmp_5001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/tmp_5001.png -------------------------------------------------------------------------------- /dqn/tmp_5550.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/tmp_5550.png -------------------------------------------------------------------------------- /dqn/tmp_5555.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/tmp_5555.png -------------------------------------------------------------------------------- /dqn/tmp_6000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/dqn/tmp_6000.png -------------------------------------------------------------------------------- /dqn/train_agent.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | ./run_gpu montezuma_revenge basic1 5550 12 false 3 | ]] 4 | require 'xlua' 5 | require 'optim' 6 | require 'signal' 7 | 8 | signal.signal("SIGPIPE", function() print("raised") end) 9 | 10 | local cmd = torch.CmdLine() 11 | cmd:text() 12 | cmd:text('Train Agent in Environment:') 13 | cmd:text() 14 | cmd:text('Options:') 15 | 16 | cmd:option('-subgoal_index', 12, 'the index of the subgoal that we want to reach. used for slurm multiple runs') 17 | cmd:option('-max_subgoal_index', 12, 'used as an index to run with all the subgoals instead of only one specific one') 18 | 19 | cmd:option('-exp_folder', '', 'name of folder where current exp state is being stored') 20 | cmd:option('-framework', '', 'name of training framework') 21 | cmd:option('-env', '', 'name of environment to use') 22 | cmd:option('-game_path', '', 'path to environment file (ROM)') 23 | cmd:option('-env_params', '', 'string of environment parameters') 24 | cmd:option('-pool_frms', '', 25 | 'string of frame pooling parameters (e.g.: size=2,type="max")') 26 | cmd:option('-actrep', 1, 'how many times to repeat action') 27 | cmd:option('-random_starts', 0, 'play action 0 between 1 and random_starts ' .. 28 | 'number of times at the start of each training episode') 29 | 30 | cmd:option('-name', '', 'filename used for saving network and training history') 31 | cmd:option('-network', '', 'reload pretrained network') 32 | cmd:option('-agent', '', 'name of agent file to use') 33 | cmd:option('-agent_params', '', 'string of agent parameters') 34 | cmd:option('-seed', 10, 'fixed input seed for repeatable experiments') 35 | cmd:option('-saveNetworkParams', true, 36 | 'saves the agent network in a separate file') 37 | cmd:option('-prog_freq', 5*10^3, 'frequency of progress output') 38 | cmd:option('-save_freq', 5*10^4, 'the model is saved every save_freq steps') 39 | cmd:option('-eval_freq', 10^4, 'frequency of greedy evaluation') 40 | cmd:option('-save_versions', 0, '') 41 | 42 | cmd:option('-steps', 10^5, 'number of training steps to perform') 43 | cmd:option('-eval_steps', 10^5, 'number of evaluation steps') 44 | 45 | cmd:option('-verbose', 2, 46 | 'the higher the level, the more information is printed to screen') 47 | cmd:option('-threads', 1, 'number of BLAS threads') 48 | cmd:option('-gpu', -1, 'gpu flag') 49 | 50 | cmd:option('-subgoal_dims', 7, 'dimensions of subgoals') 51 | cmd:option('-subgoal_nhid', 100, '') 52 | cmd:option('-display_game', true, 'option to display game') 53 | cmd:option('-port', 5550, 'Port for zmq connection') 54 | cmd:option('-stepthrough', false, 'Stepthrough') 55 | cmd:option('-subgoal_screen', true, 'overlay subgoal on screen') 56 | 57 | cmd:option('-max_steps_episode', 1000, 'Max steps per episode') 58 | 59 | 60 | 61 | 62 | 63 | cmd:text() 64 | 65 | local opt = cmd:parse(arg) 66 | ZMQ_PORT = opt.port 67 | SUBGOAL_SCREEN = opt.subgoal_screen 68 | 69 | 70 | if not dqn then 71 | require "initenv" 72 | end 73 | 74 | 75 | print(opt.env_params) 76 | print(opt.seed) 77 | 78 | --- General setup. 79 | local game_env, game_actions, agent, opt = setup(opt) 80 | 81 | -- override print to always flush the output 82 | local old_print = print 83 | local print = function(...) 84 | old_print(...) 85 | io.flush() 86 | end 87 | 88 | local learn_start = agent.learn_start 89 | local start_time = sys.clock() 90 | local reward_counts = {} 91 | local episode_counts = {} 92 | local time_history = {} 93 | local v_history = {} 94 | local qmax_history = {} 95 | local td_history = {} 96 | local reward_history = {} 97 | local step = 0 98 | time_history[1] = 0 99 | 100 | local total_reward 101 | local nrewards 102 | local nepisodes 103 | local episode_reward 104 | 105 | local screen, reward, terminal = game_env:getState() 106 | 107 | print("Iteration ..", step) 108 | local win = nil 109 | 110 | local subgoal 111 | 112 | if opt.subgoal_index < opt.max_subgoal_index then 113 | subgoal = agent:pick_subgoal(screen, opt.subgoal_index) 114 | else 115 | subgoal = agent:pick_subgoal(screen) 116 | end 117 | 118 | 119 | local action_list = {'no-op', 'fire', 'up', 'right', 'left', 'down', 'up-right','up-left','down-right','down-left', 120 | 'up-fire', 'right-fire','left-fire', 'down-fire','up-right-fire','up-left-fire', 121 | 'down-right-fire', 'down-left-fire'} 122 | 123 | death_counter = 0 --to handle a bug in MZ atari 124 | 125 | episode_step_counter = 0 126 | 127 | while step < opt.steps do 128 | xlua.progress(step, opt.steps) 129 | 130 | step = step + 1 131 | 132 | if opt.subgoal_screen then 133 | -- for i=3,#agent.objects do 134 | -- if agent.objects[i][1] > 0 and agent.objects[i][2] > 0 then 135 | -- screen[{1,{}, {30+agent.objects[i][1]-5, 30+agent.objects[i][1]+5}, {agent.objects[i][2]-5,agent.objects[i][2]+5} }] = 1 136 | -- end 137 | -- end 138 | 139 | screen[{1,{}, {30+subgoal[1]-5, 30+subgoal[1]+5}, {subgoal[2]-5,subgoal[2]+5} }] = 1 140 | win = image.display({image=screen, win=win}) 141 | end 142 | 143 | local action_index, isGoalReached, reward_ext, reward_tot, qfunc = agent:perceive(subgoal, reward, screen, terminal) 144 | 145 | if opt.stepthrough then 146 | print("Reward Ext", reward_ext) 147 | print("Reward Tot", reward_tot) 148 | print("Q-func") 149 | if qfunc then 150 | for i=1, #action_list do 151 | print(action_list[i], qfunc[i]) 152 | end 153 | end 154 | 155 | print("Action", action_index, action_list[action_index]) 156 | io.read() 157 | end 158 | 159 | if false and new_game then--new_game then 160 | print("Q-func") 161 | if prev_Q then 162 | for i=1, #action_list do 163 | print(action_list[i], prev_Q[i]) 164 | end 165 | end 166 | print("SUM OF PIXELS: ", screen:sum()) 167 | new_game = false 168 | end 169 | 170 | -- game over? get next game! 171 | if not terminal and episode_step_counter < opt.max_steps_episode then 172 | 173 | 174 | if isGoalReached and opt.subgoal_index < opt.max_subgoal_index then 175 | screen,reward, terminal = game_env:newGame() -- restart game if focussing on single subgoal 176 | subgoal = agent:pick_subgoal(screen, opt.subgoal_index) 177 | if opt.subgoal_screen then 178 | screen[{1,{}, {30+subgoal[1]-5, 30+subgoal[1]+5}, {subgoal[2]-5,subgoal[2]+5} }] = 1 179 | end 180 | 181 | isGoalReached = false 182 | end 183 | 184 | screen, reward, terminal = game_env:step(game_actions[action_index], true) 185 | screen, reward, terminal = game_env:step(game_actions[1], true) -- noop 186 | episode_step_counter = episode_step_counter + 1 187 | -- screen, reward, terminal = game_env:step(game_actions[1], true) 188 | prev_Q = qfunc 189 | else 190 | death_counter = death_counter + 1 191 | -- print("TERMINAL ENCOUNTERED") 192 | if opt.random_starts > 0 then 193 | -- print("RANDOM GAME STARTING") 194 | screen, reward, terminal = game_env:nextRandomGame() 195 | else 196 | -- print("NEW GAME STARTING") 197 | screen, reward, terminal = game_env:newGame() 198 | end 199 | 200 | if death_counter == 5 then 201 | screen,reward, terminal = game_env:newGame() 202 | death_counter = 0 203 | end 204 | 205 | new_game = true 206 | isGoalReached = true --new game so reset goal 207 | episode_step_counter = 0 208 | end 209 | 210 | if isGoalReached then 211 | if opt.subgoal_index < opt.max_subgoal_index then 212 | subgoal = agent:pick_subgoal(screen, opt.subgoal_index) 213 | else 214 | subgoal = agent:pick_subgoal(screen) 215 | end 216 | 217 | isGoalReached = false 218 | end 219 | 220 | 221 | -- display screen 222 | if opt.display_game then 223 | if not opt.subgoal_screen then 224 | screen_cropped = screen:clone() 225 | screen_cropped = screen_cropped[{{},{},{30,210},{1,160}}] 226 | screen_cropped[{1,{}, {subgoal[1]-5, subgoal[1]+5}, {subgoal[2]-5,subgoal[2]+5} }] = 1 227 | win = image.display({image=screen_cropped, win=win}) 228 | end 229 | end 230 | 231 | if step % opt.prog_freq == 0 then 232 | assert(step==agent.numSteps, 'trainer step: ' .. step .. 233 | ' & agent.numSteps: ' .. agent.numSteps) 234 | print("Steps: ", step) 235 | agent:report() 236 | collectgarbage() 237 | end 238 | 239 | 240 | -- update dynamic discount 241 | -- if step > learn_start then 242 | -- agent.dynamic_discount = 0.02 + 0.98 * agent.dynamic_discount 243 | -- end 244 | 245 | if step%1000 == 0 then collectgarbage() end 246 | 247 | -- evaluation 248 | -- TODO: make it true later 249 | if false then ---step % opt.eval_freq == 0 and step > learn_start then 250 | print("Testing ...") 251 | 252 | local cum_reward_ext = 0 253 | local cum_reward_tot = 0 254 | 255 | screen, reward, terminal = game_env:newGame() 256 | subgoal = agent:pick_subgoal(screen) 257 | if opt.subgoal_screen then 258 | screen[{1,{}, {30+subgoal[1]-5, 30+subgoal[1]+5}, {subgoal[2]-5,subgoal[2]+5} }] = 1 259 | end 260 | 261 | 262 | test_avg_Q = test_avg_Q or optim.Logger(paths.concat(opt.exp_folder , 'test_avgQ.log')) 263 | test_avg_R = test_avg_R or optim.Logger(paths.concat(opt.exp_folder , 'test_avgR.log')) 264 | test_avg_R2 = test_avg_R2 or optim.Logger(paths.concat(opt.exp_folder , 'test_avgR2.log')) 265 | 266 | total_reward = 0 267 | nrewards = 0 268 | nepisodes = 0 269 | episode_reward = 0 270 | 271 | death_counter_eval = 0 272 | 273 | local eval_time = sys.clock() 274 | for estep=1,opt.eval_steps do 275 | xlua.progress(estep, opt.eval_steps) 276 | 277 | 278 | if opt.subgoal_screen then 279 | screen[{1,{}, {30+subgoal[1]-5, 30+subgoal[1]+5}, {subgoal[2]-5,subgoal[2]+5} }] = 1 280 | win = image.display({image=screen, win=win}) 281 | end 282 | 283 | local action_index, isGoalReached, reward_ext, reward_tot = agent:perceive(subgoal, reward, screen, terminal, true, 0.1) 284 | 285 | 286 | cum_reward_tot = cum_reward_tot + reward_tot 287 | cum_reward_ext = cum_reward_ext + reward_ext 288 | 289 | -- Play game in test mode (episodes don't end when losing a life) 290 | screen, reward, terminal = game_env:step(game_actions[action_index]) 291 | screen, reward, terminal = game_env:step(game_actions[1]) 292 | 293 | -- display screen 294 | if opt.display_game and not opt.subgoal_screen then 295 | screen_cropped = screen:clone() 296 | screen_cropped = screen_cropped[{{},{},{30,210},{1,160}}] 297 | screen_cropped[{1,{}, {subgoal[1]-5, subgoal[1]+5}, {subgoal[2]-5,subgoal[2]+5} }] = 1 298 | win = image.display({image=screen_cropped, win=win}) 299 | end 300 | 301 | if estep%1000 == 0 then collectgarbage() end 302 | 303 | -- record every reward 304 | episode_reward = episode_reward + reward 305 | if reward ~= 0 then 306 | nrewards = nrewards + 1 307 | end 308 | 309 | if terminal then 310 | total_reward = total_reward + episode_reward 311 | episode_reward = 0 312 | nepisodes = nepisodes + 1 313 | screen, reward, terminal = game_env:newGame() 314 | isGoalReached = true --new game so reset subgoal 315 | death_counter_eval = death_counter_eval + 1 316 | 317 | if death_counter_eval == 5 then 318 | screen,reward, terminal = game_env:newGame() 319 | death_counter_eval = 0 320 | end 321 | end 322 | if isGoalReached then 323 | subgoal = agent:pick_subgoal(screen) 324 | isGoalReached = false 325 | end 326 | 327 | end 328 | 329 | eval_time = sys.clock() - eval_time 330 | start_time = start_time + eval_time 331 | agent:compute_validation_statistics() 332 | local ind = #reward_history+1 333 | total_reward = total_reward/math.max(1, nepisodes) 334 | 335 | cum_reward_ext = cum_reward_ext / math.max(1,nepisodes) 336 | cum_reward_tot = cum_reward_tot / math.max(1,nepisodes) 337 | 338 | if #reward_history == 0 or total_reward > torch.Tensor(reward_history):max() then 339 | agent.best_network_real = agent.network_real:clone() 340 | end 341 | 342 | if agent.v_avg then 343 | v_history[ind] = agent.v_avg 344 | td_history[ind] = agent.tderr_avg 345 | qmax_history[ind] = agent.q_max 346 | end 347 | print("V", v_history[ind], "TD error", td_history[ind], "Qmax", qmax_history[ind]) 348 | 349 | test_avg_R:add{['% Average Extrinsic Reward'] = cum_reward_ext} 350 | test_avg_R2:add{['% Average Total Reward'] = cum_reward_tot} 351 | test_avg_Q:add{['% Average Q'] = agent.v_avg} 352 | 353 | 354 | test_avg_R:style{['% Average Extrinsic Reward'] = '-'}; test_avg_R:plot() 355 | test_avg_R2:style{['% Average Total Reward'] = '-'}; test_avg_R2:plot() 356 | 357 | test_avg_Q:style{['% Average Q'] = '-'}; test_avg_Q:plot() 358 | 359 | reward_history[ind] = total_reward 360 | reward_counts[ind] = nrewards 361 | episode_counts[ind] = nepisodes 362 | 363 | time_history[ind+1] = sys.clock() - start_time 364 | 365 | local time_dif = time_history[ind+1] - time_history[ind] 366 | 367 | local training_rate = opt.actrep*opt.eval_freq/time_dif 368 | 369 | print(string.format( 370 | '\nSteps: %d (frames: %d), extrinsic reward: %.2f, total reward (I+E): %.2f, epsilon: %.2f, lr: %G, ' .. 371 | 'training time: %ds, training rate: %dfps, testing time: %ds, ' .. 372 | 'testing rate: %dfps, num. ep.: %d, num. rewards: %d', 373 | step, step*opt.actrep, cum_reward_ext, cum_reward_tot, agent.ep, agent.lr, time_dif, 374 | training_rate, eval_time, opt.actrep*opt.eval_steps/eval_time, 375 | nepisodes, nrewards)) 376 | end 377 | 378 | if step % opt.save_freq == 0 or step == opt.steps then 379 | local s, a, r, s2, term = agent.valid_s, agent.valid_a, agent.valid_r, 380 | agent.valid_s2, agent.valid_term 381 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2, 382 | agent.valid_term = nil, nil, nil, nil, nil, nil, nil 383 | local w_real, dw_real, g_real, g2_real, delta, delta2, deltas, deltas_real, tmp_real = agent.w_real, agent.dw_real, 384 | agent.g_real, agent.g2_real, agent.delta, agent.delta2, agent.deltas, agent.deltas_real, agent.tmp_real 385 | agent.w_real, agent.dw_real, agent.g_real, agent.g2_real, agent.delta, agent.delta2, agent.deltas, 386 | agent.deltas_real, agent.tmp_real = nil, nil, nil, nil, nil, nil, nil, nil, nil 387 | 388 | local filename = opt.name 389 | if opt.save_versions > 0 then 390 | filename = filename .. "_" .. math.floor(step / opt.save_versions) 391 | end 392 | filename = filename 393 | torch.save(filename .. ".t7", {agent = agent, 394 | model = agent.network, 395 | best_model = agent.best_network, 396 | model_real = agent.network_real, 397 | best_model_real = agent.best_network_real, 398 | reward_history = reward_history, 399 | reward_counts = reward_counts, 400 | episode_counts = episode_counts, 401 | time_history = time_history, 402 | v_history = v_history, 403 | td_history = td_history, 404 | qmax_history = qmax_history, 405 | arguments=opt}) 406 | if opt.saveNetworkParams then 407 | local nets = {network=w_real:clone():float()} 408 | torch.save(filename..'.params.t7', nets, 'ascii') 409 | end 410 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2, 411 | agent.valid_term = s, a, r, s2, term 412 | agent.w_real, agent.dw_real, agent.g_real, agent.g2_real, agent.delta, agent.delta2, agent.deltas, 413 | agent.deltas_real, agent.tmp_real = w_real, dw_real, g_real, g2_real, delta, delta2, deltas, deltas_real, tmp_real 414 | print('Saved:', filename .. '.t7') 415 | io.flush() 416 | collectgarbage() 417 | end 418 | end 419 | -------------------------------------------------------------------------------- /dqn/unit_tests.lua: -------------------------------------------------------------------------------- 1 | dqn = {} 2 | require 'TransitionTable_spriority' 3 | require 'cutorch' 4 | 5 | seed = torch.random(1,10000) 6 | torch.manualSeed(seed) 7 | print('seed:', seed) 8 | 9 | local function trans_table() 10 | local args = { 11 | stateDim = 10 , numActions = 5, 12 | histLen = 4, gpu = 1, 13 | maxSize = 32, histType = "linear", 14 | histSpacing = 1, nonTermProb = 1, 15 | bufferSize = 16, 16 | subgoal_dims = 1 17 | } 18 | 19 | local transitions = dqn.TransitionTable(args) 20 | for ii=1,100 do 21 | 22 | for i =1,36, 4 do 23 | for j=i,i+5 do 24 | transitions:add(torch.rand(args.stateDim), 2, torch.Tensor({0, 0}), false, 1) 25 | end 26 | transitions:add(torch.rand(args.stateDim), 2, torch.Tensor({1, 2}), true, 1) 27 | end 28 | -- print('table # -> ', transitions:size()) 29 | -- print('\n--------------------') 30 | -- print('END:', transitions.end_ptrs) 31 | -- print('Before sample DYN:', transitions.dyn_ptrs) 32 | -- print('R:', transitions.trace_indxs_with_reward) 33 | 34 | for kk=1,5 do 35 | s, a, r, s2, t, sg, sg2= transitions:sample(8) 36 | -- print(r) 37 | -- print('After sample DYN:', transitions.dyn_ptrs) 38 | end 39 | end 40 | print('[Transition Table with Prioritization] Success!') 41 | end 42 | 43 | trans_table() -- testing transition table 44 | -------------------------------------------------------------------------------- /gifs/breakout.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/gifs/breakout.gif -------------------------------------------------------------------------------- /gifs/enduro.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/gifs/enduro.gif -------------------------------------------------------------------------------- /gifs/enduro.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/gifs/enduro.mp4 -------------------------------------------------------------------------------- /install_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ###################################################################### 4 | # Torch install 5 | ###################################################################### 6 | 7 | 8 | TOPDIR=$PWD 9 | 10 | # Prefix: 11 | PREFIX=$PWD/torch 12 | echo "Installing Torch into: $PREFIX" 13 | 14 | if [[ `uname` != 'Linux' ]]; then 15 | echo 'Platform unsupported, only available for Linux' 16 | exit 17 | fi 18 | if [[ `which apt-get` == '' ]]; then 19 | echo 'apt-get not found, platform not supported' 20 | exit 21 | fi 22 | 23 | ## Install dependencies for Torch: 24 | #sudo apt-get update 25 | #sudo apt-get install -qqy build-essential 26 | #sudo apt-get install -qqy gcc g++ 27 | #sudo apt-get install -qqy cmake 28 | #sudo apt-get install -qqy curl 29 | #sudo apt-get install -qqy libreadline-dev 30 | #sudo apt-get install -qqy git-core 31 | #sudo apt-get install -qqy libjpeg-dev 32 | #sudo apt-get install -qqy libpng-dev 33 | #sudo apt-get install -qqy ncurses-dev 34 | #sudo apt-get install -qqy imagemagick 35 | #sudo apt-get install -qqy unzip 36 | #sudo apt-get install -qqy libqt4-dev 37 | #sudo apt-get install -qqy liblua5.1-0-dev 38 | #sudo apt-get install -qqy libgd-dev 39 | #sudo apt-get update 40 | # 41 | # 42 | #echo "==> Torch7's dependencies have been installed" 43 | # 44 | # 45 | # 46 | # 47 | # 48 | ## Build and install Torch7 49 | #cd /tmp 50 | #rm -rf luajit-rocks 51 | #git clone https://github.com/torch/luajit-rocks.git 52 | #cd luajit-rocks 53 | #mkdir -p build 54 | #cd build 55 | #git checkout master; git pull 56 | #rm -f CMakeCache.txt 57 | #cmake .. -DCMAKE_INSTALL_PREFIX=$PREFIX -DCMAKE_BUILD_TYPE=Release 58 | #RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 59 | #make 60 | #RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 61 | #make install 62 | #RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 63 | # 64 | # 65 | #path_to_nvcc=$(which nvcc) 66 | #if [ -x "$path_to_nvcc" ] 67 | #then 68 | # cutorch=ok 69 | # cunn=ok 70 | #fi 71 | # 72 | ## Install base packages: 73 | #$PREFIX/bin/luarocks install cwrap 74 | #$PREFIX/bin/luarocks install paths 75 | #$PREFIX/bin/luarocks install torch 76 | #$PREFIX/bin/luarocks install nn 77 | # 78 | #[ -n "$cutorch" ] && \ 79 | #($PREFIX/bin/luarocks install cutorch) 80 | #[ -n "$cunn" ] && \ 81 | #($PREFIX/bin/luarocks install cunn) 82 | # 83 | #$PREFIX/bin/luarocks install luafilesystem 84 | #$PREFIX/bin/luarocks install penlight 85 | #$PREFIX/bin/luarocks install sys 86 | #$PREFIX/bin/luarocks install xlua 87 | #$PREFIX/bin/luarocks install image 88 | #$PREFIX/bin/luarocks install env 89 | #$PREFIX/bin/luarocks install qtlua 90 | #$PREFIX/bin/luarocks install qttorch 91 | # 92 | #echo "" 93 | #echo "=> Torch7 has been installed successfully" 94 | #echo "" 95 | # 96 | # 97 | #echo "Installing nngraph ... " 98 | #$PREFIX/bin/luarocks install nngraph 99 | #RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 100 | #echo "nngraph installation completed" 101 | # 102 | echo "Installing Xitari ... " 103 | cd /tmp 104 | rm -rf xitari 105 | git clone https://github.com/deepmind/xitari.git 106 | cd xitari 107 | luarocks make 108 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 109 | echo "Xitari installation completed" 110 | 111 | echo "Installing Alewrap ... " 112 | cd /tmp 113 | rm -rf alewrap 114 | git clone https://github.com/deepmind/alewrap.git 115 | cd alewrap 116 | luarocks make 117 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 118 | echo "Alewrap installation completed" 119 | 120 | #echo "Installing Lua-GD ... " 121 | #mkdir src 122 | #cd src 123 | #rm -rf lua-gd 124 | #git clone https://github.com/ittner/lua-gd.git 125 | #cd lua-gd 126 | ##sed -i "s/LUABIN=lua5.1/LUABIN=..\/..\/bin\/luajit/" Makefile 127 | #sudo luarocks make 128 | #RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 129 | #echo "Lua-GD installation completed" 130 | 131 | echo 132 | echo "You can run experiments by executing: " 133 | echo 134 | echo " ./run_cpu game_name" 135 | echo 136 | echo " or " 137 | echo 138 | echo " ./run_gpu game_name" 139 | echo 140 | echo "For this you need to provide the rom files of the respective games (game_name.bin) in the roms/ directory" 141 | echo 142 | 143 | -------------------------------------------------------------------------------- /roms/montezuma_revenge.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/roms/montezuma_revenge.bin -------------------------------------------------------------------------------- /run.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=DQN 3 | #SBATCH --output=slurm_logs/DQN.out 4 | #SBATCH --error=slurm_logs/DQN.err 5 | #cd dqn 6 | #/home/tejask/envs/my_root/bin/python pyserver.py & 7 | #cd .. 8 | #./run_gpu montezuma_revenge 9 | 10 | ./run_exp.sh test 5000 11 | -------------------------------------------------------------------------------- /run_cpu: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./run_cpu breakout "; exit 0 5 | fi 6 | ENV=$1 7 | FRAMEWORK="alewrap" 8 | 9 | game_path=$PWD"/roms/" 10 | env_params="useRGB=true" 11 | agent="NeuralQLearner" 12 | n_replay=1 13 | netfile="\"convnet_atari3\"" 14 | update_freq=4 15 | actrep=4 16 | discount=0.99 17 | seed=1 18 | learn_start=50000 19 | pool_frms_type="\"max\"" 20 | pool_frms_size=2 21 | initial_priority="false" 22 | replay_memory=1000000 23 | eps_end=0.1 24 | eps_endt=replay_memory 25 | lr=0.00025 26 | agent_type="DQN3_0_1" 27 | preproc_net="\"net_downsample_2x_full_y\"" 28 | agent_name=$agent_type"_"$1"_FULL_Y" 29 | state_dim=7056 30 | ncols=1 31 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,min_reward=-1,max_reward=1" 32 | steps=50000000 33 | eval_freq=250000 34 | eval_steps=125000 35 | prog_freq=5000 36 | save_freq=125000 37 | gpu=-1 38 | random_starts=30 39 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 40 | num_threads=4 41 | 42 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads" 43 | echo $args 44 | 45 | cd dqn 46 | qlua train_agent.lua $args 47 | -------------------------------------------------------------------------------- /run_exp.sh: -------------------------------------------------------------------------------- 1 | if [ -z "$1" ] 2 | then echo "Please provide the logname and port for running the experiment e.g. ./run_exp basic1 5000 "; exit 0 3 | fi 4 | 5 | if [ -z "$2" ] 6 | then echo "Please provide the logname and port for running the experiment e.g. ./run_exp basic1 5000 "; exit 0 7 | fi 8 | cd dqn; 9 | python pyserver.py $2 & 10 | cd ..; 11 | ./run_gpu montezuma_revenge $1 $2 $3 $4; -------------------------------------------------------------------------------- /run_exp_multi.sh: -------------------------------------------------------------------------------- 1 | min=$1 2 | max=$2 3 | for seed in 6050; do 4 | for subgoal in $(seq $min $max); do 5 | for usedistance in 'true'; do 6 | temp="seed_${seed}_subgoal_${subgoal}_usedistance_${usedistance}" 7 | ./run_exp.sh $temp $((seed + subgoal)) $subgoal $usedistance & 8 | done 9 | done 10 | done 11 | 12 | -------------------------------------------------------------------------------- /run_gpu: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game and a log name and a port, e.g. ./run_gpu breakout basic1 5000 "; exit 0 5 | fi 6 | 7 | subgoal_index=${4:-12} # a number between 2 to 11 for now, 12 means not use any subgoal 8 | use_distance=${5:-false} #using distance to subgoal as a reward 9 | 10 | learn_start=5000 #50000 11 | steps=50000000 12 | eval_freq=30000 13 | eval_steps=10000 14 | prog_freq=10000 15 | save_freq=10000 16 | replay_memory=1000000 17 | eps_end=0.1 #0.1 18 | eps_endt=replay_memory 19 | 20 | # learn_start=1024 21 | # steps=50000000 22 | # eval_freq=1024 23 | # eval_steps=1024 24 | # prog_freq=1024 25 | # save_freq=1024 26 | # replay_memory=2000 27 | # eps_end=0.1 #0.1 28 | # eps_endt=500000 #replay_memory 29 | 30 | ENV=$1 31 | FRAMEWORK="alewrap" 32 | game_path=$PWD"/roms/" 33 | env_params="useRGB=true" 34 | agent="NeuralQLearner" 35 | n_replay=1 36 | netfile="\"convnet_atari3\"" 37 | # netfile="\"logs/smoooth_target_q_limited_steps/smoooth_target_q_limited_steps_montezuma_revenge_FULL_Y.t7\"" 38 | update_freq=4 39 | actrep=4 40 | discount=0.99 41 | discount_internal=0.99 42 | dynamic_discount=0.99 #starting value for dynamic discounting scheme 43 | seed=$3 #using port as seed 44 | pool_frms_type="\"max\"" 45 | pool_frms_size=2 46 | initial_priority="false" 47 | 48 | lr=0.00025 49 | agent_type=$2 50 | preproc_net="\"net_downsample_2x_full_y\"" 51 | agent_name=$agent_type"_"$1"_FULL_Y" 52 | state_dim=7056 53 | ncols=1 54 | agent_params="use_distance="$use_distance",lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",dynamic_discount="$dynamic_discount",discount="$discount",discount_internal="$discount_internal",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=64,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=10,min_reward=-1000,max_reward=1000" 55 | 56 | gpu=1 57 | random_starts=1 #need to make this 30 later for random starting points for comparison with original DQN 58 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 59 | num_threads=4 60 | 61 | mkdir dqn/logs/$agent_type; 62 | args="-framework $FRAMEWORK -exp_folder logs/$agent_type -game_path $game_path -name logs/$agent_type/$agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -port $3 -subgoal_index $subgoal_index" 63 | echo $args 64 | 65 | cd dqn 66 | qlua train_agent.lua $args 67 | -------------------------------------------------------------------------------- /run_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 3 | #SBATCH -n 2 4 | #SBATCH -p general 5 | #SBATCH --mem 70000 6 | #SBATCH -t 5-23:00 7 | #SBATCH --mail-user=tejask@mit.edu 8 | 9 | cd dqn 10 | /home/tejask/envs/my_root/bin/python pyserver.py & 11 | cd .. 12 | ./run_gpu montezuma_revenge fullrun1 5550 12 false 13 | -------------------------------------------------------------------------------- /run_slurm_multi_exp.sh: -------------------------------------------------------------------------------- 1 | jobname='mz_test' 2 | 3 | for seed in 6000; do 4 | for subgoal in {2..12}; do 5 | for usedistance in 'true' 'false'; do 6 | stdOut=log.${temp}.stdout 7 | stdErr=log.${temp}.stderr 8 | temp="seed_${seed}_subgoal_${subgoal}_usedistance_${usedistance}" 9 | resFile=result.${temp} 10 | stdOut=log.${jobname}.${temp}.stdout 11 | stdErr=log.${jobname}.${temp}.stderr 12 | logRoot=slurm_logs 13 | 14 | 15 | sbatch ${jobRunTime} -o ${logRoot}/${stdOut} -e ${logRoot}/${stdErr} --job-name=${jobname} run_slurm.sh ${seed} ${subgoal} ${usedistance} 16 | sleep 2 17 | 18 | done 19 | done 20 | done 21 | 22 | -------------------------------------------------------------------------------- /stop_server.sh: -------------------------------------------------------------------------------- 1 | ps x | grep python | grep $1 | cut -d' ' -f1 | xargs kill -------------------------------------------------------------------------------- /structured_priority/dqn/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | LIMITED LICENSE: 3 | 4 | Copyright (c) 2014 Google Inc. 5 | Limited License: Under no circumstance is commercial use, reproduction, or 6 | distribution permitted. Use, reproduction, and distribution are permitted 7 | solely for academic use in evaluating and reviewing claims made in 8 | "Human-level control through deep reinforcement learning", Nature 518, 529–533 9 | (26 February 2015) doi:10.1038/nature14236, provided that the following 10 | conditions are met: 11 | 12 | * Any reproduction or distribution of source code must retain the above 13 | copyright notice and the full text of this license including the following 14 | disclaimer.
 15 | 16 | * Any reproduction or distribution in binary form must reproduce the above 17 | copyright notice and the full text of this license including the following 18 | disclaimer
 in the documentation and/or other materials provided with the 19 | distribution. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /structured_priority/dqn/NeuralQLearner.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | if not dqn then 8 | require 'initenv' 9 | end 10 | 11 | local nql = torch.class('dqn.NeuralQLearner') 12 | 13 | 14 | function nql:__init(args) 15 | self.state_dim = args.state_dim -- State dimensionality. 16 | self.actions = args.actions 17 | self.n_actions = #self.actions 18 | self.verbose = args.verbose 19 | self.best = args.best 20 | 21 | --- epsilon annealing 22 | self.ep_start = args.ep or 1 23 | self.ep = self.ep_start -- Exploration probability. 24 | self.ep_end = args.ep_end or self.ep 25 | self.ep_endt = args.ep_endt or 1000000 26 | 27 | ---- learning rate annealing 28 | self.lr_start = args.lr or 0.01 --Learning rate. 29 | self.lr = self.lr_start 30 | self.lr_end = args.lr_end or self.lr 31 | self.lr_endt = args.lr_endt or 1000000 32 | self.wc = args.wc or 0 -- L2 weight cost. 33 | self.minibatch_size = args.minibatch_size or 1 34 | self.valid_size = args.valid_size or 500 35 | 36 | --- Q-learning parameters 37 | self.discount = args.discount or 0.99 --Discount factor. 38 | self.update_freq = args.update_freq or 1 39 | -- Number of points to replay per learning step. 40 | self.n_replay = args.n_replay or 1 41 | -- Number of steps after which learning starts. 42 | self.learn_start = args.learn_start or 0 43 | -- Size of the transition table. 44 | self.replay_memory = args.replay_memory or 1000000 45 | self.hist_len = args.hist_len or 1 46 | self.rescale_r = args.rescale_r 47 | self.max_reward = args.max_reward 48 | self.min_reward = args.min_reward 49 | self.clip_delta = args.clip_delta 50 | self.target_q = args.target_q 51 | self.bestq = 0 52 | 53 | self.gpu = args.gpu 54 | 55 | self.ncols = args.ncols or 1 -- number of color channels in input 56 | self.input_dims = args.input_dims or {self.hist_len*self.ncols, 84, 84} 57 | self.preproc = args.preproc -- name of preprocessing network 58 | self.histType = args.histType or "linear" -- history type to use 59 | self.histSpacing = args.histSpacing or 1 60 | self.nonTermProb = args.nonTermProb or 1 61 | self.bufferSize = args.bufferSize or 512 62 | 63 | self.transition_params = args.transition_params or {} 64 | 65 | self.network = args.network or self:createNetwork() 66 | 67 | -- check whether there is a network file 68 | local network_function 69 | if not (type(self.network) == 'string') then 70 | error("The type of the network provided in NeuralQLearner" .. 71 | " is not a string!") 72 | end 73 | 74 | local msg, err = pcall(require, self.network) 75 | if not msg then 76 | -- try to load saved agent 77 | local err_msg, exp = pcall(torch.load, self.network) 78 | if not err_msg then 79 | error("Could not find network file ") 80 | end 81 | if self.best and exp.best_model then 82 | self.network = exp.best_model 83 | else 84 | self.network = exp.model 85 | end 86 | else 87 | print('Creating Agent Network from ' .. self.network) 88 | self.network = err 89 | self.network = self:network() 90 | end 91 | 92 | if self.gpu and self.gpu >= 0 then 93 | self.network:cuda() 94 | else 95 | self.network:float() 96 | end 97 | 98 | -- Load preprocessing network. 99 | if not (type(self.preproc == 'string')) then 100 | error('The preprocessing is not a string') 101 | end 102 | msg, err = pcall(require, self.preproc) 103 | if not msg then 104 | error("Error loading preprocessing net") 105 | end 106 | self.preproc = err 107 | self.preproc = self:preproc() 108 | self.preproc:float() 109 | 110 | if self.gpu and self.gpu >= 0 then 111 | self.network:cuda() 112 | self.tensor_type = torch.CudaTensor 113 | else 114 | self.network:float() 115 | self.tensor_type = torch.FloatTensor 116 | end 117 | 118 | -- Create transition table. 119 | ---- assuming the transition table always gets floating point input 120 | ---- (Foat or Cuda tensors) and always returns one of the two, as required 121 | ---- internally it always uses ByteTensors for states, scaling and 122 | ---- converting accordingly 123 | local transition_args = { 124 | stateDim = self.state_dim, numActions = self.n_actions, 125 | histLen = self.hist_len, gpu = self.gpu, 126 | maxSize = self.replay_memory, histType = self.histType, 127 | histSpacing = self.histSpacing, nonTermProb = self.nonTermProb, 128 | bufferSize = self.bufferSize 129 | } 130 | 131 | self.transitions = dqn.TransitionTable(transition_args) 132 | 133 | self.numSteps = 0 -- Number of perceived states. 134 | self.lastState = nil 135 | self.lastAction = nil 136 | self.v_avg = 0 -- V running average. 137 | self.tderr_avg = 0 -- TD error running average. 138 | 139 | self.q_max = 1 140 | self.r_max = 1 141 | 142 | self.w, self.dw = self.network:getParameters() 143 | self.dw:zero() 144 | 145 | self.deltas = self.dw:clone():fill(0) 146 | 147 | self.tmp= self.dw:clone():fill(0) 148 | self.g = self.dw:clone():fill(0) 149 | self.g2 = self.dw:clone():fill(0) 150 | 151 | if self.target_q then 152 | self.target_network = self.network:clone() 153 | end 154 | end 155 | 156 | 157 | function nql:reset(state) 158 | if not state then 159 | return 160 | end 161 | self.best_network = state.best_network 162 | self.network = state.model 163 | self.w, self.dw = self.network:getParameters() 164 | self.dw:zero() 165 | self.numSteps = 0 166 | print("RESET STATE SUCCESFULLY") 167 | end 168 | 169 | 170 | function nql:preprocess(rawstate) 171 | if self.preproc then 172 | return self.preproc:forward(rawstate:float()) 173 | :clone():reshape(self.state_dim) 174 | end 175 | 176 | return rawstate 177 | end 178 | 179 | 180 | function nql:getQUpdate(args) 181 | local s, a, r, s2, term, delta 182 | local q, q2, q2_max 183 | 184 | s = args.s 185 | a = args.a 186 | r = args.r 187 | s2 = args.s2 188 | term = args.term 189 | 190 | -- The order of calls to forward is a bit odd in order 191 | -- to avoid unnecessary calls (we only need 2). 192 | 193 | -- delta = r + (1-terminal) * gamma * max_a Q(s2, a) - Q(s, a) 194 | term = term:clone():float():mul(-1):add(1) 195 | 196 | local target_q_net 197 | if self.target_q then 198 | target_q_net = self.target_network 199 | else 200 | target_q_net = self.network 201 | end 202 | 203 | -- Compute max_a Q(s_2, a). 204 | q2_max = target_q_net:forward(s2):float():max(2) 205 | 206 | -- Compute q2 = (1-terminal) * gamma * max_a Q(s2, a) 207 | q2 = q2_max:clone():mul(self.discount):cmul(term) 208 | 209 | delta = r:clone():float() 210 | 211 | if self.rescale_r then 212 | delta:div(self.r_max) 213 | end 214 | delta:add(q2) 215 | 216 | -- q = Q(s,a) 217 | local q_all = self.network:forward(s):float() 218 | q = torch.FloatTensor(q_all:size(1)) 219 | 220 | for i=1,q_all:size(1) do 221 | q[i] = q_all[i][a[i]] 222 | end 223 | delta:add(-1, q) 224 | 225 | if self.clip_delta then 226 | delta[delta:ge(self.clip_delta)] = self.clip_delta 227 | delta[delta:le(-self.clip_delta)] = -self.clip_delta 228 | end 229 | 230 | local targets = torch.zeros(self.minibatch_size, self.n_actions):float() 231 | for i=1,math.min(self.minibatch_size,a:size(1)) do 232 | targets[i][a[i]] = delta[i] 233 | end 234 | 235 | if self.gpu >= 0 then targets = targets:cuda() end 236 | 237 | return targets, delta, q2_max 238 | end 239 | 240 | 241 | function nql:qLearnMinibatch() 242 | -- Perform a minibatch Q-learning update: 243 | -- w += alpha * (r + gamma max Q(s2,a2) - Q(s,a)) * dQ(s,a)/dw 244 | assert(self.transitions:size() > self.minibatch_size) 245 | 246 | local s, a, r, s2, term = self.transitions:sample(self.minibatch_size) 247 | local targets, delta, q2_max = self:getQUpdate{s=s, a=a, r=r, s2=s2, 248 | term=term, update_qmax=true} 249 | 250 | -- zero gradients of parameters 251 | self.dw:zero() 252 | 253 | -- get new gradient 254 | self.network:backward(s, targets) 255 | 256 | -- add weight cost to gradient 257 | self.dw:add(-self.wc, self.w) 258 | 259 | -- compute linearly annealed learning rate 260 | local t = math.max(0, self.numSteps - self.learn_start) 261 | self.lr = (self.lr_start - self.lr_end) * (self.lr_endt - t)/self.lr_endt + 262 | self.lr_end 263 | self.lr = math.max(self.lr, self.lr_end) 264 | 265 | -- use gradients 266 | self.g:mul(0.95):add(0.05, self.dw) 267 | self.tmp:cmul(self.dw, self.dw) 268 | self.g2:mul(0.95):add(0.05, self.tmp) 269 | self.tmp:cmul(self.g, self.g) 270 | self.tmp:mul(-1) 271 | self.tmp:add(self.g2) 272 | self.tmp:add(0.01) 273 | self.tmp:sqrt() 274 | 275 | -- accumulate update 276 | self.deltas:mul(0):addcdiv(self.lr, self.dw, self.tmp) 277 | self.w:add(self.deltas) 278 | end 279 | 280 | 281 | function nql:sample_validation_data() 282 | local s, a, r, s2, term = self.transitions:sample(self.valid_size) 283 | self.valid_s = s:clone() 284 | self.valid_a = a:clone() 285 | self.valid_r = r:clone() 286 | self.valid_s2 = s2:clone() 287 | self.valid_term = term:clone() 288 | end 289 | 290 | 291 | function nql:compute_validation_statistics() 292 | local targets, delta, q2_max = self:getQUpdate{s=self.valid_s, 293 | a=self.valid_a, r=self.valid_r, s2=self.valid_s2, term=self.valid_term} 294 | 295 | self.v_avg = self.q_max * q2_max:mean() 296 | self.tderr_avg = delta:clone():abs():mean() 297 | end 298 | 299 | 300 | function nql:perceive(reward, rawstate, terminal, testing, testing_ep) 301 | -- Preprocess state (will be set to nil if terminal) 302 | local state = self:preprocess(rawstate):float() 303 | local curState 304 | 305 | if self.max_reward then 306 | reward = math.min(reward, self.max_reward) 307 | end 308 | if self.min_reward then 309 | reward = math.max(reward, self.min_reward) 310 | end 311 | if self.rescale_r then 312 | self.r_max = math.max(self.r_max, reward) 313 | end 314 | 315 | self.transitions:add_recent_state(state, terminal) 316 | 317 | local currentFullState = self.transitions:get_recent() 318 | 319 | --Store transition s, a, r, s' 320 | if self.lastState and not testing then 321 | self.transitions:add(self.lastState, self.lastAction, reward, 322 | self.lastTerminal, priority) 323 | end 324 | 325 | if self.numSteps == self.learn_start+1 and not testing then 326 | self:sample_validation_data() 327 | end 328 | 329 | curState= self.transitions:get_recent() 330 | curState = curState:resize(1, unpack(self.input_dims)) 331 | 332 | -- Select action 333 | local actionIndex = 1 334 | if not terminal then 335 | actionIndex = self:eGreedy(curState, testing_ep) 336 | end 337 | 338 | self.transitions:add_recent_action(actionIndex) 339 | 340 | --Do some Q-learning updates 341 | if self.numSteps > self.learn_start and not testing and 342 | self.numSteps % self.update_freq == 0 then 343 | for i = 1, self.n_replay do 344 | self:qLearnMinibatch() 345 | end 346 | end 347 | 348 | if not testing then 349 | self.numSteps = self.numSteps + 1 350 | end 351 | 352 | self.lastState = state:clone() 353 | self.lastAction = actionIndex 354 | self.lastTerminal = terminal 355 | 356 | if self.target_q and self.numSteps % self.target_q == 1 then 357 | self.target_network = self.network:clone() 358 | end 359 | 360 | if not terminal then 361 | return actionIndex 362 | else 363 | return 0 364 | end 365 | end 366 | 367 | 368 | function nql:eGreedy(state, testing_ep) 369 | self.ep = testing_ep or (self.ep_end + 370 | math.max(0, (self.ep_start - self.ep_end) * (self.ep_endt - 371 | math.max(0, self.numSteps - self.learn_start))/self.ep_endt)) 372 | -- Epsilon greedy 373 | if torch.uniform() < self.ep then 374 | return torch.random(1, self.n_actions) 375 | else 376 | return self:greedy(state) 377 | end 378 | end 379 | 380 | 381 | function nql:greedy(state) 382 | -- Turn single state into minibatch. Needed for convolutional nets. 383 | if state:dim() == 2 then 384 | assert(false, 'Input must be at least 3D') 385 | state = state:resize(1, state:size(1), state:size(2)) 386 | end 387 | 388 | if self.gpu >= 0 then 389 | state = state:cuda() 390 | end 391 | 392 | local q = self.network:forward(state):float():squeeze() 393 | local maxq = q[1] 394 | local besta = {1} 395 | 396 | -- Evaluate all other actions (with random tie-breaking) 397 | for a = 2, self.n_actions do 398 | if q[a] > maxq then 399 | besta = { a } 400 | maxq = q[a] 401 | elseif q[a] == maxq then 402 | besta[#besta+1] = a 403 | end 404 | end 405 | self.bestq = maxq 406 | 407 | local r = torch.random(1, #besta) 408 | 409 | self.lastAction = besta[r] 410 | 411 | return besta[r] 412 | end 413 | 414 | 415 | function nql:createNetwork() 416 | local n_hid = 128 417 | local mlp = nn.Sequential() 418 | mlp:add(nn.Reshape(self.hist_len*self.ncols*self.state_dim)) 419 | mlp:add(nn.Linear(self.hist_len*self.ncols*self.state_dim, n_hid)) 420 | mlp:add(nn.Rectifier()) 421 | mlp:add(nn.Linear(n_hid, n_hid)) 422 | mlp:add(nn.Rectifier()) 423 | mlp:add(nn.Linear(n_hid, self.n_actions)) 424 | 425 | return mlp 426 | end 427 | 428 | 429 | function nql:_loadNet() 430 | local net = self.network 431 | if self.gpu then 432 | net:cuda() 433 | else 434 | net:float() 435 | end 436 | return net 437 | end 438 | 439 | 440 | function nql:init(arg) 441 | self.actions = arg.actions 442 | self.n_actions = #self.actions 443 | self.network = self:_loadNet() 444 | -- Generate targets. 445 | self.transitions:empty() 446 | end 447 | 448 | 449 | function nql:report() 450 | print(get_weight_norms(self.network)) 451 | print(get_grad_norms(self.network)) 452 | end 453 | -------------------------------------------------------------------------------- /structured_priority/dqn/Rectifier.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | --[[ Rectified Linear Unit. 8 | 9 | The output is max(0, input). 10 | --]] 11 | 12 | local Rectifier, parent = torch.class('nn.Rectifier', 'nn.Module') 13 | 14 | -- This module accepts minibatches 15 | function Rectifier:updateOutput(input) 16 | return self.output:resizeAs(input):copy(input):abs():add(input):div(2) 17 | end 18 | 19 | function Rectifier:updateGradInput(input, gradOutput) 20 | self.gradInput:resizeAs(self.output) 21 | return self.gradInput:sign(self.output):cmul(gradOutput) 22 | end -------------------------------------------------------------------------------- /structured_priority/dqn/Scale.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "nn" 8 | require "image" 9 | 10 | local scale = torch.class('nn.Scale', 'nn.Module') 11 | 12 | 13 | function scale:__init(height, width) 14 | self.height = height 15 | self.width = width 16 | end 17 | 18 | function scale:forward(x) 19 | local x = x 20 | if x:dim() > 3 then 21 | x = x[1] 22 | end 23 | 24 | x = image.rgb2y(x) 25 | x = image.scale(x, self.width, self.height, 'bilinear') 26 | return x 27 | end 28 | 29 | function scale:updateOutput(input) 30 | return self:forward(input) 31 | end 32 | 33 | function scale:float() 34 | end 35 | -------------------------------------------------------------------------------- /structured_priority/dqn/TransitionTable.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require 'image' 8 | 9 | local trans = torch.class('dqn.TransitionTable') 10 | 11 | 12 | function trans:__init(args) 13 | self.stateDim = args.stateDim 14 | self.numActions = args.numActions 15 | self.histLen = args.histLen 16 | self.maxSize = args.maxSize or 1024^2 17 | self.bufferSize = args.bufferSize or 1024 18 | self.histType = args.histType or "linear" 19 | self.histSpacing = args.histSpacing or 1 20 | self.zeroFrames = args.zeroFrames or 1 21 | self.nonTermProb = args.nonTermProb or 1 22 | self.nonEventProb = args.nonEventProb or 1 23 | self.gpu = args.gpu 24 | self.numEntries = 0 25 | self.insertIndex = 0 26 | 27 | self.histIndices = {} 28 | local histLen = self.histLen 29 | if self.histType == "linear" then 30 | -- History is the last histLen frames. 31 | self.recentMemSize = self.histSpacing*histLen 32 | for i=1,histLen do 33 | self.histIndices[i] = i*self.histSpacing 34 | end 35 | elseif self.histType == "exp2" then 36 | -- The ith history frame is from 2^(i-1) frames ago. 37 | self.recentMemSize = 2^(histLen-1) 38 | self.histIndices[1] = 1 39 | for i=1,histLen-1 do 40 | self.histIndices[i+1] = self.histIndices[i] + 2^(7-i) 41 | end 42 | elseif self.histType == "exp1.25" then 43 | -- The ith history frame is from 1.25^(i-1) frames ago. 44 | self.histIndices[histLen] = 1 45 | for i=histLen-1,1,-1 do 46 | self.histIndices[i] = math.ceil(1.25*self.histIndices[i+1])+1 47 | end 48 | self.recentMemSize = self.histIndices[1] 49 | for i=1,histLen do 50 | self.histIndices[i] = self.recentMemSize - self.histIndices[i] + 1 51 | end 52 | end 53 | 54 | self.s = torch.ByteTensor(self.maxSize, self.stateDim):fill(0) 55 | self.a = torch.LongTensor(self.maxSize):fill(0) 56 | self.r = torch.zeros(self.maxSize) 57 | self.t = torch.ByteTensor(self.maxSize):fill(0) 58 | self.action_encodings = torch.eye(self.numActions) 59 | 60 | -- Tables for storing the last histLen states. They are used for 61 | -- constructing the most recent agent state more easily. 62 | self.recent_s = {} 63 | self.recent_a = {} 64 | self.recent_t = {} 65 | 66 | local s_size = self.stateDim*histLen 67 | self.buf_a = torch.LongTensor(self.bufferSize):fill(0) 68 | self.buf_r = torch.zeros(self.bufferSize) 69 | self.buf_term = torch.ByteTensor(self.bufferSize):fill(0) 70 | self.buf_s = torch.ByteTensor(self.bufferSize, s_size):fill(0) 71 | self.buf_s2 = torch.ByteTensor(self.bufferSize, s_size):fill(0) 72 | 73 | if self.gpu and self.gpu >= 0 then 74 | self.gpu_s = self.buf_s:float():cuda() 75 | self.gpu_s2 = self.buf_s2:float():cuda() 76 | end 77 | end 78 | 79 | 80 | function trans:reset() 81 | self.numEntries = 0 82 | self.insertIndex = 0 83 | end 84 | 85 | 86 | function trans:size() 87 | return self.numEntries 88 | end 89 | 90 | 91 | function trans:empty() 92 | return self.numEntries == 0 93 | end 94 | 95 | 96 | function trans:fill_buffer() 97 | assert(self.numEntries >= self.bufferSize) 98 | -- clear CPU buffers 99 | self.buf_ind = 1 100 | local ind 101 | for buf_ind=1,self.bufferSize do 102 | local s, a, r, s2, term = self:sample_one(1) 103 | self.buf_s[buf_ind]:copy(s) 104 | self.buf_a[buf_ind] = a 105 | self.buf_r[buf_ind] = r 106 | self.buf_s2[buf_ind]:copy(s2) 107 | self.buf_term[buf_ind] = term 108 | end 109 | self.buf_s = self.buf_s:float():div(255) 110 | self.buf_s2 = self.buf_s2:float():div(255) 111 | if self.gpu and self.gpu >= 0 then 112 | self.gpu_s:copy(self.buf_s) 113 | self.gpu_s2:copy(self.buf_s2) 114 | end 115 | end 116 | 117 | 118 | function trans:sample_one() 119 | assert(self.numEntries > 1) 120 | local index 121 | local valid = false 122 | while not valid do 123 | -- start at 2 because of previous action 124 | index = torch.random(2, self.numEntries-self.recentMemSize) 125 | if self.t[index+self.recentMemSize-1] == 0 then 126 | valid = true 127 | end 128 | if self.nonTermProb < 1 and self.t[index+self.recentMemSize] == 0 and 129 | torch.uniform() > self.nonTermProb then 130 | -- Discard non-terminal states with probability (1-nonTermProb). 131 | -- Note that this is the terminal flag for s_{t+1}. 132 | valid = false 133 | end 134 | if self.nonEventProb < 1 and self.t[index+self.recentMemSize] == 0 and 135 | self.r[index+self.recentMemSize-1] == 0 and 136 | torch.uniform() > self.nonTermProb then 137 | -- Discard non-terminal or non-reward states with 138 | -- probability (1-nonTermProb). 139 | valid = false 140 | end 141 | end 142 | 143 | return self:get(index) 144 | end 145 | 146 | 147 | function trans:sample(batch_size) 148 | local batch_size = batch_size or 1 149 | assert(batch_size < self.bufferSize) 150 | 151 | if not self.buf_ind or self.buf_ind + batch_size - 1 > self.bufferSize then 152 | self:fill_buffer() 153 | end 154 | 155 | local index = self.buf_ind 156 | 157 | self.buf_ind = self.buf_ind+batch_size 158 | local range = {{index, index+batch_size-1}} 159 | 160 | local buf_s, buf_s2, buf_a, buf_r, buf_term = self.buf_s, self.buf_s2, 161 | self.buf_a, self.buf_r, self.buf_term 162 | if self.gpu and self.gpu >=0 then 163 | buf_s = self.gpu_s 164 | buf_s2 = self.gpu_s2 165 | end 166 | 167 | return buf_s[range], buf_a[range], buf_r[range], buf_s2[range], buf_term[range] 168 | end 169 | 170 | 171 | function trans:concatFrames(index, use_recent) 172 | if use_recent then 173 | s, t = self.recent_s, self.recent_t 174 | else 175 | s, t = self.s, self.t 176 | end 177 | 178 | local fullstate = s[1].new() 179 | fullstate:resize(self.histLen, unpack(s[1]:size():totable())) 180 | 181 | -- Zero out frames from all but the most recent episode. 182 | local zero_out = false 183 | local episode_start = self.histLen 184 | 185 | for i=self.histLen-1,1,-1 do 186 | if not zero_out then 187 | for j=index+self.histIndices[i]-1,index+self.histIndices[i+1]-2 do 188 | if t[j] == 1 then 189 | zero_out = true 190 | break 191 | end 192 | end 193 | end 194 | 195 | if zero_out then 196 | fullstate[i]:zero() 197 | else 198 | episode_start = i 199 | end 200 | end 201 | 202 | if self.zeroFrames == 0 then 203 | episode_start = 1 204 | end 205 | 206 | -- Copy frames from the current episode. 207 | for i=episode_start,self.histLen do 208 | fullstate[i]:copy(s[index+self.histIndices[i]-1]) 209 | end 210 | 211 | return fullstate 212 | end 213 | 214 | 215 | function trans:concatActions(index, use_recent) 216 | local act_hist = torch.FloatTensor(self.histLen, self.numActions) 217 | if use_recent then 218 | a, t = self.recent_a, self.recent_t 219 | else 220 | a, t = self.a, self.t 221 | end 222 | 223 | -- Zero out frames from all but the most recent episode. 224 | local zero_out = false 225 | local episode_start = self.histLen 226 | 227 | for i=self.histLen-1,1,-1 do 228 | if not zero_out then 229 | for j=index+self.histIndices[i]-1,index+self.histIndices[i+1]-2 do 230 | if t[j] == 1 then 231 | zero_out = true 232 | break 233 | end 234 | end 235 | end 236 | 237 | if zero_out then 238 | act_hist[i]:zero() 239 | else 240 | episode_start = i 241 | end 242 | end 243 | 244 | if self.zeroFrames == 0 then 245 | episode_start = 1 246 | end 247 | 248 | -- Copy frames from the current episode. 249 | for i=episode_start,self.histLen do 250 | act_hist[i]:copy(self.action_encodings[a[index+self.histIndices[i]-1]]) 251 | end 252 | 253 | return act_hist 254 | end 255 | 256 | 257 | function trans:get_recent() 258 | -- Assumes that the most recent state has been added, but the action has not 259 | return self:concatFrames(1, true):float():div(255) 260 | end 261 | 262 | 263 | function trans:get(index) 264 | local s = self:concatFrames(index) 265 | local s2 = self:concatFrames(index+1) 266 | local ar_index = index+self.recentMemSize-1 267 | 268 | return s, self.a[ar_index], self.r[ar_index], s2, self.t[ar_index+1] 269 | end 270 | 271 | 272 | function trans:add(s, a, r, term) 273 | assert(s, 'State cannot be nil') 274 | assert(a, 'Action cannot be nil') 275 | assert(r, 'Reward cannot be nil') 276 | 277 | -- Incremenet until at full capacity 278 | if self.numEntries < self.maxSize then 279 | self.numEntries = self.numEntries + 1 280 | end 281 | 282 | -- Always insert at next index, then wrap around 283 | self.insertIndex = self.insertIndex + 1 284 | -- Overwrite oldest experience once at capacity 285 | if self.insertIndex > self.maxSize then 286 | self.insertIndex = 1 287 | end 288 | 289 | -- Overwrite (s,a,r,t) at insertIndex 290 | self.s[self.insertIndex] = s:clone():float():mul(255) 291 | self.a[self.insertIndex] = a 292 | self.r[self.insertIndex] = r 293 | if term then 294 | self.t[self.insertIndex] = 1 295 | else 296 | self.t[self.insertIndex] = 0 297 | end 298 | end 299 | 300 | 301 | function trans:add_recent_state(s, term) 302 | local s = s:clone():float():mul(255):byte() 303 | if #self.recent_s == 0 then 304 | for i=1,self.recentMemSize do 305 | table.insert(self.recent_s, s:clone():zero()) 306 | table.insert(self.recent_t, 1) 307 | end 308 | end 309 | 310 | table.insert(self.recent_s, s) 311 | if term then 312 | table.insert(self.recent_t, 1) 313 | else 314 | table.insert(self.recent_t, 0) 315 | end 316 | 317 | -- Keep recentMemSize states. 318 | if #self.recent_s > self.recentMemSize then 319 | table.remove(self.recent_s, 1) 320 | table.remove(self.recent_t, 1) 321 | end 322 | end 323 | 324 | 325 | function trans:add_recent_action(a) 326 | if #self.recent_a == 0 then 327 | for i=1,self.recentMemSize do 328 | table.insert(self.recent_a, 1) 329 | end 330 | end 331 | 332 | table.insert(self.recent_a, a) 333 | 334 | -- Keep recentMemSize steps. 335 | if #self.recent_a > self.recentMemSize then 336 | table.remove(self.recent_a, 1) 337 | end 338 | end 339 | 340 | 341 | --[[ 342 | Override the write function to serialize this class into a file. 343 | We do not want to store anything into the file, just the necessary info 344 | to create an empty transition table. 345 | 346 | @param file (FILE object ) @see torch.DiskFile 347 | --]] 348 | function trans:write(file) 349 | file:writeObject({self.stateDim, 350 | self.numActions, 351 | self.histLen, 352 | self.maxSize, 353 | self.bufferSize, 354 | self.numEntries, 355 | self.insertIndex, 356 | self.recentMemSize, 357 | self.histIndices}) 358 | end 359 | 360 | 361 | --[[ 362 | Override the read function to desearialize this class from file. 363 | Recreates an empty table. 364 | 365 | @param file (FILE object ) @see torch.DiskFile 366 | --]] 367 | function trans:read(file) 368 | local stateDim, numActions, histLen, maxSize, bufferSize, numEntries, insertIndex, recentMemSize, histIndices = unpack(file:readObject()) 369 | self.stateDim = stateDim 370 | self.numActions = numActions 371 | self.histLen = histLen 372 | self.maxSize = maxSize 373 | self.bufferSize = bufferSize 374 | self.recentMemSize = recentMemSize 375 | self.histIndices = histIndices 376 | self.numEntries = 0 377 | self.insertIndex = 0 378 | 379 | self.s = torch.ByteTensor(self.maxSize, self.stateDim):fill(0) 380 | self.a = torch.LongTensor(self.maxSize):fill(0) 381 | self.r = torch.zeros(self.maxSize) 382 | self.t = torch.ByteTensor(self.maxSize):fill(0) 383 | self.action_encodings = torch.eye(self.numActions) 384 | 385 | -- Tables for storing the last histLen states. They are used for 386 | -- constructing the most recent agent state more easily. 387 | self.recent_s = {} 388 | self.recent_a = {} 389 | self.recent_t = {} 390 | 391 | self.buf_a = torch.LongTensor(self.bufferSize):fill(0) 392 | self.buf_r = torch.zeros(self.bufferSize) 393 | self.buf_term = torch.ByteTensor(self.bufferSize):fill(0) 394 | self.buf_s = torch.ByteTensor(self.bufferSize, self.stateDim * self.histLen):fill(0) 395 | self.buf_s2 = torch.ByteTensor(self.bufferSize, self.stateDim * self.histLen):fill(0) 396 | 397 | if self.gpu and self.gpu >= 0 then 398 | self.gpu_s = self.buf_s:float():cuda() 399 | self.gpu_s2 = self.buf_s2:float():cuda() 400 | end 401 | end 402 | -------------------------------------------------------------------------------- /structured_priority/dqn/TransitionTable_spriority.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require 'image' 8 | 9 | local trans = torch.class('dqn.TransitionTable') 10 | 11 | 12 | function trans:__init(args) 13 | self.stateDim = args.stateDim 14 | self.numActions = args.numActions 15 | self.histLen = args.histLen 16 | self.maxSize = args.maxSize or 1024^2 17 | self.bufferSize = args.bufferSize or 1024 18 | self.histType = args.histType or "linear" 19 | self.histSpacing = args.histSpacing or 1 20 | self.zeroFrames = args.zeroFrames or 1 21 | self.nonTermProb = args.nonTermProb or 1 22 | self.nonEventProb = args.nonEventProb or 1 23 | self.gpu = args.gpu 24 | self.numEntries = 0 25 | self.insertIndex = 0 26 | self.ptrInsertIndex = 1 27 | 28 | self.histIndices = {} 29 | local histLen = self.histLen 30 | if self.histType == "linear" then 31 | -- History is the last histLen frames. 32 | self.recentMemSize = self.histSpacing*histLen 33 | for i=1,histLen do 34 | self.histIndices[i] = i*self.histSpacing 35 | end 36 | elseif self.histType == "exp2" then 37 | -- The ith history frame is from 2^(i-1) frames ago. 38 | self.recentMemSize = 2^(histLen-1) 39 | self.histIndices[1] = 1 40 | for i=1,histLen-1 do 41 | self.histIndices[i+1] = self.histIndices[i] + 2^(7-i) 42 | end 43 | elseif self.histType == "exp1.25" then 44 | -- The ith history frame is from 1.25^(i-1) frames ago. 45 | self.histIndices[histLen] = 1 46 | for i=histLen-1,1,-1 do 47 | self.histIndices[i] = math.ceil(1.25*self.histIndices[i+1])+1 48 | end 49 | self.recentMemSize = self.histIndices[1] 50 | for i=1,histLen do 51 | self.histIndices[i] = self.recentMemSize - self.histIndices[i] + 1 52 | end 53 | end 54 | 55 | self.s = torch.ByteTensor(self.maxSize, self.stateDim):fill(0) 56 | self.a = torch.LongTensor(self.maxSize):fill(0) 57 | self.r = torch.zeros(self.maxSize) 58 | self.t = torch.ByteTensor(self.maxSize):fill(0) 59 | self.action_encodings = torch.eye(self.numActions) 60 | self.end_ptrs = {} 61 | self.dyn_ptrs = {} 62 | 63 | -- Tables for storing the last histLen states. They are used for 64 | -- constructing the most recent agent state more easily. 65 | self.recent_s = {} 66 | self.recent_a = {} 67 | self.recent_t = {} 68 | 69 | local s_size = self.stateDim*histLen 70 | self.buf_a = torch.LongTensor(self.bufferSize):fill(0) 71 | self.buf_r = torch.zeros(self.bufferSize) 72 | self.buf_term = torch.ByteTensor(self.bufferSize):fill(0) 73 | self.buf_s = torch.ByteTensor(self.bufferSize, s_size):fill(0) 74 | self.buf_s2 = torch.ByteTensor(self.bufferSize, s_size):fill(0) 75 | 76 | if self.gpu and self.gpu >= 0 then 77 | self.gpu_s = self.buf_s:float():cuda() 78 | self.gpu_s2 = self.buf_s2:float():cuda() 79 | end 80 | end 81 | 82 | 83 | function trans:reset() 84 | self.numEntries = 0 85 | self.insertIndex = 0 86 | self.ptrInsertIndex = 1 87 | end 88 | 89 | 90 | function trans:size() 91 | return self.numEntries 92 | end 93 | 94 | 95 | function trans:empty() 96 | return self.numEntries == 0 97 | end 98 | 99 | 100 | function trans:fill_buffer() 101 | assert(self.numEntries >= self.bufferSize) 102 | -- clear CPU buffers 103 | self.buf_ind = 1 104 | local ind 105 | for buf_ind=1,self.bufferSize do 106 | local s, a, r, s2, term = self:sample_one(1) 107 | self.buf_s[buf_ind]:copy(s) 108 | self.buf_a[buf_ind] = a 109 | self.buf_r[buf_ind] = r 110 | self.buf_s2[buf_ind]:copy(s2) 111 | self.buf_term[buf_ind] = term 112 | end 113 | self.buf_s = self.buf_s:float():div(255) 114 | self.buf_s2 = self.buf_s2:float():div(255) 115 | if self.gpu and self.gpu >= 0 then 116 | self.gpu_s:copy(self.buf_s) 117 | self.gpu_s2:copy(self.buf_s2) 118 | end 119 | end 120 | 121 | -- TODO : replace 122 | -- function trans:sample_one() 123 | -- assert(self.numEntries > 1) 124 | -- local index 125 | -- local valid = false 126 | -- while not valid do 127 | -- -- start at 2 because of previous action 128 | -- index = torch.random(2, self.numEntries-self.recentMemSize) 129 | -- if self.t[index+self.recentMemSize-1] == 0 then 130 | -- valid = true 131 | -- end 132 | -- if self.nonTermProb < 1 and self.t[index+self.recentMemSize] == 0 and 133 | -- torch.uniform() > self.nonTermProb then 134 | -- -- Discard non-terminal states with probability (1-nonTermProb). 135 | -- -- Note that this is the terminal flag for s_{t+1}. 136 | -- valid = false 137 | -- end 138 | -- if self.nonEventProb < 1 and self.t[index+self.recentMemSize] == 0 and 139 | -- self.r[index+self.recentMemSize-1] == 0 and 140 | -- torch.uniform() > self.nonTermProb then 141 | -- -- Discard non-terminal or non-reward states with 142 | -- -- probability (1-nonTermProb). 143 | -- valid = false 144 | -- end 145 | -- end 146 | 147 | -- return self:get(index) 148 | -- end 149 | 150 | function trans:sample_one() 151 | assert(self.numEntries > 1) 152 | assert(#self.end_ptrs == #self.dyn_ptrs) 153 | 154 | local index = -1 155 | local indx 156 | while index <= 0 do 157 | indx = torch.random(#self.end_ptrs) 158 | index = self.dyn_ptrs[indx] - self.recentMemSize + 1 159 | end 160 | 161 | self.dyn_ptrs[indx] = self.dyn_ptrs[indx] - 1 162 | if self.dyn_ptrs[indx] <= 0 or self.dyn_ptrs[indx] == self.end_ptrs[indx-1] then 163 | self.dyn_ptrs[indx] = self.end_ptrs[indx] 164 | end 165 | 166 | return self:get(index) 167 | end 168 | 169 | 170 | 171 | function trans:sample(batch_size) 172 | local batch_size = batch_size or 1 173 | assert(batch_size < self.bufferSize) 174 | 175 | if not self.buf_ind or self.buf_ind + batch_size - 1 > self.bufferSize then 176 | self:fill_buffer() 177 | end 178 | 179 | local index = self.buf_ind 180 | 181 | self.buf_ind = self.buf_ind+batch_size 182 | local range = {{index, index+batch_size-1}} 183 | 184 | local buf_s, buf_s2, buf_a, buf_r, buf_term = self.buf_s, self.buf_s2, 185 | self.buf_a, self.buf_r, self.buf_term 186 | if self.gpu and self.gpu >=0 then 187 | buf_s = self.gpu_s 188 | buf_s2 = self.gpu_s2 189 | end 190 | 191 | return buf_s[range], buf_a[range], buf_r[range], buf_s2[range], buf_term[range] 192 | end 193 | 194 | 195 | function trans:concatFrames(index, use_recent) 196 | if use_recent then 197 | s, t = self.recent_s, self.recent_t 198 | else 199 | s, t = self.s, self.t 200 | end 201 | 202 | local fullstate = s[1].new() 203 | fullstate:resize(self.histLen, unpack(s[1]:size():totable())) 204 | 205 | -- Zero out frames from all but the most recent episode. 206 | local zero_out = false 207 | local episode_start = self.histLen 208 | 209 | for i=self.histLen-1,1,-1 do 210 | if not zero_out then 211 | for j=index+self.histIndices[i]-1,index+self.histIndices[i+1]-2 do 212 | if t[j] == 1 then 213 | zero_out = true 214 | break 215 | end 216 | end 217 | end 218 | 219 | if zero_out then 220 | fullstate[i]:zero() 221 | else 222 | episode_start = i 223 | end 224 | end 225 | 226 | if self.zeroFrames == 0 then 227 | episode_start = 1 228 | end 229 | 230 | -- Copy frames from the current episode. 231 | for i=episode_start,self.histLen do 232 | fullstate[i]:copy(s[index+self.histIndices[i]-1]) 233 | end 234 | 235 | return fullstate 236 | end 237 | 238 | 239 | function trans:concatActions(index, use_recent) 240 | local act_hist = torch.FloatTensor(self.histLen, self.numActions) 241 | if use_recent then 242 | a, t = self.recent_a, self.recent_t 243 | else 244 | a, t = self.a, self.t 245 | end 246 | 247 | -- Zero out frames from all but the most recent episode. 248 | local zero_out = false 249 | local episode_start = self.histLen 250 | 251 | for i=self.histLen-1,1,-1 do 252 | if not zero_out then 253 | for j=index+self.histIndices[i]-1,index+self.histIndices[i+1]-2 do 254 | if t[j] == 1 then 255 | zero_out = true 256 | break 257 | end 258 | end 259 | end 260 | 261 | if zero_out then 262 | act_hist[i]:zero() 263 | else 264 | episode_start = i 265 | end 266 | end 267 | 268 | if self.zeroFrames == 0 then 269 | episode_start = 1 270 | end 271 | 272 | -- Copy frames from the current episode. 273 | for i=episode_start,self.histLen do 274 | act_hist[i]:copy(self.action_encodings[a[index+self.histIndices[i]-1]]) 275 | end 276 | 277 | return act_hist 278 | end 279 | 280 | 281 | function trans:get_recent() 282 | -- Assumes that the most recent state has been added, but the action has not 283 | return self:concatFrames(1, true):float():div(255) 284 | end 285 | 286 | 287 | function trans:get(index) 288 | local s = self:concatFrames(index) 289 | local s2 = self:concatFrames(index+1) 290 | local ar_index = index+self.recentMemSize-1 291 | 292 | return s, self.a[ar_index], self.r[ar_index], s2, self.t[ar_index+1] 293 | end 294 | 295 | 296 | function trans:add(s, a, r, term) 297 | assert(s, 'State cannot be nil') 298 | assert(a, 'Action cannot be nil') 299 | assert(r, 'Reward cannot be nil') 300 | 301 | -- Incremenet until at full capacity 302 | if self.numEntries < self.maxSize then 303 | self.numEntries = self.numEntries + 1 304 | end 305 | 306 | -- Always insert at next index, then wrap around 307 | self.insertIndex = self.insertIndex + 1 308 | 309 | 310 | 311 | -- Overwrite oldest experience once at capacity 312 | if self.insertIndex > self.maxSize then 313 | self.insertIndex = 1 314 | self.ptrInsertIndex = 1 315 | end 316 | 317 | -- Overwrite (s,a,r,t) at insertIndex 318 | self.s[self.insertIndex] = s:clone():float():mul(255) 319 | self.a[self.insertIndex] = a 320 | self.r[self.insertIndex] = r 321 | if self.end_ptrs[self.ptrInsertIndex] == self.insertIndex then 322 | table.remove(self.end_ptrs,self.ptrInsertIndex) 323 | table.remove(self.dyn_ptrs,self.ptrInsertIndex) 324 | end 325 | if term then 326 | self.t[self.insertIndex] = 1 327 | table.insert(self.end_ptrs, self.ptrInsertIndex, self.insertIndex) 328 | table.insert(self.dyn_ptrs, self.ptrInsertIndex, self.insertIndex) 329 | self.ptrInsertIndex = self.ptrInsertIndex + 1 330 | else 331 | self.t[self.insertIndex] = 0 332 | end 333 | end 334 | 335 | 336 | function trans:add_recent_state(s, term) 337 | local s = s:clone():float():mul(255):byte() 338 | if #self.recent_s == 0 then 339 | for i=1,self.recentMemSize do 340 | table.insert(self.recent_s, s:clone():zero()) 341 | table.insert(self.recent_t, 1) 342 | end 343 | end 344 | 345 | table.insert(self.recent_s, s) 346 | if term then 347 | table.insert(self.recent_t, 1) 348 | else 349 | table.insert(self.recent_t, 0) 350 | end 351 | 352 | -- Keep recentMemSize states. 353 | if #self.recent_s > self.recentMemSize then 354 | table.remove(self.recent_s, 1) 355 | table.remove(self.recent_t, 1) 356 | end 357 | end 358 | 359 | 360 | function trans:add_recent_action(a) 361 | if #self.recent_a == 0 then 362 | for i=1,self.recentMemSize do 363 | table.insert(self.recent_a, 1) 364 | end 365 | end 366 | 367 | table.insert(self.recent_a, a) 368 | 369 | -- Keep recentMemSize steps. 370 | if #self.recent_a > self.recentMemSize then 371 | table.remove(self.recent_a, 1) 372 | end 373 | end 374 | 375 | 376 | --[[ 377 | Override the write function to serialize this class into a file. 378 | We do not want to store anything into the file, just the necessary info 379 | to create an empty transition table. 380 | 381 | @param file (FILE object ) @see torch.DiskFile 382 | --]] 383 | function trans:write(file) 384 | file:writeObject({self.stateDim, 385 | self.numActions, 386 | self.histLen, 387 | self.maxSize, 388 | self.bufferSize, 389 | self.numEntries, 390 | self.insertIndex, 391 | self.recentMemSize, 392 | self.histIndices}) 393 | end 394 | 395 | 396 | --[[ 397 | Override the read function to desearialize this class from file. 398 | Recreates an empty table. 399 | 400 | @param file (FILE object ) @see torch.DiskFile 401 | --]] 402 | function trans:read(file) 403 | local stateDim, numActions, histLen, maxSize, bufferSize, numEntries, insertIndex, recentMemSize, histIndices = unpack(file:readObject()) 404 | self.stateDim = stateDim 405 | self.numActions = numActions 406 | self.histLen = histLen 407 | self.maxSize = maxSize 408 | self.bufferSize = bufferSize 409 | self.recentMemSize = recentMemSize 410 | self.histIndices = histIndices 411 | self.numEntries = 0 412 | self.insertIndex = 0 413 | 414 | self.s = torch.ByteTensor(self.maxSize, self.stateDim):fill(0) 415 | self.a = torch.LongTensor(self.maxSize):fill(0) 416 | self.r = torch.zeros(self.maxSize) 417 | self.t = torch.ByteTensor(self.maxSize):fill(0) 418 | self.action_encodings = torch.eye(self.numActions) 419 | 420 | -- Tables for storing the last histLen states. They are used for 421 | -- constructing the most recent agent state more easily. 422 | self.recent_s = {} 423 | self.recent_a = {} 424 | self.recent_t = {} 425 | 426 | self.buf_a = torch.LongTensor(self.bufferSize):fill(0) 427 | self.buf_r = torch.zeros(self.bufferSize) 428 | self.buf_term = torch.ByteTensor(self.bufferSize):fill(0) 429 | self.buf_s = torch.ByteTensor(self.bufferSize, self.stateDim * self.histLen):fill(0) 430 | self.buf_s2 = torch.ByteTensor(self.bufferSize, self.stateDim * self.histLen):fill(0) 431 | 432 | if self.gpu and self.gpu >= 0 then 433 | self.gpu_s = self.buf_s:float():cuda() 434 | self.gpu_s2 = self.buf_s2:float():cuda() 435 | end 436 | end 437 | -------------------------------------------------------------------------------- /structured_priority/dqn/convnet.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "initenv" 8 | 9 | function create_network(args) 10 | 11 | local net = nn.Sequential() 12 | net:add(nn.Reshape(unpack(args.input_dims))) 13 | 14 | --- first convolutional layer 15 | local convLayer = nn.SpatialConvolution 16 | 17 | net:add(convLayer(args.hist_len*args.ncols, args.n_units[1], 18 | args.filter_size[1], args.filter_size[1], 19 | args.filter_stride[1], args.filter_stride[1],1)) 20 | net:add(args.nl()) 21 | 22 | -- Add convolutional layers 23 | for i=1,(#args.n_units-1) do 24 | -- second convolutional layer 25 | net:add(convLayer(args.n_units[i], args.n_units[i+1], 26 | args.filter_size[i+1], args.filter_size[i+1], 27 | args.filter_stride[i+1], args.filter_stride[i+1])) 28 | net:add(args.nl()) 29 | end 30 | 31 | local nel 32 | if args.gpu >= 0 then 33 | nel = net:cuda():forward(torch.zeros(1,unpack(args.input_dims)) 34 | :cuda()):nElement() 35 | else 36 | nel = net:forward(torch.zeros(1,unpack(args.input_dims))):nElement() 37 | end 38 | 39 | -- reshape all feature planes into a vector per example 40 | net:add(nn.Reshape(nel)) 41 | 42 | -- fully connected layer 43 | net:add(nn.Linear(nel, args.n_hid[1])) 44 | net:add(args.nl()) 45 | local last_layer_size = args.n_hid[1] 46 | 47 | for i=1,(#args.n_hid-1) do 48 | -- add Linear layer 49 | last_layer_size = args.n_hid[i+1] 50 | net:add(nn.Linear(args.n_hid[i], last_layer_size)) 51 | net:add(args.nl()) 52 | end 53 | 54 | -- add the last fully connected layer (to actions) 55 | net:add(nn.Linear(last_layer_size, args.n_actions)) 56 | 57 | if args.gpu >=0 then 58 | net:cuda() 59 | end 60 | if args.verbose >= 2 then 61 | print(net) 62 | print('Convolutional layers flattened output size:', nel) 63 | end 64 | return net 65 | end 66 | -------------------------------------------------------------------------------- /structured_priority/dqn/convnet_atari3.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require 'convnet' 8 | 9 | return function(args) 10 | args.n_units = {32, 64, 64} 11 | args.filter_size = {8, 4, 3} 12 | args.filter_stride = {4, 2, 1} 13 | args.n_hid = {512} 14 | args.nl = nn.Rectifier 15 | 16 | return create_network(args) 17 | end 18 | 19 | -------------------------------------------------------------------------------- /structured_priority/dqn/initenv.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | dqn = {} 7 | 8 | require 'torch' 9 | require 'nn' 10 | require 'nngraph' 11 | require 'nnutils' 12 | require 'image' 13 | require 'Scale' 14 | require 'NeuralQLearner' 15 | 16 | if PRIORITY_FLAG == 0 then 17 | require 'TransitionTable' 18 | else 19 | require 'TransitionTable_spriority' 20 | end 21 | 22 | require 'Rectifier' 23 | 24 | 25 | function torchSetup(_opt) 26 | _opt = _opt or {} 27 | local opt = table.copy(_opt) 28 | assert(opt) 29 | 30 | -- preprocess options: 31 | --- convert options strings to tables 32 | if opt.pool_frms then 33 | opt.pool_frms = str_to_table(opt.pool_frms) 34 | end 35 | if opt.env_params then 36 | opt.env_params = str_to_table(opt.env_params) 37 | end 38 | if opt.agent_params then 39 | opt.agent_params = str_to_table(opt.agent_params) 40 | opt.agent_params.gpu = opt.gpu 41 | opt.agent_params.best = opt.best 42 | opt.agent_params.verbose = opt.verbose 43 | if opt.network ~= '' then 44 | opt.agent_params.network = opt.network 45 | end 46 | end 47 | 48 | --- general setup 49 | opt.tensorType = opt.tensorType or 'torch.FloatTensor' 50 | torch.setdefaulttensortype(opt.tensorType) 51 | if not opt.threads then 52 | opt.threads = 4 53 | end 54 | torch.setnumthreads(opt.threads) 55 | if not opt.verbose then 56 | opt.verbose = 10 57 | end 58 | if opt.verbose >= 1 then 59 | print('Torch Threads:', torch.getnumthreads()) 60 | end 61 | 62 | --- set gpu device 63 | if opt.gpu and opt.gpu >= 0 then 64 | require 'cutorch' 65 | require 'cunn' 66 | if opt.gpu == 0 then 67 | local gpu_id = tonumber(os.getenv('GPU_ID')) 68 | if gpu_id then opt.gpu = gpu_id+1 end 69 | end 70 | if opt.gpu > 0 then cutorch.setDevice(opt.gpu) end 71 | opt.gpu = cutorch.getDevice() 72 | print('Using GPU device id:', opt.gpu-1) 73 | else 74 | opt.gpu = -1 75 | if opt.verbose >= 1 then 76 | print('Using CPU code only. GPU device id:', opt.gpu) 77 | end 78 | end 79 | 80 | --- set up random number generators 81 | -- removing lua RNG; seeding torch RNG with opt.seed and setting cutorch 82 | -- RNG seed to the first uniform random int32 from the previous RNG; 83 | -- this is preferred because using the same seed for both generators 84 | -- may introduce correlations; we assume that both torch RNGs ensure 85 | -- adequate dispersion for different seeds. 86 | math.random = nil 87 | opt.seed = opt.seed or 1 88 | torch.manualSeed(opt.seed) 89 | if opt.verbose >= 1 then 90 | print('Torch Seed:', torch.initialSeed()) 91 | end 92 | local firstRandInt = torch.random() 93 | if opt.gpu >= 0 then 94 | cutorch.manualSeed(firstRandInt) 95 | if opt.verbose >= 1 then 96 | print('CUTorch Seed:', cutorch.initialSeed()) 97 | end 98 | end 99 | 100 | return opt 101 | end 102 | 103 | 104 | function setup(_opt) 105 | assert(_opt) 106 | 107 | --preprocess options: 108 | --- convert options strings to tables 109 | _opt.pool_frms = str_to_table(_opt.pool_frms) 110 | _opt.env_params = str_to_table(_opt.env_params) 111 | _opt.agent_params = str_to_table(_opt.agent_params) 112 | if _opt.agent_params.transition_params then 113 | _opt.agent_params.transition_params = 114 | str_to_table(_opt.agent_params.transition_params) 115 | end 116 | 117 | --- first things first 118 | local opt = torchSetup(_opt) 119 | 120 | -- load training framework and environment 121 | local framework = require(opt.framework) 122 | assert(framework) 123 | 124 | local gameEnv = framework.GameEnvironment(opt) 125 | local gameActions = gameEnv:getActions() 126 | 127 | -- agent options 128 | _opt.agent_params.actions = gameActions 129 | _opt.agent_params.gpu = _opt.gpu 130 | _opt.agent_params.best = _opt.best 131 | if _opt.network ~= '' then 132 | _opt.agent_params.network = _opt.network 133 | end 134 | _opt.agent_params.verbose = _opt.verbose 135 | if not _opt.agent_params.state_dim then 136 | _opt.agent_params.state_dim = gameEnv:nObsFeature() 137 | end 138 | 139 | local agent = dqn[_opt.agent](_opt.agent_params) 140 | 141 | if opt.verbose >= 1 then 142 | print('Set up Torch using these options:') 143 | for k, v in pairs(opt) do 144 | print(k, v) 145 | end 146 | end 147 | 148 | return gameEnv, gameActions, agent, opt 149 | end 150 | 151 | 152 | 153 | --- other functions 154 | 155 | function str_to_table(str) 156 | if type(str) == 'table' then 157 | return str 158 | end 159 | if not str or type(str) ~= 'string' then 160 | if type(str) == 'table' then 161 | return str 162 | end 163 | return {} 164 | end 165 | local ttr 166 | if str ~= '' then 167 | local ttx=tt 168 | loadstring('tt = {' .. str .. '}')() 169 | ttr = tt 170 | tt = ttx 171 | else 172 | ttr = {} 173 | end 174 | return ttr 175 | end 176 | 177 | function table.copy(t) 178 | if t == nil then return nil end 179 | local nt = {} 180 | for k, v in pairs(t) do 181 | if type(v) == 'table' then 182 | nt[k] = table.copy(v) 183 | else 184 | nt[k] = v 185 | end 186 | end 187 | setmetatable(nt, table.copy(getmetatable(t))) 188 | return nt 189 | end 190 | -------------------------------------------------------------------------------- /structured_priority/dqn/net_downsample_2x_full_y.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "image" 8 | require "Scale" 9 | 10 | local function create_network(args) 11 | -- Y (luminance) 12 | return nn.Scale(84, 84, true) 13 | end 14 | 15 | return create_network 16 | -------------------------------------------------------------------------------- /structured_priority/dqn/nnutils.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "torch" 8 | 9 | function recursive_map(module, field, func) 10 | local str = "" 11 | if module[field] or module.modules then 12 | str = str .. torch.typename(module) .. ": " 13 | end 14 | if module[field] then 15 | str = str .. func(module[field]) 16 | end 17 | if module.modules then 18 | str = str .. "[" 19 | for i, submodule in ipairs(module.modules) do 20 | local submodule_str = recursive_map(submodule, field, func) 21 | str = str .. submodule_str 22 | if i < #module.modules and string.len(submodule_str) > 0 then 23 | str = str .. " " 24 | end 25 | end 26 | str = str .. "]" 27 | end 28 | 29 | return str 30 | end 31 | 32 | function abs_mean(w) 33 | return torch.mean(torch.abs(w:clone():float())) 34 | end 35 | 36 | function abs_max(w) 37 | return torch.abs(w:clone():float()):max() 38 | end 39 | 40 | -- Build a string of average absolute weight values for the modules in the 41 | -- given network. 42 | function get_weight_norms(module) 43 | return "Weight norms:\n" .. recursive_map(module, "weight", abs_mean) .. 44 | "\nWeight max:\n" .. recursive_map(module, "weight", abs_max) 45 | end 46 | 47 | -- Build a string of average absolute weight gradient values for the modules 48 | -- in the given network. 49 | function get_grad_norms(module) 50 | return "Weight grad norms:\n" .. 51 | recursive_map(module, "gradWeight", abs_mean) .. 52 | "\nWeight grad max:\n" .. recursive_map(module, "gradWeight", abs_max) 53 | end 54 | -------------------------------------------------------------------------------- /structured_priority/dqn/test_agent.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | See LICENSE file for full terms of limited license. 4 | ]] 5 | 6 | gd = require "gd" 7 | 8 | if not dqn then 9 | require "initenv" 10 | end 11 | 12 | local cmd = torch.CmdLine() 13 | cmd:text() 14 | cmd:text('Train Agent in Environment:') 15 | cmd:text() 16 | cmd:text('Options:') 17 | 18 | cmd:option('-framework', '', 'name of training framework') 19 | cmd:option('-env', '', 'name of environment to use') 20 | cmd:option('-game_path', '', 'path to environment file (ROM)') 21 | cmd:option('-env_params', '', 'string of environment parameters') 22 | cmd:option('-pool_frms', '', 23 | 'string of frame pooling parameters (e.g.: size=2,type="max")') 24 | cmd:option('-actrep', 1, 'how many times to repeat action') 25 | cmd:option('-random_starts', 0, 'play action 0 between 1 and random_starts ' .. 26 | 'number of times at the start of each training episode') 27 | 28 | cmd:option('-name', '', 'filename used for saving network and training history') 29 | cmd:option('-network', '', 'reload pretrained network') 30 | cmd:option('-agent', '', 'name of agent file to use') 31 | cmd:option('-agent_params', '', 'string of agent parameters') 32 | cmd:option('-seed', torch.random(1,1000), 'fixed input seed for repeatable experiments') 33 | 34 | cmd:option('-verbose', 2, 35 | 'the higher the level, the more information is printed to screen') 36 | cmd:option('-threads', 1, 'number of BLAS threads') 37 | cmd:option('-gpu', -1, 'gpu flag') 38 | cmd:option('-gif_file', '', 'GIF path to write session screens') 39 | cmd:option('-csv_file', '', 'CSV path to write session data') 40 | 41 | cmd:text() 42 | 43 | local opt = cmd:parse(arg) 44 | 45 | --- General setup. 46 | local game_env, game_actions, agent, opt = setup(opt) 47 | 48 | -- override print to always flush the output 49 | local old_print = print 50 | local print = function(...) 51 | old_print(...) 52 | io.flush() 53 | end 54 | 55 | -- file names from command line 56 | local gif_filename = opt.gif_file 57 | 58 | -- start a new game 59 | local screen, reward, terminal = game_env:newGame() 60 | 61 | -- compress screen to JPEG with 100% quality 62 | local jpg = image.compressJPG(screen:squeeze(), 100) 63 | -- create gd image from JPEG string 64 | local im = gd.createFromJpegStr(jpg:storage():string()) 65 | -- convert truecolor to palette 66 | im:trueColorToPalette(false, 256) 67 | 68 | -- write GIF header, use global palette and infinite looping 69 | im:gifAnimBegin(gif_filename, true, 0) 70 | -- write first frame 71 | im:gifAnimAdd(gif_filename, false, 0, 0, 7, gd.DISPOSAL_NONE) 72 | 73 | -- remember the image and show it first 74 | local previm = im 75 | local win = image.display({image=screen}) 76 | 77 | print("Started playing...") 78 | 79 | -- play one episode (game) 80 | while not terminal do 81 | -- if action was chosen randomly, Q-value is 0 82 | agent.bestq = 0 83 | 84 | -- choose the best action 85 | local action_index = agent:perceive(reward, screen, terminal, true, 0.1) 86 | 87 | -- play game in test mode (episodes don't end when losing a life) 88 | screen, reward, terminal = game_env:step(game_actions[action_index], false) 89 | 90 | -- display screen 91 | image.display({image=screen, win=win}) 92 | 93 | -- create gd image from tensor 94 | jpg = image.compressJPG(screen:squeeze(), 100) 95 | im = gd.createFromJpegStr(jpg:storage():string()) 96 | 97 | -- use palette from previous (first) image 98 | im:trueColorToPalette(false, 256) 99 | im:paletteCopy(previm) 100 | 101 | -- write new GIF frame, no local palette, starting from left-top, 7ms delay 102 | im:gifAnimAdd(gif_filename, false, 0, 0, 7, gd.DISPOSAL_NONE) 103 | -- remember previous screen for optimal compression 104 | previm = im 105 | 106 | end 107 | 108 | -- end GIF animation and close CSV file 109 | gd.gifAnimEnd(gif_filename) 110 | 111 | print("Finished playing, close window to exit!") -------------------------------------------------------------------------------- /structured_priority/dqn/train_agent.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | 8 | require 'optim' 9 | require 'xlua' 10 | 11 | 12 | local cmd = torch.CmdLine() 13 | cmd:text() 14 | cmd:text('Train Agent in Environment:') 15 | cmd:text() 16 | cmd:text('Options:') 17 | 18 | cmd:option('-framework', '', 'name of training framework') 19 | cmd:option('-env', '', 'name of environment to use') 20 | cmd:option('-game_path', '', 'path to environment file (ROM)') 21 | cmd:option('-env_params', '', 'string of environment parameters') 22 | cmd:option('-pool_frms', '', 23 | 'string of frame pooling parameters (e.g.: size=2,type="max")') 24 | cmd:option('-actrep', 1, 'how many times to repeat action') 25 | cmd:option('-random_starts', 0, 'play action 0 between 1 and random_starts ' .. 26 | 'number of times at the start of each training episode') 27 | cmd:option('-exp_folder', 'logs/', 'name of folder where current exp state is being stored') 28 | 29 | cmd:option('-name', '', 'filename used for saving network and training history') 30 | cmd:option('-network', '', 'reload pretrained network') 31 | cmd:option('-agent', '', 'name of agent file to use') 32 | cmd:option('-agent_params', '', 'string of agent parameters') 33 | cmd:option('-seed', 1, 'fixed input seed for repeatable experiments') 34 | cmd:option('-saveNetworkParams', false, 35 | 'saves the agent network in a separate file') 36 | cmd:option('-prog_freq', 5*10^3, 'frequency of progress output') 37 | cmd:option('-save_freq', 5*10^4, 'the model is saved every save_freq steps') 38 | cmd:option('-eval_freq', 10^4, 'frequency of greedy evaluation') 39 | cmd:option('-save_versions', 0, '') 40 | 41 | cmd:option('-steps', 10^5, 'number of training steps to perform') 42 | cmd:option('-eval_steps', 10^5, 'number of evaluation steps') 43 | 44 | cmd:option('-verbose', 2, 45 | 'the higher the level, the more information is printed to screen') 46 | cmd:option('-threads', 1, 'number of BLAS threads') 47 | cmd:option('-gpu', -1, 'gpu flag') 48 | cmd:option('-display_game', false, 'option to display game') 49 | cmd:option('-mode', 0, 'priority sampling on/off') 50 | 51 | 52 | cmd:text() 53 | 54 | local opt = cmd:parse(arg) 55 | PRIORITY_FLAG = opt.mode 56 | if not dqn then 57 | require "initenv" 58 | end 59 | 60 | --- General setup. 61 | local game_env, game_actions, agent, opt = setup(opt) 62 | 63 | -- override print to always flush the output 64 | local old_print = print 65 | local print = function(...) 66 | old_print(...) 67 | io.flush() 68 | end 69 | 70 | local learn_start = agent.learn_start 71 | local start_time = sys.clock() 72 | local reward_counts = {} 73 | local episode_counts = {} 74 | local time_history = {} 75 | local v_history = {} 76 | local qmax_history = {} 77 | local td_history = {} 78 | local reward_history = {} 79 | local step = 0 80 | time_history[1] = 0 81 | 82 | local total_reward 83 | local nrewards 84 | local nepisodes 85 | local episode_reward 86 | 87 | local screen, reward, terminal = game_env:getState() 88 | 89 | print("Iteration ..", step) 90 | local win = nil 91 | while step < opt.steps do 92 | step = step + 1 93 | local action_index = agent:perceive(reward, screen, terminal) 94 | 95 | -- game over? get next game! 96 | if not terminal then 97 | screen, reward, terminal = game_env:step(game_actions[action_index], true) 98 | else 99 | if opt.random_starts > 0 then 100 | screen, reward, terminal = game_env:nextRandomGame() 101 | else 102 | screen, reward, terminal = game_env:newGame() 103 | end 104 | end 105 | 106 | -- display screen 107 | if opt.display_game then win = image.display({image=screen, win=win}) end 108 | 109 | if step % opt.prog_freq == 0 then 110 | assert(step==agent.numSteps, 'trainer step: ' .. step .. 111 | ' & agent.numSteps: ' .. agent.numSteps) 112 | print("Steps: ", step) 113 | agent:report() 114 | collectgarbage() 115 | end 116 | 117 | if step%1000 == 0 then collectgarbage() end 118 | 119 | if step % opt.eval_freq == 0 and step > learn_start then 120 | 121 | screen, reward, terminal = game_env:newGame() 122 | 123 | test_avg_Q = test_avg_Q or optim.Logger(paths.concat(opt.exp_folder , opt.name .. '_normal_test_avgQ.log')) 124 | test_avg_R = test_avg_R or optim.Logger(paths.concat(opt.exp_folder , opt.name .. '_normal_test_avgR.log')) 125 | test_avg_R2 = test_avg_R2 or optim.Logger(paths.concat(opt.exp_folder , opt.name .. '_normal_test_avgR2.log')) 126 | 127 | total_reward = 0 128 | nrewards = 0 129 | nepisodes = 0 130 | episode_reward = 0 131 | 132 | local eval_time = sys.clock() 133 | for estep=1,opt.eval_steps do 134 | local action_index = agent:perceive(reward, screen, terminal, true, 0.05) 135 | 136 | -- Play game in test mode (episodes don't end when losing a life) 137 | screen, reward, terminal = game_env:step(game_actions[action_index]) 138 | 139 | -- display screen 140 | if opt.display_game then win = image.display({image=screen, win=win}) end 141 | 142 | if estep%1000 == 0 then collectgarbage() end 143 | 144 | -- record every reward 145 | episode_reward = episode_reward + reward 146 | if reward ~= 0 then 147 | nrewards = nrewards + 1 148 | end 149 | 150 | if terminal then 151 | total_reward = total_reward + episode_reward 152 | episode_reward = 0 153 | nepisodes = nepisodes + 1 154 | screen, reward, terminal = game_env:nextRandomGame() 155 | end 156 | end 157 | 158 | eval_time = sys.clock() - eval_time 159 | start_time = start_time + eval_time 160 | agent:compute_validation_statistics() 161 | local ind = #reward_history+1 162 | total_reward = total_reward/math.max(1, nepisodes) 163 | 164 | if #reward_history == 0 or total_reward > torch.Tensor(reward_history):max() then 165 | agent.best_network = agent.network:clone() 166 | end 167 | 168 | if agent.v_avg then 169 | v_history[ind] = agent.v_avg 170 | td_history[ind] = agent.tderr_avg 171 | qmax_history[ind] = agent.q_max 172 | end 173 | print("V", v_history[ind], "TD error", td_history[ind], "Qmax", qmax_history[ind]) 174 | 175 | 176 | test_avg_R:add{['% Average Extrinsic Reward'] = cum_reward_ext} 177 | test_avg_R2:add{['% Average Total Reward'] = cum_reward_tot} 178 | test_avg_Q:add{['% Average Q'] = agent.v_avg} 179 | 180 | 181 | -- test_avg_R:style{['% Average Extrinsic Reward'] = '-'}; test_avg_R:plot() 182 | -- test_avg_R2:style{['% Average Total Reward'] = '-'}; test_avg_R2:plot() 183 | 184 | -- test_avg_Q:style{['% Average Q'] = '-'}; test_avg_Q:plot() 185 | 186 | 187 | reward_history[ind] = total_reward 188 | reward_counts[ind] = nrewards 189 | episode_counts[ind] = nepisodes 190 | 191 | time_history[ind+1] = sys.clock() - start_time 192 | 193 | local time_dif = time_history[ind+1] - time_history[ind] 194 | 195 | local training_rate = opt.actrep*opt.eval_freq/time_dif 196 | 197 | print(string.format( 198 | '\nSteps: %d (frames: %d), reward: %.2f, epsilon: %.2f, lr: %G, ' .. 199 | 'training time: %ds, training rate: %dfps, testing time: %ds, ' .. 200 | 'testing rate: %dfps, num. ep.: %d, num. rewards: %d', 201 | step, step*opt.actrep, total_reward, agent.ep, agent.lr, time_dif, 202 | training_rate, eval_time, opt.actrep*opt.eval_steps/eval_time, 203 | nepisodes, nrewards)) 204 | end 205 | 206 | if step % opt.save_freq == 0 or step == opt.steps then 207 | local s, a, r, s2, term = agent.valid_s, agent.valid_a, agent.valid_r, 208 | agent.valid_s2, agent.valid_term 209 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2, 210 | agent.valid_term = nil, nil, nil, nil, nil, nil, nil 211 | local w, dw, g, g2, delta, delta2, deltas, tmp = agent.w, agent.dw, 212 | agent.g, agent.g2, agent.delta, agent.delta2, agent.deltas, agent.tmp 213 | agent.w, agent.dw, agent.g, agent.g2, agent.delta, agent.delta2, 214 | agent.deltas, agent.tmp = nil, nil, nil, nil, nil, nil, nil, nil 215 | 216 | local filename = opt.name 217 | if opt.save_versions > 0 then 218 | filename = filename .. "_" .. math.floor(step / opt.save_versions) 219 | end 220 | filename = filename 221 | torch.save(filename .. ".t7", {agent = agent, 222 | model = agent.network, 223 | best_model = agent.best_network, 224 | reward_history = reward_history, 225 | reward_counts = reward_counts, 226 | episode_counts = episode_counts, 227 | time_history = time_history, 228 | v_history = v_history, 229 | td_history = td_history, 230 | qmax_history = qmax_history, 231 | arguments=opt}) 232 | if opt.saveNetworkParams then 233 | local nets = {network=w:clone():float()} 234 | torch.save(filename..'.params.t7', nets, 'ascii') 235 | end 236 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2, 237 | agent.valid_term = s, a, r, s2, term 238 | agent.w, agent.dw, agent.g, agent.g2, agent.delta, agent.delta2, 239 | agent.deltas, agent.tmp = w, dw, g, g2, delta, delta2, deltas, tmp 240 | print('Saved:', filename .. '.t7') 241 | io.flush() 242 | collectgarbage() 243 | end 244 | end 245 | -------------------------------------------------------------------------------- /structured_priority/install_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ###################################################################### 4 | # Torch install 5 | ###################################################################### 6 | 7 | 8 | TOPDIR=$PWD 9 | 10 | # Prefix: 11 | PREFIX=$PWD/torch 12 | echo "Installing Torch into: $PREFIX" 13 | 14 | if [[ `uname` != 'Linux' ]]; then 15 | echo 'Platform unsupported, only available for Linux' 16 | exit 17 | fi 18 | if [[ `which apt-get` == '' ]]; then 19 | echo 'apt-get not found, platform not supported' 20 | exit 21 | fi 22 | 23 | # Install dependencies for Torch: 24 | sudo apt-get update 25 | sudo apt-get install -qqy build-essential 26 | sudo apt-get install -qqy gcc g++ 27 | sudo apt-get install -qqy cmake 28 | sudo apt-get install -qqy curl 29 | sudo apt-get install -qqy libreadline-dev 30 | sudo apt-get install -qqy git-core 31 | sudo apt-get install -qqy libjpeg-dev 32 | sudo apt-get install -qqy libpng-dev 33 | sudo apt-get install -qqy ncurses-dev 34 | sudo apt-get install -qqy imagemagick 35 | sudo apt-get install -qqy unzip 36 | sudo apt-get install -qqy libqt4-dev 37 | sudo apt-get install -qqy liblua5.1-0-dev 38 | sudo apt-get install -qqy libgd-dev 39 | sudo apt-get update 40 | 41 | 42 | echo "==> Torch7's dependencies have been installed" 43 | 44 | 45 | 46 | 47 | 48 | # Build and install Torch7 49 | cd /tmp 50 | rm -rf luajit-rocks 51 | git clone https://github.com/torch/luajit-rocks.git 52 | cd luajit-rocks 53 | mkdir -p build 54 | cd build 55 | git checkout master; git pull 56 | rm -f CMakeCache.txt 57 | cmake .. -DCMAKE_INSTALL_PREFIX=$PREFIX -DCMAKE_BUILD_TYPE=Release 58 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 59 | make 60 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 61 | make install 62 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 63 | 64 | 65 | path_to_nvcc=$(which nvcc) 66 | if [ -x "$path_to_nvcc" ] 67 | then 68 | cutorch=ok 69 | cunn=ok 70 | fi 71 | 72 | # Install base packages: 73 | $PREFIX/bin/luarocks install cwrap 74 | $PREFIX/bin/luarocks install paths 75 | $PREFIX/bin/luarocks install torch 76 | $PREFIX/bin/luarocks install nn 77 | 78 | [ -n "$cutorch" ] && \ 79 | ($PREFIX/bin/luarocks install cutorch) 80 | [ -n "$cunn" ] && \ 81 | ($PREFIX/bin/luarocks install cunn) 82 | 83 | $PREFIX/bin/luarocks install luafilesystem 84 | $PREFIX/bin/luarocks install penlight 85 | $PREFIX/bin/luarocks install sys 86 | $PREFIX/bin/luarocks install xlua 87 | $PREFIX/bin/luarocks install image 88 | $PREFIX/bin/luarocks install env 89 | $PREFIX/bin/luarocks install qtlua 90 | $PREFIX/bin/luarocks install qttorch 91 | 92 | echo "" 93 | echo "=> Torch7 has been installed successfully" 94 | echo "" 95 | 96 | 97 | echo "Installing nngraph ... " 98 | $PREFIX/bin/luarocks install nngraph 99 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 100 | echo "nngraph installation completed" 101 | 102 | echo "Installing Xitari ... " 103 | cd /tmp 104 | rm -rf xitari 105 | git clone https://github.com/deepmind/xitari.git 106 | cd xitari 107 | $PREFIX/bin/luarocks make 108 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 109 | echo "Xitari installation completed" 110 | 111 | echo "Installing Alewrap ... " 112 | cd /tmp 113 | rm -rf alewrap 114 | git clone https://github.com/deepmind/alewrap.git 115 | cd alewrap 116 | $PREFIX/bin/luarocks make 117 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 118 | echo "Alewrap installation completed" 119 | 120 | echo "Installing Lua-GD ... " 121 | mkdir $PREFIX/src 122 | cd $PREFIX/src 123 | rm -rf lua-gd 124 | git clone https://github.com/ittner/lua-gd.git 125 | cd lua-gd 126 | sed -i "s/LUABIN=lua5.1/LUABIN=..\/..\/bin\/luajit/" Makefile 127 | $PREFIX/bin/luarocks make 128 | RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi 129 | echo "Lua-GD installation completed" 130 | 131 | echo 132 | echo "You can run experiments by executing: " 133 | echo 134 | echo " ./run_cpu game_name" 135 | echo 136 | echo " or " 137 | echo 138 | echo " ./run_gpu game_name" 139 | echo 140 | echo "For this you need to provide the rom files of the respective games (game_name.bin) in the roms/ directory" 141 | echo 142 | 143 | -------------------------------------------------------------------------------- /structured_priority/roms/README: -------------------------------------------------------------------------------- 1 | Rom files should be put in this directory 2 | -------------------------------------------------------------------------------- /structured_priority/roms/breakout.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/hierarchical-deep-RL/a3dd9407831b215c95ea5c97815e0a8bf639478b/structured_priority/roms/breakout.bin -------------------------------------------------------------------------------- /structured_priority/run_cpu: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./run_cpu breakout "; exit 0 5 | fi 6 | ENV=$1 7 | FRAMEWORK="alewrap" 8 | 9 | game_path=$PWD"/roms/" 10 | env_params="useRGB=true" 11 | agent="NeuralQLearner" 12 | n_replay=1 13 | netfile="\"convnet_atari3\"" 14 | update_freq=4 15 | actrep=4 16 | discount=0.99 17 | seed=1 18 | learn_start=50000 19 | pool_frms_type="\"max\"" 20 | pool_frms_size=2 21 | initial_priority="false" 22 | replay_memory=1000000 23 | eps_end=0.1 24 | eps_endt=replay_memory 25 | lr=0.00025 26 | agent_type="DQN3_0_1" 27 | preproc_net="\"net_downsample_2x_full_y\"" 28 | agent_name=$agent_type"_"$1"_FULL_Y" 29 | state_dim=7056 30 | ncols=1 31 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,min_reward=-1,max_reward=1" 32 | steps=50000000 33 | eval_freq=250000 34 | eval_steps=125000 35 | prog_freq=5000 36 | save_freq=125000 37 | gpu=-1 38 | random_starts=30 39 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 40 | num_threads=4 41 | 42 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads" 43 | echo $args 44 | 45 | cd dqn 46 | ../torch/bin/qlua train_agent.lua $args 47 | -------------------------------------------------------------------------------- /structured_priority/run_gpu: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./run_gpu breakout name="" mode=0/1"; exit 0 5 | fi 6 | ENV=$1 7 | FRAMEWORK="alewrap" 8 | NAME=$2 9 | MODE=$3 10 | 11 | game_path=$PWD"/roms/" 12 | env_params="useRGB=true" 13 | agent="NeuralQLearner" 14 | n_replay=1 15 | netfile="\"convnet_atari3\"" 16 | update_freq=4 17 | actrep=4 18 | discount=0.99 19 | seed=1 20 | learn_start=50000 21 | pool_frms_type="\"max\"" 22 | pool_frms_size=2 23 | initial_priority="false" 24 | replay_memory=1000000 25 | eps_end=0.1 26 | eps_endt=replay_memory 27 | lr=0.00025 28 | agent_type="PRIO="$MODE"_"$NAME"_DQN3_0_1" 29 | preproc_net="\"net_downsample_2x_full_y\"" 30 | agent_name=$agent_type"_"$1"_FULL_Y" 31 | state_dim=7056 32 | ncols=1 33 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,min_reward=-1,max_reward=1" 34 | steps=50000000 35 | eval_freq=250000 36 | eval_steps=125000 37 | prog_freq=10000 38 | save_freq=125000 39 | gpu=0 40 | random_starts=30 41 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 42 | num_threads=4 43 | 44 | args="-framework $FRAMEWORK -mode $MODE -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads" 45 | echo $args 46 | 47 | cd dqn 48 | th train_agent.lua $args 49 | -------------------------------------------------------------------------------- /structured_priority/runner_slurm.py: -------------------------------------------------------------------------------- 1 | #run script using slurm 2 | import os 3 | 4 | if not os.path.exists("slurm_logs"): 5 | os.makedirs("slurm_logs") 6 | 7 | if not os.path.exists("slurm_scripts"): 8 | os.makedirs("slurm_scripts") 9 | 10 | # Don't give it a `save` name - that gets generated for you 11 | jobs = [ 12 | [ 13 | 'breakout', #game 14 | 'exp1', #game name 15 | 0 #priority on/off 16 | ], 17 | [ 18 | 'breakout', 19 | 'exp1', 20 | 1 21 | ], 22 | ] 23 | 24 | for jj in range(len(jobs)): 25 | jobname = "RL" 26 | flagstring = "" 27 | for ii in range(len(jobs[jj])): 28 | jobname = jobname + "_" + str(jobs[jj][ii]) 29 | flagstring = flagstring + " " + str(jobs[jj][ii]) 30 | 31 | 32 | if not os.path.exists("slurm_logs/" + jobname): 33 | os.makedirs("slurm_logs/" + jobname) 34 | 35 | with open('slurm_scripts/' + jobname + '.slurm', 'w') as slurmfile: 36 | slurmfile.write("#!/bin/bash\n") 37 | slurmfile.write("#SBATCH --job-name"+"=" + jobname + "\n") 38 | slurmfile.write("#SBATCH --output=slurm_logs/" + jobname + ".out\n") 39 | slurmfile.write("#SBATCH --error=slurm_logs/" + jobname + ".err\n") 40 | slurmfile.write("./run_gpu" + flagstring) 41 | 42 | print ("./run_gpu" + flagstring) 43 | if False: 44 | os.system("sbatch --qos=cbmm --mem=40000 -N 1 -c 2 --gres=gpu:1 --time=5-00:00:00 slurm_scripts/" + jobname + ".slurm &") 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /structured_priority/slurm_scripts/RL_breakout_exp1_0.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=RL_breakout_exp1_0 3 | #SBATCH --output=slurm_logs/RL_breakout_exp1_0.out 4 | #SBATCH --error=slurm_logs/RL_breakout_exp1_0.err 5 | ./run_gpu breakout exp1 0 -------------------------------------------------------------------------------- /structured_priority/slurm_scripts/RL_breakout_exp1_1.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=RL_breakout_exp1_1 3 | #SBATCH --output=slurm_logs/RL_breakout_exp1_1.out 4 | #SBATCH --error=slurm_logs/RL_breakout_exp1_1.err 5 | ./run_gpu breakout exp1 1 -------------------------------------------------------------------------------- /structured_priority/slurm_scripts/RL_breakout_exp2_0.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=RL_breakout_exp2_0 3 | #SBATCH --output=slurm_logs/RL_breakout_exp2_0.out 4 | #SBATCH --error=slurm_logs/RL_breakout_exp2_0.err 5 | ./run_gpu breakout exp2 0 -------------------------------------------------------------------------------- /structured_priority/slurm_scripts/RL_breakout_exp2_1.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=RL_breakout_exp2_1 3 | #SBATCH --output=slurm_logs/RL_breakout_exp2_1.out 4 | #SBATCH --error=slurm_logs/RL_breakout_exp2_1.err 5 | ./run_gpu breakout exp2 1 -------------------------------------------------------------------------------- /structured_priority/test_cpu: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./watch_pretrained breakout"; exit 0 5 | fi 6 | 7 | if [ -z "$2" ] 8 | then echo "Please provide the pretrained network file, e.g. ./watch_pretrained breakout DQN3_0_1_breakout_FULL_Y.t7"; exit 0 9 | fi 10 | 11 | ENV=$1 12 | NETWORK=$2 13 | FRAMEWORK="alewrap" 14 | 15 | game_path=$PWD"/roms/" 16 | env_params="useRGB=true" 17 | agent="NeuralQLearner" 18 | n_replay=1 19 | netfile="\"convnet_atari3\"" 20 | update_freq=4 21 | actrep=4 22 | discount=0.99 23 | seed=1 24 | learn_start=50000 25 | pool_frms_type="\"max\"" 26 | pool_frms_size=2 27 | initial_priority="false" 28 | replay_memory=1000000 29 | eps_end=0.1 30 | eps_endt=replay_memory 31 | lr=0.00025 32 | agent_type="DQN3_0_1" 33 | preproc_net="\"net_downsample_2x_full_y\"" 34 | agent_name=$agent_type"_"$1"_FULL_Y" 35 | state_dim=7056 36 | ncols=1 37 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,min_reward=-1,max_reward=1" 38 | gif_file="../gifs/$ENV.gif" 39 | gpu=-1 40 | random_starts=30 41 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 42 | num_threads=4 43 | 44 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -network $NETWORK -gif_file $gif_file" 45 | echo $args 46 | 47 | cd dqn 48 | ../torch/bin/qlua test_agent.lua $args 49 | -------------------------------------------------------------------------------- /structured_priority/test_gpu: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./watch_pretrained breakout"; exit 0 5 | fi 6 | 7 | if [ -z "$2" ] 8 | then echo "Please provide the pretrained network file, e.g. ./watch_pretrained breakout DQN3_0_1_breakout_FULL_Y.t7"; exit 0 9 | fi 10 | 11 | ENV=$1 12 | NETWORK=$2 13 | FRAMEWORK="alewrap" 14 | 15 | game_path=$PWD"/roms/" 16 | env_params="useRGB=true" 17 | agent="NeuralQLearner" 18 | n_replay=1 19 | netfile="\"convnet_atari3\"" 20 | update_freq=4 21 | actrep=4 22 | discount=0.99 23 | seed=1 24 | learn_start=50000 25 | pool_frms_type="\"max\"" 26 | pool_frms_size=2 27 | initial_priority="false" 28 | replay_memory=1000000 29 | eps_end=0.1 30 | eps_endt=replay_memory 31 | lr=0.00025 32 | agent_type="DQN3_0_1" 33 | preproc_net="\"net_downsample_2x_full_y\"" 34 | agent_name=$agent_type"_"$1"_FULL_Y" 35 | state_dim=7056 36 | ncols=1 37 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,min_reward=-1,max_reward=1" 38 | gif_file="../gifs/$ENV.gif" 39 | gpu=0 40 | random_starts=30 41 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 42 | num_threads=4 43 | 44 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -network $NETWORK -gif_file $gif_file" 45 | echo $args 46 | 47 | cd dqn 48 | qlua test_agent.lua $args 49 | -------------------------------------------------------------------------------- /test_cpu: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./watch_pretrained breakout"; exit 0 5 | fi 6 | 7 | if [ -z "$2" ] 8 | then echo "Please provide the pretrained network file, e.g. ./watch_pretrained breakout DQN3_0_1_breakout_FULL_Y.t7"; exit 0 9 | fi 10 | 11 | ENV=$1 12 | NETWORK=$2 13 | FRAMEWORK="alewrap" 14 | 15 | game_path=$PWD"/roms/" 16 | env_params="useRGB=true" 17 | agent="NeuralQLearner" 18 | n_replay=1 19 | netfile="\"convnet_atari3\"" 20 | update_freq=4 21 | actrep=4 22 | discount=0.99 23 | seed=1 24 | learn_start=50000 25 | pool_frms_type="\"max\"" 26 | pool_frms_size=2 27 | initial_priority="false" 28 | replay_memory=1000000 29 | eps_end=0.1 30 | eps_endt=replay_memory 31 | lr=0.00025 32 | agent_type="DQN3_0_1" 33 | preproc_net="\"net_downsample_2x_full_y\"" 34 | agent_name=$agent_type"_"$1"_FULL_Y" 35 | state_dim=7056 36 | ncols=1 37 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,min_reward=-1,max_reward=1" 38 | gif_file="../gifs/$ENV.gif" 39 | gpu=-1 40 | random_starts=30 41 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 42 | num_threads=4 43 | 44 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -network $NETWORK -gif_file $gif_file" 45 | echo $args 46 | 47 | cd dqn 48 | ../torch/bin/qlua test_agent.lua $args 49 | -------------------------------------------------------------------------------- /test_exp.sh: -------------------------------------------------------------------------------- 1 | if [ -z "$1" ] 2 | then echo "Please provide the logname and port for testing the experiment e.g. ./test_exp basic1 5000 "; exit 0 3 | fi 4 | cd dqn; 5 | python pyserver.py $2 & 6 | cd ..; 7 | ./run_gpu montezuma_revenge $1 $2; -------------------------------------------------------------------------------- /test_gpu: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #./test_gpu montezuma_revenge logs/basic1/basic1_montezuma_revenge_FULL_Y.t7 5550 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./watch_pretrained breakout"; exit 0 5 | fi 6 | 7 | if [ -z "$2" ] 8 | then echo "Please provide the pretrained network file, e.g. ./watch_pretrained breakout DQN3_0_1_breakout_FULL_Y.t7"; exit 0 9 | fi 10 | 11 | if [ -z "$3" ] 12 | then echo "Please provide the zmq port for testing"; exit 0 13 | fi 14 | 15 | 16 | ENV=$1 17 | NETWORK=$2 18 | PORT=$3 19 | FRAMEWORK="alewrap" 20 | 21 | game_path=$PWD"/roms/" 22 | env_params="useRGB=true" 23 | agent="NeuralQLearner" 24 | n_replay=1 25 | netfile="\"convnet_atari3\"" 26 | update_freq=4 27 | actrep=4 28 | discount=0.99 29 | seed=1 30 | learn_start=50000 31 | pool_frms_type="\"max\"" 32 | pool_frms_size=2 33 | initial_priority="false" 34 | replay_memory=1000000 35 | eps_end=0.1 36 | eps_endt=replay_memory 37 | lr=0.00025 38 | agent_type="DQN3_0_1" 39 | preproc_net="\"net_downsample_2x_full_y\"" 40 | agent_name=$agent_type"_"$1"_FULL_Y" 41 | state_dim=7056 42 | ncols=1 43 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,min_reward=-1,max_reward=1" 44 | gif_file="../gifs/$ENV.gif" 45 | gpu=0 46 | random_starts=1 47 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 48 | num_threads=4 49 | 50 | args="-framework $FRAMEWORK -port $PORT -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -network $NETWORK -gif_file $gif_file" 51 | echo $args 52 | 53 | cd dqn 54 | qlua test_agent.lua $args 55 | --------------------------------------------------------------------------------