├── .gitignore ├── src ├── pipeline │ ├── matching.lua │ ├── disparity.lua │ ├── refinement.lua │ └── post.lua ├── networks │ ├── modules │ │ ├── Concatenation.lua │ │ ├── Normalize2.lua │ │ └── SpatialConvolution1_fw.lua │ ├── mc-models │ │ ├── mc-cnn │ │ │ ├── slow.lua │ │ │ └── fast.lua │ │ ├── resmatch │ │ │ ├── acrt.lua │ │ │ ├── fast.lua │ │ │ ├── hybrid.lua │ │ │ └── components.lua │ │ ├── fast.lua │ │ ├── matching.lua │ │ └── acrt.lua │ ├── criterions │ │ ├── Margin2.lua │ │ ├── BCE.lua │ │ └── MulClassNLLCriterion.lua │ ├── scores │ │ ├── DotProduct2.lua │ │ └── L2dist.lua │ ├── gdn-models │ │ ├── ref.lua │ │ └── dispnet.lua │ └── network.lua ├── datasets │ ├── kitti2015.lua │ ├── kitti.lua │ ├── mb.lua │ └── dataset.lua ├── logger.lua ├── trainer.lua ├── opts.lua ├── cv.cpp ├── curesmatch.cu ├── main.lua ├── runner.lua └── adcensus.cu ├── scripts ├── mkdirs.sh ├── download_middlebury.sh ├── preprocess_kitti.lua └── preprocess_mb.py ├── pretrained └── README.md ├── Makefile ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | net/net_* 2 | net/*/* 3 | cache/* 4 | data.* 5 | results/* 6 | out/ 7 | *.so 8 | storage 9 | tmp 10 | -------------------------------------------------------------------------------- /src/pipeline/matching.lua: -------------------------------------------------------------------------------- 1 | 2 | local M = {} 3 | 4 | function M.match(mcnet, x_batch, disp_max, directions) 5 | return mcnet:computeMatchingCost(x_batch, disp_max, directions) 6 | end 7 | 8 | return M 9 | -------------------------------------------------------------------------------- /scripts/mkdirs.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | mkdir -p tmp 4 | mkdir -p out 5 | mkdir -p storage 6 | mkdir -p results 7 | cd storage 8 | mkdir -p net 9 | mkdir -p net/mc 10 | mkdir -p net/mc/debug 11 | mkdir -p net/disparity 12 | mkdir -p net/disparity/debug 13 | mkdir -p data.kitti.rgb/disparity 14 | mkdir -p data.kitti2015.rgb/disparity 15 | cd .. 16 | -------------------------------------------------------------------------------- /src/networks/modules/Concatenation.lua: -------------------------------------------------------------------------------- 1 | local Concatenation, parent = torch.class('nn.Concatenation', 'nn.Module') 2 | 3 | function Concatenation:__init() 4 | parent.__init(self) 5 | end 6 | 7 | function Concatenation:updateOutput(input) 8 | 9 | self.output = input:view(input:size(1)/2, input:size(2)*2, input:size(3), input:size(4)) 10 | return self.output 11 | end 12 | 13 | 14 | function Concatenation:updateGradInput(input, gradOutput) 15 | self.gradInput = gradOutput:view(input:size()) 16 | return self.gradInput 17 | end 18 | -------------------------------------------------------------------------------- /src/datasets/kitti2015.lua: -------------------------------------------------------------------------------- 1 | require 'datasets/kitti.lua' 2 | 3 | local Kitti2015Dataset, parent = torch.class('Kitti2015Dataset','KittiDataset') 4 | 5 | local function createDataset(opt) 6 | return Kitti2015Dataset:new(opt) 7 | end 8 | 9 | function Kitti2015Dataset:__init(self, opt) 10 | parent.__init(parent, self, opt) 11 | self.name='kitti2015' 12 | self.dir = opt.storage .. '/data.kitti2015.' .. opt.color 13 | 14 | --better parameters for the network 15 | self.n_te = 200 16 | self.n_tr = 200 17 | end 18 | 19 | return createDataset -------------------------------------------------------------------------------- /src/pipeline/disparity.lua: -------------------------------------------------------------------------------- 1 | 2 | local M = {} 3 | 4 | function M.disparityImage(costs, gdn) 5 | if gdn then 6 | return gdn:disparityImage(costs) 7 | else 8 | local v1, d1 = torch.min(costs[{{1}}], 2) 9 | local disp = {} 10 | disp[2] = d1:cuda():add(-1) 11 | 12 | local v2 = torch.CudaTensor() 13 | local d2 = torch.CudaTensor() 14 | if costs:size(1) > 1 then 15 | v2, d2 = torch.min(costs[{{2}}], 2) 16 | disp[1] = d2:cuda():add(-1) 17 | end 18 | return disp, costs, {v1, v2}, {t1 = 1000, t2 = 1000} 19 | end 20 | end 21 | 22 | return M 23 | -------------------------------------------------------------------------------- /src/networks/modules/Normalize2.lua: -------------------------------------------------------------------------------- 1 | local Normalize2, parent = torch.class('nn.Normalize2', 'nn.Module') 2 | 3 | function Normalize2:__init() 4 | parent.__init(self) 5 | self.norm = torch.CudaTensor() 6 | end 7 | 8 | function Normalize2:updateOutput(input) 9 | self.norm:resize(input:size(1), 1, input:size(3), input:size(4)) 10 | self.output:resizeAs(input) 11 | adcensus.Normalize_forward(input, self.norm, self.output) 12 | return self.output 13 | end 14 | 15 | function Normalize2:updateGradInput(input, gradOutput) 16 | self.gradInput:resizeAs(input) 17 | adcensus.Normalize_backward_input(gradOutput, input, self.norm, self.gradInput) 18 | return self.gradInput 19 | end 20 | -------------------------------------------------------------------------------- /pretrained/README.md: -------------------------------------------------------------------------------- 1 | Trained Resmatch Torch models 2 | ============================ 3 | 4 | These are Resmatch models trainined on Kitti 2012 and Kitti 2015 data sets. The accuracy on the online evaluation servers are included below. 5 | 6 | - [Resmatch-Kitti2012-Fast](https://www.mediafire.com/?9fvwnimchbtdb27) 7 | - [Resmatch-Kitti2012-Hybrid]() 8 | - [Resmatch-Kitti2015-Fast]() 9 | - [Resmatch-Kitti2015-Hybrid]() 10 | 11 | ##### Kitti error rates 12 | 13 | | Network | NOC | ALL | 14 | | ------------- | ------ | ------ | 15 | | Resmatch-kitti2012-fast | 2.55 | 4.07 | 16 | | Resmatch-kitti2012-hybrid| | | 17 | | Resmatch-kitti2015-fast | | | 18 | | Resmatch-kitti2015-hybrid| | | 19 | -------------------------------------------------------------------------------- /src/networks/mc-models/mc-cnn/slow.lua: -------------------------------------------------------------------------------- 1 | require 'networks/mc-models/acrt' 2 | require('networks/criterions/BCE') 3 | 4 | local network = require('networks/network') 5 | 6 | local McCnnSlow, parent = torch.class('McCnnSlow','AcrtNetwork') 7 | 8 | function McCnnSlow:__init(self, opt, dataset) 9 | parent.__init(parent, self, opt, dataset) 10 | self.criterion = nn.BCECriterion2():cuda() 11 | end 12 | 13 | function McCnnSlow:getDescriptionNetwork() 14 | local description = nn.Sequential() 15 | 16 | for i = 1,self.params.l1 do 17 | description:add(Convolution(i == 1 and self.n_input_plane or self.params.fm, 18 | self.params.fm, self.params.ks, self.params.ks)) 19 | description:add(Activation()) 20 | end 21 | return description 22 | end 23 | 24 | return McCnnSlow 25 | -------------------------------------------------------------------------------- /src/networks/criterions/Margin2.lua: -------------------------------------------------------------------------------- 1 | local Margin2, parent = torch.class('nn.Margin2', 'nn.Criterion') 2 | 3 | function Margin2:__init() 4 | parent.__init(self) 5 | self.tmp = torch.CudaTensor() 6 | self.gradInput = torch.CudaTensor() 7 | self.margin = 0.2 8 | self.pow = 1 9 | end 10 | 11 | function Margin2:updateOutput(input, target) 12 | assert(input:size(2) == 1 and input:size(3) == 1 and input:size(4) == 1) 13 | self.tmp:resize(input:size(1) / 2) 14 | self.gradInput:resizeAs(input) 15 | adcensus.Margin2(input, self.tmp, self.gradInput, self.margin, self.pow) 16 | self.output = self.tmp:mean() 17 | self.gradInput:div(self.tmp:size(1)) 18 | return self.output 19 | end 20 | 21 | function Margin2:updateGradInput(input, target) 22 | return self.gradInput 23 | end 24 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PREFIX=$(HOME)/torch/install 2 | CUDA=/usr/local/cuda 3 | CFLAGS=-I$(PREFIX)/include/THC -I$(PREFIX)/include/TH -I$(PREFIX)/include 4 | LDFLAGS_NVCC=-L$(PREFIX)/lib -Xlinker -rpath,$(PREFIX)/lib -lluaT -lTHC -lTH -lpng 5 | LDFLAGS_CPP=-L$(PREFIX)/lib -lluaT -lTH `pkg-config --libs opencv` 6 | 7 | all: libcuresmatch.so libadcensus.so libcv.so 8 | 9 | libadcensus.so: src/adcensus.cu 10 | $(CUDA)/bin/nvcc -arch sm_35 -O3 -DNDEBUG --compiler-options '-fPIC' -o libadcensus.so --shared src/adcensus.cu $(CFLAGS) $(LDFLAGS_NVCC) 11 | 12 | libcuresmatch.so: src/curesmatch.cu 13 | $(CUDA)/bin/nvcc -arch sm_35 -O3 -DNDEBUG --compiler-options '-fPIC' -o libcuresmatch.so --shared src/curesmatch.cu $(CFLAGS) $(LDFLAGS_NVCC) 14 | 15 | libcv.so: src/cv.cpp 16 | g++ -fPIC -o libcv.so -shared src/cv.cpp $(CFLAGS) $(LDFLAGS_CPP) 17 | 18 | clean: 19 | rm -f libcuresmatch.so libadcensus.so libcv.so 20 | -------------------------------------------------------------------------------- /src/networks/mc-models/mc-cnn/fast.lua: -------------------------------------------------------------------------------- 1 | require 'networks/mc-models/fast' 2 | require('networks/modules/Normalize2') 3 | 4 | local McCnnFast, parent = torch.class('McCnnFast','FastNetwork') 5 | 6 | local function createModel(opt, dataset) 7 | return McCnnFast:new(opt, dataset) 8 | end 9 | 10 | function McCnnFast:__init(self, opt, dataset) 11 | parent.__init(parent, self, opt, dataset) 12 | end 13 | 14 | function McCnnFast:getDescriptionNetwork() 15 | local description = nn.Sequential() 16 | 17 | for i = 1,self.params.l1-1 do 18 | description:add(Convolution(i == 1 and self.n_input_plane or self.params.fm, 19 | self.params.fm, self.params.ks, self.params.ks)) 20 | description:add(Activation()) 21 | end 22 | description:add(Convolution(self.params.fm, self.params.fm, self.params.ks, self.params.ks)) 23 | description:add(nn.Normalize2()) 24 | return description 25 | end 26 | 27 | return createModel 28 | -------------------------------------------------------------------------------- /src/logger.lua: -------------------------------------------------------------------------------- 1 | local M = {} 2 | 3 | local Logger = torch.class('Logger', M) 4 | 5 | function Logger:__init(opt) 6 | local cmd = getCmd(opt) 7 | self.path = opt.log .. '/' .. cmd .. '.txt' 8 | local writer = io.open(self.path, 'a') 9 | writer:write(os.date([[%x %X]])..'\n') 10 | writer:close() 11 | end 12 | 13 | function Logger:write(str) 14 | print(str) 15 | local writer = io.open(self.path, 'a') 16 | writer:write(str) 17 | writer:close() 18 | end 19 | 20 | function getCmd(opt) 21 | 22 | local cmd_str = opt.a 23 | local i =1 24 | while i <= #arg do 25 | if arg[i] == '-mcnet' or arg[i] == '-dispnet' or arg[i] == '-sm_skip' or arg[i] == '-sm_terminate' then 26 | i = i +2 27 | elseif arg[i] == '-make_cache' or arg[i] == '-use_cache' or arg[i] == '-save_train' then 28 | i = i +1 29 | else 30 | cmd_str = cmd_str .. '_' .. arg[i] 31 | i = i+1 32 | end 33 | end 34 | 35 | return cmd_str 36 | end 37 | 38 | return M.Logger 39 | -------------------------------------------------------------------------------- /src/networks/mc-models/resmatch/acrt.lua: -------------------------------------------------------------------------------- 1 | require 'networks/mc-models/acrt' 2 | require('networks/criterions/BCE') 3 | local resmatch = require 'networks/mc-models/resmatch/components' 4 | 5 | local ResmatchAcrt, parent = torch.class('ResmatchAcrt','AcrtNetwork') 6 | 7 | function ResmatchAcrt:__init(self, opt, dataset) 8 | parent.__init(parent, self, opt, dataset) 9 | 10 | self.criterion = nn.BCECriterion2():cuda() 11 | 12 | self.innerType = opt.inner 13 | self.outerType = opt.outer 14 | self.ConvBlock = basicBlock 15 | 16 | self.params.arch= {{1,2},{1,2},{1,2},{1,2},{1,2}} 17 | end 18 | 19 | function ResmatchAcrt:getDescriptionNetwork() 20 | fin = self.n_input_plane 21 | local description = nn.Sequential() 22 | local fm = self.params.fm 23 | for i =1, #self.params.arch do 24 | l = self.params.arch[i] 25 | description:add(resmatch.transition(i ==1 and fin or fm,fm)) 26 | description:add(resmatch.resStack(self.ConvBlock,fm, l[1], l[2], 1, self.innerType, self.outerType, 27 | i == #self.params.arch)) 28 | end 29 | return description 30 | end 31 | 32 | return ResmatchAcrt 33 | -------------------------------------------------------------------------------- /src/networks/mc-models/resmatch/fast.lua: -------------------------------------------------------------------------------- 1 | require 'networks/mc-models/fast' 2 | require('networks/modules/Normalize2') 3 | 4 | local resmatch = require 'networks/mc-models/resmatch/components' 5 | 6 | local ResmatchFast, parent = torch.class('ResmatchFast','FastNetwork') 7 | 8 | function ResmatchFast:__init(self, opt, dataset) 9 | parent.__init(parent, self, opt, dataset) 10 | 11 | self.innerType = opt.inner 12 | self.outerType = opt.outer 13 | self.convBlock = basicBlock 14 | 15 | self.params.arch= {{1,1},{1,1},{1,1},{1,1},{1,1}} 16 | end 17 | 18 | function ResmatchFast:getDescriptionNetwork(block, fin) 19 | fin = fin or self.n_input_plane 20 | local description = nn.Sequential() 21 | local fm = self.params.fm 22 | for i =1, #self.params.arch do 23 | l = self.params.arch[i] 24 | description:add(resmatch.transition(i ==1 and fin or fm,fm)) 25 | description:add(resmatch.resStack(self.convBlock,fm, l[1], l[2], 1, self.innerType, self.outerType, 26 | i == #self.params.arch)) 27 | end 28 | description:add(nn.Normalize2()) 29 | return description 30 | end 31 | 32 | return ResmatchFast 33 | -------------------------------------------------------------------------------- /src/networks/modules/SpatialConvolution1_fw.lua: -------------------------------------------------------------------------------- 1 | local SpatialConvolution1_fw, parent = torch.class('nn.SpatialConvolution1_fw', 'nn.Module') 2 | 3 | function SpatialConvolution1_fw:__init(inputSize, outputSize, free) 4 | parent.__init(self) 5 | self.weight = torch.CudaTensor(outputSize, inputSize) 6 | self.bias = torch.CudaTensor(1, outputSize, 1, 1) 7 | self.output:cuda() 8 | self.free = free -- free memory 9 | end 10 | 11 | function SpatialConvolution1_fw:updateOutput(input) 12 | local num_ex = input:size(1) 13 | local fm_in = input:size(2) 14 | local h = input:size(3) 15 | local w = input:size(4) 16 | local fm_out = self.weight:size(1) 17 | 18 | input:resize(num_ex, fm_in, h * w) 19 | self.output:resize(num_ex, fm_out, h * w) 20 | for i = 1,num_ex do 21 | self.output[i]:addmm(0, 1, self.weight, input[i]) 22 | end 23 | input:resize(num_ex, fm_in, h, w) 24 | self.output:resize(num_ex, fm_out, h, w) 25 | 26 | self.output:add(self.bias:expandAs(self.output)) 27 | 28 | -- Free memory 29 | if self.free then 30 | input:storage():resize(0) 31 | end 32 | return self.output 33 | end 34 | -------------------------------------------------------------------------------- /src/networks/scores/DotProduct2.lua: -------------------------------------------------------------------------------- 1 | local network = require('networks/network') 2 | local DotProduct2, parent = torch.class('nn.DotProduct2', 'nn.Module') 3 | 4 | function DotProduct2:__init() 5 | parent.__init(self) 6 | self.gradInput = torch.CudaTensor() 7 | self.tmp = torch.CudaTensor() 8 | self.output = torch.CudaTensor() 9 | end 10 | 11 | function DotProduct2:updateOutput(input) 12 | local input_L, input_R = network.sliceInput(input) 13 | self.tmp:resizeAs(input_L) 14 | self.tmp:cmul(input_L, input_R) 15 | self.output:sum(self.tmp, 2) 16 | return self.output 17 | end 18 | 19 | function DotProduct2:updateGradInput(input, gradOutput) 20 | gradOutput:cuda() 21 | input:cuda() 22 | self.gradInput:resizeAs(input) 23 | local input_L, input_R = network.sliceInput(input) 24 | local gradInput_L, gradInput_R = network.sliceInput(self.gradInput) 25 | gradInput_L:cmul(input_R, gradOutput:expandAs(input_R):cuda()) 26 | gradInput_R:cmul(input_L, gradOutput:expandAs(input_L):cuda()) 27 | return self.gradInput 28 | end 29 | 30 | function DotProduct2:computeMatchingCost(input_L, input_R, output_L, output_R) 31 | adcensus.StereoJoin(input_L, input_R, output_L, output_R) 32 | end 33 | 34 | -------------------------------------------------------------------------------- /scripts/download_middlebury.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | # This file is copied from https://github.com/jzbonter/mc-cnn 3 | 4 | mkdir -p storage/data.mb/unzip 5 | cd storage/data.mb/unzip 6 | 7 | # 2014 dataset 8 | wget -r -np -A png,pfm,txt -X "/stereo/data/scenes2014/datasets/*-perfect/" http://vision.middlebury.edu/stereo/data/scenes2014/datasets/ 9 | 10 | # 2006 dataset 11 | wget -r -np -A png,txt http://vision.middlebury.edu/stereo/data/scenes2006/HalfSize/ 12 | 13 | # 2005 dataset 14 | wget -r -np -A png,txt http://vision.middlebury.edu/stereo/data/scenes2005/HalfSize/ 15 | 16 | # 2003 dataset 17 | mkdir vision.middlebury.edu/stereo/data/scenes2003/ 18 | pushd . 19 | cd vision.middlebury.edu/stereo/data/scenes2003/ 20 | wget http://vision.middlebury.edu/stereo/data/scenes2003/newdata/full/conesH-ppm-2.zip 21 | wget http://vision.middlebury.edu/stereo/data/scenes2003/newdata/full/teddyH-ppm-2.zip 22 | unzip conesH-ppm-2.zip 23 | unzip teddyH-ppm-2.zip 24 | popd 25 | 26 | # 2001 dataset 27 | wget -r -np -A pgm,ppm,txt http://vision.middlebury.edu/stereo/data/scenes2001/data/ 28 | # get tsukuba nonocc mask 29 | pushd . 30 | cd vision.middlebury.edu/stereo/data/scenes2001/data/tsukuba 31 | wget http://vision.middlebury.edu/stereo/eval/newEval/tsukuba/nonocc.png 32 | popd 33 | 34 | # eval3 train/test set 35 | wget http://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-H.zip 36 | unzip MiddEval3-data-H.zip 37 | 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2016, Amit Shaked, amit.shaked1@gmail.com 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /src/trainer.lua: -------------------------------------------------------------------------------- 1 | require('optim') 2 | 3 | local M = {} 4 | 5 | local Trainer = torch.class('Trainer', M) 6 | 7 | function Trainer:__init(network, ds_size, bs, optimState) 8 | self.bs = bs 9 | self.ds_size = ds_size 10 | self.network = network 11 | self.optim = optim.sgd 12 | self.optimState = optimState or { 13 | learningRate = network.params.lr, 14 | learningRateDecay = 0.0, 15 | momentum = network.params.mom, 16 | nesterov = true, 17 | dampening = 0.0, 18 | weightDecay = network.params.decay, 19 | } 20 | end 21 | 22 | function Trainer:train(epoch, trainBatch) 23 | 24 | self.optimState.learningRate = self.network:learningRate(epoch) 25 | 26 | x, dl_dx = self.network:getModelParameters() 27 | 28 | local function feval() 29 | return self.network:feval(x, dl_dx, self.inputs, self.targets) 30 | end 31 | 32 | local err_tr = 0 33 | local err_tr_cnt = 0 34 | local t = 1 35 | 36 | local indexes = torch.range(1, self.ds_size/self.bs):totable() 37 | local s = self.ds_size - self.bs 38 | for i, idx in ipairs(indexes) do 39 | xlua.progress(i,#indexes) 40 | t = (idx-1) * self.bs + 1 41 | self.inputs, self.targets = trainBatch(t, self.bs, self.network.params.ws) 42 | 43 | _, fs = self.optim(feval, x, self.optimState) 44 | local err = fs[1] 45 | if err >= 0 and err < 100 then 46 | err_tr = err_tr + err 47 | err_tr_cnt = err_tr_cnt + 1 48 | else 49 | print(('WARNING! err=%f'):format(err)) 50 | if err ~= err then 51 | os.exit() 52 | end 53 | end 54 | end 55 | xlua.progress(#indexes, #indexes) 56 | return err_tr / err_tr_cnt 57 | end 58 | 59 | return M.Trainer 60 | -------------------------------------------------------------------------------- /src/pipeline/refinement.lua: -------------------------------------------------------------------------------- 1 | local M = {} 2 | 3 | function M.refine(disp, vols, opt, dataset, sm_skip, sm_terminate, disp_max, conf, t1, t2) 4 | local sm_active = true 5 | if dataset.name == 'kitti' or dataset.name == 'kitti2015' then 6 | local outlier = torch.CudaTensor():resizeAs(disp[2]):zero() 7 | curesmatch.outlier_detection(disp[2], disp[1], outlier, disp_max, conf[1], conf[2], t1, t2) 8 | 9 | if sm_active and sm_skip ~= 'occlusion' then 10 | 11 | disp[2] = adcensus.interpolate_occlusion(disp[2], outlier) 12 | 13 | end 14 | sm_active = sm_active and (sm_terminate ~= 'occlusion') 15 | 16 | if sm_active and sm_skip ~= 'mismatch' then 17 | disp[2] = adcensus.interpolate_mismatch(disp[2], outlier) 18 | 19 | end 20 | sm_active = sm_active and (sm_terminate ~= 'mismatch') 21 | end 22 | if sm_active and sm_skip ~= 'subpixel_enhancement' then 23 | disp[2] = adcensus.subpixel_enchancement(disp[2], vols[{{1}}], disp_max) 24 | 25 | end 26 | sm_active = sm_active and (sm_terminate ~= 'subpixel_enchancement') 27 | 28 | if sm_active and sm_skip ~= 'median' then 29 | disp[2] = adcensus.median2d(disp[2], 5) 30 | 31 | end 32 | sm_active = sm_active and (sm_terminate ~= 'median') 33 | 34 | if sm_active and sm_skip ~= 'bilateral' then 35 | disp[2] = adcensus.mean2d(disp[2], gaussian(opt.blur_sigma):cuda(), opt.blur_t) 36 | 37 | end 38 | 39 | return disp 40 | 41 | end 42 | 43 | function gaussian(sigma) 44 | local kr = math.ceil(sigma * 3) 45 | local ks = kr * 2 + 1 46 | local k = torch.Tensor(ks, ks) 47 | for i = 1, ks do 48 | for j = 1, ks do 49 | local y = (i - 1) - kr 50 | local x = (j - 1) - kr 51 | k[{i,j}] = math.exp(-(x * x + y * y) / (2 * sigma * sigma)) 52 | end 53 | end 54 | return k 55 | end 56 | return M 57 | -------------------------------------------------------------------------------- /src/networks/criterions/BCE.lua: -------------------------------------------------------------------------------- 1 | local BCECriterion2, parent = torch.class('nn.BCECriterion2', 'nn.Criterion') 2 | 3 | local eps = 1e-12 4 | 5 | function BCECriterion2:__init() 6 | parent.__init(self) 7 | self.sizeAverage = true 8 | end 9 | 10 | function BCECriterion2:updateOutput(input, target) 11 | -- log(input) * target + log(1 - input) * (1 - target) 12 | 13 | self.term1 = self.term1 or input.new() 14 | self.term2 = self.term2 or input.new() 15 | self.term3 = self.term3 or input.new() 16 | 17 | self.term1:resizeAs(input) 18 | self.term2:resizeAs(input) 19 | self.term3:resizeAs(input) 20 | 21 | self.term1:fill(1):add(-1,target) 22 | self.term2:fill(1):add(-1,input):add(eps):log():cmul(self.term1) 23 | 24 | self.term3:copy(input):add(eps):log():cmul(target) 25 | self.term3:add(self.term2) 26 | 27 | if self.sizeAverage then 28 | self.term3:div(target:nElement()) 29 | end 30 | 31 | self.output = - self.term3:sum() 32 | 33 | return self.output 34 | end 35 | 36 | function BCECriterion2:updateGradInput(input, target) 37 | -- target / input - (1 - target) / (1 - input) 38 | 39 | self.term1 = self.term1 or input.new() 40 | self.term2 = self.term2 or input.new() 41 | self.term3 = self.term3 or input.new() 42 | 43 | self.term1:resizeAs(input) 44 | self.term2:resizeAs(input) 45 | self.term3:resizeAs(input) 46 | 47 | self.term1:fill(1):add(-1,target) 48 | self.term2:fill(1):add(-1,input) 49 | 50 | self.term2:add(eps) 51 | self.term1:cdiv(self.term2) 52 | 53 | self.term3:copy(input):add(eps) 54 | 55 | self.gradInput:resizeAs(input) 56 | self.gradInput:copy(target):cdiv(self.term3) 57 | 58 | self.gradInput:add(-1,self.term1) 59 | 60 | if self.sizeAverage then 61 | self.gradInput:div(target:nElement()) 62 | end 63 | 64 | self.gradInput:mul(-1) 65 | 66 | return self.gradInput 67 | end 68 | -------------------------------------------------------------------------------- /src/networks/scores/L2dist.lua: -------------------------------------------------------------------------------- 1 | local network = require 'networks/network' 2 | 3 | local L2dist, parent = torch.class('nn.L2dist', 'nn.Module') 4 | 5 | function L2dist:__init() 6 | parent.__init(self) 7 | self.gradInput = torch.CudaTensor() 8 | self.tmp = torch.CudaTensor() 9 | self.diff = torch.CudaTensor() 10 | self.outExpand = torch.CudaTensor() 11 | self.ones = torch.CudaTensor() 12 | self.output = torch.CudaTensor() 13 | end 14 | 15 | 16 | function L2dist:updateOutput(input) 17 | local input_L, input_R = network.sliceInput(input) 18 | self.diff:resizeAs(input_L) 19 | self.diff:zero() 20 | self.diff:add(input_L, -1, input_R):pow(2) 21 | self.output:sum(self.diff, 2) 22 | self.output:pow(1./2) 23 | return self.output 24 | end 25 | 26 | 27 | function L2dist:updateGradInput(input, gradOutput) 28 | -- input[2*i-1]: the i'th left patch in the batch 29 | -- input[2*i]: the i'th right patch in the batch 30 | -- self.output = ||input_L - input_R ||_2 31 | -- the gradInput should be the derivative of 2-norm: 32 | -- d/dx_k(||x - y||_2) = (x_k -y_k) * x_k' / (||x-y||_2) 33 | 34 | local input_L, input_R = network.sliceInput(input) 35 | self.gradInput:resizeAs(input) 36 | 37 | local gradInput_L, gradInput_R = network.sliceInput(self.gradInput) 38 | gradInput_L:add(input_L, -1, input_R) 39 | 40 | self.outExpand:resizeAs(self.output) 41 | self.outExpand:copy(self.output) 42 | 43 | self.outExpand:add(1.0e-6) -- Prevent divide by zero errors 44 | self.outExpand:pow(-1) 45 | 46 | gradInput_L:cmul(self.outExpand:expandAs(gradInput_L)) 47 | self.grad = self.grad or gradOutput.new() 48 | self.ones = self.ones or gradOutput.new() 49 | self.grad:resize(input_L:size(1), input_L:size(2)):zero() 50 | self.ones:resize(input_L:size(2)):fill(1) 51 | self.grad:addr(gradOutput:squeeze(), self.ones) 52 | gradInput_L:cmul(self.grad) 53 | gradInput_R:zero():add(-1,gradInput_L) 54 | return self.gradInput 55 | end 56 | 57 | function L2dist:computeMatchingCost(input_L, input_R, output_L, output_R) 58 | -- computes matching cost for all pixels and all disperities at ones 59 | adcensus.L2dist(input_L, input_R, output_L, output_R) 60 | end 61 | -------------------------------------------------------------------------------- /src/opts.lua: -------------------------------------------------------------------------------- 1 | 2 | local M = {} 3 | 4 | function M.parse (arg) 5 | 6 | local cmd = torch.CmdLine() 7 | cmd:option('-ds', 'kitti', 'Dataset') 8 | cmd:option('-mc', 'resmatch', 'Matching cost network architecture') 9 | cmd:option('-m', 'acrt', 'Training mode (fast | acrt)') 10 | cmd:option('-gdn', '', 'Global disparity network architecture') 11 | cmd:option('-mcnet', '', 'Path to MC trained network') 12 | cmd:option('-dispnet', '', 'Path to GDN trained network') 13 | cmd:option('-a', 'train_mcn | train_gdn | test | submit | time | predict', 'train_mc') 14 | cmd:option('-log', '../results', 'Logs dir') 15 | cmd:option('-gpu', 1, 'gpu id') 16 | cmd:option('-seed', 6, 'Random seed') 17 | cmd:option('-debug', false) 18 | cmd:option('-times', 1, 'Test the pipeline every X epochs') 19 | cmd:option('-after', 14, 'Test every epoch after this one') 20 | 21 | cmd:option('-all', false, 'Train on both train and validation sets') 22 | cmd:option('-rename', false, 'Rename the trained network') 23 | cmd:option('-mix', false, 'Train on both kitti and kitti15') 24 | cmd:option('-storage', '../storage', 'Path to dir with the training data') 25 | cmd:option('-name', '', 'Add string to the network name') 26 | cmd:option('-METHOD_NAME', 'ResMatch', 'Name for MB submission') 27 | cmd:option('-start_epoch', 1) 28 | cmd:option('-make_cache', false) 29 | cmd:option('-use_cache', false) 30 | cmd:option('-save_img', false, 'Save the images when testing') 31 | cmd:option('-sm_terminate', 'refinement', 'Terminate the stereo method after this step') 32 | cmd:option('-sm_skip', '', 'which part of the stereo method to skip') 33 | cmd:option('-subset', 1, 'Percentage of the data set used for training') 34 | cmd:option('-epochs', 15, 'The amount of epochs to train') 35 | cmd:option('-start_epoch', 1) 36 | cmd:option('-rect', 'imperfect') 37 | cmd:option('-color', 'rgb') 38 | cmd:option('-verbose', false) 39 | cmd:option('-inner', 'L', 'Inner skip-connection') 40 | cmd:option('-outer', 'L', 'Outer skip-connection') 41 | 42 | -- Parameters of the matching cost network 43 | cmd:option('-fm', 112) 44 | cmd:option('-nh2', 384) 45 | cmd:option('-margin', 0.2, '') 46 | cmd:option('-lambda', 0.8) 47 | cmd:option('-batch_size', 128) 48 | 49 | local opt = cmd:parse(arg) 50 | 51 | return opt 52 | 53 | end 54 | 55 | return M 56 | -------------------------------------------------------------------------------- /src/networks/criterions/MulClassNLLCriterion.lua: -------------------------------------------------------------------------------- 1 | local MulClassNLLCriterion, parent = torch.class( 2 | 'nn.MulClassNLLCriterion', 3 | 'nn.Criterion' 4 | ) 5 | 6 | function MulClassNLLCriterion:__init(gt_weight) 7 | parent.__init(self) 8 | 9 | if gt_weight then 10 | gt_weight:div(gt_weight:sum()) 11 | self.gt_weight = gt_weight 12 | else 13 | self.gt_weight = torch.ones(1) 14 | end 15 | 16 | assert(self.gt_weight:nElement() % 2 == 0, 'nElement of gt_weight should be even') 17 | self.half_width = (self.gt_weight:nElement())/ 2 18 | end 19 | 20 | 21 | 22 | 23 | function MulClassNLLCriterion:__len() 24 | return 0 25 | end 26 | 27 | function MulClassNLLCriterion:updateOutput(input, target) 28 | assert(type(target) ~= 'number', 'target should be a tensor') 29 | 30 | if target:type() == 'torch.CudaTensor' then 31 | self.target = target 32 | else 33 | self.target = target:long() 34 | end 35 | 36 | -- has dimension for batch-size 37 | assert(input:dim() == 2 and target:dim() == 2, 'input should be 2D') 38 | assert(target:size(2) == 1, string.format('only support 1 gt locaton, got: %d', target:size(2))) 39 | self.output = 0 40 | for i = 1,input:size(1) do 41 | local t = math.floor(target[i][1]) 42 | local s = math.max(1, t-self.half_width + 1) 43 | local e = math.min(input:size(2), t+self.half_width) 44 | 45 | local probs= input[{i,{s,e}}] 46 | local weights = self.gt_weight[{{self.half_width-(t-s), self.half_width+ (e-t)}}] 47 | 48 | self.output = self.output - torch.cmul(probs, weights):sum() 49 | end 50 | self.output = self.output / input:nElement() 51 | return self.output 52 | end 53 | 54 | function MulClassNLLCriterion:updateGradInput(input, target) 55 | assert(type(target) ~= 'number', 'target should be a tensor') 56 | 57 | if target:type() == 'torch.CudaTensor' then 58 | self.target = target 59 | else 60 | self.target = target:long() 61 | end 62 | 63 | assert(input:dim() == 2, 'input should be 2D') 64 | self.gradInput:resizeAs(input):zero() 65 | 66 | for i = 1,input:size(1) do 67 | local t = math.floor(target[i][1]) 68 | local s = math.max(1, t-self.half_width +1) 69 | local e = math.min(input:size(2), t+self.half_width) 70 | 71 | self.gradInput[{i,{s,e}}]:copy(self.gt_weight[{{self.half_width-(t-s), self.half_width+e-t}}]):mul(-1) 72 | end 73 | 74 | self.gradInput:div(target:nElement()) 75 | return self.gradInput 76 | end 77 | -------------------------------------------------------------------------------- /src/cv.cpp: -------------------------------------------------------------------------------- 1 | /* This file is copied from https://github.com/jzbonter/mc-cnn */ 2 | 3 | extern "C" { 4 | #include "lua.h" 5 | #include "lualib.h" 6 | #include "lauxlib.h" 7 | } 8 | 9 | #include "luaT.h" 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | #include 18 | 19 | #include 20 | 21 | int warp_affine(lua_State *L) 22 | { 23 | THFloatTensor *src_ = (THFloatTensor*)luaT_checkudata(L, 1, "torch.FloatTensor"); 24 | THFloatTensor *dst_ = (THFloatTensor*)luaT_checkudata(L, 2, "torch.FloatTensor"); 25 | THFloatTensor *mat_ = (THFloatTensor*)luaT_checkudata(L, 3, "torch.FloatTensor"); 26 | 27 | float *src = THFloatTensor_data(src_); 28 | float *dst = THFloatTensor_data(dst_); 29 | float *mat = THFloatTensor_data(mat_); 30 | 31 | int src_c = THFloatTensor_size(src_, 0); 32 | int src_h = THFloatTensor_size(src_, 1); 33 | int src_w = THFloatTensor_size(src_, 2); 34 | int dst_c = THFloatTensor_size(dst_, 0); 35 | int dst_h = THFloatTensor_size(dst_, 1); 36 | int dst_w = THFloatTensor_size(dst_, 2); 37 | assert(THFloatTensor_nElement(mat_) >= 6); 38 | 39 | CvMat warp_mat = cvMat(2, 3, CV_32FC1, mat); 40 | for (int i = 0; i < src_c; i++) { 41 | CvMat src_mat = cvMat(src_h, src_w, CV_32FC1, src + i * src_h * src_w); 42 | CvMat dst_mat = cvMat(dst_h, dst_w, CV_32FC1, dst + i * dst_h * dst_w); 43 | cvWarpAffine(&src_mat, &dst_mat, &warp_mat, CV_INTER_CUBIC + CV_WARP_FILL_OUTLIERS); 44 | } 45 | 46 | return 0; 47 | } 48 | 49 | int copy_make_border(lua_State *L) 50 | { 51 | THFloatTensor *src_ = (THFloatTensor*)luaT_checkudata(L, 1, "torch.FloatTensor"); 52 | THFloatTensor *dst_ = (THFloatTensor*)luaT_checkudata(L, 2, "torch.FloatTensor"); 53 | 54 | int top_ = luaL_checkinteger(L, 3); 55 | int bottom_ = luaL_checkinteger(L, 4); 56 | int left_ = luaL_checkinteger(L, 5); 57 | int right_ = luaL_checkinteger(L, 6); 58 | 59 | cv::Mat src_mat = cv::Mat(2, 3, CV_32FC1, src_); 60 | cv::Mat dst_mat = cv::Mat(2, 3, CV_32FC1, dst_); 61 | copyMakeBorder(src_mat, dst_mat, top_, bottom_, left_, right_, cv::BORDER_REFLECT_101); 62 | 63 | return 0; 64 | } 65 | static const struct luaL_Reg funcs[] = { 66 | {"warp_affine", warp_affine}, 67 | {"copy_make_border", copy_make_border}, 68 | {NULL, NULL} 69 | }; 70 | 71 | extern "C" int luaopen_libcv(lua_State *L) { 72 | luaL_openlib(L, "cv", funcs, 0); 73 | return 1; 74 | } 75 | -------------------------------------------------------------------------------- /src/pipeline/post.lua: -------------------------------------------------------------------------------- 1 | local M = {} 2 | 3 | function M.process(vols, x_batch, disp_max, params, dataset, sm_terminate, sm_skip , directions) 4 | 5 | local vol 6 | 7 | for _, direction in ipairs(directions) do 8 | vol = vols[{{direction == -1 and 1 or 2}}] 9 | 10 | sm_active = (sm_terminate ~= 'cnn') 11 | 12 | -- cross computation 13 | local x0c, x1c 14 | --print(sm_skip) 15 | if sm_active and sm_skip ~= 'cbca' then 16 | x0c = torch.CudaTensor(1, 4, vol:size(3), vol:size(4)) 17 | x1c = torch.CudaTensor(1, 4, vol:size(3), vol:size(4)) 18 | adcensus.cross(x_batch[1], x0c, params.L1, params.tau1) 19 | adcensus.cross(x_batch[2], x1c, params.L1, params.tau1) 20 | local tmp_cbca = torch.CudaTensor(1, disp_max, vol:size(3), vol:size(4)) 21 | for i = 1,params.cbca_i1 do 22 | adcensus.cbca(x0c, x1c, vol, tmp_cbca, direction) 23 | vol:copy(tmp_cbca) 24 | end 25 | tmp_cbca = nil 26 | collectgarbage() 27 | end 28 | sm_active = sm_active and (sm_terminate ~= 'cbca1') 29 | 30 | if sm_active and sm_skip ~= 'sgm' then 31 | vol = vol:transpose(2, 3):transpose(3, 4):clone() 32 | collectgarbage() 33 | do 34 | local out = torch.CudaTensor(1, vol:size(2), vol:size(3), vol:size(4)) 35 | local tmp = torch.CudaTensor(vol:size(3), vol:size(4)) 36 | for _ = 1,params.sgm_i do 37 | out:zero() 38 | adcensus.sgm2(x_batch[1], x_batch[2], vol, out, tmp, params.pi1, params.pi2, params.tau_so, 39 | params.alpha1, params.sgm_q1, params.sgm_q2, direction) 40 | vol:copy(out):div(4) 41 | end 42 | vol:resize(1, disp_max, x_batch:size(3), x_batch:size(4)) 43 | vol:copy(out:transpose(3, 4):transpose(2, 3)):div(4) 44 | 45 | end 46 | collectgarbage() 47 | end 48 | sm_active = sm_active and (sm_terminate ~= 'sgm') 49 | 50 | if sm_active and sm_skip ~= 'cbca' then 51 | local tmp_cbca = torch.CudaTensor(1, disp_max, vol:size(3), vol:size(4)) 52 | for i = 1,params.cbca_i2 do 53 | adcensus.cbca(x0c, x1c, vol, tmp_cbca, direction) 54 | vol:copy(tmp_cbca) 55 | end 56 | end 57 | sm_active = sm_active and (sm_terminate ~= 'cbca2') 58 | vols[{{direction == -1 and 1 or 2}}] = vol 59 | end 60 | return vols 61 | end 62 | 63 | return M 64 | -------------------------------------------------------------------------------- /src/networks/mc-models/resmatch/hybrid.lua: -------------------------------------------------------------------------------- 1 | local network = require 'networks/network' 2 | require 'networks/mc-models/resmatch/acrt' 3 | require('networks/criterions/Margin2') 4 | require('networks/scores/DotProduct2') 5 | require('networks/modules/Normalize2') 6 | 7 | local ResmatchHybrid, parent = torch.class('ResmatchHybrid','ResmatchAcrt') 8 | 9 | function ResmatchHybrid:__init(self, opt, dataset) 10 | 11 | parent.__init(parent, self, opt, dataset) 12 | local bce = nn.BCECriterion2() 13 | local margin = nn.Margin2() 14 | 15 | -- parallel criterion with repeated target 16 | self.criterion = nn.ParallelCriterion(true) 17 | :add(bce, 0.8) 18 | :add(margin, 0.2) 19 | :cuda() 20 | 21 | self.params.arch= {{1,2},{1,2},{1,2},{1,2},{1,2}} 22 | end 23 | 24 | function ResmatchHybrid:getDecisionNetwork() 25 | 26 | local decision = nn.Sequential() 27 | decision:add(nn.Linear(2 * self.params.fm, self.params.nh2)) 28 | decision:add(Activation(self.alpha)) 29 | for i = 1,self.params.l2 do 30 | decision:add(nn.Linear(self.params.nh2, self.params.nh2)) 31 | decision:add(Activation()) 32 | end 33 | decision:add(nn.Linear(self.params.nh2, 1)) 34 | decision:add(nn.Sigmoid()) 35 | 36 | 37 | return nn.ConcatTable() 38 | :add(nn.Sequential() 39 | :add(nn.Reshape(self.params.bs, self.params.fm *2)) 40 | :add(decision) 41 | ) 42 | :add(nn.Sequential() 43 | :add(nn.Normalize2()) 44 | :add(nn.DotProduct2()) 45 | ) 46 | end 47 | 48 | function ResmatchHybrid:computeMatchingCost(x_batch, disp_max, directions) 49 | local desc_l, desc_r = self:getDescriptors(x_batch) 50 | 51 | -- Replace with fully convolutional network with the same weights 52 | local testDecision = network.getTestNetwork(self.decision) 53 | 54 | -- Initialize the output with the largest matching cost 55 | -- at each possible disparity ('1') 56 | local output = torch.CudaTensor(#directions, disp_max, desc_l:size(3), desc_l:size(4)):fill(1) -- (0 / 0) 57 | 58 | local x2= torch.CudaTensor() 59 | collectgarbage() 60 | for _, direction in ipairs(directions) do 61 | --print("calculate score in direction " .. direction) 62 | local index = direction == -1 and 1 or 2 63 | for d = 1,disp_max do 64 | collectgarbage() 65 | -- Get the left and right images for this disparity 66 | local l = desc_l[{{1},{},{},{d,-1}}] 67 | local r = desc_r[{{1},{},{},{1,-d}}] 68 | x2:resize(2, r:size(2), r:size(3), r:size(4)) 69 | x2[{{1}}]:copy(l) 70 | x2[{{2}}]:copy(r) 71 | 72 | -- Compute the matching score 73 | local score = testDecision:forward(x2)[1] 74 | 75 | -- Copy to the right place in the output tensor 76 | output[{index,d,{},direction == -1 and {d,-1} or {1,-d}}]:copy(score[{1,1}]) 77 | end 78 | -- Fix the borders of the obtained map 79 | network.fixBorder(output[{{index}}], direction, self.params.ws) 80 | end 81 | collectgarbage() 82 | return output 83 | end 84 | return ResmatchHybrid 85 | -------------------------------------------------------------------------------- /src/networks/mc-models/fast.lua: -------------------------------------------------------------------------------- 1 | require 'networks/mc-models/matching' 2 | require('networks/criterions/Margin2') 3 | require('networks/scores/DotProduct2') 4 | local network = require('networks/network') 5 | 6 | local FastNetwork, parent = torch.class('FastNetwork','MatchNet') 7 | 8 | function FastNetwork:__init(self, opt, dataset) 9 | parent.__init(parent, self, opt, dataset) 10 | self.sim_score = nn.DotProduct2():cuda() 11 | self.criterion = nn.Margin2():cuda() 12 | end 13 | 14 | function FastNetwork:getDecisionNetwork() 15 | return nn.Sequential() 16 | :add(self.sim_score) 17 | end 18 | 19 | function FastNetwork:computeMatchingCost(x_batch, disp_max, directions) 20 | local desc_l, desc_r = self:getDescriptors(x_batch) 21 | 22 | -- Initialize the output with the largest matching cost 23 | -- at each possible disparity ('1') 24 | local output = torch.CudaTensor(2, disp_max, x_batch:size(3), x_batch:size(4)):fill(1) 25 | 26 | -- Compute the matching cost at each possible disparity 27 | self.sim_score:computeMatchingCost(desc_l, desc_r, output[{{1}}], output[{{2}}]) 28 | 29 | -- Fix the borders of the obtained map 30 | network.fixBorder(output[{{1}}], -1, self.params.ws) 31 | network.fixBorder(output[{{2}}], 1, self.params.ws) 32 | 33 | return output 34 | end 35 | 36 | 37 | function FastNetwork:setBestParams( opt, dataset ) 38 | self.n_input_plane = dataset.n_colors 39 | self.params = {} 40 | self.params.ks = 3 -- convulutional kernel size 41 | self.params.bs = opt.batch_size -- batch size 42 | self.params.fm = opt.fm 43 | self.params.lr = 0.002 -- learning rate 44 | self.params.mom = 0.9 -- momentum 45 | self.params.decay=1e-4 46 | self.params.at = 0 47 | 48 | 49 | self.params.L1 = 0 50 | self.params.cbca_i1 = 0 -- number of cross-based iterations before semiglobal matching 51 | self.params.cbca_i2 = 0 -- number of cross based iterations after semiglobal matching 52 | self.params.tau1 = 0 53 | self.params.pi1 = 4 54 | self.params.pi2 = 55.72 55 | self.params.sgm_i = 1 56 | self.params.sgm_q1 = 3 57 | self.params.sgm_q2 = 2.5 58 | self.params.alpha1 = 1.5 59 | self.params.tau_so = 0.02 60 | self.params.blur_sigma = 7.74 61 | self.params.blur_t = 5 62 | 63 | if dataset.name == 'kitti' then 64 | elseif dataset.name == 'kitti2015' then 65 | 66 | self.params.pi1 = 2.3 67 | self.params.pi2 = 18.38 68 | self.params.sgm_i = 1 69 | self.params.sgm_q1 = 3 70 | self.params.sgm_q2 = 2 71 | self.params.alpha1 = 1.25 72 | self.params.tau_so = 0.08 73 | self.params.blur_sigma = 4.64 74 | self.params.blur_t = 5 75 | 76 | elseif dataset.name == 'mb' then 77 | 78 | self.params.L1 = 0 79 | self.params.tau1 = 0.0 80 | self.params.cbca_i1 = 0 81 | self.params.cbca_i2 = 0 82 | self.params.pi1 = 2.3 83 | self.params.pi2 = 24.3 84 | self.params.sgm_i = 1 85 | self.params.sgm_q1 = 4 86 | self.params.sgm_q2 = 2 87 | self.params.alpha1 = 1.5 88 | self.params.tau_so = 0.08 89 | self.params.blur_sigma = 6 90 | self.params.blur_t = 2 91 | end 92 | 93 | end 94 | 95 | return FastNetwork 96 | -------------------------------------------------------------------------------- /src/networks/mc-models/resmatch/components.lua: -------------------------------------------------------------------------------- 1 | 2 | local resmatch = {} 3 | 4 | -- The shortcut layer is either identity or 1x1 convolution 5 | local function shortcut(nInputPlane, nOutputPlane, stride, shortcutType) 6 | if shortcutType == 'A' or shortcutType == 'B' or shortcutType == 'C' then 7 | local useConv = shortcutType == 'C' or 8 | (shortcutType == 'B' and nInputPlane ~= nOutputPlane) 9 | if useConv then 10 | -- 1x1 convolution 11 | return nn.Sequential() 12 | :add(Convolution(nInputPlane, nOutputPlane, 1, 1)) 13 | elseif nInputPlane ~= nOutputPlane then 14 | -- Strided, zero-padded identity shortcut 15 | return nn.Sequential() 16 | :add(nn.SpatialAveragePooling(1, 1, stride, stride)) 17 | :add(nn.Concat(2) 18 | :add(nn.Identity()) 19 | :add(nn.MulConstant(0))) 20 | else 21 | return nn.Identity() 22 | end 23 | elseif shortcutType == 'D' then 24 | local m = nn.Mul() 25 | m.weight:fill(1) 26 | return nn.Sequential() 27 | :add(Convolution(nInputPlane, nOutputPlane, 1, 1)) 28 | :add(m) 29 | elseif shortcutType == 'L' then 30 | local m = nn.Mul() 31 | m.weight:fill(1) 32 | return m 33 | end 34 | end 35 | 36 | local function residualBlock(model, block, fin, fout, shortcutType) 37 | 38 | concat = nn.ConcatTable() 39 | 40 | concat:add(block(fin, fout, stride)) 41 | concat:add(shortcut(fin, fout, 1, shortcutType)) 42 | model:add(concat):add(nn.CAddTable()) 43 | end 44 | 45 | function basicBlock(fin, fout, stride) 46 | block = nn.Sequential() 47 | block:add(Convolution(fin,fout,3,3,stride,stride,1,1)) 48 | 49 | block:add(Activation()) 50 | block:add(Convolution(fout,fout,3,3,1,1,1,1)) 51 | 52 | return block 53 | end 54 | 55 | function resmatch.transition(fin, fout) 56 | local stack = nn.Sequential() 57 | -- Convolution 58 | stack:add(Convolution(fin, fout, 3, 3)) 59 | 60 | --Activation 61 | stack:add(Activation()) 62 | return stack 63 | end 64 | 65 | local function innerResStack(block, f, n, stride, shortcut, last) 66 | 67 | local stack = nn.Sequential() 68 | for k=1, n do 69 | residualBlock(stack, block, f, f, shortcut) 70 | if k < n or not last then 71 | --stack:add(Activation()) -- better results with no activation? 72 | end 73 | end 74 | return stack 75 | end 76 | 77 | function resmatch.resStack(block, f, nOut, nIn, stride, shortcutIn, shortcutOut, last) 78 | 79 | local stack = nn.Sequential() 80 | 81 | for i=1, nOut do 82 | local innerBlock = innerResStack(block, f, nIn, 1, shortcutIn, i == nOut and last) 83 | if shortcutOut and shortcutOut ~= 'none' and nIn > 1 then 84 | stack:add(nn.ConcatTable() 85 | :add(innerBlock) 86 | :add(shortcut(f,f,1,shortcutOut))) 87 | :add(nn.CAddTable()) 88 | else 89 | stack:add(innerBlock) 90 | end 91 | if not last then 92 | --stack:add(Activation()) 93 | end 94 | end 95 | 96 | return stack 97 | end 98 | 99 | return resmatch 100 | -------------------------------------------------------------------------------- /src/networks/gdn-models/ref.lua: -------------------------------------------------------------------------------- 1 | require 'networks/gdn-models/dispnet' 2 | require 'networks/criterions/MulClassNLLCriterion' 3 | 4 | local network = require 'networks/network' 5 | 6 | local Reflective, parent = torch.class('Reflective', 'DispNet') 7 | 8 | local function createModel(opt, dataset, mcnet) 9 | return Reflective:new(opt, dataset, mcnet) 10 | end 11 | 12 | function Reflective:__init(self, opt, dataset, mcnet) 13 | parent.__init(parent, self, opt, dataset) 14 | 15 | local gtw = torch.Tensor({1,4,10,10,4,1}) 16 | local mll = nn.MulClassNLLCriterion(gtw):cuda() 17 | 18 | self.criterion = nn.ParallelCriterion():add(mll, 0.85):add(nn.BCECriterion():cuda(), 0.15) 19 | self.name = 'reflective_' .. mcnet 20 | 21 | self.t1 = 0.7 22 | self.t2 = 0.1 23 | end 24 | 25 | 26 | function Reflective:feval(x, dl_dx, inputs, targets) 27 | 28 | dl_dx:zero() 29 | 30 | local prediction = self.net:forward(inputs) 31 | 32 | local probs, conf = prediction[1]:cuda(), prediction[2] 33 | local _, pred = torch.max(probs, 2) 34 | 35 | pred = pred:cuda() 36 | 37 | local tr = torch.add(pred, -1, targets):abs() 38 | local true_conf = torch.le(tr, 1):cuda() 39 | --print(true_conf:sum() / true_conf:size(1)) 40 | 41 | local loss_x = self.criterion:forward({probs:clone(), conf:clone()}, {targets:clone(), true_conf:clone()}) 42 | 43 | local back = self.criterion:backward({probs, conf}, {targets, true_conf}) 44 | 45 | self.net:backward(inputs, back) 46 | 47 | return loss_x, dl_dx 48 | 49 | end 50 | function Reflective:build() 51 | local disp_max = self.params.disp_max 52 | local disp = nn.Sequential() 53 | disp:add(Convolution(disp_max, disp_max, self.params.ks, self.params.ks)) 54 | disp:add(Activation()) 55 | 56 | disp:add(Convolution(disp_max, disp_max * 2, self.params.ks, self.params.ks)) 57 | disp:add(Activation()) 58 | 59 | disp:add(Convolution(disp_max * 2, disp_max * 2, self.params.ks, self.params.ks)) 60 | disp:add(Activation()) 61 | 62 | disp:add(Convolution(disp_max * 2, disp_max, 3, 3)) 63 | disp:add(Activation()) 64 | 65 | disp:add(Convolution(disp_max, disp_max, 1, 1)) 66 | disp:add(Activation()) 67 | disp:add(Convolution(disp_max, disp_max, 1, 1)) 68 | disp:add(Activation()) 69 | disp:add(Convolution(disp_max, disp_max, 1, 1)) 70 | 71 | disp:add(nn.Squeeze()) 72 | 73 | disp:add( 74 | nn.ConcatTable() 75 | :add(cudnn.SpatialLogSoftMax()) 76 | :add(nn.Sequential() 77 | :add(nn.Linear(disp_max, disp_max)) 78 | :add(Activation()) 79 | :add(nn.Linear(disp_max, 1)) 80 | :add(nn.Sigmoid()) 81 | ) 82 | ) 83 | 84 | self.net = disp:cuda() 85 | self.params.ws = network.getWindowSize(self.net) 86 | network.init(self.net) 87 | end 88 | 89 | function Reflective:forward(testModel, vols) 90 | vols = nn.Tanh():cuda():forward(vols) 91 | vols:mul(-1):add(1) 92 | local out = network.forwardFree(testModel,vols) 93 | local vols = out[1]:cuda() 94 | local conf = out[2] 95 | return vols, conf 96 | end 97 | 98 | function Reflective:disparityImage(vols) 99 | local testModel = network.getTestNetwork(self.net) 100 | local probs, conf = self:forward(testModel, vols) 101 | local _, d1 = torch.max(probs[{{1}}], 2) 102 | local disp = {} 103 | disp[2] = d1:cuda():add(-1)-- disp is [0, .. , disp_max -1] 104 | 105 | if probs:size(1) > 1 then 106 | local _, d2 = torch.max(probs[{{2}}], 2) 107 | disp[1] = d2:cuda():add(-1)-- disp is [0, .. , disp_max -1] 108 | end 109 | return disp, probs, {conf[1], conf[2]}, {t1 = self.t1,t2 = self.t2} 110 | end 111 | 112 | return createModel 113 | -------------------------------------------------------------------------------- /src/networks/gdn-models/dispnet.lua: -------------------------------------------------------------------------------- 1 | require('paths') 2 | local network = require('../network') 3 | 4 | local DispNet = torch.class('DispNet') 5 | 6 | function DispNet:__init(self, opt, dataset) 7 | self.dataset = dataset 8 | self.path = opt.storage .. '/net/disparity/' 9 | self:setBestParams(dataset) 10 | end 11 | 12 | function DispNet:getDisparityTrainingSamples(start, size, ws) 13 | 14 | local x = torch.FloatTensor(size, self.params.disp_max, ws, ws):zero() 15 | local y = torch.FloatTensor(size,1) 16 | 17 | for i=start, start+ size -1 do 18 | local idx = self.dataset.perm_disp[i] 19 | local img = self.dataset.disp[{idx, 1}] 20 | local dim3 = self.dataset.disp[{idx, 2}] 21 | local dim4 = self.dataset.disp[{idx, 3}] 22 | local d = self.dataset.disp[{idx, 4}] 23 | 24 | idx = i-start+1 25 | 26 | width = self.dataset.metadata[{img,2}] 27 | 28 | local x2 = self.dataset.X2[self.dataset.X2_idx[img]] 29 | 30 | local l = dim3 - ((ws-1)/2) 31 | local r = dim3 + ((ws-1)/2) 32 | local b = dim4 - ((ws-1)/2) 33 | local u = dim4 + ((ws-1)/2) 34 | local l1 = 1 35 | local r1 = ws 36 | local b1 = 1 37 | local u1 = ws 38 | if l < 1 then 39 | l1 = 1 + (1-l) 40 | l = 1 41 | end 42 | if r > x2:size(2) then 43 | r1 = ws -(r - x2:size(2)) 44 | r = x2:size(2) 45 | end 46 | if b < 1 then 47 | b1 = 1 + (1-b) 48 | b = 1 49 | end 50 | if u > x2:size(3) then 51 | u1 = ws -(u-x2:size(3)) 52 | u = x2:size(3) 53 | end 54 | 55 | x[{idx, {}, {l1, r1}, {b1,u1}}] = x2[{{},{l,r}, {b,u}}] 56 | --:pow(torch.uniform(0.8, 1.2)) 57 | y[idx] = d +1-- disp is [0, .. , disp_max -1] 58 | 59 | end 60 | 61 | return x:cuda(), y:cuda() 62 | end 63 | 64 | function DispNet:save(epoch, optimState) 65 | local fname = '' 66 | if epoch == 0 then 67 | fname = ('net_%s'):format(self.name) 68 | else 69 | fname = ('debug/net_%s_%d'):format(self.name, epoch) 70 | end 71 | 72 | local modelPath = paths.concat(self.path, fname .. '_net.t7') 73 | local optimPath = paths.concat(self.path, fname .. '_optim.t7') 74 | local latestPath = paths.concat(self.path, fname .. '.t7') 75 | local modelFile = { 76 | net = network.clean(self.net), 77 | params = self.params, 78 | name = self.name} 79 | 80 | torch.save(modelPath, modelFile) 81 | torch.save(optimPath, optimState) 82 | torch.save(latestPath, { 83 | epoch = epoch, 84 | modelPath = modelPath, 85 | optimPath = optimPath, 86 | }) 87 | 88 | return latestPath 89 | end 90 | 91 | function DispNet:load(opt) 92 | if opt.dispnet == '' then 93 | print('===> Building new disparity network...') 94 | self:build() 95 | return nil 96 | else 97 | local checkpoint = torch.load(opt.dispnet) 98 | local model, optimState 99 | if checkpoint.modelPath and paths.filep(checkpoint.modelPath) then 100 | model = torch.load(checkpoint.modelPath) 101 | optimState = torch.load(checkpoint.optimPath) 102 | else 103 | model = checkpoint 104 | checkpoint = nil 105 | end 106 | self.net = model.net:cuda() 107 | self.params = model.params 108 | self.name = model.name 109 | 110 | print('===> Loaded network '.. self.name) 111 | return checkpoint, optimState 112 | end 113 | end 114 | 115 | function DispNet:getModelParameters() 116 | return self.net:getParameters() 117 | end 118 | 119 | function DispNet:learningRate(epoch) 120 | if epoch == 12 then 121 | self.params.lr = self.params.lr / 10 122 | end 123 | return self.params.lr 124 | end 125 | 126 | function DispNet:setBestParams(dataset) 127 | self.params = {} 128 | self.params.disp_max = dataset.disp_max 129 | self.params.fm = 96 130 | self.params.ks = 3 131 | self.params.bs = 256 132 | self.params.lr = 0.002 133 | self.params.mom = 0.9 134 | self.params.decay = 1e-4 135 | end 136 | -------------------------------------------------------------------------------- /src/networks/mc-models/matching.lua: -------------------------------------------------------------------------------- 1 | require 'paths' 2 | local network = require('networks/network') 3 | 4 | local M = {} 5 | 6 | local MatchNet = torch.class('MatchNet', M) 7 | 8 | function MatchNet:__init(self, opt, dataset) 9 | self.name = self:getName(opt) 10 | self.path = opt.storage .. '/net/mc/' 11 | self:setBestParams(opt, dataset) 12 | end 13 | 14 | function MatchNet:getName(opt) 15 | local name = opt.ds .. '_' .. opt.mc .. '_' .. opt.m .. '_'.. opt.inner .. opt.outer .. '_' .. opt.color 16 | if opt.name ~= '' then 17 | name = name .. '_' .. opt.name 18 | end 19 | if opt.subset < 1 then 20 | name = name .. '_' .. opt.subset 21 | elseif opt.all then 22 | name = name .. '_all' 23 | end 24 | 25 | return name 26 | end 27 | 28 | function MatchNet:build() 29 | 30 | self.description = self:getDescriptionNetwork() 31 | self.decision = self:getDecisionNetwork() 32 | self.net = nn.Sequential() 33 | :add(self.description) 34 | :add(self.decision) 35 | :cuda() 36 | self.params.ws = network.getWindowSize(self.net) 37 | end 38 | 39 | function MatchNet:getDescriptors(x_batch) 40 | 41 | -- Replace with fully convolutional network 42 | local testDesc = network.getTestNetwork(self.description) 43 | testDesc:clearState() 44 | -- compute the two image decriptors 45 | -- we compute them separatly in order to reduce the memory usage 46 | -- to reduce more memory use forward_and_free 47 | local output_l = network.forwardFree(testDesc, x_batch[{{1}}]:clone()):clone() 48 | testDesc:clearState() 49 | local output_r = network.forwardFree(testDesc, x_batch[{{2}}]:clone()):clone() 50 | testDesc:clearState() 51 | 52 | return output_l, output_r 53 | 54 | end 55 | 56 | function MatchNet:save(epoch, optimState) 57 | local fname = '' 58 | if epoch == 0 then 59 | fname = (self.name) 60 | else 61 | fname = ('debug/%s_%d'):format(self.name, epoch) 62 | end 63 | 64 | local modelPath = paths.concat(self.path, fname .. '_net.t7') 65 | local optimPath = paths.concat(self.path, fname .. '_optim.t7') 66 | local latestPath = paths.concat(self.path, fname .. '.t7') 67 | local modelFile = { 68 | description = network.clean(self.description), 69 | decision = network.clean(self.decision), 70 | params = self.params, 71 | name = self.name} 72 | 73 | torch.save(modelPath, modelFile) 74 | torch.save(optimPath, optimState) 75 | torch.save(latestPath, { 76 | epoch = epoch, 77 | modelPath = modelPath, 78 | optimPath = optimPath, 79 | }) 80 | 81 | return latestPath 82 | end 83 | 84 | function MatchNet:load(opt) 85 | if opt.mcnet == '' then 86 | self:build() 87 | return nil 88 | else 89 | local checkpoint = torch.load(opt.mcnet) 90 | local model, optimState 91 | if checkpoint.modelPath and paths.filep(checkpoint.modelPath) then 92 | model = torch.load(checkpoint.modelPath) 93 | optimState = torch.load(checkpoint.optimPath) 94 | else 95 | model = checkpoint 96 | checkpoint = nil 97 | end 98 | 99 | self.description = model.description 100 | self.decision = model.decision 101 | self.net = nn.Sequential() 102 | :add(self.description) 103 | :add(self.decision) 104 | :cuda() 105 | 106 | self.params = model.params 107 | self.name = model.name 108 | 109 | return checkpoint, optimState 110 | end 111 | end 112 | 113 | function MatchNet:learningRate(epoch) 114 | if epoch == 12 then 115 | self.params.lr = self.params.lr / 10 116 | end 117 | return self.params.lr 118 | end 119 | 120 | function MatchNet:getModelParameters() 121 | return self.net:getParameters() 122 | end 123 | 124 | function MatchNet:feval(x, dl_dx, inputs, targets) 125 | 126 | dl_dx:zero() 127 | 128 | local prediction = self.net:forward(inputs) 129 | 130 | local loss_x = self.criterion:forward(prediction, targets) 131 | 132 | self.net:backward(inputs, 133 | self.criterion:backward(prediction, targets)) 134 | 135 | return loss_x, dl_dx 136 | 137 | end 138 | 139 | return MatchNet 140 | -------------------------------------------------------------------------------- /src/networks/mc-models/acrt.lua: -------------------------------------------------------------------------------- 1 | 2 | local network = require('networks/network') 3 | local MatchNet = require 'networks/mc-models/matching' 4 | local AcrtNetwork, parent = torch.class('AcrtNetwork','MatchNet') 5 | 6 | function AcrtNetwork:__init(self, opt, dataset) 7 | parent.__init(parent, self, opt, dataset) 8 | self.criterion = nn.BCECriterion2():cuda() 9 | end 10 | 11 | function AcrtNetwork:getDecisionNetwork() 12 | local decision = nn.Sequential() 13 | decision:add(nn.Reshape(self.params.bs, self.params.fm *2)) 14 | for i = 1,self.params.l2 do 15 | decision:add(nn.Linear(i == 1 and 2 * self.params.fm or self.params.nh2, self.params.nh2)) 16 | decision:add(Activation()) 17 | end 18 | 19 | decision:add(nn.Linear(self.params.nh2, 1)) 20 | decision:add(cudnn.Sigmoid(false)) 21 | return decision 22 | end 23 | 24 | function AcrtNetwork:computeMatchingCost(x_batch, disp_max, directions) 25 | local desc_l, desc_r = self:getDescriptors(x_batch) 26 | 27 | -- Replace with fully convolutional network with the same weights 28 | local testDecision = network.getTestNetwork(self.decision) 29 | 30 | -- Initialize the output with the largest matching cost 31 | -- at each possible disparity ('1') 32 | local output = torch.CudaTensor(#directions, disp_max, desc_l:size(3), desc_l:size(4)):fill(1) 33 | 34 | local x2= torch.CudaTensor() 35 | collectgarbage() 36 | for _, direction in ipairs(directions) do 37 | local index = direction == -1 and 1 or 2 38 | for d = 1,disp_max do 39 | collectgarbage() 40 | -- Get the left and right images for this disparity 41 | local l = desc_l[{{1},{},{},{d,-1}}] 42 | local r = desc_r[{{1},{},{},{1,-d}}] 43 | x2:resize(2, r:size(2), r:size(3), r:size(4)) 44 | x2[{{1}}]:copy(l) 45 | x2[{{2}}]:copy(r) 46 | 47 | -- Compute the matching score 48 | local score = testDecision:forward(x2) 49 | 50 | -- Copy to the right place in the output tensor 51 | output[{index,d,{},direction == -1 and {d,-1} or {1,-d}}]:copy(score[{1,1}]) 52 | end 53 | -- Fix the borders of the obtained map 54 | network.fixBorder(output[{{index}}], direction, self.params.ws) 55 | end 56 | collectgarbage() 57 | return output 58 | end 59 | 60 | function AcrtNetwork:setBestParams(opt, dataset ) 61 | self.params = {} 62 | self.n_input_plane = dataset.n_colors 63 | if dataset.name == 'kitti' or dataset.name == 'kitti2015' then 64 | self.params.at=0 65 | 66 | self.params.fm = opt.fm -- number of feature maps 67 | self.params.ks=3 68 | self.params.l1=4 -- number of convolutional layers 69 | self.params.l2=4 -- number of fully connected layers 70 | self.params.nh2= opt.nh2 71 | self.params.bs = opt.batch_size -- batch size 72 | self.params.lr= 0.003 73 | self.params.mom=0.9 74 | self.params.decay=1e-4 75 | 76 | if dataset.name == 'kitti' then 77 | self.params.L1=5 78 | self.params.cbca_i1=2 79 | self.params.cbca_i2=0 80 | self.params.tau1=0.13 81 | self.params.pi1=1.32 82 | self.params.pi2=24.25 83 | self.params.sgm_i=1 84 | self.params.sgm_q1=3 85 | self.params.sgm_q2=2 86 | self.params.alpha1=2 87 | self.params.tau_so=0.08 88 | self.params.blur_sigma=5.99 89 | self.params.blur_t=6 90 | elseif dataset.name == 'kitti2015' then 91 | self.params.L1=5 92 | self.params.cbca_i1=2 93 | self.params.cbca_i2=4 94 | self.params.tau1=0.03 95 | self.params.pi1=2.3 96 | self.params.pi2=24.25 97 | self.params.sgm_i=1 98 | self.params.sgm_q1=3 99 | self.params.sgm_q2=2 100 | self.params.alpha1=1.75 101 | self.params.tau_so=0.08 102 | self.params.blur_sigma=5.99 103 | self.params.blur_t=5 104 | end 105 | elseif dataset.name == 'mb' then 106 | 107 | self.params.l1=5 108 | self.params.fm=112 109 | self.params.ks=3 110 | self.params.l2= opt.l2 or 3 111 | self.params.nh2=384 112 | self.params.lr=0.003 113 | self.params.bs=128 114 | self.params.mom=0.9 115 | 116 | self.params.L1=14 117 | self.params.tau1=0.02 118 | self.params.cbca_i1=2 119 | self.params.cbca_i2=16 120 | self.params.pi1=1.3 121 | self.params.pi2=13.9 122 | self.params.sgm_i=1 123 | self.params.sgm_q1=4.5 124 | self.params.sgm_q2=2 125 | self.params.alpha1=2.75 126 | self.params.tau_so=0.13 127 | self.params.blur_sigma=1.67 128 | self.params.blur_t=2 129 | end 130 | end 131 | -------------------------------------------------------------------------------- /src/datasets/kitti.lua: -------------------------------------------------------------------------------- 1 | Dataset = require('datasets/dataset') 2 | 3 | local KittiDataset, parent = torch.class('KittiDataset', 'Dataset') 4 | 5 | local function createDataset(opt) 6 | return KittiDataset:new(opt) 7 | end 8 | 9 | function KittiDataset:__init(self, opt) 10 | self.name = 'kitti' 11 | self.kittiDir = opt.storage .. '/data.kitti.' .. opt.color 12 | self.kitti2015Dir = opt.storage .. '/data.kitti2015.' .. opt.color 13 | self.dir = self.kittiDir 14 | parent.__init(parent, self, opt) 15 | end 16 | 17 | function KittiDataset:setParams() 18 | -- parameters for training 19 | self.true1 = 1 20 | self.false1 = 4 21 | self.false2 = 10 22 | -- parameters for image transformations 23 | self.hflip = 0 24 | self.vflip = 0 25 | self.rotate = 7 26 | self.hscale = 0.9 27 | self.scale = 1 28 | self.trans = 0 29 | self.hshear = 0.1 30 | self.brightness = 0.7 31 | self.contrast = 1.3 32 | self.d_vtrans = 0 33 | self.d_rotate = 0 34 | self.d_hscale = 1 35 | self.d_hshear = 0 36 | self.d_brightness = 0.3 37 | self.d_contrast = 1 38 | 39 | --parameters for the network 40 | self.height = 350 41 | self.width = 1242 42 | self.disp_max = 228 43 | self.n_te = 195 44 | self.n_tr = 194 45 | self.err_at = 3 46 | 47 | end 48 | 49 | function KittiDataset:load(opt) 50 | if not opt.mix then 51 | self:load_data() 52 | else 53 | function load(fname) 54 | local X_12 = torch.load(self.kittiDir .. '/' .. fname) 55 | local X_15 = torch.load(self.kitti2015Dir .. '/' .. fname) 56 | local X = torch.cat(X_12[{{1,194}}], X_15[{{1,200}}], 1) 57 | X = torch.cat(X, dataset == 'kitti' and X_12[{{195,389}}] or X_15[{{200,400}}], 1) 58 | return X 59 | end 60 | 61 | self.X0 = load('x0.t7') 62 | self.X1 = load('x1.t7') 63 | self.metadata = load('metadata.t7') 64 | 65 | self.dispnoc = torch.cat(torch.load(opt.storage .. self.kittiDir .. '/dispnoc.t7'), torch.load(opt.storage .. self.kitti2015Dir .. '/dispnoc.t7'), 1) 66 | self.tr = torch.cat(torch.load(opt.storage .. self.kittiDir .. '/tr.t7'), torch.load(opt.storage .. self.kitti2015Dir .. '/tr.t7'):add(194)) 67 | self.te = self.name == 'kitti' and torch.load(self.kittiDir .. '/te.t7') or torch.load(self.kitti2015Dir .. '/te.t7'):add(194) 68 | function load_nnz(fname) 69 | local X_12 = torch.load(opt.storage .. self.kittiDir .. '/' .. fname) 70 | local X_15 = torch.load(opt.storage .. self.kitti2015Dir .. '/' .. fname) 71 | X_15[{{},1}]:add(194) 72 | return torch.cat(X_12, X_15, 1) 73 | end 74 | 75 | self.nnz_tr = load_nnz('nnz_tr.t7') 76 | self.nnz_te = load_nnz('nnz_te.t7') 77 | end 78 | end 79 | 80 | function KittiDataset:load_data() 81 | 82 | self.X0 = torch.load(('%s/x0.t7'):format(self.dir)) 83 | self.X1 = torch.load(('%s/x1.t7'):format(self.dir)) 84 | self.dispnoc = torch.load(('%s/dispnoc.t7'):format(self.dir)) 85 | self.metadata = torch.load(('%s/metadata.t7'):format(self.dir)) 86 | self.tr_disp = torch.load(('%s/tr_disp.t7'):format(self.dir)) 87 | self.tr = torch.load(('%s/tr.t7'):format(self.dir)) 88 | self.te = torch.load(('%s/te.t7'):format(self.dir)) 89 | self.nnz_disp = torch.load(('%s/nnz_disp.t7'):format(self.dir)) 90 | self.nnz_tr = torch.load(('%s/nnz_tr.t7'):format(self.dir)) 91 | self.nnz_te = torch.load(('%s/nnz_te.t7'):format(self.dir)) 92 | end 93 | 94 | function KittiDataset:subset(ds, tr, subset) 95 | local tr_subset = Dataset.sample(tr, subset) 96 | local nnz_tr_output = torch.FloatTensor(ds:size()):zero() 97 | local t = adcensus.subset_dataset(tr_subset, ds, nnz_tr_output); 98 | 99 | return nnz_tr_output[{{1,t}}] 100 | end 101 | 102 | 103 | function KittiDataset:getTestSample(i, submit) 104 | local img = {} 105 | 106 | img.height = self.metadata[{i,1}] 107 | img.width = self.metadata[{i,2}] 108 | 109 | img.id = self.metadata[{i,3}] 110 | if not submit then 111 | img.dispnoc = self.dispnoc[{i,{},{},{1,img.width}}]:cuda() 112 | end 113 | x0 = self.X0[{{i},{},{},{1,img.width}}] 114 | x1 = self.X1[{{i},{},{},{1,img.width}}] 115 | 116 | img.x_batch = torch.CudaTensor(2, self.n_colors, self.height, self.width) 117 | img.x_batch:resize(2, self.n_colors, x0:size(3), x0:size(4)) 118 | img.x_batch[1]:copy(x0) 119 | img.x_batch[2]:copy(x1) 120 | 121 | return img 122 | end 123 | 124 | function KittiDataset:getSubmissionRange() 125 | return torch.totable(torch.range(self.X0:size(1) - self.n_te + 1, self.X0:size(1))) 126 | 127 | end 128 | 129 | function KittiDataset:getTestRange() 130 | return torch.totable(self.te) 131 | end 132 | 133 | function KittiDataset:getLR(img) 134 | local x0 = self.X0[img] 135 | local x1 = self.X1[img] 136 | return x0, x1 137 | end 138 | 139 | return createDataset 140 | -------------------------------------------------------------------------------- /src/curesmatch.cu: -------------------------------------------------------------------------------- 1 | extern "C" { 2 | #include "lua.h" 3 | #include "lualib.h" 4 | #include "lauxlib.h" 5 | } 6 | 7 | #include "luaT.h" 8 | #include "THC.h" 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #define TB 128 18 | 19 | THCState* getCutorchState(lua_State* L) 20 | { 21 | lua_getglobal(L, "cutorch"); 22 | lua_getfield(L, -1, "getState"); 23 | lua_call(L, 0, 1); 24 | THCState *state = (THCState*) lua_touserdata(L, -1); 25 | lua_pop(L, 2); 26 | return state; 27 | } 28 | 29 | void checkCudaError(lua_State *L) { 30 | cudaError_t status = cudaPeekAtLastError(); 31 | if (status != cudaSuccess) { 32 | luaL_error(L, cudaGetErrorString(status)); 33 | } 34 | } 35 | 36 | __global__ void outlier_detection(float *d0, float *d1, float *outlier, int size, int dim3, float *conf1, float *conf2, int disp_max, float t1, float t2) 37 | { 38 | int id = blockIdx.x * blockDim.x + threadIdx.x; 39 | if (id < size) { 40 | int x = id % dim3; 41 | int d0i = d0[id]; 42 | if (x - d0i < 0) { 43 | //assert(0); 44 | outlier[id] = 1; 45 | } else if ((abs(d0[id] - d1[id - d0i]) < 1.1) 46 | || (conf1[id] > t1 47 | && (conf1[id] - conf2[id- d0i] > t2) 48 | )){ 49 | outlier[id] = 0; /* match */ 50 | } else { 51 | outlier[id] = 1; /* occlusion */ 52 | for (int d = 0; d < disp_max; d++) { 53 | if (x - d >= 0 && abs(d - d1[id - d]) < 1.1) { 54 | outlier[id] = 2; /* mismatch */ 55 | break; 56 | } 57 | } 58 | } 59 | } 60 | } 61 | 62 | int outlier_detection(lua_State *L) 63 | { 64 | THCState *state = getCutorchState(L); 65 | THCudaTensor *d0 = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 66 | THCudaTensor *d1 = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 67 | THCudaTensor *outlier = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 68 | int disp_max = luaL_checkinteger(L, 4); 69 | THCudaTensor *conf1 = (THCudaTensor*)luaT_checkudata(L, 5, "torch.CudaTensor"); 70 | THCudaTensor *conf2 = (THCudaTensor*)luaT_checkudata(L, 6, "torch.CudaTensor"); 71 | float t1 = luaL_checknumber(L, 7); 72 | float t2 = luaL_checknumber(L, 8); 73 | 74 | outlier_detection<<<(THCudaTensor_nElement(state, d0) - 1) / TB + 1, TB>>>( 75 | THCudaTensor_data(state, d0), 76 | THCudaTensor_data(state, d1), 77 | THCudaTensor_data(state, outlier), 78 | THCudaTensor_nElement(state, d0), 79 | THCudaTensor_size(state, d0, 3), 80 | THCudaTensor_data(state, conf1), 81 | THCudaTensor_data(state, conf2), 82 | disp_max, t1, t2); 83 | checkCudaError(L); 84 | return 0; 85 | } 86 | 87 | __global__ void L2dist_(float *input_L, float *input_R, float *output_L, float *output_R, int size1_input, int size1, int size3, int size23) 88 | { 89 | int id = blockIdx.x * blockDim.x + threadIdx.x; 90 | if (id < size23) { 91 | int dim3 = id % size3; 92 | assert(size1_input <= 512); 93 | float L_cache[512]; 94 | for (int i = 0; i < size1_input; i++) { 95 | L_cache[i] = input_L[i * size23 + id]; 96 | } 97 | 98 | for (int d = 0; d < size1; d++) { 99 | if (dim3 - d >= 0) { 100 | float sum = 0; 101 | float diff = 0; 102 | for (int i = 0; i < size1_input; i++) { 103 | diff = L_cache[i] - input_R[i * size23 + id - d]; 104 | sum += diff*diff; 105 | } 106 | sum = sqrt(sum); 107 | output_L[d * size23 + id] = sum; 108 | output_R[d * size23 + id - d] = sum; 109 | } 110 | } 111 | } 112 | } 113 | 114 | int L2dist(lua_State *L) 115 | { 116 | THCState *state = getCutorchState(L); 117 | THCudaTensor *input_L = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 118 | THCudaTensor *input_R = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 119 | THCudaTensor *output_L = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 120 | THCudaTensor *output_R = (THCudaTensor*)luaT_checkudata(L, 4, "torch.CudaTensor"); 121 | int size23 = THCudaTensor_size(state, output_L, 2) * THCudaTensor_size(state, output_L, 3); 122 | L2dist_<<<(size23 - 1) / TB + 1, TB>>>( 123 | THCudaTensor_data(state, input_L), 124 | THCudaTensor_data(state, input_R), 125 | THCudaTensor_data(state, output_L), 126 | THCudaTensor_data(state, output_R), 127 | THCudaTensor_size(state, input_L, 1), 128 | THCudaTensor_size(state, output_L, 1), 129 | THCudaTensor_size(state, output_L, 3), 130 | size23); 131 | checkCudaError(L); 132 | return 0; 133 | } 134 | 135 | static const struct luaL_Reg funcs[] = { 136 | {"outlier_detection", outlier_detection}, 137 | {"L2dist", L2dist}, 138 | {NULL, NULL} 139 | }; 140 | 141 | extern "C" int luaopen_libcuresmatch(lua_State *L) { 142 | srand(42); 143 | luaL_openlib(L, "curesmatch", funcs, 0); 144 | return 1; 145 | } 146 | -------------------------------------------------------------------------------- /scripts/preprocess_kitti.lua: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env luajit 2 | -- This file is copied from https://github.com/jzbonter/mc-cnn 3 | 4 | require 'image' 5 | require 'nn' 6 | require 'cutorch' 7 | require 'libadcensus' 8 | require 'os' 9 | 10 | cmd = torch.CmdLine() 11 | cmd:option('-color', 'rgb') 12 | cmd:option('-storage', 'storage') 13 | opt = cmd:parse(arg) 14 | 15 | for _, dataset in ipairs({2012, 2015}) do 16 | print(('dataset %d'):format(dataset)) 17 | 18 | torch.manualSeed(42) 19 | if dataset == 2012 then 20 | n_disp = 40 21 | n_tr = 194 22 | n_te = 195 23 | data = opt.storage .. '/data.kitti' 24 | path = data .. '.' .. opt.color 25 | image_0 = opt.colors == 1 and 'image_0' or 'colored_0' 26 | image_1 = opt.colors == 1 and 'image_1' or 'colored_1' 27 | disp_noc = 'disp_noc' 28 | nchannel = opt.color == 'rgb' and 3 or 1 29 | elseif dataset == 2015 then 30 | n_tr = 200 31 | n_disp = 40 32 | n_te = 200 33 | data = opt.storage .. '/data.kitti2015' 34 | path = data .. '.' .. opt.color 35 | image_0 = 'image_2' 36 | image_1 = 'image_3' 37 | nchannel = opt.color == 'rgb' and 3 or 1 38 | disp_noc = 'disp_noc_0' 39 | end 40 | 41 | os.execute('mkdir -p ' .. path) 42 | 43 | height = 350 44 | width = 1242 45 | 46 | x0 = torch.FloatTensor(n_tr + n_te, nchannel, height, width):zero() 47 | x1 = torch.FloatTensor(n_tr + n_te, nchannel, height, width):zero() 48 | dispnoc = torch.FloatTensor(n_tr, 1, height, width):zero() 49 | metadata = torch.IntTensor(n_tr + n_te, 3):zero() 50 | 51 | examples = {} 52 | for i = 1,n_tr do 53 | examples[#examples + 1] = {dir='training', cnt=i} 54 | end 55 | 56 | for i = 1,n_te do 57 | examples[#examples + 1] = {dir='testing', cnt=i} 58 | end 59 | 60 | for i, arg in ipairs(examples) do 61 | img_path = '%s/unzip/%s/%s/%06d_10.png' 62 | img_0 = image.loadPNG(img_path:format(data, arg['dir'], image_0, arg['cnt'] - 1), nchannel, 'byte'):float() 63 | img_1 = image.loadPNG(img_path:format(data, arg['dir'], image_1, arg['cnt'] - 1), nchannel, 'byte'):float() 64 | 65 | if opt.colors == 1 and dataset == 2015 then 66 | img_0 = image.rgb2y(img_0) 67 | img_1 = image.rgb2y(img_1) 68 | end 69 | 70 | -- crop 71 | img_height = img_0:size(2) 72 | img_width = img_0:size(3) 73 | img_0 = img_0:narrow(2, img_height - height + 1, height) 74 | img_1 = img_1:narrow(2, img_height - height + 1, height) 75 | 76 | -- preprocess 77 | print(i) 78 | 79 | img_0:add(-img_0:mean()):div(img_0:std()) 80 | img_1:add(-img_1:mean()):div(img_1:std()) 81 | 82 | x0[{i,{},{},{1,img_width}}]:copy(img_0) 83 | x1[{i,{},{},{1,img_width}}]:copy(img_1) 84 | 85 | if arg['dir'] == 'training' then 86 | img_disp = torch.FloatTensor(1, img_height, img_width) 87 | adcensus.readPNG16(img_disp, ('%s/unzip/training/%s/%06d_10.png'):format(data, disp_noc, arg['cnt'] - 1)) 88 | dispnoc[{i, 1}]:narrow(2, 1, img_width):copy(img_disp:narrow(2, img_height - height + 1, height)) 89 | end 90 | 91 | metadata[{i, 1}] = img_height 92 | metadata[{i, 2}] = img_width 93 | metadata[{i, 3}] = arg['cnt'] - 1 94 | 95 | collectgarbage() 96 | end 97 | 98 | -- split train and test 99 | perm = torch.randperm(n_tr):long() 100 | te = perm[{{1,40}}]:clone() 101 | tr_disp = perm[{{41,41 + n_disp-1}}]:clone() 102 | tr = perm[{{41+ n_disp, n_tr}}]:clone() 103 | 104 | -- prepare tr dataset 105 | nnz_disp = torch.FloatTensor(23e6, 4) 106 | nnz_tr = torch.FloatTensor(23e6, 4) 107 | nnz_te = torch.FloatTensor(23e6, 4) 108 | nnz_disp_t = 0 109 | nnz_tr_t = 0 110 | nnz_te_t = 0 111 | for i = 1,n_tr do 112 | local disp = dispnoc[{{i}}]:cuda() 113 | adcensus.remove_nonvisible(disp) 114 | adcensus.remove_occluded(disp) 115 | adcensus.remove_white(x0[{{i}}]:cuda(), disp) 116 | disp = disp:float() 117 | 118 | is_te = false 119 | for j = 1,te:nElement() do 120 | if i == te[j] then 121 | is_te = true 122 | end 123 | end 124 | 125 | is_disp = false 126 | for j = 1,tr_disp:nElement() do 127 | if i == tr_disp[j] then 128 | is_disp = true 129 | end 130 | end 131 | if is_te then 132 | nnz_te_t = adcensus.make_dataset2(disp, nnz_te, i, nnz_te_t) 133 | elseif is_disp then 134 | nnz_disp_t = adcensus.make_dataset2(disp, nnz_disp, i, nnz_disp_t) 135 | else 136 | nnz_tr_t = adcensus.make_dataset2(disp, nnz_tr, i, nnz_tr_t) 137 | end 138 | end 139 | nnz_disp = torch.FloatTensor(nnz_disp_t, 4):copy(nnz_disp[{{1,nnz_disp_t}}]) 140 | nnz_tr = torch.FloatTensor(nnz_tr_t, 4):copy(nnz_tr[{{1,nnz_tr_t}}]) 141 | nnz_te = torch.FloatTensor(nnz_te_t, 4):copy(nnz_te[{{1,nnz_te_t}}]) 142 | 143 | 144 | torch.save(('%s/x0.t7'):format(path), x0) 145 | torch.save(('%s/x1.t7'):format(path), x1) 146 | torch.save(('%s/dispnoc.t7'):format(path), dispnoc) 147 | torch.save(('%s/metadata.t7'):format(path), metadata) 148 | torch.save(('%s/tr_disp.t7'):format(path), tr_disp) 149 | torch.save(('%s/tr.t7'):format(path), tr) 150 | torch.save(('%s/te.t7'):format(path), te) 151 | torch.save(('%s/nnz_disp.t7'):format(path), nnz_disp) 152 | torch.save(('%s/nnz_tr.t7'):format(path), nnz_tr) 153 | torch.save(('%s/nnz_te.t7'):format(path), nnz_te) 154 | end 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Improved Stereo Matching with Constant Highway Networks and Reflective Loss 2 | =================================================================================== 3 | 4 | This implements the full pipeline of our paper [Improved Stereo Matching with Constant Highway Networks and Reflective Loss](https://arxiv.org/abs/1701.00165) by Amit Shaked and Lior Wolf 5 | 6 | The repository contains 7 | 8 | - Training of the Constant Highway Network to compute the matching cost 9 | - A few post processing steps taken from [MC-CNN](https://github.com/jzbontar/mc-cnn) 10 | - Training of the Global Disparity Network with the Reflective Loss 11 | - A confidence based outlier detection and interpolation 12 | 13 | 14 | ## Requirements 15 | - Install [Torch](http://torch.ch/docs/getting-started.html) on a machine with CUDA GPU 16 | - Install [cuDNN v4 or v5](https://developer.nvidia.com/cudnn) and the Torch [cuDNN bindings](https://github.com/soumith/cudnn.torch/tree/R4) 17 | If you already have Torch installed, update `nn`, `cunn`, and `cudnn`. 18 | - Install [OpenCV 2.4](http://opencv.org/) and [png++](http://www.nongnu.org/pngpp/) 19 | - A NVIDIA GPU with at least 6 GB of memory is required to run on the KITTI 20 | data set and 12 GB to run on the Middlebury data set. 21 | 22 | The code is released under the BSD 2-Clause license. 23 | Please cite our [paper](https://arxiv.org/abs/1701.00165 ) 24 | if you use code from this repository in your work. 25 | 26 | @article{shaked2016stereo, 27 | title={Improved Stereo Matching with Constant Highway Networks and Reflective Loss}, 28 | author={Shaked, Amit and Wolf, Lior}, 29 | journal={arXiv preprint arxiv:1701.00165}, 30 | year={2016} 31 | } 32 | 33 | Setup 34 | ------------------------ 35 | Create directory for the data to be stored and link it under the name "storage" where the README file is 36 | ```bash 37 | ln -s [your_dir] storage 38 | ``` 39 | Or simply create a storage directory 40 | ```bash 41 | mkdir storage 42 | ``` 43 | 44 | Run mkdirs script: 45 | ```bash 46 | scripts/mkdirs.sh 47 | ``` 48 | 49 | Compile the shared libraries: 50 | ```bash 51 | make 52 | ``` 53 | 54 | The command should produce the files: `libadcensus.so`, `libcv.so` and `libcuresmatch.so` in the lib dir. 55 | 56 | 57 | ### KITTI 58 | 59 | 60 | - Download the [KITTI 2012](http://www.cvlibs.net/download.php?file=data_stereo_flow.zip) data set and unzip it 61 | into `storage/data.kitti/unzip` (you should end up with a file `storage/data.kitti/unzip/training/image_0/000000_10.png`) and 62 | - Download the [KITTI 2015](http://www.cvlibs.net/download.php?file=data_scene_flow.zip) data set and unzip it 63 | into `storage/data.kitti2015/unzip` (you should end up with a file `storage/data.kitti2015/unzip/training/image_2/000000_10.png`). 64 | 65 | 66 | Run the preprocessing script: 67 | ```bash 68 | scripts/preprocess_kitti.lua -color rgb -storage storage 69 | ``` 70 | 71 | It should output: 72 | ```bash 73 | dataset 2012 74 | 1 75 | ... 76 | 389 77 | dataset 2015 78 | 1 79 | ... 80 | 400 81 | ``` 82 | 83 | ### Middlebury 84 | Run `download_middlebury.sh` to download the training data 85 | (this can take a long time, depending on your internet connection). 86 | ```bash 87 | scripts/download_middlebury.sh 88 | ``` 89 | 90 | The data set is downloaded into the `data.mb/unzip` directory. 91 | 92 | Compile the [MiddEval3-SDK](http://vision.middlebury.edu/stereo/submit3/). You 93 | should end up with the `computemask` binary in one of the directories listed in 94 | your `PATH` enviromential variable. 95 | 96 | Install [ImageMagick](http://www.imagemagick.org/script/index.php); the 97 | preprocessing steps requires the `convert` binary to resize the images. 98 | 99 | Run the preprocessing script: 100 | ```bash 101 | mkdir storage/data.mb.imperfect_gray 102 | scripts/preprocess_mb.py imperfect gray 103 | ``` 104 | 105 | It should output: 106 | ```bash 107 | Adirondack 108 | Backpack 109 | ... 110 | testH/Staircase 111 | ``` 112 | 113 | The preprocessing is slow (it takes around 30 minutes) the first time it is 114 | run, because the images have to be resized. 115 | 116 | 117 | Usage 118 | --------------------- 119 | Enter the src directory. 120 | The `main.lua` file contains different training and testing options: 121 | 122 | - 'a' is the action, it can be can be 'train\_mcn' to train the matching cost network, 'train\_gdn' to train the global disparity network, 'test' to check the pipeline on the validation set and 'submit' to create the submission file for the online evaluation servers 123 | - 'ds' is the dataset (kitti, kitti2015 or mb) 124 | - 'mc' is the matching cost architecture to use 125 | - 'm' is the mode ('fast', 'acrt' or 'hybrid' for the hybrid loss) 126 | - 'gdn' is the global disparity network architecture. Use 'ref' for reflective. 127 | Don't use this option when training the matching cost network 128 | - 'all' is to train on both training and validation data. 129 | When choosing this option the gdn will be automatically trained and the submission file would be created. 130 | 131 | See `opts.lua` for other options. 132 | 133 | ### Training 134 | 135 | Try training the hybrid Resmatch matching cost network: 136 | ```bash 137 | th main.lua -ds kitti -a train_mcn -mc resmatch -m hybrid 138 | ``` 139 | 140 | And then training the gdn with the reflective loss, using this matching cost network: 141 | ```bash 142 | th main.lua -ds kitti -a train_gdn -mc resmetch -m hybrid -mcnet ../storage/net/mc/kitti_resmatch_hybrid_LL_rgb.t7 -gdn ref 143 | ``` 144 | 145 | You can also try training the fast resmatch architecture, on 0.2 of the data, and test it every 3 epochs: 146 | 147 | ```bash 148 | th main.lua -ds kitti -a train_mcn -mc resmatch -m fast -debug -times 3 -subset 0.2 149 | ``` 150 | 151 | -------------------------------------------------------------------------------- /src/networks/network.lua: -------------------------------------------------------------------------------- 1 | require('networks/modules/Concatenation') 2 | require('networks/modules/SpatialConvolution1_fw') 3 | 4 | local network = {} 5 | local function deepCopy(tbl) 6 | -- creates a copy of a network with new modules and the same tensors 7 | local copy = {} 8 | for k, v in pairs(tbl) do 9 | if type(v) == 'table' then 10 | copy[k] = deepCopy(v) 11 | else 12 | copy[k] = v 13 | end 14 | end 15 | if torch.typename(tbl) then 16 | torch.setmetatable(copy, torch.typename(tbl)) 17 | end 18 | return copy 19 | end 20 | 21 | function network.clean(model) 22 | return deepCopy(model):float():clearState() 23 | end 24 | 25 | 26 | local function convInit(model, name) 27 | for k,v in pairs(model:findModules(name)) do 28 | local n = v.kW*v.kH*v.nOutputPlane 29 | v.weight:normal(0,math.sqrt(2/n)) 30 | if cudnn.version >= 4000 then 31 | v.bias = nil 32 | v.gradBias = nil 33 | else 34 | v.bias:zero() 35 | end 36 | end 37 | end 38 | 39 | local function bNInit(model, name) 40 | for k,v in pairs(model:findModules(name)) do 41 | v.weight:fill(1) 42 | v.bias:zero() 43 | end 44 | end 45 | 46 | local function linearInit(model, name) 47 | for k,v in pairs(model:findModules(name)) do 48 | v.bias:zero() 49 | end 50 | end 51 | 52 | function network.getWindowSize(net, ws) 53 | ws = ws or 1 54 | 55 | for i = 1,#net.modules do 56 | local module = net:get(i) 57 | if torch.typename(module) == 'cudnn.SpatialConvolution' then 58 | ws = ws + module.kW - 1 - module.padW - module.padH 59 | end 60 | if module.modules then 61 | ws = network.getWindowSize(module, ws) 62 | end 63 | end 64 | return ws 65 | end 66 | 67 | function network.init(net) 68 | convInit(net, 'cudnn.SpatialConvolution') 69 | convInit(net, 'nn.SpatialConvolution') 70 | bNInit(net, 'cudnn.SpatialBatchNormalization') 71 | bNInit(net, 'nn.SpatialBatchNormalization') 72 | linearInit(net, 'nn.Linear') 73 | end 74 | 75 | function network.fixBorder(vol, direction, ws) 76 | local n = (ws - 1) / 2 77 | for i=1,n do 78 | vol[{{},{},{},direction * i}]:copy(vol[{{},{},{},direction * (n + 1)}]) 79 | end 80 | end 81 | 82 | local function padConvs(module) 83 | -- Pads the convolutional layers to maintain the image resolution 84 | for i = 1,#module.modules do 85 | local m = module:get(i) 86 | if torch.typename(m) == 'cudnn.SpatialConvolution' then 87 | m.dW = 1 88 | m.dH = 1 89 | if m.kW > 1 then 90 | m.padW = (m.kW - 1) / 2 91 | end 92 | if m.kH > 1 then 93 | m.padH = (m.kH - 1) / 2 94 | end 95 | elseif m.modules then 96 | padConvs(m) 97 | end 98 | end 99 | end 100 | 101 | function network.getTestNetwork(model) 102 | -- Replace the model with fully-convolutional network 103 | -- with the same weights, and pad it to maintain resolution 104 | 105 | local testModel = model:clone('weight', 'bias') 106 | 107 | -- replace linear with 1X1 conv 108 | local nodes, containers = testModel:findModules('nn.Linear') 109 | for i = 1, #nodes do 110 | for j = 1, #(containers[i].modules) do 111 | if containers[i].modules[j] == nodes[i] then 112 | 113 | local w = nodes[i].weight 114 | local b = nodes[i].bias 115 | local conv = nn.SpatialConvolution1_fw(w:size(2), w:size(1)):cuda() 116 | conv.weight:copy(w) 117 | conv.bias:copy(b) 118 | -- Replace with a new instance 119 | containers[i].modules[j] = conv 120 | end 121 | end 122 | end 123 | 124 | -- replace reshape with concatenation 125 | nodes, containers = testModel:findModules('nn.Reshape') 126 | for i = 1, #nodes do 127 | for j = 1, #(containers[i].modules) do 128 | if containers[i].modules[j] == nodes[i] then 129 | -- Replace with a new instance 130 | containers[i].modules[j] = nn.Concatenation():cuda() 131 | end 132 | end 133 | end 134 | 135 | -- pad convolutions 136 | padConvs(testModel) 137 | 138 | -- switch to evalutation mode 139 | testModel:evaluate() 140 | 141 | 142 | return testModel 143 | end 144 | 145 | function network.forwardFree(net, input) 146 | -- Forwards the network w.r.t input module by module 147 | -- while cleaning previous modules state 148 | local currentOutput = input 149 | for i=1, #net.modules do 150 | local m = net.modules[i] 151 | local nextOutput 152 | if torch.typename(m) == 'nn.Sequential' then 153 | nextOutput = network.forwardFree(m, currentOutput) 154 | currentOutput = nextOutput:clone() 155 | elseif torch.typename(m) == 'nn.ConcatTable' or torch.typename(m) == 'nn.ParallelTable' then 156 | nextOutput = m:forward(currentOutput) 157 | currentOutput = {} 158 | currentOutput[1] = nextOutput[1]:clone() 159 | currentOutput[2] = nextOutput[2]:clone() 160 | else 161 | nextOutput = m:updateOutput(currentOutput) 162 | currentOutput = nextOutput:clone() 163 | end 164 | m:apply( 165 | function(mod) 166 | mod:clearState() 167 | end 168 | ) 169 | 170 | collectgarbage() 171 | end 172 | 173 | return currentOutput 174 | end 175 | 176 | function network.sliceInput(input) 177 | local sizes = torch.LongStorage{input:size(1) / 2, input:size(2), input:size(3), input:size(4)} 178 | local strides = torch.LongStorage{input:stride(1) * 2, input:stride(2), input:stride(3), input:stride(4)} 179 | 180 | local input_L = torch.CudaTensor(input:storage(), 1, sizes, strides) 181 | local input_R = torch.CudaTensor(input:storage(), input:stride(1) + 1, sizes, strides) 182 | 183 | return input_L, input_R 184 | end 185 | 186 | 187 | Normalization = nn.Normalize2 188 | Activation = cudnn.ReLU 189 | Convolution = cudnn.SpatialConvolution 190 | Avg = cudnn.SpatialAveragePooling 191 | Max = nn.SpatialMaxPooling 192 | 193 | return network 194 | -------------------------------------------------------------------------------- /src/main.lua: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/lua luajit 2 | require 'cutorch' 3 | require 'cunn' 4 | require 'cudnn' 5 | require 'nn' 6 | require 'torch' 7 | require 'image' 8 | require 'paths' 9 | 10 | require '../libadcensus' 11 | require '../libcv' 12 | require '../libcuresmatch' 13 | 14 | 15 | -- Initialize components 16 | local opts = require 'opts' 17 | local opt = opts.parse(arg) 18 | local log = require('logger')(opt) 19 | local dataset = require('datasets/' .. opt.ds)(opt) 20 | local mcnet = require('networks/mc-models/'..opt.mc.. 21 | '/'..opt.m):new(opt, dataset) 22 | local runner = require('runner')(mcnet, nil, dataset, opt) 23 | local Trainer = require('trainer') 24 | 25 | torch.manualSeed(opt.seed) 26 | cutorch.manualSeed(opt.seed) 27 | cutorch.setDevice(tonumber(opt.gpu)) 28 | 29 | -- Set the actions to do, optins are: 30 | -- train_mcn - train the matching cost network 31 | -- train_gdn - train the global disparity network 32 | -- test - test the pipeline with validation data 33 | -- submit - create the submission file for the dataset's 34 | -- online evaluation server 35 | pipeline = {} 36 | pipeline[opt.a] = true 37 | if opt.a == 'train_mcn' then 38 | if opt.all then 39 | pipeline['train_gdn'] = true 40 | pipeline['submit'] = true 41 | end 42 | elseif opt.a == 'train_gdn' then 43 | pipeline['test'] = true 44 | if opt.all then 45 | pipeline['submit'] = true 46 | end 47 | end 48 | 49 | -- Load last checkpoint if exists 50 | print('===> Loading matching cost network...') 51 | local checkpoint, optimState = mcnet:load(opt) 52 | print('===> Loaded! Network: ' .. mcnet.name) 53 | 54 | -- Training the matching cost network 55 | if pipeline['train_mcn'] then 56 | 57 | local start_epoch = checkpoint and checkpoint.epoch +1 or opt.start_epoch 58 | 59 | -- Initialize new trainer for the MCN 60 | local trainer = Trainer(mcnet, dataset.nnz:size(1), mcnet.params.bs/2, optimState) 61 | 62 | -- The function the trainer uses to get the next batch 63 | local function trainingBatch(start, size, ws) 64 | return dataset:trainingSamples(start,size,ws) 65 | end 66 | 67 | print('===> training matching cost network') 68 | for epoch = start_epoch, opt.epochs do 69 | dataset:shuffle() -- to get random order of samples 70 | 71 | -- Train one epoch of all the samples 72 | local err_tr = trainer:train(epoch, trainingBatch) 73 | 74 | -- Output results 75 | local msg = ('train epoch %g\t err %g\tlr %g\n') 76 | :format(epoch, err_tr, trainer.optimState.learningRate) 77 | log:write(msg) 78 | 79 | -- Save the current checkpoint 80 | mcnet:save(epoch, optimState) 81 | 82 | -- Run validation if wanted 83 | local validate = ((opt.debug and epoch % opt.times == 0) 84 | or (epoch >= opt.after)) and epoch < opt.epochs 85 | if validate then 86 | print('===> testing...') 87 | local err_te = runner:test(dataset:getTestRange(), false, false) 88 | 89 | -- Output validation results 90 | log:write(('test epoch: %g\terror: %g\n'):format(epoch, err_te)) 91 | end 92 | end 93 | 94 | -- After training is completed test and save the final model 95 | mcnet:save(0, optimState) 96 | local err_te = runner:test(dataset:getTestRange(), true, opt.make_cache) 97 | log:write(err_te) 98 | end 99 | 100 | -- Train the global disparity network 101 | local dnet 102 | if opt.gdn ~= '' then 103 | 104 | dnet = require('networks/gdn-models/' .. opt.gdn)(opt, dataset, mcnet.name) 105 | runner:setGdn(dnet) 106 | 107 | if pipeline['train_gdn'] then 108 | 109 | -- Load disparity data 110 | local ok = dataset:loadDispData(mcnet.name) 111 | 112 | -- If non exists create and save it 113 | if not ok then 114 | print('===> Creating training data for network ' .. mcnet.name) 115 | local samples, indexes = runner:createDispData() 116 | dataset:saveDispData(samples, indexes, mcnet.name) 117 | end 118 | 119 | -- Load last checkpoint if exists 120 | checkpoint, optimState = dnet:load(opt) 121 | local start_epoch = checkpoint and checkpoint.epoch +1 or opt.start_epoch 122 | 123 | 124 | -- Initialize new trainer for the GDN 125 | local trainer = Trainer(dnet, dataset.disp:size(1), dnet.params.bs/2, optimState) 126 | 127 | -- The function the trainer uses to get the next batch 128 | local function getTrainBatch(start, size, ws) 129 | return dnet:getDisparityTrainingSamples(start,size,ws) 130 | end 131 | 132 | print('===> training disparity network...') 133 | print('===> Starting from epoch ' .. start_epoch) 134 | for epoch = start_epoch, opt.epochs do 135 | dataset:shuffle() -- samples in random order 136 | 137 | local err_tr = trainer:train(epoch, getTrainBatch) 138 | 139 | -- Output results 140 | local msg = ('train epoch %g\t err %g\tlr %g\n'):format(epoch, err_tr, trainer.optimState.learningRate) 141 | log:write(msg) 142 | 143 | -- Save the current checkpoint 144 | dnet:save(epoch, optimState) 145 | 146 | -- Run validation if wanted 147 | local validate = ((opt.debug and epoch % opt.times == 0) or (epoch >= opt.after)) and epoch < opt.epochs 148 | if validate then 149 | print('===> testing...') 150 | local err_te = runner:test(dataset:getTestRange(), false, false) 151 | 152 | -- Output validation results 153 | log:write(('test: %g\t%g\n'):format(epoch, err_te)) 154 | end 155 | end 156 | 157 | -- After training is completed test and save the final model 158 | dnet:save(0, optimState) 159 | local err_te = runner:test(dataset:getTestRange(), true, false) 160 | log:write(err_te) 161 | else 162 | dnet:load(opt) 163 | end 164 | end 165 | 166 | -- Run validation 167 | if pipeline['test'] then 168 | local err_te = runner:test(dataset:getTestRange(), true, opt.make_cache) 169 | log:write(err_te) 170 | end 171 | 172 | -- Submit results 173 | if pipeline['submit'] then 174 | submission_range = dataset:getSubmissionRange() 175 | runner:submit(submission_range) 176 | end 177 | -------------------------------------------------------------------------------- /src/datasets/mb.lua: -------------------------------------------------------------------------------- 1 | Dataset = require('datasets/dataset') 2 | 3 | local MbDataset, parent = torch.class('MbDataset', 'Dataset') 4 | 5 | function MbDataset:__init(self, opt) 6 | parent.__init(parent, self, opt) 7 | self.name='mb' 8 | end 9 | 10 | local function createDataset(opt) 11 | return MbDataset:new(opt) 12 | end 13 | 14 | function MbDataset:setParams(opt) --parameters for training 15 | self.true1 = 0.5 16 | self.false1 = 1.5 17 | self.false2 = 6 18 | 19 | -- parameters for image transformations 20 | self.hflip = 0 21 | self.vflip= 0 22 | self.rotate = 28 23 | self.hscale = 0.8 24 | self.scale = 0.8 25 | self.trans = 0 26 | self.hshear = 0.1 27 | self.brightness = 1.3 28 | self.contrast = 1.1 29 | self.d_vtrans = 1 30 | self.d_rotate = 3 31 | self.d_hscale = 0.9 32 | self.d_hshear = 0.3 33 | self.d_brightness = 0.7 34 | self.d_contrast = 1.1 35 | 36 | self.d_light=0.2 37 | self.d_exp=0.2 38 | --parameters for the network 39 | self.rect = opt.rect 40 | self.n_colors = opt.color == 'rgb' and 3 or 1 41 | self.color = opt.color 42 | 43 | self.height = 1500 44 | self.width = 1000 45 | self.disp_max = 200 46 | 47 | self.err_at = 1 48 | 49 | 50 | end 51 | 52 | function MbDataset:load(opt) 53 | local data_dir = ('%s/data.mb.%s_%s'):format(opt.storage, self.rect, self.color) 54 | self.te = fromfile(('%s/te.bin'):format(data_dir)) 55 | self.metadata = fromfile(('%s/meta.bin'):format(data_dir)) 56 | self.nnz_tr = fromfile(('%s/nnz_tr.bin'):format(data_dir)) 57 | self.nnz_te = fromfile(('%s/nnz_te.bin'):format(data_dir)) 58 | self.fname_submit = {} 59 | for line in io.open(('%s/fname_submit.txt'):format(data_dir), 'r'):lines() do 60 | table.insert(self.fname_submit, line) 61 | end 62 | self.X = {} 63 | self.dispnoc = {} 64 | local fname = "" 65 | for n = 1,self.metadata:size(1) do 66 | local XX = {} 67 | local light = 1 68 | while true do 69 | fname = ('%s/x_%d_%d.bin'):format(data_dir, n, light) 70 | if not paths.filep(fname) then 71 | break 72 | end 73 | table.insert(XX, fromfile(fname)) 74 | light = light + 1 75 | if opt.a == 'test_te' or opt.a == 'submit' then 76 | break -- we don't need to load training data 77 | end 78 | end 79 | table.insert(self.X, XX) 80 | 81 | fname = ('%s/dispnoc%d.bin'):format(data_dir, n) 82 | if paths.filep(fname) then 83 | table.insert(self.dispnoc, fromfile(fname)) 84 | end 85 | end 86 | end 87 | 88 | function MbDataset:subset(ds,tr, subset) 89 | local tr_2014 = Dataset.sample(torch.range(11, 23):long(), subset) 90 | local tr_2006 = Dataset.sample(torch.range(24, 44):long(), subset) 91 | local tr_2005 = Dataset.sample(torch.range(45, 50):long(), subset) 92 | local tr_2003 = Dataset.sample(torch.range(51, 52):long(), subset) 93 | local tr_2001 = Dataset.sample(torch.range(53, 60):long(), subset) 94 | 95 | local tr_subset = torch.cat(tr_2014, tr_2006) 96 | tr_subset = torch.cat(tr_subset, tr_2005) 97 | tr_subset = torch.cat(tr_subset, tr_2003) 98 | tr_subset = torch.cat(tr_subset, tr_2001) 99 | 100 | local nnz_tr_output = torch.FloatTensor(ds:size()):zero() 101 | local t = adcensus.subset_dataset(tr_subset, ds, nnz_tr_output); 102 | return nnz_tr_output[{{1,t}}] 103 | 104 | end 105 | 106 | 107 | function MbDataset:getSubmissionRange() 108 | local examples = {} 109 | -- for i = #X - 14, #X do 110 | for i = #self.X - 29, #self.X do 111 | table.insert(examples, {i, 2}) 112 | end 113 | return examples 114 | end 115 | 116 | function MbDataset:getTestTange() 117 | local examples = {} 118 | for i = 1,self.te:nElement() do 119 | table.insert(examples, {self.te[i], 2}) 120 | end 121 | table.insert(examples, {5, 3}) 122 | table.insert(examples, {5, 4}) 123 | return examples 124 | end 125 | 126 | function MbDataset:getTestSample(i) 127 | local img = {} 128 | 129 | local i, right = table.unpack(i) 130 | img.id = ('%d_%d'):format(i, right) 131 | img.disp_max = self.metadata[{i,3}] 132 | local x0 = self.X[i][1][{{1}}] 133 | local x1 = self.X[i][1][{{right}}] 134 | img.x_batch = torch.CudaTensor(2, self.n_colors, self.height, self.width) 135 | img.x_batch:resize(2, self.n_colors, x0:size(3), x0:size(4)) 136 | --print(img.x_batch:size(), x0:size()) 137 | img.x_batch[1]:copy(x0) 138 | img.x_batch[2]:copy(x1) 139 | 140 | img.dispnoc = self.dispnoc[i]:cuda() 141 | return img 142 | end 143 | 144 | function MbDataset:prepareTrainingData(subset, action, all) 145 | -- subset training dataset 146 | if subset < 1 then 147 | self.nnz = self:subset(self.nnz_tr, self.tr, subset) 148 | elseif all then 149 | self.nnz = torch.cat(self.nnz_tr, self.nnz_te, 1) 150 | else 151 | self.nnz = self.nnz_tr 152 | end 153 | self.disp = self.nnz 154 | self.nnz_disp = self.nnz_tr 155 | end 156 | 157 | function MbDataset:getLR(img) 158 | local x0, x1 159 | local light = (torch.random() % (#self.X[img] - 1)) + 2 160 | local exp = (torch.random() % self.X[img][light]:size(1)) + 1 161 | local light_ = light 162 | local exp_ = exp 163 | if torch.uniform() < self.d_exp then 164 | exp_ = (torch.random() % self.X[img][light]:size(1)) + 1 165 | end 166 | if torch.uniform() < self.d_light then 167 | light_ = math.max(2, light - 1) 168 | end 169 | x0 = self.X[img][light][{exp,1}] 170 | x1 = self.X[img][light_][{exp_,2}] 171 | return x0, x1 172 | end 173 | 174 | function fromfile(fname) 175 | -- initialize a tensor of the proper type and dimensions from the file fname 176 | local file = io.open(fname .. '.dim') 177 | local dim = {} 178 | for line in file:lines() do 179 | table.insert(dim, tonumber(line)) 180 | end 181 | if #dim == 1 and dim[1] == 0 then 182 | return torch.Tensor() 183 | end 184 | 185 | local file = io.open(fname .. '.type') 186 | local type = file:read('*all') 187 | 188 | local x 189 | if type == 'float32' then 190 | x = torch.FloatTensor(torch.FloatStorage(fname)) 191 | elseif type == 'int32' then 192 | x = torch.IntTensor(torch.IntStorage(fname)) 193 | elseif type == 'int64' then 194 | x = torch.LongTensor(torch.LongStorage(fname)) 195 | else 196 | print(fname, type) 197 | assert(false) 198 | end 199 | 200 | x = x:reshape(torch.LongStorage(dim)) 201 | return x 202 | end 203 | 204 | return createDataset 205 | -------------------------------------------------------------------------------- /src/datasets/dataset.lua: -------------------------------------------------------------------------------- 1 | local M = {} 2 | 3 | local Dataset = torch.class('Dataset', M) 4 | 5 | function Dataset:__init(self, opt) 6 | torch.manualSeed(opt.seed) 7 | cutorch.manualSeed(opt.seed) 8 | self.__index = self 9 | self.n_colors = opt.color == 'rgb' and 3 or 1 10 | self:setParams(opt) 11 | self:load(opt) 12 | if opt.a == 'train_mcn' or opt.a == 'train_gdn' then 13 | self:prepareTrainingData(opt.subset, opt.all) 14 | end 15 | end 16 | 17 | 18 | function Dataset:obfuscationParams() 19 | assert(self.hscale <= 1 and self.scale <= 1) 20 | 21 | local params = {} 22 | params.x0 = {} 23 | params.x1 = {} 24 | local s = torch.uniform(self.scale, 1) 25 | params.x0.scale = {s * torch.uniform(self.hscale, 1), s} 26 | if self.hflip == 1 and torch.uniform() < 0.5 then 27 | params.x0.scale[1] = -params.x0.scale[1] 28 | end 29 | if self.vflip == 1 and torch.uniform() < 0.5 then 30 | params.x0.scale[2] = -params.x0.scale[2] 31 | end 32 | params.x0.hshear = torch.uniform(-self.hshear, self.hshear) 33 | params.x0.trans = {torch.uniform(-self.trans, self.trans), torch.uniform(-self.trans, self.trans)} 34 | params.x0.phi = torch.uniform(-self.rotate * math.pi / 180, self.rotate * math.pi / 180) 35 | params.x0.brightness = torch.uniform(-self.brightness, self.brightness) 36 | 37 | assert(self.contrast >= 1 and self.d_contrast >= 1) 38 | params.x0.contrast = torch.uniform(1 / self.contrast, self.contrast) 39 | 40 | params.x1.scale = {params.x0.scale[1] * torch.uniform(self.d_hscale, 1), params.x0.scale[2]} 41 | params.x1.hshear = params.x0.hshear + torch.uniform(-self.d_hshear, self.d_hshear) 42 | params.x1.trans = {params.x0.trans[1], params.x0.trans[2] + torch.uniform(-self.d_vtrans, self.d_vtrans)} 43 | params.x1.phi = params.x0.phi + torch.uniform(-self.d_rotate * math.pi / 180, self.d_rotate * math.pi / 180) 44 | params.x1.brightness = params.x0.brightness + torch.uniform(-self.d_brightness, self.d_brightness) 45 | params.x1.contrast = params.x0.contrast * torch.uniform(1 / self.d_contrast, self.d_contrast) 46 | 47 | return params 48 | end 49 | 50 | local function mul32(a,b) 51 | return {a[1]*b[1]+a[2]*b[4], a[1]*b[2]+a[2]*b[5], a[1]*b[3]+a[2]*b[6]+a[3], a[4]*b[1]+a[5]*b[4], a[4]*b[2]+a[5]*b[5], a[4]*b[3]+a[5]*b[6]+a[6]} 52 | end 53 | 54 | function Dataset:makePatch(src, dst, dim3, dim4, ws, params) 55 | local m = {1, 0, -dim4, 0, 1, -dim3} 56 | 57 | if params then 58 | m = mul32({1, 0, params.trans[1], 0, 1, params.trans[2]}, m) -- translate 59 | m = mul32({params.scale[1], 0, 0, 0, params.scale[2], 0}, m) -- scale 60 | local c = math.cos(params.phi) 61 | local s = math.sin(params.phi) 62 | m = mul32({c, s, 0, -s, c, 0}, m) -- rotate 63 | m = mul32({1, params.hshear, 0, 0, 1, 0}, m) -- shear 64 | end 65 | 66 | m = mul32({1, 0, (ws - 1) / 2, 0, 1, (ws - 1) / 2}, m) 67 | m = torch.FloatTensor(m) 68 | cv.warp_affine(src, dst, m) 69 | if params then 70 | dst:mul(params.contrast):add(params.brightness) 71 | end 72 | end 73 | 74 | 75 | function Dataset:prepareTrainingData(subset, all) 76 | self.nnz_tr = torch.cat(self.nnz_tr, self.nnz_disp, 1) 77 | self.tr = torch.cat(self.tr, self.tr_disp, 1) 78 | self.nnz_disp = self.nnz_tr 79 | self.tr_disp = self.tr 80 | -- subset training dataset 81 | if subset < 1 then 82 | self.nnz = self:subset(self.nnz_tr, self.tr, subset) 83 | self.disp = self:subset(self.nnz_disp, self.tr_disp, subset) 84 | elseif all then 85 | self.nnz = torch.cat(self.nnz_tr, self.nnz_te, 1) 86 | self.disp = torch.cat(self.nnz_disp, self.nnz_te, 1) 87 | else 88 | self.nnz = self.nnz_tr 89 | self.disp = self.nnz_disp 90 | end 91 | 92 | collectgarbage() 93 | end 94 | 95 | function Dataset:shuffle() 96 | self.perm = torch.randperm(self.nnz:size(1)) 97 | self.perm_disp = torch.randperm(self.disp:size(1)) 98 | end 99 | 100 | function Dataset:getDispRange(opt) 101 | -- Get disparity samples range 102 | local range 103 | if opt.all then 104 | range = torch.totable(torch.cat(self.tr_disp, self.te)) 105 | else 106 | range = torch.totable(self.tr_disp) 107 | end 108 | return range 109 | end 110 | 111 | function Dataset:saveDispData(samples, indexes, mcname) 112 | local path = ('%s/disparity/%s'):format(self.dir, mcname) 113 | torch.save(path .. '.t7', samples) 114 | torch.save(path .. '.indexes', indexes) 115 | self.X2 = samples 116 | self.X2_idx = indexes 117 | end 118 | 119 | 120 | function Dataset:loadDispData(mcname) 121 | local time = sys.clock() 122 | local path = ('%s/disparity/%s'):format(self.dir, mcname) 123 | 124 | if paths.filep(path .. '.t7') then 125 | print('===> Loading disparity training set...') 126 | self.X2 = self.X2 or torch.load(path .. '.t7') 127 | self.X2_idx = self.X2_idx or torch.load(path ..'.indexes') 128 | print(('===> Loaded! time=%s'):format(sys.clock() - time)) 129 | return true 130 | else 131 | print('===> No training data found for disparity network') 132 | return false 133 | end 134 | end 135 | 136 | function Dataset:trainingSamples(start, size, ws) 137 | local x = torch.FloatTensor(size * 4, self.n_colors, ws, ws) 138 | local y = torch.FloatTensor(size * 2) 139 | 140 | for i=start, start+size-1 do 141 | local idx = self.perm[i] 142 | local img = self.nnz[{idx, 1}] 143 | local dim3 = self.nnz[{idx, 2}] 144 | local dim4 = self.nnz[{idx, 3}] 145 | local d = self.nnz[{idx, 4}] 146 | 147 | local d_pos = torch.uniform(-self.true1, self.true1) 148 | local d_neg = torch.uniform(self.false1, self.false2) 149 | if torch.uniform() < 0.5 then 150 | d_neg = -d_neg 151 | end 152 | local x0, x1 = self:getLR(img) 153 | 154 | idx = i-start+1 155 | local params = self:obfuscationParams() 156 | self:makePatch(x0, x[idx * 4 - 3], dim3, dim4, ws, params.x0) 157 | self:makePatch(x1, x[idx * 4 - 2], dim3, dim4 - d + d_pos, ws, params.x1) 158 | self:makePatch(x0, x[idx * 4 - 1], dim3, dim4, ws, params.x0) 159 | self:makePatch(x1, x[idx * 4 - 0], dim3, dim4 - (d-d_neg), ws, params.x1) 160 | 161 | y[idx * 2 - 1] = 0 162 | y[idx * 2] = 1 163 | end 164 | 165 | return x:cuda(), y:cuda() 166 | end 167 | 168 | 169 | local function scale(input, size) 170 | 171 | local temp = torch.FloatTensor(input:size(1), size, size) 172 | image.scale(temp, input, 'bilinear') 173 | return temp 174 | end 175 | 176 | function Dataset.sample(xs, p) 177 | local perm = torch.randperm(xs:nElement()):long() 178 | return xs:index(1, perm[{{1, xs:size(1) * p}}]) 179 | end 180 | 181 | return M.Dataset -------------------------------------------------------------------------------- /src/runner.lua: -------------------------------------------------------------------------------- 1 | 2 | local M = {} 3 | 4 | local Runner = torch.class('Runner', M) 5 | 6 | function Runner:__init(mcnet, gdn, dataset, opt) 7 | self.matcher = require('pipeline/matching') 8 | self.disp = require('pipeline/disparity') 9 | self.post = require('pipeline/post') 10 | self.refiner = require('pipeline/refinement') 11 | self.dataset = dataset 12 | self.mcnet = mcnet 13 | self.gdn = gdn 14 | self.opt = opt 15 | self.path = ('%s/cache/%s/%s'):format(opt.storage, self.dataset.name, self.mcnet.name) 16 | end 17 | 18 | function Runner:setGdn(gdn) 19 | self.gdn = gdn 20 | end 21 | 22 | function Runner:predict(img, disp_max, directions, make_cache) 23 | 24 | local vox 25 | -- compute matching cost 26 | if self.opt.use_cache then 27 | vox = torch.load(('%s_%s.t7'):format(self.path, img.id)):cuda() 28 | else 29 | vox = self.matcher.match(self.mcnet, img.x_batch, 30 | disp_max, directions):cuda() 31 | if make_cache then 32 | torch.save(('%s_%s.t7'):format(self.path, img.id), vox) 33 | end 34 | end 35 | collectgarbage() 36 | 37 | -- post_process 38 | vox = self.post.process(vox, img.x_batch, disp_max, self.mcnet.params, self.dataset, self.opt.sm_terminate, self.opt.sm_skip, directions) 39 | 40 | -- pred after post process 41 | local vox_simple = vox:clone() 42 | 43 | -- disparity image 44 | local disp, vox, conf, t = self.disp.disparityImage(vox, self.gdn) 45 | 46 | -- refinement 47 | disp = self.refiner.refine(disp, vox_simple, self.mcnet.params, self.dataset, self.opt.sm_skip ,self.opt.sm_terminate, disp_max, conf, t.t1, t.t2) 48 | 49 | return disp[2] 50 | 51 | end 52 | 53 | function Runner:test(range, showall, make_cache) 54 | local err_sum = 0 55 | 56 | local opt = self.opt 57 | local directions = self.dataset.name == 'mb' and {-1} or {1, -1} 58 | 59 | for i, idx in ipairs(range) do 60 | xlua.progress(i-1, #range) 61 | local img = self.dataset:getTestSample(idx, false) 62 | local disp_max = img.disp_max or self.dataset.disp_max 63 | 64 | cutorch.synchronize() 65 | sys.tic() 66 | 67 | local pred = self:predict(img, disp_max, directions) 68 | 69 | cutorch.synchronize() 70 | local runtime = sys.toc() 71 | assert(pred:sum() == pred:sum()) 72 | 73 | local dispnoc = img.dispnoc 74 | local mask = torch.CudaTensor(dispnoc:size()):ne(dispnoc, 0) 75 | 76 | err, pred_bad, pred_good = self:calcErr(pred, dispnoc:clone(), mask, self.dataset.err_at) 77 | err_sum = err_sum + err 78 | 79 | if showall then 80 | print('\n' .. img.id, runtime, err .. '\n') 81 | end 82 | if self.opt.save_img then 83 | save_png(self.dataset, img, disp_max, pred, pred_bad, pred_good, mask) 84 | end 85 | end 86 | xlua.progress(#range, #range) 87 | return err_sum / #range 88 | end 89 | 90 | function Runner:submit(samples) 91 | os.execute('rm -rf out/*') 92 | if self.dataset.name == 'kitti2015' then 93 | os.execute('mkdir out/disp_0') 94 | end 95 | 96 | local directions = self.dataset.name == 'mb' and {-1} or {1, -1} 97 | for i, idx in ipairs(samples) do 98 | xlua.progress(i, #samples) 99 | 100 | local img = self.dataset:getTestSample(idx, true) 101 | local disp_max = img.disp_max or self.dataset.disp_max 102 | local pred = self:predict(img, disp_max, directions) 103 | 104 | if self.dataset.name == 'kitti' or self.dataset.name == 'kitti2015' then 105 | local pred_img = torch.FloatTensor(img.height, img.width):zero() 106 | pred_img:narrow(1, img.height - self.dataset.height + 1, self.dataset.height):copy(pred[{1,1}]) 107 | 108 | if self.dataset.name == 'kitti' then 109 | path = 'out' 110 | elseif self.dataset.name == 'kitti2015' then 111 | path = 'out/disp_0' 112 | end 113 | local s = ("%s/%06d_10.png"):format(path, img.id) 114 | adcensus.writePNG16(pred_img, img.height, img.width, s) 115 | elseif self.dataset.name == 'mb' then 116 | local base = 'out/' .. self.dataset.fname_submit[img.id - (#self.dataset.X - #self.dataset.fname_submit)] 117 | os.execute('mkdir -p ' .. base) 118 | adcensus.writePFM(image.vflip(pred[{1,1}]:float()), base .. '/disp0' .. opt.METHOD_NAME .. '.pfm') 119 | local f = io.open(base .. '/time' .. opt.METHOD_NAME .. '.txt', 'w') 120 | f:write(tostring(runtime)) 121 | f:close() 122 | end 123 | end 124 | os.execute('cd out; zip -r submission.zip .') 125 | end 126 | 127 | function Runner:createDispData() 128 | 129 | local range = self.dataset:getDispRange(self.opt) 130 | local samples = torch.FloatTensor(#range, self.dataset.disp_max, self.dataset.X0:size(3), self.dataset.X0:size(4)) 131 | local indexes = {} 132 | local directions = self.dataset.name == 'mb' and {-1} or {1, -1} 133 | for j, i in ipairs(range) do 134 | -- Get the sample to prepare 135 | local img = self.dataset:getTestSample(i) 136 | local disp_max = img.disp_max or self.dataset.disp_max 137 | 138 | -- 2 directions for left-right consistency check 139 | 140 | -- Compute the matching cost map 141 | vox = self.mcnet:computeMatchingCost(img.x_batch, self.dataset.disp_max,directions):cuda() 142 | 143 | -- Post processing 144 | vox = self.post.process(vox, img.x_batch, disp_max, self.mcnet.params, self.dataset, '','', directions) 145 | vox = nn.Tanh():cuda():forward(vox) 146 | 147 | -- Matching cost to similarity score 148 | vox:mul(-1):add(1) 149 | samples[{{j}, {}, {}, {1, vox:size(4)}}] = vox[{{1}}]:float() 150 | indexes[i] = j 151 | 152 | xlua.progress(j, #range) 153 | end 154 | return samples, indexes 155 | end 156 | function save_png(dataset, img, disp_max, pred, pred_bad, pred_good, mask) 157 | 158 | local img_pred = torch.Tensor(1, 3, pred:size(3), pred:size(4)) 159 | adcensus.grey2jet(pred:double():add(1)[{1,1}]:div(disp_max):double(), img_pred) 160 | local x0 = img.x_batch[1] 161 | if x0:size(1) == 1 then 162 | x0 = torch.repeatTensor(x0:cuda(), 3, 1, 1) 163 | end 164 | img_err = x0:mul(50):add(150):div(255) 165 | 166 | local real = torch.CudaTensor():resizeAs(img_err):copy(img_err) 167 | img_err[{1}]:add( 0.7, pred_bad) 168 | img_err[{2}]:add(-0.7, pred_bad) 169 | img_err[{3}]:add(-0.7, pred_bad) 170 | img_err[{1}]:add(-0.7, pred_good) 171 | img_err[{2}]:add( 0.7, pred_good) 172 | img_err[{3}]:add(-0.7, pred_good) 173 | 174 | local gt 175 | if dataset.name == 'kitti' or dataset.name == 'kitti2015' then 176 | gt = img.dispnoc 177 | elseif dataset.name == 'mb' then 178 | gt = img.dispnoc:resize(1, 1, pred:size(3), pred:size(4)) 179 | end 180 | local img_gt = torch.Tensor(1, 3, pred:size(3), pred:size(4)):zero() 181 | adcensus.grey2jet(gt:double():add(1)[{1}]:div(disp_max):double(), img_gt) 182 | img_gt[{1,3}]:cmul(mask:double()) 183 | 184 | image.save(('tmp/%s_%s_gt.png'):format(dataset.name, img.id), img_gt[1]) 185 | image.save(('tmp/%s_%s_real.png'):format(dataset.name, img.id), real[1]) 186 | image.save(('tmp/%s_%s_%s_pred.png'):format(dataset.name, network.name, img.id), img_pred[1]) 187 | image.save(('tmp/%s_%s_%s_err.png'):format(dataset.name, network.name, img.id), img_err[1]) 188 | end 189 | 190 | function Runner:calcErr(pred, dispnoc, mask, err_at) 191 | local pred_good = torch.CudaTensor(dispnoc:size()) 192 | local pred_bad = torch.CudaTensor(dispnoc:size()) 193 | dispnoc:add(-1, pred):abs() 194 | pred_bad:gt(dispnoc, err_at):cmul(mask) 195 | pred_good:le(dispnoc, self.dataset.err_at):cmul(mask) 196 | 197 | local err = pred_bad:sum() / mask:sum() 198 | 199 | return err, pred_bad, pred_good 200 | end 201 | 202 | return M.Runner 203 | -------------------------------------------------------------------------------- /scripts/preprocess_mb.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python2 2 | # This file is copied from https://github.com/jzbonter/mc-cnn 3 | # wget -r -np -A png,pfm,pgm,txt http://vision.middlebury.edu/stereo/data/scenes2014/datasets/ 4 | # wget -r -np -A png,pfm,pgm,txt http://vision.middlebury.edu/stereo/data/scenes2006/FullSize/ 5 | 6 | import os 7 | import re 8 | import sys 9 | import subprocess 10 | 11 | import numpy as np 12 | import cv2 13 | 14 | def load_pfm(fname, downsample): 15 | if downsample: 16 | if not os.path.isfile(fname + '.H.pfm'): 17 | x, scale = load_pfm(fname, False) 18 | x = x / 2 19 | x_ = np.zeros((x.shape[0] // 2, x.shape[1] // 2), dtype=np.float32) 20 | for i in range(0, x.shape[0], 2): 21 | for j in range(0, x.shape[1], 2): 22 | tmp = x[i:i+2,j:j+2].ravel() 23 | x_[i // 2,j // 2] = np.sort(tmp)[1] 24 | save_pfm(fname + '.H.pfm', x_, scale) 25 | return x_, scale 26 | else: 27 | fname += '.H.pfm' 28 | color = None 29 | width = None 30 | height = None 31 | scale = None 32 | endian = None 33 | 34 | file = open(fname) 35 | header = file.readline().rstrip() 36 | if header == 'PF': 37 | color = True 38 | elif header == 'Pf': 39 | color = False 40 | else: 41 | raise Exception('Not a PFM file.') 42 | 43 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline()) 44 | if dim_match: 45 | width, height = map(int, dim_match.groups()) 46 | else: 47 | raise Exception('Malformed PFM header.') 48 | 49 | scale = float(file.readline().rstrip()) 50 | if scale < 0: # little-endian 51 | endian = '<' 52 | scale = -scale 53 | else: 54 | endian = '>' # big-endian 55 | 56 | data = np.fromfile(file, endian + 'f') 57 | shape = (height, width, 3) if color else (height, width) 58 | return np.flipud(np.reshape(data, shape)), scale 59 | 60 | def save_pfm(fname, image, scale=1): 61 | file = open(fname, 'w') 62 | color = None 63 | 64 | if image.dtype.name != 'float32': 65 | raise Exception('Image dtype must be float32.') 66 | 67 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 68 | color = True 69 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 70 | color = False 71 | else: 72 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 73 | 74 | file.write('PF\n' if color else 'Pf\n') 75 | file.write('%d %d\n' % (image.shape[1], image.shape[0])) 76 | 77 | endian = image.dtype.byteorder 78 | 79 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 80 | scale = -scale 81 | 82 | file.write('%f\n' % scale) 83 | 84 | np.flipud(image).tofile(file) 85 | 86 | def read_im(fname, downsample): 87 | if downsample: 88 | if not os.path.isfile(fname + '.H.png'): 89 | subprocess.check_call('convert {} -resize 50% {}.H.png'.format(fname, fname).split()) 90 | fname += '.H.png' 91 | x = cv2.imread(fname).astype(np.float32) 92 | if color == 'rgb': 93 | x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB) 94 | x = x.transpose(2, 0, 1) 95 | else: 96 | x = cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)[None] 97 | x = (x - x.mean()) / x.std() 98 | return x[None] 99 | 100 | def tofile(fname, x): 101 | if x is None: 102 | open(fname + '.dim', 'w').write('0\n') 103 | open(fname, 'w') 104 | else: 105 | x.tofile(fname) 106 | open(fname + '.type', 'w').write(str(x.dtype)) 107 | open(fname + '.dim', 'w').write('\n'.join(map(str, x.shape))) 108 | 109 | rectification, color = sys.argv[1:] 110 | assert(rectification in set(['perfect', 'imperfect'])) 111 | assert(color in set(['gray', 'rgb'])) 112 | output_dir = 'storage/data.mb.{}_{}'.format(rectification, color) 113 | assert(os.path.isdir(output_dir)) 114 | 115 | num_channels = 3 if color == 'rgb' else 1 116 | 117 | X = [] 118 | dispnoc = [] 119 | meta = [] 120 | nnz_tr = [] 121 | nnz_te = [] 122 | te = np.arange(1, 11) 123 | 124 | ### 2014 dataset ### 125 | base1 = 'storage/data.mb/unzip/vision.middlebury.edu/stereo/data/scenes2014/datasets' 126 | for dir in sorted(os.listdir(base1)): 127 | if dir.endswith('imperfect'): 128 | print(dir.split('-')[0]) 129 | 130 | base2_imperfect = os.path.join(base1, dir) 131 | base2_perfect = base2_imperfect.replace('imperfect', 'perfect') 132 | 133 | calib = open(os.path.join(base2_imperfect, 'calib.txt')).read() 134 | ndisp = int(re.search('ndisp=(.*)', calib).group(1)) / 2 135 | 136 | x0 = read_im(os.path.join(base2_imperfect, 'im0.png'), True) 137 | x1 = read_im(os.path.join(base2_imperfect, 'im1.png'), True) 138 | x1E = read_im(os.path.join(base2_imperfect, 'im1E.png'), True) 139 | x1L = read_im(os.path.join(base2_imperfect, 'im1L.png'), True) 140 | XX = [np.concatenate((x0, x1, x1E, x1L))] 141 | 142 | base3 = os.path.join(base2_perfect if rectification == 'perfect' else base2_imperfect, 'ambient') 143 | num_light = len(os.listdir(base3)) 144 | 145 | num_exp = [], [] 146 | for fname in os.listdir(base3 + '/L1'): 147 | num_exp[int(fname[2])].append(int(fname[4]) + 1) 148 | num_exp = min(max(num_exp[0]), max(num_exp[1])) 149 | rng = { 150 | 8: [1, 3, 5], 151 | 7: [1, 3, 5], 152 | 6: [0, 2, 4], 153 | 5: [0, 2, 4], 154 | 3: [0, 1, 2], 155 | 2: [0, 1], 156 | } 157 | for light in range(num_light): 158 | imgs = [] 159 | base4 = os.path.join(base3, 'L{}'.format(light + 1)) 160 | for exp in rng[num_exp]: 161 | for cam in range(2): 162 | im = read_im(base4 + '/im{}e{}.png'.format(cam, exp), True) 163 | imgs.append(im) 164 | _, _, height, width = imgs[0].shape 165 | XX.append(np.concatenate(imgs).reshape(len(imgs) // 2, 2, num_channels, height, width)) 166 | 167 | disp0, scale0 = load_pfm(os.path.join(base2_imperfect, 'disp0.pfm'), True) 168 | disp1, scale1 = load_pfm(os.path.join(base2_imperfect, 'disp1.pfm'), True) 169 | disp0y, scale0y = load_pfm(os.path.join(base2_imperfect, 'disp0y.pfm'), True) 170 | 171 | save_pfm('tmp/disp0.pfm', disp0, 1) 172 | save_pfm('tmp/disp1.pfm', disp1, 1) 173 | save_pfm('tmp/disp0y.pfm', disp0y, 1) 174 | 175 | subprocess.check_output('computemask tmp/disp0.pfm tmp/disp0y.pfm tmp/disp1.pfm -1 tmp/mask.png'.split()) 176 | 177 | mask = cv2.imread('tmp/mask.png', 0) 178 | disp0[mask != 255] = 0 179 | y, x = np.nonzero(mask == 255) 180 | 181 | X.append(XX) 182 | nnz = nnz_te if len(X) in te else nnz_tr 183 | nnz.append(np.column_stack((np.zeros_like(y) + len(X), y, x, disp0[y, x])).astype(np.float32)) 184 | dispnoc.append(disp0.astype(np.float32)) 185 | meta.append((x0.shape[2], x0.shape[3], ndisp)) 186 | 187 | print('done with 2014') 188 | print(np.vstack(nnz_tr).shape) 189 | 190 | ### 2006 & 2005 dataset ### 191 | for year in (2006, 2005): 192 | base1 = 'storage/data.mb/unzip/vision.middlebury.edu/stereo/data/scenes{}/HalfSize'.format(year) 193 | for dir in sorted(os.listdir(base1)): 194 | base2 = os.path.join(base1, dir) 195 | if not os.path.isfile(base2 + '/disp1.png'): 196 | continue 197 | 198 | print(dir) 199 | 200 | XX = [] 201 | XX.append(None) # there are no test images for this dataset 202 | for light in range(3): 203 | imgs = [] 204 | for exp in (0, 1, 2): 205 | base3 = os.path.join(base2, 'Illum{}/Exp{}'.format(light + 1, exp)) 206 | x0 = read_im(os.path.join(base3, 'view1.png'), False) 207 | x1 = read_im(os.path.join(base3, 'view5.png'), False) 208 | imgs.append(x0) 209 | imgs.append(x1) 210 | _, _, height, width = imgs[0].shape 211 | XX.append(np.concatenate(imgs).reshape(len(imgs) // 2, 2, num_channels, height, width)) 212 | 213 | disp0 = cv2.imread(base2 + '/disp1.png', 0).astype(np.float32) / 2 214 | disp1 = cv2.imread(base2 + '/disp5.png', 0).astype(np.float32) / 2 215 | 216 | ndisp = int(np.ceil(disp0.max())) 217 | disp0[disp0 == 0] = np.inf 218 | disp1[disp1 == 0] = np.inf 219 | 220 | save_pfm('tmp/disp0.pfm', disp0, 1) 221 | save_pfm('tmp/disp1.pfm', disp1, 1) 222 | 223 | subprocess.check_output('computemask tmp/disp0.pfm tmp/disp1.pfm -1 tmp/mask.png'.split()) 224 | 225 | mask = cv2.imread('tmp/mask.png', 0) 226 | disp0[mask != 255] = 0 227 | y, x = np.nonzero(mask == 255) 228 | 229 | X.append(XX) 230 | nnz_tr.append(np.column_stack((np.zeros_like(y) + len(X), y, x, disp0[y, x])).astype(np.float32)) 231 | dispnoc.append(disp0.astype(np.float32)) 232 | meta.append((x0.shape[2], x0.shape[3], ndisp)) 233 | print(np.vstack(nnz_tr).shape) 234 | 235 | ### 2003 dataset ### 236 | for dir in ('conesH', 'teddyH'): 237 | print(dir) 238 | base1 = 'storage/data.mb/unzip/vision.middlebury.edu/stereo/data/scenes2003/{}'.format(dir) 239 | 240 | XX = [] 241 | XX.append(None) 242 | 243 | x0 = read_im(base1 + '/im2.ppm', False) 244 | x1 = read_im(base1 + '/im6.ppm', False) 245 | _, _, height, width = x0.shape 246 | XX.append(np.concatenate((x0, x1)).reshape(1, 2, num_channels, height, width)) 247 | 248 | disp0 = cv2.imread(base1 + '/disp2.pgm', 0).astype(np.float32) / 2 249 | disp1 = cv2.imread(base1 + '/disp6.pgm', 0).astype(np.float32) / 2 250 | ndisp = int(np.ceil(disp0.max())) 251 | disp0[disp0 == 0] = np.inf 252 | disp1[disp1 == 0] = np.inf 253 | 254 | save_pfm('tmp/disp0.pfm', disp0, 1) 255 | save_pfm('tmp/disp1.pfm', disp1, 1) 256 | 257 | subprocess.check_output('computemask tmp/disp0.pfm tmp/disp1.pfm -1 tmp/mask.png'.split()) 258 | 259 | mask = cv2.imread('tmp/mask.png', 0) 260 | disp0[mask != 255] = 0 261 | y, x = np.nonzero(mask == 255) 262 | 263 | X.append(XX) 264 | nnz_tr.append(np.column_stack((np.zeros_like(y) + len(X), y, x, disp0[y, x])).astype(np.float32)) 265 | dispnoc.append(disp0.astype(np.float32)) 266 | meta.append((x0.shape[2], x0.shape[3], ndisp)) 267 | 268 | print(np.vstack(nnz_tr).shape) 269 | 270 | ### 2001 dataset ### 271 | base1 = 'storage/data.mb/unzip/vision.middlebury.edu/stereo/data/scenes2001/data' 272 | for dir in sorted(os.listdir(base1)): 273 | if dir == 'tsukuba': 274 | fname_disp0, fname_disp1, fname_x0, fname_x1 = 'truedisp.row3.col3.pgm', '', 'scene1.row3.col3.ppm', 'scene1.row3.col4.ppm' 275 | elif dir == 'map': 276 | fname_disp0, fname_disp1, fname_x0, fname_x1 = 'disp0.pgm', 'disp1.pgm', 'im0.pgm', 'im1.pgm' 277 | else: 278 | fname_disp0, fname_disp1, fname_x0, fname_x1 = 'disp2.pgm', 'disp6.pgm', 'im2.ppm', 'im6.ppm' 279 | 280 | base2 = os.path.join(base1, dir) 281 | if os.path.isfile(os.path.join(base2, fname_disp0)): 282 | print(dir) 283 | 284 | XX = [] 285 | XX.append(None) 286 | 287 | x0 = read_im(os.path.join(base2, fname_x0), False) 288 | x1 = read_im(os.path.join(base2, fname_x1), False) 289 | _, _, height, width = x0.shape 290 | XX.append(np.concatenate((x0, x1)).reshape(1, 2, num_channels, height, width)) 291 | 292 | if dir == 'tsukuba': 293 | disp0 = cv2.imread(os.path.join(base2, fname_disp0), 0).astype(np.float32) / 16 294 | mask = cv2.imread(os.path.join(base2, 'nonocc.png'), 0) 295 | else: 296 | disp0 = cv2.imread(os.path.join(base2, fname_disp0), 0).astype(np.float32) / 8 297 | disp1 = cv2.imread(os.path.join(base2, fname_disp1), 0).astype(np.float32) / 8 298 | 299 | save_pfm('tmp/disp0.pfm', disp0, 1) 300 | save_pfm('tmp/disp1.pfm', disp1, 1) 301 | subprocess.check_output('computemask tmp/disp0.pfm tmp/disp1.pfm -1 tmp/mask.png'.split()) 302 | 303 | mask = cv2.imread('tmp/mask.png', 0) 304 | disp0[mask != 255] = 0 305 | y, x = np.nonzero(mask == 255) 306 | 307 | X.append(XX) 308 | nnz_tr.append(np.column_stack((np.zeros_like(y) + len(X), y, x, disp0[y, x])).astype(np.float32)) 309 | dispnoc.append(disp0.astype(np.float32)) 310 | meta.append((x0.shape[2], x0.shape[3], -1)) 311 | 312 | ### test ### 313 | fname_submit = [] 314 | 315 | base1 = 'storage/data.mb/unzip/MiddEval3' 316 | for dir1 in ['trainingH', 'testH']: 317 | base2 = os.path.join(base1, dir1) 318 | for dir2 in sorted(os.listdir(base2)): 319 | base3 = os.path.join(base2, dir2) 320 | print(os.path.join(dir1, dir2)) 321 | 322 | calib = open(os.path.join(base3, 'calib.txt')).read() 323 | ndisp = int(re.search('ndisp=(.*)', calib).group(1)) 324 | 325 | x0 = read_im(os.path.join(base3, 'im0.png'), False) 326 | x1 = read_im(os.path.join(base3, 'im1.png'), False) 327 | 328 | X.append([np.concatenate((x0, x1)).astype(np.float32)]) 329 | meta.append((x0.shape[2], x0.shape[3], ndisp)) 330 | fname_submit.append(os.path.join(dir1, dir2)) 331 | 332 | meta = np.array(meta, dtype=np.int32) 333 | nnz_tr = np.vstack(nnz_tr) 334 | nnz_te = np.vstack(nnz_te) 335 | 336 | subprocess.check_call('rm -f {}/*.{{bin,dim,txt,type}} tmp/*'.format(output_dir), shell=True) 337 | for i in range(len(X)): 338 | for j in range(len(X[i])): 339 | tofile('{}/x_{}_{}.bin'.format(output_dir, i + 1, j + 1), X[i][j]) 340 | if i < len(dispnoc): 341 | tofile('{}/dispnoc{}.bin'.format(output_dir, i + 1), dispnoc[i]) 342 | tofile('{}/meta.bin'.format(output_dir), meta) 343 | tofile('{}/nnz_tr.bin'.format(output_dir), nnz_tr) 344 | tofile('{}/nnz_te.bin'.format(output_dir), nnz_te) 345 | tofile('{}/te.bin'.format(output_dir), te) 346 | open('{}/fname_submit.txt'.format(output_dir), 'w').write('\n'.join(fname_submit)) 347 | -------------------------------------------------------------------------------- /src/adcensus.cu: -------------------------------------------------------------------------------- 1 | /* This file is copied from https://github.com/jzbonter/mc-cnn */ 2 | extern "C" { 3 | #include "lua.h" 4 | #include "lualib.h" 5 | #include "lauxlib.h" 6 | } 7 | 8 | #include "luaT.h" 9 | #include "THC.h" 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #define TB 128 19 | 20 | #define DISP_MAX 256 21 | 22 | THCState* getCutorchState(lua_State* L) 23 | { 24 | lua_getglobal(L, "cutorch"); 25 | lua_getfield(L, -1, "getState"); 26 | lua_call(L, 0, 1); 27 | THCState *state = (THCState*) lua_touserdata(L, -1); 28 | lua_pop(L, 2); 29 | return state; 30 | } 31 | 32 | void checkCudaError(lua_State *L) { 33 | cudaError_t status = cudaPeekAtLastError(); 34 | if (status != cudaSuccess) { 35 | luaL_error(L, cudaGetErrorString(status)); 36 | } 37 | } 38 | 39 | #define COLOR_DIFF(x, i, j) (abs(x[i] - x[j])) 40 | 41 | THCudaTensor *new_tensor_like(THCState *state, THCudaTensor *x) 42 | { 43 | THCudaTensor *y = THCudaTensor_new(state); 44 | THCudaTensor_resizeAs(state, y, x); 45 | return y; 46 | } 47 | 48 | __device__ void sort(float *x, int n) 49 | { 50 | for (int i = 0; i < n - 1; i++) { 51 | int min = i; 52 | for (int j = i + 1; j < n; j++) { 53 | if (x[j] < x[min]) { 54 | min = j; 55 | } 56 | } 57 | float tmp = x[min]; 58 | x[min] = x[i]; 59 | x[i] = tmp; 60 | } 61 | } 62 | 63 | __global__ void ad(float *x0, float *x1, float *output, int size, int size2, int size3, int direction) 64 | { 65 | int id = blockIdx.x * blockDim.x + threadIdx.x; 66 | 67 | if (id < size) { 68 | int d = id; 69 | int x = d % size3; 70 | d /= size3; 71 | int y = d % size2; 72 | d /= size2; 73 | d *= direction; 74 | 75 | float dist; 76 | if (0 <= x + d && x + d < size3) { 77 | int cnt = 0; 78 | dist = 0; 79 | for (int yy = y - 4; yy <= y + 4; yy++) { 80 | for (int xx = x - 4; xx <= x + 4; xx++) { 81 | if (0 <= xx && xx < size3 && 0 <= xx + d && xx + d < size3 && 0 <= yy && yy < size2) { 82 | int ind = yy * size3 + xx; 83 | dist += abs(x0[ind] - x1[ind + d]); 84 | cnt++; 85 | } 86 | } 87 | } 88 | dist /= cnt; 89 | } else { 90 | dist = CUDART_NAN; 91 | } 92 | output[id] = dist; 93 | } 94 | } 95 | 96 | int ad(lua_State *L) 97 | { 98 | THCState *state = getCutorchState(L); 99 | THCudaTensor *x0 = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 100 | THCudaTensor *x1 = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 101 | THCudaTensor *out = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 102 | int direction = luaL_checkinteger(L, 4); 103 | assert(direction == -1 || direction == 1); 104 | 105 | ad<<<(THCudaTensor_nElement(state, out) - 1) / TB + 1, TB>>>( 106 | THCudaTensor_data(state, x0), 107 | THCudaTensor_data(state, x1), 108 | THCudaTensor_data(state, out), 109 | THCudaTensor_nElement(state, out), 110 | THCudaTensor_size(state, out, 2), 111 | THCudaTensor_size(state, out, 3), 112 | direction); 113 | checkCudaError(L); 114 | return 0; 115 | } 116 | 117 | 118 | __global__ void census(float *x0, float *x1, float *output, int size, int num_channels, int size2, int size3, int direction) 119 | { 120 | int id = blockIdx.x * blockDim.x + threadIdx.x; 121 | 122 | if (id < size) { 123 | int d = id; 124 | int x = d % size3; 125 | d /= size3; 126 | int y = d % size2; 127 | d /= size2; 128 | d *= direction; 129 | 130 | float dist; 131 | if (0 <= x + d && x + d < size3) { 132 | dist = 0; 133 | for (int i = 0; i < num_channels; i++) { 134 | int ind_p = (i * size2 + y) * size3 + x; 135 | for (int yy = y - 4; yy <= y + 4; yy++) { 136 | for (int xx = x - 4; xx <= x + 4; xx++) { 137 | if (0 <= xx && xx < size3 && 0 <= xx + d && xx + d < size3 && 0 <= yy && yy < size2) { 138 | int ind_q = (i * size2 + yy) * size3 + xx; 139 | if ((x0[ind_q] < x0[ind_p]) != (x1[ind_q + d] < x1[ind_p + d])) { 140 | dist++; 141 | } 142 | } else { 143 | dist++; 144 | } 145 | } 146 | } 147 | } 148 | dist /= num_channels; 149 | } else { 150 | dist = CUDART_NAN; 151 | } 152 | output[id] = dist; 153 | } 154 | } 155 | 156 | int census(lua_State *L) 157 | { 158 | THCState *state = getCutorchState(L); 159 | THCudaTensor *x0 = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 160 | THCudaTensor *x1 = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 161 | THCudaTensor *out = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 162 | int direction = luaL_checkinteger(L, 4); 163 | assert(direction == -1 || direction == 1); 164 | 165 | census<<<(THCudaTensor_nElement(state, out) - 1) / TB + 1, TB>>>( 166 | THCudaTensor_data(state, x0), 167 | THCudaTensor_data(state, x1), 168 | THCudaTensor_data(state, out), 169 | THCudaTensor_nElement(state, out), 170 | THCudaTensor_size(state, x0, 1), 171 | THCudaTensor_size(state, out, 2), 172 | THCudaTensor_size(state, out, 3), 173 | direction); 174 | checkCudaError(L); 175 | return 0; 176 | } 177 | 178 | #if 0 179 | __global__ void add_vol(float *vol, float *cnt, float *out, int size, int size1, int size2, int size3, float ratio) 180 | { 181 | int id = blockIdx.x * blockDim.x + threadIdx.x; 182 | if (id < size) { 183 | int d = id; 184 | int x = d % size3; 185 | d /= size3; 186 | int y = d % size2; 187 | d /= size2; 188 | 189 | int lo = floor(d * ratio); 190 | int hi = lo + 1; 191 | float alpha = (d * ratio) - lo; 192 | assert(0 <= lo && hi < size1); 193 | 194 | float val = vol[(lo * size2 + y) * size3 + x] * (1 - alpha) + vol[(hi * size2 + y) * size3 + x] * alpha; 195 | if (!isnan(val) && cnt[id] > 0) { 196 | out[id] += val; 197 | cnt[id] += 1; 198 | } 199 | } 200 | } 201 | 202 | int add_vol(lua_State *L) 203 | { 204 | THCudaTensor *vol = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 205 | THCudaTensor *cnt = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 206 | THCudaTensor *out = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 207 | float ratio = luaL_checknumber(L, 4); 208 | 209 | add_vol<<<(THCudaTensor_nElement(out) - 1) / TB + 1, TB>>>( 210 | THCudaTensor_data(vol), 211 | THCudaTensor_data(cnt), 212 | THCudaTensor_data(out), 213 | THCudaTensor_nElement(out), 214 | THCudaTensor_size(vol, 1), 215 | THCudaTensor_size(out, 2), 216 | THCudaTensor_size(out, 3), 217 | ratio); 218 | checkCudaError(L); 219 | return 0; 220 | } 221 | 222 | __global__ void rho(float *x, int size, float lambda) 223 | { 224 | int id = blockIdx.x * blockDim.x + threadIdx.x; 225 | if (id < size) { 226 | x[id] = 1 - exp(-x[id] / lambda); 227 | } 228 | } 229 | 230 | int rho(lua_State *L) 231 | { 232 | THCudaTensor *x = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 233 | float lambda = luaL_checknumber(L, 2); 234 | 235 | rho<<<(THCudaTensor_nElement(x) - 1) / TB + 1, TB>>>( 236 | THCudaTensor_data(x), 237 | THCudaTensor_nElement(x), 238 | lambda); 239 | checkCudaError(L); 240 | return 0; 241 | } 242 | 243 | #endif 244 | 245 | __global__ void spatial_argmin(float *input, float *output, int size, int size1, int size23) 246 | { 247 | int id = blockIdx.x * blockDim.x + threadIdx.x; 248 | if (id < size) { 249 | int dim23 = id % size23; 250 | int dim0 = id / size23; 251 | 252 | int argmin = 0; 253 | float min = CUDART_INF; 254 | for (int i = 0; i < size1; i++) { 255 | float val = input[(dim0 * size1 + i) * size23 + dim23]; 256 | if (val < min) { 257 | min = val; 258 | argmin = i; 259 | } 260 | } 261 | output[id] = argmin + 1; 262 | } 263 | } 264 | 265 | int spatial_argmin(lua_State *L) 266 | { 267 | THCState *state = getCutorchState(L); 268 | THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 269 | THCudaTensor *output = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 270 | 271 | spatial_argmin<<<(THCudaTensor_nElement(state, output) - 1) / TB + 1, TB>>>( 272 | THCudaTensor_data(state, input), 273 | THCudaTensor_data(state, output), 274 | THCudaTensor_nElement(state, output), 275 | THCudaTensor_size(state, input, 1), 276 | THCudaTensor_size(state, input, 2) * THCudaTensor_size(state, output, 3)); 277 | checkCudaError(L); 278 | return 0; 279 | } 280 | 281 | __global__ void cross(float *x0, float *out, int size, int dim2, int dim3, int L1, float tau1) 282 | { 283 | int id = blockIdx.x * blockDim.x + threadIdx.x; 284 | if (id < size) { 285 | int dir = id; 286 | int x = dir % dim3; 287 | dir /= dim3; 288 | int y = dir % dim2; 289 | dir /= dim2; 290 | 291 | int dx = 0; 292 | int dy = 0; 293 | if (dir == 0) { 294 | dx = -1; 295 | } else if (dir == 1) { 296 | dx = 1; 297 | } else if (dir == 2) { 298 | dy = -1; 299 | } else if (dir == 3) { 300 | dy = 1; 301 | } else { 302 | assert(0); 303 | } 304 | 305 | int xx, yy, ind1, ind2, dist; 306 | ind1 = y * dim3 + x; 307 | for (xx = x + dx, yy = y + dy;;xx += dx, yy += dy) { 308 | if (xx < 0 || xx >= dim3 || yy < 0 || yy >= dim2) break; 309 | 310 | dist = max(abs(xx - x), abs(yy - y)); 311 | if (dist == 1) continue; 312 | 313 | ind2 = yy * dim3 + xx; 314 | 315 | /* rule 1 */ 316 | if (COLOR_DIFF(x0, ind1, ind2) >= tau1) break; 317 | 318 | /* rule 2 */ 319 | if (dist >= L1) break; 320 | } 321 | out[id] = dir <= 1 ? xx : yy; 322 | } 323 | } 324 | 325 | int cross(lua_State *L) 326 | { 327 | THCState *state = getCutorchState(L); 328 | THCudaTensor *x0 = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 329 | THCudaTensor *out = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 330 | int L1 = luaL_checkinteger(L, 3); 331 | float tau1 = luaL_checknumber(L, 4); 332 | 333 | cross<<<(THCudaTensor_nElement(state, out) - 1) / TB + 1, TB>>>( 334 | THCudaTensor_data(state, x0), 335 | THCudaTensor_data(state, out), 336 | THCudaTensor_nElement(state, out), 337 | THCudaTensor_size(state, out, 2), 338 | THCudaTensor_size(state, out, 3), 339 | L1, tau1); 340 | checkCudaError(L); 341 | return 0; 342 | } 343 | 344 | __global__ void cbca(float *x0c, float *x1c, float *vol, float *out, int size, int dim2, int dim3, int direction) 345 | { 346 | int id = blockIdx.x * blockDim.x + threadIdx.x; 347 | if (id < size) { 348 | int d = id; 349 | int x = d % dim3; 350 | d /= dim3; 351 | int y = d % dim2; 352 | d /= dim2; 353 | 354 | if (x + d * direction < 0 || x + d * direction >= dim3) { 355 | out[id] = vol[id]; 356 | } else { 357 | float sum = 0; 358 | int cnt = 0; 359 | 360 | int yy_s = max(x0c[(2 * dim2 + y) * dim3 + x], x1c[(2 * dim2 + y) * dim3 + x + d * direction]); 361 | int yy_t = min(x0c[(3 * dim2 + y) * dim3 + x], x1c[(3 * dim2 + y) * dim3 + x + d * direction]); 362 | for (int yy = yy_s + 1; yy < yy_t; yy++) { 363 | int xx_s = max(x0c[(0 * dim2 + yy) * dim3 + x], x1c[(0 * dim2 + yy) * dim3 + x + d * direction] - d * direction); 364 | int xx_t = min(x0c[(1 * dim2 + yy) * dim3 + x], x1c[(1 * dim2 + yy) * dim3 + x + d * direction] - d * direction); 365 | for (int xx = xx_s + 1; xx < xx_t; xx++) { 366 | float val = vol[(d * dim2 + yy) * dim3 + xx]; 367 | assert(!isnan(val)); 368 | sum += val; 369 | cnt++; 370 | } 371 | } 372 | 373 | assert(cnt > 0); 374 | out[id] = sum / cnt; 375 | assert(!isnan(out[id])); 376 | } 377 | } 378 | } 379 | 380 | int cbca(lua_State *L) 381 | { 382 | THCState *state = getCutorchState(L); 383 | THCudaTensor *x0c = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 384 | THCudaTensor *x1c = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 385 | THCudaTensor *vol_in = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 386 | THCudaTensor *vol_out = (THCudaTensor*)luaT_checkudata(L, 4, "torch.CudaTensor"); 387 | int direction = luaL_checkinteger(L, 5); 388 | 389 | assert(direction == -1 or direction == 1); 390 | cbca<<<(THCudaTensor_nElement(state, vol_out) - 1) / TB + 1, TB>>>( 391 | THCudaTensor_data(state, x0c), 392 | THCudaTensor_data(state, x1c), 393 | THCudaTensor_data(state, vol_in), 394 | THCudaTensor_data(state, vol_out), 395 | THCudaTensor_nElement(state, vol_out), 396 | THCudaTensor_size(state, vol_out, 2), 397 | THCudaTensor_size(state, vol_out, 3), 398 | direction); 399 | checkCudaError(L); 400 | return 0; 401 | } 402 | 403 | __global__ void sgm(float *x0, float *x1, float *vol, float *tmp, float *out, int dim1, int dim2, int dim3, float pi1, float pi2, float tau_so, float alpha1, float sgm_q1, float sgm_q2, int sgm_direction, int direction) 404 | { 405 | int x, y, dx, dy; 406 | 407 | dx = dy = 0; 408 | if (sgm_direction <= 1) { 409 | y = blockIdx.x * blockDim.x + threadIdx.x; 410 | if (y >= dim2) { 411 | return; 412 | } 413 | if (sgm_direction == 0) { 414 | x = 0; 415 | dx = 1; 416 | } else if (sgm_direction == 1) { 417 | x = dim3 - 1; 418 | dx = -1; 419 | } 420 | } else if (sgm_direction <= 3) { 421 | x = blockIdx.x * blockDim.x + threadIdx.x; 422 | if (x >= dim3) { 423 | return; 424 | } 425 | if (sgm_direction == 2) { 426 | y = 0; 427 | dy = 1; 428 | } else if (sgm_direction == 3) { 429 | y = dim2 - 1; 430 | dy = -1; 431 | } 432 | } 433 | 434 | assert(dim1 <= 400); 435 | float tmp_curr_[400]; 436 | float tmp_prev_[400]; 437 | float *tmp_curr = tmp_curr_; 438 | float *tmp_prev = tmp_prev_; 439 | 440 | float min_prev = CUDART_INF; 441 | for (; 0 <= y && y < dim2 && 0 <= x && x < dim3; x += dx, y += dy) { 442 | float min_curr = CUDART_INF; 443 | for (int d = 0; d < dim1; d++) { 444 | int ind = (d * dim2 + y) * dim3 + x; 445 | 446 | if (x + d * direction < 0 || 447 | x + d * direction >= dim3 || 448 | y - dy < 0 || 449 | y - dy >= dim2 || 450 | x + d * direction - dx < 0 || 451 | x + d * direction - dx >= dim3 || 452 | x - dx < 0 || 453 | x - dx >= dim3) { 454 | 455 | out[ind] += vol[ind]; 456 | tmp_curr[d] = vol[ind]; 457 | } else { 458 | int ind2 = y * dim3 + x; 459 | 460 | float D1 = COLOR_DIFF(x0, ind2, ind2 - dy * dim3 - dx); 461 | float D2 = COLOR_DIFF(x1, ind2 + d * direction, ind2 + d * direction - dy * dim3 - dx); 462 | float P1, P2; 463 | if (D1 < tau_so && D2 < tau_so) { 464 | P1 = pi1; 465 | P2 = (pi1 * pi2); 466 | } else if (D1 > tau_so && D2 > tau_so) { 467 | P1 = pi1 / (sgm_q1 * sgm_q2); 468 | P2 = (pi1 * pi2) / (sgm_q1 * sgm_q2); 469 | } else { 470 | P1 = pi1 / sgm_q1; 471 | P2 = (pi1 * pi2) / sgm_q1; 472 | } 473 | 474 | assert(min_prev != CUDART_INF); 475 | float cost = min(tmp_prev[d], min_prev + P2); 476 | if (d > 0) { 477 | cost = min(cost, tmp_prev[d - 1] + (sgm_direction == 2 ? P1 / alpha1 : P1)); 478 | } 479 | if (d < dim1 - 1) { 480 | cost = min(cost, tmp_prev[d + 1] + (sgm_direction == 3 ? P1 / alpha1 : P1)); 481 | } 482 | float val = vol[ind] + cost - min_prev; 483 | out[ind] += val; 484 | tmp_curr[d] = val; 485 | } 486 | if (tmp_curr[d] < min_curr) { 487 | min_curr = tmp_curr[d]; 488 | } 489 | } 490 | min_prev = min_curr; 491 | 492 | float *swap = tmp_curr; 493 | tmp_curr = tmp_prev; 494 | tmp_prev = swap; 495 | } 496 | } 497 | 498 | int sgm(lua_State *L) 499 | { 500 | THCState *state = getCutorchState(L); 501 | THCudaTensor *x0 = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 502 | THCudaTensor *x1 = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 503 | THCudaTensor *vol = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 504 | THCudaTensor *tmp = (THCudaTensor*)luaT_checkudata(L, 4, "torch.CudaTensor"); 505 | THCudaTensor *out = (THCudaTensor*)luaT_checkudata(L, 5, "torch.CudaTensor"); 506 | float pi1 = luaL_checknumber(L, 6); 507 | float pi2 = luaL_checknumber(L, 7); 508 | float tau_so = luaL_checknumber(L, 8); 509 | float alpha1 = luaL_checknumber(L, 9); 510 | float sgm_q1 = luaL_checknumber(L, 10); 511 | float sgm_q2 = luaL_checknumber(L, 11); 512 | int direction = luaL_checknumber(L, 12); 513 | 514 | int dim1 = THCudaTensor_size(state, out, 1); 515 | int dim2 = THCudaTensor_size(state, out, 2); 516 | int dim3 = THCudaTensor_size(state, out, 3); 517 | 518 | for (int sgm_direction = 0; sgm_direction < 4; sgm_direction++) { 519 | int size = sgm_direction <= 1 ? dim2 : dim3; 520 | sgm<<<(size - 1) / TB + 1, TB>>>( 521 | THCudaTensor_data(state, x0), 522 | THCudaTensor_data(state, x1), 523 | THCudaTensor_data(state, vol), 524 | THCudaTensor_data(state, tmp), 525 | THCudaTensor_data(state, out), 526 | dim1, dim2, dim3, pi1, pi2, tau_so, alpha1, sgm_q1, sgm_q2, sgm_direction, direction); 527 | } 528 | checkCudaError(L); 529 | return 0; 530 | } 531 | 532 | #define INDEX(dim0, dim1, dim2, dim3) \ 533 | assert((dim1) >= 0 && (dim1) < size1 && (dim2) >= 0 && (dim2) < size2 && (dim3) >= 0 && (dim3) < size3), \ 534 | ((((dim0) * size1 + (dim1)) * size2 + (dim2)) * size3 + dim3) 535 | 536 | template 537 | __global__ void sgm2(float *x0, float *x1, float *input, float *output, float *tmp, float pi1, float pi2, float tau_so, float alpha1, float sgm_q1, float sgm_q2, int direction, int size1, int size2, int size3, int step) 538 | { 539 | int x, y, dx, dy; 540 | int d = threadIdx.x; 541 | 542 | if (sgm_direction == 0) { 543 | /* right */ 544 | x = step; 545 | y = blockIdx.x; 546 | dx = 1; 547 | dy = 0; 548 | } else if (sgm_direction == 1) { 549 | /* left */ 550 | x = size2 - 1 - step; 551 | y = blockIdx.x; 552 | dx = -1; 553 | dy = 0; 554 | } else if (sgm_direction == 2) { 555 | /* down */ 556 | x = blockIdx.x; 557 | y = step; 558 | dx = 0; 559 | dy = 1; 560 | } else if (sgm_direction == 3) { 561 | /* up */ 562 | x = blockIdx.x; 563 | y = size1 - 1 - step; 564 | dx = 0; 565 | dy = -1; 566 | } 567 | 568 | if (y - dy < 0 || y - dy >= size1 || x - dx < 0 || x - dx >= size2) { 569 | float val = input[INDEX(0, y, x, d)]; 570 | output[INDEX(0, y, x, d)] += val; 571 | tmp[d * size2 + blockIdx.x] = val; 572 | return; 573 | } 574 | 575 | __shared__ float output_s[400], output_min[400]; 576 | 577 | output_s[d] = output_min[d] = tmp[d * size2 + blockIdx.x]; 578 | __syncthreads(); 579 | 580 | for (int i = 256; i > 0; i /= 2) { 581 | if (d < i && d + i < size3 && output_min[d + i] < output_min[d]) { 582 | output_min[d] = output_min[d + i]; 583 | } 584 | __syncthreads(); 585 | } 586 | 587 | int ind2 = y * size2 + x; 588 | float D1 = COLOR_DIFF(x0, ind2, ind2 - dy * size2 - dx); 589 | float D2; 590 | int xx = x + d * direction; 591 | if (xx < 0 || xx >= size2 || xx - dx < 0 || xx - dx >= size2) { 592 | D2 = 10; 593 | } else { 594 | D2 = COLOR_DIFF(x1, ind2 + d * direction, ind2 + d * direction - dy * size2 - dx); 595 | } 596 | float P1, P2; 597 | if (D1 < tau_so && D2 < tau_so) { 598 | P1 = pi1; 599 | P2 = pi2; 600 | } else if (D1 > tau_so && D2 > tau_so) { 601 | P1 = pi1 / (sgm_q1 * sgm_q2); 602 | P2 = pi2 / (sgm_q1 * sgm_q2); 603 | } else { 604 | P1 = pi1 / sgm_q1; 605 | P2 = pi2 / sgm_q1; 606 | } 607 | 608 | float cost = min(output_s[d], output_min[0] + P2); 609 | if (d - 1 >= 0) { 610 | cost = min(cost, output_s[d - 1] + (sgm_direction == 2 ? P1 / alpha1 : P1)); 611 | } 612 | if (d + 1 < size3) { 613 | cost = min(cost, output_s[d + 1] + (sgm_direction == 3 ? P1 / alpha1 : P1)); 614 | } 615 | 616 | float val = input[INDEX(0, y, x, d)] + cost - output_min[0]; 617 | output[INDEX(0, y, x, d)] += val; 618 | tmp[d * size2 + blockIdx.x] = val; 619 | } 620 | 621 | int sgm2(lua_State *L) 622 | { 623 | THCState *state = getCutorchState(L); 624 | THCudaTensor *x0 = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 625 | THCudaTensor *x1 = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 626 | THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 627 | THCudaTensor *output = (THCudaTensor*)luaT_checkudata(L, 4, "torch.CudaTensor"); 628 | THCudaTensor *tmp = (THCudaTensor*)luaT_checkudata(L, 5, "torch.CudaTensor"); 629 | float pi1 = luaL_checknumber(L, 6); 630 | float pi2 = luaL_checknumber(L, 7); 631 | float tau_so = luaL_checknumber(L, 8); 632 | float alpha1 = luaL_checknumber(L, 9); 633 | float sgm_q1 = luaL_checknumber(L, 10); 634 | float sgm_q2 = luaL_checknumber(L, 11); 635 | int direction = luaL_checknumber(L, 12); 636 | int size1 = THCudaTensor_size(state, output, 1) * THCudaTensor_size(state, output, 3); 637 | int size2 = THCudaTensor_size(state, output, 2) * THCudaTensor_size(state, output, 3); 638 | int disp_max = THCudaTensor_size(state, output, 3); 639 | 640 | for (int step = 0; step < THCudaTensor_size(state, input, 2); step++) { 641 | sgm2<0><<<(size1 - 1) / disp_max + 1, disp_max>>>( 642 | THCudaTensor_data(state, x0), 643 | THCudaTensor_data(state, x1), 644 | THCudaTensor_data(state, input), 645 | THCudaTensor_data(state, output), 646 | THCudaTensor_data(state, tmp), 647 | pi1, pi2, tau_so, alpha1, sgm_q1, sgm_q2, direction, 648 | THCudaTensor_size(state, input, 1), 649 | THCudaTensor_size(state, input, 2), 650 | THCudaTensor_size(state, input, 3), 651 | step); 652 | } 653 | 654 | for (int step = 0; step < THCudaTensor_size(state, input, 2); step++) { 655 | sgm2<1><<<(size1 - 1) / disp_max + 1, disp_max>>>( 656 | THCudaTensor_data(state, x0), 657 | THCudaTensor_data(state, x1), 658 | THCudaTensor_data(state, input), 659 | THCudaTensor_data(state, output), 660 | THCudaTensor_data(state, tmp), 661 | pi1, pi2, tau_so, alpha1, sgm_q1, sgm_q2, direction, 662 | THCudaTensor_size(state, input, 1), 663 | THCudaTensor_size(state, input, 2), 664 | THCudaTensor_size(state, input, 3), 665 | step); 666 | } 667 | 668 | for (int step = 0; step < THCudaTensor_size(state, input, 1); step++) { 669 | sgm2<2><<<(size2 - 1) / disp_max + 1, disp_max>>>( 670 | THCudaTensor_data(state, x0), 671 | THCudaTensor_data(state, x1), 672 | THCudaTensor_data(state, input), 673 | THCudaTensor_data(state, output), 674 | THCudaTensor_data(state, tmp), 675 | pi1, pi2, tau_so, alpha1, sgm_q1, sgm_q2, direction, 676 | THCudaTensor_size(state, input, 1), 677 | THCudaTensor_size(state, input, 2), 678 | THCudaTensor_size(state, input, 3), 679 | step); 680 | } 681 | 682 | for (int step = 0; step < THCudaTensor_size(state, input, 1); step++) { 683 | sgm2<3><<<(size2 - 1) / disp_max + 1, disp_max>>>( 684 | THCudaTensor_data(state, x0), 685 | THCudaTensor_data(state, x1), 686 | THCudaTensor_data(state, input), 687 | THCudaTensor_data(state, output), 688 | THCudaTensor_data(state, tmp), 689 | pi1, pi2, tau_so, alpha1, sgm_q1, sgm_q2, direction, 690 | THCudaTensor_size(state, input, 1), 691 | THCudaTensor_size(state, input, 2), 692 | THCudaTensor_size(state, input, 3), 693 | step); 694 | } 695 | 696 | checkCudaError(L); 697 | return 0; 698 | } 699 | 700 | template 701 | __global__ void sgm3(float *x0, float *x1, float *input, float *output, float pi1, float pi2, float tau_so, float alpha1, float sgm_q1, float sgm_q2, int direction, int size1, int size2, int size3, int step) 702 | { 703 | int x, y, dx, dy; 704 | int d = threadIdx.x; 705 | 706 | if (sgm_direction == 0) { 707 | /* right */ 708 | x = step; 709 | y = blockIdx.x; 710 | dx = 1; 711 | dy = 0; 712 | } else if (sgm_direction == 1) { 713 | /* left */ 714 | x = size2 - 1 - step; 715 | y = blockIdx.x; 716 | dx = -1; 717 | dy = 0; 718 | } else if (sgm_direction == 2) { 719 | /* down */ 720 | x = blockIdx.x; 721 | y = step; 722 | dx = 0; 723 | dy = 1; 724 | } else if (sgm_direction == 3) { 725 | /* up */ 726 | x = blockIdx.x; 727 | y = size1 - 1 - step; 728 | dx = 0; 729 | dy = -1; 730 | } 731 | 732 | if (y - dy < 0 || y - dy >= size1 || x - dx < 0 || x - dx >= size2) { 733 | output[INDEX(sgm_direction, y, x, d)] = input[INDEX(0, y, x, d)]; 734 | return; 735 | } 736 | 737 | __shared__ float output_s[400], output_min[400]; 738 | 739 | output_s[d] = output_min[d] = output[INDEX(sgm_direction, y - dy, x - dx, d)]; 740 | __syncthreads(); 741 | 742 | for (int i = 256; i > 0; i /= 2) { 743 | if (d < i && d + i < size3 && output_min[d + i] < output_min[d]) { 744 | output_min[d] = output_min[d + i]; 745 | } 746 | __syncthreads(); 747 | } 748 | 749 | int ind2 = y * size2 + x; 750 | float D1 = COLOR_DIFF(x0, ind2, ind2 - dy * size2 - dx); 751 | float D2; 752 | int xx = x + d * direction; 753 | if (xx < 0 || xx >= size2 || xx - dx < 0 || xx - dx >= size2) { 754 | D2 = 10; 755 | } else { 756 | D2 = COLOR_DIFF(x1, ind2 + d * direction, ind2 + d * direction - dy * size2 - dx); 757 | } 758 | float P1, P2; 759 | if (D1 < tau_so && D2 < tau_so) { 760 | P1 = pi1; 761 | P2 = pi2; 762 | } else if (D1 > tau_so && D2 > tau_so) { 763 | P1 = pi1 / (sgm_q1 * sgm_q2); 764 | P2 = pi2 / (sgm_q1 * sgm_q2); 765 | } else { 766 | P1 = pi1 / sgm_q1; 767 | P2 = pi2 / sgm_q1; 768 | } 769 | 770 | float cost = min(output_s[d], output_min[0] + P2); 771 | if (d - 1 >= 0) { 772 | cost = min(cost, output_s[d - 1] + (sgm_direction == 2 ? P1 / alpha1 : P1)); 773 | } 774 | if (d + 1 < size3) { 775 | cost = min(cost, output_s[d + 1] + (sgm_direction == 3 ? P1 / alpha1 : P1)); 776 | } 777 | 778 | output[INDEX(sgm_direction, y, x, d)] = input[INDEX(0, y, x, d)] + cost - output_min[0]; 779 | } 780 | 781 | int sgm3(lua_State *L) 782 | { 783 | THCState *state = getCutorchState(L); 784 | THCudaTensor *x0 = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 785 | THCudaTensor *x1 = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 786 | THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 787 | THCudaTensor *output = (THCudaTensor*)luaT_checkudata(L, 4, "torch.CudaTensor"); 788 | float pi1 = luaL_checknumber(L, 5); 789 | float pi2 = luaL_checknumber(L, 6); 790 | float tau_so = luaL_checknumber(L, 7); 791 | float alpha1 = luaL_checknumber(L, 8); 792 | float sgm_q1 = luaL_checknumber(L, 9); 793 | float sgm_q2 = luaL_checknumber(L, 10); 794 | int direction = luaL_checknumber(L, 11); 795 | int size1 = THCudaTensor_size(state, output, 1) * THCudaTensor_size(state, output, 3); 796 | int size2 = THCudaTensor_size(state, output, 2) * THCudaTensor_size(state, output, 3); 797 | int disp_max = THCudaTensor_size(state, output, 3); 798 | 799 | for (int step = 0; step < THCudaTensor_size(state, input, 2); step++) { 800 | sgm3<0><<<(size1 - 1) / disp_max + 1, disp_max>>>( 801 | THCudaTensor_data(state, x0), 802 | THCudaTensor_data(state, x1), 803 | THCudaTensor_data(state, input), 804 | THCudaTensor_data(state, output), 805 | pi1, pi2, tau_so, alpha1, sgm_q1, sgm_q2, direction, 806 | THCudaTensor_size(state, input, 1), 807 | THCudaTensor_size(state, input, 2), 808 | THCudaTensor_size(state, input, 3), 809 | step); 810 | } 811 | 812 | for (int step = 0; step < THCudaTensor_size(state, input, 2); step++) { 813 | sgm3<1><<<(size1 - 1) / disp_max + 1, disp_max>>>( 814 | THCudaTensor_data(state, x0), 815 | THCudaTensor_data(state, x1), 816 | THCudaTensor_data(state, input), 817 | THCudaTensor_data(state, output), 818 | pi1, pi2, tau_so, alpha1, sgm_q1, sgm_q2, direction, 819 | THCudaTensor_size(state, input, 1), 820 | THCudaTensor_size(state, input, 2), 821 | THCudaTensor_size(state, input, 3), 822 | step); 823 | } 824 | 825 | for (int step = 0; step < THCudaTensor_size(state, input, 1); step++) { 826 | sgm3<2><<<(size2 - 1) / disp_max + 1, disp_max>>>( 827 | THCudaTensor_data(state, x0), 828 | THCudaTensor_data(state, x1), 829 | THCudaTensor_data(state, input), 830 | THCudaTensor_data(state, output), 831 | pi1, pi2, tau_so, alpha1, sgm_q1, sgm_q2, direction, 832 | THCudaTensor_size(state, input, 1), 833 | THCudaTensor_size(state, input, 2), 834 | THCudaTensor_size(state, input, 3), 835 | step); 836 | } 837 | 838 | for (int step = 0; step < THCudaTensor_size(state, input, 1); step++) { 839 | sgm3<3><<<(size2 - 1) / disp_max + 1, disp_max>>>( 840 | THCudaTensor_data(state, x0), 841 | THCudaTensor_data(state, x1), 842 | THCudaTensor_data(state, input), 843 | THCudaTensor_data(state, output), 844 | pi1, pi2, tau_so, alpha1, sgm_q1, sgm_q2, direction, 845 | THCudaTensor_size(state, input, 1), 846 | THCudaTensor_size(state, input, 2), 847 | THCudaTensor_size(state, input, 3), 848 | step); 849 | } 850 | 851 | checkCudaError(L); 852 | return 0; 853 | } 854 | 855 | __global__ void fliplr(float *in, float *out, int size, int dim3) 856 | { 857 | int id = blockIdx.x * blockDim.x + threadIdx.x; 858 | if (id < size) { 859 | int x = id % dim3; 860 | out[id + dim3 - 2 * x - 1] = in[id]; 861 | } 862 | } 863 | 864 | int fliplr(lua_State *L) 865 | { 866 | THCState *state = getCutorchState(L); 867 | THCudaTensor *in = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 868 | THCudaTensor *out = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 869 | 870 | fliplr<<<(THCudaTensor_nElement(state, out) - 1) / TB + 1, TB>>>( 871 | THCudaTensor_data(state, in), 872 | THCudaTensor_data(state, out), 873 | THCudaTensor_nElement(state, out), 874 | THCudaTensor_size(state, out, 3)); 875 | checkCudaError(L); 876 | return 0; 877 | } 878 | 879 | __global__ void outlier_detection(float *d0, float *d1, float *outlier, int size, int dim3, float *conf1, float *conf2, int disp_max, float t1, float t2) 880 | { 881 | int id = blockIdx.x * blockDim.x + threadIdx.x; 882 | if (id < size) { 883 | int x = id % dim3; 884 | int d0i = d0[id]; 885 | if (x - d0i < 0) { 886 | //assert(0); 887 | outlier[id] = 1; 888 | } else if ((abs(d0[id] - d1[id - d0i]) < 1.1) 889 | || (conf1[id] > t1 890 | && (conf1[id] - conf2[id- d0i] > t2) 891 | )){ 892 | outlier[id] = 0; /* match */ 893 | } else { 894 | outlier[id] = 1; /* occlusion */ 895 | for (int d = 0; d < disp_max; d++) { 896 | if (x - d >= 0 && abs(d - d1[id - d]) < 1.1) { 897 | outlier[id] = 2; /* mismatch */ 898 | break; 899 | } 900 | } 901 | } 902 | } 903 | } 904 | 905 | int outlier_detection(lua_State *L) 906 | { 907 | THCState *state = getCutorchState(L); 908 | THCudaTensor *d0 = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 909 | THCudaTensor *d1 = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 910 | THCudaTensor *outlier = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 911 | int disp_max = luaL_checkinteger(L, 4); 912 | THCudaTensor *conf1 = (THCudaTensor*)luaT_checkudata(L, 5, "torch.CudaTensor"); 913 | THCudaTensor *conf2 = (THCudaTensor*)luaT_checkudata(L, 6, "torch.CudaTensor"); 914 | float t1 = luaL_checknumber(L, 7); 915 | float t2 = luaL_checknumber(L, 8); 916 | 917 | outlier_detection<<<(THCudaTensor_nElement(state, d0) - 1) / TB + 1, TB>>>( 918 | THCudaTensor_data(state, d0), 919 | THCudaTensor_data(state, d1), 920 | THCudaTensor_data(state, outlier), 921 | THCudaTensor_nElement(state, d0), 922 | THCudaTensor_size(state, d0, 3), 923 | THCudaTensor_data(state, conf1), 924 | THCudaTensor_data(state, conf2), 925 | disp_max, t1, t2); 926 | checkCudaError(L); 927 | return 0; 928 | } 929 | 930 | #if 0 931 | 932 | __global__ void iterative_region_voting(float *d0, float *x0c, float *x1c, float *outlier, float *d0_out, float *outlier_out, int size, int dim2, int dim3, float tau_s, float tau_h, int disp_max) 933 | { 934 | int id = blockIdx.x * blockDim.x + threadIdx.x; 935 | if (id < size) { 936 | int x = id % dim3; 937 | int y = id / dim3; 938 | 939 | d0_out[id] = d0[id]; 940 | outlier_out[id] = outlier[id]; 941 | 942 | if (outlier[id] == 0) return; 943 | 944 | assert(disp_max < DISP_MAX); 945 | int hist[DISP_MAX]; 946 | for (int i = 0; i < disp_max; i++) { 947 | hist[i] = 0; 948 | } 949 | 950 | int yy_s = x0c[(2 * dim2 + y) * dim3 + x]; 951 | int yy_t = x0c[(3 * dim2 + y) * dim3 + x]; 952 | for (int yy = yy_s + 1; yy < yy_t; yy++) { 953 | int xx_s = x0c[(0 * dim2 + yy) * dim3 + x]; 954 | int xx_t = x0c[(1 * dim2 + yy) * dim3 + x]; 955 | for (int xx = xx_s + 1; xx < xx_t; xx++) { 956 | if (outlier[yy * dim3 + xx] == 0) { 957 | hist[(int)d0[yy * dim3 + xx]]++; 958 | } 959 | } 960 | } 961 | 962 | int cnt = 0; 963 | int max_i = 0; 964 | for (int i = 0; i < disp_max; i++) { 965 | cnt += hist[i]; 966 | if (hist[i] > hist[max_i]) { 967 | max_i = i; 968 | } 969 | } 970 | 971 | if (cnt > tau_s && (float)hist[max_i] / cnt > tau_h) { 972 | outlier_out[id] = 0; 973 | d0_out[id] = max_i; 974 | } 975 | } 976 | } 977 | 978 | int iterative_region_voting(lua_State *L) 979 | { 980 | THCudaTensor *d0 = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 981 | THCudaTensor *x0c = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 982 | THCudaTensor *x1c = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 983 | THCudaTensor *outlier = (THCudaTensor*)luaT_checkudata(L, 4, "torch.CudaTensor"); 984 | float tau_s = luaL_checknumber(L, 5); 985 | float tau_h = luaL_checknumber(L, 6); 986 | int disp_max = luaL_checkinteger(L, 7); 987 | int iterations = luaL_checkinteger(L, 8); 988 | 989 | THCudaTensor *d0_tmp = new_tensor_like(state, d0); 990 | THCudaTensor *outlier_tmp = new_tensor_like(state, outlier); 991 | 992 | assert(iterations % 2 == 0); 993 | for (int i = 0; i < iterations; i++) { 994 | iterative_region_voting<<<(THCudaTensor_nElement(d0) - 1) / TB + 1, TB>>>( 995 | THCudaTensor_data(i % 2 == 0 ? d0 : d0_tmp), 996 | THCudaTensor_data(x0c), 997 | THCudaTensor_data(x1c), 998 | THCudaTensor_data(i % 2 == 0 ? outlier : outlier_tmp), 999 | THCudaTensor_data(i % 2 == 0 ? d0_tmp : d0), 1000 | THCudaTensor_data(i % 2 == 0 ? outlier_tmp : outlier), 1001 | THCudaTensor_nElement(d0), 1002 | THCudaTensor_size(d0, 2), 1003 | THCudaTensor_size(d0, 3), 1004 | tau_s, tau_h, disp_max); 1005 | } 1006 | checkCudaError(L); 1007 | return 0; 1008 | } 1009 | #endif 1010 | 1011 | __global__ void interpolate_mismatch(float *d0, float *outlier, float *out, int size, int dim2, int dim3) 1012 | { 1013 | const float dir[] = { 1014 | 0 , 1, 1015 | -0.5, 1, 1016 | -1 , 1, 1017 | -1 , 0.5, 1018 | -1 , 0, 1019 | -1 , -0.5, 1020 | -1 , -1, 1021 | -0.5, -1, 1022 | 0 , -1, 1023 | 0.5 , -1, 1024 | 1 , -1, 1025 | 1 , -0.5, 1026 | 1 , 0, 1027 | 1 , 0.5, 1028 | 1 , 1, 1029 | 0.5 , 1 1030 | }; 1031 | 1032 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1033 | if (id < size) { 1034 | if (outlier[id] != 2) { 1035 | out[id] = d0[id]; 1036 | return; 1037 | } 1038 | 1039 | float vals[16]; 1040 | int vals_size = 0; 1041 | 1042 | int x = id % dim3; 1043 | int y = id / dim3; 1044 | for (int d = 0; d < 16; d++) { 1045 | float dx = dir[2 * d]; 1046 | float dy = dir[2 * d + 1]; 1047 | float xx = x; 1048 | float yy = y; 1049 | int xx_i = round(xx); 1050 | int yy_i = round(yy); 1051 | while (0 <= yy_i && yy_i < dim2 && 0 <= xx_i && xx_i < dim3 && outlier[yy_i * dim3 + xx_i] == 2) { 1052 | xx += dx; 1053 | yy += dy; 1054 | xx_i = round(xx); 1055 | yy_i = round(yy); 1056 | } 1057 | 1058 | int ind = yy_i * dim3 + xx_i; 1059 | if (0 <= yy_i && yy_i < dim2 && 0 <= xx_i && xx_i < dim3) { 1060 | assert(outlier[ind] != 2); 1061 | vals[vals_size++] = d0[ind]; 1062 | } 1063 | } 1064 | assert(vals_size > 0); 1065 | sort(vals, vals_size); 1066 | out[id] = vals[vals_size / 2]; 1067 | } 1068 | } 1069 | 1070 | int interpolate_mismatch(lua_State *L) 1071 | { 1072 | THCState *state = getCutorchState(L); 1073 | THCudaTensor *d0 = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1074 | THCudaTensor *outlier = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 1075 | THCudaTensor *out = new_tensor_like(state, d0); 1076 | 1077 | interpolate_mismatch<<<(THCudaTensor_nElement(state, out) - 1) / TB + 1, TB>>>( 1078 | THCudaTensor_data(state, d0), 1079 | THCudaTensor_data(state, outlier), 1080 | THCudaTensor_data(state, out), 1081 | THCudaTensor_nElement(state, out), 1082 | THCudaTensor_size(state, out, 2), 1083 | THCudaTensor_size(state, out, 3)); 1084 | checkCudaError(L); 1085 | luaT_pushudata(L, out, "torch.CudaTensor"); 1086 | return 1; 1087 | } 1088 | 1089 | __global__ void interpolate_occlusion(float *d0, float *outlier, float *out, int size, int dim3) 1090 | { 1091 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1092 | if (id < size) { 1093 | if (outlier[id] != 1) { 1094 | out[id] = d0[id]; 1095 | return; 1096 | } 1097 | int x = id % dim3; 1098 | 1099 | int dx = 0; 1100 | while (x + dx >= 0 && outlier[id + dx] != 0) { 1101 | dx--; 1102 | } 1103 | if (x + dx < 0) { 1104 | dx = 0; 1105 | while (x + dx < dim3 && outlier[id + dx] != 0) { 1106 | dx++; 1107 | } 1108 | } 1109 | if (x + dx < dim3) { 1110 | out[id] = d0[id + dx]; 1111 | } else { 1112 | out[id] = d0[id]; 1113 | } 1114 | } 1115 | } 1116 | 1117 | int interpolate_occlusion(lua_State *L) 1118 | { 1119 | THCState *state = getCutorchState(L); 1120 | THCudaTensor *d0 = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1121 | THCudaTensor *outlier = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 1122 | THCudaTensor *out = new_tensor_like(state, d0); 1123 | 1124 | interpolate_occlusion<<<(THCudaTensor_nElement(state, out) - 1) / TB + 1, TB>>>( 1125 | THCudaTensor_data(state, d0), 1126 | THCudaTensor_data(state, outlier), 1127 | THCudaTensor_data(state, out), 1128 | THCudaTensor_nElement(state, out), 1129 | THCudaTensor_size(state, out, 3) 1130 | ); 1131 | 1132 | checkCudaError(L); 1133 | luaT_pushudata(L, out, "torch.CudaTensor"); 1134 | return 1; 1135 | } 1136 | 1137 | #if 0 1138 | 1139 | __global__ void sobel(float *x, float *g1, float *g2, int size, int dim2, int dim3) 1140 | { 1141 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1142 | if (id < size) { 1143 | int xx = id % dim3; 1144 | int yy = id / dim3; 1145 | 1146 | if (1 <= yy && yy < dim2 - 1 && 1 <= xx && xx < dim3 - 1) { 1147 | g1[id] = -x[id-dim3-1] +x[id-dim3+1] -2*x[id-1] +2*x[id+1] -x[id+dim3-1] +x[id+dim3+1]; 1148 | g2[id] = x[id-dim3-1] +2*x[id-dim3] +x[id-dim3+1] -x[id+dim3-1] -2*x[id+dim3] -x[id+dim3+1]; 1149 | } else { 1150 | g1[id] = 0; 1151 | g2[id] = 0; 1152 | } 1153 | } 1154 | } 1155 | 1156 | int sobel(lua_State *L) { 1157 | THCudaTensor *x = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1158 | THCudaTensor *g1 = new_tensor_like(x); 1159 | THCudaTensor *g2 = new_tensor_like(x); 1160 | 1161 | sobel<<<(THCudaTensor_nElement(x) - 1) / TB + 1, TB>>>( 1162 | THCudaTensor_data(x), 1163 | THCudaTensor_data(g1), 1164 | THCudaTensor_data(g2), 1165 | THCudaTensor_nElement(x), 1166 | THCudaTensor_size(x, 2), 1167 | THCudaTensor_size(x, 3) 1168 | ); 1169 | 1170 | checkCudaError(L); 1171 | luaT_pushudata(L, g1, "torch.CudaTensor"); 1172 | luaT_pushudata(L, g2, "torch.CudaTensor"); 1173 | return 2; 1174 | } 1175 | 1176 | __global__ void depth_discontinuity_adjustment(float *d0, float *dg1, float *dg2, float *xg1, float *xg2, float *out, int size, int dim3, float tau_e) 1177 | { 1178 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1179 | if (id < size) { 1180 | if (abs(dg1[id]) > tau_e) { 1181 | out[id] = xg1[id - 1] > xg1[id + 1] ? d0[id - 1] : d0[id + 1]; 1182 | } else if (abs(dg2[id]) > tau_e) { 1183 | out[id] = xg2[id - dim3] > xg2[id + dim3] ? d0[id - dim3] : d0[id + dim3]; 1184 | } else { 1185 | out[id] = d0[id]; 1186 | } 1187 | } 1188 | } 1189 | 1190 | int depth_discontinuity_adjustment(lua_State *L) { 1191 | THCudaTensor *d0 = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1192 | THCudaTensor *dg1 = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 1193 | THCudaTensor *dg2 = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 1194 | THCudaTensor *xg1 = (THCudaTensor*)luaT_checkudata(L, 4, "torch.CudaTensor"); 1195 | THCudaTensor *xg2 = (THCudaTensor*)luaT_checkudata(L, 5, "torch.CudaTensor"); 1196 | float tau_e = luaL_checknumber(L, 6); 1197 | THCudaTensor *out = new_tensor_like(d0); 1198 | 1199 | depth_discontinuity_adjustment<<<(THCudaTensor_nElement(out) - 1) / TB + 1, TB>>>( 1200 | THCudaTensor_data(d0), 1201 | THCudaTensor_data(dg1), 1202 | THCudaTensor_data(dg2), 1203 | THCudaTensor_data(xg1), 1204 | THCudaTensor_data(xg2), 1205 | THCudaTensor_data(out), 1206 | THCudaTensor_nElement(out), 1207 | THCudaTensor_size(out, 3), 1208 | tau_e); 1209 | checkCudaError(L); 1210 | luaT_pushudata(L, out, "torch.CudaTensor"); 1211 | return 1; 1212 | } 1213 | #endif 1214 | 1215 | __global__ void subpixel_enchancement(float *d0, float *c2, float *out, int size, int dim23, int disp_max) { 1216 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1217 | if (id < size) { 1218 | int d = d0[id]; 1219 | out[id] = d; 1220 | if (1 <= d && d < disp_max - 1) { 1221 | float cn = c2[(d - 1) * dim23 + id]; 1222 | float cz = c2[d * dim23 + id]; 1223 | float cp = c2[(d + 1) * dim23 + id]; 1224 | float denom = 2 * (cp + cn - 2 * cz); 1225 | if (denom > 1e-5) { 1226 | out[id] = d - min(1.0, max(-1.0, (cp - cn) / denom)); 1227 | } 1228 | } 1229 | } 1230 | } 1231 | 1232 | int subpixel_enchancement(lua_State *L) { 1233 | THCState *state = getCutorchState(L); 1234 | THCudaTensor *d0 = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1235 | THCudaTensor *c2 = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 1236 | int disp_max = luaL_checkinteger(L, 3); 1237 | THCudaTensor *out = new_tensor_like(state, d0); 1238 | 1239 | subpixel_enchancement<<<(THCudaTensor_nElement(state, out) - 1) / TB + 1, TB>>>( 1240 | THCudaTensor_data(state, d0), 1241 | THCudaTensor_data(state, c2), 1242 | THCudaTensor_data(state, out), 1243 | THCudaTensor_nElement(state, out), 1244 | THCudaTensor_size(state, out, 2) * THCudaTensor_size(state, out, 3), 1245 | disp_max); 1246 | checkCudaError(L); 1247 | luaT_pushudata(L, out, "torch.CudaTensor"); 1248 | return 1; 1249 | } 1250 | 1251 | __global__ void mean2d(float *img, float *kernel, float *out, int size, int kernel_radius, int dim2, int dim3, float alpha2) 1252 | { 1253 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1254 | if (id < size) { 1255 | int x = id % dim3; 1256 | int y = id / dim3; 1257 | 1258 | float sum = 0; 1259 | float cnt = 0; 1260 | int i = 0; 1261 | for (int xx = x - kernel_radius; xx <= x + kernel_radius; xx++) { 1262 | for (int yy = y - kernel_radius; yy <= y + kernel_radius; yy++, i++) { 1263 | if (0 <= xx && xx < dim3 && 0 <= yy && yy < dim2 && abs(img[yy * dim3 + xx] - img[y * dim3 + x]) < alpha2) { 1264 | sum += img[yy * dim3 + xx] * kernel[i]; 1265 | cnt += kernel[i]; 1266 | } 1267 | } 1268 | } 1269 | out[id] = sum / cnt; 1270 | } 1271 | } 1272 | 1273 | int mean2d(lua_State *L) { 1274 | THCState *state = getCutorchState(L); 1275 | THCudaTensor *img = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1276 | THCudaTensor *kernel = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 1277 | float alpha2 = luaL_checknumber(L, 3); 1278 | THCudaTensor *out = new_tensor_like(state, img); 1279 | assert(THCudaTensor_size(state, kernel, 0) % 2 == 1); 1280 | mean2d<<<(THCudaTensor_nElement(state, out) - 1) / TB + 1, TB>>>( 1281 | THCudaTensor_data(state, img), 1282 | THCudaTensor_data(state, kernel), 1283 | THCudaTensor_data(state, out), 1284 | THCudaTensor_nElement(state, out), 1285 | THCudaTensor_size(state, kernel, 0) / 2, 1286 | THCudaTensor_size(state, out, 2), 1287 | THCudaTensor_size(state, out, 3), 1288 | alpha2); 1289 | checkCudaError(L); 1290 | luaT_pushudata(L, out, "torch.CudaTensor"); 1291 | return 1; 1292 | } 1293 | 1294 | __global__ void Normalize_get_norm_(float *input, float *norm, int size1, int size23, int size023) 1295 | { 1296 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1297 | if (id < size023) { 1298 | int dim23 = id % size23; 1299 | int dim0 = id / size23; 1300 | 1301 | float sum = 0.0; 1302 | for (int dim1 = 0; dim1 < size1; dim1++) { 1303 | float x = input[(dim0 * size1 + dim1) * size23 + dim23]; 1304 | sum += x * x; 1305 | } 1306 | norm[dim0 * size23 + dim23] = sum + 1e-5; 1307 | } 1308 | } 1309 | 1310 | __global__ void Normalize_forward_(float *input, float *norm, float *output, int size23, int size123, int size0123) 1311 | { 1312 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1313 | if (id < size0123) { 1314 | int dim23 = id % size23; 1315 | int dim0 = (id / size123); 1316 | output[id] = input[id] / sqrtf(norm[dim0 * size23 + dim23]); 1317 | } 1318 | } 1319 | 1320 | int Normalize_forward(lua_State *L) 1321 | { 1322 | THCState *state = getCutorchState(L); 1323 | THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1324 | THCudaTensor *norm = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 1325 | THCudaTensor *output = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 1326 | 1327 | Normalize_get_norm_<<<(THCudaTensor_nElement(state, norm) - 1) / TB + 1, TB>>>( 1328 | THCudaTensor_data(state, input), 1329 | THCudaTensor_data(state, norm), 1330 | THCudaTensor_size(state, input, 1), 1331 | THCudaTensor_size(state, input, 2) * THCudaTensor_size(state, input, 3), 1332 | THCudaTensor_nElement(state, norm)); 1333 | 1334 | Normalize_forward_<<<(THCudaTensor_nElement(state, output) - 1) / TB + 1, TB>>>( 1335 | THCudaTensor_data(state, input), 1336 | THCudaTensor_data(state, norm), 1337 | THCudaTensor_data(state, output), 1338 | THCudaTensor_size(state, input, 2) * THCudaTensor_size(state, input, 3), 1339 | THCudaTensor_size(state, input, 1) * THCudaTensor_size(state, input, 2) * THCudaTensor_size(state, input, 3), 1340 | THCudaTensor_nElement(state, output)); 1341 | checkCudaError(L); 1342 | return 0; 1343 | } 1344 | 1345 | __global__ void Normalize_backward_input_(float *grad_output, float *input, float *norm, float *grad_input, int size1, int size23, int size0123) 1346 | { 1347 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1348 | if (id < size0123) { 1349 | int dim0 = id; 1350 | int dim23 = dim0 % size23; 1351 | dim0 /= size23; 1352 | int dim1 = dim0 % size1; 1353 | dim0 /= size1; 1354 | 1355 | float denom = powf(norm[dim0 * size23 + dim23], 1.5); 1356 | float deriv = (norm[dim0 * size23 + dim23] - input[id] * input[id]) / denom * grad_output[id]; 1357 | 1358 | float sum = 0; 1359 | for (int dim1_ = 0; dim1_ < size1; dim1_++) { 1360 | if (dim1_ != dim1) { 1361 | int ind = (dim0 * size1 + dim1_) * size23 + dim23; 1362 | sum += input[ind] * grad_output[ind]; 1363 | } 1364 | } 1365 | grad_input[id] = deriv - sum * input[id] / denom; 1366 | } 1367 | } 1368 | 1369 | int Normalize_backward_input(lua_State *L) 1370 | { 1371 | THCState *state = getCutorchState(L); 1372 | THCudaTensor *grad_output = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1373 | THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 1374 | THCudaTensor *norm = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 1375 | THCudaTensor *grad_input = (THCudaTensor*)luaT_checkudata(L, 4, "torch.CudaTensor"); 1376 | 1377 | Normalize_backward_input_<<<(THCudaTensor_nElement(state, input) - 1) / TB + 1, TB>>>( 1378 | THCudaTensor_data(state, grad_output), 1379 | THCudaTensor_data(state, input), 1380 | THCudaTensor_data(state, norm), 1381 | THCudaTensor_data(state, grad_input), 1382 | THCudaTensor_size(state, input, 1), 1383 | THCudaTensor_size(state, input, 2) * THCudaTensor_size(state, input, 3), 1384 | THCudaTensor_nElement(state, input)); 1385 | checkCudaError(L); 1386 | return 0; 1387 | } 1388 | 1389 | struct Margin2_functor { 1390 | float margin; 1391 | __host__ Margin2_functor(float margin_) : margin(margin_) {}; 1392 | __device__ float forward(float pos, float neg) { 1393 | return fmaxf(0, neg - pos + margin); 1394 | } 1395 | __device__ float backward(float pos, float neg, int which) { 1396 | float f = neg - pos + margin; 1397 | if (which == 0) { 1398 | return -1 * (f > 0); 1399 | } else { 1400 | return f > 0; 1401 | } 1402 | } 1403 | }; 1404 | 1405 | struct Margin2_squared_functor { 1406 | float margin; 1407 | __host__ Margin2_squared_functor(float margin_) : margin(margin_) {}; 1408 | __device__ float forward(float pos, float neg) { 1409 | float d = fmaxf(0, neg - pos + margin); 1410 | return d * d * 0.5; 1411 | } 1412 | __device__ float backward(float pos, float neg, int which) { 1413 | float f = neg - pos + margin; 1414 | if (which == 0) { 1415 | return -f * (f > 0); 1416 | } else { 1417 | return f * (f > 0); 1418 | } 1419 | } 1420 | }; 1421 | 1422 | template 1423 | __global__ void Margin2_(float *input, float *tmp, float *gradInput, float margin, Op op, int size) 1424 | { 1425 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1426 | if (id < size) { 1427 | float pos = input[id * 2]; 1428 | float neg = input[id * 2 + 1]; 1429 | tmp[id] = op.forward(pos, neg); 1430 | gradInput[id * 2] = op.backward(pos, neg, 0); 1431 | gradInput[id * 2 + 1] = op.backward(pos, neg, 1); 1432 | } 1433 | } 1434 | 1435 | int Margin2(lua_State *L) 1436 | { 1437 | THCState *state = getCutorchState(L); 1438 | THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1439 | THCudaTensor *tmp = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 1440 | THCudaTensor *gradInput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 1441 | float margin = luaL_checknumber(L, 4); 1442 | int pow = luaL_checkinteger(L, 5); 1443 | 1444 | if (pow == 1) { 1445 | Margin2_<<<(THCudaTensor_nElement(state, tmp) - 1) / TB + 1, TB>>>( 1446 | THCudaTensor_data(state, input), 1447 | THCudaTensor_data(state, tmp), 1448 | THCudaTensor_data(state, gradInput), 1449 | margin, 1450 | Margin2_functor(margin), 1451 | THCudaTensor_nElement(state, tmp)); 1452 | } else if (pow == 2) { 1453 | Margin2_<<<(THCudaTensor_nElement(state, tmp) - 1) / TB + 1, TB>>>( 1454 | THCudaTensor_data(state, input), 1455 | THCudaTensor_data(state, tmp), 1456 | THCudaTensor_data(state, gradInput), 1457 | margin, 1458 | Margin2_squared_functor(margin), 1459 | THCudaTensor_nElement(state, tmp)); 1460 | } 1461 | checkCudaError(L); 1462 | return 0; 1463 | } 1464 | 1465 | 1466 | __global__ void StereoJoin_(float *input_L, float *input_R, float *output_L, float *output_R, int size1_input, int size1, int size3, int size23) 1467 | { 1468 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1469 | if (id < size23) { 1470 | int dim3 = id % size3; 1471 | assert(size1_input <= 128); 1472 | float L_cache[128]; 1473 | for (int i = 0; i < size1_input; i++) { 1474 | L_cache[i] = input_L[i * size23 + id]; 1475 | } 1476 | 1477 | for (int d = 0; d < size1; d++) { 1478 | if (dim3 - d >= 0) { 1479 | float sum = 0; 1480 | for (int i = 0; i < size1_input; i++) { 1481 | sum -= L_cache[i] * input_R[i * size23 + id - d]; 1482 | } 1483 | output_L[d * size23 + id] = sum; 1484 | output_R[d * size23 + id - d] = sum; 1485 | } 1486 | } 1487 | } 1488 | } 1489 | 1490 | 1491 | int StereoJoin(lua_State *L) 1492 | { 1493 | THCState *state = getCutorchState(L); 1494 | THCudaTensor *input_L = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1495 | THCudaTensor *input_R = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 1496 | THCudaTensor *output_L = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 1497 | THCudaTensor *output_R = (THCudaTensor*)luaT_checkudata(L, 4, "torch.CudaTensor"); 1498 | int size23 = THCudaTensor_size(state, output_L, 2) * THCudaTensor_size(state, output_L, 3); 1499 | StereoJoin_<<<(size23 - 1) / TB + 1, TB>>>( 1500 | THCudaTensor_data(state, input_L), 1501 | THCudaTensor_data(state, input_R), 1502 | THCudaTensor_data(state, output_L), 1503 | THCudaTensor_data(state, output_R), 1504 | THCudaTensor_size(state, input_L, 1), 1505 | THCudaTensor_size(state, output_L, 1), 1506 | THCudaTensor_size(state, output_L, 3), 1507 | size23); 1508 | checkCudaError(L); 1509 | return 0; 1510 | } 1511 | __global__ void L2dist_(float *input_L, float *input_R, float *output_L, float *output_R, int size1_input, int size1, int size3, int size23) 1512 | { 1513 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1514 | if (id < size23) { 1515 | int dim3 = id % size3; 1516 | assert(size1_input <= 512); 1517 | float L_cache[512]; 1518 | for (int i = 0; i < size1_input; i++) { 1519 | L_cache[i] = input_L[i * size23 + id]; 1520 | } 1521 | 1522 | for (int d = 0; d < size1; d++) { 1523 | if (dim3 - d >= 0) { 1524 | float sum = 0; 1525 | float diff = 0; 1526 | for (int i = 0; i < size1_input; i++) { 1527 | diff = L_cache[i] - input_R[i * size23 + id - d]; 1528 | sum += diff*diff; 1529 | } 1530 | sum = sqrt(sum); 1531 | output_L[d * size23 + id] = sum; 1532 | output_R[d * size23 + id - d] = sum; 1533 | } 1534 | } 1535 | } 1536 | } 1537 | 1538 | int L2dist(lua_State *L) 1539 | { 1540 | THCState *state = getCutorchState(L); 1541 | THCudaTensor *input_L = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1542 | THCudaTensor *input_R = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 1543 | THCudaTensor *output_L = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor"); 1544 | THCudaTensor *output_R = (THCudaTensor*)luaT_checkudata(L, 4, "torch.CudaTensor"); 1545 | int size23 = THCudaTensor_size(state, output_L, 2) * THCudaTensor_size(state, output_L, 3); 1546 | L2dist_<<<(size23 - 1) / TB + 1, TB>>>( 1547 | THCudaTensor_data(state, input_L), 1548 | THCudaTensor_data(state, input_R), 1549 | THCudaTensor_data(state, output_L), 1550 | THCudaTensor_data(state, output_R), 1551 | THCudaTensor_size(state, input_L, 1), 1552 | THCudaTensor_size(state, output_L, 1), 1553 | THCudaTensor_size(state, output_L, 3), 1554 | size23); 1555 | checkCudaError(L); 1556 | return 0; 1557 | } 1558 | 1559 | __global__ void StereoL2R_(float *vol_L, float *vol_R, int size2, int size3, int size) 1560 | { 1561 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1562 | if (id < size) { 1563 | int dim3 = id % size3; 1564 | int dim1 = id / (size2 * size3); 1565 | 1566 | if (dim3 + dim1 >= size3) { 1567 | vol_R[id] = CUDART_INF; 1568 | } else { 1569 | vol_R[id] = vol_L[id + dim1]; 1570 | } 1571 | } 1572 | } 1573 | 1574 | int StereoL2R(lua_State *L) 1575 | { 1576 | THCState *state = getCutorchState(L); 1577 | THCudaTensor *vol_L = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1578 | THCudaTensor *vol_R = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 1579 | StereoL2R_<<<(THCudaTensor_nElement(state, vol_L) - 1) / TB + 1, TB>>>( 1580 | THCudaTensor_data(state, vol_L), 1581 | THCudaTensor_data(state, vol_R), 1582 | THCudaTensor_size(state, vol_R, 2), 1583 | THCudaTensor_size(state, vol_R, 3), 1584 | THCudaTensor_nElement(state, vol_R)); 1585 | checkCudaError(L); 1586 | return 0; 1587 | } 1588 | 1589 | __global__ void bilateral_filter(float *img, float *out, int size, int dim2, int dim3, int kernel_radius, float sigma1, float sigma2) 1590 | { 1591 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1592 | if (id < size) { 1593 | int x = id % dim3; 1594 | int y = id / dim3; 1595 | 1596 | float sum = 0; 1597 | float cnt = 0; 1598 | for (int i = -kernel_radius; i <= kernel_radius; i++) { 1599 | for (int j = -kernel_radius; j <= kernel_radius; j++) { 1600 | int yy = y + i; 1601 | int xx = x + j; 1602 | if (0 <= xx && xx < dim3 && 0 <= yy && yy < dim2) { 1603 | float color_diff = img[yy * dim3 + xx] - img[y * dim3 + x]; 1604 | float v1 = exp(-(i * i + j * j) / (2 * sigma1 * sigma1)); 1605 | float v2 = exp(-(color_diff * color_diff) / (2 * sigma2 * sigma2)); 1606 | sum += img[yy * dim3 + xx] * v1 * v2; 1607 | cnt += v1 * v2; 1608 | } 1609 | } 1610 | } 1611 | out[id] = sum / cnt; 1612 | } 1613 | } 1614 | 1615 | int bilateral_filter(lua_State *L) { 1616 | THCState *state = getCutorchState(L); 1617 | THCudaTensor *img = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1618 | float sigma1 = luaL_checknumber(L, 2); 1619 | float sigma2 = luaL_checknumber(L, 3); 1620 | THCudaTensor *out = new_tensor_like(state, img); 1621 | int kernel_radius = ceil(min(sigma1, sigma2) * 3); 1622 | bilateral_filter<<<(THCudaTensor_nElement(state, out) - 1) / TB + 1, TB>>>( 1623 | THCudaTensor_data(state, img), 1624 | THCudaTensor_data(state, out), 1625 | THCudaTensor_nElement(state, out), 1626 | THCudaTensor_size(state, out, 2), 1627 | THCudaTensor_size(state, out, 3), 1628 | kernel_radius, sigma1, sigma2); 1629 | checkCudaError(L); 1630 | luaT_pushudata(L, out, "torch.CudaTensor"); 1631 | return 1; 1632 | } 1633 | 1634 | __global__ void median2d(float *img, float *out, int size, int dim2, int dim3, int kernel_radius) 1635 | { 1636 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1637 | if (id < size) { 1638 | int x = id % dim3; 1639 | int y = id / dim3; 1640 | 1641 | float xs[11 * 11]; 1642 | int xs_size = 0; 1643 | for (int xx = x - kernel_radius; xx <= x + kernel_radius; xx++) { 1644 | for (int yy = y - kernel_radius; yy <= y + kernel_radius; yy++) { 1645 | if (0 <= xx && xx < dim3 && 0 <= yy && yy < dim2) { 1646 | xs[xs_size++] = img[yy * dim3 + xx]; 1647 | } 1648 | } 1649 | } 1650 | sort(xs, xs_size); 1651 | out[id] = xs[xs_size / 2]; 1652 | } 1653 | } 1654 | 1655 | int median2d(lua_State *L) { 1656 | THCState *state = getCutorchState(L); 1657 | THCudaTensor *img = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1658 | int kernel_size = luaL_checkinteger(L, 2); 1659 | THCudaTensor *out = new_tensor_like(state, img); 1660 | assert(kernel_size % 2 == 1); 1661 | assert(kernel_size <= 11); 1662 | median2d<<<(THCudaTensor_nElement(state, out) - 1) / TB + 1, TB>>>( 1663 | THCudaTensor_data(state, img), 1664 | THCudaTensor_data(state, out), 1665 | THCudaTensor_nElement(state, out), 1666 | THCudaTensor_size(state, out, 2), 1667 | THCudaTensor_size(state, out, 3), 1668 | kernel_size / 2); 1669 | checkCudaError(L); 1670 | luaT_pushudata(L, out, "torch.CudaTensor"); 1671 | return 1; 1672 | } 1673 | 1674 | #if 0 1675 | int histogram(lua_State *L) { 1676 | THFloatTensor *img = (THFloatTensor*)luaT_checkudata(L, 1, "torch.FloatTensor"); 1677 | THIntTensor *hist = THIntTensor_newWithSize1d(256); 1678 | THIntTensor_zero(hist); 1679 | 1680 | float *img_data = THFloatTensor_data(img); 1681 | int *hist_data = THIntTensor_data(hist); 1682 | for (int i = 0; i < THFloatTensor_size(img, 2) * THFloatTensor_size(img, 3); i++) { 1683 | assert(0 <= img_data[i] && img_data[i] < 256); 1684 | hist_data[(int)img_data[i]]++; 1685 | } 1686 | luaT_pushudata(L, hist, "torch.IntTensor"); 1687 | return 1; 1688 | } 1689 | 1690 | int histogram_equalization_map(lua_State *L) { 1691 | THIntTensor *cdf = (THIntTensor*)luaT_checkudata(L, 1, "torch.IntTensor"); 1692 | THIntTensor *map = THIntTensor_new(); 1693 | THIntTensor_resizeAs(map, cdf); 1694 | 1695 | int *cdf_data = THIntTensor_data(cdf); 1696 | int max = cdf_data[255]; 1697 | int min = cdf_data[0]; 1698 | for (int i = 0; i < 256; i++) { 1699 | if (cdf_data[i]) { 1700 | min = cdf_data[i]; 1701 | break; 1702 | } 1703 | } 1704 | int *map_data = THIntTensor_data(map); 1705 | for (int i = 0; i < 256; i++) { 1706 | map_data[i] = round((double)(cdf_data[i] - min) / (max - min) * 255); 1707 | } 1708 | luaT_pushudata(L, map, "torch.IntTensor"); 1709 | return 1; 1710 | } 1711 | 1712 | int map_intensities(lua_State *L) { 1713 | THFloatTensor *img = (THFloatTensor*)luaT_checkudata(L, 1, "torch.FloatTensor"); 1714 | THIntTensor *map = (THIntTensor*)luaT_checkudata(L, 2, "torch.IntTensor"); 1715 | THFloatTensor *out = THFloatTensor_new(); 1716 | THFloatTensor_resizeAs(out, img); 1717 | 1718 | float *img_data = THFloatTensor_data(img); 1719 | float *out_data = THFloatTensor_data(out); 1720 | int *map_data = THIntTensor_data(map); 1721 | for (int i = 0; i < THFloatTensor_size(img, 2) * THFloatTensor_size(img, 3); i++) { 1722 | out_data[i] = map_data[(int)img_data[i]]; 1723 | } 1724 | luaT_pushudata(L, out, "torch.FloatTensor"); 1725 | return 1; 1726 | } 1727 | #endif 1728 | 1729 | int readPNG16(lua_State *L) 1730 | { 1731 | THFloatTensor *img_ = (THFloatTensor*)luaT_checkudata(L, 1, "torch.FloatTensor"); 1732 | const char* fname = luaL_checkstring(L, 2); 1733 | 1734 | float *img = THFloatTensor_data(img_); 1735 | png::image image(fname); 1736 | int width = image.get_width(); 1737 | int height = image.get_height(); 1738 | for (int i = 0; i < height; i++) { 1739 | for (int j = 0; j < width; j++) { 1740 | uint16_t val = image.get_pixel(j, i); 1741 | img[i * width + j] = val == 0 ? 0.0 : ((float)val)/256.0; 1742 | } 1743 | } 1744 | return 0; 1745 | } 1746 | 1747 | int writePNG16(lua_State *L) 1748 | { 1749 | THFloatTensor *img_ = (THFloatTensor*)luaT_checkudata(L, 1, "torch.FloatTensor"); 1750 | int height = luaL_checkinteger(L, 2); 1751 | int width = luaL_checkinteger(L, 3); 1752 | const char* fname = luaL_checkstring(L, 4); 1753 | 1754 | float *img = THFloatTensor_data(img_); 1755 | png::image image(width, height); 1756 | for (int i = 0; i < height; i++) { 1757 | for (int j = 0; j < width; j++) { 1758 | float val = img[i * width + j]; 1759 | image.set_pixel(j, i, (uint16_t)(val < 1e-5 ? 0 : val * 256)); 1760 | } 1761 | } 1762 | image.write(fname); 1763 | return 0; 1764 | } 1765 | 1766 | int writePFM(lua_State *L) 1767 | { 1768 | THFloatTensor *img_ = (THFloatTensor*)luaT_checkudata(L, 1, "torch.FloatTensor"); 1769 | const char* fname = luaL_checkstring(L, 2); 1770 | 1771 | int height = THFloatTensor_size(img_, 0); 1772 | int width = THFloatTensor_size(img_, 1); 1773 | 1774 | FILE *f = fopen(fname, "w"); 1775 | fprintf(f, "Pf\n%d %d\n-0.003922\n", width, height); 1776 | fwrite(THFloatTensor_data(img_), 4, height * width, f); 1777 | fclose(f); 1778 | 1779 | return 0; 1780 | } 1781 | 1782 | __global__ void remove_nonvisible(float *y, int size, int size3) 1783 | { 1784 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1785 | if (id < size) { 1786 | int x = id % size3; 1787 | if (y[id] >= x) { 1788 | y[id] = 0; 1789 | } 1790 | } 1791 | } 1792 | 1793 | int remove_nonvisible(lua_State *L) 1794 | { 1795 | THCState *state = getCutorchState(L); 1796 | THCudaTensor *y = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1797 | 1798 | remove_nonvisible<<<(THCudaTensor_nElement(state, y) - 1) / TB + 1, TB>>>( 1799 | THCudaTensor_data(state, y), 1800 | THCudaTensor_nElement(state, y), 1801 | THCudaTensor_size(state, y, 3)); 1802 | checkCudaError(L); 1803 | return 0; 1804 | } 1805 | 1806 | __global__ void remove_occluded(float *y, int size, int size3) 1807 | { 1808 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1809 | if (id < size) { 1810 | int x = id % size3; 1811 | for (int i = 1; x + i < size3; i++) { 1812 | if (i - y[id + i] < -y[id]) { 1813 | y[id] = 0; 1814 | break; 1815 | } 1816 | } 1817 | } 1818 | } 1819 | 1820 | int remove_occluded(lua_State *L) 1821 | { 1822 | THCState *state = getCutorchState(L); 1823 | THCudaTensor *y = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1824 | 1825 | remove_occluded<<<(THCudaTensor_nElement(state, y) - 1) / TB + 1, TB>>>( 1826 | THCudaTensor_data(state, y), 1827 | THCudaTensor_nElement(state, y), 1828 | THCudaTensor_size(state, y, 3)); 1829 | checkCudaError(L); 1830 | return 0; 1831 | } 1832 | 1833 | __global__ void remove_white(float *x, float *y, int size) 1834 | { 1835 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1836 | if (id < size) { 1837 | if (x[id] == 255) { 1838 | y[id] = 0; 1839 | } 1840 | } 1841 | } 1842 | 1843 | int remove_white(lua_State *L) 1844 | { 1845 | THCState *state = getCutorchState(L); 1846 | THCudaTensor *x = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1847 | THCudaTensor *y = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 1848 | 1849 | remove_white<<<(THCudaTensor_nElement(state, y) - 1) / TB + 1, TB>>>( 1850 | THCudaTensor_data(state, x), 1851 | THCudaTensor_data(state, y), 1852 | THCudaTensor_nElement(state, y)); 1853 | checkCudaError(L); 1854 | return 0; 1855 | } 1856 | 1857 | __global__ void copy_fill(float *in, float *out, int size, int in_size2, int in_size3, int out_size2, int out_size3) 1858 | { 1859 | int id = blockIdx.x * blockDim.x + threadIdx.x; 1860 | if (id < size) { 1861 | int out_x = id % out_size3; 1862 | int out_y = id / out_size3; 1863 | 1864 | int in_x = out_x - (out_size3 - in_size3) / 2; 1865 | int in_y = out_y - (out_size2 - in_size2) / 2; 1866 | 1867 | int x = min(in_size3 - 1, max(0, in_x)); 1868 | int y = min(in_size2 - 1, max(0, in_y)); 1869 | 1870 | out[id] = in[y * in_size3 + x]; 1871 | } 1872 | } 1873 | 1874 | int copy_fill(lua_State *L) 1875 | { 1876 | THCState *state = getCutorchState(L); 1877 | THCudaTensor *in = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor"); 1878 | THCudaTensor *out = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor"); 1879 | 1880 | copy_fill<<<(THCudaTensor_nElement(state, out) - 1) / TB + 1, TB>>>( 1881 | THCudaTensor_data(state, in), 1882 | THCudaTensor_data(state, out), 1883 | THCudaTensor_nElement(state, out), 1884 | THCudaTensor_size(state, in, 2), 1885 | THCudaTensor_size(state, in, 3), 1886 | THCudaTensor_size(state, out, 2), 1887 | THCudaTensor_size(state, out, 3)); 1888 | checkCudaError(L); 1889 | luaT_pushudata(L, out, "torch.CudaTensor"); 1890 | return 1; 1891 | } 1892 | 1893 | void memcpy2d(float *dst, float *src, int x, int y, int win_radius, int height, int width) 1894 | { 1895 | assert(0 <= x - win_radius); 1896 | assert(x + win_radius <= width); 1897 | assert(0 <= y - win_radius); 1898 | assert(y + win_radius <= height); 1899 | for (int i = -win_radius; i <= win_radius; i++) { 1900 | memcpy(dst, src + (y + i) * width + x - win_radius, (win_radius * 2 + 1) * sizeof(float)); 1901 | dst += win_radius * 2 + 1; 1902 | } 1903 | } 1904 | 1905 | double random_uniform() 1906 | { 1907 | return ((double)rand()/(double)RAND_MAX); 1908 | } 1909 | 1910 | int random_int(int a, int b) 1911 | { 1912 | assert(a <= b); 1913 | return floor(random_uniform() * (b - a + 1) + a); 1914 | } 1915 | 1916 | double random_exp(double lambda) 1917 | { 1918 | double u = random_uniform(); 1919 | return -log(u) / lambda; 1920 | } 1921 | 1922 | int subset_dataset(lua_State *L) 1923 | { 1924 | THLongTensor *index_ = (THLongTensor*)luaT_checkudata(L, 1, "torch.LongTensor"); 1925 | THFloatTensor *input_ = (THFloatTensor*)luaT_checkudata(L, 2, "torch.FloatTensor"); 1926 | THFloatTensor *output_ = (THFloatTensor*)luaT_checkudata(L, 3, "torch.FloatTensor"); 1927 | 1928 | long *index = THLongTensor_data(index_); 1929 | float *input = THFloatTensor_data(input_); 1930 | float *output = THFloatTensor_data(output_); 1931 | 1932 | const int N = 200; 1933 | 1934 | int set[N]; 1935 | for (int i = 0; i < N; i++) { 1936 | set[i] = 0; 1937 | } 1938 | 1939 | for (int i = 0; i < THLongTensor_nElement(index_); i++) { 1940 | assert(index[i] < N); 1941 | set[index[i]] = 1; 1942 | } 1943 | 1944 | int i = 0; 1945 | for (int j = 0; j < THFloatTensor_size(input_, 0); j++) { 1946 | int im = input[j * 4]; 1947 | if (set[im]) { 1948 | for (int k = 0; k < 4; k++) { 1949 | output[i * 4 + k] = input[j * 4 + k]; 1950 | } 1951 | i++; 1952 | } 1953 | } 1954 | 1955 | lua_pushinteger(L, i); 1956 | return 1; 1957 | } 1958 | 1959 | int make_dataset2(lua_State *L) 1960 | { 1961 | THFloatTensor *disp_ = (THFloatTensor*)luaT_checkudata(L, 1, "torch.FloatTensor"); 1962 | THFloatTensor *nnz_ = (THFloatTensor*)luaT_checkudata(L, 2, "torch.FloatTensor"); 1963 | int img = luaL_checkinteger(L, 3); 1964 | int t = luaL_checkinteger(L, 4); 1965 | 1966 | float *disp = THFloatTensor_data(disp_); 1967 | float *nnz = THFloatTensor_data(nnz_); 1968 | 1969 | int height = THFloatTensor_size(disp_, 2); 1970 | int width = THFloatTensor_size(disp_, 3); 1971 | int nnz_size = THFloatTensor_nElement(nnz_); 1972 | 1973 | for (int i = 0; i < height; i++) { 1974 | for (int j = 0; j < width; j++) { 1975 | if (disp[i * width + j] > 0.5) { 1976 | assert(t * 4 + 4 <= nnz_size); 1977 | nnz[t * 4 + 0] = img; 1978 | nnz[t * 4 + 1] = i; 1979 | nnz[t * 4 + 2] = j; 1980 | nnz[t * 4 + 3] = disp[i * width + j]; 1981 | t++; 1982 | } 1983 | } 1984 | } 1985 | 1986 | lua_pushinteger(L, t); 1987 | return 1; 1988 | } 1989 | 1990 | int make_dataset(lua_State *L) 1991 | { 1992 | THFloatTensor *x0_ = (THFloatTensor*)luaT_checkudata(L, 1, "torch.FloatTensor"); 1993 | THFloatTensor *x1_ = (THFloatTensor*)luaT_checkudata(L, 2, "torch.FloatTensor"); 1994 | THFloatTensor *disp_ = (THFloatTensor*)luaT_checkudata(L, 3, "torch.FloatTensor"); 1995 | THFloatTensor *x_ = (THFloatTensor*)luaT_checkudata(L, 4, "torch.FloatTensor"); 1996 | THFloatTensor *y_ = (THFloatTensor*)luaT_checkudata(L, 5, "torch.FloatTensor"); 1997 | int t = luaL_checkinteger(L, 6); 1998 | float thr_true = luaL_checknumber(L, 7); 1999 | float thr_false_l = luaL_checknumber(L, 8); 2000 | float thr_false_u = luaL_checknumber(L, 9); 2001 | 2002 | float *x0 = THFloatTensor_data(x0_); 2003 | float *x1 = THFloatTensor_data(x1_); 2004 | float *disp = THFloatTensor_data(disp_); 2005 | float *x = THFloatTensor_data(x_); 2006 | float *y = THFloatTensor_data(y_); 2007 | 2008 | int height = THFloatTensor_size(x0_, 2); 2009 | int width = THFloatTensor_size(x0_, 3); 2010 | int win_size = THFloatTensor_size(x_, 2); 2011 | int x_size = THFloatTensor_size(x_, 0); 2012 | assert(win_size % 2 == 1); 2013 | int win_radius = (win_size - 1) / 2; 2014 | 2015 | x += t * 2 * win_size * win_size; 2016 | for (int i = win_radius; i < height - win_radius; i++) { 2017 | for (int j = win_radius; j < width - win_radius; j++) { 2018 | if (disp[i * width + j] > 0.5) { 2019 | int d_true = round(disp[i * width + j]); 2020 | if (0 <= j - d_true - win_radius) { 2021 | /* true offset */ 2022 | int delta = 0; 2023 | for (;;) { 2024 | delta = random_int(-thr_true, thr_true); 2025 | if (0 <= j - d_true + delta - win_radius && j - d_true + delta + win_radius < width) { 2026 | break; 2027 | } 2028 | } 2029 | assert(t < x_size); 2030 | memcpy2d(x, x0, j, i, win_radius, height, width); x += win_size * win_size; 2031 | memcpy2d(x, x1, j - d_true + delta, i, win_radius, height, width); x += win_size * win_size; 2032 | y[t] = 1; 2033 | t++; 2034 | 2035 | /* false offset */ 2036 | delta = 0; 2037 | for (;;) { 2038 | delta = random_int(thr_false_l, thr_false_u); 2039 | if (random_uniform() < 0.5) { 2040 | delta = -delta; 2041 | } 2042 | if (0 <= j - d_true + delta - win_radius && j - d_true + delta + win_radius < width) { 2043 | break; 2044 | } 2045 | } 2046 | assert(t < x_size); 2047 | memcpy2d(x, x0, j, i, win_radius, height, width); x += win_size * win_size; 2048 | memcpy2d(x, x1, j - d_true + delta, i, win_radius, height, width); x += win_size * win_size; 2049 | y[t] = 0; 2050 | t++; 2051 | } 2052 | } 2053 | } 2054 | } 2055 | lua_pushinteger(L, t); 2056 | return 1; 2057 | } 2058 | 2059 | /* CPU implementation */ 2060 | int grey2jet(lua_State *L) 2061 | { 2062 | THDoubleTensor *grey_img = (THDoubleTensor*)luaT_checkudata(L, 1, "torch.DoubleTensor"); 2063 | THDoubleTensor *col_img = (THDoubleTensor*)luaT_checkudata(L, 2, "torch.DoubleTensor"); 2064 | 2065 | assert(grey_img->nDimension == 2); 2066 | if (3 * THDoubleTensor_nElement(grey_img) != THDoubleTensor_nElement(col_img)) { 2067 | luaL_error(L, "Size mismatch"); 2068 | } 2069 | 2070 | int height = THDoubleTensor_size(grey_img, 0); 2071 | int width = THDoubleTensor_size(grey_img, 1); 2072 | 2073 | double *gray_data = THDoubleTensor_data(grey_img); 2074 | double *col_data = THDoubleTensor_data(col_img); 2075 | 2076 | for (int i = 0; i < height; i++) { 2077 | for (int j = 0; j < width; j++) { 2078 | double val = gray_data[i * width + j] * 4; 2079 | double r = 0, g = 0, b = 0; 2080 | 2081 | if (-0.1 <= val && val < 0.5) { 2082 | r = 0; 2083 | g = 0; 2084 | b = 0.5 + val; 2085 | } else if (0.5 <= val && val < 1.5) { 2086 | r = 0; 2087 | g = val - 0.5; 2088 | b = 1; 2089 | } else if (1.5 <= val && val < 2.5) { 2090 | r = val - 1.5; 2091 | g = 1; 2092 | b = 1 - (val - 1.5); 2093 | } else if (2.5 <= val && val < 3.5) { 2094 | r = 1; 2095 | g = 1 - (val - 2.5); 2096 | b = 0; 2097 | } else if (3.5 <= val && val <= 4.1) { 2098 | r = 1 - (val - 3.5); 2099 | g = 0; 2100 | b = 0; 2101 | } else { 2102 | printf("val = %f\n", val); 2103 | assert(0); 2104 | } 2105 | 2106 | col_data[(0 * height + i) * width + j] = r; 2107 | col_data[(1 * height + i) * width + j] = g; 2108 | col_data[(2 * height + i) * width + j] = b; 2109 | } 2110 | } 2111 | return 0; 2112 | } 2113 | 2114 | static const struct luaL_Reg funcs[] = { 2115 | {"ad", ad}, 2116 | {"census", census}, 2117 | {"cross", cross}, 2118 | {"cbca", cbca}, 2119 | {"sgm", sgm}, 2120 | {"sgm2", sgm2}, 2121 | {"sgm3", sgm3}, 2122 | {"outlier_detection", outlier_detection}, 2123 | {"interpolate_occlusion", interpolate_occlusion}, 2124 | {"interpolate_mismatch", interpolate_mismatch}, 2125 | {"subpixel_enchancement", subpixel_enchancement}, 2126 | {"copy_fill", copy_fill}, 2127 | {"median2d", median2d}, 2128 | {"mean2d", mean2d}, 2129 | {"Normalize_forward", Normalize_forward}, 2130 | {"Normalize_backward_input", Normalize_backward_input}, 2131 | {"Margin2", Margin2}, 2132 | {"StereoJoin", StereoJoin}, 2133 | {"StereoL2R", StereoL2R}, 2134 | {"L2dist", L2dist}, 2135 | {"subset_dataset", subset_dataset}, 2136 | {"make_dataset", make_dataset}, 2137 | {"make_dataset2", make_dataset2}, 2138 | {"remove_nonvisible", remove_nonvisible}, 2139 | {"remove_occluded", remove_occluded}, 2140 | {"remove_white", remove_white}, 2141 | {"readPNG16", readPNG16}, 2142 | {"writePNG16", writePNG16}, 2143 | {"writePFM", writePFM}, 2144 | {"grey2jet", grey2jet}, 2145 | {"spatial_argmin", spatial_argmin}, 2146 | {NULL, NULL} 2147 | }; 2148 | 2149 | extern "C" int luaopen_libadcensus(lua_State *L) { 2150 | srand(42); 2151 | luaL_openlib(L, "adcensus", funcs, 0); 2152 | return 1; 2153 | } 2154 | --------------------------------------------------------------------------------