├── figures ├── pal_asterix.png ├── a3c_beam_rider.png ├── dqn_space_invaders.png ├── doubleq_space_invaders.png └── dueling_space_invaders.png ├── test ├── testAll.lua ├── testBinaryHeap.lua └── testExperience.lua ├── async ├── AbstractAgent.lua ├── AsyncModel.lua ├── SarsaAgent.lua ├── AsyncEvaluation.lua ├── NStepQAgent.lua ├── OneStepQAgent.lua ├── QAgent.lua ├── AsyncAgent.lua ├── A3CAgent.lua ├── AsyncMaster.lua └── ValidationAgent.lua ├── examples ├── GridWorldNet.lua └── GridWorldVis.lua ├── .gitignore ├── modules ├── GradientRescale.lua ├── GuidedReLU.lua ├── sharedRmsProp.lua ├── HuberCriterion.lua ├── DeconvnetReLU.lua ├── MinDim.lua ├── DuelAggregator.lua └── rmspropm.lua ├── main.lua ├── structures ├── Singleton.lua ├── CircularQueue.lua └── BinaryHeap.lua ├── models ├── Catch.lua ├── Atari2013.lua └── Atari.lua ├── LICENSE.md ├── Evaluator.lua ├── Display.lua ├── .travis.yml ├── roms └── README.md ├── run.sh ├── Validation.lua ├── CONTRIBUTING.md ├── Master.lua ├── README.md ├── Model.lua ├── Experience.lua ├── Setup.lua └── Agent.lua /figures/pal_asterix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaixhin/Atari/HEAD/figures/pal_asterix.png -------------------------------------------------------------------------------- /figures/a3c_beam_rider.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaixhin/Atari/HEAD/figures/a3c_beam_rider.png -------------------------------------------------------------------------------- /figures/dqn_space_invaders.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaixhin/Atari/HEAD/figures/dqn_space_invaders.png -------------------------------------------------------------------------------- /figures/doubleq_space_invaders.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaixhin/Atari/HEAD/figures/doubleq_space_invaders.png -------------------------------------------------------------------------------- /figures/dueling_space_invaders.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaixhin/Atari/HEAD/figures/dueling_space_invaders.png -------------------------------------------------------------------------------- /test/testAll.lua: -------------------------------------------------------------------------------- 1 | require 'torch' -- on travis luajit is invoked and this is needed 2 | 3 | tester = torch.Tester() 4 | 5 | tester:add(require 'test/testBinaryHeap') 6 | tester:add(require 'test/testExperience') 7 | 8 | tester:run() -------------------------------------------------------------------------------- /async/AbstractAgent.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | 3 | local AbstractAgent = classic.class('AbstractAgent') 4 | 5 | 6 | AbstractAgent:mustHave('observe') 7 | AbstractAgent:mustHave('training') 8 | AbstractAgent:mustHave('evaluate') 9 | 10 | return AbstractAgent -------------------------------------------------------------------------------- /examples/GridWorldNet.lua: -------------------------------------------------------------------------------- 1 | local nn = require 'nn' 2 | require 'classic.torch' -- Enables serialisation 3 | 4 | local Body = classic.class('Body') 5 | 6 | -- Constructor 7 | function Body:_init(opts) 8 | opts = opts or {} 9 | end 10 | 11 | function Body:createBody() 12 | local net = nn.Sequential() 13 | net:add(nn.View(2)) 14 | net:add(nn.Linear(2, 32)) 15 | net:add(nn.ReLU(true)) 16 | 17 | return net 18 | end 19 | 20 | return Body 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Lua sources 2 | luac.out 3 | 4 | # luarocks build files 5 | *.src.rock 6 | *.zip 7 | *.tar.gz 8 | 9 | # Object files 10 | *.o 11 | *.os 12 | *.ko 13 | *.obj 14 | *.elf 15 | 16 | # Precompiled Headers 17 | *.gch 18 | *.pch 19 | 20 | # Libraries 21 | *.lib 22 | *.a 23 | *.la 24 | *.lo 25 | *.def 26 | *.exp 27 | 28 | # Shared objects (inc. Windows DLLs) 29 | *.dll 30 | *.so 31 | *.so.* 32 | *.dylib 33 | 34 | # Executables 35 | *.exe 36 | *.out 37 | *.app 38 | *.i*86 39 | *.x86_64 40 | *.hex 41 | core.* 42 | 43 | # ROMs 44 | roms/*.bin 45 | # Experiments 46 | experiments 47 | # Scratch space (for recordings) 48 | scratch/*.jpg 49 | # Videos 50 | videos/*.webm 51 | -------------------------------------------------------------------------------- /modules/GradientRescale.lua: -------------------------------------------------------------------------------- 1 | local GradientRescale, parent = torch.class('nn.GradientRescale', 'nn.Module') 2 | 3 | function GradientRescale:__init(scaleFactor, inplace) 4 | parent.__init(self) 5 | self.scaleFactor = scaleFactor 6 | self.inplace = inplace 7 | end 8 | 9 | function GradientRescale:updateOutput(input) 10 | self.output = input 11 | return self.output 12 | end 13 | 14 | function GradientRescale:updateGradInput(input, gradOutput) 15 | if self.inplace then 16 | self.gradInput = gradOutput 17 | else 18 | self.gradInput:resizeAs(gradOutput) 19 | self.gradInput:copy(gradOutput) 20 | end 21 | self.gradInput:mul(self.scaleFactor) 22 | return self.gradInput 23 | end 24 | -------------------------------------------------------------------------------- /main.lua: -------------------------------------------------------------------------------- 1 | local Setup = require 'Setup' 2 | local Master = require 'Master' 3 | local AsyncMaster = require 'async/AsyncMaster' 4 | local AsyncEvaluation = require 'async/AsyncEvaluation' 5 | 6 | -- Parse options and perform setup 7 | local setup = Setup(arg) 8 | local opt = setup.opt 9 | 10 | -- Start master experiment runner 11 | if opt.async then 12 | if opt.mode == 'train' then 13 | local master = AsyncMaster(opt) 14 | master:start() 15 | elseif opt.mode == 'eval' then 16 | local eval = AsyncEvaluation(opt) 17 | eval:evaluate() 18 | end 19 | else 20 | local master = Master(opt) 21 | 22 | if opt.mode == 'train' then 23 | master:train() 24 | elseif opt.mode == 'eval' then 25 | master:evaluate() 26 | end 27 | end 28 | -------------------------------------------------------------------------------- /modules/GuidedReLU.lua: -------------------------------------------------------------------------------- 1 | local GuidedReLU, parent = torch.class('nn.GuidedReLU', 'nn.ReLU') 2 | 3 | function GuidedReLU:__init(p) 4 | parent.__init(self, p) 5 | self.guide = false 6 | end 7 | 8 | function GuidedReLU:updateOutput(input) 9 | return parent.updateOutput(self, input) 10 | end 11 | 12 | function GuidedReLU:updateGradInput(input, gradOutput) 13 | parent.updateGradInput(self, input, gradOutput) 14 | if self.guide then 15 | -- Only backpropagate positive error signals 16 | self.gradInput:cmul(torch.gt(gradOutput, 0):typeAs(gradOutput)) 17 | end 18 | return self.gradInput 19 | end 20 | 21 | function GuidedReLU:salientBackprop() 22 | self.guide = true 23 | end 24 | 25 | function GuidedReLU:normalBackprop() 26 | self.guide = false 27 | end 28 | -------------------------------------------------------------------------------- /structures/Singleton.lua: -------------------------------------------------------------------------------- 1 | local class = require 'classic' 2 | require 'classic.torch' -- Enables serialisation 3 | 4 | local Singleton = classic.class('Singleton') 5 | 6 | function Singleton:_init(fields) 7 | -- Check for existing object 8 | if not Singleton.getInstance() then 9 | -- Populate new object with data 10 | for k, v in pairs(fields) do 11 | self[k] = v 12 | end 13 | 14 | -- Set static instance 15 | Singleton.static.instance = self 16 | end 17 | end 18 | 19 | -- Gets static instance 20 | function Singleton.static.getInstance() 21 | return Singleton.static.instance 22 | end 23 | 24 | -- Sets static instance 25 | function Singleton.static.setInstance(inst) 26 | Singleton.static.instance = inst 27 | end 28 | 29 | return Singleton 30 | -------------------------------------------------------------------------------- /test/testBinaryHeap.lua: -------------------------------------------------------------------------------- 1 | local BinaryHeap = require 'structures/BinaryHeap' 2 | local tds = require 'tds' 3 | 4 | local Test = torch.TestSuite() 5 | local standalone = tester == nil 6 | if standalone then 7 | tester = torch.Tester() 8 | end 9 | 10 | 11 | function Test:BinaryHeap_Test() 12 | local heap = BinaryHeap(1000) 13 | local vec = tds.Vec() 14 | 15 | for i=1,100 do 16 | local r = torch.random(100) 17 | vec[#vec+1] = r 18 | heap:insert(r,r*2) 19 | end 20 | 21 | vec:sort(function(a,b) return a > b end) 22 | 23 | tester:eq(heap:findMax(), vec[1]) 24 | 25 | for i=1,100 do 26 | local entry = heap:pop() 27 | local r = vec[i] 28 | 29 | tester:eq(entry[1], r) 30 | tester:eq(entry[2], r*2) 31 | end 32 | end 33 | 34 | 35 | 36 | if standalone then 37 | tester:add(Test) 38 | tester:run() 39 | end 40 | 41 | return Test 42 | -------------------------------------------------------------------------------- /models/Catch.lua: -------------------------------------------------------------------------------- 1 | local nn = require 'nn' 2 | require 'classic.torch' -- Enables serialisation 3 | 4 | local Body = classic.class('Body') 5 | 6 | -- Constructor 7 | function Body:_init(opts) 8 | opts = opts or {} 9 | 10 | self.recurrent = opts.recurrent 11 | self.histLen = opts.histLen 12 | self.stateSpec = opts.stateSpec 13 | end 14 | 15 | function Body:createBody() 16 | -- Number of input frames for recurrent networks is always 1 17 | local histLen = self.recurrent and 1 or self.histLen 18 | local net = nn.Sequential() 19 | net:add(nn.View(histLen*self.stateSpec[2][1], self.stateSpec[2][2], self.stateSpec[2][3])) 20 | net:add(nn.SpatialConvolution(histLen*self.stateSpec[2][1], 32, 5, 5, 2, 2, 1, 1)) 21 | net:add(nn.ReLU(true)) 22 | net:add(nn.SpatialConvolution(32, 32, 5, 5, 2, 2)) 23 | net:add(nn.ReLU(true)) 24 | 25 | return net 26 | end 27 | 28 | return Body 29 | -------------------------------------------------------------------------------- /modules/sharedRmsProp.lua: -------------------------------------------------------------------------------- 1 | function optim.sharedRmsProp(opfunc, x, config, state) 2 | -- Get state 3 | local config = config or {} 4 | local state = state or config 5 | local lr = config.learningRate or 1e-2 6 | local momentum = config.momentum or 0.95 7 | local epsilon = config.rmsEpsilon or 0.01 8 | 9 | -- Evaluate f(x) and df/dx 10 | local fx, dfdx = opfunc(x) 11 | 12 | -- Initialise storage 13 | if not state.g then 14 | state.g = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() 15 | end 16 | 17 | if not state.tmp then 18 | state.tmp = torch.Tensor():typeAs(x):resizeAs(dfdx) 19 | end 20 | 21 | state.g:mul(momentum):addcmul(1 - momentum, dfdx, dfdx) 22 | state.tmp:copy(state.g):add(epsilon):sqrt() 23 | 24 | -- Update x = x - lr x df/dx / tmp 25 | x:addcdiv(-lr, dfdx, state.tmp) 26 | 27 | -- Return x*, f(x) before optimisation 28 | return x, {fx} 29 | end -------------------------------------------------------------------------------- /async/AsyncModel.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | local Model = require 'Model' 3 | 4 | local AsyncModel = classic.class('AsyncModel') 5 | 6 | function AsyncModel:_init(opt) 7 | -- Initialise environment 8 | log.info('Setting up ' .. opt.env) 9 | local Env = require(opt.env) 10 | self.env = Env(opt) -- Environment instantiation 11 | 12 | -- Augment environment with extra methods if missing 13 | if not self.env.training then 14 | self.env.training = function() end 15 | end 16 | if not self.env.evaluate then 17 | self.env.evaluate = function() end 18 | end 19 | 20 | self.model = Model(opt) 21 | self.a3c = opt.async == 'A3C' 22 | 23 | classic.strict(self) 24 | end 25 | 26 | function AsyncModel:getEnvAndModel() 27 | return self.env, self.model 28 | end 29 | 30 | function AsyncModel:createNet() 31 | return self.model:create() 32 | end 33 | 34 | return AsyncModel 35 | -------------------------------------------------------------------------------- /modules/HuberCriterion.lua: -------------------------------------------------------------------------------- 1 | local HuberCriterion, parent = torch.class('nn.HuberCriterion', 'nn.Criterion') 2 | 3 | function HuberCriterion:__init(delta) 4 | parent.__init(self) 5 | self.delta = delta or 1 -- Boundary 6 | self.alpha = torch.Tensor() -- Residual 7 | end 8 | 9 | function HuberCriterion:updateOutput(input, target) 10 | -- Calculate residual 11 | self.alpha = target - input 12 | 13 | self.absAlpha = torch.abs(self.alpha) 14 | self.diffAlpha = torch.cmin(self.absAlpha, self.delta) 15 | 16 | self.output = torch.cmul(self.diffAlpha, self.absAlpha:mul(2):add(-self.diffAlpha)):mul(0.5):mean() 17 | 18 | return self.output 19 | end 20 | 21 | function HuberCriterion:updateGradInput(input, target) 22 | self.gradInput:resizeAs(target) 23 | 24 | self.gradInput = self.alpha:sign():cmul(self.diffAlpha) 25 | 26 | return self.gradInput 27 | end 28 | 29 | return nn.HuberCriterion 30 | -------------------------------------------------------------------------------- /models/Atari2013.lua: -------------------------------------------------------------------------------- 1 | local nn = require 'nn' 2 | require 'classic.torch' -- Enables serialisation 3 | 4 | local Body = classic.class('Body') 5 | 6 | -- Constructor 7 | function Body:_init(opts) 8 | opts = opts or {} 9 | 10 | self.recurrent = opts.recurrent 11 | self.histLen = opts.histLen 12 | self.stateSpec = opts.stateSpec 13 | end 14 | 15 | function Body:createBody() 16 | -- Number of input frames for recurrent networks is always 1 17 | local histLen = self.recurrent and 1 or self.histLen 18 | local net = nn.Sequential() 19 | net:add(nn.View(histLen*self.stateSpec[2][1], self.stateSpec[2][2], self.stateSpec[2][3])) -- Concatenate history in channel dimension 20 | net:add(nn.SpatialConvolution(histLen*self.stateSpec[2][1], 16, 8, 8, 4, 4, 1, 1)) 21 | net:add(nn.ReLU(true)) 22 | net:add(nn.SpatialConvolution(16, 32, 4, 4, 2, 2)) 23 | net:add(nn.ReLU(true)) 24 | 25 | return net 26 | end 27 | 28 | return Body 29 | -------------------------------------------------------------------------------- /async/SarsaAgent.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | local QAgent = require 'async/OneStepQAgent' 3 | 4 | local SarsaAgent, super = classic.class('SarsaAgent', 'OneStepQAgent') 5 | 6 | 7 | function SarsaAgent:_init(opt, policyNet, targetNet, theta, targetTheta, atomic, sharedG) 8 | super._init(self, opt, policyNet, targetNet, theta, targetTheta, atomic, sharedG) 9 | log.info('creating SarsaAgent') 10 | self.agentName = 'SarsaAgent' 11 | classic.strict(self) 12 | end 13 | 14 | 15 | function SarsaAgent:accumulateGradient(state, action, state_, reward, terminal) 16 | local Y = reward 17 | local Q_state = self.QCurr[action] 18 | 19 | if not terminal then 20 | local action_ = self:eGreedy(state_, self.targetNet) 21 | 22 | Y = Y + self.gamma * self.QCurr[action_] 23 | end 24 | 25 | local tdErr = Y - Q_state 26 | 27 | self:accumulateGradientTdErr(state, action, tdErr, self.policyNet) 28 | end 29 | 30 | 31 | return SarsaAgent 32 | -------------------------------------------------------------------------------- /modules/DeconvnetReLU.lua: -------------------------------------------------------------------------------- 1 | local DeconvnetReLU, parent = torch.class('nn.DeconvnetReLU', 'nn.ReLU') 2 | 3 | function DeconvnetReLU:__init(p) 4 | parent.__init(self, p) 5 | self.deconv = false 6 | end 7 | 8 | function DeconvnetReLU:updateOutput(input) 9 | return parent.updateOutput(self, input) 10 | end 11 | 12 | function DeconvnetReLU:updateGradInput(input, gradOutput) 13 | if self.deconv then 14 | -- Backpropagate all positive error signals (irrelevant of positive inputs) 15 | if self.inplace then 16 | self.gradInput = gradOutput 17 | else 18 | self.gradInput:resizeAs(gradOutput):copy(gradOutput) 19 | end 20 | 21 | self.gradInput:cmul(torch.gt(gradOutput, 0):typeAs(gradOutput)) 22 | else 23 | parent.updateGradInput(self, input, gradOutput) 24 | end 25 | 26 | return self.gradInput 27 | end 28 | 29 | function DeconvnetReLU:salientBackprop() 30 | self.deconv = true 31 | end 32 | 33 | function DeconvnetReLU:normalBackprop() 34 | self.deconv = false 35 | end 36 | -------------------------------------------------------------------------------- /models/Atari.lua: -------------------------------------------------------------------------------- 1 | local nn = require 'nn' 2 | require 'classic.torch' -- Enables serialisation 3 | 4 | local Body = classic.class('Body') 5 | 6 | -- Constructor 7 | function Body:_init(opts) 8 | opts = opts or {} 9 | 10 | self.recurrent = opts.recurrent 11 | self.histLen = opts.histLen 12 | self.stateSpec = opts.stateSpec 13 | end 14 | 15 | function Body:createBody() 16 | -- Number of input frames for recurrent networks is always 1 17 | local histLen = self.recurrent and 1 or self.histLen 18 | local net = nn.Sequential() 19 | net:add(nn.View(histLen*self.stateSpec[2][1], self.stateSpec[2][2], self.stateSpec[2][3])) -- Concatenate history in channel dimension 20 | net:add(nn.SpatialConvolution(histLen*self.stateSpec[2][1], 32, 8, 8, 4, 4, 1, 1)) 21 | net:add(nn.ReLU(true)) 22 | net:add(nn.SpatialConvolution(32, 64, 4, 4, 2, 2)) 23 | net:add(nn.ReLU(true)) 24 | net:add(nn.SpatialConvolution(64, 64, 3, 3, 1, 1)) 25 | net:add(nn.ReLU(true)) 26 | 27 | return net 28 | end 29 | 30 | return Body 31 | -------------------------------------------------------------------------------- /modules/MinDim.lua: -------------------------------------------------------------------------------- 1 | local MinDim, parent = torch.class('nn.MinDim', 'nn.Module') 2 | 3 | local function _assertTensor(t) 4 | assert(torch.isTensor(t), "This module only works on tensor") 5 | end 6 | 7 | function MinDim:__init(pos, minInputDims) 8 | parent.__init(self) 9 | self.pos = pos or error('the position to insert singleton dim not specified') 10 | self:setMinInputDims(minInputDims) 11 | end 12 | 13 | function MinDim:setMinInputDims(numInputDims) 14 | self.numInputDims = numInputDims 15 | return self 16 | end 17 | 18 | function MinDim:updateOutput(input) 19 | _assertTensor(input) 20 | self.output = input 21 | if input:dim() < self.numInputDims then 22 | nn.utils.addSingletonDimension(self.output, input, self.pos) 23 | end 24 | return self.output 25 | end 26 | 27 | function MinDim:updateGradInput(input, gradOutput) 28 | _assertTensor(input) 29 | _assertTensor(gradOutput) 30 | assert(input:nElement() == gradOutput:nElement()) 31 | self.gradInput:view(gradOutput, input:size()) 32 | return self.gradInput 33 | end 34 | -------------------------------------------------------------------------------- /modules/DuelAggregator.lua: -------------------------------------------------------------------------------- 1 | -- Creates aggregator module for a dueling architecture based on a number of discrete actions 2 | local DuelAggregator = function(m) 3 | local aggregator = nn.Sequential() 4 | local aggParallel = nn.ParallelTable() 5 | 6 | -- Advantage duplicator (for calculating and subtracting mean) 7 | local advDuplicator = nn.Sequential() 8 | local advConcat = nn.ConcatTable() 9 | advConcat:add(nn.Identity()) 10 | -- Advantage mean duplicator 11 | local advMeanDuplicator = nn.Sequential() 12 | advMeanDuplicator:add(nn.Mean(1, 1)) 13 | advMeanDuplicator:add(nn.Replicate(m, 2, 2)) 14 | advConcat:add(advMeanDuplicator) 15 | advDuplicator:add(advConcat) 16 | -- Subtract mean from advantage values 17 | advDuplicator:add(nn.CSubTable()) 18 | 19 | -- Add value and advantage duplicators 20 | aggParallel:add(nn.Replicate(m, 2, 2)) 21 | aggParallel:add(advDuplicator) 22 | 23 | -- Calculate Q^ = V^ + A^ 24 | aggregator:add(aggParallel) 25 | aggregator:add(nn.CAddTable()) 26 | 27 | return aggregator 28 | end 29 | 30 | return DuelAggregator 31 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Kai Arulkumaran 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /async/AsyncEvaluation.lua: -------------------------------------------------------------------------------- 1 | local Display = require 'Display' 2 | local ValidationAgent = require 'async/ValidationAgent' 3 | local AsyncModel = require 'async/AsyncModel' 4 | local classic = require 'classic' 5 | local tds = require 'tds' 6 | 7 | local AsyncEvaluation = classic.class('AsyncEvaluation') 8 | 9 | 10 | function AsyncEvaluation:_init(opt) 11 | local asyncModel = AsyncModel(opt) 12 | local env = asyncModel:getEnvAndModel() 13 | local policyNet = asyncModel:createNet() 14 | local theta = policyNet:getParameters() 15 | 16 | local weightsFile = paths.concat('experiments', opt._id, 'last.weights.t7') 17 | local weights = torch.load(weightsFile) 18 | theta:copy(weights) 19 | 20 | local atomic = tds.AtomicCounter() 21 | self.validAgent = ValidationAgent(opt, theta, atomic) 22 | 23 | local state = env:start() 24 | self.hasDisplay = false 25 | if opt.displaySpec then 26 | self.hasDisplay = true 27 | self.display = Display(opt, env:getDisplay()) 28 | end 29 | 30 | classic.strict(self) 31 | end 32 | 33 | 34 | function AsyncEvaluation:evaluate() 35 | local display = self.hasDisplay and self.display or nil 36 | self.validAgent:evaluate(display) 37 | end 38 | 39 | return AsyncEvaluation 40 | -------------------------------------------------------------------------------- /examples/GridWorldVis.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | local GridWorld = require 'rlenvs.GridWorld' 3 | 4 | local GridWorldVis, super = classic.class('GridWorldVis', GridWorld) 5 | 6 | function GridWorldVis:_init(opts) 7 | super._init(self) 8 | 9 | -- Create screen 10 | self.screen = torch.Tensor(3, 21, 21):zero() 11 | end 12 | 13 | function GridWorldVis:getStateSpec() 14 | return {'real', {2}, {0, 1}} 15 | end 16 | 17 | function GridWorldVis:getDisplaySpec() 18 | return {'real', {3, 21, 21}, {0, 1}} 19 | end 20 | 21 | function GridWorldVis:getDisplay() 22 | return self.screen 23 | end 24 | 25 | function GridWorldVis:drawPixel(draw) 26 | if draw then 27 | self.screen[{{}, {20*self.position[2]+1}, {20*self.position[1]+1}}] = 1 28 | else 29 | self.screen[{{}, {20*self.position[2]+1}, {20*self.position[1]+1}}] = 0 30 | end 31 | end 32 | 33 | function GridWorldVis:start() 34 | super.start(self) 35 | 36 | self.screen:zero() 37 | self:drawPixel(true) 38 | 39 | return torch.Tensor(self.position) 40 | end 41 | 42 | function GridWorldVis:step(action) 43 | self:drawPixel(false) 44 | 45 | local reward, __, terminal = super.step(self, action) 46 | 47 | self:drawPixel(true) 48 | 49 | return reward, torch.Tensor(self.position), terminal 50 | end 51 | 52 | return GridWorldVis 53 | -------------------------------------------------------------------------------- /modules/rmspropm.lua: -------------------------------------------------------------------------------- 1 | -- RMSProp with momentum as found in "Generating Sequences With Recurrent Neural Networks" 2 | function optim.rmspropm(opfunc, x, config, state) 3 | -- Get state 4 | local config = config or {} 5 | local state = state or config 6 | local lr = config.learningRate or 1e-2 7 | local momentum = config.momentum or 0.95 8 | local epsilon = config.epsilon or 0.01 9 | 10 | -- Evaluate f(x) and df/dx 11 | local fx, dfdx = opfunc(x) 12 | 13 | -- Initialise storage 14 | if not state.g then 15 | state.g = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() 16 | state.gSq = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() 17 | state.tmp = torch.Tensor():typeAs(x):resizeAs(dfdx) 18 | end 19 | 20 | -- g = αg + (1 - α)df/dx 21 | state.g:mul(momentum):add(1 - momentum, dfdx) -- Calculate momentum 22 | -- tmp = df/dx . df/dx 23 | state.tmp:cmul(dfdx, dfdx) 24 | -- gSq = αgSq + (1 - α)df/dx 25 | state.gSq:mul(momentum):add(1 - momentum, state.tmp) -- Calculate "squared" momentum 26 | -- tmp = g . g 27 | state.tmp:cmul(state.g, state.g) 28 | -- tmp = (-tmp + gSq + ε)^0.5 29 | state.tmp:neg():add(state.gSq):add(epsilon):sqrt() 30 | 31 | -- Update x = x - lr x df/dx / tmp 32 | x:addcdiv(-lr, dfdx, state.tmp) 33 | 34 | -- Return x*, f(x) before optimisation 35 | return x, {fx} 36 | end 37 | -------------------------------------------------------------------------------- /structures/CircularQueue.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | require 'classic.torch' -- Enables serialisation 3 | require 'torchx' 4 | 5 | -- A non-standard circular queue 6 | local CircularQueue = classic.class('CircularQueue') 7 | 8 | -- Creates a new fixed-length circular queue and tensor creation function 9 | function CircularQueue:_init(length, createTensor, tensorSizes) 10 | self.length = length 11 | self.queue = {} 12 | self.reset = false 13 | 14 | -- Initialise zero tensors 15 | for i = 1, self.length do 16 | self.queue[#self.queue + 1] = createTensor(torch.LongStorage(tensorSizes)):zero() 17 | end 18 | end 19 | 20 | -- Pushes a new element to the end of the queue and moves all others down 21 | function CircularQueue:push(tensor) 22 | if self.reset then 23 | -- If reset flag set, zero old tensors 24 | for i = 1, self.length - 1 do 25 | self.queue[i]:zero() 26 | end 27 | 28 | -- Unset reset flag 29 | self.reset = false 30 | else 31 | -- Otherwise, move old elements down 32 | for i = 1, self.length - 1 do 33 | self.queue[i] = self.queue[i + 1] 34 | end 35 | end 36 | 37 | -- Add new element (casting if needed, will keep reference if not) 38 | self.queue[self.length] = tensor:typeAs(self.queue[1]) 39 | end 40 | 41 | -- Pushes a new element to the end of the queue and sets reset flag 42 | function CircularQueue:pushReset(tensor) 43 | -- Move old elements down 44 | for i = 1, self.length - 1 do 45 | self.queue[i] = self.queue[i + 1] 46 | end 47 | 48 | -- Add new element (casting if needed, will keep reference if not) 49 | self.queue[self.length] = tensor:typeAs(self.queue[1]) 50 | 51 | -- Set reset flag 52 | self.reset = true 53 | end 54 | 55 | -- Resets (zeros) the entire queue 56 | function CircularQueue:clear() 57 | for i = 1, self.length do 58 | self.queue[i]:zero() 59 | end 60 | end 61 | 62 | -- Reads entire queue as a large tensor 63 | function CircularQueue:readAll() 64 | return torch.concat(self.queue) 65 | end 66 | 67 | return CircularQueue 68 | -------------------------------------------------------------------------------- /Evaluator.lua: -------------------------------------------------------------------------------- 1 | local _ = require 'moses' 2 | local classic = require 'classic' 3 | require 'classic.torch' -- Enables serialisation 4 | 5 | -- Table of game names 6 | local games = {'alien', 'amidar', 'assault', 'asterix', 'asteroids', 'atlantis', 'bank_heist', 'battle_zone', 'beam_rider', 'bowling', 'boxing', 'breakout', 'centipede', 'chopper_command', 'crazy_climber', 'demon_attack', 'double_dunk', 'enduro', 'fishing_derby', 'freeway', 'frostbite', 'gopher', 'gravitar', 'hero', 'ice_hockey', 'james_bond', 'kangaroo', 'krull', 'kung_fu_master', 'montezuma_revenge', 'ms_pacman', 'name_this_game', 'pong', 'private_eye', 'q_bert', 'river_raid', 'road_runner', 'robotank', 'seaquest', 'space_invaders', 'star_gunner', 'tennis', 'time_pilot', 'tutankham', 'up_n_down', 'venture', 'video_pinball', 'wizard_of_wor', 'zaxxon'} 7 | 8 | -- Table of random no-op scores 9 | local random = {227.80, 5.80, 222.40, 210.00, 719.10, 12850.00, 14.20, 2360.00, 363.90, 23.10, 0.10, 1.70, 2090.90, 811.00, 10780.50, 152.10, -18.60, 0.00, -91.70, 0.00, 65.20, 257.60, 173.00, 1027.00, -11.20, 29.00, 52.00, 1598.00, 258.50, 0.00, 307.30, 2292.30, -20.70, 24.90, 163.90, 1338.50, 11.50, 2.20, 68.40, 148.00, 664.00, -23.80, 3568.00, 11.40, 533.40, 0.00, 16256.90, 563.50, 32.50} 10 | 11 | -- Table of human scores 12 | local human = {6875.40, 1675.80, 1496.40, 8503.30, 13156.70, 29028.10, 734.40, 37800.00, 5774.70, 154.80, 4.30, 31.80, 11963.20, 9881.80, 35410.50, 3401.30, -15.50, 309.60, 5.50, 29.60, 4334.70, 2321.00, 2672.00, 25762.50, 0.90, 406.70, 3035.00, 2394.60, 22736.20, 4366.70, 15693.40, 4076.20, 9.30, 69571.30, 13455.00, 13513.30, 7845.00, 11.90, 20181.80, 1652.30, 10250.00, -8.90, 5925.00, 167.60, 9082.00, 1187.50, 17297.60, 4756.50, 9173.30} 13 | 14 | local Evaluator = classic.class('Evaluator') 15 | 16 | function Evaluator:_init(game) 17 | -- Game index 18 | self.index = _.find(games, game) 19 | end 20 | 21 | -- Calculates a normalised game score based on random and human performance 22 | function Evaluator:normaliseScore(score) 23 | -- Return (score_agent - score_random)/abs(score_human - score_random) 24 | return self.index and (score - random[self.index]) / math.abs(human[self.index] - random[self.index]) or nil -- Returns nil if game not included 25 | end 26 | 27 | return Evaluator 28 | -------------------------------------------------------------------------------- /Display.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | local image = require 'image' 3 | 4 | -- Detect QT for image display 5 | local qt = pcall(require, 'qt') 6 | 7 | -- Display is responsible for handling QT/recording logic 8 | local Display = classic.class('Display') 9 | 10 | -- Creates display; live if using QT 11 | function Display:_init(opt, display) 12 | self._id = opt._id 13 | self.zoom = opt.zoom 14 | self.displayHeight = opt.displaySpec[2][2] 15 | self.displayWidth = opt.displaySpec[2][3] 16 | self.saliency = opt.saliency 17 | self.record = opt.mode == 'eval' and opt.record 18 | self.fps = 60 19 | 20 | -- Activate live display if using QT 21 | self.window = qt and image.display({image=display, zoom=self.zoom}) 22 | 23 | -- Set up recording 24 | if self.record then 25 | -- Recreate scratch directory 26 | paths.rmall('scratch', 'yes') 27 | paths.mkdir('scratch') 28 | 29 | log.info('Recording screen') 30 | end 31 | 32 | classic.strict(self) 33 | end 34 | 35 | -- Computes saliency map for display from agent field 36 | function Display:createSaliencyMap(agent, display) 37 | local screen = display:clone() -- Cloned to prevent side-effects 38 | local saliencyMap = agent.saliencyMap:float() 39 | 40 | -- Use red channel for saliency map 41 | screen:select(1, 1):copy(image.scale(saliencyMap, self.displayWidth, self.displayHeight)) 42 | 43 | return screen 44 | end 45 | 46 | -- Show display (handles recording as well for efficiency) 47 | function Display:display(agent, display, step) 48 | if qt or self.record then 49 | local screen = self.saliency and self:createSaliencyMap(agent, display) or display 50 | 51 | -- Display 52 | if qt then 53 | image.display({image=screen, zoom=self.zoom, win=self.window}) 54 | end 55 | 56 | -- Record 57 | if self.record then 58 | image.save(paths.concat('scratch', self._id .. '_' .. string.format('%06d', step) .. '.jpg'), screen) 59 | end 60 | end 61 | end 62 | 63 | -- Creates videos from frames if recording 64 | function Display:createVideo() 65 | if self.record then 66 | log.info('Recorded screen') 67 | 68 | -- Create videos directory 69 | if not paths.dirp('videos') then 70 | paths.mkdir('videos') 71 | end 72 | 73 | -- Use FFmpeg to create a video from the screens 74 | log.info('Creating video') 75 | os.execute('ffmpeg -framerate ' .. self.fps .. ' -start_number 1 -i scratch/' .. self._id .. '_%06d.jpg -c:v libvpx-vp9 -crf 0 -b:v 0 videos/' .. self._id .. '.webm') 76 | log.info('Created video') 77 | 78 | -- Clear scratch space 79 | paths.rmall('scratch', 'yes') 80 | end 81 | end 82 | 83 | return Display 84 | -------------------------------------------------------------------------------- /test/testExperience.lua: -------------------------------------------------------------------------------- 1 | local Singleton = require 'structures/Singleton' 2 | local Experience = require 'Experience' 3 | 4 | local Test = torch.TestSuite() 5 | local standalone = tester == nil 6 | if standalone then 7 | tester = torch.Tester() 8 | end 9 | 10 | torch.manualSeed(1) 11 | 12 | local globals = Singleton({step = 1}) 13 | 14 | local isValidation = false 15 | local capacity = 1e4 16 | local opt = { 17 | histLen = 1, 18 | stateSpec = { 19 | 'real', 20 | {1, 10, 10, 10}, 21 | {0, 1} 22 | }, 23 | discretiseMem = true, 24 | batchSize = 10, 25 | bootstraps = 5, 26 | gpu = false, 27 | memPriority = '', 28 | learnStart = 0, 29 | steps = 1e6, 30 | alpha = .65, 31 | betaZero = 0.45, 32 | Tensor = torch.Tensor, 33 | } 34 | 35 | 36 | local function randomPopulate(priorities, experience) 37 | local state = torch.Tensor(table.unpack(opt.stateSpec[2])) 38 | local terminal = false 39 | local action = 1 40 | local reward = 1 41 | local idx = torch.Tensor(1) 42 | local prio = torch.Tensor(1) 43 | local maxPrio = 1000 44 | local heads = math.max(opt.bootstraps, 1) 45 | local mask = torch.ByteTensor(heads) 46 | 47 | for i=1,capacity do 48 | experience:store(reward, state, terminal, action, mask:clone():bernoulli(0.5)) 49 | idx[1] = i 50 | prio[1] = torch.random(maxPrio) - maxPrio / 2 51 | priorities[i] = prio[1] 52 | experience:updatePriorities(idx, prio) 53 | end 54 | end 55 | 56 | local function samplePriorityMeans(times) 57 | local experience = Experience(capacity, opt, isValidation) 58 | local priorities = torch.Tensor(capacity) 59 | local heads = math.max(opt.bootstraps, 1) 60 | randomPopulate(priorities, experience) 61 | 62 | local samplePriorities = torch.Tensor(times, opt.batchSize) 63 | 64 | for i=1,times do 65 | local idxs = experience:sample(torch.random(heads)) 66 | samplePriorities[i] = priorities:gather(1, idxs) 67 | end 68 | 69 | -- print(samplePriorities) 70 | local means = samplePriorities:abs():mean(1):squeeze() 71 | print(means) 72 | 73 | return means 74 | end 75 | 76 | function Test:TestExperience_TestUniform() 77 | torch.manualSeed(1) 78 | opt.memPriority = false 79 | local means = samplePriorityMeans(1000) 80 | 81 | for i=1,means:size(1) do 82 | tester:assert(means[i]>235 and means[i]<265) 83 | end 84 | end 85 | 86 | function Test:TestExperience_TestRank() 87 | torch.manualSeed(1) 88 | opt.memPriority = 'rank' 89 | local means = samplePriorityMeans(10) 90 | 91 | for i=2,means:size(1) do 92 | tester:assertle(means[i], means[i-1]) 93 | end 94 | end 95 | 96 | 97 | if standalone then 98 | tester:add(Test) 99 | tester:run() 100 | end 101 | 102 | return Test 103 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: c 2 | compiler: 3 | - gcc 4 | # - clang 5 | cache: 6 | directories: 7 | - $HOME/OpenBlasInstall 8 | env: 9 | - TORCH_LUA_VERSION=LUAJIT21 10 | # - TORCH_LUA_VERSION=LUA51 11 | # - TORCH_LUA_VERSION=LUA52 12 | sudo: false 13 | addons: 14 | apt: 15 | packages: 16 | - cmake 17 | - gfortran 18 | - gcc-multilib 19 | - gfortran-multilib 20 | - liblapack-dev 21 | - build-essential 22 | - gcc 23 | - g++ 24 | - curl 25 | - cmake 26 | - libreadline-dev 27 | - git-core 28 | - libqt4-core 29 | - libqt4-gui 30 | - libqt4-dev 31 | - libjpeg-dev 32 | - libpng-dev 33 | - ncurses-dev 34 | - imagemagick 35 | - libzmq3-dev 36 | - gfortran 37 | - unzip 38 | - gnuplot 39 | - gnuplot-x11 40 | before_script: 41 | - export ROOT_TRAVIS_DIR=$(pwd) 42 | - export INSTALL_PREFIX=~/torch/install 43 | - ls $HOME/OpenBlasInstall/lib || (cd /tmp/ && git clone https://github.com/xianyi/OpenBLAS.git -b master && cd OpenBLAS && (make NO_AFFINITY=1 -j$(getconf _NPROCESSORS_ONLN) 2>/dev/null >/dev/null) && make PREFIX=$HOME/OpenBlasInstall install) 44 | - git clone https://github.com/torch/distro.git ~/torch --recursive 45 | - cd ~/torch && git submodule update --init --recursive 46 | - mkdir build && cd build 47 | - export CMAKE_LIBRARY_PATH=$HOME/OpenBlasInstall/include:$HOME/OpenBlasInstall/lib:$CMAKE_LIBRARY_PATH 48 | - cmake .. -DCMAKE_INSTALL_PREFIX="${INSTALL_PREFIX}" -DCMAKE_BUILD_TYPE=Release -DWITH_${TORCH_LUA_VERSION}=ON 49 | - make && make install 50 | - cd $ROOT_TRAVIS_DIR 51 | - export LD_LIBRARY_PATH=${INSTALL_PREFIX}/lib:$LD_LIBRARY_PATH 52 | script: 53 | - ${INSTALL_PREFIX}/bin/luarocks install luaffi 54 | - ${INSTALL_PREFIX}/bin/luarocks install luaposix 33.4.0 55 | - ${INSTALL_PREFIX}/bin/luarocks install moses 56 | - ${INSTALL_PREFIX}/bin/luarocks install logroll 57 | - ${INSTALL_PREFIX}/bin/luarocks install classic 58 | - ${INSTALL_PREFIX}/bin/luarocks install torchx 59 | - ${INSTALL_PREFIX}/bin/luarocks install dpnn 60 | - ${INSTALL_PREFIX}/bin/luarocks install tds 61 | - ${INSTALL_PREFIX}/bin/luarocks install nninit 62 | - ${INSTALL_PREFIX}/bin/luarocks install https://raw.githubusercontent.com/Kaixhin/rlenvs/master/rocks/rlenvs-scm-1.rockspec 63 | - export PATH=${INSTALL_PREFIX}/bin:$PATH 64 | - export TESTLUA=$(which luajit lua | head -n 1) 65 | - ${TESTLUA} test/testAll.lua 66 | notifications: 67 | email: 68 | on_success: change 69 | on_failure: change 70 | on_start: never 71 | webhooks: 72 | urls: 73 | - https://webhooks.gitter.im/e/faf2d7f3cc77829f144c 74 | on_success: change # options: [always|never|change] default: always 75 | on_failure: always # options: [always|never|change] default: always 76 | on_start: never # options: [always|never|change] default: always 77 | -------------------------------------------------------------------------------- /async/NStepQAgent.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | local optim = require 'optim' 3 | local QAgent = require 'async/QAgent' 4 | require 'modules/sharedRmsProp' 5 | 6 | local NStepQAgent, super = classic.class('NStepQAgent', 'QAgent') 7 | 8 | 9 | function NStepQAgent:_init(opt, policyNet, targetNet, theta, targetTheta, atomic, sharedG) 10 | super._init(self, opt, policyNet, targetNet, theta, targetTheta, atomic, sharedG) 11 | self.policyNet_ = self.policyNet:clone() 12 | self.policyNet_:training() 13 | self.theta_, self.dTheta_ = self.policyNet_:getParameters() 14 | self.dTheta_:zero() 15 | 16 | self.rewards = torch.Tensor(self.batchSize) 17 | self.actions = torch.ByteTensor(self.batchSize) 18 | self.states = torch.Tensor(0) 19 | 20 | self.env:training() 21 | 22 | self.alwaysComputeGreedyQ = false 23 | 24 | classic.strict(self) 25 | end 26 | 27 | 28 | function NStepQAgent:learn(steps, from) 29 | self.step = from or 0 30 | self.stateBuffer:clear() 31 | 32 | log.info('NStepQAgent starting | steps=%d | ε=%.2f -> %.2f', steps, self.epsilon, self.epsilonEnd) 33 | local reward, terminal, state = self:start() 34 | 35 | self.states:resize(self.batchSize, table.unpack(state:size():totable())) 36 | self.tic = torch.tic() 37 | repeat 38 | self.theta_:copy(self.theta) 39 | self.batchIdx = 0 40 | repeat 41 | self.batchIdx = self.batchIdx + 1 42 | self.states[self.batchIdx]:copy(state) 43 | 44 | local action = self:eGreedy(state, self.policyNet_) 45 | self.actions[self.batchIdx] = action 46 | 47 | reward, terminal, state = self:takeAction(action) 48 | self.rewards[self.batchIdx] = reward 49 | 50 | self:progress(steps) 51 | until terminal or self.batchIdx == self.batchSize 52 | 53 | self:accumulateGradients(terminal, state) 54 | 55 | if terminal then 56 | reward, terminal, state = self:start() 57 | end 58 | 59 | self:applyGradients(self.policyNet_, self.dTheta_, self.theta) 60 | until self.step >= steps 61 | 62 | log.info('NStepQAgent ended learning steps=%d ε=%.4f', steps, self.epsilon) 63 | end 64 | 65 | 66 | function NStepQAgent:accumulateGradients(terminal, state) 67 | local R = 0 68 | if not terminal then 69 | local QPrimes = self.targetNet:forward(state):squeeze() 70 | local APrimeMax = QPrimes:max(1):squeeze() 71 | 72 | if self.doubleQ then 73 | local _,APrimeMaxInds = self.policyNet_:forward(state):squeeze():max(1) 74 | APrimeMax = QPrimes[APrimeMaxInds[1]] 75 | end 76 | R = APrimeMax 77 | end 78 | 79 | for i=self.batchIdx,1,-1 do 80 | R = self.rewards[i] + self.gamma * R 81 | local Q_i = self.policyNet_:forward(self.states[i]):squeeze() 82 | local tdErr = R - Q_i[self.actions[i]] 83 | self:accumulateGradientTdErr(self.states[i], self.actions[i], tdErr, self.policyNet_) 84 | end 85 | end 86 | 87 | 88 | return NStepQAgent 89 | -------------------------------------------------------------------------------- /async/OneStepQAgent.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | local optim = require 'optim' 3 | local QAgent = require 'async/QAgent' 4 | require 'modules/sharedRmsProp' 5 | 6 | local OneStepQAgent, super = classic.class('OneStepQAgent', 'QAgent') 7 | 8 | 9 | function OneStepQAgent:_init(opt, policyNet, targetNet, theta, targetTheta, atomic, sharedG) 10 | super._init(self, opt, policyNet, targetNet, theta, targetTheta, atomic, sharedG) 11 | self.agentName = 'OneStepQAgent' 12 | self.lstm = opt.recurrent and self.policyNet:findModules('nn.FastLSTM')[1] 13 | self.lstmTarget = opt.recurrent and self.targetNet:findModules('nn.FastLSTM')[1] 14 | classic.strict(self) 15 | end 16 | 17 | 18 | function OneStepQAgent:learn(steps, from) 19 | self.step = from or 0 20 | self.policyNet:training() 21 | self.stateBuffer:clear() 22 | self.env:training() 23 | 24 | log.info('%s starting | steps=%d | ε=%.2f -> %.2f', self.agentName, steps, self.epsilon, self.epsilonEnd) 25 | local reward, terminal, state = self:start() 26 | 27 | local action, state_ 28 | 29 | self.tic = torch.tic() 30 | for step1=1,steps do 31 | if not terminal then 32 | action = self:eGreedy(state, self.policyNet) 33 | reward, terminal, state_ = self:takeAction(action) 34 | else 35 | reward, terminal, state_ = self:start() 36 | end 37 | 38 | if state ~= nil then 39 | self:accumulateGradient(state, action, state_, reward, terminal) 40 | self.batchIdx = self.batchIdx + 1 41 | end 42 | 43 | if not terminal then 44 | state = state_ 45 | else 46 | if self.lstm then 47 | self.lstm:forget() 48 | self.lstmTarget:forget() 49 | end 50 | state = nil 51 | end 52 | 53 | if self.batchIdx == self.batchSize or terminal then 54 | self:applyGradients(self.policyNet, self.dTheta, self.theta) 55 | if self.lstm then 56 | self.lstm:forget() 57 | self.lstmTarget:forget() 58 | end 59 | self.batchIdx = 0 60 | end 61 | 62 | self:progress(steps) 63 | end 64 | 65 | log.info('%s ended learning steps=%d ε=%.4f', self.agentName, steps, self.epsilon) 66 | end 67 | 68 | 69 | function OneStepQAgent:accumulateGradient(state, action, state_, reward, terminal) 70 | local Y = reward 71 | if self.lstm then -- LSTM targetNet needs to see all states as well 72 | self.targetNet:forward(state) 73 | end 74 | if not terminal then 75 | local QPrimes = self.targetNet:forward(state_):squeeze() 76 | local APrimeMax = QPrimes:max(1):squeeze() 77 | 78 | if self.doubleQ then 79 | local _,APrimeMaxInds = self.policyNet:forward(state_):squeeze():max(1) 80 | APrimeMax = QPrimes[APrimeMaxInds[1]] 81 | end 82 | 83 | Y = Y + self.gamma * APrimeMax 84 | end 85 | 86 | if self.doubleQ then 87 | self.QCurr = self.policyNet:forward(state):squeeze() 88 | end 89 | 90 | local tdErr = Y - self.QCurr[action] 91 | 92 | self:accumulateGradientTdErr(state, action, tdErr, self.policyNet) 93 | end 94 | 95 | 96 | return OneStepQAgent 97 | -------------------------------------------------------------------------------- /roms/README.md: -------------------------------------------------------------------------------- 1 | ROMs 2 | ==== 3 | 4 | Atari 2600 binary ROM files should be placed in this directory with the appropriate filenames. 5 | 6 | Supported Games 7 | --------------- 8 | 9 | | Game | ROM Name | 10 | |---------------------|-------------------| 11 | | Air Raid | air_raid | 12 | | Alien | alien | 13 | | Amidar | amidar | 14 | | Assault | assault | 15 | | Asterix | asterix | 16 | | Asteroids | asteroids | 17 | | Atlantis | atlantis | 18 | | Bank Heist | bank_heist | 19 | | Battlezone | battle_zone | 20 | | Beamrider | beam_rider | 21 | | Berzerk | berzerk | 22 | | Bowling | bowling | 23 | | Boxing | boxing | 24 | | Breakout | breakout | 25 | | Carnival | carnival | 26 | | Centipede | centipede | 27 | | Chopper Command | chopper_command | 28 | | Crazy Climber | crazy_climber | 29 | | Defender | defender | 30 | | Demon Attack | demon_attack | 31 | | Double Dunk | double_dunk | 32 | | Elevator Action | elevator_action | 33 | | Enduro | enduro | 34 | | Fishing Derby | fishing_derby | 35 | | Freeway | freeway | 36 | | Frostbite | frostbite | 37 | | Gopher | gopher | 38 | | Gravitar | gravitar | 39 | | H.E.R.O. | hero | 40 | | Ice Hockey | ice_hockey | 41 | | James Bond 007 | james_bond | 42 | | Journey Escape | journey_escape | 43 | | Kangaroo | kangaroo | 44 | | Krull | krull | 45 | | Kung-Fu Master | kung_fu_master | 46 | | Montezuma's Revenge | montezuma_revenge | 47 | | Ms. Pac-Man | ms_pacman | 48 | | Name This Game | name_this_game | 49 | | Pac-Man | pacman | 50 | | Phoenix | phoenix | 51 | | Pitfall! | pitfall | 52 | | Pong | pong | 53 | | Pooyan | pooyan | 54 | | Private Eye | private_eye | 55 | | Q*bert | q_bert | 56 | | River Raid | riverraid | 57 | | Road Runner | road_runner | 58 | | Robot Tank | robotank | 59 | | Seaquest | seaquest | 60 | | Skiing | skiing | 61 | | Solaris | solaris | 62 | | Space Invaders | space_invaders | 63 | | Stargunner | star_gunner | 64 | | Surround | surround | 65 | | Tennis | tennis | 66 | | Time Pilot | time_pilot | 67 | | Tutankham | tutankham | 68 | | Up’n Down | up_n_down | 69 | | Venture | venture | 70 | | Video Chess | video_chess | 71 | | Video Pinball | video_pinball | 72 | | Wizard of Wor | wizard_of_wor | 73 | | Yars' Revenge | yars_revenge | 74 | | Zaxxon | zaxxon | 75 | -------------------------------------------------------------------------------- /async/QAgent.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | local QAgent = require 'async/AsyncAgent' 3 | 4 | local QAgent, super = classic.class('QAgent', 'AsyncAgent') 5 | 6 | local EPSILON_ENDS = { 0.1, 0.01, 0.5} 7 | local EPSILON_PROBS = { 0.4, 0.7, 1 } 8 | 9 | 10 | function QAgent:_init(opt, policyNet, targetNet, theta, targetTheta, atomic, sharedG) 11 | super._init(self, opt, policyNet, targetNet, theta, targetTheta, atomic, sharedG) 12 | self.super = super 13 | 14 | self.targetNet = targetNet:clone('weight', 'bias') 15 | self.targetNet:evaluate() 16 | 17 | self.targetTheta = targetTheta 18 | local __, gradParams = self.policyNet:parameters() 19 | self.dTheta = nn.Module.flatten(gradParams) 20 | self.dTheta:zero() 21 | 22 | self.doubleQ = opt.doubleQ 23 | 24 | self.epsilonStart = opt.epsilonStart 25 | self.epsilon = self.epsilonStart 26 | self.PALpha = opt.PALpha 27 | 28 | self.target = self.Tensor(self.m) 29 | 30 | self.totalSteps = math.floor(opt.steps / opt.threads) 31 | 32 | self:setEpsilon(opt) 33 | self.tic = 0 34 | self.step = 0 35 | 36 | -- Forward state anyway if recurrent 37 | self.alwaysComputeGreedyQ = opt.recurrent or not self.doubleQ 38 | 39 | self.QCurr = torch.Tensor(0) 40 | end 41 | 42 | 43 | function QAgent:setEpsilon(opt) 44 | local r = torch.rand(1):squeeze() 45 | local e = 3 46 | if r < EPSILON_PROBS[1] then 47 | e = 1 48 | elseif r < EPSILON_PROBS[2] then 49 | e = 2 50 | end 51 | self.epsilonEnd = EPSILON_ENDS[e] 52 | self.epsilonGrad = (self.epsilonEnd - opt.epsilonStart) / opt.epsilonSteps 53 | end 54 | 55 | 56 | function QAgent:eGreedy(state, net) 57 | self.epsilon = math.max(self.epsilonStart + (self.step - 1)*self.epsilonGrad, self.epsilonEnd) 58 | 59 | if self.alwaysComputeGreedyQ then 60 | self.QCurr = net:forward(state):squeeze() 61 | end 62 | 63 | if torch.uniform() < self.epsilon then 64 | return torch.random(1,self.m) 65 | end 66 | 67 | if not self.alwaysComputeGreedyQ then 68 | self.QCurr = net:forward(state):squeeze() 69 | end 70 | 71 | local _, maxIdx = self.QCurr:max(1) 72 | return maxIdx[1] 73 | end 74 | 75 | 76 | function QAgent:progress(steps) 77 | self.step = self.step + 1 78 | if self.atomic:inc() % self.tau == 0 then 79 | self.targetTheta:copy(self.theta) 80 | if self.tau>1000 then 81 | log.info('QAgent | updated targetNetwork at %d', self.atomic:get()) 82 | end 83 | end 84 | if self.step % self.progFreq == 0 then 85 | local progressPercent = 100 * self.step / steps 86 | local speed = self.progFreq / torch.toc(self.tic) 87 | self.tic = torch.tic() 88 | log.info('AsyncAgent | step=%d | %.02f%% | speed=%d/sec | ε=%.2f -> %.2f | η=%.8f', 89 | self.step, progressPercent, speed ,self.epsilon, self.epsilonEnd, self.optimParams.learningRate) 90 | end 91 | end 92 | 93 | 94 | function QAgent:accumulateGradientTdErr(state, action, tdErr, net) 95 | if self.tdClip > 0 then 96 | if tdErr > self.tdClip then tdErr = self.tdClip end 97 | if tdErr <-self.tdClip then tdErr =-self.tdClip end 98 | end 99 | 100 | self.target:zero() 101 | self.target[action] = -tdErr 102 | 103 | net:backward(state, self.target) 104 | end 105 | 106 | 107 | return QAgent 108 | 109 | -------------------------------------------------------------------------------- /async/AsyncAgent.lua: -------------------------------------------------------------------------------- 1 | local AbstractAgent = require 'async/AbstractAgent' 2 | local AsyncModel = require 'async/AsyncModel' 3 | local CircularQueue = require 'structures/CircularQueue' 4 | local classic = require 'classic' 5 | local optim = require 'optim' 6 | require 'modules/sharedRmsProp' 7 | 8 | local AsyncAgent = classic.class('AsyncAgent', AbstractAgent) 9 | 10 | local methods = { 11 | OneStepQ = 'OneStepQAgent', 12 | Sarsa = 'SarsaAgent', 13 | NStepQ = 'NStepQAgent', 14 | A3C = 'A3CAgent' 15 | } 16 | 17 | function AsyncAgent.static.build(opt, policyNet, targetNet, theta, targetTheta, atomic, sharedG) 18 | local Agent = require('async/'..methods[opt.async]) 19 | return Agent(opt, policyNet, targetNet, theta, targetTheta, atomic, sharedG) 20 | end 21 | 22 | 23 | function AsyncAgent:_init(opt, policyNet, targetNet, theta, targetTheta, atomic, sharedG) 24 | local asyncModel = AsyncModel(opt) 25 | self.env, self.model = asyncModel:getEnvAndModel() 26 | 27 | self.id = __threadid or 1 28 | self.atomic = atomic 29 | 30 | self.optimiser = optim[opt.optimiser] 31 | self.optimParams = { 32 | learningRate = opt.eta, 33 | momentum = opt.momentum, 34 | rmsEpsilon = opt.rmsEpsilon, 35 | g = sharedG 36 | } 37 | 38 | self.learningRateStart = opt.eta 39 | 40 | local actionSpec = self.env:getActionSpec() 41 | self.m = actionSpec[3][2] - actionSpec[3][1] + 1 42 | self.actionOffset = 1 - actionSpec[3][1] 43 | 44 | self.policyNet = policyNet:clone('weight', 'bias') 45 | 46 | self.theta = theta 47 | local __, gradParams = self.policyNet:parameters() 48 | self.dTheta = nn.Module.flatten(gradParams) 49 | self.dTheta:zero() 50 | 51 | self.stateBuffer = CircularQueue(opt.recurrent and 1 or opt.histLen, opt.Tensor, opt.stateSpec[2]) 52 | 53 | self.gamma = opt.gamma 54 | self.rewardClip = opt.rewardClip 55 | self.tdClip = opt.tdClip 56 | 57 | self.progFreq = opt.progFreq 58 | self.batchSize = opt.batchSize 59 | self.gradClip = opt.gradClip 60 | self.tau = opt.tau 61 | self.Tensor = opt.Tensor 62 | 63 | self.batchIdx = 0 64 | 65 | self.totalSteps = math.floor(opt.steps / opt.threads) 66 | 67 | self.tic = 0 68 | self.step = 0 69 | end 70 | 71 | 72 | function AsyncAgent:start() 73 | local reward, rawObservation, terminal = 0, self.env:start(), false 74 | local observation = self.model:preprocess(rawObservation) 75 | self.stateBuffer:push(observation) 76 | return reward, terminal, self.stateBuffer:readAll() 77 | end 78 | 79 | 80 | function AsyncAgent:takeAction(action) 81 | local reward, rawObservation, terminal = self.env:step(action - self.actionOffset) 82 | if self.rewardClip > 0 then 83 | reward = math.max(reward, -self.rewardClip) 84 | reward = math.min(reward, self.rewardClip) 85 | end 86 | 87 | local observation = self.model:preprocess(rawObservation) 88 | if terminal then 89 | self.stateBuffer:pushReset(observation) 90 | else 91 | self.stateBuffer:push(observation) 92 | end 93 | 94 | return reward, terminal, self.stateBuffer:readAll() 95 | end 96 | 97 | 98 | function AsyncAgent:applyGradients(net, dTheta, theta) 99 | if self.gradClip > 0 then 100 | net:gradParamClip(self.gradClip) 101 | end 102 | 103 | local feval = function() 104 | -- loss needed for validation stats only which is not computed for async yet, so just 0 105 | local loss = 0 -- 0.5 * tdErr ^2 106 | return loss, dTheta 107 | end 108 | 109 | self.optimParams.learningRate = self.learningRateStart * (self.totalSteps - self.step) / self.totalSteps 110 | self.optimiser(feval, theta, self.optimParams) 111 | 112 | dTheta:zero() 113 | end 114 | 115 | 116 | function AsyncAgent:observe() 117 | error('not implemented yet') 118 | end 119 | 120 | 121 | function AsyncAgent:training() 122 | error('not implemented yet') 123 | end 124 | 125 | 126 | function AsyncAgent:evaluate() 127 | error('not implemented yet') 128 | end 129 | 130 | 131 | return AsyncAgent 132 | -------------------------------------------------------------------------------- /async/A3CAgent.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | local optim = require 'optim' 3 | local AsyncAgent = require 'async/AsyncAgent' 4 | require 'modules/sharedRmsProp' 5 | 6 | local A3CAgent,super = classic.class('A3CAgent', 'AsyncAgent') 7 | 8 | local TINY_EPSILON = 1e-20 9 | 10 | function A3CAgent:_init(opt, policyNet, targetNet, theta, targetTheta, atomic, sharedG) 11 | super._init(self, opt, policyNet, targetNet, theta, targetTheta, atomic, sharedG) 12 | 13 | log.info('creating A3CAgent') 14 | 15 | self.policyNet_ = policyNet:clone() 16 | 17 | self.theta_, self.dTheta_ = self.policyNet_:getParameters() 18 | self.dTheta_:zero() 19 | 20 | self.policyTarget = self.Tensor(self.m) 21 | self.vTarget = self.Tensor(1) 22 | self.targets = { self.vTarget, self.policyTarget } 23 | 24 | self.rewards = torch.Tensor(self.batchSize) 25 | self.actions = torch.ByteTensor(self.batchSize) 26 | self.states = torch.Tensor(0) 27 | self.beta = opt.entropyBeta 28 | 29 | self.env:training() 30 | 31 | classic.strict(self) 32 | end 33 | 34 | 35 | function A3CAgent:learn(steps, from) 36 | self.step = from or 0 37 | 38 | self.stateBuffer:clear() 39 | 40 | log.info('A3CAgent starting | steps=%d', steps) 41 | local reward, terminal, state = self:start() 42 | 43 | self.states:resize(self.batchSize, table.unpack(state:size():totable())) 44 | 45 | self.tic = torch.tic() 46 | repeat 47 | self.theta_:copy(self.theta) 48 | self.batchIdx = 0 49 | repeat 50 | self.batchIdx = self.batchIdx + 1 51 | self.states[self.batchIdx]:copy(state) 52 | 53 | local V, probability = table.unpack(self.policyNet_:forward(state)) 54 | local action = torch.multinomial(probability, 1):squeeze() 55 | 56 | self.actions[self.batchIdx] = action 57 | 58 | reward, terminal, state = self:takeAction(action) 59 | self.rewards[self.batchIdx] = reward 60 | 61 | self:progress(steps) 62 | until terminal or self.batchIdx == self.batchSize 63 | 64 | self:accumulateGradients(terminal, state) 65 | 66 | if terminal then 67 | reward, terminal, state = self:start() 68 | end 69 | 70 | self:applyGradients(self.policyNet_, self.dTheta_, self.theta) 71 | until self.step >= steps 72 | 73 | log.info('A3CAgent ended learning steps=%d', steps) 74 | end 75 | 76 | 77 | function A3CAgent:accumulateGradients(terminal, state) 78 | local R = 0 79 | if not terminal then 80 | R = self.policyNet_:forward(state)[1] 81 | end 82 | 83 | for i=self.batchIdx,1,-1 do 84 | R = self.rewards[i] + self.gamma * R 85 | 86 | local action = self.actions[i] 87 | local V, probability = table.unpack(self.policyNet_:forward(self.states[i])) 88 | probability:add(TINY_EPSILON) -- could contain 0 -> log(0)= -inf -> theta = nans 89 | 90 | self.vTarget[1] = -0.5 * (R - V) 91 | 92 | -- ∇θ logp(s) = 1/p(a) for chosen a, 0 otherwise 93 | self.policyTarget:zero() 94 | -- f(s) ∇θ logp(s) 95 | self.policyTarget[action] = -(R - V) / probability[action] -- Negative target for gradient descent 96 | 97 | -- Calculate (negative of) gradient of entropy of policy (for gradient descent): -(-logp(s) - 1) 98 | local gradEntropy = torch.log(probability) + 1 99 | -- Add to target to improve exploration (prevent convergence to suboptimal deterministic policy) 100 | self.policyTarget:add(self.beta, gradEntropy) 101 | 102 | self.policyNet_:backward(self.states[i], self.targets) 103 | end 104 | end 105 | 106 | 107 | function A3CAgent:progress(steps) 108 | self.atomic:inc() 109 | self.step = self.step + 1 110 | if self.step % self.progFreq == 0 then 111 | local progressPercent = 100 * self.step / steps 112 | local speed = self.progFreq / torch.toc(self.tic) 113 | self.tic = torch.tic() 114 | log.info('A3CAgent | step=%d | %.02f%% | speed=%d/sec | η=%.8f', 115 | self.step, progressPercent, speed, self.optimParams.learningRate) 116 | end 117 | end 118 | 119 | return A3CAgent 120 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Switch to script directory 4 | cd `dirname -- "$0"` 5 | 6 | # Specify paper/hyperparameters 7 | if [ -z "$1" ]; then 8 | echo "Please enter paper, e.g. ./run nature" 9 | echo "Atari Choices: nature|doubleq|duel|prioritised|priorduel|persistent|bootstrap|recurrent|async-nstep|async-a3c" 10 | echo "Catch Choices: demo|demo-async|demo-async-a3c" 11 | echo "Example Choices: demo-grid" 12 | exit 0 13 | else 14 | PAPER=$1 15 | shift 16 | fi 17 | 18 | # Specify game 19 | if ! [[ "$PAPER" =~ demo ]]; then 20 | if [ -z "$1" ]; then 21 | echo "Please enter game, e.g. ./run nature breakout" 22 | exit 0 23 | else 24 | GAME=$1 25 | shift 26 | fi 27 | fi 28 | 29 | if [[ "$PAPER" =~ async ]]; then 30 | echo "Async mode specified, setting OpenMP threads to 1" 31 | export OMP_NUM_THREADS=1 32 | fi 33 | 34 | if [ "$PAPER" == "demo" ]; then 35 | # Catch demo 36 | th main.lua -gpu 0 -zoom 4 -hiddenSize 32 -optimiser adam -steps 500000 -learnStart 50000 -tau 4 -memSize 50000 -epsilonSteps 10000 -valFreq 10000 -valSteps 6000 -bootstraps 0 -memPriority rank -PALpha 0 "$@" 37 | elif [ "$PAPER" == "nature" ]; then 38 | # Nature 39 | th main.lua -env rlenvs.Atari -modelBody models.Atari -game $GAME -cudnn true -height 84 -width 84 -colorSpace y -duel false -bootstraps 0 -epsilonEnd 0.1 -tau 10000 -doubleQ false -PALpha 0 -eta 0.00025 -gradClip 0 "$@" 40 | elif [ "$PAPER" == "doubleq" ]; then 41 | # Double-Q (tuned) 42 | th main.lua -env rlenvs.Atari -modelBody models.Atari -game $GAME -cudnn true -height 84 -width 84 -colorSpace y -duel false -bootstraps 0 -PALpha 0 -eta 0.00025 -gradClip 0 "$@" 43 | elif [ "$PAPER" == "duel" ]; then 44 | # Duel (eta is apparently lower but not specified in paper) 45 | # Note from Tom Schaul: Tuned DDQN hyperparameters are used 46 | th main.lua -env rlenvs.Atari -modelBody models.Atari -game $GAME -cudnn true -height 84 -width 84 -colorSpace y -bootstraps 0 -PALpha 0 -eta 0.00025 "$@" 47 | elif [ "$PAPER" == "prioritised" ]; then 48 | # Prioritised (rank-based) 49 | th main.lua -env rlenvs.Atari -modelBody models.Atari -game $GAME -cudnn true -height 84 -width 84 -colorSpace y -duel false -bootstraps 0 -memPriority rank -alpha 0.7 -betaZero 0.5 -PALpha 0 -gradClip 0 "$@" 50 | elif [ "$PAPER" == "priorduel" ]; then 51 | # Duel with rank-based prioritised experience replay (in duel paper) 52 | th main.lua -env rlenvs.Atari -modelBody models.Atari -game $GAME -cudnn true -height 84 -width 84 -colorSpace y -bootstraps 0 -memPriority rank -alpha 0.7 -betaZero 0.5 -PALpha 0 "$@" 53 | elif [ "$PAPER" == "persistent" ]; then 54 | # Persistent 55 | th main.lua -env rlenvs.Atari -modelBody models.Atari -game $GAME -cudnn true -height 84 -width 84 -colorSpace y -duel false -bootstraps 0 -epsilonEnd 0.1 -tau 10000 -doubleQ false -eta 0.00025 -gradClip 0 "$@" 56 | elif [ "$PAPER" == "bootstrap" ]; then 57 | # Bootstrap 58 | th main.lua -env rlenvs.Atari -modelBody models.Atari -game $GAME -cudnn true -height 84 -width 84 -colorSpace y -duel false -tau 10000 -PALpha 0 -eta 0.00025 -gradClip 0 "$@" 59 | elif [ "$PAPER" == "recurrent" ]; then 60 | # Recurrent (note that evaluation methodology is different) 61 | th main.lua -env rlenvs.Atari -modelBody models.Atari -game $GAME -cudnn true -height 84 -width 84 -colorSpace y -histLen 10 -duel false -bootstraps 0 -recurrent true -memSize 400000 -memSampleFreq 1 -epsilonEnd 0.1 -tau 10000 -doubleQ false -PALpha 0 -optimiser adadelta -eta 0.1 "$@" 62 | 63 | # Async modes 64 | elif [ "$PAPER" == "demo-async" ]; then 65 | # N-Step Q-learning Catch demo 66 | th main.lua -zoom 4 -async NStepQ -eta 0.00025 -momentum 0.99 -bootstraps 0 -batchSize 5 -hiddenSize 32 -doubleQ false -duel false -optimiser adam -steps 15000000 -tau 4 -memSize 20000 -epsilonSteps 10000 -valFreq 10000 -valSteps 6000 -bootstraps 0 -PALpha 0 "$@" 67 | elif [ "$PAPER" == "demo-async-a3c" ]; then 68 | # A3C Catch demo 69 | th main.lua -zoom 4 -async A3C -eta 0.0007 -momentum 0.99 -bootstraps 0 -batchSize 5 -hiddenSize 32 -doubleQ false -duel false -optimiser adam -steps 15000000 -tau 4 -memSize 20000 -epsilonSteps 10000 -valFreq 10000 -valSteps 6000 -bootstraps 0 -PALpha 0 -entropyBeta 0 "$@" 70 | elif [ "$PAPER" == "async-nstep" ]; then 71 | # Steps for "1 day" = 80 * 1e6; for "4 days" = 1e9 72 | th main.lua -env rlenvs.Atari -modelBody models.Atari2013 -hiddenSize 256 -game $GAME -height 84 -width 84 -colorSpace y -async NStepQ -bootstraps 0 -batchSize 5 -momentum 0.99 -rmsEpsilon 0.1 -steps 80000000 -duel false -tau 40000 -optimiser sharedRmsProp -epsilonSteps 4000000 -doubleQ false -PALpha 0 -eta 0.0007 -gradClip 0 "$@" 73 | elif [ "$PAPER" == "async-a3c" ]; then 74 | th main.lua -env rlenvs.Atari -modelBody models.Atari2013 -hiddenSize 256 -game $GAME -height 84 -width 84 -colorSpace y -async A3C -bootstraps 0 -batchSize 5 -momentum 0.99 -rmsEpsilon 0.1 -steps 80000000 -duel false -tau 40000 -optimiser sharedRmsProp -epsilonSteps 4000000 -doubleQ false -PALpha 0 -eta 0.0007 -gradClip 0 "$@" 75 | 76 | # Examples 77 | elif [ "$PAPER" == "demo-grid" ]; then 78 | # GridWorld 79 | th main.lua -env examples/GridWorldVis -modelBody examples/GridWorldNet -histLen 1 -async A3C -zoom 4 -hiddenSize 32 -optimiser adam -steps 400000 -tau 4 -memSize 20000 -valFreq 10000 -valSteps 6000 -doubleQ false -duel false -bootstraps 0 -PALpha 0 "$@" 80 | else 81 | echo "Invalid options" 82 | fi 83 | -------------------------------------------------------------------------------- /Validation.lua: -------------------------------------------------------------------------------- 1 | local _ = require 'moses' 2 | local classic = require 'classic' 3 | local Evaluator = require 'Evaluator' 4 | 5 | local Validation = classic.class('Validation') 6 | 7 | function Validation:_init(opt, agent, env, display) 8 | self.opt = opt 9 | self.agent = agent 10 | self.env = env 11 | 12 | self.hasDisplay = false 13 | if display then 14 | self.hasDisplay = true 15 | self.display = display 16 | end 17 | 18 | -- Create (Atari normalised score) evaluator 19 | self.evaluator = Evaluator(opt.game) 20 | 21 | self.bestValScore = _.max(self.agent.valScores) or -math.huge -- Retrieve best validation score from agent if available 22 | 23 | classic.strict(self) 24 | end 25 | 26 | 27 | function Validation:validate(step) 28 | log.info('Validating') 29 | -- Set environment and agent to evaluation mode 30 | self.env:evaluate() 31 | self.agent:evaluate() 32 | 33 | -- Start new game 34 | local reward, state, terminal = 0, self.env:start(), false 35 | 36 | -- Validation variables 37 | local valEpisode = 1 38 | local valEpisodeScore = 0 39 | local valTotalScore = 0 40 | local valStepStrFormat = '%0' .. (math.floor(math.log10(self.opt.valSteps)) + 1) .. 'd' -- String format for padding step with zeros 41 | 42 | for valStep = 1, self.opt.valSteps do 43 | -- Observe and choose next action (index) 44 | local action = self.agent:observe(reward, state, terminal) 45 | if not terminal then 46 | -- Act on environment 47 | reward, state, terminal = self.env:step(action) 48 | -- Track score 49 | valEpisodeScore = valEpisodeScore + reward 50 | else 51 | -- Print score every 10 episodes 52 | if valEpisode % 10 == 0 then 53 | log.info('[VAL] Steps: ' .. string.format(valStepStrFormat, valStep) .. '/' .. self.opt.valSteps .. ' | Episode ' .. valEpisode .. ' | Score: ' .. valEpisodeScore) 54 | end 55 | 56 | -- Start a new episode 57 | valEpisode = valEpisode + 1 58 | reward, state, terminal = 0, self.env:start(), false 59 | valTotalScore = valTotalScore + valEpisodeScore -- Only add to total score at end of episode 60 | valEpisodeScore = reward -- Reset episode score 61 | end 62 | 63 | -- Display (if available) 64 | if self.hasDisplay then 65 | self.display:display(self.agent, self.env:getDisplay()) 66 | end 67 | end 68 | 69 | -- If no episodes completed then use score from incomplete episode 70 | if valEpisode == 1 then 71 | valTotalScore = valEpisodeScore 72 | end 73 | 74 | -- Print total and average score 75 | log.info('Total Score: ' .. valTotalScore) 76 | valTotalScore = valTotalScore/math.max(valEpisode - 1, 1) -- Only average score for completed episodes in general 77 | log.info('Average Score: ' .. valTotalScore) 78 | -- Pass to agent (for storage and plotting) 79 | self.agent.valScores[#self.agent.valScores + 1] = valTotalScore 80 | -- Calculate normalised score (if valid) 81 | local normScore = self.evaluator:normaliseScore(valTotalScore) 82 | if normScore then 83 | log.info('Normalised Score: ' .. normScore) 84 | self.agent.normScores[#self.agent.normScores + 1] = normScore 85 | end 86 | 87 | -- Visualise convolutional filters 88 | self.agent:visualiseFilters() 89 | 90 | -- Use transitions sampled for validation to test performance 91 | local avgV, avgTdErr = self.agent:validate() 92 | log.info('Average V: ' .. avgV) 93 | log.info('Average δ: ' .. avgTdErr) 94 | 95 | -- Save latest weights 96 | log.info('Saving weights') 97 | if self.opt.checkpoint then 98 | self.agent:saveWeights(paths.concat(self.opt.experiments, self.opt._id, step .. '.weights.t7')) 99 | else 100 | self.agent:saveWeights(paths.concat(self.opt.experiments, self.opt._id, 'last.weights.t7')) 101 | end 102 | 103 | -- Save "best weights" if best score achieved 104 | if valTotalScore > self.bestValScore then 105 | log.info('New best average score') 106 | self.bestValScore = valTotalScore 107 | 108 | log.info('Saving new best weights') 109 | self.agent:saveWeights(paths.concat(self.opt.experiments, self.opt._id, 'best.weights.t7')) 110 | end 111 | 112 | -- Set environment and agent to training mode 113 | self.env:training() 114 | self.agent:training() 115 | end 116 | 117 | 118 | function Validation:evaluate() 119 | log.info('Evaluation mode') 120 | -- Set environment and agent to evaluation mode 121 | self.env:evaluate() 122 | self.agent:evaluate() 123 | 124 | local reward, state, terminal = 0, self.env:start(), false 125 | 126 | -- Report episode score 127 | local episodeScore = reward 128 | 129 | -- Play one game (episode) 130 | local step = 1 131 | while not terminal do 132 | -- Observe and choose next action (index) 133 | action = self.agent:observe(reward, state, terminal) 134 | -- Act on environment 135 | reward, state, terminal = self.env:step(action) 136 | episodeScore = episodeScore + reward 137 | 138 | -- Record (if available) 139 | if self.hasDisplay then 140 | self.display:display(self.agent, self.env:getDisplay(), step) 141 | end 142 | -- Increment evaluation step counter 143 | step = step + 1 144 | end 145 | log.info('Final Score: ' .. episodeScore) 146 | 147 | -- Record (if available) 148 | if self.hasDisplay then 149 | self.display:createVideo() 150 | end 151 | end 152 | 153 | 154 | return Validation 155 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Atari 2 | 3 | Looking to contribute something to Atari? Here's how you can help. 4 | 5 | Please take a moment to review this document in order to make the contribution process easy and effective for everyone involved. 6 | 7 | Following these guidelines helps to communicate that you respect the time of the developer managing and developing this open source project. In return, they should reciprocate that respect in addressing your issue or assessing patches and features. 8 | 9 | ## Using the issue tracker 10 | 11 | The [issue tracker](https://github.com/Kaixhin/Atari/issues) is the preferred channel for [bug reports](#bug-reports), [feature requests](#feature-requests) and [submitting pull requests](#pull-requests), but please respect the following restrictions: 12 | 13 | - Please **do not** use the issue tracker for personal support requests. This includes asking for help on your own code. 14 | 15 | - Please **do not** derail or troll issues. Keep the discussion on topic and respect the opinions of others. 16 | 17 | - Please **do not** post comments consisting solely of "+1" or ":thumbsup:". Use [GitHub's "reactions" feature](https://github.com/blog/2119-add-reactions-to-pull-requests-issues-and-comments) instead. I reserve the right to delete comments which violate this rule. 18 | 19 | - Please **do not** ask for this repository to be ported to another language/outside of Torch7. 20 | 21 | ## Bug reports 22 | 23 | A bug is a _demonstrable problem_ that is caused by the code in the repository. Good bug reports are extremely helpful - thank you! 24 | 25 | Guidelines for bug reports: 26 | 27 | 1. **Use the GitHub issue search** — check if the issue has already been reported. 28 | 29 | 2. **Check if the issue has been fixed** — try to reproduce it using the latest `master` or development branch in the repository. 30 | 31 | 3. **Isolate the problem** — ideally create test case that is within reason, preferably within 100 lines of code. 32 | 33 | A good bug report shouldn't leave others needing to chase you up for more information. Please try to be as detailed as possible in your report. What is your environment? What steps will reproduce the issue? What OS do you experience the problem? What would you expect to be the outcome? All these details will help people to fix any potential bugs. 34 | 35 | Bugs can be somewhat difficult to isolate in machine learning code, but most of the above should still be applicable. 36 | 37 | ## Feature requests 38 | 39 | Feature requests are welcome to be filed. This project has one primary developer who works on this during his free time, so please keep that in mind. 40 | 41 | Before opening a feature request, please take a moment to find out whether your idea fits with the scope and aims of the project. It's up to *you* to make a strong case to convince the project's developer of the merits of this feature. Please provide as much detail and context as possible. 42 | 43 | ## Pull requests 44 | 45 | Good pull requests - patches, improvements, new features - are a fantastic help. They should remain focused in scope **and avoid containing unrelated commits.** 46 | 47 | **Please ask first** before embarking on any significant pull request (e.g. implementing features, refactoring code), otherwise you risk spending a lot of time working on something that the project's developers might not want to merge into the project. 48 | 49 | Please adhere to the [coding guidelines](#code-guidelines) used throughout the project (indentation, accurate comments, etc.) and any other requirements (such as test coverage). 50 | 51 | Adhering to the following this process is the best way to get your work included in the project: 52 | 53 | 1. [Fork](https://help.github.com/articles/fork-a-repo) the project, clone your fork, and configure the remotes: 54 | 55 | ```bash 56 | # Clone your fork of the repo into the current directory 57 | git clone https://github.com//Atari.git 58 | # Navigate to the newly cloned directory 59 | cd Atari 60 | # Assign the original repo to a remote called "upstream" 61 | git remote add upstream https://github.com/Kaixhin/Atari.git 62 | ``` 63 | 64 | 2. If you cloned a while ago, get the latest changes from upstream: 65 | 66 | ```bash 67 | git checkout master 68 | git pull upstream master 69 | ``` 70 | 71 | 3. Create a new topic branch (off the main project development branch) to contain your feature, change, or fix: 72 | 73 | ```bash 74 | git checkout -b 75 | ``` 76 | 77 | 4. Commit your changes in logical chunks. Please try to adhere to these [git commit message guidelines](http://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html). Use Git's [interactive rebase](https://help.github.com/articles/about-git-rebase) feature to tidy up your commits before making them public. 78 | 79 | 5. Locally merge (or rebase) the upstream development branch into your topic branch: 80 | 81 | ```bash 82 | git pull [--rebase] upstream master 83 | ``` 84 | 85 | 6. Push your topic branch up to your fork: 86 | 87 | ```bash 88 | git push origin 89 | ``` 90 | 91 | 7. [Open a Pull Request](https://help.github.com/articles/using-pull-requests/) with a clear title and description. 92 | 93 | ### Code guidelines 94 | 95 | Please try to follow the general coding style in this repository. Loosely: 96 | 97 | - 2 space indentation (spaces or tab) 98 | - camelCase 99 | - Comments encouraged (especially those including mathematical equations) 100 | 101 | ### License 102 | 103 | By contributing your code, you agree to license your contribution under the [MIT License](https://github.com/Kaixhin/Atari/blob/master/LICENSE.md). 104 | -------------------------------------------------------------------------------- /Master.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | local signal = require 'posix.signal' 3 | local Singleton = require 'structures/Singleton' 4 | local Agent = require 'Agent' 5 | local Display = require 'Display' 6 | local Validation = require 'Validation' 7 | 8 | local Master = classic.class('Master') 9 | 10 | -- Sets up environment and agent 11 | function Master:_init(opt) 12 | self.opt = opt 13 | self.verbose = opt.verbose 14 | self.learnStart = opt.learnStart 15 | self.progFreq = opt.progFreq 16 | self.reportWeights = opt.reportWeights 17 | self.noValidation = opt.noValidation 18 | self.valFreq = opt.valFreq 19 | self.experiments = opt.experiments 20 | self._id = opt._id 21 | 22 | -- Set up singleton global object for transferring step 23 | self.globals = Singleton({step = 1}) -- Initial step 24 | 25 | -- Initialise environment 26 | log.info('Setting up ' .. opt.env) 27 | local Env = require(opt.env) 28 | self.env = Env(opt) -- Environment instantiation 29 | 30 | -- Create DQN agent 31 | log.info('Creating DQN') 32 | self.agent = Agent(opt) 33 | if paths.filep(opt.network) then 34 | -- Load saved agent if specified 35 | log.info('Loading pretrained network weights') 36 | self.agent:loadWeights(opt.network) 37 | elseif paths.filep(paths.concat(opt.experiments, opt._id, 'agent.t7')) then 38 | -- Ask to load saved agent if found in experiment folder (resuming training) 39 | log.info('Saved agent found - load (y/n)?') 40 | if io.read() == 'y' then 41 | log.info('Loading saved agent') 42 | self.agent = torch.load(paths.concat(opt.experiments, opt._id, 'agent.t7')) 43 | 44 | -- Reset globals (step) from agent 45 | Singleton.setInstance(self.agent.globals) 46 | self.globals = Singleton.getInstance() 47 | 48 | -- Switch saliency style 49 | self.agent:setSaliency(opt.saliency) 50 | end 51 | end 52 | 53 | -- Start gaming 54 | log.info('Starting ' .. opt.env) 55 | if opt.game ~= '' then 56 | log.info('Starting game: ' .. opt.game) 57 | end 58 | local state = self.env:start() 59 | 60 | -- Set up display (if available) 61 | self.hasDisplay = false 62 | if opt.displaySpec then 63 | self.hasDisplay = true 64 | self.display = Display(opt, self.env:getDisplay()) 65 | end 66 | 67 | -- Set up validation (with display if available) 68 | self.validation = Validation(opt, self.agent, self.env, self.display) 69 | 70 | classic.strict(self) 71 | end 72 | 73 | -- Trains agent 74 | function Master:train() 75 | log.info('Training mode') 76 | 77 | -- Catch CTRL-C to save 78 | self:catchSigInt() 79 | 80 | local reward, state, terminal = 0, self.env:start(), false 81 | 82 | -- Set environment and agent to training mode 83 | self.env:training() 84 | self.agent:training() 85 | 86 | -- Training variables (reported in verbose mode) 87 | local episode = 1 88 | local episodeScore = reward 89 | 90 | -- Training loop 91 | local initStep = self.globals.step -- Extract step 92 | local stepStrFormat = '%0' .. (math.floor(math.log10(self.opt.steps)) + 1) .. 'd' -- String format for padding step with zeros 93 | for step = initStep, self.opt.steps do 94 | self.globals.step = step -- Pass step number to globals for use in other modules 95 | 96 | -- Observe results of previous transition (r, s', terminal') and choose next action (index) 97 | local action = self.agent:observe(reward, state, terminal) -- As results received, learn in training mode 98 | if not terminal then 99 | -- Act on environment (to cause transition) 100 | reward, state, terminal = self.env:step(action) 101 | -- Track score 102 | episodeScore = episodeScore + reward 103 | else 104 | if self.verbose then 105 | -- Print score for episode 106 | log.info('Steps: ' .. string.format(stepStrFormat, step) .. '/' .. self.opt.steps .. ' | Episode ' .. episode .. ' | Score: ' .. episodeScore) 107 | end 108 | 109 | -- Start a new episode 110 | episode = episode + 1 111 | reward, state, terminal = 0, self.env:start(), false 112 | episodeScore = reward -- Reset episode score 113 | end 114 | 115 | -- Display (if available) 116 | if self.hasDisplay then 117 | self.display:display(self.agent, self.env:getDisplay()) 118 | end 119 | 120 | -- Trigger learning after a while (wait to accumulate experience) 121 | if step == self.learnStart then 122 | log.info('Learning started') 123 | end 124 | 125 | -- Report progress 126 | if step % self.progFreq == 0 then 127 | log.info('Steps: ' .. string.format(stepStrFormat, step) .. '/' .. self.opt.steps) 128 | -- Report weight and weight gradient statistics 129 | if self.reportWeights then 130 | local reports = self.agent:report() 131 | for r = 1, #reports do 132 | log.info(reports[r]) 133 | end 134 | end 135 | end 136 | 137 | -- Validate 138 | if not self.noValidation and step >= self.learnStart and step % self.valFreq == 0 then 139 | self.validation:validate(step) -- Sets env and agent to evaluation mode and then back to training mode 140 | 141 | log.info('Resuming training') 142 | -- Start new game (as previous one was interrupted) 143 | reward, state, terminal = 0, self.env:start(), false 144 | episodeScore = reward 145 | end 146 | end 147 | 148 | log.info('Finished training') 149 | end 150 | 151 | function Master:evaluate() 152 | self.validation:evaluate() -- Sets env and agent to evaluation mode 153 | end 154 | 155 | -- Sets up SIGINT (Ctrl+C) handler to save network before quitting 156 | function Master:catchSigInt() 157 | signal.signal(signal.SIGINT, function(signum) 158 | log.warn('SIGINT received') 159 | log.info('Save agent (y/n)?') 160 | if io.read() == 'y' then 161 | log.info('Saving agent') 162 | torch.save(paths.concat(self.experiments, self._id, 'agent.t7'), self.agent) -- Save agent to resume training 163 | end 164 | log.warn('Exiting') 165 | os.exit(128 + signum) 166 | end) 167 | end 168 | 169 | return Master 170 | -------------------------------------------------------------------------------- /structures/BinaryHeap.lua: -------------------------------------------------------------------------------- 1 | local _ = require 'moses' 2 | local classic = require 'classic' 3 | require 'classic.torch' -- Enables serialisation 4 | 5 | -- Implements a Priority Queue using a non-standard (Maximum) Binary Heap 6 | local BinaryHeap = classic.class('BinaryHeap') 7 | 8 | --[[ 9 | -- Priority queue elements: 10 | -- array row 1 (priority/key): absolute TD-error |δ| 11 | -- array row 2 (value): experience replay index 12 | -- ephash key: experience replay index 13 | -- ephash value: priority queue array index 14 | -- pehash key: priority queue array index 15 | -- pehash value: experience replay index 16 | --]] 17 | 18 | -- Creates a new Binary Heap with a length or existing tensor 19 | function BinaryHeap:_init(init) 20 | -- Use values as indices in hash tables (ER -> PQ, PQ -> ER) 21 | self.ephash = {} 22 | self.pehash = {} 23 | 24 | if type(init) == 'number' then 25 | -- init is treated as the length of the heap 26 | self.array = torch.Tensor(init, 2) -- Priorities are 1st, values (which are used as hash table keys) are 2nd 27 | self.size = 0 28 | else 29 | -- Otherwise assume tensor to build heap from 30 | self.array = init 31 | self.size = init:size(1) 32 | -- Convert values to form hash tables 33 | self.ephash = torch.totable(self.array:select(2, 2)) 34 | self.pehash = _.invert(self.ephash) 35 | -- Rebalance 36 | for i = math.floor(self.size/2) - 1, 1, -1 do 37 | self:downHeap(i) 38 | end 39 | end 40 | end 41 | 42 | -- Checks if heap is full 43 | function BinaryHeap:isFull() 44 | return self.size == self.array:size(1) 45 | end 46 | 47 | --[[ 48 | -- Indices of connected nodes: 49 | -- Parent(i) = floor(i/2) 50 | -- Left_Child(i) = 2i 51 | -- Right_Child(i) = 2i+1 52 | --]] 53 | 54 | -- Inserts a new value 55 | function BinaryHeap:insert(priority, val) 56 | -- Refuse to add values if no space left 57 | if self:isFull() then 58 | print('Error: no space left in heap to add value ' .. val .. ' with priority ' .. priority) 59 | return 60 | end 61 | 62 | -- Add value to end 63 | self.size = self.size + 1 64 | self.array[self.size][1] = priority 65 | self.array[self.size][2] = val 66 | -- Update hash tables 67 | self.ephash[val] = self.size 68 | self.pehash[self.size] = val 69 | 70 | -- Rebalance 71 | self:upHeap(self.size) 72 | end 73 | 74 | -- Updates a value (and rebalances) 75 | function BinaryHeap:update(i, priority, val) 76 | if i > self.size then 77 | print('Error: index ' .. i .. ' is greater than the current size of the heap') 78 | return 79 | end 80 | 81 | -- Replace value 82 | self.array[i][1] = priority 83 | self.array[i][2] = val 84 | -- Update hash tables 85 | self.ephash[val] = i 86 | self.pehash[i] = val 87 | 88 | -- Rebalance 89 | self:downHeap(i) 90 | self:upHeap(i) 91 | end 92 | 93 | -- Updates a value by using the value (using the ER -> PQ hash table) 94 | function BinaryHeap:updateByVal(valKey, priority, val) 95 | self:update(self.ephash[valKey], priority, val) 96 | end 97 | 98 | -- Returns the maximum priority with value 99 | function BinaryHeap:findMax() 100 | return self.size ~= 0 and self.array[1][1] or nil 101 | end 102 | 103 | -- Removes and returns the maximum priority with value 104 | function BinaryHeap:pop() 105 | -- Return nil if no values 106 | if self.size == 0 then 107 | print('Error: no values in heap') 108 | return nil 109 | end 110 | 111 | local max = self.array[1]:clone() 112 | 113 | -- Move the last value (not necessarily the smallest) to the root 114 | self.array[1] = self.array[self.size] 115 | self.size = self.size - 1 116 | -- Update hash tables 117 | self.ephash[self.array[1][2]], self.pehash[1] = 1, self.array[1][2] 118 | 119 | -- Rebalance 120 | self:downHeap(1) 121 | 122 | return max 123 | end 124 | 125 | -- Rebalances the heap (by moving large values up) 126 | function BinaryHeap:upHeap(i) 127 | -- Calculate parent index 128 | local p = math.floor(i/2) 129 | 130 | if i > 1 then 131 | -- If parent is smaller than child then swap 132 | if self.array[p][1] < self.array[i][1] then 133 | self.array[i], self.array[p] = self.array[p]:clone(), self.array[i]:clone() 134 | -- Update hash tables 135 | self.ephash[self.array[i][2]], self.ephash[self.array[p][2]], self.pehash[i], self.pehash[p] = i, p, self.array[i][2], self.array[p][2] 136 | 137 | -- Continue rebalancing 138 | self:upHeap(p) 139 | end 140 | end 141 | end 142 | 143 | -- Rebalances the heap (by moving small values down) 144 | function BinaryHeap:downHeap(i) 145 | -- Calculate left and right child indices 146 | local l, r = 2*i, 2*i + 1 147 | 148 | -- Find the index of the greatest of these elements 149 | local greatest 150 | if l <= self.size and self.array[l][1] > self.array[i][1] then 151 | greatest = l 152 | else 153 | greatest = i 154 | end 155 | if r <= self.size and self.array[r][1] > self.array[greatest][1] then 156 | greatest = r 157 | end 158 | 159 | -- Continue rebalancing if necessary 160 | if greatest ~= i then 161 | self.array[i], self.array[greatest] = self.array[greatest]:clone(), self.array[i]:clone() 162 | -- Update hash tables 163 | self.ephash[self.array[i][2]], self.ephash[self.array[greatest][2]], self.pehash[i], self.pehash[greatest] = i, greatest, self.array[i][2], self.array[greatest][2] 164 | 165 | self:downHeap(greatest) 166 | end 167 | end 168 | 169 | -- Retrieves priorities 170 | function BinaryHeap:getPriorities() 171 | return self.array:narrow(2, 1, 1) 172 | end 173 | 174 | -- Retrieves values 175 | function BinaryHeap:getValues() 176 | return self.array:narrow(2, 2, 1) 177 | end 178 | 179 | -- Basic visualisation of heap 180 | function BinaryHeap:__tostring() 181 | local str = '' 182 | local level = -1 183 | local maxLevel = math.floor(math.log(self.size, 2)) 184 | 185 | -- Print each level 186 | for i = 1, self.size do 187 | -- Add a new line and spacing for each new level 188 | local l = math.floor(math.log(i, 2)) 189 | if level ~= l then 190 | str = str .. '\n' .. string.rep(' ', math.pow(2, maxLevel - l)) 191 | level = l 192 | end 193 | -- Print value and spacing 194 | str = str .. string.format('%.2f ', self.array[i][2]) .. string.rep(' ', maxLevel - l) 195 | end 196 | 197 | return str 198 | end 199 | 200 | -- Retrieves a value by using the value (using the PQ -> ER hash table) 201 | function BinaryHeap:getValueByVal(hashIndex) 202 | return self.pehash[hashIndex] 203 | end 204 | 205 | -- Retrieves a list of values by using the value (using the PQ -> ER hash table) 206 | function BinaryHeap:getValuesByVal(hashIndices) 207 | return _.at(self.pehash, table.unpack(hashIndices)) 208 | end 209 | 210 | -- Rebalances the heap 211 | -- Note from Tom Schaul: Solution for rebalancing (below) is good; original solution not revealed 212 | function BinaryHeap:rebalance() 213 | -- Sort underlying array 214 | local sortArray, sortIndices = torch.sort(self.array, 1, true) 215 | -- Retrieve values (indices) in descending priority order 216 | sortIndices = self.array:index(1, sortIndices:select(2, 1)):select(2, 2) 217 | -- Put values with corresponding priorities 218 | sortArray[{{}, {2}}] = sortIndices 219 | -- Convert values to form hash tables 220 | self.pehash = torch.totable(sortIndices) 221 | self.ephash = _.invert(self.pehash) 222 | -- Replace array 223 | self.array = sortArray 224 | -- Fix heap 225 | for i = math.floor(self.size/2) - 1, 1, -1 do 226 | self:downHeap(i) 227 | end 228 | end 229 | 230 | return BinaryHeap 231 | -------------------------------------------------------------------------------- /async/AsyncMaster.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | local threads = require 'threads' 3 | local tds = require 'tds' 4 | local signal = require 'posix.signal' 5 | local AsyncModel = require 'async/AsyncModel' 6 | local AsyncAgent = require 'async/AsyncAgent' 7 | local QAgent = require 'async/QAgent' 8 | local OneStepQAgent = require 'async/OneStepQAgent' 9 | local NStepQAgent = require 'async/NStepQAgent' 10 | local A3CAgent = require 'async/A3CAgent' 11 | local ValidationAgent = require 'async/ValidationAgent' 12 | require 'socket' 13 | threads.Threads.serialization('threads.sharedserialize') 14 | 15 | local FINISHED = -99999999 16 | 17 | local AsyncMaster = classic.class('AsyncMaster') 18 | 19 | 20 | local function checkNotNan(t) 21 | local sum = t:sum() 22 | local ok = sum == sum 23 | if not ok then 24 | log.error('ERROR '.. sum) 25 | end 26 | assert(ok) 27 | end 28 | 29 | local function torchSetup(opt) 30 | local tensorType = opt.tensorType 31 | local seed = opt.seed 32 | return function() 33 | log.info('Setting up Torch7') 34 | require 'nn' 35 | require 'rnn' 36 | require 'nngraph' 37 | require 'modules/GradientRescale' 38 | -- Set number of BLAS threads to 1 (per thread) 39 | torch.setnumthreads(1) 40 | -- Set default Tensor type (float is more efficient than double) 41 | torch.setdefaulttensortype(tensorType) 42 | -- Set manual seed (different for each thread to have different experiences) 43 | torch.manualSeed(seed * __threadid) 44 | end 45 | end 46 | 47 | local function threadedFormatter(thread) 48 | local threadName = thread 49 | 50 | return function(level, ...) 51 | local msg = nil 52 | 53 | if #{...} > 1 then 54 | msg = string.format(({...})[1], table.unpack(fn.rest({...}))) 55 | else 56 | msg = pprint.pretty_string(({...})[1]) 57 | end 58 | 59 | return string.format("[%s: %s - %s] - %s\n", threadName, logroll.levels[level], os.date("%Y_%m_%d_%X"), msg) 60 | end 61 | end 62 | 63 | local function setupLogging(opt, thread) 64 | local _id = opt._id 65 | local threadName = thread 66 | return function() 67 | unpack = table.unpack -- TODO: Remove global unpack from dependencies 68 | -- Create log10 for Lua 5.2 69 | if not math.log10 then 70 | math.log10 = function(x) 71 | return math.log(x, 10) 72 | end 73 | end 74 | 75 | require 'logroll' 76 | local thread = threadName or __threadid 77 | if type(thread) == 'number' then 78 | thread = ('%02d'):format(thread) 79 | end 80 | local file = paths.concat('experiments', _id, 'log.'.. thread ..'.txt') 81 | local flog = logroll.file_logger(file) 82 | local formatterFunc = threadedFormatter(thread) 83 | local plog = logroll.print_logger({formatter = formatterFunc}) 84 | log = logroll.combine(flog, plog) 85 | end 86 | end 87 | 88 | 89 | function AsyncMaster:_init(opt) 90 | self.opt = opt 91 | 92 | self.stateFile = paths.concat('experiments', opt._id, 'agent.async.t7') 93 | 94 | local asyncModel = AsyncModel(opt) 95 | local policyNet = asyncModel:createNet() 96 | self.theta = policyNet:getParameters() 97 | 98 | log.info('%s', policyNet) 99 | 100 | if paths.filep(opt.network) then 101 | log.info('Loading pretrained network weights') 102 | local weights = torch.load(opt.network) 103 | self.theta:copy(weights) 104 | end 105 | 106 | self.atomic = tds.AtomicCounter() 107 | 108 | local targetNet = policyNet:clone() 109 | self.targetTheta = targetNet:getParameters() 110 | local sharedG = self.theta:clone():zero() 111 | 112 | local theta = self.theta 113 | local targetTheta = self.targetTheta 114 | local stateFile = self.stateFile 115 | local atomic = self.atomic 116 | 117 | self.controlPool = threads.Threads(1) 118 | 119 | self.controlPool:addjob(setupLogging(opt, 'VA')) 120 | self.controlPool:addjob(torchSetup(opt)) 121 | self.controlPool:addjob(function() 122 | -- distinguish from thread 1 in the agent pool 123 | __threadid = 0 124 | local signal = require 'posix.signal' 125 | local ValidationAgent = require 'async/ValidationAgent' 126 | validAgent = ValidationAgent(opt, theta, atomic) 127 | if not opt.noValidation then 128 | signal.signal(signal.SIGINT, function(signum) 129 | log.warn('SIGINT received') 130 | log.info('Saving agent') 131 | local globalSteps = atomic:get() 132 | local state = { globalSteps = globalSteps } 133 | torch.save(stateFile, state) 134 | 135 | validAgent:saveWeights('last') 136 | log.warn('Exiting') 137 | os.exit(128 + signum) 138 | end) 139 | end 140 | end) 141 | 142 | self.controlPool:synchronize() 143 | 144 | -- without locking xitari sometimes crashes during initialization 145 | -- but not later... but is it really threadsafe then...? 146 | local mutex = threads.Mutex() 147 | local mutexId = mutex:id() 148 | self.pool = threads.Threads(self.opt.threads, function() 149 | end, 150 | setupLogging(opt), 151 | torchSetup(opt), 152 | function() 153 | local threads1 = require 'threads' 154 | local AsyncAgent = require 'async/AsyncAgent' 155 | local mutex1 = threads1.Mutex(mutexId) 156 | mutex1:lock() 157 | agent = AsyncAgent.build(opt, policyNet, targetNet, theta, targetTheta, atomic, sharedG) 158 | mutex1:unlock() 159 | end 160 | ) 161 | mutex:free() 162 | 163 | classic.strict(self) 164 | end 165 | 166 | 167 | function AsyncMaster:start() 168 | local stepsToGo = math.floor(self.opt.steps / self.opt.threads) 169 | local startStep = 0 170 | if self.opt.network ~= '' and paths.filep(self.stateFile) then 171 | local state = torch.load(self.stateFile) 172 | stepsToGo = math.floor((self.opt.steps - state.globalSteps) / self.opt.threads) 173 | startStep = math.floor(state.globalSteps / self.opt.threads) 174 | self.atomic:set(state.globalSteps) 175 | log.info('Resuming training from step %d', state.globalSteps) 176 | log.info('Loading pretrained network weights') 177 | local weights = torch.load(paths.concat('experiments', self.opt._id, 'last.weights.t7')) 178 | self.theta:copy(weights) 179 | self.targetTheta:copy(self.theta) 180 | end 181 | 182 | local atomic = self.atomic 183 | local opt = self.opt 184 | local theta = self.theta 185 | local targetTheta = self.targetTheta 186 | 187 | local validator = function() 188 | local posix = require 'posix' 189 | validAgent:start() 190 | local lastUpdate = 0 191 | while true do 192 | local globalStep = atomic:get() 193 | if globalStep < 0 then return end 194 | 195 | local countSince = globalStep - lastUpdate 196 | if countSince > opt.valFreq then 197 | log.info('starting validation after %d steps', countSince) 198 | lastUpdate = globalStep 199 | local status, err = xpcall(validAgent.validate, debug.traceback, validAgent) 200 | if not status then 201 | log.error('%s', err) 202 | os.exit(128) 203 | end 204 | end 205 | posix.sleep(1) 206 | end 207 | end 208 | 209 | if not self.opt.noValidation then 210 | self.controlPool:addjob(validator) 211 | end 212 | 213 | for i=1,self.opt.threads do 214 | self.pool:addjob(function() 215 | local status, err = xpcall(agent.learn, debug.traceback, agent, stepsToGo, startStep) 216 | if not status then 217 | log.error('%s', err) 218 | os.exit(128) 219 | end 220 | end) 221 | end 222 | 223 | self.pool:synchronize() 224 | self.atomic:set(FINISHED) 225 | 226 | self.controlPool:synchronize() 227 | 228 | self.pool:terminate() 229 | self.controlPool:terminate() 230 | end 231 | 232 | return AsyncMaster 233 | 234 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Atari ![Space Invader](http://www.rw-designer.com/cursor-view/74522.png) 2 | [![Build Status](https://img.shields.io/travis/Kaixhin/Atari.svg)](https://travis-ci.org/Kaixhin/Atari) 3 | [![MIT License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE.md) 4 | [![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/Kaixhin/Atari?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 5 | 6 | **Work In Progress:** Crossed out items have been partially implemented. 7 | 8 | ~~Prioritised experience replay~~ [[1]](#references) persistent advantage learning [[2]](#references) ~~bootstrapped~~ [[3]](#references) dueling [[4]](#references) double [[5]](#references) deep ~~recurrent~~ [[6]](#references) Q-network [[7]](#references) for the Arcade Learning Environment [[8]](#references) (and [custom environments](#custom)). Or PERPALB(triple-D)RQN for short... 9 | 10 | Additional asynchronous agents [[9]](#references): 11 | 12 | - One-step Sarsa 13 | - One-step Q-learning 14 | - N-step Q-learning 15 | - Advantage actor-critic 16 | 17 | Run `th main.lua` to run headless, or `qlua main.lua` to display the game. The main options are `-game` to choose the ROM (see the [ROM directory](roms/README.md) for more details) and `-mode` as either `train` or `eval`. Can visualise saliency maps [[10]](#references), optionally using guided [[11]](#references) or "deconvnet" [[12]](#references) backpropagation. Saliency map modes are applied at runtime so that they can be applied retrospectively to saved models. 18 | 19 | To run experiments based on hyperparameters specified in the individual papers, use `./run.sh `. `` can be used to overwrite arguments specified earlier (in the script); for more details see the script itself. By default the code trains on a demo environment called Catch - use `./run.sh demo` to run the demo with good default parameters. Note that this code uses CUDA if available, but the Catch network is small enough that it runs faster on CPU. If cuDNN is available, it can be enabled using `-cudnn true`; note that by default cuDNN is nondeterministic, and its deterministic modes are slower than cutorch. 20 | 21 | In training mode if you want to quit using `Ctrl+C` then this will be caught and you will be asked if you would like to save the agent. Note that for non-asynchronous agents the experience replay memory will be included, totalling ~7GB. The main script also automatically saves the last weights (`last.weights.t7`) and the weights of the best performing DQN (according to the average validation score) (`best.weights.t7`). 22 | 23 | In evaluation mode you can create recordings with `-record true` (requires FFmpeg); this does not require using `qlua`. Recordings will be stored in the videos directory. 24 | 25 | ## Requirements 26 | 27 | Requires [Torch7](http://torch.ch/), and can use CUDA and cuDNN if available. Also requires the following extra luarocks packages: 28 | 29 | - luaposix 33.4.0 30 | - luasocket 31 | - moses 32 | - logroll 33 | - classic 34 | - torchx 35 | - rnn 36 | - dpnn 37 | - nninit 38 | - tds 39 | - **xitari** 40 | - **alewrap** 41 | - **rlenvs** 42 | 43 | xitari, alewrap and rlenvs can be installed using the following commands: 44 | 45 | ```sh 46 | luarocks install https://raw.githubusercontent.com/lake4790k/xitari/master/xitari-0-0.rockspec 47 | luarocks install https://raw.githubusercontent.com/Kaixhin/alewrap/master/alewrap-0-0.rockspec 48 | luarocks install https://raw.githubusercontent.com/Kaixhin/rlenvs/master/rocks/rlenvs-scm-1.rockspec 49 | ``` 50 | 51 | ## Custom 52 | 53 | You can use a custom environment (as the path to a Lua file/`rlenvs`-namespaced environment) using `-env`, as long as the class returned respects the `rlenvs` [API](https://github.com/Kaixhin/rlenvs#api). One restriction is that the state must be represented as a single tensor (with arbitrary dimensionality), and only a single discrete action must be returned. To prevent massive memory consumption for agents that use experience replay memory, states are discretised to integers ∈ [0, 255], assuming the state is comprised of reals ∈ [0, 1] - this can be disabled with `-discretiseMem false`. Visual environments can make use of explicit `-height`, `-width` and `-colorSpace` options to perform preprocessing for the network. 54 | 55 | If the environment has separate behaviour during training and testing it should also implement `training` and `evaluate` methods - otherwise these will be added as empty methods during runtime. The environment can also implement a `getDisplay` method (with a mandatory `getDisplaySpec` method for determining screen size) which will be used for displaying the screen/computing saliency maps, where `getDisplay` must return a RGB (3D) tensor; this can also be utilised even if the state is not an image (although saliency can only be computed for states that are images). This **must** be implemented to have a visual display/computing saliency maps. The `-zoom` factor can be used to increase the size of small displays. 56 | 57 | Environments are meant to be ephemeral, as an instance is created in order to first extract environment details (e.g. state representation), which will later be automatically garbage collected (not under the control of this code). 58 | 59 | You can also use a custom model (body) with `-modelBody`, which replaces the usual DQN convolutional layers with a custom Torch neural network (as the path to a Lua file/`models`-namespaced environment). The class must include a `createBody` method which returns the custom neural network. The model will receive a stack of the previous states (as determined by `-histLen`), and must reshape them manually if needed. The DQN "heads" will then be constructed as normal, with `-hiddenSize` used to change the size of the fully connected layer if needed. 60 | 61 | For an example on a GridWorld environment, run `./run.sh demo-grid` - the demo also works with `qlua` and experience replay agents. The custom environment and network can be found in the [examples](https://github.com/Kaixhin/Atari/tree/master/examples) folder. 62 | 63 | ## Results 64 | 65 | Single run results from various papers can be seen below. DQN-based agents use [ε = 0.001](https://github.com/Kaixhin/Atari/blob/master/Agent.lua#L162) for evaluation [[4, 5]](#references). 66 | 67 | ### DQN (Space Invaders) [[7]](#references) 68 | 69 | ![DQN](figures/dqn_space_invaders.png) 70 | 71 | ### Double DQN (Space Invaders) [[5]](#references) 72 | 73 | ![DDQN](figures/doubleq_space_invaders.png) 74 | 75 | ### Dueling DQN (Space Invaders) [[4]](#references) 76 | 77 | ![DuelingDQN](figures/dueling_space_invaders.png) 78 | 79 | ### Persistent Advantage Learning DQN (Asterix) [[2]](#references) 80 | 81 | ![PALDQN](figures/pal_asterix.png) 82 | 83 | ### A3C (Beam Rider) [[9]](#references) 84 | 85 | ![A3C](figures/a3c_beam_rider.png) 86 | 87 | ## Acknowledgements 88 | 89 | - [@GeorgOstrovski](https://github.com/GeorgOstrovski) for confirmation on network usage in advantage operators + note on interaction with Double DQN. 90 | - [@schaul](https://github.com/schaul) for clarifications on prioritised experience replay + dueling DQN hyperparameters. 91 | 92 | ## Citation 93 | 94 | If you find this library useful and would like to cite it, the following would be appropriate: 95 | 96 | ``` 97 | @misc{Atari, 98 | author = {Arulkumaran, Kai and Keri, Laszlo}, 99 | title = {Kaixhin/Atari}, 100 | url = {https://github.com/Kaixhin/Atari}, 101 | year = {2015} 102 | } 103 | ``` 104 | 105 | ## References 106 | 107 | [1] [Prioritized Experience Replay](http://arxiv.org/abs/1511.05952) 108 | [2] [Increasing the Action Gap: New Operators for Reinforcement Learning](http://arxiv.org/abs/1512.04860) 109 | [3] [Deep Exploration via Bootstrapped DQN](http://arxiv.org/abs/1602.04621) 110 | [4] [Dueling Network Architectures for Deep Reinforcement Learning](http://arxiv.org/abs/1511.06581) 111 | [5] [Deep Reinforcement Learning with Double Q-learning](http://arxiv.org/abs/1509.06461) 112 | [6] [Deep Recurrent Q-Learning for Partially Observable MDPs](http://arxiv.org/abs/1507.06527) 113 | [7] [Playing Atari with Deep Reinforcement Learning](http://arxiv.org/abs/1312.5602) 114 | [8] [The Arcade Learning Environment: An Evaluation Platform for General Agents](http://arxiv.org/abs/1207.4708) 115 | [9] [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783) 116 | [10] [Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps](http://arxiv.org/abs/1312.6034) 117 | [11] [Striving for Simplicity: The All Convolutional Net](http://arxiv.org/abs/1412.6806) 118 | [12] [Visualizing and Understanding Convolutional Networks](http://arxiv.org/abs/1311.2901) 119 | -------------------------------------------------------------------------------- /Model.lua: -------------------------------------------------------------------------------- 1 | local _ = require 'moses' 2 | local paths = require 'paths' 3 | local classic = require 'classic' 4 | local nn = require 'nn' 5 | local hasCudnn, cudnn = pcall(require, 'cudnn') -- Use cuDNN if available 6 | local nninit = require 'nninit' 7 | local image = require 'image' 8 | local DuelAggregator = require 'modules/DuelAggregator' 9 | require 'classic.torch' -- Enables serialisation 10 | require 'rnn' 11 | require 'dpnn' -- Adds gradParamClip method 12 | require 'modules/GuidedReLU' 13 | require 'modules/DeconvnetReLU' 14 | require 'modules/GradientRescale' 15 | require 'modules/MinDim' 16 | 17 | local Model = classic.class('Model') 18 | 19 | -- Creates a Model (a helper for the network it creates) 20 | function Model:_init(opt) 21 | -- Extract relevant options 22 | self.tensorType = opt.tensorType 23 | self.gpu = opt.gpu 24 | self.cudnn = opt.cudnn 25 | self.colorSpace = opt.colorSpace 26 | self.width = opt.width 27 | self.height = opt.height 28 | self.modelBody = opt.modelBody 29 | self.hiddenSize = opt.hiddenSize 30 | self.histLen = opt.histLen 31 | self.duel = opt.duel 32 | self.bootstraps = opt.bootstraps 33 | self.recurrent = opt.recurrent 34 | self.env = opt.env 35 | self.modelBody = opt.modelBody 36 | self.async = opt.async 37 | self.a3c = opt.async == 'A3C' 38 | self.stateSpec = opt.stateSpec 39 | 40 | self.m = opt.actionSpec[3][2] - opt.actionSpec[3][1] + 1 -- Number of discrete actions 41 | -- Set up resizing 42 | if opt.width ~= 0 or opt.height ~= 0 then 43 | self.resize = true 44 | self.width = opt.width ~= 0 and opt.width or opt.stateSpec[2][3] 45 | self.height = opt.height ~= 0 and opt.height or opt.stateSpec[2][2] 46 | end 47 | end 48 | 49 | -- Processes a single frame for DQN input; must not return same memory to prevent side-effects 50 | function Model:preprocess(observation) 51 | local frame = observation:type(self.tensorType) -- Convert from CudaTensor if necessary 52 | 53 | -- Perform colour conversion if needed 54 | if self.colorSpace then 55 | frame = image['rgb2' .. self.colorSpace](frame) 56 | end 57 | 58 | -- Resize screen if needed 59 | if self.resize then 60 | frame = image.scale(frame, self.width, self.height) 61 | end 62 | 63 | -- Clone if needed 64 | if frame == observation then 65 | frame = frame:clone() 66 | end 67 | 68 | return frame 69 | end 70 | 71 | -- Calculates network output size 72 | local function getOutputSize(net, inputDims) 73 | return net:forward(torch.Tensor(torch.LongStorage(inputDims))):size():totable() 74 | end 75 | 76 | -- Creates a DQN/AC model based on a number of discrete actions 77 | function Model:create() 78 | -- Number of input frames for recurrent networks is always 1 79 | local histLen = self.recurrent and 1 or self.histLen 80 | 81 | -- Network starting with convolutional layers/model body 82 | local net = nn.Sequential() 83 | if self.recurrent then 84 | net:add(nn.Copy(nil, nil, true)) -- Needed when splitting batch x seq x input over seq for DRQN; better than nn.Contiguous 85 | end 86 | 87 | -- Add network body 88 | log.info('Setting up ' .. self.modelBody) 89 | local Body = require(self.modelBody) 90 | local body = Body(self):createBody() 91 | 92 | -- Calculate body output size 93 | local bodyOutputSize = torch.prod(torch.Tensor(getOutputSize(body, _.append({histLen}, self.stateSpec[2])))) 94 | if not self.async and self.recurrent then 95 | body:add(nn.View(-1, bodyOutputSize)) 96 | net:add(nn.MinDim(1, 4)) 97 | net:add(nn.Transpose({1, 2})) 98 | body = nn.Bottle(body, 4, 2) 99 | net:add(body) 100 | net:add(nn.MinDim(1, 3)) 101 | else 102 | body:add(nn.View(bodyOutputSize)) 103 | net:add(body) 104 | end 105 | 106 | -- Network head 107 | local head = nn.Sequential() 108 | local heads = math.max(self.bootstraps, 1) 109 | if self.duel then 110 | -- Value approximator V^(s) 111 | local valStream = nn.Sequential() 112 | if self.recurrent and self.async then 113 | local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen) 114 | lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1) 115 | lstm:remember('both') 116 | valStream:add(lstm) 117 | elseif self.recurrent then 118 | local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize) 119 | lstm:remember('both') 120 | valStream:add(lstm) 121 | valStream:add(nn.Select(-3, -1)) -- Select last timestep 122 | else 123 | valStream:add(nn.Linear(bodyOutputSize, self.hiddenSize)) 124 | valStream:add(nn.ReLU(true)) 125 | end 126 | valStream:add(nn.Linear(self.hiddenSize, 1)) -- Predicts value for state 127 | 128 | -- Advantage approximator A^(s, a) 129 | local advStream = nn.Sequential() 130 | if self.recurrent and self.async then 131 | local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen) 132 | lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1) -- Extra: high forget gate bias (Gers et al., 2000) 133 | lstm:remember('both') 134 | advStream:add(lstm) 135 | elseif self.recurrent then 136 | local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize) 137 | lstm:remember('both') 138 | advStream:add(lstm) 139 | advStream:add(nn.Select(-3, -1)) -- Select last timestep 140 | else 141 | advStream:add(nn.Linear(bodyOutputSize, self.hiddenSize)) 142 | advStream:add(nn.ReLU(true)) 143 | end 144 | advStream:add(nn.Linear(self.hiddenSize, self.m)) -- Predicts action-conditional advantage 145 | 146 | -- Streams container 147 | local streams = nn.ConcatTable() 148 | streams:add(valStream) 149 | streams:add(advStream) 150 | 151 | -- Network finishing with fully connected layers 152 | head:add(nn.GradientRescale(1/math.sqrt(2), true)) -- Heuristic that mildly increases stability for duel 153 | -- Create dueling streams 154 | head:add(streams) 155 | -- Add dueling streams aggregator module 156 | head:add(DuelAggregator(self.m)) 157 | else 158 | if self.recurrent and self.async then 159 | local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen) 160 | lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1) -- Extra: high forget gate bias (Gers et al., 2000) 161 | lstm:remember('both') 162 | head:add(lstm) 163 | elseif self.recurrent then 164 | local lstm = nn.SeqLSTM(bodyOutputSize, self.hiddenSize) 165 | lstm:remember('both') 166 | head:add(lstm) 167 | head:add(nn.Select(-3, -1)) -- Select last timestep 168 | else 169 | head:add(nn.Linear(bodyOutputSize, self.hiddenSize)) 170 | head:add(nn.ReLU(true)) -- DRQN paper reports worse performance with ReLU after LSTM 171 | end 172 | head:add(nn.Linear(self.hiddenSize, self.m)) -- Note: Tuned DDQN uses shared bias at last layer 173 | end 174 | 175 | if self.bootstraps > 0 then 176 | -- Add bootstrap heads 177 | local headConcat = nn.ConcatTable() 178 | for h = 1, heads do 179 | -- Clone head structure 180 | local bootHead = head:clone() 181 | -- Each head should use a different random initialisation to construct bootstrap (currently Torch default) 182 | local linearLayers = bootHead:findModules('nn.Linear') 183 | for l = 1, #linearLayers do 184 | linearLayers[l]:init('weight', nninit.kaiming, {dist = 'uniform', gain = 1/math.sqrt(3)}):init('bias', nninit.kaiming, {dist = 'uniform', gain = 1/math.sqrt(3)}) 185 | end 186 | headConcat:add(bootHead) 187 | end 188 | net:add(nn.GradientRescale(1/self.bootstraps)) -- Normalise gradients by number of heads 189 | net:add(headConcat) 190 | elseif self.a3c then 191 | -- Actor-critic does not use the normal head but instead a concatenated value function V and policy π 192 | net:add(nn.Linear(bodyOutputSize, self.hiddenSize)) 193 | net:add(nn.ReLU(true)) 194 | 195 | local valueAndPolicy = nn.ConcatTable() -- π and V share all layers except the last 196 | 197 | -- Value function V(s; θv) 198 | local valueFunction = nn.Linear(self.hiddenSize, 1) 199 | 200 | -- Policy π(a | s; θπ) 201 | local policy = nn.Sequential() 202 | policy:add(nn.Linear(self.hiddenSize, self.m)) 203 | policy:add(nn.SoftMax()) 204 | 205 | valueAndPolicy:add(valueFunction) 206 | valueAndPolicy:add(policy) 207 | 208 | net:add(valueAndPolicy) 209 | else 210 | -- Add head via ConcatTable (simplifies bootstrap code in agent) 211 | local headConcat = nn.ConcatTable() 212 | headConcat:add(head) 213 | net:add(headConcat) 214 | end 215 | 216 | if not self.a3c then 217 | net:add(nn.JoinTable(1, 1)) 218 | net:add(nn.View(heads, self.m)) 219 | end 220 | -- GPU conversion 221 | if self.gpu > 0 then 222 | require 'cunn' 223 | net:cuda() 224 | 225 | if self.cudnn and hasCudnn then 226 | cudnn.convert(net, cudnn) 227 | -- The following is legacy code that can make cuDNN deterministic (with a large drop in performance) 228 | --[[ 229 | local convs = net:findModules('cudnn.SpatialConvolution') 230 | for i, v in ipairs(convs) do 231 | v:setMode('CUDNN_CONVOLUTION_FWD_ALGO_GEMM', 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_1', 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1') 232 | end 233 | --]] 234 | end 235 | end 236 | 237 | -- Save reference to network 238 | self.net = net 239 | 240 | return net 241 | end 242 | 243 | function Model:setNetwork(net) 244 | self.net = net 245 | end 246 | 247 | -- Return list of convolutional filters as list of images 248 | function Model:getFilters() 249 | local filters = {} 250 | 251 | -- Find convolutional layers 252 | local convs = self.net:findModules(self.cudnn and hasCudnn and 'cudnn.SpatialConvolution' or 'nn.SpatialConvolution') 253 | for i, v in ipairs(convs) do 254 | -- Add filter to list (with each layer on a separate row) 255 | filters[#filters + 1] = image.toDisplayTensor(v.weight:view(v.nOutputPlane*v.nInputPlane, v.kH, v.kW), 1, v.nInputPlane, true) 256 | end 257 | 258 | return filters 259 | end 260 | 261 | -- Set ReLUs up for specified saliency visualisation type 262 | function Model:setSaliency(saliency) 263 | -- Set saliency 264 | self.saliency = saliency 265 | 266 | -- Find ReLUs on existing model 267 | local relus, relucontainers = self.net:findModules(hasCudnn and 'cudnn.ReLU' or 'nn.ReLU') 268 | if #relus == 0 then 269 | relus, relucontainers = self.net:findModules('nn.GuidedReLU') 270 | end 271 | if #relus == 0 then 272 | relus, relucontainers = self.net:findModules('nn.DeconvnetReLU') 273 | end 274 | 275 | -- Work out which ReLU to use now 276 | local layerConstructor = hasCudnn and cudnn.ReLU or nn.ReLU 277 | self.relus = {} --- Clear special ReLU list to iterate over for salient backpropagation 278 | if saliency == 'guided' then 279 | layerConstructor = nn.GuidedReLU 280 | elseif saliency == 'deconvnet' then 281 | layerConstructor = nn.DeconvnetReLU 282 | end 283 | 284 | -- Replace ReLUs 285 | for i = 1, #relus do 286 | -- Create new special ReLU 287 | local layer = layerConstructor() 288 | 289 | -- Copy everything over 290 | for key, val in pairs(relus[i]) do 291 | layer[key] = val 292 | end 293 | 294 | -- Find ReLU in containing module and replace 295 | for j = 1, #(relucontainers[i].modules) do 296 | if relucontainers[i].modules[j] == relus[i] then 297 | relucontainers[i].modules[j] = layer 298 | end 299 | end 300 | end 301 | 302 | -- Create special ReLU list to iterate over for salient backpropagation 303 | self.relus = self.net:findModules(saliency == 'guided' and 'nn.GuidedReLU' or 'nn.DeconvnetReLU') 304 | end 305 | 306 | -- Switches the backward computation of special ReLUs for salient backpropagation 307 | function Model:salientBackprop() 308 | for i, v in ipairs(self.relus) do 309 | v:salientBackprop() 310 | end 311 | end 312 | 313 | -- Switches the backward computation of special ReLUs for normal backpropagation 314 | function Model:normalBackprop() 315 | for i, v in ipairs(self.relus) do 316 | v:normalBackprop() 317 | end 318 | end 319 | 320 | return Model 321 | -------------------------------------------------------------------------------- /async/ValidationAgent.lua: -------------------------------------------------------------------------------- 1 | local _ = require 'moses' 2 | local AsyncModel = require 'async/AsyncModel' 3 | local Evaluator = require 'Evaluator' 4 | local Experience = require 'Experience' 5 | local CircularQueue = require 'structures/CircularQueue' 6 | local classic = require 'classic' 7 | local gnuplot = require 'gnuplot' 8 | require 'classic.torch' 9 | 10 | local ValidationAgent = classic.class('ValidationAgent') 11 | 12 | function ValidationAgent:_init(opt, theta, atomic) 13 | log.info('creating ValidationAgent') 14 | local asyncModel = AsyncModel(opt) 15 | self.env, self.model = asyncModel:getEnvAndModel() 16 | self.policyNet_ = asyncModel:createNet() 17 | 18 | self.lstm = opt.recurrent and self.policyNet_:findModules('nn.FastLSTM')[1] 19 | 20 | self.theta_ = self.policyNet_:getParameters() 21 | self.theta = theta 22 | 23 | self.atomic = atomic 24 | self._id = opt._id 25 | 26 | -- Validation variables 27 | self.valSize = opt.valSize 28 | self.losses = {} 29 | self.avgV = {} -- Running average of V(s') 30 | self.avgTdErr = {} -- Running average of TD-error δ 31 | self.valScores = {} -- Validation scores (passed from main script) 32 | self.normScores = {} -- Normalised validation scores (passed from main script) 33 | 34 | self.m = opt.actionSpec[3][2] - opt.actionSpec[3][1] + 1 -- Number of discrete actions 35 | self.actionOffset = 1 - opt.actionSpec[3][1] -- Calculate offset if first action is not indexed as 1 36 | 37 | self.env:training() 38 | 39 | self.stateBuffer = CircularQueue(opt.recurrent and 1 or opt.histLen, opt.Tensor, opt.stateSpec[2]) 40 | self.progFreq = opt.progFreq 41 | self.Tensor = opt.Tensor 42 | 43 | self.reportWeights = opt.reportWeights 44 | self.valSteps = opt.valSteps 45 | self.evaluator = Evaluator(opt.game) 46 | 47 | opt.batchSize = opt.valSize -- override in this thread ONLY 48 | self.valMemory = Experience(opt.valSize + 3, opt, true) 49 | 50 | self.bestValScore = -math.huge 51 | 52 | self.selectAction = self.eGreedyAction 53 | self.a3c = opt.async == 'A3C' 54 | if self.a3c then self.selectAction = self.probabilisticAction end 55 | 56 | classic.strict(self) 57 | end 58 | 59 | 60 | function ValidationAgent:start() 61 | log.info('ValidationAgent | filling ValMemory ') 62 | local reward, rawObservation, terminal = 0, self.env:start(), false 63 | local action = 1 64 | for i=1,self.valSize+1 do 65 | local observation = self.model:preprocess(rawObservation) 66 | self.valMemory:store(reward, observation, terminal, action) 67 | if not terminal then 68 | action = torch.random(1,self.m) 69 | reward, rawObservation, terminal = self.env:step(action - self.actionOffset) 70 | else 71 | reward, rawObservation, terminal = 0, self.env:start(), false 72 | end 73 | end 74 | log.info('ValidationAgent | ValMemory filled') 75 | end 76 | 77 | 78 | function ValidationAgent:eGreedyAction(state) 79 | local epsilon = 0.001 -- Taken from tuned DDQN evaluation 80 | 81 | local Q = self.policyNet_:forward(state):squeeze() 82 | 83 | if torch.uniform() < epsilon then 84 | return torch.random(1,self.m) 85 | end 86 | 87 | local _, maxIdx = Q:max(1) 88 | return maxIdx[1] 89 | end 90 | 91 | 92 | function ValidationAgent:probabilisticAction(state) 93 | local __, probability = table.unpack(self.policyNet_:forward(state)) 94 | return torch.multinomial(probability, 1):squeeze() 95 | end 96 | 97 | 98 | function ValidationAgent:validate() 99 | self.theta_:copy(self.theta) 100 | if self.lstm then self.lstm:forget() end 101 | 102 | self.stateBuffer:clear() 103 | self.env:evaluate() 104 | self.policyNet_:evaluate() 105 | 106 | local valStepStrFormat = '%0' .. (math.floor(math.log10(self.valSteps)) + 1) .. 'd' 107 | local valEpisode = 1 108 | local valEpisodeScore = 0 109 | local valTotalScore = 0 110 | 111 | local reward, observation, terminal = 0, self.env:start(), false 112 | 113 | for valStep = 1, self.valSteps do 114 | observation = self.model:preprocess(observation) 115 | if terminal then 116 | self.stateBuffer:clear() 117 | else 118 | self.stateBuffer:push(observation) 119 | end 120 | if not terminal then 121 | local state = self.stateBuffer:readAll() 122 | 123 | local action = self:selectAction(state) 124 | reward, observation, terminal = self.env:step(action - self.actionOffset) 125 | valEpisodeScore = valEpisodeScore + reward 126 | else 127 | if self.lstm then self.lstm:forget() end 128 | 129 | -- Print score every 10 episodes 130 | if valEpisode % 10 == 0 then 131 | local avgScore = valTotalScore/math.max(valEpisode - 1, 1) 132 | log.info('[VAL] Steps: ' .. string.format(valStepStrFormat, valStep) .. '/' .. self.valSteps .. ' | Episode ' .. valEpisode 133 | .. ' | Score: ' .. valEpisodeScore .. ' | TotScore: ' .. valTotalScore .. ' | AvgScore: %.2f', avgScore) 134 | end 135 | 136 | -- Start a new episode 137 | valEpisode = valEpisode + 1 138 | reward, observation, terminal = 0, self.env:start(), false 139 | valTotalScore = valTotalScore + valEpisodeScore -- Only add to total score at end of episode 140 | valEpisodeScore = reward -- Reset episode score 141 | end 142 | end 143 | 144 | -- If no episodes completed then use score from incomplete episode 145 | if valEpisode == 1 then 146 | valTotalScore = valEpisodeScore 147 | end 148 | 149 | log.info('Validated @ '.. self.atomic:get()) 150 | log.info('Total Score: ' .. valTotalScore) 151 | local valAvgScore = valTotalScore/math.max(valEpisode - 1, 1) -- Only average score for completed episodes in general 152 | log.info('Average Score: ' .. valAvgScore) 153 | self.valScores[#self.valScores + 1] = valAvgScore 154 | local normScore = self.evaluator:normaliseScore(valAvgScore) 155 | if normScore then 156 | log.info('Normalised Score: ' .. normScore) 157 | self.normScores[#self.normScores + 1] = normScore 158 | end 159 | 160 | self:visualiseFilters() 161 | 162 | local avgV = self:validationStats() 163 | log.info('Average V: ' .. avgV) 164 | 165 | self:saveWeights('last') 166 | if valAvgScore > self.bestValScore then 167 | log.info('New best average score') 168 | self.bestValScore = valAvgScore 169 | self:saveWeights('best') 170 | end 171 | 172 | if self.reportWeights then 173 | local reports = self:weightsReport() 174 | for r = 1, #reports do 175 | log.info(reports[r]) 176 | end 177 | end 178 | end 179 | 180 | function ValidationAgent:saveWeights(name) 181 | log.info('Saving weights') 182 | torch.save(paths.concat('experiments', self._id, name..'.weights.t7'), self.theta) 183 | end 184 | 185 | -- Saves network convolutional filters as images 186 | function ValidationAgent:visualiseFilters() 187 | local filters = self.model:getFilters() 188 | 189 | for i, v in ipairs(filters) do 190 | image.save(paths.concat('experiments', self._id, 'conv_layer_' .. i .. '.png'), v) 191 | end 192 | end 193 | 194 | local pprintArr = function(memo, v) 195 | return memo .. ', ' .. v 196 | end 197 | 198 | -- Reports absolute network weights and gradients 199 | function ValidationAgent:weightsReport() 200 | -- Collect layer with weights 201 | local weightLayers = self.policyNet_:findModules('nn.SpatialConvolution') 202 | if #weightLayers == 0 then 203 | -- Assume cuDNN convolutions 204 | weightLayers = self.policyNet:findModules('cudnn.SpatialConvolution') 205 | end 206 | local fcLayers = self.policyNet_:findModules('nn.Linear') 207 | weightLayers = _.append(weightLayers, fcLayers) 208 | 209 | -- Array of norms and maxima 210 | local wNorms = {} 211 | local wMaxima = {} 212 | local wGradNorms = {} 213 | local wGradMaxima = {} 214 | 215 | -- Collect statistics 216 | for l = 1, #weightLayers do 217 | local w = weightLayers[l].weight:clone():abs() -- Weights (absolute) 218 | wNorms[#wNorms + 1] = torch.mean(w) -- Weight norms: 219 | wMaxima[#wMaxima + 1] = torch.max(w) -- Weight max 220 | w = weightLayers[l].gradWeight:clone():abs() -- Weight gradients (absolute) 221 | wGradNorms[#wGradNorms + 1] = torch.mean(w) -- Weight grad norms: 222 | wGradMaxima[#wGradMaxima + 1] = torch.max(w) -- Weight grad max 223 | end 224 | 225 | -- Create report string table 226 | local reports = { 227 | 'Weight norms: ' .. _.reduce(wNorms, pprintArr), 228 | 'Weight max: ' .. _.reduce(wMaxima, pprintArr), 229 | 'Weight gradient norms: ' .. _.reduce(wGradNorms, pprintArr), 230 | 'Weight gradient max: ' .. _.reduce(wGradMaxima, pprintArr) 231 | } 232 | 233 | return reports 234 | end 235 | 236 | 237 | function ValidationAgent:validationStats() 238 | local indices = torch.linspace(2, self.valSize+1, self.valSize):long() 239 | local states, actions, rewards, transitions, terminals = self.valMemory:retrieve(indices) 240 | 241 | local totalV 242 | if self.a3c then 243 | local Vs = self.policyNet_:forward(transitions)[1] 244 | totalV = Vs:sum() 245 | else 246 | local QPrimes = self.policyNet_:forward(transitions) -- in real learning targetNet but doesnt matter for validation 247 | local VPrime = torch.max(QPrimes, 3) 248 | totalV = VPrime:sum() 249 | end 250 | local avgV = totalV / self.valSize 251 | self.avgV[#self.avgV + 1] = avgV 252 | self:plotValidation() 253 | return avgV 254 | end 255 | 256 | 257 | function ValidationAgent:plotValidation() 258 | -- Plot and save losses 259 | if #self.losses > 0 then 260 | local losses = torch.Tensor(self.losses) 261 | gnuplot.pngfigure(paths.concat('experiments', self._id, 'losses.png')) 262 | gnuplot.plot('Loss', torch.linspace(math.floor(self.learnStart/self.progFreq), math.floor(self.globals.step/self.progFreq), #self.losses), losses, '-') 263 | gnuplot.xlabel('Step (x' .. self.progFreq .. ')') 264 | gnuplot.ylabel('Loss') 265 | gnuplot.plotflush() 266 | torch.save(paths.concat('experiments', self._id, 'losses.t7'), losses) 267 | end 268 | -- Plot and save V 269 | local epochIndices = torch.linspace(1, #self.avgV, #self.avgV) 270 | local Vs = torch.Tensor(self.avgV) 271 | gnuplot.pngfigure(paths.concat('experiments', self._id, 'Vs.png')) 272 | gnuplot.plot('V', epochIndices, Vs, '-') 273 | gnuplot.xlabel('Epoch') 274 | gnuplot.ylabel('V') 275 | gnuplot.movelegend('left', 'top') 276 | gnuplot.plotflush() 277 | torch.save(paths.concat('experiments', self._id, 'V.t7'), Vs) 278 | -- Plot and save TD-error δ 279 | if #self.avgTdErr>0 then 280 | local TDErrors = torch.Tensor(self.avgTdErr) 281 | gnuplot.pngfigure(paths.concat('experiments', self._id, 'TDErrors.png')) 282 | gnuplot.plot('TD-Error', epochIndices, TDErrors, '-') 283 | gnuplot.xlabel('Epoch') 284 | gnuplot.ylabel('TD-Error') 285 | gnuplot.plotflush() 286 | torch.save(paths.concat('experiments', self._id, 'TDErrors.t7'), TDErrors) 287 | end 288 | -- Plot and save average score 289 | local scores = torch.Tensor(self.valScores) 290 | gnuplot.pngfigure(paths.concat('experiments', self._id, 'scores.png')) 291 | gnuplot.plot('Score', epochIndices, scores, '-') 292 | gnuplot.xlabel('Epoch') 293 | gnuplot.ylabel('Average Score') 294 | gnuplot.movelegend('left', 'top') 295 | gnuplot.plotflush() 296 | torch.save(paths.concat('experiments', self._id, 'scores.t7'), scores) 297 | -- Plot and save normalised score 298 | if #self.normScores > 0 then 299 | local normScores = torch.Tensor(self.normScores) 300 | gnuplot.pngfigure(paths.concat('experiments', self._id, 'normScores.png')) 301 | gnuplot.plot('Score', epochIndices, normScores, '-') 302 | gnuplot.xlabel('Epoch') 303 | gnuplot.ylabel('Normalised Score') 304 | gnuplot.movelegend('left', 'top') 305 | gnuplot.plotflush() 306 | torch.save(paths.concat('experiments', self._id, 'normScores.t7'), normScores) 307 | end 308 | gnuplot.close() 309 | end 310 | 311 | 312 | function ValidationAgent:evaluate(display) 313 | self.theta_:copy(self.theta) 314 | 315 | log.info('Evaluation mode') 316 | -- Set environment and agent to evaluation mode 317 | self.env:evaluate() 318 | 319 | local reward, observation, terminal = 0, self.env:start(), false 320 | 321 | -- Report episode score 322 | local episodeScore = reward 323 | 324 | -- Play one game (episode) 325 | local step = 1 326 | while not terminal do 327 | observation = self.model:preprocess(observation) 328 | if terminal then 329 | self.stateBuffer:pushReset(observation) 330 | else 331 | self.stateBuffer:push(observation) 332 | end 333 | -- Observe and choose next action (index) 334 | local state = self.stateBuffer:readAll() 335 | local action = self:selectAction(state) 336 | 337 | -- Act on environment 338 | if not terminal then 339 | reward, observation, terminal = self.env:step(action - self.actionOffset) 340 | else 341 | reward, observation, terminal = 0, self.env:start(), false 342 | end 343 | episodeScore = episodeScore + reward 344 | 345 | if display then 346 | display:display(self, self.env:getDisplay(), step) 347 | end 348 | -- Increment evaluation step counter 349 | step = step + 1 350 | end 351 | log.info('Final Score: ' .. episodeScore) 352 | 353 | if display then 354 | display:createVideo() 355 | end 356 | end 357 | 358 | 359 | return ValidationAgent 360 | -------------------------------------------------------------------------------- /Experience.lua: -------------------------------------------------------------------------------- 1 | local _ = require 'moses' 2 | local classic = require 'classic' 3 | local BinaryHeap = require 'structures/BinaryHeap' 4 | local Singleton = require 'structures/Singleton' 5 | require 'classic.torch' -- Enables serialisation 6 | 7 | local Experience = classic.class('Experience') 8 | 9 | -- Creates experience replay memory 10 | function Experience:_init(capacity, opt, isValidation) 11 | self.capacity = capacity 12 | -- Extract relevant options 13 | self.batchSize = opt.batchSize 14 | self.histLen = opt.histLen 15 | self.gpu = opt.gpu 16 | self.discretiseMem = opt.discretiseMem 17 | self.memPriority = opt.memPriority 18 | self.learnStart = opt.learnStart 19 | self.alpha = opt.alpha 20 | self.betaZero = opt.betaZero 21 | self.heads = math.max(opt.bootstraps, 1) 22 | 23 | -- Create transition tuples buffer 24 | local bufferStateSize = torch.LongStorage(_.append({opt.batchSize, opt.histLen}, opt.stateSpec[2])) 25 | self.transTuples = { 26 | states = opt.Tensor(bufferStateSize), 27 | actions = torch.ByteTensor(opt.batchSize), 28 | rewards = opt.Tensor(opt.batchSize), 29 | transitions = opt.Tensor(bufferStateSize), 30 | terminals = torch.ByteTensor(opt.batchSize), 31 | priorities = opt.Tensor(opt.batchSize) 32 | } 33 | self.indices = torch.LongTensor(opt.batchSize) 34 | self.w = opt.Tensor(opt.batchSize):fill(1) -- Importance-sampling weights w, 1 if no correction needed 35 | 36 | -- Allocate memory for experience 37 | local stateSize = torch.LongStorage(_.append({capacity}, opt.stateSpec[2])) -- Calculate state storage size 38 | self.imgDiscLevels = 255 -- Number of discretisation levels for images (used for float <-> byte conversion) 39 | if opt.discretiseMem then 40 | -- For the standard DQN problem, float vs. byte storage is 24GB vs. 6GB memory, so this prevents/minimises slow swap usage 41 | self.states = torch.ByteTensor(stateSize) -- ByteTensor to avoid massive memory usage 42 | else 43 | self.states = torch.Tensor(stateSize) 44 | end 45 | self.actions = torch.ByteTensor(capacity) -- Discrete action indices 46 | self.rewards = torch.FloatTensor(capacity) -- Stored at time t (not t + 1) 47 | -- Terminal conditions stored at time t+1, encoded by 0 = false, 1 = true 48 | self.terminals = torch.ByteTensor(capacity):fill(1) -- Filling with 1 prevents going back in history at beginning 49 | -- Validation flags (used if state is stored without transition) 50 | self.invalid = torch.ByteTensor(capacity) -- 1 is used to denote invalid 51 | -- Internal pointer 52 | self.masks = torch.ByteTensor(capacity, self.heads):fill(0) 53 | -- Masking flags for Bootstrap heads 54 | self.allIndexes = torch.LongTensor():range(1,capacity) 55 | -- Used during finding unmasked samples for bootstrap head 56 | self.index = 1 57 | self.isFull = false 58 | self.size = 0 59 | 60 | -- TD-error δ-based priorities 61 | self.priorityQueue = BinaryHeap(capacity) -- Stored at time t 62 | self.smallConst = 1e-12 63 | -- Sampling priority 64 | if not isValidation and opt.memPriority == 'rank' then 65 | -- Cache partition indices for several values of N as α is static 66 | self.distributions = {} 67 | local nPartitions = 100 -- learnStart must be at least 1/100 of capacity (arbitrary constant) 68 | local partitionNum = 1 69 | local partitionDivision = math.floor(capacity/nPartitions) 70 | 71 | for n = partitionDivision, capacity, partitionDivision do 72 | if n >= opt.learnStart or n == capacity then -- Do not calculate distributions for before learnStart occurs 73 | -- Set up power-law PDF and CDF 74 | local distribution = {} 75 | distribution.pdf = torch.linspace(1, n, n):pow(-opt.alpha) 76 | local pdfSum = torch.sum(distribution.pdf) 77 | distribution.pdf:div(pdfSum) -- Normalise PDF 78 | local cdf = torch.cumsum(distribution.pdf) 79 | 80 | -- Set up strata for stratified sampling (transitions will have varying TD-error magnitudes |δ|) 81 | distribution.strataEnds = torch.LongTensor(opt.batchSize + 1) 82 | distribution.strataEnds[1] = 0 -- First index is 0 (+1) 83 | distribution.strataEnds[opt.batchSize + 1] = n -- Last index is n 84 | -- Use linear search to find strata indices 85 | local stratumEnd = 1/opt.batchSize 86 | local index = 1 87 | for s = 2, opt.batchSize do 88 | while cdf[index] < stratumEnd do 89 | index = index + 1 90 | end 91 | distribution.strataEnds[s] = index -- Save index 92 | stratumEnd = stratumEnd + 1/opt.batchSize -- Set condition for next stratum 93 | end 94 | 95 | -- Check that enough transitions are available (to prevent an infinite loop of infinite tuples) 96 | if distribution.strataEnds[2] - distribution.strataEnds[1] <= opt.histLen then 97 | log.error('Experience replay strata are too small - use a smaller alpha/larger memSize/greater learnStart') 98 | error('Experience replay strata are too small - use a smaller alpha/larger memSize/greater learnStart') 99 | end 100 | 101 | -- Store distribution 102 | self.distributions[partitionNum] = distribution 103 | end 104 | 105 | partitionNum = partitionNum + 1 106 | end 107 | end 108 | 109 | -- Initialise first time step (s0) 110 | self.states[1]:zero() -- Blank out state 111 | self.terminals[1] = 0 112 | self.actions[1] = 1 -- Action is no-op 113 | self.invalid[1] = 0 -- First step is a fake blanked-out state, but can thereby be utilised 114 | self.masks[1]:zero() -- Mask out for all 115 | if self.memPriority then 116 | self.priorityQueue:insert(1, 1) -- First priority = 1 117 | end 118 | 119 | -- Calculate β growth factor (linearly annealed till end of training) 120 | self.betaGrad = (1 - opt.betaZero)/(opt.steps - opt.learnStart) 121 | 122 | -- Get singleton instance for step 123 | self.globals = Singleton.getInstance() 124 | end 125 | 126 | -- Calculates circular indices 127 | function Experience:circIndex(x) 128 | local ind = x % self.capacity 129 | return ind == 0 and self.capacity or ind -- Correct 0-index 130 | end 131 | 132 | -- Stores experience tuple parts (including pre-emptive action) 133 | function Experience:store(reward, state, terminal, action, mask) 134 | self.rewards[self.index] = reward 135 | 136 | -- Increment index and size 137 | self.index = self.index + 1 138 | self.size = math.min(self.size + 1, self.capacity) 139 | -- Circle back to beginning if memory limit reached 140 | if self.index > self.capacity then 141 | self.isFull = true -- Full memory flag 142 | self.index = 1 -- Reset index 143 | end 144 | 145 | if self.discretiseMem then 146 | self.states[self.index] = torch.mul(state, self.imgDiscLevels) -- float -> byte 147 | else 148 | self.states[self.index] = state:clone() 149 | end 150 | self.terminals[self.index] = terminal and 1 or 0 151 | self.actions[self.index] = action 152 | self.invalid[self.index] = 0 153 | self.masks[self.index] = mask:clone() 154 | 155 | -- Store with maximal priority 156 | if self.memPriority then 157 | -- TODO: Correct PER by not storing terminal states at all 158 | local maxPriority = terminal and 0 or self.priorityQueue:findMax() -- Terminal states cannot be sampled so assign priority 0 159 | if self.isFull then 160 | self.priorityQueue:updateByVal(self.index, maxPriority, self.index) 161 | else 162 | self.priorityQueue:insert(maxPriority, self.index) 163 | end 164 | end 165 | end 166 | 167 | -- Sets current state as invalid (utilised when switching to evaluation mode) 168 | function Experience:setInvalid() 169 | self.invalid[self.index] = 1 170 | end 171 | 172 | -- Retrieves experience tuples (s, a, r, s', t) 173 | function Experience:retrieve(indices) 174 | local N = indices:size(1) 175 | -- Blank out history in one go 176 | self.transTuples.states:zero() 177 | self.transTuples.transitions:zero() 178 | 179 | -- Iterate over indices 180 | for n = 1, N do 181 | local memIndex = indices[n] 182 | -- Retrieve action 183 | self.transTuples.actions[n] = self.actions[memIndex] 184 | -- Retrieve rewards 185 | self.transTuples.rewards[n] = self.rewards[memIndex] 186 | -- Retrieve terminal status (of transition) 187 | self.transTuples.terminals[n] = self.terminals[self:circIndex(memIndex + 1)] 188 | 189 | -- Go back in history whilst episode exists 190 | local histIndex = self.histLen 191 | repeat 192 | if self.discretiseMem then 193 | -- Copy state (converting to float first for non-integer division) 194 | self.transTuples.states[n][histIndex]:div(self.states[memIndex]:typeAs(self.transTuples.states), self.imgDiscLevels) -- byte -> float 195 | else 196 | self.transTuples.states[n][histIndex] = self.states[memIndex]:typeAs(self.transTuples.states) 197 | end 198 | -- Adjust indices 199 | memIndex = self:circIndex(memIndex - 1) 200 | histIndex = histIndex - 1 201 | until histIndex == 0 or self.terminals[memIndex] == 1 or self.invalid[memIndex] == 1 202 | 203 | -- If transition not terminal, fill in transition history (invalid states should not be selected in the first place) 204 | if self.transTuples.terminals[n] == 0 then 205 | -- Copy most recent state 206 | for h = 2, self.histLen do 207 | self.transTuples.transitions[n][h - 1] = self.transTuples.states[n][h] 208 | end 209 | -- Get transition frame 210 | local memTIndex = self:circIndex(indices[n] + 1) 211 | if self.discretiseMem then 212 | self.transTuples.transitions[n][self.histLen]:div(self.states[memTIndex]:typeAs(self.transTuples.transitions), self.imgDiscLevels) -- byte -> float 213 | else 214 | self.transTuples.transitions[n][self.histLen] = self.states[memTIndex]:typeAs(self.transTuples.transitions) 215 | end 216 | end 217 | end 218 | 219 | return self.transTuples.states[{{1, N}}], self.transTuples.actions[{{1, N}}], self.transTuples.rewards[{{1, N}}], self.transTuples.transitions[{{1, N}}], self.transTuples.terminals[{{1, N}}] 220 | end 221 | 222 | -- Determines if an index points to a valid transition state 223 | function Experience:validateTransition(index) 224 | -- Calculate beginning of state and end of transition for checking overlap with head of buffer 225 | local minIndex, maxIndex = index - self.histLen, self:circIndex(index + 1) 226 | -- State must not be terminal, invalid, or overlap with head of buffer 227 | return self.terminals[index] == 0 and self.invalid[index] == 0 and (self.index <= minIndex or self.index >= maxIndex) 228 | end 229 | 230 | -- Determines if an index points to a masked state 231 | function Experience:isUnmasked(index, head) 232 | return self.masks[index][head] == 1 233 | end 234 | 235 | -- Returns indices and importance-sampling weights based on (stochastic) proportional prioritised sampling 236 | function Experience:sample(head) 237 | local N = self.size 238 | local unmaskedIndexes = self.allIndexes[self.masks[{{},head}]] 239 | local M = unmaskedIndexes:size(1) 240 | 241 | -- Priority 'none' = uniform sampling 242 | if not self.memPriority then 243 | 244 | -- Keep uniformly picking random indices until indices filled 245 | for n = 1, self.batchSize do 246 | local index 247 | local isValid = false 248 | 249 | -- Generate random index until valid transition found 250 | while not isValid do 251 | index = torch.random(1, M) 252 | isValid = self:validateTransition(unmaskedIndexes[index]) 253 | end 254 | 255 | -- Store index 256 | self.indices[n] = index 257 | end 258 | 259 | elseif self.memPriority == 'rank' then 260 | 261 | -- Find closest precomputed distribution by size 262 | local distIndex = math.floor(N / self.capacity * 100) 263 | local distribution = self.distributions[distIndex] 264 | N = distIndex * 100 265 | 266 | -- Create table to store indices (by rank) 267 | local rankIndices = torch.LongTensor(self.batchSize) -- In reality the underlying array-based binary heap is used as an approximation of a ranked (sorted) array 268 | -- Perform stratified sampling 269 | for n = 1, self.batchSize do 270 | local index 271 | local isValid = false 272 | 273 | -- Generate random index until valid transition found 274 | while not isValid do 275 | -- Sample within stratum 276 | rankIndices[n] = torch.random(distribution.strataEnds[n] + 1, distribution.strataEnds[n+1]) 277 | -- Retrieve actual transition index 278 | index = self.priorityQueue:getValueByVal(rankIndices[n]) 279 | isValid = self:validateTransition(index) -- The last stratum might be full of terminal states, leading to many checks 280 | end 281 | 282 | -- Store actual transition index 283 | self.indices[n] = index 284 | end 285 | 286 | -- Compute importance-sampling weights w = (N * p(rank))^-β 287 | local beta = math.min(self.betaZero + (self.globals.step - self.learnStart - 1)*self.betaGrad, 1) 288 | self.w = distribution.pdf:index(1, rankIndices):mul(N):pow(-beta) -- torch.index does memory copy 289 | -- Calculate max importance-sampling weight 290 | -- Note from Tom Schaul: Calculated over minibatch, not entire distribution 291 | local wMax = torch.max(self.w) 292 | -- Normalise weights so updates only scale downwards (for stability) 293 | self.w:div(wMax) 294 | 295 | elseif self.memPriority == 'proportional' then 296 | 297 | -- TODO: Proportional prioritised experience replay 298 | 299 | end 300 | 301 | return self.indices, self.w 302 | end 303 | 304 | -- Update experience priorities using TD-errors δ 305 | function Experience:updatePriorities(indices, delta) 306 | if self.memPriority then 307 | local priorities = torch.abs(delta):float() -- Use absolute values 308 | if self.memPriority == 'proportional' then 309 | priorities:add(self.smallConstant) -- Allows transitions to be sampled even if error is 0 310 | end 311 | 312 | for p = 1, indices:size(1) do 313 | self.priorityQueue:updateByVal(indices[p], priorities[p], indices[p]) 314 | end 315 | end 316 | end 317 | 318 | -- Rebalance prioritised experience replay heap 319 | function Experience:rebalance() 320 | self.priorityQueue:rebalance() 321 | end 322 | 323 | return Experience 324 | -------------------------------------------------------------------------------- /Setup.lua: -------------------------------------------------------------------------------- 1 | require 'logroll' 2 | local _ = require 'moses' 3 | local classic = require 'classic' 4 | local cjson = require 'cjson' 5 | 6 | local Setup = classic.class('Setup') 7 | 8 | -- Performs global setup 9 | function Setup:_init(arg) 10 | -- Create log10 for Lua 5.2 11 | if not math.log10 then 12 | math.log10 = function(x) 13 | return math.log(x, 10) 14 | end 15 | end 16 | 17 | -- Parse command-line options 18 | self.opt = self:parseOptions(arg) 19 | 20 | -- Create experiment directory 21 | if not paths.dirp(self.opt.experiments) then 22 | paths.mkdir(self.opt.experiments) 23 | end 24 | paths.mkdir(paths.concat(self.opt.experiments, self.opt._id)) 25 | -- Save options for reference 26 | local file = torch.DiskFile(paths.concat(self.opt.experiments, self.opt._id, 'opts.json'), 'w') 27 | file:writeString(cjson.encode(self.opt)) 28 | file:close() 29 | 30 | -- Set up logging 31 | local flog = logroll.file_logger(paths.concat(self.opt.experiments, self.opt._id, 'log.txt')) 32 | local plog = logroll.print_logger() 33 | log = logroll.combine(flog, plog) -- Global logger 34 | 35 | -- Validate command-line options (logging errors) 36 | self:validateOptions() 37 | 38 | -- Augment environments to meet spec 39 | self:augmentEnv() 40 | 41 | -- Torch setup 42 | log.info('Setting up Torch7') 43 | -- Set number of BLAS threads 44 | torch.setnumthreads(self.opt.threads) 45 | -- Set default Tensor type (float is more efficient than double) 46 | torch.setdefaulttensortype(self.opt.tensorType) 47 | -- Set manual seed 48 | torch.manualSeed(self.opt.seed) 49 | 50 | -- Tensor creation function for removing need to cast to CUDA if GPU is enabled 51 | -- TODO: Replace with local functions across codebase 52 | self.opt.Tensor = function(...) 53 | return torch.Tensor(...) 54 | end 55 | 56 | -- GPU setup 57 | if self.opt.gpu > 0 then 58 | log.info('Setting up GPU') 59 | cutorch.setDevice(self.opt.gpu) 60 | -- Set manual seeds using random numbers to reduce correlations 61 | cutorch.manualSeed(torch.random()) 62 | -- Replace tensor creation function 63 | self.opt.Tensor = function(...) 64 | return torch.CudaTensor(...) 65 | end 66 | end 67 | 68 | classic.strict(self) 69 | end 70 | 71 | -- Parses command-line options 72 | function Setup:parseOptions(arg) 73 | -- Detect and use GPU 1 by default 74 | local cuda = pcall(require, 'cutorch') 75 | 76 | local cmd = torch.CmdLine() 77 | -- Base Torch7 options 78 | cmd:option('-seed', 1, 'Random seed') 79 | cmd:option('-threads', 4, 'Number of BLAS or async threads') 80 | cmd:option('-tensorType', 'torch.FloatTensor', 'Default tensor type') 81 | cmd:option('-gpu', cuda and 1 or 0, 'GPU device ID (0 to disable)') 82 | cmd:option('-cudnn', 'false', 'Utilise cuDNN (if available)') 83 | -- Environment options 84 | cmd:option('-env', 'rlenvs.Catch', 'Environment class (Lua file to be loaded/rlenv)') 85 | cmd:option('-zoom', 1, 'Display zoom (requires QT)') 86 | cmd:option('-game', '', 'Name of Atari ROM (stored in "roms" directory)') 87 | -- Training vs. evaluate mode 88 | cmd:option('-mode', 'train', 'Train vs. test mode: train|eval') 89 | -- State preprocessing options (for visual states) 90 | cmd:option('-height', 0, 'Resized screen height (0 to disable)') 91 | cmd:option('-width', 0, 'Resize screen width (0 to disable)') 92 | cmd:option('-colorSpace', '', 'Colour space conversion (screen is RGB): |y|lab|yuv|hsl|hsv|nrgb') 93 | -- Model options 94 | cmd:option('-modelBody', 'models.Catch', 'Path to Torch nn model to be used as DQN "body"') 95 | cmd:option('-hiddenSize', 512, 'Number of units in the hidden fully connected layer') 96 | cmd:option('-histLen', 4, 'Number of consecutive states processed/used for backpropagation-through-time') -- DQN standard is 4, DRQN is 10 97 | cmd:option('-duel', 'true', 'Use dueling network architecture (learns advantage function)') 98 | cmd:option('-bootstraps', 10, 'Number of bootstrap heads (0 to disable)') 99 | --cmd:option('-bootstrapMask', 1, 'Independent probability of masking a transition for each bootstrap head ~ Ber(bootstrapMask) (1 to disable)') 100 | cmd:option('-recurrent', 'false', 'Use recurrent connections') 101 | -- Experience replay options 102 | cmd:option('-discretiseMem', 'true', 'Discretise states to integers ∈ [0, 255] for storage') 103 | cmd:option('-memSize', 1e6, 'Experience replay memory size (number of tuples)') 104 | cmd:option('-memSampleFreq', 4, 'Interval of steps between sampling from memory to learn') 105 | cmd:option('-memNSamples', 1, 'Number of times to sample per learning step') 106 | cmd:option('-memPriority', '', 'Type of prioritised experience replay: |rank|proportional') -- TODO: Implement proportional prioritised experience replay 107 | cmd:option('-alpha', 0.65, 'Prioritised experience replay exponent α') -- Best vals are rank = 0.7, proportional = 0.6 108 | cmd:option('-betaZero', 0.45, 'Initial value of importance-sampling exponent β') -- Best vals are rank = 0.5, proportional = 0.4 109 | -- Reinforcement learning parameters 110 | cmd:option('-gamma', 0.99, 'Discount rate γ') 111 | cmd:option('-epsilonStart', 1, 'Initial value of greediness ε') 112 | cmd:option('-epsilonEnd', 0.01, 'Final value of greediness ε') -- Tuned DDQN final greediness (1/10 that of DQN) 113 | cmd:option('-epsilonSteps', 1e6, 'Number of steps to linearly decay epsilonStart to epsilonEnd') -- Usually same as memory size 114 | cmd:option('-tau', 30000, 'Steps between target net updates τ') -- Tuned DDQN target net update interval (3x that of DQN) 115 | cmd:option('-rewardClip', 1, 'Clips reward magnitude at rewardClip (0 to disable)') 116 | cmd:option('-tdClip', 1, 'Clips TD-error δ magnitude at tdClip (0 to disable)') 117 | cmd:option('-doubleQ', 'true', 'Use Double Q-learning') 118 | -- Note from Georg Ostrovski: The advantage operators and Double DQN are not entirely orthogonal as the increased action gap seems to reduce the statistical bias that leads to value over-estimation in a similar way that Double DQN does 119 | cmd:option('-PALpha', 0.9, 'Persistent advantage learning parameter α (0 to disable)') 120 | -- Training options 121 | cmd:option('-optimiser', 'rmspropm', 'Training algorithm') -- RMSProp with momentum as found in "Generating Sequences With Recurrent Neural Networks" 122 | cmd:option('-eta', 0.0000625, 'Learning rate η') -- Prioritied experience replay learning rate (1/4 that of DQN; does not account for Duel as well) 123 | cmd:option('-momentum', 0.95, 'Gradient descent momentum') 124 | cmd:option('-batchSize', 32, 'Minibatch size') 125 | cmd:option('-steps', 5e7, 'Training iterations (steps)') -- Frame := step in ALE; Time step := consecutive frames treated atomically by the agent 126 | cmd:option('-learnStart', 50000, 'Number of steps after which learning starts') 127 | cmd:option('-gradClip', 10, 'Clips L2 norm of gradients at gradClip (0 to disable)') 128 | -- Evaluation options 129 | cmd:option('-progFreq', 10000, 'Interval of steps between reporting progress') 130 | cmd:option('-reportWeights', 'false', 'Report weight and weight gradient statistics') 131 | cmd:option('-noValidation', 'false', 'Disable asynchronous agent validation thread') -- TODO: Make behaviour consistent across Master/AsyncMaster 132 | cmd:option('-valFreq', 250000, 'Interval of steps between validating agent') -- valFreq steps is used as an epoch, hence #epochs = steps/valFreq 133 | cmd:option('-valSteps', 125000, 'Number of steps to use for validation') 134 | cmd:option('-valSize', 500, 'Number of transitions to use for calculating validation statistics') 135 | -- Async options 136 | cmd:option('-async', '', 'Async agent: |Sarsa|OneStepQ|NStepQ|A3C') -- TODO: Change names 137 | cmd:option('-rmsEpsilon', 0.1, 'Epsilon for sharedRmsProp') 138 | cmd:option('-entropyBeta', 0.01, 'Policy entropy regularisation β') 139 | -- ALEWrap options 140 | cmd:option('-fullActions', 'false', 'Use full set of 18 actions') 141 | cmd:option('-actRep', 4, 'Times to repeat action') -- Independent of history length 142 | cmd:option('-randomStarts', 30, 'Max number of no-op actions played before presenting the start of each training episode') 143 | cmd:option('-poolFrmsType', 'max', 'Type of pooling over previous emulator frames: max|mean') 144 | cmd:option('-poolFrmsSize', 2, 'Number of emulator frames to pool over') 145 | cmd:option('-lifeLossTerminal', 'true', 'Use life loss as terminal signal (training only)') 146 | cmd:option('-flickering', 0, 'Probability of screen flickering (Catch only)') 147 | -- Experiment options 148 | cmd:option('-experiments', 'experiments', 'Base directory to store experiments') 149 | cmd:option('-_id', '', 'ID of experiment (used to store saved results, defaults to game name)') 150 | cmd:option('-network', '', 'Saved network weights file to load (weights.t7)') 151 | cmd:option('-checkpoint', 'false', 'Checkpoint network weights (instead of saving just latest weights)') 152 | cmd:option('-verbose', 'false', 'Log info for every episode (only in train mode)') 153 | cmd:option('-saliency', '', 'Display saliency maps (requires QT): |normal|guided|deconvnet') 154 | cmd:option('-record', 'false', 'Record screen (only in eval mode)') 155 | local opt = cmd:parse(arg) 156 | 157 | -- Process boolean options (Torch fails to accept false on the command line) 158 | opt.cudnn = opt.cudnn == 'true' 159 | opt.duel = opt.duel == 'true' 160 | opt.recurrent = opt.recurrent == 'true' 161 | opt.discretiseMem = opt.discretiseMem == 'true' 162 | opt.doubleQ = opt.doubleQ == 'true' 163 | opt.reportWeights = opt.reportWeights == 'true' 164 | opt.fullActions = opt.fullActions == 'true' 165 | opt.lifeLossTerminal = opt.lifeLossTerminal == 'true' 166 | opt.checkpoint = opt.checkpoint == 'true' 167 | opt.verbose = opt.verbose == 'true' 168 | opt.record = opt.record == 'true' 169 | opt.noValidation = opt.noValidation == 'true' 170 | 171 | -- Process boolean/enum options 172 | if opt.colorSpace == '' then opt.colorSpace = false end 173 | if opt.memPriority == '' then opt.memPriority = false end 174 | if opt.async == '' then opt.async = false end 175 | if opt.saliency == '' then opt.saliency = false end 176 | if opt.async then opt.gpu = 0 end -- Asynchronous agents are CPU-only 177 | 178 | -- Set ID as env (plus game name) if not set 179 | if opt._id == '' then 180 | local envName = paths.basename(opt.env) 181 | if opt.game == '' then 182 | opt._id = envName 183 | else 184 | opt._id = envName .. '.' .. opt.game 185 | end 186 | end 187 | 188 | -- Create one environment to extract specifications 189 | local Env = require(opt.env) 190 | local env = Env(opt) 191 | opt.stateSpec = env:getStateSpec() 192 | opt.actionSpec = env:getActionSpec() 193 | -- Process display if available (can be used for saliency recordings even without QT) 194 | if env.getDisplay then 195 | opt.displaySpec = env:getDisplaySpec() 196 | end 197 | 198 | return opt 199 | end 200 | 201 | -- Logs and aborts on error 202 | local function abortIf(err, msg) 203 | if err then 204 | log.error(msg) 205 | error(msg) 206 | end 207 | end 208 | 209 | -- Validates setup options 210 | function Setup:validateOptions() 211 | -- Check environment state is a single tensor 212 | abortIf(#self.opt.stateSpec ~= 3 or not _.isArray(self.opt.stateSpec[2]), 'Environment state is not a single tensor') 213 | 214 | -- Check environment has discrete actions 215 | abortIf(self.opt.actionSpec[1] ~= 'int' or self.opt.actionSpec[2] ~= 1, 'Environment does not have discrete actions') 216 | 217 | -- Change state spec if resizing 218 | if self.opt.height ~= 0 then 219 | self.opt.stateSpec[2][2] = self.opt.height 220 | end 221 | if self.opt.width ~= 0 then 222 | self.opt.stateSpec[2][3] = self.opt.width 223 | end 224 | 225 | -- Check colour conversions 226 | if self.opt.colorSpace then 227 | abortIf(not _.contains({'y', 'lab', 'yuv', 'hsl', 'hsv', 'nrgb'}, self.opt.colorSpace), 'Unsupported colour space for conversion') 228 | abortIf(self.opt.stateSpec[2][1] ~= 3, 'Original colour space must be RGB for conversion') 229 | -- Change state spec if converting from colour to greyscale 230 | if self.opt.colorSpace == 'y' then 231 | self.opt.stateSpec[2][1] = 1 232 | end 233 | end 234 | 235 | -- Check start of learning occurs after at least one minibatch of data has been collected 236 | abortIf(self.opt.learnStart <= self.opt.batchSize, 'learnStart must be greater than batchSize') 237 | 238 | -- Check enough validation transitions will be collected before first validation 239 | abortIf(self.opt.valFreq <= self.opt.valSize, 'valFreq must be greater than valSize') 240 | 241 | -- Check prioritised experience replay options 242 | abortIf(self.opt.memPriority and not _.contains({'rank', 'proportional'}, self.opt.memPriority), 'Type of prioritised experience replay unrecognised') 243 | abortIf(self.opt.memPriority == 'proportional', 'Proportional prioritised experience replay not implemented yet') -- TODO: Implement 244 | 245 | -- Check no prioritized replay is done when bootstrap 246 | abortIf(self.opt.bootstraps > 0 and _.contains({'rank', 'proportional'}, self.opt.memPriority), 'Prioritized experience replay not possible with bootstrap') 247 | 248 | -- Check start of learning occurs after at least 1/100 of memory has been filled 249 | abortIf(self.opt.learnStart <= self.opt.memSize/100, 'learnStart must be greater than memSize/100') 250 | 251 | -- Check memory size is multiple of 100 (makes prioritised sampling partitioning simpler) 252 | abortIf(self.opt.memSize % 100 ~= 0, 'memSize must be a multiple of 100') 253 | 254 | -- Check learning occurs after first progress report 255 | abortIf(self.opt.learnStart < self.opt.progFreq, 'learnStart must be greater than progFreq') 256 | 257 | -- Check saliency map options 258 | abortIf(self.opt.saliency and not _.contains({'normal', 'guided', 'deconvnet'}, self.opt.saliency), 'Unrecognised method for visualising saliency maps') 259 | 260 | -- Check saliency is valid 261 | abortIf(self.opt.saliency and not self.opt.displaySpec, 'Saliency cannot be shown without env:getDisplay()') 262 | abortIf(self.opt.saliency and #self.opt.stateSpec[2] ~= 3 and (self.opt.stateSpec[2][1] ~= 3 or self.opt.stateSpec[2][1] ~= 1), 'Saliency cannot be shown without visual state') 263 | 264 | -- Check async options 265 | if self.opt.async then 266 | abortIf(self.opt.recurrent and self.opt.async ~= 'OneStepQ', 'Recurrent connections only supported for OneStepQ in async for now') 267 | abortIf(self.opt.PALpha > 0, 'Persistent advantage learning not supported in async modes yet') 268 | abortIf(self.opt.bootstraps > 0, 'Bootstrap heads not supported in async mode yet') 269 | abortIf(self.opt.async == 'A3C' and self.opt.duel, 'Dueling networks and A3C are incompatible') 270 | abortIf(self.opt.async == 'A3C' and self.opt.doubleQ, 'Double Q-learning and A3C are incompatible') 271 | abortIf(self.opt.saliency, 'Saliency maps not supported in async modes yet') 272 | end 273 | end 274 | 275 | -- Augments environments with extra methods if missing 276 | function Setup:augmentEnv() 277 | local Env = require(self.opt.env) 278 | local env = Env(self.opt) 279 | 280 | -- Set up fake training mode (if needed) 281 | if not env.training then 282 | Env.training = function() end 283 | end 284 | -- Set up fake evaluation mode (if needed) 285 | if not env.evaluate then 286 | Env.evaluate = function() end 287 | end 288 | end 289 | 290 | return Setup 291 | -------------------------------------------------------------------------------- /Agent.lua: -------------------------------------------------------------------------------- 1 | local _ = require 'moses' 2 | local class = require 'classic' 3 | local optim = require 'optim' 4 | local gnuplot = require 'gnuplot' 5 | local Model = require 'Model' 6 | local Experience = require 'Experience' 7 | local CircularQueue = require 'structures/CircularQueue' 8 | local Singleton = require 'structures/Singleton' 9 | local AbstractAgent = require 'async/AbstractAgent' 10 | require 'classic.torch' -- Enables serialisation 11 | require 'modules/rmspropm' -- Add RMSProp with momentum 12 | 13 | -- Detect QT for image display 14 | local qt = pcall(require, 'qt') 15 | 16 | local Agent = classic.class('Agent', AbstractAgent) 17 | 18 | -- Creates a DQN agent 19 | function Agent:_init(opt) 20 | -- Experiment ID 21 | self._id = opt._id 22 | self.experiments = opt.experiments 23 | -- Actions 24 | self.m = opt.actionSpec[3][2] - opt.actionSpec[3][1] + 1 -- Number of discrete actions 25 | self.actionOffset = 1 - opt.actionSpec[3][1] -- Calculate offset if first action is not indexed as 1 26 | 27 | -- Initialise model helper 28 | self.model = Model(opt) 29 | -- Create policy and target networks 30 | self.policyNet = self.model:create() 31 | self.targetNet = self.policyNet:clone() -- Create deep copy for target network 32 | self.targetNet:evaluate() -- Target network always in evaluation mode 33 | self.tau = opt.tau 34 | self.doubleQ = opt.doubleQ 35 | -- Network parameters θ and gradients dθ 36 | self.theta, self.dTheta = self.policyNet:getParameters() 37 | 38 | -- Boostrapping 39 | self.bootstraps = opt.bootstraps 40 | self.head = 1 -- Identity of current episode bootstrap head 41 | self.heads = math.max(opt.bootstraps, 1) -- Number of heads 42 | 43 | -- Recurrency 44 | self.recurrent = opt.recurrent 45 | self.histLen = opt.histLen 46 | 47 | -- Reinforcement learning parameters 48 | self.gamma = opt.gamma 49 | self.rewardClip = opt.rewardClip 50 | self.tdClip = opt.tdClip 51 | self.epsilonStart = opt.epsilonStart 52 | self.epsilonEnd = opt.epsilonEnd 53 | self.epsilonGrad = (opt.epsilonEnd - opt.epsilonStart)/opt.epsilonSteps -- Greediness ε decay factor 54 | self.PALpha = opt.PALpha 55 | 56 | -- State buffer 57 | self.stateBuffer = CircularQueue(opt.recurrent and 1 or opt.histLen, opt.Tensor, opt.stateSpec[2]) 58 | -- Experience replay memory 59 | self.memory = Experience(opt.memSize, opt) 60 | self.memSampleFreq = opt.memSampleFreq 61 | self.memNSamples = opt.memNSamples 62 | self.memSize = opt.memSize 63 | self.memPriority = opt.memPriority 64 | 65 | -- Training mode 66 | self.isTraining = false 67 | self.batchSize = opt.batchSize 68 | self.learnStart = opt.learnStart 69 | self.progFreq = opt.progFreq 70 | self.gradClip = opt.gradClip 71 | -- Optimiser parameters 72 | self.optimiser = opt.optimiser 73 | self.optimParams = { 74 | learningRate = opt.eta, 75 | momentum = opt.momentum 76 | } 77 | 78 | -- Q-learning variables (per head) 79 | self.QPrimes = opt.Tensor(opt.batchSize, self.heads, self.m) 80 | self.tdErr = opt.Tensor(opt.batchSize, self.heads) 81 | self.VPrime = opt.Tensor(opt.batchSize, self.heads, 1) 82 | 83 | -- Validation variables 84 | self.valSize = opt.valSize 85 | self.valMemory = Experience(opt.valSize + 3, opt, true) -- Validation experience replay memory (with empty starting state...states...final transition...blank state) 86 | self.losses = {} 87 | self.avgV = {} -- Running average of V(s') 88 | self.avgTdErr = {} -- Running average of TD-error δ 89 | self.valScores = {} -- Validation scores (passed from main script) 90 | self.normScores = {} -- Normalised validation scores (passed from main script) 91 | 92 | -- Tensor creation 93 | self.Tensor = opt.Tensor 94 | 95 | -- Saliency display 96 | self:setSaliency(opt.saliency) -- Set saliency option on agent and model 97 | if #opt.stateSpec[2] == 3 then -- Make saliency map only for visual states 98 | self.saliencyMap = opt.Tensor(1, opt.stateSpec[2][2], opt.stateSpec[2][3]):zero() 99 | self.inputGrads = opt.Tensor(opt.histLen*opt.stateSpec[2][1], opt.stateSpec[2][2], opt.stateSpec[2][3]):zero() -- Gradients with respect to the input (for saliency maps) 100 | end 101 | 102 | -- Get singleton instance for step 103 | self.globals = Singleton.getInstance() 104 | end 105 | 106 | -- Sets training mode 107 | function Agent:training() 108 | self.isTraining = true 109 | self.policyNet:training() 110 | -- Clear state buffer 111 | self.stateBuffer:clear() 112 | -- Reset bootstrap head 113 | if self.bootstraps > 0 then 114 | self.head = torch.random(self.bootstraps) 115 | end 116 | -- Forget last sequence 117 | if self.recurrent then 118 | self.policyNet:forget() 119 | self.targetNet:forget() 120 | end 121 | end 122 | 123 | -- Sets evaluation mode 124 | function Agent:evaluate() 125 | self.isTraining = false 126 | self.policyNet:evaluate() 127 | -- Clear state buffer 128 | self.stateBuffer:clear() 129 | -- Set previously stored state as invalid (as no transition stored) 130 | self.memory:setInvalid() 131 | -- Reset bootstrap head 132 | if self.bootstraps > 0 then 133 | self.head = torch.random(self.bootstraps) 134 | end 135 | -- Forget last sequence 136 | if self.recurrent then 137 | self.policyNet:forget() 138 | end 139 | end 140 | 141 | -- Observes the results of the previous transition and chooses the next action to perform 142 | function Agent:observe(reward, rawObservation, terminal) 143 | -- Clip reward for stability 144 | if self.rewardClip > 0 then 145 | reward = math.max(reward, -self.rewardClip) 146 | reward = math.min(reward, self.rewardClip) 147 | end 148 | 149 | -- Process observation of current state 150 | local observation = self.model:preprocess(rawObservation) -- Must avoid side-effects on observation from env 151 | 152 | -- Store in buffer depending on terminal status 153 | if terminal then 154 | self.stateBuffer:pushReset(observation) -- Will clear buffer on next push 155 | else 156 | self.stateBuffer:push(observation) 157 | end 158 | -- Retrieve current and historical states from state buffer 159 | local state = self.stateBuffer:readAll() 160 | 161 | -- Set ε based on training vs. evaluation mode 162 | local epsilon = 0.001 -- Taken from tuned DDQN evaluation 163 | if self.isTraining then 164 | if self.globals.step < self.learnStart then 165 | -- Keep ε constant before learning starts 166 | epsilon = self.epsilonStart 167 | else 168 | -- Use annealing ε 169 | epsilon = math.max(self.epsilonStart + (self.globals.step - self.learnStart - 1)*self.epsilonGrad, self.epsilonEnd) 170 | end 171 | end 172 | 173 | local aIndex = 1 -- In a terminal state, choose no-op/first action by default 174 | if not terminal then 175 | if not self.isTraining and self.bootstraps > 0 then 176 | -- Retrieve estimates from all heads 177 | local QHeads = self.policyNet:forward(state) 178 | 179 | -- Use ensemble policy with bootstrap heads (in evaluation mode) 180 | local QHeadsMax, QHeadsMaxInds = QHeads:max(2) -- Find max action per head 181 | aIndex = torch.mode(QHeadsMaxInds:float(), 1)[1][1] -- TODO: Torch.CudaTensor:mode is missing 182 | 183 | -- Plot uncertainty in ensemble policy 184 | if qt then 185 | gnuplot.hist(QHeadsMaxInds, self.m, 0.5, self.m + 0.5) 186 | end 187 | 188 | -- Compute saliency map 189 | if self.saliency then 190 | self:computeSaliency(state, aIndex, true) 191 | end 192 | elseif torch.uniform() < epsilon then 193 | -- Choose action by ε-greedy exploration (even with bootstraps) 194 | aIndex = torch.random(1, self.m) 195 | 196 | -- Forward state anyway if recurrent 197 | if self.recurrent then 198 | self.policyNet:forward(state) 199 | end 200 | 201 | -- Reset saliency if action not chosen by network 202 | if self.saliency then 203 | self.saliencyMap:zero() 204 | end 205 | else 206 | -- Retrieve estimates from all heads 207 | local QHeads = self.policyNet:forward(state) 208 | 209 | -- Sample from current episode head (indexes on first dimension with no batch) 210 | local Qs = QHeads:select(1, self.head) 211 | local maxQ = Qs[1] 212 | local bestAs = {1} 213 | -- Find best actions 214 | for a = 2, self.m do 215 | if Qs[a] > maxQ then 216 | maxQ = Qs[a] 217 | bestAs = {a} 218 | elseif Qs[a] == maxQ then -- Ties can occur even with floats 219 | bestAs[#bestAs + 1] = a 220 | end 221 | end 222 | -- Perform random tie-breaking (if more than one argmax action) 223 | aIndex = bestAs[torch.random(1, #bestAs)] 224 | 225 | -- Compute saliency 226 | if self.saliency then 227 | self:computeSaliency(state, aIndex, false) 228 | end 229 | end 230 | end 231 | 232 | -- If training 233 | if self.isTraining then 234 | -- Store experience tuple parts (including pre-emptive action) 235 | 236 | local defaultMask = torch.ByteTensor(self.heads):fill(1) -- By default, the no head is masked 237 | local mask = defaultMask:clone() 238 | if self.bootstraps > 0 then 239 | mask = mask:bernoulli(0.5) -- Sample a mask for bootstrap using p = 0.5; Given in https://arxiv.org/pdf/1602.04621.pdf 240 | end 241 | self.memory:store(reward, observation, terminal, aIndex, mask) -- TODO: Sample independent Bernoulli(p) bootstrap masks for all heads; p = 1 means no masks needed 242 | 243 | -- Collect validation transitions at the start 244 | if self.globals.step <= self.valSize + 1 then 245 | self.valMemory:store(reward, observation, terminal, aIndex, defaultMask) 246 | end 247 | 248 | -- Sample uniformly or with prioritised sampling 249 | if self.globals.step % self.memSampleFreq == 0 and self.globals.step >= self.learnStart then 250 | for n = 1, self.memNSamples do 251 | -- Optimise (learn) from experience tuples 252 | self:optimise(self.memory:sample(self.head)) 253 | end 254 | end 255 | 256 | -- Update target network every τ steps 257 | if self.globals.step % self.tau == 0 and self.globals.step >= self.learnStart then 258 | self.targetNet = self.policyNet:clone() 259 | self.targetNet:evaluate() 260 | end 261 | 262 | -- Rebalance priority queue for prioritised experience replay 263 | if self.globals.step % self.memSize == 0 and self.memPriority then 264 | self.memory:rebalance() 265 | end 266 | end 267 | 268 | if terminal then 269 | if self.bootstraps > 0 then 270 | -- Change bootstrap head for next episode 271 | self.head = torch.random(self.bootstraps) 272 | elseif self.recurrent then 273 | -- Forget last sequence 274 | self.policyNet:forget() 275 | end 276 | end 277 | 278 | -- Return action index with offset applied 279 | return aIndex - self.actionOffset 280 | end 281 | 282 | -- Learns from experience 283 | function Agent:learn(x, indices, ISWeights, isValidation) 284 | -- Copy x to parameters θ if necessary 285 | if x ~= self.theta then 286 | self.theta:copy(x) 287 | end 288 | -- Reset gradients dθ 289 | self.dTheta:zero() 290 | 291 | -- Retrieve experience tuples 292 | local memory = isValidation and self.valMemory or self.memory 293 | local states, actions, rewards, transitions, terminals = memory:retrieve(indices) -- Terminal status is for transition (can't act in terminal state) 294 | local N = actions:size(1) 295 | 296 | if self.recurrent then 297 | -- Forget last sequence 298 | self.policyNet:forget() 299 | self.targetNet:forget() 300 | end 301 | 302 | -- Perform argmax action selection 303 | local APrimeMax, APrimeMaxInds 304 | if self.doubleQ then 305 | -- Calculate Q-values from transition using policy network 306 | self.QPrimes = self.policyNet:forward(transitions) -- Find argmax actions using policy network 307 | -- Perform argmax action selection on transition using policy network: argmax_a[Q(s', a; θpolicy)] 308 | APrimeMax, APrimeMaxInds = torch.max(self.QPrimes, 3) 309 | -- Calculate Q-values from transition using target network 310 | self.QPrimes = self.targetNet:forward(transitions) -- Evaluate Q-values of argmax actions using target network 311 | else 312 | -- Calculate Q-values from transition using target network 313 | self.QPrimes = self.targetNet:forward(transitions) -- Find and evaluate Q-values of argmax actions using target network 314 | -- Perform argmax action selection on transition using target network: argmax_a[Q(s', a; θtarget)] 315 | APrimeMax, APrimeMaxInds = torch.max(self.QPrimes, 3) 316 | end 317 | 318 | -- Initially set target Y = Q(s', argmax_a[Q(s', a; θ)]; θtarget), where initial θ is either θtarget (DQN) or θpolicy (DDQN) 319 | local Y = self.Tensor(N, self.heads) 320 | for n = 1, N do 321 | self.QPrimes[n]:mul(1 - terminals[n]) -- Zero Q(s' a) when s' is terminal 322 | Y[n] = self.QPrimes[n]:gather(2, APrimeMaxInds[n]) 323 | end 324 | -- Calculate target Y := r + γ.Q(s', argmax_a[Q(s', a; θ)]; θtarget) 325 | Y:mul(self.gamma):add(rewards:repeatTensor(1, self.heads)) 326 | 327 | -- Get all predicted Q-values from the current state 328 | if self.recurrent and self.doubleQ then 329 | self.policyNet:forget() 330 | end 331 | local QCurr = self.policyNet:forward(states) -- Correct internal state of policy network before backprop 332 | local QTaken = self.Tensor(N, self.heads) 333 | -- Get prediction of current Q-values with given actions 334 | for n = 1, N do 335 | QTaken[n] = QCurr[n][{{}, {actions[n]}}] 336 | end 337 | 338 | -- Calculate TD-errors δ := ∆Q(s, a) = Y − Q(s, a) 339 | self.tdErr = Y - QTaken 340 | 341 | -- Calculate Advantage Learning update(s) 342 | if self.PALpha > 0 then 343 | -- Calculate Q(s, a) and V(s) using target network 344 | if self.recurrent then 345 | self.targetNet:forget() 346 | end 347 | local Qs = self.targetNet:forward(states) 348 | local Q = self.Tensor(N, self.heads) 349 | for n = 1, N do 350 | Q[n] = Qs[n][{{}, {actions[n]}}] 351 | end 352 | local V = torch.max(Qs, 3) -- Current states cannot be terminal 353 | 354 | -- Calculate Advantage Learning update ∆ALQ(s, a) := δ − αPAL(V(s) − Q(s, a)) 355 | local tdErrAL = self.tdErr - V:csub(Q):mul(self.PALpha) 356 | 357 | -- Calculate Q(s', a) and V(s') using target network 358 | local QPrime = self.Tensor(N, self.heads) 359 | for n = 1, N do 360 | QPrime[n] = self.QPrimes[n][{{}, {actions[n]}}] 361 | end 362 | self.VPrime = torch.max(self.QPrimes, 3) 363 | 364 | -- Calculate Persistent Advantage Learning update ∆PALQ(s, a) := max[∆ALQ(s, a), δ − αPAL(V(s') − Q(s', a))] 365 | self.tdErr = torch.max(torch.cat(tdErrAL, self.tdErr:csub((self.VPrime:csub(QPrime):mul(self.PALpha))), 3), 3):view(N, self.heads, 1) 366 | end 367 | 368 | -- Calculate loss 369 | local loss 370 | if self.tdClip > 0 then 371 | -- Squared loss is used within clipping range, absolute loss is used outside (approximates Huber loss) 372 | local sqLoss = torch.cmin(torch.abs(self.tdErr), self.tdClip) 373 | local absLoss = torch.abs(self.tdErr) - sqLoss 374 | loss = torch.mean(sqLoss:pow(2):mul(0.5):add(absLoss:mul(self.tdClip))) -- Average over heads 375 | 376 | -- Clip TD-errors δ 377 | self.tdErr:clamp(-self.tdClip, self.tdClip) 378 | else 379 | -- Squared loss 380 | loss = torch.mean(self.tdErr:clone():pow(2):mul(0.5)) -- Average over heads 381 | end 382 | 383 | -- Exit if being used for validation metrics 384 | if isValidation then 385 | return 386 | end 387 | 388 | -- Send TD-errors δ to be used as priorities 389 | self.memory:updatePriorities(indices, torch.mean(self.tdErr, 2)) -- Use average error over heads 390 | 391 | -- Zero QCurr outputs (no error) 392 | QCurr:zero() 393 | -- Set TD-errors δ with given actions 394 | for n = 1, N do 395 | -- Correct prioritisation bias with importance-sampling weights 396 | QCurr[n][{{}, {actions[n]}}] = torch.mul(-self.tdErr[n], ISWeights[n]) -- Negate target to use gradient descent (not ascent) optimisers 397 | end 398 | 399 | -- Backpropagate (network accumulates gradients internally) 400 | self.policyNet:backward(states, QCurr) -- TODO: Work out why DRQN crashes on different batch sizes 401 | -- Clip the L2 norm of the gradients 402 | if self.gradClip > 0 then 403 | self.policyNet:gradParamClip(self.gradClip) 404 | end 405 | 406 | if self.recurrent then 407 | -- Forget last sequence 408 | self.policyNet:forget() 409 | self.targetNet:forget() 410 | -- Previous hidden state of policy net not restored as model parameters changed 411 | end 412 | 413 | return loss, self.dTheta 414 | end 415 | 416 | -- Optimises the network parameters θ 417 | function Agent:optimise(indices, ISWeights) 418 | -- Create function to evaluate given parameters x 419 | local feval = function(x) 420 | return self:learn(x, indices, ISWeights) 421 | end 422 | 423 | -- Optimise 424 | local __, loss = optim[self.optimiser](feval, self.theta, self.optimParams) 425 | -- Store loss 426 | if self.globals.step % self.progFreq == 0 then 427 | self.losses[#self.losses + 1] = loss[1] 428 | end 429 | 430 | return loss[1] 431 | end 432 | 433 | -- Pretty prints array 434 | local pprintArr = function(memo, v) 435 | return memo .. ', ' .. v 436 | end 437 | 438 | -- Reports absolute network weights and gradients 439 | function Agent:report() 440 | -- Collect layer with weights 441 | local weightLayers = self.policyNet:findModules('nn.SpatialConvolution') 442 | if #weightLayers == 0 then 443 | -- Assume cuDNN convolutions 444 | weightLayers = self.policyNet:findModules('cudnn.SpatialConvolution') 445 | end 446 | local fcLayers = self.policyNet:findModules('nn.Linear') 447 | weightLayers = _.append(weightLayers, fcLayers) 448 | 449 | -- Array of norms and maxima 450 | local wNorms = {} 451 | local wMaxima = {} 452 | local wGradNorms = {} 453 | local wGradMaxima = {} 454 | 455 | -- Collect statistics 456 | for l = 1, #weightLayers do 457 | local w = weightLayers[l].weight:clone():abs() -- Weights (absolute) 458 | wNorms[#wNorms + 1] = torch.mean(w) -- Weight norms: 459 | wMaxima[#wMaxima + 1] = torch.max(w) -- Weight max 460 | w = weightLayers[l].gradWeight:clone():abs() -- Weight gradients (absolute) 461 | wGradNorms[#wGradNorms + 1] = torch.mean(w) -- Weight grad norms: 462 | wGradMaxima[#wGradMaxima + 1] = torch.max(w) -- Weight grad max 463 | end 464 | 465 | -- Create report string table 466 | local reports = { 467 | 'Weight norms: ' .. _.reduce(wNorms, pprintArr), 468 | 'Weight max: ' .. _.reduce(wMaxima, pprintArr), 469 | 'Weight gradient norms: ' .. _.reduce(wGradNorms, pprintArr), 470 | 'Weight gradient max: ' .. _.reduce(wGradMaxima, pprintArr) 471 | } 472 | 473 | return reports 474 | end 475 | 476 | -- Reports stats for validation 477 | function Agent:validate() 478 | -- Validation variables 479 | local totalV, totalTdErr = 0, 0 480 | 481 | -- Loop over validation transitions 482 | local nBatches = math.ceil(self.valSize / self.batchSize) 483 | local ISWeights = self.Tensor(self.batchSize):fill(1) 484 | local startIndex, endIndex, batchSize, indices 485 | for n = 1, nBatches do 486 | startIndex = (n - 1)*self.batchSize + 2 487 | endIndex = math.min(n*self.batchSize + 1, self.valSize + 1) 488 | batchSize = endIndex - startIndex + 1 489 | indices = torch.linspace(startIndex, endIndex, batchSize):long() 490 | 491 | -- Perform "learning" (without optimisation) 492 | self:learn(self.theta, indices, ISWeights:narrow(1, 1, batchSize), true) 493 | 494 | -- Calculate V(s') and TD-error δ 495 | if self.PALpha == 0 then 496 | self.VPrime = torch.max(self.QPrimes, 3) 497 | end 498 | -- Average over heads 499 | totalV = totalV + torch.mean(self.VPrime, 2):sum() 500 | totalTdErr = totalTdErr + torch.mean(self.tdErr, 2):abs():sum() 501 | end 502 | 503 | -- Average and insert values 504 | self.avgV[#self.avgV + 1] = totalV / self.valSize 505 | self.avgTdErr[#self.avgTdErr + 1] = totalTdErr / self.valSize 506 | 507 | -- Plot and save losses 508 | if #self.losses > 0 then 509 | local losses = torch.Tensor(self.losses) 510 | gnuplot.pngfigure(paths.concat(self.experiments, self._id, 'losses.png')) 511 | gnuplot.plot('Loss', torch.linspace(math.floor(self.learnStart/self.progFreq), math.floor(self.globals.step/self.progFreq), #self.losses), losses, '-') 512 | gnuplot.xlabel('Step (x' .. self.progFreq .. ')') 513 | gnuplot.ylabel('Loss') 514 | gnuplot.plotflush() 515 | torch.save(paths.concat(self.experiments, self._id, 'losses.t7'), losses) 516 | end 517 | -- Plot and save V 518 | local epochIndices = torch.linspace(1, #self.avgV, #self.avgV) 519 | local Vs = torch.Tensor(self.avgV) 520 | gnuplot.pngfigure(paths.concat(self.experiments, self._id, 'Vs.png')) 521 | gnuplot.plot('V', epochIndices, Vs, '-') 522 | gnuplot.xlabel('Epoch') 523 | gnuplot.ylabel('V') 524 | gnuplot.movelegend('left', 'top') 525 | gnuplot.plotflush() 526 | torch.save(paths.concat(self.experiments, self._id, 'V.t7'), Vs) 527 | -- Plot and save TD-error δ 528 | local TDErrors = torch.Tensor(self.avgTdErr) 529 | gnuplot.pngfigure(paths.concat(self.experiments, self._id, 'TDErrors.png')) 530 | gnuplot.plot('TD-Error', epochIndices, TDErrors, '-') 531 | gnuplot.xlabel('Epoch') 532 | gnuplot.ylabel('TD-Error') 533 | gnuplot.plotflush() 534 | torch.save(paths.concat(self.experiments, self._id, 'TDErrors.t7'), TDErrors) 535 | -- Plot and save average score 536 | local scores = torch.Tensor(self.valScores) 537 | gnuplot.pngfigure(paths.concat(self.experiments, self._id, 'scores.png')) 538 | gnuplot.plot('Score', epochIndices, scores, '-') 539 | gnuplot.xlabel('Epoch') 540 | gnuplot.ylabel('Average Score') 541 | gnuplot.movelegend('left', 'top') 542 | gnuplot.plotflush() 543 | torch.save(paths.concat(self.experiments, self._id, 'scores.t7'), scores) 544 | -- Plot and save normalised score 545 | if #self.normScores > 0 then 546 | local normScores = torch.Tensor(self.normScores) 547 | gnuplot.pngfigure(paths.concat(self.experiments, self._id, 'normScores.png')) 548 | gnuplot.plot('Score', epochIndices, normScores, '-') 549 | gnuplot.xlabel('Epoch') 550 | gnuplot.ylabel('Normalised Score') 551 | gnuplot.movelegend('left', 'top') 552 | gnuplot.plotflush() 553 | torch.save(paths.concat(self.experiments, self._id, 'normScores.t7'), normScores) 554 | end 555 | gnuplot.close() 556 | 557 | return self.avgV[#self.avgV], self.avgTdErr[#self.avgTdErr] 558 | end 559 | 560 | -- Saves network convolutional filters as images 561 | function Agent:visualiseFilters() 562 | local filters = self.model:getFilters() 563 | 564 | for i, v in ipairs(filters) do 565 | image.save(paths.concat(self.experiments, self._id, 'conv_layer_' .. i .. '.png'), v) 566 | end 567 | end 568 | 569 | -- Sets saliency style 570 | function Agent:setSaliency(saliency) 571 | self.saliency = saliency 572 | self.model:setSaliency(saliency) 573 | end 574 | 575 | -- Computes a saliency map (assuming a forward pass of a single state) 576 | function Agent:computeSaliency(state, index, ensemble) 577 | -- Switch to possibly special backpropagation 578 | self.model:salientBackprop() 579 | 580 | -- Create artificial high target 581 | local maxTarget = self.Tensor(self.heads, self.m):zero() 582 | if ensemble then 583 | -- Set target on all heads (when using ensemble policy) 584 | maxTarget[{{}, {index}}] = 1 585 | else 586 | -- Set target on current head 587 | maxTarget[self.head][index] = 1 588 | end 589 | 590 | -- Backpropagate to inputs 591 | self.inputGrads = self.policyNet:backward(state, maxTarget) 592 | -- Saliency map ref used by Display 593 | self.saliencyMap = torch.abs(self.inputGrads:select(1, self.recurrent and 1 or self.histLen):float()) 594 | 595 | -- Switch back to normal backpropagation 596 | self.model:normalBackprop() 597 | end 598 | 599 | -- Saves the network parameters θ 600 | function Agent:saveWeights(path) 601 | torch.save(path, self.theta:float()) -- Do not save as CudaTensor to increase compatibility 602 | end 603 | 604 | -- Loads network parameters θ 605 | function Agent:loadWeights(path) 606 | local weights = torch.load(path) 607 | self.theta:copy(weights) 608 | self.targetNet = self.policyNet:clone() 609 | self.targetNet:evaluate() 610 | end 611 | 612 | return Agent 613 | --------------------------------------------------------------------------------