├── .travis.yml ├── CMakeLists.txt ├── CONTRIBUTING.md ├── LICENSE ├── PATENTS ├── README.md ├── dataset ├── batchdataset.lua ├── concatdataset.lua ├── coroutinebatchdataset.lua ├── datasetiterator.lua ├── indexeddataset.lua ├── init.lua ├── listdataset.lua ├── paralleldatasetiterator.lua ├── resampledataset.lua ├── shuffledataset.lua ├── splitdataset.lua ├── tabledataset.lua └── transformdataset.lua ├── engine ├── init.lua ├── optimengine.lua └── sgdengine.lua ├── env.lua ├── example └── mnist.lua ├── init.lua ├── log ├── init.lua └── view │ ├── json.lua │ ├── status.lua │ └── text.lua ├── meter ├── apmeter.lua ├── aucmeter.lua ├── averagevaluemeter.lua ├── classerrormeter.lua ├── confusionmeter.lua ├── init.lua ├── mapmeter.lua ├── movingaveragevaluemeter.lua ├── msemeter.lua ├── multilabelconfusionmeter.lua ├── ndcgmeter.lua ├── precisionatkmeter.lua ├── precisionmeter.lua ├── recallmeter.lua └── timemeter.lua ├── rocks └── torchnet-scm-1.rockspec ├── test ├── datasets.lua ├── iterators.lua ├── meters.lua └── test.lua ├── transform.lua └── utils ├── init.lua ├── nn.lua ├── sys.lua └── table.lua /.travis.yml: -------------------------------------------------------------------------------- 1 | language: c 2 | os: 3 | - linux 4 | compiler: 5 | - clang 6 | cache: 7 | directories: 8 | - $HOME/OpenBlasInstall 9 | sudo: false 10 | env: 11 | - TORCH_LUA_VERSION=LUAJIT21 12 | - TORCH_LUA_VERSION=LUA52 13 | addons: 14 | apt: 15 | packages: 16 | - cmake 17 | - gfortran 18 | - gcc-multilib 19 | - gfortran-multilib 20 | - liblapack-dev 21 | - build-essential 22 | - gcc 23 | - g++ 24 | - curl 25 | - cmake 26 | - libreadline-dev 27 | - git-core 28 | - libqt4-core 29 | - libqt4-gui 30 | - libqt4-dev 31 | - libjpeg-dev 32 | - libpng-dev 33 | - ncurses-dev 34 | - imagemagick 35 | - libzmq3-dev 36 | - gfortran 37 | - unzip 38 | - gnuplot 39 | - gnuplot-x11 40 | before_script: 41 | - export ROOT_TRAVIS_DIR=$(pwd) 42 | - export INSTALL_PREFIX=~/torch/install 43 | - ls $HOME/OpenBlasInstall/lib || (cd /tmp/ && git clone https://github.com/xianyi/OpenBLAS.git -b master && cd OpenBLAS && (make NO_AFFINITY=1 -j$(getconf _NPROCESSORS_ONLN) 2>/dev/null >/dev/null) && make PREFIX=$HOME/OpenBlasInstall install) 44 | - git clone https://github.com/torch/distro.git ~/torch --recursive 45 | - cd ~/torch && git submodule update --init --recursive 46 | - mkdir build && cd build 47 | - export CMAKE_LIBRARY_PATH=$HOME/OpenBlasInstall/include:$HOME/OpenBlasInstall/lib:$CMAKE_LIBRARY_PATH 48 | - cmake .. -DCMAKE_INSTALL_PREFIX="${INSTALL_PREFIX}" -DCMAKE_BUILD_TYPE=Release -DWITH_${TORCH_LUA_VERSION}=ON 49 | - make && make install 50 | - cd $ROOT_TRAVIS_DIR 51 | - export LD_LIBRARY_PATH=${INSTALL_PREFIX}/lib:$LD_LIBRARY_PATH 52 | - ${INSTALL_PREFIX}/bin/luarocks install torch 53 | script: 54 | - ${INSTALL_PREFIX}/bin/luarocks make rocks/torchnet-scm-1.rockspec 55 | - export PATH=${INSTALL_PREFIX}/bin:$PATH 56 | - export TESTLUA=$(which luajit lua | head -n 1) 57 | - ${TESTLUA} -e "tnt = require 'torchnet'; t=tnt.test(); if t.errors[1] then os.exit(1) end" 58 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required (VERSION 2.8) 2 | cmake_policy(VERSION 2.8) 3 | 4 | set(PKGNAME torchnet) 5 | 6 | file(GLOB_RECURSE luafiles RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.lua") 7 | 8 | foreach(file ${luafiles}) 9 | get_filename_component(dir ${file} PATH) 10 | install(FILES ${file} DESTINATION ${LUA_PATH}/${PKGNAME}/${dir}) 11 | endforeach() 12 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Torchnet 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Make sure your code lints. 12 | 5. If you haven't already, complete the Contributor License Agreement ("CLA"). 13 | 14 | ## Contributor License Agreement ("CLA") 15 | In order to accept your pull request, we need you to submit a CLA. You only need 16 | to do this once to work on any of Facebook's open source projects. 17 | 18 | Complete your CLA here: 19 | 20 | ## Issues 21 | We use GitHub issues to track public bugs. Please ensure your description is 22 | clear and has sufficient instructions to be able to reproduce the issue. 23 | 24 | ## Coding Style 25 | * 3 spaces for indentation rather than tabs 26 | * 80 character line length 27 | * variables names all lower-case, no underlines 28 | 29 | ## License 30 | By contributing to Torchnet, you agree that your contributions will be licensed 31 | under its BSD license. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For Torchnet software 4 | 5 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the Torchnet software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /dataset/batchdataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local transform = require 'torchnet.transform' 13 | 14 | local BatchDataset = 15 | torch.class('tnt.BatchDataset', 'tnt.Dataset', tnt) 16 | 17 | BatchDataset.__init = argcheck{ 18 | doc = [[ 19 | 20 | #### tnt.BatchDataset(@ARGP) 21 | @ARGT 22 | 23 | Given a `dataset`, `tnt.BatchDataset` merges samples from this dataset to 24 | form a new sample which can be interpreted as a batch (of size 25 | `batchsize`). 26 | 27 | The `merge` function controls how the batch is performed. It is a closure 28 | taking a Lua array as input containing all occurrences (for a given batch) 29 | of a field of the sample, and returning the aggregated version of these 30 | occurrences. By default the occurrences are supposed to be tensors, and 31 | they aggregated along the first dimension. 32 | 33 | More formally, if the i-th sample of the underlying dataset is written as: 34 | ```lua 35 | {input=, target=} 36 | ``` 37 | assuming only two fields `input` and `target` in the sample, then `merge()` 38 | will be passed tables of the form: 39 | ```lua 40 | {, , ... } 41 | ``` 42 | or 43 | ```lua 44 | {, , ... } 45 | ``` 46 | with `n` being the batch size. 47 | 48 | It is often important to shuffle examples while performing the batch 49 | operation. `perm(idx, size)` is a closure which returns the shuffled index 50 | of the sample at position `idx` in the underlying dataset. For convenience, 51 | the `size` of the underlying dataset is also passed to the closure. By 52 | default, the closure is the identity. 53 | 54 | The underlying dataset size might or might not be always divisible by 55 | `batchsize`. The optional `policy` string specify how to handle corner 56 | cases: 57 | - `include-last` makes sure all samples of the underlying dataset will be seen, batches will be of size equal or inferior to `batchsize`. 58 | - `skip-last` will skip last examples of the underlying dataset if its size is not properly divisible. Batches will be always of size equal to `batchsize`. 59 | - `divisible-only` will raise an error if the underlying dataset has not a size divisible by `batchsize`. 60 | 61 | Purpose: the concept of batch is problem dependent. In *torchnet*, it is up 62 | to the user to interpret a sample as a batch or not. When one wants to 63 | assemble samples from an existing dataset into a batch, then 64 | `tnt.BatchDataset` is suited for the job. Sometimes it is however more 65 | convenient to write a dataset from scratch providing "batched" samples. 66 | ]], 67 | {name='self', type='tnt.BatchDataset'}, 68 | {name='dataset', type='tnt.Dataset'}, 69 | {name='batchsize', type='number'}, 70 | {name='perm', type='function', default=function(idx, size) return idx end}, 71 | {name='merge', type='function', opt=true}, 72 | {name='policy', type='string', default='include-last'}, 73 | {name='filter', type='function', default=function(sample) return true end}, 74 | call = 75 | function(self, dataset, batchsize, perm, merge, policy, filter) 76 | assert(batchsize > 0 and math.floor(batchsize) == batchsize, 77 | 'batchsize should be a positive integer number') 78 | self.dataset = dataset 79 | self.perm = perm 80 | self.batchsize = batchsize 81 | self.makebatch = transform.makebatch{merge=merge} 82 | self.policy = policy 83 | self.filter = filter 84 | self:size() -- check policy 85 | end 86 | } 87 | 88 | BatchDataset.size = argcheck{ 89 | {name='self', type='tnt.BatchDataset'}, 90 | call = 91 | function(self) 92 | local policy = self.policy 93 | if policy == 'include-last' then 94 | return math.ceil(self.dataset:size()/self.batchsize) 95 | elseif policy == 'skip-last' then 96 | return math.floor(self.dataset:size()/self.batchsize) 97 | elseif policy == 'divisible-only' then 98 | assert(self.dataset:size() % self.batchsize == 0, 'dataset size is not divisible by batch size') 99 | return self.dataset:size()/self.batchsize 100 | else 101 | error('invalid policy (include-last | skip-last | divisible-only expected)') 102 | end 103 | end 104 | } 105 | 106 | BatchDataset.get = argcheck{ 107 | {name='self', type='tnt.BatchDataset'}, 108 | {name='idx', type='number'}, 109 | call = 110 | function(self, idx) 111 | assert(idx >= 1 and idx <= self:size(), 'index out of bound') 112 | local samples = {} 113 | local maxidx = self.dataset:size() 114 | for i=1,self.batchsize do 115 | local idx = (idx - 1)*self.batchsize + i 116 | if idx > maxidx then 117 | break 118 | end 119 | idx = self.perm(idx, maxidx) 120 | local sample = self.dataset:get(idx) 121 | if self.filter(sample) then table.insert(samples, sample) end 122 | end 123 | samples = self.makebatch(samples) 124 | collectgarbage() 125 | collectgarbage() 126 | return samples 127 | end 128 | } 129 | -------------------------------------------------------------------------------- /dataset/concatdataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local ConcatDataset = 14 | torch.class('tnt.ConcatDataset', 'tnt.Dataset', tnt) 15 | 16 | ConcatDataset.__init = argcheck{ 17 | doc = [[ 18 | 19 | #### tnt.ConcatDataset(@ARGP) 20 | @ARGT 21 | 22 | Given a Lua array (`datasets`) of [tnt.Dataset](#Dataset), concatenates 23 | them into a single dataset. The size of the new dataset is the sum of the 24 | underlying dataset sizes. 25 | 26 | Purpose: useful to assemble different existing datasets, possibly 27 | large-scale datasets as the concatenation operation is done in an 28 | on-the-fly manner. 29 | ]], 30 | noordered=true, 31 | {name='self', type='tnt.ConcatDataset'}, 32 | {name='datasets', type='table'}, 33 | call = 34 | function(self, datasets) 35 | assert(#datasets > 0, 'datasets should not be an empty table') 36 | local indices = torch.LongTensor(#datasets, 2) -- indices: begin, end 37 | local size = 0 38 | for i, dataset in ipairs(datasets) do 39 | assert(torch.isTypeOf(dataset, 'tnt.Dataset'), 40 | 'each member of datasets table should be a tnt.Dataset') 41 | indices[i][1] = size+1 42 | size = size + dataset:size() 43 | indices[i][2] = size 44 | end 45 | self.__datasets = datasets 46 | self.__indices = indices 47 | self.__size = size 48 | end 49 | } 50 | 51 | ConcatDataset.size = argcheck{ 52 | {name='self', type='tnt.ConcatDataset'}, 53 | call = 54 | function(self) 55 | return self.__size 56 | end 57 | } 58 | 59 | ConcatDataset.get = argcheck{ 60 | {name='self', type='tnt.ConcatDataset'}, 61 | {name='idx', type='number'}, 62 | call = 63 | function(self, idx) 64 | assert(idx >= 1 and idx <= self.__size, 'index out of bound') 65 | local indices = self.__indices 66 | local l, r = 1, indices:size(1) 67 | while l ~= r do 68 | local m = math.floor((r-l)/2) + l 69 | if l == m then 70 | if idx > indices[l][2] then 71 | l, r = r, r 72 | else 73 | l, r = l, l 74 | end 75 | else 76 | if idx > indices[m][2] then 77 | l, r = m, r 78 | elseif idx < indices[m][1] then 79 | l, r = l, m 80 | else 81 | l, r = m, m 82 | end 83 | end 84 | end 85 | return self.__datasets[l]:get(idx-self.__indices[l][1]+1) 86 | end 87 | } 88 | -------------------------------------------------------------------------------- /dataset/coroutinebatchdataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local CoroutineBatchDataset, BatchDataset = 14 | torch.class('tnt.CoroutineBatchDataset', 'tnt.BatchDataset', tnt) 15 | 16 | CoroutineBatchDataset.__init = argcheck{ 17 | doc = [[ 18 | 19 | #### tnt.CoroutineBatchDataset(@ARGP) 20 | @ARGT 21 | 22 | Given a `dataset`, `tnt.CoroutineBatchDataset` merges samples from this dataset 23 | to form a new sample which can be interpreted as a batch (of size `batchsize`). 24 | 25 | It behaves the same and has the same arguments as `tnt.BatchDataset` (see the 26 | documentation there for additional details), with one important distinction: 27 | it allows the underlying dataset to postpone returning the individual samples 28 | once by doing a call to `coroutine.yield()` (from the underlying dataset). 29 | 30 | This is useful when using datasets that are inefficient or slow when they need 31 | to provide the required sample immediately after a call to `dataset:get()`. The 32 | general pattern of code in the underlying `dataset:get()` would be: 33 | 34 | ```lua 35 | FooDataset.get = function(self, idx) 36 | prepare(idx) -- stores sample in self.__data[idx] 37 | coroutine.yield() 38 | return self.__data[idx] 39 | end 40 | ``` 41 | 42 | Herein, the function `prepare(idx)` can implement, for instance, a buffering of 43 | indices before actually fetching them. 44 | ]], 45 | {name = 'self', type = 'tnt.CoroutineBatchDataset'}, 46 | {name = 'dataset', type = 'tnt.Dataset'}, 47 | {name = 'batchsize', type = 'number'}, 48 | {name = 'perm', type = 'function', default = function(idx, size) return idx end}, 49 | {name = 'merge', type = 'function', opt = true}, 50 | {name = 'policy', type = 'string', default = 'include-last'}, 51 | {name='filter', type='function', default=function(sample) return true end}, 52 | call = function(self, dataset, batchsize, perm, merge, policy, filter) 53 | BatchDataset.__init(self, dataset, batchsize, perm, merge, policy, filter) 54 | end 55 | } 56 | 57 | CoroutineBatchDataset.get = argcheck{ 58 | {name = 'self', type = 'tnt.CoroutineBatchDataset'}, 59 | {name = 'idx', type = 'number'}, 60 | call = function(self, idx) 61 | assert(idx >= 1 and idx <= self:size(), 'index out of bound') 62 | assert(idx == math.floor(idx), 'index should be integer value') 63 | 64 | -- create and start coroutines that perform get(): 65 | local crs, samples, maxidx = {}, {}, self.dataset:size() 66 | for n = 1,self.batchsize do 67 | local idx = (idx - 1) * self.batchsize + n 68 | if idx > maxidx then break end 69 | 70 | -- start coroutine: 71 | crs[n] = coroutine.create( 72 | function() return self.dataset:get(self.perm(idx)) end 73 | ) -- create coroutine that gets example 74 | local status, sample = coroutine.resume(crs[n]) -- register sample 75 | if not status and 76 | string.format('%s', sample) == 'not enough memory' then 77 | collectgarbage() 78 | collectgarbage() 79 | status, sample = coroutine.resume(crs[n]) -- register sample 80 | end 81 | 82 | if not status then 83 | error(string.format('dataset threw error: %s', sample)) 84 | end 85 | 86 | -- if coroutine does not yield but dies, store sample: 87 | if coroutine.status(crs[n]) == 'dead' then samples[n] = sample end 88 | end 89 | 90 | -- get the samples from coroutines that are suspended: 91 | for n = 1,self.batchsize do 92 | if crs[n] and coroutine.status(crs[n]) == 'suspended' then 93 | local status, sample = coroutine.resume(crs[n]) 94 | if not status then 95 | error(string.format('dataset threw error: %s', sample)) 96 | end 97 | assert(coroutine.status(crs[n]) == 'dead', 'coroutine did not die') 98 | samples[n] = sample 99 | end 100 | end 101 | 102 | -- filter the samples: 103 | local filtered = {} 104 | for n = 1,self.batchsize do 105 | if self.filter(samples[n]) then table.insert(filtered, samples[n]) end 106 | end 107 | 108 | -- return batch: 109 | samples = self.makebatch(filtered) 110 | collectgarbage() 111 | return samples 112 | end 113 | } 114 | -------------------------------------------------------------------------------- /dataset/datasetiterator.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local doc = require 'argcheck.doc' 13 | 14 | doc[[ 15 | ### Dataset Iterators 16 | 17 | It is easy to iterate over datasets using a for loop. However, sometimes 18 | one wants to filter out samples in a on-the-fly manner or thread sample fetching. 19 | 20 | Iterators are here for this particular cases. In general, refrain from writing 21 | iterators for handling custom cases, and write instead a `tnt.Dataset` 22 | 23 | Iterators implement two methods: 24 | 25 | * `run()` which returns a Lua iterator usable in a for loop. 26 | * `exec(funcname, ...)` which execute a given funcname on the underlying dataset. 27 | 28 | Typical usage is achieved with a for loop: 29 | ```lua 30 | for sample in iterator:run() do 31 | 32 | end 33 | ``` 34 | 35 | Iterators implement the `__call` event, so one might also use the `()` operator: 36 | ```lua 37 | for sample in iterator() do 38 | 39 | end 40 | ``` 41 | 42 | ]] 43 | 44 | local DatasetIterator = torch.class('tnt.DatasetIterator', tnt) 45 | 46 | -- iterate over a dataset 47 | DatasetIterator.__init = argcheck{ 48 | doc = [[ 49 | 50 | #### tnt.DatasetIterator(@ARGP) 51 | @ARGT 52 | 53 | The default dataset iterator. 54 | 55 | `perm(idx)` is a permutation used to shuffle the examples. If shuffling 56 | is needed, one can use this closure, or (better) use 57 | [tnt.ShuffleDataset](#ShuffleDataset) on the underlying dataset. 58 | 59 | `filter(sample)` is a closure which returns `true` if the given sample 60 | should be considered or `false` if not. 61 | 62 | `transform(sample)` is a closure which can perform online transformation of 63 | samples. It returns a modified version of the given `sample`. It is the 64 | identity by default. It is often more interesting to use 65 | [tnt.TransformDataset](#TransformDataset) for that purpose. 66 | ]], 67 | {name='self', type='tnt.DatasetIterator'}, 68 | {name='dataset', type='tnt.Dataset'}, 69 | {name='perm', type='function', default=function(idx) return idx end}, 70 | {name='filter', type='function', default=function(sample) return true end}, 71 | {name='transform', type='function', default=function(sample) return sample end}, 72 | call = 73 | function(self, dataset, perm, filter, transform) 74 | self.dataset = dataset 75 | function self.run() 76 | local size = dataset:size() 77 | local idx = 1 78 | return 79 | function() 80 | while idx <= size do 81 | local sample = transform(dataset:get(perm(idx))) 82 | idx = idx + 1 83 | if filter(sample) then 84 | return sample 85 | end 86 | end 87 | end 88 | end 89 | end 90 | } 91 | 92 | -- iterates from another iterator 93 | DatasetIterator.__init = argcheck{ 94 | {name='self', type='tnt.DatasetIterator'}, 95 | {name='iterator', type='tnt.DatasetIterator'}, 96 | {name='filter', type='function', default=function(sample) return true end}, 97 | {name='transform', type='function', default=function(sample) return sample end}, 98 | overload = DatasetIterator.__init, 99 | call = 100 | function(self, iterator, filter, transform) 101 | self.iterator = iterator 102 | function self.run() 103 | local loop = iterator:run() 104 | return 105 | function() 106 | repeat 107 | local sample = loop() 108 | if sample then 109 | sample = transform(sample) 110 | if filter(sample) then 111 | return sample 112 | end 113 | end 114 | until not sample 115 | end 116 | end 117 | end 118 | } 119 | 120 | DatasetIterator.__call__ = 121 | function(self, ...) 122 | return self:run(...) 123 | end 124 | 125 | doc[[ 126 | 127 | #### tnt.DatasetIterator.exec(tnt.DatasetIterator, name, ...) 128 | 129 | Execute the given method `name` on the underlying dataset, passing it the 130 | subsequent arguments, and returns what the `name` method returns. 131 | ]] 132 | 133 | DatasetIterator.exec = 134 | function(self, name, ...) 135 | if type(self[name]) == 'function' then 136 | return self[name](self, ...) 137 | elseif self.dataset then 138 | return self.dataset:exec(name, ...) 139 | elseif self.iterator then 140 | return self.iterator:exec(name, ...) 141 | else 142 | error(string.format('unknown function <%s>', name)) 143 | end 144 | end 145 | 146 | DatasetIterator.filter = 147 | function(self, filter) 148 | return tnt.DatasetIterator({ iterator = self, filter = filter }) 149 | end 150 | 151 | DatasetIterator.transform = 152 | function(self, transform) 153 | return tnt.DatasetIterator({ iterator = self, transform = transform }) 154 | end 155 | -------------------------------------------------------------------------------- /dataset/init.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local doc = require 'argcheck.doc' 13 | 14 | doc[[ 15 | 16 | ### tnt.Dataset() 17 | 18 | *torchnet* provides a variety of data containers, which can be easily 19 | plugged between each others, allowing the user to easily concat, split, 20 | batch, resample etc... datasets. 21 | 22 | A instance `dataset` of a `tnt.Dataset()` implements two main methods: 23 | 24 | * `dataset:size()` which returns the size of the dataset. 25 | * `dataset:get(idx)` where `idx` is a number between 1 and the dataset size. 26 | 27 | While it is easy to iterate over a dataset with a for loop, several 28 | `DatasetIterator` iterators are nevertheless provided, allowing the user to 29 | filter out some samples in an on-the-fly manner, or to parallelize easily 30 | data fetching. 31 | 32 | In *torchnet*, a sample returned by `dataset:get()` is supposed to be a Lua 33 | `table`. Fields of the table can be arbitrary, even though many datasets 34 | will only work with torch tensors. 35 | 36 | ]] 37 | 38 | local Dataset = torch.class('tnt.Dataset', tnt) 39 | 40 | Dataset.__init = 41 | function() 42 | end 43 | 44 | -- returns a number 45 | Dataset.size = 46 | function(self) 47 | error(string.format( 48 | 'size not implemented for class <%s>', 49 | torch.type(self))) 50 | end 51 | 52 | -- execute a function 53 | Dataset.exec = 54 | function(self, name, ...) 55 | if type(self[name]) == 'function' then 56 | return self[name](self, ...) 57 | elseif self.dataset then 58 | return self.dataset:exec(name, ...) 59 | elseif self.__dataset then 60 | return self.__dataset:exec(name, ...) 61 | else 62 | error(string.format('unknown function <%s>', name)) 63 | end 64 | end 65 | 66 | -- returns a table of tensors 67 | Dataset.get = 68 | function(self) 69 | error(string.format( 70 | 'get not implemented for class <%s>', 71 | torch.type(self))) 72 | end 73 | 74 | Dataset.batch = 75 | function(...) 76 | return tnt.BatchDataset(...) 77 | end 78 | 79 | Dataset.sample = 80 | function(...) 81 | return tnt.ResampleDataset(...) 82 | end 83 | 84 | Dataset.shuffle = 85 | function(...) 86 | return tnt.ShuffleDataset(...) 87 | end 88 | 89 | Dataset.split = 90 | function(...) 91 | return tnt.SplitDataset(...) 92 | end 93 | 94 | Dataset.transform = 95 | function(...) 96 | return tnt.TransformDataset(...) 97 | end 98 | 99 | Dataset.iterator = 100 | function(...) 101 | return tnt.DatasetIterator(...) 102 | end 103 | 104 | Dataset.parallel = argcheck{ 105 | {name='self', type='tnt.Dataset'}, 106 | {name='init', type='function', default=function(idx) end}, 107 | {name='nthread', type='number'}, 108 | {name='perm', type='function', default=function(idx) return idx end}, 109 | {name='filter', type='function', default=function(sample) return true end}, 110 | {name='transform', type='function', default=function(sample) return sample end}, 111 | {name='ordered', type='boolean', default=false}, 112 | call = 113 | function(self, init, nthread, perm, filter, transform, ordered) 114 | local closure = function() return self end 115 | return tnt.ParallelDatasetIterator(init, closure, nthread, perm, filter, transform, ordered) 116 | end 117 | } 118 | -------------------------------------------------------------------------------- /dataset/listdataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local tds = require 'tds' 12 | local argcheck = require 'argcheck' 13 | local transform = require 'torchnet.transform' 14 | 15 | local ListDataset, Dataset = torch.class('tnt.ListDataset', 'tnt.Dataset', tnt) 16 | 17 | ListDataset.__init = argcheck{ 18 | doc = [[ 19 | 20 | #### tnt.ListDataset(@ARGP) 21 | @ARGT 22 | 23 | Considering a `list` (can be a `tds.Hash`, `table` or a `torch.LongTensor`) the 24 | i-th sample of a dataset will be returned by `load(list[i])`, where `load()` is 25 | a closure provided by the user. 26 | 27 | If `path` is provided, list is assumed to be a list of string, and will 28 | each element `list[i]` will prefixed by `path/` when fed to `load()`. 29 | 30 | Purpose: many low or medium-scale datasets can be seen as a list of files 31 | (for example representing input samples). For this list of file, a target 32 | can be often inferred in a simple manner. 33 | 34 | ]], 35 | {name='self', type='tnt.ListDataset'}, 36 | {name='list', type='tds.Hash'}, 37 | {name='load', type='function'}, 38 | {name='path', type='string', opt=true}, 39 | call = 40 | function(self, list, load, path) 41 | Dataset.__init(self) 42 | self.list = list 43 | self.load = load 44 | self.path = path 45 | end 46 | } 47 | 48 | ListDataset.__init = argcheck{ 49 | {name='self', type='tnt.ListDataset'}, 50 | {name='list', type='table'}, 51 | {name='load', type='function'}, 52 | {name='path', type='string', opt=true}, 53 | overload = ListDataset.__init, 54 | call = 55 | function(self, list, load, path) 56 | Dataset.__init(self) 57 | self.list = list 58 | self.load = load 59 | self.path = path 60 | end 61 | } 62 | 63 | ListDataset.__init = argcheck{ 64 | {name='self', type='tnt.ListDataset'}, 65 | {name='list', type='tds.Vec'}, 66 | {name='load', type='function'}, 67 | {name='path', type='string', opt=true}, 68 | overload = ListDataset.__init, 69 | call = 70 | function(self, list, load, path) 71 | Dataset.__init(self) 72 | self.list = list 73 | self.load = load 74 | self.path = path 75 | end 76 | } 77 | 78 | ListDataset.__init = argcheck{ 79 | {name='self', type='tnt.ListDataset'}, 80 | {name='list', type='torch.LongTensor'}, 81 | {name='load', type='function'}, 82 | {name='path', type='string', opt=true}, 83 | overload = ListDataset.__init, 84 | call = 85 | function(self, list, load, path) 86 | Dataset.__init(self) 87 | self.list = list 88 | self.load = load 89 | self.path = path 90 | end 91 | } 92 | 93 | ListDataset.__init = argcheck{ 94 | doc = [[ 95 | #### tnt.ListDataset(@ARGP) 96 | @ARGT 97 | 98 | The file specified by `filename` is interpreted as a list of strings (one 99 | string per line). The i-th sample of a dataset will be returned by 100 | `load(line[i])`, where `load()` is a closure provided by the user an 101 | `line[i]` is the i-the line of `filename`. 102 | 103 | If `path` is provided, list is assumed to be a list of string, and will 104 | each element `list[i]` will prefixed by `path/` when fed to `load()`. 105 | 106 | ]], 107 | {name='self', type='tnt.ListDataset'}, 108 | {name='filename', type='string'}, 109 | {name='load', type='function'}, 110 | {name='maxload', type='number', opt=true}, 111 | {name='path', type='string', opt=true}, 112 | overload = ListDataset.__init, 113 | call = 114 | function(self, filename, load, maxload, path) 115 | local list = tds.hash() 116 | for filename in io.lines(filename) do 117 | list[#list+1] = filename 118 | if maxload and maxload > 0 and #list == maxload then 119 | break 120 | end 121 | end 122 | ListDataset.__init(self, list, load, path) 123 | print(string.format("| loaded <%s> with %d examples", filename, #list)) 124 | end 125 | } 126 | 127 | ListDataset.size = argcheck{ 128 | {name='self', type='tnt.ListDataset'}, 129 | call = 130 | function(self) 131 | return torch.isTensor(self.list) and self.list:size(1) 132 | or #self.list 133 | end 134 | } 135 | 136 | ListDataset.get = argcheck{ 137 | {name='self', type='tnt.ListDataset'}, 138 | {name='idx', type='number'}, 139 | call = 140 | function(self, idx) 141 | assert(idx >= 1 and idx <= self:size(), 'out of bound') 142 | if self.path then 143 | return self.load(string.format("%s/%s", self.path, self.list[idx])) 144 | else 145 | return self.load(self.list[idx]) 146 | end 147 | end 148 | } 149 | -------------------------------------------------------------------------------- /dataset/paralleldatasetiterator.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local Threads = require 'threads' 12 | local argcheck = require 'argcheck' 13 | local doc = require 'argcheck.doc' 14 | 15 | local ParallelDatasetIterator = torch.class('tnt.ParallelDatasetIterator', 'tnt.DatasetIterator', tnt) 16 | 17 | ParallelDatasetIterator.__init = argcheck{ 18 | doc = [[ 19 | 20 | #### tnt.ParallelDatasetIterator(@ARGP) 21 | @ARGT 22 | 23 | Allows to iterate over a dataset in a thread 24 | manner. `tnt.ParallelDatasetIterator:run()` guarantees that all samples 25 | will be seen, but does not guarantee the order unless `ordered` is set to true. 26 | 27 | The purpose of this class is to have a zero pre-processing cost. 28 | When reading datasets on the fly from 29 | disk (not loading them fully in memory), or performing complex 30 | pre-processing this can be of interest. 31 | 32 | The number of threads used to parallelize is specified by `nthread`. 33 | 34 | `init(threadid)` (where threadid=1..nthread) is a closure which may 35 | initialize the specified thread as needed, if needed. It is doing nothing 36 | by default. 37 | 38 | `closure(threadid)` will be called on each thread and must return a 39 | `tnt.Dataset` instance. 40 | 41 | `perm(idx)` is a permutation used to shuffle the examples. If shuffling is 42 | needed, one can use this closure, or (better) use 43 | [tnt.ShuffleDataset](#ShuffleDataset) on the underlying dataset 44 | (returned by `closure()`). 45 | 46 | `filter(sample)` is a closure which returns `true` if the given sample 47 | should be considered or `false` if not. Note that filter is called _after_ 48 | fetching the data in a threaded manner. 49 | 50 | `transform(sample)` is a function which maps the given sample to a new value. 51 | This transformation occurs before filtering. 52 | 53 | When `ordered` is set to true the ordering of samples returned by the iterator 54 | is guaranteed. This option is particularly useful for repeatable experiments. 55 | By default `ordered` is false, which means that order is not guaranteed by 56 | `run()` (though often the ordering is similar in practice). 57 | 58 | A common error raised by this dataset is when `closure()` is not 59 | serializable. Make sure that all [upvalues](http://www.lua.org/pil/27.3.3.html) of `closure()` are 60 | serializable. It is recommended to avoid [upvalues](http://www.lua.org/pil/27.3.3.html) at all cost, 61 | and to make sure you require all the appropriate torch packages needed to (de-)serialize 62 | `closure()` in the `init()` function. 63 | 64 | 65 | For more information, check out the [threads package](https://github.com/torch/threads), 66 | on which `tnt.ParallelDatasetIterator` relies. 67 | ]], 68 | {name='self', type='tnt.ParallelDatasetIterator'}, 69 | {name='init', type='function', default=function(idx) end}, 70 | {name='closure', type='function'}, 71 | {name='nthread', type='number'}, 72 | {name='perm', type='function', default=function(idx) return idx end}, 73 | {name='filter', type='function', default=function(sample) return true end}, 74 | {name='transform', type='function', default=function(sample) return sample end}, 75 | {name='ordered', type='boolean', default=false}, 76 | call = 77 | function(self, init, closure, nthread, perm, filter, transform, ordered) 78 | local function main(idx) 79 | gdataset = closure(idx) 80 | assert(torch.isTypeOf(gdataset, 'tnt.Dataset'), 81 | 'closure should return a Dataset class') 82 | return gdataset:size() 83 | end 84 | Threads.serialization('threads.sharedserialize') 85 | local threads, sizes = Threads(nthread, init, main) 86 | self.__threads = threads 87 | self.__nthread = nthread 88 | local size = sizes[1][1] 89 | local sample -- beware: do not put this line in loop() 90 | local sampleOrigIdx 91 | function self.run() 92 | -- make sure we are not in the middle of something 93 | threads:synchronize() 94 | -- loading size of the dataset each time run() is called 95 | threads:addjob( 96 | function() 97 | local size = gdataset:size() 98 | return size 99 | end, 100 | function(_size_) 101 | size = _size_ 102 | end 103 | ) 104 | threads:dojob() 105 | local idx = 1 106 | local function enqueue() 107 | while idx <= size and threads:acceptsjob() do 108 | threads:addjob( 109 | function(origIdx, idx) 110 | local sample = gdataset:get(idx) 111 | collectgarbage() 112 | collectgarbage() 113 | return sample, origIdx 114 | end, 115 | function(_sample_, _origIdx_) 116 | sample, sampleOrigIdx = _sample_, _origIdx_ 117 | end, 118 | idx, perm(idx) 119 | ) 120 | idx = idx + 1 121 | end 122 | end 123 | 124 | enqueue() 125 | 126 | local iterFunction 127 | if ordered then 128 | local curSampleIdx = 1 129 | local storedSamples = {} 130 | -- `samplePlaceholder` stands in for samples which have been 131 | -- filtered out by the `filter` function 132 | local samplePlaceholder = {} 133 | 134 | -- Move past placeholders (filtered out samples) in 135 | -- `storedSamples` 136 | local function advancePastPlaceholders() 137 | while storedSamples[curSampleIdx] == samplePlaceholder do 138 | storedSamples[curSampleIdx] = nil 139 | curSampleIdx = curSampleIdx + 1 140 | end 141 | end 142 | 143 | iterFunction = function() 144 | advancePastPlaceholders() 145 | 146 | -- Load into storedSamples until we find the next sample in 147 | -- the sequence or we run out of samples 148 | while storedSamples[curSampleIdx] == nil and threads:hasjob() do 149 | enqueue() 150 | threads:dojob() 151 | if threads:haserror() then 152 | threads:synchronize() 153 | end 154 | enqueue() 155 | 156 | sample = transform(sample) 157 | if filter(sample) then 158 | -- Store sample 159 | storedSamples[sampleOrigIdx] = sample 160 | else 161 | -- Mark sample as "filtered out" 162 | storedSamples[sampleOrigIdx] = samplePlaceholder 163 | end 164 | 165 | advancePastPlaceholders() 166 | end 167 | 168 | enqueue() 169 | 170 | local curSample = storedSamples[curSampleIdx] 171 | storedSamples[curSampleIdx] = nil 172 | 173 | curSampleIdx = curSampleIdx + 1 174 | 175 | return curSample 176 | end 177 | else 178 | iterFunction = function() 179 | while threads:hasjob() do 180 | enqueue() 181 | threads:dojob() 182 | if threads:haserror() then 183 | threads:synchronize() 184 | end 185 | enqueue() 186 | sample = transform(sample) 187 | if filter(sample) then 188 | return sample 189 | end 190 | end 191 | end 192 | end 193 | 194 | return iterFunction 195 | end 196 | end 197 | } 198 | 199 | doc[[ 200 | 201 | #### tnt.ParallelDatasetIterator.execSingle(tnt.DatasetIterator, name, ...) 202 | 203 | Execute the given method `name` on the dataset corresponding to the first 204 | available thread, passing it the subsequent arguments, and returns what the 205 | `name` method returns. 206 | 207 | For example: 208 | ```lua 209 | local iterator = tnt.ParallelDatasetIterator{...} 210 | print(iterator:execSingle("size")) 211 | ``` 212 | will print the size of the dataset loaded in the first available thread. 213 | ]] 214 | 215 | ParallelDatasetIterator.execSingle = 216 | function(self, name, ...) 217 | assert(not self.__threads:hasjob(), 'cannot execSingle during loop') 218 | local args = {...} 219 | local res 220 | self.__threads:addjob( 221 | function() 222 | return gdataset:exec(name, table.unpack(args)) 223 | end, 224 | function(...) 225 | res = {...} 226 | end) 227 | self.__threads:synchronize() 228 | return table.unpack(res) 229 | end 230 | 231 | doc[[ 232 | 233 | #### tnt.ParallelDatasetIterator.exec(tnt.DatasetIterator, name, ...) 234 | 235 | Execute the given method `name` on the underlying datasets in each thread, 236 | passing to each of them the subsequent arguments, and returns a table 237 | of what the `name` method returns for each thread. 238 | 239 | For example: 240 | ```lua 241 | local iterator = tnt.ParallelDatasetIterator{...} 242 | for _, v in pairs(iterator:exec("size")) do 243 | print(v) 244 | end 245 | ``` 246 | will print the size of the datasets loaded in each thread. 247 | ]] 248 | 249 | ParallelDatasetIterator.exec = 250 | function(self, name, ...) 251 | assert(not self.__threads:hasjob(), 'cannot exec during loop') 252 | local args = {...} 253 | local res = {} 254 | self.__threads:specific(true) 255 | for i=1,self.__nthread do 256 | self.__threads:addjob(i, 257 | function() 258 | return gdataset:exec(name, table.unpack(args)) 259 | end, 260 | function(...) 261 | local r = {...} 262 | res[i] = #r > 1 and r or r[1] 263 | end) 264 | end 265 | self.__threads:specific(false) 266 | return res 267 | end 268 | -------------------------------------------------------------------------------- /dataset/resampledataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local ResampleDataset = 14 | torch.class('tnt.ResampleDataset', 'tnt.Dataset', tnt) 15 | 16 | ResampleDataset.__init = argcheck{ 17 | doc = [[ 18 | 19 | #### tnt.ResampleDataset(@ARGP) 20 | 21 | Given a `dataset`, creates a new dataset which will (re-)sample from this 22 | underlying dataset using the provided `sampler(dataset, idx)` closure. 23 | 24 | If `size` is provided, then the newly created dataset will have the 25 | specified `size`, which might be different than the underlying dataset 26 | size. 27 | 28 | If `size` is not provided, then the new dataset will have the same size 29 | than the underlying one. 30 | 31 | By default `sampler(dataset, idx)` is the identity, simply `return`ing `idx`. 32 | `dataset` corresponds to the underlying dataset provided at construction, and 33 | `idx` may take a value between 1 to `size`. It must return an index in the range 34 | acceptable for the underlying dataset. 35 | 36 | Purpose: shuffling data, re-weighting samples, getting a subset of the 37 | data. Note that an important sub-class is ([tnt.ShuffleDataset](#ShuffleDataset)), 38 | provided for convenience. 39 | ]], 40 | {name='self', type='tnt.ResampleDataset'}, 41 | {name='dataset', type='tnt.Dataset'}, 42 | {name='sampler', type='function', default=function(dataset, idx) return idx end}, 43 | {name='size', type='number', opt=true}, 44 | call = 45 | function(self, dataset, sampler, size) 46 | self.__sampler = sampler 47 | self.__dataset = dataset 48 | self.__size = size 49 | end 50 | } 51 | 52 | ResampleDataset.size = argcheck{ 53 | {name='self', type='tnt.ResampleDataset'}, 54 | call = 55 | function(self) 56 | return (self.__size and self.__size > 0) and self.__size or self.__dataset:size() 57 | end 58 | } 59 | 60 | ResampleDataset.get = argcheck{ 61 | {name='self', type='tnt.ResampleDataset'}, 62 | {name='idx', type='number'}, 63 | call = 64 | function(self, idx) 65 | assert(idx >= 1 and idx <= self:size(), 'index out of bound') 66 | idx = self.__sampler(self.__dataset, idx) 67 | assert(idx >= 1 and idx <= self.__dataset:size(), 'index out of bound (sampler)') 68 | return self.__dataset:get(idx) 69 | end 70 | } 71 | -------------------------------------------------------------------------------- /dataset/shuffledataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local ShuffleDataset, ResampleDataset = 14 | torch.class('tnt.ShuffleDataset', 'tnt.ResampleDataset', tnt) 15 | 16 | ShuffleDataset.__init = argcheck{ 17 | doc = [[ 18 | 19 | #### tnt.ShuffleDataset(@ARGP) 20 | @ARGT 21 | 22 | `tnt.ShuffleDataset` is a sub-class of 23 | [tnt.ResampleDataset](#ResampleDataset) provided for convenience. 24 | 25 | It samples uniformly from the given `dataset` with, or without 26 | `replacement`. The chosen partition can be redrawn by calling 27 | [resample()](#ShuffleDataset.resample). 28 | 29 | If `replacement` is `true`, then the specified `size` may be larger than 30 | the underlying `dataset`. 31 | 32 | If `size` is not provided, then the new dataset size will be equal to the 33 | underlying `dataset` size. 34 | 35 | Purpose: the easiest way to shuffle a dataset! 36 | ]], 37 | {name='self', type='tnt.ShuffleDataset'}, 38 | {name='dataset', type='tnt.Dataset'}, 39 | {name='size', type='number', opt=true}, 40 | {name='replacement', type='boolean', default=false}, 41 | call = 42 | function(self, dataset, size, replacement) 43 | if size and not replacement and size > dataset:size() then 44 | error('size cannot be larger than underlying dataset size when sampling without replacement') 45 | end 46 | self.__replacement = replacement 47 | local function sampler(dataset, idx) 48 | return self.__perm[idx] 49 | end 50 | ResampleDataset.__init(self, { 51 | dataset = dataset, 52 | sampler = sampler, 53 | size = size}) 54 | self:resample() 55 | end 56 | } 57 | 58 | ShuffleDataset.resample = argcheck{ 59 | doc = [[ 60 | 61 | ##### tnt.ShuffleDataset.resample(@ARGP) 62 | 63 | The permutation associated to `tnt.ShuffleDataset` is fixed, such that two 64 | calls to the same index will return the same sample from the underlying 65 | dataset. 66 | 67 | Call `resample()` to draw randomly a new permutation. 68 | ]], 69 | {name='self', type='tnt.ShuffleDataset'}, 70 | call = 71 | function(self) 72 | self.__perm = self.__replacement 73 | and torch.LongTensor(self:size()):random(self.__dataset:size()) 74 | or torch.randperm(self.__dataset:size()):long():narrow(1, 1, self:size()) 75 | end 76 | } 77 | -------------------------------------------------------------------------------- /dataset/splitdataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local SplitDataset = torch.class('tnt.SplitDataset', 'tnt.Dataset', tnt) 14 | 15 | SplitDataset.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.SplitDataset(@ARGP) 19 | @ARGT 20 | 21 | Partition a given `dataset`, according to the specified `partitions`. Use 22 | the method [select()](#SplitDataset.select) to select the current partition 23 | in use. 24 | 25 | The Lua hash table `partitions` is of the form (key, value) where key is a 26 | user-chosen string naming the partition, and value is a number representing 27 | the weight (as a number between 0 and 1) or the size (in number of samples) 28 | of the corresponding partition. 29 | 30 | Partioning is achieved linearly (no shuffling). See 31 | [tnt.ShuffleDataset](#ShuffleDataset) if you want to shuffle the dataset 32 | before partitioning. 33 | 34 | The optional variable `initialpartition` specifies the partition that is loaded 35 | initially. 36 | 37 | Purpose: useful in machine learning to perform validation procedures. 38 | ]], 39 | {name='self', type='tnt.SplitDataset'}, 40 | {name='dataset', type='tnt.Dataset'}, 41 | {name='partitions', type='table'}, 42 | {name='initialpartition', type='string', opt=true}, 43 | call = 44 | function(self, dataset, partitions, initialpartition) 45 | 46 | -- created sorted list of partition names (for determinism): 47 | local partitionnames = {} 48 | for name,_ in pairs(partitions) do 49 | table.insert(partitionnames, name) 50 | end 51 | table.sort(partitionnames) 52 | 53 | -- create partition size tensor and table with partition names: 54 | self.__dataset = dataset 55 | self.__partitionsizes = torch.DoubleTensor(#partitionnames) 56 | self.__names = {} 57 | for n, key in ipairs(partitionnames) do 58 | local val = partitions[key] 59 | self.__partitionsizes[n] = val 60 | self.__names[key] = n 61 | end 62 | 63 | -- assertions: 64 | assert( 65 | self.__partitionsizes:nElement() >= 2, 66 | 'SplitDataset should have at least two partitions' 67 | ) 68 | assert( 69 | self.__partitionsizes:min() >= 0, 70 | 'some partition sizes are negative' 71 | ) 72 | assert( 73 | self.__partitionsizes:max() > 0, 74 | 'all partitions are empty' 75 | ) 76 | 77 | -- if partition sizes are fractions, convert to sizes: 78 | if self.__partitionsizes:sum() <= 1 then 79 | self.__partitionsizes = self.__partitionsizes:double() 80 | self.__partitionsizes:mul(self.__dataset:size()):floor() 81 | else 82 | assert(torch.eq(self.__partitionsizes, 83 | torch.floor(self.__partitionsizes)):all(), 84 | 'partition sizes should be integer numbers, or sum up to <= 1 ' 85 | ) 86 | end 87 | self.__partitionsizes = self.__partitionsizes:long() 88 | assert( 89 | self.__partitionsizes:sum() <= self.__dataset:size(), 90 | 'split cannot involve more samples than dataset size' 91 | ) 92 | 93 | -- select first partition: 94 | if initialpartition then self:select(initialpartition) end 95 | end 96 | } 97 | 98 | SplitDataset.select = argcheck{ 99 | doc = [[ 100 | 101 | ##### tnt.SplitDataset.select(@ARGP) 102 | @ARGT 103 | 104 | Switch the current partition in use to the one specified by `partition`, 105 | which must be a string corresponding to one of the names provided at 106 | construction. 107 | 108 | The current dataset size changes accordingly, as well as the samples returned 109 | by the `get()` method. 110 | ]], 111 | {name='self', type='tnt.SplitDataset'}, 112 | {name='partition', type='string'}, 113 | call = 114 | function(self, partition) 115 | local id = self.__names[partition] 116 | if not id then error('partition not found') end 117 | if self.__partition then 118 | self.__partition[1] = id 119 | else 120 | self.__partition = torch.LongTensor{id} 121 | end 122 | end 123 | } 124 | 125 | SplitDataset.size = argcheck{ 126 | {name='self', type='tnt.SplitDataset'}, 127 | call = 128 | function(self) 129 | assert(self.__partition, 'select a partition before accessing data') 130 | return self.__partitionsizes[self.__partition[1]] 131 | end 132 | } 133 | 134 | SplitDataset.get = argcheck{ 135 | {name='self', type='tnt.SplitDataset'}, 136 | {name='idx', type='number'}, 137 | call = 138 | function(self, idx) 139 | assert(self.__partition, 'select a partition before accessing data') 140 | assert(idx >= 1 and idx <= self:size(), 'index out of bounds') 141 | local offset = (self.__partition[1] == 1) and 0 or 142 | self.__partitionsizes:narrow(1, 1, self.__partition[1] - 1):sum() 143 | return self.__dataset:get(offset + idx) 144 | end 145 | } 146 | -------------------------------------------------------------------------------- /dataset/tabledataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local utils = require 'torchnet.utils' 13 | 14 | local TableDataset, ListDataset 15 | = torch.class('tnt.TableDataset', 'tnt.ListDataset', tnt) 16 | 17 | TableDataset.__init = argcheck{ 18 | doc = [[ 19 | 20 | #### tnt.TableDataset(@ARGP) 21 | @ARGT 22 | 23 | `tnt.TableDataset` interfaces existing data 24 | to torchnet. It is useful if you want to use torchnet on a small dataset. 25 | 26 | The data must be contained in a `tds.Hash`. 27 | 28 | `tnt.TableDataset` does a shallow copy of the data. 29 | 30 | Data are loaded while constructing the `tnt.TableDataset`: 31 | ```lua 32 | > a = tnt.TableDataset{data = {1,2,3}} 33 | > print(a:size()) 34 | 3 35 | ``` 36 | `tnt.TableDataset` assumes that table has contiguous keys starting at 1. 37 | ]], 38 | noordered = true, 39 | {name = 'self', type = 'tnt.TableDataset'}, 40 | {name = 'data', type = 'table'}, 41 | call = function(self, data) 42 | for i = 1, #data do 43 | assert(data[i], "keys are not contiguous integers starting at 1") 44 | end 45 | local size = 0 46 | for _, _ in pairs(data) do size = size + 1 end 47 | assert(size == #data, "keys are not contiguous integers starting at 1") 48 | self.data = data 49 | end 50 | } 51 | 52 | TableDataset.size = argcheck{ 53 | {name = 'self', type = 'tnt.TableDataset'}, 54 | call = function(self) 55 | return #self.data 56 | end 57 | } 58 | 59 | TableDataset.get = argcheck{ 60 | {name = 'self', type = 'tnt.TableDataset'}, 61 | {name = 'idx', type = 'number'}, 62 | call = function(self, idx) 63 | assert(idx >= 1 and idx <= self:size(), 'index out of bound') 64 | assert(idx == math.floor(idx), 'index must be an integer') 65 | return utils.table.clone(self.data[idx]) 66 | end 67 | } 68 | -------------------------------------------------------------------------------- /dataset/transformdataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local utils = require 'torchnet.utils' 13 | 14 | local TransformDataset = 15 | torch.class('tnt.TransformDataset', 'tnt.Dataset', tnt) 16 | 17 | TransformDataset.__init = argcheck{ 18 | doc = [[ 19 | 20 | #### tnt.TransformDataset(@ARGP) 21 | @ARGT 22 | 23 | Given a closure `transform()`, and a `dataset`, `tnt.TransformDataset` 24 | applies the closure in an on-the-fly manner when querying a sample with 25 | `tnt.Dataset:get()`. 26 | 27 | If key is provided, the closure is applied to the sample field specified 28 | by `key` (only). The closure must return the new corresponding field value. 29 | 30 | If key is not provided, the closure is applied on the full sample. The 31 | closure must return the new sample table. 32 | 33 | The size of the new dataset is equal to the size of the underlying `dataset`. 34 | 35 | Purpose: when performing pre-processing operations, it is convenient to be 36 | able to perform on-the-fly transformations to a 37 | dataset. 38 | ]], 39 | {name='self', type='tnt.TransformDataset'}, 40 | {name='dataset', type='tnt.Dataset'}, 41 | {name='transform', type='function'}, 42 | {name='key', type='string', opt=true}, 43 | call = 44 | function(self, dataset, transform, key) 45 | self.dataset = dataset 46 | if key then 47 | function self.__transform(z, idx) 48 | assert(z[key], 'missing key in sample') 49 | z[key] = transform(z[key], idx) 50 | return z 51 | end 52 | else 53 | function self.__transform(z, idx) 54 | return transform(z, idx) 55 | end 56 | end 57 | end 58 | } 59 | 60 | TransformDataset.__init = argcheck{ 61 | doc = [[ 62 | 63 | #### tnt.TransformDataset(@ARGP) 64 | @ARGT 65 | 66 | Given a set of closures and a `dataset`, `tnt.TransformDataset` applies 67 | these closures in an on-the-fly manner when querying a sample with 68 | `tnt.Dataset:get()`. 69 | 70 | Closures are provided in `transforms`, a Lua table, where a (key,value) 71 | pair represents a (sample field name, corresponding closure to be applied 72 | to the field name). 73 | 74 | Each closure must return the new value of the corresponding field. 75 | ]], 76 | {name='self', type='tnt.TransformDataset'}, 77 | {name='dataset', type='tnt.Dataset'}, 78 | {name='transforms', type='table'}, 79 | overload = TransformDataset.__init, 80 | call = 81 | function(self, dataset, transforms) 82 | for k,v in pairs(transforms) do 83 | assert(type(v) == 'function', 84 | 'key/function table expected for transforms') 85 | end 86 | self.dataset = dataset 87 | transforms = utils.table.copy(transforms) 88 | function self.__transform(z) 89 | for key,transform in pairs(transforms) do 90 | assert(z[key], 'missing key in sample') 91 | z[key] = transform(z[key]) 92 | end 93 | return z 94 | end 95 | end 96 | } 97 | 98 | TransformDataset.size = argcheck{ 99 | {name='self', type='tnt.TransformDataset'}, 100 | call = 101 | function(self) 102 | return self.dataset:size() 103 | end 104 | } 105 | 106 | TransformDataset.get = argcheck{ 107 | {name='self', type='tnt.TransformDataset'}, 108 | {name='idx', type='number'}, 109 | call = 110 | function(self, idx) 111 | return self.__transform( 112 | self.dataset:get(idx), idx 113 | ) 114 | end 115 | } 116 | -------------------------------------------------------------------------------- /engine/init.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local doc = require 'argcheck.doc' 13 | 14 | doc[[ 15 | 16 | ### tnt.Engine 17 | 18 | In experimenting with different models and datasets, the underlying training 19 | procedure is often the same. The `Engine` module provides the boilerplate logic 20 | necessary for the training and testing of models. This might include conducting 21 | the interaction between model (`nn.Module`), `tnt.DatasetIterator`s, 22 | `nn.Criterion`s, and `tnt.Meter`s. 23 | 24 | An instance `engine` of a `tnt.Engine()` implements two main methods: 25 | 26 | * `engine:train()`, for training the model on data 27 | (i.e. sample data, forward prop, backward prop). 28 | * `engine:test()`, for evaluating a model on data 29 | (optionally with respect to a `nn.Criterion`). 30 | 31 | The `Engine` can be implemented for any common underlying training and testing 32 | procedure involving a model and data. It can also be designed to allow user 33 | control after certain events such as forward prop, criterion evaluation, or the 34 | end of an epoch, by using coroutines (see `tnt.SGDEngine`). 35 | 36 | ]] 37 | 38 | local Engine = torch.class('tnt.Engine', tnt) 39 | 40 | Engine.__init = argcheck{ 41 | nonamed=true, -- to avoid ambiguities 42 | {name="self", type="tnt.Engine"}, 43 | {name="hooks", type="table"}, 44 | call = 45 | function(self, hooks) 46 | self.hooks = {} 47 | for _, name in ipairs(hooks) do 48 | assert(type(name) == 'string', 'hooks must be a table of hook names (strings)') 49 | self.hooks[name] = function() end 50 | end 51 | setmetatable( 52 | self.hooks, 53 | { 54 | __index = 55 | function(hooks, name) 56 | assert(type(name) == 'string', 'hook name must be a string') 57 | error(string.format('unknown hook <%s>', name)) 58 | end, 59 | __newindex = 60 | function(self, name) 61 | assert(type(name) == 'string', 'hook name must be a string') 62 | error(string.format('unknown hook <%s>', name)) 63 | end, 64 | __call = 65 | function(hooks, name, ...) 66 | return hooks[name](...) 67 | end 68 | } 69 | ) 70 | end 71 | } 72 | 73 | Engine.train = argcheck{ 74 | {name="self", type="tnt.Engine"}, 75 | call = 76 | function(self) 77 | error('A tnt.Engine should implement the train() function.') 78 | end 79 | } 80 | 81 | Engine.test = argcheck{ 82 | {name="self", type="tnt.Engine"}, 83 | call = 84 | function(self) 85 | error('A tnt.Engine should implement the test() function.') 86 | end 87 | } 88 | -------------------------------------------------------------------------------- /engine/optimengine.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local doc = require 'argcheck.doc' 13 | 14 | doc[[ 15 | 16 | ### tnt.OptimEngine 17 | 18 | The `OptimEngine` module wraps the optimization functions from 19 | https://github.com/torch/optim. At the start of training, the engine will call 20 | `getParameters` on the provided network. 21 | 22 | The `train` method requires the following parameters in addition to the 23 | `SGDEngine.train` parameters: 24 | 25 | * `optimMethod` the optimization function (e.g `optim.sgd`) 26 | * `config` a table with configuration parameters for the optimizer 27 | 28 | Example: 29 | ```lua 30 | local engine = tnt.OptimEngine() 31 | engine:train{ 32 | network = model, 33 | criterion = criterion, 34 | iterator = iterator, 35 | optimMethod = optim.sgd, 36 | config = { 37 | learningRate = 0.1, 38 | momentum = 0.9, 39 | }, 40 | } 41 | ``` 42 | ]] 43 | 44 | require 'nn' 45 | 46 | local OptimEngine, SGDEngine = torch.class('tnt.OptimEngine', 'tnt.SGDEngine', tnt) 47 | 48 | OptimEngine.__init = argcheck{ 49 | {name="self", type="tnt.OptimEngine"}, 50 | call = 51 | function(self) 52 | SGDEngine.__init(self) 53 | end 54 | } 55 | 56 | OptimEngine.train = argcheck{ 57 | {name="self", type="tnt.OptimEngine"}, 58 | {name="network", type="nn.Module"}, 59 | {name="criterion", type="nn.Criterion"}, 60 | {name="iterator", type="tnt.DatasetIterator"}, 61 | {name="maxepoch", type="number", default=1000}, 62 | {name="optimMethod", type="function"}, 63 | {name="config", type="table", opt=true}, 64 | {name="optimState", type="table", opt=true}, 65 | {name="paramFun", type="function", opt=true}, 66 | call = 67 | function(self, network, criterion, iterator, maxepoch, optimMethod, 68 | config, optimState, paramFun) 69 | local state = { 70 | network = network, 71 | criterion = criterion, 72 | iterator = iterator, 73 | maxepoch = maxepoch, 74 | optimMethod = optimMethod, 75 | sample = {}, 76 | config = config or {}, 77 | optim = optimState or {}, 78 | epoch = 0, -- epoch done so far 79 | t = 0, -- samples seen so far 80 | training = true 81 | } 82 | 83 | if paramFun then 84 | state.params, state.gradParams = paramFun() 85 | else 86 | state.params, state.gradParams = state.network:getParameters() 87 | end 88 | 89 | local function feval() 90 | return state.criterion.output, state.gradParams 91 | end 92 | 93 | self.hooks("onStart", state) 94 | while state.epoch < state.maxepoch do 95 | state.network:training() 96 | 97 | self.hooks("onStartEpoch", state) 98 | for sample in state.iterator() do 99 | state.sample = sample 100 | self.hooks("onSample", state) 101 | 102 | state.network:forward(sample.input) 103 | self.hooks("onForward", state) 104 | state.criterion:forward(state.network.output, sample.target) 105 | self.hooks("onForwardCriterion", state) 106 | 107 | state.network:zeroGradParameters() 108 | if state.criterion.zeroGradParameters then 109 | state.criterion:zeroGradParameters() 110 | end 111 | 112 | state.criterion:backward(state.network.output, sample.target) 113 | self.hooks("onBackwardCriterion", state) 114 | state.network:backward(sample.input, state.criterion.gradInput) 115 | self.hooks("onBackward", state) 116 | 117 | state.optimMethod(feval, state.params, state.config, state.optim) 118 | state.t = state.t + 1 119 | self.hooks("onUpdate", state) 120 | end 121 | state.epoch = state.epoch + 1 122 | self.hooks("onEndEpoch", state) 123 | end 124 | self.hooks("onEnd", state) 125 | end 126 | } 127 | -------------------------------------------------------------------------------- /engine/sgdengine.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local doc = require 'argcheck.doc' 13 | 14 | doc[[ 15 | 16 | ### tnt.SGDEngine 17 | 18 | The `SGDEngine` module implements the Stochastic Gradient Descent training 19 | procedure in `train`, including data sampling, forward prop, back prop, and 20 | parameter updates. It also operates as a coroutine allowing a user control 21 | (i.e. increment some sort of `tnt.Meter`) at events such as 'start', 22 | 'start-epoch', 'forward', 'forward-criterion', 'backward', etc. 23 | The available hooks are the following: 24 | ```lua 25 | hooks = { 26 | ['onStart'] = function() end, -- Right before training 27 | ['onStartEpoch'] = function() end, -- Before new epoch 28 | ['onSample'] = function() end, -- After getting a sample 29 | ['onForward'] = function() end, -- After model:forward 30 | ['onForwardCriterion'] = function() end, -- After criterion:forward 31 | ['onBackwardCriterion'] = function() end, -- After criterion:backward 32 | ['onBackward'] = function() end, -- After model:backward 33 | ['onUpdate'] = function() end, -- After UpdateParameters 34 | ['onEndEpoch'] = function() end, -- Right before completing epoch 35 | ['onEnd'] = function() end, -- After training 36 | } 37 | ``` 38 | To specify a new closure for a given hook, we can access to it with 39 | `engine.hooks.`. For example, we could reset a `Meter` before every 40 | epoch by: 41 | ```lua 42 | local engine = tnt.SGDEngine() 43 | local meter = tnt.AverageValueMeter() 44 | engine.hooks.onStartEpoch = function(state) 45 | meter:reset() 46 | end 47 | ``` 48 | 49 | Accordingly, `train` requires a network (`nn.Module`), a criterion expressing the 50 | loss function (`nn.Criterion`), a dataset iterator (`tnt.DatasetIterator`), and a 51 | learning rate, at the minimum. The `test` function allows for simple evaluation 52 | of a model on a dataset. 53 | 54 | A `state` is maintained for external access to outputs and parameters of modules 55 | as well as sampled data. The content of the `state` table is the following, where 56 | the passed values come from the arguments of `engine:train()`: 57 | ```lua 58 | state = { 59 | ['network'] = network, 60 | ['criterion'] = criterion, 61 | ['iterator'] = iterator, 62 | ['lr'] = lr, 63 | ['lrcriterion'] = lrcriterion, 64 | ['maxepoch'] = maxepoch, 65 | ['sample'] = {}, 66 | ['epoch'] = 0, -- epoch done so far 67 | ['t'] = 0, -- samples seen so far 68 | ['training'] = true 69 | } 70 | ``` 71 | ]] 72 | 73 | require 'nn' 74 | 75 | local SGDEngine, Engine = torch.class('tnt.SGDEngine', 'tnt.Engine', tnt) 76 | 77 | SGDEngine.__init = argcheck{ 78 | {name="self", type="tnt.SGDEngine"}, 79 | call = 80 | function(self) 81 | Engine.__init(self, { 82 | "onStart", "onStartEpoch", "onSample", 83 | "onForward", "onForwardCriterion", 84 | "onBackward", "onBackwardCriterion", 85 | "onEndEpoch", "onUpdate", "onEnd" 86 | }) 87 | end 88 | } 89 | 90 | SGDEngine.train = argcheck{ 91 | {name="self", type="tnt.SGDEngine"}, 92 | {name="network", type="nn.Module"}, 93 | {name="criterion", type="nn.Criterion"}, 94 | {name="iterator", type="tnt.DatasetIterator"}, 95 | {name="lr", type="number"}, 96 | {name="lrcriterion", type="number", defaulta="lr"}, 97 | {name="maxepoch", type="number", default=1000}, 98 | call = 99 | function(self, network, criterion, iterator, lr, lrcriterion, maxepoch) 100 | local state = { 101 | network = network, 102 | criterion = criterion, 103 | iterator = iterator, 104 | lr = lr, 105 | lrcriterion = lrcriterion, 106 | maxepoch = maxepoch, 107 | sample = {}, 108 | epoch = 0, -- epoch done so far 109 | t = 0, -- samples seen so far 110 | training = true 111 | } 112 | 113 | self.hooks("onStart", state) 114 | while state.epoch < state.maxepoch do 115 | state.network:training() 116 | 117 | self.hooks("onStartEpoch", state) 118 | for sample in state.iterator() do 119 | state.sample = sample 120 | self.hooks("onSample", state) 121 | 122 | state.network:forward(sample.input) 123 | self.hooks("onForward", state) 124 | state.criterion:forward(state.network.output, sample.target) 125 | self.hooks("onForwardCriterion", state) 126 | 127 | state.network:zeroGradParameters() 128 | if state.criterion.zeroGradParameters then 129 | state.criterion:zeroGradParameters() 130 | end 131 | 132 | state.criterion:backward(state.network.output, sample.target) 133 | self.hooks("onBackwardCriterion", state) 134 | state.network:backward(sample.input, state.criterion.gradInput) 135 | self.hooks("onBackward", state) 136 | 137 | assert(state.lrcriterion >= 0, 'lrcriterion should be positive or zero') 138 | if state.lrcriterion > 0 and state.criterion.updateParameters then 139 | state.criterion:updateParameters(state.lrcriterion) 140 | end 141 | assert(state.lr >= 0, 'lr should be positive or zero') 142 | if state.lr > 0 then 143 | state.network:updateParameters(state.lr) 144 | end 145 | state.t = state.t + 1 146 | self.hooks("onUpdate", state) 147 | end 148 | state.epoch = state.epoch + 1 149 | self.hooks("onEndEpoch", state) 150 | end 151 | self.hooks("onEnd", state) 152 | end 153 | } 154 | 155 | SGDEngine.test = argcheck{ 156 | {name="self", type="tnt.SGDEngine"}, 157 | {name="network", type="nn.Module"}, 158 | {name="iterator", type="tnt.DatasetIterator"}, 159 | {name="criterion", type="nn.Criterion", opt=true}, 160 | call = function(self, network, iterator, criterion) 161 | local state = { 162 | network = network, 163 | iterator = iterator, 164 | criterion = criterion, 165 | sample = {}, 166 | t = 0, -- samples seen so far 167 | training = false 168 | } 169 | 170 | self.hooks("onStart", state) 171 | state.network:evaluate() 172 | for sample in state.iterator() do 173 | state.sample = sample 174 | self.hooks("onSample", state) 175 | state.network:forward(sample.input) 176 | state.t = state.t + 1 177 | self.hooks("onForward", state) 178 | 179 | if state.criterion then 180 | state.criterion:forward(state.network.output, sample.target) 181 | self.hooks("onForwardCriterion", state) 182 | end 183 | 184 | end 185 | self.hooks("onEnd", state) 186 | end 187 | } 188 | -------------------------------------------------------------------------------- /env.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = {} 11 | 12 | return tnt 13 | -------------------------------------------------------------------------------- /example/mnist.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | This source code is licensed under the BSD-style license found in the 5 | LICENSE file in the root directory of this source tree. An additional grant 6 | of patent rights can be found in the PATENTS file in the same directory. 7 | ]]-- 8 | 9 | -- load torchnet: 10 | local tnt = require 'torchnet' 11 | 12 | -- use GPU or not: 13 | local cmd = torch.CmdLine() 14 | cmd:option('-usegpu', false, 'use gpu for training') 15 | local config = cmd:parse(arg) 16 | print(string.format('running on %s', config.usegpu and 'GPU' or 'CPU')) 17 | 18 | -- function that sets of dataset iterator: 19 | local function getIterator(mode) 20 | return tnt.ParallelDatasetIterator{ 21 | nthread = 1, 22 | init = function() require 'torchnet' end, 23 | closure = function() 24 | 25 | -- load MNIST dataset: 26 | local mnist = require 'mnist' 27 | local dataset = mnist[mode .. 'dataset']() 28 | dataset.data = dataset.data:reshape(dataset.data:size(1), 29 | dataset.data:size(2) * dataset.data:size(3)):double():div(256) 30 | 31 | -- return batches of data: 32 | return tnt.BatchDataset{ 33 | batchsize = 128, 34 | dataset = tnt.ListDataset{ -- replace this by your own dataset 35 | list = torch.range(1, dataset.data:size(1)):long(), 36 | load = function(idx) 37 | return { 38 | input = dataset.data[idx], 39 | target = torch.LongTensor{dataset.label[idx] + 1}, 40 | } -- sample contains input and target 41 | end, 42 | } 43 | } 44 | end, 45 | } 46 | end 47 | 48 | -- set up logistic regressor: 49 | local net = nn.Sequential():add(nn.Linear(784,10)) 50 | local criterion = nn.CrossEntropyCriterion() 51 | 52 | -- set up training engine: 53 | local engine = tnt.SGDEngine() 54 | local meter = tnt.AverageValueMeter() 55 | local clerr = tnt.ClassErrorMeter{topk = {1}} 56 | engine.hooks.onStartEpoch = function(state) 57 | meter:reset() 58 | clerr:reset() 59 | end 60 | engine.hooks.onForwardCriterion = function(state) 61 | meter:add(state.criterion.output) 62 | clerr:add(state.network.output, state.sample.target) 63 | if state.training then 64 | print(string.format('avg. loss: %2.4f; avg. error: %2.4f', 65 | meter:value(), clerr:value{k = 1})) 66 | end 67 | end 68 | 69 | -- set up GPU training: 70 | if config.usegpu then 71 | 72 | -- copy model to GPU: 73 | require 'cunn' 74 | net = net:cuda() 75 | criterion = criterion:cuda() 76 | 77 | -- copy sample to GPU buffer: 78 | local igpu, tgpu = torch.CudaTensor(), torch.CudaTensor() 79 | engine.hooks.onSample = function(state) 80 | igpu:resize(state.sample.input:size() ):copy(state.sample.input) 81 | tgpu:resize(state.sample.target:size()):copy(state.sample.target) 82 | state.sample.input = igpu 83 | state.sample.target = tgpu 84 | end -- alternatively, this logic can be implemented via a TransformDataset 85 | end 86 | 87 | -- train the model: 88 | engine:train{ 89 | network = net, 90 | iterator = getIterator('train'), 91 | criterion = criterion, 92 | lr = 0.2, 93 | maxepoch = 5, 94 | } 95 | 96 | -- measure test loss and error: 97 | meter:reset() 98 | clerr:reset() 99 | engine:test{ 100 | network = net, 101 | iterator = getIterator('test'), 102 | criterion = criterion, 103 | } 104 | print(string.format('test loss: %2.4f; test error: %2.4f', 105 | meter:value(), clerr:value{k = 1})) 106 | -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require 'torch' 11 | 12 | local tnt = require 'torchnet.env' 13 | local doc = require 'argcheck.doc' 14 | 15 | doc[[ 16 | [![Build Status](https://travis-ci.org/torchnet/torchnet.svg)](https://travis-ci.org/torchnet/torchnet) 17 | 18 | # torchnet 19 | 20 | *torchnet* is a framework for [torch](http://torch.ch) which provides a set 21 | of abstractions aiming at encouraging code re-use as well as encouraging 22 | modular programming. 23 | 24 | At the moment, *torchnet* provides four set of important classes: 25 | - [`Dataset`](#tntdataset): handling and pre-processing data in various ways. 26 | - [`Engine`](#tntengine): training/testing machine learning algorithm. 27 | - [`Meter`](#tntmeter): meter performance or any other quantity. 28 | - [`Log`](#tntlog): output performance or any other string to file / disk in a consistent manner. 29 | 30 | For an overview of the *torchnet* framework, please also refer to 31 | [this paper](https://lvdmaaten.github.io/publications/papers/Torchnet_2016.pdf). 32 | 33 | 34 | ## Installation 35 | 36 | Please install *torch* first, following instructions on 37 | [torch.ch](http://torch.ch/docs/getting-started.html). If *torch* is 38 | already installed, make sure you have an up-to-date version of 39 | [*argcheck*](https://github.com/torch/argcheck), otherwise you will get 40 | weird errors at runtime. 41 | 42 | Assuming *torch* is already installed, the *torchnet* core is only a set of 43 | lua files, so it is straightforward to install it with *luarocks* 44 | ``` 45 | luarocks install torchnet 46 | ``` 47 | 48 | To run the MNIST example from the paper, install the `mnist` package: 49 | ``` 50 | luarocks install mnist 51 | ``` 52 | 53 | `cd` into the installed `torchnet` package directory and run: 54 | ``` 55 | th example/mnist.lua 56 | ``` 57 | 58 | 59 | ## Documentation 60 | 61 | Requiring *torchnet* returns a local variable containing all *torchnet* 62 | class constructors. 63 | ``` 64 | local tnt = require 'torchnet' 65 | ``` 66 | 67 | ]] 68 | 69 | require 'torchnet.dataset' 70 | require 'torchnet.dataset.listdataset' 71 | require 'torchnet.dataset.tabledataset' 72 | require 'torchnet.dataset.indexeddataset' 73 | require 'torchnet.dataset.transformdataset' 74 | require 'torchnet.dataset.batchdataset' 75 | require 'torchnet.dataset.coroutinebatchdataset' 76 | require 'torchnet.dataset.concatdataset' 77 | require 'torchnet.dataset.resampledataset' 78 | require 'torchnet.dataset.shuffledataset' 79 | require 'torchnet.dataset.splitdataset' 80 | require 'torchnet.dataset.datasetiterator' 81 | require 'torchnet.dataset.paralleldatasetiterator' 82 | 83 | require 'torchnet.engine' 84 | require 'torchnet.engine.sgdengine' 85 | require 'torchnet.engine.optimengine' 86 | 87 | require 'torchnet.meter' 88 | require 'torchnet.meter.apmeter' 89 | require 'torchnet.meter.averagevaluemeter' 90 | require 'torchnet.meter.aucmeter' 91 | require 'torchnet.meter.confusionmeter' 92 | require 'torchnet.meter.mapmeter' 93 | require 'torchnet.meter.movingaveragevaluemeter' 94 | require 'torchnet.meter.msemeter' 95 | require 'torchnet.meter.multilabelconfusionmeter' 96 | require 'torchnet.meter.classerrormeter' 97 | require 'torchnet.meter.timemeter' 98 | require 'torchnet.meter.precisionatkmeter' 99 | require 'torchnet.meter.recallmeter' 100 | require 'torchnet.meter.precisionmeter' 101 | require 'torchnet.meter.ndcgmeter' 102 | 103 | require 'torchnet.log' 104 | 105 | require 'torchnet.utils' 106 | require 'torchnet.transform' 107 | 108 | require 'torchnet.test.test' 109 | 110 | -- function that makes package serializable: 111 | local function _makepackageserializable(packagetbl, packagename) 112 | local mt = torch.class('package.' .. packagename) 113 | function mt:__write() end 114 | function mt:__read() end 115 | function mt:__factory() return require(packagename) end 116 | setmetatable(packagetbl, mt) 117 | end 118 | 119 | -- this can be removed when @locronan implements a real torch.isclass(): 120 | function torch.isclass(obj) 121 | local REG = debug.getregistry() 122 | return REG[obj] and true or false 123 | end 124 | 125 | -- make torchnet serializable: 126 | local argcheck = require 'argcheck' 127 | tnt.makepackageserializable = argcheck{ 128 | {name = 'packagetbl', type = 'table'}, 129 | {name = 'packagename', type = 'string'}, 130 | call = function(packagetbl, packagename) 131 | assert(not torch.isclass(getmetatable(packagetbl)) 132 | and not torch.isclass(packagetbl), 'input cant be a class (instance)') 133 | _makepackageserializable(packagetbl, packagename) 134 | for key, val in pairs(packagetbl) do 135 | if type(val) == 'table' and not torch.isclass(getmetatable(val)) 136 | and not torch.isclass(val) then 137 | tnt.makepackageserializable(val, packagename .. '.' .. key) 138 | end 139 | end 140 | end 141 | } 142 | tnt.makepackageserializable(tnt, 'torchnet') 143 | 144 | return tnt 145 | -------------------------------------------------------------------------------- /log/init.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local doc = require 'argcheck.doc' 13 | 14 | local Log = torch.class('tnt.Log', tnt) 15 | 16 | doc[[ 17 | ### tnt.Log 18 | 19 | `Log` classes act as tables indexed by string keys. Allowed keys must be 20 | provided at construction. A special key `__status__` can be also set the 21 | convenience method `log:status()` to record basic messages. 22 | 23 | Viewers closures can be attached to a `Log`, and called at different events: 24 | * `onSet(log, key, value)`: when setting a key to the `Log` with `log:set{}`. 25 | * `onGet(log, key)`: when querying a key with `log:get()`. 26 | * `onFlush(log)`: when flushing out the stored data of the `Log` with `log:flush()`. 27 | * `onClose(log)`: when closing a `Log` with `log:close()`. 28 | 29 | Typical viewer closures are `text` or `json`, which allow to write to disk 30 | or to the console a subset of the keys stored by the `Log`, in a particular 31 | format. The special viewer closure `status` is made to be called on `set()` 32 | events, and will print out only status records. 33 | 34 | A typical use case would be the following: 35 | ```lua 36 | tnt = require 'torchnet' 37 | 38 | -- require the viewers we want 39 | logtext = require 'torchnet.log.view.text' 40 | logstatus = require 'torchnet.log.view.status' 41 | 42 | log = tnt.Log{ 43 | keys = {"loss", "accuracy"}, 44 | onFlush = { 45 | -- write out all keys in "log" file 46 | logtext{filename='log.txt', keys={"loss", "accuracy"}, format={"%10.5f", "%3.2f"}}, 47 | -- write out loss in a standalone file 48 | logtext{filename='loss.txt', keys={"loss"}}, 49 | -- print on screen too 50 | logtext{keys={"loss", "accuracy"}}, 51 | }, 52 | onSet = { 53 | -- add status to log 54 | logstatus{filename='log.txt'}, 55 | -- print status to screen 56 | logstatus{}, 57 | } 58 | } 59 | 60 | -- set values 61 | log:set{ 62 | loss = 0.1, 63 | accuracy = 97 64 | } 65 | 66 | -- write some info 67 | log:status("hello world") 68 | 69 | -- flush out log 70 | log:flush() 71 | ``` 72 | ]] 73 | 74 | Log.__clear = 75 | function(self) 76 | self.__events = {onClose={}, onFlush={}, onGet={}, onSet={}} 77 | self.__data = {} 78 | end 79 | 80 | Log.__init = argcheck{ 81 | doc = [[ 82 | 83 | #### tnt.Log(@ARGP) 84 | @ARGT 85 | 86 | Creates a new `Log` with allowed keys (strings) `keys`. Specifiy event 87 | closures with table of functions `onClose`, `onFlush`, `onGet` and `onSet`, 88 | which will be called when `close()`, `flush()`, `get()`, and `set{}` 89 | methods will be called, respectively. 90 | ]], 91 | noordered=true, 92 | {name="self", type="tnt.Log"}, 93 | {name="keys", type="table"}, 94 | {name="onClose", type="table", opt=true}, 95 | {name="onFlush", type="table", opt=true}, 96 | {name="onGet", type="table", opt=true}, 97 | {name="onSet", type="table", opt=true}, 98 | call = 99 | function(self, keys, onClose, onFlush, onGet, onSet) 100 | self.__keys = {__status__ = true} 101 | for _, key in ipairs(keys) do 102 | self.__keys[key] = true 103 | end 104 | self:__clear() 105 | if onClose then 106 | self:attach('onClose', onClose) 107 | end 108 | if onFlush then 109 | self:attach('onFlush', onFlush) 110 | end 111 | if onGet then 112 | self:attach('onGet', onGet) 113 | end 114 | if onSet then 115 | self:attach('onSet', onSet) 116 | end 117 | end 118 | } 119 | 120 | Log.status = argcheck{ 121 | doc = [[ 122 | 123 | #### tnt.Log:status(@ARGP) 124 | @ARGT 125 | 126 | Record a status message, with corresponding (optional) time of the event. 127 | ]], 128 | {name="self", type="tnt.Log"}, 129 | {name="message", type="string", opt=true}, 130 | {name="time", type="boolean", default=true}, 131 | call = 132 | function(self, message, time) 133 | local prefix = "|" 134 | if time then 135 | prefix = prefix .. " " .. os.date() .. " |" 136 | end 137 | self:set{ 138 | __status__ = string.format("%s %s", prefix, message) 139 | } 140 | end 141 | } 142 | 143 | Log.set = argcheck{ 144 | doc = [[ 145 | 146 | #### tnt.Log:set(@ARGP) 147 | @ARGT 148 | 149 | Set a number of keys (a subset of the keys provided at construction) to 150 | their corresponding values. 151 | 152 | Closures attached to the `onSet(log, key, value)` event will be called. 153 | ]], 154 | nonamed=true, 155 | {name="self", type="tnt.Log"}, 156 | {name="keys", type="table"}, 157 | call = 158 | function(self, keys) 159 | for key, value in pairs(keys) do 160 | assert(type(key) == 'string', 'string expected for key') 161 | if not self.__keys[key] then 162 | error(string.format("unknown key <%s>", key)) 163 | end 164 | for _, closure in ipairs(self.__events.onSet) do 165 | closure(self, key, value) 166 | end 167 | self.__data[key] = value 168 | end 169 | end 170 | } 171 | 172 | Log.get = argcheck{ 173 | doc = [[ 174 | 175 | #### tnt.Log:get(@ARGP) 176 | @ARGT 177 | 178 | Get the value of a given key. 179 | 180 | Closures attached to the `onGet(log, key)` event will be called. 181 | ]], 182 | {name="self", type="tnt.Log"}, 183 | {name="key", type="string"}, 184 | call = 185 | function(self, key) 186 | if not self.__keys[key] then 187 | error(string.format("unknown key <%s>", key)) 188 | end 189 | for _, closure in ipairs(self.__events.onGet) do 190 | closure(self, key) 191 | end 192 | return self.__data[key] 193 | end 194 | } 195 | 196 | Log.flush = argcheck{ 197 | doc = [[ 198 | 199 | #### tnt.Log:flush(@ARGP) 200 | @ARGT 201 | 202 | Flush (empty) the log data. 203 | 204 | Closures attached to the `onFlush(log)` event will be called. 205 | ]], 206 | {name="self", type="tnt.Log"}, 207 | call = 208 | function(self) 209 | for _, closure in ipairs(self.__events.onFlush) do 210 | closure(self) 211 | end 212 | self.__data = {} 213 | end 214 | } 215 | 216 | Log.close = argcheck{ 217 | doc = [[ 218 | 219 | #### tnt.Log:close(@ARGP) 220 | @ARGT 221 | 222 | Close the log. 223 | 224 | Closures attached to the `onClose(log)` event will be called. 225 | ]], 226 | {name="self", type="tnt.Log"}, 227 | call = 228 | function(self) 229 | for _, closure in ipairs(self.__events.onClose) do 230 | closure(self) 231 | end 232 | self:__clear() 233 | end 234 | } 235 | 236 | Log.attach = argcheck{ 237 | doc = [[ 238 | 239 | #### tnt.Log:attach(@ARGP) 240 | @ARGT 241 | 242 | Attach a set of functions (provided in a table) to a given event. 243 | ]], 244 | {name="self", type="tnt.Log"}, 245 | {name="event", type="string"}, 246 | {name="closures", type="table"}, 247 | call = 248 | function(self, event, closures) 249 | local events = self.__events[event] 250 | assert(events, string.format('unknown event <%s>', event)) 251 | for _, closure in ipairs(closures) do 252 | assert(type(closure) == 'function', string.format('%s: table of functions expected', event)) 253 | table.insert(events, closure) 254 | end 255 | end 256 | } 257 | -------------------------------------------------------------------------------- /log/view/json.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local argcheck = require 'argcheck' 11 | 12 | local json = argcheck{ 13 | noordered=true, 14 | {name="file", type="torch.File", opt=true}, 15 | {name="filename", type="string", opt=true}, 16 | {name="keys", type="table"}, 17 | {name="format", type="table", opt=true}, 18 | {name="append", type="boolean", default=false}, 19 | call = 20 | function(file, filename, keys__, format__, append) 21 | assert(not file or not filename, "file or filename expected (not both)") 22 | if filename then 23 | file = torch.DiskFile(filename, append and "rw" or "w") 24 | end 25 | local keys = {} 26 | for idx, key in ipairs(keys__) do 27 | local format = format__ and format__[idx] 28 | if not format then 29 | table.insert(keys, {name=key, format=tostring}) 30 | elseif type(format) == 'function' then 31 | table.insert(keys, {name=key, format=format}) 32 | elseif type(format) == 'string' then 33 | table.insert(keys, {name=key, format=function(value) return string.format(format, value) end}) 34 | else 35 | error('format must be a string or a function') 36 | end 37 | end 38 | return function(log) 39 | local txt = {} 40 | for _, key in ipairs(keys) do 41 | local format = key.format(log:get(key.name)) 42 | assert(type(format) == 'string', string.format("value for key %s cannot be converted to string", key)) 43 | table.insert(txt, string.format('"%s": "%s"', key.name, format)) 44 | end 45 | txt = string.format("{%s}", table.concat(txt, ", ")) 46 | if file then 47 | file:seekEnd() 48 | file:writeString(txt) 49 | file:writeString("\n") 50 | else 51 | print(txt) 52 | end 53 | end 54 | end 55 | } 56 | 57 | return json 58 | -------------------------------------------------------------------------------- /log/view/status.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local argcheck = require 'argcheck' 11 | 12 | local status = argcheck{ 13 | noordered=true, 14 | {name="file", type="torch.File", opt=true}, 15 | {name="filename", type="string", opt=true}, 16 | {name="append", type="boolean", default=false}, 17 | call = 18 | function(file, filename, append) 19 | assert(not file or not filename, "file or filename expected (not both)") 20 | if filename then 21 | file = torch.DiskFile(filename, append and "rw" or "w") 22 | end 23 | return function(data, key, value) 24 | if key == '__status__' then 25 | local status = tostring(value) 26 | if file then 27 | file:seekEnd() 28 | file:writeString(status) 29 | file:writeString("\n") 30 | else 31 | print(status) 32 | end 33 | end 34 | end 35 | end 36 | } 37 | 38 | return status 39 | -------------------------------------------------------------------------------- /log/view/text.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local argcheck = require 'argcheck' 11 | 12 | local text = argcheck{ 13 | noordered=true, 14 | {name="file", type="torch.File", opt=true}, 15 | {name="filename", type="string", opt=true}, 16 | {name="keys", type="table"}, 17 | {name="format", type="table", opt=true}, 18 | {name="separator", type="string", default=" | "}, 19 | {name="append", type="boolean", default=false}, 20 | call = 21 | function(file, filename, keys__, format__, separator, append) 22 | assert(not file or not filename, "file or filename expected (not both)") 23 | if filename then 24 | file = torch.DiskFile(filename, append and "rw" or "w") 25 | end 26 | local keys = {} 27 | for idx, key in ipairs(keys__) do 28 | local format = format__ and format__[idx] 29 | if not format then 30 | table.insert(keys, {name=key, format=function(value) return string.format("%s %s", key, value) end}) 31 | elseif type(format) == 'function' then 32 | table.insert(keys, {name=key, format=format}) 33 | elseif type(format) == 'string' then 34 | table.insert(keys, {name=key, format=function(value) return string.format(format, value) end}) 35 | else 36 | error('format must be a string or a function') 37 | end 38 | end 39 | return function(log) 40 | local txt = {} 41 | for _, key in ipairs(keys) do 42 | if log:get(key.name) then 43 | local format = key.format(log:get(key.name)) 44 | assert(type(format) == 'string', string.format("value for key %s cannot be converted to string", key)) 45 | table.insert(txt, format) 46 | end 47 | end 48 | txt = table.concat(txt, separator) 49 | if file then 50 | file:seekEnd() 51 | file:writeString(txt) 52 | file:writeString("\n") 53 | else 54 | print(txt) 55 | end 56 | end 57 | end 58 | } 59 | 60 | return text 61 | -------------------------------------------------------------------------------- /meter/apmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local APMeter = torch.class('tnt.APMeter', 'tnt.Meter', tnt) 14 | 15 | 16 | APMeter.__init = argcheck{ 17 | doc = [[ 18 | 19 | #### tnt.APMeter(@ARGP) 20 | @ARGT 21 | 22 | The `tnt.APMeter` measures the average precision per class. 23 | 24 | The `tnt.APMeter` is designed to operate on `NxK` Tensors `output` and 25 | `target`, and optionally a `Nx1` Tensor weight where (1) the `output` contains 26 | model output scores for `N` examples and `K` classes that ought to be higher 27 | when the model is more convinced that the example should be positively labeled, 28 | and smaller when the model believes the example should be negatively labeled 29 | (for instance, the output of a sigmoid function); (2) the `target` contains 30 | only values 0 (for negative examples) and 1 (for positive examples); and (3) 31 | the `weight` ( > 0) reprsents weight for each sample. 32 | 33 | The `tnt.APMeter` has no parameters to be set. 34 | ]], 35 | {name="self", type="tnt.APMeter"}, 36 | call = function(self) 37 | self:reset() 38 | end 39 | } 40 | 41 | APMeter.reset = argcheck{ 42 | {name="self", type="tnt.APMeter"}, 43 | call = function(self) 44 | self.scores = torch.DoubleTensor(torch.DoubleStorage()) 45 | self.targets = torch.LongTensor( torch.LongStorage()) 46 | self.weights = torch.DoubleTensor(torch.DoubleStorage()) 47 | end 48 | } 49 | 50 | APMeter.add = argcheck{ 51 | {name="self", type="tnt.APMeter"}, 52 | {name="output", type="torch.*Tensor"}, 53 | {name="target", type="torch.*Tensor"}, 54 | {name="weight", type="torch.*Tensor", opt=true}, 55 | call = function(self, output, target, weight) 56 | 57 | -- assertions on the input: 58 | if weight then 59 | weight = weight:squeeze() 60 | end 61 | if output:nDimension() == 1 then 62 | output = output:view(output:size(1), 1) 63 | else 64 | assert(output:nDimension() == 2, 65 | 'wrong output size (should be 1D or 2D with one column per class)' 66 | ) 67 | end 68 | if target:nDimension() == 1 then 69 | target = target:view(target:size(1), 1) 70 | else 71 | assert(target:nDimension() == 2, 72 | 'wrong target size (should be 1D or 2D with one column per class)' 73 | ) 74 | end 75 | if weight then 76 | assert(weight:nDimension() == 1, 'Weight dimension should be 1') 77 | assert(weight:nElement() == target:size(1), 78 | 'Weight dimension 1 should be the same as that of target' 79 | ) 80 | assert(torch.ge(weight, 0):all(), 'Weight should be non-negative only') 81 | end 82 | 83 | assert(output:size(1) == target:size(1) and 84 | output:size(2) == target:size(2), 85 | 'dimensions for output and target does not match' 86 | ) 87 | assert(torch.eq(torch.eq(target, 0):add(torch.eq(target, 1)), 1):all(), 88 | 'targets should be binary (0 or 1)' 89 | ) 90 | if self.scores:nElement() > 0 then 91 | assert(output:size(2) == self.scores:size(2), 92 | 'dimensions for output should match previously added examples.' 93 | ) 94 | end 95 | if self.targets:nElement() > 0 then 96 | assert(target:size(2) == self.targets:size(2), 97 | 'dimensions for output should match previously added examples.' 98 | ) 99 | end 100 | 101 | -- make sure storage is of sufficient size: 102 | if self.scores:storage():size() < self.scores:nElement() + output:nElement() then 103 | local newsize = math.ceil(self.scores:storage():size() * 1.5) 104 | local newweightsize = math.ceil(self.weights:storage():size() * 1.5) 105 | self.scores:storage():resize(newsize + output:nElement()) 106 | self.targets:storage():resize(newsize + output:nElement()) 107 | if weight then 108 | self.weights:storage():resize(newweightsize + output:size(1)) 109 | end 110 | end 111 | 112 | -- store scores and targets: 113 | local offset = (self.scores:dim() > 0) and self.scores:size(1) or 0 114 | self.scores:resize(offset + output:size(1), output:size(2)) 115 | self.targets:resize(offset + target:size(1), target:size(2)) 116 | 117 | self.scores:narrow(1, offset + 1, output:size(1)):copy(output) 118 | self.targets:narrow(1, offset + 1, target:size(1)):copy(target) 119 | 120 | if weight then 121 | self.weights:resize(offset + weight:size(1)) 122 | self.weights:narrow(1, offset + 1, weight:size(1)):copy(weight) 123 | end 124 | end 125 | } 126 | 127 | APMeter.value = argcheck{ 128 | {name="self", type="tnt.APMeter"}, 129 | call = function(self) 130 | 131 | -- compute average precision for each class: 132 | if not self.scores:nElement() == 0 then return 0 end 133 | local ap = torch.DoubleTensor(self.scores:size(2)):fill(0) 134 | local range = torch.range(1, self.scores:size(1), 'torch.DoubleTensor') 135 | local weight, weightedtruth 136 | if self.weights:nElement() > 0 then 137 | weight = self.weights.new(self.weights:size()) 138 | weightedtruth = self.weights.new(self.weights:size()) 139 | end 140 | for k = 1,self.scores:size(2) do 141 | 142 | -- sort scores: 143 | local scores = self.scores:select(2, k) 144 | local targets = self.targets:select(2, k) 145 | local _,sortind = torch.sort(scores, 1, true) 146 | local truth = targets:index(1, sortind) 147 | if self.weights:nElement() > 0 then 148 | weight:index(self.weights, 1, sortind) 149 | torch.cmul(weightedtruth, truth:double(), weight) 150 | range = weight:cumsum() 151 | end 152 | -- compute true positive sums: 153 | local tp = weightedtruth and weightedtruth:cumsum() 154 | or truth:double():cumsum() 155 | 156 | -- compute precision curve: 157 | local precision = tp:cdiv(range) 158 | 159 | -- compute average precision: 160 | ap[k] = precision[truth:byte()]:sum() / math.max(truth:sum(), 1) 161 | end 162 | return ap 163 | end 164 | } 165 | -------------------------------------------------------------------------------- /meter/aucmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local AUCMeter = torch.class('tnt.AUCMeter', 'tnt.Meter', tnt) 14 | 15 | AUCMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.AUCMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.AUCMeter` measures the area under the receiver-operating characteristic 22 | (ROC) curve for binary classification problems. The area under the curve (AUC) 23 | can be interpreted as the probability that, given a randomly selected positive 24 | example and a randomly selected negative example, the positive example is 25 | assigned a higher score by the classification model than the negative example. 26 | 27 | The `tnt.AUCMeter` is designed to operate on one-dimensional Tensors `output` 28 | and `target`, where (1) the `output` contains model output scores that ought to 29 | be higher when the model is more convinced that the example should be positively 30 | labeled, and smaller when the model believes the example should be negatively 31 | labeled (for instance, the output of a signoid function); and (2) the `target` 32 | contains only values 0 (for negative examples) and 1 (for positive examples). 33 | 34 | The `tnt.AUCMeter` has no parameters to be set. 35 | ]], 36 | {name="self", type="tnt.AUCMeter"}, 37 | call = 38 | function(self) 39 | self:reset() 40 | end 41 | } 42 | 43 | AUCMeter.reset = argcheck{ 44 | {name="self", type="tnt.AUCMeter"}, 45 | call = 46 | function(self) 47 | self.scores = torch.DoubleTensor(torch.DoubleStorage()) 48 | self.targets = torch.LongTensor( torch.LongStorage()) 49 | end 50 | } 51 | 52 | AUCMeter.add = argcheck{ 53 | {name="self", type="tnt.AUCMeter"}, 54 | {name="output", type="torch.*Tensor"}, 55 | {name="target", type="torch.*Tensor"}, 56 | call = 57 | function(self, output, target) 58 | target = target:squeeze() 59 | output = output:squeeze() 60 | assert( 61 | output:nDimension() == 1, 62 | 'dimensionality of output should be 1 (e.g., nn.Sigmoid output)' 63 | ) 64 | assert( 65 | target:nDimension() == 1, 66 | 'dimensionality of targets should be 1' 67 | ) 68 | assert( 69 | output:size(1) == target:size(1), 70 | 'number of outputs and targets does not match' 71 | ) 72 | assert( 73 | torch.eq(torch.eq(target, 0):add(torch.eq(target, 1)), 1):all(), 74 | 'targets should be binary (0 or 1)' 75 | ) 76 | 77 | -- make sure storage is of sufficient size: 78 | if self.scores:storage():size() < self.scores:nElement() + output:nElement() then 79 | local newsize = math.ceil(self.scores:storage():size() * 1.5) 80 | self.scores:storage():resize(newsize + output:nElement()) 81 | self.targets:storage():resize(newsize + output:nElement()) 82 | end 83 | 84 | -- store scores and targets in storage: 85 | local offset = self.scores:nElement() 86 | self.scores:resize(offset + output:nElement()) 87 | self.targets:resize(offset + target:nElement()) 88 | self.scores:narrow(1, offset + 1, output:nElement()):copy(output) 89 | self.targets:narrow(1, offset + 1, target:nElement()):copy(target) 90 | end 91 | } 92 | 93 | AUCMeter.value = argcheck{ 94 | {name="self", type="tnt.AUCMeter"}, 95 | call = 96 | function(self) 97 | 98 | -- sort the scores: 99 | if self.scores:nElement() == 0 then return 0.5 end 100 | local scores, sortind = torch.sort(self.scores, 1, true) 101 | 102 | -- construct the ROC curve: 103 | local tpr = torch.DoubleTensor(scores:nElement() + 1):zero() 104 | local fpr = torch.DoubleTensor(scores:nElement() + 1):zero() 105 | for n = 2,scores:nElement() + 1 do 106 | if self.targets[sortind[n - 1]] == 1 then 107 | tpr[n], fpr[n] = tpr[n - 1] + 1, fpr[n - 1] 108 | else 109 | tpr[n], fpr[n] = tpr[n - 1], fpr[n - 1] + 1 110 | end 111 | end 112 | tpr:div(self.targets:sum()) 113 | fpr:div(torch.mul(self.targets, -1):add(1):sum()) 114 | 115 | -- compute AUC: 116 | local auc = torch.cmul( 117 | tpr:narrow(1, 1, tpr:nElement() - 1), 118 | fpr:narrow(1, 2, fpr:nElement() - 1) - 119 | fpr:narrow(1, 1, fpr:nElement() - 1)):sum() 120 | 121 | -- return AUC and ROC curve: 122 | return auc, tpr, fpr 123 | end 124 | } 125 | -------------------------------------------------------------------------------- /meter/averagevaluemeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local AverageValueMeter = torch.class('tnt.AverageValueMeter', 'tnt.Meter', tnt) 14 | 15 | AverageValueMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.AverageValueMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.AverageValueMeter` measures and returns the average value and the 22 | standard deviation of any collection of numbers that are `add`ed to it. It is 23 | useful, for instance, to measure the average loss over a collection of examples. 24 | 25 | The `add()` function expects as input a Lua number `value`, which is the value 26 | that needs to be added to the list of values to average. It also takes as input 27 | an optional parameter `n` that assigns a weight to `value` in the average, in 28 | order to facilitate computing weighted averages (default = 1). 29 | 30 | The `tnt.AverageValueMeter` has no parameters to be set at initialization time. 31 | ]], 32 | {name="self", type="tnt.AverageValueMeter"}, 33 | call = 34 | function(self) 35 | self:reset() 36 | end 37 | } 38 | 39 | AverageValueMeter.reset = argcheck{ 40 | {name="self", type="tnt.AverageValueMeter"}, 41 | call = 42 | function(self) 43 | self.sum = 0 44 | self.n = 0 45 | self.var = 0 46 | end 47 | } 48 | 49 | AverageValueMeter.add = argcheck{ 50 | {name="self", type="tnt.AverageValueMeter"}, 51 | {name="value", type="number"}, 52 | {name="n", type="number", default=1}, 53 | call = 54 | function(self, value, n) 55 | assert(n >= 0, 'example weights cannot be negative') 56 | self.sum = self.sum + n * value 57 | self.var = self.var + n * value * value 58 | self.n = self.n + n 59 | end 60 | } 61 | 62 | AverageValueMeter.value = argcheck{ 63 | {name="self", type="tnt.AverageValueMeter"}, 64 | call = 65 | function(self) 66 | local n = self.n 67 | local mean = self.sum / n 68 | -- unbiased estimator of the variance: 69 | local std = math.sqrt( (self.var - n * mean * mean) / (n-1) ) 70 | return mean, std 71 | end 72 | } 73 | -------------------------------------------------------------------------------- /meter/classerrormeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local ClassErrorMeter = torch.class('tnt.ClassErrorMeter', 'tnt.Meter', tnt) 14 | 15 | ClassErrorMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.ClassErrorMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.ClassErrorMeter` measures the classification error (in %) of 22 | classification models (zero-one loss). The meter can also measure the error of 23 | predicting the correct label among the top-k scoring labels (for instance, in 24 | the Imagenet competition, one generally measures classification@5 errors). 25 | 26 | At initialization time, it takes to optional parameters: (1) a table 27 | `topk` that contains the values at which the classification@k errors should be 28 | measures (default = {1}); and (2) a boolean `accuracy` that makes the meter 29 | output accuracies instead of errors (accuracy = 1 - error). 30 | 31 | The `add(output, target)` method takes as input an NxK-tensor `output` that 32 | contains the output scores for each of the N examples and each of the K classes, 33 | and an N-tensor `target` that contains the targets corresponding to each of the 34 | N examples (targets are integers between 1 and K). If only one example is 35 | `add`ed, `output` may also be a K-tensor and target a 1-tensor. 36 | 37 | Please note that `topk` (if specified) may not contain values larger than K. 38 | 39 | The `value()` returns a table with the classification@k errors for all values 40 | at k that were specified in `topk` at initialization time. Alternatively, 41 | `value(k)` returns the classification@k error as a number; only values of `k` 42 | that were element of `topk` are allowed. If `accuracy` was set to `true` at 43 | initialization time, the `value()` method returns accuracies instead of errors. 44 | ]], 45 | noordered = true, 46 | {name="self", type="tnt.ClassErrorMeter"}, 47 | {name="topk", type="table", default={1}}, 48 | {name="accuracy", type="boolean", default=false}, 49 | call = 50 | function(self, topk, accuracy) 51 | self.topk = torch.LongTensor(topk):sort():totable() 52 | self.accuracy = accuracy 53 | self:reset() 54 | end 55 | } 56 | 57 | ClassErrorMeter.reset = argcheck{ 58 | {name="self", type="tnt.ClassErrorMeter"}, 59 | call = 60 | function(self) 61 | self.sum = {} 62 | for _,k in ipairs(self.topk) do 63 | self.sum[k] = 0 64 | end 65 | self.n = 0 66 | end 67 | } 68 | 69 | ClassErrorMeter.add = argcheck{ 70 | {name="self", type="tnt.ClassErrorMeter"}, 71 | {name="output", type="torch.*Tensor"}, 72 | {name="target", type="torch.*Tensor"}, 73 | call = 74 | function(self, output, target) 75 | target = target:squeeze() 76 | output = output:squeeze() 77 | if output:nDimension() == 1 then 78 | output = output:view(1, output:size(1)) 79 | assert( 80 | type(target) == 'number', 81 | 'target and output do not match') 82 | target = torch.Tensor(1):fill(target) 83 | else 84 | assert( 85 | output:nDimension() == 2, 86 | 'wrong output size (1D or 2D expected)') 87 | assert( 88 | target:nDimension() == 1, 89 | 'target and output do not match') 90 | end 91 | assert( 92 | target:size(1) == output:size(1), 93 | 'target and output do not match') 94 | 95 | local topk = self.topk 96 | local maxk = topk[#topk] 97 | local no = output:size(1) 98 | 99 | local _, pred = output:double():topk(maxk, 2, true, true) 100 | local correct = pred:typeAs(target):eq( 101 | target:view(no, 1):expandAs(pred)) 102 | 103 | for _,k in ipairs(topk) do 104 | self.sum[k] = self.sum[k] + no - correct:narrow(2, 1, k):sum() 105 | end 106 | self.n = self.n + no 107 | end 108 | } 109 | 110 | ClassErrorMeter.value = argcheck{ 111 | {name="self", type="tnt.ClassErrorMeter"}, 112 | {name="k", type="number", opt=true}, 113 | call = 114 | function(self, k) 115 | if k then 116 | assert(self.sum[k], 'invalid k (this k was not provided at construction time)') 117 | return self.accuracy and (1-self.sum[k] / self.n)*100 or self.sum[k]*100 / self.n 118 | else 119 | local value = {} 120 | for _,k in ipairs(self.topk) do 121 | value[k] = self:value(k) 122 | end 123 | return value 124 | end 125 | end 126 | } 127 | -------------------------------------------------------------------------------- /meter/confusionmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local ConfusionMeter = torch.class('tnt.ConfusionMeter', 'tnt.Meter', tnt) 14 | 15 | ConfusionMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.ConfusionMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.ConfusionMeter` constructs a confusion matrix for a multi-class 22 | classification problems. It does not support multi-label, multi-class problems: 23 | for such problems, please use `tnt.MultiLabelConfusionMeter`. 24 | 25 | At initialization time, the `k` parameter that indicates the number 26 | of classes in the classification problem under consideration must be specified. 27 | Additionally, an optional parameter `normalized` (default = `false`) may be 28 | specified that determines whether or not the confusion matrix is normalized 29 | (that is, it contains percentages) or not (that is, it contains counts). 30 | 31 | The `add(output, target)` method takes as input an NxK tensor `output` that 32 | contains the output scores obtained from the model for N examples and K classes, 33 | and a corresponding N-tensor or NxK-tensor `target` that provides the targets 34 | for the N examples. When `target` is an N-tensor, the targets are assumed to be 35 | integer values between 1 and K. When target is an NxK-tensor, the targets are 36 | assumed to be provided as one-hot vectors (that is, vectors that contain only 37 | zeros and a single one at the location of the target value to be encoded). 38 | 39 | The `value()` method has no parameters and returns the confusion matrix in a 40 | KxK tensor. In the confusion matrix, rows correspond to ground-truth targets and 41 | columns correspond to predicted targets. 42 | ]], 43 | {name="self", type="tnt.ConfusionMeter"}, 44 | {name="k", type="number"}, 45 | {name="normalized", type="boolean", default=false}, 46 | call = 47 | function(self, k, normalized) 48 | self.conf = torch.LongTensor(k, k) 49 | self.normalized = normalized 50 | self:reset() 51 | end 52 | } 53 | 54 | ConfusionMeter.reset = argcheck{ 55 | {name="self", type="tnt.ConfusionMeter"}, 56 | call = 57 | function(self) 58 | self.conf:zero() 59 | end 60 | } 61 | 62 | ConfusionMeter.add = argcheck{ 63 | {name="self", type="tnt.ConfusionMeter"}, 64 | {name="output", type="torch.*Tensor"}, 65 | {name="target", type="torch.*Tensor"}, 66 | call = 67 | function(self, output, target) 68 | target = target:squeeze() 69 | output = output:squeeze() 70 | if output:nDimension() == 1 then 71 | output = output:view(1, output:size(1)) 72 | if type(target) == 'number' then 73 | target = torch.Tensor(1):fill(target) 74 | end 75 | end 76 | local onehot = not (target:nDimension() == 1) 77 | assert( 78 | target:size(1) == output:size(1), 79 | 'number of targets and outputs do not match' 80 | ) 81 | assert( 82 | output:size(2) == self.conf:size(1), 83 | 'number of outputs does not match size of confusion matrix' 84 | ) 85 | assert( 86 | not onehot or target:size(2) == output:size(2), 87 | 'target should be 1D Tensor or have size of output (one-hot)' 88 | ) 89 | if onehot then 90 | assert( 91 | torch.eq(torch.eq(target, 0):add(torch.eq(target, 1)), 1):all(), 92 | 'in one-hot encoding, target values should be 0 or 1' 93 | ) 94 | assert( 95 | torch.eq(target:sum(2), 1):all(), 96 | 'multi-label setting is not supported' 97 | ) 98 | end 99 | 100 | -- update confusion matrix: 101 | local pos 102 | local _,pred = output:double():max(2) 103 | for n = 1,pred:size(1) do 104 | if onehot then _,pos = target[n]:max(1); pos = pos[1] 105 | else pos = target[n] end 106 | self.conf[pos][pred[n][1]] = self.conf[pos][pred[n][1]] + 1 107 | end 108 | end 109 | } 110 | 111 | ConfusionMeter.value = argcheck{ 112 | {name="self", type="tnt.ConfusionMeter"}, 113 | call = 114 | function(self) 115 | local confmat 116 | if self.normalized then 117 | confmat = torch.DoubleTensor(self.conf:size()):copy(self.conf) 118 | confmat:cdiv(confmat:sum(2):cmax(1e-12):expandAs(confmat)) 119 | else 120 | confmat = self.conf 121 | end 122 | return confmat 123 | end 124 | } 125 | -------------------------------------------------------------------------------- /meter/init.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local doc = require 'argcheck.doc' 13 | 14 | local Meter = torch.class('tnt.Meter', tnt) 15 | 16 | doc[[ 17 | ### tnt.Meter 18 | 19 | When training a model, you generally would like to measure how the model is 20 | performing. Specifically, you may want to measure the average processing time 21 | required per batch of data, the classification error or AUC of a classifier a 22 | validation set, or the precision@k of a retrieval model. 23 | 24 | Meters provide a standardized way to measure a range of different measures, 25 | which makes it easy to measure a wide range of properties of your models. 26 | 27 | Nearly all meters (except `tnt.TimeMeter`) implement three methods: 28 | 29 | * `add()` which adds an observation to the meter. 30 | * `value()` which returns the value of the meter, taking into account all observations. 31 | * `reset()` which removes all previously added observations, resetting the meter. 32 | 33 | The exact input arguments to the `add()` method vary depending on the meter. 34 | Most meters define the method as `add(output, target)`, where `output` is the 35 | output produced by the model and `target` is the ground-truth label of the data. 36 | 37 | The `value()` method is parameterless for most meters, but for measures that 38 | have a parameter (such as the k parameter in precision@k), they may take an 39 | input argument. 40 | 41 | An example of a typical usage of a meter is as follows: 42 | ```lua 43 | local meter = tnt.Meter() -- initialize meter 44 | for state, event in tnt.Engine:train{ 45 | network = network, 46 | criterion = criterion, 47 | iterator = iterator, 48 | } do 49 | if state == 'start-epoch' then 50 | meter:reset() -- reset meter 51 | elseif state == 'forward-criterion' then 52 | meter:add(state.network.output, sample.target) -- add value to meter 53 | elseif state == 'end-epoch' then 54 | print('value of meter:' .. meter:value()) -- get value of meter 55 | end 56 | end 57 | ``` 58 | ]] 59 | 60 | Meter.__init = argcheck{ 61 | {name="self", type="tnt.Meter"}, 62 | call = 63 | function(self) 64 | end 65 | } 66 | 67 | Meter.reset = argcheck{ 68 | {name="self", type="tnt.Meter"}, 69 | call = 70 | function(self) 71 | error('A tnt.Meter should implement the reset() function.') 72 | end 73 | } 74 | 75 | Meter.value = argcheck{ 76 | {name="self", type="tnt.Meter"}, 77 | call = 78 | function(self) 79 | error('A tnt.Meter should implement the value() function.') 80 | end 81 | } 82 | 83 | Meter.add = argcheck{ 84 | {name="self", type="tnt.Meter"}, 85 | call = 86 | function(self) 87 | error('A tnt.Meter should implement the add() function.') 88 | end 89 | } 90 | 91 | Meter.__updatetopk = argcheck{ 92 | {name='self', type='tnt.Meter'}, 93 | {name='output', type='torch.*Tensor'}, 94 | {name='target', type='torch.*Tensor'}, -- target is k-hot vector 95 | {name='topk', type='number'}, -- number of values to maintain 96 | {name='dim', type='number', default=1}, -- top-k selection dimension 97 | {name='desc', type='boolean', default=true}, -- maintain largest values 98 | call = function(self, output, target, topk, dim, desc) 99 | assert(dim == 1 or dim == 2) 100 | 101 | -- make sure top-k buffer has the right size: 102 | local firstinput = not (self.__topkoutput and self.__topktarget) 103 | self.__topkoutput = self.__topkoutput or output.new() 104 | self.__topktarget = self.__topktarget or target.new() 105 | self.__topkoutput:resize(output:size(1) + ((dim == 1) and topk or 0), 106 | output:size(2) + ((dim == 2) and topk or 0)) 107 | self.__topktarget:resize(target:size(1) + ((dim == 1) and topk or 0), 108 | target:size(2) + ((dim == 2) and topk or 0)) 109 | if firstinput then 110 | self.__topkoutput:fill(desc and -math.huge or math.huge) 111 | end 112 | 113 | -- copy new inputs into buffer: 114 | local otherdim = (dim == 1) and 2 or 1 115 | assert(output:size(otherdim) == self.__topkoutput:size(otherdim), 116 | string.format('incorrect size of dimension %d of output', otherdim)) 117 | assert(target:size(otherdim) == self.__topktarget:size(otherdim), 118 | string.format('incorrect size of dimension %d of target', otherdim)) 119 | self.__topkoutput:narrow(dim, topk + 1, output:size(dim)):copy(output) 120 | self.__topktarget:narrow(dim, topk + 1, target:size(dim)):copy(target) 121 | 122 | -- update top-k scores: 123 | local topoutput, topind = torch.topk(self.__topkoutput, topk, dim, desc) 124 | self.__topkoutput:narrow(dim, 1, topk):copy(topoutput) 125 | self.__topktarget:narrow(dim, 1, topk):copy( 126 | self.__topktarget:gather(dim, topind) 127 | ) 128 | end 129 | } 130 | -------------------------------------------------------------------------------- /meter/mapmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local mAPMeter = torch.class('tnt.mAPMeter', 'tnt.Meter', tnt) 14 | 15 | 16 | mAPMeter.__init = argcheck{ 17 | doc = [[ 18 | 19 | #### tnt.mAPMeter(@ARGP) 20 | @ARGT 21 | 22 | The `tnt.mAPMeter` measures the mean average precision over all classes. 23 | 24 | The `tnt.mAPMeter` is designed to operate on `NxK` Tensors `output` and 25 | `target`, and optionally a `Nx1` Tensor weight where (1) the `output` contains 26 | model output scores for `N` examples and `K` classes that ought to be higher 27 | when the model is more convinced that the example should be positively labeled, 28 | and smaller when the model believes the example should be negatively labeled 29 | (for instance, the output of a sigmoid function); (2) the `target` contains 30 | only values 0 (for negative examples) and 1 (for positive examples); and (3) 31 | the `weight` ( > 0) reprsents weight for each sample. 32 | 33 | The `tnt.mAPMeter` has no parameters to be set. 34 | ]], 35 | {name="self", type="tnt.mAPMeter"}, 36 | call = function(self) 37 | self.apmeter = tnt.APMeter() 38 | end 39 | } 40 | 41 | mAPMeter.reset = argcheck{ 42 | {name="self", type="tnt.mAPMeter"}, 43 | call = function(self) 44 | self.apmeter:reset() 45 | end 46 | } 47 | 48 | mAPMeter.add = argcheck{ 49 | {name="self", type="tnt.mAPMeter"}, 50 | {name="output", type="torch.*Tensor"}, 51 | {name="target", type="torch.*Tensor"}, 52 | {name="weight", type="torch.*Tensor", opt=true}, 53 | call = function(self, output, target, weight) 54 | self.apmeter:add{output = output, target = target, weight = weight} 55 | end 56 | } 57 | 58 | mAPMeter.value = argcheck{ 59 | {name="self", type="tnt.mAPMeter"}, 60 | call = function(self) 61 | return self.apmeter:value():mean() 62 | end 63 | } 64 | -------------------------------------------------------------------------------- /meter/movingaveragevaluemeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local MovingAverageValueMeter = torch.class( 14 | 'tnt.MovingAverageValueMeter', 15 | 'tnt.Meter', tnt 16 | ) 17 | 18 | MovingAverageValueMeter.__init = argcheck{ 19 | doc = [[ 20 | 21 | #### tnt.MovingAverageValueMeter(@ARGP) 22 | @ARGT 23 | 24 | The `tnt.MovingAverageValueMeter` measures and returns the average value 25 | and the standard deviation of any collection of numbers that are `add`ed to it 26 | within the most recent moving average window. It is useful, for instance, 27 | to measure the average loss over a collection of examples withing the 28 | most recent window. 29 | 30 | The `add()` function expects as input a Lua number `value`, which is the value 31 | that needs to be added to the list of values to average. 32 | 33 | The `tnt.MovingAverageValueMeter` needs the moving window size to be set at 34 | initialization time. 35 | ]], 36 | {name="self", type="tnt.MovingAverageValueMeter"}, 37 | {name="windowsize", type="number"}, 38 | call = 39 | function(self, windowsize) 40 | self.windowsize = windowsize; 41 | self.valuequeue = torch.Tensor(self.windowsize) 42 | self:reset() 43 | end 44 | } 45 | 46 | MovingAverageValueMeter.reset = argcheck{ 47 | {name="self", type="tnt.MovingAverageValueMeter"}, 48 | call = 49 | function(self) 50 | self.sum = 0 51 | self.n = 0 52 | self.var = 0 53 | self.valuequeue:fill(0.) 54 | end 55 | } 56 | 57 | MovingAverageValueMeter.add = argcheck{ 58 | {name="self", type="tnt.MovingAverageValueMeter"}, 59 | {name="value", type="number"}, 60 | call = 61 | function(self, value) 62 | local queueid = (self.n % self.windowsize) + 1 63 | local oldvalue = self.valuequeue[queueid] 64 | self.sum = self.sum + value - oldvalue 65 | self.var = self.var + value * value 66 | - oldvalue * oldvalue 67 | self.valuequeue[queueid] = value 68 | self.n = self.n + 1 69 | end 70 | } 71 | 72 | MovingAverageValueMeter.value = argcheck{ 73 | {name="self", type="tnt.MovingAverageValueMeter"}, 74 | call = 75 | function(self) 76 | local n = math.min(self.n, self.windowsize) 77 | local mean = self.sum / math.max(1, n) 78 | -- unbiased estimator of the variance: 79 | local std = math.sqrt((self.var - n * mean * mean) / math.max(1, n-1)) 80 | return mean, std 81 | end 82 | } 83 | -------------------------------------------------------------------------------- /meter/msemeter.lua: -------------------------------------------------------------------------------- 1 | local tnt = require 'torchnet.env' 2 | local argcheck = require 'argcheck' 3 | 4 | local MSEMeter = torch.class('tnt.MSEMeter', 'tnt.Meter', tnt) 5 | 6 | MSEMeter.__init = argcheck{ 7 | {name = 'self', type = 'tnt.MSEMeter'}, 8 | {name = 'root', type = 'boolean', default = false}, 9 | call = function(self, root) 10 | self:reset() 11 | self.root = root 12 | end 13 | } 14 | 15 | MSEMeter.reset = argcheck{ 16 | {name = 'self', type = 'tnt.MSEMeter'}, 17 | call = function(self) 18 | self.n = 0 19 | self.sesum = 0 20 | end 21 | } 22 | 23 | MSEMeter.add = argcheck{ 24 | {name = 'self', type = 'tnt.MSEMeter'}, 25 | {name = 'output', type = 'torch.*Tensor'}, 26 | {name = 'target', type = 'torch.*Tensor'}, 27 | call = function(self, output, target) 28 | assert(output:isSameSizeAs(target), 'output and target not the same size') 29 | assert(torch.isTypeOf(output, torch.typename(target)), 30 | 'output and target not the same type') 31 | self.n = self.n + output:nElement() 32 | self.sesum = self.sesum + torch.add(output, -target):pow(2):sum() 33 | end 34 | } 35 | 36 | MSEMeter.value = argcheck{ 37 | {name = 'self', type = 'tnt.MSEMeter'}, 38 | call = function(self) 39 | local mse = self.sesum / math.max(1, self.n) 40 | return self.root and math.sqrt(mse) or mse 41 | end 42 | } 43 | -------------------------------------------------------------------------------- /meter/multilabelconfusionmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local tds = require 'tds' 12 | local argcheck = require 'argcheck' 13 | 14 | local MultiLabelConfusionMeter = 15 | torch.class('tnt.MultiLabelConfusionMeter', 'tnt.Meter', tnt) 16 | 17 | MultiLabelConfusionMeter.__init = argcheck{ 18 | doc = [[ 19 | 20 | #### tnt.MultiLabelConfusionMeter(@ARGP) 21 | @ARGT 22 | 23 | The `tnt.MultiLabelConfusionMeter` constructs a confusion matrix for multi- 24 | label, multi-class classification problems. In constructing the confusion 25 | matrix, the number of positive predictions is assumed to be equal to the number 26 | of positive labels in the ground-truth. Correct predictions (that is, labels in 27 | the prediction set that are also in the ground-truth set) are added to the 28 | diagonal of the confusion matrix. Incorrect predictions (that is, labels in the 29 | prediction set that are not in the ground-truth set) are equally divided over 30 | all non-predicted labels in the ground-truth set. 31 | 32 | At initialization time, the `k` parameter that indicates the number 33 | of classes in the classification problem under consideration must be specified. 34 | Additionally, an optional parameter `normalized` (default = `false`) may be 35 | specified that determines whether or not the confusion matrix is normalized 36 | (that is, it contains percentages) or not (that is, it contains counts). 37 | 38 | The `add(output, target)` method takes as input an NxK tensor `output` that 39 | contains the output scores obtained from the model for N examples and K classes, 40 | and a corresponding NxK-tensor `target` that provides the targets for the N 41 | examples using one-hot vectors (that is, vectors that contain only zeros and a 42 | single one at the location of the target value to be encoded). 43 | 44 | The `value()` method has no parameters and returns the confusion matrix in a 45 | KxK tensor. In the confusion matrix, rows correspond to ground-truth targets and 46 | columns correspond to predicted targets. 47 | ]], 48 | {name="self", type="tnt.MultiLabelConfusionMeter"}, 49 | {name="k", type="number"}, 50 | {name="normalized", type="boolean", default=true}, 51 | call = 52 | function(self, k, normalized) 53 | self.conf = torch.DoubleTensor(k, k) 54 | self.normalized = normalized 55 | self:reset() 56 | end 57 | } 58 | 59 | MultiLabelConfusionMeter.reset = argcheck{ 60 | {name="self", type="tnt.MultiLabelConfusionMeter"}, 61 | call = 62 | function(self) 63 | self.conf:zero() 64 | end 65 | } 66 | 67 | MultiLabelConfusionMeter.add = argcheck{ 68 | {name="self", type="tnt.MultiLabelConfusionMeter"}, 69 | {name="output", type="torch.*Tensor"}, 70 | {name="target", type="torch.*Tensor"}, 71 | call = 72 | function(self, output, target) 73 | target = target:squeeze() 74 | output = output:squeeze() 75 | if output:nDimension() == 1 then 76 | output = output:view(1, output:size(1)) 77 | end 78 | if target:nDimension() == 1 then 79 | target = target:view(1, target:size(1)) 80 | end 81 | assert( 82 | target:nDimension() == output:nDimension() and 83 | torch.eq( 84 | torch.LongTensor(target:size()), 85 | torch.LongTensor(output:size()) 86 | ):all(), 87 | 'number of targets and outputs do not match' 88 | ) 89 | assert( 90 | torch.eq(torch.eq(target, 0):add(torch.eq(target, 1)), 1):all(), 91 | 'target values should be 0 or 1' 92 | ) 93 | assert( 94 | target:size(2) == self.conf:size(1), 95 | 'target size does not match size of confusion matrix' 96 | ) 97 | 98 | -- update confusion matrix: 99 | local nc = output:size(2) 100 | local _,pred = output:double():sort(2, true) 101 | for n = 1,pred:size(1) do 102 | 103 | -- convert targets and predictions to sets: 104 | local targetTable, predTable = tds.hash(), tds.hash() 105 | local pos = 106 | torch.range(1, nc):round():long()[torch.eq(target[n], 1)] 107 | for k = 1,pos:nElement() do 108 | targetTable[pos[k]] = 1 109 | predTable[pred[n][k]] = 1 110 | end 111 | 112 | -- loop over correct predictions: 113 | for key,_ in pairs(targetTable) do 114 | if predTable[key] then 115 | self.conf[key][key] = self.conf[key][key] + 1 116 | targetTable[key] = nil 117 | predTable[key] = nil 118 | end 119 | end 120 | 121 | -- equally distribute mass of incorrect predictions: 122 | local weight = 1 / #predTable 123 | for key1,_ in pairs(targetTable) do 124 | for key2,_ in pairs(predTable) do 125 | self.conf[key1][key2] = self.conf[key1][key2] + weight 126 | end 127 | end 128 | end 129 | end 130 | } 131 | 132 | MultiLabelConfusionMeter.value = argcheck{ 133 | {name="self", type="tnt.MultiLabelConfusionMeter"}, 134 | call = 135 | function(self) 136 | local conf = torch.DoubleTensor(self.conf:size()):copy(self.conf) 137 | if self.normalized then 138 | conf:cdiv(conf:sum(2):expandAs(conf):add(1e-8)) 139 | end 140 | return conf 141 | end 142 | } 143 | -------------------------------------------------------------------------------- /meter/ndcgmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local NDCGMeter = torch.class('tnt.NDCGMeter', 'tnt.Meter', tnt) 14 | 15 | -- function computing discounted cumulative gain: 16 | local function computeDCG(relevance, index, K) 17 | 18 | -- assertions: 19 | assert(relevance) 20 | assert(index) 21 | assert(K) 22 | assert(type(K) == 'number') 23 | assert(index:max() <= relevance:nElement()) 24 | assert(index:nElement() >= K) 25 | relevance = relevance:squeeze() 26 | index = index:squeeze() 27 | assert(relevance:dim() == 1) 28 | assert(index:dim() == 1) 29 | 30 | -- return DCG: 31 | local dcg = relevance[index[1]] 32 | if K > 1 then 33 | dcg = dcg + relevance:index(1, index:narrow(1, 2, K - 1)):cdiv( 34 | torch.range(2, K):log():div(math.log(2)):typeAs(relevance) 35 | ):sum() 36 | end 37 | return dcg 38 | end 39 | 40 | -- function computing ideal discounted cumulative gain: 41 | local function computeIDCG(relevance, K) 42 | relevance = relevance:squeeze() 43 | assert(relevance:dim() == 1) 44 | local _,sortind = torch.sort(relevance, 1, true) -- descending order 45 | return computeDCG(relevance, sortind, K) 46 | end 47 | 48 | -- function computing the normalized discounted cumulative gain: 49 | local function computeNCDG(relevance, index, K) 50 | local r = computeDCG(relevance, index, K) / computeIDCG(relevance, K) 51 | assert(r >= 0 and r <= 1) 52 | return r 53 | end 54 | 55 | NDCGMeter.__init = argcheck{ 56 | doc = [[ 57 | 58 | #### tnt.NDCGMeter(@ARGP) 59 | @ARGT 60 | 61 | The `tnt.NDCGMeter` measures the normalized discounted cumulative gain (NDCG) of 62 | a ranking produced by a model at prespecified levels k, and averages the NDCG 63 | over all examples. 64 | 65 | The discounted cumulative gain at level k is defined as: 66 | 67 | DCG_k = rel_1 + \sum{i = 2}^k (rel_i / log_2(i)) 68 | 69 | Herein, rel_i is the relevance of item i as specified by an external rater. 70 | Defining ideal DCG (IDCG) as the best possible DCG for a given example, the NDCG 71 | at level k is defined as: 72 | 73 | NDCG_k = DCG_k / IDCG_k 74 | 75 | At initialization time, the meter takes as input a table `K` that contains all 76 | the levels k at which the NDCG is computed. 77 | 78 | The `add(output, relevance)` method takes as input (1) a NxC tensor of model 79 | `outputs`, which scores for all C possible outputs for a batch of N examples; 80 | and (2) a NxC tensor `relevance` that contains the corresponding relevances for 81 | these scores, as provided by an external rater. Relevances are generally 82 | obtained from human raters. 83 | 84 | The `value()` method returns a table that contains the NDCG values for all 85 | levels K that were provided at initialization time. Alternatively, the NDCG at 86 | a specific level k can be obtained by calling `value(k)`. Note that the level 87 | `k` should be an element of the table `K` specified at initialization time. 88 | 89 | Please note that the number of outputs and relevances C should always be at 90 | least as high as the highest NDCG level k that the meter is computing. 91 | ]], 92 | {name="self", type="tnt.NDCGMeter"}, 93 | {name="K", type="table", default = {1}}, 94 | noordered=true, 95 | call = 96 | function(self, K) 97 | self.K = torch.LongTensor(K):sort():totable() 98 | self:reset() 99 | end 100 | } 101 | 102 | NDCGMeter.reset = argcheck{ 103 | {name="self", type="tnt.NDCGMeter"}, 104 | call = 105 | function(self) 106 | self.ndcg = {} 107 | for _,k in ipairs(self.K) do self.ndcg[k] = 0 end 108 | self.n = 0 109 | end 110 | } 111 | 112 | NDCGMeter.add = argcheck{ 113 | {name="self", type="tnt.NDCGMeter"}, 114 | {name="output", type="torch.*Tensor"}, 115 | {name="relevance", type="torch.*Tensor"}, 116 | call = 117 | function(self, output, relevance) 118 | 119 | -- check inputs: 120 | if output:dim() == 1 then 121 | output:resize(1, output:nElement()) 122 | end 123 | if relevance:dim() == 1 then 124 | relevance:resize(1, relevance:nElement()) 125 | end 126 | assert(output:dim() == 2) 127 | assert(relevance:dim() == 2) 128 | assert(output:size(1) == relevance:size(1), 'batch size must match') 129 | assert(output:size(2) == relevance:size(2), 'result size must match') 130 | assert( 131 | relevance:size(2) >= self.K[#self.K], 132 | 'too few results for value of K' 133 | ) 134 | 135 | -- compute average NDCG: 136 | relevance = relevance:double() 137 | local _,index = torch.sort(output, 2, true) -- descending order 138 | for n = 1,index:size(1) do 139 | for _,k in ipairs(self.K) do 140 | self.ndcg[k] = 141 | self.ndcg[k] + computeNCDG(relevance[n], index[n], k) 142 | end 143 | end 144 | self.n = self.n + index:size(1) 145 | end 146 | } 147 | 148 | NDCGMeter.value = argcheck{ 149 | {name="self", type="tnt.NDCGMeter"}, 150 | {name="K", type="number", opt=true}, 151 | call = 152 | function(self, K) 153 | if K then 154 | assert( 155 | self.ndcg[K], 'invalid k (was not provided at construction time)' 156 | ) 157 | return self.ndcg[K] / self.n 158 | else 159 | local value = {} 160 | for _,k in ipairs(self.K) do 161 | value[k] = self.ndcg[k] / self.n 162 | end 163 | return value 164 | end 165 | end 166 | } 167 | -------------------------------------------------------------------------------- /meter/precisionatkmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local PrecisionAtKMeter = torch.class('tnt.PrecisionAtKMeter', 'tnt.Meter', tnt) 14 | 15 | PrecisionAtKMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.PrecisionAtKMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.PrecisionAtKMeter` measures the precision@k of ranking methods at pre-specified 22 | levels k. The precision@k is the percentage of the k front-ranked 23 | items according to the model that is in the list of correct (positive) targets. 24 | 25 | At initialization time, a table `topk` may be given as input that specifies the 26 | levels k at which the precision@k will be measures (default = `{10}`). In 27 | addition, a number `dim` may be provided that specifies over which dimension the 28 | precision@k should be computed (default = 2), and a boolean `online` may be 29 | specified that indicates whether we see all inputs along dimension `dim` at once 30 | (default = `false`). 31 | 32 | The `add(output, target)` method takes two inputs. In the default mode (`dim=2` 33 | and `online=false`), the inputs mean: 34 | * A NxC tensor that for each of the N examples (queries) contains a score 35 | indicating to what extent each of the C classes (documents) is relevant to 36 | the query, according to the model. 37 | * A binary NxC `target` tensor that encodes which of the C classes 38 | (documents) are actually relevant to the the N-th input (query). For 39 | instance, a row of {0, 1, 0, 1} indicates that the example is associated 40 | with classes 2 and 4. 41 | 42 | The result of setting `dim` to `1` is identical to transposing the tensors 43 | `output` and `target` in the above. The result of setting `online=true` is that 44 | the function assumes that it is not the number of queries `N` that is growing 45 | with repeated calls to `add()`, but the number of candidate documents `C`. (Use 46 | this mode in scenarios where `C` is large but `N` is small.) 47 | 48 | The `value()` method returns a table that contains the precision@k (that is, the 49 | percentage of targets predicted correctly) at the cutoff levels in `topk` that 50 | were specified at initialization time. Alternatively, the precision@k at 51 | a specific level k can be obtained by calling `value(k)`. Note that the level 52 | `k` should be an element of the table `topk` specified at initialization time. 53 | 54 | Please note that the maximum value in `topk` cannot be higher than the total 55 | number of classes (documents). 56 | ]], 57 | noordered = true, 58 | {name="self", type="tnt.PrecisionAtKMeter"}, 59 | {name="topk", type="table", default={10}}, 60 | {name="dim", type="number", default=2}, 61 | {name="online", type="boolean", default=false}, 62 | call = 63 | function(self, topk, dim, online) 64 | assert(dim == 1 or dim == 2, 'value of dimension should be 1 or 2') 65 | self.topk = torch.LongTensor(topk):sort():totable() 66 | self.maxk = self.topk[#self.topk] 67 | self.dim = dim 68 | self.online = online 69 | self:reset() 70 | end 71 | } 72 | 73 | PrecisionAtKMeter.reset = argcheck{ 74 | {name="self", type="tnt.PrecisionAtKMeter"}, 75 | call = 76 | function(self) 77 | self.tp = {} 78 | for _,k in ipairs(self.topk) do self.tp[k] = 0 end 79 | self.n = 0 80 | end 81 | } 82 | 83 | PrecisionAtKMeter.add = argcheck{ 84 | {name="self", type="tnt.PrecisionAtKMeter"}, 85 | {name="output", type="torch.*Tensor"}, 86 | {name="target", type="torch.*Tensor"}, -- target is k-hot vector 87 | call = 88 | function(self, output, target) 89 | output = output:squeeze() 90 | if output:nDimension() == 1 then 91 | output = output:view(1, output:size(1)) 92 | else 93 | assert( 94 | output:nDimension() == 2, 95 | 'wrong output size (1D or 2D expected)' 96 | ) 97 | end 98 | if target:nDimension() == 1 then 99 | target = target:view(1, target:size(1)) 100 | else 101 | assert( 102 | target:nDimension() == 2, 103 | 'wrong target size (1D or 2D expected)' 104 | ) 105 | end 106 | for s = 1,#output:size() do 107 | assert(output:size(s) == target:size(s), 108 | string.format('target and output do not match on dim %d', s)) 109 | end 110 | 111 | -- add new output-target pairs: 112 | if self.online then 113 | 114 | -- update top-k list along dimension dim: 115 | self:__updatetopk{ 116 | output = output, 117 | target = target, 118 | topk = self.maxk, 119 | dim = self.dim, 120 | desc = true, 121 | } 122 | self.n = output:size((self.dim == 1) and 2 or 1) 123 | else 124 | 125 | -- accumulate counts of true positives and total # of inputs: 126 | local topout, topind = torch.topk(output, self.maxk, self.dim, true) 127 | local _,sortind = torch.sort(topout, self.dim, true) 128 | local topind = topind:gather(self.dim, sortind) 129 | local sorttarget = target:gather(self.dim, topind) 130 | for _,k in ipairs(self.topk) do 131 | self.tp[k] = self.tp[k] + sorttarget:narrow(self.dim, 1, k):sum() 132 | end 133 | self.n = self.n + target:size((self.dim == 1) and 2 or 1) 134 | end 135 | end 136 | } 137 | 138 | PrecisionAtKMeter.value = argcheck{ 139 | {name="self", type="tnt.PrecisionAtKMeter"}, 140 | {name="k", type="number", opt=true}, 141 | call = 142 | function(self, k) 143 | 144 | -- in online mode, sort outputs and corresponding targets: 145 | if self.online then 146 | local topoutput = self.__topkoutput:narrow(self.dim, 1, self.maxk) 147 | local toptarget = self.__topktarget:narrow(self.dim, 1, self.maxk) 148 | local _,sortind = torch.sort(topoutput, self.dim, true) 149 | local sorttarget = toptarget:gather(self.dim, sortind) 150 | for _,k in ipairs(self.topk) do 151 | self.tp[k] = sorttarget:narrow(self.dim, 1, k):sum() 152 | end 153 | end 154 | 155 | -- compute the precision@k: 156 | if k then 157 | assert(self.tp[k], 'invalid k (not provided at construction time)') 158 | return (self.tp[k] / (self.n * k)) * 100 159 | else 160 | local value = {} 161 | for _,k in ipairs(self.topk) do value[k] = self:value(k) end 162 | return value 163 | end 164 | end 165 | } 166 | -------------------------------------------------------------------------------- /meter/precisionmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local PrecisionMeter = torch.class('tnt.PrecisionMeter', 'tnt.Meter', tnt) 14 | 15 | PrecisionMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.PrecisionMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.PrecisionMeter` measures the precision of ranking methods at pre- 22 | specified thresholds. The precision is the percentage of the positively labeled 23 | items according to the model that is in the list of correct (positive) targets. 24 | 25 | At initialization time, the `tnt.PrecisionMeter` provides two optional 26 | parameters. The first parameter is a table `threshold` that contains all 27 | thresholds at which the precision is measured (default = {0.5}). Thresholds 28 | should be numbers between 0 and 1. The second parameter is a boolean `perclass` 29 | that makes the meter measure the precision per class when set to `true` 30 | (default = `false`). When `perclass` is set to `false`, the precision is simply 31 | averaged over all examples. 32 | 33 | The `add(output, target)` method takes two inputs: 34 | * A NxK tensor that for each of the N examples indicates the probability 35 | of the example belonging to each of the K classes, according to the model. 36 | The probabilities should sum to one over all classes; that is, the row sums 37 | of `output` should all be one. 38 | * A binary NxK `target` tensor that encodes which of the K classes 39 | are associated with the N-th input. For instance, a row of {0, 1, 0, 1} 40 | indicates that the example is associated with classes 2 and 4. 41 | 42 | The `value()` method returns a table containing the precision of the model 43 | predictions measured at the `threshold`s specified at initialization time. The 44 | `value(t)` method returns the precision at a particular threshold `t`. Note that 45 | this threshold `t` should be an element of the `threshold` table specified at 46 | initialization time of the meter. 47 | ]], 48 | noordered = true, 49 | {name="self", type="tnt.PrecisionMeter"}, 50 | {name="threshold", type="table", default={0.5}}, 51 | {name="perclass", type="boolean", default=false}, 52 | call = function(self, threshold, perclass) 53 | self.threshold = {} 54 | for _,n in pairs(threshold) do 55 | assert(n >= 0 and n <= 1, 'threshold should be between 0 and 1') 56 | table.insert(self.threshold, n) 57 | end 58 | table.sort(self.threshold) 59 | self.perclass = perclass 60 | self:reset() 61 | end 62 | } 63 | 64 | PrecisionMeter.reset = argcheck{ 65 | {name="self", type="tnt.PrecisionMeter"}, 66 | call = function(self) 67 | self.tp = {} 68 | self.tpfp = {} 69 | for _,t in ipairs(self.threshold) do 70 | self.tp[t] = torch.Tensor() 71 | self.tpfp[t] = torch.Tensor() 72 | end 73 | end 74 | } 75 | 76 | PrecisionMeter.add = argcheck{ 77 | {name="self", type="tnt.PrecisionMeter"}, 78 | {name="output", type="torch.*Tensor"}, 79 | {name="target", type="torch.*Tensor"}, -- target is k-hot vector 80 | call = function(self, output, target) 81 | output = output:squeeze() 82 | if output:nDimension() == 1 then 83 | output = output:view(1, output:size(1)) 84 | else 85 | assert( 86 | output:nDimension() == 2, 87 | 'wrong output size (1D or 2D expected)' 88 | ) 89 | end 90 | if target:nDimension() == 1 then 91 | target = target:view(1, target:size(1)) 92 | else 93 | assert( 94 | target:nDimension() == 2, 95 | 'wrong target size (1D or 2D expected)' 96 | ) 97 | end 98 | for s = 1,#output:size() do 99 | assert( 100 | output:size(s) == target:size(s), 101 | string.format('target and output do not match on dim %d', s) 102 | ) 103 | end 104 | 105 | -- initialize upon receiving first inputs: 106 | for _,t in ipairs(self.threshold) do 107 | if self.tp[t]:nElement() == 0 then 108 | self.tp[t]:resize( target:size(2)):typeAs(output):fill(0) 109 | self.tpfp[t]:resize(target:size(2)):typeAs(output):fill(0) 110 | end 111 | end 112 | 113 | -- scores of true positives: 114 | local true_pos = output:clone() 115 | true_pos[torch.eq(target, 0)] = -1 116 | 117 | -- sum all the things: 118 | for _,t in ipairs(self.threshold) do 119 | self.tp[t]:add( torch.ge(true_pos, t):typeAs(output):sum(1):squeeze()) 120 | self.tpfp[t]:add(torch.ge(output, t):typeAs(output):sum(1):squeeze()) 121 | end 122 | end 123 | } 124 | 125 | PrecisionMeter.value = argcheck{ 126 | {name="self", type="tnt.PrecisionMeter"}, 127 | {name="t", type="number", opt=true}, 128 | call = function(self, t) 129 | if t then 130 | assert( 131 | self.tp[t] and self.tpfp[t], 132 | string.format('%f is an incorrect threshold [not set]', t) 133 | ) 134 | if self.perclass then 135 | local precisionPerClass = 136 | torch.cdiv(self.tp[t], self.tpfp[t]):mul(100) 137 | precisionPerClass[torch.eq(self.tpfp[t], 0)] = 100 138 | return precisionPerClass 139 | else 140 | if self.tpfp[t]:sum() == 0 then return 100 end 141 | return self.tp[t]:sum() / self.tpfp[t]:sum() * 100 142 | end 143 | else 144 | local value = {} 145 | for _,t in ipairs(self.threshold) do 146 | value[t] = self:value(t) 147 | end 148 | return value 149 | end 150 | end 151 | } 152 | -------------------------------------------------------------------------------- /meter/recallmeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local RecallMeter = torch.class('tnt.RecallMeter', 'tnt.Meter', tnt) 14 | 15 | RecallMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.RecallMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.RecallMeter` measures the recall of ranking methods at pre- 22 | specified thresholds. The recall is the percentage of the correct (positive) 23 | targets that is in the list of positively labeled items according to the model. 24 | 25 | At initialization time, the `tnt.RecallMeter` provides two optional 26 | parameters. The first parameter is a table `threshold` that contains all 27 | thresholds at which the recall is measured (default = {0.5}). Thresholds 28 | should be numbers between 0 and 1. The second parameter is a boolean `perclass` 29 | that makes the meter measure the recall per class when set to `true` 30 | (default = `false`). When `perclass` is set to `false`, the recall is simply 31 | averaged over all examples. 32 | 33 | The `add(output, target)` method takes two inputs: 34 | * A NxK tensor that for each of the N examples indicates the probability 35 | of the example belonging to each of the K classes, according to the model. 36 | The probabilities should sum to one over all classes; that is, the row sums 37 | of `output` should all be one. 38 | * A binary NxK `target` tensor that encodes which of the K classes 39 | are associated with the N-th input. For instance, a row of {0, 1, 0, 1} 40 | indicates that the example is associated with classes 2 and 4. 41 | 42 | The `value()` method returns a table containing the recall of the model 43 | predictions measured at the `threshold`s specified at initialization time. The 44 | `value(t)` method returns the recall at a particular threshold `t`. Note that 45 | this threshold `t` should be an element of the `threshold` table specified at 46 | initialization time of the meter. 47 | ]], 48 | noordered = true, 49 | {name="self", type="tnt.RecallMeter"}, 50 | {name="threshold", type="table", default={0.5}}, 51 | {name="perclass", type="boolean", default=false}, 52 | call = function(self, threshold, perclass) 53 | self.threshold = {} 54 | for _,n in pairs(threshold) do 55 | assert(n >= 0 and n <= 1, 'threshold should be between 0 and 1') 56 | table.insert(self.threshold, n) 57 | end 58 | table.sort(self.threshold) 59 | self.perclass = perclass 60 | self:reset() 61 | end 62 | } 63 | 64 | RecallMeter.reset = argcheck{ 65 | {name="self", type="tnt.RecallMeter"}, 66 | call = function(self) 67 | self.tp = {} 68 | self.tpfn = {} 69 | for _,t in ipairs(self.threshold) do 70 | self.tp[t] = torch.Tensor() 71 | self.tpfn[t] = torch.Tensor() 72 | end 73 | end 74 | } 75 | 76 | RecallMeter.add = argcheck{ 77 | {name="self", type="tnt.RecallMeter"}, 78 | {name="output", type="torch.*Tensor"}, 79 | {name="target", type="torch.*Tensor"}, -- target is k-hot vector 80 | call = function(self, output, target) 81 | output = output:squeeze() 82 | if output:nDimension() == 1 then 83 | output = output:view(1, output:size(1)) 84 | else 85 | assert( 86 | output:nDimension() == 2, 87 | 'wrong output size (1D or 2D expected)' 88 | ) 89 | end 90 | if target:nDimension() == 1 then 91 | target = target:view(1, target:size(1)) 92 | else 93 | assert( 94 | target:nDimension() == 2, 95 | 'wrong target size (1D or 2D expected)' 96 | ) 97 | end 98 | for s = 1,#output:size() do 99 | assert( 100 | output:size(s) == target:size(s), 101 | string.format('target and output do not match on dim %d', s) 102 | ) 103 | end 104 | 105 | -- initialize upon receiving first inputs: 106 | for _,t in ipairs(self.threshold) do 107 | if self.tp[t]:nElement() == 0 then 108 | self.tp[t]:resize( target:size(2)):typeAs(output):fill(0) 109 | self.tpfn[t]:resize(target:size(2)):typeAs(output):fill(0) 110 | end 111 | end 112 | 113 | -- scores of true positives: 114 | local true_pos = output:clone() 115 | true_pos[torch.eq(target, 0)] = -1 116 | 117 | -- sum all the things: 118 | for _,t in ipairs(self.threshold) do 119 | self.tp[t]:add(torch.ge(true_pos, t):typeAs(output):sum(1):squeeze()) 120 | self.tpfn[t]:add(target:typeAs(output):sum(1):squeeze()) 121 | end 122 | end 123 | } 124 | 125 | RecallMeter.value = argcheck{ 126 | {name="self", type="tnt.RecallMeter"}, 127 | {name="t", type="number", opt=true}, 128 | call = function(self, t) 129 | if t then 130 | assert( 131 | self.tp[t] and self.tpfn[t], 132 | string.format('%f is an incorrect threshold [not set]', t) 133 | ) 134 | if self.perclass then 135 | local recallPerClass = 136 | torch.cdiv(self.tp[t], self.tpfn[t]):mul(100) 137 | recallPerClass[torch.eq(self.tpfn[t], 0)] = 100 138 | return recallPerClass 139 | else 140 | if self.tpfn[t]:sum() == 0 then return 100 end 141 | return (self.tp[t]:sum() / self.tpfn[t]:sum()) * 100 142 | end 143 | else 144 | local value = {} 145 | for _,t in ipairs(self.threshold) do 146 | value[t] = self:value(t) 147 | end 148 | return value 149 | end 150 | end 151 | } 152 | -------------------------------------------------------------------------------- /meter/timemeter.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | 13 | local TimeMeter = torch.class('tnt.TimeMeter', 'tnt.Meter', tnt) 14 | 15 | TimeMeter.__init = argcheck{ 16 | doc = [[ 17 | 18 | #### tnt.TimeMeter(@ARGP) 19 | @ARGT 20 | 21 | The `tnt.TimeMeter` is designed to measure the time between events and can be 22 | used to measure, for instance, the average processing time per batch of data. 23 | It is different from most other meters in terms of the methods it provides: 24 | 25 | At initialization time, an optional boolean parameter `unit` may be provided 26 | (default = `false`). When set to `true`, the value returned by the meter 27 | will be divided by the number of times that the `incUnit()` method is called. 28 | This allows the user to compute, for instance, the average processing time per 29 | batch by simply calling the `incUnit()` method after processing a batch. 30 | 31 | The `tnt.TimeMeter` provides the following methods: 32 | 33 | * `reset()` resets the timer, setting the timer and unit counter to zero. 34 | * `stop()` stops the timer. 35 | * `resume()` resumes the timer. 36 | * `incUnit()` increments the unit counter by one. 37 | * `value()` returns the time passed since the last `reset()`; divided by the counter value when `unit=true`. 38 | ]], 39 | {name="self", type="tnt.TimeMeter"}, 40 | {name="unit", type="boolean", default=false}, 41 | call = 42 | function(self, unit) 43 | self.unit = unit 44 | self.timer = torch.Timer() 45 | self:reset() 46 | end 47 | } 48 | 49 | TimeMeter.reset = argcheck{ 50 | {name="self", type="tnt.TimeMeter"}, 51 | call = 52 | function(self) 53 | self.timer:reset() 54 | self.n = 0 55 | end 56 | } 57 | 58 | TimeMeter.stop = argcheck{ 59 | {name="self", type="tnt.TimeMeter"}, 60 | call = 61 | function(self) 62 | self.timer:stop() 63 | end 64 | } 65 | 66 | TimeMeter.resume = argcheck{ 67 | {name="self", type="tnt.TimeMeter"}, 68 | call = 69 | function(self) 70 | self.timer:resume() 71 | end 72 | } 73 | 74 | TimeMeter.incUnit = argcheck{ 75 | {name="self", type="tnt.TimeMeter"}, 76 | {name="value", type="number", default=1}, 77 | call = 78 | function(self, value) 79 | self.n = self.n + value 80 | end 81 | } 82 | 83 | TimeMeter.value = argcheck{ 84 | {name="self", type="tnt.TimeMeter"}, 85 | call = 86 | function(self) 87 | local time = self.timer:time().real 88 | if self.unit then 89 | return time/self.n 90 | else 91 | return time 92 | end 93 | end 94 | } 95 | -------------------------------------------------------------------------------- /rocks/torchnet-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "torchnet" 2 | version = "scm-1" 3 | 4 | source = { 5 | url = "git://github.com/torchnet/torchnet.git" 6 | } 7 | 8 | description = { 9 | summary = "Torch on steroids", 10 | detailed = [[ 11 | Various abstractions for torch7. 12 | ]], 13 | homepage = "https://github.com/torchnet/torchnet", 14 | license = "BSD" 15 | } 16 | 17 | dependencies = { 18 | "lua >= 5.1", 19 | "torch >= 7.0", 20 | "nn >= 1.0", 21 | "argcheck >= 1.0", 22 | "threads >= 1.0", 23 | "md5 >= 1.0", 24 | "luafilesystem >= 1.0", 25 | "optim >= 1.0", 26 | "tds >= 1.0", 27 | } 28 | 29 | build = { 30 | type = "cmake", 31 | variables = { 32 | CMAKE_BUILD_TYPE="Release", 33 | LUA_PATH="$(LUADIR)", 34 | LUA_CPATH="$(LIBDIR)" 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /test/datasets.lua: -------------------------------------------------------------------------------- 1 | local tnt = require 'torchnet.env' 2 | local tds = require 'tds' 3 | 4 | local tester 5 | local test = torch.TestSuite() 6 | 7 | function test.TableDataset() 8 | local d = tnt.TableDataset{data = {1, 2, 3}} 9 | tester:eq(d:size(), 3) 10 | tester:eq(d:get(1), 1) 11 | end 12 | 13 | function test.ListDataset() 14 | local identity = function(...) return ... end 15 | 16 | local h = tds.hash({ 1, 2, 3}) 17 | local d = tnt.ListDataset(h, identity) 18 | tester:eq(d:size(), 3) 19 | tester:eq(d:get(1), 1) 20 | 21 | local v = tds.vec({ 1, 2, 3}) 22 | local d = tnt.ListDataset(v, identity) 23 | tester:eq(d:size(), 3) 24 | tester:eq(d:get(1), 1) 25 | 26 | local tbl = {1, 2, 3} 27 | local d = tnt.ListDataset(tbl, identity) 28 | tester:eq(d:size(), 3) 29 | tester:eq(d:get(1), 1) 30 | 31 | local tensor = torch.LongTensor{1, 2, 3} 32 | local d = tnt.ListDataset(tbl, identity) 33 | tester:eq(d:size(), 3) 34 | tester:eq(d:get(1), 1) 35 | end 36 | 37 | function test.ListDataset_path() 38 | -- With path option 39 | local prefix = function(x) return 'bar/' .. x end 40 | local tbl = {1, 2, 3} 41 | local d = tnt.ListDataset(tbl, prefix, 'foo') 42 | tester:eq(d:size(), 3) 43 | tester:eq(d:get(3), 'bar/foo/3') 44 | end 45 | 46 | function test.ListDataset_file() 47 | local filename = os.tmpname() 48 | local f = io.open(filename, 'w') 49 | for i = 1, 50 do 50 | f:write(tostring(i) .. '\n') 51 | end 52 | f:close() 53 | 54 | local identity = function(...) return ... end 55 | local d = tnt.ListDataset(filename, identity, 'foo') 56 | tester:eq(d:size(), 50) 57 | tester:eq(d:get(15), 'foo/15') 58 | 59 | os.remove(filename) 60 | end 61 | 62 | function test.TransformDataset() 63 | local d = tnt.TransformDataset{ 64 | dataset = tnt.TableDataset{data = {1, 2, 3}}, 65 | transform = function(x) return x * 2 end 66 | } 67 | tester:eq(d:size(), 3) 68 | tester:eq(d:get(2), 4) 69 | 70 | local data = { 71 | { input = 1, target = 1 }, 72 | { input = 2, target = 2 }, 73 | { input = 3, target = 3 }, 74 | } 75 | 76 | local d = tnt.TransformDataset{ 77 | dataset = tnt.TableDataset{data = data}, 78 | transform = function(x) return x * 2 end, 79 | key = 'input', 80 | } 81 | tester:eq(d:size(), 3) 82 | tester:eq(d:get(2).input, 4) 83 | tester:eq(d:get(2).target, 2) 84 | 85 | 86 | local d = tnt.TransformDataset{ 87 | dataset = tnt.TableDataset{data = data}, 88 | transforms = { 89 | input = function(x) return x + 1 end, 90 | target = function(x) return x * 2 end, 91 | }, 92 | } 93 | tester:eq(d:size(), 3) 94 | tester:eq(d:get(2).input, 3) 95 | tester:eq(d:get(2).target, 4) 96 | 97 | -- alternative way of expressing the same transform 98 | local d = tnt.TableDataset{data = data} 99 | :transform(function(x) return x + 1 end, 'input') 100 | :transform(function(x) return x * 2 end, 'target') 101 | tester:eq(d:size(), 3) 102 | tester:eq(d:get(2).input, 3) 103 | tester:eq(d:get(2).target, 4) 104 | end 105 | 106 | function test.ConcatDataset() 107 | local datasets = { 108 | tnt.TableDataset{data={1, 2, 3}}, 109 | tnt.TableDataset{data={4, 5, 6}}, 110 | tnt.TableDataset{data={7, 8, 9}}, 111 | } 112 | local d = tnt.ConcatDataset{datasets=datasets} 113 | tester:eq(d:size(), 9) 114 | for i = 1, 9 do 115 | tester:eq(d:get(i), i) 116 | end 117 | end 118 | 119 | function test.ResampleDataset() 120 | local tbl = tnt.TableDataset{data={1, 2, 3}} 121 | local function sampler(dataset, i) 122 | return (i % 3) + 1 123 | end 124 | local d = tnt.ResampleDataset(tbl, sampler) 125 | tester:eq(d:size(), 3) 126 | tester:eq(d:get(1), 2) 127 | tester:eq(d:get(3), 1) 128 | 129 | local d = tnt.ResampleDataset(tbl, sampler, 2) 130 | tester:eq(d:size(), 2) 131 | tester:eq(d:get(1), 2) 132 | local ok, _ = pcall(function() d:get(3) end) 133 | tester:assert(not ok, 'should be out of range') 134 | end 135 | 136 | function test.ShuffleDataset() 137 | local tbl = tnt.TableDataset{data={1, 2, 3, 4, 5}} 138 | local d = tnt.ShuffleDataset(tbl, sampler) 139 | tester:eq(d:size(), 5) 140 | local present = {} 141 | for i = 1, d:size() do 142 | local val = d:get(i) 143 | tester:assert(not present[val], 'every item should appear exactly once') 144 | present[val] = true 145 | end 146 | for i = 1, d:size() do 147 | tester:assert(present[d:get(i)], 'every item should appear exactly once') 148 | end 149 | end 150 | 151 | -- function test.SplitDataset() 152 | -- local tbl = tnt.TableDataset{data={1, 2, 3, 4, 5, 6}} 153 | -- local d = tnt.SplitDataset(tbl, {train=2, val=4}) 154 | -- -- partitions are sorted alphabetically, train comes before val 155 | -- tester:assert(d:size() == 2) 156 | -- 157 | -- d:select('train') 158 | -- tester:eq(d:size(), 2) 159 | -- tester:eq(d:get(1), 1) 160 | -- 161 | -- d:select('val') 162 | -- tester:eq(d:size(), 4) 163 | -- tester:eq(d:get(1), 3) 164 | -- 165 | -- -- equal weight 166 | -- local d = tnt.SplitDataset(tbl, {train=0.2, val=0.3}) 167 | -- d:select('train'); tester:eq(d:size(), 3) 168 | -- d:select('val'); tester:eq(d:size(), 3) 169 | -- end 170 | 171 | function test.BatchDataset() 172 | local data = {} 173 | for i = 1, 100 do 174 | data[i] = { 175 | input = i, 176 | target = torch.LongTensor{i, 2*i}, 177 | } 178 | end 179 | local tbl = tnt.TableDataset{data=data} 180 | local d = tnt.BatchDataset(tbl, 30) 181 | tester:eq(d:size(), 4) 182 | 183 | local batch = d:get(2) 184 | tester:eq(torch.type(batch.input), 'table') 185 | tester:eq(#batch.input, 30) 186 | tester:eq(batch.input[1], 31) 187 | tester:eq(torch.type(batch.target), 'torch.LongTensor') 188 | tester:eq(batch.target:size(1), 30) 189 | tester:eq(batch.target:numel(), 60) 190 | tester:eq(batch.target[1][2], 62) 191 | 192 | -- last batch has the remainder 193 | tester:eq(#d:get(4).input, 10) 194 | 195 | local d = tnt.BatchDataset(tbl, 30, 'skip-last') 196 | tester:eq(d:size(), 3) 197 | tester:eq(#d:get(3).input, 30) 198 | 199 | -- divisible-only should trigger an error with batch size 30 200 | tester:assertErrorPattern( 201 | function() tnt.BatchDataset(tbl, 30, 'divisible-only') end, 202 | 'not divisible') 203 | 204 | -- divisible-only should succeed with batch size 20 205 | local d = tnt.BatchDataset(tbl, 20, 'divisible-only') 206 | tester:eq(d:size(), 5) 207 | tester:eq(#d:get(3).input, 20) 208 | 209 | -- test with custom merge 210 | local d = tnt.BatchDataset{ 211 | dataset = tbl, 212 | batchsize = 30, 213 | merge = tnt.transform.tableapply(function(field) 214 | if type(field[1]) == 'number' then 215 | return torch.IntTensor(field) 216 | else 217 | return tnt.utils.table.mergetensor(field) 218 | end 219 | end), 220 | } 221 | tester:eq(d:size(), 4) 222 | tester:eq(torch.type(d:get(1).input), 'torch.IntTensor') 223 | tester:eq(d:get(1).input:numel(), 30) 224 | end 225 | 226 | function test.IndexedDataset() 227 | local tmpdir = os.tmpname() 228 | os.remove(tmpdir) -- tmpname creates the file 229 | assert(paths.mkdir(tmpdir)) 230 | 231 | -- write integers tensors 232 | local w = tnt.IndexedDatasetWriter{ 233 | indexfilename = paths.concat(tmpdir, 'ints.idx'), 234 | datafilename = paths.concat(tmpdir, 'ints.bin'), 235 | type = 'int', 236 | } 237 | for i = 1, 10 do 238 | local t = torch.range(1, i):int() 239 | w:add(t) 240 | end 241 | w:close() 242 | 243 | -- write indexed tables 244 | local w = tnt.IndexedDatasetWriter{ 245 | indexfilename = paths.concat(tmpdir, 'tables.idx'), 246 | datafilename = paths.concat(tmpdir, 'tables.bin'), 247 | type = 'table', 248 | } 249 | for i = 1, 10 do 250 | w:add({'a', 'b', i}) 251 | end 252 | w:close() 253 | 254 | local d = tnt.IndexedDataset{ 255 | path = tmpdir, 256 | fields = {'ints','tables'}, 257 | } 258 | tester:eq(d:size(), 10) 259 | tester:eq(d:get(4), { 260 | ints = torch.range(1, 4):int(), 261 | tables = {'a', 'b', 4 }, 262 | }) 263 | 264 | local d = tnt.IndexedDataset{ 265 | path = tmpdir, 266 | fields = {'tables'}, 267 | standalone = true 268 | } 269 | tester:eq(d:size(), 10) 270 | for i=1,10 do 271 | tester:eq( 272 | d:get(i), 273 | {'a', 'b', i } 274 | ) 275 | end 276 | 277 | local d = tnt.IndexedDataset{ 278 | path = tmpdir, 279 | fields = {'ints'}, 280 | standalone = true 281 | } 282 | tester:eq(d:size(), 10) 283 | for i=1,10 do 284 | tester:eq( 285 | d:get(i), 286 | torch.range(1, i):int() 287 | ) 288 | end 289 | 290 | assert(os.remove(paths.concat(tmpdir, 'ints.idx'))) 291 | assert(os.remove(paths.concat(tmpdir, 'ints.bin'))) 292 | assert(os.remove(paths.concat(tmpdir, 'tables.idx'))) 293 | assert(os.remove(paths.concat(tmpdir, 'tables.bin'))) 294 | assert(os.remove(tmpdir)) 295 | end 296 | 297 | function test.CoroutineBatchDataset_basic() 298 | -- CoroutineBatchDataset without coroutines should work like BatchDataset 299 | local data = {} 300 | for i = 1, 100 do 301 | data[i] = { 302 | input = i, 303 | target = torch.LongTensor{i, 2*i}, 304 | } 305 | end 306 | local tbl = tnt.TableDataset{data=data} 307 | local d = tnt.CoroutineBatchDataset(tbl, 20) 308 | tester:eq(d:size(), 5) 309 | tester:eq(d:get(2).input, torch.range(21, 40):totable()) 310 | tester:eq(d:get(2).target:size(), torch.LongStorage{20, 2}) 311 | end 312 | 313 | function test.CoroutineBatchDataset() 314 | local base = tnt.Dataset() 315 | base.__keys = {} 316 | base.__buffer = {} 317 | function base:size() 318 | return 100 319 | end 320 | function base:get(i) 321 | table.insert(self.__keys, i) 322 | coroutine.yield() 323 | if self.__buffer[i] == nil then 324 | tester:assert(i % 20 == 1, 'expected to be the first sample') 325 | tester:eq(#self.__keys, 20, 'expected batch of 20') 326 | for _, k in ipairs(self.__keys) do 327 | self.__buffer[k] = { input = k * 2 } 328 | end 329 | self.__keys = {} 330 | end 331 | tester:eq(#self.__keys, 0, 'keys should be empty') 332 | local val = self.__buffer[i] 333 | tester:assert(val ~= nil) 334 | self.__buffer[i] = nil 335 | return val 336 | end 337 | local d = tnt.CoroutineBatchDataset(base, 20) 338 | tester:eq(d:size(), 5) 339 | tester:eq(d:get(2), { input = torch.range(42, 80, 2):totable() }) 340 | end 341 | 342 | return function(_tester_) 343 | tester = _tester_ 344 | return test 345 | end 346 | -------------------------------------------------------------------------------- /test/iterators.lua: -------------------------------------------------------------------------------- 1 | local tnt = require 'torchnet.env' 2 | local tds = require 'tds' 3 | 4 | local tester 5 | local test = torch.TestSuite() 6 | 7 | function test.DatasetIterator() 8 | local d = tnt.TableDataset{data = {1, 2, 3, 4, 5, 6}} 9 | 10 | local itr = tnt.DatasetIterator(d) 11 | local count = 0 12 | for sample in itr:run() do 13 | count = count + 1 14 | tester:eq(sample, count) 15 | end 16 | tester:eq(count, 6) 17 | end 18 | 19 | function test.DatasetIterator_filter() 20 | local d = tnt.TableDataset{data = {1, 2, 3, 4, 5, 6}} 21 | local itr = tnt.DatasetIterator{ 22 | dataset = d, 23 | filter = function(x) return x % 2 == 0 end, 24 | } 25 | local count = 0 26 | for sample in itr:run() do 27 | count = count + 1 28 | tester:eq(sample, count * 2, 'error at ' .. count) 29 | end 30 | tester:eq(count, 3) 31 | end 32 | 33 | function test.DatasetIterator_transform() 34 | local d = tnt.TableDataset{data = {1, 2, 3, 4, 5, 6}} 35 | local itr = tnt.DatasetIterator{ 36 | dataset = d, 37 | transform = function(x) return x - 1 end, 38 | } 39 | local count = 0 40 | for sample in itr:run() do 41 | count = count + 1 42 | tester:eq(sample, count - 1, 'error at ' .. count) 43 | end 44 | tester:eq(count, 6) 45 | end 46 | 47 | function test.DatasetIterator_exec() 48 | local itr = tnt.TableDataset{data = torch.range(1, 100):totable()} 49 | :shuffle() 50 | :transform(function(i) return i end) 51 | :iterator() 52 | 53 | local output1 = {} 54 | local output2 = {} 55 | 56 | for value in itr:run() do table.insert(output1, value) end 57 | itr:exec('resample') 58 | for value in itr:run() do table.insert(output2, value) end 59 | 60 | tester:ne(output1, output2, 'dataset not shuffled') 61 | 62 | table.sort(output1) 63 | table.sort(output2) 64 | tester:eq(output1, torch.range(1, 100):totable(), 'output1 incorrect') 65 | tester:eq(output2, torch.range(1, 100):totable(), 'output2 incorrect') 66 | end 67 | 68 | function test.DatasetIterator_perm() 69 | local d = tnt.TableDataset{data = {1, 2, 3, 4, 5, 6}} 70 | local itr = tnt.DatasetIterator{ 71 | dataset = d, 72 | perm = function(x) return (x % 6) + 1 end, 73 | } 74 | local count = 0 75 | for sample in itr:run() do 76 | count = count + 1 77 | tester:eq(sample, (count % 6) + 1, 'error at ' .. count) 78 | end 79 | tester:eq(count, 6) 80 | end 81 | 82 | function test.ParallelDatasetIterator() 83 | local d = tnt.TableDataset{data = {1, 2, 3, 4, 5, 6}} 84 | local itr = tnt.ParallelDatasetIterator{ 85 | closure = function() return d end, 86 | init = function() require 'torchnet' end, 87 | nthread = 3, 88 | } 89 | local count = 0 90 | local present = {} 91 | for sample in itr:run() do 92 | tester:eq(present[sample], nil, 'duplicate sample: ' .. tostring(sample)) 93 | present[sample] = true 94 | count = count + 1 95 | end 96 | tester:eq(count, d:size()) 97 | for i = 1, d:size() do 98 | tester:eq(present[i], true, 'missing sample: ' .. tostring(i)) 99 | end 100 | end 101 | 102 | function test.ParallelDatasetIterator_ordered() 103 | -- Create a dataset in which the second item is likely to be returned out 104 | -- of order 105 | local tds = require 'tds' 106 | local c = tds.AtomicCounter(0) 107 | local d = tnt.TableDataset{data = {1, 2, 3, 4, 5, 6}}:transform(function(s) 108 | if s == 2 then 109 | repeat until c:get() ~= 0 110 | elseif s > 2 then 111 | c:inc() 112 | end 113 | return s 114 | end) 115 | 116 | local itr = tnt.ParallelDatasetIterator{ 117 | closure = function() return d end, 118 | init = function() require 'torchnet'; require 'tds' end, 119 | nthread = 3, 120 | ordered = true, 121 | } 122 | 123 | local count = 0 124 | for sample in itr:run() do 125 | count = count + 1 126 | tester:eq(sample, count, 'sample out of order') 127 | end 128 | tester:eq(count, d:size()) 129 | end 130 | 131 | return function(_tester_) 132 | tester = _tester_ 133 | return test 134 | end 135 | -------------------------------------------------------------------------------- /test/meters.lua: -------------------------------------------------------------------------------- 1 | local tnt = require 'torchnet.env' 2 | 3 | local tester 4 | local test = torch.TestSuite() 5 | 6 | function test.AverageValueMeter() 7 | local mtr = tnt.AverageValueMeter() 8 | 9 | mtr:add(1) 10 | local avg, var = mtr:value() 11 | 12 | tester:eq(avg, 1) 13 | tester:assert(var ~= var, "Variance for a single value is undefined") 14 | 15 | mtr:add(3) 16 | avg, var = mtr:value() 17 | 18 | tester:eq(avg, 2) 19 | tester:eq(var, math.sqrt(2)) 20 | end 21 | 22 | function test.ClassErrorMeter() 23 | local mtr = tnt.ClassErrorMeter{topk = {1}} 24 | 25 | local output = torch.Tensor({{1,0,0},{0,1,0},{0,0,1}}) 26 | local target = torch.Tensor({1,2,3}) 27 | mtr:add(output, target) 28 | local error = mtr:value() 29 | 30 | tester:eq(error, {0}, "All should be correct") 31 | 32 | target[1] = 2 33 | target[2] = 1 34 | target[3] = 1 35 | mtr:add(output, target) 36 | 37 | error = mtr:value() 38 | tester:eq(error, {50}, "Half, i.e. 50%, should be correct") 39 | end 40 | 41 | function test.ConfusionMeter() 42 | local mtr = tnt.ConfusionMeter{k = 3} 43 | 44 | -- The max value is the one that is correct 45 | local output = torch.Tensor({{.8,0.1,0.1},{10,11,10},{0.2,0.2,.3}}) 46 | local target = torch.Tensor({1,2,3}) 47 | mtr:add(output, target) 48 | local conf_mtrx = mtr:value() 49 | 50 | tester:eq(conf_mtrx:sum(), 3, "All should be correct") 51 | tester:eq(torch.diag(conf_mtrx):sum(), 3, "All should be correct") 52 | 53 | target[1] = 2 54 | target[2] = 1 55 | target[3] = 1 56 | mtr:add(output, target) 57 | 58 | tester:eq(conf_mtrx:sum(), 6, "Six tests should give six values") 59 | tester:eq(torch.diag(conf_mtrx):sum(), 3, "Shouldn't have changed since all new values were false") 60 | tester:eq(conf_mtrx[1]:sum(), 3, "All top have gotten one guess") 61 | tester:eq(conf_mtrx[2]:sum(), 2, "Two first at the 2nd row have a guess") 62 | tester:eq(conf_mtrx[2][3], 0, "The last one should be empty") 63 | tester:eq(conf_mtrx[3]:sum(), 1, "Bottom row has only the first test correct") 64 | tester:eq(conf_mtrx[3][3], 1, "Bottom row has only the first test correct") 65 | 66 | -- Test normalized version 67 | mtr = tnt.ConfusionMeter{k = 4, normalized=true} 68 | output = torch.Tensor({ 69 | {.8,0.1,0.1,0}, 70 | {10,11,10,0}, 71 | {0.2,0.2,.3,0}, 72 | {0,0,0,1} 73 | }) 74 | 75 | target = torch.Tensor({1,2,3,4}) 76 | mtr:add(output, target) 77 | conf_mtrx = mtr:value() 78 | 79 | tester:eq(conf_mtrx:sum(), output:size(2), "All should be correct") 80 | tester:eq(torch.diag(conf_mtrx):sum(), output:size(2), "All should be correct") 81 | 82 | target[1] = 2 83 | target[2] = 1 84 | target[3] = 1 85 | mtr:add(output, target) 86 | conf_mtrx = mtr:value() 87 | 88 | tester:eq(conf_mtrx:sum(), output:size(2), "The noramlization should sum all values to 1") 89 | for i=1,output:size(2) do 90 | tester:eq(conf_mtrx[i]:sum(), 1, "Row no " .. i .. " fails to sum to one in normalized mode") 91 | end 92 | end 93 | 94 | function test.MultilabelConfusionMeter() 95 | local mtr = tnt.MultiLabelConfusionMeter{k = 3, normalized=false} 96 | 97 | -- The max value is the one that is correct 98 | local output = torch.Tensor({{.8,0.1,0.1},{10,11,10},{0.2,0.2,.3}}) 99 | local target = torch.LongTensor({1,2,3}) 100 | local one_hot = torch.zeros(output:size()) 101 | one_hot:scatter(2, target:view(-1,1), 1) 102 | mtr:add(output, one_hot) 103 | local conf_mtrx = mtr:value() 104 | 105 | tester:eq(conf_mtrx, torch.eye(3), "All should be correct") 106 | 107 | target[1] = 2 108 | target[2] = 1 109 | target[3] = 1 110 | one_hot = torch.zeros(output:size()) 111 | one_hot:scatter(2, target:view(-1,1), 1) 112 | mtr:add(output, one_hot) 113 | conf_mtrx = mtr:value() 114 | 115 | tester:eq(conf_mtrx:sum(), 6, "Six tests should give six values") 116 | tester:eq(torch.diag(conf_mtrx):sum(), 3, "Shouldn't have changed since all new values were false") 117 | tester:eq(conf_mtrx[1]:sum(), 3, "All top have gotten one guess") 118 | tester:eq(conf_mtrx[2]:sum(), 2, "Two first at the 2nd row have a guess") 119 | tester:eq(conf_mtrx[2][3], 0, "The last one should be empty") 120 | tester:eq(conf_mtrx[3]:sum(), 1, "Bottom row has only the first test correct") 121 | tester:eq(conf_mtrx[3][3], 1, "Bottom row has only the first test correct") 122 | 123 | -- Test normalized version 124 | mtr = tnt.MultiLabelConfusionMeter{k = 4, normalized=true} 125 | output = torch.Tensor({ 126 | {.8,0.1,0.1,0}, 127 | {10,11,10,0}, 128 | {0.2,0.2,.3,0}, 129 | {0,0,0,1} 130 | }) 131 | 132 | target = torch.LongTensor({1,2,3,4}) 133 | one_hot = torch.zeros(output:size()) 134 | one_hot:scatter(2, target:view(-1,1), 1) 135 | mtr:add(output, one_hot) 136 | conf_mtrx = mtr:value() 137 | 138 | tester:eq(conf_mtrx:sum(), output:size(2), "All should be correct", 10^-3) 139 | tester:eq(torch.diag(conf_mtrx):sum(), output:size(2), "All should be correct", 10^-3) 140 | 141 | target[1] = 2 142 | target[2] = 1 143 | target[3] = 1 144 | one_hot = torch.zeros(output:size()) 145 | one_hot:scatter(2, target:view(-1,1), 1) 146 | mtr:add(output, one_hot) 147 | conf_mtrx = mtr:value() 148 | 149 | tester:eq(conf_mtrx:sum(), output:size(2), "The noramlization should sum all values to 1", 10^-3) 150 | for i=1,output:size(2) do 151 | tester:eq(conf_mtrx[i]:sum(), 1, "Row no " .. i .. " fails to sum to one in normalized mode", 10^-3) 152 | end 153 | end 154 | 155 | function test.AUCMeter() 156 | local mtr = tnt.AUCMeter() 157 | 158 | torch.manualSeed(41412) 159 | local test_size = 10^3 160 | mtr:add(torch.rand(test_size), torch.zeros(test_size)) 161 | mtr:add(torch.rand(test_size), torch.Tensor(test_size):fill(1)) 162 | local err = mtr:value() 163 | tester:eq(err, 0.5, "Random guesses should provide a AUC close to 0.5", 10^-1) 164 | 165 | mtr:reset() 166 | mtr:add(torch.Tensor(test_size):fill(0), torch.zeros(test_size)) 167 | mtr:add(torch.Tensor(test_size):fill(.1), torch.zeros(test_size)) 168 | mtr:add(torch.Tensor(test_size):fill(.2), torch.zeros(test_size)) 169 | mtr:add(torch.Tensor(test_size):fill(.3), torch.zeros(test_size)) 170 | mtr:add(torch.Tensor(test_size):fill(.4), torch.zeros(test_size)) 171 | mtr:add(torch.Tensor(test_size):fill(1), torch.Tensor(test_size):fill(1)) 172 | err = mtr:value() 173 | tester:eq(err, 1, "Only correct guesses should provide a AUC close to 1", 10^-1) 174 | end 175 | 176 | 177 | function test.APMeter() 178 | local mtr = tnt.APMeter() 179 | 180 | local target = torch.Tensor{0, 1, 0, 1} 181 | local output = torch.Tensor{.1, 0.2, 0.3, 4} 182 | local weight = torch.Tensor{0.5, 1.0, 2.0, 0.1} 183 | mtr:add(output, target, weight) 184 | 185 | local ap = mtr:value() 186 | tester:eq( 187 | ap[1], (1*0.1/0.1 + 0*2.0/2.1 + 1.1*1/3.1+ 0*1/4)/2.0, 188 | 'aptest1 failed' 189 | ) 190 | 191 | mtr:reset() 192 | mtr:add(output, target) 193 | ap = mtr:value() 194 | tester:eq(ap[1], (1*1/1 + 0*1/2 + 2*1/3 + 0*1/4)/2, 'aptest2 failed') 195 | 196 | target = torch.Tensor{0,1,0,1} 197 | output = torch.Tensor{4,3,2,1} 198 | weight = torch.Tensor{1,2,3,4} 199 | 200 | mtr:reset() 201 | mtr:add(output, target, weight) 202 | ap = mtr:value() 203 | tester:eq(ap[1], (0*1/1 + 1*2/3 + 2*0/6 + 6*1/10)/2, 'aptest3 failed') 204 | 205 | mtr:reset() 206 | mtr:add(output, target) 207 | ap = mtr:value() 208 | tester:eq(ap[1], (0*1 + 1*1/2 + 0*1/3 + 2*1/4)/2, 'aptest4 failed') 209 | 210 | target = torch.Tensor{0,1,0,1} 211 | output = torch.Tensor{1,4,2,3} 212 | weight = torch.Tensor{1,2,3,4} 213 | mtr:reset() 214 | mtr:add(output, target, weight) 215 | ap = mtr:value() 216 | tester:eq(ap[1], (4*1/4 + 6*1/6 + 0*6/9 + 0*6/10)/2, 'aptest5 failed') 217 | 218 | mtr:reset() 219 | mtr:add(output, target) 220 | ap = mtr:value() 221 | tester:eq(ap[1], (1*1 + 2*1/2 + 0*1/3 + 0*1/4)/2, 'aptest6 failed') 222 | 223 | target = torch.Tensor{0,0,0,0} 224 | output = torch.Tensor{1,4,2,3} 225 | weight = torch.Tensor{1.0, 0.1, 0.0, 0.5} 226 | mtr:reset() 227 | mtr:add(output, target, weight) 228 | 229 | ap = mtr:value() 230 | tester:eq(ap[1], 0) 231 | 232 | mtr:reset() 233 | mtr:add(output, target) 234 | ap = mtr:value() 235 | tester:eq(ap[1], 0) 236 | 237 | target = torch.Tensor{1,1,0} 238 | output = torch.Tensor{3,1,2} 239 | weight = torch.Tensor{1,0.1,3} 240 | mtr:reset() 241 | mtr:add(output, target, weight) 242 | ap = mtr:value() 243 | tester:eq(ap[1], (1*1/1 + 1*0/4 + 1.1/4.1)/2, 'aptest7 failed') 244 | 245 | mtr:reset() 246 | mtr:add(output, target) 247 | ap = mtr:value() 248 | tester:eq(ap[1], (1*1 + 0*1/2 + 2*1/3)/2, 'aptest8 failed') 249 | 250 | -- Test multiple K:s 251 | target = torch.Tensor{ 252 | {0,1,0,1}, 253 | {0,1,0,1} 254 | }:transpose(1,2) 255 | output = torch.Tensor{ 256 | {.1,.2,.3,4}, 257 | {4,3,2,1} 258 | }:transpose(1,2) 259 | weight = torch.Tensor{ 260 | {1.0, 0.5, 2.0, 3.0}, 261 | }:transpose(1,2) 262 | 263 | mtr:reset() 264 | mtr:add(output, target, weight) 265 | ap = mtr:value() 266 | tester:eq( 267 | ap, 268 | torch.DoubleTensor{ 269 | (1*3.0/3.0 + 0*3.0/5.0 + 3.5*1/5.5 + 0*3.5/6.5)/2, 270 | (0*1.0/1.0 + 1*0.5/1.5 + 0*0.5/3.5 + 1*3.5/6.5)/2 271 | }, 272 | 'aptest9 failed' 273 | ) 274 | 275 | mtr:reset() 276 | mtr:add(output, target) 277 | ap = mtr:value() 278 | tester:eq( 279 | ap, 280 | torch.DoubleTensor{ 281 | (1*1 + 0*1/2 + 2*1/3 + 0*1/4)/2, 282 | (0*1 + 1*1/2 + 0*1/3 + 2*1/4)/2 283 | }, 284 | 'aptest10 failed' 285 | ) 286 | 287 | mtr:reset() 288 | output = torch.DoubleTensor(5, 4):fill(.25) 289 | target = torch.LongTensor(5, 4):fill(1) 290 | mtr:add(output, target) 291 | output = torch.DoubleTensor(1, 4):fill(.25) 292 | target = torch.LongTensor(1, 4):fill(1) 293 | mtr:add(output, target) 294 | tester:assert(mtr:value(), 'aptest11 failed') 295 | end 296 | 297 | 298 | function test.mAPMeter() 299 | local mtr = tnt.mAPMeter() 300 | 301 | local target = torch.Tensor{0,1,0,1} 302 | local output = torch.Tensor{.1,.2,.3,4} 303 | local weight = torch.Tensor{0.5, 1.0, 2.0, 0.1} 304 | mtr:add(output, target) 305 | 306 | local ap = mtr:value() 307 | tester:eq(ap, (1*1 + 0*1/2 + 2*1/3 + 0*1/4)/2) 308 | 309 | mtr:reset() 310 | mtr:add(output, target, weight) 311 | ap = mtr:value() 312 | tester:eq(ap, (1*0.1/0.1 + 0*2.0/2.1 + 1.1*1/3.1+ 0*1/4)/2.0) 313 | 314 | -- Test multiple K:s 315 | mtr:reset() 316 | target = torch.Tensor{ 317 | {0,1,0,1}, 318 | {0,1,0,1} 319 | }:transpose(1,2) 320 | output = torch.Tensor{ 321 | {.1,.2,.3,4}, 322 | {4,3,2,1} 323 | }:transpose(1,2) 324 | weight = torch.Tensor{ 325 | {1.0, 0.5, 2.0, 3.0}, 326 | }:transpose(1,2) 327 | 328 | mtr:add(output, target) 329 | 330 | ap = mtr:value() 331 | tester:eq( 332 | ap, 333 | torch.DoubleTensor{ 334 | (1*1 + 0*1/2 + 2*1/3 + 0*1/4)/2, 335 | (0*1 + 1*1/2 + 0*1/3 + 2*1/4)/2 336 | }:mean() 337 | ) 338 | 339 | mtr:reset() 340 | mtr:add(output, target, weight) 341 | ap = mtr:value() 342 | tester:eq( 343 | ap, 344 | torch.DoubleTensor{ 345 | (1*3.0/3.0 + 0*3.0/5.0 + 3.5*1/5.5 + 0*3.5/6.5)/2, 346 | (0*1.0/1.0 + 1*0.5/1.5 + 0*0.5/3.5 + 1*3.5/6.5)/2 347 | }:mean() 348 | ) 349 | end 350 | 351 | 352 | function test.MovingAverageValueMeter() 353 | -- Moving average meter with windowsize = 3 354 | local mtr = tnt.MovingAverageValueMeter(3) 355 | 356 | mtr:add(1) 357 | local avg, var = mtr:value() 358 | 359 | tester:eq(avg, 1) 360 | tester:eq(var, 0) 361 | 362 | mtr:add(3) 363 | avg, var = mtr:value() 364 | 365 | tester:eq(avg, 2) 366 | tester:eq(var, math.sqrt(2)) 367 | 368 | mtr:add(5) 369 | avg, var = mtr:value() 370 | 371 | tester:eq(avg, 3) 372 | tester:eq(var, 2) 373 | 374 | mtr:add(4) 375 | avg, var = mtr:value() 376 | 377 | tester:eq(avg, 4) 378 | tester:eq(var, 1) 379 | 380 | mtr:add(0) 381 | avg, var = mtr:value() 382 | 383 | tester:eq(avg, 3) 384 | tester:eq(var, math.sqrt(7)) 385 | end 386 | 387 | 388 | function test.NDCGMeter() 389 | local mtr = tnt.NDCGMeter{K = {6}} 390 | 391 | -- From: https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG 392 | local relevance = torch.DoubleTensor{3,2,3,0,1,2} 393 | local output = torch.linspace(relevance:size(1), 1, relevance:size(1)):double() 394 | mtr:add(output, relevance) 395 | 396 | local est = mtr:value() 397 | tester:eq(est[6], 0.932, "Problematic nDGC with K=6", 10^-3) 398 | end 399 | 400 | function test.PrecisionMeter() 401 | local mtr = tnt.PrecisionMeter{} 402 | 403 | local target = torch.zeros(10) 404 | target:narrow(1, 3, 7):fill(1) 405 | local output = torch.zeros(10) 406 | output:narrow(1, 2, 6):fill(1) 407 | 408 | mtr:add(output, target) 409 | 410 | local tp = 5 411 | local fp = 1 412 | tester:eq(mtr:value()[0.5], 100 * tp / (tp + fp), "Basic test", 10^-2) 413 | 414 | mtr = tnt.PrecisionMeter{ 415 | threshold = {.5, .7} 416 | } 417 | 418 | target = torch.zeros(10) 419 | target:narrow(1, 3, 7):fill(1) 420 | output = torch.zeros(10) 421 | output:narrow(1, 2, 3):fill(.5) 422 | output:narrow(1, 5, 3):fill(.7) 423 | 424 | mtr:add(output, target) 425 | tp = 5 426 | fp = 1 427 | tester:eq(mtr:value()[0.5], 100 * tp / (tp + fp), "Cutoff at 0.5", 10^-2) 428 | tester:eq(mtr:value()[0.7], 100, "Cutoff at 0.7", 10^-2) 429 | end 430 | 431 | function test.PrecisionAtKMeter() 432 | local mtr = tnt.PrecisionAtKMeter{topk = {1, 2, 3}} 433 | 434 | local target = torch.eye(3) 435 | local output = torch.Tensor{ 436 | {.5,.1,.4}, 437 | {.1,.5,.4}, 438 | {.1,.4,.5} 439 | } 440 | mtr:add(output, target) 441 | tester:eq(mtr:value()[1], 100*3/3, "Top 1 matches", 10^-3) 442 | tester:eq(mtr:value()[2], 100*3/(3*2), "Top 2 matches", 10^-3) 443 | tester:eq(mtr:value()[3], 100*3/(3*3), "Top 3 matches", 10^-3) 444 | 445 | mtr:add(output, target) -- Adding the same twice shouldn't change anything 446 | tester:eq(mtr:value()[1], 100*3/3, "Top 1 matches", 10^-3) 447 | tester:eq(mtr:value()[2], 100*3/(3*2), "Top 2 matches", 10^-3) 448 | tester:eq(mtr:value()[3], 100*3/(3*3), "Top 3 matches", 10^-3) 449 | 450 | mtr:reset() 451 | target[1][3] = 1 452 | 453 | mtr:add(output, target) 454 | tester:eq(mtr:value()[1], 100*3/3, "Top 1 matches", 10^-3) 455 | tester:eq(mtr:value()[2], 100*(3 + 1)/(3 * 2), "Top 2 matches", 10^-3) 456 | tester:eq(mtr:value()[3], 100*(3 + 1)/(3 * 3), "Top 3 matches", 10^-3) 457 | 458 | mtr:reset() 459 | output = torch.Tensor{ 460 | {.1,.5,.4}, 461 | {.1,.5,.4}, 462 | {.5,.4,.1} 463 | } 464 | mtr:add(output, target) 465 | tester:eq(mtr:value()[1], 100*1/3, "Top 1 matches", 10^-3) 466 | tester:eq(mtr:value()[2], 100*(1 + 1)/(3 * 2), "Top 2 matches", 10^-3) 467 | tester:eq(mtr:value()[3], 100*(1 + 1 + 2)/(3 * 3), "Top 3 matches", 10^-3) 468 | 469 | local mtr = tnt.PrecisionAtKMeter{topk = {1, 2, 3}, online = true} 470 | 471 | local target = torch.eye(3) 472 | local output = torch.Tensor{ 473 | {.5,.1,.4}, 474 | {.1,.5,.4}, 475 | {.1,.4,.5} 476 | } 477 | mtr:add(output, target) 478 | tester:eq(mtr:value()[1], 100*3/3, "Top 1 matches", 10^-3) 479 | tester:eq(mtr:value()[2], 100*3/(3*2), "Top 2 matches", 10^-3) 480 | tester:eq(mtr:value()[3], 100*3/(3*3), "Top 3 matches", 10^-3) 481 | 482 | mtr:add(output, target) 483 | tester:eq(mtr:value()[1], 100*3/3, "Top 1 matches", 10^-3) 484 | tester:eq(mtr:value()[2], 100*3*2/(3*2), "Top 2 matches", 10^-3) 485 | tester:eq(mtr:value()[3], 100*3*2/(3*3), "Top 3 matches", 10^-3) 486 | end 487 | 488 | function test.RecallMeter() 489 | local mtr = tnt.RecallMeter{} 490 | 491 | local target = torch.zeros(10) 492 | target:narrow(1, 3, 7):fill(1) 493 | local output = torch.zeros(10) 494 | output:narrow(1, 2, 6):fill(1) 495 | 496 | mtr:add(output, target) 497 | 498 | local tp = 5 499 | local fn = 2 500 | tester:eq(mtr:value()[0.5], 100 * tp / (tp + fn), "Basic test", 10^-2) 501 | 502 | mtr = tnt.RecallMeter{ 503 | threshold = {.5, .7} 504 | } 505 | 506 | target = torch.zeros(10) 507 | target:narrow(1, 3, 7):fill(1) 508 | output = torch.zeros(10) 509 | output:narrow(1, 2, 3):fill(.5) 510 | output:narrow(1, 5, 3):fill(.7) 511 | 512 | mtr:add(output, target) 513 | tester:eq(mtr:value()[0.5], 100 * tp / (tp + fn), "Cutoff at 0.5", 10^-2) 514 | tp = tp - 2 515 | fn = fn + 2 516 | tester:eq(mtr:value()[0.7], 100 * tp / (tp + fn), "Cutoff at 0.7", 10^-2) 517 | end 518 | 519 | function test.TimeMeter() 520 | local mtr = tnt.TimeMeter() 521 | 522 | local function wait(seconds) 523 | local start = os.time() 524 | repeat until os.time() > start + seconds 525 | end 526 | 527 | mtr:reset() 528 | wait(1) 529 | local passed_time = mtr:value() 530 | tester:assert(passed_time < 2, 531 | ("Too long time passed: %.1f sec >= 2 sec"):format(passed_time)) 532 | tester:assert(passed_time > .5, 533 | ("Too short time passed: %.1f sec <= 0.5 sec"):format(passed_time)) 534 | end 535 | 536 | return function(_tester_) 537 | tester = _tester_ 538 | return test 539 | end 540 | -------------------------------------------------------------------------------- /test/test.lua: -------------------------------------------------------------------------------- 1 | local __main__ = package.loaded['torchnet.env'] == nil 2 | 3 | local tnt = require 'torchnet.env' 4 | local tds = require 'tds' 5 | 6 | if __main__ then 7 | require 'torchnet' 8 | end 9 | 10 | local tester = torch.Tester() 11 | tester:add(paths.dofile('datasets.lua')(tester)) 12 | tester:add(paths.dofile('iterators.lua')(tester)) 13 | tester:add(paths.dofile('meters.lua')(tester)) 14 | 15 | function tnt.test(tests) 16 | tester:run(tests) 17 | return tester 18 | end 19 | 20 | if __main__ then 21 | require 'torchnet' 22 | if #arg > 0 then 23 | tnt.test(arg) 24 | else 25 | tnt.test() 26 | end 27 | end 28 | -------------------------------------------------------------------------------- /transform.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local argcheck = require 'argcheck' 12 | local utils = require 'torchnet.utils' 13 | local doc = require 'argcheck.doc' 14 | 15 | doc[[ 16 | 17 | ### tnt.transform 18 | 19 | *Torchnet* provides a set of general data transformations. 20 | These transformations are either directly on the data (e.g., normalization) 21 | or on their structure. This is particularly handy 22 | when manipulating [tnt.Dataset](#tnt.Dataset). 23 | 24 | Most of the transformations are simple but can be [composed](#transform.compose) or 25 | [merged](#transform.merge). 26 | ]] 27 | 28 | local transform = {} 29 | tnt.transform = transform 30 | local unpack = unpack or table.unpack 31 | 32 | 33 | doc[[ 34 | 35 | #### transform.identity(...) 36 | 37 | The identity transform takes any input and return it as it is. 38 | 39 | For example, this function is useful when composing 40 | transformations on data from multiple sources, and some of the sources 41 | must not be transformed. 42 | ]] 43 | 44 | transform.identity = 45 | function(...) 46 | local args = {...} 47 | return function() 48 | return unpack(args) 49 | end 50 | end 51 | 52 | transform.compose = argcheck{ 53 | doc = [[ 54 | 55 | #### transform.compose(@ARGP) 56 | @ARGT 57 | 58 | This function takes a `table` of functions and 59 | composes them to return one transformation. 60 | 61 | This function assumes that the table of transformations 62 | is indexed by contiguous ordered keys starting at 1. 63 | The transformations are composed in the ascending order. 64 | 65 | For example, the following code: 66 | ```lua 67 | > f = transform.compose{ 68 | [1] = function(x) return 2*x end, 69 | [2] = function(x) return x + 10 end, 70 | foo = function(x) return x / 2 end, 71 | [4] = function(x) return x - x end 72 | } 73 | > f(3) 74 | 16 75 | ``` 76 | is equivalent to compose the transformations stored in [1] and [2], i.e., 77 | defining the following transformation: 78 | ```lua 79 | > f = function(x) return 2*x + 10 end 80 | ``` 81 | Note that transformations stored with keys `foo` and `4` are ignored. 82 | ]], 83 | {name='transforms', type='table'}, 84 | call = 85 | function(transforms) 86 | for k,v in ipairs(transforms) do 87 | assert(type(v) == 'function', 'table of functions expected') 88 | end 89 | transforms = utils.table.copy(transforms) 90 | return 91 | function(z) 92 | for _, trans in ipairs(transforms) do 93 | z = trans(z) 94 | end 95 | return z 96 | end 97 | end 98 | } 99 | 100 | transform.merge = argcheck{ 101 | doc = [[ 102 | 103 | #### transform.merge(@ARGP) 104 | @ARGT 105 | 106 | This function takes a `table` of transformations and 107 | merge them into one transformation. 108 | Once apply to an input, this transformation will produce a `table` of output, 109 | containing the transformed input. 110 | 111 | For example, the following code: 112 | ```lua 113 | > f = transform.merge{ 114 | [1] = function(x) return 2*x end, 115 | [2] = function(x) return x + 10 end, 116 | foo = function(x) return x / 2 end, 117 | [4] = function(x) return x - x end 118 | } 119 | ``` 120 | produces a function which applies a set of transformations to the same input: 121 | ```lua 122 | > f(3) 123 | { 124 | 1 : 6 125 | 2 : 13 126 | foo : 1.5 127 | 4 : 0 128 | } 129 | ``` 130 | ]], 131 | {name='transforms', type='table'}, 132 | call = 133 | function(transforms) 134 | for k,v in pairs(transforms) do 135 | assert(type(v) == 'function', 'table of functions expected') 136 | end 137 | transforms = utils.table.copy(transforms) 138 | return 139 | function(z) 140 | local newz = {} 141 | for k, trans in pairs(transforms) do 142 | newz[k] = trans(z) 143 | end 144 | return utils.table.mergetensor(newz) 145 | end 146 | end 147 | } 148 | 149 | transform.tablenew = argcheck{ 150 | doc = [[ 151 | 152 | #### transform.tablenew() 153 | 154 | This function creates a new table of functions from an 155 | existing table of functions. 156 | ]], 157 | call = 158 | function() 159 | return 160 | function(func) 161 | local tbl = {} 162 | for k,v in pairs(func) do 163 | tbl[k] = v 164 | end 165 | return tbl 166 | end 167 | end 168 | } 169 | 170 | transform.tableapply = argcheck{ 171 | doc = [[ 172 | 173 | #### transform.tableapply(@ARGP) 174 | @ARGT 175 | 176 | This function applies a transformation to a table of input. 177 | It return a table of output of the same size as the input. 178 | 179 | For example, the following code: 180 | ```lua 181 | > f = transform.tableapply(function(x) return 2*x end) 182 | ``` 183 | produces a function which multiplies any input by 2: 184 | ```lua 185 | > f({[1] = 1, [2] = 2, foo = 3, [4] = 4}) 186 | { 187 | 1 : 2 188 | 2 : 4 189 | foo : 6 190 | 4 : 8 191 | } 192 | ``` 193 | ]], 194 | {name='transform', type='function'}, 195 | call = 196 | function(transform) 197 | return 198 | function(tbl) 199 | return utils.table.foreach(tbl, transform) 200 | end 201 | end 202 | } 203 | 204 | transform.tablemergekeys = argcheck{ 205 | doc = [[ 206 | 207 | #### transform.tablemergekeys() 208 | 209 | This function merges tables by key. More precisely, the input must be a 210 | `table` of `table` and this function will reverse the table orderto 211 | make the keys from the nested table accessible first. 212 | 213 | For example, if the input is: 214 | ```lua 215 | > x = { sample1 = {input = 1, target = "a"} , sample2 = {input = 2, target = "b", flag = "hard"} 216 | ``` 217 | Then apply this function will produce: 218 | ```lua 219 | > transform.tablemergekeys(x) 220 | { 221 | input : 222 | { 223 | sample1 : 1 224 | sample2 : 2 225 | } 226 | target : 227 | { 228 | sample1 : "a" 229 | sample2 : "b" 230 | } 231 | flag : 232 | { 233 | sample2: "hard" 234 | } 235 | } 236 | ``` 237 | ]], 238 | call = 239 | function() 240 | return 241 | function(tbl) 242 | local mergedtbl = {} 243 | for idx, elem in pairs(tbl) do 244 | for key, value in pairs(elem) do 245 | if not mergedtbl[key] then 246 | mergedtbl[key] = {} 247 | end 248 | mergedtbl[key][idx] = value 249 | end 250 | end 251 | return mergedtbl 252 | end 253 | end 254 | } 255 | 256 | transform.makebatch = argcheck{ 257 | doc = [[ 258 | 259 | #### transform.makebatch(@ARGP) 260 | @ARGT 261 | 262 | This function is used in many `tnt.Dataset` to format 263 | samples in the format used by the `tnt.Engine`. 264 | 265 | This function first [merges keys](#transform.tablemergekeys) to 266 | produces a table of output. Then, transform this table into a tensor by 267 | either using a `merge` transformation provided by the user or by 268 | simply concatenating the table into a tensor directly. 269 | 270 | This function uses the [compose](#transform.compose) transform to apply 271 | successive transformations. 272 | ]], 273 | {name='merge', type='function', opt=true}, 274 | call = 275 | function(merge) 276 | 277 | local makebatch 278 | if merge then 279 | makebatch = transform.compose{ 280 | transform.tablemergekeys(), 281 | merge 282 | } 283 | else 284 | makebatch = transform.compose{ 285 | transform.tablemergekeys(), 286 | transform.tableapply( 287 | function(field) 288 | if utils.table.canmergetensor(field) then 289 | return utils.table.mergetensor(field) 290 | else 291 | return field 292 | end 293 | end 294 | ) 295 | } 296 | end 297 | 298 | return 299 | function(samples) 300 | assert(type(samples) == 'table', 'makebatch: table of samples expected') 301 | return makebatch(samples) 302 | end 303 | end 304 | } 305 | 306 | transform.randperm = argcheck{ 307 | doc = [[ 308 | 309 | #### transform.randperm(@ARGP) 310 | @ARGT 311 | 312 | This function create a vector containing a permutation of the indices from 1 to `size`. 313 | This vector is a `LongTensor` and `size` must be a number. 314 | 315 | Once the vector created, this function can be used to call a specific indices in it. 316 | 317 | For example: 318 | ```lua 319 | > p = transform.randperm(3) 320 | ``` 321 | creates a function `p` which contains a permutation of indices: 322 | ```lua 323 | > p(1) 324 | 2 325 | > p(2) 326 | 1 327 | > p(3) 328 | 3 329 | ``` 330 | ]], 331 | {name="size", type="number"}, 332 | call = 333 | function(size) 334 | local perm = torch.randperm(size, 'torch.LongTensor') 335 | return 336 | function(idx) 337 | return perm[idx] 338 | end 339 | end 340 | } 341 | 342 | transform.normalize = argcheck{ 343 | doc = [[ 344 | 345 | #### transform.normalize(@ARGP) 346 | @ARGT 347 | 348 | This function normalizes data, i.e., it removes its mean and 349 | divide it by its standard deviation. 350 | 351 | The input must be a `Tensor`. 352 | 353 | Once create, a `threshold` can be given (must be a number). Then, 354 | the data will be divided by their standard deviation, only if this 355 | deviation is greater than the `threshold`. This is handy, if the 356 | deviation is small and deviding by it could lead to unstability. 357 | ]], 358 | {name='threshold', type='number', default=0}, 359 | call = 360 | function(threshold) 361 | return 362 | function(z) 363 | local std = z:std() 364 | z:add(-z:mean()) 365 | if std > threshold then 366 | z:div(std) 367 | end 368 | return z 369 | end 370 | end 371 | } 372 | 373 | return transform 374 | -------------------------------------------------------------------------------- /utils/init.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local tnt = require 'torchnet.env' 11 | local doc = require 'argcheck.doc' 12 | 13 | doc[[ 14 | 15 | ### tnt.utils 16 | 17 | *Torchnet* provides a set of util functions which are used all over torchnet. 18 | ]] 19 | 20 | local utils = {} 21 | tnt.utils = utils 22 | 23 | utils.table = require 'torchnet.utils.table' 24 | utils.nn = require 'torchnet.utils.nn' 25 | utils.sys = require 'torchnet.utils.sys' 26 | 27 | return utils 28 | -------------------------------------------------------------------------------- /utils/nn.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local argcheck = require 'argcheck' 11 | 12 | local unn = {} 13 | 14 | local function isz(okw, odw, kw, dw, pw) 15 | dw = dw or 1 16 | pw = pw or 0 17 | okw = okw*dw + kw-dw-pw*2 18 | odw = odw*dw 19 | return okw, odw 20 | end 21 | 22 | -- in specs, list of {kw=,dw=,pw=} (dw and pw optionals) 23 | -- kw (kernel width) 24 | -- dw (kernel stride) 25 | -- pw (padding) 26 | unn.inferinputsize = argcheck{ 27 | {name="specs", type="table"}, 28 | {name="size", type="number", default=1}, 29 | {name="verbose", type="boolean", default=false}, 30 | call = 31 | function(specs, size, verbose) 32 | local okw, odw = size, 1 33 | for i=#specs,1,-1 do 34 | if specs[i].kw then 35 | okw, odw = isz(okw, odw, specs[i].kw, specs[i].dw, specs[i].pw) 36 | end 37 | if verbose then 38 | print(string.format( 39 | "|| layer %d: size=%d stride=%d", 40 | i, 41 | okw, 42 | odw)) 43 | end 44 | end 45 | return okw, odw 46 | end 47 | } 48 | 49 | local function iszr(okw, odw, kw, dw, pw) 50 | dw = dw or 1 51 | pw = pw or 0 52 | okw = math.floor((okw+2*pw-kw)/dw)+1 53 | odw = odw * dw 54 | return okw, odw 55 | end 56 | 57 | unn.inferoutputsize = argcheck{ 58 | {name="specs", type="table"}, 59 | {name="size", type="number"}, 60 | {name="verbose", type="boolean", default=false}, 61 | call = 62 | function(specs, size, verbose) 63 | local okw, odw = size, 1 64 | for i=1,#specs do 65 | okw, odw = iszr(okw, odw, specs[i].kw, specs[i].dw, specs[i].pw) 66 | if verbose then 67 | print(string.format("|| layer %d: size=%dx stride=%d", okw, odw)) 68 | end 69 | end 70 | return okw, odw 71 | end 72 | } 73 | 74 | return unn 75 | -------------------------------------------------------------------------------- /utils/sys.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local md5 = require 'md5' 11 | local lfs = require 'lfs' 12 | local tds = require 'tds' 13 | 14 | local sys = {} 15 | 16 | function sys.md5(obj) 17 | local str = torch.serialize(obj) 18 | return md5.sumhexa(str) 19 | end 20 | 21 | function sys.mkdir(path) 22 | assert( 23 | os.execute(string.format('mkdir -p %s', path)), 24 | 'could not create directory' 25 | ) 26 | end 27 | 28 | local function cmdlinecode(code, env) 29 | assert(type(code) == 'string', 'string expected') 30 | local msg 31 | if loadstring then -- lua 5.1 32 | code, msg = loadstring(code) 33 | if code then 34 | setfenv(code, env) 35 | end 36 | else 37 | code, msg = load(code, nil, nil, env) -- lua 5.2 38 | end 39 | if not code then 40 | error(string.format('compilation error: %s', msg)) 41 | end 42 | assert(not getmetatable(env), 'env should have no metatable') 43 | setmetatable(env, {__index=_G}) 44 | local status, msg = pcall(code) 45 | setmetatable(env, nil) 46 | if not status then 47 | error(msg) 48 | end 49 | end 50 | 51 | function sys.cmdline(arg, env) 52 | for _, code in ipairs(arg) do 53 | cmdlinecode(code, env) 54 | end 55 | end 56 | 57 | function sys.loadlist(path, revert, maxload) 58 | local lst = tds.hash() 59 | local idx = 0 60 | for elem in io.lines(path) do 61 | idx = idx + 1 62 | lst[idx] = elem 63 | if revert then 64 | lst[elem] = idx 65 | end 66 | if maxload and maxload == idx then 67 | break 68 | end 69 | end 70 | return lst 71 | end 72 | 73 | function sys.listimgfiles(path, lst) 74 | lst = lst or tds.hash() 75 | for filename in lfs.dir(path) do 76 | if filename ~= '.' and filename ~= '..' then 77 | local fullpath = string.format('%s/%s', path, filename) 78 | if lfs.attributes(fullpath, 'mode') == 'directory' then 79 | sys.listimgfiles(fullpath, lst) 80 | else 81 | local ext = filename:match('[^%.]+$') 82 | if ext then 83 | ext = ext:lower() 84 | if ext == 'jpeg' or ext == 'jpg' or ext == 'png' then 85 | lst[#lst+1] = fullpath 86 | else 87 | print(string.format('ignoring <%s>', fullpath)) 88 | end 89 | end 90 | end 91 | end 92 | end 93 | return lst 94 | end 95 | 96 | return sys 97 | -------------------------------------------------------------------------------- /utils/table.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local argcheck = require 'argcheck' 11 | local doc = require 'argcheck.doc' 12 | 13 | local utable = {} 14 | 15 | doc[[ 16 | 17 | #### tnt.utils.table.clone(table) 18 | 19 | This function do a deep copy of a table. 20 | 21 | ]] 22 | 23 | function utable.clone(tbl) 24 | return torch.deserializeFromStorage(torch.serializeToStorage(tbl)) 25 | end 26 | 27 | function utable.copy(tbl) 28 | local cpy = {} 29 | for k,v in pairs(tbl) do 30 | cpy[k] = v 31 | end 32 | return cpy 33 | end 34 | 35 | utable.merge = argcheck{ 36 | doc = [[ 37 | 38 | #### tnt.utils.table.merge(@ARGP) 39 | @ARGT 40 | 41 | This function add to the destination table `dest`, the 42 | element contained in the source table `source`. 43 | 44 | The copy is shallow. 45 | 46 | If a key exists in both tables, then the element in the source table 47 | is preferred. 48 | ]], 49 | {name = "dst", type = 'table'}, 50 | {name = "src", type = 'table'}, 51 | call = function (dst, src) 52 | for k,v in pairs(src) do 53 | dst[k] = v 54 | end 55 | return dst 56 | end 57 | } 58 | 59 | utable.foreach = argcheck{ 60 | doc = [[ 61 | 62 | #### tnt.utils.table.foreach(@ARGP) 63 | @ARGT 64 | 65 | This function applies the function defined by `closure` to the 66 | table `tbl`. 67 | 68 | If `recursive` is given and set to `true`, the `closure` function 69 | will be apply recursively to the table. 70 | ]], 71 | {name = "tbl", type = 'table'}, 72 | {name = "closure", type = 'function'}, 73 | {name = "recursive", type = 'boolean', default = false}, 74 | call = function(tbl, closure, recursive) 75 | local newtbl = {} 76 | for k,v in pairs(tbl) do 77 | if recursive and type(v) == 'table' then 78 | newtbl[k] = utable.foreach(v, closure, recursive) 79 | else 80 | newtbl[k] = closure(v) 81 | end 82 | end 83 | return newtbl 84 | end 85 | } 86 | 87 | doc[[ 88 | 89 | #### tnt.utils.table.canmergetensor(tbl) 90 | 91 | Check if a table can be merged into a tensor. 92 | ]] 93 | 94 | function utable.canmergetensor(tbl) 95 | if type(tbl) ~= 'table' then 96 | return false 97 | end 98 | 99 | local typename = torch.typename(tbl[1]) 100 | if typename and typename:match('Tensor') then 101 | local sz = tbl[1]:nElement() 102 | for i=2,#tbl do 103 | -- can't merge tensors of different sizes 104 | if tbl[i]:nElement() ~= sz then 105 | return false 106 | end 107 | end 108 | return true 109 | end 110 | return false 111 | end 112 | 113 | utable.mergetensor = argcheck{ 114 | doc = [[ 115 | 116 | #### tnt.utils.table.mergetensor(@ARGP) 117 | @ARGT 118 | 119 | Merge a table into a tensor in one extra dimension. 120 | ]], 121 | {name = 'tbl', type = 'table'}, 122 | call = function(tbl) 123 | local sz = tbl[1]:size():totable() 124 | table.insert(sz, 1, #tbl) 125 | sz = torch.LongStorage(sz) 126 | local res = tbl[1].new():resize(sz) 127 | for i=1,#tbl do 128 | res:select(1, i):copy(tbl[i]) 129 | end 130 | return res 131 | end 132 | } 133 | 134 | return utable 135 | --------------------------------------------------------------------------------