├── LICENSE
├── README.md
├── __init__.py
├── augmentations.py
├── commandline.py
├── configuration.py
├── datasets
├── __init__.py
├── common.py
├── flyingThings3D.py
├── flyingThings3DMultiframe.py
├── flyingchairs.py
├── flyingchairsOcc.py
├── kitti_comb_multiframe.py
├── kitti_combined.py
├── sintel.py
├── sintel_multiframe.py
└── transforms.py
├── inference.py
├── install.sh
├── logger.py
├── losses.py
├── main.py
├── models
├── IRR_PWC.py
├── IRR_PWC_occ_joint.py
├── STAR.py
├── __init__.py
├── correlation_package
│ ├── __init__.py
│ ├── correlation.py
│ ├── correlation_cuda.cc
│ ├── correlation_cuda_kernel.cu
│ ├── correlation_cuda_kernel.cuh
│ └── setup.py
├── irr_modules.py
├── pwc_modules.py
├── pwcnet.py
├── pwcnet_irr.py
├── pwcnet_irr_occ_joint.py
├── pwcnet_occ_joint.py
├── tr_features.py
└── tr_flow.py
├── optim
└── __init__.py
├── results.png
├── runtime.py
├── saved_checkpoint
├── StarFlow_kitti
│ └── checkpoint_latest.ckpt
├── StarFlow_sintel
│ └── checkpoint_latest.ckpt
└── StarFlow_things
│ └── checkpoint_best.ckpt
├── scripts_train
├── train_starflow_chairsocc.sh
├── train_starflow_kitti_full.sh
├── train_starflow_sintel_full.sh
└── train_starflow_things.sh
├── tools.py
└── utils
├── __init__.py
├── flow.py
└── interpolation.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # STaRFlow
2 |
3 |
4 |
5 | This repository is the PyTorch implementation of STaRFlow, a recurrent convolutional neural network for multi-frame optical flow estimation. This algorithm is presented in our paper **STaRFlow: A SpatioTemporal Recurrent Cell for Lightweight Multi-Frame Optical Flow Estimation**, Pierre Godet, [Alexandre Boulch](https://github.com/aboulch), [Aurélien Plyer](https://github.com/aplyer), and Guy Le Besnerais.
6 | [[Preprint]](https://arxiv.org/pdf/2007.05481.pdf)
7 |
8 |
9 | Please cite our paper if you find our work useful.
10 |
11 | @article{godet2020starflow,
12 | title={STaRFlow: A SpatioTemporal Recurrent Cell for Lightweight Multi-Frame Optical Flow Estimation},
13 | author={Godet, Pierre and Boulch, Alexandre and Plyer, Aur{\'e}lien and Le Besnerais, Guy},
14 | journal={arXiv preprint arXiv:2007.05481},
15 | year={2020}
16 | }
17 |
18 | Contact: pierre.godet@onera.fr
19 |
20 | ## Getting started
21 | This code has been developed and tested under Anaconda(Python 3.7, scipy 1.1, numpy 1.16), Pytorch 1.1 and CUDA 10.1 on Ubuntu 18.04.
22 |
23 | 1. Please install the followings:
24 |
25 | - Anaconda (Python 3.7)
26 | - __PyTorch 1.1__ (Linux, Conda, Python 3.7, CUDA 10) (`conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=10.0 -c pytorch`)
27 | - Depending on your system, configure `-gencode`, `-ccbin`, `cuda-path` in `models/correlation_package/setup.py` accordingly
28 | - scipy 1.1 (`conda install scipy=1.1`)
29 | - colorama (`conda install colorama`)
30 | - tqdm 4.32 (`conda install -c conda-forge tqdm=4.32`)
31 | - pypng (`pip install pypng`)
32 |
33 | 2. Then, install the correlation package:
34 | ```
35 | ./install.sh
36 | ```
37 |
38 |
39 | ## Pretrained Models
40 |
41 | The `saved_checkpoint` folder contains the pre-trained models of STaRFlow trained on
42 |
43 | 1. FlyingChairsOcc -> FlyingThings3D, or
44 | 2. FlyingChairsOcc -> FlyingThings3D -> MPI Sintel, or
45 | 3. FlyingChairsOcc -> FlyingThings3D -> KITTI (2012 and 2015).
46 |
47 |
48 | ## Inference
49 |
50 | The script `inference.py` can be used for testing the pre-trained models. Example:
51 |
52 | python inference.py \
53 | --model StarFlow \
54 | --checkpoint saved_checkpoint/StarFlow_things/checkpoint_best.ckpt \
55 | --data-root /data/mpisintelcomplete/training/final/ambush_6/ \
56 | --file-list frame_0004.png frame_0005.png frame_0006.png frame_0007.png
57 |
58 | By default, it saves the results in `./output/`.
59 |
60 |
61 | ## Training
62 |
63 | Data-loaders for multi-frame training can be found in the `datasets` folder, multi-frame losses are in `losses.py`, and every architecture used in the experiments presented in our paper is available in the `models` folder.
64 |
65 | ### Datasets
66 |
67 | The datasets used for this project are followings:
68 |
69 | - [FlyingChairsOcc dataset](https://github.com/visinf/irr/tree/master/flyingchairsocc)
70 | - [MPI Sintel Dataset](http://sintel.is.tue.mpg.de/downloads)
71 | - [KITTI Optical Flow 2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) and [KITTI Optical Flow 2012](http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php?benchmark=flow)
72 | - [FlyingThings3D subset](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
73 |
74 |
75 | ### Scripts for training
76 |
77 | The `scripts` folder contains training scripts for STaRFlow.
78 | To train the model, you can simply run the script file, e.g., `./train_starflow_chairsocc.sh`.
79 | In script files, please configure your own experiment directory (EXPERIMENTS_HOME) and dataset directory in your local system (e.g., SINTEL_HOME or KITTI_HOME).
80 |
81 |
82 | ## Acknowledgement
83 |
84 | This repository is a fork of the [IRR-PWC](https://github.com/visinf/irr) implementation from Junhwa Hur and Stefan Roth.
85 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pgodet/star_flow/cedb96ff339d11abf71d12d09e794593a742ccce/__init__.py
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from . import flyingchairs
2 | from . import flyingchairsOcc
3 | from . import flyingThings3D
4 | from . import kitti_combined
5 | from . import sintel
6 |
7 | from . import flyingThings3DMultiframe
8 | from . import sintel_multiframe
9 | from . import kitti_comb_multiframe
10 |
11 |
12 | ## FlyingChairs
13 | FlyingChairsTrain = flyingchairs.FlyingChairsTrain
14 | FlyingChairsValid = flyingchairs.FlyingChairsValid
15 | FlyingChairsFull = flyingchairs.FlyingChairsFull
16 |
17 | ## Our custom FlyingChairs + Occ
18 | FlyingChairsOccTrain = flyingchairsOcc.FlyingChairsOccTrain
19 | FlyingChairsOccValid = flyingchairsOcc.FlyingChairsOccValid
20 | FlyingChairsOccFull = flyingchairsOcc.FlyingChairsOccFull
21 |
22 |
23 | ## FlyingThings3D_subset
24 | FlyingThings3dFinalTrain = flyingThings3D.FlyingThings3dFinalTrain
25 | FlyingThings3dFinalTest = flyingThings3D.FlyingThings3dFinalTest
26 | FlyingThings3dCleanTrain = flyingThings3D.FlyingThings3dCleanTrain
27 | FlyingThings3dCleanTest = flyingThings3D.FlyingThings3dCleanTest
28 |
29 |
30 | ## Sintel
31 | SintelTestClean = sintel.SintelTestClean
32 | SintelTestFinal = sintel.SintelTestFinal
33 |
34 | SintelTrainingCombFull = sintel.SintelTrainingCombFull
35 | SintelTrainingCombTrain = sintel.SintelTrainingCombTrain
36 | SintelTrainingCombValid = sintel.SintelTrainingCombValid
37 |
38 | SintelTrainingCleanFull = sintel.SintelTrainingCleanFull
39 | SintelTrainingCleanTrain = sintel.SintelTrainingCleanTrain
40 | SintelTrainingCleanValid = sintel.SintelTrainingCleanValid
41 |
42 | SintelTrainingFinalFull = sintel.SintelTrainingFinalFull
43 | SintelTrainingFinalTrain = sintel.SintelTrainingFinalTrain
44 | SintelTrainingFinalValid = sintel.SintelTrainingFinalValid
45 |
46 |
47 | ## KITTI Optical Flow 2012 + 2015
48 | KittiCombTrain = kitti_combined.KittiCombTrain
49 | KittiCombVal = kitti_combined.KittiCombVal
50 | KittiCombFull = kitti_combined.KittiCombFull
51 |
52 | KittiComb2012Train = kitti_combined.KittiComb2012Train
53 | KittiComb2012Val = kitti_combined.KittiComb2012Val
54 | KittiComb2012Full = kitti_combined.KittiComb2012Full
55 | KittiComb2012Test = kitti_combined.KittiComb2012Test
56 |
57 | KittiComb2015Train = kitti_combined.KittiComb2015Train
58 | KittiComb2015Val = kitti_combined.KittiComb2015Val
59 | KittiComb2015Full = kitti_combined.KittiComb2015Full
60 | KittiComb2015Test = kitti_combined.KittiComb2015Test
61 |
62 |
63 | ## FlyingThings3D_subset_Multiframe
64 | FlyingThings3dMultiframeCleanTrain = flyingThings3DMultiframe.FlyingThings3dMultiframeCleanTrain
65 | FlyingThings3dMultiframeCleanTest = flyingThings3DMultiframe.FlyingThings3dMultiframeCleanTest
66 |
67 |
68 | ## SintelMultiframe
69 | SintelMultiframeTrainingCombFull = sintel_multiframe.SintelMultiframeTrainingCombFull
70 | SintelMultiframeTrainingCleanFull = sintel_multiframe.SintelMultiframeTrainingCleanFull
71 | SintelMultiframeTrainingFinalFull = sintel_multiframe.SintelMultiframeTrainingFinalFull
72 |
73 | SintelMultiframeTrainingCombValid = sintel_multiframe.SintelMultiframeTrainingCombValid
74 | SintelMultiframeTrainingCleanValid = sintel_multiframe.SintelMultiframeTrainingCleanValid
75 | SintelMultiframeTrainingFinalValid = sintel_multiframe.SintelMultiframeTrainingFinalValid
76 |
77 | SintelMultiframeTrainingCombTrain = sintel_multiframe.SintelMultiframeTrainingCombTrain
78 | SintelMultiframeTrainingCleanTrain = sintel_multiframe.SintelMultiframeTrainingCleanTrain
79 | SintelMultiframeTrainingFinalTrain = sintel_multiframe.SintelMultiframeTrainingFinalTrain
80 |
81 | SintelMultiframeTestFinal = sintel_multiframe.SintelMultiframeTestFinal
82 | SintelMultiframeTestClean = sintel_multiframe.SintelMultiframeTestClean
83 |
84 |
85 | ## KITTI Optical Flow 2012 + 2015 MULTIFRAME
86 | KittiMultiframeCombTrain = kitti_comb_multiframe.KittiMultiframeCombTrain
87 | KittiMultiframeCombVal = kitti_comb_multiframe.KittiMultiframeCombVal
88 | KittiMultiframeCombFull = kitti_comb_multiframe.KittiMultiframeCombFull
89 |
90 | KittiMultiframeComb2012Train = kitti_comb_multiframe.KittiMultiframeComb2012Train
91 | KittiMultiframeComb2012Val = kitti_comb_multiframe.KittiMultiframeComb2012Val
92 | KittiMultiframeComb2012Full = kitti_comb_multiframe.KittiMultiframeComb2012Full
93 | KittiMultiframeComb2012Test = kitti_comb_multiframe.KittiMultiframeComb2012Test
94 |
95 | KittiMultiframeComb2015Train = kitti_comb_multiframe.KittiMultiframeComb2015Train
96 | KittiMultiframeComb2015Val = kitti_comb_multiframe.KittiMultiframeComb2015Val
97 | KittiMultiframeComb2015Full = kitti_comb_multiframe.KittiMultiframeComb2015Full
98 | KittiMultiframeComb2015Test = kitti_comb_multiframe.KittiMultiframeComb2015Test
99 |
--------------------------------------------------------------------------------
/datasets/common.py:
--------------------------------------------------------------------------------
1 | ## Portions of Code from, copyright 2018 Jochen Gast
2 |
3 | from __future__ import absolute_import, division, print_function
4 |
5 | import torch
6 | import numpy as np
7 | from scipy import ndimage
8 |
9 | import png
10 |
11 |
12 | def numpy2torch(array):
13 | assert(isinstance(array, np.ndarray))
14 | if array.ndim == 3:
15 | array = np.transpose(array, (2, 0, 1))
16 | else:
17 | array = np.expand_dims(array, axis=0)
18 | return torch.from_numpy(array.copy()).float()
19 |
20 |
21 | def read_flo_as_float32(filename):
22 | with open(filename, 'rb') as file:
23 | magic = np.fromfile(file, np.float32, count=1)
24 | assert(202021.25 == magic), "Magic number incorrect. Invalid .flo file"
25 | w = np.fromfile(file, np.int32, count=1)[0]
26 | h = np.fromfile(file, np.int32, count=1)[0]
27 | data = np.fromfile(file, np.float32, count=2*h*w)
28 | data2D = np.resize(data, (h, w, 2))
29 | return data2D
30 |
31 |
32 | def read_occ_image_as_float32(filename):
33 | occ = ndimage.imread(filename).astype(np.float32) / np.float32(255.0)
34 | if occ.ndim == 3:
35 | occ = occ[:, :, 0]
36 | return occ
37 |
38 |
39 | def read_image_as_float32(filename):
40 | return ndimage.imread(filename).astype(np.float32) / np.float32(255.0)
41 |
42 |
43 | def read_image_as_byte(filename):
44 | return ndimage.imread(filename)
45 |
46 |
47 | def read_flopng_as_float32(filename):
48 | """
49 | Read from KITTI .png file
50 | :param flow_file: name of the flow file
51 | :return: optical flow data in matrix
52 | """
53 | flow_object = png.Reader(filename=filename)
54 | flow_direct = flow_object.asDirect()
55 | flow_data = list(flow_direct[2])
56 | (w, h) = flow_direct[3]['size']
57 | #print("Reading %d x %d flow file in .png format" % (h, w))
58 | flow = np.zeros((h, w, 3), dtype=np.float32)
59 | for i in range(len(flow_data)):
60 | flow[i, :, 0] = flow_data[i][0::3]
61 | flow[i, :, 1] = flow_data[i][1::3]
62 | flow[i, :, 2] = flow_data[i][2::3]
63 |
64 | invalid_idx = (flow[:, :, 2] == 0)
65 | flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0
66 | flow[invalid_idx, 0] = 0
67 | flow[invalid_idx, 1] = 0
68 | return flow[:, :, :2]
--------------------------------------------------------------------------------
/datasets/flyingThings3D.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import os
4 | import torch.utils.data as data
5 | from glob import glob
6 |
7 | from torchvision import transforms as vision_transforms
8 |
9 | from . import transforms
10 | from . import common
11 |
12 | import numpy as np
13 |
14 |
15 | def fillingInNaN(flow):
16 | h, w, c = flow.shape
17 | indices = np.argwhere(np.isnan(flow))
18 | neighbors = [[-1, 0], [1, 0], [0, -1], [0, 1]]
19 | for ii, idx in enumerate(indices):
20 | sum_sample = 0
21 | count = 0
22 | for jj in range(0, len(neighbors) - 1):
23 | hh = idx[0] + neighbors[jj][0]
24 | ww = idx[1] + neighbors[jj][1]
25 | if hh < 0 or hh >= h:
26 | continue
27 | if ww < 0 or ww >= w:
28 | continue
29 | sample_flow = flow[hh, ww, idx[2]]
30 | if np.isnan(sample_flow):
31 | continue
32 | sum_sample += sample_flow
33 | count += 1
34 | if count is 0:
35 | print('FATAL ERROR: no sample')
36 | flow[idx[0], idx[1], idx[2]] = sum_sample / count
37 |
38 | return flow
39 |
40 |
41 | class FlyingThings3d(data.Dataset):
42 | def __init__(self,
43 | args,
44 | images_root,
45 | flow_root,
46 | occ_root,
47 | photometric_augmentations=False,
48 | backward=False):
49 |
50 | self._args = args
51 | self.backward = backward
52 |
53 | if not os.path.isdir(images_root):
54 | raise ValueError("Image directory '%s' not found!")
55 | if flow_root is not None and not os.path.isdir(flow_root):
56 | raise ValueError("Flow directory '%s' not found!")
57 | if occ_root is not None and not os.path.isdir(occ_root):
58 | raise ValueError("Occ directory '%s' not found!")
59 |
60 | if flow_root is not None:
61 | flow_f_filenames = sorted(glob(os.path.join(flow_root, "into_future/*.flo")))
62 | flow_b_filenames = sorted(glob(os.path.join(flow_root, "into_past/*.flo")))
63 |
64 | if occ_root is not None:
65 | occ1_filenames = sorted(glob(os.path.join(occ_root, "into_future/*.png")))
66 | occ2_filenames = sorted(glob(os.path.join(occ_root, "into_past/*.png")))
67 |
68 | all_img_filenames = sorted(glob(os.path.join(images_root, "*.png")))
69 |
70 | self._image_list = []
71 | self._flow_list = [] if flow_root is not None else None
72 | self._occ_list = [] if occ_root is not None else None
73 |
74 | assert len(all_img_filenames) != 0
75 | assert len(flow_f_filenames) != 0
76 | assert len(flow_b_filenames) != 0
77 | assert len(occ1_filenames) != 0
78 | assert len(occ2_filenames) != 0
79 |
80 | ## path definition
81 | path_flow_f = os.path.join(flow_root, "into_future")
82 | path_flow_b = os.path.join(flow_root, "into_past")
83 | path_occ_f = os.path.join(occ_root, "into_future")
84 | path_occ_b = os.path.join(occ_root, "into_past")
85 |
86 | # ----------------------------------------------------------
87 | # Save list of actual filenames for inputs and flows
88 | # ----------------------------------------------------------
89 |
90 | for ii in range(0, len(flow_f_filenames)):
91 |
92 | flo_f = flow_f_filenames[ii]
93 |
94 | idx_f = os.path.splitext(os.path.basename(flo_f))[0]
95 | idx_b = str(int(idx_f) + 1).zfill(7)
96 |
97 | flo_b = os.path.join(path_flow_b, idx_b + ".flo")
98 |
99 | im1 = os.path.join(images_root, idx_f + ".png")
100 | im2 = os.path.join(images_root, idx_b + ".png")
101 | occ1 = os.path.join(path_occ_f, idx_f + ".png")
102 | occ2 = os.path.join(path_occ_b, idx_b + ".png")
103 |
104 | if not os.path.isfile(flo_f) or not os.path.isfile(flo_b) or not os.path.isfile(im1) or not os.path.isfile(
105 | im2) or not os.path.isfile(occ1) or not os.path.isfile(occ2):
106 | continue
107 |
108 | self._image_list += [[im1, im2]]
109 | self._flow_list += [[flo_f, flo_b]]
110 | self._occ_list += [[occ1, occ2]]
111 |
112 | self._size = len(self._image_list)
113 |
114 | assert len(self._image_list) == len(self._flow_list)
115 | assert len(self._occ_list) == len(self._flow_list)
116 | assert len(self._image_list) != 0
117 |
118 | # ----------------------------------------------------------
119 | # photometric_augmentations
120 | # ----------------------------------------------------------
121 | if photometric_augmentations:
122 | self._photometric_transform = transforms.ConcatTransformSplitChainer([
123 | # uint8 -> PIL
124 | vision_transforms.ToPILImage(),
125 | # PIL -> PIL : random hsv and contrast
126 | vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
127 | # PIL -> FloatTensor
128 | vision_transforms.transforms.ToTensor(),
129 | transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
130 | ], from_numpy=True, to_numpy=False)
131 |
132 | else:
133 | self._photometric_transform = transforms.ConcatTransformSplitChainer([
134 | # uint8 -> FloatTensor
135 | vision_transforms.transforms.ToTensor(),
136 | ], from_numpy=True, to_numpy=False)
137 |
138 | def __getitem__(self, index):
139 | index = index % self._size
140 |
141 | im1_filename = self._image_list[index][0]
142 | im2_filename = self._image_list[index][1]
143 | flo_f_filename = self._flow_list[index][0]
144 | flo_b_filename = self._flow_list[index][1]
145 | occ1_filename = self._occ_list[index][0]
146 | occ2_filename = self._occ_list[index][1]
147 |
148 | # read float32 images and flow
149 | im1_np0 = common.read_image_as_byte(im1_filename)
150 | im2_np0 = common.read_image_as_byte(im2_filename)
151 | flo_f_np0 = common.read_flo_as_float32(flo_f_filename)
152 | flo_b_np0 = common.read_flo_as_float32(flo_b_filename)
153 | occ1_np0 = common.read_occ_image_as_float32(occ1_filename)
154 | occ2_np0 = common.read_occ_image_as_float32(occ2_filename)
155 |
156 | # temp - check isnan
157 | if np.any(np.isnan(flo_f_np0)):
158 | flo_f_np0 = fillingInNaN(flo_f_np0)
159 |
160 | if np.any(np.isnan(flo_b_np0)):
161 | flo_b_np0 = fillingInNaN(flo_b_np0)
162 |
163 | # possibly apply photometric transformations
164 | im1, im2 = self._photometric_transform(im1_np0, im2_np0)
165 |
166 | # convert flow to FloatTensor
167 | flo_f = common.numpy2torch(flo_f_np0)
168 | flo_b = common.numpy2torch(flo_b_np0)
169 |
170 | # convert occ to FloatTensor
171 | occ1 = common.numpy2torch(occ1_np0)
172 | occ2 = common.numpy2torch(occ2_np0)
173 |
174 | # example filename
175 | basename = os.path.basename(im1_filename)[:-4]
176 |
177 | if self.backward:
178 | #inversion des flots (et occ) : backward en 1 et forward en 2
179 | example_dict = {
180 | "input1": im1,
181 | "input2": im2,
182 | "target1": flo_b,
183 | "target2": flo_f,
184 | "target_occ1": occ2,
185 | "target_occ2": occ1,
186 | "index": index,
187 | "basename": basename
188 | }
189 | else:
190 | example_dict = {
191 | "input1": im1,
192 | "input2": im2,
193 | "target1": flo_f,
194 | "target2": flo_b,
195 | "target_occ1": occ1,
196 | "target_occ2": occ2,
197 | "index": index,
198 | "basename": basename
199 | }
200 |
201 | return example_dict
202 |
203 | def __len__(self):
204 | return self._size
205 |
206 |
207 | class FlyingThings3dFinalTrain(FlyingThings3d):
208 | def __init__(self,
209 | args,
210 | root,
211 | photometric_augmentations=True,
212 | backward=False):
213 | images_root = os.path.join(root, "frames_finalpass")
214 | flow_root = os.path.join(root, "optical_flow")
215 | occ_root = os.path.join(root, "occlusion")
216 | super(FlyingThings3dFinalTrain, self).__init__(
217 | args,
218 | images_root=images_root,
219 | flow_root=flow_root,
220 | occ_root=occ_root,
221 | photometric_augmentations=photometric_augmentations,
222 | backward=backward)
223 |
224 |
225 | class FlyingThings3dFinalTest(FlyingThings3d):
226 | def __init__(self,
227 | args,
228 | root,
229 | photometric_augmentations=False,
230 | backward=False):
231 | images_root = os.path.join(root, "frames_finalpass")
232 | flow_root = os.path.join(root, "optical_flow")
233 | occ_root = os.path.join(root, "occlusion")
234 | super(FlyingThings3dFinalTest, self).__init__(
235 | args,
236 | images_root=images_root,
237 | flow_root=flow_root,
238 | occ_root=occ_root,
239 | photometric_augmentations=photometric_augmentations,
240 | backward=backward)
241 |
242 |
243 | class FlyingThings3dCleanTrain(FlyingThings3d):
244 | def __init__(self,
245 | args,
246 | root,
247 | photometric_augmentations=True,
248 | backward=False):
249 | images_root = os.path.join(root, "train", "image_clean", "left")
250 | flow_root = os.path.join(root, "train", "flow", "left")
251 | occ_root = os.path.join(root, "train", "flow_occlusions", "left")
252 | super(FlyingThings3dCleanTrain, self).__init__(
253 | args,
254 | images_root=images_root,
255 | flow_root=flow_root,
256 | occ_root=occ_root,
257 | photometric_augmentations=photometric_augmentations,
258 | backward=backward)
259 |
260 |
261 | class FlyingThings3dCleanTest(FlyingThings3d):
262 | def __init__(self,
263 | args,
264 | root,
265 | photometric_augmentations=False,
266 | backward=False):
267 | images_root = os.path.join(root, "val", "image_clean", "left")
268 | flow_root = os.path.join(root, "val", "flow", "left")
269 | occ_root = os.path.join(root, "val", "flow_occlusions", "left")
270 | super(FlyingThings3dCleanTest, self).__init__(
271 | args,
272 | images_root=images_root,
273 | flow_root=flow_root,
274 | occ_root=occ_root,
275 | photometric_augmentations=photometric_augmentations,
276 | backward=backward)
277 |
--------------------------------------------------------------------------------
/datasets/flyingThings3DMultiframe.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import os
4 | import torch
5 | import torch.utils.data as data
6 | from glob import glob
7 |
8 | import torch
9 | from torchvision import transforms as vision_transforms
10 |
11 | from . import transforms
12 | from . import common
13 |
14 | import numpy as np
15 |
16 |
17 | def fillingInNaN(flow):
18 | h, w, c = flow.shape
19 | indices = np.argwhere(np.isnan(flow))
20 | neighbors = [[-1, 0], [1, 0], [0, -1], [0, 1]]
21 | for ii, idx in enumerate(indices):
22 | sum_sample = 0
23 | count = 0
24 | for jj in range(0, len(neighbors) - 1):
25 | hh = idx[0] + neighbors[jj][0]
26 | ww = idx[1] + neighbors[jj][1]
27 | if hh < 0 or hh >= h:
28 | continue
29 | if ww < 0 or ww >= w:
30 | continue
31 | sample_flow = flow[hh, ww, idx[2]]
32 | if np.isnan(sample_flow):
33 | continue
34 | sum_sample += sample_flow
35 | count += 1
36 | if count is 0:
37 | print('FATAL ERROR: no sample')
38 | flow[idx[0], idx[1], idx[2]] = sum_sample / count
39 |
40 | return flow
41 |
42 |
43 | class FlyingThings3dMultiframe(data.Dataset):
44 | def __init__(self,
45 | args,
46 | images_root,
47 | flow_root,
48 | occ_root,
49 | seq_lengths_path, nframes=5,
50 | photometric_augmentations=False,
51 | backward=False):
52 |
53 | self._args = args
54 | self._nframes = nframes
55 | self.backward = backward
56 |
57 | if not os.path.isdir(images_root):
58 | raise ValueError("Image directory '%s' not found!", images_root)
59 | if flow_root is not None and not os.path.isdir(flow_root):
60 | raise ValueError("Flow directory '%s' not found!", flow_root)
61 | if occ_root is not None and not os.path.isdir(occ_root):
62 | raise ValueError("Occ directory '%s' not found!", occ_root)
63 |
64 | if flow_root is not None:
65 | flow_f_filenames = sorted(glob(os.path.join(flow_root, "into_future/*.flo")))
66 | flow_b_filenames = sorted(glob(os.path.join(flow_root, "into_past/*.flo")))
67 |
68 | if occ_root is not None:
69 | occ1_filenames = sorted(glob(os.path.join(occ_root, "into_future/*.png")))
70 | occ2_filenames = sorted(glob(os.path.join(occ_root, "into_past/*.png")))
71 |
72 | all_img_filenames = sorted(glob(os.path.join(images_root, "*.png")))
73 |
74 | self._image_list = []
75 | self._flow_list = [] if flow_root is not None else None
76 | self._occ_list = [] if occ_root is not None else None
77 |
78 | assert len(all_img_filenames) != 0
79 | assert len(flow_f_filenames) != 0
80 | assert len(flow_b_filenames) != 0
81 | assert len(occ1_filenames) != 0
82 | assert len(occ2_filenames) != 0
83 |
84 | self._seq_lengths = np.load(seq_lengths_path)
85 |
86 | ## path definition
87 | path_flow_f = os.path.join(flow_root, "into_future")
88 | path_flow_b = os.path.join(flow_root, "into_past")
89 | path_occ_f = os.path.join(occ_root, "into_future")
90 | path_occ_b = os.path.join(occ_root, "into_past")
91 |
92 | # ----------------------------------------------------------
93 | # Save list of actual filenames for inputs and flows
94 | # ----------------------------------------------------------
95 |
96 | idx_first = 0
97 |
98 | for seq_len in self._seq_lengths:
99 | list_images = []
100 | list_flows = []
101 | list_occs = []
102 |
103 | for ii in range(idx_first, idx_first + seq_len - 1):
104 | list_images.append(os.path.join(images_root, "{:07d}".format(ii) + ".png"))
105 | if self.backward:
106 | list_flows.append(os.path.join(path_flow_b, "{:07d}".format(ii+1) + ".flo"))
107 | list_occs.append(os.path.join(path_occ_b, "{:07d}".format(ii+1) + ".png"))
108 | else:
109 | list_flows.append(os.path.join(path_flow_f, "{:07d}".format(ii) + ".flo"))
110 | list_occs.append(os.path.join(path_occ_f, "{:07d}".format(ii) + ".png"))
111 | #if not os.path.isfile(flo_f) or not os.path.isfile(flo_b) or not os.path.isfile(im1) or not os.path.isfile(
112 | # im2) or not os.path.isfile(occ1) or not os.path.isfile(occ2):
113 | # continue
114 | list_images.append(os.path.join(images_root, "{:07d}".format(ii + 1) + ".png")) # ii + 1 = idx_first + seq_len - 1
115 |
116 | for i in range(len(list_images) - self._nframes + 1):
117 |
118 | imgs = list_images[i:i+self._nframes]
119 | flows = list_flows[i:i+self._nframes-1]
120 | occs = list_occs[i:i+self._nframes-1]
121 |
122 | self._image_list += [imgs]
123 | self._flow_list += [flows]
124 | self._occ_list += [occs]
125 |
126 | idx_first += seq_len
127 |
128 | self._size = len(self._image_list)
129 |
130 | assert len(self._image_list) == len(self._flow_list)
131 | assert len(self._occ_list) == len(self._flow_list)
132 | assert len(self._image_list) != 0
133 |
134 |
135 | # ----------------------------------------------------------
136 | # photometric_augmentations
137 | # ----------------------------------------------------------
138 | if photometric_augmentations:
139 | self._photometric_transform = transforms.ConcatTransformSplitChainer([
140 | # uint8 -> PIL
141 | vision_transforms.ToPILImage(),
142 | # PIL -> PIL : random hsv and contrast
143 | vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
144 | # PIL -> FloatTensor
145 | vision_transforms.transforms.ToTensor(),
146 | transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
147 | ], from_numpy=True, to_numpy=False)
148 |
149 | else:
150 | self._photometric_transform = transforms.ConcatTransformSplitChainer([
151 | # uint8 -> FloatTensor
152 | vision_transforms.transforms.ToTensor(),
153 | ], from_numpy=True, to_numpy=False)
154 |
155 | def __getitem__(self, index):
156 |
157 | index = index % self._size
158 |
159 | imgs_filenames = self._image_list[index]
160 | flows_filenames = self._flow_list[index]
161 | occs_filenames = self._occ_list[index]
162 |
163 | # read float32 images and flow
164 | imgs_np0 = [common.read_image_as_byte(filename) for filename in imgs_filenames]
165 | flows_np0 = [common.read_flo_as_float32(filename) for filename in flows_filenames]
166 | occs_np0 = [common.read_occ_image_as_float32(filename) for filename in occs_filenames]
167 |
168 | # temp - check isnan
169 | for ii in range(len(flows_np0)):
170 | if np.any(np.isnan(flows_np0[ii])):
171 | flows_np0[ii] = fillingInNaN(flows_np0[ii])
172 |
173 | # possibly apply photometric transformations
174 | imgs = self._photometric_transform(*imgs_np0)
175 |
176 | # convert flow to FloatTensor
177 | flows = [common.numpy2torch(flo_np0) for flo_np0 in flows_np0]
178 |
179 | # convert occ to FloatTensor
180 | occs = [common.numpy2torch(occ_np0) for occ_np0 in occs_np0]
181 |
182 | # example filename
183 | basename = [os.path.basename(filename)[:-4] for filename in imgs_filenames]
184 |
185 | example_dict = {
186 | "input1": imgs[0],
187 | "input_images": imgs, # "target_flows": torch.stack(flows, 0),
188 | "target1": flows[0],
189 | "target_flows": flows, #torch.stack(flows, 0)
190 | "target_occ1": occs[0],
191 | "target_occs": occs, #torch.stack(occs, 0)
192 | "index": index,
193 | "basename": basename,
194 | "nframes": self._nframes
195 | }
196 |
197 | return example_dict
198 |
199 | def __len__(self):
200 | return self._size
201 |
202 |
203 | # class FlyingThings3dFinalTrain(FlyingThings3d):
204 | # def __init__(self,
205 | # args,
206 | # root,
207 | # photometric_augmentations=True):
208 | # images_root = os.path.join(root, "frames_finalpass")
209 | # flow_root = os.path.join(root, "optical_flow")
210 | # occ_root = os.path.join(root, "occlusion")
211 | # seq_lengths_path = os.path.join(root, "seq_lengths.npy")
212 | # super(FlyingThings3dFinalTrain, self).__init__(
213 | # args,
214 | # images_root=images_root,
215 | # flow_root=flow_root,
216 | # occ_root=occ_root,
217 | # seq_lengths_path=seq_lengths_path,
218 | # photometric_augmentations=photometric_augmentations)
219 |
220 |
221 | # class FlyingThings3dFinalTest(FlyingThings3d):
222 | # def __init__(self,
223 | # args,
224 | # root,
225 | # photometric_augmentations=False):
226 | # images_root = os.path.join(root, "frames_finalpass")
227 | # flow_root = os.path.join(root, "optical_flow")
228 | # occ_root = os.path.join(root, "occlusion")
229 | # seq_lengths_path = os.path.join(root, "seq_lengths.npy")
230 | # super(FlyingThings3dFinalTest, self).__init__(
231 | # args,
232 | # images_root=images_root,
233 | # flow_root=flow_root,
234 | # occ_root=occ_root,
235 | # seq_lengths_path=seq_lengths_path,
236 | # photometric_augmentations=photometric_augmentations)
237 |
238 |
239 | class FlyingThings3dMultiframeCleanTrain(FlyingThings3dMultiframe):
240 | def __init__(self,
241 | args,
242 | root,
243 | nframes=5,
244 | photometric_augmentations=True,
245 | backward=False):
246 | images_root = os.path.join(root, "train", "image_clean", "left")
247 | flow_root = os.path.join(root, "train", "flow", "left")
248 | occ_root = os.path.join(root, "train", "flow_occlusions", "left")
249 | seq_lengths_path = os.path.join(root, "train", "seq_lengths.npy")
250 | super(FlyingThings3dMultiframeCleanTrain, self).__init__(
251 | args,
252 | images_root=images_root,
253 | flow_root=flow_root,
254 | occ_root=occ_root,
255 | seq_lengths_path=seq_lengths_path,
256 | photometric_augmentations=photometric_augmentations,
257 | nframes=nframes, backward=backward)
258 |
259 |
260 | class FlyingThings3dMultiframeCleanTest(FlyingThings3dMultiframe):
261 | def __init__(self,
262 | args,
263 | root,
264 | nframes=5,
265 | photometric_augmentations=False,
266 | backward=False):
267 | images_root = os.path.join(root, "val", "image_clean", "left")
268 | flow_root = os.path.join(root, "val", "flow", "left")
269 | occ_root = os.path.join(root, "val", "flow_occlusions", "left")
270 | seq_lengths_path = os.path.join(root, "val", "seq_lengths.npy")
271 | super(FlyingThings3dMultiframeCleanTest, self).__init__(
272 | args,
273 | images_root=images_root,
274 | flow_root=flow_root,
275 | occ_root=occ_root,
276 | seq_lengths_path=seq_lengths_path,
277 | photometric_augmentations=photometric_augmentations,
278 | nframes=nframes, backward=backward)
279 |
--------------------------------------------------------------------------------
/datasets/flyingchairs.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import os
4 | import torch.utils.data as data
5 | from glob import glob
6 |
7 | from torchvision import transforms as vision_transforms
8 |
9 | from . import transforms
10 | from . import common
11 |
12 |
13 | VALIDATE_INDICES = [
14 | 5, 17, 42, 45, 58, 62, 96, 111, 117, 120, 121, 131, 132,
15 | 152, 160, 248, 263, 264, 291, 293, 295, 299, 316, 320, 336,
16 | 337, 343, 358, 399, 401, 429, 438, 468, 476, 494, 509, 528,
17 | 531, 572, 581, 583, 588, 593, 681, 688, 696, 714, 767, 786,
18 | 810, 825, 836, 841, 883, 917, 937, 942, 970, 974, 980, 1016,
19 | 1043, 1064, 1118, 1121, 1133, 1153, 1155, 1158, 1159, 1173,
20 | 1187, 1219, 1237, 1238, 1259, 1266, 1278, 1296, 1354, 1378,
21 | 1387, 1494, 1508, 1518, 1574, 1601, 1614, 1668, 1673, 1699,
22 | 1712, 1714, 1737, 1841, 1872, 1879, 1901, 1921, 1934, 1961,
23 | 1967, 1978, 2018, 2030, 2039, 2043, 2061, 2113, 2204, 2216,
24 | 2236, 2250, 2274, 2292, 2310, 2342, 2359, 2374, 2382, 2399,
25 | 2415, 2419, 2483, 2502, 2504, 2576, 2589, 2590, 2622, 2624,
26 | 2636, 2651, 2655, 2658, 2659, 2664, 2672, 2706, 2707, 2709,
27 | 2725, 2732, 2761, 2827, 2864, 2866, 2905, 2922, 2929, 2966,
28 | 2972, 2993, 3010, 3025, 3031, 3040, 3041, 3070, 3113, 3124,
29 | 3129, 3137, 3141, 3157, 3183, 3206, 3219, 3247, 3253, 3272,
30 | 3276, 3321, 3328, 3333, 3338, 3341, 3346, 3351, 3396, 3419,
31 | 3430, 3433, 3448, 3455, 3463, 3503, 3526, 3529, 3537, 3555,
32 | 3577, 3584, 3591, 3594, 3597, 3603, 3613, 3615, 3670, 3676,
33 | 3678, 3697, 3723, 3728, 3734, 3745, 3750, 3752, 3779, 3782,
34 | 3813, 3817, 3819, 3854, 3885, 3944, 3947, 3970, 3985, 4011,
35 | 4022, 4071, 4075, 4132, 4158, 4167, 4190, 4194, 4207, 4246,
36 | 4249, 4298, 4307, 4317, 4318, 4319, 4320, 4382, 4399, 4401,
37 | 4407, 4416, 4423, 4484, 4491, 4493, 4517, 4525, 4538, 4578,
38 | 4606, 4609, 4620, 4623, 4637, 4646, 4662, 4668, 4716, 4739,
39 | 4747, 4770, 4774, 4776, 4785, 4800, 4845, 4863, 4891, 4904,
40 | 4922, 4925, 4956, 4963, 4964, 4994, 5011, 5019, 5036, 5038,
41 | 5041, 5055, 5118, 5122, 5130, 5162, 5164, 5178, 5196, 5227,
42 | 5266, 5270, 5273, 5279, 5299, 5310, 5314, 5363, 5375, 5384,
43 | 5393, 5414, 5417, 5433, 5448, 5494, 5505, 5509, 5525, 5566,
44 | 5581, 5602, 5609, 5620, 5653, 5670, 5678, 5690, 5700, 5703,
45 | 5724, 5752, 5765, 5803, 5811, 5860, 5881, 5895, 5912, 5915,
46 | 5940, 5952, 5966, 5977, 5988, 6007, 6037, 6061, 6069, 6080,
47 | 6111, 6127, 6146, 6161, 6166, 6168, 6178, 6182, 6190, 6220,
48 | 6235, 6253, 6270, 6343, 6372, 6379, 6410, 6411, 6442, 6453,
49 | 6481, 6498, 6500, 6509, 6532, 6541, 6543, 6560, 6576, 6580,
50 | 6594, 6595, 6609, 6625, 6629, 6644, 6658, 6673, 6680, 6698,
51 | 6699, 6702, 6705, 6741, 6759, 6785, 6792, 6794, 6809, 6810,
52 | 6830, 6838, 6869, 6871, 6889, 6925, 6995, 7003, 7026, 7029,
53 | 7080, 7082, 7097, 7102, 7116, 7165, 7200, 7232, 7271, 7282,
54 | 7324, 7333, 7335, 7372, 7387, 7407, 7472, 7474, 7482, 7489,
55 | 7499, 7516, 7533, 7536, 7566, 7620, 7654, 7691, 7704, 7722,
56 | 7746, 7750, 7773, 7806, 7821, 7827, 7851, 7873, 7880, 7884,
57 | 7904, 7912, 7948, 7964, 7965, 7984, 7989, 7992, 8035, 8050,
58 | 8074, 8091, 8094, 8113, 8116, 8151, 8159, 8171, 8179, 8194,
59 | 8195, 8239, 8263, 8290, 8295, 8312, 8367, 8374, 8387, 8407,
60 | 8437, 8439, 8518, 8556, 8588, 8597, 8601, 8651, 8657, 8723,
61 | 8759, 8763, 8785, 8802, 8813, 8826, 8854, 8856, 8866, 8918,
62 | 8922, 8923, 8932, 8958, 8967, 9003, 9018, 9078, 9095, 9104,
63 | 9112, 9129, 9147, 9170, 9171, 9197, 9200, 9249, 9253, 9270,
64 | 9282, 9288, 9295, 9321, 9323, 9324, 9347, 9399, 9403, 9417,
65 | 9426, 9427, 9439, 9468, 9486, 9496, 9511, 9516, 9518, 9529,
66 | 9557, 9563, 9564, 9584, 9586, 9591, 9599, 9600, 9601, 9632,
67 | 9654, 9667, 9678, 9696, 9716, 9723, 9740, 9820, 9824, 9825,
68 | 9828, 9863, 9866, 9868, 9889, 9929, 9938, 9953, 9967, 10019,
69 | 10020, 10025, 10059, 10111, 10118, 10125, 10174, 10194,
70 | 10201, 10202, 10220, 10221, 10226, 10242, 10250, 10276,
71 | 10295, 10302, 10305, 10327, 10351, 10360, 10369, 10393,
72 | 10407, 10438, 10455, 10463, 10465, 10470, 10478, 10503,
73 | 10508, 10509, 10809, 11080, 11331, 11607, 11610, 11864,
74 | 12390, 12393, 12396, 12399, 12671, 12921, 12930, 13178,
75 | 13453, 13717, 14499, 14517, 14775, 15297, 15556, 15834,
76 | 15839, 16126, 16127, 16386, 16633, 16644, 16651, 17166,
77 | 17169, 17958, 17959, 17962, 18224, 21176, 21180, 21190,
78 | 21802, 21803, 21806, 22584, 22857, 22858, 22866]
79 |
80 |
81 | class FlyingChairs(data.Dataset):
82 | def __init__(self,
83 | args,
84 | root,
85 | photometric_augmentations=False,
86 | dstype="train"):
87 |
88 | self._args = args
89 |
90 | # -------------------------------------------------------------
91 | # filenames for all input images and target flows
92 | # -------------------------------------------------------------
93 | image_filenames = sorted( glob( os.path.join(root, "*.ppm")) )
94 | flow_filenames = sorted( glob( os.path.join(root, "*.flo")) )
95 | assert (len(image_filenames)/2 == len(flow_filenames))
96 | num_flows = len(flow_filenames)
97 |
98 | # -------------------------------------------------------------
99 | # Remove invalid validation indices
100 | # -------------------------------------------------------------
101 | validate_indices = [x for x in VALIDATE_INDICES if x in range(num_flows)]
102 |
103 | # ----------------------------------------------------------
104 | # Construct list of indices for training/validation
105 | # ----------------------------------------------------------
106 | list_of_indices = None
107 | if dstype == "train":
108 | list_of_indices = [x for x in range(num_flows) if x not in validate_indices]
109 | elif dstype == "valid":
110 | list_of_indices = validate_indices
111 | elif dstype == "full":
112 | list_of_indices = range(num_flows)
113 | else:
114 | raise ValueError("FlyingChairs: dstype '%s' unknown!", dstype)
115 |
116 |
117 | # ----------------------------------------------------------
118 | # Save list of actual filenames for inputs and flows
119 | # ----------------------------------------------------------
120 | self._image_list = []
121 | self._flow_list = []
122 | for i in list_of_indices:
123 | flo = flow_filenames[i]
124 | im1 = image_filenames[2*i]
125 | im2 = image_filenames[2*i + 1]
126 | self._image_list += [ [ im1, im2 ] ]
127 | self._flow_list += [ flo ]
128 | self._size = len(self._image_list)
129 | assert len(self._image_list) == len(self._flow_list)
130 |
131 | # ----------------------------------------------------------
132 | # photometric_augmentations
133 | # ----------------------------------------------------------
134 | if photometric_augmentations:
135 | self._photometric_transform = transforms.ConcatTransformSplitChainer([
136 | # uint8 -> PIL
137 | vision_transforms.ToPILImage(),
138 | # PIL -> PIL : random hsv and contrast
139 | vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
140 | # PIL -> FloatTensor
141 | vision_transforms.transforms.ToTensor(),
142 | transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
143 | ], from_numpy=True, to_numpy=False)
144 | else:
145 | self._photometric_transform = transforms.ConcatTransformSplitChainer([
146 | # uint8 -> FloatTensor
147 | vision_transforms.transforms.ToTensor(),
148 | ], from_numpy=True, to_numpy=False)
149 |
150 | def __getitem__(self, index):
151 | index = index % self._size
152 |
153 | im1_filename = self._image_list[index][0]
154 | im2_filename = self._image_list[index][1]
155 | flo_filename = self._flow_list[index]
156 |
157 | # read float32 images and flow
158 | im1_np0 = common.read_image_as_byte(im1_filename)
159 | im2_np0 = common.read_image_as_byte(im2_filename)
160 | flo_np0 = common.read_flo_as_float32(flo_filename)
161 |
162 | # possibly apply photometric transformations
163 | im1, im2 = self._photometric_transform(im1_np0, im2_np0)
164 |
165 | # convert flow to FloatTensor
166 | flo = common.numpy2torch(flo_np0)
167 |
168 | # target_occ: initialized by zero (not used)
169 | target_occ = common.numpy2torch(common.read_occ_image_as_float32(im1_filename)) * 0
170 |
171 | # example filename
172 | basename = os.path.basename(im1_filename)[:5]
173 |
174 | example_dict = {
175 | "input1": im1,
176 | "input2": im2,
177 | "target1": flo,
178 | "target_occ1": target_occ,
179 | "index": index,
180 | "basename": basename
181 | }
182 |
183 | return example_dict
184 |
185 | def __len__(self):
186 | return self._size
187 |
188 |
189 | class FlyingChairsTrain(FlyingChairs):
190 | def __init__(self,
191 | args,
192 | root,
193 | photometric_augmentations=True):
194 | super(FlyingChairsTrain, self).__init__(
195 | args,
196 | root=root,
197 | photometric_augmentations=photometric_augmentations,
198 | dstype="train")
199 |
200 |
201 | class FlyingChairsValid(FlyingChairs):
202 | def __init__(self,
203 | args,
204 | root,
205 | photometric_augmentations=False):
206 | super(FlyingChairsValid, self).__init__(
207 | args,
208 | root=root,
209 | photometric_augmentations=photometric_augmentations,
210 | dstype="valid")
211 |
212 |
213 | class FlyingChairsFull(FlyingChairs):
214 | def __init__(self,
215 | args,
216 | root,
217 | photometric_augmentations=False):
218 | super(FlyingChairsFull, self).__init__(
219 | args,
220 | root=root,
221 | photometric_augmentations=photometric_augmentations,
222 | dstype="full")
223 |
--------------------------------------------------------------------------------
/datasets/flyingchairsOcc.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import os
4 | import torch
5 | import torch.utils.data as data
6 | from glob import glob
7 |
8 | from torchvision import transforms as vision_transforms
9 |
10 | from . import transforms
11 | from . import common
12 |
13 |
14 | VALIDATE_INDICES = [
15 | 5, 17, 42, 45, 58, 62, 96, 111, 117, 120, 121, 131, 132,
16 | 152, 160, 248, 263, 264, 291, 293, 295, 299, 316, 320, 336,
17 | 337, 343, 358, 399, 401, 429, 438, 468, 476, 494, 509, 528,
18 | 531, 572, 581, 583, 588, 593, 681, 688, 696, 714, 767, 786,
19 | 810, 825, 836, 841, 883, 917, 937, 942, 970, 974, 980, 1016,
20 | 1043, 1064, 1118, 1121, 1133, 1153, 1155, 1158, 1159, 1173,
21 | 1187, 1219, 1237, 1238, 1259, 1266, 1278, 1296, 1354, 1378,
22 | 1387, 1494, 1508, 1518, 1574, 1601, 1614, 1668, 1673, 1699,
23 | 1712, 1714, 1737, 1841, 1872, 1879, 1901, 1921, 1934, 1961,
24 | 1967, 1978, 2018, 2030, 2039, 2043, 2061, 2113, 2204, 2216,
25 | 2236, 2250, 2274, 2292, 2310, 2342, 2359, 2374, 2382, 2399,
26 | 2415, 2419, 2483, 2502, 2504, 2576, 2589, 2590, 2622, 2624,
27 | 2636, 2651, 2655, 2658, 2659, 2664, 2672, 2706, 2707, 2709,
28 | 2725, 2732, 2761, 2827, 2864, 2866, 2905, 2922, 2929, 2966,
29 | 2972, 2993, 3010, 3025, 3031, 3040, 3041, 3070, 3113, 3124,
30 | 3129, 3137, 3141, 3157, 3183, 3206, 3219, 3247, 3253, 3272,
31 | 3276, 3321, 3328, 3333, 3338, 3341, 3346, 3351, 3396, 3419,
32 | 3430, 3433, 3448, 3455, 3463, 3503, 3526, 3529, 3537, 3555,
33 | 3577, 3584, 3591, 3594, 3597, 3603, 3613, 3615, 3670, 3676,
34 | 3678, 3697, 3723, 3728, 3734, 3745, 3750, 3752, 3779, 3782,
35 | 3813, 3817, 3819, 3854, 3885, 3944, 3947, 3970, 3985, 4011,
36 | 4022, 4071, 4075, 4132, 4158, 4167, 4190, 4194, 4207, 4246,
37 | 4249, 4298, 4307, 4317, 4318, 4319, 4320, 4382, 4399, 4401,
38 | 4407, 4416, 4423, 4484, 4491, 4493, 4517, 4525, 4538, 4578,
39 | 4606, 4609, 4620, 4623, 4637, 4646, 4662, 4668, 4716, 4739,
40 | 4747, 4770, 4774, 4776, 4785, 4800, 4845, 4863, 4891, 4904,
41 | 4922, 4925, 4956, 4963, 4964, 4994, 5011, 5019, 5036, 5038,
42 | 5041, 5055, 5118, 5122, 5130, 5162, 5164, 5178, 5196, 5227,
43 | 5266, 5270, 5273, 5279, 5299, 5310, 5314, 5363, 5375, 5384,
44 | 5393, 5414, 5417, 5433, 5448, 5494, 5505, 5509, 5525, 5566,
45 | 5581, 5602, 5609, 5620, 5653, 5670, 5678, 5690, 5700, 5703,
46 | 5724, 5752, 5765, 5803, 5811, 5860, 5881, 5895, 5912, 5915,
47 | 5940, 5952, 5966, 5977, 5988, 6007, 6037, 6061, 6069, 6080,
48 | 6111, 6127, 6146, 6161, 6166, 6168, 6178, 6182, 6190, 6220,
49 | 6235, 6253, 6270, 6343, 6372, 6379, 6410, 6411, 6442, 6453,
50 | 6481, 6498, 6500, 6509, 6532, 6541, 6543, 6560, 6576, 6580,
51 | 6594, 6595, 6609, 6625, 6629, 6644, 6658, 6673, 6680, 6698,
52 | 6699, 6702, 6705, 6741, 6759, 6785, 6792, 6794, 6809, 6810,
53 | 6830, 6838, 6869, 6871, 6889, 6925, 6995, 7003, 7026, 7029,
54 | 7080, 7082, 7097, 7102, 7116, 7165, 7200, 7232, 7271, 7282,
55 | 7324, 7333, 7335, 7372, 7387, 7407, 7472, 7474, 7482, 7489,
56 | 7499, 7516, 7533, 7536, 7566, 7620, 7654, 7691, 7704, 7722,
57 | 7746, 7750, 7773, 7806, 7821, 7827, 7851, 7873, 7880, 7884,
58 | 7904, 7912, 7948, 7964, 7965, 7984, 7989, 7992, 8035, 8050,
59 | 8074, 8091, 8094, 8113, 8116, 8151, 8159, 8171, 8179, 8194,
60 | 8195, 8239, 8263, 8290, 8295, 8312, 8367, 8374, 8387, 8407,
61 | 8437, 8439, 8518, 8556, 8588, 8597, 8601, 8651, 8657, 8723,
62 | 8759, 8763, 8785, 8802, 8813, 8826, 8854, 8856, 8866, 8918,
63 | 8922, 8923, 8932, 8958, 8967, 9003, 9018, 9078, 9095, 9104,
64 | 9112, 9129, 9147, 9170, 9171, 9197, 9200, 9249, 9253, 9270,
65 | 9282, 9288, 9295, 9321, 9323, 9324, 9347, 9399, 9403, 9417,
66 | 9426, 9427, 9439, 9468, 9486, 9496, 9511, 9516, 9518, 9529,
67 | 9557, 9563, 9564, 9584, 9586, 9591, 9599, 9600, 9601, 9632,
68 | 9654, 9667, 9678, 9696, 9716, 9723, 9740, 9820, 9824, 9825,
69 | 9828, 9863, 9866, 9868, 9889, 9929, 9938, 9953, 9967, 10019,
70 | 10020, 10025, 10059, 10111, 10118, 10125, 10174, 10194,
71 | 10201, 10202, 10220, 10221, 10226, 10242, 10250, 10276,
72 | 10295, 10302, 10305, 10327, 10351, 10360, 10369, 10393,
73 | 10407, 10438, 10455, 10463, 10465, 10470, 10478, 10503,
74 | 10508, 10509, 10809, 11080, 11331, 11607, 11610, 11864,
75 | 12390, 12393, 12396, 12399, 12671, 12921, 12930, 13178,
76 | 13453, 13717, 14499, 14517, 14775, 15297, 15556, 15834,
77 | 15839, 16126, 16127, 16386, 16633, 16644, 16651, 17166,
78 | 17169, 17958, 17959, 17962, 18224, 21176, 21180, 21190,
79 | 21802, 21803, 21806, 22584, 22857, 22858, 22866]
80 |
81 |
82 | class FlyingChairsOcc(data.Dataset):
83 | def __init__(self,
84 | args,
85 | root,
86 | photometric_augmentations=False,
87 | dstype="train", backward=False):
88 |
89 | self._args = args
90 | self.backward = backward
91 |
92 | # -------------------------------------------------------------
93 | # filenames for all input images and target flows
94 | # -------------------------------------------------------------
95 | image1_filenames = sorted(glob(os.path.join(root, "*_img1.png")))
96 | image2_filenames = sorted(glob(os.path.join(root, "*_img2.png")))
97 | occ1_filenames = sorted(glob(os.path.join(root, "*_occ1.png")))
98 | occ2_filenames = sorted(glob(os.path.join(root, "*_occ2.png")))
99 | flow_f_filenames = sorted(glob(os.path.join(root, "*_flow.flo")))
100 | flow_b_filenames = sorted(glob(os.path.join(root, "*_flow_b.flo")))
101 | assert (len(image1_filenames) == len(image2_filenames))
102 | assert (len(image2_filenames) == len(occ1_filenames))
103 | assert (len(occ1_filenames) == len(occ2_filenames))
104 | assert (len(occ2_filenames) == len(flow_f_filenames))
105 | assert (len(flow_f_filenames) == len(flow_b_filenames))
106 |
107 | num_flows = len(flow_f_filenames)
108 |
109 | # -------------------------------------------------------------
110 | # Remove invalid validation indices
111 | # -------------------------------------------------------------
112 | validate_indices = [x for x in VALIDATE_INDICES if x in range(num_flows)]
113 |
114 | # ----------------------------------------------------------
115 | # Construct list of indices for training/validation
116 | # ----------------------------------------------------------
117 | list_of_indices = None
118 | if dstype == "train":
119 | list_of_indices = [x for x in range(num_flows) if x not in validate_indices]
120 | elif dstype == "valid":
121 | list_of_indices = validate_indices
122 | elif dstype == "full":
123 | list_of_indices = range(num_flows)
124 | else:
125 | raise ValueError("FlyingChairs: dstype '%s' unknown!", dstype)
126 |
127 | # ----------------------------------------------------------
128 | # Save list of actual filenames for inputs and flows
129 | # ----------------------------------------------------------
130 | self._image_list = []
131 | self._flow_list = []
132 | self._occ_list = []
133 | for i in list_of_indices:
134 | flo_f = flow_f_filenames[i]
135 | flo_b = flow_b_filenames[i]
136 | im1 = image1_filenames[i]
137 | im2 = image2_filenames[i]
138 | self._image_list += [[im1, im2]]
139 | self._flow_list += [[flo_f, flo_b]]
140 | occ1 = occ1_filenames[i]
141 | occ2 = occ2_filenames[i]
142 | self._occ_list += [[occ1, occ2]]
143 |
144 | self._size = len(self._image_list)
145 | assert len(self._image_list) == len(self._flow_list)
146 | assert len(self._occ_list) == len(self._flow_list)
147 |
148 |
149 | # ----------------------------------------------------------
150 | # photometric_augmentations
151 | # ----------------------------------------------------------
152 | if photometric_augmentations:
153 | self._photometric_transform = transforms.ConcatTransformSplitChainer([
154 | # uint8 -> PIL
155 | vision_transforms.ToPILImage(),
156 | # PIL -> PIL : random hsv and contrast
157 | vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
158 | # PIL -> FloatTensor
159 | vision_transforms.transforms.ToTensor(),
160 | transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
161 | ], from_numpy=True, to_numpy=False)
162 |
163 | else:
164 | self._photometric_transform = transforms.ConcatTransformSplitChainer([
165 | # uint8 -> FloatTensor
166 | vision_transforms.transforms.ToTensor(),
167 | ], from_numpy=True, to_numpy=False)
168 |
169 | def __getitem__(self, index):
170 | index = index % self._size
171 |
172 | im1_filename = self._image_list[index][0]
173 | im2_filename = self._image_list[index][1]
174 | flo_f_filename = self._flow_list[index][0]
175 | flo_b_filename = self._flow_list[index][1]
176 | occ1_filename = self._occ_list[index][0]
177 | occ2_filename = self._occ_list[index][1]
178 |
179 | # read float32 images and flow
180 | im1_np0 = common.read_image_as_byte(im1_filename)
181 | im2_np0 = common.read_image_as_byte(im2_filename)
182 | flo_f_np0 = common.read_flo_as_float32(flo_f_filename)
183 | flo_b_np0 = common.read_flo_as_float32(flo_b_filename)
184 | occ1_np0 = common.read_occ_image_as_float32(occ1_filename)
185 | occ2_np0 = common.read_occ_image_as_float32(occ2_filename)
186 |
187 | # possibly apply photometric transformations
188 | im1, im2 = self._photometric_transform(im1_np0, im2_np0)
189 |
190 | # convert flow to FloatTensor
191 | flo_f = common.numpy2torch(flo_f_np0)
192 | flo_b = common.numpy2torch(flo_b_np0)
193 |
194 | # convert occ to FloatTensor
195 | occ1 = common.numpy2torch(occ1_np0)
196 | occ2 = common.numpy2torch(occ2_np0)
197 |
198 | # example filename
199 | basename = os.path.basename(im1_filename)[:5]
200 |
201 | if self.backward:
202 | #inversion des flots (et occ) : backward en 1 et forward en 2
203 | example_dict = {
204 | "input1": im1,
205 | "input2": im2,
206 | "target1": flo_b,
207 | "target2": flo_f,
208 | "target_occ1": occ2,
209 | "target_occ2": occ1,
210 | "index": index,
211 | "basename": basename
212 | }
213 | else:
214 | example_dict = {
215 | "input1": im1,
216 | "input2": im2,
217 | "target1": flo_f,
218 | "target2": flo_b,
219 | "target_occ1": occ1,
220 | "target_occ2": occ2,
221 | "index": index,
222 | "basename": basename
223 | }
224 |
225 | return example_dict
226 |
227 | def __len__(self):
228 | return self._size
229 |
230 |
231 | class FlyingChairsOccTrain(FlyingChairsOcc):
232 | def __init__(self,
233 | args,
234 | root,
235 | photometric_augmentations=True,
236 | backward=False):
237 | super(FlyingChairsOccTrain, self).__init__(
238 | args,
239 | root=root,
240 | photometric_augmentations=photometric_augmentations,
241 | dstype="train", backward=backward)
242 |
243 |
244 | class FlyingChairsOccValid(FlyingChairsOcc):
245 | def __init__(self,
246 | args,
247 | root,
248 | photometric_augmentations=False,
249 | backward=False):
250 | super(FlyingChairsOccValid, self).__init__(
251 | args,
252 | root=root,
253 | photometric_augmentations=photometric_augmentations,
254 | dstype="valid", backward=backward)
255 |
256 |
257 | class FlyingChairsOccFull(FlyingChairsOcc):
258 | def __init__(self,
259 | args,
260 | root,
261 | photometric_augmentations=False,
262 | backward=False):
263 | super(FlyingChairsOccFull, self).__init__(
264 | args,
265 | root=root,
266 | photometric_augmentations=photometric_augmentations,
267 | dstype="full", backward=backward)
268 |
--------------------------------------------------------------------------------
/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | ## Portions of Code from, copyright 2018 Jochen Gast
2 |
3 | from __future__ import absolute_import, division, print_function
4 |
5 | import numpy as np
6 | import torch
7 |
8 |
9 | def image_random_gamma(image, min_gamma=0.7, max_gamma=1.5, clip_image=False):
10 | gamma = np.random.uniform(min_gamma, max_gamma)
11 | adjusted = torch.pow(image, gamma)
12 | if clip_image:
13 | adjusted.clamp_(0.0, 1.0)
14 | return adjusted
15 |
16 |
17 | class RandomGamma:
18 | def __init__(self, min_gamma=0.7, max_gamma=1.5, clip_image=False):
19 | self._min_gamma = min_gamma
20 | self._max_gamma = max_gamma
21 | self._clip_image = clip_image
22 |
23 | def __call__(self, image):
24 | return image_random_gamma(
25 | image,
26 | min_gamma=self._min_gamma,
27 | max_gamma=self._max_gamma,
28 | clip_image=self._clip_image)
29 |
30 |
31 | # ------------------------------------------------------------------
32 | # Allow transformation chains of the type:
33 | # im1, im2, .... = transform(im1, im2, ...)
34 | # ------------------------------------------------------------------
35 | class TransformChainer:
36 | def __init__(self, list_of_transforms):
37 | self._list_of_transforms = list_of_transforms
38 |
39 | def __call__(self, *args):
40 | list_of_args = list(args)
41 | for transform in self._list_of_transforms:
42 | list_of_args = [transform(arg) for arg in list_of_args]
43 | if len(args) == 1:
44 | return list_of_args[0]
45 | else:
46 | return list_of_args
47 |
48 |
49 | # ------------------------------------------------------------------
50 | # Allow transformation chains of the type:
51 | # im1, im2, .... = split( transform( concatenate(im1, im2, ...) ))
52 | # ------------------------------------------------------------------
53 | class ConcatTransformSplitChainer:
54 | def __init__(self, list_of_transforms, from_numpy=True, to_numpy=False):
55 | self._chainer = TransformChainer(list_of_transforms)
56 | self._from_numpy = from_numpy
57 | self._to_numpy = to_numpy
58 |
59 | def __call__(self, *args):
60 | num_splits = len(args)
61 |
62 | if self._from_numpy:
63 | concatenated = np.concatenate(args, axis=0)
64 | else:
65 | concatenated = torch.cat(args, dim=1)
66 |
67 | transformed = self._chainer(concatenated)
68 |
69 | if self._to_numpy:
70 | split = np.split(transformed, indices_or_sections=num_splits, axis=0)
71 | else:
72 | split = torch.chunk(transformed, num_splits, dim=1)
73 |
74 | return split
75 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from glob import glob
4 |
5 | import scipy.misc
6 | import numpy as np
7 | import torch
8 |
9 | from torchvision import transforms as vision_transforms
10 | import models
11 | from datasets import common
12 | from configuration import ModelAndLoss
13 |
14 | from utils.flow import flow_to_png_middlebury, write_flow
15 |
16 | import pylab as pl
17 | pl.interactive(True)
18 |
19 | import argparse
20 |
21 | '''
22 | Example (will save results in ./output/):
23 | python inference.py \
24 | --model StarFlow \
25 | --checkpoint saved_checkpoint/StarFlow_things/checkpoint_best.ckpt \
26 | --data-root /data/mpisintelcomplete/training/final/ambush_6/ \
27 | --file-list frame_0004.png frame_0005.png frame_0006.png frame_0007.png
28 | '''
29 |
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument("--savedir", type=str, default="./output")
32 | parser.add_argument("--data-root", type=str,
33 | default="./")
34 | parser.add_argument('--file-list', nargs='*', default=[-1], type=str)
35 |
36 | parser.add_argument("--model", type=str, default='StarFlow')
37 | parser.add_argument('--checkpoint', dest='checkpoint', default=None,
38 | metavar='PATH', help='path to pre-trained model')
39 |
40 | parser.add_argument('--device', type=int, default=0)
41 | parser.add_argument("--no-cuda", action="store_true")
42 |
43 | args = parser.parse_args()
44 |
45 | # use cuda GPU
46 | use_cuda = (not args.no_cuda) and torch.cuda.is_available()
47 |
48 | # ---------------------
49 | # Load pretrained model
50 | # ---------------------
51 | MODEL = models.__dict__[args.model]
52 | net = ModelAndLoss(None, MODEL(None), None)
53 | checkpoint_with_state = torch.load(args.checkpoint,
54 | map_location=lambda storage,
55 | loc: storage.cuda(args.device))
56 | state_dict = checkpoint_with_state['state_dict']
57 | net.load_state_dict(state_dict)
58 | net.eval()
59 | net.cuda()
60 |
61 | # -------------------
62 | # Load image sequence
63 | # -------------------
64 | if not os.path.exists(args.data_root):
65 | raise ValueError("data-root: {} not found".format(args.data_root))
66 | if len(args.file_list) == 0:
67 | raise ValueError("file-list empty")
68 | elif len(args.file_list) == 1:
69 | path = os.path.join(args.data_root, args.file_list[0])
70 | list_path_imgs = sorted(glob(path))
71 | if len(list_path_imgs) == 0:
72 | raise ValueError("no data were found")
73 | else:
74 | list_path_imgs = [os.path.join(args.data_root, file_name)
75 | for file_name in args.file_list]
76 | for path_im in list_path_imgs:
77 | if not os.path.isfile(path_im):
78 | raise ValueError("file {} not found".format(path_im))
79 | img_reader = common.read_image_as_byte
80 | #flo_reader = common.read_flo_as_float32
81 | imgs_np = [img_reader(path) for path in list_path_imgs]
82 | if imgs_np[0].squeeze().ndim == 2:
83 | imgs_np = [np.dstack([im]*3) for im in imgs_np]
84 | to_tensor = vision_transforms.ToTensor()
85 | images = [to_tensor(im).unsqueeze(0).cuda() for im in imgs_np]
86 | input_dict = {'input_images':images}
87 |
88 | # ---------------
89 | # Flow estimation
90 | # ---------------
91 | with torch.no_grad():
92 | output_dict = net._model(input_dict)
93 |
94 | estimated_flow = output_dict['flow']
95 |
96 | if len(imgs_np) > 2:
97 | estimated_flow_np = estimated_flow[:,0].cpu().numpy()
98 | estimated_flow_np = [flow for flow in estimated_flow_np]
99 | else:
100 | estimated_flow_np = [estimated_flow[0].cpu().numpy()]
101 |
102 |
103 | # ------------
104 | # Save results
105 | # ------------
106 | if not os.path.exists(os.path.join(args.savedir, "visu")):
107 | os.makedirs(os.path.join(args.savedir, "visu"))
108 | if not os.path.exists(os.path.join(args.savedir, "flow")):
109 | os.makedirs(os.path.join(args.savedir, "flow"))
110 | for t in range(len(imgs_np)-1):
111 | flow_visu = flow_to_png_middlebury(estimated_flow_np[t])
112 | basename = os.path.splitext(os.path.basename(list_path_imgs[t]))[0]
113 | file_name_flow_visu = os.path.join(args.savedir, 'visu',
114 | basename + '_flow_visu.png')
115 | file_name_flow = os.path.join(args.savedir, 'flow',
116 | basename + '_flow.flo')
117 | scipy.misc.imsave(file_name_flow_visu, flow_visu)
118 | write_flow(file_name_flow, estimated_flow_np[t].swapaxes(0, 1).swapaxes(1, 2))
119 |
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | cd ./models/correlation_package
3 | python setup.py install
4 | cd ..
5 |
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | ## Portions of Code from, copyright 2018 Jochen Gast
2 |
3 | from __future__ import absolute_import, division, print_function
4 |
5 | import colorama
6 | import logging
7 | import os
8 | import re
9 | import tools
10 | import sys
11 |
12 |
13 | def get_default_logging_format(colorize=False, brackets=False):
14 | style = colorama.Style.DIM if colorize else ''
15 | # color = colorama.Fore.CYAN if colorize else ''
16 | color = colorama.Fore.WHITE if colorize else ''
17 | reset = colorama.Style.RESET_ALL if colorize else ''
18 | if brackets:
19 | result = "{}{}[%(asctime)s]{} %(message)s".format(style, color, reset)
20 | else:
21 | result = "{}{}%(asctime)s{} %(message)s".format(style, color, reset)
22 | return result
23 |
24 |
25 | def get_default_logging_datefmt():
26 | return "%Y-%m-%d %H:%M:%S"
27 |
28 |
29 | def log_module_info(module):
30 | lines = module.__str__().split("\n")
31 | for line in lines:
32 | logging.info(line)
33 |
34 |
35 | class LogbookFormatter(logging.Formatter):
36 | def __init__(self, fmt=None, datefmt=None):
37 | super(LogbookFormatter, self).__init__(fmt=fmt, datefmt=datefmt)
38 | self._re = re.compile(r"\033\[[0-9]+m")
39 |
40 | def remove_colors_from_msg(self, msg):
41 | msg = re.sub(self._re, "", msg)
42 | return msg
43 |
44 | def format(self, record=None):
45 | record.msg = self.remove_colors_from_msg(record.msg)
46 | return super(LogbookFormatter, self).format(record)
47 |
48 |
49 | class ConsoleFormatter(logging.Formatter):
50 | def __init__(self, fmt=None, datefmt=None):
51 | super(ConsoleFormatter, self).__init__(fmt=fmt, datefmt=datefmt)
52 |
53 | def format(self, record=None):
54 | indent = sys.modules[__name__].global_indent
55 | record.msg = " " * indent + record.msg
56 | return super(ConsoleFormatter, self).format(record)
57 |
58 |
59 | class SkipLogbookFilter(logging.Filter):
60 | def filter(self, record):
61 | return record.levelno != logging.LOGBOOK
62 |
63 |
64 | def configure_logging(filename=None):
65 | # set global indent level
66 | sys.modules[__name__].global_indent = 0
67 |
68 | # add custom tqdm logger
69 | tools.addLoggingLevel("LOGBOOK", 1000)
70 |
71 | # create logger
72 | root_logger = logging.getLogger("")
73 | root_logger.setLevel(logging.INFO)
74 |
75 | # create console handler and set level to debug
76 | console = logging.StreamHandler()
77 | console.setLevel(logging.INFO)
78 | fmt = get_default_logging_format(colorize=True, brackets=False)
79 | datefmt = get_default_logging_datefmt()
80 | formatter = ConsoleFormatter(fmt=fmt, datefmt=datefmt)
81 | console.setFormatter(formatter)
82 |
83 | # Skip logging.tqdm requests for console outputs
84 | skip_logbook_filter = SkipLogbookFilter()
85 | console.addFilter(skip_logbook_filter)
86 |
87 | # add console to root_logger
88 | root_logger.addHandler(console)
89 |
90 | # add logbook
91 | if filename is not None:
92 | # ensure dir
93 | d = os.path.dirname(filename)
94 | if not os.path.exists(d):
95 | os.makedirs(d)
96 |
97 | # --------------------------------------------------------------------------------------
98 | # Configure handler that removes color codes from logbook
99 | # --------------------------------------------------------------------------------------
100 | logbook = logging.FileHandler(filename=filename, mode="a", encoding="utf-8")
101 | logbook.setLevel(logging.INFO)
102 | fmt = get_default_logging_format(colorize=False, brackets=True)
103 | logbook_formatter = LogbookFormatter(fmt=fmt, datefmt=datefmt)
104 | logbook.setFormatter(logbook_formatter)
105 | root_logger.addHandler(logbook)
106 |
107 |
108 | class LoggingBlock:
109 | def __init__(self, title, emph=False):
110 | self._emph = emph
111 | bright = colorama.Style.BRIGHT
112 | cyan = colorama.Fore.CYAN
113 | reset = colorama.Style.RESET_ALL
114 | if emph:
115 | logging.info("%s==>%s %s%s%s" % (cyan, reset, bright, title, reset))
116 | else:
117 | logging.info(title)
118 |
119 | def __enter__(self):
120 | sys.modules[__name__].global_indent += 2
121 | return self
122 |
123 | def __exit__(self, exc_type, exc_value, traceback):
124 | sys.modules[__name__].global_indent -= 2
125 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import os
4 | import subprocess
5 | import commandline
6 | import configuration as config
7 | import runtime
8 | import logger
9 | import logging
10 | import tools
11 | import torch
12 |
13 | def main():
14 |
15 | # Change working directory
16 | os.chdir(os.path.dirname(os.path.realpath(__file__)))
17 |
18 | # Parse commandline arguments
19 | args = commandline.setup_logging_and_parse_arguments(blocktitle="Commandline Arguments")
20 |
21 | # set cuda device:
22 | if args.cuda:
23 | torch.cuda.set_device(args.device)
24 |
25 | # Set random seed, possibly on Cuda
26 | config.configure_random_seed(args)
27 |
28 | # DataLoader
29 | train_loader, validation_loader, inference_loader = config.configure_data_loaders(args)
30 | success = any(loader is not None for loader in [train_loader, validation_loader, inference_loader])
31 | if not success:
32 | logging.info("No dataset could be loaded successfully. Please check dataset paths!")
33 | quit()
34 |
35 | if args.resume:
36 | args.checkpoint = os.path.join(args.save, 'checkpoint_latest.ckpt')
37 | args.optim_checkpoint = os.path.join(args.save, 'optim_state_dict_checkpoint.pth')
38 | checkpoint_with_state = torch.load(args.checkpoint,
39 | map_location=lambda storage,
40 | loc: storage.cuda(args.device))
41 | args.start_epoch = checkpoint_with_state['epoch'] + 1
42 |
43 | # Configure data augmentation
44 | training_augmentation, validation_augmentation = config.configure_runtime_augmentations(args)
45 |
46 | # Configure model and loss
47 | model_and_loss = config.configure_model_and_loss(args)
48 |
49 | # Resume from checkpoint if available
50 | checkpoint_saver, checkpoint_stats = config.configure_checkpoint_saver(args, model_and_loss)
51 |
52 | # Checkpoint and save directory
53 | with logger.LoggingBlock("Save Directory", emph=True):
54 | logging.info("Save directory: %s" % args.save)
55 | if not os.path.exists(args.save):
56 | os.makedirs(args.save)
57 |
58 | # # Multi-GPU automation
59 | # with logger.LoggingBlock("Multi GPU", emph=True):
60 | # if torch.cuda.device_count() > 1:
61 | # logging.info("Let's use %d GPUs!" % torch.cuda.device_count())
62 | # model_and_loss._model = torch.nn.DataParallel(model_and_loss._model)
63 | # else:
64 | # logging.info("Let's use %d GPU!" % torch.cuda.device_count())
65 |
66 | # Configure optimizer
67 | optimizer = config.configure_optimizer(args, model_and_loss)
68 |
69 | # Configure learning rate
70 | lr_scheduler = config.configure_lr_scheduler(args, optimizer)
71 |
72 | # If this is just an evaluation: overwrite savers and epochs
73 | if args.evaluation:
74 | args.start_epoch = 1
75 | args.total_epochs = 1
76 | train_loader = None
77 | checkpoint_saver = None
78 | optimizer = None
79 | lr_scheduler = None
80 |
81 | # Cuda optimization
82 | if args.cuda:
83 | torch.backends.cudnn.benchmark = True
84 |
85 | # Kickoff training, validation and/or testing
86 | return runtime.exec_runtime(
87 | args,
88 | checkpoint_saver=checkpoint_saver,
89 | model_and_loss=model_and_loss,
90 | optimizer=optimizer,
91 | lr_scheduler=lr_scheduler,
92 | train_loader=train_loader,
93 | validation_loader=validation_loader,
94 | inference_loader=inference_loader,
95 | training_augmentation=training_augmentation,
96 | validation_augmentation=validation_augmentation)
97 |
98 | if __name__ == "__main__":
99 | main()
100 |
--------------------------------------------------------------------------------
/models/IRR_PWC.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .pwc_modules import conv, upsample2d_as, rescale_flow, initialize_msra
7 | from .pwc_modules import WarpingLayer, FeatureExtractor
8 | from .pwc_modules import ContextNetwork, FlowEstimatorDense
9 | from .pwc_modules import OccContextNetwork, OccEstimatorDense
10 | from .irr_modules import OccUpsampleNetwork, RefineFlow, RefineOcc
11 | from .correlation_package.correlation import Correlation
12 |
13 | import copy
14 |
15 |
16 | class PWCNet(nn.Module):
17 | def __init__(self, args, div_flow=0.05):
18 | super(PWCNet, self).__init__()
19 | self.args = args
20 | self._div_flow = div_flow
21 | self.search_range = 4
22 | self.num_chs = [3, 16, 32, 64, 96, 128, 196]
23 | self.output_level = 4
24 | self.num_levels = 7
25 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
26 |
27 | self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
28 | self.warping_layer = WarpingLayer()
29 |
30 | self.dim_corr = (self.search_range * 2 + 1) ** 2
31 | self.num_ch_in_flo = self.dim_corr + 32 + 2
32 | self.num_ch_in_occ = self.dim_corr + 32 + 1
33 |
34 | self.flow_estimators = FlowEstimatorDense(self.num_ch_in_flo)
35 | self.context_networks = ContextNetwork(self.num_ch_in_flo + 448 + 2)
36 | self.occ_estimators = OccEstimatorDense(self.num_ch_in_occ)
37 | self.occ_context_networks = OccContextNetwork(self.num_ch_in_occ + 448 + 1)
38 | self.occ_shuffle_upsample = OccUpsampleNetwork(11, 1)
39 |
40 | self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),
41 | conv(128, 32, kernel_size=1, stride=1, dilation=1),
42 | conv(96, 32, kernel_size=1, stride=1, dilation=1),
43 | conv(64, 32, kernel_size=1, stride=1, dilation=1)])
44 |
45 | self.conv_1x1_1 = conv(16, 3, kernel_size=1, stride=1, dilation=1)
46 |
47 | self.refine_flow = RefineFlow(2 + 1 + 32)
48 | self.refine_occ = RefineOcc(1 + 32 + 32)
49 |
50 | initialize_msra(self.modules())
51 |
52 | def forward(self, input_dict):
53 |
54 | x1_raw = input_dict['input1']
55 | x2_raw = input_dict['input2']
56 | batch_size, _, height_im, width_im = x1_raw.size()
57 |
58 | # on the bottom level are original images
59 | x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
60 | x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
61 |
62 | # outputs
63 | output_dict = {}
64 | output_dict_eval = {}
65 | flows = []
66 | occs = []
67 |
68 | _, _, h_x1, w_x1, = x1_pyramid[0].size()
69 | flow_f = torch.zeros(batch_size, 2, h_x1, w_x1).float().cuda()
70 | flow_b = torch.zeros(batch_size, 2, h_x1, w_x1).float().cuda()
71 | occ_f = torch.zeros(batch_size, 1, h_x1, w_x1).float().cuda()
72 | occ_b = torch.zeros(batch_size, 1, h_x1, w_x1).float().cuda()
73 |
74 | for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
75 |
76 | if l <= self.output_level:
77 |
78 | # warping
79 | if l == 0:
80 | x2_warp = x2
81 | x1_warp = x1
82 | else:
83 | flow_f = upsample2d_as(flow_f, x1, mode="bilinear")
84 | flow_b = upsample2d_as(flow_b, x2, mode="bilinear")
85 | occ_f = upsample2d_as(occ_f, x1, mode="bilinear")
86 | occ_b = upsample2d_as(occ_b, x2, mode="bilinear")
87 | x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)
88 | x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)
89 |
90 | # correlation
91 | out_corr_f = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x1, x2_warp)
92 | out_corr_b = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x2, x1_warp)
93 | out_corr_relu_f = self.leakyRELU(out_corr_f)
94 | out_corr_relu_b = self.leakyRELU(out_corr_b)
95 |
96 | if l != self.output_level:
97 | x1_1by1 = self.conv_1x1[l](x1)
98 | x2_1by1 = self.conv_1x1[l](x2)
99 | else:
100 | x1_1by1 = x1
101 | x2_1by1 = x2
102 |
103 | # concat and estimate flow
104 | flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=True)
105 | flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=True)
106 |
107 | x_intm_f, flow_res_f = self.flow_estimators(torch.cat([out_corr_relu_f, x1_1by1, flow_f], dim=1))
108 | x_intm_b, flow_res_b = self.flow_estimators(torch.cat([out_corr_relu_b, x2_1by1, flow_b], dim=1))
109 | flow_est_f = flow_f + flow_res_f
110 | flow_est_b = flow_b + flow_res_b
111 |
112 | flow_cont_f = flow_est_f + self.context_networks(torch.cat([x_intm_f, flow_est_f], dim=1))
113 | flow_cont_b = flow_est_b + self.context_networks(torch.cat([x_intm_b, flow_est_b], dim=1))
114 |
115 | # occ estimation
116 | x_intm_occ_f, occ_res_f = self.occ_estimators(torch.cat([out_corr_relu_f, x1_1by1, occ_f], dim=1))
117 | x_intm_occ_b, occ_res_b = self.occ_estimators(torch.cat([out_corr_relu_b, x2_1by1, occ_b], dim=1))
118 | occ_est_f = occ_f + occ_res_f
119 | occ_est_b = occ_b + occ_res_b
120 |
121 | occ_cont_f = occ_est_f + self.occ_context_networks(torch.cat([x_intm_occ_f, occ_est_f], dim=1))
122 | occ_cont_b = occ_est_b + self.occ_context_networks(torch.cat([x_intm_occ_b, occ_est_b], dim=1))
123 |
124 | # refinement
125 | img1_resize = upsample2d_as(x1_raw, flow_f, mode="bilinear")
126 | img2_resize = upsample2d_as(x2_raw, flow_b, mode="bilinear")
127 | img2_warp = self.warping_layer(img2_resize, rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow)
128 | img1_warp = self.warping_layer(img1_resize, rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow)
129 |
130 | # flow refine
131 | flow_f = self.refine_flow(flow_cont_f.detach(), img1_resize - img2_warp, x1_1by1)
132 | flow_b = self.refine_flow(flow_cont_b.detach(), img2_resize - img1_warp, x2_1by1)
133 |
134 | flow_cont_f = rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False)
135 | flow_cont_b = rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False)
136 | flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=False)
137 | flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=False)
138 |
139 | # occ refine
140 | x2_1by1_warp = self.warping_layer(x2_1by1, flow_f, height_im, width_im, self._div_flow)
141 | x1_1by1_warp = self.warping_layer(x1_1by1, flow_b, height_im, width_im, self._div_flow)
142 |
143 | occ_f = self.refine_occ(occ_cont_f.detach(), x1_1by1, x1_1by1 - x2_1by1_warp)
144 | occ_b = self.refine_occ(occ_cont_b.detach(), x2_1by1, x2_1by1 - x1_1by1_warp)
145 |
146 | flows.append([flow_cont_f, flow_cont_b, flow_f, flow_b])
147 | occs.append([occ_cont_f, occ_cont_b, occ_f, occ_b])
148 |
149 | else:
150 | flow_f = upsample2d_as(flow_f, x1, mode="bilinear")
151 | flow_b = upsample2d_as(flow_b, x2, mode="bilinear")
152 | flows.append([flow_f, flow_b])
153 |
154 | x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)
155 | x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)
156 | flow_b_warp = self.warping_layer(flow_b, flow_f, height_im, width_im, self._div_flow)
157 | flow_f_warp = self.warping_layer(flow_f, flow_b, height_im, width_im, self._div_flow)
158 |
159 | if l != self.num_levels-1:
160 | x1_in = self.conv_1x1_1(x1)
161 | x2_in = self.conv_1x1_1(x2)
162 | x1_w_in = self.conv_1x1_1(x1_warp)
163 | x2_w_in = self.conv_1x1_1(x2_warp)
164 | else:
165 | x1_in = x1
166 | x2_in = x2
167 | x1_w_in = x1_warp
168 | x2_w_in = x2_warp
169 |
170 | occ_f = self.occ_shuffle_upsample(occ_f, torch.cat([x1_in, x2_w_in, flow_f, flow_b_warp], dim=1))
171 | occ_b = self.occ_shuffle_upsample(occ_b, torch.cat([x2_in, x1_w_in, flow_b, flow_f_warp], dim=1))
172 |
173 | occs.append([occ_f, occ_b])
174 |
175 | output_dict_eval['flow'] = upsample2d_as(flow_f, x1_raw, mode="bilinear") * (1.0 / self._div_flow)
176 | output_dict_eval['occ'] = upsample2d_as(occ_f, x1_raw, mode="bilinear")
177 | output_dict['flow'] = flows
178 | output_dict['occ'] = occs
179 |
180 | if self.training:
181 | return output_dict
182 | else:
183 | return output_dict_eval
184 |
--------------------------------------------------------------------------------
/models/IRR_PWC_occ_joint.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .pwc_modules import conv, upsample2d_as, rescale_flow, initialize_msra
7 | from .pwc_modules import WarpingLayer, FeatureExtractor
8 | from .pwc_modules import FlowAndOccEstimatorDense, FlowAndOccContextNetwork
9 | from .irr_modules import OccUpsampleNetwork, RefineFlow, RefineOcc
10 | from .correlation_package.correlation import Correlation
11 |
12 | import copy
13 |
14 |
15 | class PWCNet(nn.Module):
16 | def __init__(self, args, div_flow=0.05):
17 | super(PWCNet, self).__init__()
18 | self.args = args
19 | self._div_flow = div_flow
20 | self.search_range = 4
21 | self.num_chs = [3, 16, 32, 64, 96, 128, 196]
22 | self.output_level = 4
23 | self.num_levels = 7
24 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
25 |
26 | self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
27 | self.warping_layer = WarpingLayer()
28 |
29 | self.dim_corr = (self.search_range * 2 + 1) ** 2
30 | self.num_ch_in = self.dim_corr + 32 + 2 + 1
31 |
32 | self.flow_and_occ_estimators = FlowAndOccEstimatorDense(self.num_ch_in)
33 | self.context_networks = FlowAndOccContextNetwork(self.num_ch_in + 448 + 2 + 1)
34 | self.occ_shuffle_upsample = OccUpsampleNetwork(11, 1)
35 |
36 | self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),
37 | conv(128, 32, kernel_size=1, stride=1, dilation=1),
38 | conv(96, 32, kernel_size=1, stride=1, dilation=1),
39 | conv(64, 32, kernel_size=1, stride=1, dilation=1)])
40 |
41 | self.conv_1x1_1 = conv(16, 3, kernel_size=1, stride=1, dilation=1)
42 |
43 | self.refine_flow = RefineFlow(2 + 1 + 32)
44 | self.refine_occ = RefineOcc(1 + 32 + 32)
45 |
46 | initialize_msra(self.modules())
47 |
48 | def forward(self, input_dict):
49 |
50 | x1_raw = input_dict['input1']
51 | x2_raw = input_dict['input2']
52 | batch_size, _, height_im, width_im = x1_raw.size()
53 |
54 | # on the bottom level are original images
55 | x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
56 | x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
57 |
58 | # outputs
59 | output_dict = {}
60 | output_dict_eval = {}
61 | flows = []
62 | occs = []
63 |
64 | _, _, h_x1, w_x1, = x1_pyramid[0].size()
65 | flow_f = torch.zeros(batch_size, 2, h_x1, w_x1).float().cuda()
66 | flow_b = torch.zeros(batch_size, 2, h_x1, w_x1).float().cuda()
67 | occ_f = torch.zeros(batch_size, 1, h_x1, w_x1).float().cuda()
68 | occ_b = torch.zeros(batch_size, 1, h_x1, w_x1).float().cuda()
69 |
70 | for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
71 |
72 | if l <= self.output_level:
73 |
74 | # warping
75 | if l == 0:
76 | x2_warp = x2
77 | x1_warp = x1
78 | else:
79 | flow_f = upsample2d_as(flow_f, x1, mode="bilinear")
80 | flow_b = upsample2d_as(flow_b, x2, mode="bilinear")
81 | occ_f = upsample2d_as(occ_f, x1, mode="bilinear")
82 | occ_b = upsample2d_as(occ_b, x2, mode="bilinear")
83 | x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)
84 | x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)
85 |
86 | # correlation
87 | out_corr_f = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x1, x2_warp)
88 | out_corr_b = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x2, x1_warp)
89 | out_corr_relu_f = self.leakyRELU(out_corr_f)
90 | out_corr_relu_b = self.leakyRELU(out_corr_b)
91 |
92 | if l != self.output_level:
93 | x1_1by1 = self.conv_1x1[l](x1)
94 | x2_1by1 = self.conv_1x1[l](x2)
95 | else:
96 | x1_1by1 = x1
97 | x2_1by1 = x2
98 |
99 | # concat and estimate flow and occ
100 | flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=True)
101 | flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=True)
102 |
103 | x_intm_f, flow_res_f, occ_res_f = self.flow_and_occ_estimators(torch.cat([out_corr_relu_f, x1_1by1, flow_f, occ_f], dim=1))
104 | x_intm_b, flow_res_b, occ_res_b = self.flow_and_occ_estimators(torch.cat([out_corr_relu_b, x2_1by1, flow_b, occ_b], dim=1))
105 | flow_est_f = flow_f + flow_res_f
106 | flow_est_b = flow_b + flow_res_b
107 | occ_est_f = occ_f + occ_res_f
108 | occ_est_b = occ_b + occ_res_b
109 |
110 | flow_fine_f, occ_fine_f = self.context_networks(torch.cat([x_intm_f, flow_est_f, occ_est_f], dim=1))
111 | flow_fine_b, occ_fine_b = self.context_networks(torch.cat([x_intm_b, flow_est_b, occ_est_b], dim=1))
112 |
113 | flow_cont_f = flow_est_f + flow_fine_f
114 | flow_cont_b = flow_est_b + flow_fine_b
115 | occ_cont_f = occ_est_f + occ_fine_f
116 | occ_cont_b = occ_est_b + occ_fine_b
117 |
118 | # refinement
119 | img1_resize = upsample2d_as(x1_raw, flow_f, mode="bilinear")
120 | img2_resize = upsample2d_as(x2_raw, flow_b, mode="bilinear")
121 | img2_warp = self.warping_layer(img2_resize, rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow)
122 | img1_warp = self.warping_layer(img1_resize, rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow)
123 |
124 | # flow refine
125 | flow_f = self.refine_flow(flow_cont_f.detach(), img1_resize - img2_warp, x1_1by1)
126 | flow_b = self.refine_flow(flow_cont_b.detach(), img2_resize - img1_warp, x2_1by1)
127 |
128 | flow_cont_f = rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False)
129 | flow_cont_b = rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False)
130 | flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=False)
131 | flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=False)
132 |
133 | # occ refine
134 | x2_1by1_warp = self.warping_layer(x2_1by1, flow_f, height_im, width_im, self._div_flow)
135 | x1_1by1_warp = self.warping_layer(x1_1by1, flow_b, height_im, width_im, self._div_flow)
136 |
137 | occ_f = self.refine_occ(occ_cont_f.detach(), x1_1by1, x1_1by1 - x2_1by1_warp)
138 | occ_b = self.refine_occ(occ_cont_b.detach(), x2_1by1, x2_1by1 - x1_1by1_warp)
139 |
140 | flows.append([flow_cont_f, flow_cont_b, flow_f, flow_b])
141 | occs.append([occ_cont_f, occ_cont_b, occ_f, occ_b])
142 |
143 | else:
144 | flow_f = upsample2d_as(flow_f, x1, mode="bilinear")
145 | flow_b = upsample2d_as(flow_b, x2, mode="bilinear")
146 | flows.append([flow_f, flow_b])
147 |
148 | x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)
149 | x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)
150 | flow_b_warp = self.warping_layer(flow_b, flow_f, height_im, width_im, self._div_flow)
151 | flow_f_warp = self.warping_layer(flow_f, flow_b, height_im, width_im, self._div_flow)
152 |
153 | if l != self.num_levels-1:
154 | x1_in = self.conv_1x1_1(x1)
155 | x2_in = self.conv_1x1_1(x2)
156 | x1_w_in = self.conv_1x1_1(x1_warp)
157 | x2_w_in = self.conv_1x1_1(x2_warp)
158 | else:
159 | x1_in = x1
160 | x2_in = x2
161 | x1_w_in = x1_warp
162 | x2_w_in = x2_warp
163 |
164 | occ_f = self.occ_shuffle_upsample(occ_f, torch.cat([x1_in, x2_w_in, flow_f, flow_b_warp], dim=1))
165 | occ_b = self.occ_shuffle_upsample(occ_b, torch.cat([x2_in, x1_w_in, flow_b, flow_f_warp], dim=1))
166 |
167 | occs.append([occ_f, occ_b])
168 |
169 | output_dict_eval['flow'] = upsample2d_as(flow_f, x1_raw, mode="bilinear") * (1.0 / self._div_flow)
170 | output_dict_eval['occ'] = upsample2d_as(occ_f, x1_raw, mode="bilinear")
171 | output_dict['flow'] = flows
172 | output_dict['occ'] = occs
173 |
174 | if self.training:
175 | return output_dict
176 | else:
177 | return output_dict_eval
178 |
--------------------------------------------------------------------------------
/models/STAR.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .pwc_modules import conv, upsample2d_as, rescale_flow, initialize_msra
7 | from .pwc_modules import WarpingLayer, FeatureExtractor
8 | from .pwc_modules import FlowAndOccContextNetwork, FlowAndOccEstimatorDense
9 | from .irr_modules import OccUpsampleNetwork, RefineFlow, RefineOcc
10 | from .correlation_package.correlation import Correlation
11 |
12 | import copy
13 |
14 |
15 | class StarFlow(nn.Module):
16 | def __init__(self, args, div_flow=0.05):
17 | super(StarFlow, self).__init__()
18 | self.args = args
19 | self._div_flow = div_flow
20 | self.search_range = 4
21 | self.num_chs = [3, 16, 32, 64, 96, 128, 196]
22 | self.output_level = 4
23 | self.num_levels = 7
24 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
25 |
26 | self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
27 | self.warping_layer = WarpingLayer()
28 |
29 | self.dim_corr = (self.search_range * 2 + 1) ** 2
30 | self.num_ch_in = self.dim_corr + 32 + 2 + 1
31 |
32 | self.flow_and_occ_estimators = FlowAndOccEstimatorDense(2 * self.num_ch_in)
33 | self.context_networks = FlowAndOccContextNetwork(2 * self.num_ch_in + 448 + 2 + 1)
34 |
35 | self.occ_shuffle_upsample = OccUpsampleNetwork(11, 1)
36 |
37 | self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),
38 | conv(128, 32, kernel_size=1, stride=1, dilation=1),
39 | conv(96, 32, kernel_size=1, stride=1, dilation=1),
40 | conv(64, 32, kernel_size=1, stride=1, dilation=1)])
41 |
42 | self.conv_1x1_1 = conv(16, 3, kernel_size=1, stride=1, dilation=1)
43 |
44 | self.conv_1x1_time = conv(2 * self.num_ch_in + 448, self.num_ch_in, kernel_size=1, stride=1, dilation=1)
45 |
46 | self.refine_flow = RefineFlow(2 + 1 + 32)
47 | self.refine_occ = RefineOcc(1 + 32 + 32)
48 |
49 | initialize_msra(self.modules())
50 |
51 | def forward(self, input_dict):
52 |
53 | if 'input_images' in input_dict.keys():
54 | list_imgs = input_dict['input_images']
55 | else:
56 | x1_raw = input_dict['input1']
57 | x2_raw = input_dict['input2']
58 | list_imgs = [x1_raw, x2_raw]
59 |
60 | _, _, height_im, width_im = list_imgs[0].size()
61 |
62 | # on the bottom level are original images
63 | list_pyramids = [] #indices : [time][level]
64 | for im in list_imgs:
65 | list_pyramids.append(self.feature_pyramid_extractor(im) + [im])
66 |
67 | # outputs
68 | output_dict = {}
69 | output_dict_eval = {}
70 | flows_f = [] #indices : [level][time]
71 | flows_b = [] #indices : [level][time]
72 | occs_f = []
73 | occs_b = []
74 | flows_coarse_f = []
75 | occs_coarse_f = []
76 | for l in range(len(list_pyramids[0])):
77 | flows_f.append([])
78 | flows_b.append([])
79 | occs_f.append([])
80 | occs_b.append([])
81 | for l in range(self.output_level + 1):
82 | flows_coarse_f.append([])
83 | occs_coarse_f.append([])
84 |
85 | # init
86 | b_size, _, h_x1, w_x1, = list_pyramids[0][0].size()
87 | init_dtype = list_pyramids[0][0].dtype
88 | init_device = list_pyramids[0][0].device
89 | flow_f = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
90 | flow_b = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
91 | occ_f = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
92 | occ_b = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
93 | previous_features = []
94 |
95 | for i in range(len(list_imgs) - 1):
96 | x1_pyramid, x2_pyramid = list_pyramids[i:i+2]
97 |
98 | for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
99 |
100 | if l <= self.output_level:
101 | if i == 0:
102 | bs_, _, h_, w_, = list_pyramids[0][l].size()
103 | previous_features.append(torch.zeros(bs_, self.num_ch_in, h_, w_, dtype=init_dtype, device=init_device).float())
104 |
105 | # warping
106 | if l == 0:
107 | x2_warp = x2
108 | x1_warp = x1
109 | else:
110 | flow_f = upsample2d_as(flow_f, x1, mode="bilinear")
111 | flow_b = upsample2d_as(flow_b, x2, mode="bilinear")
112 | occ_f = upsample2d_as(occ_f, x1, mode="bilinear")
113 | occ_b = upsample2d_as(occ_b, x2, mode="bilinear")
114 | x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)
115 | x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)
116 |
117 | # correlation
118 | out_corr_f = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x1, x2_warp)
119 | out_corr_b = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x2, x1_warp)
120 | out_corr_relu_f = self.leakyRELU(out_corr_f)
121 | out_corr_relu_b = self.leakyRELU(out_corr_b)
122 |
123 | if l != self.output_level:
124 | x1_1by1 = self.conv_1x1[l](x1)
125 | x2_1by1 = self.conv_1x1[l](x2)
126 | else:
127 | x1_1by1 = x1
128 | x2_1by1 = x2
129 |
130 | if i > 0: #temporal connection:
131 | previous_features[l] = self.warping_layer(previous_features[l],
132 | flows_b[l][-1], height_im, width_im, self._div_flow)
133 |
134 | # Flow and occlusions estimation
135 | flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=True)
136 | flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=True)
137 |
138 | features = torch.cat([previous_features[l], out_corr_relu_f, x1_1by1, flow_f, occ_f], 1)
139 | features_b = torch.cat([torch.zeros_like(previous_features[l]), out_corr_relu_b, x2_1by1, flow_b, occ_b], 1)
140 |
141 | x_intm_f, flow_res_f, occ_res_f = self.flow_and_occ_estimators(features)
142 | flow_est_f = flow_f + flow_res_f
143 | occ_est_f = occ_f + occ_res_f
144 | with torch.no_grad():
145 | x_intm_b, flow_res_b, occ_res_b = self.flow_and_occ_estimators(features_b)
146 | flow_est_b = flow_b + flow_res_b
147 | occ_est_b = occ_b + occ_res_b
148 |
149 | # Context:
150 | flow_cont_res_f, occ_cont_res_f = self.context_networks(torch.cat([x_intm_f, flow_est_f, occ_est_f], dim=1))
151 | flow_cont_f = flow_est_f + flow_cont_res_f
152 | occ_cont_f = occ_est_f + occ_cont_res_f
153 | with torch.no_grad():
154 | flow_cont_res_b, occ_cont_res_b = self.context_networks(torch.cat([x_intm_b, flow_est_b, occ_est_b], dim=1))
155 | flow_cont_b = flow_est_b + flow_cont_res_b
156 | occ_cont_b = occ_est_b + occ_cont_res_b
157 |
158 | # refinement
159 | img1_resize = upsample2d_as(list_imgs[i], flow_f, mode="bilinear")
160 | img2_resize = upsample2d_as(list_imgs[i+1], flow_b, mode="bilinear")
161 | img2_warp = self.warping_layer(img2_resize, rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow)
162 | img1_warp = self.warping_layer(img1_resize, rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow)
163 |
164 | # flow refine
165 | flow_f = self.refine_flow(flow_cont_f.detach(), img1_resize - img2_warp, x1_1by1)
166 | flow_b = self.refine_flow(flow_cont_b.detach(), img2_resize - img1_warp, x2_1by1)
167 |
168 | flow_cont_f = rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False)
169 | flow_cont_b = rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False)
170 | flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=False)
171 | flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=False)
172 |
173 | # occ refine
174 | x2_1by1_warp = self.warping_layer(x2_1by1, flow_f, height_im, width_im, self._div_flow)
175 | x1_1by1_warp = self.warping_layer(x1_1by1, flow_b, height_im, width_im, self._div_flow)
176 |
177 | occ_f = self.refine_occ(occ_cont_f.detach(), x1_1by1, x1_1by1 - x2_1by1_warp)
178 | occ_b = self.refine_occ(occ_cont_b.detach(), x2_1by1, x2_1by1 - x1_1by1_warp)
179 |
180 | # save features for temporal connection:
181 | previous_features[l] = self.conv_1x1_time(x_intm_f)
182 | flows_f[l].append(flow_f)
183 | occs_f[l].append(occ_f)
184 | flows_b[l].append(flow_b)
185 | occs_b[l].append(occ_b)
186 | flows_coarse_f[l].append(flow_cont_f)
187 | occs_coarse_f[l].append(occ_cont_f)
188 | #flows.append([flow_cont_f, flow_cont_b, flow_f, flow_b])
189 | #occs.append([occ_cont_f, occ_cont_b, occ_f, occ_b])
190 |
191 | else:
192 | flow_f = upsample2d_as(flow_f, x1, mode="bilinear")
193 | flow_b = upsample2d_as(flow_b, x2, mode="bilinear")
194 | flows_f[l].append(flow_f)
195 | flows_b[l].append(flow_b)
196 | #flows.append([flow_f, flow_b])
197 |
198 | x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)
199 | x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)
200 | flow_b_warp = self.warping_layer(flow_b, flow_f, height_im, width_im, self._div_flow)
201 | flow_f_warp = self.warping_layer(flow_f, flow_b, height_im, width_im, self._div_flow)
202 |
203 | if l != self.num_levels-1:
204 | x1_in = self.conv_1x1_1(x1)
205 | x2_in = self.conv_1x1_1(x2)
206 | x1_w_in = self.conv_1x1_1(x1_warp)
207 | x2_w_in = self.conv_1x1_1(x2_warp)
208 | else:
209 | x1_in = x1
210 | x2_in = x2
211 | x1_w_in = x1_warp
212 | x2_w_in = x2_warp
213 |
214 | occ_f = self.occ_shuffle_upsample(occ_f, torch.cat([x1_in, x2_w_in, flow_f, flow_b_warp], dim=1))
215 | occ_b = self.occ_shuffle_upsample(occ_b, torch.cat([x2_in, x1_w_in, flow_b, flow_f_warp], dim=1))
216 |
217 | occs_f[l].append(occ_f)
218 | occs_b[l].append(occ_b)
219 | #occs.append([occ_f, occ_b])
220 |
221 | flow_f = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
222 | flow_b = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
223 | occ_f = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
224 | occ_b = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
225 |
226 | if self.training:
227 | if len(list_imgs) > 2:
228 | for l in range(len(flows_f)):
229 | flows_f[l] = torch.stack(flows_f[l], 0)
230 | occs_f[l] = torch.stack(occs_f[l], 0)
231 | for l in range(len(flows_coarse_f)):
232 | flows_coarse_f[l] = torch.stack(flows_coarse_f[l], 0)
233 | occs_coarse_f[l] = torch.stack(occs_coarse_f[l], 0)
234 | else:
235 | for l in range(len(flows_f)):
236 | flows_f[l] = flows_f[l][0]
237 | occs_f[l] = occs_f[l][0]
238 | for l in range(len(flows_coarse_f)):
239 | flows_coarse_f[l] = flows_coarse_f[l][0]
240 | occs_coarse_f[l] = occs_coarse_f[l][0]
241 | output_dict['flow'] = flows_f
242 | output_dict['occ'] = occs_f
243 | output_dict['flow_coarse'] = flows_coarse_f
244 | output_dict['occ_coarse'] = occs_coarse_f
245 | return output_dict
246 | else:
247 | output_dict_eval = {}
248 | if len(list_imgs) > 2:
249 | out_flow = []
250 | out_occ = []
251 | for i in range(len(flows_f[0])):
252 | out_flow.append(upsample2d_as(flows_f[-1][i], list_imgs[0], mode="bilinear") * (1.0 / self._div_flow))
253 | out_occ.append(upsample2d_as(occs_f[-1][i], list_imgs[0], mode="bilinear"))
254 | out_flow = torch.stack(out_flow, 0)
255 | out_occ = torch.stack(out_occ, 0)
256 | else:
257 | out_flow = upsample2d_as(flows_f[-1][0], list_imgs[0], mode="bilinear") * (1.0 / self._div_flow)
258 | out_occ = upsample2d_as(occs_f[-1][0], list_imgs[0], mode="bilinear")
259 | output_dict_eval['flow'] = out_flow
260 | output_dict_eval['occ'] = out_occ
261 | return output_dict_eval
262 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from . import pwcnet
2 | from . import pwcnet_irr
3 |
4 | from . import pwcnet_occ_joint
5 | from . import pwcnet_irr_occ_joint
6 |
7 | from . import tr_flow
8 | from . import tr_features
9 |
10 | from . import IRR_PWC
11 | from . import IRR_PWC_occ_joint
12 | from . import STAR
13 |
14 | PWCNet = pwcnet.PWCNet
15 | PWCNet_irr = pwcnet_irr.PWCNet
16 | PWCNet_occ_joint = pwcnet_occ_joint.PWCNet
17 | PWCNet_irr_occ_joint = pwcnet_irr_occ_joint.PWCNet
18 |
19 | TRFlow = tr_flow.TRFlow
20 | TRFlow_occjoint = tr_flow.TRFlow_occjoint
21 | TRFlow_irr = tr_flow.TRFlow_irr
22 | TRFlow_irr_occjoint = tr_flow.TRFlow_irr_occjoint
23 |
24 | TRFeat = tr_features.TRFeat
25 | TRFeat_occjoint = tr_features.TRFeat_occjoint
26 | TRFeat_irr_occjoint = tr_features.TRFeat_irr_occjoint
27 |
28 | # -- With refinement ---
29 |
30 | IRR_PWC = IRR_PWC.PWCNet
31 | IRR_occ_joint = IRR_PWC_occ_joint.PWCNet
32 |
33 | StarFlow = STAR.StarFlow
34 |
--------------------------------------------------------------------------------
/models/correlation_package/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pgodet/star_flow/cedb96ff339d11abf71d12d09e794593a742ccce/models/correlation_package/__init__.py
--------------------------------------------------------------------------------
/models/correlation_package/correlation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.modules.module import Module
3 | from torch.autograd import Function
4 | import correlation_cuda
5 |
6 | class CorrelationFunction(Function):
7 |
8 | def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
9 | super(CorrelationFunction, self).__init__()
10 | self.pad_size = pad_size
11 | self.kernel_size = kernel_size
12 | self.max_displacement = max_displacement
13 | self.stride1 = stride1
14 | self.stride2 = stride2
15 | self.corr_multiply = corr_multiply
16 | # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)
17 |
18 | def forward(self, input1, input2):
19 | self.save_for_backward(input1, input2)
20 |
21 | with torch.cuda.device_of(input1):
22 | rbot1 = input1.new()
23 | rbot2 = input2.new()
24 | output = input1.new()
25 |
26 | correlation_cuda.forward(input1, input2, rbot1, rbot2, output,
27 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
28 |
29 | return output
30 |
31 | def backward(self, grad_output):
32 | input1, input2 = self.saved_tensors
33 |
34 | with torch.cuda.device_of(input1):
35 | rbot1 = input1.new()
36 | rbot2 = input2.new()
37 |
38 | grad_input1 = input1.new()
39 | grad_input2 = input2.new()
40 |
41 | correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
42 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
43 |
44 | return grad_input1, grad_input2
45 |
46 |
47 | class Correlation(Module):
48 | def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):
49 | super(Correlation, self).__init__()
50 | self.pad_size = pad_size
51 | self.kernel_size = kernel_size
52 | self.max_displacement = max_displacement
53 | self.stride1 = stride1
54 | self.stride2 = stride2
55 | self.corr_multiply = corr_multiply
56 |
57 | def forward(self, input1, input2):
58 |
59 | result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement, self.stride1, self.stride2, self.corr_multiply)(input1, input2)
60 |
61 | return result
62 |
63 |
--------------------------------------------------------------------------------
/models/correlation_package/correlation_cuda.cc:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 | #include "correlation_cuda_kernel.cuh"
9 |
10 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,
11 | int pad_size,
12 | int kernel_size,
13 | int max_displacement,
14 | int stride1,
15 | int stride2,
16 | int corr_type_multiply)
17 | {
18 |
19 | int batchSize = input1.size(0);
20 |
21 | int nInputChannels = input1.size(1);
22 | int inputHeight = input1.size(2);
23 | int inputWidth = input1.size(3);
24 |
25 | int kernel_radius = (kernel_size - 1) / 2;
26 | int border_radius = kernel_radius + max_displacement;
27 |
28 | int paddedInputHeight = inputHeight + 2 * pad_size;
29 | int paddedInputWidth = inputWidth + 2 * pad_size;
30 |
31 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);
32 |
33 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1));
34 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1));
35 |
36 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
37 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
38 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});
39 |
40 | rInput1.fill_(0);
41 | rInput2.fill_(0);
42 | output.fill_(0);
43 |
44 | int success = correlation_forward_cuda_kernel(
45 | output,
46 | output.size(0),
47 | output.size(1),
48 | output.size(2),
49 | output.size(3),
50 | output.stride(0),
51 | output.stride(1),
52 | output.stride(2),
53 | output.stride(3),
54 | input1,
55 | input1.size(1),
56 | input1.size(2),
57 | input1.size(3),
58 | input1.stride(0),
59 | input1.stride(1),
60 | input1.stride(2),
61 | input1.stride(3),
62 | input2,
63 | input2.size(1),
64 | input2.stride(0),
65 | input2.stride(1),
66 | input2.stride(2),
67 | input2.stride(3),
68 | rInput1,
69 | rInput2,
70 | pad_size,
71 | kernel_size,
72 | max_displacement,
73 | stride1,
74 | stride2,
75 | corr_type_multiply,
76 | at::cuda::getCurrentCUDAStream()
77 | //at::globalContext().getCurrentCUDAStream()
78 | );
79 |
80 | //check for errors
81 | if (!success) {
82 | AT_ERROR("CUDA call failed");
83 | }
84 |
85 | return 1;
86 |
87 | }
88 |
89 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput,
90 | at::Tensor& gradInput1, at::Tensor& gradInput2,
91 | int pad_size,
92 | int kernel_size,
93 | int max_displacement,
94 | int stride1,
95 | int stride2,
96 | int corr_type_multiply)
97 | {
98 |
99 | int batchSize = input1.size(0);
100 | int nInputChannels = input1.size(1);
101 | int paddedInputHeight = input1.size(2)+ 2 * pad_size;
102 | int paddedInputWidth = input1.size(3)+ 2 * pad_size;
103 |
104 | int height = input1.size(2);
105 | int width = input1.size(3);
106 |
107 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
108 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
109 | gradInput1.resize_({batchSize, nInputChannels, height, width});
110 | gradInput2.resize_({batchSize, nInputChannels, height, width});
111 |
112 | rInput1.fill_(0);
113 | rInput2.fill_(0);
114 | gradInput1.fill_(0);
115 | gradInput2.fill_(0);
116 |
117 | int success = correlation_backward_cuda_kernel(gradOutput,
118 | gradOutput.size(0),
119 | gradOutput.size(1),
120 | gradOutput.size(2),
121 | gradOutput.size(3),
122 | gradOutput.stride(0),
123 | gradOutput.stride(1),
124 | gradOutput.stride(2),
125 | gradOutput.stride(3),
126 | input1,
127 | input1.size(1),
128 | input1.size(2),
129 | input1.size(3),
130 | input1.stride(0),
131 | input1.stride(1),
132 | input1.stride(2),
133 | input1.stride(3),
134 | input2,
135 | input2.stride(0),
136 | input2.stride(1),
137 | input2.stride(2),
138 | input2.stride(3),
139 | gradInput1,
140 | gradInput1.stride(0),
141 | gradInput1.stride(1),
142 | gradInput1.stride(2),
143 | gradInput1.stride(3),
144 | gradInput2,
145 | gradInput2.size(1),
146 | gradInput2.stride(0),
147 | gradInput2.stride(1),
148 | gradInput2.stride(2),
149 | gradInput2.stride(3),
150 | rInput1,
151 | rInput2,
152 | pad_size,
153 | kernel_size,
154 | max_displacement,
155 | stride1,
156 | stride2,
157 | corr_type_multiply,
158 | at::cuda::getCurrentCUDAStream()
159 | //at::globalContext().getCurrentCUDAStream()
160 | );
161 |
162 | if (!success) {
163 | AT_ERROR("CUDA call failed");
164 | }
165 |
166 | return 1;
167 | }
168 |
169 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
170 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)");
171 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)");
172 | }
173 |
174 |
--------------------------------------------------------------------------------
/models/correlation_package/correlation_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "correlation_cuda_kernel.cuh"
4 |
5 | #define CUDA_NUM_THREADS 1024
6 | #define THREADS_PER_BLOCK 32
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 |
13 | using at::Half;
14 |
15 | template
16 | __global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size)
17 | {
18 |
19 | // n (batch size), c (num of channels), y (height), x (width)
20 | int n = blockIdx.x;
21 | int y = blockIdx.y;
22 | int x = blockIdx.z;
23 |
24 | int ch_off = threadIdx.x;
25 | scalar_t value;
26 |
27 | int dimcyx = channels * height * width;
28 | int dimyx = height * width;
29 |
30 | int p_dimx = (width + 2 * pad_size);
31 | int p_dimy = (height + 2 * pad_size);
32 | int p_dimyxc = channels * p_dimy * p_dimx;
33 | int p_dimxc = p_dimx * channels;
34 |
35 | for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) {
36 | value = input[n * dimcyx + c * dimyx + y * width + x];
37 | rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value;
38 | }
39 | }
40 |
41 | template
42 | __global__ void correlation_forward(scalar_t* output, int nOutputChannels, int outputHeight, int outputWidth,
43 | const scalar_t* __restrict__ rInput1, int nInputChannels, int inputHeight, int inputWidth,
44 | const scalar_t* __restrict__ rInput2,
45 | int pad_size,
46 | int kernel_size,
47 | int max_displacement,
48 | int stride1,
49 | int stride2)
50 | {
51 | // n (batch size), c (num of channels), y (height), x (width)
52 |
53 | int pInputWidth = inputWidth + 2 * pad_size;
54 | int pInputHeight = inputHeight + 2 * pad_size;
55 |
56 | int kernel_rad = (kernel_size - 1) / 2;
57 | int displacement_rad = max_displacement / stride2;
58 | int displacement_size = 2 * displacement_rad + 1;
59 |
60 | int n = blockIdx.x;
61 | int y1 = blockIdx.y * stride1 + max_displacement;
62 | int x1 = blockIdx.z * stride1 + max_displacement;
63 | int c = threadIdx.x;
64 |
65 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
66 | int pdimxc = pInputWidth * nInputChannels;
67 | int pdimc = nInputChannels;
68 |
69 | int tdimcyx = nOutputChannels * outputHeight * outputWidth;
70 | int tdimyx = outputHeight * outputWidth;
71 | int tdimx = outputWidth;
72 |
73 | scalar_t nelems = kernel_size * kernel_size * pdimc;
74 |
75 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
76 |
77 | // no significant speed-up in using chip memory for input1 sub-data,
78 | // not enough chip memory size to accomodate memory per block for input2 sub-data
79 | // instead i've used device memory for both
80 |
81 | // element-wise product along channel axis
82 | for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) {
83 | for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) {
84 | prod_sum[c] = 0;
85 | int x2 = x1 + ti*stride2;
86 | int y2 = y1 + tj*stride2;
87 |
88 | for (int j = -kernel_rad; j <= kernel_rad; ++j) {
89 | for (int i = -kernel_rad; i <= kernel_rad; ++i) {
90 | for (int ch = c; ch < pdimc; ch += THREADS_PER_BLOCK) {
91 | int indx1 = n * pdimyxc + (y1 + j) * pdimxc + (x1 + i) * pdimc + ch;
92 | int indx2 = n * pdimyxc + (y2 + j) * pdimxc + (x2 + i) * pdimc + ch;
93 |
94 | prod_sum[c] += rInput1[indx1] * rInput2[indx2];
95 | }
96 | }
97 | }
98 |
99 | // accumulate
100 | __syncthreads();
101 | if (c == 0) {
102 | scalar_t reduce_sum = 0;
103 | for (int index = 0; index < THREADS_PER_BLOCK; ++index) {
104 | reduce_sum += prod_sum[index];
105 | }
106 | int tc = (tj + displacement_rad) * displacement_size + (ti + displacement_rad);
107 | const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + blockIdx.z;
108 | output[tindx] = reduce_sum / nelems;
109 | }
110 |
111 | }
112 | }
113 |
114 | }
115 |
116 | template
117 | __global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth,
118 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
119 | const scalar_t* __restrict__ rInput2,
120 | int pad_size,
121 | int kernel_size,
122 | int max_displacement,
123 | int stride1,
124 | int stride2)
125 | {
126 | // n (batch size), c (num of channels), y (height), x (width)
127 |
128 | int n = item;
129 | int y = blockIdx.x * stride1 + pad_size;
130 | int x = blockIdx.y * stride1 + pad_size;
131 | int c = blockIdx.z;
132 | int tch_off = threadIdx.x;
133 |
134 | int kernel_rad = (kernel_size - 1) / 2;
135 | int displacement_rad = max_displacement / stride2;
136 | int displacement_size = 2 * displacement_rad + 1;
137 |
138 | int xmin = (x - kernel_rad - max_displacement) / stride1;
139 | int ymin = (y - kernel_rad - max_displacement) / stride1;
140 |
141 | int xmax = (x + kernel_rad - max_displacement) / stride1;
142 | int ymax = (y + kernel_rad - max_displacement) / stride1;
143 |
144 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
145 | // assumes gradInput1 is pre-allocated and zero filled
146 | return;
147 | }
148 |
149 | if (xmin > xmax || ymin > ymax) {
150 | // assumes gradInput1 is pre-allocated and zero filled
151 | return;
152 | }
153 |
154 | xmin = max(0, xmin);
155 | xmax = min(outputWidth - 1, xmax);
156 |
157 | ymin = max(0, ymin);
158 | ymax = min(outputHeight - 1, ymax);
159 |
160 | int pInputWidth = inputWidth + 2 * pad_size;
161 | int pInputHeight = inputHeight + 2 * pad_size;
162 |
163 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
164 | int pdimxc = pInputWidth * nInputChannels;
165 | int pdimc = nInputChannels;
166 |
167 | int tdimcyx = nOutputChannels * outputHeight * outputWidth;
168 | int tdimyx = outputHeight * outputWidth;
169 | int tdimx = outputWidth;
170 |
171 | int odimcyx = nInputChannels * inputHeight* inputWidth;
172 | int odimyx = inputHeight * inputWidth;
173 | int odimx = inputWidth;
174 |
175 | scalar_t nelems = kernel_size * kernel_size * nInputChannels;
176 |
177 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
178 | prod_sum[tch_off] = 0;
179 |
180 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
181 |
182 | int i2 = (tc % displacement_size - displacement_rad) * stride2;
183 | int j2 = (tc / displacement_size - displacement_rad) * stride2;
184 |
185 | int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c;
186 |
187 | scalar_t val2 = rInput2[indx2];
188 |
189 | for (int j = ymin; j <= ymax; ++j) {
190 | for (int i = xmin; i <= xmax; ++i) {
191 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
192 | prod_sum[tch_off] += gradOutput[tindx] * val2;
193 | }
194 | }
195 | }
196 | __syncthreads();
197 |
198 | if (tch_off == 0) {
199 | scalar_t reduce_sum = 0;
200 | for (int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
201 | reduce_sum += prod_sum[idx];
202 | }
203 | const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
204 | gradInput1[indx1] = reduce_sum / nelems;
205 | }
206 |
207 | }
208 |
209 | template
210 | __global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth,
211 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
212 | const scalar_t* __restrict__ rInput1,
213 | int pad_size,
214 | int kernel_size,
215 | int max_displacement,
216 | int stride1,
217 | int stride2)
218 | {
219 | // n (batch size), c (num of channels), y (height), x (width)
220 |
221 | int n = item;
222 | int y = blockIdx.x * stride1 + pad_size;
223 | int x = blockIdx.y * stride1 + pad_size;
224 | int c = blockIdx.z;
225 |
226 | int tch_off = threadIdx.x;
227 |
228 | int kernel_rad = (kernel_size - 1) / 2;
229 | int displacement_rad = max_displacement / stride2;
230 | int displacement_size = 2 * displacement_rad + 1;
231 |
232 | int pInputWidth = inputWidth + 2 * pad_size;
233 | int pInputHeight = inputHeight + 2 * pad_size;
234 |
235 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
236 | int pdimxc = pInputWidth * nInputChannels;
237 | int pdimc = nInputChannels;
238 |
239 | int tdimcyx = nOutputChannels * outputHeight * outputWidth;
240 | int tdimyx = outputHeight * outputWidth;
241 | int tdimx = outputWidth;
242 |
243 | int odimcyx = nInputChannels * inputHeight* inputWidth;
244 | int odimyx = inputHeight * inputWidth;
245 | int odimx = inputWidth;
246 |
247 | scalar_t nelems = kernel_size * kernel_size * nInputChannels;
248 |
249 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
250 | prod_sum[tch_off] = 0;
251 |
252 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
253 | int i2 = (tc % displacement_size - displacement_rad) * stride2;
254 | int j2 = (tc / displacement_size - displacement_rad) * stride2;
255 |
256 | int xmin = (x - kernel_rad - max_displacement - i2) / stride1;
257 | int ymin = (y - kernel_rad - max_displacement - j2) / stride1;
258 |
259 | int xmax = (x + kernel_rad - max_displacement - i2) / stride1;
260 | int ymax = (y + kernel_rad - max_displacement - j2) / stride1;
261 |
262 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
263 | // assumes gradInput2 is pre-allocated and zero filled
264 | continue;
265 | }
266 |
267 | if (xmin > xmax || ymin > ymax) {
268 | // assumes gradInput2 is pre-allocated and zero filled
269 | continue;
270 | }
271 |
272 | xmin = max(0, xmin);
273 | xmax = min(outputWidth - 1, xmax);
274 |
275 | ymin = max(0, ymin);
276 | ymax = min(outputHeight - 1, ymax);
277 |
278 | int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c;
279 | scalar_t val1 = rInput1[indx1];
280 |
281 | for (int j = ymin; j <= ymax; ++j) {
282 | for (int i = xmin; i <= xmax; ++i) {
283 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
284 | prod_sum[tch_off] += gradOutput[tindx] * val1;
285 | }
286 | }
287 | }
288 |
289 | __syncthreads();
290 |
291 | if (tch_off == 0) {
292 | scalar_t reduce_sum = 0;
293 | for (int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
294 | reduce_sum += prod_sum[idx];
295 | }
296 | const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
297 | gradInput2[indx2] = reduce_sum / nelems;
298 | }
299 |
300 | }
301 |
302 | int correlation_forward_cuda_kernel(at::Tensor& output,
303 | int ob,
304 | int oc,
305 | int oh,
306 | int ow,
307 | int osb,
308 | int osc,
309 | int osh,
310 | int osw,
311 |
312 | at::Tensor& input1,
313 | int ic,
314 | int ih,
315 | int iw,
316 | int isb,
317 | int isc,
318 | int ish,
319 | int isw,
320 |
321 | at::Tensor& input2,
322 | int gc,
323 | int gsb,
324 | int gsc,
325 | int gsh,
326 | int gsw,
327 |
328 | at::Tensor& rInput1,
329 | at::Tensor& rInput2,
330 | int pad_size,
331 | int kernel_size,
332 | int max_displacement,
333 | int stride1,
334 | int stride2,
335 | int corr_type_multiply,
336 | cudaStream_t stream)
337 | {
338 |
339 | int batchSize = ob;
340 |
341 | int nInputChannels = ic;
342 | int inputWidth = iw;
343 | int inputHeight = ih;
344 |
345 | int nOutputChannels = oc;
346 | int outputWidth = ow;
347 | int outputHeight = oh;
348 |
349 | dim3 blocks_grid(batchSize, inputHeight, inputWidth);
350 | dim3 threads_block(THREADS_PER_BLOCK);
351 |
352 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] {
353 |
354 | channels_first << > >(
355 | input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size);
356 |
357 | }));
358 |
359 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] {
360 |
361 | channels_first << > > (
362 | input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size);
363 |
364 | }));
365 |
366 | dim3 threadsPerBlock(THREADS_PER_BLOCK);
367 | dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth);
368 |
369 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] {
370 |
371 | correlation_forward << > >
372 | (output.data(), nOutputChannels, outputHeight, outputWidth,
373 | rInput1.data(), nInputChannels, inputHeight, inputWidth,
374 | rInput2.data(),
375 | pad_size,
376 | kernel_size,
377 | max_displacement,
378 | stride1,
379 | stride2);
380 |
381 | }));
382 |
383 | cudaError_t err = cudaGetLastError();
384 |
385 |
386 | // check for errors
387 | if (err != cudaSuccess) {
388 | printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err));
389 | return 0;
390 | }
391 |
392 | return 1;
393 | }
394 |
395 |
396 | int correlation_backward_cuda_kernel(
397 | at::Tensor& gradOutput,
398 | int gob,
399 | int goc,
400 | int goh,
401 | int gow,
402 | int gosb,
403 | int gosc,
404 | int gosh,
405 | int gosw,
406 |
407 | at::Tensor& input1,
408 | int ic,
409 | int ih,
410 | int iw,
411 | int isb,
412 | int isc,
413 | int ish,
414 | int isw,
415 |
416 | at::Tensor& input2,
417 | int gsb,
418 | int gsc,
419 | int gsh,
420 | int gsw,
421 |
422 | at::Tensor& gradInput1,
423 | int gisb,
424 | int gisc,
425 | int gish,
426 | int gisw,
427 |
428 | at::Tensor& gradInput2,
429 | int ggc,
430 | int ggsb,
431 | int ggsc,
432 | int ggsh,
433 | int ggsw,
434 |
435 | at::Tensor& rInput1,
436 | at::Tensor& rInput2,
437 | int pad_size,
438 | int kernel_size,
439 | int max_displacement,
440 | int stride1,
441 | int stride2,
442 | int corr_type_multiply,
443 | cudaStream_t stream)
444 | {
445 |
446 | int batchSize = gob;
447 | int num = batchSize;
448 |
449 | int nInputChannels = ic;
450 | int inputWidth = iw;
451 | int inputHeight = ih;
452 |
453 | int nOutputChannels = goc;
454 | int outputWidth = gow;
455 | int outputHeight = goh;
456 |
457 | dim3 blocks_grid(batchSize, inputHeight, inputWidth);
458 | dim3 threads_block(THREADS_PER_BLOCK);
459 |
460 |
461 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] {
462 |
463 | channels_first << > >(
464 | input1.data(),
465 | rInput1.data(),
466 | nInputChannels,
467 | inputHeight,
468 | inputWidth,
469 | pad_size
470 | );
471 | }));
472 |
473 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
474 |
475 | channels_first << > >(
476 | input2.data(),
477 | rInput2.data(),
478 | nInputChannels,
479 | inputHeight,
480 | inputWidth,
481 | pad_size
482 | );
483 | }));
484 |
485 | dim3 threadsPerBlock(THREADS_PER_BLOCK);
486 | dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels);
487 |
488 | for (int n = 0; n < num; ++n) {
489 |
490 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
491 |
492 |
493 | correlation_backward_input1 << > > (
494 | n, gradInput1.data(), nInputChannels, inputHeight, inputWidth,
495 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth,
496 | rInput2.data(),
497 | pad_size,
498 | kernel_size,
499 | max_displacement,
500 | stride1,
501 | stride2);
502 | }));
503 | }
504 |
505 | for (int n = 0; n < batchSize; n++) {
506 |
507 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] {
508 |
509 | correlation_backward_input2 << > >(
510 | n, gradInput2.data(), nInputChannels, inputHeight, inputWidth,
511 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth,
512 | rInput1.data(),
513 | pad_size,
514 | kernel_size,
515 | max_displacement,
516 | stride1,
517 | stride2);
518 |
519 | }));
520 | }
521 |
522 | // check for errors
523 | cudaError_t err = cudaGetLastError();
524 | if (err != cudaSuccess) {
525 | printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err));
526 | return 0;
527 | }
528 |
529 | return 1;
530 | }
531 |
--------------------------------------------------------------------------------
/models/correlation_package/correlation_cuda_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | int correlation_forward_cuda_kernel(at::Tensor& output,
8 | int ob,
9 | int oc,
10 | int oh,
11 | int ow,
12 | int osb,
13 | int osc,
14 | int osh,
15 | int osw,
16 |
17 | at::Tensor& input1,
18 | int ic,
19 | int ih,
20 | int iw,
21 | int isb,
22 | int isc,
23 | int ish,
24 | int isw,
25 |
26 | at::Tensor& input2,
27 | int gc,
28 | int gsb,
29 | int gsc,
30 | int gsh,
31 | int gsw,
32 |
33 | at::Tensor& rInput1,
34 | at::Tensor& rInput2,
35 | int pad_size,
36 | int kernel_size,
37 | int max_displacement,
38 | int stride1,
39 | int stride2,
40 | int corr_type_multiply,
41 | cudaStream_t stream);
42 |
43 |
44 | int correlation_backward_cuda_kernel(
45 | at::Tensor& gradOutput,
46 | int gob,
47 | int goc,
48 | int goh,
49 | int gow,
50 | int gosb,
51 | int gosc,
52 | int gosh,
53 | int gosw,
54 |
55 | at::Tensor& input1,
56 | int ic,
57 | int ih,
58 | int iw,
59 | int isb,
60 | int isc,
61 | int ish,
62 | int isw,
63 |
64 | at::Tensor& input2,
65 | int gsb,
66 | int gsc,
67 | int gsh,
68 | int gsw,
69 |
70 | at::Tensor& gradInput1,
71 | int gisb,
72 | int gisc,
73 | int gish,
74 | int gisw,
75 |
76 | at::Tensor& gradInput2,
77 | int ggc,
78 | int ggsb,
79 | int ggsc,
80 | int ggsh,
81 | int ggsw,
82 |
83 | at::Tensor& rInput1,
84 | at::Tensor& rInput2,
85 | int pad_size,
86 | int kernel_size,
87 | int max_displacement,
88 | int stride1,
89 | int stride2,
90 | int corr_type_multiply,
91 | cudaStream_t stream);
92 |
--------------------------------------------------------------------------------
/models/correlation_package/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 |
4 | os.environ["CC"] = "gcc"
5 | os.environ["CXX"] = "gcc"
6 |
7 | import torch
8 |
9 | from setuptools import setup, find_packages
10 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
11 |
12 | cxx_args = ['-std=c++11',
13 | '-D_GLICXX_USE_CXX11_ABI=1',
14 | ]
15 |
16 | nvcc_args = [
17 | '-gencode', 'arch=compute_50,code=sm_50',
18 | '-gencode', 'arch=compute_52,code=sm_52',
19 | '-gencode', 'arch=compute_60,code=sm_60',
20 | '-gencode', 'arch=compute_61,code=sm_61',
21 | '-gencode', 'arch=compute_70,code=sm_70',
22 | '-gencode', 'arch=compute_75,code=sm_75',
23 | '-gencode', 'arch=compute_75,code=compute_75',
24 | '-ccbin', '/usr/bin/gcc'
25 | ]
26 |
27 | # '-ccbin', '/usr/bin/gcc-5'
28 |
29 | #nvcc_args = [
30 | # '-gencode', 'arch=compute_50,code=sm_50',
31 | # '-gencode', 'arch=compute_52,code=sm_52',
32 | # '-gencode', 'arch=compute_60,code=sm_60',
33 | # '-gencode', 'arch=compute_61,code=sm_61',
34 | # '-gencode', 'arch=compute_70,code=sm_70',
35 | # '-gencode', 'arch=compute_70,code=compute_70'
36 | #]
37 |
38 | setup(
39 | name='correlation_cuda',
40 | ext_modules=[
41 | CUDAExtension('correlation_cuda', [
42 | 'correlation_cuda.cc',
43 | 'correlation_cuda_kernel.cu'
44 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args, 'cuda-path': ['/usr/local/cuda']})
45 | ],
46 | cmdclass={
47 | 'build_ext': BuildExtension
48 | })
49 |
--------------------------------------------------------------------------------
/models/irr_modules.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as tf
6 |
7 | def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True):
8 | if isReLU:
9 | return nn.Sequential(
10 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
11 | padding=((kernel_size - 1) * dilation) // 2, bias=True),
12 | nn.LeakyReLU(0.1, inplace=True)
13 | )
14 | else:
15 | return nn.Sequential(
16 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
17 | padding=((kernel_size - 1) * dilation) // 2, bias=True)
18 | )
19 |
20 |
21 | def upsample_factor2(inputs, target_as):
22 | inputs = tf.interpolate(inputs, scale_factor=2, mode="nearest")
23 | _, _, h, w = target_as.size()
24 | if inputs.size(2) != h or inputs.size(3) != w:
25 | return tf.interpolate(inputs, [h, w], mode="bilinear", align_corners=False)
26 | else:
27 | return inputs
28 |
29 |
30 | class OccUpsampleNetwork(nn.Module):
31 | def __init__(self, ch_in, ch_out):
32 | super(OccUpsampleNetwork, self).__init__()
33 |
34 | self.feat_dim = 32
35 | self.init_conv = conv(ch_in, self.feat_dim)
36 |
37 | self.res_convs = nn.Sequential(
38 | conv(self.feat_dim, self.feat_dim),
39 | conv(self.feat_dim, self.feat_dim, isReLU=False)
40 | )
41 | self.res_end_conv = conv(self.feat_dim, self.feat_dim)
42 | self.mul_const = 0.1
43 |
44 | self.out_convs = conv(self.feat_dim, ch_out)
45 |
46 | def forward(self, occ, x):
47 | occ = upsample_factor2(occ, x)
48 | x_in = torch.cat([occ, x], dim=1)
49 | x_init = self.init_conv(x_in)
50 | x_res = x_init
51 | x_res = x_res + self.res_convs(x_res) * self.mul_const
52 | x_res = x_res + self.res_convs(x_res) * self.mul_const
53 | x_res = x_res + self.res_convs(x_res) * self.mul_const
54 | x_init = x_init + self.res_end_conv(x_res)
55 |
56 | return self.out_convs(x_init) + occ
57 |
58 |
59 | def subtract_mean(input):
60 | return input - input.mean(2).mean(2).unsqueeze(2).unsqueeze(2).expand_as(input)
61 |
62 |
63 | class RefineFlow(nn.Module):
64 | def __init__(self, ch_in):
65 | super(RefineFlow, self).__init__()
66 |
67 | self.kernel_size = 3
68 | self.pad_size = 1
69 | self.pad_ftn = nn.ReplicationPad2d(self.pad_size)
70 |
71 | self.convs = nn.Sequential(
72 | conv(ch_in, 128, 3, 1, 1),
73 | conv(128, 128, 3, 1, 1),
74 | conv(128, 64, 3, 1, 1),
75 | conv(64, 64, 3, 1, 1),
76 | conv(64, 32, 3, 1, 1),
77 | conv(32, 32, 3, 1, 1),
78 | conv(32, self.kernel_size * self.kernel_size, 3, 1, 1)
79 | )
80 |
81 | self.softmax_feat = nn.Softmax(dim=1)
82 | self.unfold_flow = nn.Unfold(kernel_size=(self.kernel_size, self.kernel_size))
83 | self.unfold_kernel = nn.Unfold(kernel_size=(1, 1))
84 |
85 | def forward(self, flow, diff_img, feature):
86 | b, _, h, w = flow.size()
87 |
88 | flow_m = subtract_mean(flow)
89 | norm2_img = torch.norm(diff_img, p=2, dim=1, keepdim=True)
90 |
91 | feat = self.convs(torch.cat([flow_m, norm2_img, feature], dim=1))
92 | feat_kernel = self.softmax_feat(-feat ** 2)
93 |
94 | flow_x = flow[:, 0].unsqueeze(1)
95 | flow_y = flow[:, 1].unsqueeze(1)
96 |
97 | flow_x_unfold = self.unfold_flow(self.pad_ftn(flow_x))
98 | flow_y_unfold = self.unfold_flow(self.pad_ftn(flow_y))
99 | feat_kernel_unfold = self.unfold_kernel(feat_kernel)
100 |
101 | flow_out_x = torch.sum(flow_x_unfold * feat_kernel_unfold, dim=1).unsqueeze(1).view(b, 1, h, w)
102 | flow_out_y = torch.sum(flow_y_unfold * feat_kernel_unfold, dim=1).unsqueeze(1).view(b, 1, h, w)
103 |
104 | return torch.cat([flow_out_x, flow_out_y], dim=1)
105 |
106 |
107 | class RefineOcc(nn.Module):
108 | def __init__(self, ch_in):
109 | super(RefineOcc, self).__init__()
110 |
111 | self.kernel_size = 3
112 | self.pad_size = 1
113 | self.pad_ftn = nn.ReplicationPad2d(self.pad_size)
114 |
115 | self.convs = nn.Sequential(
116 | conv(ch_in, 128, 3, 1, 1),
117 | conv(128, 128, 3, 1, 1),
118 | conv(128, 64, 3, 1, 1),
119 | conv(64, 64, 3, 1, 1),
120 | conv(64, 32, 3, 1, 1),
121 | conv(32, 32, 3, 1, 1),
122 | conv(32, self.kernel_size * self.kernel_size, 3, 1, 1)
123 | )
124 |
125 | self.softmax_feat = nn.Softmax(dim=1)
126 | self.unfold_occ = nn.Unfold(kernel_size=(self.kernel_size, self.kernel_size))
127 | self.unfold_kernel = nn.Unfold(kernel_size=(1, 1))
128 |
129 | def forward(self, occ, feat1, feat2):
130 | b, _, h, w = occ.size()
131 |
132 | feat = self.convs(torch.cat([occ, feat1, feat2], dim=1))
133 | feat_kernel = self.softmax_feat(-feat ** 2)
134 |
135 | occ_unfold = self.unfold_occ(self.pad_ftn(occ))
136 | feat_kernel_unfold = self.unfold_kernel(feat_kernel)
137 |
138 | occ_out = torch.sum(occ_unfold * feat_kernel_unfold, dim=1).unsqueeze(1).view(b, 1, h, w)
139 |
140 | return occ_out
--------------------------------------------------------------------------------
/models/pwc_modules.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as tf
6 | import logging
7 |
8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True):
9 | if isReLU:
10 | return nn.Sequential(
11 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
12 | padding=((kernel_size - 1) * dilation) // 2, bias=True),
13 | nn.LeakyReLU(0.1, inplace=True)
14 | )
15 | else:
16 | return nn.Sequential(
17 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
18 | padding=((kernel_size - 1) * dilation) // 2, bias=True)
19 | )
20 |
21 |
22 | def initialize_msra(modules):
23 | logging.info("Initializing MSRA")
24 | for layer in modules:
25 | if isinstance(layer, nn.Conv2d):
26 | nn.init.kaiming_normal_(layer.weight)
27 | if layer.bias is not None:
28 | nn.init.constant_(layer.bias, 0)
29 |
30 | elif isinstance(layer, nn.ConvTranspose2d):
31 | nn.init.kaiming_normal_(layer.weight)
32 | if layer.bias is not None:
33 | nn.init.constant_(layer.bias, 0)
34 |
35 | elif isinstance(layer, nn.LeakyReLU):
36 | pass
37 |
38 | elif isinstance(layer, nn.Sequential):
39 | pass
40 |
41 |
42 | def upsample2d_as(inputs, target_as, mode="bilinear"):
43 | _, _, h, w = target_as.size()
44 | return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True)
45 |
46 |
47 | def rescale_flow(flow, div_flow, width_im, height_im, to_local=True):
48 | if to_local:
49 | u_scale = float(flow.size(3) / width_im / div_flow)
50 | v_scale = float(flow.size(2) / height_im / div_flow)
51 | else:
52 | u_scale = float(width_im * div_flow / flow.size(3))
53 | v_scale = float(height_im * div_flow / flow.size(2))
54 |
55 | u, v = flow.chunk(2, dim=1)
56 | u *= u_scale
57 | v *= v_scale
58 |
59 | return torch.cat([u, v], dim=1)
60 |
61 |
62 | class FeatureExtractor(nn.Module):
63 | def __init__(self, num_chs):
64 | super(FeatureExtractor, self).__init__()
65 | self.num_chs = num_chs
66 | self.convs = nn.ModuleList()
67 |
68 | for l, (ch_in, ch_out) in enumerate(zip(num_chs[:-1], num_chs[1:])):
69 | layer = nn.Sequential(
70 | conv(ch_in, ch_out, stride=2),
71 | conv(ch_out, ch_out)
72 | )
73 | self.convs.append(layer)
74 |
75 | def forward(self, x):
76 | feature_pyramid = []
77 | for conv in self.convs:
78 | x = conv(x)
79 | feature_pyramid.append(x)
80 |
81 | return feature_pyramid[::-1]
82 |
83 |
84 | def get_grid(x):
85 | grid_H = torch.linspace(-1.0, 1.0, x.size(3)).view(1, 1, 1, x.size(3)).expand(x.size(0), 1, x.size(2), x.size(3))
86 | grid_V = torch.linspace(-1.0, 1.0, x.size(2)).view(1, 1, x.size(2), 1).expand(x.size(0), 1, x.size(2), x.size(3))
87 | grid = torch.cat([grid_H, grid_V], 1)
88 | grids_cuda = grid.float().requires_grad_(False)
89 | if x.is_cuda:
90 | grids_cuda = grids_cuda.cuda()
91 | return grids_cuda
92 |
93 |
94 | class WarpingLayer(nn.Module):
95 | def __init__(self):
96 | super(WarpingLayer, self).__init__()
97 |
98 | def forward(self, x, flow, height_im, width_im, div_flow):
99 | flo_list = []
100 | flo_w = flow[:, 0] * 2 / max(width_im - 1, 1) / div_flow
101 | flo_h = flow[:, 1] * 2 / max(height_im - 1, 1) / div_flow
102 | flo_list.append(flo_w)
103 | flo_list.append(flo_h)
104 | flow_for_grid = torch.stack(flo_list).transpose(0, 1)
105 | grid = torch.add(get_grid(x), flow_for_grid).transpose(1, 2).transpose(2, 3)
106 | x_warp = tf.grid_sample(x, grid)
107 |
108 | mask = torch.ones(x.size(), requires_grad=False)
109 | if x.is_cuda:
110 | mask = mask.cuda()
111 | mask = tf.grid_sample(mask, grid)
112 | mask = (mask >= 1.0).float()
113 |
114 | return x_warp * mask
115 |
116 | class OpticalFlowEstimator(nn.Module):
117 | def __init__(self, ch_in):
118 | super(OpticalFlowEstimator, self).__init__()
119 |
120 | self.convs = nn.Sequential(
121 | conv(ch_in, 128),
122 | conv(128, 128),
123 | conv(128, 96),
124 | conv(96, 64),
125 | conv(64, 32)
126 | )
127 | self.conv_last = conv(32, 2, isReLU=False)
128 |
129 | def forward(self, x):
130 | x_intm = self.convs(x)
131 | return x_intm, self.conv_last(x_intm)
132 |
133 |
134 | class FlowEstimatorDense(nn.Module):
135 | def __init__(self, ch_in):
136 | super(FlowEstimatorDense, self).__init__()
137 | self.conv1 = conv(ch_in, 128)
138 | self.conv2 = conv(ch_in + 128, 128)
139 | self.conv3 = conv(ch_in + 256, 96)
140 | self.conv4 = conv(ch_in + 352, 64)
141 | self.conv5 = conv(ch_in + 416, 32)
142 | self.conv_last = conv(ch_in + 448, 2, isReLU=False)
143 |
144 | def forward(self, x):
145 | x1 = torch.cat([self.conv1(x), x], dim=1)
146 | x2 = torch.cat([self.conv2(x1), x1], dim=1)
147 | x3 = torch.cat([self.conv3(x2), x2], dim=1)
148 | x4 = torch.cat([self.conv4(x3), x3], dim=1)
149 | x5 = torch.cat([self.conv5(x4), x4], dim=1)
150 | x_out = self.conv_last(x5)
151 | return x5, x_out
152 |
153 | class OcclusionEstimator(nn.Module):
154 | def __init__(self, ch_in):
155 | super(OcclusionEstimator, self).__init__()
156 | self.convs = nn.Sequential(
157 | conv(ch_in, 128),
158 | conv(128, 128),
159 | conv(128, 96),
160 | conv(96, 64),
161 | conv(64, 32)
162 | )
163 | self.conv_last = conv(32, 1, isReLU=False)
164 |
165 | def forward(self, x):
166 | x_intm = self.convs(x)
167 | return x_intm, self.conv_last(x_intm)
168 |
169 |
170 | class OccEstimatorDense(nn.Module):
171 | def __init__(self, ch_in):
172 | super(OccEstimatorDense, self).__init__()
173 | self.conv1 = conv(ch_in, 128)
174 | self.conv2 = conv(ch_in + 128, 128)
175 | self.conv3 = conv(ch_in + 256, 96)
176 | self.conv4 = conv(ch_in + 352, 64)
177 | self.conv5 = conv(ch_in + 416, 32)
178 | self.conv_last = conv(ch_in + 448, 1, isReLU=False)
179 |
180 | def forward(self, x):
181 | x1 = torch.cat([self.conv1(x), x], dim=1)
182 | x2 = torch.cat([self.conv2(x1), x1], dim=1)
183 | x3 = torch.cat([self.conv3(x2), x2], dim=1)
184 | x4 = torch.cat([self.conv4(x3), x3], dim=1)
185 | x5 = torch.cat([self.conv5(x4), x4], dim=1)
186 | x_out = self.conv_last(x5)
187 | return x5, x_out
188 |
189 |
190 | class ContextNetwork(nn.Module):
191 | def __init__(self, ch_in):
192 | super(ContextNetwork, self).__init__()
193 |
194 | self.convs = nn.Sequential(
195 | conv(ch_in, 128, 3, 1, 1),
196 | conv(128, 128, 3, 1, 2),
197 | conv(128, 128, 3, 1, 4),
198 | conv(128, 96, 3, 1, 8),
199 | conv(96, 64, 3, 1, 16),
200 | conv(64, 32, 3, 1, 1),
201 | conv(32, 2, isReLU=False)
202 | )
203 |
204 | def forward(self, x):
205 | return self.convs(x)
206 |
207 |
208 | class OccContextNetwork(nn.Module):
209 | def __init__(self, ch_in):
210 | super(OccContextNetwork, self).__init__()
211 |
212 | self.convs = nn.Sequential(
213 | conv(ch_in, 128, 3, 1, 1),
214 | conv(128, 128, 3, 1, 2),
215 | conv(128, 128, 3, 1, 4),
216 | conv(128, 96, 3, 1, 8),
217 | conv(96, 64, 3, 1, 16),
218 | conv(64, 32, 3, 1, 1),
219 | conv(32, 1, isReLU=False)
220 | )
221 |
222 | def forward(self, x):
223 | return self.convs(x)
224 |
225 | # -------------------------------------------
226 |
227 | class FlowAndOccEstimatorDense(nn.Module):
228 | def __init__(self, ch_in):
229 | super(FlowAndOccEstimatorDense, self).__init__()
230 | self.conv1 = conv(ch_in, 128)
231 | self.conv2 = conv(ch_in + 128, 128)
232 | self.conv3 = conv(ch_in + 256, 96)
233 | self.conv4 = conv(ch_in + 352, 64)
234 | self.conv5 = conv(ch_in + 416, 32)
235 | self.conv_last = conv(ch_in + 448, 3, isReLU=False)
236 |
237 | def forward(self, x):
238 | x1 = torch.cat([self.conv1(x), x], dim=1)
239 | x2 = torch.cat([self.conv2(x1), x1], dim=1)
240 | x3 = torch.cat([self.conv3(x2), x2], dim=1)
241 | x4 = torch.cat([self.conv4(x3), x3], dim=1)
242 | x5 = torch.cat([self.conv5(x4), x4], dim=1)
243 | x_out = self.conv_last(x5)
244 | return x5, x_out[:,:2,:,:], x_out[:,2,:,:].unsqueeze(1)
245 |
246 |
247 | class FlowAndOccContextNetwork(nn.Module):
248 | def __init__(self, ch_in):
249 | super(FlowAndOccContextNetwork, self).__init__()
250 |
251 | self.convs = nn.Sequential(
252 | conv(ch_in, 128, 3, 1, 1),
253 | conv(128, 128, 3, 1, 2),
254 | conv(128, 128, 3, 1, 4),
255 | conv(128, 96, 3, 1, 8),
256 | conv(96, 64, 3, 1, 16),
257 | conv(64, 32, 3, 1, 1),
258 | conv(32, 3, isReLU=False)
259 | )
260 |
261 | def forward(self, x):
262 | x_out = self.convs(x)
263 | return x_out[:,:2,:,:], x_out[:,2,:,:].unsqueeze(1)
264 |
--------------------------------------------------------------------------------
/models/pwcnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .pwc_modules import upsample2d_as, initialize_msra
7 | from .pwc_modules import WarpingLayer, FeatureExtractor
8 | from .pwc_modules import ContextNetwork, FlowEstimatorDense
9 | from .correlation_package.correlation import Correlation
10 |
11 | class PWCNet(nn.Module):
12 | def __init__(self, args, div_flow=0.05):
13 | super(PWCNet, self).__init__()
14 | self.args = args
15 | self._div_flow = div_flow
16 | self.search_range = 4
17 | self.num_chs = [3, 16, 32, 64, 96, 128, 196]
18 | self.output_level = 4
19 | self.num_levels = 7
20 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
21 |
22 | self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
23 | self.warping_layer = WarpingLayer()
24 |
25 | self.flow_estimators = nn.ModuleList()
26 | self.dim_corr = (self.search_range * 2 + 1) ** 2
27 | for l, ch in enumerate(self.num_chs[::-1]):
28 | if l > self.output_level:
29 | break
30 |
31 | if l == 0:
32 | num_ch_in = self.dim_corr
33 | else:
34 | num_ch_in = self.dim_corr + ch + 2
35 |
36 | layer = FlowEstimatorDense(num_ch_in)
37 | self.flow_estimators.append(layer)
38 |
39 | self.context_networks = ContextNetwork(self.dim_corr + 32 + 2 + 448 + 2)
40 |
41 | initialize_msra(self.modules())
42 |
43 | def forward(self, input_dict):
44 |
45 | x1_raw = input_dict['input1']
46 | x2_raw = input_dict['input2']
47 | _, _, height_im, width_im = x1_raw.size()
48 |
49 | # on the bottom level are original images
50 | x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
51 | x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
52 |
53 | # outputs
54 | output_dict = {}
55 | flows = []
56 |
57 | # init
58 | b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
59 | init_dtype = x1_pyramid[0].dtype
60 | init_device = x1_pyramid[0].device
61 | flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
62 |
63 | for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
64 |
65 | # warping
66 | if l == 0:
67 | x2_warp = x2
68 | else:
69 | flow = upsample2d_as(flow, x1, mode="bilinear")
70 | x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow)
71 |
72 | # correlation
73 | out_corr = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x1, x2_warp)
74 | out_corr_relu = self.leakyRELU(out_corr)
75 |
76 | # flow estimator
77 | if l == 0:
78 | x_intm, flow = self.flow_estimators[l](out_corr_relu)
79 | else:
80 | x_intm, flow = self.flow_estimators[l](torch.cat([out_corr_relu, x1, flow], dim=1))
81 |
82 | # upsampling or post-processing
83 | if l != self.output_level:
84 | flows.append(flow)
85 | else:
86 | flow_res = self.context_networks(torch.cat([x_intm, flow], dim=1))
87 | flow = flow + flow_res
88 | flows.append(flow)
89 | break
90 |
91 | output_dict['flow'] = flows
92 |
93 | if self.training:
94 | return output_dict
95 | else:
96 | output_dict_eval = {}
97 | out_flow = upsample2d_as(flow, x1_raw, mode="bilinear") * (1.0 / self._div_flow)
98 | output_dict_eval['flow'] = out_flow
99 | return output_dict_eval
100 |
--------------------------------------------------------------------------------
/models/pwcnet_irr.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .pwc_modules import conv, rescale_flow, upsample2d_as, initialize_msra
7 | from .pwc_modules import WarpingLayer, FeatureExtractor
8 | from .pwc_modules import ContextNetwork, FlowEstimatorDense
9 | from .correlation_package.correlation import Correlation
10 |
11 | class PWCNet(nn.Module):
12 | def __init__(self, args, div_flow=0.05):
13 | super(PWCNet, self).__init__()
14 | self.args = args
15 | self._div_flow = div_flow
16 | self.search_range = 4
17 | self.num_chs = [3, 16, 32, 64, 96, 128, 196]
18 | self.output_level = 4
19 | self.num_levels = 7
20 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
21 |
22 | self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
23 | self.warping_layer = WarpingLayer()
24 |
25 | self.dim_corr = (self.search_range * 2 + 1) ** 2
26 | self.num_ch_in = self.dim_corr + 32 + 2
27 |
28 | self.flow_estimators = FlowEstimatorDense(self.num_ch_in)
29 |
30 | self.context_networks = ContextNetwork(self.num_ch_in + 448 + 2)
31 |
32 | self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),
33 | conv(128, 32, kernel_size=1, stride=1, dilation=1),
34 | conv(96, 32, kernel_size=1, stride=1, dilation=1),
35 | conv(64, 32, kernel_size=1, stride=1, dilation=1),
36 | conv(32, 32, kernel_size=1, stride=1, dilation=1)])
37 |
38 | initialize_msra(self.modules())
39 |
40 | def forward(self, input_dict):
41 |
42 | x1_raw = input_dict['input1']
43 | x2_raw = input_dict['input2']
44 | _, _, height_im, width_im = x1_raw.size()
45 |
46 | # on the bottom level are original images
47 | x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
48 | x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
49 |
50 | # outputs
51 | output_dict = {}
52 | flows = []
53 |
54 | # init
55 | b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
56 | init_dtype = x1_pyramid[0].dtype
57 | init_device = x1_pyramid[0].device
58 | flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
59 |
60 | for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
61 |
62 | # warping
63 | if l == 0:
64 | x2_warp = x2
65 | else:
66 | flow = upsample2d_as(flow, x1, mode="bilinear")
67 | x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow)
68 |
69 | # correlation
70 | out_corr = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x1, x2_warp)
71 | out_corr_relu = self.leakyRELU(out_corr)
72 |
73 | # concat and estimate flow
74 | flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=True)
75 |
76 | x1_1by1 = self.conv_1x1[l](x1)
77 | x_intm, flow_res = self.flow_estimators(torch.cat([out_corr_relu, x1_1by1, flow], dim=1))
78 | flow = flow + flow_res
79 |
80 | flow_fine = self.context_networks(torch.cat([x_intm, flow], dim=1))
81 | flow = flow + flow_fine
82 |
83 | flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=False)
84 | flows.append(flow)
85 |
86 | # upsampling or post-processing
87 | if l == self.output_level:
88 | break
89 |
90 | output_dict['flow'] = flows
91 |
92 | if self.training:
93 | return output_dict
94 | else:
95 | output_dict_eval = {}
96 | output_dict_eval['flow'] = upsample2d_as(flow, x1_raw, mode="bilinear") * (1.0 / self._div_flow)
97 | return output_dict_eval
98 |
--------------------------------------------------------------------------------
/models/pwcnet_irr_occ_joint.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .pwc_modules import conv, rescale_flow, upsample2d_as, initialize_msra
7 | from .pwc_modules import WarpingLayer, FeatureExtractor
8 | from .pwc_modules import FlowAndOccContextNetwork, FlowAndOccEstimatorDense
9 | from .correlation_package.correlation import Correlation
10 |
11 | class PWCNet(nn.Module):
12 | def __init__(self, args, div_flow=0.05):
13 | super(PWCNet, self).__init__()
14 | self.args = args
15 | self._div_flow = div_flow
16 | self.search_range = 4
17 | self.num_chs = [3, 16, 32, 64, 96, 128, 196]
18 | self.output_level = 4
19 | self.num_levels = 7
20 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
21 |
22 | self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
23 | self.warping_layer = WarpingLayer()
24 |
25 | self.dim_corr = (self.search_range * 2 + 1) ** 2
26 | self.num_ch_in = self.dim_corr + 32 + 2 + 1
27 |
28 | self.flow_and_occ_estimators = FlowAndOccEstimatorDense(self.num_ch_in)
29 |
30 | self.context_networks = FlowAndOccContextNetwork(self.num_ch_in + 448 + 2 + 1)
31 |
32 | self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),
33 | conv(128, 32, kernel_size=1, stride=1, dilation=1),
34 | conv(96, 32, kernel_size=1, stride=1, dilation=1),
35 | conv(64, 32, kernel_size=1, stride=1, dilation=1),
36 | conv(32, 32, kernel_size=1, stride=1, dilation=1)])
37 |
38 | initialize_msra(self.modules())
39 |
40 | def forward(self, input_dict):
41 |
42 | x1_raw = input_dict['input1']
43 | x2_raw = input_dict['input2']
44 | _, _, height_im, width_im = x1_raw.size()
45 |
46 | # on the bottom level are original images
47 | x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
48 | x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
49 |
50 | # outputs
51 | output_dict = {}
52 | flows = []
53 | occs = []
54 |
55 | # init
56 | b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
57 | init_dtype = x1_pyramid[0].dtype
58 | init_device = x1_pyramid[0].device
59 | flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
60 | occ = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
61 |
62 | for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
63 |
64 | # warping
65 | if l == 0:
66 | x2_warp = x2
67 | else:
68 | flow = upsample2d_as(flow, x1, mode="bilinear")
69 | occ = upsample2d_as(occ, x1, mode="bilinear")
70 | x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow)
71 |
72 | # correlation
73 | out_corr = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x1, x2_warp)
74 | out_corr_relu = self.leakyRELU(out_corr)
75 |
76 | # concat and estimate flow
77 | flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=True)
78 |
79 | x1_1by1 = self.conv_1x1[l](x1)
80 | x_intm, flow_res, occ_res = self.flow_and_occ_estimators(torch.cat([out_corr_relu, x1_1by1, flow, occ], dim=1))
81 | flow = flow + flow_res
82 | occ = occ + occ_res
83 |
84 | flow_fine, occ_fine = self.context_networks(torch.cat([x_intm, flow, occ], dim=1))
85 | flow = flow + flow_fine
86 | occ = occ + occ_fine
87 |
88 | flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=False)
89 | flows.append(flow)
90 | occs.append(occ)
91 |
92 | # upsampling or post-processing
93 | if l == self.output_level:
94 | break
95 |
96 | output_dict['flow'] = flows
97 | output_dict['occ'] = occs
98 |
99 | if self.training:
100 | return output_dict
101 | else:
102 | output_dict_eval = {}
103 | output_dict_eval['flow'] = upsample2d_as(flow, x1_raw, mode="bilinear") * (1.0 / self._div_flow)
104 | output_dict_eval['occ'] = upsample2d_as(occ, x1_raw, mode="bilinear")
105 | return output_dict_eval
106 |
--------------------------------------------------------------------------------
/models/pwcnet_occ_joint.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .pwc_modules import upsample2d_as, initialize_msra
7 | from .pwc_modules import WarpingLayer, FeatureExtractor
8 | from .pwc_modules import FlowAndOccEstimatorDense, FlowAndOccContextNetwork
9 | from .correlation_package.correlation import Correlation
10 |
11 | class PWCNet(nn.Module):
12 | def __init__(self, args, div_flow=0.05):
13 | super(PWCNet, self).__init__()
14 | self.args = args
15 | self._div_flow = div_flow
16 | self.search_range = 4
17 | self.num_chs = [3, 16, 32, 64, 96, 128, 196]
18 | self.output_level = 4
19 | self.num_levels = 7
20 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
21 |
22 | self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)
23 | self.warping_layer = WarpingLayer()
24 |
25 | self.flow_and_occ_estimators = nn.ModuleList()
26 | self.dim_corr = (self.search_range * 2 + 1) ** 2
27 | for l, ch in enumerate(self.num_chs[::-1]):
28 | if l > self.output_level:
29 | break
30 |
31 | if l == 0:
32 | num_ch_in = self.dim_corr
33 | else:
34 | num_ch_in = self.dim_corr + ch + 2 + 1
35 |
36 | layer = FlowAndOccEstimatorDense(num_ch_in)
37 | self.flow_and_occ_estimators.append(layer)
38 |
39 | self.context_networks = FlowAndOccContextNetwork(self.dim_corr + 32 + 2 + 1 + 448 + 2 + 1)
40 |
41 | initialize_msra(self.modules())
42 |
43 | def forward(self, input_dict):
44 |
45 | x1_raw = input_dict['input1']
46 | x2_raw = input_dict['input2']
47 | _, _, height_im, width_im = x1_raw.size()
48 |
49 | # on the bottom level are original images
50 | x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]
51 | x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]
52 |
53 | # outputs
54 | output_dict = {}
55 | flows = []
56 | occs = []
57 |
58 | # init
59 | b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
60 | init_dtype = x1_pyramid[0].dtype
61 | init_device = x1_pyramid[0].device
62 | flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
63 | occ = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
64 |
65 | for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
66 |
67 | # warping
68 | if l == 0:
69 | x2_warp = x2
70 | else:
71 | flow = upsample2d_as(flow, x1, mode="bilinear")
72 | occ = upsample2d_as(occ, x1, mode="bilinear")
73 | x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow)
74 |
75 | # correlation
76 | out_corr = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x1, x2_warp)
77 | out_corr_relu = self.leakyRELU(out_corr)
78 |
79 | # flow estimator
80 | if l == 0:
81 | x_intm, flow, occ = self.flow_and_occ_estimators[l](out_corr_relu)
82 | else:
83 | x_intm, flow, occ = self.flow_and_occ_estimators[l](torch.cat([out_corr_relu, x1, flow, occ], dim=1))
84 |
85 | # upsampling or post-processing
86 | if l != self.output_level:
87 | flows.append(flow)
88 | occs.append(occ)
89 | else:
90 | flow_fine, occ_fine = self.context_networks(torch.cat([x_intm, flow, occ], dim=1))
91 | flow = flow + flow_fine
92 | occ = occ + occ_fine
93 | flows.append(flow)
94 | occs.append(occ)
95 | break
96 |
97 | output_dict['flow'] = flows
98 | output_dict['occ'] = occs
99 |
100 | if self.training:
101 | return output_dict
102 | else:
103 | output_dict_eval = {}
104 | output_dict_eval['flow'] = upsample2d_as(flow, x1_raw, mode="bilinear") * (1.0 / self._div_flow)
105 | output_dict_eval['occ'] = upsample2d_as(occ, x1_raw, mode="bilinear")
106 | return output_dict_eval
107 |
--------------------------------------------------------------------------------
/optim/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys
3 | from tools import module_classes_to_dict
4 |
5 | # ------------------------------------------------------------------------------------
6 | # Export PyTorch optimizer
7 | # ------------------------------------------------------------------------------------
8 | _this = sys.modules[__name__]
9 | _optimizer_classes = module_classes_to_dict(torch.optim, exclude_classes="Optimizer")
10 | for name, constructor in _optimizer_classes.items():
11 | setattr(_this, name, constructor)
12 | __all__ = _optimizer_classes.keys()
13 |
14 |
--------------------------------------------------------------------------------
/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pgodet/star_flow/cedb96ff339d11abf71d12d09e794593a742ccce/results.png
--------------------------------------------------------------------------------
/saved_checkpoint/StarFlow_kitti/checkpoint_latest.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pgodet/star_flow/cedb96ff339d11abf71d12d09e794593a742ccce/saved_checkpoint/StarFlow_kitti/checkpoint_latest.ckpt
--------------------------------------------------------------------------------
/saved_checkpoint/StarFlow_sintel/checkpoint_latest.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pgodet/star_flow/cedb96ff339d11abf71d12d09e794593a742ccce/saved_checkpoint/StarFlow_sintel/checkpoint_latest.ckpt
--------------------------------------------------------------------------------
/saved_checkpoint/StarFlow_things/checkpoint_best.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pgodet/star_flow/cedb96ff339d11abf71d12d09e794593a742ccce/saved_checkpoint/StarFlow_things/checkpoint_best.ckpt
--------------------------------------------------------------------------------
/scripts_train/train_starflow_chairsocc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # experiments and datasets meta
4 | EXPERIMENTS_HOME="experiments"
5 |
6 | # datasets
7 | FLYINGCHAIRS_OCC_HOME=(YOUR PATH)/FlyingChairsOcc/
8 | SINTEL_HOME=(YOUR PATH)/mpisintelcomplete
9 |
10 | # model and checkpoint
11 | MODEL=StarFlow
12 | EVAL_LOSS=MultiScaleEPE_PWC_Occ_upsample
13 | CHECKPOINT=None
14 | SIZE_OF_BATCH=8
15 | DEVICE=0
16 |
17 | # save path
18 | SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-chairs"
19 |
20 | # training configuration
21 | python ../main.py \
22 | --batch_size=$SIZE_OF_BATCH \
23 | --batch_size_val=$SIZE_OF_BATCH \
24 | --checkpoint=$CHECKPOINT \
25 | --lr_scheduler=MultiStepLR \
26 | --lr_scheduler_gamma=0.5 \
27 | --lr_scheduler_milestones="[108, 144, 180]" \
28 | --model=$MODEL \
29 | --num_workers=6 \
30 | --device=$DEVICE \
31 | --optimizer=Adam \
32 | --optimizer_lr=1e-4 \
33 | --optimizer_weight_decay=4e-4 \
34 | --save=$SAVE_PATH \
35 | --total_epochs=216 \
36 | --training_augmentation=RandomAffineFlowOcc \
37 | --training_dataset=FlyingChairsOccTrain \
38 | --training_dataset_photometric_augmentations=True \
39 | --training_dataset_root=$FLYINGCHAIRS_OCC_HOME \
40 | --training_key=total_loss \
41 | --training_loss=$EVAL_LOSS \
42 | --validation_dataset=FlyingChairsOccValid \
43 | --validation_dataset_photometric_augmentations=False \
44 | --validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \
45 | --validation_key=epe \
46 | --validation_loss=$EVAL_LOSS
47 |
--------------------------------------------------------------------------------
/scripts_train/train_starflow_kitti_full.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # experiments and datasets meta
4 | EXPERIMENTS_HOME="experiments"
5 |
6 | # datasets
7 | KITTI_HOME=(YOUR PATH)/KittiComb
8 |
9 | # model and checkpoint
10 | MODEL=StarFlow
11 | EVAL_LOSS=MultiScaleEPE_PWC_Occ_upsample_KITTI
12 | CHECKPOINT=None
13 | SIZE_OF_BATCH=4
14 | NFRAMES=4
15 | DEVICE=0
16 |
17 | # save path
18 | SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-ftkitti-full"
19 |
20 | # training configuration
21 | python ../main.py \
22 | --batch_size=$SIZE_OF_BATCH \
23 | --batch_size_val=$SIZE_OF_BATCH \
24 | --checkpoint=$CHECKPOINT \
25 | --lr_scheduler=MultiStepLR \
26 | --lr_scheduler_gamma=0.5 \
27 | --lr_scheduler_milestones="[456, 659, 862, 963, 989, 1014, 1116, 1217, 1319, 1420]" \
28 | --model=$MODEL \
29 | --num_workers=6 \
30 | --device=$DEVICE \
31 | --optimizer=Adam \
32 | --optimizer_lr=3e-05 \
33 | --optimizer_weight_decay=4e-4 \
34 | --save=$SAVE_PATH \
35 | --start_epoch=1 \
36 | --total_epochs=550 \
37 | --training_augmentation=RandomAffineFlowOccVideoKitti \
38 | --training_augmentation_crop="[320,896]" \
39 | --training_dataset=KittiMultiframeCombFull \
40 | --training_dataset_nframes=$NFRAMES \
41 | --training_dataset_photometric_augmentations=True \
42 | --training_dataset_root=$KITTI_HOME \
43 | --training_dataset_preprocessing_crop=True \
44 | --training_key=total_loss \
45 | --training_loss=$EVAL_LOSS \
46 | --validation_dataset=KittiMultiframeComb2015Val \
47 | --validation_dataset_nframes=$NFRAMES \
48 | --validation_dataset_photometric_augmentations=True \
49 | --validation_dataset_root=$KITTI_HOME \
50 | --validation_dataset_preprocessing_crop=True \
51 | --validation_key=epe \
52 | --validation_loss=$EVAL_LOSS
53 |
--------------------------------------------------------------------------------
/scripts_train/train_starflow_sintel_full.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # experiments and datasets meta
4 | EXPERIMENTS_HOME="experiments"
5 |
6 | # datasets
7 | SINTEL_HOME=(YOUR PATH)/mpisintelcomplete
8 |
9 | # model and checkpoint
10 | MODEL=StarFlow
11 | EVAL_LOSS=MultiScaleEPE_PWC_Occ_video_upsample_Sintel
12 | CHECKPOINT=None
13 | SIZE_OF_BATCH=4
14 | NFRAMES=4
15 | DEVICE=0
16 |
17 | # save path
18 | SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-ftsintel1-full"
19 |
20 | # training configuration
21 | python ../main.py \
22 | --batch_size=$SIZE_OF_BATCH \
23 | --batch_size_val=$SIZE_OF_BATCH \
24 | --checkpoint=$CHECKPOINT \
25 | --lr_scheduler=MultiStepLR \
26 | --lr_scheduler_gamma=0.5 \
27 | --lr_scheduler_milestones="[89, 130, 170, 190, 195, 200, 220, 240, 260, 280]" \
28 | --model=$MODEL \
29 | --num_workers=6 \
30 | --device=$DEVICE \
31 | --optimizer=Adam \
32 | --optimizer_lr=1.5e-05 \
33 | --optimizer_weight_decay=4e-4 \
34 | --save=$SAVE_PATH \
35 | --start_epoch=1 \
36 | --total_epochs=300 \
37 | --training_augmentation=RandomAffineFlowOccVideo \
38 | --training_augmentation_crop="[384,768]" \
39 | --training_dataset=SintelMultiframeTrainingCombFull \
40 | --training_dataset_nframes=$NFRAMES \
41 | --training_dataset_photometric_augmentations=True \
42 | --training_dataset_root=$SINTEL_HOME \
43 | --training_key=total_loss \
44 | --training_loss=$EVAL_LOSS \
45 | --validation_dataset=SintelMultiframeTrainingFinalValid \
46 | --validation_dataset_nframes=$NFRAMES \
47 | --validation_dataset_photometric_augmentations=False \
48 | --validation_dataset_root=$SINTEL_HOME \
49 | --validation_key=epe \
50 | --validation_loss=$EVAL_LOSS
51 |
52 | # save path
53 | SAVE_PATH_2="$EXPERIMENTS_HOME/$MODEL-ftsintel2-full"
54 |
55 | # training configuration
56 | python ../main.py \
57 | --batch_size=$SIZE_OF_BATCH \
58 | --batch_size_val=$SIZE_OF_BATCH \
59 | --checkpoint=$SAVE_PATH \
60 | --lr_scheduler=MultiStepLR \
61 | --lr_scheduler_gamma=0.5 \
62 | --lr_scheduler_milestones="[481, 562, 643, 683, 693, 703, 743, 783, 824, 864]" \
63 | --model=$MODEL \
64 | --num_workers=6 \
65 | --device=$DEVICE \
66 | --optimizer=Adam \
67 | --optimizer_lr=1e-05 \
68 | --optimizer_weight_decay=4e-4 \
69 | --save=$SAVE_PATH_2 \
70 | --start_epoch=301 \
71 | --total_epochs=451 \
72 | --training_augmentation=RandomAffineFlowOccVideo \
73 | --training_augmentation_crop="[384,768]" \
74 | --training_dataset=SintelMultiframeTrainingFinalFull \
75 | --training_dataset_nframes=$NFRAMES \
76 | --training_dataset_photometric_augmentations=True \
77 | --training_dataset_root=$SINTEL_HOME \
78 | --training_key=total_loss \
79 | --training_loss=$EVAL_LOSS \
80 | --validation_dataset=SintelMultiframeTrainingFinalValid \
81 | --validation_dataset_nframes=$NFRAMES \
82 | --validation_dataset_photometric_augmentations=False \
83 | --validation_dataset_root=$SINTEL_HOME \
84 | --validation_key=epe \
85 | --validation_loss=$EVAL_LOSS
86 |
--------------------------------------------------------------------------------
/scripts_train/train_starflow_things.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # experiments and datasets meta
4 | EXPERIMENTS_HOME="experiments"
5 |
6 | # datasets
7 | FLYINGTHINGS_HOME=(YOUR PATH)/FlyingThings3DSubset
8 | SINTEL_HOME=(YOUR PATH)/mpisintelcomplete
9 |
10 | # model and checkpoint
11 | MODEL=StarFlow
12 | EVAL_LOSS=MultiScaleEPE_PWC_Occ_video_upsample
13 | CHECKPOINT=None
14 | SIZE_OF_BATCH=4
15 | NFRAMES=4
16 | DEVICE=0
17 |
18 | # save path
19 | SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-ftthings"
20 |
21 | # training configuration
22 | python ../main.py \
23 | --batch_size=$SIZE_OF_BATCH \
24 | --batch_size_val=$SIZE_OF_BATCH \
25 | --checkpoint=$CHECKPOINT \
26 | --lr_scheduler=MultiStepLR \
27 | --lr_scheduler_gamma=0.5 \
28 | --lr_scheduler_milestones="[257, 287, 307, 317]" \
29 | --model=$MODEL \
30 | --num_workers=6 \
31 | --device=$DEVICE \
32 | --optimizer=Adam \
33 | --optimizer_lr=1e-4 \
34 | --optimizer_weight_decay=4e-4 \
35 | --save=$SAVE_PATH \
36 | --start_epoch=217 \
37 | --total_epochs=327 \
38 | --training_augmentation=RandomAffineFlowOccVideo \
39 | --training_augmentation_crop="[384,768]" \
40 | --training_dataset=FlyingThings3dMultiframeCleanTrain \
41 | --training_dataset_nframes=$NFRAMES \
42 | --training_dataset_photometric_augmentations=True \
43 | --training_dataset_root=$FLYINGTHINGS_HOME \
44 | --training_key=total_loss \
45 | --training_loss=$EVAL_LOSS \
46 | --validation_dataset=FlyingThings3dMultiframeCleanTest \
47 | --validation_dataset_nframes=$NFRAMES \
48 | --validation_dataset_photometric_augmentations=False \
49 | --validation_dataset_root=$FLYINGTHINGS_HOME \
50 | --validation_key=epe \
51 | --validation_loss=$EVAL_LOSS
52 |
--------------------------------------------------------------------------------
/tools.py:
--------------------------------------------------------------------------------
1 | ## Portions of Code from, copyright 2018 Jochen Gast
2 |
3 | from __future__ import absolute_import, division, print_function
4 |
5 | import os
6 | import socket
7 | import re
8 | #from pytz import timezone
9 | from datetime import datetime
10 | import fnmatch
11 | import itertools
12 | import argparse
13 | import sys
14 | import six
15 | import unicodedata
16 | import json
17 | import inspect
18 | import tqdm
19 | import logging
20 | import torch
21 | import ast
22 | import numpy as np
23 |
24 |
25 | def x2module(module_or_data_parallel):
26 | if isinstance(module_or_data_parallel, torch.nn.DataParallel):
27 | return module_or_data_parallel.module
28 | else:
29 | return module_or_data_parallel
30 |
31 |
32 | # ----------------------------------------------------------------------------------------
33 | # Comprehensively adds a new logging level to the `logging` module and the
34 | # currently configured logging class.
35 | # e.g. addLoggingLevel('TRACE', logging.DEBUG - 5)
36 | # ----------------------------------------------------------------------------------------
37 | def addLoggingLevel(level_name, level_num, method_name=None):
38 | if not method_name:
39 | method_name = level_name.lower()
40 | if hasattr(logging, level_name):
41 | raise AttributeError('{} already defined in logging module'.format(level_name))
42 | if hasattr(logging, method_name):
43 | raise AttributeError('{} already defined in logging module'.format(method_name))
44 | if hasattr(logging.getLoggerClass(), method_name):
45 | raise AttributeError('{} already defined in logger class'.format(method_name))
46 |
47 | # This method was inspired by the answers to Stack Overflow post
48 | # http://stackoverflow.com/q/2183233/2988730, especially
49 | # http://stackoverflow.com/a/13638084/2988730
50 | def logForLevel(self, message, *args, **kwargs):
51 | if self.isEnabledFor(level_num):
52 | self._log(level_num, message, args, **kwargs)
53 |
54 | def logToRoot(message, *args, **kwargs):
55 | logging.log(level_num, message, *args, **kwargs)
56 |
57 | logging.addLevelName(level_num, level_name)
58 | setattr(logging, level_name, level_num)
59 | setattr(logging.getLoggerClass(), method_name, logForLevel)
60 | setattr(logging, method_name, logToRoot)
61 |
62 |
63 | # -------------------------------------------------------------------------------------------------
64 | # Looks for sub arguments in the argument structure.
65 | # Retrieve sub arguments for modules such as optimizer_*
66 | # -------------------------------------------------------------------------------------------------
67 | def kwargs_from_args(args, name, exclude=[]):
68 | if isinstance(exclude, str):
69 | exclude = [exclude]
70 | exclude += ["class"]
71 | args_dict = vars(args)
72 | name += "_"
73 | subargs_dict = {
74 | key[len(name):]: value for key, value in args_dict.items()
75 | if name in key and all([key != name + x for x in exclude])
76 | }
77 | return subargs_dict
78 |
79 |
80 | # -------------------------------------------------------------------------------------------------
81 | # Create class instance from kwargs dictionary.
82 | # Filters out keys that not in the constructor
83 | # -------------------------------------------------------------------------------------------------
84 | def instance_from_kwargs(class_constructor, kwargs):
85 | argspec = inspect.getargspec(class_constructor.__init__)
86 | full_args = argspec.args
87 | filtered_args = dict([(k,v) for k,v in kwargs.items() if k in full_args])
88 | instance = class_constructor(**filtered_args)
89 | return instance
90 |
91 |
92 | def module_classes_to_dict(module, include_classes="*", exclude_classes=()):
93 | # -------------------------------------------------------------------------
94 | # If arguments are strings, convert them to a list
95 | # -------------------------------------------------------------------------
96 | if include_classes is not None:
97 | if isinstance(include_classes, str):
98 | include_classes = [include_classes]
99 |
100 | if exclude_classes is not None:
101 | if isinstance(exclude_classes, str):
102 | exclude_classes = [exclude_classes]
103 |
104 | # -------------------------------------------------------------------------
105 | # Obtain dictionary from given module
106 | # -------------------------------------------------------------------------
107 | item_dict = dict([(name, getattr(module, name)) for name in dir(module)])
108 |
109 | # -------------------------------------------------------------------------
110 | # Filter classes
111 | # -------------------------------------------------------------------------
112 | item_dict = dict([
113 | (name,value) for name, value in item_dict.items() if inspect.isclass(getattr(module, name))
114 | ])
115 |
116 | filtered_keys = filter_list_of_strings(
117 | item_dict.keys(), include=include_classes, exclude=exclude_classes)
118 |
119 | # -------------------------------------------------------------------------
120 | # Construct dictionary from matched results
121 | # -------------------------------------------------------------------------
122 | result_dict = dict([(name, value) for name, value in item_dict.items() if name in filtered_keys])
123 |
124 | return result_dict
125 |
126 |
127 | def ensure_dir(file_path):
128 | directory = os.path.dirname(file_path)
129 | if not os.path.exists(directory):
130 | os.makedirs(directory)
131 |
132 |
133 | def search_and_replace(string, regex, replace):
134 | while True:
135 | match = re.search(regex, string)
136 | if match:
137 | string = string.replace(match.group(0), replace)
138 | else:
139 | break
140 | return string
141 |
142 |
143 | def hostname():
144 | name = socket.gethostname()
145 | n = name.find('.')
146 | if n > 0:
147 | name = name[:n]
148 | return name
149 |
150 |
151 | def get_filenames(directory, match='*.*', not_match=()):
152 | if match is not None:
153 | if isinstance(match, str):
154 | match = [match]
155 | if not_match is not None:
156 | if isinstance(not_match, str):
157 | not_match = [not_match]
158 |
159 | result = []
160 | for dirpath, _, filenames in os.walk(directory):
161 | filtered_matches = list(itertools.chain.from_iterable(
162 | [fnmatch.filter(filenames, x) for x in match]))
163 | filtered_nomatch = list(itertools.chain.from_iterable(
164 | [fnmatch.filter(filenames, x) for x in not_match]))
165 | matched = list(set(filtered_matches) - set(filtered_nomatch))
166 | result += [os.path.join(dirpath, x) for x in matched]
167 | return result
168 |
169 |
170 | def str2bool(v):
171 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
172 | return True
173 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
174 | return False
175 | else:
176 | raise argparse.ArgumentTypeError('Boolean value expected.')
177 |
178 |
179 | def str2str_or_none(v):
180 | if v.lower() == "none":
181 | return None
182 | return v
183 |
184 |
185 | def str2dict(v):
186 | return ast.literal_eval(v)
187 |
188 |
189 | def str2intlist(v):
190 | return [int(x.strip()) for x in v.strip()[1:-1].split(',')]
191 |
192 |
193 | def str2list(v):
194 | return [str(x.strip()) for x in v.strip()[1:-1].split(',')]
195 |
196 |
197 | def read_json(filename):
198 |
199 | def _convert_from_unicode(data):
200 | new_data = dict()
201 | for name, value in six.iteritems(data):
202 | if isinstance(name, six.string_types):
203 | name = unicodedata.normalize('NFKD', name).encode(
204 | 'ascii', 'ignore')
205 | if isinstance(value, six.string_types):
206 | value = unicodedata.normalize('NFKD', value).encode(
207 | 'ascii', 'ignore')
208 | if isinstance(value, dict):
209 | value = _convert_from_unicode(value)
210 | new_data[name] = value
211 | return new_data
212 |
213 | output_dict = None
214 | with open(filename, "r") as f:
215 | lines = f.readlines()
216 | try:
217 | output_dict = json.loads(''.join(lines), encoding='utf-8')
218 | except:
219 | raise ValueError('Could not read %s. %s' % (filename, sys.exc_info()[1]))
220 | output_dict = _convert_from_unicode(output_dict)
221 | return output_dict
222 |
223 |
224 | def write_json(data_dict, filename):
225 | with open(filename, "w") as file:
226 | json.dump(data_dict, file)
227 |
228 |
229 | def datestr():
230 | #pacific = timezone('US/Pacific')
231 | #now = datetime.now(pacific)
232 | now = datetime.now()
233 | return '{}{:02}{:02}_{:02}{:02}'.format(now.year, now.month, now.day, now.hour, now.minute)
234 |
235 |
236 | def filter_list_of_strings(lst, include="*", exclude=()):
237 | filtered_matches = list(itertools.chain.from_iterable([fnmatch.filter(lst, x) for x in include]))
238 | filtered_nomatch = list(itertools.chain.from_iterable([fnmatch.filter(lst, x) for x in exclude]))
239 | matched = list(set(filtered_matches) - set(filtered_nomatch))
240 | return matched
241 |
242 |
243 | # ----------------------------------------------------------------------------
244 | # Writes all pairs to a filename for book keeping
245 | # Either .txt or .json
246 | # ----------------------------------------------------------------------------
247 | def write_dictionary_to_file(arguments_dict, filename):
248 | # ensure dir
249 | d = os.path.dirname(filename)
250 | if not os.path.exists(d):
251 | os.makedirs(d)
252 |
253 | # check for json extension
254 | ext = os.path.splitext(filename)[1]
255 | if ext == ".json":
256 |
257 | def replace_quotes(x):
258 | return x.replace("\'", "\"")
259 |
260 | with open(filename, 'w') as file:
261 | file.write("{\n")
262 | for i, (key, value) in enumerate(arguments_dict):
263 | if isinstance(value, tuple):
264 | value = list(value)
265 | if value is None:
266 | file.write(" \"%s\": null" % key)
267 | elif isinstance(value, str):
268 | value = value.replace("\'", "\"")
269 | file.write(" \"%s\": \"%s\"" % (key, replace_quotes(str( value))))
270 | elif isinstance(value, bool):
271 | file.write(" \"%s\": %s" % (key, str(value).lower()))
272 | else:
273 | file.write(" \"%s\": %s" % (key, replace_quotes(str(value))))
274 | if i < len(arguments_dict) - 1:
275 | file.write(',\n')
276 | else:
277 | file.write('\n')
278 | file.write("}\n")
279 | else:
280 | with open(filename, 'w') as file:
281 | for key, value in arguments_dict:
282 | file.write('%s: %s\n' % (key, value))
283 |
284 |
285 | class MovingAverage:
286 | postfix = "avg"
287 |
288 | def __init__(self):
289 | self._sum = 0.0
290 | self._count = 0
291 |
292 | def add_value(self, sigma, addcount=1):
293 | self._sum += sigma
294 | self._count += addcount
295 |
296 | def add_average(self, avg, addcount):
297 | self._sum += avg*addcount
298 | self._count += addcount
299 |
300 | def mean(self):
301 | return self._sum / self._count
302 |
303 |
304 | class ExponentialMovingAverage:
305 | postfix = "ema"
306 |
307 | def __init__(self, alpha=0.7):
308 | self._weighted_sum = 0.0
309 | self._weighted_count = 0
310 | self._alpha = alpha
311 |
312 | def add_value(self, sigma, addcount=1):
313 | self._weighted_sum = sigma + (1.0 - self._alpha)*self._weighted_sum
314 | self._weighted_count = 1 + (1.0 - self._alpha)*self._weighted_count
315 |
316 | def add_average(self, avg, addcount):
317 | self._weighted_sum = avg*addcount + (1.0 - self._alpha)*self._weighted_sum
318 | self._weighted_count = addcount + (1.0 - self._alpha)*self._weighted_count
319 |
320 | def mean(self):
321 | return self._weighted_sum / self._weighted_count
322 |
323 |
324 | # -----------------------------------------------------------------
325 | # Subclass tqdm to achieve two things:
326 | # 1) Output the progress bar into the logbook.
327 | # 2) Remove the comma before {postfix} because it's annoying.
328 | # -----------------------------------------------------------------
329 | class TqdmToLogger(tqdm.tqdm):
330 | def __init__(self, iterable=None, desc=None, total=None, leave=True,
331 | file=None, ncols=None, mininterval=0.1,
332 | maxinterval=10.0, miniters=None, ascii=None, disable=False,
333 | unit='it', unit_scale=False, dynamic_ncols=False,
334 | smoothing=0.3, bar_format=None, initial=0, position=None,
335 | postfix=None,
336 | logging_on_close=True,
337 | logging_on_update=False):
338 |
339 | super(TqdmToLogger, self).__init__(
340 | iterable=iterable, desc=desc, total=total, leave=leave,
341 | file=file, ncols=ncols, mininterval=mininterval,
342 | maxinterval=maxinterval, miniters=miniters, ascii=ascii, disable=disable,
343 | unit=unit, unit_scale=unit_scale, dynamic_ncols=dynamic_ncols,
344 | smoothing=smoothing, bar_format=bar_format, initial=initial, position=position,
345 | postfix=postfix)
346 |
347 | self._logging_on_close = logging_on_close
348 | self._logging_on_update = logging_on_update
349 | self._closed = False
350 |
351 | @staticmethod
352 | def format_meter(n, total, elapsed, ncols=None, prefix='', ascii=False,
353 | unit='it', unit_scale=False, rate=None, bar_format=None,
354 | postfix=None, unit_divisor=1000):
355 |
356 | meter = tqdm.tqdm.format_meter(
357 | n=n, total=total, elapsed=elapsed, ncols=ncols, prefix=prefix, ascii=ascii,
358 | unit=unit, unit_scale=unit_scale, rate=rate, bar_format=bar_format,
359 | postfix=postfix, unit_divisor=unit_divisor)
360 |
361 | # get rid of that stupid comma before the postfix
362 | if postfix is not None:
363 | postfix_with_comma = ", %s" % postfix
364 | meter = meter.replace(postfix_with_comma, postfix)
365 |
366 | return meter
367 |
368 | def update(self, n=1):
369 | if self._logging_on_update:
370 | msg = self.__repr__()
371 | logging.logbook(msg)
372 | return super(TqdmToLogger, self).update(n=n)
373 |
374 | def close(self):
375 | if self._logging_on_close and not self._closed:
376 | msg = self.__repr__()
377 | logging.logbook(msg)
378 | self._closed = True
379 | return super(TqdmToLogger, self).close()
380 |
381 |
382 | def tqdm_with_logging(iterable=None, desc=None, total=None, leave=True,
383 | ncols=None, mininterval=0.1,
384 | maxinterval=10.0, miniters=None, ascii=None, disable=False,
385 | unit="it", unit_scale=False, dynamic_ncols=False,
386 | smoothing=0.3, bar_format=None, initial=0, position=None,
387 | postfix=None,
388 | logging_on_close=True,
389 | logging_on_update=False):
390 |
391 | return TqdmToLogger(
392 | iterable=iterable, desc=desc, total=total, leave=leave,
393 | ncols=ncols, mininterval=mininterval,
394 | maxinterval=maxinterval, miniters=miniters, ascii=ascii, disable=disable,
395 | unit=unit, unit_scale=unit_scale, dynamic_ncols=dynamic_ncols,
396 | smoothing=smoothing, bar_format=bar_format, initial=initial, position=position,
397 | postfix=postfix,
398 | logging_on_close=logging_on_close,
399 | logging_on_update=logging_on_update)
400 |
401 |
402 | def cd_dotdot(path_or_filename):
403 | return os.path.abspath(os.path.join(os.path.dirname(path_or_filename), ".."))
404 |
405 |
406 | def cd_dotdotdot(path_or_filename):
407 | return os.path.abspath(os.path.join(os.path.dirname(path_or_filename), "../.."))
408 |
409 |
410 | def cd_dotdotdotdot(path_or_filename):
411 | return os.path.abspath(os.path.join(os.path.dirname(path_or_filename), "../../.."))
412 |
413 |
414 | def tensor2numpy(tensor):
415 | if isinstance(tensor, np.ndarray):
416 | return tensor
417 | else:
418 | if isinstance(tensor, torch.autograd.Variable):
419 | tensor = tensor.data
420 | if tensor.dim() == 3:
421 | return tensor.cpu().numpy().transpose([1,2,0])
422 | else:
423 | return tensor.cpu().numpy().transpose([0,2,3,1])
424 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pgodet/star_flow/cedb96ff339d11abf71d12d09e794593a742ccce/utils/__init__.py
--------------------------------------------------------------------------------
/utils/flow.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import numpy as np
4 | import png
5 | #import matplotlib.colors as cl
6 |
7 | TAG_CHAR = np.array([202021.25], np.float32)
8 | UNKNOWN_FLOW_THRESH = 1e7
9 |
10 |
11 | def write_flow(filename, uv, v=None):
12 | nBands = 2
13 |
14 | if v is None:
15 | assert (uv.ndim == 3)
16 | assert (uv.shape[2] == 2)
17 | u = uv[:, :, 0]
18 | v = uv[:, :, 1]
19 | else:
20 | u = uv
21 |
22 | assert (u.shape == v.shape)
23 | height, width = u.shape
24 | f = open(filename, 'wb')
25 | # write the header
26 | f.write(TAG_CHAR)
27 | np.array(width).astype(np.int32).tofile(f)
28 | np.array(height).astype(np.int32).tofile(f)
29 | # arrange into matrix form
30 | tmp = np.zeros((height, width * nBands))
31 | tmp[:, np.arange(width) * 2] = u
32 | tmp[:, np.arange(width) * 2 + 1] = v
33 | tmp.astype(np.float32).tofile(f)
34 | f.close()
35 |
36 |
37 | def write_flow_png(filename, uv, v=None, mask=None):
38 |
39 | if v is None:
40 | assert (uv.ndim == 3)
41 | assert (uv.shape[2] == 2)
42 | u = uv[:, :, 0]
43 | v = uv[:, :, 1]
44 | else:
45 | u = uv
46 |
47 | assert (u.shape == v.shape)
48 |
49 | height_img, width_img = u.shape
50 | if mask is None:
51 | valid_mask = np.ones([height_img, width_img])
52 | else:
53 | valid_mask = mask
54 |
55 | flow_u = np.clip((u * 64 + 2 ** 15), 0.0, 65535.0).astype(np.uint16)
56 | flow_v = np.clip((v * 64 + 2 ** 15), 0.0, 65535.0).astype(np.uint16)
57 |
58 | output = np.stack((flow_u, flow_v, valid_mask), axis=-1)
59 |
60 | with open(filename, 'wb') as f:
61 | writer = png.Writer(width=width_img, height=height_img, bitdepth=16)
62 | writer.write(f, np.reshape(output, (-1, width_img*3)))
63 |
64 |
65 | def flow_to_png(flow_map, max_value=None):
66 | _, h, w = flow_map.shape
67 | rgb_map = np.ones((h, w, 3)).astype(np.float32)
68 | if max_value is not None:
69 | normalized_flow_map = flow_map / max_value
70 | else:
71 | normalized_flow_map = flow_map / (np.abs(flow_map).max())
72 | rgb_map[:, :, 0] += normalized_flow_map[0]
73 | rgb_map[:, :, 1] -= 0.5 * (normalized_flow_map[0] + normalized_flow_map[1])
74 | rgb_map[:, :, 2] += normalized_flow_map[1]
75 | return rgb_map.clip(0, 1)
76 |
77 |
78 |
79 | def compute_color(u, v):
80 | """
81 | compute optical flow color map
82 | :param u: optical flow horizontal map
83 | :param v: optical flow vertical map
84 | :return: optical flow in color code
85 | """
86 | [h, w] = u.shape
87 | img = np.zeros([h, w, 3])
88 | nanIdx = np.isnan(u) | np.isnan(v)
89 | u[nanIdx] = 0
90 | v[nanIdx] = 0
91 |
92 | colorwheel = make_color_wheel()
93 | ncols = np.size(colorwheel, 0)
94 |
95 | rad = np.sqrt(u ** 2 + v ** 2)
96 | rad[rad>1] = 1
97 |
98 | a = np.arctan2(-v, -u) / np.pi
99 |
100 | fk = (a + 1) / 2 * (ncols - 1) + 1
101 |
102 | k0 = np.floor(fk).astype(int)
103 |
104 | k1 = k0 + 1
105 | k1[k1 == ncols + 1] = 1
106 | f = fk - k0
107 |
108 | for i in range(0, np.size(colorwheel, 1)):
109 | tmp = colorwheel[:, i]
110 | col0 = tmp[k0 - 1] / 255
111 | col1 = tmp[k1 - 1] / 255
112 | col = (1 - f) * col0 + f * col1
113 |
114 | idx = rad <= 1
115 | col[idx] = 1 - rad[idx] * (1 - col[idx])
116 | notidx = np.logical_not(idx)
117 |
118 | col[notidx] *= 0.75
119 | img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
120 |
121 | return img
122 |
123 |
124 | def make_color_wheel():
125 | """
126 | Generate color wheel according Middlebury color code
127 | :return: Color wheel
128 | """
129 | RY = 15
130 | YG = 6
131 | GC = 4
132 | CB = 11
133 | BM = 13
134 | MR = 6
135 |
136 | ncols = RY + YG + GC + CB + BM + MR
137 |
138 | colorwheel = np.zeros([ncols, 3])
139 |
140 | col = 0
141 |
142 | # RY
143 | colorwheel[0:RY, 0] = 255
144 | colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
145 | col += RY
146 |
147 | # YG
148 | colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
149 | colorwheel[col:col + YG, 1] = 255
150 | col += YG
151 |
152 | # GC
153 | colorwheel[col:col + GC, 1] = 255
154 | colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
155 | col += GC
156 |
157 | # CB
158 | colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
159 | colorwheel[col:col + CB, 2] = 255
160 | col += CB
161 |
162 | # BM
163 | colorwheel[col:col + BM, 2] = 255
164 | colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
165 | col += + BM
166 |
167 | # MR
168 | colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
169 | colorwheel[col:col + MR, 0] = 255
170 |
171 | return colorwheel
172 |
173 |
174 | def flow_to_png_middlebury(flow, maxnorm=None):
175 | """
176 | Convert flow into middlebury color code image
177 | :param flow: optical flow map
178 | :return: optical flow image in middlebury color
179 | """
180 |
181 | flow = flow.transpose([1, 2, 0])
182 | u = flow[:, :, 0]
183 | v = flow[:, :, 1]
184 |
185 | maxu = -999.
186 | maxv = -999.
187 | minu = 999.
188 | minv = 999.
189 |
190 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
191 | u[idxUnknow] = 0
192 | v[idxUnknow] = 0
193 |
194 | maxu = max(maxu, np.max(u))
195 | minu = min(minu, np.min(u))
196 |
197 | maxv = max(maxv, np.max(v))
198 | minv = min(minv, np.min(v))
199 |
200 | if maxnorm is None:
201 | rad = np.sqrt(u ** 2 + v ** 2)
202 | maxrad = max(-1, np.max(rad))
203 | else:
204 | maxrad = maxnorm
205 |
206 | u = u / (maxrad + np.finfo(float).eps)
207 | v = v / (maxrad + np.finfo(float).eps)
208 |
209 | img = compute_color(u, v)
210 |
211 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
212 | img[idx] = 0
213 |
214 | return np.uint8(img)
215 |
--------------------------------------------------------------------------------
/utils/interpolation.py:
--------------------------------------------------------------------------------
1 | ## Portions of Code from, copyright 2018 Jochen Gast
2 |
3 | from __future__ import absolute_import, division, print_function
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as tf
8 |
9 |
10 | def _bchw2bhwc(tensor):
11 | return tensor.transpose(1,2).transpose(2,3)
12 |
13 |
14 | def _bhwc2bchw(tensor):
15 | return tensor.transpose(2,3).transpose(1,2)
16 |
17 |
18 | class Meshgrid(nn.Module):
19 | def __init__(self):
20 | super(Meshgrid, self).__init__()
21 | self.width = 0
22 | self.height = 0
23 | self.register_buffer("xx", torch.zeros(1,1))
24 | self.register_buffer("yy", torch.zeros(1,1))
25 | self.register_buffer("rangex", torch.zeros(1,1))
26 | self.register_buffer("rangey", torch.zeros(1,1))
27 |
28 | def _compute_meshgrid(self, width, height):
29 | torch.arange(0, width, out=self.rangex)
30 | torch.arange(0, height, out=self.rangey)
31 | self.xx = self.rangex.repeat(height, 1).contiguous()
32 | self.yy = self.rangey.repeat(width, 1).t().contiguous()
33 |
34 | def forward(self, width, height):
35 | if self.width != width or self.height != height:
36 | self._compute_meshgrid(width=width, height=height)
37 | self.width = width
38 | self.height = height
39 | return self.xx, self.yy
40 |
41 |
42 | class BatchSub2Ind(nn.Module):
43 | def __init__(self):
44 | super(BatchSub2Ind, self).__init__()
45 | self.register_buffer("_offsets", torch.LongTensor())
46 |
47 | def forward(self, shape, row_sub, col_sub, out=None):
48 | batch_size = row_sub.size(0)
49 | height, width = shape
50 | ind = row_sub*width + col_sub
51 | torch.arange(batch_size, out=self._offsets)
52 | self._offsets *= (height*width)
53 |
54 | if out is None:
55 | return torch.add(ind, self._offsets.view(-1,1,1))
56 | else:
57 | torch.add(ind, self._offsets.view(-1,1,1), out=out)
58 |
59 |
60 | class Interp2(nn.Module):
61 | def __init__(self, clamp=False):
62 | super(Interp2, self).__init__()
63 | self._clamp = clamp
64 | self._batch_sub2ind = BatchSub2Ind()
65 | self.register_buffer("_x0", torch.LongTensor())
66 | self.register_buffer("_x1", torch.LongTensor())
67 | self.register_buffer("_y0", torch.LongTensor())
68 | self.register_buffer("_y1", torch.LongTensor())
69 | self.register_buffer("_i00", torch.LongTensor())
70 | self.register_buffer("_i01", torch.LongTensor())
71 | self.register_buffer("_i10", torch.LongTensor())
72 | self.register_buffer("_i11", torch.LongTensor())
73 | self.register_buffer("_v00", torch.FloatTensor())
74 | self.register_buffer("_v01", torch.FloatTensor())
75 | self.register_buffer("_v10", torch.FloatTensor())
76 | self.register_buffer("_v11", torch.FloatTensor())
77 | self.register_buffer("_x", torch.FloatTensor())
78 | self.register_buffer("_y", torch.FloatTensor())
79 |
80 | def forward(self, v, xq, yq):
81 | batch_size, channels, height, width = v.size()
82 |
83 | # clamp if wanted
84 | if self._clamp:
85 | xq.clamp_(0, width - 1)
86 | yq.clamp_(0, height - 1)
87 |
88 | # ------------------------------------------------------------------
89 | # Find neighbors
90 | #
91 | # x0 = torch.floor(xq).long(), x0.clamp_(0, width - 1)
92 | # x1 = x0 + 1, x1.clamp_(0, width - 1)
93 | # y0 = torch.floor(yq).long(), y0.clamp_(0, height - 1)
94 | # y1 = y0 + 1, y1.clamp_(0, height - 1)
95 | #
96 | # ------------------------------------------------------------------
97 | self._x0 = torch.floor(xq).long().clamp(0, width - 1)
98 | self._y0 = torch.floor(yq).long().clamp(0, height - 1)
99 |
100 | self._x1 = torch.add(self._x0, 1).clamp(0, width - 1)
101 | self._y1 = torch.add(self._y0, 1).clamp(0, height - 1)
102 |
103 | # batch_sub2ind
104 | self._batch_sub2ind([height, width], self._y0, self._x0, out=self._i00)
105 | self._batch_sub2ind([height, width], self._y0, self._x1, out=self._i01)
106 | self._batch_sub2ind([height, width], self._y1, self._x0, out=self._i10)
107 | self._batch_sub2ind([height, width], self._y1, self._x1, out=self._i11)
108 |
109 | # reshape
110 | v_flat = _bchw2bhwc(v).contiguous().view(-1, channels)
111 | torch.index_select(v_flat, dim=0, index=self._i00.view(-1), out=self._v00)
112 | torch.index_select(v_flat, dim=0, index=self._i01.view(-1), out=self._v01)
113 | torch.index_select(v_flat, dim=0, index=self._i10.view(-1), out=self._v10)
114 | torch.index_select(v_flat, dim=0, index=self._i11.view(-1), out=self._v11)
115 |
116 | # local_coords
117 | torch.add(xq, - self._x0.float(), out=self._x)
118 | torch.add(yq, - self._y0.float(), out=self._y)
119 |
120 | # weights
121 | w00 = torch.unsqueeze((1.0 - self._y) * (1.0 - self._x), dim=1)
122 | w01 = torch.unsqueeze((1.0 - self._y) * self._x, dim=1)
123 | w10 = torch.unsqueeze(self._y * (1.0 - self._x), dim=1)
124 | w11 = torch.unsqueeze(self._y * self._x, dim=1)
125 |
126 | def _reshape(u):
127 | return _bhwc2bchw(u.view(batch_size, height, width, channels))
128 |
129 | # values
130 | values = _reshape(self._v00)*w00 + _reshape(self._v01)*w01 \
131 | + _reshape(self._v10)*w10 + _reshape(self._v11)*w11
132 |
133 | if self._clamp:
134 | return values
135 | else:
136 | # find_invalid
137 | invalid = ((xq < 0) | (xq >= width) | (yq < 0) | (yq >= height)).unsqueeze(dim=1).float()
138 | # maskout invalid
139 | transformed = invalid * torch.zeros_like(values) + (1.0 - invalid)*values
140 |
141 | return transformed
142 |
143 |
144 | class Interp2MaskBinary(nn.Module):
145 | def __init__(self, clamp=False):
146 | super(Interp2MaskBinary, self).__init__()
147 | self._clamp = clamp
148 | self._batch_sub2ind = BatchSub2Ind()
149 | self.register_buffer("_x0", torch.LongTensor())
150 | self.register_buffer("_x1", torch.LongTensor())
151 | self.register_buffer("_y0", torch.LongTensor())
152 | self.register_buffer("_y1", torch.LongTensor())
153 | self.register_buffer("_i00", torch.LongTensor())
154 | self.register_buffer("_i01", torch.LongTensor())
155 | self.register_buffer("_i10", torch.LongTensor())
156 | self.register_buffer("_i11", torch.LongTensor())
157 | self.register_buffer("_v00", torch.FloatTensor())
158 | self.register_buffer("_v01", torch.FloatTensor())
159 | self.register_buffer("_v10", torch.FloatTensor())
160 | self.register_buffer("_v11", torch.FloatTensor())
161 | self.register_buffer("_m00", torch.FloatTensor())
162 | self.register_buffer("_m01", torch.FloatTensor())
163 | self.register_buffer("_m10", torch.FloatTensor())
164 | self.register_buffer("_m11", torch.FloatTensor())
165 | self.register_buffer("_x", torch.FloatTensor())
166 | self.register_buffer("_y", torch.FloatTensor())
167 |
168 | def forward(self, v, xq, yq, mask):
169 | batch_size, channels, height, width = v.size()
170 | _, channels_mask, _, _ = mask.size()
171 |
172 | if channels_mask != channels:
173 | mask = mask.repeat(1, int(channels/channels_mask), 1, 1)
174 |
175 | # clamp if wanted
176 | if self._clamp:
177 | xq.clamp_(0, width - 1)
178 | yq.clamp_(0, height - 1)
179 |
180 | # ------------------------------------------------------------------
181 | # Find neighbors
182 | #
183 | # x0 = torch.floor(xq).long(), x0.clamp_(0, width - 1)
184 | # x1 = x0 + 1, x1.clamp_(0, width - 1)
185 | # y0 = torch.floor(yq).long(), y0.clamp_(0, height - 1)
186 | # y1 = y0 + 1, y1.clamp_(0, height - 1)
187 | #
188 | # ------------------------------------------------------------------
189 | self._x0 = torch.floor(xq).long().clamp(0, width - 1)
190 | self._y0 = torch.floor(yq).long().clamp(0, height - 1)
191 |
192 | self._x1 = torch.add(self._x0, 1).clamp(0, width - 1)
193 | self._y1 = torch.add(self._y0, 1).clamp(0, height - 1)
194 |
195 | # batch_sub2ind
196 | self._batch_sub2ind([height, width], self._y0, self._x0, out=self._i00)
197 | self._batch_sub2ind([height, width], self._y0, self._x1, out=self._i01)
198 | self._batch_sub2ind([height, width], self._y1, self._x0, out=self._i10)
199 | self._batch_sub2ind([height, width], self._y1, self._x1, out=self._i11)
200 |
201 | # reshape
202 | v_flat = _bchw2bhwc(v).contiguous().view(-1, channels)
203 | torch.index_select(v_flat, dim=0, index=self._i00.view(-1), out=self._v00)
204 | torch.index_select(v_flat, dim=0, index=self._i01.view(-1), out=self._v01)
205 | torch.index_select(v_flat, dim=0, index=self._i10.view(-1), out=self._v10)
206 | torch.index_select(v_flat, dim=0, index=self._i11.view(-1), out=self._v11)
207 |
208 | # reshape
209 | m_flat = _bchw2bhwc(mask).contiguous().view(-1, channels)
210 | torch.index_select(m_flat, dim=0, index=self._i00.view(-1), out=self._m00)
211 | torch.index_select(m_flat, dim=0, index=self._i01.view(-1), out=self._m01)
212 | torch.index_select(m_flat, dim=0, index=self._i10.view(-1), out=self._m10)
213 | torch.index_select(m_flat, dim=0, index=self._i11.view(-1), out=self._m11)
214 |
215 | # local_coords
216 | torch.add(xq, - self._x0.float(), out=self._x)
217 | torch.add(yq, - self._y0.float(), out=self._y)
218 |
219 | # weights
220 | w00 = torch.unsqueeze((1.0 - self._y) * (1.0 - self._x), dim=1)
221 | w01 = torch.unsqueeze((1.0 - self._y) * self._x, dim=1)
222 | w10 = torch.unsqueeze(self._y * (1.0 - self._x), dim=1)
223 | w11 = torch.unsqueeze(self._y * self._x, dim=1)
224 |
225 | def _reshape(u):
226 | return _bhwc2bchw(u.view(batch_size, height, width, channels))
227 |
228 | # values
229 | values = _reshape(self._m00) * _reshape(self._v00) * w00 + _reshape(self._m01) * _reshape(
230 | self._v01) * w01 + _reshape(self._m10) * _reshape(self._v10) * w10 + _reshape(self._m11) * _reshape(
231 | self._v11) * w11
232 | m_weights = _reshape(self._m00) * w00 + _reshape(self._m01) * w01 + _reshape(self._m10) * w10 + _reshape(
233 | self._m11) * w11
234 | values = values / (m_weights + 1e-12)
235 | invalid_mask = (((1 - m_weights) / (m_weights + 1e-12)) > 0.5)[:, 0:1, :, :]
236 |
237 | if self._clamp:
238 | return values
239 | else:
240 | # find_invalid
241 | invalid = ((xq < 0) | (xq >= width) | (yq < 0) | (yq >= height) | invalid_mask.squeeze(dim=1)).unsqueeze(dim=1).float()
242 | transformed = invalid * torch.zeros_like(values) + (1.0 - invalid) * values
243 |
244 | return transformed, (1 - invalid_mask).float()
245 |
246 |
247 | def resize2D(inputs, size_targets, mode="bilinear"):
248 | size_inputs = [inputs.size(2), inputs.size(3)]
249 |
250 | if all([size_inputs == size_targets]):
251 | return inputs # nothing to do
252 | elif any([size_targets < size_inputs]):
253 | resized = tf.adaptive_avg_pool2d(inputs, size_targets) # downscaling
254 | else:
255 | resized = tf.upsample(inputs, size=size_targets, mode=mode) # upsampling
256 |
257 | # correct scaling
258 | return resized
259 |
260 |
261 | def resize2D_as(inputs, output_as, mode="bilinear"):
262 | size_targets = [output_as.size(2), output_as.size(3)]
263 | return resize2D(inputs, size_targets, mode=mode)
264 |
--------------------------------------------------------------------------------