├── .idea
├── UPFlow_pytorch.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── LICENSE
├── README.md
├── dataset
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ └── kitti_dataset.cpython-35.pyc
└── kitti_dataset.py
├── how_to_install.md
├── model
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── pwc_modules.cpython-35.pyc
│ └── upflow.cpython-35.pyc
├── correlation_package
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-35.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── correlation.cpython-35.pyc
│ │ └── correlation.cpython-36.pyc
│ ├── build
│ │ ├── lib.linux-x86_64-3.5
│ │ │ └── correlation_cuda.cpython-35m-x86_64-linux-gnu.so
│ │ └── temp.linux-x86_64-3.5
│ │ │ ├── correlation_cuda.o
│ │ │ └── correlation_cuda_kernel.o
│ ├── correlation.py
│ ├── correlation_cuda.cc
│ ├── correlation_cuda.egg-info
│ │ ├── PKG-INFO
│ │ ├── SOURCES.txt
│ │ ├── dependency_links.txt
│ │ └── top_level.txt
│ ├── correlation_cuda_kernel.cu
│ ├── correlation_cuda_kernel.cuh
│ ├── dist
│ │ └── correlation_cuda-0.0.0-py3.5-linux-x86_64.egg
│ └── setup.py
├── pwc_modules.py
└── upflow.py
├── requirements.txt
├── scripts
├── __init__.py
├── ex_runner.py
├── simple_train.py
└── upflow_kitti2015.pth
├── test.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-35.pyc
├── loss.cpython-35.pyc
├── pytorch_correlation.cpython-35.pyc
└── tools.cpython-35.pyc
├── loss.py
├── pytorch_correlation.py
└── tools.py
/.idea/UPFlow_pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 | 1617856874885
66 |
67 |
68 | 1617856874885
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 coolbeam
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
[CVPR2021] UPFlow: Upsampling Pyramid for Unsupervised Optical Flow Learning
2 |
3 | Kunming Luo1 , Chuan Wang1 , Shuaicheng Liu2,1 , Haoqiang Fan1 , Jue Wang1 , Jian Sun1
4 |
5 | 1. Megvii Technology, 2. University of Electronic Science and Technology of China
6 |
7 | This is the official implementation of the paper [***UPFlow: Upsampling Pyramid for Unsupervised Optical Flow Learning***](https://openaccess.thecvf.com/content/CVPR2021/papers/Luo_UPFlow_Upsampling_Pyramid_for_Unsupervised_Optical_Flow_Learning_CVPR_2021_paper.pdf) CVPR 2021.
8 |
9 | ## Abstract
10 | We present an unsupervised learning approach for optical flow estimation by improving the upsampling and learning of pyramid network. We design a self-guided upsample module to tackle the interpolation blur problem caused by bilinear upsampling between pyramid levels. Moreover, we propose a pyramid distillation loss to add supervision for intermediate levels via distilling the finest flow as pseudo labels. By integrating these two components together, our method achieves the best performance for unsupervised optical flow learning on multiple leading benchmarks, including MPI-SIntel, KITTI 2012 and KITTI 2015. In particular, we achieve EPE=1.4 on KITTI 2012 and F1=9.38% on KITTI 2015, which outperform the previous state-of-the-art methods by 22.2% and 15.7%, respectively.
11 |
12 | ## This repository includes:
13 | - inferring scripts; and
14 | - pretrain model;
15 |
16 | ## Presentation Video
17 | [[Youtube](https://www.youtube.com/watch?v=voD3tA8q-lk&t=4s)], [[Bilibili](https://www.bilibili.com/video/BV1vg41137eH/)]
18 |
19 | ## Pipeline
20 | 
21 | Illustration of the pipeline of our network, which contains two stage: pyramid encoding to extract feature pairs in different scales and pyramid decoding to estimate optical flow in each scale. Note that the parameters of the decoder module and the upsample module are shared across all the pyramid levels.
22 |
23 | ## Self-Guided Upsample Module
24 | 
25 |
26 | ## Usage
27 |
28 | Please first install the environments following `how_to_install.md`.
29 |
30 | Run `python3 test.py` to test our trained model on KITTI 2015 dataset. Note that Cuda is needed.
31 |
32 | ## Results
33 | 
34 | Visual example of our self-guided upsample module (SGU) on MPI-Sintel Final dataset. Results of bilinear method and our SGU are shown.
35 |
36 | ## Citation
37 | If you think this work is helpful, please cite
38 | ```
39 | @inproceedings{luo2021upflow,
40 | title={Upflow: Upsampling pyramid for unsupervised optical flow learning},
41 | author={Luo, Kunming and Wang, Chuan and Liu, Shuaicheng and Fan, Haoqiang and Wang, Jue and Sun, Jian},
42 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
43 | pages={1045--1054},
44 | year={2021}
45 | }
46 | ```
47 |
48 |
49 |
50 | ## Acknowledgement
51 | Part of our codes are adapted from [IRR-PWC](https://github.com/visinf/irr), [UnFlow](https://github.com/simonmeister/UnFlow) [ARFlow](https://github.com/lliuz/ARFlow) and [UFlow](https://github.com/google-research/google-research/tree/master/uflow), we thank the authors for their contributions.
52 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/dataset/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/kitti_dataset.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/dataset/__pycache__/kitti_dataset.cpython-35.pyc
--------------------------------------------------------------------------------
/dataset/kitti_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | import os
4 | from utils.tools import tools
5 | import random
6 | import cv2
7 | from torch.utils.data import Dataset, DataLoader
8 | import numpy as np
9 | import torch
10 | import tensorflow as tf
11 | import warnings # ignore warnings
12 | import zipfile
13 | from glob import glob
14 | from torchvision import transforms as vision_transforms
15 | import imageio
16 | import png
17 |
18 | '''
19 | Here tensorflow is not necessary, it is needed in my early implementation from UnFlow and DDFlow.
20 | '''
21 |
22 | warnings.filterwarnings('ignore')
23 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' # or any {'0', '1', '2'}, ignore tensorflow log information
24 | '''
25 | you can download kitti mv dataset from:
26 | kitti 2012 multi view: http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php?benchmark=flow
27 | (17GB) data_stereo_flow_multiview.zip, here I save it in .../KITTI_data_mv_stereo_flow_2012/data_stereo_flow_multiview.zip
28 | kitti 2015 multi view: http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow
29 | (14GB) data_scene_flow_multiview.zip, here I save it in .../KITTI_data_mv_stereo_flow_2015/data_scene_flow_multiview.zip
30 | '''
31 | mv_data_dir = '/data/Optical_Flow_all/datasets/KITTI_data/KITTI_data_mv' # TODO important path
32 | '''
33 | KITTI flow data:
34 | kitti 2012 data: data_stereo_flow.zip (2.0GB)
35 | kitti 2015 data: data_scene_flow.zip (1.7GB)
36 | here you should unzip them
37 | '''
38 | kitti_flow_dir = '/data/Optical_Flow_all/datasets/KITTI_data' # TODO important path
39 |
40 |
41 | class img_func():
42 |
43 | @classmethod
44 | def get_process_img(cls, img_name, normalize=True, if_horizontal_flip=False):
45 | def _normalize_image(image):
46 | mean = [104.920005, 110.1753, 114.785955]
47 | stddev = 1 / 0.0039216
48 | unflow_im = (image - mean) / stddev
49 | # std__ = [69.85, 68.81, 72.45]
50 | # mean_ = [118.93, 113.97, 102.60]
51 | # unflow_pytorch = unflow_im * std__ + mean_
52 | # # check(img_rgb, 'img_rgb')
53 | # unflow_pytorch = unflow_pytorch * (1.0 / 255.0)
54 | return unflow_im
55 |
56 | # img = cv2.imread(img_name)
57 | data = tf.io.read_file(img_name)
58 | img = tf.image.decode_image(data).numpy()
59 | if if_horizontal_flip:
60 | img = np.flip(img, 1)
61 | if normalize:
62 | img = _normalize_image(img)
63 | # img = img / 255.0
64 | img = np.transpose(img, [2, 0, 1])
65 | return img
66 |
67 | @classmethod
68 | def get_process_img_only_img(cls, img, normalize=True, if_horizontal_flip=False):
69 | def _normalize_image(image):
70 | mean = [104.920005, 110.1753, 114.785955]
71 | stddev = 1 / 0.0039216
72 | unflow_im = (image - mean) / stddev
73 | # std__ = [69.85, 68.81, 72.45]
74 | # mean_ = [118.93, 113.97, 102.60]
75 | # unflow_pytorch = unflow_im * std__ + mean_
76 | # # check(img_rgb, 'img_rgb')
77 | # unflow_pytorch = unflow_pytorch * (1.0 / 255.0)
78 | return unflow_im
79 |
80 | # img = cv2.imread(img_name)
81 | # img = imageio.imread(img_name)
82 | if if_horizontal_flip:
83 | img = np.flip(img, 1)
84 | if normalize:
85 | img = _normalize_image(img)
86 | # img = img / 255.0
87 | img = np.transpose(img, [2, 0, 1])
88 | return img
89 |
90 | @classmethod
91 | def frame_name_to_num(cls, name):
92 | stripped = name.split('.')[0].lstrip('0')
93 | if stripped == '':
94 | return 0
95 | return int(stripped)
96 |
97 | @classmethod
98 | def np_2_tensor(cls, *args):
99 | def func(a):
100 | b = torch.from_numpy(a)
101 | b = b.float()
102 | return b
103 |
104 | return [func(a) for a in args]
105 |
106 | @classmethod
107 | def read_flow(cls, filename):
108 | gt = cv2.imread(filename)
109 | flow = (gt[:, :, 0:2] - 2 ** 15) / 64.0
110 | mask = gt[:, :, 2:3]
111 | flow = np.transpose(flow, [2, 0, 1])
112 | mask = np.transpose(mask, [2, 0, 1])
113 | return flow, mask
114 |
115 | @classmethod
116 | def read_flow_tf(cls, filename): # need tensorflow 2.0.0, WRONG!!! use the function read_png_flow
117 | data = tf.io.read_file(filename)
118 | gt = tf.image.decode_png(data, channels=3, dtype=tf.uint16).numpy()
119 | # gt=cv2.imread(filename)
120 | flow = (gt[:, :, 0:2] - 2 ** 15) / 64.0
121 | flow = flow.astype(np.float)
122 | mask = gt[:, :, 2:3]
123 | mask = np.uint8(mask)
124 | flow = np.transpose(flow, [2, 0, 1])
125 | mask = np.transpose(mask, [2, 0, 1])
126 | # print('mask: max', np.max(mask), 'min', np.min(mask))
127 | return flow, mask
128 |
129 | @classmethod
130 | def read_png_flow(cls, fpath):
131 | """
132 | Read KITTI optical flow, returns u,v,valid mask
133 |
134 | """
135 |
136 | R = png.Reader(fpath)
137 | width, height, data, _ = R.asDirect()
138 | # This only worked with python2.
139 | # I = np.array(map(lambda x:x,data)).reshape((height,width,3))
140 | gt = np.array([x for x in data]).reshape((height, width, 3))
141 | flow = gt[:, :, 0:2]
142 | flow = (flow.astype('float64') - 2 ** 15) / 64.0
143 | flow = flow.astype(np.float)
144 | mask = gt[:, :, 2:3]
145 | mask = np.uint8(mask)
146 | flow = np.transpose(flow, [2, 0, 1])
147 | mask = np.transpose(mask, [2, 0, 1])
148 |
149 | return flow, mask
150 |
151 | @classmethod
152 | def censusTransform(cls, src_bytes, if_debug=False):
153 | def censusTransformSingleChannel(src_bytes):
154 | h, w = src_bytes.shape
155 | # Initialize output array
156 | census = np.zeros((h - 2, w - 2), dtype='uint8')
157 | # census1 = np.zeros((h, w), dtype='uint8')
158 |
159 | # centre pixels, which are offset by (1, 1)
160 | cp = src_bytes[1:h - 1, 1:w - 1]
161 |
162 | # offsets of non-central pixels
163 | offsets = [(u, v) for v in range(3) for u in range(3) if not u == 1 == v]
164 |
165 | # Do the pixel comparisons
166 | for u, v in offsets:
167 | census = (census << 1) | (src_bytes[v:v + h - 2, u:u + w - 2] >= cp)
168 |
169 | return census
170 |
171 | # chk num of channels. if 1 call as is, if 3 call it with each channel
172 | if len(src_bytes.shape) == 2: # single channel
173 | census = censusTransformSingleChannel(np.lib.pad(src_bytes, 1, 'constant', constant_values=0))
174 |
175 | elif len(src_bytes.shape) == 3 and src_bytes.shape[2] == 3:
176 | temp = np.lib.pad(src_bytes[:, :, 0], 1, 'constant', constant_values=0)
177 | census_a = censusTransformSingleChannel(temp)
178 | if if_debug:
179 | cv2.imshow(mat=temp, winname='pad')
180 | print('temp', temp.shape, np.max(temp), np.min(temp))
181 | print('census_a', census_a.shape, np.max(census_a), np.min(census_a))
182 | cv2.imshow(mat=census_a, winname='census_a')
183 | cv2.waitKey()
184 | census_b = censusTransformSingleChannel(np.lib.pad(src_bytes[:, :, 1], 1, 'constant', constant_values=0))
185 | census_c = censusTransformSingleChannel(np.lib.pad(src_bytes[:, :, 2], 1, 'constant', constant_values=0))
186 | census = np.dstack((census_a, census_b, census_c))
187 | else:
188 | raise ValueError('wrong channel RGB ')
189 |
190 | return census
191 |
192 |
193 | class kitti_train:
194 | @classmethod
195 | def mv_data_get_file_names(cls, if_test=False):
196 | file_names_save_path = os.path.join(mv_data_dir, 'kitti_mv_file_names.pkl')
197 | if os.path.isfile(file_names_save_path) and not if_test:
198 | data = tools.pickle_saver.load_picke(file_names_save_path)
199 | return data
200 | else:
201 | mv_2012_name = 'stereo_flow_2012'
202 | mv_2012_file_name = 'data_stereo_flow_multiview.zip'
203 | mv_2012_zip_file = os.path.join(mv_data_dir, mv_2012_name, mv_2012_file_name)
204 | mv_2012_dir = os.path.join(mv_data_dir, mv_2012_name, mv_2012_file_name[:-4])
205 | if os.path.isdir(mv_2012_dir):
206 | pass
207 | else:
208 | tools.extract_zip(mv_2012_zip_file, mv_2012_dir)
209 |
210 | mv_2015_name = 'stereo_flow_2015'
211 | mv_2015_file_name = 'data_scene_flow_multiview.zip'
212 | mv_2015_zip_file = os.path.join(mv_data_dir, mv_2015_name, mv_2015_file_name)
213 | mv_2015_dir = os.path.join(mv_data_dir, mv_2015_name, mv_2015_file_name[:-4])
214 | if os.path.isdir(mv_2015_dir):
215 | pass
216 | else:
217 | tools.extract_zip(mv_2015_zip_file, mv_2015_dir)
218 |
219 | def read_mv_data(d_path):
220 | def tf_read_img(im_path):
221 | data_img = tf.io.read_file(im_path)
222 | img_read = tf.image.decode_image(data_img).numpy() # get image 1
223 | return img_read
224 |
225 | sample_ls = []
226 | for sub_dir in ['testing', 'training']:
227 | img_dir = os.path.join(d_path, sub_dir, 'image_2')
228 | file_ls = os.listdir(img_dir)
229 | file_ls.sort()
230 | print(' ')
231 | for ind in range(len(file_ls) - 1):
232 | name = file_ls[ind]
233 | nex_name = file_ls[ind + 1]
234 | id_ = int(name[-6:-4])
235 | id_nex = int(nex_name[-6:-4])
236 | if id_ != id_nex - 1 or 12 >= id_ >= 9 or 12 >= id_nex >= 9:
237 | pass
238 | else:
239 | file_path = os.path.join(img_dir, name)
240 | file_path_nex = os.path.join(img_dir, nex_name)
241 | # # test
242 | if if_test:
243 | # im1 = cv2.imread(file_path)
244 | # im2 = cv2.imread(file_path_nex)
245 | im1 = tf_read_img(file_path)[:, :, ::-1]
246 | im2 = tf_read_img(file_path_nex)[:, :, ::-1]
247 | cv2.imshow(name, im1)
248 | k = cv2.waitKey()
249 | c = 0
250 | while k != ord('q'):
251 | c += 1
252 | if c % 2 == 0:
253 | cv2.imshow(name, im1)
254 | else:
255 | cv2.imshow(name, im2)
256 | k = cv2.waitKey()
257 | cv2.destroyAllWindows()
258 | sample_ls.append((file_path, file_path_nex))
259 |
260 | return sample_ls
261 |
262 | filenames = {}
263 | filenames['2012'] = read_mv_data(mv_2012_dir)
264 | filenames['2015'] = read_mv_data(mv_2015_dir)
265 | tools.pickle_saver.save_pickle(files=filenames, file_path=file_names_save_path)
266 | return filenames
267 |
268 | class kitti_data_with_start_point(Dataset):
269 | class config(tools.abstract_config):
270 |
271 | def __init__(self, **kwargs):
272 | self.crop_size = (256, 832) # original size is (512,1152), we directly set as (256, 832) during training
273 | self.rho = 8
274 | self.swap_images = True
275 | self.normalize = True
276 | self.repeat = None # if repeat the dataset in one epoch
277 | self.horizontal_flip_aug = True
278 | self.mv_type = None # '2015' or '2012'
279 | self.update(kwargs)
280 |
281 | def __call__(self):
282 | return kitti_train.kitti_data_with_start_point(self)
283 |
284 | def __init__(self, conf: config):
285 | self.conf = conf
286 | if self.conf.mv_type in ['2015', '2012']:
287 | file_dict_ = kitti_train.mv_data_get_file_names()
288 | if self.conf.mv_type == '2015':
289 | print('=' * 5)
290 | print('use multi_view dataset 2015')
291 | print('=' * 5)
292 | filenames_extended = file_dict_['2015']
293 | elif self.conf.mv_type == '2012':
294 | print('=' * 5)
295 | print('use multi_view dataset 2012')
296 | print('=' * 5)
297 | filenames_extended = file_dict_['2012']
298 | else:
299 | raise ValueError('wrong type mv dataset: %s' % self.conf.mv_type)
300 | else:
301 | raise ValueError('mv_type should be 2012 or 2015')
302 | # =====================================
303 | self.filenames_extended = filenames_extended
304 | self.N = len(self.filenames_extended)
305 |
306 | def __len__(self):
307 | if self.conf.repeat is None or self.conf.repeat <= 0:
308 | return len(self.filenames_extended)
309 | else:
310 | assert type(self.conf.repeat) == int
311 | return len(self.filenames_extended) * self.conf.repeat
312 |
313 | def __getitem__(self, index):
314 | im1, im2 = self.read_img(index=index)
315 | im1_crop, im2_crop, start = self.random_crop(im1, im2)
316 | im1, im2, im1_crop, im2_crop, start = img_func.np_2_tensor(im1, im2, im1_crop, im2_crop, start)
317 | return im1, im2, im1_crop, im2_crop, start
318 |
319 | def random_crop(self, im1, im2):
320 | height, width = im1.shape[1:]
321 | patch_size_h, patch_size_w = self.conf.crop_size
322 | x = np.random.randint(self.conf.rho, width - self.conf.rho - patch_size_w)
323 | # print(self.rho, height - self.rho - patch_size_h)
324 | y = np.random.randint(self.conf.rho, height - self.conf.rho - patch_size_h)
325 | start = np.array([x, y])
326 | start = np.expand_dims(np.expand_dims(start, 1), 2)
327 | img_1_patch = im1[:, y: y + patch_size_h, x: x + patch_size_w]
328 | img_2_patch = im2[:, y: y + patch_size_h, x: x + patch_size_w]
329 | return img_1_patch, img_2_patch, start
330 |
331 | def read_img(self, index):
332 | if self.conf.horizontal_flip_aug and random.random() < 0.5:
333 | if_horizontal_flip = True
334 | else:
335 | if_horizontal_flip = False
336 | im1_path, im2_path = self.filenames_extended[index % self.N]
337 | im1 = img_func.get_process_img(im1_path, normalize=self.conf.normalize, if_horizontal_flip=if_horizontal_flip)
338 | im2 = img_func.get_process_img(im2_path, normalize=self.conf.normalize, if_horizontal_flip=if_horizontal_flip)
339 | if self.conf.swap_images and tools.random_flag(0.5):
340 | return im2, im1
341 | else:
342 | return im1, im2
343 |
344 | @classmethod
345 | def demo(cls):
346 | def process(a):
347 | b = a.numpy()
348 | b = np.squeeze(b)
349 | b = np.transpose(b, (1, 2, 0))
350 | b = tools.im_norm(b)
351 | b = b.astype(np.uint8)
352 | return b
353 |
354 | data_conf = kitti_train.kitti_data_with_start_point.config()
355 | print('begin!!!' + '=' * 10)
356 | data_conf.crop_size = (256, 832)
357 | data_conf.rho = 8
358 | data_conf.repeat = None # if repeat the dataset in one epoch
359 | data_conf.normalize = True
360 | data_conf.horizontal_flip_aug = False
361 | data_conf.mv_type = '2012'
362 | data_conf.get_name(print_now=True)
363 | data = data_conf()
364 | N = len(data)
365 | print('len: %5s' % N)
366 | for i in range(len(data)):
367 | if i > 5:
368 | break
369 | im1, im2, im1_crop, im2_crop, start = data.__getitem__(i)
370 | tools.check_tensor(im1, 'im1')
371 | tools.check_tensor(im2, 'im2')
372 | tools.check_tensor(im1_crop, 'im1_crop')
373 | tools.check_tensor(im2_crop, 'im2_crop')
374 | tools.check_tensor(start, 'start')
375 | # im1, im2 = tools.func_decorator(process, im1, im2)
376 | # tools.cv2_show_dict(im1=im1, im2=im2)
377 |
378 |
379 | class kitti_flow:
380 | class Evaluation_bench():
381 |
382 | def __init__(self, name, if_gpu=True, batch_size=1):
383 | assert if_gpu == True
384 | self.batch_size = batch_size
385 | assert name in ['2012_train', '2015_train', '2012_test', '2015_test']
386 | self.name = name
387 | self.dataset = kitti_flow.kitti_train(name=name)
388 | self.loader = tools.data_prefetcher(self.dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
389 | # self.loader = DataLoader(dataset=self.dataset, batch_size=batc_size, num_workers=4, shuffle=False, drop_last=False)
390 | # self.loader = tools.data_prefetcher(self.loader)
391 | self.if_gpu = if_gpu
392 | self.timer = tools.time_clock()
393 |
394 | def __call__(self, test_model: tools.abs_test_model):
395 | def calculate(predflow, gt_flow, gt_mask):
396 | error_ = self.flow_error_avg(gt_flow, predflow, gt_mask)
397 | # error_avg_ = summarized_placeholder('AEE/' + name, key='eval_avg')
398 | outliers_ = self.outlier_pct(gt_flow, predflow, gt_mask)
399 | # outliers_avg = summarized_placeholder('outliers/' + name, key='eval_avg')
400 | # values_.extend([error_, outliers_])
401 | # averages_.extend([error_avg_, outliers_avg])
402 | return error_, outliers_
403 |
404 | if self.name in ['2012_test', '2015_test']:
405 | self.timer.start()
406 | index = -1
407 | # with torch.no_grad():
408 | for i in range(len(self.dataset)):
409 | im1, im2, img_name = self.dataset.__getitem__(i)
410 | im1 = torch.unsqueeze(im1, 0)
411 | im2 = torch.unsqueeze(im2, 0)
412 | if self.if_gpu:
413 | im1, im2 = tools.tensor_gpu(im1, im2, check_on=True)
414 | index += 1
415 | # im1, im2 = batch
416 | predflow = test_model.eval_forward(im1, im2, 0)
417 | test_model.eval_save_result(img_name, predflow)
418 | self.timer.end()
419 | print('=' * 3 + ' test time %ss ' % self.timer.get_during() + '=' * 3)
420 | return
421 | all_pep_error_meter = tools.AverageMeter()
422 | f1_rate_meter = tools.AverageMeter()
423 | occ_pep_error_meter = tools.AverageMeter()
424 | noc_pep_error_meter = tools.AverageMeter()
425 | self.timer.start()
426 | index = -1
427 | batch = self.loader.next()
428 | # with torch.no_grad():
429 | while batch is not None:
430 | index += 1
431 | im1, im2, occ, occmask, noc, nocmask = batch
432 | num = im1.shape[0]
433 | predflow = test_model.eval_forward(im1, im2, occ, occmask, noc, nocmask)
434 |
435 | pep_error_all, f1_rate = calculate(predflow=predflow, gt_flow=occ, gt_mask=occmask)
436 | all_pep_error_meter.update(val=pep_error_all.item(), num=num)
437 | f1_rate_meter.update(val=f1_rate.item(), num=num)
438 |
439 | noc_pep_error_, _ = calculate(predflow=predflow, gt_flow=noc, gt_mask=nocmask)
440 | noc_pep_error_meter.update(val=noc_pep_error_.item(), num=num)
441 |
442 | occ_erea_mask = occmask - nocmask
443 | pep_error_occ, _ = calculate(predflow=predflow, gt_flow=occ, gt_mask=occ_erea_mask)
444 | occ_pep_error_meter.update(val=pep_error_occ.item(), num=num)
445 | save_name = 'all_%.2f f1_%.1f noc_%.2f occ_%.2f__%d' % (all_pep_error_meter.val, f1_rate_meter.val, noc_pep_error_meter.val, occ_pep_error_meter.val, index)
446 | test_model.eval_save_result(save_name, predflow, occmask=occmask)
447 | batch = self.loader.next()
448 | self.timer.end()
449 | print('=' * 3 + ' eval time %ss ' % self.timer.get_during() + '=' * 3)
450 | return all_pep_error_meter.avg, f1_rate_meter.avg, noc_pep_error_meter.avg, occ_pep_error_meter.avg
451 |
452 | @classmethod
453 | def flow_error_avg_tf(cls, flow_1, flow_2, mask):
454 | """Evaluates the average endpoint error between flow batches.tf batch is n h w c"""
455 |
456 | def euclidean(t):
457 | return tf.sqrt(tf.reduce_sum(t ** 2, [3], keepdims=True))
458 |
459 | diff = euclidean(flow_1 - flow_2) * mask
460 | error = tf.reduce_sum(diff) / tf.reduce_sum(mask)
461 | return error
462 |
463 | @classmethod
464 | def flow_error_avg(cls, flow_1, flow_2, mask):
465 | """Evaluates the average endpoint error between flow batches. torch batch is n c h w"""
466 |
467 | def euclidean(t):
468 | return torch.sqrt(torch.sum(t ** 2, dim=(1,), keepdim=True))
469 |
470 | diff = euclidean(flow_1 - flow_2) * mask
471 | mask_s = torch.sum(mask)
472 | diff_s = torch.sum(diff)
473 | # print('diff_s', diff_s, 'mask_s', mask_s)
474 | error = diff_s / (mask_s + 1e-6)
475 | return error
476 |
477 | @classmethod
478 | def outlier_pct(cls, gt_flow, predflow, mask, threshold=3.0, relative=0.05):
479 | def euclidean(t):
480 | return torch.sqrt(torch.sum(t ** 2, dim=(1,), keepdim=True))
481 |
482 | def outlier_ratio(gtflow, predflow, mask, threshold=3.0, relative=0.05):
483 | diff = euclidean(gtflow - predflow) * mask
484 | threshold = torch.tensor(threshold).type_as(gtflow)
485 | if relative is not None:
486 | threshold_map = torch.max(threshold, euclidean(gt_flow) * relative)
487 | # outliers = tf.cast(tf.greater_equal(diff, threshold), tf.float32)
488 | outliers = diff > threshold_map
489 | else:
490 | # outliers = tf.cast(tf.greater_equal(diff, threshold), tf.float32)
491 | outliers = diff > threshold
492 | mask_s = torch.sum(mask)
493 | outliers_s = torch.sum(outliers)
494 | # print('outliers_s', outliers_s, 'mask_s', mask_s)
495 | ratio = outliers_s / mask_s
496 | return ratio
497 |
498 | frac = outlier_ratio(gt_flow, predflow, mask, threshold, relative) * 100
499 | return frac
500 |
501 | @classmethod
502 | def demo(cls):
503 | class test_model(tools.abs_test_model):
504 |
505 | def eval_forward(self, im1, im2, gt, *args):
506 | return gt
507 |
508 | def eval_save_result(self, save_name, predflow, *args, **kwargs):
509 | print(save_name)
510 |
511 | eval_ben = kitti_flow.Evaluation_bench(name='2012_train')
512 | model = test_model()
513 | occ_pep, occ_rate, noc_pep, noc_rate = eval_ben(model)
514 | print('occ_pep', occ_pep, 'occ_rate', occ_rate, 'noc_pep', noc_pep, 'noc_rate', noc_rate)
515 |
516 | @classmethod
517 | def get_file_names(cls, if_test=False):
518 | def get_img_flow_path_pair(im_dir, flow_occ_dir, flow_noc_dir):
519 | a = []
520 | image_files = os.listdir(im_dir)
521 | image_files.sort()
522 | flow_occ_files = os.listdir(flow_occ_dir)
523 | flow_occ_files.sort()
524 | flow_noc_files = os.listdir(flow_noc_dir)
525 | flow_noc_files.sort()
526 | assert len(image_files) % 2 == 0, 'expected pairs of images'
527 | assert len(flow_occ_files) == len(flow_noc_files), 'here goes wrong'
528 | assert len(flow_occ_files) == len(image_files) / 2, 'here goes wrong'
529 | for i in range(len(image_files) // 2):
530 | filenames_1 = os.path.join(image_dir, image_files[i * 2])
531 | filenames_2 = os.path.join(image_dir, image_files[i * 2 + 1])
532 | filenames_gt_occ = os.path.join(flow_dir_occ, flow_occ_files[i])
533 | filenames_gt_noc = os.path.join(flow_dir_noc, flow_noc_files[i])
534 | print('occ', flow_occ_files[i], 'noc', flow_noc_files[i], 'im1', image_files[i * 2], 'im2', image_files[i * 2 + 1])
535 | a.append({'flow_occ': filenames_gt_occ, 'flow_noc': filenames_gt_noc, 'im1': filenames_1, 'im2': filenames_2})
536 | return a
537 |
538 | def get_img_path_dir(im_dir):
539 | a = []
540 | image_files = os.listdir(im_dir)
541 | image_files.sort()
542 | assert len(image_files) % 2 == 0, 'expected pairs of images'
543 | for i in range(len(image_files) // 2):
544 | filenames_1 = os.path.join(image_dir, image_files[i * 2])
545 | filenames_2 = os.path.join(image_dir, image_files[i * 2 + 1])
546 | a.append({'im1': filenames_1, 'im2': filenames_2})
547 | return a
548 |
549 | file_names_save_path = os.path.join(kitti_flow_dir, 'kitti_flow_2012_2015_file_names.pkl')
550 | if os.path.isfile(file_names_save_path) and not if_test:
551 | data = tools.pickle_saver.load_picke(file_names_save_path)
552 | return data
553 | else:
554 | data = {}
555 | # get 2012 train dataset paths
556 | image_dir = os.path.join(kitti_flow_dir, 'data_stereo_flow', 'training', 'colored_0')
557 | flow_dir_occ = os.path.join(kitti_flow_dir, 'data_stereo_flow', 'training', 'flow_occ')
558 | flow_dir_noc = os.path.join(kitti_flow_dir, 'data_stereo_flow', 'training', 'flow_noc')
559 | data['2012_train'] = get_img_flow_path_pair(im_dir=image_dir, flow_occ_dir=flow_dir_occ, flow_noc_dir=flow_dir_noc)
560 | # get 2015 train dataset paths
561 | image_dir = os.path.join(kitti_flow_dir, 'data_scene_flow', 'training', 'image_2')
562 | flow_dir_occ = os.path.join(kitti_flow_dir, 'data_scene_flow', 'training', 'flow_occ')
563 | flow_dir_noc = os.path.join(kitti_flow_dir, 'data_scene_flow', 'training', 'flow_noc')
564 | data['2015_train'] = get_img_flow_path_pair(im_dir=image_dir, flow_occ_dir=flow_dir_occ, flow_noc_dir=flow_dir_noc)
565 |
566 | # get 2012 test dataset paths
567 | image_dir = os.path.join(kitti_flow_dir, 'data_stereo_flow', 'testing', 'colored_0')
568 | data['2012_test'] = get_img_path_dir(im_dir=image_dir)
569 | # get 2015 test dataset paths
570 | image_dir = os.path.join(kitti_flow_dir, 'data_scene_flow', 'testing', 'image_2')
571 | data['2015_test'] = get_img_path_dir(im_dir=image_dir)
572 | tools.pickle_saver.save_pickle(files=data, file_path=file_names_save_path)
573 | return data
574 |
575 | class kitti_train():
576 |
577 | def __init__(self, name):
578 | assert name in ['2012_train', '2015_train', '2012_test', '2015_test'] # ['2012_train', '2015_train']
579 | data = kitti_flow.get_file_names()
580 | self.file_names = data[name]
581 | self.normalize = True
582 | self.name = name
583 |
584 | def __len__(self):
585 | return len(self.file_names)
586 |
587 | def __getitem__(self, index):
588 | def pro(*args):
589 | def func(a):
590 | # crop
591 | b = a[:, y: y + th, x: x + tw]
592 | # expand to 1,c,h,w
593 | # b = np.expand_dims(b, 0)
594 | return b
595 |
596 | return [func(a) for a in args]
597 |
598 | def read_img(img_path):
599 | data_ = tf.io.read_file(img_path)
600 | img = tf.image.decode_image(data_).numpy()
601 | return img
602 |
603 | data = self.file_names[index]
604 |
605 | im1 = read_img(data['im1'])
606 | im2 = read_img(data['im2'])
607 | im1 = img_func.get_process_img_only_img(im1, normalize=self.normalize)
608 | im2 = img_func.get_process_img_only_img(im2, normalize=self.normalize)
609 | if self.name in ['2012_test', '2015_test']:
610 | # crop
611 | img_name = os.path.basename(data['im1']).replace('.png', '')
612 | h, w = im1.shape[1:]
613 | th = int(int(h / 32) * 32)
614 | tw = int(int(w / 32) * 32)
615 | x = int((h - th) / 2)
616 | y = int((w - tw) / 2)
617 | # im1, im2 = pro(im1, im2)
618 | im1, im2 = img_func.np_2_tensor(im1, im2)
619 | return im1, im2, img_name
620 | else:
621 | # use tf
622 | occ, occmask = img_func.read_png_flow(data['flow_occ'])
623 | noc, nocmask = img_func.read_png_flow(data['flow_noc'])
624 | # crop
625 | h, w = im1.shape[1:]
626 | th = int(int(h / 32) * 32)
627 | tw = int(int(w / 32) * 32)
628 | x = int((h - th) / 2)
629 | y = int((w - tw) / 2)
630 | im1, im2, occ, occmask, noc, nocmask = img_func.np_2_tensor(im1, im2, occ, occmask, noc, nocmask)
631 | return im1, im2, occ, occmask, noc, nocmask
632 |
633 | @classmethod
634 | def demo(cls):
635 | def process(a):
636 | b = a.numpy()
637 | b = np.squeeze(b)
638 | if len(b.shape) > 2:
639 | b = np.transpose(b, (1, 2, 0))
640 | b = tools.im_norm(b)
641 | b = b.astype(np.uint8)
642 | return b
643 |
644 | data = kitti_flow.kitti_train('2015_test')
645 | # loader = tools.data_prefetcher(data, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
646 | for i in range(len(data)):
647 | im1, im2, occ, occmask, noc, nocmask = data.__getitem__(i)
648 | tools.check_tensor(im1, 'im1')
649 | tools.check_tensor(im2, 'im2')
650 |
651 |
652 | if __name__ == '__main__':
653 | kitti_flow.Evaluation_bench.demo()
654 |
--------------------------------------------------------------------------------
/how_to_install.md:
--------------------------------------------------------------------------------
1 | ## How to Install
2 |
3 | 1. install the python environment
4 | 2. install the cuda correlation layer
5 |
6 | ## Python Environment
7 |
8 | python3.5 is needed (training memory cost may be higher for python3.6 or higher in my case):
9 | ```
10 | conda create -n upflow python=3.5
11 | source deactivate
12 | source activate upflow
13 | ```
14 |
15 |
16 | use `pip install -r requirements.txt` to install python environment
17 |
18 | - Q1: ImportError: cannot import name 'DataLoaderIter
19 |
20 | - A1: DataLoaderIter is not exits in pytorch(1.2.0), may use _MultiProcessingDataLoaderIter or _SingleProcessDataLoaderIter: `from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter as _DataLoaderIter`
21 |
22 |
23 | ## Cuda Correlation Layer
24 | You should first check where your cuda is installed
25 |
26 | - my case: python3.5 with cuda9.0, where the cuda in installed in /usr/local/cuda-9.0
27 | - another case: python3.5 with cuda10.0 installed in /data/cuda/cuda-10.0/cuda
28 |
29 | Then check the 'cuda-path' in correlation_package/setup.py
30 |
31 | install the correlation layer(maybe you should check your gcc version before compile the correlation layer, use `which gcc`):
32 | ```
33 | cd ./model/correlation_package
34 | python3 setup.py install --user
35 | ```
36 |
37 | - Q1: get error: OSError: CUDA_HOME environment variable is not set. Please set it to your CUDA install root.
38 | - A1: use:
39 | ```
40 | export PATH=/data/cuda/cuda-10.0/cuda/bin:$PATH
41 | export LD_LIBRARY_PATH=/data/cuda/cuda-10.0/cuda/lib64:/data/cuda/cuda-10.0/cudnn/v7.5.0/lib64:$LD_LIBRARY_PATH
42 | ```
43 |
44 | - Q2: permission denied
45 | - A2: try: `sudo python3 setup.py install` or `python3 setup.py install --user`
46 |
47 |
48 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | # Author: kim luo
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/model/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/model/__pycache__/pwc_modules.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/model/__pycache__/pwc_modules.cpython-35.pyc
--------------------------------------------------------------------------------
/model/__pycache__/upflow.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/model/__pycache__/upflow.cpython-35.pyc
--------------------------------------------------------------------------------
/model/correlation_package/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/model/correlation_package/__init__.py
--------------------------------------------------------------------------------
/model/correlation_package/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/model/correlation_package/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/model/correlation_package/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/model/correlation_package/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/correlation_package/__pycache__/correlation.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/model/correlation_package/__pycache__/correlation.cpython-35.pyc
--------------------------------------------------------------------------------
/model/correlation_package/__pycache__/correlation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/model/correlation_package/__pycache__/correlation.cpython-36.pyc
--------------------------------------------------------------------------------
/model/correlation_package/build/lib.linux-x86_64-3.5/correlation_cuda.cpython-35m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/model/correlation_package/build/lib.linux-x86_64-3.5/correlation_cuda.cpython-35m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/model/correlation_package/build/temp.linux-x86_64-3.5/correlation_cuda.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/model/correlation_package/build/temp.linux-x86_64-3.5/correlation_cuda.o
--------------------------------------------------------------------------------
/model/correlation_package/build/temp.linux-x86_64-3.5/correlation_cuda_kernel.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/model/correlation_package/build/temp.linux-x86_64-3.5/correlation_cuda_kernel.o
--------------------------------------------------------------------------------
/model/correlation_package/correlation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.modules.module import Module
3 | from torch.autograd import Function
4 | import correlation_cuda
5 |
6 | class CorrelationFunction(Function):
7 |
8 | def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
9 | super(CorrelationFunction, self).__init__()
10 | self.pad_size = pad_size
11 | self.kernel_size = kernel_size
12 | self.max_displacement = max_displacement
13 | self.stride1 = stride1
14 | self.stride2 = stride2
15 | self.corr_multiply = corr_multiply
16 | # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)
17 |
18 | def forward(self, input1, input2):
19 | self.save_for_backward(input1, input2)
20 |
21 | with torch.cuda.device_of(input1):
22 | rbot1 = input1.new()
23 | rbot2 = input2.new()
24 | output = input1.new()
25 |
26 | correlation_cuda.forward(input1, input2, rbot1, rbot2, output,
27 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
28 |
29 | return output
30 |
31 | def backward(self, grad_output):
32 | input1, input2 = self.saved_tensors
33 |
34 | with torch.cuda.device_of(input1):
35 | rbot1 = input1.new()
36 | rbot2 = input2.new()
37 |
38 | grad_input1 = input1.new()
39 | grad_input2 = input2.new()
40 |
41 | correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
42 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
43 |
44 | return grad_input1, grad_input2
45 |
46 |
47 | class Correlation(Module):
48 | def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):
49 | super(Correlation, self).__init__()
50 | self.pad_size = pad_size
51 | self.kernel_size = kernel_size
52 | self.max_displacement = max_displacement
53 | self.stride1 = stride1
54 | self.stride2 = stride2
55 | self.corr_multiply = corr_multiply
56 |
57 | def forward(self, input1, input2):
58 |
59 | result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement, self.stride1, self.stride2, self.corr_multiply)(input1, input2)
60 |
61 | return result
62 |
63 |
--------------------------------------------------------------------------------
/model/correlation_package/correlation_cuda.cc:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 | #include "correlation_cuda_kernel.cuh"
9 |
10 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,
11 | int pad_size,
12 | int kernel_size,
13 | int max_displacement,
14 | int stride1,
15 | int stride2,
16 | int corr_type_multiply)
17 | {
18 |
19 | int batchSize = input1.size(0);
20 |
21 | int nInputChannels = input1.size(1);
22 | int inputHeight = input1.size(2);
23 | int inputWidth = input1.size(3);
24 |
25 | int kernel_radius = (kernel_size - 1) / 2;
26 | int border_radius = kernel_radius + max_displacement;
27 |
28 | int paddedInputHeight = inputHeight + 2 * pad_size;
29 | int paddedInputWidth = inputWidth + 2 * pad_size;
30 |
31 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);
32 |
33 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1));
34 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1));
35 |
36 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
37 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
38 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});
39 |
40 | rInput1.fill_(0);
41 | rInput2.fill_(0);
42 | output.fill_(0);
43 |
44 | int success = correlation_forward_cuda_kernel(
45 | output,
46 | output.size(0),
47 | output.size(1),
48 | output.size(2),
49 | output.size(3),
50 | output.stride(0),
51 | output.stride(1),
52 | output.stride(2),
53 | output.stride(3),
54 | input1,
55 | input1.size(1),
56 | input1.size(2),
57 | input1.size(3),
58 | input1.stride(0),
59 | input1.stride(1),
60 | input1.stride(2),
61 | input1.stride(3),
62 | input2,
63 | input2.size(1),
64 | input2.stride(0),
65 | input2.stride(1),
66 | input2.stride(2),
67 | input2.stride(3),
68 | rInput1,
69 | rInput2,
70 | pad_size,
71 | kernel_size,
72 | max_displacement,
73 | stride1,
74 | stride2,
75 | corr_type_multiply,
76 | at::cuda::getCurrentCUDAStream()
77 | //at::globalContext().getCurrentCUDAStream()
78 | );
79 |
80 | //check for errors
81 | if (!success) {
82 | AT_ERROR("CUDA call failed");
83 | }
84 |
85 | return 1;
86 |
87 | }
88 |
89 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput,
90 | at::Tensor& gradInput1, at::Tensor& gradInput2,
91 | int pad_size,
92 | int kernel_size,
93 | int max_displacement,
94 | int stride1,
95 | int stride2,
96 | int corr_type_multiply)
97 | {
98 |
99 | int batchSize = input1.size(0);
100 | int nInputChannels = input1.size(1);
101 | int paddedInputHeight = input1.size(2)+ 2 * pad_size;
102 | int paddedInputWidth = input1.size(3)+ 2 * pad_size;
103 |
104 | int height = input1.size(2);
105 | int width = input1.size(3);
106 |
107 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
108 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
109 | gradInput1.resize_({batchSize, nInputChannels, height, width});
110 | gradInput2.resize_({batchSize, nInputChannels, height, width});
111 |
112 | rInput1.fill_(0);
113 | rInput2.fill_(0);
114 | gradInput1.fill_(0);
115 | gradInput2.fill_(0);
116 |
117 | int success = correlation_backward_cuda_kernel(gradOutput,
118 | gradOutput.size(0),
119 | gradOutput.size(1),
120 | gradOutput.size(2),
121 | gradOutput.size(3),
122 | gradOutput.stride(0),
123 | gradOutput.stride(1),
124 | gradOutput.stride(2),
125 | gradOutput.stride(3),
126 | input1,
127 | input1.size(1),
128 | input1.size(2),
129 | input1.size(3),
130 | input1.stride(0),
131 | input1.stride(1),
132 | input1.stride(2),
133 | input1.stride(3),
134 | input2,
135 | input2.stride(0),
136 | input2.stride(1),
137 | input2.stride(2),
138 | input2.stride(3),
139 | gradInput1,
140 | gradInput1.stride(0),
141 | gradInput1.stride(1),
142 | gradInput1.stride(2),
143 | gradInput1.stride(3),
144 | gradInput2,
145 | gradInput2.size(1),
146 | gradInput2.stride(0),
147 | gradInput2.stride(1),
148 | gradInput2.stride(2),
149 | gradInput2.stride(3),
150 | rInput1,
151 | rInput2,
152 | pad_size,
153 | kernel_size,
154 | max_displacement,
155 | stride1,
156 | stride2,
157 | corr_type_multiply,
158 | at::cuda::getCurrentCUDAStream()
159 | //at::globalContext().getCurrentCUDAStream()
160 | );
161 |
162 | if (!success) {
163 | AT_ERROR("CUDA call failed");
164 | }
165 |
166 | return 1;
167 | }
168 |
169 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
170 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)");
171 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)");
172 | }
173 |
174 |
--------------------------------------------------------------------------------
/model/correlation_package/correlation_cuda.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 1.0
2 | Name: correlation-cuda
3 | Version: 0.0.0
4 | Summary: UNKNOWN
5 | Home-page: UNKNOWN
6 | Author: UNKNOWN
7 | Author-email: UNKNOWN
8 | License: UNKNOWN
9 | Description: UNKNOWN
10 | Platform: UNKNOWN
11 |
--------------------------------------------------------------------------------
/model/correlation_package/correlation_cuda.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | correlation_cuda.cc
2 | correlation_cuda_kernel.cu
3 | setup.py
4 | correlation_cuda.egg-info/PKG-INFO
5 | correlation_cuda.egg-info/SOURCES.txt
6 | correlation_cuda.egg-info/dependency_links.txt
7 | correlation_cuda.egg-info/top_level.txt
--------------------------------------------------------------------------------
/model/correlation_package/correlation_cuda.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/correlation_package/correlation_cuda.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | correlation_cuda
2 |
--------------------------------------------------------------------------------
/model/correlation_package/correlation_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "correlation_cuda_kernel.cuh"
4 |
5 | #define CUDA_NUM_THREADS 1024
6 | #define THREADS_PER_BLOCK 32
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 |
13 | using at::Half;
14 |
15 | template
16 | __global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size)
17 | {
18 |
19 | // n (batch size), c (num of channels), y (height), x (width)
20 | int n = blockIdx.x;
21 | int y = blockIdx.y;
22 | int x = blockIdx.z;
23 |
24 | int ch_off = threadIdx.x;
25 | scalar_t value;
26 |
27 | int dimcyx = channels * height * width;
28 | int dimyx = height * width;
29 |
30 | int p_dimx = (width + 2 * pad_size);
31 | int p_dimy = (height + 2 * pad_size);
32 | int p_dimyxc = channels * p_dimy * p_dimx;
33 | int p_dimxc = p_dimx * channels;
34 |
35 | for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) {
36 | value = input[n * dimcyx + c * dimyx + y * width + x];
37 | rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value;
38 | }
39 | }
40 |
41 | template
42 | __global__ void correlation_forward(scalar_t* output, int nOutputChannels, int outputHeight, int outputWidth,
43 | const scalar_t* __restrict__ rInput1, int nInputChannels, int inputHeight, int inputWidth,
44 | const scalar_t* __restrict__ rInput2,
45 | int pad_size,
46 | int kernel_size,
47 | int max_displacement,
48 | int stride1,
49 | int stride2)
50 | {
51 | // n (batch size), c (num of channels), y (height), x (width)
52 |
53 | int pInputWidth = inputWidth + 2 * pad_size;
54 | int pInputHeight = inputHeight + 2 * pad_size;
55 |
56 | int kernel_rad = (kernel_size - 1) / 2;
57 | int displacement_rad = max_displacement / stride2;
58 | int displacement_size = 2 * displacement_rad + 1;
59 |
60 | int n = blockIdx.x;
61 | int y1 = blockIdx.y * stride1 + max_displacement;
62 | int x1 = blockIdx.z * stride1 + max_displacement;
63 | int c = threadIdx.x;
64 |
65 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
66 | int pdimxc = pInputWidth * nInputChannels;
67 | int pdimc = nInputChannels;
68 |
69 | int tdimcyx = nOutputChannels * outputHeight * outputWidth;
70 | int tdimyx = outputHeight * outputWidth;
71 | int tdimx = outputWidth;
72 |
73 | scalar_t nelems = kernel_size * kernel_size * pdimc;
74 |
75 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
76 |
77 | // no significant speed-up in using chip memory for input1 sub-data,
78 | // not enough chip memory size to accomodate memory per block for input2 sub-data
79 | // instead i've used device memory for both
80 |
81 | // element-wise product along channel axis
82 | for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) {
83 | for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) {
84 | prod_sum[c] = 0;
85 | int x2 = x1 + ti*stride2;
86 | int y2 = y1 + tj*stride2;
87 |
88 | for (int j = -kernel_rad; j <= kernel_rad; ++j) {
89 | for (int i = -kernel_rad; i <= kernel_rad; ++i) {
90 | for (int ch = c; ch < pdimc; ch += THREADS_PER_BLOCK) {
91 | int indx1 = n * pdimyxc + (y1 + j) * pdimxc + (x1 + i) * pdimc + ch;
92 | int indx2 = n * pdimyxc + (y2 + j) * pdimxc + (x2 + i) * pdimc + ch;
93 |
94 | prod_sum[c] += rInput1[indx1] * rInput2[indx2];
95 | }
96 | }
97 | }
98 |
99 | // accumulate
100 | __syncthreads();
101 | if (c == 0) {
102 | scalar_t reduce_sum = 0;
103 | for (int index = 0; index < THREADS_PER_BLOCK; ++index) {
104 | reduce_sum += prod_sum[index];
105 | }
106 | int tc = (tj + displacement_rad) * displacement_size + (ti + displacement_rad);
107 | const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + blockIdx.z;
108 | output[tindx] = reduce_sum / nelems;
109 | }
110 |
111 | }
112 | }
113 |
114 | }
115 |
116 | template
117 | __global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth,
118 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
119 | const scalar_t* __restrict__ rInput2,
120 | int pad_size,
121 | int kernel_size,
122 | int max_displacement,
123 | int stride1,
124 | int stride2)
125 | {
126 | // n (batch size), c (num of channels), y (height), x (width)
127 |
128 | int n = item;
129 | int y = blockIdx.x * stride1 + pad_size;
130 | int x = blockIdx.y * stride1 + pad_size;
131 | int c = blockIdx.z;
132 | int tch_off = threadIdx.x;
133 |
134 | int kernel_rad = (kernel_size - 1) / 2;
135 | int displacement_rad = max_displacement / stride2;
136 | int displacement_size = 2 * displacement_rad + 1;
137 |
138 | int xmin = (x - kernel_rad - max_displacement) / stride1;
139 | int ymin = (y - kernel_rad - max_displacement) / stride1;
140 |
141 | int xmax = (x + kernel_rad - max_displacement) / stride1;
142 | int ymax = (y + kernel_rad - max_displacement) / stride1;
143 |
144 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
145 | // assumes gradInput1 is pre-allocated and zero filled
146 | return;
147 | }
148 |
149 | if (xmin > xmax || ymin > ymax) {
150 | // assumes gradInput1 is pre-allocated and zero filled
151 | return;
152 | }
153 |
154 | xmin = max(0, xmin);
155 | xmax = min(outputWidth - 1, xmax);
156 |
157 | ymin = max(0, ymin);
158 | ymax = min(outputHeight - 1, ymax);
159 |
160 | int pInputWidth = inputWidth + 2 * pad_size;
161 | int pInputHeight = inputHeight + 2 * pad_size;
162 |
163 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
164 | int pdimxc = pInputWidth * nInputChannels;
165 | int pdimc = nInputChannels;
166 |
167 | int tdimcyx = nOutputChannels * outputHeight * outputWidth;
168 | int tdimyx = outputHeight * outputWidth;
169 | int tdimx = outputWidth;
170 |
171 | int odimcyx = nInputChannels * inputHeight* inputWidth;
172 | int odimyx = inputHeight * inputWidth;
173 | int odimx = inputWidth;
174 |
175 | scalar_t nelems = kernel_size * kernel_size * nInputChannels;
176 |
177 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
178 | prod_sum[tch_off] = 0;
179 |
180 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
181 |
182 | int i2 = (tc % displacement_size - displacement_rad) * stride2;
183 | int j2 = (tc / displacement_size - displacement_rad) * stride2;
184 |
185 | int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c;
186 |
187 | scalar_t val2 = rInput2[indx2];
188 |
189 | for (int j = ymin; j <= ymax; ++j) {
190 | for (int i = xmin; i <= xmax; ++i) {
191 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
192 | prod_sum[tch_off] += gradOutput[tindx] * val2;
193 | }
194 | }
195 | }
196 | __syncthreads();
197 |
198 | if (tch_off == 0) {
199 | scalar_t reduce_sum = 0;
200 | for (int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
201 | reduce_sum += prod_sum[idx];
202 | }
203 | const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
204 | gradInput1[indx1] = reduce_sum / nelems;
205 | }
206 |
207 | }
208 |
209 | template
210 | __global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth,
211 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
212 | const scalar_t* __restrict__ rInput1,
213 | int pad_size,
214 | int kernel_size,
215 | int max_displacement,
216 | int stride1,
217 | int stride2)
218 | {
219 | // n (batch size), c (num of channels), y (height), x (width)
220 |
221 | int n = item;
222 | int y = blockIdx.x * stride1 + pad_size;
223 | int x = blockIdx.y * stride1 + pad_size;
224 | int c = blockIdx.z;
225 |
226 | int tch_off = threadIdx.x;
227 |
228 | int kernel_rad = (kernel_size - 1) / 2;
229 | int displacement_rad = max_displacement / stride2;
230 | int displacement_size = 2 * displacement_rad + 1;
231 |
232 | int pInputWidth = inputWidth + 2 * pad_size;
233 | int pInputHeight = inputHeight + 2 * pad_size;
234 |
235 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
236 | int pdimxc = pInputWidth * nInputChannels;
237 | int pdimc = nInputChannels;
238 |
239 | int tdimcyx = nOutputChannels * outputHeight * outputWidth;
240 | int tdimyx = outputHeight * outputWidth;
241 | int tdimx = outputWidth;
242 |
243 | int odimcyx = nInputChannels * inputHeight* inputWidth;
244 | int odimyx = inputHeight * inputWidth;
245 | int odimx = inputWidth;
246 |
247 | scalar_t nelems = kernel_size * kernel_size * nInputChannels;
248 |
249 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
250 | prod_sum[tch_off] = 0;
251 |
252 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
253 | int i2 = (tc % displacement_size - displacement_rad) * stride2;
254 | int j2 = (tc / displacement_size - displacement_rad) * stride2;
255 |
256 | int xmin = (x - kernel_rad - max_displacement - i2) / stride1;
257 | int ymin = (y - kernel_rad - max_displacement - j2) / stride1;
258 |
259 | int xmax = (x + kernel_rad - max_displacement - i2) / stride1;
260 | int ymax = (y + kernel_rad - max_displacement - j2) / stride1;
261 |
262 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
263 | // assumes gradInput2 is pre-allocated and zero filled
264 | continue;
265 | }
266 |
267 | if (xmin > xmax || ymin > ymax) {
268 | // assumes gradInput2 is pre-allocated and zero filled
269 | continue;
270 | }
271 |
272 | xmin = max(0, xmin);
273 | xmax = min(outputWidth - 1, xmax);
274 |
275 | ymin = max(0, ymin);
276 | ymax = min(outputHeight - 1, ymax);
277 |
278 | int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c;
279 | scalar_t val1 = rInput1[indx1];
280 |
281 | for (int j = ymin; j <= ymax; ++j) {
282 | for (int i = xmin; i <= xmax; ++i) {
283 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
284 | prod_sum[tch_off] += gradOutput[tindx] * val1;
285 | }
286 | }
287 | }
288 |
289 | __syncthreads();
290 |
291 | if (tch_off == 0) {
292 | scalar_t reduce_sum = 0;
293 | for (int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
294 | reduce_sum += prod_sum[idx];
295 | }
296 | const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
297 | gradInput2[indx2] = reduce_sum / nelems;
298 | }
299 |
300 | }
301 |
302 | int correlation_forward_cuda_kernel(at::Tensor& output,
303 | int ob,
304 | int oc,
305 | int oh,
306 | int ow,
307 | int osb,
308 | int osc,
309 | int osh,
310 | int osw,
311 |
312 | at::Tensor& input1,
313 | int ic,
314 | int ih,
315 | int iw,
316 | int isb,
317 | int isc,
318 | int ish,
319 | int isw,
320 |
321 | at::Tensor& input2,
322 | int gc,
323 | int gsb,
324 | int gsc,
325 | int gsh,
326 | int gsw,
327 |
328 | at::Tensor& rInput1,
329 | at::Tensor& rInput2,
330 | int pad_size,
331 | int kernel_size,
332 | int max_displacement,
333 | int stride1,
334 | int stride2,
335 | int corr_type_multiply,
336 | cudaStream_t stream)
337 | {
338 |
339 | int batchSize = ob;
340 |
341 | int nInputChannels = ic;
342 | int inputWidth = iw;
343 | int inputHeight = ih;
344 |
345 | int nOutputChannels = oc;
346 | int outputWidth = ow;
347 | int outputHeight = oh;
348 |
349 | dim3 blocks_grid(batchSize, inputHeight, inputWidth);
350 | dim3 threads_block(THREADS_PER_BLOCK);
351 |
352 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] {
353 |
354 | channels_first << > >(
355 | input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size);
356 |
357 | }));
358 |
359 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] {
360 |
361 | channels_first << > > (
362 | input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size);
363 |
364 | }));
365 |
366 | dim3 threadsPerBlock(THREADS_PER_BLOCK);
367 | dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth);
368 |
369 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] {
370 |
371 | correlation_forward << > >
372 | (output.data(), nOutputChannels, outputHeight, outputWidth,
373 | rInput1.data(), nInputChannels, inputHeight, inputWidth,
374 | rInput2.data(),
375 | pad_size,
376 | kernel_size,
377 | max_displacement,
378 | stride1,
379 | stride2);
380 |
381 | }));
382 |
383 | cudaError_t err = cudaGetLastError();
384 |
385 |
386 | // check for errors
387 | if (err != cudaSuccess) {
388 | printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err));
389 | return 0;
390 | }
391 |
392 | return 1;
393 | }
394 |
395 |
396 | int correlation_backward_cuda_kernel(
397 | at::Tensor& gradOutput,
398 | int gob,
399 | int goc,
400 | int goh,
401 | int gow,
402 | int gosb,
403 | int gosc,
404 | int gosh,
405 | int gosw,
406 |
407 | at::Tensor& input1,
408 | int ic,
409 | int ih,
410 | int iw,
411 | int isb,
412 | int isc,
413 | int ish,
414 | int isw,
415 |
416 | at::Tensor& input2,
417 | int gsb,
418 | int gsc,
419 | int gsh,
420 | int gsw,
421 |
422 | at::Tensor& gradInput1,
423 | int gisb,
424 | int gisc,
425 | int gish,
426 | int gisw,
427 |
428 | at::Tensor& gradInput2,
429 | int ggc,
430 | int ggsb,
431 | int ggsc,
432 | int ggsh,
433 | int ggsw,
434 |
435 | at::Tensor& rInput1,
436 | at::Tensor& rInput2,
437 | int pad_size,
438 | int kernel_size,
439 | int max_displacement,
440 | int stride1,
441 | int stride2,
442 | int corr_type_multiply,
443 | cudaStream_t stream)
444 | {
445 |
446 | int batchSize = gob;
447 | int num = batchSize;
448 |
449 | int nInputChannels = ic;
450 | int inputWidth = iw;
451 | int inputHeight = ih;
452 |
453 | int nOutputChannels = goc;
454 | int outputWidth = gow;
455 | int outputHeight = goh;
456 |
457 | dim3 blocks_grid(batchSize, inputHeight, inputWidth);
458 | dim3 threads_block(THREADS_PER_BLOCK);
459 |
460 |
461 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] {
462 |
463 | channels_first << > >(
464 | input1.data(),
465 | rInput1.data(),
466 | nInputChannels,
467 | inputHeight,
468 | inputWidth,
469 | pad_size
470 | );
471 | }));
472 |
473 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
474 |
475 | channels_first << > >(
476 | input2.data(),
477 | rInput2.data(),
478 | nInputChannels,
479 | inputHeight,
480 | inputWidth,
481 | pad_size
482 | );
483 | }));
484 |
485 | dim3 threadsPerBlock(THREADS_PER_BLOCK);
486 | dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels);
487 |
488 | for (int n = 0; n < num; ++n) {
489 |
490 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
491 |
492 |
493 | correlation_backward_input1 << > > (
494 | n, gradInput1.data(), nInputChannels, inputHeight, inputWidth,
495 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth,
496 | rInput2.data(),
497 | pad_size,
498 | kernel_size,
499 | max_displacement,
500 | stride1,
501 | stride2);
502 | }));
503 | }
504 |
505 | for (int n = 0; n < batchSize; n++) {
506 |
507 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] {
508 |
509 | correlation_backward_input2 << > >(
510 | n, gradInput2.data(), nInputChannels, inputHeight, inputWidth,
511 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth,
512 | rInput1.data(),
513 | pad_size,
514 | kernel_size,
515 | max_displacement,
516 | stride1,
517 | stride2);
518 |
519 | }));
520 | }
521 |
522 | // check for errors
523 | cudaError_t err = cudaGetLastError();
524 | if (err != cudaSuccess) {
525 | printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err));
526 | return 0;
527 | }
528 |
529 | return 1;
530 | }
531 |
--------------------------------------------------------------------------------
/model/correlation_package/correlation_cuda_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | int correlation_forward_cuda_kernel(at::Tensor& output,
8 | int ob,
9 | int oc,
10 | int oh,
11 | int ow,
12 | int osb,
13 | int osc,
14 | int osh,
15 | int osw,
16 |
17 | at::Tensor& input1,
18 | int ic,
19 | int ih,
20 | int iw,
21 | int isb,
22 | int isc,
23 | int ish,
24 | int isw,
25 |
26 | at::Tensor& input2,
27 | int gc,
28 | int gsb,
29 | int gsc,
30 | int gsh,
31 | int gsw,
32 |
33 | at::Tensor& rInput1,
34 | at::Tensor& rInput2,
35 | int pad_size,
36 | int kernel_size,
37 | int max_displacement,
38 | int stride1,
39 | int stride2,
40 | int corr_type_multiply,
41 | cudaStream_t stream);
42 |
43 |
44 | int correlation_backward_cuda_kernel(
45 | at::Tensor& gradOutput,
46 | int gob,
47 | int goc,
48 | int goh,
49 | int gow,
50 | int gosb,
51 | int gosc,
52 | int gosh,
53 | int gosw,
54 |
55 | at::Tensor& input1,
56 | int ic,
57 | int ih,
58 | int iw,
59 | int isb,
60 | int isc,
61 | int ish,
62 | int isw,
63 |
64 | at::Tensor& input2,
65 | int gsb,
66 | int gsc,
67 | int gsh,
68 | int gsw,
69 |
70 | at::Tensor& gradInput1,
71 | int gisb,
72 | int gisc,
73 | int gish,
74 | int gisw,
75 |
76 | at::Tensor& gradInput2,
77 | int ggc,
78 | int ggsb,
79 | int ggsc,
80 | int ggsh,
81 | int ggsw,
82 |
83 | at::Tensor& rInput1,
84 | at::Tensor& rInput2,
85 | int pad_size,
86 | int kernel_size,
87 | int max_displacement,
88 | int stride1,
89 | int stride2,
90 | int corr_type_multiply,
91 | cudaStream_t stream);
92 |
--------------------------------------------------------------------------------
/model/correlation_package/dist/correlation_cuda-0.0.0-py3.5-linux-x86_64.egg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/model/correlation_package/dist/correlation_cuda-0.0.0-py3.5-linux-x86_64.egg
--------------------------------------------------------------------------------
/model/correlation_package/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import torch
4 |
5 | from setuptools import setup, find_packages
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | cxx_args = ['-std=c++11']
9 |
10 | nvcc_args = [
11 | '-gencode', 'arch=compute_50,code=sm_50',
12 | '-gencode', 'arch=compute_52,code=sm_52',
13 | '-gencode', 'arch=compute_60,code=sm_60',
14 | '-gencode', 'arch=compute_61,code=sm_61',
15 | '-gencode', 'arch=compute_61,code=compute_61',
16 | '-ccbin', '/usr/bin/gcc-4.9'
17 | ]
18 |
19 | setup(
20 | name='correlation_cuda',
21 | ext_modules=[
22 | CUDAExtension('correlation_cuda', [
23 | 'correlation_cuda.cc',
24 | 'correlation_cuda_kernel.cu'
25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args, 'cuda-path': ['/data/cuda/cuda-10.0/cuda']})
26 | ],
27 | cmdclass={
28 | 'build_ext': BuildExtension
29 | })
30 |
--------------------------------------------------------------------------------
/model/pwc_modules.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as tf
5 | import logging
6 | from utils.tools import tools
7 | import numpy as np
8 |
9 |
10 | def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True, if_IN=False, IN_affine=False, if_BN=False):
11 | if isReLU:
12 | if if_IN:
13 | return nn.Sequential(
14 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
15 | padding=((kernel_size - 1) * dilation) // 2, bias=True),
16 | nn.LeakyReLU(0.1, inplace=True),
17 | nn.InstanceNorm2d(out_planes, affine=IN_affine)
18 | )
19 | elif if_BN:
20 | return nn.Sequential(
21 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
22 | padding=((kernel_size - 1) * dilation) // 2, bias=True),
23 | nn.LeakyReLU(0.1, inplace=True),
24 | nn.BatchNorm2d(out_planes, affine=IN_affine)
25 | )
26 | else:
27 | return nn.Sequential(
28 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
29 | padding=((kernel_size - 1) * dilation) // 2, bias=True),
30 | nn.LeakyReLU(0.1, inplace=True)
31 | )
32 | else:
33 | if if_IN:
34 | return nn.Sequential(
35 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
36 | padding=((kernel_size - 1) * dilation) // 2, bias=True),
37 | nn.InstanceNorm2d(out_planes, affine=IN_affine)
38 | )
39 | elif if_BN:
40 | return nn.Sequential(
41 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
42 | padding=((kernel_size - 1) * dilation) // 2, bias=True),
43 | nn.BatchNorm2d(out_planes, affine=IN_affine)
44 | )
45 | else:
46 | return nn.Sequential(
47 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
48 | padding=((kernel_size - 1) * dilation) // 2, bias=True)
49 | )
50 |
51 |
52 | def initialize_msra(modules):
53 | logging.info("Initializing MSRA")
54 | for layer in modules:
55 | if isinstance(layer, nn.Conv2d):
56 | nn.init.kaiming_normal_(layer.weight)
57 | if layer.bias is not None:
58 | nn.init.constant_(layer.bias, 0)
59 |
60 | elif isinstance(layer, nn.ConvTranspose2d):
61 | nn.init.kaiming_normal_(layer.weight)
62 | if layer.bias is not None:
63 | nn.init.constant_(layer.bias, 0)
64 |
65 | elif isinstance(layer, nn.LeakyReLU):
66 | pass
67 |
68 | elif isinstance(layer, nn.Sequential):
69 | pass
70 |
71 |
72 | def upsample2d_as(inputs, target_as, mode="bilinear"):
73 | _, _, h, w = target_as.size()
74 | return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True)
75 |
76 |
77 | def upsample2d_flow_as(inputs, target_as, mode="bilinear", if_rate=False):
78 | _, _, h, w = target_as.size()
79 | res = tf.interpolate(inputs, [h, w], mode=mode, align_corners=True)
80 | if if_rate:
81 | _, _, h_, w_ = inputs.size()
82 | # inputs[:, 0, :, :] *= (w / w_)
83 | # inputs[:, 1, :, :] *= (h / h_)
84 | u_scale = (w / w_)
85 | v_scale = (h / h_)
86 | u, v = res.chunk(2, dim=1)
87 | u *= u_scale
88 | v *= v_scale
89 | res = torch.cat([u, v], dim=1)
90 | return res
91 |
92 |
93 | def upsample_flow(inputs, target_size=None, target_flow=None, mode="bilinear"):
94 | if target_size is not None:
95 | h, w = target_size
96 | elif target_flow is not None:
97 | _, _, h, w = target_flow.size()
98 | else:
99 | raise ValueError('wrong input')
100 | _, _, h_, w_ = inputs.size()
101 | res = tf.interpolate(inputs, [h, w], mode=mode, align_corners=True)
102 | res[:, 0, :, :] *= (w / w_)
103 | res[:, 1, :, :] *= (h / h_)
104 | return res
105 |
106 |
107 | def rescale_flow(flow, div_flow, width_im, height_im, to_local=True):
108 | if to_local:
109 | u_scale = float(flow.size(3) / width_im / div_flow)
110 | v_scale = float(flow.size(2) / height_im / div_flow)
111 | else:
112 | u_scale = float(width_im * div_flow / flow.size(3))
113 | v_scale = float(height_im * div_flow / flow.size(2))
114 |
115 | u, v = flow.chunk(2, dim=1)
116 | u *= u_scale
117 | v *= v_scale
118 |
119 | return torch.cat([u, v], dim=1)
120 |
121 |
122 | class FeatureExtractor(nn.Module):
123 |
124 | def __init__(self, num_chs, if_end_relu=True, if_end_norm=False):
125 | super(FeatureExtractor, self).__init__()
126 | self.num_chs = num_chs
127 | self.convs = nn.ModuleList()
128 |
129 | for l, (ch_in, ch_out) in enumerate(zip(num_chs[:-1], num_chs[1:])):
130 | layer = nn.Sequential(
131 | conv(ch_in, ch_out, stride=2),
132 | conv(ch_out, ch_out, isReLU=if_end_relu, if_IN=if_end_norm)
133 | )
134 | self.convs.append(layer)
135 |
136 | def forward(self, x):
137 | feature_pyramid = []
138 | for conv in self.convs:
139 | x = conv(x)
140 | feature_pyramid.append(x)
141 |
142 | return feature_pyramid[::-1]
143 |
144 |
145 | def get_grid(x):
146 | grid_H = torch.linspace(-1.0, 1.0, x.size(3)).view(1, 1, 1, x.size(3)).expand(x.size(0), 1, x.size(2), x.size(3))
147 | grid_V = torch.linspace(-1.0, 1.0, x.size(2)).view(1, 1, x.size(2), 1).expand(x.size(0), 1, x.size(2), x.size(3))
148 | grid = torch.cat([grid_H, grid_V], 1)
149 | if x.is_cuda:
150 | grids_cuda = grid.float().requires_grad_(False).cuda()
151 | else:
152 | grids_cuda = grid.float().requires_grad_(False) # .cuda()
153 | return grids_cuda
154 |
155 |
156 | class WarpingLayer(nn.Module):
157 |
158 | def __init__(self):
159 | super(WarpingLayer, self).__init__()
160 |
161 | def forward(self, x, flow, height_im, width_im, div_flow):
162 | flo_list = []
163 | flo_w = flow[:, 0] * 2 / max(width_im - 1, 1) / div_flow
164 | flo_h = flow[:, 1] * 2 / max(height_im - 1, 1) / div_flow
165 | flo_list.append(flo_w)
166 | flo_list.append(flo_h)
167 | flow_for_grid = torch.stack(flo_list).transpose(0, 1)
168 | grid = torch.add(get_grid(x), flow_for_grid).transpose(1, 2).transpose(2, 3)
169 | x_warp = tf.grid_sample(x, grid)
170 | if x.is_cuda:
171 | mask = torch.ones(x.size(), requires_grad=False).cuda()
172 | else:
173 | mask = torch.ones(x.size(), requires_grad=False) # .cuda()
174 | mask = tf.grid_sample(mask, grid)
175 | mask = (mask >= 1.0).float()
176 | return x_warp * mask
177 |
178 |
179 | class WarpingLayer_no_div(nn.Module):
180 |
181 | def __init__(self):
182 | super(WarpingLayer_no_div, self).__init__()
183 |
184 | def forward(self, x, flow):
185 | B, C, H, W = x.size()
186 | # mesh grid
187 | xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
188 | yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
189 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
190 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
191 | grid = torch.cat((xx, yy), 1).float()
192 | if x.is_cuda:
193 | grid = grid.cuda()
194 | # print(grid.shape,flo.shape,'...')
195 | vgrid = grid + flow
196 | # scale grid to [-1,1]
197 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :] / max(W - 1, 1) - 1.0
198 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :] / max(H - 1, 1) - 1.0
199 | vgrid = vgrid.permute(0, 2, 3, 1) # B H,W,C
200 | x_warp = tf.grid_sample(x, vgrid, padding_mode='zeros')
201 | if x.is_cuda:
202 | mask = torch.ones(x.size(), requires_grad=False).cuda()
203 | else:
204 | mask = torch.ones(x.size(), requires_grad=False) # .cuda()
205 | mask = tf.grid_sample(mask, vgrid)
206 | mask = (mask >= 1.0).float()
207 | return x_warp * mask
208 |
209 |
210 | class OpticalFlowEstimator(nn.Module):
211 |
212 | def __init__(self, ch_in):
213 | super(OpticalFlowEstimator, self).__init__()
214 |
215 | self.convs = nn.Sequential(
216 | conv(ch_in, 128),
217 | conv(128, 128),
218 | conv(128, 96),
219 | conv(96, 64),
220 | conv(64, 32)
221 | )
222 | self.conv_last = conv(32, 2, isReLU=False)
223 |
224 | def forward(self, x):
225 | x_intm = self.convs(x)
226 | return x_intm, self.conv_last(x_intm)
227 |
228 |
229 | class FlowEstimatorDense(nn.Module):
230 |
231 | def __init__(self, ch_in):
232 | super(FlowEstimatorDense, self).__init__()
233 | self.conv1 = conv(ch_in, 128)
234 | self.conv2 = conv(ch_in + 128, 128)
235 | self.conv3 = conv(ch_in + 256, 96)
236 | self.conv4 = conv(ch_in + 352, 64)
237 | self.conv5 = conv(ch_in + 416, 32)
238 | self.conv_last = conv(ch_in + 448, 2, isReLU=False)
239 |
240 | def forward(self, x):
241 | x1 = torch.cat([self.conv1(x), x], dim=1)
242 | x2 = torch.cat([self.conv2(x1), x1], dim=1)
243 | x3 = torch.cat([self.conv3(x2), x2], dim=1)
244 | x4 = torch.cat([self.conv4(x3), x3], dim=1)
245 | x5 = torch.cat([self.conv5(x4), x4], dim=1)
246 | x_out = self.conv_last(x5)
247 | return x5, x_out
248 |
249 |
250 | class FlowEstimatorDense_v2(tools.abstract_model):
251 |
252 | def __init__(self, ch_in, f_channels=(128, 128, 96, 64, 32), out_channel=2):
253 | super(FlowEstimatorDense_v2, self).__init__()
254 | N = 0
255 | ind = 0
256 | N += ch_in
257 | self.conv1 = conv(N, f_channels[ind])
258 | N += f_channels[ind]
259 |
260 | ind += 1
261 | self.conv2 = conv(N, f_channels[ind])
262 | N += f_channels[ind]
263 |
264 | ind += 1
265 | self.conv3 = conv(N, f_channels[ind])
266 | N += f_channels[ind]
267 |
268 | ind += 1
269 | self.conv4 = conv(N, f_channels[ind])
270 | N += f_channels[ind]
271 |
272 | ind += 1
273 | self.conv5 = conv(N, f_channels[ind])
274 | N += f_channels[ind]
275 | self.n_channels = N
276 | ind += 1
277 | self.conv_last = conv(N, out_channel, isReLU=False)
278 |
279 | def forward(self, x):
280 | x1 = torch.cat([self.conv1(x), x], dim=1)
281 | x2 = torch.cat([self.conv2(x1), x1], dim=1)
282 | x3 = torch.cat([self.conv3(x2), x2], dim=1)
283 | x4 = torch.cat([self.conv4(x3), x3], dim=1)
284 | x5 = torch.cat([self.conv5(x4), x4], dim=1)
285 | x_out = self.conv_last(x5)
286 | return x5, x_out
287 |
288 |
289 | class FlowEstimatorDense_v3(tools.abstract_model):
290 |
291 | def __init__(self, ch_in, f_channels=(128, 128, 96, 64, 32), if_conv_cat=False):
292 | super(FlowEstimatorDense_v3, self).__init__()
293 | self.conv_ls = nn.ModuleList()
294 | in_channel = ch_in
295 | self.if_conv_cat = if_conv_cat
296 | for i in f_channels:
297 | out_channel = i
298 | self.conv_ls.append(conv(in_channel, out_channel))
299 | in_channel += out_channel
300 |
301 | # N = 0
302 | # ind = 0
303 | # N += ch_in
304 | # self.conv1 = conv(N, f_channels[ind])
305 | # N += f_channels[ind]
306 | #
307 | # ind += 1
308 | # self.conv2 = conv(N, f_channels[ind])
309 | # N += f_channels[ind]
310 | #
311 | # ind += 1
312 | # self.conv3 = conv(N, f_channels[ind])
313 | # N += f_channels[ind]
314 | #
315 | # ind += 1
316 | # self.conv4 = conv(N, f_channels[ind])
317 | # N += f_channels[ind]
318 | #
319 | # ind += 1
320 | # self.conv5 = conv(N, f_channels[ind])
321 | # N += f_channels[ind]
322 | self.n_channels = in_channel
323 | # ind += 1
324 | self.conv_last = conv(in_channel, 2, isReLU=False)
325 |
326 | def forward(self, x):
327 | for conv_layer in self.conv_ls:
328 | x = torch.cat([conv_layer(x), x], dim=1)
329 | # x1 = torch.cat([self.conv1(x), x], dim=1)
330 | # x2 = torch.cat([self.conv2(x1), x1], dim=1)
331 | # x3 = torch.cat([self.conv3(x2), x2], dim=1)
332 | # x4 = torch.cat([self.conv4(x3), x3], dim=1)
333 | # x5 = torch.cat([self.conv5(x4), x4], dim=1)
334 | x_out = self.conv_last(x)
335 | return x, x_out
336 |
337 |
338 | class OcclusionEstimator(nn.Module):
339 |
340 | def __init__(self, ch_in):
341 | super(OcclusionEstimator, self).__init__()
342 | self.convs = nn.Sequential(
343 | conv(ch_in, 128),
344 | conv(128, 128),
345 | conv(128, 96),
346 | conv(96, 64),
347 | conv(64, 32)
348 | )
349 | self.conv_last = conv(32, 1, isReLU=False)
350 |
351 | def forward(self, x):
352 | x_intm = self.convs(x)
353 | return x_intm, self.conv_last(x_intm)
354 |
355 |
356 | class OccEstimatorDense(nn.Module):
357 |
358 | def __init__(self, ch_in):
359 | super(OccEstimatorDense, self).__init__()
360 | self.conv1 = conv(ch_in, 128)
361 | self.conv2 = conv(ch_in + 128, 128)
362 | self.conv3 = conv(ch_in + 256, 96)
363 | self.conv4 = conv(ch_in + 352, 64)
364 | self.conv5 = conv(ch_in + 416, 32)
365 | self.conv_last = conv(ch_in + 448, 1, isReLU=False)
366 |
367 | def forward(self, x):
368 | x1 = torch.cat([self.conv1(x), x], dim=1)
369 | x2 = torch.cat([self.conv2(x1), x1], dim=1)
370 | x3 = torch.cat([self.conv3(x2), x2], dim=1)
371 | x4 = torch.cat([self.conv4(x3), x3], dim=1)
372 | x5 = torch.cat([self.conv5(x4), x4], dim=1)
373 | x_out = self.conv_last(x5)
374 | return x5, x_out
375 |
376 |
377 | class ContextNetwork(nn.Module):
378 |
379 | def __init__(self, ch_in):
380 | super(ContextNetwork, self).__init__()
381 |
382 | self.convs = nn.Sequential(
383 | conv(ch_in, 128, 3, 1, 1),
384 | conv(128, 128, 3, 1, 2),
385 | conv(128, 128, 3, 1, 4),
386 | conv(128, 96, 3, 1, 8),
387 | conv(96, 64, 3, 1, 16),
388 | conv(64, 32, 3, 1, 1),
389 | conv(32, 2, isReLU=False)
390 | )
391 |
392 | def forward(self, x):
393 | return self.convs(x)
394 |
395 |
396 | class ContextNetwork_v2_(nn.Module):
397 |
398 | def __init__(self, ch_in, f_channels=(128, 128, 128, 96, 64, 32, 2)):
399 | super(ContextNetwork_v2_, self).__init__()
400 |
401 | self.convs = nn.Sequential(
402 | conv(ch_in, f_channels[0], 3, 1, 1),
403 | conv(f_channels[0], f_channels[1], 3, 1, 2),
404 | conv(f_channels[1], f_channels[2], 3, 1, 4),
405 | conv(f_channels[2], f_channels[3], 3, 1, 8),
406 | conv(f_channels[3], f_channels[4], 3, 1, 16),
407 | conv(f_channels[4], f_channels[5], 3, 1, 1),
408 | conv(f_channels[5], f_channels[6], isReLU=False)
409 | )
410 |
411 | def forward(self, x):
412 | return self.convs(x)
413 |
414 |
415 | class ContextNetwork_v2(nn.Module):
416 |
417 | def __init__(self, num_ls=(3, 128, 128, 128, 96, 64, 32, 16)):
418 | super(ContextNetwork_v2, self).__init__()
419 | self.num_ls = num_ls
420 | cnt = 0
421 | cnt_in = num_ls[0]
422 | self.cov1 = conv(num_ls[0], num_ls[1], 3, 1, 1)
423 |
424 | cnt += 1 # 1
425 | cnt_in += num_ls[cnt]
426 | self.cov2 = conv(cnt_in, num_ls[cnt + 1], 3, 1, 2)
427 |
428 | cnt += 1 # 2
429 | cnt_in += num_ls[cnt]
430 | self.cov3 = conv(cnt_in, num_ls[cnt + 1], 3, 1, 4)
431 |
432 | cnt += 1 # 3
433 | cnt_in += num_ls[cnt]
434 | self.cov4 = conv(cnt_in, num_ls[cnt + 1], 3, 1, 8)
435 |
436 | cnt += 1 # 4
437 | cnt_in += num_ls[cnt]
438 | self.cov5 = conv(cnt_in, num_ls[cnt + 1], 3, 1, 16)
439 |
440 | cnt += 1 # 5
441 | cnt_in += num_ls[cnt]
442 | self.cov6 = conv(cnt_in, num_ls[cnt + 1], 3, 1, 1)
443 |
444 | cnt += 1
445 | cnt_in += num_ls[cnt]
446 | self.final = conv(cnt_in, num_ls[cnt + 1], isReLU=False)
447 |
448 | def forward(self, x):
449 | x = torch.cat((self.cov1(x), x), dim=1)
450 | x = torch.cat((self.cov2(x), x), dim=1)
451 | x = torch.cat((self.cov3(x), x), dim=1)
452 | x = torch.cat((self.cov4(x), x), dim=1)
453 | x = torch.cat((self.cov5(x), x), dim=1)
454 | x = torch.cat((self.cov6(x), x), dim=1)
455 | x = self.final(x)
456 | return x
457 |
458 |
459 | class OccContextNetwork(nn.Module):
460 |
461 | def __init__(self, ch_in):
462 | super(OccContextNetwork, self).__init__()
463 |
464 | self.convs = nn.Sequential(
465 | conv(ch_in, 128, 3, 1, 1),
466 | conv(128, 128, 3, 1, 2),
467 | conv(128, 128, 3, 1, 4),
468 | conv(128, 96, 3, 1, 8),
469 | conv(96, 64, 3, 1, 16),
470 | conv(64, 32, 3, 1, 1),
471 | conv(32, 1, isReLU=False)
472 | )
473 |
474 | def forward(self, x):
475 | return self.convs(x)
476 |
477 |
--------------------------------------------------------------------------------
/model/upflow.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import collections
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | # from torch.nn.utils.spectral_norm import spectral_norm
7 | from model.pwc_modules import conv, initialize_msra, upsample2d_flow_as, upsample_flow, FlowEstimatorDense_v2, ContextNetwork_v2_, OccEstimatorDense, OccContextNetwork
8 | from model.pwc_modules import WarpingLayer_no_div, FeatureExtractor
9 | from model.correlation_package.correlation import Correlation
10 | import numpy as np
11 | from utils.tools import tools
12 | from utils.loss import loss_functions
13 | from utils.pytorch_correlation import Corr_pyTorch
14 | import cv2
15 | import os
16 | import math
17 |
18 |
19 | class network_tools():
20 | class sgu_model(tools.abstract_model):
21 | def __init__(self):
22 | super(network_tools.sgu_model, self).__init__()
23 |
24 | class FlowEstimatorDense_temp(tools.abstract_model):
25 |
26 | def __init__(self, ch_in, f_channels=(128, 128, 96, 64, 32), ch_out=2):
27 | super(FlowEstimatorDense_temp, self).__init__()
28 | N = 0
29 | ind = 0
30 | N += ch_in
31 | self.conv1 = conv(N, f_channels[ind])
32 | N += f_channels[ind]
33 |
34 | ind += 1
35 | self.conv2 = conv(N, f_channels[ind])
36 | N += f_channels[ind]
37 |
38 | ind += 1
39 | self.conv3 = conv(N, f_channels[ind])
40 | N += f_channels[ind]
41 |
42 | ind += 1
43 | self.conv4 = conv(N, f_channels[ind])
44 | N += f_channels[ind]
45 |
46 | ind += 1
47 | self.conv5 = conv(N, f_channels[ind])
48 | N += f_channels[ind]
49 | self.num_feature_channel = N
50 | ind += 1
51 | self.conv_last = conv(N, ch_out, isReLU=False)
52 |
53 | def forward(self, x):
54 | x1 = torch.cat([self.conv1(x), x], dim=1)
55 | x2 = torch.cat([self.conv2(x1), x1], dim=1)
56 | x3 = torch.cat([self.conv3(x2), x2], dim=1)
57 | x4 = torch.cat([self.conv4(x3), x3], dim=1)
58 | x5 = torch.cat([self.conv5(x4), x4], dim=1)
59 | x_out = self.conv_last(x5)
60 | return x5, x_out
61 |
62 | f_channels_es = (32, 32, 32, 16, 8)
63 | in_C = 64
64 | self.warping_layer = WarpingLayer_no_div()
65 | self.dense_estimator_mask = FlowEstimatorDense_temp(in_C, f_channels=f_channels_es, ch_out=3)
66 | self.upsample_output_conv = nn.Sequential(conv(3, 16, kernel_size=3, stride=1, dilation=1),
67 | conv(16, 16, stride=2),
68 | conv(16, 32, kernel_size=3, stride=1, dilation=1),
69 | conv(32, 32, stride=2), )
70 |
71 | def forward(self, flow_init, feature_1, feature_2, output_level_flow=None):
72 | n, c, h, w = flow_init.shape
73 | n_f, c_f, h_f, w_f = feature_1.shape
74 | if h != h_f or w != w_f:
75 | flow_init = upsample2d_flow_as(flow_init, feature_1, mode="bilinear", if_rate=True)
76 | feature_2_warp = self.warping_layer(feature_2, flow_init)
77 | input_feature = torch.cat((feature_1, feature_2_warp), dim=1)
78 | feature, x_out = self.dense_estimator_mask(input_feature)
79 | inter_flow = x_out[:, :2, :, :]
80 | inter_mask = x_out[:, 2, :, :]
81 | inter_mask = torch.unsqueeze(inter_mask, 1)
82 | inter_mask = torch.sigmoid(inter_mask)
83 | n_, c_, h_, w_ = inter_flow.shape
84 | if output_level_flow is not None:
85 | inter_flow = upsample2d_flow_as(inter_flow, output_level_flow, mode="bilinear", if_rate=True)
86 | inter_mask = upsample2d_flow_as(inter_mask, output_level_flow, mode="bilinear")
87 | flow_init = output_level_flow
88 | flow_up = tools.torch_warp(flow_init, inter_flow) * (1 - inter_mask) + flow_init * inter_mask
89 | return flow_init, flow_up, inter_flow, inter_mask
90 |
91 | def output_conv(self, x):
92 | return self.upsample_output_conv(x)
93 |
94 | @classmethod
95 | def normalize_features(cls, feature_list, normalize, center, moments_across_channels=True, moments_across_images=True):
96 | """Normalizes feature tensors (e.g., before computing the cost volume).
97 | Args:
98 | feature_list: list of torch tensors, each with dimensions [b, c, h, w]
99 | normalize: bool flag, divide features by their standard deviation
100 | center: bool flag, subtract feature mean
101 | moments_across_channels: bool flag, compute mean and std across channels, 看到UFlow默认是True
102 | moments_across_images: bool flag, compute mean and std across images, 看到UFlow默认是True
103 |
104 | Returns:
105 | list, normalized feature_list
106 | """
107 |
108 | # Compute feature statistics.
109 |
110 | statistics = collections.defaultdict(list)
111 | axes = [1, 2, 3] if moments_across_channels else [2, 3] # [b, c, h, w]
112 | for feature_image in feature_list:
113 | mean = torch.mean(feature_image, dim=axes, keepdim=True) # [b,1,1,1] or [b,c,1,1]
114 | variance = torch.var(feature_image, dim=axes, keepdim=True) # [b,1,1,1] or [b,c,1,1]
115 | statistics['mean'].append(mean)
116 | statistics['var'].append(variance)
117 |
118 | if moments_across_images:
119 | # statistics['mean'] = ([tf.reduce_mean(input_tensor=statistics['mean'])] *
120 | # len(feature_list))
121 | # statistics['var'] = [tf.reduce_mean(input_tensor=statistics['var'])
122 | # ] * len(feature_list)
123 | statistics['mean'] = ([torch.mean(torch.stack(statistics['mean'], dim=0), dim=(0,))] * len(feature_list))
124 | statistics['var'] = ([torch.var(torch.stack(statistics['var'], dim=0), dim=(0,))] * len(feature_list))
125 |
126 | statistics['std'] = [torch.sqrt(v + 1e-16) for v in statistics['var']]
127 |
128 | # Center and normalize features.
129 |
130 | if center:
131 | feature_list = [
132 | f - mean for f, mean in zip(feature_list, statistics['mean'])
133 | ]
134 | if normalize:
135 | feature_list = [f / std for f, std in zip(feature_list, statistics['std'])]
136 |
137 | return feature_list
138 |
139 | @classmethod
140 | def weighted_ssim(cls, x, y, weight, c1=float('inf'), c2=9e-6, weight_epsilon=0.01):
141 | """Computes a weighted structured image similarity measure.
142 | Args:
143 | x: a batch of images, of shape [B, C, H, W].
144 | y: a batch of images, of shape [B, C, H, W].
145 | weight: shape [B, 1, H, W], representing the weight of each
146 | pixel in both images when we come to calculate moments (means and
147 | correlations). values are in [0,1]
148 | c1: A floating point number, regularizes division by zero of the means.
149 | c2: A floating point number, regularizes division by zero of the second
150 | moments.
151 | weight_epsilon: A floating point number, used to regularize division by the
152 | weight.
153 |
154 | Returns:
155 | A tuple of two pytorch Tensors. First, of shape [B, C, H-2, W-2], is scalar
156 | similarity loss per pixel per channel, and the second, of shape
157 | [B, 1, H-2. W-2], is the average pooled `weight`. It is needed so that we
158 | know how much to weigh each pixel in the first tensor. For example, if
159 | `'weight` was very small in some area of the images, the first tensor will
160 | still assign a loss to these pixels, but we shouldn't take the result too
161 | seriously.
162 | """
163 |
164 | def _avg_pool3x3(x):
165 | # tf kernel [b,h,w,c]
166 | return F.avg_pool2d(x, (3, 3), (1, 1))
167 | # return tf.nn.avg_pool(x, [1, 3, 3, 1], [1, 1, 1, 1], 'VALID')
168 |
169 | if c1 == float('inf') and c2 == float('inf'):
170 | raise ValueError('Both c1 and c2 are infinite, SSIM loss is zero. This is '
171 | 'likely unintended.')
172 | average_pooled_weight = _avg_pool3x3(weight)
173 | weight_plus_epsilon = weight + weight_epsilon
174 | inverse_average_pooled_weight = 1.0 / (average_pooled_weight + weight_epsilon)
175 |
176 | def weighted_avg_pool3x3(z):
177 | wighted_avg = _avg_pool3x3(z * weight_plus_epsilon)
178 | return wighted_avg * inverse_average_pooled_weight
179 |
180 | mu_x = weighted_avg_pool3x3(x)
181 | mu_y = weighted_avg_pool3x3(y)
182 | sigma_x = weighted_avg_pool3x3(x ** 2) - mu_x ** 2
183 | sigma_y = weighted_avg_pool3x3(y ** 2) - mu_y ** 2
184 | sigma_xy = weighted_avg_pool3x3(x * y) - mu_x * mu_y
185 | if c1 == float('inf'):
186 | ssim_n = (2 * sigma_xy + c2)
187 | ssim_d = (sigma_x + sigma_y + c2)
188 | elif c2 == float('inf'):
189 | ssim_n = 2 * mu_x * mu_y + c1
190 | ssim_d = mu_x ** 2 + mu_y ** 2 + c1
191 | else:
192 | ssim_n = (2 * mu_x * mu_y + c1) * (2 * sigma_xy + c2)
193 | ssim_d = (mu_x ** 2 + mu_y ** 2 + c1) * (sigma_x + sigma_y + c2)
194 | result = ssim_n / ssim_d
195 | return torch.clamp((1 - result) / 2, 0, 1), average_pooled_weight
196 |
197 | @classmethod
198 | def edge_aware_smoothness_order1(cls, img, pred):
199 | def gradient_x(img):
200 | gx = img[:, :, :-1, :] - img[:, :, 1:, :]
201 | return gx
202 |
203 | def gradient_y(img):
204 | gy = img[:, :, :, :-1] - img[:, :, :, 1:]
205 | return gy
206 |
207 | pred_gradients_x = gradient_x(pred)
208 | pred_gradients_y = gradient_y(pred)
209 |
210 | image_gradients_x = gradient_x(img)
211 | image_gradients_y = gradient_y(img)
212 |
213 | weights_x = torch.exp(-torch.mean(torch.abs(image_gradients_x), 1, keepdim=True))
214 | weights_y = torch.exp(-torch.mean(torch.abs(image_gradients_y), 1, keepdim=True))
215 |
216 | smoothness_x = torch.abs(pred_gradients_x) * weights_x
217 | smoothness_y = torch.abs(pred_gradients_y) * weights_y
218 | return torch.mean(smoothness_x) + torch.mean(smoothness_y)
219 |
220 | @classmethod
221 | def edge_aware_smoothness_order2(cls, img, pred):
222 | def gradient_x(img, stride=1):
223 | gx = img[:, :, :-stride, :] - img[:, :, stride:, :]
224 | return gx
225 |
226 | def gradient_y(img, stride=1):
227 | gy = img[:, :, :, :-stride] - img[:, :, :, stride:]
228 | return gy
229 |
230 | pred_gradients_x = gradient_x(pred)
231 | pred_gradients_xx = gradient_x(pred_gradients_x)
232 | pred_gradients_y = gradient_y(pred)
233 | pred_gradients_yy = gradient_y(pred_gradients_y)
234 |
235 | image_gradients_x = gradient_x(img, stride=2)
236 | image_gradients_y = gradient_y(img, stride=2)
237 |
238 | weights_x = torch.exp(-torch.mean(torch.abs(image_gradients_x), 1, keepdim=True))
239 | weights_y = torch.exp(-torch.mean(torch.abs(image_gradients_y), 1, keepdim=True))
240 |
241 | smoothness_x = torch.abs(pred_gradients_xx) * weights_x
242 | smoothness_y = torch.abs(pred_gradients_yy) * weights_y
243 | return torch.mean(smoothness_x) + torch.mean(smoothness_y)
244 |
245 | @classmethod
246 | def flow_smooth_delta(cls, flow, if_second_order=False):
247 | def gradient(x):
248 | D_dy = x[:, :, 1:] - x[:, :, :-1]
249 | D_dx = x[:, :, :, 1:] - x[:, :, :, :-1]
250 | return D_dx, D_dy
251 |
252 | dx, dy = gradient(flow)
253 | # dx2, dxdy = gradient(dx)
254 | # dydx, dy2 = gradient(dy)
255 | if if_second_order:
256 | dx2, dxdy = gradient(dx)
257 | dydx, dy2 = gradient(dy)
258 | smooth_loss = dx.abs().mean() + dy.abs().mean() + dx2.abs().mean() + dxdy.abs().mean() + dydx.abs().mean() + dy2.abs().mean()
259 | else:
260 | smooth_loss = dx.abs().mean() + dy.abs().mean()
261 | # smooth_loss = dx.abs().mean() + dy.abs().mean() # + dx2.abs().mean() + dxdy.abs().mean() + dydx.abs().mean() + dy2.abs().mean()
262 | # 暂时不上二阶的平滑损失,似乎加上以后就太猛了,无法降低photo loss TODO
263 | return smooth_loss
264 |
265 | @classmethod
266 | def photo_loss_multi_type(cls, x, y, occ_mask, photo_loss_type='abs_robust', # abs_robust, charbonnier,L1, SSIM
267 | photo_loss_delta=0.4, photo_loss_use_occ=False,
268 | ):
269 | occ_weight = occ_mask
270 | if photo_loss_type == 'abs_robust':
271 | photo_diff = x - y
272 | loss_diff = (torch.abs(photo_diff) + 0.01).pow(photo_loss_delta)
273 | elif photo_loss_type == 'charbonnier':
274 | photo_diff = x - y
275 | loss_diff = ((photo_diff) ** 2 + 1e-6).pow(photo_loss_delta)
276 | elif photo_loss_type == 'L1':
277 | photo_diff = x - y
278 | loss_diff = torch.abs(photo_diff + 1e-6)
279 | elif photo_loss_type == 'SSIM':
280 | loss_diff, occ_weight = cls.weighted_ssim(x, y, occ_mask)
281 | else:
282 | raise ValueError('wrong photo_loss type: %s' % photo_loss_type)
283 |
284 | if photo_loss_use_occ:
285 | photo_loss = torch.sum(loss_diff * occ_weight) / (torch.sum(occ_weight) + 1e-6)
286 | else:
287 | photo_loss = torch.mean(loss_diff)
288 | return photo_loss
289 |
290 |
291 | class UPFlow_net(tools.abstract_model):
292 | class config(tools.abstract_config):
293 | def __init__(self):
294 | # occ loss choose
295 | self.occ_type = 'for_back_check'
296 | self.alpha_1 = 0.1
297 | self.alpha_2 = 0.5
298 | self.occ_check_obj_out_all = 'obj' # if boundary dilated warping is used, here should be obj
299 | self.stop_occ_gradient = False
300 | self.smooth_level = 'final' # final or 1/4
301 | self.smooth_type = 'edge' # edge or delta
302 | self.smooth_order_1_weight = 1
303 | # smooth loss
304 | self.smooth_order_2_weight = 0
305 | # photo loss type add SSIM
306 | self.photo_loss_type = 'abs_robust' # abs_robust, charbonnier,L1, SSIM
307 | self.photo_loss_delta = 0.4
308 | self.photo_loss_use_occ = False
309 | self.photo_loss_census_weight = 0
310 | # use cost volume norm
311 | self.if_norm_before_cost_volume = False
312 | self.norm_moments_across_channels = True
313 | self.norm_moments_across_images = True
314 | self.multi_scale_distillation_weight = 0
315 | self.multi_scale_distillation_style = 'upup' # down,upup,
316 | # 'down', 'upup', 'updown'
317 | self.multi_scale_distillation_occ = True # if consider occlusion mask in multiscale distilation
318 | self.if_froze_pwc = False
319 | self.input_or_sp_input = 1 # use raw input or special input for photo loss
320 | self.if_use_boundary_warp = True # if use the boundary dilated warping
321 |
322 | self.if_sgu_upsample = False # if use sgu upsampling
323 | self.if_use_cor_pytorch = False # use my implementation of correlation layer by pytorch. only for test model in cpu(corr layer cuda is not compiled)
324 |
325 | def __call__(self, ):
326 | # return PWCNet_unsup_irr_bi_v5_4(self)
327 | return UPFlow_net(self)
328 |
329 | def __init__(self, conf: config):
330 | super(UPFlow_net, self).__init__()
331 | # === get config file
332 | self.conf = conf
333 |
334 | # === build the network
335 | self.search_range = 4
336 | self.num_chs = [3, 16, 32, 64, 96, 128, 196]
337 | # 1/2 1/4 1/8 1/16 1/32 1/64
338 | self.estimator_f_channels = (128, 128, 96, 64, 32)
339 | self.context_f_channels = (128, 128, 128, 96, 64, 32, 2)
340 | self.output_level = 4
341 | self.num_levels = 7
342 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
343 | self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
344 | self.warping_layer = WarpingLayer_no_div()
345 | self.dim_corr = (self.search_range * 2 + 1) ** 2
346 | self.num_ch_in = self.dim_corr + 32 + 2
347 | self.flow_estimators = FlowEstimatorDense_v2(self.num_ch_in, f_channels=self.estimator_f_channels)
348 | self.context_networks = ContextNetwork_v2_(self.flow_estimators.n_channels + 2, f_channels=self.context_f_channels)
349 | self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),
350 | conv(128, 32, kernel_size=1, stride=1, dilation=1),
351 | conv(96, 32, kernel_size=1, stride=1, dilation=1),
352 | conv(64, 32, kernel_size=1, stride=1, dilation=1),
353 | conv(32, 32, kernel_size=1, stride=1, dilation=1)])
354 | self.occ_check_model_ls = []
355 | self.correlation_pytorch = Corr_pyTorch(pad_size=self.search_range, kernel_size=1,
356 | max_displacement=self.search_range, stride1=1, stride2=1) # correlation layer using pytorch
357 | # === build sgu upsampling
358 | if self.conf.if_sgu_upsample:
359 | self.sgi_model = network_tools.sgu_model()
360 | else:
361 | self.sgi_model = None
362 |
363 | # === build loss function
364 | self.occ_check_model = tools.occ_check_model(occ_type=self.conf.occ_type, occ_alpha_1=self.conf.alpha_1, occ_alpha_2=self.conf.alpha_2,
365 | obj_out_all=self.conf.occ_check_obj_out_all)
366 | initialize_msra(self.modules())
367 | if self.conf.if_froze_pwc:
368 | self.froze_PWC()
369 |
370 | def forward(self, input_dict: dict):
371 | '''
372 | :param input_dict: im1, im2, im1_raw, im2_raw, start, if_loss
373 | :return: output_dict: flows, flow_f_out, flow_b_out, photo_loss
374 | '''
375 | im1_ori, im2_ori = input_dict['im1'], input_dict['im2'] # in training: the cropped image; in testing: the input image
376 | if input_dict['if_loss']:
377 | if self.conf.input_or_sp_input == 1:
378 | im1, im2 = im1_ori, im2_ori
379 | else:
380 | im1, im2 = input_dict['im1_sp'], input_dict['im2_sp'] # change the input image to special input image and the original images are used for loss computing
381 | else:
382 | im1, im2 = im1_ori, im2_ori
383 |
384 | output_dict = {}
385 | flow_f_pwc_out, flow_b_pwc_out, flows = self.forward_2_frame_v3(im1, im2, if_loss=input_dict['if_loss']) # forward estimation
386 | occ_fw, occ_bw = self.occ_check_model(flow_f=flow_f_pwc_out, flow_b=flow_b_pwc_out) # 0 in occ area, 1 in others
387 |
388 | ''' ====================================== ===================================== '''
389 | output_dict['flow_f_out'] = flow_f_pwc_out
390 | output_dict['flow_b_out'] = flow_b_pwc_out
391 | output_dict['occ_fw'] = occ_fw
392 | output_dict['occ_bw'] = occ_bw
393 |
394 | if input_dict['if_loss']:
395 | # === smooth loss
396 | if self.conf.smooth_level == 'final':
397 | s_flow_f, s_flow_b = flow_f_pwc_out, flow_b_pwc_out
398 | s_im1, s_im2 = im1_ori, im2_ori
399 | elif self.conf.smooth_level == '1/4':
400 | s_flow_f, s_flow_b = flows[0] # flow in 1/4 scale
401 | _, _, temp_h, temp_w = s_flow_f.size()
402 | s_im1 = F.interpolate(im1_ori, (temp_h, temp_w), mode='area')
403 | s_im2 = F.interpolate(im2_ori, (temp_h, temp_w), mode='area')
404 | else:
405 | raise ValueError('wrong smooth level choosed: %s' % self.smooth_level)
406 | smooth_loss = 0
407 | # 1 order smooth loss
408 | if self.conf.smooth_order_1_weight > 0:
409 | if self.conf.smooth_type == 'edge':
410 | smooth_loss += self.conf.smooth_order_1_weight * network_tools.edge_aware_smoothness_order1(img=s_im1, pred=s_flow_f)
411 | smooth_loss += self.conf.smooth_order_1_weight * network_tools.edge_aware_smoothness_order1(img=s_im2, pred=s_flow_b)
412 | elif self.conf.smooth_type == 'delta':
413 | smooth_loss += self.conf.smooth_order_1_weight * network_tools.flow_smooth_delta(flow=s_flow_f, if_second_order=False)
414 | smooth_loss += self.conf.smooth_order_1_weight * network_tools.flow_smooth_delta(flow=s_flow_b, if_second_order=False)
415 | else:
416 | raise ValueError('wrong smooth_type: %s' % self.conf.smooth_type)
417 |
418 | # 2 order smooth loss
419 | if self.conf.smooth_order_2_weight > 0:
420 | if self.conf.smooth_type == 'edge':
421 | smooth_loss += self.conf.smooth_order_2_weight * network_tools.edge_aware_smoothness_order2(img=s_im1, pred=s_flow_f)
422 | smooth_loss += self.conf.smooth_order_2_weight * network_tools.edge_aware_smoothness_order2(img=s_im2, pred=s_flow_b)
423 | elif self.conf.smooth_type == 'delta':
424 | smooth_loss += self.conf.smooth_order_2_weight * network_tools.flow_smooth_delta(flow=s_flow_f, if_second_order=True)
425 | smooth_loss += self.conf.smooth_order_2_weight * network_tools.flow_smooth_delta(flow=s_flow_b, if_second_order=True)
426 | else:
427 | raise ValueError('wrong smooth_type: %s' % self.conf.smooth_type)
428 | output_dict['smooth_loss'] = smooth_loss
429 |
430 | # === photo loss
431 | if self.conf.if_use_boundary_warp:
432 | im1_s, im2_s, start_s = input_dict['im1_raw'], input_dict['im2_raw'], input_dict['start'] # the image before cropping
433 | im1_warp = tools.boundary_dilated_warp.warp_im(im2_s, flow_f_pwc_out, start_s) # warped im1 by forward flow and im2
434 | im2_warp = tools.boundary_dilated_warp.warp_im(im1_s, flow_b_pwc_out, start_s)
435 | else:
436 | im1_warp = tools.torch_warp(im2_ori, flow_f_pwc_out) # warped im1 by forward flow and im2
437 | im2_warp = tools.torch_warp(im1_ori, flow_b_pwc_out)
438 | # photo loss
439 | if self.conf.stop_occ_gradient:
440 | occ_fw, occ_bw = occ_fw.clone().detach(), occ_bw.clone().detach()
441 | photo_loss = network_tools.photo_loss_multi_type(im1_ori, im1_warp, occ_fw, photo_loss_type=self.conf.photo_loss_type,
442 | photo_loss_delta=self.conf.photo_loss_delta, photo_loss_use_occ=self.conf.photo_loss_use_occ)
443 | photo_loss += network_tools.photo_loss_multi_type(im2_ori, im2_warp, occ_bw, photo_loss_type=self.conf.photo_loss_type,
444 | photo_loss_delta=self.conf.photo_loss_delta, photo_loss_use_occ=self.conf.photo_loss_use_occ)
445 | output_dict['photo_loss'] = photo_loss
446 | output_dict['im1_warp'] = im1_warp
447 | output_dict['im2_warp'] = im2_warp
448 |
449 | # === census loss
450 | if self.conf.photo_loss_census_weight > 0:
451 | census_loss = loss_functions.census_loss_torch(img1=im1_ori, img1_warp=im1_warp, mask=occ_fw, q=self.conf.photo_loss_delta,
452 | charbonnier_or_abs_robust=False, if_use_occ=self.conf.photo_loss_use_occ, averge=True) + \
453 | loss_functions.census_loss_torch(img1=im2_ori, img1_warp=im2_warp, mask=occ_bw, q=self.conf.photo_loss_delta,
454 | charbonnier_or_abs_robust=False, if_use_occ=self.conf.photo_loss_use_occ, averge=True)
455 | census_loss *= self.conf.photo_loss_census_weight
456 | else:
457 | census_loss = None
458 | output_dict['census_loss'] = census_loss
459 |
460 | # === multi scale distillation loss
461 | if self.conf.multi_scale_distillation_weight > 0:
462 | flow_fw_label = flow_f_pwc_out.clone().detach()
463 | flow_bw_label = flow_b_pwc_out.clone().detach()
464 | msd_loss_ls = []
465 | for i, (scale_fw, scale_bw) in enumerate(flows):
466 | if self.conf.multi_scale_distillation_style == 'down':
467 | flow_fw_label_sacle = upsample_flow(flow_fw_label, target_flow=scale_fw)
468 | occ_scale_fw = F.interpolate(occ_fw, [scale_fw.size(2), scale_fw.size(3)], mode='nearest')
469 | flow_bw_label_sacle = upsample_flow(flow_bw_label, target_flow=scale_bw)
470 | occ_scale_bw = F.interpolate(occ_bw, [scale_bw.size(2), scale_bw.size(3)], mode='nearest')
471 | elif self.conf.multi_scale_distillation_style == 'upup':
472 | flow_fw_label_sacle = flow_fw_label
473 | scale_fw = upsample_flow(scale_fw, target_flow=flow_fw_label_sacle)
474 | occ_scale_fw = occ_fw
475 | flow_bw_label_sacle = flow_bw_label
476 | scale_bw = upsample_flow(scale_bw, target_flow=flow_bw_label_sacle)
477 | occ_scale_bw = occ_bw
478 | else:
479 | raise ValueError('wrong multi_scale_distillation_style: %s' % self.conf.multi_scale_distillation_style)
480 | msd_loss_scale_fw = network_tools.photo_loss_multi_type(x=scale_fw, y=flow_fw_label_sacle, occ_mask=occ_scale_fw, photo_loss_type='abs_robust',
481 | photo_loss_use_occ=self.conf.multi_scale_distillation_occ)
482 | msd_loss_ls.append(msd_loss_scale_fw)
483 | msd_loss_scale_bw = network_tools.photo_loss_multi_type(x=scale_bw, y=flow_bw_label_sacle, occ_mask=occ_scale_bw, photo_loss_type='abs_robust',
484 | photo_loss_use_occ=self.conf.multi_scale_distillation_occ)
485 | msd_loss_ls.append(msd_loss_scale_bw)
486 | msd_loss = sum(msd_loss_ls)
487 | msd_loss = self.conf.multi_scale_distillation_weight * msd_loss
488 | else:
489 | msd_loss = None
490 |
491 | output_dict['msd_loss'] = msd_loss
492 | return output_dict
493 |
494 | def forward_2_frame_v3(self, x1_raw, x2_raw, if_loss=False):
495 | _, _, height_im, width_im = x1_raw.size()
496 | # on the bottom level are original images
497 | x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
498 | x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
499 | flows = []
500 | # init
501 | b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
502 | init_dtype = x1_pyramid[0].dtype
503 | init_device = x1_pyramid[0].device
504 | flow_f = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
505 | flow_b = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
506 | # build pyramid
507 | feature_level_ls = []
508 | for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
509 | x1_1by1 = self.conv_1x1[l](x1)
510 | x2_1by1 = self.conv_1x1[l](x2)
511 | feature_level_ls.append((x1, x1_1by1, x2, x2_1by1))
512 | if l == self.output_level:
513 | break
514 | for level, (x1, x1_1by1, x2, x2_1by1) in enumerate(feature_level_ls):
515 | flow_f, flow_b, flow_f_res, flow_b_res = self.decode_level_res(level=level, flow_1=flow_f, flow_2=flow_b,
516 | feature_1=x1, feature_1_1x1=x1_1by1,
517 | feature_2=x2, feature_2_1x1=x2_1by1,
518 | img_ori_1=x1_raw, img_ori_2=x2_raw)
519 | flow_f = flow_f + flow_f_res
520 | flow_b = flow_b + flow_b_res
521 | flows.append([flow_f, flow_b])
522 | flow_f_out = upsample2d_flow_as(flow_f, x1_raw, mode="bilinear", if_rate=True)
523 | flow_b_out = upsample2d_flow_as(flow_b, x1_raw, mode="bilinear", if_rate=True)
524 |
525 | # === upsample to full resolution
526 | if self.conf.if_sgu_upsample:
527 | feature_1_1x1 = self.sgi_model.output_conv(x1_raw)
528 | feature_2_1x1 = self.sgi_model.output_conv(x2_raw)
529 | flow_f_out = self.self_guided_upsample(flow_up_bilinear=flow_f, feature_1=feature_1_1x1, feature_2=feature_2_1x1, output_level_flow=flow_f_out)
530 | flow_b_out = self.self_guided_upsample(flow_up_bilinear=flow_b, feature_1=feature_2_1x1, feature_2=feature_1_1x1, output_level_flow=flow_b_out)
531 | else:
532 | pass
533 | return flow_f_out, flow_b_out, flows[::-1]
534 |
535 | def decode_level_res(self, level, flow_1, flow_2, feature_1, feature_1_1x1, feature_2, feature_2_1x1, img_ori_1, img_ori_2):
536 | flow_1_up_bilinear = upsample2d_flow_as(flow_1, feature_1, mode="bilinear", if_rate=True)
537 | flow_2_up_bilinear = upsample2d_flow_as(flow_2, feature_2, mode="bilinear", if_rate=True)
538 | # warping
539 | if level == 0:
540 | feature_2_warp = feature_2
541 | feature_1_warp = feature_1
542 | else:
543 | if self.conf.if_sgu_upsample:
544 | flow_1_up_bilinear = self.self_guided_upsample(flow_up_bilinear=flow_1_up_bilinear, feature_1=feature_1_1x1, feature_2=feature_2_1x1)
545 | flow_2_up_bilinear = self.self_guided_upsample(flow_up_bilinear=flow_2_up_bilinear, feature_1=feature_2_1x1, feature_2=feature_1_1x1)
546 | feature_2_warp = self.warping_layer(feature_2, flow_1_up_bilinear)
547 | feature_1_warp = self.warping_layer(feature_1, flow_2_up_bilinear)
548 | # if norm feature
549 | if self.conf.if_norm_before_cost_volume:
550 | feature_1, feature_2_warp = network_tools.normalize_features((feature_1, feature_2_warp), normalize=True, center=True,
551 | moments_across_channels=self.conf.norm_moments_across_channels,
552 | moments_across_images=self.conf.norm_moments_across_images)
553 | feature_2, feature_1_warp = network_tools.normalize_features((feature_2, feature_1_warp), normalize=True, center=True,
554 | moments_across_channels=self.conf.norm_moments_across_channels,
555 | moments_across_images=self.conf.norm_moments_across_images)
556 | # correlation
557 | if self.conf.if_use_cor_pytorch:
558 | out_corr_1 = self.correlation_pytorch(feature_1, feature_2_warp)
559 | out_corr_2 = self.correlation_pytorch(feature_2, feature_1_warp)
560 | else:
561 | out_corr_1 = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(feature_1, feature_2_warp)
562 | out_corr_2 = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(feature_2, feature_1_warp)
563 | out_corr_relu_1 = self.leakyRELU(out_corr_1)
564 | out_corr_relu_2 = self.leakyRELU(out_corr_2)
565 | feature_int_1, flow_res_1 = self.flow_estimators(torch.cat([out_corr_relu_1, feature_1_1x1, flow_1_up_bilinear], dim=1))
566 | feature_int_2, flow_res_2 = self.flow_estimators(torch.cat([out_corr_relu_2, feature_2_1x1, flow_2_up_bilinear], dim=1))
567 | flow_1_up_bilinear_ = flow_1_up_bilinear + flow_res_1
568 | flow_2_up_bilinear_ = flow_2_up_bilinear + flow_res_2
569 | flow_fine_1 = self.context_networks(torch.cat([feature_int_1, flow_1_up_bilinear_], dim=1))
570 | flow_fine_2 = self.context_networks(torch.cat([feature_int_2, flow_2_up_bilinear_], dim=1))
571 | flow_1_res = flow_res_1 + flow_fine_1
572 | flow_2_res = flow_res_2 + flow_fine_2
573 | return flow_1_up_bilinear, flow_2_up_bilinear, flow_1_res, flow_2_res
574 |
575 | def froze_PWC(self):
576 | for param in self.feature_pyramid_extractor.parameters():
577 | param.requires_grad = False
578 | for param in self.flow_estimators.parameters():
579 | param.requires_grad = False
580 | for param in self.context_networks.parameters():
581 | param.requires_grad = False
582 | for param in self.conv_1x1.parameters():
583 | param.requires_grad = False
584 |
585 | def self_guided_upsample(self, flow_up_bilinear, feature_1, feature_2, output_level_flow=None):
586 | flow_up_bilinear_, out_flow, inter_flow, inter_mask = self.sgi_model(flow_up_bilinear, feature_1, feature_2, output_level_flow=output_level_flow)
587 | return out_flow
588 |
589 | @classmethod
590 | def demo(cls):
591 | param_dict = {
592 | 'occ_type': 'for_back_check',
593 | 'alpha_1': 0.1,
594 | 'alpha_2': 0.5,
595 | 'occ_check_obj_out_all': 'obj',
596 | 'stop_occ_gradient': False,
597 | 'smooth_level': 'final', # final or 1/4
598 | 'smooth_type': 'edge', # edge or delta
599 | 'smooth_order_1_weight': 1,
600 | # smooth loss
601 | 'smooth_order_2_weight': 0,
602 | # photo loss type add SSIM
603 | 'photo_loss_type': 'abs_robust', # abs_robust, charbonnier,L1, SSIM
604 | 'photo_loss_delta': 0.4,
605 | 'photo_loss_use_occ': False,
606 | 'photo_loss_census_weight': 1,
607 | # use cost volume norm
608 | 'if_norm_before_cost_volume': True,
609 | 'norm_moments_across_channels': False,
610 | 'norm_moments_across_images': False,
611 | 'multi_scale_distillation_weight': 1,
612 | 'multi_scale_distillation_style': 'upup',
613 | 'multi_scale_photo_weight': 1, # 'down', 'upup', 'updown'
614 | 'multi_scale_distillation_occ': True, # if consider occlusion mask in multiscale distilation
615 | 'if_froze_pwc': False,
616 | 'input_or_sp_input': 1,
617 | 'if_use_boundary_warp': True,
618 | 'if_use_cor_pytorch': True,
619 | }
620 | net_conf = UPFlow_net.config()
621 | net_conf.update(param_dict)
622 | net_conf.get_name(print_now=True)
623 | net = net_conf() # .cuda()
624 | net.eval()
625 | im = np.random.random((1, 3, 320, 320))
626 | start = np.zeros((1, 2, 1, 1))
627 | start = torch.from_numpy(start).float() # .cuda()
628 | im_torch = torch.from_numpy(im).float() # .cuda()
629 | input_dict = {'im1': im_torch, 'im2': im_torch,
630 | 'im1_raw': im_torch, 'im2_raw': im_torch, 'im1_sp': im_torch, 'im2_sp': im_torch, 'start': start, 'if_loss': True}
631 | output_dict = net(input_dict)
632 | print('smooth_loss', output_dict['smooth_loss'], 'photo_loss', output_dict['photo_loss'], 'census_loss', output_dict['census_loss'])
633 | for i in output_dict.keys():
634 | if output_dict[i] is None:
635 | print(i, output_dict[i])
636 | else:
637 | tools.check_tensor(output_dict[i], i)
638 |
639 |
640 | if __name__ == '__main__':
641 | UPFlow_net.demo()
642 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | h5py==2.9.0
2 | imageio==2.5.0
3 | matplotlib==3.0.3
4 | numpy==1.17.0
5 | opencv-python==4.1.0.25
6 | Pillow==6.1.0
7 | pypng==0.0.20
8 | rarfile==3.1
9 | scikit-image==0.15.0
10 | scipy==1.3.1
11 | tensorflow==2.0.0a0
12 | torch==1.1.0
13 | torchvision==0.1.6
14 | tqdm==4.36.1
15 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/ex_runner.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | from utils.tools import tools
4 | import cv2
5 | import numpy as np
6 | from copy import deepcopy
7 | import torch
8 | import warnings # ignore warnings
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from dataset.kitti_dataset import kitti_train, kitti_flow
12 | from model.upflow import UPFlow_net
13 | from torch.utils.data import DataLoader
14 | import time
15 |
16 | ''' scripts for training:
17 | 1. simply using photo loss and smooth loss
18 | 2. add occlusion checking
19 | 3. add teacher-student loss(ARFlow)
20 | '''
21 |
22 | class Trainer_model(tools.abs_test_model):
23 | class config(tools.abstract_config): # TODO
24 |
25 | def __init__(self, **kwargs):
26 | self.lr = 1e-2
27 | self.weight_decay = 1e-5
28 | self.optmizer_name = 'adam'
29 | self.gamma = 0.95
30 | self.gpu_opt = None # None is multi GPU
31 | self.model_path = '/data/luokunming/Optical_Flow_all/training/unsup_PWC_flyc_photo_smooth/unsupPWC_epoch_27_Flyc_epe_error(5.125).pth' # loading the model
32 | self.load_relax = False # load the pretrain model的时候是不是放宽松要求
33 | self.print_every = 20
34 |
35 | # parameters of spatial transform
36 | self.if_train_sp = False # 开关,是否使用spatial transform增强
37 | self.sptrans_add_noise = True
38 | self.sptrans_hflip = True
39 | self.sptrans_rotate = [-0.01, 0.01, -0.01, 0.01]
40 | self.sptrans_squeeze = [1.0, 1.0, 1.0, 1.0]
41 | self.sptrans_trans = [0.04, 0.005]
42 | self.sptrans_vflip = True
43 | self.sptrans_zoom = [1.0, 1.4, 0.99, 1.01]
44 | self.spatial_trans_if_mask = True # 在spatial transform蒸馏的时候使用使用occ mask
45 | self.spatial_trans_eps = 0.0
46 | self.spatial_trans_q = 1.0
47 | self.spatial_trans_loss_weight = 0.01 # 计算在loss里面的权重
48 | self.sp_input_or_sp_input = 1
49 | self.train_sp_msd_loss_weight = 0 # sp的时候也算多尺度的损失
50 | self.train_sp_msd_loss_style = 'down' # 暂时只有'down'和'up'
51 |
52 | self.final_sp_train_weight = 0 # 用clean的输出来监督一波final, 必须要'sp_input_or_sp_input' <=0才能使用
53 | self.final_sp_train_style = 'down' # 暂时没有用,
54 |
55 | self.multi_scale_eval = False # 验证的时候也对比计算多尺度的结果
56 |
57 | self.train_dir = '/data/luokunming/Optical_Flow_all/training/demo_unsupervised_train' # 这个参数主函数里面会设置
58 | self.update(kwargs)
59 |
60 | def __call__(self, net_work: tools.abstract_model):
61 | # load network
62 | if self.model_path is not None:
63 | net_work.load_model(self.model_path, if_relax=self.load_relax)
64 | return Trainer_model(self, net_work)
65 |
66 | def __init__(self, conf: config, net_work: tools.abstract_model):
67 | super(Trainer_model, self).__init__(conf=conf, net_work=net_work)
68 | self.conf = conf
69 | self.net_work=net_work
70 | if self.conf.optmizer_name == 'adam':
71 | self.optimizer = optim.Adam(self.net_work.parameters(), lr=self.lr, amsgrad=True, weight_decay=self.weight_decay)
72 | else:
73 | raise ValueError('wrong optmizer name: %s' % self.conf.optmizer_name)
74 | self.data_clock = tools.Clock_luo()
75 | self.msd_loss_meter = tools.AverageMeter()
76 | self.sp_msd_loss_meter = tools.AverageMeter()
77 | self.final_sp_loss_meter = tools.AverageMeter()
78 | self.app_loss_meter = tools.AverageMeter()
79 | self.appd_loss_meter = tools.AverageMeter()
80 | self.inpaint_img_loss_meter = tools.AverageMeter()
81 | self.spatial_trans_loss_meter = tools.AverageMeter()
82 | self.multi_scale_loss_meter = tools.AverageMeter()
83 | self.census_loss_meter = tools.AverageMeter()
84 | self.best_name = ''
85 | self.occ_check_model = tools.occ_check_model(occ_type=self.conf.occ_type, occ_alpha_1=self.conf.alpha_1, occ_alpha_2=self.conf.alpha_2,
86 | sum_abs_or_squar=self.conf.occ_check_sum_abs_or_squar, obj_out_all=self.conf.occ_check_obj_out_all)
87 | self.print_str = ''
88 | self.cnt = 0
89 | self.temp_save_eval_test = []
90 |
91 | # ===== spatial transform =====
92 | '''
93 | class config():
94 | def __init__(self):
95 | self.add_noise = False
96 | self.hflip = False
97 | self.rotate = [-0.01, 0.01, -0.01, 0.01]
98 | self.squeeze = [1.0, 1.0, 1.0, 1.0]
99 | self.trans = [0.04, 0.005]
100 | self.vflip = False
101 | self.zoom = [1.0, 1.4, 0.99, 1.01]
102 | '''
103 |
104 | class sp_conf():
105 |
106 | def __init__(self, conf):
107 | self.add_noise = conf.sptrans_add_noise # False
108 | self.hflip = conf.sptrans_hflip # False
109 | self.rotate = conf.sptrans_rotate # [-0.01, 0.01, -0.01, 0.01]
110 | self.squeeze = conf.sptrans_squeeze # [1.0, 1.0, 1.0, 1.0]
111 | self.trans = conf.sptrans_trans # [0.04, 0.005]
112 | self.vflip = conf.sptrans_vflip # False
113 | self.zoom = conf.sptrans_zoom # [1.0, 1.4, 0.99, 1.01]
114 |
115 | self.sp_transform = tools.SP_transform.RandomAffineFlow(
116 | sp_conf(self.conf), addnoise=self.conf.sptrans_add_noise).cuda() #
117 |
118 | def train_batch(self, batch_step, im1, im2, *args, **kwargs): # 训练一个batch
119 |
120 | if self.data_clock.start_flag:
121 | self.data_clock.end()
122 | if_print = batch_step % self.conf.print_every == 0
123 | frame_1_ls = []
124 | frame_2_ls = []
125 | print_str = '%s %s Epoch%d Iter%d [%.4fs]' % (self.conf.print_name, self.best_name, self.epoch, batch_step, self.data_clock.get_during())
126 | if_show = batch_step % self.conf.show_every == 0 and self.conf.show_every > 0
127 | batch_N = im1.shape[0]
128 | im1_crop_ori, im2_crop_ori, start = args
129 | sp_img1_ori, sp_img2_ori = kwargs['im1_crop_at'], kwargs['im2_crop_at'] # final的图片
130 | _, _, h_, w_ = im1_crop_ori.size()
131 | # self.save_image(im1, 'im1')
132 | # self.save_image(im2, 'im2')
133 | # self.save_image(im1_crop_ori, 'im1_crop_ori')
134 | # self.save_image(im2_crop_ori, 'im2_crop_ori')
135 | # self.save_image(sp_img1_ori, 'sp_img1_ori')
136 | # self.save_image(sp_img2_ori, 'sp_img2_ori')
137 | # while True:
138 | # print('return')
139 | # time.sleep(1)
140 | # 决定输入给网络的数据
141 | im1_crop, im2_crop = im1_crop_ori, im2_crop_ori
142 | # ============================================================= 网络输出 ===================================================================
143 | self.optimizer.zero_grad()
144 | # =========== 计算photo loss和smooth loss以及census loss ===========
145 | if self.conf.model_name.lower() in ['pwcirrbiv5_v4', ]:
146 | input_dict = {'im1': im1_crop, 'im2': im2_crop, 'im1_sp': sp_img1_ori, 'im2_sp': sp_img2_ori,
147 | 'im1_raw': im1, 'im2_raw': im2, 'start': start, 'if_loss': True, 'if_show': if_show}
148 | output_dict = self.net_work(input_dict)
149 | flow_fw, flow_bw = output_dict['flow_f_out'], output_dict['flow_b_out']
150 | occ_fw, occ_bw = output_dict['occ_fw'], output_dict['occ_bw']
151 | photo_loss, smooth_loss, census_loss = output_dict['photo_loss'].mean(), output_dict['smooth_loss'].mean(), output_dict['census_loss']
152 | im1_warp = output_dict['im1_warp']
153 | im2_warp = output_dict['im2_warp']
154 | loss = photo_loss + smooth_loss
155 | if census_loss is None:
156 | pass
157 | else:
158 | census_loss = census_loss.mean()
159 | loss += census_loss
160 | self.census_loss_meter.update(val=census_loss.item(), num=batch_N)
161 | print_str += ' cens %.4f(%.4f)' % (self.census_loss_meter.val, self.census_loss_meter.avg)
162 | if output_dict['msd_loss'] is None:
163 | pass
164 | else:
165 | msd_loss = output_dict['msd_loss'].mean()
166 | loss += msd_loss
167 | self.msd_loss_meter.update(val=msd_loss.item(), num=batch_N)
168 | print_str += ' msd %.4f(%.4f)' % (self.msd_loss_meter.val, self.msd_loss_meter.avg)
169 | if 'app_loss' not in output_dict.keys():
170 | pass
171 | elif output_dict['app_loss'] is None:
172 | pass
173 | else:
174 | app_loss = output_dict['app_loss'].mean()
175 | loss += app_loss
176 | self.app_loss_meter.update(val=app_loss.item(), num=batch_N)
177 | print_str += ' app %.4f(%.4f)' % (self.app_loss_meter.val, self.app_loss_meter.avg)
178 |
179 | self.photo_loss_meter.update(val=photo_loss.item(), num=batch_N)
180 | self.smooth_loss_meter.update(val=smooth_loss.item(), num=batch_N)
181 | print_str += ' ph %.4f(%.4f)' % (self.photo_loss_meter.val, self.photo_loss_meter.avg)
182 | print_str += ' sm %.4f(%.4f)' % (self.smooth_loss_meter.val, self.smooth_loss_meter.avg)
183 | elif self.conf.model_name.lower() in ['pwcirrbiv5_v5', ]:
184 | input_dict = {'im1': im1_crop, 'im2': im2_crop, 'im1_sp': sp_img1_ori, 'im2_sp': sp_img2_ori,
185 | 'im1_raw': im1, 'im2_raw': im2, 'start': start, 'if_loss': True, 'if_show': if_show}
186 | output_dict = self.net_work(input_dict)
187 | flow_fw, flow_bw = output_dict['flow_f_out'], output_dict['flow_b_out']
188 | occ_fw, occ_bw = output_dict['occ_fw'], output_dict['occ_bw']
189 | photo_loss, smooth_loss, census_loss = output_dict['photo_loss'].mean(), output_dict['smooth_loss'].mean(), output_dict['census_loss']
190 | im1_warp = output_dict['im1_warp']
191 | im2_warp = output_dict['im2_warp']
192 | loss = photo_loss + smooth_loss
193 | if census_loss is None:
194 | pass
195 | else:
196 | census_loss = census_loss.mean()
197 | loss += census_loss
198 | self.census_loss_meter.update(val=census_loss.item(), num=batch_N)
199 | print_str += ' cens %.4f(%.4f)' % (self.census_loss_meter.val, self.census_loss_meter.avg)
200 | if output_dict['msd_loss'] is None:
201 | pass
202 | else:
203 | msd_loss = output_dict['msd_loss'].mean()
204 | loss += msd_loss
205 | self.msd_loss_meter.update(val=msd_loss.item(), num=batch_N)
206 | print_str += ' msd %.4f(%.4f)' % (self.msd_loss_meter.val, self.msd_loss_meter.avg)
207 | if 'occ_loss' not in output_dict.keys():
208 | pass
209 | elif output_dict['occ_loss'] is None:
210 | pass
211 | else:
212 | occ_loss = output_dict['occ_loss'].mean()
213 | loss += occ_loss
214 | self.occ_loss_meter.update(val=occ_loss.item(), num=batch_N)
215 | print_str += ' occ %.4f(%.4f)' % (self.occ_loss_meter.val, self.occ_loss_meter.avg)
216 | self.photo_loss_meter.update(val=photo_loss.item(), num=batch_N)
217 | self.smooth_loss_meter.update(val=smooth_loss.item(), num=batch_N)
218 | print_str += ' ph %.4f(%.4f)' % (self.photo_loss_meter.val, self.photo_loss_meter.avg)
219 | print_str += ' sm %.4f(%.4f)' % (self.smooth_loss_meter.val, self.smooth_loss_meter.avg)
220 | else:
221 | raise ValueError('wrong model: %s' % self.conf.model_name)
222 |
223 | # =========== 计算spatial transform的等变损失 ===========
224 | if self.conf.if_train_sp:
225 | # s = {'imgs': [sp_img1, sp_img2], 'flows_f': [flow_fw], 'masks_f': [occ_fw]}
226 | # s=deepcopy(s)
227 | if self.conf.sp_input_or_sp_input >= 1: # 取clean的图片算sp
228 | sp_img1, sp_img2 = im1_crop_ori, im2_crop_ori
229 | elif self.conf.sp_input_or_sp_input > 0:
230 | if tools.random_flag(threshold_0_1=self.conf.sp_input_or_sp_input):
231 | sp_img1, sp_img2 = sp_img1_ori, sp_img2_ori
232 | else:
233 | sp_img1, sp_img2 = im1_crop_ori, im2_crop_ori
234 | else: # 取final的图片算sp
235 | sp_img1, sp_img2 = sp_img1_ori, sp_img2_ori
236 | flow_fw_pseudo_label, occ_fw_pseudo_label = flow_fw.clone().detach(), occ_fw.clone().detach()
237 | flow_bw_pseudo_label, occ_bw_pseudo_label = flow_bw.clone().detach(), occ_bw.clone().detach()
238 | # 使用final的数据多train一次
239 | if self.conf.final_sp_train_weight > 0 and self.conf.sp_input_or_sp_input <= 0:
240 | input_dict_final_sp = {'im1': sp_img1_ori, 'im2': sp_img2_ori, 'if_loss': False,
241 | 'if_final_sp_train': True, 'final_sp_train_style': self.conf.final_sp_train_style,
242 | 'final_fw_label': flow_fw_pseudo_label, 'final_fw_occ': occ_fw_pseudo_label,
243 | 'final_bw_label': flow_bw_pseudo_label, 'final_bw_occ': occ_bw_pseudo_label}
244 | output_dict_final_sp = self.net_work(input_dict_final_sp)
245 | if output_dict_final_sp['final_sp_loss'] is None:
246 | pass
247 | else:
248 | final_sp_loss = output_dict_final_sp['final_sp_loss'].mean() * self.conf.final_sp_train_weight
249 | loss += final_sp_loss
250 | self.final_sp_loss_meter.update(val=final_sp_loss.item(), num=batch_N)
251 | print_str += ' finalsp %.4f(%.4f)' % (self.final_sp_loss_meter.val, self.final_sp_loss_meter.avg)
252 |
253 | s = {'imgs': [sp_img1, sp_img2], 'flows_f': [flow_fw_pseudo_label], 'masks_f': [occ_fw_pseudo_label]}
254 | st_res = self.sp_transform(s)
255 | flow_t, noc_t = st_res['flows_f'][0], st_res['masks_f'][0]
256 | # run 2nd pass spatial transform
257 | im1_crop_st, im2_crop_st = st_res['imgs']
258 | if self.conf.train_sp_msd_loss_weight > 0:
259 | input_dict_sp = {'im1': im1_crop_st, 'im2': im2_crop_st, 'if_loss': False,
260 | 'if_sp_msd_loss': True, 'fw_pseudo_label': flow_t, 'fw_occ_pseudo': noc_t, 'sp_msd_loss_style': self.conf.train_sp_msd_loss_style}
261 | else:
262 | input_dict_sp = {'im1': im1_crop_st, 'im2': im2_crop_st, 'if_loss': False}
263 | output_dict_sp = self.net_work(input_dict_sp)
264 | flow_fw_st, flow_bw_st = output_dict_sp['flow_f_out'], output_dict_sp['flow_b_out']
265 |
266 | if not self.conf.spatial_trans_if_mask:
267 | noc_t = torch.ones_like(noc_t)
268 | if self.conf.spatial_trans_q <= 0:
269 | l_atst = (flow_fw_st - flow_t).abs()
270 | else:
271 | l_atst = ((flow_fw_st - flow_t).abs() + self.conf.spatial_trans_eps) ** self.conf.spatial_trans_q
272 | l_atst = (l_atst * noc_t).mean() / (noc_t.mean() + 1e-6)
273 | l_atst *= self.conf.spatial_trans_loss_weight
274 | loss += l_atst
275 | self.spatial_trans_loss_meter.update(val=l_atst.item(), num=batch_N)
276 | print_str += ' sp %.4f(%.4f)' % (self.spatial_trans_loss_meter.val, self.spatial_trans_loss_meter.avg)
277 |
278 | if output_dict_sp['sp_msd_loss'] is None:
279 | pass
280 | else:
281 | sp_msd_loss = output_dict_sp['sp_msd_loss'].mean() * self.conf.train_sp_msd_loss_weight
282 | loss += sp_msd_loss
283 | self.sp_msd_loss_meter.update(val=sp_msd_loss.item(), num=batch_N)
284 | print_str += ' spmsd %.4f(%.4f)' % (self.sp_msd_loss_meter.val, self.sp_msd_loss_meter.avg)
285 |
286 | loss.backward()
287 | self.optimizer.step()
288 |
289 | # show training
290 | if if_print:
291 | print(print_str)
292 | # show img
293 | if if_show:
294 | if_show_bw = False # 是否展示backward flow过程
295 | # base thing
296 | im1_crop, im2_crop, im1_warp, flow_fw, occ_fw = tools.tensor_gpu(im1_crop, im2_crop, im1_warp, flow_fw, occ_fw, check_on=False)
297 | frame_1_ls += [('im1', im1_crop), ('im1 ', im1_crop), ('flow forward', flow_fw), ('occ forward', occ_fw)]
298 | frame_2_ls += [('im2', im2_crop), ('im1_warp', im1_warp), ('flow forward', flow_fw), ('occ forward', occ_fw)]
299 | if if_show_bw:
300 | im2_warp, flow_bw, occ_bw = tools.tensor_gpu(im2_warp, flow_bw, occ_bw, check_on=False)
301 | frame_1_ls += [('im2_warp', im2_warp), ('flow backward', flow_bw), ('occ backward', occ_bw)]
302 | frame_2_ls += [('im2', im2_crop), ('flow backward', flow_bw), ('occ backward', occ_bw)]
303 |
304 | # ============================ 有的模型 要加一些操作 =====================
305 | if self.conf.model_name.lower() == '加油':
306 | pass
307 | elif self.conf.model_name.lower() in ['pwcirrbiv5', ]:
308 | pass
309 | elif self.conf.model_name.lower() in ['pwcirrbiv5_v5', ]:
310 | fw_im1 = output_dict['im1_warp_ss']
311 | fw_im2 = output_dict['im2_warp_ss']
312 | fw_im1_, fw_im2_ = tools.tensor_gpu(fw_im1.clone().detach(), fw_im2.clone().detach(), check_on=False)
313 | frame_1_ls += [('fw_im1_', fw_im1_), ('fw_im2_', fw_im2_), ]
314 | frame_2_ls += [('fw_im1_', fw_im1_), ('fw_im2_', fw_im2_), ]
315 | else:
316 | pass # no operation
317 | # raise ValueError(' not implemented model name: %s' % self.conf.model_name)
318 |
319 | self.training_shower.get_batch_pair_all_list_nchw_check_flow_frame1_frame2_gif(batch_dict_ls_frame1=frame_1_ls, batch_dict_ls_frame2=frame_2_ls,
320 | name='iter_%s_%s' % (batch_step, print_str))
321 | self.training_shower.put_frame1_frame2_gif(name='Epoch %d Iteration %d ' % (self.epoch, batch_step))
322 | # compute data time
323 | self.data_clock.start()
324 |
325 | @classmethod
326 | def save_image_v2(cls, tensor_data, name, save_dir_dir, mask_or_flow_or_image='image', if_flow_data_save_png=False):
327 | def decom(a):
328 | b = tools.tensor_gpu(a, check_on=False)[0]
329 | c = b[0, :, :, :]
330 | c = np.transpose(c, (1, 2, 0))
331 | return c
332 |
333 | if mask_or_flow_or_image == 'flow':
334 | flow_f_np = decom(tensor_data)
335 | if flow_f_np.shape[2] == 2:
336 | cv2.imwrite(os.path.join(save_dir_dir, name + '_s.png'), tools.flow_to_image(flow_f_np)[:, :, ::-1])
337 | if if_flow_data_save_png:
338 | save_path = os.path.join(save_dir_dir, name + '.png')
339 | tools.write_kitti_png_file(save_path, flow_f_np)
340 | elif flow_f_np.shape[2] == 3:
341 | flow = flow_f_np[:, :, :2]
342 | mask = flow_f_np[:, :, 2]
343 | cv2.imwrite(os.path.join(save_dir_dir, name + '_s.png'), tools.flow_to_image(flow)[:, :, ::-1])
344 | if if_flow_data_save_png:
345 | save_path = os.path.join(save_dir_dir, name + '.png')
346 | tools.write_kitti_png_file(save_path, flow, mask_data=mask)
347 | else:
348 | raise ValueError('flow_f_np shape not right: %s' % flow_f_np.shape)
349 | elif mask_or_flow_or_image == 'mask':
350 | mask = decom(tensor_data)
351 | cv2.imwrite(os.path.join(save_dir_dir, name + '.png'), tools.Show_GIF.im_norm(mask * 255))
352 | elif mask_or_flow_or_image == 'image':
353 | img1_np = tools.Show_GIF.im_norm(decom(tensor_data))
354 | # img1_np = decom(tensor_data)
355 | cv2.imwrite(os.path.join(save_dir_dir, name + '.png'), img1_np[:, :, ::-1])
356 | else:
357 | raise ValueError('wrong data type: %s' % mask_or_flow_or_image)
358 |
359 | def eval_forward(self, im1, im2, flow, *args):
360 | # ==================================================================== 网络输出 ======================================================================
361 | with torch.no_grad():
362 | if self.conf.model_name.lower() == '加油':
363 | flow_fw, flow_bw, app_flow_1, app_flow_2, _, _ = self.net_work(im1, im2) # flow from im1->im2
364 | pred_flow = flow_fw
365 | elif self.conf.model_name.lower() in ['pwcirrbiv5_v4', 'pwcirrbiv5_v5']:
366 | input_dict = {'im1': im1, 'im2': im2, 'if_loss': False, 'if_test': True}
367 | output_dict = self.net_work(input_dict)
368 | flow_fw, flow_bw = output_dict['flow_f_out'], output_dict['flow_b_out']
369 | flows = output_dict['flows']
370 | pred_flow = flow_fw
371 | elif self.conf.model_name.lower() in ['pwcirrbiv5_v4_show', 'pwcirrbiv5_v5_show']:
372 | if self.conf.save_running_process:
373 | running_process_dir = os.path.join(self.training_shower.save_dir, 'running_process')
374 | sample_dir = os.path.join(running_process_dir, '%s' % self.cnt)
375 | tools.check_dir(sample_dir)
376 | input_dict = {'im1': im1, 'im2': im2, 'if_loss': False, 'if_test': True, 'save_running_process': True, 'process_dir': sample_dir}
377 | else:
378 | input_dict = {'im1': im1, 'im2': im2, 'if_loss': False, 'if_test': True}
379 | output_dict = self.net_work(input_dict)
380 | flow_fw, flow_bw = output_dict['flow_f_out'], output_dict['flow_b_out']
381 | flows = output_dict['flows']
382 | pred_flow = flow_fw
383 | # print('======')
384 | # tools.check_tensor(flow, 'gt flow')
385 | # tools.check_tensor(flow_fw, 'output_flow_fw')
386 | # for i, (fw, fb) in enumerate(flows):
387 | # tools.check_tensor(fw, '%s scale fw' % i)
388 | # print('======')
389 | if len(args) > 0 and self.conf.save_running_process:
390 | def decom(a):
391 | b = tools.tensor_gpu(a, check_on=False)[0]
392 | c = b[0, :, :, :]
393 | c = np.transpose(c, (1, 2, 0))
394 | return c
395 |
396 | occ_mask = args[0]
397 | # save gt occ_mask
398 | gt_occ_mask_np = decom(occ_mask)
399 | gt_flow_np = decom(flow)
400 | pred_flow_np = decom(pred_flow)
401 | # save gt_flow
402 | cv2.imwrite(os.path.join(sample_dir, 'gt_flow_np' + '.png'), tools.flow_to_image(gt_flow_np)[:, :, ::-1])
403 | # save pred flow_f
404 | cv2.imwrite(os.path.join(sample_dir, 'pred_flow_f' + '.png'), tools.flow_to_image(pred_flow_np)[:, :, ::-1])
405 |
406 | # show flow gt error image
407 | flow_error_image = tools.lib_to_show_flow.flow_error_image_np(pred_flow_np, gt_flow_np, gt_occ_mask_np)
408 | # print('flow_error_image', np.max(flow_error_image), np.min(flow_error_image))
409 | cv2.imwrite(os.path.join(sample_dir, 'flow_error_image' + '.png'), tools.Show_GIF.im_norm(flow_error_image))
410 | flow_error_image_gray = tools.lib_to_show_flow.flow_error_image_np(pred_flow_np, gt_flow_np, gt_occ_mask_np, log_colors=False)
411 | cv2.imwrite(os.path.join(sample_dir, 'flow_error_image_gray' + '.png'), tools.Show_GIF.im_norm(flow_error_image_gray))
412 | if self.conf.model_name.lower() == 'pwcirrbiv5_show_v3': # save some results
413 | occmask, noc_gt_flow, nocmask = args
414 | save_dir = os.path.join(self.training_shower.save_dir, 'saving_res')
415 | tools.check_dir(save_dir)
416 | dir_name = '%s' % self.cnt # + '_occ_%.3f_'%occ_value.item()
417 | save_dir_dir = os.path.join(save_dir, dir_name)
418 | tools.check_dir(save_dir_dir)
419 | occ_fw = output_dict['occ_fw']
420 | self.save_image_v2(flow_fw, 'flow_f', save_dir_dir, 'flow', True)
421 | self.save_image_v2(flow, 'gt', save_dir_dir, 'flow', True)
422 | self.save_image_v2(noc_gt_flow, 'noc_gt_flow', save_dir_dir, 'flow', True)
423 | self.save_image_v2(occmask, 'gt_occ_mask', save_dir_dir, 'mask', False)
424 | self.save_image_v2(nocmask, 'gt_noc_mask', save_dir_dir, 'mask', False)
425 | self.save_image_v2(im1, 'im1', save_dir_dir, 'image', False)
426 | self.save_image_v2(im2, 'im2', save_dir_dir, 'image', False)
427 | self.save_image_v2(occ_fw, 'occ_mask', save_dir_dir, 'mask', False)
428 | else:
429 | raise ValueError(' not implemented model name: %s' % self.conf.model_name)
430 | if self.conf.if_do_eval: # 管理在测试或者验证过程中, 是否展示gif结果,
431 | self.cnt += 1
432 | if self.conf.if_do_test:
433 | im1_warp = tools.torch_warp(im2, pred_flow)
434 | warp_error = torch.sqrt((im1_warp - im1) ** 2)
435 | warp_error = warp_error.mean()
436 | print_str = 'iter_%s warp_error%.5f' % (self.cnt, warp_error.item())
437 | self.print_str = print_str
438 | if self.conf.if_test_save_show_results:
439 | im1_np, im2_np, pred_flow_np, im1_warp_np = tools.tensor_gpu(im1, im2, pred_flow, im1_warp, check_on=False)
440 | frame_1_ls = [('im1', im1_np), ('im1 ', im1_np), ('im1_warp', im1_warp_np), ('flow pred', pred_flow_np)]
441 | frame_2_ls = [('im2', im2_np), ('im1_warp', im1_warp_np), ('im1_warp', im1_warp_np), ('flow pred', pred_flow_np)]
442 | self.training_shower.get_batch_pair_all_list_nchw_check_flow_frame1_frame2_gif(batch_dict_ls_frame1=frame_1_ls, batch_dict_ls_frame2=frame_2_ls,
443 | name=print_str)
444 | else:
445 | im1_warp = tools.torch_warp(im2, pred_flow)
446 | im1_gt_warp = tools.torch_warp(im2, flow)
447 |
448 | warp_error = torch.sqrt((im1_warp - im1) ** 2)
449 | warp_error = warp_error.mean()
450 |
451 | gt_warp_error = torch.sqrt((im1_warp - im1_gt_warp) ** 2)
452 | gt_warp_error = gt_warp_error.mean()
453 | print_str = 'iter_%s warp_error%.5f gtwarperror_%.5f' % (self.cnt, warp_error.item(), gt_warp_error.item())
454 | self.print_str = print_str
455 | if self.conf.if_do_eval_save_show_result:
456 | im1_np, im2_np, gt_flow_np, pred_flow_np, im1_warp_np, im1_gt_warp_np = tools.tensor_gpu(im1, im2, flow, pred_flow, im1_warp, im1_gt_warp, check_on=False)
457 | frame_1_ls = [('im1', im1_np), ('im1', im1_np), ('gt_warp_im1', im1_gt_warp_np), ('flow pred', pred_flow_np)]
458 | frame_2_ls = [('im2', im2_np), ('im1_warp', im1_warp_np), ('im1_warp', im1_warp_np), ('flow gt', gt_flow_np)]
459 | self.training_shower.get_batch_pair_all_list_nchw_check_flow_frame1_frame2_gif(batch_dict_ls_frame1=frame_1_ls, batch_dict_ls_frame2=frame_2_ls,
460 | name=print_str)
461 | if self.conf.if_save_flow_in_eval_or_test: # 管理是否保存flow结果,存为.png或者.flo
462 | self.temp_save_eval_test = [pred_flow, flow] # 缓存起来
463 | if self.conf.multi_scale_eval:
464 | return pred_flow, flows
465 | return pred_flow
466 |
467 | def eval_save_result(self, save_name, *args, **kwargs):
468 | def flow_tensor_np_h_w_2(a):
469 | a_np = tools.tensor_gpu(a, check_on=False)[0]
470 | a_np = a_np[0, :, :, :] # n,c,h,w
471 | a_np = np.transpose(a_np, (1, 2, 0)) # h,w,2
472 | return a_np
473 |
474 | if self.conf.if_do_eval:
475 | if self.conf.if_do_eval_print:
476 | print(self.print_str + ' ' + save_name)
477 | if self.conf.if_do_test:
478 | if self.conf.if_test_save_show_results:
479 | if len(args) > 0:
480 | sample_dir_name = args[0]
481 | self.training_shower.put_frame1_frame2_gif(name=sample_dir_name + '_' + save_name + '_' + self.print_str)
482 | else:
483 | self.training_shower.put_frame1_frame2_gif(name=save_name + '_' + self.print_str)
484 | else:
485 | if self.conf.if_do_eval_save_show_result:
486 | self.training_shower.put_frame1_frame2_gif(name=save_name + '_' + self.print_str)
487 |
488 | if self.conf.if_save_flow_in_eval_or_test:
489 | if self.conf.if_do_test: # test
490 | save_dir = os.path.join(self.training_shower.save_dir, 'save_test_flow')
491 | tools.check_dir(save_dir)
492 | if len(args) > 0:
493 | sample_dir_name = args[0]
494 | if type(sample_dir_name) == str:
495 | save_dir = os.path.join(save_dir, sample_dir_name)
496 | tools.check_dir(save_dir)
497 | if self.conf.if_save_flow_in_eval_or_test_type == 'png':
498 | pred_flow, _ = self.temp_save_eval_test
499 | pred_flow_np = flow_tensor_np_h_w_2(pred_flow)
500 | save_path = os.path.join(save_dir, save_name + '.png')
501 | # 2015上这样save是可以用的,但同样的save方法2012就不能用了
502 | tools.write_kitti_png_file(save_path, pred_flow_np)
503 | # 尝试一个新的方式
504 | # tools.lib_to_show_flow.flow_write_png( u=pred_flow_np[:,:,0], v=pred_flow_np[:,:,1], fpath=save_path)
505 | elif self.conf.if_save_flow_in_eval_or_test_type == 'flo':
506 | pred_flow, _ = self.temp_save_eval_test
507 | pred_flow_np = flow_tensor_np_h_w_2(pred_flow)
508 | save_path = os.path.join(save_dir, save_name + '.flo')
509 | tools.write_flo(flow=pred_flow_np, filename=save_path) # write_flow, or, write_flo
510 | else:
511 | raise ValueError('wrong if_save_flow_eval_test_type, should be png or flo, but got: %s' % self.conf.if_save_flow_in_eval_or_test_type)
512 | else: # eval
513 | save_dir = os.path.join(self.training_shower.save_dir, 'eval_test')
514 | tools.check_dir(save_dir)
515 | if len(args) > 0:
516 | sample_dir_name = args[0]
517 | if type(sample_dir_name) == str:
518 | save_dir = os.path.join(save_dir, sample_dir_name)
519 | tools.check_dir(save_dir)
520 | if self.conf.if_save_flow_in_eval_or_test_type == 'png':
521 | pred_flow, gt_flow = self.temp_save_eval_test
522 | pred_flow_np = flow_tensor_np_h_w_2(pred_flow)
523 | save_path = os.path.join(save_dir, save_name + '.png')
524 | tools.write_kitti_png_file(save_path, pred_flow_np)
525 | # tools.WriteKittiPngFile(save_path, pred_flow_np)
526 | # save gt
527 | gt_save_path = os.path.join(save_dir, save_name + '_gt.png')
528 | tools.write_kitti_png_file(gt_save_path, flow_tensor_np_h_w_2(gt_flow))
529 | elif self.conf.if_save_flow_in_eval_or_test_type == 'flo':
530 | pred_flow, gt_flow = self.temp_save_eval_test
531 | pred_flow_np = flow_tensor_np_h_w_2(pred_flow)
532 | save_path = os.path.join(save_dir, save_name + '.flo')
533 | tools.write_flo(flow=pred_flow_np, filename=save_path) # write_flow, or, write_flo
534 | gt_save_path = os.path.join(save_dir, save_name + '_gt.flo')
535 | gt_flow_np = flow_tensor_np_h_w_2(gt_flow)
536 | tools.write_flo(flow=gt_flow_np, filename=gt_save_path) # write_flow, or, write_flo
537 | if_check = True
538 | if if_check:
539 | temp = tools.read_flo(save_path) # read_flow, or, read_flo
540 | temp_gt = tools.read_flo(gt_save_path) # read_flow, or, read_flo
541 | temp_error = temp - pred_flow_np
542 | temp_gt_error = temp_gt - gt_flow_np
543 | print('pred flow save write .flo file, error: ', np.min(temp_error), np.max(temp_error), 'gt r.w.error: ', np.min(temp_gt_error), np.max(temp_gt_error))
544 | else:
545 | raise ValueError('wrong if_save_flow_eval_test_type, should be png or flo, but got: %s' % self.conf.if_save_flow_in_eval_or_test_type)
546 |
547 | def train(self, epoch=0): # 进入训练状态
548 | # torch.cuda.empty_cache()
549 | if hasattr(torch.cuda, 'empty_cache'):
550 | torch.cuda.empty_cache()
551 | torch.set_grad_enabled(True)
552 | self.net_work.train()
553 | self.app_loss_meter.reset()
554 | self.appd_loss_meter.reset()
555 | self.msd_loss_meter.reset()
556 | self.sp_msd_loss_meter.reset()
557 | self.final_sp_loss_meter.reset()
558 | self.occ_loss_meter.reset()
559 | self.spatial_trans_loss_meter.reset()
560 | self.photo_loss_meter.reset()
561 | self.smooth_loss_meter.reset()
562 | self.inpaint_img_loss_meter.reset()
563 | self.multi_scale_loss_meter.reset()
564 | if epoch % 1 == 0:
565 | self.scheduler.step()
566 | print('epoch', epoch, 'lr={:.6f}'.format(self.scheduler.get_lr()[0]))
567 | self.epoch = epoch
568 |
569 |
570 | class Train_Config(tools.abstract_config):
571 |
572 | def __init__(self, **kwargs):
573 | self.batchsize = 4
574 | self.gpu_opt = None # gpu option
575 | self.n_epoch = 1000 # number of epoch
576 | self.if_eval = True # do evaluation during the training process
577 | self.train_data_name = 'kitti_2015_mv' # or kitti_2012_mv
578 | self.eval_data_name = '2015_train' # or 2015_train
579 | self.eval_per = -1 # do evaluation every N iters
580 | self.eval_batchsize = 1 # batch size for evaluation
581 | self.use_prefether = True # faster loader
582 | self.if_histmatch = False # do not use this
583 | self.update(kwargs)
584 |
585 |
586 | class Training():
587 | def __init__(self, **kwargs):
588 | self.conf = Train_Config(**kwargs)
589 | self.data_conf = self.get_train_data(**kwargs)
590 |
591 | def get_train_data(self, **kwargs):
592 | '''
593 | get dataset
594 | data config = {
595 | 'crop_size': (256, 832),
596 | 'rho': 8,
597 | 'swap_images': True,
598 | 'normalize': True,
599 | 'horizontal_flip_aug': True,
600 | }
601 | '''
602 | if self.conf.train_data_name == 'kitti_2015_mv':
603 | data_conf = kitti_train.kitti_data_with_start_point.config(mv_type='2015', **kwargs)
604 | elif self.conf.train_data_name == 'kitti_2012_mv':
605 | data_conf = kitti_train.kitti_data_with_start_point.config(mv_type='2012', **kwargs)
606 | else:
607 | raise ValueError('not implemented train data: %s' % self.conf.train_data_name)
608 | return data_conf
609 |
610 | def get_network(self):
611 | pass
612 |
613 | def get_eval_benchmark(self):
614 | pass
615 |
616 | def do_training(self):
617 | pass
618 |
619 |
620 | param_dict = {
621 | # training
622 | 'batchsize': 4,
623 | 'gpu_opt': None,
624 | 'n_epoch': 1000,
625 | 'if_eval': True,
626 | 'train_data_name': 'kitti_2015_mv',
627 | 'eval_data_name': '2015_train', # or 2015_train
628 | 'eval_per': -1, # 隔多少个iter做验证
629 | 'eval_batchsize': 1, # 验证batch size
630 | 'use_prefether': True, # 这个速度快一点,会好一点
631 |
632 | # data
633 | 'crop_size': (256, 832),
634 | 'rho': 8,
635 | 'swap_images': True,
636 | 'normalize': True,
637 | 'horizontal_flip_aug': True,
638 |
639 | # network
640 |
641 |
642 | }
643 |
--------------------------------------------------------------------------------
/scripts/simple_train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | from utils.tools import tools
4 | import cv2
5 | import numpy as np
6 | from copy import deepcopy
7 | import torch
8 | import warnings # ignore warnings
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from dataset.kitti_dataset import kitti_train, kitti_flow
12 | from model.upflow import UPFlow_net
13 | from torch.utils.data import DataLoader
14 | import time
15 |
16 | ''' scripts for training:
17 | 1. simply using photo loss and smooth loss
18 | 2. add occlusion checking
19 | 3. add teacher-student loss(ARFlow)
20 | '''
21 |
22 | # save and log loss value during training
23 | class Loss_manager():
24 | def __init__(self):
25 | self.error_meter = tools.Avg_meter_ls()
26 |
27 | def fetch_loss(self, loss, loss_dict, name, batch_N, short_name=None):
28 | if name not in loss_dict.keys():
29 | pass
30 | elif loss_dict[name] is None:
31 | pass
32 | else:
33 | this_loss = loss_dict[name].mean()
34 | self.error_meter.update(name=name, val=this_loss.item(), num=batch_N, short_name=short_name)
35 | loss = loss + this_loss
36 | return loss
37 |
38 | def prepare_epoch(self):
39 | self.error_meter.reset()
40 |
41 | def log_info(self):
42 | p_str = self.error_meter.print_all_losses()
43 | return p_str
44 |
45 | def compute_loss(self, loss_dict, batch_N):
46 | loss = 0
47 | loss = self.fetch_loss(loss=loss, loss_dict=loss_dict, name='photo_loss', short_name='ph', batch_N=batch_N)
48 | loss = self.fetch_loss(loss=loss, loss_dict=loss_dict, name='smooth_loss', short_name='sm', batch_N=batch_N)
49 | loss = self.fetch_loss(loss=loss, loss_dict=loss_dict, name='census_loss', short_name='cen', batch_N=batch_N)
50 | # photo_loss, smooth_loss, census_loss = output_dict['photo_loss'].mean(), output_dict['smooth_loss'], output_dict['census_loss']
51 | loss = self.fetch_loss(loss=loss, loss_dict=loss_dict, name='msd_loss', short_name='msd', batch_N=batch_N)
52 | loss = self.fetch_loss(loss=loss, loss_dict=loss_dict, name='eq_loss', short_name='eq', batch_N=batch_N)
53 | loss = self.fetch_loss(loss=loss, loss_dict=loss_dict, name='oi_loss', short_name='oi', batch_N=batch_N)
54 | return loss
55 |
56 | class Eval_model(tools.abs_test_model):
57 | def __init__(self):
58 | super(Eval_model, self).__init__()
59 | self.net_work = None
60 |
61 | def eval_forward(self, im1, im2, gt, *args):
62 | if self.net_work is None:
63 | raise ValueError('not network for evaluation')
64 | # === network output
65 | with torch.no_grad():
66 | input_dict = {'im1': im1, 'im2': im2, 'if_loss': False}
67 | output_dict = self.net_work(input_dict)
68 | flow_fw, flow_bw = output_dict['flow_f_out'], output_dict['flow_b_out']
69 | pred_flow = flow_fw
70 | return pred_flow
71 |
72 | def eval_save_result(self, save_name, predflow, *args, **kwargs):
73 | # you can save flow results here
74 | # print(save_name)
75 | pass
76 |
77 | def change_model(self, net):
78 | net.eval()
79 | self.net_work = net
80 |
81 |
82 | class Trainer():
83 | class Config(tools.abstract_config):
84 | def __init__(self, **kwargs):
85 | self.exp_dir = './demo_exp'
86 | self.if_cuda = True
87 |
88 | self.batchsize = 2
89 | self.NUM_WORKERS = 4
90 | self.n_epoch = 1000
91 | self.batch_per_epoch = 500
92 | self.batch_per_print = 20
93 | self.lr = 1e-4
94 | self.weight_decay = 1e-4
95 | self.scheduler_gamma = 1
96 |
97 | # init
98 | self.update(kwargs)
99 |
100 | def __call__(self, ):
101 | t = Trainer(self)
102 | return t
103 |
104 | def __init__(self, conf: Config):
105 | self.conf = conf
106 |
107 | tools.check_dir(self.conf.exp_dir)
108 |
109 | # load network
110 | self.net = self.load_model()
111 |
112 | # for evaluation
113 | self.bench = self.load_eval_bench()
114 | self.eval_model = Eval_model()
115 |
116 | # load training dataset
117 | self.train_set = self.load_training_dataset()
118 |
119 | def training(self):
120 | train_loader = tools.data_prefetcher(self.train_set, batch_size=self.conf.batchsize, shuffle=True, num_workers=self.conf.NUM_WORKERS, pin_memory=True, drop_last=True)
121 | optimizer = optim.Adam(self.net.parameters(), lr=self.conf.lr, amsgrad=True, weight_decay=self.conf.weight_decay)
122 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.conf.scheduler_gamma)
123 | loss_manager = Loss_manager()
124 | timer = tools.time_clock()
125 | print("start training" + '=' * 10)
126 | i_batch = 0
127 | epoch = 0
128 | loss_manager.prepare_epoch()
129 | current_val, best_val, best_epoch = 0, 0, 0
130 | timer.start()
131 | while True:
132 | # prepare batch data
133 | batch_value = train_loader.next()
134 | if batch_value is None:
135 | batch_value = train_loader.next()
136 | assert batch_value is not None
137 | batchsize = batch_value['im1'].shape[0] # check if the im1 exists
138 | i_batch += 1
139 | # train batch
140 | self.net.train()
141 | optimizer.zero_grad()
142 | out_data = self.net(batch_value) #
143 |
144 | loss_dict = out_data['loss_dict']
145 | loss = loss_manager.compute_loss(loss_dict=loss_dict, batch_N=batchsize)
146 |
147 | loss.backward()
148 | optimizer.step()
149 | if i_batch % self.conf.batch_per_print == 0:
150 | pass
151 | if i_batch % self.conf.batch_per_epoch == 0:
152 | # do eval and check if save model todo===
153 | epoch+=1
154 | timer.end()
155 | print(' === epoch use time %.2f' % timer.get_during())
156 | scheduler.step(epoch=epoch)
157 | timer.start()
158 |
159 | def evaluation(self):
160 | self.eval_model.change_model(self.net)
161 | epe_all, f1, epe_noc, epe_occ = self.bench(self.eval_model)
162 | print('EPE All = %.2f, F1 = %.2f, EPE Noc = %.2f, EPE Occ = %.2f' % (epe_all, f1, epe_noc, epe_occ))
163 | print_str = 'EPE_%.2f__F1_%.2f__Noc_%.2f__Occ_%.2f' % (epe_all, f1, epe_noc, epe_occ)
164 | return epe_all, print_str
165 |
166 | # ======
167 | def load_model(self):
168 | param_dict = {
169 | # use cost volume norm
170 | 'if_norm_before_cost_volume': True,
171 | 'norm_moments_across_channels': False,
172 | 'norm_moments_across_images': False,
173 | 'if_froze_pwc': False,
174 | 'if_use_cor_pytorch': False, # speed is very slow, just for debug when cuda correlation is not compiled
175 | 'if_sgu_upsample': False, # 先把这个关掉跑通吧
176 | }
177 | pretrain_path = None # pretrain path
178 | net_conf = UPFlow_net.config()
179 | net_conf.update(param_dict)
180 | net = net_conf() # .cuda()
181 | if pretrain_path is not None:
182 | net.load_model(pretrain_path, if_relax=True, if_print=False)
183 | if self.conf.if_cuda:
184 | net = net.cuda()
185 | return net
186 |
187 | def load_eval_bench(self):
188 | bench = kitti_flow.Evaluation_bench(name='2015_train', if_gpu=self.conf.if_cuda, batch_size=1)
189 | return bench
190 |
191 | def load_training_dataset(self):
192 | data_config = {
193 | 'crop_size': (256, 832),
194 | 'rho': 8,
195 | 'swap_images': True,
196 | 'normalize': True,
197 | 'horizontal_flip_aug': True,
198 | }
199 | data_conf = kitti_train.kitti_data_with_start_point.config(mv_type='2015', **data_config)
200 | dataset = data_conf()
201 | return dataset
202 |
203 |
204 | if __name__ == '__main__':
205 | training_param = {} # change param here
206 | conf = Trainer.Config(**training_param)
207 | trainer = conf()
208 | trainer.training()
209 |
--------------------------------------------------------------------------------
/scripts/upflow_kitti2015.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/scripts/upflow_kitti2015.pth
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | from utils.tools import tools
4 | import cv2
5 | import numpy as np
6 | from copy import deepcopy
7 | import torch
8 | import warnings # ignore warnings
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from dataset.kitti_dataset import kitti_train, kitti_flow
12 | from model.upflow import UPFlow_net
13 | from torch.utils.data import DataLoader
14 | import time
15 |
16 | if_cuda = True
17 |
18 |
19 | class Test_model(tools.abs_test_model):
20 | def __init__(self, pretrain_path='./scripts/upflow_kitti2015.pth'):
21 | super(Test_model, self).__init__()
22 | param_dict = {
23 | # use cost volume norm
24 | 'if_norm_before_cost_volume': True,
25 | 'norm_moments_across_channels': False,
26 | 'norm_moments_across_images': False,
27 | 'if_froze_pwc': False,
28 | 'if_use_cor_pytorch': False, # speed is very slow, just for debug when cuda correlation is not compiled
29 | 'if_sgu_upsample': True,
30 | }
31 | net_conf = UPFlow_net.config()
32 | net_conf.update(param_dict)
33 | net = net_conf() # .cuda()
34 | net.load_model(pretrain_path, if_relax=True, if_print=True)
35 | if if_cuda:
36 | net = net.cuda()
37 | net.eval()
38 | self.net_work = net
39 |
40 | def eval_forward(self, im1, im2, gt, *args):
41 | # === network output
42 | with torch.no_grad():
43 | input_dict = {'im1': im1, 'im2': im2, 'if_loss': False}
44 | output_dict = self.net_work(input_dict)
45 | flow_fw, flow_bw = output_dict['flow_f_out'], output_dict['flow_b_out']
46 | pred_flow = flow_fw
47 | return pred_flow
48 |
49 | def eval_save_result(self, save_name, predflow, *args, **kwargs):
50 | # you can save flow results here
51 | print(save_name)
52 |
53 |
54 | def kitti_2015_test():
55 | pretrain_path = './scripts/upflow_kitti2015.pth'
56 | # note that eval batch size should be 1 for KITTI 2012 and KITTI 2015 (image size may be different for different sequence)
57 | bench = kitti_flow.Evaluation_bench(name='2015_train', if_gpu=if_cuda, batch_size=1)
58 | testmodel = Test_model(pretrain_path=pretrain_path)
59 | epe_all, f1, epe_noc, epe_occ = bench(testmodel)
60 | print('EPE All = %.2f, F1 = %.2f, EPE Noc = %.2f, EPE Occ = %.2f' % (epe_all, f1, epe_noc, epe_occ))
61 |
62 |
63 | if __name__ == '__main__':
64 | kitti_2015_test()
65 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/utils/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/loss.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/utils/__pycache__/loss.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/pytorch_correlation.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/utils/__pycache__/pytorch_correlation.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/tools.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coolbeam/UPFlow_pytorch/54c26608b31c548212f337300c2c96789bfa594f/utils/__pycache__/tools.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 20-3-8 下午5:53
3 | import torch
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from utils.tools import tools
7 |
8 |
9 | def upsample2d_as(inputs, target_as, mode="bilinear"):
10 | _, _, h, w = target_as.size()
11 | return F.interpolate(inputs, [h, w], mode=mode, align_corners=True)
12 |
13 |
14 | class loss_functions():
15 |
16 | @classmethod
17 | def photo_loss_function(cls, diff, mask, q, charbonnier_or_abs_robust, if_use_occ, averge=True):
18 | if charbonnier_or_abs_robust:
19 | if if_use_occ:
20 | p = ((diff) ** 2 + 1e-6).pow(q)
21 | p = p * mask
22 | if averge:
23 | p = p.mean()
24 | ap = mask.mean()
25 | else:
26 | p = p.sum()
27 | ap = mask.sum()
28 | loss_mean = p / (ap * 2 + 1e-6)
29 | else:
30 | p = ((diff) ** 2 + 1e-8).pow(q)
31 | if averge:
32 | p = p.mean()
33 | else:
34 | p = p.sum()
35 | return p
36 | else:
37 | if if_use_occ:
38 | diff = (torch.abs(diff) + 0.01).pow(q)
39 | diff = diff * mask
40 | diff_sum = torch.sum(diff)
41 | loss_mean = diff_sum / (torch.sum(mask) * 2 + 1e-6)
42 | else:
43 | diff = (torch.abs(diff) + 0.01).pow(q)
44 | if averge:
45 | loss_mean = diff.mean()
46 | else:
47 | loss_mean = diff.sum()
48 | return loss_mean
49 |
50 | @classmethod
51 | def census_loss_torch(cls, img1, img1_warp, mask, q, charbonnier_or_abs_robust, if_use_occ, averge=True, max_distance=3):
52 | patch_size = 2 * max_distance + 1
53 |
54 | def _ternary_transform_torch(image):
55 | R, G, B = torch.split(image, 1, 1)
56 | intensities_torch = (0.2989 * R + 0.5870 * G + 0.1140 * B) # * 255 # convert to gray
57 | # intensities = tf.image.rgb_to_grayscale(image) * 255
58 | out_channels = patch_size * patch_size
59 | w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels)) # h,w,1,out_c
60 | w_ = np.transpose(w, (3, 2, 0, 1)) # 1,out_c,h,w
61 | weight = torch.from_numpy(w_).float()
62 | if image.is_cuda:
63 | weight = weight.cuda()
64 | patches_torch = torch.conv2d(input=intensities_torch, weight=weight, bias=None, stride=[1, 1], padding=[max_distance, max_distance])
65 | transf_torch = patches_torch - intensities_torch
66 | transf_norm_torch = transf_torch / torch.sqrt(0.81 + transf_torch ** 2)
67 | return transf_norm_torch
68 |
69 | def _hamming_distance_torch(t1, t2):
70 | dist = (t1 - t2) ** 2
71 | dist = torch.sum(dist / (0.1 + dist), 1, keepdim=True)
72 | return dist
73 |
74 | def create_mask_torch(tensor, paddings):
75 | shape = tensor.shape # N,c, H,W
76 | inner_width = shape[2] - (paddings[0][0] + paddings[0][1])
77 | inner_height = shape[3] - (paddings[1][0] + paddings[1][1])
78 | inner_torch = torch.ones([shape[0], shape[1], inner_width, inner_height]).float()
79 | if tensor.is_cuda:
80 | inner_torch = inner_torch.cuda()
81 | mask2d = F.pad(inner_torch, [paddings[0][0], paddings[0][1], paddings[1][0], paddings[1][1]])
82 | return mask2d
83 |
84 | img1 = _ternary_transform_torch(img1)
85 | img1_warp = _ternary_transform_torch(img1_warp)
86 | dist = _hamming_distance_torch(img1, img1_warp)
87 | transform_mask = create_mask_torch(mask, [[max_distance, max_distance],
88 | [max_distance, max_distance]])
89 | census_loss = cls.photo_loss_function(diff=dist, mask=mask * transform_mask, q=q,
90 | charbonnier_or_abs_robust=charbonnier_or_abs_robust, if_use_occ=if_use_occ, averge=averge)
91 | return census_loss
92 |
93 | @classmethod
94 | def flow_smooth_delta(cls, flow, if_second_order=False):
95 | def gradient(x):
96 | D_dy = x[:, :, 1:] - x[:, :, :-1]
97 | D_dx = x[:, :, :, 1:] - x[:, :, :, :-1]
98 | return D_dx, D_dy
99 |
100 | dx, dy = gradient(flow)
101 | # dx2, dxdy = gradient(dx)
102 | # dydx, dy2 = gradient(dy)
103 | if if_second_order:
104 | dx2, dxdy = gradient(dx)
105 | dydx, dy2 = gradient(dy)
106 | smooth_loss = dx.abs().mean() + dy.abs().mean() + dx2.abs().mean() + dxdy.abs().mean() + dydx.abs().mean() + dy2.abs().mean()
107 | else:
108 | smooth_loss = dx.abs().mean() + dy.abs().mean()
109 | # smooth_loss = dx.abs().mean() + dy.abs().mean() # + dx2.abs().mean() + dxdy.abs().mean() + dydx.abs().mean() + dy2.abs().mean()
110 | # 暂时不上二阶的平滑损失,似乎加上以后就太猛了,无法降低photo loss TODO
111 | return smooth_loss
112 |
113 | @classmethod
114 | def edge_aware_smoothness_per_pixel(cls, img, pred):
115 | def gradient_x(img):
116 | gx = img[:, :, :-1, :] - img[:, :, 1:, :]
117 | return gx
118 |
119 | def gradient_y(img):
120 | gy = img[:, :, :, :-1] - img[:, :, :, 1:]
121 | return gy
122 |
123 | pred_gradients_x = gradient_x(pred)
124 | pred_gradients_y = gradient_y(pred)
125 |
126 | image_gradients_x = gradient_x(img)
127 | image_gradients_y = gradient_y(img)
128 |
129 | weights_x = torch.exp(-torch.mean(torch.abs(image_gradients_x), 1, keepdim=True))
130 | weights_y = torch.exp(-torch.mean(torch.abs(image_gradients_y), 1, keepdim=True))
131 |
132 | smoothness_x = torch.abs(pred_gradients_x) * weights_x
133 | smoothness_y = torch.abs(pred_gradients_y) * weights_y
134 | return torch.mean(smoothness_x) + torch.mean(smoothness_y)
135 |
--------------------------------------------------------------------------------
/utils/pytorch_correlation.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch
7 | from utils.tools import tools
8 |
9 |
10 | class Corr_pyTorch(tools.abstract_model):
11 | '''
12 | my implementation of correlation layer using pytorch
13 | note that the Ispeed is much slower than cuda version
14 | '''
15 |
16 | def __init__(self, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1, corr_multiply=1):
17 | assert pad_size == max_displacement
18 | assert stride1 == stride2 == 1
19 | super().__init__()
20 | self.pad_size = pad_size
21 | self.kernel_size = kernel_size
22 | self.stride1 = stride1
23 | self.stride2 = stride2
24 | self.max_hdisp = max_displacement
25 | self.padlayer = nn.ConstantPad2d(pad_size, 0)
26 |
27 | def forward(self, in1, in2):
28 | bz, cn, hei, wid = in1.shape
29 | # print(self.kernel_size, self.pad_size, self.stride1)
30 | f1 = F.unfold(in1, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=self.stride1)
31 | f2 = F.unfold(in2, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=self.stride2) # 在这一步抽取完了kernel以后做warping插值岂不美哉?
32 | # tools.check_tensor(in2, 'in2')
33 | # tools.check_tensor(f2, 'f2')
34 | searching_kernel_size = f2.shape[1]
35 | f2_ = torch.reshape(f2, (bz, searching_kernel_size, hei, wid))
36 | f2_ = torch.reshape(f2_, (bz * searching_kernel_size, hei, wid)).unsqueeze(1)
37 | # tools.check_tensor(f2_, 'f2_reshape')
38 | f2 = F.unfold(f2_, kernel_size=(hei, wid), padding=self.pad_size, stride=self.stride2)
39 | # tools.check_tensor(f2, 'f2_reunfold')
40 | _, kernel_number, window_number = f2.shape
41 | f2_ = torch.reshape(f2, (bz, searching_kernel_size, kernel_number, window_number))
42 | f2_2 = torch.transpose(f2_, dim0=1, dim1=3).transpose(2, 3)
43 | f1_2 = f1.unsqueeze(1)
44 | # tools.check_tensor(f1_2, 'f1_2_reshape')
45 | # tools.check_tensor(f2_2, 'f2_2_reshape')
46 | res = f2_2 * f1_2
47 | res = torch.mean(res, dim=2)
48 | res = torch.reshape(res, (bz, window_number, hei, wid))
49 | # tools.check_tensor(res, 'res')
50 | return res
51 |
--------------------------------------------------------------------------------