├── README.md ├── build_ale.sh ├── build_alewrap.sh ├── produce_dataset.sh ├── roms └── .gitignore └── util ├── gen_bw_2014.lua ├── see_frames.lua └── split.lua /README.md: -------------------------------------------------------------------------------- 1 | Producing of the dataset 2 | ======================== 3 | 1) Get [Torch](http://torch.ch/) 4 | 5 | 2) Put atari ROMs for freeway, pong, riverraid, seaquest, space_invaders 6 | under `roms/` 7 | 8 | SHA1 sums of the supported ROM files: 9 | 10 | 91cc7e5cd6c0d4a6f42ed66353b7ee7bb972fa3f roms/freeway.bin 11 | 1ffe89d79d55adabc0916b95cc37e18619ef7830 roms/pong.bin 12 | 40329780402f8247f294fe884ffc56cc3da0c62d roms/riverraid.bin 13 | 7324a1ebc695a477c8884718ffcad27732a98ab0 roms/seaquest.bin 14 | 31d9668fe5812c3d2e076987ca327ac6b2e280bf roms/space_invaders.bin 15 | 16 | 3) Build [Arcade Learning Environment](http://www.arcadelearningenvironment.org/) 17 | 18 | ./build_ale.sh 19 | 20 | 4) Build the Lua wrapper for the Arcade Learning Environment. 21 | 22 | ./build_alewrap.sh 23 | 24 | 5) Produce the frames. 25 | 26 | ./produce_dataset.sh 27 | 28 | 6) See some produced frames. 29 | 30 | qlua util/see_frames.lua freeway-train.bin 31 | 32 | 33 | Any bugs can be reported to ivo@danihelka.net 34 | -------------------------------------------------------------------------------- /build_ale.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -e 3 | wget http://www.arcadelearningenvironment.org/wp-content/uploads/2014/04/ale_0.4.4.zip 4 | unzip ale_0.4.4.zip 5 | cd ale_0.4.4/ale_0_4 6 | make -f makefile.`luajit -e 'if jit.os == "OSX" then print("mac") else print("unix") end'` 7 | -------------------------------------------------------------------------------- /build_alewrap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -e 3 | git clone git@github.com:fidlej/alewrap.git 4 | cd alewrap 5 | mkdir build 6 | cd build 7 | cmake -DCMAKE_BUILD_TYPE=Release -DALE_INCLUDE_DIR=../../ale_0.4.4/ale_0_4/src -DALE_LIBRARY=../../ale_0.4.4/ale_0_4/libale.so .. 8 | make 9 | cd .. 10 | ln -s build/alewrap/libalewrap.so 11 | 12 | -------------------------------------------------------------------------------- /produce_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -e 3 | 4 | # Generating the frames 5 | ALE_HOME="ale_0.4.4/ale_0_4" 6 | export LD_LIBRARY_PATH="$ALE_HOME" 7 | export DYLD_LIBRARY_PATH="$LD_LIBRARY_PATH" 8 | export LUA_CPATH="alewrap/?.so;;" 9 | export LUA_PATH="alewrap/?/init.lua;;" 10 | 11 | luajit util/gen_bw_2014.lua freeway 12 | luajit util/gen_bw_2014.lua pong 13 | luajit util/gen_bw_2014.lua riverraid 14 | luajit util/gen_bw_2014.lua seaquest 15 | luajit util/gen_bw_2014.lua space_invaders 16 | 17 | luajit util/split.lua 18 | 19 | rm freeway.bin 20 | rm pong.bin 21 | rm riverraid.bin 22 | rm seaquest.bin 23 | rm space_invaders.bin 24 | -------------------------------------------------------------------------------- /roms/.gitignore: -------------------------------------------------------------------------------- 1 | /*.bin 2 | -------------------------------------------------------------------------------- /util/gen_bw_2014.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | local alewrap = require 'alewrap' 3 | 4 | torch.manualSeed(0x123456789) 5 | math.randomseed(0x123456789) 6 | 7 | local romDir = 'roms/' 8 | local numSamplesPerGame = 100000 9 | local probAccept = 0.01 10 | 11 | -- edge detection 12 | local MAX_INTENSITY = 255 13 | local edgeNetwork = nn.SpatialConvolution(1, 1, 2, 2) 14 | edgeNetwork.weight:copy(torch.Tensor({ 15 | -- The Robert's Cross operator is used to detect the edges. 16 | -- http://homepages.inf.ed.ac.uk/rbf/CVonline/LOCAL_COPIES/OWENS/LECT6/node2.html 17 | 18 | -- Filter 1 19 | {{ 20 | {MAX_INTENSITY + 1, 1}, 21 | {-1, -(MAX_INTENSITY + 1)}, 22 | }}, 23 | })) 24 | edgeNetwork.bias:zero() 25 | 26 | function edgeDetector(input) 27 | input:resize(1, input:size(1), input:size(2)) 28 | local output = edgeNetwork:forward(input) 29 | output:apply(function (x) 30 | if x ~= 0 then 31 | return 1 32 | end 33 | return 0 34 | end) 35 | return output 36 | end 37 | 38 | local name = assert(arg[1], "Usage: %prog GAME_NAME") 39 | local romPath = romDir .. name .. '.bin' 40 | if not paths.filep(romPath) then 41 | io.stderr:write(string.format("Missing a ROM file: %s\n", romPath)) 42 | os.exit(1) 43 | end 44 | 45 | local env = alewrap.createEnv(romPath, {}) 46 | env:envStart() 47 | local actions = env:actions():storage():totable() 48 | local action = {torch.Tensor(1)} 49 | 50 | function sample() 51 | local pixels 52 | while true do 53 | action[1][1] = actions[math.random(1, #actions)] 54 | local reward, observe = env:envStep(action) 55 | if torch.rand(1)[1] < probAccept then 56 | pixels = edgeDetector(observe[1]:double()) 57 | break 58 | end 59 | end 60 | return pixels 61 | end 62 | 63 | io.write(name .. ' ') 64 | local binfile = io.open(string.format('%s.bin', name), 'w') 65 | for ii = 1,numSamplesPerGame do 66 | local samp = sample() 67 | if ii == 1 then 68 | print('frame size:', unpack(samp:size():totable())) 69 | end 70 | io.write('.') 71 | io.flush() 72 | 73 | samp:apply(function(pixel) 74 | binfile:write((pixel == 1 and "\1") or "\0") 75 | end) 76 | samp = nil 77 | collectgarbage() 78 | end 79 | binfile:close() 80 | 81 | 82 | -------------------------------------------------------------------------------- /util/see_frames.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'image' 3 | 4 | local function loadFrames(path) 5 | local nRows = 209 6 | local nCols = 159 7 | 8 | local frameStore = torch.ByteStorage(path) 9 | local frames = torch.ByteTensor(frameStore) 10 | local nFrames = frames:nElement() / (nRows * nCols) 11 | assert(nFrames == math.floor(nFrames), "unexpected frame size") 12 | frames:resize(nFrames, 1, nRows, nCols) 13 | return frames 14 | end 15 | 16 | local function main() 17 | local path = arg[1] or "pong-train.bin" 18 | print("Loading frames:", path) 19 | local frames = loadFrames(path) 20 | print("size:", frames:size()) 21 | 22 | local win 23 | for i = 1, frames:size(1) do 24 | win = image.display({image=frames[i], win=win}) 25 | end 26 | end 27 | 28 | main() 29 | -------------------------------------------------------------------------------- /util/split.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | local ffi = require 'ffi' 3 | if not torch.data then 4 | require 'torchffi' 5 | end 6 | 7 | local function loadFrames(path) 8 | local nRows = 209 9 | local nCols = 159 10 | 11 | local frameStore = torch.ByteStorage(path) 12 | local frames = torch.ByteTensor(frameStore) 13 | local nFrames = frames:nElement() / (nRows * nCols) 14 | assert(nFrames == math.floor(nFrames), "unexpected frame size") 15 | frames:resize(nFrames, nRows, nCols) 16 | return frames 17 | end 18 | 19 | local function saveBytes(path, tensor) 20 | assert(torch.typename(tensor) == "torch.ByteTensor", "expecting a ByteTensor") 21 | assert(tensor:isContiguous(), "expecting a contiguous tensor") 22 | local flatTensor = tensor.new(tensor):resize(tensor:nElement()) 23 | local bufferSize = 8*1024 24 | local output = assert(io.open(path, 'w')) 25 | for startIndex = 1, tensor:nElement(), bufferSize do 26 | local includedEnd = math.min(tensor:nElement(), startIndex + bufferSize - 1) 27 | local flatView = flatTensor[{{startIndex, includedEnd}}] 28 | local bytes = ffi.string(torch.data(flatView), flatView:nElement()) 29 | output:write(bytes) 30 | end 31 | output:close() 32 | end 33 | 34 | local function getShuffled(examples) 35 | local perm = torch.randperm(examples:size(1)) 36 | local shuffled = examples:clone() 37 | for i = 1, examples:size(1) do 38 | shuffled[i]:copy(examples[perm[i]]) 39 | end 40 | return shuffled 41 | end 42 | 43 | local function splitSet(allExamples, nTestFrames) 44 | local trainingExamples = allExamples[{{1, allExamples:size(1) - nTestFrames}}] 45 | local testExamples = allExamples[{{allExamples:size(1) - nTestFrames + 1, allExamples:size(1)}}] 46 | assert(testExamples:size(1) == nTestFrames, "wrong number of test examples") 47 | return trainingExamples, testExamples 48 | end 49 | 50 | local function main() 51 | torch.manualSeed(1) 52 | for _, game in ipairs({"freeway", "pong", "riverraid", "seaquest", "space_invaders"}) do 53 | print("Loading:", game) 54 | local file = string.format("%s.bin", game) 55 | local frames = loadFrames(file) 56 | local shuffled = getShuffled(frames) 57 | local train, test = splitSet(shuffled, 10000) 58 | 59 | saveBytes(string.format("%s-train.bin", game), train) 60 | saveBytes(string.format("%s-test.bin", game), test) 61 | end 62 | end 63 | 64 | main() 65 | --------------------------------------------------------------------------------