├── README.md ├── example.lua ├── LICENSE ├── download_mnist.lua └── mnist_cluttered.lua /README.md: -------------------------------------------------------------------------------- 1 | Cluttered MNIST Dataset 2 | ======================= 3 | 4 | A setup script will download MNIST and produce `mnist/*.t7` files: 5 | 6 | luajit download_mnist.lua 7 | 8 | Example usage: 9 | 10 | local mnist_cluttered = require 'mnist_cluttered' 11 | -- The observation will have size 1x100x100 with 8 distractors. 12 | local dataConfig = {megapatch_w=100, num_dist=8} 13 | local dataInfo = mnist_cluttered.createData(dataConfig) 14 | local observation, target = unpack(dataInfo.nextExample()) 15 | -------------------------------------------------------------------------------- /example.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright 2014 Google Inc. All Rights Reserved. 3 | 4 | Use of this source code is governed by a BSD-style 5 | license that can be found in the LICENSE file or at 6 | https://developers.google.com/open-source/licenses/bsd 7 | ]] 8 | 9 | local mnist_cluttered = require 'mnist_cluttered' 10 | 11 | local dataConfig = {megapatch_w=100, num_dist=8} 12 | local dataInfo = mnist_cluttered.createData(dataConfig) 13 | local observation, target = unpack(dataInfo.nextExample()) 14 | print("observation size:", table.concat(observation:size():totable(), 'x')) 15 | print("targets:", target) 16 | 17 | print("Saving example.png") 18 | require 'image' 19 | local formatted = image.toDisplayTensor({input=observation}) 20 | image.save("example.png", formatted) 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2015, Google Inc. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | * Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above 10 | copyright notice, this list of conditions and the following disclaimer 11 | in the documentation and/or other materials provided with the 12 | distribution. 13 | * Neither the name of Google Inc. nor the names of its 14 | contributors may be used to endorse or promote products derived from 15 | this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /download_mnist.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright 2014 Google Inc. All Rights Reserved. 3 | 4 | Use of this source code is governed by a BSD-style 5 | license that can be found in the LICENSE file or at 6 | https://developers.google.com/open-source/licenses/bsd 7 | ]] 8 | 9 | --[[ 10 | Script to download and save mnist data. 11 | 12 | - gets files from Yann LeCun's web site (http://yann.lecun.com/exdb/mnist/) 13 | - Processes data into a table containing 'data' and 'labels' tensors. 14 | ]] 15 | 16 | require 'os' 17 | require 'torch' 18 | require 'paths' 19 | 20 | 21 | local DIR = "mnist" 22 | local FILENAMES = { "train-images-idx3-ubyte", "train-labels-idx1-ubyte", 23 | "t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte" } 24 | local URLS = { "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", 25 | "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", 26 | "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", 27 | "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz" } 28 | local TRAINSIZE = 60000 29 | local TESTSIZE = 10000 30 | local trainSet = { data = torch.ByteTensor(TRAINSIZE, 28, 28), 31 | labels = torch.ByteTensor(TRAINSIZE, 1) } 32 | local testSet = { data = torch.ByteTensor(TESTSIZE, 28, 28), 33 | labels = torch.ByteTensor(TESTSIZE, 1) } 34 | 35 | local function runCmd(cmd) 36 | local exitCode = os.execute(cmd) 37 | if exitCode ~= true and exitCode ~= 0 then 38 | error("failed cmd: " .. cmd) 39 | end 40 | end 41 | 42 | local function downloadData() 43 | -- download 44 | print "Downloading data" 45 | for i = 1, 4 do 46 | os.remove(FILENAMES[i] .. ".gz") 47 | runCmd("wget " .. URLS[i]) 48 | end 49 | 50 | -- Unpack and store 51 | print "Unpacking data" 52 | runCmd("mkdir -p " .. DIR) 53 | 54 | for i = 1, 4 do 55 | runCmd("gunzip " .. FILENAMES[i] .. ".gz") 56 | assert(os.rename(FILENAMES[i], paths.concat(DIR, FILENAMES[i]))) 57 | end 58 | end 59 | 60 | 61 | function processTrainData() 62 | -- see data format as described on http://yann.lecun.com/exdb/mnist/ 63 | print "Reformatting training set" 64 | 65 | -- open training data file and check headers 66 | local trainData = torch.DiskFile("mnist/" .. FILENAMES[1], "r") 67 | trainData:binary() 68 | trainData:bigEndianEncoding() 69 | local magicNumber = trainData:readInt() 70 | local numberOfItems = trainData:readInt() 71 | local nRows = trainData:readInt() 72 | local nCols = trainData:readInt() 73 | assert(magicNumber == 2051) 74 | assert(numberOfItems == TRAINSIZE) 75 | assert(nRows == 28) 76 | assert(nCols == 28) 77 | 78 | -- open labels data file and check headers 79 | local trainLabels = torch.DiskFile("mnist/" .. FILENAMES[2], "r") 80 | trainLabels:binary() 81 | trainLabels:bigEndianEncoding() 82 | magicNumber = trainLabels:readInt() 83 | numberOfItems = trainLabels:readInt() 84 | assert(magicNumber == 2049) 85 | assert(numberOfItems == TRAINSIZE) 86 | 87 | -- read all the data 88 | for i = 1, TRAINSIZE do 89 | if i % 1000 == 0 then 90 | print(i .. "/" .. TRAINSIZE .. " done.") 91 | end 92 | -- read training image 93 | trainSet.data[i]:apply(function() 94 | return trainData:readByte() 95 | end) 96 | -- read label 97 | local byte = trainLabels:readByte() 98 | trainSet.labels[i][1] = byte 99 | end 100 | 101 | -- close input files 102 | trainData:close() 103 | trainLabels:close() 104 | 105 | -- output torch files 106 | local nValidExamples = 10000 107 | torch.save('mnist/train.t7', { 108 | data = trainSet.data[{{1, TRAINSIZE - nValidExamples}}]:clone(), 109 | labels = trainSet.labels[{{1, TRAINSIZE - nValidExamples}}]:clone(), 110 | }) 111 | torch.save('mnist/valid.t7', { 112 | data = trainSet.data[{{TRAINSIZE - nValidExamples + 1, -1}}]:clone(), 113 | labels = trainSet.labels[{{TRAINSIZE - nValidExamples + 1, -1}}]:clone(), 114 | }) 115 | end 116 | 117 | 118 | local function processTestData() 119 | -- see data format as described on http://yann.lecun.com/exdb/mnist/ 120 | print "Reformatting test set" 121 | 122 | -- open training data file and check headers 123 | local testData = torch.DiskFile("mnist/" .. FILENAMES[3], "r") 124 | testData:binary() 125 | testData:bigEndianEncoding() 126 | local magicNumber = testData:readInt() 127 | local numberOfItems = testData:readInt() 128 | local nRows = testData:readInt() 129 | local nCols = testData:readInt() 130 | assert(magicNumber == 2051) 131 | assert(numberOfItems == TESTSIZE) 132 | assert(nRows == 28) 133 | assert(nCols == 28) 134 | 135 | -- open labels data file and check headers 136 | local testLabels = torch.DiskFile("mnist/" .. FILENAMES[4], "r") 137 | testLabels:binary() 138 | testLabels:bigEndianEncoding() 139 | magicNumber = testLabels:readInt() 140 | numberOfItems = testLabels:readInt() 141 | assert(magicNumber == 2049) 142 | assert(numberOfItems == TESTSIZE) 143 | 144 | -- read all the data 145 | for i = 1, TESTSIZE do 146 | if i % 1000 == 0 then 147 | print(i .. "/" .. TESTSIZE .. " done.") 148 | end 149 | -- read the image 150 | testSet.data[i]:apply(function() 151 | return testData:readByte() 152 | end) 153 | -- read label 154 | local byte = testLabels:readByte() 155 | testSet.labels[i][1] = byte 156 | end 157 | 158 | -- close input files 159 | testData:close() 160 | testLabels:close() 161 | 162 | -- output torch files 163 | torch.save('mnist/test.t7', testSet) 164 | end 165 | 166 | local function processData() 167 | print "Processing data" 168 | processTrainData() 169 | processTestData() 170 | end 171 | 172 | -- Execution starts here 173 | downloadData() 174 | 175 | -- Process data into Lua table & tensors 176 | processData() 177 | -------------------------------------------------------------------------------- /mnist_cluttered.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright 2014 Google Inc. All Rights Reserved. 3 | 4 | Use of this source code is governed by a BSD-style 5 | license that can be found in the LICENSE file or at 6 | https://developers.google.com/open-source/licenses/bsd 7 | ]] 8 | 9 | require 'torch' 10 | 11 | local M = {} 12 | 13 | -- Copies values from src to dst. 14 | local function update(dst, src) 15 | for k, v in pairs(src) do 16 | dst[k] = v 17 | end 18 | end 19 | 20 | -- Copies the config. An error is raised on unknown params. 21 | local function updateDefaults(dst, src) 22 | for k, v in pairs(src) do 23 | if dst[k] == nil then 24 | error("unsupported param: " .. k) 25 | end 26 | end 27 | update(dst, src) 28 | end 29 | 30 | local function loadDataset(path) 31 | local dataset = torch.load(path) 32 | dataset.data = dataset.data:type(torch.Tensor():type()) 33 | collectgarbage() 34 | dataset.data:mul(1/dataset.data:max()) 35 | 36 | if dataset.data[1]:dim() ~= 3 then 37 | local sideLen = math.sqrt(dataset.data[1]:nElement()) 38 | dataset.data = dataset.data:view(dataset.data:size(1), 1, sideLen, sideLen) 39 | end 40 | 41 | assert(dataset.labels:min() == 0, "expecting zero-based labels") 42 | return dataset 43 | end 44 | 45 | -- Return a list with pointers to selected examples. 46 | local function selectSamples(examples, nSamples) 47 | local nExamples = examples:size(1) 48 | local samples = {} 49 | for i = 1, nSamples do 50 | samples[i] = examples[torch.random(1, nExamples)] 51 | end 52 | return samples 53 | end 54 | 55 | -- Puts the sprite on a random position inside of the obs. 56 | -- The observation should have intensities in the [0, 1] range. 57 | local function placeSpriteRandomly(obs, sprite, border) 58 | assert(obs:dim() == 3, "expecting an image") 59 | assert(sprite:dim() == 3, "expecting a sprite") 60 | local h = obs:size(2) 61 | local w = obs:size(3) 62 | local spriteH = sprite:size(2) 63 | local spriteW = sprite:size(3) 64 | 65 | local y = torch.random(1 + border, h - spriteH + 1 - border) 66 | local x = torch.random(1 + border, w - spriteW + 1 - border) 67 | 68 | local subTensor = obs[{{}, {y, y + spriteH - 1}, {x, x + spriteW - 1}}] 69 | subTensor:add(sprite) 70 | -- Keeping the values in the [0, 1] range. 71 | subTensor:apply(function(x) 72 | if x > 1 then 73 | return 1 74 | end 75 | if x < 0 then 76 | return 0 77 | end 78 | return x 79 | end) 80 | end 81 | 82 | local function placeDistractors(config, patch, examples) 83 | local distractors = selectSamples(examples, config.num_dist) 84 | local dist_w = config.dist_w 85 | local megapatch_w = config.megapatch_w 86 | 87 | local t_y, t_x, s_y, s_x 88 | for ind, d_patch in ipairs(distractors) do 89 | t_y = torch.random((megapatch_w-dist_w)+1)-1 90 | t_x = torch.random((megapatch_w-dist_w)+1)-1 91 | s_y = torch.random((d_patch:size(2)-dist_w)+1)-1 92 | s_x = torch.random((d_patch:size(3)-dist_w)+1)-1 93 | patch[{{}, {t_y+1,t_y+dist_w}, {t_x+1,t_x+dist_w}}]:add(d_patch[{{}, {s_y+1,s_y+dist_w}, {s_x+1,s_x+dist_w}}]) 94 | end 95 | patch[patch:ge(1)]=1 96 | end 97 | 98 | -- Returns a map from {smallerDigit, biggerOrEqualDigit} 99 | -- to an input in the softmax output. 100 | local function createIndexMap(n, k) 101 | assert(k == 2, "expecting k=2") 102 | local indexMap = torch.Tensor(n, n):fill(0/0) 103 | local nextIndex = 1 104 | for i = 1, n do 105 | for j = i, n do 106 | indexMap[i][j] = nextIndex 107 | nextIndex = nextIndex + 1 108 | end 109 | end 110 | assert(k == 2 and nextIndex - 1 == (n * (n + 1))/2, "wrong count for k=2") 111 | return indexMap 112 | end 113 | 114 | local targetFilling = {} 115 | function targetFilling.mark(target, usedClasses, config) 116 | -- The used encoding: 117 | -- target[digit + 1] will be 1 if the zero-based digit is present. 118 | target:resize(config.nClasses) 119 | :zero() 120 | for _, class in ipairs(usedClasses) do 121 | target[class] = 1 122 | end 123 | end 124 | 125 | function targetFilling.combine(target, usedClasses, config) 126 | -- We will have one softmax output for each 127 | -- combination-with-repetion of the two possible digits. 128 | local nClasses = config.nClasses 129 | local nOutputs = 1 130 | for k = 1, config.nDigits do 131 | nOutputs = nOutputs * (nClasses + k - 1) / k 132 | end 133 | target:resize(nOutputs) 134 | :zero() 135 | config.indexMap = config.indexMap or createIndexMap(nClasses, config.nDigits) 136 | assert(config.indexMap:max() == nOutputs, "wrong nOutputs") 137 | table.sort(usedClasses) 138 | target[config.indexMap[usedClasses]] = 1 139 | end 140 | 141 | function targetFilling.sum(target, usedClasses, config) 142 | local maxValue = (config.nClasses - 1) * config.nDigits 143 | -- The possible sums are {0, 1, ..., maxValue} 144 | target:resize(1 + maxValue) 145 | :zero() 146 | if config.nDigits == 2 then 147 | assert(target:nElement() == 19, "expecting 19 targets") 148 | end 149 | local value = torch.Tensor(usedClasses):add(-1):sum() 150 | assert(value >= 0 and value <= maxValue, "wrong sum") 151 | target[1 + value] = 1 152 | end 153 | 154 | -- The task is a classification of MNIST digits. 155 | -- Each training example has a MNIST digit placed on a bigger black background. 156 | function M.createData(extraConfig) 157 | local config = { 158 | datasetPath = 'mnist/train.t7', 159 | -- The size of the background. 160 | megapatch_w = 28, 161 | -- Number of distractors. 162 | num_dist = 0, 163 | -- The distractor width. 164 | dist_w = 8, 165 | -- The width of a black border. 166 | border = 0, 167 | -- The number of digits in on image. 168 | nDigits = 1, 169 | -- The number of digit classes. 170 | nClasses = 10, 171 | 172 | -- The digits can be combined into one target for a softmax. 173 | -- Or the digits can be summed together. 174 | -- Otherwise the target should be modeled by Bernoulli units. 175 | targetFilling = "mark", 176 | } 177 | updateDefaults(config, extraConfig) 178 | 179 | local dataset = loadDataset(config.datasetPath) 180 | 181 | local nExamples = dataset.data:size(1) 182 | local perm = torch.Tensor() 183 | 184 | local obs = torch.Tensor(dataset.data[1]:size(1), config.megapatch_w, config.megapatch_w) 185 | assert(dataset.labels:max() < config.nClasses, "expecting labels from {0, .., nClasses - 1}") 186 | 187 | local target = torch.Tensor() 188 | local fillTarget = assert(targetFilling[config.targetFilling], "unknown targetFilling") 189 | local step = nExamples 190 | local function nextExample() 191 | obs:zero() 192 | placeDistractors(config, obs, dataset.data) 193 | 194 | local usedClasses = {} 195 | for i = 1, config.nDigits do 196 | step = step + 1 197 | if step > nExamples then 198 | torch.randperm(perm, nExamples) 199 | step = 1 200 | end 201 | 202 | local sprite = dataset.data[perm[step]] 203 | placeSpriteRandomly(obs, sprite, config.border) 204 | 205 | local selectedDigit = dataset.labels[perm[step]][1] 206 | -- The marked class will be from {1, .., nClasses}. 207 | table.insert(usedClasses, selectedDigit + 1) 208 | end 209 | 210 | fillTarget(target, usedClasses, config) 211 | return {obs, target} 212 | end 213 | 214 | return { 215 | nextExample = nextExample, 216 | } 217 | end 218 | 219 | return M 220 | --------------------------------------------------------------------------------