├── .gitignore ├── LICENSE ├── PWC_src ├── __init__.py ├── correlation_package │ ├── __init__.py │ ├── correlation.py │ ├── correlation_cuda.cc │ ├── correlation_cuda.egg-info │ │ ├── PKG-INFO │ │ ├── SOURCES.txt │ │ ├── dependency_links.txt │ │ └── top_level.txt │ ├── correlation_cuda_kernel.cu │ ├── correlation_cuda_kernel.cuh │ └── setup.py ├── flowlib.py └── pwc.py ├── README.md ├── demo.py ├── example ├── 0img0.ppm ├── 0img1.ppm ├── 1img0.ppm ├── 1img1.ppm └── flow0.png ├── misc └── demo.png ├── models ├── chairs-things.pytorch └── sintel.pytorch └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *build/ 2 | *dist/ 3 | *pyc 4 | *__pycache__ 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /PWC_src/__init__.py: -------------------------------------------------------------------------------- 1 | from .pwc import PWC_Net 2 | from .flowlib import * 3 | -------------------------------------------------------------------------------- /PWC_src/correlation_package/__init__.py: -------------------------------------------------------------------------------- 1 | from .correlation import Correlation 2 | -------------------------------------------------------------------------------- /PWC_src/correlation_package/correlation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.module import Module 3 | from torch.autograd import Function 4 | import correlation_cuda 5 | 6 | class CorrelationFunction(Function): 7 | 8 | def __init__(self, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1, corr_multiply=1): 9 | super(CorrelationFunction, self).__init__() 10 | self.pad_size = pad_size 11 | self.kernel_size = kernel_size 12 | self.max_displacement = max_displacement 13 | self.stride1 = stride1 14 | self.stride2 = stride2 15 | self.corr_multiply = corr_multiply 16 | # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1) 17 | 18 | def forward(self, input1, input2): 19 | self.save_for_backward(input1, input2) 20 | with torch.cuda.device_of(input1): 21 | rbot1 = input1.new() 22 | rbot2 = input2.new() 23 | output = input1.new() 24 | 25 | correlation_cuda.forward(input1, input2, rbot1, rbot2, output, 26 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) 27 | 28 | return output 29 | 30 | def backward(self, grad_output): 31 | input1, input2 = self.saved_tensors 32 | 33 | with torch.cuda.device_of(input1): 34 | rbot1 = input1.new() 35 | rbot2 = input2.new() 36 | 37 | grad_input1 = input1.new() 38 | grad_input2 = input2.new() 39 | 40 | correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2, 41 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) 42 | 43 | return grad_input1, grad_input2 44 | 45 | 46 | class Correlation(Module): 47 | def __init__(self, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1, corr_multiply=1): 48 | super(Correlation, self).__init__() 49 | self.pad_size = pad_size 50 | self.kernel_size = kernel_size 51 | self.max_displacement = max_displacement 52 | self.stride1 = stride1 53 | self.stride2 = stride2 54 | self.corr_multiply = corr_multiply 55 | 56 | def forward(self, input1, input2): 57 | 58 | result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2) 59 | 60 | return result 61 | 62 | -------------------------------------------------------------------------------- /PWC_src/correlation_package/correlation_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "correlation_cuda_kernel.cuh" 9 | 10 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, 11 | int pad_size, 12 | int kernel_size, 13 | int max_displacement, 14 | int stride1, 15 | int stride2, 16 | int corr_type_multiply) 17 | { 18 | 19 | int batchSize = input1.size(0); 20 | 21 | int nInputChannels = input1.size(1); 22 | int inputHeight = input1.size(2); 23 | int inputWidth = input1.size(3); 24 | 25 | int kernel_radius = (kernel_size - 1) / 2; 26 | int border_radius = kernel_radius + max_displacement; 27 | 28 | int paddedInputHeight = inputHeight + 2 * pad_size; 29 | int paddedInputWidth = inputWidth + 2 * pad_size; 30 | 31 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); 32 | 33 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1)); 34 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1)); 35 | 36 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 37 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 38 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); 39 | 40 | rInput1.fill_(0); 41 | rInput2.fill_(0); 42 | output.fill_(0); 43 | int success = correlation_forward_cuda_kernel( 44 | output, 45 | output.size(0), 46 | output.size(1), 47 | output.size(2), 48 | output.size(3), 49 | output.stride(0), 50 | output.stride(1), 51 | output.stride(2), 52 | output.stride(3), 53 | input1, 54 | input1.size(1), 55 | input1.size(2), 56 | input1.size(3), 57 | input1.stride(0), 58 | input1.stride(1), 59 | input1.stride(2), 60 | input1.stride(3), 61 | input2, 62 | input2.size(1), 63 | input2.stride(0), 64 | input2.stride(1), 65 | input2.stride(2), 66 | input2.stride(3), 67 | rInput1, 68 | rInput2, 69 | pad_size, 70 | kernel_size, 71 | max_displacement, 72 | stride1, 73 | stride2, 74 | corr_type_multiply, 75 | at::cuda::getCurrentCUDAStream() 76 | ); 77 | // Original: at::globalContext().getCurrentCUDAStream() 78 | // For CUDA-10.0, we need at::cuda::getCurrentCUDAStream() 79 | 80 | //check for errors 81 | if (!success) { 82 | AT_ERROR("CUDA call failed"); 83 | } 84 | 85 | return 1; 86 | 87 | } 88 | 89 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, 90 | at::Tensor& gradInput1, at::Tensor& gradInput2, 91 | int pad_size, 92 | int kernel_size, 93 | int max_displacement, 94 | int stride1, 95 | int stride2, 96 | int corr_type_multiply) 97 | { 98 | 99 | int batchSize = input1.size(0); 100 | int nInputChannels = input1.size(1); 101 | int paddedInputHeight = input1.size(2)+ 2 * pad_size; 102 | int paddedInputWidth = input1.size(3)+ 2 * pad_size; 103 | 104 | int height = input1.size(2); 105 | int width = input1.size(3); 106 | 107 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 108 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 109 | gradInput1.resize_({batchSize, nInputChannels, height, width}); 110 | gradInput2.resize_({batchSize, nInputChannels, height, width}); 111 | 112 | rInput1.fill_(0); 113 | rInput2.fill_(0); 114 | gradInput1.fill_(0); 115 | gradInput2.fill_(0); 116 | 117 | int success = correlation_backward_cuda_kernel(gradOutput, 118 | gradOutput.size(0), 119 | gradOutput.size(1), 120 | gradOutput.size(2), 121 | gradOutput.size(3), 122 | gradOutput.stride(0), 123 | gradOutput.stride(1), 124 | gradOutput.stride(2), 125 | gradOutput.stride(3), 126 | input1, 127 | input1.size(1), 128 | input1.size(2), 129 | input1.size(3), 130 | input1.stride(0), 131 | input1.stride(1), 132 | input1.stride(2), 133 | input1.stride(3), 134 | input2, 135 | input2.stride(0), 136 | input2.stride(1), 137 | input2.stride(2), 138 | input2.stride(3), 139 | gradInput1, 140 | gradInput1.stride(0), 141 | gradInput1.stride(1), 142 | gradInput1.stride(2), 143 | gradInput1.stride(3), 144 | gradInput2, 145 | gradInput2.size(1), 146 | gradInput2.stride(0), 147 | gradInput2.stride(1), 148 | gradInput2.stride(2), 149 | gradInput2.stride(3), 150 | rInput1, 151 | rInput2, 152 | pad_size, 153 | kernel_size, 154 | max_displacement, 155 | stride1, 156 | stride2, 157 | corr_type_multiply, 158 | at::cuda::getCurrentCUDAStream() 159 | ); 160 | 161 | if (!success) { 162 | AT_ERROR("CUDA call failed"); 163 | } 164 | 165 | return 1; 166 | } 167 | 168 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 169 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); 170 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); 171 | } 172 | 173 | -------------------------------------------------------------------------------- /PWC_src/correlation_package/correlation_cuda.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: correlation-cuda 3 | Version: 0.0.0 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /PWC_src/correlation_package/correlation_cuda.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | correlation_cuda.cc 2 | correlation_cuda_kernel.cu 3 | setup.py 4 | correlation_cuda.egg-info/PKG-INFO 5 | correlation_cuda.egg-info/SOURCES.txt 6 | correlation_cuda.egg-info/dependency_links.txt 7 | correlation_cuda.egg-info/top_level.txt -------------------------------------------------------------------------------- /PWC_src/correlation_package/correlation_cuda.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /PWC_src/correlation_package/correlation_cuda.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | correlation_cuda 2 | -------------------------------------------------------------------------------- /PWC_src/correlation_package/correlation_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "correlation_cuda_kernel.cuh" 4 | 5 | #define CUDA_NUM_THREADS 1024 6 | #define THREADS_PER_BLOCK 32 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | using at::Half; 14 | 15 | template 16 | __global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size) 17 | { 18 | 19 | // n (batch size), c (num of channels), y (height), x (width) 20 | int n = blockIdx.x; 21 | int y = blockIdx.y; 22 | int x = blockIdx.z; 23 | 24 | int ch_off = threadIdx.x; 25 | scalar_t value; 26 | 27 | int dimcyx = channels * height * width; 28 | int dimyx = height * width; 29 | 30 | int p_dimx = (width + 2 * pad_size); 31 | int p_dimy = (height + 2 * pad_size); 32 | int p_dimyxc = channels * p_dimy * p_dimx; 33 | int p_dimxc = p_dimx * channels; 34 | 35 | for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) { 36 | value = input[n * dimcyx + c * dimyx + y * width + x]; 37 | rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value; 38 | } 39 | } 40 | 41 | template 42 | __global__ void correlation_forward( scalar_t* output, int nOutputChannels, int outputHeight, int outputWidth, 43 | const scalar_t* __restrict__ rInput1, int nInputChannels, int inputHeight, int inputWidth, 44 | const scalar_t* __restrict__ rInput2, 45 | int pad_size, 46 | int kernel_size, 47 | int max_displacement, 48 | int stride1, 49 | int stride2) 50 | { 51 | // n (batch size), c (num of channels), y (height), x (width) 52 | 53 | int pInputWidth = inputWidth + 2 * pad_size; 54 | int pInputHeight = inputHeight + 2 * pad_size; 55 | 56 | int kernel_rad = (kernel_size - 1) / 2; 57 | int displacement_rad = max_displacement / stride2; 58 | int displacement_size = 2 * displacement_rad + 1; 59 | 60 | int n = blockIdx.x; 61 | int y1 = blockIdx.y * stride1 + max_displacement; 62 | int x1 = blockIdx.z * stride1 + max_displacement; 63 | int c = threadIdx.x; 64 | 65 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 66 | int pdimxc = pInputWidth * nInputChannels; 67 | int pdimc = nInputChannels; 68 | 69 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 70 | int tdimyx = outputHeight * outputWidth; 71 | int tdimx = outputWidth; 72 | 73 | scalar_t nelems = kernel_size * kernel_size * pdimc; 74 | 75 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; 76 | 77 | // no significant speed-up in using chip memory for input1 sub-data, 78 | // not enough chip memory size to accomodate memory per block for input2 sub-data 79 | // instead i've used device memory for both 80 | 81 | // element-wise product along channel axis 82 | for (int tj = -displacement_rad; tj <= displacement_rad; ++tj ) { 83 | for (int ti = -displacement_rad; ti <= displacement_rad; ++ti ) { 84 | prod_sum[c] = 0; 85 | int x2 = x1 + ti*stride2; 86 | int y2 = y1 + tj*stride2; 87 | 88 | for (int j = -kernel_rad; j <= kernel_rad; ++j) { 89 | for (int i = -kernel_rad; i <= kernel_rad; ++i) { 90 | for (int ch = c; ch < pdimc; ch += THREADS_PER_BLOCK) { 91 | int indx1 = n * pdimyxc + (y1+j) * pdimxc + (x1 + i) * pdimc + ch; 92 | int indx2 = n * pdimyxc + (y2+j) * pdimxc + (x2 + i) * pdimc + ch; 93 | 94 | prod_sum[c] += rInput1[indx1] * rInput2[indx2]; 95 | } 96 | } 97 | } 98 | 99 | // accumulate 100 | __syncthreads(); 101 | if (c == 0) { 102 | scalar_t reduce_sum = 0; 103 | for (int index = 0; index < THREADS_PER_BLOCK; ++index) { 104 | reduce_sum += prod_sum[index]; 105 | } 106 | int tc = (tj + displacement_rad) * displacement_size + (ti + displacement_rad); 107 | const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + blockIdx.z; 108 | output[tindx] = reduce_sum / nelems; 109 | } 110 | 111 | } 112 | } 113 | 114 | } 115 | 116 | template 117 | __global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth, 118 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 119 | const scalar_t* __restrict__ rInput2, 120 | int pad_size, 121 | int kernel_size, 122 | int max_displacement, 123 | int stride1, 124 | int stride2) 125 | { 126 | // n (batch size), c (num of channels), y (height), x (width) 127 | 128 | int n = item; 129 | int y = blockIdx.x * stride1 + pad_size; 130 | int x = blockIdx.y * stride1 + pad_size; 131 | int c = blockIdx.z; 132 | int tch_off = threadIdx.x; 133 | 134 | int kernel_rad = (kernel_size - 1) / 2; 135 | int displacement_rad = max_displacement / stride2; 136 | int displacement_size = 2 * displacement_rad + 1; 137 | 138 | int xmin = (x - kernel_rad - max_displacement) / stride1; 139 | int ymin = (y - kernel_rad - max_displacement) / stride1; 140 | 141 | int xmax = (x + kernel_rad - max_displacement) / stride1; 142 | int ymax = (y + kernel_rad - max_displacement) / stride1; 143 | 144 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 145 | // assumes gradInput1 is pre-allocated and zero filled 146 | return; 147 | } 148 | 149 | if (xmin > xmax || ymin > ymax) { 150 | // assumes gradInput1 is pre-allocated and zero filled 151 | return; 152 | } 153 | 154 | xmin = max(0,xmin); 155 | xmax = min(outputWidth-1,xmax); 156 | 157 | ymin = max(0,ymin); 158 | ymax = min(outputHeight-1,ymax); 159 | 160 | int pInputWidth = inputWidth + 2 * pad_size; 161 | int pInputHeight = inputHeight + 2 * pad_size; 162 | 163 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 164 | int pdimxc = pInputWidth * nInputChannels; 165 | int pdimc = nInputChannels; 166 | 167 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 168 | int tdimyx = outputHeight * outputWidth; 169 | int tdimx = outputWidth; 170 | 171 | int odimcyx = nInputChannels * inputHeight* inputWidth; 172 | int odimyx = inputHeight * inputWidth; 173 | int odimx = inputWidth; 174 | 175 | scalar_t nelems = kernel_size * kernel_size * nInputChannels; 176 | 177 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; 178 | prod_sum[tch_off] = 0; 179 | 180 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { 181 | 182 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 183 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 184 | 185 | int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c; 186 | 187 | scalar_t val2 = rInput2[indx2]; 188 | 189 | for (int j = ymin; j <= ymax; ++j) { 190 | for (int i = xmin; i <= xmax; ++i) { 191 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 192 | prod_sum[tch_off] += gradOutput[tindx] * val2; 193 | } 194 | } 195 | } 196 | __syncthreads(); 197 | 198 | if(tch_off == 0) { 199 | scalar_t reduce_sum = 0; 200 | for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { 201 | reduce_sum += prod_sum[idx]; 202 | } 203 | const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 204 | gradInput1[indx1] = reduce_sum / nelems; 205 | } 206 | 207 | } 208 | 209 | template 210 | __global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth, 211 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 212 | const scalar_t* __restrict__ rInput1, 213 | int pad_size, 214 | int kernel_size, 215 | int max_displacement, 216 | int stride1, 217 | int stride2) 218 | { 219 | // n (batch size), c (num of channels), y (height), x (width) 220 | 221 | int n = item; 222 | int y = blockIdx.x * stride1 + pad_size; 223 | int x = blockIdx.y * stride1 + pad_size; 224 | int c = blockIdx.z; 225 | 226 | int tch_off = threadIdx.x; 227 | 228 | int kernel_rad = (kernel_size - 1) / 2; 229 | int displacement_rad = max_displacement / stride2; 230 | int displacement_size = 2 * displacement_rad + 1; 231 | 232 | int pInputWidth = inputWidth + 2 * pad_size; 233 | int pInputHeight = inputHeight + 2 * pad_size; 234 | 235 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 236 | int pdimxc = pInputWidth * nInputChannels; 237 | int pdimc = nInputChannels; 238 | 239 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 240 | int tdimyx = outputHeight * outputWidth; 241 | int tdimx = outputWidth; 242 | 243 | int odimcyx = nInputChannels * inputHeight* inputWidth; 244 | int odimyx = inputHeight * inputWidth; 245 | int odimx = inputWidth; 246 | 247 | scalar_t nelems = kernel_size * kernel_size * nInputChannels; 248 | 249 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; 250 | prod_sum[tch_off] = 0; 251 | 252 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { 253 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 254 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 255 | 256 | int xmin = (x - kernel_rad - max_displacement - i2) / stride1; 257 | int ymin = (y - kernel_rad - max_displacement - j2) / stride1; 258 | 259 | int xmax = (x + kernel_rad - max_displacement - i2) / stride1; 260 | int ymax = (y + kernel_rad - max_displacement - j2) / stride1; 261 | 262 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 263 | // assumes gradInput2 is pre-allocated and zero filled 264 | continue; 265 | } 266 | 267 | if (xmin > xmax || ymin > ymax) { 268 | // assumes gradInput2 is pre-allocated and zero filled 269 | continue; 270 | } 271 | 272 | xmin = max(0,xmin); 273 | xmax = min(outputWidth-1,xmax); 274 | 275 | ymin = max(0,ymin); 276 | ymax = min(outputHeight-1,ymax); 277 | 278 | int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c; 279 | scalar_t val1 = rInput1[indx1]; 280 | 281 | for (int j = ymin; j <= ymax; ++j) { 282 | for (int i = xmin; i <= xmax; ++i) { 283 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 284 | prod_sum[tch_off] += gradOutput[tindx] * val1; 285 | } 286 | } 287 | } 288 | 289 | __syncthreads(); 290 | 291 | if(tch_off == 0) { 292 | scalar_t reduce_sum = 0; 293 | for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { 294 | reduce_sum += prod_sum[idx]; 295 | } 296 | const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 297 | gradInput2[indx2] = reduce_sum / nelems; 298 | } 299 | 300 | } 301 | 302 | int correlation_forward_cuda_kernel(at::Tensor& output, 303 | int ob, 304 | int oc, 305 | int oh, 306 | int ow, 307 | int osb, 308 | int osc, 309 | int osh, 310 | int osw, 311 | 312 | at::Tensor& input1, 313 | int ic, 314 | int ih, 315 | int iw, 316 | int isb, 317 | int isc, 318 | int ish, 319 | int isw, 320 | 321 | at::Tensor& input2, 322 | int gc, 323 | int gsb, 324 | int gsc, 325 | int gsh, 326 | int gsw, 327 | 328 | at::Tensor& rInput1, 329 | at::Tensor& rInput2, 330 | int pad_size, 331 | int kernel_size, 332 | int max_displacement, 333 | int stride1, 334 | int stride2, 335 | int corr_type_multiply, 336 | cudaStream_t stream) 337 | { 338 | 339 | int batchSize = ob; 340 | 341 | int nInputChannels = ic; 342 | int inputWidth = iw; 343 | int inputHeight = ih; 344 | 345 | int nOutputChannels = oc; 346 | int outputWidth = ow; 347 | int outputHeight = oh; 348 | 349 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 350 | dim3 threads_block(THREADS_PER_BLOCK); 351 | 352 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] { 353 | 354 | channels_first<<>>( 355 | input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size); 356 | 357 | })); 358 | 359 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] { 360 | 361 | channels_first<<>> ( 362 | input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size); 363 | 364 | })); 365 | 366 | dim3 threadsPerBlock(THREADS_PER_BLOCK); 367 | dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth); 368 | 369 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] { 370 | 371 | correlation_forward<<>> 372 | (output.data(), nOutputChannels, outputHeight, outputWidth, 373 | rInput1.data(), nInputChannels, inputHeight, inputWidth, 374 | rInput2.data(), 375 | pad_size, 376 | kernel_size, 377 | max_displacement, 378 | stride1, 379 | stride2); 380 | 381 | })); 382 | 383 | cudaError_t err = cudaGetLastError(); 384 | 385 | 386 | // check for errors 387 | if (err != cudaSuccess) { 388 | printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err)); 389 | return 0; 390 | } 391 | 392 | return 1; 393 | } 394 | 395 | 396 | int correlation_backward_cuda_kernel( 397 | at::Tensor& gradOutput, 398 | int gob, 399 | int goc, 400 | int goh, 401 | int gow, 402 | int gosb, 403 | int gosc, 404 | int gosh, 405 | int gosw, 406 | 407 | at::Tensor& input1, 408 | int ic, 409 | int ih, 410 | int iw, 411 | int isb, 412 | int isc, 413 | int ish, 414 | int isw, 415 | 416 | at::Tensor& input2, 417 | int gsb, 418 | int gsc, 419 | int gsh, 420 | int gsw, 421 | 422 | at::Tensor& gradInput1, 423 | int gisb, 424 | int gisc, 425 | int gish, 426 | int gisw, 427 | 428 | at::Tensor& gradInput2, 429 | int ggc, 430 | int ggsb, 431 | int ggsc, 432 | int ggsh, 433 | int ggsw, 434 | 435 | at::Tensor& rInput1, 436 | at::Tensor& rInput2, 437 | int pad_size, 438 | int kernel_size, 439 | int max_displacement, 440 | int stride1, 441 | int stride2, 442 | int corr_type_multiply, 443 | cudaStream_t stream) 444 | { 445 | 446 | int batchSize = gob; 447 | int num = batchSize; 448 | 449 | int nInputChannels = ic; 450 | int inputWidth = iw; 451 | int inputHeight = ih; 452 | 453 | int nOutputChannels = goc; 454 | int outputWidth = gow; 455 | int outputHeight = goh; 456 | 457 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 458 | dim3 threads_block(THREADS_PER_BLOCK); 459 | 460 | 461 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] { 462 | 463 | channels_first<<>>( 464 | input1.data(), 465 | rInput1.data(), 466 | nInputChannels, 467 | inputHeight, 468 | inputWidth, 469 | pad_size 470 | ); 471 | })); 472 | 473 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { 474 | 475 | channels_first<<>>( 476 | input2.data(), 477 | rInput2.data(), 478 | nInputChannels, 479 | inputHeight, 480 | inputWidth, 481 | pad_size 482 | ); 483 | })); 484 | 485 | dim3 threadsPerBlock(THREADS_PER_BLOCK); 486 | dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels); 487 | 488 | for (int n = 0; n < num; ++n) { 489 | 490 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { 491 | 492 | 493 | correlation_backward_input1<<>> ( 494 | n, gradInput1.data(), nInputChannels, inputHeight, inputWidth, 495 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth, 496 | rInput2.data(), 497 | pad_size, 498 | kernel_size, 499 | max_displacement, 500 | stride1, 501 | stride2); 502 | })); 503 | } 504 | 505 | for(int n = 0; n < batchSize; n++) { 506 | 507 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] { 508 | 509 | correlation_backward_input2<<>>( 510 | n, gradInput2.data(), nInputChannels, inputHeight, inputWidth, 511 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth, 512 | rInput1.data(), 513 | pad_size, 514 | kernel_size, 515 | max_displacement, 516 | stride1, 517 | stride2); 518 | 519 | })); 520 | } 521 | 522 | // check for errors 523 | cudaError_t err = cudaGetLastError(); 524 | if (err != cudaSuccess) { 525 | printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err)); 526 | return 0; 527 | } 528 | 529 | return 1; 530 | } 531 | -------------------------------------------------------------------------------- /PWC_src/correlation_package/correlation_cuda_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | int correlation_forward_cuda_kernel(at::Tensor& output, 8 | int ob, 9 | int oc, 10 | int oh, 11 | int ow, 12 | int osb, 13 | int osc, 14 | int osh, 15 | int osw, 16 | 17 | at::Tensor& input1, 18 | int ic, 19 | int ih, 20 | int iw, 21 | int isb, 22 | int isc, 23 | int ish, 24 | int isw, 25 | 26 | at::Tensor& input2, 27 | int gc, 28 | int gsb, 29 | int gsc, 30 | int gsh, 31 | int gsw, 32 | 33 | at::Tensor& rInput1, 34 | at::Tensor& rInput2, 35 | int pad_size, 36 | int kernel_size, 37 | int max_displacement, 38 | int stride1, 39 | int stride2, 40 | int corr_type_multiply, 41 | cudaStream_t stream); 42 | 43 | 44 | int correlation_backward_cuda_kernel( 45 | at::Tensor& gradOutput, 46 | int gob, 47 | int goc, 48 | int goh, 49 | int gow, 50 | int gosb, 51 | int gosc, 52 | int gosh, 53 | int gosw, 54 | 55 | at::Tensor& input1, 56 | int ic, 57 | int ih, 58 | int iw, 59 | int isb, 60 | int isc, 61 | int ish, 62 | int isw, 63 | 64 | at::Tensor& input2, 65 | int gsb, 66 | int gsc, 67 | int gsh, 68 | int gsw, 69 | 70 | at::Tensor& gradInput1, 71 | int gisb, 72 | int gisc, 73 | int gish, 74 | int gisw, 75 | 76 | at::Tensor& gradInput2, 77 | int ggc, 78 | int ggsb, 79 | int ggsc, 80 | int ggsh, 81 | int ggsw, 82 | 83 | at::Tensor& rInput1, 84 | at::Tensor& rInput2, 85 | int pad_size, 86 | int kernel_size, 87 | int max_displacement, 88 | int stride1, 89 | int stride2, 90 | int corr_type_multiply, 91 | cudaStream_t stream); 92 | -------------------------------------------------------------------------------- /PWC_src/correlation_package/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from setuptools import setup, find_packages 5 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 6 | 7 | cxx_args = ['-std=c++11'] 8 | 9 | nvcc_args = [ 10 | '-gencode', 'arch=compute_37,code=sm_37', 11 | '-gencode', 'arch=compute_50,code=sm_50', 12 | '-gencode', 'arch=compute_52,code=sm_52', 13 | '-gencode', 'arch=compute_60,code=sm_60', 14 | '-gencode', 'arch=compute_61,code=sm_61', 15 | '-gencode', 'arch=compute_70,code=sm_70', 16 | '-gencode', 'arch=compute_70,code=compute_70' 17 | ] 18 | 19 | setup( 20 | name='correlation_cuda', 21 | ext_modules=[ 22 | CUDAExtension('correlation_cuda', [ 23 | 'correlation_cuda.cc', 24 | 'correlation_cuda_kernel.cu' 25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 26 | ], 27 | cmdclass={ 28 | 'build_ext': BuildExtension 29 | }) 30 | -------------------------------------------------------------------------------- /PWC_src/flowlib.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | """ 3 | # ============================== 4 | # flowlib.py 5 | # library for optical flow processing 6 | # Author: Ruoteng Li 7 | # Date: 6th Aug 2016 8 | # ============================== 9 | """ 10 | import png 11 | import numpy as np 12 | import matplotlib.colors as cl 13 | import matplotlib.pyplot as plt 14 | from PIL import Image 15 | 16 | 17 | UNKNOWN_FLOW_THRESH = 1e7 18 | SMALLFLOW = 0.0 19 | LARGEFLOW = 1e8 20 | 21 | """ 22 | ============= 23 | Flow Section 24 | ============= 25 | """ 26 | 27 | 28 | def show_flow(filename): 29 | """ 30 | visualize optical flow map using matplotlib 31 | :param filename: optical flow file 32 | :return: None 33 | """ 34 | flow = read_flow(filename) 35 | img = flow_to_image(flow) 36 | plt.imshow(img) 37 | plt.show() 38 | 39 | 40 | def visualize_flow(flow, mode='Y'): 41 | """ 42 | this function visualize the input flow 43 | :param flow: input flow in array 44 | :param mode: choose which color mode to visualize the flow (Y: Ccbcr, RGB: RGB color) 45 | :return: None 46 | """ 47 | if mode == 'Y': 48 | # Ccbcr color wheel 49 | img = flow_to_image(flow) 50 | plt.imshow(img) 51 | plt.show() 52 | elif mode == 'RGB': 53 | (h, w) = flow.shape[0:2] 54 | du = flow[:, :, 0] 55 | dv = flow[:, :, 1] 56 | valid = flow[:, :, 2] 57 | max_flow = max(np.max(du), np.max(dv)) 58 | img = np.zeros((h, w, 3), dtype=np.float64) 59 | # angle layer 60 | img[:, :, 0] = np.arctan2(dv, du) / (2 * np.pi) 61 | # magnitude layer, normalized to 1 62 | img[:, :, 1] = np.sqrt(du * du + dv * dv) * 8 / max_flow 63 | # phase layer 64 | img[:, :, 2] = 8 - img[:, :, 1] 65 | # clip to [0,1] 66 | small_idx = img[:, :, 0:3] < 0 67 | large_idx = img[:, :, 0:3] > 1 68 | img[small_idx] = 0 69 | img[large_idx] = 1 70 | # convert to rgb 71 | img = cl.hsv_to_rgb(img) 72 | # remove invalid point 73 | img[:, :, 0] = img[:, :, 0] * valid 74 | img[:, :, 1] = img[:, :, 1] * valid 75 | img[:, :, 2] = img[:, :, 2] * valid 76 | # show 77 | plt.imshow(img) 78 | plt.show() 79 | 80 | return None 81 | 82 | 83 | def read_flow(filename): 84 | """ 85 | read optical flow from Middlebury .flo file 86 | :param filename: name of the flow file 87 | :return: optical flow data in matrix 88 | """ 89 | f = open(filename, 'rb') 90 | try: 91 | magic = np.fromfile(f, np.float32, count=1)[0] # For Python3.x 92 | except: 93 | magic = np.fromfile(f, np.float32, count=1) # For Python2.x 94 | data2d = None 95 | 96 | if 202021.25 != magic: 97 | print('Magic number incorrect. Invalid .flo file') 98 | else: 99 | w = np.fromfile(f, np.int32, count=1) 100 | h = np.fromfile(f, np.int32, count=1) 101 | #print("Reading %d x %d flo file" % (h, w)) 102 | data2d = np.fromfile(f, np.float32, count=2 * w[0] * h[0]) 103 | # reshape data into 3D array (columns, rows, channels) 104 | data2d = np.resize(data2d, (h[0], w[0], 2)) 105 | f.close() 106 | return data2d 107 | 108 | 109 | def read_flow_png(flow_file): 110 | """ 111 | Read optical flow from KITTI .png file 112 | :param flow_file: name of the flow file 113 | :return: optical flow data in matrix 114 | """ 115 | flow_object = png.Reader(filename=flow_file) 116 | flow_direct = flow_object.asDirect() 117 | flow_data = list(flow_direct[2]) 118 | (w, h) = flow_direct[3]['size'] 119 | flow = np.zeros((h, w, 3), dtype=np.float64) 120 | for i in range(len(flow_data)): 121 | flow[i, :, 0] = flow_data[i][0::3] 122 | flow[i, :, 1] = flow_data[i][1::3] 123 | flow[i, :, 2] = flow_data[i][2::3] 124 | 125 | invalid_idx = (flow[:, :, 2] == 0) 126 | flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0 127 | flow[invalid_idx, 0] = 0 128 | flow[invalid_idx, 1] = 0 129 | return flow 130 | 131 | 132 | def write_flow(flow, filename): 133 | """ 134 | write optical flow in Middlebury .flo format 135 | :param flow: optical flow map 136 | :param filename: optical flow file path to be saved 137 | :return: None 138 | """ 139 | f = open(filename, 'wb') 140 | magic = np.array([202021.25], dtype=np.float32) 141 | (height, width) = flow.shape[0:2] 142 | w = np.array([width], dtype=np.int32) 143 | h = np.array([height], dtype=np.int32) 144 | magic.tofile(f) 145 | w.tofile(f) 146 | h.tofile(f) 147 | flow.tofile(f) 148 | f.close() 149 | 150 | 151 | def segment_flow(flow): 152 | h = flow.shape[0] 153 | w = flow.shape[1] 154 | u = flow[:, :, 0] 155 | v = flow[:, :, 1] 156 | 157 | idx = ((abs(u) > LARGEFLOW) | (abs(v) > LARGEFLOW)) 158 | idx2 = (abs(u) == SMALLFLOW) 159 | class0 = (v == 0) & (u == 0) 160 | u[idx2] = 0.00001 161 | tan_value = v / u 162 | 163 | class1 = (tan_value < 1) & (tan_value >= 0) & (u > 0) & (v >= 0) 164 | class2 = (tan_value >= 1) & (u >= 0) & (v >= 0) 165 | class3 = (tan_value < -1) & (u <= 0) & (v >= 0) 166 | class4 = (tan_value < 0) & (tan_value >= -1) & (u < 0) & (v >= 0) 167 | class8 = (tan_value >= -1) & (tan_value < 0) & (u > 0) & (v <= 0) 168 | class7 = (tan_value < -1) & (u >= 0) & (v <= 0) 169 | class6 = (tan_value >= 1) & (u <= 0) & (v <= 0) 170 | class5 = (tan_value >= 0) & (tan_value < 1) & (u < 0) & (v <= 0) 171 | 172 | seg = np.zeros((h, w)) 173 | 174 | seg[class1] = 1 175 | seg[class2] = 2 176 | seg[class3] = 3 177 | seg[class4] = 4 178 | seg[class5] = 5 179 | seg[class6] = 6 180 | seg[class7] = 7 181 | seg[class8] = 8 182 | seg[class0] = 0 183 | seg[idx] = 0 184 | 185 | return seg 186 | 187 | 188 | def flow_error(tu, tv, u, v): 189 | """ 190 | Calculate average end point error 191 | :param tu: ground-truth horizontal flow map 192 | :param tv: ground-truth vertical flow map 193 | :param u: estimated horizontal flow map 194 | :param v: estimated vertical flow map 195 | :return: End point error of the estimated flow 196 | """ 197 | smallflow = 0.0 198 | ''' 199 | stu = tu[bord+1:end-bord,bord+1:end-bord] 200 | stv = tv[bord+1:end-bord,bord+1:end-bord] 201 | su = u[bord+1:end-bord,bord+1:end-bord] 202 | sv = v[bord+1:end-bord,bord+1:end-bord] 203 | ''' 204 | stu = tu[:] 205 | stv = tv[:] 206 | su = u[:] 207 | sv = v[:] 208 | 209 | idxUnknow = (abs(stu) > UNKNOWN_FLOW_THRESH) | (abs(stv) > UNKNOWN_FLOW_THRESH) 210 | stu[idxUnknow] = 0 211 | stv[idxUnknow] = 0 212 | su[idxUnknow] = 0 213 | sv[idxUnknow] = 0 214 | 215 | ind2 = [(np.absolute(stu) > smallflow) | (np.absolute(stv) > smallflow)] 216 | index_su = su[ind2] 217 | index_sv = sv[ind2] 218 | an = 1.0 / np.sqrt(index_su ** 2 + index_sv ** 2 + 1) 219 | un = index_su * an 220 | vn = index_sv * an 221 | 222 | index_stu = stu[ind2] 223 | index_stv = stv[ind2] 224 | tn = 1.0 / np.sqrt(index_stu ** 2 + index_stv ** 2 + 1) 225 | tun = index_stu * tn 226 | tvn = index_stv * tn 227 | 228 | ''' 229 | angle = un * tun + vn * tvn + (an * tn) 230 | index = [angle == 1.0] 231 | angle[index] = 0.999 232 | ang = np.arccos(angle) 233 | mang = np.mean(ang) 234 | mang = mang * 180 / np.pi 235 | ''' 236 | 237 | epe = np.sqrt((stu - su) ** 2 + (stv - sv) ** 2) 238 | epe = epe[ind2] 239 | mepe = np.mean(epe) 240 | return mepe 241 | 242 | 243 | def flow_to_image(flow, display=False): 244 | """ 245 | Convert flow into middlebury color code image 246 | :param flow: optical flow map 247 | :return: optical flow image in middlebury color 248 | """ 249 | u = flow[:, :, 0] 250 | v = flow[:, :, 1] 251 | 252 | maxu = -999. 253 | maxv = -999. 254 | minu = 999. 255 | minv = 999. 256 | 257 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) 258 | u[idxUnknow] = 0 259 | v[idxUnknow] = 0 260 | 261 | maxu = max(maxu, np.max(u)) 262 | minu = min(minu, np.min(u)) 263 | 264 | maxv = max(maxv, np.max(v)) 265 | minv = min(minv, np.min(v)) 266 | 267 | rad = np.sqrt(u ** 2 + v ** 2) 268 | maxrad = max(-1, np.max(rad)) 269 | 270 | if display: 271 | print("max flow: %.4f\nflow range:\nu = %.3f .. %.3f\nv = %.3f .. %.3f" % (maxrad, minu,maxu, minv, maxv)) 272 | 273 | u = u/(maxrad + np.finfo(float).eps) 274 | v = v/(maxrad + np.finfo(float).eps) 275 | 276 | img = compute_color(u, v) 277 | 278 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) 279 | img[idx] = 0 280 | 281 | return np.uint8(img) 282 | 283 | 284 | def evaluate_flow_file(gt, pred): 285 | """ 286 | evaluate the estimated optical flow end point error according to ground truth provided 287 | :param gt: ground truth file path 288 | :param pred: estimated optical flow file path 289 | :return: end point error, float32 290 | """ 291 | # Read flow files and calculate the errors 292 | gt_flow = read_flow(gt) # ground truth flow 293 | eva_flow = read_flow(pred) # predicted flow 294 | # Calculate errors 295 | average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], eva_flow[:, :, 0], eva_flow[:, :, 1]) 296 | return average_pe 297 | 298 | 299 | def evaluate_flow(gt_flow, pred_flow): 300 | """ 301 | gt: ground-truth flow 302 | pred: estimated flow 303 | """ 304 | average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], pred_flow[:, :, 0], pred_flow[:, :, 1]) 305 | return average_pe 306 | 307 | 308 | """ 309 | ============== 310 | Disparity Section 311 | ============== 312 | """ 313 | 314 | 315 | def read_disp_png(file_name): 316 | """ 317 | Read optical flow from KITTI .png file 318 | :param file_name: name of the flow file 319 | :return: optical flow data in matrix 320 | """ 321 | image_object = png.Reader(filename=file_name) 322 | image_direct = image_object.asDirect() 323 | image_data = list(image_direct[2]) 324 | (w, h) = image_direct[3]['size'] 325 | channel = len(image_data[0]) / w 326 | flow = np.zeros((h, w, channel), dtype=np.uint16) 327 | for i in range(len(image_data)): 328 | for j in range(channel): 329 | flow[i, :, j] = image_data[i][j::channel] 330 | return flow[:, :, 0] / 256 331 | 332 | 333 | def disp_to_flowfile(disp, filename): 334 | """ 335 | Read KITTI disparity file in png format 336 | :param disp: disparity matrix 337 | :param filename: the flow file name to save 338 | :return: None 339 | """ 340 | f = open(filename, 'wb') 341 | magic = np.array([202021.25], dtype=np.float32) 342 | (height, width) = disp.shape[0:2] 343 | w = np.array([width], dtype=np.int32) 344 | h = np.array([height], dtype=np.int32) 345 | empty_map = np.zeros((height, width), dtype=np.float32) 346 | data = np.dstack((disp, empty_map)) 347 | magic.tofile(f) 348 | w.tofile(f) 349 | h.tofile(f) 350 | data.tofile(f) 351 | f.close() 352 | 353 | 354 | """ 355 | ============== 356 | Image Section 357 | ============== 358 | """ 359 | 360 | 361 | def read_image(filename): 362 | """ 363 | Read normal image of any format 364 | :param filename: name of the image file 365 | :return: image data in matrix uint8 type 366 | """ 367 | img = Image.open(filename) 368 | im = np.array(img) 369 | return im 370 | 371 | 372 | def warp_image(im, flow): 373 | """ 374 | Use optical flow to warp image to the next 375 | :param im: image to warp 376 | :param flow: optical flow 377 | :return: warped image 378 | """ 379 | from scipy import interpolate 380 | image_height = im.shape[0] 381 | image_width = im.shape[1] 382 | flow_height = flow.shape[0] 383 | flow_width = flow.shape[1] 384 | n = image_height * image_width 385 | (iy, ix) = np.mgrid[0:image_height, 0:image_width] 386 | (fy, fx) = np.mgrid[0:flow_height, 0:flow_width] 387 | fx += flow[:,:,0] 388 | fy += flow[:,:,1] 389 | mask = np.logical_or(fx <0 , fx > flow_width) 390 | mask = np.logical_or(mask, fy < 0) 391 | mask = np.logical_or(mask, fy > flow_height) 392 | fx = np.minimum(np.maximum(fx, 0), flow_width) 393 | fy = np.minimum(np.maximum(fy, 0), flow_height) 394 | points = np.concatenate((ix.reshape(n,1), iy.reshape(n,1)), axis=1) 395 | xi = np.concatenate((fx.reshape(n, 1), fy.reshape(n,1)), axis=1) 396 | warp = np.zeros((image_height, image_width, im.shape[2])) 397 | for i in range(im.shape[2]): 398 | channel = im[:, :, i] 399 | plt.imshow(channel, cmap='gray') 400 | values = channel.reshape(n, 1) 401 | new_channel = interpolate.griddata(points, values, xi, method='cubic') 402 | new_channel = np.reshape(new_channel, [flow_height, flow_width]) 403 | new_channel[mask] = 1 404 | warp[:, :, i] = new_channel.astype(np.uint8) 405 | 406 | return warp.astype(np.uint8) 407 | 408 | 409 | """ 410 | ============== 411 | Others 412 | ============== 413 | """ 414 | 415 | def scale_image(image, new_range): 416 | """ 417 | Linearly scale the image into desired range 418 | :param image: input image 419 | :param new_range: the new range to be aligned 420 | :return: image normalized in new range 421 | """ 422 | min_val = np.min(image).astype(np.float32) 423 | max_val = np.max(image).astype(np.float32) 424 | min_val_new = np.array(min(new_range), dtype=np.float32) 425 | max_val_new = np.array(max(new_range), dtype=np.float32) 426 | scaled_image = (image - min_val) / (max_val - min_val) * (max_val_new - min_val_new) + min_val_new 427 | return scaled_image.astype(np.uint8) 428 | 429 | 430 | def compute_color(u, v): 431 | """ 432 | compute optical flow color map 433 | :param u: optical flow horizontal map 434 | :param v: optical flow vertical map 435 | :return: optical flow in color code 436 | """ 437 | [h, w] = u.shape 438 | img = np.zeros([h, w, 3]) 439 | nanIdx = np.isnan(u) | np.isnan(v) 440 | u[nanIdx] = 0 441 | v[nanIdx] = 0 442 | 443 | colorwheel = make_color_wheel() 444 | ncols = np.size(colorwheel, 0) 445 | 446 | rad = np.sqrt(u**2+v**2) 447 | 448 | a = np.arctan2(-v, -u) / np.pi 449 | 450 | fk = (a+1) / 2 * (ncols - 1) + 1 451 | 452 | k0 = np.floor(fk).astype(int) 453 | 454 | k1 = k0 + 1 455 | k1[k1 == ncols+1] = 1 456 | f = fk - k0 457 | 458 | for i in range(0, np.size(colorwheel,1)): 459 | tmp = colorwheel[:, i] 460 | col0 = tmp[k0-1] / 255 461 | col1 = tmp[k1-1] / 255 462 | col = (1-f) * col0 + f * col1 463 | 464 | idx = rad <= 1 465 | col[idx] = 1-rad[idx]*(1-col[idx]) 466 | notidx = np.logical_not(idx) 467 | 468 | col[notidx] *= 0.75 469 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) 470 | 471 | return img 472 | 473 | 474 | def make_color_wheel(): 475 | """ 476 | Generate color wheel according Middlebury color code 477 | :return: Color wheel 478 | """ 479 | RY = 15 480 | YG = 6 481 | GC = 4 482 | CB = 11 483 | BM = 13 484 | MR = 6 485 | 486 | ncols = RY + YG + GC + CB + BM + MR 487 | 488 | colorwheel = np.zeros([ncols, 3]) 489 | 490 | col = 0 491 | 492 | # RY 493 | colorwheel[0:RY, 0] = 255 494 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) 495 | col += RY 496 | 497 | # YG 498 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) 499 | colorwheel[col:col+YG, 1] = 255 500 | col += YG 501 | 502 | # GC 503 | colorwheel[col:col+GC, 1] = 255 504 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) 505 | col += GC 506 | 507 | # CB 508 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) 509 | colorwheel[col:col+CB, 2] = 255 510 | col += CB 511 | 512 | # BM 513 | colorwheel[col:col+BM, 2] = 255 514 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) 515 | col += + BM 516 | 517 | # MR 518 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 519 | colorwheel[col:col+MR, 0] = 255 520 | 521 | return colorwheel 522 | -------------------------------------------------------------------------------- /PWC_src/pwc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import getopt 4 | import math 5 | import numpy 6 | import os 7 | import PIL 8 | import PIL.Image 9 | import sys 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from .correlation_package import correlation 15 | 16 | 17 | class Extractor(nn.Module): 18 | def __init__(self): 19 | super(Extractor, self).__init__() 20 | 21 | self.moduleOne = nn.Sequential( 22 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1), 23 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 24 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), 25 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 26 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), 27 | nn.LeakyReLU(inplace=False, negative_slope=0.1) 28 | ) 29 | 30 | self.moduleTwo = nn.Sequential( 31 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), 32 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 33 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), 34 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 35 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), 36 | nn.LeakyReLU(inplace=False, negative_slope=0.1) 37 | ) 38 | 39 | self.moduleThr = nn.Sequential( 40 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), 41 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 42 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 43 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 44 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 45 | nn.LeakyReLU(inplace=False, negative_slope=0.1) 46 | ) 47 | 48 | self.moduleFou = nn.Sequential( 49 | nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1), 50 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 51 | nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), 52 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 53 | nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), 54 | nn.LeakyReLU(inplace=False, negative_slope=0.1) 55 | ) 56 | 57 | self.moduleFiv = nn.Sequential( 58 | nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1), 59 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 60 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 61 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 62 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 63 | nn.LeakyReLU(inplace=False, negative_slope=0.1) 64 | ) 65 | 66 | self.moduleSix = nn.Sequential( 67 | nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1), 68 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 69 | nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), 70 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 71 | nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), 72 | nn.LeakyReLU(inplace=False, negative_slope=0.1) 73 | ) 74 | 75 | def forward(self, tensorInput): 76 | tensorOne = self.moduleOne(tensorInput) 77 | tensorTwo = self.moduleTwo(tensorOne) 78 | tensorThr = self.moduleThr(tensorTwo) 79 | tensorFou = self.moduleFou(tensorThr) 80 | tensorFiv = self.moduleFiv(tensorFou) 81 | tensorSix = self.moduleSix(tensorFiv) 82 | 83 | return [ tensorOne, tensorTwo, tensorThr, tensorFou, tensorFiv, tensorSix ] 84 | 85 | 86 | class Backward(nn.Module): 87 | def __init__(self): 88 | super(Backward, self).__init__() 89 | 90 | def forward(self, tensorInput, tensorFlow): 91 | if hasattr(self, 'tensorPartial') == False or self.tensorPartial.size(0) != tensorFlow.size(0) or self.tensorPartial.size(2) != tensorFlow.size(2) or self.tensorPartial.size(3) != tensorFlow.size(3): 92 | self.tensorPartial = tensorFlow.new_ones(tensorFlow.size(0), 1, tensorFlow.size(2), tensorFlow.size(3)) 93 | 94 | if hasattr(self, 'tensorGrid') == False or self.tensorGrid.size(0) != tensorFlow.size(0) or self.tensorGrid.size(2) != tensorFlow.size(2) or self.tensorGrid.size(3) != tensorFlow.size(3): 95 | tensorHorizontal = torch.linspace(-1.0, 1.0, tensorFlow.size(3)).view(1, 1, 1, tensorFlow.size(3)).expand(tensorFlow.size(0), -1, tensorFlow.size(2), -1) 96 | tensorVertical = torch.linspace(-1.0, 1.0, tensorFlow.size(2)).view(1, 1, tensorFlow.size(2), 1).expand(tensorFlow.size(0), -1, -1, tensorFlow.size(3)) 97 | self.tensorGrid = torch.cat([tensorHorizontal, tensorVertical], 1).cuda() 98 | 99 | tensorInput = torch.cat([tensorInput, self.tensorPartial], 1) 100 | tensorFlow = torch.cat([tensorFlow[:, 0:1, :, :]/((tensorInput.size(3)-1.0)/2.0), tensorFlow[:, 1:2, :, :]/((tensorInput.size(2)-1.0)/2.0)], 1) 101 | 102 | tensorOutput = F.grid_sample(input=tensorInput, grid=(self.tensorGrid + tensorFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros') 103 | tensorMask = tensorOutput[:, -1:, :, :]; tensorMask[tensorMask > 0.999] = 1.0; tensorMask[tensorMask < 1.0] = 0.0 104 | 105 | return tensorOutput[:, :-1, :, :] * tensorMask 106 | 107 | 108 | class Decoder(nn.Module): 109 | def __init__(self, intLevel): 110 | super(Decoder, self).__init__() 111 | 112 | intPrevious = [None, None, 81+32+2+2, 81+64+2+2, 81+96+2+2, 81+128+2+2, 81, None][intLevel+1] 113 | intCurrent = [None, None, 81+32+2+2, 81+64+2+2, 81+96+2+2, 81+128+2+2, 81, None][intLevel+0] 114 | 115 | if intLevel < 6: self.moduleUpflow = nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1) 116 | if intLevel < 6: self.moduleUpfeat = nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1) 117 | 118 | if intLevel < 6: self.dblBackward = [None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel+1] 119 | if intLevel < 6: self.moduleBackward = Backward() 120 | 121 | self.moduleCorrelation = correlation.Correlation() 122 | self.moduleCorreleaky = nn.LeakyReLU(inplace=False, negative_slope=0.1) 123 | 124 | self.moduleOne = nn.Sequential( 125 | nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1), 126 | nn.LeakyReLU(inplace=False, negative_slope=0.1) 127 | ) 128 | 129 | self.moduleTwo = nn.Sequential( 130 | nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1), 131 | nn.LeakyReLU(inplace=False, negative_slope=0.1) 132 | ) 133 | 134 | self.moduleThr = nn.Sequential( 135 | nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1), 136 | nn.LeakyReLU(inplace=False, negative_slope=0.1) 137 | ) 138 | 139 | self.moduleFou = nn.Sequential( 140 | nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1), 141 | nn.LeakyReLU(inplace=False, negative_slope=0.1) 142 | ) 143 | 144 | self.moduleFiv = nn.Sequential( 145 | nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1), 146 | nn.LeakyReLU(inplace=False, negative_slope=0.1) 147 | ) 148 | 149 | self.moduleSix = nn.Sequential( 150 | nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1) 151 | ) 152 | 153 | def forward(self, tensorFirst, tensorSecond, objectPrevious): 154 | tensorFlow = None 155 | tensorFeat = None 156 | 157 | if objectPrevious is None: 158 | tensorFlow = None 159 | tensorFeat = None 160 | 161 | tensorVolume = self.moduleCorreleaky(self.moduleCorrelation(tensorFirst, tensorSecond)) 162 | tensorFeat = torch.cat([tensorVolume], 1) 163 | 164 | elif objectPrevious is not None: 165 | tensorFlow = self.moduleUpflow(objectPrevious['tensorFlow']) 166 | tensorFeat = self.moduleUpfeat(objectPrevious['tensorFeat']) 167 | tensorVolume = self.moduleCorreleaky(self.moduleCorrelation(tensorFirst, self.moduleBackward(tensorSecond, tensorFlow*self.dblBackward))) 168 | tensorFeat = torch.cat([tensorVolume, tensorFirst, tensorFlow, tensorFeat], 1) 169 | 170 | tensorFeat = torch.cat([self.moduleOne(tensorFeat), tensorFeat], 1) 171 | tensorFeat = torch.cat([self.moduleTwo(tensorFeat), tensorFeat], 1) 172 | tensorFeat = torch.cat([self.moduleThr(tensorFeat), tensorFeat], 1) 173 | tensorFeat = torch.cat([self.moduleFou(tensorFeat), tensorFeat], 1) 174 | tensorFeat = torch.cat([self.moduleFiv(tensorFeat), tensorFeat], 1) 175 | tensorFlow = self.moduleSix(tensorFeat) 176 | 177 | return { 178 | 'tensorFlow': tensorFlow, 179 | 'tensorFeat': tensorFeat 180 | } 181 | 182 | 183 | class Refiner(nn.Module): 184 | def __init__(self): 185 | super(Refiner, self).__init__() 186 | 187 | self.moduleMain = nn.Sequential( 188 | nn.Conv2d(in_channels=81+32+2+2+128+128+96+64+32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1), 189 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 190 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2), 191 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 192 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4), 193 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 194 | nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8), 195 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 196 | nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16), 197 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 198 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1), 199 | nn.LeakyReLU(inplace=False, negative_slope=0.1), 200 | nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1) 201 | ) 202 | 203 | def forward(self, tensorInput): 204 | return self.moduleMain(tensorInput) 205 | 206 | 207 | class PWC_Net(nn.Module): 208 | def __init__(self, model_path=None): 209 | super(PWC_Net, self).__init__() 210 | self.model_path = model_path 211 | 212 | self.moduleExtractor = Extractor() 213 | self.moduleTwo = Decoder(2) 214 | self.moduleThr = Decoder(3) 215 | self.moduleFou = Decoder(4) 216 | self.moduleFiv = Decoder(5) 217 | self.moduleSix = Decoder(6) 218 | self.moduleRefiner = Refiner() 219 | self.load_state_dict(torch.load(self.model_path)) 220 | 221 | def forward(self, tensorFirst, tensorSecond): 222 | tensorFirst = self.moduleExtractor(tensorFirst) 223 | tensorSecond = self.moduleExtractor(tensorSecond) 224 | 225 | objectEstimate = self.moduleSix(tensorFirst[-1], tensorSecond[-1], None) 226 | objectEstimate = self.moduleFiv(tensorFirst[-2], tensorSecond[-2], objectEstimate) 227 | objectEstimate = self.moduleFou(tensorFirst[-3], tensorSecond[-3], objectEstimate) 228 | objectEstimate = self.moduleThr(tensorFirst[-4], tensorSecond[-4], objectEstimate) 229 | objectEstimate = self.moduleTwo(tensorFirst[-5], tensorSecond[-5], objectEstimate) 230 | 231 | return objectEstimate['tensorFlow'] + self.moduleRefiner(objectEstimate['tensorFeat']) 232 | 233 | 234 | if __name__ == '__main__': 235 | net = PWC_Net(model_path='models/sintel.pytorch') 236 | net.cuda() 237 | 238 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## PWC-Net (PyTorch v1.0.1) 2 | 3 | Pytorch implementation of [PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume](https://arxiv.org/pdf/1709.02371.pdf). We made it as a off-the-shelf package: 4 | - After installation, just copy the whole folder `PWC_src` to your codebase to use. See demo.py for details. 5 | 6 | ### Environment 7 | 8 | This code has been test with Python3.6 and PyTorch1.0.1, with a Tesla K80 GPU. The system is Ubuntu 14.04, and the CUDA version is 10.0. All the required python packages can be found in `requirements.txt`. 9 | 10 | ### Installation 11 | 12 | # install custom layers 13 | cd PWC_src/correlation_package 14 | python setup.py install 15 | 16 | Note: you might need to add `gencode` [here](https://github.com/vt-vl-lab/pwc-net.pytorch/blob/master/PWC_src/correlation_package/setup.py#L9), according to the GPU you use. You can find more information about `gencode` [here](https://developer.nvidia.com/cuda-gpus) and [here](http://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/). 17 | 18 | ### Converted Caffe Pre-trained Models 19 | You can find them in `models` folder. 20 | 21 | ### Inference mode 22 | Modify the path to your input, then 23 | 24 | ``` 25 | python demo.py 26 | ``` 27 | 28 | If installation is sucessful, you should see the following: 29 | ![PWC-Net Sample Prediction](https://github.com/vt-vl-lab/pwc-net.pytorch/blob/master/misc/demo.png) 30 | 31 | ### Reference 32 | If you find this implementation useful in your work, please acknowledge it appropriately and cite the paper using: 33 | ```` 34 | @inproceedings{sun2018pwc, 35 | title={PWC-Net: CNNs for optical flow using pyramid, warping, and cost volume}, 36 | author={Sun, Deqing and Yang, Xiaodong and Liu, Ming-Yu and Kautz, Jan}, 37 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 38 | pages={8934--8943}, 39 | year={2018} 40 | } 41 | ```` 42 | 43 | ### Acknowledgments 44 | * [sniklaus/pytorch-pwc](https://github.com/sniklaus/pytorch-pwc): Network defintion and converted PyTorch model weights. 45 | * [NVIDIA/flownet2-pytorch](https://github.com/NVIDIA/flownet2-pytorch): Correlation module. 46 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import torch 5 | 6 | from PWC_src import PWC_Net 7 | from PWC_src import flow_to_image 8 | 9 | FLOW_SCALE = 20.0 10 | 11 | 12 | if __name__ == '__main__': 13 | # Prepare img pair (size need to be a multipler of 64) 14 | im1 = cv2.imread('example/0img0.ppm') 15 | im2 = cv2.imread('example/0img1.ppm') 16 | im1 = torch.from_numpy((im1/255.).astype(np.float32)).permute(2, 0, 1).unsqueeze(0) 17 | im2 = torch.from_numpy((im2/255.).astype(np.float32)).permute(2, 0, 1).unsqueeze(0) 18 | im1_v = im1.cuda() 19 | im2_v = im2.cuda() 20 | 21 | # Build model 22 | pwc = PWC_Net(model_path='models/sintel.pytorch') 23 | #pwc = PWC_Net(model_path='models/chairs-things.pytorch') 24 | pwc = pwc.cuda() 25 | pwc.eval() 26 | 27 | import time 28 | start = time.time() 29 | flow = FLOW_SCALE*pwc(im1_v, im2_v) 30 | print(time.time()-start) 31 | flow = flow.data.cpu() 32 | flow = flow[0].numpy().transpose((1,2,0)) 33 | flow_im = flow_to_image(flow) 34 | 35 | # Visualization 36 | import matplotlib.pyplot as plt 37 | plt.imshow(flow_im) 38 | plt.show() 39 | 40 | -------------------------------------------------------------------------------- /example/0img0.ppm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/example/0img0.ppm -------------------------------------------------------------------------------- /example/0img1.ppm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/example/0img1.ppm -------------------------------------------------------------------------------- /example/1img0.ppm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/example/1img0.ppm -------------------------------------------------------------------------------- /example/1img1.ppm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/example/1img1.ppm -------------------------------------------------------------------------------- /example/flow0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/example/flow0.png -------------------------------------------------------------------------------- /misc/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/misc/demo.png -------------------------------------------------------------------------------- /models/chairs-things.pytorch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/models/chairs-things.pytorch -------------------------------------------------------------------------------- /models/sintel.pytorch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/models/sintel.pytorch -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2019.3.9 2 | cycler==0.10.0 3 | joblib==0.11 4 | kiwisolver==1.1.0 5 | matplotlib==3.1.0 6 | numpy==1.16.4 7 | opencv-python==3.4.3.18 8 | Pillow==6.2.0 9 | pyparsing==2.4.0 10 | pypng==0.0.19 11 | python-dateutil==2.8.0 12 | six==1.12.0 13 | torch==1.0.1 14 | --------------------------------------------------------------------------------