├── env.lua ├── .gitignore ├── image ├── init.lua └── transformimage.lua ├── test ├── transformimage.lua └── test.lua ├── CMakeLists.txt ├── datasets ├── init.lua ├── ethzfood101.lua ├── upmcfood101.lua ├── mit67.lua └── utils.lua ├── init.lua ├── models ├── init.lua ├── inceptionv4.lua ├── inceptionresnetv2.lua ├── resnet.lua ├── inceptionv3.lua ├── vggm.lua ├── vgg16.lua └── overfeat.lua ├── rocks └── torchnet-vision-scm-1.rockspec ├── CONTRIBUTING.md ├── LICENSE ├── .travis.yml ├── PATENTS ├── README.md └── example ├── upmcfood101extract.lua └── mit67finetuning.lua /env.lua: -------------------------------------------------------------------------------- 1 | local vision = {} 2 | 3 | return vision 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | ._.DS_Store 3 | build.luarocks 4 | example/data 5 | example/logs 6 | example/models 7 | example/features 8 | -------------------------------------------------------------------------------- /image/init.lua: -------------------------------------------------------------------------------- 1 | local vision = require 'torchnet-vision.env' 2 | 3 | local image = {} 4 | vision.image = image 5 | 6 | image.transformimage = require 'torchnet-vision.image.transformimage' 7 | 8 | return image 9 | -------------------------------------------------------------------------------- /test/transformimage.lua: -------------------------------------------------------------------------------- 1 | local vision = require 'torchnet-vision.env' 2 | local tds = require 'tds' 3 | 4 | local tester 5 | local test = torch.TestSuite() 6 | 7 | function test.TransformImage() 8 | end 9 | 10 | return function(_tester_) 11 | tester = _tester_ 12 | return test 13 | end -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required (VERSION 2.8) 2 | cmake_policy(VERSION 2.8) 3 | 4 | set(PKGNAME torchnet-vision) 5 | 6 | file(GLOB_RECURSE luafiles RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.lua") 7 | 8 | foreach(file ${luafiles}) 9 | get_filename_component(dir ${file} PATH) 10 | install(FILES ${file} DESTINATION ${LUA_PATH}/${PKGNAME}/${dir}) 11 | endforeach() 12 | -------------------------------------------------------------------------------- /datasets/init.lua: -------------------------------------------------------------------------------- 1 | local vision = require 'torchnet-vision.env' 2 | 3 | local datasets = {} 4 | vision.datasets = datasets 5 | 6 | datasets.utils = require 'torchnet-vision.datasets.utils' 7 | datasets.upmcfood101 = require 'torchnet-vision.datasets.upmcfood101' 8 | datasets.ethzfood101 = require 'torchnet-vision.datasets.ethzfood101' 9 | datasets.mit67 = require 'torchnet-vision.datasets.mit67' 10 | 11 | return datasets 12 | -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | 3 | local tnt = require 'torchnet' 4 | local vision = require 'torchnet-vision.env' 5 | local doc = require 'argcheck.doc' 6 | 7 | -- doc[[]] 8 | 9 | require 'torchnet-vision.image' 10 | 11 | require 'torchnet-vision.models' 12 | 13 | require 'torchnet-vision.datasets' 14 | 15 | require 'torchnet-vision.test.test' 16 | 17 | tnt.makepackageserializable(vision, 'torchnet-vision') 18 | 19 | return vision 20 | -------------------------------------------------------------------------------- /test/test.lua: -------------------------------------------------------------------------------- 1 | local __main__ = package.loaded['torchnet-vision.env'] == nil 2 | 3 | local vision = require 'torchnet-vision.env' 4 | local tds = require 'tds' 5 | 6 | if __main__ then 7 | require 'torchnet-vision' 8 | end 9 | 10 | local tester = torch.Tester() 11 | tester:add(paths.dofile('transformimage.lua')(tester)) 12 | 13 | function vision.test(tests) 14 | tester:run(tests) 15 | return tester 16 | end 17 | 18 | if __main__ then 19 | require 'torchnet-vision' 20 | if #arg > 0 then 21 | vision.test(arg) 22 | else 23 | vision.test() 24 | end 25 | end 26 | -------------------------------------------------------------------------------- /models/init.lua: -------------------------------------------------------------------------------- 1 | local vision = require 'torchnet-vision.env' 2 | 3 | local models = {} 4 | vision.models = models 5 | 6 | models.overfeat = require 'torchnet-vision.models.overfeat' 7 | models.vggm = require 'torchnet-vision.models.vggm' 8 | models.vgg16 = require 'torchnet-vision.models.vgg16' 9 | models.inceptionv3 = require 'torchnet-vision.models.inceptionv3' 10 | models.inceptionv4 = require 'torchnet-vision.models.inceptionv4' 11 | models.inceptionresnetv2 = require 'torchnet-vision.models.inceptionresnetv2' 12 | models.resnet = require 'torchnet-vision.models.resnet' 13 | 14 | return models 15 | -------------------------------------------------------------------------------- /rocks/torchnet-vision-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "torchnet-vision" 2 | version = "scm-1" 3 | 4 | source = { 5 | url = "git://github.com/Cadene/torchnet-vision.git" 6 | } 7 | 8 | description = { 9 | summary = "Plugin vision for Torchnet", 10 | detailed = [[ 11 | Various abstractions for vision processing. 12 | ]], 13 | homepage = "https://github.com/Cadene/torchnet-vision", 14 | license = "BSD" 15 | } 16 | 17 | dependencies = { 18 | "lua >= 5.1", 19 | "torch >= 7.0", 20 | "argcheck >= 1.0", 21 | "tds >= 1.0", 22 | "image >= 1.0", 23 | "torchnet >= 1.0" 24 | } 25 | 26 | build = { 27 | type = "cmake", 28 | variables = { 29 | CMAKE_BUILD_TYPE="Release", 30 | LUA_PATH="$(LUADIR)", 31 | LUA_CPATH="$(LIBDIR)" 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Torchnet-Vision 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Make sure your code lints. 12 | 13 | ## Issues 14 | We use GitHub issues to track public bugs. Please ensure your description is 15 | clear and has sufficient instructions to be able to reproduce the issue. 16 | 17 | ## Coding Style 18 | * 3 spaces for indentation rather than tabs 19 | * 80 character line length 20 | * variables names all lower-case, no underlines 21 | 22 | ## License 23 | By contributing to Torchnet-Vision, you agree that your contributions will be licensed 24 | under its BSD license. 25 | -------------------------------------------------------------------------------- /models/inceptionv4.lua: -------------------------------------------------------------------------------- 1 | local argcheck = require 'argcheck' 2 | 3 | local inceptionv4 = {} 4 | 5 | inceptionv4.__download = argcheck{ 6 | {name='filename', type='string'}, 7 | call = 8 | function(filename) 9 | os.execute('mkdir -p '..paths.dirname(filename)..';' 10 | ..'wget http://webia.lip6.fr/~cadene/Downloads/inceptionv4.t7' 11 | ..' -O '..filename) 12 | end 13 | } 14 | 15 | inceptionv4.load = argcheck{ 16 | {name='filename', type='string', default='inceptionv4.t7'}, 17 | call = 18 | function(filename) 19 | if not path.exists(filename) then 20 | inceptionv4.__download(filename) 21 | end 22 | return torch.load(filename) 23 | end 24 | } 25 | 26 | inceptionv4.colorMode = 'RGB' 27 | inceptionv4.pixelRange = {0,1} -- [0,1] instead of [0,255] 28 | inceptionv4.inputSize = {3, 299, 299} 29 | inceptionv4.mean = torch.Tensor{0.5, 0.5, 0.5} 30 | inceptionv4.std = torch.Tensor{0.5, 0.5, 0.5} 31 | 32 | return inceptionv4 33 | -------------------------------------------------------------------------------- /models/inceptionresnetv2.lua: -------------------------------------------------------------------------------- 1 | local argcheck = require 'argcheck' 2 | 3 | local inceptionresnetv2 = {} 4 | 5 | inceptionresnetv2.__download = argcheck{ 6 | {name='filename', type='string'}, 7 | call = 8 | function(filename) 9 | os.execute('mkdir -p '..paths.dirname(filename)..';' 10 | ..'wget http://webia.lip6.fr/~cadene/Downloads/inceptionresnetv2.t7' 11 | ..' -O '..filename) 12 | end 13 | } 14 | 15 | inceptionresnetv2.load = argcheck{ 16 | {name='filename', type='string', default='inceptionresnetv2.t7'}, 17 | call = 18 | function(filename) 19 | if not path.exists(filename) then 20 | inceptionresnetv2.__download(filename) 21 | end 22 | return torch.load(filename) 23 | end 24 | } 25 | 26 | inceptionresnetv2.colorMode = 'RGB' 27 | inceptionresnetv2.pixelRange = {0,1} -- [0,1] instead of [0,255] 28 | inceptionresnetv2.inputSize = {3, 299, 299} 29 | inceptionresnetv2.mean = torch.Tensor{0.5, 0.5, 0.5} 30 | inceptionresnetv2.std = torch.Tensor{0.5, 0.5, 0.5} 31 | 32 | return inceptionresnetv2 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For Torchnet software 4 | 5 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /models/resnet.lua: -------------------------------------------------------------------------------- 1 | local argcheck = require 'argcheck' 2 | 3 | local resnet = {} 4 | 5 | resnet.__download = argcheck{ 6 | {name='filename', type='string'}, 7 | {name='length', type='number'}, 8 | call = 9 | function(filename, length) 10 | os.execute('mkdir -p '..paths.dirname(filename)..';' 11 | ..'wget https://d2j0dndfm35trm.cloudfront.net/resnet-'..length..'.t7' 12 | ..' -O '..filename) 13 | end 14 | } 15 | 16 | resnet.load = argcheck{ 17 | {name='filename', type='string'}, 18 | {name='length', type='number', 19 | help='18, 34, 50, 101, 152, 200', 20 | check=function(length) 21 | return length == 18 or length == 34 or length == 50 or 22 | length == 101 or length == 152 or length == 200 23 | end 24 | }, 25 | call = 26 | function(filename, length) 27 | if not path.exists(filename) then 28 | resnet.__download(filename, length) 29 | end 30 | return torch.load(filename) 31 | end 32 | } 33 | 34 | resnet.loadFinetuning = argcheck{ 35 | {name='filename', type='string'}, 36 | {name='nclasses', type='number'}, 37 | {name='ftfactor', type='number', default=10}, 38 | call = 39 | function(filename, nclasses, ftfactor) 40 | local net = resnet.load(filename) 41 | net:remove() -- nn.Linear 42 | net:add(nn.GradientReversal(-1.0/ftfactor)) 43 | net:add(nn.Linear(2048, nclasses)) 44 | return net 45 | end 46 | } 47 | 48 | resnet.colorMode = 'RGB' 49 | resnet.inputSize = {3, 224, 224} 50 | resnet.mean = torch.Tensor{0.485, 0.456, 0.406} 51 | resnet.std = torch.Tensor{0.229, 0.224, 0.225} 52 | 53 | return resnet 54 | -------------------------------------------------------------------------------- /models/inceptionv3.lua: -------------------------------------------------------------------------------- 1 | local argcheck = require 'argcheck' 2 | 3 | local inceptionv3 = {} 4 | 5 | inceptionv3.__download = argcheck{ 6 | {name='filename', type='string'}, 7 | call = 8 | function(filename) 9 | os.execute('mkdir -p '..paths.dirname(filename)..';' 10 | ..'wget http://webia.lip6.fr/~cadene/Downloads/inceptionv3.t7' 11 | ..' -O '..filename) 12 | end 13 | } 14 | 15 | inceptionv3.load = argcheck{ 16 | {name='filename', type='string'}, 17 | call = 18 | function(filename) 19 | if not path.exists(filename) then 20 | inceptionv3.__download(filename) 21 | end 22 | return torch.load(filename) 23 | end 24 | } 25 | 26 | inceptionv3.loadFinetuning = argcheck{ 27 | {name='filename', type='string'}, 28 | {name='nclasses', type='number'}, 29 | {name='ftfactor', type='number', default=10}, 30 | call = 31 | function(filename, nclasses, ftfactor) 32 | local net = inceptionv3.load(filename) 33 | net:remove() -- nn.SoftMax 34 | net:remove() -- nn.Linear 35 | net:add(nn.GradientReversal(-1.0/ftfactor)) 36 | net:add(nn.Linear(2048, nclasses)) 37 | return net 38 | end 39 | } 40 | 41 | inceptionv3.loadExtracting = argcheck{ 42 | {name='filename', type='string'}, 43 | {name='layerid', type='number'}, 44 | call = 45 | function(filename, layerid) 46 | local net = inceptionv3.load(filename) 47 | for i=net:size(), layerid+1, -1 do 48 | net:remove() 49 | end 50 | net:evaluate() 51 | return net 52 | end 53 | } 54 | 55 | inceptionv3.colorMode = 'RGB' 56 | inceptionv3.inputSize = {3, 299, 299} 57 | inceptionv3.mean = torch.Tensor{128, 128, 128} 58 | inceptionv3.std = torch.Tensor{128, 128, 128} 59 | 60 | return inceptionv3 61 | -------------------------------------------------------------------------------- /models/vggm.lua: -------------------------------------------------------------------------------- 1 | local argcheck = require 'argcheck' 2 | 3 | local vggm = {} 4 | 5 | vggm.__download = argcheck{ 6 | {name='filename', type='string'}, 7 | call = 8 | function(filename) 9 | os.execute('mkdir -p '..paths.dirname(filename)..';' 10 | ..'wget http://webia.lip6.fr/~cadene/Downloads/vggm.t7' 11 | ..' -O '..filename) 12 | end 13 | } 14 | 15 | vggm.load = argcheck{ 16 | {name='filename', type='string'}, 17 | call = 18 | function(filename) 19 | if not path.exists(filename) then 20 | vggm.__download(filename) 21 | end 22 | return torch.load(filename) 23 | end 24 | } 25 | 26 | vggm.loadFinetuning = argcheck{ 27 | {name='filename', type='string'}, 28 | {name='nclasses', type='number'}, 29 | {name='ftfactor', type='number', default=10}, 30 | {name='layerid', type='number', default=22}, 31 | call = 32 | function(filename, nclasses, ftfactor, layerid) 33 | local net = vggm.load(filename) 34 | net:remove(24) -- nn.SoftMax 35 | net:remove(23) -- nn.Linear 36 | net:add(nn.Linear(4096, nclasses)) 37 | for i=net:size(), layerid+1, -1 do 38 | net:get(i):reset() 39 | end 40 | net:insert(nn.GradientReversal(-1.0/ftfactor), layerid) 41 | return net 42 | end 43 | } 44 | 45 | vggm.loadExtracting = argcheck{ 46 | {name='filename', type='string'}, 47 | {name='layerid', type='number'}, 48 | call = 49 | function(filename, layerid) 50 | local net = vggm.load(filename) 51 | for i=net:size(), layerid+1, -1 do 52 | net:remove() 53 | end 54 | net:evaluate() 55 | return net 56 | end 57 | } 58 | 59 | vggm.colorMode = 'BGR' 60 | vggm.inputSize = {3, 221, 221} 61 | vggm.mean = torch.Tensor{123.68, 116.779, 103.939} 62 | -- no std 63 | 64 | return vggm 65 | -------------------------------------------------------------------------------- /models/vgg16.lua: -------------------------------------------------------------------------------- 1 | local argcheck = require 'argcheck' 2 | 3 | local vgg16 = {} 4 | 5 | vgg16.__download = argcheck{ 6 | {name='filename', type='string'}, 7 | call = 8 | function(filename) 9 | os.execute('mkdir -p '..paths.dirname(filename)..';' 10 | ..'wget http://webia.lip6.fr/~cadene/Downloads/vgg16.t7' 11 | ..' -O '..filename) 12 | end 13 | } 14 | 15 | vgg16.load = argcheck{ 16 | {name='filename', type='string'}, 17 | call = 18 | function(filename) 19 | if not path.exists(filename) then 20 | vgg16.__download(filename) 21 | end 22 | return torch.load(filename) 23 | end 24 | } 25 | 26 | vgg16.loadFinetuning = argcheck{ 27 | {name='filename', type='string'}, 28 | {name='nclasses', type='number'}, 29 | {name='ftfactor', type='number', default=10}, 30 | {name='layerid', type='number', default=38}, 31 | call = 32 | function(filename, nclasses, ftfactor, layerid) 33 | local net = vgg16.load(filename) 34 | net:remove() -- nn.SoftMax 35 | net:remove() -- nn.Linear 36 | net:add(nn.Linear(4096, nclasses)) 37 | for i=net:size(), layerid+1, -1 do 38 | net:get(i):reset() 39 | end 40 | net:insert(nn.GradientReversal(-1.0/ftfactor), layerid) 41 | return net 42 | end 43 | } 44 | 45 | vgg16.loadExtracting = argcheck{ 46 | {name='filename', type='string'}, 47 | {name='layerid', type='number'}, 48 | call = 49 | function(filename, layerid) 50 | local net = vgg16.load(filename) 51 | for i=net:size(), layerid+1, -1 do 52 | net:remove() 53 | end 54 | net:evaluate() 55 | return net 56 | end 57 | } 58 | 59 | vgg16.colorMode = 'BGR' 60 | vgg16.inputSize = {3, 221, 221} 61 | vgg16.mean = torch.Tensor{123.68, 116.779, 103.939} 62 | -- no std 63 | 64 | return vgg16 65 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: c 2 | os: 3 | - linux 4 | - osx 5 | compiler: 6 | - clang 7 | cache: 8 | directories: 9 | - $HOME/OpenBlasInstall 10 | sudo: false 11 | env: 12 | - TORCH_LUA_VERSION=LUAJIT21 13 | - TORCH_LUA_VERSION=LUA52 14 | addons: 15 | apt: 16 | packages: 17 | - cmake 18 | - gfortran 19 | - gcc-multilib 20 | - gfortran-multilib 21 | - liblapack-dev 22 | - build-essential 23 | - gcc 24 | - g++ 25 | - curl 26 | - cmake 27 | - libreadline-dev 28 | - git-core 29 | - libqt4-core 30 | - libqt4-gui 31 | - libqt4-dev 32 | - libjpeg-dev 33 | - libpng-dev 34 | - ncurses-dev 35 | - imagemagick 36 | - libzmq3-dev 37 | - gfortran 38 | - unzip 39 | - gnuplot 40 | - gnuplot-x11 41 | before_script: 42 | - export ROOT_TRAVIS_DIR=$(pwd) 43 | - export INSTALL_PREFIX=~/torch/install 44 | - ls $HOME/OpenBlasInstall/lib || (cd /tmp/ && git clone https://github.com/xianyi/OpenBLAS.git -b master && cd OpenBLAS && (make NO_AFFINITY=1 -j$(getconf _NPROCESSORS_ONLN) 2>/dev/null >/dev/null) && make PREFIX=$HOME/OpenBlasInstall install) 45 | - git clone https://github.com/torch/distro.git ~/torch --recursive 46 | - cd ~/torch && git submodule update --init --recursive 47 | - mkdir build && cd build 48 | - export CMAKE_LIBRARY_PATH=$HOME/OpenBlasInstall/include:$HOME/OpenBlasInstall/lib:$CMAKE_LIBRARY_PATH 49 | - cmake .. -DCMAKE_INSTALL_PREFIX="${INSTALL_PREFIX}" -DCMAKE_BUILD_TYPE=Release -DWITH_${TORCH_LUA_VERSION}=ON 50 | - make && make install 51 | - cd $ROOT_TRAVIS_DIR 52 | - export LD_LIBRARY_PATH=${INSTALL_PREFIX}/lib:$LD_LIBRARY_PATH 53 | - ${INSTALL_PREFIX}/bin/luarocks install torch 54 | script: 55 | - ${INSTALL_PREFIX}/bin/luarocks make rocks/torchnet-vision-scm-1.rockspec 56 | - export PATH=${INSTALL_PREFIX}/bin:$PATH 57 | - export TESTLUA=$(which luajit lua | head -n 1) 58 | - ${TESTLUA} -e "vision = require 'torchnet-vision'; t=vision.test(); if t.errors[1] then os.exit(1) end" 59 | -------------------------------------------------------------------------------- /datasets/ethzfood101.lua: -------------------------------------------------------------------------------- 1 | local argcheck = require 'argcheck' 2 | local tnt = require 'torchnet' 3 | local utils = require 'torchnet-vision.datasets.utils' 4 | local lsplit = string.split 5 | 6 | local ethzfood101 = {} 7 | 8 | ethzfood101.__download = argcheck{ 9 | {name='dirname', type='string'}, 10 | call = 11 | function(dirname) 12 | os.execute('mkdir -p '..dirname..'; ' 13 | --..'cp /net/big/cadene/doc/Deep6Framework/data/raw/UPMC_Food101/UPMC_Food101.tar.gz'..' '..dirname..'; ' 14 | ..'wget http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz -P '..dirname..'; ' 15 | ..'tar -xzf '..dirname..'/food-101.tar.gz -C '..dirname) 16 | end 17 | } 18 | 19 | ethzfood101.load = argcheck{ 20 | {name='dirname', type='string', default='data/raw/ethzfood101'}, 21 | call = 22 | function(dirname) 23 | local dirimg = paths.concat(dirname,'food-101','images') 24 | local traintxt = paths.concat(dirname,'food-101','meta','train.txt') 25 | local testtxt = paths.concat(dirname,'food-101','meta','test.txt') 26 | if not paths.dirp(dirname) then 27 | ethzfood101.__download(dirname) 28 | end 29 | local classes, class2target = utils.findClasses(dirimg) 30 | local loadSample = function(line) 31 | local spl = lsplit(line, '/') 32 | local sample = {} 33 | sample.path = line..'.jpg' 34 | sample.label = spl[#spl-1] 35 | sample.target = class2target[sample.label] 36 | return sample 37 | end 38 | local trainset = tnt.ListDataset{ 39 | filename = traintxt, 40 | path = dirimg, 41 | load = loadSample 42 | } 43 | local testset = tnt.ListDataset{ 44 | filename = testtxt, 45 | path = dirimg, 46 | load = loadSample 47 | } 48 | return trainset, testset, classes, class2target 49 | end 50 | } 51 | 52 | return ethzfood101 53 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the Torchnet software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /models/overfeat.lua: -------------------------------------------------------------------------------- 1 | local argcheck = require 'argcheck' 2 | 3 | local overfeat = {} 4 | 5 | overfeat.__download = argcheck{ 6 | {name='filename', type='string'}, 7 | call = 8 | function(filename) 9 | os.execute('mkdir -p '..paths.dirname(filename)..';' 10 | ..'wget http://webia.lip6.fr/~cadene/Downloads/overfeat.t7' 11 | ..' -O '..filename) 12 | end 13 | } 14 | 15 | overfeat.load = argcheck{ 16 | {name='filename', type='string'}, 17 | call = 18 | function(filename) 19 | if not path.exists(filename) then 20 | overfeat.__download(filename) 21 | end 22 | return torch.load(filename) 23 | end 24 | } 25 | 26 | -- overfeat.loadFinetuning = argcheck{ 27 | -- {name='filename', type='string'}, 28 | -- {name='nclasses', type='number'}, 29 | -- {name='ftfactor', type='number', default=10}, 30 | -- {name='layerid', type='number', default=38}, 31 | -- call = 32 | -- function(filename, nclasses, ftfactor, layerid) 33 | -- local net = overfeat.load(filename) 34 | -- net:remove() -- nn.SoftMax 35 | -- net:remove() -- nn.Linear 36 | -- net:add(nn.Linear(4096, nclasses)) 37 | -- for i=net:size(), layerid+1, -1 do 38 | -- net:get(i):reset() 39 | -- end 40 | -- net:insert(nn.GradientReversal(-1.0/ftfactor), layerid) 41 | -- return net 42 | -- end 43 | -- } 44 | 45 | -- overfeat.loadExtracting = argcheck{ 46 | -- {name='filename', type='string'}, 47 | -- {name='layerid', type='number'}, 48 | -- call = 49 | -- function(filename, layerid) 50 | -- local net = overfeat.load(filename) 51 | -- for i=net:size(), layerid+1, -1 do 52 | -- net:remove() 53 | -- end 54 | -- net:evaluate() 55 | -- return net 56 | -- end 57 | -- } 58 | 59 | overfeat.colorMode = 'RGB' 60 | overfeat.inputSize = {3, 224, 224} 61 | overfeat.mean = torch.Tensor{118.380948, 118.380948, 118.380948} 62 | overfeat.std = torch.Tensor{61.896913, 61.896913, 61.896913} 63 | 64 | return overfeat 65 | -------------------------------------------------------------------------------- /datasets/upmcfood101.lua: -------------------------------------------------------------------------------- 1 | local argcheck = require 'argcheck' 2 | local tnt = require 'torchnet' 3 | local utils = require 'torchnet-vision.datasets.utils' 4 | local lsplit = string.split 5 | 6 | local upmcfood101 = {} 7 | 8 | upmcfood101.__download = argcheck{ 9 | {name='dirname', type='string'}, 10 | call = 11 | function(dirname) 12 | os.execute('mkdir -p '..dirname..'; ' 13 | --..'cp /net/big/cadene/doc/Deep6Framework/data/raw/UPMC_Food101/UPMC_Food101.tar.gz'..' '..dirname..'; ' 14 | ..'wget http://visiir.lip6.fr/data/public/UPMC_Food101.tar.gz -P '..dirname..'; ' 15 | ..'tar -xzf '..dirname..'/UPMC_Food101.tar.gz -C '..dirname) 16 | end 17 | } 18 | 19 | upmcfood101.load = argcheck{ 20 | {name='dirname', type='string', default='data/raw/upmcfood101'}, 21 | call = 22 | function(dirname) 23 | local dirimg = paths.concat(dirname,'images') 24 | local dirtrain = paths.concat(dirimg,'train') 25 | local dirtest = paths.concat(dirimg,'test') 26 | local traintxt = paths.concat(dirtrain,'TrainImages.txt') 27 | local testtxt = paths.concat(dirtest,'TestImages.txt') 28 | if not paths.dirp(dirname) then 29 | upmcfood101.__download(dirname) 30 | end 31 | local classes, class2target = utils.findClasses(dirtrain) 32 | if not paths.filep(traintxt) then 33 | utils.findFilenames(dirtrain, classes, 'TrainImages.txt') 34 | end 35 | if not paths.filep(testtxt) then 36 | utils.findFilenames(dirtest, classes, 'TestImages.txt') 37 | end 38 | local loadSample = function(line) 39 | local spl = lsplit(line, '/') 40 | local sample = {} 41 | sample.path = line 42 | sample.label = spl[#spl-1] 43 | sample.target = class2target[sample.label] 44 | return sample 45 | end 46 | local trainset = tnt.ListDataset{ 47 | filename = traintxt, 48 | path = dirtrain, 49 | load = loadSample 50 | } 51 | local testset = tnt.ListDataset{ 52 | filename = testtxt, 53 | path = dirtest, 54 | load = loadSample 55 | } 56 | return trainset, testset, classes, class2target 57 | end 58 | } 59 | 60 | return upmcfood101 61 | -------------------------------------------------------------------------------- /datasets/mit67.lua: -------------------------------------------------------------------------------- 1 | local argcheck = require 'argcheck' 2 | local tnt = require 'torchnet' 3 | local utils = require 'torchnet-vision.datasets.utils' 4 | local lsplit = string.split 5 | 6 | local mit67 = {} 7 | 8 | mit67.__download = argcheck{ 9 | {name='dirname', type='string', default='data/raw/mit67'}, 10 | call = 11 | function(dirname) 12 | local urlremote = 'http://groups.csail.mit.edu/vision/LabelMe/NewImages/indoorCVPR_09.tar' 13 | local urltrainimages = 'http://web.mit.edu/torralba/www/TrainImages.txt' 14 | local urltestimages = 'http://web.mit.edu/torralba/www/TestImages.txt' 15 | os.execute('mkdir -p '..dirname..'; '.. 16 | 'wget '..urlremote..' -P '..dirname..'; '.. 17 | 'tar -C '..dirname..' -xf '..dirname..'/indoorCVPR_09.tar') 18 | os.execute('wget '..urltrainimages..' -P '..dirname) 19 | os.execute('wget '..urltestimages..' -P '..dirname) 20 | local dirimg = paths.concat(dirname,'Images') 21 | local classes, _ = utils.findClasses(dirimg) 22 | for _, class in pairs(classes) do 23 | print('Convert class '..class..' to jpg ') 24 | os.execute('mogrify -format jpg ' 25 | ..paths.concat(dirimg, class, '*.jpg')) 26 | -- the extension is jpg, but img need to be converted to jpg 27 | end 28 | end 29 | } 30 | 31 | mit67.load = argcheck{ 32 | {name='dirname', type='string', default='data/raw/mit67'}, 33 | call = 34 | function(dirname) 35 | local dirimg = paths.concat(dirname, 'Images') 36 | local traintxt = paths.concat(dirname, 'TrainImages.txt') 37 | local testtxt = paths.concat(dirname, 'TestImages.txt') 38 | if not (paths.dirp(dirname) and paths.dirp(dirimg) and 39 | paths.filep(traintxt) and paths.filep(testtxt)) then 40 | mit67.__download(dirname) 41 | end 42 | local classes, class2target = utils.findClasses(dirimg) 43 | local loadSample = function(line) 44 | local spl = lsplit(line, '/') 45 | local sample = {} 46 | sample.path = line 47 | sample.label = spl[#spl-1] 48 | sample.target = class2target[sample.label] 49 | return sample 50 | end 51 | local trainset = tnt.ListDataset{ 52 | filename = traintxt, 53 | path = dirimg, 54 | load = loadSample 55 | } 56 | local testset = tnt.ListDataset{ 57 | filename = testtxt, 58 | path = dirimg, 59 | load = loadSample 60 | } 61 | return trainset, testset, classes, class2target 62 | end 63 | } 64 | 65 | return mit67 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /datasets/utils.lua: -------------------------------------------------------------------------------- 1 | local argcheck = require 'argcheck' 2 | local tnt = require 'torchnet' 3 | 4 | local utils = {} 5 | 6 | utils.findClasses = argcheck{ 7 | {name='dirname', type='string'}, 8 | call = 9 | function(dirname) 10 | local find = 'find' 11 | local handle = io.popen(find..' '..dirname..' -mindepth 1 -maxdepth 1 -type d' 12 | ..' | grep -o \'[^/]*$\' | sort') 13 | local classes = {} 14 | local class2target = {} 15 | local key = 1 16 | for class in handle:lines() do 17 | table.insert(classes, class) 18 | class2target[class] = key 19 | key = key + 1 20 | end 21 | handle:close() 22 | return classes, class2target 23 | end 24 | } 25 | 26 | utils.findFilenames = argcheck{ 27 | {name='dirname', type='string'}, 28 | {name='classes', type='table'}, 29 | {name='filename', type='string', default='filename.txt'}, 30 | call = 31 | function(dirname, classes, filename) 32 | local pathfilename = paths.concat(dirname,filename) 33 | local find = 'find' 34 | local extensionList = {'jpg', 'png', 'JPG', 'PNG', 'JPEG', 35 | 'ppm', 'PPM', 'bmp', 'BMP'} 36 | local findOptions = ' -iname "*.' .. extensionList[1] .. '"' 37 | for i = 2, #extensionList do 38 | findOptions = findOptions .. ' -o' 39 | ..' -iname "*.' .. extensionList[i] .. '"' 40 | end 41 | print(pathfilename) 42 | assert(not paths.filep(pathfilename), 43 | 'filename already exists, you should remove it first') 44 | for _, class in pairs(classes) do 45 | print(find..' "'..dirname..'/'..class..'" '..findOptions 46 | ..' | grep -o \'[^/]*/[^/]*$\' >> '..pathfilename) 47 | os.execute(find..' "'..dirname..'/'..class..'" '..findOptions 48 | ..' | grep -o \'[^/]*/[^/]*$\' >> '..pathfilename) 49 | end 50 | return pathfilename 51 | end 52 | } 53 | 54 | utils.loadDataset = argcheck{ 55 | {name='dirname', type='string'}, 56 | {name='filename', type='string', default='filename.txt'}, 57 | call = 58 | function(dirname, filename) 59 | local classes, class2target = utils.findClasses(dirname) 60 | local pathfilename = dirname..'/'..filename 61 | if not paths.filep(pathfilename) then 62 | utils.findFilenames(dirname, classes, filename) 63 | end 64 | local dataset = tnt.ListDataset{ 65 | filename = pathfilename, 66 | path = dirname, 67 | load = function(line) 68 | local sample = { 69 | path = line 70 | } 71 | return sample 72 | end 73 | } 74 | -- return { 75 | -- dataset = dataset, 76 | -- classes = classes, 77 | -- class2target = class2target 78 | -- } 79 | return dataset, classes, class2target 80 | end 81 | } 82 | 83 | return utils 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchnet-vision 2 | 3 | *torchnet-vision* is a plugin for [torchnet](http://github.com/torchnet/torchnet) which provides a set 4 | of abstractions aiming at encouraging code re-use as well as encouraging 5 | modular programming. 6 | 7 | At the moment, *torchnet-vision* provides several functionalities: 8 | - TransformImage: pre-processing image in various ways. 9 | - Datasets: download and load dataset easily. 10 | - Models: download and load models easily. 11 | 12 | For an overview of the *torchnet* framework, please also refer to [this paper](https://lvdmaaten.github.io/publications/papers/Torchnet_2016.pdf). 13 | 14 | 15 | ## Installation 16 | 17 | Please install *torch* first, following instructions on 18 | [torch.ch](http://torch.ch/docs/getting-started.html). If *torch* is 19 | already installed, make sure you have an up-to-date version of 20 | [*argcheck*](https://github.com/torch/argcheck), otherwise you will get 21 | weird errors at runtime. 22 | 23 | Assuming *torch* is already installed, the *torchnet* and *torchnet-vision* cores are only a set of 24 | lua files, so it is straightforward to install it with *luarocks* 25 | ``` 26 | luarocks install torchnet 27 | git clone https://github.com/Cadene/torchnet-vision.git 28 | cd torchnet-vision 29 | luarocks make rocks/torchnet-vision-scm-1.rockspec 30 | ``` 31 | 32 | 33 | ## Documentation 34 | 35 | ### Extraction features from lena with inceptionv3 36 | 37 | ```lua 38 | require 'image' 39 | tnt = require 'torchnet' 40 | vision = require 'torchnet-vision' 41 | 42 | augmentation = tnt.transform.compose{ 43 | vision.image.transformimage.randomScale{minSize=299,maxSize=350}, 44 | vision.image.transformimage.randomCrop(299), 45 | vision.image.transformimage.colorNormalize{ 46 | mean = vision.models.inceptionv3.mean, 47 | std = vision.models.inceptionv3.std 48 | }, 49 | function(img) return img:float() end 50 | } 51 | img = augmentation(image.lena()) 52 | 53 | net = vision.models.inceptionv3.loadExtracting{ -- download included 54 | filename = 'tmp/inceptionv3.t7', 55 | layerid = 30 56 | } 57 | net:evaluate() 58 | print(net:forward(img:view(1,3,299,299)):size()) -- 2048 59 | ``` 60 | 61 | ### Fine tuning on MIT67 in 250 lines of code 62 | 63 | ``` 64 | $ CUDA_VISIBLE_DEVICES=0 th example/mit67finetuning.lua -usegpu true 65 | $ ls demo/logs/mit67/*/ 66 | $ cat demo/logs/mit67/*/trainlog.txt 67 | $ cat demo/logs/mit67/*/testlog.txt 68 | ``` 69 | 70 | ## Other projects using torchnet 71 | 72 | - [Wide residual network](https://github.com/szagoruyko/wide-residual-networks) 73 | - [MultiGPU ImageNet](https://github.com/karandwivedi42/imagenet-multiGPU.torchnet) 74 | - [Learning to Compare Image Patches via Convolutional Neural Networks (CVPR 2015)](https://github.com/szagoruyko/cvpr15deepcompare/blob/master/training/train.lua) 75 | - [Facebook "A MultiPath Network for Object Detection"](https://github.com/facebookresearch/multipathnet) 76 | 77 | ## Roadmap 78 | 79 | - defining names for package and classes (vision?) 80 | - add docs to TransformImage methods 81 | - add test 82 | - add a method to tnt.DataIterator to process the mean and std 83 | - add a better system to preprocess images than tnt.transform (especially to add or remove TransformImage.colorNormalize) 84 | - add data loaders (largscale or not) 85 | - add video directory 86 | -------------------------------------------------------------------------------- /example/upmcfood101extract.lua: -------------------------------------------------------------------------------- 1 | local tnt = require 'torchnet' 2 | local vision = require 'torchnet-vision' 3 | require 'image' 4 | require 'os' 5 | require 'optim' 6 | ffi = require 'ffi' 7 | local logtext = require 'torchnet.log.view.text' 8 | local logstatus = require 'torchnet.log.view.status' 9 | local transformimage = require 'torchnet-vision.image.transformimage' 10 | local upmcfood101 = require 'torchnet-vision.datasets.upmcfood101' 11 | 12 | local cmd = torch.CmdLine() 13 | cmd:option('-seed', 1337, 'seed for cpu and gpu') 14 | cmd:option('-usegpu', true, 'use gpu') 15 | cmd:option('-bsize', 25, 'batch size') 16 | cmd:option('-nthread', 3, 'threads number for parallel iterator') 17 | cmd:option('-model', 'vgg16', 'options: vgg16 | vggm | resnet200 | inceptionv3') 18 | cmd:option('-layerid', 37, 'ex: vgg16 + 37 = 2nd FC layer after ReLU ') 19 | local config = cmd:parse(arg) 20 | print(string.format('running on %s', config.usegpu and 'GPU' or 'CPU')) 21 | 22 | config.idGPU = os.getenv('CUDA_VISIBLE_DEVICES') or -1 23 | config.date = os.date("%y_%m_%d_%X") 24 | 25 | torch.setdefaulttensortype('torch.FloatTensor') 26 | torch.manualSeed(config.seed) 27 | 28 | local pathdataset = paths.concat('example/data/processed/upmcfood101') 29 | local pathtrainset = paths.concat(pathdataset,'trainset.t7') 30 | local pathtestset = paths.concat(pathdataset,'testset.t7') 31 | local pathdata = paths.concat('example/data/raw/upmcfood101') 32 | local pathmodel = paths.concat('example/models',config.model,'net.t7') 33 | local pathextract = paths.concat('example/features/upmcfood101',config.date) 34 | local pathconfig = paths.concat(pathextract,'config.t7') 35 | 36 | local model = vision.models[config.model] 37 | local net = model.loadExtracting{ 38 | filename = pathmodel, 39 | layerid = config.layerid 40 | } 41 | print(net) 42 | local criterion = nn.CrossEntropyCriterion():float() 43 | 44 | local trainset, testset, classes, class2target = upmcfood101.load() 45 | -- testset = testset:shuffle(300) 46 | -- trainset = trainset:shuffle(300) 47 | 48 | local function addTransforms(dataset, model) 49 | dataset = dataset:transform(function(sample) 50 | sample.input = tnt.transform.compose{ 51 | function(path) return image.load(path, 3) end, 52 | vision.image.transformimage.scale(model.inputSize[2]), 53 | vision.image.transformimage.centerCrop(model.inputSize[2]), 54 | vision.image.transformimage.colorNormalize{ 55 | mean = model.mean, 56 | std = model.std 57 | } 58 | }(sample.path) 59 | return sample 60 | end) 61 | return dataset 62 | end 63 | testset = addTransforms(testset, model) 64 | trainset = addTransforms(trainset, model) 65 | function trainset:manualSeed(seed) torch.manualSeed(seed) end 66 | 67 | os.execute('mkdir -p '..pathdataset) 68 | os.execute('mkdir -p '..pathextract) 69 | torch.save(pathconfig, config) 70 | torch.save(pathtrainset, trainset) 71 | torch.save(pathtestset, testset) 72 | 73 | local function getIterator(mode) 74 | local iterator = tnt.ParallelDatasetIterator{ 75 | nthread = config.nthread, 76 | init = function() 77 | require 'torchnet' 78 | require 'torchnet-vision' 79 | end, 80 | closure = function(threadid) 81 | local dataset = torch.load(pathdataset..'/'..mode..'set.t7') 82 | return dataset:batch(config.bsize) 83 | end, 84 | transform = function(sample) 85 | sample.target = torch.Tensor(sample.target):view(-1,1) 86 | return sample 87 | end 88 | } 89 | print('Stats of '..mode..'set') 90 | for i, v in pairs(iterator:exec('size')) do 91 | print(i, v) 92 | end 93 | return iterator 94 | end 95 | 96 | local meter = { 97 | timem = tnt.TimeMeter{unit = false}, 98 | } 99 | 100 | local engine = tnt.OptimEngine() 101 | local file 102 | local count, nbatch 103 | engine.hooks.onStart = function(state) 104 | count = 1 105 | nbatch = state.iterator:execSingle("size") 106 | for _,m in pairs(meter) do m:reset() end 107 | print(engine.mode) 108 | file = assert(io.open(pathextract..'/'..engine.mode..'set.csv', "w")) 109 | file:write('path;gttarget;gtclass') 110 | for i=1, #classes do file:write(';pred'..i) end 111 | file:write("\n") 112 | end 113 | engine.hooks.onForward = function(state) 114 | local output = state.network.output 115 | print('Mode: '..engine.mode, 116 | 'Inputid: '..count..' / '..nbatch, 117 | 'Size: '..state.sample.input:size(1) 118 | ..' '..output:size(1) 119 | ..' '..state.sample.target:size(1) 120 | ..' '..#state.sample.path) 121 | count = count + 1 122 | if state.sample.input:size(1) == output:size(1) then -- hotfix 123 | for i=1, output:size(1) do 124 | file:write(state.sample.path[i]); 125 | if engine.mode ~= 'test' then 126 | file:write(';') 127 | file:write(state.sample.target[i][1]); file:write(';') 128 | file:write(state.sample.label[i]) 129 | end 130 | for j=1, output:size(2) do 131 | file:write(';'); file:write(output[i][j]) 132 | end 133 | file:write("\n") 134 | end 135 | end 136 | end 137 | engine.hooks.onEnd = function(state) 138 | print('End of extracting on '..engine.mode..'set') 139 | print('Took '..meter.timem:value()) 140 | file:close() 141 | end 142 | 143 | if config.usegpu then 144 | require 'cutorch' 145 | cutorch.manualSeed(config.seed) 146 | require 'cunn' 147 | require 'cudnn' 148 | cudnn.convert(net, cudnn) 149 | net = net:cuda() 150 | criterion = criterion:cuda() 151 | local igpu, tgpu = torch.CudaTensor(), torch.CudaTensor() 152 | engine.hooks.onSample = function(state) 153 | igpu:resize(state.sample.input:size() ):copy(state.sample.input) 154 | tgpu:resize(state.sample.target:size()):copy(state.sample.target) 155 | state.sample.input = igpu 156 | state.sample.target = tgpu 157 | end -- alternatively, this logic can be implemented via a TransformDataset 158 | end 159 | 160 | print('Extracting trainset ...') 161 | engine.mode = 'train' 162 | engine:test{ 163 | network = net, 164 | iterator = getIterator('train'), 165 | criterion = criterion 166 | } 167 | 168 | print('Extracting testset ...') 169 | engine.mode = 'test' 170 | engine:test{ 171 | network = net, 172 | iterator = getIterator('test'), 173 | criterion = criterion 174 | } 175 | -------------------------------------------------------------------------------- /example/mit67finetuning.lua: -------------------------------------------------------------------------------- 1 | require 'image' 2 | require 'os' 3 | require 'optim' 4 | ffi = require 'ffi' 5 | local tnt = require 'torchnet' 6 | local vision = require 'torchnet-vision' 7 | local logtext = require 'torchnet.log.view.text' 8 | local logstatus = require 'torchnet.log.view.status' 9 | local transformimage = require 'torchnet-vision.image.transformimage' 10 | 11 | local cmd = torch.CmdLine() 12 | cmd:option('-seed', 1337, 'seed for cpu and gpu') 13 | cmd:option('-usegpu', true, 'use gpu') 14 | cmd:option('-usecudnn', true, 'use cudnn') 15 | cmd:option('-bsize', 20, 'batch size') 16 | cmd:option('-nepoch', 50, 'epoch number') 17 | cmd:option('-optim', 'adam', 'optimization method, options: sgd | ...') 18 | cmd:option('-lr', 1e-4, 'learning rate') 19 | cmd:option('-lrd', 0, 'learning rate decay (adam compatible)') 20 | cmd:option('-ftfactor', 10, 'fine tuning factor') 21 | cmd:option('-nthread', 3, 'threads number for parallel iterator') 22 | cmd:option('-model', 'vgg16', 'model name, options: inceptionv3 | vggm') 23 | local config = cmd:parse(arg) 24 | print(string.format('running on %s', config.usegpu and 'GPU' or 'CPU')) 25 | 26 | config.idGPU = os.getenv('CUDA_VISIBLE_DEVICES') or -1 27 | config.date = os.date("%y_%m_%d_%X") 28 | torch.setdefaulttensortype('torch.FloatTensor') 29 | torch.manualSeed(config.seed) 30 | 31 | local path = './example' 32 | local pathmodel = path..'/models/'..config.model..'/net.t7' 33 | local pathdata = path..'/data/raw/mit67' 34 | local pathdataset = path..'/data/processed/mit67' 35 | local pathlog = path..'/logs/mit67/'..config.date 36 | local pathtrainset = pathdataset..'/trainset.t7' 37 | local pathtestset = pathdataset..'/testset.t7' 38 | local pathtrainlog = pathlog..'/trainlog.txt' 39 | local pathtestlog = pathlog..'/testlog.txt' 40 | local pathbestepoch = pathlog..'/bestepoch.t7' 41 | local pathbestnet = pathlog..'/net.t7' 42 | local pathconfig = pathlog..'/config.t7' 43 | os.execute('mkdir -p '..pathdataset) -- here we save datasets for threads 44 | os.execute('mkdir -p '..pathlog) -- here we save experiments logs and best net 45 | torch.save(pathconfig, config) 46 | 47 | local trainset, testset, classes, class2target 48 | = vision.datasets.mit67.load(pathdata) 49 | print('Trainset size', trainset:size()) 50 | print('Testset size', testset:size()) 51 | print('Classes number', #classes) 52 | print('First class and its associated target', classes[1], class2target[classes[1]]) 53 | print('') 54 | 55 | local model = vision.models[config.model] 56 | local net = model.loadFinetuning{ -- download included 57 | filename = pathmodel, 58 | ftfactor = config.ftfactor, 59 | nclasses = #classes 60 | } 61 | print(net) 62 | print('Input size', {model.inputSize[1],model.inputSize[2],model.inputSize[3]}) 63 | print('Color mode', model.colorMode) 64 | print('') 65 | local criterion = nn.CrossEntropyCriterion():float() 66 | 67 | local function addTransforms(dataset, model) 68 | dataset = dataset:transform(function(sample) 69 | sample.input = tnt.transform.compose{ 70 | function(path) return image.load(path, 3) end, 71 | transformimage.randomScale{ 72 | minSize = model.inputSize[2], -- 224 for vgg16 or 299 for inceptionv3 73 | maxSize = model.inputSize[2] + 20 -- randomly load bigger image 74 | }, -- keep image ratio by cropping instead of rescaling to a squared image 75 | transformimage.randomCrop(model.inputSize[2]), 76 | transformimage.verticalFlip(0.5), 77 | transformimage.rotation(0.05), 78 | function(img) 79 | if model.colorMode == 'BGR' then 80 | return transformimage.moveColor()(img * 255) 81 | else -- vggm and vgg16 take img color=BGR intensity=[0,255] 82 | return img -- inceptionv3 takes img color=RGB intensity=[0,1] 83 | end 84 | end, 85 | transformimage.colorNormalize(model.mean, model.std) 86 | }(sample.path) 87 | return sample 88 | end) 89 | return dataset 90 | end 91 | 92 | trainset = trainset:shuffle() -- trainset:shuffle(300) to try out with 300 images 93 | trainset = addTransforms(trainset, model) 94 | testset = addTransforms(testset, model) 95 | -- manualSeed is called after each epoch before shuffling the trainset 96 | function trainset:manualSeed(seed) torch.manualSeed(seed) end 97 | torch.save(pathtrainset, trainset) 98 | torch.save(pathtestset, testset) 99 | 100 | local function getIterator(mode) -- mode options= train | test 101 | local iterator = tnt.ParallelDatasetIterator{ 102 | nthread = config.nthread, 103 | init = function() 104 | require 'torchnet' 105 | require 'torchnet-vision' 106 | end, 107 | closure = function(threadid) 108 | local dataset = torch.load(pathdataset..'/'..mode..'set.t7') 109 | return dataset:batch(config.bsize) 110 | end, 111 | transform = function(sample) 112 | sample.target = torch.Tensor(sample.target):view(-1,1) 113 | return sample 114 | end 115 | } 116 | print('Stats of '..mode..'set') 117 | -- all threads have the same # of batch 118 | for i, v in pairs(iterator:exec('size')) do 119 | print('Theadid='..i, 'Batch number='..v) 120 | end 121 | return iterator 122 | end 123 | 124 | local meter = { 125 | avgvm = tnt.AverageValueMeter(), 126 | confm = tnt.ConfusionMeter{k=#classes, normalized=true}, 127 | timem = tnt.TimeMeter{unit = false}, 128 | clerr = tnt.ClassErrorMeter{topk = {1,5}} 129 | } 130 | 131 | local function createLog(mode, pathlog) 132 | local keys = {'epoch', 'loss', 'acc1', 'acc5', 'time'} 133 | local format = {'%d', '%.5f', '%3.2f%%', '%3.2f%%', '%.1f'} 134 | for i=1, #keys do format[i] = keys[i]..' '..format[i] end 135 | local log = tnt.Log{ 136 | keys = keys, 137 | onFlush = { 138 | logtext{filename=pathlog, keys=keys}, 139 | logtext{keys=keys, format=format}, 140 | }, 141 | onSet = { 142 | logstatus{filename=pathlog}, 143 | logstatus{}, -- print status to screen 144 | } 145 | } 146 | log:status("Mode "..mode) 147 | return log 148 | end 149 | local log = { 150 | train = createLog('train', pathtrainlog), 151 | test = createLog('test', pathtestlog) 152 | } 153 | 154 | local engine = tnt.OptimEngine() 155 | engine.hooks.onStart = function(state) 156 | for _, m in pairs(meter) do m:reset() end 157 | end 158 | engine.hooks.onStartEpoch = function(state) -- training only 159 | engine.epoch = engine.epoch and (engine.epoch + 1) or 1 160 | end 161 | engine.hooks.onForwardCriterion = function(state) 162 | meter.timem:incUnit() 163 | meter.avgvm:add(state.criterion.output) 164 | meter.clerr:add(state.network.output, state.sample.target) 165 | meter.confm:add(state.network.output, state.sample.target) 166 | log[engine.mode]:set{ 167 | epoch = engine.epoch, 168 | loss = meter.avgvm:value(), 169 | acc1 = 100 - meter.clerr:value{k = 1}, 170 | acc5 = 100 - meter.clerr:value{k = 5}, 171 | time = meter.timem:value() 172 | } 173 | print(string.format('%s epoch: %i; avg. loss: %2.4f; avg. acctop1: %2.4f%%', 174 | engine.mode, engine.epoch, meter.avgvm:value(), 100 - meter.clerr:value{k = 1})) 175 | end 176 | engine.hooks.onEnd = function(state) 177 | print('End of epoch '..engine.epoch..' on '..engine.mode..'set') 178 | log[engine.mode]:flush() 179 | print('Confusion matrix saved (rows = gt, cols = pred)\n') 180 | image.save(pathlog..'/confm_epoch,'..engine.epoch..'.pgm', meter.confm:value()) 181 | end 182 | if config.usegpu then 183 | require 'cutorch' 184 | cutorch.manualSeed(config.seed) 185 | require 'cunn' 186 | if config.usecudnn then 187 | require 'cudnn' 188 | cudnn.convert(net, cudnn) 189 | end 190 | net = net:cuda() 191 | criterion = criterion:cuda() 192 | local igpu, tgpu = torch.CudaTensor(), torch.CudaTensor() 193 | engine.hooks.onSample = function(state) 194 | igpu:resize(state.sample.input:size() ):copy(state.sample.input) 195 | tgpu:resize(state.sample.target:size()):copy(state.sample.target) 196 | state.sample.input = igpu 197 | state.sample.target = tgpu 198 | end -- alternatively, this logic can be implemented via a TransformDataset 199 | end 200 | 201 | -- Iterator 202 | local trainiter = getIterator('train') 203 | local testiter = getIterator('test') 204 | 205 | local bestepoch = { 206 | acctop1 = 0, 207 | acctop5 = 0, 208 | epoch = 0 209 | } 210 | 211 | for epoch = 1, config.nepoch do 212 | print('Training ...') 213 | engine.mode = 'train' 214 | trainiter:exec('manualSeed', config.seed + epoch) -- call trainset:manualSeed(seed) 215 | trainiter:exec('resample') -- shuffle trainset 216 | engine:train{ 217 | maxepoch = 1, -- we control the epoch with for loop 218 | network = net, 219 | iterator = trainiter, 220 | criterion = criterion, 221 | optimMethod = optim[config.optim], 222 | config = { 223 | learningRate = config.lr, 224 | learningRateDecay = config.lrd 225 | }, 226 | } 227 | print('Testing ...') 228 | engine.mode = 'test' 229 | engine:test{ 230 | network = net, 231 | iterator = testiter, 232 | criterion = criterion, 233 | } 234 | if bestepoch.acctop1 < 100 - meter.clerr:value{k = 1} then 235 | bestepoch = { 236 | acctop1 = 100 - meter.clerr:value{k = 1}, 237 | acctop5 = 100 - meter.clerr:value{k = 5}, 238 | epoch = epoch, 239 | confm = meter.confm:value():clone() 240 | } 241 | torch.save(pathbestepoch, bestepoch) 242 | torch.save(pathbestnet, net:clearState()) 243 | end 244 | end 245 | -------------------------------------------------------------------------------- /image/transformimage.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Image transforms for data augmentation and input normalization 10 | -- 11 | 12 | local argcheck = require 'argcheck' 13 | 14 | require 'image' 15 | 16 | local transformimage = {} 17 | 18 | transformimage.colorNormalize = argcheck{ 19 | doc = [[]], 20 | {name='mean', type='torch.*Tensor', opt=true}, 21 | {name='std', type='torch.*Tensor', opt=true}, 22 | call = 23 | function(mean, std) 24 | return function(img) 25 | if not (mean or std) then 26 | return img 27 | end 28 | if mean:dim() == 1 and mean:size(1) == 3 then 29 | for i=1,3 do 30 | img[i]:add(-mean[i]) 31 | if std then 32 | img[i]:div(std[i]) 33 | end 34 | end 35 | elseif mean:dim() == 3 and std:dim() == 3 then 36 | img:add(mean) 37 | if std then 38 | img:div(std) 39 | end 40 | else 41 | assert(false, 'must be {128,128,128} or 3d tensor') 42 | end 43 | return img 44 | end 45 | end 46 | } 47 | 48 | -- Scales the smaller edge to size 49 | transformimage.scale = argcheck{ 50 | {name='size', type='number'}, 51 | {name='interpolation', type='string', default='bicubic'}, 52 | call = 53 | function(size, interpolation) 54 | return function(img) 55 | local w, h = img:size(3), img:size(2) 56 | if (w <= h and w == size) or (h <= w and h == size) then 57 | return img 58 | end 59 | if w < h then 60 | return image.scale(img, size, h/w * size, interpolation) 61 | else 62 | return image.scale(img, w/h * size, size, interpolation) 63 | end 64 | end 65 | end 66 | } 67 | 68 | -- Crop to centered rectangle 69 | transformimage.centerCrop = argcheck{ 70 | {name='size', type='number'}, 71 | call = 72 | function(size) 73 | return function(img) 74 | local w1 = math.ceil((img:size(3) - size)/2) 75 | local h1 = math.ceil((img:size(2) - size)/2) 76 | return image.crop(img, w1, h1, w1 + size, h1 + size) -- center patch 77 | end 78 | end 79 | } 80 | 81 | -- Random crop form larger image with optional zero padding 82 | transformimage.randomCrop = argcheck{ 83 | {name='size', type='number'}, 84 | {name='padding', type='number', default=0}, 85 | call = 86 | function(size, padding) 87 | return function(img) 88 | if padding > 0 then 89 | local temp = img.new(3, img:size(2) + 2*padding, img:size(3) + 2*padding) 90 | temp:zero() 91 | :narrow(2, padding+1, img:size(2)) 92 | :narrow(3, padding+1, img:size(3)) 93 | :copy(img) 94 | img = temp 95 | end 96 | 97 | local w, h = img:size(3), img:size(2) 98 | if w == size and h == size then 99 | return img 100 | end 101 | 102 | local x1, y1 = torch.random(0, w - size), torch.random(0, h - size) 103 | local out = image.crop(img, x1, y1, x1 + size, y1 + size) 104 | assert(out:size(2) == size and out:size(3) == size, 'wrong crop size') 105 | return out 106 | end 107 | end 108 | } 109 | 110 | -- Four corner patches and center crop from image and its horizontal reflection 111 | transformimage.tenCrop = argcheck{ 112 | {name='size', type='number'}, 113 | call = 114 | function(size) 115 | local centerCrop = transformimage.CenterCrop(size) 116 | 117 | return function(img) 118 | local w, h = img:size(3), img:size(2) 119 | 120 | local output = {} 121 | for _, img in ipairs{img, image.hflip(img)} do 122 | table.insert(output, centerCrop(img)) 123 | table.insert(output, image.crop(img, 0, 0, size, size)) 124 | table.insert(output, image.crop(img, w-size, 0, w, size)) 125 | table.insert(output, image.crop(img, 0, h-size, size, h)) 126 | table.insert(output, image.crop(img, w-size, h-size, w, h)) 127 | end 128 | 129 | -- View as mini-batch 130 | for i, img in ipairs(output) do 131 | output[i] = img:view(1, img:size(1), img:size(2), img:size(3)) 132 | end 133 | 134 | return img.cat(output, 1) 135 | end 136 | end 137 | } 138 | 139 | -- Resized with shorter side randomly sampled from [minSize, maxSize] (ResNet-style) 140 | transformimage.randomScale = argcheck{ 141 | {name='minSize', type='number'}, 142 | {name='maxSize', type='number'}, 143 | call = 144 | function(minSize, maxSize) 145 | return function(img) 146 | local w, h = img:size(3), img:size(2) 147 | 148 | local targetSz = torch.random(minSize, maxSize) 149 | local targetW, targetH = targetSz, targetSz 150 | if w < h then 151 | targetH = torch.round(h / w * targetW) 152 | else 153 | targetW = torch.round(w / h * targetH) 154 | end 155 | 156 | return image.scale(img, targetW, targetH, 'bicubic') 157 | end 158 | end 159 | } 160 | 161 | -- Random crop with size 8%-100% and aspect ratio 3/4 - 4/3 (Inception-style) 162 | transformimage.randomSizedCrop = argcheck{ 163 | {name='size', type='number'}, 164 | call = 165 | function(size) 166 | local scale = transformimage.Scale(size) 167 | local crop = transformimage.CenterCrop(size) 168 | 169 | return function(img) 170 | local attempt = 0 171 | repeat 172 | local area = img:size(2) * img:size(3) 173 | local targetArea = torch.uniform(0.08, 1.0) * area 174 | 175 | local aspectRatio = torch.uniform(3/4, 4/3) 176 | local w = torch.round(math.sqrt(targetArea * aspectRatio)) 177 | local h = torch.round(math.sqrt(targetArea / aspectRatio)) 178 | 179 | if torch.uniform() < 0.5 then 180 | w, h = h, w 181 | end 182 | 183 | if h <= img:size(2) and w <= img:size(3) then 184 | local y1 = torch.random(0, img:size(2) - h) 185 | local x1 = torch.random(0, img:size(3) - w) 186 | 187 | local out = image.crop(img, x1, y1, x1 + w, y1 + h) 188 | assert(out:size(2) == h and out:size(3) == w, 'wrong crop size') 189 | 190 | return image.scale(out, size, size, 'bicubic') 191 | end 192 | attempt = attempt + 1 193 | until attempt >= 10 194 | 195 | -- fallback 196 | return crop(scale(img)) 197 | end 198 | end 199 | } 200 | 201 | transformimage.horizontalFlip = argcheck{ 202 | {name='prob', type='number', default=0.5}, 203 | call = 204 | function(prob) 205 | return function(img) 206 | if torch.uniform() < prob then 207 | img = image.hflip(img) 208 | end 209 | return img 210 | end 211 | end 212 | } 213 | 214 | transformimage.verticalFlip = argcheck{ 215 | {name='prob', type='number', default=0.5}, 216 | call = 217 | function(prob) 218 | return function(img) 219 | if torch.uniform() < prob then 220 | img = image.vflip(img) 221 | end 222 | return img 223 | end 224 | end 225 | } 226 | 227 | transformimage.rotation = argcheck{ 228 | {name='deg', type='number'}, 229 | call = 230 | function(deg) 231 | return function(img) 232 | if deg ~= 0 then 233 | img = image.rotate(img, (torch.uniform() - 0.5) * deg * math.pi / 180, 'bilinear') 234 | end 235 | return img 236 | end 237 | end 238 | } 239 | 240 | -- Lighting noise (AlexNet-style PCA-based noise) 241 | transformimage.lighting = argcheck{ 242 | {name='alphastd', type='number'}, 243 | {name='eigval', type='torch.*Tensor', 244 | default=torch.Tensor{ 0.2175, 0.0188, 0.0045 }, 245 | check=function(x) 246 | return x:dim() == 1 and x:size(1) == 3 247 | end}, 248 | {name='eigvec', type='torch.*Tensor', 249 | default=torch.Tensor{ 250 | { -0.5675, 0.7192, 0.4009 }, 251 | { -0.5808, -0.0045, -0.8140 }, 252 | { -0.5836, -0.6948, 0.4203 }, 253 | }, 254 | check=function(x) 255 | return x:dim() == 2 and x:size(1) == 3 256 | and x:size(2) == 3 257 | end}, 258 | call = 259 | function(alphastd, eigval, eigvec) 260 | return function(img) 261 | if alphastd == 0 then 262 | return img 263 | end 264 | 265 | local alpha = torch.Tensor(3):normal(0, alphastd) 266 | local rgb = eigvec:clone() 267 | :cmul(alpha:view(1, 3):expand(3, 3)) 268 | :cmul(eigval:view(1, 3):expand(3, 3)) 269 | :sum(2) 270 | :squeeze() 271 | 272 | img = img:clone() 273 | for i=1,3 do 274 | img[i]:add(rgb[i]) 275 | end 276 | return img 277 | end 278 | end 279 | } 280 | 281 | local function blend(img1, img2, alpha) 282 | return img1:mul(alpha):add(1 - alpha, img2) 283 | end 284 | 285 | transformimage.grayscale = argcheck{ 286 | {name='rgbval', type='torch.*Tensor', 287 | default=torch.Tensor{ 0.299, 0.587, 0.114 }, 288 | check=function(x) 289 | return x:dim() == 1 and x:size(1) == 3 290 | end}, 291 | call = 292 | function(rgbval) 293 | return function(img) 294 | local dst = img.new():resizeAs(img) 295 | dst[1]:zero() 296 | dst[1]:add(rgbval[1], img[1]) 297 | :add(rgbval[2], img[2]) 298 | :add(rgbval[3], img[3]) 299 | dst[2]:copy(dst[1]) 300 | dst[3]:copy(dst[1]) 301 | return dst 302 | end 303 | end 304 | } 305 | 306 | transformimage.moveColor = argcheck{ 307 | {name='colormap', type='torch.ByteTensor', 308 | default=torch.ByteTensor{ 3, 2, 1 }, 309 | check=function(x) 310 | return x:dim() == 1 and x:size(1) == 3 311 | end}, 312 | call = 313 | function(colormap) 314 | return function(img) 315 | local dst = img.new():resizeAs(img) 316 | dst[colormap[1]]:copy(img[1]) 317 | dst[colormap[2]]:copy(img[2]) 318 | dst[colormap[3]]:copy(img[3]) 319 | return dst 320 | end 321 | end 322 | } 323 | 324 | transformimage.saturation = argcheck{ 325 | {name='var', type='number'}, 326 | call = 327 | function(var) 328 | local gs 329 | return function(img) 330 | gs = gs or img.new() 331 | transformimage.grayscale(gs, img) 332 | local alpha = 1.0 + torch.uniform(-var, var) 333 | blend(img, gs, alpha) 334 | return img 335 | end 336 | end 337 | } 338 | 339 | transformimage.brightness = argcheck{ 340 | {name='var', type='number'}, 341 | call = 342 | function(var) 343 | local gs 344 | return function(img) 345 | gs = gs or img.new() 346 | gs:resizeAs(img):zero() 347 | local alpha = 1.0 + torch.uniform(-var, var) 348 | blend(img, gs, alpha) 349 | return img 350 | end 351 | end 352 | } 353 | 354 | transformimage.contrast = argcheck{ 355 | {name='var', type='number'}, 356 | call = 357 | function(var) 358 | local gs 359 | 360 | return function(img) 361 | gs = gs or img.new() 362 | transformimage.grayscale(gs, img) 363 | gs:fill(gs[1]:mean()) 364 | 365 | local alpha = 1.0 + torch.uniform(-var, var) 366 | blend(img, gs, alpha) 367 | return img 368 | end 369 | end 370 | } 371 | 372 | transformimage.randomOrder = argcheck{ 373 | {name='ts', type='table'}, 374 | call = 375 | function(ts) 376 | return function(img) 377 | local img = img.img or img 378 | local order = torch.randperm(#ts) 379 | for i=1,#ts do 380 | img = ts[order[i]](img) 381 | end 382 | return img 383 | end 384 | end 385 | } 386 | 387 | transformimage.colorJitter = argcheck{ 388 | {name='brightness', type='number', default=0}, 389 | {name='contrast', type='number', default=0}, 390 | {name='saturation', type='number', default=0}, 391 | call = 392 | function(brightness, contrast, saturation) 393 | local ts = {} 394 | if brightness ~= 0 then 395 | table.insert(ts, self:brightness(brightness)) 396 | end 397 | if contrast ~= 0 then 398 | table.insert(ts, self:contrast(contrast)) 399 | end 400 | if saturation ~= 0 then 401 | table.insert(ts, self:saturation(saturation)) 402 | end 403 | 404 | if #ts == 0 then 405 | return function(img) return img end 406 | end 407 | 408 | return transformimage.randomOrder(ts) 409 | end 410 | } 411 | 412 | return transformimage 413 | 414 | --------------------------------------------------------------------------------