├── LICENSE ├── README.md ├── __init__.py ├── augmentations.py ├── commandline.py ├── configuration.py ├── datasets ├── __init__.py ├── common.py ├── flyingThings3D.py ├── flyingThings3DMultiframe.py ├── flyingchairs.py ├── flyingchairsOcc.py ├── kitti_comb_multiframe.py ├── kitti_combined.py ├── sintel.py ├── sintel_multiframe.py └── transforms.py ├── inference.py ├── install.sh ├── logger.py ├── losses.py ├── main.py ├── models ├── IRR_PWC.py ├── IRR_PWC_occ_joint.py ├── STAR.py ├── __init__.py ├── correlation_package │ ├── __init__.py │ ├── correlation.py │ ├── correlation_cuda.cc │ ├── correlation_cuda_kernel.cu │ ├── correlation_cuda_kernel.cuh │ └── setup.py ├── irr_modules.py ├── pwc_modules.py ├── pwcnet.py ├── pwcnet_irr.py ├── pwcnet_irr_occ_joint.py ├── pwcnet_occ_joint.py ├── tr_features.py └── tr_flow.py ├── optim └── __init__.py ├── results.png ├── runtime.py ├── saved_checkpoint ├── StarFlow_kitti │ └── checkpoint_latest.ckpt ├── StarFlow_sintel │ └── checkpoint_latest.ckpt └── StarFlow_things │ └── checkpoint_best.ckpt ├── scripts_train ├── train_starflow_chairsocc.sh ├── train_starflow_kitti_full.sh ├── train_starflow_sintel_full.sh └── train_starflow_things.sh ├── tools.py └── utils ├── __init__.py ├── flow.py └── interpolation.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # STaRFlow 2 | 3 | 4 | 5 | This repository is the PyTorch implementation of STaRFlow, a recurrent convolutional neural network for multi-frame optical flow estimation. This algorithm is presented in our paper **STaRFlow: A SpatioTemporal Recurrent Cell for Lightweight Multi-Frame Optical Flow Estimation**, Pierre Godet, [Alexandre Boulch](https://github.com/aboulch), [Aurélien Plyer](https://github.com/aplyer), and Guy Le Besnerais. 6 | [[Preprint]](https://arxiv.org/pdf/2007.05481.pdf) 7 | 8 | 9 | Please cite our paper if you find our work useful. 10 | 11 | @article{godet2020starflow, 12 | title={STaRFlow: A SpatioTemporal Recurrent Cell for Lightweight Multi-Frame Optical Flow Estimation}, 13 | author={Godet, Pierre and Boulch, Alexandre and Plyer, Aur{\'e}lien and Le Besnerais, Guy}, 14 | journal={arXiv preprint arXiv:2007.05481}, 15 | year={2020} 16 | } 17 | 18 | Contact: pierre.godet@onera.fr 19 | 20 | ## Getting started 21 | This code has been developed and tested under Anaconda(Python 3.7, scipy 1.1, numpy 1.16), Pytorch 1.1 and CUDA 10.1 on Ubuntu 18.04. 22 | 23 | 1. Please install the followings: 24 | 25 | - Anaconda (Python 3.7) 26 | - __PyTorch 1.1__ (Linux, Conda, Python 3.7, CUDA 10) (`conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=10.0 -c pytorch`) 27 | - Depending on your system, configure `-gencode`, `-ccbin`, `cuda-path` in `models/correlation_package/setup.py` accordingly 28 | - scipy 1.1 (`conda install scipy=1.1`) 29 | - colorama (`conda install colorama`) 30 | - tqdm 4.32 (`conda install -c conda-forge tqdm=4.32`) 31 | - pypng (`pip install pypng`) 32 | 33 | 2. Then, install the correlation package: 34 | ``` 35 | ./install.sh 36 | ``` 37 | 38 | 39 | ## Pretrained Models 40 | 41 | The `saved_checkpoint` folder contains the pre-trained models of STaRFlow trained on 42 | 43 | 1. FlyingChairsOcc -> FlyingThings3D, or 44 | 2. FlyingChairsOcc -> FlyingThings3D -> MPI Sintel, or 45 | 3. FlyingChairsOcc -> FlyingThings3D -> KITTI (2012 and 2015). 46 | 47 | 48 | ## Inference 49 | 50 | The script `inference.py` can be used for testing the pre-trained models. Example: 51 | 52 | python inference.py \ 53 | --model StarFlow \ 54 | --checkpoint saved_checkpoint/StarFlow_things/checkpoint_best.ckpt \ 55 | --data-root /data/mpisintelcomplete/training/final/ambush_6/ \ 56 | --file-list frame_0004.png frame_0005.png frame_0006.png frame_0007.png 57 | 58 | By default, it saves the results in `./output/`. 59 | 60 | 61 | ## Training 62 | 63 | Data-loaders for multi-frame training can be found in the `datasets` folder, multi-frame losses are in `losses.py`, and every architecture used in the experiments presented in our paper is available in the `models` folder. 64 | 65 | ### Datasets 66 | 67 | The datasets used for this project are followings: 68 | 69 | - [FlyingChairsOcc dataset](https://github.com/visinf/irr/tree/master/flyingchairsocc) 70 | - [MPI Sintel Dataset](http://sintel.is.tue.mpg.de/downloads) 71 | - [KITTI Optical Flow 2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) and [KITTI Optical Flow 2012](http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php?benchmark=flow) 72 | - [FlyingThings3D subset](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 73 | 74 | 75 | ### Scripts for training 76 | 77 | The `scripts` folder contains training scripts for STaRFlow. 78 | To train the model, you can simply run the script file, e.g., `./train_starflow_chairsocc.sh`. 79 | In script files, please configure your own experiment directory (EXPERIMENTS_HOME) and dataset directory in your local system (e.g., SINTEL_HOME or KITTI_HOME). 80 | 81 | 82 | ## Acknowledgement 83 | 84 | This repository is a fork of the [IRR-PWC](https://github.com/visinf/irr) implementation from Junhwa Hur and Stefan Roth. 85 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgodet/star_flow/cedb96ff339d11abf71d12d09e794593a742ccce/__init__.py -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import flyingchairs 2 | from . import flyingchairsOcc 3 | from . import flyingThings3D 4 | from . import kitti_combined 5 | from . import sintel 6 | 7 | from . import flyingThings3DMultiframe 8 | from . import sintel_multiframe 9 | from . import kitti_comb_multiframe 10 | 11 | 12 | ## FlyingChairs 13 | FlyingChairsTrain = flyingchairs.FlyingChairsTrain 14 | FlyingChairsValid = flyingchairs.FlyingChairsValid 15 | FlyingChairsFull = flyingchairs.FlyingChairsFull 16 | 17 | ## Our custom FlyingChairs + Occ 18 | FlyingChairsOccTrain = flyingchairsOcc.FlyingChairsOccTrain 19 | FlyingChairsOccValid = flyingchairsOcc.FlyingChairsOccValid 20 | FlyingChairsOccFull = flyingchairsOcc.FlyingChairsOccFull 21 | 22 | 23 | ## FlyingThings3D_subset 24 | FlyingThings3dFinalTrain = flyingThings3D.FlyingThings3dFinalTrain 25 | FlyingThings3dFinalTest = flyingThings3D.FlyingThings3dFinalTest 26 | FlyingThings3dCleanTrain = flyingThings3D.FlyingThings3dCleanTrain 27 | FlyingThings3dCleanTest = flyingThings3D.FlyingThings3dCleanTest 28 | 29 | 30 | ## Sintel 31 | SintelTestClean = sintel.SintelTestClean 32 | SintelTestFinal = sintel.SintelTestFinal 33 | 34 | SintelTrainingCombFull = sintel.SintelTrainingCombFull 35 | SintelTrainingCombTrain = sintel.SintelTrainingCombTrain 36 | SintelTrainingCombValid = sintel.SintelTrainingCombValid 37 | 38 | SintelTrainingCleanFull = sintel.SintelTrainingCleanFull 39 | SintelTrainingCleanTrain = sintel.SintelTrainingCleanTrain 40 | SintelTrainingCleanValid = sintel.SintelTrainingCleanValid 41 | 42 | SintelTrainingFinalFull = sintel.SintelTrainingFinalFull 43 | SintelTrainingFinalTrain = sintel.SintelTrainingFinalTrain 44 | SintelTrainingFinalValid = sintel.SintelTrainingFinalValid 45 | 46 | 47 | ## KITTI Optical Flow 2012 + 2015 48 | KittiCombTrain = kitti_combined.KittiCombTrain 49 | KittiCombVal = kitti_combined.KittiCombVal 50 | KittiCombFull = kitti_combined.KittiCombFull 51 | 52 | KittiComb2012Train = kitti_combined.KittiComb2012Train 53 | KittiComb2012Val = kitti_combined.KittiComb2012Val 54 | KittiComb2012Full = kitti_combined.KittiComb2012Full 55 | KittiComb2012Test = kitti_combined.KittiComb2012Test 56 | 57 | KittiComb2015Train = kitti_combined.KittiComb2015Train 58 | KittiComb2015Val = kitti_combined.KittiComb2015Val 59 | KittiComb2015Full = kitti_combined.KittiComb2015Full 60 | KittiComb2015Test = kitti_combined.KittiComb2015Test 61 | 62 | 63 | ## FlyingThings3D_subset_Multiframe 64 | FlyingThings3dMultiframeCleanTrain = flyingThings3DMultiframe.FlyingThings3dMultiframeCleanTrain 65 | FlyingThings3dMultiframeCleanTest = flyingThings3DMultiframe.FlyingThings3dMultiframeCleanTest 66 | 67 | 68 | ## SintelMultiframe 69 | SintelMultiframeTrainingCombFull = sintel_multiframe.SintelMultiframeTrainingCombFull 70 | SintelMultiframeTrainingCleanFull = sintel_multiframe.SintelMultiframeTrainingCleanFull 71 | SintelMultiframeTrainingFinalFull = sintel_multiframe.SintelMultiframeTrainingFinalFull 72 | 73 | SintelMultiframeTrainingCombValid = sintel_multiframe.SintelMultiframeTrainingCombValid 74 | SintelMultiframeTrainingCleanValid = sintel_multiframe.SintelMultiframeTrainingCleanValid 75 | SintelMultiframeTrainingFinalValid = sintel_multiframe.SintelMultiframeTrainingFinalValid 76 | 77 | SintelMultiframeTrainingCombTrain = sintel_multiframe.SintelMultiframeTrainingCombTrain 78 | SintelMultiframeTrainingCleanTrain = sintel_multiframe.SintelMultiframeTrainingCleanTrain 79 | SintelMultiframeTrainingFinalTrain = sintel_multiframe.SintelMultiframeTrainingFinalTrain 80 | 81 | SintelMultiframeTestFinal = sintel_multiframe.SintelMultiframeTestFinal 82 | SintelMultiframeTestClean = sintel_multiframe.SintelMultiframeTestClean 83 | 84 | 85 | ## KITTI Optical Flow 2012 + 2015 MULTIFRAME 86 | KittiMultiframeCombTrain = kitti_comb_multiframe.KittiMultiframeCombTrain 87 | KittiMultiframeCombVal = kitti_comb_multiframe.KittiMultiframeCombVal 88 | KittiMultiframeCombFull = kitti_comb_multiframe.KittiMultiframeCombFull 89 | 90 | KittiMultiframeComb2012Train = kitti_comb_multiframe.KittiMultiframeComb2012Train 91 | KittiMultiframeComb2012Val = kitti_comb_multiframe.KittiMultiframeComb2012Val 92 | KittiMultiframeComb2012Full = kitti_comb_multiframe.KittiMultiframeComb2012Full 93 | KittiMultiframeComb2012Test = kitti_comb_multiframe.KittiMultiframeComb2012Test 94 | 95 | KittiMultiframeComb2015Train = kitti_comb_multiframe.KittiMultiframeComb2015Train 96 | KittiMultiframeComb2015Val = kitti_comb_multiframe.KittiMultiframeComb2015Val 97 | KittiMultiframeComb2015Full = kitti_comb_multiframe.KittiMultiframeComb2015Full 98 | KittiMultiframeComb2015Test = kitti_comb_multiframe.KittiMultiframeComb2015Test 99 | -------------------------------------------------------------------------------- /datasets/common.py: -------------------------------------------------------------------------------- 1 | ## Portions of Code from, copyright 2018 Jochen Gast 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import torch 6 | import numpy as np 7 | from scipy import ndimage 8 | 9 | import png 10 | 11 | 12 | def numpy2torch(array): 13 | assert(isinstance(array, np.ndarray)) 14 | if array.ndim == 3: 15 | array = np.transpose(array, (2, 0, 1)) 16 | else: 17 | array = np.expand_dims(array, axis=0) 18 | return torch.from_numpy(array.copy()).float() 19 | 20 | 21 | def read_flo_as_float32(filename): 22 | with open(filename, 'rb') as file: 23 | magic = np.fromfile(file, np.float32, count=1) 24 | assert(202021.25 == magic), "Magic number incorrect. Invalid .flo file" 25 | w = np.fromfile(file, np.int32, count=1)[0] 26 | h = np.fromfile(file, np.int32, count=1)[0] 27 | data = np.fromfile(file, np.float32, count=2*h*w) 28 | data2D = np.resize(data, (h, w, 2)) 29 | return data2D 30 | 31 | 32 | def read_occ_image_as_float32(filename): 33 | occ = ndimage.imread(filename).astype(np.float32) / np.float32(255.0) 34 | if occ.ndim == 3: 35 | occ = occ[:, :, 0] 36 | return occ 37 | 38 | 39 | def read_image_as_float32(filename): 40 | return ndimage.imread(filename).astype(np.float32) / np.float32(255.0) 41 | 42 | 43 | def read_image_as_byte(filename): 44 | return ndimage.imread(filename) 45 | 46 | 47 | def read_flopng_as_float32(filename): 48 | """ 49 | Read from KITTI .png file 50 | :param flow_file: name of the flow file 51 | :return: optical flow data in matrix 52 | """ 53 | flow_object = png.Reader(filename=filename) 54 | flow_direct = flow_object.asDirect() 55 | flow_data = list(flow_direct[2]) 56 | (w, h) = flow_direct[3]['size'] 57 | #print("Reading %d x %d flow file in .png format" % (h, w)) 58 | flow = np.zeros((h, w, 3), dtype=np.float32) 59 | for i in range(len(flow_data)): 60 | flow[i, :, 0] = flow_data[i][0::3] 61 | flow[i, :, 1] = flow_data[i][1::3] 62 | flow[i, :, 2] = flow_data[i][2::3] 63 | 64 | invalid_idx = (flow[:, :, 2] == 0) 65 | flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0 66 | flow[invalid_idx, 0] = 0 67 | flow[invalid_idx, 1] = 0 68 | return flow[:, :, :2] -------------------------------------------------------------------------------- /datasets/flyingThings3D.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import torch.utils.data as data 5 | from glob import glob 6 | 7 | from torchvision import transforms as vision_transforms 8 | 9 | from . import transforms 10 | from . import common 11 | 12 | import numpy as np 13 | 14 | 15 | def fillingInNaN(flow): 16 | h, w, c = flow.shape 17 | indices = np.argwhere(np.isnan(flow)) 18 | neighbors = [[-1, 0], [1, 0], [0, -1], [0, 1]] 19 | for ii, idx in enumerate(indices): 20 | sum_sample = 0 21 | count = 0 22 | for jj in range(0, len(neighbors) - 1): 23 | hh = idx[0] + neighbors[jj][0] 24 | ww = idx[1] + neighbors[jj][1] 25 | if hh < 0 or hh >= h: 26 | continue 27 | if ww < 0 or ww >= w: 28 | continue 29 | sample_flow = flow[hh, ww, idx[2]] 30 | if np.isnan(sample_flow): 31 | continue 32 | sum_sample += sample_flow 33 | count += 1 34 | if count is 0: 35 | print('FATAL ERROR: no sample') 36 | flow[idx[0], idx[1], idx[2]] = sum_sample / count 37 | 38 | return flow 39 | 40 | 41 | class FlyingThings3d(data.Dataset): 42 | def __init__(self, 43 | args, 44 | images_root, 45 | flow_root, 46 | occ_root, 47 | photometric_augmentations=False, 48 | backward=False): 49 | 50 | self._args = args 51 | self.backward = backward 52 | 53 | if not os.path.isdir(images_root): 54 | raise ValueError("Image directory '%s' not found!") 55 | if flow_root is not None and not os.path.isdir(flow_root): 56 | raise ValueError("Flow directory '%s' not found!") 57 | if occ_root is not None and not os.path.isdir(occ_root): 58 | raise ValueError("Occ directory '%s' not found!") 59 | 60 | if flow_root is not None: 61 | flow_f_filenames = sorted(glob(os.path.join(flow_root, "into_future/*.flo"))) 62 | flow_b_filenames = sorted(glob(os.path.join(flow_root, "into_past/*.flo"))) 63 | 64 | if occ_root is not None: 65 | occ1_filenames = sorted(glob(os.path.join(occ_root, "into_future/*.png"))) 66 | occ2_filenames = sorted(glob(os.path.join(occ_root, "into_past/*.png"))) 67 | 68 | all_img_filenames = sorted(glob(os.path.join(images_root, "*.png"))) 69 | 70 | self._image_list = [] 71 | self._flow_list = [] if flow_root is not None else None 72 | self._occ_list = [] if occ_root is not None else None 73 | 74 | assert len(all_img_filenames) != 0 75 | assert len(flow_f_filenames) != 0 76 | assert len(flow_b_filenames) != 0 77 | assert len(occ1_filenames) != 0 78 | assert len(occ2_filenames) != 0 79 | 80 | ## path definition 81 | path_flow_f = os.path.join(flow_root, "into_future") 82 | path_flow_b = os.path.join(flow_root, "into_past") 83 | path_occ_f = os.path.join(occ_root, "into_future") 84 | path_occ_b = os.path.join(occ_root, "into_past") 85 | 86 | # ---------------------------------------------------------- 87 | # Save list of actual filenames for inputs and flows 88 | # ---------------------------------------------------------- 89 | 90 | for ii in range(0, len(flow_f_filenames)): 91 | 92 | flo_f = flow_f_filenames[ii] 93 | 94 | idx_f = os.path.splitext(os.path.basename(flo_f))[0] 95 | idx_b = str(int(idx_f) + 1).zfill(7) 96 | 97 | flo_b = os.path.join(path_flow_b, idx_b + ".flo") 98 | 99 | im1 = os.path.join(images_root, idx_f + ".png") 100 | im2 = os.path.join(images_root, idx_b + ".png") 101 | occ1 = os.path.join(path_occ_f, idx_f + ".png") 102 | occ2 = os.path.join(path_occ_b, idx_b + ".png") 103 | 104 | if not os.path.isfile(flo_f) or not os.path.isfile(flo_b) or not os.path.isfile(im1) or not os.path.isfile( 105 | im2) or not os.path.isfile(occ1) or not os.path.isfile(occ2): 106 | continue 107 | 108 | self._image_list += [[im1, im2]] 109 | self._flow_list += [[flo_f, flo_b]] 110 | self._occ_list += [[occ1, occ2]] 111 | 112 | self._size = len(self._image_list) 113 | 114 | assert len(self._image_list) == len(self._flow_list) 115 | assert len(self._occ_list) == len(self._flow_list) 116 | assert len(self._image_list) != 0 117 | 118 | # ---------------------------------------------------------- 119 | # photometric_augmentations 120 | # ---------------------------------------------------------- 121 | if photometric_augmentations: 122 | self._photometric_transform = transforms.ConcatTransformSplitChainer([ 123 | # uint8 -> PIL 124 | vision_transforms.ToPILImage(), 125 | # PIL -> PIL : random hsv and contrast 126 | vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), 127 | # PIL -> FloatTensor 128 | vision_transforms.transforms.ToTensor(), 129 | transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True), 130 | ], from_numpy=True, to_numpy=False) 131 | 132 | else: 133 | self._photometric_transform = transforms.ConcatTransformSplitChainer([ 134 | # uint8 -> FloatTensor 135 | vision_transforms.transforms.ToTensor(), 136 | ], from_numpy=True, to_numpy=False) 137 | 138 | def __getitem__(self, index): 139 | index = index % self._size 140 | 141 | im1_filename = self._image_list[index][0] 142 | im2_filename = self._image_list[index][1] 143 | flo_f_filename = self._flow_list[index][0] 144 | flo_b_filename = self._flow_list[index][1] 145 | occ1_filename = self._occ_list[index][0] 146 | occ2_filename = self._occ_list[index][1] 147 | 148 | # read float32 images and flow 149 | im1_np0 = common.read_image_as_byte(im1_filename) 150 | im2_np0 = common.read_image_as_byte(im2_filename) 151 | flo_f_np0 = common.read_flo_as_float32(flo_f_filename) 152 | flo_b_np0 = common.read_flo_as_float32(flo_b_filename) 153 | occ1_np0 = common.read_occ_image_as_float32(occ1_filename) 154 | occ2_np0 = common.read_occ_image_as_float32(occ2_filename) 155 | 156 | # temp - check isnan 157 | if np.any(np.isnan(flo_f_np0)): 158 | flo_f_np0 = fillingInNaN(flo_f_np0) 159 | 160 | if np.any(np.isnan(flo_b_np0)): 161 | flo_b_np0 = fillingInNaN(flo_b_np0) 162 | 163 | # possibly apply photometric transformations 164 | im1, im2 = self._photometric_transform(im1_np0, im2_np0) 165 | 166 | # convert flow to FloatTensor 167 | flo_f = common.numpy2torch(flo_f_np0) 168 | flo_b = common.numpy2torch(flo_b_np0) 169 | 170 | # convert occ to FloatTensor 171 | occ1 = common.numpy2torch(occ1_np0) 172 | occ2 = common.numpy2torch(occ2_np0) 173 | 174 | # example filename 175 | basename = os.path.basename(im1_filename)[:-4] 176 | 177 | if self.backward: 178 | #inversion des flots (et occ) : backward en 1 et forward en 2 179 | example_dict = { 180 | "input1": im1, 181 | "input2": im2, 182 | "target1": flo_b, 183 | "target2": flo_f, 184 | "target_occ1": occ2, 185 | "target_occ2": occ1, 186 | "index": index, 187 | "basename": basename 188 | } 189 | else: 190 | example_dict = { 191 | "input1": im1, 192 | "input2": im2, 193 | "target1": flo_f, 194 | "target2": flo_b, 195 | "target_occ1": occ1, 196 | "target_occ2": occ2, 197 | "index": index, 198 | "basename": basename 199 | } 200 | 201 | return example_dict 202 | 203 | def __len__(self): 204 | return self._size 205 | 206 | 207 | class FlyingThings3dFinalTrain(FlyingThings3d): 208 | def __init__(self, 209 | args, 210 | root, 211 | photometric_augmentations=True, 212 | backward=False): 213 | images_root = os.path.join(root, "frames_finalpass") 214 | flow_root = os.path.join(root, "optical_flow") 215 | occ_root = os.path.join(root, "occlusion") 216 | super(FlyingThings3dFinalTrain, self).__init__( 217 | args, 218 | images_root=images_root, 219 | flow_root=flow_root, 220 | occ_root=occ_root, 221 | photometric_augmentations=photometric_augmentations, 222 | backward=backward) 223 | 224 | 225 | class FlyingThings3dFinalTest(FlyingThings3d): 226 | def __init__(self, 227 | args, 228 | root, 229 | photometric_augmentations=False, 230 | backward=False): 231 | images_root = os.path.join(root, "frames_finalpass") 232 | flow_root = os.path.join(root, "optical_flow") 233 | occ_root = os.path.join(root, "occlusion") 234 | super(FlyingThings3dFinalTest, self).__init__( 235 | args, 236 | images_root=images_root, 237 | flow_root=flow_root, 238 | occ_root=occ_root, 239 | photometric_augmentations=photometric_augmentations, 240 | backward=backward) 241 | 242 | 243 | class FlyingThings3dCleanTrain(FlyingThings3d): 244 | def __init__(self, 245 | args, 246 | root, 247 | photometric_augmentations=True, 248 | backward=False): 249 | images_root = os.path.join(root, "train", "image_clean", "left") 250 | flow_root = os.path.join(root, "train", "flow", "left") 251 | occ_root = os.path.join(root, "train", "flow_occlusions", "left") 252 | super(FlyingThings3dCleanTrain, self).__init__( 253 | args, 254 | images_root=images_root, 255 | flow_root=flow_root, 256 | occ_root=occ_root, 257 | photometric_augmentations=photometric_augmentations, 258 | backward=backward) 259 | 260 | 261 | class FlyingThings3dCleanTest(FlyingThings3d): 262 | def __init__(self, 263 | args, 264 | root, 265 | photometric_augmentations=False, 266 | backward=False): 267 | images_root = os.path.join(root, "val", "image_clean", "left") 268 | flow_root = os.path.join(root, "val", "flow", "left") 269 | occ_root = os.path.join(root, "val", "flow_occlusions", "left") 270 | super(FlyingThings3dCleanTest, self).__init__( 271 | args, 272 | images_root=images_root, 273 | flow_root=flow_root, 274 | occ_root=occ_root, 275 | photometric_augmentations=photometric_augmentations, 276 | backward=backward) 277 | -------------------------------------------------------------------------------- /datasets/flyingThings3DMultiframe.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import torch 5 | import torch.utils.data as data 6 | from glob import glob 7 | 8 | import torch 9 | from torchvision import transforms as vision_transforms 10 | 11 | from . import transforms 12 | from . import common 13 | 14 | import numpy as np 15 | 16 | 17 | def fillingInNaN(flow): 18 | h, w, c = flow.shape 19 | indices = np.argwhere(np.isnan(flow)) 20 | neighbors = [[-1, 0], [1, 0], [0, -1], [0, 1]] 21 | for ii, idx in enumerate(indices): 22 | sum_sample = 0 23 | count = 0 24 | for jj in range(0, len(neighbors) - 1): 25 | hh = idx[0] + neighbors[jj][0] 26 | ww = idx[1] + neighbors[jj][1] 27 | if hh < 0 or hh >= h: 28 | continue 29 | if ww < 0 or ww >= w: 30 | continue 31 | sample_flow = flow[hh, ww, idx[2]] 32 | if np.isnan(sample_flow): 33 | continue 34 | sum_sample += sample_flow 35 | count += 1 36 | if count is 0: 37 | print('FATAL ERROR: no sample') 38 | flow[idx[0], idx[1], idx[2]] = sum_sample / count 39 | 40 | return flow 41 | 42 | 43 | class FlyingThings3dMultiframe(data.Dataset): 44 | def __init__(self, 45 | args, 46 | images_root, 47 | flow_root, 48 | occ_root, 49 | seq_lengths_path, nframes=5, 50 | photometric_augmentations=False, 51 | backward=False): 52 | 53 | self._args = args 54 | self._nframes = nframes 55 | self.backward = backward 56 | 57 | if not os.path.isdir(images_root): 58 | raise ValueError("Image directory '%s' not found!", images_root) 59 | if flow_root is not None and not os.path.isdir(flow_root): 60 | raise ValueError("Flow directory '%s' not found!", flow_root) 61 | if occ_root is not None and not os.path.isdir(occ_root): 62 | raise ValueError("Occ directory '%s' not found!", occ_root) 63 | 64 | if flow_root is not None: 65 | flow_f_filenames = sorted(glob(os.path.join(flow_root, "into_future/*.flo"))) 66 | flow_b_filenames = sorted(glob(os.path.join(flow_root, "into_past/*.flo"))) 67 | 68 | if occ_root is not None: 69 | occ1_filenames = sorted(glob(os.path.join(occ_root, "into_future/*.png"))) 70 | occ2_filenames = sorted(glob(os.path.join(occ_root, "into_past/*.png"))) 71 | 72 | all_img_filenames = sorted(glob(os.path.join(images_root, "*.png"))) 73 | 74 | self._image_list = [] 75 | self._flow_list = [] if flow_root is not None else None 76 | self._occ_list = [] if occ_root is not None else None 77 | 78 | assert len(all_img_filenames) != 0 79 | assert len(flow_f_filenames) != 0 80 | assert len(flow_b_filenames) != 0 81 | assert len(occ1_filenames) != 0 82 | assert len(occ2_filenames) != 0 83 | 84 | self._seq_lengths = np.load(seq_lengths_path) 85 | 86 | ## path definition 87 | path_flow_f = os.path.join(flow_root, "into_future") 88 | path_flow_b = os.path.join(flow_root, "into_past") 89 | path_occ_f = os.path.join(occ_root, "into_future") 90 | path_occ_b = os.path.join(occ_root, "into_past") 91 | 92 | # ---------------------------------------------------------- 93 | # Save list of actual filenames for inputs and flows 94 | # ---------------------------------------------------------- 95 | 96 | idx_first = 0 97 | 98 | for seq_len in self._seq_lengths: 99 | list_images = [] 100 | list_flows = [] 101 | list_occs = [] 102 | 103 | for ii in range(idx_first, idx_first + seq_len - 1): 104 | list_images.append(os.path.join(images_root, "{:07d}".format(ii) + ".png")) 105 | if self.backward: 106 | list_flows.append(os.path.join(path_flow_b, "{:07d}".format(ii+1) + ".flo")) 107 | list_occs.append(os.path.join(path_occ_b, "{:07d}".format(ii+1) + ".png")) 108 | else: 109 | list_flows.append(os.path.join(path_flow_f, "{:07d}".format(ii) + ".flo")) 110 | list_occs.append(os.path.join(path_occ_f, "{:07d}".format(ii) + ".png")) 111 | #if not os.path.isfile(flo_f) or not os.path.isfile(flo_b) or not os.path.isfile(im1) or not os.path.isfile( 112 | # im2) or not os.path.isfile(occ1) or not os.path.isfile(occ2): 113 | # continue 114 | list_images.append(os.path.join(images_root, "{:07d}".format(ii + 1) + ".png")) # ii + 1 = idx_first + seq_len - 1 115 | 116 | for i in range(len(list_images) - self._nframes + 1): 117 | 118 | imgs = list_images[i:i+self._nframes] 119 | flows = list_flows[i:i+self._nframes-1] 120 | occs = list_occs[i:i+self._nframes-1] 121 | 122 | self._image_list += [imgs] 123 | self._flow_list += [flows] 124 | self._occ_list += [occs] 125 | 126 | idx_first += seq_len 127 | 128 | self._size = len(self._image_list) 129 | 130 | assert len(self._image_list) == len(self._flow_list) 131 | assert len(self._occ_list) == len(self._flow_list) 132 | assert len(self._image_list) != 0 133 | 134 | 135 | # ---------------------------------------------------------- 136 | # photometric_augmentations 137 | # ---------------------------------------------------------- 138 | if photometric_augmentations: 139 | self._photometric_transform = transforms.ConcatTransformSplitChainer([ 140 | # uint8 -> PIL 141 | vision_transforms.ToPILImage(), 142 | # PIL -> PIL : random hsv and contrast 143 | vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), 144 | # PIL -> FloatTensor 145 | vision_transforms.transforms.ToTensor(), 146 | transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True), 147 | ], from_numpy=True, to_numpy=False) 148 | 149 | else: 150 | self._photometric_transform = transforms.ConcatTransformSplitChainer([ 151 | # uint8 -> FloatTensor 152 | vision_transforms.transforms.ToTensor(), 153 | ], from_numpy=True, to_numpy=False) 154 | 155 | def __getitem__(self, index): 156 | 157 | index = index % self._size 158 | 159 | imgs_filenames = self._image_list[index] 160 | flows_filenames = self._flow_list[index] 161 | occs_filenames = self._occ_list[index] 162 | 163 | # read float32 images and flow 164 | imgs_np0 = [common.read_image_as_byte(filename) for filename in imgs_filenames] 165 | flows_np0 = [common.read_flo_as_float32(filename) for filename in flows_filenames] 166 | occs_np0 = [common.read_occ_image_as_float32(filename) for filename in occs_filenames] 167 | 168 | # temp - check isnan 169 | for ii in range(len(flows_np0)): 170 | if np.any(np.isnan(flows_np0[ii])): 171 | flows_np0[ii] = fillingInNaN(flows_np0[ii]) 172 | 173 | # possibly apply photometric transformations 174 | imgs = self._photometric_transform(*imgs_np0) 175 | 176 | # convert flow to FloatTensor 177 | flows = [common.numpy2torch(flo_np0) for flo_np0 in flows_np0] 178 | 179 | # convert occ to FloatTensor 180 | occs = [common.numpy2torch(occ_np0) for occ_np0 in occs_np0] 181 | 182 | # example filename 183 | basename = [os.path.basename(filename)[:-4] for filename in imgs_filenames] 184 | 185 | example_dict = { 186 | "input1": imgs[0], 187 | "input_images": imgs, # "target_flows": torch.stack(flows, 0), 188 | "target1": flows[0], 189 | "target_flows": flows, #torch.stack(flows, 0) 190 | "target_occ1": occs[0], 191 | "target_occs": occs, #torch.stack(occs, 0) 192 | "index": index, 193 | "basename": basename, 194 | "nframes": self._nframes 195 | } 196 | 197 | return example_dict 198 | 199 | def __len__(self): 200 | return self._size 201 | 202 | 203 | # class FlyingThings3dFinalTrain(FlyingThings3d): 204 | # def __init__(self, 205 | # args, 206 | # root, 207 | # photometric_augmentations=True): 208 | # images_root = os.path.join(root, "frames_finalpass") 209 | # flow_root = os.path.join(root, "optical_flow") 210 | # occ_root = os.path.join(root, "occlusion") 211 | # seq_lengths_path = os.path.join(root, "seq_lengths.npy") 212 | # super(FlyingThings3dFinalTrain, self).__init__( 213 | # args, 214 | # images_root=images_root, 215 | # flow_root=flow_root, 216 | # occ_root=occ_root, 217 | # seq_lengths_path=seq_lengths_path, 218 | # photometric_augmentations=photometric_augmentations) 219 | 220 | 221 | # class FlyingThings3dFinalTest(FlyingThings3d): 222 | # def __init__(self, 223 | # args, 224 | # root, 225 | # photometric_augmentations=False): 226 | # images_root = os.path.join(root, "frames_finalpass") 227 | # flow_root = os.path.join(root, "optical_flow") 228 | # occ_root = os.path.join(root, "occlusion") 229 | # seq_lengths_path = os.path.join(root, "seq_lengths.npy") 230 | # super(FlyingThings3dFinalTest, self).__init__( 231 | # args, 232 | # images_root=images_root, 233 | # flow_root=flow_root, 234 | # occ_root=occ_root, 235 | # seq_lengths_path=seq_lengths_path, 236 | # photometric_augmentations=photometric_augmentations) 237 | 238 | 239 | class FlyingThings3dMultiframeCleanTrain(FlyingThings3dMultiframe): 240 | def __init__(self, 241 | args, 242 | root, 243 | nframes=5, 244 | photometric_augmentations=True, 245 | backward=False): 246 | images_root = os.path.join(root, "train", "image_clean", "left") 247 | flow_root = os.path.join(root, "train", "flow", "left") 248 | occ_root = os.path.join(root, "train", "flow_occlusions", "left") 249 | seq_lengths_path = os.path.join(root, "train", "seq_lengths.npy") 250 | super(FlyingThings3dMultiframeCleanTrain, self).__init__( 251 | args, 252 | images_root=images_root, 253 | flow_root=flow_root, 254 | occ_root=occ_root, 255 | seq_lengths_path=seq_lengths_path, 256 | photometric_augmentations=photometric_augmentations, 257 | nframes=nframes, backward=backward) 258 | 259 | 260 | class FlyingThings3dMultiframeCleanTest(FlyingThings3dMultiframe): 261 | def __init__(self, 262 | args, 263 | root, 264 | nframes=5, 265 | photometric_augmentations=False, 266 | backward=False): 267 | images_root = os.path.join(root, "val", "image_clean", "left") 268 | flow_root = os.path.join(root, "val", "flow", "left") 269 | occ_root = os.path.join(root, "val", "flow_occlusions", "left") 270 | seq_lengths_path = os.path.join(root, "val", "seq_lengths.npy") 271 | super(FlyingThings3dMultiframeCleanTest, self).__init__( 272 | args, 273 | images_root=images_root, 274 | flow_root=flow_root, 275 | occ_root=occ_root, 276 | seq_lengths_path=seq_lengths_path, 277 | photometric_augmentations=photometric_augmentations, 278 | nframes=nframes, backward=backward) 279 | -------------------------------------------------------------------------------- /datasets/flyingchairs.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import torch.utils.data as data 5 | from glob import glob 6 | 7 | from torchvision import transforms as vision_transforms 8 | 9 | from . import transforms 10 | from . import common 11 | 12 | 13 | VALIDATE_INDICES = [ 14 | 5, 17, 42, 45, 58, 62, 96, 111, 117, 120, 121, 131, 132, 15 | 152, 160, 248, 263, 264, 291, 293, 295, 299, 316, 320, 336, 16 | 337, 343, 358, 399, 401, 429, 438, 468, 476, 494, 509, 528, 17 | 531, 572, 581, 583, 588, 593, 681, 688, 696, 714, 767, 786, 18 | 810, 825, 836, 841, 883, 917, 937, 942, 970, 974, 980, 1016, 19 | 1043, 1064, 1118, 1121, 1133, 1153, 1155, 1158, 1159, 1173, 20 | 1187, 1219, 1237, 1238, 1259, 1266, 1278, 1296, 1354, 1378, 21 | 1387, 1494, 1508, 1518, 1574, 1601, 1614, 1668, 1673, 1699, 22 | 1712, 1714, 1737, 1841, 1872, 1879, 1901, 1921, 1934, 1961, 23 | 1967, 1978, 2018, 2030, 2039, 2043, 2061, 2113, 2204, 2216, 24 | 2236, 2250, 2274, 2292, 2310, 2342, 2359, 2374, 2382, 2399, 25 | 2415, 2419, 2483, 2502, 2504, 2576, 2589, 2590, 2622, 2624, 26 | 2636, 2651, 2655, 2658, 2659, 2664, 2672, 2706, 2707, 2709, 27 | 2725, 2732, 2761, 2827, 2864, 2866, 2905, 2922, 2929, 2966, 28 | 2972, 2993, 3010, 3025, 3031, 3040, 3041, 3070, 3113, 3124, 29 | 3129, 3137, 3141, 3157, 3183, 3206, 3219, 3247, 3253, 3272, 30 | 3276, 3321, 3328, 3333, 3338, 3341, 3346, 3351, 3396, 3419, 31 | 3430, 3433, 3448, 3455, 3463, 3503, 3526, 3529, 3537, 3555, 32 | 3577, 3584, 3591, 3594, 3597, 3603, 3613, 3615, 3670, 3676, 33 | 3678, 3697, 3723, 3728, 3734, 3745, 3750, 3752, 3779, 3782, 34 | 3813, 3817, 3819, 3854, 3885, 3944, 3947, 3970, 3985, 4011, 35 | 4022, 4071, 4075, 4132, 4158, 4167, 4190, 4194, 4207, 4246, 36 | 4249, 4298, 4307, 4317, 4318, 4319, 4320, 4382, 4399, 4401, 37 | 4407, 4416, 4423, 4484, 4491, 4493, 4517, 4525, 4538, 4578, 38 | 4606, 4609, 4620, 4623, 4637, 4646, 4662, 4668, 4716, 4739, 39 | 4747, 4770, 4774, 4776, 4785, 4800, 4845, 4863, 4891, 4904, 40 | 4922, 4925, 4956, 4963, 4964, 4994, 5011, 5019, 5036, 5038, 41 | 5041, 5055, 5118, 5122, 5130, 5162, 5164, 5178, 5196, 5227, 42 | 5266, 5270, 5273, 5279, 5299, 5310, 5314, 5363, 5375, 5384, 43 | 5393, 5414, 5417, 5433, 5448, 5494, 5505, 5509, 5525, 5566, 44 | 5581, 5602, 5609, 5620, 5653, 5670, 5678, 5690, 5700, 5703, 45 | 5724, 5752, 5765, 5803, 5811, 5860, 5881, 5895, 5912, 5915, 46 | 5940, 5952, 5966, 5977, 5988, 6007, 6037, 6061, 6069, 6080, 47 | 6111, 6127, 6146, 6161, 6166, 6168, 6178, 6182, 6190, 6220, 48 | 6235, 6253, 6270, 6343, 6372, 6379, 6410, 6411, 6442, 6453, 49 | 6481, 6498, 6500, 6509, 6532, 6541, 6543, 6560, 6576, 6580, 50 | 6594, 6595, 6609, 6625, 6629, 6644, 6658, 6673, 6680, 6698, 51 | 6699, 6702, 6705, 6741, 6759, 6785, 6792, 6794, 6809, 6810, 52 | 6830, 6838, 6869, 6871, 6889, 6925, 6995, 7003, 7026, 7029, 53 | 7080, 7082, 7097, 7102, 7116, 7165, 7200, 7232, 7271, 7282, 54 | 7324, 7333, 7335, 7372, 7387, 7407, 7472, 7474, 7482, 7489, 55 | 7499, 7516, 7533, 7536, 7566, 7620, 7654, 7691, 7704, 7722, 56 | 7746, 7750, 7773, 7806, 7821, 7827, 7851, 7873, 7880, 7884, 57 | 7904, 7912, 7948, 7964, 7965, 7984, 7989, 7992, 8035, 8050, 58 | 8074, 8091, 8094, 8113, 8116, 8151, 8159, 8171, 8179, 8194, 59 | 8195, 8239, 8263, 8290, 8295, 8312, 8367, 8374, 8387, 8407, 60 | 8437, 8439, 8518, 8556, 8588, 8597, 8601, 8651, 8657, 8723, 61 | 8759, 8763, 8785, 8802, 8813, 8826, 8854, 8856, 8866, 8918, 62 | 8922, 8923, 8932, 8958, 8967, 9003, 9018, 9078, 9095, 9104, 63 | 9112, 9129, 9147, 9170, 9171, 9197, 9200, 9249, 9253, 9270, 64 | 9282, 9288, 9295, 9321, 9323, 9324, 9347, 9399, 9403, 9417, 65 | 9426, 9427, 9439, 9468, 9486, 9496, 9511, 9516, 9518, 9529, 66 | 9557, 9563, 9564, 9584, 9586, 9591, 9599, 9600, 9601, 9632, 67 | 9654, 9667, 9678, 9696, 9716, 9723, 9740, 9820, 9824, 9825, 68 | 9828, 9863, 9866, 9868, 9889, 9929, 9938, 9953, 9967, 10019, 69 | 10020, 10025, 10059, 10111, 10118, 10125, 10174, 10194, 70 | 10201, 10202, 10220, 10221, 10226, 10242, 10250, 10276, 71 | 10295, 10302, 10305, 10327, 10351, 10360, 10369, 10393, 72 | 10407, 10438, 10455, 10463, 10465, 10470, 10478, 10503, 73 | 10508, 10509, 10809, 11080, 11331, 11607, 11610, 11864, 74 | 12390, 12393, 12396, 12399, 12671, 12921, 12930, 13178, 75 | 13453, 13717, 14499, 14517, 14775, 15297, 15556, 15834, 76 | 15839, 16126, 16127, 16386, 16633, 16644, 16651, 17166, 77 | 17169, 17958, 17959, 17962, 18224, 21176, 21180, 21190, 78 | 21802, 21803, 21806, 22584, 22857, 22858, 22866] 79 | 80 | 81 | class FlyingChairs(data.Dataset): 82 | def __init__(self, 83 | args, 84 | root, 85 | photometric_augmentations=False, 86 | dstype="train"): 87 | 88 | self._args = args 89 | 90 | # ------------------------------------------------------------- 91 | # filenames for all input images and target flows 92 | # ------------------------------------------------------------- 93 | image_filenames = sorted( glob( os.path.join(root, "*.ppm")) ) 94 | flow_filenames = sorted( glob( os.path.join(root, "*.flo")) ) 95 | assert (len(image_filenames)/2 == len(flow_filenames)) 96 | num_flows = len(flow_filenames) 97 | 98 | # ------------------------------------------------------------- 99 | # Remove invalid validation indices 100 | # ------------------------------------------------------------- 101 | validate_indices = [x for x in VALIDATE_INDICES if x in range(num_flows)] 102 | 103 | # ---------------------------------------------------------- 104 | # Construct list of indices for training/validation 105 | # ---------------------------------------------------------- 106 | list_of_indices = None 107 | if dstype == "train": 108 | list_of_indices = [x for x in range(num_flows) if x not in validate_indices] 109 | elif dstype == "valid": 110 | list_of_indices = validate_indices 111 | elif dstype == "full": 112 | list_of_indices = range(num_flows) 113 | else: 114 | raise ValueError("FlyingChairs: dstype '%s' unknown!", dstype) 115 | 116 | 117 | # ---------------------------------------------------------- 118 | # Save list of actual filenames for inputs and flows 119 | # ---------------------------------------------------------- 120 | self._image_list = [] 121 | self._flow_list = [] 122 | for i in list_of_indices: 123 | flo = flow_filenames[i] 124 | im1 = image_filenames[2*i] 125 | im2 = image_filenames[2*i + 1] 126 | self._image_list += [ [ im1, im2 ] ] 127 | self._flow_list += [ flo ] 128 | self._size = len(self._image_list) 129 | assert len(self._image_list) == len(self._flow_list) 130 | 131 | # ---------------------------------------------------------- 132 | # photometric_augmentations 133 | # ---------------------------------------------------------- 134 | if photometric_augmentations: 135 | self._photometric_transform = transforms.ConcatTransformSplitChainer([ 136 | # uint8 -> PIL 137 | vision_transforms.ToPILImage(), 138 | # PIL -> PIL : random hsv and contrast 139 | vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), 140 | # PIL -> FloatTensor 141 | vision_transforms.transforms.ToTensor(), 142 | transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True), 143 | ], from_numpy=True, to_numpy=False) 144 | else: 145 | self._photometric_transform = transforms.ConcatTransformSplitChainer([ 146 | # uint8 -> FloatTensor 147 | vision_transforms.transforms.ToTensor(), 148 | ], from_numpy=True, to_numpy=False) 149 | 150 | def __getitem__(self, index): 151 | index = index % self._size 152 | 153 | im1_filename = self._image_list[index][0] 154 | im2_filename = self._image_list[index][1] 155 | flo_filename = self._flow_list[index] 156 | 157 | # read float32 images and flow 158 | im1_np0 = common.read_image_as_byte(im1_filename) 159 | im2_np0 = common.read_image_as_byte(im2_filename) 160 | flo_np0 = common.read_flo_as_float32(flo_filename) 161 | 162 | # possibly apply photometric transformations 163 | im1, im2 = self._photometric_transform(im1_np0, im2_np0) 164 | 165 | # convert flow to FloatTensor 166 | flo = common.numpy2torch(flo_np0) 167 | 168 | # target_occ: initialized by zero (not used) 169 | target_occ = common.numpy2torch(common.read_occ_image_as_float32(im1_filename)) * 0 170 | 171 | # example filename 172 | basename = os.path.basename(im1_filename)[:5] 173 | 174 | example_dict = { 175 | "input1": im1, 176 | "input2": im2, 177 | "target1": flo, 178 | "target_occ1": target_occ, 179 | "index": index, 180 | "basename": basename 181 | } 182 | 183 | return example_dict 184 | 185 | def __len__(self): 186 | return self._size 187 | 188 | 189 | class FlyingChairsTrain(FlyingChairs): 190 | def __init__(self, 191 | args, 192 | root, 193 | photometric_augmentations=True): 194 | super(FlyingChairsTrain, self).__init__( 195 | args, 196 | root=root, 197 | photometric_augmentations=photometric_augmentations, 198 | dstype="train") 199 | 200 | 201 | class FlyingChairsValid(FlyingChairs): 202 | def __init__(self, 203 | args, 204 | root, 205 | photometric_augmentations=False): 206 | super(FlyingChairsValid, self).__init__( 207 | args, 208 | root=root, 209 | photometric_augmentations=photometric_augmentations, 210 | dstype="valid") 211 | 212 | 213 | class FlyingChairsFull(FlyingChairs): 214 | def __init__(self, 215 | args, 216 | root, 217 | photometric_augmentations=False): 218 | super(FlyingChairsFull, self).__init__( 219 | args, 220 | root=root, 221 | photometric_augmentations=photometric_augmentations, 222 | dstype="full") 223 | -------------------------------------------------------------------------------- /datasets/flyingchairsOcc.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import torch 5 | import torch.utils.data as data 6 | from glob import glob 7 | 8 | from torchvision import transforms as vision_transforms 9 | 10 | from . import transforms 11 | from . import common 12 | 13 | 14 | VALIDATE_INDICES = [ 15 | 5, 17, 42, 45, 58, 62, 96, 111, 117, 120, 121, 131, 132, 16 | 152, 160, 248, 263, 264, 291, 293, 295, 299, 316, 320, 336, 17 | 337, 343, 358, 399, 401, 429, 438, 468, 476, 494, 509, 528, 18 | 531, 572, 581, 583, 588, 593, 681, 688, 696, 714, 767, 786, 19 | 810, 825, 836, 841, 883, 917, 937, 942, 970, 974, 980, 1016, 20 | 1043, 1064, 1118, 1121, 1133, 1153, 1155, 1158, 1159, 1173, 21 | 1187, 1219, 1237, 1238, 1259, 1266, 1278, 1296, 1354, 1378, 22 | 1387, 1494, 1508, 1518, 1574, 1601, 1614, 1668, 1673, 1699, 23 | 1712, 1714, 1737, 1841, 1872, 1879, 1901, 1921, 1934, 1961, 24 | 1967, 1978, 2018, 2030, 2039, 2043, 2061, 2113, 2204, 2216, 25 | 2236, 2250, 2274, 2292, 2310, 2342, 2359, 2374, 2382, 2399, 26 | 2415, 2419, 2483, 2502, 2504, 2576, 2589, 2590, 2622, 2624, 27 | 2636, 2651, 2655, 2658, 2659, 2664, 2672, 2706, 2707, 2709, 28 | 2725, 2732, 2761, 2827, 2864, 2866, 2905, 2922, 2929, 2966, 29 | 2972, 2993, 3010, 3025, 3031, 3040, 3041, 3070, 3113, 3124, 30 | 3129, 3137, 3141, 3157, 3183, 3206, 3219, 3247, 3253, 3272, 31 | 3276, 3321, 3328, 3333, 3338, 3341, 3346, 3351, 3396, 3419, 32 | 3430, 3433, 3448, 3455, 3463, 3503, 3526, 3529, 3537, 3555, 33 | 3577, 3584, 3591, 3594, 3597, 3603, 3613, 3615, 3670, 3676, 34 | 3678, 3697, 3723, 3728, 3734, 3745, 3750, 3752, 3779, 3782, 35 | 3813, 3817, 3819, 3854, 3885, 3944, 3947, 3970, 3985, 4011, 36 | 4022, 4071, 4075, 4132, 4158, 4167, 4190, 4194, 4207, 4246, 37 | 4249, 4298, 4307, 4317, 4318, 4319, 4320, 4382, 4399, 4401, 38 | 4407, 4416, 4423, 4484, 4491, 4493, 4517, 4525, 4538, 4578, 39 | 4606, 4609, 4620, 4623, 4637, 4646, 4662, 4668, 4716, 4739, 40 | 4747, 4770, 4774, 4776, 4785, 4800, 4845, 4863, 4891, 4904, 41 | 4922, 4925, 4956, 4963, 4964, 4994, 5011, 5019, 5036, 5038, 42 | 5041, 5055, 5118, 5122, 5130, 5162, 5164, 5178, 5196, 5227, 43 | 5266, 5270, 5273, 5279, 5299, 5310, 5314, 5363, 5375, 5384, 44 | 5393, 5414, 5417, 5433, 5448, 5494, 5505, 5509, 5525, 5566, 45 | 5581, 5602, 5609, 5620, 5653, 5670, 5678, 5690, 5700, 5703, 46 | 5724, 5752, 5765, 5803, 5811, 5860, 5881, 5895, 5912, 5915, 47 | 5940, 5952, 5966, 5977, 5988, 6007, 6037, 6061, 6069, 6080, 48 | 6111, 6127, 6146, 6161, 6166, 6168, 6178, 6182, 6190, 6220, 49 | 6235, 6253, 6270, 6343, 6372, 6379, 6410, 6411, 6442, 6453, 50 | 6481, 6498, 6500, 6509, 6532, 6541, 6543, 6560, 6576, 6580, 51 | 6594, 6595, 6609, 6625, 6629, 6644, 6658, 6673, 6680, 6698, 52 | 6699, 6702, 6705, 6741, 6759, 6785, 6792, 6794, 6809, 6810, 53 | 6830, 6838, 6869, 6871, 6889, 6925, 6995, 7003, 7026, 7029, 54 | 7080, 7082, 7097, 7102, 7116, 7165, 7200, 7232, 7271, 7282, 55 | 7324, 7333, 7335, 7372, 7387, 7407, 7472, 7474, 7482, 7489, 56 | 7499, 7516, 7533, 7536, 7566, 7620, 7654, 7691, 7704, 7722, 57 | 7746, 7750, 7773, 7806, 7821, 7827, 7851, 7873, 7880, 7884, 58 | 7904, 7912, 7948, 7964, 7965, 7984, 7989, 7992, 8035, 8050, 59 | 8074, 8091, 8094, 8113, 8116, 8151, 8159, 8171, 8179, 8194, 60 | 8195, 8239, 8263, 8290, 8295, 8312, 8367, 8374, 8387, 8407, 61 | 8437, 8439, 8518, 8556, 8588, 8597, 8601, 8651, 8657, 8723, 62 | 8759, 8763, 8785, 8802, 8813, 8826, 8854, 8856, 8866, 8918, 63 | 8922, 8923, 8932, 8958, 8967, 9003, 9018, 9078, 9095, 9104, 64 | 9112, 9129, 9147, 9170, 9171, 9197, 9200, 9249, 9253, 9270, 65 | 9282, 9288, 9295, 9321, 9323, 9324, 9347, 9399, 9403, 9417, 66 | 9426, 9427, 9439, 9468, 9486, 9496, 9511, 9516, 9518, 9529, 67 | 9557, 9563, 9564, 9584, 9586, 9591, 9599, 9600, 9601, 9632, 68 | 9654, 9667, 9678, 9696, 9716, 9723, 9740, 9820, 9824, 9825, 69 | 9828, 9863, 9866, 9868, 9889, 9929, 9938, 9953, 9967, 10019, 70 | 10020, 10025, 10059, 10111, 10118, 10125, 10174, 10194, 71 | 10201, 10202, 10220, 10221, 10226, 10242, 10250, 10276, 72 | 10295, 10302, 10305, 10327, 10351, 10360, 10369, 10393, 73 | 10407, 10438, 10455, 10463, 10465, 10470, 10478, 10503, 74 | 10508, 10509, 10809, 11080, 11331, 11607, 11610, 11864, 75 | 12390, 12393, 12396, 12399, 12671, 12921, 12930, 13178, 76 | 13453, 13717, 14499, 14517, 14775, 15297, 15556, 15834, 77 | 15839, 16126, 16127, 16386, 16633, 16644, 16651, 17166, 78 | 17169, 17958, 17959, 17962, 18224, 21176, 21180, 21190, 79 | 21802, 21803, 21806, 22584, 22857, 22858, 22866] 80 | 81 | 82 | class FlyingChairsOcc(data.Dataset): 83 | def __init__(self, 84 | args, 85 | root, 86 | photometric_augmentations=False, 87 | dstype="train", backward=False): 88 | 89 | self._args = args 90 | self.backward = backward 91 | 92 | # ------------------------------------------------------------- 93 | # filenames for all input images and target flows 94 | # ------------------------------------------------------------- 95 | image1_filenames = sorted(glob(os.path.join(root, "*_img1.png"))) 96 | image2_filenames = sorted(glob(os.path.join(root, "*_img2.png"))) 97 | occ1_filenames = sorted(glob(os.path.join(root, "*_occ1.png"))) 98 | occ2_filenames = sorted(glob(os.path.join(root, "*_occ2.png"))) 99 | flow_f_filenames = sorted(glob(os.path.join(root, "*_flow.flo"))) 100 | flow_b_filenames = sorted(glob(os.path.join(root, "*_flow_b.flo"))) 101 | assert (len(image1_filenames) == len(image2_filenames)) 102 | assert (len(image2_filenames) == len(occ1_filenames)) 103 | assert (len(occ1_filenames) == len(occ2_filenames)) 104 | assert (len(occ2_filenames) == len(flow_f_filenames)) 105 | assert (len(flow_f_filenames) == len(flow_b_filenames)) 106 | 107 | num_flows = len(flow_f_filenames) 108 | 109 | # ------------------------------------------------------------- 110 | # Remove invalid validation indices 111 | # ------------------------------------------------------------- 112 | validate_indices = [x for x in VALIDATE_INDICES if x in range(num_flows)] 113 | 114 | # ---------------------------------------------------------- 115 | # Construct list of indices for training/validation 116 | # ---------------------------------------------------------- 117 | list_of_indices = None 118 | if dstype == "train": 119 | list_of_indices = [x for x in range(num_flows) if x not in validate_indices] 120 | elif dstype == "valid": 121 | list_of_indices = validate_indices 122 | elif dstype == "full": 123 | list_of_indices = range(num_flows) 124 | else: 125 | raise ValueError("FlyingChairs: dstype '%s' unknown!", dstype) 126 | 127 | # ---------------------------------------------------------- 128 | # Save list of actual filenames for inputs and flows 129 | # ---------------------------------------------------------- 130 | self._image_list = [] 131 | self._flow_list = [] 132 | self._occ_list = [] 133 | for i in list_of_indices: 134 | flo_f = flow_f_filenames[i] 135 | flo_b = flow_b_filenames[i] 136 | im1 = image1_filenames[i] 137 | im2 = image2_filenames[i] 138 | self._image_list += [[im1, im2]] 139 | self._flow_list += [[flo_f, flo_b]] 140 | occ1 = occ1_filenames[i] 141 | occ2 = occ2_filenames[i] 142 | self._occ_list += [[occ1, occ2]] 143 | 144 | self._size = len(self._image_list) 145 | assert len(self._image_list) == len(self._flow_list) 146 | assert len(self._occ_list) == len(self._flow_list) 147 | 148 | 149 | # ---------------------------------------------------------- 150 | # photometric_augmentations 151 | # ---------------------------------------------------------- 152 | if photometric_augmentations: 153 | self._photometric_transform = transforms.ConcatTransformSplitChainer([ 154 | # uint8 -> PIL 155 | vision_transforms.ToPILImage(), 156 | # PIL -> PIL : random hsv and contrast 157 | vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), 158 | # PIL -> FloatTensor 159 | vision_transforms.transforms.ToTensor(), 160 | transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True), 161 | ], from_numpy=True, to_numpy=False) 162 | 163 | else: 164 | self._photometric_transform = transforms.ConcatTransformSplitChainer([ 165 | # uint8 -> FloatTensor 166 | vision_transforms.transforms.ToTensor(), 167 | ], from_numpy=True, to_numpy=False) 168 | 169 | def __getitem__(self, index): 170 | index = index % self._size 171 | 172 | im1_filename = self._image_list[index][0] 173 | im2_filename = self._image_list[index][1] 174 | flo_f_filename = self._flow_list[index][0] 175 | flo_b_filename = self._flow_list[index][1] 176 | occ1_filename = self._occ_list[index][0] 177 | occ2_filename = self._occ_list[index][1] 178 | 179 | # read float32 images and flow 180 | im1_np0 = common.read_image_as_byte(im1_filename) 181 | im2_np0 = common.read_image_as_byte(im2_filename) 182 | flo_f_np0 = common.read_flo_as_float32(flo_f_filename) 183 | flo_b_np0 = common.read_flo_as_float32(flo_b_filename) 184 | occ1_np0 = common.read_occ_image_as_float32(occ1_filename) 185 | occ2_np0 = common.read_occ_image_as_float32(occ2_filename) 186 | 187 | # possibly apply photometric transformations 188 | im1, im2 = self._photometric_transform(im1_np0, im2_np0) 189 | 190 | # convert flow to FloatTensor 191 | flo_f = common.numpy2torch(flo_f_np0) 192 | flo_b = common.numpy2torch(flo_b_np0) 193 | 194 | # convert occ to FloatTensor 195 | occ1 = common.numpy2torch(occ1_np0) 196 | occ2 = common.numpy2torch(occ2_np0) 197 | 198 | # example filename 199 | basename = os.path.basename(im1_filename)[:5] 200 | 201 | if self.backward: 202 | #inversion des flots (et occ) : backward en 1 et forward en 2 203 | example_dict = { 204 | "input1": im1, 205 | "input2": im2, 206 | "target1": flo_b, 207 | "target2": flo_f, 208 | "target_occ1": occ2, 209 | "target_occ2": occ1, 210 | "index": index, 211 | "basename": basename 212 | } 213 | else: 214 | example_dict = { 215 | "input1": im1, 216 | "input2": im2, 217 | "target1": flo_f, 218 | "target2": flo_b, 219 | "target_occ1": occ1, 220 | "target_occ2": occ2, 221 | "index": index, 222 | "basename": basename 223 | } 224 | 225 | return example_dict 226 | 227 | def __len__(self): 228 | return self._size 229 | 230 | 231 | class FlyingChairsOccTrain(FlyingChairsOcc): 232 | def __init__(self, 233 | args, 234 | root, 235 | photometric_augmentations=True, 236 | backward=False): 237 | super(FlyingChairsOccTrain, self).__init__( 238 | args, 239 | root=root, 240 | photometric_augmentations=photometric_augmentations, 241 | dstype="train", backward=backward) 242 | 243 | 244 | class FlyingChairsOccValid(FlyingChairsOcc): 245 | def __init__(self, 246 | args, 247 | root, 248 | photometric_augmentations=False, 249 | backward=False): 250 | super(FlyingChairsOccValid, self).__init__( 251 | args, 252 | root=root, 253 | photometric_augmentations=photometric_augmentations, 254 | dstype="valid", backward=backward) 255 | 256 | 257 | class FlyingChairsOccFull(FlyingChairsOcc): 258 | def __init__(self, 259 | args, 260 | root, 261 | photometric_augmentations=False, 262 | backward=False): 263 | super(FlyingChairsOccFull, self).__init__( 264 | args, 265 | root=root, 266 | photometric_augmentations=photometric_augmentations, 267 | dstype="full", backward=backward) 268 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | ## Portions of Code from, copyright 2018 Jochen Gast 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def image_random_gamma(image, min_gamma=0.7, max_gamma=1.5, clip_image=False): 10 | gamma = np.random.uniform(min_gamma, max_gamma) 11 | adjusted = torch.pow(image, gamma) 12 | if clip_image: 13 | adjusted.clamp_(0.0, 1.0) 14 | return adjusted 15 | 16 | 17 | class RandomGamma: 18 | def __init__(self, min_gamma=0.7, max_gamma=1.5, clip_image=False): 19 | self._min_gamma = min_gamma 20 | self._max_gamma = max_gamma 21 | self._clip_image = clip_image 22 | 23 | def __call__(self, image): 24 | return image_random_gamma( 25 | image, 26 | min_gamma=self._min_gamma, 27 | max_gamma=self._max_gamma, 28 | clip_image=self._clip_image) 29 | 30 | 31 | # ------------------------------------------------------------------ 32 | # Allow transformation chains of the type: 33 | # im1, im2, .... = transform(im1, im2, ...) 34 | # ------------------------------------------------------------------ 35 | class TransformChainer: 36 | def __init__(self, list_of_transforms): 37 | self._list_of_transforms = list_of_transforms 38 | 39 | def __call__(self, *args): 40 | list_of_args = list(args) 41 | for transform in self._list_of_transforms: 42 | list_of_args = [transform(arg) for arg in list_of_args] 43 | if len(args) == 1: 44 | return list_of_args[0] 45 | else: 46 | return list_of_args 47 | 48 | 49 | # ------------------------------------------------------------------ 50 | # Allow transformation chains of the type: 51 | # im1, im2, .... = split( transform( concatenate(im1, im2, ...) )) 52 | # ------------------------------------------------------------------ 53 | class ConcatTransformSplitChainer: 54 | def __init__(self, list_of_transforms, from_numpy=True, to_numpy=False): 55 | self._chainer = TransformChainer(list_of_transforms) 56 | self._from_numpy = from_numpy 57 | self._to_numpy = to_numpy 58 | 59 | def __call__(self, *args): 60 | num_splits = len(args) 61 | 62 | if self._from_numpy: 63 | concatenated = np.concatenate(args, axis=0) 64 | else: 65 | concatenated = torch.cat(args, dim=1) 66 | 67 | transformed = self._chainer(concatenated) 68 | 69 | if self._to_numpy: 70 | split = np.split(transformed, indices_or_sections=num_splits, axis=0) 71 | else: 72 | split = torch.chunk(transformed, num_splits, dim=1) 73 | 74 | return split 75 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from glob import glob 4 | 5 | import scipy.misc 6 | import numpy as np 7 | import torch 8 | 9 | from torchvision import transforms as vision_transforms 10 | import models 11 | from datasets import common 12 | from configuration import ModelAndLoss 13 | 14 | from utils.flow import flow_to_png_middlebury, write_flow 15 | 16 | import pylab as pl 17 | pl.interactive(True) 18 | 19 | import argparse 20 | 21 | ''' 22 | Example (will save results in ./output/): 23 | python inference.py \ 24 | --model StarFlow \ 25 | --checkpoint saved_checkpoint/StarFlow_things/checkpoint_best.ckpt \ 26 | --data-root /data/mpisintelcomplete/training/final/ambush_6/ \ 27 | --file-list frame_0004.png frame_0005.png frame_0006.png frame_0007.png 28 | ''' 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--savedir", type=str, default="./output") 32 | parser.add_argument("--data-root", type=str, 33 | default="./") 34 | parser.add_argument('--file-list', nargs='*', default=[-1], type=str) 35 | 36 | parser.add_argument("--model", type=str, default='StarFlow') 37 | parser.add_argument('--checkpoint', dest='checkpoint', default=None, 38 | metavar='PATH', help='path to pre-trained model') 39 | 40 | parser.add_argument('--device', type=int, default=0) 41 | parser.add_argument("--no-cuda", action="store_true") 42 | 43 | args = parser.parse_args() 44 | 45 | # use cuda GPU 46 | use_cuda = (not args.no_cuda) and torch.cuda.is_available() 47 | 48 | # --------------------- 49 | # Load pretrained model 50 | # --------------------- 51 | MODEL = models.__dict__[args.model] 52 | net = ModelAndLoss(None, MODEL(None), None) 53 | checkpoint_with_state = torch.load(args.checkpoint, 54 | map_location=lambda storage, 55 | loc: storage.cuda(args.device)) 56 | state_dict = checkpoint_with_state['state_dict'] 57 | net.load_state_dict(state_dict) 58 | net.eval() 59 | net.cuda() 60 | 61 | # ------------------- 62 | # Load image sequence 63 | # ------------------- 64 | if not os.path.exists(args.data_root): 65 | raise ValueError("data-root: {} not found".format(args.data_root)) 66 | if len(args.file_list) == 0: 67 | raise ValueError("file-list empty") 68 | elif len(args.file_list) == 1: 69 | path = os.path.join(args.data_root, args.file_list[0]) 70 | list_path_imgs = sorted(glob(path)) 71 | if len(list_path_imgs) == 0: 72 | raise ValueError("no data were found") 73 | else: 74 | list_path_imgs = [os.path.join(args.data_root, file_name) 75 | for file_name in args.file_list] 76 | for path_im in list_path_imgs: 77 | if not os.path.isfile(path_im): 78 | raise ValueError("file {} not found".format(path_im)) 79 | img_reader = common.read_image_as_byte 80 | #flo_reader = common.read_flo_as_float32 81 | imgs_np = [img_reader(path) for path in list_path_imgs] 82 | if imgs_np[0].squeeze().ndim == 2: 83 | imgs_np = [np.dstack([im]*3) for im in imgs_np] 84 | to_tensor = vision_transforms.ToTensor() 85 | images = [to_tensor(im).unsqueeze(0).cuda() for im in imgs_np] 86 | input_dict = {'input_images':images} 87 | 88 | # --------------- 89 | # Flow estimation 90 | # --------------- 91 | with torch.no_grad(): 92 | output_dict = net._model(input_dict) 93 | 94 | estimated_flow = output_dict['flow'] 95 | 96 | if len(imgs_np) > 2: 97 | estimated_flow_np = estimated_flow[:,0].cpu().numpy() 98 | estimated_flow_np = [flow for flow in estimated_flow_np] 99 | else: 100 | estimated_flow_np = [estimated_flow[0].cpu().numpy()] 101 | 102 | 103 | # ------------ 104 | # Save results 105 | # ------------ 106 | if not os.path.exists(os.path.join(args.savedir, "visu")): 107 | os.makedirs(os.path.join(args.savedir, "visu")) 108 | if not os.path.exists(os.path.join(args.savedir, "flow")): 109 | os.makedirs(os.path.join(args.savedir, "flow")) 110 | for t in range(len(imgs_np)-1): 111 | flow_visu = flow_to_png_middlebury(estimated_flow_np[t]) 112 | basename = os.path.splitext(os.path.basename(list_path_imgs[t]))[0] 113 | file_name_flow_visu = os.path.join(args.savedir, 'visu', 114 | basename + '_flow_visu.png') 115 | file_name_flow = os.path.join(args.savedir, 'flow', 116 | basename + '_flow.flo') 117 | scipy.misc.imsave(file_name_flow_visu, flow_visu) 118 | write_flow(file_name_flow, estimated_flow_np[t].swapaxes(0, 1).swapaxes(1, 2)) 119 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ./models/correlation_package 3 | python setup.py install 4 | cd .. 5 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | ## Portions of Code from, copyright 2018 Jochen Gast 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import colorama 6 | import logging 7 | import os 8 | import re 9 | import tools 10 | import sys 11 | 12 | 13 | def get_default_logging_format(colorize=False, brackets=False): 14 | style = colorama.Style.DIM if colorize else '' 15 | # color = colorama.Fore.CYAN if colorize else '' 16 | color = colorama.Fore.WHITE if colorize else '' 17 | reset = colorama.Style.RESET_ALL if colorize else '' 18 | if brackets: 19 | result = "{}{}[%(asctime)s]{} %(message)s".format(style, color, reset) 20 | else: 21 | result = "{}{}%(asctime)s{} %(message)s".format(style, color, reset) 22 | return result 23 | 24 | 25 | def get_default_logging_datefmt(): 26 | return "%Y-%m-%d %H:%M:%S" 27 | 28 | 29 | def log_module_info(module): 30 | lines = module.__str__().split("\n") 31 | for line in lines: 32 | logging.info(line) 33 | 34 | 35 | class LogbookFormatter(logging.Formatter): 36 | def __init__(self, fmt=None, datefmt=None): 37 | super(LogbookFormatter, self).__init__(fmt=fmt, datefmt=datefmt) 38 | self._re = re.compile(r"\033\[[0-9]+m") 39 | 40 | def remove_colors_from_msg(self, msg): 41 | msg = re.sub(self._re, "", msg) 42 | return msg 43 | 44 | def format(self, record=None): 45 | record.msg = self.remove_colors_from_msg(record.msg) 46 | return super(LogbookFormatter, self).format(record) 47 | 48 | 49 | class ConsoleFormatter(logging.Formatter): 50 | def __init__(self, fmt=None, datefmt=None): 51 | super(ConsoleFormatter, self).__init__(fmt=fmt, datefmt=datefmt) 52 | 53 | def format(self, record=None): 54 | indent = sys.modules[__name__].global_indent 55 | record.msg = " " * indent + record.msg 56 | return super(ConsoleFormatter, self).format(record) 57 | 58 | 59 | class SkipLogbookFilter(logging.Filter): 60 | def filter(self, record): 61 | return record.levelno != logging.LOGBOOK 62 | 63 | 64 | def configure_logging(filename=None): 65 | # set global indent level 66 | sys.modules[__name__].global_indent = 0 67 | 68 | # add custom tqdm logger 69 | tools.addLoggingLevel("LOGBOOK", 1000) 70 | 71 | # create logger 72 | root_logger = logging.getLogger("") 73 | root_logger.setLevel(logging.INFO) 74 | 75 | # create console handler and set level to debug 76 | console = logging.StreamHandler() 77 | console.setLevel(logging.INFO) 78 | fmt = get_default_logging_format(colorize=True, brackets=False) 79 | datefmt = get_default_logging_datefmt() 80 | formatter = ConsoleFormatter(fmt=fmt, datefmt=datefmt) 81 | console.setFormatter(formatter) 82 | 83 | # Skip logging.tqdm requests for console outputs 84 | skip_logbook_filter = SkipLogbookFilter() 85 | console.addFilter(skip_logbook_filter) 86 | 87 | # add console to root_logger 88 | root_logger.addHandler(console) 89 | 90 | # add logbook 91 | if filename is not None: 92 | # ensure dir 93 | d = os.path.dirname(filename) 94 | if not os.path.exists(d): 95 | os.makedirs(d) 96 | 97 | # -------------------------------------------------------------------------------------- 98 | # Configure handler that removes color codes from logbook 99 | # -------------------------------------------------------------------------------------- 100 | logbook = logging.FileHandler(filename=filename, mode="a", encoding="utf-8") 101 | logbook.setLevel(logging.INFO) 102 | fmt = get_default_logging_format(colorize=False, brackets=True) 103 | logbook_formatter = LogbookFormatter(fmt=fmt, datefmt=datefmt) 104 | logbook.setFormatter(logbook_formatter) 105 | root_logger.addHandler(logbook) 106 | 107 | 108 | class LoggingBlock: 109 | def __init__(self, title, emph=False): 110 | self._emph = emph 111 | bright = colorama.Style.BRIGHT 112 | cyan = colorama.Fore.CYAN 113 | reset = colorama.Style.RESET_ALL 114 | if emph: 115 | logging.info("%s==>%s %s%s%s" % (cyan, reset, bright, title, reset)) 116 | else: 117 | logging.info(title) 118 | 119 | def __enter__(self): 120 | sys.modules[__name__].global_indent += 2 121 | return self 122 | 123 | def __exit__(self, exc_type, exc_value, traceback): 124 | sys.modules[__name__].global_indent -= 2 125 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import subprocess 5 | import commandline 6 | import configuration as config 7 | import runtime 8 | import logger 9 | import logging 10 | import tools 11 | import torch 12 | 13 | def main(): 14 | 15 | # Change working directory 16 | os.chdir(os.path.dirname(os.path.realpath(__file__))) 17 | 18 | # Parse commandline arguments 19 | args = commandline.setup_logging_and_parse_arguments(blocktitle="Commandline Arguments") 20 | 21 | # set cuda device: 22 | if args.cuda: 23 | torch.cuda.set_device(args.device) 24 | 25 | # Set random seed, possibly on Cuda 26 | config.configure_random_seed(args) 27 | 28 | # DataLoader 29 | train_loader, validation_loader, inference_loader = config.configure_data_loaders(args) 30 | success = any(loader is not None for loader in [train_loader, validation_loader, inference_loader]) 31 | if not success: 32 | logging.info("No dataset could be loaded successfully. Please check dataset paths!") 33 | quit() 34 | 35 | if args.resume: 36 | args.checkpoint = os.path.join(args.save, 'checkpoint_latest.ckpt') 37 | args.optim_checkpoint = os.path.join(args.save, 'optim_state_dict_checkpoint.pth') 38 | checkpoint_with_state = torch.load(args.checkpoint, 39 | map_location=lambda storage, 40 | loc: storage.cuda(args.device)) 41 | args.start_epoch = checkpoint_with_state['epoch'] + 1 42 | 43 | # Configure data augmentation 44 | training_augmentation, validation_augmentation = config.configure_runtime_augmentations(args) 45 | 46 | # Configure model and loss 47 | model_and_loss = config.configure_model_and_loss(args) 48 | 49 | # Resume from checkpoint if available 50 | checkpoint_saver, checkpoint_stats = config.configure_checkpoint_saver(args, model_and_loss) 51 | 52 | # Checkpoint and save directory 53 | with logger.LoggingBlock("Save Directory", emph=True): 54 | logging.info("Save directory: %s" % args.save) 55 | if not os.path.exists(args.save): 56 | os.makedirs(args.save) 57 | 58 | # # Multi-GPU automation 59 | # with logger.LoggingBlock("Multi GPU", emph=True): 60 | # if torch.cuda.device_count() > 1: 61 | # logging.info("Let's use %d GPUs!" % torch.cuda.device_count()) 62 | # model_and_loss._model = torch.nn.DataParallel(model_and_loss._model) 63 | # else: 64 | # logging.info("Let's use %d GPU!" % torch.cuda.device_count()) 65 | 66 | # Configure optimizer 67 | optimizer = config.configure_optimizer(args, model_and_loss) 68 | 69 | # Configure learning rate 70 | lr_scheduler = config.configure_lr_scheduler(args, optimizer) 71 | 72 | # If this is just an evaluation: overwrite savers and epochs 73 | if args.evaluation: 74 | args.start_epoch = 1 75 | args.total_epochs = 1 76 | train_loader = None 77 | checkpoint_saver = None 78 | optimizer = None 79 | lr_scheduler = None 80 | 81 | # Cuda optimization 82 | if args.cuda: 83 | torch.backends.cudnn.benchmark = True 84 | 85 | # Kickoff training, validation and/or testing 86 | return runtime.exec_runtime( 87 | args, 88 | checkpoint_saver=checkpoint_saver, 89 | model_and_loss=model_and_loss, 90 | optimizer=optimizer, 91 | lr_scheduler=lr_scheduler, 92 | train_loader=train_loader, 93 | validation_loader=validation_loader, 94 | inference_loader=inference_loader, 95 | training_augmentation=training_augmentation, 96 | validation_augmentation=validation_augmentation) 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /models/IRR_PWC.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .pwc_modules import conv, upsample2d_as, rescale_flow, initialize_msra 7 | from .pwc_modules import WarpingLayer, FeatureExtractor 8 | from .pwc_modules import ContextNetwork, FlowEstimatorDense 9 | from .pwc_modules import OccContextNetwork, OccEstimatorDense 10 | from .irr_modules import OccUpsampleNetwork, RefineFlow, RefineOcc 11 | from .correlation_package.correlation import Correlation 12 | 13 | import copy 14 | 15 | 16 | class PWCNet(nn.Module): 17 | def __init__(self, args, div_flow=0.05): 18 | super(PWCNet, self).__init__() 19 | self.args = args 20 | self._div_flow = div_flow 21 | self.search_range = 4 22 | self.num_chs = [3, 16, 32, 64, 96, 128, 196] 23 | self.output_level = 4 24 | self.num_levels = 7 25 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True) 26 | 27 | self.feature_pyramid_extractor = FeatureExtractor(self.num_chs) 28 | self.warping_layer = WarpingLayer() 29 | 30 | self.dim_corr = (self.search_range * 2 + 1) ** 2 31 | self.num_ch_in_flo = self.dim_corr + 32 + 2 32 | self.num_ch_in_occ = self.dim_corr + 32 + 1 33 | 34 | self.flow_estimators = FlowEstimatorDense(self.num_ch_in_flo) 35 | self.context_networks = ContextNetwork(self.num_ch_in_flo + 448 + 2) 36 | self.occ_estimators = OccEstimatorDense(self.num_ch_in_occ) 37 | self.occ_context_networks = OccContextNetwork(self.num_ch_in_occ + 448 + 1) 38 | self.occ_shuffle_upsample = OccUpsampleNetwork(11, 1) 39 | 40 | self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1), 41 | conv(128, 32, kernel_size=1, stride=1, dilation=1), 42 | conv(96, 32, kernel_size=1, stride=1, dilation=1), 43 | conv(64, 32, kernel_size=1, stride=1, dilation=1)]) 44 | 45 | self.conv_1x1_1 = conv(16, 3, kernel_size=1, stride=1, dilation=1) 46 | 47 | self.refine_flow = RefineFlow(2 + 1 + 32) 48 | self.refine_occ = RefineOcc(1 + 32 + 32) 49 | 50 | initialize_msra(self.modules()) 51 | 52 | def forward(self, input_dict): 53 | 54 | x1_raw = input_dict['input1'] 55 | x2_raw = input_dict['input2'] 56 | batch_size, _, height_im, width_im = x1_raw.size() 57 | 58 | # on the bottom level are original images 59 | x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw] 60 | x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw] 61 | 62 | # outputs 63 | output_dict = {} 64 | output_dict_eval = {} 65 | flows = [] 66 | occs = [] 67 | 68 | _, _, h_x1, w_x1, = x1_pyramid[0].size() 69 | flow_f = torch.zeros(batch_size, 2, h_x1, w_x1).float().cuda() 70 | flow_b = torch.zeros(batch_size, 2, h_x1, w_x1).float().cuda() 71 | occ_f = torch.zeros(batch_size, 1, h_x1, w_x1).float().cuda() 72 | occ_b = torch.zeros(batch_size, 1, h_x1, w_x1).float().cuda() 73 | 74 | for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)): 75 | 76 | if l <= self.output_level: 77 | 78 | # warping 79 | if l == 0: 80 | x2_warp = x2 81 | x1_warp = x1 82 | else: 83 | flow_f = upsample2d_as(flow_f, x1, mode="bilinear") 84 | flow_b = upsample2d_as(flow_b, x2, mode="bilinear") 85 | occ_f = upsample2d_as(occ_f, x1, mode="bilinear") 86 | occ_b = upsample2d_as(occ_b, x2, mode="bilinear") 87 | x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow) 88 | x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow) 89 | 90 | # correlation 91 | out_corr_f = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x1, x2_warp) 92 | out_corr_b = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x2, x1_warp) 93 | out_corr_relu_f = self.leakyRELU(out_corr_f) 94 | out_corr_relu_b = self.leakyRELU(out_corr_b) 95 | 96 | if l != self.output_level: 97 | x1_1by1 = self.conv_1x1[l](x1) 98 | x2_1by1 = self.conv_1x1[l](x2) 99 | else: 100 | x1_1by1 = x1 101 | x2_1by1 = x2 102 | 103 | # concat and estimate flow 104 | flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=True) 105 | flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=True) 106 | 107 | x_intm_f, flow_res_f = self.flow_estimators(torch.cat([out_corr_relu_f, x1_1by1, flow_f], dim=1)) 108 | x_intm_b, flow_res_b = self.flow_estimators(torch.cat([out_corr_relu_b, x2_1by1, flow_b], dim=1)) 109 | flow_est_f = flow_f + flow_res_f 110 | flow_est_b = flow_b + flow_res_b 111 | 112 | flow_cont_f = flow_est_f + self.context_networks(torch.cat([x_intm_f, flow_est_f], dim=1)) 113 | flow_cont_b = flow_est_b + self.context_networks(torch.cat([x_intm_b, flow_est_b], dim=1)) 114 | 115 | # occ estimation 116 | x_intm_occ_f, occ_res_f = self.occ_estimators(torch.cat([out_corr_relu_f, x1_1by1, occ_f], dim=1)) 117 | x_intm_occ_b, occ_res_b = self.occ_estimators(torch.cat([out_corr_relu_b, x2_1by1, occ_b], dim=1)) 118 | occ_est_f = occ_f + occ_res_f 119 | occ_est_b = occ_b + occ_res_b 120 | 121 | occ_cont_f = occ_est_f + self.occ_context_networks(torch.cat([x_intm_occ_f, occ_est_f], dim=1)) 122 | occ_cont_b = occ_est_b + self.occ_context_networks(torch.cat([x_intm_occ_b, occ_est_b], dim=1)) 123 | 124 | # refinement 125 | img1_resize = upsample2d_as(x1_raw, flow_f, mode="bilinear") 126 | img2_resize = upsample2d_as(x2_raw, flow_b, mode="bilinear") 127 | img2_warp = self.warping_layer(img2_resize, rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow) 128 | img1_warp = self.warping_layer(img1_resize, rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow) 129 | 130 | # flow refine 131 | flow_f = self.refine_flow(flow_cont_f.detach(), img1_resize - img2_warp, x1_1by1) 132 | flow_b = self.refine_flow(flow_cont_b.detach(), img2_resize - img1_warp, x2_1by1) 133 | 134 | flow_cont_f = rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False) 135 | flow_cont_b = rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False) 136 | flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=False) 137 | flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=False) 138 | 139 | # occ refine 140 | x2_1by1_warp = self.warping_layer(x2_1by1, flow_f, height_im, width_im, self._div_flow) 141 | x1_1by1_warp = self.warping_layer(x1_1by1, flow_b, height_im, width_im, self._div_flow) 142 | 143 | occ_f = self.refine_occ(occ_cont_f.detach(), x1_1by1, x1_1by1 - x2_1by1_warp) 144 | occ_b = self.refine_occ(occ_cont_b.detach(), x2_1by1, x2_1by1 - x1_1by1_warp) 145 | 146 | flows.append([flow_cont_f, flow_cont_b, flow_f, flow_b]) 147 | occs.append([occ_cont_f, occ_cont_b, occ_f, occ_b]) 148 | 149 | else: 150 | flow_f = upsample2d_as(flow_f, x1, mode="bilinear") 151 | flow_b = upsample2d_as(flow_b, x2, mode="bilinear") 152 | flows.append([flow_f, flow_b]) 153 | 154 | x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow) 155 | x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow) 156 | flow_b_warp = self.warping_layer(flow_b, flow_f, height_im, width_im, self._div_flow) 157 | flow_f_warp = self.warping_layer(flow_f, flow_b, height_im, width_im, self._div_flow) 158 | 159 | if l != self.num_levels-1: 160 | x1_in = self.conv_1x1_1(x1) 161 | x2_in = self.conv_1x1_1(x2) 162 | x1_w_in = self.conv_1x1_1(x1_warp) 163 | x2_w_in = self.conv_1x1_1(x2_warp) 164 | else: 165 | x1_in = x1 166 | x2_in = x2 167 | x1_w_in = x1_warp 168 | x2_w_in = x2_warp 169 | 170 | occ_f = self.occ_shuffle_upsample(occ_f, torch.cat([x1_in, x2_w_in, flow_f, flow_b_warp], dim=1)) 171 | occ_b = self.occ_shuffle_upsample(occ_b, torch.cat([x2_in, x1_w_in, flow_b, flow_f_warp], dim=1)) 172 | 173 | occs.append([occ_f, occ_b]) 174 | 175 | output_dict_eval['flow'] = upsample2d_as(flow_f, x1_raw, mode="bilinear") * (1.0 / self._div_flow) 176 | output_dict_eval['occ'] = upsample2d_as(occ_f, x1_raw, mode="bilinear") 177 | output_dict['flow'] = flows 178 | output_dict['occ'] = occs 179 | 180 | if self.training: 181 | return output_dict 182 | else: 183 | return output_dict_eval 184 | -------------------------------------------------------------------------------- /models/IRR_PWC_occ_joint.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .pwc_modules import conv, upsample2d_as, rescale_flow, initialize_msra 7 | from .pwc_modules import WarpingLayer, FeatureExtractor 8 | from .pwc_modules import FlowAndOccEstimatorDense, FlowAndOccContextNetwork 9 | from .irr_modules import OccUpsampleNetwork, RefineFlow, RefineOcc 10 | from .correlation_package.correlation import Correlation 11 | 12 | import copy 13 | 14 | 15 | class PWCNet(nn.Module): 16 | def __init__(self, args, div_flow=0.05): 17 | super(PWCNet, self).__init__() 18 | self.args = args 19 | self._div_flow = div_flow 20 | self.search_range = 4 21 | self.num_chs = [3, 16, 32, 64, 96, 128, 196] 22 | self.output_level = 4 23 | self.num_levels = 7 24 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True) 25 | 26 | self.feature_pyramid_extractor = FeatureExtractor(self.num_chs) 27 | self.warping_layer = WarpingLayer() 28 | 29 | self.dim_corr = (self.search_range * 2 + 1) ** 2 30 | self.num_ch_in = self.dim_corr + 32 + 2 + 1 31 | 32 | self.flow_and_occ_estimators = FlowAndOccEstimatorDense(self.num_ch_in) 33 | self.context_networks = FlowAndOccContextNetwork(self.num_ch_in + 448 + 2 + 1) 34 | self.occ_shuffle_upsample = OccUpsampleNetwork(11, 1) 35 | 36 | self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1), 37 | conv(128, 32, kernel_size=1, stride=1, dilation=1), 38 | conv(96, 32, kernel_size=1, stride=1, dilation=1), 39 | conv(64, 32, kernel_size=1, stride=1, dilation=1)]) 40 | 41 | self.conv_1x1_1 = conv(16, 3, kernel_size=1, stride=1, dilation=1) 42 | 43 | self.refine_flow = RefineFlow(2 + 1 + 32) 44 | self.refine_occ = RefineOcc(1 + 32 + 32) 45 | 46 | initialize_msra(self.modules()) 47 | 48 | def forward(self, input_dict): 49 | 50 | x1_raw = input_dict['input1'] 51 | x2_raw = input_dict['input2'] 52 | batch_size, _, height_im, width_im = x1_raw.size() 53 | 54 | # on the bottom level are original images 55 | x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw] 56 | x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw] 57 | 58 | # outputs 59 | output_dict = {} 60 | output_dict_eval = {} 61 | flows = [] 62 | occs = [] 63 | 64 | _, _, h_x1, w_x1, = x1_pyramid[0].size() 65 | flow_f = torch.zeros(batch_size, 2, h_x1, w_x1).float().cuda() 66 | flow_b = torch.zeros(batch_size, 2, h_x1, w_x1).float().cuda() 67 | occ_f = torch.zeros(batch_size, 1, h_x1, w_x1).float().cuda() 68 | occ_b = torch.zeros(batch_size, 1, h_x1, w_x1).float().cuda() 69 | 70 | for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)): 71 | 72 | if l <= self.output_level: 73 | 74 | # warping 75 | if l == 0: 76 | x2_warp = x2 77 | x1_warp = x1 78 | else: 79 | flow_f = upsample2d_as(flow_f, x1, mode="bilinear") 80 | flow_b = upsample2d_as(flow_b, x2, mode="bilinear") 81 | occ_f = upsample2d_as(occ_f, x1, mode="bilinear") 82 | occ_b = upsample2d_as(occ_b, x2, mode="bilinear") 83 | x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow) 84 | x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow) 85 | 86 | # correlation 87 | out_corr_f = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x1, x2_warp) 88 | out_corr_b = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x2, x1_warp) 89 | out_corr_relu_f = self.leakyRELU(out_corr_f) 90 | out_corr_relu_b = self.leakyRELU(out_corr_b) 91 | 92 | if l != self.output_level: 93 | x1_1by1 = self.conv_1x1[l](x1) 94 | x2_1by1 = self.conv_1x1[l](x2) 95 | else: 96 | x1_1by1 = x1 97 | x2_1by1 = x2 98 | 99 | # concat and estimate flow and occ 100 | flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=True) 101 | flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=True) 102 | 103 | x_intm_f, flow_res_f, occ_res_f = self.flow_and_occ_estimators(torch.cat([out_corr_relu_f, x1_1by1, flow_f, occ_f], dim=1)) 104 | x_intm_b, flow_res_b, occ_res_b = self.flow_and_occ_estimators(torch.cat([out_corr_relu_b, x2_1by1, flow_b, occ_b], dim=1)) 105 | flow_est_f = flow_f + flow_res_f 106 | flow_est_b = flow_b + flow_res_b 107 | occ_est_f = occ_f + occ_res_f 108 | occ_est_b = occ_b + occ_res_b 109 | 110 | flow_fine_f, occ_fine_f = self.context_networks(torch.cat([x_intm_f, flow_est_f, occ_est_f], dim=1)) 111 | flow_fine_b, occ_fine_b = self.context_networks(torch.cat([x_intm_b, flow_est_b, occ_est_b], dim=1)) 112 | 113 | flow_cont_f = flow_est_f + flow_fine_f 114 | flow_cont_b = flow_est_b + flow_fine_b 115 | occ_cont_f = occ_est_f + occ_fine_f 116 | occ_cont_b = occ_est_b + occ_fine_b 117 | 118 | # refinement 119 | img1_resize = upsample2d_as(x1_raw, flow_f, mode="bilinear") 120 | img2_resize = upsample2d_as(x2_raw, flow_b, mode="bilinear") 121 | img2_warp = self.warping_layer(img2_resize, rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow) 122 | img1_warp = self.warping_layer(img1_resize, rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow) 123 | 124 | # flow refine 125 | flow_f = self.refine_flow(flow_cont_f.detach(), img1_resize - img2_warp, x1_1by1) 126 | flow_b = self.refine_flow(flow_cont_b.detach(), img2_resize - img1_warp, x2_1by1) 127 | 128 | flow_cont_f = rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False) 129 | flow_cont_b = rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False) 130 | flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=False) 131 | flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=False) 132 | 133 | # occ refine 134 | x2_1by1_warp = self.warping_layer(x2_1by1, flow_f, height_im, width_im, self._div_flow) 135 | x1_1by1_warp = self.warping_layer(x1_1by1, flow_b, height_im, width_im, self._div_flow) 136 | 137 | occ_f = self.refine_occ(occ_cont_f.detach(), x1_1by1, x1_1by1 - x2_1by1_warp) 138 | occ_b = self.refine_occ(occ_cont_b.detach(), x2_1by1, x2_1by1 - x1_1by1_warp) 139 | 140 | flows.append([flow_cont_f, flow_cont_b, flow_f, flow_b]) 141 | occs.append([occ_cont_f, occ_cont_b, occ_f, occ_b]) 142 | 143 | else: 144 | flow_f = upsample2d_as(flow_f, x1, mode="bilinear") 145 | flow_b = upsample2d_as(flow_b, x2, mode="bilinear") 146 | flows.append([flow_f, flow_b]) 147 | 148 | x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow) 149 | x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow) 150 | flow_b_warp = self.warping_layer(flow_b, flow_f, height_im, width_im, self._div_flow) 151 | flow_f_warp = self.warping_layer(flow_f, flow_b, height_im, width_im, self._div_flow) 152 | 153 | if l != self.num_levels-1: 154 | x1_in = self.conv_1x1_1(x1) 155 | x2_in = self.conv_1x1_1(x2) 156 | x1_w_in = self.conv_1x1_1(x1_warp) 157 | x2_w_in = self.conv_1x1_1(x2_warp) 158 | else: 159 | x1_in = x1 160 | x2_in = x2 161 | x1_w_in = x1_warp 162 | x2_w_in = x2_warp 163 | 164 | occ_f = self.occ_shuffle_upsample(occ_f, torch.cat([x1_in, x2_w_in, flow_f, flow_b_warp], dim=1)) 165 | occ_b = self.occ_shuffle_upsample(occ_b, torch.cat([x2_in, x1_w_in, flow_b, flow_f_warp], dim=1)) 166 | 167 | occs.append([occ_f, occ_b]) 168 | 169 | output_dict_eval['flow'] = upsample2d_as(flow_f, x1_raw, mode="bilinear") * (1.0 / self._div_flow) 170 | output_dict_eval['occ'] = upsample2d_as(occ_f, x1_raw, mode="bilinear") 171 | output_dict['flow'] = flows 172 | output_dict['occ'] = occs 173 | 174 | if self.training: 175 | return output_dict 176 | else: 177 | return output_dict_eval 178 | -------------------------------------------------------------------------------- /models/STAR.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .pwc_modules import conv, upsample2d_as, rescale_flow, initialize_msra 7 | from .pwc_modules import WarpingLayer, FeatureExtractor 8 | from .pwc_modules import FlowAndOccContextNetwork, FlowAndOccEstimatorDense 9 | from .irr_modules import OccUpsampleNetwork, RefineFlow, RefineOcc 10 | from .correlation_package.correlation import Correlation 11 | 12 | import copy 13 | 14 | 15 | class StarFlow(nn.Module): 16 | def __init__(self, args, div_flow=0.05): 17 | super(StarFlow, self).__init__() 18 | self.args = args 19 | self._div_flow = div_flow 20 | self.search_range = 4 21 | self.num_chs = [3, 16, 32, 64, 96, 128, 196] 22 | self.output_level = 4 23 | self.num_levels = 7 24 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True) 25 | 26 | self.feature_pyramid_extractor = FeatureExtractor(self.num_chs) 27 | self.warping_layer = WarpingLayer() 28 | 29 | self.dim_corr = (self.search_range * 2 + 1) ** 2 30 | self.num_ch_in = self.dim_corr + 32 + 2 + 1 31 | 32 | self.flow_and_occ_estimators = FlowAndOccEstimatorDense(2 * self.num_ch_in) 33 | self.context_networks = FlowAndOccContextNetwork(2 * self.num_ch_in + 448 + 2 + 1) 34 | 35 | self.occ_shuffle_upsample = OccUpsampleNetwork(11, 1) 36 | 37 | self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1), 38 | conv(128, 32, kernel_size=1, stride=1, dilation=1), 39 | conv(96, 32, kernel_size=1, stride=1, dilation=1), 40 | conv(64, 32, kernel_size=1, stride=1, dilation=1)]) 41 | 42 | self.conv_1x1_1 = conv(16, 3, kernel_size=1, stride=1, dilation=1) 43 | 44 | self.conv_1x1_time = conv(2 * self.num_ch_in + 448, self.num_ch_in, kernel_size=1, stride=1, dilation=1) 45 | 46 | self.refine_flow = RefineFlow(2 + 1 + 32) 47 | self.refine_occ = RefineOcc(1 + 32 + 32) 48 | 49 | initialize_msra(self.modules()) 50 | 51 | def forward(self, input_dict): 52 | 53 | if 'input_images' in input_dict.keys(): 54 | list_imgs = input_dict['input_images'] 55 | else: 56 | x1_raw = input_dict['input1'] 57 | x2_raw = input_dict['input2'] 58 | list_imgs = [x1_raw, x2_raw] 59 | 60 | _, _, height_im, width_im = list_imgs[0].size() 61 | 62 | # on the bottom level are original images 63 | list_pyramids = [] #indices : [time][level] 64 | for im in list_imgs: 65 | list_pyramids.append(self.feature_pyramid_extractor(im) + [im]) 66 | 67 | # outputs 68 | output_dict = {} 69 | output_dict_eval = {} 70 | flows_f = [] #indices : [level][time] 71 | flows_b = [] #indices : [level][time] 72 | occs_f = [] 73 | occs_b = [] 74 | flows_coarse_f = [] 75 | occs_coarse_f = [] 76 | for l in range(len(list_pyramids[0])): 77 | flows_f.append([]) 78 | flows_b.append([]) 79 | occs_f.append([]) 80 | occs_b.append([]) 81 | for l in range(self.output_level + 1): 82 | flows_coarse_f.append([]) 83 | occs_coarse_f.append([]) 84 | 85 | # init 86 | b_size, _, h_x1, w_x1, = list_pyramids[0][0].size() 87 | init_dtype = list_pyramids[0][0].dtype 88 | init_device = list_pyramids[0][0].device 89 | flow_f = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float() 90 | flow_b = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float() 91 | occ_f = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float() 92 | occ_b = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float() 93 | previous_features = [] 94 | 95 | for i in range(len(list_imgs) - 1): 96 | x1_pyramid, x2_pyramid = list_pyramids[i:i+2] 97 | 98 | for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)): 99 | 100 | if l <= self.output_level: 101 | if i == 0: 102 | bs_, _, h_, w_, = list_pyramids[0][l].size() 103 | previous_features.append(torch.zeros(bs_, self.num_ch_in, h_, w_, dtype=init_dtype, device=init_device).float()) 104 | 105 | # warping 106 | if l == 0: 107 | x2_warp = x2 108 | x1_warp = x1 109 | else: 110 | flow_f = upsample2d_as(flow_f, x1, mode="bilinear") 111 | flow_b = upsample2d_as(flow_b, x2, mode="bilinear") 112 | occ_f = upsample2d_as(occ_f, x1, mode="bilinear") 113 | occ_b = upsample2d_as(occ_b, x2, mode="bilinear") 114 | x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow) 115 | x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow) 116 | 117 | # correlation 118 | out_corr_f = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x1, x2_warp) 119 | out_corr_b = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x2, x1_warp) 120 | out_corr_relu_f = self.leakyRELU(out_corr_f) 121 | out_corr_relu_b = self.leakyRELU(out_corr_b) 122 | 123 | if l != self.output_level: 124 | x1_1by1 = self.conv_1x1[l](x1) 125 | x2_1by1 = self.conv_1x1[l](x2) 126 | else: 127 | x1_1by1 = x1 128 | x2_1by1 = x2 129 | 130 | if i > 0: #temporal connection: 131 | previous_features[l] = self.warping_layer(previous_features[l], 132 | flows_b[l][-1], height_im, width_im, self._div_flow) 133 | 134 | # Flow and occlusions estimation 135 | flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=True) 136 | flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=True) 137 | 138 | features = torch.cat([previous_features[l], out_corr_relu_f, x1_1by1, flow_f, occ_f], 1) 139 | features_b = torch.cat([torch.zeros_like(previous_features[l]), out_corr_relu_b, x2_1by1, flow_b, occ_b], 1) 140 | 141 | x_intm_f, flow_res_f, occ_res_f = self.flow_and_occ_estimators(features) 142 | flow_est_f = flow_f + flow_res_f 143 | occ_est_f = occ_f + occ_res_f 144 | with torch.no_grad(): 145 | x_intm_b, flow_res_b, occ_res_b = self.flow_and_occ_estimators(features_b) 146 | flow_est_b = flow_b + flow_res_b 147 | occ_est_b = occ_b + occ_res_b 148 | 149 | # Context: 150 | flow_cont_res_f, occ_cont_res_f = self.context_networks(torch.cat([x_intm_f, flow_est_f, occ_est_f], dim=1)) 151 | flow_cont_f = flow_est_f + flow_cont_res_f 152 | occ_cont_f = occ_est_f + occ_cont_res_f 153 | with torch.no_grad(): 154 | flow_cont_res_b, occ_cont_res_b = self.context_networks(torch.cat([x_intm_b, flow_est_b, occ_est_b], dim=1)) 155 | flow_cont_b = flow_est_b + flow_cont_res_b 156 | occ_cont_b = occ_est_b + occ_cont_res_b 157 | 158 | # refinement 159 | img1_resize = upsample2d_as(list_imgs[i], flow_f, mode="bilinear") 160 | img2_resize = upsample2d_as(list_imgs[i+1], flow_b, mode="bilinear") 161 | img2_warp = self.warping_layer(img2_resize, rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow) 162 | img1_warp = self.warping_layer(img1_resize, rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow) 163 | 164 | # flow refine 165 | flow_f = self.refine_flow(flow_cont_f.detach(), img1_resize - img2_warp, x1_1by1) 166 | flow_b = self.refine_flow(flow_cont_b.detach(), img2_resize - img1_warp, x2_1by1) 167 | 168 | flow_cont_f = rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False) 169 | flow_cont_b = rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False) 170 | flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=False) 171 | flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=False) 172 | 173 | # occ refine 174 | x2_1by1_warp = self.warping_layer(x2_1by1, flow_f, height_im, width_im, self._div_flow) 175 | x1_1by1_warp = self.warping_layer(x1_1by1, flow_b, height_im, width_im, self._div_flow) 176 | 177 | occ_f = self.refine_occ(occ_cont_f.detach(), x1_1by1, x1_1by1 - x2_1by1_warp) 178 | occ_b = self.refine_occ(occ_cont_b.detach(), x2_1by1, x2_1by1 - x1_1by1_warp) 179 | 180 | # save features for temporal connection: 181 | previous_features[l] = self.conv_1x1_time(x_intm_f) 182 | flows_f[l].append(flow_f) 183 | occs_f[l].append(occ_f) 184 | flows_b[l].append(flow_b) 185 | occs_b[l].append(occ_b) 186 | flows_coarse_f[l].append(flow_cont_f) 187 | occs_coarse_f[l].append(occ_cont_f) 188 | #flows.append([flow_cont_f, flow_cont_b, flow_f, flow_b]) 189 | #occs.append([occ_cont_f, occ_cont_b, occ_f, occ_b]) 190 | 191 | else: 192 | flow_f = upsample2d_as(flow_f, x1, mode="bilinear") 193 | flow_b = upsample2d_as(flow_b, x2, mode="bilinear") 194 | flows_f[l].append(flow_f) 195 | flows_b[l].append(flow_b) 196 | #flows.append([flow_f, flow_b]) 197 | 198 | x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow) 199 | x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow) 200 | flow_b_warp = self.warping_layer(flow_b, flow_f, height_im, width_im, self._div_flow) 201 | flow_f_warp = self.warping_layer(flow_f, flow_b, height_im, width_im, self._div_flow) 202 | 203 | if l != self.num_levels-1: 204 | x1_in = self.conv_1x1_1(x1) 205 | x2_in = self.conv_1x1_1(x2) 206 | x1_w_in = self.conv_1x1_1(x1_warp) 207 | x2_w_in = self.conv_1x1_1(x2_warp) 208 | else: 209 | x1_in = x1 210 | x2_in = x2 211 | x1_w_in = x1_warp 212 | x2_w_in = x2_warp 213 | 214 | occ_f = self.occ_shuffle_upsample(occ_f, torch.cat([x1_in, x2_w_in, flow_f, flow_b_warp], dim=1)) 215 | occ_b = self.occ_shuffle_upsample(occ_b, torch.cat([x2_in, x1_w_in, flow_b, flow_f_warp], dim=1)) 216 | 217 | occs_f[l].append(occ_f) 218 | occs_b[l].append(occ_b) 219 | #occs.append([occ_f, occ_b]) 220 | 221 | flow_f = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float() 222 | flow_b = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float() 223 | occ_f = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float() 224 | occ_b = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float() 225 | 226 | if self.training: 227 | if len(list_imgs) > 2: 228 | for l in range(len(flows_f)): 229 | flows_f[l] = torch.stack(flows_f[l], 0) 230 | occs_f[l] = torch.stack(occs_f[l], 0) 231 | for l in range(len(flows_coarse_f)): 232 | flows_coarse_f[l] = torch.stack(flows_coarse_f[l], 0) 233 | occs_coarse_f[l] = torch.stack(occs_coarse_f[l], 0) 234 | else: 235 | for l in range(len(flows_f)): 236 | flows_f[l] = flows_f[l][0] 237 | occs_f[l] = occs_f[l][0] 238 | for l in range(len(flows_coarse_f)): 239 | flows_coarse_f[l] = flows_coarse_f[l][0] 240 | occs_coarse_f[l] = occs_coarse_f[l][0] 241 | output_dict['flow'] = flows_f 242 | output_dict['occ'] = occs_f 243 | output_dict['flow_coarse'] = flows_coarse_f 244 | output_dict['occ_coarse'] = occs_coarse_f 245 | return output_dict 246 | else: 247 | output_dict_eval = {} 248 | if len(list_imgs) > 2: 249 | out_flow = [] 250 | out_occ = [] 251 | for i in range(len(flows_f[0])): 252 | out_flow.append(upsample2d_as(flows_f[-1][i], list_imgs[0], mode="bilinear") * (1.0 / self._div_flow)) 253 | out_occ.append(upsample2d_as(occs_f[-1][i], list_imgs[0], mode="bilinear")) 254 | out_flow = torch.stack(out_flow, 0) 255 | out_occ = torch.stack(out_occ, 0) 256 | else: 257 | out_flow = upsample2d_as(flows_f[-1][0], list_imgs[0], mode="bilinear") * (1.0 / self._div_flow) 258 | out_occ = upsample2d_as(occs_f[-1][0], list_imgs[0], mode="bilinear") 259 | output_dict_eval['flow'] = out_flow 260 | output_dict_eval['occ'] = out_occ 261 | return output_dict_eval 262 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import pwcnet 2 | from . import pwcnet_irr 3 | 4 | from . import pwcnet_occ_joint 5 | from . import pwcnet_irr_occ_joint 6 | 7 | from . import tr_flow 8 | from . import tr_features 9 | 10 | from . import IRR_PWC 11 | from . import IRR_PWC_occ_joint 12 | from . import STAR 13 | 14 | PWCNet = pwcnet.PWCNet 15 | PWCNet_irr = pwcnet_irr.PWCNet 16 | PWCNet_occ_joint = pwcnet_occ_joint.PWCNet 17 | PWCNet_irr_occ_joint = pwcnet_irr_occ_joint.PWCNet 18 | 19 | TRFlow = tr_flow.TRFlow 20 | TRFlow_occjoint = tr_flow.TRFlow_occjoint 21 | TRFlow_irr = tr_flow.TRFlow_irr 22 | TRFlow_irr_occjoint = tr_flow.TRFlow_irr_occjoint 23 | 24 | TRFeat = tr_features.TRFeat 25 | TRFeat_occjoint = tr_features.TRFeat_occjoint 26 | TRFeat_irr_occjoint = tr_features.TRFeat_irr_occjoint 27 | 28 | # -- With refinement --- 29 | 30 | IRR_PWC = IRR_PWC.PWCNet 31 | IRR_occ_joint = IRR_PWC_occ_joint.PWCNet 32 | 33 | StarFlow = STAR.StarFlow 34 | -------------------------------------------------------------------------------- /models/correlation_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgodet/star_flow/cedb96ff339d11abf71d12d09e794593a742ccce/models/correlation_package/__init__.py -------------------------------------------------------------------------------- /models/correlation_package/correlation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.module import Module 3 | from torch.autograd import Function 4 | import correlation_cuda 5 | 6 | class CorrelationFunction(Function): 7 | 8 | def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1): 9 | super(CorrelationFunction, self).__init__() 10 | self.pad_size = pad_size 11 | self.kernel_size = kernel_size 12 | self.max_displacement = max_displacement 13 | self.stride1 = stride1 14 | self.stride2 = stride2 15 | self.corr_multiply = corr_multiply 16 | # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1) 17 | 18 | def forward(self, input1, input2): 19 | self.save_for_backward(input1, input2) 20 | 21 | with torch.cuda.device_of(input1): 22 | rbot1 = input1.new() 23 | rbot2 = input2.new() 24 | output = input1.new() 25 | 26 | correlation_cuda.forward(input1, input2, rbot1, rbot2, output, 27 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) 28 | 29 | return output 30 | 31 | def backward(self, grad_output): 32 | input1, input2 = self.saved_tensors 33 | 34 | with torch.cuda.device_of(input1): 35 | rbot1 = input1.new() 36 | rbot2 = input2.new() 37 | 38 | grad_input1 = input1.new() 39 | grad_input2 = input2.new() 40 | 41 | correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2, 42 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) 43 | 44 | return grad_input1, grad_input2 45 | 46 | 47 | class Correlation(Module): 48 | def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1): 49 | super(Correlation, self).__init__() 50 | self.pad_size = pad_size 51 | self.kernel_size = kernel_size 52 | self.max_displacement = max_displacement 53 | self.stride1 = stride1 54 | self.stride2 = stride2 55 | self.corr_multiply = corr_multiply 56 | 57 | def forward(self, input1, input2): 58 | 59 | result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement, self.stride1, self.stride2, self.corr_multiply)(input1, input2) 60 | 61 | return result 62 | 63 | -------------------------------------------------------------------------------- /models/correlation_package/correlation_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "correlation_cuda_kernel.cuh" 9 | 10 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, 11 | int pad_size, 12 | int kernel_size, 13 | int max_displacement, 14 | int stride1, 15 | int stride2, 16 | int corr_type_multiply) 17 | { 18 | 19 | int batchSize = input1.size(0); 20 | 21 | int nInputChannels = input1.size(1); 22 | int inputHeight = input1.size(2); 23 | int inputWidth = input1.size(3); 24 | 25 | int kernel_radius = (kernel_size - 1) / 2; 26 | int border_radius = kernel_radius + max_displacement; 27 | 28 | int paddedInputHeight = inputHeight + 2 * pad_size; 29 | int paddedInputWidth = inputWidth + 2 * pad_size; 30 | 31 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); 32 | 33 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1)); 34 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1)); 35 | 36 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 37 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 38 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); 39 | 40 | rInput1.fill_(0); 41 | rInput2.fill_(0); 42 | output.fill_(0); 43 | 44 | int success = correlation_forward_cuda_kernel( 45 | output, 46 | output.size(0), 47 | output.size(1), 48 | output.size(2), 49 | output.size(3), 50 | output.stride(0), 51 | output.stride(1), 52 | output.stride(2), 53 | output.stride(3), 54 | input1, 55 | input1.size(1), 56 | input1.size(2), 57 | input1.size(3), 58 | input1.stride(0), 59 | input1.stride(1), 60 | input1.stride(2), 61 | input1.stride(3), 62 | input2, 63 | input2.size(1), 64 | input2.stride(0), 65 | input2.stride(1), 66 | input2.stride(2), 67 | input2.stride(3), 68 | rInput1, 69 | rInput2, 70 | pad_size, 71 | kernel_size, 72 | max_displacement, 73 | stride1, 74 | stride2, 75 | corr_type_multiply, 76 | at::cuda::getCurrentCUDAStream() 77 | //at::globalContext().getCurrentCUDAStream() 78 | ); 79 | 80 | //check for errors 81 | if (!success) { 82 | AT_ERROR("CUDA call failed"); 83 | } 84 | 85 | return 1; 86 | 87 | } 88 | 89 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, 90 | at::Tensor& gradInput1, at::Tensor& gradInput2, 91 | int pad_size, 92 | int kernel_size, 93 | int max_displacement, 94 | int stride1, 95 | int stride2, 96 | int corr_type_multiply) 97 | { 98 | 99 | int batchSize = input1.size(0); 100 | int nInputChannels = input1.size(1); 101 | int paddedInputHeight = input1.size(2)+ 2 * pad_size; 102 | int paddedInputWidth = input1.size(3)+ 2 * pad_size; 103 | 104 | int height = input1.size(2); 105 | int width = input1.size(3); 106 | 107 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 108 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 109 | gradInput1.resize_({batchSize, nInputChannels, height, width}); 110 | gradInput2.resize_({batchSize, nInputChannels, height, width}); 111 | 112 | rInput1.fill_(0); 113 | rInput2.fill_(0); 114 | gradInput1.fill_(0); 115 | gradInput2.fill_(0); 116 | 117 | int success = correlation_backward_cuda_kernel(gradOutput, 118 | gradOutput.size(0), 119 | gradOutput.size(1), 120 | gradOutput.size(2), 121 | gradOutput.size(3), 122 | gradOutput.stride(0), 123 | gradOutput.stride(1), 124 | gradOutput.stride(2), 125 | gradOutput.stride(3), 126 | input1, 127 | input1.size(1), 128 | input1.size(2), 129 | input1.size(3), 130 | input1.stride(0), 131 | input1.stride(1), 132 | input1.stride(2), 133 | input1.stride(3), 134 | input2, 135 | input2.stride(0), 136 | input2.stride(1), 137 | input2.stride(2), 138 | input2.stride(3), 139 | gradInput1, 140 | gradInput1.stride(0), 141 | gradInput1.stride(1), 142 | gradInput1.stride(2), 143 | gradInput1.stride(3), 144 | gradInput2, 145 | gradInput2.size(1), 146 | gradInput2.stride(0), 147 | gradInput2.stride(1), 148 | gradInput2.stride(2), 149 | gradInput2.stride(3), 150 | rInput1, 151 | rInput2, 152 | pad_size, 153 | kernel_size, 154 | max_displacement, 155 | stride1, 156 | stride2, 157 | corr_type_multiply, 158 | at::cuda::getCurrentCUDAStream() 159 | //at::globalContext().getCurrentCUDAStream() 160 | ); 161 | 162 | if (!success) { 163 | AT_ERROR("CUDA call failed"); 164 | } 165 | 166 | return 1; 167 | } 168 | 169 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 170 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); 171 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); 172 | } 173 | 174 | -------------------------------------------------------------------------------- /models/correlation_package/correlation_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "correlation_cuda_kernel.cuh" 4 | 5 | #define CUDA_NUM_THREADS 1024 6 | #define THREADS_PER_BLOCK 32 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | using at::Half; 14 | 15 | template 16 | __global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size) 17 | { 18 | 19 | // n (batch size), c (num of channels), y (height), x (width) 20 | int n = blockIdx.x; 21 | int y = blockIdx.y; 22 | int x = blockIdx.z; 23 | 24 | int ch_off = threadIdx.x; 25 | scalar_t value; 26 | 27 | int dimcyx = channels * height * width; 28 | int dimyx = height * width; 29 | 30 | int p_dimx = (width + 2 * pad_size); 31 | int p_dimy = (height + 2 * pad_size); 32 | int p_dimyxc = channels * p_dimy * p_dimx; 33 | int p_dimxc = p_dimx * channels; 34 | 35 | for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) { 36 | value = input[n * dimcyx + c * dimyx + y * width + x]; 37 | rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value; 38 | } 39 | } 40 | 41 | template 42 | __global__ void correlation_forward(scalar_t* output, int nOutputChannels, int outputHeight, int outputWidth, 43 | const scalar_t* __restrict__ rInput1, int nInputChannels, int inputHeight, int inputWidth, 44 | const scalar_t* __restrict__ rInput2, 45 | int pad_size, 46 | int kernel_size, 47 | int max_displacement, 48 | int stride1, 49 | int stride2) 50 | { 51 | // n (batch size), c (num of channels), y (height), x (width) 52 | 53 | int pInputWidth = inputWidth + 2 * pad_size; 54 | int pInputHeight = inputHeight + 2 * pad_size; 55 | 56 | int kernel_rad = (kernel_size - 1) / 2; 57 | int displacement_rad = max_displacement / stride2; 58 | int displacement_size = 2 * displacement_rad + 1; 59 | 60 | int n = blockIdx.x; 61 | int y1 = blockIdx.y * stride1 + max_displacement; 62 | int x1 = blockIdx.z * stride1 + max_displacement; 63 | int c = threadIdx.x; 64 | 65 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 66 | int pdimxc = pInputWidth * nInputChannels; 67 | int pdimc = nInputChannels; 68 | 69 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 70 | int tdimyx = outputHeight * outputWidth; 71 | int tdimx = outputWidth; 72 | 73 | scalar_t nelems = kernel_size * kernel_size * pdimc; 74 | 75 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; 76 | 77 | // no significant speed-up in using chip memory for input1 sub-data, 78 | // not enough chip memory size to accomodate memory per block for input2 sub-data 79 | // instead i've used device memory for both 80 | 81 | // element-wise product along channel axis 82 | for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { 83 | for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) { 84 | prod_sum[c] = 0; 85 | int x2 = x1 + ti*stride2; 86 | int y2 = y1 + tj*stride2; 87 | 88 | for (int j = -kernel_rad; j <= kernel_rad; ++j) { 89 | for (int i = -kernel_rad; i <= kernel_rad; ++i) { 90 | for (int ch = c; ch < pdimc; ch += THREADS_PER_BLOCK) { 91 | int indx1 = n * pdimyxc + (y1 + j) * pdimxc + (x1 + i) * pdimc + ch; 92 | int indx2 = n * pdimyxc + (y2 + j) * pdimxc + (x2 + i) * pdimc + ch; 93 | 94 | prod_sum[c] += rInput1[indx1] * rInput2[indx2]; 95 | } 96 | } 97 | } 98 | 99 | // accumulate 100 | __syncthreads(); 101 | if (c == 0) { 102 | scalar_t reduce_sum = 0; 103 | for (int index = 0; index < THREADS_PER_BLOCK; ++index) { 104 | reduce_sum += prod_sum[index]; 105 | } 106 | int tc = (tj + displacement_rad) * displacement_size + (ti + displacement_rad); 107 | const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + blockIdx.z; 108 | output[tindx] = reduce_sum / nelems; 109 | } 110 | 111 | } 112 | } 113 | 114 | } 115 | 116 | template 117 | __global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth, 118 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 119 | const scalar_t* __restrict__ rInput2, 120 | int pad_size, 121 | int kernel_size, 122 | int max_displacement, 123 | int stride1, 124 | int stride2) 125 | { 126 | // n (batch size), c (num of channels), y (height), x (width) 127 | 128 | int n = item; 129 | int y = blockIdx.x * stride1 + pad_size; 130 | int x = blockIdx.y * stride1 + pad_size; 131 | int c = blockIdx.z; 132 | int tch_off = threadIdx.x; 133 | 134 | int kernel_rad = (kernel_size - 1) / 2; 135 | int displacement_rad = max_displacement / stride2; 136 | int displacement_size = 2 * displacement_rad + 1; 137 | 138 | int xmin = (x - kernel_rad - max_displacement) / stride1; 139 | int ymin = (y - kernel_rad - max_displacement) / stride1; 140 | 141 | int xmax = (x + kernel_rad - max_displacement) / stride1; 142 | int ymax = (y + kernel_rad - max_displacement) / stride1; 143 | 144 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 145 | // assumes gradInput1 is pre-allocated and zero filled 146 | return; 147 | } 148 | 149 | if (xmin > xmax || ymin > ymax) { 150 | // assumes gradInput1 is pre-allocated and zero filled 151 | return; 152 | } 153 | 154 | xmin = max(0, xmin); 155 | xmax = min(outputWidth - 1, xmax); 156 | 157 | ymin = max(0, ymin); 158 | ymax = min(outputHeight - 1, ymax); 159 | 160 | int pInputWidth = inputWidth + 2 * pad_size; 161 | int pInputHeight = inputHeight + 2 * pad_size; 162 | 163 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 164 | int pdimxc = pInputWidth * nInputChannels; 165 | int pdimc = nInputChannels; 166 | 167 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 168 | int tdimyx = outputHeight * outputWidth; 169 | int tdimx = outputWidth; 170 | 171 | int odimcyx = nInputChannels * inputHeight* inputWidth; 172 | int odimyx = inputHeight * inputWidth; 173 | int odimx = inputWidth; 174 | 175 | scalar_t nelems = kernel_size * kernel_size * nInputChannels; 176 | 177 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; 178 | prod_sum[tch_off] = 0; 179 | 180 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { 181 | 182 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 183 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 184 | 185 | int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c; 186 | 187 | scalar_t val2 = rInput2[indx2]; 188 | 189 | for (int j = ymin; j <= ymax; ++j) { 190 | for (int i = xmin; i <= xmax; ++i) { 191 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 192 | prod_sum[tch_off] += gradOutput[tindx] * val2; 193 | } 194 | } 195 | } 196 | __syncthreads(); 197 | 198 | if (tch_off == 0) { 199 | scalar_t reduce_sum = 0; 200 | for (int idx = 0; idx < THREADS_PER_BLOCK; idx++) { 201 | reduce_sum += prod_sum[idx]; 202 | } 203 | const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 204 | gradInput1[indx1] = reduce_sum / nelems; 205 | } 206 | 207 | } 208 | 209 | template 210 | __global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth, 211 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 212 | const scalar_t* __restrict__ rInput1, 213 | int pad_size, 214 | int kernel_size, 215 | int max_displacement, 216 | int stride1, 217 | int stride2) 218 | { 219 | // n (batch size), c (num of channels), y (height), x (width) 220 | 221 | int n = item; 222 | int y = blockIdx.x * stride1 + pad_size; 223 | int x = blockIdx.y * stride1 + pad_size; 224 | int c = blockIdx.z; 225 | 226 | int tch_off = threadIdx.x; 227 | 228 | int kernel_rad = (kernel_size - 1) / 2; 229 | int displacement_rad = max_displacement / stride2; 230 | int displacement_size = 2 * displacement_rad + 1; 231 | 232 | int pInputWidth = inputWidth + 2 * pad_size; 233 | int pInputHeight = inputHeight + 2 * pad_size; 234 | 235 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 236 | int pdimxc = pInputWidth * nInputChannels; 237 | int pdimc = nInputChannels; 238 | 239 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 240 | int tdimyx = outputHeight * outputWidth; 241 | int tdimx = outputWidth; 242 | 243 | int odimcyx = nInputChannels * inputHeight* inputWidth; 244 | int odimyx = inputHeight * inputWidth; 245 | int odimx = inputWidth; 246 | 247 | scalar_t nelems = kernel_size * kernel_size * nInputChannels; 248 | 249 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; 250 | prod_sum[tch_off] = 0; 251 | 252 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { 253 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 254 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 255 | 256 | int xmin = (x - kernel_rad - max_displacement - i2) / stride1; 257 | int ymin = (y - kernel_rad - max_displacement - j2) / stride1; 258 | 259 | int xmax = (x + kernel_rad - max_displacement - i2) / stride1; 260 | int ymax = (y + kernel_rad - max_displacement - j2) / stride1; 261 | 262 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 263 | // assumes gradInput2 is pre-allocated and zero filled 264 | continue; 265 | } 266 | 267 | if (xmin > xmax || ymin > ymax) { 268 | // assumes gradInput2 is pre-allocated and zero filled 269 | continue; 270 | } 271 | 272 | xmin = max(0, xmin); 273 | xmax = min(outputWidth - 1, xmax); 274 | 275 | ymin = max(0, ymin); 276 | ymax = min(outputHeight - 1, ymax); 277 | 278 | int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c; 279 | scalar_t val1 = rInput1[indx1]; 280 | 281 | for (int j = ymin; j <= ymax; ++j) { 282 | for (int i = xmin; i <= xmax; ++i) { 283 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 284 | prod_sum[tch_off] += gradOutput[tindx] * val1; 285 | } 286 | } 287 | } 288 | 289 | __syncthreads(); 290 | 291 | if (tch_off == 0) { 292 | scalar_t reduce_sum = 0; 293 | for (int idx = 0; idx < THREADS_PER_BLOCK; idx++) { 294 | reduce_sum += prod_sum[idx]; 295 | } 296 | const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 297 | gradInput2[indx2] = reduce_sum / nelems; 298 | } 299 | 300 | } 301 | 302 | int correlation_forward_cuda_kernel(at::Tensor& output, 303 | int ob, 304 | int oc, 305 | int oh, 306 | int ow, 307 | int osb, 308 | int osc, 309 | int osh, 310 | int osw, 311 | 312 | at::Tensor& input1, 313 | int ic, 314 | int ih, 315 | int iw, 316 | int isb, 317 | int isc, 318 | int ish, 319 | int isw, 320 | 321 | at::Tensor& input2, 322 | int gc, 323 | int gsb, 324 | int gsc, 325 | int gsh, 326 | int gsw, 327 | 328 | at::Tensor& rInput1, 329 | at::Tensor& rInput2, 330 | int pad_size, 331 | int kernel_size, 332 | int max_displacement, 333 | int stride1, 334 | int stride2, 335 | int corr_type_multiply, 336 | cudaStream_t stream) 337 | { 338 | 339 | int batchSize = ob; 340 | 341 | int nInputChannels = ic; 342 | int inputWidth = iw; 343 | int inputHeight = ih; 344 | 345 | int nOutputChannels = oc; 346 | int outputWidth = ow; 347 | int outputHeight = oh; 348 | 349 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 350 | dim3 threads_block(THREADS_PER_BLOCK); 351 | 352 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] { 353 | 354 | channels_first << > >( 355 | input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size); 356 | 357 | })); 358 | 359 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] { 360 | 361 | channels_first << > > ( 362 | input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size); 363 | 364 | })); 365 | 366 | dim3 threadsPerBlock(THREADS_PER_BLOCK); 367 | dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth); 368 | 369 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] { 370 | 371 | correlation_forward << > > 372 | (output.data(), nOutputChannels, outputHeight, outputWidth, 373 | rInput1.data(), nInputChannels, inputHeight, inputWidth, 374 | rInput2.data(), 375 | pad_size, 376 | kernel_size, 377 | max_displacement, 378 | stride1, 379 | stride2); 380 | 381 | })); 382 | 383 | cudaError_t err = cudaGetLastError(); 384 | 385 | 386 | // check for errors 387 | if (err != cudaSuccess) { 388 | printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err)); 389 | return 0; 390 | } 391 | 392 | return 1; 393 | } 394 | 395 | 396 | int correlation_backward_cuda_kernel( 397 | at::Tensor& gradOutput, 398 | int gob, 399 | int goc, 400 | int goh, 401 | int gow, 402 | int gosb, 403 | int gosc, 404 | int gosh, 405 | int gosw, 406 | 407 | at::Tensor& input1, 408 | int ic, 409 | int ih, 410 | int iw, 411 | int isb, 412 | int isc, 413 | int ish, 414 | int isw, 415 | 416 | at::Tensor& input2, 417 | int gsb, 418 | int gsc, 419 | int gsh, 420 | int gsw, 421 | 422 | at::Tensor& gradInput1, 423 | int gisb, 424 | int gisc, 425 | int gish, 426 | int gisw, 427 | 428 | at::Tensor& gradInput2, 429 | int ggc, 430 | int ggsb, 431 | int ggsc, 432 | int ggsh, 433 | int ggsw, 434 | 435 | at::Tensor& rInput1, 436 | at::Tensor& rInput2, 437 | int pad_size, 438 | int kernel_size, 439 | int max_displacement, 440 | int stride1, 441 | int stride2, 442 | int corr_type_multiply, 443 | cudaStream_t stream) 444 | { 445 | 446 | int batchSize = gob; 447 | int num = batchSize; 448 | 449 | int nInputChannels = ic; 450 | int inputWidth = iw; 451 | int inputHeight = ih; 452 | 453 | int nOutputChannels = goc; 454 | int outputWidth = gow; 455 | int outputHeight = goh; 456 | 457 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 458 | dim3 threads_block(THREADS_PER_BLOCK); 459 | 460 | 461 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] { 462 | 463 | channels_first << > >( 464 | input1.data(), 465 | rInput1.data(), 466 | nInputChannels, 467 | inputHeight, 468 | inputWidth, 469 | pad_size 470 | ); 471 | })); 472 | 473 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { 474 | 475 | channels_first << > >( 476 | input2.data(), 477 | rInput2.data(), 478 | nInputChannels, 479 | inputHeight, 480 | inputWidth, 481 | pad_size 482 | ); 483 | })); 484 | 485 | dim3 threadsPerBlock(THREADS_PER_BLOCK); 486 | dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels); 487 | 488 | for (int n = 0; n < num; ++n) { 489 | 490 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { 491 | 492 | 493 | correlation_backward_input1 << > > ( 494 | n, gradInput1.data(), nInputChannels, inputHeight, inputWidth, 495 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth, 496 | rInput2.data(), 497 | pad_size, 498 | kernel_size, 499 | max_displacement, 500 | stride1, 501 | stride2); 502 | })); 503 | } 504 | 505 | for (int n = 0; n < batchSize; n++) { 506 | 507 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] { 508 | 509 | correlation_backward_input2 << > >( 510 | n, gradInput2.data(), nInputChannels, inputHeight, inputWidth, 511 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth, 512 | rInput1.data(), 513 | pad_size, 514 | kernel_size, 515 | max_displacement, 516 | stride1, 517 | stride2); 518 | 519 | })); 520 | } 521 | 522 | // check for errors 523 | cudaError_t err = cudaGetLastError(); 524 | if (err != cudaSuccess) { 525 | printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err)); 526 | return 0; 527 | } 528 | 529 | return 1; 530 | } 531 | -------------------------------------------------------------------------------- /models/correlation_package/correlation_cuda_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | int correlation_forward_cuda_kernel(at::Tensor& output, 8 | int ob, 9 | int oc, 10 | int oh, 11 | int ow, 12 | int osb, 13 | int osc, 14 | int osh, 15 | int osw, 16 | 17 | at::Tensor& input1, 18 | int ic, 19 | int ih, 20 | int iw, 21 | int isb, 22 | int isc, 23 | int ish, 24 | int isw, 25 | 26 | at::Tensor& input2, 27 | int gc, 28 | int gsb, 29 | int gsc, 30 | int gsh, 31 | int gsw, 32 | 33 | at::Tensor& rInput1, 34 | at::Tensor& rInput2, 35 | int pad_size, 36 | int kernel_size, 37 | int max_displacement, 38 | int stride1, 39 | int stride2, 40 | int corr_type_multiply, 41 | cudaStream_t stream); 42 | 43 | 44 | int correlation_backward_cuda_kernel( 45 | at::Tensor& gradOutput, 46 | int gob, 47 | int goc, 48 | int goh, 49 | int gow, 50 | int gosb, 51 | int gosc, 52 | int gosh, 53 | int gosw, 54 | 55 | at::Tensor& input1, 56 | int ic, 57 | int ih, 58 | int iw, 59 | int isb, 60 | int isc, 61 | int ish, 62 | int isw, 63 | 64 | at::Tensor& input2, 65 | int gsb, 66 | int gsc, 67 | int gsh, 68 | int gsw, 69 | 70 | at::Tensor& gradInput1, 71 | int gisb, 72 | int gisc, 73 | int gish, 74 | int gisw, 75 | 76 | at::Tensor& gradInput2, 77 | int ggc, 78 | int ggsb, 79 | int ggsc, 80 | int ggsh, 81 | int ggsw, 82 | 83 | at::Tensor& rInput1, 84 | at::Tensor& rInput2, 85 | int pad_size, 86 | int kernel_size, 87 | int max_displacement, 88 | int stride1, 89 | int stride2, 90 | int corr_type_multiply, 91 | cudaStream_t stream); 92 | -------------------------------------------------------------------------------- /models/correlation_package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | 4 | os.environ["CC"] = "gcc" 5 | os.environ["CXX"] = "gcc" 6 | 7 | import torch 8 | 9 | from setuptools import setup, find_packages 10 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 11 | 12 | cxx_args = ['-std=c++11', 13 | '-D_GLICXX_USE_CXX11_ABI=1', 14 | ] 15 | 16 | nvcc_args = [ 17 | '-gencode', 'arch=compute_50,code=sm_50', 18 | '-gencode', 'arch=compute_52,code=sm_52', 19 | '-gencode', 'arch=compute_60,code=sm_60', 20 | '-gencode', 'arch=compute_61,code=sm_61', 21 | '-gencode', 'arch=compute_70,code=sm_70', 22 | '-gencode', 'arch=compute_75,code=sm_75', 23 | '-gencode', 'arch=compute_75,code=compute_75', 24 | '-ccbin', '/usr/bin/gcc' 25 | ] 26 | 27 | # '-ccbin', '/usr/bin/gcc-5' 28 | 29 | #nvcc_args = [ 30 | # '-gencode', 'arch=compute_50,code=sm_50', 31 | # '-gencode', 'arch=compute_52,code=sm_52', 32 | # '-gencode', 'arch=compute_60,code=sm_60', 33 | # '-gencode', 'arch=compute_61,code=sm_61', 34 | # '-gencode', 'arch=compute_70,code=sm_70', 35 | # '-gencode', 'arch=compute_70,code=compute_70' 36 | #] 37 | 38 | setup( 39 | name='correlation_cuda', 40 | ext_modules=[ 41 | CUDAExtension('correlation_cuda', [ 42 | 'correlation_cuda.cc', 43 | 'correlation_cuda_kernel.cu' 44 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args, 'cuda-path': ['/usr/local/cuda']}) 45 | ], 46 | cmdclass={ 47 | 'build_ext': BuildExtension 48 | }) 49 | -------------------------------------------------------------------------------- /models/irr_modules.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as tf 6 | 7 | def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True): 8 | if isReLU: 9 | return nn.Sequential( 10 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, 11 | padding=((kernel_size - 1) * dilation) // 2, bias=True), 12 | nn.LeakyReLU(0.1, inplace=True) 13 | ) 14 | else: 15 | return nn.Sequential( 16 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, 17 | padding=((kernel_size - 1) * dilation) // 2, bias=True) 18 | ) 19 | 20 | 21 | def upsample_factor2(inputs, target_as): 22 | inputs = tf.interpolate(inputs, scale_factor=2, mode="nearest") 23 | _, _, h, w = target_as.size() 24 | if inputs.size(2) != h or inputs.size(3) != w: 25 | return tf.interpolate(inputs, [h, w], mode="bilinear", align_corners=False) 26 | else: 27 | return inputs 28 | 29 | 30 | class OccUpsampleNetwork(nn.Module): 31 | def __init__(self, ch_in, ch_out): 32 | super(OccUpsampleNetwork, self).__init__() 33 | 34 | self.feat_dim = 32 35 | self.init_conv = conv(ch_in, self.feat_dim) 36 | 37 | self.res_convs = nn.Sequential( 38 | conv(self.feat_dim, self.feat_dim), 39 | conv(self.feat_dim, self.feat_dim, isReLU=False) 40 | ) 41 | self.res_end_conv = conv(self.feat_dim, self.feat_dim) 42 | self.mul_const = 0.1 43 | 44 | self.out_convs = conv(self.feat_dim, ch_out) 45 | 46 | def forward(self, occ, x): 47 | occ = upsample_factor2(occ, x) 48 | x_in = torch.cat([occ, x], dim=1) 49 | x_init = self.init_conv(x_in) 50 | x_res = x_init 51 | x_res = x_res + self.res_convs(x_res) * self.mul_const 52 | x_res = x_res + self.res_convs(x_res) * self.mul_const 53 | x_res = x_res + self.res_convs(x_res) * self.mul_const 54 | x_init = x_init + self.res_end_conv(x_res) 55 | 56 | return self.out_convs(x_init) + occ 57 | 58 | 59 | def subtract_mean(input): 60 | return input - input.mean(2).mean(2).unsqueeze(2).unsqueeze(2).expand_as(input) 61 | 62 | 63 | class RefineFlow(nn.Module): 64 | def __init__(self, ch_in): 65 | super(RefineFlow, self).__init__() 66 | 67 | self.kernel_size = 3 68 | self.pad_size = 1 69 | self.pad_ftn = nn.ReplicationPad2d(self.pad_size) 70 | 71 | self.convs = nn.Sequential( 72 | conv(ch_in, 128, 3, 1, 1), 73 | conv(128, 128, 3, 1, 1), 74 | conv(128, 64, 3, 1, 1), 75 | conv(64, 64, 3, 1, 1), 76 | conv(64, 32, 3, 1, 1), 77 | conv(32, 32, 3, 1, 1), 78 | conv(32, self.kernel_size * self.kernel_size, 3, 1, 1) 79 | ) 80 | 81 | self.softmax_feat = nn.Softmax(dim=1) 82 | self.unfold_flow = nn.Unfold(kernel_size=(self.kernel_size, self.kernel_size)) 83 | self.unfold_kernel = nn.Unfold(kernel_size=(1, 1)) 84 | 85 | def forward(self, flow, diff_img, feature): 86 | b, _, h, w = flow.size() 87 | 88 | flow_m = subtract_mean(flow) 89 | norm2_img = torch.norm(diff_img, p=2, dim=1, keepdim=True) 90 | 91 | feat = self.convs(torch.cat([flow_m, norm2_img, feature], dim=1)) 92 | feat_kernel = self.softmax_feat(-feat ** 2) 93 | 94 | flow_x = flow[:, 0].unsqueeze(1) 95 | flow_y = flow[:, 1].unsqueeze(1) 96 | 97 | flow_x_unfold = self.unfold_flow(self.pad_ftn(flow_x)) 98 | flow_y_unfold = self.unfold_flow(self.pad_ftn(flow_y)) 99 | feat_kernel_unfold = self.unfold_kernel(feat_kernel) 100 | 101 | flow_out_x = torch.sum(flow_x_unfold * feat_kernel_unfold, dim=1).unsqueeze(1).view(b, 1, h, w) 102 | flow_out_y = torch.sum(flow_y_unfold * feat_kernel_unfold, dim=1).unsqueeze(1).view(b, 1, h, w) 103 | 104 | return torch.cat([flow_out_x, flow_out_y], dim=1) 105 | 106 | 107 | class RefineOcc(nn.Module): 108 | def __init__(self, ch_in): 109 | super(RefineOcc, self).__init__() 110 | 111 | self.kernel_size = 3 112 | self.pad_size = 1 113 | self.pad_ftn = nn.ReplicationPad2d(self.pad_size) 114 | 115 | self.convs = nn.Sequential( 116 | conv(ch_in, 128, 3, 1, 1), 117 | conv(128, 128, 3, 1, 1), 118 | conv(128, 64, 3, 1, 1), 119 | conv(64, 64, 3, 1, 1), 120 | conv(64, 32, 3, 1, 1), 121 | conv(32, 32, 3, 1, 1), 122 | conv(32, self.kernel_size * self.kernel_size, 3, 1, 1) 123 | ) 124 | 125 | self.softmax_feat = nn.Softmax(dim=1) 126 | self.unfold_occ = nn.Unfold(kernel_size=(self.kernel_size, self.kernel_size)) 127 | self.unfold_kernel = nn.Unfold(kernel_size=(1, 1)) 128 | 129 | def forward(self, occ, feat1, feat2): 130 | b, _, h, w = occ.size() 131 | 132 | feat = self.convs(torch.cat([occ, feat1, feat2], dim=1)) 133 | feat_kernel = self.softmax_feat(-feat ** 2) 134 | 135 | occ_unfold = self.unfold_occ(self.pad_ftn(occ)) 136 | feat_kernel_unfold = self.unfold_kernel(feat_kernel) 137 | 138 | occ_out = torch.sum(occ_unfold * feat_kernel_unfold, dim=1).unsqueeze(1).view(b, 1, h, w) 139 | 140 | return occ_out -------------------------------------------------------------------------------- /models/pwc_modules.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as tf 6 | import logging 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True): 9 | if isReLU: 10 | return nn.Sequential( 11 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, 12 | padding=((kernel_size - 1) * dilation) // 2, bias=True), 13 | nn.LeakyReLU(0.1, inplace=True) 14 | ) 15 | else: 16 | return nn.Sequential( 17 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, 18 | padding=((kernel_size - 1) * dilation) // 2, bias=True) 19 | ) 20 | 21 | 22 | def initialize_msra(modules): 23 | logging.info("Initializing MSRA") 24 | for layer in modules: 25 | if isinstance(layer, nn.Conv2d): 26 | nn.init.kaiming_normal_(layer.weight) 27 | if layer.bias is not None: 28 | nn.init.constant_(layer.bias, 0) 29 | 30 | elif isinstance(layer, nn.ConvTranspose2d): 31 | nn.init.kaiming_normal_(layer.weight) 32 | if layer.bias is not None: 33 | nn.init.constant_(layer.bias, 0) 34 | 35 | elif isinstance(layer, nn.LeakyReLU): 36 | pass 37 | 38 | elif isinstance(layer, nn.Sequential): 39 | pass 40 | 41 | 42 | def upsample2d_as(inputs, target_as, mode="bilinear"): 43 | _, _, h, w = target_as.size() 44 | return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True) 45 | 46 | 47 | def rescale_flow(flow, div_flow, width_im, height_im, to_local=True): 48 | if to_local: 49 | u_scale = float(flow.size(3) / width_im / div_flow) 50 | v_scale = float(flow.size(2) / height_im / div_flow) 51 | else: 52 | u_scale = float(width_im * div_flow / flow.size(3)) 53 | v_scale = float(height_im * div_flow / flow.size(2)) 54 | 55 | u, v = flow.chunk(2, dim=1) 56 | u *= u_scale 57 | v *= v_scale 58 | 59 | return torch.cat([u, v], dim=1) 60 | 61 | 62 | class FeatureExtractor(nn.Module): 63 | def __init__(self, num_chs): 64 | super(FeatureExtractor, self).__init__() 65 | self.num_chs = num_chs 66 | self.convs = nn.ModuleList() 67 | 68 | for l, (ch_in, ch_out) in enumerate(zip(num_chs[:-1], num_chs[1:])): 69 | layer = nn.Sequential( 70 | conv(ch_in, ch_out, stride=2), 71 | conv(ch_out, ch_out) 72 | ) 73 | self.convs.append(layer) 74 | 75 | def forward(self, x): 76 | feature_pyramid = [] 77 | for conv in self.convs: 78 | x = conv(x) 79 | feature_pyramid.append(x) 80 | 81 | return feature_pyramid[::-1] 82 | 83 | 84 | def get_grid(x): 85 | grid_H = torch.linspace(-1.0, 1.0, x.size(3)).view(1, 1, 1, x.size(3)).expand(x.size(0), 1, x.size(2), x.size(3)) 86 | grid_V = torch.linspace(-1.0, 1.0, x.size(2)).view(1, 1, x.size(2), 1).expand(x.size(0), 1, x.size(2), x.size(3)) 87 | grid = torch.cat([grid_H, grid_V], 1) 88 | grids_cuda = grid.float().requires_grad_(False) 89 | if x.is_cuda: 90 | grids_cuda = grids_cuda.cuda() 91 | return grids_cuda 92 | 93 | 94 | class WarpingLayer(nn.Module): 95 | def __init__(self): 96 | super(WarpingLayer, self).__init__() 97 | 98 | def forward(self, x, flow, height_im, width_im, div_flow): 99 | flo_list = [] 100 | flo_w = flow[:, 0] * 2 / max(width_im - 1, 1) / div_flow 101 | flo_h = flow[:, 1] * 2 / max(height_im - 1, 1) / div_flow 102 | flo_list.append(flo_w) 103 | flo_list.append(flo_h) 104 | flow_for_grid = torch.stack(flo_list).transpose(0, 1) 105 | grid = torch.add(get_grid(x), flow_for_grid).transpose(1, 2).transpose(2, 3) 106 | x_warp = tf.grid_sample(x, grid) 107 | 108 | mask = torch.ones(x.size(), requires_grad=False) 109 | if x.is_cuda: 110 | mask = mask.cuda() 111 | mask = tf.grid_sample(mask, grid) 112 | mask = (mask >= 1.0).float() 113 | 114 | return x_warp * mask 115 | 116 | class OpticalFlowEstimator(nn.Module): 117 | def __init__(self, ch_in): 118 | super(OpticalFlowEstimator, self).__init__() 119 | 120 | self.convs = nn.Sequential( 121 | conv(ch_in, 128), 122 | conv(128, 128), 123 | conv(128, 96), 124 | conv(96, 64), 125 | conv(64, 32) 126 | ) 127 | self.conv_last = conv(32, 2, isReLU=False) 128 | 129 | def forward(self, x): 130 | x_intm = self.convs(x) 131 | return x_intm, self.conv_last(x_intm) 132 | 133 | 134 | class FlowEstimatorDense(nn.Module): 135 | def __init__(self, ch_in): 136 | super(FlowEstimatorDense, self).__init__() 137 | self.conv1 = conv(ch_in, 128) 138 | self.conv2 = conv(ch_in + 128, 128) 139 | self.conv3 = conv(ch_in + 256, 96) 140 | self.conv4 = conv(ch_in + 352, 64) 141 | self.conv5 = conv(ch_in + 416, 32) 142 | self.conv_last = conv(ch_in + 448, 2, isReLU=False) 143 | 144 | def forward(self, x): 145 | x1 = torch.cat([self.conv1(x), x], dim=1) 146 | x2 = torch.cat([self.conv2(x1), x1], dim=1) 147 | x3 = torch.cat([self.conv3(x2), x2], dim=1) 148 | x4 = torch.cat([self.conv4(x3), x3], dim=1) 149 | x5 = torch.cat([self.conv5(x4), x4], dim=1) 150 | x_out = self.conv_last(x5) 151 | return x5, x_out 152 | 153 | class OcclusionEstimator(nn.Module): 154 | def __init__(self, ch_in): 155 | super(OcclusionEstimator, self).__init__() 156 | self.convs = nn.Sequential( 157 | conv(ch_in, 128), 158 | conv(128, 128), 159 | conv(128, 96), 160 | conv(96, 64), 161 | conv(64, 32) 162 | ) 163 | self.conv_last = conv(32, 1, isReLU=False) 164 | 165 | def forward(self, x): 166 | x_intm = self.convs(x) 167 | return x_intm, self.conv_last(x_intm) 168 | 169 | 170 | class OccEstimatorDense(nn.Module): 171 | def __init__(self, ch_in): 172 | super(OccEstimatorDense, self).__init__() 173 | self.conv1 = conv(ch_in, 128) 174 | self.conv2 = conv(ch_in + 128, 128) 175 | self.conv3 = conv(ch_in + 256, 96) 176 | self.conv4 = conv(ch_in + 352, 64) 177 | self.conv5 = conv(ch_in + 416, 32) 178 | self.conv_last = conv(ch_in + 448, 1, isReLU=False) 179 | 180 | def forward(self, x): 181 | x1 = torch.cat([self.conv1(x), x], dim=1) 182 | x2 = torch.cat([self.conv2(x1), x1], dim=1) 183 | x3 = torch.cat([self.conv3(x2), x2], dim=1) 184 | x4 = torch.cat([self.conv4(x3), x3], dim=1) 185 | x5 = torch.cat([self.conv5(x4), x4], dim=1) 186 | x_out = self.conv_last(x5) 187 | return x5, x_out 188 | 189 | 190 | class ContextNetwork(nn.Module): 191 | def __init__(self, ch_in): 192 | super(ContextNetwork, self).__init__() 193 | 194 | self.convs = nn.Sequential( 195 | conv(ch_in, 128, 3, 1, 1), 196 | conv(128, 128, 3, 1, 2), 197 | conv(128, 128, 3, 1, 4), 198 | conv(128, 96, 3, 1, 8), 199 | conv(96, 64, 3, 1, 16), 200 | conv(64, 32, 3, 1, 1), 201 | conv(32, 2, isReLU=False) 202 | ) 203 | 204 | def forward(self, x): 205 | return self.convs(x) 206 | 207 | 208 | class OccContextNetwork(nn.Module): 209 | def __init__(self, ch_in): 210 | super(OccContextNetwork, self).__init__() 211 | 212 | self.convs = nn.Sequential( 213 | conv(ch_in, 128, 3, 1, 1), 214 | conv(128, 128, 3, 1, 2), 215 | conv(128, 128, 3, 1, 4), 216 | conv(128, 96, 3, 1, 8), 217 | conv(96, 64, 3, 1, 16), 218 | conv(64, 32, 3, 1, 1), 219 | conv(32, 1, isReLU=False) 220 | ) 221 | 222 | def forward(self, x): 223 | return self.convs(x) 224 | 225 | # ------------------------------------------- 226 | 227 | class FlowAndOccEstimatorDense(nn.Module): 228 | def __init__(self, ch_in): 229 | super(FlowAndOccEstimatorDense, self).__init__() 230 | self.conv1 = conv(ch_in, 128) 231 | self.conv2 = conv(ch_in + 128, 128) 232 | self.conv3 = conv(ch_in + 256, 96) 233 | self.conv4 = conv(ch_in + 352, 64) 234 | self.conv5 = conv(ch_in + 416, 32) 235 | self.conv_last = conv(ch_in + 448, 3, isReLU=False) 236 | 237 | def forward(self, x): 238 | x1 = torch.cat([self.conv1(x), x], dim=1) 239 | x2 = torch.cat([self.conv2(x1), x1], dim=1) 240 | x3 = torch.cat([self.conv3(x2), x2], dim=1) 241 | x4 = torch.cat([self.conv4(x3), x3], dim=1) 242 | x5 = torch.cat([self.conv5(x4), x4], dim=1) 243 | x_out = self.conv_last(x5) 244 | return x5, x_out[:,:2,:,:], x_out[:,2,:,:].unsqueeze(1) 245 | 246 | 247 | class FlowAndOccContextNetwork(nn.Module): 248 | def __init__(self, ch_in): 249 | super(FlowAndOccContextNetwork, self).__init__() 250 | 251 | self.convs = nn.Sequential( 252 | conv(ch_in, 128, 3, 1, 1), 253 | conv(128, 128, 3, 1, 2), 254 | conv(128, 128, 3, 1, 4), 255 | conv(128, 96, 3, 1, 8), 256 | conv(96, 64, 3, 1, 16), 257 | conv(64, 32, 3, 1, 1), 258 | conv(32, 3, isReLU=False) 259 | ) 260 | 261 | def forward(self, x): 262 | x_out = self.convs(x) 263 | return x_out[:,:2,:,:], x_out[:,2,:,:].unsqueeze(1) 264 | -------------------------------------------------------------------------------- /models/pwcnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .pwc_modules import upsample2d_as, initialize_msra 7 | from .pwc_modules import WarpingLayer, FeatureExtractor 8 | from .pwc_modules import ContextNetwork, FlowEstimatorDense 9 | from .correlation_package.correlation import Correlation 10 | 11 | class PWCNet(nn.Module): 12 | def __init__(self, args, div_flow=0.05): 13 | super(PWCNet, self).__init__() 14 | self.args = args 15 | self._div_flow = div_flow 16 | self.search_range = 4 17 | self.num_chs = [3, 16, 32, 64, 96, 128, 196] 18 | self.output_level = 4 19 | self.num_levels = 7 20 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True) 21 | 22 | self.feature_pyramid_extractor = FeatureExtractor(self.num_chs) 23 | self.warping_layer = WarpingLayer() 24 | 25 | self.flow_estimators = nn.ModuleList() 26 | self.dim_corr = (self.search_range * 2 + 1) ** 2 27 | for l, ch in enumerate(self.num_chs[::-1]): 28 | if l > self.output_level: 29 | break 30 | 31 | if l == 0: 32 | num_ch_in = self.dim_corr 33 | else: 34 | num_ch_in = self.dim_corr + ch + 2 35 | 36 | layer = FlowEstimatorDense(num_ch_in) 37 | self.flow_estimators.append(layer) 38 | 39 | self.context_networks = ContextNetwork(self.dim_corr + 32 + 2 + 448 + 2) 40 | 41 | initialize_msra(self.modules()) 42 | 43 | def forward(self, input_dict): 44 | 45 | x1_raw = input_dict['input1'] 46 | x2_raw = input_dict['input2'] 47 | _, _, height_im, width_im = x1_raw.size() 48 | 49 | # on the bottom level are original images 50 | x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw] 51 | x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw] 52 | 53 | # outputs 54 | output_dict = {} 55 | flows = [] 56 | 57 | # init 58 | b_size, _, h_x1, w_x1, = x1_pyramid[0].size() 59 | init_dtype = x1_pyramid[0].dtype 60 | init_device = x1_pyramid[0].device 61 | flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float() 62 | 63 | for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)): 64 | 65 | # warping 66 | if l == 0: 67 | x2_warp = x2 68 | else: 69 | flow = upsample2d_as(flow, x1, mode="bilinear") 70 | x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow) 71 | 72 | # correlation 73 | out_corr = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x1, x2_warp) 74 | out_corr_relu = self.leakyRELU(out_corr) 75 | 76 | # flow estimator 77 | if l == 0: 78 | x_intm, flow = self.flow_estimators[l](out_corr_relu) 79 | else: 80 | x_intm, flow = self.flow_estimators[l](torch.cat([out_corr_relu, x1, flow], dim=1)) 81 | 82 | # upsampling or post-processing 83 | if l != self.output_level: 84 | flows.append(flow) 85 | else: 86 | flow_res = self.context_networks(torch.cat([x_intm, flow], dim=1)) 87 | flow = flow + flow_res 88 | flows.append(flow) 89 | break 90 | 91 | output_dict['flow'] = flows 92 | 93 | if self.training: 94 | return output_dict 95 | else: 96 | output_dict_eval = {} 97 | out_flow = upsample2d_as(flow, x1_raw, mode="bilinear") * (1.0 / self._div_flow) 98 | output_dict_eval['flow'] = out_flow 99 | return output_dict_eval 100 | -------------------------------------------------------------------------------- /models/pwcnet_irr.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .pwc_modules import conv, rescale_flow, upsample2d_as, initialize_msra 7 | from .pwc_modules import WarpingLayer, FeatureExtractor 8 | from .pwc_modules import ContextNetwork, FlowEstimatorDense 9 | from .correlation_package.correlation import Correlation 10 | 11 | class PWCNet(nn.Module): 12 | def __init__(self, args, div_flow=0.05): 13 | super(PWCNet, self).__init__() 14 | self.args = args 15 | self._div_flow = div_flow 16 | self.search_range = 4 17 | self.num_chs = [3, 16, 32, 64, 96, 128, 196] 18 | self.output_level = 4 19 | self.num_levels = 7 20 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True) 21 | 22 | self.feature_pyramid_extractor = FeatureExtractor(self.num_chs) 23 | self.warping_layer = WarpingLayer() 24 | 25 | self.dim_corr = (self.search_range * 2 + 1) ** 2 26 | self.num_ch_in = self.dim_corr + 32 + 2 27 | 28 | self.flow_estimators = FlowEstimatorDense(self.num_ch_in) 29 | 30 | self.context_networks = ContextNetwork(self.num_ch_in + 448 + 2) 31 | 32 | self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1), 33 | conv(128, 32, kernel_size=1, stride=1, dilation=1), 34 | conv(96, 32, kernel_size=1, stride=1, dilation=1), 35 | conv(64, 32, kernel_size=1, stride=1, dilation=1), 36 | conv(32, 32, kernel_size=1, stride=1, dilation=1)]) 37 | 38 | initialize_msra(self.modules()) 39 | 40 | def forward(self, input_dict): 41 | 42 | x1_raw = input_dict['input1'] 43 | x2_raw = input_dict['input2'] 44 | _, _, height_im, width_im = x1_raw.size() 45 | 46 | # on the bottom level are original images 47 | x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw] 48 | x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw] 49 | 50 | # outputs 51 | output_dict = {} 52 | flows = [] 53 | 54 | # init 55 | b_size, _, h_x1, w_x1, = x1_pyramid[0].size() 56 | init_dtype = x1_pyramid[0].dtype 57 | init_device = x1_pyramid[0].device 58 | flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float() 59 | 60 | for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)): 61 | 62 | # warping 63 | if l == 0: 64 | x2_warp = x2 65 | else: 66 | flow = upsample2d_as(flow, x1, mode="bilinear") 67 | x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow) 68 | 69 | # correlation 70 | out_corr = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x1, x2_warp) 71 | out_corr_relu = self.leakyRELU(out_corr) 72 | 73 | # concat and estimate flow 74 | flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=True) 75 | 76 | x1_1by1 = self.conv_1x1[l](x1) 77 | x_intm, flow_res = self.flow_estimators(torch.cat([out_corr_relu, x1_1by1, flow], dim=1)) 78 | flow = flow + flow_res 79 | 80 | flow_fine = self.context_networks(torch.cat([x_intm, flow], dim=1)) 81 | flow = flow + flow_fine 82 | 83 | flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=False) 84 | flows.append(flow) 85 | 86 | # upsampling or post-processing 87 | if l == self.output_level: 88 | break 89 | 90 | output_dict['flow'] = flows 91 | 92 | if self.training: 93 | return output_dict 94 | else: 95 | output_dict_eval = {} 96 | output_dict_eval['flow'] = upsample2d_as(flow, x1_raw, mode="bilinear") * (1.0 / self._div_flow) 97 | return output_dict_eval 98 | -------------------------------------------------------------------------------- /models/pwcnet_irr_occ_joint.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .pwc_modules import conv, rescale_flow, upsample2d_as, initialize_msra 7 | from .pwc_modules import WarpingLayer, FeatureExtractor 8 | from .pwc_modules import FlowAndOccContextNetwork, FlowAndOccEstimatorDense 9 | from .correlation_package.correlation import Correlation 10 | 11 | class PWCNet(nn.Module): 12 | def __init__(self, args, div_flow=0.05): 13 | super(PWCNet, self).__init__() 14 | self.args = args 15 | self._div_flow = div_flow 16 | self.search_range = 4 17 | self.num_chs = [3, 16, 32, 64, 96, 128, 196] 18 | self.output_level = 4 19 | self.num_levels = 7 20 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True) 21 | 22 | self.feature_pyramid_extractor = FeatureExtractor(self.num_chs) 23 | self.warping_layer = WarpingLayer() 24 | 25 | self.dim_corr = (self.search_range * 2 + 1) ** 2 26 | self.num_ch_in = self.dim_corr + 32 + 2 + 1 27 | 28 | self.flow_and_occ_estimators = FlowAndOccEstimatorDense(self.num_ch_in) 29 | 30 | self.context_networks = FlowAndOccContextNetwork(self.num_ch_in + 448 + 2 + 1) 31 | 32 | self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1), 33 | conv(128, 32, kernel_size=1, stride=1, dilation=1), 34 | conv(96, 32, kernel_size=1, stride=1, dilation=1), 35 | conv(64, 32, kernel_size=1, stride=1, dilation=1), 36 | conv(32, 32, kernel_size=1, stride=1, dilation=1)]) 37 | 38 | initialize_msra(self.modules()) 39 | 40 | def forward(self, input_dict): 41 | 42 | x1_raw = input_dict['input1'] 43 | x2_raw = input_dict['input2'] 44 | _, _, height_im, width_im = x1_raw.size() 45 | 46 | # on the bottom level are original images 47 | x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw] 48 | x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw] 49 | 50 | # outputs 51 | output_dict = {} 52 | flows = [] 53 | occs = [] 54 | 55 | # init 56 | b_size, _, h_x1, w_x1, = x1_pyramid[0].size() 57 | init_dtype = x1_pyramid[0].dtype 58 | init_device = x1_pyramid[0].device 59 | flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float() 60 | occ = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float() 61 | 62 | for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)): 63 | 64 | # warping 65 | if l == 0: 66 | x2_warp = x2 67 | else: 68 | flow = upsample2d_as(flow, x1, mode="bilinear") 69 | occ = upsample2d_as(occ, x1, mode="bilinear") 70 | x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow) 71 | 72 | # correlation 73 | out_corr = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x1, x2_warp) 74 | out_corr_relu = self.leakyRELU(out_corr) 75 | 76 | # concat and estimate flow 77 | flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=True) 78 | 79 | x1_1by1 = self.conv_1x1[l](x1) 80 | x_intm, flow_res, occ_res = self.flow_and_occ_estimators(torch.cat([out_corr_relu, x1_1by1, flow, occ], dim=1)) 81 | flow = flow + flow_res 82 | occ = occ + occ_res 83 | 84 | flow_fine, occ_fine = self.context_networks(torch.cat([x_intm, flow, occ], dim=1)) 85 | flow = flow + flow_fine 86 | occ = occ + occ_fine 87 | 88 | flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=False) 89 | flows.append(flow) 90 | occs.append(occ) 91 | 92 | # upsampling or post-processing 93 | if l == self.output_level: 94 | break 95 | 96 | output_dict['flow'] = flows 97 | output_dict['occ'] = occs 98 | 99 | if self.training: 100 | return output_dict 101 | else: 102 | output_dict_eval = {} 103 | output_dict_eval['flow'] = upsample2d_as(flow, x1_raw, mode="bilinear") * (1.0 / self._div_flow) 104 | output_dict_eval['occ'] = upsample2d_as(occ, x1_raw, mode="bilinear") 105 | return output_dict_eval 106 | -------------------------------------------------------------------------------- /models/pwcnet_occ_joint.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .pwc_modules import upsample2d_as, initialize_msra 7 | from .pwc_modules import WarpingLayer, FeatureExtractor 8 | from .pwc_modules import FlowAndOccEstimatorDense, FlowAndOccContextNetwork 9 | from .correlation_package.correlation import Correlation 10 | 11 | class PWCNet(nn.Module): 12 | def __init__(self, args, div_flow=0.05): 13 | super(PWCNet, self).__init__() 14 | self.args = args 15 | self._div_flow = div_flow 16 | self.search_range = 4 17 | self.num_chs = [3, 16, 32, 64, 96, 128, 196] 18 | self.output_level = 4 19 | self.num_levels = 7 20 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True) 21 | 22 | self.feature_pyramid_extractor = FeatureExtractor(self.num_chs) 23 | self.warping_layer = WarpingLayer() 24 | 25 | self.flow_and_occ_estimators = nn.ModuleList() 26 | self.dim_corr = (self.search_range * 2 + 1) ** 2 27 | for l, ch in enumerate(self.num_chs[::-1]): 28 | if l > self.output_level: 29 | break 30 | 31 | if l == 0: 32 | num_ch_in = self.dim_corr 33 | else: 34 | num_ch_in = self.dim_corr + ch + 2 + 1 35 | 36 | layer = FlowAndOccEstimatorDense(num_ch_in) 37 | self.flow_and_occ_estimators.append(layer) 38 | 39 | self.context_networks = FlowAndOccContextNetwork(self.dim_corr + 32 + 2 + 1 + 448 + 2 + 1) 40 | 41 | initialize_msra(self.modules()) 42 | 43 | def forward(self, input_dict): 44 | 45 | x1_raw = input_dict['input1'] 46 | x2_raw = input_dict['input2'] 47 | _, _, height_im, width_im = x1_raw.size() 48 | 49 | # on the bottom level are original images 50 | x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw] 51 | x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw] 52 | 53 | # outputs 54 | output_dict = {} 55 | flows = [] 56 | occs = [] 57 | 58 | # init 59 | b_size, _, h_x1, w_x1, = x1_pyramid[0].size() 60 | init_dtype = x1_pyramid[0].dtype 61 | init_device = x1_pyramid[0].device 62 | flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float() 63 | occ = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float() 64 | 65 | for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)): 66 | 67 | # warping 68 | if l == 0: 69 | x2_warp = x2 70 | else: 71 | flow = upsample2d_as(flow, x1, mode="bilinear") 72 | occ = upsample2d_as(occ, x1, mode="bilinear") 73 | x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow) 74 | 75 | # correlation 76 | out_corr = Correlation(pad_size=self.search_range, kernel_size=1, max_displacement=self.search_range, stride1=1, stride2=1, corr_multiply=1)(x1, x2_warp) 77 | out_corr_relu = self.leakyRELU(out_corr) 78 | 79 | # flow estimator 80 | if l == 0: 81 | x_intm, flow, occ = self.flow_and_occ_estimators[l](out_corr_relu) 82 | else: 83 | x_intm, flow, occ = self.flow_and_occ_estimators[l](torch.cat([out_corr_relu, x1, flow, occ], dim=1)) 84 | 85 | # upsampling or post-processing 86 | if l != self.output_level: 87 | flows.append(flow) 88 | occs.append(occ) 89 | else: 90 | flow_fine, occ_fine = self.context_networks(torch.cat([x_intm, flow, occ], dim=1)) 91 | flow = flow + flow_fine 92 | occ = occ + occ_fine 93 | flows.append(flow) 94 | occs.append(occ) 95 | break 96 | 97 | output_dict['flow'] = flows 98 | output_dict['occ'] = occs 99 | 100 | if self.training: 101 | return output_dict 102 | else: 103 | output_dict_eval = {} 104 | output_dict_eval['flow'] = upsample2d_as(flow, x1_raw, mode="bilinear") * (1.0 / self._div_flow) 105 | output_dict_eval['occ'] = upsample2d_as(occ, x1_raw, mode="bilinear") 106 | return output_dict_eval 107 | -------------------------------------------------------------------------------- /optim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | from tools import module_classes_to_dict 4 | 5 | # ------------------------------------------------------------------------------------ 6 | # Export PyTorch optimizer 7 | # ------------------------------------------------------------------------------------ 8 | _this = sys.modules[__name__] 9 | _optimizer_classes = module_classes_to_dict(torch.optim, exclude_classes="Optimizer") 10 | for name, constructor in _optimizer_classes.items(): 11 | setattr(_this, name, constructor) 12 | __all__ = _optimizer_classes.keys() 13 | 14 | -------------------------------------------------------------------------------- /results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgodet/star_flow/cedb96ff339d11abf71d12d09e794593a742ccce/results.png -------------------------------------------------------------------------------- /saved_checkpoint/StarFlow_kitti/checkpoint_latest.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgodet/star_flow/cedb96ff339d11abf71d12d09e794593a742ccce/saved_checkpoint/StarFlow_kitti/checkpoint_latest.ckpt -------------------------------------------------------------------------------- /saved_checkpoint/StarFlow_sintel/checkpoint_latest.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgodet/star_flow/cedb96ff339d11abf71d12d09e794593a742ccce/saved_checkpoint/StarFlow_sintel/checkpoint_latest.ckpt -------------------------------------------------------------------------------- /saved_checkpoint/StarFlow_things/checkpoint_best.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgodet/star_flow/cedb96ff339d11abf71d12d09e794593a742ccce/saved_checkpoint/StarFlow_things/checkpoint_best.ckpt -------------------------------------------------------------------------------- /scripts_train/train_starflow_chairsocc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # experiments and datasets meta 4 | EXPERIMENTS_HOME="experiments" 5 | 6 | # datasets 7 | FLYINGCHAIRS_OCC_HOME=(YOUR PATH)/FlyingChairsOcc/ 8 | SINTEL_HOME=(YOUR PATH)/mpisintelcomplete 9 | 10 | # model and checkpoint 11 | MODEL=StarFlow 12 | EVAL_LOSS=MultiScaleEPE_PWC_Occ_upsample 13 | CHECKPOINT=None 14 | SIZE_OF_BATCH=8 15 | DEVICE=0 16 | 17 | # save path 18 | SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-chairs" 19 | 20 | # training configuration 21 | python ../main.py \ 22 | --batch_size=$SIZE_OF_BATCH \ 23 | --batch_size_val=$SIZE_OF_BATCH \ 24 | --checkpoint=$CHECKPOINT \ 25 | --lr_scheduler=MultiStepLR \ 26 | --lr_scheduler_gamma=0.5 \ 27 | --lr_scheduler_milestones="[108, 144, 180]" \ 28 | --model=$MODEL \ 29 | --num_workers=6 \ 30 | --device=$DEVICE \ 31 | --optimizer=Adam \ 32 | --optimizer_lr=1e-4 \ 33 | --optimizer_weight_decay=4e-4 \ 34 | --save=$SAVE_PATH \ 35 | --total_epochs=216 \ 36 | --training_augmentation=RandomAffineFlowOcc \ 37 | --training_dataset=FlyingChairsOccTrain \ 38 | --training_dataset_photometric_augmentations=True \ 39 | --training_dataset_root=$FLYINGCHAIRS_OCC_HOME \ 40 | --training_key=total_loss \ 41 | --training_loss=$EVAL_LOSS \ 42 | --validation_dataset=FlyingChairsOccValid \ 43 | --validation_dataset_photometric_augmentations=False \ 44 | --validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \ 45 | --validation_key=epe \ 46 | --validation_loss=$EVAL_LOSS 47 | -------------------------------------------------------------------------------- /scripts_train/train_starflow_kitti_full.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # experiments and datasets meta 4 | EXPERIMENTS_HOME="experiments" 5 | 6 | # datasets 7 | KITTI_HOME=(YOUR PATH)/KittiComb 8 | 9 | # model and checkpoint 10 | MODEL=StarFlow 11 | EVAL_LOSS=MultiScaleEPE_PWC_Occ_upsample_KITTI 12 | CHECKPOINT=None 13 | SIZE_OF_BATCH=4 14 | NFRAMES=4 15 | DEVICE=0 16 | 17 | # save path 18 | SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-ftkitti-full" 19 | 20 | # training configuration 21 | python ../main.py \ 22 | --batch_size=$SIZE_OF_BATCH \ 23 | --batch_size_val=$SIZE_OF_BATCH \ 24 | --checkpoint=$CHECKPOINT \ 25 | --lr_scheduler=MultiStepLR \ 26 | --lr_scheduler_gamma=0.5 \ 27 | --lr_scheduler_milestones="[456, 659, 862, 963, 989, 1014, 1116, 1217, 1319, 1420]" \ 28 | --model=$MODEL \ 29 | --num_workers=6 \ 30 | --device=$DEVICE \ 31 | --optimizer=Adam \ 32 | --optimizer_lr=3e-05 \ 33 | --optimizer_weight_decay=4e-4 \ 34 | --save=$SAVE_PATH \ 35 | --start_epoch=1 \ 36 | --total_epochs=550 \ 37 | --training_augmentation=RandomAffineFlowOccVideoKitti \ 38 | --training_augmentation_crop="[320,896]" \ 39 | --training_dataset=KittiMultiframeCombFull \ 40 | --training_dataset_nframes=$NFRAMES \ 41 | --training_dataset_photometric_augmentations=True \ 42 | --training_dataset_root=$KITTI_HOME \ 43 | --training_dataset_preprocessing_crop=True \ 44 | --training_key=total_loss \ 45 | --training_loss=$EVAL_LOSS \ 46 | --validation_dataset=KittiMultiframeComb2015Val \ 47 | --validation_dataset_nframes=$NFRAMES \ 48 | --validation_dataset_photometric_augmentations=True \ 49 | --validation_dataset_root=$KITTI_HOME \ 50 | --validation_dataset_preprocessing_crop=True \ 51 | --validation_key=epe \ 52 | --validation_loss=$EVAL_LOSS 53 | -------------------------------------------------------------------------------- /scripts_train/train_starflow_sintel_full.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # experiments and datasets meta 4 | EXPERIMENTS_HOME="experiments" 5 | 6 | # datasets 7 | SINTEL_HOME=(YOUR PATH)/mpisintelcomplete 8 | 9 | # model and checkpoint 10 | MODEL=StarFlow 11 | EVAL_LOSS=MultiScaleEPE_PWC_Occ_video_upsample_Sintel 12 | CHECKPOINT=None 13 | SIZE_OF_BATCH=4 14 | NFRAMES=4 15 | DEVICE=0 16 | 17 | # save path 18 | SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-ftsintel1-full" 19 | 20 | # training configuration 21 | python ../main.py \ 22 | --batch_size=$SIZE_OF_BATCH \ 23 | --batch_size_val=$SIZE_OF_BATCH \ 24 | --checkpoint=$CHECKPOINT \ 25 | --lr_scheduler=MultiStepLR \ 26 | --lr_scheduler_gamma=0.5 \ 27 | --lr_scheduler_milestones="[89, 130, 170, 190, 195, 200, 220, 240, 260, 280]" \ 28 | --model=$MODEL \ 29 | --num_workers=6 \ 30 | --device=$DEVICE \ 31 | --optimizer=Adam \ 32 | --optimizer_lr=1.5e-05 \ 33 | --optimizer_weight_decay=4e-4 \ 34 | --save=$SAVE_PATH \ 35 | --start_epoch=1 \ 36 | --total_epochs=300 \ 37 | --training_augmentation=RandomAffineFlowOccVideo \ 38 | --training_augmentation_crop="[384,768]" \ 39 | --training_dataset=SintelMultiframeTrainingCombFull \ 40 | --training_dataset_nframes=$NFRAMES \ 41 | --training_dataset_photometric_augmentations=True \ 42 | --training_dataset_root=$SINTEL_HOME \ 43 | --training_key=total_loss \ 44 | --training_loss=$EVAL_LOSS \ 45 | --validation_dataset=SintelMultiframeTrainingFinalValid \ 46 | --validation_dataset_nframes=$NFRAMES \ 47 | --validation_dataset_photometric_augmentations=False \ 48 | --validation_dataset_root=$SINTEL_HOME \ 49 | --validation_key=epe \ 50 | --validation_loss=$EVAL_LOSS 51 | 52 | # save path 53 | SAVE_PATH_2="$EXPERIMENTS_HOME/$MODEL-ftsintel2-full" 54 | 55 | # training configuration 56 | python ../main.py \ 57 | --batch_size=$SIZE_OF_BATCH \ 58 | --batch_size_val=$SIZE_OF_BATCH \ 59 | --checkpoint=$SAVE_PATH \ 60 | --lr_scheduler=MultiStepLR \ 61 | --lr_scheduler_gamma=0.5 \ 62 | --lr_scheduler_milestones="[481, 562, 643, 683, 693, 703, 743, 783, 824, 864]" \ 63 | --model=$MODEL \ 64 | --num_workers=6 \ 65 | --device=$DEVICE \ 66 | --optimizer=Adam \ 67 | --optimizer_lr=1e-05 \ 68 | --optimizer_weight_decay=4e-4 \ 69 | --save=$SAVE_PATH_2 \ 70 | --start_epoch=301 \ 71 | --total_epochs=451 \ 72 | --training_augmentation=RandomAffineFlowOccVideo \ 73 | --training_augmentation_crop="[384,768]" \ 74 | --training_dataset=SintelMultiframeTrainingFinalFull \ 75 | --training_dataset_nframes=$NFRAMES \ 76 | --training_dataset_photometric_augmentations=True \ 77 | --training_dataset_root=$SINTEL_HOME \ 78 | --training_key=total_loss \ 79 | --training_loss=$EVAL_LOSS \ 80 | --validation_dataset=SintelMultiframeTrainingFinalValid \ 81 | --validation_dataset_nframes=$NFRAMES \ 82 | --validation_dataset_photometric_augmentations=False \ 83 | --validation_dataset_root=$SINTEL_HOME \ 84 | --validation_key=epe \ 85 | --validation_loss=$EVAL_LOSS 86 | -------------------------------------------------------------------------------- /scripts_train/train_starflow_things.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # experiments and datasets meta 4 | EXPERIMENTS_HOME="experiments" 5 | 6 | # datasets 7 | FLYINGTHINGS_HOME=(YOUR PATH)/FlyingThings3DSubset 8 | SINTEL_HOME=(YOUR PATH)/mpisintelcomplete 9 | 10 | # model and checkpoint 11 | MODEL=StarFlow 12 | EVAL_LOSS=MultiScaleEPE_PWC_Occ_video_upsample 13 | CHECKPOINT=None 14 | SIZE_OF_BATCH=4 15 | NFRAMES=4 16 | DEVICE=0 17 | 18 | # save path 19 | SAVE_PATH="$EXPERIMENTS_HOME/$MODEL-ftthings" 20 | 21 | # training configuration 22 | python ../main.py \ 23 | --batch_size=$SIZE_OF_BATCH \ 24 | --batch_size_val=$SIZE_OF_BATCH \ 25 | --checkpoint=$CHECKPOINT \ 26 | --lr_scheduler=MultiStepLR \ 27 | --lr_scheduler_gamma=0.5 \ 28 | --lr_scheduler_milestones="[257, 287, 307, 317]" \ 29 | --model=$MODEL \ 30 | --num_workers=6 \ 31 | --device=$DEVICE \ 32 | --optimizer=Adam \ 33 | --optimizer_lr=1e-4 \ 34 | --optimizer_weight_decay=4e-4 \ 35 | --save=$SAVE_PATH \ 36 | --start_epoch=217 \ 37 | --total_epochs=327 \ 38 | --training_augmentation=RandomAffineFlowOccVideo \ 39 | --training_augmentation_crop="[384,768]" \ 40 | --training_dataset=FlyingThings3dMultiframeCleanTrain \ 41 | --training_dataset_nframes=$NFRAMES \ 42 | --training_dataset_photometric_augmentations=True \ 43 | --training_dataset_root=$FLYINGTHINGS_HOME \ 44 | --training_key=total_loss \ 45 | --training_loss=$EVAL_LOSS \ 46 | --validation_dataset=FlyingThings3dMultiframeCleanTest \ 47 | --validation_dataset_nframes=$NFRAMES \ 48 | --validation_dataset_photometric_augmentations=False \ 49 | --validation_dataset_root=$FLYINGTHINGS_HOME \ 50 | --validation_key=epe \ 51 | --validation_loss=$EVAL_LOSS 52 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | ## Portions of Code from, copyright 2018 Jochen Gast 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import os 6 | import socket 7 | import re 8 | #from pytz import timezone 9 | from datetime import datetime 10 | import fnmatch 11 | import itertools 12 | import argparse 13 | import sys 14 | import six 15 | import unicodedata 16 | import json 17 | import inspect 18 | import tqdm 19 | import logging 20 | import torch 21 | import ast 22 | import numpy as np 23 | 24 | 25 | def x2module(module_or_data_parallel): 26 | if isinstance(module_or_data_parallel, torch.nn.DataParallel): 27 | return module_or_data_parallel.module 28 | else: 29 | return module_or_data_parallel 30 | 31 | 32 | # ---------------------------------------------------------------------------------------- 33 | # Comprehensively adds a new logging level to the `logging` module and the 34 | # currently configured logging class. 35 | # e.g. addLoggingLevel('TRACE', logging.DEBUG - 5) 36 | # ---------------------------------------------------------------------------------------- 37 | def addLoggingLevel(level_name, level_num, method_name=None): 38 | if not method_name: 39 | method_name = level_name.lower() 40 | if hasattr(logging, level_name): 41 | raise AttributeError('{} already defined in logging module'.format(level_name)) 42 | if hasattr(logging, method_name): 43 | raise AttributeError('{} already defined in logging module'.format(method_name)) 44 | if hasattr(logging.getLoggerClass(), method_name): 45 | raise AttributeError('{} already defined in logger class'.format(method_name)) 46 | 47 | # This method was inspired by the answers to Stack Overflow post 48 | # http://stackoverflow.com/q/2183233/2988730, especially 49 | # http://stackoverflow.com/a/13638084/2988730 50 | def logForLevel(self, message, *args, **kwargs): 51 | if self.isEnabledFor(level_num): 52 | self._log(level_num, message, args, **kwargs) 53 | 54 | def logToRoot(message, *args, **kwargs): 55 | logging.log(level_num, message, *args, **kwargs) 56 | 57 | logging.addLevelName(level_num, level_name) 58 | setattr(logging, level_name, level_num) 59 | setattr(logging.getLoggerClass(), method_name, logForLevel) 60 | setattr(logging, method_name, logToRoot) 61 | 62 | 63 | # ------------------------------------------------------------------------------------------------- 64 | # Looks for sub arguments in the argument structure. 65 | # Retrieve sub arguments for modules such as optimizer_* 66 | # ------------------------------------------------------------------------------------------------- 67 | def kwargs_from_args(args, name, exclude=[]): 68 | if isinstance(exclude, str): 69 | exclude = [exclude] 70 | exclude += ["class"] 71 | args_dict = vars(args) 72 | name += "_" 73 | subargs_dict = { 74 | key[len(name):]: value for key, value in args_dict.items() 75 | if name in key and all([key != name + x for x in exclude]) 76 | } 77 | return subargs_dict 78 | 79 | 80 | # ------------------------------------------------------------------------------------------------- 81 | # Create class instance from kwargs dictionary. 82 | # Filters out keys that not in the constructor 83 | # ------------------------------------------------------------------------------------------------- 84 | def instance_from_kwargs(class_constructor, kwargs): 85 | argspec = inspect.getargspec(class_constructor.__init__) 86 | full_args = argspec.args 87 | filtered_args = dict([(k,v) for k,v in kwargs.items() if k in full_args]) 88 | instance = class_constructor(**filtered_args) 89 | return instance 90 | 91 | 92 | def module_classes_to_dict(module, include_classes="*", exclude_classes=()): 93 | # ------------------------------------------------------------------------- 94 | # If arguments are strings, convert them to a list 95 | # ------------------------------------------------------------------------- 96 | if include_classes is not None: 97 | if isinstance(include_classes, str): 98 | include_classes = [include_classes] 99 | 100 | if exclude_classes is not None: 101 | if isinstance(exclude_classes, str): 102 | exclude_classes = [exclude_classes] 103 | 104 | # ------------------------------------------------------------------------- 105 | # Obtain dictionary from given module 106 | # ------------------------------------------------------------------------- 107 | item_dict = dict([(name, getattr(module, name)) for name in dir(module)]) 108 | 109 | # ------------------------------------------------------------------------- 110 | # Filter classes 111 | # ------------------------------------------------------------------------- 112 | item_dict = dict([ 113 | (name,value) for name, value in item_dict.items() if inspect.isclass(getattr(module, name)) 114 | ]) 115 | 116 | filtered_keys = filter_list_of_strings( 117 | item_dict.keys(), include=include_classes, exclude=exclude_classes) 118 | 119 | # ------------------------------------------------------------------------- 120 | # Construct dictionary from matched results 121 | # ------------------------------------------------------------------------- 122 | result_dict = dict([(name, value) for name, value in item_dict.items() if name in filtered_keys]) 123 | 124 | return result_dict 125 | 126 | 127 | def ensure_dir(file_path): 128 | directory = os.path.dirname(file_path) 129 | if not os.path.exists(directory): 130 | os.makedirs(directory) 131 | 132 | 133 | def search_and_replace(string, regex, replace): 134 | while True: 135 | match = re.search(regex, string) 136 | if match: 137 | string = string.replace(match.group(0), replace) 138 | else: 139 | break 140 | return string 141 | 142 | 143 | def hostname(): 144 | name = socket.gethostname() 145 | n = name.find('.') 146 | if n > 0: 147 | name = name[:n] 148 | return name 149 | 150 | 151 | def get_filenames(directory, match='*.*', not_match=()): 152 | if match is not None: 153 | if isinstance(match, str): 154 | match = [match] 155 | if not_match is not None: 156 | if isinstance(not_match, str): 157 | not_match = [not_match] 158 | 159 | result = [] 160 | for dirpath, _, filenames in os.walk(directory): 161 | filtered_matches = list(itertools.chain.from_iterable( 162 | [fnmatch.filter(filenames, x) for x in match])) 163 | filtered_nomatch = list(itertools.chain.from_iterable( 164 | [fnmatch.filter(filenames, x) for x in not_match])) 165 | matched = list(set(filtered_matches) - set(filtered_nomatch)) 166 | result += [os.path.join(dirpath, x) for x in matched] 167 | return result 168 | 169 | 170 | def str2bool(v): 171 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 172 | return True 173 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 174 | return False 175 | else: 176 | raise argparse.ArgumentTypeError('Boolean value expected.') 177 | 178 | 179 | def str2str_or_none(v): 180 | if v.lower() == "none": 181 | return None 182 | return v 183 | 184 | 185 | def str2dict(v): 186 | return ast.literal_eval(v) 187 | 188 | 189 | def str2intlist(v): 190 | return [int(x.strip()) for x in v.strip()[1:-1].split(',')] 191 | 192 | 193 | def str2list(v): 194 | return [str(x.strip()) for x in v.strip()[1:-1].split(',')] 195 | 196 | 197 | def read_json(filename): 198 | 199 | def _convert_from_unicode(data): 200 | new_data = dict() 201 | for name, value in six.iteritems(data): 202 | if isinstance(name, six.string_types): 203 | name = unicodedata.normalize('NFKD', name).encode( 204 | 'ascii', 'ignore') 205 | if isinstance(value, six.string_types): 206 | value = unicodedata.normalize('NFKD', value).encode( 207 | 'ascii', 'ignore') 208 | if isinstance(value, dict): 209 | value = _convert_from_unicode(value) 210 | new_data[name] = value 211 | return new_data 212 | 213 | output_dict = None 214 | with open(filename, "r") as f: 215 | lines = f.readlines() 216 | try: 217 | output_dict = json.loads(''.join(lines), encoding='utf-8') 218 | except: 219 | raise ValueError('Could not read %s. %s' % (filename, sys.exc_info()[1])) 220 | output_dict = _convert_from_unicode(output_dict) 221 | return output_dict 222 | 223 | 224 | def write_json(data_dict, filename): 225 | with open(filename, "w") as file: 226 | json.dump(data_dict, file) 227 | 228 | 229 | def datestr(): 230 | #pacific = timezone('US/Pacific') 231 | #now = datetime.now(pacific) 232 | now = datetime.now() 233 | return '{}{:02}{:02}_{:02}{:02}'.format(now.year, now.month, now.day, now.hour, now.minute) 234 | 235 | 236 | def filter_list_of_strings(lst, include="*", exclude=()): 237 | filtered_matches = list(itertools.chain.from_iterable([fnmatch.filter(lst, x) for x in include])) 238 | filtered_nomatch = list(itertools.chain.from_iterable([fnmatch.filter(lst, x) for x in exclude])) 239 | matched = list(set(filtered_matches) - set(filtered_nomatch)) 240 | return matched 241 | 242 | 243 | # ---------------------------------------------------------------------------- 244 | # Writes all pairs to a filename for book keeping 245 | # Either .txt or .json 246 | # ---------------------------------------------------------------------------- 247 | def write_dictionary_to_file(arguments_dict, filename): 248 | # ensure dir 249 | d = os.path.dirname(filename) 250 | if not os.path.exists(d): 251 | os.makedirs(d) 252 | 253 | # check for json extension 254 | ext = os.path.splitext(filename)[1] 255 | if ext == ".json": 256 | 257 | def replace_quotes(x): 258 | return x.replace("\'", "\"") 259 | 260 | with open(filename, 'w') as file: 261 | file.write("{\n") 262 | for i, (key, value) in enumerate(arguments_dict): 263 | if isinstance(value, tuple): 264 | value = list(value) 265 | if value is None: 266 | file.write(" \"%s\": null" % key) 267 | elif isinstance(value, str): 268 | value = value.replace("\'", "\"") 269 | file.write(" \"%s\": \"%s\"" % (key, replace_quotes(str( value)))) 270 | elif isinstance(value, bool): 271 | file.write(" \"%s\": %s" % (key, str(value).lower())) 272 | else: 273 | file.write(" \"%s\": %s" % (key, replace_quotes(str(value)))) 274 | if i < len(arguments_dict) - 1: 275 | file.write(',\n') 276 | else: 277 | file.write('\n') 278 | file.write("}\n") 279 | else: 280 | with open(filename, 'w') as file: 281 | for key, value in arguments_dict: 282 | file.write('%s: %s\n' % (key, value)) 283 | 284 | 285 | class MovingAverage: 286 | postfix = "avg" 287 | 288 | def __init__(self): 289 | self._sum = 0.0 290 | self._count = 0 291 | 292 | def add_value(self, sigma, addcount=1): 293 | self._sum += sigma 294 | self._count += addcount 295 | 296 | def add_average(self, avg, addcount): 297 | self._sum += avg*addcount 298 | self._count += addcount 299 | 300 | def mean(self): 301 | return self._sum / self._count 302 | 303 | 304 | class ExponentialMovingAverage: 305 | postfix = "ema" 306 | 307 | def __init__(self, alpha=0.7): 308 | self._weighted_sum = 0.0 309 | self._weighted_count = 0 310 | self._alpha = alpha 311 | 312 | def add_value(self, sigma, addcount=1): 313 | self._weighted_sum = sigma + (1.0 - self._alpha)*self._weighted_sum 314 | self._weighted_count = 1 + (1.0 - self._alpha)*self._weighted_count 315 | 316 | def add_average(self, avg, addcount): 317 | self._weighted_sum = avg*addcount + (1.0 - self._alpha)*self._weighted_sum 318 | self._weighted_count = addcount + (1.0 - self._alpha)*self._weighted_count 319 | 320 | def mean(self): 321 | return self._weighted_sum / self._weighted_count 322 | 323 | 324 | # ----------------------------------------------------------------- 325 | # Subclass tqdm to achieve two things: 326 | # 1) Output the progress bar into the logbook. 327 | # 2) Remove the comma before {postfix} because it's annoying. 328 | # ----------------------------------------------------------------- 329 | class TqdmToLogger(tqdm.tqdm): 330 | def __init__(self, iterable=None, desc=None, total=None, leave=True, 331 | file=None, ncols=None, mininterval=0.1, 332 | maxinterval=10.0, miniters=None, ascii=None, disable=False, 333 | unit='it', unit_scale=False, dynamic_ncols=False, 334 | smoothing=0.3, bar_format=None, initial=0, position=None, 335 | postfix=None, 336 | logging_on_close=True, 337 | logging_on_update=False): 338 | 339 | super(TqdmToLogger, self).__init__( 340 | iterable=iterable, desc=desc, total=total, leave=leave, 341 | file=file, ncols=ncols, mininterval=mininterval, 342 | maxinterval=maxinterval, miniters=miniters, ascii=ascii, disable=disable, 343 | unit=unit, unit_scale=unit_scale, dynamic_ncols=dynamic_ncols, 344 | smoothing=smoothing, bar_format=bar_format, initial=initial, position=position, 345 | postfix=postfix) 346 | 347 | self._logging_on_close = logging_on_close 348 | self._logging_on_update = logging_on_update 349 | self._closed = False 350 | 351 | @staticmethod 352 | def format_meter(n, total, elapsed, ncols=None, prefix='', ascii=False, 353 | unit='it', unit_scale=False, rate=None, bar_format=None, 354 | postfix=None, unit_divisor=1000): 355 | 356 | meter = tqdm.tqdm.format_meter( 357 | n=n, total=total, elapsed=elapsed, ncols=ncols, prefix=prefix, ascii=ascii, 358 | unit=unit, unit_scale=unit_scale, rate=rate, bar_format=bar_format, 359 | postfix=postfix, unit_divisor=unit_divisor) 360 | 361 | # get rid of that stupid comma before the postfix 362 | if postfix is not None: 363 | postfix_with_comma = ", %s" % postfix 364 | meter = meter.replace(postfix_with_comma, postfix) 365 | 366 | return meter 367 | 368 | def update(self, n=1): 369 | if self._logging_on_update: 370 | msg = self.__repr__() 371 | logging.logbook(msg) 372 | return super(TqdmToLogger, self).update(n=n) 373 | 374 | def close(self): 375 | if self._logging_on_close and not self._closed: 376 | msg = self.__repr__() 377 | logging.logbook(msg) 378 | self._closed = True 379 | return super(TqdmToLogger, self).close() 380 | 381 | 382 | def tqdm_with_logging(iterable=None, desc=None, total=None, leave=True, 383 | ncols=None, mininterval=0.1, 384 | maxinterval=10.0, miniters=None, ascii=None, disable=False, 385 | unit="it", unit_scale=False, dynamic_ncols=False, 386 | smoothing=0.3, bar_format=None, initial=0, position=None, 387 | postfix=None, 388 | logging_on_close=True, 389 | logging_on_update=False): 390 | 391 | return TqdmToLogger( 392 | iterable=iterable, desc=desc, total=total, leave=leave, 393 | ncols=ncols, mininterval=mininterval, 394 | maxinterval=maxinterval, miniters=miniters, ascii=ascii, disable=disable, 395 | unit=unit, unit_scale=unit_scale, dynamic_ncols=dynamic_ncols, 396 | smoothing=smoothing, bar_format=bar_format, initial=initial, position=position, 397 | postfix=postfix, 398 | logging_on_close=logging_on_close, 399 | logging_on_update=logging_on_update) 400 | 401 | 402 | def cd_dotdot(path_or_filename): 403 | return os.path.abspath(os.path.join(os.path.dirname(path_or_filename), "..")) 404 | 405 | 406 | def cd_dotdotdot(path_or_filename): 407 | return os.path.abspath(os.path.join(os.path.dirname(path_or_filename), "../..")) 408 | 409 | 410 | def cd_dotdotdotdot(path_or_filename): 411 | return os.path.abspath(os.path.join(os.path.dirname(path_or_filename), "../../..")) 412 | 413 | 414 | def tensor2numpy(tensor): 415 | if isinstance(tensor, np.ndarray): 416 | return tensor 417 | else: 418 | if isinstance(tensor, torch.autograd.Variable): 419 | tensor = tensor.data 420 | if tensor.dim() == 3: 421 | return tensor.cpu().numpy().transpose([1,2,0]) 422 | else: 423 | return tensor.cpu().numpy().transpose([0,2,3,1]) 424 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pgodet/star_flow/cedb96ff339d11abf71d12d09e794593a742ccce/utils/__init__.py -------------------------------------------------------------------------------- /utils/flow.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import numpy as np 4 | import png 5 | #import matplotlib.colors as cl 6 | 7 | TAG_CHAR = np.array([202021.25], np.float32) 8 | UNKNOWN_FLOW_THRESH = 1e7 9 | 10 | 11 | def write_flow(filename, uv, v=None): 12 | nBands = 2 13 | 14 | if v is None: 15 | assert (uv.ndim == 3) 16 | assert (uv.shape[2] == 2) 17 | u = uv[:, :, 0] 18 | v = uv[:, :, 1] 19 | else: 20 | u = uv 21 | 22 | assert (u.shape == v.shape) 23 | height, width = u.shape 24 | f = open(filename, 'wb') 25 | # write the header 26 | f.write(TAG_CHAR) 27 | np.array(width).astype(np.int32).tofile(f) 28 | np.array(height).astype(np.int32).tofile(f) 29 | # arrange into matrix form 30 | tmp = np.zeros((height, width * nBands)) 31 | tmp[:, np.arange(width) * 2] = u 32 | tmp[:, np.arange(width) * 2 + 1] = v 33 | tmp.astype(np.float32).tofile(f) 34 | f.close() 35 | 36 | 37 | def write_flow_png(filename, uv, v=None, mask=None): 38 | 39 | if v is None: 40 | assert (uv.ndim == 3) 41 | assert (uv.shape[2] == 2) 42 | u = uv[:, :, 0] 43 | v = uv[:, :, 1] 44 | else: 45 | u = uv 46 | 47 | assert (u.shape == v.shape) 48 | 49 | height_img, width_img = u.shape 50 | if mask is None: 51 | valid_mask = np.ones([height_img, width_img]) 52 | else: 53 | valid_mask = mask 54 | 55 | flow_u = np.clip((u * 64 + 2 ** 15), 0.0, 65535.0).astype(np.uint16) 56 | flow_v = np.clip((v * 64 + 2 ** 15), 0.0, 65535.0).astype(np.uint16) 57 | 58 | output = np.stack((flow_u, flow_v, valid_mask), axis=-1) 59 | 60 | with open(filename, 'wb') as f: 61 | writer = png.Writer(width=width_img, height=height_img, bitdepth=16) 62 | writer.write(f, np.reshape(output, (-1, width_img*3))) 63 | 64 | 65 | def flow_to_png(flow_map, max_value=None): 66 | _, h, w = flow_map.shape 67 | rgb_map = np.ones((h, w, 3)).astype(np.float32) 68 | if max_value is not None: 69 | normalized_flow_map = flow_map / max_value 70 | else: 71 | normalized_flow_map = flow_map / (np.abs(flow_map).max()) 72 | rgb_map[:, :, 0] += normalized_flow_map[0] 73 | rgb_map[:, :, 1] -= 0.5 * (normalized_flow_map[0] + normalized_flow_map[1]) 74 | rgb_map[:, :, 2] += normalized_flow_map[1] 75 | return rgb_map.clip(0, 1) 76 | 77 | 78 | 79 | def compute_color(u, v): 80 | """ 81 | compute optical flow color map 82 | :param u: optical flow horizontal map 83 | :param v: optical flow vertical map 84 | :return: optical flow in color code 85 | """ 86 | [h, w] = u.shape 87 | img = np.zeros([h, w, 3]) 88 | nanIdx = np.isnan(u) | np.isnan(v) 89 | u[nanIdx] = 0 90 | v[nanIdx] = 0 91 | 92 | colorwheel = make_color_wheel() 93 | ncols = np.size(colorwheel, 0) 94 | 95 | rad = np.sqrt(u ** 2 + v ** 2) 96 | rad[rad>1] = 1 97 | 98 | a = np.arctan2(-v, -u) / np.pi 99 | 100 | fk = (a + 1) / 2 * (ncols - 1) + 1 101 | 102 | k0 = np.floor(fk).astype(int) 103 | 104 | k1 = k0 + 1 105 | k1[k1 == ncols + 1] = 1 106 | f = fk - k0 107 | 108 | for i in range(0, np.size(colorwheel, 1)): 109 | tmp = colorwheel[:, i] 110 | col0 = tmp[k0 - 1] / 255 111 | col1 = tmp[k1 - 1] / 255 112 | col = (1 - f) * col0 + f * col1 113 | 114 | idx = rad <= 1 115 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 116 | notidx = np.logical_not(idx) 117 | 118 | col[notidx] *= 0.75 119 | img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx))) 120 | 121 | return img 122 | 123 | 124 | def make_color_wheel(): 125 | """ 126 | Generate color wheel according Middlebury color code 127 | :return: Color wheel 128 | """ 129 | RY = 15 130 | YG = 6 131 | GC = 4 132 | CB = 11 133 | BM = 13 134 | MR = 6 135 | 136 | ncols = RY + YG + GC + CB + BM + MR 137 | 138 | colorwheel = np.zeros([ncols, 3]) 139 | 140 | col = 0 141 | 142 | # RY 143 | colorwheel[0:RY, 0] = 255 144 | colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) 145 | col += RY 146 | 147 | # YG 148 | colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG)) 149 | colorwheel[col:col + YG, 1] = 255 150 | col += YG 151 | 152 | # GC 153 | colorwheel[col:col + GC, 1] = 255 154 | colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) 155 | col += GC 156 | 157 | # CB 158 | colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB)) 159 | colorwheel[col:col + CB, 2] = 255 160 | col += CB 161 | 162 | # BM 163 | colorwheel[col:col + BM, 2] = 255 164 | colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) 165 | col += + BM 166 | 167 | # MR 168 | colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 169 | colorwheel[col:col + MR, 0] = 255 170 | 171 | return colorwheel 172 | 173 | 174 | def flow_to_png_middlebury(flow, maxnorm=None): 175 | """ 176 | Convert flow into middlebury color code image 177 | :param flow: optical flow map 178 | :return: optical flow image in middlebury color 179 | """ 180 | 181 | flow = flow.transpose([1, 2, 0]) 182 | u = flow[:, :, 0] 183 | v = flow[:, :, 1] 184 | 185 | maxu = -999. 186 | maxv = -999. 187 | minu = 999. 188 | minv = 999. 189 | 190 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) 191 | u[idxUnknow] = 0 192 | v[idxUnknow] = 0 193 | 194 | maxu = max(maxu, np.max(u)) 195 | minu = min(minu, np.min(u)) 196 | 197 | maxv = max(maxv, np.max(v)) 198 | minv = min(minv, np.min(v)) 199 | 200 | if maxnorm is None: 201 | rad = np.sqrt(u ** 2 + v ** 2) 202 | maxrad = max(-1, np.max(rad)) 203 | else: 204 | maxrad = maxnorm 205 | 206 | u = u / (maxrad + np.finfo(float).eps) 207 | v = v / (maxrad + np.finfo(float).eps) 208 | 209 | img = compute_color(u, v) 210 | 211 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) 212 | img[idx] = 0 213 | 214 | return np.uint8(img) 215 | -------------------------------------------------------------------------------- /utils/interpolation.py: -------------------------------------------------------------------------------- 1 | ## Portions of Code from, copyright 2018 Jochen Gast 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as tf 8 | 9 | 10 | def _bchw2bhwc(tensor): 11 | return tensor.transpose(1,2).transpose(2,3) 12 | 13 | 14 | def _bhwc2bchw(tensor): 15 | return tensor.transpose(2,3).transpose(1,2) 16 | 17 | 18 | class Meshgrid(nn.Module): 19 | def __init__(self): 20 | super(Meshgrid, self).__init__() 21 | self.width = 0 22 | self.height = 0 23 | self.register_buffer("xx", torch.zeros(1,1)) 24 | self.register_buffer("yy", torch.zeros(1,1)) 25 | self.register_buffer("rangex", torch.zeros(1,1)) 26 | self.register_buffer("rangey", torch.zeros(1,1)) 27 | 28 | def _compute_meshgrid(self, width, height): 29 | torch.arange(0, width, out=self.rangex) 30 | torch.arange(0, height, out=self.rangey) 31 | self.xx = self.rangex.repeat(height, 1).contiguous() 32 | self.yy = self.rangey.repeat(width, 1).t().contiguous() 33 | 34 | def forward(self, width, height): 35 | if self.width != width or self.height != height: 36 | self._compute_meshgrid(width=width, height=height) 37 | self.width = width 38 | self.height = height 39 | return self.xx, self.yy 40 | 41 | 42 | class BatchSub2Ind(nn.Module): 43 | def __init__(self): 44 | super(BatchSub2Ind, self).__init__() 45 | self.register_buffer("_offsets", torch.LongTensor()) 46 | 47 | def forward(self, shape, row_sub, col_sub, out=None): 48 | batch_size = row_sub.size(0) 49 | height, width = shape 50 | ind = row_sub*width + col_sub 51 | torch.arange(batch_size, out=self._offsets) 52 | self._offsets *= (height*width) 53 | 54 | if out is None: 55 | return torch.add(ind, self._offsets.view(-1,1,1)) 56 | else: 57 | torch.add(ind, self._offsets.view(-1,1,1), out=out) 58 | 59 | 60 | class Interp2(nn.Module): 61 | def __init__(self, clamp=False): 62 | super(Interp2, self).__init__() 63 | self._clamp = clamp 64 | self._batch_sub2ind = BatchSub2Ind() 65 | self.register_buffer("_x0", torch.LongTensor()) 66 | self.register_buffer("_x1", torch.LongTensor()) 67 | self.register_buffer("_y0", torch.LongTensor()) 68 | self.register_buffer("_y1", torch.LongTensor()) 69 | self.register_buffer("_i00", torch.LongTensor()) 70 | self.register_buffer("_i01", torch.LongTensor()) 71 | self.register_buffer("_i10", torch.LongTensor()) 72 | self.register_buffer("_i11", torch.LongTensor()) 73 | self.register_buffer("_v00", torch.FloatTensor()) 74 | self.register_buffer("_v01", torch.FloatTensor()) 75 | self.register_buffer("_v10", torch.FloatTensor()) 76 | self.register_buffer("_v11", torch.FloatTensor()) 77 | self.register_buffer("_x", torch.FloatTensor()) 78 | self.register_buffer("_y", torch.FloatTensor()) 79 | 80 | def forward(self, v, xq, yq): 81 | batch_size, channels, height, width = v.size() 82 | 83 | # clamp if wanted 84 | if self._clamp: 85 | xq.clamp_(0, width - 1) 86 | yq.clamp_(0, height - 1) 87 | 88 | # ------------------------------------------------------------------ 89 | # Find neighbors 90 | # 91 | # x0 = torch.floor(xq).long(), x0.clamp_(0, width - 1) 92 | # x1 = x0 + 1, x1.clamp_(0, width - 1) 93 | # y0 = torch.floor(yq).long(), y0.clamp_(0, height - 1) 94 | # y1 = y0 + 1, y1.clamp_(0, height - 1) 95 | # 96 | # ------------------------------------------------------------------ 97 | self._x0 = torch.floor(xq).long().clamp(0, width - 1) 98 | self._y0 = torch.floor(yq).long().clamp(0, height - 1) 99 | 100 | self._x1 = torch.add(self._x0, 1).clamp(0, width - 1) 101 | self._y1 = torch.add(self._y0, 1).clamp(0, height - 1) 102 | 103 | # batch_sub2ind 104 | self._batch_sub2ind([height, width], self._y0, self._x0, out=self._i00) 105 | self._batch_sub2ind([height, width], self._y0, self._x1, out=self._i01) 106 | self._batch_sub2ind([height, width], self._y1, self._x0, out=self._i10) 107 | self._batch_sub2ind([height, width], self._y1, self._x1, out=self._i11) 108 | 109 | # reshape 110 | v_flat = _bchw2bhwc(v).contiguous().view(-1, channels) 111 | torch.index_select(v_flat, dim=0, index=self._i00.view(-1), out=self._v00) 112 | torch.index_select(v_flat, dim=0, index=self._i01.view(-1), out=self._v01) 113 | torch.index_select(v_flat, dim=0, index=self._i10.view(-1), out=self._v10) 114 | torch.index_select(v_flat, dim=0, index=self._i11.view(-1), out=self._v11) 115 | 116 | # local_coords 117 | torch.add(xq, - self._x0.float(), out=self._x) 118 | torch.add(yq, - self._y0.float(), out=self._y) 119 | 120 | # weights 121 | w00 = torch.unsqueeze((1.0 - self._y) * (1.0 - self._x), dim=1) 122 | w01 = torch.unsqueeze((1.0 - self._y) * self._x, dim=1) 123 | w10 = torch.unsqueeze(self._y * (1.0 - self._x), dim=1) 124 | w11 = torch.unsqueeze(self._y * self._x, dim=1) 125 | 126 | def _reshape(u): 127 | return _bhwc2bchw(u.view(batch_size, height, width, channels)) 128 | 129 | # values 130 | values = _reshape(self._v00)*w00 + _reshape(self._v01)*w01 \ 131 | + _reshape(self._v10)*w10 + _reshape(self._v11)*w11 132 | 133 | if self._clamp: 134 | return values 135 | else: 136 | # find_invalid 137 | invalid = ((xq < 0) | (xq >= width) | (yq < 0) | (yq >= height)).unsqueeze(dim=1).float() 138 | # maskout invalid 139 | transformed = invalid * torch.zeros_like(values) + (1.0 - invalid)*values 140 | 141 | return transformed 142 | 143 | 144 | class Interp2MaskBinary(nn.Module): 145 | def __init__(self, clamp=False): 146 | super(Interp2MaskBinary, self).__init__() 147 | self._clamp = clamp 148 | self._batch_sub2ind = BatchSub2Ind() 149 | self.register_buffer("_x0", torch.LongTensor()) 150 | self.register_buffer("_x1", torch.LongTensor()) 151 | self.register_buffer("_y0", torch.LongTensor()) 152 | self.register_buffer("_y1", torch.LongTensor()) 153 | self.register_buffer("_i00", torch.LongTensor()) 154 | self.register_buffer("_i01", torch.LongTensor()) 155 | self.register_buffer("_i10", torch.LongTensor()) 156 | self.register_buffer("_i11", torch.LongTensor()) 157 | self.register_buffer("_v00", torch.FloatTensor()) 158 | self.register_buffer("_v01", torch.FloatTensor()) 159 | self.register_buffer("_v10", torch.FloatTensor()) 160 | self.register_buffer("_v11", torch.FloatTensor()) 161 | self.register_buffer("_m00", torch.FloatTensor()) 162 | self.register_buffer("_m01", torch.FloatTensor()) 163 | self.register_buffer("_m10", torch.FloatTensor()) 164 | self.register_buffer("_m11", torch.FloatTensor()) 165 | self.register_buffer("_x", torch.FloatTensor()) 166 | self.register_buffer("_y", torch.FloatTensor()) 167 | 168 | def forward(self, v, xq, yq, mask): 169 | batch_size, channels, height, width = v.size() 170 | _, channels_mask, _, _ = mask.size() 171 | 172 | if channels_mask != channels: 173 | mask = mask.repeat(1, int(channels/channels_mask), 1, 1) 174 | 175 | # clamp if wanted 176 | if self._clamp: 177 | xq.clamp_(0, width - 1) 178 | yq.clamp_(0, height - 1) 179 | 180 | # ------------------------------------------------------------------ 181 | # Find neighbors 182 | # 183 | # x0 = torch.floor(xq).long(), x0.clamp_(0, width - 1) 184 | # x1 = x0 + 1, x1.clamp_(0, width - 1) 185 | # y0 = torch.floor(yq).long(), y0.clamp_(0, height - 1) 186 | # y1 = y0 + 1, y1.clamp_(0, height - 1) 187 | # 188 | # ------------------------------------------------------------------ 189 | self._x0 = torch.floor(xq).long().clamp(0, width - 1) 190 | self._y0 = torch.floor(yq).long().clamp(0, height - 1) 191 | 192 | self._x1 = torch.add(self._x0, 1).clamp(0, width - 1) 193 | self._y1 = torch.add(self._y0, 1).clamp(0, height - 1) 194 | 195 | # batch_sub2ind 196 | self._batch_sub2ind([height, width], self._y0, self._x0, out=self._i00) 197 | self._batch_sub2ind([height, width], self._y0, self._x1, out=self._i01) 198 | self._batch_sub2ind([height, width], self._y1, self._x0, out=self._i10) 199 | self._batch_sub2ind([height, width], self._y1, self._x1, out=self._i11) 200 | 201 | # reshape 202 | v_flat = _bchw2bhwc(v).contiguous().view(-1, channels) 203 | torch.index_select(v_flat, dim=0, index=self._i00.view(-1), out=self._v00) 204 | torch.index_select(v_flat, dim=0, index=self._i01.view(-1), out=self._v01) 205 | torch.index_select(v_flat, dim=0, index=self._i10.view(-1), out=self._v10) 206 | torch.index_select(v_flat, dim=0, index=self._i11.view(-1), out=self._v11) 207 | 208 | # reshape 209 | m_flat = _bchw2bhwc(mask).contiguous().view(-1, channels) 210 | torch.index_select(m_flat, dim=0, index=self._i00.view(-1), out=self._m00) 211 | torch.index_select(m_flat, dim=0, index=self._i01.view(-1), out=self._m01) 212 | torch.index_select(m_flat, dim=0, index=self._i10.view(-1), out=self._m10) 213 | torch.index_select(m_flat, dim=0, index=self._i11.view(-1), out=self._m11) 214 | 215 | # local_coords 216 | torch.add(xq, - self._x0.float(), out=self._x) 217 | torch.add(yq, - self._y0.float(), out=self._y) 218 | 219 | # weights 220 | w00 = torch.unsqueeze((1.0 - self._y) * (1.0 - self._x), dim=1) 221 | w01 = torch.unsqueeze((1.0 - self._y) * self._x, dim=1) 222 | w10 = torch.unsqueeze(self._y * (1.0 - self._x), dim=1) 223 | w11 = torch.unsqueeze(self._y * self._x, dim=1) 224 | 225 | def _reshape(u): 226 | return _bhwc2bchw(u.view(batch_size, height, width, channels)) 227 | 228 | # values 229 | values = _reshape(self._m00) * _reshape(self._v00) * w00 + _reshape(self._m01) * _reshape( 230 | self._v01) * w01 + _reshape(self._m10) * _reshape(self._v10) * w10 + _reshape(self._m11) * _reshape( 231 | self._v11) * w11 232 | m_weights = _reshape(self._m00) * w00 + _reshape(self._m01) * w01 + _reshape(self._m10) * w10 + _reshape( 233 | self._m11) * w11 234 | values = values / (m_weights + 1e-12) 235 | invalid_mask = (((1 - m_weights) / (m_weights + 1e-12)) > 0.5)[:, 0:1, :, :] 236 | 237 | if self._clamp: 238 | return values 239 | else: 240 | # find_invalid 241 | invalid = ((xq < 0) | (xq >= width) | (yq < 0) | (yq >= height) | invalid_mask.squeeze(dim=1)).unsqueeze(dim=1).float() 242 | transformed = invalid * torch.zeros_like(values) + (1.0 - invalid) * values 243 | 244 | return transformed, (1 - invalid_mask).float() 245 | 246 | 247 | def resize2D(inputs, size_targets, mode="bilinear"): 248 | size_inputs = [inputs.size(2), inputs.size(3)] 249 | 250 | if all([size_inputs == size_targets]): 251 | return inputs # nothing to do 252 | elif any([size_targets < size_inputs]): 253 | resized = tf.adaptive_avg_pool2d(inputs, size_targets) # downscaling 254 | else: 255 | resized = tf.upsample(inputs, size=size_targets, mode=mode) # upsampling 256 | 257 | # correct scaling 258 | return resized 259 | 260 | 261 | def resize2D_as(inputs, output_as, mode="bilinear"): 262 | size_targets = [output_as.size(2), output_as.size(3)] 263 | return resize2D(inputs, size_targets, mode=mode) 264 | --------------------------------------------------------------------------------