├── CMakeLists.txt ├── env.lua ├── utils ├── init.lua ├── nn.lua ├── sys.lua └── table.lua ├── rocks └── torchnet-scm-1.rockspec ├── log ├── transfer.lua ├── view │ ├── status.lua │ ├── json.lua │ └── text.lua ├── init.lua └── remotelog.lua ├── CONTRIBUTING.md ├── LICENSE ├── meter ├── mapmeter.lua ├── averagevaluemeter.lua ├── timemeter.lua ├── aucmeter.lua ├── apmeter.lua ├── confusionmeter.lua ├── classerrormeter.lua ├── init.lua ├── multilabelconfusionmeter.lua ├── ndcgmeter.lua ├── recallmeter.lua ├── precisionmeter.lua └── precisionatkmeter.lua ├── PATENTS ├── dataset ├── tabledataset.lua ├── init.lua ├── resampledataset.lua ├── shuffledataset.lua ├── concatdataset.lua ├── transformdataset.lua ├── coroutinebatchdataset.lua ├── splitdataset.lua ├── listdataset.lua ├── datasetiterator.lua ├── batchdataset.lua ├── paralleldatasetiterator.lua └── indexeddataset.lua ├── engine ├── init.lua ├── optimengine.lua └── sgdengine.lua ├── example └── mnist.lua ├── init.lua └── transform.lua /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required (VERSION 2.8) 2 | cmake_policy(VERSION 2.8) 3 | 4 | set(PKGNAME torchnet) 5 | 6 | file(GLOB_RECURSE luafiles RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.lua") 7 | 8 | foreach(file ${luafiles}) 9 | get_filename_component(dir ${file} PATH) 10 | install(FILES ${file} DESTINATION ${LUA_PATH}/${PKGNAME}/${dir}) 11 | endforeach() 12 | -------------------------------------------------------------------------------- /env.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = {} 11 | 12 | return tnt 13 | -------------------------------------------------------------------------------- /utils/init.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local doc = require 'argcheck.doc' 12 | 13 | doc[[ 14 | 15 | ### tnt.utils 16 | 17 | *Torchnet* provides a set of util functions which are used all over torchnet. 18 | ]] 19 | 20 | local utils = {} 21 | tnt.utils = utils 22 | 23 | utils.table = require 'torchnet.utils.table' 24 | utils.nn = require 'torchnet.utils.nn' 25 | utils.sys = require 'torchnet.utils.sys' 26 | 27 | return utils 28 | -------------------------------------------------------------------------------- /rocks/torchnet-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "torchnet" 2 | version = "scm-1" 3 | 4 | source = { 5 | url = "git://github.com/torchnet/torchnet.git" 6 | } 7 | 8 | description = { 9 | summary = "Torch on steroids", 10 | detailed = [[ 11 | Various abstractions for torch7. 12 | ]], 13 | homepage = "https://github.com/torchnet/torchnet", 14 | license = "BSD" 15 | } 16 | 17 | dependencies = { 18 | "lua >= 5.1", 19 | "torch >= 7.0", 20 | "nn >= 1.0", 21 | "argcheck >= 1.0", 22 | "threads >= 1.0", 23 | "md5 >= 1.0", 24 | "luafilesystem >= 1.0", 25 | "luasocket >= 1.0", 26 | "optim >= 1.0", 27 | "tds >= 1.0", 28 | } 29 | 30 | build = { 31 | type = "cmake", 32 | variables = { 33 | CMAKE_BUILD_TYPE="Release", 34 | LUA_PATH="$(LUADIR)", 35 | LUA_CPATH="$(LIBDIR)" 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /log/transfer.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local transfer = {} 11 | 12 | function transfer.send(c, data) 13 | data = torch.serialize(data) 14 | c:send(string.format("0x%0.16x", #data)) 15 | c:send(data) 16 | end 17 | 18 | function transfer.receive(c) 19 | local sz, err = c:receive(18) 20 | if err then 21 | return 22 | end 23 | sz = tonumber(sz) 24 | local data, err = c:receive(sz) 25 | if err then 26 | return 27 | end 28 | local status, data = pcall(torch.deserialize, data) 29 | if status then 30 | return data 31 | end 32 | end 33 | 34 | return transfer 35 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Torchnet 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Make sure your code lints. 12 | 5. If you haven't already, complete the Contributor License Agreement ("CLA"). 13 | 14 | ## Contributor License Agreement ("CLA") 15 | In order to accept your pull request, we need you to submit a CLA. You only need 16 | to do this once to work on any of Facebook's open source projects. 17 | 18 | Complete your CLA here: 19 | 20 | ## Issues 21 | We use GitHub issues to track public bugs. Please ensure your description is 22 | clear and has sufficient instructions to be able to reproduce the issue. 23 | 24 | ## Coding Style 25 | * 3 spaces for indentation rather than tabs 26 | * 80 character line length 27 | * variables names all lower-case, no underlines 28 | 29 | ## License 30 | By contributing to Torchnet, you agree that your contributions will be licensed 31 | under its BSD license. 32 | -------------------------------------------------------------------------------- /log/view/status.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local argcheck = require 'argcheck' 11 | 12 | local status = argcheck{ 13 | noordered=true, 14 | {name="filename", type="string", opt=true}, 15 | {name="append", type="boolean", default=false}, 16 | call = 17 | function(filename, append) 18 | if filename and not append then 19 | local f = io.open(filename, 'w') -- reset the file 20 | assert(f, string.format("could not open file <%s> for writing", filename)) 21 | f:close() 22 | end 23 | return function(data, key, value) 24 | if key == '__status__' then 25 | local status = tostring(value) 26 | if filename then 27 | local f = io.open(filename, 'a+') -- append 28 | assert(f, string.format("could not open file <%s> for writing", filename)) 29 | f:write(status) 30 | f:write("\n") 31 | f:close() 32 | else 33 | print(status) 34 | end 35 | end 36 | end 37 | end 38 | } 39 | 40 | return status 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For Torchnet software 4 | 5 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /meter/mapmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local mAPMeter = torch.class('tnt.mAPMeter', 'tnt.Meter', tnt) 14 | 15 | 16 | mAPMeter.__init = argcheck{ 17 | doc = [[ 18 | 19 | #### tnt.mAPMeter(@ARGP) 20 | @ARGT 21 | 22 | The `tnt.mAPMeter` measures the mean average precision over all classes. 23 | 24 | The `tnt.mAPMeter` is designed to operate on `NxK` Tensors `output` and `target` 25 | where (1) the `output` contains model output scores for `N` examples and `K` 26 | classes that ought to be higher when the model is more convinced that the 27 | example should be positively labeled, and smaller when the model believes the 28 | example should be negatively labeled (for instance, the output of a sigmoid 29 | function); and (2) the `target` contains only values 0 (for negative examples) 30 | and 1 (for positive examples). 31 | 32 | The `tnt.mAPMeter` has no parameters to be set. 33 | ]], 34 | {name="self", type="tnt.mAPMeter"}, 35 | call = function(self) 36 | self.apmeter = tnt.APMeter() 37 | end 38 | } 39 | 40 | mAPMeter.reset = argcheck{ 41 | {name="self", type="tnt.mAPMeter"}, 42 | call = function(self) 43 | self.apmeter:reset() 44 | end 45 | } 46 | 47 | mAPMeter.add = argcheck{ 48 | {name="self", type="tnt.mAPMeter"}, 49 | {name="output", type="torch.*Tensor"}, 50 | {name="target", type="torch.*Tensor"}, 51 | call = function(self, output, target) 52 | self.apmeter:add{output = output, target = target} 53 | end 54 | } 55 | 56 | mAPMeter.value = argcheck{ 57 | {name="self", type="tnt.mAPMeter"}, 58 | call = function(self) 59 | return self.apmeter:value():mean() 60 | end 61 | } 62 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the Torchnet software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /dataset/tabledataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local utils = require 'torchnet.utils' 13 | 14 | local TableDataset, ListDataset 15 | = torch.class('tnt.TableDataset', 'tnt.ListDataset', tnt) 16 | 17 | TableDataset.__init = argcheck{ 18 | doc = [[ 19 | 20 | #### tnt.TableDataset(@ARGP) 21 | @ARGT 22 | 23 | `tnt.TableDataset` interfaces existing data 24 | to torchnet. It is useful if you want to use torchnet on a small dataset. 25 | 26 | The data must be contained in a `tds.Hash`. 27 | 28 | `tnt.TableDataset` does a shallow copy of the data. 29 | 30 | Data are loaded while constructing the `tnt.TableDataset`: 31 | ```lua 32 | > a = tnt.TableDataset({1,2,3}) 33 | > print(a:size()) 34 | 3 35 | ``` 36 | `tnt.TableDataset` assumes that table has contiguous keys starting at 1. 37 | ]], 38 | noordered = true, 39 | {name = 'self', type = 'tnt.TableDataset'}, 40 | {name = 'data', type = 'table'}, 41 | call = function(self, data) 42 | for i = 1, #data do 43 | assert(data[i], "keys are not contiguous integers starting at 1") 44 | end 45 | local size = 0 46 | for _, _ in pairs(data) do size = size + 1 end 47 | assert(size == #data, "keys are not contiguous integers starting at 1") 48 | self.data = data 49 | end 50 | } 51 | 52 | TableDataset.size = argcheck{ 53 | {name = 'self', type = 'tnt.TableDataset'}, 54 | call = function(self) 55 | return #self.data 56 | end 57 | } 58 | 59 | TableDataset.get = argcheck{ 60 | {name = 'self', type = 'tnt.TableDataset'}, 61 | {name = 'idx', type = 'number'}, 62 | call = function(self, idx) 63 | assert(idx >= 1 and idx <= self:size(), 'index out of bound') 64 | assert(idx == math.floor(idx), 'index must be an integer') 65 | return utils.table.clone(self.data[idx]) 66 | end 67 | } 68 | -------------------------------------------------------------------------------- /utils/nn.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local argcheck = require 'argcheck' 11 | 12 | local unn = {} 13 | 14 | local function isz(okw, odw, kw, dw, pw) 15 | dw = dw or 1 16 | pw = pw or 0 17 | okw = okw*dw + kw-dw-pw*2 18 | odw = odw*dw 19 | return okw, odw 20 | end 21 | 22 | -- in specs, list of {kw=,dw=,pw=} (dw and pw optionals) 23 | -- kw (kernel width) 24 | -- dw (kernel stride) 25 | -- pw (padding) 26 | unn.inferinputsize = argcheck{ 27 | {name="specs", type="table"}, 28 | {name="size", type="number", default=1}, 29 | {name="verbose", type="boolean", default=false}, 30 | call = 31 | function(specs, size, verbose) 32 | local okw, odw = size, 1 33 | for i=#specs,1,-1 do 34 | if specs[i].kw then 35 | okw, odw = isz(okw, odw, specs[i].kw, specs[i].dw, specs[i].pw) 36 | end 37 | if verbose then 38 | print(string.format( 39 | "|| layer %d: size=%d stride=%d", 40 | i, 41 | okw, 42 | odw)) 43 | end 44 | end 45 | return okw, odw 46 | end 47 | } 48 | 49 | local function iszr(okw, odw, kw, dw, pw) 50 | dw = dw or 1 51 | pw = pw or 0 52 | okw = math.floor((okw+2*pw-kw)/dw)+1 53 | odw = odw * dw 54 | return okw, odw 55 | end 56 | 57 | unn.inferoutputsize = argcheck{ 58 | {name="specs", type="table"}, 59 | {name="size", type="number"}, 60 | {name="verbose", type="boolean", default=false}, 61 | call = 62 | function(specs, size, verbose) 63 | local okw, odw = size, 1 64 | for i=1,#specs do 65 | okw, odw = iszr(okw, odw, specs[i].kw, specs[i].dw, specs[i].pw) 66 | if verbose then 67 | print(string.format("|| layer %d: size=%dx stride=%d", okw, odw)) 68 | end 69 | end 70 | return okw, odw 71 | end 72 | } 73 | 74 | return unn 75 | -------------------------------------------------------------------------------- /dataset/init.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local doc = require 'argcheck.doc' 13 | 14 | doc[[ 15 | 16 | ### tnt.Dataset() 17 | 18 | *torchnet* provides a variety of data containers, which can be easily 19 | plugged between each others, allowing the user to easily concat, split, 20 | batch, resample etc... datasets. 21 | 22 | A instance `dataset` of a `tnt.Dataset()` implements two main methods: 23 | 24 | * `dataset:size()` which returns the size of the dataset. 25 | * `dataset:get(idx)` where `idx` is a number between 1 and the dataset size. 26 | 27 | While it is easy to iterate over a dataset with a for loop, several 28 | `DatasetIterator` iterators are nevertheless provided, allowing the user to 29 | filter out some samples in an on-the-fly manner, or to parallelize easily 30 | data fetching. 31 | 32 | In *torchnet*, a sample returned by `dataset:get()` is supposed to be a Lua 33 | `table`. Fields of the table can be arbitrary, even though many datasets 34 | will only work with torch tensors. 35 | 36 | ]] 37 | 38 | local Dataset = torch.class('tnt.Dataset', tnt) 39 | 40 | Dataset.__init = 41 | function() 42 | end 43 | 44 | -- returns a number 45 | Dataset.size = 46 | function(self) 47 | error(string.format( 48 | 'size not implemented for class <%s>', 49 | torch.type(self))) 50 | end 51 | 52 | -- execute a function 53 | Dataset.exec = 54 | function(self, name, ...) 55 | if type(self[name]) == 'function' then 56 | return self[name](self, ...) 57 | elseif self.dataset then 58 | return self.dataset:exec(name, ...) 59 | elseif self.__dataset then 60 | return self.__dataset:exec(name, ...) 61 | else 62 | error(string.format('unknown function <%s>', name)) 63 | end 64 | end 65 | 66 | -- returns a table of tensors 67 | Dataset.get = 68 | function(self) 69 | error(string.format( 70 | 'get not implemented for class <%s>', 71 | torch.type(self))) 72 | end 73 | -------------------------------------------------------------------------------- /meter/averagevaluemeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local AverageValueMeter = torch.class('tnt.AverageValueMeter', 'tnt.Meter', tnt) 14 | 15 | AverageValueMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.AverageValueMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.AverageValueMeter` measures the average value of any collection of 22 | numbers that are `add`ed to it. It is useful, for instance, to measure the 23 | average loss over a collection of examples. 24 | 25 | The `add()` function expects as input a Lua number `value`, which is the value 26 | that needs to be added to the list of values to average. It also takes as input 27 | an optional parameter `n` that assigns a weight to `value` in the average, in 28 | order to facilitate computing weighted averages (default = 1). 29 | 30 | The `tnt.AverageValueMeter` has no parameters to be set at initialization time. 31 | ]], 32 | {name="self", type="tnt.AverageValueMeter"}, 33 | call = 34 | function(self) 35 | self:reset() 36 | end 37 | } 38 | 39 | AverageValueMeter.reset = argcheck{ 40 | {name="self", type="tnt.AverageValueMeter"}, 41 | call = 42 | function(self) 43 | self.sum = 0 44 | self.n = 0 45 | self.var = 0 46 | end 47 | } 48 | 49 | AverageValueMeter.add = argcheck{ 50 | {name="self", type="tnt.AverageValueMeter"}, 51 | {name="value", type="number"}, 52 | {name="n", type="number", default=1}, 53 | call = 54 | function(self, value, n) 55 | self.sum = self.sum + value 56 | self.var = self.var + value * value 57 | self.n = self.n + n 58 | end 59 | } 60 | 61 | AverageValueMeter.value = argcheck{ 62 | {name="self", type="tnt.AverageValueMeter"}, 63 | call = 64 | function(self) 65 | local n = self.n 66 | local mean = self.sum / n 67 | -- unbiased estimator of the variance: 68 | local std = math.sqrt( (self.var - n * mean * mean) / (n-1) ) 69 | return mean, std 70 | end 71 | } 72 | -------------------------------------------------------------------------------- /log/view/json.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local argcheck = require 'argcheck' 11 | 12 | local json = argcheck{ 13 | noordered=true, 14 | {name="filename", type="string", opt=true}, 15 | {name="keys", type="table"}, 16 | {name="format", type="table", opt=true}, 17 | {name="append", type="boolean", default=false}, 18 | call = 19 | function(filename, keys__, format__, append) 20 | local keys = {} 21 | for idx, key in ipairs(keys__) do 22 | local format = format__ and format__[idx] 23 | if not format then 24 | table.insert(keys, {name=key, format=function(value) return string.format("%s %s", key, value) end}) 25 | elseif type(format) == 'function' then 26 | table.insert(keys, {name=key, format=format}) 27 | elseif type(format) == 'string' then 28 | table.insert(keys, {name=key, format=function(value) return string.format(format, value) end}) 29 | else 30 | error('format must be a string or a function') 31 | end 32 | end 33 | if filename and not append then 34 | local f = io.open(filename, 'w') -- reset the file 35 | assert(f, string.format("could not open file <%s> for writing", filename)) 36 | f:close() 37 | end 38 | return function(log) 39 | local txt = {} 40 | for _, key in ipairs(keys) do 41 | local format = key.format(log:get(key.name)) 42 | assert(type(format) == 'string', string.format("value for key %s cannot be converted to string", key)) 43 | table.insert(txt, string.format('"%s": "%s"', key.name, format)) 44 | end 45 | txt = string.format("{%s}", table.concat(txt, ", ")) 46 | if filename then 47 | local f = io.open(filename, 'a+') -- append 48 | assert(f, string.format("could not open file <%s> for writing", filename)) 49 | f:write(txt) 50 | f:write("\n") 51 | f:close() 52 | else 53 | print(txt) 54 | end 55 | end 56 | end 57 | } 58 | 59 | return json 60 | -------------------------------------------------------------------------------- /log/view/text.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local argcheck = require 'argcheck' 11 | 12 | local text = argcheck{ 13 | noordered=true, 14 | {name="filename", type="string", opt=true}, 15 | {name="keys", type="table"}, 16 | {name="format", type="table", opt=true}, 17 | {name="separator", type="string", default=" | "}, 18 | {name="append", type="boolean", default=false}, 19 | call = 20 | function(filename, keys__, format__, separator, append) 21 | local keys = {} 22 | for idx, key in ipairs(keys__) do 23 | local format = format__ and format__[idx] 24 | if not format then 25 | table.insert(keys, {name=key, format=function(value) return string.format("%s %s", key, value) end}) 26 | elseif type(format) == 'function' then 27 | table.insert(keys, {name=key, format=format}) 28 | elseif type(format) == 'string' then 29 | table.insert(keys, {name=key, format=function(value) return string.format(format, value) end}) 30 | else 31 | error('format must be a string or a function') 32 | end 33 | end 34 | if filename and not append then 35 | local f = io.open(filename, 'w') -- reset the file 36 | assert(f, string.format("could not open file <%s> for writing", filename)) 37 | f:close() 38 | end 39 | return function(log) 40 | local txt = {} 41 | for _, key in ipairs(keys) do 42 | local format = key.format(log:get(key.name)) 43 | assert(type(format) == 'string', string.format("value for key %s cannot be converted to string", key)) 44 | table.insert(txt, format) 45 | end 46 | txt = table.concat(txt, separator) 47 | if filename then 48 | local f = io.open(filename, 'a+') -- append 49 | assert(f, string.format("could not open file <%s> for writing", filename)) 50 | f:write(txt) 51 | f:write("\n") 52 | f:close() 53 | else 54 | print(txt) 55 | end 56 | end 57 | end 58 | } 59 | 60 | return text 61 | -------------------------------------------------------------------------------- /dataset/resampledataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local ResampleDataset = 14 | torch.class('tnt.ResampleDataset', 'tnt.Dataset', tnt) 15 | 16 | ResampleDataset.__init = argcheck{ 17 | doc = [[ 18 | 19 | #### tnt.ResampleDataset(@ARGP) 20 | 21 | Given a `dataset`, creates a new dataset which will (re-)sample from this 22 | underlying dataset using the provided `sampler(dataset, idx)` closure. 23 | 24 | If `size` is provided, then the newly created dataset will have the 25 | specified `size`, which might be different than the underlying dataset 26 | size. 27 | 28 | If `size` is not provided, then the new dataset will have the same size 29 | than the underlying one. 30 | 31 | By default `sampler(dataset, idx)` is the identity. `dataset` corresponds 32 | to the underlying dataset provided at construction, and `idx` may take a 33 | value between 1 to `size`. It must return an index in the range acceptable 34 | for the underlying dataset. 35 | 36 | Purpose: shuffling data, re-weighting samples, getting a subset of the 37 | data. Note that an important sub-class is ([tnt.ShuffleDataset](#ShuffleDataset)), 38 | provided for convenience. 39 | ]], 40 | {name='self', type='tnt.ResampleDataset'}, 41 | {name='dataset', type='tnt.Dataset'}, 42 | {name='sampler', type='function', default=function(dataset, idx) return idx end}, 43 | {name='size', type='number', opt=true}, 44 | call = 45 | function(self, dataset, sampler, size) 46 | self.__sampler = sampler 47 | self.__dataset = dataset 48 | self.__size = size 49 | end 50 | } 51 | 52 | ResampleDataset.size = argcheck{ 53 | {name='self', type='tnt.ResampleDataset'}, 54 | call = 55 | function(self) 56 | return (self.__size and self.__size > 0) and self.__size or self.__dataset:size() 57 | end 58 | } 59 | 60 | ResampleDataset.get = argcheck{ 61 | {name='self', type='tnt.ResampleDataset'}, 62 | {name='idx', type='number'}, 63 | call = 64 | function(self, idx) 65 | assert(idx >= 1 and idx <= self:size(), 'index out of bound') 66 | idx = self.__sampler(self.__dataset, idx) 67 | assert(idx >= 1 and idx <= self.__dataset:size(), 'index out of bound (sampler)') 68 | return self.__dataset:get(idx) 69 | end 70 | } 71 | -------------------------------------------------------------------------------- /utils/sys.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local md5 = require 'md5' 11 | local lfs = require 'lfs' 12 | local tds = require 'tds' 13 | 14 | local sys = {} 15 | 16 | function sys.md5(obj) 17 | local str = torch.serialize(obj) 18 | return md5.sumhexa(str) 19 | end 20 | 21 | function sys.mkdir(path) 22 | assert( 23 | os.execute(string.format('mkdir -p %s', path)), 24 | 'could not create directory' 25 | ) 26 | end 27 | 28 | local function cmdlinecode(code, env) 29 | assert(type(code) == 'string', 'string expected') 30 | local msg 31 | if loadstring then -- lua 5.1 32 | code, msg = loadstring(code) 33 | if code then 34 | setfenv(code, env) 35 | end 36 | else 37 | code, msg = load(code, nil, nil, env) -- lua 5.2 38 | end 39 | if not code then 40 | error(string.format('compilation error: %s', msg)) 41 | end 42 | assert(not getmetatable(env), 'env should have no metatable') 43 | setmetatable(env, {__index=_G}) 44 | local status, msg = pcall(code) 45 | setmetatable(env, nil) 46 | if not status then 47 | error(msg) 48 | end 49 | end 50 | 51 | function sys.cmdline(arg, env) 52 | for _, code in ipairs(arg) do 53 | cmdlinecode(code, env) 54 | end 55 | end 56 | 57 | function sys.loadlist(path, revert, maxload) 58 | local lst = tds.hash() 59 | local idx = 0 60 | for elem in io.lines(path) do 61 | idx = idx + 1 62 | lst[idx] = elem 63 | if revert then 64 | lst[elem] = idx 65 | end 66 | if maxload and maxload == idx then 67 | break 68 | end 69 | end 70 | return lst 71 | end 72 | 73 | function sys.listimgfiles(path, lst) 74 | lst = lst or tds.hash() 75 | for filename in lfs.dir(path) do 76 | if filename ~= '.' and filename ~= '..' then 77 | local fullpath = string.format('%s/%s', path, filename) 78 | if lfs.attributes(fullpath, 'mode') == 'directory' then 79 | sys.listimgfiles(fullpath, lst) 80 | else 81 | local ext = filename:match('[^%.]+$') 82 | if ext then 83 | ext = ext:lower() 84 | if ext == 'jpeg' or ext == 'jpg' or ext == 'png' then 85 | lst[#lst+1] = fullpath 86 | else 87 | print(string.format('ignoring <%s>', fullpath)) 88 | end 89 | end 90 | end 91 | end 92 | end 93 | return lst 94 | end 95 | 96 | return sys 97 | -------------------------------------------------------------------------------- /dataset/shuffledataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local ShuffleDataset, ResampleDataset = 14 | torch.class('tnt.ShuffleDataset', 'tnt.ResampleDataset', tnt) 15 | 16 | ShuffleDataset.__init = argcheck{ 17 | doc = [[ 18 | 19 | #### tnt.ShuffleDataset(@ARGP) 20 | @ARGT 21 | 22 | `tnt.ShuffleDataset` is a sub-class of 23 | [tnt.ResampleDataset](#ResampleDataset) provided for convenience. 24 | 25 | It samples uniformly from the given `dataset` with, or without 26 | `replacement`. The chosen partition can be redrawn by calling 27 | [resample()](#ShuffleDataset.resample). 28 | 29 | If `replacement` is `true`, then the specified `size` may be larger than 30 | the underlying `dataset`. 31 | 32 | If `size` is not provided, then the new dataset size will be equal to the 33 | underlying `dataset` size. 34 | 35 | Purpose: the easiest way to shuffle a dataset! 36 | ]], 37 | {name='self', type='tnt.ShuffleDataset'}, 38 | {name='dataset', type='tnt.Dataset'}, 39 | {name='size', type='number', opt=true}, 40 | {name='replacement', type='boolean', default=false}, 41 | call = 42 | function(self, dataset, size, replacement) 43 | if size and not replacement and size > dataset:size() then 44 | error('size cannot be larger than underlying dataset size when sampling without replacement') 45 | end 46 | self.__replacement = replacement 47 | local function sampler(dataset, idx) 48 | return self.__perm[idx] 49 | end 50 | ResampleDataset.__init(self, { 51 | dataset = dataset, 52 | sampler = sampler, 53 | size = size}) 54 | self:resample() 55 | end 56 | } 57 | 58 | ShuffleDataset.resample = argcheck{ 59 | doc = [[ 60 | 61 | ##### tnt.ShuffleDataset.resample(@ARGP) 62 | 63 | The permutation associated to `tnt.ShuffleDataset` is fixed, such that two 64 | calls to the same index will return the same sample from the underlying 65 | dataset. 66 | 67 | Call `resample()` to draw randomly a new permutation. 68 | ]], 69 | {name='self', type='tnt.ShuffleDataset'}, 70 | call = 71 | function(self) 72 | self.__perm = self.__replacement 73 | and torch.LongTensor(self:size()):random(self.__dataset:size()) 74 | or torch.randperm(self.__dataset:size()):long():narrow(1, 1, self:size()) 75 | end 76 | } 77 | -------------------------------------------------------------------------------- /dataset/concatdataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local ConcatDataset = 14 | torch.class('tnt.ConcatDataset', 'tnt.Dataset', tnt) 15 | 16 | ConcatDataset.__init = argcheck{ 17 | doc = [[ 18 | 19 | #### tnt.ConcatDataset(@ARGP) 20 | @ARGT 21 | 22 | Given a Lua array (`datasets`) of [tnt.Dataset](#Dataset), concatenates 23 | them into a single dataset. The size of the new dataset is the sum of the 24 | underlying dataset sizes. 25 | 26 | Purpose: useful to assemble different existing datasets, possibly 27 | large-scale datasets as the concatenation operation is done in an 28 | on-the-fly manner. 29 | ]], 30 | noordered=true, 31 | {name='self', type='tnt.ConcatDataset'}, 32 | {name='datasets', type='table'}, 33 | call = 34 | function(self, datasets) 35 | assert(#datasets > 0, 'datasets should not be an empty table') 36 | local indices = torch.LongTensor(#datasets, 2) -- indices: begin, end 37 | local size = 0 38 | for i, dataset in ipairs(datasets) do 39 | assert(torch.isTypeOf(dataset, 'tnt.Dataset'), 40 | 'each member of datasets table should be a tnt.Dataset') 41 | indices[i][1] = size+1 42 | size = size + dataset:size() 43 | indices[i][2] = size 44 | end 45 | self.__datasets = datasets 46 | self.__indices = indices 47 | self.__size = size 48 | end 49 | } 50 | 51 | ConcatDataset.size = argcheck{ 52 | {name='self', type='tnt.ConcatDataset'}, 53 | call = 54 | function(self) 55 | return self.__size 56 | end 57 | } 58 | 59 | ConcatDataset.get = argcheck{ 60 | {name='self', type='tnt.ConcatDataset'}, 61 | {name='idx', type='number'}, 62 | call = 63 | function(self, idx) 64 | assert(idx >= 1 and idx <= self.__size, 'index out of bound') 65 | local indices = self.__indices 66 | local l, r = 1, indices:size(1) 67 | while l ~= r do 68 | local m = math.floor((r-l)/2) + l 69 | if l == m then 70 | if idx > indices[l][2] then 71 | l, r = r, r 72 | else 73 | l, r = l, l 74 | end 75 | else 76 | if idx > indices[m][2] then 77 | l, r = m, r 78 | elseif idx < indices[m][1] then 79 | l, r = l, m 80 | else 81 | l, r = m, m 82 | end 83 | end 84 | end 85 | return self.__datasets[l]:get(idx-self.__indices[l][1]+1) 86 | end 87 | } 88 | -------------------------------------------------------------------------------- /meter/timemeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local TimeMeter = torch.class('tnt.TimeMeter', 'tnt.Meter', tnt) 14 | 15 | TimeMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.TimeMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.TimeMeter` is designed to measure the time between events and can be 22 | used to measure, for instance, the average processing time per batch of data. 23 | It is different from most other meters in terms of the methods it provides: 24 | 25 | At initialization time, an optional boolean parameter `unit` may be provided 26 | (default = `false`). When set to `true`, the value returned by the meter 27 | will be divided by the number of times that the `incUnit()` method is called. 28 | This allows the user to compute, for instance, the average processing time per 29 | batch by simply calling the `incUnit()` method after processing a batch. 30 | 31 | The `tnt.TimeMeter` provides the following methods: 32 | 33 | * `reset()` resets the timer, setting the timer and unit counter to zero. 34 | * `stop()` stops the timer. 35 | * `resume()` resumes the timer. 36 | * `incUnit()` increments the unit counter by one. 37 | * `value()` returns the time passed since the last `reset()`; divided by the counter value when `unit=true`. 38 | ]], 39 | {name="self", type="tnt.TimeMeter"}, 40 | {name="unit", type="boolean", default=false}, 41 | call = 42 | function(self, unit) 43 | self.unit = unit 44 | self.timer = torch.Timer() 45 | self:reset() 46 | end 47 | } 48 | 49 | TimeMeter.reset = argcheck{ 50 | {name="self", type="tnt.TimeMeter"}, 51 | call = 52 | function(self) 53 | self.timer:reset() 54 | self.n = 0 55 | end 56 | } 57 | 58 | TimeMeter.stop = argcheck{ 59 | {name="self", type="tnt.TimeMeter"}, 60 | call = 61 | function(self) 62 | self.timer:stop() 63 | end 64 | } 65 | 66 | TimeMeter.resume = argcheck{ 67 | {name="self", type="tnt.TimeMeter"}, 68 | call = 69 | function(self) 70 | self.timer:resume() 71 | end 72 | } 73 | 74 | TimeMeter.incUnit = argcheck{ 75 | {name="self", type="tnt.TimeMeter"}, 76 | {name="value", type="number", default=1}, 77 | call = 78 | function(self, value) 79 | self.n = self.n + value 80 | end 81 | } 82 | 83 | TimeMeter.value = argcheck{ 84 | {name="self", type="tnt.TimeMeter"}, 85 | call = 86 | function(self) 87 | local time = self.timer:time().real 88 | if self.unit then 89 | return time/self.n 90 | else 91 | return time 92 | end 93 | end 94 | } 95 | -------------------------------------------------------------------------------- /engine/init.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local doc = require 'argcheck.doc' 13 | 14 | doc[[ 15 | 16 | ### tnt.Engine 17 | 18 | In experimenting with different models and datasets, the underlying training 19 | procedure is often the same. The Engine module provides the boilerplate logic 20 | necessary for the training and testing of models. This might include conducting 21 | the interaction between model (nn.Module), `tnt.DatasetIterator`s, 22 | `nn.Criterion`s, and `tnt.Meter`s. 23 | 24 | An instance `engine` of a `tnt.Engine()` implements two main methods: 25 | 26 | * `engine:train()`, for training the model on data 27 | (i.e. sample data, forward prop, backward prop). 28 | * `engine:test()`, for evaluating a model on data 29 | (optionally with respect to a `nn.Criterion`). 30 | 31 | The Engine can be implemented for any common underlying training and testing 32 | procedure involving a model and data. It can also be designed to allow user 33 | control after certain events such as forward prop, criterion evaluation, or the 34 | end of an epoch, by using coroutines (see `tnt.SGDEngine`). 35 | 36 | ]] 37 | 38 | local Engine = torch.class('tnt.Engine', tnt) 39 | 40 | Engine.__init = argcheck{ 41 | nonamed=true, -- to avoid ambiguities 42 | {name="self", type="tnt.Engine"}, 43 | {name="hooks", type="table"}, 44 | call = 45 | function(self, hooks) 46 | self.hooks = {} 47 | for _, name in ipairs(hooks) do 48 | assert(type(name) == 'string', 'hooks must be a table of hook names (strings)') 49 | self.hooks[name] = function() end 50 | end 51 | setmetatable( 52 | self.hooks, 53 | { 54 | __index = 55 | function(hooks, name) 56 | assert(type(name) == 'string', 'hook name must be a string') 57 | error(string.format('unknown hook <%s>', name)) 58 | end, 59 | __newindex = 60 | function(self, name) 61 | assert(type(name) == 'string', 'hook name must be a string') 62 | error(string.format('unknown hook <%s>', name)) 63 | end, 64 | __call = 65 | function(hooks, name, ...) 66 | return hooks[name](...) 67 | end 68 | } 69 | ) 70 | end 71 | } 72 | 73 | Engine.train = argcheck{ 74 | {name="self", type="tnt.Engine"}, 75 | call = 76 | function(self) 77 | error('A tnt.Engine should implement the train() function.') 78 | end 79 | } 80 | 81 | Engine.test = argcheck{ 82 | {name="self", type="tnt.Engine"}, 83 | call = 84 | function(self) 85 | error('A tnt.Engine should implement the test() function.') 86 | end 87 | } 88 | -------------------------------------------------------------------------------- /example/mnist.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | This source code is licensed under the BSD-style license found in the 5 | LICENSE file in the root directory of this source tree. An additional grant 6 | of patent rights can be found in the PATENTS file in the same directory. 7 | ]]-- 8 | 9 | -- load torchnet: 10 | local tnt = require 'torchnet' 11 | 12 | -- use GPU or not: 13 | local cmd = torch.CmdLine() 14 | cmd:option('-usegpu', false, 'use gpu for training') 15 | local config = cmd:parse(arg) 16 | print(string.format('running on %s', config.usegpu and 'GPU' or 'CPU')) 17 | 18 | -- function that sets of dataset iterator: 19 | local function getIterator(mode) 20 | return tnt.ParallelDatasetIterator{ 21 | nthread = 1, 22 | init = function() require 'torchnet' end, 23 | closure = function() 24 | 25 | -- load MNIST dataset: 26 | local mnist = require 'mnist' 27 | local dataset = mnist[mode .. 'dataset']() 28 | dataset.data = dataset.data:reshape(dataset.data:size(1), 29 | dataset.data:size(2) * dataset.data:size(3)):double() 30 | 31 | -- return batches of data: 32 | return tnt.BatchDataset{ 33 | batchsize = 128, 34 | dataset = tnt.ListDataset{ -- replace this by your own dataset 35 | list = torch.range(1, dataset.data:size(1)):long(), 36 | load = function(idx) 37 | return { 38 | input = dataset.data[idx], 39 | target = torch.LongTensor{dataset.label[idx] + 1}, 40 | } -- sample contains input and target 41 | end, 42 | } 43 | } 44 | end, 45 | } 46 | end 47 | 48 | -- set up logistic regressor: 49 | local net = nn.Sequential():add(nn.Linear(784,10)) 50 | local criterion = nn.CrossEntropyCriterion() 51 | 52 | -- set up training engine: 53 | local engine = tnt.SGDEngine() 54 | local meter = tnt.AverageValueMeter() 55 | local clerr = tnt.ClassErrorMeter{topk = {1}} 56 | engine.hooks.onStartEpoch = function(state) 57 | meter:reset() 58 | clerr:reset() 59 | end 60 | engine.hooks.onForwardCriterion = function(state) 61 | meter:add(state.criterion.output) 62 | clerr:add(state.network.output, state.sample.target) 63 | if state.training then 64 | print(string.format('avg. loss: %2.4f; avg. error: %2.4f', 65 | meter:value(), clerr:value{k = 1})) 66 | end 67 | end 68 | 69 | -- set up GPU training: 70 | if config.usegpu then 71 | 72 | -- copy model to GPU: 73 | require 'cunn' 74 | net = net:cuda() 75 | criterion = criterion:cuda() 76 | 77 | -- copy sample to GPU buffer: 78 | local igpu, tgpu = torch.CudaTensor(), torch.CudaTensor() 79 | engine.hooks.onSample = function(state) 80 | igpu:resize(state.sample.input:size() ):copy(state.sample.input) 81 | tgpu:resize(state.sample.target:size()):copy(state.sample.target) 82 | state.sample.input = igpu 83 | state.sample.target = tgpu 84 | end -- alternatively, this logic can be implemented via a TransformDataset 85 | end 86 | 87 | -- train the model: 88 | engine:train{ 89 | network = net, 90 | iterator = getIterator('train'), 91 | criterion = criterion, 92 | lr = 0.2, 93 | maxepoch = 5, 94 | } 95 | 96 | -- measure test loss and error: 97 | meter:reset() 98 | clerr:reset() 99 | engine:test{ 100 | network = net, 101 | iterator = getIterator('test'), 102 | criterion = criterion, 103 | } 104 | print(string.format('test loss: %2.4f; test error: %2.4f', 105 | meter:value(), clerr:value{k = 1})) 106 | -------------------------------------------------------------------------------- /utils/table.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local argcheck = require 'argcheck' 11 | local doc = require 'argcheck.doc' 12 | 13 | local utable = {} 14 | 15 | doc[[ 16 | 17 | #### tnt.utils.table.clone(table) 18 | 19 | This function do a deep copy of a table. 20 | 21 | ]] 22 | 23 | function utable.clone(tbl) 24 | return torch.deserializeFromStorage(torch.serializeToStorage(tbl)) 25 | end 26 | 27 | function utable.copy(tbl) 28 | local cpy = {} 29 | for k,v in pairs(tbl) do 30 | cpy[k] = v 31 | end 32 | return cpy 33 | end 34 | 35 | utable.merge = argcheck{ 36 | doc = [[ 37 | 38 | #### tnt.utils.table.merge(@ARGP) 39 | @ARGT 40 | 41 | This function add to the destination table `dest`, the 42 | element contained in the source table `source`. 43 | 44 | The copy is shallow. 45 | 46 | If a key exists in both tables, then the element in the source table 47 | is preferred. 48 | ]], 49 | {name = "dst", type = 'table'}, 50 | {name = "src", type = 'table'}, 51 | call = function (dst, src) 52 | for k,v in pairs(src) do 53 | dst[k] = v 54 | end 55 | return dst 56 | end 57 | } 58 | 59 | utable.foreach = argcheck{ 60 | doc = [[ 61 | 62 | #### tnt.utils.table.foreach(@ARGP) 63 | @ARGT 64 | 65 | This function applies the function defined by `closure` to the 66 | table `tbl`. 67 | 68 | If `recursive` is given and set to `true`, the `closure` function 69 | will be apply recursively to the table. 70 | ]], 71 | {name = "tbl", type = 'table'}, 72 | {name = "closure", type = 'function'}, 73 | {name = "recursive", type = 'boolean', default = false}, 74 | call = function(tbl, closure, recursive) 75 | local newtbl = {} 76 | for k,v in pairs(tbl) do 77 | if recursive and type(v) == 'table' then 78 | newtbl[k] = utable.foreach(v, closure, recursive) 79 | else 80 | newtbl[k] = closure(v) 81 | end 82 | end 83 | return newtbl 84 | end 85 | } 86 | 87 | doc[[ 88 | 89 | #### tnt.utils.table.canmergetensor(tbl) 90 | 91 | Check if a table can be merged into a tensor. 92 | ]] 93 | 94 | function utable.canmergetensor(tbl) 95 | if type(tbl) ~= 'table' then 96 | return false 97 | end 98 | 99 | local typename = torch.typename(tbl[1]) 100 | if typename and typename:match('Tensor') then 101 | local sz = tbl[1]:nElement() 102 | for i=2,#tbl do 103 | -- can't merge tensors of different sizes 104 | if tbl[i]:nElement() ~= sz then 105 | return false 106 | end 107 | end 108 | return true 109 | end 110 | return false 111 | end 112 | 113 | utable.mergetensor = argcheck{ 114 | doc = [[ 115 | 116 | #### tnt.utils.table.mergetensor(@ARGP) 117 | @ARGT 118 | 119 | Merge a table into a tensor in one extra dimension. 120 | ]], 121 | {name = 'tbl', type = 'table'}, 122 | call = function(tbl) 123 | local sz = tbl[1]:size():totable() 124 | table.insert(sz, 1, #tbl) 125 | sz = torch.LongStorage(sz) 126 | local res = tbl[1].new():resize(sz) 127 | for i=1,#tbl do 128 | res:select(1, i):copy(tbl[i]) 129 | end 130 | return res 131 | end 132 | } 133 | 134 | return utable 135 | -------------------------------------------------------------------------------- /dataset/transformdataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local utils = require 'torchnet.utils' 13 | 14 | local TransformDataset = 15 | torch.class('tnt.TransformDataset', 'tnt.Dataset', tnt) 16 | 17 | TransformDataset.__init = argcheck{ 18 | doc = [[ 19 | 20 | #### tnt.TransformDataset(@ARGP) 21 | @ARGT 22 | 23 | Given a closure `transform()`, and a `dataset`, `tnt.TransformDataset` 24 | applies the closure in an on-the-fly manner when querying a sample with 25 | `tnt.Dataset:get()`. 26 | 27 | If key is provided, the closure is applied to the sample field specified 28 | by `key` (only). The closure must return the new corresponding field value. 29 | 30 | If key is not provided, the closure is applied on the full sample. The 31 | closure must return the new sample table. 32 | 33 | The size of the new dataset is equal to the size of the underlying `dataset`. 34 | 35 | Purpose: when performing pre-processing operations, it is convenient to be 36 | able to perform on-the-fly transformations to a 37 | dataset. 38 | ]], 39 | {name='self', type='tnt.TransformDataset'}, 40 | {name='dataset', type='tnt.Dataset'}, 41 | {name='transform', type='function'}, 42 | {name='key', type='string', opt=true}, 43 | call = 44 | function(self, dataset, transform, key) 45 | self.dataset = dataset 46 | if key then 47 | function self.transform(z, idx) 48 | assert(z[key], 'missing key in sample') 49 | z[key] = transform(z[key], idx) 50 | return z 51 | end 52 | else 53 | function self.transform(z, idx) 54 | return transform(z, idx) 55 | end 56 | end 57 | end 58 | } 59 | 60 | TransformDataset.__init = argcheck{ 61 | doc = [[ 62 | 63 | #### tnt.TransformDataset(@ARGP) 64 | @ARGT 65 | 66 | Given a set of closures and a `dataset`, `tnt.TransformDataset` applies 67 | these closures in an on-the-fly manner when querying a sample with 68 | `tnt.Dataset:get()`. 69 | 70 | Closures are provided in `transforms`, a Lua table, where a (key,value) 71 | pair represents a (sample field name, corresponding closure to be applied 72 | to the field name). 73 | 74 | Each closure must return the new value of the corresponding field. 75 | ]], 76 | {name='self', type='tnt.TransformDataset'}, 77 | {name='dataset', type='tnt.Dataset'}, 78 | {name='transforms', type='table'}, 79 | overload = TransformDataset.__init, 80 | call = 81 | function(self, dataset, transforms) 82 | for k,v in pairs(transforms) do 83 | assert(type(v) == 'function', 84 | 'key/function table expected for transforms') 85 | end 86 | self.dataset = dataset 87 | transforms = utils.table.copy(transforms) 88 | function self.transform(z) 89 | for key,transform in pairs(transforms) do 90 | assert(z[key], 'missing key in sample') 91 | z[key] = transform(z[key]) 92 | end 93 | return z 94 | end 95 | end 96 | } 97 | 98 | TransformDataset.size = argcheck{ 99 | {name='self', type='tnt.TransformDataset'}, 100 | call = 101 | function(self) 102 | return self.dataset:size() 103 | end 104 | } 105 | 106 | TransformDataset.get = argcheck{ 107 | {name='self', type='tnt.TransformDataset'}, 108 | {name='idx', type='number'}, 109 | call = 110 | function(self, idx) 111 | return self.transform( 112 | self.dataset:get(idx), idx 113 | ) 114 | end 115 | } 116 | -------------------------------------------------------------------------------- /dataset/coroutinebatchdataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local CoroutineBatchDataset, BatchDataset = 14 | torch.class('tnt.CoroutineBatchDataset', 'tnt.BatchDataset', tnt) 15 | 16 | CoroutineBatchDataset.__init = argcheck{ 17 | doc = [[ 18 | 19 | #### tnt.CoroutineBatchDataset(@ARGP) 20 | @ARGT 21 | 22 | Given a `dataset`, `tnt.CoroutineBatchDataset` merges samples from this dataset 23 | to form a new sample which can be interpreted as a batch (of size `batchsize`). 24 | 25 | It behaves the same and has the same arguments as `tnt.BatchDataset` (see the 26 | documentation there for additional details), with one important distinction: 27 | it allows the underlying dataset to postpone returning the individual samples 28 | once by doing a call to `coroutine.yield()` (from the underlying dataset). 29 | 30 | This is useful when using datasets that are inefficient or slow when they need 31 | to provide the required sample immediately after a call to `dataset:get()`. The 32 | general pattern of code in the underlying `dataset:get()` would be: 33 | 34 | ```lua 35 | FooDataset.get = function(self, idx) 36 | prepare(idx) -- stores sample in self.__data[idx] 37 | coroutine.yield() 38 | return self.__data[idx] 39 | end 40 | ``` 41 | 42 | Herein, the function `prepare(idx)` can implement, for instance, a buffering of 43 | indices before actually fetching them. 44 | ]], 45 | {name = 'self', type = 'tnt.CoroutineBatchDataset'}, 46 | {name = 'dataset', type = 'tnt.Dataset'}, 47 | {name = 'batchsize', type = 'number'}, 48 | {name = 'perm', type = 'function', default = function(idx, size) return idx end}, 49 | {name = 'merge', type = 'function', opt = true}, 50 | {name = 'policy', type = 'string', default = 'include-last'}, 51 | call = function(self, dataset, batchsize, perm, merge, policy) 52 | BatchDataset.__init(self, dataset, batchsize, perm, merge, policy) 53 | end 54 | } 55 | 56 | CoroutineBatchDataset.get = argcheck{ 57 | {name = 'self', type = 'tnt.CoroutineBatchDataset'}, 58 | {name = 'idx', type = 'number'}, 59 | call = function(self, idx) 60 | assert(idx >= 1 and idx <= self:size(), 'index out of bound') 61 | assert(idx == math.floor(idx), 'index should be integer value') 62 | 63 | -- create and start coroutines that perform get(): 64 | local crs, samples, maxidx = {}, {}, self.dataset:size() 65 | for n = 1,self.batchsize do 66 | local idx = (idx - 1) * self.batchsize + n 67 | if idx > maxidx then break end 68 | 69 | -- start coroutine: 70 | crs[n] = coroutine.create( 71 | function() return self.dataset:get(self.perm(idx)) end 72 | ) -- create coroutine that gets example 73 | local status, sample = coroutine.resume(crs[n]) -- register sample 74 | if not status then 75 | error(string.format('dataset threw error: %s', sample)) 76 | end 77 | 78 | -- if coroutine does not yield but dies, store sample: 79 | if coroutine.status(crs[n]) == 'dead' then samples[n] = sample end 80 | end 81 | 82 | -- get the samples from coroutines that are suspended: 83 | for n = 1,self.batchsize do 84 | if crs[n] and coroutine.status(crs[n]) == 'suspended' then 85 | local status, sample = coroutine.resume(crs[n]) 86 | if not status then 87 | error(string.format('dataset threw error: %s', sample)) 88 | end 89 | assert(coroutine.status(crs[n]) == 'dead', 'coroutine did not die') 90 | samples[n] = sample 91 | end 92 | end 93 | 94 | -- return batch: 95 | samples = self.makebatch(samples) 96 | collectgarbage() 97 | return samples 98 | end 99 | } 100 | -------------------------------------------------------------------------------- /meter/aucmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local AUCMeter = torch.class('tnt.AUCMeter', 'tnt.Meter', tnt) 14 | 15 | AUCMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.AUCMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.AUCMeter` measures the area under the receiver-operating characteristic 22 | (ROC) curve for binary classification problems. The area under the curve (AUC) 23 | can be interpreted as the probability that, given a randomly selected positive 24 | example and a randomly selected negative example, the positive example is 25 | assigned a higher score by the classification model than the negative example. 26 | 27 | The `tnt.AUCMeter` is designed to operate on one-dimensional Tensors `output` 28 | and `target`, where (1) the `output` contains model output scores that ought to 29 | be higher when the model is more convinced that the example should be positively 30 | labeled, and smaller when the model believes the example should be negatively 31 | labeled (for instance, the output of a signoid function); and (2) the `target` 32 | contains only values 0 (for negative examples) and 1 (for positive examples). 33 | 34 | The `tnt.AUCMeter` has no parameters to be set. 35 | ]], 36 | {name="self", type="tnt.AUCMeter"}, 37 | call = 38 | function(self) 39 | self:reset() 40 | end 41 | } 42 | 43 | AUCMeter.reset = argcheck{ 44 | {name="self", type="tnt.AUCMeter"}, 45 | call = 46 | function(self) 47 | self.scores = torch.DoubleTensor() 48 | self.targets = torch.LongTensor() 49 | end 50 | } 51 | 52 | AUCMeter.add = argcheck{ 53 | {name="self", type="tnt.AUCMeter"}, 54 | {name="output", type="torch.*Tensor"}, 55 | {name="target", type="torch.*Tensor"}, 56 | call = 57 | function(self, output, target) 58 | target = target:squeeze() 59 | output = output:squeeze() 60 | assert( 61 | output:nDimension() == 1, 62 | 'dimensionality of output should be 1 (e.g., nn.Sigmoid output)' 63 | ) 64 | assert( 65 | target:nDimension() == 1, 66 | 'dimensionality of targets should be 1' 67 | ) 68 | assert( 69 | output:size(1) == target:size(1), 70 | 'number of outputs and targets does not match' 71 | ) 72 | assert( 73 | torch.eq(torch.eq(target, 0):add(torch.eq(target, 1)), 1):all(), 74 | 'targets should be binary (0 or 1)' 75 | ) 76 | 77 | -- store scores and targets in storage: 78 | local offset1 = self.scores:nElement() 79 | local offset2 = self.targets:nElement() 80 | self.scores:resize( offset1 + output:nElement()) 81 | self.targets:resize(offset2 + target:nElement()) 82 | self.scores:narrow( 1, offset1 + 1, output:nElement()):copy(output) 83 | self.targets:narrow(1, offset2 + 1, target:nElement()):copy(target) 84 | end 85 | } 86 | 87 | AUCMeter.value = argcheck{ 88 | {name="self", type="tnt.AUCMeter"}, 89 | call = 90 | function(self) 91 | 92 | -- sort the scores: 93 | local scores, sortind = torch.sort(self.scores, 1, true) 94 | 95 | -- construct the ROC curve: 96 | local tpr = torch.DoubleTensor(scores:nElement() + 1):zero() 97 | local fpr = torch.DoubleTensor(scores:nElement() + 1):zero() 98 | for n = 2,scores:nElement() + 1 do 99 | if self.targets[sortind[n - 1]] == 1 then 100 | tpr[n], fpr[n] = tpr[n - 1] + 1, fpr[n - 1] 101 | else 102 | tpr[n], fpr[n] = tpr[n - 1], fpr[n - 1] + 1 103 | end 104 | end 105 | tpr:div(self.targets:sum()) 106 | fpr:div(torch.mul(self.targets, -1):add(1):sum()) 107 | 108 | -- compute AUC: 109 | local auc = torch.div(tpr, fpr:nElement()):sum() 110 | 111 | -- return AUC and ROC curve: 112 | return auc, tpr, fpr 113 | end 114 | } 115 | -------------------------------------------------------------------------------- /dataset/splitdataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local SplitDataset = torch.class('tnt.SplitDataset', 'tnt.Dataset', tnt) 14 | 15 | SplitDataset.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.SplitDataset(@ARGP) 19 | @ARGT 20 | 21 | Partition a given `dataset`, according to the specified `partitions`. Use 22 | the method [select()](#SplitDataset.select) to select the current partition 23 | in use. 24 | 25 | The Lua hash table `partitions` is of the form (key, value) where key is a 26 | user-chosen string naming the partition, and value is a number representing 27 | the weight (in size) of the corresponding partition. 28 | 29 | The sum of the partition weights may or may not sum to one 30 | (`tnt.SplitDataset` will make them sum to one!). 31 | 32 | Partionning is achieved linearly (no shuffling). See 33 | [tnt.ShuffleDataset](#ShuffleDataset) if you want to shuffle the dataset 34 | before partitioning. 35 | 36 | Purpose: useful in machine learning to perform validation procedures. 37 | ]], 38 | {name='self', type='tnt.SplitDataset'}, 39 | {name='dataset', type='tnt.Dataset'}, 40 | {name='partitions', type='table'}, 41 | call = 42 | function(self, dataset, partitions) 43 | 44 | -- create partition size tensor and table with partition names: 45 | self.__dataset = dataset 46 | local n = 0; for _,_ in pairs(partitions) do n = n + 1 end 47 | self.__partitionsizes = torch.DoubleTensor(n) 48 | self.__names = {} 49 | n = 0 50 | for key, val in pairs(partitions) do 51 | n = n + 1 52 | self.__partitionsizes[n] = val 53 | self.__names[key] = n 54 | end 55 | 56 | -- assertions: 57 | assert( 58 | self.__partitionsizes:nElement() >= 2, 59 | 'SplitDataset should have at least two partitions' 60 | ) 61 | assert( 62 | self.__partitionsizes:min() >= 0, 63 | 'some partition sizes are negative' 64 | ) 65 | assert( 66 | self.__partitionsizes:max() > 0, 67 | 'all partitions are empty' 68 | ) 69 | 70 | -- if partition sizes are fractions, convert to sizes: 71 | if math.abs(self.__partitionsizes:sum() - 1) < 1e-5 then 72 | self.__partitionsizes = self.__partitionsizes:double() 73 | self.__partitionsizes:mul( 74 | self.__dataset:size() / self.__partitionsizes:sum() 75 | ):floor():long() 76 | end 77 | 78 | -- select first partition by default: 79 | self.__partition = 1 80 | end 81 | } 82 | 83 | SplitDataset.select = argcheck{ 84 | doc = [[ 85 | 86 | ##### tnt.SplitDataset.select(@ARGP) 87 | @ARGT 88 | 89 | Switch the current partition in use to the one specified by `partition`, 90 | which must be a string corresponding to one of the names provided at 91 | construction. 92 | 93 | The current dataset size changes accordingly, as well as the samples returned 94 | by the `get()` method. 95 | ]], 96 | {name='self', type='tnt.SplitDataset'}, 97 | {name='partition', type='string'}, 98 | call = 99 | function(self, partition) 100 | self.__partition = self.__names[partition] 101 | if not self.__partition then error('partition not found') end 102 | end 103 | } 104 | 105 | SplitDataset.size = argcheck{ 106 | {name='self', type='tnt.SplitDataset'}, 107 | call = 108 | function(self) 109 | return self.__partitionsizes[self.__partition] 110 | end 111 | } 112 | 113 | SplitDataset.get = argcheck{ 114 | {name='self', type='tnt.SplitDataset'}, 115 | {name='idx', type='number'}, 116 | call = 117 | function(self, idx) 118 | assert(idx >= 1 and idx <= self:size(), 'index out of bounds') 119 | local offset = (self.__partition == 1) and 0 or 120 | self.__partitionsizes:narrow(1, 1, self.__partition - 1):sum() 121 | return self.__dataset:get(offset + idx) 122 | end 123 | } 124 | -------------------------------------------------------------------------------- /dataset/listdataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local tds = require 'tds' 12 | local argcheck = require 'argcheck' 13 | local transform = require 'torchnet.transform' 14 | 15 | local ListDataset, Dataset = torch.class('tnt.ListDataset', 'tnt.Dataset', tnt) 16 | 17 | ListDataset.__init = argcheck{ 18 | doc = [[ 19 | 20 | #### tnt.ListDataset(@ARGP) 21 | @ARGT 22 | 23 | Considering a `list` (can be a `tds.Hash`, `table` or a `torch.LongTensor`) the 24 | i-th sample of a dataset will be returned by `load(list[i])`, where `load()` is 25 | a closure provided by the user. 26 | 27 | If `path` is provided, list is assumed to be a list of string, and will 28 | each element `list[i]` will prefixed by `path/` when fed to `load()`. 29 | 30 | Purpose: many low or medium-scale datasets can be seen as a list of files 31 | (for example representing input samples). For this list of file, a target 32 | can be often inferred in a simple manner. 33 | 34 | ]], 35 | {name='self', type='tnt.ListDataset'}, 36 | {name='list', type='tds.Hash'}, 37 | {name='load', type='function'}, 38 | {name='path', type='string', opt=true}, 39 | call = 40 | function(self, list, load, path) 41 | Dataset.__init(self) 42 | self.list = list 43 | self.load = load 44 | self.path = path 45 | end 46 | } 47 | 48 | ListDataset.__init = argcheck{ 49 | {name='self', type='tnt.ListDataset'}, 50 | {name='list', type='table'}, 51 | {name='load', type='function'}, 52 | {name='path', type='string', opt=true}, 53 | overload = ListDataset.__init, 54 | call = 55 | function(self, list, load, path) 56 | Dataset.__init(self) 57 | self.list = list 58 | self.load = load 59 | self.path = path 60 | end 61 | } 62 | 63 | ListDataset.__init = argcheck{ 64 | {name='self', type='tnt.ListDataset'}, 65 | {name='list', type='torch.LongTensor'}, 66 | {name='load', type='function'}, 67 | {name='path', type='string', opt=true}, 68 | overload = ListDataset.__init, 69 | call = 70 | function(self, list, load, path) 71 | Dataset.__init(self) 72 | self.list = list 73 | self.load = load 74 | self.path = path 75 | end 76 | } 77 | 78 | ListDataset.__init = argcheck{ 79 | doc = [[ 80 | #### tnt.ListDataset(@ARGP) 81 | @ARGT 82 | 83 | The file specified by `filename` is interpreted as a list of strings (one 84 | string per line). The i-th sample of a dataset will be returned by 85 | `load(line[i])`, where `load()` is a closure provided by the user an 86 | `line[i]` is the i-the line of `filename`. 87 | 88 | If `path` is provided, list is assumed to be a list of string, and will 89 | each element `list[i]` will prefixed by `path/` when fed to `load()`. 90 | 91 | ]], 92 | {name='self', type='tnt.ListDataset'}, 93 | {name='filename', type='string'}, 94 | {name='load', type='function'}, 95 | {name='maxload', type='number', opt=true}, 96 | {name='path', type='string', opt=true}, 97 | overload = ListDataset.__init, 98 | call = 99 | function(self, filename, load, maxload, path) 100 | local list = tds.hash() 101 | for filename in io.lines(filename) do 102 | list[#list+1] = filename 103 | if maxload and maxload > 0 and #list == maxload then 104 | break 105 | end 106 | end 107 | ListDataset.__init(self, list, load, path) 108 | print(string.format("| loaded <%s> with %d examples", filename, #list)) 109 | end 110 | } 111 | 112 | ListDataset.size = argcheck{ 113 | {name='self', type='tnt.ListDataset'}, 114 | call = 115 | function(self) 116 | return torch.isTensor(self.list) and self.list:size(1) 117 | or #self.list 118 | end 119 | } 120 | 121 | ListDataset.get = argcheck{ 122 | {name='self', type='tnt.ListDataset'}, 123 | {name='idx', type='number'}, 124 | call = 125 | function(self, idx) 126 | assert(idx >= 1 and idx <= self:size(), 'out of bound') 127 | if self.path then 128 | return self.load(string.format("%s/%s", self.path, self.list[idx])) 129 | else 130 | return self.load(self.list[idx]) 131 | end 132 | end 133 | } 134 | -------------------------------------------------------------------------------- /meter/apmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local APMeter = torch.class('tnt.APMeter', 'tnt.Meter', tnt) 14 | 15 | 16 | APMeter.__init = argcheck{ 17 | doc = [[ 18 | 19 | #### tnt.APMeter(@ARGP) 20 | @ARGT 21 | 22 | The `tnt.APMeter` measures the average precision per class. 23 | 24 | The `tnt.APMeter` is designed to operate on `NxK` Tensors `output` and `target`, 25 | where (1) the `output` contains model output scores for `N` examples and `K` 26 | classes that ought to be higher when the model is more convinced that the 27 | example should be positively labeled, and smaller when the model believes the 28 | example should be negatively labeled (for instance, the output of a sigmoid 29 | function); and (2) the `target` contains only values 0 (for negative examples) 30 | and 1 (for positive examples). 31 | 32 | The `tnt.APMeter` has no parameters to be set. 33 | ]], 34 | {name="self", type="tnt.APMeter"}, 35 | call = function(self) 36 | self:reset() 37 | end 38 | } 39 | 40 | APMeter.reset = argcheck{ 41 | {name="self", type="tnt.APMeter"}, 42 | call = function(self) 43 | self.n = 0 44 | self.scores = torch.DoubleTensor() 45 | self.targets = torch.LongTensor() 46 | end 47 | } 48 | 49 | APMeter.add = argcheck{ 50 | {name="self", type="tnt.APMeter"}, 51 | {name="output", type="torch.*Tensor"}, 52 | {name="target", type="torch.*Tensor"}, 53 | call = function(self, output, target) 54 | 55 | -- assertions on the input: 56 | target = target:squeeze() 57 | output = output:squeeze() 58 | if output:nDimension() == 1 then 59 | output = output:view(output:size(1), 1) 60 | else 61 | assert(output:nDimension() == 2, 62 | 'wrong output size (should be 1D or 2D with one column per class)' 63 | ) 64 | end 65 | if target:nDimension() == 1 then 66 | target = target:view(target:size(1), 1) 67 | else 68 | assert(target:nDimension() == 2, 69 | 'wrong target size (should be 1D or 2D with one column per class)' 70 | ) 71 | end 72 | assert(output:size(1) == target:size(1) and 73 | output:size(2) == target:size(2), 74 | 'dimensions for output and target does not match' 75 | ) 76 | assert(torch.eq(torch.eq(target, 0):add(torch.eq(target, 1)), 1):all(), 77 | 'targets should be binary (0 or 1)' 78 | ) 79 | if self.scores:nDimension() > 0 then 80 | assert(output:size(2) == self.scores:size(2), 81 | 'dimensions for output should match previously added examples.' 82 | ) 83 | end 84 | if self.targets:nDimension() > 0 then 85 | assert(target:size(2) == self.targets:size(2), 86 | 'dimensions for output should match previously added examples.' 87 | ) 88 | end 89 | 90 | -- store scores and targets in storage: 91 | self.scores:resize( self.n + output:size(1), output:size(2)) 92 | self.targets:resize(self.n + target:size(1), target:size(2)) 93 | self.scores:narrow( 1, self.n + 1, output:size(1)):copy(output) 94 | self.targets:narrow(1, self.n + 1, target:size(1)):copy(target) 95 | self.n = self.n + output:size(1) 96 | end 97 | } 98 | 99 | APMeter.value = argcheck{ 100 | {name="self", type="tnt.APMeter"}, 101 | call = function(self) 102 | 103 | -- compute average precision for each class: 104 | local ap = torch.DoubleTensor(self.scores:size(2)):fill(0) 105 | local range = torch.range(1, self.scores:size(1), 'torch.DoubleTensor') 106 | for k = 1,self.scores:size(2) do 107 | 108 | -- sort scores: 109 | local scores = self.scores:select( 2, k) 110 | local targets = self.targets:select(2, k) 111 | local _,sortind = torch.sort(scores, 1, true) 112 | local truth = targets:index(1, sortind) 113 | 114 | -- compute true positive sums: 115 | local tp = truth:double():cumsum() 116 | 117 | -- compute precision curve: 118 | local precision = tp:cdiv(range) 119 | 120 | -- compute average precision: 121 | ap[k] = precision[truth:byte()]:sum() / math.max(truth:sum(), 1) 122 | end 123 | return ap 124 | end 125 | } 126 | -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require 'torch' 11 | 12 | local tnt = require 'torchnet.env' 13 | local doc = require 'argcheck.doc' 14 | 15 | doc[[ 16 | 17 | # torchnet 18 | 19 | *torchnet* is a framework for [torch](http://torch.ch) which provides a set 20 | of abstractions aiming at encouraging code re-use as well as encouraging 21 | modular programming. 22 | 23 | At the moment, *torchnet* provides four set of important classes: 24 | - Dataset: handling and pre-processing data in various ways. 25 | - Engine: training/testing machine learning algorithm. 26 | - Meter: meter performance or any other quantity. 27 | - Log: output performance or any other string to file / disk in a consistent manner. 28 | 29 | For an overview of the *torchnet* framework, please also refer to 30 | [this paper](https://lvdmaaten.github.io/publications/papers/Torchnet_2016.pdf). 31 | 32 | ## Installation 33 | 34 | Please install *torch* first, following instructions on 35 | [torch.ch](http://torch.ch/docs/getting-started.html). If *torch* is 36 | already installed, make sure you have an up-to-date version of 37 | [*argcheck*](https://github.com/torch/argcheck), otherwise you will get 38 | weird errors at runtime. 39 | 40 | Assuming *torch* is already installed, the *torchnet* core is only a set of 41 | lua files, so it is straightforward to install it with *luarocks* 42 | ``` 43 | luarocks install torchnet 44 | ``` 45 | 46 | ## Documentation 47 | 48 | Requiring *torchnet* returns a local variable containing all *torchnet* 49 | class constructors. 50 | ``` 51 | local tnt = require 'torchnet' 52 | ``` 53 | 54 | ]] 55 | 56 | require 'torchnet.dataset' 57 | require 'torchnet.dataset.listdataset' 58 | require 'torchnet.dataset.tabledataset' 59 | require 'torchnet.dataset.indexeddataset' 60 | require 'torchnet.dataset.transformdataset' 61 | require 'torchnet.dataset.batchdataset' 62 | require 'torchnet.dataset.coroutinebatchdataset' 63 | require 'torchnet.dataset.concatdataset' 64 | require 'torchnet.dataset.resampledataset' 65 | require 'torchnet.dataset.shuffledataset' 66 | require 'torchnet.dataset.splitdataset' 67 | require 'torchnet.dataset.datasetiterator' 68 | require 'torchnet.dataset.paralleldatasetiterator' 69 | 70 | require 'torchnet.engine' 71 | require 'torchnet.engine.sgdengine' 72 | require 'torchnet.engine.optimengine' 73 | 74 | require 'torchnet.meter' 75 | require 'torchnet.meter.apmeter' 76 | require 'torchnet.meter.averagevaluemeter' 77 | require 'torchnet.meter.aucmeter' 78 | require 'torchnet.meter.confusionmeter' 79 | require 'torchnet.meter.mapmeter' 80 | require 'torchnet.meter.multilabelconfusionmeter' 81 | require 'torchnet.meter.classerrormeter' 82 | require 'torchnet.meter.timemeter' 83 | require 'torchnet.meter.precisionatkmeter' 84 | require 'torchnet.meter.recallmeter' 85 | require 'torchnet.meter.precisionmeter' 86 | require 'torchnet.meter.ndcgmeter' 87 | 88 | require 'torchnet.log' 89 | require 'torchnet.log.remotelog' 90 | 91 | require 'torchnet.utils' 92 | require 'torchnet.transform' 93 | 94 | -- function that makes package serializable: 95 | local function _makepackageserializable(packagetbl, packagename) 96 | local mt = torch.class('package.' .. packagename) 97 | function mt:__write() end 98 | function mt:__read() end 99 | function mt:__factory() return require(packagename) end 100 | setmetatable(packagetbl, mt) 101 | end 102 | 103 | -- this can be removed when @locronan implements a real torch.isclass(): 104 | function torch.isclass(obj) 105 | local REG = debug.getregistry() 106 | return REG[obj] and true or false 107 | end 108 | 109 | -- make torchnet serializable: 110 | local argcheck = require 'argcheck' 111 | tnt.makepackageserializable = argcheck{ 112 | {name = 'packagetbl', type = 'table'}, 113 | {name = 'packagename', type = 'string'}, 114 | call = function(packagetbl, packagename) 115 | assert(not torch.isclass(getmetatable(packagetbl)) 116 | and not torch.isclass(packagetbl), 'input cant be a class (instance)') 117 | _makepackageserializable(packagetbl, packagename) 118 | for key, val in pairs(packagetbl) do 119 | if type(val) == 'table' and not torch.isclass(getmetatable(val)) 120 | and not torch.isclass(val) then 121 | tnt.makepackageserializable(val, packagename .. '.' .. key) 122 | end 123 | end 124 | end 125 | } 126 | tnt.makepackageserializable(tnt, 'torchnet') 127 | 128 | return tnt 129 | -------------------------------------------------------------------------------- /dataset/datasetiterator.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local doc = require 'argcheck.doc' 13 | 14 | doc[[ 15 | #### Dataset Iterators 16 | 17 | It is easy to iterate over datasets using a for loop. However, sometimes 18 | one wants to filter out samples in a on-the-fly manner or thread sample fetching. 19 | 20 | Iterators are here for this particular cases. In general, refrain writing 21 | iterators for handling custom cases, and write instead a `tnt.Dataset` 22 | 23 | Iterators implement two methods: 24 | 25 | * `run()` which returns a Lua iterator usable in a for loop. 26 | * `exec(funcname, ...)` which execute a given funcname on the underlying dataset. 27 | 28 | Typical usage is achieved with a for loop: 29 | ```lua 30 | for sample in iterator:run() do 31 | 32 | end 33 | ``` 34 | 35 | Iterators implement the __call event, so one might also use the `()` operator: 36 | ```lua 37 | for sample in iterator() do 38 | 39 | end 40 | ``` 41 | 42 | ]] 43 | 44 | local DatasetIterator = torch.class('tnt.DatasetIterator', tnt) 45 | 46 | -- iterate over a dataset 47 | DatasetIterator.__init = argcheck{ 48 | doc = [[ 49 | 50 | ##### tnt.DatasetIterator(@ARGP) 51 | @ARGT 52 | 53 | The default dataset iterator. 54 | 55 | `filter(sample)` is a closure which returns `true` if the given sample 56 | should be considered or `false` if not. 57 | 58 | `transform(sample)` is a closure which can perform online transformation of 59 | samples. It returns a modified version of the given `sample`. It is the 60 | identity by default. It is often more interesting to use 61 | [tnt.TransformDataset](#TransformDataset) for that purpose. 62 | ]], 63 | {name='self', type='tnt.DatasetIterator'}, 64 | {name='dataset', type='tnt.Dataset'}, 65 | {name='filter', type='function', default=function(sample) return true end}, 66 | {name='transform', type='function', default=function(sample) return sample end}, 67 | call = 68 | function(self, dataset, filter, transform) 69 | self.dataset = dataset 70 | function self.run() 71 | local size = dataset:size() 72 | local idx = 1 73 | return 74 | function() 75 | while idx <= size do 76 | local sample = transform(dataset:get(idx)) 77 | idx = idx + 1 78 | if filter(sample) then 79 | return sample 80 | end 81 | end 82 | end 83 | end 84 | end 85 | } 86 | 87 | -- iterates from another iterator 88 | DatasetIterator.__init = argcheck{ 89 | {name='self', type='tnt.DatasetIterator'}, 90 | {name='iterator', type='tnt.DatasetIterator'}, 91 | {name='filter', type='function', default=function(sample) return true end}, 92 | {name='transform', type='function', default=function(sample) return sample end}, 93 | overload = DatasetIterator.__init, 94 | call = 95 | function(self, iterator, filter, transform) 96 | self.iterator = iterator 97 | function self.run() 98 | local loop = iterator:run() 99 | return 100 | function() 101 | repeat 102 | local sample = loop() 103 | if sample then 104 | sample = transform(sample) 105 | if filter(sample) then 106 | return sample 107 | end 108 | end 109 | until not sample 110 | end 111 | end 112 | end 113 | } 114 | 115 | DatasetIterator.__call__ = 116 | function(self, ...) 117 | return self:run(...) 118 | end 119 | 120 | doc[[ 121 | 122 | ##### tnt.DatasetIterator.exec(tnt.DatasetIterator, name, ...) 123 | 124 | Execute the given method `name` on the underlying dataset, passing it the 125 | subsequent arguments, and returns what the `name` method returns. 126 | ]] 127 | 128 | DatasetIterator.exec = 129 | function(self, name, ...) 130 | if type(self[name]) == 'function' then 131 | return self[name](self, ...) 132 | elseif self.dataset then 133 | return self.dataset:exec(name, ...) 134 | elseif self.iterator then 135 | return self.iterator:exec(name, ...) 136 | else 137 | error(string.format('unknown function <%s>', name)) 138 | end 139 | end 140 | -------------------------------------------------------------------------------- /meter/confusionmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local ConfusionMeter = torch.class('tnt.ConfusionMeter', 'tnt.Meter', tnt) 14 | 15 | ConfusionMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.ConfusionMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.ConfusionMeter` constructs a confusion matrix for a multi-class 22 | classification problems. It does not support multi-label, multi-class problems: 23 | for such problems, please use `tnt.MultiLabelConfusionMeter`. 24 | 25 | At initialization time, the `k` parameter that indicates the number 26 | of classes in the classification problem under consideration must be specified. 27 | Additionally, an optional parameter `normalized` (default = `false`) may be 28 | specified that determines whether or not the confusion matrix is normalized 29 | (that is, it contains percentages) or not (that is, it contains counts). 30 | 31 | The `add(output, target)` method takes as input an NxK tensor `output` that 32 | contains the output scores obtained from the model for N examples and K classes, 33 | and a corresponding N-tensor or NxK-tensor `target` that provides the targets 34 | for the N examples. When `target` is an N-tensor, the targets are assumed to be 35 | integer values between 1 and K. When target is an NxK-tensor, the targets are 36 | assumed to be provided as one-hot vectors (that is, vectors that contain only 37 | zeros and a single one at the location of the target value to be encoded). 38 | 39 | The `value()` method has no parameters and returns the confusion matrix in a 40 | KxK tensor. In the confusion matrix, rows correspond to ground-truth targets and 41 | columns correspond to predicted targets. 42 | ]], 43 | noordered = true, 44 | {name="self", type="tnt.ConfusionMeter"}, 45 | {name="k", type="number"}, 46 | {name="normalized", type="boolean", default=false}, 47 | call = 48 | function(self, k, normalized) 49 | self.conf = torch.LongTensor(k, k) 50 | self.normalized = normalized 51 | self:reset() 52 | end 53 | } 54 | 55 | ConfusionMeter.reset = argcheck{ 56 | {name="self", type="tnt.ConfusionMeter"}, 57 | call = 58 | function(self) 59 | self.conf:zero() 60 | end 61 | } 62 | 63 | ConfusionMeter.add = argcheck{ 64 | {name="self", type="tnt.ConfusionMeter"}, 65 | {name="output", type="torch.*Tensor"}, 66 | {name="target", type="torch.*Tensor"}, 67 | call = 68 | function(self, output, target) 69 | target = target:squeeze() 70 | output = output:squeeze() 71 | if output:nDimension() == 1 then 72 | output = output:view(1, output:size(1)) 73 | if type(target) == 'number' then 74 | target = torch.Tensor(1):fill(target) 75 | end 76 | end 77 | local onehot = not (target:nDimension() == 1) 78 | assert( 79 | target:size(1) == output:size(1), 80 | 'number of targets and outputs do not match' 81 | ) 82 | assert( 83 | output:size(2) == self.conf:size(1), 84 | 'number of outputs does not match size of confusion matrix' 85 | ) 86 | assert( 87 | not onehot or target:size(2) == output:size(2), 88 | 'target should be 1D Tensor or have size of output (one-hot)' 89 | ) 90 | if onehot then 91 | assert( 92 | torch.eq(torch.eq(target, 0):add(torch.eq(target, 1)), 1):all(), 93 | 'in one-hot encoding, target values should be 0 or 1' 94 | ) 95 | assert( 96 | torch.eq(target:sum(2), 1):all(), 97 | 'multi-label setting is not supported' 98 | ) 99 | end 100 | 101 | -- update confusion matrix: 102 | local pos 103 | local _,pred = output:double():max(2) 104 | for n = 1,pred:size(1) do 105 | if onehot then _,pos = target[n]:max(1); pos = pos[1] 106 | else pos = target[n] end 107 | self.conf[pos][pred[n][1]] = self.conf[pos][pred[n][1]] + 1 108 | end 109 | end 110 | } 111 | 112 | ConfusionMeter.value = argcheck{ 113 | {name="self", type="tnt.ConfusionMeter"}, 114 | call = 115 | function(self) 116 | local confmat 117 | if self.normalized then 118 | confmat = torch.DoubleTensor(self.conf:size()):copy(self.conf) 119 | confmat:cdiv(confmat:sum(2):expandAs(confmat)) 120 | else 121 | confmat = self.conf 122 | end 123 | return confmat 124 | end 125 | } 126 | -------------------------------------------------------------------------------- /meter/classerrormeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local ClassErrorMeter = torch.class('tnt.ClassErrorMeter', 'tnt.Meter', tnt) 14 | 15 | ClassErrorMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.ClassErrorMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.ClassErrorMeter` measures the classification error (in %) of 22 | classification models (zero-one loss). The meter can also measure the error of 23 | predicting the correct label among the top-k scoring labels (for instance, in 24 | the Imagenet competition, one generally measures classification@5 errors). 25 | 26 | At initialization time, it takes to optional parameters: (1) a table 27 | `topk` that contains the values at which the classification@k errors should be 28 | measures (default = {1}); and (2) a boolean `accuracy` that makes the meter 29 | output accuracies instead of errors (accuracy = 1 - error). 30 | 31 | The `add(output, target)` method takes as input an NxK-tensor `output` that 32 | contains the output scores for each of the N examples and each of the K classes, 33 | and an N-tensor `target` that contains the targets corresponding to each of the 34 | N examples (targets are integers between 1 and K). If only one example is 35 | `add`ed, `output` may also be a K-tensor and target a 1-tensor. 36 | 37 | Please note that `topk` (if specified) may not contain values larger than K. 38 | 39 | The `value()` returns a table with the classification@k errors for all values 40 | at k that were specified in `topk` at initialization time. Alternatively, 41 | `value(k)` returns the classification@k error as a number; only values of `k` 42 | that were element of `topk` are allowed. If `accuracy` was set to `true` at 43 | initialization time, the `value()` method returns accuracies instead of errors. 44 | ]], 45 | noordered = true, 46 | {name="self", type="tnt.ClassErrorMeter"}, 47 | {name="topk", type="table", default={1}}, 48 | {name="accuracy", type="boolean", default=false}, 49 | call = 50 | function(self, topk, accuracy) 51 | self.topk = torch.LongTensor(topk):sort():totable() 52 | self.accuracy = accuracy 53 | self:reset() 54 | end 55 | } 56 | 57 | ClassErrorMeter.reset = argcheck{ 58 | {name="self", type="tnt.ClassErrorMeter"}, 59 | call = 60 | function(self) 61 | self.sum = {} 62 | for _,k in ipairs(self.topk) do 63 | self.sum[k] = 0 64 | end 65 | self.n = 0 66 | end 67 | } 68 | 69 | ClassErrorMeter.add = argcheck{ 70 | {name="self", type="tnt.ClassErrorMeter"}, 71 | {name="output", type="torch.*Tensor"}, 72 | {name="target", type="torch.*Tensor"}, 73 | call = 74 | function(self, output, target) 75 | target = target:squeeze() 76 | output = output:squeeze() 77 | if output:nDimension() == 1 then 78 | output = output:view(1, output:size(1)) 79 | assert( 80 | type(target) == 'number', 81 | 'target and output do not match') 82 | target = torch.Tensor(1):fill(target) 83 | else 84 | assert( 85 | output:nDimension() == 2, 86 | 'wrong output size (1D or 2D expected)') 87 | assert( 88 | target:nDimension() == 1, 89 | 'target and output do not match') 90 | end 91 | assert( 92 | target:size(1) == output:size(1), 93 | 'target and output do not match') 94 | 95 | local topk = self.topk 96 | local maxk = topk[#topk] 97 | local no = output:size(1) 98 | 99 | local _, pred = output:double():topk(maxk, 2, true, true) 100 | local correct = pred:typeAs(target):eq( 101 | target:view(no, 1):expandAs(pred)) 102 | 103 | for _,k in ipairs(topk) do 104 | self.sum[k] = self.sum[k] + no - correct:narrow(2, 1, k):sum() 105 | end 106 | self.n = self.n + no 107 | end 108 | } 109 | 110 | ClassErrorMeter.value = argcheck{ 111 | {name="self", type="tnt.ClassErrorMeter"}, 112 | {name="k", type="number", opt=true}, 113 | call = 114 | function(self, k) 115 | if k then 116 | assert(self.sum[k], 'invalid k (this k was not provided at construction time)') 117 | return self.accuracy and (1-self.sum[k] / self.n)*100 or self.sum[k]*100 / self.n 118 | else 119 | local value = {} 120 | for _,k in ipairs(self.topk) do 121 | value[k] = self:value(k) 122 | end 123 | return value 124 | end 125 | end 126 | } 127 | -------------------------------------------------------------------------------- /meter/init.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local doc = require 'argcheck.doc' 13 | 14 | local Meter = torch.class('tnt.Meter', tnt) 15 | 16 | doc[[ 17 | #### Meters 18 | 19 | When training a model, you generally would like to measure how the model is 20 | performing. Specifically, you may want to measure the average processing time 21 | required per batch of data, the classification error or AUC of a classifier a 22 | validation set, or the precision@k of a retrieval model. 23 | 24 | Meters provide a standardized way to measure a range of different measures, 25 | which makes it easy to measure a wide range of properties of your models. 26 | 27 | Nearly all meters (except `tnt.TimeMeter`) implement three methods: 28 | 29 | * `add()` which adds an observation to the meter. 30 | * `value()` which returns the value of the meter, taking into account all observations. 31 | * `reset()` which removes all previously added observations, resetting the meter. 32 | 33 | The exact input arguments to the `add()` method vary depending on the meter. 34 | Most meters define the method as `add(output, target)`, where `output` is the 35 | output produced by the model and `target` is the ground-truth label of the data. 36 | 37 | The `value()` method is parameterless for most meters, but for measures that 38 | have a parameter (such as the k parameter in precision@k), they may take an 39 | input argument. 40 | 41 | An example of a typical usage of a meter is as follows: 42 | ```lua 43 | local meter = tnt.Meter() -- initialize meter 44 | for state, event in tnt.Engine:train{ 45 | network = network, 46 | criterion = criterion, 47 | iterator = iterator, 48 | } do 49 | if state == 'start-epoch' then 50 | meter:reset() -- reset meter 51 | elseif state == 'forward-criterion' then 52 | meter:add(state.network.output, sample.target) -- add value to meter 53 | elseif state == 'end-epoch' then 54 | print('value of meter:' .. meter:value()) -- get value of meter 55 | end 56 | end 57 | ``` 58 | ]] 59 | 60 | Meter.__init = argcheck{ 61 | {name="self", type="tnt.Meter"}, 62 | call = 63 | function(self) 64 | end 65 | } 66 | 67 | Meter.reset = argcheck{ 68 | {name="self", type="tnt.Meter"}, 69 | call = 70 | function(self) 71 | error('A tnt.Meter should implement the reset() function.') 72 | end 73 | } 74 | 75 | Meter.value = argcheck{ 76 | {name="self", type="tnt.Meter"}, 77 | call = 78 | function(self) 79 | error('A tnt.Meter should implement the value() function.') 80 | end 81 | } 82 | 83 | Meter.add = argcheck{ 84 | {name="self", type="tnt.Meter"}, 85 | call = 86 | function(self) 87 | error('A tnt.Meter should implement the add() function.') 88 | end 89 | } 90 | 91 | Meter.__updatetopk = argcheck{ 92 | {name='self', type='tnt.Meter'}, 93 | {name='output', type='torch.*Tensor'}, 94 | {name='target', type='torch.*Tensor'}, -- target is k-hot vector 95 | {name='topk', type='number'}, -- number of values to maintain 96 | {name='dim', type='number', default=1}, -- top-k selection dimension 97 | {name='desc', type='boolean', default=true}, -- maintain largest values 98 | call = function(self, output, target, topk, dim, desc) 99 | assert(dim == 1 or dim == 2) 100 | 101 | -- make sure top-k buffer has the right size: 102 | local firstinput = not (self.__topkoutput and self.__topktarget) 103 | self.__topkoutput = self.__topkoutput or output.new() 104 | self.__topktarget = self.__topktarget or target.new() 105 | self.__topkoutput:resize(output:size(1) + ((dim == 1) and topk or 0), 106 | output:size(2) + ((dim == 2) and topk or 0)) 107 | self.__topktarget:resize(target:size(1) + ((dim == 1) and topk or 0), 108 | target:size(2) + ((dim == 2) and topk or 0)) 109 | if firstinput then 110 | self.__topkoutput:fill(desc and -math.huge or math.huge) 111 | end 112 | 113 | -- copy new inputs into buffer: 114 | local otherdim = (dim == 1) and 2 or 1 115 | assert(output:size(otherdim) == self.__topkoutput:size(otherdim), 116 | string.format('incorrect size of dimension %d of output', otherdim)) 117 | assert(target:size(otherdim) == self.__topktarget:size(otherdim), 118 | string.format('incorrect size of dimension %d of target', otherdim)) 119 | self.__topkoutput:narrow(dim, topk + 1, output:size(dim)):copy(output) 120 | self.__topktarget:narrow(dim, topk + 1, target:size(dim)):copy(target) 121 | 122 | -- update top-k scores: 123 | local topoutput, topind = torch.topk(self.__topkoutput, topk, dim, desc) 124 | self.__topkoutput:narrow(dim, 1, topk):copy(topoutput) 125 | self.__topktarget:narrow(dim, 1, topk):copy( 126 | self.__topktarget:gather(dim, topind) 127 | ) 128 | end 129 | } 130 | -------------------------------------------------------------------------------- /dataset/batchdataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local transform = require 'torchnet.transform' 13 | 14 | local BatchDataset = 15 | torch.class('tnt.BatchDataset', 'tnt.Dataset', tnt) 16 | 17 | BatchDataset.__init = argcheck{ 18 | doc = [[ 19 | 20 | #### tnt.BatchDataset(@ARGP) 21 | @ARGT 22 | 23 | Given a `dataset`, `tnt.BatchDataset` merges samples from this dataset to 24 | form a new sample which can be interpreted as a batch (of size 25 | `batchsize`). 26 | 27 | The `merge` function controls how the batch is performed. It is a closure 28 | taking a Lua array as input containing all occurrences (for a given batch) 29 | of a field of the sample, and returning the aggregated version of these 30 | occurrences. By default the occurrences are supposed to be tensors, and 31 | they aggregated along the first dimension. 32 | 33 | More formally, if the i-th sample of the underlying dataset is written as: 34 | ```lua 35 | {input=, target=} 36 | ``` 37 | assuming only two fields `input` and `target` in the sample, then `merge()` 38 | will be passed tables of the form: 39 | ```lua 40 | {, , ... } 41 | ``` 42 | or 43 | ```lua 44 | {, , ... } 45 | ``` 46 | with `n` being the batch size. 47 | 48 | It is often important to shuffle examples while performing the batch 49 | operation. `perm(idx, size)` is a closure which returns the shuffled index 50 | of the sample at position `idx` in the underlying dataset. For convenience, 51 | the `size` of the underlying dataset is also passed to the closure. By 52 | default, the closure is the identity. 53 | 54 | The underlying dataset size might or might not be always divisible by 55 | `batchsize`. The optional `policy` string specify how to handle corner 56 | cases: 57 | - `include-last` makes sure all samples of the underlying dataset will be seen, batches will be of size equal or inferior to `batchsize`. 58 | - `skip-last` will skip last examples of the underlying dataset if its size is not properly divisible. Batches will be always of size equal to `batchsize`. 59 | - `divisible-only` will raise an error if the underlying dataset has not a size divisible by `batchsize`. 60 | 61 | Purpose: the concept of batch is problem dependent. In *torchnet*, it is up 62 | to the user to interpret a sample as a batch or not. When one wants to 63 | assemble samples from an existing dataset into a batch, then 64 | `tnt.BatchDataset` is suited for the job. Sometimes it is however more 65 | convenient to write a dataset from scratch providing "batched" samples. 66 | ]], 67 | {name='self', type='tnt.BatchDataset'}, 68 | {name='dataset', type='tnt.Dataset'}, 69 | {name='batchsize', type='number'}, 70 | {name='perm', type='function', default=function(idx, size) return idx end}, 71 | {name='merge', type='function', opt=true}, 72 | {name='policy', type='string', default='include-last'}, 73 | call = 74 | function(self, dataset, batchsize, perm, merge, policy) 75 | assert(batchsize > 0 and math.floor(batchsize) == batchsize, 76 | 'batchsize should be a positive integer number') 77 | self.dataset = dataset 78 | self.perm = perm 79 | self.batchsize = batchsize 80 | self.makebatch = transform.makebatch{merge=merge} 81 | self.policy = policy 82 | self:size() -- check policy 83 | end 84 | } 85 | 86 | BatchDataset.size = argcheck{ 87 | {name='self', type='tnt.BatchDataset'}, 88 | call = 89 | function(self) 90 | local policy = self.policy 91 | if policy == 'include-last' then 92 | return math.ceil(self.dataset:size()/self.batchsize) 93 | elseif policy == 'skip-last' then 94 | return math.floor(self.dataset:size()/self.batchsize) 95 | elseif policy == 'divisible-only' then 96 | assert(self.dataset:size() % self.batchsize == 0, 'dataset size is not divisible by batch size') 97 | return self.dataset:size()/self.batchsize 98 | else 99 | error('invalid policy (include-last | skip-last | divisible-only expected)') 100 | end 101 | end 102 | } 103 | 104 | BatchDataset.get = argcheck{ 105 | {name='self', type='tnt.BatchDataset'}, 106 | {name='idx', type='number'}, 107 | call = 108 | function(self, idx) 109 | assert(idx >= 1 and idx <= self:size(), 'index out of bound') 110 | local samples = {} 111 | local maxidx = self.dataset:size() 112 | for i=1,self.batchsize do 113 | local idx = (idx - 1)*self.batchsize + i 114 | if idx > maxidx then 115 | break 116 | end 117 | idx = self.perm(idx, maxidx) 118 | table.insert(samples, self.dataset:get(idx)) 119 | end 120 | samples = self.makebatch(samples) 121 | collectgarbage() 122 | collectgarbage() 123 | return samples 124 | end 125 | } 126 | -------------------------------------------------------------------------------- /engine/optimengine.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local doc = require 'argcheck.doc' 13 | 14 | doc[[ 15 | 16 | ### tnt.OptimEngine 17 | 18 | The OptimEngine module wraps the optimization functions from 19 | https://github.com/torch/optim. At the start of training, the engine will call 20 | `getParameters` on the provided network. 21 | 22 | The `train` method requires the following parameters in addition to the 23 | SGDEngine.train parameters: 24 | 25 | * `optimMethod` the optimization function (e.g optim.sgd) 26 | * `config` a table with configuration parameters for the optimizer 27 | 28 | Example: 29 | ```lua 30 | local engine = tnt.OptimEngine() 31 | engine:train{ 32 | network = model, 33 | criterion = criterion, 34 | iterator = iterator, 35 | optimMethod = optim.sgd, 36 | config = { 37 | learningRate = 0.1, 38 | momentum = 0.9, 39 | }, 40 | } 41 | ``` 42 | ]] 43 | 44 | require 'nn' 45 | 46 | local OptimEngine, SGDEngine = torch.class('tnt.OptimEngine', 'tnt.SGDEngine', tnt) 47 | 48 | OptimEngine.__init = argcheck{ 49 | {name="self", type="tnt.OptimEngine"}, 50 | call = 51 | function(self) 52 | SGDEngine.__init(self) 53 | end 54 | } 55 | 56 | OptimEngine.train = argcheck{ 57 | {name="self", type="tnt.OptimEngine"}, 58 | {name="network", type="nn.Module"}, 59 | {name="criterion", type="nn.Criterion"}, 60 | {name="iterator", type="tnt.DatasetIterator"}, 61 | {name="maxepoch", type="number", default=1000}, 62 | {name="optimMethod", type="function"}, 63 | {name="config", type="table", opt=true}, 64 | {name="optimState", type="table", opt=true}, 65 | {name="paramFun", type="function", opt=true}, 66 | call = 67 | function(self, network, criterion, iterator, maxepoch, optimMethod, 68 | config, optimState, paramFun) 69 | local state = { 70 | network = network, 71 | criterion = criterion, 72 | iterator = iterator, 73 | maxepoch = maxepoch, 74 | optimMethod = optimMethod, 75 | sample = {}, 76 | config = config or {}, 77 | optim = optimState or {}, 78 | epoch = 0, -- epoch done so far 79 | t = 0, -- samples seen so far 80 | training = true 81 | } 82 | 83 | if paramFun then 84 | state.params, state.gradParams = paramFun() 85 | else 86 | state.params, state.gradParams = state.network:getParameters() 87 | end 88 | 89 | local function feval() 90 | return state.criterion.output, state.gradParams 91 | end 92 | 93 | self.hooks("onStart", state) 94 | while state.epoch < state.maxepoch do 95 | state.network:training() 96 | 97 | self.hooks("onStartEpoch", state) 98 | for sample in state.iterator() do 99 | state.sample = sample 100 | self.hooks("onSample", state) 101 | 102 | state.network:forward(sample.input) 103 | self.hooks("onForward", state) 104 | state.criterion:forward(state.network.output, sample.target) 105 | self.hooks("onForwardCriterion", state) 106 | 107 | state.network:zeroGradParameters() 108 | if state.criterion.zeroGradParameters then 109 | state.criterion:zeroGradParameters() 110 | end 111 | 112 | state.criterion:backward(state.network.output, sample.target) 113 | self.hooks("onBackwardCriterion", state) 114 | state.network:backward(sample.input, state.criterion.gradInput) 115 | self.hooks("onBackward", state) 116 | 117 | state.optimMethod(feval, state.params, state.config, state.optim) 118 | state.t = state.t + 1 119 | self.hooks("onUpdate", state) 120 | end 121 | state.epoch = state.epoch + 1 122 | self.hooks("onEndEpoch", state) 123 | end 124 | self.hooks("onEnd", state) 125 | end 126 | } 127 | 128 | OptimEngine.test = argcheck{ 129 | {name="self", type="tnt.OptimEngine"}, 130 | {name="network", type="nn.Module"}, 131 | {name="iterator", type="tnt.DatasetIterator"}, 132 | {name="criterion", type="nn.Criterion", opt=true}, 133 | call = function(self, network, iterator, criterion) 134 | local state = { 135 | network = network, 136 | iterator = iterator, 137 | criterion = criterion, 138 | sample = {}, 139 | t = 0, -- samples seen so far 140 | training = false 141 | } 142 | 143 | self.hooks("onStart", state) 144 | state.network:evaluate() 145 | for sample in state.iterator() do 146 | state.sample = sample 147 | self.hooks("onSample", state) 148 | state.network:forward(sample.input) 149 | state.t = state.t + 1 150 | self.hooks("onForward", state) 151 | 152 | if state.criterion then 153 | state.criterion:forward(state.network.output, sample.target) 154 | self.hooks("onForwardCriterion", state) 155 | end 156 | 157 | end 158 | self.hooks("onEnd", state) 159 | end 160 | } 161 | -------------------------------------------------------------------------------- /engine/sgdengine.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local doc = require 'argcheck.doc' 13 | 14 | doc[[ 15 | 16 | ### tnt.SGDEngine 17 | 18 | The SGDEngine module implements the Stochastic Gradient Descent training 19 | procedure in `train`, including data sampling, forward prop, back prop, and 20 | parameter updates. It also operates as a coroutine allowing a user control 21 | (i.e. increment some sort of `tnt.Meter`) at events such as 'start', 22 | 'start-epoch', 'forward', 'forward-criterion', 'backward', etc. 23 | 24 | Accordingly, `train` requires a network (nn.Module), a criterion expressing the 25 | loss function (nn.Criterion), a dataset iterator (tnt.DatasetIterator), and a 26 | learning rate, at the minimum. The `test` function allows for simple evaluation 27 | of a model on a dataset. 28 | 29 | A `state` is maintained for external access to outputs and parameters of modules 30 | as well as sampled data. 31 | ]] 32 | 33 | require 'nn' 34 | 35 | local SGDEngine, Engine = torch.class('tnt.SGDEngine', 'tnt.Engine', tnt) 36 | 37 | SGDEngine.__init = argcheck{ 38 | {name="self", type="tnt.SGDEngine"}, 39 | call = 40 | function(self) 41 | Engine.__init(self, { 42 | "onStart", "onStartEpoch", "onSample", 43 | "onForward", "onForwardCriterion", 44 | "onBackward", "onBackwardCriterion", 45 | "onEndEpoch", "onUpdate", "onEnd" 46 | }) 47 | end 48 | } 49 | 50 | SGDEngine.train = argcheck{ 51 | {name="self", type="tnt.SGDEngine"}, 52 | {name="network", type="nn.Module"}, 53 | {name="criterion", type="nn.Criterion"}, 54 | {name="iterator", type="tnt.DatasetIterator"}, 55 | {name="lr", type="number"}, 56 | {name="lrcriterion", type="number", defaulta="lr"}, 57 | {name="maxepoch", type="number", default=1000}, 58 | call = 59 | function(self, network, criterion, iterator, lr, lrcriterion, maxepoch) 60 | local state = { 61 | network = network, 62 | criterion = criterion, 63 | iterator = iterator, 64 | lr = lr, 65 | lrcriterion = lrcriterion, 66 | maxepoch = maxepoch, 67 | sample = {}, 68 | epoch = 0, -- epoch done so far 69 | t = 0, -- samples seen so far 70 | training = true 71 | } 72 | 73 | self.hooks("onStart", state) 74 | while state.epoch < state.maxepoch do 75 | state.network:training() 76 | 77 | self.hooks("onStartEpoch", state) 78 | for sample in state.iterator() do 79 | state.sample = sample 80 | self.hooks("onSample", state) 81 | 82 | state.network:forward(sample.input) 83 | self.hooks("onForward", state) 84 | state.criterion:forward(state.network.output, sample.target) 85 | self.hooks("onForwardCriterion", state) 86 | 87 | state.network:zeroGradParameters() 88 | if state.criterion.zeroGradParameters then 89 | state.criterion:zeroGradParameters() 90 | end 91 | 92 | state.criterion:backward(state.network.output, sample.target) 93 | self.hooks("onBackwardCriterion", state) 94 | state.network:backward(sample.input, state.criterion.gradInput) 95 | self.hooks("onBackward", state) 96 | 97 | assert(state.lrcriterion >= 0, 'lrcriterion should be positive or zero') 98 | if state.lrcriterion > 0 and state.criterion.updateParameters then 99 | state.criterion:updateParameters(state.lrcriterion) 100 | end 101 | assert(state.lr >= 0, 'lr should be positive or zero') 102 | if state.lr > 0 then 103 | state.network:updateParameters(state.lr) 104 | end 105 | state.t = state.t + 1 106 | self.hooks("onUpdate", state) 107 | end 108 | state.epoch = state.epoch + 1 109 | self.hooks("onEndEpoch", state) 110 | end 111 | self.hooks("onEnd", state) 112 | end 113 | } 114 | 115 | SGDEngine.test = argcheck{ 116 | {name="self", type="tnt.SGDEngine"}, 117 | {name="network", type="nn.Module"}, 118 | {name="iterator", type="tnt.DatasetIterator"}, 119 | {name="criterion", type="nn.Criterion", opt=true}, 120 | call = function(self, network, iterator, criterion) 121 | local state = { 122 | network = network, 123 | iterator = iterator, 124 | criterion = criterion, 125 | sample = {}, 126 | t = 0, -- samples seen so far 127 | training = false 128 | } 129 | 130 | self.hooks("onStart", state) 131 | state.network:evaluate() 132 | for sample in state.iterator() do 133 | state.sample = sample 134 | self.hooks("onSample", state) 135 | state.network:forward(sample.input) 136 | state.t = state.t + 1 137 | self.hooks("onForward", state) 138 | 139 | if state.criterion then 140 | state.criterion:forward(state.network.output, sample.target) 141 | self.hooks("onForwardCriterion", state) 142 | end 143 | 144 | end 145 | self.hooks("onEnd", state) 146 | end 147 | } 148 | -------------------------------------------------------------------------------- /meter/multilabelconfusionmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local tds = require 'tds' 12 | local argcheck = require 'argcheck' 13 | 14 | local MultiLabelConfusionMeter = 15 | torch.class('tnt.MultiLabelConfusionMeter', 'tnt.Meter', tnt) 16 | 17 | MultiLabelConfusionMeter.__init = argcheck{ 18 | doc = [[ 19 | 20 | #### tnt.MultiLabelConfusionMeter(@ARGP) 21 | @ARGT 22 | 23 | The `tnt.MultiLabelConfusionMeter` constructs a confusion matrix for multi- 24 | label, multi-class classification problems. In constructing the confusion 25 | matrix, the number of positive predictions is assumed to be equal to the number 26 | of positive labels in the ground-truth. Correct predictions (that is, labels in 27 | the prediction set that are also in the ground-truth set) are added to the 28 | diagonal of the confusion matrix. Incorrect predictions (that is, labels in the 29 | prediction set that are not in the ground-truth set) are equally divided over 30 | all non-predicted labels in the ground-truth set. 31 | 32 | At initialization time, the `k` parameter that indicates the number 33 | of classes in the classification problem under consideration must be specified. 34 | Additionally, an optional parameter `normalized` (default = `false`) may be 35 | specified that determines whether or not the confusion matrix is normalized 36 | (that is, it contains percentages) or not (that is, it contains counts). 37 | 38 | The `add(output, target)` method takes as input an NxK tensor `output` that 39 | contains the output scores obtained from the model for N examples and K classes, 40 | and a corresponding NxK-tensor `target` that provides the targets for the N 41 | examples using one-hot vectors (that is, vectors that contain only zeros and a 42 | single one at the location of the target value to be encoded). 43 | 44 | The `value()` method has no parameters and returns the confusion matrix in a 45 | KxK tensor. In the confusion matrix, rows correspond to ground-truth targets and 46 | columns correspond to predicted targets. 47 | ]], 48 | noordered = true, 49 | {name="self", type="tnt.MultiLabelConfusionMeter"}, 50 | {name="k", type="number"}, 51 | {name="normalized", type="boolean", default=true}, 52 | call = 53 | function(self, k, normalized) 54 | self.conf = torch.DoubleTensor(k, k) 55 | self.normalized = normalized 56 | self:reset() 57 | end 58 | } 59 | 60 | MultiLabelConfusionMeter.reset = argcheck{ 61 | {name="self", type="tnt.MultiLabelConfusionMeter"}, 62 | call = 63 | function(self) 64 | self.conf:zero() 65 | end 66 | } 67 | 68 | MultiLabelConfusionMeter.add = argcheck{ 69 | {name="self", type="tnt.MultiLabelConfusionMeter"}, 70 | {name="output", type="torch.*Tensor"}, 71 | {name="target", type="torch.*Tensor"}, 72 | call = 73 | function(self, output, target) 74 | target = target:squeeze() 75 | output = output:squeeze() 76 | if output:nDimension() == 1 then 77 | output = output:view(1, output:size(1)) 78 | end 79 | if target:nDimension() == 1 then 80 | target = target:view(1, output:size(1)) 81 | end 82 | assert( 83 | target:nDimension() == output:nDimension() and 84 | torch.eq( 85 | torch.LongTensor(target:size()), 86 | torch.LongTensor(output:size()) 87 | ):all(), 88 | 'number of targets and outputs do not match' 89 | ) 90 | assert( 91 | torch.eq(torch.eq(target, 0):add(torch.eq(target, 1)), 1):all(), 92 | 'target values should be 0 or 1' 93 | ) 94 | assert( 95 | target:size(2) == self.conf:size(1), 96 | 'target size does not match size of confusion matrix' 97 | ) 98 | 99 | -- update confusion matrix: 100 | local nc = output:size(2) 101 | local _,pred = output:double():sort(2, true) 102 | for n = 1,pred:size(1) do 103 | 104 | -- convert targets and predictions to sets: 105 | local targetTable, predTable = tds.hash(), tds.hash() 106 | local pos = torch.range(1, nc)[torch.eq(target[n], 1)] 107 | for k = 1,pos:nElement() do 108 | targetTable[pos[k]] = 1 109 | predTable[pred[n][k]] = 1 110 | end 111 | 112 | -- loop over correct predictions: 113 | for key,_ in pairs(targetTable) do 114 | if predTable[key] then 115 | self.conf[key][key] = self.conf[key][key] + 1 116 | targetTable[key] = nil 117 | predTable[key] = nil 118 | end 119 | end 120 | 121 | -- equally distribute mass of incorrect predictions: 122 | local weight = 1 / #predTable 123 | for key1,_ in pairs(targetTable) do 124 | for key2,_ in pairs(predTable) do 125 | self.conf[key1][key2] = self.conf[key1][key2] + weight 126 | end 127 | end 128 | end 129 | end 130 | } 131 | 132 | MultiLabelConfusionMeter.value = argcheck{ 133 | {name="self", type="tnt.MultiLabelConfusionMeter"}, 134 | call = 135 | function(self) 136 | local conf = torch.DoubleTensor(self.conf:size()):copy(self.conf) 137 | if self.normalized then 138 | conf:cdiv(conf:sum(2):expandAs(conf):add(1e-8)) 139 | end 140 | return conf 141 | end 142 | } 143 | -------------------------------------------------------------------------------- /meter/ndcgmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local NDCGMeter = torch.class('tnt.NDCGMeter', 'tnt.Meter', tnt) 14 | 15 | -- function computing discounted cumulative gain: 16 | local function computeDCG(relevance, index, K) 17 | 18 | -- assertions: 19 | assert(relevance) 20 | assert(index) 21 | assert(K) 22 | assert(type(K) == 'number') 23 | assert(index:max() <= relevance:nElement()) 24 | assert(index:nElement() >= K) 25 | relevance = relevance:squeeze() 26 | index = index:squeeze() 27 | assert(relevance:dim() == 1) 28 | assert(index:dim() == 1) 29 | 30 | -- return DCG: 31 | local dcg = relevance[index[1]] 32 | if K > 1 then 33 | dcg = dcg + relevance:index(1, index:narrow(1, 2, K - 1)):cdiv( 34 | torch.range(2, K):log():div(math.log(2)):typeAs(relevance) 35 | ):sum() 36 | end 37 | return dcg 38 | end 39 | 40 | -- function computing ideal discounted cumulative gain: 41 | local function computeIDCG(relevance, K) 42 | relevance = relevance:squeeze() 43 | assert(relevance:dim() == 1) 44 | local _,sortind = torch.sort(relevance, 1, true) -- descending order 45 | return computeDCG(relevance, sortind, K) 46 | end 47 | 48 | -- function computing the normalized discounted cumulative gain: 49 | local function computeNCDG(relevance, index, K) 50 | local r = computeDCG(relevance, index, K) / computeIDCG(relevance, K) 51 | assert(r >= 0 and r <= 1) 52 | return r 53 | end 54 | 55 | NDCGMeter.__init = argcheck{ 56 | doc = [[ 57 | 58 | #### tnt.NDCGMeter(@ARGP) 59 | @ARGT 60 | 61 | The `tnt.NDCGMeter` measures the normalized discounted cumulative gain (NDCG) of 62 | a ranking produced by a model at prespecified levels k, and averages the NDCG 63 | over all examples. 64 | 65 | The discounted cumulative gain at level k is defined as: 66 | 67 | DCG_k = rel_1 + \sum{i = 2}^k (rel_i / log_2(i)) 68 | 69 | Herein, rel_i is the relevance of item i as specified by an external rater. 70 | Defining ideal DCG (IDCG) as the best possible DCG for a given example, the NDCG 71 | at level k is defined as: 72 | 73 | NDCG_k = DCG_k / IDCG_k 74 | 75 | At initialization time, the meter takes as input a table `K` that contains all 76 | the levels k at which the NDCG is computed. 77 | 78 | The `add(output, relevance)` method takes as input (1) a NxC tensor of model 79 | `outputs`, which scores for all C possible outputs for a batch of N examples; 80 | and (2) a NxC tensor `relevance` that contains the corresponding relevances for 81 | these scores, as provided by an external rater. Relevances are generally 82 | obtained from human raters. 83 | 84 | The `value()` method returns a table that contains the NDCG values for all 85 | levels K that were provided at initialization time. Alternatively, the NDCG at 86 | a specific level k can be obtained by calling `value(k)`. Note that the level 87 | `k` should be an element of the table `K` specified at initialization time. 88 | 89 | Please note that the number of outputs and relevances C should always be at 90 | least as high as the highest NDCG level k that the meter is computing. 91 | ]], 92 | {name="self", type="tnt.NDCGMeter"}, 93 | {name="K", type="table", default = {1}}, 94 | call = 95 | function(self, K) 96 | self.K = torch.LongTensor(K):sort():totable() 97 | self:reset() 98 | end 99 | } 100 | 101 | NDCGMeter.reset = argcheck{ 102 | {name="self", type="tnt.NDCGMeter"}, 103 | call = 104 | function(self) 105 | self.ndcg = {} 106 | for _,k in ipairs(self.K) do self.ndcg[k] = 0 end 107 | self.n = 0 108 | end 109 | } 110 | 111 | NDCGMeter.add = argcheck{ 112 | {name="self", type="tnt.NDCGMeter"}, 113 | {name="output", type="torch.*Tensor"}, 114 | {name="relevance", type="torch.*Tensor"}, 115 | call = 116 | function(self, output, relevance) 117 | 118 | -- check inputs: 119 | if output:dim() == 1 then 120 | output:resize(1, output:nElement()) 121 | end 122 | if relevance:dim() == 1 then 123 | relevance:resize(1, relevance:nElement()) 124 | end 125 | assert(output:dim() == 2) 126 | assert(relevance:dim() == 2) 127 | assert(output:size(1) == relevance:size(1), 'batch size must match') 128 | assert(output:size(2) == relevance:size(2), 'result size must match') 129 | assert( 130 | relevance:size(2) >= self.K[#self.K], 131 | 'too few results for value of K' 132 | ) 133 | 134 | -- compute average NDCG: 135 | relevance = relevance:double() 136 | local _,index = torch.sort(output, 2, true) -- descending order 137 | for n = 1,index:size(1) do 138 | for _,k in ipairs(self.K) do 139 | self.ndcg[k] = 140 | self.ndcg[k] + computeNCDG(relevance[n], index[n], k) 141 | end 142 | end 143 | self.n = self.n + index:size(1) 144 | end 145 | } 146 | 147 | NDCGMeter.value = argcheck{ 148 | {name="self", type="tnt.NDCGMeter"}, 149 | {name="K", type="number", opt=true}, 150 | call = 151 | function(self, K) 152 | if K then 153 | assert( 154 | self.ndcg[K], 'invalid k (was not provided at construction time)' 155 | ) 156 | return self.ndcg[K] / self.n 157 | else 158 | local value = {} 159 | for _,k in ipairs(self.K) do 160 | value[k] = self.ndcg[k] / self.n 161 | end 162 | return value 163 | end 164 | end 165 | } 166 | -------------------------------------------------------------------------------- /meter/recallmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local RecallMeter = torch.class('tnt.RecallMeter', 'tnt.Meter', tnt) 14 | 15 | RecallMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.RecallMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.RecallMeter` measures the recall of ranking methods at pre- 22 | specified thresholds. The recall is the percentage of the correct (positive) 23 | targets that is in the list of positively labeled items according to the model. 24 | 25 | At initialization time, the `tnt.RecallMeter` provides two optional 26 | parameters. The first parameter is a table `threshold` that contains all 27 | thresholds at which the recall is measured (default = {0.5}). Thresholds 28 | should be numbers between 0 and 1. The second parameter is a boolean `perclass` 29 | that makes the meter measure the recall per class when set to `true` 30 | (default = `false`). When `perclass` is set to `false`, the recall is simply 31 | averaged over all examples. 32 | 33 | The `add(output, target)` method takes two inputs: 34 | * A NxK tensor that for each of the N examples indicates the probability 35 | of the example belonging to each of the K classes, according to the model. 36 | The probabilities should sum to one over all classes; that is, the row sums 37 | of `output` should all be one. 38 | * A binary NxK `target` tensor that encodes which of the K classes 39 | are associated with the N-th input. For instance, a row of {0, 1, 0, 1} 40 | indicates that the example is associated with classes 2 and 4. 41 | 42 | The `value()` method returns a table containing the recall of the model 43 | predictions measured at the `threshold`s specified at initialization time. The 44 | `value(t)` method returns the recall at a particular threshold `t`. Note that 45 | this threshold `t` should be an element of the `threshold` table specified at 46 | initialization time of the meter. 47 | ]], 48 | noordered = true, 49 | {name="self", type="tnt.RecallMeter"}, 50 | {name="threshold", type="table", default={0.5}}, 51 | {name="perclass", type="boolean", default=false}, 52 | call = function(self, threshold, perclass) 53 | self.threshold = {} 54 | for _,n in pairs(threshold) do 55 | assert(n >= 0 and n <= 1, 'threshold should be between 0 and 1') 56 | table.insert(self.threshold, n) 57 | end 58 | table.sort(self.threshold) 59 | self.perclass = perclass 60 | self:reset() 61 | end 62 | } 63 | 64 | RecallMeter.reset = argcheck{ 65 | {name="self", type="tnt.RecallMeter"}, 66 | call = function(self) 67 | self.tp = {} 68 | self.tpfn = {} 69 | for _,t in ipairs(self.threshold) do 70 | self.tp[t] = torch.Tensor() 71 | self.tpfn[t] = torch.Tensor() 72 | end 73 | end 74 | } 75 | 76 | RecallMeter.add = argcheck{ 77 | {name="self", type="tnt.RecallMeter"}, 78 | {name="output", type="torch.*Tensor"}, 79 | {name="target", type="torch.*Tensor"}, -- target is k-hot vector 80 | call = function(self, output, target) 81 | output = output:squeeze() 82 | if output:nDimension() == 1 then 83 | output = output:view(1, output:size(1)) 84 | else 85 | assert( 86 | output:nDimension() == 2, 87 | 'wrong output size (1D or 2D expected)' 88 | ) 89 | end 90 | if target:nDimension() == 1 then 91 | target = target:view(1, target:size(1)) 92 | else 93 | assert( 94 | target:nDimension() == 2, 95 | 'wrong target size (1D or 2D expected)' 96 | ) 97 | end 98 | for s = 1,#output:size() do 99 | assert( 100 | output:size(s) == target:size(s), 101 | string.format('target and output do not match on dim %d', s) 102 | ) 103 | end 104 | 105 | -- initialize upon receiving first inputs: 106 | for _,t in ipairs(self.threshold) do 107 | if self.tp[t]:nElement() == 0 then 108 | self.tp[t]:resize( target:size(2)):typeAs(output):fill(0) 109 | self.tpfn[t]:resize(target:size(2)):typeAs(output):fill(0) 110 | end 111 | end 112 | 113 | -- scores of true positives: 114 | local true_pos = output:clone() 115 | true_pos[torch.eq(target, 0)] = -1 116 | 117 | -- sum all the things: 118 | for _,t in ipairs(self.threshold) do 119 | self.tp[t]:add(torch.ge(true_pos, t):typeAs(output):sum(1):squeeze()) 120 | self.tpfn[t]:add(target:typeAs(output):sum(1):squeeze()) 121 | end 122 | end 123 | } 124 | 125 | RecallMeter.value = argcheck{ 126 | {name="self", type="tnt.RecallMeter"}, 127 | {name="t", type="number", opt=true}, 128 | call = function(self, t) 129 | if t then 130 | assert( 131 | self.tp[t] and self.tpfn[t], 132 | string.format('%f is an incorrect threshold [not set]', t) 133 | ) 134 | if self.perclass then 135 | local recallPerClass = 136 | torch.cdiv(self.tp[t], self.tpfn[t]):mul(100) 137 | recallPerClass[torch.eq(self.tpfn[t], 0)] = 100 138 | return recallPerClass 139 | else 140 | if self.tpfn[t]:sum() == 0 then return 100 end 141 | return (self.tp[t]:sum() / self.tpfn[t]:sum()) * 100 142 | end 143 | else 144 | local value = {} 145 | for _,t in ipairs(self.threshold) do 146 | value[t] = self:value(t) 147 | end 148 | return value 149 | end 150 | end 151 | } 152 | -------------------------------------------------------------------------------- /meter/precisionmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local PrecisionMeter = torch.class('tnt.PrecisionMeter', 'tnt.Meter', tnt) 14 | 15 | PrecisionMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.PrecisionMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.PrecisionMeter` measures the precision of ranking methods at pre- 22 | specified thresholds. The precision is the percentage of the positively labeled 23 | items according to the model that is in the list of correct (positive) targets. 24 | 25 | At initialization time, the `tnt.PrecisionMeter` provides two optional 26 | parameters. The first parameter is a table `threshold` that contains all 27 | thresholds at which the precision is measured (default = {0.5}). Thresholds 28 | should be numbers between 0 and 1. The second parameter is a boolean `perclass` 29 | that makes the meter measure the precision per class when set to `true` 30 | (default = `false`). When `perclass` is set to `false`, the precision is simply 31 | averaged over all examples. 32 | 33 | The `add(output, target)` method takes two inputs: 34 | * A NxK tensor that for each of the N examples indicates the probability 35 | of the example belonging to each of the K classes, according to the model. 36 | The probabilities should sum to one over all classes; that is, the row sums 37 | of `output` should all be one. 38 | * A binary NxK `target` tensor that encodes which of the K classes 39 | are associated with the N-th input. For instance, a row of {0, 1, 0, 1} 40 | indicates that the example is associated with classes 2 and 4. 41 | 42 | The `value()` method returns a table containing the precision of the model 43 | predictions measured at the `threshold`s specified at initialization time. The 44 | `value(t)` method returns the precision at a particular threshold `t`. Note that 45 | this threshold `t` should be an element of the `threshold` table specified at 46 | initialization time of the meter. 47 | ]], 48 | noordered = true, 49 | {name="self", type="tnt.PrecisionMeter"}, 50 | {name="threshold", type="table", default={0.5}}, 51 | {name="perclass", type="boolean", default=false}, 52 | call = function(self, threshold, perclass) 53 | self.threshold = {} 54 | for _,n in pairs(threshold) do 55 | assert(n >= 0 and n <= 1, 'threshold should be between 0 and 1') 56 | table.insert(self.threshold, n) 57 | end 58 | table.sort(self.threshold) 59 | self.perclass = perclass 60 | self:reset() 61 | end 62 | } 63 | 64 | PrecisionMeter.reset = argcheck{ 65 | {name="self", type="tnt.PrecisionMeter"}, 66 | call = function(self) 67 | self.tp = {} 68 | self.tpfp = {} 69 | for _,t in ipairs(self.threshold) do 70 | self.tp[t] = torch.Tensor() 71 | self.tpfp[t] = torch.Tensor() 72 | end 73 | end 74 | } 75 | 76 | PrecisionMeter.add = argcheck{ 77 | {name="self", type="tnt.PrecisionMeter"}, 78 | {name="output", type="torch.*Tensor"}, 79 | {name="target", type="torch.*Tensor"}, -- target is k-hot vector 80 | call = function(self, output, target) 81 | output = output:squeeze() 82 | if output:nDimension() == 1 then 83 | output = output:view(1, output:size(1)) 84 | else 85 | assert( 86 | output:nDimension() == 2, 87 | 'wrong output size (1D or 2D expected)' 88 | ) 89 | end 90 | if target:nDimension() == 1 then 91 | target = target:view(1, target:size(1)) 92 | else 93 | assert( 94 | target:nDimension() == 2, 95 | 'wrong target size (1D or 2D expected)' 96 | ) 97 | end 98 | for s = 1,#output:size() do 99 | assert( 100 | output:size(s) == target:size(s), 101 | string.format('target and output do not match on dim %d', s) 102 | ) 103 | end 104 | 105 | -- initialize upon receiving first inputs: 106 | for _,t in ipairs(self.threshold) do 107 | if self.tp[t]:nElement() == 0 then 108 | self.tp[t]:resize( target:size(2)):typeAs(output):fill(0) 109 | self.tpfp[t]:resize(target:size(2)):typeAs(output):fill(0) 110 | end 111 | end 112 | 113 | -- scores of true positives: 114 | local true_pos = output:clone() 115 | true_pos[torch.eq(target, 0)] = -1 116 | 117 | -- sum all the things: 118 | for _,t in ipairs(self.threshold) do 119 | self.tp[t]:add( torch.ge(true_pos, t):typeAs(output):sum(1):squeeze()) 120 | self.tpfp[t]:add(torch.ge(output, t):typeAs(output):sum(1):squeeze()) 121 | end 122 | end 123 | } 124 | 125 | PrecisionMeter.value = argcheck{ 126 | {name="self", type="tnt.PrecisionMeter"}, 127 | {name="t", type="number", opt=true}, 128 | call = function(self, t) 129 | if t then 130 | assert( 131 | self.tp[t] and self.tpfp[t], 132 | string.format('%f is an incorrect threshold [not set]', t) 133 | ) 134 | if self.perclass then 135 | local precisionPerClass = 136 | torch.cdiv(self.tp[t], self.tpfp[t]):mul(100) 137 | precisionPerClass[torch.eq(self.tpfp[t], 0)] = 100 138 | return precisionPerClass 139 | else 140 | if self.tpfp[t]:sum() == 0 then return 100 end 141 | return self.tp[t]:sum() / self.tpfp[t]:sum() * 100 142 | end 143 | else 144 | local value = {} 145 | for _,t in ipairs(self.threshold) do 146 | value[t] = self:value(t) 147 | end 148 | return value 149 | end 150 | end 151 | } 152 | -------------------------------------------------------------------------------- /meter/precisionatkmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local PrecisionAtKMeter = torch.class('tnt.PrecisionAtKMeter', 'tnt.Meter', tnt) 14 | 15 | PrecisionAtKMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.PrecisionAtKMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.PrecisionAtKMeter` measures the precision@k of ranking methods at pre-specified 22 | levels k. The precision@k is the percentage of the k front-ranked 23 | items according to the model that is in the list of correct (positive) targets. 24 | 25 | At initialization time, a table `topk` may be given as input that specifies the 26 | levels k at which the precision@k will be measures (default = `{10}`). In 27 | addition, a number `dim` may be provided that specifies over which dimension the 28 | precision@k should be computed (default = 2), and a boolean `online` may be 29 | specified that indicates whether we see all inputs along dimension `dim` at once 30 | (default = `false`). 31 | 32 | The `add(output, target)` method takes two inputs. In the default mode (`dim=2` 33 | and `online=false`), the inputs mean: 34 | * A NxC tensor that for each of the N examples (queries) contains a score 35 | indicating to what extent each of the C classes (documents) is relevant to 36 | the query, according to the model. 37 | * A binary NxC `target` tensor that encodes which of the C classes 38 | (documents) are actually relevant to the the N-th input (query). For 39 | instance, a row of {0, 1, 0, 1} indicates that the example is associated 40 | with classes 2 and 4. 41 | 42 | The result of setting `dim` to `1` is identical to transposing the tensors 43 | `output` and `target` in the above. The result of setting `online=true` is that 44 | the function assumes that it is not the number of queries `N` that is growing 45 | with repeated calls to `add()`, but the number of candidate documents `C`. (Use 46 | this mode in scenarios where `C` is large but `N` is small.) 47 | 48 | The `value()` method returns a table that contains the precision@k (that is, the 49 | percentage of targets predicted correctly) at the cutoff levels in `topk` that 50 | were specified at initialization time. Alternatively, the precision@k at 51 | a specific level k can be obtained by calling `value(k)`. Note that the level 52 | `k` should be an element of the table `topk` specified at initialization time. 53 | 54 | Please note that the maximum value in `topk` cannot be higher than the total 55 | number of classes (documents). 56 | ]], 57 | noordered = true, 58 | {name="self", type="tnt.PrecisionAtKMeter"}, 59 | {name="topk", type="table", default={10}}, 60 | {name="dim", type="number", default=2}, 61 | {name="online", type="boolean", default=false}, 62 | call = 63 | function(self, topk, dim, online) 64 | assert(dim == 1 or dim == 2, 'value of dimension should be 1 or 2') 65 | self.topk = torch.LongTensor(topk):sort():totable() 66 | self.maxk = self.topk[#self.topk] 67 | self.dim = dim 68 | self.online = online 69 | self:reset() 70 | end 71 | } 72 | 73 | PrecisionAtKMeter.reset = argcheck{ 74 | {name="self", type="tnt.PrecisionAtKMeter"}, 75 | call = 76 | function(self) 77 | self.tp = {} 78 | for _,k in ipairs(self.topk) do self.tp[k] = 0 end 79 | self.n = 0 80 | end 81 | } 82 | 83 | PrecisionAtKMeter.add = argcheck{ 84 | {name="self", type="tnt.PrecisionAtKMeter"}, 85 | {name="output", type="torch.*Tensor"}, 86 | {name="target", type="torch.*Tensor"}, -- target is k-hot vector 87 | call = 88 | function(self, output, target) 89 | output = output:squeeze() 90 | if output:nDimension() == 1 then 91 | output = output:view(1, output:size(1)) 92 | else 93 | assert( 94 | output:nDimension() == 2, 95 | 'wrong output size (1D or 2D expected)' 96 | ) 97 | end 98 | if target:nDimension() == 1 then 99 | target = target:view(1, target:size(1)) 100 | else 101 | assert( 102 | target:nDimension() == 2, 103 | 'wrong target size (1D or 2D expected)' 104 | ) 105 | end 106 | for s = 1,#output:size() do 107 | assert(output:size(s) == target:size(s), 108 | string.format('target and output do not match on dim %d', s)) 109 | end 110 | 111 | -- add new output-target pairs: 112 | if self.online then 113 | 114 | -- update top-k list along dimension dim: 115 | self:__updatetopk{ 116 | output = output, 117 | target = target, 118 | topk = self.maxk, 119 | dim = self.dim, 120 | desc = true, 121 | } 122 | self.n = output:size((self.dim == 1) and 2 or 1) 123 | else 124 | 125 | -- accumulate counts of true positives and total # of inputs: 126 | local topout, topind = torch.topk(output, self.maxk, self.dim, true) 127 | local _,sortind = torch.sort(topout, self.dim, true) 128 | local topind = topind:gather(self.dim, sortind) 129 | local sorttarget = target:gather(self.dim, topind) 130 | for _,k in ipairs(self.topk) do 131 | self.tp[k] = self.tp[k] + sorttarget:narrow(self.dim, 1, k):sum() 132 | end 133 | self.n = self.n + target:size((self.dim == 1) and 2 or 1) 134 | end 135 | end 136 | } 137 | 138 | PrecisionAtKMeter.value = argcheck{ 139 | {name="self", type="tnt.PrecisionAtKMeter"}, 140 | {name="k", type="number", opt=true}, 141 | call = 142 | function(self, k) 143 | 144 | -- in online mode, sort outputs and corresponding targets: 145 | if self.online then 146 | local topoutput = self.__topkoutput:narrow(self.dim, 1, self.maxk) 147 | local toptarget = self.__topktarget:narrow(self.dim, 1, self.maxk) 148 | local _,sortind = torch.sort(topoutput, self.dim, true) 149 | local sorttarget = toptarget:gather(self.dim, sortind) 150 | for _,k in ipairs(self.topk) do 151 | self.tp[k] = sorttarget:narrow(self.dim, 1, k):sum() 152 | end 153 | end 154 | 155 | -- compute the precision@k: 156 | if k then 157 | assert(self.tp[k], 'invalid k (not provided at construction time)') 158 | return (self.tp[k] / (self.n * k)) * 100 159 | else 160 | local value = {} 161 | for _,k in ipairs(self.topk) do value[k] = self:value(k) end 162 | return value 163 | end 164 | end 165 | } 166 | -------------------------------------------------------------------------------- /log/init.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local doc = require 'argcheck.doc' 13 | 14 | local Log = torch.class('tnt.Log', tnt) 15 | 16 | doc[[ 17 | ### Log 18 | 19 | Log classes act as tables indexed by string keys. Allowed keys must be 20 | provided at construction. A special key `__status__` can be also set the 21 | convenience method `log:status()` to record basic messages. 22 | 23 | Viewers closures can be attached to a Log, and called at different events: 24 | * `onSet(log, key, value)`: when setting a key to the Log with `log:set{}`. 25 | * `onGet(log, key)`: when querying a key with `log:get()`. 26 | * `onFlush(log)`: when flushing out the stored data of the Log with `log:flush()`. 27 | * `onClose(log)`: when closing a Log with `log:close()`. 28 | 29 | Typical viewer closures are `text` or `json`, which allow to write to disk 30 | or to the console a subset of the keys stored by the Log, in a particular 31 | format. The special viewer closure `status` is made to be called on `set()` 32 | events, and will print out only status records. 33 | 34 | A typical use case would be the following: 35 | ```lua 36 | 37 | -- require the viewers we want 38 | local logtext = require 'torchnet.log.view.text' 39 | local logstatus = require 'torchnet.log.view.status' 40 | 41 | local log = tnt.Log{ 42 | keys = {"loss", "accuracy"}, 43 | onFlush = { 44 | -- write out all keys in "log" file 45 | logtext{filename=string.format('%s/log', rundir), keys={"loss", "accuracy"}, format={"%10.5f", "%3.2f"}}, 46 | -- write out loss in a standalone file 47 | logtext{filename=string.format('%s/loss', rundir), keys={"loss"}} 48 | -- print on screen too 49 | logtext{keys=logkeys, keys=keys={"loss", "accuracy"}}, 50 | }, 51 | onSet = { 52 | -- add status to log 53 | logstatus{filename=string.format('%s/log', rundir)}, 54 | -- print status to screen 55 | logstatus{}, 56 | } 57 | } 58 | 59 | -- set values 60 | log:set{ 61 | loss = 0.1, 62 | accuracy = 97 63 | } 64 | 65 | -- write some info 66 | log:status("hello world") 67 | 68 | -- flush out log 69 | log:flush() 70 | ``` 71 | ]] 72 | 73 | Log.__clear = 74 | function(self) 75 | self.__events = {onClose={}, onFlush={}, onGet={}, onSet={}} 76 | self.__data = {} 77 | end 78 | 79 | Log.__init = argcheck{ 80 | doc = [[ 81 | 82 | #### tnt.Log(@ARGP) 83 | @ARGT 84 | 85 | Creates a new Log with allowed keys (strings) `keys`. Specifiy event 86 | closures with table of functions `onClose`, `onFlush`, `onGet` and `onSet`, 87 | which will be called when `close()`, `flush()`, `get()`, and `set{}` 88 | methods will be called, respectively. 89 | ]], 90 | noordered=true, 91 | {name="self", type="tnt.Log"}, 92 | {name="keys", type="table"}, 93 | {name="onClose", type="table", opt=true}, 94 | {name="onFlush", type="table", opt=true}, 95 | {name="onGet", type="table", opt=true}, 96 | {name="onSet", type="table", opt=true}, 97 | call = 98 | function(self, keys, onClose, onFlush, onGet, onSet) 99 | self.__keys = {__status__ = true} 100 | for _, key in ipairs(keys) do 101 | self.__keys[key] = true 102 | end 103 | self:__clear() 104 | if onClose then 105 | self:attach('onClose', onClose) 106 | end 107 | if onFlush then 108 | self:attach('onFlush', onFlush) 109 | end 110 | if onGet then 111 | self:attach('onGet', onGet) 112 | end 113 | if onSet then 114 | self:attach('onSet', onSet) 115 | end 116 | end 117 | } 118 | 119 | Log.status = argcheck{ 120 | doc = [[ 121 | 122 | #### tnt.Log:status(@ARGP) 123 | @ARGT 124 | 125 | Record a status message, with corresponding (optional) time of the event. 126 | ]], 127 | {name="self", type="tnt.Log"}, 128 | {name="message", type="string", opt=true}, 129 | {name="time", type="boolean", default=true}, 130 | call = 131 | function(self, message, time) 132 | local prefix = "|" 133 | if time then 134 | prefix = prefix .. " " .. os.date() .. " |" 135 | end 136 | self:set{ 137 | __status__ = string.format("%s %s", prefix, message) 138 | } 139 | end 140 | } 141 | 142 | Log.set = argcheck{ 143 | doc = [[ 144 | 145 | #### tnt.Log:set(@ARGP) 146 | @ARGT 147 | 148 | Set a number of keys (a subset of the keys provided at construction) to 149 | their corresponding values. 150 | 151 | Closures attached to the `onSet(log, key, value)` event will be called. 152 | ]], 153 | nonamed=true, 154 | {name="self", type="tnt.Log"}, 155 | {name="keys", type="table"}, 156 | call = 157 | function(self, keys) 158 | for key, value in pairs(keys) do 159 | assert(type(key) == 'string', 'string expected for key') 160 | if not self.__keys[key] then 161 | error(string.format("unknown key <%s>", key)) 162 | end 163 | for _, closure in ipairs(self.__events.onSet) do 164 | closure(self, key, value) 165 | end 166 | self.__data[key] = value 167 | end 168 | end 169 | } 170 | 171 | Log.get = argcheck{ 172 | doc = [[ 173 | 174 | #### tnt.Log:get(@ARGP) 175 | @ARGT 176 | 177 | Get the value of a given key. 178 | 179 | Closures attached to the `onGet(log, key)` event will be called. 180 | ]], 181 | {name="self", type="tnt.Log"}, 182 | {name="key", type="string"}, 183 | call = 184 | function(self, key) 185 | if not self.__keys[key] then 186 | error(string.format("unknown key <%s>", key)) 187 | end 188 | for _, closure in ipairs(self.__events.onGet) do 189 | closure(self, key) 190 | end 191 | return self.__data[key] 192 | end 193 | } 194 | 195 | Log.flush = argcheck{ 196 | doc = [[ 197 | 198 | #### tnt.Log:flush(@ARGP) 199 | @ARGT 200 | 201 | Flush (empty) the log data. 202 | 203 | Closures attached to the `onFlush(log)` event will be called. 204 | ]], 205 | {name="self", type="tnt.Log"}, 206 | call = 207 | function(self) 208 | for _, closure in ipairs(self.__events.onFlush) do 209 | closure(self) 210 | end 211 | self.__data = {} 212 | end 213 | } 214 | 215 | Log.close = argcheck{ 216 | doc = [[ 217 | 218 | #### tnt.Log:close(@ARGP) 219 | @ARGT 220 | 221 | Close the log. 222 | 223 | Closures attached to the `onClose(log)` event will be called. 224 | ]], 225 | {name="self", type="tnt.Log"}, 226 | call = 227 | function(self) 228 | for _, closure in ipairs(self.__events.onClose) do 229 | closure(self) 230 | end 231 | self:__clear() 232 | end 233 | } 234 | 235 | Log.attach = argcheck{ 236 | doc = [[ 237 | 238 | #### tnt.Log:attach(@ARGP) 239 | @ARGT 240 | 241 | Attach a set of functions (provided in a table) to a given event. 242 | ]], 243 | {name="self", type="tnt.Log"}, 244 | {name="event", type="string"}, 245 | {name="closures", type="table"}, 246 | call = 247 | function(self, event, closures) 248 | local events = self.__events[event] 249 | assert(events, string.format('unknown event <%s>', event)) 250 | for _, closure in ipairs(closures) do 251 | assert(type(closure) == 'function', string.format('%s: table of functions expected', event)) 252 | table.insert(events, closure) 253 | end 254 | end 255 | } 256 | -------------------------------------------------------------------------------- /log/remotelog.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local threads = require 'threads' 12 | local transfer = require 'torchnet.log.transfer' 13 | local socket = require 'socket' 14 | local argcheck = require 'argcheck' 15 | 16 | local RemoteLog, Log = torch.class('tnt.RemoteLog', 'tnt.Log', tnt) 17 | 18 | RemoteLog.__clear = 19 | function(self) 20 | local server 21 | if self.__servername then 22 | server = self.__servername 23 | else 24 | self.__mutex = threads.Mutex() 25 | local servername = torch.ByteStorage() 26 | servername:retain() 27 | self.__mutex:lock() 28 | 29 | self.__server = threads.Thread(string.format([[ 30 | local logs = {} 31 | require 'torch' 32 | local transfer = require 'torchnet.log.transfer' 33 | local threads = require 'threads' 34 | local mutex = threads.Mutex(%d) 35 | local servername = torch.pushudata(0x%x, "torch.ByteStorage") 36 | local socket = require("socket") 37 | local tnt = require 'torchnet' 38 | -- create a TCP socket and bind it to the local host, at any port 39 | local server = assert(socket.bind("*", 0)) 40 | -- find out which port the OS chose for us 41 | local ip, port = server:getsockname() 42 | servername:string(string.format("%%s:%%s", ip, port)) 43 | mutex:unlock() 44 | local function error(msg) 45 | print(string.format("$ Log server %%s:%%d error: %%s", ip, port, msg)) 46 | end 47 | local function xcall(log, funcname, ...) 48 | local status, res = pcall( 49 | function(...) 50 | return {log[funcname](log, ...)} 51 | end, 52 | ... 53 | ) 54 | if status then 55 | return table.unpack(res) 56 | else 57 | error(res) 58 | end 59 | end 60 | -- loop forever waiting for clients 61 | while true do 62 | -- wait for a connection from any client 63 | local client = server:accept() 64 | -- make sure we don't block waiting for this client's line 65 | client:settimeout(10) 66 | -- receive the line 67 | local cmd = transfer.receive(client) 68 | if type(cmd) == 'string' then 69 | if cmd == 'close' then 70 | for _, log in pairs(logs) do 71 | xcall(log, 'close') 72 | end 73 | break 74 | elseif cmd == 'lognames' then 75 | local lognames = {} 76 | for name, _ in pairs(logs) do 77 | table.insert(lognames, name) 78 | end 79 | transfer.send(client, lognames) 80 | elseif cmd == 'create' then 81 | local logname = transfer.receive(client) 82 | local keys = transfer.receive(client) 83 | if not logs[logname] then 84 | logs[logname] = tnt.Log{keys=keys} 85 | end 86 | elseif cmd == 'attach' then 87 | local logname = transfer.receive(client) 88 | local event = transfer.receive(client) 89 | local closures = transfer.receive(client) 90 | xcall(logs[logname], 'attach', event, closures) 91 | elseif cmd == 'set' then 92 | local logname = transfer.receive(client) 93 | local keys = transfer.receive(client) 94 | xcall(logs[logname], 'set', keys) 95 | elseif cmd == 'get' then 96 | local logname = transfer.receive(client) 97 | local key = transfer.receive(client) 98 | local value = xcall(logs[logname], 'get', key) 99 | transfer.send(client, value) 100 | elseif cmd == 'flush' then 101 | local logname = transfer.receive(client) 102 | xcall(logs[logname], 'flush') 103 | end 104 | end 105 | client:close() 106 | end 107 | server:close() 108 | ]], self.__mutex:id(), torch.pointer(servername))) 109 | 110 | self.__mutex:lock() 111 | server = servername:string() 112 | 113 | -- GC Lua 5.1 114 | if newproxy then 115 | self.__gc__ = newproxy(true) 116 | getmetatable(self.__gc__).__gc = 117 | function() 118 | self:__gc() 119 | end 120 | end 121 | end 122 | 123 | self.__ip, self.__port = server:match("^(.+)%:(.+)$") 124 | self.__port = tonumber(self.__port) 125 | assert(self.__ip and self.__port, "invalid ip:port name") 126 | 127 | -- create table 128 | local c = socket.connect(self.__ip, self.__port) 129 | transfer.send(c, "create") 130 | transfer.send(c, self.__name) 131 | local keys = {} 132 | for key, _ in pairs(self.__keys) do 133 | table.insert(keys, key) 134 | end 135 | transfer.send(c, keys) 136 | c:close() 137 | end 138 | 139 | RemoteLog.__init = argcheck{ 140 | doc = [[ 141 | 142 | #### tnt.RemoteLog(@ARGP) 143 | @ARGT 144 | 145 | Creates a new RemoteLog with allowed keys (strings) `keys`. Specifiy event 146 | closures with table of functions `onClose`, `onFlush`, `onGet` and `onSet`, 147 | which will be called when `close()`, `flush()`, `get()`, and `set{}` 148 | methods will be called, respectively. 149 | 150 | If `server` is not provided, RemoteLog creates a server which can later be 151 | reached at the address provided by `server()`. 152 | 153 | If `server` is provided, RemoteLog will dialog with the given server to 154 | store any values to be recorded by the Log (or query any of these values). 155 | 156 | A given server can record different Log, with different names. The default name 157 | is `default`, but can be specified with the `name` option. 158 | 159 | At this time, it is important to call the `close()` method when RemoteLog 160 | is not used anymore (before quitting the application). 161 | ]], 162 | noordered=true, 163 | {name="self", type="tnt.RemoteLog"}, 164 | {name="keys", type="table"}, 165 | {name="server", type="string", opt=true}, 166 | {name="name", type="string", default="default"}, 167 | {name="onClose", type="table", opt=true}, 168 | {name="onFlush", type="table", opt=true}, 169 | {name="onGet", type="table", opt=true}, 170 | {name="onSet", type="table", opt=true}, 171 | call = 172 | function(self, keys, server, name, onClose, onFlush, onGet, onSet) 173 | self.__server = server 174 | self.__name = name 175 | Log.__init( 176 | self, 177 | { 178 | keys=keys, 179 | onClose=onClose, 180 | onFlush=onFlush, 181 | onGet=onGet, 182 | onSet=onSet 183 | } 184 | ) 185 | end 186 | } 187 | 188 | RemoteLog.set = argcheck{ 189 | nonamed=true, 190 | {name="self", type="tnt.RemoteLog"}, 191 | {name="keys", type="table"}, 192 | call = 193 | function(self, keys) 194 | local c = socket.connect(self.__ip, self.__port) 195 | transfer.send(c, "set") 196 | transfer.send(c, self.__name) 197 | transfer.send(c, keys) 198 | c:close() 199 | end 200 | } 201 | 202 | RemoteLog.get = argcheck{ 203 | {name="self", type="tnt.Log"}, 204 | {name="key", type="string"}, 205 | call = 206 | function(self, key) 207 | local c = socket.connect(self.__ip, self.__port) 208 | transfer.send(c, "get") 209 | transfer.send(c, self.__name) 210 | transfer.send(c, key) 211 | local value = transfer.receive(c) 212 | c:close() 213 | return value 214 | end 215 | } 216 | 217 | RemoteLog.flush = argcheck{ 218 | {name="self", type="tnt.RemoteLog"}, 219 | call = 220 | function(self) 221 | local c = socket.connect(self.__ip, self.__port) 222 | transfer.send(c, "flush") 223 | transfer.send(c, self.__name) 224 | c:close() 225 | end 226 | } 227 | 228 | RemoteLog.attach = argcheck{ 229 | {name="self", type="tnt.RemoteLog"}, 230 | {name="event", type="string"}, 231 | {name="closures", type="table"}, 232 | call = 233 | function(self, event, closures) 234 | local c = socket.connect(self.__ip, self.__port) 235 | transfer.send(c, "attach") 236 | transfer.send(c, self.__name) 237 | transfer.send(c, event) 238 | transfer.send(c, closures) 239 | c:close() 240 | end 241 | } 242 | 243 | RemoteLog.close = argcheck{ 244 | {name="self", type="tnt.RemoteLog"}, 245 | call = 246 | function(self) 247 | local c = socket.connect(self.__ip, self.__port) 248 | transfer.send(c, "close") 249 | c:close() 250 | end 251 | } 252 | 253 | RemoteLog.server = argcheck{ 254 | {name="self", type="tnt.RemoteLog"}, 255 | call = 256 | function(self) 257 | return string.format("%s:%s", self.__ip, self.__port) 258 | end 259 | } 260 | 261 | RemoteLog.lognames = argcheck{ 262 | {name="self", type="tnt.RemoteLog"}, 263 | call = 264 | function(self) 265 | local c = socket.connect(self.__ip, self.__port) 266 | transfer.send(c, "lognames") 267 | local lognames = transfer.receive(c) 268 | c:close() 269 | return lognames 270 | end 271 | } 272 | 273 | -- GC Lua 5.2 274 | function RemoteLog:__gc() 275 | if self.__server then 276 | local c = socket.connect(self.__ip, self.__port) 277 | transfer.send(c, "close") 278 | c:close() 279 | self.__server:free() 280 | end 281 | end 282 | -------------------------------------------------------------------------------- /dataset/paralleldatasetiterator.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local Threads = require 'threads' 12 | local argcheck = require 'argcheck' 13 | local doc = require 'argcheck.doc' 14 | 15 | local ParallelDatasetIterator = torch.class('tnt.ParallelDatasetIterator', 'tnt.DatasetIterator', tnt) 16 | 17 | ParallelDatasetIterator.__init = argcheck{ 18 | doc = [[ 19 | 20 | ##### tnt.ParallelDatasetIterator(@ARGP) 21 | @ARGT 22 | 23 | Allows to iterate over a dataset in a thread 24 | manner. `tnt.ParallelDatasetIterator:run()` guarantees that all samples 25 | will be seen, but does not guarantee the order unless `ordered` is set to true. 26 | 27 | The purpose of this class is to have a zero pre-processing cost. 28 | When reading datasets on the fly from 29 | disk (not loading them fully in memory), or performing complex 30 | pre-processing this can be of interest. 31 | 32 | The number of threads used to parallelize is specified by `nthread`. 33 | 34 | `init(threadid)` (where threadid=1..nthread) is a closure which may 35 | initialize the specified thread as needed, if needed. It is doing nothing 36 | by default. 37 | 38 | `closure(threadid)` will be called on each thread and must return a 39 | `tnt.Dataset` instance. 40 | 41 | `perm(idx)` is a permutation used to shuffle the examples. If shuffling is 42 | needed, one can use this closure, or (better) use 43 | [tnt.ShuffleDataset](#ShuffleDataset) on the underlying dataset 44 | (returned by `closure()`). 45 | 46 | `filter(sample)` is a closure which returns `true` if the given sample 47 | should be considered or `false` if not. Note that filter is called _after_ 48 | fetching the data in a threaded manner. 49 | 50 | `transform(sample)` is a function which maps the given sample to a new value. 51 | This transformation occurs before filtering. 52 | 53 | When `ordered` is set to true the ordering of samples returned by the iterator 54 | is guaranteed. This option is particularly useful for repeatable experiments. 55 | By default `ordered` is false, which means that order is not guaranteed by 56 | `run()` (though often the ordering is similar in practice). 57 | 58 | A common error raised by this dataset is when `closure()` is not 59 | serializable. Make sure that all [upvalues](http://www.lua.org/pil/27.3.3.html) of `closure()` are 60 | serializable. It is recommended to avoid [upvalues](http://www.lua.org/pil/27.3.3.html) at all cost, 61 | and to make sure you require all the appropriate torch packages needed to (de-)serialize 62 | `closure()` in the `init()` function. 63 | 64 | 65 | For more information, check out the [threads package](https://github.com/torch/threads), 66 | on which `tnt.ParallelDatasetIterator` relies. 67 | ]], 68 | {name='self', type='tnt.ParallelDatasetIterator'}, 69 | {name='init', type='function', default=function(idx) end}, 70 | {name='closure', type='function'}, 71 | {name='nthread', type='number'}, 72 | {name='perm', type='function', default=function(idx) return idx end}, 73 | {name='filter', type='function', default=function(sample) return true end}, 74 | {name='transform', type='function', default=function(sample) return sample end}, 75 | {name='ordered', type='boolean', default=false}, 76 | call = 77 | function(self, init, closure, nthread, perm, filter, transform, ordered) 78 | local function main(idx) 79 | gdataset = closure(idx) 80 | assert(torch.isTypeOf(gdataset, 'tnt.Dataset'), 81 | 'closure should return a Dataset class') 82 | return gdataset:size() 83 | end 84 | Threads.serialization('threads.sharedserialize') 85 | local threads, sizes = Threads(nthread, init, main) 86 | self.__threads = threads 87 | self.__nthread = nthread 88 | local size = sizes[1][1] 89 | local sample -- beware: do not put this line in loop() 90 | local sampleOrigIdx 91 | function self.run() 92 | -- loading size of the dataset each time run() is called 93 | threads:addjob( 94 | function() 95 | local size = gdataset:size() 96 | return size 97 | end, 98 | function(_size_) 99 | size = _size_ 100 | end 101 | ) 102 | threads:dojob() 103 | local idx = 1 104 | local function enqueue() 105 | while idx <= size and threads:acceptsjob() do 106 | threads:addjob( 107 | function(argList) 108 | local origIdx, idx = unpack(argList) 109 | local sample = gdataset:get(idx) 110 | collectgarbage() 111 | collectgarbage() 112 | return {sample, origIdx} 113 | end, 114 | function(argList) 115 | sample, sampleOrigIdx = unpack(argList) 116 | end, 117 | {idx, perm(idx)} 118 | ) 119 | idx = idx + 1 120 | end 121 | end 122 | 123 | enqueue() 124 | 125 | local iterFunction 126 | if ordered then 127 | local curSampleIdx = 1 128 | local storedSamples = {} 129 | -- `samplePlaceholder` stands in for samples which have been 130 | -- filtered out by the `filter` function 131 | local samplePlaceholder = {} 132 | 133 | -- Move past placeholders (filtered out samples) in 134 | -- `storedSamples` 135 | local function advancePastPlaceholders() 136 | while storedSamples[curSampleIdx] == samplePlaceholder do 137 | storedSamples[curSampleIdx] = nil 138 | curSampleIdx = curSampleIdx + 1 139 | end 140 | end 141 | 142 | iterFunction = function() 143 | advancePastPlaceholders() 144 | 145 | -- Load into storedSamples until we find the next sample in 146 | -- the sequence or we run out of samples 147 | while storedSamples[curSampleIdx] == nil and threads:hasjob() do 148 | enqueue() 149 | threads:dojob() 150 | if threads:haserror() then 151 | threads:synchronize() 152 | end 153 | enqueue() 154 | 155 | sample = transform(sample) 156 | if filter(sample) then 157 | -- Store sample 158 | storedSamples[sampleOrigIdx] = sample 159 | else 160 | -- Mark sample as "filtered out" 161 | storedSamples[sampleOrigIdx] = samplePlaceholder 162 | end 163 | 164 | advancePastPlaceholders() 165 | end 166 | 167 | enqueue() 168 | 169 | local curSample = storedSamples[curSampleIdx] 170 | storedSamples[curSampleIdx] = nil 171 | 172 | curSampleIdx = curSampleIdx + 1 173 | 174 | return curSample 175 | end 176 | else 177 | iterFunction = function() 178 | while threads:hasjob() do 179 | enqueue() 180 | threads:dojob() 181 | if threads:haserror() then 182 | threads:synchronize() 183 | end 184 | enqueue() 185 | sample = transform(sample) 186 | if filter(sample) then 187 | return sample 188 | end 189 | end 190 | end 191 | end 192 | 193 | return iterFunction 194 | end 195 | end 196 | } 197 | 198 | doc[[ 199 | 200 | ##### tnt.ParallelDatasetIterator.execSingle(tnt.DatasetIterator, name, ...) 201 | 202 | Execute the given method `name` on the dataset corresponding to the first 203 | available thread, passing it the subsequent arguments, and returns what the 204 | `name` method returns. 205 | 206 | For example: 207 | ```lua 208 | local iterator = tnt.ParallelDatasetIterator{...} 209 | print(iterator:execSingle("size")) 210 | ``` 211 | will print the size of the dataset loaded in the first available thread. 212 | ]] 213 | 214 | ParallelDatasetIterator.execSingle = 215 | function(self, name, ...) 216 | assert(not self.__threads:hasjob(), 'cannot execSingle during loop') 217 | local args = {...} 218 | local res 219 | self.__threads:addjob( 220 | function() 221 | return gdataset:exec(name, table.unpack(args)) 222 | end, 223 | function(...) 224 | res = {...} 225 | end) 226 | self.__threads:synchronize() 227 | return table.unpack(res) 228 | end 229 | 230 | doc[[ 231 | 232 | ##### tnt.ParallelDatasetIterator.exec(tnt.DatasetIterator, name, ...) 233 | 234 | Execute the given method `name` on the underlying datasets in each thread, 235 | passing to each of them the subsequent arguments, and returns a table 236 | of what the `name` method returns for each thread. 237 | 238 | For example: 239 | ```lua 240 | local iterator = tnt.ParallelDatasetIterator{...} 241 | for _, v in pairs(iterator:exec("size")) do 242 | print(v) 243 | end 244 | ``` 245 | will print the size of the datasets loaded in each thread. 246 | ]] 247 | 248 | ParallelDatasetIterator.exec = 249 | function(self, name, ...) 250 | assert(not self.__threads:hasjob(), 'cannot exec during loop') 251 | local args = {...} 252 | local res = {} 253 | self.__threads:specific(true) 254 | for i=1,self.__nthread do 255 | self.__threads:addjob(i, 256 | function() 257 | return gdataset:exec(name, table.unpack(args)) 258 | end, 259 | function(...) 260 | local r = {...} 261 | res[i] = #r > 1 and r or r[1] 262 | end) 263 | end 264 | self.__threads:specific(false) 265 | return res 266 | end 267 | -------------------------------------------------------------------------------- /transform.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local utils = require 'torchnet.utils' 13 | local doc = require 'argcheck.doc' 14 | 15 | doc[[ 16 | 17 | ### tnt.transform 18 | 19 | *Torchnet* provides a set of general data transformations. 20 | These transformations are either directly on the data (e.g., normalization) 21 | or on their structure. This is particularly handy 22 | when manipulating [tnt.Dataset](#tnt.Dataset). 23 | 24 | Most of the transformations are simple but can be [composed](#transform.compose) or 25 | [merged](#transform.merge). 26 | ]] 27 | 28 | local transform = {} 29 | tnt.transform = transform 30 | local unpack = unpack or table.unpack 31 | 32 | 33 | doc[[ 34 | 35 | #### transform.identity(...) 36 | 37 | The identity transform takes any input and return it as it is. 38 | 39 | For example, this function is useful when composing 40 | transformations on data from multiple sources, and some of the sources 41 | must not be transformed. 42 | ]] 43 | 44 | transform.identity = 45 | function(...) 46 | local args = {...} 47 | return function() 48 | return unpack(args) 49 | end 50 | end 51 | 52 | transform.compose = argcheck{ 53 | doc = [[ 54 | 55 | #### transform.compose(@ARGP) 56 | @ARGT 57 | 58 | This function takes a `table` of functions and 59 | composes them to return one transformation. 60 | 61 | This function assumes that the table of transformations 62 | is indexed by contiguous ordered keys starting at 1. 63 | The transformations are composed in the ascending order. 64 | 65 | For example, the following code: 66 | ```lua 67 | > f = transform.compose{ 68 | [1] = function(x) return 2*x end, 69 | [2] = function(x) return x + 10 end, 70 | foo = function(x) return x / 2 end, 71 | [4] = function(x) return x - x end 72 | } 73 | > f(3) 74 | 16 75 | ``` 76 | is equivalent to compose the transformations stored in [1] and [2], i.e., 77 | defining the following transformation: 78 | ```lua 79 | > f = function(x) return 2*x + 10 end ``` 80 | Note that transformations stored with keys `foo` and `4` are ignored. 81 | ``` 82 | ]], 83 | {name='transforms', type='table'}, 84 | call = 85 | function(transforms) 86 | for k,v in ipairs(transforms) do 87 | assert(type(v) == 'function', 'table of functions expected') 88 | end 89 | transforms = utils.table.copy(transforms) 90 | return 91 | function(z) 92 | for _, trans in ipairs(transforms) do 93 | z = trans(z) 94 | end 95 | return z 96 | end 97 | end 98 | } 99 | 100 | transform.merge = argcheck{ 101 | doc = [[ 102 | 103 | #### transform.merge(@ARGP) 104 | @ARGT 105 | 106 | This function takes a `table` of transformations and 107 | merge them into one transformation. 108 | Once apply to an input, this transformation will produce a `table` of output, 109 | containing the transformed input. 110 | 111 | For example, the following code: 112 | ```lua 113 | > f = transform.merge{ 114 | [1] = function(x) return 2*x end, 115 | [2] = function(x) return x + 10 end, 116 | foo = function(x) return x / 2 end, 117 | [4] = function(x) return x - x end 118 | } 119 | ``` 120 | produces a function which applies a set of transformations to the same input: 121 | ```lua 122 | > f(3) 123 | { 124 | 1 : 6 125 | 2 : 13 126 | foo : 1.5 127 | 4 : 0 128 | } 129 | ``` 130 | ]], 131 | {name='transforms', type='table'}, 132 | call = 133 | function(transforms) 134 | for k,v in pairs(transforms) do 135 | assert(type(v) == 'function', 'table of functions expected') 136 | end 137 | transforms = utils.table.copy(transforms) 138 | return 139 | function(z) 140 | local newz = {} 141 | for k, trans in pairs(transforms) do 142 | newz[k] = trans(z) 143 | end 144 | return utils.table.mergetensor(newz) 145 | end 146 | end 147 | } 148 | 149 | transform.tablenew = argcheck{ 150 | doc = [[ 151 | 152 | #### transform.tablenew() 153 | 154 | This function creates a new table of functions from an 155 | existing table of functions. 156 | ]], 157 | call = 158 | function() 159 | return 160 | function(func) 161 | local tbl = {} 162 | for k,v in pairs(func) do 163 | tbl[k] = v 164 | end 165 | return tbl 166 | end 167 | end 168 | } 169 | 170 | transform.tableapply = argcheck{ 171 | doc = [[ 172 | 173 | #### transform.tableapply(@ARGP) 174 | @ARGT 175 | 176 | This function applies a transformation to a table of input. 177 | It return a table of output of the same size as the input. 178 | 179 | For example, the following code: 180 | ```lua 181 | > f = transform.tableapply(function(x) return 2*x end) 182 | ``` 183 | produces a function which multiplies any input by 2: 184 | ```lua 185 | > f({[1] = 1, [2] = 2, foo = 3, [4] = 4}) 186 | { 187 | 1 : 2 188 | 2 : 4 189 | foo : 6 190 | 4 : 8 191 | } 192 | ``` 193 | ]], 194 | {name='transform', type='function'}, 195 | call = 196 | function(transform) 197 | return 198 | function(tbl) 199 | return utils.table.foreach(tbl, transform) 200 | end 201 | end 202 | } 203 | 204 | transform.tablemergekeys = argcheck{ 205 | doc = [[ 206 | 207 | #### transform.tablemergekeys() 208 | 209 | This function merges tables by key. More precisely, the input must be a 210 | `table` of `table` and this function will reverse the table orderto 211 | make the keys from the nested table accessible first. 212 | 213 | For example, if the input is: 214 | ```lua 215 | > x = { sample1 = {input = 1, target = "a"} , sample2 = {input = 2, target = "b", flag = "hard"} 216 | ``` 217 | Then apply this function will produce: 218 | ```lua 219 | > transform.tablemergekeys(x) 220 | { 221 | input : 222 | { 223 | sample1 : 1 224 | sample2 : 2 225 | } 226 | target : 227 | { 228 | sample1 : "a" 229 | sample2 : "b" 230 | } 231 | flag : 232 | { 233 | sample2: "hard" 234 | } 235 | } 236 | ``` 237 | ]], 238 | call = 239 | function() 240 | return 241 | function(tbl) 242 | local mergedtbl = {} 243 | for idx, elem in ipairs(tbl) do 244 | for key, value in pairs(elem) do 245 | if not mergedtbl[key] then 246 | mergedtbl[key] = {} 247 | end 248 | mergedtbl[key][idx] = value 249 | end 250 | end 251 | return mergedtbl 252 | end 253 | end 254 | } 255 | 256 | transform.makebatch = argcheck{ 257 | doc = [[ 258 | 259 | #### transform.makebatch(@ARGP) 260 | @ARGT 261 | 262 | This function is used in many `tnt.Dataset` to format 263 | samples in the format used by the `tnt.Engine`. 264 | 265 | This function first [merges keys](#transform.tablemergekeys) to 266 | produces a table of output. Then, transform this table into a tensor by 267 | either using a `merge` transformation provided by the user or by 268 | simply concatenating the table into a tensor directly. 269 | 270 | This function uses the [compose](#transform.compose) transform to apply 271 | successive transformations. 272 | ]], 273 | {name='merge', type='function', opt=true}, 274 | call = 275 | function(merge) 276 | 277 | local makebatch 278 | if merge then 279 | makebatch = transform.compose{ 280 | transform.tablemergekeys(), 281 | merge 282 | } 283 | else 284 | makebatch = transform.compose{ 285 | transform.tablemergekeys(), 286 | transform.tableapply( 287 | function(field) 288 | if utils.table.canmergetensor(field) then 289 | return utils.table.mergetensor(field) 290 | else 291 | return field 292 | end 293 | end 294 | ) 295 | } 296 | end 297 | 298 | return 299 | function(samples) 300 | assert(type(samples) == 'table', 'makebatch: table of samples expected') 301 | return makebatch(samples) 302 | end 303 | end 304 | } 305 | 306 | transform.randperm = argcheck{ 307 | doc = [[ 308 | 309 | #### transform.perm(@ARGP) 310 | @ARGT 311 | 312 | This function create a vector containing a permutation of the indices from 1 to `size`. 313 | This vector is a `LongTensor` and `size` must be a number. 314 | 315 | Once the vector created, this function can be used to call a specific indices in it. 316 | 317 | For example: 318 | ```lua 319 | > p = transform.perm(3) 320 | ``` 321 | creates a function `p` which contains a permutation of indices: 322 | ```lua 323 | > p(1) 324 | 2 325 | > p(2) 326 | 1 327 | > p(3) 328 | 3 329 | ``` 330 | ]], 331 | {name="size", type="number"}, 332 | call = 333 | function(size) 334 | local perm = torch.randperm(size, 'torch.LongTensor') 335 | return 336 | function(idx) 337 | return perm[idx] 338 | end 339 | end 340 | } 341 | 342 | transform.normalize = argcheck{ 343 | doc = [[ 344 | 345 | #### transform.normalize(@ARGP) 346 | @ARGT 347 | 348 | This function normalizes data, i.e., it removes its mean and 349 | devide it by its standard deviation. 350 | 351 | The input must be a `Tensor`. 352 | 353 | Once create, a `threshold` can be given (must be a number). Then, 354 | the data will be devided by their standard deviation, only if this 355 | deviation is greater than the `threshold`. This is handy, if the 356 | deviation is small and deviding by it could lead to unstability. 357 | ]], 358 | {name='threshold', type='number', default=0}, 359 | call = 360 | function(threshold) 361 | return 362 | function(z) 363 | local std = z:std() 364 | z:add(-z:mean()) 365 | if std > threshold then 366 | z:div(std) 367 | end 368 | return z 369 | end 370 | end 371 | } 372 | 373 | return transform 374 | -------------------------------------------------------------------------------- /dataset/indexeddataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local IndexedDataset, Dataset = torch.class('tnt.IndexedDataset', 'tnt.Dataset', tnt) 14 | local IndexedDatasetReader = torch.class('tnt.IndexedDatasetReader', tnt) 15 | local IndexedDatasetWriter = torch.class('tnt.IndexedDatasetWriter', tnt) 16 | 17 | -- dataset built on index reader 18 | IndexedDataset.__init = argcheck{ 19 | doc = [[ 20 | 21 | #### tnt.IndexedDataset(@ARGP) 22 | @ARGT 23 | 24 | A `tnt.IndexedDataset()` is a data structure built upon (possibly several) 25 | data archives containing a bunch of tensors of the same type. 26 | 27 | See [tnt.IndexedDatasetWriter](#IndexedDatasetWriter) and 28 | [tnt.IndexedDatasetReader](#IndexedDatasetReader) to see how to create and 29 | read a single archive. 30 | 31 | Purpose: large datasets (containing a lot of files) are often not very well 32 | handled by filesystems (especially over network). `tnt.IndexedDataset` 33 | provides a convenient and efficient way to bundle them into a single 34 | archive file, associated with an indexed file. 35 | 36 | If `path` is provided, then `fields` must be a Lua array (keys being 37 | numbers), where values are string representing a filename prefix to a 38 | (index,archive) pair. In other word `path/field.{idx,bin}` must exist. The 39 | i-th sample returned by this dataset will be a table containing each field 40 | as key, and a tensor found at the corresponding archive at index i. 41 | 42 | If `path` is not provided, then `fields` must be a Lua hash. Each key 43 | represents sample fields and the corresponding value must be a table 44 | containing the keys `idx` (for the index filename path) and `bin` (for the 45 | archive filename path). 46 | 47 | If provided (and positive), `maxload` limits the dataset size to the 48 | specified size. 49 | 50 | Archives and/or indexes can also be memory mapped with the `mmap` and 51 | `mmapidx` flags. 52 | 53 | ]], 54 | noordered = true, 55 | {name='self', type='tnt.IndexedDataset'}, 56 | {name='fields', type='table'}, 57 | {name='path', type='string', opt=true}, 58 | {name='maxload', type='number', opt=true}, 59 | {name='mmap', type='boolean', default=false}, 60 | {name='mmapidx', type='boolean', default=false}, 61 | call = 62 | function(self, fields, path, maxload, mmap, mmapidx) 63 | self.__fields = {} 64 | if path then 65 | for _, fieldname in ipairs(fields) do 66 | assert(type(fieldname) == 'string', 'fields should be a list of strings (fieldnames)') 67 | table.insert( 68 | self.__fields, { 69 | name = fieldname, 70 | idx = string.format('%s/%s.idx', path, fieldname), 71 | bin = string.format('%s/%s.bin', path, fieldname)}) 72 | end 73 | else 74 | for fieldname, paths in pairs(fields) do 75 | assert( 76 | type(fieldname) == 'string' 77 | and type(paths.bin) == 'string' 78 | and type(paths.idx) == 'string', 79 | 'fields should be a hash of string (fieldname) -> table (idx=string, bin=string)') 80 | table.insert(self.__fields, { 81 | name = fieldname, 82 | idx = paths.idx, 83 | bin = paths.bin}) 84 | end 85 | end 86 | assert(#self.__fields > 0, 'fields should not be empty') 87 | local size 88 | for _, field in ipairs(self.__fields) do 89 | field.data = tnt.IndexedDatasetReader(field.idx, field.bin, mmap, mmapidx) 90 | if size then 91 | assert(field.data:size() == size, 'inconsistent index data size') 92 | else 93 | size = field.data:size() 94 | end 95 | end 96 | if maxload and maxload > 0 and maxload < size then 97 | size = maxload 98 | end 99 | self.__size = size 100 | print(string.format("| IndexedDataset: loaded %s with %d examples", path or '', size)) 101 | end 102 | } 103 | 104 | IndexedDataset.size = argcheck{ 105 | {name='self', type='tnt.IndexedDataset'}, 106 | call = 107 | function(self) 108 | return self.__size 109 | end 110 | } 111 | 112 | IndexedDataset.get = argcheck{ 113 | {name='self', type='tnt.IndexedDataset'}, 114 | {name='idx', type='number'}, 115 | call = 116 | function(self, idx) 117 | assert(idx >= 1 and idx <= self.__size, 'index out of bound') 118 | local sample = {} 119 | for _, field in ipairs(self.__fields) do 120 | sample[field.name] = field.data:get(idx) 121 | end 122 | return sample 123 | end 124 | } 125 | 126 | -- supported tensor types 127 | local IndexedDatasetIndexTypes = {} 128 | for code, type in ipairs{'byte', 'char', 'short', 'int', 'long', 'float', 'double', 'table'} do 129 | local Type = (type == 'table') and 'Char' or type:sub(1,1):upper() .. type:sub(2) 130 | IndexedDatasetIndexTypes[type] = { 131 | code = code, 132 | name = string.format('torch.%sTensor', Type), 133 | read = string.format('read%s', Type), 134 | write = string.format('write%s', Type), 135 | size = torch[string.format('%sStorage', Type)].elementSize(), 136 | storage = torch[string.format('%sStorage', Type)], 137 | tensor = torch[string.format('%sTensor', Type)] 138 | } 139 | end 140 | 141 | -- index reader helper function 142 | local function readindex(self, indexfilename) 143 | local f = torch.DiskFile(indexfilename):binary() 144 | assert(f:readLong() == 0x584449544E54, "unrecognized index format") 145 | assert(f:readLong() == 1, "unsupported format version") 146 | local code = f:readLong() 147 | for typename, type in pairs(IndexedDatasetIndexTypes) do 148 | if type.code == code then 149 | self.type = type 150 | end 151 | end 152 | assert(self.type, "unrecognized type") 153 | assert(f:readLong() == self.type.size, "type size do not match") 154 | self.N = f:readLong() 155 | self.S = f:readLong() 156 | self.dimoffsets = torch.LongTensor(f:readLong(self.N+1)) 157 | self.datoffsets = torch.LongTensor(f:readLong(self.N+1)) 158 | self.sizes = torch.LongTensor(f:readLong(self.S)) 159 | f:close() 160 | end 161 | 162 | -- index writer 163 | IndexedDatasetWriter.__init = argcheck{ 164 | doc = [[ 165 | 166 | ##### tnt.IndexedDatasetWriter(@ARGP) 167 | @ARGT 168 | 169 | Creates a (archive,index) file pair. The archive will contain tensors of the same specified `type`. 170 | 171 | `type` must be a string chosen in {`byte`, `char`, `short`, `int`, `long`, `float`, `double` or `table`}. 172 | 173 | `indexfilename` is the full path to the index file to be created. 174 | `datafilename` is the full path to the data archive file to be created. 175 | 176 | Tensors are added to the archive with [add()](#IndexedDataset.add). 177 | 178 | Note that you *must* call [close()](#IndexedDataset.close) to ensure all 179 | data is written on disk and to create the index file. 180 | 181 | The type `table` is special: data will be stored into a CharTensor, 182 | serialized from a Lua table object. IndexedDatasetReader will then 183 | deserialize the CharTensor into a table at read time. This allows storing 184 | heterogenous data easily into an IndexedDataset. 185 | 186 | ]], 187 | {name='self', type='tnt.IndexedDatasetWriter'}, 188 | {name='indexfilename', type='string'}, 189 | {name='datafilename', type='string'}, 190 | {name='type', type='string'}, 191 | call = 192 | function(self, indexfilename, datafilename, type) 193 | self.BLOCKSZ = 1024 194 | self.indexfilename = indexfilename 195 | self.datafilename = datafilename 196 | assert(IndexedDatasetIndexTypes[type], 'invalid type (byte, char, short, int, long, float, double or table expected)') 197 | self.dimoffsets = torch.LongTensor(self.BLOCKSZ) 198 | self.datoffsets = torch.LongTensor(self.BLOCKSZ) 199 | self.sizes = torch.LongTensor(self.BLOCKSZ) 200 | self.N = 0 201 | self.S = 0 202 | self.dimoffsets[1] = 0 203 | self.datoffsets[1] = 0 204 | self.type = IndexedDatasetIndexTypes[type] 205 | self.datafile = torch.DiskFile(datafilename, 'w'):binary() 206 | end 207 | } 208 | 209 | -- append mode 210 | IndexedDatasetWriter.__init = argcheck{ 211 | doc = [[ 212 | ##### tnt.IndexedDatasetWriter(@ARGP) 213 | @ARGT 214 | 215 | Opens an existing (archive,index) file pair for appending. The tensor type is inferred from the provided 216 | index file. 217 | 218 | `indexfilename` is the full path to the index file to be opened. 219 | `datafilename` is the full path to the data archive file to be opened. 220 | 221 | ]], 222 | {name='self', type='tnt.IndexedDatasetWriter'}, 223 | {name='indexfilename', type='string'}, 224 | {name='datafilename', type='string'}, 225 | overload = IndexedDatasetWriter.__init, 226 | call = 227 | function(self, indexfilename, datafilename) 228 | self.BLOCKSZ = 1024 229 | self.indexfilename = indexfilename 230 | self.datafilename = datafilename 231 | readindex(self, indexfilename) 232 | self.datafile = torch.DiskFile(datafilename, 'rw'):binary() 233 | self.datafile:seekEnd() 234 | end 235 | } 236 | 237 | IndexedDatasetWriter.add = argcheck{ 238 | doc = [[ 239 | 240 | ###### tnt.IndexedDatasetWriter.add(@ARGP) 241 | @ARGT 242 | 243 | Add a tensor to the archive and record its index position. The tensor type must of the same type 244 | than the one specified at the creation of the `tnt.IndexedDatasetWriter`. 245 | ]], 246 | {name='self', type='tnt.IndexedDatasetWriter'}, 247 | {name='tensor', type='torch.*Tensor'}, 248 | call = 249 | function(self, tensor) 250 | assert(torch.typename(tensor) == self.type.name, 'invalid tensor type') 251 | local size = tensor:size() 252 | local dim = size:size() 253 | local N = self.N + 1 254 | local S = self.S + dim 255 | if self.dimoffsets:size(1) < N+1 then -- +1 for the first 0 value 256 | self.dimoffsets:resize(N+self.BLOCKSZ) 257 | self.datoffsets:resize(N+self.BLOCKSZ) 258 | end 259 | if self.sizes:size(1) < S then 260 | self.sizes:resize(S+self.BLOCKSZ) 261 | end 262 | self.dimoffsets[N+1] = self.dimoffsets[N] + dim 263 | self.datoffsets[N+1] = self.datoffsets[N] + tensor:nElement() 264 | if dim > 0 then 265 | self.sizes:narrow(1, self.S+1, dim):copy(torch.LongTensor(size)) 266 | end 267 | self.N = N 268 | self.S = S 269 | if tensor:nElement() > 0 then 270 | self.datafile[self.type.write](self.datafile, tensor:clone():storage()) 271 | end 272 | end 273 | } 274 | 275 | IndexedDatasetWriter.add = argcheck{ 276 | doc = [[ 277 | ###### tnt.IndexedDatasetWriter.add(@ARGP) 278 | @ARGT 279 | 280 | Convenience method which given a `filename` will open the corresponding 281 | file in `binary` mode, and reads all data in there as if it was of the type 282 | specified at the `tnt.IndexedDatasetWriter` construction. A corresponding 283 | tensor is then added to the archive/index pair. 284 | ]], 285 | {name='self', type='tnt.IndexedDatasetWriter'}, 286 | {name='filename', type='string'}, 287 | overload = IndexedDatasetWriter.add, 288 | call = 289 | function(self, filename) 290 | local f = torch.DiskFile(filename):binary() 291 | f:seekEnd() 292 | local sz = f:position()-1 293 | f:seek(1) 294 | local storage = f[self.type.read](f, sz/self.type.size) 295 | f:close() 296 | self:add(self.type.tensor(storage)) 297 | end 298 | } 299 | 300 | IndexedDatasetWriter.add = argcheck{ 301 | doc = [[ 302 | ###### tnt.IndexedDatasetWriter.add(@ARGP) 303 | @ARGT 304 | 305 | Convenience method only available for `table` type IndexedDataset. 306 | The table will be serialized into a CharTensor. 307 | 308 | ]], 309 | {name='self', type='tnt.IndexedDatasetWriter'}, 310 | {name='table', type='table'}, 311 | nonamed = true, -- ambiguity possible with table arg 312 | overload = IndexedDatasetWriter.add, 313 | call = 314 | function(self, tbl) 315 | assert( 316 | self.type == IndexedDatasetIndexTypes.table 317 | or self.type == IndexedDatasetIndexTypes.char, 318 | 'table convenience method is only available for "table" or "char"-based datasets') 319 | tbl = torch.CharTensor(torch.serializeToStorage(tbl)) 320 | self:add(tbl) 321 | end 322 | } 323 | 324 | IndexedDatasetWriter.close = argcheck{ 325 | doc = [[ 326 | ###### tnt.IndexedDatasetWriter.add(@ARGP) 327 | @ARGT 328 | 329 | Finalize the index, and Close the archive/index filename pair. This method 330 | must be called to ensure the index is written and all the archive data is 331 | flushed on disk. 332 | ]], 333 | {name='self', type='tnt.IndexedDatasetWriter'}, 334 | call = 335 | function(self) 336 | local f = torch.DiskFile(self.indexfilename, 'w'):binary() 337 | f:writeLong(0x584449544E54) -- magic number 338 | f:writeLong(1) -- version 339 | f:writeLong(self.type.code) -- type code 340 | f:writeLong(self.type.size) 341 | f:writeLong(self.N) 342 | f:writeLong(self.S) 343 | 344 | -- resize properly underlying storages 345 | self.dimoffsets = torch.LongTensor( self.dimoffsets:storage():resize(self.N+1) ) 346 | self.datoffsets = torch.LongTensor( self.datoffsets:storage():resize(self.N+1) ) 347 | self.sizes = torch.LongTensor( self.sizes:storage():resize(self.S) ) 348 | 349 | -- write index on disk 350 | f:writeLong(self.dimoffsets:storage()) 351 | f:writeLong(self.datoffsets:storage()) 352 | f:writeLong(self.sizes:storage()) 353 | f:close() 354 | self.datafile:close() 355 | end 356 | } 357 | 358 | -- helper function that updates the meta table when data is on file: 359 | local function updatemetatable(data, datafilename) 360 | local data_mt = {} 361 | local f = torch.DiskFile(datafilename):binary():noBuffer() 362 | function data_mt:narrow(dim, offset, size) 363 | f:seek((offset - 1) * self.type.size + 1) 364 | return self.type.tensor(f[self.type.read](f, size)) 365 | end 366 | setmetatable(data, {__index = data_mt}) 367 | return data 368 | end 369 | 370 | -- index reader 371 | IndexedDatasetReader.__init = argcheck{ 372 | doc = [[ 373 | 374 | ##### tnt.IndexedDatasetReader(@ARGP) 375 | @ARGT 376 | 377 | Reads an archive/index pair previously created by 378 | [tnt.IndexedDatasetWriter](#IndexedDatasetWriter). 379 | 380 | `indexfilename` is the full path to the index file. 381 | `datafilename` is the full path to the archive file. 382 | 383 | Memory mapping can be specified for both the archive and index through the 384 | optional `mmap` and `mmapidx` flags. 385 | 386 | ]], 387 | {name='self', type='tnt.IndexedDatasetReader'}, 388 | {name='indexfilename', type='string'}, 389 | {name='datafilename', type='string'}, 390 | {name='mmap', type='boolean', default=false}, 391 | {name='mmapidx', type='boolean', default=false}, 392 | call = 393 | function(self, indexfilename, datafilename, mmap, mmapidx) 394 | self.indexfilename = indexfilename 395 | self.datafilename = datafilename 396 | 397 | if mmapidx then -- memory mapped index 398 | local idx = torch.LongStorage(indexfilename) 399 | local offset = 1 400 | assert(idx[offset] == 0x584449544E54, "unrecognized index format") 401 | offset = offset + 1 402 | assert(idx[offset] == 1, "unsupported format version") 403 | offset = offset + 1 404 | local code = idx[offset] 405 | offset = offset + 1 406 | for typename, type in pairs(IndexedDatasetIndexTypes) do 407 | if type.code == code then 408 | self.type = type 409 | self.typename = typename 410 | end 411 | end 412 | assert(self.type, "unrecognized type") 413 | assert(idx[offset] == self.type.size, "type size do not match") 414 | offset = offset + 1 415 | self.N = idx[offset] 416 | offset = offset + 1 417 | self.S = idx[offset] 418 | offset = offset + 1 419 | self.dimoffsets = torch.LongTensor(idx, offset, self.N+1) 420 | offset = offset + self.N+1 421 | self.datoffsets = torch.LongTensor(idx, offset, self.N+1) 422 | offset = offset + self.N+1 423 | self.sizes = torch.LongTensor(idx, offset, self.S) 424 | else -- index on file 425 | readindex(self, indexfilename) 426 | end 427 | 428 | if mmap then -- memory mapped data 429 | self.data = self.type.tensor(self.type.storage(datafilename)) 430 | else -- data on file 431 | local data = {type=self.type} 432 | self.data = updatemetatable(data, datafilename) 433 | end 434 | end 435 | } 436 | 437 | function IndexedDatasetReader:__write(file) 438 | local obj = {} 439 | for k,v in pairs(self) do obj[k] = v end 440 | obj.type = nil 441 | if type(self.data) == 'table' then obj.data.type = nil end 442 | file:writeObject(obj) 443 | if type(self.data) == 'table' then obj.data.type = self.type end 444 | end 445 | 446 | function IndexedDatasetReader:__read(file) 447 | for k,v in pairs(file:readObject()) do 448 | self[k] = v 449 | end 450 | self.type = IndexedDatasetIndexTypes[self.typename] 451 | if type(self.data) == 'table' then 452 | self.data.type = self.type 453 | updatemetatable(self.data, self.datafilename) 454 | end 455 | end 456 | 457 | IndexedDatasetReader.size = argcheck{ 458 | doc = [[ 459 | 460 | ###### tnt.IndexedDatasetReader.size(@ARGP) 461 | 462 | Returns the number of tensors present in the archive. 463 | ]], 464 | {name='self', type='tnt.IndexedDatasetReader'}, 465 | call = 466 | function(self) 467 | return self.N 468 | end 469 | } 470 | 471 | IndexedDatasetReader.get = argcheck{ 472 | doc = [[ 473 | 474 | ###### tnt.IndexedDatasetReader.get(@ARGP) 475 | 476 | Returns the tensor at the specified `index` in the archive. 477 | ]], 478 | {name='self', type='tnt.IndexedDatasetReader'}, 479 | {name='index', type='number'}, 480 | call = 481 | function(self, index) 482 | assert(index > 0 and index <= self.N, 'index out of range') 483 | local ndim = self.dimoffsets[index+1]-self.dimoffsets[index] 484 | if ndim == 0 then 485 | return self.type.tensor() 486 | end 487 | local size = self.sizes:narrow( 488 | 1, 489 | self.dimoffsets[index]+1, 490 | ndim 491 | ) 492 | size = size:clone():storage() 493 | local data = self.data:narrow( 494 | 1, 495 | self.datoffsets[index]+1, 496 | self.datoffsets[index+1]-self.datoffsets[index] 497 | ):view(size) 498 | if self.type == IndexedDatasetIndexTypes.table then 499 | return torch.deserializeFromStorage(data) 500 | else 501 | return data:clone() 502 | end 503 | end 504 | } 505 | --------------------------------------------------------------------------------