├── .editorconfig ├── .gitignore ├── images ├── classic.png └── compare3.png ├── .gitmodules ├── include ├── THCGenerateFloatType.h └── common.h ├── deep-encoding-scm-1.rockspec ├── init.lua ├── experiments ├── datasets │ ├── init.lua │ ├── cifar10.lua │ ├── stl10.lua │ ├── joint-gen.lua │ ├── cifar10-gen.lua │ ├── stl10-gen.lua │ ├── joint.lua │ ├── kth-gen.lua │ ├── ground-gen.lua │ ├── light-gen.lua │ ├── fmd.lua │ ├── kth.lua │ ├── light.lua │ ├── ground.lua │ ├── minc-gen.lua │ ├── minc.lua │ ├── fmd-gen.lua │ └── transforms.lua ├── checkpoints.lua ├── main.lua ├── opts.lua ├── dataloader.lua ├── models │ ├── init.lua │ └── encoding.lua └── train.lua ├── lib ├── HZENCODING.c ├── HZEncoding.cu ├── HZWeighting.cu └── HZAggregate.cu ├── CMakeLists.txt ├── generic ├── encoding.c ├── hzencoding.c ├── aggregate.c └── weighting.c ├── init.cu ├── layers ├── aggregate.lua ├── netvlad.lua └── encoding.lua ├── README.md └── cmake └── select_compute_arch.cmake /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = tab 5 | indent_size = 2 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build.luarocks/ 2 | experiments/gen/ 3 | experiments/untitle/ 4 | *.DS_Store 5 | *.swp 6 | -------------------------------------------------------------------------------- /images/classic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanghang1989/Torch-Encoding-Layer/HEAD/images/classic.png -------------------------------------------------------------------------------- /images/compare3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanghang1989/Torch-Encoding-Layer/HEAD/images/compare3.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "experiments/fb.resnet.torch"] 2 | path = experiments/fb.resnet.torch 3 | url = git@github.com:facebook/fb.resnet.torch.git 4 | -------------------------------------------------------------------------------- /include/THCGenerateFloatType.h: -------------------------------------------------------------------------------- 1 | /* copied from cutorch */ 2 | #ifndef THC_GENERIC_FILE 3 | #error "You must define THC_GENERIC_FILE before including THGenerateFloatType.h" 4 | #endif 5 | 6 | #define real float 7 | #define accreal float 8 | #define Real Float 9 | #define CReal Cuda 10 | #define THC_REAL_IS_FLOAT 11 | #line 1 THC_GENERIC_FILE 12 | #include THC_GENERIC_FILE 13 | #undef real 14 | #undef accreal 15 | #undef Real 16 | #undef CReal 17 | #undef THC_REAL_IS_FLOAT 18 | 19 | #ifndef THCGenerateAllTypes 20 | #undef THC_GENERIC_FILE 21 | #endif 22 | 23 | -------------------------------------------------------------------------------- /deep-encoding-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "deep-encoding" 2 | version = "scm-1" 3 | 4 | source = { 5 | url = "git://github.com/zhanghang1989/Deep-Encoding.git", 6 | tag = "master" 7 | } 8 | 9 | description = { 10 | summary = "Deep Encoding Network", 11 | detailed = [[ 12 | Deep Encoding Network 13 | ]], 14 | homepage = "https://github.com/zhanghang1989/Deep-Encoding" 15 | } 16 | 17 | dependencies = { 18 | "torch >= 7.0", 19 | "cutorch >= 1.0" 20 | } 21 | 22 | build = { 23 | type = "cmake", 24 | variables = { 25 | CMAKE_BUILD_TYPE="Release", 26 | CMAKE_PREFIX_PATH="$(LUA_BINDIR)/..", 27 | CMAKE_INSTALL_PREFIX="$(PREFIX)" 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | -- load dependencies 17 | require 'nn' 18 | require 'cutorch' 19 | require 'cunn' 20 | require 'cudnn' 21 | 22 | -- load encoding packages 23 | require 'libencoding' 24 | require 'encoding.aggregate' 25 | require 'encoding.encoding' 26 | -------------------------------------------------------------------------------- /experiments/datasets/init.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- modified from https://github.com/facebook/fb.resnet.torch 3 | -- original copyrights preserves 4 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 | 6 | local M = {} 7 | 8 | local function isvalid(opt, cachePath) 9 | local imageInfo = torch.load(cachePath) 10 | if imageInfo.basedir and imageInfo.basedir ~= opt.data then 11 | return false 12 | end 13 | return true 14 | end 15 | 16 | function M.create(opt, split) 17 | local cachePath = paths.concat(opt.gen, opt.dataset .. '.t7') 18 | if not paths.filep(cachePath) or not isvalid(opt, cachePath) then 19 | paths.mkdir('gen') 20 | print('filename = ', opt.dataset .. '-gen.lua') 21 | local script = paths.dofile(opt.dataset .. '-gen.lua') 22 | script.exec(opt, cachePath) 23 | end 24 | local imageInfo = torch.load(cachePath) 25 | 26 | local Dataset = require('datasets/' .. opt.dataset) 27 | return Dataset(imageInfo, opt, split) 28 | end 29 | 30 | return M 31 | -------------------------------------------------------------------------------- /lib/HZENCODING.c: -------------------------------------------------------------------------------- 1 | /*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | * Created by: Hang Zhang 3 | * ECE Department, Rutgers University 4 | * Email: zhang.hang@rutgers.edu 5 | * Copyright (c) 2016 6 | * 7 | * Free to reuse and distribute this software for research or 8 | * non-profit purpose, subject to the following conditions: 9 | * 1. The code must retain the above copyright notice, this list of 10 | * conditions. 11 | * 2. Original authors' names are not deleted. 12 | * 3. The authors' names are not used to endorse or promote products 13 | * derived from this software 14 | *+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | */ 16 | #ifndef THC_GENERIC_FILE 17 | #define THC_GENERIC_FILE "lib/HZENCODING.c" 18 | #else 19 | 20 | #define HZENCODING_assertSameGPU(...) THAssertMsg(THCTensor_(checkGPU)(__VA_ARGS__), \ 21 | "Some of weight/gradient/input tensors are located on different GPUs. Please move them to a single one.") 22 | 23 | #include "HZAggregate.cu" 24 | #include "HZWeighting.cu" 25 | #include "HZEncoding.cu" 26 | 27 | #endif // THC_GENERIC_FILE 28 | -------------------------------------------------------------------------------- /include/common.h: -------------------------------------------------------------------------------- 1 | /*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | * Created by: Hang Zhang 3 | * ECE Department, Rutgers University 4 | * Email: zhang.hang@rutgers.edu 5 | * Copyright (c) 2016 6 | * 7 | * Free to reuse and distribute this software for research or 8 | * non-profit purpose, subject to the following conditions: 9 | * 1. The code must retain the above copyright notice, this list of 10 | * conditions. 11 | * 2. Original authors' names are not deleted. 12 | * 3. The authors' names are not used to endorse or promote products 13 | * derived from this software 14 | *+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | */ 16 | template 17 | THCDeviceTensor devicetensor(THCState *state, THCudaTensor *t) { 18 | if (!t) { 19 | return THCDeviceTensor(); 20 | } 21 | 22 | int inDim = THCudaTensor_nDimension(state, t); 23 | if (inDim == Dim) { 24 | return toDeviceTensor(state, t); 25 | } 26 | 27 | // View in which the last dimensions are collapsed or expanded as needed 28 | THAssert(THCudaTensor_isContiguous(state, t)); 29 | int size[Dim]; 30 | for (int i = 0; i < Dim || i < inDim; ++i) { 31 | if (i < Dim && i < inDim) { 32 | size[i] = t->size[i]; 33 | } else if (i < Dim) { 34 | size[i] = 1; 35 | } else { 36 | size[Dim - 1] *= t->size[i]; 37 | } 38 | } 39 | return THCDeviceTensor(THCudaTensor_data(state, t), size); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2016 6 | ## 7 | ## Free to reuse and distribute this software for research or 8 | ## non-profit purpose, subject to the following conditions: 9 | ## 1. The code must retain the above copyright notice, this list of 10 | ## conditions. 11 | ## 2. Original authors' names are not deleted. 12 | ## 3. The authors' names are not used to endorse or promote products 13 | ## derived from this software 14 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | CMAKE_MINIMUM_REQUIRED(VERSION 2.8 FATAL_ERROR) 17 | CMAKE_POLICY(VERSION 2.8) 18 | 19 | # Find Torch and CUDA 20 | FIND_PACKAGE(Torch REQUIRED) 21 | FIND_PACKAGE(CUDA 6.5 REQUIRED) 22 | 23 | # Detect CUDA architecture and get best NVCC flags 24 | IF(NOT COMMAND CUDA_SELECT_NVCC_ARCH_FLAGS) 25 | INCLUDE(${CMAKE_CURRENT_SOURCE_DIR}/cmake/select_compute_arch.cmake) 26 | ENDIF() 27 | CUDA_SELECT_NVCC_ARCH_FLAGS(NVCC_FLAGS_EXTRA $ENV{TORCH_CUDA_ARCH_LIST}) 28 | LIST(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA}) 29 | 30 | INCLUDE_DIRECTORIES( 31 | ./include 32 | ${CMAKE_CURRENT_SOURCE_DIR} 33 | "${Torch_INSTALL_INCLUDE}/THC" 34 | ) 35 | LINK_DIRECTORIES("${Torch_INSTALL_LIB}") 36 | 37 | # ADD lua source files 38 | FILE(GLOB src *.c *.cu) 39 | FILE(GLOB luasrc *.lua layers/*.lua) 40 | 41 | # ADD the torch package and link denpendencies 42 | ADD_TORCH_PACKAGE(encoding "${src}" "${luasrc}") 43 | TARGET_LINK_LIBRARIES(encoding 44 | THC TH luaT ${CUDA_cusparse_LIBRARY} 45 | ) 46 | -------------------------------------------------------------------------------- /generic/encoding.c: -------------------------------------------------------------------------------- 1 | /*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | * Created by: Hang Zhang 3 | * ECE Department, Rutgers University 4 | * Email: zhang.hang@rutgers.edu 5 | * Copyright (c) 2016 6 | * 7 | * Free to reuse and distribute this software for research or 8 | * non-profit purpose, subject to the following conditions: 9 | * 1. The code must retain the above copyright notice, this list of 10 | * conditions. 11 | * 2. Original authors' names are not deleted. 12 | * 3. The authors' names are not used to endorse or promote products 13 | * derived from this software 14 | *+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | */ 16 | static int encoding_(Main_Encoding_ForwardF)(lua_State *L) 17 | /* 18 | */ 19 | { 20 | /* Check number of inputs */ 21 | if(lua_gettop(L) != 3) 22 | luaL_error(L, "Encoding: Incorrect number of arguments.\n"); 23 | THCTensor* F_ = *(THCTensor**)luaL_checkudata(L, 1, 24 | THC_Tensor); 25 | THCTensor* C_ = *(THCTensor**)luaL_checkudata(L, 2, 26 | THC_Tensor); 27 | THCTensor* s_ = *(THCTensor**)luaL_checkudata(L, 2, 28 | THC_Tensor); 29 | THCTensor* X_ = *(THCTensor**)luaL_checkudata(L, 3, 30 | THC_Tensor); 31 | /* Check input dims */ 32 | THCState *state = cutorch_getstate(L); 33 | if (THCTensor_(nDimension)(state, F_) != 3 || 34 | THCTensor_(nDimension)(state, C_) != 2 || 35 | THCTensor_(nDimension)(state, s_) != 1 || 36 | THCTensor_(nDimension)(state, X_) != 3) 37 | luaL_error(L, "Encoding: incorrect input dims. \n"); 38 | 39 | HZEncoding_ForwardF(state, F_, C_, s_, X_); 40 | /* C function return number of the outputs */ 41 | return 0; 42 | } 43 | 44 | -------------------------------------------------------------------------------- /init.cu: -------------------------------------------------------------------------------- 1 | /*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | * Created by: Hang Zhang 3 | * ECE Department, Rutgers University 4 | * Email: zhang.hang@rutgers.edu 5 | * Copyright (c) 2016 6 | * 7 | * Free to reuse and distribute this software for research or 8 | * non-profit purpose, subject to the following conditions: 9 | * 1. The code must retain the above copyright notice, this list of 10 | * conditions. 11 | * 2. Original authors' names are not deleted. 12 | * 3. The authors' names are not used to endorse or promote products 13 | * derived from this software 14 | *+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | */ 16 | #include "TH.h" 17 | #include "luaT.h" 18 | #include 19 | #include "THCDeviceTensor.cuh" 20 | #include "THCDeviceTensorUtils.cuh" 21 | #include "common.h" 22 | 23 | /* extern function in cutorch */ 24 | struct THCState; 25 | #ifdef __cplusplus 26 | extern "C" struct THCState* cutorch_getstate(lua_State* L); 27 | #else 28 | extern struct THCState* cutorch_getstate(lua_State* L); 29 | #endif 30 | 31 | #define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME) 32 | #define torch_Tensor TH_CONCAT_STRING_3(torch., Real, Tensor) 33 | #define THCTensor TH_CONCAT_3(TH,CReal,Tensor) 34 | #define THCTensor_(NAME) TH_CONCAT_4(TH,CReal,Tensor_,NAME) 35 | #define THC_Tensor TH_CONCAT_STRING_3(torch., CReal, Tensor) 36 | #define encoding_(NAME) TH_CONCAT_3(encoding_, Real, NAME) 37 | 38 | #ifdef __cplusplus 39 | extern "C" { 40 | #endif 41 | 42 | #include "lib/HZENCODING.c" 43 | #include "THCGenerateFloatType.h" 44 | 45 | #include "generic/hzencoding.c" 46 | #include "THCGenerateFloatType.h" 47 | 48 | #ifdef __cplusplus 49 | } 50 | #endif 51 | -------------------------------------------------------------------------------- /experiments/datasets/cifar10.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- modified from https://github.com/facebook/fb.resnet.torch 3 | -- original copyrights preserves 4 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 | 6 | local t = require 'datasets/transforms' 7 | 8 | local M = {} 9 | local CifarDataset = torch.class('resnet.CifarDataset', M) 10 | 11 | function CifarDataset:__init(imageInfo, opt, split) 12 | assert(imageInfo[split], split) 13 | self.imageInfo = imageInfo[split] 14 | self.split = split 15 | end 16 | 17 | function CifarDataset:get(i) 18 | local image = self.imageInfo.data[i]:float() 19 | local label = self.imageInfo.labels[i] 20 | 21 | return { 22 | input = image, 23 | target = label, 24 | } 25 | end 26 | 27 | function CifarDataset:size() 28 | return self.imageInfo.data:size(1) 29 | end 30 | 31 | -- Computed from entire CIFAR-10 training set 32 | local meanstd = { 33 | mean = {125.3, 123.0, 113.9}, 34 | std = {63.0, 62.1, 66.7}, 35 | } 36 | 37 | function CifarDataset:preprocess(opt) 38 | if self.split == 'train' then 39 | if opt.multisize then 40 | return t.Compose{ 41 | t.ColorNormalize(meanstd), 42 | t.HorizontalFlip(0.5), 43 | --t.RandomTwoCrop(24, 36, 4), 44 | t.RandomThreeCrop(28, 32, 36, 4), 45 | } 46 | else 47 | return t.Compose{ 48 | t.ColorNormalize(meanstd), 49 | t.HorizontalFlip(0.5), 50 | t.RandomCrop(32, 4), 51 | } 52 | end 53 | elseif self.split == 'val' then 54 | return t.ColorNormalize(meanstd) 55 | else 56 | error('invalid split: ' .. self.split) 57 | end 58 | end 59 | 60 | return M.CifarDataset 61 | -------------------------------------------------------------------------------- /experiments/datasets/stl10.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local t = require 'datasets/transforms' 17 | 18 | local M = {} 19 | local STLDataset = torch.class('resnet.STLDataset', M) 20 | 21 | function STLDataset:__init(imageInfo, opt, split) 22 | assert(imageInfo[split], split) 23 | self.imageInfo = imageInfo[split] 24 | self.split = split 25 | end 26 | 27 | function STLDataset:get(i) 28 | local image = self.imageInfo.data[i]:float() 29 | local label = self.imageInfo.labels[i] 30 | 31 | return { 32 | input = image, 33 | target = label, 34 | } 35 | end 36 | 37 | function STLDataset:size() 38 | return self.imageInfo.data:size(1) 39 | end 40 | 41 | -- Same Params as in CIFAR-10 training set 42 | local meanstd = { 43 | mean = {125.3, 123.0, 113.9}, 44 | std = {63.0, 62.1, 66.7}, 45 | } 46 | 47 | function STLDataset:preprocess() 48 | if self.split == 'train' then 49 | return t.Compose{ 50 | t.ColorNormalize(meanstd), 51 | t.HorizontalFlip(0.5), 52 | t.RandomCrop(96, 12), 53 | } 54 | elseif self.split == 'val' then 55 | return t.ColorNormalize(meanstd) 56 | else 57 | error('invalid split: ' .. self.split) 58 | end 59 | end 60 | 61 | return M.STLDataset 62 | -------------------------------------------------------------------------------- /experiments/datasets/joint-gen.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local M = {} 17 | 18 | local function isvalid(opt, cachePath) 19 | local imageInfo = torch.load(cachePath) 20 | if imageInfo.basedir and imageInfo.basedir ~= opt.data then 21 | return false 22 | end 23 | return true 24 | end 25 | 26 | 27 | function M.exec(opt, cacheFile) 28 | local dir = paths.dirname(cacheFile) 29 | local cifarPath = string.format('%s/cifar10.t7', dir) 30 | local stlPath = string.format('%s/stl10.t7', dir) 31 | if not paths.filep(cifarPath) or not isvalid(opt, cifarPath) then 32 | paths.mkdir('gen') 33 | local script = paths.dofile('cifar10-gen.lua') 34 | script.exec(opt, cifarPath) 35 | end 36 | if not paths.filep(stlPath) or not isvalid(opt, stlPath) then 37 | paths.mkdir('gen') 38 | local script = paths.dofile('stl10-gen.lua') 39 | script.exec(opt, stlPath) 40 | end 41 | local cifarData = torch.load(cifarPath) 42 | local stlData = torch.load(stlPath) 43 | 44 | print(" | saving joint dataset to " .. cacheFile) 45 | torch.save(cacheFile, { 46 | train = { 47 | set1 = cifarData.train, 48 | set2 = stlData.train, 49 | }, 50 | val = { 51 | set1 = cifarData.val, 52 | set2 = stlData.val, 53 | }, 54 | }) 55 | end 56 | 57 | return M 58 | -------------------------------------------------------------------------------- /experiments/datasets/cifar10-gen.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- modified from https://github.com/facebook/fb.resnet.torch 3 | -- original copyrights preserves 4 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 | 6 | local URL = 'http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz' 7 | 8 | local M = {} 9 | 10 | local function convertToTensor(files) 11 | local data, labels 12 | 13 | for _, file in ipairs(files) do 14 | local m = torch.load(file, 'ascii') 15 | if not data then 16 | data = m.data:t() 17 | labels = m.labels:squeeze() 18 | else 19 | data = torch.cat(data, m.data:t(), 1) 20 | labels = torch.cat(labels, m.labels:squeeze()) 21 | end 22 | end 23 | 24 | -- This is *very* important. The downloaded files have labels 0-9, which do 25 | -- not work with CrossEntropyCriterion 26 | labels:add(1) 27 | 28 | return { 29 | data = data:contiguous():view(-1, 3, 32, 32), 30 | labels = labels, 31 | } 32 | end 33 | 34 | function M.exec(opt, cacheFile) 35 | print("=> Downloading CIFAR-10 dataset from " .. URL) 36 | local ok = os.execute('curl ' .. URL .. ' | tar xz -C gen/') 37 | assert(ok == true or ok == 0, 'error downloading CIFAR-10') 38 | 39 | print(" | combining dataset into a single file") 40 | local trainData = convertToTensor({ 41 | 'gen/cifar-10-batches-t7/data_batch_1.t7', 42 | 'gen/cifar-10-batches-t7/data_batch_2.t7', 43 | 'gen/cifar-10-batches-t7/data_batch_3.t7', 44 | 'gen/cifar-10-batches-t7/data_batch_4.t7', 45 | 'gen/cifar-10-batches-t7/data_batch_5.t7', 46 | }) 47 | local testData = convertToTensor({ 48 | 'gen/cifar-10-batches-t7/test_batch.t7', 49 | }) 50 | 51 | print(" | saving CIFAR-10 dataset to " .. cacheFile) 52 | torch.save(cacheFile, { 53 | train = trainData, 54 | val = testData, 55 | }) 56 | end 57 | 58 | return M 59 | -------------------------------------------------------------------------------- /generic/hzencoding.c: -------------------------------------------------------------------------------- 1 | /*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | * Created by: Hang Zhang 3 | * ECE Department, Rutgers University 4 | * Email: zhang.hang@rutgers.edu 5 | * Copyright (c) 2016 6 | * 7 | * Free to reuse and distribute this software for research or 8 | * non-profit purpose, subject to the following conditions: 9 | * 1. The code must retain the above copyright notice, this list of 10 | * conditions. 11 | * 2. Original authors' names are not deleted. 12 | * 3. The authors' names are not used to endorse or promote products 13 | * derived from this software 14 | *+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | */ 16 | #ifndef THC_GENERIC_FILE 17 | #define THC_GENERIC_FILE "generic/hzencoding.c" 18 | #else 19 | 20 | /* load the implementation detail */ 21 | #include "aggregate.c" 22 | #include "weighting.c" 23 | #include "encoding.c" 24 | 25 | /* register the functions */ 26 | static const struct luaL_Reg encoding_(Aggregate) [] = 27 | { 28 | {"Forward", encoding_(Main_Aggregate_Forward)}, 29 | {"BackwardA", encoding_(Main_Aggregate_BackwardA)}, 30 | /* end */ 31 | {NULL, NULL} 32 | }; 33 | 34 | static const struct luaL_Reg encoding_(Weighting) [] = 35 | { 36 | {"UpdateParams", encoding_(Main_Weighting_UpdateParams)}, 37 | {"BatchRowScale", encoding_(Main_Weighting_BatchRow)}, 38 | /* end */ 39 | {NULL, NULL} 40 | }; 41 | 42 | static const struct luaL_Reg encoding_(Encoding) [] = 43 | { 44 | {"ForwardF", encoding_(Main_Encoding_ForwardF)}, 45 | /* end */ 46 | {NULL, NULL} 47 | }; 48 | 49 | DLL_EXPORT int luaopen_libencoding(lua_State *L) { 50 | lua_newtable(L); 51 | lua_pushvalue(L, -1); 52 | lua_setglobal(L, "HZENCODING"); 53 | 54 | lua_newtable(L); 55 | luaT_setfuncs(L, encoding_(Aggregate), 0); 56 | lua_setfield(L, -2, "Aggregate"); 57 | 58 | lua_newtable(L); 59 | luaT_setfuncs(L, encoding_(Weighting), 0); 60 | lua_setfield(L, -2, "Weighting"); 61 | 62 | lua_newtable(L); 63 | luaT_setfuncs(L, encoding_(Encoding), 0); 64 | lua_setfield(L, -2, "Encoding"); 65 | 66 | return 1; 67 | } 68 | 69 | #endif // THC_GENERIC_FILE 70 | -------------------------------------------------------------------------------- /experiments/checkpoints.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- modified from https://github.com/facebook/fb.resnet.torch 3 | -- original copyrights preserves 4 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 | 6 | local checkpoint = {} 7 | 8 | local function deepCopy(tbl) 9 | -- creates a copy of a network with new modules and the same tensors 10 | local copy = {} 11 | for k, v in pairs(tbl) do 12 | if type(v) == 'table' then 13 | copy[k] = deepCopy(v) 14 | else 15 | copy[k] = v 16 | end 17 | end 18 | if torch.typename(tbl) then 19 | torch.setmetatable(copy, torch.typename(tbl)) 20 | end 21 | return copy 22 | end 23 | 24 | function checkpoint.latest(opt) 25 | if opt.resume == 'none' then 26 | return nil 27 | end 28 | 29 | local latestPath = paths.concat(opt.resume, 'latest.t7') 30 | if not paths.filep(latestPath) then 31 | return nil 32 | end 33 | 34 | print('=> Loading checkpoint ' .. latestPath) 35 | local latest = torch.load(latestPath) 36 | local optimState = torch.load(paths.concat(opt.resume, latest.optimFile)) 37 | 38 | return latest, optimState 39 | end 40 | 41 | function checkpoint.save(epoch, model, optimState, isBestModel, opt) 42 | -- don't save the DataParallelTable for easier loading on other machines 43 | if torch.type(model) == 'nn.DataParallelTable' then 44 | model = model:get(1) 45 | end 46 | 47 | -- create a clean copy on the CPU without modifying the original network 48 | model = deepCopy(model):float():clearState() 49 | 50 | local modelFile = 'model_' .. epoch .. '.t7' 51 | local optimFile = 'optimState_' .. epoch .. '.t7' 52 | 53 | torch.save(paths.concat(opt.save, modelFile), model) 54 | torch.save(paths.concat(opt.save, optimFile), optimState) 55 | torch.save(paths.concat(opt.save, 'latest.t7'), { 56 | epoch = epoch, 57 | modelFile = modelFile, 58 | optimFile = optimFile, 59 | }) 60 | 61 | if isBestModel then 62 | torch.save(paths.concat(opt.save, 'model_best.t7'), model) 63 | end 64 | end 65 | 66 | return checkpoint 67 | -------------------------------------------------------------------------------- /lib/HZEncoding.cu: -------------------------------------------------------------------------------- 1 | /*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | * Created by: Hang Zhang 3 | * ECE Department, Rutgers University 4 | * Email: zhang.hang@rutgers.edu 5 | * Copyright (c) 2016 6 | * 7 | * Free to reuse and distribute this software for research or 8 | * non-profit purpose, subject to the following conditions: 9 | * 1. The code must retain the above copyright notice, this list of 10 | * conditions. 11 | * 2. Original authors' names are not deleted. 12 | * 3. The authors' names are not used to endorse or promote products 13 | * derived from this software 14 | *+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | */ 16 | __global__ void HZEncoding_ForwardF_kernel ( 17 | THCDeviceTensor F, 18 | THCDeviceTensor C, 19 | THCDeviceTensor s, 20 | THCDeviceTensor X) 21 | { 22 | /* declarations of the variables */ 23 | int b, k, i, d, D; 24 | real sum; 25 | /* Get the index and channels */ 26 | b = blockIdx.z; 27 | k = blockIdx.x * blockDim.x + threadIdx.x; 28 | i = blockIdx.y * blockDim.y + threadIdx.y; 29 | D = C.getSize(1); 30 | /* boundary check for output */ 31 | if (k >= F.getSize(2) || i >= F.getSize(1)) return; 32 | /* main operation */ 33 | sum = 0; 34 | for (d=0; d F = devicetensor<3>(state, F_); 50 | THCDeviceTensor C = devicetensor<2>(state, C_); 51 | THCDeviceTensor s = devicetensor<1>(state, s_); 52 | THCDeviceTensor X = devicetensor<3>(state, X_); 53 | /* kernel function */ 54 | cudaStream_t stream = THCState_getCurrentStream(state); 55 | dim3 threads(16, 16); 56 | dim3 blocks(F.getSize(2)/16+1, F.getSize(1)/16+1, F.getSize(0)); 57 | HZEncoding_ForwardF_kernel<<>>(F, C, s, X); 58 | THCudaCheck(cudaGetLastError()); 59 | } 60 | 61 | -------------------------------------------------------------------------------- /generic/aggregate.c: -------------------------------------------------------------------------------- 1 | /*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | * Created by: Hang Zhang 3 | * ECE Department, Rutgers University 4 | * Email: zhang.hang@rutgers.edu 5 | * Copyright (c) 2016 6 | * 7 | * Free to reuse and distribute this software for research or 8 | * non-profit purpose, subject to the following conditions: 9 | * 1. The code must retain the above copyright notice, this list of 10 | * conditions. 11 | * 2. Original authors' names are not deleted. 12 | * 3. The authors' names are not used to endorse or promote products 13 | * derived from this software 14 | *+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | */ 16 | static int encoding_(Main_Aggregate_Forward)(lua_State *L) 17 | /* 18 | */ 19 | { 20 | /* Check number of inputs */ 21 | if(lua_gettop(L) != 3) 22 | luaL_error(L, "Encoding: Incorrect number of arguments.\n"); 23 | THCTensor* E_ = *(THCTensor**)luaL_checkudata(L, 1, 24 | THC_Tensor); 25 | THCTensor* A_ = *(THCTensor**)luaL_checkudata(L, 2, 26 | THC_Tensor); 27 | THCTensor* R_ = *(THCTensor**)luaL_checkudata(L, 3, 28 | THC_Tensor); 29 | /* Check input dims */ 30 | THCState *state = cutorch_getstate(L); 31 | if (THCTensor_(nDimension)(state, E_) != 3 || 32 | THCTensor_(nDimension)(state, A_) != 3 || 33 | THCTensor_(nDimension)(state, R_) != 4) 34 | luaL_error(L, "Encoding: incorrect input dims. \n"); 35 | 36 | HZAggregate_Forward(state, E_, A_, R_); 37 | /* C function return number of the outputs */ 38 | return 0; 39 | } 40 | 41 | static int encoding_(Main_Aggregate_BackwardA)(lua_State *L) 42 | /* 43 | */ 44 | { 45 | /* Check number of inputs */ 46 | if(lua_gettop(L) != 3) 47 | luaL_error(L, "Encoding: Incorrect number of arguments.\n"); 48 | THCTensor* G_ = *(THCTensor**)luaL_checkudata(L, 1, 49 | THC_Tensor); 50 | THCTensor* L_ = *(THCTensor**)luaL_checkudata(L, 2, 51 | THC_Tensor); 52 | THCTensor* R_ = *(THCTensor**)luaL_checkudata(L, 3, 53 | THC_Tensor); 54 | /* Check input dims */ 55 | THCState *state = cutorch_getstate(L); 56 | if (THCTensor_(nDimension)(state, G_) != 3 || 57 | THCTensor_(nDimension)(state, L_) != 3 || 58 | THCTensor_(nDimension)(state, R_) != 4) 59 | luaL_error(L, "Encoding: incorrect input dims. \n"); 60 | 61 | HZAggregate_BackwardA(state, G_, L_, R_); 62 | /* C function return number of the outputs */ 63 | return 0; 64 | } 65 | -------------------------------------------------------------------------------- /generic/weighting.c: -------------------------------------------------------------------------------- 1 | /*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | * Created by: Hang Zhang 3 | * ECE Department, Rutgers University 4 | * Email: zhang.hang@rutgers.edu 5 | * Copyright (c) 2016 6 | * 7 | * Free to reuse and distribute this software for research or 8 | * non-profit purpose, subject to the following conditions: 9 | * 1. The code must retain the above copyright notice, this list of 10 | * conditions. 11 | * 2. Original authors' names are not deleted. 12 | * 3. The authors' names are not used to endorse or promote products 13 | * derived from this software 14 | *+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | */ 16 | static int encoding_(Main_Weighting_UpdateParams)(lua_State *L) 17 | /* 18 | */ 19 | { 20 | /* Check number of inputs */ 21 | if(lua_gettop(L) != 3) 22 | luaL_error(L, "Encoding: Incorrect number of arguments.\n"); 23 | THCTensor* G_ = *(THCTensor**)luaL_checkudata(L, 1, 24 | THC_Tensor); 25 | THCTensor* L_ = *(THCTensor**)luaL_checkudata(L, 2, 26 | THC_Tensor); 27 | THCTensor* D_ = *(THCTensor**)luaL_checkudata(L, 3, 28 | THC_Tensor); 29 | /* Check input dims */ 30 | THCState *state = cutorch_getstate(L); 31 | if (THCTensor_(nDimension)(state, G_) != 2 || 32 | THCTensor_(nDimension)(state, L_) != 3 || 33 | THCTensor_(nDimension)(state, D_) != 3) 34 | luaL_error(L, "Encoding: incorrect input dims. \n"); 35 | 36 | HZWeighting_UpdateParams(state, G_, L_, D_); 37 | /* C function return number of the outputs */ 38 | return 0; 39 | } 40 | 41 | static int encoding_(Main_Weighting_BatchRow)(lua_State *L) 42 | /* 43 | */ 44 | { 45 | /* Check number of inputs */ 46 | if(lua_gettop(L) != 3) 47 | luaL_error(L, "Encoding: Incorrect number of arguments.\n"); 48 | THCTensor* G_ = *(THCTensor**)luaL_checkudata(L, 1, 49 | THC_Tensor); 50 | THCTensor* W_ = *(THCTensor**)luaL_checkudata(L, 2, 51 | THC_Tensor); 52 | THCTensor* L_ = *(THCTensor**)luaL_checkudata(L, 3, 53 | THC_Tensor); 54 | /* Check input dims */ 55 | THCState *state = cutorch_getstate(L); 56 | if (THCTensor_(nDimension)(state, G_) != 3 || 57 | THCTensor_(nDimension)(state, W_) != 2 || 58 | THCTensor_(nDimension)(state, L_) != 3) 59 | luaL_error(L, "Encoding: incorrect input dims. \n"); 60 | 61 | HZWeighting_BatchRowWeighting(state, G_, W_, L_); 62 | /* C function return number of the outputs */ 63 | return 0; 64 | } 65 | 66 | -------------------------------------------------------------------------------- /experiments/datasets/stl10-gen.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz' 17 | 18 | local M = {} 19 | 20 | local function convertToTensor(inputFname, inputLabelsFname) 21 | local nSamples = 0 22 | local m=torch.DiskFile(inputFname, 'r'):binary() 23 | m:seekEnd() 24 | local length = m:position() - 1 25 | local nSamplesF = length / (3*96*96) 26 | assert(nSamplesF == math.floor(nSamplesF), 'expecting numSamples to be an exact integer') 27 | nSamples = nSamples + nSamplesF 28 | m:close() 29 | 30 | local data = torch.ByteTensor(nSamples, 3, 96, 96) 31 | local index = 1 32 | local m=torch.DiskFile(inputFname, 'r'):binary() 33 | m:seekEnd() 34 | local length = m:position() - 1 35 | local nSamplesF = length / (3*96*96) 36 | m:seek(1) 37 | for j=1,nSamplesF do 38 | local store = m:readByte(3*96*96) 39 | data[index]:copy(torch.ByteTensor(store)) 40 | index = index + 1 41 | end 42 | m:close() 43 | 44 | local m=torch.DiskFile(inputLabelsFname, 'r'):binary() 45 | local labels = torch.ByteTensor(m:readByte(nSamplesF)), 46 | m:close() 47 | return { 48 | data = data:transpose(3,4),--:view(-1, 3, 96, 96), 49 | labels=labels, 50 | } 51 | end 52 | 53 | function M.exec(opt, cacheFile) 54 | print("=> Downloading STL-10 dataset from " .. URL) 55 | local ok = os.execute('curl ' .. URL .. ' | tar xz -C gen/') 56 | assert(ok == true or ok == 0, 'error downloading STL-10') 57 | 58 | 59 | local trainData = convertToTensor('gen/stl10_binary/train_X.bin', 60 | 'gen/stl10_binary/train_y.bin') 61 | local testData = convertToTensor('gen/stl10_binary/test_X.bin', 62 | 'gen/stl10_binary/test_y.bin') 63 | 64 | print(" | saving STL-10 dataset to " .. cacheFile) 65 | torch.save(cacheFile, { 66 | train = trainData, 67 | val = testData, 68 | }) 69 | end 70 | 71 | return M 72 | -------------------------------------------------------------------------------- /experiments/datasets/joint.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local t = require 'datasets/transforms' 17 | 18 | local M = {} 19 | local JointDataset = torch.class('resnet.JointDataset', M) 20 | 21 | function JointDataset:__init(imageInfo, opt, split) 22 | assert(imageInfo[split], split) 23 | self.imageInfo = imageInfo[split] 24 | self.split = split 25 | end 26 | 27 | function JointDataset:get(i) 28 | local idx1 = (i-1) % self.imageInfo.set1.data:size(1) + 1 29 | local idx2 = (i-1) % self.imageInfo.set2.data:size(1) + 1 30 | 31 | local image1 = self.imageInfo.set1.data[idx1]:float() 32 | local label1 = self.imageInfo.set1.labels[idx1] 33 | local image2 = self.imageInfo.set2.data[idx2]:float() 34 | local label2 = self.imageInfo.set2.labels[idx2] 35 | 36 | return { 37 | input = { 38 | image1, 39 | image2, 40 | }, 41 | target = { 42 | label1, 43 | label2, 44 | }, 45 | } 46 | end 47 | 48 | function JointDataset:size() 49 | return math.max(self.imageInfo.set1.data:size(1), 50 | self.imageInfo.set2.data:size(1)) 51 | end 52 | 53 | -- Same Params as in CIFAR-10 training set 54 | local meanstd = { 55 | mean = {125.3, 123.0, 113.9}, 56 | std = {63.0, 62.1, 66.7}, 57 | } 58 | 59 | function JointDataset:preprocess() 60 | if self.split == 'train' then 61 | local f1 = t.Compose{ 62 | t.ColorNormalize(meanstd), 63 | t.HorizontalFlip(0.5), 64 | t.RandomCrop(32, 4), 65 | } 66 | local f2 = t.Compose{ 67 | t.ColorNormalize(meanstd), 68 | t.HorizontalFlip(0.5), 69 | t.RandomCrop(96, 12), 70 | } 71 | return function(input) 72 | return { 73 | f1(input[1]), 74 | f2(input[2]), 75 | } 76 | end 77 | elseif self.split == 'val' then 78 | local f = t.ColorNormalize(meanstd) 79 | return function(input) 80 | return{ 81 | f(input[1]), 82 | f(input[2]), 83 | } 84 | end 85 | else 86 | error('invalid split: ' .. self.split) 87 | end 88 | end 89 | 90 | return M.JointDataset 91 | -------------------------------------------------------------------------------- /experiments/datasets/kth-gen.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local sys = require 'sys' 17 | local ffi = require 'ffi' 18 | 19 | local M = {} 20 | 21 | local function findImages(dir, append, idx) 22 | -- copied from fb.resnet.torch 23 | local imagePath = torch.CharTensor() 24 | local imageClass = torch.LongTensor() 25 | -- read the txt 26 | print('reading the file') 27 | print(dir .. append .. string.format('%i.txt', idx)) 28 | local file = io.open(dir .. append .. string.format('%i.txt', idx), 'r') 29 | local f = io.input(file) 30 | local maxLength = -1 31 | local imagePaths = {} 32 | local imageClasses = {} 33 | 34 | -- Generate a list of all the images and their class 35 | while true do 36 | local line = f:read('*line') 37 | if not line then break end 38 | 39 | local filename, classId = line:match("([^,]+) ([^,]+)") 40 | 41 | local classId = tonumber(classId) 42 | assert(classId, 'class not found: ' .. classId) 43 | 44 | table.insert(imagePaths, filename) 45 | table.insert(imageClasses, classId) 46 | 47 | maxLength = math.max(maxLength, #filename + 1) 48 | end 49 | 50 | f:close() 51 | 52 | -- Convert the generated list to a tensor for faster loading 53 | local nImages = #imagePaths 54 | local imagePath = torch.CharTensor(nImages, maxLength):zero() 55 | for i, path in ipairs(imagePaths) do 56 | ffi.copy(imagePath[i]:data(), path) 57 | end 58 | 59 | local imageClass = torch.LongTensor(imageClasses) 60 | return imagePath, imageClass 61 | end 62 | 63 | function M.exec(opt, cacheFile) 64 | -- copied from fb.resnet.torch 65 | -- find the image path names 66 | local imagePath = torch.CharTensor() -- path to each image in dataset 67 | local imageClass = torch.LongTensor() -- class index of each image (class index in self.classes) 68 | 69 | print(" | finding all images") 70 | local idx = 1 71 | local trainImagePath, trainImageClass = findImages(opt.data, '/train', idx) 72 | local valImagePath, valImageClass = findImages(opt.data, '/test', idx) 73 | 74 | assert(trainImagePath) 75 | local info = { 76 | basedir = opt.data, 77 | classList = classList, 78 | train = { 79 | imagePath = trainImagePath, 80 | imageClass = trainImageClass, 81 | }, 82 | val = { 83 | imagePath = valImagePath, 84 | imageClass = valImageClass, 85 | }, 86 | } 87 | 88 | print(" | saving list of images to " .. cacheFile) 89 | torch.save(cacheFile, info) 90 | return info 91 | end 92 | 93 | return M 94 | -------------------------------------------------------------------------------- /experiments/main.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- modified from https://github.com/facebook/fb.resnet.torch 3 | -- original copyrights preserves 4 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 | 6 | require 'torch' 7 | require 'paths' 8 | require 'optim' 9 | require 'nn' 10 | require 'encoding' 11 | csv = require 'csvigo' 12 | 13 | package.path = package.path .. ';./?.lua' 14 | local DataLoader = require 'dataloader' 15 | local models = require 'models/init' 16 | local Trainer = require 'train' 17 | local opts = require 'opts' 18 | local checkpoints = require 'checkpoints' 19 | 20 | torch.setdefaulttensortype('torch.FloatTensor') 21 | torch.setnumthreads(1) 22 | 23 | local opt = opts.parse(arg) 24 | 25 | torch.manualSeed(opt.manualSeed) 26 | cutorch.manualSeedAll(opt.manualSeed) 27 | 28 | -- Load previous checkpoint, if it exists 29 | local checkpoint, optimState = checkpoints.latest(opt) 30 | 31 | -- Create model 32 | local model, criterion = models.setup(opt, checkpoint) 33 | 34 | -- Data loading 35 | local trainLoader, valLoader = DataLoader.create(opt) 36 | 37 | -- The trainer handles the training loop and evaluation on validation set 38 | local trainer = Trainer(model, criterion, opt, optimState) 39 | 40 | if opt.testOnly then 41 | local top1Err, top5Err = trainer:test(0, valLoader) 42 | print(string.format(' * Results top1: %6.3f top5: %6.3f', top1Err, top5Err)) 43 | return 44 | end 45 | 46 | local function istable(x) 47 | return type(x) == 'table' and not torch.typename(x) 48 | end 49 | 50 | print('Total Epochs is ', opt.nEpochs) 51 | 52 | local startEpoch = checkpoint and checkpoint.epoch + 1 or opt.epochNumber 53 | local bestTop1 = math.huge 54 | local bestTop5 = math.huge 55 | for epoch = startEpoch, opt.nEpochs do 56 | -- Train for a single epoch 57 | local trainTop1, trainTop5, trainLoss = trainer:train(epoch, trainLoader) 58 | 59 | -- Run model on validation set 60 | local testTop1, testTop5 = trainer:test(epoch, valLoader) 61 | if istable(trainTop1) then 62 | csvf1 = csv.File(paths.concat(opt.save, 'ErrTracking1.csv'), 'a') 63 | csvf1:write({epoch, trainTop1[1], trainTop5[1], trainLoss, testTop1[1], testTop5[1]}) 64 | csvf1:close() 65 | csvf2 = csv.File(paths.concat(opt.save, 'ErrTracking2.csv'), 'a') 66 | csvf2:write({epoch, trainTop1[2], trainTop5[2], trainLoss, testTop1[2], testTop5[2]}) 67 | csvf2:close() 68 | else 69 | csvf = csv.File(paths.concat(opt.save, 'ErrTracking.csv'), 'a') 70 | csvf:write({epoch, trainTop1, trainTop5, trainLoss, testTop1, testTop5}) 71 | csvf:close() 72 | end 73 | local bestModel = false 74 | if istable(testTop1) then 75 | if testTop1[2] < bestTop1 then 76 | bestModel = true 77 | bestTop1 = testTop1[2] 78 | bestTop5 = testTop5[2] 79 | print(' * Best model for set 2', bestTop1, bestTop5) 80 | end 81 | elseif testTop1< bestTop1 then 82 | bestModel = true 83 | bestTop1 = testTop1 84 | bestTop5 = testTop5 85 | print(' * Best model ', bestTop1, bestTop5) 86 | end 87 | 88 | checkpoints.save(epoch, model, trainer.optimState, bestModel, opt) 89 | end 90 | 91 | print(string.format(' * Finished top1: %6.3f top5: %6.3f', 92 | bestTop1, bestTop5)) 93 | -------------------------------------------------------------------------------- /experiments/datasets/ground-gen.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local sys = require 'sys' 17 | local ffi = require 'ffi' 18 | 19 | local M = {} 20 | 21 | local function findImages(dir, append, idx) 22 | -- copied from fb.resnet.torch 23 | local imagePath = torch.CharTensor() 24 | local imageClass = torch.LongTensor() 25 | -- read the txt 26 | print('reading the file') 27 | print(dir .. append .. string.format('%i.txt', idx)) 28 | local file = io.open(dir .. append .. string.format('%i.txt', idx), 'r') 29 | local f = io.input(file) 30 | local maxLength = -1 31 | local imagePaths = {} 32 | local imageClasses = {} 33 | 34 | -- Generate a list of all the images and their class 35 | while true do 36 | local line = f:read('*line') 37 | if not line then break end 38 | 39 | local filename, classId = line:match("([^,]+) ([^,]+)") 40 | 41 | local classId = tonumber(classId) 42 | assert(classId, 'class not found: ' .. classId) 43 | 44 | table.insert(imagePaths, filename) 45 | table.insert(imageClasses, classId) 46 | 47 | maxLength = math.max(maxLength, #filename + 1) 48 | end 49 | 50 | f:close() 51 | 52 | -- Convert the generated list to a tensor for faster loading 53 | local nImages = #imagePaths 54 | local imagePath = torch.CharTensor(nImages, maxLength):zero() 55 | for i, path in ipairs(imagePaths) do 56 | ffi.copy(imagePath[i]:data(), path) 57 | end 58 | 59 | local imageClass = torch.LongTensor(imageClasses) 60 | return imagePath, imageClass 61 | end 62 | 63 | function M.exec(opt, cacheFile) 64 | -- copied from fb.resnet.torch 65 | -- find the image path names 66 | local imagePath = torch.CharTensor() -- path to each image in dataset 67 | local imageClass = torch.LongTensor() -- class index of each image (class index in self.classes) 68 | 69 | -- TODO FIXME idx = opts.fold; 70 | idx = 5; 71 | print(" | finding all training images") 72 | local trainImagePath, trainImageClass = findImages(opt.data, '/trainlist0', idx) 73 | print(" | finding all test images") 74 | local valImagePath , valImageClass = findImages(opt.data, '/testlist0' , idx) 75 | 76 | local info = { 77 | basedir = opt.data, 78 | classList = classList, 79 | train = { 80 | imagePath = trainImagePath, 81 | imageClass = trainImageClass, 82 | }, 83 | val = { 84 | imagePath = valImagePath, 85 | imageClass = valImageClass, 86 | }, 87 | } 88 | 89 | print(" | saving list of images to " .. cacheFile) 90 | torch.save(cacheFile, info) 91 | return info 92 | end 93 | 94 | return M 95 | -------------------------------------------------------------------------------- /experiments/datasets/light-gen.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local sys = require 'sys' 17 | local ffi = require 'ffi' 18 | 19 | local M = {} 20 | 21 | local function findImages(dir, append, idx) 22 | -- copied from fb.resnet.torch 23 | local imagePath = torch.CharTensor() 24 | local imageClass = torch.LongTensor() 25 | -- read the txt 26 | print('reading the file') 27 | print(dir .. append .. string.format('%i.txt', idx)) 28 | local file = io.open(dir .. append .. string.format('%i.txt', idx), 'r') 29 | local f = io.input(file) 30 | local maxLength = -1 31 | local imagePaths = {} 32 | local imageClasses = {} 33 | 34 | -- Generate a list of all the images and their class 35 | while true do 36 | local line = f:read('*line') 37 | if not line then break end 38 | 39 | local filename, classId = line:match("([^,]+) ([^,]+)") 40 | 41 | local classId = tonumber(classId) 42 | assert(classId, 'class not found: ' .. classId) 43 | 44 | table.insert(imagePaths, filename) 45 | table.insert(imageClasses, classId) 46 | 47 | maxLength = math.max(maxLength, #filename + 1) 48 | end 49 | 50 | f:close() 51 | 52 | -- Convert the generated list to a tensor for faster loading 53 | local nImages = #imagePaths 54 | local imagePath = torch.CharTensor(nImages, maxLength):zero() 55 | for i, path in ipairs(imagePaths) do 56 | ffi.copy(imagePath[i]:data(), path) 57 | end 58 | 59 | local imageClass = torch.LongTensor(imageClasses) 60 | return imagePath, imageClass 61 | end 62 | 63 | function M.exec(opt, cacheFile) 64 | -- copied from fb.resnet.torch 65 | -- find the image path names 66 | local imagePath = torch.CharTensor() -- path to each image in dataset 67 | local imageClass = torch.LongTensor() -- class index of each image (class index in self.classes) 68 | 69 | -- TODO FIXME idx = opts.fold; 70 | idx = 2; 71 | print(" | finding all training images") 72 | local trainImagePath, trainImageClass = findImages(opt.data, '/trainlist0', idx) 73 | print(" | finding all test images") 74 | local valImagePath , valImageClass = findImages(opt.data, '/testlist0' , idx) 75 | 76 | local info = { 77 | basedir = opt.data, 78 | classList = classList, 79 | train = { 80 | imagePath = trainImagePath, 81 | imageClass = trainImageClass, 82 | }, 83 | val = { 84 | imagePath = valImagePath, 85 | imageClass = valImageClass, 86 | }, 87 | } 88 | 89 | print(" | saving list of images to " .. cacheFile) 90 | torch.save(cacheFile, info) 91 | return info 92 | end 93 | 94 | return M 95 | -------------------------------------------------------------------------------- /lib/HZWeighting.cu: -------------------------------------------------------------------------------- 1 | /*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | * Created by: Hang Zhang 3 | * ECE Department, Rutgers University 4 | * Email: zhang.hang@rutgers.edu 5 | * Copyright (c) 2016 6 | * 7 | * Free to reuse and distribute this software for research or 8 | * non-profit purpose, subject to the following conditions: 9 | * 1. The code must retain the above copyright notice, this list of 10 | * conditions. 11 | * 2. Original authors' names are not deleted. 12 | * 3. The authors' names are not used to endorse or promote products 13 | * derived from this software 14 | *+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | */ 16 | __global__ void HZWeighting_UpdateParams_kernel ( 17 | THCDeviceTensor G, 18 | THCDeviceTensor L, 19 | THCDeviceTensor D) 20 | { 21 | /* declarations of the variables */ 22 | int b, k, i, N; 23 | real sum; 24 | /* Get the index and channels */ 25 | b = blockIdx.y; 26 | k = blockIdx.x * blockDim.x + threadIdx.x; 27 | N = L.getSize(1); 28 | /* boundary check for output */ 29 | if (k >= G.getSize(1)) return; 30 | /* main operation */ 31 | sum = 0; 32 | for(i=0; i G = devicetensor<2>(state, G_); 48 | THCDeviceTensor L = devicetensor<3>(state, L_); 49 | THCDeviceTensor D = devicetensor<3>(state, D_); 50 | /* kernel function */ 51 | cudaStream_t stream = THCState_getCurrentStream(state); 52 | dim3 threads(16); 53 | dim3 blocks(G.getSize(1)/16+1,G.getSize(0)); 54 | HZWeighting_UpdateParams_kernel<<>>(G, L, D); 55 | THCudaCheck(cudaGetLastError()); 56 | } 57 | 58 | __global__ void HZWeighting_BatchRowWeighing_kernel ( 59 | THCDeviceTensor G, 60 | THCDeviceTensor W, 61 | THCDeviceTensor L) 62 | { 63 | /* declarations of the variables */ 64 | int b, k, d; 65 | /* Get the index and channels */ 66 | b = blockIdx.z; 67 | d = blockIdx.x * blockDim.x + threadIdx.x; 68 | k = blockIdx.y * blockDim.y + threadIdx.y; 69 | /* boundary check for output */ 70 | if (k >= G.getSize(1) || d >= G.getSize(2)) return; 71 | /* main operation */ 72 | G[b][k][d] = L[b][k][d].ldg() * W[b][k].ldg(); 73 | } 74 | 75 | void HZWeighting_BatchRowWeighting(THCState *state, THCTensor *G_, THCTensor *W_, 76 | THCTensor *L_) 77 | /* 78 | * mapping the image pixels based on the lookuptable 79 | */ 80 | { 81 | /* Check the GPU index */ 82 | HZENCODING_assertSameGPU(state, 3, G_, W_, L_); 83 | /* Device tensors */ 84 | THCDeviceTensor G = devicetensor<3>(state, G_); 85 | THCDeviceTensor W = devicetensor<2>(state, W_); 86 | THCDeviceTensor L = devicetensor<3>(state, L_); 87 | /* kernel function */ 88 | cudaStream_t stream = THCState_getCurrentStream(state); 89 | dim3 threads(16,16); 90 | dim3 blocks(G.getSize(2)/16+1, G.getSize(1)/16+1, G.getSize(0)); 91 | HZWeighting_BatchRowWeighing_kernel<<>>(G, W, L); 92 | THCudaCheck(cudaGetLastError()); 93 | } 94 | 95 | -------------------------------------------------------------------------------- /lib/HZAggregate.cu: -------------------------------------------------------------------------------- 1 | /*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | * Created by: Hang Zhang 3 | * ECE Department, Rutgers University 4 | * Email: zhang.hang@rutgers.edu 5 | * Copyright (c) 2016 6 | * 7 | * Free to reuse and distribute this software for research or 8 | * non-profit purpose, subject to the following conditions: 9 | * 1. The code must retain the above copyright notice, this list of 10 | * conditions. 11 | * 2. Original authors' names are not deleted. 12 | * 3. The authors' names are not used to endorse or promote products 13 | * derived from this software 14 | *+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | */ 16 | __global__ void HZAggregate_Forward_kernel ( 17 | THCDeviceTensor E, 18 | THCDeviceTensor A, 19 | THCDeviceTensor R) 20 | { 21 | /* declarations of the variables */ 22 | int b, k, d, i, N; 23 | real sum; 24 | /* Get the index and channels */ 25 | b = blockIdx.z; 26 | d = blockIdx.x * blockDim.x + threadIdx.x; 27 | k = blockIdx.y * blockDim.y + threadIdx.y; 28 | N = A.getSize(1); 29 | /* boundary check for output */ 30 | sum = 0; 31 | if (d >= E.getSize(2) || k >= E.getSize(1)) return; 32 | /* main operation */ 33 | for(i=0; i E = devicetensor<3>(state, E_); 49 | THCDeviceTensor A = devicetensor<3>(state, A_); 50 | THCDeviceTensor R = devicetensor<4>(state, R_); 51 | /* kernel function */ 52 | cudaStream_t stream = THCState_getCurrentStream(state); 53 | dim3 threads(16, 16); 54 | dim3 blocks(E.getSize(2)/16+1, E.getSize(1)/16+1, 55 | E.getSize(0)); 56 | HZAggregate_Forward_kernel<<>>(E, A, R); 57 | THCudaCheck(cudaGetLastError()); 58 | } 59 | 60 | __global__ void HZAggregate_BackwardA_kernel ( 61 | THCDeviceTensor G, 62 | THCDeviceTensor L, 63 | THCDeviceTensor R) 64 | { 65 | /* declarations of the variables */ 66 | int b, k, d, i, D; 67 | real sum; 68 | /* Get the index and channels */ 69 | b = blockIdx.z; 70 | k = blockIdx.x * blockDim.x + threadIdx.x; 71 | i = blockIdx.y * blockDim.y + threadIdx.y; 72 | D = L.getSize(2); 73 | /* boundary check for output */ 74 | if (k >= G.getSize(2) || i >= G.getSize(1)) return; 75 | /* main operation */ 76 | sum = 0; 77 | for(d=0; d G = devicetensor<3>(state, G_); 93 | THCDeviceTensor L = devicetensor<3>(state, L_); 94 | THCDeviceTensor R = devicetensor<4>(state, R_); 95 | /* kernel function */ 96 | cudaStream_t stream = THCState_getCurrentStream(state); 97 | dim3 threads(16, 16); 98 | dim3 blocks(G.getSize(2)/16+1, G.getSize(1)/16+1, 99 | G.getSize(0)); 100 | HZAggregate_BackwardA_kernel<<>>(G, L, R); 101 | THCudaCheck(cudaGetLastError()); 102 | } 103 | -------------------------------------------------------------------------------- /layers/aggregate.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local Aggregate, parent = torch.class('nn.Aggregate', 'nn.Module') 17 | 18 | local function isint(x) 19 | return type(x) == 'number' and x == math.floor(x) 20 | end 21 | 22 | function Aggregate:__init(K, D) 23 | parent.__init(self) 24 | -- assertions 25 | assert(self and K and D, 'should specify K and D') 26 | assert(isint(K) and isint(D), 'K and D should be integers') 27 | self.K = K 28 | self.D = D 29 | self.gradInput = {torch.Tensor(), torch.Tensor()} 30 | end 31 | 32 | function Aggregate:updateOutput(input) 33 | assert(self) 34 | local K = self.K 35 | local D = self.D 36 | A = input[1] 37 | R = input[2] 38 | -- TODO assert inputs (A \in R^{NxK} R \in R^{NxKxD}) 39 | 40 | if A:dim() == 2 then 41 | self.output:resize(K, D) 42 | -- aggregation operation (in Matrix From) 43 | -- e_k = a_k^T * R_k, where a_k and R_k are expanded at 2nd dim 44 | for k=1,self.K do 45 | self.output:select(1, k):copy(torch.mv(R:select(2, k):t(), 46 | A:select(2, k))) 47 | end 48 | elseif A:dim() == 3 then 49 | local B = A:size(1) 50 | self.output:resize(B, K, D) 51 | for b=1, B do 52 | for k=1,self.K do 53 | self.output[b]:select(1,k):copy(torch.mv(R[b]:select(2, k):t(), 54 | A[b]:select(2, k))) 55 | end 56 | end 57 | else 58 | error('input must be 2D or 3D') 59 | end 60 | return self.output 61 | end 62 | 63 | function Aggregate:updateGradInput(input, gradOutput) 64 | assert(self) 65 | assert(self.gradInput) 66 | A = input[1] 67 | R = input[2] 68 | 69 | -- TODO assert the gtadOutput size 70 | if #self.gradInput == 0 then 71 | for i = 1, 2 do self.gradInput[i] = input[i].new() end 72 | end 73 | 74 | -- N may vary during the training 75 | self.gradInput[1]:resizeAs(input[1]):fill(0) 76 | self.gradInput[2]:resizeAs(input[2]):fill(0) 77 | 78 | 79 | if A:dim() == 2 then 80 | -- d_l/d_A \in R^{NxK} 81 | for k = 1,self.K do 82 | -- d_l/d_a_k = R_k * d_l/d_e_k 83 | self.gradInput[1]:select(2,k):copy( 84 | torch.mv(R:select(2, k), gradOutput[k]) 85 | ) 86 | end 87 | -- d_l/d_R \in R^{NxKxD} 88 | for k = 1,self.K do 89 | -- d_l/d_R_k = a_k * {d_l/d_e_k}^T 90 | self.gradInput[2]:select(2,k):addr( 91 | A:select(2,k),gradOutput[k] 92 | ) 93 | end 94 | elseif A:dim() == 3 then 95 | local B = A:size(1) 96 | -- d_l/d_A \in R^{NxK} 97 | for b=1, B do 98 | for k = 1,self.K do 99 | -- d_l/d_a_k = R_k * d_l/d_e_k 100 | self.gradInput[1][b]:select(2,k):copy( 101 | torch.mv(R[b]:select(2, k), gradOutput[b][k]) 102 | ) 103 | end 104 | -- d_l/d_R \in R^{NxKxD} 105 | for k = 1,self.K do 106 | -- d_l/d_R_k = a_k * {d_l/d_e_k}^T 107 | self.gradInput[2][b]:select(2,k):addr( 108 | A[b]:select(2,k),gradOutput[b][k] 109 | ) 110 | end 111 | end 112 | else 113 | error('input must be 2D or 3D') 114 | end 115 | 116 | return self.gradInput 117 | end 118 | 119 | function Aggregate:__tostring__() 120 | return torch.type(self) .. 121 | string.format( 122 | '(Nx%d, Nx%dx%d -> %dx%d)', 123 | self.K, self.K, self.D, self.K, self.D 124 | ) 125 | end 126 | 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Encoding 2 | Created by [Hang Zhang](http://hangzh.com/) 3 | 4 | ### Table of Contents 5 | 0. [Introduction](#introduction) 6 | 0. [Installation](#installation) 7 | 0. [Experiments](#experiments) 8 | 0. [Benchmarks](#benchmarks) 9 | 0. [Acknowldgements](#acknowldgements) 10 | 11 | ## Introduction 12 | - **Please checkout our [PyTorch implementation](https://github.com/zhanghang1989/PyTorch-Encoding) (recommended, memory efficient)**. 13 | 14 | - This repo is a Torch implementation of Encoding Layer as described in the paper: 15 | 16 | **Deep TEN: Texture Encoding Network** [[arXiv]](https://arxiv.org/pdf/1612.02844.pdf) 17 | [Hang Zhang](http://hangzh.com/), [Jia Xue](http://jiaxueweb.com/), [Kristin Dana](http://eceweb1.rutgers.edu/vision/dana.html) 18 | ``` 19 | @article{zhang2016deep, 20 | title={Deep TEN: Texture Encoding Network}, 21 | author={Zhang, Hang and Xue, Jia and Dana, Kristin}, 22 | journal={arXiv preprint arXiv:1612.02844}, 23 | year={2016} 24 | } 25 | ``` 26 | 27 |
28 | 29 | Traditional methods such as bag-of-words BoW (left) have a structural similarity to more recent FV-CNN methods (center). Each component is optimized in separate steps. In our approach (right) the entire pipeline is learned in an integrated manner, tuning each component for the task at hand (end-to-end texture/material/pattern recognition). 30 | 31 | 32 | ## Installation 33 | On Linux 34 | ```bash 35 | luarocks install https://raw.githubusercontent.com/zhanghang1989/Deep-Encoding/master/deep-encoding-scm-1.rockspec 36 | ``` 37 | On OSX 38 | ```bash 39 | CC=clang CXX=clang++ luarocks install https://raw.githubusercontent.com/zhanghang1989/Deep-Encoding/master/deep-encoding-scm-1.rockspec 40 | ``` 41 | ## Experiments 42 | - The Joint Encoding experiment in Sec4.2 will execute by default (tested using 1 Titan X GPU). This achieves *12.89%* percentage error on STL-10 dataset, which is ***49.8%*** relative improvement comparing to pervious state-of-the art *25.67%* of Zhao *et. al. 2015*.: 43 | ```bash 44 | git clone https://github.com/zhanghang1989/Deep-Encoding 45 | cd Deep-Encoding/experiments 46 | th main.lua 47 | ``` 48 | - Training Deep-TEN on MINC-2500 in Sec4.1 using 4 GPUs. 49 | 50 | 0. Please download the pre-trained 51 | [ResNet-50](https://d2j0dndfm35trm.cloudfront.net/resnet-50.t7) Torch model 52 | and the [MINC-2500](http://opensurfaces.cs.cornell.edu/static/minc/minc-2500.tar.gz) dataset to ``minc`` folder before executing the program (tested using 4 Titan X GPUs). 53 | ```bash 54 | th main.lua -retrain resnet-50.t7 -ft true \ 55 | -netType encoding -nCodes 32 -dataset minc \ 56 | -data minc/ -nClasses 23 -batchSize 64 \ 57 | -nGPU 4 -multisize true 58 | ``` 59 | 60 | 0. To get comparable results using 2 GPUs, you should change the batch size and the corresponding learning rate: 61 | ```bash 62 | th main.lua -retrain resnet-50.t7 -ft true \ 63 | -netType encoding -nCodes 32 -dataset minc \ 64 | -data minc/ -nClasses 23 -batchSize 32 \ 65 | -nGPU 2 -multisize true -LR 0.05\ 66 | ``` 67 | 68 | ### Benchmarks 69 | Dataset |MINC-2500| FMD | GTOS | KTH |4D-Light 70 | :----------------------------|:-------:|:---:|:----:|:---:|:------: 71 | FV-SIFT |46.0 |47.0 |65.5 |66.3 |58.4 72 | FV-CNN(VD) |61.8 |75.0 |77.1 |71.0 |70.4 73 | FV-CNN(VD) multi |63.1 |74.0 |79.2 |77.8 |76.5 74 | FV-CNN(ResNet)multi|69.3 |78.2 |77.1 |78.3 |77.6 75 | Deep-TEN\*(**ours**) |**81.3**|80.2±0.9|**84.5±2.9**|**84.5±3.5**|**81.7±1.0** 76 | State-of-the-Art |76.0±0.2|**82.4±1.4**| 81.4|81.1±1.5|77.0±1.1 77 | 78 | ### Acknowldgements 79 | We thank Wenhan Zhang from Physics department, Rutgers University for discussions of mathematic models. 80 | This work was supported by National Science Foundation award IIS-1421134. 81 | A GPU used for this research was donated by the NVIDIA Corporation. 82 | -------------------------------------------------------------------------------- /experiments/datasets/fmd.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local image = require 'image' 17 | local paths = require 'paths' 18 | local t = require 'datasets/transforms' 19 | local ffi = require 'ffi' 20 | 21 | local M = {} 22 | local FMDDataset = torch.class('resnet.FMDDataset', M) 23 | 24 | function FMDDataset:__init(imageInfo, opt, split) 25 | -- copied from fb.resnet.torch 26 | self.imageInfo = imageInfo[split] 27 | self.opt = opt 28 | self.split = split 29 | self.dir = paths.concat(opt.data, 'image') 30 | assert(paths.dirp(self.dir), 'directory does not exist: ' .. self.dir) 31 | end 32 | 33 | function FMDDataset:get(i) 34 | -- copied from fb.resnet.torch 35 | local path = ffi.string(self.imageInfo.imagePath[i]:data()) 36 | 37 | local image = self:_loadImage(paths.concat(self.dir, path)) 38 | local class = self.imageInfo.imageClass[i] 39 | 40 | return { 41 | input = image, 42 | target = class, 43 | } 44 | end 45 | 46 | function FMDDataset:_loadImage(path) 47 | local ok, input = pcall(function() 48 | return image.load(path, 3, 'float') 49 | end) 50 | 51 | -- Sometimes image.load fails because the file extension does not match the 52 | -- image format. In that case, use image.decompress on a ByteTensor. 53 | if not ok then 54 | local f = io.open(path, 'r') 55 | assert(f, 'Error reading: ' .. tostring(path)) 56 | local data = f:read('*a') 57 | f:close() 58 | 59 | local b = torch.ByteTensor(string.len(data)) 60 | ffi.copy(b:data(), data, b:size(1)) 61 | 62 | input = image.decompress(b, 3, 'float') 63 | end 64 | 65 | return input 66 | end 67 | 68 | function FMDDataset:size() 69 | return self.imageInfo.imageClass:size(1) 70 | end 71 | 72 | -- Computed from random subset of ImageNet training images 73 | local meanstd = { 74 | mean = { 0.485, 0.456, 0.406 }, 75 | std = { 0.229, 0.224, 0.225 }, 76 | } 77 | local pca = { 78 | eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 }, 79 | eigvec = torch.Tensor{ 80 | { -0.5675, 0.7192, 0.4009 }, 81 | { -0.5808, -0.0045, -0.8140 }, 82 | { -0.5836, -0.6948, 0.4203 }, 83 | }, 84 | } 85 | 86 | function FMDDataset:preprocess(opt) 87 | -- copied from fb.resnet.torch 88 | if self.split == 'train' then 89 | if opt.multisize then 90 | return t.Compose{ 91 | t.Scale(400), 92 | --t.RandomSizedCrop(352), 93 | t.RandomTwoSizeCrop(352, 320), 94 | t.ColorJitter({ 95 | brightness = 0.4, 96 | contrast = 0.4, 97 | saturation = 0.4, 98 | }), 99 | t.Lighting(0.1, pca.eigval, pca.eigvec), 100 | t.ColorNormalize(meanstd), 101 | t.HorizontalFlip(0.5), 102 | } 103 | else 104 | return t.Compose{ 105 | t.Scale(400), 106 | t.RandomSizedCrop(352), 107 | t.ColorJitter({ 108 | brightness = 0.4, 109 | contrast = 0.4, 110 | saturation = 0.4, 111 | }), 112 | t.Lighting(0.1, pca.eigval, pca.eigvec), 113 | t.ColorNormalize(meanstd), 114 | t.HorizontalFlip(0.5), 115 | } 116 | end 117 | elseif self.split == 'val' then 118 | local Crop = self.opt.tenCrop and t.TenCrop or t.CenterCrop 119 | return t.Compose{ 120 | t.Scale(400), 121 | t.ColorNormalize(meanstd), 122 | Crop(352), 123 | } 124 | else 125 | error('invalid split: ' .. self.split) 126 | end 127 | end 128 | 129 | return M.FMDDataset 130 | -------------------------------------------------------------------------------- /experiments/datasets/kth.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local image = require 'image' 17 | local paths = require 'paths' 18 | local t = require 'datasets/transforms' 19 | local ffi = require 'ffi' 20 | 21 | local M = {} 22 | local KTHDataset = torch.class('resnet.KTHDataset', M) 23 | 24 | function KTHDataset:__init(imageInfo, opt, split) 25 | -- copied from fb.resnet.torch 26 | self.imageInfo = imageInfo[split] 27 | self.opt = opt 28 | self.split = split 29 | self.dir = opt.data 30 | assert(paths.dirp(self.dir), 'directory does not exist: ' .. self.dir) 31 | end 32 | 33 | function KTHDataset:get(i) 34 | -- copied from fb.resnet.torch 35 | local path = ffi.string(self.imageInfo.imagePath[i]:data()) 36 | local image = self:_loadImage(paths.concat(self.dir, path)) 37 | local class = self.imageInfo.imageClass[i] 38 | 39 | return { 40 | input = image, 41 | target = class, 42 | } 43 | end 44 | 45 | function KTHDataset:_loadImage(path) 46 | local ok, input = pcall(function() 47 | return image.load(path, 3, 'float') 48 | end) 49 | 50 | -- Sometimes image.load fails because the file extension does not match the 51 | -- image format. In that case, use image.decompress on a ByteTensor. 52 | if not ok then 53 | local f = io.open(path, 'r') 54 | assert(f, 'Error reading: ' .. tostring(path)) 55 | local data = f:read('*a') 56 | f:close() 57 | 58 | local b = torch.ByteTensor(string.len(data)) 59 | ffi.copy(b:data(), data, b:size(1)) 60 | print('before decompressing', path) 61 | input = image.decompress(b, 3, 'float') 62 | end 63 | 64 | return input 65 | end 66 | 67 | function KTHDataset:size() 68 | return self.imageInfo.imageClass:size(1) 69 | end 70 | 71 | -- Computed from random subset of ImageNet training images 72 | local meanstd = { 73 | mean = { 0.485, 0.456, 0.406 }, 74 | std = { 0.229, 0.224, 0.225 }, 75 | } 76 | local pca = { 77 | eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 }, 78 | eigvec = torch.Tensor{ 79 | { -0.5675, 0.7192, 0.4009 }, 80 | { -0.5808, -0.0045, -0.8140 }, 81 | { -0.5836, -0.6948, 0.4203 }, 82 | }, 83 | } 84 | 85 | function KTHDataset:preprocess(opt) 86 | -- copied from fb.resnet.torch 87 | if self.split == 'train' then 88 | if opt.multisize then 89 | return t.Compose{ 90 | t.Scale(400), 91 | --t.RandomSizedCrop(352), 92 | t.RandomTwoCrop(352, 320), 93 | t.ColorJitter({ 94 | brightness = 0.4, 95 | contrast = 0.4, 96 | saturation = 0.4, 97 | }), 98 | t.Lighting(0.1, pca.eigval, pca.eigvec), 99 | t.ColorNormalize(meanstd), 100 | t.HorizontalFlip(0.5), 101 | } 102 | else 103 | return t.Compose{ 104 | t.Scale(400), 105 | t.RandomCrop(352), 106 | --[[ 107 | t.ColorJitter({ 108 | brightness = 0.4, 109 | contrast = 0.4, 110 | saturation = 0.4, 111 | }), 112 | t.Lighting(0.1, pca.eigval, pca.eigvec), 113 | --]] 114 | t.ColorNormalize(meanstd), 115 | t.HorizontalFlip(0.5), 116 | } 117 | end 118 | elseif self.split == 'val' then 119 | local Crop = self.opt.tenCrop and t.TenCrop or t.CenterCrop 120 | return t.Compose{ 121 | t.Scale(400), 122 | t.ColorNormalize(meanstd), 123 | Crop(352), 124 | } 125 | else 126 | error('invalid split: ' .. self.split) 127 | end 128 | end 129 | 130 | return M.KTHDataset 131 | -------------------------------------------------------------------------------- /experiments/datasets/light.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local image = require 'image' 17 | local paths = require 'paths' 18 | local t = require 'datasets/transforms' 19 | local ffi = require 'ffi' 20 | 21 | local M = {} 22 | local LightDataset = torch.class('resnet.LightDataset', M) 23 | 24 | function LightDataset:__init(imageInfo, opt, split) 25 | -- copied from fb.resnet.torch 26 | self.imageInfo = imageInfo[split] 27 | self.opt = opt 28 | self.split = split 29 | self.dir = opt.data 30 | assert(paths.dirp(self.dir), 'directory does not exist: ' .. self.dir) 31 | end 32 | 33 | function LightDataset:get(i) 34 | -- copied from fb.resnet.torch 35 | local path = ffi.string(self.imageInfo.imagePath[i]:data()) 36 | local image = self:_loadImage(paths.concat(self.dir, path)) 37 | local class = self.imageInfo.imageClass[i] 38 | 39 | return { 40 | input = image, 41 | target = class, 42 | } 43 | end 44 | 45 | function LightDataset:_loadImage(path) 46 | local ok, input = pcall(function() 47 | return image.load(path, 3, 'float') 48 | end) 49 | 50 | -- Sometimes image.load fails because the file extension does not match the 51 | -- image format. In that case, use image.decompress on a ByteTensor. 52 | if not ok then 53 | local f = io.open(path, 'r') 54 | assert(f, 'Error reading: ' .. tostring(path)) 55 | local data = f:read('*a') 56 | f:close() 57 | 58 | local b = torch.ByteTensor(string.len(data)) 59 | ffi.copy(b:data(), data, b:size(1)) 60 | print('before decompressing', path) 61 | input = image.decompress(b, 3, 'float') 62 | end 63 | 64 | return input 65 | end 66 | 67 | function LightDataset:size() 68 | return self.imageInfo.imageClass:size(1) 69 | end 70 | 71 | -- Computed from random subset of ImageNet training images 72 | local meanstd = { 73 | mean = { 0.485, 0.456, 0.406 }, 74 | std = { 0.229, 0.224, 0.225 }, 75 | } 76 | local pca = { 77 | eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 }, 78 | eigvec = torch.Tensor{ 79 | { -0.5675, 0.7192, 0.4009 }, 80 | { -0.5808, -0.0045, -0.8140 }, 81 | { -0.5836, -0.6948, 0.4203 }, 82 | }, 83 | } 84 | 85 | function LightDataset:preprocess(opt) 86 | -- copied from fb.resnet.torch 87 | if self.split == 'train' then 88 | if opt.multisize then 89 | return t.Compose{ 90 | t.Scale(400), 91 | --t.RandomSizedCrop(352), 92 | --t.RandomTwoSizeCrop(352), 93 | t.RandomTwoCrop(352, 320), 94 | --[[ 95 | t.ColorJitter({ 96 | brightness = 0.4, 97 | contrast = 0.4, 98 | saturation = 0.4, 99 | }), 100 | t.Lighting(0.1, pca.eigval, pca.eigvec), 101 | --]] 102 | t.ColorNormalize(meanstd), 103 | t.HorizontalFlip(0.5), 104 | } 105 | else 106 | return t.Compose{ 107 | t.Scale(400), 108 | t.RandomCrop(352), 109 | --[[ 110 | t.ColorJitter({ 111 | brightness = 0.4, 112 | contrast = 0.4, 113 | saturation = 0.4, 114 | }), 115 | t.Lighting(0.1, pca.eigval, pca.eigvec), 116 | --]] 117 | t.ColorNormalize(meanstd), 118 | t.HorizontalFlip(0.5), 119 | } 120 | end 121 | elseif self.split == 'val' then 122 | local Crop = self.opt.tenCrop and t.TenCrop or t.CenterCrop 123 | return t.Compose{ 124 | t.Scale(400), 125 | t.ColorNormalize(meanstd), 126 | Crop(352), 127 | } 128 | else 129 | error('invalid split: ' .. self.split) 130 | end 131 | end 132 | 133 | return M.LightDataset 134 | -------------------------------------------------------------------------------- /experiments/datasets/ground.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local image = require 'image' 17 | local paths = require 'paths' 18 | local t = require 'datasets/transforms' 19 | local ffi = require 'ffi' 20 | 21 | local M = {} 22 | local GroundDataset = torch.class('resnet.GroundDataset', M) 23 | 24 | function GroundDataset:__init(imageInfo, opt, split) 25 | -- copied from fb.resnet.torch 26 | self.imageInfo = imageInfo[split] 27 | self.opt = opt 28 | self.split = split 29 | self.dir = opt.data 30 | assert(paths.dirp(self.dir), 'directory does not exist: ' .. self.dir) 31 | end 32 | 33 | function GroundDataset:get(i) 34 | -- copied from fb.resnet.torch 35 | local path = ffi.string(self.imageInfo.imagePath[i]:data()) 36 | local image = self:_loadImage(paths.concat(self.dir, path)) 37 | local class = self.imageInfo.imageClass[i] 38 | 39 | return { 40 | input = image, 41 | target = class, 42 | } 43 | end 44 | 45 | function GroundDataset:_loadImage(path) 46 | local ok, input = pcall(function() 47 | return image.load(path, 3, 'float') 48 | end) 49 | 50 | -- Sometimes image.load fails because the file extension does not match the 51 | -- image format. In that case, use image.decompress on a ByteTensor. 52 | if not ok then 53 | local f = io.open(path, 'r') 54 | assert(f, 'Error reading: ' .. tostring(path)) 55 | local data = f:read('*a') 56 | f:close() 57 | 58 | local b = torch.ByteTensor(string.len(data)) 59 | ffi.copy(b:data(), data, b:size(1)) 60 | print('before decompressing', path) 61 | input = image.decompress(b, 3, 'float') 62 | end 63 | 64 | return input 65 | end 66 | 67 | function GroundDataset:size() 68 | return self.imageInfo.imageClass:size(1) 69 | end 70 | 71 | -- Computed from random subset of ImageNet training images 72 | local meanstd = { 73 | mean = { 0.485, 0.456, 0.406 }, 74 | std = { 0.229, 0.224, 0.225 }, 75 | } 76 | local pca = { 77 | eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 }, 78 | eigvec = torch.Tensor{ 79 | { -0.5675, 0.7192, 0.4009 }, 80 | { -0.5808, -0.0045, -0.8140 }, 81 | { -0.5836, -0.6948, 0.4203 }, 82 | }, 83 | } 84 | 85 | function GroundDataset:preprocess(opt) 86 | -- copied from fb.resnet.torch 87 | if self.split == 'train' then 88 | if opt.multisize then 89 | return t.Compose{ 90 | t.Scale(400), 91 | --t.RandomSizedCrop(352), 92 | --t.RandomTwoSizeCrop(352), 93 | t.RandomTwoCrop(352, 320), 94 | --[[ 95 | t.ColorJitter({ 96 | brightness = 0.4, 97 | contrast = 0.4, 98 | saturation = 0.4, 99 | }), 100 | t.Lighting(0.1, pca.eigval, pca.eigvec), 101 | --]] 102 | t.ColorNormalize(meanstd), 103 | t.HorizontalFlip(0.5), 104 | } 105 | else 106 | return t.Compose{ 107 | t.Scale(400), 108 | t.RandomCrop(352), 109 | --[[ 110 | t.ColorJitter({ 111 | brightness = 0.4, 112 | contrast = 0.4, 113 | saturation = 0.4, 114 | }), 115 | t.Lighting(0.1, pca.eigval, pca.eigvec), 116 | --]] 117 | t.ColorNormalize(meanstd), 118 | t.HorizontalFlip(0.5), 119 | } 120 | end 121 | elseif self.split == 'val' then 122 | local Crop = self.opt.tenCrop and t.TenCrop or t.CenterCrop 123 | return t.Compose{ 124 | t.Scale(400), 125 | t.ColorNormalize(meanstd), 126 | Crop(352), 127 | } 128 | else 129 | error('invalid split: ' .. self.split) 130 | end 131 | end 132 | 133 | return M.GroundDataset 134 | -------------------------------------------------------------------------------- /experiments/datasets/minc-gen.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local sys = require 'sys' 17 | local ffi = require 'ffi' 18 | 19 | local M = {} 20 | 21 | local function findClasses(dir) 22 | -- copied from fb.resnet.torch 23 | local dirs = paths.dir(dir) 24 | table.sort(dirs) 25 | 26 | local classList = {} 27 | local classToIdx = {} 28 | for _ ,class in ipairs(dirs) do 29 | if not classToIdx[class] and class ~= '.' and class ~= '..' then 30 | table.insert(classList, class) 31 | classToIdx[class] = #classList 32 | end 33 | end 34 | 35 | return classList, classToIdx 36 | end 37 | 38 | local function findImages(dir, classToIdx, append, idx) 39 | -- copied from fb.resnet.torch 40 | local imagePath = torch.CharTensor() 41 | local imageClass = torch.LongTensor() 42 | -- read the txt 43 | print('reading the file') 44 | print(dir .. append .. string.format('%i.txt', idx)) 45 | local file = io.open(dir .. append .. string.format('%i.txt', idx), 'r') 46 | local f = io.input(file) 47 | local maxLength = -1 48 | local imagePaths = {} 49 | local imageClasses = {} 50 | 51 | -- Generate a list of all the images and their class 52 | while true do 53 | local line = f:read('*line') 54 | if not line then break end 55 | 56 | local className = paths.basename(paths.dirname(line)) 57 | local filename = paths.basename(line) 58 | local path = className .. '/' .. filename 59 | 60 | local classId = classToIdx[className] 61 | assert(classId, 'class not found: ' .. className) 62 | 63 | table.insert(imagePaths, path) 64 | table.insert(imageClasses, classId) 65 | 66 | maxLength = math.max(maxLength, #path + 1) 67 | end 68 | 69 | f:close() 70 | 71 | -- Convert the generated list to a tensor for faster loading 72 | local nImages = #imagePaths 73 | local imagePath = torch.CharTensor(nImages, maxLength):zero() 74 | for i, path in ipairs(imagePaths) do 75 | ffi.copy(imagePath[i]:data(), path) 76 | end 77 | 78 | local imageClass = torch.LongTensor(imageClasses) 79 | return imagePath, imageClass 80 | end 81 | 82 | function M.exec(opt, cacheFile) 83 | -- copied from fb.resnet.torch 84 | -- find the image path names 85 | local imagePath = torch.CharTensor() -- path to each image in dataset 86 | local imageClass = torch.LongTensor() -- class index of each image (class index in self.classes) 87 | 88 | local imgDir = paths.concat(opt.data, 'images') 89 | local labelDir = paths.concat(opt.data, 'labels') 90 | assert(paths.dirp(imgDir), 'image directory not found: ' .. imgDir) 91 | 92 | print("=> Generating list of images") 93 | local classList, classToIdx = findClasses(imgDir) 94 | 95 | -- TODO FIXME idx = opts.fold; 96 | idx = 1; 97 | print(" | finding all training images") 98 | local trainImagePath, trainImageClass = findImages(labelDir, classToIdx, '/train', idx) 99 | print(" | finding all test images") 100 | local testImagePath , testImageClass = findImages(labelDir, classToIdx, '/test' , idx) 101 | print(" | finding all val images") 102 | local valImagePath , valImageClass = findImages(labelDir, classToIdx, '/validate' , idx) 103 | 104 | local info = { 105 | basedir = opt.data, 106 | classList = classList, 107 | train = { 108 | imagePath = trainImagePath, 109 | imageClass = trainImageClass, 110 | }, 111 | val = { 112 | imagePath = valImagePath, 113 | imageClass = valImageClass, 114 | }, 115 | test = { 116 | imagePath = testImagePath, 117 | imageClass = testImageClass, 118 | }, 119 | } 120 | 121 | print(" | saving list of images to " .. cacheFile) 122 | torch.save(cacheFile, info) 123 | return info 124 | end 125 | 126 | return M 127 | -------------------------------------------------------------------------------- /experiments/opts.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- modified from https://github.com/facebook/fb.resnet.torch 3 | -- original copyrights preserves 4 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 | 6 | local M={} 7 | 8 | function M.parse(arg) 9 | local cmd = torch.CmdLine() 10 | cmd:text() 11 | cmd:text('Options:') 12 | -- Data 13 | cmd:option('-dataset', 'joint','Options: ') 14 | cmd:option('-multisize', 'false', 'Path to dataset') 15 | cmd:option('-threesize', 'false', 'Path to dataset') 16 | cmd:option('-data', '', 'Path to dataset') 17 | cmd:option('-nSplit', 1, 'Current number of split') 18 | cmd:option('-nThreads', 8, 'Threads for data loading') 19 | cmd:option('-gen', 'gen', 'Path to save generated files') 20 | -- Model 21 | cmd:option('-netType', 'encoding', 'Options: resnet | preresnet | encoding') 22 | cmd:option('-nCodes', 16, 'Options: 2 ~ inf') 23 | cmd:option('-depth', 20, 'ResNet depth: 18 | 34 | 50 | 101 | ...', 'number') 24 | cmd:option('-bottleneck', 'false', 'Options: true | false') 25 | -- Checkpointing 26 | cmd:option('-save', 'untitle','Directory in which to save') 27 | cmd:option('-resume', 'none', 'Resume in this directory') 28 | 29 | -- Training 30 | cmd:option('-nGPU', 1, 'Number of GPUs, 1 by default') 31 | cmd:option('-batchSize', 128, 'Batch size, 128 by default') 32 | cmd:option('-nEpochs', 0, 'Number of total epochs to run') 33 | cmd:option('-shareGradInput','false','Share gradInput to reduce memory') 34 | cmd:option('-manualSeed', 0, 'Manually set RNG seed') 35 | cmd:option('-LR', 0.1, 'Initial learning rate') 36 | cmd:option('-momentum', 0.9, 'Momentum') 37 | cmd:option('-weightDecay',1e-4, 'Weight decay') 38 | -- Fine-tune 39 | cmd:option('-ft', 'false', 'Reinit the classifer for FT') 40 | cmd:option('-epochNumber',1, 'Manual epoch number (useful on restarts)') 41 | cmd:option('-retrain', 'none', 'Path to the model to retrain with') 42 | cmd:option('-nClasses', 0, 'Number of classes for FT datasets') 43 | cmd:option('-lockEpoch' ,0, 'Number of Epoachs to lock Per-trained features') 44 | 45 | -- Test 46 | cmd:option('-tenCrop', 'false', 'Ten-crop testing') 47 | cmd:option('-testOnly', 'false', 'Only testing') 48 | 49 | cmd:text() 50 | 51 | local opt = cmd:parse(arg or {}) 52 | 53 | opt.shareGradInput = opt.shareGradInput ~= 'false' 54 | opt.bottleneck = opt.bottleneck ~= 'false' 55 | opt.ft = opt.ft ~= 'false' 56 | opt.tenCrop = opt.tenCrop ~= 'false' 57 | opt.testOnly = opt.testOnly ~= 'false' 58 | opt.multisize = opt.multisize~= 'false' 59 | opt.threesize = opt.threesize~= 'false' 60 | 61 | if not paths.dirp(opt.save) and not paths.mkdir(opt.save) then 62 | cmd:error('error: unable to create checkpoint directory: ' .. opt.save .. '\n') 63 | end 64 | 65 | if opt.dataset == 'cifar10' then 66 | -- Default shortcutType=A and nEpochs=164 67 | opt.shortcutType = opt.shortcutType == '' and 'A' or opt.shortcutType 68 | opt.nEpochs = opt.nEpochs == 0 and 164 or opt.nEpochs 69 | elseif opt.dataset == 'stl10' then 70 | -- Default shortcutType=A and nEpochs=164 71 | opt.shortcutType = opt.shortcutType == '' and 'A' or opt.shortcutType 72 | opt.nEpochs = opt.nEpochs == 0 and 164 or opt.nEpochs 73 | elseif opt.dataset == 'joint' then 74 | -- Default shortcutType=A and nEpochs=164 75 | opt.shortcutType = opt.shortcutType == '' and 'A' or opt.shortcutType 76 | opt.nEpochs = opt.nEpochs == 0 and 164 or opt.nEpochs 77 | elseif opt.dataset == 'minc' then 78 | -- add the customize dataset here 79 | -- Handle the most common case of missing -data flag 80 | local trainDir = paths.concat(opt.data, 'images') 81 | if not paths.dirp(opt.data) then 82 | cmd:error('error: missing MINC data directory') 83 | elseif not paths.dirp(trainDir) then 84 | cmd:error('error: MINC missing `train` directory: ' .. trainDir) 85 | end 86 | -- Default shortcutType=B and nEpochs=90 87 | opt.shortcutType = opt.shortcutType == '' and 'B' or opt.shortcutType 88 | if opt.ft then 89 | opt.nEpochs = opt.nEpochs == 0 and 60 or opt.nEpochs 90 | else 91 | opt.nEpochs = opt.nEpochs == 0 and 164 or opt.nEpochs 92 | end 93 | 94 | else 95 | cmd:error('unknown dataset: ' .. opt.dataset) 96 | end 97 | 98 | return opt 99 | end 100 | 101 | return M 102 | -------------------------------------------------------------------------------- /experiments/datasets/minc.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local image = require 'image' 17 | local paths = require 'paths' 18 | local t = require 'datasets/transforms' 19 | local ffi = require 'ffi' 20 | 21 | local M = {} 22 | local MINCDataset = torch.class('resnet.MINCDataset', M) 23 | 24 | function MINCDataset:__init(imageInfo, opt, split) 25 | -- copied from fb.resnet.torch 26 | self.imageInfo = imageInfo[split] 27 | self.opt = opt 28 | self.split = split 29 | self.dir = paths.concat(opt.data, 'images') 30 | assert(paths.dirp(self.dir), 'directory does not exist: ' .. self.dir) 31 | end 32 | 33 | function MINCDataset:get(i) 34 | -- copied from fb.resnet.torch 35 | local path = ffi.string(self.imageInfo.imagePath[i]:data()) 36 | 37 | local image = self:_loadImage(paths.concat(self.dir, path)) 38 | local class = self.imageInfo.imageClass[i] 39 | 40 | return { 41 | input = image, 42 | target = class, 43 | } 44 | end 45 | 46 | function MINCDataset:_loadImage(path) 47 | local ok, input = pcall(function() 48 | return image.load(path, 3, 'float') 49 | end) 50 | 51 | -- Sometimes image.load fails because the file extension does not match the 52 | -- image format. In that case, use image.decompress on a ByteTensor. 53 | if not ok then 54 | local f = io.open(path, 'r') 55 | assert(f, 'Error reading: ' .. tostring(path)) 56 | local data = f:read('*a') 57 | f:close() 58 | 59 | local b = torch.ByteTensor(string.len(data)) 60 | ffi.copy(b:data(), data, b:size(1)) 61 | 62 | input = image.decompress(b, 3, 'float') 63 | end 64 | 65 | return input 66 | end 67 | 68 | function MINCDataset:size() 69 | return self.imageInfo.imageClass:size(1) 70 | end 71 | 72 | -- Computed from random subset of ImageNet training images 73 | local meanstd = { 74 | mean = { 0.485, 0.456, 0.406 }, 75 | std = { 0.229, 0.224, 0.225 }, 76 | } 77 | local pca = { 78 | eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 }, 79 | eigvec = torch.Tensor{ 80 | { -0.5675, 0.7192, 0.4009 }, 81 | { -0.5808, -0.0045, -0.8140 }, 82 | { -0.5836, -0.6948, 0.4203 }, 83 | }, 84 | } 85 | 86 | function MINCDataset:preprocess(opt) 87 | -- copied from fb.resnet.torch 88 | if self.split == 'train' then 89 | if opt.multisize then 90 | return t.Compose{ 91 | t.Scale(400), 92 | --t.RandomSizedCrop(352), 93 | t.RandomTwoSizeCrop(352, 320), 94 | t.ColorJitter({ 95 | brightness = 0.4, 96 | contrast = 0.4, 97 | saturation = 0.4, 98 | }), 99 | t.Lighting(0.1, pca.eigval, pca.eigvec), 100 | t.ColorNormalize(meanstd), 101 | t.HorizontalFlip(0.5), 102 | } 103 | elseif opt.threesize then 104 | return t.Compose{ 105 | t.Scale(400), 106 | --t.RandomSizedCrop(352), 107 | t.RandomTwoSizeCrop(352, 320, 288), 108 | t.ColorJitter({ 109 | brightness = 0.4, 110 | contrast = 0.4, 111 | saturation = 0.4, 112 | }), 113 | t.Lighting(0.1, pca.eigval, pca.eigvec), 114 | t.ColorNormalize(meanstd), 115 | t.HorizontalFlip(0.5), 116 | } 117 | 118 | else 119 | return t.Compose{ 120 | t.Scale(400), 121 | t.RandomSizedCrop(352), 122 | t.ColorJitter({ 123 | brightness = 0.4, 124 | contrast = 0.4, 125 | saturation = 0.4, 126 | }), 127 | t.Lighting(0.1, pca.eigval, pca.eigvec), 128 | t.ColorNormalize(meanstd), 129 | t.HorizontalFlip(0.5), 130 | } 131 | end 132 | elseif self.split == 'val' then 133 | local Crop = self.opt.tenCrop and t.TenCrop or t.CenterCrop 134 | return t.Compose{ 135 | t.Scale(400), 136 | t.ColorNormalize(meanstd), 137 | Crop(352), 138 | } 139 | else 140 | error('invalid split: ' .. self.split) 141 | end 142 | end 143 | 144 | return M.MINCDataset 145 | -------------------------------------------------------------------------------- /experiments/datasets/fmd-gen.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local sys = require 'sys' 17 | local ffi = require 'ffi' 18 | 19 | local M = {} 20 | 21 | local function findClasses(dir) 22 | -- copied from fb.resnet.torch 23 | local dirs = paths.dir(dir) 24 | table.sort(dirs) 25 | 26 | local classList = {} 27 | local classToIdx = {} 28 | for _ ,class in ipairs(dirs) do 29 | if not classToIdx[class] and class ~= '.' and class ~= '..' then 30 | table.insert(classList, class) 31 | classToIdx[class] = #classList 32 | end 33 | end 34 | 35 | return classList, classToIdx 36 | end 37 | 38 | local function findImages(dir, classToIdx) 39 | local imagePath = torch.CharTensor() 40 | local imageClass = torch.LongTensor() 41 | 42 | ---------------------------------------------------------------------- 43 | -- Options for the GNU and BSD find command 44 | local extensionList = {'jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG', 'ppm', 'PPM', 'bmp', 'BMP'} 45 | local findOptions = ' -iname "*.' .. extensionList[1] .. '"' 46 | for i=2,#extensionList do 47 | findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"' 48 | end 49 | 50 | -- Find all the images using the find command 51 | local f = io.popen('find -L ' .. dir .. findOptions) 52 | 53 | local maxLength = -1 54 | local imagePaths = {} 55 | local imageClasses = {} 56 | 57 | -- Generate a list of all the images and their class 58 | while true do 59 | local line = f:read('*line') 60 | if not line then break end 61 | 62 | local className = paths.basename(paths.dirname(line)) 63 | local filename = paths.basename(line) 64 | local path = dir .. '/' .. className .. '/' .. filename 65 | 66 | local classId = classToIdx[className] 67 | assert(classId, 'class not found: ' .. className) 68 | 69 | table.insert(imagePaths, path) 70 | table.insert(imageClasses, classId) 71 | 72 | maxLength = math.max(maxLength, #path + 1) 73 | end 74 | 75 | f:close() 76 | 77 | -- Convert the generated list to a tensor for faster loading 78 | local nImages = #imagePaths 79 | local imagePath = torch.CharTensor(nImages, maxLength):zero() 80 | for i, path in ipairs(imagePaths) do 81 | ffi.copy(imagePath[i]:data(), path) 82 | end 83 | 84 | local imageClass = torch.LongTensor(imageClasses) 85 | return imagePath, imageClass 86 | end 87 | 88 | function M.exec(opt, cacheFile) 89 | -- copied from fb.resnet.torch 90 | -- find the image path names 91 | local imagePath = torch.CharTensor() -- path to each image in dataset 92 | local imageClass = torch.LongTensor() -- class index of each image (class index in self.classes) 93 | 94 | local imgDir = paths.concat(opt.data, 'image') 95 | assert(paths.dirp(imgDir), 'image directory not found: ' .. imgDir) 96 | 97 | print("=> Generating list of images") 98 | local classList, classToIdx = findClasses(imgDir) 99 | 100 | idx = 1; 101 | print(" | finding all training images") 102 | local ImagePath, ImageClass = findImages(imgDir, classToIdx) 103 | 104 | -- Follow the standard in prior work (https://github.com/mcimpoi/deep-fbanks/blob/master/os_train.m#L132). 105 | trainImagePath = torch.cat(ImagePath:view(10,100,-1):narrow(2,1,40), ImagePath:view(10,100,-1):narrow(2,51,50), 2):view(900,-1) 106 | valImagePath = ImagePath:view(10,100,-1):narrow(2,41,10):reshape(100,83) 107 | trainImageClass = torch.cat(ImageClass:view(10,100):narrow(2,1,40), ImageClass:view(10,100):narrow(2,51,50), 2):view(900) 108 | valImageClass = ImageClass:view(10,100):narrow(2,41,10):reshape(100) 109 | 110 | local info = { 111 | basedir = opt.data, 112 | classList = classList, 113 | train = { 114 | imagePath = trainImagePath, 115 | imageClass = trainImageClass, 116 | }, 117 | val = { 118 | imagePath = valImagePath, 119 | imageClass = valImageClass, 120 | }, 121 | } 122 | 123 | print(" | saving list of images to " .. cacheFile) 124 | torch.save(cacheFile, info) 125 | return info 126 | end 127 | 128 | return M 129 | -------------------------------------------------------------------------------- /experiments/dataloader.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- modified from https://github.com/facebook/fb.resnet.torch 3 | -- original copyrights preserves 4 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 | 6 | local datasets = require 'datasets/init' 7 | local Threads = require 'threads' 8 | Threads.serialization('threads.sharedserialize') 9 | 10 | local M = {} 11 | local DataLoader = torch.class('resnet.DataLoader', M) 12 | 13 | function DataLoader.create(opt) 14 | -- The train and val loader 15 | local loaders = {} 16 | 17 | for i, split in ipairs{'train', 'val'} do 18 | local dataset = datasets.create(opt, split) 19 | loaders[i] = M.DataLoader(dataset, opt, split) 20 | end 21 | 22 | return table.unpack(loaders) 23 | end 24 | 25 | function DataLoader:__init(dataset, opt, split) 26 | local manualSeed = opt.manualSeed 27 | local function init() 28 | require('datasets/' .. opt.dataset) 29 | end 30 | local function main(idx) 31 | if manualSeed ~= 0 then 32 | torch.manualSeed(manualSeed + idx) 33 | end 34 | torch.setnumthreads(1) 35 | _G.dataset = dataset 36 | _G.preprocess = dataset:preprocess(opt) 37 | return dataset:size() 38 | end 39 | 40 | local threads, sizes = Threads(opt.nThreads, init, main) 41 | self.nCrops = (split == 'val' and opt.tenCrop) and 10 or 1 42 | self.threads = threads 43 | self.__size = sizes[1][1] 44 | self.batchSize = math.floor(opt.batchSize / self.nCrops) 45 | end 46 | 47 | function DataLoader:size() 48 | return math.ceil(self.__size / self.batchSize) 49 | end 50 | 51 | local function istable(x) 52 | return type(x) == 'table' and not torch.typename(x) 53 | end 54 | 55 | function DataLoader:run(epoch) 56 | local threads = self.threads 57 | local size, batchSize = self.__size, self.batchSize 58 | local perm = torch.randperm(size) 59 | 60 | local idx, sample = 1, nil 61 | 62 | local function enqueue() 63 | while idx <= size and threads:acceptsjob() do 64 | local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1)) 65 | threads:addjob( 66 | function(indices, nCrops) 67 | local sz = indices:size(1) 68 | local batch, imageSize, imSize1, imSize2 69 | local target = torch.IntTensor(sz) 70 | local tableInput = false 71 | for i, idx in ipairs(indices:totable()) do 72 | local sample = _G.dataset:get(idx) 73 | local input = _G.preprocess(sample.input, epoch) 74 | if tableInput or istable(sample.input) then 75 | tableInput = tableInput or true 76 | if not batch then 77 | imSize1 = input[1]:size():totable() 78 | imSize2 = input[2]:size():totable() 79 | batch = {torch.FloatTensor(sz, nCrops, 80 | table.unpack(imSize1)), 81 | torch.FloatTensor(sz, nCrops, 82 | table.unpack(imSize2))} 83 | target = {torch.IntTensor(sz), torch.IntTensor(sz)} 84 | end 85 | batch[1][i]:copy(input[1]) 86 | batch[2][i]:copy(input[2]) 87 | target[1][i] = sample.target[1] 88 | target[2][i] = sample.target[2] 89 | else 90 | if not batch then 91 | imageSize = input:size():totable() 92 | if nCrops > 1 then table.remove(imageSize, 1) end 93 | batch = torch.FloatTensor(sz, nCrops, table.unpack(imageSize)) 94 | end 95 | -- reinit batch buffer for diff image sizes 96 | if batch:size(4)~=input:size(2) then 97 | imageSize = input:size():totable() 98 | if nCrops > 1 then table.remove(imageSize, 1) end 99 | batch:resize(sz, nCrops, table.unpack(imageSize)) 100 | end 101 | batch[i]:copy(input) 102 | target[i] = sample.target 103 | end 104 | end 105 | collectgarbage() 106 | if tableInput then 107 | return { 108 | input = { 109 | batch[1]:view(sz * nCrops, table.unpack(imSize1)), 110 | batch[2]:view(sz * nCrops, table.unpack(imSize2)), 111 | }, 112 | target = target, 113 | } 114 | else 115 | return { 116 | input = batch:view(sz * nCrops, table.unpack(imageSize)), 117 | target = target, 118 | } 119 | end 120 | end, 121 | function(_sample_) 122 | sample = _sample_ 123 | end, 124 | indices, 125 | self.nCrops 126 | ) 127 | idx = idx + batchSize 128 | end 129 | end 130 | 131 | local n = 0 132 | local function loop() 133 | enqueue() 134 | if not threads:hasjob() then 135 | return nil 136 | end 137 | threads:dojob() 138 | if threads:haserror() then 139 | threads:synchronize() 140 | end 141 | enqueue() 142 | n = n + 1 143 | return n, sample 144 | end 145 | 146 | return loop 147 | end 148 | 149 | return M.DataLoader 150 | -------------------------------------------------------------------------------- /layers/netvlad.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2017 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local NetVLAD, parent = torch.class('nn.NetVLAD', 'nn.Module') 17 | 18 | local function isint(x) 19 | return type(x) == 'number' and x == math.floor(x) 20 | end 21 | 22 | function NetVLAD:__init(K, D) 23 | parent.__init(self) 24 | -- assertions 25 | assert(self and K and D, 'should specify K and D') 26 | assert(isint(K) and isint(D), 'K and D should be integers') 27 | self.K = K 28 | self.D = D 29 | -- the dictionary, assigning matrix and residuals 30 | self.weight = torch.Tensor(K, D) 31 | self.A = torch.Tensor() 32 | self.R = torch.Tensor() 33 | -- the assigning drops the link with centers and 34 | -- is simplified as 1x1 conv with the input 35 | self.assigner = nn.Sequential() 36 | self.assigner:add(nn.Linear(D, K, false)) 37 | self.assigner:add(nn.SoftMax()) 38 | -- the gradient parameters 39 | self.gradInput = torch.Tensor() 40 | self.gradWeight = torch.Tensor(K, D) 41 | self.gradA = torch.Tensor() 42 | -- init the dictionary 43 | self:reset() 44 | end 45 | 46 | function NetVLAD:reset(stdv) 47 | if stdv then 48 | stdv = stdv * math.sqrt(3) 49 | else 50 | stdv = 1./math.sqrt(self.weight:size(2)) 51 | end 52 | self.weight:uniform(-stdv,stdv) 53 | self.assigner:reset() 54 | return self 55 | end 56 | 57 | function NetVLAD:updateOutput(input) 58 | assert(self) 59 | assert(input:dim()==2 or input:dim()==3, 'only 2D or 3D input supported') 60 | -- calculate the A and R 61 | -- X \in R^{[Bx]NxD} 62 | local K = self.K 63 | local D = self.D 64 | 65 | if input:dim() == 2 then 66 | local N = input:size(1) 67 | -- assigning 68 | self.A = self.assigner:forward(input) 69 | -- calculate residuals 70 | self.R = input:view(N,1,D):expand(N,K,D) 71 | - self.weight:view(1,K,D):expand(N,K,D) 72 | elseif input:dim() == 3 then 73 | B = input:size(1) 74 | local N = input:size(2) 75 | -- assigning 76 | self.A = self.assigner:forward(input:view(B*N, D)):view(B, 77 | N, K) 78 | -- calculate residuals 79 | self.R = input:view(B,N,1,D):expand(B,N,K,D) 80 | - self.weight:view(1,1,K,D):expand(B,N,K,D) 81 | end 82 | 83 | if input:dim() == 2 then 84 | self.output:resize(K, D) 85 | HZENCODING.Aggregate.Forward(self.output:view(1,K,D), 86 | self.A:view(1,N,K), 87 | self.R:view(1,N,K,D)) 88 | elseif input:dim() == 3 then 89 | local B = self.A:size(1) 90 | self.output:resize(B, K, D) 91 | HZENCODING.Aggregate.Forward(self.output, self.A, self.R) 92 | end 93 | return self.output 94 | end 95 | 96 | function NetVLAD:updateGradInput(input, gradOutput) 97 | assert(self) 98 | assert(self.gradInput) 99 | -- TODO assert the gtadOutput size 100 | -- N may vary during the training 101 | self.gradA:resizeAs(self.A):fill(0) 102 | self.gradInput:resizeAs(input):fill(0) 103 | 104 | if self.A:dim() == 2 then 105 | -- d_l/d_A \in R^{NxK} 106 | local N = A:size(1) 107 | HZENCODING.Aggregate.BackwardA(self.gradA:view(1,N,K), 108 | gradOutput:view(1,K,D), self.R:view(1,N,K,D)) 109 | -- d_l/d_X = d_l/d_A * d_A/d_X + d_l/d_R * d_R/d_X 110 | self.gradInput = self.assigner:updateGradInput(input, self.gradA) 111 | + self.A * gradOutput 112 | elseif self.A:dim() == 3 then 113 | local B = self.A:size(1) 114 | local N = input:size(2) 115 | -- d_l/d_A \in R^{NxK} 116 | HZENCODING.Aggregate.BackwardA(self.gradA, gradOutput, self.R) 117 | -- d_l/d_X = d_l/d_A * d_A/d_X + d_l/d_R * d_R/d_X 118 | self.gradInput= self.assigner:updateGradInput(input:view(B*N,self.D), 119 | self.gradA:view(B*N,self.K)) 120 | :view(B,N,self.D) 121 | + torch.bmm(self.A, gradOutput) 122 | else 123 | error('input must be 2D or 3D') 124 | end 125 | 126 | return self.gradInput 127 | end 128 | 129 | function NetVLAD:accGradParameters(input, gradOutput, scale) 130 | scale = scale or 1 131 | -- update the assigner 132 | self.assigner:accUpdateGradParameters(input, self.gradA, scale) 133 | -- update the dictionary 134 | if self.A:dim() == 2 then 135 | for k = 1,self.K do 136 | -- d_l/d_c 137 | self.gradWeight[k] = -scale*self.A:select(2,k):sum() * gradOutput[k] 138 | end 139 | elseif self.A:dim() == 3 then 140 | local B = self.A:size(1) 141 | for b = 1,B do 142 | for k = 1,self.K do 143 | -- d_l/d_c 144 | self.gradWeight[k] = self.gradWeight[k] 145 | -scale * self.A[b]:select(2,k):sum() * gradOutput[b][k] 146 | end 147 | end 148 | end 149 | end 150 | 151 | function NetVLAD:__tostring__() 152 | return torch.type(self) .. 153 | string.format( 154 | '(Nx%d -> %dx%d)', 155 | self.D, self.K, self.D 156 | ) 157 | end 158 | 159 | function NetVLAD:cuda() 160 | self.assigner:cuda() 161 | return self.cuda() 162 | end 163 | 164 | function NetVLAD:training() 165 | self.assigner:training() 166 | return self 167 | end 168 | 169 | function NetVLAD:evaluation() 170 | self.assigner:evaluation() 171 | return self 172 | end 173 | -------------------------------------------------------------------------------- /experiments/models/init.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- modified from https://github.com/facebook/fb.resnet.torch 3 | -- original copyrights preserves 4 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 | 6 | require 'nn' 7 | require 'cunn' 8 | require 'cudnn' 9 | require 'encoding' 10 | 11 | local M = {} 12 | 13 | function M.setup(opt, checkpoint) 14 | local model 15 | if checkpoint then 16 | local modelPath = paths.concat(opt.resume, checkpoint.modelFile) 17 | assert(paths.filep(modelPath), 'Saved model not found: ' .. modelPath) 18 | print('=> Resuming model from ' .. modelPath) 19 | model = torch.load(modelPath):cuda() 20 | elseif opt.retrain ~= 'none' then 21 | assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain) 22 | print('Loading model from file: ' .. opt.retrain) 23 | model = torch.load(opt.retrain):cuda() 24 | else 25 | print('=> Creating model from file: models/' .. opt.netType .. '.lua') 26 | model = require('models/' .. opt.netType)(opt) 27 | end 28 | 29 | -- First remove any DataParallelTable 30 | if torch.type(model) == 'nn.DataParallelTable' then 31 | model = model:get(1) 32 | end 33 | 34 | -- optnet is an general library for reducing memory usage in neural networks 35 | if opt.optnet then 36 | local optnet = require 'optnet' 37 | local imsize = opt.dataset == 'imagenet' and 224 or 32 38 | local sampleInput = torch.zeros(4,3,imsize,imsize):cuda() 39 | optnet.optimizeMemory(model, sampleInput, {inplace = false, mode = 'training'}) 40 | end 41 | 42 | -- This is useful for fitting ResNet-50 on 4 GPUs, but requires that all 43 | -- containers override backwards to call backwards recursively on submodules 44 | if opt.shareGradInput then 45 | M.shareGradInput(model) 46 | end 47 | 48 | -- For resetting the classifier when fine-tuning on a different Dataset 49 | if opt.ft and not checkpoint then 50 | print(' => Replacing classifier with ' .. opt.nClasses .. '-way classifier') 51 | local orig = model:get(#model.modules) 52 | assert(torch.type(orig) == 'nn.Linear', 53 | 'expected last layer to be fully connected') 54 | 55 | if opt.netType == 'encoding' then 56 | -- FC 57 | model:remove(#model.modules) 58 | -- View 59 | model:remove(#model.modules) 60 | -- Avg Pool 61 | model:remove(#model.modules) 62 | -- Assuming ResNet has 2048 channels 63 | local m1 = model:clone() 64 | local m2 = nn.Sequential() 65 | -- 1x1 conv to reduce channels 66 | nInputPlane = 2048 67 | nOutputPlane = 128 68 | m2:add(cudnn.SpatialConvolution(nInputPlane, nOutputPlane, 1, 1, 1, 1)) 69 | m2:get(#m2.modules).bias:zero() 70 | -- BN and ReLu ? 71 | m2:add(nn.SpatialBatchNormalization(nOutputPlane)) 72 | m2:add(cudnn.ReLU(true)) 73 | -- BxCxWxH => BxCxN 74 | m2:add(nn.View(nOutputPlane,-1):setNumInputDims(3)) 75 | -- BxCxN => BxNxC 76 | m2:add(nn.Transpose({2,3})) 77 | m2:add(nn.Encoding(opt.nCodes, nOutputPlane)) 78 | --m2:add(nn.View(-1, nOutputPlane))--:setNumInputDims(2)) 79 | --m2:add(nn.Normalize(2)) 80 | m2:add(nn.View(-1):setNumInputDims(2)) 81 | m2:add(nn.Normalize(2)) 82 | m2:add(nn.Linear(nOutputPlane*opt.nCodes, opt.nClasses)) 83 | m2:get(#m2.modules).bias:zero() 84 | 85 | model = nn.Sequential() 86 | model:add(m1) 87 | model:add(m2:cuda()) 88 | 89 | else 90 | local linear = nn.Linear(orig.weight:size(2), opt.nClasses) 91 | linear.bias:zero() 92 | 93 | model:remove(#model.modules) 94 | local m1 = model:clone() 95 | local m2 = nn.Sequential() 96 | m2:add(linear:cuda()) 97 | 98 | model = nn.Sequential() 99 | model:add(m1) 100 | model:add(m2) 101 | end 102 | end 103 | 104 | -- Set the CUDNN flags 105 | if opt.cudnn == 'fastest' then 106 | cudnn.fastest = true 107 | cudnn.benchmark = true 108 | elseif opt.cudnn == 'deterministic' then 109 | -- Use a deterministic convolution implementation 110 | model:apply(function(m) 111 | if m.setMode then m:setMode(1, 1, 1) end 112 | end) 113 | end 114 | 115 | -- Wrap the model with DataParallelTable, if using more than one GPU 116 | if opt.nGPU > 1 then 117 | local gpus = torch.range(1, opt.nGPU):totable() 118 | local fastest, benchmark = cudnn.fastest, cudnn.benchmark 119 | 120 | local dpt = nn.DataParallelTable(1, true, true) 121 | :add(model, gpus) 122 | :threads(function() 123 | require 'encoding' 124 | local cudnn = require 'cudnn' 125 | cudnn.fastest, cudnn.benchmark = fastest, benchmark 126 | end) 127 | dpt.gradInput = nil 128 | 129 | model = dpt:cuda() 130 | end 131 | print(model) 132 | local criterion 133 | if opt.dataset == 'joint' or opt.dataset == 'joint2' then 134 | criterion = nn.ParallelCriterion():add(nn.CrossEntropyCriterion()):add(nn.CrossEntropyCriterion()):cuda() 135 | else 136 | criterion = nn.CrossEntropyCriterion():cuda() 137 | end 138 | return model, criterion 139 | end 140 | 141 | function M.shareGradInput(model) 142 | local function sharingKey(m) 143 | local key = torch.type(m) 144 | if m.__shareGradInputKey then 145 | key = key .. ':' .. m.__shareGradInputKey 146 | end 147 | return key 148 | end 149 | 150 | -- Share gradInput for memory efficient backprop 151 | local cache = {} 152 | model:apply(function(m) 153 | local moduleType = torch.type(m) 154 | if torch.isTensor(m.gradInput) and moduleType ~= 'nn.ConcatTable' then 155 | local key = sharingKey(m) 156 | if cache[key] == nil then 157 | cache[key] = torch.CudaStorage(1) 158 | end 159 | m.gradInput = torch.CudaTensor(cache[key], 1, 0) 160 | end 161 | end) 162 | for i, m in ipairs(model:findModules('nn.ConcatTable')) do 163 | if cache[i % 2] == nil then 164 | cache[i % 2] = torch.CudaStorage(1) 165 | end 166 | m.gradInput = torch.CudaTensor(cache[i % 2], 1, 0) 167 | end 168 | end 169 | 170 | return M 171 | -------------------------------------------------------------------------------- /layers/encoding.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local Encoding, parent = torch.class('nn.Encoding', 'nn.Module') 17 | 18 | local function isint(x) 19 | return type(x) == 'number' and x == math.floor(x) 20 | end 21 | 22 | function Encoding:__init(K, D) 23 | parent.__init(self) 24 | -- assertions 25 | assert(self and K and D, 'should specify K and D') 26 | assert(isint(K) and isint(D), 'K and D should be integers') 27 | self.K = K 28 | self.D = D 29 | -- the dictionary 30 | self.weight = torch.Tensor(K, D) 31 | -- the assigning factors (smoothing) 32 | self.bias = torch.Tensor(K) 33 | -- the assigning matrix and residuals 34 | self.A = torch.Tensor() 35 | self.R = torch.Tensor() 36 | -- the soft assigning operation 37 | self.soft = nn.SoftMax() 38 | self.batchMul = nn.MM() 39 | -- the gradient parameters 40 | self.gradInput = torch.Tensor() 41 | self.gradWeight = torch.Tensor(K, D) 42 | self.gradBias = torch.Tensor(K):abs() 43 | self.gradA = torch.Tensor() 44 | -- init the dictionary 45 | self:reset() 46 | end 47 | 48 | function Encoding:reset(stdv) 49 | local stdv1, stdv2 50 | if stdv then 51 | stdv1 = stdv * math.sqrt(3) 52 | stdv2 = stdv * math.sqrt(3) 53 | else 54 | stdv1 = 1./math.sqrt(self.K * self.D) 55 | stdv2 = 1./math.sqrt(self.K) 56 | end 57 | self.weight:uniform(-stdv1,stdv1) 58 | self.bias:uniform(-stdv2,stdv2) 59 | return self 60 | end 61 | 62 | function Encoding:updateOutput(input) 63 | assert(self) 64 | assert(input:dim()==2 or input:dim()==3, 'only 2D or 3D input supported') 65 | -- lazy init for weighted L2 66 | self.L2 = self.L2 or self.A.new() 67 | self.SL2 = self.SL2 or self.A.new() 68 | -- calculate the A and R 69 | -- X \in R^{[Bx]NxD} 70 | local K = self.K 71 | local D = self.D 72 | local N 73 | if input:dim() == 2 then 74 | N = input:size(1) 75 | -- calculate residuals 76 | self.R = input:view(N,1,D):expand(N,K,D) 77 | - self.weight:view(1,K,D):expand(N,K,D) 78 | -- L2 norm of r_ik (assuming the N and K > 1) 79 | self.L2 = self.R:clone() 80 | self.L2 = self.L2:pow(2):sum(3):squeeze() 81 | -- weighted 82 | self.SL2 = - self.L2 * self.bias:diag() 83 | self.A = self.soft:forward(self.SL2) 84 | elseif input:dim() == 3 then 85 | local B = input:size(1) 86 | N = input:size(2) 87 | -- calculate residuals 88 | self.R:resize(B,N,K,D) 89 | self.A:resize(B,N,K) 90 | self.SL2:resize(B,N,K) 91 | 92 | self.R:copy( input:view(B,N,1,D):expand(B,N,K,D) 93 | - self.weight:view(1,1,K,D):expand(B,N,K,D) ) 94 | -- L2 norm of r_ik 95 | self.L2 = self.R:clone() 96 | self.L2 = self.L2:pow(2):sum(4):view(B,N,K) 97 | -- weighted 98 | self.SL2:copy( -torch.bmm( 99 | self.L2, self.bias:diag():view(1,K,K):expand(B,K,K)) ) 100 | self.A:copy( self.soft:forward( 101 | self.SL2:view(B*N,K) 102 | ):view(B,N,K) ) 103 | end 104 | 105 | if input:dim() == 2 then 106 | self.output:resize(K, D) 107 | HZENCODING.Aggregate.Forward(self.output:view(1,K,D), 108 | self.A:view(1,N,K), 109 | self.R:view(1,N,K,D)) 110 | elseif input:dim() == 3 then 111 | local B = self.A:size(1) 112 | self.output:resize(B, K, D) 113 | HZENCODING.Aggregate.Forward(self.output, self.A, self.R) 114 | end 115 | return self.output 116 | end 117 | 118 | function Encoding:updateGradInput(input, gradOutput) 119 | assert(self) 120 | assert(self.gradInput) 121 | assert(input:dim()==2 or input:dim()==3, 'only 2D or 3D input supported') 122 | -- N may vary during the training 123 | local K = self.K 124 | local D = self.D 125 | self.gradA:resizeAs(self.A):fill(0) 126 | self.gradInput:resizeAs(input):fill(0) 127 | self.gradSL2 = self.gradSL2 or self.gradA.new() 128 | 129 | if input:dim() == 2 then 130 | -- d_l/d_A \in R^{NxK} 131 | local N = A:size(1) 132 | HZENCODING.Aggregate.BackwardA(self.gradA:view(1,N,K), 133 | gradOutput:view(1,K,D), self.R:view(1,N,K,D)) 134 | -- d_l/d_X = d_l/d_A * d_A/d_X + d_l/d_R * d_R/d_X 135 | self.gradSL2 = self.soft:updateGradInput(self.SL2, self.gradA) 136 | self.gradInput:copy( 2*torch.bmm( 137 | (self.gradSL2 * self.bias:diag()):view(N,1,K), 138 | self.R):squeeze() 139 | + self.A * gradOutput ) 140 | elseif input:dim() == 3 then 141 | local B = self.A:size(1) 142 | local N = input:size(2) 143 | -- d_l/d_A \in R^{NxK} 144 | HZENCODING.Aggregate.BackwardA(self.gradA, gradOutput, self.R) 145 | -- d_l/d_X = d_l/d_A * d_A/d_X + d_l/d_R * d_R/d_X 146 | self.gradSL2 = self.soft:updateGradInput(self.SL2:view(B*N,K), 147 | self.gradA:view(B*N,K)):view(B,N,K) 148 | self.gradInput:copy( 2*torch.bmm( 149 | (self.gradSL2:view(B*N,K)*self.bias:diag()):view(B*N,1,K), 150 | self.R:view(B*N,K,D)):view(B,N,D) 151 | + torch.bmm(self.A, gradOutput) ) 152 | end 153 | return self.gradInput 154 | end 155 | 156 | function Encoding:accGradParameters(input, gradOutput, scale) 157 | scale = scale or 1 158 | local K = self.K 159 | local D = self.D 160 | self.bufBias = self.bufBias or self.bias.new() 161 | self.bufWeight = self.bufWeight or self.weight.new() 162 | 163 | if input:dim() == 2 then 164 | local N = input:size(1) 165 | -- d_loss/d_C = d_loss/d_R * d_R/d_C + d_loss/d_A * d_A/d_C 166 | HZENCODING.Weighting.UpdateParams(self.gradBias:view(1,K), 167 | self.gradSL2:view(1,N,K), self.L2:view(1,N,K)) 168 | for k = 1,self.K do 169 | -- d_l/d_c 170 | self.gradWeight[k] = -scale*self.A:select(2,k):sum()*gradOutput[k] 171 | -2*scale*self.gradSL2:select(2,k):reshape(1,N) 172 | * self.bias[k] * self.R:select(2,k) 173 | end 174 | elseif input:dim() == 3 then 175 | local B = self.A:size(1) 176 | local N = input:size(2) 177 | -- batch gradient of s_k 178 | self.bufBias:resize(B, K) 179 | HZENCODING.Weighting.UpdateParams(self.bufBias, self.gradSL2, self.L2) 180 | -- average the gradient for s in the batch instead of sum, avoid overflow 181 | self.gradBias:copy(scale * self.bufBias:sum(1):squeeze()) 182 | 183 | self.bufWeight:resize(B,K,D) 184 | HZENCODING.Weighting.BatchRowScale(self.bufWeight, self.A:sum(2):squeeze(), gradOutput) 185 | self.gradWeight:copy( -2*scale* torch.bmm( 186 | (self.gradSL2:view(B*N,K)*self.bias:diag()):view(B,N,K):transpose(2,3):reshape(B*K,1,N), 187 | self.R:transpose(2,3):reshape(B*K,N,D)):view(B,K,D):sum(1):squeeze() 188 | -scale * self.bufWeight:sum(1):squeeze()) 189 | end 190 | end 191 | 192 | function Encoding:__tostring__() 193 | return torch.type(self) .. 194 | string.format( 195 | '(Nx%d -> %dx%d)', 196 | self.D, self.K, self.D 197 | ) 198 | end 199 | 200 | -------------------------------------------------------------------------------- /experiments/models/encoding.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2016 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | 17 | local nn = require 'nn' 18 | require 'cunn' 19 | require 'cudnn' 20 | require 'encoding' 21 | 22 | local Convolution = cudnn.SpatialConvolution 23 | local Avg = cudnn.SpatialAveragePooling 24 | local ReLU = cudnn.ReLU 25 | local Max = nn.SpatialMaxPooling 26 | local SBatchNorm = nn.SpatialBatchNormalization 27 | 28 | 29 | local function createModel(opt) 30 | local depth = opt.depth 31 | local shortcutType = opt.shortcutType or 'B' 32 | local iChannels 33 | 34 | -- The shortcut layer is either identity or 1x1 convolution 35 | local function shortcut(nInputPlane, nOutputPlane, stride) 36 | local useConv = shortcutType == 'C' or 37 | (shortcutType == 'B' and nInputPlane ~= nOutputPlane) 38 | if useConv then 39 | -- 1x1 convolution 40 | return nn.Sequential() 41 | :add(Convolution(nInputPlane, nOutputPlane, 1, 1, stride, stride)) 42 | elseif nInputPlane ~= nOutputPlane then 43 | -- Strided, zero-padded identity shortcut 44 | return nn.Sequential() 45 | :add(nn.SpatialAveragePooling(1, 1, stride, stride)) 46 | :add(nn.Concat(2) 47 | :add(nn.Identity()) 48 | :add(nn.MulConstant(0))) 49 | else 50 | return nn.Identity() 51 | end 52 | end 53 | 54 | local function ShareGradInput(module, key) 55 | assert(key) 56 | module.__shareGradInputKey = key 57 | return module 58 | end 59 | 60 | local function basicblock(n, stride, type) 61 | local nInputPlane = iChannels 62 | iChannels = n 63 | 64 | local block = nn.Sequential() 65 | local s = nn.Sequential() 66 | if type == 'both_preact' then 67 | block:add(ShareGradInput(SBatchNorm(nInputPlane), 'preact')) 68 | block:add(ReLU(true)) 69 | elseif type ~= 'no_preact' then 70 | s:add(SBatchNorm(nInputPlane)) 71 | s:add(ReLU(true)) 72 | end 73 | s:add(Convolution(nInputPlane,n,3,3,stride,stride,1,1)) 74 | s:add(SBatchNorm(n)) 75 | s:add(ReLU(true)) 76 | s:add(Convolution(n,n,3,3,1,1,1,1)) 77 | 78 | return block 79 | :add(nn.ConcatTable() 80 | :add(s) 81 | :add(shortcut(nInputPlane, n, stride))) 82 | :add(nn.CAddTable(true)) 83 | end 84 | 85 | local function bottleneck(n, stride, type) 86 | local nInputPlane = iChannels 87 | iChannels = n * 4 88 | 89 | local block = nn.Sequential() 90 | local s = nn.Sequential() 91 | if type == 'both_preact' then 92 | block:add(ShareGradInput(SBatchNorm(nInputPlane), 'preact')) 93 | block:add(ReLU(true)) 94 | elseif type ~= 'no_preact' then 95 | s:add(SBatchNorm(nInputPlane)) 96 | s:add(ReLU(true)) 97 | end 98 | s:add(Convolution(nInputPlane,n,1,1,1,1,0,0)) 99 | s:add(SBatchNorm(n)) 100 | s:add(ReLU(true)) 101 | s:add(Convolution(n,n,3,3,stride,stride,1,1)) 102 | s:add(SBatchNorm(n)) 103 | s:add(ReLU(true)) 104 | s:add(Convolution(n,n*4,1,1,1,1,0,0)) 105 | 106 | return block 107 | :add(nn.ConcatTable() 108 | :add(s) 109 | :add(shortcut(nInputPlane, n * 4, stride))) 110 | :add(nn.CAddTable(true)) 111 | end 112 | -- Creates count residual blocks with specified number of features 113 | local function layer(block, features, count, stride, type) 114 | local s = nn.Sequential() 115 | if count < 1 then 116 | return s 117 | end 118 | s:add(block(features, stride, 119 | type == 'first' and 'no_preact' or 'both_preact')) 120 | for i=2,count do 121 | s:add(block(features, 1)) 122 | end 123 | return s 124 | end 125 | 126 | local model = nn.Sequential() 127 | if opt.dataset == 'cifar10' or opt.dataset == 'stl10' then 128 | print('opt.bottleneck', opt.bottleneck) 129 | if opt.bottleneck then 130 | 131 | else 132 | -- Model type specifies number of layers for CIFAR-10 model 133 | assert((depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56,.. 1202') 134 | local n = (depth - 2) / 6 135 | iChannels = 16 136 | print(' | Encoding-' .. depth .. ' CIFAR-10') 137 | 138 | -- The ResNet CIFAR-10 model 139 | model:add(Convolution(3,16,3,3,1,1,1,1)) 140 | model:add(SBatchNorm(16)) 141 | model:add(ReLU(true)) 142 | model:add(layer(basicblock, 16, n)) 143 | model:add(layer(basicblock, 32, n, 2)) 144 | model:add(layer(basicblock, 64, n, 2)) 145 | model:add(nn.View(64, -1):setNumInputDims(3)) 146 | model:add(nn.Transpose({2,3})) 147 | model:add(nn.Encoding(opt.nCodes, 64)) 148 | model:add(nn.View(-1):setNumInputDims(2)) 149 | model:add(nn.Normalize(2)) 150 | model:add(nn.Linear(64*opt.nCodes, 10)) 151 | print(model) 152 | end 153 | elseif opt.dataset == 'joint' then 154 | assert((depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56,.. 1202') 155 | local n = (depth - 2) / 6 156 | iChannels = 16 157 | -- joint encoding for cifar10 and stl10 158 | local m1 = nn.Sequential() 159 | m1:add(Convolution(3,16,3,3,1,1,1,1)) 160 | m1:add(SBatchNorm(16)) 161 | m1:add(ReLU(true)) 162 | m1:add(layer(basicblock, 16, n)) 163 | m1:add(layer(basicblock, 32, n, 2)) 164 | m1:add(layer(basicblock, 64, n, 2)) 165 | m1:add(nn.View(64,-1):setNumInputDims(3)) 166 | m1:add(nn.Transpose({2,3})) 167 | -- sharing weights for joint training 168 | local m2 = m1:clone('weight','bias','gradWeight','gradBias'); 169 | 170 | local model1=nn.Sequential() 171 | model1:add(m1) 172 | model1:add(nn.Encoding(opt.nCodes, 64)) 173 | model1:add(nn.View(-1):setNumInputDims(2)) 174 | model1:add(nn.Normalize(2)) 175 | model1:add(nn.Linear(64*opt.nCodes, 10)) 176 | 177 | local model2=nn.Sequential() 178 | model2:add(m2) 179 | model2:add(nn.Encoding(opt.nCodes, 64)) 180 | model2:add(nn.View(-1):setNumInputDims(2)) 181 | model2:add(nn.Normalize(2)) 182 | model2:add(nn.Linear(64*opt.nCodes, 10)) 183 | 184 | model = nn.ParallelTable() 185 | :add(model1) 186 | :add(model2) 187 | 188 | print(model) 189 | else 190 | error('invalid dataset: ' .. opt.dataset) 191 | end 192 | 193 | local function ConvInit(name) 194 | for k,v in pairs(model:findModules(name)) do 195 | local n = v.kW*v.kH*v.nOutputPlane 196 | v.weight:normal(0,math.sqrt(2/n)) 197 | if cudnn.version >= 4000 then 198 | v.bias = nil 199 | v.gradBias = nil 200 | else 201 | v.bias:zero() 202 | end 203 | end 204 | end 205 | local function BNInit(name) 206 | for k,v in pairs(model:findModules(name)) do 207 | v.weight:fill(1) 208 | v.bias:zero() 209 | end 210 | end 211 | 212 | ConvInit('cudnn.SpatialConvolution') 213 | ConvInit('nn.SpatialConvolution') 214 | BNInit('fbnn.SpatialBatchNormalization') 215 | BNInit('cudnn.SpatialBatchNormalization') 216 | BNInit('nn.SpatialBatchNormalization') 217 | for k,v in pairs(model:findModules('nn.Linear')) do 218 | v.bias:zero() 219 | end 220 | model:cuda() 221 | 222 | if opt.cudnn == 'deterministic' then 223 | model:apply(function(m) 224 | if m.setMode then m:setMode(1,1,1) end 225 | end) 226 | end 227 | 228 | model:get(1).gradInput = nil 229 | 230 | return model 231 | end 232 | 233 | return createModel 234 | -------------------------------------------------------------------------------- /cmake/select_compute_arch.cmake: -------------------------------------------------------------------------------- 1 | # Synopsis: 2 | # CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures]) 3 | # -- Selects GPU arch flags for nvcc based on target_CUDA_architectures 4 | # target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...) 5 | # - "Auto" detects local machine GPU compute arch at runtime. 6 | # - "Common" and "All" cover common and entire subsets of architectures 7 | # ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX 8 | # NAME: Fermi Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal 9 | # NUM: Any number. Only those pairs are currently accepted by NVCC though: 10 | # 2.0 2.1 3.0 3.2 3.5 3.7 5.0 5.2 5.3 6.0 6.2 11 | # Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable} 12 | # Additionally, sets ${out_variable}_readable to the resulting numeric list 13 | # Example: 14 | # CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell) 15 | # LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS}) 16 | # 17 | # More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA 18 | # 19 | 20 | # This list will be used for CUDA_ARCH_NAME = All option 21 | set(CUDA_KNOWN_GPU_ARCHITECTURES "Fermi" "Kepler" "Maxwell") 22 | 23 | # This list will be used for CUDA_ARCH_NAME = Common option (enabled by default) 24 | set(CUDA_COMMON_GPU_ARCHITECTURES "3.0" "3.5" "5.0") 25 | 26 | if (CUDA_VERSION VERSION_GREATER "6.5") 27 | list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Kepler+Tegra" "Kepler+Tesla" "Maxwell+Tegra") 28 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2") 29 | endif () 30 | 31 | if (CUDA_VERSION VERSION_GREATER "7.5") 32 | list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Pascal") 33 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.0" "6.1" "6.1+PTX") 34 | else() 35 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2+PTX") 36 | endif () 37 | 38 | 39 | 40 | ################################################################################################ 41 | # A function for automatic detection of GPUs installed (if autodetection is enabled) 42 | # Usage: 43 | # CUDA_DETECT_INSTALLED_GPUS(OUT_VARIABLE) 44 | # 45 | function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE) 46 | if(NOT CUDA_GPU_DETECT_OUTPUT) 47 | set(cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu) 48 | 49 | file(WRITE ${cufile} "" 50 | "#include \n" 51 | "int main()\n" 52 | "{\n" 53 | " int count = 0;\n" 54 | " if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n" 55 | " if (count == 0) return -1;\n" 56 | " for (int device = 0; device < count; ++device)\n" 57 | " {\n" 58 | " cudaDeviceProp prop;\n" 59 | " if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n" 60 | " std::printf(\"%d.%d \", prop.major, prop.minor);\n" 61 | " }\n" 62 | " return 0;\n" 63 | "}\n") 64 | 65 | execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "--run" "${cufile}" 66 | "-ccbin" ${CMAKE_CXX_COMPILER} 67 | WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/" 68 | RESULT_VARIABLE nvcc_res OUTPUT_VARIABLE nvcc_out 69 | ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) 70 | 71 | if(nvcc_res EQUAL 0) 72 | string(REPLACE "2.1" "2.1(2.0)" nvcc_out "${nvcc_out}") 73 | set(CUDA_GPU_DETECT_OUTPUT ${nvcc_out} CACHE INTERNAL "Returned GPU architetures from detect_gpus tool" FORCE) 74 | endif() 75 | endif() 76 | 77 | if(NOT CUDA_GPU_DETECT_OUTPUT) 78 | message(STATUS "Automatic GPU detection failed. Building for common architectures.") 79 | set(${OUT_VARIABLE} ${CUDA_COMMON_GPU_ARCHITECTURES} PARENT_SCOPE) 80 | else() 81 | set(${OUT_VARIABLE} ${CUDA_GPU_DETECT_OUTPUT} PARENT_SCOPE) 82 | endif() 83 | endfunction() 84 | 85 | 86 | ################################################################################################ 87 | # Function for selecting GPU arch flags for nvcc based on CUDA architectures from parameter list 88 | # Usage: 89 | # SELECT_NVCC_ARCH_FLAGS(out_variable [list of CUDA compute archs]) 90 | function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable) 91 | set(CUDA_ARCH_LIST "${ARGN}") 92 | 93 | if("X${CUDA_ARCH_LIST}" STREQUAL "X" ) 94 | set(CUDA_ARCH_LIST "Auto") 95 | endif() 96 | 97 | set(cuda_arch_bin) 98 | set(cuda_arch_ptx) 99 | 100 | if("${CUDA_ARCH_LIST}" STREQUAL "All") 101 | set(CUDA_ARCH_LIST ${CUDA_KNOWN_GPU_ARCHITECTURES}) 102 | elseif("${CUDA_ARCH_LIST}" STREQUAL "Common") 103 | set(CUDA_ARCH_LIST ${CUDA_COMMON_GPU_ARCHITECTURES}) 104 | elseif("${CUDA_ARCH_LIST}" STREQUAL "Auto") 105 | CUDA_DETECT_INSTALLED_GPUS(CUDA_ARCH_LIST) 106 | message(STATUS "Autodetected CUDA architecture(s): ${CUDA_ARCH_LIST}") 107 | endif() 108 | 109 | # Now process the list and look for names 110 | string(REGEX REPLACE "[ \t]+" ";" CUDA_ARCH_LIST "${CUDA_ARCH_LIST}") 111 | list(REMOVE_DUPLICATES CUDA_ARCH_LIST) 112 | foreach(arch_name ${CUDA_ARCH_LIST}) 113 | set(arch_bin) 114 | set(add_ptx FALSE) 115 | # Check to see if we are compiling PTX 116 | if(arch_name MATCHES "(.*)\\+PTX$") 117 | set(add_ptx TRUE) 118 | set(arch_name ${CMAKE_MATCH_1}) 119 | endif() 120 | if(arch_name MATCHES "(^[0-9]\\.[0-9](\\([0-9]\\.[0-9]\\))?)$") 121 | set(arch_bin ${CMAKE_MATCH_1}) 122 | set(arch_ptx ${arch_bin}) 123 | else() 124 | # Look for it in our list of known architectures 125 | if(${arch_name} STREQUAL "Fermi") 126 | set(arch_bin "2.0 2.1(2.0)") 127 | elseif(${arch_name} STREQUAL "Kepler+Tegra") 128 | set(arch_bin 3.2) 129 | elseif(${arch_name} STREQUAL "Kepler+Tesla") 130 | set(arch_bin 3.7) 131 | elseif(${arch_name} STREQUAL "Kepler") 132 | set(arch_bin 3.0 3.5) 133 | set(arch_ptx 3.5) 134 | elseif(${arch_name} STREQUAL "Maxwell+Tegra") 135 | set(arch_bin 5.3) 136 | elseif(${arch_name} STREQUAL "Maxwell") 137 | set(arch_bin 5.0 5.2) 138 | set(arch_ptx 5.2) 139 | elseif(${arch_name} STREQUAL "Pascal") 140 | set(arch_bin 6.0 6.1) 141 | set(arch_ptx 6.1) 142 | else() 143 | message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS") 144 | endif() 145 | endif() 146 | if(NOT arch_bin) 147 | message(SEND_ERROR "arch_bin wasn't set for some reason") 148 | endif() 149 | list(APPEND cuda_arch_bin ${arch_bin}) 150 | if(add_ptx) 151 | if (NOT arch_ptx) 152 | set(arch_ptx ${arch_bin}) 153 | endif() 154 | list(APPEND cuda_arch_ptx ${arch_ptx}) 155 | endif() 156 | endforeach() 157 | 158 | # remove dots and convert to lists 159 | string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}") 160 | string(REGEX REPLACE "\\." "" cuda_arch_ptx "${cuda_arch_ptx}") 161 | string(REGEX MATCHALL "[0-9()]+" cuda_arch_bin "${cuda_arch_bin}") 162 | string(REGEX MATCHALL "[0-9]+" cuda_arch_ptx "${cuda_arch_ptx}") 163 | 164 | if(cuda_arch_bin) 165 | list(REMOVE_DUPLICATES cuda_arch_bin) 166 | endif() 167 | if(cuda_arch_ptx) 168 | list(REMOVE_DUPLICATES cuda_arch_ptx) 169 | endif() 170 | 171 | set(nvcc_flags "") 172 | set(nvcc_archs_readable "") 173 | 174 | # Tell NVCC to add binaries for the specified GPUs 175 | foreach(arch ${cuda_arch_bin}) 176 | if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)") 177 | # User explicitly specified ARCH for the concrete CODE 178 | list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1}) 179 | list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1}) 180 | else() 181 | # User didn't explicitly specify ARCH for the concrete CODE, we assume ARCH=CODE 182 | list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch}) 183 | list(APPEND nvcc_archs_readable sm_${arch}) 184 | endif() 185 | endforeach() 186 | 187 | # Tell NVCC to add PTX intermediate code for the specified architectures 188 | foreach(arch ${cuda_arch_ptx}) 189 | list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch}) 190 | list(APPEND nvcc_archs_readable compute_${arch}) 191 | endforeach() 192 | 193 | string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}") 194 | set(${out_variable} ${nvcc_flags} PARENT_SCOPE) 195 | set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE) 196 | endfunction() 197 | -------------------------------------------------------------------------------- /experiments/train.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- modified from https://github.com/facebook/fb.resnet.torch 3 | -- original copyrights preserves 4 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 | 6 | local optim = require 'optim' 7 | require 'encoding' 8 | 9 | local M = {} 10 | local Trainer = torch.class('resnet.Trainer', M) 11 | 12 | local function istable(x) 13 | return type(x) == 'table' and not torch.typename(x) 14 | end 15 | 16 | function Trainer:__init(model, criterion, opt, optimState) 17 | self.model = model 18 | self.criterion = criterion 19 | self.optimState = optimState or { 20 | learningRate = opt.LR, 21 | learningRateDecay = 0.0, 22 | momentum = opt.momentum, 23 | nesterov = true, 24 | dampening = 0.0, 25 | weightDecay = opt.weightDecay, 26 | } 27 | self.opt = opt 28 | if opt.ft and opt.lockEpoch > 0 then 29 | print('Locking the Features for Fine-tuning') 30 | -- only work for FT with encoding 31 | self.lockEpoch = opt.lockEpoch 32 | print(model:get(1):get(2)) 33 | self.params, self.gradParams = model:get(1):get(2):getParameters() 34 | self.allparams, self.allgradParams = model:getParameters() 35 | else 36 | self.lockEpoch = -1 37 | self.params, self.gradParams = model:getParameters() 38 | end 39 | end 40 | 41 | function Trainer:train(epoch, dataloader) 42 | -- Trains the model for a single epoch 43 | self.optimState.learningRate = self:learningRate(epoch) 44 | 45 | local timer = torch.Timer() 46 | local dataTimer = torch.Timer() 47 | 48 | -- release the lock 49 | if epoch == self.lockEpoch+1 then 50 | print('Unlocking the Features for Fine-tuning') 51 | self.params, self.gradParams = self.params, self.gradParams 52 | end 53 | 54 | local function feval() 55 | return self.criterion.output, self.gradParams 56 | end 57 | 58 | local trainSize = dataloader:size() 59 | local top1Sum, top5Sum, lossSum = 0.0, 0.0, 0.0 60 | local N = 0 61 | 62 | print('=> Training epoch # ' .. epoch) 63 | -- set the batch norm to training mode 64 | self.model:training() 65 | for n, sample in dataloader:run(epoch) do 66 | local dataTime = dataTimer:time().real 67 | 68 | -- Copy input and target to the GPU 69 | self:copyInputs(sample) 70 | local output = self.model:forward(self.input) 71 | local batchSize 72 | if istable(output) then 73 | output = { 74 | output[1]:float(), 75 | output[2]:float(), 76 | } 77 | batchSize = output[1]:size(1) 78 | else 79 | output = output:float() 80 | batchSize = output:size(1) 81 | end 82 | 83 | local loss = self.criterion:forward(self.model.output, self.target) 84 | self.model:zeroGradParameters() 85 | self.criterion:backward(self.model.output, self.target) 86 | self.model:backward(self.input, self.criterion.gradInput) 87 | optim.sgd(feval, self.params, self.optimState) 88 | 89 | local top1, top5 = self:computeScore(output, sample.target, 1) 90 | if istable(top1) then 91 | if istable(top1Sum) then 92 | 93 | else 94 | top1Sum = {0.0, 0.0} 95 | top5Sum = {0.0, 0.0} 96 | end 97 | top1Sum[1] = top1Sum[1] + top1[1]*batchSize 98 | top5Sum[1] = top5Sum[1] + top5[1]*batchSize 99 | top1Sum[2] = top1Sum[2] + top1[2]*batchSize 100 | top5Sum[2] = top5Sum[2] + top5[2]*batchSize 101 | lossSum = lossSum + loss*batchSize 102 | print((' | Epoch: [%d][%d/%d] Time %.3f Data %.3f Err %1.4f set1-top1 %7.3f set2-top1 %7.3f'):format( 103 | epoch, n, trainSize, timer:time().real, dataTime, loss, top1[1], top1[2])) 104 | else 105 | top1Sum = top1Sum + top1*batchSize 106 | top5Sum = top5Sum + top5*batchSize 107 | lossSum = lossSum + loss*batchSize 108 | print((' | Epoch: [%d][%d/%d] Time %.3f Data %.3f Err %1.4f top1 %7.3f top5 %7.3f'):format( 109 | epoch, n, trainSize, timer:time().real, dataTime, loss, top1, top5)) 110 | end 111 | N = N + batchSize 112 | 113 | 114 | -- check that the storage didn't get changed do to an unfortunate getParameters call 115 | -- assert(self.params:storage() == self.model:parameters()[1]:storage()) 116 | 117 | timer:reset() 118 | dataTimer:reset() 119 | collectgarbage() 120 | end 121 | 122 | if istable(top1Sum) then 123 | top1Sum[1] = top1Sum[1] / N 124 | top5Sum[1] = top5Sum[1] / N 125 | top1Sum[2] = top1Sum[2] / N 126 | top5Sum[2] = top5Sum[2] / N 127 | return top1Sum , top5Sum , lossSum / N 128 | else 129 | return top1Sum / N, top5Sum / N, lossSum / N 130 | end 131 | end 132 | 133 | function Trainer:test(epoch, dataloader) 134 | -- Computes the top-1 and top-5 err on the validation set 135 | 136 | local timer = torch.Timer() 137 | local dataTimer = torch.Timer() 138 | local size = dataloader:size() 139 | 140 | local nCrops = self.opt.tenCrop and 10 or 1 141 | local top1Sum, top5Sum = 0.0, 0.0 142 | local N = 0 143 | 144 | self.model:evaluate() 145 | for n, sample in dataloader:run() do 146 | local dataTime = dataTimer:time().real 147 | 148 | -- Copy input and target to the GPU 149 | self:copyInputs(sample) 150 | 151 | local output = self.model:forward(self.input) 152 | local batchSize 153 | if istable(output) then 154 | output = { 155 | output[1]:float(), 156 | output[2]:float(), 157 | } 158 | batchSize = output[1]:size(1) / nCrops 159 | else 160 | output = output:float() 161 | batchSize = output:size(1) / nCrops 162 | end 163 | 164 | local loss = self.criterion:forward(self.model.output, self.target) 165 | 166 | local top1, top5 = self:computeScore(output, sample.target, nCrops) 167 | if istable(top1) then 168 | if istable(top1Sum) then 169 | 170 | else 171 | top1Sum = {0.0, 0.0} 172 | top5Sum = {0.0, 0.0} 173 | end 174 | top1Sum[1] = top1Sum[1] + top1[1]*batchSize 175 | top5Sum[1] = top5Sum[1] + top5[1]*batchSize 176 | top1Sum[2] = top1Sum[2] + top1[2]*batchSize 177 | top5Sum[2] = top5Sum[2] + top5[2]*batchSize 178 | print((' | Test: [%d][%d/%d] Time %.3f Data %.3f set1-top1 %7.3f (%7.3f) set2-top1 %7.3f (%7.3f)'):format( 179 | epoch, n, size, timer:time().real, dataTime, top1[1], top1Sum[1] / N, top1[2], top1Sum[2] / N)) 180 | else 181 | top1Sum = top1Sum + top1*batchSize 182 | top5Sum = top5Sum + top5*batchSize 183 | print((' | Test: [%d][%d/%d] Time %.3f Data %.3f top1 %7.3f (%7.3f) top5 %7.3f (%7.3f)'):format( 184 | epoch, n, size, timer:time().real, dataTime, top1, top1Sum / N, top5, top5Sum / N)) 185 | end 186 | N = N + batchSize 187 | 188 | 189 | timer:reset() 190 | dataTimer:reset() 191 | collectgarbage() 192 | end 193 | self.model:training() 194 | 195 | if istable(top1Sum) then 196 | print((' * Finished epoch # %d set1-top1: %7.3f set2-top1: %7.3f\n'):format( 197 | epoch, top1Sum[1] / N, top1Sum[2] / N)) 198 | top1Sum[1] = top1Sum[1] / N 199 | top5Sum[1] = top5Sum[1] / N 200 | top1Sum[2] = top1Sum[2] / N 201 | top5Sum[2] = top5Sum[2] / N 202 | return top1Sum , top5Sum 203 | else 204 | print((' * Finished epoch # %d top1: %7.3f top5: %7.3f\n'):format( 205 | epoch, top1Sum / N, top5Sum / N)) 206 | return top1Sum / N, top5Sum / N 207 | end 208 | end 209 | 210 | function Trainer:computeScore(output, target, nCrops) 211 | if nCrops > 1 then 212 | -- Sum over crops 213 | output = output:view(output:size(1) / nCrops, nCrops, output:size(2)) 214 | --:exp() 215 | :sum(2):squeeze(2) 216 | end 217 | 218 | -- Coputes the top1 and top5 error rate 219 | local batchSize 220 | local predictions, correct 221 | if istable(output) then 222 | batchSize = output[1]:size(1) 223 | predictions = {} 224 | correct = {} 225 | _ , predictions[1] = output[1]:float():sort(2, true) -- descending 226 | _ , predictions[2] = output[2]:float():sort(2, true) -- descending 227 | 228 | -- Find which predictions match the target 229 | correct[1] = predictions[1]:eq( 230 | target[1]:long():view(batchSize, 1):expandAs(output[1])) 231 | correct[2] = predictions[2]:eq( 232 | target[2]:long():view(batchSize, 1):expandAs(output[2])) 233 | 234 | -- Top-1 score 235 | local top1 = {1.0 - (correct[1]:narrow(2, 1, 1):sum() / batchSize), 236 | 1.0 - (correct[2]:narrow(2, 1, 1):sum() / batchSize)} 237 | -- Top-5 score, if there are at least 5 classes 238 | local len1 = math.min(5, correct[1]:size(2)) 239 | local len2 = math.min(5, correct[2]:size(2)) 240 | local top5 = {1.0 - (correct[1]:narrow(2, 1, len1):sum() / batchSize), 241 | 1.0 - (correct[2]:narrow(2, 1, len2):sum() / batchSize)} 242 | 243 | 244 | return {top1[1] * 100, top1[2] * 100}, {top5[1] * 100, top5[2] * 100} 245 | else 246 | batchSize = output:size(1) 247 | _ , predictions = output:float():sort(2, true) -- descending 248 | 249 | -- Find which predictions match the target 250 | correct = predictions:eq( 251 | target:long():view(batchSize, 1):expandAs(output)) 252 | 253 | -- Top-1 score 254 | local top1 = 1.0 - (correct:narrow(2, 1, 1):sum() / batchSize) 255 | 256 | -- Top-5 score, if there are at least 5 classes 257 | local len = math.min(5, correct:size(2)) 258 | local top5 = 1.0 - (correct:narrow(2, 1, len):sum() / batchSize) 259 | 260 | return top1 * 100, top5 * 100 261 | end 262 | 263 | 264 | end 265 | 266 | function Trainer:copyInputs(sample) 267 | -- Copies the input to a CUDA tensor, if using 1 GPU, or to pinned memory, 268 | -- if using DataParallelTable. The target is always copied to a CUDA tensor 269 | if istable(sample.input) then 270 | self.input = self.input or (self.opt.nGPU == 1 271 | and {torch.CudaTensor(), torch.CudaTensor()} or 272 | {cutorch.createCudaHostTensor(), cutorch.createCudaHostTensor()}) 273 | self.target = self.target or {torch.CudaTensor(), torch.CudaTensor()} 274 | self.input[1]:resize(sample.input[1]:size()):copy(sample.input[1]) 275 | self.input[2]:resize(sample.input[2]:size()):copy(sample.input[2]) 276 | self.target[1]:resize(sample.target[1]:size()):copy(sample.target[1]) 277 | self.target[2]:resize(sample.target[2]:size()):copy(sample.target[2]) 278 | else 279 | self.input = self.input or (self.opt.nGPU == 1 280 | and torch.CudaTensor() 281 | or cutorch.createCudaHostTensor()) 282 | self.target = self.target or torch.CudaTensor() 283 | self.input:resize(sample.input:size()):copy(sample.input) 284 | self.target:resize(sample.target:size()):copy(sample.target) 285 | end 286 | end 287 | 288 | function Trainer:learningRate(epoch) 289 | -- Training schedule 290 | local decay = 0 291 | if self.opt.dataset == 'imagenet' then 292 | decay = math.floor((epoch - 1) / 30) 293 | elseif self.opt.dataset == 'cifar10' then 294 | decay = epoch >= 122 and 2 or epoch >= 81 and 1 or 0 295 | elseif self.opt.dataset == 'stl10' then 296 | decay = epoch >= 122 and 2 or epoch >= 81 and 1 or 0 297 | elseif self.opt.dataset == 'joint' then 298 | decay = epoch >= 122 and 2 or epoch >= 81 and 1 or 0 299 | elseif self.opt.ft then 300 | decay = epoch > 40 and 2 or 1 301 | else 302 | decay = epoch >= 40 and 2 or 1 303 | end 304 | local learningRate = self.opt.LR * math.pow(0.1, decay) 305 | print('Learning Rate is ', learningRate) 306 | return learningRate 307 | end 308 | 309 | return M.Trainer 310 | -------------------------------------------------------------------------------- /experiments/datasets/transforms.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- modified from https://github.com/facebook/fb.resnet.torch 3 | -- original copyrights preserves 4 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 | 6 | require 'image' 7 | 8 | local M = {} 9 | 10 | function M.Compose(transforms) 11 | return function(input, param) 12 | for _, transform in ipairs(transforms) do 13 | input = transform(input, param) 14 | end 15 | return input 16 | end 17 | end 18 | 19 | function M.ColorNormalize(meanstd) 20 | return function(img) 21 | img = img:clone() 22 | for i=1,3 do 23 | img[i]:add(-meanstd.mean[i]) 24 | img[i]:div(meanstd.std[i]) 25 | end 26 | return img 27 | end 28 | end 29 | 30 | -- Scales the smaller edge to size 31 | function M.Scale(size, interpolation) 32 | interpolation = interpolation or 'bicubic' 33 | return function(input) 34 | local w, h = input:size(3), input:size(2) 35 | if (w <= h and w == size) or (h <= w and h == size) then 36 | return input 37 | end 38 | if w < h then 39 | return image.scale(input, size, h/w * size, interpolation) 40 | else 41 | return image.scale(input, w/h * size, size, interpolation) 42 | end 43 | end 44 | end 45 | 46 | -- Crop to centered rectangle 47 | function M.CenterCrop(size) 48 | return function(input) 49 | local w1 = math.ceil((input:size(3) - size)/2) 50 | local h1 = math.ceil((input:size(2) - size)/2) 51 | return image.crop(input, w1, h1, w1 + size, h1 + size) -- center patch 52 | end 53 | end 54 | 55 | -- Random crop form larger image with optional zero padding 56 | function M.RandomCrop(size, padding) 57 | padding = padding or 0 58 | 59 | return function(input) 60 | if padding > 0 then 61 | local temp = input.new(3, input:size(2) + 2*padding, input:size(3) + 2*padding) 62 | temp:zero() 63 | :narrow(2, padding+1, input:size(2)) 64 | :narrow(3, padding+1, input:size(3)) 65 | :copy(input) 66 | input = temp 67 | end 68 | 69 | local w, h = input:size(3), input:size(2) 70 | if w == size and h == size then 71 | return input 72 | end 73 | 74 | local x1, y1 = torch.random(0, w - size), torch.random(0, h - size) 75 | local out = image.crop(input, x1, y1, x1 + size, y1 + size) 76 | assert(out:size(2) == size and out:size(3) == size, 'wrong crop size') 77 | return out 78 | end 79 | end 80 | 81 | -- for cifar or stl, proof of concept 82 | function M.RandomTwoCrop(size1, size2, padding) 83 | padding = padding or 0 84 | 85 | return function(input, epoch) 86 | local size = (epoch % 2 == 0) and size1 or size2 87 | if padding > 0 then 88 | local temp = input.new(3, input:size(2) + 2*padding, input:size(3) + 2*padding) 89 | temp:zero() 90 | :narrow(2, padding+1, input:size(2)) 91 | :narrow(3, padding+1, input:size(3)) 92 | :copy(input) 93 | input = temp 94 | end 95 | 96 | local w, h = input:size(3), input:size(2) 97 | if w == size and h == size then 98 | return input 99 | end 100 | 101 | local x1, y1 = torch.random(0, w - size), torch.random(0, h - size) 102 | local out = image.crop(input, x1, y1, x1 + size, y1 + size) 103 | assert(out:size(2) == size and out:size(3) == size, 'wrong crop size') 104 | return out 105 | end 106 | end 107 | 108 | function M.RandomThreeCrop(size1, size2, size3, padding) 109 | padding = padding or 0 110 | 111 | return function(input, epoch) 112 | local size 113 | if epoch % 3 == 0 then 114 | size = size1 115 | elseif epoch % 3 == 1 then 116 | size = size2 117 | else 118 | size = size3 119 | end 120 | 121 | if padding > 0 then 122 | local temp = input.new(3, input:size(2) + 2*padding, input:size(3) + 2*padding) 123 | temp:zero() 124 | :narrow(2, padding+1, input:size(2)) 125 | :narrow(3, padding+1, input:size(3)) 126 | :copy(input) 127 | input = temp 128 | end 129 | 130 | local w, h = input:size(3), input:size(2) 131 | if w == size and h == size then 132 | return input 133 | end 134 | 135 | local x1, y1 = torch.random(0, w - size), torch.random(0, h - size) 136 | local out = image.crop(input, x1, y1, x1 + size, y1 + size) 137 | assert(out:size(2) == size and out:size(3) == size, 'wrong crop size') 138 | return out 139 | end 140 | end 141 | 142 | 143 | 144 | -- Four corner patches and center crop from image and its horizontal reflection 145 | function M.TenCrop(size) 146 | local centerCrop = M.CenterCrop(size) 147 | 148 | return function(input) 149 | local w, h = input:size(3), input:size(2) 150 | 151 | local output = {} 152 | for _, img in ipairs{input, image.hflip(input)} do 153 | table.insert(output, centerCrop(img)) 154 | table.insert(output, image.crop(img, 0, 0, size, size)) 155 | table.insert(output, image.crop(img, w-size, 0, w, size)) 156 | table.insert(output, image.crop(img, 0, h-size, size, h)) 157 | table.insert(output, image.crop(img, w-size, h-size, w, h)) 158 | end 159 | 160 | -- View as mini-batch 161 | for i, img in ipairs(output) do 162 | output[i] = img:view(1, img:size(1), img:size(2), img:size(3)) 163 | end 164 | 165 | return input.cat(output, 1) 166 | end 167 | end 168 | 169 | -- Resized with shorter side randomly sampled from [minSize, maxSize] (ResNet-style) 170 | function M.RandomScale(minSize, maxSize) 171 | return function(input) 172 | local w, h = input:size(3), input:size(2) 173 | 174 | local targetSz = torch.random(minSize, maxSize) 175 | local targetW, targetH = targetSz, targetSz 176 | if w < h then 177 | targetH = torch.round(h / w * targetW) 178 | else 179 | targetW = torch.round(w / h * targetH) 180 | end 181 | 182 | return image.scale(input, targetW, targetH, 'bicubic') 183 | end 184 | end 185 | 186 | -- Random crop with size 8%-100% and aspect ratio 3/4 - 4/3 (Inception-style) 187 | function M.RandomSizedCrop(size) 188 | local scale = M.Scale(size) 189 | local crop = M.CenterCrop(size) 190 | 191 | return function(input) 192 | local attempt = 0 193 | repeat 194 | local area = input:size(2) * input:size(3) 195 | local targetArea = torch.uniform(0.09, 1.0) * area 196 | 197 | local aspectRatio = torch.uniform(3/4, 4/3) 198 | local w = torch.round(math.sqrt(targetArea * aspectRatio)) 199 | local h = torch.round(math.sqrt(targetArea / aspectRatio)) 200 | 201 | if torch.uniform() < 0.5 then 202 | w, h = h, w 203 | end 204 | 205 | if h <= input:size(2) and w <= input:size(3) then 206 | local y1 = torch.random(0, input:size(2) - h) 207 | local x1 = torch.random(0, input:size(3) - w) 208 | 209 | local out = image.crop(input, x1, y1, x1 + w, y1 + h) 210 | assert(out:size(2) == h and out:size(3) == w, 'wrong crop size') 211 | 212 | return image.scale(out, size, size, 'bicubic') 213 | end 214 | attempt = attempt + 1 215 | until attempt >= 10 216 | 217 | -- fallback 218 | return crop(scale(input)) 219 | end 220 | end 221 | 222 | function M.RandomTwoSizeCrop(size1, size2) 223 | return function(input, epoch) 224 | local attempt = 0 225 | local size = (epoch % 2 == 0) and size1 or size2 226 | local scale = M.Scale((size) / 2) 227 | local crop = M.CenterCrop((size)/2) 228 | 229 | repeat 230 | local area = input:size(2) * input:size(3) 231 | local targetArea = torch.uniform(0.25, 1.0) * area 232 | 233 | local aspectRatio = torch.uniform(3/4, 4/3) 234 | local w = torch.round(math.sqrt(targetArea * aspectRatio)) 235 | local h = torch.round(math.sqrt(targetArea / aspectRatio)) 236 | 237 | if torch.uniform() < 0.5 then 238 | w, h = h, w 239 | end 240 | 241 | if h <= input:size(2) and w <= input:size(3) then 242 | local y1 = torch.random(0, input:size(2) - h) 243 | local x1 = torch.random(0, input:size(3) - w) 244 | 245 | local out = image.crop(input, x1, y1, x1 + w, y1 + h) 246 | assert(out:size(2) == h and out:size(3) == w, 'wrong crop size') 247 | 248 | return image.scale(out, size, size, 'bicubic') 249 | end 250 | attempt = attempt + 1 251 | until attempt >= 10 252 | 253 | -- fallback 254 | return crop(scale(input)) 255 | end 256 | end 257 | 258 | function M.RandomThreeSizeCrop(size1, size2, size3) 259 | return function(input, epoch) 260 | local attempt = 0 261 | local size = (epoch % 2 == 0) and size1 or (epoch % 2 == 1) and size2 or size3 262 | local scale = M.Scale((size) / 2) 263 | local crop = M.CenterCrop((size)/2) 264 | 265 | repeat 266 | local area = input:size(2) * input:size(3) 267 | local targetArea = torch.uniform(0.25, 1.0) * area 268 | 269 | local aspectRatio = torch.uniform(3/4, 4/3) 270 | local w = torch.round(math.sqrt(targetArea * aspectRatio)) 271 | local h = torch.round(math.sqrt(targetArea / aspectRatio)) 272 | 273 | if torch.uniform() < 0.5 then 274 | w, h = h, w 275 | end 276 | 277 | if h <= input:size(2) and w <= input:size(3) then 278 | local y1 = torch.random(0, input:size(2) - h) 279 | local x1 = torch.random(0, input:size(3) - w) 280 | 281 | local out = image.crop(input, x1, y1, x1 + w, y1 + h) 282 | assert(out:size(2) == h and out:size(3) == w, 'wrong crop size') 283 | 284 | return image.scale(out, size, size, 'bicubic') 285 | end 286 | attempt = attempt + 1 287 | until attempt >= 10 288 | 289 | -- fallback 290 | return crop(scale(input)) 291 | end 292 | end 293 | 294 | 295 | 296 | 297 | function M.HorizontalFlip(prob) 298 | return function(input) 299 | if torch.uniform() < prob then 300 | input = image.hflip(input) 301 | end 302 | return input 303 | end 304 | end 305 | 306 | function M.Rotation(deg) 307 | return function(input) 308 | if deg ~= 0 then 309 | input = image.rotate(input, (torch.uniform() - 0.5) * deg * math.pi / 180, 'bilinear') 310 | end 311 | return input 312 | end 313 | end 314 | 315 | -- Lighting noise (AlexNet-style PCA-based noise) 316 | function M.Lighting(alphastd, eigval, eigvec) 317 | return function(input) 318 | if alphastd == 0 then 319 | return input 320 | end 321 | 322 | local alpha = torch.Tensor(3):normal(0, alphastd) 323 | local rgb = eigvec:clone() 324 | :cmul(alpha:view(1, 3):expand(3, 3)) 325 | :cmul(eigval:view(1, 3):expand(3, 3)) 326 | :sum(2) 327 | :squeeze() 328 | 329 | input = input:clone() 330 | for i=1,3 do 331 | input[i]:add(rgb[i]) 332 | end 333 | return input 334 | end 335 | end 336 | 337 | local function blend(img1, img2, alpha) 338 | return img1:mul(alpha):add(1 - alpha, img2) 339 | end 340 | 341 | local function grayscale(dst, img) 342 | dst:resizeAs(img) 343 | dst[1]:zero() 344 | dst[1]:add(0.299, img[1]):add(0.587, img[2]):add(0.114, img[3]) 345 | dst[2]:copy(dst[1]) 346 | dst[3]:copy(dst[1]) 347 | return dst 348 | end 349 | 350 | function M.Saturation(var) 351 | local gs 352 | 353 | return function(input) 354 | gs = gs or input.new() 355 | grayscale(gs, input) 356 | 357 | local alpha = 1.0 + torch.uniform(-var, var) 358 | blend(input, gs, alpha) 359 | return input 360 | end 361 | end 362 | 363 | function M.Brightness(var) 364 | local gs 365 | 366 | return function(input) 367 | gs = gs or input.new() 368 | gs:resizeAs(input):zero() 369 | 370 | local alpha = 1.0 + torch.uniform(-var, var) 371 | blend(input, gs, alpha) 372 | return input 373 | end 374 | end 375 | 376 | function M.Contrast(var) 377 | local gs 378 | 379 | return function(input) 380 | gs = gs or input.new() 381 | grayscale(gs, input) 382 | gs:fill(gs[1]:mean()) 383 | 384 | local alpha = 1.0 + torch.uniform(-var, var) 385 | blend(input, gs, alpha) 386 | return input 387 | end 388 | end 389 | 390 | function M.RandomOrder(ts) 391 | return function(input) 392 | local img = input.img or input 393 | local order = torch.randperm(#ts) 394 | for i=1,#ts do 395 | img = ts[order[i]](img) 396 | end 397 | return input 398 | end 399 | end 400 | 401 | function M.ColorJitter(opt) 402 | local brightness = opt.brightness or 0 403 | local contrast = opt.contrast or 0 404 | local saturation = opt.saturation or 0 405 | 406 | local ts = {} 407 | if brightness ~= 0 then 408 | table.insert(ts, M.Brightness(brightness)) 409 | end 410 | if contrast ~= 0 then 411 | table.insert(ts, M.Contrast(contrast)) 412 | end 413 | if saturation ~= 0 then 414 | table.insert(ts, M.Saturation(saturation)) 415 | end 416 | 417 | if #ts == 0 then 418 | return function(input) return input end 419 | end 420 | 421 | return M.RandomOrder(ts) 422 | end 423 | 424 | return M 425 | --------------------------------------------------------------------------------