├── .gitignore
├── EPECriterion.lua
├── LICENSE
├── README.md
├── data.lua
├── dataset.lua
├── donkey.lua
├── extras
├── spybhwd
│ ├── CMakeLists.txt
│ ├── ScaleBHWD.cu
│ ├── ScaleBHWD.lua
│ ├── generic
│ │ └── ScaleBHWD.c
│ ├── init.c
│ ├── init.cu
│ ├── init.lua
│ ├── spybhwd-scm-1.rockspec
│ ├── test.lua
│ ├── utils.c
│ └── utils.h
└── stnbhwd
│ ├── AffineGridGeneratorBHWD.lua
│ ├── AffineTransformMatrixGenerator.lua
│ ├── BilinearSamplerBHWD.cu
│ ├── BilinearSamplerBHWD.lua
│ ├── CMakeLists.txt
│ ├── LICENSE
│ ├── README.md
│ ├── ScaleBHWD.cu
│ ├── ScaleBHWD.lua
│ ├── demo
│ ├── Optim.lua
│ ├── README.md
│ ├── demo_mnist.lua
│ ├── distort_mnist.lua
│ └── spatial_transformer.lua
│ ├── generic
│ ├── BilinearSamplerBHWD.c
│ └── ScaleBHWD.c
│ ├── init.c
│ ├── init.cu
│ ├── init.lua
│ ├── stnbhwd-scm-1.rockspec
│ ├── test.lua
│ ├── utils.c
│ └── utils.h
├── flowExtensions.lua
├── main.lua
├── model.lua
├── models
├── modelL1_3.t7
├── modelL1_4.t7
├── modelL1_C.t7
├── modelL1_F.t7
├── modelL1_K.t7
├── modelL2_3.t7
├── modelL2_4.t7
├── modelL2_C.t7
├── modelL2_F.t7
├── modelL2_K.t7
├── modelL3_3.t7
├── modelL3_4.t7
├── modelL3_C.t7
├── modelL3_F.t7
├── modelL3_K.t7
├── modelL4_3.t7
├── modelL4_4.t7
├── modelL4_C.t7
├── modelL4_F.t7
├── modelL4_K.t7
├── modelL5_3.t7
├── modelL5_4.t7
├── modelL5_C.t7
├── modelL5_F.t7
├── modelL5_K.t7
├── modelL6_C.t7
├── modelL6_F.t7
├── modelL6_K.t7
└── volcon.lua
├── opts.lua
├── samples
├── 00001_flow.flo
├── 00001_img1.ppm
├── 00001_img2.ppm
├── 00002_flow.flo
├── 00002_img1.ppm
├── 00002_img2.ppm
├── 00003_flow.flo
├── 00003_img1.ppm
└── 00003_img2.ppm
├── spynet.lua
├── test.lua
├── timing_benchmark.lua
├── timing_util.lua
├── train.lua
├── train_val_split.txt
├── transforms.lua
└── util.lua
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoint*
2 |
--------------------------------------------------------------------------------
/EPECriterion.lua:
--------------------------------------------------------------------------------
1 |
2 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
3 | -- All rights reserved.
4 | -- This software is provided for research purposes only.
5 | -- By using this software you agree to the terms of the license file
6 | -- in the root folder.
7 | -- For commercial use, please contact ps-license@tue.mpg.de.
8 |
9 | local EPECriterion, parent = torch.class('nn.EPECriterion', 'nn.Criterion')
10 |
11 | -- Computes average endpoint error for batchSize x ChannelSize x Height x Width
12 | -- flow fields or general multidimensional matrices.
13 |
14 | local eps = 1e-12
15 |
16 | function EPECriterion:__init()
17 | parent.__init(self)
18 | self.sizeAverage = true
19 | end
20 |
21 | function EPECriterion:updateOutput(input, target)
22 | assert( input:nElement() == target:nElement(),
23 | "input and target size mismatch")
24 |
25 | self.buffer = self.buffer or input.new()
26 |
27 | local buffer = self.buffer
28 | local output
29 | local npixels
30 |
31 | buffer:resizeAs(input)
32 | npixels = input:nElement()/2 -- 2 channel flow fields
33 |
34 | buffer:add(input, -1, target):pow(2)
35 | output = torch.sum(buffer,2):sqrt() -- second channel is flow
36 | output = output:sum()
37 |
38 | output = output / npixels
39 |
40 | self.output = output
41 |
42 | return self.output
43 | end
44 |
45 | function EPECriterion:updateGradInput(input, target)
46 |
47 | assert( input:nElement() == target:nElement(),
48 | "input and target size mismatch")
49 |
50 | self.buffer = self.buffer or input.new()
51 |
52 | local buffer = self.buffer
53 | local gradInput = self.gradInput
54 | local npixels
55 | local loss
56 |
57 | buffer:resizeAs(input)
58 | npixels = input:nElement()/2
59 |
60 | buffer:add(input, -1, target):pow(2)
61 | loss = torch.sum(buffer,2):sqrt():add(eps) -- forms the denominator
62 | loss = torch.cat(loss, loss, 2) -- Repeat tensor to scale the gradients
63 |
64 | gradInput:resizeAs(input)
65 | gradInput:add(input, -1, target):cdiv(loss)
66 | gradInput = gradInput / npixels
67 | return gradInput
68 | end
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | SPyNet License
2 | For non-commercial and scientific research purposes
3 | Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft
4 |
5 | Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the trained models and code, (the "Model"). By downloading and/or using the Model, you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Model.
6 |
7 | Ownership
8 | The Model has been developed at the Max Planck Institute for Intelligent Systems (hereinafter "MPI") and is owned by and proprietary material of the Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”).
9 |
10 | License Grant
11 | Max-Planck grants you a non-exclusive, non-transferable, free of charge right:
12 |
13 | To download the Model and use it on computers owned, leased or otherwise controlled by you and/or your organisation;
14 | To use the Model for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects.
15 | Any other use, in particular any use for commercial purposes, is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, as training data for a commercial product, for commercial ergonomic analysis (e.g. product design, architectural design, etc.), or production of other artifacts for commercial purposes including, for example, web services, movies, television programs, mobile applications, or video games. The Model may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Model to train methods/algorithms/neural networks/etc. for commercial use of any kind. The Model may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission. By downloading the Model, you agree not to reverse engineer it.
16 |
17 | Disclaimer of Representations and Warranties
18 | You expressly acknowledge and agree that the Model results from basic research, is provided “AS IS”, may contain errors, and that any use of the Model is at your sole risk. MAX-PLANCK MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE MODEL, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, Max-Planck makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Model, (ii) that the use of the Model will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Model will not cause any damage of any kind to you or a third party.
19 |
20 | Limitation of Liability
21 | Under no circumstances shall Max-Planck be liable for any incidental, special, indirect or consequential damages arising out of or relating to this license, including but not limited to, any lost profits, business interruption, loss of programs or other data, or all other commercial damages or losses, even if advised of the possibility thereof.
22 |
23 | No Maintenance Services
24 | You understand and agree that Max-Planck is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Model. Max-Planck nevertheless reserves the right to update, modify, or discontinue the Model at any time.
25 |
26 | Publication
27 | You agree to cite the most recent paper describing the model as specified on the download website. This website lists the most up to date bibliographic information on the about page.
28 |
29 | Media projects
30 | When using the Model in a media project please give credit to Max Planck Institute for Intelligent Systems. For example: SPyNet was used for optical flow estimation courtesy of the Max Planck Institute for Intelligent Systems.
31 | Commercial licensing opportunities
32 | For commercial use, please contact ps-license@tue.mpg.de.
33 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SPyNet: Spatial Pyramid Network for Optical Flow
2 | This code is based on the paper [Optical Flow Estimation using a Spatial Pyramid Network](https://arxiv.org/abs/1611.00850).
3 |
4 | [[Unofficial Pytorch version](https://github.com/sniklaus/pytorch-spynet)] [[Unofficial tensorflow version](https://github.com/tukilabs/Video-Compression-Net/blob/master/utils/network.py)]
5 |
6 | * [First things first:](#setUp) Setting up this code
7 | * [Easy Usage:](#easyUsage) Compute Optical Flow in 5 lines
8 | * [Fast Performance Usage:](#fastPerformanceUsage) Compute Optical Flow at a rocket speed
9 | * [Training:](#training) Train your own models using Spatial Pyramid approach on mulitiple GPUs
10 | * [End2End SPyNet:](#end2end) An easy trainable end-to-end version of SPyNet
11 | * [Optical Flow Utilities:](#flowUtils) A set of functions in lua for working around optical flow
12 | * [References:](#references) For further reading
13 |
14 |
15 | ## First things first
16 | You need to have [Torch.](http://torch.ch/docs/getting-started.html#_)
17 |
18 | Install other required packages
19 | ```bash
20 | cd extras/spybhwd
21 | luarocks make
22 | cd ../stnbhwd
23 | luarocks make
24 | ```
25 |
26 | ## For Easy Usage, follow this
27 | #### Set up SPyNet
28 | ```lua
29 | spynet = require('spynet')
30 | easyComputeFlow = spynet.easy_setup()
31 | ```
32 | #### Load images and compute flow
33 | ```lua
34 | im1 = image.load('samples/00001_img1.ppm' )
35 | im2 = image.load('samples/00001_img2.ppm' )
36 | flow = easyComputeFlow(im1, im2)
37 | ```
38 | To save your flow fields to a .flo file use [flowExtensions.writeFLO](#writeFLO).
39 |
40 |
41 | ## For Fast Performace, follow this (recommended)
42 | #### Set up SPyNet
43 | Set up SPyNet according to the image size and model. For optimal performance, resize your image such that width and height are a multiple of 32. You can also specify your favorite model. The present supported modes are fine tuned models `sintelFinal`(default), `sintelClean`, `kittiFinal`, and base models `chairsFinal` and `chairsClean`.
44 | ```lua
45 | spynet = require('spynet')
46 | computeFlow = spynet.setup(512, 384, 'sintelFinal') -- for 384x512 images
47 | ```
48 | Now you can call computeFlow anytime to estimate optical flow between image pairs.
49 |
50 | #### Computing flow
51 | Load an image pair and stack and normalize it.
52 | ```lua
53 | im1 = image.load('samples/00001_img1.ppm' )
54 | im2 = image.load('samples/00001_img2.ppm' )
55 | im = torch.cat(im1, im2, 1)
56 | im = spynet.normalize(im)
57 | ```
58 | SPyNet works with batches of data on CUDA. So, compute flow using
59 | ```lua
60 | im = im:resize(1, im:size(1), im:size(2), im:size(3)):cuda()
61 | flow = computeFlow(im)
62 | ```
63 | You can also use batch-mode, if your images `im` are a tensor of size `Bx6xHxW`, of batch size B with 6 RGB pair channels. You can directly use:
64 | ```lua
65 | flow = computeFlow(im)
66 | ```
67 |
68 | ## Training
69 | Training sequentially is faster than training end-to-end since you need to learn small number of parameters at each level. To train a level `N`, we need the trained models at levels `1` to `N-1`. You also initialize the model with a pretrained model at `N-1`.
70 |
71 | E.g. To train level 3, we need trained models at `L1` and `L2`, and we initialize it `modelL2_3.t7`.
72 | ```bash
73 | th main.lua -fineWidth 128 -fineHeight 96 -level 3 -netType volcon \
74 | -cache checkpoint -data FLYING_CHAIRS_DIR \
75 | -L1 models/modelL1_3.t7 -L2 models/modelL2_3.t7 \
76 | -retrain models/modelL2_3.t7
77 | ```
78 |
79 | ## End2End SPyNet
80 | The end-to-end version of SPyNet is easily trainable and is available at [anuragranj/end2end-spynet](https://github.com/anuragranj/end2end-spynet).
81 |
82 |
83 | ## Optical Flow Utilities
84 | We provide `flowExtensions.lua` containing various functions to make your life easier with optical flow while using Torch/Lua. You can just copy this file into your project directory and use if off the shelf.
85 | ```lua
86 | flowX = require 'flowExtensions'
87 | ```
88 | #### [flow_magnitude] flowX.computeNorm(flow_x, flow_y)
89 | Given `flow_x` and `flow_y` of size `MxN` each, evaluate `flow_magnitude` of size `MxN`.
90 |
91 | #### [flow_angle] flowX.computeAngle(flow_x, flow_y)
92 | Given `flow_x` and `flow_y` of size `MxN` each, evaluate `flow_angle` of size `MxN` in degrees.
93 |
94 | #### [rgb] flowX.field2rgb(flow_magnitude, flow_angle, [max], [legend])
95 | Given `flow_magnitude` and `flow_angle` of size `MxN` each, return an image of size `3xMxN` for visualizing optical flow. `max`(optional) specifies maximum flow magnitude and `legend`(optional) is boolean that prints a legend on the image.
96 |
97 | #### [rgb] flowX.xy2rgb(flow_x, flow_y, [max])
98 | Given `flow_x` and `flow_y` of size `MxN` each, return an image of size `3xMxN` for visualizing optical flow. `max`(optional) specifies maximum flow magnitude.
99 |
100 | #### [flow] flowX.loadFLO(filename)
101 | Reads a `.flo` file. Loads `x` and `y` components of optical flow in a 2 channel `2xMxN` optical flow field. First channel stores `x` component and second channel stores `y` component.
102 |
103 |
104 | #### flowX.writeFLO(filename,F)
105 | Write a `2xMxN` flow field `F` containing `x` and `y` components of its flow fields in its first and second channel respectively to `filename`, a `.flo` file.
106 |
107 | #### [flow] flowX.loadPFM(filename)
108 | Reads a `.pfm` file. Loads `x` and `y` components of optical flow in a 2 channel `2xMxN` optical flow field. First channel stores `x` component and second channel stores `y` component.
109 |
110 | #### [flow_rotated] flowX.rotate(flow, angle)
111 | Rotates `flow` of size `2xMxN` by `angle` in radians. Uses nearest-neighbor interpolation to avoid blurring at boundaries.
112 |
113 | #### [flow_scaled] flowX.scale(flow, sc, [opt])
114 | Scales `flow` of size `2xMxN` by `sc` times. `opt`(optional) specifies interpolation method, `simple` (default), `bilinear`, and `bicubic`.
115 |
116 | #### [flowBatch_scaled] flowX.scaleBatch(flowBatch, sc)
117 | Scales `flowBatch` of size `Bx2xMxN`, a batch of `B` flow fields by `sc` times. Uses nearest-neighbor interpolation.
118 |
119 |
120 | ## Timing Benchmarks
121 | Our timing benchmark is set up on Flying chair dataset. To test it, you need to download
122 | ```bash
123 | wget http://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs/FlyingChairs.zip
124 | ```
125 | Run the timing benchmark
126 | ```bash
127 | th timing_benchmark.lua -data YOUR_FLYING_CHAIRS_DATA_DIRECTORY
128 | ```
129 |
130 |
131 | ## References
132 | 1. Our warping code is based on [qassemoquab/stnbhwd.](https://github.com/qassemoquab/stnbhwd)
133 | 2. The images in `samples` are from Flying Chairs dataset:
134 | Dosovitskiy, Alexey, et al. "Flownet: Learning optical flow with convolutional networks." 2015 IEEE International Conference on Computer Vision (ICCV). IEEE, 2015.
135 | 3. Some parts of `flowExtensions.lua` are adapted from [marcoscoffier/optical-flow](https://github.com/marcoscoffier/optical-flow/blob/master/init.lua) with help from [fguney](https://github.com/fguney).
136 | 4. The unofficial PyTorch implementation is from [sniklaus](https://github.com/sniklaus).
137 |
138 | ## License
139 | Free for non-commercial and scientific research purposes. For commercial use, please contact ps-license@tue.mpg.de. Check LICENSE file for details.
140 |
141 | ## When using this code, please cite
142 | Ranjan, Anurag, and Michael J. Black. "Optical Flow Estimation using a Spatial Pyramid Network." arXiv preprint arXiv:1611.00850 (2016).
143 |
--------------------------------------------------------------------------------
/data.lua:
--------------------------------------------------------------------------------
1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
2 | -- All rights reserved.
3 | -- This software is provided for research purposes only.
4 | -- By using this software you agree to the terms of the license file
5 | -- in the root folder.
6 | -- For commercial use, please contact ps-license@tue.mpg.de.
7 | --
8 | -- Copyright (c) 2014, Facebook, Inc.
9 | -- All rights reserved.
10 | --
11 | -- This source code is licensed under the BSD-style license found in the
12 | -- LICENSE file in the root directory of this source tree. An additional grant
13 | -- of patent rights can be found in the PATENTS file in the same directory.
14 | --
15 | local ffi = require 'ffi'
16 | local Threads = require 'threads'
17 | Threads.serialization('threads.sharedserialize')
18 |
19 | -- This script contains the logic to create K threads for parallel data-loading.
20 | -- For the data-loading details, look at donkey.lua
21 | -------------------------------------------------------------------------------
22 | do -- start K datathreads (donkeys)
23 | if opt.nDonkeys > 0 then
24 | local options = opt -- make an upvalue to serialize over to donkey threads
25 | donkeys = Threads(
26 | opt.nDonkeys,
27 | function()
28 | require 'torch'
29 | end,
30 | function(idx)
31 | opt = options -- pass to all donkeys via upvalue
32 | tid = idx
33 | local seed = opt.manualSeed + idx
34 | torch.manualSeed(seed)
35 | print(string.format('Starting donkey with id: %d seed: %d', tid, seed))
36 | paths.dofile('donkey.lua')
37 | end
38 | );
39 | else -- single threaded data loading. useful for debugging
40 | paths.dofile('donkey.lua')
41 | donkeys = {}
42 | function donkeys:addjob(f1, f2) f2(f1()) end
43 | function donkeys:synchronize() end
44 | end
45 | end
46 |
47 | nTest = 0
48 | donkeys:addjob(function() return testLoader:size() end, function(c) nTest = c end)
49 | donkeys:synchronize()
50 | assert(nTest > 0, "Failed to get nTest")
51 | print('nTest: ', nTest)
52 |
--------------------------------------------------------------------------------
/dataset.lua:
--------------------------------------------------------------------------------
1 |
2 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
3 | -- All rights reserved.
4 | -- This software is provided for research purposes only.
5 | -- By using this software you agree to the terms of the license file
6 | -- in the root folder.
7 | -- For commercial use, please contact ps-license@tue.mpg.de.
8 |
9 | require 'torch'
10 | torch.setdefaulttensortype('torch.FloatTensor')
11 | local ffi = require 'ffi'
12 | local class = require('pl.class')
13 | local dir = require 'pl.dir'
14 | local tablex = require 'pl.tablex'
15 | local argcheck = require 'argcheck'
16 | require 'sys'
17 | require 'xlua'
18 | require 'image'
19 |
20 | local dataset = torch.class('dataLoader')
21 |
22 | local initcheck = argcheck{
23 | pack=true,
24 | help=[[
25 | A dataset class for images in a flat folder structure (folder-name is class-name).
26 | Optimized for extremely large datasets (upwards of 14 million images).
27 | Tested only on Linux (as it uses command-line linux utilities to scale up)
28 | ]],
29 | {name="inputSize",
30 | type="table",
31 | help="the size of the input images"},
32 |
33 | {name="outputSize",
34 | type="table",
35 | help="the size of the network output"},
36 |
37 | {name="split",
38 | type="number",
39 | help="Percentage of split to go to Training"
40 | },
41 |
42 | {name="samplingMode",
43 | type="string",
44 | help="Sampling mode: random | balanced ",
45 | default = "balanced"},
46 |
47 | {name="verbose",
48 | type="boolean",
49 | help="Verbose mode during initialization",
50 | default = false},
51 |
52 | {name="loadSize",
53 | type="table",
54 | help="a size to load the images to, initially",
55 | opt = true},
56 |
57 | {name="samplingIds",
58 | type="torch.LongTensor",
59 | help="the ids of training or testing images",
60 | opt = true},
61 |
62 | {name="sampleHookTrain",
63 | type="function",
64 | help="applied to sample during training(ex: for lighting jitter). "
65 | .. "It takes the image path as input",
66 | opt = true},
67 |
68 | {name="sampleHookTest",
69 | type="function",
70 | help="applied to sample during testing",
71 | opt = true},
72 | }
73 |
74 | function dataset:__init(...)
75 |
76 | -- argcheck
77 | local args = initcheck(...)
78 | print(args)
79 | for k,v in pairs(args) do self[k] = v end
80 |
81 | if not self.loadSize then self.loadSize = self.inputSize; end
82 |
83 | if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end
84 | if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end
85 |
86 | local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end
87 |
88 | self.numSamples = self.samplingIds:size()[1]
89 | assert(self.numSamples > 0, "Could not find any sample in the given input paths")
90 |
91 | if self.verbose then print(self.numSamples .. ' samples found.') end
92 | end
93 |
94 | function dataset:size(class, list)
95 | return self.numSamples
96 | end
97 |
98 | -- converts a table of samples (and corresponding labels) to a clean tensor
99 | local function tableToOutput(self, imgTable, outputTable)
100 | local images, outputs
101 | local quantity = #imgTable
102 | assert(imgTable[1]:size()[1] == self.inputSize[1])
103 | assert(outputTable[1]:size()[1] == self.outputSize[1])
104 |
105 | images = torch.Tensor(quantity,
106 | self.inputSize[1], self.inputSize[2], self.inputSize[3])
107 | outputs = torch.Tensor(quantity,
108 | self.outputSize[1], self.outputSize[2], self.outputSize[3])
109 |
110 | for i=1,quantity do
111 | images[i]:copy(imgTable[i])
112 | outputs[i]:copy(outputTable[i])
113 | end
114 | return images, outputs
115 | end
116 |
117 | -- sampler, samples from the training set.
118 | function dataset:sample(quantity)
119 | assert(quantity)
120 | local imgTable = {}
121 | local outputTable = {}
122 | for i=1,quantity do
123 | local id = torch.random(1, self.numSamples)
124 | local img, output = self:sampleHookTrain(self.samplingIds[id][1]) -- single element[not tensor] from a row
125 |
126 | table.insert(imgTable, img)
127 | table.insert(outputTable, output)
128 | end
129 | local images, outputs = tableToOutput(self, imgTable, outputTable)
130 | return images, outputs
131 | end
132 |
133 | function dataset:get(i1, i2)
134 | local indices = self.samplingIds[{{i1, i2}}];
135 | local quantity = i2 - i1 + 1;
136 | assert(quantity > 0)
137 | local imgTable = {}
138 | local outputTable = {}
139 | for i=1,quantity do
140 | local img, output = self:sampleHookTest(indices[i][1])
141 | table.insert(imgTable, img)
142 | table.insert(outputTable, output)
143 | end
144 | local images, outputs = tableToOutput(self, imgTable, outputTable)
145 | return images, outputs
146 | end
147 |
148 | return dataset
149 |
--------------------------------------------------------------------------------
/donkey.lua:
--------------------------------------------------------------------------------
1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
2 | -- All rights reserved.
3 | -- This software is provided for research purposes only.
4 | -- By using this software you agree to the terms of the license file
5 | -- in the root folder.
6 | -- For commercial use, please contact ps-license@tue.mpg.de.
7 |
8 | require 'image'
9 | require 'nn'
10 | require 'cunn'
11 | require 'cudnn'
12 | require 'nngraph'
13 | require 'stn'
14 | require 'spy'
15 |
16 | local flowX = require 'flowExtensions'
17 | local TF = require 'transforms'
18 |
19 | paths.dofile('dataset.lua')
20 | paths.dofile('util.lua')
21 |
22 | -- This file contains the data-loading logic and details.
23 | -- It is run by each data-loader thread.
24 | ------------------------------------------
25 | local eps = 1e-6
26 | -- a cache file of the training metadata (if doesnt exist, will be created)
27 | local trainCache = paths.concat(opt.cache, 'trainCache.t7')
28 | local testCache = paths.concat(opt.cache, 'testCache.t7')
29 |
30 | local meanstd = {
31 | mean = { 0.485, 0.456, 0.406 },
32 | std = { 0.229, 0.224, 0.225 },
33 | }
34 | local pca = {
35 | eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 },
36 | eigvec = torch.Tensor{
37 | { -0.5675, 0.7192, 0.4009 },
38 | { -0.5808, -0.0045, -0.8140 },
39 | { -0.5836, -0.6948, 0.4203 },
40 | },
41 | }
42 |
43 | local mean = meanstd.mean
44 | local std = meanstd.std
45 | ------------------------------------------
46 | -- Warping Function:
47 | local function createWarpModel()
48 | local imgData = nn.Identity()()
49 | local floData = nn.Identity()()
50 |
51 | local imgOut = nn.Transpose({2,3},{3,4})(imgData)
52 | local floOut = nn.Transpose({2,3},{3,4})(floData)
53 |
54 | local warpImOut = nn.Transpose({3,4},{2,3})(nn.BilinearSamplerBHWD()({imgOut, floOut}))
55 | local model = nn.gModule({imgData, floData}, {warpImOut})
56 |
57 | return model
58 | end
59 |
60 | local modelL1, modelL2, modelL3, modelL4
61 | local modelL1path, modelL2path, modelL3path, modelL4path
62 | local down1, down2, down3, down4, up2, up3, up4
63 | local warpmodel2, warpmodel3, warpmodel4
64 |
65 | modelL1path = opt.L1
66 | modelL2path = opt.L2
67 | modelL3path = opt.L3
68 | modelL4path = opt.L4
69 |
70 | if opt.level > 1 then
71 | -- Load modelL1
72 | modelL1 = torch.load(modelL1path)
73 | if torch.type(modelL1) == 'nn.DataParallelTable' then
74 | modelL1 = modelL1:get(1)
75 | end
76 | modelL1:evaluate()
77 | down1 = nn.SpatialAveragePooling(2,2,2,2):cuda()
78 | down1:evaluate()
79 | end
80 |
81 | if opt.level > 2 then
82 | -- Load modelL2
83 | modelL2 = torch.load(modelL2path)
84 | if torch.type(modelL2) == 'nn.DataParallelTable' then
85 | modelL2 = modelL2:get(1)
86 | end
87 | modelL2:evaluate()
88 |
89 | down2 = nn.SpatialAveragePooling(2,2,2,2):cuda()
90 | up2 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda()
91 | warpmodel2 = createWarpModel():cuda()
92 |
93 | down2:evaluate()
94 | up2:evaluate()
95 | warpmodel2:evaluate()
96 | end
97 |
98 | if opt.level > 3 then
99 | -- Load modelL3
100 | modelL3 = torch.load(modelL3path)
101 | if torch.type(modelL3) == 'nn.DataParallelTable' then
102 | modelL3 = modelL3:get(1)
103 | end
104 | modelL3:evaluate()
105 |
106 | down3 = nn.SpatialAveragePooling(2,2,2,2):cuda()
107 | up3 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda()
108 | warpmodel3 = createWarpModel():cuda()
109 |
110 | down3:evaluate()
111 | up3:evaluate()
112 | warpmodel3:evaluate()
113 | end
114 |
115 | if opt.level > 4 then
116 | -- Load modelL4
117 | modelL4 = torch.load(modelL4path)
118 | if torch.type(modelL4) == 'nn.DataParallelTable' then
119 | modelL4 = modelL4:get(1)
120 | end
121 | modelL4:evaluate()
122 |
123 | down4 = nn.SpatialAveragePooling(2,2,2,2):cuda()
124 | up4 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda()
125 | warpmodel4 = createWarpModel():cuda()
126 |
127 | down4:evaluate()
128 | up4:evaluate()
129 | warpmodel4:evaluate()
130 | end
131 |
132 | -- Check for existence of opt.data
133 | if not os.execute('cd ' .. opt.data) then
134 | error(("could not chdir to '%s'"):format(opt.data))
135 | end
136 |
137 | local loadSize = opt.loadSize
138 | local inputSize = {8, opt.fineHeight, opt.fineWidth}
139 | local outputSize = {2, opt.fineHeight, opt.fineWidth}
140 |
141 | local function getTrainValidationSplits(path)
142 | local numSamples = sys.fexecute( "ls " .. opt.data .. "| wc -l")/3
143 | local ff = torch.DiskFile(path, 'r')
144 | local trainValidationSamples = torch.IntTensor(numSamples)
145 | ff:readInt(trainValidationSamples:storage())
146 | ff:close()
147 |
148 | local train_samples = trainValidationSamples:eq(1):nonzero()
149 | local validation_samples = trainValidationSamples:eq(2):nonzero()
150 |
151 | return train_samples, validation_samples
152 | end
153 |
154 | local train_samples, validation_samples = getTrainValidationSplits(opt.trainValidationSplit)
155 |
156 | local function loadImage(path)
157 | local input = image.load(path, 3, 'float')
158 | return input
159 | end
160 |
161 | local function rotateFlow(flow, angle)
162 | local flow_rot = image.rotate(flow, angle)
163 | local fu = torch.mul(flow_rot[1], math.cos(-angle)) - torch.mul(flow_rot[2], math.sin(-angle))
164 | local fv = torch.mul(flow_rot[1], math.sin(-angle)) + torch.mul(flow_rot[2], math.cos(-angle))
165 | flow_rot[1]:copy(fu)
166 | flow_rot[2]:copy(fv)
167 |
168 | return flow_rot
169 | end
170 |
171 | local function scaleFlow(flow, height, width)
172 | -- scale the original flow to a flow of size height x width
173 | local sc = height/flow:size(2)
174 | assert(torch.abs(width/flow:size(3) - sc)= 4.6.2 or change your OS to enable OpenMP")
30 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unknown-pragmas")
31 | SET(WITH_OPENMP OFF CACHE BOOL "OpenMP support if available?" FORCE)
32 | ENDIF ()
33 | ENDIF ()
34 |
35 | IF (WITH_OPENMP)
36 | FIND_PACKAGE(OpenMP)
37 | IF(OPENMP_FOUND)
38 | MESSAGE(STATUS "Compiling with OpenMP support")
39 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
40 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
41 | SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
42 | ENDIF(OPENMP_FOUND)
43 | ENDIF (WITH_OPENMP)
44 |
45 | LINK_DIRECTORIES("${Torch_INSTALL_LIB}")
46 |
47 | SET(src init.c)
48 | FILE(GLOB luasrc *.lua)
49 | ADD_TORCH_PACKAGE(spy "${src}" "${luasrc}")
50 | TARGET_LINK_LIBRARIES(spy luaT TH)
51 |
52 |
53 | FIND_PACKAGE(CUDA 5.5)
54 |
55 | IF (CUDA_FOUND)
56 | LIST(APPEND CUDA_NVCC_FLAGS "-arch=sm_20")
57 | LIST(APPEND CUDA_NVCC_FLAGS "-Xcompiler -std=c++98")
58 |
59 | INCLUDE_DIRECTORIES("${Torch_INSTALL_INCLUDE}/THC")
60 | SET(src-cuda init.cu)
61 | CUDA_ADD_LIBRARY(cuspy MODULE ${src-cuda})
62 | TARGET_LINK_LIBRARIES(cuspy luaT THC TH)
63 | IF(APPLE)
64 | SET_TARGET_PROPERTIES(cuspy PROPERTIES
65 | LINK_FLAGS "-undefined dynamic_lookup")
66 | ENDIF()
67 | ### Torch packages supposes libraries prefix is "lib"
68 | SET_TARGET_PROPERTIES(cuspy PROPERTIES
69 | PREFIX "lib"
70 | IMPORT_PREFIX "lib")
71 |
72 | INSTALL(TARGETS cuspy
73 | RUNTIME DESTINATION "${Torch_INSTALL_LUA_CPATH_SUBDIR}"
74 | LIBRARY DESTINATION "${Torch_INSTALL_LUA_CPATH_SUBDIR}")
75 | ENDIF(CUDA_FOUND)
76 |
--------------------------------------------------------------------------------
/extras/spybhwd/ScaleBHWD.lua:
--------------------------------------------------------------------------------
1 | local ScaleBHWD, parent = torch.class('nn.ScaleBHWD', 'nn.Module')
2 |
3 | --[[
4 | ScaleBHWD() :
5 | ScaleBHWD:updateOutput({inputImages, grids})
6 | ScaleBHWD:updateGradInput({inputImages, grids}, gradOutput)
7 |
8 | ScaleBHWD will perform bilinear sampling of the input images according to the
9 | normalized coordinates provided in the grid. Output will be of same size as the grids,
10 | with as many features as the input images.
11 |
12 | - inputImages has to be in BHWD layout
13 |
14 | - grids have to be in BHWD layout, with dim(D)=2
15 | - grids contains, for each sample (first dim), the normalized coordinates of the output wrt the input sample
16 | - first coordinate is Y coordinate, second is X
17 | - normalized coordinates : (-1,-1) points to top left, (-1,1) points to top right
18 | - if the normalized coordinates fall outside of the image, then output will be filled with zeros
19 | ]]
20 |
21 | function ScaleBHWD:__init(scale)
22 | parent.__init(self)
23 | self.scale = scale or 1
24 | end
25 |
26 | function ScaleBHWD:check(input, gradOutput)
27 | local inputImages = input
28 | -- local grids = input[2]
29 |
30 | assert(inputImages:isContiguous(), 'Input images have to be contiguous')
31 | assert(inputImages:nDimension()==4)
32 | -- assert(grids:nDimension()==4)
33 | -- assert(inputImages:size(1)==grids:size(1)) -- batch
34 | -- assert(grids:size(4)==2) -- coordinates
35 |
36 | -- if gradOutput then
37 | -- TODO: checks for output size here
38 | -- assert(inputImages:size(1)==gradOutput:size(1))
39 | -- assert(inputImages:size(2)==gradOutput:size(2))
40 | -- assert(inputImages:size(3)==gradOutput:size(3))
41 | -- end
42 | end
43 |
44 | local function addOuterDim(t)
45 | local sizes = t:size()
46 | local newsizes = torch.LongStorage(sizes:size()+1)
47 | newsizes[1]=1
48 | for i=1,sizes:size() do
49 | newsizes[i+1]=sizes[i]
50 | end
51 | return t:view(newsizes)
52 | end
53 |
54 | function ScaleBHWD:updateOutput(input)
55 | local _inputImages = input
56 | -- local _grids = input[2]
57 |
58 | local inputImages
59 | if _inputImages:nDimension()==3 then
60 | inputImages = addOuterDim(_inputImages)
61 | -- grids = addOuterDim(_grids)
62 | else
63 | inputImages = _inputImages
64 | -- grids = _grids
65 | end
66 |
67 | local input = inputImages
68 |
69 | self:check(input)
70 |
71 | self.output:resize(inputImages:size(1), self.scale*inputImages:size(2), self.scale*inputImages:size(3), inputImages:size(4))
72 |
73 | inputImages.nn.ScaleBHWD_updateOutput(self, inputImages, self.output)
74 |
75 | if _inputImages:nDimension()==3 then
76 | self.output=self.output:select(1,1)
77 | end
78 |
79 | return self.output
80 | end
81 |
82 | function ScaleBHWD:updateGradInput(_input, _gradOutput)
83 | self.gradInput:resizeAs(input)
84 | local _inputImages = _input
85 |
86 | local inputImages, gradOutput
87 | if _inputImages:nDimension()==3 then
88 | inputImages = addOuterDim(_inputImages)
89 | gradOutput = addOuterDim(_gradOutput)
90 | else
91 | inputImages = _inputImages
92 | gradOutput = _gradOutput
93 | end
94 |
95 | local input = inputImages
96 |
97 | self:check(input, gradOutput)
98 | -- for i=1,#input do
99 | self.gradInput = self.gradInput or input.new()
100 | self.gradInput:resizeAs(input):zero()
101 | -- end
102 |
103 | local gradInputImages = self.gradInput[1]
104 | --local gradGrids = self.gradInput[2]
105 |
106 | inputImages.nn.ScaleBHWD_updateGradInput(self, inputImages, gradInputImages, gradOutput)
107 |
108 | if _gradOutput:nDimension()==3 then
109 | self.gradInput=self.gradInput:select(1,1)
110 | -- self.gradInput[2]=self.gradInput[2]:select(1,1)
111 | end
112 |
113 | return self.gradInput
114 | end
115 |
--------------------------------------------------------------------------------
/extras/spybhwd/generic/ScaleBHWD.c:
--------------------------------------------------------------------------------
1 | #ifndef TH_GENERIC_FILE
2 | #define TH_GENERIC_FILE "generic/ScaleBHWD.c"
3 | #else
4 |
5 | #include
6 |
7 |
8 | static int nn_(ScaleBHWD_updateOutput)(lua_State *L)
9 | {
10 | THTensor *inputImages = luaT_checkudata(L, 2, torch_Tensor);
11 | //THTensor *grids = luaT_checkudata(L, 3, torch_Tensor);
12 | //real scale = luaT_getfieldchecknumber(L, 3, "scale");
13 | THTensor *output = luaT_checkudata(L, 3, torch_Tensor);
14 |
15 | int batchsize = inputImages->size[0];
16 | int inputImages_height = inputImages->size[1];
17 | int inputImages_width = inputImages->size[2];
18 | int output_height = output->size[1];
19 | int output_width = output->size[2];
20 | int inputImages_channels = inputImages->size[3];
21 |
22 | int output_strideBatch = output->stride[0];
23 | int output_strideHeight = output->stride[1];
24 | int output_strideWidth = output->stride[2];
25 |
26 | int inputImages_strideBatch = inputImages->stride[0];
27 | int inputImages_strideHeight = inputImages->stride[1];
28 | int inputImages_strideWidth = inputImages->stride[2];
29 |
30 | // int grids_strideBatch = grids->stride[0];
31 | // int grids_strideHeight = grids->stride[1];
32 | // int grids_strideWidth = grids->stride[2];
33 |
34 | real *inputImages_data, *output_data;
35 | inputImages_data = THTensor_(data)(inputImages);
36 | output_data = THTensor_(data)(output);
37 | // grids_data = THTensor_(data)(grids);
38 |
39 | int b, yOut, xOut;
40 |
41 | for(b=0; b < batchsize; b++)
42 | {
43 | for(yOut=0; yOut < output_height; yOut++)
44 | {
45 | for(xOut=0; xOut < output_width; xOut++)
46 | {
47 | //read the grid
48 | //real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth];
49 | //real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1];
50 |
51 | // get the weights for interpolation
52 | int yInTopLeft, xInTopLeft;
53 | real yWeightTopLeft, xWeightTopLeft;
54 |
55 | real xcoord = (inputImages_width - 1)*xOut / (output_width -1);
56 | xInTopLeft = floor(xcoord);
57 | xWeightTopLeft = 1 - (xcoord - xInTopLeft);
58 |
59 | real ycoord = (inputImages_height -1)*yOut / (output_height -1);
60 | yInTopLeft = floor(ycoord);
61 | yWeightTopLeft = 1 - (ycoord - yInTopLeft);
62 |
63 |
64 |
65 | const int outAddress = output_strideBatch * b + output_strideHeight * yOut + output_strideWidth * xOut;
66 | const int inTopLeftAddress = inputImages_strideBatch * b + inputImages_strideHeight * yInTopLeft + inputImages_strideWidth * xInTopLeft;
67 | const int inTopRightAddress = inTopLeftAddress + inputImages_strideWidth;
68 | const int inBottomLeftAddress = inTopLeftAddress + inputImages_strideHeight;
69 | const int inBottomRightAddress = inBottomLeftAddress + inputImages_strideWidth;
70 |
71 | real v=0;
72 | real inTopLeft=0;
73 | real inTopRight=0;
74 | real inBottomLeft=0;
75 | real inBottomRight=0;
76 |
77 | // we are careful with the boundaries
78 | bool topLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
79 | bool topRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
80 | bool bottomLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;
81 | bool bottomRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;
82 |
83 | int t;
84 | // interpolation happens here
85 | for(t=0; tsize[0];
119 | int inputImages_height = inputImages->size[1];
120 | int inputImages_width = inputImages->size[2];
121 | int gradOutput_height = gradOutput->size[1];
122 | int gradOutput_width = gradOutput->size[2];
123 | int inputImages_channels = inputImages->size[3];
124 |
125 | int gradOutput_strideBatch = gradOutput->stride[0];
126 | int gradOutput_strideHeight = gradOutput->stride[1];
127 | int gradOutput_strideWidth = gradOutput->stride[2];
128 |
129 | int inputImages_strideBatch = inputImages->stride[0];
130 | int inputImages_strideHeight = inputImages->stride[1];
131 | int inputImages_strideWidth = inputImages->stride[2];
132 |
133 | int gradInputImages_strideBatch = gradInputImages->stride[0];
134 | int gradInputImages_strideHeight = gradInputImages->stride[1];
135 | int gradInputImages_strideWidth = gradInputImages->stride[2];
136 |
137 | // int grids_strideBatch = grids->stride[0];
138 | // int grids_strideHeight = grids->stride[1];
139 | // int grids_strideWidth = grids->stride[2];
140 |
141 | // int gradGrids_strideBatch = gradGrids->stride[0];
142 | // int gradGrids_strideHeight = gradGrids->stride[1];
143 | // int gradGrids_strideWidth = gradGrids->stride[2];
144 |
145 | real *inputImages_data, *gradOutput_data, *gradInputImages_data;
146 | inputImages_data = THTensor_(data)(inputImages);
147 | gradOutput_data = THTensor_(data)(gradOutput);
148 | // grids_data = THTensor_(data)(grids);
149 | // gradGrids_data = THTensor_(data)(gradGrids);
150 | gradInputImages_data = THTensor_(data)(gradInputImages);
151 |
152 | int b, yOut, xOut;
153 |
154 | for(b=0; b < batchsize; b++)
155 | {
156 | for(yOut=0; yOut < gradOutput_height; yOut++)
157 | {
158 | for(xOut=0; xOut < gradOutput_width; xOut++)
159 | {
160 | //read the grid
161 | //real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth];
162 | //real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1];
163 |
164 | // get the weights for interpolation
165 | int yInTopLeft, xInTopLeft;
166 | real yWeightTopLeft, xWeightTopLeft;
167 |
168 | real xcoord = (inputImages_width - 1)*xOut / (gradOutput_width -1);
169 | xInTopLeft = floor(xcoord);
170 | xWeightTopLeft = 1 - (xcoord - xInTopLeft);
171 |
172 | real ycoord = (inputImages_height -1)*yOut / (gradOutput_height -1);
173 | yInTopLeft = floor(ycoord);
174 | yWeightTopLeft = 1 - (ycoord - yInTopLeft);
175 |
176 |
177 | const int inTopLeftAddress = inputImages_strideBatch * b + inputImages_strideHeight * yInTopLeft + inputImages_strideWidth * xInTopLeft;
178 | const int inTopRightAddress = inTopLeftAddress + inputImages_strideWidth;
179 | const int inBottomLeftAddress = inTopLeftAddress + inputImages_strideHeight;
180 | const int inBottomRightAddress = inBottomLeftAddress + inputImages_strideWidth;
181 |
182 | const int gradInputImagesTopLeftAddress = gradInputImages_strideBatch * b + gradInputImages_strideHeight * yInTopLeft + gradInputImages_strideWidth * xInTopLeft;
183 | const int gradInputImagesTopRightAddress = gradInputImagesTopLeftAddress + gradInputImages_strideWidth;
184 | const int gradInputImagesBottomLeftAddress = gradInputImagesTopLeftAddress + gradInputImages_strideHeight;
185 | const int gradInputImagesBottomRightAddress = gradInputImagesBottomLeftAddress + gradInputImages_strideWidth;
186 |
187 | const int gradOutputAddress = gradOutput_strideBatch * b + gradOutput_strideHeight * yOut + gradOutput_strideWidth * xOut;
188 |
189 | real topLeftDotProduct = 0;
190 | real topRightDotProduct = 0;
191 | real bottomLeftDotProduct = 0;
192 | real bottomRightDotProduct = 0;
193 |
194 | real v=0;
195 | real inTopLeft=0;
196 | real inTopRight=0;
197 | real inBottomLeft=0;
198 | real inBottomRight=0;
199 |
200 | // we are careful with the boundaries
201 | bool topLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
202 | bool topRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
203 | bool bottomLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;
204 | bool bottomRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;
205 |
206 | int t;
207 |
208 | for(t=0; t= 7.0",
18 | "nn >= 1.0",
19 | }
20 |
21 | build = {
22 | type = "command",
23 | build_command = [[
24 | cmake -E make_directory build && cd build && cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)" && $(MAKE)
25 | ]],
26 | install_command = "cd build && $(MAKE) install"
27 | }
28 |
--------------------------------------------------------------------------------
/extras/spybhwd/test.lua:
--------------------------------------------------------------------------------
1 | -- you can easily test specific units like this:
2 | -- th -lnn -e "nn.test{'LookupTable'}"
3 | -- th -lnn -e "nn.test{'LookupTable', 'Add'}"
4 |
5 | local mytester = torch.Tester()
6 | local jac
7 | local sjac
8 |
9 | local precision = 1e-5
10 | local expprecision = 1e-4
11 |
12 | local stntest = {}
13 |
14 | function stntest.AffineGridGeneratorBHWD_batch()
15 | local nframes = torch.random(2,10)
16 | local height = torch.random(2,5)
17 | local width = torch.random(2,5)
18 | local input = torch.zeros(nframes, 2, 3):uniform()
19 | local module = nn.AffineGridGeneratorBHWD(height, width)
20 |
21 | local err = jac.testJacobian(module,input)
22 | mytester:assertlt(err,precision, 'error on state ')
23 |
24 | -- IO
25 | local ferr,berr = jac.testIO(module,input)
26 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
27 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
28 |
29 | end
30 |
31 | function stntest.AffineGridGeneratorBHWD_single()
32 | local height = torch.random(2,5)
33 | local width = torch.random(2,5)
34 | local input = torch.zeros(2, 3):uniform()
35 | local module = nn.AffineGridGeneratorBHWD(height, width)
36 |
37 | local err = jac.testJacobian(module,input)
38 | mytester:assertlt(err,precision, 'error on state ')
39 |
40 | -- IO
41 | local ferr,berr = jac.testIO(module,input)
42 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
43 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
44 |
45 | end
46 |
47 | function stntest.BilinearSamplerBHWD_batch()
48 | local nframes = torch.random(2,10)
49 | local height = torch.random(1,5)
50 | local width = torch.random(1,5)
51 | local channels = torch.random(1,6)
52 | local inputImages = torch.zeros(nframes, height, width, channels):uniform()
53 | local grids = torch.zeros(nframes, height, width, 2):uniform(-1, 1)
54 | local module = nn.BilinearSamplerBHWD()
55 |
56 | -- test input images (first element of input table)
57 | module._updateOutput = module.updateOutput
58 | function module:updateOutput(input)
59 | return self:_updateOutput({input, grids})
60 | end
61 |
62 | module._updateGradInput = module.updateGradInput
63 | function module:updateGradInput(input, gradOutput)
64 | self:_updateGradInput({input, grids}, gradOutput)
65 | return self.gradInput[1]
66 | end
67 |
68 | local errImages = jac.testJacobian(module,inputImages)
69 | mytester:assertlt(errImages,precision, 'error on state ')
70 |
71 | -- test grids (second element of input table)
72 | function module:updateOutput(input)
73 | return self:_updateOutput({inputImages, input})
74 | end
75 |
76 | function module:updateGradInput(input, gradOutput)
77 | self:_updateGradInput({inputImages, input}, gradOutput)
78 | return self.gradInput[2]
79 | end
80 |
81 | local errGrids = jac.testJacobian(module,grids)
82 | mytester:assertlt(errGrids,precision, 'error on state ')
83 | end
84 |
85 | function stntest.BilinearSamplerBHWD_single()
86 | local height = torch.random(1,5)
87 | local width = torch.random(1,5)
88 | local channels = torch.random(1,6)
89 | local inputImages = torch.zeros(height, width, channels):uniform()
90 | local grids = torch.zeros(height, width, 2):uniform(-1, 1)
91 | local module = nn.BilinearSamplerBHWD()
92 |
93 | -- test input images (first element of input table)
94 | module._updateOutput = module.updateOutput
95 | function module:updateOutput(input)
96 | return self:_updateOutput({input, grids})
97 | end
98 |
99 | module._updateGradInput = module.updateGradInput
100 | function module:updateGradInput(input, gradOutput)
101 | self:_updateGradInput({input, grids}, gradOutput)
102 | return self.gradInput[1]
103 | end
104 |
105 | local errImages = jac.testJacobian(module,inputImages)
106 | mytester:assertlt(errImages,precision, 'error on state ')
107 |
108 | -- test grids (second element of input table)
109 | function module:updateOutput(input)
110 | return self:_updateOutput({inputImages, input})
111 | end
112 |
113 | function module:updateGradInput(input, gradOutput)
114 | self:_updateGradInput({inputImages, input}, gradOutput)
115 | return self.gradInput[2]
116 | end
117 |
118 | local errGrids = jac.testJacobian(module,grids)
119 | mytester:assertlt(errGrids,precision, 'error on state ')
120 | end
121 |
122 | function stntest.AffineTransformMatrixGenerator_batch()
123 | -- test all possible transformations
124 | for _,useRotation in pairs{true,false} do
125 | for _,useScale in pairs{true,false} do
126 | for _,useTranslation in pairs{true,false} do
127 | local currTest = ''
128 | if useRotation then currTest = currTest..'rotation ' end
129 | if useScale then currTest = currTest..'scale ' end
130 | if useTranslation then currTest = currTest..'translation' end
131 | if currTest=='' then currTest = 'full' end
132 |
133 | local nbNeededParams = 0
134 | if useRotation then nbNeededParams = nbNeededParams + 1 end
135 | if useScale then nbNeededParams = nbNeededParams + 1 end
136 | if useTranslation then nbNeededParams = nbNeededParams + 2 end
137 | if nbNeededParams == 0 then nbNeededParams = 6 end -- full affine case
138 |
139 | local nframes = torch.random(2,10)
140 | local params = torch.zeros(nframes,nbNeededParams):uniform()
141 | local module = nn.AffineTransformMatrixGenerator(useRotation,useScale,useTranslation)
142 |
143 | local err = jac.testJacobian(module,params)
144 | mytester:assertlt(err,precision, 'error on state for test '..currTest)
145 |
146 | -- IO
147 | local ferr,berr = jac.testIO(module,params)
148 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err for test '..currTest)
149 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err for test '..currTest)
150 |
151 | end
152 | end
153 | end
154 | end
155 |
156 | function stntest.AffineTransformMatrixGenerator_single()
157 | -- test all possible transformations
158 | for _,useRotation in pairs{true,false} do
159 | for _,useScale in pairs{true,false} do
160 | for _,useTranslation in pairs{true,false} do
161 | local currTest = ''
162 | if useRotation then currTest = currTest..'rotation ' end
163 | if useScale then currTest = currTest..'scale ' end
164 | if useTranslation then currTest = currTest..'translation' end
165 | if currTest=='' then currTest = 'full' end
166 |
167 | local nbNeededParams = 0
168 | if useRotation then nbNeededParams = nbNeededParams + 1 end
169 | if useScale then nbNeededParams = nbNeededParams + 1 end
170 | if useTranslation then nbNeededParams = nbNeededParams + 2 end
171 | if nbNeededParams == 0 then nbNeededParams = 6 end -- full affine case
172 |
173 | local params = torch.zeros(nbNeededParams):uniform()
174 | local module = nn.AffineTransformMatrixGenerator(useRotation,useScale,useTranslation)
175 |
176 | local err = jac.testJacobian(module,params)
177 | mytester:assertlt(err,precision, 'error on state for test '..currTest)
178 |
179 | -- IO
180 | local ferr,berr = jac.testIO(module,params)
181 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err for test '..currTest)
182 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err for test '..currTest)
183 |
184 | end
185 | end
186 | end
187 | end
188 |
189 | mytester:add(stntest)
190 |
191 | if not nn then
192 | require 'nn'
193 | jac = nn.Jacobian
194 | sjac = nn.SparseJacobian
195 | mytester:run()
196 | else
197 | jac = nn.Jacobian
198 | sjac = nn.SparseJacobian
199 | function stn.test(tests)
200 | -- randomize stuff
201 | math.randomseed(os.time())
202 | mytester:run(tests)
203 | return mytester
204 | end
205 | end
206 |
--------------------------------------------------------------------------------
/extras/spybhwd/utils.c:
--------------------------------------------------------------------------------
1 | #include "utils.h"
2 |
3 | THCState* getCutorchState(lua_State* L)
4 | {
5 | lua_getglobal(L, "cutorch");
6 | lua_getfield(L, -1, "getState");
7 | lua_call(L, 0, 1);
8 | THCState *state = (THCState*) lua_touserdata(L, -1);
9 | lua_pop(L, 2);
10 | return state;
11 | }
12 |
--------------------------------------------------------------------------------
/extras/spybhwd/utils.h:
--------------------------------------------------------------------------------
1 | #ifndef CUNN_UTILS_H
2 | #define CUNN_UTILS_H
3 |
4 | #include
5 | #include "THCGeneral.h"
6 |
7 | THCState* getCutorchState(lua_State* L);
8 |
9 | #endif
10 |
--------------------------------------------------------------------------------
/extras/stnbhwd/AffineGridGeneratorBHWD.lua:
--------------------------------------------------------------------------------
1 | local AGG, parent = torch.class('nn.AffineGridGeneratorBHWD', 'nn.Module')
2 |
3 | --[[
4 | AffineGridGeneratorBHWD(height, width) :
5 | AffineGridGeneratorBHWD:updateOutput(transformMatrix)
6 | AffineGridGeneratorBHWD:updateGradInput(transformMatrix, gradGrids)
7 |
8 | AffineGridGeneratorBHWD will take 2x3 an affine image transform matrix (homogeneous
9 | coordinates) as input, and output a grid, in normalized coordinates* that, once used
10 | with the Bilinear Sampler, will result in an affine transform.
11 |
12 | AffineGridGenerator
13 | - takes (B,2,3)-shaped transform matrices as input (B=batch).
14 | - outputs a grid in BHWD layout, that can be used directly with BilinearSamplerBHWD
15 | - initialization of the previous layer should biased towards the identity transform :
16 | | 1 0 0 |
17 | | 0 1 0 |
18 |
19 | *: normalized coordinates [-1,1] correspond to the boundaries of the input image.
20 | ]]
21 |
22 | function AGG:__init(height, width)
23 | parent.__init(self)
24 | assert(height > 1)
25 | assert(width > 1)
26 | self.height = height
27 | self.width = width
28 |
29 | self.baseGrid = torch.Tensor(height, width, 3)
30 | for i=1,self.height do
31 | self.baseGrid:select(3,1):select(1,i):fill(-1 + (i-1)/(self.height-1) * 2)
32 | end
33 | for j=1,self.width do
34 | self.baseGrid:select(3,2):select(2,j):fill(-1 + (j-1)/(self.width-1) * 2)
35 | end
36 | self.baseGrid:select(3,3):fill(1)
37 | self.batchGrid = torch.Tensor(1, height, width, 3):copy(self.baseGrid)
38 | end
39 |
40 | local function addOuterDim(t)
41 | local sizes = t:size()
42 | local newsizes = torch.LongStorage(sizes:size()+1)
43 | newsizes[1]=1
44 | for i=1,sizes:size() do
45 | newsizes[i+1]=sizes[i]
46 | end
47 | return t:view(newsizes)
48 | end
49 |
50 | function AGG:updateOutput(_transformMatrix)
51 | local transformMatrix
52 | if _transformMatrix:nDimension()==2 then
53 | transformMatrix = addOuterDim(_transformMatrix)
54 | else
55 | transformMatrix = _transformMatrix
56 | end
57 | assert(transformMatrix:nDimension()==3
58 | and transformMatrix:size(2)==2
59 | and transformMatrix:size(3)==3
60 | , 'please input affine transform matrices (bx2x3)')
61 | local batchsize = transformMatrix:size(1)
62 |
63 | if self.batchGrid:size(1) ~= batchsize then
64 | self.batchGrid:resize(batchsize, self.height, self.width, 3)
65 | for i=1,batchsize do
66 | self.batchGrid:select(1,i):copy(self.baseGrid)
67 | end
68 | end
69 |
70 | self.output:resize(batchsize, self.height, self.width, 2)
71 | local flattenedBatchGrid = self.batchGrid:view(batchsize, self.width*self.height, 3)
72 | local flattenedOutput = self.output:view(batchsize, self.width*self.height, 2)
73 | torch.bmm(flattenedOutput, flattenedBatchGrid, transformMatrix:transpose(2,3))
74 | if _transformMatrix:nDimension()==2 then
75 | self.output = self.output:select(1,1)
76 | end
77 | return self.output
78 | end
79 |
80 | function AGG:updateGradInput(_transformMatrix, _gradGrid)
81 | local transformMatrix, gradGrid
82 | if _transformMatrix:nDimension()==2 then
83 | transformMatrix = addOuterDim(_transformMatrix)
84 | gradGrid = addOuterDim(_gradGrid)
85 | else
86 | transformMatrix = _transformMatrix
87 | gradGrid = _gradGrid
88 | end
89 |
90 | local batchsize = transformMatrix:size(1)
91 | local flattenedGradGrid = gradGrid:view(batchsize, self.width*self.height, 2)
92 | local flattenedBatchGrid = self.batchGrid:view(batchsize, self.width*self.height, 3)
93 | self.gradInput:resizeAs(transformMatrix):zero()
94 | self.gradInput:baddbmm(flattenedGradGrid:transpose(2,3), flattenedBatchGrid)
95 | -- torch.baddbmm doesn't work on cudatensors for some reason
96 |
97 | if _transformMatrix:nDimension()==2 then
98 | self.gradInput = self.gradInput:select(1,1)
99 | end
100 |
101 | return self.gradInput
102 | end
103 |
--------------------------------------------------------------------------------
/extras/stnbhwd/AffineTransformMatrixGenerator.lua:
--------------------------------------------------------------------------------
1 | local ATMG, parent = torch.class('nn.AffineTransformMatrixGenerator', 'nn.Module')
2 |
3 | --[[
4 | AffineTransformMatrixGenerator(useRotation, useScale, useTranslation) :
5 | AffineTransformMatrixGenerator:updateOutput(transformParams)
6 | AffineTransformMatrixGenerator:updateGradInput(transformParams, gradParams)
7 |
8 | This module can be used in between the localisation network (that outputs the
9 | parameters of the transformation) and the AffineGridGeneratorBHWD (that expects
10 | an affine transform matrix as input).
11 |
12 | The goal is to be able to use only specific transformations or a combination of them.
13 |
14 | If no specific transformation is specified, it uses a fully parametrized
15 | linear transformation and thus expects 6 parameters as input. In this case
16 | the module is equivalent to nn.View(2,3):setNumInputDims(2).
17 |
18 | Any combination of the 3 transformations (rotation, scale and/or translation)
19 | can be used. The transform parameters must be supplied in the following order:
20 | rotation (1 param), scale (1 param) then translation (2 params).
21 |
22 | Example:
23 | AffineTransformMatrixGenerator(true,false,true) expects as input a tensor of
24 | if size (B, 3) containing (rotationAngle, translationX, translationY).
25 | ]]
26 |
27 | function ATMG:__init(useRotation, useScale, useTranslation)
28 | parent.__init(self)
29 |
30 | -- if no specific transformation, use fully parametrized version
31 | self.fullMode = not(useRotation or useScale or useTranslation)
32 |
33 | if not self.fullMode then
34 | self.useRotation = useRotation
35 | self.useScale = useScale
36 | self.useTranslation = useTranslation
37 | end
38 | end
39 |
40 | function ATMG:check(input)
41 | if self.fullMode then
42 | assert(input:size(2)==6, 'Expected 6 parameters, got ' .. input:size(2))
43 | else
44 | local numberParameters = 0
45 | if self.useRotation then
46 | numberParameters = numberParameters + 1
47 | end
48 | if self.useScale then
49 | numberParameters = numberParameters + 1
50 | end
51 | if self.useTranslation then
52 | numberParameters = numberParameters + 2
53 | end
54 | assert(input:size(2)==numberParameters, 'Expected '..numberParameters..
55 | ' parameters, got ' .. input:size(2))
56 | end
57 | end
58 |
59 | local function addOuterDim(t)
60 | local sizes = t:size()
61 | local newsizes = torch.LongStorage(sizes:size()+1)
62 | newsizes[1]=1
63 | for i=1,sizes:size() do
64 | newsizes[i+1]=sizes[i]
65 | end
66 | return t:view(newsizes)
67 | end
68 |
69 | function ATMG:updateOutput(_tranformParams)
70 | local transformParams
71 | if _tranformParams:nDimension()==1 then
72 | transformParams = addOuterDim(_tranformParams)
73 | else
74 | transformParams = _tranformParams
75 | end
76 |
77 | self:check(transformParams)
78 | local batchSize = transformParams:size(1)
79 |
80 | if self.fullMode then
81 | self.output = transformParams:view(batchSize, 2, 3)
82 | else
83 | local completeTransformation = torch.zeros(batchSize,3,3):typeAs(transformParams)
84 | completeTransformation:select(3,1):select(2,1):add(1)
85 | completeTransformation:select(3,2):select(2,2):add(1)
86 | completeTransformation:select(3,3):select(2,3):add(1)
87 | local transformationBuffer = torch.Tensor(batchSize,3,3):typeAs(transformParams)
88 |
89 | local paramIndex = 1
90 | if self.useRotation then
91 | local alphas = transformParams:select(2, paramIndex)
92 | paramIndex = paramIndex + 1
93 |
94 | transformationBuffer:zero()
95 | transformationBuffer:select(3,3):select(2,3):add(1)
96 | local cosines = torch.cos(alphas)
97 | local sinuses = torch.sin(alphas)
98 | transformationBuffer:select(3,1):select(2,1):copy(cosines)
99 | transformationBuffer:select(3,2):select(2,2):copy(cosines)
100 | transformationBuffer:select(3,1):select(2,2):copy(sinuses)
101 | transformationBuffer:select(3,2):select(2,1):copy(-sinuses)
102 |
103 | completeTransformation = torch.bmm(completeTransformation, transformationBuffer)
104 | end
105 | self.rotationOutput = completeTransformation:narrow(2,1,2):narrow(3,1,2):clone()
106 |
107 | if self.useScale then
108 | local scaleFactors = transformParams:select(2,paramIndex)
109 | paramIndex = paramIndex + 1
110 |
111 | transformationBuffer:zero()
112 | transformationBuffer:select(3,1):select(2,1):copy(scaleFactors)
113 | transformationBuffer:select(3,2):select(2,2):copy(scaleFactors)
114 | transformationBuffer:select(3,3):select(2,3):add(1)
115 |
116 | completeTransformation = torch.bmm(completeTransformation, transformationBuffer)
117 | end
118 | self.scaleOutput = completeTransformation:narrow(2,1,2):narrow(3,1,2):clone()
119 |
120 | if self.useTranslation then
121 | local txs = transformParams:select(2,paramIndex)
122 | local tys = transformParams:select(2,paramIndex+1)
123 |
124 | transformationBuffer:zero()
125 | transformationBuffer:select(3,1):select(2,1):add(1)
126 | transformationBuffer:select(3,2):select(2,2):add(1)
127 | transformationBuffer:select(3,3):select(2,3):add(1)
128 | transformationBuffer:select(3,3):select(2,1):copy(txs)
129 | transformationBuffer:select(3,3):select(2,2):copy(tys)
130 |
131 | completeTransformation = torch.bmm(completeTransformation, transformationBuffer)
132 | end
133 |
134 | self.output=completeTransformation:narrow(2,1,2)
135 | end
136 |
137 | if _tranformParams:nDimension()==1 then
138 | self.output = self.output:select(1,1)
139 | end
140 | return self.output
141 | end
142 |
143 |
144 | function ATMG:updateGradInput(_tranformParams, _gradParams)
145 | local transformParams, gradParams
146 | if _tranformParams:nDimension()==1 then
147 | transformParams = addOuterDim(_tranformParams)
148 | gradParams = addOuterDim(_gradParams):clone()
149 | else
150 | transformParams = _tranformParams
151 | gradParams = _gradParams:clone()
152 | end
153 |
154 | local batchSize = transformParams:size(1)
155 | if self.fullMode then
156 | self.gradInput = gradParams:view(batchSize, 6)
157 | else
158 | local paramIndex = transformParams:size(2)
159 | self.gradInput:resizeAs(transformParams)
160 | if self.useTranslation then
161 | local gradInputTranslationParams = self.gradInput:narrow(2,paramIndex-1,2)
162 | local tParams = torch.Tensor(batchSize, 1, 2):typeAs(transformParams)
163 | tParams:select(3,1):copy(transformParams:select(2,paramIndex-1))
164 | tParams:select(3,2):copy(transformParams:select(2,paramIndex))
165 | paramIndex = paramIndex-2
166 |
167 | local selectedOutput = self.scaleOutput
168 | local selectedGradParams = gradParams:narrow(2,1,2):narrow(3,3,1):transpose(2,3)
169 | gradInputTranslationParams:copy(torch.bmm(selectedGradParams, selectedOutput))
170 |
171 | local gradientCorrection = torch.bmm(selectedGradParams:transpose(2,3), tParams)
172 | gradParams:narrow(3,1,2):narrow(2,1,2):add(1,gradientCorrection)
173 | end
174 |
175 | if self.useScale then
176 | local gradInputScaleparams = self.gradInput:narrow(2,paramIndex,1)
177 | local sParams = transformParams:select(2,paramIndex)
178 | paramIndex = paramIndex-1
179 |
180 | local selectedOutput = self.rotationOutput
181 | local selectedGradParams = gradParams:narrow(2,1,2):narrow(3,1,2)
182 | gradInputScaleparams:copy(torch.cmul(selectedOutput, selectedGradParams):sum(2):sum(3))
183 |
184 | gradParams:select(3,1):select(2,1):cmul(sParams)
185 | gradParams:select(3,2):select(2,1):cmul(sParams)
186 | gradParams:select(3,1):select(2,2):cmul(sParams)
187 | gradParams:select(3,2):select(2,2):cmul(sParams)
188 | end
189 |
190 | if self.useRotation then
191 | local gradInputRotationParams = self.gradInput:narrow(2,paramIndex,1)
192 | local rParams = transformParams:select(2,paramIndex)
193 |
194 | local rotationDerivative = torch.zeros(batchSize, 2, 2):typeAs(rParams)
195 | torch.sin(rotationDerivative:select(3,1):select(2,1),-rParams)
196 | torch.sin(rotationDerivative:select(3,2):select(2,2),-rParams)
197 | torch.cos(rotationDerivative:select(3,1):select(2,2),rParams)
198 | torch.cos(rotationDerivative:select(3,2):select(2,1),rParams):mul(-1)
199 | local selectedGradParams = gradParams:narrow(2,1,2):narrow(3,1,2)
200 | gradInputRotationParams:copy(torch.cmul(rotationDerivative,selectedGradParams):sum(2):sum(3))
201 | end
202 | end
203 |
204 | if _tranformParams:nDimension()==1 then
205 | self.gradInput = self.gradInput:select(1,1)
206 | end
207 | return self.gradInput
208 | end
209 |
210 |
211 |
--------------------------------------------------------------------------------
/extras/stnbhwd/BilinearSamplerBHWD.lua:
--------------------------------------------------------------------------------
1 | local BilinearSamplerBHWD, parent = torch.class('nn.BilinearSamplerBHWD', 'nn.Module')
2 |
3 | --[[
4 | BilinearSamplerBHWD() :
5 | BilinearSamplerBHWD:updateOutput({inputImages, grids})
6 | BilinearSamplerBHWD:updateGradInput({inputImages, grids}, gradOutput)
7 |
8 | BilinearSamplerBHWD will perform bilinear sampling of the input images according to the
9 | normalized coordinates provided in the grid. Output will be of same size as the grids,
10 | with as many features as the input images.
11 |
12 | - inputImages has to be in BHWD layout
13 |
14 | - grids have to be in BHWD layout, with dim(D)=2
15 | - grids contains, for each sample (first dim), the normalized coordinates of the output wrt the input sample
16 | - first coordinate is Y coordinate, second is X
17 | - normalized coordinates : (-1,-1) points to top left, (-1,1) points to top right
18 | - if the normalized coordinates fall outside of the image, then output will be filled with zeros
19 | ]]
20 |
21 | function BilinearSamplerBHWD:__init()
22 | parent.__init(self)
23 | self.gradInput={}
24 | end
25 |
26 | function BilinearSamplerBHWD:check(input, gradOutput)
27 | local inputImages = input[1]
28 | local grids = input[2]
29 |
30 | assert(inputImages:isContiguous(), 'Input images have to be contiguous')
31 | assert(inputImages:nDimension()==4)
32 | assert(grids:nDimension()==4)
33 | assert(inputImages:size(1)==grids:size(1)) -- batch
34 | assert(grids:size(4)==2) -- coordinates
35 |
36 | if gradOutput then
37 | assert(grids:size(1)==gradOutput:size(1))
38 | assert(grids:size(2)==gradOutput:size(2))
39 | assert(grids:size(3)==gradOutput:size(3))
40 | end
41 | end
42 |
43 | local function addOuterDim(t)
44 | local sizes = t:size()
45 | local newsizes = torch.LongStorage(sizes:size()+1)
46 | newsizes[1]=1
47 | for i=1,sizes:size() do
48 | newsizes[i+1]=sizes[i]
49 | end
50 | return t:view(newsizes)
51 | end
52 |
53 | function BilinearSamplerBHWD:updateOutput(input)
54 | local _inputImages = input[1]
55 | local _grids = input[2]
56 |
57 | local inputImages, grids
58 | if _inputImages:nDimension()==3 then
59 | inputImages = addOuterDim(_inputImages)
60 | grids = addOuterDim(_grids)
61 | else
62 | inputImages = _inputImages
63 | grids = _grids
64 | end
65 |
66 | local input = {inputImages, grids}
67 |
68 | self:check(input)
69 |
70 | self.output:resize(inputImages:size(1), grids:size(2), grids:size(3), inputImages:size(4))
71 |
72 | inputImages.nn.BilinearSamplerBHWD_updateOutput(self, inputImages, grids, self.output)
73 |
74 | if _inputImages:nDimension()==3 then
75 | self.output=self.output:select(1,1)
76 | end
77 |
78 | return self.output
79 | end
80 |
81 | function BilinearSamplerBHWD:updateGradInput(_input, _gradOutput)
82 | local _inputImages = _input[1]
83 | local _grids = _input[2]
84 |
85 | local inputImages, grids, gradOutput
86 | if _inputImages:nDimension()==3 then
87 | inputImages = addOuterDim(_inputImages)
88 | grids = addOuterDim(_grids)
89 | gradOutput = addOuterDim(_gradOutput)
90 | else
91 | inputImages = _inputImages
92 | grids = _grids
93 | gradOutput = _gradOutput
94 | end
95 |
96 | local input = {inputImages, grids}
97 |
98 | self:check(input, gradOutput)
99 | for i=1,#input do
100 | self.gradInput[i] = self.gradInput[i] or input[1].new()
101 | self.gradInput[i]:resizeAs(input[i]):zero()
102 | end
103 |
104 | local gradInputImages = self.gradInput[1]
105 | local gradGrids = self.gradInput[2]
106 |
107 | inputImages.nn.BilinearSamplerBHWD_updateGradInput(self, inputImages, grids, gradInputImages, gradGrids, gradOutput)
108 |
109 | if _gradOutput:nDimension()==3 then
110 | self.gradInput[1]=self.gradInput[1]:select(1,1)
111 | self.gradInput[2]=self.gradInput[2]:select(1,1)
112 | end
113 |
114 | return self.gradInput
115 | end
116 |
--------------------------------------------------------------------------------
/extras/stnbhwd/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | CMAKE_MINIMUM_REQUIRED(VERSION 2.8 FATAL_ERROR)
2 | CMAKE_POLICY(VERSION 2.8)
3 |
4 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_FORCE_INLINES")
5 | SET(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake" "${CMAKE_MODULE_PATH}")
6 |
7 | FIND_PACKAGE(Torch REQUIRED)
8 |
9 | # Flags
10 | # When using MSVC
11 | IF(MSVC)
12 | # we want to respect the standard, and we are bored of those **** .
13 | ADD_DEFINITIONS(-D_CRT_SECURE_NO_DEPRECATE=1)
14 | ENDIF(MSVC)
15 |
16 | # OpenMP support?
17 | SET(WITH_OPENMP ON CACHE BOOL "OpenMP support if available?")
18 | IF (APPLE AND CMAKE_COMPILER_IS_GNUCC)
19 | EXEC_PROGRAM (uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION)
20 | STRING (REGEX MATCH "[0-9]+" DARWIN_VERSION ${DARWIN_VERSION})
21 | MESSAGE (STATUS "MAC OS Darwin Version: ${DARWIN_VERSION}")
22 | IF (DARWIN_VERSION GREATER 9)
23 | SET(APPLE_OPENMP_SUCKS 1)
24 | ENDIF (DARWIN_VERSION GREATER 9)
25 | EXECUTE_PROCESS (COMMAND ${CMAKE_C_COMPILER} -dumpversion
26 | OUTPUT_VARIABLE GCC_VERSION)
27 | IF (APPLE_OPENMP_SUCKS AND GCC_VERSION VERSION_LESS 4.6.2)
28 | MESSAGE(STATUS "Warning: Disabling OpenMP (unstable with this version of GCC)")
29 | MESSAGE(STATUS " Install GCC >= 4.6.2 or change your OS to enable OpenMP")
30 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unknown-pragmas")
31 | SET(WITH_OPENMP OFF CACHE BOOL "OpenMP support if available?" FORCE)
32 | ENDIF ()
33 | ENDIF ()
34 |
35 | IF (WITH_OPENMP)
36 | FIND_PACKAGE(OpenMP)
37 | IF(OPENMP_FOUND)
38 | MESSAGE(STATUS "Compiling with OpenMP support")
39 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
40 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
41 | SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
42 | ENDIF(OPENMP_FOUND)
43 | ENDIF (WITH_OPENMP)
44 |
45 | LINK_DIRECTORIES("${Torch_INSTALL_LIB}")
46 |
47 | SET(src init.c)
48 | FILE(GLOB luasrc *.lua)
49 | ADD_TORCH_PACKAGE(stn "${src}" "${luasrc}")
50 | TARGET_LINK_LIBRARIES(stn luaT TH)
51 |
52 |
53 | FIND_PACKAGE(CUDA 5.5)
54 |
55 | IF (CUDA_FOUND)
56 | LIST(APPEND CUDA_NVCC_FLAGS "-arch=sm_20")
57 | LIST(APPEND CUDA_NVCC_FLAGS "-Xcompiler -std=c++98")
58 |
59 | INCLUDE_DIRECTORIES("${Torch_INSTALL_INCLUDE}/THC")
60 | SET(src-cuda init.cu)
61 | CUDA_ADD_LIBRARY(custn MODULE ${src-cuda})
62 | TARGET_LINK_LIBRARIES(custn luaT THC TH)
63 | IF(APPLE)
64 | SET_TARGET_PROPERTIES(custn PROPERTIES
65 | LINK_FLAGS "-undefined dynamic_lookup")
66 | ENDIF()
67 | ### Torch packages supposes libraries prefix is "lib"
68 | SET_TARGET_PROPERTIES(custn PROPERTIES
69 | PREFIX "lib"
70 | IMPORT_PREFIX "lib")
71 |
72 | INSTALL(TARGETS custn
73 | RUNTIME DESTINATION "${Torch_INSTALL_LUA_CPATH_SUBDIR}"
74 | LIBRARY DESTINATION "${Torch_INSTALL_LUA_CPATH_SUBDIR}")
75 | ENDIF(CUDA_FOUND)
76 |
--------------------------------------------------------------------------------
/extras/stnbhwd/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2015 qassemoquab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
--------------------------------------------------------------------------------
/extras/stnbhwd/README.md:
--------------------------------------------------------------------------------
1 | # stnbhwd
2 |
3 | ## Main modules
4 |
5 | These are the basic modules (BHWD layout) needed to implement a Spatial Transformer Network (Jaderberg et al.) http://arxiv.org/abs/1506.02025
6 |
7 | ``` lua
8 | require 'stn'
9 |
10 | nn.AffineGridGeneratorBHWD(height, width)
11 | -- takes B x 2 x 3 affine transform matrices as input,
12 | -- outputs a height x width grid in normalized [-1,1] coordinates
13 | -- output layout is B,H,W,2 where the first coordinate in the 4th dimension is y, and the second is x
14 |
15 | nn.BilinearSamplerBHWD()
16 | -- takes a table {inputImages, grids} as inputs
17 | -- outputs the interpolated images according to the grids
18 | -- inputImages is a batch of samples in BHWD layout
19 | -- grids is a batch of grids (output of AffineGridGeneratorBHWD)
20 | -- output is also BHWD
21 | ```
22 |
23 | ## Advanced module
24 |
25 | This module allows the user to put a constraint on the possible transformations.
26 | It should be placed between the localisation network and the grid generator.
27 |
28 | ``` lua
29 | require 'stn'
30 |
31 | nn.AffineTransformMatrixGenerator(useRotation, useScale, useTranslation)
32 | -- takes a B x nbParams tensor as inputs
33 | -- nbParams depends on the contrained transformation
34 | -- The parameters for the selected transformation(s) should be supplied in the
35 | -- following order: rotationAngle, scaleFactor, translationX, translationY
36 | -- If no transformation is specified, it generates a generic affine transformation (nbParams = 6)
37 | -- outputs B x 2 x 3 affine transform matrices
38 | ```
39 |
40 |
41 | If this code is useful to your research, please cite this repository.
42 |
--------------------------------------------------------------------------------
/extras/stnbhwd/ScaleBHWD.lua:
--------------------------------------------------------------------------------
1 | local ScaleBHWD, parent = torch.class('nn.ScaleBHWD', 'nn.Module')
2 |
3 | --[[
4 | ScaleBHWD() :
5 | ScaleBHWD:updateOutput({inputImages, grids})
6 | ScaleBHWD:updateGradInput({inputImages, grids}, gradOutput)
7 |
8 | ScaleBHWD will perform bilinear sampling of the input images according to the
9 | normalized coordinates provided in the grid. Output will be of same size as the grids,
10 | with as many features as the input images.
11 |
12 | - inputImages has to be in BHWD layout
13 |
14 | - grids have to be in BHWD layout, with dim(D)=2
15 | - grids contains, for each sample (first dim), the normalized coordinates of the output wrt the input sample
16 | - first coordinate is Y coordinate, second is X
17 | - normalized coordinates : (-1,-1) points to top left, (-1,1) points to top right
18 | - if the normalized coordinates fall outside of the image, then output will be filled with zeros
19 | ]]
20 |
21 | function ScaleBHWD:__init()
22 | parent.__init(self)
23 | self.gradInput={}
24 | end
25 |
26 | function ScaleBHWD:check(input, gradOutput)
27 | local inputImages = input[1]
28 | local grids = input[2]
29 |
30 | assert(inputImages:isContiguous(), 'Input images have to be contiguous')
31 | assert(inputImages:nDimension()==4)
32 | assert(grids:nDimension()==4)
33 | assert(inputImages:size(1)==grids:size(1)) -- batch
34 | assert(grids:size(4)==2) -- coordinates
35 |
36 | if gradOutput then
37 | assert(grids:size(1)==gradOutput:size(1))
38 | assert(grids:size(2)==gradOutput:size(2))
39 | assert(grids:size(3)==gradOutput:size(3))
40 | end
41 | end
42 |
43 | local function addOuterDim(t)
44 | local sizes = t:size()
45 | local newsizes = torch.LongStorage(sizes:size()+1)
46 | newsizes[1]=1
47 | for i=1,sizes:size() do
48 | newsizes[i+1]=sizes[i]
49 | end
50 | return t:view(newsizes)
51 | end
52 |
53 | function ScaleBHWD:updateOutput(input)
54 | local _inputImages = input[1]
55 | local _grids = input[2]
56 |
57 | local inputImages, grids
58 | if _inputImages:nDimension()==3 then
59 | inputImages = addOuterDim(_inputImages)
60 | grids = addOuterDim(_grids)
61 | else
62 | inputImages = _inputImages
63 | grids = _grids
64 | end
65 |
66 | local input = {inputImages, grids}
67 |
68 | self:check(input)
69 |
70 | self.output:resize(inputImages:size(1), grids:size(2), grids:size(3), inputImages:size(4))
71 |
72 | inputImages.nn.ScaleBHWD_updateOutput(self, inputImages, grids, self.output)
73 |
74 | if _inputImages:nDimension()==3 then
75 | self.output=self.output:select(1,1)
76 | end
77 |
78 | return self.output
79 | end
80 |
81 | function ScaleBHWD:updateGradInput(_input, _gradOutput)
82 | local _inputImages = _input[1]
83 | local _grids = _input[2]
84 |
85 | local inputImages, grids, gradOutput
86 | if _inputImages:nDimension()==3 then
87 | inputImages = addOuterDim(_inputImages)
88 | grids = addOuterDim(_grids)
89 | gradOutput = addOuterDim(_gradOutput)
90 | else
91 | inputImages = _inputImages
92 | grids = _grids
93 | gradOutput = _gradOutput
94 | end
95 |
96 | local input = {inputImages, grids}
97 |
98 | self:check(input, gradOutput)
99 | for i=1,#input do
100 | self.gradInput[i] = self.gradInput[i] or input[1].new()
101 | self.gradInput[i]:resizeAs(input[i]):zero()
102 | end
103 |
104 | local gradInputImages = self.gradInput[1]
105 | local gradGrids = self.gradInput[2]
106 |
107 | inputImages.nn.ScaleBHWD_updateGradInput(self, inputImages, grids, gradInputImages, gradGrids, gradOutput)
108 |
109 | if _gradOutput:nDimension()==3 then
110 | self.gradInput[1]=self.gradInput[1]:select(1,1)
111 | self.gradInput[2]=self.gradInput[2]:select(1,1)
112 | end
113 |
114 | return self.gradInput
115 | end
116 |
--------------------------------------------------------------------------------
/extras/stnbhwd/demo/Optim.lua:
--------------------------------------------------------------------------------
1 | --[[ That would be the license for Optim.lua
2 |
3 | BSD License
4 |
5 | For fbcunn software
6 |
7 | Copyright (c) 2014, Facebook, Inc. All rights reserved.
8 |
9 | Redistribution and use in source and binary forms, with or without modification,
10 | are permitted provided that the following conditions are met:
11 |
12 | * Redistributions of source code must retain the above copyright notice, this
13 | list of conditions and the following disclaimer.
14 |
15 | * Redistributions in binary form must reproduce the above copyright notice,
16 | this list of conditions and the following disclaimer in the documentation
17 | and/or other materials provided with the distribution.
18 |
19 | * Neither the name Facebook nor the names of its contributors may be used to
20 | endorse or promote products derived from this software without specific
21 | prior written permission.
22 |
23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
24 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
25 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
27 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
30 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 | ]]
34 |
35 | -- Copyright 2004-present Facebook. All Rights Reserved.
36 |
37 | local pl = require('pl.import_into')()
38 |
39 | -- from fblualib/fb/util/data.lua , copied here because fblualib is not rockspec ready yet.
40 | -- deepcopy routine that assumes the presence of a 'clone' method in user
41 | -- data should be used to deeply copy. This matches the behavior of Torch
42 | -- tensors.
43 | local function deepcopy(x)
44 | local typename = type(x)
45 | if typename == "userdata" then
46 | return x:clone()
47 | end
48 | if typename == "table" then
49 | local retval = { }
50 | for k,v in pairs(x) do
51 | retval[deepcopy(k)] = deepcopy(v)
52 | end
53 | return retval
54 | end
55 | return x
56 | end
57 |
58 | local Optim, parent = torch.class('nn.Optim')
59 |
60 |
61 | -- Returns weight parameters and bias parameters and associated grad parameters
62 | -- for this module. Annotates the return values with flag marking parameter set
63 | -- as bias parameters set
64 | function Optim.weight_bias_parameters(module)
65 | local weight_params, bias_params
66 | if module.weight then
67 | weight_params = {module.weight, module.gradWeight}
68 | weight_params.is_bias = false
69 | end
70 | if module.bias then
71 | bias_params = {module.bias, module.gradBias}
72 | bias_params.is_bias = true
73 | end
74 | return {weight_params, bias_params}
75 | end
76 |
77 | -- The regular `optim` package relies on `getParameters`, which is a
78 | -- beastly abomination before all. This `optim` package uses separate
79 | -- optim state for each submodule of a `nn.Module`.
80 | function Optim:__init(model, optState, checkpoint_data)
81 | assert(model)
82 | assert(checkpoint_data or optState)
83 | assert(not (checkpoint_data and optState))
84 |
85 | self.model = model
86 | self.modulesToOptState = {}
87 | -- Keep this around so we update it in setParameters
88 | self.originalOptState = optState
89 |
90 | -- Each module has some set of parameters and grad parameters. Since
91 | -- they may be allocated discontinuously, we need separate optState for
92 | -- each parameter tensor. self.modulesToOptState maps each module to
93 | -- a lua table of optState clones.
94 | if not checkpoint_data then
95 | self.model:apply(function(module)
96 | self.modulesToOptState[module] = { }
97 | local params = self.weight_bias_parameters(module)
98 | -- expects either an empty table or 2 element table, one for weights
99 | -- and one for biases
100 | assert(pl.tablex.size(params) == 0 or pl.tablex.size(params) == 2)
101 | for i, _ in ipairs(params) do
102 | self.modulesToOptState[module][i] = deepcopy(optState)
103 | if params[i] and params[i].is_bias then
104 | -- never regularize biases
105 | self.modulesToOptState[module][i].weightDecay = 0.0
106 | end
107 | end
108 | assert(module)
109 | assert(self.modulesToOptState[module])
110 | end)
111 | else
112 | local state = checkpoint_data.optim_state
113 | local modules = {}
114 | self.model:apply(function(m) table.insert(modules, m) end)
115 | assert(pl.tablex.compare_no_order(modules, pl.tablex.keys(state)))
116 | self.modulesToOptState = state
117 | end
118 | end
119 |
120 | function Optim:save()
121 | return {
122 | optim_state = self.modulesToOptState
123 | }
124 | end
125 |
126 | local function _type_all(obj, t)
127 | for k, v in pairs(obj) do
128 | if type(v) == 'table' then
129 | _type_all(v, t)
130 | else
131 | local tn = torch.typename(v)
132 | if tn and tn:find('torch%..+Tensor') then
133 | obj[k] = v:type(t)
134 | end
135 | end
136 | end
137 | end
138 |
139 | function Optim:type(t)
140 | self.model:apply(function(module)
141 | local state= self.modulesToOptState[module]
142 | assert(state)
143 | _type_all(state, t)
144 | end)
145 | end
146 |
147 | local function get_device_for_module(mod)
148 | local dev_id = nil
149 | for name, val in pairs(mod) do
150 | if torch.typename(val) == 'torch.CudaTensor' then
151 | local this_dev = val:getDevice()
152 | if this_dev ~= 0 then
153 | -- _make sure the tensors are allocated consistently
154 | assert(dev_id == nil or dev_id == this_dev)
155 | dev_id = this_dev
156 | end
157 | end
158 | end
159 | return dev_id -- _may still be zero if none are allocated.
160 | end
161 |
162 | local function on_device_for_module(mod, f)
163 | local this_dev = get_device_for_module(mod)
164 | if this_dev ~= nil then
165 | return cutorch.withDevice(this_dev, f)
166 | end
167 | return f()
168 | end
169 |
170 | function Optim:optimize(optimMethod, inputs, targets, criterion)
171 | assert(optimMethod)
172 | assert(inputs)
173 | assert(targets)
174 | assert(criterion)
175 | assert(self.modulesToOptState)
176 |
177 | self.model:zeroGradParameters()
178 | local output = self.model:forward(inputs)
179 |
180 | local err = criterion:forward(output, targets)
181 |
182 | local df_do = criterion:backward(output, targets)
183 | self.model:backward(inputs, df_do)
184 |
185 | -- We'll set these in the loop that iterates over each module. Get them
186 | -- out here to be captured.
187 | local curGrad
188 | local curParam
189 | local function fEvalMod(x)
190 | return err, curGrad
191 | end
192 |
193 | for curMod, opt in pairs(self.modulesToOptState) do
194 | on_device_for_module(curMod, function()
195 | local curModParams = self.weight_bias_parameters(curMod)
196 | -- expects either an empty table or 2 element table, one for weights
197 | -- and one for biases
198 | assert(pl.tablex.size(curModParams) == 0 or
199 | pl.tablex.size(curModParams) == 2)
200 | if curModParams then
201 | for i, tensor in ipairs(curModParams) do
202 | if curModParams[i] then
203 | -- expect param, gradParam pair
204 | curParam, curGrad = table.unpack(curModParams[i])
205 | assert(curParam and curGrad)
206 | optimMethod(fEvalMod, curParam, opt[i])
207 | end
208 | end
209 | end
210 | end)
211 | end
212 |
213 | return err, output
214 | end
215 |
216 | function Optim:optimizeFromGradients(optimMethod, inputs, gradients)
217 | assert(optimMethod)
218 | assert(inputs)
219 | assert(gradients)
220 | assert(self.modulesToOptState)
221 |
222 | self.model:zeroGradParameters()
223 | self.model:backward(inputs, gradients)
224 |
225 | -- We'll set these in the loop that iterates over each module. Get them
226 | -- out here to be captured.
227 | local curGrad
228 | local curParam
229 | local function fEvalMod(x)
230 | return 0, curGrad
231 | end
232 |
233 | for curMod, opt in pairs(self.modulesToOptState) do
234 | on_device_for_module(curMod, function()
235 | local curModParams = self.weight_bias_parameters(curMod)
236 | -- expects either an empty table or 2 element table, one for weights
237 | -- and one for biases
238 | assert(pl.tablex.size(curModParams) == 0 or
239 | pl.tablex.size(curModParams) == 2)
240 | if curModParams then
241 | for i, tensor in ipairs(curModParams) do
242 | if curModParams[i] then
243 | -- expect param, gradParam pair
244 | curParam, curGrad = table.unpack(curModParams[i])
245 | assert(curParam and curGrad)
246 | optimMethod(fEvalMod, curParam, opt[i])
247 | end
248 | end
249 | end
250 | end)
251 | end
252 |
253 | return err, output
254 | end
255 |
256 | function Optim:setParameters(newParams)
257 | assert(newParams)
258 | assert(type(newParams) == 'table')
259 | local function splice(dest, src)
260 | for k,v in pairs(src) do
261 | dest[k] = v
262 | end
263 | end
264 |
265 | splice(self.originalOptState, newParams)
266 | for _,optStates in pairs(self.modulesToOptState) do
267 | for i,optState in pairs(optStates) do
268 | assert(type(optState) == 'table')
269 | splice(optState, newParams)
270 | end
271 | end
272 | end
--------------------------------------------------------------------------------
/extras/stnbhwd/demo/README.md:
--------------------------------------------------------------------------------
1 | # stnbhwd demo
2 |
3 | Download MNIST and untar in the demo folder, then run with qlua (for image.display):
4 |
5 | ```
6 | wget 'http://torch7.s3-website-us-east-1.amazonaws.com/data/mnist.t7.tgz'
7 | tar -xf mnist.t7.tgz
8 | qlua -ide demo_mnist.lua
9 | ```
10 |
11 | Images should appear after 5 epochs and show what the STN does on a test batch.
12 | You can edit demo_mnist.lua set use_stn = false to compare accuracy.
13 |
14 | You will need to work with the getParamsByDevice branch of the 'nn' package (required for nn.Optim).
15 |
--------------------------------------------------------------------------------
/extras/stnbhwd/demo/demo_mnist.lua:
--------------------------------------------------------------------------------
1 | -- wget 'http://torch7.s3-website-us-east-1.amazonaws.com/data/mnist.t7.tgz'
2 | -- tar -xf mnist.t7.tgz
3 |
4 | require 'cunn'
5 | require 'cudnn'
6 | require 'image'
7 | require 'optim'
8 | paths.dofile('Optim.lua')
9 |
10 | use_stn = true
11 |
12 | -- distorted mnist dataset
13 | paths.dofile('distort_mnist.lua')
14 | datasetTrain, datasetVal = createDatasetsDistorted()
15 |
16 | -- model
17 | model = nn.Sequential()
18 | model:add(nn.View(32*32))
19 | model:add(nn.Linear(32*32, 128))
20 | model:add(cudnn.ReLU(true))
21 | model:add(nn.Linear(128, 128))
22 | model:add(cudnn.ReLU(true))
23 | model:add(nn.Linear(128, 10))
24 | model:add(nn.LogSoftMax())
25 |
26 | if use_stn then
27 | require 'stn'
28 | paths.dofile('spatial_transformer.lua')
29 | model:insert(spanet,1)
30 | end
31 |
32 | model:cuda()
33 | criterion = nn.ClassNLLCriterion():cuda()
34 |
35 | optimState = {learningRate = 0.01, momentum = 0.9, weightDecay = 5e-4}
36 | optimizer = nn.Optim(model, optimState)
37 |
38 | local w1,w2
39 |
40 | for epoch=1,30 do
41 | model:training()
42 | local trainError = 0
43 | for batchidx = 1, datasetTrain:getNumBatches() do
44 | local inputs, labels = datasetTrain:getBatch(batchidx)
45 | err = optimizer:optimize(optim.sgd, inputs:cuda(), labels:cuda(), criterion)
46 | --print('epoch : ', epoch, 'batch : ', batchidx, 'train error : ', err)
47 | trainError = trainError + err
48 | end
49 | print('epoch : ', epoch, 'trainError : ', trainError / datasetTrain:getNumBatches())
50 |
51 | model:evaluate()
52 | local valError = 0
53 | local correct = 0
54 | local all = 0
55 | for batchidx = 1, datasetVal:getNumBatches() do
56 | local inputs, labels = datasetVal:getBatch(batchidx)
57 | local pred = model:forward(inputs:cuda())
58 | valError = valError + criterion:forward(pred, labels:cuda())
59 | _, preds = pred:max(2)
60 | correct = correct + preds:eq(labels:cuda()):sum()
61 | all = all + preds:size(1)
62 | end
63 | print('validation error : ', valError / datasetVal:getNumBatches())
64 | print('accuracy % : ', correct / all * 100)
65 | print('')
66 |
67 | if use_stn then
68 | w1=image.display({image=spanet.output, nrow=16, legend='STN-transformed inputs, epoch : '..epoch, win=w1})
69 | w2=image.display({image=tranet:get(1).output, nrow=16, legend='Inputs, epoch : '..epoch, win=w2})
70 | end
71 |
72 | end
73 |
74 |
--------------------------------------------------------------------------------
/extras/stnbhwd/demo/distort_mnist.lua:
--------------------------------------------------------------------------------
1 | -- wget 'http://torch7.s3-website-us-east-1.amazonaws.com/data/mnist.t7.tgz'
2 | -- tar -xf mnist.t7.tgz
3 |
4 | function distortData(foo)
5 | local res=torch.FloatTensor(foo:size(1), 1, 42, 42):fill(0)
6 | for i=1,foo:size(1) do
7 | baseImg=foo:select(1,i)
8 | distImg=res:select(1,i)
9 |
10 | r = image.rotate(baseImg, torch.uniform(-3.14/4,3.14/4))
11 | scale = torch.uniform(0.7,1.2)
12 | sz = torch.floor(scale*32)
13 | s = image.scale(r, sz, sz)
14 | rest = 42-sz
15 | offsetx = torch.random(1, 1+rest)
16 | offsety = torch.random(1, 1+rest)
17 |
18 | distImg:narrow(2, offsety, sz):narrow(3,offsetx, sz):copy(s)
19 | end
20 | return res
21 | end
22 |
23 | function distortData32(foo)
24 | local res=torch.FloatTensor(foo:size(1), 1, 32, 32):fill(0)
25 | local distImg=torch.FloatTensor(1, 42, 42):fill(0)
26 | for i=1,foo:size(1) do
27 | baseImg=foo:select(1,i)
28 |
29 | r = image.rotate(baseImg, torch.uniform(-3.14/4,3.14/4))
30 | scale = torch.uniform(0.7,1.2)
31 | sz = torch.floor(scale*32)
32 | s = image.scale(r, sz, sz)
33 | rest = 42-sz
34 | offsetx = torch.random(1, 1+rest)
35 | offsety = torch.random(1, 1+rest)
36 |
37 | distImg:zero()
38 | distImg:narrow(2, offsety, sz):narrow(3,offsetx, sz):copy(s)
39 | res:select(1,i):copy(image.scale(distImg,32,32))
40 | end
41 | return res
42 | end
43 |
44 | function createDatasetsDistorted()
45 | local testFileName = 'mnist.t7/test_32x32.t7'
46 | local trainFileName = 'mnist.t7/train_32x32.t7'
47 | local train = torch.load(trainFileName, 'ascii')
48 | local test = torch.load(testFileName, 'ascii')
49 | train.data = train.data:float()
50 | train.labels = train.labels:float()
51 | test.data = test.data:float()
52 | test.labels = test.labels:float()
53 |
54 | -- distortion
55 | train.data = distortData32(train.data)
56 | test.data = distortData32(test.data)
57 |
58 | local mean = train.data:mean()
59 | local std = train.data:std()
60 | train.data:add(-mean):div(std)
61 | test.data:add(-mean):div(std)
62 |
63 | local batchSize = 256
64 |
65 | local datasetTrain = {
66 | getBatch = function(self, idx)
67 | local data = train.data:narrow(1, (idx - 1) * batchSize + 1, batchSize)
68 | local labels = train.labels:narrow(1, (idx - 1) * batchSize + 1, batchSize)
69 | return data, labels, batchSize
70 | end,
71 | getNumBatches = function()
72 | return torch.floor(60000 / batchSize)
73 | end
74 | }
75 |
76 | local datasetVal = {
77 | getBatch = function(self, idx)
78 | local data = test.data:narrow(1, (idx - 1) * batchSize + 1, batchSize)
79 | local labels = test.labels:narrow(1, (idx - 1) * batchSize + 1, batchSize)
80 | return data, labels, batchSize
81 | end,
82 | getNumBatches = function()
83 | return torch.floor(10000 / batchSize)
84 | end
85 | }
86 |
87 | return datasetTrain, datasetVal
88 | end
--------------------------------------------------------------------------------
/extras/stnbhwd/demo/spatial_transformer.lua:
--------------------------------------------------------------------------------
1 | require 'stn'
2 |
3 | spanet=nn.Sequential()
4 |
5 | local concat=nn.ConcatTable()
6 |
7 | -- first branch is there to transpose inputs to BHWD, for the bilinear sampler
8 | tranet=nn.Sequential()
9 | tranet:add(nn.Identity())
10 | tranet:add(nn.Transpose({2,3},{3,4}))
11 |
12 | -- second branch is the localization network
13 | local locnet = nn.Sequential()
14 | locnet:add(cudnn.SpatialMaxPooling(2,2,2,2))
15 | locnet:add(cudnn.SpatialConvolution(1,20,5,5))
16 | locnet:add(cudnn.ReLU(true))
17 | locnet:add(cudnn.SpatialMaxPooling(2,2,2,2))
18 | locnet:add(cudnn.SpatialConvolution(20,20,5,5))
19 | locnet:add(cudnn.ReLU(true))
20 | locnet:add(nn.View(20*2*2))
21 | locnet:add(nn.Linear(20*2*2,20))
22 | locnet:add(cudnn.ReLU(true))
23 |
24 | -- we initialize the output layer so it gives the identity transform
25 | local outLayer = nn.Linear(20,6)
26 | outLayer.weight:fill(0)
27 | local bias = torch.FloatTensor(6):fill(0)
28 | bias[1]=1
29 | bias[5]=1
30 | outLayer.bias:copy(bias)
31 | locnet:add(outLayer)
32 |
33 | -- there we generate the grids
34 | locnet:add(nn.View(2,3))
35 | locnet:add(nn.AffineGridGeneratorBHWD(32,32))
36 |
37 | -- we need a table input for the bilinear sampler, so we use concattable
38 | concat:add(tranet)
39 | concat:add(locnet)
40 |
41 | spanet:add(concat)
42 | spanet:add(nn.BilinearSamplerBHWD())
43 |
44 | -- and we transpose back to standard BDHW format for subsequent processing by nn modules
45 | spanet:add(nn.Transpose({3,4},{2,3}))
46 |
--------------------------------------------------------------------------------
/extras/stnbhwd/generic/BilinearSamplerBHWD.c:
--------------------------------------------------------------------------------
1 | #ifndef TH_GENERIC_FILE
2 | #define TH_GENERIC_FILE "generic/BilinearSamplerBHWD.c"
3 | #else
4 |
5 | #include
6 |
7 |
8 | static int nn_(BilinearSamplerBHWD_updateOutput)(lua_State *L)
9 | {
10 | THTensor *inputImages = luaT_checkudata(L, 2, torch_Tensor);
11 | THTensor *grids = luaT_checkudata(L, 3, torch_Tensor);
12 | THTensor *output = luaT_checkudata(L, 4, torch_Tensor);
13 |
14 | int batchsize = inputImages->size[0];
15 | int inputImages_height = inputImages->size[1];
16 | int inputImages_width = inputImages->size[2];
17 | int output_height = output->size[1];
18 | int output_width = output->size[2];
19 | int inputImages_channels = inputImages->size[3];
20 |
21 | int output_strideBatch = output->stride[0];
22 | int output_strideHeight = output->stride[1];
23 | int output_strideWidth = output->stride[2];
24 |
25 | int inputImages_strideBatch = inputImages->stride[0];
26 | int inputImages_strideHeight = inputImages->stride[1];
27 | int inputImages_strideWidth = inputImages->stride[2];
28 |
29 | int grids_strideBatch = grids->stride[0];
30 | int grids_strideHeight = grids->stride[1];
31 | int grids_strideWidth = grids->stride[2];
32 |
33 | real *inputImages_data, *output_data, *grids_data;
34 | inputImages_data = THTensor_(data)(inputImages);
35 | output_data = THTensor_(data)(output);
36 | grids_data = THTensor_(data)(grids);
37 |
38 | int b, yOut, xOut;
39 |
40 | for(b=0; b < batchsize; b++)
41 | {
42 | for(yOut=0; yOut < output_height; yOut++)
43 | {
44 | for(xOut=0; xOut < output_width; xOut++)
45 | {
46 | //read the grid
47 | real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth];
48 | real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1];
49 |
50 | // get the weights for interpolation
51 | int yInTopLeft, xInTopLeft;
52 | real yWeightTopLeft, xWeightTopLeft;
53 |
54 | real xcoord = (xf + 1) * (inputImages_width - 1) / 2;
55 | xInTopLeft = floor(xcoord);
56 | xWeightTopLeft = 1 - (xcoord - xInTopLeft);
57 |
58 | real ycoord = (yf + 1) * (inputImages_height - 1) / 2;
59 | yInTopLeft = floor(ycoord);
60 | yWeightTopLeft = 1 - (ycoord - yInTopLeft);
61 |
62 |
63 |
64 | const int outAddress = output_strideBatch * b + output_strideHeight * yOut + output_strideWidth * xOut;
65 | const int inTopLeftAddress = inputImages_strideBatch * b + inputImages_strideHeight * yInTopLeft + inputImages_strideWidth * xInTopLeft;
66 | const int inTopRightAddress = inTopLeftAddress + inputImages_strideWidth;
67 | const int inBottomLeftAddress = inTopLeftAddress + inputImages_strideHeight;
68 | const int inBottomRightAddress = inBottomLeftAddress + inputImages_strideWidth;
69 |
70 | real v=0;
71 | real inTopLeft=0;
72 | real inTopRight=0;
73 | real inBottomLeft=0;
74 | real inBottomRight=0;
75 |
76 | // we are careful with the boundaries
77 | bool topLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
78 | bool topRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
79 | bool bottomLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;
80 | bool bottomRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;
81 |
82 | int t;
83 | // interpolation happens here
84 | for(t=0; tsize[0];
117 | int inputImages_height = inputImages->size[1];
118 | int inputImages_width = inputImages->size[2];
119 | int gradOutput_height = gradOutput->size[1];
120 | int gradOutput_width = gradOutput->size[2];
121 | int inputImages_channels = inputImages->size[3];
122 |
123 | int gradOutput_strideBatch = gradOutput->stride[0];
124 | int gradOutput_strideHeight = gradOutput->stride[1];
125 | int gradOutput_strideWidth = gradOutput->stride[2];
126 |
127 | int inputImages_strideBatch = inputImages->stride[0];
128 | int inputImages_strideHeight = inputImages->stride[1];
129 | int inputImages_strideWidth = inputImages->stride[2];
130 |
131 | int gradInputImages_strideBatch = gradInputImages->stride[0];
132 | int gradInputImages_strideHeight = gradInputImages->stride[1];
133 | int gradInputImages_strideWidth = gradInputImages->stride[2];
134 |
135 | int grids_strideBatch = grids->stride[0];
136 | int grids_strideHeight = grids->stride[1];
137 | int grids_strideWidth = grids->stride[2];
138 |
139 | int gradGrids_strideBatch = gradGrids->stride[0];
140 | int gradGrids_strideHeight = gradGrids->stride[1];
141 | int gradGrids_strideWidth = gradGrids->stride[2];
142 |
143 | real *inputImages_data, *gradOutput_data, *grids_data, *gradGrids_data, *gradInputImages_data;
144 | inputImages_data = THTensor_(data)(inputImages);
145 | gradOutput_data = THTensor_(data)(gradOutput);
146 | grids_data = THTensor_(data)(grids);
147 | gradGrids_data = THTensor_(data)(gradGrids);
148 | gradInputImages_data = THTensor_(data)(gradInputImages);
149 |
150 | int b, yOut, xOut;
151 |
152 | for(b=0; b < batchsize; b++)
153 | {
154 | for(yOut=0; yOut < gradOutput_height; yOut++)
155 | {
156 | for(xOut=0; xOut < gradOutput_width; xOut++)
157 | {
158 | //read the grid
159 | real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth];
160 | real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1];
161 |
162 | // get the weights for interpolation
163 | int yInTopLeft, xInTopLeft;
164 | real yWeightTopLeft, xWeightTopLeft;
165 |
166 | real xcoord = (xf + 1) * (inputImages_width - 1) / 2;
167 | xInTopLeft = floor(xcoord);
168 | xWeightTopLeft = 1 - (xcoord - xInTopLeft);
169 |
170 | real ycoord = (yf + 1) * (inputImages_height - 1) / 2;
171 | yInTopLeft = floor(ycoord);
172 | yWeightTopLeft = 1 - (ycoord - yInTopLeft);
173 |
174 |
175 | const int inTopLeftAddress = inputImages_strideBatch * b + inputImages_strideHeight * yInTopLeft + inputImages_strideWidth * xInTopLeft;
176 | const int inTopRightAddress = inTopLeftAddress + inputImages_strideWidth;
177 | const int inBottomLeftAddress = inTopLeftAddress + inputImages_strideHeight;
178 | const int inBottomRightAddress = inBottomLeftAddress + inputImages_strideWidth;
179 |
180 | const int gradInputImagesTopLeftAddress = gradInputImages_strideBatch * b + gradInputImages_strideHeight * yInTopLeft + gradInputImages_strideWidth * xInTopLeft;
181 | const int gradInputImagesTopRightAddress = gradInputImagesTopLeftAddress + gradInputImages_strideWidth;
182 | const int gradInputImagesBottomLeftAddress = gradInputImagesTopLeftAddress + gradInputImages_strideHeight;
183 | const int gradInputImagesBottomRightAddress = gradInputImagesBottomLeftAddress + gradInputImages_strideWidth;
184 |
185 | const int gradOutputAddress = gradOutput_strideBatch * b + gradOutput_strideHeight * yOut + gradOutput_strideWidth * xOut;
186 |
187 | real topLeftDotProduct = 0;
188 | real topRightDotProduct = 0;
189 | real bottomLeftDotProduct = 0;
190 | real bottomRightDotProduct = 0;
191 |
192 | real v=0;
193 | real inTopLeft=0;
194 | real inTopRight=0;
195 | real inBottomLeft=0;
196 | real inBottomRight=0;
197 |
198 | // we are careful with the boundaries
199 | bool topLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
200 | bool topRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
201 | bool bottomLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;
202 | bool bottomRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;
203 |
204 | int t;
205 |
206 | for(t=0; t
6 |
7 |
8 | static int nn_(ScaleBHWD_updateOutput)(lua_State *L)
9 | {
10 | THTensor *inputImages = luaT_checkudata(L, 2, torch_Tensor);
11 | THTensor *grids = luaT_checkudata(L, 3, torch_Tensor);
12 | THTensor *output = luaT_checkudata(L, 4, torch_Tensor);
13 |
14 | int batchsize = inputImages->size[0];
15 | int inputImages_height = inputImages->size[1];
16 | int inputImages_width = inputImages->size[2];
17 | int output_height = output->size[1];
18 | int output_width = output->size[2];
19 | int inputImages_channels = inputImages->size[3];
20 |
21 | int output_strideBatch = output->stride[0];
22 | int output_strideHeight = output->stride[1];
23 | int output_strideWidth = output->stride[2];
24 |
25 | int inputImages_strideBatch = inputImages->stride[0];
26 | int inputImages_strideHeight = inputImages->stride[1];
27 | int inputImages_strideWidth = inputImages->stride[2];
28 |
29 | int grids_strideBatch = grids->stride[0];
30 | int grids_strideHeight = grids->stride[1];
31 | int grids_strideWidth = grids->stride[2];
32 |
33 | real *inputImages_data, *output_data, *grids_data;
34 | inputImages_data = THTensor_(data)(inputImages);
35 | output_data = THTensor_(data)(output);
36 | grids_data = THTensor_(data)(grids);
37 |
38 | int b, yOut, xOut;
39 |
40 | for(b=0; b < batchsize; b++)
41 | {
42 | for(yOut=0; yOut < output_height; yOut++)
43 | {
44 | for(xOut=0; xOut < output_width; xOut++)
45 | {
46 | //read the grid
47 | real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth];
48 | real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1];
49 |
50 | // get the weights for interpolation
51 | int yInTopLeft, xInTopLeft;
52 | real yWeightTopLeft, xWeightTopLeft;
53 |
54 | real xcoord = (xf + 1) * (inputImages_width - 1) / 2;
55 | xInTopLeft = floor(xcoord);
56 | xWeightTopLeft = 1 - (xcoord - xInTopLeft);
57 |
58 | real ycoord = (yf + 1) * (inputImages_height - 1) / 2;
59 | yInTopLeft = floor(ycoord);
60 | yWeightTopLeft = 1 - (ycoord - yInTopLeft);
61 |
62 |
63 |
64 | const int outAddress = output_strideBatch * b + output_strideHeight * yOut + output_strideWidth * xOut;
65 | const int inTopLeftAddress = inputImages_strideBatch * b + inputImages_strideHeight * yInTopLeft + inputImages_strideWidth * xInTopLeft;
66 | const int inTopRightAddress = inTopLeftAddress + inputImages_strideWidth;
67 | const int inBottomLeftAddress = inTopLeftAddress + inputImages_strideHeight;
68 | const int inBottomRightAddress = inBottomLeftAddress + inputImages_strideWidth;
69 |
70 | real v=0;
71 | real inTopLeft=0;
72 | real inTopRight=0;
73 | real inBottomLeft=0;
74 | real inBottomRight=0;
75 |
76 | // we are careful with the boundaries
77 | bool topLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
78 | bool topRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
79 | bool bottomLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;
80 | bool bottomRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;
81 |
82 | int t;
83 | // interpolation happens here
84 | for(t=0; tsize[0];
117 | int inputImages_height = inputImages->size[1];
118 | int inputImages_width = inputImages->size[2];
119 | int gradOutput_height = gradOutput->size[1];
120 | int gradOutput_width = gradOutput->size[2];
121 | int inputImages_channels = inputImages->size[3];
122 |
123 | int gradOutput_strideBatch = gradOutput->stride[0];
124 | int gradOutput_strideHeight = gradOutput->stride[1];
125 | int gradOutput_strideWidth = gradOutput->stride[2];
126 |
127 | int inputImages_strideBatch = inputImages->stride[0];
128 | int inputImages_strideHeight = inputImages->stride[1];
129 | int inputImages_strideWidth = inputImages->stride[2];
130 |
131 | int gradInputImages_strideBatch = gradInputImages->stride[0];
132 | int gradInputImages_strideHeight = gradInputImages->stride[1];
133 | int gradInputImages_strideWidth = gradInputImages->stride[2];
134 |
135 | int grids_strideBatch = grids->stride[0];
136 | int grids_strideHeight = grids->stride[1];
137 | int grids_strideWidth = grids->stride[2];
138 |
139 | int gradGrids_strideBatch = gradGrids->stride[0];
140 | int gradGrids_strideHeight = gradGrids->stride[1];
141 | int gradGrids_strideWidth = gradGrids->stride[2];
142 |
143 | real *inputImages_data, *gradOutput_data, *grids_data, *gradGrids_data, *gradInputImages_data;
144 | inputImages_data = THTensor_(data)(inputImages);
145 | gradOutput_data = THTensor_(data)(gradOutput);
146 | grids_data = THTensor_(data)(grids);
147 | gradGrids_data = THTensor_(data)(gradGrids);
148 | gradInputImages_data = THTensor_(data)(gradInputImages);
149 |
150 | int b, yOut, xOut;
151 |
152 | for(b=0; b < batchsize; b++)
153 | {
154 | for(yOut=0; yOut < gradOutput_height; yOut++)
155 | {
156 | for(xOut=0; xOut < gradOutput_width; xOut++)
157 | {
158 | //read the grid
159 | real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth];
160 | real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1];
161 |
162 | // get the weights for interpolation
163 | int yInTopLeft, xInTopLeft;
164 | real yWeightTopLeft, xWeightTopLeft;
165 |
166 | real xcoord = (xf + 1) * (inputImages_width - 1) / 2;
167 | xInTopLeft = floor(xcoord);
168 | xWeightTopLeft = 1 - (xcoord - xInTopLeft);
169 |
170 | real ycoord = (yf + 1) * (inputImages_height - 1) / 2;
171 | yInTopLeft = floor(ycoord);
172 | yWeightTopLeft = 1 - (ycoord - yInTopLeft);
173 |
174 |
175 | const int inTopLeftAddress = inputImages_strideBatch * b + inputImages_strideHeight * yInTopLeft + inputImages_strideWidth * xInTopLeft;
176 | const int inTopRightAddress = inTopLeftAddress + inputImages_strideWidth;
177 | const int inBottomLeftAddress = inTopLeftAddress + inputImages_strideHeight;
178 | const int inBottomRightAddress = inBottomLeftAddress + inputImages_strideWidth;
179 |
180 | const int gradInputImagesTopLeftAddress = gradInputImages_strideBatch * b + gradInputImages_strideHeight * yInTopLeft + gradInputImages_strideWidth * xInTopLeft;
181 | const int gradInputImagesTopRightAddress = gradInputImagesTopLeftAddress + gradInputImages_strideWidth;
182 | const int gradInputImagesBottomLeftAddress = gradInputImagesTopLeftAddress + gradInputImages_strideHeight;
183 | const int gradInputImagesBottomRightAddress = gradInputImagesBottomLeftAddress + gradInputImages_strideWidth;
184 |
185 | const int gradOutputAddress = gradOutput_strideBatch * b + gradOutput_strideHeight * yOut + gradOutput_strideWidth * xOut;
186 |
187 | real topLeftDotProduct = 0;
188 | real topRightDotProduct = 0;
189 | real bottomLeftDotProduct = 0;
190 | real bottomRightDotProduct = 0;
191 |
192 | real v=0;
193 | real inTopLeft=0;
194 | real inTopRight=0;
195 | real inBottomLeft=0;
196 | real inBottomRight=0;
197 |
198 | // we are careful with the boundaries
199 | bool topLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
200 | bool topRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
201 | bool bottomLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;
202 | bool bottomRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;
203 |
204 | int t;
205 |
206 | for(t=0; t= 7.0",
18 | "nn >= 1.0",
19 | }
20 |
21 | build = {
22 | type = "command",
23 | build_command = [[
24 | cmake -E make_directory build && cd build && cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)" && $(MAKE)
25 | ]],
26 | install_command = "cd build && $(MAKE) install"
27 | }
28 |
--------------------------------------------------------------------------------
/extras/stnbhwd/test.lua:
--------------------------------------------------------------------------------
1 | -- you can easily test specific units like this:
2 | -- th -lnn -e "nn.test{'LookupTable'}"
3 | -- th -lnn -e "nn.test{'LookupTable', 'Add'}"
4 |
5 | local mytester = torch.Tester()
6 | local jac
7 | local sjac
8 |
9 | local precision = 1e-5
10 | local expprecision = 1e-4
11 |
12 | local stntest = {}
13 |
14 | function stntest.AffineGridGeneratorBHWD_batch()
15 | local nframes = torch.random(2,10)
16 | local height = torch.random(2,5)
17 | local width = torch.random(2,5)
18 | local input = torch.zeros(nframes, 2, 3):uniform()
19 | local module = nn.AffineGridGeneratorBHWD(height, width)
20 |
21 | local err = jac.testJacobian(module,input)
22 | mytester:assertlt(err,precision, 'error on state ')
23 |
24 | -- IO
25 | local ferr,berr = jac.testIO(module,input)
26 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
27 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
28 |
29 | end
30 |
31 | function stntest.AffineGridGeneratorBHWD_single()
32 | local height = torch.random(2,5)
33 | local width = torch.random(2,5)
34 | local input = torch.zeros(2, 3):uniform()
35 | local module = nn.AffineGridGeneratorBHWD(height, width)
36 |
37 | local err = jac.testJacobian(module,input)
38 | mytester:assertlt(err,precision, 'error on state ')
39 |
40 | -- IO
41 | local ferr,berr = jac.testIO(module,input)
42 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
43 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
44 |
45 | end
46 |
47 | function stntest.BilinearSamplerBHWD_batch()
48 | local nframes = torch.random(2,10)
49 | local height = torch.random(1,5)
50 | local width = torch.random(1,5)
51 | local channels = torch.random(1,6)
52 | local inputImages = torch.zeros(nframes, height, width, channels):uniform()
53 | local grids = torch.zeros(nframes, height, width, 2):uniform(-1, 1)
54 | local module = nn.BilinearSamplerBHWD()
55 |
56 | -- test input images (first element of input table)
57 | module._updateOutput = module.updateOutput
58 | function module:updateOutput(input)
59 | return self:_updateOutput({input, grids})
60 | end
61 |
62 | module._updateGradInput = module.updateGradInput
63 | function module:updateGradInput(input, gradOutput)
64 | self:_updateGradInput({input, grids}, gradOutput)
65 | return self.gradInput[1]
66 | end
67 |
68 | local errImages = jac.testJacobian(module,inputImages)
69 | mytester:assertlt(errImages,precision, 'error on state ')
70 |
71 | -- test grids (second element of input table)
72 | function module:updateOutput(input)
73 | return self:_updateOutput({inputImages, input})
74 | end
75 |
76 | function module:updateGradInput(input, gradOutput)
77 | self:_updateGradInput({inputImages, input}, gradOutput)
78 | return self.gradInput[2]
79 | end
80 |
81 | local errGrids = jac.testJacobian(module,grids)
82 | mytester:assertlt(errGrids,precision, 'error on state ')
83 | end
84 |
85 | function stntest.BilinearSamplerBHWD_single()
86 | local height = torch.random(1,5)
87 | local width = torch.random(1,5)
88 | local channels = torch.random(1,6)
89 | local inputImages = torch.zeros(height, width, channels):uniform()
90 | local grids = torch.zeros(height, width, 2):uniform(-1, 1)
91 | local module = nn.BilinearSamplerBHWD()
92 |
93 | -- test input images (first element of input table)
94 | module._updateOutput = module.updateOutput
95 | function module:updateOutput(input)
96 | return self:_updateOutput({input, grids})
97 | end
98 |
99 | module._updateGradInput = module.updateGradInput
100 | function module:updateGradInput(input, gradOutput)
101 | self:_updateGradInput({input, grids}, gradOutput)
102 | return self.gradInput[1]
103 | end
104 |
105 | local errImages = jac.testJacobian(module,inputImages)
106 | mytester:assertlt(errImages,precision, 'error on state ')
107 |
108 | -- test grids (second element of input table)
109 | function module:updateOutput(input)
110 | return self:_updateOutput({inputImages, input})
111 | end
112 |
113 | function module:updateGradInput(input, gradOutput)
114 | self:_updateGradInput({inputImages, input}, gradOutput)
115 | return self.gradInput[2]
116 | end
117 |
118 | local errGrids = jac.testJacobian(module,grids)
119 | mytester:assertlt(errGrids,precision, 'error on state ')
120 | end
121 |
122 | function stntest.AffineTransformMatrixGenerator_batch()
123 | -- test all possible transformations
124 | for _,useRotation in pairs{true,false} do
125 | for _,useScale in pairs{true,false} do
126 | for _,useTranslation in pairs{true,false} do
127 | local currTest = ''
128 | if useRotation then currTest = currTest..'rotation ' end
129 | if useScale then currTest = currTest..'scale ' end
130 | if useTranslation then currTest = currTest..'translation' end
131 | if currTest=='' then currTest = 'full' end
132 |
133 | local nbNeededParams = 0
134 | if useRotation then nbNeededParams = nbNeededParams + 1 end
135 | if useScale then nbNeededParams = nbNeededParams + 1 end
136 | if useTranslation then nbNeededParams = nbNeededParams + 2 end
137 | if nbNeededParams == 0 then nbNeededParams = 6 end -- full affine case
138 |
139 | local nframes = torch.random(2,10)
140 | local params = torch.zeros(nframes,nbNeededParams):uniform()
141 | local module = nn.AffineTransformMatrixGenerator(useRotation,useScale,useTranslation)
142 |
143 | local err = jac.testJacobian(module,params)
144 | mytester:assertlt(err,precision, 'error on state for test '..currTest)
145 |
146 | -- IO
147 | local ferr,berr = jac.testIO(module,params)
148 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err for test '..currTest)
149 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err for test '..currTest)
150 |
151 | end
152 | end
153 | end
154 | end
155 |
156 | function stntest.AffineTransformMatrixGenerator_single()
157 | -- test all possible transformations
158 | for _,useRotation in pairs{true,false} do
159 | for _,useScale in pairs{true,false} do
160 | for _,useTranslation in pairs{true,false} do
161 | local currTest = ''
162 | if useRotation then currTest = currTest..'rotation ' end
163 | if useScale then currTest = currTest..'scale ' end
164 | if useTranslation then currTest = currTest..'translation' end
165 | if currTest=='' then currTest = 'full' end
166 |
167 | local nbNeededParams = 0
168 | if useRotation then nbNeededParams = nbNeededParams + 1 end
169 | if useScale then nbNeededParams = nbNeededParams + 1 end
170 | if useTranslation then nbNeededParams = nbNeededParams + 2 end
171 | if nbNeededParams == 0 then nbNeededParams = 6 end -- full affine case
172 |
173 | local params = torch.zeros(nbNeededParams):uniform()
174 | local module = nn.AffineTransformMatrixGenerator(useRotation,useScale,useTranslation)
175 |
176 | local err = jac.testJacobian(module,params)
177 | mytester:assertlt(err,precision, 'error on state for test '..currTest)
178 |
179 | -- IO
180 | local ferr,berr = jac.testIO(module,params)
181 | mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err for test '..currTest)
182 | mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err for test '..currTest)
183 |
184 | end
185 | end
186 | end
187 | end
188 |
189 | mytester:add(stntest)
190 |
191 | if not nn then
192 | require 'nn'
193 | jac = nn.Jacobian
194 | sjac = nn.SparseJacobian
195 | mytester:run()
196 | else
197 | jac = nn.Jacobian
198 | sjac = nn.SparseJacobian
199 | function stn.test(tests)
200 | -- randomize stuff
201 | math.randomseed(os.time())
202 | mytester:run(tests)
203 | return mytester
204 | end
205 | end
206 |
--------------------------------------------------------------------------------
/extras/stnbhwd/utils.c:
--------------------------------------------------------------------------------
1 | #include "utils.h"
2 |
3 | THCState* getCutorchState(lua_State* L)
4 | {
5 | lua_getglobal(L, "cutorch");
6 | lua_getfield(L, -1, "getState");
7 | lua_call(L, 0, 1);
8 | THCState *state = (THCState*) lua_touserdata(L, -1);
9 | lua_pop(L, 2);
10 | return state;
11 | }
12 |
--------------------------------------------------------------------------------
/extras/stnbhwd/utils.h:
--------------------------------------------------------------------------------
1 | #ifndef CUNN_UTILS_H
2 | #define CUNN_UTILS_H
3 |
4 | #include
5 | #include "THCGeneral.h"
6 |
7 | THCState* getCutorchState(lua_State* L);
8 |
9 | #endif
10 |
--------------------------------------------------------------------------------
/flowExtensions.lua:
--------------------------------------------------------------------------------
1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
2 | -- All rights reserved.
3 | -- This software is provided for research purposes only.
4 | -- By using this software you agree to the terms of the license file
5 | -- in the root folder.
6 | -- For commercial use, please contact ps-license@tue.mpg.de.
7 |
8 | -------------------------
9 | -- Optical Flow Utilities
10 | -------------------------
11 | local stringx = require('pl.stringx')
12 | local M = {}
13 |
14 | local eps = 1e-6
15 |
16 | local function computeNorm(...)
17 | -- check args
18 | local _, flow_x, flow_y = xlua.unpack(
19 | {...},
20 | 'opticalflow.computeNorm',
21 | 'computes norm (size) of flow field from flow_x and flow_y,\n',
22 | {arg='flow_x', type='torch.Tensor', help='flow field (x), (WxH)', req=true},
23 | {arg='flow_y', type='torch.Tensor', help='flow field (y), (WxH)', req=true}
24 | )
25 | local flow_norm = torch.Tensor()
26 | local x_squared = torch.Tensor():resizeAs(flow_x):copy(flow_x):cmul(flow_x)
27 | flow_norm:resizeAs(flow_y):copy(flow_y):cmul(flow_y):add(x_squared):sqrt()
28 | return flow_norm
29 | end
30 | M.computeNorm = computeNorm
31 |
32 | ------------------------------------------------------------
33 | -- computes angle (direction) of flow field from flow_x and flow_y,
34 | --
35 | -- @usage opticalflow.computeAngle() -- prints online help
36 | --
37 | -- @param flow_x flow field (x), (WxH) [required] [type = torch.Tensor]
38 | -- @param flow_y flow field (y), (WxH) [required] [type = torch.Tensor]
39 | ------------------------------------------------------------
40 | local function computeAngle(...)
41 | -- check args
42 | local _, flow_x, flow_y = xlua.unpack(
43 | {...},
44 | 'opticalflow.computeAngle',
45 | 'computes angle (direction) of flow field from flow_x and flow_y,\n',
46 | {arg='flow_x', type='torch.Tensor', help='flow field (x), (WxH)', req=true},
47 | {arg='flow_y', type='torch.Tensor', help='flow field (y), (WxH)', req=true}
48 | )
49 | local flow_angle = torch.Tensor()
50 | flow_angle:resizeAs(flow_y):copy(flow_y):cdiv(flow_x):abs():atan():mul(180/math.pi)
51 | flow_angle:map2(flow_x, flow_y, function(h,x,y)
52 | if x == 0 and y >= 0 then
53 | return 90
54 | elseif x == 0 and y <= 0 then
55 | return 270
56 | elseif x >= 0 and y >= 0 then
57 | -- all good
58 | elseif x >= 0 and y < 0 then
59 | return 360 - h
60 | elseif x < 0 and y >= 0 then
61 | return 180 - h
62 | elseif x < 0 and y < 0 then
63 | return 180 + h
64 | end
65 | end)
66 | return flow_angle
67 | end
68 | M.computeAngle = computeAngle
69 | ------------------------------------------------------------
70 | -- merges Norm and Angle flow fields into a single RGB image,
71 | -- where saturation=intensity, and hue=direction
72 | --
73 | -- @usage opticalflow.field2rgb() -- prints online help
74 | --
75 | -- @param norm flow field (norm), (WxH) [required] [type = torch.Tensor]
76 | -- @param angle flow field (angle), (WxH) [required] [type = torch.Tensor]
77 | -- @param max if not provided, norm:max() is used [type = number]
78 | -- @param legend prints a legend on the image [type = boolean]
79 | ------------------------------------------------------------
80 | local function field2rgb(...)
81 | -- check args
82 | local _, norm, angle, max, legend = xlua.unpack(
83 | {...},
84 | 'opticalflow.field2rgb',
85 | 'merges Norm and Angle flow fields into a single RGB image,\n'
86 | .. 'where saturation=intensity, and hue=direction',
87 | {arg='norm', type='torch.Tensor', help='flow field (norm), (WxH)', req=true},
88 | {arg='angle', type='torch.Tensor', help='flow field (angle), (WxH)', req=true},
89 | {arg='max', type='number', help='if not provided, norm:max() is used'},
90 | {arg='legend', type='boolean', help='prints a legend on the image', default=false}
91 | )
92 |
93 | -- max
94 | local saturate = false
95 | if max then saturate = true end
96 | max = math.max(max or norm:max(), 1e-2)
97 |
98 | -- merge them into an HSL image
99 | local hsl = torch.Tensor(3,norm:size(1), norm:size(2))
100 | -- hue = angle:
101 | hsl:select(1,1):copy(angle):div(360)
102 | -- saturation = normalized intensity:
103 | hsl:select(1,2):copy(norm):div(max)
104 | if saturate then hsl:select(1,2):tanh() end
105 | -- light varies inversely from saturation (null flow = white):
106 | hsl:select(1,3):copy(hsl:select(1,2)):mul(-0.5):add(1)
107 |
108 | -- convert HSL to RGB
109 | local rgb = image.hsl2rgb(hsl)
110 |
111 | -- legend
112 | if legend then
113 | _legend_ = _legend_
114 | or image.load(paths.concat(paths.install_lua_path, 'opticalflow/legend.png'),3)
115 | legend = torch.Tensor(3,hsl:size(2)/8, hsl:size(2)/8)
116 | image.scale(_legend_, legend, 'bilinear')
117 | rgb:narrow(1,1,legend:size(2)):narrow(2,hsl:size(2)-legend:size(2)+1,legend:size(2)):copy(legend)
118 | end
119 |
120 | -- done
121 | return rgb
122 | end
123 | M.field2rgb = field2rgb
124 | ------------------------------------------------------------
125 | -- Simplifies display of flow field in HSV colorspace when the
126 | -- available field is in x,y displacement
127 | --
128 | -- @usage opticalflow.xy2rgb() -- prints online help
129 | --
130 | -- @param x flow field (x), (WxH) [required] [type = torch.Tensor]
131 | -- @param y flow field (y), (WxH) [required] [type = torch.Tensor]
132 | ------------------------------------------------------------
133 | local function xy2rgb(...)
134 | -- check args
135 | local _, x, y, max = xlua.unpack(
136 | {...},
137 | 'opticalflow.xy2rgb',
138 | 'merges x and y flow fields into a single RGB image,\n'
139 | .. 'where saturation=intensity, and hue=direction',
140 | {arg='x', type='torch.Tensor', help='flow field (norm), (WxH)', req=true},
141 | {arg='y', type='torch.Tensor', help='flow field (angle), (WxH)', req=true},
142 | {arg='max', type='number', help='if not provided, norm:max() is used'}
143 | )
144 |
145 | local norm = computeNorm(x,y)
146 | local angle = computeAngle(x,y)
147 | return field2rgb(norm,angle,max)
148 | end
149 | M.xy2rgb = xy2rgb
150 |
151 | local function loadFLO(filename)
152 | TAG_FLOAT = 202021.25
153 | local ff = torch.DiskFile(filename):binary()
154 | local tag = ff:readFloat()
155 | if tag ~= TAG_FLOAT then
156 | xerror('unable to read '..filename..
157 | ' perhaps bigendian error','readflo()')
158 | end
159 |
160 | local w = ff:readInt()
161 | local h = ff:readInt()
162 | local nbands = 2
163 | local tf = torch.FloatTensor(h, w, nbands)
164 | ff:readFloat(tf:storage())
165 | ff:close()
166 |
167 | local flow = tf:permute(3,1,2)
168 | return flow
169 | end
170 | M.loadFLO = loadFLO
171 |
172 | local function writeFLO(filename, F)
173 | F = F:permute(2,3,1):clone()
174 | TAG_FLOAT = 202021.25
175 | local ff = torch.DiskFile(filename, 'w'):binary()
176 | ff:writeFloat(TAG_FLOAT)
177 |
178 | ff:writeInt(F:size(2)) -- width
179 | ff:writeInt(F:size(1)) -- height
180 |
181 | ff:writeFloat(F:storage())
182 | ff:close()
183 | end
184 | M.writeFLO = writeFLO
185 |
186 | local function loadPFM(filename)
187 | ff = torch.DiskFile(filename):binary()
188 | local header = ff:readString("*l")
189 | local color, nbands
190 | if stringx.strip(header) == 'PF' then
191 | color = true
192 | nbands = 3
193 | else
194 | color = false
195 | nbands = 1
196 | end
197 | local dims = stringx.split(ff:readString("*l"))
198 | local scale = ff:readString("*l")
199 | if tonumber(scale) < 0 then
200 | ff:littleEndianEncoding()
201 | else
202 | ff:bigEndianEncoding()
203 | end
204 | local tf = ff:readFloat(dims[1]*dims[2]*nbands)
205 | ff:close()
206 | tf = torch.FloatTensor(tf):resize(dims[2],dims[1],nbands):permute(3,1,2)
207 | tf = image.vflip(tf)
208 | return tf[{{1,2},{},{}}]
209 | end
210 | M.loadPFM = loadPFM
211 |
212 | local function rotate(flow, angle)
213 | local flow_rot = image.rotate(flow, angle, 'simple')
214 | local fu = torch.mul(flow_rot[1], math.cos(-angle)) - torch.mul(flow_rot[2], math.sin(-angle))
215 | local fv = torch.mul(flow_rot[1], math.sin(-angle)) + torch.mul(flow_rot[2], math.cos(-angle))
216 | flow_rot[1]:copy(fu)
217 | flow_rot[2]:copy(fv)
218 |
219 | return flow_rot
220 | end
221 | M.rotate = rotate
222 |
223 | local function scale(flow, sc, opt)
224 | opt = opt or 'simple'
225 | local flow_scaled = image.scale(flow, '*'..sc, opt)*sc
226 |
227 | return flow_scaled
228 |
229 | end
230 | M.scale = scale
231 |
232 | local function scaleBatch(flow, sc)
233 | local flowR = torch.FloatTensor(opt.batchSize*2, flow:size(3), flow:size(4))
234 | local outputR = torch.FloatTensor(opt.batchSize, 2, flow:size(3)*sc, flow:size(4)*sc)
235 |
236 | flowR:copy(flow)
237 | local output = image.scale(flowR, '*'..sc, 'simple')*sc
238 | outputR:copy(output)
239 | return outputR
240 | end
241 | M.scaleBatch = scaleBatch
242 |
243 | return M
244 |
--------------------------------------------------------------------------------
/main.lua:
--------------------------------------------------------------------------------
1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
2 | -- All rights reserved.
3 | -- This software is provided for research purposes only.
4 | -- By using this software you agree to the terms of the license file
5 | -- in the root folder.
6 | -- For commercial use, please contact ps-license@tue.mpg.de.
7 |
8 | require 'torch'
9 | require 'cutorch'
10 | require 'paths'
11 | require 'xlua'
12 | require 'optim'
13 | require 'nn'
14 |
15 | torch.setdefaulttensortype('torch.FloatTensor')
16 |
17 | local opts = paths.dofile('opts.lua')
18 |
19 | opt = opts.parse(arg)
20 |
21 | print('Saving everything to: ' .. opt.save)
22 | os.execute('mkdir -p ' .. opt.save)
23 |
24 | paths.dofile('util.lua')
25 | paths.dofile('model.lua')
26 | opt.imageSize = model.imageSize or opt.imageSize
27 | opt.outputSize = model.outputSize or opt.outputSize
28 |
29 | print(opt)
30 |
31 | cutorch.setDevice(opt.GPU) -- by default, use GPU 1
32 | torch.manualSeed(opt.manualSeed)
33 |
34 | paths.dofile('data.lua')
35 | paths.dofile('train.lua')
36 | paths.dofile('test.lua')
37 |
38 | epoch = opt.epochNumber
39 |
40 | for i=1,opt.nEpochs do
41 | train()
42 | test()
43 | epoch = epoch + 1
44 | end
45 |
--------------------------------------------------------------------------------
/model.lua:
--------------------------------------------------------------------------------
1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
2 | -- All rights reserved.
3 | -- This software is provided for research purposes only.
4 | -- By using this software you agree to the terms of the license file
5 | -- in the root folder.
6 | -- For commercial use, please contact ps-license@tue.mpg.de.
7 | --
8 | -- Copyright (c) 2014, Facebook, Inc.
9 | -- All rights reserved.
10 | --
11 | -- This source code is licensed under the BSD-style license found in the
12 | -- LICENSE file in the root directory of this source tree. An additional grant
13 | -- of patent rights can be found in the PATENTS file in the same directory.
14 | --
15 | require 'nn'
16 | require 'cunn'
17 | require 'optim'
18 | include('EPECriterion.lua')
19 |
20 | --[[
21 | 1. Create Model
22 | 2. Create Criterion
23 | 3. Convert model to CUDA
24 | ]]--
25 |
26 | -- 1. Create Network
27 | -- 1.1 If preloading option is set, preload weights from existing models appropriately
28 | if opt.retrain ~= 'none' then
29 | assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
30 | print('Loading model from file: ' .. opt.retrain);
31 | model = loadDataParallel(opt.retrain, opt.nGPU) -- defined in util.lua
32 | else
33 | paths.dofile('models/' .. opt.netType .. '.lua')
34 | print('=> Creating model from file: models/' .. opt.netType .. '.lua')
35 | model = createModel(opt.nGPU) -- for the model creation code, check the models/ folder
36 | if opt.backend == 'cudnn' then
37 | require 'cudnn'
38 | cudnn.convert(model, cudnn)
39 | elseif opt.backend ~= 'nn' then
40 | error'Unsupported backend'
41 | end
42 | end
43 |
44 | -- 2. Create Criterion
45 | criterion = nn.EPECriterion()
46 |
47 | print('=> Model')
48 | print(model)
49 |
50 | print('=> Criterion')
51 | print(criterion)
52 |
53 | criterion:cuda()
54 |
55 | collectgarbage()
56 |
--------------------------------------------------------------------------------
/models/modelL1_3.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL1_3.t7
--------------------------------------------------------------------------------
/models/modelL1_4.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL1_4.t7
--------------------------------------------------------------------------------
/models/modelL1_C.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL1_C.t7
--------------------------------------------------------------------------------
/models/modelL1_F.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL1_F.t7
--------------------------------------------------------------------------------
/models/modelL1_K.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL1_K.t7
--------------------------------------------------------------------------------
/models/modelL2_3.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL2_3.t7
--------------------------------------------------------------------------------
/models/modelL2_4.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL2_4.t7
--------------------------------------------------------------------------------
/models/modelL2_C.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL2_C.t7
--------------------------------------------------------------------------------
/models/modelL2_F.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL2_F.t7
--------------------------------------------------------------------------------
/models/modelL2_K.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL2_K.t7
--------------------------------------------------------------------------------
/models/modelL3_3.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL3_3.t7
--------------------------------------------------------------------------------
/models/modelL3_4.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL3_4.t7
--------------------------------------------------------------------------------
/models/modelL3_C.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL3_C.t7
--------------------------------------------------------------------------------
/models/modelL3_F.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL3_F.t7
--------------------------------------------------------------------------------
/models/modelL3_K.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL3_K.t7
--------------------------------------------------------------------------------
/models/modelL4_3.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL4_3.t7
--------------------------------------------------------------------------------
/models/modelL4_4.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL4_4.t7
--------------------------------------------------------------------------------
/models/modelL4_C.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL4_C.t7
--------------------------------------------------------------------------------
/models/modelL4_F.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL4_F.t7
--------------------------------------------------------------------------------
/models/modelL4_K.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL4_K.t7
--------------------------------------------------------------------------------
/models/modelL5_3.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL5_3.t7
--------------------------------------------------------------------------------
/models/modelL5_4.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL5_4.t7
--------------------------------------------------------------------------------
/models/modelL5_C.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL5_C.t7
--------------------------------------------------------------------------------
/models/modelL5_F.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL5_F.t7
--------------------------------------------------------------------------------
/models/modelL5_K.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL5_K.t7
--------------------------------------------------------------------------------
/models/modelL6_C.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL6_C.t7
--------------------------------------------------------------------------------
/models/modelL6_F.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL6_F.t7
--------------------------------------------------------------------------------
/models/modelL6_K.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/models/modelL6_K.t7
--------------------------------------------------------------------------------
/models/volcon.lua:
--------------------------------------------------------------------------------
1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
2 | -- All rights reserved.
3 | -- This software is provided for research purposes only.
4 | -- By using this software you agree to the terms of the license file
5 | -- in the root folder.
6 | -- For commercial use, please contact ps-license@tue.mpg.de.
7 |
8 | require 'nn'
9 | require 'cutorch'
10 | require 'cunn'
11 | require 'cudnn'
12 | function createModel(nGPU)
13 | local model = nn.Sequential()
14 | model:add(nn.SpatialConvolution(8,32,7,7,1,1,3,3))
15 | model:add(nn.ReLU())
16 | model:add(nn.SpatialConvolution(32,64,7,7,1,1,3,3))
17 | model:add(nn.ReLU())
18 | model:add(nn.SpatialConvolution(64,32,7,7,1,1,3,3))
19 | model:add(nn.ReLU())
20 | model:add(nn.SpatialConvolution(32,16,7,7,1,1,3,3))
21 | model:add(nn.ReLU())
22 | model:add(nn.SpatialConvolution(16,2,7,7,1,1,3,3))
23 |
24 | if nGPU>0 then
25 | model:cuda()
26 | model = makeDataParallel(model, nGPU)
27 | end
28 |
29 | return model
30 | end
31 |
--------------------------------------------------------------------------------
/opts.lua:
--------------------------------------------------------------------------------
1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
2 | -- All rights reserved.
3 | -- This software is provided for research purposes only.
4 | -- By using this software you agree to the terms of the license file
5 | -- in the root folder.
6 | -- For commercial use, please contact ps-license@tue.mpg.de.
7 |
8 | local M = { }
9 |
10 | function M.parse(arg)
11 | local cmd = torch.CmdLine()
12 | cmd:text()
13 | cmd:text('SPyNet Coarse-to-Fine Optical Flow Training')
14 | cmd:text()
15 | cmd:text('Options:')
16 | ------------ General options --------------------
17 |
18 | cmd:option('-cache', 'checkpoint/', 'subdirectory in which to save/log experiments')
19 | cmd:option('-data', 'flying_chairs/data', 'Home of Flying Chairs dataset')
20 | cmd:option('-trainValidationSplit', 'train_val_split.txt', 'File containing training and validation split')
21 | cmd:option('-manualSeed', 2, 'Manually set RNG seed')
22 | cmd:option('-GPU', 1, 'Default preferred GPU')
23 | cmd:option('-nGPU', 1, 'Number of GPUs to use by default')
24 | cmd:option('-backend', 'cudnn', 'Options: cudnn | ccn2 | cunn')
25 | ------------- Data options ------------------------
26 | cmd:option('-nDonkeys', 4, 'number of donkeys to initialize (data loading threads)')
27 | cmd:option('-fineWidth', 512, 'the length of the fine flow field')
28 | cmd:option('-fineHeight', 384, 'the width of the fine flow field')
29 | cmd:option('-level', 1, 'Options: 1,2,3.., wheather to initialize flow to zero' )
30 | ------------- Training options --------------------
31 | cmd:option('-augment', 1, 'augment the data')
32 | cmd:option('-nEpochs', 1000, 'Number of total epochs to run')
33 | cmd:option('-epochSize', 1000, 'Number of batches per epoch')
34 | cmd:option('-epochNumber', 1, 'Manual epoch number (useful on restarts)')
35 | cmd:option('-batchSize', 32, 'mini-batch size (1 = pure stochastic)')
36 | ---------- Optimization options ----------------------
37 | cmd:option('-LR', 0.0, 'learning rate; if set, overrides default LR/WD recipe')
38 | cmd:option('-momentum', 0.9, 'momentum')
39 | cmd:option('-weightDecay', 5e-4, 'weight decay')
40 | cmd:option('-optimizer', 'adam', 'adam or sgd')
41 | ---------- Model options ----------------------------------
42 | cmd:option('-L1', 'models/modelL1_4.t7', 'Trained Level 1 model')
43 | cmd:option('-L2', 'models/modelL2_4.t7', 'Trained Level 2 model')
44 | cmd:option('-L3', 'models/modelL3_4.t7', 'Trained Level 3 model')
45 | cmd:option('-L4', 'models/modelL4_4.t7', 'Trained Level 4 model')
46 |
47 | cmd:option('-netType', 'volcon', 'Lua network file')
48 | cmd:option('-retrain', 'none', 'provide path to model to retrain with')
49 | cmd:option('-optimState', 'none', 'provide path to an optimState to reload from')
50 | cmd:text()
51 |
52 | local opt = cmd:parse(arg or {})
53 | opt.save = paths.concat(opt.cache)
54 | -- add date/time
55 | opt.save = paths.concat(opt.save, '' .. os.date():gsub(' ',''))
56 |
57 | opt.loadSize = {8, 384, 512}
58 | return opt
59 | end
60 |
61 | return M
62 |
--------------------------------------------------------------------------------
/samples/00001_flow.flo:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00001_flow.flo
--------------------------------------------------------------------------------
/samples/00001_img1.ppm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00001_img1.ppm
--------------------------------------------------------------------------------
/samples/00001_img2.ppm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00001_img2.ppm
--------------------------------------------------------------------------------
/samples/00002_flow.flo:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00002_flow.flo
--------------------------------------------------------------------------------
/samples/00002_img1.ppm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00002_img1.ppm
--------------------------------------------------------------------------------
/samples/00002_img2.ppm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00002_img2.ppm
--------------------------------------------------------------------------------
/samples/00003_flow.flo:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00003_flow.flo
--------------------------------------------------------------------------------
/samples/00003_img1.ppm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00003_img1.ppm
--------------------------------------------------------------------------------
/samples/00003_img2.ppm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/spynet/7c4a3f7d1a5879a50361fff3e1b8fa35ec397fcd/samples/00003_img2.ppm
--------------------------------------------------------------------------------
/spynet.lua:
--------------------------------------------------------------------------------
1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
2 | -- All rights reserved.
3 | -- This software is provided for research purposes only.
4 | -- By using this software you agree to the terms of the license file
5 | -- in the root folder.
6 | -- For commercial use, please contact ps-license@tue.mpg.de.
7 |
8 | require 'image'
9 | local TF = require 'transforms'
10 | require 'cutorch'
11 | require 'nn'
12 | require 'cunn'
13 | require 'cudnn'
14 | require 'nngraph'
15 | require 'stn'
16 | require 'spy'
17 | local flowX = require 'flowExtensions'
18 |
19 | local M = {}
20 |
21 | local eps = 1e-6
22 | local meanstd = {
23 | mean = { 0.485, 0.456, 0.406 },
24 | std = { 0.229, 0.224, 0.225 },
25 | }
26 | local pca = {
27 | eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 },
28 | eigvec = torch.Tensor{
29 | { -0.5675, 0.7192, 0.4009 },
30 | { -0.5808, -0.0045, -0.8140 },
31 | { -0.5836, -0.6948, 0.4203 },
32 | },
33 | }
34 |
35 | local mean = meanstd.mean
36 | local std = meanstd.std
37 | ------------------------------------------
38 | local function createWarpModel()
39 | local imgData = nn.Identity()()
40 | local floData = nn.Identity()()
41 |
42 | local imgOut = nn.Transpose({2,3},{3,4})(imgData)
43 | local floOut = nn.Transpose({2,3},{3,4})(floData)
44 |
45 | local warpImOut = nn.Transpose({3,4},{2,3})(nn.BilinearSamplerBHWD()({imgOut, floOut}))
46 | local model = nn.gModule({imgData, floData}, {warpImOut})
47 |
48 | return model
49 | end
50 |
51 | local down2 = nn.SpatialAveragePooling(2,2,2,2):cuda()
52 | local down3 = nn.SpatialAveragePooling(2,2,2,2):cuda()
53 | local down4 = nn.SpatialAveragePooling(2,2,2,2):cuda()
54 | local down5 = nn.SpatialAveragePooling(2,2,2,2):cuda()
55 | local down6 = nn.SpatialAveragePooling(2,2,2,2):cuda()
56 |
57 | local up2 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda()
58 | local up3 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda()
59 | local up4 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda()
60 | local up5 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda()
61 | local up6 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda()
62 |
63 | local warpmodel2 = createWarpModel():cuda()
64 | local warpmodel3 = createWarpModel():cuda()
65 | local warpmodel4 = createWarpModel():cuda()
66 | local warpmodel5 = createWarpModel():cuda()
67 | local warpmodel6 = createWarpModel():cuda()
68 |
69 | down2:evaluate()
70 | down3:evaluate()
71 | down4:evaluate()
72 | down5:evaluate()
73 | down6:evaluate()
74 |
75 | up2:evaluate()
76 | up3:evaluate()
77 | up4:evaluate()
78 | up5:evaluate()
79 | up6:evaluate()
80 |
81 | warpmodel2:evaluate()
82 | warpmodel3:evaluate()
83 | warpmodel4:evaluate()
84 | warpmodel5:evaluate()
85 | warpmodel6:evaluate()
86 |
87 | -------------------------------------------------
88 | local modelL1, modelL2, modelL3, modelL4, modelL5, modelL6
89 | local modelL1path, modelL2path, modelL3path, modelL4path, modelL5path, modelL6path
90 |
91 | local function loadImage(path)
92 | local input = image.load(path, 3, 'float')
93 | return input
94 | end
95 | M.loadImage = loadImage
96 |
97 | local function loadFlow(filename)
98 | TAG_FLOAT = 202021.25
99 | local ff = torch.DiskFile(filename):binary()
100 | local tag = ff:readFloat()
101 | if tag ~= TAG_FLOAT then
102 | xerror('unable to read '..filename..
103 | ' perhaps bigendian error','readflo()')
104 | end
105 |
106 | local w = ff:readInt()
107 | local h = ff:readInt()
108 | local nbands = 2
109 | local tf = torch.FloatTensor(h, w, nbands)
110 | ff:readFloat(tf:storage())
111 | ff:close()
112 |
113 | local flow = tf:permute(3,1,2)
114 | return flow
115 | end
116 | M.loadFlow = loadFlow
117 |
118 |
119 | local function computeInitFlowL1(imagesL1)
120 | local h = imagesL1:size(3)
121 | local w = imagesL1:size(4)
122 | local batchSize = imagesL1:size(1)
123 |
124 | local _flowappend = torch.zeros(batchSize, 2, h, w):cuda()
125 | local images_in = torch.cat(imagesL1, _flowappend, 2)
126 |
127 | local flow_est = modelL1:forward(images_in)
128 | return flow_est
129 | end
130 | M.computeInitFlowL1 = computeInitFlowL1
131 |
132 | local function computeInitFlowL2(imagesL2)
133 | local imagesL1 = down2:forward(imagesL2:clone())
134 | local _flowappend = up2:forward(computeInitFlowL1(imagesL1))*2
135 | local _img2 = imagesL2[{{},{4,6},{},{}}]
136 | imagesL2[{{},{4,6},{},{}}]:copy(warpmodel2:forward({_img2, _flowappend}))
137 |
138 | local images_in = torch.cat(imagesL2, _flowappend, 2)
139 |
140 | local flow_est = modelL2:forward(images_in)
141 | return flow_est:add(_flowappend)
142 | end
143 | M.computeInitFlowL2 = computeInitFlowL2
144 |
145 | local function computeInitFlowL3(imagesL3)
146 | local imagesL2 = down3:forward(imagesL3:clone())
147 | local _flowappend = up3:forward(computeInitFlowL2(imagesL2))*2
148 | local _img2 = imagesL3[{{},{4,6},{},{}}]
149 | imagesL3[{{},{4,6},{},{}}]:copy(warpmodel3:forward({_img2, _flowappend}))
150 |
151 | local images_in = torch.cat(imagesL3, _flowappend, 2)
152 |
153 | local flow_est = modelL3:forward(images_in)
154 | return flow_est:add(_flowappend)
155 | end
156 | M.computeInitFlowL3 = computeInitFlowL3
157 |
158 | local function computeInitFlowL4(imagesL4)
159 | local imagesL3 = down4:forward(imagesL4)
160 | local _flowappend = up4:forward(computeInitFlowL3(imagesL3))*2
161 | local _img2 = imagesL4[{{},{4,6},{},{}}]
162 | imagesL4[{{},{4,6},{},{}}]:copy(warpmodel4:forward({_img2, _flowappend}))
163 |
164 | local images_in = torch.cat(imagesL4, _flowappend, 2)
165 |
166 | local flow_est = modelL4:forward(images_in)
167 | return flow_est:add(_flowappend)
168 | end
169 | M.computeInitFlowL4 = computeInitFlowL4
170 |
171 | local function computeInitFlowL5(imagesL5)
172 | local imagesL4 = down5:forward(imagesL5)
173 | local _flowappend = up5:forward(computeInitFlowL4(imagesL4))*2
174 |
175 | local _img2 = imagesL5[{{},{4,6},{},{}}]
176 | imagesL5[{{},{4,6},{},{}}]:copy(warpmodel5:forward({_img2, _flowappend}))
177 |
178 | local images_in = torch.cat(imagesL5, _flowappend, 2)
179 |
180 | local flow_est = modelL5:forward(images_in)
181 | return flow_est:add(_flowappend)
182 | end
183 | M.computeInitFlowL5 = computeInitFlowL5
184 |
185 | local function computeInitFlowL6(imagesL6)
186 | local imagesL5 = down6:forward(imagesL6)
187 | local _flowappend = up6:forward(computeInitFlowL5(imagesL5))*2
188 |
189 | local _img2 = imagesL6[{{},{4,6},{},{}}]
190 | imagesL6[{{},{4,6},{},{}}]:copy(warpmodel6:forward({_img2, _flowappend}))
191 |
192 | local images_in = torch.cat(imagesL6, _flowappend, 2)
193 |
194 | local flow_est = modelL6:forward(images_in)
195 | return flow_est:add(_flowappend)
196 | end
197 | M.computeInitFlowL6 = computeInitFlowL6
198 |
199 |
200 | local function setup(width, height, opt)
201 | opt = opt or "sintelFinal"
202 | local len = math.max(width, height)
203 | local computeFlow
204 | local level
205 |
206 | if len <= 32 then
207 | computeFlow = computeInitFlowL1
208 | level = 1
209 | elseif len <= 64 then
210 | computeFlow = computeInitFlowL2
211 | level = 2
212 | elseif len <= 128 then
213 | computeFlow = computeInitFlowL3
214 | level = 3
215 | elseif len <= 256 then
216 | computeFlow = computeInitFlowL4
217 | level = 4
218 | elseif len <= 512 then
219 | computeFlow = computeInitFlowL5
220 | level = 5
221 | elseif len <= 1472 then
222 | computeFlow = computeInitFlowL6
223 | level = 6
224 | else
225 | error("Only image size <= 1472 supported. Next release will have full support.")
226 | end
227 |
228 | if opt=="sintelFinal" then
229 | modelL1path = paths.concat('models', 'modelL1_F.t7')
230 | modelL2path = paths.concat('models', 'modelL2_F.t7')
231 | modelL3path = paths.concat('models', 'modelL3_F.t7')
232 | modelL4path = paths.concat('models', 'modelL4_F.t7')
233 | modelL5path = paths.concat('models', 'modelL5_F.t7')
234 | modelL6path = paths.concat('models', 'modelL6_F.t7')
235 | end
236 |
237 | if opt=="sintelClean" then
238 | modelL1path = paths.concat('models', 'modelL1_C.t7')
239 | modelL2path = paths.concat('models', 'modelL2_C.t7')
240 | modelL3path = paths.concat('models', 'modelL3_C.t7')
241 | modelL4path = paths.concat('models', 'modelL4_C.t7')
242 | modelL5path = paths.concat('models', 'modelL5_C.t7')
243 | modelL6path = paths.concat('models', 'modelL6_C.t7')
244 | end
245 |
246 | if opt=="chairsClean" then
247 | modelL1path = paths.concat('models', 'modelL1_4.t7')
248 | modelL2path = paths.concat('models', 'modelL2_4.t7')
249 | modelL3path = paths.concat('models', 'modelL3_4.t7')
250 | modelL4path = paths.concat('models', 'modelL4_4.t7')
251 | modelL5path = paths.concat('models', 'modelL5_4.t7')
252 | modelL6path = paths.concat('models', 'modelL5_4.t7')
253 | end
254 |
255 | if opt=="chairsFinal" then
256 | modelL1path = paths.concat('models', 'modelL1_3.t7')
257 | modelL2path = paths.concat('models', 'modelL2_3.t7')
258 | modelL3path = paths.concat('models', 'modelL3_3.t7')
259 | modelL4path = paths.concat('models', 'modelL4_3.t7')
260 | modelL5path = paths.concat('models', 'modelL5_3.t7')
261 | modelL6path = paths.concat('models', 'modelL5_3.t7')
262 | end
263 |
264 | if opt=="kittiFinal" then
265 | modelL1path = paths.concat('models', 'modelL1_K.t7')
266 | modelL2path = paths.concat('models', 'modelL2_K.t7')
267 | modelL3path = paths.concat('models', 'modelL3_K.t7')
268 | modelL4path = paths.concat('models', 'modelL4_K.t7')
269 | modelL5path = paths.concat('models', 'modelL5_K.t7')
270 | modelL6path = paths.concat('models', 'modelL6_K.t7')
271 | end
272 |
273 |
274 | if level>0 then
275 | modelL1 = torch.load(modelL1path)
276 | if torch.type(modelL1) == 'nn.DataParallelTable' then
277 | modelL1 = modelL1:get(1)
278 | end
279 | modelL1:evaluate()
280 | end
281 |
282 | if level>1 then
283 | modelL2 = torch.load(modelL2path)
284 | if torch.type(modelL2) == 'nn.DataParallelTable' then
285 | modelL2 = modelL2:get(1)
286 | end
287 | modelL2:evaluate()
288 | end
289 |
290 | if level>2 then
291 | modelL3 = torch.load(modelL3path)
292 | if torch.type(modelL3) == 'nn.DataParallelTable' then
293 | modelL3 = modelL3:get(1)
294 | end
295 | modelL3:evaluate()
296 | end
297 |
298 | if level>3 then
299 | modelL4 = torch.load(modelL4path)
300 | if torch.type(modelL4) == 'nn.DataParallelTable' then
301 | modelL4 = modelL4:get(1)
302 | end
303 | modelL4:evaluate()
304 | end
305 |
306 | if level>4 then
307 | modelL5 = torch.load(modelL5path)
308 | if torch.type(modelL5) == 'nn.DataParallelTable' then
309 | modelL5 = modelL5:get(1)
310 | end
311 | modelL5:evaluate()
312 | end
313 |
314 | if level>5 then
315 | modelL6 = torch.load(modelL6path)
316 | if torch.type(modelL6) == 'nn.DataParallelTable' then
317 | modelL6 = modelL6:get(1)
318 | end
319 | modelL6:evaluate()
320 | end
321 |
322 | return computeFlow
323 | end
324 | M.setup = setup
325 |
326 | local function DeAdjustFlow(flow, h, w)
327 | local sc_h = h/flow:size(2)
328 | local sc_w = w/flow:size(3)
329 | flow = image.scale(flow, w, h, 'simple')
330 | flow[2] = flow[2]*sc_h
331 | flow[1] = flow[1]*sc_w
332 |
333 | return flow
334 | end
335 | M.DeAdjustFlow = DeAdjustFlow
336 |
337 | local function normalize(imgs)
338 | return TF.ColorNormalize(meanstd)(imgs)
339 | end
340 | M.normalize = normalize
341 |
342 | local easyComputeFlow = function(im1, im2)
343 | local imgs = torch.cat(im1, im2, 1)
344 | imgs = TF.ColorNormalize(meanstd)(imgs)
345 |
346 | local width = imgs:size(3)
347 | local height = imgs:size(2)
348 |
349 | local fineWidth, fineHeight
350 |
351 | if width%32 == 0 then
352 | fineWidth = width
353 | else
354 | fineWidth = width + 32 - math.fmod(width, 32)
355 | end
356 |
357 | if height%32 == 0 then
358 | fineHeight = height
359 | else
360 | fineHeight = height + 32 - math.fmod(height, 32)
361 | end
362 |
363 | imgs = image.scale(imgs, fineWidth, fineHeight)
364 |
365 | local len = math.max(fineWidth, fineHeight)
366 | local computeFlow
367 |
368 | if len <= 32 then
369 | computeFlow = computeInitFlowL1
370 | elseif len <= 64 then
371 | computeFlow = computeInitFlowL2
372 | elseif len <= 128 then
373 | computeFlow = computeInitFlowL3
374 | elseif len <= 256 then
375 | computeFlow = computeInitFlowL4
376 | elseif len <= 512 then
377 | computeFlow = computeInitFlowL5
378 | else
379 | computeFlow = computeInitFlowL6
380 | end
381 |
382 | imgs = imgs:resize(1,6,fineHeight,fineWidth):cuda()
383 | local flow_est = computeFlow(imgs)
384 |
385 | flow_est = flow_est:squeeze():float()
386 | flow_est = DeAdjustFlow(flow_est, height, width)
387 |
388 | return flow_est
389 |
390 | end
391 |
392 | local function easy_setup(opt)
393 | opt = opt or 'sintelFinal'
394 |
395 | if opt=="sintelFinal" then
396 | modelL1path = paths.concat('models', 'modelL1_F.t7')
397 | modelL2path = paths.concat('models', 'modelL2_F.t7')
398 | modelL3path = paths.concat('models', 'modelL3_F.t7')
399 | modelL4path = paths.concat('models', 'modelL4_F.t7')
400 | modelL5path = paths.concat('models', 'modelL5_F.t7')
401 | modelL6path = paths.concat('models', 'modelL6_F.t7')
402 | end
403 |
404 | if opt=="sintelClean" then
405 | modelL1path = paths.concat('models', 'modelL1_C.t7')
406 | modelL2path = paths.concat('models', 'modelL2_C.t7')
407 | modelL3path = paths.concat('models', 'modelL3_C.t7')
408 | modelL4path = paths.concat('models', 'modelL4_C.t7')
409 | modelL5path = paths.concat('models', 'modelL5_C.t7')
410 | modelL6path = paths.concat('models', 'modelL6_C.t7')
411 | end
412 |
413 | if opt=="chairsClean" then
414 | modelL1path = paths.concat('models', 'modelL1_4.t7')
415 | modelL2path = paths.concat('models', 'modelL2_4.t7')
416 | modelL3path = paths.concat('models', 'modelL3_4.t7')
417 | modelL4path = paths.concat('models', 'modelL4_4.t7')
418 | modelL5path = paths.concat('models', 'modelL5_4.t7')
419 | modelL6path = paths.concat('models', 'modelL5_4.t7')
420 | end
421 |
422 | if opt=="chairsFinal" then
423 | modelL1path = paths.concat('models', 'modelL1_3.t7')
424 | modelL2path = paths.concat('models', 'modelL2_3.t7')
425 | modelL3path = paths.concat('models', 'modelL3_3.t7')
426 | modelL4path = paths.concat('models', 'modelL4_3.t7')
427 | modelL5path = paths.concat('models', 'modelL5_3.t7')
428 | modelL6path = paths.concat('models', 'modelL5_3.t7')
429 | end
430 |
431 | if opt=="kittiFinal" then
432 | modelL1path = paths.concat('models', 'modelL1_K.t7')
433 | modelL2path = paths.concat('models', 'modelL2_K.t7')
434 | modelL3path = paths.concat('models', 'modelL3_K.t7')
435 | modelL4path = paths.concat('models', 'modelL4_K.t7')
436 | modelL5path = paths.concat('models', 'modelL5_K.t7')
437 | modelL6path = paths.concat('models', 'modelL6_K.t7')
438 | end
439 |
440 | modelL1 = torch.load(modelL1path)
441 | if torch.type(modelL1) == 'nn.DataParallelTable' then
442 | modelL1 = modelL1:get(1)
443 | end
444 | modelL1:evaluate()
445 |
446 | modelL2 = torch.load(modelL2path)
447 | if torch.type(modelL2) == 'nn.DataParallelTable' then
448 | modelL2 = modelL2:get(1)
449 | end
450 | modelL2:evaluate()
451 |
452 | modelL3 = torch.load(modelL3path)
453 | if torch.type(modelL3) == 'nn.DataParallelTable' then
454 | modelL3 = modelL3:get(1)
455 | end
456 | modelL3:evaluate()
457 |
458 | modelL4 = torch.load(modelL4path)
459 | if torch.type(modelL4) == 'nn.DataParallelTable' then
460 | modelL4 = modelL4:get(1)
461 | end
462 | modelL4:evaluate()
463 |
464 | modelL5 = torch.load(modelL5path)
465 | if torch.type(modelL5) == 'nn.DataParallelTable' then
466 | modelL5 = modelL5:get(1)
467 | end
468 | modelL5:evaluate()
469 |
470 | modelL6 = torch.load(modelL6path)
471 | if torch.type(modelL6) == 'nn.DataParallelTable' then
472 | modelL6 = modelL6:get(1)
473 | end
474 | modelL6:evaluate()
475 | return easyComputeFlow
476 | end
477 | M.easy_setup = easy_setup
478 |
479 |
480 |
481 | return M
--------------------------------------------------------------------------------
/test.lua:
--------------------------------------------------------------------------------
1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
2 | -- All rights reserved.
3 | -- This software is provided for research purposes only.
4 | -- By using this software you agree to the terms of the license file
5 | -- in the root folder.
6 | -- For commercial use, please contact ps-license@tue.mpg.de.
7 | --
8 | -- Copyright (c) 2014, Facebook, Inc.
9 | -- All rights reserved.
10 | --
11 | -- This source code is licensed under the BSD-style license found in the
12 | -- LICENSE file in the root directory of this source tree. An additional grant
13 | -- of patent rights can be found in the PATENTS file in the same directory.
14 | --
15 | testLogger = optim.Logger(paths.concat(opt.save, 'test.log'))
16 |
17 | local batchNumber
18 | local error_center, loss
19 | local timer = torch.Timer()
20 |
21 | function test()
22 | print('==> doing epoch on validation data:')
23 | print("==> online epoch # " .. epoch)
24 |
25 | batchNumber = 0
26 | cutorch.synchronize()
27 | timer:reset()
28 |
29 | -- set the dropouts to evaluate mode
30 | model:evaluate()
31 |
32 | error_center = 0
33 | loss = 0
34 | for i=1,nTest/opt.batchSize do -- nTest is set in 1_data.lua
35 | local indexStart = (i-1) * opt.batchSize + 1
36 | local indexEnd = (indexStart + opt.batchSize - 1)
37 | donkeys:addjob(
38 | -- work to be done by donkey thread
39 | function()
40 | local inputs, labels = testLoader:get(indexStart, indexEnd)
41 | return inputs, labels
42 | end,
43 | -- callback that is run in the main thread once the work is done
44 | testBatch
45 | )
46 | end
47 |
48 | donkeys:synchronize()
49 | cutorch.synchronize()
50 |
51 | error_center = error_center * 100 / nTest
52 | loss = loss / (nTest/opt.batchSize) -- because loss is calculated per batch
53 | testLogger:add{
54 | ['% top1 accuracy (test set) (center crop)'] = error_center,
55 | ['avg loss (test set)'] = loss
56 | }
57 | print(string.format('Epoch: [%d][TESTING SUMMARY] Total Time(s): %.2f \t'
58 | .. 'average loss (per batch): %.2f \t '
59 | .. 'accuracy [Center](%%):\t top-1 %.2f\t ',
60 | epoch, timer:time().real, loss, error_center))
61 |
62 | print('\n')
63 |
64 |
65 | end -- of test()
66 | -----------------------------------------------------------------------------
67 | local inputs = torch.CudaTensor()
68 | local labels = torch.CudaTensor()
69 |
70 | function testBatch(inputsCPU, labelsCPU)
71 | batchNumber = batchNumber + opt.batchSize
72 |
73 | inputs:resize(inputsCPU:size()):copy(inputsCPU)
74 | labels:resize(labelsCPU:size()):copy(labelsCPU)
75 |
76 | local outputs = model:forward(inputs)
77 | local err = criterion:forward(outputs, labels)
78 | cutorch.synchronize()
79 | local pred = outputs:float()
80 |
81 | loss = loss + err
82 |
83 | print(('Epoch: Testing [%d][%d/%d]'):format(epoch, batchNumber, nTest))
84 | end
85 |
--------------------------------------------------------------------------------
/timing_benchmark.lua:
--------------------------------------------------------------------------------
1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
2 | -- All rights reserved.
3 | -- This software is provided for research purposes only.
4 | -- By using this software you agree to the terms of the license file
5 | -- in the root folder.
6 | -- For commercial use, please contact ps-license@tue.mpg.de.
7 |
8 | require 'image'
9 | require 'cutorch'
10 |
11 | local cmd = torch.CmdLine()
12 | cmd:option('-data', '../FlyingChairs/data', 'Flying Chairs data directory')
13 | opt = cmd:parse(arg or {})
14 |
15 | opt.showFlow = 0
16 | opt.fineHeight = 384
17 | opt.fineWidth = 512
18 | opt.preprocess = 0
19 | opt.level = 5
20 | opt.polluteFlow = 0
21 | opt.augment = 0
22 | opt.warp = 1
23 | opt.batchSize = 1
24 | local donkey = require('timing_util')
25 |
26 | local train_samples, validation_samples = donkey.getTrainValidationSplits('train_val_split.txt')
27 | local loss = torch.zeros(1,1, opt.fineHeight, opt.fineWidth):float()
28 | local errors = torch.zeros(validation_samples:size()[1])
29 | timings = torch.zeros(validation_samples:size()[1])
30 | local loss = 0
31 | local flowCPU = cutorch.createCudaHostTensor(640, 2,opt.fineHeight,opt.fineWidth):uniform()
32 |
33 | for i=1,validation_samples:size()[1] do
34 | collectgarbage()
35 |
36 | local id = validation_samples[i][1]
37 | local imgs, flow = donkey.testHook(id)
38 |
39 | timer = torch.Timer()
40 | imgs = imgs:resize(1,6,opt.fineHeight, opt.fineWidth):cuda()
41 | flow_est = donkey.computeInitFlowL5(imgs):squeeze()
42 | flowCPU[i]:copyAsync(flow_est)
43 | cutorch.streamSynchronize(cutorch.getStream())
44 | local time_elapsed = timer:time().real
45 |
46 | print('Time Elapsed: '..time_elapsed)
47 |
48 | timings[i] = time_elapsed
49 | end
50 | cutorch.streamSynchronize(cutorch.getStream())
51 |
52 |
53 | for i=1,validation_samples:size()[1] do
54 | local id = validation_samples[i][1]
55 | local raw_im1, raw_im2, raw_flow = donkey.getRawData(id)
56 |
57 | local _err = (raw_flow - flowCPU[i]):pow(2)
58 | local err = torch.sum(_err, 1):sqrt()
59 | loss = loss + err:float()
60 | errors[i] = err:mean()
61 |
62 | print(i, errors[i])
63 | end
64 | loss = torch.div(loss, validation_samples:size()[1])
65 | print('Average EPE = '..loss:sum()/(opt.fineWidth*opt.fineHeight))
66 | print('Mean Timing: ' ..timings:mean())
67 | print('Median Timing: ' ..timings:median()[1])
68 |
--------------------------------------------------------------------------------
/timing_util.lua:
--------------------------------------------------------------------------------
1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
2 | -- All rights reserved.
3 | -- This software is provided for research purposes only.
4 | -- By using this software you agree to the terms of the license file
5 | -- in the root folder.
6 | -- For commercial use, please contact ps-license@tue.mpg.de.
7 |
8 | require 'image'
9 | local TF = require 'transforms'
10 | require 'cutorch'
11 | require 'nn'
12 | require 'cunn'
13 | require 'cudnn'
14 | require 'nngraph'
15 | require 'stn'
16 | require 'spy'
17 | local flowX = require 'flowExtensions'
18 |
19 | local M = {}
20 |
21 | local eps = 1e-6
22 | local meanstd = {
23 | mean = { 0.485, 0.456, 0.406 },
24 | std = { 0.229, 0.224, 0.225 },
25 | }
26 | local pca = {
27 | eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 },
28 | eigvec = torch.Tensor{
29 | { -0.5675, 0.7192, 0.4009 },
30 | { -0.5808, -0.0045, -0.8140 },
31 | { -0.5836, -0.6948, 0.4203 },
32 | },
33 | }
34 |
35 | local mean = meanstd.mean
36 | local std = meanstd.std
37 | ------------------------------------------
38 | local function createWarpModel()
39 | local imgData = nn.Identity()()
40 | local floData = nn.Identity()()
41 |
42 | local imgOut = nn.Transpose({2,3},{3,4})(imgData)
43 | local floOut = nn.Transpose({2,3},{3,4})(floData)
44 |
45 | local warpImOut = nn.Transpose({3,4},{2,3})(nn.BilinearSamplerBHWD()({imgOut, floOut}))
46 | local model = nn.gModule({imgData, floData}, {warpImOut})
47 |
48 | return model
49 | end
50 |
51 | local down2 = nn.SpatialAveragePooling(2,2,2,2):cuda()
52 | local down3 = nn.SpatialAveragePooling(2,2,2,2):cuda()
53 | local down4 = nn.SpatialAveragePooling(2,2,2,2):cuda()
54 | local down5 = nn.SpatialAveragePooling(2,2,2,2):cuda()
55 |
56 | local up2 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda()
57 | local up3 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda()
58 | local up4 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda()
59 | local up5 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda()
60 |
61 | local warpmodel2 = createWarpModel():cuda()
62 | local warpmodel3 = createWarpModel():cuda()
63 | local warpmodel4 = createWarpModel():cuda()
64 | local warpmodel5 = createWarpModel():cuda()
65 |
66 | down2:evaluate()
67 | down3:evaluate()
68 | down4:evaluate()
69 | down5:evaluate()
70 |
71 | up2:evaluate()
72 | up3:evaluate()
73 | up4:evaluate()
74 | up5:evaluate()
75 |
76 | warpmodel2:evaluate()
77 | warpmodel3:evaluate()
78 | warpmodel4:evaluate()
79 | warpmodel5:evaluate()
80 |
81 | -------------------------------------------------
82 | local modelL0, modelL1, modelL2, modelL3, modelL4, modelL5
83 | local modelL1path, modelL2path, modelL3path, modelL4path, modelL5path
84 |
85 | modelL1path = paths.concat('models', 'modelL1_4.t7')
86 | modelL2path = paths.concat('models', 'modelL2_4.t7')
87 | modelL3path = paths.concat('models', 'modelL3_4.t7')
88 | modelL4path = paths.concat('models', 'modelL4_4.t7')
89 | modelL5path = paths.concat('models', 'modelL5_4.t7')
90 |
91 | modelL1 = torch.load(modelL1path)
92 | if torch.type(modelL1) == 'nn.DataParallelTable' then
93 | modelL1 = modelL1:get(1)
94 | end
95 | modelL1:evaluate()
96 |
97 | modelL2 = torch.load(modelL2path)
98 | if torch.type(modelL2) == 'nn.DataParallelTable' then
99 | modelL2 = modelL2:get(1)
100 | end
101 | modelL2:evaluate()
102 |
103 | modelL3 = torch.load(modelL3path)
104 | if torch.type(modelL3) == 'nn.DataParallelTable' then
105 | modelL3 = modelL3:get(1)
106 | end
107 | modelL3:evaluate()
108 |
109 | modelL4 = torch.load(modelL4path)
110 | if torch.type(modelL4) == 'nn.DataParallelTable' then
111 | modelL4 = modelL4:get(1)
112 | end
113 | modelL4:evaluate()
114 |
115 | modelL5 = torch.load(modelL5path)
116 | if torch.type(modelL5) == 'nn.DataParallelTable' then
117 | modelL5 = modelL5:get(1)
118 | end
119 | modelL5:evaluate()
120 |
121 | local function getTrainValidationSplits(path)
122 | local numSamples = sys.fexecute( "ls " .. opt.data .. "| wc -l")/3
123 | local ff = torch.DiskFile(path, 'r')
124 | local trainValidationSamples = torch.IntTensor(numSamples)
125 | ff:readInt(trainValidationSamples:storage())
126 | ff:close()
127 |
128 | local train_samples = trainValidationSamples:eq(1):nonzero()
129 | local validation_samples = trainValidationSamples:eq(2):nonzero()
130 |
131 | return train_samples, validation_samples
132 | -- body
133 | end
134 | M.getTrainValidationSplits = getTrainValidationSplits
135 |
136 | local function loadImage(path)
137 | local input = image.load(path, 3, 'float')
138 | return input
139 | end
140 | M.loadImage = loadImage
141 |
142 | local function loadFlow(filename)
143 | TAG_FLOAT = 202021.25
144 | local ff = torch.DiskFile(filename):binary()
145 | local tag = ff:readFloat()
146 | if tag ~= TAG_FLOAT then
147 | xerror('unable to read '..filename..
148 | ' perhaps bigendian error','readflo()')
149 | end
150 |
151 | local w = ff:readInt()
152 | local h = ff:readInt()
153 | local nbands = 2
154 | local tf = torch.FloatTensor(h, w, nbands)
155 | ff:readFloat(tf:storage())
156 | ff:close()
157 |
158 | local flow = tf:permute(3,1,2)
159 | return flow
160 | end
161 | M.loadFlow = loadFlow
162 |
163 |
164 | local function computeInitFlowL1(imagesL1)
165 | local h = imagesL1:size(3)
166 | local w = imagesL1:size(4)
167 |
168 | local _flowappend = torch.zeros(opt.batchSize, 2, h, w):cuda()
169 | local images_in = torch.cat(imagesL1, _flowappend, 2)
170 |
171 | local flow_est = modelL1:forward(images_in)
172 | return flow_est
173 | end
174 | M.computeInitFlowL1 = computeInitFlowL1
175 |
176 | local function computeInitFlowL2(imagesL2)
177 | local imagesL1 = down2:forward(imagesL2:clone())
178 | local _flowappend = up2:forward(computeInitFlowL1(imagesL1))*2
179 | local _img2 = imagesL2[{{},{4,6},{},{}}]
180 | imagesL2[{{},{4,6},{},{}}]:copy(warpmodel2:forward({_img2, _flowappend}))
181 |
182 | local images_in = torch.cat(imagesL2, _flowappend, 2)
183 |
184 | local flow_est = modelL2:forward(images_in)
185 | return flow_est:add(_flowappend)
186 | end
187 | M.computeInitFlowL2 = computeInitFlowL2
188 |
189 | local function computeInitFlowL3(imagesL3)
190 | local imagesL2 = down3:forward(imagesL3:clone())
191 | local _flowappend = up3:forward(computeInitFlowL2(imagesL2))*2
192 | local _img2 = imagesL3[{{},{4,6},{},{}}]
193 | imagesL3[{{},{4,6},{},{}}]:copy(warpmodel3:forward({_img2, _flowappend}))
194 |
195 | local images_in = torch.cat(imagesL3, _flowappend, 2)
196 |
197 | local flow_est = modelL3:forward(images_in)
198 | return flow_est:add(_flowappend)
199 | end
200 | M.computeInitFlowL3 = computeInitFlowL3
201 |
202 | local function computeInitFlowL4(imagesL4)
203 | local imagesL3 = down4:forward(imagesL4)
204 | local _flowappend = up4:forward(computeInitFlowL3(imagesL3))*2
205 | local _img2 = imagesL4[{{},{4,6},{},{}}]
206 | imagesL4[{{},{4,6},{},{}}]:copy(warpmodel4:forward({_img2, _flowappend}))
207 |
208 | local images_in = torch.cat(imagesL4, _flowappend, 2)
209 |
210 | local flow_est = modelL4:forward(images_in)
211 | return flow_est:add(_flowappend)
212 | end
213 | M.computeInitFlowL4 = computeInitFlowL4
214 |
215 | local function computeInitFlowL5(imagesL5)
216 | local imagesL4 = down5:forward(imagesL5)
217 | local _flowappend = up5:forward(computeInitFlowL4(imagesL4))*2
218 |
219 | local _img2 = imagesL5[{{},{4,6},{},{}}]
220 | imagesL5[{{},{4,6},{},{}}]:copy(warpmodel5:forward({_img2, _flowappend}))
221 |
222 | local images_in = torch.cat(imagesL5, _flowappend, 2)
223 |
224 | local flow_est = modelL5:forward(images_in)
225 | return flow_est:add(_flowappend)
226 | end
227 | M.computeInitFlowL5 = computeInitFlowL5
228 |
229 | local function getRawData(id)
230 | local path1 = paths.concat(opt.data, (string.format("%05i", id) .."_img1.ppm"))
231 | local path2 = paths.concat(opt.data, (string.format("%05i", id) .."_img2.ppm"))
232 |
233 | local img1 = loadImage(path1)
234 | local img2 = loadImage(path2)
235 |
236 | local pathF = paths.concat(opt.data, (string.format("%05i", id) .."_flow.flo"))
237 | local flow = loadFlow(pathF)
238 |
239 | return img1, img2, flow
240 | end
241 | M.getRawData = getRawData
242 |
243 | local testHook = function(id)
244 | local path1 = paths.concat(opt.data, (string.format("%05i", id) .."_img1.ppm"))
245 | local path2 = paths.concat(opt.data, (string.format("%05i", id) .."_img2.ppm"))
246 |
247 | local img1 = loadImage(path1)
248 | local img2 = loadImage(path2)
249 | local images = torch.cat(img1, img2, 1)
250 |
251 | local pathF = paths.concat(opt.data, (string.format("%05i", id) .."_flow.flo"))
252 | local flow = loadFlow(pathF)
253 |
254 | images = TF.ColorNormalize(meanstd)(images)
255 | return images, flow
256 | end
257 | M.testHook = testHook
258 |
259 | return M
260 |
--------------------------------------------------------------------------------
/train.lua:
--------------------------------------------------------------------------------
1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
2 | -- All rights reserved.
3 | -- This software is provided for research purposes only.
4 | -- By using this software you agree to the terms of the license file
5 | -- in the root folder.
6 | -- For commercial use, please contact ps-license@tue.mpg.de.
7 | --
8 | -- Copyright (c) 2014, Facebook, Inc.
9 | -- All rights reserved.
10 | --
11 | -- This source code is licensed under the BSD-style license found in the
12 | -- LICENSE file in the root directory of this source tree. An additional grant
13 | -- of patent rights can be found in the PATENTS file in the same directory.
14 | --
15 | require 'optim'
16 |
17 | --[[
18 | 1. Setup SGD optimization state and learning rate schedule
19 | 2. Create loggers.
20 | 3. train - this function handles the high-level training loop,
21 | i.e. load data, train model, save model and state to disk
22 | 4. trainBatch - Used by train() to train a single batch after the data is loaded.
23 | ]]--
24 |
25 | -- Setup a reused optimization state (for sgd). If needed, reload it from disk
26 | local optimState = {
27 | learningRate = opt.LR,
28 | learningRateDecay = 0.0,
29 | momentum = opt.momentum,
30 | dampening = 0.0,
31 | weightDecay = opt.weightDecay
32 | }
33 |
34 | if opt.optimState ~= 'none' then
35 | assert(paths.filep(opt.optimState), 'File not found: ' .. opt.optimState)
36 | print('Loading optimState from file: ' .. opt.optimState)
37 | optimState = torch.load(opt.optimState)
38 | end
39 |
40 | -- Learning rate annealing schedule. We will build a new optimizer for
41 | -- each epoch.
42 | --
43 | -- By default we follow a known recipe for a 55-epoch training. If
44 | -- the learningRate command-line parameter has been specified, though,
45 | -- we trust the user is doing something manual, and will use her
46 | -- exact settings for all optimization.
47 | --
48 | -- Return values:
49 | -- diff to apply to optimState,
50 | -- true IFF this is the first epoch of a new regime
51 | local function paramsForEpoch(epoch)
52 | if opt.LR ~= 0.0 then -- if manually specified
53 | return { }
54 | end
55 | local regimes = {
56 | -- start, end, LR, WD,
57 | { 1, 10, 5e-3, 0 },
58 | { 11, 80, 1e-4, 0 },
59 | { 81, 120, 1e-4, 0 },
60 | { 121, 160, 1e-4, 0 },
61 | { 161, 1e8, 5e-5, 0 },
62 | }
63 |
64 | for _, row in ipairs(regimes) do
65 | if epoch >= row[1] and epoch <= row[2] then
66 | return { learningRate=row[3], weightDecay=row[4] }, epoch == row[1]
67 | end
68 | end
69 | end
70 |
71 | -- 2. Create loggers.
72 | trainLogger = optim.Logger(paths.concat(opt.save, 'train.log'))
73 | local batchNumber
74 | local top1_epoch, loss_epoch
75 |
76 | -- 3. train - this function handles the high-level training loop,
77 | -- i.e. load data, train model, save model and state to disk
78 | function train()
79 | print('==> doing epoch on training data:')
80 | print("==> online epoch # " .. epoch)
81 |
82 | local params, newRegime = paramsForEpoch(epoch)
83 | if newRegime then
84 | optimState = {
85 | learningRate = params.learningRate,
86 | learningRateDecay = 0.0,
87 | momentum = opt.momentum,
88 | dampening = 0.0,
89 | weightDecay = params.weightDecay
90 | }
91 | end
92 | batchNumber = 0
93 | cutorch.synchronize()
94 |
95 | -- set the dropouts to training mode
96 | model:training()
97 |
98 | local tm = torch.Timer()
99 | top1_epoch = 0
100 | loss_epoch = 0
101 | for i=1,opt.epochSize do
102 | -- queue jobs to data-workers
103 | donkeys:addjob(
104 | -- the job callback (runs in data-worker thread)
105 | function()
106 | local inputs, labels = trainLoader:sample(opt.batchSize)
107 | return inputs, labels
108 | end,
109 | -- the end callback (runs in the main thread)
110 | trainBatch
111 | )
112 | end
113 |
114 | donkeys:synchronize()
115 | cutorch.synchronize()
116 |
117 | top1_epoch = top1_epoch * 100 / (opt.batchSize * opt.epochSize)
118 | loss_epoch = loss_epoch / opt.epochSize
119 |
120 | trainLogger:add{
121 | ['% top1 accuracy (train set)'] = top1_epoch,
122 | ['avg loss (train set)'] = loss_epoch
123 | }
124 | print(string.format('Epoch: [%d][TRAINING SUMMARY] Total Time(s): %.2f\t'
125 | .. 'average loss (per batch): %.2f \t '
126 | .. 'accuracy(%%):\t top-1 %.2f\t',
127 | epoch, tm:time().real, loss_epoch, top1_epoch))
128 | print('\n')
129 |
130 | -- save model
131 | collectgarbage()
132 |
133 | -- clear the intermediate states in the model before saving to disk
134 | -- this saves lots of disk space
135 | model:clearState()
136 | saveDataParallel(paths.concat(opt.save, 'model_' .. epoch .. '.t7'), model) -- defined in util.lua
137 | torch.save(paths.concat(opt.save, 'optimState_' .. epoch .. '.t7'), optimState)
138 | end -- of train()
139 | -------------------------------------------------------------------------------------------
140 | -- GPU inputs (preallocate)
141 | local inputs = torch.CudaTensor()
142 | local labels = torch.CudaTensor()
143 |
144 | local timer = torch.Timer()
145 | local dataTimer = torch.Timer()
146 |
147 | local parameters, gradParameters = model:getParameters()
148 |
149 | -- 4. trainBatch - Used by train() to train a single batch after the data is loaded.
150 | function trainBatch(inputsCPU, labelsCPU)
151 | cutorch.synchronize()
152 | collectgarbage()
153 | local dataLoadingTime = dataTimer:time().real
154 | timer:reset()
155 |
156 | -- transfer over to GPU
157 | inputs:resize(inputsCPU:size()):copy(inputsCPU)
158 | labels:resize(labelsCPU:size()):copy(labelsCPU)
159 |
160 | local err, outputs
161 | feval = function(x)
162 | model:zeroGradParameters()
163 | outputs = model:forward(inputs)
164 | err = criterion:forward(outputs, labels)
165 | local gradOutputs = criterion:backward(outputs, labels)
166 | model:backward(inputs, gradOutputs)
167 | return err, gradParameters
168 | end
169 |
170 | if opt.optimizer == 'adam' then
171 | optim.adam(feval, parameters, optimState)
172 | elseif opt.optimizer == 'sgd' then
173 | optim.sgd(feval, parameters, optimState)
174 | else
175 | error("Specify Optimizer")
176 | end
177 |
178 | -- DataParallelTable's syncParameters
179 | if model.needsSync then
180 | model:syncParameters()
181 | end
182 |
183 | cutorch.synchronize()
184 | batchNumber = batchNumber + 1
185 | loss_epoch = loss_epoch + err
186 |
187 | -- Calculate top-1 error, and print information
188 | print(('Epoch: [%d][%d/%d]\tTime %.3f Err %.4f LR %.0e DataLoadingTime %.3f'):format(
189 | epoch, batchNumber, opt.epochSize, timer:time().real, err,
190 | optimState.learningRate, dataLoadingTime))
191 |
192 | dataTimer:reset()
193 | end
194 |
--------------------------------------------------------------------------------
/transforms.lua:
--------------------------------------------------------------------------------
1 | -- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
2 | -- All rights reserved.
3 | -- This software is provided for research purposes only.
4 | -- By using this software you agree to the terms of the license file
5 | -- in the root folder.
6 | -- For commercial use, please contact ps-license@tue.mpg.de.
7 |
8 | -- https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua
9 | --
10 | -- Copyright (c) 2016, Facebook, Inc.
11 | -- All rights reserved.
12 | --
13 | -- This source code is licensed under the BSD-style license found in the
14 | -- LICENSE file in the root directory of this source tree. An additional grant
15 | -- of patent rights can be found in the PATENTS file in the same directory.
16 | --
17 | -- Image transforms for data augmentation and input normalization
18 | --
19 |
20 | require 'image'
21 |
22 | local M = {}
23 |
24 | function M.Compose(transforms)
25 | return function(input)
26 | for _, transform in ipairs(transforms) do
27 | input = transform(input)
28 | end
29 | return input
30 | end
31 | end
32 |
33 | function M.ColorNormalize(meanstd)
34 | return function(img)
35 | img = img:clone()
36 | for i=1,3 do
37 | img[i]:add(-meanstd.mean[i])
38 | img[i]:div(meanstd.std[i])
39 | img[3+i]:add(-meanstd.mean[i])
40 | img[3+i]:div(meanstd.std[i])
41 | end
42 | return img
43 | end
44 | end
45 |
46 | -- Scales the smaller edge to size
47 | function M.Scale(size, interpolation)
48 | interpolation = interpolation or 'bicubic'
49 | return function(input)
50 | local w, h = input:size(3), input:size(2)
51 | if (w <= h and w == size) or (h <= w and h == size) then
52 | return input
53 | end
54 | if w < h then
55 | return image.scale(input, size, h/w * size, interpolation)
56 | else
57 | return image.scale(input, w/h * size, size, interpolation)
58 | end
59 | end
60 | end
61 |
62 | -- Crop to centered rectangle
63 | function M.CenterCrop(size)
64 | return function(input)
65 | local w1 = math.ceil((input:size(3) - size)/2)
66 | local h1 = math.ceil((input:size(2) - size)/2)
67 | return image.crop(input, w1, h1, w1 + size, h1 + size) -- center patch
68 | end
69 | end
70 |
71 | -- Random crop form larger image with optional zero padding
72 | function M.RandomCrop(size, padding)
73 | padding = padding or 0
74 |
75 | return function(input)
76 | if padding > 0 then
77 | local temp = input.new(3, input:size(2) + 2*padding, input:size(3) + 2*padding)
78 | temp:zero()
79 | :narrow(2, padding+1, input:size(2))
80 | :narrow(3, padding+1, input:size(3))
81 | :copy(input)
82 | input = temp
83 | end
84 |
85 | local w, h = input:size(3), input:size(2)
86 | if w == size and h == size then
87 | return input
88 | end
89 |
90 | local x1, y1 = torch.random(0, w - size), torch.random(0, h - size)
91 | local out = image.crop(input, x1, y1, x1 + size, y1 + size)
92 | assert(out:size(2) == size and out:size(3) == size, 'wrong crop size')
93 | return out
94 | end
95 | end
96 |
97 | -- Four corner patches and center crop from image and its horizontal reflection
98 | function M.TenCrop(size)
99 | local centerCrop = M.CenterCrop(size)
100 |
101 | return function(input)
102 | local w, h = input:size(3), input:size(2)
103 |
104 | local output = {}
105 | for _, img in ipairs{input, image.hflip(input)} do
106 | table.insert(output, centerCrop(img))
107 | table.insert(output, image.crop(img, 0, 0, size, size))
108 | table.insert(output, image.crop(img, w-size, 0, w, size))
109 | table.insert(output, image.crop(img, 0, h-size, size, h))
110 | table.insert(output, image.crop(img, w-size, h-size, w, h))
111 | end
112 |
113 | -- View as mini-batch
114 | for i, img in ipairs(output) do
115 | output[i] = img:view(1, img:size(1), img:size(2), img:size(3))
116 | end
117 |
118 | return input.cat(output, 1)
119 | end
120 | end
121 |
122 | -- Resized with shorter side randomly sampled from [minSize, maxSize] (ResNet-style)
123 | function M.RandomScale(minSize, maxSize)
124 | return function(input)
125 | local w, h = input:size(3), input:size(2)
126 |
127 | local targetSz = torch.random(minSize, maxSize)
128 | local targetW, targetH = targetSz, targetSz
129 | if w < h then
130 | targetH = torch.round(h / w * targetW)
131 | else
132 | targetW = torch.round(w / h * targetH)
133 | end
134 |
135 | return image.scale(input, targetW, targetH, 'bicubic')
136 | end
137 | end
138 |
139 | -- Random crop with size 8%-100% and aspect ratio 3/4 - 4/3 (Inception-style)
140 | function M.RandomSizedCrop(size)
141 | local scale = M.Scale(size)
142 | local crop = M.CenterCrop(size)
143 |
144 | return function(input)
145 | local attempt = 0
146 | repeat
147 | local area = input:size(2) * input:size(3)
148 | local targetArea = torch.uniform(0.08, 1.0) * area
149 |
150 | local aspectRatio = torch.uniform(3/4, 4/3)
151 | local w = torch.round(math.sqrt(targetArea * aspectRatio))
152 | local h = torch.round(math.sqrt(targetArea / aspectRatio))
153 |
154 | if torch.uniform() < 0.5 then
155 | w, h = h, w
156 | end
157 |
158 | if h <= input:size(2) and w <= input:size(3) then
159 | local y1 = torch.random(0, input:size(2) - h)
160 | local x1 = torch.random(0, input:size(3) - w)
161 |
162 | local out = image.crop(input, x1, y1, x1 + w, y1 + h)
163 | assert(out:size(2) == h and out:size(3) == w, 'wrong crop size')
164 |
165 | return image.scale(out, size, size, 'bicubic')
166 | end
167 | attempt = attempt + 1
168 | until attempt >= 10
169 |
170 | -- fallback
171 | return crop(scale(input))
172 | end
173 | end
174 |
175 | function M.HorizontalFlip(prob)
176 | return function(input)
177 | if torch.uniform() < prob then
178 | input = image.hflip(input)
179 | end
180 | return input
181 | end
182 | end
183 |
184 | function M.Rotation(deg)
185 | return function(input)
186 | if deg ~= 0 then
187 | input = image.rotate(input, (torch.uniform() - 0.5) * deg * math.pi / 180, 'bilinear')
188 | end
189 | return input
190 | end
191 | end
192 |
193 | -- Lighting noise (AlexNet-style PCA-based noise)
194 | function M.Lighting(alphastd, eigval, eigvec)
195 | return function(input)
196 | if alphastd == 0 then
197 | return input
198 | end
199 |
200 | local alpha = torch.Tensor(3):normal(0, alphastd)
201 | local rgb = eigvec:clone()
202 | :cmul(alpha:view(1, 3):expand(3, 3))
203 | :cmul(eigval:view(1, 3):expand(3, 3))
204 | :sum(2)
205 | :squeeze()
206 |
207 | input = input:clone()
208 | for i=1,3 do
209 | input[i]:add(rgb[i])
210 | input[3+i]:add(rgb[i])
211 | end
212 | return input
213 | end
214 | end
215 |
216 | local function blend(img1, img2, alpha)
217 | return img1:mul(alpha):add(1 - alpha, img2)
218 | end
219 |
220 | local function grayscale(dst, img)
221 | assert(img:size(1)==3)
222 |
223 | dst[1]:zero()
224 | dst[1]:add(0.299, img[1]):add(0.587, img[2]):add(0.114, img[3])
225 | dst[2]:copy(dst[1])
226 | dst[3]:copy(dst[1])
227 | return dst
228 | end
229 |
230 | function M.Saturation(var)
231 | local gs
232 |
233 | return function(input)
234 | gs = gs or input.new()
235 | gs:resizeAs(input)
236 |
237 | grayscale(gs[{{1,3},{},{}}], input[{{1,3},{},{}}])
238 | grayscale(gs[{{4,6},{},{}}], input[{{4,6},{},{}}])
239 |
240 | local alpha = 1.0 + torch.uniform(-var, var)
241 | blend(input, gs, alpha)
242 | return input
243 | end
244 | end
245 |
246 | function M.Brightness(var)
247 | local gs
248 |
249 | return function(input)
250 | gs = gs or input.new()
251 | gs:resizeAs(input):zero()
252 |
253 | local alpha = 1.0 + torch.uniform(-var, var)
254 | blend(input, gs, alpha)
255 | return input
256 | end
257 | end
258 |
259 | function M.Contrast(var)
260 | local gs
261 |
262 | return function(input)
263 | gs = gs or input.new()
264 | gs:resizeAs(input)
265 |
266 | grayscale(gs[{{1,3},{},{}}], input[{{1,3},{},{}}])
267 | grayscale(gs[{{4,6},{},{}}], input[{{4,6},{},{}}])
268 |
269 | gs[{{1,3},{},{}}]:fill(gs[1]:mean())
270 | gs[{{4,6},{},{}}]:fill(gs[4]:mean())
271 |
272 | local alpha = 1.0 + torch.uniform(-var, var)
273 | blend(input, gs, alpha)
274 | return input
275 | end
276 | end
277 |
278 | function M.RandomOrder(ts)
279 | return function(input)
280 | local img = input.img or input
281 | local order = torch.randperm(#ts)
282 | for i=1,#ts do
283 | img = ts[order[i]](img)
284 | end
285 | return input
286 | end
287 | end
288 |
289 | function M.ColorJitter(opt)
290 | local brightness = opt.brightness or 0
291 | local contrast = opt.contrast or 0
292 | local saturation = opt.saturation or 0
293 |
294 | local ts = {}
295 | if brightness ~= 0 then
296 | table.insert(ts, M.Brightness(brightness))
297 | end
298 | if contrast ~= 0 then
299 | table.insert(ts, M.Contrast(contrast))
300 | end
301 | if saturation ~= 0 then
302 | table.insert(ts, M.Saturation(saturation))
303 | end
304 |
305 | if #ts == 0 then
306 | return function(input) return input end
307 | end
308 |
309 | return M.RandomOrder(ts)
310 | end
311 |
312 | return M
313 |
--------------------------------------------------------------------------------
/util.lua:
--------------------------------------------------------------------------------
1 | require 'cunn'
2 | local ffi=require 'ffi'
3 |
4 | function makeDataParallel(model, nGPU)
5 | if nGPU > 1 then
6 | print('converting module to nn.DataParallelTable')
7 | assert(nGPU <= cutorch.getDeviceCount(), 'number of GPUs less than nGPU specified')
8 | local model_single = model
9 | model = nn.DataParallelTable(1)
10 | for i=1, nGPU do
11 | cutorch.setDevice(i)
12 | model:add(model_single:clone():cuda(), i)
13 | end
14 | end
15 | cutorch.setDevice(opt.GPU)
16 |
17 | return model
18 | end
19 |
20 | local function cleanDPT(module)
21 | -- This assumes this DPT was created by the function above: all the
22 | -- module.modules are clones of the same network on different GPUs
23 | -- hence we only need to keep one when saving the model to the disk.
24 | local newDPT = nn.DataParallelTable(1)
25 | cutorch.setDevice(opt.GPU)
26 | newDPT:add(module:get(1), opt.GPU)
27 | return newDPT
28 | end
29 |
30 | function saveDataParallel(filename, model)
31 | if torch.type(model) == 'nn.DataParallelTable' then
32 | torch.save(filename, cleanDPT(model))
33 | elseif torch.type(model) == 'nn.Sequential' then
34 | local temp_model = nn.Sequential()
35 | for i, module in ipairs(model.modules) do
36 | if torch.type(module) == 'nn.DataParallelTable' then
37 | temp_model:add(cleanDPT(module))
38 | else
39 | temp_model:add(module)
40 | end
41 | end
42 | torch.save(filename, temp_model)
43 | elseif torch.type(model) == 'nn.gModule' then
44 | torch.save(filename, model)
45 | else
46 | error('This saving function only works with Sequential or DataParallelTable modules.')
47 | end
48 | end
49 |
50 | function loadDataParallel(filename, nGPU)
51 | if opt.backend == 'cudnn' then
52 | require 'cudnn'
53 | end
54 | local model = torch.load(filename)
55 | if torch.type(model) == 'nn.DataParallelTable' then
56 | return makeDataParallel(model:get(1), nGPU)
57 | elseif torch.type(model) == 'nn.Sequential' then
58 | for i,module in ipairs(model.modules) do
59 | if torch.type(module) == 'nn.DataParallelTable' then
60 | model.modules[i] = makeDataParallel(module:get(1):float(), nGPU)
61 | end
62 | end
63 | return model
64 | elseif torch.type(model) == 'nn.gModule' then
65 | return model
66 | else
67 | error('The loaded model is not a Sequential or DataParallelTable module.')
68 | end
69 | end
70 |
--------------------------------------------------------------------------------