├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── cmake └── FindMatlab.cmake ├── data ├── 000000_10.png ├── 000000_11.png ├── frame_0001.png └── frame_0002.png ├── matlab ├── Matlabdef.def ├── dcflow.m └── demo.m ├── net ├── deploy.prototxt ├── kitti.caffemodel └── sintel.caffemodel └── src_cl ├── CMakeLists.txt ├── dcflow.cpp ├── dcflow.h ├── dcflow_mex.cpp ├── helper.h ├── sgm.cl ├── sgm.cpp └── sgm.h /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | matlab/external 3 | matlab/.DS_Store 4 | build 5 | matches*.txt 6 | *.mexa64 7 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.2) 2 | project(dcflow) 3 | 4 | set(CMAKE_INSTALL_PREFIX ${CMAKE_SOURCE_DIR}) 5 | set(CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) 6 | 7 | add_definitions(/DMATLAB_MEX_FILE) 8 | find_package(Matlab REQUIRED) 9 | 10 | # Build DC Flow 11 | add_subdirectory(src_cl) 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) Intel Corporation 2017 4 | Jia Xu, René Ranftl, Vladlen Koltun 5 | Intel Labs 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DISCONTINUATION OF PROJECT # 2 | This project will no longer be maintained by Intel. 3 | Intel has ceased development and contributions including, but not limited to, maintenance, bug fixes, new releases, or updates, to this project. 4 | Intel no longer accepts patches to this project. 5 | If you have an ongoing need to use this project, are interested in independently developing it, or would like to maintain patches for the open source software community, please create your own fork of this project. 6 | 7 | # Code for the paper "Accurate Optical Flow via Direct Cost Volume Processing. Jia Xu, René Ranftl, and Vladlen Koltun. CVPR 2017" 8 | 9 | If you use this code or the provided models in your research, please cite the following paper: 10 | 11 | @inproceedings{XRK2017, 12 | author = {Jia Xu and Ren\'e Ranftl and Vladlen Koltun}, 13 | title = {{Accurate Optical Flow via Direct Cost Volume Processing}}, 14 | booktitle = {CVPR}, 15 | year = {2017}, 16 | } 17 | 18 | ## Dependencies 19 | 20 | - CMake 3.2 21 | - Caffe + MatCaffe (needs to be in Matlab path, see ``matlab/demo.m``) 22 | - We have seen issues with the latest version of caffe. We recommend to use this brunch (training and testing): https://github.com/Wangyida/caffe/tree/cnn_triplet 23 | - OpenCL 24 | 25 | ## Setup 26 | 27 | - Set path to OpenCL SDK: 28 | - For Intel OpenCL set `export INTELOCLSDKROOT=`, e.g., `export INTELOCLSDKROOT=/usr/local/intel/opencl` 29 | - For NVIDIA OpenCL set `export CUDA_PATH=`, e.g., `export CUDA_PATH=/usr/local/cuda` 30 | - For AMD OpenCL set `export AMDAPPSDKROOD=`, e.g., `export AMDAPPSDKROOD=/usr/local/amd/opencl` 31 | - Set ``MATLAB_ROOT`` environment variable, e.g., `export MATLAB_ROOT=/usr/local/MATLAB/R2017a` 32 | - ``mkdir build`` 33 | - ``cd build`` 34 | - ``cmake ..`` 35 | - ``make`` 36 | - ``make install`` 37 | 38 | ## Running the code: 39 | See ``matlab/demo.m`` 40 | 41 | ## Log 42 | - Version 0.1, 2017-07-20 43 | 44 | Includes feature embedding code/models, 4-D cost volume construction and processing, and forward-backward consistency checking. Part of poster-processing (EpicFlow inpainting, homography fitting) can not be included due to license issues. We expect to release them in future versions. You may download the EpicFlow code (or other inpainting code), and replace the match file with our output to obtain a dense optical flow filed. 45 | -------------------------------------------------------------------------------- /cmake/FindMatlab.cmake: -------------------------------------------------------------------------------- 1 | # - this module looks for Matlab 2 | # Defines: 3 | # MATLAB_INCLUDE_DIR: include path for mex.h 4 | # MATLAB_LIBRARIES: required libraries: libmex, libmx 5 | # MATLAB_MEX_LIBRARY: path to libmex 6 | # MATLAB_MX_LIBRARY: path to libmx 7 | 8 | SET(MATLAB_FOUND 0) 9 | IF( "$ENV{MATLAB_ROOT}" STREQUAL "" ) 10 | MESSAGE(STATUS "MATLAB_ROOT environment variable not set." ) 11 | MESSAGE(STATUS "In Linux this can be done in your user .bashrc file by appending the corresponding line, e.g:" ) 12 | MESSAGE(STATUS "export MATLAB_ROOT=/usr/local/MATLAB/R2012b" ) 13 | MESSAGE(STATUS "In Windows this can be done by adding system variable, e.g:" ) 14 | MESSAGE(STATUS "MATLAB_ROOT=D:\\Program Files\\MATLAB\\R2011a" ) 15 | ELSE("$ENV{MATLAB_ROOT}" STREQUAL "" ) 16 | 17 | FIND_PATH(MATLAB_INCLUDE_DIR mex.h 18 | $ENV{MATLAB_ROOT}/extern/include) 19 | 20 | INCLUDE_DIRECTORIES(${MATLAB_INCLUDE_DIR}) 21 | 22 | FIND_LIBRARY( MATLAB_MEX_LIBRARY 23 | NAMES libmex mex 24 | PATHS $ENV{MATLAB_ROOT}/bin $ENV{MATLAB_ROOT}/extern/lib 25 | PATH_SUFFIXES glnxa64 glnx86 win64/microsoft win32/microsoft) 26 | 27 | FIND_LIBRARY( MATLAB_MX_LIBRARY 28 | NAMES libmx mx 29 | PATHS $ENV{MATLAB_ROOT}/bin $ENV{MATLAB_ROOT}/extern/lib 30 | PATH_SUFFIXES glnxa64 glnx86 win64/microsoft win32/microsoft) 31 | 32 | MESSAGE (STATUS "MATLAB_ROOT: $ENV{MATLAB_ROOT}") 33 | 34 | ENDIF("$ENV{MATLAB_ROOT}" STREQUAL "" ) 35 | 36 | # This is common to UNIX and Win32: 37 | SET(MATLAB_LIBRARIES 38 | ${MATLAB_MEX_LIBRARY} 39 | ${MATLAB_MX_LIBRARY} 40 | ) 41 | 42 | IF(MATLAB_INCLUDE_DIR AND MATLAB_LIBRARIES) 43 | SET(MATLAB_FOUND 1) 44 | MESSAGE(STATUS "Matlab libraries will be used") 45 | ENDIF(MATLAB_INCLUDE_DIR AND MATLAB_LIBRARIES) 46 | 47 | MARK_AS_ADVANCED( 48 | MATLAB_LIBRARIES 49 | MATLAB_MEX_LIBRARY 50 | MATLAB_MX_LIBRARY 51 | MATLAB_INCLUDE_DIR 52 | MATLAB_FOUND 53 | MATLAB_ROOT 54 | ) 55 | -------------------------------------------------------------------------------- /data/000000_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/dcflow/12b4df2ab761f80ddd895b84e1c71a88f2a140d6/data/000000_10.png -------------------------------------------------------------------------------- /data/000000_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/dcflow/12b4df2ab761f80ddd895b84e1c71a88f2a140d6/data/000000_11.png -------------------------------------------------------------------------------- /data/frame_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/dcflow/12b4df2ab761f80ddd895b84e1c71a88f2a140d6/data/frame_0001.png -------------------------------------------------------------------------------- /data/frame_0002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/dcflow/12b4df2ab761f80ddd895b84e1c71a88f2a140d6/data/frame_0002.png -------------------------------------------------------------------------------- /matlab/Matlabdef.def: -------------------------------------------------------------------------------- 1 | EXPORTS mexFunction -------------------------------------------------------------------------------- /matlab/dcflow.m: -------------------------------------------------------------------------------- 1 | function matches = dcflow(im1_ori, im2_ori, param, net) 2 | maxDisp = param.maxDisp; ratio = param.ratio; 3 | im1 = im2double(im1_ori); im2 = im2double(im2_ori); 4 | im1 = max(min(imresize(im1(1:end-mod(size(im1_ori,1),ratio),1:end-mod(size(im1_ori,2),ratio),:),1/ratio),1),0); 5 | im2 = max(min(imresize(im2(1:end-mod(size(im2_ori,1),ratio),1:end-mod(size(im2_ori,2),ratio),:),1/ratio),1),0); 6 | 7 | [M, N, ~] = size(im1); 8 | range = -floor(maxDisp/ratio):floor(maxDisp/ratio); 9 | r_max = range(end); 10 | 11 | im1_ = zeros(M+8, N+8, 3, 'single'); im2_ = im1_; 12 | for i = 1:3 13 | p1 = single(im1(:,:,i)); p2 = single(im2(:,:,i)); 14 | im1_(:,:,i) = padarray((p1 - mean(p1(:)))/std(p1(:)), [4, 4], 'symmetric', 'both'); 15 | im2_(:,:,i) = padarray((p2 - mean(p2(:)))/std(p2(:)), [4, 4], 'symmetric', 'both'); 16 | end 17 | 18 | im1_ = permute(im1_, [2,1,3]); im2_ = permute(im2_, [2,1,3]); 19 | [M2, N2, ~] = size(im1_); 20 | net.blobs('data').reshape([M2 N2 3 2]); net.reshape(); 21 | feat = net.forward({cat(4, im1_, im2_)}); 22 | feat_1 = feat{1}(:,:,:,1); feat_2 = feat{1}(:,:,:,2); 23 | feat_1 = permute(feat_1, [2,1,3]); feat_2 = permute(feat_2, [2,1,3]); 24 | feat_1_n = sqrt(sum(feat_1.^2, 3)+1e-12); feat_2_n = sqrt(sum(feat_2.^2, 3)+1e-12); 25 | feat_1 = bsxfun(@rdivide, feat_1, feat_1_n); feat_2 = bsxfun(@rdivide, feat_2, feat_2_n); 26 | 27 | [forward, backward] = dcflow_mex(feat_1, feat_2, ... 28 | single(rgb2gray(im1)), single(rgb2gray(im2)), ... 29 | r_max, param.outOfRange, param.P1, param.P2); 30 | 31 | forward = single(permute(forward, [2, 3, 1])); backward = single(permute(backward, [2, 3, 1])); 32 | matches = filter_matches(forward, backward, ratio, param.occ_threshold); 33 | end 34 | 35 | function [matches, occ] = filter_matches(forward, backward, ratio, threshold) 36 | [m,n,~] = size(forward); 37 | u = forward(:,:,1); v = forward(:,:,2); 38 | [x1, y1] = meshgrid(1:n, 1:m); 39 | x2 = x1 + u; y2 = y1 + v; 40 | 41 | % out-of-boundary pixels 42 | B = (x2>n) | (x2<1) | (y2>m) | (y2<1); 43 | x2(B) = x1(B); y2(B) = y1(B); 44 | inv_u = interp2(backward(:,:,1), x2, y2, 'linear', 0); inv_v = interp2(backward(:,:,2), x2, y2, 'linear', 0); 45 | occ = ((u+inv_u).^2 + (v+inv_v).^2) <= threshold; occ(B) = 0; 46 | 47 | % count bump in the second image 48 | count = zeros(m,n); 49 | for r=1:m, 50 | for c=1:n 51 | c1 = round(x2(r,c)); r1 = round(y2(r,c)); count(r1, c1) = count(r1, c1)+1; 52 | end 53 | end 54 | occ(count>=2) = 0; 55 | 56 | j = x1; i = y1; 57 | matches = [ratio*j(:) - ceil(ratio/2), ratio*i(:) - ceil(ratio/2), ratio*(j(:)+u(:)) - ceil(ratio/2), ratio*(i(:)+v(:)) - ceil(ratio/2)]; 58 | valid = find(occ); 59 | matches = matches(valid, :); 60 | end 61 | -------------------------------------------------------------------------------- /matlab/demo.m: -------------------------------------------------------------------------------- 1 | close all; clear all; 2 | 3 | addpath('external/caffe/matlab') 4 | 5 | param.model_file = '../net/deploy.prototxt'; 6 | param.maxDisp = 242; % maximum displacement for both x and y direction 7 | param.ratio = 3; % downsample scale 8 | param.P1 = 7; % SGM param 9 | param.P2 = 485; % SGM param 10 | param.outOfRange = 0.251; % Default cost for out-of-range displacements 11 | param.occ_threshold = 0.8; % threshold for fwd+bwd consisntency check 12 | 13 | if 1 %% KITTI 14 | param.weight_file = '../net/kitti.caffemodel'; 15 | im1 = imread('../data/000000_10.png'); 16 | im2 = imread('../data/000000_11.png'); 17 | else %% Sintel 18 | param.P2 = 600; 19 | param.weight_file = '../net/sintel.caffemodel'; 20 | im1 = imread('../data/frame_0001.png'); 21 | im2 = imread('../data/frame_0002.png'); 22 | end 23 | 24 | caffe.set_mode_gpu(); 25 | net = caffe.Net(param.model_file, param.weight_file, 'test'); 26 | 27 | matches = dcflow(im1, im2, param, net); 28 | 29 | dlmwrite('matches.txt', matches, ' '); 30 | -------------------------------------------------------------------------------- /net/deploy.prototxt: -------------------------------------------------------------------------------- 1 | name: "train_triplet_9" 2 | 3 | input: "data" 4 | input_shape { 5 | dim: 2 6 | dim: 3 7 | dim: 100 8 | dim: 100 9 | } 10 | 11 | layer { 12 | name: "Convolution1" 13 | type: "Convolution" 14 | bottom: "data" 15 | top: "Convolution1" 16 | convolution_param { 17 | num_output: 64 18 | kernel_size: 3 19 | weight_filler { 20 | type: "xavier" 21 | } 22 | } 23 | } 24 | layer { 25 | name: "conv1" 26 | type: "ReLU" 27 | bottom: "Convolution1" 28 | top: "conv1" 29 | } 30 | layer { 31 | name: "Convolution2" 32 | type: "Convolution" 33 | bottom: "conv1" 34 | top: "Convolution2" 35 | convolution_param { 36 | num_output: 64 37 | kernel_size: 3 38 | weight_filler { 39 | type: "xavier" 40 | } 41 | } 42 | } 43 | layer { 44 | name: "conv2" 45 | type: "ReLU" 46 | bottom: "Convolution2" 47 | top: "conv2" 48 | } 49 | layer { 50 | name: "Convolution3" 51 | type: "Convolution" 52 | bottom: "conv2" 53 | top: "Convolution3" 54 | convolution_param { 55 | num_output: 64 56 | kernel_size: 3 57 | weight_filler { 58 | type: "xavier" 59 | } 60 | } 61 | } 62 | layer { 63 | name: "conv3" 64 | type: "ReLU" 65 | bottom: "Convolution3" 66 | top: "conv3" 67 | } 68 | layer { 69 | name: "conv4" 70 | type: "Convolution" 71 | bottom: "conv3" 72 | top: "conv4" 73 | convolution_param { 74 | num_output: 64 75 | kernel_size: 3 76 | weight_filler { 77 | type: "xavier" 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /net/kitti.caffemodel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/dcflow/12b4df2ab761f80ddd895b84e1c71a88f2a140d6/net/kitti.caffemodel -------------------------------------------------------------------------------- /net/sintel.caffemodel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/dcflow/12b4df2ab761f80ddd895b84e1c71a88f2a140d6/net/sintel.caffemodel -------------------------------------------------------------------------------- /src_cl/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(OpenCL REQUIRED) 2 | 3 | message(STATUS ${OpenCL_INCLUDE_DIRS}) 4 | message(STATUS ${OpenCL_LIBRARY}) 5 | 6 | set(CL_FILE dcflow) 7 | 8 | include_directories(${MATLAB_INCLUDE_DIR} ${OpenCL_INCLUDE_DIRS}) 9 | 10 | add_library(${CL_FILE} SHARED 11 | dcflow_mex.cpp 12 | dcflow.cpp 13 | sgm.cpp 14 | ${CMAKE_SOURCE_DIR}/matlab/Matlabdef.def) 15 | target_link_libraries(${CL_FILE} ${MATLAB_LIBRARIES} ${OpenCL_LIBRARIES} dl) 16 | 17 | set_target_properties(${CL_FILE} PROPERTIES PREFIX "" LINKER_LANGUAGE CXX) 18 | set_target_properties(${CL_FILE} PROPERTIES OUTPUT_NAME "dcflow_mex") 19 | 20 | set_property(TARGET ${CL_FILE} PROPERTY CXX_STANDARD 11) 21 | set_property(TARGET ${CL_FILE} PROPERTY CXX_STANDARD_REQUIRED ON) 22 | 23 | # 32-bit or 64-bit mex 24 | if(WIN32) 25 | if (CMAKE_CL_64) 26 | set_target_properties(${CL_FILE} PROPERTIES SUFFIX .mexw64) 27 | else(CMAKE_CL_64) 28 | set_target_properties(${CL_FILE} PROPERTIES SUFFIX .mexw32) 29 | endif(CMAKE_CL_64) 30 | else(WIN32) 31 | if (CMAKE_SIZEOF_VOID_P MATCHES "8") 32 | set_target_properties(${CL_FILE} PROPERTIES SUFFIX .mexa64 PREFIX "") 33 | else(CMAKE_SIZEOF_VOID_P MATCHES "8") 34 | set_target_properties(${CL_FILE} PROPERTIES SUFFIX .mexglx PREFIX "") 35 | endif (CMAKE_SIZEOF_VOID_P MATCHES "8") 36 | endif(WIN32) 37 | 38 | install(TARGETS ${CL_FILE} DESTINATION matlab) 39 | -------------------------------------------------------------------------------- /src_cl/dcflow.cpp: -------------------------------------------------------------------------------- 1 | #include "dcflow.h" 2 | #include "sgm.h" 3 | #include "helper.h" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #define PROFILE 10 | 11 | #ifdef PROFILE 12 | using namespace std::chrono; 13 | #endif 14 | 15 | void sgmflow(const float *feat1, const float *feat2, 16 | const float *im1, const float *im2, 17 | int M, int N, int L, 18 | int max_offset, float out_of_range, 19 | int P1, int P2, 20 | int16_t *unary1, int16_t *unary2) { 21 | 22 | std::vector all_plat; 23 | 24 | std::cout << M << "/" << N << "/" << max_offset << std::endl; 25 | 26 | cl::Platform::get(&all_plat); 27 | 28 | if (all_plat.size() == 0) { 29 | std::cout << "ERROR: No OpenCL platforms found, check OpenCL installation" 30 | << std::endl; 31 | return; 32 | } 33 | cl::Platform default_platform = all_plat[0]; 34 | 35 | std::vector all_dev; 36 | default_platform.getDevices(CL_DEVICE_TYPE_GPU, &all_dev); 37 | 38 | if (all_dev.size() == 0) { 39 | std::cout << "ERROR: no OpenCL device found" << std::endl; 40 | return; 41 | } 42 | 43 | cl::Device default_device = all_dev[0]; 44 | 45 | cl::Context context(default_device); 46 | 47 | cl::Program::Sources sources; 48 | 49 | std::string kernels = readKernel("../src_cl/sgm.cl"); 50 | 51 | sources.push_back({kernels.c_str(), kernels.length()}); 52 | 53 | cl::Program program(context, sources); 54 | if (program.build({default_device}, "-cl-mad-enable -cl-fast-relaxed-math -cl-single-precision-constant") != CL_SUCCESS) { 55 | std::cout<<"ERROR: Error building: "<< 56 | program.getBuildInfo(default_device)<< std::endl; 57 | } 58 | 59 | cl::CommandQueue queue(context, default_device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE); 60 | 61 | // Pitched memory doesn't help 62 | // Input data 63 | cl_int err; 64 | 65 | cl::Buffer unary1_d(context, CL_MEM_READ_WRITE, L*M*N*sizeof(uint8_t), NULL, &err); 66 | CL_CHECK_ERR_R(queue.enqueueFillBuffer(unary1_d, uint8_t(255*out_of_range), 0, L*M*N*sizeof(uint8_t))); 67 | 68 | cl::Buffer unary2_d(context, CL_MEM_READ_WRITE, L*M*N*sizeof(uint8_t)); 69 | CL_CHECK_ERR_R(queue.enqueueFillBuffer(unary2_d, uint8_t(255*out_of_range), 0, L*M*N*sizeof(uint8_t))); 70 | 71 | cl::Buffer feat1_d(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, FEAT_DIM*M*N*sizeof(float), (void*)feat1); 72 | cl::Buffer feat2_d(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, FEAT_DIM*M*N*sizeof(float), (void*)feat2); 73 | 74 | CL_CHECK_ERR_R(queue.finish()); 75 | 76 | // Match 77 | #ifdef PROFILE 78 | high_resolution_clock::time_point match_start = high_resolution_clock::now(); 79 | #endif 80 | 81 | std::vector match_evt(1); 82 | cl::Kernel match_kernel(program, "match"); 83 | match_kernel.setArg(0, feat1_d); match_kernel.setArg(1, feat2_d); 84 | match_kernel.setArg(2, M); match_kernel.setArg(3, N); 85 | match_kernel.setArg(4, max_offset); 86 | match_kernel.setArg(5, unary1_d); match_kernel.setArg(6, unary2_d); 87 | CL_CHECK_ERR_R(queue.enqueueNDRangeKernel(match_kernel, cl::NDRange(3, 3), 88 | cl::NDRange(M-6, N-6), 89 | cl::NullRange, NULL, &match_evt[0])); 90 | // queue.enqueueNDRangeKernel(match_kernel, cl::NullRange, 91 | // cl::NDRange(M,DIVUP(N,4)), 92 | // cl::NDRange(M,4), NULL, &match_evt[0]); 93 | 94 | // Fix border 95 | cl::Kernel fix_kernel(program, "fix_border"); 96 | fix_kernel.setArg(0, unary1_d); fix_kernel.setArg(1, unary2_d); 97 | fix_kernel.setArg(2, M); fix_kernel.setArg(3, N); fix_kernel.setArg(4, L); 98 | CL_CHECK_ERR_R(queue.enqueueNDRangeKernel(fix_kernel, cl::NullRange, 99 | cl::NDRange(M, N), 100 | cl::NullRange, &match_evt, NULL)); 101 | 102 | // Set up buffers for SGM 103 | cl::Buffer im1_d(context, CL_MEM_READ_ONLY, M*N*sizeof(float)); 104 | CL_CHECK_ERR_R(queue.enqueueWriteBuffer(im1_d, CL_FALSE, 0, M*N*sizeof(float), im1)); 105 | 106 | cl::Buffer im2_d(context, CL_MEM_READ_ONLY, M*N*sizeof(float)); 107 | CL_CHECK_ERR_R(queue.enqueueWriteBuffer(im2_d, CL_FALSE, 0, M*N*sizeof(float), im2)); 108 | 109 | #ifdef PROFILE 110 | match_evt[0].wait(); 111 | high_resolution_clock::time_point match_end = high_resolution_clock::now(); 112 | 113 | auto match_dur = duration_cast(match_end - match_start).count(); 114 | 115 | std::cout << "Matching: " << match_dur << " ms" << std::endl; 116 | #endif 117 | 118 | #ifdef PROFILE 119 | high_resolution_clock::time_point sgm_start = high_resolution_clock::now(); 120 | #endif 121 | 122 | // Set up SGM 123 | SGM sgm(M, N, L, max_offset, P1, P2, queue, context, program); 124 | 125 | // Forward 126 | sgm.process(unary1_d, im1_d); 127 | CL_CHECK_ERR_R(queue.enqueueReadBuffer(*sgm.recoverFlow(), CL_FALSE, 0, 128 | M*N*sizeof(cl_short2), unary1)); 129 | 130 | // Backward 131 | sgm.process(unary2_d, im2_d); 132 | CL_CHECK_ERR_R(queue.enqueueReadBuffer(*sgm.recoverFlow(), CL_FALSE, 0, 133 | M*N*sizeof(cl_short2), unary2)); 134 | 135 | CL_CHECK_ERR_R(queue.finish()); 136 | 137 | 138 | #ifdef PROFILE 139 | high_resolution_clock::time_point sgm_end = high_resolution_clock::now(); 140 | 141 | auto sgm_dur = duration_cast(sgm_end - sgm_start).count(); 142 | 143 | std::cout << "SGM (fw+bw): " << sgm_dur << " ms" << std::endl; 144 | #endif 145 | } 146 | 147 | 148 | void wtaflow(const float *feat1, const float *feat2, 149 | int M, int N, int L, 150 | int max_offset, float out_of_range, 151 | int P1, int P2, 152 | uint8_t *unary1, uint8_t *unary2) { 153 | 154 | std::vector all_plat; 155 | 156 | 157 | std::cout << M << "/" << N << "/" << max_offset << std::endl; 158 | 159 | cl::Platform::get(&all_plat); 160 | 161 | if (all_plat.size() == 0) { 162 | std::cout << "ERROR: No OpenCL platforms found, check OpenCL installation" 163 | << std::endl; 164 | return; 165 | } 166 | cl::Platform default_platform = all_plat[0]; 167 | 168 | std::vector all_dev; 169 | default_platform.getDevices(CL_DEVICE_TYPE_GPU, &all_dev); 170 | 171 | if (all_dev.size() == 0) { 172 | std::cout << "ERROR: no OpenCL device found" << std::endl; 173 | return; 174 | } 175 | 176 | cl::Device default_device = all_dev[0]; 177 | 178 | cl::Context context(default_device); 179 | 180 | cl::Program::Sources sources; 181 | 182 | std::string kernels = readKernel("../src_cl/sgm.cl"); 183 | 184 | sources.push_back({kernels.c_str(), kernels.length()}); 185 | 186 | cl::Program program(context, sources); 187 | if (program.build({default_device}) != CL_SUCCESS) { 188 | std::cout<<"ERROR: Error building: "<< 189 | program.getBuildInfo(default_device)<< std::endl; 190 | } 191 | 192 | cl::CommandQueue queue(context, default_device, CL_QUEUE_PROFILING_ENABLE); 193 | 194 | 195 | cl::Buffer unary1_d(context, CL_MEM_READ_WRITE, L*M*N*sizeof(uint8_t)); 196 | queue.enqueueFillBuffer(unary1_d, uint8_t(255*out_of_range), 0, L*M*N); 197 | 198 | cl::Buffer unary2_d(context, CL_MEM_READ_WRITE, L*M*N*sizeof(uint8_t)); 199 | queue.enqueueFillBuffer(unary2_d, uint8_t(255*out_of_range), 0, L*M*N); 200 | 201 | cl::Buffer feat1_d(context, CL_MEM_READ_ONLY, FEAT_DIM*M*N*sizeof(float)); 202 | queue.enqueueWriteBuffer(feat1_d, CL_TRUE, 0, FEAT_DIM*M*N*sizeof(float), feat1); 203 | 204 | cl::Buffer feat2_d(context, CL_MEM_READ_ONLY, FEAT_DIM*M*N*sizeof(float)); 205 | queue.enqueueWriteBuffer(feat2_d, CL_TRUE, 0, FEAT_DIM*M*N*sizeof(float), feat2); 206 | 207 | auto match_kernel = cl::make_kernel(program, "match"); 210 | 211 | cl::EnqueueArgs eargs(queue, 212 | cl::NullRange, 213 | cl::NDRange(M, N), 214 | cl::NullRange); 215 | 216 | match_kernel(eargs, feat1_d, feat2_d, M, N, max_offset, 217 | unary1_d, unary2_d); 218 | 219 | queue.enqueueReadBuffer(unary1_d, CL_TRUE, 0, L*M*N*sizeof(uint8_t), unary1); 220 | queue.enqueueReadBuffer(unary2_d, CL_TRUE, 0, L*M*N*sizeof(uint8_t), unary2); 221 | } 222 | -------------------------------------------------------------------------------- /src_cl/dcflow.h: -------------------------------------------------------------------------------- 1 | #ifndef SGMFLOW_MEX 2 | #define SGMFLOW_MEX 3 | 4 | #include 5 | 6 | #define MAX_SOURCE_SIZE (0x100000) 7 | 8 | #define BLOCKDIM_X 8 9 | #define BLOCKDIM_Y 32 10 | 11 | #define FEAT_DIM 64 12 | 13 | void sgmflow(const float *feat1, const float *feat2, 14 | const float *im1, const float *im2, 15 | int M, int N, int L, 16 | int max_offset, float out_of_range, 17 | int P1, int P2, 18 | int16_t *unary1, int16_t *unary2); 19 | 20 | #endif /* end of include guard: */ 21 | -------------------------------------------------------------------------------- /src_cl/dcflow_mex.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "dcflow.h" 5 | 6 | void mexFunction(int nlhs, mxArray *plhs[], 7 | int nrhs, const mxArray *prhs[]) 8 | { 9 | if (nrhs < 8) 10 | mexErrMsgTxt("Wrong number of input arguments"); 11 | 12 | const mwSize *dims = mxGetDimensions(prhs[0]); 13 | 14 | const int M = dims[0]; 15 | const int N = dims[1]; 16 | 17 | float *feat1 = (float *)mxGetData(prhs[0]); 18 | float *feat2 = (float *)mxGetData(prhs[1]); 19 | 20 | float *im1 = (float *)mxGetData(prhs[2]); 21 | float *im2 = (float *)mxGetData(prhs[3]); 22 | 23 | int max_offset = mxGetScalar(prhs[4]); 24 | float out_of_range = mxGetScalar(prhs[5]); 25 | 26 | int P1 = (int)mxGetScalar(prhs[6]); 27 | int P2 = (int)mxGetScalar(prhs[7]); 28 | 29 | int n_disps = 2*max_offset+1; 30 | n_disps = n_disps*n_disps; 31 | 32 | const mwSize dims_out[] = {2, M, N}; 33 | 34 | plhs[0] = mxCreateNumericArray(3, dims_out, mxINT16_CLASS, mxREAL); 35 | int16_t *unary1 = (int16_t*)mxGetData(plhs[0]); 36 | 37 | plhs[1] = mxCreateNumericArray(3, dims_out, mxINT16_CLASS, mxREAL); 38 | int16_t *unary2 = (int16_t*)mxGetData(plhs[1]); 39 | 40 | sgmflow(feat1, feat2, im1, im2, M, N, n_disps, max_offset, out_of_range, 41 | P1, P2, unary1, unary2); 42 | } 43 | -------------------------------------------------------------------------------- /src_cl/helper.h: -------------------------------------------------------------------------------- 1 | #ifndef HELPER_H 2 | #define HELPER_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #define CL_CHECK_ERR(_err) \ 11 | if(_err) { \ 12 | std::cout << "OpenCL error in Line " << __LINE__ << " : " << getErrorString(_err) << std::endl; \ 13 | } \ 14 | 15 | #define CL_CHECK_ERR_R(_err) \ 16 | if(_err) { \ 17 | std::cout << "OpenCL error in Line " << __LINE__ << " : " << getErrorString(_err) << std::endl; \ 18 | return; \ 19 | } \ 20 | 21 | inline std::string readKernel(const std::string &fileName) 22 | { 23 | std::ifstream ifs(fileName.c_str(), std::ios::in | std::ios::binary | std::ios::ate); 24 | 25 | std::ifstream::pos_type fileSize = ifs.tellg(); 26 | ifs.seekg(0, std::ios::beg); 27 | 28 | std::vector bytes(fileSize); 29 | ifs.read(&bytes[0], fileSize); 30 | 31 | return std::string(&bytes[0], fileSize); 32 | } 33 | 34 | inline std::string getErrorString(cl_int error) 35 | { 36 | switch(error){ 37 | // run-time and JIT compiler errors 38 | case 0: return "CL_SUCCESS"; 39 | case -1: return "CL_DEVICE_NOT_FOUND"; 40 | case -2: return "CL_DEVICE_NOT_AVAILABLE"; 41 | case -3: return "CL_COMPILER_NOT_AVAILABLE"; 42 | case -4: return "CL_MEM_OBJECT_ALLOCATION_FAILURE"; 43 | case -5: return "CL_OUT_OF_RESOURCES"; 44 | case -6: return "CL_OUT_OF_HOST_MEMORY"; 45 | case -7: return "CL_PROFILING_INFO_NOT_AVAILABLE"; 46 | case -8: return "CL_MEM_COPY_OVERLAP"; 47 | case -9: return "CL_IMAGE_FORMAT_MISMATCH"; 48 | case -10: return "CL_IMAGE_FORMAT_NOT_SUPPORTED"; 49 | case -11: return "CL_BUILD_PROGRAM_FAILURE"; 50 | case -12: return "CL_MAP_FAILURE"; 51 | case -13: return "CL_MISALIGNED_SUB_BUFFER_OFFSET"; 52 | case -14: return "CL_EXEC_STATUS_ERROR_FOR_EVENTS_IN_WAIT_LIST"; 53 | case -15: return "CL_COMPILE_PROGRAM_FAILURE"; 54 | case -16: return "CL_LINKER_NOT_AVAILABLE"; 55 | case -17: return "CL_LINK_PROGRAM_FAILURE"; 56 | case -18: return "CL_DEVICE_PARTITION_FAILED"; 57 | case -19: return "CL_KERNEL_ARG_INFO_NOT_AVAILABLE"; 58 | 59 | // compile-time errors 60 | case -30: return "CL_INVALID_VALUE"; 61 | case -31: return "CL_INVALID_DEVICE_TYPE"; 62 | case -32: return "CL_INVALID_PLATFORM"; 63 | case -33: return "CL_INVALID_DEVICE"; 64 | case -34: return "CL_INVALID_CONTEXT"; 65 | case -35: return "CL_INVALID_QUEUE_PROPERTIES"; 66 | case -36: return "CL_INVALID_COMMAND_QUEUE"; 67 | case -37: return "CL_INVALID_HOST_PTR"; 68 | case -38: return "CL_INVALID_MEM_OBJECT"; 69 | case -39: return "CL_INVALID_IMAGE_FORMAT_DESCRIPTOR"; 70 | case -40: return "CL_INVALID_IMAGE_SIZE"; 71 | case -41: return "CL_INVALID_SAMPLER"; 72 | case -42: return "CL_INVALID_BINARY"; 73 | case -43: return "CL_INVALID_BUILD_OPTIONS"; 74 | case -44: return "CL_INVALID_PROGRAM"; 75 | case -45: return "CL_INVALID_PROGRAM_EXECUTABLE"; 76 | case -46: return "CL_INVALID_KERNEL_NAME"; 77 | case -47: return "CL_INVALID_KERNEL_DEFINITION"; 78 | case -48: return "CL_INVALID_KERNEL"; 79 | case -49: return "CL_INVALID_ARG_INDEX"; 80 | case -50: return "CL_INVALID_ARG_VALUE"; 81 | case -51: return "CL_INVALID_ARG_SIZE"; 82 | case -52: return "CL_INVALID_KERNEL_ARGS"; 83 | case -53: return "CL_INVALID_WORK_DIMENSION"; 84 | case -54: return "CL_INVALID_WORK_GROUP_SIZE"; 85 | case -55: return "CL_INVALID_WORK_ITEM_SIZE"; 86 | case -56: return "CL_INVALID_GLOBAL_OFFSET"; 87 | case -57: return "CL_INVALID_EVENT_WAIT_LIST"; 88 | case -58: return "CL_INVALID_EVENT"; 89 | case -59: return "CL_INVALID_OPERATION"; 90 | case -60: return "CL_INVALID_GL_OBJECT"; 91 | case -61: return "CL_INVALID_BUFFER_SIZE"; 92 | case -62: return "CL_INVALID_MIP_LEVEL"; 93 | case -63: return "CL_INVALID_GLOBAL_WORK_SIZE"; 94 | case -64: return "CL_INVALID_PROPERTY"; 95 | case -65: return "CL_INVALID_IMAGE_DESCRIPTOR"; 96 | case -66: return "CL_INVALID_COMPILER_OPTIONS"; 97 | case -67: return "CL_INVALID_LINKER_OPTIONS"; 98 | case -68: return "CL_INVALID_DEVICE_PARTITION_COUNT"; 99 | 100 | // extension errors 101 | case -1000: return "CL_INVALID_GL_SHAREGROUP_REFERENCE_KHR"; 102 | case -1001: return "CL_PLATFORM_NOT_FOUND_KHR"; 103 | case -1002: return "CL_INVALID_D3D10_DEVICE_KHR"; 104 | case -1003: return "CL_INVALID_D3D10_RESOURCE_KHR"; 105 | case -1004: return "CL_D3D10_RESOURCE_ALREADY_ACQUIRED_KHR"; 106 | case -1005: return "CL_D3D10_RESOURCE_NOT_ACQUIRED_KHR"; 107 | default: return "Unknown OpenCL error"; 108 | } 109 | } 110 | 111 | #endif // HELPER_H 112 | -------------------------------------------------------------------------------- /src_cl/sgm.cl: -------------------------------------------------------------------------------- 1 | #define FEAT_DIM 64 2 | #define USE_ATOMIC_ADD 3 | 4 | void atomicAddShort(__global ushort *address, ushort val) { 5 | __global unsigned int *base_address = (__global unsigned int *) ((__global char *)address - ((size_t)address & 2)); 6 | unsigned int long_val = ((size_t)address & 2) ? ((unsigned int)val << 16) : (unsigned short)val; 7 | unsigned int long_old = atomic_add(base_address, long_val); 8 | } 9 | 10 | void partial_reduce(const ushort val, const __private y, const __private l, 11 | const __private int M, const __private int L, const __private int ps, 12 | __global ushort *out, __local ushort *sdata) { 13 | const int yid = get_local_id(0); 14 | const int tid = get_local_id(1); 15 | 16 | if (y < M) { 17 | const int yoff = yid*get_local_size(1); 18 | sdata[yoff + tid] = (l < L) ? val : USHRT_MAX; 19 | 20 | barrier(CLK_LOCAL_MEM_FENCE); 21 | 22 | for (unsigned int s = get_local_size(1)/2; s > 0; s >>= 1) { 23 | if (tid < s) 24 | sdata[yoff+tid] = min(sdata[yoff+tid], sdata[yoff+tid+s]); 25 | barrier(CLK_LOCAL_MEM_FENCE); 26 | } 27 | 28 | if (tid == 0) out[y*get_num_groups(1) + get_group_id(1)] = sdata[yoff]; 29 | } 30 | } 31 | 32 | __constant sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST; 33 | 34 | ushort sgm_step(const __global float *refImg, 35 | const __global ushort *minvals, 36 | __global ushort *out, __read_only image3d_t tmp_in, 37 | __private int M, __private int N, __private int L, __private int sL, 38 | __private int x, __private int y, __private int dx, __private int dy, 39 | __private int xyps, 40 | __private int inc, __private ushort P1, __private ushort P2) { 41 | const float D1 = fabs(refImg[M*x+y] - refImg[M*x+y-inc]); 42 | const ushort P2_ = (D1 < 0.02f) ? P2 : P2/3.0f; 43 | 44 | const ushort old_min = minvals[x*M+y-inc]; 45 | const ushort P2cost = old_min + P2_; 46 | 47 | ushort prev = min(P2cost, 48 | (ushort)read_imageui(tmp_in, sampler, (int4)(xyps,dx,dy,0)).x); 49 | 50 | ushort prev2 = 51 | (min((ushort)read_imageui(tmp_in, sampler, (int4)(xyps,dx,dy-1,0)).x, 52 | (ushort)read_imageui(tmp_in, sampler, (int4)(xyps,dx,dy+1,0)).x)); 53 | 54 | ushort prev3 = 55 | (min((ushort)read_imageui(tmp_in, sampler, (int4)(xyps,dx+1,dy,0)).x, 56 | (ushort)read_imageui(tmp_in, sampler, (int4)(xyps,dx-1,dy,0)).x)); 57 | 58 | return min(prev, add_sat(P1, min(prev3, prev2))) - old_min; 59 | } 60 | 61 | __kernel void sgm_slice_horz(const __global uchar *cost, 62 | const __global float *refImg, const __global ushort *minvals, 63 | __private int M, __private int N, __private int L, 64 | __private int sL, __private int ps, 65 | __private int P1, __private int P2, __global ushort *out, 66 | __read_only image3d_t tmp_in, __write_only image3d_t tmp_out, 67 | __private int x, __private int inc, 68 | __global ushort *scratch, __local ushort *sdata) { 69 | 70 | const int y = get_global_id(0); 71 | const int l = get_global_id(1); 72 | 73 | if (y < M & l < L) { 74 | ushort outval = cost[(l*N + x)*M + y]; 75 | 76 | const int dx = l % sL; 77 | const int dy = l / sL; 78 | 79 | if ((inc < 0) ? (x < N-1) : x > 0) { 80 | outval += sgm_step(refImg, minvals, out, tmp_in, M, N, L, sL, 81 | x, y, dx, dy, y, inc, P1, P2); 82 | } 83 | 84 | write_imageui(tmp_out, (int4)(y, dx, dy, 0), outval); 85 | 86 | // Write result 87 | #ifdef USE_ATOMIC_ADD 88 | atomicAddShort(&(out[(x*M + y)*L + l]), outval); 89 | #else 90 | out[(x*M + y)*L + l] += outval; 91 | #endif 92 | 93 | // Partial reduction for next pass 94 | partial_reduce(outval, y, l, M, L, ps, scratch, sdata); 95 | } 96 | } 97 | 98 | __kernel void sgm_slice_vert(const __global uchar *cost, 99 | const __global float *refImg, const __global ushort *minvals, 100 | __private int M, __private int N, __private int L, 101 | __private int sL, __private int ps, 102 | __private int P1, __private int P2, __global ushort *out, 103 | __read_only image3d_t tmp_in, __write_only image3d_t tmp_out, 104 | __private int y, __private int inc, 105 | __global ushort *scratch, __local ushort *sdata) { 106 | 107 | const int l = get_global_id(1); 108 | const int x = get_global_id(0); 109 | 110 | if (x < N & l < L) { 111 | ushort outval = cost[(l*N + x)*M + y]; 112 | 113 | const int dx = l % sL; 114 | const int dy = l / sL; 115 | 116 | if ((inc < 0) ? (y < M-1) : y > 0) { 117 | outval += sgm_step(refImg, minvals, out, tmp_in, M, N, L, sL, 118 | x, y, dx, dy, x, inc, P1, P2); 119 | } 120 | 121 | write_imageui(tmp_out, (int4)(x, dx, dy, 0), outval); 122 | 123 | // Write result 124 | #ifdef USE_ATOMIC_ADD 125 | atomicAddShort(&(out[(x*M + y)*L + l]), outval); 126 | #else 127 | out[(x*M + y)*L + l] += outval; 128 | #endif 129 | 130 | // Partial reduction for next pass 131 | partial_reduce(outval, x, l, N, L, ps, scratch, sdata); 132 | } 133 | } 134 | 135 | // TODO(Rene): Parallel reduction in shared memory 136 | ushort min_reduce_loc(const __global ushort *in, __private int L) { 137 | ushort mm = USHRT_MAX; 138 | 139 | #pragma unroll 140 | for (int l = 0; l < L; ++l) 141 | mm = min(in[l], mm); 142 | 143 | return mm; 144 | } 145 | 146 | // TODO(Rene): These are redundant ->fuse 147 | __kernel void min_reduce_horz_fin(const __global ushort *in, 148 | const __private int M, const __private int N, const __private int L, 149 | const __private int x, __global ushort *out) { 150 | const int y = get_global_id(0); 151 | 152 | if (y < M) { 153 | out[x*M + y] = min_reduce_loc(in+y*L, L); 154 | } 155 | } 156 | 157 | __kernel void min_reduce_vert_fin(const __global ushort *in, 158 | const __private int M, const __private int N, const __private int L, 159 | const __private int y, __global ushort *out) { 160 | const int x = get_global_id(0); 161 | 162 | if (x < N) { 163 | out[x*M + y] = min_reduce_loc(in+x*L, L); 164 | } 165 | } 166 | 167 | __kernel void match(__global float *feat1, __global float *feat2, 168 | __private int M, __private int N, __private int max_offset, 169 | __global uchar *unary1, __global uchar *unary2) { 170 | const int y = get_global_id(0); 171 | const int x = get_global_id(1); 172 | 173 | if ((x < N) & y < M) { 174 | 175 | const int slice_stride = M*N; 176 | 177 | const int label_stride = 2*max_offset + 1; 178 | const int base = y + M*x; 179 | 180 | float4 feat1_s[FEAT_DIM/4]; 181 | 182 | // Cache left feature map 183 | #pragma unroll FEAT_DIM/4 184 | for (int l = 0; l < FEAT_DIM; l+=4) { 185 | // Weird, but seems to be slightly faster than direct adressing 186 | int addr = base + M*N*l; 187 | feat1_s[l >> 2] = (float4)(feat1[addr], 188 | feat1[addr+=slice_stride], 189 | feat1[addr+=slice_stride], 190 | feat1[addr+=slice_stride]); 191 | } 192 | 193 | for (int j = -max_offset, cnt = 0; j <= max_offset; ++j) { 194 | const int x2 = x + j; 195 | const int lin_idx_x = (max_offset - j)*label_stride + max_offset; 196 | 197 | for (int i = -max_offset; i <= max_offset; ++i, ++cnt) { 198 | const int y2 = y + i; 199 | 200 | // TODO(Rene): We have thread divergence here, can we optimize this? 201 | if ((x2 < 0) | (x2 >= N) | (y2 < 0) | (y2 >= M)) { 202 | continue; 203 | } 204 | 205 | const int lin_idx2 = lin_idx_x - i; 206 | const int base2 = y2 + M*x2; 207 | 208 | // Compute dot product with prefetching 209 | float accum = 1.0f; 210 | 211 | #pragma unroll FEAT_DIM/4 212 | for (int l = 0; l < FEAT_DIM; l+=4) { 213 | int addr = base2 + M*N*l; 214 | float4 ff2 = (float4)(feat2[addr], 215 | feat2[addr+=slice_stride], 216 | feat2[addr+=slice_stride], 217 | feat2[addr+=slice_stride]); 218 | 219 | accum -= dot(feat1_s[l >> 2], ff2); 220 | } 221 | 222 | accum *= 127.0f; 223 | 224 | unary1[base + slice_stride*cnt] = accum; 225 | unary2[base2 + slice_stride*lin_idx2] = accum; 226 | } 227 | } 228 | } 229 | } 230 | 231 | __kernel void recover_flow(const __global ushort *cvol, 232 | __private int M, __private int N, 233 | __private int L, __global short2 *flow, 234 | __private int n_range, __private int max_offset) { 235 | const int y = get_global_id(0); 236 | const int x = get_global_id(1); 237 | 238 | if ((y < M) & (x < N)) { 239 | ushort minval = USHRT_MAX; 240 | int minlabel = 0; 241 | 242 | int base = x*M + y; 243 | cvol += base*L; 244 | 245 | #pragma unroll 246 | for (int l = 0; l < L; ++l) { 247 | ushort cval = cvol[l]; 248 | if (cval < minval) { 249 | minval = cval; 250 | minlabel = l; 251 | } 252 | } 253 | 254 | flow[base] = (short2)(minlabel/n_range, minlabel % n_range) - (short2)(max_offset); 255 | } 256 | } 257 | 258 | 259 | __kernel void fix_border(__global uchar *unary1, __global uchar *unary2, 260 | __private int M, __private int N, __private int L) { 261 | const int y = get_global_id(0); 262 | const int x = get_global_id(1); 263 | 264 | // Assume fixed window size for now 265 | // left border 266 | if (x < 4) { 267 | for (int l = 0; l < L; ++l) { 268 | unary1[y + M*(x + N*l)] = unary1[y + M*(7 - x + N*l)]; 269 | unary2[y + M*(x + N*l)] = unary2[y + M*(7 - x + N*l)]; 270 | } 271 | } 272 | 273 | // right border 274 | if ((x >= N-4) & (x < N)) { 275 | for (int l = 0; l < L; ++l) { 276 | unary1[y + M*(x + N*l)] = unary1[y + M*(2*N-9 - x + N*l)]; 277 | unary2[y + M*(x + N*l)] = unary2[y + M*(2*N-9 - x + N*l)]; 278 | } 279 | } 280 | 281 | // upper border 282 | if (y < 4) { 283 | for (int l = 0; l < L; ++l) { 284 | unary1[y + M*(x + N*l)] = unary1[7 - y + M*(x + N*l)]; 285 | unary2[y + M*(x + N*l)] = unary2[7 - y + M*(x + N*l)]; 286 | } 287 | } 288 | 289 | // lower border 290 | if ((y >= M-4) & (y < M)) { 291 | for (int l = 0; l < L; ++l) { 292 | unary1[y + M*(x + N*l)] = unary1[2*M - 9 - y + M*(x + N*l)]; 293 | unary2[y + M*(x + N*l)] = unary2[2*M - 9 - y + M*(x + N*l)]; 294 | } 295 | } 296 | } 297 | -------------------------------------------------------------------------------- /src_cl/sgm.cpp: -------------------------------------------------------------------------------- 1 | #include "sgm.h" 2 | #include "helper.h" 3 | #include 4 | 5 | SGM::SGM(int M, int N, int L, int max_offset, int P1, int P2, 6 | cl::CommandQueue &queue, cl::Context &context, cl::Program &program) : 7 | M_(M), N_(N), L_(L), max_offset_(max_offset), P1_(P1), P2_(P2), 8 | sL_(2*max_offset+1), ps_(L + (4 - (L % 4))), 9 | queue_(queue), context_(context), program_(program) { 10 | 11 | ngroups_horz_ = DIVUP(ps_,HORZ_DBS)/HORZ_DBS; 12 | ngroups_vert_ = DIVUP(ps_,VERT_DBS)/VERT_DBS; 13 | 14 | // Set up output 15 | out_ = new cl::Buffer(context, CL_MEM_READ_WRITE, L*M*N*sizeof(uint16_t)); 16 | flow_ = new cl::Buffer(context, CL_MEM_READ_WRITE, M*N*sizeof(cl_short2)); 17 | 18 | // Set up all temporary buffers 19 | for (int i = 0; i < 4; ++i) { 20 | min_bufs_[i] = new cl::Buffer(context, CL_MEM_READ_WRITE, M*N*sizeof(uint16_t)); 21 | 22 | tmp_bufs_h_[i] = new cl::Image3D(context, CL_MEM_READ_WRITE, 23 | cl::ImageFormat(CL_R, CL_UNSIGNED_INT16), M, sL_, sL_); 24 | tmp_bufs_v_[i] = new cl::Image3D(context, CL_MEM_READ_WRITE, 25 | cl::ImageFormat(CL_R, CL_UNSIGNED_INT16), N, sL_, sL_); 26 | } 27 | 28 | for (int i = 0; i < 2; ++i) { 29 | scratch_h_[i] = new cl::Buffer(context, CL_MEM_READ_WRITE, M*ngroups_horz_*sizeof(uint16_t)); 30 | scratch_v_[i] = new cl::Buffer(context, CL_MEM_READ_WRITE, N*ngroups_vert_*sizeof(uint16_t)); 31 | } 32 | 33 | setupKernels(); 34 | } 35 | 36 | void SGM::process(cl::Buffer &unary, cl::Buffer &im) { 37 | CL_CHECK_ERR_R(queue_.enqueueFillBuffer(*out_, uint16_t(0), 0, L_*M_*N_*sizeof(uint16_t))); 38 | 39 | sgm_horz_kernel_->setArg(0, unary); 40 | sgm_horz_kernel_->setArg(1, im); 41 | 42 | sgm_vert_kernel_->setArg(0, unary); 43 | sgm_vert_kernel_->setArg(1, im); 44 | CL_CHECK_ERR_R(queue_.finish()); 45 | 46 | // Process all 4 canonical directions 47 | for (int x = 0, y = 0; (x < N_) | (y < M_); ++x, ++y) { 48 | int xr = N_-1-x; 49 | int yr = M_-1-y; 50 | 51 | if (x < N_) 52 | processHorz(x, 0); 53 | 54 | if (xr >= 0) 55 | processHorz(xr, 1); 56 | 57 | if (y < M_) 58 | processVert(y, 0); 59 | 60 | if (yr >= 0) 61 | processVert(yr, 1); 62 | } 63 | } 64 | 65 | void SGM::processHorz(int x, int reverse) { 66 | auto *tmp_from = tmp_bufs_h_[x % 2 ? reverse : 2 + reverse]; 67 | auto *tmp_to = tmp_bufs_h_[x % 2 ? 2 + reverse : reverse]; 68 | int inc = reverse ? -M_ : M_; 69 | 70 | std::vector *glob_sync = reverse ? &sync_lr_bw_ : &sync_lr_fw_; 71 | std::vector *preq_sync = reverse ? &preq_lr_bw_ : &preq_lr_fw_; 72 | 73 | preq_sync->push_back(cl::Event()); 74 | 75 | // SGM pass 76 | sgm_horz_kernel_->setArg(2, *min_bufs_[reverse]); 77 | sgm_horz_kernel_->setArg(11, *tmp_from); 78 | sgm_horz_kernel_->setArg(12, *tmp_to); 79 | sgm_horz_kernel_->setArg(13, x); 80 | sgm_horz_kernel_->setArg(14, inc); 81 | sgm_horz_kernel_->setArg(15, *scratch_h_[reverse]); 82 | sgm_horz_kernel_->setArg(16, cl::Local(M_*HORZ_DBS*sizeof(ushort))); 83 | 84 | CL_CHECK_ERR_R(queue_.enqueueNDRangeKernel(*sgm_horz_kernel_, 85 | cl::NullRange, 86 | cl::NDRange(DIVUP(M_,HORZ_BS_X), DIVUP(L_,HORZ_DBS)), 87 | cl::NDRange(HORZ_BS_X,HORZ_DBS), glob_sync, &(preq_sync->back()))); 88 | 89 | // Final reduction pass 90 | red_horz_kernel_->setArg(0, *scratch_h_[reverse]); 91 | red_horz_kernel_->setArg(3, ngroups_horz_); 92 | red_horz_kernel_->setArg(4, x); 93 | red_horz_kernel_->setArg(5, *min_bufs_[reverse]); 94 | 95 | glob_sync->push_back(cl::Event()); 96 | 97 | CL_CHECK_ERR_R(queue_.enqueueNDRangeKernel(*red_horz_kernel_, cl::NullRange, 98 | cl::NDRange(DIVUP(M_, HORZ_DBS)), 99 | cl::NDRange(HORZ_DBS), preq_sync, &(glob_sync->back()))); 100 | } 101 | 102 | void SGM::processVert(int y, int reverse) { 103 | auto *tmp_from = tmp_bufs_v_[y % 2 ? reverse : 2 + reverse]; 104 | auto *tmp_to = tmp_bufs_v_[y % 2 ? 2 + reverse : reverse]; 105 | int inc = reverse ? -1 : 1; 106 | 107 | std::vector *glob_sync = reverse ? &sync_tb_bw_ : &sync_tb_fw_; 108 | std::vector *preq_sync = reverse ? &preq_tb_bw_ : &preq_tb_fw_; 109 | 110 | preq_sync->push_back(cl::Event()); 111 | 112 | // SGM pass 113 | sgm_vert_kernel_->setArg(2, *min_bufs_[2+reverse]); 114 | sgm_vert_kernel_->setArg(11, *tmp_from); 115 | sgm_vert_kernel_->setArg(12, *tmp_to); 116 | sgm_vert_kernel_->setArg(13, y); 117 | sgm_vert_kernel_->setArg(14, inc); 118 | sgm_vert_kernel_->setArg(15, *scratch_v_[reverse]); 119 | sgm_vert_kernel_->setArg(16, cl::Local(N_*VERT_DBS*sizeof(ushort))); 120 | 121 | CL_CHECK_ERR_R(queue_.enqueueNDRangeKernel(*sgm_vert_kernel_, 122 | cl::NullRange, 123 | cl::NDRange(DIVUP(N_,VERT_BS_X), DIVUP(L_,VERT_DBS)), 124 | cl::NDRange(VERT_BS_X,VERT_DBS), glob_sync, &(preq_sync->back()))); 125 | 126 | // Final reduction pass 127 | red_vert_kernel_->setArg(0, *scratch_v_[reverse]); 128 | red_vert_kernel_->setArg(3, ngroups_vert_); 129 | red_vert_kernel_->setArg(4, y); 130 | red_vert_kernel_->setArg(5, *min_bufs_[2+reverse]); 131 | 132 | glob_sync->push_back(cl::Event()); 133 | 134 | CL_CHECK_ERR_R(queue_.enqueueNDRangeKernel(*red_vert_kernel_, cl::NullRange, 135 | cl::NDRange(DIVUP(N_, VERT_DBS)), 136 | cl::NDRange(VERT_DBS), preq_sync, &(glob_sync->back()))); 137 | } 138 | 139 | cl::Buffer *SGM::recoverFlow() { 140 | CL_CHECK_ERR(queue_.finish()); 141 | 142 | // This is slow -> Implement proper reduction! 143 | CL_CHECK_ERR(queue_.enqueueNDRangeKernel(*recover_kernel_, 144 | cl::NullRange, 145 | cl::NDRange(M_, N_), 146 | cl::NullRange, NULL, NULL)); 147 | 148 | CL_CHECK_ERR(queue_.finish()); 149 | 150 | return flow_; 151 | } 152 | 153 | void SGM::setupKernels() { 154 | // Setup constant arguments of kernels 155 | sgm_horz_kernel_ = new cl::Kernel(program_, "sgm_slice_horz"); 156 | sgm_horz_kernel_->setArg(3, M_); 157 | sgm_horz_kernel_->setArg(4, N_); 158 | sgm_horz_kernel_->setArg(5, L_); 159 | sgm_horz_kernel_->setArg(6, sL_); 160 | sgm_horz_kernel_->setArg(7, ps_); 161 | sgm_horz_kernel_->setArg(8, P1_); 162 | sgm_horz_kernel_->setArg(9, P2_); 163 | sgm_horz_kernel_->setArg(10, *out_); 164 | 165 | sgm_vert_kernel_ = new cl::Kernel(program_, "sgm_slice_vert"); 166 | sgm_vert_kernel_->setArg(3, M_); 167 | sgm_vert_kernel_->setArg(4, N_); 168 | sgm_vert_kernel_->setArg(5, L_); 169 | sgm_vert_kernel_->setArg(6, sL_); 170 | sgm_vert_kernel_->setArg(7, ps_); 171 | sgm_vert_kernel_->setArg(8, P1_); 172 | sgm_vert_kernel_->setArg(9, P2_); 173 | sgm_vert_kernel_->setArg(10, *out_); 174 | 175 | red_horz_kernel_ = new cl::Kernel(program_, "min_reduce_horz_fin"); 176 | red_horz_kernel_->setArg(1, M_); 177 | red_horz_kernel_->setArg(2, N_); 178 | red_horz_kernel_->setArg(3, L_); 179 | 180 | red_vert_kernel_ = new cl::Kernel(program_, "min_reduce_vert_fin"); 181 | red_vert_kernel_->setArg(1, M_); 182 | red_vert_kernel_->setArg(2, N_); 183 | red_vert_kernel_->setArg(3, L_); 184 | 185 | recover_kernel_ = new cl::Kernel(program_, "recover_flow"); 186 | recover_kernel_->setArg(0, *out_); 187 | recover_kernel_->setArg(1, M_); 188 | recover_kernel_->setArg(2, N_); 189 | recover_kernel_->setArg(3, L_); 190 | recover_kernel_->setArg(4, *flow_); 191 | recover_kernel_->setArg(5, sL_); 192 | recover_kernel_->setArg(6, max_offset_); 193 | } 194 | 195 | SGM::~SGM() { 196 | CL_CHECK_ERR(queue_.finish()); 197 | 198 | for (int i = 0; i < 4; ++i) { 199 | delete min_bufs_[i]; 200 | delete tmp_bufs_h_[i]; 201 | delete tmp_bufs_v_[i]; 202 | } 203 | 204 | for (int i = 0; i < 2; ++i) { 205 | delete scratch_h_[i]; 206 | delete scratch_v_[i]; 207 | } 208 | 209 | delete out_; 210 | delete flow_; 211 | 212 | delete sgm_horz_kernel_; 213 | delete sgm_vert_kernel_; 214 | delete red_horz_kernel_; 215 | delete red_vert_kernel_; 216 | delete recover_kernel_; 217 | }; 218 | -------------------------------------------------------------------------------- /src_cl/sgm.h: -------------------------------------------------------------------------------- 1 | #ifndef SGM_H 2 | #include 3 | 4 | #define DIVUP(S, A) ((S + (A) - 1)/(A))*(A) 5 | 6 | #define HORZ_DBS 32 7 | #define VERT_DBS 32 8 | 9 | #define HORZ_BS_X 8 10 | #define VERT_BS_X 8 11 | 12 | 13 | class SGM 14 | { 15 | public: 16 | SGM(int M, int N, int L, int max_offset, int P1, int P2, 17 | cl::CommandQueue &queue, cl::Context &context, cl::Program &program); 18 | 19 | void process(cl::Buffer &unary, cl::Buffer &im); 20 | 21 | cl::Buffer *recoverFlow(); 22 | 23 | cl::Buffer *getCostVolume() { 24 | queue_.finish(); 25 | return out_; 26 | } 27 | 28 | ~SGM(); 29 | 30 | private: 31 | void setupKernels(); 32 | 33 | void processHorz(int x, int reverse); 34 | void processVert(int y, int reverse); 35 | 36 | int M_, N_, L_, max_offset_, P1_, P2_, sL_, ps_, ngroups_horz_, ngroups_vert_; 37 | 38 | cl::CommandQueue &queue_; 39 | cl::Context &context_; 40 | cl::Program &program_; 41 | 42 | cl::Buffer *min_bufs_[4]; 43 | 44 | cl::Image3D *tmp_bufs_h_[4]; 45 | cl::Image3D *tmp_bufs_v_[4]; 46 | 47 | cl::Buffer *scratch_h_[2]; 48 | cl::Buffer *scratch_v_[2]; 49 | 50 | cl::Buffer *out_; 51 | cl::Buffer *flow_; 52 | 53 | cl::Kernel *sgm_horz_kernel_, *sgm_vert_kernel_; 54 | cl::Kernel *red_horz_kernel_, *red_vert_kernel_; 55 | cl::Kernel *recover_kernel_; 56 | 57 | std::vector sync_lr_fw_; 58 | std::vector sync_lr_bw_; 59 | std::vector sync_tb_fw_; 60 | std::vector sync_tb_bw_; 61 | 62 | std::vector preq_lr_fw_; 63 | std::vector preq_lr_bw_; 64 | std::vector preq_tb_fw_; 65 | std::vector preq_tb_bw_; 66 | }; 67 | 68 | #endif // SGM_H 69 | --------------------------------------------------------------------------------