├── .gitignore ├── EPECriterion.lua ├── LICENSE ├── README.md ├── data.lua ├── dataset.lua ├── donkey.lua ├── extras ├── spybhwd │ ├── CMakeLists.txt │ ├── ScaleBHWD.cu │ ├── ScaleBHWD.lua │ ├── generic │ │ └── ScaleBHWD.c │ ├── init.c │ ├── init.cu │ ├── init.lua │ ├── spybhwd-scm-1.rockspec │ ├── test.lua │ ├── utils.c │ └── utils.h └── stnbhwd │ ├── AffineGridGeneratorBHWD.lua │ ├── AffineTransformMatrixGenerator.lua │ ├── BilinearSamplerBHWD.cu │ ├── BilinearSamplerBHWD.lua │ ├── CMakeLists.txt │ ├── LICENSE │ ├── README.md │ ├── ScaleBHWD.cu │ ├── ScaleBHWD.lua │ ├── demo │ ├── Optim.lua │ ├── README.md │ ├── demo_mnist.lua │ ├── distort_mnist.lua │ └── spatial_transformer.lua │ ├── generic │ ├── BilinearSamplerBHWD.c │ └── ScaleBHWD.c │ ├── init.c │ ├── init.cu │ ├── init.lua │ ├── stnbhwd-scm-1.rockspec │ ├── test.lua │ ├── utils.c │ └── utils.h ├── flowExtensions.lua ├── main.lua ├── model.lua ├── models ├── modelL1_3.t7 ├── modelL1_4.t7 ├── modelL1_C.t7 ├── modelL1_F.t7 ├── modelL1_K.t7 ├── modelL2_3.t7 ├── modelL2_4.t7 ├── modelL2_C.t7 ├── modelL2_F.t7 ├── modelL2_K.t7 ├── modelL3_3.t7 ├── modelL3_4.t7 ├── modelL3_C.t7 ├── modelL3_F.t7 ├── modelL3_K.t7 ├── modelL4_3.t7 ├── modelL4_4.t7 ├── modelL4_C.t7 ├── modelL4_F.t7 ├── modelL4_K.t7 ├── modelL5_3.t7 ├── modelL5_4.t7 ├── modelL5_C.t7 ├── modelL5_F.t7 ├── modelL5_K.t7 ├── modelL6_C.t7 ├── modelL6_F.t7 ├── modelL6_K.t7 └── volcon.lua ├── opts.lua ├── samples ├── 00001_flow.flo ├── 00001_img1.ppm ├── 00001_img2.ppm ├── 00002_flow.flo ├── 00002_img1.ppm ├── 00002_img2.ppm ├── 00003_flow.flo ├── 00003_img1.ppm └── 00003_img2.ppm ├── spynet.lua ├── test.lua ├── timing_benchmark.lua ├── timing_util.lua ├── train.lua ├── train_val_split.txt ├── transforms.lua └── util.lua /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoint* 2 | -------------------------------------------------------------------------------- /EPECriterion.lua: -------------------------------------------------------------------------------- 1 | 2 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. 3 | -- All rights reserved. 4 | -- This software is provided for research purposes only. 5 | -- By using this software you agree to the terms of the license file 6 | -- in the root folder. 7 | -- For commercial use, please contact ps-license@tue.mpg.de. 8 | 9 | local EPECriterion, parent = torch.class('nn.EPECriterion', 'nn.Criterion') 10 | 11 | -- Computes average endpoint error for batchSize x ChannelSize x Height x Width 12 | -- flow fields or general multidimensional matrices. 13 | 14 | local eps = 1e-12 15 | 16 | function EPECriterion:__init() 17 | parent.__init(self) 18 | self.sizeAverage = true 19 | end 20 | 21 | function EPECriterion:updateOutput(input, target) 22 | assert( input:nElement() == target:nElement(), 23 | "input and target size mismatch") 24 | 25 | self.buffer = self.buffer or input.new() 26 | 27 | local buffer = self.buffer 28 | local output 29 | local npixels 30 | 31 | buffer:resizeAs(input) 32 | npixels = input:nElement()/2 -- 2 channel flow fields 33 | 34 | buffer:add(input, -1, target):pow(2) 35 | output = torch.sum(buffer,2):sqrt() -- second channel is flow 36 | output = output:sum() 37 | 38 | output = output / npixels 39 | 40 | self.output = output 41 | 42 | return self.output 43 | end 44 | 45 | function EPECriterion:updateGradInput(input, target) 46 | 47 | assert( input:nElement() == target:nElement(), 48 | "input and target size mismatch") 49 | 50 | self.buffer = self.buffer or input.new() 51 | 52 | local buffer = self.buffer 53 | local gradInput = self.gradInput 54 | local npixels 55 | local loss 56 | 57 | buffer:resizeAs(input) 58 | npixels = input:nElement()/2 59 | 60 | buffer:add(input, -1, target):pow(2) 61 | loss = torch.sum(buffer,2):sqrt():add(eps) -- forms the denominator 62 | loss = torch.cat(loss, loss, 2) -- Repeat tensor to scale the gradients 63 | 64 | gradInput:resizeAs(input) 65 | gradInput:add(input, -1, target):cdiv(loss) 66 | gradInput = gradInput / npixels 67 | return gradInput 68 | end -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | SPyNet License 2 | For non-commercial and scientific research purposes 3 | Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft 4 | 5 | Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the trained models and code, (the "Model"). By downloading and/or using the Model, you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Model. 6 | 7 | Ownership 8 | The Model has been developed at the Max Planck Institute for Intelligent Systems (hereinafter "MPI") and is owned by and proprietary material of the Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”). 9 | 10 | License Grant 11 | Max-Planck grants you a non-exclusive, non-transferable, free of charge right: 12 | 13 | To download the Model and use it on computers owned, leased or otherwise controlled by you and/or your organisation; 14 | To use the Model for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects. 15 | Any other use, in particular any use for commercial purposes, is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, as training data for a commercial product, for commercial ergonomic analysis (e.g. product design, architectural design, etc.), or production of other artifacts for commercial purposes including, for example, web services, movies, television programs, mobile applications, or video games. The Model may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Model to train methods/algorithms/neural networks/etc. for commercial use of any kind. The Model may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission. By downloading the Model, you agree not to reverse engineer it. 16 | 17 | Disclaimer of Representations and Warranties 18 | You expressly acknowledge and agree that the Model results from basic research, is provided “AS IS”, may contain errors, and that any use of the Model is at your sole risk. MAX-PLANCK MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE MODEL, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, Max-Planck makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Model, (ii) that the use of the Model will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Model will not cause any damage of any kind to you or a third party. 19 | 20 | Limitation of Liability 21 | Under no circumstances shall Max-Planck be liable for any incidental, special, indirect or consequential damages arising out of or relating to this license, including but not limited to, any lost profits, business interruption, loss of programs or other data, or all other commercial damages or losses, even if advised of the possibility thereof. 22 | 23 | No Maintenance Services 24 | You understand and agree that Max-Planck is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Model. Max-Planck nevertheless reserves the right to update, modify, or discontinue the Model at any time. 25 | 26 | Publication 27 | You agree to cite the most recent paper describing the model as specified on the download website. This website lists the most up to date bibliographic information on the about page. 28 | 29 | Media projects 30 | When using the Model in a media project please give credit to Max Planck Institute for Intelligent Systems. For example: SPyNet was used for optical flow estimation courtesy of the Max Planck Institute for Intelligent Systems. 31 | Commercial licensing opportunities 32 | For commercial use, please contact ps-license@tue.mpg.de. 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPyNet: Spatial Pyramid Network for Optical Flow 2 | This code is based on the paper [Optical Flow Estimation using a Spatial Pyramid Network](https://arxiv.org/abs/1611.00850). 3 | 4 | [[Unofficial Pytorch version](https://github.com/sniklaus/pytorch-spynet)] [[Unofficial tensorflow version](https://github.com/tukilabs/Video-Compression-Net/blob/master/utils/network.py)] 5 | 6 | * [First things first:](#setUp) Setting up this code 7 | * [Easy Usage:](#easyUsage) Compute Optical Flow in 5 lines 8 | * [Fast Performance Usage:](#fastPerformanceUsage) Compute Optical Flow at a rocket speed 9 | * [Training:](#training) Train your own models using Spatial Pyramid approach on mulitiple GPUs 10 | * [End2End SPyNet:](#end2end) An easy trainable end-to-end version of SPyNet 11 | * [Optical Flow Utilities:](#flowUtils) A set of functions in lua for working around optical flow 12 | * [References:](#references) For further reading 13 | 14 | 15 | ## First things first 16 | You need to have [Torch.](http://torch.ch/docs/getting-started.html#_) 17 | 18 | Install other required packages 19 | ```bash 20 | cd extras/spybhwd 21 | luarocks make 22 | cd ../stnbhwd 23 | luarocks make 24 | ``` 25 | 26 | ## For Easy Usage, follow this 27 | #### Set up SPyNet 28 | ```lua 29 | spynet = require('spynet') 30 | easyComputeFlow = spynet.easy_setup() 31 | ``` 32 | #### Load images and compute flow 33 | ```lua 34 | im1 = image.load('samples/00001_img1.ppm' ) 35 | im2 = image.load('samples/00001_img2.ppm' ) 36 | flow = easyComputeFlow(im1, im2) 37 | ``` 38 | To save your flow fields to a .flo file use [flowExtensions.writeFLO](#writeFLO). 39 | 40 | 41 | ## For Fast Performace, follow this (recommended) 42 | #### Set up SPyNet 43 | Set up SPyNet according to the image size and model. For optimal performance, resize your image such that width and height are a multiple of 32. You can also specify your favorite model. The present supported modes are fine tuned models `sintelFinal`(default), `sintelClean`, `kittiFinal`, and base models `chairsFinal` and `chairsClean`. 44 | ```lua 45 | spynet = require('spynet') 46 | computeFlow = spynet.setup(512, 384, 'sintelFinal') -- for 384x512 images 47 | ``` 48 | Now you can call computeFlow anytime to estimate optical flow between image pairs. 49 | 50 | #### Computing flow 51 | Load an image pair and stack and normalize it. 52 | ```lua 53 | im1 = image.load('samples/00001_img1.ppm' ) 54 | im2 = image.load('samples/00001_img2.ppm' ) 55 | im = torch.cat(im1, im2, 1) 56 | im = spynet.normalize(im) 57 | ``` 58 | SPyNet works with batches of data on CUDA. So, compute flow using 59 | ```lua 60 | im = im:resize(1, im:size(1), im:size(2), im:size(3)):cuda() 61 | flow = computeFlow(im) 62 | ``` 63 | You can also use batch-mode, if your images `im` are a tensor of size `Bx6xHxW`, of batch size B with 6 RGB pair channels. You can directly use: 64 | ```lua 65 | flow = computeFlow(im) 66 | ``` 67 | 68 | ## Training 69 | Training sequentially is faster than training end-to-end since you need to learn small number of parameters at each level. To train a level `N`, we need the trained models at levels `1` to `N-1`. You also initialize the model with a pretrained model at `N-1`. 70 | 71 | E.g. To train level 3, we need trained models at `L1` and `L2`, and we initialize it `modelL2_3.t7`. 72 | ```bash 73 | th main.lua -fineWidth 128 -fineHeight 96 -level 3 -netType volcon \ 74 | -cache checkpoint -data FLYING_CHAIRS_DIR \ 75 | -L1 models/modelL1_3.t7 -L2 models/modelL2_3.t7 \ 76 | -retrain models/modelL2_3.t7 77 | ``` 78 | 79 | ## End2End SPyNet 80 | The end-to-end version of SPyNet is easily trainable and is available at [anuragranj/end2end-spynet](https://github.com/anuragranj/end2end-spynet). 81 | 82 | 83 | ## Optical Flow Utilities 84 | We provide `flowExtensions.lua` containing various functions to make your life easier with optical flow while using Torch/Lua. You can just copy this file into your project directory and use if off the shelf. 85 | ```lua 86 | flowX = require 'flowExtensions' 87 | ``` 88 | #### [flow_magnitude] flowX.computeNorm(flow_x, flow_y) 89 | Given `flow_x` and `flow_y` of size `MxN` each, evaluate `flow_magnitude` of size `MxN`. 90 | 91 | #### [flow_angle] flowX.computeAngle(flow_x, flow_y) 92 | Given `flow_x` and `flow_y` of size `MxN` each, evaluate `flow_angle` of size `MxN` in degrees. 93 | 94 | #### [rgb] flowX.field2rgb(flow_magnitude, flow_angle, [max], [legend]) 95 | Given `flow_magnitude` and `flow_angle` of size `MxN` each, return an image of size `3xMxN` for visualizing optical flow. `max`(optional) specifies maximum flow magnitude and `legend`(optional) is boolean that prints a legend on the image. 96 | 97 | #### [rgb] flowX.xy2rgb(flow_x, flow_y, [max]) 98 | Given `flow_x` and `flow_y` of size `MxN` each, return an image of size `3xMxN` for visualizing optical flow. `max`(optional) specifies maximum flow magnitude. 99 | 100 | #### [flow] flowX.loadFLO(filename) 101 | Reads a `.flo` file. Loads `x` and `y` components of optical flow in a 2 channel `2xMxN` optical flow field. First channel stores `x` component and second channel stores `y` component. 102 | 103 | 104 | #### flowX.writeFLO(filename,F) 105 | Write a `2xMxN` flow field `F` containing `x` and `y` components of its flow fields in its first and second channel respectively to `filename`, a `.flo` file. 106 | 107 | #### [flow] flowX.loadPFM(filename) 108 | Reads a `.pfm` file. Loads `x` and `y` components of optical flow in a 2 channel `2xMxN` optical flow field. First channel stores `x` component and second channel stores `y` component. 109 | 110 | #### [flow_rotated] flowX.rotate(flow, angle) 111 | Rotates `flow` of size `2xMxN` by `angle` in radians. Uses nearest-neighbor interpolation to avoid blurring at boundaries. 112 | 113 | #### [flow_scaled] flowX.scale(flow, sc, [opt]) 114 | Scales `flow` of size `2xMxN` by `sc` times. `opt`(optional) specifies interpolation method, `simple` (default), `bilinear`, and `bicubic`. 115 | 116 | #### [flowBatch_scaled] flowX.scaleBatch(flowBatch, sc) 117 | Scales `flowBatch` of size `Bx2xMxN`, a batch of `B` flow fields by `sc` times. Uses nearest-neighbor interpolation. 118 | 119 | 120 | ## Timing Benchmarks 121 | Our timing benchmark is set up on Flying chair dataset. To test it, you need to download 122 | ```bash 123 | wget http://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs/FlyingChairs.zip 124 | ``` 125 | Run the timing benchmark 126 | ```bash 127 | th timing_benchmark.lua -data YOUR_FLYING_CHAIRS_DATA_DIRECTORY 128 | ``` 129 | 130 | 131 | ## References 132 | 1. Our warping code is based on [qassemoquab/stnbhwd.](https://github.com/qassemoquab/stnbhwd) 133 | 2. The images in `samples` are from Flying Chairs dataset: 134 | Dosovitskiy, Alexey, et al. "Flownet: Learning optical flow with convolutional networks." 2015 IEEE International Conference on Computer Vision (ICCV). IEEE, 2015. 135 | 3. Some parts of `flowExtensions.lua` are adapted from [marcoscoffier/optical-flow](https://github.com/marcoscoffier/optical-flow/blob/master/init.lua) with help from [fguney](https://github.com/fguney). 136 | 4. The unofficial PyTorch implementation is from [sniklaus](https://github.com/sniklaus). 137 | 138 | ## License 139 | Free for non-commercial and scientific research purposes. For commercial use, please contact ps-license@tue.mpg.de. Check LICENSE file for details. 140 | 141 | ## When using this code, please cite 142 | Ranjan, Anurag, and Michael J. Black. "Optical Flow Estimation using a Spatial Pyramid Network." arXiv preprint arXiv:1611.00850 (2016). 143 | -------------------------------------------------------------------------------- /data.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. 2 | -- All rights reserved. 3 | -- This software is provided for research purposes only. 4 | -- By using this software you agree to the terms of the license file 5 | -- in the root folder. 6 | -- For commercial use, please contact ps-license@tue.mpg.de. 7 | -- 8 | -- Copyright (c) 2014, Facebook, Inc. 9 | -- All rights reserved. 10 | -- 11 | -- This source code is licensed under the BSD-style license found in the 12 | -- LICENSE file in the root directory of this source tree. An additional grant 13 | -- of patent rights can be found in the PATENTS file in the same directory. 14 | -- 15 | local ffi = require 'ffi' 16 | local Threads = require 'threads' 17 | Threads.serialization('threads.sharedserialize') 18 | 19 | -- This script contains the logic to create K threads for parallel data-loading. 20 | -- For the data-loading details, look at donkey.lua 21 | ------------------------------------------------------------------------------- 22 | do -- start K datathreads (donkeys) 23 | if opt.nDonkeys > 0 then 24 | local options = opt -- make an upvalue to serialize over to donkey threads 25 | donkeys = Threads( 26 | opt.nDonkeys, 27 | function() 28 | require 'torch' 29 | end, 30 | function(idx) 31 | opt = options -- pass to all donkeys via upvalue 32 | tid = idx 33 | local seed = opt.manualSeed + idx 34 | torch.manualSeed(seed) 35 | print(string.format('Starting donkey with id: %d seed: %d', tid, seed)) 36 | paths.dofile('donkey.lua') 37 | end 38 | ); 39 | else -- single threaded data loading. useful for debugging 40 | paths.dofile('donkey.lua') 41 | donkeys = {} 42 | function donkeys:addjob(f1, f2) f2(f1()) end 43 | function donkeys:synchronize() end 44 | end 45 | end 46 | 47 | nTest = 0 48 | donkeys:addjob(function() return testLoader:size() end, function(c) nTest = c end) 49 | donkeys:synchronize() 50 | assert(nTest > 0, "Failed to get nTest") 51 | print('nTest: ', nTest) 52 | -------------------------------------------------------------------------------- /dataset.lua: -------------------------------------------------------------------------------- 1 | 2 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. 3 | -- All rights reserved. 4 | -- This software is provided for research purposes only. 5 | -- By using this software you agree to the terms of the license file 6 | -- in the root folder. 7 | -- For commercial use, please contact ps-license@tue.mpg.de. 8 | 9 | require 'torch' 10 | torch.setdefaulttensortype('torch.FloatTensor') 11 | local ffi = require 'ffi' 12 | local class = require('pl.class') 13 | local dir = require 'pl.dir' 14 | local tablex = require 'pl.tablex' 15 | local argcheck = require 'argcheck' 16 | require 'sys' 17 | require 'xlua' 18 | require 'image' 19 | 20 | local dataset = torch.class('dataLoader') 21 | 22 | local initcheck = argcheck{ 23 | pack=true, 24 | help=[[ 25 | A dataset class for images in a flat folder structure (folder-name is class-name). 26 | Optimized for extremely large datasets (upwards of 14 million images). 27 | Tested only on Linux (as it uses command-line linux utilities to scale up) 28 | ]], 29 | {name="inputSize", 30 | type="table", 31 | help="the size of the input images"}, 32 | 33 | {name="outputSize", 34 | type="table", 35 | help="the size of the network output"}, 36 | 37 | {name="split", 38 | type="number", 39 | help="Percentage of split to go to Training" 40 | }, 41 | 42 | {name="samplingMode", 43 | type="string", 44 | help="Sampling mode: random | balanced ", 45 | default = "balanced"}, 46 | 47 | {name="verbose", 48 | type="boolean", 49 | help="Verbose mode during initialization", 50 | default = false}, 51 | 52 | {name="loadSize", 53 | type="table", 54 | help="a size to load the images to, initially", 55 | opt = true}, 56 | 57 | {name="samplingIds", 58 | type="torch.LongTensor", 59 | help="the ids of training or testing images", 60 | opt = true}, 61 | 62 | {name="sampleHookTrain", 63 | type="function", 64 | help="applied to sample during training(ex: for lighting jitter). " 65 | .. "It takes the image path as input", 66 | opt = true}, 67 | 68 | {name="sampleHookTest", 69 | type="function", 70 | help="applied to sample during testing", 71 | opt = true}, 72 | } 73 | 74 | function dataset:__init(...) 75 | 76 | -- argcheck 77 | local args = initcheck(...) 78 | print(args) 79 | for k,v in pairs(args) do self[k] = v end 80 | 81 | if not self.loadSize then self.loadSize = self.inputSize; end 82 | 83 | if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end 84 | if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end 85 | 86 | local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end 87 | 88 | self.numSamples = self.samplingIds:size()[1] 89 | assert(self.numSamples > 0, "Could not find any sample in the given input paths") 90 | 91 | if self.verbose then print(self.numSamples .. ' samples found.') end 92 | end 93 | 94 | function dataset:size(class, list) 95 | return self.numSamples 96 | end 97 | 98 | -- converts a table of samples (and corresponding labels) to a clean tensor 99 | local function tableToOutput(self, imgTable, outputTable) 100 | local images, outputs 101 | local quantity = #imgTable 102 | assert(imgTable[1]:size()[1] == self.inputSize[1]) 103 | assert(outputTable[1]:size()[1] == self.outputSize[1]) 104 | 105 | images = torch.Tensor(quantity, 106 | self.inputSize[1], self.inputSize[2], self.inputSize[3]) 107 | outputs = torch.Tensor(quantity, 108 | self.outputSize[1], self.outputSize[2], self.outputSize[3]) 109 | 110 | for i=1,quantity do 111 | images[i]:copy(imgTable[i]) 112 | outputs[i]:copy(outputTable[i]) 113 | end 114 | return images, outputs 115 | end 116 | 117 | -- sampler, samples from the training set. 118 | function dataset:sample(quantity) 119 | assert(quantity) 120 | local imgTable = {} 121 | local outputTable = {} 122 | for i=1,quantity do 123 | local id = torch.random(1, self.numSamples) 124 | local img, output = self:sampleHookTrain(self.samplingIds[id][1]) -- single element[not tensor] from a row 125 | 126 | table.insert(imgTable, img) 127 | table.insert(outputTable, output) 128 | end 129 | local images, outputs = tableToOutput(self, imgTable, outputTable) 130 | return images, outputs 131 | end 132 | 133 | function dataset:get(i1, i2) 134 | local indices = self.samplingIds[{{i1, i2}}]; 135 | local quantity = i2 - i1 + 1; 136 | assert(quantity > 0) 137 | local imgTable = {} 138 | local outputTable = {} 139 | for i=1,quantity do 140 | local img, output = self:sampleHookTest(indices[i][1]) 141 | table.insert(imgTable, img) 142 | table.insert(outputTable, output) 143 | end 144 | local images, outputs = tableToOutput(self, imgTable, outputTable) 145 | return images, outputs 146 | end 147 | 148 | return dataset 149 | -------------------------------------------------------------------------------- /donkey.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. 2 | -- All rights reserved. 3 | -- This software is provided for research purposes only. 4 | -- By using this software you agree to the terms of the license file 5 | -- in the root folder. 6 | -- For commercial use, please contact ps-license@tue.mpg.de. 7 | 8 | require 'image' 9 | require 'nn' 10 | require 'cunn' 11 | require 'cudnn' 12 | require 'nngraph' 13 | require 'stn' 14 | require 'spy' 15 | 16 | local flowX = require 'flowExtensions' 17 | local TF = require 'transforms' 18 | 19 | paths.dofile('dataset.lua') 20 | paths.dofile('util.lua') 21 | 22 | -- This file contains the data-loading logic and details. 23 | -- It is run by each data-loader thread. 24 | ------------------------------------------ 25 | local eps = 1e-6 26 | -- a cache file of the training metadata (if doesnt exist, will be created) 27 | local trainCache = paths.concat(opt.cache, 'trainCache.t7') 28 | local testCache = paths.concat(opt.cache, 'testCache.t7') 29 | 30 | local meanstd = { 31 | mean = { 0.485, 0.456, 0.406 }, 32 | std = { 0.229, 0.224, 0.225 }, 33 | } 34 | local pca = { 35 | eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 }, 36 | eigvec = torch.Tensor{ 37 | { -0.5675, 0.7192, 0.4009 }, 38 | { -0.5808, -0.0045, -0.8140 }, 39 | { -0.5836, -0.6948, 0.4203 }, 40 | }, 41 | } 42 | 43 | local mean = meanstd.mean 44 | local std = meanstd.std 45 | ------------------------------------------ 46 | -- Warping Function: 47 | local function createWarpModel() 48 | local imgData = nn.Identity()() 49 | local floData = nn.Identity()() 50 | 51 | local imgOut = nn.Transpose({2,3},{3,4})(imgData) 52 | local floOut = nn.Transpose({2,3},{3,4})(floData) 53 | 54 | local warpImOut = nn.Transpose({3,4},{2,3})(nn.BilinearSamplerBHWD()({imgOut, floOut})) 55 | local model = nn.gModule({imgData, floData}, {warpImOut}) 56 | 57 | return model 58 | end 59 | 60 | local modelL1, modelL2, modelL3, modelL4 61 | local modelL1path, modelL2path, modelL3path, modelL4path 62 | local down1, down2, down3, down4, up2, up3, up4 63 | local warpmodel2, warpmodel3, warpmodel4 64 | 65 | modelL1path = opt.L1 66 | modelL2path = opt.L2 67 | modelL3path = opt.L3 68 | modelL4path = opt.L4 69 | 70 | if opt.level > 1 then 71 | -- Load modelL1 72 | modelL1 = torch.load(modelL1path) 73 | if torch.type(modelL1) == 'nn.DataParallelTable' then 74 | modelL1 = modelL1:get(1) 75 | end 76 | modelL1:evaluate() 77 | down1 = nn.SpatialAveragePooling(2,2,2,2):cuda() 78 | down1:evaluate() 79 | end 80 | 81 | if opt.level > 2 then 82 | -- Load modelL2 83 | modelL2 = torch.load(modelL2path) 84 | if torch.type(modelL2) == 'nn.DataParallelTable' then 85 | modelL2 = modelL2:get(1) 86 | end 87 | modelL2:evaluate() 88 | 89 | down2 = nn.SpatialAveragePooling(2,2,2,2):cuda() 90 | up2 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda() 91 | warpmodel2 = createWarpModel():cuda() 92 | 93 | down2:evaluate() 94 | up2:evaluate() 95 | warpmodel2:evaluate() 96 | end 97 | 98 | if opt.level > 3 then 99 | -- Load modelL3 100 | modelL3 = torch.load(modelL3path) 101 | if torch.type(modelL3) == 'nn.DataParallelTable' then 102 | modelL3 = modelL3:get(1) 103 | end 104 | modelL3:evaluate() 105 | 106 | down3 = nn.SpatialAveragePooling(2,2,2,2):cuda() 107 | up3 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda() 108 | warpmodel3 = createWarpModel():cuda() 109 | 110 | down3:evaluate() 111 | up3:evaluate() 112 | warpmodel3:evaluate() 113 | end 114 | 115 | if opt.level > 4 then 116 | -- Load modelL4 117 | modelL4 = torch.load(modelL4path) 118 | if torch.type(modelL4) == 'nn.DataParallelTable' then 119 | modelL4 = modelL4:get(1) 120 | end 121 | modelL4:evaluate() 122 | 123 | down4 = nn.SpatialAveragePooling(2,2,2,2):cuda() 124 | up4 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda() 125 | warpmodel4 = createWarpModel():cuda() 126 | 127 | down4:evaluate() 128 | up4:evaluate() 129 | warpmodel4:evaluate() 130 | end 131 | 132 | -- Check for existence of opt.data 133 | if not os.execute('cd ' .. opt.data) then 134 | error(("could not chdir to '%s'"):format(opt.data)) 135 | end 136 | 137 | local loadSize = opt.loadSize 138 | local inputSize = {8, opt.fineHeight, opt.fineWidth} 139 | local outputSize = {2, opt.fineHeight, opt.fineWidth} 140 | 141 | local function getTrainValidationSplits(path) 142 | local numSamples = sys.fexecute( "ls " .. opt.data .. "| wc -l")/3 143 | local ff = torch.DiskFile(path, 'r') 144 | local trainValidationSamples = torch.IntTensor(numSamples) 145 | ff:readInt(trainValidationSamples:storage()) 146 | ff:close() 147 | 148 | local train_samples = trainValidationSamples:eq(1):nonzero() 149 | local validation_samples = trainValidationSamples:eq(2):nonzero() 150 | 151 | return train_samples, validation_samples 152 | end 153 | 154 | local train_samples, validation_samples = getTrainValidationSplits(opt.trainValidationSplit) 155 | 156 | local function loadImage(path) 157 | local input = image.load(path, 3, 'float') 158 | return input 159 | end 160 | 161 | local function rotateFlow(flow, angle) 162 | local flow_rot = image.rotate(flow, angle) 163 | local fu = torch.mul(flow_rot[1], math.cos(-angle)) - torch.mul(flow_rot[2], math.sin(-angle)) 164 | local fv = torch.mul(flow_rot[1], math.sin(-angle)) + torch.mul(flow_rot[2], math.cos(-angle)) 165 | flow_rot[1]:copy(fu) 166 | flow_rot[2]:copy(fv) 167 | 168 | return flow_rot 169 | end 170 | 171 | local function scaleFlow(flow, height, width) 172 | -- scale the original flow to a flow of size height x width 173 | local sc = height/flow:size(2) 174 | assert(torch.abs(width/flow:size(3) - sc)= 4.6.2 or change your OS to enable OpenMP") 30 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unknown-pragmas") 31 | SET(WITH_OPENMP OFF CACHE BOOL "OpenMP support if available?" FORCE) 32 | ENDIF () 33 | ENDIF () 34 | 35 | IF (WITH_OPENMP) 36 | FIND_PACKAGE(OpenMP) 37 | IF(OPENMP_FOUND) 38 | MESSAGE(STATUS "Compiling with OpenMP support") 39 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 40 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 41 | SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") 42 | ENDIF(OPENMP_FOUND) 43 | ENDIF (WITH_OPENMP) 44 | 45 | LINK_DIRECTORIES("${Torch_INSTALL_LIB}") 46 | 47 | SET(src init.c) 48 | FILE(GLOB luasrc *.lua) 49 | ADD_TORCH_PACKAGE(spy "${src}" "${luasrc}") 50 | TARGET_LINK_LIBRARIES(spy luaT TH) 51 | 52 | 53 | FIND_PACKAGE(CUDA 5.5) 54 | 55 | IF (CUDA_FOUND) 56 | LIST(APPEND CUDA_NVCC_FLAGS "-arch=sm_20") 57 | LIST(APPEND CUDA_NVCC_FLAGS "-Xcompiler -std=c++98") 58 | 59 | INCLUDE_DIRECTORIES("${Torch_INSTALL_INCLUDE}/THC") 60 | SET(src-cuda init.cu) 61 | CUDA_ADD_LIBRARY(cuspy MODULE ${src-cuda}) 62 | TARGET_LINK_LIBRARIES(cuspy luaT THC TH) 63 | IF(APPLE) 64 | SET_TARGET_PROPERTIES(cuspy PROPERTIES 65 | LINK_FLAGS "-undefined dynamic_lookup") 66 | ENDIF() 67 | ### Torch packages supposes libraries prefix is "lib" 68 | SET_TARGET_PROPERTIES(cuspy PROPERTIES 69 | PREFIX "lib" 70 | IMPORT_PREFIX "lib") 71 | 72 | INSTALL(TARGETS cuspy 73 | RUNTIME DESTINATION "${Torch_INSTALL_LUA_CPATH_SUBDIR}" 74 | LIBRARY DESTINATION "${Torch_INSTALL_LUA_CPATH_SUBDIR}") 75 | ENDIF(CUDA_FOUND) 76 | -------------------------------------------------------------------------------- /extras/spybhwd/ScaleBHWD.lua: -------------------------------------------------------------------------------- 1 | local ScaleBHWD, parent = torch.class('nn.ScaleBHWD', 'nn.Module') 2 | 3 | --[[ 4 | ScaleBHWD() : 5 | ScaleBHWD:updateOutput({inputImages, grids}) 6 | ScaleBHWD:updateGradInput({inputImages, grids}, gradOutput) 7 | 8 | ScaleBHWD will perform bilinear sampling of the input images according to the 9 | normalized coordinates provided in the grid. Output will be of same size as the grids, 10 | with as many features as the input images. 11 | 12 | - inputImages has to be in BHWD layout 13 | 14 | - grids have to be in BHWD layout, with dim(D)=2 15 | - grids contains, for each sample (first dim), the normalized coordinates of the output wrt the input sample 16 | - first coordinate is Y coordinate, second is X 17 | - normalized coordinates : (-1,-1) points to top left, (-1,1) points to top right 18 | - if the normalized coordinates fall outside of the image, then output will be filled with zeros 19 | ]] 20 | 21 | function ScaleBHWD:__init(scale) 22 | parent.__init(self) 23 | self.scale = scale or 1 24 | end 25 | 26 | function ScaleBHWD:check(input, gradOutput) 27 | local inputImages = input 28 | -- local grids = input[2] 29 | 30 | assert(inputImages:isContiguous(), 'Input images have to be contiguous') 31 | assert(inputImages:nDimension()==4) 32 | -- assert(grids:nDimension()==4) 33 | -- assert(inputImages:size(1)==grids:size(1)) -- batch 34 | -- assert(grids:size(4)==2) -- coordinates 35 | 36 | -- if gradOutput then 37 | -- TODO: checks for output size here 38 | -- assert(inputImages:size(1)==gradOutput:size(1)) 39 | -- assert(inputImages:size(2)==gradOutput:size(2)) 40 | -- assert(inputImages:size(3)==gradOutput:size(3)) 41 | -- end 42 | end 43 | 44 | local function addOuterDim(t) 45 | local sizes = t:size() 46 | local newsizes = torch.LongStorage(sizes:size()+1) 47 | newsizes[1]=1 48 | for i=1,sizes:size() do 49 | newsizes[i+1]=sizes[i] 50 | end 51 | return t:view(newsizes) 52 | end 53 | 54 | function ScaleBHWD:updateOutput(input) 55 | local _inputImages = input 56 | -- local _grids = input[2] 57 | 58 | local inputImages 59 | if _inputImages:nDimension()==3 then 60 | inputImages = addOuterDim(_inputImages) 61 | -- grids = addOuterDim(_grids) 62 | else 63 | inputImages = _inputImages 64 | -- grids = _grids 65 | end 66 | 67 | local input = inputImages 68 | 69 | self:check(input) 70 | 71 | self.output:resize(inputImages:size(1), self.scale*inputImages:size(2), self.scale*inputImages:size(3), inputImages:size(4)) 72 | 73 | inputImages.nn.ScaleBHWD_updateOutput(self, inputImages, self.output) 74 | 75 | if _inputImages:nDimension()==3 then 76 | self.output=self.output:select(1,1) 77 | end 78 | 79 | return self.output 80 | end 81 | 82 | function ScaleBHWD:updateGradInput(_input, _gradOutput) 83 | self.gradInput:resizeAs(input) 84 | local _inputImages = _input 85 | 86 | local inputImages, gradOutput 87 | if _inputImages:nDimension()==3 then 88 | inputImages = addOuterDim(_inputImages) 89 | gradOutput = addOuterDim(_gradOutput) 90 | else 91 | inputImages = _inputImages 92 | gradOutput = _gradOutput 93 | end 94 | 95 | local input = inputImages 96 | 97 | self:check(input, gradOutput) 98 | -- for i=1,#input do 99 | self.gradInput = self.gradInput or input.new() 100 | self.gradInput:resizeAs(input):zero() 101 | -- end 102 | 103 | local gradInputImages = self.gradInput[1] 104 | --local gradGrids = self.gradInput[2] 105 | 106 | inputImages.nn.ScaleBHWD_updateGradInput(self, inputImages, gradInputImages, gradOutput) 107 | 108 | if _gradOutput:nDimension()==3 then 109 | self.gradInput=self.gradInput:select(1,1) 110 | -- self.gradInput[2]=self.gradInput[2]:select(1,1) 111 | end 112 | 113 | return self.gradInput 114 | end 115 | -------------------------------------------------------------------------------- /extras/spybhwd/generic/ScaleBHWD.c: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/ScaleBHWD.c" 3 | #else 4 | 5 | #include 6 | 7 | 8 | static int nn_(ScaleBHWD_updateOutput)(lua_State *L) 9 | { 10 | THTensor *inputImages = luaT_checkudata(L, 2, torch_Tensor); 11 | //THTensor *grids = luaT_checkudata(L, 3, torch_Tensor); 12 | //real scale = luaT_getfieldchecknumber(L, 3, "scale"); 13 | THTensor *output = luaT_checkudata(L, 3, torch_Tensor); 14 | 15 | int batchsize = inputImages->size[0]; 16 | int inputImages_height = inputImages->size[1]; 17 | int inputImages_width = inputImages->size[2]; 18 | int output_height = output->size[1]; 19 | int output_width = output->size[2]; 20 | int inputImages_channels = inputImages->size[3]; 21 | 22 | int output_strideBatch = output->stride[0]; 23 | int output_strideHeight = output->stride[1]; 24 | int output_strideWidth = output->stride[2]; 25 | 26 | int inputImages_strideBatch = inputImages->stride[0]; 27 | int inputImages_strideHeight = inputImages->stride[1]; 28 | int inputImages_strideWidth = inputImages->stride[2]; 29 | 30 | // int grids_strideBatch = grids->stride[0]; 31 | // int grids_strideHeight = grids->stride[1]; 32 | // int grids_strideWidth = grids->stride[2]; 33 | 34 | real *inputImages_data, *output_data; 35 | inputImages_data = THTensor_(data)(inputImages); 36 | output_data = THTensor_(data)(output); 37 | // grids_data = THTensor_(data)(grids); 38 | 39 | int b, yOut, xOut; 40 | 41 | for(b=0; b < batchsize; b++) 42 | { 43 | for(yOut=0; yOut < output_height; yOut++) 44 | { 45 | for(xOut=0; xOut < output_width; xOut++) 46 | { 47 | //read the grid 48 | //real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth]; 49 | //real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1]; 50 | 51 | // get the weights for interpolation 52 | int yInTopLeft, xInTopLeft; 53 | real yWeightTopLeft, xWeightTopLeft; 54 | 55 | real xcoord = (inputImages_width - 1)*xOut / (output_width -1); 56 | xInTopLeft = floor(xcoord); 57 | xWeightTopLeft = 1 - (xcoord - xInTopLeft); 58 | 59 | real ycoord = (inputImages_height -1)*yOut / (output_height -1); 60 | yInTopLeft = floor(ycoord); 61 | yWeightTopLeft = 1 - (ycoord - yInTopLeft); 62 | 63 | 64 | 65 | const int outAddress = output_strideBatch * b + output_strideHeight * yOut + output_strideWidth * xOut; 66 | const int inTopLeftAddress = inputImages_strideBatch * b + inputImages_strideHeight * yInTopLeft + inputImages_strideWidth * xInTopLeft; 67 | const int inTopRightAddress = inTopLeftAddress + inputImages_strideWidth; 68 | const int inBottomLeftAddress = inTopLeftAddress + inputImages_strideHeight; 69 | const int inBottomRightAddress = inBottomLeftAddress + inputImages_strideWidth; 70 | 71 | real v=0; 72 | real inTopLeft=0; 73 | real inTopRight=0; 74 | real inBottomLeft=0; 75 | real inBottomRight=0; 76 | 77 | // we are careful with the boundaries 78 | bool topLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1; 79 | bool topRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1; 80 | bool bottomLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1; 81 | bool bottomRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1; 82 | 83 | int t; 84 | // interpolation happens here 85 | for(t=0; tsize[0]; 119 | int inputImages_height = inputImages->size[1]; 120 | int inputImages_width = inputImages->size[2]; 121 | int gradOutput_height = gradOutput->size[1]; 122 | int gradOutput_width = gradOutput->size[2]; 123 | int inputImages_channels = inputImages->size[3]; 124 | 125 | int gradOutput_strideBatch = gradOutput->stride[0]; 126 | int gradOutput_strideHeight = gradOutput->stride[1]; 127 | int gradOutput_strideWidth = gradOutput->stride[2]; 128 | 129 | int inputImages_strideBatch = inputImages->stride[0]; 130 | int inputImages_strideHeight = inputImages->stride[1]; 131 | int inputImages_strideWidth = inputImages->stride[2]; 132 | 133 | int gradInputImages_strideBatch = gradInputImages->stride[0]; 134 | int gradInputImages_strideHeight = gradInputImages->stride[1]; 135 | int gradInputImages_strideWidth = gradInputImages->stride[2]; 136 | 137 | // int grids_strideBatch = grids->stride[0]; 138 | // int grids_strideHeight = grids->stride[1]; 139 | // int grids_strideWidth = grids->stride[2]; 140 | 141 | // int gradGrids_strideBatch = gradGrids->stride[0]; 142 | // int gradGrids_strideHeight = gradGrids->stride[1]; 143 | // int gradGrids_strideWidth = gradGrids->stride[2]; 144 | 145 | real *inputImages_data, *gradOutput_data, *gradInputImages_data; 146 | inputImages_data = THTensor_(data)(inputImages); 147 | gradOutput_data = THTensor_(data)(gradOutput); 148 | // grids_data = THTensor_(data)(grids); 149 | // gradGrids_data = THTensor_(data)(gradGrids); 150 | gradInputImages_data = THTensor_(data)(gradInputImages); 151 | 152 | int b, yOut, xOut; 153 | 154 | for(b=0; b < batchsize; b++) 155 | { 156 | for(yOut=0; yOut < gradOutput_height; yOut++) 157 | { 158 | for(xOut=0; xOut < gradOutput_width; xOut++) 159 | { 160 | //read the grid 161 | //real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth]; 162 | //real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1]; 163 | 164 | // get the weights for interpolation 165 | int yInTopLeft, xInTopLeft; 166 | real yWeightTopLeft, xWeightTopLeft; 167 | 168 | real xcoord = (inputImages_width - 1)*xOut / (gradOutput_width -1); 169 | xInTopLeft = floor(xcoord); 170 | xWeightTopLeft = 1 - (xcoord - xInTopLeft); 171 | 172 | real ycoord = (inputImages_height -1)*yOut / (gradOutput_height -1); 173 | yInTopLeft = floor(ycoord); 174 | yWeightTopLeft = 1 - (ycoord - yInTopLeft); 175 | 176 | 177 | const int inTopLeftAddress = inputImages_strideBatch * b + inputImages_strideHeight * yInTopLeft + inputImages_strideWidth * xInTopLeft; 178 | const int inTopRightAddress = inTopLeftAddress + inputImages_strideWidth; 179 | const int inBottomLeftAddress = inTopLeftAddress + inputImages_strideHeight; 180 | const int inBottomRightAddress = inBottomLeftAddress + inputImages_strideWidth; 181 | 182 | const int gradInputImagesTopLeftAddress = gradInputImages_strideBatch * b + gradInputImages_strideHeight * yInTopLeft + gradInputImages_strideWidth * xInTopLeft; 183 | const int gradInputImagesTopRightAddress = gradInputImagesTopLeftAddress + gradInputImages_strideWidth; 184 | const int gradInputImagesBottomLeftAddress = gradInputImagesTopLeftAddress + gradInputImages_strideHeight; 185 | const int gradInputImagesBottomRightAddress = gradInputImagesBottomLeftAddress + gradInputImages_strideWidth; 186 | 187 | const int gradOutputAddress = gradOutput_strideBatch * b + gradOutput_strideHeight * yOut + gradOutput_strideWidth * xOut; 188 | 189 | real topLeftDotProduct = 0; 190 | real topRightDotProduct = 0; 191 | real bottomLeftDotProduct = 0; 192 | real bottomRightDotProduct = 0; 193 | 194 | real v=0; 195 | real inTopLeft=0; 196 | real inTopRight=0; 197 | real inBottomLeft=0; 198 | real inBottomRight=0; 199 | 200 | // we are careful with the boundaries 201 | bool topLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1; 202 | bool topRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1; 203 | bool bottomLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1; 204 | bool bottomRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1; 205 | 206 | int t; 207 | 208 | for(t=0; t= 7.0", 18 | "nn >= 1.0", 19 | } 20 | 21 | build = { 22 | type = "command", 23 | build_command = [[ 24 | cmake -E make_directory build && cd build && cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)" && $(MAKE) 25 | ]], 26 | install_command = "cd build && $(MAKE) install" 27 | } 28 | -------------------------------------------------------------------------------- /extras/spybhwd/test.lua: -------------------------------------------------------------------------------- 1 | -- you can easily test specific units like this: 2 | -- th -lnn -e "nn.test{'LookupTable'}" 3 | -- th -lnn -e "nn.test{'LookupTable', 'Add'}" 4 | 5 | local mytester = torch.Tester() 6 | local jac 7 | local sjac 8 | 9 | local precision = 1e-5 10 | local expprecision = 1e-4 11 | 12 | local stntest = {} 13 | 14 | function stntest.AffineGridGeneratorBHWD_batch() 15 | local nframes = torch.random(2,10) 16 | local height = torch.random(2,5) 17 | local width = torch.random(2,5) 18 | local input = torch.zeros(nframes, 2, 3):uniform() 19 | local module = nn.AffineGridGeneratorBHWD(height, width) 20 | 21 | local err = jac.testJacobian(module,input) 22 | mytester:assertlt(err,precision, 'error on state ') 23 | 24 | -- IO 25 | local ferr,berr = jac.testIO(module,input) 26 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') 27 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') 28 | 29 | end 30 | 31 | function stntest.AffineGridGeneratorBHWD_single() 32 | local height = torch.random(2,5) 33 | local width = torch.random(2,5) 34 | local input = torch.zeros(2, 3):uniform() 35 | local module = nn.AffineGridGeneratorBHWD(height, width) 36 | 37 | local err = jac.testJacobian(module,input) 38 | mytester:assertlt(err,precision, 'error on state ') 39 | 40 | -- IO 41 | local ferr,berr = jac.testIO(module,input) 42 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') 43 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') 44 | 45 | end 46 | 47 | function stntest.BilinearSamplerBHWD_batch() 48 | local nframes = torch.random(2,10) 49 | local height = torch.random(1,5) 50 | local width = torch.random(1,5) 51 | local channels = torch.random(1,6) 52 | local inputImages = torch.zeros(nframes, height, width, channels):uniform() 53 | local grids = torch.zeros(nframes, height, width, 2):uniform(-1, 1) 54 | local module = nn.BilinearSamplerBHWD() 55 | 56 | -- test input images (first element of input table) 57 | module._updateOutput = module.updateOutput 58 | function module:updateOutput(input) 59 | return self:_updateOutput({input, grids}) 60 | end 61 | 62 | module._updateGradInput = module.updateGradInput 63 | function module:updateGradInput(input, gradOutput) 64 | self:_updateGradInput({input, grids}, gradOutput) 65 | return self.gradInput[1] 66 | end 67 | 68 | local errImages = jac.testJacobian(module,inputImages) 69 | mytester:assertlt(errImages,precision, 'error on state ') 70 | 71 | -- test grids (second element of input table) 72 | function module:updateOutput(input) 73 | return self:_updateOutput({inputImages, input}) 74 | end 75 | 76 | function module:updateGradInput(input, gradOutput) 77 | self:_updateGradInput({inputImages, input}, gradOutput) 78 | return self.gradInput[2] 79 | end 80 | 81 | local errGrids = jac.testJacobian(module,grids) 82 | mytester:assertlt(errGrids,precision, 'error on state ') 83 | end 84 | 85 | function stntest.BilinearSamplerBHWD_single() 86 | local height = torch.random(1,5) 87 | local width = torch.random(1,5) 88 | local channels = torch.random(1,6) 89 | local inputImages = torch.zeros(height, width, channels):uniform() 90 | local grids = torch.zeros(height, width, 2):uniform(-1, 1) 91 | local module = nn.BilinearSamplerBHWD() 92 | 93 | -- test input images (first element of input table) 94 | module._updateOutput = module.updateOutput 95 | function module:updateOutput(input) 96 | return self:_updateOutput({input, grids}) 97 | end 98 | 99 | module._updateGradInput = module.updateGradInput 100 | function module:updateGradInput(input, gradOutput) 101 | self:_updateGradInput({input, grids}, gradOutput) 102 | return self.gradInput[1] 103 | end 104 | 105 | local errImages = jac.testJacobian(module,inputImages) 106 | mytester:assertlt(errImages,precision, 'error on state ') 107 | 108 | -- test grids (second element of input table) 109 | function module:updateOutput(input) 110 | return self:_updateOutput({inputImages, input}) 111 | end 112 | 113 | function module:updateGradInput(input, gradOutput) 114 | self:_updateGradInput({inputImages, input}, gradOutput) 115 | return self.gradInput[2] 116 | end 117 | 118 | local errGrids = jac.testJacobian(module,grids) 119 | mytester:assertlt(errGrids,precision, 'error on state ') 120 | end 121 | 122 | function stntest.AffineTransformMatrixGenerator_batch() 123 | -- test all possible transformations 124 | for _,useRotation in pairs{true,false} do 125 | for _,useScale in pairs{true,false} do 126 | for _,useTranslation in pairs{true,false} do 127 | local currTest = '' 128 | if useRotation then currTest = currTest..'rotation ' end 129 | if useScale then currTest = currTest..'scale ' end 130 | if useTranslation then currTest = currTest..'translation' end 131 | if currTest=='' then currTest = 'full' end 132 | 133 | local nbNeededParams = 0 134 | if useRotation then nbNeededParams = nbNeededParams + 1 end 135 | if useScale then nbNeededParams = nbNeededParams + 1 end 136 | if useTranslation then nbNeededParams = nbNeededParams + 2 end 137 | if nbNeededParams == 0 then nbNeededParams = 6 end -- full affine case 138 | 139 | local nframes = torch.random(2,10) 140 | local params = torch.zeros(nframes,nbNeededParams):uniform() 141 | local module = nn.AffineTransformMatrixGenerator(useRotation,useScale,useTranslation) 142 | 143 | local err = jac.testJacobian(module,params) 144 | mytester:assertlt(err,precision, 'error on state for test '..currTest) 145 | 146 | -- IO 147 | local ferr,berr = jac.testIO(module,params) 148 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err for test '..currTest) 149 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err for test '..currTest) 150 | 151 | end 152 | end 153 | end 154 | end 155 | 156 | function stntest.AffineTransformMatrixGenerator_single() 157 | -- test all possible transformations 158 | for _,useRotation in pairs{true,false} do 159 | for _,useScale in pairs{true,false} do 160 | for _,useTranslation in pairs{true,false} do 161 | local currTest = '' 162 | if useRotation then currTest = currTest..'rotation ' end 163 | if useScale then currTest = currTest..'scale ' end 164 | if useTranslation then currTest = currTest..'translation' end 165 | if currTest=='' then currTest = 'full' end 166 | 167 | local nbNeededParams = 0 168 | if useRotation then nbNeededParams = nbNeededParams + 1 end 169 | if useScale then nbNeededParams = nbNeededParams + 1 end 170 | if useTranslation then nbNeededParams = nbNeededParams + 2 end 171 | if nbNeededParams == 0 then nbNeededParams = 6 end -- full affine case 172 | 173 | local params = torch.zeros(nbNeededParams):uniform() 174 | local module = nn.AffineTransformMatrixGenerator(useRotation,useScale,useTranslation) 175 | 176 | local err = jac.testJacobian(module,params) 177 | mytester:assertlt(err,precision, 'error on state for test '..currTest) 178 | 179 | -- IO 180 | local ferr,berr = jac.testIO(module,params) 181 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err for test '..currTest) 182 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err for test '..currTest) 183 | 184 | end 185 | end 186 | end 187 | end 188 | 189 | mytester:add(stntest) 190 | 191 | if not nn then 192 | require 'nn' 193 | jac = nn.Jacobian 194 | sjac = nn.SparseJacobian 195 | mytester:run() 196 | else 197 | jac = nn.Jacobian 198 | sjac = nn.SparseJacobian 199 | function stn.test(tests) 200 | -- randomize stuff 201 | math.randomseed(os.time()) 202 | mytester:run(tests) 203 | return mytester 204 | end 205 | end 206 | -------------------------------------------------------------------------------- /extras/spybhwd/utils.c: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | THCState* getCutorchState(lua_State* L) 4 | { 5 | lua_getglobal(L, "cutorch"); 6 | lua_getfield(L, -1, "getState"); 7 | lua_call(L, 0, 1); 8 | THCState *state = (THCState*) lua_touserdata(L, -1); 9 | lua_pop(L, 2); 10 | return state; 11 | } 12 | -------------------------------------------------------------------------------- /extras/spybhwd/utils.h: -------------------------------------------------------------------------------- 1 | #ifndef CUNN_UTILS_H 2 | #define CUNN_UTILS_H 3 | 4 | #include 5 | #include "THCGeneral.h" 6 | 7 | THCState* getCutorchState(lua_State* L); 8 | 9 | #endif 10 | -------------------------------------------------------------------------------- /extras/stnbhwd/AffineGridGeneratorBHWD.lua: -------------------------------------------------------------------------------- 1 | local AGG, parent = torch.class('nn.AffineGridGeneratorBHWD', 'nn.Module') 2 | 3 | --[[ 4 | AffineGridGeneratorBHWD(height, width) : 5 | AffineGridGeneratorBHWD:updateOutput(transformMatrix) 6 | AffineGridGeneratorBHWD:updateGradInput(transformMatrix, gradGrids) 7 | 8 | AffineGridGeneratorBHWD will take 2x3 an affine image transform matrix (homogeneous 9 | coordinates) as input, and output a grid, in normalized coordinates* that, once used 10 | with the Bilinear Sampler, will result in an affine transform. 11 | 12 | AffineGridGenerator 13 | - takes (B,2,3)-shaped transform matrices as input (B=batch). 14 | - outputs a grid in BHWD layout, that can be used directly with BilinearSamplerBHWD 15 | - initialization of the previous layer should biased towards the identity transform : 16 | | 1 0 0 | 17 | | 0 1 0 | 18 | 19 | *: normalized coordinates [-1,1] correspond to the boundaries of the input image. 20 | ]] 21 | 22 | function AGG:__init(height, width) 23 | parent.__init(self) 24 | assert(height > 1) 25 | assert(width > 1) 26 | self.height = height 27 | self.width = width 28 | 29 | self.baseGrid = torch.Tensor(height, width, 3) 30 | for i=1,self.height do 31 | self.baseGrid:select(3,1):select(1,i):fill(-1 + (i-1)/(self.height-1) * 2) 32 | end 33 | for j=1,self.width do 34 | self.baseGrid:select(3,2):select(2,j):fill(-1 + (j-1)/(self.width-1) * 2) 35 | end 36 | self.baseGrid:select(3,3):fill(1) 37 | self.batchGrid = torch.Tensor(1, height, width, 3):copy(self.baseGrid) 38 | end 39 | 40 | local function addOuterDim(t) 41 | local sizes = t:size() 42 | local newsizes = torch.LongStorage(sizes:size()+1) 43 | newsizes[1]=1 44 | for i=1,sizes:size() do 45 | newsizes[i+1]=sizes[i] 46 | end 47 | return t:view(newsizes) 48 | end 49 | 50 | function AGG:updateOutput(_transformMatrix) 51 | local transformMatrix 52 | if _transformMatrix:nDimension()==2 then 53 | transformMatrix = addOuterDim(_transformMatrix) 54 | else 55 | transformMatrix = _transformMatrix 56 | end 57 | assert(transformMatrix:nDimension()==3 58 | and transformMatrix:size(2)==2 59 | and transformMatrix:size(3)==3 60 | , 'please input affine transform matrices (bx2x3)') 61 | local batchsize = transformMatrix:size(1) 62 | 63 | if self.batchGrid:size(1) ~= batchsize then 64 | self.batchGrid:resize(batchsize, self.height, self.width, 3) 65 | for i=1,batchsize do 66 | self.batchGrid:select(1,i):copy(self.baseGrid) 67 | end 68 | end 69 | 70 | self.output:resize(batchsize, self.height, self.width, 2) 71 | local flattenedBatchGrid = self.batchGrid:view(batchsize, self.width*self.height, 3) 72 | local flattenedOutput = self.output:view(batchsize, self.width*self.height, 2) 73 | torch.bmm(flattenedOutput, flattenedBatchGrid, transformMatrix:transpose(2,3)) 74 | if _transformMatrix:nDimension()==2 then 75 | self.output = self.output:select(1,1) 76 | end 77 | return self.output 78 | end 79 | 80 | function AGG:updateGradInput(_transformMatrix, _gradGrid) 81 | local transformMatrix, gradGrid 82 | if _transformMatrix:nDimension()==2 then 83 | transformMatrix = addOuterDim(_transformMatrix) 84 | gradGrid = addOuterDim(_gradGrid) 85 | else 86 | transformMatrix = _transformMatrix 87 | gradGrid = _gradGrid 88 | end 89 | 90 | local batchsize = transformMatrix:size(1) 91 | local flattenedGradGrid = gradGrid:view(batchsize, self.width*self.height, 2) 92 | local flattenedBatchGrid = self.batchGrid:view(batchsize, self.width*self.height, 3) 93 | self.gradInput:resizeAs(transformMatrix):zero() 94 | self.gradInput:baddbmm(flattenedGradGrid:transpose(2,3), flattenedBatchGrid) 95 | -- torch.baddbmm doesn't work on cudatensors for some reason 96 | 97 | if _transformMatrix:nDimension()==2 then 98 | self.gradInput = self.gradInput:select(1,1) 99 | end 100 | 101 | return self.gradInput 102 | end 103 | -------------------------------------------------------------------------------- /extras/stnbhwd/AffineTransformMatrixGenerator.lua: -------------------------------------------------------------------------------- 1 | local ATMG, parent = torch.class('nn.AffineTransformMatrixGenerator', 'nn.Module') 2 | 3 | --[[ 4 | AffineTransformMatrixGenerator(useRotation, useScale, useTranslation) : 5 | AffineTransformMatrixGenerator:updateOutput(transformParams) 6 | AffineTransformMatrixGenerator:updateGradInput(transformParams, gradParams) 7 | 8 | This module can be used in between the localisation network (that outputs the 9 | parameters of the transformation) and the AffineGridGeneratorBHWD (that expects 10 | an affine transform matrix as input). 11 | 12 | The goal is to be able to use only specific transformations or a combination of them. 13 | 14 | If no specific transformation is specified, it uses a fully parametrized 15 | linear transformation and thus expects 6 parameters as input. In this case 16 | the module is equivalent to nn.View(2,3):setNumInputDims(2). 17 | 18 | Any combination of the 3 transformations (rotation, scale and/or translation) 19 | can be used. The transform parameters must be supplied in the following order: 20 | rotation (1 param), scale (1 param) then translation (2 params). 21 | 22 | Example: 23 | AffineTransformMatrixGenerator(true,false,true) expects as input a tensor of 24 | if size (B, 3) containing (rotationAngle, translationX, translationY). 25 | ]] 26 | 27 | function ATMG:__init(useRotation, useScale, useTranslation) 28 | parent.__init(self) 29 | 30 | -- if no specific transformation, use fully parametrized version 31 | self.fullMode = not(useRotation or useScale or useTranslation) 32 | 33 | if not self.fullMode then 34 | self.useRotation = useRotation 35 | self.useScale = useScale 36 | self.useTranslation = useTranslation 37 | end 38 | end 39 | 40 | function ATMG:check(input) 41 | if self.fullMode then 42 | assert(input:size(2)==6, 'Expected 6 parameters, got ' .. input:size(2)) 43 | else 44 | local numberParameters = 0 45 | if self.useRotation then 46 | numberParameters = numberParameters + 1 47 | end 48 | if self.useScale then 49 | numberParameters = numberParameters + 1 50 | end 51 | if self.useTranslation then 52 | numberParameters = numberParameters + 2 53 | end 54 | assert(input:size(2)==numberParameters, 'Expected '..numberParameters.. 55 | ' parameters, got ' .. input:size(2)) 56 | end 57 | end 58 | 59 | local function addOuterDim(t) 60 | local sizes = t:size() 61 | local newsizes = torch.LongStorage(sizes:size()+1) 62 | newsizes[1]=1 63 | for i=1,sizes:size() do 64 | newsizes[i+1]=sizes[i] 65 | end 66 | return t:view(newsizes) 67 | end 68 | 69 | function ATMG:updateOutput(_tranformParams) 70 | local transformParams 71 | if _tranformParams:nDimension()==1 then 72 | transformParams = addOuterDim(_tranformParams) 73 | else 74 | transformParams = _tranformParams 75 | end 76 | 77 | self:check(transformParams) 78 | local batchSize = transformParams:size(1) 79 | 80 | if self.fullMode then 81 | self.output = transformParams:view(batchSize, 2, 3) 82 | else 83 | local completeTransformation = torch.zeros(batchSize,3,3):typeAs(transformParams) 84 | completeTransformation:select(3,1):select(2,1):add(1) 85 | completeTransformation:select(3,2):select(2,2):add(1) 86 | completeTransformation:select(3,3):select(2,3):add(1) 87 | local transformationBuffer = torch.Tensor(batchSize,3,3):typeAs(transformParams) 88 | 89 | local paramIndex = 1 90 | if self.useRotation then 91 | local alphas = transformParams:select(2, paramIndex) 92 | paramIndex = paramIndex + 1 93 | 94 | transformationBuffer:zero() 95 | transformationBuffer:select(3,3):select(2,3):add(1) 96 | local cosines = torch.cos(alphas) 97 | local sinuses = torch.sin(alphas) 98 | transformationBuffer:select(3,1):select(2,1):copy(cosines) 99 | transformationBuffer:select(3,2):select(2,2):copy(cosines) 100 | transformationBuffer:select(3,1):select(2,2):copy(sinuses) 101 | transformationBuffer:select(3,2):select(2,1):copy(-sinuses) 102 | 103 | completeTransformation = torch.bmm(completeTransformation, transformationBuffer) 104 | end 105 | self.rotationOutput = completeTransformation:narrow(2,1,2):narrow(3,1,2):clone() 106 | 107 | if self.useScale then 108 | local scaleFactors = transformParams:select(2,paramIndex) 109 | paramIndex = paramIndex + 1 110 | 111 | transformationBuffer:zero() 112 | transformationBuffer:select(3,1):select(2,1):copy(scaleFactors) 113 | transformationBuffer:select(3,2):select(2,2):copy(scaleFactors) 114 | transformationBuffer:select(3,3):select(2,3):add(1) 115 | 116 | completeTransformation = torch.bmm(completeTransformation, transformationBuffer) 117 | end 118 | self.scaleOutput = completeTransformation:narrow(2,1,2):narrow(3,1,2):clone() 119 | 120 | if self.useTranslation then 121 | local txs = transformParams:select(2,paramIndex) 122 | local tys = transformParams:select(2,paramIndex+1) 123 | 124 | transformationBuffer:zero() 125 | transformationBuffer:select(3,1):select(2,1):add(1) 126 | transformationBuffer:select(3,2):select(2,2):add(1) 127 | transformationBuffer:select(3,3):select(2,3):add(1) 128 | transformationBuffer:select(3,3):select(2,1):copy(txs) 129 | transformationBuffer:select(3,3):select(2,2):copy(tys) 130 | 131 | completeTransformation = torch.bmm(completeTransformation, transformationBuffer) 132 | end 133 | 134 | self.output=completeTransformation:narrow(2,1,2) 135 | end 136 | 137 | if _tranformParams:nDimension()==1 then 138 | self.output = self.output:select(1,1) 139 | end 140 | return self.output 141 | end 142 | 143 | 144 | function ATMG:updateGradInput(_tranformParams, _gradParams) 145 | local transformParams, gradParams 146 | if _tranformParams:nDimension()==1 then 147 | transformParams = addOuterDim(_tranformParams) 148 | gradParams = addOuterDim(_gradParams):clone() 149 | else 150 | transformParams = _tranformParams 151 | gradParams = _gradParams:clone() 152 | end 153 | 154 | local batchSize = transformParams:size(1) 155 | if self.fullMode then 156 | self.gradInput = gradParams:view(batchSize, 6) 157 | else 158 | local paramIndex = transformParams:size(2) 159 | self.gradInput:resizeAs(transformParams) 160 | if self.useTranslation then 161 | local gradInputTranslationParams = self.gradInput:narrow(2,paramIndex-1,2) 162 | local tParams = torch.Tensor(batchSize, 1, 2):typeAs(transformParams) 163 | tParams:select(3,1):copy(transformParams:select(2,paramIndex-1)) 164 | tParams:select(3,2):copy(transformParams:select(2,paramIndex)) 165 | paramIndex = paramIndex-2 166 | 167 | local selectedOutput = self.scaleOutput 168 | local selectedGradParams = gradParams:narrow(2,1,2):narrow(3,3,1):transpose(2,3) 169 | gradInputTranslationParams:copy(torch.bmm(selectedGradParams, selectedOutput)) 170 | 171 | local gradientCorrection = torch.bmm(selectedGradParams:transpose(2,3), tParams) 172 | gradParams:narrow(3,1,2):narrow(2,1,2):add(1,gradientCorrection) 173 | end 174 | 175 | if self.useScale then 176 | local gradInputScaleparams = self.gradInput:narrow(2,paramIndex,1) 177 | local sParams = transformParams:select(2,paramIndex) 178 | paramIndex = paramIndex-1 179 | 180 | local selectedOutput = self.rotationOutput 181 | local selectedGradParams = gradParams:narrow(2,1,2):narrow(3,1,2) 182 | gradInputScaleparams:copy(torch.cmul(selectedOutput, selectedGradParams):sum(2):sum(3)) 183 | 184 | gradParams:select(3,1):select(2,1):cmul(sParams) 185 | gradParams:select(3,2):select(2,1):cmul(sParams) 186 | gradParams:select(3,1):select(2,2):cmul(sParams) 187 | gradParams:select(3,2):select(2,2):cmul(sParams) 188 | end 189 | 190 | if self.useRotation then 191 | local gradInputRotationParams = self.gradInput:narrow(2,paramIndex,1) 192 | local rParams = transformParams:select(2,paramIndex) 193 | 194 | local rotationDerivative = torch.zeros(batchSize, 2, 2):typeAs(rParams) 195 | torch.sin(rotationDerivative:select(3,1):select(2,1),-rParams) 196 | torch.sin(rotationDerivative:select(3,2):select(2,2),-rParams) 197 | torch.cos(rotationDerivative:select(3,1):select(2,2),rParams) 198 | torch.cos(rotationDerivative:select(3,2):select(2,1),rParams):mul(-1) 199 | local selectedGradParams = gradParams:narrow(2,1,2):narrow(3,1,2) 200 | gradInputRotationParams:copy(torch.cmul(rotationDerivative,selectedGradParams):sum(2):sum(3)) 201 | end 202 | end 203 | 204 | if _tranformParams:nDimension()==1 then 205 | self.gradInput = self.gradInput:select(1,1) 206 | end 207 | return self.gradInput 208 | end 209 | 210 | 211 | -------------------------------------------------------------------------------- /extras/stnbhwd/BilinearSamplerBHWD.lua: -------------------------------------------------------------------------------- 1 | local BilinearSamplerBHWD, parent = torch.class('nn.BilinearSamplerBHWD', 'nn.Module') 2 | 3 | --[[ 4 | BilinearSamplerBHWD() : 5 | BilinearSamplerBHWD:updateOutput({inputImages, grids}) 6 | BilinearSamplerBHWD:updateGradInput({inputImages, grids}, gradOutput) 7 | 8 | BilinearSamplerBHWD will perform bilinear sampling of the input images according to the 9 | normalized coordinates provided in the grid. Output will be of same size as the grids, 10 | with as many features as the input images. 11 | 12 | - inputImages has to be in BHWD layout 13 | 14 | - grids have to be in BHWD layout, with dim(D)=2 15 | - grids contains, for each sample (first dim), the normalized coordinates of the output wrt the input sample 16 | - first coordinate is Y coordinate, second is X 17 | - normalized coordinates : (-1,-1) points to top left, (-1,1) points to top right 18 | - if the normalized coordinates fall outside of the image, then output will be filled with zeros 19 | ]] 20 | 21 | function BilinearSamplerBHWD:__init() 22 | parent.__init(self) 23 | self.gradInput={} 24 | end 25 | 26 | function BilinearSamplerBHWD:check(input, gradOutput) 27 | local inputImages = input[1] 28 | local grids = input[2] 29 | 30 | assert(inputImages:isContiguous(), 'Input images have to be contiguous') 31 | assert(inputImages:nDimension()==4) 32 | assert(grids:nDimension()==4) 33 | assert(inputImages:size(1)==grids:size(1)) -- batch 34 | assert(grids:size(4)==2) -- coordinates 35 | 36 | if gradOutput then 37 | assert(grids:size(1)==gradOutput:size(1)) 38 | assert(grids:size(2)==gradOutput:size(2)) 39 | assert(grids:size(3)==gradOutput:size(3)) 40 | end 41 | end 42 | 43 | local function addOuterDim(t) 44 | local sizes = t:size() 45 | local newsizes = torch.LongStorage(sizes:size()+1) 46 | newsizes[1]=1 47 | for i=1,sizes:size() do 48 | newsizes[i+1]=sizes[i] 49 | end 50 | return t:view(newsizes) 51 | end 52 | 53 | function BilinearSamplerBHWD:updateOutput(input) 54 | local _inputImages = input[1] 55 | local _grids = input[2] 56 | 57 | local inputImages, grids 58 | if _inputImages:nDimension()==3 then 59 | inputImages = addOuterDim(_inputImages) 60 | grids = addOuterDim(_grids) 61 | else 62 | inputImages = _inputImages 63 | grids = _grids 64 | end 65 | 66 | local input = {inputImages, grids} 67 | 68 | self:check(input) 69 | 70 | self.output:resize(inputImages:size(1), grids:size(2), grids:size(3), inputImages:size(4)) 71 | 72 | inputImages.nn.BilinearSamplerBHWD_updateOutput(self, inputImages, grids, self.output) 73 | 74 | if _inputImages:nDimension()==3 then 75 | self.output=self.output:select(1,1) 76 | end 77 | 78 | return self.output 79 | end 80 | 81 | function BilinearSamplerBHWD:updateGradInput(_input, _gradOutput) 82 | local _inputImages = _input[1] 83 | local _grids = _input[2] 84 | 85 | local inputImages, grids, gradOutput 86 | if _inputImages:nDimension()==3 then 87 | inputImages = addOuterDim(_inputImages) 88 | grids = addOuterDim(_grids) 89 | gradOutput = addOuterDim(_gradOutput) 90 | else 91 | inputImages = _inputImages 92 | grids = _grids 93 | gradOutput = _gradOutput 94 | end 95 | 96 | local input = {inputImages, grids} 97 | 98 | self:check(input, gradOutput) 99 | for i=1,#input do 100 | self.gradInput[i] = self.gradInput[i] or input[1].new() 101 | self.gradInput[i]:resizeAs(input[i]):zero() 102 | end 103 | 104 | local gradInputImages = self.gradInput[1] 105 | local gradGrids = self.gradInput[2] 106 | 107 | inputImages.nn.BilinearSamplerBHWD_updateGradInput(self, inputImages, grids, gradInputImages, gradGrids, gradOutput) 108 | 109 | if _gradOutput:nDimension()==3 then 110 | self.gradInput[1]=self.gradInput[1]:select(1,1) 111 | self.gradInput[2]=self.gradInput[2]:select(1,1) 112 | end 113 | 114 | return self.gradInput 115 | end 116 | -------------------------------------------------------------------------------- /extras/stnbhwd/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | CMAKE_MINIMUM_REQUIRED(VERSION 2.8 FATAL_ERROR) 2 | CMAKE_POLICY(VERSION 2.8) 3 | 4 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_FORCE_INLINES") 5 | SET(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake" "${CMAKE_MODULE_PATH}") 6 | 7 | FIND_PACKAGE(Torch REQUIRED) 8 | 9 | # Flags 10 | # When using MSVC 11 | IF(MSVC) 12 | # we want to respect the standard, and we are bored of those **** . 13 | ADD_DEFINITIONS(-D_CRT_SECURE_NO_DEPRECATE=1) 14 | ENDIF(MSVC) 15 | 16 | # OpenMP support? 17 | SET(WITH_OPENMP ON CACHE BOOL "OpenMP support if available?") 18 | IF (APPLE AND CMAKE_COMPILER_IS_GNUCC) 19 | EXEC_PROGRAM (uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION) 20 | STRING (REGEX MATCH "[0-9]+" DARWIN_VERSION ${DARWIN_VERSION}) 21 | MESSAGE (STATUS "MAC OS Darwin Version: ${DARWIN_VERSION}") 22 | IF (DARWIN_VERSION GREATER 9) 23 | SET(APPLE_OPENMP_SUCKS 1) 24 | ENDIF (DARWIN_VERSION GREATER 9) 25 | EXECUTE_PROCESS (COMMAND ${CMAKE_C_COMPILER} -dumpversion 26 | OUTPUT_VARIABLE GCC_VERSION) 27 | IF (APPLE_OPENMP_SUCKS AND GCC_VERSION VERSION_LESS 4.6.2) 28 | MESSAGE(STATUS "Warning: Disabling OpenMP (unstable with this version of GCC)") 29 | MESSAGE(STATUS " Install GCC >= 4.6.2 or change your OS to enable OpenMP") 30 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unknown-pragmas") 31 | SET(WITH_OPENMP OFF CACHE BOOL "OpenMP support if available?" FORCE) 32 | ENDIF () 33 | ENDIF () 34 | 35 | IF (WITH_OPENMP) 36 | FIND_PACKAGE(OpenMP) 37 | IF(OPENMP_FOUND) 38 | MESSAGE(STATUS "Compiling with OpenMP support") 39 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 40 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 41 | SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") 42 | ENDIF(OPENMP_FOUND) 43 | ENDIF (WITH_OPENMP) 44 | 45 | LINK_DIRECTORIES("${Torch_INSTALL_LIB}") 46 | 47 | SET(src init.c) 48 | FILE(GLOB luasrc *.lua) 49 | ADD_TORCH_PACKAGE(stn "${src}" "${luasrc}") 50 | TARGET_LINK_LIBRARIES(stn luaT TH) 51 | 52 | 53 | FIND_PACKAGE(CUDA 5.5) 54 | 55 | IF (CUDA_FOUND) 56 | LIST(APPEND CUDA_NVCC_FLAGS "-arch=sm_20") 57 | LIST(APPEND CUDA_NVCC_FLAGS "-Xcompiler -std=c++98") 58 | 59 | INCLUDE_DIRECTORIES("${Torch_INSTALL_INCLUDE}/THC") 60 | SET(src-cuda init.cu) 61 | CUDA_ADD_LIBRARY(custn MODULE ${src-cuda}) 62 | TARGET_LINK_LIBRARIES(custn luaT THC TH) 63 | IF(APPLE) 64 | SET_TARGET_PROPERTIES(custn PROPERTIES 65 | LINK_FLAGS "-undefined dynamic_lookup") 66 | ENDIF() 67 | ### Torch packages supposes libraries prefix is "lib" 68 | SET_TARGET_PROPERTIES(custn PROPERTIES 69 | PREFIX "lib" 70 | IMPORT_PREFIX "lib") 71 | 72 | INSTALL(TARGETS custn 73 | RUNTIME DESTINATION "${Torch_INSTALL_LUA_CPATH_SUBDIR}" 74 | LIBRARY DESTINATION "${Torch_INSTALL_LUA_CPATH_SUBDIR}") 75 | ENDIF(CUDA_FOUND) 76 | -------------------------------------------------------------------------------- /extras/stnbhwd/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 qassemoquab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /extras/stnbhwd/README.md: -------------------------------------------------------------------------------- 1 | # stnbhwd 2 | 3 | ## Main modules 4 | 5 | These are the basic modules (BHWD layout) needed to implement a Spatial Transformer Network (Jaderberg et al.) http://arxiv.org/abs/1506.02025 6 | 7 | ``` lua 8 | require 'stn' 9 | 10 | nn.AffineGridGeneratorBHWD(height, width) 11 | -- takes B x 2 x 3 affine transform matrices as input, 12 | -- outputs a height x width grid in normalized [-1,1] coordinates 13 | -- output layout is B,H,W,2 where the first coordinate in the 4th dimension is y, and the second is x 14 | 15 | nn.BilinearSamplerBHWD() 16 | -- takes a table {inputImages, grids} as inputs 17 | -- outputs the interpolated images according to the grids 18 | -- inputImages is a batch of samples in BHWD layout 19 | -- grids is a batch of grids (output of AffineGridGeneratorBHWD) 20 | -- output is also BHWD 21 | ``` 22 | 23 | ## Advanced module 24 | 25 | This module allows the user to put a constraint on the possible transformations. 26 | It should be placed between the localisation network and the grid generator. 27 | 28 | ``` lua 29 | require 'stn' 30 | 31 | nn.AffineTransformMatrixGenerator(useRotation, useScale, useTranslation) 32 | -- takes a B x nbParams tensor as inputs 33 | -- nbParams depends on the contrained transformation 34 | -- The parameters for the selected transformation(s) should be supplied in the 35 | -- following order: rotationAngle, scaleFactor, translationX, translationY 36 | -- If no transformation is specified, it generates a generic affine transformation (nbParams = 6) 37 | -- outputs B x 2 x 3 affine transform matrices 38 | ``` 39 | 40 | 41 | If this code is useful to your research, please cite this repository. 42 | -------------------------------------------------------------------------------- /extras/stnbhwd/ScaleBHWD.lua: -------------------------------------------------------------------------------- 1 | local ScaleBHWD, parent = torch.class('nn.ScaleBHWD', 'nn.Module') 2 | 3 | --[[ 4 | ScaleBHWD() : 5 | ScaleBHWD:updateOutput({inputImages, grids}) 6 | ScaleBHWD:updateGradInput({inputImages, grids}, gradOutput) 7 | 8 | ScaleBHWD will perform bilinear sampling of the input images according to the 9 | normalized coordinates provided in the grid. Output will be of same size as the grids, 10 | with as many features as the input images. 11 | 12 | - inputImages has to be in BHWD layout 13 | 14 | - grids have to be in BHWD layout, with dim(D)=2 15 | - grids contains, for each sample (first dim), the normalized coordinates of the output wrt the input sample 16 | - first coordinate is Y coordinate, second is X 17 | - normalized coordinates : (-1,-1) points to top left, (-1,1) points to top right 18 | - if the normalized coordinates fall outside of the image, then output will be filled with zeros 19 | ]] 20 | 21 | function ScaleBHWD:__init() 22 | parent.__init(self) 23 | self.gradInput={} 24 | end 25 | 26 | function ScaleBHWD:check(input, gradOutput) 27 | local inputImages = input[1] 28 | local grids = input[2] 29 | 30 | assert(inputImages:isContiguous(), 'Input images have to be contiguous') 31 | assert(inputImages:nDimension()==4) 32 | assert(grids:nDimension()==4) 33 | assert(inputImages:size(1)==grids:size(1)) -- batch 34 | assert(grids:size(4)==2) -- coordinates 35 | 36 | if gradOutput then 37 | assert(grids:size(1)==gradOutput:size(1)) 38 | assert(grids:size(2)==gradOutput:size(2)) 39 | assert(grids:size(3)==gradOutput:size(3)) 40 | end 41 | end 42 | 43 | local function addOuterDim(t) 44 | local sizes = t:size() 45 | local newsizes = torch.LongStorage(sizes:size()+1) 46 | newsizes[1]=1 47 | for i=1,sizes:size() do 48 | newsizes[i+1]=sizes[i] 49 | end 50 | return t:view(newsizes) 51 | end 52 | 53 | function ScaleBHWD:updateOutput(input) 54 | local _inputImages = input[1] 55 | local _grids = input[2] 56 | 57 | local inputImages, grids 58 | if _inputImages:nDimension()==3 then 59 | inputImages = addOuterDim(_inputImages) 60 | grids = addOuterDim(_grids) 61 | else 62 | inputImages = _inputImages 63 | grids = _grids 64 | end 65 | 66 | local input = {inputImages, grids} 67 | 68 | self:check(input) 69 | 70 | self.output:resize(inputImages:size(1), grids:size(2), grids:size(3), inputImages:size(4)) 71 | 72 | inputImages.nn.ScaleBHWD_updateOutput(self, inputImages, grids, self.output) 73 | 74 | if _inputImages:nDimension()==3 then 75 | self.output=self.output:select(1,1) 76 | end 77 | 78 | return self.output 79 | end 80 | 81 | function ScaleBHWD:updateGradInput(_input, _gradOutput) 82 | local _inputImages = _input[1] 83 | local _grids = _input[2] 84 | 85 | local inputImages, grids, gradOutput 86 | if _inputImages:nDimension()==3 then 87 | inputImages = addOuterDim(_inputImages) 88 | grids = addOuterDim(_grids) 89 | gradOutput = addOuterDim(_gradOutput) 90 | else 91 | inputImages = _inputImages 92 | grids = _grids 93 | gradOutput = _gradOutput 94 | end 95 | 96 | local input = {inputImages, grids} 97 | 98 | self:check(input, gradOutput) 99 | for i=1,#input do 100 | self.gradInput[i] = self.gradInput[i] or input[1].new() 101 | self.gradInput[i]:resizeAs(input[i]):zero() 102 | end 103 | 104 | local gradInputImages = self.gradInput[1] 105 | local gradGrids = self.gradInput[2] 106 | 107 | inputImages.nn.ScaleBHWD_updateGradInput(self, inputImages, grids, gradInputImages, gradGrids, gradOutput) 108 | 109 | if _gradOutput:nDimension()==3 then 110 | self.gradInput[1]=self.gradInput[1]:select(1,1) 111 | self.gradInput[2]=self.gradInput[2]:select(1,1) 112 | end 113 | 114 | return self.gradInput 115 | end 116 | -------------------------------------------------------------------------------- /extras/stnbhwd/demo/Optim.lua: -------------------------------------------------------------------------------- 1 | --[[ That would be the license for Optim.lua 2 | 3 | BSD License 4 | 5 | For fbcunn software 6 | 7 | Copyright (c) 2014, Facebook, Inc. All rights reserved. 8 | 9 | Redistribution and use in source and binary forms, with or without modification, 10 | are permitted provided that the following conditions are met: 11 | 12 | * Redistributions of source code must retain the above copyright notice, this 13 | list of conditions and the following disclaimer. 14 | 15 | * Redistributions in binary form must reproduce the above copyright notice, 16 | this list of conditions and the following disclaimer in the documentation 17 | and/or other materials provided with the distribution. 18 | 19 | * Neither the name Facebook nor the names of its contributors may be used to 20 | endorse or promote products derived from this software without specific 21 | prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 24 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 25 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 26 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 27 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 29 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 30 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 31 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 32 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | ]] 34 | 35 | -- Copyright 2004-present Facebook. All Rights Reserved. 36 | 37 | local pl = require('pl.import_into')() 38 | 39 | -- from fblualib/fb/util/data.lua , copied here because fblualib is not rockspec ready yet. 40 | -- deepcopy routine that assumes the presence of a 'clone' method in user 41 | -- data should be used to deeply copy. This matches the behavior of Torch 42 | -- tensors. 43 | local function deepcopy(x) 44 | local typename = type(x) 45 | if typename == "userdata" then 46 | return x:clone() 47 | end 48 | if typename == "table" then 49 | local retval = { } 50 | for k,v in pairs(x) do 51 | retval[deepcopy(k)] = deepcopy(v) 52 | end 53 | return retval 54 | end 55 | return x 56 | end 57 | 58 | local Optim, parent = torch.class('nn.Optim') 59 | 60 | 61 | -- Returns weight parameters and bias parameters and associated grad parameters 62 | -- for this module. Annotates the return values with flag marking parameter set 63 | -- as bias parameters set 64 | function Optim.weight_bias_parameters(module) 65 | local weight_params, bias_params 66 | if module.weight then 67 | weight_params = {module.weight, module.gradWeight} 68 | weight_params.is_bias = false 69 | end 70 | if module.bias then 71 | bias_params = {module.bias, module.gradBias} 72 | bias_params.is_bias = true 73 | end 74 | return {weight_params, bias_params} 75 | end 76 | 77 | -- The regular `optim` package relies on `getParameters`, which is a 78 | -- beastly abomination before all. This `optim` package uses separate 79 | -- optim state for each submodule of a `nn.Module`. 80 | function Optim:__init(model, optState, checkpoint_data) 81 | assert(model) 82 | assert(checkpoint_data or optState) 83 | assert(not (checkpoint_data and optState)) 84 | 85 | self.model = model 86 | self.modulesToOptState = {} 87 | -- Keep this around so we update it in setParameters 88 | self.originalOptState = optState 89 | 90 | -- Each module has some set of parameters and grad parameters. Since 91 | -- they may be allocated discontinuously, we need separate optState for 92 | -- each parameter tensor. self.modulesToOptState maps each module to 93 | -- a lua table of optState clones. 94 | if not checkpoint_data then 95 | self.model:apply(function(module) 96 | self.modulesToOptState[module] = { } 97 | local params = self.weight_bias_parameters(module) 98 | -- expects either an empty table or 2 element table, one for weights 99 | -- and one for biases 100 | assert(pl.tablex.size(params) == 0 or pl.tablex.size(params) == 2) 101 | for i, _ in ipairs(params) do 102 | self.modulesToOptState[module][i] = deepcopy(optState) 103 | if params[i] and params[i].is_bias then 104 | -- never regularize biases 105 | self.modulesToOptState[module][i].weightDecay = 0.0 106 | end 107 | end 108 | assert(module) 109 | assert(self.modulesToOptState[module]) 110 | end) 111 | else 112 | local state = checkpoint_data.optim_state 113 | local modules = {} 114 | self.model:apply(function(m) table.insert(modules, m) end) 115 | assert(pl.tablex.compare_no_order(modules, pl.tablex.keys(state))) 116 | self.modulesToOptState = state 117 | end 118 | end 119 | 120 | function Optim:save() 121 | return { 122 | optim_state = self.modulesToOptState 123 | } 124 | end 125 | 126 | local function _type_all(obj, t) 127 | for k, v in pairs(obj) do 128 | if type(v) == 'table' then 129 | _type_all(v, t) 130 | else 131 | local tn = torch.typename(v) 132 | if tn and tn:find('torch%..+Tensor') then 133 | obj[k] = v:type(t) 134 | end 135 | end 136 | end 137 | end 138 | 139 | function Optim:type(t) 140 | self.model:apply(function(module) 141 | local state= self.modulesToOptState[module] 142 | assert(state) 143 | _type_all(state, t) 144 | end) 145 | end 146 | 147 | local function get_device_for_module(mod) 148 | local dev_id = nil 149 | for name, val in pairs(mod) do 150 | if torch.typename(val) == 'torch.CudaTensor' then 151 | local this_dev = val:getDevice() 152 | if this_dev ~= 0 then 153 | -- _make sure the tensors are allocated consistently 154 | assert(dev_id == nil or dev_id == this_dev) 155 | dev_id = this_dev 156 | end 157 | end 158 | end 159 | return dev_id -- _may still be zero if none are allocated. 160 | end 161 | 162 | local function on_device_for_module(mod, f) 163 | local this_dev = get_device_for_module(mod) 164 | if this_dev ~= nil then 165 | return cutorch.withDevice(this_dev, f) 166 | end 167 | return f() 168 | end 169 | 170 | function Optim:optimize(optimMethod, inputs, targets, criterion) 171 | assert(optimMethod) 172 | assert(inputs) 173 | assert(targets) 174 | assert(criterion) 175 | assert(self.modulesToOptState) 176 | 177 | self.model:zeroGradParameters() 178 | local output = self.model:forward(inputs) 179 | 180 | local err = criterion:forward(output, targets) 181 | 182 | local df_do = criterion:backward(output, targets) 183 | self.model:backward(inputs, df_do) 184 | 185 | -- We'll set these in the loop that iterates over each module. Get them 186 | -- out here to be captured. 187 | local curGrad 188 | local curParam 189 | local function fEvalMod(x) 190 | return err, curGrad 191 | end 192 | 193 | for curMod, opt in pairs(self.modulesToOptState) do 194 | on_device_for_module(curMod, function() 195 | local curModParams = self.weight_bias_parameters(curMod) 196 | -- expects either an empty table or 2 element table, one for weights 197 | -- and one for biases 198 | assert(pl.tablex.size(curModParams) == 0 or 199 | pl.tablex.size(curModParams) == 2) 200 | if curModParams then 201 | for i, tensor in ipairs(curModParams) do 202 | if curModParams[i] then 203 | -- expect param, gradParam pair 204 | curParam, curGrad = table.unpack(curModParams[i]) 205 | assert(curParam and curGrad) 206 | optimMethod(fEvalMod, curParam, opt[i]) 207 | end 208 | end 209 | end 210 | end) 211 | end 212 | 213 | return err, output 214 | end 215 | 216 | function Optim:optimizeFromGradients(optimMethod, inputs, gradients) 217 | assert(optimMethod) 218 | assert(inputs) 219 | assert(gradients) 220 | assert(self.modulesToOptState) 221 | 222 | self.model:zeroGradParameters() 223 | self.model:backward(inputs, gradients) 224 | 225 | -- We'll set these in the loop that iterates over each module. Get them 226 | -- out here to be captured. 227 | local curGrad 228 | local curParam 229 | local function fEvalMod(x) 230 | return 0, curGrad 231 | end 232 | 233 | for curMod, opt in pairs(self.modulesToOptState) do 234 | on_device_for_module(curMod, function() 235 | local curModParams = self.weight_bias_parameters(curMod) 236 | -- expects either an empty table or 2 element table, one for weights 237 | -- and one for biases 238 | assert(pl.tablex.size(curModParams) == 0 or 239 | pl.tablex.size(curModParams) == 2) 240 | if curModParams then 241 | for i, tensor in ipairs(curModParams) do 242 | if curModParams[i] then 243 | -- expect param, gradParam pair 244 | curParam, curGrad = table.unpack(curModParams[i]) 245 | assert(curParam and curGrad) 246 | optimMethod(fEvalMod, curParam, opt[i]) 247 | end 248 | end 249 | end 250 | end) 251 | end 252 | 253 | return err, output 254 | end 255 | 256 | function Optim:setParameters(newParams) 257 | assert(newParams) 258 | assert(type(newParams) == 'table') 259 | local function splice(dest, src) 260 | for k,v in pairs(src) do 261 | dest[k] = v 262 | end 263 | end 264 | 265 | splice(self.originalOptState, newParams) 266 | for _,optStates in pairs(self.modulesToOptState) do 267 | for i,optState in pairs(optStates) do 268 | assert(type(optState) == 'table') 269 | splice(optState, newParams) 270 | end 271 | end 272 | end -------------------------------------------------------------------------------- /extras/stnbhwd/demo/README.md: -------------------------------------------------------------------------------- 1 | # stnbhwd demo 2 | 3 | Download MNIST and untar in the demo folder, then run with qlua (for image.display): 4 | 5 | ``` 6 | wget 'http://torch7.s3-website-us-east-1.amazonaws.com/data/mnist.t7.tgz' 7 | tar -xf mnist.t7.tgz 8 | qlua -ide demo_mnist.lua 9 | ``` 10 | 11 | Images should appear after 5 epochs and show what the STN does on a test batch. 12 | You can edit demo_mnist.lua set use_stn = false to compare accuracy. 13 | 14 | You will need to work with the getParamsByDevice branch of the 'nn' package (required for nn.Optim). 15 | -------------------------------------------------------------------------------- /extras/stnbhwd/demo/demo_mnist.lua: -------------------------------------------------------------------------------- 1 | -- wget 'http://torch7.s3-website-us-east-1.amazonaws.com/data/mnist.t7.tgz' 2 | -- tar -xf mnist.t7.tgz 3 | 4 | require 'cunn' 5 | require 'cudnn' 6 | require 'image' 7 | require 'optim' 8 | paths.dofile('Optim.lua') 9 | 10 | use_stn = true 11 | 12 | -- distorted mnist dataset 13 | paths.dofile('distort_mnist.lua') 14 | datasetTrain, datasetVal = createDatasetsDistorted() 15 | 16 | -- model 17 | model = nn.Sequential() 18 | model:add(nn.View(32*32)) 19 | model:add(nn.Linear(32*32, 128)) 20 | model:add(cudnn.ReLU(true)) 21 | model:add(nn.Linear(128, 128)) 22 | model:add(cudnn.ReLU(true)) 23 | model:add(nn.Linear(128, 10)) 24 | model:add(nn.LogSoftMax()) 25 | 26 | if use_stn then 27 | require 'stn' 28 | paths.dofile('spatial_transformer.lua') 29 | model:insert(spanet,1) 30 | end 31 | 32 | model:cuda() 33 | criterion = nn.ClassNLLCriterion():cuda() 34 | 35 | optimState = {learningRate = 0.01, momentum = 0.9, weightDecay = 5e-4} 36 | optimizer = nn.Optim(model, optimState) 37 | 38 | local w1,w2 39 | 40 | for epoch=1,30 do 41 | model:training() 42 | local trainError = 0 43 | for batchidx = 1, datasetTrain:getNumBatches() do 44 | local inputs, labels = datasetTrain:getBatch(batchidx) 45 | err = optimizer:optimize(optim.sgd, inputs:cuda(), labels:cuda(), criterion) 46 | --print('epoch : ', epoch, 'batch : ', batchidx, 'train error : ', err) 47 | trainError = trainError + err 48 | end 49 | print('epoch : ', epoch, 'trainError : ', trainError / datasetTrain:getNumBatches()) 50 | 51 | model:evaluate() 52 | local valError = 0 53 | local correct = 0 54 | local all = 0 55 | for batchidx = 1, datasetVal:getNumBatches() do 56 | local inputs, labels = datasetVal:getBatch(batchidx) 57 | local pred = model:forward(inputs:cuda()) 58 | valError = valError + criterion:forward(pred, labels:cuda()) 59 | _, preds = pred:max(2) 60 | correct = correct + preds:eq(labels:cuda()):sum() 61 | all = all + preds:size(1) 62 | end 63 | print('validation error : ', valError / datasetVal:getNumBatches()) 64 | print('accuracy % : ', correct / all * 100) 65 | print('') 66 | 67 | if use_stn then 68 | w1=image.display({image=spanet.output, nrow=16, legend='STN-transformed inputs, epoch : '..epoch, win=w1}) 69 | w2=image.display({image=tranet:get(1).output, nrow=16, legend='Inputs, epoch : '..epoch, win=w2}) 70 | end 71 | 72 | end 73 | 74 | -------------------------------------------------------------------------------- /extras/stnbhwd/demo/distort_mnist.lua: -------------------------------------------------------------------------------- 1 | -- wget 'http://torch7.s3-website-us-east-1.amazonaws.com/data/mnist.t7.tgz' 2 | -- tar -xf mnist.t7.tgz 3 | 4 | function distortData(foo) 5 | local res=torch.FloatTensor(foo:size(1), 1, 42, 42):fill(0) 6 | for i=1,foo:size(1) do 7 | baseImg=foo:select(1,i) 8 | distImg=res:select(1,i) 9 | 10 | r = image.rotate(baseImg, torch.uniform(-3.14/4,3.14/4)) 11 | scale = torch.uniform(0.7,1.2) 12 | sz = torch.floor(scale*32) 13 | s = image.scale(r, sz, sz) 14 | rest = 42-sz 15 | offsetx = torch.random(1, 1+rest) 16 | offsety = torch.random(1, 1+rest) 17 | 18 | distImg:narrow(2, offsety, sz):narrow(3,offsetx, sz):copy(s) 19 | end 20 | return res 21 | end 22 | 23 | function distortData32(foo) 24 | local res=torch.FloatTensor(foo:size(1), 1, 32, 32):fill(0) 25 | local distImg=torch.FloatTensor(1, 42, 42):fill(0) 26 | for i=1,foo:size(1) do 27 | baseImg=foo:select(1,i) 28 | 29 | r = image.rotate(baseImg, torch.uniform(-3.14/4,3.14/4)) 30 | scale = torch.uniform(0.7,1.2) 31 | sz = torch.floor(scale*32) 32 | s = image.scale(r, sz, sz) 33 | rest = 42-sz 34 | offsetx = torch.random(1, 1+rest) 35 | offsety = torch.random(1, 1+rest) 36 | 37 | distImg:zero() 38 | distImg:narrow(2, offsety, sz):narrow(3,offsetx, sz):copy(s) 39 | res:select(1,i):copy(image.scale(distImg,32,32)) 40 | end 41 | return res 42 | end 43 | 44 | function createDatasetsDistorted() 45 | local testFileName = 'mnist.t7/test_32x32.t7' 46 | local trainFileName = 'mnist.t7/train_32x32.t7' 47 | local train = torch.load(trainFileName, 'ascii') 48 | local test = torch.load(testFileName, 'ascii') 49 | train.data = train.data:float() 50 | train.labels = train.labels:float() 51 | test.data = test.data:float() 52 | test.labels = test.labels:float() 53 | 54 | -- distortion 55 | train.data = distortData32(train.data) 56 | test.data = distortData32(test.data) 57 | 58 | local mean = train.data:mean() 59 | local std = train.data:std() 60 | train.data:add(-mean):div(std) 61 | test.data:add(-mean):div(std) 62 | 63 | local batchSize = 256 64 | 65 | local datasetTrain = { 66 | getBatch = function(self, idx) 67 | local data = train.data:narrow(1, (idx - 1) * batchSize + 1, batchSize) 68 | local labels = train.labels:narrow(1, (idx - 1) * batchSize + 1, batchSize) 69 | return data, labels, batchSize 70 | end, 71 | getNumBatches = function() 72 | return torch.floor(60000 / batchSize) 73 | end 74 | } 75 | 76 | local datasetVal = { 77 | getBatch = function(self, idx) 78 | local data = test.data:narrow(1, (idx - 1) * batchSize + 1, batchSize) 79 | local labels = test.labels:narrow(1, (idx - 1) * batchSize + 1, batchSize) 80 | return data, labels, batchSize 81 | end, 82 | getNumBatches = function() 83 | return torch.floor(10000 / batchSize) 84 | end 85 | } 86 | 87 | return datasetTrain, datasetVal 88 | end -------------------------------------------------------------------------------- /extras/stnbhwd/demo/spatial_transformer.lua: -------------------------------------------------------------------------------- 1 | require 'stn' 2 | 3 | spanet=nn.Sequential() 4 | 5 | local concat=nn.ConcatTable() 6 | 7 | -- first branch is there to transpose inputs to BHWD, for the bilinear sampler 8 | tranet=nn.Sequential() 9 | tranet:add(nn.Identity()) 10 | tranet:add(nn.Transpose({2,3},{3,4})) 11 | 12 | -- second branch is the localization network 13 | local locnet = nn.Sequential() 14 | locnet:add(cudnn.SpatialMaxPooling(2,2,2,2)) 15 | locnet:add(cudnn.SpatialConvolution(1,20,5,5)) 16 | locnet:add(cudnn.ReLU(true)) 17 | locnet:add(cudnn.SpatialMaxPooling(2,2,2,2)) 18 | locnet:add(cudnn.SpatialConvolution(20,20,5,5)) 19 | locnet:add(cudnn.ReLU(true)) 20 | locnet:add(nn.View(20*2*2)) 21 | locnet:add(nn.Linear(20*2*2,20)) 22 | locnet:add(cudnn.ReLU(true)) 23 | 24 | -- we initialize the output layer so it gives the identity transform 25 | local outLayer = nn.Linear(20,6) 26 | outLayer.weight:fill(0) 27 | local bias = torch.FloatTensor(6):fill(0) 28 | bias[1]=1 29 | bias[5]=1 30 | outLayer.bias:copy(bias) 31 | locnet:add(outLayer) 32 | 33 | -- there we generate the grids 34 | locnet:add(nn.View(2,3)) 35 | locnet:add(nn.AffineGridGeneratorBHWD(32,32)) 36 | 37 | -- we need a table input for the bilinear sampler, so we use concattable 38 | concat:add(tranet) 39 | concat:add(locnet) 40 | 41 | spanet:add(concat) 42 | spanet:add(nn.BilinearSamplerBHWD()) 43 | 44 | -- and we transpose back to standard BDHW format for subsequent processing by nn modules 45 | spanet:add(nn.Transpose({3,4},{2,3})) 46 | -------------------------------------------------------------------------------- /extras/stnbhwd/generic/BilinearSamplerBHWD.c: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/BilinearSamplerBHWD.c" 3 | #else 4 | 5 | #include 6 | 7 | 8 | static int nn_(BilinearSamplerBHWD_updateOutput)(lua_State *L) 9 | { 10 | THTensor *inputImages = luaT_checkudata(L, 2, torch_Tensor); 11 | THTensor *grids = luaT_checkudata(L, 3, torch_Tensor); 12 | THTensor *output = luaT_checkudata(L, 4, torch_Tensor); 13 | 14 | int batchsize = inputImages->size[0]; 15 | int inputImages_height = inputImages->size[1]; 16 | int inputImages_width = inputImages->size[2]; 17 | int output_height = output->size[1]; 18 | int output_width = output->size[2]; 19 | int inputImages_channels = inputImages->size[3]; 20 | 21 | int output_strideBatch = output->stride[0]; 22 | int output_strideHeight = output->stride[1]; 23 | int output_strideWidth = output->stride[2]; 24 | 25 | int inputImages_strideBatch = inputImages->stride[0]; 26 | int inputImages_strideHeight = inputImages->stride[1]; 27 | int inputImages_strideWidth = inputImages->stride[2]; 28 | 29 | int grids_strideBatch = grids->stride[0]; 30 | int grids_strideHeight = grids->stride[1]; 31 | int grids_strideWidth = grids->stride[2]; 32 | 33 | real *inputImages_data, *output_data, *grids_data; 34 | inputImages_data = THTensor_(data)(inputImages); 35 | output_data = THTensor_(data)(output); 36 | grids_data = THTensor_(data)(grids); 37 | 38 | int b, yOut, xOut; 39 | 40 | for(b=0; b < batchsize; b++) 41 | { 42 | for(yOut=0; yOut < output_height; yOut++) 43 | { 44 | for(xOut=0; xOut < output_width; xOut++) 45 | { 46 | //read the grid 47 | real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth]; 48 | real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1]; 49 | 50 | // get the weights for interpolation 51 | int yInTopLeft, xInTopLeft; 52 | real yWeightTopLeft, xWeightTopLeft; 53 | 54 | real xcoord = (xf + 1) * (inputImages_width - 1) / 2; 55 | xInTopLeft = floor(xcoord); 56 | xWeightTopLeft = 1 - (xcoord - xInTopLeft); 57 | 58 | real ycoord = (yf + 1) * (inputImages_height - 1) / 2; 59 | yInTopLeft = floor(ycoord); 60 | yWeightTopLeft = 1 - (ycoord - yInTopLeft); 61 | 62 | 63 | 64 | const int outAddress = output_strideBatch * b + output_strideHeight * yOut + output_strideWidth * xOut; 65 | const int inTopLeftAddress = inputImages_strideBatch * b + inputImages_strideHeight * yInTopLeft + inputImages_strideWidth * xInTopLeft; 66 | const int inTopRightAddress = inTopLeftAddress + inputImages_strideWidth; 67 | const int inBottomLeftAddress = inTopLeftAddress + inputImages_strideHeight; 68 | const int inBottomRightAddress = inBottomLeftAddress + inputImages_strideWidth; 69 | 70 | real v=0; 71 | real inTopLeft=0; 72 | real inTopRight=0; 73 | real inBottomLeft=0; 74 | real inBottomRight=0; 75 | 76 | // we are careful with the boundaries 77 | bool topLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1; 78 | bool topRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1; 79 | bool bottomLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1; 80 | bool bottomRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1; 81 | 82 | int t; 83 | // interpolation happens here 84 | for(t=0; tsize[0]; 117 | int inputImages_height = inputImages->size[1]; 118 | int inputImages_width = inputImages->size[2]; 119 | int gradOutput_height = gradOutput->size[1]; 120 | int gradOutput_width = gradOutput->size[2]; 121 | int inputImages_channels = inputImages->size[3]; 122 | 123 | int gradOutput_strideBatch = gradOutput->stride[0]; 124 | int gradOutput_strideHeight = gradOutput->stride[1]; 125 | int gradOutput_strideWidth = gradOutput->stride[2]; 126 | 127 | int inputImages_strideBatch = inputImages->stride[0]; 128 | int inputImages_strideHeight = inputImages->stride[1]; 129 | int inputImages_strideWidth = inputImages->stride[2]; 130 | 131 | int gradInputImages_strideBatch = gradInputImages->stride[0]; 132 | int gradInputImages_strideHeight = gradInputImages->stride[1]; 133 | int gradInputImages_strideWidth = gradInputImages->stride[2]; 134 | 135 | int grids_strideBatch = grids->stride[0]; 136 | int grids_strideHeight = grids->stride[1]; 137 | int grids_strideWidth = grids->stride[2]; 138 | 139 | int gradGrids_strideBatch = gradGrids->stride[0]; 140 | int gradGrids_strideHeight = gradGrids->stride[1]; 141 | int gradGrids_strideWidth = gradGrids->stride[2]; 142 | 143 | real *inputImages_data, *gradOutput_data, *grids_data, *gradGrids_data, *gradInputImages_data; 144 | inputImages_data = THTensor_(data)(inputImages); 145 | gradOutput_data = THTensor_(data)(gradOutput); 146 | grids_data = THTensor_(data)(grids); 147 | gradGrids_data = THTensor_(data)(gradGrids); 148 | gradInputImages_data = THTensor_(data)(gradInputImages); 149 | 150 | int b, yOut, xOut; 151 | 152 | for(b=0; b < batchsize; b++) 153 | { 154 | for(yOut=0; yOut < gradOutput_height; yOut++) 155 | { 156 | for(xOut=0; xOut < gradOutput_width; xOut++) 157 | { 158 | //read the grid 159 | real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth]; 160 | real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1]; 161 | 162 | // get the weights for interpolation 163 | int yInTopLeft, xInTopLeft; 164 | real yWeightTopLeft, xWeightTopLeft; 165 | 166 | real xcoord = (xf + 1) * (inputImages_width - 1) / 2; 167 | xInTopLeft = floor(xcoord); 168 | xWeightTopLeft = 1 - (xcoord - xInTopLeft); 169 | 170 | real ycoord = (yf + 1) * (inputImages_height - 1) / 2; 171 | yInTopLeft = floor(ycoord); 172 | yWeightTopLeft = 1 - (ycoord - yInTopLeft); 173 | 174 | 175 | const int inTopLeftAddress = inputImages_strideBatch * b + inputImages_strideHeight * yInTopLeft + inputImages_strideWidth * xInTopLeft; 176 | const int inTopRightAddress = inTopLeftAddress + inputImages_strideWidth; 177 | const int inBottomLeftAddress = inTopLeftAddress + inputImages_strideHeight; 178 | const int inBottomRightAddress = inBottomLeftAddress + inputImages_strideWidth; 179 | 180 | const int gradInputImagesTopLeftAddress = gradInputImages_strideBatch * b + gradInputImages_strideHeight * yInTopLeft + gradInputImages_strideWidth * xInTopLeft; 181 | const int gradInputImagesTopRightAddress = gradInputImagesTopLeftAddress + gradInputImages_strideWidth; 182 | const int gradInputImagesBottomLeftAddress = gradInputImagesTopLeftAddress + gradInputImages_strideHeight; 183 | const int gradInputImagesBottomRightAddress = gradInputImagesBottomLeftAddress + gradInputImages_strideWidth; 184 | 185 | const int gradOutputAddress = gradOutput_strideBatch * b + gradOutput_strideHeight * yOut + gradOutput_strideWidth * xOut; 186 | 187 | real topLeftDotProduct = 0; 188 | real topRightDotProduct = 0; 189 | real bottomLeftDotProduct = 0; 190 | real bottomRightDotProduct = 0; 191 | 192 | real v=0; 193 | real inTopLeft=0; 194 | real inTopRight=0; 195 | real inBottomLeft=0; 196 | real inBottomRight=0; 197 | 198 | // we are careful with the boundaries 199 | bool topLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1; 200 | bool topRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1; 201 | bool bottomLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1; 202 | bool bottomRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1; 203 | 204 | int t; 205 | 206 | for(t=0; t 6 | 7 | 8 | static int nn_(ScaleBHWD_updateOutput)(lua_State *L) 9 | { 10 | THTensor *inputImages = luaT_checkudata(L, 2, torch_Tensor); 11 | THTensor *grids = luaT_checkudata(L, 3, torch_Tensor); 12 | THTensor *output = luaT_checkudata(L, 4, torch_Tensor); 13 | 14 | int batchsize = inputImages->size[0]; 15 | int inputImages_height = inputImages->size[1]; 16 | int inputImages_width = inputImages->size[2]; 17 | int output_height = output->size[1]; 18 | int output_width = output->size[2]; 19 | int inputImages_channels = inputImages->size[3]; 20 | 21 | int output_strideBatch = output->stride[0]; 22 | int output_strideHeight = output->stride[1]; 23 | int output_strideWidth = output->stride[2]; 24 | 25 | int inputImages_strideBatch = inputImages->stride[0]; 26 | int inputImages_strideHeight = inputImages->stride[1]; 27 | int inputImages_strideWidth = inputImages->stride[2]; 28 | 29 | int grids_strideBatch = grids->stride[0]; 30 | int grids_strideHeight = grids->stride[1]; 31 | int grids_strideWidth = grids->stride[2]; 32 | 33 | real *inputImages_data, *output_data, *grids_data; 34 | inputImages_data = THTensor_(data)(inputImages); 35 | output_data = THTensor_(data)(output); 36 | grids_data = THTensor_(data)(grids); 37 | 38 | int b, yOut, xOut; 39 | 40 | for(b=0; b < batchsize; b++) 41 | { 42 | for(yOut=0; yOut < output_height; yOut++) 43 | { 44 | for(xOut=0; xOut < output_width; xOut++) 45 | { 46 | //read the grid 47 | real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth]; 48 | real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1]; 49 | 50 | // get the weights for interpolation 51 | int yInTopLeft, xInTopLeft; 52 | real yWeightTopLeft, xWeightTopLeft; 53 | 54 | real xcoord = (xf + 1) * (inputImages_width - 1) / 2; 55 | xInTopLeft = floor(xcoord); 56 | xWeightTopLeft = 1 - (xcoord - xInTopLeft); 57 | 58 | real ycoord = (yf + 1) * (inputImages_height - 1) / 2; 59 | yInTopLeft = floor(ycoord); 60 | yWeightTopLeft = 1 - (ycoord - yInTopLeft); 61 | 62 | 63 | 64 | const int outAddress = output_strideBatch * b + output_strideHeight * yOut + output_strideWidth * xOut; 65 | const int inTopLeftAddress = inputImages_strideBatch * b + inputImages_strideHeight * yInTopLeft + inputImages_strideWidth * xInTopLeft; 66 | const int inTopRightAddress = inTopLeftAddress + inputImages_strideWidth; 67 | const int inBottomLeftAddress = inTopLeftAddress + inputImages_strideHeight; 68 | const int inBottomRightAddress = inBottomLeftAddress + inputImages_strideWidth; 69 | 70 | real v=0; 71 | real inTopLeft=0; 72 | real inTopRight=0; 73 | real inBottomLeft=0; 74 | real inBottomRight=0; 75 | 76 | // we are careful with the boundaries 77 | bool topLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1; 78 | bool topRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1; 79 | bool bottomLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1; 80 | bool bottomRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1; 81 | 82 | int t; 83 | // interpolation happens here 84 | for(t=0; tsize[0]; 117 | int inputImages_height = inputImages->size[1]; 118 | int inputImages_width = inputImages->size[2]; 119 | int gradOutput_height = gradOutput->size[1]; 120 | int gradOutput_width = gradOutput->size[2]; 121 | int inputImages_channels = inputImages->size[3]; 122 | 123 | int gradOutput_strideBatch = gradOutput->stride[0]; 124 | int gradOutput_strideHeight = gradOutput->stride[1]; 125 | int gradOutput_strideWidth = gradOutput->stride[2]; 126 | 127 | int inputImages_strideBatch = inputImages->stride[0]; 128 | int inputImages_strideHeight = inputImages->stride[1]; 129 | int inputImages_strideWidth = inputImages->stride[2]; 130 | 131 | int gradInputImages_strideBatch = gradInputImages->stride[0]; 132 | int gradInputImages_strideHeight = gradInputImages->stride[1]; 133 | int gradInputImages_strideWidth = gradInputImages->stride[2]; 134 | 135 | int grids_strideBatch = grids->stride[0]; 136 | int grids_strideHeight = grids->stride[1]; 137 | int grids_strideWidth = grids->stride[2]; 138 | 139 | int gradGrids_strideBatch = gradGrids->stride[0]; 140 | int gradGrids_strideHeight = gradGrids->stride[1]; 141 | int gradGrids_strideWidth = gradGrids->stride[2]; 142 | 143 | real *inputImages_data, *gradOutput_data, *grids_data, *gradGrids_data, *gradInputImages_data; 144 | inputImages_data = THTensor_(data)(inputImages); 145 | gradOutput_data = THTensor_(data)(gradOutput); 146 | grids_data = THTensor_(data)(grids); 147 | gradGrids_data = THTensor_(data)(gradGrids); 148 | gradInputImages_data = THTensor_(data)(gradInputImages); 149 | 150 | int b, yOut, xOut; 151 | 152 | for(b=0; b < batchsize; b++) 153 | { 154 | for(yOut=0; yOut < gradOutput_height; yOut++) 155 | { 156 | for(xOut=0; xOut < gradOutput_width; xOut++) 157 | { 158 | //read the grid 159 | real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth]; 160 | real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1]; 161 | 162 | // get the weights for interpolation 163 | int yInTopLeft, xInTopLeft; 164 | real yWeightTopLeft, xWeightTopLeft; 165 | 166 | real xcoord = (xf + 1) * (inputImages_width - 1) / 2; 167 | xInTopLeft = floor(xcoord); 168 | xWeightTopLeft = 1 - (xcoord - xInTopLeft); 169 | 170 | real ycoord = (yf + 1) * (inputImages_height - 1) / 2; 171 | yInTopLeft = floor(ycoord); 172 | yWeightTopLeft = 1 - (ycoord - yInTopLeft); 173 | 174 | 175 | const int inTopLeftAddress = inputImages_strideBatch * b + inputImages_strideHeight * yInTopLeft + inputImages_strideWidth * xInTopLeft; 176 | const int inTopRightAddress = inTopLeftAddress + inputImages_strideWidth; 177 | const int inBottomLeftAddress = inTopLeftAddress + inputImages_strideHeight; 178 | const int inBottomRightAddress = inBottomLeftAddress + inputImages_strideWidth; 179 | 180 | const int gradInputImagesTopLeftAddress = gradInputImages_strideBatch * b + gradInputImages_strideHeight * yInTopLeft + gradInputImages_strideWidth * xInTopLeft; 181 | const int gradInputImagesTopRightAddress = gradInputImagesTopLeftAddress + gradInputImages_strideWidth; 182 | const int gradInputImagesBottomLeftAddress = gradInputImagesTopLeftAddress + gradInputImages_strideHeight; 183 | const int gradInputImagesBottomRightAddress = gradInputImagesBottomLeftAddress + gradInputImages_strideWidth; 184 | 185 | const int gradOutputAddress = gradOutput_strideBatch * b + gradOutput_strideHeight * yOut + gradOutput_strideWidth * xOut; 186 | 187 | real topLeftDotProduct = 0; 188 | real topRightDotProduct = 0; 189 | real bottomLeftDotProduct = 0; 190 | real bottomRightDotProduct = 0; 191 | 192 | real v=0; 193 | real inTopLeft=0; 194 | real inTopRight=0; 195 | real inBottomLeft=0; 196 | real inBottomRight=0; 197 | 198 | // we are careful with the boundaries 199 | bool topLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1; 200 | bool topRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1; 201 | bool bottomLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1; 202 | bool bottomRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1; 203 | 204 | int t; 205 | 206 | for(t=0; t= 7.0", 18 | "nn >= 1.0", 19 | } 20 | 21 | build = { 22 | type = "command", 23 | build_command = [[ 24 | cmake -E make_directory build && cd build && cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)" && $(MAKE) 25 | ]], 26 | install_command = "cd build && $(MAKE) install" 27 | } 28 | -------------------------------------------------------------------------------- /extras/stnbhwd/test.lua: -------------------------------------------------------------------------------- 1 | -- you can easily test specific units like this: 2 | -- th -lnn -e "nn.test{'LookupTable'}" 3 | -- th -lnn -e "nn.test{'LookupTable', 'Add'}" 4 | 5 | local mytester = torch.Tester() 6 | local jac 7 | local sjac 8 | 9 | local precision = 1e-5 10 | local expprecision = 1e-4 11 | 12 | local stntest = {} 13 | 14 | function stntest.AffineGridGeneratorBHWD_batch() 15 | local nframes = torch.random(2,10) 16 | local height = torch.random(2,5) 17 | local width = torch.random(2,5) 18 | local input = torch.zeros(nframes, 2, 3):uniform() 19 | local module = nn.AffineGridGeneratorBHWD(height, width) 20 | 21 | local err = jac.testJacobian(module,input) 22 | mytester:assertlt(err,precision, 'error on state ') 23 | 24 | -- IO 25 | local ferr,berr = jac.testIO(module,input) 26 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') 27 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') 28 | 29 | end 30 | 31 | function stntest.AffineGridGeneratorBHWD_single() 32 | local height = torch.random(2,5) 33 | local width = torch.random(2,5) 34 | local input = torch.zeros(2, 3):uniform() 35 | local module = nn.AffineGridGeneratorBHWD(height, width) 36 | 37 | local err = jac.testJacobian(module,input) 38 | mytester:assertlt(err,precision, 'error on state ') 39 | 40 | -- IO 41 | local ferr,berr = jac.testIO(module,input) 42 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') 43 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') 44 | 45 | end 46 | 47 | function stntest.BilinearSamplerBHWD_batch() 48 | local nframes = torch.random(2,10) 49 | local height = torch.random(1,5) 50 | local width = torch.random(1,5) 51 | local channels = torch.random(1,6) 52 | local inputImages = torch.zeros(nframes, height, width, channels):uniform() 53 | local grids = torch.zeros(nframes, height, width, 2):uniform(-1, 1) 54 | local module = nn.BilinearSamplerBHWD() 55 | 56 | -- test input images (first element of input table) 57 | module._updateOutput = module.updateOutput 58 | function module:updateOutput(input) 59 | return self:_updateOutput({input, grids}) 60 | end 61 | 62 | module._updateGradInput = module.updateGradInput 63 | function module:updateGradInput(input, gradOutput) 64 | self:_updateGradInput({input, grids}, gradOutput) 65 | return self.gradInput[1] 66 | end 67 | 68 | local errImages = jac.testJacobian(module,inputImages) 69 | mytester:assertlt(errImages,precision, 'error on state ') 70 | 71 | -- test grids (second element of input table) 72 | function module:updateOutput(input) 73 | return self:_updateOutput({inputImages, input}) 74 | end 75 | 76 | function module:updateGradInput(input, gradOutput) 77 | self:_updateGradInput({inputImages, input}, gradOutput) 78 | return self.gradInput[2] 79 | end 80 | 81 | local errGrids = jac.testJacobian(module,grids) 82 | mytester:assertlt(errGrids,precision, 'error on state ') 83 | end 84 | 85 | function stntest.BilinearSamplerBHWD_single() 86 | local height = torch.random(1,5) 87 | local width = torch.random(1,5) 88 | local channels = torch.random(1,6) 89 | local inputImages = torch.zeros(height, width, channels):uniform() 90 | local grids = torch.zeros(height, width, 2):uniform(-1, 1) 91 | local module = nn.BilinearSamplerBHWD() 92 | 93 | -- test input images (first element of input table) 94 | module._updateOutput = module.updateOutput 95 | function module:updateOutput(input) 96 | return self:_updateOutput({input, grids}) 97 | end 98 | 99 | module._updateGradInput = module.updateGradInput 100 | function module:updateGradInput(input, gradOutput) 101 | self:_updateGradInput({input, grids}, gradOutput) 102 | return self.gradInput[1] 103 | end 104 | 105 | local errImages = jac.testJacobian(module,inputImages) 106 | mytester:assertlt(errImages,precision, 'error on state ') 107 | 108 | -- test grids (second element of input table) 109 | function module:updateOutput(input) 110 | return self:_updateOutput({inputImages, input}) 111 | end 112 | 113 | function module:updateGradInput(input, gradOutput) 114 | self:_updateGradInput({inputImages, input}, gradOutput) 115 | return self.gradInput[2] 116 | end 117 | 118 | local errGrids = jac.testJacobian(module,grids) 119 | mytester:assertlt(errGrids,precision, 'error on state ') 120 | end 121 | 122 | function stntest.AffineTransformMatrixGenerator_batch() 123 | -- test all possible transformations 124 | for _,useRotation in pairs{true,false} do 125 | for _,useScale in pairs{true,false} do 126 | for _,useTranslation in pairs{true,false} do 127 | local currTest = '' 128 | if useRotation then currTest = currTest..'rotation ' end 129 | if useScale then currTest = currTest..'scale ' end 130 | if useTranslation then currTest = currTest..'translation' end 131 | if currTest=='' then currTest = 'full' end 132 | 133 | local nbNeededParams = 0 134 | if useRotation then nbNeededParams = nbNeededParams + 1 end 135 | if useScale then nbNeededParams = nbNeededParams + 1 end 136 | if useTranslation then nbNeededParams = nbNeededParams + 2 end 137 | if nbNeededParams == 0 then nbNeededParams = 6 end -- full affine case 138 | 139 | local nframes = torch.random(2,10) 140 | local params = torch.zeros(nframes,nbNeededParams):uniform() 141 | local module = nn.AffineTransformMatrixGenerator(useRotation,useScale,useTranslation) 142 | 143 | local err = jac.testJacobian(module,params) 144 | mytester:assertlt(err,precision, 'error on state for test '..currTest) 145 | 146 | -- IO 147 | local ferr,berr = jac.testIO(module,params) 148 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err for test '..currTest) 149 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err for test '..currTest) 150 | 151 | end 152 | end 153 | end 154 | end 155 | 156 | function stntest.AffineTransformMatrixGenerator_single() 157 | -- test all possible transformations 158 | for _,useRotation in pairs{true,false} do 159 | for _,useScale in pairs{true,false} do 160 | for _,useTranslation in pairs{true,false} do 161 | local currTest = '' 162 | if useRotation then currTest = currTest..'rotation ' end 163 | if useScale then currTest = currTest..'scale ' end 164 | if useTranslation then currTest = currTest..'translation' end 165 | if currTest=='' then currTest = 'full' end 166 | 167 | local nbNeededParams = 0 168 | if useRotation then nbNeededParams = nbNeededParams + 1 end 169 | if useScale then nbNeededParams = nbNeededParams + 1 end 170 | if useTranslation then nbNeededParams = nbNeededParams + 2 end 171 | if nbNeededParams == 0 then nbNeededParams = 6 end -- full affine case 172 | 173 | local params = torch.zeros(nbNeededParams):uniform() 174 | local module = nn.AffineTransformMatrixGenerator(useRotation,useScale,useTranslation) 175 | 176 | local err = jac.testJacobian(module,params) 177 | mytester:assertlt(err,precision, 'error on state for test '..currTest) 178 | 179 | -- IO 180 | local ferr,berr = jac.testIO(module,params) 181 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err for test '..currTest) 182 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err for test '..currTest) 183 | 184 | end 185 | end 186 | end 187 | end 188 | 189 | mytester:add(stntest) 190 | 191 | if not nn then 192 | require 'nn' 193 | jac = nn.Jacobian 194 | sjac = nn.SparseJacobian 195 | mytester:run() 196 | else 197 | jac = nn.Jacobian 198 | sjac = nn.SparseJacobian 199 | function stn.test(tests) 200 | -- randomize stuff 201 | math.randomseed(os.time()) 202 | mytester:run(tests) 203 | return mytester 204 | end 205 | end 206 | -------------------------------------------------------------------------------- /extras/stnbhwd/utils.c: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | THCState* getCutorchState(lua_State* L) 4 | { 5 | lua_getglobal(L, "cutorch"); 6 | lua_getfield(L, -1, "getState"); 7 | lua_call(L, 0, 1); 8 | THCState *state = (THCState*) lua_touserdata(L, -1); 9 | lua_pop(L, 2); 10 | return state; 11 | } 12 | -------------------------------------------------------------------------------- /extras/stnbhwd/utils.h: -------------------------------------------------------------------------------- 1 | #ifndef CUNN_UTILS_H 2 | #define CUNN_UTILS_H 3 | 4 | #include 5 | #include "THCGeneral.h" 6 | 7 | THCState* getCutorchState(lua_State* L); 8 | 9 | #endif 10 | -------------------------------------------------------------------------------- /flowExtensions.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. 2 | -- All rights reserved. 3 | -- This software is provided for research purposes only. 4 | -- By using this software you agree to the terms of the license file 5 | -- in the root folder. 6 | -- For commercial use, please contact ps-license@tue.mpg.de. 7 | 8 | ------------------------- 9 | -- Optical Flow Utilities 10 | ------------------------- 11 | local stringx = require('pl.stringx') 12 | local M = {} 13 | 14 | local eps = 1e-6 15 | 16 | local function computeNorm(...) 17 | -- check args 18 | local _, flow_x, flow_y = xlua.unpack( 19 | {...}, 20 | 'opticalflow.computeNorm', 21 | 'computes norm (size) of flow field from flow_x and flow_y,\n', 22 | {arg='flow_x', type='torch.Tensor', help='flow field (x), (WxH)', req=true}, 23 | {arg='flow_y', type='torch.Tensor', help='flow field (y), (WxH)', req=true} 24 | ) 25 | local flow_norm = torch.Tensor() 26 | local x_squared = torch.Tensor():resizeAs(flow_x):copy(flow_x):cmul(flow_x) 27 | flow_norm:resizeAs(flow_y):copy(flow_y):cmul(flow_y):add(x_squared):sqrt() 28 | return flow_norm 29 | end 30 | M.computeNorm = computeNorm 31 | 32 | ------------------------------------------------------------ 33 | -- computes angle (direction) of flow field from flow_x and flow_y, 34 | -- 35 | -- @usage opticalflow.computeAngle() -- prints online help 36 | -- 37 | -- @param flow_x flow field (x), (WxH) [required] [type = torch.Tensor] 38 | -- @param flow_y flow field (y), (WxH) [required] [type = torch.Tensor] 39 | ------------------------------------------------------------ 40 | local function computeAngle(...) 41 | -- check args 42 | local _, flow_x, flow_y = xlua.unpack( 43 | {...}, 44 | 'opticalflow.computeAngle', 45 | 'computes angle (direction) of flow field from flow_x and flow_y,\n', 46 | {arg='flow_x', type='torch.Tensor', help='flow field (x), (WxH)', req=true}, 47 | {arg='flow_y', type='torch.Tensor', help='flow field (y), (WxH)', req=true} 48 | ) 49 | local flow_angle = torch.Tensor() 50 | flow_angle:resizeAs(flow_y):copy(flow_y):cdiv(flow_x):abs():atan():mul(180/math.pi) 51 | flow_angle:map2(flow_x, flow_y, function(h,x,y) 52 | if x == 0 and y >= 0 then 53 | return 90 54 | elseif x == 0 and y <= 0 then 55 | return 270 56 | elseif x >= 0 and y >= 0 then 57 | -- all good 58 | elseif x >= 0 and y < 0 then 59 | return 360 - h 60 | elseif x < 0 and y >= 0 then 61 | return 180 - h 62 | elseif x < 0 and y < 0 then 63 | return 180 + h 64 | end 65 | end) 66 | return flow_angle 67 | end 68 | M.computeAngle = computeAngle 69 | ------------------------------------------------------------ 70 | -- merges Norm and Angle flow fields into a single RGB image, 71 | -- where saturation=intensity, and hue=direction 72 | -- 73 | -- @usage opticalflow.field2rgb() -- prints online help 74 | -- 75 | -- @param norm flow field (norm), (WxH) [required] [type = torch.Tensor] 76 | -- @param angle flow field (angle), (WxH) [required] [type = torch.Tensor] 77 | -- @param max if not provided, norm:max() is used [type = number] 78 | -- @param legend prints a legend on the image [type = boolean] 79 | ------------------------------------------------------------ 80 | local function field2rgb(...) 81 | -- check args 82 | local _, norm, angle, max, legend = xlua.unpack( 83 | {...}, 84 | 'opticalflow.field2rgb', 85 | 'merges Norm and Angle flow fields into a single RGB image,\n' 86 | .. 'where saturation=intensity, and hue=direction', 87 | {arg='norm', type='torch.Tensor', help='flow field (norm), (WxH)', req=true}, 88 | {arg='angle', type='torch.Tensor', help='flow field (angle), (WxH)', req=true}, 89 | {arg='max', type='number', help='if not provided, norm:max() is used'}, 90 | {arg='legend', type='boolean', help='prints a legend on the image', default=false} 91 | ) 92 | 93 | -- max 94 | local saturate = false 95 | if max then saturate = true end 96 | max = math.max(max or norm:max(), 1e-2) 97 | 98 | -- merge them into an HSL image 99 | local hsl = torch.Tensor(3,norm:size(1), norm:size(2)) 100 | -- hue = angle: 101 | hsl:select(1,1):copy(angle):div(360) 102 | -- saturation = normalized intensity: 103 | hsl:select(1,2):copy(norm):div(max) 104 | if saturate then hsl:select(1,2):tanh() end 105 | -- light varies inversely from saturation (null flow = white): 106 | hsl:select(1,3):copy(hsl:select(1,2)):mul(-0.5):add(1) 107 | 108 | -- convert HSL to RGB 109 | local rgb = image.hsl2rgb(hsl) 110 | 111 | -- legend 112 | if legend then 113 | _legend_ = _legend_ 114 | or image.load(paths.concat(paths.install_lua_path, 'opticalflow/legend.png'),3) 115 | legend = torch.Tensor(3,hsl:size(2)/8, hsl:size(2)/8) 116 | image.scale(_legend_, legend, 'bilinear') 117 | rgb:narrow(1,1,legend:size(2)):narrow(2,hsl:size(2)-legend:size(2)+1,legend:size(2)):copy(legend) 118 | end 119 | 120 | -- done 121 | return rgb 122 | end 123 | M.field2rgb = field2rgb 124 | ------------------------------------------------------------ 125 | -- Simplifies display of flow field in HSV colorspace when the 126 | -- available field is in x,y displacement 127 | -- 128 | -- @usage opticalflow.xy2rgb() -- prints online help 129 | -- 130 | -- @param x flow field (x), (WxH) [required] [type = torch.Tensor] 131 | -- @param y flow field (y), (WxH) [required] [type = torch.Tensor] 132 | ------------------------------------------------------------ 133 | local function xy2rgb(...) 134 | -- check args 135 | local _, x, y, max = xlua.unpack( 136 | {...}, 137 | 'opticalflow.xy2rgb', 138 | 'merges x and y flow fields into a single RGB image,\n' 139 | .. 'where saturation=intensity, and hue=direction', 140 | {arg='x', type='torch.Tensor', help='flow field (norm), (WxH)', req=true}, 141 | {arg='y', type='torch.Tensor', help='flow field (angle), (WxH)', req=true}, 142 | {arg='max', type='number', help='if not provided, norm:max() is used'} 143 | ) 144 | 145 | local norm = computeNorm(x,y) 146 | local angle = computeAngle(x,y) 147 | return field2rgb(norm,angle,max) 148 | end 149 | M.xy2rgb = xy2rgb 150 | 151 | local function loadFLO(filename) 152 | TAG_FLOAT = 202021.25 153 | local ff = torch.DiskFile(filename):binary() 154 | local tag = ff:readFloat() 155 | if tag ~= TAG_FLOAT then 156 | xerror('unable to read '..filename.. 157 | ' perhaps bigendian error','readflo()') 158 | end 159 | 160 | local w = ff:readInt() 161 | local h = ff:readInt() 162 | local nbands = 2 163 | local tf = torch.FloatTensor(h, w, nbands) 164 | ff:readFloat(tf:storage()) 165 | ff:close() 166 | 167 | local flow = tf:permute(3,1,2) 168 | return flow 169 | end 170 | M.loadFLO = loadFLO 171 | 172 | local function writeFLO(filename, F) 173 | F = F:permute(2,3,1):clone() 174 | TAG_FLOAT = 202021.25 175 | local ff = torch.DiskFile(filename, 'w'):binary() 176 | ff:writeFloat(TAG_FLOAT) 177 | 178 | ff:writeInt(F:size(2)) -- width 179 | ff:writeInt(F:size(1)) -- height 180 | 181 | ff:writeFloat(F:storage()) 182 | ff:close() 183 | end 184 | M.writeFLO = writeFLO 185 | 186 | local function loadPFM(filename) 187 | ff = torch.DiskFile(filename):binary() 188 | local header = ff:readString("*l") 189 | local color, nbands 190 | if stringx.strip(header) == 'PF' then 191 | color = true 192 | nbands = 3 193 | else 194 | color = false 195 | nbands = 1 196 | end 197 | local dims = stringx.split(ff:readString("*l")) 198 | local scale = ff:readString("*l") 199 | if tonumber(scale) < 0 then 200 | ff:littleEndianEncoding() 201 | else 202 | ff:bigEndianEncoding() 203 | end 204 | local tf = ff:readFloat(dims[1]*dims[2]*nbands) 205 | ff:close() 206 | tf = torch.FloatTensor(tf):resize(dims[2],dims[1],nbands):permute(3,1,2) 207 | tf = image.vflip(tf) 208 | return tf[{{1,2},{},{}}] 209 | end 210 | M.loadPFM = loadPFM 211 | 212 | local function rotate(flow, angle) 213 | local flow_rot = image.rotate(flow, angle, 'simple') 214 | local fu = torch.mul(flow_rot[1], math.cos(-angle)) - torch.mul(flow_rot[2], math.sin(-angle)) 215 | local fv = torch.mul(flow_rot[1], math.sin(-angle)) + torch.mul(flow_rot[2], math.cos(-angle)) 216 | flow_rot[1]:copy(fu) 217 | flow_rot[2]:copy(fv) 218 | 219 | return flow_rot 220 | end 221 | M.rotate = rotate 222 | 223 | local function scale(flow, sc, opt) 224 | opt = opt or 'simple' 225 | local flow_scaled = image.scale(flow, '*'..sc, opt)*sc 226 | 227 | return flow_scaled 228 | 229 | end 230 | M.scale = scale 231 | 232 | local function scaleBatch(flow, sc) 233 | local flowR = torch.FloatTensor(opt.batchSize*2, flow:size(3), flow:size(4)) 234 | local outputR = torch.FloatTensor(opt.batchSize, 2, flow:size(3)*sc, flow:size(4)*sc) 235 | 236 | flowR:copy(flow) 237 | local output = image.scale(flowR, '*'..sc, 'simple')*sc 238 | outputR:copy(output) 239 | return outputR 240 | end 241 | M.scaleBatch = scaleBatch 242 | 243 | return M 244 | -------------------------------------------------------------------------------- /main.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. 2 | -- All rights reserved. 3 | -- This software is provided for research purposes only. 4 | -- By using this software you agree to the terms of the license file 5 | -- in the root folder. 6 | -- For commercial use, please contact ps-license@tue.mpg.de. 7 | 8 | require 'torch' 9 | require 'cutorch' 10 | require 'paths' 11 | require 'xlua' 12 | require 'optim' 13 | require 'nn' 14 | 15 | torch.setdefaulttensortype('torch.FloatTensor') 16 | 17 | local opts = paths.dofile('opts.lua') 18 | 19 | opt = opts.parse(arg) 20 | 21 | print('Saving everything to: ' .. opt.save) 22 | os.execute('mkdir -p ' .. opt.save) 23 | 24 | paths.dofile('util.lua') 25 | paths.dofile('model.lua') 26 | opt.imageSize = model.imageSize or opt.imageSize 27 | opt.outputSize = model.outputSize or opt.outputSize 28 | 29 | print(opt) 30 | 31 | cutorch.setDevice(opt.GPU) -- by default, use GPU 1 32 | torch.manualSeed(opt.manualSeed) 33 | 34 | paths.dofile('data.lua') 35 | paths.dofile('train.lua') 36 | paths.dofile('test.lua') 37 | 38 | epoch = opt.epochNumber 39 | 40 | for i=1,opt.nEpochs do 41 | train() 42 | test() 43 | epoch = epoch + 1 44 | end 45 | -------------------------------------------------------------------------------- /model.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. 2 | -- All rights reserved. 3 | -- This software is provided for research purposes only. 4 | -- By using this software you agree to the terms of the license file 5 | -- in the root folder. 6 | -- For commercial use, please contact ps-license@tue.mpg.de. 7 | -- 8 | -- Copyright (c) 2014, Facebook, Inc. 9 | -- All rights reserved. 10 | -- 11 | -- This source code is licensed under the BSD-style license found in the 12 | -- LICENSE file in the root directory of this source tree. An additional grant 13 | -- of patent rights can be found in the PATENTS file in the same directory. 14 | -- 15 | require 'nn' 16 | require 'cunn' 17 | require 'optim' 18 | include('EPECriterion.lua') 19 | 20 | --[[ 21 | 1. Create Model 22 | 2. Create Criterion 23 | 3. Convert model to CUDA 24 | ]]-- 25 | 26 | -- 1. Create Network 27 | -- 1.1 If preloading option is set, preload weights from existing models appropriately 28 | if opt.retrain ~= 'none' then 29 | assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain) 30 | print('Loading model from file: ' .. opt.retrain); 31 | model = loadDataParallel(opt.retrain, opt.nGPU) -- defined in util.lua 32 | else 33 | paths.dofile('models/' .. opt.netType .. '.lua') 34 | print('=> Creating model from file: models/' .. opt.netType .. '.lua') 35 | model = createModel(opt.nGPU) -- for the model creation code, check the models/ folder 36 | if opt.backend == 'cudnn' then 37 | require 'cudnn' 38 | cudnn.convert(model, cudnn) 39 | elseif opt.backend ~= 'nn' then 40 | error'Unsupported backend' 41 | end 42 | end 43 | 44 | -- 2. Create Criterion 45 | criterion = nn.EPECriterion() 46 | 47 | print('=> Model') 48 | print(model) 49 | 50 | print('=> Criterion') 51 | print(criterion) 52 | 53 | criterion:cuda() 54 | 55 | collectgarbage() 56 | -------------------------------------------------------------------------------- /models/modelL1_3.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL1_3.t7 -------------------------------------------------------------------------------- /models/modelL1_4.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL1_4.t7 -------------------------------------------------------------------------------- /models/modelL1_C.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL1_C.t7 -------------------------------------------------------------------------------- /models/modelL1_F.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL1_F.t7 -------------------------------------------------------------------------------- /models/modelL1_K.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL1_K.t7 -------------------------------------------------------------------------------- /models/modelL2_3.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL2_3.t7 -------------------------------------------------------------------------------- /models/modelL2_4.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL2_4.t7 -------------------------------------------------------------------------------- /models/modelL2_C.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL2_C.t7 -------------------------------------------------------------------------------- /models/modelL2_F.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL2_F.t7 -------------------------------------------------------------------------------- /models/modelL2_K.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL2_K.t7 -------------------------------------------------------------------------------- /models/modelL3_3.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL3_3.t7 -------------------------------------------------------------------------------- /models/modelL3_4.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL3_4.t7 -------------------------------------------------------------------------------- /models/modelL3_C.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL3_C.t7 -------------------------------------------------------------------------------- /models/modelL3_F.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL3_F.t7 -------------------------------------------------------------------------------- /models/modelL3_K.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL3_K.t7 -------------------------------------------------------------------------------- /models/modelL4_3.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL4_3.t7 -------------------------------------------------------------------------------- /models/modelL4_4.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL4_4.t7 -------------------------------------------------------------------------------- /models/modelL4_C.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL4_C.t7 -------------------------------------------------------------------------------- /models/modelL4_F.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL4_F.t7 -------------------------------------------------------------------------------- /models/modelL4_K.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL4_K.t7 -------------------------------------------------------------------------------- /models/modelL5_3.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL5_3.t7 -------------------------------------------------------------------------------- /models/modelL5_4.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL5_4.t7 -------------------------------------------------------------------------------- /models/modelL5_C.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL5_C.t7 -------------------------------------------------------------------------------- /models/modelL5_F.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL5_F.t7 -------------------------------------------------------------------------------- /models/modelL5_K.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL5_K.t7 -------------------------------------------------------------------------------- /models/modelL6_C.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL6_C.t7 -------------------------------------------------------------------------------- /models/modelL6_F.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL6_F.t7 -------------------------------------------------------------------------------- /models/modelL6_K.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL6_K.t7 -------------------------------------------------------------------------------- /models/volcon.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. 2 | -- All rights reserved. 3 | -- This software is provided for research purposes only. 4 | -- By using this software you agree to the terms of the license file 5 | -- in the root folder. 6 | -- For commercial use, please contact ps-license@tue.mpg.de. 7 | 8 | require 'nn' 9 | require 'cutorch' 10 | require 'cunn' 11 | require 'cudnn' 12 | function createModel(nGPU) 13 | local model = nn.Sequential() 14 | model:add(nn.SpatialConvolution(8,32,7,7,1,1,3,3)) 15 | model:add(nn.ReLU()) 16 | model:add(nn.SpatialConvolution(32,64,7,7,1,1,3,3)) 17 | model:add(nn.ReLU()) 18 | model:add(nn.SpatialConvolution(64,32,7,7,1,1,3,3)) 19 | model:add(nn.ReLU()) 20 | model:add(nn.SpatialConvolution(32,16,7,7,1,1,3,3)) 21 | model:add(nn.ReLU()) 22 | model:add(nn.SpatialConvolution(16,2,7,7,1,1,3,3)) 23 | 24 | if nGPU>0 then 25 | model:cuda() 26 | model = makeDataParallel(model, nGPU) 27 | end 28 | 29 | return model 30 | end 31 | -------------------------------------------------------------------------------- /opts.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. 2 | -- All rights reserved. 3 | -- This software is provided for research purposes only. 4 | -- By using this software you agree to the terms of the license file 5 | -- in the root folder. 6 | -- For commercial use, please contact ps-license@tue.mpg.de. 7 | 8 | local M = { } 9 | 10 | function M.parse(arg) 11 | local cmd = torch.CmdLine() 12 | cmd:text() 13 | cmd:text('SPyNet Coarse-to-Fine Optical Flow Training') 14 | cmd:text() 15 | cmd:text('Options:') 16 | ------------ General options -------------------- 17 | 18 | cmd:option('-cache', 'checkpoint/', 'subdirectory in which to save/log experiments') 19 | cmd:option('-data', 'flying_chairs/data', 'Home of Flying Chairs dataset') 20 | cmd:option('-trainValidationSplit', 'train_val_split.txt', 'File containing training and validation split') 21 | cmd:option('-manualSeed', 2, 'Manually set RNG seed') 22 | cmd:option('-GPU', 1, 'Default preferred GPU') 23 | cmd:option('-nGPU', 1, 'Number of GPUs to use by default') 24 | cmd:option('-backend', 'cudnn', 'Options: cudnn | ccn2 | cunn') 25 | ------------- Data options ------------------------ 26 | cmd:option('-nDonkeys', 4, 'number of donkeys to initialize (data loading threads)') 27 | cmd:option('-fineWidth', 512, 'the length of the fine flow field') 28 | cmd:option('-fineHeight', 384, 'the width of the fine flow field') 29 | cmd:option('-level', 1, 'Options: 1,2,3.., wheather to initialize flow to zero' ) 30 | ------------- Training options -------------------- 31 | cmd:option('-augment', 1, 'augment the data') 32 | cmd:option('-nEpochs', 1000, 'Number of total epochs to run') 33 | cmd:option('-epochSize', 1000, 'Number of batches per epoch') 34 | cmd:option('-epochNumber', 1, 'Manual epoch number (useful on restarts)') 35 | cmd:option('-batchSize', 32, 'mini-batch size (1 = pure stochastic)') 36 | ---------- Optimization options ---------------------- 37 | cmd:option('-LR', 0.0, 'learning rate; if set, overrides default LR/WD recipe') 38 | cmd:option('-momentum', 0.9, 'momentum') 39 | cmd:option('-weightDecay', 5e-4, 'weight decay') 40 | cmd:option('-optimizer', 'adam', 'adam or sgd') 41 | ---------- Model options ---------------------------------- 42 | cmd:option('-L1', 'models/modelL1_4.t7', 'Trained Level 1 model') 43 | cmd:option('-L2', 'models/modelL2_4.t7', 'Trained Level 2 model') 44 | cmd:option('-L3', 'models/modelL3_4.t7', 'Trained Level 3 model') 45 | cmd:option('-L4', 'models/modelL4_4.t7', 'Trained Level 4 model') 46 | 47 | cmd:option('-netType', 'volcon', 'Lua network file') 48 | cmd:option('-retrain', 'none', 'provide path to model to retrain with') 49 | cmd:option('-optimState', 'none', 'provide path to an optimState to reload from') 50 | cmd:text() 51 | 52 | local opt = cmd:parse(arg or {}) 53 | opt.save = paths.concat(opt.cache) 54 | -- add date/time 55 | opt.save = paths.concat(opt.save, '' .. os.date():gsub(' ','')) 56 | 57 | opt.loadSize = {8, 384, 512} 58 | return opt 59 | end 60 | 61 | return M 62 | -------------------------------------------------------------------------------- /samples/00001_flow.flo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00001_flow.flo -------------------------------------------------------------------------------- /samples/00001_img1.ppm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00001_img1.ppm -------------------------------------------------------------------------------- /samples/00001_img2.ppm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00001_img2.ppm -------------------------------------------------------------------------------- /samples/00002_flow.flo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00002_flow.flo -------------------------------------------------------------------------------- /samples/00002_img1.ppm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00002_img1.ppm -------------------------------------------------------------------------------- /samples/00002_img2.ppm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00002_img2.ppm -------------------------------------------------------------------------------- /samples/00003_flow.flo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00003_flow.flo -------------------------------------------------------------------------------- /samples/00003_img1.ppm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00003_img1.ppm -------------------------------------------------------------------------------- /samples/00003_img2.ppm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00003_img2.ppm -------------------------------------------------------------------------------- /spynet.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. 2 | -- All rights reserved. 3 | -- This software is provided for research purposes only. 4 | -- By using this software you agree to the terms of the license file 5 | -- in the root folder. 6 | -- For commercial use, please contact ps-license@tue.mpg.de. 7 | 8 | require 'image' 9 | local TF = require 'transforms' 10 | require 'cutorch' 11 | require 'nn' 12 | require 'cunn' 13 | require 'cudnn' 14 | require 'nngraph' 15 | require 'stn' 16 | require 'spy' 17 | local flowX = require 'flowExtensions' 18 | 19 | local M = {} 20 | 21 | local eps = 1e-6 22 | local meanstd = { 23 | mean = { 0.485, 0.456, 0.406 }, 24 | std = { 0.229, 0.224, 0.225 }, 25 | } 26 | local pca = { 27 | eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 }, 28 | eigvec = torch.Tensor{ 29 | { -0.5675, 0.7192, 0.4009 }, 30 | { -0.5808, -0.0045, -0.8140 }, 31 | { -0.5836, -0.6948, 0.4203 }, 32 | }, 33 | } 34 | 35 | local mean = meanstd.mean 36 | local std = meanstd.std 37 | ------------------------------------------ 38 | local function createWarpModel() 39 | local imgData = nn.Identity()() 40 | local floData = nn.Identity()() 41 | 42 | local imgOut = nn.Transpose({2,3},{3,4})(imgData) 43 | local floOut = nn.Transpose({2,3},{3,4})(floData) 44 | 45 | local warpImOut = nn.Transpose({3,4},{2,3})(nn.BilinearSamplerBHWD()({imgOut, floOut})) 46 | local model = nn.gModule({imgData, floData}, {warpImOut}) 47 | 48 | return model 49 | end 50 | 51 | local down2 = nn.SpatialAveragePooling(2,2,2,2):cuda() 52 | local down3 = nn.SpatialAveragePooling(2,2,2,2):cuda() 53 | local down4 = nn.SpatialAveragePooling(2,2,2,2):cuda() 54 | local down5 = nn.SpatialAveragePooling(2,2,2,2):cuda() 55 | local down6 = nn.SpatialAveragePooling(2,2,2,2):cuda() 56 | 57 | local up2 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda() 58 | local up3 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda() 59 | local up4 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda() 60 | local up5 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda() 61 | local up6 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda() 62 | 63 | local warpmodel2 = createWarpModel():cuda() 64 | local warpmodel3 = createWarpModel():cuda() 65 | local warpmodel4 = createWarpModel():cuda() 66 | local warpmodel5 = createWarpModel():cuda() 67 | local warpmodel6 = createWarpModel():cuda() 68 | 69 | down2:evaluate() 70 | down3:evaluate() 71 | down4:evaluate() 72 | down5:evaluate() 73 | down6:evaluate() 74 | 75 | up2:evaluate() 76 | up3:evaluate() 77 | up4:evaluate() 78 | up5:evaluate() 79 | up6:evaluate() 80 | 81 | warpmodel2:evaluate() 82 | warpmodel3:evaluate() 83 | warpmodel4:evaluate() 84 | warpmodel5:evaluate() 85 | warpmodel6:evaluate() 86 | 87 | ------------------------------------------------- 88 | local modelL1, modelL2, modelL3, modelL4, modelL5, modelL6 89 | local modelL1path, modelL2path, modelL3path, modelL4path, modelL5path, modelL6path 90 | 91 | local function loadImage(path) 92 | local input = image.load(path, 3, 'float') 93 | return input 94 | end 95 | M.loadImage = loadImage 96 | 97 | local function loadFlow(filename) 98 | TAG_FLOAT = 202021.25 99 | local ff = torch.DiskFile(filename):binary() 100 | local tag = ff:readFloat() 101 | if tag ~= TAG_FLOAT then 102 | xerror('unable to read '..filename.. 103 | ' perhaps bigendian error','readflo()') 104 | end 105 | 106 | local w = ff:readInt() 107 | local h = ff:readInt() 108 | local nbands = 2 109 | local tf = torch.FloatTensor(h, w, nbands) 110 | ff:readFloat(tf:storage()) 111 | ff:close() 112 | 113 | local flow = tf:permute(3,1,2) 114 | return flow 115 | end 116 | M.loadFlow = loadFlow 117 | 118 | 119 | local function computeInitFlowL1(imagesL1) 120 | local h = imagesL1:size(3) 121 | local w = imagesL1:size(4) 122 | local batchSize = imagesL1:size(1) 123 | 124 | local _flowappend = torch.zeros(batchSize, 2, h, w):cuda() 125 | local images_in = torch.cat(imagesL1, _flowappend, 2) 126 | 127 | local flow_est = modelL1:forward(images_in) 128 | return flow_est 129 | end 130 | M.computeInitFlowL1 = computeInitFlowL1 131 | 132 | local function computeInitFlowL2(imagesL2) 133 | local imagesL1 = down2:forward(imagesL2:clone()) 134 | local _flowappend = up2:forward(computeInitFlowL1(imagesL1))*2 135 | local _img2 = imagesL2[{{},{4,6},{},{}}] 136 | imagesL2[{{},{4,6},{},{}}]:copy(warpmodel2:forward({_img2, _flowappend})) 137 | 138 | local images_in = torch.cat(imagesL2, _flowappend, 2) 139 | 140 | local flow_est = modelL2:forward(images_in) 141 | return flow_est:add(_flowappend) 142 | end 143 | M.computeInitFlowL2 = computeInitFlowL2 144 | 145 | local function computeInitFlowL3(imagesL3) 146 | local imagesL2 = down3:forward(imagesL3:clone()) 147 | local _flowappend = up3:forward(computeInitFlowL2(imagesL2))*2 148 | local _img2 = imagesL3[{{},{4,6},{},{}}] 149 | imagesL3[{{},{4,6},{},{}}]:copy(warpmodel3:forward({_img2, _flowappend})) 150 | 151 | local images_in = torch.cat(imagesL3, _flowappend, 2) 152 | 153 | local flow_est = modelL3:forward(images_in) 154 | return flow_est:add(_flowappend) 155 | end 156 | M.computeInitFlowL3 = computeInitFlowL3 157 | 158 | local function computeInitFlowL4(imagesL4) 159 | local imagesL3 = down4:forward(imagesL4) 160 | local _flowappend = up4:forward(computeInitFlowL3(imagesL3))*2 161 | local _img2 = imagesL4[{{},{4,6},{},{}}] 162 | imagesL4[{{},{4,6},{},{}}]:copy(warpmodel4:forward({_img2, _flowappend})) 163 | 164 | local images_in = torch.cat(imagesL4, _flowappend, 2) 165 | 166 | local flow_est = modelL4:forward(images_in) 167 | return flow_est:add(_flowappend) 168 | end 169 | M.computeInitFlowL4 = computeInitFlowL4 170 | 171 | local function computeInitFlowL5(imagesL5) 172 | local imagesL4 = down5:forward(imagesL5) 173 | local _flowappend = up5:forward(computeInitFlowL4(imagesL4))*2 174 | 175 | local _img2 = imagesL5[{{},{4,6},{},{}}] 176 | imagesL5[{{},{4,6},{},{}}]:copy(warpmodel5:forward({_img2, _flowappend})) 177 | 178 | local images_in = torch.cat(imagesL5, _flowappend, 2) 179 | 180 | local flow_est = modelL5:forward(images_in) 181 | return flow_est:add(_flowappend) 182 | end 183 | M.computeInitFlowL5 = computeInitFlowL5 184 | 185 | local function computeInitFlowL6(imagesL6) 186 | local imagesL5 = down6:forward(imagesL6) 187 | local _flowappend = up6:forward(computeInitFlowL5(imagesL5))*2 188 | 189 | local _img2 = imagesL6[{{},{4,6},{},{}}] 190 | imagesL6[{{},{4,6},{},{}}]:copy(warpmodel6:forward({_img2, _flowappend})) 191 | 192 | local images_in = torch.cat(imagesL6, _flowappend, 2) 193 | 194 | local flow_est = modelL6:forward(images_in) 195 | return flow_est:add(_flowappend) 196 | end 197 | M.computeInitFlowL6 = computeInitFlowL6 198 | 199 | 200 | local function setup(width, height, opt) 201 | opt = opt or "sintelFinal" 202 | local len = math.max(width, height) 203 | local computeFlow 204 | local level 205 | 206 | if len <= 32 then 207 | computeFlow = computeInitFlowL1 208 | level = 1 209 | elseif len <= 64 then 210 | computeFlow = computeInitFlowL2 211 | level = 2 212 | elseif len <= 128 then 213 | computeFlow = computeInitFlowL3 214 | level = 3 215 | elseif len <= 256 then 216 | computeFlow = computeInitFlowL4 217 | level = 4 218 | elseif len <= 512 then 219 | computeFlow = computeInitFlowL5 220 | level = 5 221 | elseif len <= 1472 then 222 | computeFlow = computeInitFlowL6 223 | level = 6 224 | else 225 | error("Only image size <= 1472 supported. Next release will have full support.") 226 | end 227 | 228 | if opt=="sintelFinal" then 229 | modelL1path = paths.concat('models', 'modelL1_F.t7') 230 | modelL2path = paths.concat('models', 'modelL2_F.t7') 231 | modelL3path = paths.concat('models', 'modelL3_F.t7') 232 | modelL4path = paths.concat('models', 'modelL4_F.t7') 233 | modelL5path = paths.concat('models', 'modelL5_F.t7') 234 | modelL6path = paths.concat('models', 'modelL6_F.t7') 235 | end 236 | 237 | if opt=="sintelClean" then 238 | modelL1path = paths.concat('models', 'modelL1_C.t7') 239 | modelL2path = paths.concat('models', 'modelL2_C.t7') 240 | modelL3path = paths.concat('models', 'modelL3_C.t7') 241 | modelL4path = paths.concat('models', 'modelL4_C.t7') 242 | modelL5path = paths.concat('models', 'modelL5_C.t7') 243 | modelL6path = paths.concat('models', 'modelL6_C.t7') 244 | end 245 | 246 | if opt=="chairsClean" then 247 | modelL1path = paths.concat('models', 'modelL1_4.t7') 248 | modelL2path = paths.concat('models', 'modelL2_4.t7') 249 | modelL3path = paths.concat('models', 'modelL3_4.t7') 250 | modelL4path = paths.concat('models', 'modelL4_4.t7') 251 | modelL5path = paths.concat('models', 'modelL5_4.t7') 252 | modelL6path = paths.concat('models', 'modelL5_4.t7') 253 | end 254 | 255 | if opt=="chairsFinal" then 256 | modelL1path = paths.concat('models', 'modelL1_3.t7') 257 | modelL2path = paths.concat('models', 'modelL2_3.t7') 258 | modelL3path = paths.concat('models', 'modelL3_3.t7') 259 | modelL4path = paths.concat('models', 'modelL4_3.t7') 260 | modelL5path = paths.concat('models', 'modelL5_3.t7') 261 | modelL6path = paths.concat('models', 'modelL5_3.t7') 262 | end 263 | 264 | if opt=="kittiFinal" then 265 | modelL1path = paths.concat('models', 'modelL1_K.t7') 266 | modelL2path = paths.concat('models', 'modelL2_K.t7') 267 | modelL3path = paths.concat('models', 'modelL3_K.t7') 268 | modelL4path = paths.concat('models', 'modelL4_K.t7') 269 | modelL5path = paths.concat('models', 'modelL5_K.t7') 270 | modelL6path = paths.concat('models', 'modelL6_K.t7') 271 | end 272 | 273 | 274 | if level>0 then 275 | modelL1 = torch.load(modelL1path) 276 | if torch.type(modelL1) == 'nn.DataParallelTable' then 277 | modelL1 = modelL1:get(1) 278 | end 279 | modelL1:evaluate() 280 | end 281 | 282 | if level>1 then 283 | modelL2 = torch.load(modelL2path) 284 | if torch.type(modelL2) == 'nn.DataParallelTable' then 285 | modelL2 = modelL2:get(1) 286 | end 287 | modelL2:evaluate() 288 | end 289 | 290 | if level>2 then 291 | modelL3 = torch.load(modelL3path) 292 | if torch.type(modelL3) == 'nn.DataParallelTable' then 293 | modelL3 = modelL3:get(1) 294 | end 295 | modelL3:evaluate() 296 | end 297 | 298 | if level>3 then 299 | modelL4 = torch.load(modelL4path) 300 | if torch.type(modelL4) == 'nn.DataParallelTable' then 301 | modelL4 = modelL4:get(1) 302 | end 303 | modelL4:evaluate() 304 | end 305 | 306 | if level>4 then 307 | modelL5 = torch.load(modelL5path) 308 | if torch.type(modelL5) == 'nn.DataParallelTable' then 309 | modelL5 = modelL5:get(1) 310 | end 311 | modelL5:evaluate() 312 | end 313 | 314 | if level>5 then 315 | modelL6 = torch.load(modelL6path) 316 | if torch.type(modelL6) == 'nn.DataParallelTable' then 317 | modelL6 = modelL6:get(1) 318 | end 319 | modelL6:evaluate() 320 | end 321 | 322 | return computeFlow 323 | end 324 | M.setup = setup 325 | 326 | local function DeAdjustFlow(flow, h, w) 327 | local sc_h = h/flow:size(2) 328 | local sc_w = w/flow:size(3) 329 | flow = image.scale(flow, w, h, 'simple') 330 | flow[2] = flow[2]*sc_h 331 | flow[1] = flow[1]*sc_w 332 | 333 | return flow 334 | end 335 | M.DeAdjustFlow = DeAdjustFlow 336 | 337 | local function normalize(imgs) 338 | return TF.ColorNormalize(meanstd)(imgs) 339 | end 340 | M.normalize = normalize 341 | 342 | local easyComputeFlow = function(im1, im2) 343 | local imgs = torch.cat(im1, im2, 1) 344 | imgs = TF.ColorNormalize(meanstd)(imgs) 345 | 346 | local width = imgs:size(3) 347 | local height = imgs:size(2) 348 | 349 | local fineWidth, fineHeight 350 | 351 | if width%32 == 0 then 352 | fineWidth = width 353 | else 354 | fineWidth = width + 32 - math.fmod(width, 32) 355 | end 356 | 357 | if height%32 == 0 then 358 | fineHeight = height 359 | else 360 | fineHeight = height + 32 - math.fmod(height, 32) 361 | end 362 | 363 | imgs = image.scale(imgs, fineWidth, fineHeight) 364 | 365 | local len = math.max(fineWidth, fineHeight) 366 | local computeFlow 367 | 368 | if len <= 32 then 369 | computeFlow = computeInitFlowL1 370 | elseif len <= 64 then 371 | computeFlow = computeInitFlowL2 372 | elseif len <= 128 then 373 | computeFlow = computeInitFlowL3 374 | elseif len <= 256 then 375 | computeFlow = computeInitFlowL4 376 | elseif len <= 512 then 377 | computeFlow = computeInitFlowL5 378 | else 379 | computeFlow = computeInitFlowL6 380 | end 381 | 382 | imgs = imgs:resize(1,6,fineHeight,fineWidth):cuda() 383 | local flow_est = computeFlow(imgs) 384 | 385 | flow_est = flow_est:squeeze():float() 386 | flow_est = DeAdjustFlow(flow_est, height, width) 387 | 388 | return flow_est 389 | 390 | end 391 | 392 | local function easy_setup(opt) 393 | opt = opt or 'sintelFinal' 394 | 395 | if opt=="sintelFinal" then 396 | modelL1path = paths.concat('models', 'modelL1_F.t7') 397 | modelL2path = paths.concat('models', 'modelL2_F.t7') 398 | modelL3path = paths.concat('models', 'modelL3_F.t7') 399 | modelL4path = paths.concat('models', 'modelL4_F.t7') 400 | modelL5path = paths.concat('models', 'modelL5_F.t7') 401 | modelL6path = paths.concat('models', 'modelL6_F.t7') 402 | end 403 | 404 | if opt=="sintelClean" then 405 | modelL1path = paths.concat('models', 'modelL1_C.t7') 406 | modelL2path = paths.concat('models', 'modelL2_C.t7') 407 | modelL3path = paths.concat('models', 'modelL3_C.t7') 408 | modelL4path = paths.concat('models', 'modelL4_C.t7') 409 | modelL5path = paths.concat('models', 'modelL5_C.t7') 410 | modelL6path = paths.concat('models', 'modelL6_C.t7') 411 | end 412 | 413 | if opt=="chairsClean" then 414 | modelL1path = paths.concat('models', 'modelL1_4.t7') 415 | modelL2path = paths.concat('models', 'modelL2_4.t7') 416 | modelL3path = paths.concat('models', 'modelL3_4.t7') 417 | modelL4path = paths.concat('models', 'modelL4_4.t7') 418 | modelL5path = paths.concat('models', 'modelL5_4.t7') 419 | modelL6path = paths.concat('models', 'modelL5_4.t7') 420 | end 421 | 422 | if opt=="chairsFinal" then 423 | modelL1path = paths.concat('models', 'modelL1_3.t7') 424 | modelL2path = paths.concat('models', 'modelL2_3.t7') 425 | modelL3path = paths.concat('models', 'modelL3_3.t7') 426 | modelL4path = paths.concat('models', 'modelL4_3.t7') 427 | modelL5path = paths.concat('models', 'modelL5_3.t7') 428 | modelL6path = paths.concat('models', 'modelL5_3.t7') 429 | end 430 | 431 | if opt=="kittiFinal" then 432 | modelL1path = paths.concat('models', 'modelL1_K.t7') 433 | modelL2path = paths.concat('models', 'modelL2_K.t7') 434 | modelL3path = paths.concat('models', 'modelL3_K.t7') 435 | modelL4path = paths.concat('models', 'modelL4_K.t7') 436 | modelL5path = paths.concat('models', 'modelL5_K.t7') 437 | modelL6path = paths.concat('models', 'modelL6_K.t7') 438 | end 439 | 440 | modelL1 = torch.load(modelL1path) 441 | if torch.type(modelL1) == 'nn.DataParallelTable' then 442 | modelL1 = modelL1:get(1) 443 | end 444 | modelL1:evaluate() 445 | 446 | modelL2 = torch.load(modelL2path) 447 | if torch.type(modelL2) == 'nn.DataParallelTable' then 448 | modelL2 = modelL2:get(1) 449 | end 450 | modelL2:evaluate() 451 | 452 | modelL3 = torch.load(modelL3path) 453 | if torch.type(modelL3) == 'nn.DataParallelTable' then 454 | modelL3 = modelL3:get(1) 455 | end 456 | modelL3:evaluate() 457 | 458 | modelL4 = torch.load(modelL4path) 459 | if torch.type(modelL4) == 'nn.DataParallelTable' then 460 | modelL4 = modelL4:get(1) 461 | end 462 | modelL4:evaluate() 463 | 464 | modelL5 = torch.load(modelL5path) 465 | if torch.type(modelL5) == 'nn.DataParallelTable' then 466 | modelL5 = modelL5:get(1) 467 | end 468 | modelL5:evaluate() 469 | 470 | modelL6 = torch.load(modelL6path) 471 | if torch.type(modelL6) == 'nn.DataParallelTable' then 472 | modelL6 = modelL6:get(1) 473 | end 474 | modelL6:evaluate() 475 | return easyComputeFlow 476 | end 477 | M.easy_setup = easy_setup 478 | 479 | 480 | 481 | return M -------------------------------------------------------------------------------- /test.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. 2 | -- All rights reserved. 3 | -- This software is provided for research purposes only. 4 | -- By using this software you agree to the terms of the license file 5 | -- in the root folder. 6 | -- For commercial use, please contact ps-license@tue.mpg.de. 7 | -- 8 | -- Copyright (c) 2014, Facebook, Inc. 9 | -- All rights reserved. 10 | -- 11 | -- This source code is licensed under the BSD-style license found in the 12 | -- LICENSE file in the root directory of this source tree. An additional grant 13 | -- of patent rights can be found in the PATENTS file in the same directory. 14 | -- 15 | testLogger = optim.Logger(paths.concat(opt.save, 'test.log')) 16 | 17 | local batchNumber 18 | local error_center, loss 19 | local timer = torch.Timer() 20 | 21 | function test() 22 | print('==> doing epoch on validation data:') 23 | print("==> online epoch # " .. epoch) 24 | 25 | batchNumber = 0 26 | cutorch.synchronize() 27 | timer:reset() 28 | 29 | -- set the dropouts to evaluate mode 30 | model:evaluate() 31 | 32 | error_center = 0 33 | loss = 0 34 | for i=1,nTest/opt.batchSize do -- nTest is set in 1_data.lua 35 | local indexStart = (i-1) * opt.batchSize + 1 36 | local indexEnd = (indexStart + opt.batchSize - 1) 37 | donkeys:addjob( 38 | -- work to be done by donkey thread 39 | function() 40 | local inputs, labels = testLoader:get(indexStart, indexEnd) 41 | return inputs, labels 42 | end, 43 | -- callback that is run in the main thread once the work is done 44 | testBatch 45 | ) 46 | end 47 | 48 | donkeys:synchronize() 49 | cutorch.synchronize() 50 | 51 | error_center = error_center * 100 / nTest 52 | loss = loss / (nTest/opt.batchSize) -- because loss is calculated per batch 53 | testLogger:add{ 54 | ['% top1 accuracy (test set) (center crop)'] = error_center, 55 | ['avg loss (test set)'] = loss 56 | } 57 | print(string.format('Epoch: [%d][TESTING SUMMARY] Total Time(s): %.2f \t' 58 | .. 'average loss (per batch): %.2f \t ' 59 | .. 'accuracy [Center](%%):\t top-1 %.2f\t ', 60 | epoch, timer:time().real, loss, error_center)) 61 | 62 | print('\n') 63 | 64 | 65 | end -- of test() 66 | ----------------------------------------------------------------------------- 67 | local inputs = torch.CudaTensor() 68 | local labels = torch.CudaTensor() 69 | 70 | function testBatch(inputsCPU, labelsCPU) 71 | batchNumber = batchNumber + opt.batchSize 72 | 73 | inputs:resize(inputsCPU:size()):copy(inputsCPU) 74 | labels:resize(labelsCPU:size()):copy(labelsCPU) 75 | 76 | local outputs = model:forward(inputs) 77 | local err = criterion:forward(outputs, labels) 78 | cutorch.synchronize() 79 | local pred = outputs:float() 80 | 81 | loss = loss + err 82 | 83 | print(('Epoch: Testing [%d][%d/%d]'):format(epoch, batchNumber, nTest)) 84 | end 85 | -------------------------------------------------------------------------------- /timing_benchmark.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. 2 | -- All rights reserved. 3 | -- This software is provided for research purposes only. 4 | -- By using this software you agree to the terms of the license file 5 | -- in the root folder. 6 | -- For commercial use, please contact ps-license@tue.mpg.de. 7 | 8 | require 'image' 9 | require 'cutorch' 10 | 11 | local cmd = torch.CmdLine() 12 | cmd:option('-data', '../FlyingChairs/data', 'Flying Chairs data directory') 13 | opt = cmd:parse(arg or {}) 14 | 15 | opt.showFlow = 0 16 | opt.fineHeight = 384 17 | opt.fineWidth = 512 18 | opt.preprocess = 0 19 | opt.level = 5 20 | opt.polluteFlow = 0 21 | opt.augment = 0 22 | opt.warp = 1 23 | opt.batchSize = 1 24 | local donkey = require('timing_util') 25 | 26 | local train_samples, validation_samples = donkey.getTrainValidationSplits('train_val_split.txt') 27 | local loss = torch.zeros(1,1, opt.fineHeight, opt.fineWidth):float() 28 | local errors = torch.zeros(validation_samples:size()[1]) 29 | timings = torch.zeros(validation_samples:size()[1]) 30 | local loss = 0 31 | local flowCPU = cutorch.createCudaHostTensor(640, 2,opt.fineHeight,opt.fineWidth):uniform() 32 | 33 | for i=1,validation_samples:size()[1] do 34 | collectgarbage() 35 | 36 | local id = validation_samples[i][1] 37 | local imgs, flow = donkey.testHook(id) 38 | 39 | timer = torch.Timer() 40 | imgs = imgs:resize(1,6,opt.fineHeight, opt.fineWidth):cuda() 41 | flow_est = donkey.computeInitFlowL5(imgs):squeeze() 42 | flowCPU[i]:copyAsync(flow_est) 43 | cutorch.streamSynchronize(cutorch.getStream()) 44 | local time_elapsed = timer:time().real 45 | 46 | print('Time Elapsed: '..time_elapsed) 47 | 48 | timings[i] = time_elapsed 49 | end 50 | cutorch.streamSynchronize(cutorch.getStream()) 51 | 52 | 53 | for i=1,validation_samples:size()[1] do 54 | local id = validation_samples[i][1] 55 | local raw_im1, raw_im2, raw_flow = donkey.getRawData(id) 56 | 57 | local _err = (raw_flow - flowCPU[i]):pow(2) 58 | local err = torch.sum(_err, 1):sqrt() 59 | loss = loss + err:float() 60 | errors[i] = err:mean() 61 | 62 | print(i, errors[i]) 63 | end 64 | loss = torch.div(loss, validation_samples:size()[1]) 65 | print('Average EPE = '..loss:sum()/(opt.fineWidth*opt.fineHeight)) 66 | print('Mean Timing: ' ..timings:mean()) 67 | print('Median Timing: ' ..timings:median()[1]) 68 | -------------------------------------------------------------------------------- /timing_util.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. 2 | -- All rights reserved. 3 | -- This software is provided for research purposes only. 4 | -- By using this software you agree to the terms of the license file 5 | -- in the root folder. 6 | -- For commercial use, please contact ps-license@tue.mpg.de. 7 | 8 | require 'image' 9 | local TF = require 'transforms' 10 | require 'cutorch' 11 | require 'nn' 12 | require 'cunn' 13 | require 'cudnn' 14 | require 'nngraph' 15 | require 'stn' 16 | require 'spy' 17 | local flowX = require 'flowExtensions' 18 | 19 | local M = {} 20 | 21 | local eps = 1e-6 22 | local meanstd = { 23 | mean = { 0.485, 0.456, 0.406 }, 24 | std = { 0.229, 0.224, 0.225 }, 25 | } 26 | local pca = { 27 | eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 }, 28 | eigvec = torch.Tensor{ 29 | { -0.5675, 0.7192, 0.4009 }, 30 | { -0.5808, -0.0045, -0.8140 }, 31 | { -0.5836, -0.6948, 0.4203 }, 32 | }, 33 | } 34 | 35 | local mean = meanstd.mean 36 | local std = meanstd.std 37 | ------------------------------------------ 38 | local function createWarpModel() 39 | local imgData = nn.Identity()() 40 | local floData = nn.Identity()() 41 | 42 | local imgOut = nn.Transpose({2,3},{3,4})(imgData) 43 | local floOut = nn.Transpose({2,3},{3,4})(floData) 44 | 45 | local warpImOut = nn.Transpose({3,4},{2,3})(nn.BilinearSamplerBHWD()({imgOut, floOut})) 46 | local model = nn.gModule({imgData, floData}, {warpImOut}) 47 | 48 | return model 49 | end 50 | 51 | local down2 = nn.SpatialAveragePooling(2,2,2,2):cuda() 52 | local down3 = nn.SpatialAveragePooling(2,2,2,2):cuda() 53 | local down4 = nn.SpatialAveragePooling(2,2,2,2):cuda() 54 | local down5 = nn.SpatialAveragePooling(2,2,2,2):cuda() 55 | 56 | local up2 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda() 57 | local up3 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda() 58 | local up4 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda() 59 | local up5 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda() 60 | 61 | local warpmodel2 = createWarpModel():cuda() 62 | local warpmodel3 = createWarpModel():cuda() 63 | local warpmodel4 = createWarpModel():cuda() 64 | local warpmodel5 = createWarpModel():cuda() 65 | 66 | down2:evaluate() 67 | down3:evaluate() 68 | down4:evaluate() 69 | down5:evaluate() 70 | 71 | up2:evaluate() 72 | up3:evaluate() 73 | up4:evaluate() 74 | up5:evaluate() 75 | 76 | warpmodel2:evaluate() 77 | warpmodel3:evaluate() 78 | warpmodel4:evaluate() 79 | warpmodel5:evaluate() 80 | 81 | ------------------------------------------------- 82 | local modelL0, modelL1, modelL2, modelL3, modelL4, modelL5 83 | local modelL1path, modelL2path, modelL3path, modelL4path, modelL5path 84 | 85 | modelL1path = paths.concat('models', 'modelL1_4.t7') 86 | modelL2path = paths.concat('models', 'modelL2_4.t7') 87 | modelL3path = paths.concat('models', 'modelL3_4.t7') 88 | modelL4path = paths.concat('models', 'modelL4_4.t7') 89 | modelL5path = paths.concat('models', 'modelL5_4.t7') 90 | 91 | modelL1 = torch.load(modelL1path) 92 | if torch.type(modelL1) == 'nn.DataParallelTable' then 93 | modelL1 = modelL1:get(1) 94 | end 95 | modelL1:evaluate() 96 | 97 | modelL2 = torch.load(modelL2path) 98 | if torch.type(modelL2) == 'nn.DataParallelTable' then 99 | modelL2 = modelL2:get(1) 100 | end 101 | modelL2:evaluate() 102 | 103 | modelL3 = torch.load(modelL3path) 104 | if torch.type(modelL3) == 'nn.DataParallelTable' then 105 | modelL3 = modelL3:get(1) 106 | end 107 | modelL3:evaluate() 108 | 109 | modelL4 = torch.load(modelL4path) 110 | if torch.type(modelL4) == 'nn.DataParallelTable' then 111 | modelL4 = modelL4:get(1) 112 | end 113 | modelL4:evaluate() 114 | 115 | modelL5 = torch.load(modelL5path) 116 | if torch.type(modelL5) == 'nn.DataParallelTable' then 117 | modelL5 = modelL5:get(1) 118 | end 119 | modelL5:evaluate() 120 | 121 | local function getTrainValidationSplits(path) 122 | local numSamples = sys.fexecute( "ls " .. opt.data .. "| wc -l")/3 123 | local ff = torch.DiskFile(path, 'r') 124 | local trainValidationSamples = torch.IntTensor(numSamples) 125 | ff:readInt(trainValidationSamples:storage()) 126 | ff:close() 127 | 128 | local train_samples = trainValidationSamples:eq(1):nonzero() 129 | local validation_samples = trainValidationSamples:eq(2):nonzero() 130 | 131 | return train_samples, validation_samples 132 | -- body 133 | end 134 | M.getTrainValidationSplits = getTrainValidationSplits 135 | 136 | local function loadImage(path) 137 | local input = image.load(path, 3, 'float') 138 | return input 139 | end 140 | M.loadImage = loadImage 141 | 142 | local function loadFlow(filename) 143 | TAG_FLOAT = 202021.25 144 | local ff = torch.DiskFile(filename):binary() 145 | local tag = ff:readFloat() 146 | if tag ~= TAG_FLOAT then 147 | xerror('unable to read '..filename.. 148 | ' perhaps bigendian error','readflo()') 149 | end 150 | 151 | local w = ff:readInt() 152 | local h = ff:readInt() 153 | local nbands = 2 154 | local tf = torch.FloatTensor(h, w, nbands) 155 | ff:readFloat(tf:storage()) 156 | ff:close() 157 | 158 | local flow = tf:permute(3,1,2) 159 | return flow 160 | end 161 | M.loadFlow = loadFlow 162 | 163 | 164 | local function computeInitFlowL1(imagesL1) 165 | local h = imagesL1:size(3) 166 | local w = imagesL1:size(4) 167 | 168 | local _flowappend = torch.zeros(opt.batchSize, 2, h, w):cuda() 169 | local images_in = torch.cat(imagesL1, _flowappend, 2) 170 | 171 | local flow_est = modelL1:forward(images_in) 172 | return flow_est 173 | end 174 | M.computeInitFlowL1 = computeInitFlowL1 175 | 176 | local function computeInitFlowL2(imagesL2) 177 | local imagesL1 = down2:forward(imagesL2:clone()) 178 | local _flowappend = up2:forward(computeInitFlowL1(imagesL1))*2 179 | local _img2 = imagesL2[{{},{4,6},{},{}}] 180 | imagesL2[{{},{4,6},{},{}}]:copy(warpmodel2:forward({_img2, _flowappend})) 181 | 182 | local images_in = torch.cat(imagesL2, _flowappend, 2) 183 | 184 | local flow_est = modelL2:forward(images_in) 185 | return flow_est:add(_flowappend) 186 | end 187 | M.computeInitFlowL2 = computeInitFlowL2 188 | 189 | local function computeInitFlowL3(imagesL3) 190 | local imagesL2 = down3:forward(imagesL3:clone()) 191 | local _flowappend = up3:forward(computeInitFlowL2(imagesL2))*2 192 | local _img2 = imagesL3[{{},{4,6},{},{}}] 193 | imagesL3[{{},{4,6},{},{}}]:copy(warpmodel3:forward({_img2, _flowappend})) 194 | 195 | local images_in = torch.cat(imagesL3, _flowappend, 2) 196 | 197 | local flow_est = modelL3:forward(images_in) 198 | return flow_est:add(_flowappend) 199 | end 200 | M.computeInitFlowL3 = computeInitFlowL3 201 | 202 | local function computeInitFlowL4(imagesL4) 203 | local imagesL3 = down4:forward(imagesL4) 204 | local _flowappend = up4:forward(computeInitFlowL3(imagesL3))*2 205 | local _img2 = imagesL4[{{},{4,6},{},{}}] 206 | imagesL4[{{},{4,6},{},{}}]:copy(warpmodel4:forward({_img2, _flowappend})) 207 | 208 | local images_in = torch.cat(imagesL4, _flowappend, 2) 209 | 210 | local flow_est = modelL4:forward(images_in) 211 | return flow_est:add(_flowappend) 212 | end 213 | M.computeInitFlowL4 = computeInitFlowL4 214 | 215 | local function computeInitFlowL5(imagesL5) 216 | local imagesL4 = down5:forward(imagesL5) 217 | local _flowappend = up5:forward(computeInitFlowL4(imagesL4))*2 218 | 219 | local _img2 = imagesL5[{{},{4,6},{},{}}] 220 | imagesL5[{{},{4,6},{},{}}]:copy(warpmodel5:forward({_img2, _flowappend})) 221 | 222 | local images_in = torch.cat(imagesL5, _flowappend, 2) 223 | 224 | local flow_est = modelL5:forward(images_in) 225 | return flow_est:add(_flowappend) 226 | end 227 | M.computeInitFlowL5 = computeInitFlowL5 228 | 229 | local function getRawData(id) 230 | local path1 = paths.concat(opt.data, (string.format("%05i", id) .."_img1.ppm")) 231 | local path2 = paths.concat(opt.data, (string.format("%05i", id) .."_img2.ppm")) 232 | 233 | local img1 = loadImage(path1) 234 | local img2 = loadImage(path2) 235 | 236 | local pathF = paths.concat(opt.data, (string.format("%05i", id) .."_flow.flo")) 237 | local flow = loadFlow(pathF) 238 | 239 | return img1, img2, flow 240 | end 241 | M.getRawData = getRawData 242 | 243 | local testHook = function(id) 244 | local path1 = paths.concat(opt.data, (string.format("%05i", id) .."_img1.ppm")) 245 | local path2 = paths.concat(opt.data, (string.format("%05i", id) .."_img2.ppm")) 246 | 247 | local img1 = loadImage(path1) 248 | local img2 = loadImage(path2) 249 | local images = torch.cat(img1, img2, 1) 250 | 251 | local pathF = paths.concat(opt.data, (string.format("%05i", id) .."_flow.flo")) 252 | local flow = loadFlow(pathF) 253 | 254 | images = TF.ColorNormalize(meanstd)(images) 255 | return images, flow 256 | end 257 | M.testHook = testHook 258 | 259 | return M 260 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. 2 | -- All rights reserved. 3 | -- This software is provided for research purposes only. 4 | -- By using this software you agree to the terms of the license file 5 | -- in the root folder. 6 | -- For commercial use, please contact ps-license@tue.mpg.de. 7 | -- 8 | -- Copyright (c) 2014, Facebook, Inc. 9 | -- All rights reserved. 10 | -- 11 | -- This source code is licensed under the BSD-style license found in the 12 | -- LICENSE file in the root directory of this source tree. An additional grant 13 | -- of patent rights can be found in the PATENTS file in the same directory. 14 | -- 15 | require 'optim' 16 | 17 | --[[ 18 | 1. Setup SGD optimization state and learning rate schedule 19 | 2. Create loggers. 20 | 3. train - this function handles the high-level training loop, 21 | i.e. load data, train model, save model and state to disk 22 | 4. trainBatch - Used by train() to train a single batch after the data is loaded. 23 | ]]-- 24 | 25 | -- Setup a reused optimization state (for sgd). If needed, reload it from disk 26 | local optimState = { 27 | learningRate = opt.LR, 28 | learningRateDecay = 0.0, 29 | momentum = opt.momentum, 30 | dampening = 0.0, 31 | weightDecay = opt.weightDecay 32 | } 33 | 34 | if opt.optimState ~= 'none' then 35 | assert(paths.filep(opt.optimState), 'File not found: ' .. opt.optimState) 36 | print('Loading optimState from file: ' .. opt.optimState) 37 | optimState = torch.load(opt.optimState) 38 | end 39 | 40 | -- Learning rate annealing schedule. We will build a new optimizer for 41 | -- each epoch. 42 | -- 43 | -- By default we follow a known recipe for a 55-epoch training. If 44 | -- the learningRate command-line parameter has been specified, though, 45 | -- we trust the user is doing something manual, and will use her 46 | -- exact settings for all optimization. 47 | -- 48 | -- Return values: 49 | -- diff to apply to optimState, 50 | -- true IFF this is the first epoch of a new regime 51 | local function paramsForEpoch(epoch) 52 | if opt.LR ~= 0.0 then -- if manually specified 53 | return { } 54 | end 55 | local regimes = { 56 | -- start, end, LR, WD, 57 | { 1, 10, 5e-3, 0 }, 58 | { 11, 80, 1e-4, 0 }, 59 | { 81, 120, 1e-4, 0 }, 60 | { 121, 160, 1e-4, 0 }, 61 | { 161, 1e8, 5e-5, 0 }, 62 | } 63 | 64 | for _, row in ipairs(regimes) do 65 | if epoch >= row[1] and epoch <= row[2] then 66 | return { learningRate=row[3], weightDecay=row[4] }, epoch == row[1] 67 | end 68 | end 69 | end 70 | 71 | -- 2. Create loggers. 72 | trainLogger = optim.Logger(paths.concat(opt.save, 'train.log')) 73 | local batchNumber 74 | local top1_epoch, loss_epoch 75 | 76 | -- 3. train - this function handles the high-level training loop, 77 | -- i.e. load data, train model, save model and state to disk 78 | function train() 79 | print('==> doing epoch on training data:') 80 | print("==> online epoch # " .. epoch) 81 | 82 | local params, newRegime = paramsForEpoch(epoch) 83 | if newRegime then 84 | optimState = { 85 | learningRate = params.learningRate, 86 | learningRateDecay = 0.0, 87 | momentum = opt.momentum, 88 | dampening = 0.0, 89 | weightDecay = params.weightDecay 90 | } 91 | end 92 | batchNumber = 0 93 | cutorch.synchronize() 94 | 95 | -- set the dropouts to training mode 96 | model:training() 97 | 98 | local tm = torch.Timer() 99 | top1_epoch = 0 100 | loss_epoch = 0 101 | for i=1,opt.epochSize do 102 | -- queue jobs to data-workers 103 | donkeys:addjob( 104 | -- the job callback (runs in data-worker thread) 105 | function() 106 | local inputs, labels = trainLoader:sample(opt.batchSize) 107 | return inputs, labels 108 | end, 109 | -- the end callback (runs in the main thread) 110 | trainBatch 111 | ) 112 | end 113 | 114 | donkeys:synchronize() 115 | cutorch.synchronize() 116 | 117 | top1_epoch = top1_epoch * 100 / (opt.batchSize * opt.epochSize) 118 | loss_epoch = loss_epoch / opt.epochSize 119 | 120 | trainLogger:add{ 121 | ['% top1 accuracy (train set)'] = top1_epoch, 122 | ['avg loss (train set)'] = loss_epoch 123 | } 124 | print(string.format('Epoch: [%d][TRAINING SUMMARY] Total Time(s): %.2f\t' 125 | .. 'average loss (per batch): %.2f \t ' 126 | .. 'accuracy(%%):\t top-1 %.2f\t', 127 | epoch, tm:time().real, loss_epoch, top1_epoch)) 128 | print('\n') 129 | 130 | -- save model 131 | collectgarbage() 132 | 133 | -- clear the intermediate states in the model before saving to disk 134 | -- this saves lots of disk space 135 | model:clearState() 136 | saveDataParallel(paths.concat(opt.save, 'model_' .. epoch .. '.t7'), model) -- defined in util.lua 137 | torch.save(paths.concat(opt.save, 'optimState_' .. epoch .. '.t7'), optimState) 138 | end -- of train() 139 | ------------------------------------------------------------------------------------------- 140 | -- GPU inputs (preallocate) 141 | local inputs = torch.CudaTensor() 142 | local labels = torch.CudaTensor() 143 | 144 | local timer = torch.Timer() 145 | local dataTimer = torch.Timer() 146 | 147 | local parameters, gradParameters = model:getParameters() 148 | 149 | -- 4. trainBatch - Used by train() to train a single batch after the data is loaded. 150 | function trainBatch(inputsCPU, labelsCPU) 151 | cutorch.synchronize() 152 | collectgarbage() 153 | local dataLoadingTime = dataTimer:time().real 154 | timer:reset() 155 | 156 | -- transfer over to GPU 157 | inputs:resize(inputsCPU:size()):copy(inputsCPU) 158 | labels:resize(labelsCPU:size()):copy(labelsCPU) 159 | 160 | local err, outputs 161 | feval = function(x) 162 | model:zeroGradParameters() 163 | outputs = model:forward(inputs) 164 | err = criterion:forward(outputs, labels) 165 | local gradOutputs = criterion:backward(outputs, labels) 166 | model:backward(inputs, gradOutputs) 167 | return err, gradParameters 168 | end 169 | 170 | if opt.optimizer == 'adam' then 171 | optim.adam(feval, parameters, optimState) 172 | elseif opt.optimizer == 'sgd' then 173 | optim.sgd(feval, parameters, optimState) 174 | else 175 | error("Specify Optimizer") 176 | end 177 | 178 | -- DataParallelTable's syncParameters 179 | if model.needsSync then 180 | model:syncParameters() 181 | end 182 | 183 | cutorch.synchronize() 184 | batchNumber = batchNumber + 1 185 | loss_epoch = loss_epoch + err 186 | 187 | -- Calculate top-1 error, and print information 188 | print(('Epoch: [%d][%d/%d]\tTime %.3f Err %.4f LR %.0e DataLoadingTime %.3f'):format( 189 | epoch, batchNumber, opt.epochSize, timer:time().real, err, 190 | optimState.learningRate, dataLoadingTime)) 191 | 192 | dataTimer:reset() 193 | end 194 | -------------------------------------------------------------------------------- /transforms.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. 2 | -- All rights reserved. 3 | -- This software is provided for research purposes only. 4 | -- By using this software you agree to the terms of the license file 5 | -- in the root folder. 6 | -- For commercial use, please contact ps-license@tue.mpg.de. 7 | 8 | -- https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua 9 | -- 10 | -- Copyright (c) 2016, Facebook, Inc. 11 | -- All rights reserved. 12 | -- 13 | -- This source code is licensed under the BSD-style license found in the 14 | -- LICENSE file in the root directory of this source tree. An additional grant 15 | -- of patent rights can be found in the PATENTS file in the same directory. 16 | -- 17 | -- Image transforms for data augmentation and input normalization 18 | -- 19 | 20 | require 'image' 21 | 22 | local M = {} 23 | 24 | function M.Compose(transforms) 25 | return function(input) 26 | for _, transform in ipairs(transforms) do 27 | input = transform(input) 28 | end 29 | return input 30 | end 31 | end 32 | 33 | function M.ColorNormalize(meanstd) 34 | return function(img) 35 | img = img:clone() 36 | for i=1,3 do 37 | img[i]:add(-meanstd.mean[i]) 38 | img[i]:div(meanstd.std[i]) 39 | img[3+i]:add(-meanstd.mean[i]) 40 | img[3+i]:div(meanstd.std[i]) 41 | end 42 | return img 43 | end 44 | end 45 | 46 | -- Scales the smaller edge to size 47 | function M.Scale(size, interpolation) 48 | interpolation = interpolation or 'bicubic' 49 | return function(input) 50 | local w, h = input:size(3), input:size(2) 51 | if (w <= h and w == size) or (h <= w and h == size) then 52 | return input 53 | end 54 | if w < h then 55 | return image.scale(input, size, h/w * size, interpolation) 56 | else 57 | return image.scale(input, w/h * size, size, interpolation) 58 | end 59 | end 60 | end 61 | 62 | -- Crop to centered rectangle 63 | function M.CenterCrop(size) 64 | return function(input) 65 | local w1 = math.ceil((input:size(3) - size)/2) 66 | local h1 = math.ceil((input:size(2) - size)/2) 67 | return image.crop(input, w1, h1, w1 + size, h1 + size) -- center patch 68 | end 69 | end 70 | 71 | -- Random crop form larger image with optional zero padding 72 | function M.RandomCrop(size, padding) 73 | padding = padding or 0 74 | 75 | return function(input) 76 | if padding > 0 then 77 | local temp = input.new(3, input:size(2) + 2*padding, input:size(3) + 2*padding) 78 | temp:zero() 79 | :narrow(2, padding+1, input:size(2)) 80 | :narrow(3, padding+1, input:size(3)) 81 | :copy(input) 82 | input = temp 83 | end 84 | 85 | local w, h = input:size(3), input:size(2) 86 | if w == size and h == size then 87 | return input 88 | end 89 | 90 | local x1, y1 = torch.random(0, w - size), torch.random(0, h - size) 91 | local out = image.crop(input, x1, y1, x1 + size, y1 + size) 92 | assert(out:size(2) == size and out:size(3) == size, 'wrong crop size') 93 | return out 94 | end 95 | end 96 | 97 | -- Four corner patches and center crop from image and its horizontal reflection 98 | function M.TenCrop(size) 99 | local centerCrop = M.CenterCrop(size) 100 | 101 | return function(input) 102 | local w, h = input:size(3), input:size(2) 103 | 104 | local output = {} 105 | for _, img in ipairs{input, image.hflip(input)} do 106 | table.insert(output, centerCrop(img)) 107 | table.insert(output, image.crop(img, 0, 0, size, size)) 108 | table.insert(output, image.crop(img, w-size, 0, w, size)) 109 | table.insert(output, image.crop(img, 0, h-size, size, h)) 110 | table.insert(output, image.crop(img, w-size, h-size, w, h)) 111 | end 112 | 113 | -- View as mini-batch 114 | for i, img in ipairs(output) do 115 | output[i] = img:view(1, img:size(1), img:size(2), img:size(3)) 116 | end 117 | 118 | return input.cat(output, 1) 119 | end 120 | end 121 | 122 | -- Resized with shorter side randomly sampled from [minSize, maxSize] (ResNet-style) 123 | function M.RandomScale(minSize, maxSize) 124 | return function(input) 125 | local w, h = input:size(3), input:size(2) 126 | 127 | local targetSz = torch.random(minSize, maxSize) 128 | local targetW, targetH = targetSz, targetSz 129 | if w < h then 130 | targetH = torch.round(h / w * targetW) 131 | else 132 | targetW = torch.round(w / h * targetH) 133 | end 134 | 135 | return image.scale(input, targetW, targetH, 'bicubic') 136 | end 137 | end 138 | 139 | -- Random crop with size 8%-100% and aspect ratio 3/4 - 4/3 (Inception-style) 140 | function M.RandomSizedCrop(size) 141 | local scale = M.Scale(size) 142 | local crop = M.CenterCrop(size) 143 | 144 | return function(input) 145 | local attempt = 0 146 | repeat 147 | local area = input:size(2) * input:size(3) 148 | local targetArea = torch.uniform(0.08, 1.0) * area 149 | 150 | local aspectRatio = torch.uniform(3/4, 4/3) 151 | local w = torch.round(math.sqrt(targetArea * aspectRatio)) 152 | local h = torch.round(math.sqrt(targetArea / aspectRatio)) 153 | 154 | if torch.uniform() < 0.5 then 155 | w, h = h, w 156 | end 157 | 158 | if h <= input:size(2) and w <= input:size(3) then 159 | local y1 = torch.random(0, input:size(2) - h) 160 | local x1 = torch.random(0, input:size(3) - w) 161 | 162 | local out = image.crop(input, x1, y1, x1 + w, y1 + h) 163 | assert(out:size(2) == h and out:size(3) == w, 'wrong crop size') 164 | 165 | return image.scale(out, size, size, 'bicubic') 166 | end 167 | attempt = attempt + 1 168 | until attempt >= 10 169 | 170 | -- fallback 171 | return crop(scale(input)) 172 | end 173 | end 174 | 175 | function M.HorizontalFlip(prob) 176 | return function(input) 177 | if torch.uniform() < prob then 178 | input = image.hflip(input) 179 | end 180 | return input 181 | end 182 | end 183 | 184 | function M.Rotation(deg) 185 | return function(input) 186 | if deg ~= 0 then 187 | input = image.rotate(input, (torch.uniform() - 0.5) * deg * math.pi / 180, 'bilinear') 188 | end 189 | return input 190 | end 191 | end 192 | 193 | -- Lighting noise (AlexNet-style PCA-based noise) 194 | function M.Lighting(alphastd, eigval, eigvec) 195 | return function(input) 196 | if alphastd == 0 then 197 | return input 198 | end 199 | 200 | local alpha = torch.Tensor(3):normal(0, alphastd) 201 | local rgb = eigvec:clone() 202 | :cmul(alpha:view(1, 3):expand(3, 3)) 203 | :cmul(eigval:view(1, 3):expand(3, 3)) 204 | :sum(2) 205 | :squeeze() 206 | 207 | input = input:clone() 208 | for i=1,3 do 209 | input[i]:add(rgb[i]) 210 | input[3+i]:add(rgb[i]) 211 | end 212 | return input 213 | end 214 | end 215 | 216 | local function blend(img1, img2, alpha) 217 | return img1:mul(alpha):add(1 - alpha, img2) 218 | end 219 | 220 | local function grayscale(dst, img) 221 | assert(img:size(1)==3) 222 | 223 | dst[1]:zero() 224 | dst[1]:add(0.299, img[1]):add(0.587, img[2]):add(0.114, img[3]) 225 | dst[2]:copy(dst[1]) 226 | dst[3]:copy(dst[1]) 227 | return dst 228 | end 229 | 230 | function M.Saturation(var) 231 | local gs 232 | 233 | return function(input) 234 | gs = gs or input.new() 235 | gs:resizeAs(input) 236 | 237 | grayscale(gs[{{1,3},{},{}}], input[{{1,3},{},{}}]) 238 | grayscale(gs[{{4,6},{},{}}], input[{{4,6},{},{}}]) 239 | 240 | local alpha = 1.0 + torch.uniform(-var, var) 241 | blend(input, gs, alpha) 242 | return input 243 | end 244 | end 245 | 246 | function M.Brightness(var) 247 | local gs 248 | 249 | return function(input) 250 | gs = gs or input.new() 251 | gs:resizeAs(input):zero() 252 | 253 | local alpha = 1.0 + torch.uniform(-var, var) 254 | blend(input, gs, alpha) 255 | return input 256 | end 257 | end 258 | 259 | function M.Contrast(var) 260 | local gs 261 | 262 | return function(input) 263 | gs = gs or input.new() 264 | gs:resizeAs(input) 265 | 266 | grayscale(gs[{{1,3},{},{}}], input[{{1,3},{},{}}]) 267 | grayscale(gs[{{4,6},{},{}}], input[{{4,6},{},{}}]) 268 | 269 | gs[{{1,3},{},{}}]:fill(gs[1]:mean()) 270 | gs[{{4,6},{},{}}]:fill(gs[4]:mean()) 271 | 272 | local alpha = 1.0 + torch.uniform(-var, var) 273 | blend(input, gs, alpha) 274 | return input 275 | end 276 | end 277 | 278 | function M.RandomOrder(ts) 279 | return function(input) 280 | local img = input.img or input 281 | local order = torch.randperm(#ts) 282 | for i=1,#ts do 283 | img = ts[order[i]](img) 284 | end 285 | return input 286 | end 287 | end 288 | 289 | function M.ColorJitter(opt) 290 | local brightness = opt.brightness or 0 291 | local contrast = opt.contrast or 0 292 | local saturation = opt.saturation or 0 293 | 294 | local ts = {} 295 | if brightness ~= 0 then 296 | table.insert(ts, M.Brightness(brightness)) 297 | end 298 | if contrast ~= 0 then 299 | table.insert(ts, M.Contrast(contrast)) 300 | end 301 | if saturation ~= 0 then 302 | table.insert(ts, M.Saturation(saturation)) 303 | end 304 | 305 | if #ts == 0 then 306 | return function(input) return input end 307 | end 308 | 309 | return M.RandomOrder(ts) 310 | end 311 | 312 | return M 313 | -------------------------------------------------------------------------------- /util.lua: -------------------------------------------------------------------------------- 1 | require 'cunn' 2 | local ffi=require 'ffi' 3 | 4 | function makeDataParallel(model, nGPU) 5 | if nGPU > 1 then 6 | print('converting module to nn.DataParallelTable') 7 | assert(nGPU <= cutorch.getDeviceCount(), 'number of GPUs less than nGPU specified') 8 | local model_single = model 9 | model = nn.DataParallelTable(1) 10 | for i=1, nGPU do 11 | cutorch.setDevice(i) 12 | model:add(model_single:clone():cuda(), i) 13 | end 14 | end 15 | cutorch.setDevice(opt.GPU) 16 | 17 | return model 18 | end 19 | 20 | local function cleanDPT(module) 21 | -- This assumes this DPT was created by the function above: all the 22 | -- module.modules are clones of the same network on different GPUs 23 | -- hence we only need to keep one when saving the model to the disk. 24 | local newDPT = nn.DataParallelTable(1) 25 | cutorch.setDevice(opt.GPU) 26 | newDPT:add(module:get(1), opt.GPU) 27 | return newDPT 28 | end 29 | 30 | function saveDataParallel(filename, model) 31 | if torch.type(model) == 'nn.DataParallelTable' then 32 | torch.save(filename, cleanDPT(model)) 33 | elseif torch.type(model) == 'nn.Sequential' then 34 | local temp_model = nn.Sequential() 35 | for i, module in ipairs(model.modules) do 36 | if torch.type(module) == 'nn.DataParallelTable' then 37 | temp_model:add(cleanDPT(module)) 38 | else 39 | temp_model:add(module) 40 | end 41 | end 42 | torch.save(filename, temp_model) 43 | elseif torch.type(model) == 'nn.gModule' then 44 | torch.save(filename, model) 45 | else 46 | error('This saving function only works with Sequential or DataParallelTable modules.') 47 | end 48 | end 49 | 50 | function loadDataParallel(filename, nGPU) 51 | if opt.backend == 'cudnn' then 52 | require 'cudnn' 53 | end 54 | local model = torch.load(filename) 55 | if torch.type(model) == 'nn.DataParallelTable' then 56 | return makeDataParallel(model:get(1), nGPU) 57 | elseif torch.type(model) == 'nn.Sequential' then 58 | for i,module in ipairs(model.modules) do 59 | if torch.type(module) == 'nn.DataParallelTable' then 60 | model.modules[i] = makeDataParallel(module:get(1):float(), nGPU) 61 | end 62 | end 63 | return model 64 | elseif torch.type(model) == 'nn.gModule' then 65 | return model 66 | else 67 | error('The loaded model is not a Sequential or DataParallelTable module.') 68 | end 69 | end 70 | --------------------------------------------------------------------------------