├── .gitignore
├── LICENSE
├── PWC_src
├── __init__.py
├── correlation_package
│ ├── __init__.py
│ ├── correlation.py
│ ├── correlation_cuda.cc
│ ├── correlation_cuda.egg-info
│ │ ├── PKG-INFO
│ │ ├── SOURCES.txt
│ │ ├── dependency_links.txt
│ │ └── top_level.txt
│ ├── correlation_cuda_kernel.cu
│ ├── correlation_cuda_kernel.cuh
│ └── setup.py
├── flowlib.py
└── pwc.py
├── README.md
├── demo.py
├── example
├── 0img0.ppm
├── 0img1.ppm
├── 1img0.ppm
├── 1img1.ppm
└── flow0.png
├── misc
└── demo.png
├── models
├── chairs-things.pytorch
└── sintel.pytorch
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | *build/
2 | *dist/
3 | *pyc
4 | *__pycache__
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU LESSER GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 |
9 | This version of the GNU Lesser General Public License incorporates
10 | the terms and conditions of version 3 of the GNU General Public
11 | License, supplemented by the additional permissions listed below.
12 |
13 | 0. Additional Definitions.
14 |
15 | As used herein, "this License" refers to version 3 of the GNU Lesser
16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU
17 | General Public License.
18 |
19 | "The Library" refers to a covered work governed by this License,
20 | other than an Application or a Combined Work as defined below.
21 |
22 | An "Application" is any work that makes use of an interface provided
23 | by the Library, but which is not otherwise based on the Library.
24 | Defining a subclass of a class defined by the Library is deemed a mode
25 | of using an interface provided by the Library.
26 |
27 | A "Combined Work" is a work produced by combining or linking an
28 | Application with the Library. The particular version of the Library
29 | with which the Combined Work was made is also called the "Linked
30 | Version".
31 |
32 | The "Minimal Corresponding Source" for a Combined Work means the
33 | Corresponding Source for the Combined Work, excluding any source code
34 | for portions of the Combined Work that, considered in isolation, are
35 | based on the Application, and not on the Linked Version.
36 |
37 | The "Corresponding Application Code" for a Combined Work means the
38 | object code and/or source code for the Application, including any data
39 | and utility programs needed for reproducing the Combined Work from the
40 | Application, but excluding the System Libraries of the Combined Work.
41 |
42 | 1. Exception to Section 3 of the GNU GPL.
43 |
44 | You may convey a covered work under sections 3 and 4 of this License
45 | without being bound by section 3 of the GNU GPL.
46 |
47 | 2. Conveying Modified Versions.
48 |
49 | If you modify a copy of the Library, and, in your modifications, a
50 | facility refers to a function or data to be supplied by an Application
51 | that uses the facility (other than as an argument passed when the
52 | facility is invoked), then you may convey a copy of the modified
53 | version:
54 |
55 | a) under this License, provided that you make a good faith effort to
56 | ensure that, in the event an Application does not supply the
57 | function or data, the facility still operates, and performs
58 | whatever part of its purpose remains meaningful, or
59 |
60 | b) under the GNU GPL, with none of the additional permissions of
61 | this License applicable to that copy.
62 |
63 | 3. Object Code Incorporating Material from Library Header Files.
64 |
65 | The object code form of an Application may incorporate material from
66 | a header file that is part of the Library. You may convey such object
67 | code under terms of your choice, provided that, if the incorporated
68 | material is not limited to numerical parameters, data structure
69 | layouts and accessors, or small macros, inline functions and templates
70 | (ten or fewer lines in length), you do both of the following:
71 |
72 | a) Give prominent notice with each copy of the object code that the
73 | Library is used in it and that the Library and its use are
74 | covered by this License.
75 |
76 | b) Accompany the object code with a copy of the GNU GPL and this license
77 | document.
78 |
79 | 4. Combined Works.
80 |
81 | You may convey a Combined Work under terms of your choice that,
82 | taken together, effectively do not restrict modification of the
83 | portions of the Library contained in the Combined Work and reverse
84 | engineering for debugging such modifications, if you also do each of
85 | the following:
86 |
87 | a) Give prominent notice with each copy of the Combined Work that
88 | the Library is used in it and that the Library and its use are
89 | covered by this License.
90 |
91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license
92 | document.
93 |
94 | c) For a Combined Work that displays copyright notices during
95 | execution, include the copyright notice for the Library among
96 | these notices, as well as a reference directing the user to the
97 | copies of the GNU GPL and this license document.
98 |
99 | d) Do one of the following:
100 |
101 | 0) Convey the Minimal Corresponding Source under the terms of this
102 | License, and the Corresponding Application Code in a form
103 | suitable for, and under terms that permit, the user to
104 | recombine or relink the Application with a modified version of
105 | the Linked Version to produce a modified Combined Work, in the
106 | manner specified by section 6 of the GNU GPL for conveying
107 | Corresponding Source.
108 |
109 | 1) Use a suitable shared library mechanism for linking with the
110 | Library. A suitable mechanism is one that (a) uses at run time
111 | a copy of the Library already present on the user's computer
112 | system, and (b) will operate properly with a modified version
113 | of the Library that is interface-compatible with the Linked
114 | Version.
115 |
116 | e) Provide Installation Information, but only if you would otherwise
117 | be required to provide such information under section 6 of the
118 | GNU GPL, and only to the extent that such information is
119 | necessary to install and execute a modified version of the
120 | Combined Work produced by recombining or relinking the
121 | Application with a modified version of the Linked Version. (If
122 | you use option 4d0, the Installation Information must accompany
123 | the Minimal Corresponding Source and Corresponding Application
124 | Code. If you use option 4d1, you must provide the Installation
125 | Information in the manner specified by section 6 of the GNU GPL
126 | for conveying Corresponding Source.)
127 |
128 | 5. Combined Libraries.
129 |
130 | You may place library facilities that are a work based on the
131 | Library side by side in a single library together with other library
132 | facilities that are not Applications and are not covered by this
133 | License, and convey such a combined library under terms of your
134 | choice, if you do both of the following:
135 |
136 | a) Accompany the combined library with a copy of the same work based
137 | on the Library, uncombined with any other library facilities,
138 | conveyed under the terms of this License.
139 |
140 | b) Give prominent notice with the combined library that part of it
141 | is a work based on the Library, and explaining where to find the
142 | accompanying uncombined form of the same work.
143 |
144 | 6. Revised Versions of the GNU Lesser General Public License.
145 |
146 | The Free Software Foundation may publish revised and/or new versions
147 | of the GNU Lesser General Public License from time to time. Such new
148 | versions will be similar in spirit to the present version, but may
149 | differ in detail to address new problems or concerns.
150 |
151 | Each version is given a distinguishing version number. If the
152 | Library as you received it specifies that a certain numbered version
153 | of the GNU Lesser General Public License "or any later version"
154 | applies to it, you have the option of following the terms and
155 | conditions either of that published version or of any later version
156 | published by the Free Software Foundation. If the Library as you
157 | received it does not specify a version number of the GNU Lesser
158 | General Public License, you may choose any version of the GNU Lesser
159 | General Public License ever published by the Free Software Foundation.
160 |
161 | If the Library as you received it specifies that a proxy can decide
162 | whether future versions of the GNU Lesser General Public License shall
163 | apply, that proxy's public statement of acceptance of any version is
164 | permanent authorization for you to choose that version for the
165 | Library.
166 |
--------------------------------------------------------------------------------
/PWC_src/__init__.py:
--------------------------------------------------------------------------------
1 | from .pwc import PWC_Net
2 | from .flowlib import *
3 |
--------------------------------------------------------------------------------
/PWC_src/correlation_package/__init__.py:
--------------------------------------------------------------------------------
1 | from .correlation import Correlation
2 |
--------------------------------------------------------------------------------
/PWC_src/correlation_package/correlation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.modules.module import Module
3 | from torch.autograd import Function
4 | import correlation_cuda
5 |
6 | class CorrelationFunction(Function):
7 |
8 | def __init__(self, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1, corr_multiply=1):
9 | super(CorrelationFunction, self).__init__()
10 | self.pad_size = pad_size
11 | self.kernel_size = kernel_size
12 | self.max_displacement = max_displacement
13 | self.stride1 = stride1
14 | self.stride2 = stride2
15 | self.corr_multiply = corr_multiply
16 | # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)
17 |
18 | def forward(self, input1, input2):
19 | self.save_for_backward(input1, input2)
20 | with torch.cuda.device_of(input1):
21 | rbot1 = input1.new()
22 | rbot2 = input2.new()
23 | output = input1.new()
24 |
25 | correlation_cuda.forward(input1, input2, rbot1, rbot2, output,
26 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
27 |
28 | return output
29 |
30 | def backward(self, grad_output):
31 | input1, input2 = self.saved_tensors
32 |
33 | with torch.cuda.device_of(input1):
34 | rbot1 = input1.new()
35 | rbot2 = input2.new()
36 |
37 | grad_input1 = input1.new()
38 | grad_input2 = input2.new()
39 |
40 | correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
41 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
42 |
43 | return grad_input1, grad_input2
44 |
45 |
46 | class Correlation(Module):
47 | def __init__(self, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1, corr_multiply=1):
48 | super(Correlation, self).__init__()
49 | self.pad_size = pad_size
50 | self.kernel_size = kernel_size
51 | self.max_displacement = max_displacement
52 | self.stride1 = stride1
53 | self.stride2 = stride2
54 | self.corr_multiply = corr_multiply
55 |
56 | def forward(self, input1, input2):
57 |
58 | result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2)
59 |
60 | return result
61 |
62 |
--------------------------------------------------------------------------------
/PWC_src/correlation_package/correlation_cuda.cc:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 | #include "correlation_cuda_kernel.cuh"
9 |
10 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,
11 | int pad_size,
12 | int kernel_size,
13 | int max_displacement,
14 | int stride1,
15 | int stride2,
16 | int corr_type_multiply)
17 | {
18 |
19 | int batchSize = input1.size(0);
20 |
21 | int nInputChannels = input1.size(1);
22 | int inputHeight = input1.size(2);
23 | int inputWidth = input1.size(3);
24 |
25 | int kernel_radius = (kernel_size - 1) / 2;
26 | int border_radius = kernel_radius + max_displacement;
27 |
28 | int paddedInputHeight = inputHeight + 2 * pad_size;
29 | int paddedInputWidth = inputWidth + 2 * pad_size;
30 |
31 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);
32 |
33 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1));
34 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1));
35 |
36 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
37 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
38 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});
39 |
40 | rInput1.fill_(0);
41 | rInput2.fill_(0);
42 | output.fill_(0);
43 | int success = correlation_forward_cuda_kernel(
44 | output,
45 | output.size(0),
46 | output.size(1),
47 | output.size(2),
48 | output.size(3),
49 | output.stride(0),
50 | output.stride(1),
51 | output.stride(2),
52 | output.stride(3),
53 | input1,
54 | input1.size(1),
55 | input1.size(2),
56 | input1.size(3),
57 | input1.stride(0),
58 | input1.stride(1),
59 | input1.stride(2),
60 | input1.stride(3),
61 | input2,
62 | input2.size(1),
63 | input2.stride(0),
64 | input2.stride(1),
65 | input2.stride(2),
66 | input2.stride(3),
67 | rInput1,
68 | rInput2,
69 | pad_size,
70 | kernel_size,
71 | max_displacement,
72 | stride1,
73 | stride2,
74 | corr_type_multiply,
75 | at::cuda::getCurrentCUDAStream()
76 | );
77 | // Original: at::globalContext().getCurrentCUDAStream()
78 | // For CUDA-10.0, we need at::cuda::getCurrentCUDAStream()
79 |
80 | //check for errors
81 | if (!success) {
82 | AT_ERROR("CUDA call failed");
83 | }
84 |
85 | return 1;
86 |
87 | }
88 |
89 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput,
90 | at::Tensor& gradInput1, at::Tensor& gradInput2,
91 | int pad_size,
92 | int kernel_size,
93 | int max_displacement,
94 | int stride1,
95 | int stride2,
96 | int corr_type_multiply)
97 | {
98 |
99 | int batchSize = input1.size(0);
100 | int nInputChannels = input1.size(1);
101 | int paddedInputHeight = input1.size(2)+ 2 * pad_size;
102 | int paddedInputWidth = input1.size(3)+ 2 * pad_size;
103 |
104 | int height = input1.size(2);
105 | int width = input1.size(3);
106 |
107 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
108 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
109 | gradInput1.resize_({batchSize, nInputChannels, height, width});
110 | gradInput2.resize_({batchSize, nInputChannels, height, width});
111 |
112 | rInput1.fill_(0);
113 | rInput2.fill_(0);
114 | gradInput1.fill_(0);
115 | gradInput2.fill_(0);
116 |
117 | int success = correlation_backward_cuda_kernel(gradOutput,
118 | gradOutput.size(0),
119 | gradOutput.size(1),
120 | gradOutput.size(2),
121 | gradOutput.size(3),
122 | gradOutput.stride(0),
123 | gradOutput.stride(1),
124 | gradOutput.stride(2),
125 | gradOutput.stride(3),
126 | input1,
127 | input1.size(1),
128 | input1.size(2),
129 | input1.size(3),
130 | input1.stride(0),
131 | input1.stride(1),
132 | input1.stride(2),
133 | input1.stride(3),
134 | input2,
135 | input2.stride(0),
136 | input2.stride(1),
137 | input2.stride(2),
138 | input2.stride(3),
139 | gradInput1,
140 | gradInput1.stride(0),
141 | gradInput1.stride(1),
142 | gradInput1.stride(2),
143 | gradInput1.stride(3),
144 | gradInput2,
145 | gradInput2.size(1),
146 | gradInput2.stride(0),
147 | gradInput2.stride(1),
148 | gradInput2.stride(2),
149 | gradInput2.stride(3),
150 | rInput1,
151 | rInput2,
152 | pad_size,
153 | kernel_size,
154 | max_displacement,
155 | stride1,
156 | stride2,
157 | corr_type_multiply,
158 | at::cuda::getCurrentCUDAStream()
159 | );
160 |
161 | if (!success) {
162 | AT_ERROR("CUDA call failed");
163 | }
164 |
165 | return 1;
166 | }
167 |
168 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
169 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)");
170 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)");
171 | }
172 |
173 |
--------------------------------------------------------------------------------
/PWC_src/correlation_package/correlation_cuda.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 1.0
2 | Name: correlation-cuda
3 | Version: 0.0.0
4 | Summary: UNKNOWN
5 | Home-page: UNKNOWN
6 | Author: UNKNOWN
7 | Author-email: UNKNOWN
8 | License: UNKNOWN
9 | Description: UNKNOWN
10 | Platform: UNKNOWN
11 |
--------------------------------------------------------------------------------
/PWC_src/correlation_package/correlation_cuda.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | correlation_cuda.cc
2 | correlation_cuda_kernel.cu
3 | setup.py
4 | correlation_cuda.egg-info/PKG-INFO
5 | correlation_cuda.egg-info/SOURCES.txt
6 | correlation_cuda.egg-info/dependency_links.txt
7 | correlation_cuda.egg-info/top_level.txt
--------------------------------------------------------------------------------
/PWC_src/correlation_package/correlation_cuda.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/PWC_src/correlation_package/correlation_cuda.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | correlation_cuda
2 |
--------------------------------------------------------------------------------
/PWC_src/correlation_package/correlation_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "correlation_cuda_kernel.cuh"
4 |
5 | #define CUDA_NUM_THREADS 1024
6 | #define THREADS_PER_BLOCK 32
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 |
13 | using at::Half;
14 |
15 | template
16 | __global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size)
17 | {
18 |
19 | // n (batch size), c (num of channels), y (height), x (width)
20 | int n = blockIdx.x;
21 | int y = blockIdx.y;
22 | int x = blockIdx.z;
23 |
24 | int ch_off = threadIdx.x;
25 | scalar_t value;
26 |
27 | int dimcyx = channels * height * width;
28 | int dimyx = height * width;
29 |
30 | int p_dimx = (width + 2 * pad_size);
31 | int p_dimy = (height + 2 * pad_size);
32 | int p_dimyxc = channels * p_dimy * p_dimx;
33 | int p_dimxc = p_dimx * channels;
34 |
35 | for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) {
36 | value = input[n * dimcyx + c * dimyx + y * width + x];
37 | rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value;
38 | }
39 | }
40 |
41 | template
42 | __global__ void correlation_forward( scalar_t* output, int nOutputChannels, int outputHeight, int outputWidth,
43 | const scalar_t* __restrict__ rInput1, int nInputChannels, int inputHeight, int inputWidth,
44 | const scalar_t* __restrict__ rInput2,
45 | int pad_size,
46 | int kernel_size,
47 | int max_displacement,
48 | int stride1,
49 | int stride2)
50 | {
51 | // n (batch size), c (num of channels), y (height), x (width)
52 |
53 | int pInputWidth = inputWidth + 2 * pad_size;
54 | int pInputHeight = inputHeight + 2 * pad_size;
55 |
56 | int kernel_rad = (kernel_size - 1) / 2;
57 | int displacement_rad = max_displacement / stride2;
58 | int displacement_size = 2 * displacement_rad + 1;
59 |
60 | int n = blockIdx.x;
61 | int y1 = blockIdx.y * stride1 + max_displacement;
62 | int x1 = blockIdx.z * stride1 + max_displacement;
63 | int c = threadIdx.x;
64 |
65 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
66 | int pdimxc = pInputWidth * nInputChannels;
67 | int pdimc = nInputChannels;
68 |
69 | int tdimcyx = nOutputChannels * outputHeight * outputWidth;
70 | int tdimyx = outputHeight * outputWidth;
71 | int tdimx = outputWidth;
72 |
73 | scalar_t nelems = kernel_size * kernel_size * pdimc;
74 |
75 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
76 |
77 | // no significant speed-up in using chip memory for input1 sub-data,
78 | // not enough chip memory size to accomodate memory per block for input2 sub-data
79 | // instead i've used device memory for both
80 |
81 | // element-wise product along channel axis
82 | for (int tj = -displacement_rad; tj <= displacement_rad; ++tj ) {
83 | for (int ti = -displacement_rad; ti <= displacement_rad; ++ti ) {
84 | prod_sum[c] = 0;
85 | int x2 = x1 + ti*stride2;
86 | int y2 = y1 + tj*stride2;
87 |
88 | for (int j = -kernel_rad; j <= kernel_rad; ++j) {
89 | for (int i = -kernel_rad; i <= kernel_rad; ++i) {
90 | for (int ch = c; ch < pdimc; ch += THREADS_PER_BLOCK) {
91 | int indx1 = n * pdimyxc + (y1+j) * pdimxc + (x1 + i) * pdimc + ch;
92 | int indx2 = n * pdimyxc + (y2+j) * pdimxc + (x2 + i) * pdimc + ch;
93 |
94 | prod_sum[c] += rInput1[indx1] * rInput2[indx2];
95 | }
96 | }
97 | }
98 |
99 | // accumulate
100 | __syncthreads();
101 | if (c == 0) {
102 | scalar_t reduce_sum = 0;
103 | for (int index = 0; index < THREADS_PER_BLOCK; ++index) {
104 | reduce_sum += prod_sum[index];
105 | }
106 | int tc = (tj + displacement_rad) * displacement_size + (ti + displacement_rad);
107 | const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + blockIdx.z;
108 | output[tindx] = reduce_sum / nelems;
109 | }
110 |
111 | }
112 | }
113 |
114 | }
115 |
116 | template
117 | __global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth,
118 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
119 | const scalar_t* __restrict__ rInput2,
120 | int pad_size,
121 | int kernel_size,
122 | int max_displacement,
123 | int stride1,
124 | int stride2)
125 | {
126 | // n (batch size), c (num of channels), y (height), x (width)
127 |
128 | int n = item;
129 | int y = blockIdx.x * stride1 + pad_size;
130 | int x = blockIdx.y * stride1 + pad_size;
131 | int c = blockIdx.z;
132 | int tch_off = threadIdx.x;
133 |
134 | int kernel_rad = (kernel_size - 1) / 2;
135 | int displacement_rad = max_displacement / stride2;
136 | int displacement_size = 2 * displacement_rad + 1;
137 |
138 | int xmin = (x - kernel_rad - max_displacement) / stride1;
139 | int ymin = (y - kernel_rad - max_displacement) / stride1;
140 |
141 | int xmax = (x + kernel_rad - max_displacement) / stride1;
142 | int ymax = (y + kernel_rad - max_displacement) / stride1;
143 |
144 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
145 | // assumes gradInput1 is pre-allocated and zero filled
146 | return;
147 | }
148 |
149 | if (xmin > xmax || ymin > ymax) {
150 | // assumes gradInput1 is pre-allocated and zero filled
151 | return;
152 | }
153 |
154 | xmin = max(0,xmin);
155 | xmax = min(outputWidth-1,xmax);
156 |
157 | ymin = max(0,ymin);
158 | ymax = min(outputHeight-1,ymax);
159 |
160 | int pInputWidth = inputWidth + 2 * pad_size;
161 | int pInputHeight = inputHeight + 2 * pad_size;
162 |
163 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
164 | int pdimxc = pInputWidth * nInputChannels;
165 | int pdimc = nInputChannels;
166 |
167 | int tdimcyx = nOutputChannels * outputHeight * outputWidth;
168 | int tdimyx = outputHeight * outputWidth;
169 | int tdimx = outputWidth;
170 |
171 | int odimcyx = nInputChannels * inputHeight* inputWidth;
172 | int odimyx = inputHeight * inputWidth;
173 | int odimx = inputWidth;
174 |
175 | scalar_t nelems = kernel_size * kernel_size * nInputChannels;
176 |
177 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
178 | prod_sum[tch_off] = 0;
179 |
180 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
181 |
182 | int i2 = (tc % displacement_size - displacement_rad) * stride2;
183 | int j2 = (tc / displacement_size - displacement_rad) * stride2;
184 |
185 | int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c;
186 |
187 | scalar_t val2 = rInput2[indx2];
188 |
189 | for (int j = ymin; j <= ymax; ++j) {
190 | for (int i = xmin; i <= xmax; ++i) {
191 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
192 | prod_sum[tch_off] += gradOutput[tindx] * val2;
193 | }
194 | }
195 | }
196 | __syncthreads();
197 |
198 | if(tch_off == 0) {
199 | scalar_t reduce_sum = 0;
200 | for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
201 | reduce_sum += prod_sum[idx];
202 | }
203 | const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
204 | gradInput1[indx1] = reduce_sum / nelems;
205 | }
206 |
207 | }
208 |
209 | template
210 | __global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth,
211 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
212 | const scalar_t* __restrict__ rInput1,
213 | int pad_size,
214 | int kernel_size,
215 | int max_displacement,
216 | int stride1,
217 | int stride2)
218 | {
219 | // n (batch size), c (num of channels), y (height), x (width)
220 |
221 | int n = item;
222 | int y = blockIdx.x * stride1 + pad_size;
223 | int x = blockIdx.y * stride1 + pad_size;
224 | int c = blockIdx.z;
225 |
226 | int tch_off = threadIdx.x;
227 |
228 | int kernel_rad = (kernel_size - 1) / 2;
229 | int displacement_rad = max_displacement / stride2;
230 | int displacement_size = 2 * displacement_rad + 1;
231 |
232 | int pInputWidth = inputWidth + 2 * pad_size;
233 | int pInputHeight = inputHeight + 2 * pad_size;
234 |
235 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
236 | int pdimxc = pInputWidth * nInputChannels;
237 | int pdimc = nInputChannels;
238 |
239 | int tdimcyx = nOutputChannels * outputHeight * outputWidth;
240 | int tdimyx = outputHeight * outputWidth;
241 | int tdimx = outputWidth;
242 |
243 | int odimcyx = nInputChannels * inputHeight* inputWidth;
244 | int odimyx = inputHeight * inputWidth;
245 | int odimx = inputWidth;
246 |
247 | scalar_t nelems = kernel_size * kernel_size * nInputChannels;
248 |
249 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
250 | prod_sum[tch_off] = 0;
251 |
252 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
253 | int i2 = (tc % displacement_size - displacement_rad) * stride2;
254 | int j2 = (tc / displacement_size - displacement_rad) * stride2;
255 |
256 | int xmin = (x - kernel_rad - max_displacement - i2) / stride1;
257 | int ymin = (y - kernel_rad - max_displacement - j2) / stride1;
258 |
259 | int xmax = (x + kernel_rad - max_displacement - i2) / stride1;
260 | int ymax = (y + kernel_rad - max_displacement - j2) / stride1;
261 |
262 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
263 | // assumes gradInput2 is pre-allocated and zero filled
264 | continue;
265 | }
266 |
267 | if (xmin > xmax || ymin > ymax) {
268 | // assumes gradInput2 is pre-allocated and zero filled
269 | continue;
270 | }
271 |
272 | xmin = max(0,xmin);
273 | xmax = min(outputWidth-1,xmax);
274 |
275 | ymin = max(0,ymin);
276 | ymax = min(outputHeight-1,ymax);
277 |
278 | int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c;
279 | scalar_t val1 = rInput1[indx1];
280 |
281 | for (int j = ymin; j <= ymax; ++j) {
282 | for (int i = xmin; i <= xmax; ++i) {
283 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
284 | prod_sum[tch_off] += gradOutput[tindx] * val1;
285 | }
286 | }
287 | }
288 |
289 | __syncthreads();
290 |
291 | if(tch_off == 0) {
292 | scalar_t reduce_sum = 0;
293 | for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
294 | reduce_sum += prod_sum[idx];
295 | }
296 | const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
297 | gradInput2[indx2] = reduce_sum / nelems;
298 | }
299 |
300 | }
301 |
302 | int correlation_forward_cuda_kernel(at::Tensor& output,
303 | int ob,
304 | int oc,
305 | int oh,
306 | int ow,
307 | int osb,
308 | int osc,
309 | int osh,
310 | int osw,
311 |
312 | at::Tensor& input1,
313 | int ic,
314 | int ih,
315 | int iw,
316 | int isb,
317 | int isc,
318 | int ish,
319 | int isw,
320 |
321 | at::Tensor& input2,
322 | int gc,
323 | int gsb,
324 | int gsc,
325 | int gsh,
326 | int gsw,
327 |
328 | at::Tensor& rInput1,
329 | at::Tensor& rInput2,
330 | int pad_size,
331 | int kernel_size,
332 | int max_displacement,
333 | int stride1,
334 | int stride2,
335 | int corr_type_multiply,
336 | cudaStream_t stream)
337 | {
338 |
339 | int batchSize = ob;
340 |
341 | int nInputChannels = ic;
342 | int inputWidth = iw;
343 | int inputHeight = ih;
344 |
345 | int nOutputChannels = oc;
346 | int outputWidth = ow;
347 | int outputHeight = oh;
348 |
349 | dim3 blocks_grid(batchSize, inputHeight, inputWidth);
350 | dim3 threads_block(THREADS_PER_BLOCK);
351 |
352 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] {
353 |
354 | channels_first<<>>(
355 | input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size);
356 |
357 | }));
358 |
359 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] {
360 |
361 | channels_first<<>> (
362 | input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size);
363 |
364 | }));
365 |
366 | dim3 threadsPerBlock(THREADS_PER_BLOCK);
367 | dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth);
368 |
369 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] {
370 |
371 | correlation_forward<<>>
372 | (output.data(), nOutputChannels, outputHeight, outputWidth,
373 | rInput1.data(), nInputChannels, inputHeight, inputWidth,
374 | rInput2.data(),
375 | pad_size,
376 | kernel_size,
377 | max_displacement,
378 | stride1,
379 | stride2);
380 |
381 | }));
382 |
383 | cudaError_t err = cudaGetLastError();
384 |
385 |
386 | // check for errors
387 | if (err != cudaSuccess) {
388 | printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err));
389 | return 0;
390 | }
391 |
392 | return 1;
393 | }
394 |
395 |
396 | int correlation_backward_cuda_kernel(
397 | at::Tensor& gradOutput,
398 | int gob,
399 | int goc,
400 | int goh,
401 | int gow,
402 | int gosb,
403 | int gosc,
404 | int gosh,
405 | int gosw,
406 |
407 | at::Tensor& input1,
408 | int ic,
409 | int ih,
410 | int iw,
411 | int isb,
412 | int isc,
413 | int ish,
414 | int isw,
415 |
416 | at::Tensor& input2,
417 | int gsb,
418 | int gsc,
419 | int gsh,
420 | int gsw,
421 |
422 | at::Tensor& gradInput1,
423 | int gisb,
424 | int gisc,
425 | int gish,
426 | int gisw,
427 |
428 | at::Tensor& gradInput2,
429 | int ggc,
430 | int ggsb,
431 | int ggsc,
432 | int ggsh,
433 | int ggsw,
434 |
435 | at::Tensor& rInput1,
436 | at::Tensor& rInput2,
437 | int pad_size,
438 | int kernel_size,
439 | int max_displacement,
440 | int stride1,
441 | int stride2,
442 | int corr_type_multiply,
443 | cudaStream_t stream)
444 | {
445 |
446 | int batchSize = gob;
447 | int num = batchSize;
448 |
449 | int nInputChannels = ic;
450 | int inputWidth = iw;
451 | int inputHeight = ih;
452 |
453 | int nOutputChannels = goc;
454 | int outputWidth = gow;
455 | int outputHeight = goh;
456 |
457 | dim3 blocks_grid(batchSize, inputHeight, inputWidth);
458 | dim3 threads_block(THREADS_PER_BLOCK);
459 |
460 |
461 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] {
462 |
463 | channels_first<<>>(
464 | input1.data(),
465 | rInput1.data(),
466 | nInputChannels,
467 | inputHeight,
468 | inputWidth,
469 | pad_size
470 | );
471 | }));
472 |
473 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
474 |
475 | channels_first<<>>(
476 | input2.data(),
477 | rInput2.data(),
478 | nInputChannels,
479 | inputHeight,
480 | inputWidth,
481 | pad_size
482 | );
483 | }));
484 |
485 | dim3 threadsPerBlock(THREADS_PER_BLOCK);
486 | dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels);
487 |
488 | for (int n = 0; n < num; ++n) {
489 |
490 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
491 |
492 |
493 | correlation_backward_input1<<>> (
494 | n, gradInput1.data(), nInputChannels, inputHeight, inputWidth,
495 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth,
496 | rInput2.data(),
497 | pad_size,
498 | kernel_size,
499 | max_displacement,
500 | stride1,
501 | stride2);
502 | }));
503 | }
504 |
505 | for(int n = 0; n < batchSize; n++) {
506 |
507 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] {
508 |
509 | correlation_backward_input2<<>>(
510 | n, gradInput2.data(), nInputChannels, inputHeight, inputWidth,
511 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth,
512 | rInput1.data(),
513 | pad_size,
514 | kernel_size,
515 | max_displacement,
516 | stride1,
517 | stride2);
518 |
519 | }));
520 | }
521 |
522 | // check for errors
523 | cudaError_t err = cudaGetLastError();
524 | if (err != cudaSuccess) {
525 | printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err));
526 | return 0;
527 | }
528 |
529 | return 1;
530 | }
531 |
--------------------------------------------------------------------------------
/PWC_src/correlation_package/correlation_cuda_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | int correlation_forward_cuda_kernel(at::Tensor& output,
8 | int ob,
9 | int oc,
10 | int oh,
11 | int ow,
12 | int osb,
13 | int osc,
14 | int osh,
15 | int osw,
16 |
17 | at::Tensor& input1,
18 | int ic,
19 | int ih,
20 | int iw,
21 | int isb,
22 | int isc,
23 | int ish,
24 | int isw,
25 |
26 | at::Tensor& input2,
27 | int gc,
28 | int gsb,
29 | int gsc,
30 | int gsh,
31 | int gsw,
32 |
33 | at::Tensor& rInput1,
34 | at::Tensor& rInput2,
35 | int pad_size,
36 | int kernel_size,
37 | int max_displacement,
38 | int stride1,
39 | int stride2,
40 | int corr_type_multiply,
41 | cudaStream_t stream);
42 |
43 |
44 | int correlation_backward_cuda_kernel(
45 | at::Tensor& gradOutput,
46 | int gob,
47 | int goc,
48 | int goh,
49 | int gow,
50 | int gosb,
51 | int gosc,
52 | int gosh,
53 | int gosw,
54 |
55 | at::Tensor& input1,
56 | int ic,
57 | int ih,
58 | int iw,
59 | int isb,
60 | int isc,
61 | int ish,
62 | int isw,
63 |
64 | at::Tensor& input2,
65 | int gsb,
66 | int gsc,
67 | int gsh,
68 | int gsw,
69 |
70 | at::Tensor& gradInput1,
71 | int gisb,
72 | int gisc,
73 | int gish,
74 | int gisw,
75 |
76 | at::Tensor& gradInput2,
77 | int ggc,
78 | int ggsb,
79 | int ggsc,
80 | int ggsh,
81 | int ggsw,
82 |
83 | at::Tensor& rInput1,
84 | at::Tensor& rInput2,
85 | int pad_size,
86 | int kernel_size,
87 | int max_displacement,
88 | int stride1,
89 | int stride2,
90 | int corr_type_multiply,
91 | cudaStream_t stream);
92 |
--------------------------------------------------------------------------------
/PWC_src/correlation_package/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from setuptools import setup, find_packages
5 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
6 |
7 | cxx_args = ['-std=c++11']
8 |
9 | nvcc_args = [
10 | '-gencode', 'arch=compute_37,code=sm_37',
11 | '-gencode', 'arch=compute_50,code=sm_50',
12 | '-gencode', 'arch=compute_52,code=sm_52',
13 | '-gencode', 'arch=compute_60,code=sm_60',
14 | '-gencode', 'arch=compute_61,code=sm_61',
15 | '-gencode', 'arch=compute_70,code=sm_70',
16 | '-gencode', 'arch=compute_70,code=compute_70'
17 | ]
18 |
19 | setup(
20 | name='correlation_cuda',
21 | ext_modules=[
22 | CUDAExtension('correlation_cuda', [
23 | 'correlation_cuda.cc',
24 | 'correlation_cuda_kernel.cu'
25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
26 | ],
27 | cmdclass={
28 | 'build_ext': BuildExtension
29 | })
30 |
--------------------------------------------------------------------------------
/PWC_src/flowlib.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | """
3 | # ==============================
4 | # flowlib.py
5 | # library for optical flow processing
6 | # Author: Ruoteng Li
7 | # Date: 6th Aug 2016
8 | # ==============================
9 | """
10 | import png
11 | import numpy as np
12 | import matplotlib.colors as cl
13 | import matplotlib.pyplot as plt
14 | from PIL import Image
15 |
16 |
17 | UNKNOWN_FLOW_THRESH = 1e7
18 | SMALLFLOW = 0.0
19 | LARGEFLOW = 1e8
20 |
21 | """
22 | =============
23 | Flow Section
24 | =============
25 | """
26 |
27 |
28 | def show_flow(filename):
29 | """
30 | visualize optical flow map using matplotlib
31 | :param filename: optical flow file
32 | :return: None
33 | """
34 | flow = read_flow(filename)
35 | img = flow_to_image(flow)
36 | plt.imshow(img)
37 | plt.show()
38 |
39 |
40 | def visualize_flow(flow, mode='Y'):
41 | """
42 | this function visualize the input flow
43 | :param flow: input flow in array
44 | :param mode: choose which color mode to visualize the flow (Y: Ccbcr, RGB: RGB color)
45 | :return: None
46 | """
47 | if mode == 'Y':
48 | # Ccbcr color wheel
49 | img = flow_to_image(flow)
50 | plt.imshow(img)
51 | plt.show()
52 | elif mode == 'RGB':
53 | (h, w) = flow.shape[0:2]
54 | du = flow[:, :, 0]
55 | dv = flow[:, :, 1]
56 | valid = flow[:, :, 2]
57 | max_flow = max(np.max(du), np.max(dv))
58 | img = np.zeros((h, w, 3), dtype=np.float64)
59 | # angle layer
60 | img[:, :, 0] = np.arctan2(dv, du) / (2 * np.pi)
61 | # magnitude layer, normalized to 1
62 | img[:, :, 1] = np.sqrt(du * du + dv * dv) * 8 / max_flow
63 | # phase layer
64 | img[:, :, 2] = 8 - img[:, :, 1]
65 | # clip to [0,1]
66 | small_idx = img[:, :, 0:3] < 0
67 | large_idx = img[:, :, 0:3] > 1
68 | img[small_idx] = 0
69 | img[large_idx] = 1
70 | # convert to rgb
71 | img = cl.hsv_to_rgb(img)
72 | # remove invalid point
73 | img[:, :, 0] = img[:, :, 0] * valid
74 | img[:, :, 1] = img[:, :, 1] * valid
75 | img[:, :, 2] = img[:, :, 2] * valid
76 | # show
77 | plt.imshow(img)
78 | plt.show()
79 |
80 | return None
81 |
82 |
83 | def read_flow(filename):
84 | """
85 | read optical flow from Middlebury .flo file
86 | :param filename: name of the flow file
87 | :return: optical flow data in matrix
88 | """
89 | f = open(filename, 'rb')
90 | try:
91 | magic = np.fromfile(f, np.float32, count=1)[0] # For Python3.x
92 | except:
93 | magic = np.fromfile(f, np.float32, count=1) # For Python2.x
94 | data2d = None
95 |
96 | if 202021.25 != magic:
97 | print('Magic number incorrect. Invalid .flo file')
98 | else:
99 | w = np.fromfile(f, np.int32, count=1)
100 | h = np.fromfile(f, np.int32, count=1)
101 | #print("Reading %d x %d flo file" % (h, w))
102 | data2d = np.fromfile(f, np.float32, count=2 * w[0] * h[0])
103 | # reshape data into 3D array (columns, rows, channels)
104 | data2d = np.resize(data2d, (h[0], w[0], 2))
105 | f.close()
106 | return data2d
107 |
108 |
109 | def read_flow_png(flow_file):
110 | """
111 | Read optical flow from KITTI .png file
112 | :param flow_file: name of the flow file
113 | :return: optical flow data in matrix
114 | """
115 | flow_object = png.Reader(filename=flow_file)
116 | flow_direct = flow_object.asDirect()
117 | flow_data = list(flow_direct[2])
118 | (w, h) = flow_direct[3]['size']
119 | flow = np.zeros((h, w, 3), dtype=np.float64)
120 | for i in range(len(flow_data)):
121 | flow[i, :, 0] = flow_data[i][0::3]
122 | flow[i, :, 1] = flow_data[i][1::3]
123 | flow[i, :, 2] = flow_data[i][2::3]
124 |
125 | invalid_idx = (flow[:, :, 2] == 0)
126 | flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0
127 | flow[invalid_idx, 0] = 0
128 | flow[invalid_idx, 1] = 0
129 | return flow
130 |
131 |
132 | def write_flow(flow, filename):
133 | """
134 | write optical flow in Middlebury .flo format
135 | :param flow: optical flow map
136 | :param filename: optical flow file path to be saved
137 | :return: None
138 | """
139 | f = open(filename, 'wb')
140 | magic = np.array([202021.25], dtype=np.float32)
141 | (height, width) = flow.shape[0:2]
142 | w = np.array([width], dtype=np.int32)
143 | h = np.array([height], dtype=np.int32)
144 | magic.tofile(f)
145 | w.tofile(f)
146 | h.tofile(f)
147 | flow.tofile(f)
148 | f.close()
149 |
150 |
151 | def segment_flow(flow):
152 | h = flow.shape[0]
153 | w = flow.shape[1]
154 | u = flow[:, :, 0]
155 | v = flow[:, :, 1]
156 |
157 | idx = ((abs(u) > LARGEFLOW) | (abs(v) > LARGEFLOW))
158 | idx2 = (abs(u) == SMALLFLOW)
159 | class0 = (v == 0) & (u == 0)
160 | u[idx2] = 0.00001
161 | tan_value = v / u
162 |
163 | class1 = (tan_value < 1) & (tan_value >= 0) & (u > 0) & (v >= 0)
164 | class2 = (tan_value >= 1) & (u >= 0) & (v >= 0)
165 | class3 = (tan_value < -1) & (u <= 0) & (v >= 0)
166 | class4 = (tan_value < 0) & (tan_value >= -1) & (u < 0) & (v >= 0)
167 | class8 = (tan_value >= -1) & (tan_value < 0) & (u > 0) & (v <= 0)
168 | class7 = (tan_value < -1) & (u >= 0) & (v <= 0)
169 | class6 = (tan_value >= 1) & (u <= 0) & (v <= 0)
170 | class5 = (tan_value >= 0) & (tan_value < 1) & (u < 0) & (v <= 0)
171 |
172 | seg = np.zeros((h, w))
173 |
174 | seg[class1] = 1
175 | seg[class2] = 2
176 | seg[class3] = 3
177 | seg[class4] = 4
178 | seg[class5] = 5
179 | seg[class6] = 6
180 | seg[class7] = 7
181 | seg[class8] = 8
182 | seg[class0] = 0
183 | seg[idx] = 0
184 |
185 | return seg
186 |
187 |
188 | def flow_error(tu, tv, u, v):
189 | """
190 | Calculate average end point error
191 | :param tu: ground-truth horizontal flow map
192 | :param tv: ground-truth vertical flow map
193 | :param u: estimated horizontal flow map
194 | :param v: estimated vertical flow map
195 | :return: End point error of the estimated flow
196 | """
197 | smallflow = 0.0
198 | '''
199 | stu = tu[bord+1:end-bord,bord+1:end-bord]
200 | stv = tv[bord+1:end-bord,bord+1:end-bord]
201 | su = u[bord+1:end-bord,bord+1:end-bord]
202 | sv = v[bord+1:end-bord,bord+1:end-bord]
203 | '''
204 | stu = tu[:]
205 | stv = tv[:]
206 | su = u[:]
207 | sv = v[:]
208 |
209 | idxUnknow = (abs(stu) > UNKNOWN_FLOW_THRESH) | (abs(stv) > UNKNOWN_FLOW_THRESH)
210 | stu[idxUnknow] = 0
211 | stv[idxUnknow] = 0
212 | su[idxUnknow] = 0
213 | sv[idxUnknow] = 0
214 |
215 | ind2 = [(np.absolute(stu) > smallflow) | (np.absolute(stv) > smallflow)]
216 | index_su = su[ind2]
217 | index_sv = sv[ind2]
218 | an = 1.0 / np.sqrt(index_su ** 2 + index_sv ** 2 + 1)
219 | un = index_su * an
220 | vn = index_sv * an
221 |
222 | index_stu = stu[ind2]
223 | index_stv = stv[ind2]
224 | tn = 1.0 / np.sqrt(index_stu ** 2 + index_stv ** 2 + 1)
225 | tun = index_stu * tn
226 | tvn = index_stv * tn
227 |
228 | '''
229 | angle = un * tun + vn * tvn + (an * tn)
230 | index = [angle == 1.0]
231 | angle[index] = 0.999
232 | ang = np.arccos(angle)
233 | mang = np.mean(ang)
234 | mang = mang * 180 / np.pi
235 | '''
236 |
237 | epe = np.sqrt((stu - su) ** 2 + (stv - sv) ** 2)
238 | epe = epe[ind2]
239 | mepe = np.mean(epe)
240 | return mepe
241 |
242 |
243 | def flow_to_image(flow, display=False):
244 | """
245 | Convert flow into middlebury color code image
246 | :param flow: optical flow map
247 | :return: optical flow image in middlebury color
248 | """
249 | u = flow[:, :, 0]
250 | v = flow[:, :, 1]
251 |
252 | maxu = -999.
253 | maxv = -999.
254 | minu = 999.
255 | minv = 999.
256 |
257 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
258 | u[idxUnknow] = 0
259 | v[idxUnknow] = 0
260 |
261 | maxu = max(maxu, np.max(u))
262 | minu = min(minu, np.min(u))
263 |
264 | maxv = max(maxv, np.max(v))
265 | minv = min(minv, np.min(v))
266 |
267 | rad = np.sqrt(u ** 2 + v ** 2)
268 | maxrad = max(-1, np.max(rad))
269 |
270 | if display:
271 | print("max flow: %.4f\nflow range:\nu = %.3f .. %.3f\nv = %.3f .. %.3f" % (maxrad, minu,maxu, minv, maxv))
272 |
273 | u = u/(maxrad + np.finfo(float).eps)
274 | v = v/(maxrad + np.finfo(float).eps)
275 |
276 | img = compute_color(u, v)
277 |
278 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
279 | img[idx] = 0
280 |
281 | return np.uint8(img)
282 |
283 |
284 | def evaluate_flow_file(gt, pred):
285 | """
286 | evaluate the estimated optical flow end point error according to ground truth provided
287 | :param gt: ground truth file path
288 | :param pred: estimated optical flow file path
289 | :return: end point error, float32
290 | """
291 | # Read flow files and calculate the errors
292 | gt_flow = read_flow(gt) # ground truth flow
293 | eva_flow = read_flow(pred) # predicted flow
294 | # Calculate errors
295 | average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], eva_flow[:, :, 0], eva_flow[:, :, 1])
296 | return average_pe
297 |
298 |
299 | def evaluate_flow(gt_flow, pred_flow):
300 | """
301 | gt: ground-truth flow
302 | pred: estimated flow
303 | """
304 | average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], pred_flow[:, :, 0], pred_flow[:, :, 1])
305 | return average_pe
306 |
307 |
308 | """
309 | ==============
310 | Disparity Section
311 | ==============
312 | """
313 |
314 |
315 | def read_disp_png(file_name):
316 | """
317 | Read optical flow from KITTI .png file
318 | :param file_name: name of the flow file
319 | :return: optical flow data in matrix
320 | """
321 | image_object = png.Reader(filename=file_name)
322 | image_direct = image_object.asDirect()
323 | image_data = list(image_direct[2])
324 | (w, h) = image_direct[3]['size']
325 | channel = len(image_data[0]) / w
326 | flow = np.zeros((h, w, channel), dtype=np.uint16)
327 | for i in range(len(image_data)):
328 | for j in range(channel):
329 | flow[i, :, j] = image_data[i][j::channel]
330 | return flow[:, :, 0] / 256
331 |
332 |
333 | def disp_to_flowfile(disp, filename):
334 | """
335 | Read KITTI disparity file in png format
336 | :param disp: disparity matrix
337 | :param filename: the flow file name to save
338 | :return: None
339 | """
340 | f = open(filename, 'wb')
341 | magic = np.array([202021.25], dtype=np.float32)
342 | (height, width) = disp.shape[0:2]
343 | w = np.array([width], dtype=np.int32)
344 | h = np.array([height], dtype=np.int32)
345 | empty_map = np.zeros((height, width), dtype=np.float32)
346 | data = np.dstack((disp, empty_map))
347 | magic.tofile(f)
348 | w.tofile(f)
349 | h.tofile(f)
350 | data.tofile(f)
351 | f.close()
352 |
353 |
354 | """
355 | ==============
356 | Image Section
357 | ==============
358 | """
359 |
360 |
361 | def read_image(filename):
362 | """
363 | Read normal image of any format
364 | :param filename: name of the image file
365 | :return: image data in matrix uint8 type
366 | """
367 | img = Image.open(filename)
368 | im = np.array(img)
369 | return im
370 |
371 |
372 | def warp_image(im, flow):
373 | """
374 | Use optical flow to warp image to the next
375 | :param im: image to warp
376 | :param flow: optical flow
377 | :return: warped image
378 | """
379 | from scipy import interpolate
380 | image_height = im.shape[0]
381 | image_width = im.shape[1]
382 | flow_height = flow.shape[0]
383 | flow_width = flow.shape[1]
384 | n = image_height * image_width
385 | (iy, ix) = np.mgrid[0:image_height, 0:image_width]
386 | (fy, fx) = np.mgrid[0:flow_height, 0:flow_width]
387 | fx += flow[:,:,0]
388 | fy += flow[:,:,1]
389 | mask = np.logical_or(fx <0 , fx > flow_width)
390 | mask = np.logical_or(mask, fy < 0)
391 | mask = np.logical_or(mask, fy > flow_height)
392 | fx = np.minimum(np.maximum(fx, 0), flow_width)
393 | fy = np.minimum(np.maximum(fy, 0), flow_height)
394 | points = np.concatenate((ix.reshape(n,1), iy.reshape(n,1)), axis=1)
395 | xi = np.concatenate((fx.reshape(n, 1), fy.reshape(n,1)), axis=1)
396 | warp = np.zeros((image_height, image_width, im.shape[2]))
397 | for i in range(im.shape[2]):
398 | channel = im[:, :, i]
399 | plt.imshow(channel, cmap='gray')
400 | values = channel.reshape(n, 1)
401 | new_channel = interpolate.griddata(points, values, xi, method='cubic')
402 | new_channel = np.reshape(new_channel, [flow_height, flow_width])
403 | new_channel[mask] = 1
404 | warp[:, :, i] = new_channel.astype(np.uint8)
405 |
406 | return warp.astype(np.uint8)
407 |
408 |
409 | """
410 | ==============
411 | Others
412 | ==============
413 | """
414 |
415 | def scale_image(image, new_range):
416 | """
417 | Linearly scale the image into desired range
418 | :param image: input image
419 | :param new_range: the new range to be aligned
420 | :return: image normalized in new range
421 | """
422 | min_val = np.min(image).astype(np.float32)
423 | max_val = np.max(image).astype(np.float32)
424 | min_val_new = np.array(min(new_range), dtype=np.float32)
425 | max_val_new = np.array(max(new_range), dtype=np.float32)
426 | scaled_image = (image - min_val) / (max_val - min_val) * (max_val_new - min_val_new) + min_val_new
427 | return scaled_image.astype(np.uint8)
428 |
429 |
430 | def compute_color(u, v):
431 | """
432 | compute optical flow color map
433 | :param u: optical flow horizontal map
434 | :param v: optical flow vertical map
435 | :return: optical flow in color code
436 | """
437 | [h, w] = u.shape
438 | img = np.zeros([h, w, 3])
439 | nanIdx = np.isnan(u) | np.isnan(v)
440 | u[nanIdx] = 0
441 | v[nanIdx] = 0
442 |
443 | colorwheel = make_color_wheel()
444 | ncols = np.size(colorwheel, 0)
445 |
446 | rad = np.sqrt(u**2+v**2)
447 |
448 | a = np.arctan2(-v, -u) / np.pi
449 |
450 | fk = (a+1) / 2 * (ncols - 1) + 1
451 |
452 | k0 = np.floor(fk).astype(int)
453 |
454 | k1 = k0 + 1
455 | k1[k1 == ncols+1] = 1
456 | f = fk - k0
457 |
458 | for i in range(0, np.size(colorwheel,1)):
459 | tmp = colorwheel[:, i]
460 | col0 = tmp[k0-1] / 255
461 | col1 = tmp[k1-1] / 255
462 | col = (1-f) * col0 + f * col1
463 |
464 | idx = rad <= 1
465 | col[idx] = 1-rad[idx]*(1-col[idx])
466 | notidx = np.logical_not(idx)
467 |
468 | col[notidx] *= 0.75
469 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
470 |
471 | return img
472 |
473 |
474 | def make_color_wheel():
475 | """
476 | Generate color wheel according Middlebury color code
477 | :return: Color wheel
478 | """
479 | RY = 15
480 | YG = 6
481 | GC = 4
482 | CB = 11
483 | BM = 13
484 | MR = 6
485 |
486 | ncols = RY + YG + GC + CB + BM + MR
487 |
488 | colorwheel = np.zeros([ncols, 3])
489 |
490 | col = 0
491 |
492 | # RY
493 | colorwheel[0:RY, 0] = 255
494 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
495 | col += RY
496 |
497 | # YG
498 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
499 | colorwheel[col:col+YG, 1] = 255
500 | col += YG
501 |
502 | # GC
503 | colorwheel[col:col+GC, 1] = 255
504 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
505 | col += GC
506 |
507 | # CB
508 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
509 | colorwheel[col:col+CB, 2] = 255
510 | col += CB
511 |
512 | # BM
513 | colorwheel[col:col+BM, 2] = 255
514 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
515 | col += + BM
516 |
517 | # MR
518 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
519 | colorwheel[col:col+MR, 0] = 255
520 |
521 | return colorwheel
522 |
--------------------------------------------------------------------------------
/PWC_src/pwc.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import getopt
4 | import math
5 | import numpy
6 | import os
7 | import PIL
8 | import PIL.Image
9 | import sys
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | from .correlation_package import correlation
15 |
16 |
17 | class Extractor(nn.Module):
18 | def __init__(self):
19 | super(Extractor, self).__init__()
20 |
21 | self.moduleOne = nn.Sequential(
22 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),
23 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
24 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
25 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
26 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
27 | nn.LeakyReLU(inplace=False, negative_slope=0.1)
28 | )
29 |
30 | self.moduleTwo = nn.Sequential(
31 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
32 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
33 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
34 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
35 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
36 | nn.LeakyReLU(inplace=False, negative_slope=0.1)
37 | )
38 |
39 | self.moduleThr = nn.Sequential(
40 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
41 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
42 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
43 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
44 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
45 | nn.LeakyReLU(inplace=False, negative_slope=0.1)
46 | )
47 |
48 | self.moduleFou = nn.Sequential(
49 | nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1),
50 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
51 | nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
52 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
53 | nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
54 | nn.LeakyReLU(inplace=False, negative_slope=0.1)
55 | )
56 |
57 | self.moduleFiv = nn.Sequential(
58 | nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1),
59 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
60 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
61 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
62 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
63 | nn.LeakyReLU(inplace=False, negative_slope=0.1)
64 | )
65 |
66 | self.moduleSix = nn.Sequential(
67 | nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1),
68 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
69 | nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),
70 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
71 | nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),
72 | nn.LeakyReLU(inplace=False, negative_slope=0.1)
73 | )
74 |
75 | def forward(self, tensorInput):
76 | tensorOne = self.moduleOne(tensorInput)
77 | tensorTwo = self.moduleTwo(tensorOne)
78 | tensorThr = self.moduleThr(tensorTwo)
79 | tensorFou = self.moduleFou(tensorThr)
80 | tensorFiv = self.moduleFiv(tensorFou)
81 | tensorSix = self.moduleSix(tensorFiv)
82 |
83 | return [ tensorOne, tensorTwo, tensorThr, tensorFou, tensorFiv, tensorSix ]
84 |
85 |
86 | class Backward(nn.Module):
87 | def __init__(self):
88 | super(Backward, self).__init__()
89 |
90 | def forward(self, tensorInput, tensorFlow):
91 | if hasattr(self, 'tensorPartial') == False or self.tensorPartial.size(0) != tensorFlow.size(0) or self.tensorPartial.size(2) != tensorFlow.size(2) or self.tensorPartial.size(3) != tensorFlow.size(3):
92 | self.tensorPartial = tensorFlow.new_ones(tensorFlow.size(0), 1, tensorFlow.size(2), tensorFlow.size(3))
93 |
94 | if hasattr(self, 'tensorGrid') == False or self.tensorGrid.size(0) != tensorFlow.size(0) or self.tensorGrid.size(2) != tensorFlow.size(2) or self.tensorGrid.size(3) != tensorFlow.size(3):
95 | tensorHorizontal = torch.linspace(-1.0, 1.0, tensorFlow.size(3)).view(1, 1, 1, tensorFlow.size(3)).expand(tensorFlow.size(0), -1, tensorFlow.size(2), -1)
96 | tensorVertical = torch.linspace(-1.0, 1.0, tensorFlow.size(2)).view(1, 1, tensorFlow.size(2), 1).expand(tensorFlow.size(0), -1, -1, tensorFlow.size(3))
97 | self.tensorGrid = torch.cat([tensorHorizontal, tensorVertical], 1).cuda()
98 |
99 | tensorInput = torch.cat([tensorInput, self.tensorPartial], 1)
100 | tensorFlow = torch.cat([tensorFlow[:, 0:1, :, :]/((tensorInput.size(3)-1.0)/2.0), tensorFlow[:, 1:2, :, :]/((tensorInput.size(2)-1.0)/2.0)], 1)
101 |
102 | tensorOutput = F.grid_sample(input=tensorInput, grid=(self.tensorGrid + tensorFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros')
103 | tensorMask = tensorOutput[:, -1:, :, :]; tensorMask[tensorMask > 0.999] = 1.0; tensorMask[tensorMask < 1.0] = 0.0
104 |
105 | return tensorOutput[:, :-1, :, :] * tensorMask
106 |
107 |
108 | class Decoder(nn.Module):
109 | def __init__(self, intLevel):
110 | super(Decoder, self).__init__()
111 |
112 | intPrevious = [None, None, 81+32+2+2, 81+64+2+2, 81+96+2+2, 81+128+2+2, 81, None][intLevel+1]
113 | intCurrent = [None, None, 81+32+2+2, 81+64+2+2, 81+96+2+2, 81+128+2+2, 81, None][intLevel+0]
114 |
115 | if intLevel < 6: self.moduleUpflow = nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1)
116 | if intLevel < 6: self.moduleUpfeat = nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1)
117 |
118 | if intLevel < 6: self.dblBackward = [None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel+1]
119 | if intLevel < 6: self.moduleBackward = Backward()
120 |
121 | self.moduleCorrelation = correlation.Correlation()
122 | self.moduleCorreleaky = nn.LeakyReLU(inplace=False, negative_slope=0.1)
123 |
124 | self.moduleOne = nn.Sequential(
125 | nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1),
126 | nn.LeakyReLU(inplace=False, negative_slope=0.1)
127 | )
128 |
129 | self.moduleTwo = nn.Sequential(
130 | nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1),
131 | nn.LeakyReLU(inplace=False, negative_slope=0.1)
132 | )
133 |
134 | self.moduleThr = nn.Sequential(
135 | nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1),
136 | nn.LeakyReLU(inplace=False, negative_slope=0.1)
137 | )
138 |
139 | self.moduleFou = nn.Sequential(
140 | nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1),
141 | nn.LeakyReLU(inplace=False, negative_slope=0.1)
142 | )
143 |
144 | self.moduleFiv = nn.Sequential(
145 | nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1),
146 | nn.LeakyReLU(inplace=False, negative_slope=0.1)
147 | )
148 |
149 | self.moduleSix = nn.Sequential(
150 | nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1)
151 | )
152 |
153 | def forward(self, tensorFirst, tensorSecond, objectPrevious):
154 | tensorFlow = None
155 | tensorFeat = None
156 |
157 | if objectPrevious is None:
158 | tensorFlow = None
159 | tensorFeat = None
160 |
161 | tensorVolume = self.moduleCorreleaky(self.moduleCorrelation(tensorFirst, tensorSecond))
162 | tensorFeat = torch.cat([tensorVolume], 1)
163 |
164 | elif objectPrevious is not None:
165 | tensorFlow = self.moduleUpflow(objectPrevious['tensorFlow'])
166 | tensorFeat = self.moduleUpfeat(objectPrevious['tensorFeat'])
167 | tensorVolume = self.moduleCorreleaky(self.moduleCorrelation(tensorFirst, self.moduleBackward(tensorSecond, tensorFlow*self.dblBackward)))
168 | tensorFeat = torch.cat([tensorVolume, tensorFirst, tensorFlow, tensorFeat], 1)
169 |
170 | tensorFeat = torch.cat([self.moduleOne(tensorFeat), tensorFeat], 1)
171 | tensorFeat = torch.cat([self.moduleTwo(tensorFeat), tensorFeat], 1)
172 | tensorFeat = torch.cat([self.moduleThr(tensorFeat), tensorFeat], 1)
173 | tensorFeat = torch.cat([self.moduleFou(tensorFeat), tensorFeat], 1)
174 | tensorFeat = torch.cat([self.moduleFiv(tensorFeat), tensorFeat], 1)
175 | tensorFlow = self.moduleSix(tensorFeat)
176 |
177 | return {
178 | 'tensorFlow': tensorFlow,
179 | 'tensorFeat': tensorFeat
180 | }
181 |
182 |
183 | class Refiner(nn.Module):
184 | def __init__(self):
185 | super(Refiner, self).__init__()
186 |
187 | self.moduleMain = nn.Sequential(
188 | nn.Conv2d(in_channels=81+32+2+2+128+128+96+64+32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1),
189 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
190 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2),
191 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
192 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4),
193 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
194 | nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8),
195 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
196 | nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16),
197 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
198 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1),
199 | nn.LeakyReLU(inplace=False, negative_slope=0.1),
200 | nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1)
201 | )
202 |
203 | def forward(self, tensorInput):
204 | return self.moduleMain(tensorInput)
205 |
206 |
207 | class PWC_Net(nn.Module):
208 | def __init__(self, model_path=None):
209 | super(PWC_Net, self).__init__()
210 | self.model_path = model_path
211 |
212 | self.moduleExtractor = Extractor()
213 | self.moduleTwo = Decoder(2)
214 | self.moduleThr = Decoder(3)
215 | self.moduleFou = Decoder(4)
216 | self.moduleFiv = Decoder(5)
217 | self.moduleSix = Decoder(6)
218 | self.moduleRefiner = Refiner()
219 | self.load_state_dict(torch.load(self.model_path))
220 |
221 | def forward(self, tensorFirst, tensorSecond):
222 | tensorFirst = self.moduleExtractor(tensorFirst)
223 | tensorSecond = self.moduleExtractor(tensorSecond)
224 |
225 | objectEstimate = self.moduleSix(tensorFirst[-1], tensorSecond[-1], None)
226 | objectEstimate = self.moduleFiv(tensorFirst[-2], tensorSecond[-2], objectEstimate)
227 | objectEstimate = self.moduleFou(tensorFirst[-3], tensorSecond[-3], objectEstimate)
228 | objectEstimate = self.moduleThr(tensorFirst[-4], tensorSecond[-4], objectEstimate)
229 | objectEstimate = self.moduleTwo(tensorFirst[-5], tensorSecond[-5], objectEstimate)
230 |
231 | return objectEstimate['tensorFlow'] + self.moduleRefiner(objectEstimate['tensorFeat'])
232 |
233 |
234 | if __name__ == '__main__':
235 | net = PWC_Net(model_path='models/sintel.pytorch')
236 | net.cuda()
237 |
238 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## PWC-Net (PyTorch v1.0.1)
2 |
3 | Pytorch implementation of [PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume](https://arxiv.org/pdf/1709.02371.pdf). We made it as a off-the-shelf package:
4 | - After installation, just copy the whole folder `PWC_src` to your codebase to use. See demo.py for details.
5 |
6 | ### Environment
7 |
8 | This code has been test with Python3.6 and PyTorch1.0.1, with a Tesla K80 GPU. The system is Ubuntu 14.04, and the CUDA version is 10.0. All the required python packages can be found in `requirements.txt`.
9 |
10 | ### Installation
11 |
12 | # install custom layers
13 | cd PWC_src/correlation_package
14 | python setup.py install
15 |
16 | Note: you might need to add `gencode` [here](https://github.com/vt-vl-lab/pwc-net.pytorch/blob/master/PWC_src/correlation_package/setup.py#L9), according to the GPU you use. You can find more information about `gencode` [here](https://developer.nvidia.com/cuda-gpus) and [here](http://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/).
17 |
18 | ### Converted Caffe Pre-trained Models
19 | You can find them in `models` folder.
20 |
21 | ### Inference mode
22 | Modify the path to your input, then
23 |
24 | ```
25 | python demo.py
26 | ```
27 |
28 | If installation is sucessful, you should see the following:
29 | 
30 |
31 | ### Reference
32 | If you find this implementation useful in your work, please acknowledge it appropriately and cite the paper using:
33 | ````
34 | @inproceedings{sun2018pwc,
35 | title={PWC-Net: CNNs for optical flow using pyramid, warping, and cost volume},
36 | author={Sun, Deqing and Yang, Xiaodong and Liu, Ming-Yu and Kautz, Jan},
37 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
38 | pages={8934--8943},
39 | year={2018}
40 | }
41 | ````
42 |
43 | ### Acknowledgments
44 | * [sniklaus/pytorch-pwc](https://github.com/sniklaus/pytorch-pwc): Network defintion and converted PyTorch model weights.
45 | * [NVIDIA/flownet2-pytorch](https://github.com/NVIDIA/flownet2-pytorch): Correlation module.
46 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | import torch
5 |
6 | from PWC_src import PWC_Net
7 | from PWC_src import flow_to_image
8 |
9 | FLOW_SCALE = 20.0
10 |
11 |
12 | if __name__ == '__main__':
13 | # Prepare img pair (size need to be a multipler of 64)
14 | im1 = cv2.imread('example/0img0.ppm')
15 | im2 = cv2.imread('example/0img1.ppm')
16 | im1 = torch.from_numpy((im1/255.).astype(np.float32)).permute(2, 0, 1).unsqueeze(0)
17 | im2 = torch.from_numpy((im2/255.).astype(np.float32)).permute(2, 0, 1).unsqueeze(0)
18 | im1_v = im1.cuda()
19 | im2_v = im2.cuda()
20 |
21 | # Build model
22 | pwc = PWC_Net(model_path='models/sintel.pytorch')
23 | #pwc = PWC_Net(model_path='models/chairs-things.pytorch')
24 | pwc = pwc.cuda()
25 | pwc.eval()
26 |
27 | import time
28 | start = time.time()
29 | flow = FLOW_SCALE*pwc(im1_v, im2_v)
30 | print(time.time()-start)
31 | flow = flow.data.cpu()
32 | flow = flow[0].numpy().transpose((1,2,0))
33 | flow_im = flow_to_image(flow)
34 |
35 | # Visualization
36 | import matplotlib.pyplot as plt
37 | plt.imshow(flow_im)
38 | plt.show()
39 |
40 |
--------------------------------------------------------------------------------
/example/0img0.ppm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/example/0img0.ppm
--------------------------------------------------------------------------------
/example/0img1.ppm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/example/0img1.ppm
--------------------------------------------------------------------------------
/example/1img0.ppm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/example/1img0.ppm
--------------------------------------------------------------------------------
/example/1img1.ppm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/example/1img1.ppm
--------------------------------------------------------------------------------
/example/flow0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/example/flow0.png
--------------------------------------------------------------------------------
/misc/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/misc/demo.png
--------------------------------------------------------------------------------
/models/chairs-things.pytorch:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/models/chairs-things.pytorch
--------------------------------------------------------------------------------
/models/sintel.pytorch:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vt-vl-lab/pwc-net.pytorch/a03f0e5fa3c85fd4612a87b3eb9016815bb717e9/models/sintel.pytorch
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | certifi==2019.3.9
2 | cycler==0.10.0
3 | joblib==0.11
4 | kiwisolver==1.1.0
5 | matplotlib==3.1.0
6 | numpy==1.16.4
7 | opencv-python==3.4.3.18
8 | Pillow==6.2.0
9 | pyparsing==2.4.0
10 | pypng==0.0.19
11 | python-dateutil==2.8.0
12 | six==1.12.0
13 | torch==1.0.1
14 |
--------------------------------------------------------------------------------